A Coding Guide on LLM Post Training with TRL from Supervised Fine Tuning to DPO and GRPO Reasoning
In this tutorial, we stroll by way of an entire, hands-on journey of post-training massive language fashions utilizing the highly effective TRL (Transformer Reinforcement Learning) library ecosystem. We begin from a light-weight base mannequin and progressively apply 4 key methods: Supervised Fine-Tuning (SFT), Reward Modeling (RM), Direct Preference Optimization (DPO), and Group Relative Policy Optimization (GRPO). Also, we leverage environment friendly strategies like LoRA to make coaching possible even on restricted {hardware}, equivalent to Google Colab’s T4 GPU. As we transfer step-by-step, we construct instinct for the way fashionable alignment pipelines work, from instructing fashions how to reply to shaping their conduct utilizing preferences and verifiable rewards.
import subprocess, sys
subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "-U",
"torchao>=0.16",
"trl>=0.20",
"transformers>=4.45",
"datasets",
"peft>=0.13",
"accelerate",
"bitsandbytes",
])
import sys as _sys
for _m in [m for m in list(_sys.modules) if m.startswith(("torchao", "peft"))]:
_sys.modules.pop(_m, None)
strive:
import torchao
besides Exception:
import varieties
_fake = varieties.ModuleType("torchao")
_fake.__version__ = "0.16.1"
_sys.modules["torchao"] = _fake
import os, re, gc, torch, warnings
warnings.filterwarnings("ignore")
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["WANDB_DISABLED"] = "true"
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import LoraConfig
print(f"torch={torch.__version__} cuda={torch.cuda.is_available()}")
if torch.cuda.is_available():
print(f"GPU: {torch.cuda.get_device_name(0)} "
f"({torch.cuda.get_device_properties(0).total_memory/1e9:.1f} GB)")
MODEL_NAME = "Qwen/Qwen2.5-0.5B-Instruct"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BF16_OK = torch.cuda.is_available() and torch.cuda.is_bf16_supported()
LORA_CFG = LoraConfig(
r=8, lora_alpha=16, lora_dropout=0.05, bias="none",
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
task_type="CAUSAL_LM",
)
def cleanup():
"""Release VRAM between coaching phases (Colab T4 is tight)."""
gc.acquire()
if torch.cuda.is_available():
torch.cuda.empty_cache()
def chat_generate(mannequin, tokenizer, immediate, max_new_tokens=120):
"""Helper: format as chat, generate, decode simply the assistant flip."""
msgs = [{"role": "user", "content": prompt}]
ids = tokenizer.apply_chat_template(
msgs, return_tensors="pt", add_generation_prompt=True
).to(mannequin.system)
with torch.no_grad():
out = mannequin.generate(
ids, max_new_tokens=max_new_tokens,
do_sample=True, temperature=0.7, top_p=0.9,
pad_token_id=tokenizer.eos_token_id,
)
return tokenizer.decode(out[0][ids.shape[-1]:], skip_special_tokens=True)
We set up and configure the complete coaching stack, making certain compatibility throughout libraries like TRL (Transformer Reinforcement Learning library), Transformers, and PEFT. We arrange surroundings variables and GPU checks, and outline reusable parts equivalent to LoRA configuration and helper features. We additionally put together utility features for reminiscence cleanup and chat-style technology to assist all later phases.
print("n" + "="*72 + "nPART 1 — Supervised Fine-Tuning (SFT)n" + "="*72)
from trl import SFTTrainer, SFTConfig
sft_ds = load_dataset("trl-lib/Capybara", break up="practice[:300]")
print(f"SFT dataset rows: {len(sft_ds)}")
print(f"Example messages: {sft_ds[0]['messages'][:1]}")
sft_args = SFTConfig(
output_dir="./sft_out",
num_train_epochs=1,
per_device_train_batch_size=2,
gradient_accumulation_steps=4,
learning_rate=2e-4,
logging_steps=10,
save_strategy="no",
bf16=BF16_OK, fp16=not BF16_OK,
max_length=768,
gradient_checkpointing=True,
report_to="none",
)
sft_trainer = SFTTrainer(
mannequin=MODEL_NAME,
args=sft_args,
train_dataset=sft_ds,
peft_config=LORA_CFG,
)
sft_trainer.practice()
print("n[SFT inference]")
print("Q: Explain the bias-variance tradeoff in two sentences.")
print("A:", chat_generate(sft_trainer.mannequin, sft_trainer.processing_class,
"Explain the bias-variance tradeoff in two sentences."))
sft_trainer.save_model("./sft_out/closing")
del sft_trainer; cleanup()
We start by supervised fine-tuning, loading a conversational dataset, and configuring the SFT coach. We practice the mannequin to imitate high-quality responses utilizing LoRA for environment friendly adaptation on restricted {hardware}. We then validate the mannequin’s conduct by way of inference to verify it follows instruction-style outputs.
print("n" + "="*72 + "nPART 2 — Reward Modelingn" + "="*72)
from trl import RewardCoach, RewardConfig
rm_ds = load_dataset("trl-lib/ultrafeedback_binarized", break up="practice[:300]")
print(f"RM dataset rows: {len(rm_ds)} keys: {listing(rm_ds[0].keys())}")
rm_args = RewardConfig(
output_dir="./rm_out",
num_train_epochs=1,
per_device_train_batch_size=2,
gradient_accumulation_steps=2,
learning_rate=1e-4,
logging_steps=10,
save_strategy="no",
bf16=BF16_OK, fp16=not BF16_OK,
max_length=512,
gradient_checkpointing=True,
report_to="none",
)
rm_lora = LoraConfig(
r=8, lora_alpha=16, lora_dropout=0.05, bias="none",
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
task_type="SEQ_CLS",
)
rm_trainer = RewardCoach(
mannequin=MODEL_NAME,
args=rm_args,
train_dataset=rm_ds,
peft_config=rm_lora,
)
rm_trainer.practice()
del rm_trainer; cleanup()
We transfer to reward modeling, the place we practice a mannequin to rating responses based mostly on human desire knowledge. We configure a sequence classification setup and practice utilizing chosen vs rejected pairs. This stage helps us study a reward sign that may information alignment in later strategies.
print("n" + "="*72 + "nPART 3 — Direct Preference Optimization (DPO)n" + "="*72)
from trl import DPOTrainer, DPOConfig
dpo_ds = load_dataset("trl-lib/ultrafeedback_binarized", break up="practice[:300]")
dpo_args = DPOConfig(
output_dir="./dpo_out",
num_train_epochs=1,
per_device_train_batch_size=1,
gradient_accumulation_steps=4,
learning_rate=5e-6,
logging_steps=10,
save_strategy="no",
bf16=BF16_OK, fp16=not BF16_OK,
max_length=512,
max_prompt_length=256,
beta=0.1,
gradient_checkpointing=True,
report_to="none",
)
dpo_trainer = DPOTrainer(
mannequin=MODEL_NAME,
args=dpo_args,
train_dataset=dpo_ds,
peft_config=LORA_CFG,
)
dpo_trainer.practice()
del dpo_trainer; cleanup()
We implement Direct Preference Optimization to immediately optimize the mannequin utilizing desire knowledge while not having a separate reward mannequin. We configure a low studying fee and management divergence utilizing the beta parameter. We practice the mannequin to effectively align its outputs with most well-liked responses.
print("n" + "="*72 + "nPART 4 — GRPO with verifiable math rewardsn" + "="*72)
from trl import GRPOTrainer, GRPOConfig
import random
random.seed(0)
def make_math_problem():
a, b = random.randint(1, 50), random.randint(1, 50)
op = random.selection(["+", "-", "*"])
expr = f"{a} {op} {b}"
return {
"immediate": f"Solve this and finish your reply with solely the ultimate quantity. {expr} =",
"reply": str(eval(expr)),
}
grpo_ds = Dataset.from_list([make_math_problem() for _ in range(200)])
print(f"GRPO dataset rows: {len(grpo_ds)}")
print(f"Example: {grpo_ds[0]}")
def correctness_reward(completions, **kwargs):
"""+1 if the final quantity within the completion matches the gold reply."""
solutions = kwargs["answer"]
rewards = []
for c, gold in zip(completions, solutions):
nums = re.findall(r"-?d+", c)
rewards.append(1.0 if nums and nums[-1] == gold else 0.0)
return rewards
def brevity_reward(completions, **kwargs):
"""Small bonus for brief solutions — discourages rambling."""
return [max(0.0, 1.0 - len(c) / 200) * 0.2 for c in completions]
grpo_args = GRPOConfig(
output_dir="./grpo_out",
learning_rate=1e-5,
per_device_train_batch_size=2,
gradient_accumulation_steps=2,
num_generations=4,
max_prompt_length=128,
max_completion_length=96,
logging_steps=2,
save_strategy="no",
bf16=BF16_OK, fp16=not BF16_OK,
gradient_checkpointing=True,
max_steps=15,
report_to="none",
)
grpo_trainer = GRPOTrainer(
mannequin=MODEL_NAME,
args=grpo_args,
train_dataset=grpo_ds,
reward_funcs=[correctness_reward, brevity_reward],
peft_config=LORA_CFG,
)
grpo_trainer.practice()
print("n[GRPO inference]")
for q in ["What is 17 + 28?", "What is 9 * 7?", "What is 100 - 47?"]:
a = chat_generate(grpo_trainer.mannequin, grpo_trainer.processing_class, q, 60)
print(f"Q: {q}nA: {a}n")
del grpo_trainer; cleanup()
print("n✓ Tutorial full — you've got skilled 4 post-training algorithms!")
We apply GRPO by producing a number of responses per immediate and evaluating them utilizing customized reward features. We design deterministic rewards for correctness and brevity, permitting the mannequin to study from verifiable alerts. We lastly take a look at the mannequin on arithmetic queries to observe improved reasoning conduct.
In conclusion, we applied and understood 4 main post-training paradigms that outline in the present day’s LLM alignment workflows. We noticed how every technique builds on the earlier one, beginning with structured studying in SFT, transferring to desire understanding in RM, simplifying optimization with DPO, and lastly scaling reasoning with GRPO. Also, we show that superior coaching methods are usually not restricted to large infrastructure; they are often prototyped effectively with the correct instruments and abstractions. It offers us a powerful basis for additional experimentation, customizing reward features, scaling fashions, and designing our personal aligned AI methods.
Check out the Full Codes here. Also, be at liberty to observe us on Twitter and don’t overlook to be part of our 130k+ ML SubReddit and Subscribe to our Newsletter. Wait! are you on telegram? now you can join us on telegram as well.
Need to companion with us for selling your GitHub Repo OR Hugging Face Page OR Product Release OR Webinar and so on.? Connect with us
The put up A Coding Guide on LLM Post Training with TRL from Supervised Fine Tuning to DPO and GRPO Reasoning appeared first on MarkTechPost.
