Skip to content

Commit

Permalink
use torch.float32 in alignment chapter
Browse files Browse the repository at this point in the history
  • Loading branch information
burtenshaw committed Dec 13, 2024
1 parent 1a66727 commit b34d7ed
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 7 deletions.
6 changes: 3 additions & 3 deletions 2_preference_alignment/notebooks/dpo_finetuning_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@
"# Model to fine-tune\n",
"model = AutoModelForCausalLM.from_pretrained(\n",
" pretrained_model_name_or_path=model_name,\n",
" torch_dtype=torch.float16,\n",
" torch_dtype=torch.float32,\n",
").to(device)\n",
"model.config.use_cache = False\n",
"tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
Expand Down Expand Up @@ -343,7 +343,7 @@
"provenance": []
},
"kernelspec": {
"display_name": "py310",
"display_name": ".venv",
"language": "python",
"name": "python3"
},
Expand All @@ -357,7 +357,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.15"
"version": "3.11.10"
},
"widgets": {
"application/vnd.jupyter.widget-state+json": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
"\n",
"# Authenticate to Hugging Face\n",
"from huggingface_hub import login\n",
"\n",
"login()"
]
},
Expand Down Expand Up @@ -257,15 +258,13 @@
"device = (\n",
" \"cuda\"\n",
" if torch.cuda.is_available()\n",
" else \"mps\"\n",
" if torch.backends.mps.is_available()\n",
" else \"cpu\"\n",
" else \"mps\" if torch.backends.mps.is_available() else \"cpu\"\n",
")\n",
"\n",
"# Model to fine-tune\n",
"model = AutoModelForCausalLM.from_pretrained(\n",
" pretrained_model_name_or_path=model_name,\n",
" torch_dtype=torch.float16,\n",
" torch_dtype=torch.float32,\n",
").to(device)\n",
"model.config.use_cache = False\n",
"tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
Expand Down

0 comments on commit b34d7ed

Please sign in to comment.