fix uv env path + details (#188)

* fix uv env path + details

* Update slurm/grpo.slurm

---------

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
This commit is contained in:
Edward Beeching 2025-02-05 23:59:25 +01:00 committed by GitHub
parent 138df0ca44
commit 3fd56dc7b4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 7 additions and 8 deletions

View File

@ -8,8 +8,7 @@
set -x -e
source ~/.bashrc
conda activate openr1
source openr1/bin/activate
TASK_NAME=$1
TASKS=$2
MODEL_ID=$3
@ -31,7 +30,7 @@ fi
LM_EVAL_REPO_ID="open-r1/open-r1-eval-leaderboard"
MODEL_NAME=$(echo $MODEL_ID | sed 's/\//_/g') # replaces / with _
DETAILS_REPO_ID="open-r1//details-$MODEL_NAME"
DETAILS_REPO_ID="open-r1/details-$MODEL_NAME"
OUTPUT_DIR="eval_results/$MODEL_ID/$MODEL_REVISION/$TASK_NAME"
# We need this flag since we run this script from training jobs that use DeepSpeed and the env vars get progated which causes errors during evaluation
ACCELERATE_USE_DEEPSPEED=false

View File

@ -14,7 +14,7 @@
set -x -e
source ~/.bashrc
conda activate openr1
source openr1/bin/activate
module load cuda/12.1
echo "START TIME: $(date)"
echo "PYTHON ENV: $(which python)"

View File

@ -129,7 +129,7 @@ export LD_LIBRARY_PATH=.venv/lib/python3.11/site-packages/nvidia/nvjitlink/lib
echo "SLURM_JOB_ID: $SLURM_JOB_ID"
echo "SLURM_JOB_NODELIST: $SLURM_JOB_NODELIST"
source .venv/bin/activate
source openr1/bin/activate
# Getting the node names
nodes=$(scontrol show hostnames "$SLURM_JOB_NODELIST")

View File

@ -11,7 +11,7 @@
set -x -e
source ~/.bashrc
conda activate openr1
source openr1/bin/activate
echo "START TIME: $(date)"
echo "PYTHON ENV: $(which python)"

View File

@ -11,7 +11,7 @@
set -x -e
source ~/.bashrc
conda activate openr1
source openr1/bin/activate
echo "START TIME: $(date)"
echo "PYTHON ENV: $(which python)"

View File

@ -39,7 +39,7 @@ class ScriptArguments:
def main():
parser = HfArgumentParser(ScriptArguments)
args = parser.parse()
args = parser.parse_args_into_dataclasses()[0]
if all(file.endswith(".json") for file in args.data_files):
ds = load_dataset("json", data_files=args.data_files)