Compare commits

...

10 Commits

Author SHA1 Message Date
Edward Beeching
7041fbc9d6
Update setup.py (#315)
Some checks failed
Tests / Run tests and quality checks (push) Has been cancelled
adds peft as a temp dep due to https://github.com/huggingface/trl/issues/2849
2025-02-13 15:04:03 +01:00
Kashif Rasul
90a6de94c7
Revert "Weighted reward functions (#213)" (#317)
This reverts commit fbea53267b9676fc89e92c9a24c83cb23e0884d0.
2025-02-13 15:00:05 +01:00
Almaz Zinollayev
fbea53267b
Weighted reward functions (#213)
* [Weighted reward functions] Adding functionality to weigh rewards. Tests.

* [Weighted reward functions] Adding @wraps decorator to preserve reward function metadata

* style

* Changing grpo.py tests to run if cuda is available

* style

* Apply suggestions from code review

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2025-02-13 14:08:27 +01:00
lewtun
272b648c03
Fix logging import (#316) 2025-02-13 12:01:09 +01:00
Kashif Rasul
7832290687
[Rewards] add kimi len_reward (#292)
* add kimi len_reward

* add to REWARD_FUNCS_REGISTRY

* fix formatting

* Update src/open_r1/grpo.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update src/open_r1/grpo.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update src/open_r1/grpo.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update src/open_r1/rewards.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update src/open_r1/rewards.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update src/open_r1/rewards.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update src/open_r1/rewards.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update src/open_r1/rewards.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* missing import

---------

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2025-02-13 11:51:09 +01:00
Edward Beeching
80e7e7b23c
move details script and fix wandb logging (#314) 2025-02-13 11:13:00 +01:00
Edward Beeching
f987b3c877
bump vllm to version to 0.7.2 (#311)
VLLM has made a number of throughput improvements in version 0.7.2, so it's worth bumping the version, particularly for GRPO training runs.
2025-02-13 10:48:11 +01:00
lewtun
96a6b0fa33
Enable Weights & Biases defaults to be overridden in training (#294)
* Enable WandB defaults to be set

* Fix
2025-02-12 13:01:07 +01:00
Anton Lozhkov
fa9b621cc9
Fix uuid in the data generator (#284)
* fix uuid issues
2025-02-11 14:08:46 +01:00
Quentin Gallouédec
52aa8759a2
new grpo logic (#274) 2025-02-11 09:35:06 +01:00
15 changed files with 218 additions and 22 deletions

4
.gitignore vendored
View File

@ -175,4 +175,6 @@ data/
wandb/
logs/
eval_results/
results/
results/
.vscode/

View File

@ -57,7 +57,7 @@ uv venv openr1 --python 3.11 && source openr1/bin/activate && uv pip install --u
Next, install vLLM:
```shell
uv pip install vllm==0.7.1 --link-mode=copy
uv pip install vllm==0.7.2 --link-mode=copy
```
This will also install PyTorch `v2.5.1` and it is **very important** to use this version since the vLLM binaries are compiled for it. You can then install the remaining dependencies for your specific use case via `pip install -e .[LIST OF MODES]`. For most contributors, we recommend:
@ -126,6 +126,14 @@ accelerate launch --config_file recipes/accelerate_configs/zero3.yaml src/open_r
--per_device_train_batch_size=1 --num_train_epochs=5
```
If you also wish to override the Weights and Biases default settings, you can do so as follows:
```shell
accelerate launch --config_file recipes/accelerate_configs/zero3.yaml src/open_r1/sft.py \
--config recipes/Qwen2.5-1.5B-Instruct/sft/config_demo.yaml
--wandb_entity huggingface --wandb_project open-r1 --run_name Qwen2.5-1.5B-GRPO
```
> [!NOTE]
> The training commands below are configured for a node of 8 x H100s (80GB). For different hardware and topologies, you may need to tune the batch size and number of gradient accumulation steps.
@ -141,10 +149,10 @@ ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_con
### GRPO
To train via the GRPO trainer, we use one GPU to run vLLM for faster generation and the remaining GPUs for training. For example, one a node with 8 GPUs, use the `recipes/accelerate_configs/zero3.yaml` config and then overwrite `num_processes` to run on 7 devices:
To train via the GRPO trainer, we use one GPU to run vLLM for faster generation and the remaining GPUs for training. For example, one a node with 8 GPUs, use the `recipes/accelerate_configs/zero2.yaml` config and then overwrite `num_processes` to run on 7 devices:
```shell
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/zero3.yaml \
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/zero2.yaml \
--num_processes=7 src/open_r1/grpo.py \
--config recipes/Qwen2.5-1.5B-Instruct/grpo/config_demo.yaml
```

View File

@ -31,12 +31,12 @@ lr_scheduler_type: cosine
max_prompt_length: 512
max_completion_length: 1024
max_steps: -1
num_generations: 2
num_generations: 7
num_train_epochs: 1
output_dir: data/DeepSeek-R1-Distill-Qwen-7B-GRPO
overwrite_output_dir: true
per_device_eval_batch_size: 4
per_device_train_batch_size: 2
per_device_eval_batch_size: 32
per_device_train_batch_size: 16
push_to_hub: true
report_to:
- wandb

View File

@ -33,12 +33,12 @@ lr_scheduler_type: cosine
max_prompt_length: 512
max_completion_length: 1024
max_steps: -1
num_generations: 2
num_generations: 7
num_train_epochs: 1
output_dir: data/Qwen2.5-1.5B-Open-R1-GRPO
overwrite_output_dir: true
per_device_eval_batch_size: 4
per_device_train_batch_size: 2
per_device_eval_batch_size: 32
per_device_train_batch_size: 16
push_to_hub: true
report_to:
- wandb

View File

@ -37,8 +37,8 @@ num_generations: 7
num_train_epochs: 1
output_dir: data/Qwen-2.5-7B-Simple-RL
overwrite_output_dir: true
per_device_eval_batch_size: 2
per_device_train_batch_size: 2
per_device_eval_batch_size: 16
per_device_train_batch_size: 16
push_to_hub: true
report_to:
- wandb

View File

@ -1,5 +1,6 @@
import argparse
import asyncio
import hashlib
import json
import os
import random
@ -87,14 +88,14 @@ async def process_example(example, session, args, output_file, pbar):
return None
async def load_processed_uuids(output_file):
async def load_processed_uuids(output_file, uuid_column):
processed_uuids = set()
if os.path.exists(output_file):
async with aiofiles.open(output_file, mode="r") as f:
async for line in f:
try:
data = json.loads(line)
processed_uuids.add(data["uuid"])
processed_uuids.add(hashlib.md5(str(data[uuid_column]).encode()).hexdigest())
except json.JSONDecodeError:
continue
return processed_uuids
@ -120,7 +121,9 @@ async def main():
args = parser.parse_args()
dataset = load_dataset(args.dataset_name, split="train").shuffle()
processed_uuids = await load_processed_uuids(args.output_file)
processed_uuids = await load_processed_uuids(args.output_file, args.uuid_column)
if processed_uuids:
print(f"Found {len(processed_uuids)} already processed examples, resuming from there...")
if not os.path.exists(args.output_file):
async with aiofiles.open(args.output_file, mode="w") as f:
@ -129,7 +132,7 @@ async def main():
active_tasks: Set[asyncio.Task] = set()
pbar = tqdm(
total=len(dataset),
total=len(dataset) - len(processed_uuids),
desc="Generating responses",
unit="row",
mininterval=2,
@ -142,7 +145,8 @@ async def main():
connector=aiohttp.TCPConnector(limit=args.max_concurrent, ttl_dns_cache=300, keepalive_timeout=60 * 60),
) as session:
for example in dataset:
if example["uuid"] not in processed_uuids:
uuid = hashlib.md5(str(example[args.uuid_column]).encode()).hexdigest()
if uuid not in processed_uuids:
# Wait if we've hit the concurrency limit
while len(active_tasks) >= args.max_concurrent:
done, active_tasks = await asyncio.wait(active_tasks, return_when=asyncio.FIRST_COMPLETED)

View File

@ -58,6 +58,7 @@ _deps = [
"math-verify==0.5.2", # Used for math verification in grpo
"packaging>=23.0",
"parameterized>=0.9.0",
"peft>=0.14.0",
"pytest",
"ruff>=0.9.0",
"safetensors>=0.3.3",

View File

@ -81,7 +81,7 @@ echo "Uploading details to Hugging Face Hub..."
DETAILS_FILEPATHS=$(find $OUTPUT_DIR/details/ -type f \( -name "*.parquet" \))
echo "DETAILS_FILEPATHS: $DETAILS_FILEPATHS"
TIMESTAMP=$(date +"%Y-%m-%dT%H-%M-%S")
python src/open_r1/utils/upload_details.py --data_files $DETAILS_FILEPATHS --hub_repo_id $DETAILS_REPO_ID --config_name $MODEL_REVISION.$TASK_NAME.$TIMESTAMP
python scripts/upload_details.py --data_files $DETAILS_FILEPATHS --hub_repo_id $DETAILS_REPO_ID --config_name $MODEL_REVISION.$TASK_NAME.$TIMESTAMP
echo "Cleaning up ..."
rm -rf $OUTPUT_DIR

View File

@ -40,6 +40,14 @@ class GRPOConfig(trl.GRPOConfig):
)
overwrite_hub_revision: bool = field(default=False, metadata={"help": "Whether to overwrite the Hub revision."})
push_to_hub_revision: bool = field(default=False, metadata={"help": "Whether to push to a Hub revision/branch."})
wandb_entity: Optional[str] = field(
default=None,
metadata={"help": ("The entity to store runs under.")},
)
wandb_project: Optional[str] = field(
default=None,
metadata={"help": ("The project to store runs under.")},
)
@dataclass
@ -64,3 +72,11 @@ class SFTConfig(trl.SFTConfig):
)
overwrite_hub_revision: bool = field(default=False, metadata={"help": "Whether to overwrite the Hub revision."})
push_to_hub_revision: bool = field(default=False, metadata={"help": "Whether to push to a Hub revision/branch."})
wandb_entity: Optional[str] = field(
default=None,
metadata={"help": ("The entity to store runs under.")},
)
wandb_project: Optional[str] = field(
default=None,
metadata={"help": ("The project to store runs under.")},
)

View File

@ -30,9 +30,11 @@ from open_r1.rewards import (
format_reward,
get_cosine_scaled_reward,
get_repetition_penalty_reward,
len_reward,
reasoning_steps_reward,
)
from open_r1.utils.callbacks import get_callbacks
from open_r1.utils.wandb_logging import init_wandb_training
from trl import GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config
@ -46,7 +48,7 @@ class GRPOScriptArguments(ScriptArguments):
Args:
reward_funcs (`list[str]`):
List of reward functions. Possible values: 'accuracy', 'format', 'reasoning_steps', 'cosine', 'repetition_penalty'.
List of reward functions. Possible values: 'accuracy', 'format', 'reasoning_steps', 'cosine', 'repetition_penalty', 'length'.
cosine_min_value_wrong (`float`):
Minimum reward for cosine scaling for wrong answers.
cosine_max_value_wrong (`float`):
@ -62,7 +64,7 @@ class GRPOScriptArguments(ScriptArguments):
reward_funcs: list[str] = field(
default_factory=lambda: ["accuracy", "format"],
metadata={
"help": "List of reward functions. Possible values: 'accuracy', 'format', 'reasoning_steps', 'cosine', 'repetition_penalty'"
"help": "List of reward functions. Possible values: 'accuracy', 'format', 'reasoning_steps', 'cosine', 'repetition_penalty', 'length'"
},
)
cosine_min_value_wrong: float = field(
@ -130,7 +132,7 @@ def main(script_args, training_args, model_args):
)
logger.info(f"Model parameters {model_args}")
logger.info(f"Script parameters {script_args}")
logger.info(f"Data parameters {training_args}")
logger.info(f"Training parameters {training_args}")
# Check for last checkpoint
last_checkpoint = None
@ -139,6 +141,9 @@ def main(script_args, training_args, model_args):
if last_checkpoint is not None and training_args.resume_from_checkpoint is None:
logger.info(f"Checkpoint detected, resuming training at {last_checkpoint=}.")
if "wandb" in training_args.report_to:
init_wandb_training(training_args)
# Load the dataset
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
@ -158,6 +163,7 @@ def main(script_args, training_args, model_args):
ngram_size=script_args.repetition_n_grams,
max_penalty=script_args.repetition_max_penalty,
),
"length": len_reward,
}
reward_funcs = [REWARD_FUNCS_REGISTRY[func] for func in script_args.reward_funcs]

View File

@ -2,6 +2,7 @@
import math
import re
from typing import Dict
from latex2sympy2_extended import NormalizationConfig
from math_verify import LatexExtractionConfig, parse, verify
@ -74,6 +75,79 @@ def reasoning_steps_reward(completions, **kwargs):
return [min(1.0, count / 3) for count in matches]
def len_reward(completions: list[Dict[str, str]], solutions: list[str], **kwargs) -> float:
"""Compute length-based rewards to discourage overthinking and promote token efficiency.
Taken from from the Kimi 1.5 tech report: https://arxiv.org/abs/2501.12599
Args:
completions: List of model completions
solutions: List of ground truth solutions
Returns:
List of rewards where:
- For correct answers: reward = 0.5 - (len - min_len)/(max_len - min_len)
- For incorrect answers: reward = min(0, 0.5 - (len - min_len)/(max_len - min_len))
"""
contents = [completion[0]["content"] for completion in completions]
# First check correctness of answers
correctness = []
for content, sol in zip(contents, solutions):
gold_parsed = parse(
sol,
extraction_mode="first_match",
extraction_config=[LatexExtractionConfig()],
)
if len(gold_parsed) == 0:
# Skip unparseable examples
correctness.append(True) # Treat as correct to avoid penalizing
print("Failed to parse gold solution: ", sol)
continue
answer_parsed = parse(
content,
extraction_config=[
LatexExtractionConfig(
normalization_config=NormalizationConfig(
nits=False,
malformed_operators=False,
basic_latex=True,
equations=True,
boxed=True,
units=True,
),
boxed_match_priority=0,
try_extract_without_anchor=False,
)
],
extraction_mode="first_match",
)
correctness.append(verify(answer_parsed, gold_parsed))
# Calculate lengths
lengths = [len(content) for content in contents]
min_len = min(lengths)
max_len = max(lengths)
# If all responses have the same length, return zero rewards
if max_len == min_len:
return [0.0] * len(completions)
rewards = []
for length, is_correct in zip(lengths, correctness):
lambda_val = 0.5 - (length - min_len) / (max_len - min_len)
if is_correct:
reward = lambda_val
else:
reward = min(0, lambda_val)
rewards.append(float(reward))
return rewards
def get_cosine_scaled_reward(
min_value_wrong: float = -1.0,
max_value_wrong: float = -0.5,

View File

@ -48,6 +48,7 @@ from transformers.trainer_utils import get_last_checkpoint
from open_r1.configs import SFTConfig
from open_r1.utils.callbacks import get_callbacks
from open_r1.utils.wandb_logging import init_wandb_training
from trl import (
ModelConfig,
ScriptArguments,
@ -88,7 +89,7 @@ def main(script_args, training_args, model_args):
)
logger.info(f"Model parameters {model_args}")
logger.info(f"Script parameters {script_args}")
logger.info(f"Data parameters {training_args}")
logger.info(f"Training parameters {training_args}")
# Check for last checkpoint
last_checkpoint = None
@ -97,6 +98,9 @@ def main(script_args, training_args, model_args):
if last_checkpoint is not None and training_args.resume_from_checkpoint is None:
logger.info(f"Checkpoint detected, resuming training at {last_checkpoint=}.")
if "wandb" in training_args.report_to:
init_wandb_training(training_args)
################
# Load datasets
################

View File

@ -0,0 +1,11 @@
import os
def init_wandb_training(training_args):
"""
Helper function for setting up Weights & Biases logging tools.
"""
if training_args.wandb_entity is not None:
os.environ["WANDB_ENTITY"] = training_args.wandb_entity
if training_args.wandb_project is not None:
os.environ["WANDB_PROJECT"] = training_args.wandb_project

View File

@ -5,6 +5,7 @@ from open_r1.rewards import (
format_reward,
get_cosine_scaled_reward,
get_repetition_penalty_reward,
len_reward,
reasoning_steps_reward,
)
@ -110,6 +111,75 @@ class TestRewards(unittest.TestCase):
rewards = format_reward(completion)
self.assertEqual(rewards[0], 1.0)
def test_same_length_responses(self):
"""Test len_reward when all responses have the same length."""
completions = [[{"content": r"\boxed{\frac{63}{400}}"}], [{"content": r"\boxed{\frac{64}{400}}"}]]
solutions = [r"\frac{63}{400}", r"\frac{63}{400}"]
rewards = len_reward(completions, solutions)
self.assertEqual(rewards, [0.0, 0.0])
def test_different_lengths_correct_answers(self):
"""Test len_reward with different length correct answers."""
completions = [
[{"content": r"\boxed{\frac{63}{400}}"}], # shorter
[{"content": r"\boxed{\frac{63}{400}} " + "x" * 10}], # longer
]
solutions = [r"\frac{63}{400}", r"\frac{63}{400}"]
rewards = len_reward(completions, solutions)
self.assertGreater(rewards[0], rewards[1]) # shorter answer should get higher reward
self.assertAlmostEqual(rewards[0], 0.5) # shortest correct answer gets maximum reward
def test_different_lengths_incorrect_answers(self):
"""Test len_reward with different length incorrect answers."""
completions = [
[{"content": r"\boxed{\frac{64}{400}}"}], # shorter
[{"content": r"\boxed{\frac{64}{400}} " + "x" * 10}], # longer
]
solutions = [r"\frac{63}{400}", r"\frac{63}{400}"]
rewards = len_reward(completions, solutions)
self.assertLessEqual(rewards[0], 0.0) # incorrect answers should get non-positive rewards
self.assertLessEqual(rewards[1], 0.0)
self.assertGreater(rewards[0], rewards[1]) # shorter answer should still be penalized less
def test_mixed_correctness(self):
"""Test len_reward with mix of correct and incorrect answers of different lengths."""
completions = [
[{"content": r"\boxed{\frac{63}{400}}"}], # correct, shorter
[{"content": r"\boxed{\frac{63}{400}} " + "x" * 10}], # correct, longer
[{"content": r"\boxed{\frac{64}{400}}"}], # incorrect, shorter
[{"content": r"\boxed{\frac{64}{400}} " + "x" * 10}], # incorrect, longer
]
solutions = [r"\frac{63}{400}"] * 4
rewards = len_reward(completions, solutions)
# Shortest correct answer should get positive reward
self.assertGreater(rewards[0], 0.0)
# Longer correct answer might get negative reward:
self.assertGreater(rewards[2], rewards[1])
self.assertGreaterEqual(rewards[1], rewards[3])
# Incorrect answers should get non-positive rewards
self.assertLessEqual(rewards[2], 0.0)
self.assertLessEqual(rewards[3], 0.0)
# Shorter answers should get better rewards within their correctness category
self.assertGreater(rewards[0], rewards[1]) # correct answers
self.assertGreater(rewards[2], rewards[3]) # incorrect answers
def test_unparseable_solution(self):
"""Test len_reward with unparseable solution."""
completions = [[{"content": r"\boxed{answer}"}], [{"content": r"\boxed{answer} " + "x" * 10}]]
solutions = ["unparseable_latex", "unparseable_latex"]
rewards = len_reward(completions, solutions)
self.assertGreater(rewards[0], rewards[1]) # shorter answer should still get better reward
self.assertAlmostEqual(rewards[0], 0.5) # treated as correct, shortest gets maximum reward
class TestRepetitionPenaltyReward(unittest.TestCase):
def test_positive_max_penalty_raises_value_error(self):