RLHF (Beta)

Reinforcement Learning from Human Feedback is a method whereby a language model is optimized from data using human feedback.

Overview

Reinforcement Learning from Human Feedback is a method whereby a language model is optimized from data using human feedback. Various methods include, but not limited to:

RLHF using Axolotl

Important

This is a BETA feature and many features are not fully implemented. You are encouraged to open new PRs to improve the integration and functionality.

We rely on the TRL library for implementations of various RL training methods, which we wrap around to expose in axolotl. Each method has their own supported ways of loading datasets and prompt formats.

Tip

You can find what each method supports by going into src/axolotl/prompt_strategies/{method} where {method} is one of our supported methods. The type: can be retrieved from {method}.{function_name}.

DPO

Example config:

rl: dpo
datasets:
  - path: Intel/orca_dpo_pairs
    split: train
    type: chatml.intel
  - path: argilla/ultrafeedback-binarized-preferences
    split: train
    type: chatml

DPO supports the following types with the following dataset format:

chatml.argilla

{
    "system": "...", // optional
    "instruction": "...",
    "chosen_response": "...",
    "rejected_response": "..."
}

chatml.argilla_chat

{
    "chosen": [
        {"role": "user", "content": "..."},
        {"role": "assistant", "content": "..."}
    ],
    "rejected": [
        {"role": "user", "content": "..."},
        {"role": "assistant", "content": "..."}
    ]
}

chatml.icr

{
    "system": "...", // optional
    "input": "...",
    "chosen": "...",
    "rejected": "..."
}

chatml.intel

{
    "system": "...", // optional
    "question": "...",
    "chosen": "...",
    "rejected": "..."
}

chatml.prompt_pairs

{
    "system": "...", // optional
    "prompt": "...",
    "chosen": "...",
    "rejected": "..."
}

chatml.ultra

{
    "system": "...", // optional
    "prompt": "...",
    "chosen": [
        {"role": "user", "content": "..."},
        {"role": "assistant", "content": "..."}
    ],
    "rejected": [
        {"role": "user", "content": "..."},
        {"role": "assistant", "content": "..."}
    ]
}

llama3.argilla

{
    "system": "...", // optional
    "instruction": "...",
    "chosen_response": "...",
    "rejected_response": "..."
}

llama3.argilla_chat

{
    "chosen": [
        {"role": "user", "content": "..."},
        {"role": "assistant", "content": "..."}
    ],
    "rejected": [
        {"role": "user", "content": "..."},
        {"role": "assistant", "content": "..."}
    ]
}

llama3.icr

{
    "system": "...", // optional
    "input": "...",
    "chosen": "...",
    "rejected": "..."
}

llama3.intel

{
    "system": "...", // optional
    "question": "...",
    "chosen": "...",
    "rejected": "..."
}

llama3.prompt_pairs

{
    "system": "...", // optional
    "prompt": "...",
    "chosen": "...",
    "rejected": "..."
}

llama3.ultra

{
    "system": "...", // optional
    "prompt": "...",
    "chosen": [
        {"role": "user", "content": "..."},
        {"role": "assistant", "content": "..."}
    ],
    "rejected": [
        {"role": "user", "content": "..."},
        {"role": "assistant", "content": "..."}
    ]
}

zephyr.nectar

{
    "prompt": "...",
    "answers": [
        {
            "answer": "...",
            "rank": 1
        },
        {
            "answer": "...",
            "rank": 2
        }
        // ... more answers with ranks
    ]
}

chat_template.argilla_chat

{
    "chosen": [
        {"role": "user", "content": "..."},
        {"role": "assistant", "content": "..."}
    ],
    "rejected": [
        {"role": "user", "content": "..."},
        {"role": "assistant", "content": "..."}
    ]
}

chat_template.default

rl: dpo
datasets:
  - path: ...
    split: train
    type: chat_template.default
    field_messages: "messages"
    field_chosen: "chosen"
    field_rejected: "rejected"
    message_property_mappings:
      role: role
      content: content
    roles:
      user: ["user"]
      assistant: ["assistant"]
      system: ["system"]

Sample input format:

{
    "messages": [
        {
            "role": "system",
            "content": "..."
        },
        {
            "role": "user",
            "content": "..."
        },
        // ... more messages
    ],
    "chosen": {
        "role": "assistant",
        "content": "..."
    },
    "rejected": {
        "role": "assistant",
        "content": "..."
    }
}

user_defined.default

For custom behaviors,

rl: dpo
datasets:
  - path: ...
    split: train
    type:
      field_prompt: "prompt"
      field_system: "system"
      field_chosen: "chosen"
      field_rejected: "rejected"
      prompt_format: "{prompt}"
      chosen_format: "{chosen}"
      rejected_format: "{rejected}"

The input format is a simple JSON input with customizable fields based on the above config.

{
    "system": "...",  // optional
    "prompt": "...",
    "chosen": "...",
    "rejected": "..."
}

IPO

As IPO is just DPO with a different loss function, all supported dataset formats for DPO are also supported for IPO.

rl: ipo

ORPO

Paper: https://arxiv.org/abs/2403.07691

rl: orpo
orpo_alpha: 0.1
remove_unused_columns: false

chat_template: chatml
datasets:
  - path: argilla/ultrafeedback-binarized-preferences-cleaned
    type: chat_template.argilla

ORPO supports the following types with the following dataset format:

chat_template.argilla

{
    "system": "...",  // optional
    "prompt": "...",  // if available, will be taken as user message for single-turn instead of from list below

    // chosen/rejected should be same till last content and only even-number of alternating user/assistant turns
    "chosen": [
        {"role": "user", "content": "..."},
        {"role": "assistant", "content": "..."}
    ],
    "rejected": [
        {"role": "user", "content": "..."},
        {"role": "assistant", "content": "..."}
    ]
}

KTO

rl: kto
rl_beta: 0.1  # default
kto_desirable_weight: 1.0  # default
kto_undesirable_weight: 1.0  # default

remove_unused_columns: false

datasets:
  - path: argilla/ultrafeedback-binarized-preferences-cleaned-kto
    type: llama3.ultra
    split: train

gradient_checkpointing: true
gradient_checkpointing_kwargs:
  use_reentrant: true

KTO supports the following types with the following dataset format:

chatml.argilla

{
    "system": "...", // optional
    "instruction": "...",
    "completion": "..."
}

chatml.argilla_chat

{
    "chosen": [
        {"role": "user", "content": "..."}
    ],
    "completion": [
        {"role": "assistant", "content": "..."}
    ]
}

chatml.intel

{
    "system": "...", // optional
    "question": "...",
    "completion": "..."
}

chatml.prompt_pairs

{
    "system": "...", // optional
    "prompt": "...",
    "completion": "..."
}

chatml.ultra

{
    "system": "...", // optional
    "prompt": "...",
    "completion": "..."
}

llama3.argilla

{
    "system": "...", // optional
    "instruction": "...",
    "completion": "..."
}

llama3.argilla_chat

{
    "completion": [
        {"role": "user", "content": "..."},
        {"role": "assistant", "content": "..."}
    ]
}

llama3.intel

{
    "system": "...", // optional
    "question": "...",
    "completion": "..."
}

llama3.prompt_pairs

{
    "system": "...", // optional
    "prompt": "...",
    "completion": "..."
}

llama3.ultra

{
    "system": "...", // optional
    "prompt": "...",
    "completion": "..."
}

user_defined.default

For custom behaviors,

rl: kto
datasets:
  - path: ...
    split: train
    type:
      field_prompt: "prompt"
      field_system: "system"
      field_completion: "completion"
      field_label: "label"
      prompt_format: "{prompt}"
      completion_format: "{completion}"

The input format is a simple JSON input with customizable fields based on the above config.

{
    "system": "...",  // optional
    "prompt": "...",
    "completion": "...",
    "label": "..."
}

GRPO

Tip

Check out our GRPO cookbook.

In the latest GRPO implementation, vLLM is used to significantly speedup trajectory generation during training. In this example, we’re using 4 GPUs - 2 for training, and 2 for vLLM:

Important

Make sure you’ve installed the correct version of vLLM by including it as an extra when installing axolotl, e.g. pip install axolotl[vllm].

base_model: Qwen/Qwen2.5-1.5B-Instruct

vllm:
    host: 0.0.0.0
    port: 8000
    tensor_parallel_size: 2
    gpu_memory_utilization: 0.85
    dtype: auto
    # max_model_len: # you may find it useful to set the vLLM model context length if you know this beforehand

rl: grpo
trl:
    use_vllm: true
    vllm_server_host: 0.0.0.0
    vllm_server_port: 8000
    vllm_server_timeout: 300
CUDA_VISIBLE_DEVICES=2,3 axolotl vllm-serve grpo.yaml

Your vLLM instance will now attempt to spin up, and it’s time to kick off training utilizing our remaining two GPUs. In another terminal, execute:

CUDA_VISIBLE_DEVICES=0,1 axolotl train grpo.yaml --num-processes 2
Note

Due to TRL’s implementation with vLLM, the vLLM instance must use the last N GPUs instead of the first N GPUs. This is why in the example above, we use CUDA_VISIBLE_DEVICES=2,3 for the vLLM instance.

Reward functions

GRPO uses custom reward functions and transformations. Please have them ready locally.

For example, to load OpenAI’s GSM8K and use a random reward for completions:

# rewards.py
import random

def rand_reward_func(completions, **kwargs) -> list[float]:
    return [random.uniform(0, 1) for _ in completions]

def oai_gsm8k_transform(cfg, *args, **kwargs):
    def transform_fn(example, tokenizer=None):
        label = example["answer"].split("####")[-1].strip().replace(",", "")
        return {
            "prompt": [{"role": "user", "content": example["question"]},],
            "answer": label,
        }
    return transform_fn, {"remove_columns": ["question"]}
rl: grpo

trl:
    beta: 0.001
    max_completion_length: 256
    use_vllm: True
    num_generations: 4
    reward_funcs: ["rewards.rand_reward_func"]    # format: '{file_name}.{fn_name}'
    reward_weights: [1.0]
datasets:
  - path: openai/gsm8k
    name: main
    type: rewards.oai_gsm8k_transform  # format: '{file_name}.{fn_name}'

To see other examples of custom reward functions, please see TRL GRPO Docs.

To see all configs, please see TRLConfig.

OpenEnv Rollout Functions

GRPO supports custom rollout functions for OpenEnv-style environments, enabling interactive tasks like web browsing, code execution, or tool use. This allows you to implement custom generation logic that interacts with external environments.

For example, to implement a simple math-solving environment with step-by-step verification:

# math_env.py
import re

def math_solver_rollout(model, processing_class, prompts, generation_config=None):
    """
    Custom rollout function that generates step-by-step math solutions.

    Args:
        model: The language model
        processing_class: The tokenizer/processing_class
        prompts: List of prompt dicts (with 'messages' key for chat format)
        generation_config: Optional generation configuration

    Returns:
        List of completion strings
    """
    completions = []

    for prompt in prompts:
        # Apply chat template to prompt
        messages = prompt.get("messages", [])
        formatted_prompt = processing_class.apply_chat_template(
            messages, processing_class=False, add_generation_prompt=True
        )

        # Generate step-by-step solution
        full_response = ""
        for step in range(5):  # Max 5 reasoning steps
            current_input = formatted_prompt + full_response + "\nNext step:"
            inputs = processing_class(current_input, return_tensors="pt").to(model.device)

            outputs = model.generate(
                **inputs,
                max_new_tokens=100,
                generation_config=generation_config,
            )
            step_text = processing_class.decode(
                outputs[0][inputs.input_ids.shape[1]:],
                skip_special_tokens=True
            )

            # Check if solution is complete
            if "FINAL ANSWER:" in step_text:
                full_response += step_text
                break
            full_response += step_text + "\n"

        completions.append(full_response)

    return completions

def math_reward(prompts, completions, answers, **kwargs):
    """Reward function that checks mathematical correctness"""
    rewards = []
    for completion, correct_answer in zip(completions, answers):
        # Extract predicted answer
        match = re.search(r"FINAL ANSWER:\s*(.+)", completion)
        predicted = match.group(1).strip() if match else ""

        # Compare with correct answer
        reward = 1.0 if predicted == str(correct_answer) else 0.0
        rewards.append(reward)

    return rewards

def math_transform(cfg, *args, **kwargs):
    """Transform dataset to GRPO format with answer field"""
    def transform_fn(example, processing_class=None):
        return {
            "prompt": [{"role": "user", "content": example["question"]}],
            "answer": str(example["answer"]),
        }
    return transform_fn, {"remove_columns": ["question"]}
rl: grpo

trl:
  beta: 0.001
  max_completion_length: 512
  num_generations: 4
  rollout_func: "math_env.math_solver_rollout"  # Custom rollout function
  reward_funcs: ["math_env.math_reward"]
  reward_weights: [1.0]

datasets:
  - path: openai/gsm8k
    name: main
    type: math_env.math_transform

The rollout_func parameter accepts a fully qualified name (e.g., module_name.function_name) that points to a callable function in your local directory. The function receives:

  • model: The language model
  • processing_class: The tokenizer/processing class
  • prompts: List of prompt dictionaries
  • generation_config (optional): Generation configuration

And should return a list of completion strings.

For more OpenEnv examples, see TRL OpenEnv Documentation.

GRPO with DAPO/Dr. GRPO loss

The DAPO paper and subsequently Dr. GRPO paper proposed an alternative loss function for GRPO to remediate the penalty in longer responses.

trl:
  loss_type: dr_grpo
  # Normalizes loss based on max completion length (default: 256)
  max_completion_length:

For more information, see GRPO docs.

Async GRPO

Async GRPO overlaps vLLM generation with training by producing rollouts in a background thread. While the model trains on the current batch, the next batch is already being generated. This can significantly reduce wall-clock time per step.

trl:
  use_data_producer: true     # Enable data producer protocol
  use_vllm: true
  async_prefetch: true         # Generate rollouts in background thread
  prefetch_depth: 1            # Number of rollouts to prefetch
  vllm_sync_interval: 2        # Sync weights to vLLM every N steps
Note

Because the background thread generates completions with slightly stale model weights, async GRPO uses importance sampling correction to account for the distribution shift. This is controlled by vllm_importance_sampling_correction: true (default when async is enabled).

vLLM LoRA Sync

By default, weight sync to vLLM merges the LoRA adapter into the base model and broadcasts all parameters via NCCL. LoRA sync is a faster alternative that saves only the adapter weights to the filesystem and has vLLM load them natively using Punica kernels.

adapter: lora
lora_r: 32
lora_alpha: 64
lora_target_linear: true

trl:
  vllm_lora_sync: true         # Enable native LoRA sync

When vllm_lora_sync: true is set, axolotl automatically selects the LoRA-aware vLLM serve module. Start vLLM as usual:

CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve config.yaml

Then start training on a separate GPU:

CUDA_VISIBLE_DEVICES=1 axolotl train config.yaml
Tip

LoRA sync is especially beneficial with multi-GPU training (FSDP/DeepSpeed), where NCCL merge-sync can cause GPU contention with vLLM generation.

Streaming Partial Batch

Instead of scoring the entire batch at once, streaming mode scores one prompt group at a time. This enables finer-grained zero-advantage skipping and reduces peak memory usage during scoring.

trl:
  streaming_partial_batch: true
Importance Sampling Correction

When using async prefetch, completions are generated from a slightly older version of the model. Importance sampling (IS) correction adjusts the policy gradient to account for this distribution shift.

trl:
  vllm_importance_sampling_correction: true   # Enable IS correction
  importance_sampling_level: token             # 'token' or 'sequence'
  off_policy_mask_threshold: 0.5              # Mask sequences with IS ratio below this
  • importance_sampling_level: token applies per-token IS ratios (recommended with Liger kernel)
  • importance_sampling_level: sequence applies per-sequence IS ratios
  • off_policy_mask_threshold masks out sequences where the IS ratio indicates they are too far off-policy
Replay Buffer

The replay buffer caches rollout groups that had learning signal (non-zero reward variance) and uses them to replace zero-signal groups in later batches.

trl:
  replay_buffer_size: 100       # Max cached groups (0 = disabled)
  replay_recompute_logps: true  # Recompute log-probs for replayed data (recommended)
Note

When replay_recompute_logps: true (default), old log-probabilities are recomputed using the current model weights. This fixes the IS mismatch that would otherwise occur when replaying stale data.

Deferred Re-rolling

Failed prompts (where the model produces zero reward for all generations) are buffered and re-injected into later batches when the model may be better equipped to solve them.

trl:
  reroll_start_fraction: 0.5    # Start re-rolling after 50% of training
  reroll_max_groups: 1          # Max groups to replace per batch
Zero-Advantage Batch Skipping

When all advantages in a micro-batch are zero (no learning signal), the forward/backward pass is skipped entirely. This is enabled by default and logged as skipped_zero_adv_batches=1.

trl:
  skip_zero_advantage_batches: true   # default
Parallel Reward Workers

Reward functions that use signal.alarm() (e.g., math_verify) must run in the main thread. Parallel reward workers use subprocesses to work around this limitation while enabling concurrent reward computation.

trl:
  reward_num_workers: 4         # Number of subprocess workers (1 = no parallelism)
Full Async GRPO Example
base_model: Qwen/Qwen2.5-1.5B-Instruct

vllm:
    host: 0.0.0.0
    port: 8000
    gpu_memory_utilization: 0.35
    dtype: auto

adapter: lora
lora_r: 32
lora_alpha: 64
lora_target_linear: true

rl: grpo
trl:
  use_data_producer: true
  use_vllm: true
  async_prefetch: true
  prefetch_depth: 1
  vllm_sync_interval: 2
  vllm_lora_sync: true
  streaming_partial_batch: true
  vllm_importance_sampling_correction: true
  off_policy_mask_threshold: 0.5
  importance_sampling_level: token
  num_generations: 8
  max_completion_length: 512
  reward_funcs:
    - rewards.accuracy_reward
  reroll_start_fraction: 0.5
  replay_buffer_size: 100
  reward_num_workers: 4
  skip_zero_advantage_batches: true

datasets:
  - path: AI-MO/NuminaMath-TIR
    type: rewards.prompt_transform
    split: train

gradient_accumulation_steps: 4
micro_batch_size: 2
max_steps: 500
learning_rate: 1e-5
bf16: true
gradient_checkpointing: true
# Terminal 1: Start vLLM on GPU 0
CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve config.yaml

# Terminal 2: Train on GPU 1
CUDA_VISIBLE_DEVICES=1 axolotl train config.yaml
Multi-GPU Async GRPO

Async GRPO supports FSDP and DeepSpeed ZeRO-3 for multi-GPU training. vLLM runs on one GPU while training is distributed across the remaining GPUs.

FSDP:

fsdp:
  - full_shard
  - auto_wrap
fsdp_config:
  fsdp_transformer_layer_cls_to_wrap: Qwen2DecoderLayer
gradient_checkpointing_kwargs:
  use_reentrant: false

DeepSpeed ZeRO-3:

deepspeed: deepspeed_configs/zero3_bf16.json
gradient_checkpointing_kwargs:
  use_reentrant: true   # Required for ZeRO-3
# Terminal 1: Start vLLM on GPU 0
CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve config.yaml

# Terminal 2: Train on GPUs 0,1
CUDA_VISIBLE_DEVICES=0,1 accelerate launch --num_processes 2 -m axolotl.cli.train config.yaml
Important

With multi-GPU async prefetch, only rank 0 generates completions in the background thread. Results are broadcast to all ranks on the main thread. This avoids FSDP/DeepSpeed collective deadlocks from unsynchronized background threads.

GDPO

GDPO (Group Reward-Decoupled Policy Optimization) extends GRPO for multi-reward training. It addresses the reward advantage collapse problem by normalizing each reward function independently before combining them.

Tip

Use GDPO when training with multiple reward functions. For single reward, GRPO and GDPO produce equivalent results.

Paper: https://arxiv.org/pdf/2501.05242

GDPO uses TRL’s native multi_objective_aggregation parameter under the hood. When you set rl: gdpo, axolotl automatically configures TRL to use normalize_then_sum aggregation.

base_model: Qwen/Qwen2.5-1.5B-Instruct

vllm:
    host: 0.0.0.0
    port: 8000
    tensor_parallel_size: 2
    gpu_memory_utilization: 0.85

rl: gdpo

trl:
    beta: 0.001
    max_completion_length: 256
    use_vllm: true
    num_generations: 4
    reward_funcs:
        - rewards.format_reward
        - rewards.correctness_reward
    reward_weights: [1.0, 2.0]

datasets:
    - path: openai/gsm8k
      name: main
      type: rewards.oai_gsm8k_transform

You can also use GRPO with explicit aggregation control:

rl: grpo
trl:
    multi_objective_aggregation: normalize_then_sum  # GDPO behavior
    # or: sum_then_normalize  # Default GRPO behavior

GDPO vs GRPO

Aspect GRPO GDPO
Aggregation sum_then_normalize normalize_then_sum
Multi-reward May collapse advantages Preserves reward signals
Single reward Standard behavior Equivalent to GRPO

Why GDPO?

When using multiple rewards with GRPO, different reward combinations can produce identical advantages:

# Example: format + correctness rewards
[format=0, correct=3] → sum=3
[format=1, correct=2] → sum=3  ← GRPO sees these as equal!
[format=2, correct=1] → sum=3
[format=3, correct=0] → sum=3

GDPO normalizes each reward independently, preserving their relative differences.

Reward Functions

GDPO uses the same reward function format as GRPO:

# rewards.py
def format_reward(completions, **kwargs) -> list[float]:
    return [1.0 if len(c) > 10 else 0.0 for c in completions]

def correctness_reward(completions, answers, **kwargs) -> list[float]:
    rewards = []
    for completion, answer in zip(completions, answers):
        # Your scoring logic here
        rewards.append(score)
    return rewards

Sequence Parallelism

GDPO supports sequence parallelism for long-context training:

rl: gdpo
context_parallel_size: 2

SimPO

SimPO uses CPOTrainer but with alternative loss function.

rl: simpo
rl_beta: 0.1  # default in CPOTrainer
cpo_alpha: 1.0  # default in CPOTrainer
simpo_gamma: 0.5  # default in CPOTrainer

This method uses the same dataset format as DPO.

Using local dataset files

datasets:
  - ds_type: json
    data_files:
      - orca_rlhf.jsonl
    split: train
    type: chatml.intel

TRL auto-unwrapping for PEFT

TRL supports auto-unwrapping PEFT models for RL training paradigms which rely on a reference model. This significantly reduces memory pressure as an additional refreference model does not need to be loaded, and reference model log-probabilities can be obtained by disabling PEFT adapters. This is enabled by default. To turn it off, pass the following config:

# load ref model when adapter training.
rl_adapter_ref_model: true