Compare commits
10 Commits
82b2a6525f
...
7041fbc9d6
Author | SHA1 | Date | |
---|---|---|---|
![]() |
7041fbc9d6 | ||
![]() |
90a6de94c7 | ||
![]() |
fbea53267b | ||
![]() |
272b648c03 | ||
![]() |
7832290687 | ||
![]() |
80e7e7b23c | ||
![]() |
f987b3c877 | ||
![]() |
96a6b0fa33 | ||
![]() |
fa9b621cc9 | ||
![]() |
52aa8759a2 |
4
.gitignore
vendored
4
.gitignore
vendored
@ -175,4 +175,6 @@ data/
|
||||
wandb/
|
||||
logs/
|
||||
eval_results/
|
||||
results/
|
||||
results/
|
||||
|
||||
.vscode/
|
14
README.md
14
README.md
@ -57,7 +57,7 @@ uv venv openr1 --python 3.11 && source openr1/bin/activate && uv pip install --u
|
||||
Next, install vLLM:
|
||||
|
||||
```shell
|
||||
uv pip install vllm==0.7.1 --link-mode=copy
|
||||
uv pip install vllm==0.7.2 --link-mode=copy
|
||||
```
|
||||
|
||||
This will also install PyTorch `v2.5.1` and it is **very important** to use this version since the vLLM binaries are compiled for it. You can then install the remaining dependencies for your specific use case via `pip install -e .[LIST OF MODES]`. For most contributors, we recommend:
|
||||
@ -126,6 +126,14 @@ accelerate launch --config_file recipes/accelerate_configs/zero3.yaml src/open_r
|
||||
--per_device_train_batch_size=1 --num_train_epochs=5
|
||||
```
|
||||
|
||||
If you also wish to override the Weights and Biases default settings, you can do so as follows:
|
||||
|
||||
```shell
|
||||
accelerate launch --config_file recipes/accelerate_configs/zero3.yaml src/open_r1/sft.py \
|
||||
--config recipes/Qwen2.5-1.5B-Instruct/sft/config_demo.yaml
|
||||
--wandb_entity huggingface --wandb_project open-r1 --run_name Qwen2.5-1.5B-GRPO
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> The training commands below are configured for a node of 8 x H100s (80GB). For different hardware and topologies, you may need to tune the batch size and number of gradient accumulation steps.
|
||||
|
||||
@ -141,10 +149,10 @@ ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_con
|
||||
|
||||
### GRPO
|
||||
|
||||
To train via the GRPO trainer, we use one GPU to run vLLM for faster generation and the remaining GPUs for training. For example, one a node with 8 GPUs, use the `recipes/accelerate_configs/zero3.yaml` config and then overwrite `num_processes` to run on 7 devices:
|
||||
To train via the GRPO trainer, we use one GPU to run vLLM for faster generation and the remaining GPUs for training. For example, one a node with 8 GPUs, use the `recipes/accelerate_configs/zero2.yaml` config and then overwrite `num_processes` to run on 7 devices:
|
||||
|
||||
```shell
|
||||
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/zero3.yaml \
|
||||
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/zero2.yaml \
|
||||
--num_processes=7 src/open_r1/grpo.py \
|
||||
--config recipes/Qwen2.5-1.5B-Instruct/grpo/config_demo.yaml
|
||||
```
|
||||
|
@ -31,12 +31,12 @@ lr_scheduler_type: cosine
|
||||
max_prompt_length: 512
|
||||
max_completion_length: 1024
|
||||
max_steps: -1
|
||||
num_generations: 2
|
||||
num_generations: 7
|
||||
num_train_epochs: 1
|
||||
output_dir: data/DeepSeek-R1-Distill-Qwen-7B-GRPO
|
||||
overwrite_output_dir: true
|
||||
per_device_eval_batch_size: 4
|
||||
per_device_train_batch_size: 2
|
||||
per_device_eval_batch_size: 32
|
||||
per_device_train_batch_size: 16
|
||||
push_to_hub: true
|
||||
report_to:
|
||||
- wandb
|
||||
|
@ -33,12 +33,12 @@ lr_scheduler_type: cosine
|
||||
max_prompt_length: 512
|
||||
max_completion_length: 1024
|
||||
max_steps: -1
|
||||
num_generations: 2
|
||||
num_generations: 7
|
||||
num_train_epochs: 1
|
||||
output_dir: data/Qwen2.5-1.5B-Open-R1-GRPO
|
||||
overwrite_output_dir: true
|
||||
per_device_eval_batch_size: 4
|
||||
per_device_train_batch_size: 2
|
||||
per_device_eval_batch_size: 32
|
||||
per_device_train_batch_size: 16
|
||||
push_to_hub: true
|
||||
report_to:
|
||||
- wandb
|
||||
|
@ -37,8 +37,8 @@ num_generations: 7
|
||||
num_train_epochs: 1
|
||||
output_dir: data/Qwen-2.5-7B-Simple-RL
|
||||
overwrite_output_dir: true
|
||||
per_device_eval_batch_size: 2
|
||||
per_device_train_batch_size: 2
|
||||
per_device_eval_batch_size: 16
|
||||
per_device_train_batch_size: 16
|
||||
push_to_hub: true
|
||||
report_to:
|
||||
- wandb
|
||||
|
@ -1,5 +1,6 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
@ -87,14 +88,14 @@ async def process_example(example, session, args, output_file, pbar):
|
||||
return None
|
||||
|
||||
|
||||
async def load_processed_uuids(output_file):
|
||||
async def load_processed_uuids(output_file, uuid_column):
|
||||
processed_uuids = set()
|
||||
if os.path.exists(output_file):
|
||||
async with aiofiles.open(output_file, mode="r") as f:
|
||||
async for line in f:
|
||||
try:
|
||||
data = json.loads(line)
|
||||
processed_uuids.add(data["uuid"])
|
||||
processed_uuids.add(hashlib.md5(str(data[uuid_column]).encode()).hexdigest())
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
return processed_uuids
|
||||
@ -120,7 +121,9 @@ async def main():
|
||||
args = parser.parse_args()
|
||||
|
||||
dataset = load_dataset(args.dataset_name, split="train").shuffle()
|
||||
processed_uuids = await load_processed_uuids(args.output_file)
|
||||
processed_uuids = await load_processed_uuids(args.output_file, args.uuid_column)
|
||||
if processed_uuids:
|
||||
print(f"Found {len(processed_uuids)} already processed examples, resuming from there...")
|
||||
|
||||
if not os.path.exists(args.output_file):
|
||||
async with aiofiles.open(args.output_file, mode="w") as f:
|
||||
@ -129,7 +132,7 @@ async def main():
|
||||
active_tasks: Set[asyncio.Task] = set()
|
||||
|
||||
pbar = tqdm(
|
||||
total=len(dataset),
|
||||
total=len(dataset) - len(processed_uuids),
|
||||
desc="Generating responses",
|
||||
unit="row",
|
||||
mininterval=2,
|
||||
@ -142,7 +145,8 @@ async def main():
|
||||
connector=aiohttp.TCPConnector(limit=args.max_concurrent, ttl_dns_cache=300, keepalive_timeout=60 * 60),
|
||||
) as session:
|
||||
for example in dataset:
|
||||
if example["uuid"] not in processed_uuids:
|
||||
uuid = hashlib.md5(str(example[args.uuid_column]).encode()).hexdigest()
|
||||
if uuid not in processed_uuids:
|
||||
# Wait if we've hit the concurrency limit
|
||||
while len(active_tasks) >= args.max_concurrent:
|
||||
done, active_tasks = await asyncio.wait(active_tasks, return_when=asyncio.FIRST_COMPLETED)
|
||||
|
1
setup.py
1
setup.py
@ -58,6 +58,7 @@ _deps = [
|
||||
"math-verify==0.5.2", # Used for math verification in grpo
|
||||
"packaging>=23.0",
|
||||
"parameterized>=0.9.0",
|
||||
"peft>=0.14.0",
|
||||
"pytest",
|
||||
"ruff>=0.9.0",
|
||||
"safetensors>=0.3.3",
|
||||
|
@ -81,7 +81,7 @@ echo "Uploading details to Hugging Face Hub..."
|
||||
DETAILS_FILEPATHS=$(find $OUTPUT_DIR/details/ -type f \( -name "*.parquet" \))
|
||||
echo "DETAILS_FILEPATHS: $DETAILS_FILEPATHS"
|
||||
TIMESTAMP=$(date +"%Y-%m-%dT%H-%M-%S")
|
||||
python src/open_r1/utils/upload_details.py --data_files $DETAILS_FILEPATHS --hub_repo_id $DETAILS_REPO_ID --config_name $MODEL_REVISION.$TASK_NAME.$TIMESTAMP
|
||||
python scripts/upload_details.py --data_files $DETAILS_FILEPATHS --hub_repo_id $DETAILS_REPO_ID --config_name $MODEL_REVISION.$TASK_NAME.$TIMESTAMP
|
||||
|
||||
echo "Cleaning up ..."
|
||||
rm -rf $OUTPUT_DIR
|
||||
|
@ -40,6 +40,14 @@ class GRPOConfig(trl.GRPOConfig):
|
||||
)
|
||||
overwrite_hub_revision: bool = field(default=False, metadata={"help": "Whether to overwrite the Hub revision."})
|
||||
push_to_hub_revision: bool = field(default=False, metadata={"help": "Whether to push to a Hub revision/branch."})
|
||||
wandb_entity: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": ("The entity to store runs under.")},
|
||||
)
|
||||
wandb_project: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": ("The project to store runs under.")},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -64,3 +72,11 @@ class SFTConfig(trl.SFTConfig):
|
||||
)
|
||||
overwrite_hub_revision: bool = field(default=False, metadata={"help": "Whether to overwrite the Hub revision."})
|
||||
push_to_hub_revision: bool = field(default=False, metadata={"help": "Whether to push to a Hub revision/branch."})
|
||||
wandb_entity: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": ("The entity to store runs under.")},
|
||||
)
|
||||
wandb_project: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": ("The project to store runs under.")},
|
||||
)
|
||||
|
@ -30,9 +30,11 @@ from open_r1.rewards import (
|
||||
format_reward,
|
||||
get_cosine_scaled_reward,
|
||||
get_repetition_penalty_reward,
|
||||
len_reward,
|
||||
reasoning_steps_reward,
|
||||
)
|
||||
from open_r1.utils.callbacks import get_callbacks
|
||||
from open_r1.utils.wandb_logging import init_wandb_training
|
||||
from trl import GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config
|
||||
|
||||
|
||||
@ -46,7 +48,7 @@ class GRPOScriptArguments(ScriptArguments):
|
||||
|
||||
Args:
|
||||
reward_funcs (`list[str]`):
|
||||
List of reward functions. Possible values: 'accuracy', 'format', 'reasoning_steps', 'cosine', 'repetition_penalty'.
|
||||
List of reward functions. Possible values: 'accuracy', 'format', 'reasoning_steps', 'cosine', 'repetition_penalty', 'length'.
|
||||
cosine_min_value_wrong (`float`):
|
||||
Minimum reward for cosine scaling for wrong answers.
|
||||
cosine_max_value_wrong (`float`):
|
||||
@ -62,7 +64,7 @@ class GRPOScriptArguments(ScriptArguments):
|
||||
reward_funcs: list[str] = field(
|
||||
default_factory=lambda: ["accuracy", "format"],
|
||||
metadata={
|
||||
"help": "List of reward functions. Possible values: 'accuracy', 'format', 'reasoning_steps', 'cosine', 'repetition_penalty'"
|
||||
"help": "List of reward functions. Possible values: 'accuracy', 'format', 'reasoning_steps', 'cosine', 'repetition_penalty', 'length'"
|
||||
},
|
||||
)
|
||||
cosine_min_value_wrong: float = field(
|
||||
@ -130,7 +132,7 @@ def main(script_args, training_args, model_args):
|
||||
)
|
||||
logger.info(f"Model parameters {model_args}")
|
||||
logger.info(f"Script parameters {script_args}")
|
||||
logger.info(f"Data parameters {training_args}")
|
||||
logger.info(f"Training parameters {training_args}")
|
||||
|
||||
# Check for last checkpoint
|
||||
last_checkpoint = None
|
||||
@ -139,6 +141,9 @@ def main(script_args, training_args, model_args):
|
||||
if last_checkpoint is not None and training_args.resume_from_checkpoint is None:
|
||||
logger.info(f"Checkpoint detected, resuming training at {last_checkpoint=}.")
|
||||
|
||||
if "wandb" in training_args.report_to:
|
||||
init_wandb_training(training_args)
|
||||
|
||||
# Load the dataset
|
||||
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
|
||||
|
||||
@ -158,6 +163,7 @@ def main(script_args, training_args, model_args):
|
||||
ngram_size=script_args.repetition_n_grams,
|
||||
max_penalty=script_args.repetition_max_penalty,
|
||||
),
|
||||
"length": len_reward,
|
||||
}
|
||||
reward_funcs = [REWARD_FUNCS_REGISTRY[func] for func in script_args.reward_funcs]
|
||||
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
import math
|
||||
import re
|
||||
from typing import Dict
|
||||
|
||||
from latex2sympy2_extended import NormalizationConfig
|
||||
from math_verify import LatexExtractionConfig, parse, verify
|
||||
@ -74,6 +75,79 @@ def reasoning_steps_reward(completions, **kwargs):
|
||||
return [min(1.0, count / 3) for count in matches]
|
||||
|
||||
|
||||
def len_reward(completions: list[Dict[str, str]], solutions: list[str], **kwargs) -> float:
|
||||
"""Compute length-based rewards to discourage overthinking and promote token efficiency.
|
||||
|
||||
Taken from from the Kimi 1.5 tech report: https://arxiv.org/abs/2501.12599
|
||||
|
||||
Args:
|
||||
completions: List of model completions
|
||||
solutions: List of ground truth solutions
|
||||
|
||||
Returns:
|
||||
List of rewards where:
|
||||
- For correct answers: reward = 0.5 - (len - min_len)/(max_len - min_len)
|
||||
- For incorrect answers: reward = min(0, 0.5 - (len - min_len)/(max_len - min_len))
|
||||
"""
|
||||
contents = [completion[0]["content"] for completion in completions]
|
||||
|
||||
# First check correctness of answers
|
||||
correctness = []
|
||||
for content, sol in zip(contents, solutions):
|
||||
gold_parsed = parse(
|
||||
sol,
|
||||
extraction_mode="first_match",
|
||||
extraction_config=[LatexExtractionConfig()],
|
||||
)
|
||||
if len(gold_parsed) == 0:
|
||||
# Skip unparseable examples
|
||||
correctness.append(True) # Treat as correct to avoid penalizing
|
||||
print("Failed to parse gold solution: ", sol)
|
||||
continue
|
||||
|
||||
answer_parsed = parse(
|
||||
content,
|
||||
extraction_config=[
|
||||
LatexExtractionConfig(
|
||||
normalization_config=NormalizationConfig(
|
||||
nits=False,
|
||||
malformed_operators=False,
|
||||
basic_latex=True,
|
||||
equations=True,
|
||||
boxed=True,
|
||||
units=True,
|
||||
),
|
||||
boxed_match_priority=0,
|
||||
try_extract_without_anchor=False,
|
||||
)
|
||||
],
|
||||
extraction_mode="first_match",
|
||||
)
|
||||
correctness.append(verify(answer_parsed, gold_parsed))
|
||||
|
||||
# Calculate lengths
|
||||
lengths = [len(content) for content in contents]
|
||||
min_len = min(lengths)
|
||||
max_len = max(lengths)
|
||||
|
||||
# If all responses have the same length, return zero rewards
|
||||
if max_len == min_len:
|
||||
return [0.0] * len(completions)
|
||||
|
||||
rewards = []
|
||||
for length, is_correct in zip(lengths, correctness):
|
||||
lambda_val = 0.5 - (length - min_len) / (max_len - min_len)
|
||||
|
||||
if is_correct:
|
||||
reward = lambda_val
|
||||
else:
|
||||
reward = min(0, lambda_val)
|
||||
|
||||
rewards.append(float(reward))
|
||||
|
||||
return rewards
|
||||
|
||||
|
||||
def get_cosine_scaled_reward(
|
||||
min_value_wrong: float = -1.0,
|
||||
max_value_wrong: float = -0.5,
|
||||
|
@ -48,6 +48,7 @@ from transformers.trainer_utils import get_last_checkpoint
|
||||
|
||||
from open_r1.configs import SFTConfig
|
||||
from open_r1.utils.callbacks import get_callbacks
|
||||
from open_r1.utils.wandb_logging import init_wandb_training
|
||||
from trl import (
|
||||
ModelConfig,
|
||||
ScriptArguments,
|
||||
@ -88,7 +89,7 @@ def main(script_args, training_args, model_args):
|
||||
)
|
||||
logger.info(f"Model parameters {model_args}")
|
||||
logger.info(f"Script parameters {script_args}")
|
||||
logger.info(f"Data parameters {training_args}")
|
||||
logger.info(f"Training parameters {training_args}")
|
||||
|
||||
# Check for last checkpoint
|
||||
last_checkpoint = None
|
||||
@ -97,6 +98,9 @@ def main(script_args, training_args, model_args):
|
||||
if last_checkpoint is not None and training_args.resume_from_checkpoint is None:
|
||||
logger.info(f"Checkpoint detected, resuming training at {last_checkpoint=}.")
|
||||
|
||||
if "wandb" in training_args.report_to:
|
||||
init_wandb_training(training_args)
|
||||
|
||||
################
|
||||
# Load datasets
|
||||
################
|
||||
|
11
src/open_r1/utils/wandb_logging.py
Normal file
11
src/open_r1/utils/wandb_logging.py
Normal file
@ -0,0 +1,11 @@
|
||||
import os
|
||||
|
||||
|
||||
def init_wandb_training(training_args):
|
||||
"""
|
||||
Helper function for setting up Weights & Biases logging tools.
|
||||
"""
|
||||
if training_args.wandb_entity is not None:
|
||||
os.environ["WANDB_ENTITY"] = training_args.wandb_entity
|
||||
if training_args.wandb_project is not None:
|
||||
os.environ["WANDB_PROJECT"] = training_args.wandb_project
|
@ -5,6 +5,7 @@ from open_r1.rewards import (
|
||||
format_reward,
|
||||
get_cosine_scaled_reward,
|
||||
get_repetition_penalty_reward,
|
||||
len_reward,
|
||||
reasoning_steps_reward,
|
||||
)
|
||||
|
||||
@ -110,6 +111,75 @@ class TestRewards(unittest.TestCase):
|
||||
rewards = format_reward(completion)
|
||||
self.assertEqual(rewards[0], 1.0)
|
||||
|
||||
def test_same_length_responses(self):
|
||||
"""Test len_reward when all responses have the same length."""
|
||||
completions = [[{"content": r"\boxed{\frac{63}{400}}"}], [{"content": r"\boxed{\frac{64}{400}}"}]]
|
||||
solutions = [r"\frac{63}{400}", r"\frac{63}{400}"]
|
||||
|
||||
rewards = len_reward(completions, solutions)
|
||||
self.assertEqual(rewards, [0.0, 0.0])
|
||||
|
||||
def test_different_lengths_correct_answers(self):
|
||||
"""Test len_reward with different length correct answers."""
|
||||
completions = [
|
||||
[{"content": r"\boxed{\frac{63}{400}}"}], # shorter
|
||||
[{"content": r"\boxed{\frac{63}{400}} " + "x" * 10}], # longer
|
||||
]
|
||||
solutions = [r"\frac{63}{400}", r"\frac{63}{400}"]
|
||||
|
||||
rewards = len_reward(completions, solutions)
|
||||
self.assertGreater(rewards[0], rewards[1]) # shorter answer should get higher reward
|
||||
self.assertAlmostEqual(rewards[0], 0.5) # shortest correct answer gets maximum reward
|
||||
|
||||
def test_different_lengths_incorrect_answers(self):
|
||||
"""Test len_reward with different length incorrect answers."""
|
||||
completions = [
|
||||
[{"content": r"\boxed{\frac{64}{400}}"}], # shorter
|
||||
[{"content": r"\boxed{\frac{64}{400}} " + "x" * 10}], # longer
|
||||
]
|
||||
solutions = [r"\frac{63}{400}", r"\frac{63}{400}"]
|
||||
|
||||
rewards = len_reward(completions, solutions)
|
||||
self.assertLessEqual(rewards[0], 0.0) # incorrect answers should get non-positive rewards
|
||||
self.assertLessEqual(rewards[1], 0.0)
|
||||
self.assertGreater(rewards[0], rewards[1]) # shorter answer should still be penalized less
|
||||
|
||||
def test_mixed_correctness(self):
|
||||
"""Test len_reward with mix of correct and incorrect answers of different lengths."""
|
||||
completions = [
|
||||
[{"content": r"\boxed{\frac{63}{400}}"}], # correct, shorter
|
||||
[{"content": r"\boxed{\frac{63}{400}} " + "x" * 10}], # correct, longer
|
||||
[{"content": r"\boxed{\frac{64}{400}}"}], # incorrect, shorter
|
||||
[{"content": r"\boxed{\frac{64}{400}} " + "x" * 10}], # incorrect, longer
|
||||
]
|
||||
solutions = [r"\frac{63}{400}"] * 4
|
||||
|
||||
rewards = len_reward(completions, solutions)
|
||||
|
||||
# Shortest correct answer should get positive reward
|
||||
self.assertGreater(rewards[0], 0.0)
|
||||
|
||||
# Longer correct answer might get negative reward:
|
||||
self.assertGreater(rewards[2], rewards[1])
|
||||
self.assertGreaterEqual(rewards[1], rewards[3])
|
||||
|
||||
# Incorrect answers should get non-positive rewards
|
||||
self.assertLessEqual(rewards[2], 0.0)
|
||||
self.assertLessEqual(rewards[3], 0.0)
|
||||
|
||||
# Shorter answers should get better rewards within their correctness category
|
||||
self.assertGreater(rewards[0], rewards[1]) # correct answers
|
||||
self.assertGreater(rewards[2], rewards[3]) # incorrect answers
|
||||
|
||||
def test_unparseable_solution(self):
|
||||
"""Test len_reward with unparseable solution."""
|
||||
completions = [[{"content": r"\boxed{answer}"}], [{"content": r"\boxed{answer} " + "x" * 10}]]
|
||||
solutions = ["unparseable_latex", "unparseable_latex"]
|
||||
|
||||
rewards = len_reward(completions, solutions)
|
||||
self.assertGreater(rewards[0], rewards[1]) # shorter answer should still get better reward
|
||||
self.assertAlmostEqual(rewards[0], 0.5) # treated as correct, shortest gets maximum reward
|
||||
|
||||
|
||||
class TestRepetitionPenaltyReward(unittest.TestCase):
|
||||
def test_positive_max_penalty_raises_value_error(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user