From 84d44764fcfd9e0dee96cf78735847ffb429757f Mon Sep 17 00:00:00 2001 From: Glavin Wiechert Date: Sun, 3 Sep 2023 05:36:45 +0000 Subject: [PATCH 01/22] WIP Add training callback to send predictions to WandB table --- src/axolotl/utils/callbacks.py | 154 +++++++++++++++++++++++++++++++++ src/axolotl/utils/trainer.py | 4 + 2 files changed, 158 insertions(+) diff --git a/src/axolotl/utils/callbacks.py b/src/axolotl/utils/callbacks.py index ee5acfd555..40a84450fa 100644 --- a/src/axolotl/utils/callbacks.py +++ b/src/axolotl/utils/callbacks.py @@ -15,6 +15,7 @@ from optimum.bettertransformer import BetterTransformer from tqdm import tqdm from transformers import ( + Trainer, TrainerCallback, TrainerControl, TrainerState, @@ -22,6 +23,7 @@ ) from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy +import wandb from axolotl.utils.bench import log_gpu_memory_usage from axolotl.utils.distributed import ( barrier, @@ -313,3 +315,155 @@ def on_evaluate( trainer.log(results) return BenchEvalCallback + + +def log_prediction_callback_factory(trainer: Trainer, tokenizer): + class LogPredictionCallback(TrainerCallback): + """Callback to log prediction values during each evaluation""" + + def __init__(self, cfg): + self.cfg = cfg + self.logged = False + + def on_evaluate( + self, + args: AxolotlTrainingArguments, + state: TrainerState, + control: TrainerControl, + # model, + # tokenizer, + eval_dataloader, + **kwargs, + ): + LOG.info("logging predictions") + + # Initialize an empty wandb.Table + table = wandb.Table(columns=["Prediction", "Ground Truth"]) + + # Iterate over the evaluation data + # for batch in eval_dataloader: + # inputs, labels = batch + # predictions = model(inputs) + + # # Convert the predictions and labels to a readable format + # predictions = tokenizer.decode(predictions) + # labels = tokenizer.decode(labels) + + # # Add the data to the wandb.Table + # table.add_data(predictions, labels) + + # Generate fake data for the table + # for _ in range(10): + # fake_prediction = "Fake Prediction " + str(_) + # fake_ground_truth = "Fake Ground Truth " + str(_) + # table.add_data(fake_prediction, fake_ground_truth) + + print(dir(eval_dataloader)) + + # eval_loop = trainer.prediction_loop if trainer.args.use_legacy_prediction_loop else trainer.evaluation_loop + # output = eval_loop( + # eval_dataloader, + # description="Evaluation", + # # No point gathering the predictions if there are no metrics, otherwise we defer to + # # self.args.prediction_loss_only + # # prediction_loss_only=True if trainer.compute_metrics is None else None, + # prediction_loss_only=False, + # # ignore_keys=ignore_keys, + # # metric_key_prefix=metric_key_prefix, + # ) + + # print(type(output)) + # print(dir(output)) + # print(output.predictions) + # print(output.label_ids) + # print(output.metrics) + + # # Extract the predictions and labels from the output + # predictions = output.predictions + # labels = output.label_ids + # # Convert the predictions and labels to a readable format + # predictions = [tokenizer.decode(p) for p in predictions] + # labels = [tokenizer.decode(l) for l in labels] + + # # Add the data to the wandb.Table + # for prediction, label in zip(predictions, labels): + # table.add_data(prediction, label) + + trainer.model.eval() + # preds, refs = [], [] + # loss_bench = 0 + predictions = [] + for batch in tqdm(eval_dataloader, total=len(eval_dataloader)): + (loss, logits, labels) = trainer.prediction_step( + trainer.model, + batch, + prediction_loss_only=False, + ) + + print("logits", logits) + print("labels", labels) + + pred_tokens = [] + for i, logit in enumerate(logits): + print(dir(logit)) + print(logit) + print(logit.shape) + # # Convert the logits to probabilities using softmax + # probabilities = torch.softmax(logit, dim=-1) + + # # Get the predicted token id (the one with the highest probability) + # predicted_token_id = torch.argmax(probabilities).item() + + # # Decode the predicted token id to get the plaintext + # predicted_token = tokenizer.decode([predicted_token_id]) + + # # Append the predicted token to the preds list + # pred_tokens.append(predicted_token) + + # Convert the logits to probabilities using softmax + probabilities = torch.softmax(logit, dim=-1) + + # Get the predicted token ids (the ones with the highest probability) + predicted_token_ids = torch.argmax(probabilities, dim=-1) + + # Decode the predicted token ids to get the plaintext + predicted_tokens = tokenizer.batch_decode(predicted_token_ids) + + # Append the predicted tokens to the preds list + pred_tokens.extend(predicted_tokens) + + # add prediction + # convert pred_tokens to a single string + pred_string = " ".join(pred_tokens) + predictions.append(pred_string) + + # # Convert the predictions and labels to a readable format + # # predictions = [tokenizer.decode(p) for p in logits] + # # labels = [tokenizer.decode(l) for l in labels] + + # # Add the data to the wandb.Table + # for prediction, label in zip(predictions, labels): + # table.add_data(prediction, label) + + # using trainer.model generate prediction tokens for each input in eval_dataloader + # predictions = [] + # for batch in eval_dataloader: + # inputs, _ = batch + # print(inputs) + # with torch.no_grad(): + # outputs = trainer.model(inputs) + # print(outputs) + # next_pred = [tokenizer.decode(p) for p in outputs.logits.argmax(dim=-1).tolist()] + # print(next_pred) + # predictions.extend(next_pred) + + # add the predictions to the table + for prediction in predictions: + table.add_data(prediction, "Ground Truth") + + # Log the wandb.Table + wandb.log({"Predictions vs Ground Truth": table}) + + return control + + return LogPredictionCallback diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index f91f4e318e..0ed12c7b6d 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -30,6 +30,7 @@ SaveBetterTransformerModelCallback, SavePeftModelCallback, bench_eval_callback_factory, + log_prediction_callback_factory, ) from axolotl.utils.collators import DataCollatorForSeq2Seq from axolotl.utils.dataloader import MultipackDistributedDataloader @@ -719,6 +720,9 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_ **trainer_kwargs, ) + LogPredictionCallback = log_prediction_callback_factory(trainer, tokenizer) + trainer.add_callback(LogPredictionCallback(cfg)) + if cfg.do_bench_eval: trainer.add_callback(bench_eval_callback_factory(trainer, tokenizer)) From 0c743e31f6ec99d44c45b5f380ef10a40ce535d9 Mon Sep 17 00:00:00 2001 From: Glavin Wiechert Date: Tue, 5 Sep 2023 06:32:21 +0000 Subject: [PATCH 02/22] WIP improve wandb table reporting callback --- src/axolotl/utils/callbacks.py | 133 +++++++++++++++++++++++++++------ src/axolotl/utils/trainer.py | 5 +- 2 files changed, 114 insertions(+), 24 deletions(-) diff --git a/src/axolotl/utils/callbacks.py b/src/axolotl/utils/callbacks.py index da0458fdc7..dc97777e0d 100644 --- a/src/axolotl/utils/callbacks.py +++ b/src/axolotl/utils/callbacks.py @@ -342,7 +342,7 @@ def on_evaluate( LOG.info("logging predictions") # Initialize an empty wandb.Table - table = wandb.Table(columns=["Prediction", "Ground Truth"]) + table = wandb.Table(columns=["id", "Prompt", "Correct Completion", "Predicted Completion"]) # Iterate over the evaluation data # for batch in eval_dataloader: @@ -362,7 +362,7 @@ def on_evaluate( # fake_ground_truth = "Fake Ground Truth " + str(_) # table.add_data(fake_prediction, fake_ground_truth) - print(dir(eval_dataloader)) + # print(dir(eval_dataloader)) # eval_loop = trainer.prediction_loop if trainer.args.use_legacy_prediction_loop else trainer.evaluation_loop # output = eval_loop( @@ -394,24 +394,62 @@ def on_evaluate( # table.add_data(prediction, label) trainer.model.eval() + + def logits_to_tokens(logits) -> str: + + probabilities = torch.softmax(logits, dim=-1) + + # Get the predicted token ids (the ones with the highest probability) + predicted_token_ids = torch.argmax(probabilities, dim=-1) + + # Decode the predicted token ids to get the plaintext + # predicted_tokens = tokenizer.batch_decode(predicted_token_ids) + # predicted_tokens = tokenizer.decode(predicted_token_ids) + # return predicted_tokens + + return predicted_token_ids + + # preds, refs = [], [] # loss_bench = 0 - predictions = [] + # predictions = [] + id = 0 for batch in tqdm(eval_dataloader, total=len(eval_dataloader)): + + # batch.data['labels'].shape + # torch.Size([2, 320]) + # values at front with -100 are supposed to be prompt tokens + # values after are completion tokens + + # batch.data['input_ids'].shape + # torch.Size([2, 320]) + + # # Extract prompt and completion tokens from input_ids based on labels + # prompt_token_ids = batch.data['input_ids'][batch.data['labels'] == IGNORE_INDEX] + # completion_token_ids = batch.data['input_ids'][batch.data['labels'] != IGNORE_INDEX] + + # # prompt_texts = tokenizer.batch_decode(batch.data['input_ids']) + # prompt_texts = tokenizer.batch_decode(prompt_token_ids) + # completion_texts = tokenizer.batch_decode(completion_token_ids) + (loss, logits, labels) = trainer.prediction_step( trainer.model, batch, prediction_loss_only=False, ) - print("logits", logits) - print("labels", labels) + # prompt_completion_pairs = zip(prompt_texts, logits) - pred_tokens = [] - for i, logit in enumerate(logits): - print(dir(logit)) - print(logit) - print(logit.shape) + # print("logits", logits) + # print("labels", labels) + + # pred_tokens = [] + # for i, logit in enumerate(logits): + for i, (logit, labels_i) in enumerate(zip(logits, labels)): + # for i, (prompt_text, logit) in enumerate(prompt_completion_pairs): + # print(dir(logit)) + # print(logit) + # print(logit.shape) # # Convert the logits to probabilities using softmax # probabilities = torch.softmax(logit, dim=-1) @@ -424,22 +462,67 @@ def on_evaluate( # # Append the predicted token to the preds list # pred_tokens.append(predicted_token) - # Convert the logits to probabilities using softmax - probabilities = torch.softmax(logit, dim=-1) + # # Convert the logits to probabilities using softmax + # probabilities = torch.softmax(logit, dim=-1) + + # # Get the predicted token ids (the ones with the highest probability) + # predicted_token_ids = torch.argmax(probabilities, dim=-1) + + # # Decode the predicted token ids to get the plaintext + # predicted_tokens = tokenizer.batch_decode(predicted_token_ids) + + # + # label_non_zero_indices = (batch["labels"][i] != IGNORE_INDEX).nonzero().transpose(0, 1)[0] # FIXME: clean up? + + prompt_token_indices = (batch["labels"][i] == IGNORE_INDEX).nonzero().transpose(0, 1)[0] # FIXME: clean up? + completion_token_indices = (batch["labels"][i] != IGNORE_INDEX).nonzero().transpose(0, 1)[0] # FIXME: clean up? - # Get the predicted token ids (the ones with the highest probability) - predicted_token_ids = torch.argmax(probabilities, dim=-1) + # Extract prompt and completion tokens from input_ids based on labels + # prompt_token_ids = batch['input_ids'][batch['labels'] == IGNORE_INDEX] + # completion_token_ids = batch['input_ids'][batch['labels'] != IGNORE_INDEX] - # Decode the predicted token ids to get the plaintext - predicted_tokens = tokenizer.batch_decode(predicted_token_ids) + # prompt_token_ids = batch['input_ids'][batch['labels'] == IGNORE_INDEX] + # prompt_token_ids = batch['input_ids'][label_non_zero_indices] + # prompt_token_ids = batch['input_ids'][i][label_non_zero_indices] + # prompt_token_ids = batch['input_ids'][i] + + prompt_token_ids = batch['input_ids'][i][prompt_token_indices] + completion_token_ids = batch['input_ids'][i][completion_token_indices] + + # prompt_texts = tokenizer.batch_decode(batch.data['input_ids']) + # prompt_texts = tokenizer.batch_decode(prompt_token_ids) + prompt_text = tokenizer.decode(prompt_token_ids) + completion_text = tokenizer.decode(completion_token_ids) + + completion_logit = logit[completion_token_indices] + # predicted_tokens = logits_to_tokens(logit) + predicted_tokens = logits_to_tokens(completion_logit) # Append the predicted tokens to the preds list - pred_tokens.extend(predicted_tokens) + # pred_tokens.extend(predicted_tokens) + # pred_string = " ".join(predicted_tokens) # FIXME: missing spaces + prediction_text = tokenizer.decode(predicted_tokens) + + # print("=" * 80) + # print("Prompt:") + # print(prompt_text) + # print("=" * 80) + # print("Expected Completion:") + # print(completion_text) + # print("=" * 80) + # print("Predicted Completion:") + # print(prediction_text) + # print("=" * 80) + + table.add_data(id, prompt_text, completion_text, prediction_text) + id += 1 # add prediction # convert pred_tokens to a single string - pred_string = " ".join(pred_tokens) - predictions.append(pred_string) + # pred_string = " ".join(pred_tokens) + # predictions.append(pred_string) + + # table.add_data(prompt_text, pred_string, "Ground Truth") # # Convert the predictions and labels to a readable format # # predictions = [tokenizer.decode(p) for p in logits] @@ -462,11 +545,17 @@ def on_evaluate( # predictions.extend(next_pred) # add the predictions to the table - for prediction in predictions: - table.add_data(prediction, "Ground Truth") + # for prediction in predictions: + # table.add_data(prediction, "Ground Truth") + + # print table size + # print("Table size:", len(table.data)) + + # print first entry in table + # print("First entry in table:", table.data[0]) # Log the wandb.Table - wandb.log({"Predictions vs Ground Truth": table}) + wandb.run.log({"Predictions vs Ground Truth": table}) return control diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 0ed12c7b6d..57d503d397 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -720,8 +720,9 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_ **trainer_kwargs, ) - LogPredictionCallback = log_prediction_callback_factory(trainer, tokenizer) - trainer.add_callback(LogPredictionCallback(cfg)) + if cfg.use_wandb: + LogPredictionCallback = log_prediction_callback_factory(trainer, tokenizer) + trainer.add_callback(LogPredictionCallback(cfg)) if cfg.do_bench_eval: trainer.add_callback(bench_eval_callback_factory(trainer, tokenizer)) From 5a7f301d548c76147ea04d1bc127c1ad4a036c5a Mon Sep 17 00:00:00 2001 From: Glavin Wiechert Date: Tue, 5 Sep 2023 08:03:03 +0000 Subject: [PATCH 03/22] WIP improve wandb table reporting callback (cont) --- src/axolotl/utils/callbacks.py | 374 +++++++++++++++------------------ src/axolotl/utils/models.py | 4 +- 2 files changed, 176 insertions(+), 202 deletions(-) diff --git a/src/axolotl/utils/callbacks.py b/src/axolotl/utils/callbacks.py index dc97777e0d..6e635f02bc 100644 --- a/src/axolotl/utils/callbacks.py +++ b/src/axolotl/utils/callbacks.py @@ -336,226 +336,200 @@ def on_evaluate( control: TrainerControl, # model, # tokenizer, + train_dataloader, eval_dataloader, **kwargs, ): LOG.info("logging predictions") - # Initialize an empty wandb.Table - table = wandb.Table(columns=["id", "Prompt", "Correct Completion", "Predicted Completion"]) - - # Iterate over the evaluation data - # for batch in eval_dataloader: - # inputs, labels = batch - # predictions = model(inputs) - - # # Convert the predictions and labels to a readable format - # predictions = tokenizer.decode(predictions) - # labels = tokenizer.decode(labels) - - # # Add the data to the wandb.Table - # table.add_data(predictions, labels) - - # Generate fake data for the table - # for _ in range(10): - # fake_prediction = "Fake Prediction " + str(_) - # fake_ground_truth = "Fake Ground Truth " + str(_) - # table.add_data(fake_prediction, fake_ground_truth) - - # print(dir(eval_dataloader)) - - # eval_loop = trainer.prediction_loop if trainer.args.use_legacy_prediction_loop else trainer.evaluation_loop - # output = eval_loop( - # eval_dataloader, - # description="Evaluation", - # # No point gathering the predictions if there are no metrics, otherwise we defer to - # # self.args.prediction_loss_only - # # prediction_loss_only=True if trainer.compute_metrics is None else None, - # prediction_loss_only=False, - # # ignore_keys=ignore_keys, - # # metric_key_prefix=metric_key_prefix, - # ) - - # print(type(output)) - # print(dir(output)) - # print(output.predictions) - # print(output.label_ids) - # print(output.metrics) - - # # Extract the predictions and labels from the output - # predictions = output.predictions - # labels = output.label_ids - # # Convert the predictions and labels to a readable format - # predictions = [tokenizer.decode(p) for p in predictions] - # labels = [tokenizer.decode(l) for l in labels] - - # # Add the data to the wandb.Table - # for prediction, label in zip(predictions, labels): - # table.add_data(prediction, label) trainer.model.eval() - def logits_to_tokens(logits) -> str: + def logits_to_tokens(logits) -> str: probabilities = torch.softmax(logits, dim=-1) - # Get the predicted token ids (the ones with the highest probability) predicted_token_ids = torch.argmax(probabilities, dim=-1) - - # Decode the predicted token ids to get the plaintext - # predicted_tokens = tokenizer.batch_decode(predicted_token_ids) - # predicted_tokens = tokenizer.decode(predicted_token_ids) - # return predicted_tokens - return predicted_token_ids + def log_table_from_dataloader(name: str, table_dataloader): - # preds, refs = [], [] - # loss_bench = 0 - # predictions = [] - id = 0 - for batch in tqdm(eval_dataloader, total=len(eval_dataloader)): - - # batch.data['labels'].shape - # torch.Size([2, 320]) - # values at front with -100 are supposed to be prompt tokens - # values after are completion tokens - - # batch.data['input_ids'].shape - # torch.Size([2, 320]) - - # # Extract prompt and completion tokens from input_ids based on labels - # prompt_token_ids = batch.data['input_ids'][batch.data['labels'] == IGNORE_INDEX] - # completion_token_ids = batch.data['input_ids'][batch.data['labels'] != IGNORE_INDEX] - - # # prompt_texts = tokenizer.batch_decode(batch.data['input_ids']) - # prompt_texts = tokenizer.batch_decode(prompt_token_ids) - # completion_texts = tokenizer.batch_decode(completion_token_ids) - - (loss, logits, labels) = trainer.prediction_step( - trainer.model, - batch, - prediction_loss_only=False, - ) - - # prompt_completion_pairs = zip(prompt_texts, logits) - - # print("logits", logits) - # print("labels", labels) - - # pred_tokens = [] - # for i, logit in enumerate(logits): - for i, (logit, labels_i) in enumerate(zip(logits, labels)): - # for i, (prompt_text, logit) in enumerate(prompt_completion_pairs): - # print(dir(logit)) - # print(logit) - # print(logit.shape) - # # Convert the logits to probabilities using softmax - # probabilities = torch.softmax(logit, dim=-1) - - # # Get the predicted token id (the one with the highest probability) - # predicted_token_id = torch.argmax(probabilities).item() - - # # Decode the predicted token id to get the plaintext - # predicted_token = tokenizer.decode([predicted_token_id]) - - # # Append the predicted token to the preds list - # pred_tokens.append(predicted_token) - - # # Convert the logits to probabilities using softmax - # probabilities = torch.softmax(logit, dim=-1) - - # # Get the predicted token ids (the ones with the highest probability) - # predicted_token_ids = torch.argmax(probabilities, dim=-1) - - # # Decode the predicted token ids to get the plaintext - # predicted_tokens = tokenizer.batch_decode(predicted_token_ids) - - # - # label_non_zero_indices = (batch["labels"][i] != IGNORE_INDEX).nonzero().transpose(0, 1)[0] # FIXME: clean up? - - prompt_token_indices = (batch["labels"][i] == IGNORE_INDEX).nonzero().transpose(0, 1)[0] # FIXME: clean up? - completion_token_indices = (batch["labels"][i] != IGNORE_INDEX).nonzero().transpose(0, 1)[0] # FIXME: clean up? + # Initialize an empty wandb.Table + table = wandb.Table(columns=["id", "Prompt", "Correct Completion", "Predicted Completion 1", "Predicted Completion 2"]) - # Extract prompt and completion tokens from input_ids based on labels - # prompt_token_ids = batch['input_ids'][batch['labels'] == IGNORE_INDEX] - # completion_token_ids = batch['input_ids'][batch['labels'] != IGNORE_INDEX] + # preds, refs = [], [] + # loss_bench = 0 + # predictions = [] + id = 0 + for batch in tqdm(table_dataloader, total=len(table_dataloader)): + # max_examples = 100 + # for batch in tqdm(table_dataloader, total=min(max_examples, len(table_dataloader))): - # prompt_token_ids = batch['input_ids'][batch['labels'] == IGNORE_INDEX] - # prompt_token_ids = batch['input_ids'][label_non_zero_indices] - # prompt_token_ids = batch['input_ids'][i][label_non_zero_indices] - # prompt_token_ids = batch['input_ids'][i] + # batch.data['labels'].shape + # torch.Size([2, 320]) + # values at front with -100 are supposed to be prompt tokens + # values after are completion tokens - prompt_token_ids = batch['input_ids'][i][prompt_token_indices] - completion_token_ids = batch['input_ids'][i][completion_token_indices] + # batch.data['input_ids'].shape + # torch.Size([2, 320]) + + # # Extract prompt and completion tokens from input_ids based on labels + # prompt_token_ids = batch.data['input_ids'][batch.data['labels'] == IGNORE_INDEX] + # completion_token_ids = batch.data['input_ids'][batch.data['labels'] != IGNORE_INDEX] - # prompt_texts = tokenizer.batch_decode(batch.data['input_ids']) + # # prompt_texts = tokenizer.batch_decode(batch.data['input_ids']) # prompt_texts = tokenizer.batch_decode(prompt_token_ids) - prompt_text = tokenizer.decode(prompt_token_ids) - completion_text = tokenizer.decode(completion_token_ids) - - completion_logit = logit[completion_token_indices] - # predicted_tokens = logits_to_tokens(logit) - predicted_tokens = logits_to_tokens(completion_logit) - - # Append the predicted tokens to the preds list - # pred_tokens.extend(predicted_tokens) - # pred_string = " ".join(predicted_tokens) # FIXME: missing spaces - prediction_text = tokenizer.decode(predicted_tokens) - - # print("=" * 80) - # print("Prompt:") - # print(prompt_text) - # print("=" * 80) - # print("Expected Completion:") - # print(completion_text) - # print("=" * 80) - # print("Predicted Completion:") - # print(prediction_text) - # print("=" * 80) - - table.add_data(id, prompt_text, completion_text, prediction_text) - id += 1 - - # add prediction - # convert pred_tokens to a single string - # pred_string = " ".join(pred_tokens) - # predictions.append(pred_string) - - # table.add_data(prompt_text, pred_string, "Ground Truth") - - # # Convert the predictions and labels to a readable format - # # predictions = [tokenizer.decode(p) for p in logits] - # # labels = [tokenizer.decode(l) for l in labels] - - # # Add the data to the wandb.Table - # for prediction, label in zip(predictions, labels): - # table.add_data(prediction, label) - - # using trainer.model generate prediction tokens for each input in eval_dataloader - # predictions = [] - # for batch in eval_dataloader: - # inputs, _ = batch - # print(inputs) - # with torch.no_grad(): - # outputs = trainer.model(inputs) - # print(outputs) - # next_pred = [tokenizer.decode(p) for p in outputs.logits.argmax(dim=-1).tolist()] - # print(next_pred) - # predictions.extend(next_pred) - - # add the predictions to the table - # for prediction in predictions: - # table.add_data(prediction, "Ground Truth") - - # print table size - # print("Table size:", len(table.data)) - - # print first entry in table - # print("First entry in table:", table.data[0]) - - # Log the wandb.Table - wandb.run.log({"Predictions vs Ground Truth": table}) + # completion_texts = tokenizer.batch_decode(completion_token_ids) + + (loss, logits, labels) = trainer.prediction_step( + trainer.model, + batch, + prediction_loss_only=False, + ) + + # prompt_completion_pairs = zip(prompt_texts, logits) + + # print("logits", logits) + # print("labels", labels) + + # pred_tokens = [] + # for i, logit in enumerate(logits): + for i, (logit, labels_i) in enumerate(zip(logits, labels)): + # for i, (prompt_text, logit) in enumerate(prompt_completion_pairs): + # print(dir(logit)) + # print(logit) + # print(logit.shape) + # # Convert the logits to probabilities using softmax + # probabilities = torch.softmax(logit, dim=-1) + + # # Get the predicted token id (the one with the highest probability) + # predicted_token_id = torch.argmax(probabilities).item() + + # # Decode the predicted token id to get the plaintext + # predicted_token = tokenizer.decode([predicted_token_id]) + + # # Append the predicted token to the preds list + # pred_tokens.append(predicted_token) + + # # Convert the logits to probabilities using softmax + # probabilities = torch.softmax(logit, dim=-1) + + # # Get the predicted token ids (the ones with the highest probability) + # predicted_token_ids = torch.argmax(probabilities, dim=-1) + + # # Decode the predicted token ids to get the plaintext + # predicted_tokens = tokenizer.batch_decode(predicted_token_ids) + + # + # label_non_zero_indices = (batch["labels"][i] != IGNORE_INDEX).nonzero().transpose(0, 1)[0] # FIXME: clean up? + + prompt_token_indices = (batch["labels"][i] == IGNORE_INDEX).nonzero().transpose(0, 1)[0] # FIXME: clean up? + completion_token_indices = (batch["labels"][i] != IGNORE_INDEX).nonzero().transpose(0, 1)[0] # FIXME: clean up? + + # Extract prompt and completion tokens from input_ids based on labels + # prompt_token_ids = batch['input_ids'][batch['labels'] == IGNORE_INDEX] + # completion_token_ids = batch['input_ids'][batch['labels'] != IGNORE_INDEX] + + # prompt_token_ids = batch['input_ids'][batch['labels'] == IGNORE_INDEX] + # prompt_token_ids = batch['input_ids'][label_non_zero_indices] + # prompt_token_ids = batch['input_ids'][i][label_non_zero_indices] + # prompt_token_ids = batch['input_ids'][i] + + prompt_token_ids = batch['input_ids'][i][prompt_token_indices] + completion_token_ids = batch['input_ids'][i][completion_token_indices] + + # prompt_texts = tokenizer.batch_decode(batch.data['input_ids']) + # prompt_texts = tokenizer.batch_decode(prompt_token_ids) + prompt_text = tokenizer.decode(prompt_token_ids) + completion_text = tokenizer.decode(completion_token_ids) + + completion_logit = logit[completion_token_indices] + # predicted_tokens = logits_to_tokens(logit) + predicted_tokens = logits_to_tokens(completion_logit) + + # Append the predicted tokens to the preds list + # pred_tokens.extend(predicted_tokens) + # pred_string = " ".join(predicted_tokens) # FIXME: missing spaces + prediction_text = tokenizer.decode(predicted_tokens) + + # generate new prediction with trainer.model which is a transformer model + # Generate new prediction with trainer.model which is a transformer model + with torch.no_grad(): + # new_prediction = trainer.model(batch['input_ids'][i].unsqueeze(0)) + new_prediction = trainer.model(prompt_token_ids.unsqueeze(0)) + + # Convert the logits to probabilities using softmax + new_probabilities = torch.softmax(new_prediction.logits, dim=-1) + + # Get the predicted token ids (the ones with the highest probability) + new_predicted_token_ids = torch.argmax(new_probabilities, dim=-1) + + # Decode the predicted token ids to get the plaintext + new_predicted_tokens = tokenizer.decode(new_predicted_token_ids[0]) + + # print("=" * 80) + # print("Prompt:") + # print(prompt_text) + # print("=" * 80) + # print("Expected Completion:") + # print(completion_text) + # print("=" * 80) + # print("Predicted Completion:") + # print(prediction_text) + # print("=" * 80) + + table.add_data(id, prompt_text, completion_text, prediction_text, new_predicted_tokens) + id += 1 + + # add prediction + # convert pred_tokens to a single string + # pred_string = " ".join(pred_tokens) + # predictions.append(pred_string) + + # table.add_data(prompt_text, pred_string, "Ground Truth") + + # # Convert the predictions and labels to a readable format + # # predictions = [tokenizer.decode(p) for p in logits] + # # labels = [tokenizer.decode(l) for l in labels] + + # # Add the data to the wandb.Table + # for prediction, label in zip(predictions, labels): + # table.add_data(prediction, label) + + # using trainer.model generate prediction tokens for each input in eval_dataloader + # predictions = [] + # for batch in eval_dataloader: + # inputs, _ = batch + # print(inputs) + # with torch.no_grad(): + # outputs = trainer.model(inputs) + # print(outputs) + # next_pred = [tokenizer.decode(p) for p in outputs.logits.argmax(dim=-1).tolist()] + # print(next_pred) + # predictions.extend(next_pred) + + # add the predictions to the table + # for prediction in predictions: + # table.add_data(prediction, "Ground Truth") + + # print table size + # print("Table size:", len(table.data)) + + # print first entry in table + # print("First entry in table:", table.data[0]) + + # Log the wandb.Table + wandb.run.log({ f"{name} - Predictions vs Ground Truth": table }) + + # log_table_from_dataloader("Train", train_dataloader) + # log_table_from_dataloader("Train", train_dataloader) + + # # Get first 10 records from train_dataloader as a new dataloader + # train_data_subset = [next(iter(train_dataloader)) for _ in range(10)] + # train_dataloader_subset = torch.utils.data.DataLoader(train_data_subset, batch_size=train_dataloader.batch_size, shuffle=False) + # log_table_from_dataloader("Train Subset", train_dataloader_subset) + + log_table_from_dataloader("Eval", eval_dataloader) return control diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 9f0795af76..e4ff7b5d42 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -340,10 +340,10 @@ def load_model( if ( hasattr(model.config, "max_position_embeddings") and model.config.max_position_embeddings - and cfg.sequence_len >= model.config.max_position_embeddings + and cfg.sequence_len > model.config.max_position_embeddings ): LOG.warning( - f"increasing model.config.max_position_embeddings to {cfg.sequence_len}" + f"increasing model.config.max_position_embeddings from {model.config.max_position_embeddings} to {cfg.sequence_len}" ) model.config.max_position_embeddings = cfg.sequence_len From 8c7b7c599279259fae10a474020181bbb72a0b2d Mon Sep 17 00:00:00 2001 From: Glavin Wiechert Date: Tue, 5 Sep 2023 08:03:21 +0000 Subject: [PATCH 04/22] Add VSCode launching for debugging --- .vscode/launch.json | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) create mode 100644 .vscode/launch.json diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 0000000000..8af2261000 --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,21 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "name": "train", + "type": "python", + "request": "launch", + "module": "accelerate.commands.launch", + "args": [ + "${workspaceFolder}/scripts/finetune.py", + // "${file}", + "${workspaceFolder}/examples/llama-2/tiny-random.yml", + ], // other args comes after train.py + "console": "integratedTerminal", + // "env": {"CUDA_LAUNCH_BLOCKING": "1"} + }, + ] +} \ No newline at end of file From 88c31f14d34a71a91309a8818f562c0df98f4265 Mon Sep 17 00:00:00 2001 From: Glavin Wiechert Date: Tue, 5 Sep 2023 08:03:38 +0000 Subject: [PATCH 05/22] Add tiny llama example --- examples/llama-2/tiny-random.yml | 88 ++++++++++++++++++++++++++++++++ 1 file changed, 88 insertions(+) create mode 100644 examples/llama-2/tiny-random.yml diff --git a/examples/llama-2/tiny-random.yml b/examples/llama-2/tiny-random.yml new file mode 100644 index 0000000000..6ddd9ac8f9 --- /dev/null +++ b/examples/llama-2/tiny-random.yml @@ -0,0 +1,88 @@ +# anushehchaudry/llama-2-tiny-random +# base_model: anushehchaudry/llama-2-tiny-random +# base_model_config: anushehchaudry/llama-2-tiny-random + +base_model: JackFram/llama-68m +base_model_config: JackFram/llama-68m + +# base_model: PY007/TinyLlama-1.1B-step-50K-105b +# base_model_config: PY007/TinyLlama-1.1B-step-50K-105b + +model_type: LlamaForCausalLM +tokenizer_type: LlamaTokenizer +is_llama_derived_model: true + +load_in_8bit: true +load_in_4bit: false +strict: false + +datasets: + # - path: mhenrichsen/alpaca_2k_test + # type: alpaca + - path: teknium/GPT4-LLM-Cleaned + type: alpaca + # - path: Glavin001/startup-interviews + # type: alpaca +dataset_prepared_path: last_run_prepared +# val_set_size: 0.01 +val_set_size: 0.1 +# output_dir: ./lora-out +output_dir: ./lora-2-out + +# sequence_len: 4096 +# sequence_len: 2048 +# sequence_len: 256 +sequence_len: 512 +# sample_packing: true +sample_packing: false + +adapter: lora +lora_model_dir: +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_linear: true +lora_fan_in_fan_out: + +wandb_project: test-issue-490 +wandb_entity: +wandb_watch: +wandb_run_id: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 2 +# num_epochs: 3 +num_epochs: 0.001 +# num_epochs: 5 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +train_on_inputs: false +group_by_length: false +bf16: true +fp16: false +tf32: false + +gradient_checkpointing: true +early_stopping_patience: +resume_from_checkpoint: +local_rank: +logging_steps: 1 +xformers_attention: +flash_attention: true + +warmup_steps: 10 +# eval_steps: 20 +eval_steps: 2 +save_steps: +debug: +deepspeed: +weight_decay: 0.0 +fsdp: +fsdp_config: +special_tokens: + bos_token: "" + eos_token: "" + unk_token: "" From 06a44dec87415c17ee25534a11c66231ee57ef56 Mon Sep 17 00:00:00 2001 From: Glavin Wiechert Date: Thu, 7 Sep 2023 05:17:59 +0000 Subject: [PATCH 06/22] WIP attempt to improve post-eval prediction generation for table --- docker-compose.yaml | 9 ++--- examples/llama-2/tiny-random.yml | 33 +++++++++-------- src/axolotl/utils/callbacks.py | 61 ++++++++++++++++++++++++-------- 3 files changed, 71 insertions(+), 32 deletions(-) diff --git a/docker-compose.yaml b/docker-compose.yaml index d40422f94f..187d567f84 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -1,9 +1,10 @@ # version: '3.8' services: axolotl: - build: - context: . - dockerfile: ./docker/Dockerfile + # build: + # context: . + # dockerfile: ./docker/Dockerfile + image: winglian/axolotl:main-py3.10-cu118-2.0.1 volumes: - .:/workspace/axolotl - ~/.cache/huggingface/:/root/.cache/huggingface/ @@ -15,6 +16,6 @@ services: reservations: devices: - driver: nvidia - # count: 1 + count: 1 capabilities: [gpu] command: tail -f /dev/null diff --git a/examples/llama-2/tiny-random.yml b/examples/llama-2/tiny-random.yml index 6ddd9ac8f9..d5e88c7fa4 100644 --- a/examples/llama-2/tiny-random.yml +++ b/examples/llama-2/tiny-random.yml @@ -2,11 +2,11 @@ # base_model: anushehchaudry/llama-2-tiny-random # base_model_config: anushehchaudry/llama-2-tiny-random -base_model: JackFram/llama-68m -base_model_config: JackFram/llama-68m +# base_model: JackFram/llama-68m +# base_model_config: JackFram/llama-68m -# base_model: PY007/TinyLlama-1.1B-step-50K-105b -# base_model_config: PY007/TinyLlama-1.1B-step-50K-105b +base_model: PY007/TinyLlama-1.1B-step-50K-105b +base_model_config: PY007/TinyLlama-1.1B-step-50K-105b model_type: LlamaForCausalLM tokenizer_type: LlamaTokenizer @@ -19,20 +19,21 @@ strict: false datasets: # - path: mhenrichsen/alpaca_2k_test # type: alpaca - - path: teknium/GPT4-LLM-Cleaned - type: alpaca - # - path: Glavin001/startup-interviews + # - path: teknium/GPT4-LLM-Cleaned # type: alpaca + - path: Glavin001/startup-interviews + type: alpaca dataset_prepared_path: last_run_prepared # val_set_size: 0.01 val_set_size: 0.1 # output_dir: ./lora-out -output_dir: ./lora-2-out +# output_dir: ./lora-2-out +output_dir: ./lora-5-out # sequence_len: 4096 -# sequence_len: 2048 +sequence_len: 2048 # sequence_len: 256 -sequence_len: 512 +# sequence_len: 512 # sample_packing: true sample_packing: false @@ -51,10 +52,13 @@ wandb_run_id: wandb_log_model: gradient_accumulation_steps: 4 -micro_batch_size: 2 +# micro_batch_size: 2 +micro_batch_size: 16 # num_epochs: 3 -num_epochs: 0.001 -# num_epochs: 5 +# num_epochs: 0.001 +# num_epochs: 0.01 +# num_epochs: 1 +num_epochs: 5 optimizer: adamw_bnb_8bit lr_scheduler: cosine learning_rate: 0.0002 @@ -74,8 +78,9 @@ xformers_attention: flash_attention: true warmup_steps: 10 +eval_steps: 10 # eval_steps: 20 -eval_steps: 2 +# eval_steps: 2 save_steps: debug: deepspeed: diff --git a/src/axolotl/utils/callbacks.py b/src/axolotl/utils/callbacks.py index 6e635f02bc..f1ab97b012 100644 --- a/src/axolotl/utils/callbacks.py +++ b/src/axolotl/utils/callbacks.py @@ -20,6 +20,7 @@ TrainerControl, TrainerState, TrainingArguments, + GenerationConfig, ) from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy @@ -322,6 +323,8 @@ def on_evaluate( def log_prediction_callback_factory(trainer: Trainer, tokenizer): + LOG.info("log_prediction_callback_factory") + class LogPredictionCallback(TrainerCallback): """Callback to log prediction values during each evaluation""" @@ -334,18 +337,17 @@ def on_evaluate( args: AxolotlTrainingArguments, state: TrainerState, control: TrainerControl, - # model, + model, # tokenizer, train_dataloader, eval_dataloader, **kwargs, ): + LOG.info("=" * 80) LOG.info("logging predictions") - trainer.model.eval() - def logits_to_tokens(logits) -> str: probabilities = torch.softmax(logits, dim=-1) # Get the predicted token ids (the ones with the highest probability) @@ -456,17 +458,48 @@ def log_table_from_dataloader(name: str, table_dataloader): # Generate new prediction with trainer.model which is a transformer model with torch.no_grad(): # new_prediction = trainer.model(batch['input_ids'][i].unsqueeze(0)) - new_prediction = trainer.model(prompt_token_ids.unsqueeze(0)) - - # Convert the logits to probabilities using softmax - new_probabilities = torch.softmax(new_prediction.logits, dim=-1) - - # Get the predicted token ids (the ones with the highest probability) - new_predicted_token_ids = torch.argmax(new_probabilities, dim=-1) - - # Decode the predicted token ids to get the plaintext - new_predicted_tokens = tokenizer.decode(new_predicted_token_ids[0]) - + # new_prediction = trainer.model(prompt_token_ids.unsqueeze(0)) + # new_prediction = trainer.model(prompt_token_ids.unsqueeze(0)) + + generation_config = GenerationConfig( + repetition_penalty=1.1, + # max_new_tokens=1024, + # max_new_tokens=256, + max_new_tokens=128, + temperature=0.9, + # top_p=0.95, + # top_k=40, + bos_token_id=tokenizer.bos_token_id, + eos_token_id=tokenizer.eos_token_id, + pad_token_id=tokenizer.pad_token_id, + # do_sample=True, + do_sample=False, + use_cache=True, + return_dict_in_generate=True, + output_attentions=False, + output_hidden_states=False, + output_scores=False, + ) + # streamer = TextStreamer(tokenizer) + new_prediction = trainer.model.generate( + # inputs=batch["input_ids"].to(cfg.device), + inputs=prompt_token_ids.unsqueeze(0), + generation_config=generation_config, + # streamer=streamer, + ) + + # # Convert the logits to probabilities using softmax + # new_probabilities = torch.softmax(new_prediction.logits, dim=-1) + + # # Get the predicted token ids (the ones with the highest probability) + # new_predicted_token_ids = torch.argmax(new_probabilities, dim=-1) + + # # Decode the predicted token ids to get the plaintext + # new_predicted_tokens = tokenizer.decode(new_predicted_token_ids[0]) + + new_predicted_tokens = tokenizer.decode(new_prediction["sequences"].cpu().tolist()[0]) + + # print("=" * 80) # print("Prompt:") # print(prompt_text) From ab3cffa14193a1951269a121e9ab9f41a2651fa6 Mon Sep 17 00:00:00 2001 From: Glavin Wiechert Date: Fri, 8 Sep 2023 07:40:21 +0000 Subject: [PATCH 07/22] WIP attempt to improve post-eval prediction generation for table - part 2 --- src/axolotl/utils/callbacks.py | 49 +++++++++++++++++++++++++++------- 1 file changed, 40 insertions(+), 9 deletions(-) diff --git a/src/axolotl/utils/callbacks.py b/src/axolotl/utils/callbacks.py index f1ab97b012..81195aa640 100644 --- a/src/axolotl/utils/callbacks.py +++ b/src/axolotl/utils/callbacks.py @@ -396,7 +396,7 @@ def log_table_from_dataloader(name: str, table_dataloader): # pred_tokens = [] # for i, logit in enumerate(logits): - for i, (logit, labels_i) in enumerate(zip(logits, labels)): + for i, (input_ids, logit, labels_i) in enumerate(zip(batch['input_ids'], logits, labels)): # for i, (prompt_text, logit) in enumerate(prompt_completion_pairs): # print(dir(logit)) # print(logit) @@ -424,9 +424,32 @@ def log_table_from_dataloader(name: str, table_dataloader): # # label_non_zero_indices = (batch["labels"][i] != IGNORE_INDEX).nonzero().transpose(0, 1)[0] # FIXME: clean up? + # labels[15].tolist()[-3:] - prompt_token_indices = (batch["labels"][i] == IGNORE_INDEX).nonzero().transpose(0, 1)[0] # FIXME: clean up? - completion_token_indices = (batch["labels"][i] != IGNORE_INDEX).nonzero().transpose(0, 1)[0] # FIXME: clean up? + # prompt_token_indices = (batch["labels"][i] == IGNORE_INDEX).nonzero().transpose(0, 1)[0] # FIXME: clean up? + # completion_token_indices = (batch["labels"][i] != IGNORE_INDEX).nonzero().transpose(0, 1)[0] # FIXME: clean up? + + # prompt_token_indices = labels_i[labels_i == IGNORE_INDEX] + + # Prompt tokens are all tokens up to eos_token, the excluding pad_token + # input_ids = batch['input_ids'][i] + + # prompt_token_ids = input_ids[0:] + # examples = group_sublists_by(input_ids.tolist(), tokenizer.eos_token_id) + # clean_examples = [example for example in examples if example == tokenizer.pad_token_id] + + # prompt_token_indices = (batch["labels"][i] == IGNORE_INDEX).nonzero().transpose(0, 1)[0] # FIXME: clean up? + # completion_token_indices = (batch["labels"][i] != IGNORE_INDEX).nonzero().transpose(0, 1)[0] # FIXME: clean up? + + # print(tokenizer.decode(input_ids[1][(batch['labels'][1] == IGNORE_INDEX)])) + tokens_without_loss = (labels_i == IGNORE_INDEX) + tokens_with_loss = (labels_i != IGNORE_INDEX) + tokens_exclude_padding = (input_ids != tokenizer.pad_token_id) + # prompt_token_includes = (labels_i == IGNORE_INDEX) & (input_ids[i] != tokenizer.pad_token_id) + prompt_token_includes = tokens_without_loss & tokens_exclude_padding + + prompt_token_ids = input_ids[prompt_token_includes] + # print(tokenizer.decode(prompt_token_ids)) # Extract prompt and completion tokens from input_ids based on labels # prompt_token_ids = batch['input_ids'][batch['labels'] == IGNORE_INDEX] @@ -437,15 +460,18 @@ def log_table_from_dataloader(name: str, table_dataloader): # prompt_token_ids = batch['input_ids'][i][label_non_zero_indices] # prompt_token_ids = batch['input_ids'][i] - prompt_token_ids = batch['input_ids'][i][prompt_token_indices] - completion_token_ids = batch['input_ids'][i][completion_token_indices] + # prompt_token_ids = batch['input_ids'][i][prompt_token_indices] + # completion_token_ids = batch['input_ids'][i][completion_token_indices] + # completion_token_ids = batch['input_ids'][i][tokens_with_loss] + completion_token_ids = input_ids[tokens_with_loss] # prompt_texts = tokenizer.batch_decode(batch.data['input_ids']) # prompt_texts = tokenizer.batch_decode(prompt_token_ids) prompt_text = tokenizer.decode(prompt_token_ids) completion_text = tokenizer.decode(completion_token_ids) - completion_logit = logit[completion_token_indices] + # completion_logit = logit[completion_token_indices] + completion_logit = logit[tokens_with_loss] # predicted_tokens = logits_to_tokens(logit) predicted_tokens = logits_to_tokens(completion_logit) @@ -465,8 +491,10 @@ def log_table_from_dataloader(name: str, table_dataloader): repetition_penalty=1.1, # max_new_tokens=1024, # max_new_tokens=256, - max_new_tokens=128, - temperature=0.9, + # max_new_tokens=128, + # max_new_tokens=64, + max_new_tokens=32, + # temperature=0.9, # top_p=0.95, # top_k=40, bos_token_id=tokenizer.bos_token_id, @@ -497,8 +525,11 @@ def log_table_from_dataloader(name: str, table_dataloader): # # Decode the predicted token ids to get the plaintext # new_predicted_tokens = tokenizer.decode(new_predicted_token_ids[0]) - new_predicted_tokens = tokenizer.decode(new_prediction["sequences"].cpu().tolist()[0]) + new_prediction_all_tokens = new_prediction["sequences"].cpu().tolist()[0] + new_prediction_completion_only_tokens = new_prediction_all_tokens[len(prompt_token_ids):] + # new_predicted_tokens = tokenizer.decode(new_prediction["sequences"].cpu().tolist()[0]) + new_predicted_tokens = tokenizer.decode(new_prediction_completion_only_tokens) # print("=" * 80) # print("Prompt:") From b22d1c6d86a632c276a6a4a6ad50f0ccf6b1d9af Mon Sep 17 00:00:00 2001 From: Glavin Wiechert Date: Fri, 8 Sep 2023 08:52:03 +0000 Subject: [PATCH 08/22] WIP batch generation --- src/axolotl/utils/callbacks.py | 255 +++++++-------------------------- 1 file changed, 50 insertions(+), 205 deletions(-) diff --git a/src/axolotl/utils/callbacks.py b/src/axolotl/utils/callbacks.py index 81195aa640..41c1853417 100644 --- a/src/axolotl/utils/callbacks.py +++ b/src/axolotl/utils/callbacks.py @@ -359,230 +359,75 @@ def log_table_from_dataloader(name: str, table_dataloader): # Initialize an empty wandb.Table table = wandb.Table(columns=["id", "Prompt", "Correct Completion", "Predicted Completion 1", "Predicted Completion 2"]) - # preds, refs = [], [] - # loss_bench = 0 - # predictions = [] - id = 0 + batch_index = 0 for batch in tqdm(table_dataloader, total=len(table_dataloader)): - # max_examples = 100 - # for batch in tqdm(table_dataloader, total=min(max_examples, len(table_dataloader))): - - # batch.data['labels'].shape - # torch.Size([2, 320]) - # values at front with -100 are supposed to be prompt tokens - # values after are completion tokens - - # batch.data['input_ids'].shape - # torch.Size([2, 320]) - - # # Extract prompt and completion tokens from input_ids based on labels - # prompt_token_ids = batch.data['input_ids'][batch.data['labels'] == IGNORE_INDEX] - # completion_token_ids = batch.data['input_ids'][batch.data['labels'] != IGNORE_INDEX] - - # # prompt_texts = tokenizer.batch_decode(batch.data['input_ids']) - # prompt_texts = tokenizer.batch_decode(prompt_token_ids) - # completion_texts = tokenizer.batch_decode(completion_token_ids) - - (loss, logits, labels) = trainer.prediction_step( + # For each batch I want prompt, completion, 2x predictions + + # (loss, logits, labels) = trainer.prediction_step( + (batch_loss, batch_logits, batch_labels) = trainer.prediction_step( trainer.model, batch, prediction_loss_only=False, ) - # prompt_completion_pairs = zip(prompt_texts, logits) - - # print("logits", logits) - # print("labels", labels) - - # pred_tokens = [] - # for i, logit in enumerate(logits): - for i, (input_ids, logit, labels_i) in enumerate(zip(batch['input_ids'], logits, labels)): - # for i, (prompt_text, logit) in enumerate(prompt_completion_pairs): - # print(dir(logit)) - # print(logit) - # print(logit.shape) - # # Convert the logits to probabilities using softmax - # probabilities = torch.softmax(logit, dim=-1) - - # # Get the predicted token id (the one with the highest probability) - # predicted_token_id = torch.argmax(probabilities).item() - - # # Decode the predicted token id to get the plaintext - # predicted_token = tokenizer.decode([predicted_token_id]) - - # # Append the predicted token to the preds list - # pred_tokens.append(predicted_token) - - # # Convert the logits to probabilities using softmax - # probabilities = torch.softmax(logit, dim=-1) + prompt_token_ids_list = [] + completion_texts = [] + prediction_texts = [] - # # Get the predicted token ids (the ones with the highest probability) - # predicted_token_ids = torch.argmax(probabilities, dim=-1) - - # # Decode the predicted token ids to get the plaintext - # predicted_tokens = tokenizer.batch_decode(predicted_token_ids) - - # - # label_non_zero_indices = (batch["labels"][i] != IGNORE_INDEX).nonzero().transpose(0, 1)[0] # FIXME: clean up? - # labels[15].tolist()[-3:] - - # prompt_token_indices = (batch["labels"][i] == IGNORE_INDEX).nonzero().transpose(0, 1)[0] # FIXME: clean up? - # completion_token_indices = (batch["labels"][i] != IGNORE_INDEX).nonzero().transpose(0, 1)[0] # FIXME: clean up? - - # prompt_token_indices = labels_i[labels_i == IGNORE_INDEX] + # for input_ids in batch['input_ids']: + # for batch_item_idx, (input_ids, labels) in enumerate(zip(batch['input_ids'], logits, labels)): + for batch_item_idx, (input_ids, logits, labels) in enumerate(zip(batch['input_ids'], batch_logits, batch_labels)): + tokens_without_loss = (labels == IGNORE_INDEX) + tokens_with_loss = (labels != IGNORE_INDEX) + tokens_exclude_padding = (input_ids != tokenizer.pad_token_id) - # Prompt tokens are all tokens up to eos_token, the excluding pad_token - # input_ids = batch['input_ids'][i] + prompt_token_includes = tokens_without_loss & tokens_exclude_padding - # prompt_token_ids = input_ids[0:] - # examples = group_sublists_by(input_ids.tolist(), tokenizer.eos_token_id) - # clean_examples = [example for example in examples if example == tokenizer.pad_token_id] + prompt_token_ids = input_ids[prompt_token_includes] + prompt_token_ids_list.append(prompt_token_ids) - # prompt_token_indices = (batch["labels"][i] == IGNORE_INDEX).nonzero().transpose(0, 1)[0] # FIXME: clean up? - # completion_token_indices = (batch["labels"][i] != IGNORE_INDEX).nonzero().transpose(0, 1)[0] # FIXME: clean up? + completion_token_ids = input_ids[tokens_with_loss] + completion_text = tokenizer.decode(completion_token_ids) + completion_texts.append(completion_text) - # print(tokenizer.decode(input_ids[1][(batch['labels'][1] == IGNORE_INDEX)])) - tokens_without_loss = (labels_i == IGNORE_INDEX) - tokens_with_loss = (labels_i != IGNORE_INDEX) - tokens_exclude_padding = (input_ids != tokenizer.pad_token_id) - # prompt_token_includes = (labels_i == IGNORE_INDEX) & (input_ids[i] != tokenizer.pad_token_id) - prompt_token_includes = tokens_without_loss & tokens_exclude_padding + completion_logit = logits[tokens_with_loss] + predicted_tokens = logits_to_tokens(completion_logit) + prediction_text = tokenizer.decode(predicted_tokens) + prediction_texts.append(prediction_text) - prompt_token_ids = input_ids[prompt_token_includes] - # print(tokenizer.decode(prompt_token_ids)) + prompt_texts = tokenizer.batch_decode(prompt_token_ids_list, skip_special_tokens=True) - # Extract prompt and completion tokens from input_ids based on labels - # prompt_token_ids = batch['input_ids'][batch['labels'] == IGNORE_INDEX] - # completion_token_ids = batch['input_ids'][batch['labels'] != IGNORE_INDEX] + with torch.no_grad(): + generation_config = GenerationConfig( + repetition_penalty=1.1, + max_new_tokens=32, + bos_token_id=tokenizer.bos_token_id, + eos_token_id=tokenizer.eos_token_id, + pad_token_id=tokenizer.pad_token_id, + do_sample=False, + use_cache=True, + return_dict_in_generate=True, + output_attentions=False, + output_hidden_states=False, + output_scores=False, + ) - # prompt_token_ids = batch['input_ids'][batch['labels'] == IGNORE_INDEX] - # prompt_token_ids = batch['input_ids'][label_non_zero_indices] - # prompt_token_ids = batch['input_ids'][i][label_non_zero_indices] - # prompt_token_ids = batch['input_ids'][i] + encoding = tokenizer(prompt_texts, padding=True, return_tensors='pt').to(self.cfg.device) + new_predictions = trainer.model.generate(**encoding, generation_config=generation_config) - # prompt_token_ids = batch['input_ids'][i][prompt_token_indices] - # completion_token_ids = batch['input_ids'][i][completion_token_indices] - # completion_token_ids = batch['input_ids'][i][tokens_with_loss] - completion_token_ids = input_ids[tokens_with_loss] + new_prediction_all_tokens = new_predictions["sequences"].cpu().tolist() + new_prediction_without_prompt_tokens_list = [] + for prompt_token_ids, new_prediction_tokens in zip(prompt_token_ids_list, new_prediction_all_tokens): + new_prediction_without_prompt_tokens = new_prediction_tokens[len(prompt_token_ids):] + new_prediction_without_prompt_tokens_list.append(new_prediction_without_prompt_tokens) - # prompt_texts = tokenizer.batch_decode(batch.data['input_ids']) - # prompt_texts = tokenizer.batch_decode(prompt_token_ids) - prompt_text = tokenizer.decode(prompt_token_ids) - completion_text = tokenizer.decode(completion_token_ids) + new_predicted_texts = tokenizer.batch_decode(new_prediction_without_prompt_tokens_list, skip_special_tokens=True) - # completion_logit = logit[completion_token_indices] - completion_logit = logit[tokens_with_loss] - # predicted_tokens = logits_to_tokens(logit) - predicted_tokens = logits_to_tokens(completion_logit) + for i, (prompt_text, completion_text, prediction_text, new_predicted_text) in enumerate(zip(prompt_texts, completion_texts, prediction_texts, new_predicted_texts)): + table.add_data(i, prompt_text, completion_text, prediction_text, new_predicted_text) - # Append the predicted tokens to the preds list - # pred_tokens.extend(predicted_tokens) - # pred_string = " ".join(predicted_tokens) # FIXME: missing spaces - prediction_text = tokenizer.decode(predicted_tokens) + batch_index += 1 - # generate new prediction with trainer.model which is a transformer model - # Generate new prediction with trainer.model which is a transformer model - with torch.no_grad(): - # new_prediction = trainer.model(batch['input_ids'][i].unsqueeze(0)) - # new_prediction = trainer.model(prompt_token_ids.unsqueeze(0)) - # new_prediction = trainer.model(prompt_token_ids.unsqueeze(0)) - - generation_config = GenerationConfig( - repetition_penalty=1.1, - # max_new_tokens=1024, - # max_new_tokens=256, - # max_new_tokens=128, - # max_new_tokens=64, - max_new_tokens=32, - # temperature=0.9, - # top_p=0.95, - # top_k=40, - bos_token_id=tokenizer.bos_token_id, - eos_token_id=tokenizer.eos_token_id, - pad_token_id=tokenizer.pad_token_id, - # do_sample=True, - do_sample=False, - use_cache=True, - return_dict_in_generate=True, - output_attentions=False, - output_hidden_states=False, - output_scores=False, - ) - # streamer = TextStreamer(tokenizer) - new_prediction = trainer.model.generate( - # inputs=batch["input_ids"].to(cfg.device), - inputs=prompt_token_ids.unsqueeze(0), - generation_config=generation_config, - # streamer=streamer, - ) - - # # Convert the logits to probabilities using softmax - # new_probabilities = torch.softmax(new_prediction.logits, dim=-1) - - # # Get the predicted token ids (the ones with the highest probability) - # new_predicted_token_ids = torch.argmax(new_probabilities, dim=-1) - - # # Decode the predicted token ids to get the plaintext - # new_predicted_tokens = tokenizer.decode(new_predicted_token_ids[0]) - - new_prediction_all_tokens = new_prediction["sequences"].cpu().tolist()[0] - new_prediction_completion_only_tokens = new_prediction_all_tokens[len(prompt_token_ids):] - - # new_predicted_tokens = tokenizer.decode(new_prediction["sequences"].cpu().tolist()[0]) - new_predicted_tokens = tokenizer.decode(new_prediction_completion_only_tokens) - - # print("=" * 80) - # print("Prompt:") - # print(prompt_text) - # print("=" * 80) - # print("Expected Completion:") - # print(completion_text) - # print("=" * 80) - # print("Predicted Completion:") - # print(prediction_text) - # print("=" * 80) - - table.add_data(id, prompt_text, completion_text, prediction_text, new_predicted_tokens) - id += 1 - - # add prediction - # convert pred_tokens to a single string - # pred_string = " ".join(pred_tokens) - # predictions.append(pred_string) - - # table.add_data(prompt_text, pred_string, "Ground Truth") - - # # Convert the predictions and labels to a readable format - # # predictions = [tokenizer.decode(p) for p in logits] - # # labels = [tokenizer.decode(l) for l in labels] - - # # Add the data to the wandb.Table - # for prediction, label in zip(predictions, labels): - # table.add_data(prediction, label) - - # using trainer.model generate prediction tokens for each input in eval_dataloader - # predictions = [] - # for batch in eval_dataloader: - # inputs, _ = batch - # print(inputs) - # with torch.no_grad(): - # outputs = trainer.model(inputs) - # print(outputs) - # next_pred = [tokenizer.decode(p) for p in outputs.logits.argmax(dim=-1).tolist()] - # print(next_pred) - # predictions.extend(next_pred) - - # add the predictions to the table - # for prediction in predictions: - # table.add_data(prediction, "Ground Truth") - - # print table size - # print("Table size:", len(table.data)) - - # print first entry in table - # print("First entry in table:", table.data[0]) - - # Log the wandb.Table wandb.run.log({ f"{name} - Predictions vs Ground Truth": table }) # log_table_from_dataloader("Train", train_dataloader) From 6f3216eb39b348f2ea445fa4c89a317868e2cc9d Mon Sep 17 00:00:00 2001 From: Glavin Wiechert Date: Fri, 8 Sep 2023 10:19:22 +0000 Subject: [PATCH 09/22] WIP attempt to handle sample_packing using position_ids for wandb prediction table --- src/axolotl/utils/callbacks.py | 86 +++++++++++++++++++++++++--------- 1 file changed, 63 insertions(+), 23 deletions(-) diff --git a/src/axolotl/utils/callbacks.py b/src/axolotl/utils/callbacks.py index 41c1853417..657cb6d7f5 100644 --- a/src/axolotl/utils/callbacks.py +++ b/src/axolotl/utils/callbacks.py @@ -347,6 +347,7 @@ def on_evaluate( LOG.info("logging predictions") trainer.model.eval() + device = torch.device(self.cfg.device) def logits_to_tokens(logits) -> str: probabilities = torch.softmax(logits, dim=-1) @@ -354,6 +355,16 @@ def logits_to_tokens(logits) -> str: predicted_token_ids = torch.argmax(probabilities, dim=-1) return predicted_token_ids + def find_ranges(lst): + ranges = [] + start = 0 + for i in range(1, len(lst)): + if lst[i] == 0: + ranges.append((start, i-1)) + start = i + ranges.append((start, len(lst)-1)) # for the last range + return ranges + def log_table_from_dataloader(name: str, table_dataloader): # Initialize an empty wandb.Table @@ -364,38 +375,65 @@ def log_table_from_dataloader(name: str, table_dataloader): # For each batch I want prompt, completion, 2x predictions # (loss, logits, labels) = trainer.prediction_step( - (batch_loss, batch_logits, batch_labels) = trainer.prediction_step( - trainer.model, - batch, - prediction_loss_only=False, - ) + # (batch_loss, batch_logits, batch_labels) = trainer.prediction_step( + # trainer.model, + # batch, + # prediction_loss_only=False, + # ) + + batch_labels = batch['labels'].to(device) + batch_input_ids = batch['input_ids'].to(device) + + if 'position_ids' in batch: + batch_pos_ids = batch['position_ids'].tolist() + else: + batch_pos_ids = [None] * len(batch['input_ids']) prompt_token_ids_list = [] - completion_texts = [] - prediction_texts = [] + completion_token_ids_list = [] + # completion_texts = [] + # prediction_texts = [] # for input_ids in batch['input_ids']: # for batch_item_idx, (input_ids, labels) in enumerate(zip(batch['input_ids'], logits, labels)): - for batch_item_idx, (input_ids, logits, labels) in enumerate(zip(batch['input_ids'], batch_logits, batch_labels)): - tokens_without_loss = (labels == IGNORE_INDEX) - tokens_with_loss = (labels != IGNORE_INDEX) - tokens_exclude_padding = (input_ids != tokenizer.pad_token_id) + # for batch_item_idx, (input_ids, logits, labels) in enumerate(zip(batch['input_ids'].to(device), batch_logits, batch_labels)): + for batch_item_idx, (input_ids_all, labels_all, pos_ids) in enumerate(zip(batch_input_ids, batch_labels, batch_pos_ids)): + + if pos_ids is None: + pos_ranges = [(0, len(input_ids_all)-1)] + else: + pos_ranges = find_ranges(pos_ids) + + for pos_range in pos_ranges: + start, end = pos_range + if start == end: + continue + + input_ids = input_ids_all[start:end+1] + labels = labels_all[start:end+1] + # input_ids[start:end] = tokenizer.pad_token_id + + tokens_without_loss = (labels == IGNORE_INDEX) + tokens_with_loss = (labels != IGNORE_INDEX) + tokens_exclude_padding = (input_ids != tokenizer.pad_token_id) - prompt_token_includes = tokens_without_loss & tokens_exclude_padding + prompt_token_includes = tokens_without_loss & tokens_exclude_padding - prompt_token_ids = input_ids[prompt_token_includes] - prompt_token_ids_list.append(prompt_token_ids) + prompt_token_ids = input_ids[prompt_token_includes] + prompt_token_ids_list.append(prompt_token_ids) - completion_token_ids = input_ids[tokens_with_loss] - completion_text = tokenizer.decode(completion_token_ids) - completion_texts.append(completion_text) + completion_token_ids = input_ids[tokens_with_loss] + completion_token_ids_list.append(completion_token_ids) + # completion_text = tokenizer.decode(completion_token_ids) + # completion_texts.append(completion_text) - completion_logit = logits[tokens_with_loss] - predicted_tokens = logits_to_tokens(completion_logit) - prediction_text = tokenizer.decode(predicted_tokens) - prediction_texts.append(prediction_text) + # completion_logit = logits[tokens_with_loss] + # predicted_tokens = logits_to_tokens(completion_logit) + # prediction_text = tokenizer.decode(predicted_tokens) + # prediction_texts.append(prediction_text) prompt_texts = tokenizer.batch_decode(prompt_token_ids_list, skip_special_tokens=True) + completion_texts = tokenizer.batch_decode(completion_token_ids_list, skip_special_tokens=True) with torch.no_grad(): generation_config = GenerationConfig( @@ -413,7 +451,7 @@ def log_table_from_dataloader(name: str, table_dataloader): ) encoding = tokenizer(prompt_texts, padding=True, return_tensors='pt').to(self.cfg.device) - new_predictions = trainer.model.generate(**encoding, generation_config=generation_config) + new_predictions = trainer.model.generate(**encoding, generation_config=generation_config) # FIXME: when sample_packing=True then error: "TypeError: varlen_fwd(): incompatible function arguments." new_prediction_all_tokens = new_predictions["sequences"].cpu().tolist() new_prediction_without_prompt_tokens_list = [] @@ -423,7 +461,9 @@ def log_table_from_dataloader(name: str, table_dataloader): new_predicted_texts = tokenizer.batch_decode(new_prediction_without_prompt_tokens_list, skip_special_tokens=True) - for i, (prompt_text, completion_text, prediction_text, new_predicted_text) in enumerate(zip(prompt_texts, completion_texts, prediction_texts, new_predicted_texts)): + # for i, (prompt_text, completion_text, prediction_text, new_predicted_text) in enumerate(zip(prompt_texts, completion_texts, prediction_texts, new_predicted_texts)): + for i, (prompt_text, completion_text, new_predicted_text) in enumerate(zip(prompt_texts, completion_texts, new_predicted_texts)): + prediction_text = "" table.add_data(i, prompt_text, completion_text, prediction_text, new_predicted_text) batch_index += 1 From e9eae77ba77e6dd549d4ec5237d4b725e4780a56 Mon Sep 17 00:00:00 2001 From: Glavin Wiechert Date: Fri, 8 Sep 2023 10:21:16 +0000 Subject: [PATCH 10/22] WIP add code for debugging --- .vscode/launch.json | 15 +++++++++++++++ docker-compose.yaml | 2 ++ examples/llama-2/tiny-random.yml | 18 +++++++++++------- scripts/finetune.py | 5 +++++ src/axolotl/utils/callbacks.py | 12 ++++++++++++ 5 files changed, 45 insertions(+), 7 deletions(-) diff --git a/.vscode/launch.json b/.vscode/launch.json index 8af2261000..e116653768 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -4,6 +4,21 @@ // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 "version": "0.2.0", "configurations": [ + { + "name": "Python: Remote Attach", + "type": "python", + "request": "attach", + "connect": { + "host": "0.0.0.0", + "port": 5678 + }, + "pathMappings": [ + { + "localRoot": "${workspaceFolder}", + "remoteRoot": "/workspace/axolotl/" + } + ] + }, { "name": "train", "type": "python", diff --git a/docker-compose.yaml b/docker-compose.yaml index 187d567f84..143d5f6c91 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -11,6 +11,8 @@ services: # set environment variables environment: - WANDB_API_KEY=${WANDB_API_KEY} + ports: + - "5678:5678" deploy: resources: reservations: diff --git a/examples/llama-2/tiny-random.yml b/examples/llama-2/tiny-random.yml index d5e88c7fa4..4a841f4609 100644 --- a/examples/llama-2/tiny-random.yml +++ b/examples/llama-2/tiny-random.yml @@ -19,13 +19,14 @@ strict: false datasets: # - path: mhenrichsen/alpaca_2k_test # type: alpaca - # - path: teknium/GPT4-LLM-Cleaned - # type: alpaca - - path: Glavin001/startup-interviews + - path: teknium/GPT4-LLM-Cleaned type: alpaca + # - path: Glavin001/startup-interviews + # type: alpaca dataset_prepared_path: last_run_prepared # val_set_size: 0.01 -val_set_size: 0.1 +val_set_size: 0.001 +# val_set_size: 0.1 # output_dir: ./lora-out # output_dir: ./lora-2-out output_dir: ./lora-5-out @@ -35,7 +36,7 @@ sequence_len: 2048 # sequence_len: 256 # sequence_len: 512 # sample_packing: true -sample_packing: false +sample_packing: false # FIXME: disabled until we can fix the bug in callbacks.py adapter: lora lora_model_dir: @@ -54,11 +55,13 @@ wandb_log_model: gradient_accumulation_steps: 4 # micro_batch_size: 2 micro_batch_size: 16 +# micro_batch_size: 24 +# micro_batch_size: 24 # num_epochs: 3 # num_epochs: 0.001 # num_epochs: 0.01 -# num_epochs: 1 -num_epochs: 5 +num_epochs: 1 +# num_epochs: 5 optimizer: adamw_bnb_8bit lr_scheduler: cosine learning_rate: 0.0002 @@ -81,6 +84,7 @@ warmup_steps: 10 eval_steps: 10 # eval_steps: 20 # eval_steps: 2 +# eval_steps: 1 save_steps: debug: deepspeed: diff --git a/scripts/finetune.py b/scripts/finetune.py index b998edc798..91e6d5d844 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -28,6 +28,11 @@ from axolotl.utils.tokenization import check_dataset_labels from axolotl.utils.wandb import setup_wandb_env_vars +# import debugpy +# debugpy.listen(('0.0.0.0', 5678)) +# debugpy.wait_for_client() +# debugpy.breakpoint() + project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) src_dir = os.path.join(project_root, "src") sys.path.insert(0, src_dir) diff --git a/src/axolotl/utils/callbacks.py b/src/axolotl/utils/callbacks.py index 657cb6d7f5..6ee0243324 100644 --- a/src/axolotl/utils/callbacks.py +++ b/src/axolotl/utils/callbacks.py @@ -6,6 +6,7 @@ import os from typing import TYPE_CHECKING, Dict, List +import itertools import evaluate import numpy as np import pandas as pd @@ -483,3 +484,14 @@ def log_table_from_dataloader(name: str, table_dataloader): return control return LogPredictionCallback + + +def group_sublists_by(lst: List[int], value: int) -> List[List[int]]: + """ + Group sublists of lst by value + """ + grouped = [] + for key, group in itertools.groupby(lst, lambda x: x == value): + if key: + grouped.append(list(group)) + return grouped From 83e6b29fe896c4af54fe7c3572f7ed6c15874f53 Mon Sep 17 00:00:00 2001 From: Glavin Wiechert Date: Sat, 9 Sep 2023 06:48:41 +0000 Subject: [PATCH 11/22] Fix sample_packing support for wandb prediction table --- .vscode/launch.json | 3 ++- examples/llama-2/tiny-random.yml | 26 +++++++++++-------- .../monkeypatch/llama_attn_hijack_flash.py | 6 ++++- src/axolotl/utils/callbacks.py | 5 ++-- 4 files changed, 25 insertions(+), 15 deletions(-) diff --git a/.vscode/launch.json b/.vscode/launch.json index e116653768..e264f9d69f 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -17,7 +17,8 @@ "localRoot": "${workspaceFolder}", "remoteRoot": "/workspace/axolotl/" } - ] + ], + "justMyCode": false }, { "name": "train", diff --git a/examples/llama-2/tiny-random.yml b/examples/llama-2/tiny-random.yml index 4a841f4609..138aab2963 100644 --- a/examples/llama-2/tiny-random.yml +++ b/examples/llama-2/tiny-random.yml @@ -19,24 +19,27 @@ strict: false datasets: # - path: mhenrichsen/alpaca_2k_test # type: alpaca - - path: teknium/GPT4-LLM-Cleaned - type: alpaca - # - path: Glavin001/startup-interviews + # - path: teknium/GPT4-LLM-Cleaned # type: alpaca + - path: Glavin001/startup-interviews + type: alpaca dataset_prepared_path: last_run_prepared # val_set_size: 0.01 -val_set_size: 0.001 +val_set_size: 0.02 +# val_set_size: 0.05 +# val_set_size: 0.001 # val_set_size: 0.1 # output_dir: ./lora-out # output_dir: ./lora-2-out -output_dir: ./lora-5-out +output_dir: ./lora-6-out # sequence_len: 4096 -sequence_len: 2048 +# sequence_len: 2048 # sequence_len: 256 # sequence_len: 512 -# sample_packing: true -sample_packing: false # FIXME: disabled until we can fix the bug in callbacks.py +sequence_len: 1024 +sample_packing: true +# sample_packing: false # FIXME: disabled until we can fix the bug in callbacks.py adapter: lora lora_model_dir: @@ -60,8 +63,9 @@ micro_batch_size: 16 # num_epochs: 3 # num_epochs: 0.001 # num_epochs: 0.01 -num_epochs: 1 +# num_epochs: 1 # num_epochs: 5 +num_epochs: 10 optimizer: adamw_bnb_8bit lr_scheduler: cosine learning_rate: 0.0002 @@ -81,9 +85,9 @@ xformers_attention: flash_attention: true warmup_steps: 10 -eval_steps: 10 +# eval_steps: 10 # eval_steps: 20 -# eval_steps: 2 +eval_steps: 2 # eval_steps: 1 save_steps: debug: diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py index 39cfb5c173..d90d5e5497 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py @@ -194,7 +194,11 @@ def flashattn_forward( # only on first autoregressive step q,k,v have same seqlen is_causal = key_states.shape == query_states.shape - if cu_seqlens is not None and max_seqlen is not None: + # if cu_seqlens is not None and max_seqlen is not None: + # if cu_seqlens is not None and max_seqlen is not None and self.training: + # if cu_seqlens is not None and max_seqlen is not None and query_states.shape == key_states.shape: + # if cu_seqlens is not None and max_seqlen is not None and len(cu_seqlens[0]) > 2: + if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1: # special handling using sample packing qkv = torch.stack( [query_states, key_states, value_states], dim=2 diff --git a/src/axolotl/utils/callbacks.py b/src/axolotl/utils/callbacks.py index 6ee0243324..8ab19707bb 100644 --- a/src/axolotl/utils/callbacks.py +++ b/src/axolotl/utils/callbacks.py @@ -438,8 +438,9 @@ def log_table_from_dataloader(name: str, table_dataloader): with torch.no_grad(): generation_config = GenerationConfig( - repetition_penalty=1.1, - max_new_tokens=32, + # repetition_penalty=1.1, + max_new_tokens=128, + # max_new_tokens=32, bos_token_id=tokenizer.bos_token_id, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id, From aaf4d1e795b38e2520e26f244be9bc7dc50e099c Mon Sep 17 00:00:00 2001 From: Glavin Wiechert Date: Sat, 9 Sep 2023 07:41:46 +0000 Subject: [PATCH 12/22] Clean up code for PR review --- .vscode/launch.json | 2 +- .../monkeypatch/llama_attn_hijack_flash.py | 4 - src/axolotl/utils/callbacks.py | 162 +++++++----------- 3 files changed, 67 insertions(+), 101 deletions(-) diff --git a/.vscode/launch.json b/.vscode/launch.json index e264f9d69f..3f60f05d36 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -34,4 +34,4 @@ // "env": {"CUDA_LAUNCH_BLOCKING": "1"} }, ] -} \ No newline at end of file +} diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py index d90d5e5497..6b3fd4355e 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py @@ -194,10 +194,6 @@ def flashattn_forward( # only on first autoregressive step q,k,v have same seqlen is_causal = key_states.shape == query_states.shape - # if cu_seqlens is not None and max_seqlen is not None: - # if cu_seqlens is not None and max_seqlen is not None and self.training: - # if cu_seqlens is not None and max_seqlen is not None and query_states.shape == key_states.shape: - # if cu_seqlens is not None and max_seqlen is not None and len(cu_seqlens[0]) > 2: if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1: # special handling using sample packing qkv = torch.stack( diff --git a/src/axolotl/utils/callbacks.py b/src/axolotl/utils/callbacks.py index 8ab19707bb..2272ce0a2b 100644 --- a/src/axolotl/utils/callbacks.py +++ b/src/axolotl/utils/callbacks.py @@ -6,7 +6,6 @@ import os from typing import TYPE_CHECKING, Dict, List -import itertools import evaluate import numpy as np import pandas as pd @@ -16,12 +15,12 @@ from optimum.bettertransformer import BetterTransformer from tqdm import tqdm from transformers import ( + GenerationConfig, Trainer, TrainerCallback, TrainerControl, TrainerState, TrainingArguments, - GenerationConfig, ) from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy @@ -324,8 +323,6 @@ def on_evaluate( def log_prediction_callback_factory(trainer: Trainer, tokenizer): - LOG.info("log_prediction_callback_factory") - class LogPredictionCallback(TrainerCallback): """Callback to log prediction values during each evaluation""" @@ -338,70 +335,53 @@ def on_evaluate( args: AxolotlTrainingArguments, state: TrainerState, control: TrainerControl, - model, - # tokenizer, train_dataloader, eval_dataloader, **kwargs, ): - LOG.info("=" * 80) - LOG.info("logging predictions") - trainer.model.eval() device = torch.device(self.cfg.device) - def logits_to_tokens(logits) -> str: - probabilities = torch.softmax(logits, dim=-1) - # Get the predicted token ids (the ones with the highest probability) - predicted_token_ids = torch.argmax(probabilities, dim=-1) - return predicted_token_ids - def find_ranges(lst): ranges = [] start = 0 for i in range(1, len(lst)): if lst[i] == 0: - ranges.append((start, i-1)) + ranges.append((start, i - 1)) start = i - ranges.append((start, len(lst)-1)) # for the last range + end = len(lst) - 1 + ranges.append((start, end)) return ranges def log_table_from_dataloader(name: str, table_dataloader): + table = wandb.Table( + columns=[ + "id", + "Prompt", + "Correct Completion", + "Predicted Completion", + ] + ) + row_index = 0 + max_new_tokens = 128 - # Initialize an empty wandb.Table - table = wandb.Table(columns=["id", "Prompt", "Correct Completion", "Predicted Completion 1", "Predicted Completion 2"]) - - batch_index = 0 for batch in tqdm(table_dataloader, total=len(table_dataloader)): - # For each batch I want prompt, completion, 2x predictions - - # (loss, logits, labels) = trainer.prediction_step( - # (batch_loss, batch_logits, batch_labels) = trainer.prediction_step( - # trainer.model, - # batch, - # prediction_loss_only=False, - # ) - - batch_labels = batch['labels'].to(device) - batch_input_ids = batch['input_ids'].to(device) + batch_labels = batch["labels"].to(device) + batch_input_ids = batch["input_ids"].to(device) - if 'position_ids' in batch: - batch_pos_ids = batch['position_ids'].tolist() + if "position_ids" in batch: + batch_pos_ids = batch["position_ids"].tolist() else: - batch_pos_ids = [None] * len(batch['input_ids']) + batch_pos_ids = [None] * len(batch["input_ids"]) prompt_token_ids_list = [] completion_token_ids_list = [] - # completion_texts = [] - # prediction_texts = [] - - # for input_ids in batch['input_ids']: - # for batch_item_idx, (input_ids, labels) in enumerate(zip(batch['input_ids'], logits, labels)): - # for batch_item_idx, (input_ids, logits, labels) in enumerate(zip(batch['input_ids'].to(device), batch_logits, batch_labels)): - for batch_item_idx, (input_ids_all, labels_all, pos_ids) in enumerate(zip(batch_input_ids, batch_labels, batch_pos_ids)): + for input_ids_all, labels_all, pos_ids in zip( + batch_input_ids, batch_labels, batch_pos_ids + ): if pos_ids is None: - pos_ranges = [(0, len(input_ids_all)-1)] + pos_ranges = [(0, len(input_ids_all) - 1)] else: pos_ranges = find_ranges(pos_ids) @@ -410,37 +390,33 @@ def log_table_from_dataloader(name: str, table_dataloader): if start == end: continue - input_ids = input_ids_all[start:end+1] - labels = labels_all[start:end+1] - # input_ids[start:end] = tokenizer.pad_token_id + input_ids = input_ids_all[start : end + 1] + labels = labels_all[start : end + 1] - tokens_without_loss = (labels == IGNORE_INDEX) - tokens_with_loss = (labels != IGNORE_INDEX) - tokens_exclude_padding = (input_ids != tokenizer.pad_token_id) + tokens_without_loss = labels == IGNORE_INDEX + tokens_with_loss = labels != IGNORE_INDEX + tokens_exclude_padding = input_ids != tokenizer.pad_token_id - prompt_token_includes = tokens_without_loss & tokens_exclude_padding + prompt_token_includes = ( + tokens_without_loss & tokens_exclude_padding + ) prompt_token_ids = input_ids[prompt_token_includes] prompt_token_ids_list.append(prompt_token_ids) completion_token_ids = input_ids[tokens_with_loss] completion_token_ids_list.append(completion_token_ids) - # completion_text = tokenizer.decode(completion_token_ids) - # completion_texts.append(completion_text) - - # completion_logit = logits[tokens_with_loss] - # predicted_tokens = logits_to_tokens(completion_logit) - # prediction_text = tokenizer.decode(predicted_tokens) - # prediction_texts.append(prediction_text) - prompt_texts = tokenizer.batch_decode(prompt_token_ids_list, skip_special_tokens=True) - completion_texts = tokenizer.batch_decode(completion_token_ids_list, skip_special_tokens=True) + prompt_texts = tokenizer.batch_decode( + prompt_token_ids_list, skip_special_tokens=True + ) + completion_texts = tokenizer.batch_decode( + completion_token_ids_list, skip_special_tokens=True + ) with torch.no_grad(): generation_config = GenerationConfig( - # repetition_penalty=1.1, - max_new_tokens=128, - # max_new_tokens=32, + max_new_tokens=max_new_tokens, bos_token_id=tokenizer.bos_token_id, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id, @@ -452,47 +428,41 @@ def log_table_from_dataloader(name: str, table_dataloader): output_scores=False, ) - encoding = tokenizer(prompt_texts, padding=True, return_tensors='pt').to(self.cfg.device) - new_predictions = trainer.model.generate(**encoding, generation_config=generation_config) # FIXME: when sample_packing=True then error: "TypeError: varlen_fwd(): incompatible function arguments." - - new_prediction_all_tokens = new_predictions["sequences"].cpu().tolist() - new_prediction_without_prompt_tokens_list = [] - for prompt_token_ids, new_prediction_tokens in zip(prompt_token_ids_list, new_prediction_all_tokens): - new_prediction_without_prompt_tokens = new_prediction_tokens[len(prompt_token_ids):] - new_prediction_without_prompt_tokens_list.append(new_prediction_without_prompt_tokens) - - new_predicted_texts = tokenizer.batch_decode(new_prediction_without_prompt_tokens_list, skip_special_tokens=True) + prompt_encoding = tokenizer( + prompt_texts, padding=True, return_tensors="pt" + ).to(self.cfg.device) + predictions = trainer.model.generate( + **prompt_encoding, generation_config=generation_config + ) - # for i, (prompt_text, completion_text, prediction_text, new_predicted_text) in enumerate(zip(prompt_texts, completion_texts, prediction_texts, new_predicted_texts)): - for i, (prompt_text, completion_text, new_predicted_text) in enumerate(zip(prompt_texts, completion_texts, new_predicted_texts)): - prediction_text = "" - table.add_data(i, prompt_text, completion_text, prediction_text, new_predicted_text) + prediction_all_tokens = predictions["sequences"].cpu().tolist() + prediction_without_prompt_tokens_list = [] + for prompt_token_ids, prediction_tokens in zip( + prompt_token_ids_list, prediction_all_tokens + ): + prediction_without_prompt_tokens = prediction_tokens[ + len(prompt_token_ids) : + ] + prediction_without_prompt_tokens_list.append( + prediction_without_prompt_tokens + ) - batch_index += 1 + predicted_texts = tokenizer.batch_decode( + prediction_without_prompt_tokens_list, skip_special_tokens=True + ) - wandb.run.log({ f"{name} - Predictions vs Ground Truth": table }) + for prompt_text, completion_text, prediction_text in zip( + prompt_texts, completion_texts, predicted_texts + ): + table.add_data( + row_index, prompt_text, completion_text, prediction_text + ) + row_index += 1 - # log_table_from_dataloader("Train", train_dataloader) - # log_table_from_dataloader("Train", train_dataloader) + wandb.run.log({f"{name} - Predictions vs Ground Truth": table}) - # # Get first 10 records from train_dataloader as a new dataloader - # train_data_subset = [next(iter(train_dataloader)) for _ in range(10)] - # train_dataloader_subset = torch.utils.data.DataLoader(train_data_subset, batch_size=train_dataloader.batch_size, shuffle=False) - # log_table_from_dataloader("Train Subset", train_dataloader_subset) - log_table_from_dataloader("Eval", eval_dataloader) return control return LogPredictionCallback - - -def group_sublists_by(lst: List[int], value: int) -> List[List[int]]: - """ - Group sublists of lst by value - """ - grouped = [] - for key, group in itertools.groupby(lst, lambda x: x == value): - if key: - grouped.append(list(group)) - return grouped From e4c1a2e16f4cc19e650d72c1befec43428dedbc4 Mon Sep 17 00:00:00 2001 From: Glavin Wiechert Date: Sat, 9 Sep 2023 21:29:47 +0000 Subject: [PATCH 13/22] WIP Add AutoGPTQ quantization script --- scripts/finetune.py | 21 ++- scripts/quantize.py | 281 ++++++++++++++++++++++++++++++++++++ src/axolotl/utils/config.py | 5 + 3 files changed, 302 insertions(+), 5 deletions(-) create mode 100644 scripts/quantize.py diff --git a/scripts/finetune.py b/scripts/finetune.py index b998edc798..1ea18b98b8 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -57,27 +57,38 @@ def get_multi_line_input() -> Optional[str]: # instruction = pathlib.Path("/proc/self/fd/0").read_text() return instruction +def get_merged_out_dir(cfg: DictDefault): + return Path(cfg.output_dir) / "merged" -def do_merge_lora( +def do_merge_lora_model_and_tokenizer( *, cfg: DictDefault, - cli_args: TrainerCliArgs, + model, + tokenizer, ): - model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args) safe_serialization = cfg.save_safetensors is True LOG.info("running merge of LoRA with base model") model = model.merge_and_unload() model.to(dtype=torch.float16) + merged_out_dir = str(get_merged_out_dir(cfg)) + if cfg.local_rank == 0: LOG.info("saving merged model") model.save_pretrained( - str(Path(cfg.output_dir) / "merged"), + merged_out_dir, safe_serialization=safe_serialization, ) - tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged")) + tokenizer.save_pretrained(merged_out_dir) +def do_merge_lora( + *, + cfg: DictDefault, + cli_args: TrainerCliArgs, +): + model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args) + do_merge_lora_model_and_tokenizer(cfg=cfg, model=model, tokenizer=tokenizer) def shard( *, diff --git a/scripts/quantize.py b/scripts/quantize.py new file mode 100644 index 0000000000..56106cc655 --- /dev/null +++ b/scripts/quantize.py @@ -0,0 +1,281 @@ +# pip install auto-gptq --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/ + +# import debugpy +# debugpy.listen(('0.0.0.0', 5678)) +# debugpy.wait_for_client() +# debugpy.breakpoint() + +import json +import random +import time +from pathlib import Path +import logging + +import torch +from datasets import load_dataset, Dataset +from transformers import AutoTokenizer, LlamaTokenizer, TextGenerationPipeline +from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig +from axolotl.prompters import AlpacaPrompter +from axolotl.utils.models import load_model, load_tokenizer +from axolotl.common.cli import TrainerCliArgs +from axolotl.logging_config import configure_logging +from axolotl.utils.dict import DictDefault +# from scripts.finetune import load_cfg +from finetune import load_cfg, get_merged_out_dir, do_merge_lora_model_and_tokenizer, load_datasets + +configure_logging() +LOG = logging.getLogger("axolotl") + +# logging.basicConfig( +# format="%(asctime)s %(levelname)s [%(name)s] %(message)s", level=logging.DEBUG, datefmt="%Y-%m-%d %H:%M:%S" +# ) + +# LOG.setLevel(logging.DEBUG) +# handler = logging.StreamHandler() +# formatter = logging.Formatter('%(asctime)s %(levelname)s [%(name)s] %(message)s', datefmt='%Y-%m-%d %H:%M:%S') +# handler.setFormatter(formatter) +# LOG.addHandler(handler) + +print("Done importing...") + +## CHANGE BELOW ## +config_path: Path = Path("./examples/llama-2/lora.yml") + +# pretrained_model_dir = "facebook/opt-125m" +# quantized_model_dir = "opt-125m-4bit" +dataset_name = "teknium/GPT4-LLM-Cleaned" +# huggingface_username = "CHANGE_ME" +## CHANGE ABOVE + +quantize_config = BaseQuantizeConfig( + bits=4, # quantize model to 4-bit + group_size=128, # it is recommended to set the value to 128 + desc_act=False, # set to False can significantly speed up inference but the perplexity may slightly bad +) + +# TEMPLATE = "<|prompt|>{instruction}<|answer|>" +prompter = AlpacaPrompter() + +# def load_data(data_path, tokenizer, n_samples, template=TEMPLATE): +def load_data(data_path, tokenizer, n_samples): + # Load dataset + dataset = load_dataset(data_path) + + if "train" in dataset: + raw_data = dataset["train"] + else: + raw_data = dataset + + # Sample from the dataset if n_samples is provided and less than the dataset size + if n_samples is not None and n_samples < len(raw_data): + raw_data = raw_data.shuffle(seed=42).select(range(n_samples)) + + def tokenize(examples): + instructions = examples["instruction"] + outputs = examples["output"] + + prompts = [] + texts = [] + input_ids = [] + attention_mask = [] + for input_text, output_text in zip(instructions, outputs): + # prompt = template.format(instruction=input_text) + # prompt = next(prompter.build_prompt(instruction=input_text, output=output_text)) + prompt = next(prompter.build_prompt(instruction=input_text)) + text = prompt + output_text + + if len(tokenizer(prompt)["input_ids"]) >= tokenizer.model_max_length: + continue + + tokenized_data = tokenizer(text) + + input_ids.append(tokenized_data["input_ids"][: tokenizer.model_max_length]) + attention_mask.append(tokenized_data["attention_mask"][: tokenizer.model_max_length]) + prompts.append(prompt) + texts.append(text) + + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "prompt": prompts, + "text": texts, + } + + raw_data = raw_data.map( + tokenize, + batched=True, + batch_size=len(raw_data), + num_proc=1, + keep_in_memory=True, + load_from_cache_file=False, + # remove_columns=["instruction", "input"] + ) + + # Convert to PyTorch tensors + raw_data.set_format(type='torch', columns=['input_ids', 'attention_mask']) + + # for sample in dataset: + # sample["input_ids"] = torch.LongTensor(sample["input_ids"]) + # sample["attention_mask"] = torch.LongTensor(sample["attention_mask"]) + + return raw_data + + +# def get_tokenizer(): +# print("Loading tokenizer...") +# # tokenizer = AutoTokenizer.from_pretrained(pretrained_model_dir, use_fast=True) +# tokenizer = LlamaTokenizer.from_pretrained(pretrained_model_dir, use_fast=True) +# return tokenizer + +# def get_model(): +def load_merged_model(cfg: DictDefault): + print("Loading model...") + + merged_out_dir = get_merged_out_dir(cfg) + + # Check if the merged model exists + if not merged_out_dir.exists(): + # If not, merge the model + print("Merged model not found. Merging...") + # model, tokenizer = load_model(cfg, inference=True) + # do_merge_lora_model_and_tokenizer(cfg=cfg, model=model, tokenizer=tokenizer) + raise NotImplementedError("Merging model is not implemented yet.") + + # load un-quantized model, by default, the model will always be loaded into CPU memory + model = AutoGPTQForCausalLM.from_pretrained(merged_out_dir, quantize_config) + # model = AutoGPTQForCausalLM.from_pretrained(pretrained_model_dir, quantize_config) + print("Model loaded.") + return model + +def get_quantized_model(cfg: DictDefault): + print("Loading quantized model...") + quantized_model_dir = get_quantized_model_dir(cfg) + model = AutoGPTQForCausalLM.from_quantized(quantized_model_dir, device="cuda:0", use_safetensors=True) + print("Model loaded.") + return model + +def quantize_and_save(cfg: DictDefault, model, tokenizer, examples_for_quant): + print("Quantize...") + start = time.time() + # quantize model, the examples should be list of dict whose keys can only be "input_ids" and "attention_mask" + model.quantize( + examples_for_quant, + batch_size=1, + # batch_size=args.quant_batch_size, + # use_triton=args.use_triton, + # autotune_warmup_after_quantized=args.use_triton + ) + end = time.time() + print(f"quantization took: {end - start: .4f}s") + + # save quantized model + print("Saving quantized model...") + # model.save_quantized(quantized_model_dir) + quantized_model_dir = get_quantized_model_dir(cfg) + model.save_quantized(quantized_model_dir, use_safetensors=True) + print("Saving tokenizer...") + tokenizer.save_pretrained(quantized_model_dir) + print("Saved.") + + return model + +def push_model(cfg: DictDefault, model, tokenizer): +# def push_model(model): + # push quantized model to Hugging Face Hub. + # to use use_auth_token=True, Login first via huggingface-cli login. + # or pass explcit token with: use_auth_token="hf_xxxxxxx" + # (uncomment the following three lines to enable this feature) + # repo_id = f"YourUserName/{quantized_model_dir}" + print("Pushing to Huggingface hub...") + # repo_id = f"{huggingface_username}/{quantized_model_dir}" + repo_id = get_quantized_model_id(cfg) + pretrained_model_dir = cfg['base_model'] + commit_message = f"AutoGPTQ model for {pretrained_model_dir}: {quantize_config.bits}bits, gr{quantize_config.group_size}, desc_act={quantize_config.desc_act}" + # model.push_to_hub(repo_id, commit_message=commit_message, use_auth_token=True, use_safetensors=True, safe_serialization=True) + # model.push_to_hub(repo_id, commit_message=commit_message, use_auth_token=True, safe_serialization=True) + model.push_to_hub(repo_id, commit_message=commit_message, use_auth_token=True, use_safetensors=True) + tokenizer.push_to_hub(repo_id, commit_message=commit_message, use_auth_token=True) + print("Pushed.") + +# def push_tokenizer(tokenizer): + +def get_quantized_model_id(cfg: DictDefault): +# def get_quantized_model_id(cfg: DictDefault, quantize_config): + # return f"{cfg.hub_model_id}-{quantize_config.bits}bits-gr{quantize_config.group_size}-desc_act{quantize_config.desc_act}" + if not cfg.hub_model_id: + raise ValueError("Missing hub_model_id in the configuration.") + return f"{cfg.hub_model_id}-GPTQ" + +def get_quantized_model_dir(cfg: DictDefault): +# def get_quantized_model_dir(cfg: DictDefault, quantize_config): + if not cfg.output_dir: + raise ValueError("Missing output_dir in the configuration.") + return f"{cfg.output_dir.lstrip('./')}-GPTQ" + +def main(): + print("Starting...") + # return + # prompt = "<|prompt|>How can entrepreneurs start building their own communities even before launching their product?<|answer|>" + + should_quantize = False + # tokenizer = get_tokenizer() + + cfg = load_cfg(config_path) + + cfg['lora_model_dir'] = cfg['output_dir'] + + LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}") + tokenizer = load_tokenizer(cfg) + + if should_quantize: + print("Quantizing...") + + print("Loading dataset...") + datasets = load_datasets(cfg=cfg, cli_args=TrainerCliArgs()) + train_dataset = datasets.train_dataset + n_samples = 128 + # n_samples = 2 + examples = train_dataset.shuffle(seed=42).select( + [ + random.randrange(0, len(train_dataset) - 1) # nosec + for _ in range(n_samples) + ] + ) + + LOG.info("loading model and (optionally) peft_config...") + # model, peft_config = load_model(cfg, tokenizer, inference=True) + model = load_merged_model(cfg) + # model = get_model() + + # examples = load_data(dataset_name, tokenizer, n_samples) + + # print(examples) + examples_for_quant = [ + {"input_ids": example["input_ids"], "attention_mask": example["attention_mask"]} + for example in examples + ] + # print(examples_for_quant) + + modelq = quantize_and_save(cfg, model, tokenizer, examples_for_quant) + else: + print("Loading quantized model...") + modelq = get_quantized_model(cfg) + + push_model(cfg, modelq, tokenizer) + +main() + + +# Load configure +# Load dataset +# Load tokenizer +# Prepare database +# Load previous model, final checkpoint + + +# --merge_lora --lora_model_dir="./completed-model" --load_in_8bit=False --load_in_4bit=False +# accelerate launch ./scripts/finetune.py ./examples/llama-2/lora.yml --merge_lora --lora_model_dir="./lora-out" --load_in_8bit=False --load_in_4bit=False +# CUDA_VISIBLE_DEVICES="1" accelerate launch ./scripts/finetune.py ./examples/llama-2/lora.yml --merge_lora --lora_model_dir="./lora-out" --load_in_8bit=False --load_in_4bit=False + +# HUB_MODEL_ID="Glavin001/llama-2-7b-alpaca_2k_test" accelerate launch ./scripts/quantize.py + diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config.py index 6de807eab9..4916fe2f6b 100644 --- a/src/axolotl/utils/config.py +++ b/src/axolotl/utils/config.py @@ -82,6 +82,11 @@ def normalize_config(cfg): log_gpu_memory_usage(LOG, "baseline", cfg.device) + if os.environ.get("WANDB_PROJECT") and len(os.environ.get("WANDB_PROJECT", "")) > 0: + cfg.wandb_project = os.environ.get("WANDB_PROJECT") + + if os.environ.get("HUB_MODEL_ID") and len(os.environ.get("HUB_MODEL_ID", "")) > 0: + cfg.hub_model_id = os.environ.get("HUB_MODEL_ID") def validate_config(cfg): if cfg.max_packed_sequence_len and cfg.sample_packing: From 19a30cfedec0c7ceb7711a33b0d42f40bf29658e Mon Sep 17 00:00:00 2001 From: Glavin Wiechert Date: Sun, 10 Sep 2023 22:34:37 +0000 Subject: [PATCH 14/22] WIP Integrate quantization into finetune script --- examples/llama-2/lora-short.yml | 70 +++++++++++ scripts/finetune.py | 45 +++++++- scripts/quantize.py | 198 +++----------------------------- src/axolotl/utils/quantize.py | 132 +++++++++++++++++++++ 4 files changed, 263 insertions(+), 182 deletions(-) create mode 100644 examples/llama-2/lora-short.yml create mode 100644 src/axolotl/utils/quantize.py diff --git a/examples/llama-2/lora-short.yml b/examples/llama-2/lora-short.yml new file mode 100644 index 0000000000..bd2b51b962 --- /dev/null +++ b/examples/llama-2/lora-short.yml @@ -0,0 +1,70 @@ +base_model: meta-llama/Llama-2-7b-hf +base_model_config: meta-llama/Llama-2-7b-hf +model_type: LlamaForCausalLM +tokenizer_type: LlamaTokenizer +is_llama_derived_model: true + +load_in_8bit: true +load_in_4bit: false +strict: false + +datasets: + - path: mhenrichsen/alpaca_2k_test + type: alpaca +dataset_prepared_path: last_run_prepared +# val_set_size: 0.01 +val_set_size: 0.001 +output_dir: ./lora-out + +sequence_len: 4096 +sample_packing: true + +adapter: lora +lora_model_dir: +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_linear: true +lora_fan_in_fan_out: + +wandb_project: +wandb_entity: +wandb_watch: +wandb_run_id: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 2 +# num_epochs: 3 +# num_epochs: 1 +num_epochs: 0.1 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +train_on_inputs: false +group_by_length: false +bf16: true +fp16: false +tf32: false + +gradient_checkpointing: true +early_stopping_patience: +resume_from_checkpoint: +local_rank: +logging_steps: 1 +xformers_attention: +flash_attention: true + +warmup_steps: 10 +eval_steps: 20 +save_steps: +debug: +deepspeed: +weight_decay: 0.0 +fsdp: +fsdp_config: +special_tokens: + bos_token: "" + eos_token: "" + unk_token: "" diff --git a/scripts/finetune.py b/scripts/finetune.py index 1ea18b98b8..ecef38bba0 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -27,6 +27,7 @@ from axolotl.utils.models import load_tokenizer from axolotl.utils.tokenization import check_dataset_labels from axolotl.utils.wandb import setup_wandb_env_vars +from axolotl.utils.quantize import get_examples_for_quantization, load_merged_model, quantize_and_save project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) src_dir = os.path.join(project_root, "src") @@ -87,8 +88,14 @@ def do_merge_lora( cfg: DictDefault, cli_args: TrainerCliArgs, ): - model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args) - do_merge_lora_model_and_tokenizer(cfg=cfg, model=model, tokenizer=tokenizer) + new_cfg = DictDefault({ + **cfg, + 'lora_model_dir': cfg['output_dir'], + 'load_in_8bit': False, + 'load_in_4bit': False, + }) + model, tokenizer = load_model_and_tokenizer(cfg=new_cfg, cli_args=cli_args) + do_merge_lora_model_and_tokenizer(cfg=new_cfg, model=model, tokenizer=tokenizer) def shard( *, @@ -282,7 +289,39 @@ def do_cli(config: Path = Path("examples/"), **kwargs): dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) if parsed_cli_args.prepare_ds_only: return - train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta) + model, tokenizer = train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta) + # tokenizer = None + should_quantize = True + + if should_quantize: + # Merge model + # do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args) + # do_merge_lora_model_and_tokenizer(cfg=parsed_cfg, model=model, tokenizer=tokenizer) + # new_cfg = parsed_cfg.copy() + # new_cfg['lora_model_dir'] = new_cfg['output_dir'] + # new_cfg['load_in_8bit'] = False + # new_cfg['load_in_4bit'] = False + + # new_cfg = DictDefault({ + # **parsed_cfg, + # 'lora_model_dir': parsed_cfg['output_dir'], + # 'load_in_8bit': False, + # 'load_in_4bit': False, + # }) + # lora_model_dir="./completed-model" --load_in_8bit=False --load_in_4bit=False + # do_merge_lora(cfg=new_cfg, cli_args=parsed_cli_args) + + # TODO: release old model from GPU memory + do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args) + + # Load merged model with AutoGPTQ + merged_model = load_merged_model(parsed_cfg) + + # Quantize & save + n_samples = 128 + examples = get_examples_for_quantization(dataset_meta.train_dataset, n_samples) + quantize_and_save(parsed_cfg, merged_model, tokenizer, examples) + if __name__ == "__main__": diff --git a/scripts/quantize.py b/scripts/quantize.py index 56106cc655..ff382b7da1 100644 --- a/scripts/quantize.py +++ b/scripts/quantize.py @@ -23,6 +23,8 @@ # from scripts.finetune import load_cfg from finetune import load_cfg, get_merged_out_dir, do_merge_lora_model_and_tokenizer, load_datasets +from axolotl.utils.quantize import load_merged_model, get_quantized_model, quantize_and_save, push_model, get_quantized_model_id, get_quantized_model_dir, get_examples_for_quantization + configure_logging() LOG = logging.getLogger("axolotl") @@ -39,7 +41,8 @@ print("Done importing...") ## CHANGE BELOW ## -config_path: Path = Path("./examples/llama-2/lora.yml") +# config_path: Path = Path("./examples/llama-2/lora.yml") +config_path: Path = Path("./examples/llama-2/lora-short.yml") # pretrained_model_dir = "facebook/opt-125m" # quantized_model_dir = "opt-125m-4bit" @@ -47,177 +50,12 @@ # huggingface_username = "CHANGE_ME" ## CHANGE ABOVE -quantize_config = BaseQuantizeConfig( - bits=4, # quantize model to 4-bit - group_size=128, # it is recommended to set the value to 128 - desc_act=False, # set to False can significantly speed up inference but the perplexity may slightly bad -) - -# TEMPLATE = "<|prompt|>{instruction}<|answer|>" -prompter = AlpacaPrompter() - -# def load_data(data_path, tokenizer, n_samples, template=TEMPLATE): -def load_data(data_path, tokenizer, n_samples): - # Load dataset - dataset = load_dataset(data_path) - - if "train" in dataset: - raw_data = dataset["train"] - else: - raw_data = dataset - - # Sample from the dataset if n_samples is provided and less than the dataset size - if n_samples is not None and n_samples < len(raw_data): - raw_data = raw_data.shuffle(seed=42).select(range(n_samples)) - - def tokenize(examples): - instructions = examples["instruction"] - outputs = examples["output"] - - prompts = [] - texts = [] - input_ids = [] - attention_mask = [] - for input_text, output_text in zip(instructions, outputs): - # prompt = template.format(instruction=input_text) - # prompt = next(prompter.build_prompt(instruction=input_text, output=output_text)) - prompt = next(prompter.build_prompt(instruction=input_text)) - text = prompt + output_text - - if len(tokenizer(prompt)["input_ids"]) >= tokenizer.model_max_length: - continue - - tokenized_data = tokenizer(text) - - input_ids.append(tokenized_data["input_ids"][: tokenizer.model_max_length]) - attention_mask.append(tokenized_data["attention_mask"][: tokenizer.model_max_length]) - prompts.append(prompt) - texts.append(text) - - return { - "input_ids": input_ids, - "attention_mask": attention_mask, - "prompt": prompts, - "text": texts, - } - - raw_data = raw_data.map( - tokenize, - batched=True, - batch_size=len(raw_data), - num_proc=1, - keep_in_memory=True, - load_from_cache_file=False, - # remove_columns=["instruction", "input"] - ) - - # Convert to PyTorch tensors - raw_data.set_format(type='torch', columns=['input_ids', 'attention_mask']) - - # for sample in dataset: - # sample["input_ids"] = torch.LongTensor(sample["input_ids"]) - # sample["attention_mask"] = torch.LongTensor(sample["attention_mask"]) - - return raw_data - - -# def get_tokenizer(): -# print("Loading tokenizer...") -# # tokenizer = AutoTokenizer.from_pretrained(pretrained_model_dir, use_fast=True) -# tokenizer = LlamaTokenizer.from_pretrained(pretrained_model_dir, use_fast=True) -# return tokenizer - -# def get_model(): -def load_merged_model(cfg: DictDefault): - print("Loading model...") - - merged_out_dir = get_merged_out_dir(cfg) - - # Check if the merged model exists - if not merged_out_dir.exists(): - # If not, merge the model - print("Merged model not found. Merging...") - # model, tokenizer = load_model(cfg, inference=True) - # do_merge_lora_model_and_tokenizer(cfg=cfg, model=model, tokenizer=tokenizer) - raise NotImplementedError("Merging model is not implemented yet.") - - # load un-quantized model, by default, the model will always be loaded into CPU memory - model = AutoGPTQForCausalLM.from_pretrained(merged_out_dir, quantize_config) - # model = AutoGPTQForCausalLM.from_pretrained(pretrained_model_dir, quantize_config) - print("Model loaded.") - return model - -def get_quantized_model(cfg: DictDefault): - print("Loading quantized model...") - quantized_model_dir = get_quantized_model_dir(cfg) - model = AutoGPTQForCausalLM.from_quantized(quantized_model_dir, device="cuda:0", use_safetensors=True) - print("Model loaded.") - return model - -def quantize_and_save(cfg: DictDefault, model, tokenizer, examples_for_quant): - print("Quantize...") - start = time.time() - # quantize model, the examples should be list of dict whose keys can only be "input_ids" and "attention_mask" - model.quantize( - examples_for_quant, - batch_size=1, - # batch_size=args.quant_batch_size, - # use_triton=args.use_triton, - # autotune_warmup_after_quantized=args.use_triton - ) - end = time.time() - print(f"quantization took: {end - start: .4f}s") - - # save quantized model - print("Saving quantized model...") - # model.save_quantized(quantized_model_dir) - quantized_model_dir = get_quantized_model_dir(cfg) - model.save_quantized(quantized_model_dir, use_safetensors=True) - print("Saving tokenizer...") - tokenizer.save_pretrained(quantized_model_dir) - print("Saved.") - - return model - -def push_model(cfg: DictDefault, model, tokenizer): -# def push_model(model): - # push quantized model to Hugging Face Hub. - # to use use_auth_token=True, Login first via huggingface-cli login. - # or pass explcit token with: use_auth_token="hf_xxxxxxx" - # (uncomment the following three lines to enable this feature) - # repo_id = f"YourUserName/{quantized_model_dir}" - print("Pushing to Huggingface hub...") - # repo_id = f"{huggingface_username}/{quantized_model_dir}" - repo_id = get_quantized_model_id(cfg) - pretrained_model_dir = cfg['base_model'] - commit_message = f"AutoGPTQ model for {pretrained_model_dir}: {quantize_config.bits}bits, gr{quantize_config.group_size}, desc_act={quantize_config.desc_act}" - # model.push_to_hub(repo_id, commit_message=commit_message, use_auth_token=True, use_safetensors=True, safe_serialization=True) - # model.push_to_hub(repo_id, commit_message=commit_message, use_auth_token=True, safe_serialization=True) - model.push_to_hub(repo_id, commit_message=commit_message, use_auth_token=True, use_safetensors=True) - tokenizer.push_to_hub(repo_id, commit_message=commit_message, use_auth_token=True) - print("Pushed.") - -# def push_tokenizer(tokenizer): - -def get_quantized_model_id(cfg: DictDefault): -# def get_quantized_model_id(cfg: DictDefault, quantize_config): - # return f"{cfg.hub_model_id}-{quantize_config.bits}bits-gr{quantize_config.group_size}-desc_act{quantize_config.desc_act}" - if not cfg.hub_model_id: - raise ValueError("Missing hub_model_id in the configuration.") - return f"{cfg.hub_model_id}-GPTQ" - -def get_quantized_model_dir(cfg: DictDefault): -# def get_quantized_model_dir(cfg: DictDefault, quantize_config): - if not cfg.output_dir: - raise ValueError("Missing output_dir in the configuration.") - return f"{cfg.output_dir.lstrip('./')}-GPTQ" - def main(): print("Starting...") # return # prompt = "<|prompt|>How can entrepreneurs start building their own communities even before launching their product?<|answer|>" - should_quantize = False + should_quantize = True # tokenizer = get_tokenizer() cfg = load_cfg(config_path) @@ -234,13 +72,13 @@ def main(): datasets = load_datasets(cfg=cfg, cli_args=TrainerCliArgs()) train_dataset = datasets.train_dataset n_samples = 128 - # n_samples = 2 - examples = train_dataset.shuffle(seed=42).select( - [ - random.randrange(0, len(train_dataset) - 1) # nosec - for _ in range(n_samples) - ] - ) + # # n_samples = 2 + # examples = train_dataset.shuffle(seed=42).select( + # [ + # random.randrange(0, len(train_dataset) - 1) # nosec + # for _ in range(n_samples) + # ] + # ) LOG.info("loading model and (optionally) peft_config...") # model, peft_config = load_model(cfg, tokenizer, inference=True) @@ -250,11 +88,12 @@ def main(): # examples = load_data(dataset_name, tokenizer, n_samples) # print(examples) - examples_for_quant = [ - {"input_ids": example["input_ids"], "attention_mask": example["attention_mask"]} - for example in examples - ] + # examples_for_quant = [ + # {"input_ids": example["input_ids"], "attention_mask": example["attention_mask"]} + # for example in examples + # ] # print(examples_for_quant) + examples_for_quant = get_examples_for_quantization(train_dataset, n_samples) modelq = quantize_and_save(cfg, model, tokenizer, examples_for_quant) else: @@ -263,7 +102,8 @@ def main(): push_model(cfg, modelq, tokenizer) -main() +if __name__ == "__main__": + main() # Load configure diff --git a/src/axolotl/utils/quantize.py b/src/axolotl/utils/quantize.py new file mode 100644 index 0000000000..f8214bda2a --- /dev/null +++ b/src/axolotl/utils/quantize.py @@ -0,0 +1,132 @@ +# pip install auto-gptq --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/ + +# import debugpy +# debugpy.listen(('0.0.0.0', 5678)) +# debugpy.wait_for_client() +# debugpy.breakpoint() + +import json +import random +import time +from pathlib import Path +import logging + +# import torch +# from datasets import load_dataset, Dataset +# from transformers import AutoTokenizer, LlamaTokenizer, TextGenerationPipeline +from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig +from axolotl.prompters import AlpacaPrompter +from axolotl.utils.models import load_model, load_tokenizer +from axolotl.common.cli import TrainerCliArgs +from axolotl.logging_config import configure_logging +from axolotl.utils.dict import DictDefault +# from finetune import load_cfg, get_merged_out_dir, do_merge_lora_model_and_tokenizer + +# configure_logging() +# LOG = logging.getLogger("axolotl") + +quantize_config = BaseQuantizeConfig( + bits=4, # quantize model to 4-bit + group_size=128, # it is recommended to set the value to 128 + desc_act=False, # set to False can significantly speed up inference but the perplexity may slightly bad +) + +def get_merged_out_dir(cfg: DictDefault): + return Path(cfg.output_dir) / "merged" + +def load_merged_model(cfg: DictDefault): + print("Loading merged model...") + + merged_out_dir = get_merged_out_dir(cfg) + + # Check if the merged model exists + if not merged_out_dir.exists(): + # If not, merge the model + print("Merged model not found. Merging...") + # model, tokenizer = load_model(cfg, inference=True) + # do_merge_lora_model_and_tokenizer(cfg=cfg, model=model, tokenizer=tokenizer) + raise NotImplementedError("Merging model is not implemented yet.") + + # load un-quantized model, by default, the model will always be loaded into CPU memory + model = AutoGPTQForCausalLM.from_pretrained(merged_out_dir, quantize_config) + # model = AutoGPTQForCausalLM.from_pretrained(pretrained_model_dir, quantize_config) + print("Model loaded.") + return model + +def get_quantized_model(cfg: DictDefault): + print("Loading quantized model...") + quantized_model_dir = get_quantized_model_dir(cfg) + model = AutoGPTQForCausalLM.from_quantized(quantized_model_dir, device="cuda:0", use_safetensors=True) + print("Model loaded.") + return model + +def quantize_and_save(cfg: DictDefault, model, tokenizer, examples_for_quant): + print("Quantize...") + start = time.time() + # quantize model, the examples should be list of dict whose keys can only be "input_ids" and "attention_mask" + model.quantize( + examples_for_quant, + batch_size=1, + # batch_size=args.quant_batch_size, + # use_triton=args.use_triton, + # autotune_warmup_after_quantized=args.use_triton + ) + end = time.time() + print(f"quantization took: {end - start: .4f}s") + + # save quantized model + print("Saving quantized model...") + # model.save_quantized(quantized_model_dir) + quantized_model_dir = get_quantized_model_dir(cfg) + model.save_quantized(quantized_model_dir, use_safetensors=True) + print("Saving tokenizer...") + tokenizer.save_pretrained(quantized_model_dir) + print("Saved.") + + return model + +def push_model(cfg: DictDefault, model, tokenizer): +# def push_model(model): + # push quantized model to Hugging Face Hub. + # to use use_auth_token=True, Login first via huggingface-cli login. + # or pass explcit token with: use_auth_token="hf_xxxxxxx" + # (uncomment the following three lines to enable this feature) + # repo_id = f"YourUserName/{quantized_model_dir}" + print("Pushing to Huggingface hub...") + # repo_id = f"{huggingface_username}/{quantized_model_dir}" + repo_id = get_quantized_model_id(cfg) + pretrained_model_dir = cfg['base_model'] + commit_message = f"AutoGPTQ model for {pretrained_model_dir}: {quantize_config.bits}bits, gr{quantize_config.group_size}, desc_act={quantize_config.desc_act}" + # model.push_to_hub(repo_id, commit_message=commit_message, use_auth_token=True, use_safetensors=True, safe_serialization=True) + # model.push_to_hub(repo_id, commit_message=commit_message, use_auth_token=True, safe_serialization=True) + model.push_to_hub(repo_id, commit_message=commit_message, use_auth_token=True, use_safetensors=True) + tokenizer.push_to_hub(repo_id, commit_message=commit_message, use_auth_token=True) + print("Pushed.") + +def get_quantized_model_id(cfg: DictDefault): +# def get_quantized_model_id(cfg: DictDefault, quantize_config): + # return f"{cfg.hub_model_id}-{quantize_config.bits}bits-gr{quantize_config.group_size}-desc_act{quantize_config.desc_act}" + if not cfg.hub_model_id: + raise ValueError("Missing hub_model_id in the configuration.") + return f"{cfg.hub_model_id}-GPTQ" + +def get_quantized_model_dir(cfg: DictDefault): +# def get_quantized_model_dir(cfg: DictDefault, quantize_config): + if not cfg.output_dir: + raise ValueError("Missing output_dir in the configuration.") + return f"{cfg.output_dir.lstrip('./')}-GPTQ" + +def get_examples_for_quantization(dataset, n_samples): + print("Loading dataset...") + examples = dataset.shuffle(seed=42).select( + [ + random.randrange(0, len(dataset) - 1) # nosec + for _ in range(n_samples) + ] + ) + + examples_for_quant = [ + {"input_ids": example["input_ids"], "attention_mask": example["attention_mask"]} + for example in examples + ] + return examples_for_quant From 894a4befcbda1adc232f7f8238509a04a762bc52 Mon Sep 17 00:00:00 2001 From: Glavin Wiechert Date: Mon, 11 Sep 2023 07:49:36 +0000 Subject: [PATCH 15/22] Add --quantize option to finetune script, fix auto_gptq logging --- README.md | 2 +- scripts/finetune.py | 35 +++++++++++++++++++++++++++++++++-- scripts/quantize.py | 28 +++++++++++++++++++++++++++- src/axolotl/common/cli.py | 1 + src/axolotl/logging_config.py | 5 +++++ src/axolotl/utils/quantize.py | 8 ++++---- 6 files changed, 71 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 775592efe6..4189165577 100644 --- a/README.md +++ b/README.md @@ -703,7 +703,7 @@ Pass the appropriate flag to the train command: Add below flag to train command above ```bash ---merge_lora --lora_model_dir="./completed-model" --load_in_8bit=False --load_in_4bit=False +--merge_lora --lora_model_dir="./completed-model" ``` If you run out of CUDA memory, you can try to merge in system RAM with diff --git a/scripts/finetune.py b/scripts/finetune.py index ecef38bba0..b02cd54912 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -1,5 +1,7 @@ """Prepare and train a model on a dataset. Can also infer from a model or merge lora""" +import gc + import importlib import logging import os @@ -90,7 +92,7 @@ def do_merge_lora( ): new_cfg = DictDefault({ **cfg, - 'lora_model_dir': cfg['output_dir'], + 'lora_model_dir': cfg.get('lora_model_dir', cfg['output_dir']), 'load_in_8bit': False, 'load_in_4bit': False, }) @@ -285,11 +287,24 @@ def do_cli(config: Path = Path("examples/"), **kwargs): do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args) elif parsed_cli_args.shard: shard(cfg=parsed_cfg, cli_args=parsed_cli_args) + elif parsed_cli_args.quantize: + dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) + + tokenizer = load_tokenizer(parsed_cfg) + # Load merged model with AutoGPTQ + merged_model = load_merged_model(parsed_cfg) + + # Quantize & save + n_samples = 128 + examples = get_examples_for_quantization(dataset_meta.train_dataset, n_samples) + quantize_and_save(parsed_cfg, merged_model, tokenizer, examples) + else: dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) if parsed_cli_args.prepare_ds_only: return - model, tokenizer = train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta) + # model, tokenizer = train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta) + train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta) # tokenizer = None should_quantize = True @@ -310,8 +325,24 @@ def do_cli(config: Path = Path("examples/"), **kwargs): # }) # lora_model_dir="./completed-model" --load_in_8bit=False --load_in_4bit=False # do_merge_lora(cfg=new_cfg, cli_args=parsed_cli_args) + + def log_gpu_memory(): + print("GPU Memory:", torch.cuda.memory_allocated()) + log_gpu_memory() + print(len(gc.get_referrers(model))) + print(sys.getrefcount(model)) + # TODO: release old model from GPU memory + print(gc.collect()) + del model + # del tokenizer + print(gc.collect()) + torch.cuda.empty_cache() + print(gc.collect()) + + log_gpu_memory() + do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args) # Load merged model with AutoGPTQ diff --git a/scripts/quantize.py b/scripts/quantize.py index ff382b7da1..10bddf6a7c 100644 --- a/scripts/quantize.py +++ b/scripts/quantize.py @@ -10,6 +10,7 @@ import time from pathlib import Path import logging +import re import torch from datasets import load_dataset, Dataset @@ -26,7 +27,32 @@ from axolotl.utils.quantize import load_merged_model, get_quantized_model, quantize_and_save, push_model, get_quantized_model_id, get_quantized_model_dir, get_examples_for_quantization configure_logging() -LOG = logging.getLogger("axolotl") +LOG = logging.getLogger("axolotl.quantize") + +import debugpy +debugpy.listen(('0.0.0.0', 5678)) +debugpy.wait_for_client() +debugpy.breakpoint() + +class ProgressExtractingHandler(logging.Handler): + def emit(self, record): + log_entry = self.format(record) + progress_info = self.extract_progress(log_entry) + if progress_info: + print(f"Progress: {progress_info}") + + @staticmethod + def extract_progress(log_entry): +# [2023-09-11 07:20:37,502] [INFO] [auto_gptq.modeling._base.quantize:364] [PID:3962] [RANK:0] Quantizing self_attn.k_proj in layer 4/32... + match = re.search(r'layer (\d+/\d+)', log_entry) + return match.group(1) if match else None + # [2023-09-11 07:27:52,208] [INFO] [auto_gptq.modeling._utils.pack_model:129] [PID:3962] [RANK:0] model.layers.15.self_attn.o_proj + +handler = ProgressExtractingHandler() +# logging.getLogger('auto_gptq.modeling._base.quantize').addHandler(handler) +logger = logging.getLogger('auto_gptq.modeling._base.quantize') +logger.setLevel(logging.DEBUG) +logger.addHandler(handler) # logging.basicConfig( # format="%(asctime)s %(levelname)s [%(name)s] %(message)s", level=logging.DEBUG, datefmt="%Y-%m-%d %H:%M:%S" diff --git a/src/axolotl/common/cli.py b/src/axolotl/common/cli.py index 62f2b1061a..d6bf59f74f 100644 --- a/src/axolotl/common/cli.py +++ b/src/axolotl/common/cli.py @@ -25,6 +25,7 @@ class TrainerCliArgs: debug_num_examples: int = field(default=5) inference: bool = field(default=False) merge_lora: bool = field(default=False) + quantize: bool = field(default=False) prepare_ds_only: bool = field(default=False) prompter: Optional[str] = field(default=None) shard: bool = field(default=False) diff --git a/src/axolotl/logging_config.py b/src/axolotl/logging_config.py index 8f473aa240..41a8e41103 100644 --- a/src/axolotl/logging_config.py +++ b/src/axolotl/logging_config.py @@ -61,6 +61,11 @@ def format(self, record): "level": "DEBUG", "propagate": False, }, + "auto_gptq": { + "handlers": ["color_console"], + "level": "DEBUG", + "propagate": False, + }, }, } diff --git a/src/axolotl/utils/quantize.py b/src/axolotl/utils/quantize.py index f8214bda2a..89da2add0e 100644 --- a/src/axolotl/utils/quantize.py +++ b/src/axolotl/utils/quantize.py @@ -42,10 +42,9 @@ def load_merged_model(cfg: DictDefault): # Check if the merged model exists if not merged_out_dir.exists(): # If not, merge the model - print("Merged model not found. Merging...") - # model, tokenizer = load_model(cfg, inference=True) + raise FileNotFoundError("Merged model not found. Please ensure the model has been merged.") # do_merge_lora_model_and_tokenizer(cfg=cfg, model=model, tokenizer=tokenizer) - raise NotImplementedError("Merging model is not implemented yet.") + # raise NotImplementedError("Merging model is not implemented yet.") # load un-quantized model, by default, the model will always be loaded into CPU memory model = AutoGPTQForCausalLM.from_pretrained(merged_out_dir, quantize_config) @@ -114,7 +113,8 @@ def get_quantized_model_dir(cfg: DictDefault): # def get_quantized_model_dir(cfg: DictDefault, quantize_config): if not cfg.output_dir: raise ValueError("Missing output_dir in the configuration.") - return f"{cfg.output_dir.lstrip('./')}-GPTQ" + p = Path(cfg.output_dir) / "quantized" + return str(p).lstrip('./') def get_examples_for_quantization(dataset, n_samples): print("Loading dataset...") From 24c048348d624c2f6d8adaa2d963bb634fd26983 Mon Sep 17 00:00:00 2001 From: Glavin Wiechert Date: Mon, 11 Sep 2023 08:04:33 +0000 Subject: [PATCH 16/22] Disable quantizing directly after fine tuning --- scripts/finetune.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/finetune.py b/scripts/finetune.py index b02cd54912..925a29d0a2 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -306,7 +306,7 @@ def do_cli(config: Path = Path("examples/"), **kwargs): # model, tokenizer = train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta) train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta) # tokenizer = None - should_quantize = True + should_quantize = False if should_quantize: # Merge model From 14d26e15acab90327b1e67e80f27df21aefa2a61 Mon Sep 17 00:00:00 2001 From: Glavin Wiechert Date: Tue, 12 Sep 2023 06:51:26 +0000 Subject: [PATCH 17/22] Add eval_table_size, eval_table_max_new_tokens configs & clean up code --- docker-compose.yaml | 11 +-- examples/llama-2/lora.yml | 2 + .../monkeypatch/llama_attn_hijack_flash.py | 2 + src/axolotl/utils/callbacks.py | 74 +++++++++++++------ src/axolotl/utils/config.py | 2 + 5 files changed, 61 insertions(+), 30 deletions(-) diff --git a/docker-compose.yaml b/docker-compose.yaml index 143d5f6c91..d40422f94f 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -1,23 +1,20 @@ # version: '3.8' services: axolotl: - # build: - # context: . - # dockerfile: ./docker/Dockerfile - image: winglian/axolotl:main-py3.10-cu118-2.0.1 + build: + context: . + dockerfile: ./docker/Dockerfile volumes: - .:/workspace/axolotl - ~/.cache/huggingface/:/root/.cache/huggingface/ # set environment variables environment: - WANDB_API_KEY=${WANDB_API_KEY} - ports: - - "5678:5678" deploy: resources: reservations: devices: - driver: nvidia - count: 1 + # count: 1 capabilities: [gpu] command: tail -f /dev/null diff --git a/examples/llama-2/lora.yml b/examples/llama-2/lora.yml index 2a0af130be..fc2307faed 100644 --- a/examples/llama-2/lora.yml +++ b/examples/llama-2/lora.yml @@ -55,6 +55,8 @@ flash_attention: true warmup_steps: 10 eval_steps: 20 +eval_table_size: 5 +eval_table_max_new_tokens: 128 save_steps: debug: deepspeed: diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py index 6b3fd4355e..f6adcdac25 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py @@ -262,6 +262,8 @@ def flashattn_forward( if attention_mask is not None else None, ) + if q_unpad.dtype != kv_unpad.dtype: + kv_unpad = kv_unpad.to(q_unpad.dtype) output_unpad = flash_attn_varlen_kvpacked_func( q_unpad, kv_unpad, diff --git a/src/axolotl/utils/callbacks.py b/src/axolotl/utils/callbacks.py index 2272ce0a2b..be81503cbf 100644 --- a/src/axolotl/utils/callbacks.py +++ b/src/axolotl/utils/callbacks.py @@ -339,9 +339,33 @@ def on_evaluate( eval_dataloader, **kwargs, ): + eval_table_size = self.cfg.eval_table_size + + if eval_table_size <= 0: + return control + trainer.model.eval() device = torch.device(self.cfg.device) + generation_config = GenerationConfig( + max_new_tokens=self.cfg.eval_table_max_new_tokens, + bos_token_id=tokenizer.bos_token_id, + eos_token_id=tokenizer.eos_token_id, + pad_token_id=tokenizer.pad_token_id, + do_sample=False, + use_cache=True, + return_dict_in_generate=True, + output_attentions=False, + output_hidden_states=False, + output_scores=False, + ) + + def logits_to_tokens(logits) -> str: + probabilities = torch.softmax(logits, dim=-1) + # Get the predicted token ids (the ones with the highest probability) + predicted_token_ids = torch.argmax(probabilities, dim=-1) + return predicted_token_ids + def find_ranges(lst): ranges = [] start = 0 @@ -359,13 +383,17 @@ def log_table_from_dataloader(name: str, table_dataloader): "id", "Prompt", "Correct Completion", - "Predicted Completion", + "Predicted Completion (model.generate)", + "Predicted Completion (trainer.prediction_step)", ] ) row_index = 0 - max_new_tokens = 128 - for batch in tqdm(table_dataloader, total=len(table_dataloader)): + for batch in tqdm(table_dataloader): + + if row_index > eval_table_size: + break + batch_labels = batch["labels"].to(device) batch_input_ids = batch["input_ids"].to(device) @@ -374,11 +402,18 @@ def log_table_from_dataloader(name: str, table_dataloader): else: batch_pos_ids = [None] * len(batch["input_ids"]) + (_, batch_logits, _) = trainer.prediction_step( + trainer.model, + batch, + prediction_loss_only=False, + ) + prompt_token_ids_list = [] + pred_step_token_ids_list = [] completion_token_ids_list = [] - for input_ids_all, labels_all, pos_ids in zip( - batch_input_ids, batch_labels, batch_pos_ids + for input_ids_all, labels_all, pos_ids, logits in zip( + batch_input_ids, batch_labels, batch_pos_ids, batch_logits, ): if pos_ids is None: pos_ranges = [(0, len(input_ids_all) - 1)] @@ -396,7 +431,6 @@ def log_table_from_dataloader(name: str, table_dataloader): tokens_without_loss = labels == IGNORE_INDEX tokens_with_loss = labels != IGNORE_INDEX tokens_exclude_padding = input_ids != tokenizer.pad_token_id - prompt_token_includes = ( tokens_without_loss & tokens_exclude_padding ) @@ -407,27 +441,20 @@ def log_table_from_dataloader(name: str, table_dataloader): completion_token_ids = input_ids[tokens_with_loss] completion_token_ids_list.append(completion_token_ids) + pred_step_token_ids = logits_to_tokens(logits[start : end + 1])[tokens_with_loss] + pred_step_token_ids_list.append(pred_step_token_ids) + prompt_texts = tokenizer.batch_decode( prompt_token_ids_list, skip_special_tokens=True ) completion_texts = tokenizer.batch_decode( completion_token_ids_list, skip_special_tokens=True ) + pred_step_texts = tokenizer.batch_decode( + pred_step_token_ids_list, skip_special_tokens=True + ) with torch.no_grad(): - generation_config = GenerationConfig( - max_new_tokens=max_new_tokens, - bos_token_id=tokenizer.bos_token_id, - eos_token_id=tokenizer.eos_token_id, - pad_token_id=tokenizer.pad_token_id, - do_sample=False, - use_cache=True, - return_dict_in_generate=True, - output_attentions=False, - output_hidden_states=False, - output_scores=False, - ) - prompt_encoding = tokenizer( prompt_texts, padding=True, return_tensors="pt" ).to(self.cfg.device) @@ -451,17 +478,18 @@ def log_table_from_dataloader(name: str, table_dataloader): prediction_without_prompt_tokens_list, skip_special_tokens=True ) - for prompt_text, completion_text, prediction_text in zip( - prompt_texts, completion_texts, predicted_texts + for prompt_text, completion_text, prediction_text, pred_step_text in zip( + prompt_texts, completion_texts, predicted_texts, pred_step_texts ): table.add_data( - row_index, prompt_text, completion_text, prediction_text + row_index, prompt_text, completion_text, prediction_text, pred_step_text ) row_index += 1 wandb.run.log({f"{name} - Predictions vs Ground Truth": table}) - log_table_from_dataloader("Eval", eval_dataloader) + if is_main_process(): + log_table_from_dataloader("Eval", eval_dataloader) return control diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config.py index 93a23f7738..ce1f2255b8 100644 --- a/src/axolotl/utils/config.py +++ b/src/axolotl/utils/config.py @@ -48,6 +48,8 @@ def normalize_config(cfg): ) cfg.world_size = int(os.environ.get("WORLD_SIZE", 1)) cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0)) + cfg.eval_table_size = cfg.eval_table_size or 0 + cfg.eval_table_max_new_tokens = cfg.eval_table_max_new_tokens or 128 choose_device(cfg) cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1 if cfg.ddp: From c6c54eeeaaa5cc939b2b9821be9f42c40dcaf78d Mon Sep 17 00:00:00 2001 From: Glavin Wiechert Date: Tue, 12 Sep 2023 06:58:53 +0000 Subject: [PATCH 18/22] Clean up PR, delete VSCode config, add tiny-llama example --- .vscode/launch.json | 37 ----------- examples/llama-2/qlora.yml | 1 + examples/llama-2/tiny-llama.yml | 69 +++++++++++++++++++++ examples/llama-2/tiny-random.yml | 101 ------------------------------- scripts/finetune.py | 5 -- src/axolotl/utils/trainer.py | 2 +- 6 files changed, 71 insertions(+), 144 deletions(-) delete mode 100644 .vscode/launch.json create mode 100644 examples/llama-2/tiny-llama.yml delete mode 100644 examples/llama-2/tiny-random.yml diff --git a/.vscode/launch.json b/.vscode/launch.json deleted file mode 100644 index 3f60f05d36..0000000000 --- a/.vscode/launch.json +++ /dev/null @@ -1,37 +0,0 @@ -{ - // Use IntelliSense to learn about possible attributes. - // Hover to view descriptions of existing attributes. - // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 - "version": "0.2.0", - "configurations": [ - { - "name": "Python: Remote Attach", - "type": "python", - "request": "attach", - "connect": { - "host": "0.0.0.0", - "port": 5678 - }, - "pathMappings": [ - { - "localRoot": "${workspaceFolder}", - "remoteRoot": "/workspace/axolotl/" - } - ], - "justMyCode": false - }, - { - "name": "train", - "type": "python", - "request": "launch", - "module": "accelerate.commands.launch", - "args": [ - "${workspaceFolder}/scripts/finetune.py", - // "${file}", - "${workspaceFolder}/examples/llama-2/tiny-random.yml", - ], // other args comes after train.py - "console": "integratedTerminal", - // "env": {"CUDA_LAUNCH_BLOCKING": "1"} - }, - ] -} diff --git a/examples/llama-2/qlora.yml b/examples/llama-2/qlora.yml index 3ad2a7e4fd..872b8935d4 100644 --- a/examples/llama-2/qlora.yml +++ b/examples/llama-2/qlora.yml @@ -57,6 +57,7 @@ flash_attention: true warmup_steps: 10 eval_steps: 20 +eval_table_size: 5 save_steps: debug: deepspeed: diff --git a/examples/llama-2/tiny-llama.yml b/examples/llama-2/tiny-llama.yml new file mode 100644 index 0000000000..a53c9c831b --- /dev/null +++ b/examples/llama-2/tiny-llama.yml @@ -0,0 +1,69 @@ +base_model: PY007/TinyLlama-1.1B-step-50K-105b +base_model_config: PY007/TinyLlama-1.1B-step-50K-105b + +model_type: LlamaForCausalLM +tokenizer_type: LlamaTokenizer +is_llama_derived_model: true + +load_in_8bit: true +load_in_4bit: false +strict: false + +datasets: + - path: mhenrichsen/alpaca_2k_test + type: alpaca +dataset_prepared_path: last_run_prepared +val_set_size: 0.01 +output_dir: ./lora-out + +sequence_len: 4096 +sample_packing: true + +adapter: lora +lora_model_dir: +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_linear: true +lora_fan_in_fan_out: + +wandb_project: +wandb_entity: +wandb_watch: +wandb_run_id: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 2 +num_epochs: 3 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +train_on_inputs: false +group_by_length: false +bf16: true +fp16: false +tf32: false + +gradient_checkpointing: true +early_stopping_patience: +resume_from_checkpoint: +local_rank: +logging_steps: 1 +xformers_attention: +flash_attention: true + +warmup_steps: 10 +eval_steps: 20 +eval_table_size: 5 +save_steps: +debug: +deepspeed: +weight_decay: 0.0 +fsdp: +fsdp_config: +special_tokens: + bos_token: "" + eos_token: "" + unk_token: "" diff --git a/examples/llama-2/tiny-random.yml b/examples/llama-2/tiny-random.yml deleted file mode 100644 index 138aab2963..0000000000 --- a/examples/llama-2/tiny-random.yml +++ /dev/null @@ -1,101 +0,0 @@ -# anushehchaudry/llama-2-tiny-random -# base_model: anushehchaudry/llama-2-tiny-random -# base_model_config: anushehchaudry/llama-2-tiny-random - -# base_model: JackFram/llama-68m -# base_model_config: JackFram/llama-68m - -base_model: PY007/TinyLlama-1.1B-step-50K-105b -base_model_config: PY007/TinyLlama-1.1B-step-50K-105b - -model_type: LlamaForCausalLM -tokenizer_type: LlamaTokenizer -is_llama_derived_model: true - -load_in_8bit: true -load_in_4bit: false -strict: false - -datasets: - # - path: mhenrichsen/alpaca_2k_test - # type: alpaca - # - path: teknium/GPT4-LLM-Cleaned - # type: alpaca - - path: Glavin001/startup-interviews - type: alpaca -dataset_prepared_path: last_run_prepared -# val_set_size: 0.01 -val_set_size: 0.02 -# val_set_size: 0.05 -# val_set_size: 0.001 -# val_set_size: 0.1 -# output_dir: ./lora-out -# output_dir: ./lora-2-out -output_dir: ./lora-6-out - -# sequence_len: 4096 -# sequence_len: 2048 -# sequence_len: 256 -# sequence_len: 512 -sequence_len: 1024 -sample_packing: true -# sample_packing: false # FIXME: disabled until we can fix the bug in callbacks.py - -adapter: lora -lora_model_dir: -lora_r: 32 -lora_alpha: 16 -lora_dropout: 0.05 -lora_target_linear: true -lora_fan_in_fan_out: - -wandb_project: test-issue-490 -wandb_entity: -wandb_watch: -wandb_run_id: -wandb_log_model: - -gradient_accumulation_steps: 4 -# micro_batch_size: 2 -micro_batch_size: 16 -# micro_batch_size: 24 -# micro_batch_size: 24 -# num_epochs: 3 -# num_epochs: 0.001 -# num_epochs: 0.01 -# num_epochs: 1 -# num_epochs: 5 -num_epochs: 10 -optimizer: adamw_bnb_8bit -lr_scheduler: cosine -learning_rate: 0.0002 - -train_on_inputs: false -group_by_length: false -bf16: true -fp16: false -tf32: false - -gradient_checkpointing: true -early_stopping_patience: -resume_from_checkpoint: -local_rank: -logging_steps: 1 -xformers_attention: -flash_attention: true - -warmup_steps: 10 -# eval_steps: 10 -# eval_steps: 20 -eval_steps: 2 -# eval_steps: 1 -save_steps: -debug: -deepspeed: -weight_decay: 0.0 -fsdp: -fsdp_config: -special_tokens: - bos_token: "" - eos_token: "" - unk_token: "" diff --git a/scripts/finetune.py b/scripts/finetune.py index 91e6d5d844..b998edc798 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -28,11 +28,6 @@ from axolotl.utils.tokenization import check_dataset_labels from axolotl.utils.wandb import setup_wandb_env_vars -# import debugpy -# debugpy.listen(('0.0.0.0', 5678)) -# debugpy.wait_for_client() -# debugpy.breakpoint() - project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) src_dir = os.path.join(project_root, "src") sys.path.insert(0, src_dir) diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 57d503d397..f7d0b4329a 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -720,7 +720,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_ **trainer_kwargs, ) - if cfg.use_wandb: + if cfg.use_wandb and cfg.eval_table_size > 0: LogPredictionCallback = log_prediction_callback_factory(trainer, tokenizer) trainer.add_callback(LogPredictionCallback(cfg)) From dee3d54b943cc3e5c1b0f70c1cc92804f332f4e9 Mon Sep 17 00:00:00 2001 From: Glavin Wiechert Date: Tue, 12 Sep 2023 06:51:26 +0000 Subject: [PATCH 19/22] Add eval_table_size, eval_table_max_new_tokens configs & clean up code --- examples/llama-2/lora.yml | 2 + .../monkeypatch/llama_attn_hijack_flash.py | 2 + src/axolotl/utils/callbacks.py | 177 ++++++++++++++++++ src/axolotl/utils/config.py | 2 + 4 files changed, 183 insertions(+) diff --git a/examples/llama-2/lora.yml b/examples/llama-2/lora.yml index a54799b408..4fdcb04092 100644 --- a/examples/llama-2/lora.yml +++ b/examples/llama-2/lora.yml @@ -56,6 +56,8 @@ flash_attention: true warmup_steps: 10 eval_steps: 20 +eval_table_size: 5 +eval_table_max_new_tokens: 128 save_steps: debug: deepspeed: diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py index ef048082c1..909f2300b1 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py @@ -261,6 +261,8 @@ def flashattn_forward( if attention_mask is not None else None, ) + if q_unpad.dtype != kv_unpad.dtype: + kv_unpad = kv_unpad.to(q_unpad.dtype) output_unpad = flash_attn_varlen_kvpacked_func( q_unpad, kv_unpad, diff --git a/src/axolotl/utils/callbacks.py b/src/axolotl/utils/callbacks.py index 3f776537a5..3925c20d8c 100644 --- a/src/axolotl/utils/callbacks.py +++ b/src/axolotl/utils/callbacks.py @@ -15,6 +15,8 @@ from optimum.bettertransformer import BetterTransformer from tqdm import tqdm from transformers import ( + GenerationConfig, + Trainer, TrainerCallback, TrainerControl, TrainerState, @@ -22,6 +24,7 @@ ) from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy +import wandb from axolotl.utils.bench import log_gpu_memory_usage from axolotl.utils.distributed import ( barrier, @@ -323,3 +326,177 @@ def on_evaluate( metrics[key] = val return BenchEvalCallback + + +def log_prediction_callback_factory(trainer: Trainer, tokenizer): + class LogPredictionCallback(TrainerCallback): + """Callback to log prediction values during each evaluation""" + + def __init__(self, cfg): + self.cfg = cfg + self.logged = False + + def on_evaluate( + self, + args: AxolotlTrainingArguments, + state: TrainerState, + control: TrainerControl, + train_dataloader, + eval_dataloader, + **kwargs, + ): + eval_table_size = self.cfg.eval_table_size + + if eval_table_size <= 0: + return control + + trainer.model.eval() + device = torch.device(self.cfg.device) + + generation_config = GenerationConfig( + max_new_tokens=self.cfg.eval_table_max_new_tokens, + bos_token_id=tokenizer.bos_token_id, + eos_token_id=tokenizer.eos_token_id, + pad_token_id=tokenizer.pad_token_id, + do_sample=False, + use_cache=True, + return_dict_in_generate=True, + output_attentions=False, + output_hidden_states=False, + output_scores=False, + ) + + def logits_to_tokens(logits) -> str: + probabilities = torch.softmax(logits, dim=-1) + # Get the predicted token ids (the ones with the highest probability) + predicted_token_ids = torch.argmax(probabilities, dim=-1) + return predicted_token_ids + + def find_ranges(lst): + ranges = [] + start = 0 + for i in range(1, len(lst)): + if lst[i] == 0: + ranges.append((start, i - 1)) + start = i + end = len(lst) - 1 + ranges.append((start, end)) + return ranges + + def log_table_from_dataloader(name: str, table_dataloader): + table = wandb.Table( + columns=[ + "id", + "Prompt", + "Correct Completion", + "Predicted Completion (model.generate)", + "Predicted Completion (trainer.prediction_step)", + ] + ) + row_index = 0 + + for batch in tqdm(table_dataloader): + + if row_index > eval_table_size: + break + + batch_labels = batch["labels"].to(device) + batch_input_ids = batch["input_ids"].to(device) + + if "position_ids" in batch: + batch_pos_ids = batch["position_ids"].tolist() + else: + batch_pos_ids = [None] * len(batch["input_ids"]) + + (_, batch_logits, _) = trainer.prediction_step( + trainer.model, + batch, + prediction_loss_only=False, + ) + + prompt_token_ids_list = [] + pred_step_token_ids_list = [] + completion_token_ids_list = [] + + for input_ids_all, labels_all, pos_ids, logits in zip( + batch_input_ids, batch_labels, batch_pos_ids, batch_logits, + ): + if pos_ids is None: + pos_ranges = [(0, len(input_ids_all) - 1)] + else: + pos_ranges = find_ranges(pos_ids) + + for pos_range in pos_ranges: + start, end = pos_range + if start == end: + continue + + input_ids = input_ids_all[start : end + 1] + labels = labels_all[start : end + 1] + + tokens_without_loss = labels == IGNORE_INDEX + tokens_with_loss = labels != IGNORE_INDEX + tokens_exclude_padding = input_ids != tokenizer.pad_token_id + prompt_token_includes = ( + tokens_without_loss & tokens_exclude_padding + ) + + prompt_token_ids = input_ids[prompt_token_includes] + prompt_token_ids_list.append(prompt_token_ids) + + completion_token_ids = input_ids[tokens_with_loss] + completion_token_ids_list.append(completion_token_ids) + + pred_step_token_ids = logits_to_tokens(logits[start : end + 1])[tokens_with_loss] + pred_step_token_ids_list.append(pred_step_token_ids) + + prompt_texts = tokenizer.batch_decode( + prompt_token_ids_list, skip_special_tokens=True + ) + completion_texts = tokenizer.batch_decode( + completion_token_ids_list, skip_special_tokens=True + ) + pred_step_texts = tokenizer.batch_decode( + pred_step_token_ids_list, skip_special_tokens=True + ) + + with torch.no_grad(): + prompt_encoding = tokenizer( + prompt_texts, padding=True, return_tensors="pt" + ).to(self.cfg.device) + predictions = trainer.model.generate( + **prompt_encoding, generation_config=generation_config + ) + + prediction_all_tokens = predictions["sequences"].cpu().tolist() + prediction_without_prompt_tokens_list = [] + for prompt_token_ids, prediction_tokens in zip( + prompt_token_ids_list, prediction_all_tokens + ): + prediction_without_prompt_tokens = prediction_tokens[ + len(prompt_token_ids) : + ] + prediction_without_prompt_tokens_list.append( + prediction_without_prompt_tokens + ) + + predicted_texts = tokenizer.batch_decode( + prediction_without_prompt_tokens_list, skip_special_tokens=True + ) + + for prompt_text, completion_text, prediction_text, pred_step_text in zip( + prompt_texts, completion_texts, predicted_texts, pred_step_texts + ): + table.add_data( + row_index, prompt_text, completion_text, prediction_text, pred_step_text + ) + row_index += 1 + + wandb.run.log({f"{name} - Predictions vs Ground Truth": table}) + + if is_main_process(): + log_table_from_dataloader("Eval", eval_dataloader) + + return control + + return LogPredictionCallback diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config.py index 4916fe2f6b..a4e68869d8 100644 --- a/src/axolotl/utils/config.py +++ b/src/axolotl/utils/config.py @@ -48,6 +48,8 @@ def normalize_config(cfg): ) cfg.world_size = int(os.environ.get("WORLD_SIZE", 1)) cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0)) + cfg.eval_table_size = cfg.eval_table_size or 0 + cfg.eval_table_max_new_tokens = cfg.eval_table_max_new_tokens or 128 choose_device(cfg) cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1 if cfg.ddp: From 09b16d8b975c8909669583cbbbc24e055c19823e Mon Sep 17 00:00:00 2001 From: Glavin Wiechert Date: Tue, 12 Sep 2023 06:58:53 +0000 Subject: [PATCH 20/22] Clean up PR, delete VSCode config, add tiny-llama example --- examples/llama-2/qlora.yml | 1 + examples/llama-2/tiny-llama.yml | 69 +++++++++++++++++++++++++++++++++ src/axolotl/utils/trainer.py | 5 +++ 3 files changed, 75 insertions(+) create mode 100644 examples/llama-2/tiny-llama.yml diff --git a/examples/llama-2/qlora.yml b/examples/llama-2/qlora.yml index dd029859ed..ef20d9fbe3 100644 --- a/examples/llama-2/qlora.yml +++ b/examples/llama-2/qlora.yml @@ -58,6 +58,7 @@ flash_attention: true warmup_steps: 10 eval_steps: 20 +eval_table_size: 5 save_steps: debug: deepspeed: diff --git a/examples/llama-2/tiny-llama.yml b/examples/llama-2/tiny-llama.yml new file mode 100644 index 0000000000..a53c9c831b --- /dev/null +++ b/examples/llama-2/tiny-llama.yml @@ -0,0 +1,69 @@ +base_model: PY007/TinyLlama-1.1B-step-50K-105b +base_model_config: PY007/TinyLlama-1.1B-step-50K-105b + +model_type: LlamaForCausalLM +tokenizer_type: LlamaTokenizer +is_llama_derived_model: true + +load_in_8bit: true +load_in_4bit: false +strict: false + +datasets: + - path: mhenrichsen/alpaca_2k_test + type: alpaca +dataset_prepared_path: last_run_prepared +val_set_size: 0.01 +output_dir: ./lora-out + +sequence_len: 4096 +sample_packing: true + +adapter: lora +lora_model_dir: +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_linear: true +lora_fan_in_fan_out: + +wandb_project: +wandb_entity: +wandb_watch: +wandb_run_id: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 2 +num_epochs: 3 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +train_on_inputs: false +group_by_length: false +bf16: true +fp16: false +tf32: false + +gradient_checkpointing: true +early_stopping_patience: +resume_from_checkpoint: +local_rank: +logging_steps: 1 +xformers_attention: +flash_attention: true + +warmup_steps: 10 +eval_steps: 20 +eval_table_size: 5 +save_steps: +debug: +deepspeed: +weight_decay: 0.0 +fsdp: +fsdp_config: +special_tokens: + bos_token: "" + eos_token: "" + unk_token: "" diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index ece1bd9b69..d959c896f1 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -30,6 +30,7 @@ SaveBetterTransformerModelCallback, SavePeftModelCallback, bench_eval_callback_factory, + log_prediction_callback_factory, ) from axolotl.utils.collators import DataCollatorForSeq2Seq from axolotl.utils.dataloader import MultipackDistributedDataloader @@ -703,6 +704,10 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_ **trainer_kwargs, ) + if cfg.use_wandb and cfg.eval_table_size > 0: + LogPredictionCallback = log_prediction_callback_factory(trainer, tokenizer) + trainer.add_callback(LogPredictionCallback(cfg)) + if cfg.do_bench_eval: trainer.add_callback(bench_eval_callback_factory(trainer, tokenizer)) From cf239987c47dde0185c9f911ee707dd326eb5f0d Mon Sep 17 00:00:00 2001 From: Glavin Wiechert Date: Tue, 12 Sep 2023 09:21:32 +0000 Subject: [PATCH 21/22] WIP quantize model & push model --- docker-compose.yaml | 9 ++++++--- scripts/finetune.py | 9 +++++++-- src/axolotl/common/cli.py | 6 +++++- src/axolotl/utils/quantize.py | 14 +++++++++----- 4 files changed, 27 insertions(+), 11 deletions(-) diff --git a/docker-compose.yaml b/docker-compose.yaml index a16be726cf..6708dbf6a0 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -1,9 +1,10 @@ # version: '3.8' services: axolotl: - build: - context: . - dockerfile: ./docker/Dockerfile + # build: + # context: . + # dockerfile: ./docker/Dockerfile + image: winglian/axolotl:main-py3.10-cu118-2.0.1 volumes: - .:/workspace/axolotl - ~/.cache/huggingface/:/root/.cache/huggingface/ @@ -15,6 +16,8 @@ services: - GIT_COMMITTER_NAME=${GIT_COMMITTER_NAME} - GIT_COMMITTER_EMAIL=${GIT_COMMITTER_EMAIL} - WANDB_API_KEY=${WANDB_API_KEY} + ports: + - "5678:5678" deploy: resources: reservations: diff --git a/scripts/finetune.py b/scripts/finetune.py index 925a29d0a2..bcf828202c 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -29,7 +29,7 @@ from axolotl.utils.models import load_tokenizer from axolotl.utils.tokenization import check_dataset_labels from axolotl.utils.wandb import setup_wandb_env_vars -from axolotl.utils.quantize import get_examples_for_quantization, load_merged_model, quantize_and_save +from axolotl.utils.quantize import get_examples_for_quantization, load_merged_model, push_model, quantize_and_save project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) src_dir = os.path.join(project_root, "src") @@ -155,6 +155,9 @@ def do_inference( batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True) print("=" * 40) + print(prompt) + print("=" * 20) + model.eval() with torch.no_grad(): generation_config = GenerationConfig( @@ -298,7 +301,9 @@ def do_cli(config: Path = Path("examples/"), **kwargs): n_samples = 128 examples = get_examples_for_quantization(dataset_meta.train_dataset, n_samples) quantize_and_save(parsed_cfg, merged_model, tokenizer, examples) - + elif parsed_cli_args.push: + model, tokenizer = load_model_and_tokenizer(cfg=parsed_cfg, cli_args=parsed_cli_args) + push_model(cfg=parsed_cfg, model=model, tokenizer=tokenizer) else: dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) if parsed_cli_args.prepare_ds_only: diff --git a/src/axolotl/common/cli.py b/src/axolotl/common/cli.py index d6bf59f74f..ef439d3cf0 100644 --- a/src/axolotl/common/cli.py +++ b/src/axolotl/common/cli.py @@ -9,6 +9,7 @@ from axolotl.logging_config import configure_logging from axolotl.utils.dict import DictDefault from axolotl.utils.models import load_model, load_tokenizer +from axolotl.utils.quantize import get_quantized_model configure_logging() LOG = logging.getLogger("axolotl.common.cli") @@ -26,6 +27,7 @@ class TrainerCliArgs: inference: bool = field(default=False) merge_lora: bool = field(default=False) quantize: bool = field(default=False) + push: bool = field(default=False) prepare_ds_only: bool = field(default=False) prompter: Optional[str] = field(default=None) shard: bool = field(default=False) @@ -39,6 +41,8 @@ def load_model_and_tokenizer( LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}") tokenizer = load_tokenizer(cfg) LOG.info("loading model and (optionally) peft_config...") - model, _ = load_model(cfg, tokenizer, inference=cli_args.inference) + # model, _ = load_model(cfg, tokenizer, inference=cli_args.inference) + # TEMP + model = get_quantized_model(cfg) return model, tokenizer diff --git a/src/axolotl/utils/quantize.py b/src/axolotl/utils/quantize.py index 89da2add0e..89c46d7072 100644 --- a/src/axolotl/utils/quantize.py +++ b/src/axolotl/utils/quantize.py @@ -15,10 +15,10 @@ # from datasets import load_dataset, Dataset # from transformers import AutoTokenizer, LlamaTokenizer, TextGenerationPipeline from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig -from axolotl.prompters import AlpacaPrompter -from axolotl.utils.models import load_model, load_tokenizer -from axolotl.common.cli import TrainerCliArgs -from axolotl.logging_config import configure_logging +# from axolotl.prompters import AlpacaPrompter +# from axolotl.utils.models import load_model, load_tokenizer +# from axolotl.common.cli import TrainerCliArgs +# from axolotl.logging_config import configure_logging from axolotl.utils.dict import DictDefault # from finetune import load_cfg, get_merged_out_dir, do_merge_lora_model_and_tokenizer @@ -55,7 +55,11 @@ def load_merged_model(cfg: DictDefault): def get_quantized_model(cfg: DictDefault): print("Loading quantized model...") quantized_model_dir = get_quantized_model_dir(cfg) - model = AutoGPTQForCausalLM.from_quantized(quantized_model_dir, device="cuda:0", use_safetensors=True) + model = AutoGPTQForCausalLM.from_quantized(quantized_model_dir, + device="cuda:0", + use_safetensors=True, + inject_fused_attention=False, # WORKAROUND for https://github.com/PanQiWei/AutoGPTQ/issues/210 + ) print("Model loaded.") return model From 8a26ab325a0406ba1e15ebeaed215b553e17ca2c Mon Sep 17 00:00:00 2001 From: Glavin Wiechert Date: Tue, 12 Sep 2023 10:50:31 +0000 Subject: [PATCH 22/22] WIP --- .vscode/launch.json | 37 ++++++++++++++ examples/llama-2/llama-68.yml | 70 +++++++++++++++++++++++++ examples/llama-2/lora.yml | 2 +- examples/llama-2/tiny-puffed-llama.yml | 71 ++++++++++++++++++++++++++ src/axolotl/utils/quantize.py | 3 ++ 5 files changed, 182 insertions(+), 1 deletion(-) create mode 100644 .vscode/launch.json create mode 100644 examples/llama-2/llama-68.yml create mode 100644 examples/llama-2/tiny-puffed-llama.yml diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 0000000000..3f60f05d36 --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,37 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "name": "Python: Remote Attach", + "type": "python", + "request": "attach", + "connect": { + "host": "0.0.0.0", + "port": 5678 + }, + "pathMappings": [ + { + "localRoot": "${workspaceFolder}", + "remoteRoot": "/workspace/axolotl/" + } + ], + "justMyCode": false + }, + { + "name": "train", + "type": "python", + "request": "launch", + "module": "accelerate.commands.launch", + "args": [ + "${workspaceFolder}/scripts/finetune.py", + // "${file}", + "${workspaceFolder}/examples/llama-2/tiny-random.yml", + ], // other args comes after train.py + "console": "integratedTerminal", + // "env": {"CUDA_LAUNCH_BLOCKING": "1"} + }, + ] +} diff --git a/examples/llama-2/llama-68.yml b/examples/llama-2/llama-68.yml new file mode 100644 index 0000000000..85616aba90 --- /dev/null +++ b/examples/llama-2/llama-68.yml @@ -0,0 +1,70 @@ +base_model: JackFram/llama-68m +base_model_config: JackFram/llama-68m + +model_type: LlamaForCausalLM +tokenizer_type: LlamaTokenizer +is_llama_derived_model: true + +load_in_8bit: true +load_in_4bit: false +strict: false + +datasets: + - path: mhenrichsen/alpaca_2k_test + type: alpaca +dataset_prepared_path: last_run_prepared +val_set_size: 0.01 +output_dir: ./lora-out + +# sequence_len: 4096 +sequence_len: 2048 +sample_packing: true + +adapter: lora +lora_model_dir: +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_linear: true +lora_fan_in_fan_out: + +wandb_project: +wandb_entity: +wandb_watch: +wandb_run_id: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 2 +num_epochs: 3 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +train_on_inputs: false +group_by_length: false +bf16: true +fp16: false +tf32: false + +gradient_checkpointing: true +early_stopping_patience: +resume_from_checkpoint: +local_rank: +logging_steps: 1 +xformers_attention: +flash_attention: true + +warmup_steps: 10 +eval_steps: 20 +eval_table_size: 5 +save_steps: +debug: +deepspeed: +weight_decay: 0.0 +fsdp: +fsdp_config: +special_tokens: + bos_token: "" + eos_token: "" + unk_token: "" diff --git a/examples/llama-2/lora.yml b/examples/llama-2/lora.yml index 4fdcb04092..2438b0d884 100644 --- a/examples/llama-2/lora.yml +++ b/examples/llama-2/lora.yml @@ -27,7 +27,7 @@ lora_dropout: 0.05 lora_target_linear: true lora_fan_in_fan_out: -wandb_project: +wandb_project: test-issue-490-7b-2 wandb_entity: wandb_watch: wandb_run_id: diff --git a/examples/llama-2/tiny-puffed-llama.yml b/examples/llama-2/tiny-puffed-llama.yml new file mode 100644 index 0000000000..ac02b7b27b --- /dev/null +++ b/examples/llama-2/tiny-puffed-llama.yml @@ -0,0 +1,71 @@ +base_model: PY007/TinyLlama-1.1B-step-50K-105b +base_model_config: PY007/TinyLlama-1.1B-step-50K-105b + +model_type: LlamaForCausalLM +tokenizer_type: LlamaTokenizer +is_llama_derived_model: true + +load_in_8bit: true +load_in_4bit: false +strict: false + +datasets: + # - path: mhenrichsen/alpaca_2k_test + # type: alpaca + - path: LDJnr/Puffin + type: sharegpt:chat +dataset_prepared_path: last_run_prepared +val_set_size: 0.01 +output_dir: ./lora-tiny-puffed-out + +sequence_len: 2048 +sample_packing: true + +adapter: lora +lora_model_dir: +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_linear: true +lora_fan_in_fan_out: + +wandb_project: +wandb_entity: +wandb_watch: +wandb_run_id: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 4 +num_epochs: 2 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +train_on_inputs: false +group_by_length: false +bf16: true +fp16: false +tf32: false + +gradient_checkpointing: true +early_stopping_patience: +resume_from_checkpoint: +local_rank: +logging_steps: 1 +xformers_attention: +flash_attention: true + +warmup_steps: 10 +eval_steps: 20 +eval_table_size: 10 +save_steps: +debug: +deepspeed: +weight_decay: 0.0 +fsdp: +fsdp_config: +special_tokens: + bos_token: "" + eos_token: "" + unk_token: "" diff --git a/src/axolotl/utils/quantize.py b/src/axolotl/utils/quantize.py index 89c46d7072..39318836a9 100644 --- a/src/axolotl/utils/quantize.py +++ b/src/axolotl/utils/quantize.py @@ -86,6 +86,9 @@ def quantize_and_save(cfg: DictDefault, model, tokenizer, examples_for_quant): tokenizer.save_pretrained(quantized_model_dir) print("Saved.") + # FIXME: Add fix to config.json + # "error": "handler: 'pad_token_id' \ntraceback: Traceback (most recent call last):\n File \"/usr/local/lib/python3.10/dist-packages/runpod/serverless/modules/job.py\", line 141, in run_job_generator\n for output_partial in job_output:\n File \"/data/handler.py\", line 107, in inference\n generator, default_settings = load_model()\n File \"/data/handler.py\", line 45, in load_model\n config = ExLlamaConfig(model_config_path) # create config from config.json\n File \"/data/exllama/model.py\", line 52, in __init__\n self.pad_token_id = read_config[\"pad_token_id\"]\nKeyError: 'pad_token_id'\n" + return model def push_model(cfg: DictDefault, model, tokenizer):