Compare commits
10 Commits
82b2a6525f
...
7041fbc9d6
Author | SHA1 | Date | |
---|---|---|---|
![]() |
7041fbc9d6 | ||
![]() |
90a6de94c7 | ||
![]() |
fbea53267b | ||
![]() |
272b648c03 | ||
![]() |
7832290687 | ||
![]() |
80e7e7b23c | ||
![]() |
f987b3c877 | ||
![]() |
96a6b0fa33 | ||
![]() |
fa9b621cc9 | ||
![]() |
52aa8759a2 |
4
.gitignore
vendored
4
.gitignore
vendored
@ -175,4 +175,6 @@ data/
|
|||||||
wandb/
|
wandb/
|
||||||
logs/
|
logs/
|
||||||
eval_results/
|
eval_results/
|
||||||
results/
|
results/
|
||||||
|
|
||||||
|
.vscode/
|
14
README.md
14
README.md
@ -57,7 +57,7 @@ uv venv openr1 --python 3.11 && source openr1/bin/activate && uv pip install --u
|
|||||||
Next, install vLLM:
|
Next, install vLLM:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
uv pip install vllm==0.7.1 --link-mode=copy
|
uv pip install vllm==0.7.2 --link-mode=copy
|
||||||
```
|
```
|
||||||
|
|
||||||
This will also install PyTorch `v2.5.1` and it is **very important** to use this version since the vLLM binaries are compiled for it. You can then install the remaining dependencies for your specific use case via `pip install -e .[LIST OF MODES]`. For most contributors, we recommend:
|
This will also install PyTorch `v2.5.1` and it is **very important** to use this version since the vLLM binaries are compiled for it. You can then install the remaining dependencies for your specific use case via `pip install -e .[LIST OF MODES]`. For most contributors, we recommend:
|
||||||
@ -126,6 +126,14 @@ accelerate launch --config_file recipes/accelerate_configs/zero3.yaml src/open_r
|
|||||||
--per_device_train_batch_size=1 --num_train_epochs=5
|
--per_device_train_batch_size=1 --num_train_epochs=5
|
||||||
```
|
```
|
||||||
|
|
||||||
|
If you also wish to override the Weights and Biases default settings, you can do so as follows:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
accelerate launch --config_file recipes/accelerate_configs/zero3.yaml src/open_r1/sft.py \
|
||||||
|
--config recipes/Qwen2.5-1.5B-Instruct/sft/config_demo.yaml
|
||||||
|
--wandb_entity huggingface --wandb_project open-r1 --run_name Qwen2.5-1.5B-GRPO
|
||||||
|
```
|
||||||
|
|
||||||
> [!NOTE]
|
> [!NOTE]
|
||||||
> The training commands below are configured for a node of 8 x H100s (80GB). For different hardware and topologies, you may need to tune the batch size and number of gradient accumulation steps.
|
> The training commands below are configured for a node of 8 x H100s (80GB). For different hardware and topologies, you may need to tune the batch size and number of gradient accumulation steps.
|
||||||
|
|
||||||
@ -141,10 +149,10 @@ ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_con
|
|||||||
|
|
||||||
### GRPO
|
### GRPO
|
||||||
|
|
||||||
To train via the GRPO trainer, we use one GPU to run vLLM for faster generation and the remaining GPUs for training. For example, one a node with 8 GPUs, use the `recipes/accelerate_configs/zero3.yaml` config and then overwrite `num_processes` to run on 7 devices:
|
To train via the GRPO trainer, we use one GPU to run vLLM for faster generation and the remaining GPUs for training. For example, one a node with 8 GPUs, use the `recipes/accelerate_configs/zero2.yaml` config and then overwrite `num_processes` to run on 7 devices:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/zero3.yaml \
|
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/zero2.yaml \
|
||||||
--num_processes=7 src/open_r1/grpo.py \
|
--num_processes=7 src/open_r1/grpo.py \
|
||||||
--config recipes/Qwen2.5-1.5B-Instruct/grpo/config_demo.yaml
|
--config recipes/Qwen2.5-1.5B-Instruct/grpo/config_demo.yaml
|
||||||
```
|
```
|
||||||
|
@ -31,12 +31,12 @@ lr_scheduler_type: cosine
|
|||||||
max_prompt_length: 512
|
max_prompt_length: 512
|
||||||
max_completion_length: 1024
|
max_completion_length: 1024
|
||||||
max_steps: -1
|
max_steps: -1
|
||||||
num_generations: 2
|
num_generations: 7
|
||||||
num_train_epochs: 1
|
num_train_epochs: 1
|
||||||
output_dir: data/DeepSeek-R1-Distill-Qwen-7B-GRPO
|
output_dir: data/DeepSeek-R1-Distill-Qwen-7B-GRPO
|
||||||
overwrite_output_dir: true
|
overwrite_output_dir: true
|
||||||
per_device_eval_batch_size: 4
|
per_device_eval_batch_size: 32
|
||||||
per_device_train_batch_size: 2
|
per_device_train_batch_size: 16
|
||||||
push_to_hub: true
|
push_to_hub: true
|
||||||
report_to:
|
report_to:
|
||||||
- wandb
|
- wandb
|
||||||
|
@ -33,12 +33,12 @@ lr_scheduler_type: cosine
|
|||||||
max_prompt_length: 512
|
max_prompt_length: 512
|
||||||
max_completion_length: 1024
|
max_completion_length: 1024
|
||||||
max_steps: -1
|
max_steps: -1
|
||||||
num_generations: 2
|
num_generations: 7
|
||||||
num_train_epochs: 1
|
num_train_epochs: 1
|
||||||
output_dir: data/Qwen2.5-1.5B-Open-R1-GRPO
|
output_dir: data/Qwen2.5-1.5B-Open-R1-GRPO
|
||||||
overwrite_output_dir: true
|
overwrite_output_dir: true
|
||||||
per_device_eval_batch_size: 4
|
per_device_eval_batch_size: 32
|
||||||
per_device_train_batch_size: 2
|
per_device_train_batch_size: 16
|
||||||
push_to_hub: true
|
push_to_hub: true
|
||||||
report_to:
|
report_to:
|
||||||
- wandb
|
- wandb
|
||||||
|
@ -37,8 +37,8 @@ num_generations: 7
|
|||||||
num_train_epochs: 1
|
num_train_epochs: 1
|
||||||
output_dir: data/Qwen-2.5-7B-Simple-RL
|
output_dir: data/Qwen-2.5-7B-Simple-RL
|
||||||
overwrite_output_dir: true
|
overwrite_output_dir: true
|
||||||
per_device_eval_batch_size: 2
|
per_device_eval_batch_size: 16
|
||||||
per_device_train_batch_size: 2
|
per_device_train_batch_size: 16
|
||||||
push_to_hub: true
|
push_to_hub: true
|
||||||
report_to:
|
report_to:
|
||||||
- wandb
|
- wandb
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import hashlib
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
@ -87,14 +88,14 @@ async def process_example(example, session, args, output_file, pbar):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
async def load_processed_uuids(output_file):
|
async def load_processed_uuids(output_file, uuid_column):
|
||||||
processed_uuids = set()
|
processed_uuids = set()
|
||||||
if os.path.exists(output_file):
|
if os.path.exists(output_file):
|
||||||
async with aiofiles.open(output_file, mode="r") as f:
|
async with aiofiles.open(output_file, mode="r") as f:
|
||||||
async for line in f:
|
async for line in f:
|
||||||
try:
|
try:
|
||||||
data = json.loads(line)
|
data = json.loads(line)
|
||||||
processed_uuids.add(data["uuid"])
|
processed_uuids.add(hashlib.md5(str(data[uuid_column]).encode()).hexdigest())
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
continue
|
continue
|
||||||
return processed_uuids
|
return processed_uuids
|
||||||
@ -120,7 +121,9 @@ async def main():
|
|||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
dataset = load_dataset(args.dataset_name, split="train").shuffle()
|
dataset = load_dataset(args.dataset_name, split="train").shuffle()
|
||||||
processed_uuids = await load_processed_uuids(args.output_file)
|
processed_uuids = await load_processed_uuids(args.output_file, args.uuid_column)
|
||||||
|
if processed_uuids:
|
||||||
|
print(f"Found {len(processed_uuids)} already processed examples, resuming from there...")
|
||||||
|
|
||||||
if not os.path.exists(args.output_file):
|
if not os.path.exists(args.output_file):
|
||||||
async with aiofiles.open(args.output_file, mode="w") as f:
|
async with aiofiles.open(args.output_file, mode="w") as f:
|
||||||
@ -129,7 +132,7 @@ async def main():
|
|||||||
active_tasks: Set[asyncio.Task] = set()
|
active_tasks: Set[asyncio.Task] = set()
|
||||||
|
|
||||||
pbar = tqdm(
|
pbar = tqdm(
|
||||||
total=len(dataset),
|
total=len(dataset) - len(processed_uuids),
|
||||||
desc="Generating responses",
|
desc="Generating responses",
|
||||||
unit="row",
|
unit="row",
|
||||||
mininterval=2,
|
mininterval=2,
|
||||||
@ -142,7 +145,8 @@ async def main():
|
|||||||
connector=aiohttp.TCPConnector(limit=args.max_concurrent, ttl_dns_cache=300, keepalive_timeout=60 * 60),
|
connector=aiohttp.TCPConnector(limit=args.max_concurrent, ttl_dns_cache=300, keepalive_timeout=60 * 60),
|
||||||
) as session:
|
) as session:
|
||||||
for example in dataset:
|
for example in dataset:
|
||||||
if example["uuid"] not in processed_uuids:
|
uuid = hashlib.md5(str(example[args.uuid_column]).encode()).hexdigest()
|
||||||
|
if uuid not in processed_uuids:
|
||||||
# Wait if we've hit the concurrency limit
|
# Wait if we've hit the concurrency limit
|
||||||
while len(active_tasks) >= args.max_concurrent:
|
while len(active_tasks) >= args.max_concurrent:
|
||||||
done, active_tasks = await asyncio.wait(active_tasks, return_when=asyncio.FIRST_COMPLETED)
|
done, active_tasks = await asyncio.wait(active_tasks, return_when=asyncio.FIRST_COMPLETED)
|
||||||
|
1
setup.py
1
setup.py
@ -58,6 +58,7 @@ _deps = [
|
|||||||
"math-verify==0.5.2", # Used for math verification in grpo
|
"math-verify==0.5.2", # Used for math verification in grpo
|
||||||
"packaging>=23.0",
|
"packaging>=23.0",
|
||||||
"parameterized>=0.9.0",
|
"parameterized>=0.9.0",
|
||||||
|
"peft>=0.14.0",
|
||||||
"pytest",
|
"pytest",
|
||||||
"ruff>=0.9.0",
|
"ruff>=0.9.0",
|
||||||
"safetensors>=0.3.3",
|
"safetensors>=0.3.3",
|
||||||
|
@ -81,7 +81,7 @@ echo "Uploading details to Hugging Face Hub..."
|
|||||||
DETAILS_FILEPATHS=$(find $OUTPUT_DIR/details/ -type f \( -name "*.parquet" \))
|
DETAILS_FILEPATHS=$(find $OUTPUT_DIR/details/ -type f \( -name "*.parquet" \))
|
||||||
echo "DETAILS_FILEPATHS: $DETAILS_FILEPATHS"
|
echo "DETAILS_FILEPATHS: $DETAILS_FILEPATHS"
|
||||||
TIMESTAMP=$(date +"%Y-%m-%dT%H-%M-%S")
|
TIMESTAMP=$(date +"%Y-%m-%dT%H-%M-%S")
|
||||||
python src/open_r1/utils/upload_details.py --data_files $DETAILS_FILEPATHS --hub_repo_id $DETAILS_REPO_ID --config_name $MODEL_REVISION.$TASK_NAME.$TIMESTAMP
|
python scripts/upload_details.py --data_files $DETAILS_FILEPATHS --hub_repo_id $DETAILS_REPO_ID --config_name $MODEL_REVISION.$TASK_NAME.$TIMESTAMP
|
||||||
|
|
||||||
echo "Cleaning up ..."
|
echo "Cleaning up ..."
|
||||||
rm -rf $OUTPUT_DIR
|
rm -rf $OUTPUT_DIR
|
||||||
|
@ -40,6 +40,14 @@ class GRPOConfig(trl.GRPOConfig):
|
|||||||
)
|
)
|
||||||
overwrite_hub_revision: bool = field(default=False, metadata={"help": "Whether to overwrite the Hub revision."})
|
overwrite_hub_revision: bool = field(default=False, metadata={"help": "Whether to overwrite the Hub revision."})
|
||||||
push_to_hub_revision: bool = field(default=False, metadata={"help": "Whether to push to a Hub revision/branch."})
|
push_to_hub_revision: bool = field(default=False, metadata={"help": "Whether to push to a Hub revision/branch."})
|
||||||
|
wandb_entity: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": ("The entity to store runs under.")},
|
||||||
|
)
|
||||||
|
wandb_project: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": ("The project to store runs under.")},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -64,3 +72,11 @@ class SFTConfig(trl.SFTConfig):
|
|||||||
)
|
)
|
||||||
overwrite_hub_revision: bool = field(default=False, metadata={"help": "Whether to overwrite the Hub revision."})
|
overwrite_hub_revision: bool = field(default=False, metadata={"help": "Whether to overwrite the Hub revision."})
|
||||||
push_to_hub_revision: bool = field(default=False, metadata={"help": "Whether to push to a Hub revision/branch."})
|
push_to_hub_revision: bool = field(default=False, metadata={"help": "Whether to push to a Hub revision/branch."})
|
||||||
|
wandb_entity: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": ("The entity to store runs under.")},
|
||||||
|
)
|
||||||
|
wandb_project: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": ("The project to store runs under.")},
|
||||||
|
)
|
||||||
|
@ -30,9 +30,11 @@ from open_r1.rewards import (
|
|||||||
format_reward,
|
format_reward,
|
||||||
get_cosine_scaled_reward,
|
get_cosine_scaled_reward,
|
||||||
get_repetition_penalty_reward,
|
get_repetition_penalty_reward,
|
||||||
|
len_reward,
|
||||||
reasoning_steps_reward,
|
reasoning_steps_reward,
|
||||||
)
|
)
|
||||||
from open_r1.utils.callbacks import get_callbacks
|
from open_r1.utils.callbacks import get_callbacks
|
||||||
|
from open_r1.utils.wandb_logging import init_wandb_training
|
||||||
from trl import GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config
|
from trl import GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config
|
||||||
|
|
||||||
|
|
||||||
@ -46,7 +48,7 @@ class GRPOScriptArguments(ScriptArguments):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
reward_funcs (`list[str]`):
|
reward_funcs (`list[str]`):
|
||||||
List of reward functions. Possible values: 'accuracy', 'format', 'reasoning_steps', 'cosine', 'repetition_penalty'.
|
List of reward functions. Possible values: 'accuracy', 'format', 'reasoning_steps', 'cosine', 'repetition_penalty', 'length'.
|
||||||
cosine_min_value_wrong (`float`):
|
cosine_min_value_wrong (`float`):
|
||||||
Minimum reward for cosine scaling for wrong answers.
|
Minimum reward for cosine scaling for wrong answers.
|
||||||
cosine_max_value_wrong (`float`):
|
cosine_max_value_wrong (`float`):
|
||||||
@ -62,7 +64,7 @@ class GRPOScriptArguments(ScriptArguments):
|
|||||||
reward_funcs: list[str] = field(
|
reward_funcs: list[str] = field(
|
||||||
default_factory=lambda: ["accuracy", "format"],
|
default_factory=lambda: ["accuracy", "format"],
|
||||||
metadata={
|
metadata={
|
||||||
"help": "List of reward functions. Possible values: 'accuracy', 'format', 'reasoning_steps', 'cosine', 'repetition_penalty'"
|
"help": "List of reward functions. Possible values: 'accuracy', 'format', 'reasoning_steps', 'cosine', 'repetition_penalty', 'length'"
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
cosine_min_value_wrong: float = field(
|
cosine_min_value_wrong: float = field(
|
||||||
@ -130,7 +132,7 @@ def main(script_args, training_args, model_args):
|
|||||||
)
|
)
|
||||||
logger.info(f"Model parameters {model_args}")
|
logger.info(f"Model parameters {model_args}")
|
||||||
logger.info(f"Script parameters {script_args}")
|
logger.info(f"Script parameters {script_args}")
|
||||||
logger.info(f"Data parameters {training_args}")
|
logger.info(f"Training parameters {training_args}")
|
||||||
|
|
||||||
# Check for last checkpoint
|
# Check for last checkpoint
|
||||||
last_checkpoint = None
|
last_checkpoint = None
|
||||||
@ -139,6 +141,9 @@ def main(script_args, training_args, model_args):
|
|||||||
if last_checkpoint is not None and training_args.resume_from_checkpoint is None:
|
if last_checkpoint is not None and training_args.resume_from_checkpoint is None:
|
||||||
logger.info(f"Checkpoint detected, resuming training at {last_checkpoint=}.")
|
logger.info(f"Checkpoint detected, resuming training at {last_checkpoint=}.")
|
||||||
|
|
||||||
|
if "wandb" in training_args.report_to:
|
||||||
|
init_wandb_training(training_args)
|
||||||
|
|
||||||
# Load the dataset
|
# Load the dataset
|
||||||
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
|
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
|
||||||
|
|
||||||
@ -158,6 +163,7 @@ def main(script_args, training_args, model_args):
|
|||||||
ngram_size=script_args.repetition_n_grams,
|
ngram_size=script_args.repetition_n_grams,
|
||||||
max_penalty=script_args.repetition_max_penalty,
|
max_penalty=script_args.repetition_max_penalty,
|
||||||
),
|
),
|
||||||
|
"length": len_reward,
|
||||||
}
|
}
|
||||||
reward_funcs = [REWARD_FUNCS_REGISTRY[func] for func in script_args.reward_funcs]
|
reward_funcs = [REWARD_FUNCS_REGISTRY[func] for func in script_args.reward_funcs]
|
||||||
|
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
import math
|
import math
|
||||||
import re
|
import re
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
from latex2sympy2_extended import NormalizationConfig
|
from latex2sympy2_extended import NormalizationConfig
|
||||||
from math_verify import LatexExtractionConfig, parse, verify
|
from math_verify import LatexExtractionConfig, parse, verify
|
||||||
@ -74,6 +75,79 @@ def reasoning_steps_reward(completions, **kwargs):
|
|||||||
return [min(1.0, count / 3) for count in matches]
|
return [min(1.0, count / 3) for count in matches]
|
||||||
|
|
||||||
|
|
||||||
|
def len_reward(completions: list[Dict[str, str]], solutions: list[str], **kwargs) -> float:
|
||||||
|
"""Compute length-based rewards to discourage overthinking and promote token efficiency.
|
||||||
|
|
||||||
|
Taken from from the Kimi 1.5 tech report: https://arxiv.org/abs/2501.12599
|
||||||
|
|
||||||
|
Args:
|
||||||
|
completions: List of model completions
|
||||||
|
solutions: List of ground truth solutions
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of rewards where:
|
||||||
|
- For correct answers: reward = 0.5 - (len - min_len)/(max_len - min_len)
|
||||||
|
- For incorrect answers: reward = min(0, 0.5 - (len - min_len)/(max_len - min_len))
|
||||||
|
"""
|
||||||
|
contents = [completion[0]["content"] for completion in completions]
|
||||||
|
|
||||||
|
# First check correctness of answers
|
||||||
|
correctness = []
|
||||||
|
for content, sol in zip(contents, solutions):
|
||||||
|
gold_parsed = parse(
|
||||||
|
sol,
|
||||||
|
extraction_mode="first_match",
|
||||||
|
extraction_config=[LatexExtractionConfig()],
|
||||||
|
)
|
||||||
|
if len(gold_parsed) == 0:
|
||||||
|
# Skip unparseable examples
|
||||||
|
correctness.append(True) # Treat as correct to avoid penalizing
|
||||||
|
print("Failed to parse gold solution: ", sol)
|
||||||
|
continue
|
||||||
|
|
||||||
|
answer_parsed = parse(
|
||||||
|
content,
|
||||||
|
extraction_config=[
|
||||||
|
LatexExtractionConfig(
|
||||||
|
normalization_config=NormalizationConfig(
|
||||||
|
nits=False,
|
||||||
|
malformed_operators=False,
|
||||||
|
basic_latex=True,
|
||||||
|
equations=True,
|
||||||
|
boxed=True,
|
||||||
|
units=True,
|
||||||
|
),
|
||||||
|
boxed_match_priority=0,
|
||||||
|
try_extract_without_anchor=False,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
extraction_mode="first_match",
|
||||||
|
)
|
||||||
|
correctness.append(verify(answer_parsed, gold_parsed))
|
||||||
|
|
||||||
|
# Calculate lengths
|
||||||
|
lengths = [len(content) for content in contents]
|
||||||
|
min_len = min(lengths)
|
||||||
|
max_len = max(lengths)
|
||||||
|
|
||||||
|
# If all responses have the same length, return zero rewards
|
||||||
|
if max_len == min_len:
|
||||||
|
return [0.0] * len(completions)
|
||||||
|
|
||||||
|
rewards = []
|
||||||
|
for length, is_correct in zip(lengths, correctness):
|
||||||
|
lambda_val = 0.5 - (length - min_len) / (max_len - min_len)
|
||||||
|
|
||||||
|
if is_correct:
|
||||||
|
reward = lambda_val
|
||||||
|
else:
|
||||||
|
reward = min(0, lambda_val)
|
||||||
|
|
||||||
|
rewards.append(float(reward))
|
||||||
|
|
||||||
|
return rewards
|
||||||
|
|
||||||
|
|
||||||
def get_cosine_scaled_reward(
|
def get_cosine_scaled_reward(
|
||||||
min_value_wrong: float = -1.0,
|
min_value_wrong: float = -1.0,
|
||||||
max_value_wrong: float = -0.5,
|
max_value_wrong: float = -0.5,
|
||||||
|
@ -48,6 +48,7 @@ from transformers.trainer_utils import get_last_checkpoint
|
|||||||
|
|
||||||
from open_r1.configs import SFTConfig
|
from open_r1.configs import SFTConfig
|
||||||
from open_r1.utils.callbacks import get_callbacks
|
from open_r1.utils.callbacks import get_callbacks
|
||||||
|
from open_r1.utils.wandb_logging import init_wandb_training
|
||||||
from trl import (
|
from trl import (
|
||||||
ModelConfig,
|
ModelConfig,
|
||||||
ScriptArguments,
|
ScriptArguments,
|
||||||
@ -88,7 +89,7 @@ def main(script_args, training_args, model_args):
|
|||||||
)
|
)
|
||||||
logger.info(f"Model parameters {model_args}")
|
logger.info(f"Model parameters {model_args}")
|
||||||
logger.info(f"Script parameters {script_args}")
|
logger.info(f"Script parameters {script_args}")
|
||||||
logger.info(f"Data parameters {training_args}")
|
logger.info(f"Training parameters {training_args}")
|
||||||
|
|
||||||
# Check for last checkpoint
|
# Check for last checkpoint
|
||||||
last_checkpoint = None
|
last_checkpoint = None
|
||||||
@ -97,6 +98,9 @@ def main(script_args, training_args, model_args):
|
|||||||
if last_checkpoint is not None and training_args.resume_from_checkpoint is None:
|
if last_checkpoint is not None and training_args.resume_from_checkpoint is None:
|
||||||
logger.info(f"Checkpoint detected, resuming training at {last_checkpoint=}.")
|
logger.info(f"Checkpoint detected, resuming training at {last_checkpoint=}.")
|
||||||
|
|
||||||
|
if "wandb" in training_args.report_to:
|
||||||
|
init_wandb_training(training_args)
|
||||||
|
|
||||||
################
|
################
|
||||||
# Load datasets
|
# Load datasets
|
||||||
################
|
################
|
||||||
|
11
src/open_r1/utils/wandb_logging.py
Normal file
11
src/open_r1/utils/wandb_logging.py
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
def init_wandb_training(training_args):
|
||||||
|
"""
|
||||||
|
Helper function for setting up Weights & Biases logging tools.
|
||||||
|
"""
|
||||||
|
if training_args.wandb_entity is not None:
|
||||||
|
os.environ["WANDB_ENTITY"] = training_args.wandb_entity
|
||||||
|
if training_args.wandb_project is not None:
|
||||||
|
os.environ["WANDB_PROJECT"] = training_args.wandb_project
|
@ -5,6 +5,7 @@ from open_r1.rewards import (
|
|||||||
format_reward,
|
format_reward,
|
||||||
get_cosine_scaled_reward,
|
get_cosine_scaled_reward,
|
||||||
get_repetition_penalty_reward,
|
get_repetition_penalty_reward,
|
||||||
|
len_reward,
|
||||||
reasoning_steps_reward,
|
reasoning_steps_reward,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -110,6 +111,75 @@ class TestRewards(unittest.TestCase):
|
|||||||
rewards = format_reward(completion)
|
rewards = format_reward(completion)
|
||||||
self.assertEqual(rewards[0], 1.0)
|
self.assertEqual(rewards[0], 1.0)
|
||||||
|
|
||||||
|
def test_same_length_responses(self):
|
||||||
|
"""Test len_reward when all responses have the same length."""
|
||||||
|
completions = [[{"content": r"\boxed{\frac{63}{400}}"}], [{"content": r"\boxed{\frac{64}{400}}"}]]
|
||||||
|
solutions = [r"\frac{63}{400}", r"\frac{63}{400}"]
|
||||||
|
|
||||||
|
rewards = len_reward(completions, solutions)
|
||||||
|
self.assertEqual(rewards, [0.0, 0.0])
|
||||||
|
|
||||||
|
def test_different_lengths_correct_answers(self):
|
||||||
|
"""Test len_reward with different length correct answers."""
|
||||||
|
completions = [
|
||||||
|
[{"content": r"\boxed{\frac{63}{400}}"}], # shorter
|
||||||
|
[{"content": r"\boxed{\frac{63}{400}} " + "x" * 10}], # longer
|
||||||
|
]
|
||||||
|
solutions = [r"\frac{63}{400}", r"\frac{63}{400}"]
|
||||||
|
|
||||||
|
rewards = len_reward(completions, solutions)
|
||||||
|
self.assertGreater(rewards[0], rewards[1]) # shorter answer should get higher reward
|
||||||
|
self.assertAlmostEqual(rewards[0], 0.5) # shortest correct answer gets maximum reward
|
||||||
|
|
||||||
|
def test_different_lengths_incorrect_answers(self):
|
||||||
|
"""Test len_reward with different length incorrect answers."""
|
||||||
|
completions = [
|
||||||
|
[{"content": r"\boxed{\frac{64}{400}}"}], # shorter
|
||||||
|
[{"content": r"\boxed{\frac{64}{400}} " + "x" * 10}], # longer
|
||||||
|
]
|
||||||
|
solutions = [r"\frac{63}{400}", r"\frac{63}{400}"]
|
||||||
|
|
||||||
|
rewards = len_reward(completions, solutions)
|
||||||
|
self.assertLessEqual(rewards[0], 0.0) # incorrect answers should get non-positive rewards
|
||||||
|
self.assertLessEqual(rewards[1], 0.0)
|
||||||
|
self.assertGreater(rewards[0], rewards[1]) # shorter answer should still be penalized less
|
||||||
|
|
||||||
|
def test_mixed_correctness(self):
|
||||||
|
"""Test len_reward with mix of correct and incorrect answers of different lengths."""
|
||||||
|
completions = [
|
||||||
|
[{"content": r"\boxed{\frac{63}{400}}"}], # correct, shorter
|
||||||
|
[{"content": r"\boxed{\frac{63}{400}} " + "x" * 10}], # correct, longer
|
||||||
|
[{"content": r"\boxed{\frac{64}{400}}"}], # incorrect, shorter
|
||||||
|
[{"content": r"\boxed{\frac{64}{400}} " + "x" * 10}], # incorrect, longer
|
||||||
|
]
|
||||||
|
solutions = [r"\frac{63}{400}"] * 4
|
||||||
|
|
||||||
|
rewards = len_reward(completions, solutions)
|
||||||
|
|
||||||
|
# Shortest correct answer should get positive reward
|
||||||
|
self.assertGreater(rewards[0], 0.0)
|
||||||
|
|
||||||
|
# Longer correct answer might get negative reward:
|
||||||
|
self.assertGreater(rewards[2], rewards[1])
|
||||||
|
self.assertGreaterEqual(rewards[1], rewards[3])
|
||||||
|
|
||||||
|
# Incorrect answers should get non-positive rewards
|
||||||
|
self.assertLessEqual(rewards[2], 0.0)
|
||||||
|
self.assertLessEqual(rewards[3], 0.0)
|
||||||
|
|
||||||
|
# Shorter answers should get better rewards within their correctness category
|
||||||
|
self.assertGreater(rewards[0], rewards[1]) # correct answers
|
||||||
|
self.assertGreater(rewards[2], rewards[3]) # incorrect answers
|
||||||
|
|
||||||
|
def test_unparseable_solution(self):
|
||||||
|
"""Test len_reward with unparseable solution."""
|
||||||
|
completions = [[{"content": r"\boxed{answer}"}], [{"content": r"\boxed{answer} " + "x" * 10}]]
|
||||||
|
solutions = ["unparseable_latex", "unparseable_latex"]
|
||||||
|
|
||||||
|
rewards = len_reward(completions, solutions)
|
||||||
|
self.assertGreater(rewards[0], rewards[1]) # shorter answer should still get better reward
|
||||||
|
self.assertAlmostEqual(rewards[0], 0.5) # treated as correct, shortest gets maximum reward
|
||||||
|
|
||||||
|
|
||||||
class TestRepetitionPenaltyReward(unittest.TestCase):
|
class TestRepetitionPenaltyReward(unittest.TestCase):
|
||||||
def test_positive_max_penalty_raises_value_error(self):
|
def test_positive_max_penalty_raises_value_error(self):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user