Skip to content

Commit

Permalink
update RLHF use case
Browse files Browse the repository at this point in the history
  • Loading branch information
jdf-prog committed Mar 2, 2024
1 parent f800d7b commit 445d273
Showing 1 changed file with 51 additions and 34 deletions.
85 changes: 51 additions & 34 deletions blender_usage.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -16,38 +16,9 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2024-03-01 15:09:16.289772: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA\n",
"To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
"2024-03-01 15:09:17.254293: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: :/usr/local/cuda-11.8//lib64\n",
"2024-03-01 15:09:17.254403: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: :/usr/local/cuda-11.8//lib64\n",
"2024-03-01 15:09:17.254409: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n",
"WARNING:root:No ranker config provided, no ranker loaded, please load ranker first through load_ranker()\n",
"WARNING:root:No fuser config provided, no fuser loaded, please load fuser first through load_fuser()\n",
"/home/dongfu/miniconda3/envs/llm-blender/lib/python3.9/site-packages/dataclasses_json/core.py:187: RuntimeWarning: 'NoneType' object value of non-optional type load_checkpoint detected when decoding RankerConfig.\n",
" warnings.warn(\n",
"/home/dongfu/miniconda3/envs/llm-blender/lib/python3.9/site-packages/dataclasses_json/core.py:187: RuntimeWarning: 'NoneType' object value of non-optional type device detected when decoding RankerConfig.\n",
" warnings.warn(\n",
"/home/dongfu/miniconda3/envs/llm-blender/lib/python3.9/site-packages/transformers/convert_slow_tokenizer.py:515: UserWarning: The sentencepiece tokenizer that you are converting to a fast tokenizer uses the byte fallback option which is not implemented in the fast tokenizers. In practice this means that the fast version of the tokenizer can produce unknown tokens whereas the sentencepiece version would have converted these unknown tokens into a sequence of byte tokens matching the original piece of text.\n",
" warnings.warn(\n",
"/home/dongfu/miniconda3/envs/llm-blender/lib/python3.9/site-packages/torch/_utils.py:831: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()\n",
" return self.fget.__get__(instance, owner)()\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Successfully loaded ranker from /home/dongfu/data/.cache/huggingface/hub/llm-blender/PairRM\n"
]
}
],
"outputs": [],
"source": [
"import os\n",
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
Expand Down Expand Up @@ -108,8 +79,7 @@
"metadata": {},
"outputs": [],
"source": [
"ranks = blender.rank(inputs, candidates_texts, instructions=insts, return_scores=False, batch_size=2)\n",
"ranks = blender.rank_with_ref(inputs, candidates_texts, return_scores=False, batch_size=2, mode=\"longest\")"
"ranks = blender.rank(inputs, candidates_texts, instructions=insts, return_scores=False, batch_size=2)"
]
},
{
Expand Down Expand Up @@ -259,8 +229,55 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## Use case 5: Use PairRM for RLHF tuning"
"## Use case 5: Use PairRM for RLHF tuning\n",
"\n",
"To get scalar rewards, you can use `blender.rank_with_ref` method (see the example below)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"rewards = blender.rank_with_ref(inputs, candidates_texts, return_scores=True, batch_size=2, mode=\"longest\")\n",
"print(\"Rewards for input 1:\", rewards[0]) # rewards of candidates for input 1\n",
"\"\"\"\n",
"rewards is a List[List[float]] of shape (len(inputs), len(candidates_texts[0])).\n",
"representing the rewards of each candidate for each input.\n",
"By default, the rewards are calculated based on the the comparison with the longest generation as a reference.(mode=\"longest\").\n",
"other supported modes are \"shortest\" \"median_length\" \"first\" \"last\"\n",
"\"\"\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You can also pass a list of references to compare with, instead of automatically selecting one from the candidates as the fixed reference.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"ref_candidates = [_c[0] for _c in candidates_texts] # use the first candidate as the reference, same as mode=\"first\"\n",
"rewards = blender.rank_with_ref(inputs, candidates_texts, return_scores=True, batch_size=2, ref_candidates=ref_candidates) \n",
"\"\"\"\n",
"ref_candidates = [ref1, ref2, ref3, ...] # ref_candidates is a List[str], shape (len(inputs),)\n",
"this parameter will override the mode parameter, and use the ref_candidates as the reference for reward calculation.\n",
"rewards is a List[List[float]] of shape (len(inputs), len(candidates_texts[0])).\n",
"\"\"\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down

0 comments on commit 445d273

Please sign in to comment.