Reward verification and evaluation fixes (#55)
* bump up deps, fix aime24 evals, make grpo more strict
* minor fixes
* 🤨 fmt
* Update src/open_r1/grpo.py
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
---------
Co-authored-by: Hynek Kydlicek <kydlicek.hynek@huggingface.co>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
This commit is contained in:
parent
15df4fb134
commit
90b0947382
2
setup.py
2
setup.py
@ -54,7 +54,7 @@ _deps = [
|
||||
"isort>=5.12.0",
|
||||
"liger_kernel==0.5.2",
|
||||
"lighteval @ git+https://github.com/huggingface/lighteval.git@4f381b352c0e467b5870a97d41cb66b487a2c503#egg=lighteval[math]",
|
||||
"math-verify", # Used for math verification in grpo
|
||||
"math-verify>=0.3.2", # Used for math verification in grpo
|
||||
"packaging>=23.0",
|
||||
"parameterized>=0.9.0",
|
||||
"pytest",
|
||||
|
@ -43,6 +43,7 @@ lighteval vllm $MODEL_ARGS "custom|$TASK|0|0" \
|
||||
--custom-tasks src/open_r1/evaluate.py \
|
||||
--use-chat-template \
|
||||
--system-prompt="Please reason step by step, and put your final answer within \boxed{}." \
|
||||
--save-details
|
||||
--output-dir $OUTPUT_DIR
|
||||
|
||||
|
||||
|
@ -24,7 +24,7 @@ from lighteval.tasks.requests import Doc
|
||||
from lighteval.utils.language import Language
|
||||
|
||||
|
||||
metric = multilingual_extractive_match_metric(
|
||||
latex_gold_metric = multilingual_extractive_match_metric(
|
||||
language=Language.ENGLISH,
|
||||
fallback_mode="first_match",
|
||||
precision=5,
|
||||
@ -33,6 +33,15 @@ metric = multilingual_extractive_match_metric(
|
||||
aggregation_function=max,
|
||||
)
|
||||
|
||||
expr_gold_metric = multilingual_extractive_match_metric(
|
||||
language=Language.ENGLISH,
|
||||
fallback_mode="first_match",
|
||||
precision=5,
|
||||
gold_extraction_target=(ExprExtractionConfig(),),
|
||||
pred_extraction_target=(ExprExtractionConfig(), LatexExtractionConfig()),
|
||||
aggregation_function=max,
|
||||
)
|
||||
|
||||
|
||||
def prompt_fn(line, task_name: str = None):
|
||||
"""Assumes the model is either prompted to emit \\boxed{answer} or does so automatically"""
|
||||
@ -44,11 +53,20 @@ def prompt_fn(line, task_name: str = None):
|
||||
)
|
||||
|
||||
|
||||
def aime_prompt_fn(line, task_name: str = None):
|
||||
return Doc(
|
||||
task_name=task_name,
|
||||
query=line["problem"],
|
||||
choices=[line["answer"]],
|
||||
gold_index=0,
|
||||
)
|
||||
|
||||
|
||||
# Define tasks
|
||||
aime24 = LightevalTaskConfig(
|
||||
name="aime24",
|
||||
suite=["custom"],
|
||||
prompt_function=prompt_fn,
|
||||
prompt_function=aime_prompt_fn,
|
||||
hf_repo="HuggingFaceH4/aime_2024",
|
||||
hf_subset="default",
|
||||
hf_avail_splits=["train"],
|
||||
@ -56,7 +74,7 @@ aime24 = LightevalTaskConfig(
|
||||
few_shots_split=None,
|
||||
few_shots_select=None,
|
||||
generation_size=32768,
|
||||
metric=[metric],
|
||||
metric=[expr_gold_metric],
|
||||
version=1,
|
||||
)
|
||||
math_500 = LightevalTaskConfig(
|
||||
@ -70,7 +88,7 @@ math_500 = LightevalTaskConfig(
|
||||
few_shots_split=None,
|
||||
few_shots_select=None,
|
||||
generation_size=32768,
|
||||
metric=[metric],
|
||||
metric=[latex_gold_metric],
|
||||
version=1,
|
||||
)
|
||||
|
||||
|
@ -17,7 +17,8 @@ from dataclasses import dataclass, field
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
from math_verify import parse, verify
|
||||
from latex2sympy2_extended import NormalizationConfig
|
||||
from math_verify import LatexExtractionConfig, parse, verify
|
||||
from trl import GRPOConfig, GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config
|
||||
|
||||
|
||||
@ -42,13 +43,36 @@ def accuracy_reward(completions, solution, **kwargs):
|
||||
contents = [completion[0]["content"] for completion in completions]
|
||||
rewards = []
|
||||
for content, sol in zip(contents, solution):
|
||||
try:
|
||||
answer = parse(content)
|
||||
reward = float(verify(answer, parse(sol)))
|
||||
except Exception: # if it fails for any reason, return 0.0
|
||||
reward = 0.0
|
||||
gold_parsed = parse(sol, extraction_mode="first_match", extraction_config=[LatexExtractionConfig()])
|
||||
if len(gold_parsed) != 0:
|
||||
# We require the answer to be provided in correct latex (no malformed operators)
|
||||
answer_parsed = parse(
|
||||
content,
|
||||
extraction_config=[
|
||||
LatexExtractionConfig(
|
||||
normalization_config=NormalizationConfig(
|
||||
nits=False,
|
||||
malformed_operators=False,
|
||||
basic_latex=True,
|
||||
equations=True,
|
||||
boxed=True,
|
||||
units=True,
|
||||
),
|
||||
# Ensures that boxed is tried first
|
||||
boxed_match_priority=0,
|
||||
try_extract_without_anchor=False,
|
||||
)
|
||||
],
|
||||
extraction_mode="first_match",
|
||||
)
|
||||
# Reward 1 if the content is the same as the ground truth, 0 otherwise
|
||||
reward = float(verify(answer_parsed, gold_parsed))
|
||||
else:
|
||||
# If the gold solution is not parseable, we reward 1 to skip this example
|
||||
reward = 1.0
|
||||
print("Failed to parse gold solution: ", sol)
|
||||
rewards.append(reward)
|
||||
# Reward 1 if the content is the same as the ground truth, 0 otherwise
|
||||
|
||||
return rewards
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user