diff --git a/.github/ISSUE_TEMPLATE/bug-report.yaml b/.github/ISSUE_TEMPLATE/bug-report.yaml
index cca243fdf2..5cf332587a 100644
--- a/.github/ISSUE_TEMPLATE/bug-report.yaml
+++ b/.github/ISSUE_TEMPLATE/bug-report.yaml
@@ -53,6 +53,13 @@ body:
validations:
required: true
+ - type: textarea
+ id: config
+ attributes:
+ label: Config yaml
+ description: |
+ Please attach the config yaml!
+
- type: textarea
id: possible-solution
attributes:
diff --git a/.github/workflows/base.yml b/.github/workflows/base.yml
index d755b63263..5f08854842 100644
--- a/.github/workflows/base.yml
+++ b/.github/workflows/base.yml
@@ -25,6 +25,11 @@ jobs:
python_version: "3.10"
pytorch: 2.0.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX"
+ - cuda: "118"
+ cuda_version: 11.8.0
+ python_version: "3.10"
+ pytorch: 2.1.0
+ torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX"
steps:
- name: Checkout
uses: actions/checkout@v3
diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml
index a5b4d30379..f84f7f7a9a 100644
--- a/.github/workflows/main.yml
+++ b/.github/workflows/main.yml
@@ -23,6 +23,11 @@ jobs:
python_version: "3.10"
pytorch: 2.0.1
axolotl_extras:
+ - cuda: 118
+ cuda_version: 11.8.0
+ python_version: "3.10"
+ pytorch: 2.1.0
+ axolotl_extras:
runs-on: [self-hosted, gpu, docker]
steps:
- name: Checkout
@@ -46,6 +51,7 @@ jobs:
build-args: |
BASE_TAG=${{ github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
CUDA=${{ matrix.cuda }}
+ PYTORCH_VERSION=${{ matrix.pytorch }}
file: ./docker/Dockerfile
push: ${{ github.event_name != 'pull_request' }}
tags: ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
@@ -68,6 +74,11 @@ jobs:
pytorch: 2.0.1
axolotl_extras:
is_latest: true
+ - cuda: 118
+ cuda_version: 11.8.0
+ python_version: "3.10"
+ pytorch: 2.1.0
+ axolotl_extras:
runs-on: [self-hosted, gpu, docker]
steps:
- name: Checkout
diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml
index 18ff575c10..04fe53c00a 100644
--- a/.github/workflows/tests.yml
+++ b/.github/workflows/tests.yml
@@ -4,7 +4,13 @@ on:
push:
branches:
- "main"
+ paths:
+ - '**.py'
+ - 'requirements.txt'
pull_request:
+ paths:
+ - '**.py'
+ - 'requirements.txt'
workflow_dispatch:
jobs:
@@ -40,7 +46,7 @@ jobs:
- name: Install dependencies
run: |
- pip3 install -e .
+ pip3 install -U -e .
pip3 install -r requirements-tests.txt
- name: Run tests
@@ -65,8 +71,8 @@ jobs:
- name: Install dependencies
run: |
- pip3 install -e .
- pip3 install flash-attn
+ pip3 uninstall -y transformers accelerate
+ pip3 install -U -e .[flash-attn]
pip3 install -r requirements-tests.txt
- name: Run e2e tests
diff --git a/.gitignore b/.gitignore
index 614a6676b1..f46b6808f5 100644
--- a/.gitignore
+++ b/.gitignore
@@ -161,3 +161,7 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/
+
+# WandB
+# wandb creates a folder to store logs for training runs
+wandb
diff --git a/.isort.cfg b/.isort.cfg
index b9fb3f3e8c..79067a7c91 100644
--- a/.isort.cfg
+++ b/.isort.cfg
@@ -1,2 +1,3 @@
[settings]
profile=black
+known_third_party=wandb
diff --git a/README.md b/README.md
index 6e7b625c21..bd0426b4ca 100644
--- a/README.md
+++ b/README.md
@@ -23,14 +23,16 @@ Features:
- [Supported Features](#axolotl-supports)
- [Quickstart](#quickstart-)
- [Installation](#installation)
- - [Docker Installation](#environment)
- - [Conda/Pip venv Installation](#condapip-venv)
- - [LambdaLabs Installation](#lambdalabs)
+ - [Docker](#docker)
+ - [Conda/Pip venv](#condapip-venv)
+ - [LambdaLabs](#lambdalabs)
+ - [Windows](#windows)
- [Dataset](#dataset)
- [How to Add Custom Prompts](#how-to-add-custom-prompts)
- [How to Use Custom Pretokenized Dataset](#how-to-use-your-custom-pretokenized-dataset)
- [Config](#config)
- [Train](#train)
+ - [Training w/ Deepspeed](#training-with-deepspeed)
- [Inference](#inference)
- [Merge LORA to Base](#merge-lora-to-base)
- [Common Errors](#common-errors-)
@@ -49,7 +51,7 @@ Features:
Axolotl provides a unified repository for fine-tuning
a variety of AI models with ease
- Go ahead and axolotl questions!!
+ Go ahead and Axolotl questions!!
@@ -86,7 +88,7 @@ git clone https://github.com/OpenAccess-AI-Collective/axolotl
cd axolotl
pip3 install packaging
-pip3 install -e .[flash-attn]
+pip3 install -e '.[flash-attn,deepspeed]'
pip3 install -U git+https://github.com/huggingface/peft.git
# finetune lora
@@ -101,7 +103,7 @@ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
### Environment
-- Docker
+#### Docker
```bash
docker run --gpus '"all"' --rm -it winglian/axolotl:main-py3.10-cu118-2.0.1
```
@@ -113,18 +115,23 @@ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
docker compose up -d
```
-- Conda/Pip venv
+#### Conda/Pip venv
1. Install python >=**3.9**
2. Install pytorch stable https://pytorch.org/get-started/locally/
- 3. Install axolotl along with python dependencies
+ 3. Install Axolotl along with python dependencies
```bash
pip3 install packaging
- pip3 install -e .[flash-attn]
+ pip3 install -e '.[flash-attn,deepspeed]'
```
+ 4. (Optional) Login to Huggingface to use gated models/datasets.
+ ```bash
+ huggingface-cli login
+ ```
+ Get the token at huggingface.co/settings/tokens
-- LambdaLabs
+#### LambdaLabs
Click to Expand
@@ -157,7 +164,7 @@ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
cd axolotl
pip3 install packaging
- pip3 install -e .[flash-attn]
+ pip3 install -e '.[flash-attn,deepspeed]'
pip3 install protobuf==3.20.3
pip3 install -U --ignore-installed requests Pillow psutil scipy
```
@@ -168,7 +175,8 @@ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
```
-- Windows: Please use WSL or Docker!
+#### Windows
+Please use WSL or Docker!
### Dataset
@@ -179,7 +187,7 @@ Have dataset(s) in one of the following format (JSONL recommended):
```json
{"instruction": "...", "input": "...", "output": "..."}
```
-- `sharegpt:chat`: conversations where `from` is `human`/`gpt`
+- `sharegpt`: conversations where `from` is `human`/`gpt`
```json
{"conversations": [{"from": "...", "value": "..."}]}
```
@@ -244,6 +252,10 @@ Have dataset(s) in one of the following format (JSONL recommended):
```json
{"article": "...", "question": "...", "answer": "..."}
```
+- `context_qa.load_v2`: in context question answering (alternate)
+ ```json
+ {"context": "...", "question": "...", "answer": "..."}
+ ```
- `context_qa.load_404`: in context question answering from an article, with default response for no answer from context
```json
{"article": "...", "unanswerable_question": "..."}
@@ -268,11 +280,11 @@ Have dataset(s) in one of the following format (JSONL recommended):
```json
{"prompt": "...", "generation": "..."}
```
-- `sharegpt_simple.load_role`: conversations where `role` is used instead of `from`
+- `sharegpt.load_role`: conversations where `role` is used instead of `from`
```json
{"conversations": [{"role": "...", "value": "..."}]}
```
-- `sharegpt_simple.load_guanaco`: conversations where `from` is `prompter`/`assistant` instead of default sharegpt
+- `sharegpt.load_guanaco`: conversations where `from` is `prompter`/`assistant` instead of default sharegpt
```json
{"conversations": [{"from": "...", "value": "..."}]}
```
@@ -285,29 +297,28 @@ Have dataset(s) in one of the following format (JSONL recommended):
#### How to add custom prompts
-Using yaml. Example:
+For a dataset that is preprocessed for instruction purposes:
+
+```json
+{"instruction": "...", "output": "..."}
+```
+
+You can use this example in your YAML config:
+
```yaml
datasets:
- path: repo
type:
system_prompt: ""
- no_input_format: |-
- User: {instruction}<|end_of_turn|>
- Assistant:
- format: |-
- User: {instruction}
- {input}<|end_of_turn|>
- Assistant:
+ field_system: system
+ format: "[INST] {instruction} [/INST]"
+ no_input_format: "[INST] {instruction} [/INST]"
```
-Using file:
-1. Add your method to a file in [prompt_strategies](src/axolotl/prompt_strategies). Please see other files as example.
-2. Use your custom file name as the dataset type `.load_`.
-
#### How to use your custom pretokenized dataset
- Do not pass a `type:`
-- Dataset must contain `input_ids`, `attention_mask`, `labels` in columns
+- Columns in Dataset must be exactly `input_ids`, `attention_mask`, `labels`
### Config
@@ -350,6 +361,12 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
- path: data.jsonl # or json
ds_type: json # see other options below
type: alpaca
+
+ # dataset with splits, but no train split
+ dataset:
+ - path: knowrohit07/know_sql
+ type: context_qa.load_v2
+ train_on_split: validation
```
- loading
@@ -380,15 +397,15 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
All yaml options
```yaml
-# this is the huggingface model that contains *.pt, *.safetensors, or *.bin files
-# this can also be a relative path to a model on disk
+# This is the huggingface model that contains *.pt, *.safetensors, or *.bin files
+# This can also be a relative path to a model on disk
base_model: ./llama-7b-hf
-# you can specify an ignore pattern if the model repo contains more than 1 model type (*.pt, etc)
+# You can specify an ignore pattern if the model repo contains more than 1 model type (*.pt, etc)
base_model_ignore_patterns:
-# if the base_model repo on hf hub doesn't include configuration .json files,
-# you can set that here, or leave this empty to default to base_model
+# If the base_model repo on hf hub doesn't include configuration .json files,
+# You can set that here, or leave this empty to default to base_model
base_model_config: ./llama-7b-hf
-# you can specify to choose a specific model revision from huggingface hub
+# You can specify to choose a specific model revision from huggingface hub
model_revision:
# Optional tokenizer configuration override in case you want to use a different tokenizer
# than the one defined in the base model
@@ -403,18 +420,24 @@ trust_remote_code:
tokenizer_use_fast:
# Whether to use the legacy tokenizer setting, defaults to True
tokenizer_legacy:
-# resize the model embeddings when new tokens are added to multiples of 32
-# this is reported to improve training speed on some models
+# Resize the model embeddings when new tokens are added to multiples of 32
+# This is reported to improve training speed on some models
resize_token_embeddings_to_32x:
-# whether you are training a 4-bit GPTQ quantized model
+# Used to identify which the model is based on
+is_falcon_derived_model:
+is_llama_derived_model:
+# Please note that if you set this to true, `padding_side` will be set to "left" by default
+is_mistral_derived_model:
+
+# Whether you are training a 4-bit GPTQ quantized model
gptq: true
gptq_groupsize: 128 # group size
gptq_model_v1: false # v1 or v2
-# this will attempt to quantize the model down to 8 bits and use adam 8 bit optimizer
+# This will attempt to quantize the model down to 8 bits and use adam 8 bit optimizer
load_in_8bit: true
-# use bitsandbytes 4 bit
+# Use bitsandbytes 4 bit
load_in_4bit:
# Use CUDA bf16
@@ -428,9 +451,9 @@ tf32: true # require >=ampere
bfloat16: true # require >=ampere
float16: true
-# a list of one or more datasets to finetune the model with
+# A list of one or more datasets to finetune the model with
datasets:
- # hf dataset repo | "json" for local dataset, make sure to fill data_files
+ # HuggingFace dataset repo | "json" for local dataset, make sure to fill data_files
- path: vicgalle/alpaca-gpt4
# The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection]
type: alpaca # format | format: (chat/instruct) | .load_
@@ -438,18 +461,20 @@ datasets:
data_files: # Optional[str] path to source data files
shards: # Optional[int] number of shards to split data into
name: # Optional[str] name of dataset configuration to load
+ conversation: # Optional[str] fastchat conversation type, only used with type: sharegpt
- # custom user prompt
+ # Custom user prompt
- path: repo
type:
- # the below are defaults. only set what's needed.
+ # The below are defaults. only set what's needed.
system_prompt: ""
+ system_format: "{system}"
field_system: system
field_instruction: instruction
- field_output: input
+ field_input: input
+ field_output: output
- # customizable to be single line or multi-line
- system_format: "{system}"
+ # Customizable to be single line or multi-line
# 'format' can include {input}
format: |-
User: {instruction} {input}
@@ -457,21 +482,24 @@ datasets:
# 'no_input_format' cannot include {input}
no_input_format: "{instruction} "
- # for completions datsets, uses the provided field if not `text`
+ # For `completion` datsets only, uses the provided field instead of `text` column
field:
-# axolotl attempts to save the dataset as an arrow after packing the data together so
+# Axolotl attempts to save the dataset as an arrow after packing the data together so
# subsequent training attempts load faster, relative path
dataset_prepared_path: data/last_run_prepared
-# push prepared dataset to hub
+# Push prepared dataset to hub
push_dataset_to_hub: # repo path
+# The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()`
+# if not set.
+dataset_processes: # defaults to os.cpu_count() if not set
# push checkpoints to hub
hub_model_id: # repo path to push finetuned model
# how to push checkpoints to hub
# https://huggingface.co/docs/transformers/v4.31.0/en/main_classes/trainer#transformers.TrainingArguments.hub_strategy
hub_strategy:
-# whether to use hf `use_auth_token` for loading datasets. Useful for fetching private datasets
-# required to be true when used in combination with `push_dataset_to_hub`
+# Whether to use hf `use_auth_token` for loading datasets. Useful for fetching private datasets
+# Required to be true when used in combination with `push_dataset_to_hub`
hf_use_auth_token: # boolean
# How much of the dataset to set aside as evaluation. 1 = 100%, 0.50 = 50%, etc. 0 for no eval.
val_set_size: 0.04
@@ -480,28 +508,34 @@ dataset_shard_num:
# Index of shard to use for whole dataset
dataset_shard_idx:
-# the maximum length of an input to train with, this should typically be less than 2048
+# The maximum length of an input to train with, this should typically be less than 2048
# as most models have a token/context limit of 2048
sequence_len: 2048
-# pad inputs so each step uses constant sized buffers
-# this will reduce memory fragmentation and may prevent OOMs, by re-using memory more efficiently
+# Pad inputs so each step uses constant sized buffers
+# This will reduce memory fragmentation and may prevent OOMs, by re-using memory more efficiently
pad_to_sequence_len:
-# max sequence length to concatenate training samples together up to
-# inspired by StackLLaMA. see https://huggingface.co/blog/stackllama#supervised-fine-tuning
+# Max sequence length to concatenate training samples together up to
+# Inspired by StackLLaMA. see https://huggingface.co/blog/stackllama#supervised-fine-tuning
# FutureWarning: This will soon be DEPRECATED
max_packed_sequence_len: 1024
-# use efficient multi-packing with block diagonal attention and per sequence position_ids. Recommend set to 'true'
+# Use efficient multi-packing with block diagonal attention and per sequence position_ids. Recommend set to 'true'
sample_packing:
-# you can set these packing optimizations AFTER starting a training at least once.
+# Set to 'false' if getting errors during eval with sample_packing on.
+eval_sample_packing:
+# You can set these packing optimizations AFTER starting a training at least once.
# The trainer will provide recommended values for these values.
sample_packing_eff_est:
total_num_tokens:
-# if you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model
+# If you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model
adapter: lora
-# if you already have a lora model trained that you want to load, put that here
-# lora hyperparameters
+# If you already have a lora model trained that you want to load, put that here.
+# This means after training, if you want to test the model, you should set this to the value of `lora_out_dir`.
peft_model_dir:
+
+# LoRA hyperparameters
+# For more details about the following options, see:
+# https://www.anyscale.com/blog/fine-tuning-llms-lora-or-full-parameter-an-in-depth-analysis-with-llama-2
lora_r: 8
lora_alpha: 16
lora_dropout: 0.05
@@ -513,10 +547,19 @@ lora_target_modules:
# - gate_proj
# - down_proj
# - up_proj
-lora_target_linear: # if true, will target all linear layers
+lora_target_linear: # If true, will target all linear layers
+
+# If you added new tokens to the tokenizer, you may need to save some LoRA modules because they need to know the new tokens.
+# For LLaMA and Mistral, you need to save `embed_tokens` and `lm_head`. It may vary for other models.
+# `embed_tokens` converts tokens to embeddings, and `lm_head` converts embeddings to token probabilities.
+# https://github.com/huggingface/peft/issues/334#issuecomment-1561727994
lora_modules_to_save:
# - embed_tokens
# - lm_head
+
+# Once you complete training, the model will be saved to the following directory.
+# If you merge the adapter to the base model, a subdirectory `merged` will be created under this directory.
+# Make sure `lora_model_dir` points to this directory if you want to use the trained model.
lora_out_dir:
lora_fan_in_fan_out: false
ia3_target_modules: # target modules for IA3, for llama, k, v, and down projections
@@ -524,73 +567,79 @@ ia3_feedforward_modules: # ffn modules for IA3, for llama down projection
ia3_fan_in_fan_out:
# ReLoRA configuration
-# must use either 'lora' or 'qlora' adapter, and does not support fsdp or deepspeed
-relora_steps: # number of steps per ReLoRA restart
-relora_warmup_steps: # number of per-restart warmup steps
-relora_cpu_offload: # true to perform lora weight merges on cpu during restarts, for modest gpu memory savings
+# Must use either 'lora' or 'qlora' adapter, and does not support fsdp or deepspeed
+relora_steps: # Number of steps per ReLoRA restart
+relora_warmup_steps: # Number of per-restart warmup steps
+relora_cpu_offload: # True to perform lora weight merges on cpu during restarts, for modest gpu memory savings
# wandb configuration if you're using it
wandb_mode: # "offline" to save run metadata locally and not sync to the server, "disabled" to turn off wandb
-wandb_project: # your wandb project name
-wandb_entity: # a wandb Team name if using a Team
+wandb_project: # Your wandb project name
+wandb_entity: # A wandb Team name if using a Team
wandb_watch:
-wandb_run_id: # set the name of your wandb run
+wandb_run_id: # Set the name of your wandb run
wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_steps` or "end" to log only at the end of training
-# where to save the finished model to
+# Where to save the full-finetuned model to
output_dir: ./completed-model
-# whether to use torch.compile and which backend to use
+# Whether to use torch.compile and which backend to use
torch_compile: # bool
torch_compile_backend: # Optional[str]
-# training hyperparameters
+# Training hyperparameters
+
+# If greater than 1, backpropagation will be skipped and the gradients will be accumulated for the given number of steps.
gradient_accumulation_steps: 1
+# The number of samples to include in each batch. This is the number of samples sent to each GPU.
micro_batch_size: 2
-eval_batch_size: 2
+eval_batch_size:
num_epochs: 3
warmup_steps: 100
learning_rate: 0.00003
lr_quadratic_warmup:
logging_steps:
-save_strategy: # set to `no` to skip checkpoint saves
-save_steps: # leave empty to save at each epoch
-eval_steps: # leave empty to eval at each epoch
-save_total_limit: # checkpoints saved at a time
+save_strategy: # Set to `no` to skip checkpoint saves
+save_steps: # Leave empty to save at each epoch
+eval_steps: # Leave empty to eval at each epoch
+save_total_limit: # Checkpoints saved at a time
+# Maximum number of iterations to train for. It precedes num_epochs which means that
+# if both are set, num_epochs will not be guaranteed.
+# e.g., when 1 epoch is 1000 steps => `num_epochs: 2` and `max_steps: 100` will train for 100 steps
max_steps:
-eval_table_size: # approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
-eval_table_max_new_tokens: # total number of tokens generated for predictions sent to wandb. Default is 128
+eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
+eval_table_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
-# save model as safetensors (require safetensors package)
+# Save model as safetensors (require safetensors package)
save_safetensors:
-# whether to mask out or include the human's prompt from the training labels
+# Whether to mask out or include the human's prompt from the training labels
train_on_inputs: false
-# group similarly sized data to minimize padding
-# may be slower to start, as it must download and sort the entire dataset
-# note that training loss may have an oscillating pattern with this enabled
+# Group similarly sized data to minimize padding.
+# May be slower to start, as it must download and sort the entire dataset.
+# Note that training loss may have an oscillating pattern with this enabled.
group_by_length: false
# Whether to use gradient checkpointing https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing
gradient_checkpointing: false
-# stop training after this many evaluation losses have increased in a row
+# Stop training after this many evaluation losses have increased in a row
# https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback
early_stopping_patience: 3
-# specify a scheduler and kwargs to use with the optimizer
+# Specify a scheduler and kwargs to use with the optimizer
lr_scheduler: # 'one_cycle' | 'log_sweep' | empty for cosine
lr_scheduler_kwargs:
-# for one_cycle optim
-lr_div_factor: # learning rate div factor
+# For one_cycle optim
+lr_div_factor: # Learning rate div factor
-# for log_sweep optim
+# For log_sweep optim
log_sweep_min_lr:
log_sweep_max_lr:
-# specify optimizer
+# Specify optimizer
# Valid values are driven by the Transformers OptimizerNames class, see:
# https://github.com/huggingface/transformers/blob/95b374952dc27d8511541d6f5a4e22c9ec11fb24/src/transformers/training_args.py#L134
#
@@ -616,7 +665,7 @@ log_sweep_max_lr:
# - paged_lion_32bit
# - paged_lion_8bit
optimizer:
-# specify weight decay
+# Specify weight decay
weight_decay:
# adamw hyperparams
adam_beta1:
@@ -625,47 +674,56 @@ adam_epsilon:
# Gradient clipping max norm
max_grad_norm:
-# whether to bettertransformers
+# Augmentation techniques
+# NEFT https://arxiv.org/abs/2310.05914, set this to a number (paper default is 5) to add noise to embeddings
+# currently only supported on Llama and Mistral
+noisy_embedding_alpha:
+
+# Whether to bettertransformers
flash_optimum:
-# whether to use xformers attention patch https://github.com/facebookresearch/xformers:
+# Whether to use xformers attention patch https://github.com/facebookresearch/xformers:
xformers_attention:
-# whether to use flash attention patch https://github.com/Dao-AILab/flash-attention:
+# Whether to use flash attention patch https://github.com/Dao-AILab/flash-attention:
flash_attention:
-# whether to use scaled-dot-product attention
+flash_attn_cross_entropy: # Whether to use flash-attention cross entropy implementation - advanced use only
+flash_attn_rms_norm: # Whether to use flash-attention rms norm implementation - advanced use only
+# Whether to use scaled-dot-product attention
# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
sdp_attention:
# Landmark attention (only llama)
landmark_attention:
# xpos RoPE see https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py
-# llama only
+# LLaMA only
xpos_rope:
# RoPE Scaling https://github.com/huggingface/transformers/pull/24653
rope_scaling:
type: # linear | dynamic
factor: # float
-# resume from a specific checkpoint dir
+# Resume from a specific checkpoint dir
resume_from_checkpoint:
-# if resume_from_checkpoint isn't set and you simply want it to start where it left off
-# be careful with this being turned on between different models
+# If resume_from_checkpoint isn't set and you simply want it to start where it left off.
+# Be careful with this being turned on between different models.
auto_resume_from_checkpoints: false
-# don't mess with this, it's here for accelerate and torchrun
+# Don't mess with this, it's here for accelerate and torchrun
local_rank:
-# add or change special tokens
+# Add or change special tokens.
+# If you add tokens here, you don't need to add them to the `tokens` list.
special_tokens:
# bos_token: ""
# eos_token: ""
# unk_token: ""
-# add extra tokens
+
+# Add extra tokens.
tokens:
# FSDP
fsdp:
fsdp_config:
-# Deepspeed config path
+# Deepspeed config path. e.g., deepspeed/zero3.json
deepspeed:
# Advanced DDP Arguments
@@ -691,6 +749,66 @@ strict:
+
+ Understanding of batch size and gradient accumulation steps
+
+Gradient accumulation means accumulating gradients over several mini-batches and updating the model weights afterward. When the samples in each batch are diverse, this technique doesn't significantly impact learning.
+
+This method allows for effective training with larger effective batch sizes without needing proportionally larger memory. Here's why:
+
+1. **Memory Consumption with Batch Size**: The primary reason increasing the batch size impacts memory is due to the storage requirements for intermediate activations. When you forward propagate a batch through a network, you have to store the activations at each layer for each sample in the batch, because these activations are used during backpropagation to compute gradients. Therefore, larger batches mean more activations, leading to greater GPU memory consumption.
+
+2. **Gradient Accumulation**: With gradient accumulation, you're effectively simulating a larger batch size by accumulating gradients over several smaller batches (or micro-batches). However, at any given time, you're only forward and backward propagating a micro-batch. This means you only store activations for the micro-batch, not the full accumulated batch. As a result, you can simulate the effect of a larger batch size without the memory cost of storing activations for a large batch.
+
+**Example 1:**
+Micro batch size: 3
+Gradient accumulation steps: 2
+Number of GPUs: 3
+Total batch size = 3 * 2 * 3 = 18
+
+```
+| GPU 1 | GPU 2 | GPU 3 |
+|----------------|----------------|----------------|
+| S1, S2, S3 | S4, S5, S6 | S7, S8, S9 |
+| e1, e2, e3 | e4, e5, e6 | e7, e8, e9 |
+|----------------|----------------|----------------|
+| → (accumulate) | → (accumulate) | → (accumulate) |
+|----------------|----------------|----------------|
+| S10, S11, S12 | S13, S14, S15 | S16, S17, S18 |
+| e10, e11, e12 | e13, e14, e15 | e16, e17, e18 |
+|----------------|----------------|----------------|
+| → (apply) | → (apply) | → (apply) |
+
+Accumulated gradient for the weight w1 after the second iteration (considering all GPUs):
+Total gradient for w1 = e1 + e2 + e3 + e4 + e5 + e6 + e7 + e8 + e9 + e10 + e11 + e12 + e13 + e14 + e15 + e16 + e17 + e18
+
+Weight update for w1:
+w1_new = w1_old - learning rate x (Total gradient for w1 / 18)
+```
+
+**Example 2:**
+Micro batch size: 2
+Gradient accumulation steps: 1
+Number of GPUs: 3
+Total batch size = 2 * 1 * 3 = 6
+
+```
+| GPU 1 | GPU 2 | GPU 3 |
+|-----------|-----------|-----------|
+| S1, S2 | S3, S4 | S5, S6 |
+| e1, e2 | e3, e4 | e5, e6 |
+|-----------|-----------|-----------|
+| → (apply) | → (apply) | → (apply) |
+
+Accumulated gradient for the weight w1 (considering all GPUs):
+Total gradient for w1 = e1 + e2 + e3 + e4 + e5 + e6
+
+Weight update for w1:
+w1_new = w1_old - learning rate × (Total gradient for w1 / 6)
+```
+
+
+
### Train
Run
@@ -718,11 +836,6 @@ fsdp_config:
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
```
-- llama Deepspeed
-```yaml
-deepspeed: deepspeed/zero3.json
-```
-
##### Weights & Biases Logging
- wandb options
@@ -735,6 +848,24 @@ wandb_run_id:
wandb_log_model:
```
+### Training with Deepspeed
+
+Deepspeed is an optimization suite for multi-gpu systems allowing you to train much larger models than you
+might typically be able to fit into your GPU's VRAM. More information about the various optimization types
+for deepspeed is available at https://huggingface.co/docs/accelerate/main/en/usage_guides/deepspeed#what-is-integrated
+
+We provide several default deepspeed JSON configurations for ZeRO stage 1, 2, and 3.
+
+```shell
+accelerate launch -m axolotl.cli.train examples/llama-2/config.py --deepspeed deepspeed/zero1.json
+```
+
+or
+
+```yaml
+deepspeed: deepspeed/zero1.json
+```
+
### Inference
Pass the appropriate flag to the train command:
@@ -753,6 +884,10 @@ Pass the appropriate flag to the train command:
--base_model="./completed-model" --prompter=None --load_in_8bit=True
```
+Please use `--sample_packing False` if you have it on and receive the error similar to below:
+
+> RuntimeError: stack expects each tensor to be equal size, but got [1, 32, 1, 128] at entry 0 and [1, 32, 8, 128] at entry 1
+
### Merge LORA to base
Add below flag to train command above
diff --git a/deepspeed/zero1.json b/deepspeed/zero1.json
index 39d648b2b5..85bc586091 100644
--- a/deepspeed/zero1.json
+++ b/deepspeed/zero1.json
@@ -1,39 +1,41 @@
{
- "zero_optimization": {
- "stage": 1,
- "overlap_comm": true
- },
- "bf16": {
- "enabled": "auto"
- },
- "fp16": {
- "enabled": "auto",
- "auto_cast": false,
- "loss_scale": 0,
- "initial_scale_power": 32,
- "loss_scale_window": 1000,
- "hysteresis": 2,
- "min_loss_scale": 1
- },
- "optimizer": {
- "type": "AdamW",
- "params": {
- "lr": "auto",
- "betas": "auto",
- "eps": "auto",
- "weight_decay": "auto"
- }
- },
- "scheduler": {
- "type": "WarmupDecayLR",
- "params": {
- "warmup_min_lr": "auto",
- "warmup_max_lr": "auto",
- "warmup_num_steps": "auto",
- "total_num_steps": "auto"
- }
- },
- "train_batch_size": "auto",
- "train_micro_batch_size_per_gpu": "auto",
- "wall_clock_breakdown": false
+ "zero_optimization": {
+ "stage": 1,
+ "overlap_comm": true
+ },
+ "bf16": {
+ "enabled": "auto"
+ },
+ "fp16": {
+ "enabled": "auto",
+ "auto_cast": false,
+ "loss_scale": 0,
+ "initial_scale_power": 32,
+ "loss_scale_window": 1000,
+ "hysteresis": 2,
+ "min_loss_scale": 1
+ },
+ "optimizer": {
+ "type": "AdamW",
+ "params": {
+ "lr": "auto",
+ "betas": "auto",
+ "eps": "auto",
+ "weight_decay": "auto"
+ }
+ },
+ "scheduler": {
+ "type": "WarmupDecayLR",
+ "params": {
+ "warmup_min_lr": "auto",
+ "warmup_max_lr": "auto",
+ "warmup_num_steps": "auto",
+ "warmup_type": "linear",
+ "total_num_steps": "auto"
+ }
+ },
+ "gradient_accumulation_steps": "auto",
+ "train_batch_size": "auto",
+ "train_micro_batch_size_per_gpu": "auto",
+ "wall_clock_breakdown": false
}
diff --git a/deepspeed/zero2.json b/deepspeed/zero2.json
index 16f2710f69..18ec4bd0e3 100644
--- a/deepspeed/zero2.json
+++ b/deepspeed/zero2.json
@@ -1,43 +1,45 @@
{
- "zero_optimization": {
- "stage": 2,
- "offload_optimizer": {
- "device": "cpu"
- },
- "contiguous_gradients": true,
- "overlap_comm": true
+ "zero_optimization": {
+ "stage": 2,
+ "offload_optimizer": {
+ "device": "cpu"
},
- "bf16": {
- "enabled": "auto"
- },
- "fp16": {
- "enabled": "auto",
- "auto_cast": false,
- "loss_scale": 0,
- "initial_scale_power": 32,
- "loss_scale_window": 1000,
- "hysteresis": 2,
- "min_loss_scale": 1
- },
- "optimizer": {
- "type": "AdamW",
- "params": {
- "lr": "auto",
- "betas": "auto",
- "eps": "auto",
- "weight_decay": "auto"
- }
- },
- "scheduler": {
- "type": "WarmupDecayLR",
- "params": {
- "warmup_min_lr": "auto",
- "warmup_max_lr": "auto",
- "warmup_num_steps": "auto",
- "total_num_steps": "auto"
- }
- },
- "train_batch_size": "auto",
- "train_micro_batch_size_per_gpu": "auto",
- "wall_clock_breakdown": false
+ "contiguous_gradients": true,
+ "overlap_comm": true
+ },
+ "bf16": {
+ "enabled": "auto"
+ },
+ "fp16": {
+ "enabled": "auto",
+ "auto_cast": false,
+ "loss_scale": 0,
+ "initial_scale_power": 32,
+ "loss_scale_window": 1000,
+ "hysteresis": 2,
+ "min_loss_scale": 1
+ },
+ "optimizer": {
+ "type": "AdamW",
+ "params": {
+ "lr": "auto",
+ "betas": "auto",
+ "eps": "auto",
+ "weight_decay": "auto"
+ }
+ },
+ "scheduler": {
+ "type": "WarmupDecayLR",
+ "params": {
+ "warmup_min_lr": "auto",
+ "warmup_max_lr": "auto",
+ "warmup_num_steps": "auto",
+ "warmup_type": "linear",
+ "total_num_steps": "auto"
+ }
+ },
+ "gradient_accumulation_steps": "auto",
+ "train_batch_size": "auto",
+ "train_micro_batch_size_per_gpu": "auto",
+ "wall_clock_breakdown": false
}
diff --git a/deepspeed/zero3.json b/deepspeed/zero3.json
index 3f5dff3302..5da9c07e28 100644
--- a/deepspeed/zero3.json
+++ b/deepspeed/zero3.json
@@ -45,9 +45,11 @@
"params": {
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
- "warmup_num_steps": "auto"
+ "warmup_num_steps": "auto",
+ "warmup_type": "linear"
}
},
+ "gradient_accumulation_steps": "auto",
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false
diff --git a/docker/Dockerfile b/docker/Dockerfile
index 8608e2348b..ff47548bc4 100644
--- a/docker/Dockerfile
+++ b/docker/Dockerfile
@@ -5,6 +5,9 @@ ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
ARG AXOLOTL_EXTRAS=""
ARG CUDA="118"
ENV BNB_CUDA_VERSION=$CUDA
+ARG PYTORCH_VERSION="2.0.1"
+
+ENV PYTORCH_VERSION=$PYTORCH_VERSION
RUN apt-get update && \
apt-get install -y vim curl
@@ -12,17 +15,19 @@ RUN apt-get update && \
WORKDIR /workspace
RUN git clone --depth=1 https://github.com/OpenAccess-AI-Collective/axolotl.git
+
+WORKDIR /workspace/axolotl
+
# If AXOLOTL_EXTRAS is set, append it in brackets
-RUN cd axolotl && \
- if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
+RUN sed -i "s/torch==.*/torch==$PYTORCH_VERSION/" requirements.txt
+RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install -e .[flash-attn,$AXOLOTL_EXTRAS]; \
else \
pip install -e .[flash-attn]; \
fi
# fix so that git fetch/pull from remote works
-RUN cd axolotl && \
- git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \
+RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \
git config --get remote.origin.fetch
# helper for huggingface-login cli
diff --git a/docker/Dockerfile-base b/docker/Dockerfile-base
index 643a3bfb62..8f33ff2295 100644
--- a/docker/Dockerfile-base
+++ b/docker/Dockerfile-base
@@ -13,16 +13,14 @@ ARG CUDA="118"
ENV PYTHON_VERSION=$PYTHON_VERSION
-RUN apt-get update
-RUN apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev && rm -rf /var/lib/apt/lists/*
-
-RUN wget \
+RUN apt-get update \
+ && apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev && rm -rf /var/lib/apt/lists/* \
+ && wget \
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
&& mkdir /root/.conda \
&& bash Miniconda3-latest-Linux-x86_64.sh -b \
- && rm -f Miniconda3-latest-Linux-x86_64.sh
-
-RUN conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}"
+ && rm -f Miniconda3-latest-Linux-x86_64.sh \
+ && conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}"
ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
@@ -46,6 +44,8 @@ FROM base-builder AS bnb-builder
WORKDIR /workspace
ARG CUDA="118"
ENV CUDA=$CUDA
+ARG MAX_JOBS="-1"
+ENV MAX_JOBS=$MAX_JOBS
RUN git clone https://github.com/TimDettmers/bitsandbytes.git && \
cd bitsandbytes && \
@@ -57,12 +57,6 @@ FROM base-builder
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
-# recompile apex
-RUN python3 -m pip uninstall -y apex
-RUN git clone https://github.com/NVIDIA/apex
-# `MAX_JOBS=1` disables parallel building to avoid cpu memory OOM when building image on GitHub Action (standard) runners
-RUN cd apex && MAX_JOBS=1 python3 -m pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./
-
RUN mkdir -p /workspace/builds
COPY --from=bnb-builder /workspace/bitsandbytes /workspace/builds/bitsandbytes
diff --git a/docs/multi-node.md b/docs/multi-node.md
new file mode 100644
index 0000000000..6806159690
--- /dev/null
+++ b/docs/multi-node.md
@@ -0,0 +1,45 @@
+# Multi Node
+
+You will need to create a configuration for accelerate, either by using `accelerate config` and follow the instructions or you can use one of the preset below:
+
+~/.cache/huggingface/accelerate/default_config.yaml
+```yaml
+compute_environment: LOCAL_MACHINE
+debug: false
+distributed_type: FSDP
+downcast_bf16: 'no'
+machine_rank: 0 # Set to 0 for the main machine, increment by one for other machines
+main_process_ip: 10.0.0.4 # Set to main machine's IP
+main_process_port: 5000
+main_training_function: main
+mixed_precision: bf16
+num_machines: 2 # Change to the number of machines
+num_processes: 4 # That's the total number of GPUs, (for example: if you have 2 machines with 4 GPU, put 8)
+rdzv_backend: static
+same_network: true
+tpu_env: []
+tpu_use_cluster: false
+tpu_use_sudo: false
+use_cpu: false
+```
+
+Configure your model to use FSDP with for example:
+```yaml
+fsdp:
+ - full_shard
+ - auto_wrap
+fsdp_config:
+ fsdp_offload_params: true
+ fsdp_state_dict_type: FULL_STATE_DICT
+ fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
+```
+
+## Machine configuration
+
+On each machine you need a copy of Axolotl, we suggest using the same commit to ensure compatibility.
+
+You will also need to have the same configuration file for your model on each machine.
+
+On the main machine only, make sure the port you set as `main_process_port` is open in TCP and reachable by other machines.
+
+All you have to do now is launch using accelerate as you would usually do on each machine and voila, the processes will start once you have launched accelerate on every machine.
diff --git a/docs/multipack.md b/docs/multipack.md
new file mode 100644
index 0000000000..2a55148b2a
--- /dev/null
+++ b/docs/multipack.md
@@ -0,0 +1,51 @@
+# Multipack
+
+4k context, bsz =4,
+each character represents 256 tokens
+X represents a padding token
+
+```
+ 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5
+[[ A A A A A A A A A A A ]
+ B B B B B B ]
+ C C C C C C C ]
+ D D D D ]]
+
+[[ E E E E E E E E ]
+ [ F F F F ]
+ [ G G G ]
+ [ H H H H ]]
+
+[[ I I I ]
+ [ J J J ]
+ [ K K K K K]
+ [ L L L ]]
+```
+
+after padding to longest input in each step
+```
+ 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5
+[[ A A A A A A A A A A A ]
+ B B B B B B X X X X X X ]
+ C C C C C C C X X X X ]
+ D D D D X X X X X X X ]]
+
+[[ E E E E E E E E ]
+ [ F F F F X X X X ]
+ [ G G G X X X X X ]
+ [ H H H H X X X X ]]
+
+[[ I I I X X ]
+ [ J J J X X ]
+ [ K K K K K ]
+ [ L L L X X ]]
+```
+
+w packing ( note it's the same effective number of tokens per step, but a true bsz of 1)
+```
+ 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5
+[[ A A A A A A A A A A A B B B B B
+ B C C C C C C C D D D D E E E E
+ E E E E F F F F F G G G H H H H
+ I I I J J J J K K K K K L L L X ]]
+```
diff --git a/examples/cerebras/qlora.yml b/examples/cerebras/qlora.yml
index 0e8a54bdaf..bab3d04d51 100644
--- a/examples/cerebras/qlora.yml
+++ b/examples/cerebras/qlora.yml
@@ -7,7 +7,7 @@ push_dataset_to_hub:
datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca
-dataset_prepared_path: last_run_prepared
+dataset_prepared_path:
val_set_size: 0.01
adapter: qlora
peft_model_dir:
diff --git a/examples/code-llama/13b/lora.yml b/examples/code-llama/13b/lora.yml
index 3faa70f41a..74aab34f97 100644
--- a/examples/code-llama/13b/lora.yml
+++ b/examples/code-llama/13b/lora.yml
@@ -11,7 +11,7 @@ strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
-dataset_prepared_path: last_run_prepared
+dataset_prepared_path:
val_set_size: 0.01
output_dir: ./lora-out
diff --git a/examples/code-llama/13b/qlora.yml b/examples/code-llama/13b/qlora.yml
index f70ac8a628..a6b941089f 100644
--- a/examples/code-llama/13b/qlora.yml
+++ b/examples/code-llama/13b/qlora.yml
@@ -11,7 +11,7 @@ strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
-dataset_prepared_path: last_run_prepared
+dataset_prepared_path:
val_set_size: 0.01
output_dir: ./qlora-out
diff --git a/examples/code-llama/34b/lora.yml b/examples/code-llama/34b/lora.yml
index 4a62297d09..9623c9c2cb 100644
--- a/examples/code-llama/34b/lora.yml
+++ b/examples/code-llama/34b/lora.yml
@@ -11,7 +11,7 @@ strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
-dataset_prepared_path: last_run_prepared
+dataset_prepared_path:
val_set_size: 0.01
output_dir: ./lora-out
diff --git a/examples/code-llama/34b/qlora.yml b/examples/code-llama/34b/qlora.yml
index 14c24a9708..6a980e4076 100644
--- a/examples/code-llama/34b/qlora.yml
+++ b/examples/code-llama/34b/qlora.yml
@@ -11,7 +11,7 @@ strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
-dataset_prepared_path: last_run_prepared
+dataset_prepared_path:
val_set_size: 0.01
output_dir: ./qlora-out
diff --git a/examples/code-llama/7b/lora.yml b/examples/code-llama/7b/lora.yml
index 39bc5d1b23..4546d737b5 100644
--- a/examples/code-llama/7b/lora.yml
+++ b/examples/code-llama/7b/lora.yml
@@ -11,7 +11,7 @@ strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
-dataset_prepared_path: last_run_prepared
+dataset_prepared_path:
val_set_size: 0.01
output_dir: ./lora-out
diff --git a/examples/code-llama/7b/qlora.yml b/examples/code-llama/7b/qlora.yml
index bb2d2a8976..0b6d1c02dd 100644
--- a/examples/code-llama/7b/qlora.yml
+++ b/examples/code-llama/7b/qlora.yml
@@ -11,7 +11,7 @@ strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
-dataset_prepared_path: last_run_prepared
+dataset_prepared_path:
val_set_size: 0.01
output_dir: ./qlora-out
diff --git a/examples/falcon/config-7b-lora.yml b/examples/falcon/config-7b-lora.yml
index 0743e9145c..79b7b62a35 100644
--- a/examples/falcon/config-7b-lora.yml
+++ b/examples/falcon/config-7b-lora.yml
@@ -3,6 +3,7 @@ base_model_config: tiiuae/falcon-7b
trust_remote_code: true
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
+is_falcon_derived_model: true
load_in_8bit: true
load_in_4bit: false
gptq: false
@@ -11,7 +12,7 @@ push_dataset_to_hub:
datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca:chat
-dataset_prepared_path: last_run_prepared
+dataset_prepared_path:
val_set_size: 0.01
adapter: lora
peft_model_dir:
diff --git a/examples/falcon/config-7b-qlora.yml b/examples/falcon/config-7b-qlora.yml
index 7ae69a95dc..8fa914820d 100644
--- a/examples/falcon/config-7b-qlora.yml
+++ b/examples/falcon/config-7b-qlora.yml
@@ -6,6 +6,7 @@ base_model_config: tiiuae/falcon-7b
trust_remote_code: true
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
+is_falcon_derived_model: true
load_in_8bit: false
# enable 4bit for QLoRA
load_in_4bit: true
@@ -17,7 +18,7 @@ datasets:
data_files:
- Chain-of-Thought/formatted_cot_data/gsm8k_train.json
type: "alpaca:chat"
-dataset_prepared_path: last_run_prepared
+dataset_prepared_path:
val_set_size: 0.01
# enable QLoRA
adapter: qlora
diff --git a/examples/falcon/config-7b.yml b/examples/falcon/config-7b.yml
index 06af7c0978..adaea40eda 100644
--- a/examples/falcon/config-7b.yml
+++ b/examples/falcon/config-7b.yml
@@ -3,6 +3,7 @@ base_model_config: tiiuae/falcon-7b
trust_remote_code: true
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
+is_falcon_derived_model: true
load_in_8bit: false
load_in_4bit: false
gptq: false
@@ -11,7 +12,7 @@ push_dataset_to_hub:
datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca:chat
-dataset_prepared_path: last_run_prepared
+dataset_prepared_path:
val_set_size: 0.01
adapter:
peft_model_dir:
diff --git a/examples/gptj/qlora.yml b/examples/gptj/qlora.yml
index 1887a3f279..b1b575155f 100644
--- a/examples/gptj/qlora.yml
+++ b/examples/gptj/qlora.yml
@@ -7,7 +7,7 @@ push_dataset_to_hub:
datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca
-dataset_prepared_path: last_run_prepared
+dataset_prepared_path:
val_set_size: 0.01
adapter: qlora
peft_model_dir:
diff --git a/examples/jeopardy-bot/config.yml b/examples/jeopardy-bot/config.yml
index 3b9e926798..946308d363 100644
--- a/examples/jeopardy-bot/config.yml
+++ b/examples/jeopardy-bot/config.yml
@@ -6,7 +6,7 @@ load_in_8bit: false
datasets:
- path: openaccess-ai-collective/jeopardy
type: jeopardy
-dataset_prepared_path: last_run_prepared
+dataset_prepared_path:
val_set_size: 0.02
adapter:
peft_model_dir:
diff --git a/examples/llama-2/gptq-lora.yml b/examples/llama-2/gptq-lora.yml
index c1fba92ed0..48ac43a96e 100644
--- a/examples/llama-2/gptq-lora.yml
+++ b/examples/llama-2/gptq-lora.yml
@@ -15,7 +15,7 @@ hf_use_auth_token: true
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
-dataset_prepared_path: last_run_prepared
+dataset_prepared_path:
val_set_size: 0.01
adapter: lora
peft_model_dir:
diff --git a/examples/llama-2/lora.yml b/examples/llama-2/lora.yml
index b1981ab933..e196183ca9 100644
--- a/examples/llama-2/lora.yml
+++ b/examples/llama-2/lora.yml
@@ -1,5 +1,5 @@
-base_model: meta-llama/Llama-2-7b-hf
-base_model_config: meta-llama/Llama-2-7b-hf
+base_model: NousResearch/Llama-2-7b-hf
+base_model_config: NousResearch/Llama-2-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
@@ -11,7 +11,7 @@ strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
-dataset_prepared_path: last_run_prepared
+dataset_prepared_path:
val_set_size: 0.01
output_dir: ./lora-out
@@ -56,7 +56,7 @@ flash_attention: true
warmup_steps: 10
eval_steps: 20
-eval_table_size: 5
+eval_table_size:
eval_table_max_new_tokens: 128
save_steps:
debug:
diff --git a/examples/llama-2/qlora.yml b/examples/llama-2/qlora.yml
index 0c1f89d469..bdf40e1a8c 100644
--- a/examples/llama-2/qlora.yml
+++ b/examples/llama-2/qlora.yml
@@ -1,5 +1,5 @@
-base_model: meta-llama/Llama-2-7b-hf
-base_model_config: meta-llama/Llama-2-7b-hf
+base_model: NousResearch/Llama-2-7b-hf
+base_model_config: NousResearch/Llama-2-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
@@ -11,7 +11,7 @@ strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
-dataset_prepared_path: last_run_prepared
+dataset_prepared_path:
val_set_size: 0.01
output_dir: ./qlora-out
@@ -58,7 +58,7 @@ flash_attention: true
warmup_steps: 10
eval_steps: 20
-eval_table_size: 5
+eval_table_size:
save_steps:
debug:
deepspeed:
diff --git a/examples/llama-2/relora.yml b/examples/llama-2/relora.yml
index 84c0dc0390..be3b38efc5 100644
--- a/examples/llama-2/relora.yml
+++ b/examples/llama-2/relora.yml
@@ -1,5 +1,5 @@
-base_model: meta-llama/Llama-2-7b-hf
-base_model_config: meta-llama/Llama-2-7b-hf
+base_model: NousResearch/Llama-2-7b-hf
+base_model_config: NousResearch/Llama-2-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
@@ -11,7 +11,7 @@ strict: false
datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca
-dataset_prepared_path: last_run_prepared
+dataset_prepared_path:
val_set_size: 0.01
output_dir: ./relora-out
diff --git a/examples/llama-2/tiny-llama.yml b/examples/llama-2/tiny-llama.yml
index db7ec5e8d5..3526333b10 100644
--- a/examples/llama-2/tiny-llama.yml
+++ b/examples/llama-2/tiny-llama.yml
@@ -12,7 +12,7 @@ strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
-dataset_prepared_path: last_run_prepared
+dataset_prepared_path:
val_set_size: 0.01
output_dir: ./lora-out
@@ -56,7 +56,7 @@ flash_attention: true
warmup_steps: 10
eval_steps: 20
-eval_table_size: 5
+eval_table_size:
save_steps:
debug:
deepspeed:
diff --git a/examples/mistral/README.md b/examples/mistral/README.md
new file mode 100644
index 0000000000..d1efb2cab3
--- /dev/null
+++ b/examples/mistral/README.md
@@ -0,0 +1,12 @@
+**Mistral 7B** is a language model with a total of 7.3 billion parameters, showcasing a notable performance across a variety of benchmarks.
+
+Fine Tune:
+```shell
+accelerate launch -m axolotl.cli.train examples/mistral/config.yml
+
+```
+
+If you run into CUDA OOM, use deepspeed with config zero2.json:
+```shell
+accelerate launch -m axolotl.cli.train examples/mistral/config.yml --deepspeed deepspeed/zero2.json
+```
diff --git a/examples/mistral/config.yml b/examples/mistral/config.yml
new file mode 100644
index 0000000000..2a4498a119
--- /dev/null
+++ b/examples/mistral/config.yml
@@ -0,0 +1,62 @@
+base_model: mistralai/Mistral-7B-v0.1
+base_model_config: mistralai/Mistral-7B-v0.1
+model_type: MistralForCausalLM
+tokenizer_type: LlamaTokenizer
+is_mistral_derived_model: true
+
+load_in_8bit: false
+load_in_4bit: false
+strict: false
+
+datasets:
+ - path: mhenrichsen/alpaca_2k_test
+ type: alpaca
+dataset_prepared_path:
+val_set_size: 0.01
+output_dir: ./out
+
+sequence_len: 8192
+sample_packing: true
+pad_to_sequence_len: true
+
+wandb_project:
+wandb_entity:
+wandb_watch:
+wandb_run_id:
+wandb_log_model:
+
+gradient_accumulation_steps: 4
+micro_batch_size: 2
+num_epochs: 3
+optimizer: adamw_bnb_8bit
+lr_scheduler: cosine
+learning_rate: 0.000005
+
+train_on_inputs: false
+group_by_length: false
+bf16: true
+fp16: false
+tf32: false
+
+gradient_checkpointing: true
+early_stopping_patience:
+resume_from_checkpoint:
+local_rank:
+logging_steps: 1
+xformers_attention:
+flash_attention: true
+
+warmup_steps: 10
+eval_steps: 20
+eval_table_size: 5
+eval_table_max_new_tokens: 128
+save_steps:
+debug:
+deepspeed:
+weight_decay: 0.0
+fsdp:
+fsdp_config:
+special_tokens:
+ bos_token: ""
+ eos_token: ""
+ unk_token: ""
diff --git a/examples/mistral/qlora.yml b/examples/mistral/qlora.yml
new file mode 100644
index 0000000000..09639d006a
--- /dev/null
+++ b/examples/mistral/qlora.yml
@@ -0,0 +1,79 @@
+base_model: mistralai/Mistral-7B-v0.1
+base_model_config: mistralai/Mistral-7B-v0.1
+model_type: MistralForCausalLM
+tokenizer_type: LlamaTokenizer
+is_mistral_derived_model: true
+
+load_in_8bit: false
+load_in_4bit: true
+strict: false
+
+datasets:
+ - path: mhenrichsen/alpaca_2k_test
+ type: alpaca
+dataset_prepared_path: last_run_prepared
+val_set_size: 0.01
+output_dir: ./qlora-out
+
+adapter: qlora
+lora_model_dir:
+
+sequence_len: 8192
+sample_packing: true
+pad_to_sequence_len: true
+
+lora_r: 32
+lora_alpha: 16
+lora_dropout: 0.05
+lora_target_linear: true
+lora_fan_in_fan_out:
+lora_target_modules:
+ - gate_proj
+ - down_proj
+ - up_proj
+ - q_proj
+ - v_proj
+ - k_proj
+ - o_proj
+
+wandb_project:
+wandb_entity:
+wandb_watch:
+wandb_run_id:
+wandb_log_model:
+
+gradient_accumulation_steps: 4
+micro_batch_size: 2
+num_epochs: 1
+optimizer: adamw_bnb_8bit
+lr_scheduler: cosine
+learning_rate: 0.0002
+
+train_on_inputs: false
+group_by_length: false
+bf16: true
+fp16: false
+tf32: false
+
+gradient_checkpointing: true
+early_stopping_patience:
+resume_from_checkpoint:
+local_rank:
+logging_steps: 1
+xformers_attention:
+flash_attention: true
+
+warmup_steps: 10
+eval_steps: 20
+eval_table_size: 5
+eval_table_max_new_tokens: 128
+save_steps:
+debug:
+deepspeed:
+weight_decay: 0.0
+fsdp:
+fsdp_config:
+special_tokens:
+ bos_token: ""
+ eos_token: ""
+ unk_token: ""
diff --git a/examples/mpt-7b/config.yml b/examples/mpt-7b/config.yml
index 23442fada8..b57a83cd6c 100644
--- a/examples/mpt-7b/config.yml
+++ b/examples/mpt-7b/config.yml
@@ -6,7 +6,7 @@ load_in_8bit: false
datasets:
- path: vicgalle/alpaca-gpt4
type: alpaca
-dataset_prepared_path: last_run_prepared
+dataset_prepared_path:
val_set_size: 0.02
adapter:
peft_model_dir:
diff --git a/examples/openllama-3b/config.yml b/examples/openllama-3b/config.yml
index acc9041c2f..1f4338ca3b 100644
--- a/examples/openllama-3b/config.yml
+++ b/examples/openllama-3b/config.yml
@@ -9,7 +9,7 @@ push_dataset_to_hub:
datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca
-dataset_prepared_path: last_run_prepared
+dataset_prepared_path:
val_set_size: 0.02
adapter:
peft_model_dir:
diff --git a/examples/openllama-3b/lora.yml b/examples/openllama-3b/lora.yml
index 64a7f572b5..a59b5883f8 100644
--- a/examples/openllama-3b/lora.yml
+++ b/examples/openllama-3b/lora.yml
@@ -9,7 +9,7 @@ push_dataset_to_hub:
datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca
-dataset_prepared_path: last_run_prepared
+dataset_prepared_path:
val_set_size: 0.02
adapter: lora
peft_model_dir:
diff --git a/examples/openllama-3b/qlora.yml b/examples/openllama-3b/qlora.yml
index c95a76c8a3..9ebd774fbf 100644
--- a/examples/openllama-3b/qlora.yml
+++ b/examples/openllama-3b/qlora.yml
@@ -9,7 +9,7 @@ push_dataset_to_hub:
datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca
-dataset_prepared_path: last_run_prepared
+dataset_prepared_path:
val_set_size: 0.01
adapter: qlora
peft_model_dir:
diff --git a/examples/phi/phi-ft.yml b/examples/phi/phi-ft.yml
index d64ce2f945..30052252b3 100644
--- a/examples/phi/phi-ft.yml
+++ b/examples/phi/phi-ft.yml
@@ -13,7 +13,7 @@ datasets:
- path: garage-bAInd/Open-Platypus
type: alpaca
-dataset_prepared_path: last_run_prepared
+dataset_prepared_path:
val_set_size: 0.05
output_dir: ./phi-sft-out
diff --git a/examples/phi/phi-qlora.yml b/examples/phi/phi-qlora.yml
index 9f47fe6598..d3116e2079 100644
--- a/examples/phi/phi-qlora.yml
+++ b/examples/phi/phi-qlora.yml
@@ -13,7 +13,7 @@ datasets:
- path: garage-bAInd/Open-Platypus
type: alpaca
-dataset_prepared_path: last_run_prepared
+dataset_prepared_path:
val_set_size: 0.05
output_dir: ./phi-sft-out
diff --git a/examples/pythia-12b/config.yml b/examples/pythia-12b/config.yml
index 5da6af3f4a..30040efb0d 100644
--- a/examples/pythia-12b/config.yml
+++ b/examples/pythia-12b/config.yml
@@ -10,7 +10,7 @@ device_map: auto
datasets:
- path: vicgalle/alpaca-gpt4
type: alpaca
-dataset_prepared_path: last_run_prepared
+dataset_prepared_path:
val_set_size: 0.05
adapter:
peft_model_dir:
diff --git a/examples/pythia/lora.yml b/examples/pythia/lora.yml
index ea34655ea3..fe20da5638 100644
--- a/examples/pythia/lora.yml
+++ b/examples/pythia/lora.yml
@@ -4,7 +4,7 @@ load_in_8bit: true
datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca
-dataset_prepared_path: last_run_prepared
+dataset_prepared_path:
val_set_size: 0.05
adapter: lora
peft_model_dir:
@@ -28,8 +28,8 @@ num_epochs: 3
learning_rate: 0.00001
train_on_inputs: false
group_by_length: false
-bf16: True
-tf32: True
+bf16: true
+tf32: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
diff --git a/examples/redpajama/config-3b.yml b/examples/redpajama/config-3b.yml
index e918c8d75d..3dc3a16744 100644
--- a/examples/redpajama/config-3b.yml
+++ b/examples/redpajama/config-3b.yml
@@ -7,7 +7,7 @@ load_in_8bit: false
datasets:
- path: vicgalle/alpaca-gpt4
type: alpaca
-dataset_prepared_path: last_run_prepared
+dataset_prepared_path:
val_set_size: 0.02
adapter:
peft_model_dir:
diff --git a/examples/replit-3b/config-lora.yml b/examples/replit-3b/config-lora.yml
index 426fea9c3b..09174e6ba4 100644
--- a/examples/replit-3b/config-lora.yml
+++ b/examples/replit-3b/config-lora.yml
@@ -5,7 +5,7 @@ load_in_8bit: false
datasets:
- path: vicgalle/alpaca-gpt4
type: alpaca
-dataset_prepared_path: last_run_prepared
+dataset_prepared_path:
val_set_size: 0.05
adapter: lora
peft_model_dir:
diff --git a/examples/xgen-7b/xgen-7b-8k-qlora.yml b/examples/xgen-7b/xgen-7b-8k-qlora.yml
index 5ccee8171c..1fb61118d0 100644
--- a/examples/xgen-7b/xgen-7b-8k-qlora.yml
+++ b/examples/xgen-7b/xgen-7b-8k-qlora.yml
@@ -16,7 +16,7 @@ datasets:
data_files:
- openassistant_best_replies_train.jsonl
type: "completion"
-dataset_prepared_path: last_run_prepared
+dataset_prepared_path:
val_set_size: 0.01
# enable QLoRA
adapter: qlora
diff --git a/image/sticker_fixed.png b/image/sticker_fixed.png
new file mode 100644
index 0000000000..0f5c67a192
Binary files /dev/null and b/image/sticker_fixed.png differ
diff --git a/requirements.txt b/requirements.txt
index 4ef9f5fd2d..1092ebddfb 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -4,19 +4,19 @@ torch==2.0.1
auto-gptq
packaging
peft @ git+https://github.com/huggingface/peft.git
-transformers @ git+https://github.com/huggingface/transformers.git
+transformers @ git+https://github.com/huggingface/transformers.git@bd6205919aad4d3a2300a39a98a642f1cc3a5348
bitsandbytes>=0.41.1
-accelerate @ git+https://github.com/huggingface/accelerate
+accelerate @ git+https://github.com/huggingface/accelerate@80da9cfb09bb3cc9f1b385cb55d6b90d025a5fd9
+deepspeed
addict
-evaluate
fire
PyYAML>=6.0
datasets
-flash-attn>=2.2.1
+flash-attn>=2.3.0
sentencepiece
wandb
einops
-xformers
+xformers>=0.0.22
optimum
hf_transfer
colorama
@@ -30,4 +30,4 @@ scipy
scikit-learn==1.2.2
pynvml
art
-wandb
+fschat==0.2.29
diff --git a/scripts/finetune.py b/scripts/finetune.py
index 7b6751e31c..118a97b844 100644
--- a/scripts/finetune.py
+++ b/scripts/finetune.py
@@ -7,6 +7,7 @@
from axolotl.cli import (
check_accelerate_default_config,
+ check_user_token,
do_inference,
do_merge_lora,
load_cfg,
@@ -31,6 +32,7 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
)
parsed_cfg = load_cfg(config, **kwargs)
check_accelerate_default_config()
+ check_user_token()
parser = transformers.HfArgumentParser((TrainerCliArgs))
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
diff --git a/setup.py b/setup.py
index fca5088da1..e3ee54350b 100644
--- a/setup.py
+++ b/setup.py
@@ -13,9 +13,22 @@ def parse_requirements():
# Handle custom index URLs
_, url = line.split()
_dependency_links.append(url)
- elif "flash-attn" not in line and line and line[0] != "#":
+ elif (
+ "flash-attn" not in line
+ and "deepspeed" not in line
+ and line
+ and line[0] != "#"
+ ):
# Handle standard packages
_install_requires.append(line)
+
+ # TODO(wing) remove once xformers release supports torch 2.1.0
+ if "torch==2.1.0" in _install_requires:
+ _install_requires.pop(_install_requires.index("xformers>=0.0.22"))
+ _install_requires.append(
+ "xformers @ git+https://github.com/facebookresearch/xformers.git@main"
+ )
+
return _install_requires, _dependency_links
@@ -33,9 +46,9 @@ def parse_requirements():
dependency_links=dependency_links,
extras_require={
"flash-attn": [
- "flash-attn>=2.2.1",
+ "flash-attn>=2.3.0",
],
- "extras": [
+ "deepspeed": [
"deepspeed",
],
},
diff --git a/src/axolotl/cli/__init__.py b/src/axolotl/cli/__init__.py
index 90e1d508b0..07a6209e46 100644
--- a/src/axolotl/cli/__init__.py
+++ b/src/axolotl/cli/__init__.py
@@ -14,6 +14,8 @@
# add src to the pythonpath so we don't need to pip install this
from accelerate.commands.config import config_args
from art import text2art
+from huggingface_hub import HfApi
+from huggingface_hub.utils import LocalTokenNotFoundError
from transformers import GenerationConfig, TextStreamer
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
@@ -49,7 +51,7 @@ def print_axolotl_text_art(suffix=None):
def get_multi_line_input() -> Optional[str]:
- print("Give me an instruction (Ctrl + D to finish): ")
+ print("Give me an instruction (Ctrl + D to submit): ")
instruction = ""
for line in sys.stdin:
instruction += line # pylint: disable=consider-using-join
@@ -192,6 +194,7 @@ def load_cfg(config: Path = Path("examples/"), **kwargs):
# load the config from the yaml file
with open(config, encoding="utf-8") as file:
cfg: DictDefault = DictDefault(yaml.safe_load(file))
+ cfg.axolotl_config_path = config
# if there are any options passed in the cli, if it is something that seems valid from the yaml,
# then overwrite the value
cfg_keys = cfg.keys()
@@ -247,3 +250,16 @@ def check_accelerate_default_config():
LOG.warning(
f"accelerate config file found at {config_args.default_yaml_config_file}. This can lead to unexpected errors"
)
+
+
+def check_user_token():
+ # Verify if token is valid
+ api = HfApi()
+ try:
+ user_info = api.whoami()
+ return bool(user_info)
+ except LocalTokenNotFoundError:
+ LOG.warning(
+ "Error verifying HuggingFace token. Remember to log in using `huggingface-cli login` and get your access token from https://huggingface.co/settings/tokens if you want to use gated models or datasets."
+ )
+ return False
diff --git a/src/axolotl/cli/inference.py b/src/axolotl/cli/inference.py
index f3daac83dd..91405d8c66 100644
--- a/src/axolotl/cli/inference.py
+++ b/src/axolotl/cli/inference.py
@@ -14,6 +14,7 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
# pylint: disable=duplicate-code
print_axolotl_text_art()
parsed_cfg = load_cfg(config, **kwargs)
+ parsed_cfg.sample_packing = False
parser = transformers.HfArgumentParser((TrainerCliArgs))
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
diff --git a/src/axolotl/cli/train.py b/src/axolotl/cli/train.py
index 72a9250c8d..b49cbc6b60 100644
--- a/src/axolotl/cli/train.py
+++ b/src/axolotl/cli/train.py
@@ -1,30 +1,45 @@
"""
CLI to run training on a model
"""
+import logging
from pathlib import Path
import fire
import transformers
+from colorama import Fore
from axolotl.cli import (
check_accelerate_default_config,
+ check_user_token,
load_cfg,
load_datasets,
print_axolotl_text_art,
)
from axolotl.common.cli import TrainerCliArgs
+from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.train import train
+LOG = logging.getLogger("axolotl.cli.train")
+
def do_cli(config: Path = Path("examples/"), **kwargs):
# pylint: disable=duplicate-code
print_axolotl_text_art()
parsed_cfg = load_cfg(config, **kwargs)
check_accelerate_default_config()
+ check_user_token()
parser = transformers.HfArgumentParser((TrainerCliArgs))
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
+ if parsed_cli_args.prepare_ds_only and not parsed_cfg.dataset_prepared_path:
+ msg = (
+ Fore.RED
+ + "--prepare_ds_only called without dataset_prepared_path set."
+ + Fore.RESET
+ )
+ LOG.warning(msg)
+ parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
if parsed_cli_args.prepare_ds_only:
diff --git a/src/axolotl/common/const.py b/src/axolotl/common/const.py
new file mode 100644
index 0000000000..fd34ad4694
--- /dev/null
+++ b/src/axolotl/common/const.py
@@ -0,0 +1,5 @@
+"""
+Various shared constants
+"""
+
+DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
diff --git a/src/axolotl/datasets.py b/src/axolotl/datasets.py
index 75d8432da8..32b2e0cc2d 100644
--- a/src/axolotl/datasets.py
+++ b/src/axolotl/datasets.py
@@ -22,7 +22,7 @@ class TokenizedPromptDataset(Dataset):
"""
Dataset that returns tokenized prompts from a stream of text files.
Args:
- prompt_tokenizer (PromptTokenizingStrategy): The prompt tokenizing method for proccessing the data.
+ prompt_tokenizer (PromptTokenizingStrategy): The prompt tokenizing method for processing the data.
dataset (dataset.Dataset): Dataset with text files.
"""
@@ -38,10 +38,15 @@ def __init__( # pylint: disable=super-init-not-called
def process(self, dataset):
features = dataset.features.keys()
num_proc = min(64, os.cpu_count())
+ map_kwargs = {}
+ if self.prompt_tokenizer.supports_batched:
+ map_kwargs["batched"] = True
+ map_kwargs["batch_size"] = 100
return dataset.map(
self.prompt_tokenizer.tokenize_prompt,
num_proc=num_proc,
remove_columns=features,
+ **map_kwargs,
)
@@ -50,7 +55,7 @@ class ConstantLengthDataset(IterableDataset):
"""
Iterable dataset that returns constant length chunks of tokens from stream of text files.
Args:
- tokenizer (Tokenizer): The processor used for proccessing the data.
+ tokenizer (Tokenizer): The processor used for processing the data.
dataset (dataset.Dataset): Dataset with text files.
seq_length (int): Length of token sequences to return.
"""
diff --git a/src/axolotl/models/phi/modeling_mixformer_sequential.py b/src/axolotl/models/phi/modeling_mixformer_sequential.py
index 27bede5f01..fd2ec054c5 100644
--- a/src/axolotl/models/phi/modeling_mixformer_sequential.py
+++ b/src/axolotl/models/phi/modeling_mixformer_sequential.py
@@ -711,12 +711,8 @@ def __init__(
self.resid_dropout = nn.Dropout(config.resid_pdrop)
self.block_idx = block_idx
- self.mixer = MHA(config=config, **mixer, layer_idx=block_idx)
- mlp_cls = mlp.pop("mlp_cls")
- if mlp_cls == "fused_mlp":
- self.mlp = FusedMLP(config=config, **mlp)
- else:
- self.mlp = MLP(config=config, **mlp)
+ self.mixer = MHA(config, layer_idx=block_idx)
+ self.mlp = MLP(config)
def forward(
self,
diff --git a/src/axolotl/monkeypatch/btlm_attn_hijack_flash.py b/src/axolotl/monkeypatch/btlm_attn_hijack_flash.py
index be5a705595..1275906804 100644
--- a/src/axolotl/monkeypatch/btlm_attn_hijack_flash.py
+++ b/src/axolotl/monkeypatch/btlm_attn_hijack_flash.py
@@ -7,6 +7,7 @@
from typing import Optional, Tuple
import torch
+from accelerate import init_empty_weights
from flash_attn.flash_attn_interface import flash_attn_func
from transformers import AutoConfig, AutoModelForCausalLM
@@ -17,7 +18,8 @@ def replace_btlm_attn_with_flash_attn(model_name="cerebras/btlm-3b-8k-base"):
# this is a wonky hack to get the remotely loaded module
model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
# we need to load the model here in order for modeling_btlm to be available
- AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
+ with init_empty_weights():
+ AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
module_name = model_config.__class__.__module__.replace(
".configuration_btlm", ".modeling_btlm"
)
diff --git a/src/axolotl/monkeypatch/falcon_attn_hijack_flash.py b/src/axolotl/monkeypatch/falcon_attn_hijack_flash.py
deleted file mode 100644
index ed11c55234..0000000000
--- a/src/axolotl/monkeypatch/falcon_attn_hijack_flash.py
+++ /dev/null
@@ -1,101 +0,0 @@
-"""
-Flash Attention monkey patch for Falcon
-
-copied from https://github.com/pacman100/DHS-LLM-Workshop/blob/main/chat_assistant/training/falcon_flash_attn_monkey_patch.py
-"""
-
-from typing import Optional, Tuple
-
-import torch
-import transformers
-from flash_attn import flash_attn_func
-
-
-def forward(
- self,
- hidden_states: torch.Tensor,
- alibi: Optional[torch.Tensor],
- attention_mask: torch.Tensor, # pylint: disable=unused-argument
- layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
- head_mask: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
- use_cache: bool = False,
- output_attentions: bool = False, # pylint: disable=unused-argument
-):
- fused_qkv = self.query_key_value(
- hidden_states
- ) # [batch_size, seq_length, 3 x hidden_size]
- num_kv_heads = (
- self.num_heads if self.new_decoder_architecture else self.num_kv_heads
- )
- # 3 x [batch_size, seq_length, num_heads, head_dim]
- (
- query_layer,
- key_layer,
- value_layer,
- ) = self._split_heads( # pylint: disable=protected-access
- fused_qkv
- )
-
- batch_size, query_length, _, _ = query_layer.shape
-
- query_layer = query_layer.transpose(1, 2).reshape(
- batch_size * self.num_heads, query_length, self.head_dim
- )
- key_layer = key_layer.transpose(1, 2).reshape(
- batch_size * num_kv_heads,
- query_length,
- self.head_dim,
- )
- value_layer = value_layer.transpose(1, 2).reshape(
- batch_size * num_kv_heads, query_length, self.head_dim
- )
-
- past_kv_length = 0 if layer_past is None else layer_past[0].shape[1]
- query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length)
-
- if layer_past is not None:
- past_key, past_value = layer_past
- # concatenate along seq_length dimension:
- # - key: [batch_size * self.num_heads, kv_length, head_dim]
- # - value: [batch_size * self.num_heads, kv_length, head_dim]
- key_layer = torch.cat((past_key, key_layer), dim=1)
- value_layer = torch.cat((past_value, value_layer), dim=1)
-
- # unused
- # _, kv_length, _ = key_layer.shape
- if use_cache:
- present = (key_layer, value_layer)
- else:
- present = None
- # unused
- # attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, float("-1e9")).to(query_layer.dtype)
- query_layer_ = (
- query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
- .transpose(1, 2)
- .to(torch.bfloat16)
- )
- key_layer_ = (
- key_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim)
- .transpose(1, 2)
- .to(torch.bfloat16)
- )
- value_layer_ = (
- value_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim)
- .transpose(1, 2)
- .to(torch.bfloat16)
- )
-
- if alibi is not None:
- raise ValueError("`alibi` is not supported when `use_flash_attn` is True")
-
- # below output will have shape (batch_size, seqlen, nheads, headdim)
- attn_output = flash_attn_func(query_layer_, key_layer_, value_layer_, causal=True)
- attn_output = attn_output.reshape(
- batch_size, query_length, self.num_heads * self.head_dim
- )
- output_tensor = self.dense(attn_output)
- return output_tensor, present
-
-
-def replace_falcon_attn_with_flash_attn():
- transformers.models.falcon.modeling_falcon.FalconAttention.forward = forward
diff --git a/src/axolotl/monkeypatch/fastchat_conversation_turns.py b/src/axolotl/monkeypatch/fastchat_conversation_turns.py
new file mode 100644
index 0000000000..19313fb7e2
--- /dev/null
+++ b/src/axolotl/monkeypatch/fastchat_conversation_turns.py
@@ -0,0 +1,174 @@
+"""
+monkeypatch to add a get_turns method
+"""
+
+import logging
+from typing import Generator, Tuple
+
+from fastchat.conversation import SeparatorStyle
+
+LOG = logging.getLogger("axolotl.monkeypatch.fastchat_conversation_turns")
+
+
+def get_prompt(self) -> str:
+ ret = ""
+ for role, msg in self.get_turns():
+ ret += role + msg
+ return ret
+
+
+def get_turns( # pylint: disable=too-many-return-statements
+ self,
+) -> Generator[Tuple[str, str], None, None]:
+ """Get the prompt for generation."""
+ system_prompt = self.system_template.format(system_message=self.system_message)
+ if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE:
+ yield "", system_prompt + self.sep
+ for role, message in self.messages:
+ if message:
+ yield role + ": ", message + self.sep
+ else:
+ yield role + ":", ""
+ return
+ if self.sep_style == SeparatorStyle.ADD_COLON_TWO:
+ seps = [self.sep, self.sep2]
+ yield "", system_prompt + seps[0]
+ for i, (role, message) in enumerate(self.messages):
+ if message:
+ yield role + ": ", message + seps[i % 2]
+ else:
+ yield role + ":", ""
+ return
+ if self.sep_style == SeparatorStyle.ADD_COLON_SPACE_SINGLE:
+ yield "", system_prompt + self.sep
+ for role, message in self.messages:
+ if message:
+ yield role + ": ", message + self.sep
+ else:
+ yield role + ": ", "" # must be end with a space
+ return
+ if self.sep_style == SeparatorStyle.ADD_NEW_LINE_SINGLE:
+ yield "", "" if system_prompt == "" else system_prompt + self.sep
+ for role, message in self.messages:
+ if message:
+ yield role + "\n", message + self.sep
+ else:
+ yield role + "\n", ""
+ return
+ if self.sep_style == SeparatorStyle.NO_COLON_SINGLE:
+ yield "", system_prompt
+ for role, message in self.messages:
+ if message:
+ yield role, message + self.sep
+ else:
+ yield role, ""
+ return
+ if self.sep_style == SeparatorStyle.NO_COLON_TWO:
+ seps = [self.sep, self.sep2]
+ yield "", system_prompt
+ for i, (role, message) in enumerate(self.messages):
+ if message:
+ yield role, message + seps[i % 2]
+ else:
+ yield role, ""
+ return
+ if self.sep_style == SeparatorStyle.RWKV:
+ yield "", system_prompt
+ for i, (role, message) in enumerate(self.messages):
+ if message:
+ yield role + ": ", message.replace("\r\n", "\n").replace(
+ "\n\n", "\n"
+ ) + "\n\n"
+ else:
+ yield role + ":", ""
+ return
+ if self.sep_style == SeparatorStyle.LLAMA2:
+ seps = [self.sep, self.sep2]
+ if self.system_message:
+ yield "", system_prompt
+ else:
+ yield "", "[INST] "
+ for i, (role, message) in enumerate(self.messages[1:]):
+ if message:
+ yield role + " ", message + seps[i % 2]
+ else:
+ yield role, ""
+ return
+ if self.sep_style == SeparatorStyle.CHATGLM:
+ # source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308
+ # source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926
+ round_add_n = 1 if self.name == "chatglm2" else 0
+ if system_prompt:
+ yield "", system_prompt + self.sep
+
+ for i, (role, message) in enumerate(self.messages):
+ if i % 2 == 0:
+ yield "", f"[Round {i//2 + round_add_n}]{self.sep}"
+
+ if message:
+ yield f"{role}:", f"{message}{self.sep}"
+ else:
+ yield f"{role}:", ""
+ return
+ if self.sep_style == SeparatorStyle.CHATML:
+ yield "", "" if system_prompt == "" else system_prompt + self.sep + "\n"
+ for role, message in self.messages:
+ if message:
+ yield role + "\n", message + self.sep + "\n"
+ else:
+ yield role + "\n", ""
+ return
+ if self.sep_style == SeparatorStyle.CHATINTERN:
+ # source: https://huggingface.co/internlm/internlm-chat-7b-8k/blob/bd546fa984b4b0b86958f56bf37f94aa75ab8831/modeling_internlm.py#L771
+ seps = [self.sep, self.sep2]
+ yield "", system_prompt
+ for i, (role, message) in enumerate(self.messages):
+ prefix = "" if i % 2 == 0 else ""
+ if message:
+ yield prefix + role + ":", message + seps[i % 2] + "\n"
+ else:
+ yield role + ":", ""
+ return
+ if self.sep_style == SeparatorStyle.DOLLY:
+ seps = [self.sep, self.sep2]
+ yield "", system_prompt
+ for i, (role, message) in enumerate(self.messages):
+ if message:
+ suffix = "\n\n" if i % 2 == 1 else ""
+ yield role + ":\n", message + seps[i % 2] + suffix
+ else:
+ yield role + ":\n", ""
+ return
+ if self.sep_style == SeparatorStyle.PHOENIX:
+ yield "", system_prompt
+ for role, message in self.messages:
+ if message:
+ yield role + ": ", "" + message + ""
+ else:
+ yield role + ": " + "", ""
+ return
+ if self.sep_style == SeparatorStyle.ROBIN:
+ yield "", system_prompt + self.sep
+ for role, message in self.messages:
+ if message:
+ yield role + ":\n", message + self.sep
+ else:
+ yield role + ":\n", ""
+ return
+ if self.sep_style == SeparatorStyle.FALCON_CHAT:
+ if self.system_message:
+ yield "", system_prompt + self.sep
+ for role, message in self.messages:
+ if message:
+ yield role + ": ", message + self.sep
+ else:
+ yield role + ":", ""
+ else:
+ raise ValueError(f"Invalid style: {self.sep_style}")
+
+
+def add_get_turns_to_conversation():
+ import fastchat.conversation
+
+ fastchat.conversation.Conversation.get_turns = get_turns
+ fastchat.conversation.Conversation.get_prompt = get_prompt
diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py
index 0f2fbbd54f..262c10d90a 100644
--- a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py
+++ b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py
@@ -38,7 +38,11 @@
LOG = logging.getLogger("axolotl")
-def replace_llama_attn_with_flash_attn(packed: Optional[bool] = False):
+def replace_llama_attn_with_flash_attn(
+ packed: Optional[bool] = False,
+ cross_entropy: Optional[bool] = False,
+ rms_norm: Optional[bool] = False,
+):
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access
_prepare_decoder_attention_mask
)
@@ -49,33 +53,37 @@ def replace_llama_attn_with_flash_attn(packed: Optional[bool] = False):
llama_model_forward
)
- try:
- from flash_attn.losses.cross_entropy import CrossEntropyLoss
+ # skip only if explicitly disabled
+ if cross_entropy:
+ try:
+ from flash_attn.losses.cross_entropy import CrossEntropyLoss
- LOG.info("patching with flash_attn.losses.cross_entropy")
- transformers.models.llama.modeling_llama.CrossEntropyLoss = partial(
- CrossEntropyLoss, inplace_backward=True
- )
- except ImportError:
- LOG.info(
- "optimized flash-attention CrossEntropyLoss not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=xentropy_cuda_lib&subdirectory=csrc/xentropy'`)"
- )
+ LOG.info("patching with flash_attn.losses.cross_entropy")
+ transformers.models.llama.modeling_llama.CrossEntropyLoss = partial(
+ CrossEntropyLoss, inplace_backward=True
+ )
+ except ImportError:
+ LOG.info(
+ "optimized flash-attention CrossEntropyLoss not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=xentropy_cuda_lib&subdirectory=csrc/xentropy'`)"
+ )
- try:
- from flash_attn.ops.rms_norm import RMSNorm
+ # skip only if explicitly disabled
+ if rms_norm:
+ try:
+ from flash_attn.ops.rms_norm import RMSNorm
- class LlamaRMSNorm(RMSNorm):
- """Patched LLamaRMSNorm"""
+ class LlamaRMSNorm(RMSNorm):
+ """Patched LLamaRMSNorm"""
- def __init__(self, hidden_size, eps=1e-6):
- super().__init__(hidden_size, eps=eps)
+ def __init__(self, hidden_size, eps=1e-6):
+ super().__init__(hidden_size, eps=eps)
- LOG.info("patching with flash_attn.ops.rms_norm")
- transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
- except ImportError:
- LOG.info(
- "optimized flash-attention RMSNorm not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=dropout_layer_norm&subdirectory=csrc/layer_norm'`)"
- )
+ LOG.info("patching with flash_attn.ops.rms_norm")
+ transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
+ except ImportError:
+ LOG.info(
+ "optimized flash-attention RMSNorm not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=dropout_layer_norm&subdirectory=csrc/layer_norm'`)"
+ )
# Disable the transformation of the attention mask in LlamaModel as the flash attention
@@ -99,6 +107,7 @@ def flashattn_forward(
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
+ padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
@@ -489,6 +498,13 @@ def llama_model_forward(
dtype=torch.bool,
device=inputs_embeds.device,
)
+ padding_mask = None
+ else:
+ if 0 in attention_mask:
+ padding_mask = attention_mask
+ else:
+ padding_mask = None
+
attention_mask = (
self._prepare_decoder_attention_mask( # pylint: disable=protected-access
attention_mask,
@@ -524,7 +540,9 @@ def llama_model_forward(
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
- return module(*inputs)
+ return module(
+ *inputs,
+ )
return custom_forward
@@ -533,9 +551,10 @@ def custom_forward(*inputs):
hidden_states,
attention_mask,
position_ids,
- None,
+ past_key_value,
output_attentions,
None,
+ padding_mask,
cu_seqlens,
max_seqlen,
)
@@ -547,6 +566,7 @@ def custom_forward(*inputs):
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
+ padding_mask=padding_mask,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
@@ -597,6 +617,7 @@ def forward(
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
+ padding_mask: Optional[torch.LongTensor] = None,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[torch.Tensor] = None,
) -> Tuple[
@@ -629,6 +650,7 @@ def forward(
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
+ padding_mask=padding_mask,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
diff --git a/src/axolotl/monkeypatch/llama_embeddings_hijack.py b/src/axolotl/monkeypatch/llama_embeddings_hijack.py
new file mode 100644
index 0000000000..654ca3ba82
--- /dev/null
+++ b/src/axolotl/monkeypatch/llama_embeddings_hijack.py
@@ -0,0 +1,40 @@
+"""
+patch to add noisy embeddings per https://arxiv.org/abs/2310.05914
+"""
+
+import torch
+import transformers.models.llama.modeling_llama
+from transformers.utils import logging
+
+logger = logging.get_logger(__name__)
+
+
+def replace_llama_embeddings_with_uniform_distribution(noise_alpha=5):
+ # pylint: disable=duplicate-code
+ def noised_embed(orig_embed, noise_alpha, model):
+ def new_func(input_ids):
+ # during training, we add noise to the embedding
+ # during generation, we don't add noise to the embedding
+ if model.training:
+ embed_init = orig_embed(input_ids)
+ dims = torch.tensor(embed_init.size(1) * embed_init.size(2))
+ mag_norm = noise_alpha / torch.sqrt(dims)
+ return embed_init + torch.zeros_like(embed_init).uniform_(
+ -mag_norm, mag_norm
+ )
+ return orig_embed(input_ids)
+
+ return new_func
+
+ def post_init(orig_post_init):
+ def new_func(self):
+ orig_post_init(self)
+ self.embed_tokens.forward = noised_embed(
+ self.embed_tokens.forward, noise_alpha, self
+ )
+
+ return new_func
+
+ transformers.models.llama.modeling_llama.LlamaModel.post_init = post_init(
+ transformers.models.llama.modeling_llama.LlamaModel.post_init
+ )
diff --git a/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py b/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py
new file mode 100644
index 0000000000..26b511d06e
--- /dev/null
+++ b/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py
@@ -0,0 +1,640 @@
+"""Flash attention monkey patch for mistral model"""
+# pylint: disable=duplicate-code
+
+import logging
+from typing import List, Optional, Tuple, Union
+
+import torch
+import transformers
+from einops import rearrange
+from flash_attn.bert_padding import pad_input, unpad_input
+from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
+ flash_attn_kvpacked_func,
+ flash_attn_varlen_kvpacked_func,
+ flash_attn_varlen_qkvpacked_func,
+)
+from transformers.modeling_outputs import BaseModelOutputWithPast
+from transformers.models.mistral.modeling_mistral import (
+ MistralAttention as OriginalMistralAttention,
+)
+from transformers.models.mistral.modeling_mistral import (
+ MistralDecoderLayer as OriginalMistralDecoderLayer,
+)
+from transformers.models.mistral.modeling_mistral import apply_rotary_pos_emb, repeat_kv
+
+from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
+
+LOG = logging.getLogger("axolotl.monkeypatch.mistral")
+
+
+def replace_mistral_attn_with_flash_attn(
+ packed: Optional[bool] = False,
+):
+ transformers.models.mistral.modeling_mistral.MistralModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access
+ _prepare_decoder_attention_mask
+ )
+ transformers.models.mistral.modeling_mistral.MistralAttention.forward = (
+ flashattn_forward
+ )
+ if packed:
+ transformers.models.mistral.modeling_mistral.MistralDecoderLayer = (
+ MistralDecoderLayer
+ )
+ transformers.models.mistral.modeling_mistral.MistralModel.forward = (
+ mistral_model_forward
+ )
+
+
+@torch.jit.script
+def _make_sliding_window_causal_mask(
+ bsz: int,
+ tgt_len: int,
+ dtype: torch.dtype,
+ device: torch.device,
+ past_key_values_length: int = 0,
+ sliding_window: int = 4096,
+):
+ """
+ Make causal mask used for sliding window attention
+ """
+ tensor = torch.full(
+ (tgt_len, tgt_len),
+ fill_value=1,
+ device=device,
+ )
+ mask = torch.tril(tensor, diagonal=0)
+ # make the mask banded to account for sliding window
+ # NOTE: HF implementation is wrong as of 14-10-2023 for torch.triu, needs +1
+ mask = torch.triu(mask, diagonal=-sliding_window + 1)
+ mask = torch.log(mask).to(dtype)
+
+ if past_key_values_length > 0:
+ mask = torch.cat(
+ [
+ torch.zeros(
+ tgt_len, past_key_values_length, dtype=dtype, device=device
+ ),
+ mask,
+ ],
+ dim=-1,
+ )
+ return mask[None, None, :, :].expand(
+ bsz, 1, tgt_len, tgt_len + past_key_values_length
+ )
+
+
+# Disable the transformation of the attention mask in LlamaModel as the flash attention
+# requires the attention mask to be the same as the key_padding_mask
+def _prepare_decoder_attention_mask(
+ self,
+ attention_mask,
+ input_shape,
+ inputs_embeds,
+ past_key_values_length,
+ sliding_window,
+): # pylint: disable=unused-argument
+ # [bsz, seq_len]
+ if attention_mask is None:
+ return attention_mask
+
+ # NOTE: attention mask and sliding masks are only broadcastable in certain scenarios.
+ # Without attention_mask.shape[0] == 1, error will trigger after eval loss but only when wandb is enabled.
+ if input_shape[-1] > 1 and attention_mask.shape[0] == 1:
+ sliding_window_mask = _make_sliding_window_causal_mask(
+ bsz=input_shape[0],
+ tgt_len=input_shape[1],
+ dtype=inputs_embeds.dtype,
+ device=inputs_embeds.device,
+ past_key_values_length=past_key_values_length,
+ sliding_window=sliding_window,
+ )
+ attention_mask = attention_mask + sliding_window_mask
+ else:
+ LOG.info("skipping sliding window mask, not broadcastable with attention mask")
+
+ return attention_mask
+
+
+def flashattn_forward(
+ self: OriginalMistralAttention,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cu_seqlens: Optional[torch.Tensor] = None,
+ max_seqlen: Optional[torch.Tensor] = None,
+) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(
+ bsz, q_len, self.num_heads, self.head_dim
+ ).transpose(1, 2)
+ key_states = key_states.view(
+ bsz, q_len, self.num_key_value_heads, self.head_dim
+ ).transpose(1, 2)
+ value_states = value_states.view(
+ bsz, q_len, self.num_key_value_heads, self.head_dim
+ ).transpose(1, 2)
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ kv_seq_len += past_key_value[0].shape[-2]
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
+ query_states, key_states = apply_rotary_pos_emb(
+ query_states, key_states, cos, sin, position_ids
+ )
+
+ use_sliding_windows = (
+ hasattr(self.config, "sliding_window") is not None
+ and kv_seq_len > self.config.sliding_window
+ )
+
+ if use_sliding_windows:
+ window_size = (self.config.sliding_window, self.config.sliding_window)
+ else:
+ window_size = (-1, -1)
+
+ if past_key_value is not None:
+ # Activate slicing cache only if the config has a value `sliding_windows` attribute
+ if (
+ hasattr(self.config, "sliding_window")
+ and kv_seq_len > self.config.sliding_window
+ ):
+ slicing_tokens = kv_seq_len - self.config.sliding_window
+
+ past_key = past_key_value[0]
+ past_value = past_key_value[1]
+
+ past_key = past_key[:, :, slicing_tokens:, :].contiguous()
+ past_value = past_value[:, :, slicing_tokens:, :].contiguous()
+
+ if past_key.shape[-2] != self.config.sliding_window - 1:
+ raise ValueError(
+ f"past key much have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
+ f" {past_key.shape}"
+ )
+
+ past_key_value = (past_key, past_value) if use_cache else None
+
+ if past_key_value is not None:
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
+
+ past_key_value = (key_states, value_states) if use_cache else None
+
+ # repeat k/v heads if n_kv_heads < n_heads
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ if self.training:
+ # during training q,k,v always have same seqlen
+ assert key_states.shape == query_states.shape
+ is_causal = True
+ else:
+ # turn off FA causal mask after first inference autoregressive iteration
+ # only on first autoregressive step q,k,v have same seqlen
+ is_causal = key_states.shape == query_states.shape
+
+ if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
+ # special handling using sample packing
+ qkv = torch.stack(
+ [query_states, key_states, value_states], dim=2
+ ) # [bsz, nh, 3, q_len, hd]
+ qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
+ qkv = rearrange(qkv, "b s ... -> (b s) ...")
+
+ output = flash_attn_varlen_qkvpacked_func(
+ qkv,
+ cu_seqlens,
+ max_seqlen,
+ 0.0,
+ softmax_scale=None,
+ causal=True,
+ window_size=window_size,
+ )
+ output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
+ elif query_states.shape == key_states.shape:
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+ qkv_unpad, cu_seqlens_q, max_seqlen_q, _, output_pad_fn = generate_qkv(
+ query_states,
+ key_states,
+ value_states,
+ qkvpacked=True,
+ # We have disabled _prepare_decoder_attention_mask in LlamaModel
+ # the attention_mask should be the same as the key_padding_mask
+ key_padding_mask=attention_mask,
+ query_padding_mask=attention_mask[:, -query_states.size(1) :]
+ if attention_mask is not None
+ else None,
+ )
+ output_unpad = flash_attn_varlen_qkvpacked_func(
+ qkv_unpad,
+ cu_seqlens_q,
+ max_seqlen_q,
+ 0.0,
+ softmax_scale=None,
+ causal=is_causal,
+ window_size=window_size,
+ )
+ output = output_pad_fn(output_unpad)
+ else:
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+ if attention_mask is None or attention_mask.all().item():
+ output = flash_attn_kvpacked_func(
+ query_states,
+ torch.stack([key_states, value_states], 2),
+ causal=is_causal,
+ window_size=window_size,
+ )
+ else:
+ ( # pylint: disable=unbalanced-tuple-unpacking
+ q_unpad,
+ kv_unpad,
+ cu_seqlens_q,
+ cu_seqlens_k,
+ max_seqlen_q,
+ max_seqlen_k,
+ _,
+ _,
+ output_pad_fn,
+ ) = generate_qkv(
+ query_states,
+ key_states,
+ value_states,
+ kvpacked=True,
+ key_padding_mask=attention_mask,
+ query_padding_mask=attention_mask[:, -query_states.size(1) :]
+ if attention_mask is not None
+ else None,
+ )
+ if q_unpad.dtype != kv_unpad.dtype:
+ kv_unpad = kv_unpad.to(q_unpad.dtype)
+ output_unpad = flash_attn_varlen_kvpacked_func(
+ q_unpad,
+ kv_unpad,
+ cu_seqlens_q,
+ cu_seqlens_k,
+ max_seqlen_q,
+ max_seqlen_k,
+ 0.0,
+ softmax_scale=None,
+ causal=is_causal,
+ window_size=window_size,
+ )
+ output = output_pad_fn(output_unpad)
+
+ attn_output = output
+ if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+ attn_output = rearrange(attn_output, "b s h d -> b s (h d)")
+
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+# based on https://github.com/Dao-AILab/flash-attention/blob/364a5b/tests/test_flash_attn.py#L38
+def generate_qkv(
+ q,
+ k,
+ v,
+ query_padding_mask=None,
+ key_padding_mask=None,
+ kvpacked=False,
+ qkvpacked=False,
+): # pylint: disable=invalid-name,unnecessary-lambda-assignment
+ """
+ Arguments:
+ q: (batch_size, seqlen_q, nheads, d)
+ k: (batch_size, seqlen_k, nheads_k, d)
+ v: (batch_size, seqlen_k, nheads_k, d)
+ query_padding_mask: (batch_size, seqlen), bool
+ key_padding_mask: (batch_size, seqlen), bool
+ """
+ assert not (kvpacked and qkvpacked)
+ batch_size, seqlen_q, nheads, d = q.shape
+ _, seqlen_k, nheads_k, _ = k.shape
+ assert k.shape == (batch_size, seqlen_k, nheads_k, d)
+ assert v.shape == (batch_size, seqlen_k, nheads_k, d)
+
+ if query_padding_mask is not None:
+ q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(
+ q, query_padding_mask
+ )
+
+ output_pad_fn = lambda output_unpad: pad_input( # noqa: E731
+ output_unpad, indices_q, batch_size, seqlen_q
+ )
+
+ else:
+ q_unpad = rearrange(q, "b s h d -> (b s) h d")
+ cu_seqlens_q = torch.arange(
+ 0,
+ (batch_size + 1) * seqlen_q,
+ step=seqlen_q,
+ dtype=torch.int32,
+ device=q_unpad.device,
+ )
+ max_seqlen_q = seqlen_q
+
+ output_pad_fn = lambda output_unpad: rearrange( # noqa: E731
+ output_unpad, "(b s) h d -> b s h d", b=batch_size
+ )
+
+ if key_padding_mask is not None:
+ k_unpad, _, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask)
+ v_unpad, _, _, _ = unpad_input(v, key_padding_mask)
+ else:
+ k_unpad = rearrange(k, "b s h d -> (b s) h d")
+ v_unpad = rearrange(v, "b s h d -> (b s) h d")
+ cu_seqlens_k = torch.arange(
+ 0,
+ (batch_size + 1) * seqlen_k,
+ step=seqlen_k,
+ dtype=torch.int32,
+ device=k_unpad.device,
+ )
+ max_seqlen_k = seqlen_k
+
+ if qkvpacked:
+ assert nheads == nheads_k
+ qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1)
+ qkv = torch.stack([q, k, v], dim=2)
+ return (qkv_unpad, cu_seqlens_q, max_seqlen_q, qkv, output_pad_fn)
+
+ if kvpacked:
+ kv_unpad = torch.stack([k_unpad, v_unpad], dim=1)
+ kv = torch.stack([k, v], dim=2)
+ return (
+ q_unpad,
+ kv_unpad,
+ cu_seqlens_q,
+ cu_seqlens_k,
+ max_seqlen_q,
+ max_seqlen_k,
+ q,
+ kv,
+ output_pad_fn,
+ )
+
+ return (
+ q_unpad,
+ k_unpad,
+ v_unpad,
+ cu_seqlens_q,
+ cu_seqlens_k,
+ max_seqlen_q,
+ max_seqlen_k,
+ q,
+ k,
+ v,
+ output_pad_fn,
+ )
+
+
+def mistral_model_forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+) -> Union[Tuple, BaseModelOutputWithPast]:
+ output_attentions = (
+ output_attentions
+ if output_attentions is not None
+ else self.config.output_attentions
+ )
+ output_hidden_states = (
+ output_hidden_states
+ if output_hidden_states is not None
+ else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ return_dict = (
+ return_dict if return_dict is not None else self.config.use_return_dict
+ )
+
+ # retrieve input_ids and inputs_embeds
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError(
+ "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
+ )
+ if input_ids is not None:
+ batch_size, seq_length = input_ids.shape
+ elif inputs_embeds is not None:
+ batch_size, seq_length, _ = inputs_embeds.shape
+ else:
+ raise ValueError(
+ "You have to specify either decoder_input_ids or decoder_inputs_embeds"
+ )
+
+ seq_length_with_past = seq_length
+ past_key_values_length = 0
+
+ if past_key_values is not None:
+ past_key_values_length = past_key_values[0][0].shape[2]
+ seq_length_with_past = seq_length_with_past + past_key_values_length
+
+ cu_seqlens = None
+ max_seqlen = None
+ if position_ids is None:
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+ position_ids = torch.arange(
+ past_key_values_length,
+ seq_length + past_key_values_length,
+ dtype=torch.long,
+ device=device,
+ )
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
+ else:
+ position_ids = position_ids.view(-1, seq_length).long()
+ cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(position_ids)
+ cu_seqlens = cu_seqlens.squeeze()
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+ # embed positions
+ if attention_mask is None:
+ attention_mask = torch.ones(
+ (batch_size, seq_length_with_past),
+ dtype=torch.bool,
+ device=inputs_embeds.device,
+ )
+ attention_mask = (
+ self._prepare_decoder_attention_mask( # pylint: disable=protected-access
+ attention_mask,
+ (batch_size, seq_length),
+ inputs_embeds,
+ past_key_values_length,
+ sliding_window=self.config.sliding_window,
+ )
+ )
+
+ hidden_states = inputs_embeds
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ transformers.logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ next_decoder_cache = () if use_cache else None
+
+ for idx, decoder_layer in enumerate(self.layers):
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
+
+ if self.gradient_checkpointing and self.training:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ # None for past_key_value
+ return module(*inputs)
+
+ return custom_forward
+
+ layer_outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(decoder_layer),
+ hidden_states,
+ attention_mask,
+ position_ids,
+ past_key_value,
+ output_attentions,
+ None,
+ cu_seqlens,
+ max_seqlen,
+ )
+ else:
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cu_seqlens=cu_seqlens,
+ max_seqlen=max_seqlen,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ hidden_states = self.norm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ next_cache = next_decoder_cache if use_cache else None
+ if not return_dict:
+ return tuple(
+ v
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
+ if v is not None
+ )
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+
+class MistralDecoderLayer(OriginalMistralDecoderLayer):
+ """
+ patched version of MistralDecoderLayer to pass through the precalculated cu_seqlens
+ """
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ cu_seqlens: Optional[torch.Tensor] = None,
+ max_seqlen: Optional[torch.Tensor] = None,
+ ) -> Tuple[
+ torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
+ ]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
+ cu_seqlens (`torch.Tensor`, *optional*) cumulative sequence len when packing
+ """
+
+ residual = hidden_states
+
+ hidden_states = self.input_layernorm(hidden_states)
+
+ # Self Attention
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cu_seqlens=cu_seqlens,
+ max_seqlen=max_seqlen,
+ )
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ if use_cache:
+ outputs += (present_key_value,)
+
+ return outputs
diff --git a/src/axolotl/monkeypatch/mistral_embeddings_hijack.py b/src/axolotl/monkeypatch/mistral_embeddings_hijack.py
new file mode 100644
index 0000000000..ed5f259650
--- /dev/null
+++ b/src/axolotl/monkeypatch/mistral_embeddings_hijack.py
@@ -0,0 +1,40 @@
+"""
+patch to add noisy embeddings per https://arxiv.org/abs/2310.05914
+"""
+
+import torch
+import transformers.models.mistral.modeling_mistral
+from transformers.utils import logging
+
+logger = logging.get_logger(__name__)
+
+
+def replace_mistral_embeddings_with_uniform_distribution(noise_alpha=5):
+ # pylint: disable=duplicate-code
+ def noised_embed(orig_embed, noise_alpha, model):
+ def new_func(input_ids):
+ # during training, we add noise to the embedding
+ # during generation, we don't add noise to the embedding
+ if model.training:
+ embed_init = orig_embed(input_ids)
+ dims = torch.tensor(embed_init.size(1) * embed_init.size(2))
+ mag_norm = noise_alpha / torch.sqrt(dims)
+ return embed_init + torch.zeros_like(embed_init).uniform_(
+ -mag_norm, mag_norm
+ )
+ return orig_embed(input_ids)
+
+ return new_func
+
+ def post_init(orig_post_init):
+ def new_func(self):
+ orig_post_init(self)
+ self.embed_tokens.forward = noised_embed(
+ self.embed_tokens.forward, noise_alpha, self
+ )
+
+ return new_func
+
+ transformers.models.mistral.modeling_mistral.MistralModel.post_init = post_init(
+ transformers.models.mistral.modeling_mistral.MistralModel.post_init
+ )
diff --git a/src/axolotl/monkeypatch/stablelm_attn_hijack_flash.py b/src/axolotl/monkeypatch/stablelm_attn_hijack_flash.py
new file mode 100644
index 0000000000..0269f90157
--- /dev/null
+++ b/src/axolotl/monkeypatch/stablelm_attn_hijack_flash.py
@@ -0,0 +1,415 @@
+# coding=utf-8
+# Copyright 2023 Stability AI, EleutherAI, and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# This code is based off the following work:
+# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
+# https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py
+""" PyTorch StableLM Epoch model. """
+import importlib
+import math
+from typing import Optional, Tuple, Union
+
+import torch
+import torch.utils.checkpoint
+from accelerate import init_empty_weights
+from einops import rearrange
+from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
+ flash_attn_varlen_qkvpacked_func,
+)
+from torch import nn
+from transformers import AutoConfig, AutoModelForCausalLM
+from transformers.modeling_outputs import BaseModelOutputWithPast
+from transformers.utils import logging
+
+from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
+
+logger = logging.get_logger(__name__)
+
+
+def replace_stablelm_attn_with_flash_attn(model_name="stabilityai/stablelm-3b-4e1t"):
+ # this is a wonky hack to get the remotely loaded module
+ model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
+ # we need to load the model here in order for modeling_stablelm_epoch to be available
+ with init_empty_weights():
+ AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
+ module_name = model_config.__class__.__module__.replace(
+ ".configuration_stablelm_epoch", ".modeling_stablelm_epoch"
+ )
+ modeling_stablelm = importlib.import_module(module_name)
+ modeling_stablelm.Attention.forward = ( # pylint: disable=protected-access
+ flashattn_attn
+ )
+ modeling_stablelm.StableLMEpochModel.forward = ( # pylint: disable=protected-access
+ stablelm_model_forward
+ )
+ modeling_stablelm.DecoderLayer.forward = ( # pylint: disable=protected-access
+ decoder_layer_forward
+ )
+
+
+def rotate_half(x: torch.Tensor):
+ """Rotates half the hidden dims of the input."""
+ # pylint: disable=invalid-name
+ x1, x2 = torch.chunk(x, 2, dim=-1)
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
+ # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
+ # pylint: disable=invalid-name
+ cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
+ sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
+ cos = cos[position_ids].unsqueeze(1) # [batch_size, 1, seq_len, dim]
+ sin = sin[position_ids].unsqueeze(1) # [batch_size, 1, seq_len, dim]
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(
+ batch, num_key_value_heads, n_rep, slen, head_dim
+ )
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+def flashattn_attn(
+ self,
+ hidden_states: torch.FloatTensor,
+ attention_mask: torch.FloatTensor,
+ position_ids: torch.LongTensor,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: Optional[bool] = False, # pylint: disable=unused-argument
+ use_cache: Optional[bool] = False,
+ cu_seqlens: Optional[torch.Tensor] = None,
+ max_seqlen: Optional[torch.Tensor] = None,
+) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(
+ bsz, q_len, self.num_heads, self.head_dim
+ ).transpose(1, 2)
+ key_states = key_states.view(
+ bsz, q_len, self.num_key_value_heads, self.head_dim
+ ).transpose(1, 2)
+ value_states = value_states.view(
+ bsz, q_len, self.num_key_value_heads, self.head_dim
+ ).transpose(1, 2)
+
+ query_rot = query_states[..., : self.rotary_ndims]
+ query_pass = query_states[..., self.rotary_ndims :]
+ key_rot = key_states[..., : self.rotary_ndims]
+ key_pass = key_states[..., self.rotary_ndims :]
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ kv_seq_len += past_key_value[0].shape[-2]
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
+ query_states, key_states = apply_rotary_pos_emb(
+ query_rot, key_rot, cos, sin, position_ids
+ )
+
+ # [batch_size, num_heads, seq_len, head_dim]
+ query_states = torch.cat((query_states, query_pass), dim=-1)
+ key_states = torch.cat((key_states, key_pass), dim=-1)
+
+ if past_key_value is not None:
+ # Reuse k, v, self_attention
+ key_states = torch.cat((past_key_value[0], key_states), dim=2)
+ value_states = torch.cat((past_key_value[1], value_states), dim=2)
+
+ past_key_value = (key_states, value_states) if use_cache else None
+
+ # Repeat k/v heads if n_kv_heads < n_heads
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
+ # special handling using sample packing
+ qkv = torch.stack(
+ [query_states, key_states, value_states], dim=2
+ ) # [bsz, nh, 3, q_len, hd]
+ qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
+ qkv = rearrange(qkv, "b s ... -> (b s) ...")
+ softmax_scale = None
+
+ output = flash_attn_varlen_qkvpacked_func(
+ qkv, cu_seqlens, max_seqlen, 0.0, softmax_scale=softmax_scale, causal=True
+ )
+
+ attn_output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
+ attn_output = rearrange(attn_output, "b s h d -> b s (h d)")
+ else:
+ attn_weights = torch.matmul(
+ query_states, key_states.transpose(2, 3)
+ ) / math.sqrt(self.head_dim)
+
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
+ raise ValueError(
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
+ f" {attn_weights.size()}"
+ )
+
+ if attention_mask is not None:
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
+ raise ValueError(
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
+ )
+ attn_weights = attn_weights + attention_mask
+
+ # Upcast attention to fp32
+ attn_weights = nn.functional.softmax(
+ attn_weights, dim=-1, dtype=torch.float32
+ ).to(query_states.dtype)
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ # Merge heads
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+
+ # Final linear projection
+ attn_output = self.o_proj(attn_output)
+
+ return attn_output, None, past_key_value
+
+
+def decoder_layer_forward(
+ self,
+ hidden_states: Optional[torch.FloatTensor],
+ attention_mask: Optional[torch.FloatTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ cu_seqlens: Optional[torch.Tensor] = None,
+ max_seqlen: Optional[torch.Tensor] = None,
+) -> Union[
+ Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]
+]:
+ # pylint: disable=duplicate-code
+ residual = hidden_states
+
+ hidden_states = self.input_layernorm(hidden_states)
+
+ # Self Attention
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cu_seqlens=cu_seqlens,
+ max_seqlen=max_seqlen,
+ )
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ if use_cache:
+ outputs += (present_key_value,)
+
+ return outputs
+
+
+def stablelm_model_forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+) -> Union[Tuple, BaseModelOutputWithPast]:
+ # pylint: disable=duplicate-code
+ output_attentions = (
+ output_attentions
+ if output_attentions is not None
+ else self.config.output_attentions
+ )
+ output_hidden_states = (
+ output_hidden_states
+ if output_hidden_states is not None
+ else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ return_dict = (
+ return_dict if return_dict is not None else self.config.use_return_dict
+ )
+
+ # Retrieve input_ids and inputs_embeds
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError(
+ "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
+ )
+ if input_ids is not None:
+ batch_size, seq_length = input_ids.shape
+ elif inputs_embeds is not None:
+ batch_size, seq_length, _ = inputs_embeds.shape
+ else:
+ raise ValueError(
+ "You have to specify either decoder_input_ids or decoder_inputs_embeds"
+ )
+
+ seq_length_with_past = seq_length
+ past_key_values_length = 0
+
+ if past_key_values is not None:
+ past_key_values_length = past_key_values[0][0].shape[2]
+ seq_length_with_past = seq_length_with_past + past_key_values_length
+
+ cu_seqlens = None
+ max_seqlen = None
+ if position_ids is None:
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+ position_ids = torch.arange(
+ past_key_values_length,
+ seq_length + past_key_values_length,
+ dtype=torch.long,
+ device=device,
+ )
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
+ else:
+ position_ids = position_ids.view(-1, seq_length).long()
+ cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(position_ids)
+ cu_seqlens = cu_seqlens.squeeze()
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+ # Embed positions
+ if attention_mask is None:
+ attention_mask = torch.ones(
+ (batch_size, seq_length_with_past),
+ dtype=torch.bool,
+ device=inputs_embeds.device,
+ )
+ attention_mask = (
+ self._prepare_decoder_attention_mask( # pylint: disable=protected-access
+ attention_mask,
+ (batch_size, seq_length),
+ inputs_embeds,
+ past_key_values_length,
+ )
+ )
+
+ hidden_states = inputs_embeds
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ # Decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ next_decoder_cache = () if use_cache else None
+
+ for idx, decoder_layer in enumerate(self.layers):
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
+
+ if self.gradient_checkpointing and self.training:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ # None for past_key_value
+ return module(*inputs)
+
+ return custom_forward
+
+ layer_outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(decoder_layer),
+ hidden_states,
+ attention_mask,
+ position_ids,
+ past_key_value,
+ output_attentions,
+ None,
+ cu_seqlens,
+ max_seqlen,
+ )
+ else:
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cu_seqlens=cu_seqlens,
+ max_seqlen=max_seqlen,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ hidden_states = self.norm(hidden_states)
+
+ # Add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ next_cache = next_decoder_cache if use_cache else None
+ if not return_dict:
+ return tuple(
+ v
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
+ if v is not None
+ )
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
diff --git a/src/axolotl/prompt_strategies/alpaca_chat.py b/src/axolotl/prompt_strategies/alpaca_chat.py
index 17fe69be7b..975fee889e 100644
--- a/src/axolotl/prompt_strategies/alpaca_chat.py
+++ b/src/axolotl/prompt_strategies/alpaca_chat.py
@@ -1,6 +1,6 @@
-"""Module containing the AlpacaQAPromptTokenizingStrategy class"""
+"""Module for Alpaca prompt strategy classes"""
-from typing import Tuple
+from typing import Any, Dict, Optional, Tuple
from axolotl.prompt_tokenizers import (
AlpacaPromptTokenizingStrategy,
@@ -9,9 +9,13 @@
from axolotl.prompters import AlpacaPrompter, PromptStyle, UnpromptedPrompter
-def load(tokenizer, cfg):
+def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
+ prompt_style = PromptStyle.CHAT.value
+ if ds_cfg and "conversation" in ds_cfg:
+ prompt_style = ds_cfg["conversation"]
+
return AlpacaPromptTokenizingStrategy(
- AlpacaPrompter(PromptStyle.CHAT.value),
+ AlpacaPrompter(prompt_style),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
diff --git a/src/axolotl/prompt_strategies/completion.py b/src/axolotl/prompt_strategies/completion.py
index ee5b4cb3e8..3285e667cb 100644
--- a/src/axolotl/prompt_strategies/completion.py
+++ b/src/axolotl/prompt_strategies/completion.py
@@ -1,10 +1,81 @@
"""
Basic completion text
"""
-from typing import Any, Dict, Optional
+from collections import defaultdict
+from typing import Any, Dict, Generator, Optional, Tuple
-from axolotl.prompt_tokenizers import CompletionPromptTokenizingStrategy
-from axolotl.prompters import CompletionPrompter
+from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy
+
+
+class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
+ """
+ Tokenizing strategy for Completion prompts.
+ """
+
+ _field: str = "text"
+
+ def __init__(self, *args, max_length=None, **kwargs):
+ super().__init__(*args, **kwargs)
+ if max_length is not None:
+ self.max_length = max_length
+
+ @property
+ def supports_batched(self):
+ return True
+
+ @property
+ def field(self) -> str:
+ return self._field
+
+ @field.setter
+ def field(self, new_field: str):
+ self._field = new_field
+
+ def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
+ return (
+ prompt[self.field],
+ "",
+ "",
+ )
+
+ def tokenize_prompt(self, prompt):
+ res = defaultdict(lambda: [])
+ feature_names = list(prompt.keys())
+ for row in zip(*prompt.values()):
+ prompt_row = dict(zip(feature_names, row))
+ (
+ instruction,
+ _,
+ _,
+ ) = self.parse_instruction_fields(prompt_row)
+
+ full_prompt = self._build_full_prompt(instruction, None, None)
+ tokenized_full_prompt = self._tokenize(full_prompt)
+
+ for key, val in tokenized_full_prompt.items():
+ for i in range(0, len(val), self.sequence_len):
+ res[key].append(val[i : i + self.sequence_len])
+
+ return dict(res)
+
+ def _build_full_prompt(
+ self, instruction, input, response
+ ): # pylint: disable=redefined-builtin
+ return next(iter(self.prompter.build_prompt(instruction, input, response)))
+
+
+class CompletionPrompter:
+ """
+ Prompter for completion
+ """
+
+ def build_prompt(
+ self,
+ instruction: str,
+ input=None, # pylint: disable=redefined-builtin, unused-argument
+ output=None, # pylint: disable=unused-argument
+ ) -> Generator[str, None, None]:
+ yield instruction
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
@@ -13,6 +84,7 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
+ max_length=cfg.sequence_len * 64,
)
if ds_cfg and "field" in ds_cfg:
strat.field = ds_cfg["field"]
diff --git a/src/axolotl/prompt_strategies/context_qa.py b/src/axolotl/prompt_strategies/context_qa.py
index f7027c7e24..f87dd8b5cd 100644
--- a/src/axolotl/prompt_strategies/context_qa.py
+++ b/src/axolotl/prompt_strategies/context_qa.py
@@ -24,6 +24,15 @@ def load(tokenizer, cfg):
)
+def load_v2(tokenizer, cfg):
+ return ContextQaV2PromptTokenizingStrategy(
+ ContextV2Prompter(),
+ tokenizer,
+ cfg.train_on_inputs,
+ cfg.sequence_len,
+ )
+
+
class AlpacaContextPrompter(AlpacaPrompter):
"""
Customized system prompted for concise QA
@@ -50,6 +59,38 @@ def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
)
+class ContextQaV2PromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
+ """
+ Tokenization Strategy to combine in-context article with a question and answer
+ """
+
+ def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
+ return (
+ "Context: "
+ + prompt["context"]
+ + "\nQuestion: "
+ + prompt["question"]
+ + "\n",
+ "",
+ "Answer: " + prompt["answer"],
+ )
+
+
+class ContextV2Prompter(AlpacaPrompter):
+ """
+ Customized system prompted for concise QA
+ """
+
+ system_prompt = ""
+ system_no_input_prompt = ""
+
+ def match_prompt_style(self):
+ # pylint: disable=duplicate-code
+ self.turn_format = "{instruction}\n{input}"
+ self.turn_no_input_format = "{instruction}"
+ self.system_format = "{system}"
+
+
class AlpacaMissingInfoContextPromptTokenizingStrategy(
InstructionPromptTokenizingStrategy
):
diff --git a/src/axolotl/prompt_strategies/sharegpt_simple.py b/src/axolotl/prompt_strategies/sharegpt.py
similarity index 63%
rename from src/axolotl/prompt_strategies/sharegpt_simple.py
rename to src/axolotl/prompt_strategies/sharegpt.py
index bfe0d164bc..da36e778e0 100644
--- a/src/axolotl/prompt_strategies/sharegpt_simple.py
+++ b/src/axolotl/prompt_strategies/sharegpt.py
@@ -1,12 +1,35 @@
"""Module containing the SimpleShareGPTPromptTokenizingStrategy class"""
+from typing import Any, Dict, Optional
+
+from fastchat.conversation import Conversation, SeparatorStyle, register_conv_template
from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
-from axolotl.prompters import PromptStyle, ShareGPTPrompter
+from axolotl.prompters import ShareGPTPrompterV2
+
+register_conv_template(
+ Conversation(
+ name="chatml",
+ system_template="<|im_start|>system\n{system_message}",
+ system_message="You are a helpful assistant.",
+ roles=["<|im_start|>user", "<|im_start|>assistant"],
+ sep_style=SeparatorStyle.CHATML,
+ sep="<|im_end|>\n",
+ )
+)
-def load(tokenizer, cfg):
+def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
+ conversation = (
+ ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None
+ )
+ field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
+ field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
return SimpleShareGPTPromptTokenizingStrategy(
- ShareGPTPrompter(PromptStyle.CHAT.value),
+ ShareGPTPrompterV2(
+ conversation=conversation,
+ role_key_model=field_model,
+ role_key_human=field_human,
+ ),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
@@ -15,7 +38,7 @@ def load(tokenizer, cfg):
def load_role(tokenizer, cfg):
return SimpleRoleShareGPTPromptTokenizingStrategy(
- ShareGPTPrompter(PromptStyle.CHAT.value),
+ ShareGPTPrompterV2(),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
@@ -24,7 +47,7 @@ def load_role(tokenizer, cfg):
def load_guanaco(tokenizer, cfg):
return GuanacoShareGPTPromptTokenizingStrategy(
- ShareGPTPrompter(PromptStyle.CHAT.value),
+ ShareGPTPrompterV2(),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
diff --git a/src/axolotl/prompt_strategies/sharegpt_jokes.py b/src/axolotl/prompt_strategies/sharegpt_jokes.py
index ac424bf7c7..404302c81e 100644
--- a/src/axolotl/prompt_strategies/sharegpt_jokes.py
+++ b/src/axolotl/prompt_strategies/sharegpt_jokes.py
@@ -1,11 +1,11 @@
"""Module for Jokes prompts using sharegpt style """
from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
-from axolotl.prompters import PromptStyle, ShareGPTPrompter
+from axolotl.prompters import ShareGPTPrompterV2
def load(tokenizer, cfg):
return SimpleJokesShareGPTPromptTokenizingStrategy(
- ShareGPTPrompter(PromptStyle.CHAT.value),
+ ShareGPTPrompterV2(),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
diff --git a/src/axolotl/prompt_tokenizers.py b/src/axolotl/prompt_tokenizers.py
index f30d0e3832..4e30b81a71 100644
--- a/src/axolotl/prompt_tokenizers.py
+++ b/src/axolotl/prompt_tokenizers.py
@@ -2,12 +2,15 @@
import abc
import copy
-import functools
import logging
from typing import Dict, List, Tuple, Union
+from fastchat.conversation import Conversation
from transformers import BatchEncoding, PreTrainedTokenizer
+from axolotl.monkeypatch.fastchat_conversation_turns import (
+ add_get_turns_to_conversation,
+)
from axolotl.prompters import IGNORE_TOKEN_ID
LOG = logging.getLogger("axolotl")
@@ -18,6 +21,8 @@
LLAMA_DEFAULT_BOS_TOKEN = "" # nosec
LLAMA_DEFAULT_UNK_TOKEN = "" # nosec
+add_get_turns_to_conversation()
+
class InvalidDataException(Exception):
"""
@@ -41,43 +46,28 @@ def __init__(
self.tokenizer: PreTrainedTokenizer = tokenizer
self.train_on_inputs = train_on_inputs
self.sequence_len = sequence_len
+ self.max_length = sequence_len
@abc.abstractmethod
def tokenize_prompt(self, prompt):
pass
- @functools.lru_cache(maxsize=128)
- def _get_user_token(self):
- try:
- id_or_ids = self.tokenizer.convert_tokens_to_ids("<|USER|>")
- if isinstance(id_or_ids, (int,)):
- return id_or_ids
- except KeyError:
- pass
- return False
-
- @functools.lru_cache(maxsize=128)
- def _get_assistant_token(self):
- try:
- id_or_ids = self.tokenizer.convert_tokens_to_ids("<|ASSISTANT|>")
- if isinstance(id_or_ids, (int,)):
- return id_or_ids
- except KeyError:
- pass
+ @property
+ def supports_batched(self):
return False
def _tokenize(
self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False
) -> BatchEncoding:
result: BatchEncoding
- if not prompt.strip():
+ if not prompt:
LOG.warning("Empty text requested for tokenization.")
result = BatchEncoding(data={"input_ids": [], "attention_mask": []})
else:
result = self.tokenizer(
prompt,
truncation=True,
- max_length=self.sequence_len,
+ max_length=self.max_length,
padding=False,
return_tensors=None,
)
@@ -86,7 +76,7 @@ def _tokenize(
if (
len(result["input_ids"]) > 0
and result["input_ids"][-1] != self.tokenizer.eos_token_id
- and len(result["input_ids"]) < self.sequence_len
+ and len(result["input_ids"]) < self.max_length
and add_eos_token
):
result["input_ids"].append(self.tokenizer.eos_token_id)
@@ -247,46 +237,6 @@ def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
)
-class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
- """
- Tokenizing strategy for Completion prompts.
- """
-
- _field: str = "text"
-
- @property
- def field(self) -> str:
- return self._field
-
- @field.setter
- def field(self, new_field: str):
- self._field = new_field
-
- def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
- return (
- prompt[self.field],
- "",
- "",
- )
-
- def tokenize_prompt(self, prompt):
- (
- instruction,
- _,
- _,
- ) = self.parse_instruction_fields(prompt)
-
- full_prompt = self._build_full_prompt(instruction, None, None)
- tokenized_full_prompt = self._tokenize(full_prompt)
-
- return tokenized_full_prompt
-
- def _build_full_prompt(
- self, instruction, input, response
- ): # pylint: disable=redefined-builtin
- return next(iter(self.prompter.build_prompt(instruction, input, response)))
-
-
class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
"""
Tokenizing strategy for Reflection prompts.
@@ -385,51 +335,82 @@ def get_conversation_thread(self, prompt):
def tokenize_prompt(self, prompt):
result, current_len = tokenize_prompt_default()
- user_token = self._get_user_token()
- assistant_token = self._get_assistant_token()
+ conversation: Conversation = (
+ self.prompter._conversation.copy() # pylint: disable=protected-access
+ )
+
+ # support for custom roles from the dataset, only useful for vicuna style prompts/roles
+ role_remap = []
+ if (
+ conversation.name == "vicuna_v1.1"
+ and "roles" in prompt
+ and len(prompt["roles"]) >= 2
+ ):
+ role_remap = [
+ {"from": conversation.roles[0], "to": prompt["roles"][0]},
+ {"from": conversation.roles[1], "to": prompt["roles"][1]},
+ ]
+
try:
for _, part in enumerate(
self.prompter.build_prompt(self.get_conversation_thread(prompt))
):
if isinstance(part, tuple):
- if part[0] == "USER:":
- part = part[0] + part[1] if not user_token else part[1]
+ if conversation.roles[0] in part[0]:
+ role = (
+ part[0].replace(role_remap[0]["from"], role_remap[0]["to"])
+ if role_remap
+ else part[0]
+ )
+ turn = role + part[1]
# this is still the user query, we should
+ if not part[1].strip():
+ LOG.warning(f"user turn has empty text: {prompt}")
res = self._tokenize(
- part.strip(),
+ turn,
add_eos_token=False,
strip_bos_token=True,
)
- if user_token:
- res["input_ids"] = [user_token, *res["input_ids"]]
# everything from this is masked out from the labels
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
- elif part[0] == "ASSISTANT:":
+ elif conversation.roles[1] in part[0]:
# TODO label assistant token/tokens w/ IGNORE_TOKEN_ID
- part = part[0] + part[1] if not assistant_token else part[1]
- # this should be the assistent response, should end with an eos token
+ role = (
+ part[0].replace(role_remap[1]["from"], role_remap[1]["to"])
+ if role_remap
+ else part[0]
+ )
+ turn = role + part[1]
+ # this should be the assistant response, should end with an eos token
+ if not part[1].strip():
+ LOG.warning(f"assistant turn has empty text: {prompt}")
res = self._tokenize(
- part.strip(),
+ turn,
add_eos_token=True,
strip_bos_token=True,
)
- if assistant_token:
- res["input_ids"] = [
- assistant_token,
- *res["input_ids"],
- ]
+ role_res = self._tokenize(
+ role.rstrip(),
+ add_eos_token=False,
+ strip_bos_token=True,
+ )
# not masked out from labels
labels = copy.deepcopy(res["input_ids"])
- elif part[0] == "SYSTEM:":
- part = part[1] # Ignore the system role from preamble
+ len_role = len(role_res["input_ids"])
+ labels[:len_role] = [IGNORE_TOKEN_ID] * min(
+ len_role, len(labels)
+ )
+ elif part[0] == "":
+ turn = part[1]
# this is only ever the first part, should include the bos token and the user query
res = self._tokenize(
- part.strip(), add_eos_token=False, strip_bos_token=False
+ turn, add_eos_token=False, strip_bos_token=False
)
# everything from this is masked out from the labels
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
else:
LOG.warning(f"unhandled role: {part[0]}")
+ continue
# pylint: disable=duplicate-code
result, current_len = parse_tokenized_to_result(
@@ -444,22 +425,31 @@ def tokenize_prompt(self, prompt):
raise InvalidDataException(str(err)) from err
def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False):
- result = self.tokenizer(
- prompt,
- truncation=True,
- max_length=self.sequence_len,
- padding=False,
- return_tensors=None,
- )
+ if not prompt.strip():
+ LOG.warning("Empty text requested for tokenization.")
+ result = BatchEncoding(data={"input_ids": [], "attention_mask": []})
+ else:
+ result = self.tokenizer(
+ prompt,
+ truncation=True,
+ max_length=self.sequence_len,
+ padding=False,
+ return_tensors=None,
+ )
if (
- result["input_ids"][-1] != self.tokenizer.eos_token_id
+ len(result["input_ids"]) > 0
+ and result["input_ids"][-1] != self.tokenizer.eos_token_id
and len(result["input_ids"]) < self.sequence_len
and add_eos_token
):
result["input_ids"].append(self.tokenizer.eos_token_id)
result["attention_mask"].append(1)
- if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token:
+ if (
+ len(result["input_ids"]) > 0
+ and result["input_ids"][0] == self.tokenizer.bos_token_id
+ and strip_bos_token
+ ):
result["input_ids"] = result["input_ids"][1:]
result["attention_mask"] = result["attention_mask"][1:]
diff --git a/src/axolotl/prompters.py b/src/axolotl/prompters.py
index 5322a10182..7cd89886aa 100644
--- a/src/axolotl/prompters.py
+++ b/src/axolotl/prompters.py
@@ -1,9 +1,10 @@
"""Module containing prompters"""
-import dataclasses
import logging
-from enum import Enum, auto
-from typing import Generator, List, Optional, Tuple, Union
+from enum import Enum
+from typing import Generator, Optional, Union
+
+from fastchat.conversation import Conversation, get_conv_template
LOG = logging.getLogger("axolotl")
IGNORE_TOKEN_ID = -100
@@ -135,20 +136,6 @@ def match_prompt_style(self):
self.turn_no_input_format = "USER: Summarize the following article as a TL;DR.\n{instruction}\nASSISTANT:"
-class CompletionPrompter:
- """
- Prompter for completion
- """
-
- def build_prompt(
- self,
- instruction: str,
- input=None, # pylint: disable=redefined-builtin, unused-argument
- output=None, # pylint: disable=unused-argument
- ) -> Generator[str, None, None]:
- yield instruction
-
-
class GPTeacherPrompter(AlpacaPrompter):
"""
Prompter for GPTeacher
@@ -228,53 +215,6 @@ def build_prompt(
yield res
-class SeparatorStyle(Enum):
- """Different separator style."""
-
- SINGLE = auto()
- TWO = auto()
- DOLLY = auto()
-
-
-# TODO clean this 💩 up
-@dataclasses.dataclass
-class Conversation:
- """A class that keeps all conversation history."""
-
- system: str
- roles: List[str]
- messages: List[List[str]]
- offset: int
- sep_style: SeparatorStyle = SeparatorStyle.SINGLE
- sep: str = "###"
- sep2: Optional[str] = None
-
- def get_prompt(self) -> Generator[Tuple[str, str], None, None]:
- # seps = [self.sep, self.sep2]
- preamble = self.system + self.sep
- yield ("SYSTEM:", preamble)
- for _, (role, message) in enumerate(self.messages):
- if message:
- yield (role + ":", " " + message)
- else:
- LOG.warning(f"role with empty message: {role}")
- yield (role + ":", "")
-
- def copy(self):
- return Conversation(
- system=self.system,
- roles=self.roles,
- messages=[[x, y] for x, y in self.messages],
- offset=self.offset,
- sep_style=self.sep_style,
- sep=self.sep,
- sep2=self.sep2,
- )
-
- def append_message(self, role, message):
- self.messages.append([role, message])
-
-
SHAREGPT_ASSERTION_FAILED_ROLE = (
"Role did not alternate between turns (gpt and human). Please check your data."
)
@@ -285,28 +225,27 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
A prompter that generates prompts for the ShareGPT
"""
- def __init__(self, prompt_style=None, system_prompt: Optional[str] = None):
- if prompt_style != PromptStyle.CHAT.value:
- raise ValueError(
- f"unsupported prompt_style for ShareGPTPrompter({prompt_style})"
- )
- system: str = (
- system_prompt
- if system_prompt
- else (
- "A chat between a curious user and an artificial intelligence assistant. "
- "The assistant gives helpful, detailed, and polite answers to the user's questions."
- )
- )
- self._conversation = Conversation(
- system=system,
- roles=["USER", "ASSISTANT"],
- messages=[],
- offset=0,
- sep_style=SeparatorStyle.TWO,
- sep=" ",
- sep2=" ",
- )
+ role_key_human = "human"
+ role_key_model = "gpt"
+
+ def __init__(
+ self,
+ prompt_style=None, # pylint: disable=unused-argument
+ conversation: Optional[Union[str, Conversation]] = None,
+ role_key_human: Optional[str] = None,
+ role_key_model: Optional[str] = None,
+ ):
+ if conversation:
+ if isinstance(conversation, Conversation):
+ self._conversation = conversation
+ else:
+ self._conversation = get_conv_template(conversation)
+ else:
+ self._conversation = get_conv_template("vicuna_v1.1")
+ if role_key_human:
+ self.role_key_human = role_key_human
+ if role_key_model:
+ self.role_key_model = role_key_model
def build_prompt(self, source) -> Generator[str, None, None]:
if len(source) < 2:
@@ -320,17 +259,14 @@ def build_prompt(self, source) -> Generator[str, None, None]:
# Add the conversation system prompt if provided, otherwise use the default one
if source[0]["from"] == "system":
- conv.system = source[0]["value"]
+ conv.set_system_message(source[0]["value"])
source.pop(0)
- roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
+ roles = {self.role_key_human: conv.roles[0], self.role_key_model: conv.roles[1]}
try:
# Apply prompt templates
- if (
- source[0]["from"] not in roles
- or roles[source[0]["from"]] != conv.roles[0]
- ):
+ if source[0]["from"] not in roles:
# Skip the first one if it is not from human
source = source[1:]
except IndexError as err:
@@ -338,10 +274,33 @@ def build_prompt(self, source) -> Generator[str, None, None]:
raise err
conv.messages = []
- for j, sentence in enumerate(source):
+ for _, sentence in enumerate(source):
role = roles[sentence["from"]]
- assert role == conv.roles[j % 2], SHAREGPT_ASSERTION_FAILED_ROLE
+ if len(conv.messages) > 0 and (
+ (role == conv.messages[-1][0]) or (role not in conv.roles)
+ ):
+ LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}")
conv.append_message(role, sentence["value"])
- for part in conv.get_prompt():
+ for part in conv.get_turns():
+ if part[0] and not part[1]:
+ LOG.warning(f"role with empty message: {part[0]}")
yield part
+
+
+class ShareGPTPrompterV2(ShareGPTPrompter):
+ """
+ A V2 prompter that generates prompts for the ShareGPT
+ """
+
+ def __init__(
+ self,
+ conversation: Optional[Union[str, Conversation]] = None,
+ role_key_human: Optional[str] = None,
+ role_key_model: Optional[str] = None,
+ ):
+ super().__init__(
+ conversation=conversation,
+ role_key_human=role_key_human,
+ role_key_model=role_key_model,
+ )
diff --git a/src/axolotl/train.py b/src/axolotl/train.py
index 5ed5837f21..da98600a45 100644
--- a/src/axolotl/train.py
+++ b/src/axolotl/train.py
@@ -9,8 +9,7 @@
from typing import Optional
import torch
-
-# add src to the pythonpath so we don't need to pip install this
+import transformers.modelcard
from datasets import Dataset
from optimum.bettertransformer import BetterTransformer
@@ -103,6 +102,9 @@ def terminate_handler(_, __, model):
signal.SIGINT, lambda signum, frame: terminate_handler(signum, frame, model)
)
+ badge_markdown = """[](https://github.com/OpenAccess-AI-Collective/axolotl)"""
+ transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}"
+
LOG.info("Starting trainer...")
if cfg.group_by_length:
LOG.info("hang tight... sorting dataset for group_by_length")
@@ -138,4 +140,7 @@ def terminate_handler(_, __, model):
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
+ if not cfg.hub_model_id:
+ trainer.create_model_card(model_name=cfg.output_dir.lstrip("./"))
+
return model, tokenizer
diff --git a/src/axolotl/utils/bench.py b/src/axolotl/utils/bench.py
index b460b2ba7c..40be0d9ac8 100644
--- a/src/axolotl/utils/bench.py
+++ b/src/axolotl/utils/bench.py
@@ -1,14 +1,44 @@
"""Benchmarking and measurement utilities"""
+import functools
import pynvml
import torch
from pynvml.nvml import NVMLError
+def check_cuda_device(default_value):
+ """
+ wraps a function and returns the default value instead of running the
+ wrapped function if cuda isn't available or the device is auto
+ :param default_value:
+ :return:
+ """
+
+ def deco(func):
+ @functools.wraps(func)
+ def wrapper(*args, **kwargs):
+ device = kwargs.get("device", args[0] if args else None)
+
+ if (
+ not torch.cuda.is_available()
+ or device == "auto"
+ or torch.device(device).type == "cpu"
+ ):
+ return default_value
+
+ return func(*args, **kwargs)
+
+ return wrapper
+
+ return deco
+
+
+@check_cuda_device(0.0)
def gpu_memory_usage(device=0):
return torch.cuda.memory_allocated(device) / 1024.0**3
+@check_cuda_device((0.0, 0.0, 0.0))
def gpu_memory_usage_all(device=0):
usage = torch.cuda.memory_allocated(device) / 1024.0**3
reserved = torch.cuda.memory_reserved(device) / 1024.0**3
@@ -16,6 +46,7 @@ def gpu_memory_usage_all(device=0):
return usage, reserved - usage, max(0, smi - reserved)
+@check_cuda_device(0.0)
def gpu_memory_usage_smi(device=0):
if isinstance(device, torch.device):
device = device.index
@@ -31,9 +62,6 @@ def gpu_memory_usage_smi(device=0):
def log_gpu_memory_usage(log, msg, device):
- if not torch.cuda.is_available() or device == "auto":
- return (0, 0, 0)
-
usage, cache, misc = gpu_memory_usage_all(device)
extras = []
if cache > 0:
diff --git a/src/axolotl/utils/callbacks.py b/src/axolotl/utils/callbacks.py
index bf8c4145bd..458e537c63 100644
--- a/src/axolotl/utils/callbacks.py
+++ b/src/axolotl/utils/callbacks.py
@@ -43,26 +43,26 @@
IGNORE_INDEX = -100
-class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods
- """Callback to save the PEFT adapter"""
+class EvalFirstStepCallback(
+ TrainerCallback
+): # pylint: disable=too-few-public-methods disable=unused-argument
+ """
+ Callback to trigger evals on the first step
+ """
- def on_save(
+ def on_step_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
- checkpoint_folder = os.path.join(
- args.output_dir,
- f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
- )
-
- peft_model_path = os.path.join(checkpoint_folder, "adapter_model")
- kwargs["model"].save_pretrained(
- peft_model_path, save_safetensors=args.save_safetensors
- )
-
+ if (
+ args.evaluation_strategy == IntervalStrategy.STEPS
+ and args.eval_steps < 1.0
+ and state.global_step == 1
+ ):
+ control.should_evaluate = True
return control
@@ -514,3 +514,27 @@ def log_table_from_dataloader(name: str, table_dataloader):
return control
return LogPredictionCallback
+
+
+class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
+ """Callback to save axolotl config to wandb"""
+
+ def __init__(self, axolotl_config_path):
+ self.axolotl_config_path = axolotl_config_path
+
+ def on_train_begin(
+ self,
+ args: AxolotlTrainingArguments, # pylint: disable=unused-argument
+ state: TrainerState, # pylint: disable=unused-argument
+ control: TrainerControl,
+ **kwargs, # pylint: disable=unused-argument
+ ):
+ if is_main_process():
+ try:
+ artifact = wandb.Artifact(name="axolotl-config", type="config")
+ artifact.add_file(local_path=self.axolotl_config_path)
+ wandb.run.log_artifact(artifact)
+ LOG.info("Axolotl config has been saved to WandB as an artifact.")
+ except (FileNotFoundError, ConnectionError) as err:
+ LOG.warning(f"Error while saving Axolotl config to WandB: {err}")
+ return control
diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config.py
index 8fdfaafc16..1c0a15d675 100644
--- a/src/axolotl/utils/config.py
+++ b/src/axolotl/utils/config.py
@@ -49,6 +49,8 @@ def normalize_config(cfg):
cfg.batch_size = (
cfg.batch_size or cfg.micro_batch_size * cfg.gradient_accumulation_steps
)
+ if cfg.eval_batch_size is None:
+ cfg.eval_batch_size = cfg.micro_batch_size
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
cfg.eval_table_size = cfg.eval_table_size or 0
@@ -75,6 +77,8 @@ def normalize_config(cfg):
else:
cfg.torch_dtype = torch.float32
+ cfg.dataset_processes = cfg.dataset_processes or os.cpu_count()
+
model_config = load_model_config(cfg)
cfg.model_config_type = model_config.model_type
@@ -82,10 +86,39 @@ def normalize_config(cfg):
cfg.is_llama_derived_model = (
(hasattr(model_config, "model_type") and model_config.model_type == "llama")
or cfg.is_llama_derived_model
- or "llama" in cfg.base_model
+ or "llama" in cfg.base_model.lower()
or (cfg.model_type and "llama" in cfg.model_type.lower())
)
+ # figure out if the model is falcon
+ cfg.is_falcon_derived_model = (
+ (
+ hasattr(model_config, "model_type")
+ and model_config.model_type
+ in [
+ "falcon",
+ "RefinedWebModel",
+ "RefinedWeb",
+ ]
+ )
+ or cfg.is_falcon_derived_model
+ or "falcon" in cfg.base_model.lower()
+ or (cfg.model_type and "rwforcausallm" in cfg.model_type.lower())
+ )
+
+ cfg.is_mistral_derived_model = (
+ (
+ hasattr(model_config, "model_type")
+ and model_config.model_type
+ in [
+ "mistral",
+ ]
+ )
+ or cfg.is_mistral_derived_model
+ or "mistral" in cfg.base_model.lower()
+ or (cfg.model_type and "mistral" in cfg.model_type.lower())
+ )
+
log_gpu_memory_usage(LOG, "baseline", cfg.device)
@@ -94,7 +127,7 @@ def validate_config(cfg):
if not cfg.bf16 and not cfg.bfloat16:
LOG.info("bf16 support detected, but not enabled for this configuration.")
else:
- if cfg.bf16 or cfg.bfloat16:
+ if not cfg.merge_lora and (cfg.bf16 or cfg.bfloat16):
raise ValueError(
"bf16 requested, but AMP is not supported on this GPU. Requires Ampere series or above."
)
@@ -126,6 +159,11 @@ def validate_config(cfg):
"batch_size is not recommended. Please use gradient_accumulation_steps instead.",
"To calculate the equivalent gradient_accumulation_steps, divide batch_size / micro_batch_size / number of gpus.",
)
+ if cfg.eval_batch_size != cfg.micro_batch_size:
+ LOG.warning(
+ "eval_batch_size != micro_batch_size. This can lead to VRAM instability."
+ )
+
if cfg.load_4bit:
raise ValueError("cfg.load_4bit parameter has been deprecated")
@@ -265,6 +303,45 @@ def validate_config(cfg):
"`model_type: MixFormerSequentialForCausalLM` required for sample_packing"
)
+ if cfg.datasets:
+ for idx, ds_cfg in enumerate(cfg.datasets):
+ if not ds_cfg.type:
+ continue
+ if ds_cfg.type == "sharegpt:chat":
+ LOG.warning(
+ PendingDeprecationWarning(
+ "`type: sharegpt:chat` will soon be deprecated. simply use `type: sharegpt` instead."
+ )
+ )
+ cfg.datasets[idx].type = "sharegpt"
+ if "sharegpt_simple" in ds_cfg.type:
+ LOG.warning(
+ PendingDeprecationWarning(
+ "`type: sharegpt_simple` will soon be deprecated. simply use `type: sharegpt` instead."
+ )
+ )
+ cfg.datasets[idx].type = cfg.datasets[idx].type.replace(
+ "sharegpt_simple", "sharegpt"
+ )
+ if cfg.save_strategy and cfg.save_steps and cfg.save_strategy != "steps":
+ raise ValueError(
+ "save_strategy and save_steps mismatch. Please set save_strategy to 'steps' or remove save_steps."
+ )
+
+ if (
+ cfg.evaluation_strategy
+ and cfg.eval_steps
+ and cfg.evaluation_strategy != "steps"
+ ):
+ raise ValueError(
+ "evaluation_strategy and eval_steps mismatch. Please set evaluation_strategy to 'steps' or remove eval_steps."
+ )
+
+ if cfg.val_set_size == 0 and (cfg.eval_steps or cfg.evaluation_strategy):
+ raise ValueError(
+ "eval_steps and evaluation_strategy are not supported with val_set_size == 0"
+ )
+
# TODO
# MPT 7b
# https://github.com/facebookresearch/bitsandbytes/issues/25
diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py
index 2dc2d82b28..c944dd27ba 100644
--- a/src/axolotl/utils/data.py
+++ b/src/axolotl/utils/data.py
@@ -16,6 +16,7 @@
from huggingface_hub import hf_hub_download
from transformers import PreTrainedTokenizerBase
+from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.datasets import ConstantLengthDataset, TokenizedPromptDataset
from axolotl.prompt_strategies import load
from axolotl.prompt_tokenizers import (
@@ -25,7 +26,6 @@
GPTeacherPromptTokenizingStrategy,
JeopardyPromptTokenizingStrategy,
OpenAssistantPromptTokenizingStrategy,
- ShareGPTPromptTokenizingStrategy,
SummarizeTLDRPromptTokenizingStrategy,
)
from axolotl.prompters import (
@@ -35,7 +35,6 @@
MultipleChoiceConcisePrompter,
MultipleChoiceExplainPrompter,
ReflectAlpacaPrompter,
- ShareGPTPrompter,
SummarizeTLDRPrompter,
)
from axolotl.utils.dict import DictDefault
@@ -46,7 +45,6 @@
)
LOG = logging.getLogger("axolotl")
-DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
def md5(to_hash: str, encoding: str = "utf-8") -> str:
@@ -76,7 +74,7 @@ def prepare_dataset(cfg, tokenizer):
with zero_first(is_main_process()):
train_dataset, eval_dataset = process_datasets_for_packing(
- cfg, train_dataset, eval_dataset
+ cfg, train_dataset, eval_dataset, tokenizer
)
if cfg.max_steps:
total_num_steps = min(
@@ -116,7 +114,7 @@ def load_tokenized_prepared_datasets(
if cfg.push_dataset_to_hub:
dataset = load_dataset(
f"{cfg.push_dataset_to_hub}/{ds_hash}",
- use_auth_token=use_auth_token,
+ token=use_auth_token,
)
dataset = dataset["train"]
except Exception: # pylint: disable=broad-except # nosec
@@ -124,7 +122,7 @@ def load_tokenized_prepared_datasets(
if dataset:
...
- elif any(prepared_ds_path.glob("*")):
+ elif cfg.dataset_prepared_path and any(prepared_ds_path.glob("*")):
LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
dataset = load_from_disk(str(prepared_ds_path))
LOG.info("Prepared dataset loaded from disk...")
@@ -157,7 +155,7 @@ def for_d_in_datasets(dataset_configs):
d.path,
name=d.name,
streaming=True,
- use_auth_token=use_auth_token,
+ token=use_auth_token,
)
ds_from_hub = True
except FileNotFoundError:
@@ -204,14 +202,29 @@ def for_d_in_datasets(dataset_configs):
name=d.name,
streaming=False,
data_files=d.data_files,
- use_auth_token=use_auth_token,
+ token=use_auth_token,
)
else:
- fp = hf_hub_download(
- repo_id=d.path,
- repo_type="dataset",
- filename=d.data_files,
- )
+ if isinstance(d.data_files, str):
+ fp = hf_hub_download(
+ repo_id=d.path,
+ repo_type="dataset",
+ filename=d.data_files,
+ )
+ elif isinstance(d.data_files, list):
+ fp = []
+ for file in d.data_files:
+ fp.append(
+ hf_hub_download(
+ repo_id=d.path,
+ repo_type="dataset",
+ filename=file,
+ )
+ )
+ else:
+ raise ValueError(
+ "data_files must be either a string or list of strings"
+ )
ds = load_dataset(
"json", name=d.name, data_files=fp, streaming=False, split=None
)
@@ -234,6 +247,16 @@ def for_d_in_datasets(dataset_configs):
d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None
if "train" in ds:
ds = ds["train"]
+ elif (
+ isinstance(ds, DatasetDict)
+ and d.train_on_split
+ and d.train_on_split in ds
+ ):
+ ds = ds[d.train_on_split]
+ elif isinstance(ds, DatasetDict):
+ raise ValueError(
+ f"no train split found for dataset {d.path}, you may specify a split with 'train_on_split: `"
+ )
if (
"input_ids" in ds.features
and "attention_mask" in ds.features
@@ -320,15 +343,6 @@ def for_d_in_datasets(dataset_configs):
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
- elif d_base_type == "sharegpt":
- ds_strategy = ShareGPTPromptTokenizingStrategy(
- ShareGPTPrompter(d_prompt_style),
- tokenizer,
- cfg.train_on_inputs,
- cfg.sequence_len,
- )
- ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
- datasets.append(ds_wrapper)
else:
suffix = ""
if ":load_" in d.type:
@@ -403,7 +417,7 @@ def load_prepare_datasets(
)
dataset = load_dataset(
f"{cfg.push_dataset_to_hub}/{ds_hash}",
- use_auth_token=use_auth_token,
+ token=use_auth_token,
)
dataset = dataset["train"]
except Exception: # pylint: disable=broad-except # nosec
@@ -411,7 +425,7 @@ def load_prepare_datasets(
if dataset:
...
- elif any(prepared_ds_path.glob("*")):
+ elif cfg.dataset_prepared_path and any(prepared_ds_path.glob("*")):
LOG.info(
f"Loading prepared packed dataset from disk at {prepared_ds_path}..."
)
diff --git a/src/axolotl/utils/distributed.py b/src/axolotl/utils/distributed.py
index 10f06538ab..9a1c689fb7 100644
--- a/src/axolotl/utils/distributed.py
+++ b/src/axolotl/utils/distributed.py
@@ -77,7 +77,9 @@ def gather_scalar_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-n
value_scalar = fn()
if not is_distributed():
return [value_scalar]
- value_tensor = torch.tensor(value_scalar, device=dist.get_rank()).float()
+ value_tensor = torch.tensor(
+ value_scalar, device=torch.cuda.current_device()
+ ).float()
if not is_main_process():
dist.gather(value_tensor, dst=0)
@@ -137,9 +139,13 @@ def compute_and_broadcast(fn): # pylint: disable=invalid-name
"""
if is_main_process():
value_scalar = fn()
- value_tensor = torch.tensor(value_scalar, device=dist.get_rank()).float()
+ value_tensor = torch.tensor(
+ value_scalar, device=torch.cuda.current_device()
+ ).float()
else:
- value_tensor = torch.tensor(0.0, device=dist.get_rank()) # Placeholder tensor
+ value_tensor = torch.tensor(
+ 0.0, device=torch.cuda.current_device()
+ ) # Placeholder tensor
# Broadcast the tensor to all processes.
barrier()
@@ -164,7 +170,9 @@ def gather_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-name
- A list of computed values from all ranks if on the gathering rank, otherwise None.
"""
value_scalar = fn()
- value_tensor = torch.tensor(value_scalar, device=dist.get_rank()).float()
+ value_tensor = torch.tensor(
+ value_scalar, device=torch.cuda.current_device()
+ ).float()
# Placeholder tensor for gathering results
if is_main_process():
diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py
index f7160c477a..1ca6b9ff70 100644
--- a/src/axolotl/utils/models.py
+++ b/src/axolotl/utils/models.py
@@ -1,5 +1,4 @@
"""Module for models and model loading"""
-import importlib
import logging
import math
import os
@@ -10,7 +9,9 @@
import transformers
from optimum.bettertransformer import BetterTransformer
from peft import PeftConfig, prepare_model_for_kbit_training
+from peft.tuners.lora import QuantLinear
from transformers import ( # noqa: F401
+ AddedToken,
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
@@ -80,11 +81,22 @@ def load_tokenizer(cfg):
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
os.environ["TOKENIZERS_PARALLELISM"] = "false"
+ # Mistral's official FA implementation requires left padding
+ if cfg.is_mistral_derived_model and cfg.flash_attention and not cfg.sample_packing:
+ tokenizer.padding_side = "left"
+
if cfg.special_tokens:
for k, val in cfg.special_tokens.items():
- tokenizer.add_special_tokens({k: val})
+ tokenizer.add_special_tokens(
+ {k: AddedToken(val, rstrip=False, lstrip=False, normalized=False)}
+ )
if cfg.tokens:
- tokenizer.add_tokens(list(cfg.tokens))
+ tokenizer.add_tokens(
+ [
+ AddedToken(token, rstrip=False, lstrip=False, normalized=False)
+ for token in cfg.tokens
+ ]
+ )
return tokenizer
@@ -113,26 +125,29 @@ def load_model(
replace_btlm_attn_with_flash_attn(cfg.base_model)
- if hasattr(model_config, "model_type") and model_config.model_type in [
- "falcon",
- "RefinedWebModel",
- "RefinedWeb",
- ]:
- if cfg.flash_attention:
- from axolotl.monkeypatch.falcon_attn_hijack_flash import (
- replace_falcon_attn_with_flash_attn,
+ if (
+ hasattr(model_config, "model_type")
+ and model_config.model_type == "stablelm_epoch"
+ ):
+ if cfg.flash_attention and cfg.sample_packing:
+ from axolotl.monkeypatch.stablelm_attn_hijack_flash import (
+ replace_stablelm_attn_with_flash_attn,
)
- replace_falcon_attn_with_flash_attn()
+ replace_stablelm_attn_with_flash_attn(cfg.base_model)
- if cfg.is_llama_derived_model and cfg.flash_attention:
+ if cfg.is_llama_derived_model and cfg.flash_attention and cfg.sample_packing:
if cfg.device not in ["mps", "cpu"] and not inference:
from axolotl.monkeypatch.llama_attn_hijack_flash import (
replace_llama_attn_with_flash_attn,
)
- LOG.info("patching with flash attention")
- replace_llama_attn_with_flash_attn(packed=cfg.sample_packing)
+ LOG.info("patching with flash attention for sample packing")
+ replace_llama_attn_with_flash_attn(
+ packed=cfg.sample_packing,
+ cross_entropy=cfg.flash_attn_cross_entropy,
+ rms_norm=cfg.flash_attn_rms_norm,
+ )
elif cfg.is_llama_derived_model and cfg.xformers_attention:
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
hijack_llama_attention,
@@ -157,6 +172,34 @@ def load_model(
# Note: This might overwrite previous additional_special_tokens
tokenizer.add_special_tokens({"additional_special_tokens": [MEM_TOKEN]})
+ if cfg.is_mistral_derived_model and cfg.flash_attention and cfg.sample_packing:
+ from axolotl.monkeypatch.mistral_attn_hijack_flash import (
+ replace_mistral_attn_with_flash_attn,
+ )
+
+ LOG.info("patching with flash attention")
+ replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing)
+
+ if cfg.is_llama_derived_model and cfg.noisy_embedding_alpha:
+ from axolotl.monkeypatch.llama_embeddings_hijack import (
+ replace_llama_embeddings_with_uniform_distribution,
+ )
+
+ LOG.info("patching with noisy embeddings")
+ replace_llama_embeddings_with_uniform_distribution(
+ noise_alpha=cfg.noisy_embedding_alpha
+ )
+
+ if cfg.is_mistral_derived_model and cfg.noisy_embedding_alpha:
+ from axolotl.monkeypatch.mistral_embeddings_hijack import (
+ replace_mistral_embeddings_with_uniform_distribution,
+ )
+
+ LOG.info("patching with noisy embeddings")
+ replace_mistral_embeddings_with_uniform_distribution(
+ noise_alpha=cfg.noisy_embedding_alpha
+ )
+
if cfg.is_llama_derived_model and cfg.xpos_rope:
from axolotl.monkeypatch.xpos_rope_llama_monkey_patch import (
replace_llama_rope_with_xpos_rope,
@@ -175,21 +218,11 @@ def load_model(
LOG.info("patching _expand_mask")
hijack_expand_mask()
- # special handling b/c remote MixFormers code doesn't have _no_split_modules set
- if (
- "MixFormerSequentialConfig" in model_config.__class__.__name__
- and cfg.model_type == "AutoModelForCausalLM"
- ):
- module_name = model_config.__class__.__module__.replace(
- ".configuration_mixformer_sequential", ".modeling_mixformer_sequential"
- )
- modeling_phi = importlib.import_module(module_name)
- # pylint:disable=protected-access
- modeling_phi.MixFormerSequentialForCausalLM._no_split_modules = [
- "ParallelBlock"
- ]
-
model_kwargs = {}
+
+ model_kwargs["device_map"] = cfg.device_map
+ model_kwargs["torch_dtype"] = cfg.torch_dtype
+
if cfg.model_revision:
model_kwargs["revision"] = cfg.model_revision
if cfg.gptq:
@@ -212,6 +245,15 @@ def load_model(
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
)
+ # sample packing uses custom FA2 patch
+ if cfg.flash_attention and not cfg.sample_packing:
+ if (
+ cfg.is_llama_derived_model
+ or cfg.is_falcon_derived_model
+ or cfg.is_mistral_derived_model
+ ):
+ model_kwargs["use_flash_attention_2"] = True
+
try:
if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq:
from transformers import LlamaForCausalLM
@@ -226,10 +268,8 @@ def load_model(
model = LlamaForCausalLM.from_pretrained(
base_model,
config=config,
- device_map=cfg.device_map,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
- torch_dtype=cfg.torch_dtype,
**model_kwargs,
)
# elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention:
@@ -263,28 +303,22 @@ def load_model(
model = MixFormerSequentialForCausalLM.from_pretrained(
base_model,
- device_map=cfg.device_map,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
- torch_dtype=cfg.torch_dtype,
**model_kwargs,
)
elif model_type and not cfg.trust_remote_code:
if cfg.gptq:
model = AutoModelForCausalLM.from_pretrained(
base_model,
- device_map=cfg.device_map,
- torch_dtype=cfg.torch_dtype,
trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)
else:
model = getattr(transformers, model_type).from_pretrained(
base_model,
- device_map=cfg.device_map,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
- torch_dtype=cfg.torch_dtype,
trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)
@@ -309,16 +343,22 @@ def load_model(
):
config.max_sequence_length = cfg.sequence_len
LOG.warning(f"increasing context length to {cfg.sequence_len}")
- model = AutoModelForCausalLM.from_pretrained(
- base_model,
- config=config,
- device_map=cfg.device_map,
- load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
- load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
- torch_dtype=cfg.torch_dtype,
- trust_remote_code=cfg.trust_remote_code or False,
- **model_kwargs,
- )
+ if cfg.gptq:
+ model = AutoModelForCausalLM.from_pretrained(
+ base_model,
+ config=config,
+ trust_remote_code=cfg.trust_remote_code or False,
+ **model_kwargs,
+ )
+ else:
+ model = AutoModelForCausalLM.from_pretrained(
+ base_model,
+ config=config,
+ load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
+ load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
+ trust_remote_code=cfg.trust_remote_code or False,
+ **model_kwargs,
+ )
except Exception as err: # pylint: disable=broad-exception-caught
LOG.error(
"Exception raised attempting to load model, retrying with AutoModelForCausalLM"
@@ -326,10 +366,8 @@ def load_model(
LOG.exception(err)
model = AutoModelForCausalLM.from_pretrained(
base_model,
- device_map=cfg.device_map,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
- torch_dtype=cfg.torch_dtype,
trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)
@@ -468,10 +506,10 @@ def load_llama_adapter(model, cfg):
def find_all_linear_names(model):
- cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear)
+ cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear, QuantLinear)
lora_module_names = set()
for name, module in model.named_modules():
- if isinstance(module, cls):
+ if isinstance(module, cls) or "Linear" in module.__class__.__name__:
names = name.split(".")
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
diff --git a/src/axolotl/utils/tokenization.py b/src/axolotl/utils/tokenization.py
index 4746ceddef..deb9bb00ec 100644
--- a/src/axolotl/utils/tokenization.py
+++ b/src/axolotl/utils/tokenization.py
@@ -31,7 +31,9 @@ def check_example_labels(example, tokenizer, text_only=False):
)
colored_tokens.append(colored_token)
- LOG.info(" ".join(colored_tokens))
+ delimiter = "" if text_only else " "
+ LOG.info(delimiter.join(colored_tokens))
LOG.info("\n\n\n")
+ print(" ".join(colored_tokens))
return " ".join(colored_tokens)
diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py
index 2067a90069..820202b80b 100644
--- a/src/axolotl/utils/trainer.py
+++ b/src/axolotl/utils/trainer.py
@@ -28,9 +28,10 @@
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
from axolotl.utils.callbacks import (
+ EvalFirstStepCallback,
GPUStatsCallback,
+ SaveAxolotlConfigtoWandBCallback,
SaveBetterTransformerModelCallback,
- SavePeftModelCallback,
bench_eval_callback_factory,
log_prediction_callback_factory,
)
@@ -397,23 +398,38 @@ def disable_datasets_caching():
set_caching_enabled(True)
-def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
+def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
drop_long = partial(drop_long_seq, sequence_len=cfg.sequence_len)
with zero_first(is_main_process()):
- train_dataset = train_dataset.filter(drop_long, num_proc=os.cpu_count())
+ train_dataset = train_dataset.filter(drop_long, num_proc=cfg.dataset_processes)
if eval_dataset:
- eval_dataset = eval_dataset.filter(drop_long, num_proc=os.cpu_count())
+ eval_dataset = eval_dataset.filter(
+ drop_long, num_proc=cfg.dataset_processes
+ )
if cfg.group_by_length:
- train_dataset = train_dataset.map(add_length, num_proc=os.cpu_count())
+ train_dataset = train_dataset.map(
+ add_length, num_proc=cfg.dataset_processes
+ )
if cfg.sample_packing:
- train_dataset = train_dataset.map(add_position_ids, num_proc=os.cpu_count())
+ train_dataset = train_dataset.map(
+ add_position_ids, num_proc=cfg.dataset_processes
+ )
if cfg.eval_sample_packing is not False:
if eval_dataset:
eval_dataset = eval_dataset.map(
- add_position_ids, num_proc=os.cpu_count()
+ add_position_ids, num_proc=cfg.dataset_processes
)
+
+ # Phi doesn't want the attention_mask feature when training
+ if "CodeGenTokenizer" in tokenizer.__class__.__name__ or (
+ cfg.is_mistral_derived_model and cfg.flash_attention
+ ):
+ train_dataset = train_dataset.remove_columns("attention_mask")
+ if eval_dataset:
+ eval_dataset = eval_dataset.remove_columns("attention_mask")
+
return train_dataset, eval_dataset
@@ -597,26 +613,19 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
"sample_packing_efficiency"
] = cfg.sample_packing_eff_est
- if cfg.eval_steps and cfg.evaluation_strategy:
- # assume if the user set both, they know what they're doing
- training_arguments_kwargs["evaluation_strategy"] = cfg.evaluation_strategy
+ if cfg.eval_steps:
+ training_arguments_kwargs["evaluation_strategy"] = "steps"
training_arguments_kwargs["eval_steps"] = cfg.eval_steps
+ elif cfg.evaluation_strategy:
+ training_arguments_kwargs["evaluation_strategy"] = cfg.evaluation_strategy
elif cfg.val_set_size == 0:
# no eval set, so don't eval
training_arguments_kwargs["evaluation_strategy"] = "no"
- elif cfg.evaluation_strategy and cfg.evaluation_strategy in ["epoch", "no"]:
- # if explicitly set for epoch, just set, and eval steps don't matter
- training_arguments_kwargs["evaluation_strategy"] = cfg.evaluation_strategy
- elif cfg.eval_steps:
- # steps isn't used w/ epochs
- training_arguments_kwargs["evaluation_strategy"] = "steps"
- training_arguments_kwargs["eval_steps"] = cfg.eval_steps
else:
# we have an eval set, but no steps defined, default to use epoch
training_arguments_kwargs["evaluation_strategy"] = "epoch"
if cfg.save_steps:
- # save_steps implies save_strategy of steps
training_arguments_kwargs["save_strategy"] = "steps"
training_arguments_kwargs["save_steps"] = cfg.save_steps
elif cfg.save_strategy:
@@ -662,9 +671,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
max_steps=total_num_steps if cfg.max_steps else -1,
max_seq_length=cfg.sequence_len,
per_device_train_batch_size=cfg.micro_batch_size,
- per_device_eval_batch_size=cfg.eval_batch_size
- if cfg.eval_batch_size is not None
- else cfg.micro_batch_size,
+ per_device_eval_batch_size=cfg.eval_batch_size,
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
eval_accumulation_steps=cfg.gradient_accumulation_steps,
num_train_epochs=cfg.num_epochs,
@@ -675,6 +682,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
(cfg.load_best_model_at_end is not False or cfg.early_stopping_patience)
and cfg.val_set_size > 0
and cfg.save_steps
+ and cfg.eval_steps
and cfg.save_steps % cfg.eval_steps == 0
)
or False,
@@ -704,16 +712,11 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
callbacks = []
callbacks.append(GPUStatsCallback(cfg))
+ callbacks.append(EvalFirstStepCallback)
if cfg.relora_steps:
callbacks.append(ReLoRACallback(cfg))
- if cfg.local_rank == 0 and cfg.adapter in [
- "lora",
- "qlora",
- ]: # only save in rank 0
- callbacks.append(SavePeftModelCallback)
-
if hasattr(model, "use_bettertransformer") and model.use_bettertransformer is True:
callbacks.append(SaveBetterTransformerModelCallback)
@@ -775,6 +778,9 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
LogPredictionCallback = log_prediction_callback_factory(trainer, tokenizer)
trainer.add_callback(LogPredictionCallback(cfg))
+ if cfg.use_wandb:
+ trainer.add_callback(SaveAxolotlConfigtoWandBCallback(cfg.axolotl_config_path))
+
if cfg.do_bench_eval:
trainer.add_callback(bench_eval_callback_factory(trainer, tokenizer))
diff --git a/tests/e2e/test_lora_llama.py b/tests/e2e/test_lora_llama.py
index fbca33633e..7d4b75cceb 100644
--- a/tests/e2e/test_lora_llama.py
+++ b/tests/e2e/test_lora_llama.py
@@ -6,6 +6,7 @@
import os
import tempfile
import unittest
+from pathlib import Path
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
@@ -24,6 +25,7 @@ class TestLoraLlama(unittest.TestCase):
def test_lora(self):
# pylint: disable=duplicate-code
+ output_dir = tempfile.mkdtemp()
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
@@ -51,7 +53,7 @@ def test_lora(self):
"num_epochs": 2,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
- "output_dir": tempfile.mkdtemp(),
+ "output_dir": output_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
@@ -62,9 +64,11 @@ def test_lora(self):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
+ assert (Path(output_dir) / "adapter_model.bin").exists()
def test_lora_packing(self):
# pylint: disable=duplicate-code
+ output_dir = tempfile.mkdtemp()
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
@@ -94,7 +98,7 @@ def test_lora_packing(self):
"num_epochs": 2,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
- "output_dir": tempfile.mkdtemp(),
+ "output_dir": output_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
@@ -105,3 +109,53 @@ def test_lora_packing(self):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
+ assert (Path(output_dir) / "adapter_model.bin").exists()
+
+ def test_lora_gptq(self):
+ # pylint: disable=duplicate-code
+ output_dir = tempfile.mkdtemp()
+ cfg = DictDefault(
+ {
+ "base_model": "TheBlokeAI/jackfram_llama-68m-GPTQ",
+ "base_model_config": "TheBlokeAI/jackfram_llama-68m-GPTQ",
+ "model_type": "AutoModelForCausalLM",
+ "tokenizer_type": "LlamaTokenizer",
+ "sequence_len": 1024,
+ "sample_packing": True,
+ "flash_attention": True,
+ "load_in_8bit": True,
+ "adapter": "lora",
+ "gptq": True,
+ "gptq_disable_exllama": True,
+ "lora_r": 32,
+ "lora_alpha": 64,
+ "lora_dropout": 0.05,
+ "lora_target_linear": True,
+ "val_set_size": 0.1,
+ "special_tokens": {
+ "unk_token": "",
+ "bos_token": "",
+ "eos_token": "",
+ },
+ "datasets": [
+ {
+ "path": "mhenrichsen/alpaca_2k_test",
+ "type": "alpaca",
+ },
+ ],
+ "num_epochs": 2,
+ "save_steps": 0.5,
+ "micro_batch_size": 8,
+ "gradient_accumulation_steps": 1,
+ "output_dir": output_dir,
+ "learning_rate": 0.00001,
+ "optimizer": "adamw_torch",
+ "lr_scheduler": "cosine",
+ }
+ )
+ normalize_config(cfg)
+ cli_args = TrainerCliArgs()
+ dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
+
+ train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
+ assert (Path(output_dir) / "adapter_model.bin").exists()
diff --git a/tests/e2e/test_mistral.py b/tests/e2e/test_mistral.py
new file mode 100644
index 0000000000..f3098f0583
--- /dev/null
+++ b/tests/e2e/test_mistral.py
@@ -0,0 +1,116 @@
+"""
+E2E tests for lora llama
+"""
+
+import logging
+import os
+import tempfile
+import unittest
+from pathlib import Path
+
+from transformers.utils import is_torch_bf16_gpu_available
+
+from axolotl.cli import load_datasets
+from axolotl.common.cli import TrainerCliArgs
+from axolotl.train import train
+from axolotl.utils.config import normalize_config
+from axolotl.utils.dict import DictDefault
+
+LOG = logging.getLogger("axolotl.tests.e2e")
+os.environ["WANDB_DISABLED"] = "true"
+
+
+class TestMistral(unittest.TestCase):
+ """
+ Test case for Llama models using LoRA
+ """
+
+ def test_lora(self):
+ # pylint: disable=duplicate-code
+ output_dir = tempfile.mkdtemp()
+ cfg = DictDefault(
+ {
+ "base_model": "openaccess-ai-collective/tiny-mistral",
+ "base_model_config": "openaccess-ai-collective/tiny-mistral",
+ "flash_attention": True,
+ "sequence_len": 1024,
+ "load_in_8bit": True,
+ "adapter": "lora",
+ "lora_r": 32,
+ "lora_alpha": 64,
+ "lora_dropout": 0.05,
+ "lora_target_linear": True,
+ "val_set_size": 0.1,
+ "special_tokens": {
+ "unk_token": "",
+ "bos_token": "",
+ "eos_token": "",
+ },
+ "datasets": [
+ {
+ "path": "mhenrichsen/alpaca_2k_test",
+ "type": "alpaca",
+ },
+ ],
+ "num_epochs": 2,
+ "micro_batch_size": 2,
+ "gradient_accumulation_steps": 1,
+ "output_dir": output_dir,
+ "learning_rate": 0.00001,
+ "optimizer": "adamw_torch",
+ "lr_scheduler": "cosine",
+ "max_steps": 20,
+ "save_steps": 10,
+ "eval_steps": 10,
+ }
+ )
+ normalize_config(cfg)
+ cli_args = TrainerCliArgs()
+ dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
+
+ train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
+ assert (Path(output_dir) / "adapter_model.bin").exists()
+
+ def test_ft(self):
+ # pylint: disable=duplicate-code
+ output_dir = tempfile.mkdtemp()
+ cfg = DictDefault(
+ {
+ "base_model": "openaccess-ai-collective/tiny-mistral",
+ "base_model_config": "openaccess-ai-collective/tiny-mistral",
+ "flash_attention": True,
+ "sequence_len": 1024,
+ "val_set_size": 0.1,
+ "special_tokens": {
+ "unk_token": "",
+ "bos_token": "",
+ "eos_token": "",
+ },
+ "datasets": [
+ {
+ "path": "mhenrichsen/alpaca_2k_test",
+ "type": "alpaca",
+ },
+ ],
+ "num_epochs": 2,
+ "micro_batch_size": 2,
+ "gradient_accumulation_steps": 1,
+ "output_dir": output_dir,
+ "learning_rate": 0.00001,
+ "optimizer": "adamw_torch",
+ "lr_scheduler": "cosine",
+ "max_steps": 20,
+ "save_steps": 10,
+ "eval_steps": 10,
+ }
+ )
+ if is_torch_bf16_gpu_available():
+ cfg.bf16 = True
+ else:
+ cfg.fp16 = True
+ normalize_config(cfg)
+ cli_args = TrainerCliArgs()
+ dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
+
+ train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
+ assert (Path(output_dir) / "pytorch_model.bin").exists()
diff --git a/tests/e2e/test_mistral_samplepack.py b/tests/e2e/test_mistral_samplepack.py
new file mode 100644
index 0000000000..623d20b0c2
--- /dev/null
+++ b/tests/e2e/test_mistral_samplepack.py
@@ -0,0 +1,118 @@
+"""
+E2E tests for lora llama
+"""
+
+import logging
+import os
+import tempfile
+import unittest
+from pathlib import Path
+
+from transformers.utils import is_torch_bf16_gpu_available
+
+from axolotl.cli import load_datasets
+from axolotl.common.cli import TrainerCliArgs
+from axolotl.train import train
+from axolotl.utils.config import normalize_config
+from axolotl.utils.dict import DictDefault
+
+LOG = logging.getLogger("axolotl.tests.e2e")
+os.environ["WANDB_DISABLED"] = "true"
+
+
+class TestMistral(unittest.TestCase):
+ """
+ Test case for Llama models using LoRA
+ """
+
+ def test_lora_packing(self):
+ # pylint: disable=duplicate-code
+ output_dir = tempfile.mkdtemp()
+ cfg = DictDefault(
+ {
+ "base_model": "openaccess-ai-collective/tiny-mistral",
+ "base_model_config": "openaccess-ai-collective/tiny-mistral",
+ "flash_attention": True,
+ "sample_packing": True,
+ "sequence_len": 1024,
+ "load_in_8bit": True,
+ "adapter": "lora",
+ "lora_r": 32,
+ "lora_alpha": 64,
+ "lora_dropout": 0.05,
+ "lora_target_linear": True,
+ "val_set_size": 0.1,
+ "special_tokens": {
+ "unk_token": "",
+ "bos_token": "",
+ "eos_token": "",
+ },
+ "datasets": [
+ {
+ "path": "mhenrichsen/alpaca_2k_test",
+ "type": "alpaca",
+ },
+ ],
+ "num_epochs": 2,
+ "micro_batch_size": 2,
+ "gradient_accumulation_steps": 1,
+ "output_dir": output_dir,
+ "learning_rate": 0.00001,
+ "optimizer": "adamw_torch",
+ "lr_scheduler": "cosine",
+ "max_steps": 20,
+ "save_steps": 10,
+ "eval_steps": 10,
+ }
+ )
+ normalize_config(cfg)
+ cli_args = TrainerCliArgs()
+ dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
+
+ train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
+ assert (Path(output_dir) / "adapter_model.bin").exists()
+
+ def test_ft_packing(self):
+ # pylint: disable=duplicate-code
+ output_dir = tempfile.mkdtemp()
+ cfg = DictDefault(
+ {
+ "base_model": "openaccess-ai-collective/tiny-mistral",
+ "base_model_config": "openaccess-ai-collective/tiny-mistral",
+ "flash_attention": True,
+ "sample_packing": True,
+ "sequence_len": 1024,
+ "val_set_size": 0.1,
+ "special_tokens": {
+ "unk_token": "",
+ "bos_token": "",
+ "eos_token": "",
+ },
+ "datasets": [
+ {
+ "path": "mhenrichsen/alpaca_2k_test",
+ "type": "alpaca",
+ },
+ ],
+ "num_epochs": 2,
+ "micro_batch_size": 2,
+ "gradient_accumulation_steps": 1,
+ "output_dir": output_dir,
+ "learning_rate": 0.00001,
+ "optimizer": "adamw_torch",
+ "lr_scheduler": "cosine",
+ "max_steps": 20,
+ "save_steps": 10,
+ "eval_steps": 10,
+ }
+ )
+ if is_torch_bf16_gpu_available():
+ cfg.bf16 = True
+ else:
+ cfg.fp16 = True
+ normalize_config(cfg)
+ cli_args = TrainerCliArgs()
+ dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
+
+ train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
+ assert (Path(output_dir) / "pytorch_model.bin").exists()
diff --git a/tests/e2e/test_phi.py b/tests/e2e/test_phi.py
index fb8aa5d875..a84ef0778c 100644
--- a/tests/e2e/test_phi.py
+++ b/tests/e2e/test_phi.py
@@ -31,9 +31,9 @@ def test_ft(self):
"trust_remote_code": True,
"model_type": "MixFormerSequentialForCausalLM",
"tokenizer_type": "AutoTokenizer",
- "sequence_len": 2048,
+ "sequence_len": 512,
"sample_packing": False,
- "load_in_8bit": True,
+ "load_in_8bit": False,
"adapter": None,
"val_set_size": 0.1,
"special_tokens": {
@@ -55,8 +55,9 @@ def test_ft(self):
"gradient_accumulation_steps": 1,
"output_dir": tempfile.mkdtemp(),
"learning_rate": 0.00001,
- "optimizer": "adamw_torch",
+ "optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
+ "bf16": True,
}
)
normalize_config(cfg)
@@ -74,9 +75,9 @@ def test_ft_packed(self):
"trust_remote_code": True,
"model_type": "MixFormerSequentialForCausalLM",
"tokenizer_type": "AutoTokenizer",
- "sequence_len": 2048,
+ "sequence_len": 512,
"sample_packing": True,
- "load_in_8bit": True,
+ "load_in_8bit": False,
"adapter": None,
"val_set_size": 0.1,
"special_tokens": {
@@ -98,8 +99,9 @@ def test_ft_packed(self):
"gradient_accumulation_steps": 1,
"output_dir": tempfile.mkdtemp(),
"learning_rate": 0.00001,
- "optimizer": "adamw_torch",
+ "optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
+ "bf16": True,
}
)
normalize_config(cfg)
diff --git a/tests/fixtures/conversation.missingturns.json b/tests/fixtures/conversation.missingturns.json
new file mode 100644
index 0000000000..b5b962d859
--- /dev/null
+++ b/tests/fixtures/conversation.missingturns.json
@@ -0,0 +1,52 @@
+{
+ "conversations": [
+ {
+ "from": "human",
+ "value": "how could i implement a minesweeper algorithm that utilises algebraic topology to solve boards?"
+ },
+ {
+ "from": "gpt",
+ "value": ""
+ },
+ {
+ "from": "human",
+ "value": "could you give me the gist of how it could though?"
+ },
+ {
+ "from": "gpt",
+ "value": "Sure. The basic idea would be to use algebraic topology to analyze the structure of the Minesweeper game board and determine the locations of mines based on certain topological invariants. This could involve representing the game board as a topological space and using topological invariants such as homotopy groups and Betti numbers to infer information about the location of mines on the board. However, as I mentioned before, implementing this would require a deep understanding of both algebraic topology and the game of Minesweeper, so it's not something that can be explained in a simple answer."
+ },
+ {
+ "from": "human",
+ "value": "give me a specific example"
+ },
+ {
+ "from": "gpt",
+ "value": "As a specific example, suppose we have a Minesweeper game board with the following configuration:\n```\n 1 1 1 1 1\n 1 ? ? ? 1\n 1 ? ? ? 1\n 1 ? ? ? 1\n 1 1 1 1 1\n```\nIn this case, the game board can be represented as a topological space, where each square on the board corresponds to a point in the space. The mines on the board can be thought of as \"holes\" in the space, which are points that are not part of the space. Using algebraic topology, we can analyze the structure of this space and determine the locations of the mines based on certain topological invariants.\n\nFor instance, one way to do this is to compute the homotopy groups of the space. Homotopy groups are a way of describing the holes in a space, and they can be used to infer information about the locations of the mines on the Minesweeper board. In the case of the board above, we can compute the first homotopy group of the space, which describes the number and \"shape\" of the holes in the space. In this case, the first homotopy group would tell us that there are four holes in the space, arranged in a square shape.\n\nBased on this information, we can infer that the mines in the game are located in the four corners of the board, as shown below:\n```\n 1 1 1 1 1\n 1 M ? ? 1\n 1 ? ? ? 1\n 1 ? ? M 1\n 1 1 1 1 1\n```\nOf course, this is just one example of how algebraic topology could be used to solve Minesweeper. There are many other possible approaches, and the specific details of how to implement them would depend on the specific game board and the topological invariants being used. As I mentioned before, implementing a solution like this would require a deep understanding of both algebraic topology and the game of Minesweeper."
+ },
+ {
+ "from": "human",
+ "value": "assuming i was someone with a deep understanding of both, could you explain the same concept to me with that assumption?"
+ },
+ {
+ "from": "gpt",
+ "value": "Sure. Assuming you have a deep understanding of both algebraic topology and the game of Minesweeper, here is how you might use algebraic topology to solve a Minesweeper game board:\n\n1. Represent the Minesweeper game board as a topological space. Each square on the board corresponds to a point in the space, and the mines on the board can be thought of as \"holes\" in the space, which are points that are not part of the space.\n2. Use topological invariants to infer information about the locations of the mines on the board. There are many possible topological invariants that could be used for this, such as homotopy groups, Betti numbers, or persistent homology.\n3. Use the information inferred from the topological invariants to determine the locations of the mines on the board. For example, if the first homotopy group of the space indicates that there are four holes in the space, arranged in a square shape, then we can infer that the mines are located in the four corners of the board.\n4. Use this information to solve the Minesweeper game. Once you know the locations of the mines, you can use this information to avoid them and clear the board.\n\nOf course, this is just one possible approach to using algebraic topology to solve Minesweeper. There may be other ways to do it, depending on the specific game board and the topological invariants being used. However, this should give you a general idea of how algebraic topology can be used to solve Minesweeper."
+ },
+ {
+ "from": "human",
+ "value": "how would i do it with persistent homology or any of the other topological invariants?"
+ },
+ {
+ "from": "gpt",
+ "value": "To use persistent homology or any of the other topological invariants to solve a Minesweeper game board, you would follow the same basic steps as described above, but with some differences in the specific details.\n\nFor instance, here is how you might use persistent homology to solve a Minesweeper game board:\n\n1. Represent the Minesweeper game board as a topological space. Each square on the board corresponds to a point in the space, and the mines on the board can be thought of as \"holes\" in the space, which are points that are not part of the space.\n2. Compute the persistent homology of the space. Persistent homology is a way of describing the holes in a space over time, and it can be used to infer information about the locations of the mines on the Minesweeper board.\n3. Use the information from the persistent homology to determine the locations of the mines on the board. For example, if the persistent homology indicates that there are four holes in the space, arranged in a square shape, then we can infer that the mines are located in the four corners of the board.\n4. Use this information to solve the Minesweeper game. Once you know the locations of the mines, you can use this information to avoid them and clear the board.\n\nOf course, this is just one possible approach to using persistent homology to solve Minesweeper. There may be other ways to do it, depending on the specific game board and the specific details of the persistent homology computation. However, this should give you a general idea of how persistent homology can be used to solve Minesweeper."
+ },
+ {
+ "from": "human",
+ "value": "how could i implement this in python?"
+ },
+ {
+ "from": "gpt",
+ "value": "To implement a Minesweeper solver that uses algebraic topology in Python, you would need to do the following:\n\n1. Install a Python package for algebraic topology. There are several options available, such as Gudhi, Dionysus, or scikit-tda. These packages provide algorithms and data structures for computing topological invariants, such as homotopy groups, Betti numbers, or persistent homology.\n2. Represent the Minesweeper game board as a topological space. This could be done using a data structure provided by the algebraic topology package, such as a simplicial complex or a Vietoris-Rips complex. Each square on the board would correspond to a point in the space, and the mines on the board would be represented as \"holes\" in the space.\n3. Compute the topological invariants of the space. This could be done using algorithms provided by the algebraic topology package. For example, you could compute the persistent homology of the space using the persistent homology algorithm.\n4. Use the information from the topological invariants to determine the locations of the mines on the board. This could be done using a combination of mathematical reasoning and programming. For example, if the persistent homology indicates that there are four holes in the space, arranged in a square shape, then you could use a piece of code to infer that the mines are located in the four corners of the board.\n5. Use this information to solve the Minesweeper game. Once you know the locations of the mines, you could use a piece of code to automatically click on the safe squares and clear the board.\n\nOf course, this is just one possible approach to implementing a Minesweeper solver that uses algebraic topology in Python. There may be other ways to do it, depending on the specific details of the implementation. However, this should give you a general idea of how to get started with such a project."
+ }
+ ]
+}
diff --git a/tests/fixtures/conversation.tokenized.json b/tests/fixtures/conversation.tokenized.json
index 0ac93713bc..4b7334cc15 100644
--- a/tests/fixtures/conversation.tokenized.json
+++ b/tests/fixtures/conversation.tokenized.json
@@ -1 +1 @@
-{"input_ids": [1, 319, 13563, 1546, 263, 12758, 1404, 322, 385, 23116, 21082, 20255, 29889, 450, 20255, 4076, 8444, 29892, 13173, 29892, 322, 1248, 568, 6089, 304, 278, 1404, 29915, 29879, 5155, 29889, 3148, 1001, 29901, 920, 1033, 474, 2334, 263, 29086, 705, 11356, 5687, 393, 3667, 4637, 21531, 20159, 304, 4505, 1045, 3163, 29973, 319, 1799, 9047, 13566, 29901, 739, 29915, 29879, 1950, 304, 671, 21531, 20159, 304, 4505, 341, 1475, 705, 11356, 29892, 541, 372, 723, 5517, 367, 3755, 4280, 29889, 838, 5869, 293, 20159, 338, 263, 5443, 310, 23964, 393, 11898, 278, 4426, 310, 8162, 393, 526, 21634, 1090, 9126, 316, 689, 800, 29892, 1316, 408, 16116, 292, 322, 289, 2548, 29889, 512, 278, 3030, 310, 341, 1475, 705, 11356, 29892, 445, 1033, 367, 1304, 304, 27599, 278, 3829, 310, 278, 3748, 7613, 322, 8161, 278, 14354, 310, 29086, 2729, 373, 3058, 25002, 15939, 1934, 29889, 2398, 29892, 16049, 445, 723, 1996, 263, 6483, 8004, 310, 1716, 21531, 20159, 322, 278, 3748, 310, 341, 1475, 705, 11356, 29889, 739, 29915, 29879, 451, 1554, 393, 508, 367, 10824, 297, 263, 2560, 1234, 29889, 2, 3148, 1001, 29901, 1033, 366, 2367, 592, 278, 330, 391, 310, 920, 372, 1033, 2466, 29973, 319, 1799, 9047, 13566, 29901, 18585, 29889, 450, 6996, 2969, 723, 367, 304, 671, 21531, 20159, 304, 27599, 278, 3829, 310, 278, 341, 1475, 705, 11356, 3748, 7613, 322, 8161, 278, 14354, 310, 29086, 2729, 373, 3058, 25002, 15939, 1934, 29889, 910, 1033, 25135, 15783, 278, 3748, 7613, 408, 263, 25002, 2913, 322, 773, 25002, 15939, 1934, 1316, 408, 3632, 327, 2270, 6471, 322, 350, 9890, 3694, 304, 10115, 2472, 1048, 278, 4423, 310, 29086, 373, 278, 7613, 29889, 2398, 29892, 408, 306, 5276, 1434, 29892, 16049, 445, 723, 1996, 263, 6483, 8004, 310, 1716, 21531, 20159, 322, 278, 3748, 310, 341, 1475, 705, 11356, 29892, 577, 372, 29915, 29879, 451, 1554, 393, 508, 367, 10824, 297, 263, 2560, 1234, 29889, 2, 3148, 1001, 29901, 2367, 592, 263, 2702, 1342, 319, 1799, 9047, 13566, 29901, 1094, 263, 2702, 1342, 29892, 7755, 591, 505, 263, 341, 1475, 705, 11356, 3748, 7613, 411, 278, 1494, 5285, 29901, 13, 28956, 13, 29871, 29896, 259, 29896, 259, 29896, 259, 29896, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 259, 29896, 259, 29896, 259, 29896, 259, 29896, 13, 28956, 13, 797, 445, 1206, 29892, 278, 3748, 7613, 508, 367, 9875, 408, 263, 25002, 2913, 29892, 988, 1269, 6862, 373, 278, 7613, 16161, 304, 263, 1298, 297, 278, 2913, 29889, 450, 29086, 373, 278, 7613, 508, 367, 2714, 310, 408, 376, 5391, 267, 29908, 297, 278, 2913, 29892, 607, 526, 3291, 393, 526, 451, 760, 310, 278, 2913, 29889, 5293, 21531, 20159, 29892, 591, 508, 27599, 278, 3829, 310, 445, 2913, 322, 8161, 278, 14354, 310, 278, 29086, 2729, 373, 3058, 25002, 15939, 1934, 29889, 13, 13, 2831, 2777, 29892, 697, 982, 304, 437, 445, 338, 304, 10272, 278, 3632, 327, 2270, 6471, 310, 278, 2913, 29889, 15089, 327, 2270, 6471, 526, 263, 982, 310, 20766, 278, 26532, 297, 263, 2913, 29892, 322, 896, 508, 367, 1304, 304, 10115, 2472, 1048, 278, 14354, 310, 278, 29086, 373, 278, 341, 1475, 705, 11356, 7613, 29889, 512, 278, 1206, 310, 278, 7613, 2038, 29892, 591, 508, 10272, 278, 937, 3632, 327, 2270, 2318, 310, 278, 2913, 29892, 607, 16612, 278, 1353, 322, 376, 12181, 29908, 310, 278, 26532, 297, 278, 2913, 29889, 512, 445, 1206, 29892, 278, 937, 3632, 327, 2270, 2318, 723, 2649, 502, 393, 727, 526, 3023, 26532, 297, 278, 2913, 29892, 21050, 297, 263, 6862, 8267, 29889, 13, 13, 29933, 1463, 373, 445, 2472, 29892, 591, 508, 10115, 393, 278, 29086, 297, 278, 3748, 526, 5982, 297, 278, 3023, 26995, 310, 278, 7613, 29892, 408, 4318, 2400, 29901, 13, 28956, 13, 29871, 29896, 259, 29896, 259, 29896, 259, 29896, 259, 29896, 13, 29871, 29896, 29871, 341, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 341, 259, 29896, 13, 29871, 29896, 259, 29896, 259, 29896, 259, 29896, 259, 29896, 13, 28956, 13, 2776, 3236, 29892, 445, 338, 925, 697, 1342, 310, 920, 21531, 20159, 1033, 367, 1304, 304, 4505, 341, 1475, 705, 11356, 29889, 1670, 526, 1784, 916, 1950, 13501, 29892, 322, 278, 2702, 4902, 310, 920, 304, 2334, 963, 723, 8839, 373, 278, 2702, 3748, 7613, 322, 278, 25002, 15939, 1934, 1641, 1304, 29889, 1094, 306, 5276, 1434, 29892, 16049, 263, 1650, 763, 445, 723, 1996, 263, 6483, 8004, 310, 1716, 21531, 20159, 322, 278, 3748, 310, 341, 1475, 705, 11356, 29889, 2, 3148, 1001, 29901, 10241, 474, 471, 4856, 411, 263, 6483, 8004, 310, 1716, 29892, 1033, 366, 5649, 278, 1021, 6964, 304, 592, 411, 393, 11833, 29973, 319, 1799, 9047, 13566, 29901, 18585, 29889, 17090, 366, 505, 263, 6483, 8004, 310, 1716, 21531, 20159, 322, 278, 3748, 310, 341, 1475, 705, 11356, 29892, 1244, 338, 920, 366, 1795, 671, 21531, 20159, 304, 4505, 263, 341, 1475, 705, 11356, 3748, 7613, 29901, 13, 13, 29896, 29889, 16314, 278, 341, 1475, 705, 11356, 3748, 7613, 408, 263, 25002, 2913, 29889, 7806, 6862, 373, 278, 7613, 16161, 304, 263, 1298, 297, 278, 2913, 29892, 322, 278, 29086, 373, 278, 7613, 508, 367, 2714, 310, 408, 376, 5391, 267, 29908, 297, 278, 2913, 29892, 607, 526, 3291, 393, 526, 451, 760, 310, 278, 2913, 29889, 13, 29906, 29889, 4803, 25002, 15939, 1934, 304, 10115, 2472, 1048, 278, 14354, 310, 278, 29086, 373, 278, 7613, 29889, 1670, 526, 1784, 1950, 25002, 15939, 1934, 393, 1033, 367, 1304, 363, 445, 29892, 1316, 408, 3632, 327, 2270, 6471, 29892, 350, 9890, 3694, 29892, 470, 28152, 3632, 3002, 29889, 13, 29941, 29889, 4803, 278, 2472, 10115, 1127, 515, 278, 25002, 15939, 1934, 304, 8161, 278, 14354, 310, 278, 29086, 373, 278, 7613, 29889, 1152, 1342, 29892, 565, 278, 937, 3632, 327, 2270, 2318, 310, 278, 2913, 14088, 393, 727, 526, 3023, 26532, 297, 278, 2913, 29892, 21050, 297, 263, 6862, 8267, 29892, 769, 591, 508, 10115, 393, 278, 29086, 526, 5982, 297, 278, 3023, 26995, 310, 278, 7613, 29889, 13, 29946, 29889, 4803, 445, 2472, 304, 4505, 278, 341, 1475, 705, 11356, 3748, 29889, 9038, 366, 1073, 278, 14354, 310, 278, 29086, 29892, 366, 508, 671, 445, 2472, 304, 4772, 963, 322, 2821, 278, 7613, 29889, 13, 13, 2776, 3236, 29892, 445, 338, 925, 697, 1950, 2948, 304, 773, 21531, 20159, 304, 4505, 341, 1475, 705, 11356, 29889, 1670, 1122, 367, 916, 5837, 304, 437, 372, 29892, 8679, 373, 278, 2702, 3748, 7613, 322, 278, 25002, 15939, 1934, 1641, 1304, 29889, 2398, 29892, 445, 881, 2367, 366, 263, 2498, 2969, 310, 920, 21531, 20159, 508, 367, 1304, 304, 4505, 341, 1475, 705, 11356, 29889, 2, 3148, 1001, 29901, 920, 723, 474, 437, 372, 411, 28152, 3632, 3002, 470, 738, 310, 278, 916, 25002, 15939, 1934, 29973, 319, 1799, 9047, 13566, 29901, 1763, 671, 28152, 3632, 3002, 470, 738, 310, 278, 916, 25002, 15939, 1934, 304, 4505, 263, 341, 1475, 705, 11356, 3748, 7613, 29892, 366, 723, 1101, 278, 1021, 6996, 6576, 408, 5439, 2038, 29892, 541, 411, 777, 12651, 297, 278, 2702, 4902, 29889, 13, 13, 2831, 2777, 29892, 1244, 338, 920, 366, 1795, 671, 28152, 3632, 3002, 304, 4505, 263, 341, 1475, 705, 11356, 3748, 7613, 29901, 13, 13, 29896, 29889, 16314, 278, 341, 1475, 705, 11356, 3748, 7613, 408, 263, 25002, 2913, 29889, 7806, 6862, 373, 278, 7613, 16161, 304, 263, 1298, 297, 278, 2913, 29892, 322, 278, 29086, 373, 278, 7613, 508, 367, 2714, 310, 408, 376, 5391, 267, 29908, 297, 278, 2913, 29892, 607, 526, 3291, 393, 526, 451, 760, 310, 278, 2913, 29889, 13, 29906, 29889, 11796, 29872, 278, 28152, 3632, 3002, 310, 278, 2913, 29889, 9034, 9696, 3632, 3002, 338, 263, 982, 310, 20766, 278, 26532, 297, 263, 2913, 975, 931, 29892, 322, 372, 508, 367, 1304, 304, 10115, 2472, 1048, 278, 14354, 310, 278, 29086, 373, 278, 341, 1475, 705, 11356, 7613, 29889, 13, 29941, 29889, 4803, 278, 2472, 515, 278, 28152, 3632, 3002, 304, 8161, 278, 14354, 310, 278, 29086, 373, 278, 7613, 29889, 1152, 1342, 29892, 565, 278, 28152, 3632, 3002, 14088, 393, 727, 526, 3023, 26532, 297, 278, 2913, 29892, 21050, 297, 263, 6862, 8267, 29892, 769, 591, 508, 10115, 393, 278, 29086, 526, 5982, 297, 278, 3023, 26995, 310, 278, 7613, 29889, 13, 29946, 29889, 4803, 445, 2472, 304, 4505, 278, 341, 1475, 705, 11356, 3748, 29889, 9038, 366, 1073, 278, 14354, 310, 278, 29086, 29892, 366, 508, 671, 445, 2472, 304, 4772, 963, 322, 2821, 278, 7613, 29889, 13, 13, 2776, 3236, 29892, 445, 338, 925, 697, 1950, 2948, 304, 773, 28152, 3632, 3002, 304, 4505, 341, 1475, 705, 11356, 29889, 1670, 1122, 367, 916, 5837, 304, 437, 372, 29892, 8679, 373, 278, 2702, 3748, 7613, 322, 278, 2702, 4902, 310, 278, 28152, 3632, 3002, 16287, 29889, 2398, 29892, 445, 881, 2367, 366, 263, 2498, 2969, 310, 920, 28152, 3632, 3002, 508, 367, 1304, 304, 4505, 341, 1475, 705, 11356, 29889, 2, 3148, 1001, 29901, 920, 1033, 474, 2334, 445, 297, 3017, 29973, 319, 1799, 9047, 13566, 29901, 1763, 2334, 263, 341, 1475, 705, 11356, 899, 369, 393, 3913, 21531, 20159, 297, 5132, 29892, 366, 723, 817, 304, 437, 278, 1494, 29901, 13, 13, 29896, 29889, 16052, 263, 5132, 3577, 363, 21531, 20159, 29889, 1670, 526, 3196, 3987, 3625, 29892, 1316, 408, 402, 566, 2918, 29892, 360, 291, 952, 375, 29892, 470, 4560, 7354, 29899, 29873, 1388, 29889, 4525, 9741, 3867, 14009, 322, 848, 12286, 363, 20602, 25002, 15939, 1934, 29892, 1316, 408, 3632, 327, 2270, 6471, 29892, 350, 9890, 3694, 29892, 470, 28152, 3632, 3002, 29889, 13, 29906, 29889, 16314, 278, 341, 1475, 705, 11356, 3748, 7613, 408, 263, 25002, 2913, 29889, 910, 1033, 367, 2309, 773, 263, 848, 3829, 4944, 491, 278, 21531, 20159, 3577, 29892, 1316, 408, 263, 3053, 506, 616, 4280, 470, 263, 478, 2035, 29367, 29899, 29934, 4512, 4280, 29889, 7806, 6862, 373, 278, 7613, 723, 3928, 304, 263, 1298, 297, 278, 2913, 29892, 322, 278, 29086, 373, 278, 7613, 723, 367, 9875, 408, 376, 5391, 267, 29908, 297, 278, 2913, 29889, 13, 29941, 29889, 11796, 29872, 278, 25002, 15939, 1934, 310, 278, 2913, 29889, 910, 1033, 367, 2309, 773, 14009, 4944, 491, 278, 21531, 20159, 3577, 29889, 1152, 1342, 29892, 366, 1033, 10272, 278, 28152, 3632, 3002, 310, 278, 2913, 773, 278, 28152, 3632, 3002, 5687, 29889, 13, 29946, 29889, 4803, 278, 2472, 515, 278, 25002, 15939, 1934, 304, 8161, 278, 14354, 310, 278, 29086, 373, 278, 7613, 29889, 910, 1033, 367, 2309, 773, 263, 10296, 310, 19475, 24481, 322, 8720, 29889, 1152, 1342, 29892, 565, 278, 28152, 3632, 3002, 14088, 393, 727, 526, 3023, 26532, 297, 278, 2913, 29892, 21050, 297, 263, 6862, 8267, 29892, 769, 366, 1033, 671, 263, 8424, 310, 775, 304, 10115, 393, 278, 29086, 526, 5982, 297, 278, 3023, 26995, 310, 278, 7613, 29889, 13, 29945, 29889, 4803, 445, 2472, 304, 4505, 278, 341, 1475, 705, 11356, 3748, 29889, 9038, 366, 1073, 278, 14354, 310, 278, 29086, 29892, 366, 1033, 671, 263, 8424, 310, 775, 304, 6336, 2828, 373, 278, 9109, 25256, 322, 2821, 278, 7613, 29889, 13, 13, 2776, 3236, 29892, 445, 338, 925, 697, 1950, 2948, 304, 16049, 263, 341, 1475, 705, 11356, 899, 369, 393, 3913, 21531, 20159, 297, 5132, 29889, 1670, 1122, 367, 916, 5837, 304, 437, 372, 29892, 8679, 373, 278, 2702, 4902, 310, 278, 5314, 29889, 2398, 29892, 445, 881, 2367, 366, 263, 2498, 2969, 310, 920, 304, 679, 4687, 411, 1316, 263, 2060, 29889, 2], "attention_mask": [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "labels": [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 319, 1799, 9047, 13566, 29901, 739, 29915, 29879, 1950, 304, 671, 21531, 20159, 304, 4505, 341, 1475, 705, 11356, 29892, 541, 372, 723, 5517, 367, 3755, 4280, 29889, 838, 5869, 293, 20159, 338, 263, 5443, 310, 23964, 393, 11898, 278, 4426, 310, 8162, 393, 526, 21634, 1090, 9126, 316, 689, 800, 29892, 1316, 408, 16116, 292, 322, 289, 2548, 29889, 512, 278, 3030, 310, 341, 1475, 705, 11356, 29892, 445, 1033, 367, 1304, 304, 27599, 278, 3829, 310, 278, 3748, 7613, 322, 8161, 278, 14354, 310, 29086, 2729, 373, 3058, 25002, 15939, 1934, 29889, 2398, 29892, 16049, 445, 723, 1996, 263, 6483, 8004, 310, 1716, 21531, 20159, 322, 278, 3748, 310, 341, 1475, 705, 11356, 29889, 739, 29915, 29879, 451, 1554, 393, 508, 367, 10824, 297, 263, 2560, 1234, 29889, 2, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 319, 1799, 9047, 13566, 29901, 18585, 29889, 450, 6996, 2969, 723, 367, 304, 671, 21531, 20159, 304, 27599, 278, 3829, 310, 278, 341, 1475, 705, 11356, 3748, 7613, 322, 8161, 278, 14354, 310, 29086, 2729, 373, 3058, 25002, 15939, 1934, 29889, 910, 1033, 25135, 15783, 278, 3748, 7613, 408, 263, 25002, 2913, 322, 773, 25002, 15939, 1934, 1316, 408, 3632, 327, 2270, 6471, 322, 350, 9890, 3694, 304, 10115, 2472, 1048, 278, 4423, 310, 29086, 373, 278, 7613, 29889, 2398, 29892, 408, 306, 5276, 1434, 29892, 16049, 445, 723, 1996, 263, 6483, 8004, 310, 1716, 21531, 20159, 322, 278, 3748, 310, 341, 1475, 705, 11356, 29892, 577, 372, 29915, 29879, 451, 1554, 393, 508, 367, 10824, 297, 263, 2560, 1234, 29889, 2, -100, -100, -100, -100, -100, -100, -100, -100, 319, 1799, 9047, 13566, 29901, 1094, 263, 2702, 1342, 29892, 7755, 591, 505, 263, 341, 1475, 705, 11356, 3748, 7613, 411, 278, 1494, 5285, 29901, 13, 28956, 13, 29871, 29896, 259, 29896, 259, 29896, 259, 29896, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 259, 29896, 259, 29896, 259, 29896, 259, 29896, 13, 28956, 13, 797, 445, 1206, 29892, 278, 3748, 7613, 508, 367, 9875, 408, 263, 25002, 2913, 29892, 988, 1269, 6862, 373, 278, 7613, 16161, 304, 263, 1298, 297, 278, 2913, 29889, 450, 29086, 373, 278, 7613, 508, 367, 2714, 310, 408, 376, 5391, 267, 29908, 297, 278, 2913, 29892, 607, 526, 3291, 393, 526, 451, 760, 310, 278, 2913, 29889, 5293, 21531, 20159, 29892, 591, 508, 27599, 278, 3829, 310, 445, 2913, 322, 8161, 278, 14354, 310, 278, 29086, 2729, 373, 3058, 25002, 15939, 1934, 29889, 13, 13, 2831, 2777, 29892, 697, 982, 304, 437, 445, 338, 304, 10272, 278, 3632, 327, 2270, 6471, 310, 278, 2913, 29889, 15089, 327, 2270, 6471, 526, 263, 982, 310, 20766, 278, 26532, 297, 263, 2913, 29892, 322, 896, 508, 367, 1304, 304, 10115, 2472, 1048, 278, 14354, 310, 278, 29086, 373, 278, 341, 1475, 705, 11356, 7613, 29889, 512, 278, 1206, 310, 278, 7613, 2038, 29892, 591, 508, 10272, 278, 937, 3632, 327, 2270, 2318, 310, 278, 2913, 29892, 607, 16612, 278, 1353, 322, 376, 12181, 29908, 310, 278, 26532, 297, 278, 2913, 29889, 512, 445, 1206, 29892, 278, 937, 3632, 327, 2270, 2318, 723, 2649, 502, 393, 727, 526, 3023, 26532, 297, 278, 2913, 29892, 21050, 297, 263, 6862, 8267, 29889, 13, 13, 29933, 1463, 373, 445, 2472, 29892, 591, 508, 10115, 393, 278, 29086, 297, 278, 3748, 526, 5982, 297, 278, 3023, 26995, 310, 278, 7613, 29892, 408, 4318, 2400, 29901, 13, 28956, 13, 29871, 29896, 259, 29896, 259, 29896, 259, 29896, 259, 29896, 13, 29871, 29896, 29871, 341, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 341, 259, 29896, 13, 29871, 29896, 259, 29896, 259, 29896, 259, 29896, 259, 29896, 13, 28956, 13, 2776, 3236, 29892, 445, 338, 925, 697, 1342, 310, 920, 21531, 20159, 1033, 367, 1304, 304, 4505, 341, 1475, 705, 11356, 29889, 1670, 526, 1784, 916, 1950, 13501, 29892, 322, 278, 2702, 4902, 310, 920, 304, 2334, 963, 723, 8839, 373, 278, 2702, 3748, 7613, 322, 278, 25002, 15939, 1934, 1641, 1304, 29889, 1094, 306, 5276, 1434, 29892, 16049, 263, 1650, 763, 445, 723, 1996, 263, 6483, 8004, 310, 1716, 21531, 20159, 322, 278, 3748, 310, 341, 1475, 705, 11356, 29889, 2, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 319, 1799, 9047, 13566, 29901, 18585, 29889, 17090, 366, 505, 263, 6483, 8004, 310, 1716, 21531, 20159, 322, 278, 3748, 310, 341, 1475, 705, 11356, 29892, 1244, 338, 920, 366, 1795, 671, 21531, 20159, 304, 4505, 263, 341, 1475, 705, 11356, 3748, 7613, 29901, 13, 13, 29896, 29889, 16314, 278, 341, 1475, 705, 11356, 3748, 7613, 408, 263, 25002, 2913, 29889, 7806, 6862, 373, 278, 7613, 16161, 304, 263, 1298, 297, 278, 2913, 29892, 322, 278, 29086, 373, 278, 7613, 508, 367, 2714, 310, 408, 376, 5391, 267, 29908, 297, 278, 2913, 29892, 607, 526, 3291, 393, 526, 451, 760, 310, 278, 2913, 29889, 13, 29906, 29889, 4803, 25002, 15939, 1934, 304, 10115, 2472, 1048, 278, 14354, 310, 278, 29086, 373, 278, 7613, 29889, 1670, 526, 1784, 1950, 25002, 15939, 1934, 393, 1033, 367, 1304, 363, 445, 29892, 1316, 408, 3632, 327, 2270, 6471, 29892, 350, 9890, 3694, 29892, 470, 28152, 3632, 3002, 29889, 13, 29941, 29889, 4803, 278, 2472, 10115, 1127, 515, 278, 25002, 15939, 1934, 304, 8161, 278, 14354, 310, 278, 29086, 373, 278, 7613, 29889, 1152, 1342, 29892, 565, 278, 937, 3632, 327, 2270, 2318, 310, 278, 2913, 14088, 393, 727, 526, 3023, 26532, 297, 278, 2913, 29892, 21050, 297, 263, 6862, 8267, 29892, 769, 591, 508, 10115, 393, 278, 29086, 526, 5982, 297, 278, 3023, 26995, 310, 278, 7613, 29889, 13, 29946, 29889, 4803, 445, 2472, 304, 4505, 278, 341, 1475, 705, 11356, 3748, 29889, 9038, 366, 1073, 278, 14354, 310, 278, 29086, 29892, 366, 508, 671, 445, 2472, 304, 4772, 963, 322, 2821, 278, 7613, 29889, 13, 13, 2776, 3236, 29892, 445, 338, 925, 697, 1950, 2948, 304, 773, 21531, 20159, 304, 4505, 341, 1475, 705, 11356, 29889, 1670, 1122, 367, 916, 5837, 304, 437, 372, 29892, 8679, 373, 278, 2702, 3748, 7613, 322, 278, 25002, 15939, 1934, 1641, 1304, 29889, 2398, 29892, 445, 881, 2367, 366, 263, 2498, 2969, 310, 920, 21531, 20159, 508, 367, 1304, 304, 4505, 341, 1475, 705, 11356, 29889, 2, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 319, 1799, 9047, 13566, 29901, 1763, 671, 28152, 3632, 3002, 470, 738, 310, 278, 916, 25002, 15939, 1934, 304, 4505, 263, 341, 1475, 705, 11356, 3748, 7613, 29892, 366, 723, 1101, 278, 1021, 6996, 6576, 408, 5439, 2038, 29892, 541, 411, 777, 12651, 297, 278, 2702, 4902, 29889, 13, 13, 2831, 2777, 29892, 1244, 338, 920, 366, 1795, 671, 28152, 3632, 3002, 304, 4505, 263, 341, 1475, 705, 11356, 3748, 7613, 29901, 13, 13, 29896, 29889, 16314, 278, 341, 1475, 705, 11356, 3748, 7613, 408, 263, 25002, 2913, 29889, 7806, 6862, 373, 278, 7613, 16161, 304, 263, 1298, 297, 278, 2913, 29892, 322, 278, 29086, 373, 278, 7613, 508, 367, 2714, 310, 408, 376, 5391, 267, 29908, 297, 278, 2913, 29892, 607, 526, 3291, 393, 526, 451, 760, 310, 278, 2913, 29889, 13, 29906, 29889, 11796, 29872, 278, 28152, 3632, 3002, 310, 278, 2913, 29889, 9034, 9696, 3632, 3002, 338, 263, 982, 310, 20766, 278, 26532, 297, 263, 2913, 975, 931, 29892, 322, 372, 508, 367, 1304, 304, 10115, 2472, 1048, 278, 14354, 310, 278, 29086, 373, 278, 341, 1475, 705, 11356, 7613, 29889, 13, 29941, 29889, 4803, 278, 2472, 515, 278, 28152, 3632, 3002, 304, 8161, 278, 14354, 310, 278, 29086, 373, 278, 7613, 29889, 1152, 1342, 29892, 565, 278, 28152, 3632, 3002, 14088, 393, 727, 526, 3023, 26532, 297, 278, 2913, 29892, 21050, 297, 263, 6862, 8267, 29892, 769, 591, 508, 10115, 393, 278, 29086, 526, 5982, 297, 278, 3023, 26995, 310, 278, 7613, 29889, 13, 29946, 29889, 4803, 445, 2472, 304, 4505, 278, 341, 1475, 705, 11356, 3748, 29889, 9038, 366, 1073, 278, 14354, 310, 278, 29086, 29892, 366, 508, 671, 445, 2472, 304, 4772, 963, 322, 2821, 278, 7613, 29889, 13, 13, 2776, 3236, 29892, 445, 338, 925, 697, 1950, 2948, 304, 773, 28152, 3632, 3002, 304, 4505, 341, 1475, 705, 11356, 29889, 1670, 1122, 367, 916, 5837, 304, 437, 372, 29892, 8679, 373, 278, 2702, 3748, 7613, 322, 278, 2702, 4902, 310, 278, 28152, 3632, 3002, 16287, 29889, 2398, 29892, 445, 881, 2367, 366, 263, 2498, 2969, 310, 920, 28152, 3632, 3002, 508, 367, 1304, 304, 4505, 341, 1475, 705, 11356, 29889, 2, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 319, 1799, 9047, 13566, 29901, 1763, 2334, 263, 341, 1475, 705, 11356, 899, 369, 393, 3913, 21531, 20159, 297, 5132, 29892, 366, 723, 817, 304, 437, 278, 1494, 29901, 13, 13, 29896, 29889, 16052, 263, 5132, 3577, 363, 21531, 20159, 29889, 1670, 526, 3196, 3987, 3625, 29892, 1316, 408, 402, 566, 2918, 29892, 360, 291, 952, 375, 29892, 470, 4560, 7354, 29899, 29873, 1388, 29889, 4525, 9741, 3867, 14009, 322, 848, 12286, 363, 20602, 25002, 15939, 1934, 29892, 1316, 408, 3632, 327, 2270, 6471, 29892, 350, 9890, 3694, 29892, 470, 28152, 3632, 3002, 29889, 13, 29906, 29889, 16314, 278, 341, 1475, 705, 11356, 3748, 7613, 408, 263, 25002, 2913, 29889, 910, 1033, 367, 2309, 773, 263, 848, 3829, 4944, 491, 278, 21531, 20159, 3577, 29892, 1316, 408, 263, 3053, 506, 616, 4280, 470, 263, 478, 2035, 29367, 29899, 29934, 4512, 4280, 29889, 7806, 6862, 373, 278, 7613, 723, 3928, 304, 263, 1298, 297, 278, 2913, 29892, 322, 278, 29086, 373, 278, 7613, 723, 367, 9875, 408, 376, 5391, 267, 29908, 297, 278, 2913, 29889, 13, 29941, 29889, 11796, 29872, 278, 25002, 15939, 1934, 310, 278, 2913, 29889, 910, 1033, 367, 2309, 773, 14009, 4944, 491, 278, 21531, 20159, 3577, 29889, 1152, 1342, 29892, 366, 1033, 10272, 278, 28152, 3632, 3002, 310, 278, 2913, 773, 278, 28152, 3632, 3002, 5687, 29889, 13, 29946, 29889, 4803, 278, 2472, 515, 278, 25002, 15939, 1934, 304, 8161, 278, 14354, 310, 278, 29086, 373, 278, 7613, 29889, 910, 1033, 367, 2309, 773, 263, 10296, 310, 19475, 24481, 322, 8720, 29889, 1152, 1342, 29892, 565, 278, 28152, 3632, 3002, 14088, 393, 727, 526, 3023, 26532, 297, 278, 2913, 29892, 21050, 297, 263, 6862, 8267, 29892, 769, 366, 1033, 671, 263, 8424, 310, 775, 304, 10115, 393, 278, 29086, 526, 5982, 297, 278, 3023, 26995, 310, 278, 7613, 29889, 13, 29945, 29889, 4803, 445, 2472, 304, 4505, 278, 341, 1475, 705, 11356, 3748, 29889, 9038, 366, 1073, 278, 14354, 310, 278, 29086, 29892, 366, 1033, 671, 263, 8424, 310, 775, 304, 6336, 2828, 373, 278, 9109, 25256, 322, 2821, 278, 7613, 29889, 13, 13, 2776, 3236, 29892, 445, 338, 925, 697, 1950, 2948, 304, 16049, 263, 341, 1475, 705, 11356, 899, 369, 393, 3913, 21531, 20159, 297, 5132, 29889, 1670, 1122, 367, 916, 5837, 304, 437, 372, 29892, 8679, 373, 278, 2702, 4902, 310, 278, 5314, 29889, 2398, 29892, 445, 881, 2367, 366, 263, 2498, 2969, 310, 920, 304, 679, 4687, 411, 1316, 263, 2060, 29889, 2]}
+{"input_ids": [1, 319, 13563, 1546, 263, 12758, 1404, 322, 385, 23116, 21082, 20255, 29889, 450, 20255, 4076, 8444, 29892, 13173, 29892, 322, 1248, 568, 6089, 304, 278, 1404, 29915, 29879, 5155, 29889, 29871, 3148, 1001, 29901, 920, 1033, 474, 2334, 263, 29086, 705, 11356, 5687, 393, 3667, 4637, 21531, 20159, 304, 4505, 1045, 3163, 29973, 29871, 319, 1799, 9047, 13566, 29901, 739, 29915, 29879, 1950, 304, 671, 21531, 20159, 304, 4505, 341, 1475, 705, 11356, 29892, 541, 372, 723, 5517, 367, 3755, 4280, 29889, 838, 5869, 293, 20159, 338, 263, 5443, 310, 23964, 393, 11898, 278, 4426, 310, 8162, 393, 526, 21634, 1090, 9126, 316, 689, 800, 29892, 1316, 408, 16116, 292, 322, 289, 2548, 29889, 512, 278, 3030, 310, 341, 1475, 705, 11356, 29892, 445, 1033, 367, 1304, 304, 27599, 278, 3829, 310, 278, 3748, 7613, 322, 8161, 278, 14354, 310, 29086, 2729, 373, 3058, 25002, 15939, 1934, 29889, 2398, 29892, 16049, 445, 723, 1996, 263, 6483, 8004, 310, 1716, 21531, 20159, 322, 278, 3748, 310, 341, 1475, 705, 11356, 29889, 739, 29915, 29879, 451, 1554, 393, 508, 367, 10824, 297, 263, 2560, 1234, 21106, 29879, 29958, 2, 3148, 1001, 29901, 1033, 366, 2367, 592, 278, 330, 391, 310, 920, 372, 1033, 2466, 29973, 29871, 319, 1799, 9047, 13566, 29901, 18585, 29889, 450, 6996, 2969, 723, 367, 304, 671, 21531, 20159, 304, 27599, 278, 3829, 310, 278, 341, 1475, 705, 11356, 3748, 7613, 322, 8161, 278, 14354, 310, 29086, 2729, 373, 3058, 25002, 15939, 1934, 29889, 910, 1033, 25135, 15783, 278, 3748, 7613, 408, 263, 25002, 2913, 322, 773, 25002, 15939, 1934, 1316, 408, 3632, 327, 2270, 6471, 322, 350, 9890, 3694, 304, 10115, 2472, 1048, 278, 4423, 310, 29086, 373, 278, 7613, 29889, 2398, 29892, 408, 306, 5276, 1434, 29892, 16049, 445, 723, 1996, 263, 6483, 8004, 310, 1716, 21531, 20159, 322, 278, 3748, 310, 341, 1475, 705, 11356, 29892, 577, 372, 29915, 29879, 451, 1554, 393, 508, 367, 10824, 297, 263, 2560, 1234, 21106, 29879, 29958, 2, 3148, 1001, 29901, 2367, 592, 263, 2702, 1342, 29871, 319, 1799, 9047, 13566, 29901, 1094, 263, 2702, 1342, 29892, 7755, 591, 505, 263, 341, 1475, 705, 11356, 3748, 7613, 411, 278, 1494, 5285, 29901, 13, 28956, 13, 29871, 29896, 259, 29896, 259, 29896, 259, 29896, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 259, 29896, 259, 29896, 259, 29896, 259, 29896, 13, 28956, 13, 797, 445, 1206, 29892, 278, 3748, 7613, 508, 367, 9875, 408, 263, 25002, 2913, 29892, 988, 1269, 6862, 373, 278, 7613, 16161, 304, 263, 1298, 297, 278, 2913, 29889, 450, 29086, 373, 278, 7613, 508, 367, 2714, 310, 408, 376, 5391, 267, 29908, 297, 278, 2913, 29892, 607, 526, 3291, 393, 526, 451, 760, 310, 278, 2913, 29889, 5293, 21531, 20159, 29892, 591, 508, 27599, 278, 3829, 310, 445, 2913, 322, 8161, 278, 14354, 310, 278, 29086, 2729, 373, 3058, 25002, 15939, 1934, 29889, 13, 13, 2831, 2777, 29892, 697, 982, 304, 437, 445, 338, 304, 10272, 278, 3632, 327, 2270, 6471, 310, 278, 2913, 29889, 15089, 327, 2270, 6471, 526, 263, 982, 310, 20766, 278, 26532, 297, 263, 2913, 29892, 322, 896, 508, 367, 1304, 304, 10115, 2472, 1048, 278, 14354, 310, 278, 29086, 373, 278, 341, 1475, 705, 11356, 7613, 29889, 512, 278, 1206, 310, 278, 7613, 2038, 29892, 591, 508, 10272, 278, 937, 3632, 327, 2270, 2318, 310, 278, 2913, 29892, 607, 16612, 278, 1353, 322, 376, 12181, 29908, 310, 278, 26532, 297, 278, 2913, 29889, 512, 445, 1206, 29892, 278, 937, 3632, 327, 2270, 2318, 723, 2649, 502, 393, 727, 526, 3023, 26532, 297, 278, 2913, 29892, 21050, 297, 263, 6862, 8267, 29889, 13, 13, 29933, 1463, 373, 445, 2472, 29892, 591, 508, 10115, 393, 278, 29086, 297, 278, 3748, 526, 5982, 297, 278, 3023, 26995, 310, 278, 7613, 29892, 408, 4318, 2400, 29901, 13, 28956, 13, 29871, 29896, 259, 29896, 259, 29896, 259, 29896, 259, 29896, 13, 29871, 29896, 29871, 341, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 341, 259, 29896, 13, 29871, 29896, 259, 29896, 259, 29896, 259, 29896, 259, 29896, 13, 28956, 13, 2776, 3236, 29892, 445, 338, 925, 697, 1342, 310, 920, 21531, 20159, 1033, 367, 1304, 304, 4505, 341, 1475, 705, 11356, 29889, 1670, 526, 1784, 916, 1950, 13501, 29892, 322, 278, 2702, 4902, 310, 920, 304, 2334, 963, 723, 8839, 373, 278, 2702, 3748, 7613, 322, 278, 25002, 15939, 1934, 1641, 1304, 29889, 1094, 306, 5276, 1434, 29892, 16049, 263, 1650, 763, 445, 723, 1996, 263, 6483, 8004, 310, 1716, 21531, 20159, 322, 278, 3748, 310, 341, 1475, 705, 11356, 21106, 29879, 29958, 2, 3148, 1001, 29901, 10241, 474, 471, 4856, 411, 263, 6483, 8004, 310, 1716, 29892, 1033, 366, 5649, 278, 1021, 6964, 304, 592, 411, 393, 11833, 29973, 29871, 319, 1799, 9047, 13566, 29901, 18585, 29889, 17090, 366, 505, 263, 6483, 8004, 310, 1716, 21531, 20159, 322, 278, 3748, 310, 341, 1475, 705, 11356, 29892, 1244, 338, 920, 366, 1795, 671, 21531, 20159, 304, 4505, 263, 341, 1475, 705, 11356, 3748, 7613, 29901, 13, 13, 29896, 29889, 16314, 278, 341, 1475, 705, 11356, 3748, 7613, 408, 263, 25002, 2913, 29889, 7806, 6862, 373, 278, 7613, 16161, 304, 263, 1298, 297, 278, 2913, 29892, 322, 278, 29086, 373, 278, 7613, 508, 367, 2714, 310, 408, 376, 5391, 267, 29908, 297, 278, 2913, 29892, 607, 526, 3291, 393, 526, 451, 760, 310, 278, 2913, 29889, 13, 29906, 29889, 4803, 25002, 15939, 1934, 304, 10115, 2472, 1048, 278, 14354, 310, 278, 29086, 373, 278, 7613, 29889, 1670, 526, 1784, 1950, 25002, 15939, 1934, 393, 1033, 367, 1304, 363, 445, 29892, 1316, 408, 3632, 327, 2270, 6471, 29892, 350, 9890, 3694, 29892, 470, 28152, 3632, 3002, 29889, 13, 29941, 29889, 4803, 278, 2472, 10115, 1127, 515, 278, 25002, 15939, 1934, 304, 8161, 278, 14354, 310, 278, 29086, 373, 278, 7613, 29889, 1152, 1342, 29892, 565, 278, 937, 3632, 327, 2270, 2318, 310, 278, 2913, 14088, 393, 727, 526, 3023, 26532, 297, 278, 2913, 29892, 21050, 297, 263, 6862, 8267, 29892, 769, 591, 508, 10115, 393, 278, 29086, 526, 5982, 297, 278, 3023, 26995, 310, 278, 7613, 29889, 13, 29946, 29889, 4803, 445, 2472, 304, 4505, 278, 341, 1475, 705, 11356, 3748, 29889, 9038, 366, 1073, 278, 14354, 310, 278, 29086, 29892, 366, 508, 671, 445, 2472, 304, 4772, 963, 322, 2821, 278, 7613, 29889, 13, 13, 2776, 3236, 29892, 445, 338, 925, 697, 1950, 2948, 304, 773, 21531, 20159, 304, 4505, 341, 1475, 705, 11356, 29889, 1670, 1122, 367, 916, 5837, 304, 437, 372, 29892, 8679, 373, 278, 2702, 3748, 7613, 322, 278, 25002, 15939, 1934, 1641, 1304, 29889, 2398, 29892, 445, 881, 2367, 366, 263, 2498, 2969, 310, 920, 21531, 20159, 508, 367, 1304, 304, 4505, 341, 1475, 705, 11356, 21106, 29879, 29958, 2, 3148, 1001, 29901, 920, 723, 474, 437, 372, 411, 28152, 3632, 3002, 470, 738, 310, 278, 916, 25002, 15939, 1934, 29973, 29871, 319, 1799, 9047, 13566, 29901, 1763, 671, 28152, 3632, 3002, 470, 738, 310, 278, 916, 25002, 15939, 1934, 304, 4505, 263, 341, 1475, 705, 11356, 3748, 7613, 29892, 366, 723, 1101, 278, 1021, 6996, 6576, 408, 5439, 2038, 29892, 541, 411, 777, 12651, 297, 278, 2702, 4902, 29889, 13, 13, 2831, 2777, 29892, 1244, 338, 920, 366, 1795, 671, 28152, 3632, 3002, 304, 4505, 263, 341, 1475, 705, 11356, 3748, 7613, 29901, 13, 13, 29896, 29889, 16314, 278, 341, 1475, 705, 11356, 3748, 7613, 408, 263, 25002, 2913, 29889, 7806, 6862, 373, 278, 7613, 16161, 304, 263, 1298, 297, 278, 2913, 29892, 322, 278, 29086, 373, 278, 7613, 508, 367, 2714, 310, 408, 376, 5391, 267, 29908, 297, 278, 2913, 29892, 607, 526, 3291, 393, 526, 451, 760, 310, 278, 2913, 29889, 13, 29906, 29889, 11796, 29872, 278, 28152, 3632, 3002, 310, 278, 2913, 29889, 9034, 9696, 3632, 3002, 338, 263, 982, 310, 20766, 278, 26532, 297, 263, 2913, 975, 931, 29892, 322, 372, 508, 367, 1304, 304, 10115, 2472, 1048, 278, 14354, 310, 278, 29086, 373, 278, 341, 1475, 705, 11356, 7613, 29889, 13, 29941, 29889, 4803, 278, 2472, 515, 278, 28152, 3632, 3002, 304, 8161, 278, 14354, 310, 278, 29086, 373, 278, 7613, 29889, 1152, 1342, 29892, 565, 278, 28152, 3632, 3002, 14088, 393, 727, 526, 3023, 26532, 297, 278, 2913, 29892, 21050, 297, 263, 6862, 8267, 29892, 769, 591, 508, 10115, 393, 278, 29086, 526, 5982, 297, 278, 3023, 26995, 310, 278, 7613, 29889, 13, 29946, 29889, 4803, 445, 2472, 304, 4505, 278, 341, 1475, 705, 11356, 3748, 29889, 9038, 366, 1073, 278, 14354, 310, 278, 29086, 29892, 366, 508, 671, 445, 2472, 304, 4772, 963, 322, 2821, 278, 7613, 29889, 13, 13, 2776, 3236, 29892, 445, 338, 925, 697, 1950, 2948, 304, 773, 28152, 3632, 3002, 304, 4505, 341, 1475, 705, 11356, 29889, 1670, 1122, 367, 916, 5837, 304, 437, 372, 29892, 8679, 373, 278, 2702, 3748, 7613, 322, 278, 2702, 4902, 310, 278, 28152, 3632, 3002, 16287, 29889, 2398, 29892, 445, 881, 2367, 366, 263, 2498, 2969, 310, 920, 28152, 3632, 3002, 508, 367, 1304, 304, 4505, 341, 1475, 705, 11356, 21106, 29879, 29958, 2, 3148, 1001, 29901, 920, 1033, 474, 2334, 445, 297, 3017, 29973, 29871, 319, 1799, 9047, 13566, 29901, 1763, 2334, 263, 341, 1475, 705, 11356, 899, 369, 393, 3913, 21531, 20159, 297, 5132, 29892, 366, 723, 817, 304, 437, 278, 1494, 29901, 13, 13, 29896, 29889, 16052, 263, 5132, 3577, 363, 21531, 20159, 29889, 1670, 526, 3196, 3987, 3625, 29892, 1316, 408, 402, 566, 2918, 29892, 360, 291, 952, 375, 29892, 470, 4560, 7354, 29899, 29873, 1388, 29889, 4525, 9741, 3867, 14009, 322, 848, 12286, 363, 20602, 25002, 15939, 1934, 29892, 1316, 408, 3632, 327, 2270, 6471, 29892, 350, 9890, 3694, 29892, 470, 28152, 3632, 3002, 29889, 13, 29906, 29889, 16314, 278, 341, 1475, 705, 11356, 3748, 7613, 408, 263, 25002, 2913, 29889, 910, 1033, 367, 2309, 773, 263, 848, 3829, 4944, 491, 278, 21531, 20159, 3577, 29892, 1316, 408, 263, 3053, 506, 616, 4280, 470, 263, 478, 2035, 29367, 29899, 29934, 4512, 4280, 29889, 7806, 6862, 373, 278, 7613, 723, 3928, 304, 263, 1298, 297, 278, 2913, 29892, 322, 278, 29086, 373, 278, 7613, 723, 367, 9875, 408, 376, 5391, 267, 29908, 297, 278, 2913, 29889, 13, 29941, 29889, 11796, 29872, 278, 25002, 15939, 1934, 310, 278, 2913, 29889, 910, 1033, 367, 2309, 773, 14009, 4944, 491, 278, 21531, 20159, 3577, 29889, 1152, 1342, 29892, 366, 1033, 10272, 278, 28152, 3632, 3002, 310, 278, 2913, 773, 278, 28152, 3632, 3002, 5687, 29889, 13, 29946, 29889, 4803, 278, 2472, 515, 278, 25002, 15939, 1934, 304, 8161, 278, 14354, 310, 278, 29086, 373, 278, 7613, 29889, 910, 1033, 367, 2309, 773, 263, 10296, 310, 19475, 24481, 322, 8720, 29889, 1152, 1342, 29892, 565, 278, 28152, 3632, 3002, 14088, 393, 727, 526, 3023, 26532, 297, 278, 2913, 29892, 21050, 297, 263, 6862, 8267, 29892, 769, 366, 1033, 671, 263, 8424, 310, 775, 304, 10115, 393, 278, 29086, 526, 5982, 297, 278, 3023, 26995, 310, 278, 7613, 29889, 13, 29945, 29889, 4803, 445, 2472, 304, 4505, 278, 341, 1475, 705, 11356, 3748, 29889, 9038, 366, 1073, 278, 14354, 310, 278, 29086, 29892, 366, 1033, 671, 263, 8424, 310, 775, 304, 6336, 2828, 373, 278, 9109, 25256, 322, 2821, 278, 7613, 29889, 13, 13, 2776, 3236, 29892, 445, 338, 925, 697, 1950, 2948, 304, 16049, 263, 341, 1475, 705, 11356, 899, 369, 393, 3913, 21531, 20159, 297, 5132, 29889, 1670, 1122, 367, 916, 5837, 304, 437, 372, 29892, 8679, 373, 278, 2702, 4902, 310, 278, 5314, 29889, 2398, 29892, 445, 881, 2367, 366, 263, 2498, 2969, 310, 920, 304, 679, 4687, 411, 1316, 263, 2060, 21106, 29879, 29958, 2], "attention_mask": [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "labels": [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 739, 29915, 29879, 1950, 304, 671, 21531, 20159, 304, 4505, 341, 1475, 705, 11356, 29892, 541, 372, 723, 5517, 367, 3755, 4280, 29889, 838, 5869, 293, 20159, 338, 263, 5443, 310, 23964, 393, 11898, 278, 4426, 310, 8162, 393, 526, 21634, 1090, 9126, 316, 689, 800, 29892, 1316, 408, 16116, 292, 322, 289, 2548, 29889, 512, 278, 3030, 310, 341, 1475, 705, 11356, 29892, 445, 1033, 367, 1304, 304, 27599, 278, 3829, 310, 278, 3748, 7613, 322, 8161, 278, 14354, 310, 29086, 2729, 373, 3058, 25002, 15939, 1934, 29889, 2398, 29892, 16049, 445, 723, 1996, 263, 6483, 8004, 310, 1716, 21531, 20159, 322, 278, 3748, 310, 341, 1475, 705, 11356, 29889, 739, 29915, 29879, 451, 1554, 393, 508, 367, 10824, 297, 263, 2560, 1234, 21106, 29879, 29958, 2, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 18585, 29889, 450, 6996, 2969, 723, 367, 304, 671, 21531, 20159, 304, 27599, 278, 3829, 310, 278, 341, 1475, 705, 11356, 3748, 7613, 322, 8161, 278, 14354, 310, 29086, 2729, 373, 3058, 25002, 15939, 1934, 29889, 910, 1033, 25135, 15783, 278, 3748, 7613, 408, 263, 25002, 2913, 322, 773, 25002, 15939, 1934, 1316, 408, 3632, 327, 2270, 6471, 322, 350, 9890, 3694, 304, 10115, 2472, 1048, 278, 4423, 310, 29086, 373, 278, 7613, 29889, 2398, 29892, 408, 306, 5276, 1434, 29892, 16049, 445, 723, 1996, 263, 6483, 8004, 310, 1716, 21531, 20159, 322, 278, 3748, 310, 341, 1475, 705, 11356, 29892, 577, 372, 29915, 29879, 451, 1554, 393, 508, 367, 10824, 297, 263, 2560, 1234, 21106, 29879, 29958, 2, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1094, 263, 2702, 1342, 29892, 7755, 591, 505, 263, 341, 1475, 705, 11356, 3748, 7613, 411, 278, 1494, 5285, 29901, 13, 28956, 13, 29871, 29896, 259, 29896, 259, 29896, 259, 29896, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 259, 29896, 259, 29896, 259, 29896, 259, 29896, 13, 28956, 13, 797, 445, 1206, 29892, 278, 3748, 7613, 508, 367, 9875, 408, 263, 25002, 2913, 29892, 988, 1269, 6862, 373, 278, 7613, 16161, 304, 263, 1298, 297, 278, 2913, 29889, 450, 29086, 373, 278, 7613, 508, 367, 2714, 310, 408, 376, 5391, 267, 29908, 297, 278, 2913, 29892, 607, 526, 3291, 393, 526, 451, 760, 310, 278, 2913, 29889, 5293, 21531, 20159, 29892, 591, 508, 27599, 278, 3829, 310, 445, 2913, 322, 8161, 278, 14354, 310, 278, 29086, 2729, 373, 3058, 25002, 15939, 1934, 29889, 13, 13, 2831, 2777, 29892, 697, 982, 304, 437, 445, 338, 304, 10272, 278, 3632, 327, 2270, 6471, 310, 278, 2913, 29889, 15089, 327, 2270, 6471, 526, 263, 982, 310, 20766, 278, 26532, 297, 263, 2913, 29892, 322, 896, 508, 367, 1304, 304, 10115, 2472, 1048, 278, 14354, 310, 278, 29086, 373, 278, 341, 1475, 705, 11356, 7613, 29889, 512, 278, 1206, 310, 278, 7613, 2038, 29892, 591, 508, 10272, 278, 937, 3632, 327, 2270, 2318, 310, 278, 2913, 29892, 607, 16612, 278, 1353, 322, 376, 12181, 29908, 310, 278, 26532, 297, 278, 2913, 29889, 512, 445, 1206, 29892, 278, 937, 3632, 327, 2270, 2318, 723, 2649, 502, 393, 727, 526, 3023, 26532, 297, 278, 2913, 29892, 21050, 297, 263, 6862, 8267, 29889, 13, 13, 29933, 1463, 373, 445, 2472, 29892, 591, 508, 10115, 393, 278, 29086, 297, 278, 3748, 526, 5982, 297, 278, 3023, 26995, 310, 278, 7613, 29892, 408, 4318, 2400, 29901, 13, 28956, 13, 29871, 29896, 259, 29896, 259, 29896, 259, 29896, 259, 29896, 13, 29871, 29896, 29871, 341, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 341, 259, 29896, 13, 29871, 29896, 259, 29896, 259, 29896, 259, 29896, 259, 29896, 13, 28956, 13, 2776, 3236, 29892, 445, 338, 925, 697, 1342, 310, 920, 21531, 20159, 1033, 367, 1304, 304, 4505, 341, 1475, 705, 11356, 29889, 1670, 526, 1784, 916, 1950, 13501, 29892, 322, 278, 2702, 4902, 310, 920, 304, 2334, 963, 723, 8839, 373, 278, 2702, 3748, 7613, 322, 278, 25002, 15939, 1934, 1641, 1304, 29889, 1094, 306, 5276, 1434, 29892, 16049, 263, 1650, 763, 445, 723, 1996, 263, 6483, 8004, 310, 1716, 21531, 20159, 322, 278, 3748, 310, 341, 1475, 705, 11356, 21106, 29879, 29958, 2, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 18585, 29889, 17090, 366, 505, 263, 6483, 8004, 310, 1716, 21531, 20159, 322, 278, 3748, 310, 341, 1475, 705, 11356, 29892, 1244, 338, 920, 366, 1795, 671, 21531, 20159, 304, 4505, 263, 341, 1475, 705, 11356, 3748, 7613, 29901, 13, 13, 29896, 29889, 16314, 278, 341, 1475, 705, 11356, 3748, 7613, 408, 263, 25002, 2913, 29889, 7806, 6862, 373, 278, 7613, 16161, 304, 263, 1298, 297, 278, 2913, 29892, 322, 278, 29086, 373, 278, 7613, 508, 367, 2714, 310, 408, 376, 5391, 267, 29908, 297, 278, 2913, 29892, 607, 526, 3291, 393, 526, 451, 760, 310, 278, 2913, 29889, 13, 29906, 29889, 4803, 25002, 15939, 1934, 304, 10115, 2472, 1048, 278, 14354, 310, 278, 29086, 373, 278, 7613, 29889, 1670, 526, 1784, 1950, 25002, 15939, 1934, 393, 1033, 367, 1304, 363, 445, 29892, 1316, 408, 3632, 327, 2270, 6471, 29892, 350, 9890, 3694, 29892, 470, 28152, 3632, 3002, 29889, 13, 29941, 29889, 4803, 278, 2472, 10115, 1127, 515, 278, 25002, 15939, 1934, 304, 8161, 278, 14354, 310, 278, 29086, 373, 278, 7613, 29889, 1152, 1342, 29892, 565, 278, 937, 3632, 327, 2270, 2318, 310, 278, 2913, 14088, 393, 727, 526, 3023, 26532, 297, 278, 2913, 29892, 21050, 297, 263, 6862, 8267, 29892, 769, 591, 508, 10115, 393, 278, 29086, 526, 5982, 297, 278, 3023, 26995, 310, 278, 7613, 29889, 13, 29946, 29889, 4803, 445, 2472, 304, 4505, 278, 341, 1475, 705, 11356, 3748, 29889, 9038, 366, 1073, 278, 14354, 310, 278, 29086, 29892, 366, 508, 671, 445, 2472, 304, 4772, 963, 322, 2821, 278, 7613, 29889, 13, 13, 2776, 3236, 29892, 445, 338, 925, 697, 1950, 2948, 304, 773, 21531, 20159, 304, 4505, 341, 1475, 705, 11356, 29889, 1670, 1122, 367, 916, 5837, 304, 437, 372, 29892, 8679, 373, 278, 2702, 3748, 7613, 322, 278, 25002, 15939, 1934, 1641, 1304, 29889, 2398, 29892, 445, 881, 2367, 366, 263, 2498, 2969, 310, 920, 21531, 20159, 508, 367, 1304, 304, 4505, 341, 1475, 705, 11356, 21106, 29879, 29958, 2, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1763, 671, 28152, 3632, 3002, 470, 738, 310, 278, 916, 25002, 15939, 1934, 304, 4505, 263, 341, 1475, 705, 11356, 3748, 7613, 29892, 366, 723, 1101, 278, 1021, 6996, 6576, 408, 5439, 2038, 29892, 541, 411, 777, 12651, 297, 278, 2702, 4902, 29889, 13, 13, 2831, 2777, 29892, 1244, 338, 920, 366, 1795, 671, 28152, 3632, 3002, 304, 4505, 263, 341, 1475, 705, 11356, 3748, 7613, 29901, 13, 13, 29896, 29889, 16314, 278, 341, 1475, 705, 11356, 3748, 7613, 408, 263, 25002, 2913, 29889, 7806, 6862, 373, 278, 7613, 16161, 304, 263, 1298, 297, 278, 2913, 29892, 322, 278, 29086, 373, 278, 7613, 508, 367, 2714, 310, 408, 376, 5391, 267, 29908, 297, 278, 2913, 29892, 607, 526, 3291, 393, 526, 451, 760, 310, 278, 2913, 29889, 13, 29906, 29889, 11796, 29872, 278, 28152, 3632, 3002, 310, 278, 2913, 29889, 9034, 9696, 3632, 3002, 338, 263, 982, 310, 20766, 278, 26532, 297, 263, 2913, 975, 931, 29892, 322, 372, 508, 367, 1304, 304, 10115, 2472, 1048, 278, 14354, 310, 278, 29086, 373, 278, 341, 1475, 705, 11356, 7613, 29889, 13, 29941, 29889, 4803, 278, 2472, 515, 278, 28152, 3632, 3002, 304, 8161, 278, 14354, 310, 278, 29086, 373, 278, 7613, 29889, 1152, 1342, 29892, 565, 278, 28152, 3632, 3002, 14088, 393, 727, 526, 3023, 26532, 297, 278, 2913, 29892, 21050, 297, 263, 6862, 8267, 29892, 769, 591, 508, 10115, 393, 278, 29086, 526, 5982, 297, 278, 3023, 26995, 310, 278, 7613, 29889, 13, 29946, 29889, 4803, 445, 2472, 304, 4505, 278, 341, 1475, 705, 11356, 3748, 29889, 9038, 366, 1073, 278, 14354, 310, 278, 29086, 29892, 366, 508, 671, 445, 2472, 304, 4772, 963, 322, 2821, 278, 7613, 29889, 13, 13, 2776, 3236, 29892, 445, 338, 925, 697, 1950, 2948, 304, 773, 28152, 3632, 3002, 304, 4505, 341, 1475, 705, 11356, 29889, 1670, 1122, 367, 916, 5837, 304, 437, 372, 29892, 8679, 373, 278, 2702, 3748, 7613, 322, 278, 2702, 4902, 310, 278, 28152, 3632, 3002, 16287, 29889, 2398, 29892, 445, 881, 2367, 366, 263, 2498, 2969, 310, 920, 28152, 3632, 3002, 508, 367, 1304, 304, 4505, 341, 1475, 705, 11356, 21106, 29879, 29958, 2, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1763, 2334, 263, 341, 1475, 705, 11356, 899, 369, 393, 3913, 21531, 20159, 297, 5132, 29892, 366, 723, 817, 304, 437, 278, 1494, 29901, 13, 13, 29896, 29889, 16052, 263, 5132, 3577, 363, 21531, 20159, 29889, 1670, 526, 3196, 3987, 3625, 29892, 1316, 408, 402, 566, 2918, 29892, 360, 291, 952, 375, 29892, 470, 4560, 7354, 29899, 29873, 1388, 29889, 4525, 9741, 3867, 14009, 322, 848, 12286, 363, 20602, 25002, 15939, 1934, 29892, 1316, 408, 3632, 327, 2270, 6471, 29892, 350, 9890, 3694, 29892, 470, 28152, 3632, 3002, 29889, 13, 29906, 29889, 16314, 278, 341, 1475, 705, 11356, 3748, 7613, 408, 263, 25002, 2913, 29889, 910, 1033, 367, 2309, 773, 263, 848, 3829, 4944, 491, 278, 21531, 20159, 3577, 29892, 1316, 408, 263, 3053, 506, 616, 4280, 470, 263, 478, 2035, 29367, 29899, 29934, 4512, 4280, 29889, 7806, 6862, 373, 278, 7613, 723, 3928, 304, 263, 1298, 297, 278, 2913, 29892, 322, 278, 29086, 373, 278, 7613, 723, 367, 9875, 408, 376, 5391, 267, 29908, 297, 278, 2913, 29889, 13, 29941, 29889, 11796, 29872, 278, 25002, 15939, 1934, 310, 278, 2913, 29889, 910, 1033, 367, 2309, 773, 14009, 4944, 491, 278, 21531, 20159, 3577, 29889, 1152, 1342, 29892, 366, 1033, 10272, 278, 28152, 3632, 3002, 310, 278, 2913, 773, 278, 28152, 3632, 3002, 5687, 29889, 13, 29946, 29889, 4803, 278, 2472, 515, 278, 25002, 15939, 1934, 304, 8161, 278, 14354, 310, 278, 29086, 373, 278, 7613, 29889, 910, 1033, 367, 2309, 773, 263, 10296, 310, 19475, 24481, 322, 8720, 29889, 1152, 1342, 29892, 565, 278, 28152, 3632, 3002, 14088, 393, 727, 526, 3023, 26532, 297, 278, 2913, 29892, 21050, 297, 263, 6862, 8267, 29892, 769, 366, 1033, 671, 263, 8424, 310, 775, 304, 10115, 393, 278, 29086, 526, 5982, 297, 278, 3023, 26995, 310, 278, 7613, 29889, 13, 29945, 29889, 4803, 445, 2472, 304, 4505, 278, 341, 1475, 705, 11356, 3748, 29889, 9038, 366, 1073, 278, 14354, 310, 278, 29086, 29892, 366, 1033, 671, 263, 8424, 310, 775, 304, 6336, 2828, 373, 278, 9109, 25256, 322, 2821, 278, 7613, 29889, 13, 13, 2776, 3236, 29892, 445, 338, 925, 697, 1950, 2948, 304, 16049, 263, 341, 1475, 705, 11356, 899, 369, 393, 3913, 21531, 20159, 297, 5132, 29889, 1670, 1122, 367, 916, 5837, 304, 437, 372, 29892, 8679, 373, 278, 2702, 4902, 310, 278, 5314, 29889, 2398, 29892, 445, 881, 2367, 366, 263, 2498, 2969, 310, 920, 304, 679, 4687, 411, 1316, 263, 2060, 21106, 29879, 29958, 2]}
diff --git a/tests/test_prompt_tokenizers.py b/tests/test_prompt_tokenizers.py
index 1dd511f6b8..0635bd718b 100644
--- a/tests/test_prompt_tokenizers.py
+++ b/tests/test_prompt_tokenizers.py
@@ -3,7 +3,9 @@
import logging
import unittest
from pathlib import Path
+from typing import Optional
+import pytest
from transformers import AutoTokenizer, LlamaTokenizer
from axolotl.prompt_strategies.alpaca_chat import NoSystemPrompter
@@ -19,7 +21,7 @@
AlpacaPromptTokenizingStrategy,
ShareGPTPromptTokenizingStrategy,
)
-from axolotl.prompters import AlpacaPrompter, PromptStyle, ShareGPTPrompter
+from axolotl.prompters import AlpacaPrompter, PromptStyle, ShareGPTPrompterV2
LOG = logging.getLogger("axolotl")
@@ -29,6 +31,12 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
Test class for prompt tokenization strategies.
"""
+ _caplog: Optional[pytest.LogCaptureFixture] = None
+
+ @pytest.fixture(autouse=True)
+ def inject_fixtures(self, caplog):
+ self._caplog = caplog
+
def setUp(self) -> None:
# pylint: disable=duplicate-code
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
@@ -52,7 +60,7 @@ def test_sharegpt_integration(self):
) as fin:
data = fin.read()
tokenized_conversation = json.loads(data)
- prompter = ShareGPTPrompter("chat")
+ prompter = ShareGPTPrompterV2()
strat = ShareGPTPromptTokenizingStrategy(
prompter,
self.tokenizer,
@@ -64,6 +72,91 @@ def test_sharegpt_integration(self):
self.assertEqual(len(example[fields]), len(tokenized_conversation[fields]))
self.assertEqual(example[fields], tokenized_conversation[fields])
+ def test_sharegpt_warnings_integration(self):
+ with open(
+ Path(__file__).parent / "fixtures/conversation.missingturns.json",
+ encoding="utf-8",
+ ) as fin:
+ data = fin.read()
+ conversation = json.loads(data)
+ prompter = ShareGPTPrompterV2()
+ strat = ShareGPTPromptTokenizingStrategy(
+ prompter,
+ self.tokenizer,
+ False,
+ 2048,
+ )
+ with self._caplog.at_level(logging.WARNING):
+ strat.tokenize_prompt(conversation)
+ assert "assistant turn has empty text" in self._caplog.records[1].message
+
+ def test_sharegpt_warnings_turns(self):
+ conversation = {
+ "conversations": [
+ {"from": "system", "value": "lorem"},
+ {"from": "gpt", "value": "ipsum"},
+ {"from": "human", "value": "dolor"},
+ {"from": "human", "value": "dolor"},
+ {"from": "gpt", "value": "sit"},
+ ]
+ }
+ prompter = ShareGPTPrompterV2()
+ strat = ShareGPTPromptTokenizingStrategy(
+ prompter,
+ self.tokenizer,
+ False,
+ 2048,
+ )
+ with self._caplog.at_level(logging.WARNING):
+ strat.tokenize_prompt(conversation)
+ assert (
+ "Role did not alternate between turns (gpt and human)"
+ in self._caplog.records[0].message
+ )
+
+ def test_sharegpt_changes_roles(self):
+ conversation = {
+ "roles": ["USER", "CHARACTER"],
+ "conversations": [
+ {"from": "system", "value": "lorem"},
+ {"from": "gpt", "value": "ipsum"},
+ {"from": "human", "value": "dolor"},
+ {"from": "gpt", "value": "sit"},
+ ],
+ }
+ prompter = ShareGPTPrompterV2()
+ strat = ShareGPTPromptTokenizingStrategy(
+ prompter,
+ self.tokenizer,
+ False,
+ 2048,
+ )
+ with self._caplog.at_level(logging.WARNING):
+ res = strat.tokenize_prompt(conversation)
+ assert "CHARACTER" in self.tokenizer.decode(res["input_ids"])
+
+ def test_sharegpt_assistant_label_ignore(self):
+ conversation = {
+ "roles": ["user", "assistant"],
+ "conversations": [
+ {"from": "system", "value": "lorem"},
+ {"from": "gpt", "value": "ipsum"},
+ {"from": "human", "value": "dolor"},
+ {"from": "gpt", "value": "sit"},
+ ],
+ }
+ prompter = ShareGPTPrompterV2()
+ strat = ShareGPTPromptTokenizingStrategy(
+ prompter,
+ self.tokenizer,
+ False,
+ 2048,
+ )
+ with self._caplog.at_level(logging.WARNING):
+ res = strat.tokenize_prompt(conversation)
+ idx = res["input_ids"].index(20255) # assistant token
+ assert res["labels"][idx] == -100
+
def test_no_sys_prompt(self):
"""
tests the interface between the user and assistant parts
diff --git a/tests/test_validation.py b/tests/test_validation.py
index f250e5cb47..35d90a2cb4 100644
--- a/tests/test_validation.py
+++ b/tests/test_validation.py
@@ -351,3 +351,217 @@ def test_packing(self):
regex_exp = r".*set only one of max_packed_sequence_len \(deprecated soon\) or sample_packing.*"
with pytest.raises(ValueError, match=regex_exp):
validate_config(cfg)
+
+ def test_merge_lora_no_bf16_fail(self):
+ """
+ This is assumed to be run on a CPU machine, so bf16 is not supported.
+ """
+
+ cfg = DictDefault(
+ {
+ "bf16": True,
+ }
+ )
+
+ with pytest.raises(ValueError, match=r".*AMP is not supported on this GPU*"):
+ validate_config(cfg)
+
+ cfg = DictDefault(
+ {
+ "bf16": True,
+ "merge_lora": True,
+ }
+ )
+
+ validate_config(cfg)
+
+ def test_sharegpt_deprecation(self):
+ cfg = DictDefault(
+ {"datasets": [{"path": "lorem/ipsum", "type": "sharegpt:chat"}]}
+ )
+ with self._caplog.at_level(logging.WARNING):
+ validate_config(cfg)
+ assert any(
+ "`type: sharegpt:chat` will soon be deprecated." in record.message
+ for record in self._caplog.records
+ )
+ assert cfg.datasets[0].type == "sharegpt"
+
+ cfg = DictDefault(
+ {"datasets": [{"path": "lorem/ipsum", "type": "sharegpt_simple:load_role"}]}
+ )
+ with self._caplog.at_level(logging.WARNING):
+ validate_config(cfg)
+ assert any(
+ "`type: sharegpt_simple` will soon be deprecated." in record.message
+ for record in self._caplog.records
+ )
+ assert cfg.datasets[0].type == "sharegpt:load_role"
+
+ def test_no_conflict_save_strategy(self):
+ cfg = DictDefault(
+ {
+ "save_strategy": "epoch",
+ "save_steps": 10,
+ }
+ )
+
+ with pytest.raises(
+ ValueError, match=r".*save_strategy and save_steps mismatch.*"
+ ):
+ validate_config(cfg)
+
+ cfg = DictDefault(
+ {
+ "save_strategy": "no",
+ "save_steps": 10,
+ }
+ )
+
+ with pytest.raises(
+ ValueError, match=r".*save_strategy and save_steps mismatch.*"
+ ):
+ validate_config(cfg)
+
+ cfg = DictDefault(
+ {
+ "save_strategy": "steps",
+ }
+ )
+
+ validate_config(cfg)
+
+ cfg = DictDefault(
+ {
+ "save_strategy": "steps",
+ "save_steps": 10,
+ }
+ )
+
+ validate_config(cfg)
+
+ cfg = DictDefault(
+ {
+ "save_steps": 10,
+ }
+ )
+
+ validate_config(cfg)
+
+ cfg = DictDefault(
+ {
+ "save_strategy": "no",
+ }
+ )
+
+ validate_config(cfg)
+
+ def test_no_conflict_eval_strategy(self):
+ cfg = DictDefault(
+ {
+ "evaluation_strategy": "epoch",
+ "eval_steps": 10,
+ }
+ )
+
+ with pytest.raises(
+ ValueError, match=r".*evaluation_strategy and eval_steps mismatch.*"
+ ):
+ validate_config(cfg)
+
+ cfg = DictDefault(
+ {
+ "evaluation_strategy": "no",
+ "eval_steps": 10,
+ }
+ )
+
+ with pytest.raises(
+ ValueError, match=r".*evaluation_strategy and eval_steps mismatch.*"
+ ):
+ validate_config(cfg)
+
+ cfg = DictDefault(
+ {
+ "evaluation_strategy": "steps",
+ }
+ )
+
+ validate_config(cfg)
+
+ cfg = DictDefault(
+ {
+ "evaluation_strategy": "steps",
+ "eval_steps": 10,
+ }
+ )
+
+ validate_config(cfg)
+
+ cfg = DictDefault(
+ {
+ "eval_steps": 10,
+ }
+ )
+
+ validate_config(cfg)
+
+ cfg = DictDefault(
+ {
+ "evaluation_strategy": "no",
+ }
+ )
+
+ validate_config(cfg)
+
+ cfg = DictDefault(
+ {
+ "evaluation_strategy": "epoch",
+ "val_set_size": 0,
+ }
+ )
+
+ with pytest.raises(
+ ValueError,
+ match=r".*eval_steps and evaluation_strategy are not supported with val_set_size == 0.*",
+ ):
+ validate_config(cfg)
+
+ cfg = DictDefault(
+ {
+ "eval_steps": 10,
+ "val_set_size": 0,
+ }
+ )
+
+ with pytest.raises(
+ ValueError,
+ match=r".*eval_steps and evaluation_strategy are not supported with val_set_size == 0.*",
+ ):
+ validate_config(cfg)
+
+ cfg = DictDefault(
+ {
+ "val_set_size": 0,
+ }
+ )
+
+ validate_config(cfg)
+
+ cfg = DictDefault(
+ {
+ "eval_steps": 10,
+ "val_set_size": 0.01,
+ }
+ )
+
+ validate_config(cfg)
+
+ cfg = DictDefault(
+ {
+ "evaluation_strategy": "epoch",
+ "val_set_size": 0.01,
+ }
+ )
+
+ validate_config(cfg)