RLHF (Beta)
Overview
Reinforcement Learning from Human Feedback is a method whereby a language model is optimized from data using human feedback. Various methods include, but not limited to:
RLHF using Axolotl
This is a BETA feature and many features are not fully implemented. You are encouraged to open new PRs to improve the integration and functionality.
We rely on the TRL library for implementations of various RL training methods, which we wrap around to expose in axolotl. Each method has their own supported ways of loading datasets and prompt formats.
You can find what each method supports by going into src/axolotl/prompt_strategies/{method} where {method} is one of our supported methods. The type: can be retrieved from {method}.{function_name}.
DPO
Example config:
rl: dpo
datasets:
- path: Intel/orca_dpo_pairs
split: train
type: chatml.intel
- path: argilla/ultrafeedback-binarized-preferences
split: train
type: chatmlDPO supports the following types with the following dataset format:
chatml.argilla
{
"system": "...", // optional
"instruction": "...",
"chosen_response": "...",
"rejected_response": "..."
}chatml.argilla_chat
{
"chosen": [
{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."}
],
"rejected": [
{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."}
]
}chatml.icr
{
"system": "...", // optional
"input": "...",
"chosen": "...",
"rejected": "..."
}chatml.intel
{
"system": "...", // optional
"question": "...",
"chosen": "...",
"rejected": "..."
}chatml.prompt_pairs
{
"system": "...", // optional
"prompt": "...",
"chosen": "...",
"rejected": "..."
}chatml.ultra
{
"system": "...", // optional
"prompt": "...",
"chosen": [
{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."}
],
"rejected": [
{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."}
]
}llama3.argilla
{
"system": "...", // optional
"instruction": "...",
"chosen_response": "...",
"rejected_response": "..."
}llama3.argilla_chat
{
"chosen": [
{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."}
],
"rejected": [
{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."}
]
}llama3.icr
{
"system": "...", // optional
"input": "...",
"chosen": "...",
"rejected": "..."
}llama3.intel
{
"system": "...", // optional
"question": "...",
"chosen": "...",
"rejected": "..."
}llama3.prompt_pairs
{
"system": "...", // optional
"prompt": "...",
"chosen": "...",
"rejected": "..."
}llama3.ultra
{
"system": "...", // optional
"prompt": "...",
"chosen": [
{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."}
],
"rejected": [
{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."}
]
}zephyr.nectar
{
"prompt": "...",
"answers": [
{
"answer": "...",
"rank": 1
},
{
"answer": "...",
"rank": 2
}
// ... more answers with ranks
]
}chat_template.argilla_chat
{
"chosen": [
{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."}
],
"rejected": [
{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."}
]
}chat_template.default
rl: dpo
datasets:
- path: ...
split: train
type: chat_template.default
field_messages: "messages"
field_chosen: "chosen"
field_rejected: "rejected"
message_property_mappings:
role: role
content: content
roles:
user: ["user"]
assistant: ["assistant"]
system: ["system"]Sample input format:
{
"messages": [
{
"role": "system",
"content": "..."
},
{
"role": "user",
"content": "..."
},
// ... more messages
],
"chosen": {
"role": "assistant",
"content": "..."
},
"rejected": {
"role": "assistant",
"content": "..."
}
}user_defined.default
For custom behaviors,
rl: dpo
datasets:
- path: ...
split: train
type:
field_prompt: "prompt"
field_system: "system"
field_chosen: "chosen"
field_rejected: "rejected"
prompt_format: "{prompt}"
chosen_format: "{chosen}"
rejected_format: "{rejected}"The input format is a simple JSON input with customizable fields based on the above config.
{
"system": "...", // optional
"prompt": "...",
"chosen": "...",
"rejected": "..."
}IPO
As IPO is just DPO with a different loss function, all supported dataset formats for DPO are also supported for IPO.
rl: ipoORPO
Paper: https://arxiv.org/abs/2403.07691
rl: orpo
orpo_alpha: 0.1
remove_unused_columns: false
chat_template: chatml
datasets:
- path: argilla/ultrafeedback-binarized-preferences-cleaned
type: chat_template.argillaORPO supports the following types with the following dataset format:
chat_template.argilla
{
"system": "...", // optional
"prompt": "...", // if available, will be taken as user message for single-turn instead of from list below
// chosen/rejected should be same till last content and only even-number of alternating user/assistant turns
"chosen": [
{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."}
],
"rejected": [
{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."}
]
}KTO
rl: kto
rl_beta: 0.1 # default
kto_desirable_weight: 1.0 # default
kto_undesirable_weight: 1.0 # default
remove_unused_columns: false
datasets:
- path: argilla/ultrafeedback-binarized-preferences-cleaned-kto
type: llama3.ultra
split: train
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: trueKTO supports the following types with the following dataset format:
chatml.argilla
{
"system": "...", // optional
"instruction": "...",
"completion": "..."
}chatml.argilla_chat
{
"chosen": [
{"role": "user", "content": "..."}
],
"completion": [
{"role": "assistant", "content": "..."}
]
}chatml.intel
{
"system": "...", // optional
"question": "...",
"completion": "..."
}chatml.prompt_pairs
{
"system": "...", // optional
"prompt": "...",
"completion": "..."
}chatml.ultra
{
"system": "...", // optional
"prompt": "...",
"completion": "..."
}llama3.argilla
{
"system": "...", // optional
"instruction": "...",
"completion": "..."
}llama3.argilla_chat
{
"completion": [
{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."}
]
}llama3.intel
{
"system": "...", // optional
"question": "...",
"completion": "..."
}llama3.prompt_pairs
{
"system": "...", // optional
"prompt": "...",
"completion": "..."
}llama3.ultra
{
"system": "...", // optional
"prompt": "...",
"completion": "..."
}user_defined.default
For custom behaviors,
rl: kto
datasets:
- path: ...
split: train
type:
field_prompt: "prompt"
field_system: "system"
field_completion: "completion"
field_label: "label"
prompt_format: "{prompt}"
completion_format: "{completion}"The input format is a simple JSON input with customizable fields based on the above config.
{
"system": "...", // optional
"prompt": "...",
"completion": "...",
"label": "..."
}GRPO
Check out our GRPO cookbook.
In the latest GRPO implementation, vLLM is used to significantly speedup trajectory generation during training. In this example, we’re using 4 GPUs - 2 for training, and 2 for vLLM:
Make sure you’ve installed the correct version of vLLM by including it as an extra when installing axolotl, e.g. pip install axolotl[vllm].
base_model: Qwen/Qwen2.5-1.5B-Instruct
vllm:
host: 0.0.0.0
port: 8000
tensor_parallel_size: 2
gpu_memory_utilization: 0.85
dtype: auto
# max_model_len: # you may find it useful to set the vLLM model context length if you know this beforehand
rl: grpo
trl:
use_vllm: true
vllm_server_host: 0.0.0.0
vllm_server_port: 8000
vllm_server_timeout: 300CUDA_VISIBLE_DEVICES=2,3 axolotl vllm-serve grpo.yamlYour vLLM instance will now attempt to spin up, and it’s time to kick off training utilizing our remaining two GPUs. In another terminal, execute:
CUDA_VISIBLE_DEVICES=0,1 axolotl train grpo.yaml --num-processes 2Due to TRL’s implementation with vLLM, the vLLM instance must use the last N GPUs instead of the first N GPUs. This is why in the example above, we use CUDA_VISIBLE_DEVICES=2,3 for the vLLM instance.
Reward functions
GRPO uses custom reward functions and transformations. Please have them ready locally.
For example, to load OpenAI’s GSM8K and use a random reward for completions:
# rewards.py
import random
def rand_reward_func(completions, **kwargs) -> list[float]:
return [random.uniform(0, 1) for _ in completions]
def oai_gsm8k_transform(cfg, *args, **kwargs):
def transform_fn(example, tokenizer=None):
label = example["answer"].split("####")[-1].strip().replace(",", "")
return {
"prompt": [{"role": "user", "content": example["question"]},],
"answer": label,
}
return transform_fn, {"remove_columns": ["question"]}rl: grpo
trl:
beta: 0.001
max_completion_length: 256
use_vllm: True
num_generations: 4
reward_funcs: ["rewards.rand_reward_func"] # format: '{file_name}.{fn_name}'
reward_weights: [1.0]
datasets:
- path: openai/gsm8k
name: main
type: rewards.oai_gsm8k_transform # format: '{file_name}.{fn_name}'To see other examples of custom reward functions, please see TRL GRPO Docs.
To see all configs, please see TRLConfig.
OpenEnv Rollout Functions
GRPO supports custom rollout functions for OpenEnv-style environments, enabling interactive tasks like web browsing, code execution, or tool use. This allows you to implement custom generation logic that interacts with external environments.
For example, to implement a simple math-solving environment with step-by-step verification:
# math_env.py
import re
def math_solver_rollout(model, processing_class, prompts, generation_config=None):
"""
Custom rollout function that generates step-by-step math solutions.
Args:
model: The language model
processing_class: The tokenizer/processing_class
prompts: List of prompt dicts (with 'messages' key for chat format)
generation_config: Optional generation configuration
Returns:
List of completion strings
"""
completions = []
for prompt in prompts:
# Apply chat template to prompt
messages = prompt.get("messages", [])
formatted_prompt = processing_class.apply_chat_template(
messages, processing_class=False, add_generation_prompt=True
)
# Generate step-by-step solution
full_response = ""
for step in range(5): # Max 5 reasoning steps
current_input = formatted_prompt + full_response + "\nNext step:"
inputs = processing_class(current_input, return_tensors="pt").to(model.device)
outputs = model.generate(
**inputs,
max_new_tokens=100,
generation_config=generation_config,
)
step_text = processing_class.decode(
outputs[0][inputs.input_ids.shape[1]:],
skip_special_tokens=True
)
# Check if solution is complete
if "FINAL ANSWER:" in step_text:
full_response += step_text
break
full_response += step_text + "\n"
completions.append(full_response)
return completions
def math_reward(prompts, completions, answers, **kwargs):
"""Reward function that checks mathematical correctness"""
rewards = []
for completion, correct_answer in zip(completions, answers):
# Extract predicted answer
match = re.search(r"FINAL ANSWER:\s*(.+)", completion)
predicted = match.group(1).strip() if match else ""
# Compare with correct answer
reward = 1.0 if predicted == str(correct_answer) else 0.0
rewards.append(reward)
return rewards
def math_transform(cfg, *args, **kwargs):
"""Transform dataset to GRPO format with answer field"""
def transform_fn(example, processing_class=None):
return {
"prompt": [{"role": "user", "content": example["question"]}],
"answer": str(example["answer"]),
}
return transform_fn, {"remove_columns": ["question"]}rl: grpo
trl:
beta: 0.001
max_completion_length: 512
num_generations: 4
rollout_func: "math_env.math_solver_rollout" # Custom rollout function
reward_funcs: ["math_env.math_reward"]
reward_weights: [1.0]
datasets:
- path: openai/gsm8k
name: main
type: math_env.math_transformThe rollout_func parameter accepts a fully qualified name (e.g., module_name.function_name) that points to a callable function in your local directory. The function receives:
model: The language modelprocessing_class: The tokenizer/processing classprompts: List of prompt dictionariesgeneration_config(optional): Generation configuration
And should return a list of completion strings.
For more OpenEnv examples, see TRL OpenEnv Documentation.
GRPO with DAPO/Dr. GRPO loss
The DAPO paper and subsequently Dr. GRPO paper proposed an alternative loss function for GRPO to remediate the penalty in longer responses.
trl:
loss_type: dr_grpo
# Normalizes loss based on max completion length (default: 256)
max_completion_length:For more information, see GRPO docs.
Async GRPO
Async GRPO overlaps vLLM generation with training by producing rollouts in a background thread. While the model trains on the current batch, the next batch is already being generated. This can significantly reduce wall-clock time per step.
trl:
use_data_producer: true # Enable data producer protocol
use_vllm: true
async_prefetch: true # Generate rollouts in background thread
prefetch_depth: 1 # Number of rollouts to prefetch
vllm_sync_interval: 2 # Sync weights to vLLM every N stepsBecause the background thread generates completions with slightly stale model weights, async GRPO uses importance sampling correction to account for the distribution shift. This is controlled by vllm_importance_sampling_correction: true (default when async is enabled).
vLLM LoRA Sync
By default, weight sync to vLLM merges the LoRA adapter into the base model and broadcasts all parameters via NCCL. LoRA sync is a faster alternative that saves only the adapter weights to the filesystem and has vLLM load them natively using Punica kernels.
adapter: lora
lora_r: 32
lora_alpha: 64
lora_target_linear: true
trl:
vllm_lora_sync: true # Enable native LoRA syncWhen vllm_lora_sync: true is set, axolotl automatically selects the LoRA-aware vLLM serve module. Start vLLM as usual:
CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve config.yamlThen start training on a separate GPU:
CUDA_VISIBLE_DEVICES=1 axolotl train config.yamlLoRA sync is especially beneficial with multi-GPU training (FSDP/DeepSpeed), where NCCL merge-sync can cause GPU contention with vLLM generation.
Streaming Partial Batch
Instead of scoring the entire batch at once, streaming mode scores one prompt group at a time. This enables finer-grained zero-advantage skipping and reduces peak memory usage during scoring.
trl:
streaming_partial_batch: trueImportance Sampling Correction
When using async prefetch, completions are generated from a slightly older version of the model. Importance sampling (IS) correction adjusts the policy gradient to account for this distribution shift.
trl:
vllm_importance_sampling_correction: true # Enable IS correction
importance_sampling_level: token # 'token' or 'sequence'
off_policy_mask_threshold: 0.5 # Mask sequences with IS ratio below thisimportance_sampling_level: tokenapplies per-token IS ratios (recommended with Liger kernel)importance_sampling_level: sequenceapplies per-sequence IS ratiosoff_policy_mask_thresholdmasks out sequences where the IS ratio indicates they are too far off-policy
Replay Buffer
The replay buffer caches rollout groups that had learning signal (non-zero reward variance) and uses them to replace zero-signal groups in later batches.
trl:
replay_buffer_size: 100 # Max cached groups (0 = disabled)
replay_recompute_logps: true # Recompute log-probs for replayed data (recommended)When replay_recompute_logps: true (default), old log-probabilities are recomputed using the current model weights. This fixes the IS mismatch that would otherwise occur when replaying stale data.
Deferred Re-rolling
Failed prompts (where the model produces zero reward for all generations) are buffered and re-injected into later batches when the model may be better equipped to solve them.
trl:
reroll_start_fraction: 0.5 # Start re-rolling after 50% of training
reroll_max_groups: 1 # Max groups to replace per batchZero-Advantage Batch Skipping
When all advantages in a micro-batch are zero (no learning signal), the forward/backward pass is skipped entirely. This is enabled by default and logged as skipped_zero_adv_batches=1.
trl:
skip_zero_advantage_batches: true # defaultParallel Reward Workers
Reward functions that use signal.alarm() (e.g., math_verify) must run in the main thread. Parallel reward workers use subprocesses to work around this limitation while enabling concurrent reward computation.
trl:
reward_num_workers: 4 # Number of subprocess workers (1 = no parallelism)Full Async GRPO Example
base_model: Qwen/Qwen2.5-1.5B-Instruct
vllm:
host: 0.0.0.0
port: 8000
gpu_memory_utilization: 0.35
dtype: auto
adapter: lora
lora_r: 32
lora_alpha: 64
lora_target_linear: true
rl: grpo
trl:
use_data_producer: true
use_vllm: true
async_prefetch: true
prefetch_depth: 1
vllm_sync_interval: 2
vllm_lora_sync: true
streaming_partial_batch: true
vllm_importance_sampling_correction: true
off_policy_mask_threshold: 0.5
importance_sampling_level: token
num_generations: 8
max_completion_length: 512
reward_funcs:
- rewards.accuracy_reward
reroll_start_fraction: 0.5
replay_buffer_size: 100
reward_num_workers: 4
skip_zero_advantage_batches: true
datasets:
- path: AI-MO/NuminaMath-TIR
type: rewards.prompt_transform
split: train
gradient_accumulation_steps: 4
micro_batch_size: 2
max_steps: 500
learning_rate: 1e-5
bf16: true
gradient_checkpointing: true# Terminal 1: Start vLLM on GPU 0
CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve config.yaml
# Terminal 2: Train on GPU 1
CUDA_VISIBLE_DEVICES=1 axolotl train config.yamlMulti-GPU Async GRPO
Async GRPO supports FSDP and DeepSpeed ZeRO-3 for multi-GPU training. vLLM runs on one GPU while training is distributed across the remaining GPUs.
FSDP:
fsdp:
- full_shard
- auto_wrap
fsdp_config:
fsdp_transformer_layer_cls_to_wrap: Qwen2DecoderLayer
gradient_checkpointing_kwargs:
use_reentrant: falseDeepSpeed ZeRO-3:
deepspeed: deepspeed_configs/zero3_bf16.json
gradient_checkpointing_kwargs:
use_reentrant: true # Required for ZeRO-3# Terminal 1: Start vLLM on GPU 0
CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve config.yaml
# Terminal 2: Train on GPUs 0,1
CUDA_VISIBLE_DEVICES=0,1 accelerate launch --num_processes 2 -m axolotl.cli.train config.yamlWith multi-GPU async prefetch, only rank 0 generates completions in the background thread. Results are broadcast to all ranks on the main thread. This avoids FSDP/DeepSpeed collective deadlocks from unsynchronized background threads.
GDPO
GDPO (Group Reward-Decoupled Policy Optimization) extends GRPO for multi-reward training. It addresses the reward advantage collapse problem by normalizing each reward function independently before combining them.
Use GDPO when training with multiple reward functions. For single reward, GRPO and GDPO produce equivalent results.
Paper: https://arxiv.org/pdf/2501.05242
GDPO uses TRL’s native multi_objective_aggregation parameter under the hood. When you set rl: gdpo, axolotl automatically configures TRL to use normalize_then_sum aggregation.
base_model: Qwen/Qwen2.5-1.5B-Instruct
vllm:
host: 0.0.0.0
port: 8000
tensor_parallel_size: 2
gpu_memory_utilization: 0.85
rl: gdpo
trl:
beta: 0.001
max_completion_length: 256
use_vllm: true
num_generations: 4
reward_funcs:
- rewards.format_reward
- rewards.correctness_reward
reward_weights: [1.0, 2.0]
datasets:
- path: openai/gsm8k
name: main
type: rewards.oai_gsm8k_transformYou can also use GRPO with explicit aggregation control:
rl: grpo
trl:
multi_objective_aggregation: normalize_then_sum # GDPO behavior
# or: sum_then_normalize # Default GRPO behaviorGDPO vs GRPO
| Aspect | GRPO | GDPO |
|---|---|---|
| Aggregation | sum_then_normalize |
normalize_then_sum |
| Multi-reward | May collapse advantages | Preserves reward signals |
| Single reward | Standard behavior | Equivalent to GRPO |
Why GDPO?
When using multiple rewards with GRPO, different reward combinations can produce identical advantages:
# Example: format + correctness rewards
[format=0, correct=3] → sum=3
[format=1, correct=2] → sum=3 ← GRPO sees these as equal!
[format=2, correct=1] → sum=3
[format=3, correct=0] → sum=3
GDPO normalizes each reward independently, preserving their relative differences.
Reward Functions
GDPO uses the same reward function format as GRPO:
# rewards.py
def format_reward(completions, **kwargs) -> list[float]:
return [1.0 if len(c) > 10 else 0.0 for c in completions]
def correctness_reward(completions, answers, **kwargs) -> list[float]:
rewards = []
for completion, answer in zip(completions, answers):
# Your scoring logic here
rewards.append(score)
return rewardsSequence Parallelism
GDPO supports sequence parallelism for long-context training:
rl: gdpo
context_parallel_size: 2SimPO
SimPO uses CPOTrainer but with alternative loss function.
rl: simpo
rl_beta: 0.1 # default in CPOTrainer
cpo_alpha: 1.0 # default in CPOTrainer
simpo_gamma: 0.5 # default in CPOTrainerThis method uses the same dataset format as DPO.
Using local dataset files
datasets:
- ds_type: json
data_files:
- orca_rlhf.jsonl
split: train
type: chatml.intelTRL auto-unwrapping for PEFT
TRL supports auto-unwrapping PEFT models for RL training paradigms which rely on a reference model. This significantly reduces memory pressure as an additional refreference model does not need to be loaded, and reference model log-probabilities can be obtained by disabling PEFT adapters. This is enabled by default. To turn it off, pass the following config:
# load ref model when adapter training.
rl_adapter_ref_model: true