Skip to content

Commit

Permalink
Merge pull request #30 from aws-samples/dev
Browse files Browse the repository at this point in the history
fix rlhf hyperparams type error
  • Loading branch information
xiehust authored Oct 25, 2024
2 parents be3d870 + a5435b0 commit 0b4d482
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions backend/training/training_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,14 +144,14 @@ def create_training_yaml(self,

#如果是dpo或者kto, 暂时固定值
if stage == 'dpo':
doc['pref_beta'] = job_payload.get("pref_beta",0.1)
doc['pref_beta'] = float(job_payload.get("pref_beta",0.1))
doc['pref_loss'] = job_payload.get("pref_loss",'sigmoid')
doc['pref_ftx'] = job_payload.get("pref_ftx",0)
doc['pref_ftx'] = float(job_payload.get("pref_ftx",0))
doc['stage'] = 'dpo'
elif stage == 'kto':
doc['pref_beta'] = job_payload.get("pref_beta",0.1)
doc['pref_loss'] = job_payload.get("pref_loss",'sigmoid')
doc['pref_ftx'] = job_payload.get("pref_ftx",0)
doc['pref_beta'] = float(job_payload.get("pref_beta",0.1))
doc['pref_loss'] = float(job_payload.get("pref_loss",'sigmoid'))
doc['pref_ftx'] = float(job_payload.get("pref_ftx",0))
doc['stage'] = 'kto'

doc['model_name_or_path'] = model_id
Expand Down

0 comments on commit 0b4d482

Please sign in to comment.