From 419b33a0eff3931fd41954e5e07ac7b523391983 Mon Sep 17 00:00:00 2001 From: tj-solergibert Date: Sat, 27 Jul 2024 11:03:24 +0000 Subject: [PATCH] First prototype, let's jump padding free --- convert_hf_nanotron.ipynb | 764 ++++++++++++++++++++ examples/config_llama_sft.yaml | 97 +++ run_train.py | 30 +- src/nanotron/config/config.py | 18 +- src/nanotron/data/chat_dataset.py | 139 ++++ src/nanotron/data/chat_tokenizer.py | 81 +++ src/nanotron/data/collator.py | 89 ++- src/nanotron/data/dataloader_builder.py | 35 +- src/nanotron/models/llama_sft.py | 888 ++++++++++++++++++++++++ src/nanotron/trainer.py | 11 +- 10 files changed, 2141 insertions(+), 11 deletions(-) create mode 100644 convert_hf_nanotron.ipynb create mode 100644 examples/config_llama_sft.yaml create mode 100644 src/nanotron/data/chat_dataset.py create mode 100644 src/nanotron/data/chat_tokenizer.py create mode 100644 src/nanotron/models/llama_sft.py diff --git a/convert_hf_nanotron.ipynb b/convert_hf_nanotron.ipynb new file mode 100644 index 00000000..943b1af9 --- /dev/null +++ b/convert_hf_nanotron.ipynb @@ -0,0 +1,764 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from torch.testing import assert_close\n", + "\n", + "import os\n", + "\n", + "dtype = torch.bfloat16\n", + "device = torch.device(\"cuda\")\n", + "\n", + "os.environ[\"WORLD_SIZE\"] = \"1\"\n", + "os.environ[\"RANK\"] = \"0\"\n", + "os.environ[\"MASTER_ADDR\"] = \"0.0.0.0\"\n", + "os.environ[\"MASTER_PORT\"] = \"6000\"" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/solergib/.local/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n", + "You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.\n", + "Loading checkpoint shards: 100%|██████████| 4/4 [00:00<00:00, 13.70it/s]\n" + ] + } + ], + "source": [ + "from transformers import AutoModelForCausalLM\n", + "PATH_TO_LLAMA = \"/mloscratch/homes/solergib/models/Meta-Llama-3-8B-Instruct\"\n", + "hf_model = AutoModelForCausalLM.from_pretrained(PATH_TO_LLAMA, torch_dtype=dtype, attn_implementation=\"flash_attention_2\").to(device)\n", + "# print(hf_model)\n", + "# print(hf_model.config)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "LlamaConfig {\n", + " \"architectures\": [\n", + " \"LlamaForCausalLM\"\n", + " ],\n", + " \"attention_bias\": false,\n", + " \"attention_dropout\": 0.0,\n", + " \"bos_token_id\": 128000,\n", + " \"eos_token_id\": 128001,\n", + " \"hidden_act\": \"silu\",\n", + " \"hidden_size\": 4096,\n", + " \"initializer_range\": 0.02,\n", + " \"intermediate_size\": 14336,\n", + " \"max_position_embeddings\": 8192,\n", + " \"mlp_bias\": false,\n", + " \"model_type\": \"llama\",\n", + " \"num_attention_heads\": 32,\n", + " \"num_hidden_layers\": 32,\n", + " \"num_key_value_heads\": 8,\n", + " \"pretraining_tp\": 1,\n", + " \"rms_norm_eps\": 1e-05,\n", + " \"rope_scaling\": null,\n", + " \"rope_theta\": 500000.0,\n", + " \"tie_word_embeddings\": false,\n", + " \"torch_dtype\": \"bfloat16\",\n", + " \"transformers_version\": \"4.44.0.dev0\",\n", + " \"use_cache\": true,\n", + " \"vocab_size\": 128256\n", + "}\n", + "\n" + ] + } + ], + "source": [ + "from transformers import LlamaConfig\n", + "hf_config = LlamaConfig.from_pretrained(PATH_TO_LLAMA)\n", + "print(hf_config)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "from nanotron.config import ParallelismArgs\n", + "from nanotron.parallel import ParallelContext\n", + "from nanotron.parallel.pipeline_parallel.engine import AllForwardAllBackwardPipelineEngine\n", + "from nanotron.parallel.tensor_parallel.nn import TensorParallelLinearMode\n", + "\n", + "DP = 1\n", + "PP = 1\n", + "TP = 1\n", + "\n", + "parallel_config = ParallelismArgs(\n", + " dp=DP,\n", + " pp=PP,\n", + " tp=TP,\n", + " pp_engine=AllForwardAllBackwardPipelineEngine(),\n", + " tp_mode=TensorParallelLinearMode.ALL_REDUCE,\n", + " tp_linear_async_communication=False,\n", + ")\n", + "assert (\n", + " parallel_config.tp_mode == TensorParallelLinearMode.ALL_REDUCE\n", + " and parallel_config.tp_linear_async_communication is False\n", + ")\n", + "\n", + "parallel_context = ParallelContext(\n", + " data_parallel_size=parallel_config.dp,\n", + " pipeline_parallel_size=parallel_config.pp,\n", + " tensor_parallel_size=parallel_config.tp,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "from nanotron.config.models_config import LlamaConfig as LlamaConfigNanotron\n", + "\n", + "nanotron_config = LlamaConfigNanotron(\n", + " bos_token_id=hf_config.bos_token_id,\n", + " eos_token_id=hf_config.eos_token_id,\n", + " hidden_act=hf_config.hidden_act,\n", + " hidden_size=hf_config.hidden_size,\n", + " initializer_range=hf_config.initializer_range,\n", + " intermediate_size=hf_config.intermediate_size,\n", + " is_llama_config=True,\n", + " max_position_embeddings=hf_config.max_position_embeddings,\n", + " num_attention_heads=hf_config.num_attention_heads,\n", + " num_hidden_layers=hf_config.num_hidden_layers,\n", + " num_key_value_heads=hf_config.num_key_value_heads,\n", + " pad_token_id=None,\n", + " pretraining_tp=hf_config.pretraining_tp,\n", + " rms_norm_eps=hf_config.rms_norm_eps,\n", + " rope_scaling=hf_config.rope_scaling,\n", + " rope_theta=hf_config.rope_theta,\n", + " rope_interleaved=False,\n", + " tie_word_embeddings=hf_config.tie_word_embeddings,\n", + " use_cache=hf_config.use_cache,\n", + " vocab_size=hf_config.vocab_size,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "from nanotron.models.llama_sft import LlamaForSFT\n", + "from nanotron.models import build_model\n", + "\n", + "nanotron_model = build_model(\n", + " model_builder=lambda: LlamaForSFT(\n", + " config=nanotron_config,\n", + " parallel_context=parallel_context,\n", + " parallel_config=parallel_config,\n", + " random_states=None,\n", + " ),\n", + " parallel_context=parallel_context,\n", + " dtype=dtype,\n", + " device=device,\n", + ")\n", + "# print(nanotron_model)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "from nanotron.trainer import mark_tied_parameters\n", + "\n", + "mark_tied_parameters(model=nanotron_model, parallel_context=parallel_context)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "ShardedInfo(global_ranks=(0,), local_global_slices_pairs=(SlicesPair(local_slices=(slice(None, None, None), slice(None, None, None)), global_slices=(slice(0, 128256, None), slice(None, None, None))),), unsharded_shape=(128256, 4096))" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "nanotron_model.model.token_position_embeddings.pp_block.token_embedding.weight.get_sharded_info()" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "False" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "nanotron_model.model.token_position_embeddings.pp_block.token_embedding.weight.is_tied" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "# Final script\n", + "# TODO Añadir variables de TP para splitear los parametros de las layers de HF\n", + "# TODO Cargar modelo HF en cpu y copiar desde ahi\n", + "\n", + "\n", + "# Token embeddings\n", + "assert nanotron_model.model.token_position_embeddings.pp_block.token_embedding.weight.shape == hf_model.model.embed_tokens.weight.shape\n", + "\n", + "with torch.no_grad():\n", + " nanotron_model.model.token_position_embeddings.pp_block.token_embedding.weight.copy_(hf_model.model.embed_tokens.weight)# = hf_model.model.embed_tokens.weight.data\n", + "\n", + "# Decoder layers\n", + "for i in range(nanotron_config.num_hidden_layers):\n", + " # Input layer norm\n", + " assert hf_model.model.layers[i].input_layernorm.weight.shape == nanotron_model.model.decoder[i].pp_block.input_layernorm.weight.shape\n", + " with torch.no_grad():\n", + " nanotron_model.model.decoder[i].pp_block.input_layernorm.weight.copy_(hf_model.model.layers[i].input_layernorm.weight)# = hf_model.model.layers[i].input_layernorm.weight\n", + " # Self attn\n", + " ## QKV\n", + " tmp_qkv_proj = torch.cat([\n", + " hf_model.model.layers[i].self_attn.q_proj.weight,\n", + " hf_model.model.layers[i].self_attn.k_proj.weight,\n", + " hf_model.model.layers[i].self_attn.v_proj.weight\n", + " ], dim = 0) \n", + " assert tmp_qkv_proj.shape == nanotron_model.model.decoder[i].pp_block.attn.qkv_proj.weight.shape\n", + " with torch.no_grad():\n", + " nanotron_model.model.decoder[i].pp_block.attn.qkv_proj.weight.copy_(tmp_qkv_proj)# = tmp_qkv_proj # torch.nn.Parameter(tmp_qkv_proj)\n", + " \n", + " ## O\n", + " assert hf_model.model.layers[i].self_attn.o_proj.weight.shape == nanotron_model.model.decoder[i].pp_block.attn.o_proj.weight.shape\n", + " with torch.no_grad():\n", + " nanotron_model.model.decoder[i].pp_block.attn.o_proj.weight.copy_(hf_model.model.layers[i].self_attn.o_proj.weight)# = hf_model.model.layers[i].self_attn.o_proj.weight\n", + " # MLP\n", + " ## Gate Up Proj\n", + " tmp_gate_up_proj = torch.cat([\n", + " hf_model.model.layers[i].mlp.gate_proj.weight,\n", + " hf_model.model.layers[i].mlp.up_proj.weight,\n", + " ], dim = 0)\n", + "\n", + " assert tmp_gate_up_proj.shape == nanotron_model.model.decoder[i].pp_block.mlp.gate_up_proj.weight.shape\n", + " with torch.no_grad():\n", + " nanotron_model.model.decoder[i].pp_block.mlp.gate_up_proj.weight.copy_(tmp_gate_up_proj)# = tmp_gate_up_proj\n", + " ## Down Proj\n", + " assert hf_model.model.layers[i].mlp.down_proj.weight.shape == nanotron_model.model.decoder[i].pp_block.mlp.down_proj.weight.shape\n", + " with torch.no_grad():\n", + " nanotron_model.model.decoder[i].pp_block.mlp.down_proj.weight.copy_(hf_model.model.layers[i].mlp.down_proj.weight)# = hf_model.model.layers[i].mlp.down_proj.weight\n", + "\n", + "\n", + " # Post attn layer norm\n", + " assert hf_model.model.layers[i].post_attention_layernorm.weight.shape == nanotron_model.model.decoder[i].pp_block.post_attention_layernorm.weight.shape\n", + " with torch.no_grad():\n", + " nanotron_model.model.decoder[i].pp_block.post_attention_layernorm.weight.copy_(hf_model.model.layers[i].post_attention_layernorm.weight)# = hf_model.model.layers[i].post_attention_layernorm.weight\n", + " \n", + "# Last layer norm\n", + "assert nanotron_model.model.final_layer_norm.pp_block.weight.shape == hf_model.model.norm.weight.shape\n", + "with torch.no_grad():\n", + " nanotron_model.model.final_layer_norm.pp_block.weight.copy_(hf_model.model.norm.weight)# = hf_model.model.norm.weight\n", + "# LM_Head\n", + "assert nanotron_model.model.lm_head.pp_block.weight.shape == hf_model.lm_head.weight.shape\n", + "with torch.no_grad():\n", + " nanotron_model.model.lm_head.pp_block.weight.copy_(hf_model.lm_head.weight)# = hf_model.lm_head.weight" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "from nanotron.data.chat_dataset import ChatDataset\n", + "from nanotron.data.dataloader_builder import build_chat_dataloader\n", + "\n", + "train_dataset = ChatDataset(\n", + " dataset_path=\"Open-Orca/SlimOrca\",\n", + " tokenizer_name_or_path=PATH_TO_LLAMA,\n", + " sequence_length=2048,\n", + " train_on_completions_only=True,\n", + " remove_cross_attention=True,\n", + " split=\"train\",\n", + " conversation_column_name=\"conversations\",\n", + " dp_rank=parallel_context.dp_pg.rank(),\n", + " dp_ranks_size=parallel_context.dp_pg.size(),\n", + ")\n", + "\n", + "# Prepare dataloader\n", + "train_dataloader = build_chat_dataloader(\n", + " dataset=train_dataset,\n", + " sequence_length=2048,\n", + " parallel_context=parallel_context,\n", + " input_pp_rank=0,\n", + " output_pp_rank=0,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "batch = next(iter(train_dataloader))" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009,\n", + " 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009,\n", + " 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009,\n", + " 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009,\n", + " 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009,\n", + " 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009,\n", + " 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009,\n", + " 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009,\n", + " 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009,\n", + " 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009,\n", + " 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009,\n", + " 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009,\n", + " 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009,\n", + " 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009,\n", + " 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009,\n", + " 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009,\n", + " 128009, 128009, 128009, 128009, 128009, 128009]], dtype=torch.int32)" + ] + }, + "execution_count": 31, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "batch[\"input_ids\"][:, -150:]" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[128000, 128006, 26380, ..., 13, 128009, 128001]],\n", + " dtype=torch.int32)" + ] + }, + "execution_count": 32, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "batch[\"input_ids\"][:, :-150]" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "LlamaForCausalLM(\n", + " (model): LlamaModel(\n", + " (embed_tokens): Embedding(128256, 4096)\n", + " (layers): ModuleList(\n", + " (0-31): 32 x LlamaDecoderLayer(\n", + " (self_attn): LlamaFlashAttention2(\n", + " (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n", + " (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", + " (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", + " (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n", + " (rotary_emb): LlamaRotaryEmbedding()\n", + " )\n", + " (mlp): LlamaMLP(\n", + " (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n", + " (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n", + " (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): LlamaRMSNorm()\n", + " (post_attention_layernorm): LlamaRMSNorm()\n", + " )\n", + " )\n", + " (norm): LlamaRMSNorm()\n", + " (rotary_emb): LlamaRotaryEmbedding()\n", + " )\n", + " (lm_head): Linear(in_features=4096, out_features=128256, bias=False)\n", + ")" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "nanotron_model.eval()\n", + "hf_model.eval()" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": {}, + "outputs": [], + "source": [ + "with torch.no_grad():\n", + " output_nanotron = nanotron_model.model(input_ids=batch[\"input_ids\"][:, :-150].cuda(), position_ids = batch[\"position_ids\"][:, :-150].cuda())" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "PEPEPEPEPE\n", + "PEPEPEPEPE\n", + "PEPEPEPEPE\n", + "PEPEPEPEPE\n", + "PEPEPEPEPE\n", + "PEPEPEPEPE\n", + "PEPEPEPEPE\n", + "PEPEPEPEPE\n", + "PEPEPEPEPE\n", + "PEPEPEPEPE\n", + "PEPEPEPEPE\n", + "PEPEPEPEPE\n", + "PEPEPEPEPE\n", + "PEPEPEPEPE\n", + "PEPEPEPEPE\n", + "PEPEPEPEPE\n", + "PEPEPEPEPE\n", + "PEPEPEPEPE\n", + "PEPEPEPEPE\n", + "PEPEPEPEPE\n", + "PEPEPEPEPE\n", + "PEPEPEPEPE\n", + "PEPEPEPEPE\n", + "PEPEPEPEPE\n", + "PEPEPEPEPE\n", + "PEPEPEPEPE\n", + "PEPEPEPEPE\n", + "PEPEPEPEPE\n", + "PEPEPEPEPE\n", + "PEPEPEPEPE\n", + "PEPEPEPEPE\n", + "PEPEPEPEPE\n" + ] + } + ], + "source": [ + "with torch.no_grad():\n", + " output_hf = hf_model(input_ids=batch[\"input_ids\"][:, :-150].cuda(), position_ids = batch[\"position_ids\"][:, :-150].cuda())" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "metadata": {}, + "outputs": [ + { + "ename": "AssertionError", + "evalue": "Tensor-likes are not close!\n\nMismatched elements: 243083431 / 243429888 (99.9%)\nGreatest absolute difference: 46.65625 at index (0, 1125, 22) (up to 1e-05 allowed)\nGreatest relative difference: 74448896.0 at index (0, 715, 31230) (up to 1.3e-06 allowed)", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[38], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mtesting\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m assert_close\n\u001b[0;32m----> 3\u001b[0m \u001b[43massert_close\u001b[49m\u001b[43m(\u001b[49m\u001b[43moutput_hf\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlogits\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moutput_nanotron\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtranspose\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m,\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/testing/_comparison.py:1520\u001b[0m, in \u001b[0;36massert_close\u001b[0;34m(actual, expected, allow_subclasses, rtol, atol, equal_nan, check_device, check_dtype, check_layout, check_stride, msg)\u001b[0m\n\u001b[1;32m 1498\u001b[0m error_metas \u001b[38;5;241m=\u001b[39m not_close_error_metas(\n\u001b[1;32m 1499\u001b[0m actual,\n\u001b[1;32m 1500\u001b[0m expected,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1515\u001b[0m msg\u001b[38;5;241m=\u001b[39mmsg,\n\u001b[1;32m 1516\u001b[0m )\n\u001b[1;32m 1518\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m error_metas:\n\u001b[1;32m 1519\u001b[0m \u001b[38;5;66;03m# TODO: compose all metas into one AssertionError\u001b[39;00m\n\u001b[0;32m-> 1520\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m error_metas[\u001b[38;5;241m0\u001b[39m]\u001b[38;5;241m.\u001b[39mto_error(msg)\n", + "\u001b[0;31mAssertionError\u001b[0m: Tensor-likes are not close!\n\nMismatched elements: 243083431 / 243429888 (99.9%)\nGreatest absolute difference: 46.65625 at index (0, 1125, 22) (up to 1e-05 allowed)\nGreatest relative difference: 74448896.0 at index (0, 715, 31230) (up to 1.3e-06 allowed)" + ] + } + ], + "source": [ + "from torch.testing import assert_close\n", + "\n", + "assert_close(output_hf.logits, output_nanotron.transpose(0,1))" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[HF Model] Next token: 11415, probability: 0.10412170737981796\n", + "[HF Model] Next token: 1523, probability: 0.04918361455202103\n", + "[HF Model] Next token: 47032, probability: 0.043404385447502136\n", + "[HF Model] Next token: 72514, probability: 0.03830423951148987\n", + "[HF Model] Next token: 3493, probability: 0.03830423951148987\n", + "[HF Model] Next token: 10477, probability: 0.03830423951148987\n", + "[HF Model] Next token: 16805, probability: 0.03175532445311546\n", + "[HF Model] Next token: 10552, probability: 0.026326090097427368\n", + "[HF Model] Next token: 7664, probability: 0.021825095638632774\n", + "[HF Model] Next token: 3041, probability: 0.018093638122081757\n" + ] + } + ], + "source": [ + "predicted_token = 34\n", + "\n", + "next_tokens_hf = torch.softmax(output_hf.logits[0, predicted_token, :], -1)\n", + "hf_topk_next_tokens= torch.topk(next_tokens_hf, 10)\n", + "\n", + "\n", + "print(*[f\"[HF Model] Next token: {idx.item()}, probability: {prob}\" for idx, prob in zip(hf_topk_next_tokens.indices, hf_topk_next_tokens.values)], sep=\"\\n\")" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[Nanotron Model] Next token: 220, probability: 0.0804644376039505\n", + "[Nanotron Model] Next token: 994, probability: 0.029601214453577995\n", + "[Nanotron Model] Next token: 3639, probability: 0.02612297795712948\n", + "[Nanotron Model] Next token: 656, probability: 0.024540266022086143\n", + "[Nanotron Model] Next token: 279, probability: 0.024540266022086143\n", + "[Nanotron Model] Next token: 3277, probability: 0.021656708791851997\n", + "[Nanotron Model] Next token: 264, probability: 0.013982621021568775\n", + "[Nanotron Model] Next token: 1148, probability: 0.01022990420460701\n", + "[Nanotron Model] Next token: 507, probability: 0.01022990420460701\n", + "[Nanotron Model] Next token: 323, probability: 0.01022990420460701\n" + ] + } + ], + "source": [ + "next_tokens_nanotron = torch.softmax(output_nanotron.transpose(0,1)[0, predicted_token, :], -1)\n", + "nanotron_topk_next_tokens= torch.topk(next_tokens_nanotron, 10)\n", + "\n", + "\n", + "print(*[f\"[Nanotron Model] Next token: {idx.item()}, probability: {prob}\" for idx, prob in zip(nanotron_topk_next_tokens.indices, nanotron_topk_next_tokens.values)], sep=\"\\n\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Save the Nanotron model" + ] + }, + { + "cell_type": "code", + "execution_count": 97, + "metadata": {}, + "outputs": [], + "source": [ + "from nanotron.parallel.parameters import sanity_check\n", + "\n", + "sanity_check(root_module=nanotron_model)" + ] + }, + { + "cell_type": "code", + "execution_count": 98, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Saving weights: 100%|██████████| 195/195 [00:41<00:00, 4.67it/s]\n" + ] + } + ], + "source": [ + "from pathlib import Path\n", + "from nanotron.serialize import save_meta, save_weights, TrainingMetadata\n", + "from nanotron.serialize.metadata import DataStageMetadata\n", + "\n", + "out_path = \"/mloscratch/homes/solergib/converter/nanotron/n_c/first/\"\n", + "out_path = Path(out_path)\n", + "\n", + "save_weights(model=nanotron_model, parallel_context=parallel_context, root_folder=out_path)\n", + "\n", + "training_metadata = TrainingMetadata(last_train_step=0, consumed_train_samples=0, data_stages=[DataStageMetadata(name=\"Empty\", consumed_train_samples=0, start_training_step=0)])\n", + "\n", + "save_meta(root_folder=out_path, parallel_context=parallel_context, training_metadata=training_metadata)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 99, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Saving config ...\n", + "Saving model config ...\n" + ] + } + ], + "source": [ + "import json \n", + "import yaml\n", + "from nanotron.config import GeneralArgs, ModelArgs, TokenizerArgs, Config\n", + "from nanotron.config.models_config import ExistingCheckpointInit\n", + "from dataclasses import asdict\n", + "\n", + "with open(out_path / \"config.yaml\", \"w\") as f:\n", + " config = Config(\n", + " general=GeneralArgs(project=\"conversion\", run=\"Llama3-8B\"),\n", + " parallelism=parallel_config,\n", + " model=ModelArgs(\n", + " init_method=ExistingCheckpointInit(out_path),\n", + " model_config=nanotron_config,\n", + " ),\n", + " tokenizer=TokenizerArgs(PATH_TO_LLAMA),\n", + " )\n", + " print(\"Saving config ...\")\n", + " yaml.dump(config.as_dict(), f)\n", + "\n", + "with open(out_path / \"model_config.json\", \"w\") as f:\n", + " print(\"Saving model config ...\")\n", + " json.dump(asdict(nanotron_config), f)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/mloscratch/homes/solergib/SFT/transformers/src/transformers/deepspeed.py:24: FutureWarning: transformers.deepspeed module is deprecated and will be removed in a future version. Please import deepspeed modules directly from transformers.integrations\n", + " warnings.warn(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'input_ids': tensor([[27, 22, 0, 97, 13, 49, 56, 35, 70, 91, 38, 30, 26, 94, 68, 46, 89, 32,\n", + " 70, 85, 50, 67, 70, 86, 66, 82, 18, 72, 27, 37, 91, 27, 60, 57, 23, 93,\n", + " 10, 80, 82, 26, 13, 50, 12, 68, 63, 85, 55, 1, 3, 61, 37, 70, 12, 97,\n", + " 1, 59, 90, 45, 74, 62, 66, 54, 94, 18, 54, 89, 49, 3, 66, 55]],\n", + " device='cuda:0'), 'position_ids': tensor([[0, 0, 1, 0, 1, 2, 0, 1, 2, 3, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 5, 0, 1, 2,\n", + " 3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5,\n", + " 6, 0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6]],\n", + " device='cuda:0')}\n" + ] + } + ], + "source": [ + "import sys\n", + "sys.path.append(\"/mloscratch/homes/solergib/SFT/transformers\")\n", + "\n", + "import torch\n", + "from t_tests.models.llama.test_modeling_llama import LlamaModelTester\n", + "\n", + "lmt = LlamaModelTester(parent=None)\n", + "\n", + "_, inputs_dict = lmt.prepare_config_and_inputs_for_common()\n", + "dummy_attention_mask = inputs_dict[\"attention_mask\"]\n", + "inputs_dict[\"input_ids\"][~dummy_attention_mask.bool()] = 0\n", + "\n", + "padfree_inputs_dict = {\n", + " k: v[dummy_attention_mask.bool()].unsqueeze(0)\n", + " for k, v in inputs_dict.items()\n", + " if not k == \"attention_mask\"\n", + "}\n", + "\n", + "padfree_inputs_dict[\"position_ids\"] = (\n", + " torch.cat([torch.arange(length) for length in dummy_attention_mask.sum(1).tolist()])\n", + " .long()\n", + " .unsqueeze(0)\n", + " .to(\"cuda\")\n", + ")\n", + "\n", + "print(padfree_inputs_dict)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/config_llama_sft.yaml b/examples/config_llama_sft.yaml new file mode 100644 index 00000000..d65f7683 --- /dev/null +++ b/examples/config_llama_sft.yaml @@ -0,0 +1,97 @@ +checkpoints: + checkpoint_interval: 1000 + checkpoints_path: /mloscratch/homes/solergib/converter/nanotron/checkpoints + checkpoints_path_is_shared_file_system: false + resume_checkpoint_path: null + save_initial_state: false +data_stages: +- data: + dataset: + hf_dataset: Open-Orca/SlimOrca + hf_dataset_split: train + conversation_column_name: conversations + train_on_completions_only: true + remove_cross_attention: true + num_loading_workers: 1 + seed: 42 + name: General purpose training (Single dataset) + start_training_step: 1 +general: + benchmark_csv_path: null + consumed_train_samples: null + ignore_sanity_checks: true + project: Chat + run: Llama3-8B + seed: 42 + step: null +lighteval: null +logging: + iteration_step_info_interval: 1 + log_level: info + log_level_replica: info +model: + ddp_bucket_cap_mb: 25 + dtype: bfloat16 + init_method: + std: 0.025 + make_vocab_size_divisible_by: 1 + model_config: + bos_token_id: 128000 + eos_token_id: 128001 + hidden_act: silu + hidden_size: 4096 + initializer_range: 0.02 + intermediate_size: 14336 + is_llama_config: true + max_position_embeddings: 4096 + num_attention_heads: 32 + num_hidden_layers: 4 + num_key_value_heads: 8 + pad_token_id: null + pretraining_tp: 1 + rms_norm_eps: 1.0e-05 + rope_scaling: null + rope_theta: 500000.0 + tie_word_embeddings: false + use_cache: true + vocab_size: 128256 +optimizer: + accumulate_grad_in_fp32: true + clip_grad: 1.0 + learning_rate_scheduler: + learning_rate: 0.0003 + lr_decay_starting_step: null + lr_decay_steps: 98 + lr_decay_style: cosine + lr_warmup_steps: 2 + lr_warmup_style: linear + min_decay_lr: 1.0e-05 + optimizer_factory: + adam_beta1: 0.9 + adam_beta2: 0.95 + adam_eps: 1.0e-08 + name: adamW + torch_adam_is_fused: true + weight_decay: 0.01 + zero_stage: 0 +parallelism: + dp: 1 + expert_parallel_size: 1 + pp: 1 + pp_engine: 1f1b + tp: 1 + tp_linear_async_communication: false + tp_mode: ALL_REDUCE +profiler: null +tokenizer: + tokenizer_max_length: null + tokenizer_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct + tokenizer_revision: null +tokens: + batch_accumulation_per_replica: 1 + limit_test_batches: 0 + limit_val_batches: 0 + micro_batch_size: 3 + sequence_length: 4096 + train_steps: 100 + val_check_interval: -1 diff --git a/run_train.py b/run_train.py index 021d955d..60f01373 100644 --- a/run_train.py +++ b/run_train.py @@ -12,8 +12,9 @@ import numpy as np from nanotron import logging -from nanotron.config import DataArgs, DatasetStageArgs, NanosetDatasetsArgs, PretrainDatasetsArgs -from nanotron.data.dataloader_builder import build_nanoset_dataloader +from nanotron.config import ChatDatasetsArgs, DataArgs, DatasetStageArgs, NanosetDatasetsArgs, PretrainDatasetsArgs +from nanotron.data.chat_dataset import ChatDataset +from nanotron.data.dataloader_builder import build_chat_dataloader, build_nanoset_dataloader from nanotron.dataloader import ( clm_process, dummy_infinite_data_generator, @@ -171,6 +172,31 @@ def get_dataloader_from_data_stage( dataloader_drop_last=True, ) + return train_dataloader + # Case 4: Chat Datasets + elif isinstance(data.dataset, ChatDatasetsArgs): + with main_rank_first(trainer.parallel_context.world_pg): + train_dataset = ChatDataset( + dataset_path=data.dataset.hf_dataset, + tokenizer_name_or_path=trainer.config.tokenizer.tokenizer_name_or_path, + sequence_length=trainer.sequence_length, + train_on_completions_only=data.dataset.train_on_completions_only, + remove_cross_attention=data.dataset.remove_cross_attention, + split=data.dataset.hf_dataset_split, + conversation_column_name=data.dataset.conversation_column_name, + dp_rank=trainer.parallel_context.dp_pg.rank(), + dp_ranks_size=trainer.parallel_context.dp_pg.size(), + ) + + # Prepare dataloader + train_dataloader = build_chat_dataloader( + dataset=train_dataset, + sequence_length=trainer.sequence_length, + parallel_context=trainer.parallel_context, + input_pp_rank=input_pp_rank, + output_pp_rank=output_pp_rank, + ) + return train_dataloader else: raise ValueError(f"Unhandled case of `self.config.data.dataset`. Got: {data.dataset}") diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index 05b49955..96337e9a 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -107,11 +107,27 @@ def __post_init__(self): self.dataset_weights = list(tmp_dataset_folder.values()) +@dataclass +class ChatDatasetsArgs: + hf_dataset: str + hf_dataset_split: str + conversation_column_name: str + # Debug + train_on_completions_only: bool = True + remove_cross_attention: bool = True + + def __post_init__(self): + if self.hf_dataset_split is None: + self.hf_dataset_split = "train" + if self.conversation_column_name is None: + self.conversation_column_name = "conversations" + + @dataclass class DataArgs: """Arguments related to the data and data files processing""" - dataset: Union[PretrainDatasetsArgs, NanosetDatasetsArgs] + dataset: Union[PretrainDatasetsArgs, NanosetDatasetsArgs, ChatDatasetsArgs] seed: Optional[int] num_loading_workers: Optional[int] = 1 diff --git a/src/nanotron/data/chat_dataset.py b/src/nanotron/data/chat_dataset.py new file mode 100644 index 00000000..ac46ba42 --- /dev/null +++ b/src/nanotron/data/chat_dataset.py @@ -0,0 +1,139 @@ +from typing import List + +import numpy as np +from datasets import load_dataset +from datasets.distributed import split_dataset_by_node +from nanotron.data.chat_tokenizer import ChatTokenizer +from nanotron.data.collator import ( + build_labels, + build_labels_completions_only, + build_position_ids, + build_position_ids_dummy, +) +from torch.utils.data import IterableDataset +from transformers import AutoTokenizer + + +class ChatDataset(IterableDataset): + """ + Chat Dataset for training models with: + 1. Packing + 2. No cross-contamination between packed samples + 3. Train on completitions only + + Args: + dataset_path (str): Path to the dataset in the file system. If provided, data will be loaded from this path instead of downloaded. + tokenizer_name_or_path (str): Path to a directory containing vocabulary files required by the tokenizer or the model id of a predefined tokenizer hosted inside a model repo on the Hugging Face Hub. + seq_len (int): max sequence length + train_on_completions_only (bool): Whether to just train on completitions or not. To be deleted + remove_cross_attention (bool): Whether to just attend to the tokens from the same sample or to all (Vanilla mechanism). To be deleted + split (str): Split of the dataset to train on + conversation_column_name (str): Column name of the dataset containing the conversations + dp_rank (int): rank of the current data parallel process + dp_ranks_size (int): number of data parallel processes participating in training + """ + + def __init__( + self, + dataset_path: str, + tokenizer_name_or_path, + sequence_length: int, + conversation_column_name: str, + train_on_completions_only: bool = True, + remove_cross_attention: bool = True, + split: str = "train", + dp_rank: int = 0, + dp_ranks_size: int = 1, + skip_num_samples: int = None, # TODO Delete, check later comment + seed: int = 1234, + ) -> None: + + # TODO: Support checkpointing for resuming training. We have to store the number of consumed samples from the dataset (Which is different from the number of steps) and the buffers. + # skip_num_samples will fail, as it's computed with the number of steps and as we are packing sequences we might have consumed MORE samples from the dataset + # TODO: Support interleaving datasets + + self.dataset_path = dataset_path + self.chat_tokenizer = ChatTokenizer(tokenizer_name_or_path) + self.sequence_length = sequence_length + self.conversation_column_name = conversation_column_name + self.skip_num_samples = skip_num_samples + self.seed = seed + + # Load, split and shuffle dataset. Also skip samples if resuming training. + self.dataset = load_dataset(dataset_path, split=split, streaming=True) + self.dataset = split_dataset_by_node(self.dataset, dp_rank, dp_ranks_size) + self.dataset = self.dataset.shuffle(seed=seed, buffer_size=10_000) + + # TODO delete, just 4 switching the training only on completitions setting + if train_on_completions_only: + self.create_labels = build_labels_completions_only + else: + self.create_labels = build_labels + + # TODO delete, just 4 switching the remove cross-attention setting + if remove_cross_attention: + self.create_position_ids = build_position_ids + else: + self.create_position_ids = build_position_ids_dummy + + # Todo delete (debug), just change the dict keys + self.debug_tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path) # TODO delete debug + self.debug_tokenizer.chat_template = "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['from'] + '<|end_header_id|>\n\n'+ message['value'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>' }}{% endif %}" + + def __iter__(self): + max_buffer_token_len = 1 + self.sequence_length + buffer_tokens: List[int] = [] + buffer_is_completition: List[int] = [] + buffer_lengths: List[int] = [] + + while True: + for sample in iter(self.dataset): + tokens, is_completition = self.chat_tokenizer(sample[self.conversation_column_name]) + + # TODO assert that tokenized conversations are not longer than max_buffer_token_len? + + # TODO delete (debug). The [:-1] of tokens is because apply chat template doesn't adds eos (NOT eot) token + assert ( + self.debug_tokenizer.apply_chat_template(sample["conversations"]) == tokens[:-1] + ), f'{self.debug_tokenizer.apply_chat_template(sample["conversations"])}\n\n{tokens[:-1]}' + + buffer_tokens.extend(tokens) + buffer_is_completition.extend(is_completition) + buffer_lengths.append(len(tokens)) + + if len(buffer_tokens) > max_buffer_token_len: # Can't pack more samples, yield + # Pop last sample from buffers + sample_tokens = buffer_tokens[: -len(tokens)] + sample_completitions = buffer_is_completition[: -len(tokens)] + sample_lengths = buffer_lengths[:-1] + + # TODO delete (debug) + assert len(sample_tokens) == len(sample_completitions) == sum(sample_lengths) + + # Reset tokens buffers + buffer_tokens = tokens.copy() + buffer_is_completition = is_completition.copy() + buffer_lengths = [len(tokens)] + + # Pad to max_buffer_token_len. Pad token added in ChatTokenizer init if necessary + sample_tokens.extend( + [self.chat_tokenizer.tokenizer.pad_token_id] * (max_buffer_token_len - len(sample_tokens)) + ) + sample_completitions.extend([False] * (max_buffer_token_len - len(sample_completitions))) + + # TODO delete, just 4 switching the training only on completitions setting + labels = self.create_labels(sample_tokens, sample_completitions) + + # TODO delete, just 4 switching the remove cross-attention setting + position_ids = self.create_position_ids(sample_lengths, self.sequence_length) + + # TODO delete (debug) + assert len(sample_tokens) == max_buffer_token_len + + yield { + "input_ids": np.array(sample_tokens[:-1], dtype=np.int32), + "label_ids": labels, + "position_ids": position_ids, + } + + print("Consumed all samples, dataset is being re-looped.") diff --git a/src/nanotron/data/chat_tokenizer.py b/src/nanotron/data/chat_tokenizer.py new file mode 100644 index 00000000..847a365f --- /dev/null +++ b/src/nanotron/data/chat_tokenizer.py @@ -0,0 +1,81 @@ +from typing import List, Tuple + +from transformers import AutoTokenizer + + +class ChatTokenizer: + """ + The ChatTokenizer encodes a conversation applying the Llama3 Chat Template and returns the role (Either User or Assistant) of each token + + Args: + tokenizer_name_or_path (str): A path to a directory containing vocabulary files required by the tokenizer or the model id of a predefined tokenizer hosted inside a model repo on the Hugging Face Hub. + """ + + def __init__(self, tokenizer_name_or_path: str): + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path) + + # Add pad token if necessary + if self.tokenizer.pad_token is None: + self.tokenizer.add_special_tokens({"pad_token": "<|eot_id|>"}) + + def __call__(self, conversation: List[dict]) -> Tuple[List[int], List[bool]]: + """ + Applies the Llama3 chat template, encodes the conversation and returns the tokens along with a bool value for each token whether if the token belongs to the answer of the assistant or not to be able to just train on the assistant answers + Args: + conversation (List[dict]): List of dicts where each dict contains the "from" key to specify the emisor del mensaje and the "value" key with the message. + Same format as SlimOrca dataset with possible from values: "System", "human" and "gpt" + Example: + conversation: [ { "from": "system", "value": "You are an AI assistant that follows instruction extremely well. Help as much as you can."}, + { "from": "human", "value": "Answer the following question: - number is 54 - debutteam is pittsburgh steelers - draftpick is 166 - birth date is 24 may 1982 - weight is 243 - nfl is wal475737 - debutyear is 2005 - finalteam is new york sentinels - statlabel is tackles sacks interceptions - heightin is 3 - statvalue is 9 0.0 1 - heightft is 6 - college is temple - birth place is pottstown , pennsylvania - draftyear is 2005 - position is linebacker - draftround is 5 - finalyear is 2009 Given the details above, guess who could this information be about.\nAnswer:"}, + { "from": "gpt", "value": "The information provided seems to refer to Rian Wallace, a former NFL player."} ] + + After applying chat template: + <|begin_of_text|><|start_header_id|>system<|end_header_id|> + + You are an AI assistant that follows instruction extremely well. Help as much as you can.<|eot_id|><|start_header_id|>human<|end_header_id|> + + Answer the following question: - number is 54 - debutteam is pittsburgh steelers - draftpick is 166 - birth date is 24 may 1982 - weight is 243 - nfl is wal475737 - debutyear is 2005 - finalteam is new york sentinels - statlabel is tackles sacks interceptions - heightin is 3 - statvalue is 9 0.0 1 - heightft is 6 - college is temple - birth place is pottstown , pennsylvania - draftyear is 2005 - position is linebacker - draftround is 5 - finalyear is 2009 Given the details above, guess who could this information be about. + Answer:<|eot_id|><|start_header_id|>gpt<|end_header_id|> + + The information provided seems to refer to Rian Wallace, a former NFL player.<|eot_id|> + returns: + tokens (List[int]): A list of tokens e.g. [128000, 128006, 9125, 128007, 271, 2675, 527, ..., 12873, 2851, 13, 128009, 128001] + is_completitions (List[bool]): A list of bools whether the tokens belong to the assistant answer or not e.g. [False, False, False, ..., False, True, True, True, True] + """ + tokens = [] + # Append <|begin_of_text|> + tokens.append(self.tokenizer.bos_token_id) + is_completitions = [False] * len(tokens) + + for message in conversation: + message_tokens, message_completitions = self.encode_message(message) + tokens.extend(message_tokens) + is_completitions.extend(message_completitions) + + # Append <|end_of_text|> token + tokens.extend(self.tokenizer.encode("<|end_of_text|>", add_special_tokens=False)) + is_completitions.append(True) + + return tokens, is_completitions + + def encode_message(self, message: dict) -> Tuple[List[int], List[int]]: + # TODO The "from", "value", "gpt" keys are form SlimOrca Dataset. Llama3 uses another ones. We should stick to a + # single format and document it properly rather than supporting multiple formats, as each one will need a different + # ChatTokenizer and the idea is that all Datasets share the same ChatTokenizer + + # Encode header + tokens = self.tokenizer.encode( + f"<|start_header_id|>{message['from']}<|end_header_id|>\n\n", add_special_tokens=False + ) + is_completitions = [False] * len(tokens) + + # Encode message + tokens.extend(self.tokenizer.encode(message["value"].strip(), add_special_tokens=False)) + + # Append <|eot_id|> token + tokens.extend(self.tokenizer.encode("<|eot_id|>", add_special_tokens=False)) + + # True if token belongs to assistant answer, False otherwise + is_completitions.extend([True if message["from"] == "gpt" else False] * (len(tokens) - len(is_completitions))) + + return tokens, is_completitions diff --git a/src/nanotron/data/collator.py b/src/nanotron/data/collator.py index 199527e1..b34a7369 100644 --- a/src/nanotron/data/collator.py +++ b/src/nanotron/data/collator.py @@ -1,4 +1,4 @@ -import dataclasses +from dataclasses import dataclass from typing import Dict, List, Union import numpy as np @@ -8,7 +8,7 @@ from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer -@dataclasses.dataclass +@dataclass class NanosetDataCollatorForCLM: """ Data collator used for causal language modeling with Nanosets dataset. @@ -78,3 +78,88 @@ def __call__(self, examples: List[Dict[str, List[np.ndarray]]]) -> Dict[str, Uni ) return result + + +# TODO Find a more elegant way. e.g. extend instead of append. OK, so no extend +# We could compute position ids after tokenizing each sample but we will still miss the last length of the padding tokens +def build_position_ids(lengths, sequence_length) -> np.array: + position_ids = [list(range(length)) for length in lengths] # Create position ids list + position_ids.append([0] * (sequence_length - sum(lengths))) # Append position_ids of the padding tokens + return np.array([x for xs in position_ids for x in xs], dtype=np.int32) # Flatten list of position ids + + +# TODO delete, just 4 switching the remove cross-attention setting +def build_position_ids_dummy(lengths, sequence_length) -> np.array: + return np.array(list(range(sequence_length)), dtype=np.int32) # TODO numpy arange + + +# TODO delete, just 4 switching the training only on completitions setting. This will be in the __iter__ method instead of a function +def build_labels_completions_only(input_ids, is_completitions): + labels = np.where( + is_completitions, input_ids, -100 + ) # Mask tokens that don't belong to the completitions by the Assistant + return np.array(labels[1:], dtype=np.int32) + + +# TODO delete, just 4 switching the training only on completitions setting +def build_labels(input_ids, is_completitions): + return np.array(input_ids[1:], dtype=np.int32) + + +@dataclass +class NanoChatDataCollatorForSFT: # TODO(tj.solergibert) Find a better name + """ + Data collator used with Chat Dataset. + - sequence_length: Sequence length of each sample in the batch + - input_pp_rank: Discards last input id token + - output_pp_rank: Discards first label id token + - other pp ranks: Don't have data. Instead, we use `TensorPointer` to point to the rank having the data. + """ + + sequence_length: int + input_pp_rank: int + output_pp_rank: int + parallel_context: ParallelContext + + def __call__(self, examples: List[Dict[str, List[int]]]) -> Dict[str, Union[torch.Tensor, TensorPointer]]: + # Process the case when current rank doesn't require data. We return `TensorPointer` that points to ranks having the data. + + current_pp_rank = dist.get_rank(self.parallel_context.pp_pg) + if current_pp_rank not in [ + self.input_pp_rank, + self.output_pp_rank, + ]: + assert all(len(example) == 0 for example in examples) + return { + "input_ids": TensorPointer(group_rank=self.input_pp_rank), + "input_mask": TensorPointer(group_rank=self.input_pp_rank), + "label_ids": TensorPointer(group_rank=self.output_pp_rank), + "label_mask": TensorPointer(group_rank=self.output_pp_rank), + } + + # TODO clean this, as we are flatting the batch there is no necessity for vstack but we need the batch dimension too + input_ids = np.vstack([examples[i]["input_ids"] for i in range(len(examples))]) # (b, s) + label_ids = np.vstack([examples[i]["label_ids"] for i in range(len(examples))]) # (b, s) + position_ids = np.vstack([examples[i]["position_ids"] for i in range(len(examples))]) # (b, s) + + result: Dict[str, Union[np.ndarray, TensorPointer]] = {} + + result["input_ids"] = TensorPointer(group_rank=self.input_pp_rank) + result["input_mask"] = TensorPointer(group_rank=self.input_pp_rank) + result["label_ids"] = TensorPointer(group_rank=self.output_pp_rank) + result["label_mask"] = TensorPointer(group_rank=self.output_pp_rank) + + # Process inputs + if current_pp_rank == self.input_pp_rank: + result["input_ids"] = input_ids + result["input_mask"] = np.ones((1, self.sequence_length), dtype=np.bool_) + result["position_ids"] = position_ids + + # Process labels: shift them to the left + if current_pp_rank == self.output_pp_rank: + result["label_ids"] = label_ids + result["label_mask"] = np.ones((1, self.sequence_length), dtype=np.bool_) + + # Cast np.array to torch.Tensor + result = {k: v if isinstance(v, TensorPointer) else torch.from_numpy(v) for k, v in result.items()} + return result diff --git a/src/nanotron/data/dataloader_builder.py b/src/nanotron/data/dataloader_builder.py index 9d3285f6..f63237ad 100644 --- a/src/nanotron/data/dataloader_builder.py +++ b/src/nanotron/data/dataloader_builder.py @@ -1,6 +1,6 @@ import nanotron.distributed as dist from nanotron import logging -from nanotron.data.collator import NanosetDataCollatorForCLM +from nanotron.data.collator import NanoChatDataCollatorForSFT, NanosetDataCollatorForCLM from nanotron.dataloader import ( EmptyInfiniteDataset, get_dataloader_worker_init, @@ -62,3 +62,36 @@ def build_nanoset_dataloader( pin_memory=dataloader_pin_memory, worker_init_fn=get_dataloader_worker_init(dp_rank=dp_rank), ) + + +def build_chat_dataloader( + dataset, + sequence_length: int, + parallel_context: ParallelContext, + input_pp_rank: int, + output_pp_rank: int, + dataloader_pin_memory: bool = True, +) -> DataLoader: + + # Case of ranks not requiring data. We give them a dummy dataset, then the collator will do his job + if dist.get_rank(parallel_context.pp_pg) not in [input_pp_rank, output_pp_rank]: + dataset_length = 1_000_000 # len(dataset) TODO find a more elegant way to specify this dummy dataset + dataset = EmptyInfiniteDataset(length=dataset_length) + + data_collator = NanoChatDataCollatorForSFT( + sequence_length=sequence_length, + input_pp_rank=input_pp_rank, + output_pp_rank=output_pp_rank, + parallel_context=parallel_context, + ) + + dp_rank = parallel_context.dp_pg.rank() + + return DataLoader( + dataset, + batch_size=1, + collate_fn=data_collator, + num_workers=0, + pin_memory=dataloader_pin_memory, + worker_init_fn=get_dataloader_worker_init(dp_rank=dp_rank), + ) diff --git a/src/nanotron/models/llama_sft.py b/src/nanotron/models/llama_sft.py new file mode 100644 index 00000000..a7ccb9d2 --- /dev/null +++ b/src/nanotron/models/llama_sft.py @@ -0,0 +1,888 @@ +# coding=utf-8 +# Copyright 2018 HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch LLaMa model.""" + +from typing import Dict, Optional, Union + +import torch +from flash_attn import flash_attn_varlen_func +from torch import nn + +from nanotron import distributed as dist +from nanotron import logging +from nanotron.config import Config, LlamaConfig, ParallelismArgs +from nanotron.config.models_config import RandomInit, SpectralMupInit +from nanotron.generation.generate_store import AttachableStore +from nanotron.logging import log_rank +from nanotron.models import NanotronModel +from nanotron.nn.activations import ACT2FN +from nanotron.nn.layer_norm import TritonRMSNorm +from nanotron.parallel import ParallelContext +from nanotron.parallel.parameters import NanotronParameter +from nanotron.parallel.pipeline_parallel.block import PipelineBlock, TensorPointer +from nanotron.parallel.pipeline_parallel.p2p import P2P +from nanotron.parallel.tensor_parallel.functional import sharded_cross_entropy +from nanotron.parallel.tensor_parallel.nn import ( + TensorParallelColumnLinear, + TensorParallelEmbedding, + TensorParallelLinearMode, + TensorParallelRowLinear, +) +from nanotron.random import RandomStates +from nanotron.scaling.parametrization import SpectralMupParametrizator, StandardParametrizator +from nanotron.utils import checkpoint_method + +logger = logging.get_logger(__name__) + + +####### +# NOTE(tj.solergibert) Copied from https://github.com/huggingface/transformers/blob/81233c069c166af033794134bd8888783ac49ebe/src/transformers/modeling_rope_utils.py#L29 +def _compute_default_rope_parameters( + config: LlamaConfig, +) -> torch.Tensor: + """ + Computes the inverse frequencies according to the original RoPE implementation + Args: + config (LlamaConfig): + The model configuration. + Returns: + inv_freq (torch.Tensor) + Contains the inverse frequencies for the RoPE embeddings + """ + + base = config.rope_theta # NOTE(tj.solergibert) 500000.0 + partial_rotary_factor = ( + config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 + ) # NOTE(tj.solergibert) 1 + dim = int((config.hidden_size // config.num_attention_heads) * partial_rotary_factor) # NOTE(tj.solergibert) 128 + + # Compute the inverse frequencies + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim)) + return inv_freq + + +# NOTE(tj.solergibert) Copied from https://github.com/huggingface/transformers/blob/5f841c74b62754f186a8c06a684d491524b7bc03/src/transformers/models/llama/modeling_llama.py#L81 +# NOTE(tj.solergibert) FlashAttention RoPEs are faster (triton), but currently they don't support position_ids +# NOTE(tj.solergibert) This function is just called once per batch to compute the position_embeddings, the expensive operation +# is def apply_rotary_pos_emb +class LlamaRotaryEmbedding(nn.Module): + def __init__( + self, + config: LlamaConfig, + ): + super().__init__() + self.config = config + + inv_freq = _compute_default_rope_parameters(self.config) # NOTE(tj.solergibert) shape: 64 , 1.0 + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + def forward(self, x, position_ids): + # Core RoPE block + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# NOTE(tj.solergibert) FlashAttention RoPEs are faster (triton), but currently they don't support position_ids +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (torch.Tensor): The query tensor. + k (torch.Tensor): The key tensor. + cos (torch.Tensor): The cosine part of the rotary embedding. + sin (torch.Tensor): The sine part of the rotary embedding. + unsqueeze_dim (int, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + tuple (torch.Tensor) comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) # NOTE(tj.solergibert) [1, 70, 128] --> [1, 1, 70, 128] + sin = sin.unsqueeze(unsqueeze_dim) # NOTE(tj.solergibert) + q_embed = (q * cos) + (rotate_half(q) * sin) # NOTE(tj.solergibert) [1, 32, 70, 128] + k_embed = (k * cos) + (rotate_half(k) * sin) # NOTE(tj.solergibert) [1, 8, 70, 128] + return q_embed, k_embed + + +def prepare_varlen_args(position_ids): + position_ids = position_ids.flatten() + indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32) + cu_seqlens = torch.cat( + ( + indices_q[position_ids == 0], + torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32), + ) + ) + + max_seqlen_in_batch = position_ids.max() + 1 + + return cu_seqlens, max_seqlen_in_batch + + +####### + + +class GLUActivation(nn.Module): + def __init__(self, act_fn_name: str): + super().__init__() + self.act = ACT2FN[act_fn_name] + + def forward(self, merged_states: torch.Tensor): + gate_states, up_states = torch.split(merged_states, merged_states.shape[-1] // 2, dim=-1) + return self.act(gate_states) * up_states + + +class MLP(nn.Module): + def __init__( + self, + config: LlamaConfig, + parallel_config: Optional[ParallelismArgs], + tp_pg: dist.ProcessGroup, + ): + super().__init__() + + # TODO @thomasw21: refactor so that we store that default in a single place. + tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE + tp_linear_async_communication = ( + parallel_config.tp_linear_async_communication if parallel_config is not None else False + ) + + gate_up_contiguous_chunks = ( + config.intermediate_size, # shape of gate_linear + config.intermediate_size, # shape of up_linear + ) + self.gate_up_proj = TensorParallelColumnLinear( + config.hidden_size, + 2 * config.intermediate_size, + pg=tp_pg, + mode=tp_mode, + bias=False, + async_communication=tp_linear_async_communication, + contiguous_chunks=gate_up_contiguous_chunks, + ) + self.down_proj = TensorParallelRowLinear( + config.intermediate_size, + config.hidden_size, + pg=tp_pg, + mode=tp_mode, + bias=False, + async_communication=tp_linear_async_communication and tp_mode is TensorParallelLinearMode.REDUCE_SCATTER, + ) + # TODO @nouamane: why can't we torch.jit.script GLUActivation? + self.split_silu_mul = GLUActivation(config.hidden_act) + + def forward(self, hidden_states): # [seq_length, batch_size, hidden_dim] + merged_states = self.gate_up_proj(hidden_states) + hidden_states = self.down_proj(self.split_silu_mul(merged_states)) + return hidden_states + + +class CoreAttention(nn.Module): + def __init__(self, config: LlamaConfig, parallel_config: Optional[ParallelismArgs], layer_idx: int): + super().__init__() + # TODO @thomasw21: GPT has a weird `d_kv` config which I'm guessing is essentically a `d_qkv` + assert ( + config.hidden_size % config.num_attention_heads == 0 + ), f"Hidden size {config.hidden_size} must be divisible by number of attention heads {config.num_attention_heads}." + self.d_qk = config.hidden_size // config.num_attention_heads + self.d_v = config.hidden_size // config.num_attention_heads + self.is_using_mup = config.is_using_mup + + self.checkpoint_attention = False # Because flash_attn already does checkpointing + + @checkpoint_method(attr_name="checkpoint_attention") + def forward( + self, + query_states: torch.Tensor, # [batch_size, q_length, n_local_q_heads, inner_dim] + key_states: torch.Tensor, # [batch_size, kv_length, n_local_kv_heads, inner_dim] + value_states: torch.Tensor, # [batch_size, kv_length, n_local_kv_heads, inner_dim] + ): + from flash_attn.flash_attn_interface import flash_attn_func + + # NOTE: this scale is for µTransfer, + # in SP, we use sqrt(1/d_h) + softmax_scale = 1 / query_states.shape[-1] if self.is_using_mup else None + # For now we are assuming that we use causual mask. No magic here + causal = True + attn_output = flash_attn_func( + q=query_states, + k=key_states, + v=value_states, + dropout_p=0.0, + softmax_scale=softmax_scale, + causal=causal, + return_attn_probs=False, + ) + + return attn_output + + +class CausalSelfAttention(nn.Module, AttachableStore): + def __init__( + self, + config: LlamaConfig, + parallel_config: Optional[ParallelismArgs], + tp_pg: dist.ProcessGroup, + layer_idx: int, + ): + + super().__init__() + # Tensor parallel considerations: We split tensors along head dimension + assert ( + config.num_attention_heads % tp_pg.size() == 0 + ), f"Number of attention heads ({config.num_attention_heads}) must be divisible by TP size ({tp_pg.size()})." + try: + assert ( + config.num_key_value_heads % tp_pg.size() == 0 + ), f"Number of key/value heads ({config.num_key_value_heads}) must be divisible by TP size ({tp_pg.size()})." + except AttributeError: + log_rank( + "WARNING: num_key_value_heads not defined, assuming it is equal to num_attention_heads", + logger=logger, + level=logging.WARNING, + rank=0, + ) + # If num_key_value_heads is not defined, we assume that it is equal to num_attention_heads + config.num_key_value_heads = config.num_attention_heads + assert ( + config.num_attention_heads % config.num_key_value_heads == 0 + ), f"Number of attention heads ({config.num_attention_heads}) must be divisible by number of key/value heads ({config.num_key_value_heads})." + self.n_local_q_heads = config.num_attention_heads // tp_pg.size() + self.n_local_kv_heads = config.num_key_value_heads // tp_pg.size() + self.n_repeats = config.num_attention_heads // config.num_key_value_heads + self.is_gqa = config.num_attention_heads != config.num_key_value_heads # Whether we are using GQA or not + self.d_qk = config.hidden_size // config.num_attention_heads + self.d_v = config.hidden_size // config.num_attention_heads + self.d_model = config.hidden_size + self.is_using_mup = config.is_using_mup + + # TODO @thomasw21: refactor so that we store that default in a single place. + tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE + tp_linear_async_communication = ( + parallel_config.tp_linear_async_communication if parallel_config is not None else False + ) + + # build the slice config for self.qkv for save/load + # shard are done within the contiguous chunk + qkv_contiguous_chunks = ( + config.num_attention_heads * self.d_qk, # shape of q + config.num_key_value_heads * self.d_qk, # shape of k + config.num_key_value_heads * self.d_qk, # shape of v + ) + self.qkv_proj = TensorParallelColumnLinear( + self.d_model, + config.num_attention_heads * self.d_qk + 2 * config.num_key_value_heads * self.d_qk, + pg=tp_pg, + mode=tp_mode, + bias=False, + async_communication=tp_linear_async_communication, + contiguous_chunks=qkv_contiguous_chunks, + ) + + self.o_proj = TensorParallelRowLinear( + config.num_attention_heads * self.d_qk, + self.d_model, + pg=tp_pg, + mode=tp_mode, + bias=False, + async_communication=tp_linear_async_communication, + ) + + # TODO(tj.solergibert) Deshacernos de este bloque POR DIOS!!! + self.attention = CoreAttention( + config, + parallel_config=parallel_config, + layer_idx=layer_idx, + ) + + def forward( + self, + hidden_states, # [seq_length, batch_size, hidden_size] + position_ids, # [batch_size, seq_length] + cos, # [batch_size, seq_length, hidden_size//num_attention_heads] + sin, # [batch_size, seq_length, hidden_size//num_attention_heads] + ): + qkv_states = self.qkv_proj( + hidden_states + ) # [seq_length, batch_size, n_local_q_heads * d_qk + 2 * n_local_kv_heads * d_qk] + q_length, batch_size, _ = qkv_states.shape + + if self.is_gqa: + query_states, key_states, value_states = torch.split( + qkv_states, + [ + self.n_local_q_heads * self.d_qk, + self.n_local_kv_heads * self.d_qk, + self.n_local_kv_heads * self.d_qk, + ], + dim=-1, + ) + + query_states = ( + query_states.transpose(0, 1).contiguous().view(batch_size, q_length, self.n_local_q_heads, self.d_qk) + ) + key_states = ( + key_states.transpose(0, 1).contiguous().view(batch_size, q_length, self.n_local_kv_heads, self.d_qk) + ) + value_states = ( + value_states.transpose(0, 1).contiguous().view(batch_size, q_length, self.n_local_kv_heads, self.d_qk) + ) + else: + query_states, key_states, value_states = ( + qkv_states.view(q_length, batch_size, 3, self.n_local_q_heads, self.d_qk) + .permute(2, 1, 0, 3, 4) + .contiguous() + ) # [3, batch_size, seq_length, n_local_q_heads, d_qk] + + # Training case OLD + # Apply rotary embeddings to query/key states + # NOTE: The layout is different from models/llama.py which is [batch_size, num_heads, seq_length, d_qk] + # Here it is, [batch_size, seq_length, num_heads, d_qk] + # [2, batch_size, seq_length, num_heads, d_qk] + # key_value_states = torch.cat([key_states.unsqueeze(0), value_states.unsqueeze(0)], dim=0) + # [batch_size, seq_length, 2, num_heads, d_qk] + # key_value_states = key_value_states.permute(1, 2, 0, 3, 4).contiguous() + # query_states, key_value_states = self.flash_rotary_embedding(query_states, kv=key_value_states) + # [batch_size, seq_length, num_heads, d_qk] + # key_states, value_states = torch.split(key_value_states, 1, dim=2) + + # TODO(tj.solergibert) ver si esto sirve de algo o no!!!!! + # kv_length = key_states.shape[1] + # key_states = key_states.view(batch_size, kv_length, self.n_local_kv_heads, self.d_qk) + # value_states = value_states.view(batch_size, kv_length, self.n_local_kv_heads, self.d_v) + + # attention_output = self.attention( + # query_states=query_states, + # key_states=key_states, + # value_states=value_states, + # ) + + # TODO(tj.solergibert) Apply RoPE embeddings WITHOUT too many transpose... + query_states, key_states = query_states.transpose(1, 2), key_states.transpose(1, 2) + # Apply RoPE + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = query_states.transpose(1, 2), key_states.transpose(1, 2) + + # Prepare varlen args + cu_seqlens, max_seqlen_in_batch = prepare_varlen_args(position_ids) + print(cu_seqlens) + print(max_seqlen_in_batch) + query_states = query_states.view(-1, query_states.size(-2), query_states.size(-1)) + key_states = key_states.view(-1, key_states.size(-2), key_states.size(-1)) + value_states = value_states.view(-1, value_states.size(-2), value_states.size(-1)) + + attention_output = flash_attn_varlen_func( + query_states, # NOTE(tj.solergibert) Shape: [70, 32, 128] + key_states, # NOTE(tj.solergibert) Shape: [70, 8, 128] + value_states, # NOTE(tj.solergibert) Shape: [70, 8, 128] + cu_seqlens_q=cu_seqlens, # NOTE(tj.solergibert) Shape: Tensor, [14] + cu_seqlens_k=cu_seqlens, # NOTE(tj.solergibert) Shape: Tensor, [14] + max_seqlen_q=max_seqlen_in_batch, # NOTE(tj.solergibert) Shape: Tensor, [1] Just 1 element with the longer sequence in batch. In the HF Transformers dummy test is 7 + max_seqlen_k=max_seqlen_in_batch, # NOTE(tj.solergibert) Shape: Tensor, [1] Just 1 element with the longer sequence in batch. In the HF Transformers dummy test is 7 + causal=True, # NOTE(tj.solergibert) True + ) # NOTE(tj.solergibert) Returns out: (total, nheads, headdim). + + attention_output = ( + attention_output.contiguous() + .view(batch_size, q_length, self.n_local_q_heads * self.d_v) + .transpose(0, 1) # TODO(tj.solergibert) View is necessary, but contiguous? + ) + output = self.o_proj(attention_output) + + return output + + +class LlamaDecoderLayer(nn.Module): + def __init__( + self, + config: LlamaConfig, + parallel_config: Optional[ParallelismArgs], + tp_pg: dist.ProcessGroup, + layer_idx: int, + ): + super().__init__() + self.input_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.attn = CausalSelfAttention( + config=config, + parallel_config=parallel_config, + tp_pg=tp_pg, + layer_idx=layer_idx, + ) + + self.post_attention_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.mlp = MLP(config=config, parallel_config=parallel_config, tp_pg=tp_pg) + + def forward( + self, + hidden_states: Union[torch.Tensor, TensorPointer], + position_ids: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] + cos: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] + sin: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] + ) -> Dict[str, Union[torch.Tensor, TensorPointer]]: + + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + hidden_states = self.attn(hidden_states=hidden_states, position_ids=position_ids, cos=cos, sin=sin) + hidden_states = hidden_states + residual + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states=hidden_states) + hidden_states = hidden_states + residual + + return { + "hidden_states": hidden_states, + "position_ids": position_ids, + "cos": cos, + "sin": sin, + } + + +class Embedding(nn.Module, AttachableStore): + def __init__(self, tp_pg: dist.ProcessGroup, config: LlamaConfig, parallel_config: Optional[ParallelismArgs]): + super().__init__() + self.token_embedding = TensorParallelEmbedding( + num_embeddings=config.vocab_size, + embedding_dim=config.hidden_size, + padding_idx=config.pad_token_id, + pg=tp_pg, + mode=parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE, + ) + self.pg = tp_pg + + # NOTE(tj.solergibert) SFT + self.position_embedding = LlamaRotaryEmbedding(config=config) + + def forward(self, input_ids: torch.Tensor, position_ids: torch.Tensor): # [batch_size, seq_length] + # TODO(tj.solergibert) Delete this store stuff ################ + store = self.get_local_store() + if store is not None: + if "past_length" in store: + store["past_length"] + else: + torch.zeros(1, dtype=torch.long, device=input_ids.device).expand(input_ids.shape[0]) + + # cumsum_mask = input_mask.cumsum(-1, dtype=torch.long) + # Store new past_length in store + # store["past_length"] = past_length + cumsum_mask[:, -1] + ################################################################ + + # NOTE(tj.solergibert) We create the cos & sin and propagate them through the pipeline so we + # don't have to create the LlamaRotaryEmbedding layer in each and every decoder layer + # We will still send the position ids for the varlen, but we will try to delete it. Computing them from + # the position ids it's not very expensive AND we keep a tensor with constant shape + cos, sin = self.position_embedding( + input_ids, position_ids + ) # TODO(tj.solergibert) We just need from inputs_ids the device type + + # Format input in `[seq_length, batch_size]` to support high TP with low batch_size + input_ids = input_ids.transpose(0, 1) + input_embeds = self.token_embedding(input_ids) + return {"input_embeds": input_embeds, "position_ids": position_ids, "cos": cos, "sin": sin} + + +class LlamaModel(nn.Module): + """Build pipeline graph""" + + def __init__( + self, + config: LlamaConfig, + parallel_context: ParallelContext, + parallel_config: Optional[ParallelismArgs], + ): + super().__init__() + + # Declare all the nodes + self.p2p = P2P(parallel_context.pp_pg, device=torch.device("cuda")) + self.config = config + self.parallel_config = parallel_config + self.parallel_context = parallel_context + self.tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE + tp_linear_async_communication = ( + parallel_config.tp_linear_async_communication if parallel_config is not None else False + ) + + self.token_position_embeddings = PipelineBlock( + p2p=self.p2p, + module_builder=Embedding, + module_kwargs={ + "tp_pg": parallel_context.tp_pg, + "config": config, + "parallel_config": parallel_config, + }, + module_input_keys={"input_ids", "position_ids"}, + module_output_keys={"input_embeds", "position_ids", "cos", "sin"}, + ) + + self.decoder = nn.ModuleList( + [ + PipelineBlock( + p2p=self.p2p, + module_builder=LlamaDecoderLayer, + module_kwargs={ + "config": config, + "parallel_config": parallel_config, + "tp_pg": parallel_context.tp_pg, + "layer_idx": layer_idx, + }, + module_input_keys={"hidden_states", "position_ids", "cos", "sin"}, + module_output_keys={"hidden_states", "position_ids", "cos", "sin"}, + ) + for layer_idx in range(config.num_hidden_layers) + ] + ) + + self.final_layer_norm = PipelineBlock( + p2p=self.p2p, + module_builder=TritonRMSNorm, + module_kwargs={"hidden_size": config.hidden_size, "eps": config.rms_norm_eps}, + module_input_keys={"input"}, + module_output_keys={"hidden_states"}, + ) # TODO + + self.lm_head = PipelineBlock( + p2p=self.p2p, + # Understand that this means that we return sharded logits that are going to need to be gathered + module_builder=TensorParallelColumnLinear, + module_kwargs={ + "in_features": config.hidden_size, + "out_features": config.vocab_size, + "pg": parallel_context.tp_pg, + "bias": False, + # TODO @thomasw21: refactor so that we store that default in a single place. + "mode": self.tp_mode, + "async_communication": tp_linear_async_communication, + }, + module_input_keys={"x"}, + module_output_keys={"logits"}, + ) + + self.cast_to_fp32 = PipelineBlock( + p2p=self.p2p, + module_builder=lambda: lambda x: x.float(), + module_kwargs={}, + module_input_keys={"x"}, + module_output_keys={"output"}, + ) + + def forward( + self, + input_ids: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] + position_ids: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] + ): + return self.forward_with_hidden_states(input_ids=input_ids, position_ids=position_ids)[0] + + def forward_with_hidden_states( + self, + input_ids: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] + position_ids: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] + ): + # all tensors are optional as most ranks don't need anything from the dataloader. + + hidden_encoder_states = self.token_position_embeddings(input_ids=input_ids, position_ids=position_ids) + + # NOTE(tj.solergibert) Rename input_embeds --> hidden_states + hidden_encoder_states["hidden_states"] = hidden_encoder_states.pop("input_embeds") + + for encoder_block in self.decoder: + hidden_encoder_states = encoder_block(**hidden_encoder_states) + + hidden_states = self.final_layer_norm(input=hidden_encoder_states["hidden_states"])["hidden_states"] + + sharded_logits = self.lm_head(x=hidden_states)["logits"] + + fp32_sharded_logits = self.cast_to_fp32(x=sharded_logits)["output"] + + return fp32_sharded_logits, hidden_states + + def get_block_compute_costs(self): + """Computes the compute cost of each block in the model so that we can do a better job of load balancing.""" + model_config = self.config + d_ff = model_config.intermediate_size + d_qkv = model_config.hidden_size // model_config.num_attention_heads + block_compute_costs = { + # CausalSelfAttention (qkv proj + attn out) + MLP + LlamaDecoderLayer: 4 * model_config.num_attention_heads * d_qkv * model_config.hidden_size + + 3 * d_ff * model_config.hidden_size, + # This is the last lm_head + TensorParallelColumnLinear: model_config.vocab_size * model_config.hidden_size, + } + return block_compute_costs + + def get_flops_per_sec(self, iteration_time_in_sec, sequence_length, global_batch_size): + """Get flops per second for a given model""" + world_size = self.parallel_context.world_pg.size() + try: + num_key_values_heads = self.config.num_key_value_heads + except AttributeError: + num_key_values_heads = self.config.num_attention_heads + + model_flops, hardware_flops = get_flops( + num_layers=self.config.num_hidden_layers, + hidden_size=self.config.hidden_size, + num_heads=self.config.num_attention_heads, + num_key_value_heads=num_key_values_heads, + vocab_size=self.config.vocab_size, + ffn_hidden_size=self.config.intermediate_size, + seq_len=sequence_length, + batch_size=global_batch_size, + ) + + model_flops_per_s = model_flops / (iteration_time_in_sec * world_size * 1e12) + hardware_flops_per_s = hardware_flops / (iteration_time_in_sec * world_size * 1e12) + return model_flops_per_s, hardware_flops_per_s + + +@torch.jit.script +def masked_mean(loss, label_mask, dtype): + # type: (Tensor, Tensor, torch.dtype) -> Tensor + return (loss * label_mask).sum(dtype=dtype) / label_mask.sum() + + +class Loss(nn.Module): + def __init__(self, tp_pg: dist.ProcessGroup): + super().__init__() + self.tp_pg = tp_pg + + def forward( + self, + sharded_logits: torch.Tensor, # [seq_length, batch_size, logits] + label_ids: torch.Tensor, # [batch_size, seq_length] + label_mask: torch.Tensor, # [batch_size, seq_length] + ) -> Dict[str, torch.Tensor]: + # Megatron by defaults cast everything in fp32. `--f16-lm-cross-entropy` is an option you can use to keep current precision. + # https://github.com/NVIDIA/Megatron-LM/blob/f267e6186eae1d6e2055b412b00e2e545a8e896a/megatron/model/gpt_model.py#L38 + + loss = sharded_cross_entropy( + sharded_logits, label_ids.transpose(0, 1).contiguous(), group=self.tp_pg, dtype=torch.float + ).transpose(0, 1) + # TODO @thomasw21: It's unclear what kind of normalization we want to do. + loss = masked_mean(loss, label_mask, dtype=torch.float) + # I think indexing causes a sync we don't actually want + # loss = loss[label_mask].sum() + return {"loss": loss} + + +class LlamaForSFT(NanotronModel): + def __init__( + self, + config: LlamaConfig, + parallel_context: ParallelContext, + parallel_config: Optional[ParallelismArgs], + random_states: Optional[RandomStates] = None, + ): + super().__init__() + self.model = LlamaModel(config=config, parallel_context=parallel_context, parallel_config=parallel_config) + self.loss = PipelineBlock( + p2p=self.model.p2p, + module_builder=Loss, + module_kwargs={"tp_pg": parallel_context.tp_pg}, + module_input_keys={ + "sharded_logits", + "label_ids", + "label_mask", + }, + module_output_keys={"loss"}, + ) + self.parallel_context = parallel_context + self.config = config + self.parallel_config = parallel_config + + def forward( + self, + input_ids: Union[torch.Tensor, TensorPointer], + position_ids: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] + label_ids: Union[torch.Tensor, TensorPointer], + label_mask: Union[torch.Tensor, TensorPointer], + ) -> Dict[str, Union[torch.Tensor, TensorPointer]]: + sharded_logits = self.model( + input_ids=input_ids, + position_ids=position_ids, + ) + loss = self.loss( + sharded_logits=sharded_logits, + label_ids=label_ids, + label_mask=label_mask, + )["loss"] + return {"loss": loss} + + @torch.no_grad() + def init_model_randomly(self, config: Config): + """Initialize model parameters randomly. + Note: + Layernorm weight all 0 or 1 depending on `apply_layernorm_1p` + """ + init_method = config.model.init_method + if isinstance(init_method, RandomInit): + parametrizator_cls = StandardParametrizator + elif isinstance(init_method, SpectralMupInit): + parametrizator_cls = SpectralMupParametrizator + else: + raise ValueError(f"Unknown init method {init_method}") + + parametrizator = parametrizator_cls(config=config.model) + + log_rank( + f"Parametrizing model parameters using {parametrizator.__class__.__name__}", + logger=logger, + level=logging.INFO, + rank=0, + ) + + model = self + initialized_parameters = set() + # Handle tensor parallelism + module_id_to_prefix = {id(module): f"{module_name}." for module_name, module in model.named_modules()} + # Fix the root_model + module_id_to_prefix[id(model)] = "" + + for param_name, param in model.named_parameters(): + assert isinstance(param, NanotronParameter) + + module_name, param_name = param_name.rsplit(".", 1) + + if param.is_tied: + tied_info = param.get_tied_info() + full_param_name = tied_info.get_full_name_from_module_id_to_prefix( + module_id_to_prefix=module_id_to_prefix + ) + else: + full_param_name = f"{module_name}.{param_name}" + + if full_param_name in initialized_parameters: + # Already initialized + continue + + module = model.get_submodule(module_name) + parametrizator.parametrize(param_name, module) + + assert full_param_name not in initialized_parameters + initialized_parameters.add(full_param_name) + + assert initialized_parameters == { + param.get_tied_info().get_full_name_from_module_id_to_prefix(module_id_to_prefix=module_id_to_prefix) + if param.is_tied + else name + for name, param in model.named_parameters() + }, f"Somehow the initialized set of parameters don't match:\n - Expected: { {name for name, _ in model.named_parameters()} }\n - Got: {initialized_parameters}" + + def get_embeddings_lm_head_tied_names(self): + """Get the names of the tied embeddings and lm_head weights""" + if self.config.tie_word_embeddings is True: + return ["model.token_position_embeddings.pp_block.token_embedding.weight", "model.lm_head.pp_block.weight"] + else: + return [] + + def get_block_compute_costs(self): + """Computes the compute cost of each block in the model so that we can do a better job of load balancing.""" + return self.model.get_block_compute_costs() + + def get_flops_per_sec(self, iteration_time_in_sec, sequence_length, global_batch_size): + """Get flops per second for a given model""" + return self.model.get_flops_per_sec(iteration_time_in_sec, sequence_length, global_batch_size) + + +def get_flops( + num_layers, + hidden_size, + num_heads, + num_key_value_heads, + vocab_size, + seq_len, + ffn_hidden_size, + batch_size=1, +): + """Counts flops in an decoder-only model + Args: + num_layers: number of decoder layers + hidden_size: hidden size of the model + num_heads: number of heads in the model + num_key_value_heads: number of key/value heads in the model + ffn_hidden_size: hidden size of the FFN + vocab_size: size of the vocabulary + seq_len: sequence length of the decoder + batch_size: batch size + Returns: + model_flops: flops in the model (should be independent of the hardware and model implementation) + hardware_flops: flops in the hardware (actual flops performed on the hardware). Check 6.3 in https://arxiv.org/pdf/2205.05198.pdf + """ + if num_key_value_heads is None: + num_key_value_heads = num_heads + hidden_size_per_head = hidden_size // num_heads + # In the following we mark the reduced dimension with parentheses + # decoder + # self attention + ## qkv projection + decoder_qkv_proj_flops_fwd = ( + 2 * num_layers * batch_size * seq_len * (hidden_size) * num_heads * hidden_size_per_head + + 2 * num_layers * batch_size * seq_len * (hidden_size) * 2 * num_key_value_heads * hidden_size_per_head + ) + ## qk logits + decoder_qk_logits_flops_fwd = 2 * num_layers * batch_size * num_heads * seq_len * (hidden_size_per_head) * seq_len + ## v logits + decoder_v_logits_flops_fwd = 2 * num_layers * batch_size * num_heads * seq_len * (seq_len) * hidden_size_per_head + ## attn out + decoder_attn_out_flops_fwd = ( + 2 * num_layers * batch_size * num_heads * seq_len * (hidden_size_per_head) * hidden_size + ) + # FF + ## 1st layer + decoder_ffn_1_flops_fwd = 4 * num_layers * batch_size * seq_len * (hidden_size) * ffn_hidden_size + ## 2nd layer + decoder_ffn_2_flops_fwd = 2 * num_layers * batch_size * seq_len * (ffn_hidden_size) * hidden_size + + decoder_flops_fwd = ( + decoder_qkv_proj_flops_fwd + + decoder_qk_logits_flops_fwd + + decoder_v_logits_flops_fwd + + decoder_attn_out_flops_fwd + + decoder_ffn_1_flops_fwd + + decoder_ffn_2_flops_fwd + ) + + # lm head + lm_head_flops_fwd = 2 * batch_size * seq_len * (hidden_size) * vocab_size + + # the bwd pass requires double the flops in case of matmuls to calculate the gradients with respect to + # both input and weight tensors + model_flops = 3 * (decoder_flops_fwd + lm_head_flops_fwd) # 1 for fwd + 2 for bwd + + hardware_flops = model_flops # TODO: This is a placeholder for now + + return model_flops, hardware_flops diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index b6752f38..9984b881 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -56,7 +56,7 @@ ) from nanotron.models import NanotronModel, build_model from nanotron.models.base import check_model_has_grad -from nanotron.models.llama import LlamaForTraining, RotaryEmbedding +from nanotron.models.llama import LlamaForTraining from nanotron.models.starcoder2 import Starcoder2ForTraining from nanotron.optim.clip_grads import clip_grad_norm from nanotron.parallel import ParallelContext @@ -750,11 +750,12 @@ def _init_model( model_builder=model_builder, ) + # TODO(tj.solergibert) Fix this RoPE init only used with LlamaModel for generation? # Initialize rotary embeddings - for module in model.modules(): - if not isinstance(module, RotaryEmbedding): - continue - module.init_rotary_embeddings() + # for module in model.modules(): + # if not isinstance(module, RotaryEmbedding): + # continue + # module.init_rotary_embeddings() # Mark some parameters as tied self._mark_tied_parameters(model=model, parallel_context=parallel_context, parallel_config=parallel_config)