diff --git a/ochat/config/__init__.py b/ochat/config/__init__.py index d221042..9d4bbce 100644 --- a/ochat/config/__init__.py +++ b/ochat/config/__init__.py @@ -16,11 +16,31 @@ } +_GEMMA_IT_PREFIXES = { + "user": "user", + "assistant": "model" +} + + def _v3_2_role_prefix(from_role, condition): return f"{condition} {_V3_2_PREFIXES[from_role]}".strip() MODEL_CONFIG_MAP = { + # OpenChat V3.6 (MoE) + "openchat_3.6": ModelConfig( + # Model + model_max_context=8192, + model_tokenizer_create=partial(transformers.AutoTokenizer.from_pretrained, use_fast=False), + model_create_for_training=lambda: None, # NOTE(one): MoE trainer decoupled from the codebase + + # Conversation Template + conversation_template=partial(ConversationTemplate, + role_prefix=_v3_2_role_prefix, + eot="", + inference_condition="GPT4 Correct") + ), + # OpenChat V3.2 "openchat_v3.2": ModelConfig( # Model @@ -54,6 +74,23 @@ def _v3_2_role_prefix(from_role, condition): inference_condition="GPT4 Correct") ), + "openchat_v3.2_gemma_new": ModelConfig( + serving_aliases=("openchat_3.5_gemma_new", ), + + # Model + model_max_context=8192, + model_tokenizer_create=partial(transformers.AutoTokenizer.from_pretrained, use_fast=False), + model_create_for_training=partial(ochat.models.GemmaForCausalLM.from_pretrained, + low_cpu_mem_usage=True, + torch_dtype=torch.bfloat16), + + # Conversation Template + conversation_template=partial(ConversationTemplate, + role_prefix=_v3_2_role_prefix, + eot="", + inference_condition="GPT4 Correct") + ), + ### Other models "chatml_mistral": ModelConfig( # Model @@ -83,4 +120,18 @@ def _v3_2_role_prefix(from_role, condition): eot="", inference_condition="") ), + "gemma_it": ModelConfig( + # Model + model_max_context=8192, + model_tokenizer_create=partial(transformers.AutoTokenizer.from_pretrained, use_fast=False), + model_create_for_training=partial(ochat.models.GemmaForCausalLM.from_pretrained, + low_cpu_mem_usage=True, + torch_dtype=torch.bfloat16), + + # Conversation Template + conversation_template=partial(ConversationTemplate, + role_prefix=lambda from_role, condition: f"{_GEMMA_IT_PREFIXES[from_role]}\n", + eot="", + inference_condition="") + ), } diff --git a/ochat/data/generate_dataset.py b/ochat/data/generate_dataset.py index f1ccf39..b0604d0 100644 --- a/ochat/data/generate_dataset.py +++ b/ochat/data/generate_dataset.py @@ -4,9 +4,9 @@ Usage: python -m ochat.data.generate_data --in-file sharegpt_gpt4.jsonl --tokenizer-name HF_REPO_NAME --out-dir . """ -from typing import Optional import argparse import os +import gc import random import ray @@ -77,6 +77,9 @@ def convert_conversation_batch(model_type: str, model_path: str, batch: list, sc print ("Tokenizing ...") tokens_list, weights_list = conv_template.tokenize_conversations(batch, inference=False, seq_level_weight=per_sequence_loss) + del batch + gc.collect() + # Generate data print ("Generating ...") max_context = model_config.model_max_context @@ -92,12 +95,20 @@ def convert_conversation_batch(model_type: str, model_path: str, batch: list, sc # Add to results add_single_conv(outputs, tokens, weights) - print ("Chunk finish") + del tokens_list, weights_list + gc.collect() - return pyarrow.Table.from_pydict(outputs, schema=schema) + print ("To table ...") + table = pyarrow.Table.from_pydict(outputs, schema=schema) + del outputs + gc.collect() + + print ("Chunk finish") + return table -def generate_split(model_type: str, model_path: str, conversations: list, split_name: str, out_prefix: str, per_sequence_loss: bool): + +def generate_epoch(seed: int, model_type: str, model_path: str, in_filename: str, out_filename: str, per_sequence_loss: bool): # schema metadata = { "model_type": model_type @@ -115,40 +126,52 @@ def generate_split(model_type: str, model_path: str, conversations: list, split_ schema = pyarrow.schema(schema, metadata={"metadata_json": orjson.dumps(metadata)}) - # launch remote workers - if not ray.is_initialized(): - ray.init(ignore_reinit_error=True, num_cpus=os.cpu_count()) + # Load data + with open(in_filename, "rb") as f: + batches = f.readlines() + + random.seed(seed) # Randomized load balancing + random.shuffle(batches) + + batches = _split(batches, int(ray.available_resources()["CPU"])) + # launch remote workers handles = [convert_conversation_batch.remote( model_type=model_type, # type: ignore model_path=model_path, batch=batch, schema=schema, per_sequence_loss=per_sequence_loss - ) for batch in _split(conversations, int(ray.available_resources()["CPU"]))] + ) for batch in batches] # write - parquet.write_table(pyarrow.concat_tables([ray.get(handle) for handle in handles]), f"{out_prefix}.{split_name}.parquet") + parquet.write_table(pyarrow.concat_tables([ray.get(handle) for handle in handles]), out_filename) -def generate_dataset(model_type, model_path, in_files, out_prefix, per_sequence_loss, seed, eval_ratio): - # Load conversations - conversations = [] - for filename in in_files: - with open(filename, "rt") as f: - conversations.extend(f.readlines()) +def generate_dataset(model_type, model_path, in_prefix, out_prefix, per_sequence_loss, seed): + # Initialize Ray + if not ray.is_initialized(): + ray.init(ignore_reinit_error=True, num_cpus=os.cpu_count()) - # Train-test split - random.seed(seed) - random.shuffle(conversations) - eval_num = int(eval_ratio * len(conversations)) + # Load epochs and tokenize + epoch = 0 + while True: + in_filename = f"{in_prefix}.{epoch}.jsonl" + if not os.path.exists(in_filename): + break - train_conversations = conversations[eval_num:] - eval_conversations = conversations[:eval_num] + out_filename = f"{out_prefix}.{epoch}.parquet" + generate_epoch( + seed=seed + epoch, + model_type=model_type, + model_path=model_path, + in_filename=in_filename, + out_filename=out_filename, + per_sequence_loss=per_sequence_loss + ) + gc.collect() - generate_split(model_type, model_path, train_conversations, "train", out_prefix, per_sequence_loss) - if eval_num > 0: - generate_split(model_type, model_path, eval_conversations, "eval", out_prefix, per_sequence_loss) + epoch += 1 if __name__ == "__main__": @@ -156,12 +179,11 @@ def generate_dataset(model_type, model_path, in_files, out_prefix, per_sequence_ parser.add_argument("--model-type", type=str, required=True) parser.add_argument("--model-path", type=str, required=True) - parser.add_argument("--in-files", type=str, nargs="+", required=True) + parser.add_argument("--in-prefix", type=str, required=True) parser.add_argument("--out-prefix", type=str, required=True) parser.add_argument("--per-sequence-loss", action="store_true") parser.add_argument("--seed", type=int, default=42) - parser.add_argument("--eval-ratio", type=float, default=0.005) args = parser.parse_args() generate_dataset(**vars(args)) diff --git a/ochat/evaluation/match_answer.py b/ochat/evaluation/match_answer.py index 00ddef7..2699ea1 100644 --- a/ochat/evaluation/match_answer.py +++ b/ochat/evaluation/match_answer.py @@ -109,7 +109,7 @@ def fs_cothub_bbh_match_answer(task_data, response): return False, ans else: # Free form, direct return - if ans[-1] == '.': + if len(ans) and ans[-1] == '.': ans = ans[:-1] return True, ans @@ -155,12 +155,11 @@ def _function_exists(code, func_name): return False def _try_match(content, prefix, entrypoint): - for block in content.split("```"): - # Sanitize block - block = block.strip() - if block.startswith("python"): - block = block[len("python"):] + # All markdown code blocks, as well as raw + code_blocks = [m[1] for m in re.findall(r"(\`{3}.*?\n+)([\s\S]*?)(\n+\`{3})", content)] \ + + [content] + for block in code_blocks: # Check syntax try: code_completion = prefix + block diff --git a/ochat/evaluation/run_eval.py b/ochat/evaluation/run_eval.py index 633b90c..4e57c34 100644 --- a/ochat/evaluation/run_eval.py +++ b/ochat/evaluation/run_eval.py @@ -17,6 +17,12 @@ from ochat.config import MODEL_CONFIG_MAP +def _strip_first_space(s: str): + if len(s) and s[0] == " ": + return s[1:] + return s + + @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(20), retry=retry_if_exception_type((RateLimitError, ServiceUnavailableError, ))) async def _chat_completion_with_backoff(**kwargs): return await openai.ChatCompletion.acreate(**kwargs) @@ -122,7 +128,8 @@ def get_model_answers( questions: list, condition: str, system_msg: str, - model_type: str + model_type: str, + tensor_parallel_size: int ): # Load model config if model_type is None: @@ -136,9 +143,10 @@ def get_model_answers( # Init vLLM engine engine = LLM(model, max_num_batched_tokens=model_config.model_max_context, - max_model_len=model_config.model_max_context) + max_model_len=model_config.model_max_context, + tensor_parallel_size=tensor_parallel_size) sampling_params = SamplingParams(temperature=0, - max_tokens=model_config.model_max_context, + max_tokens=None, stop_token_ids=conv_template.eot_tokens_, # Override stop tokens ignore_eos=True) @@ -149,8 +157,7 @@ def get_model_answers( # calculate & fill in responses responses = engine.generate(prompt_token_ids=prompts, sampling_params=sampling_params) for idx, resp in zip(prompt_indices, responses): - questions[idx]["response"] = resp.outputs[0].text - + questions[idx]["response"] = _strip_first_space(resp.outputs[0].text) return questions @@ -167,7 +174,8 @@ async def run_eval( continue_from: Optional[str], output_file: str, - parallel: int + parallel: int, + tensor_parallel_size: int ): print (f"Evaluating ({model_type})...\n\nCondition: {condition}\nSystem Prompt: {system_msg}\n") @@ -201,7 +209,7 @@ async def run_eval( if model.startswith("gpt-3.5-turbo") or model.startswith("gpt-4"): questions = await get_openai_answers(model, questions, parallel) else: - questions = get_model_answers(model, questions, condition, system_msg, model_type) + questions = get_model_answers(model, questions, condition, system_msg, model_type, tensor_parallel_size) # Calculate accuracy for q in questions: @@ -235,6 +243,7 @@ async def main(): parser.add_argument("--continue_from", type=str, default=None) parser.add_argument("--output_file", type=str, default=None) parser.add_argument("--parallel", type=int, default=16) + parser.add_argument("--tensor-parallel-size", type=int, default=1) args = parser.parse_args() diff --git a/ochat/models/__init__.py b/ochat/models/__init__.py index 7307e83..973bdf5 100644 --- a/ochat/models/__init__.py +++ b/ochat/models/__init__.py @@ -1,2 +1,3 @@ from ochat.models.unpadded_llama import LlamaForCausalLM from ochat.models.unpadded_mistral import MistralForCausalLM +from ochat.models.unpadded_gemma import GemmaForCausalLM diff --git a/ochat/models/unpadded_gemma.py b/ochat/models/unpadded_gemma.py new file mode 100644 index 0000000..e1e7a80 --- /dev/null +++ b/ochat/models/unpadded_gemma.py @@ -0,0 +1,379 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch Unpadded & Fused Gemma model. Compatible with HF. """ + +from typing import Optional, Tuple + +import torch +import torch.utils.checkpoint +import torch.nn.functional as F +from torch import nn + +from transformers.activations import ACT2FN +from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging +from transformers.models.gemma.configuration_gemma import GemmaConfig + +try: + from flash_attn.flash_attn_interface import flash_attn_varlen_func + from flash_attn.bert_padding import pad_input +except ImportError: + print ("FlashAttention not found. Install it if you need to train models.") + + +def rotate_half(x: torch.Tensor): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, position_ids: torch.Tensor): + # q, k: [nnz, num_heads, head_dim] + # position_ids: [nnz] + # cos, sin: [max_seq_len, head_dim] + cos = cos[position_ids].unsqueeze(-2) # [nnz, 1, head_dim] + sin = sin[position_ids].unsqueeze(-2) # [nnz, 1, head_dim] + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +@torch.jit.script +def lm_head_with_loss(embed_weights: torch.Tensor, hidden_states: torch.Tensor, nz_shifted_label_ids: torch.Tensor, nz_shifted_loss_weights: torch.Tensor): + logits = F.linear(hidden_states, embed_weights) + + loss = (nz_shifted_loss_weights * torch.nn.functional.cross_entropy(logits, nz_shifted_label_ids, reduction="none")).sum() + token_accuracy = (nz_shifted_loss_weights * (torch.argmax(logits.detach(), dim=-1) == nz_shifted_label_ids)).sum() + return loss, token_accuracy + + +# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Gemma +RMS_NORM_TRACED = None + + +def rms_norm(hidden_states: torch.Tensor, weight: torch.Tensor, variance_epsilon: torch.Tensor): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + + variance = hidden_states.square().mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon) + return (1 + weight) * hidden_states.to(input_dtype) + + +class UnpaddedGemmaRMSNorm(nn.Module): + def __init__(self, hidden_size, eps): + """ + UnpaddedGemmaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + + self.weight = nn.Parameter(torch.zeros(hidden_size)) + self.variance_epsilon = torch.tensor(eps, dtype=torch.get_default_dtype()) + + global RMS_NORM_TRACED + if RMS_NORM_TRACED is None: + RMS_NORM_TRACED = torch.jit.trace(rms_norm, (torch.ones(hidden_size), torch.ones(hidden_size), self.variance_epsilon)) + + def forward(self, hidden_states): + global RMS_NORM_TRACED + return RMS_NORM_TRACED(hidden_states, self.weight, self.variance_epsilon) + + +# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Gemma +class UnpaddedGemmaRotaryEmbedding(torch.nn.Module): + def __init__(self, dim, max_position_embeddings, base, device=None): + super().__init__() + + # RoPE + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)) + t = torch.arange(max_position_embeddings, dtype=torch.float32, device=device) + freqs = torch.outer(t, inv_freq) + + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + dtype = torch.get_default_dtype() + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + def forward(self): + return self.cos_cached, self.sin_cached + + +class UnpaddedGemmaMLP(nn.Module): + def __init__(self, config: GemmaConfig): + super().__init__() + + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + +class UnpaddedGemmaAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: GemmaConfig): + super().__init__() + + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = config.head_dim + self.num_key_value_heads = config.num_key_value_heads + + if self.hidden_size % self.num_heads != 0: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) + + def forward( + self, + cos_sin: Tuple[torch.Tensor, torch.Tensor], + # Unpadded inputs + nz_hidden_states: torch.Tensor, + nz_position_ids: torch.LongTensor, + cu_seqlens: torch.Tensor, + max_seqlen: int + ) -> torch.Tensor: + # nz_hidden_states: [nnz, num_heads, head_dim] + # nz_position_ids: [nnz] + # cu_seqlens: [bs + 1] + + query_states = self.q_proj(nz_hidden_states).view(-1, self.num_heads, self.head_dim) + key_states = self.k_proj(nz_hidden_states).view(-1, self.num_key_value_heads, self.head_dim) + value_states = self.v_proj(nz_hidden_states).view(-1, self.num_key_value_heads, self.head_dim) + + # RoPE + cos, sin = cos_sin + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, nz_position_ids) + + # flash attn + attn_output = flash_attn_varlen_func( + q=query_states, k=key_states, v=value_states, + cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, max_seqlen_k=max_seqlen, + + dropout_p=0.0, causal=True) + + # attn_output: [total_nnz, num_heads, head_dim] + attn_output = attn_output.view(-1, self.num_heads * self.head_dim) # type: ignore + return self.o_proj(attn_output) + + +class UnpaddedGemmaDecoderLayer(nn.Module): + def __init__(self, config: GemmaConfig): + super().__init__() + + self.hidden_size = config.hidden_size + self.self_attn = UnpaddedGemmaAttention(config=config) + self.mlp = UnpaddedGemmaMLP(config=config) + self.input_layernorm = UnpaddedGemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = UnpaddedGemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + cos_sin: Tuple[torch.Tensor, torch.Tensor], + # Unpadded inputs + nz_hidden_states: torch.Tensor, + nz_position_ids: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: int + ) -> torch.Tensor: + # Self Attention + residual = nz_hidden_states + + nz_hidden_states = self.input_layernorm(nz_hidden_states) + nz_hidden_states = self.self_attn( + cos_sin=cos_sin, + + nz_hidden_states=nz_hidden_states, + nz_position_ids=nz_position_ids, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen + ) + nz_hidden_states = residual + nz_hidden_states + + # Fully Connected + residual = nz_hidden_states + + nz_hidden_states = self.post_attention_layernorm(nz_hidden_states) + nz_hidden_states = self.mlp(nz_hidden_states) + nz_hidden_states = residual + nz_hidden_states + + return nz_hidden_states + + +class UnpaddedGemmaPreTrainedModel(PreTrainedModel): + config_class = GemmaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["UnpaddedGemmaDecoderLayer"] + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +class UnpaddedGemmaModel(UnpaddedGemmaPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`UnpaddedGemmaDecoderLayer`] + + Args: + config: GemmaConfig + """ + + def __init__(self, config: GemmaConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.normalization_factor = config.hidden_size ** 0.5 + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.rotary_emb = UnpaddedGemmaRotaryEmbedding(config.head_dim, + max_position_embeddings=config.max_position_embeddings, + base=config.rope_theta) + + self.layers = nn.ModuleList([UnpaddedGemmaDecoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.norm = UnpaddedGemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def forward( + self, + # Unpadded inputs + nz_input_ids: torch.Tensor, + nz_position_ids: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: int, + ) -> torch.Tensor: + nz_hidden_states = self.embed_tokens(nz_input_ids) * self.normalization_factor # Normalized + cos_sin = self.rotary_emb() + + # decoder layers + for decoder_layer in self.layers: + if self.gradient_checkpointing and self.training: + nz_hidden_states = self._gradient_checkpointing_func( + decoder_layer.__call__, + + cos_sin, + nz_hidden_states, + nz_position_ids, + cu_seqlens, + max_seqlen + ) + else: + nz_hidden_states = decoder_layer( + cos_sin, + + nz_hidden_states, + nz_position_ids, + cu_seqlens, + max_seqlen + ) + + nz_hidden_states = self.norm(nz_hidden_states) + + return nz_hidden_states + + +class GemmaForCausalLM(UnpaddedGemmaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.model = UnpaddedGemmaModel(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.model.embed_tokens + + def set_output_embeddings(self, new_embeddings): + self.model.embed_tokens = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def forward( + self, + # Unpadded inputs + nz_input_ids: torch.Tensor, + nz_position_ids: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: int, + # Unpadded labels + nz_shifted_label_ids: Optional[torch.Tensor] = None, + nz_shifted_loss_weights: Optional[torch.Tensor] = None + ) -> CausalLMOutputWithPast: + # Model logits + hidden_states = self.model( + nz_input_ids=nz_input_ids, + nz_position_ids=nz_position_ids, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen + ) + + # Loss + loss = lm_head_with_loss( + self.model.embed_tokens.weight, # Tied embeddings + hidden_states, + nz_shifted_label_ids, + nz_shifted_loss_weights + ) + + return CausalLMOutputWithPast( + loss=loss # type: ignore + ) diff --git a/ochat/scripts/modify_eos_embedding.py b/ochat/scripts/modify_eos_embedding.py new file mode 100644 index 0000000..406906c --- /dev/null +++ b/ochat/scripts/modify_eos_embedding.py @@ -0,0 +1,38 @@ +import argparse + +import transformers +import torch + + +def modify_eos_embeddings(model_path, output_dir): + tokenizer = transformers.AutoTokenizer.from_pretrained(model_path) + model = transformers.AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16) + + eos_token_id = tokenizer.eos_token_id + + print (f"EOS Token {tokenizer.convert_ids_to_tokens(eos_token_id)} ID {eos_token_id}") + with torch.no_grad(): + model.model.embed_tokens.weight[eos_token_id] = torch.mean(model.model.embed_tokens.weight, dim=0) + model.lm_head.weight[eos_token_id] = torch.mean(model.lm_head.weight, dim=0) + + # Save + tokenizer.save_pretrained(output_dir) + model.save_pretrained(output_dir) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--model-path", + help="Location of model, or HuggingFace repo ID", + ) + parser.add_argument( + "--output-dir", + help="Location to write resulting model and tokenizer", + ) + + modify_eos_embeddings(**vars(parser.parse_args())) + + +if __name__ == "__main__": + main() diff --git a/ochat/training_deepspeed/train.py b/ochat/training_deepspeed/train.py index c0caa45..ee8ad5c 100644 --- a/ochat/training_deepspeed/train.py +++ b/ochat/training_deepspeed/train.py @@ -55,12 +55,12 @@ def parse_args(): return args -def create_dataset_and_dataloader(args, split_name): +def create_dataset_and_dataloader(args, epoch: int): # Find data - filename = f"{args.data_prefix}.{split_name}.parquet" + filename = f"{args.data_prefix}.{epoch}.parquet" # Create dataset and dataloader - print(f"Loading {split_name} data from {filename}...") + print(f"Loading epoch {epoch} data from {filename}...") dataset = OpenchatDataset( dataset_filename=filename, @@ -76,8 +76,7 @@ def create_dataset_and_dataloader(args, split_name): num_workers=1, prefetch_factor=8, - pin_memory=True, - persistent_workers=True + pin_memory=True ) return dataset, dataloader @@ -158,6 +157,8 @@ def calculate_auto_lr(lr, batch_max_len, model_type, train_dataset): base_bs = 4_000_000 if "mistral" in model_type.lower(): base_lr /= 6.0 + elif "gemma" in model_type.lower(): + base_lr /= 5.5 # NOTE(one): Maybe MLP and Attn layers are using different lr? loss_weights = np.concatenate(train_dataset.dataset["nz_shifted_loss_weights"]) supervised_ratio = np.sum(loss_weights != 0) / len(loss_weights) @@ -192,7 +193,7 @@ def train(): args = parse_args() # Dataset - train_dataset, train_loader = create_dataset_and_dataloader(args, "train") + train_dataset, train_loader = create_dataset_and_dataloader(args, 0) if train_dataset is None: raise RuntimeError("Training data not found.") @@ -224,6 +225,12 @@ def train(): for epoch in range(args.epochs): print (f"[rank {RANK} of {WORLD_SIZE}]: Epoch {epoch}") + ############ Load Dataset + if epoch != 0: + del train_dataset, train_loader + + train_dataset, train_loader = create_dataset_and_dataloader(args, epoch) + ############ Train Epoch model_engine.train() for (batch_tensor, batch_info), all_numseq, cur_numseq in train_loader: diff --git a/pyproject.toml b/pyproject.toml index b6a3119..e084f94 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,14 +26,14 @@ dependencies = [ "torch", "ray", "sentencepiece", - "transformers>=4.35.0", + "transformers>=4.38.2", "accelerate", "protobuf", "fastapi", "pydantic", "shortuuid", "uvicorn", - "vllm", + "vllm>=0.3.3", "pytest" ]