diff --git a/.github/workflows/base.yml b/.github/workflows/base.yml index ea1da66840..4c1b0463a7 100644 --- a/.github/workflows/base.yml +++ b/.github/workflows/base.yml @@ -30,7 +30,12 @@ jobs: - cuda: "121" cuda_version: 12.1.0 python_version: "3.11" - pytorch: 2.2.1 + pytorch: 2.2.2 + torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" + - cuda: "121" + cuda_version: 12.1.0 + python_version: "3.11" + pytorch: 2.3.0 torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" steps: - name: Checkout diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 07a271769c..d0d0289824 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -28,7 +28,12 @@ jobs: - cuda: 121 cuda_version: 12.1.0 python_version: "3.11" - pytorch: 2.2.1 + pytorch: 2.2.2 + axolotl_extras: + - cuda: 121 + cuda_version: 12.1.0 + python_version: "3.11" + pytorch: 2.3.0 axolotl_extras: runs-on: axolotl-gpu-runner steps: @@ -84,7 +89,12 @@ jobs: - cuda: 121 cuda_version: 12.1.0 python_version: "3.11" - pytorch: 2.2.1 + pytorch: 2.2.2 + axolotl_extras: + - cuda: 121 + cuda_version: 12.1.0 + python_version: "3.11" + pytorch: 2.3.0 axolotl_extras: runs-on: axolotl-gpu-runner steps: @@ -115,3 +125,45 @@ jobs: ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} ${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }} labels: ${{ steps.metadata.outputs.labels }} + + build-axolotl-cloud-no-tmux: + needs: build-axolotl + if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'OpenAccess-AI-Collective' }} + # this job needs to be run on self-hosted GPU runners... + strategy: + matrix: + include: + - cuda: 121 + cuda_version: 12.1.0 + python_version: "3.11" + pytorch: 2.3.0 + axolotl_extras: + runs-on: axolotl-gpu-runner + steps: + - name: Checkout + uses: actions/checkout@v4 + - name: Docker metadata + id: metadata + uses: docker/metadata-action@v5 + with: + images: winglian/axolotl-cloud-term + - name: Login to Docker Hub + uses: docker/login-action@v3 + with: + username: ${{ secrets.DOCKERHUB_USERNAME }} + password: ${{ secrets.DOCKERHUB_TOKEN }} + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v2 + - name: Build + uses: docker/build-push-action@v5 + with: + context: . + build-args: | + BASE_TAG=${{ github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} + CUDA=${{ matrix.cuda }} + file: ./docker/Dockerfile-cloud-no-tmux + push: ${{ github.event_name != 'pull_request' }} + tags: | + ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} + ${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }} + labels: ${{ steps.metadata.outputs.labels }} diff --git a/.github/workflows/nightlies.yml b/.github/workflows/nightlies.yml index cc263a887a..f668e5f65b 100644 --- a/.github/workflows/nightlies.yml +++ b/.github/workflows/nightlies.yml @@ -27,7 +27,12 @@ jobs: - cuda: 121 cuda_version: 12.1.0 python_version: "3.11" - pytorch: 2.2.1 + pytorch: 2.2.2 + axolotl_extras: + - cuda: 121 + cuda_version: 12.1.0 + python_version: "3.11" + pytorch: 2.3.0 axolotl_extras: runs-on: axolotl-gpu-runner steps: @@ -84,7 +89,12 @@ jobs: - cuda: 121 cuda_version: 12.1.0 python_version: "3.11" - pytorch: 2.2.1 + pytorch: 2.2.2 + axolotl_extras: + - cuda: 121 + cuda_version: 12.1.0 + python_version: "3.11" + pytorch: 2.3.0 axolotl_extras: runs-on: axolotl-gpu-runner steps: diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index a53640e0b0..8f25eddc31 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -82,7 +82,7 @@ jobs: - cuda: 121 cuda_version: 12.1.0 python_version: "3.11" - pytorch: 2.2.1 + pytorch: 2.2.2 num_gpus: 1 steps: - name: Checkout diff --git a/README.md b/README.md index 4f055213cc..30a399d16f 100644 --- a/README.md +++ b/README.md @@ -34,6 +34,7 @@ Features: - [Mac](#mac) - [Google Colab](#google-colab) - [Launching on public clouds via SkyPilot](#launching-on-public-clouds-via-skypilot) + - [Launching on public clouds via dstack](#launching-on-public-clouds-via-dstack) - [Dataset](#dataset) - [Config](#config) - [Train](#train) @@ -292,6 +293,42 @@ HF_TOKEN=xx sky launch axolotl.yaml --env HF_TOKEN HF_TOKEN=xx BUCKET= sky spot launch axolotl-spot.yaml --env HF_TOKEN --env BUCKET ``` +#### Launching on public clouds via dstack +To launch on GPU instance (both on-demand and spot instances) on public clouds (GCP, AWS, Azure, Lambda Labs, TensorDock, Vast.ai, and CUDO), you can use [dstack](https://dstack.ai/). + +Write a job description in YAML as below: + +```yaml +# dstack.yaml +type: task + +image: winglian/axolotl-cloud:main-20240429-py3.11-cu121-2.2.2 + +env: + - HUGGING_FACE_HUB_TOKEN + - WANDB_API_KEY + +commands: + - accelerate launch -m axolotl.cli.train config.yaml + +ports: + - 6006 + +resources: + gpu: + memory: 24GB.. + count: 2 +``` + +then, simply run the job with `dstack run` command. Append `--spot` option if you want spot instance. `dstack run` command will show you the instance with cheapest price across multi cloud services: + +```bash +pip install dstack +HUGGING_FACE_HUB_TOKEN=xxx WANDB_API_KEY=xxx dstack run . -f dstack.yaml # --spot +``` + +For further and fine-grained use cases, please refer to the official [dstack documents](https://dstack.ai/docs/) and the detailed description of [axolotl example](https://github.com/dstackai/dstack/tree/master/examples/fine-tuning/axolotl) on the official repository. + ### Dataset Axolotl supports a variety of dataset formats. It is recommended to use a JSONL. The schema of the JSONL depends upon the task and the prompt template you wish to use. Instead of a JSONL, you can also use a HuggingFace dataset with columns for each JSONL field. diff --git a/docker/Dockerfile b/docker/Dockerfile index fefa041c01..9ba29d9a38 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -11,7 +11,7 @@ ARG PYTORCH_VERSION="2.1.2" ENV PYTORCH_VERSION=$PYTORCH_VERSION RUN apt-get update && \ - apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev apt-transport-https ca-certificates gnupg + apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev rsync s3fs apt-transport-https ca-certificates gnupg RUN echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.cloud.google.com/apt cloud-sdk main" | tee -a /etc/apt/sources.list.d/google-cloud-sdk.list && \ curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | gpg --dearmor -o /usr/share/keyrings/cloud.google.gpg && \ diff --git a/docker/Dockerfile-cloud-no-tmux b/docker/Dockerfile-cloud-no-tmux new file mode 100644 index 0000000000..efeffef8e6 --- /dev/null +++ b/docker/Dockerfile-cloud-no-tmux @@ -0,0 +1,27 @@ +ARG BASE_TAG=main +FROM winglian/axolotl:$BASE_TAG + +ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets" +ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub" +ENV TRANSFORMERS_CACHE="/workspace/data/huggingface-cache/hub" +ENV HF_HOME="/workspace/data/huggingface-cache/hub" +ENV HF_HUB_ENABLE_HF_TRANSFER="1" + +EXPOSE 8888 +EXPOSE 22 + +COPY scripts/cloud-entrypoint-term.sh /root/cloud-entrypoint.sh +COPY scripts/motd /etc/motd + +RUN pip install jupyterlab notebook ipywidgets && \ + jupyter lab clean +RUN apt install --yes --no-install-recommends openssh-server tmux sudo && \ + pip3 install -U --no-cache-dir grpcio ray[default]==2.9.3 && \ + mkdir -p ~/.ssh && \ + chmod 700 ~/.ssh && \ + printf "[ ! -z \"\$TERM\" -a -r /etc/motd ] && cat /etc/motd\n" >> ~/.bashrc && \ + chmod +x /workspace/axolotl/scripts/cloud-entrypoint.sh && \ + chmod +x /root/cloud-entrypoint.sh + +ENTRYPOINT ["/root/cloud-entrypoint.sh"] +CMD ["sleep", "infinity"] diff --git a/docs/config.qmd b/docs/config.qmd index dadc5c487c..1c87386a6d 100644 --- a/docs/config.qmd +++ b/docs/config.qmd @@ -186,6 +186,11 @@ eval_sample_packing: # The trainer will provide recommended values for these values. sample_packing_eff_est: total_num_tokens: +# Increasing the following values helps with packing, but usually only slightly (<%1.) +# The number of samples packed at a time. +sample_packing_group_size: 100000 +# The number of samples which can be packed into one sequence. Increase if using a large sequence_len with many short samples. +sample_packing_bin_size: 200 # Passed through to transformers when loading the model when launched without accelerate # Use `sequential` when training w/ model parallelism to limit memory @@ -227,6 +232,12 @@ lora_modules_to_save: lora_fan_in_fan_out: false +# LoRA+ hyperparameters +# For more details about the following options, see: +# https://arxiv.org/abs/2402.12354 and `src/axolotl/core/train_builder.py` +loraplus_lr_ratio: # loraplus learning rate ratio lr_B / lr_A. Recommended value is 2^4. +loraplus_lr_embedding: # loraplus learning rate for lora embedding layers. Default value is 1e-6. + peft: # Configuration options for loftq initialization for LoRA # https://huggingface.co/docs/peft/developer_guides/quantization#loftq-initialization @@ -279,7 +290,7 @@ lr_quadratic_warmup: logging_steps: eval_steps: # Leave empty to eval at each epoch, integers for every N steps. decimal for fraction of total steps evals_per_epoch: # number of times per epoch to run evals, mutually exclusive with eval_steps -save_strategy: # Set to `no` to skip checkpoint saves +save_strategy: # Set to `"no"` to skip checkpoint saves save_steps: # Leave empty to save at each epoch saves_per_epoch: # number of times per epoch to save a checkpoint, mutually exclusive with save_steps save_total_limit: # Checkpoints saved at a time diff --git a/examples/cerebras/btlm-ft.yml b/examples/cerebras/btlm-ft.yml index 18dd86e6b4..ba4e65daae 100644 --- a/examples/cerebras/btlm-ft.yml +++ b/examples/cerebras/btlm-ft.yml @@ -38,7 +38,7 @@ wandb_watch: wandb_name: wandb_log_model: -output_dir: btlm-out +output_dir: ./outputs/btlm-out gradient_accumulation_steps: 1 micro_batch_size: 1 num_epochs: 1 diff --git a/examples/cerebras/qlora.yml b/examples/cerebras/qlora.yml index c4f44326c2..285607a4c8 100644 --- a/examples/cerebras/qlora.yml +++ b/examples/cerebras/qlora.yml @@ -25,7 +25,7 @@ wandb_entity: wandb_watch: wandb_name: wandb_log_model: -output_dir: ./qlora-out +output_dir: ./outputs/qlora-out batch_size: 4 micro_batch_size: 4 num_epochs: 2 diff --git a/examples/code-llama/13b/lora.yml b/examples/code-llama/13b/lora.yml index ce5a892d08..0ba96cfaa7 100644 --- a/examples/code-llama/13b/lora.yml +++ b/examples/code-llama/13b/lora.yml @@ -11,7 +11,7 @@ datasets: type: alpaca dataset_prepared_path: val_set_size: 0.05 -output_dir: ./lora-out +output_dir: ./outputs/lora-out sequence_len: 4096 sample_packing: true diff --git a/examples/code-llama/13b/qlora.yml b/examples/code-llama/13b/qlora.yml index d822e68470..787862d010 100644 --- a/examples/code-llama/13b/qlora.yml +++ b/examples/code-llama/13b/qlora.yml @@ -11,7 +11,7 @@ datasets: type: alpaca dataset_prepared_path: val_set_size: 0.05 -output_dir: ./qlora-out +output_dir: ./outputs/qlora-out adapter: qlora lora_model_dir: diff --git a/examples/code-llama/34b/lora.yml b/examples/code-llama/34b/lora.yml index dfef2538b0..92d4c544a3 100644 --- a/examples/code-llama/34b/lora.yml +++ b/examples/code-llama/34b/lora.yml @@ -11,7 +11,7 @@ datasets: type: alpaca dataset_prepared_path: val_set_size: 0.05 -output_dir: ./lora-out +output_dir: ./outputs/lora-out sequence_len: 4096 sample_packing: true diff --git a/examples/code-llama/34b/qlora.yml b/examples/code-llama/34b/qlora.yml index 77f821e1c8..93a6de8777 100644 --- a/examples/code-llama/34b/qlora.yml +++ b/examples/code-llama/34b/qlora.yml @@ -11,7 +11,7 @@ datasets: type: alpaca dataset_prepared_path: val_set_size: 0.05 -output_dir: ./qlora-out +output_dir: ./outputs/qlora-out adapter: qlora lora_model_dir: diff --git a/examples/code-llama/7b/lora.yml b/examples/code-llama/7b/lora.yml index 3e6c7fe620..d13f505325 100644 --- a/examples/code-llama/7b/lora.yml +++ b/examples/code-llama/7b/lora.yml @@ -11,7 +11,7 @@ datasets: type: alpaca dataset_prepared_path: val_set_size: 0.05 -output_dir: ./lora-out +output_dir: ./outputs/lora-out sequence_len: 4096 sample_packing: true diff --git a/examples/code-llama/7b/qlora.yml b/examples/code-llama/7b/qlora.yml index e817b113cc..a1026a982d 100644 --- a/examples/code-llama/7b/qlora.yml +++ b/examples/code-llama/7b/qlora.yml @@ -11,7 +11,7 @@ datasets: type: alpaca dataset_prepared_path: val_set_size: 0.05 -output_dir: ./qlora-out +output_dir: ./outputs/qlora-out adapter: qlora lora_model_dir: diff --git a/examples/colab-notebooks/colab-axolotl-example.ipynb b/examples/colab-notebooks/colab-axolotl-example.ipynb index 9adbe00047..fc3b761949 100644 --- a/examples/colab-notebooks/colab-axolotl-example.ipynb +++ b/examples/colab-notebooks/colab-axolotl-example.ipynb @@ -84,7 +84,7 @@ " type: alpaca\n", "dataset_prepared_path:\n", "val_set_size: 0.05\n", - "output_dir: ./qlora-out\n", + "output_dir: ./outputs/qlora-out\n", "\n", "adapter: qlora\n", "lora_model_dir:\n", diff --git a/examples/dbrx/16bit-lora.yaml b/examples/dbrx/16bit-lora.yaml index e5e3ea9216..32b625ac69 100644 --- a/examples/dbrx/16bit-lora.yaml +++ b/examples/dbrx/16bit-lora.yaml @@ -10,7 +10,7 @@ datasets: type: alpaca dataset_prepared_path: last_run_prepared val_set_size: 0.0 -output_dir: ./out +output_dir: ./outputs/out sequence_len: 512 sample_packing: false diff --git a/examples/dbrx/8bit-lora.yaml b/examples/dbrx/8bit-lora.yaml index 89e24db058..50ee0a0164 100644 --- a/examples/dbrx/8bit-lora.yaml +++ b/examples/dbrx/8bit-lora.yaml @@ -10,7 +10,7 @@ datasets: type: alpaca dataset_prepared_path: last_run_prepared val_set_size: 0.0 -output_dir: ./out +output_dir: ./outputs/out sequence_len: 512 sample_packing: false diff --git a/examples/dbrx/fft-ds-zero3.yaml b/examples/dbrx/fft-ds-zero3.yaml index 68292707a4..60dc201eee 100644 --- a/examples/dbrx/fft-ds-zero3.yaml +++ b/examples/dbrx/fft-ds-zero3.yaml @@ -10,7 +10,7 @@ datasets: type: alpaca dataset_prepared_path: last_run_prepared val_set_size: 0.0 -output_dir: ./out +output_dir: ./outputs/out sequence_len: 512 sample_packing: false diff --git a/examples/falcon/config-7b-lora.yml b/examples/falcon/config-7b-lora.yml index 5be9c64253..029ca40e09 100644 --- a/examples/falcon/config-7b-lora.yml +++ b/examples/falcon/config-7b-lora.yml @@ -28,7 +28,7 @@ wandb_entity: wandb_watch: wandb_name: wandb_log_model: -output_dir: ./falcon-7b +output_dir: ./outputs/falcon-7b batch_size: 2 micro_batch_size: 1 num_epochs: 4 diff --git a/examples/falcon/config-7b-qlora.yml b/examples/falcon/config-7b-qlora.yml index eb1cdfcdba..4e34144ed6 100644 --- a/examples/falcon/config-7b-qlora.yml +++ b/examples/falcon/config-7b-qlora.yml @@ -42,7 +42,7 @@ wandb_entity: wandb_watch: wandb_name: wandb_log_model: -output_dir: ./qlora-out +output_dir: ./outputs/qlora-out # QLoRA paper Table 9 # - 16 for 7b & 13b diff --git a/examples/falcon/config-7b.yml b/examples/falcon/config-7b.yml index 1dd46a93ff..36264f063e 100644 --- a/examples/falcon/config-7b.yml +++ b/examples/falcon/config-7b.yml @@ -28,7 +28,7 @@ wandb_entity: wandb_watch: wandb_name: wandb_log_model: -output_dir: ./falcon-7b +output_dir: ./outputs/falcon-7b batch_size: 2 micro_batch_size: 1 num_epochs: 4 diff --git a/examples/gemma/qlora.yml b/examples/gemma/qlora.yml index 619a401291..e08facfc5d 100644 --- a/examples/gemma/qlora.yml +++ b/examples/gemma/qlora.yml @@ -12,7 +12,7 @@ datasets: - path: mhenrichsen/alpaca_2k_test type: alpaca val_set_size: 0.1 -output_dir: ./out +output_dir: ./outputs/out adapter: qlora lora_r: 32 diff --git a/examples/gptj/qlora.yml b/examples/gptj/qlora.yml index cd3f2e2ad7..f801729fac 100644 --- a/examples/gptj/qlora.yml +++ b/examples/gptj/qlora.yml @@ -23,7 +23,7 @@ wandb_entity: wandb_watch: wandb_name: wandb_log_model: -output_dir: ./qlora-out +output_dir: ./outputs/qlora-out gradient_accumulation_steps: 2 micro_batch_size: 2 num_epochs: 2 diff --git a/examples/jamba/qlora.yaml b/examples/jamba/qlora.yaml index 41a3854fe1..3d6f69e793 100644 --- a/examples/jamba/qlora.yaml +++ b/examples/jamba/qlora.yaml @@ -10,7 +10,7 @@ datasets: type: alpaca dataset_prepared_path: val_set_size: 0.0 -output_dir: ./out +output_dir: ./outputs/out sequence_len: 4096 sample_packing: false diff --git a/examples/jamba/qlora_deepspeed.yaml b/examples/jamba/qlora_deepspeed.yaml index ef04fb53fe..43a76c00b1 100644 --- a/examples/jamba/qlora_deepspeed.yaml +++ b/examples/jamba/qlora_deepspeed.yaml @@ -10,7 +10,7 @@ datasets: type: alpaca dataset_prepared_path: val_set_size: 0.0 -output_dir: ./out +output_dir: ./outputs/out sequence_len: 4096 sample_packing: false diff --git a/examples/jeopardy-bot/config.yml b/examples/jeopardy-bot/config.yml index a672c7b94f..088629c084 100644 --- a/examples/jeopardy-bot/config.yml +++ b/examples/jeopardy-bot/config.yml @@ -21,7 +21,7 @@ wandb_entity: wandb_watch: wandb_name: wandb_log_model: -output_dir: ./jeopardy-bot-7b +output_dir: ./outputs/jeopardy-bot-7b gradient_accumulation_steps: 1 micro_batch_size: 1 num_epochs: 4 diff --git a/examples/llama-2/fft_optimized.yml b/examples/llama-2/fft_optimized.yml index 74edc95e6b..3d94b04b8b 100644 --- a/examples/llama-2/fft_optimized.yml +++ b/examples/llama-2/fft_optimized.yml @@ -11,7 +11,7 @@ datasets: type: alpaca dataset_prepared_path: last_run_prepared val_set_size: 0.05 -output_dir: ./out +output_dir: ./outputs/out sequence_len: 4096 sample_packing: true diff --git a/examples/llama-2/gptq-lora.yml b/examples/llama-2/gptq-lora.yml index 68ca9ed31c..2a706265bd 100644 --- a/examples/llama-2/gptq-lora.yml +++ b/examples/llama-2/gptq-lora.yml @@ -33,7 +33,7 @@ wandb_project: wandb_watch: wandb_name: wandb_log_model: -output_dir: ./model-out +output_dir: ./outputs/model-out gradient_accumulation_steps: 1 micro_batch_size: 1 num_epochs: 4 diff --git a/examples/llama-2/lisa.yml b/examples/llama-2/lisa.yml index e692c7ac1e..7012d1f613 100644 --- a/examples/llama-2/lisa.yml +++ b/examples/llama-2/lisa.yml @@ -11,7 +11,7 @@ datasets: type: alpaca dataset_prepared_path: last_run_prepared val_set_size: 0.05 -output_dir: ./lisa-out +output_dir: ./outputs/lisa-out sequence_len: 4096 sample_packing: true diff --git a/examples/llama-2/loftq.yml b/examples/llama-2/loftq.yml index 4529a912dc..68d9ac0142 100644 --- a/examples/llama-2/loftq.yml +++ b/examples/llama-2/loftq.yml @@ -11,7 +11,7 @@ datasets: type: alpaca dataset_prepared_path: val_set_size: 0.05 -output_dir: ./lora-out +output_dir: ./outputs/lora-out sequence_len: 4096 sample_packing: true diff --git a/examples/llama-2/lora.yml b/examples/llama-2/lora.yml index a7793dce4c..95bfae6920 100644 --- a/examples/llama-2/lora.yml +++ b/examples/llama-2/lora.yml @@ -11,7 +11,7 @@ datasets: type: alpaca dataset_prepared_path: val_set_size: 0.05 -output_dir: ./lora-out +output_dir: ./outputs/lora-out sequence_len: 4096 sample_packing: true diff --git a/examples/llama-2/qlora-fsdp.yml b/examples/llama-2/qlora-fsdp.yml index 93b3b2a60a..88029f92d5 100644 --- a/examples/llama-2/qlora-fsdp.yml +++ b/examples/llama-2/qlora-fsdp.yml @@ -11,7 +11,7 @@ datasets: type: alpaca dataset_prepared_path: last_run_prepared val_set_size: 0.05 -output_dir: ./qlora-out +output_dir: ./outputs/qlora-out adapter: qlora lora_model_dir: diff --git a/examples/llama-2/qlora.yml b/examples/llama-2/qlora.yml index 834dbfb33a..dda32170bd 100644 --- a/examples/llama-2/qlora.yml +++ b/examples/llama-2/qlora.yml @@ -11,7 +11,7 @@ datasets: type: alpaca dataset_prepared_path: val_set_size: 0.05 -output_dir: ./qlora-out +output_dir: ./outputs/qlora-out adapter: qlora lora_model_dir: diff --git a/examples/llama-2/relora.yml b/examples/llama-2/relora.yml index 9fd19953c6..93247ce068 100644 --- a/examples/llama-2/relora.yml +++ b/examples/llama-2/relora.yml @@ -12,7 +12,7 @@ datasets: type: alpaca dataset_prepared_path: val_set_size: 0.05 -output_dir: ./relora-out +output_dir: ./outputs/relora-out adapter: qlora lora_model_dir: diff --git a/examples/llama-3/fft-8b.yaml b/examples/llama-3/fft-8b.yaml index 8c9ba90bfe..a36fd740e4 100644 --- a/examples/llama-3/fft-8b.yaml +++ b/examples/llama-3/fft-8b.yaml @@ -11,7 +11,7 @@ datasets: type: alpaca dataset_prepared_path: last_run_prepared val_set_size: 0.05 -output_dir: ./out +output_dir: ./outputs/out sequence_len: 8192 sample_packing: true diff --git a/examples/llama-3/lora-8b.yml b/examples/llama-3/lora-8b.yml index d60f8a3035..6b0ebaed86 100644 --- a/examples/llama-3/lora-8b.yml +++ b/examples/llama-3/lora-8b.yml @@ -11,7 +11,7 @@ datasets: type: alpaca dataset_prepared_path: val_set_size: 0.05 -output_dir: ./lora-out +output_dir: ./outputs/lora-out sequence_len: 4096 sample_packing: true diff --git a/examples/llama-3/qlora-fsdp-70b.yaml b/examples/llama-3/qlora-fsdp-70b.yaml index 8d8785bfd5..9b74f6b4de 100644 --- a/examples/llama-3/qlora-fsdp-70b.yaml +++ b/examples/llama-3/qlora-fsdp-70b.yaml @@ -11,7 +11,7 @@ datasets: type: alpaca dataset_prepared_path: last_run_prepared val_set_size: 0.05 -output_dir: ./out/qlora-llama3-70b +output_dir: ./outputs/out/qlora-llama3-70b adapter: qlora lora_model_dir: diff --git a/examples/llama-3/qlora.yml b/examples/llama-3/qlora.yml index 9cedee8eec..44120d9385 100644 --- a/examples/llama-3/qlora.yml +++ b/examples/llama-3/qlora.yml @@ -11,7 +11,7 @@ datasets: type: alpaca dataset_prepared_path: val_set_size: 0 -output_dir: ./qlora-out +output_dir: ./outputs/qlora-out adapter: qlora lora_model_dir: diff --git a/examples/mamba/config.yml b/examples/mamba/config.yml index 0a5223bcac..f88f5138d9 100644 --- a/examples/mamba/config.yml +++ b/examples/mamba/config.yml @@ -12,7 +12,7 @@ datasets: type: alpaca dataset_prepared_path: val_set_size: 0.0 -output_dir: ./out +output_dir: ./outputs/out sequence_len: 2048 sample_packing: false diff --git a/examples/mistral/bigstral-ds-zero3.yaml b/examples/mistral/bigstral-ds-zero3.yaml index cc0a44b2a4..e993e44a78 100644 --- a/examples/mistral/bigstral-ds-zero3.yaml +++ b/examples/mistral/bigstral-ds-zero3.yaml @@ -23,7 +23,7 @@ datasets: type: alpaca dataset_prepared_path: last_run_prepared val_set_size: 0.05 -output_dir: ./out +output_dir: ./outputs/out sequence_len: 2048 sample_packing: true diff --git a/examples/mistral/config.yml b/examples/mistral/config.yml index c909c63e22..a70937c4fd 100644 --- a/examples/mistral/config.yml +++ b/examples/mistral/config.yml @@ -11,7 +11,7 @@ datasets: type: alpaca dataset_prepared_path: val_set_size: 0.05 -output_dir: ./out +output_dir: ./outputs/out sequence_len: 8192 sample_packing: true diff --git a/examples/mistral/lora-mps.yml b/examples/mistral/lora-mps.yml index 31b0d527e2..03c74bb59b 100644 --- a/examples/mistral/lora-mps.yml +++ b/examples/mistral/lora-mps.yml @@ -11,7 +11,7 @@ datasets: type: alpaca dataset_prepared_path: last_run_prepared val_set_size: 0 -output_dir: ./lora-out +output_dir: ./outputs/lora-out eval_sample_packing: false adapter: lora diff --git a/examples/mistral/lora.yml b/examples/mistral/lora.yml index ac9ac0dd98..0d5dc9edd7 100644 --- a/examples/mistral/lora.yml +++ b/examples/mistral/lora.yml @@ -11,7 +11,7 @@ datasets: type: alpaca dataset_prepared_path: last_run_prepared val_set_size: 0.1 -output_dir: ./lora-out +output_dir: ./outputs/lora-out adapter: lora lora_model_dir: diff --git a/examples/mistral/mistral-qlora-fsdp.yml b/examples/mistral/mistral-qlora-fsdp.yml index 71ac1e701f..e6b07c594b 100644 --- a/examples/mistral/mistral-qlora-fsdp.yml +++ b/examples/mistral/mistral-qlora-fsdp.yml @@ -12,7 +12,7 @@ datasets: type: alpaca dataset_prepared_path: last_run_prepared val_set_size: 0.02 -output_dir: ./qlora-out +output_dir: ./outputs/qlora-out model_config: output_router_logits: true diff --git a/examples/mistral/mistral-qlora-orpo.yml b/examples/mistral/mistral-qlora-orpo.yml index 7727fd7485..2549ef018c 100644 --- a/examples/mistral/mistral-qlora-orpo.yml +++ b/examples/mistral/mistral-qlora-orpo.yml @@ -16,7 +16,7 @@ datasets: type: chat_template.argilla dataset_prepared_path: last_run_prepared val_set_size: 0.1 -output_dir: ./mistral-qlora-orpo-out +output_dir: ./outputs/mistral-qlora-orpo-out adapter: qlora lora_model_dir: diff --git a/examples/mistral/mixtral-8x22b-qlora-fsdp.yml b/examples/mistral/mixtral-8x22b-qlora-fsdp.yml index ac80a2a756..fe68b28172 100644 --- a/examples/mistral/mixtral-8x22b-qlora-fsdp.yml +++ b/examples/mistral/mixtral-8x22b-qlora-fsdp.yml @@ -11,7 +11,7 @@ datasets: type: alpaca dataset_prepared_path: last_run_prepared val_set_size: 0.02 -output_dir: ./qlora-out +output_dir: ./outputs/qlora-out model_config: output_router_logits: true diff --git a/examples/mistral/mixtral-qlora-fsdp.yml b/examples/mistral/mixtral-qlora-fsdp.yml index b6a07ae51c..c095970402 100644 --- a/examples/mistral/mixtral-qlora-fsdp.yml +++ b/examples/mistral/mixtral-qlora-fsdp.yml @@ -12,7 +12,7 @@ datasets: type: alpaca dataset_prepared_path: last_run_prepared val_set_size: 0.02 -output_dir: ./qlora-out +output_dir: ./outputs/qlora-out model_config: output_router_logits: true diff --git a/examples/mistral/mixtral.yml b/examples/mistral/mixtral.yml index 5ee3da9d65..13fbe92ab8 100644 --- a/examples/mistral/mixtral.yml +++ b/examples/mistral/mixtral.yml @@ -12,7 +12,7 @@ datasets: type: alpaca dataset_prepared_path: last_run_prepared val_set_size: 0.0 -output_dir: ./qlora-out +output_dir: ./outputs/qlora-out ## You can optionally freeze the entire model and unfreeze a subset of parameters unfrozen_parameters: diff --git a/examples/mistral/mixtral_22.yml b/examples/mistral/mixtral_22.yml index 9abb6f407a..9a1e86386c 100644 --- a/examples/mistral/mixtral_22.yml +++ b/examples/mistral/mixtral_22.yml @@ -21,7 +21,7 @@ model_config: datasets: - path: yahma/alpaca-cleaned type: alpaca -output_dir: ./out +output_dir: ./outputs/out sequence_len: 8000 sample_packing: true diff --git a/examples/mistral/qlora.yml b/examples/mistral/qlora.yml index 6fbbb96183..c7bdb155c0 100644 --- a/examples/mistral/qlora.yml +++ b/examples/mistral/qlora.yml @@ -11,7 +11,7 @@ datasets: type: alpaca dataset_prepared_path: last_run_prepared val_set_size: 0.1 -output_dir: ./qlora-out +output_dir: ./outputs/qlora-out adapter: qlora lora_model_dir: diff --git a/examples/mpt-7b/config.yml b/examples/mpt-7b/config.yml index 45e31266f1..530415de17 100644 --- a/examples/mpt-7b/config.yml +++ b/examples/mpt-7b/config.yml @@ -23,7 +23,7 @@ wandb_entity: wandb_watch: wandb_name: wandb_log_model: -output_dir: ./mpt-alpaca-7b +output_dir: ./outputs/mpt-alpaca-7b gradient_accumulation_steps: 1 micro_batch_size: 1 num_epochs: 4 diff --git a/examples/openllama-3b/config.yml b/examples/openllama-3b/config.yml index 0a404c79d8..a0473213c0 100644 --- a/examples/openllama-3b/config.yml +++ b/examples/openllama-3b/config.yml @@ -25,7 +25,7 @@ wandb_entity: wandb_watch: wandb_name: wandb_log_model: -output_dir: ./openllama-out +output_dir: ./outputs/openllama-out gradient_accumulation_steps: 1 micro_batch_size: 1 num_epochs: 4 diff --git a/examples/openllama-3b/lora.yml b/examples/openllama-3b/lora.yml index b83b2db4e4..2b67849159 100644 --- a/examples/openllama-3b/lora.yml +++ b/examples/openllama-3b/lora.yml @@ -31,7 +31,7 @@ wandb_entity: wandb_watch: wandb_name: wandb_log_model: -output_dir: ./lora-out +output_dir: ./outputs/lora-out gradient_accumulation_steps: 1 micro_batch_size: 2 num_epochs: 4 diff --git a/examples/openllama-3b/qlora.yml b/examples/openllama-3b/qlora.yml index 3d6218b308..8d4dc05ca7 100644 --- a/examples/openllama-3b/qlora.yml +++ b/examples/openllama-3b/qlora.yml @@ -25,7 +25,7 @@ wandb_entity: wandb_watch: wandb_name: wandb_log_model: -output_dir: ./qlora-out +output_dir: ./outputs/qlora-out gradient_accumulation_steps: 1 micro_batch_size: 2 num_epochs: 4 diff --git a/examples/phi/phi-ft.yml b/examples/phi/phi-ft.yml index b21386f707..0dabadc7a4 100644 --- a/examples/phi/phi-ft.yml +++ b/examples/phi/phi-ft.yml @@ -12,7 +12,7 @@ datasets: dataset_prepared_path: val_set_size: 0.05 -output_dir: ./phi-sft-out +output_dir: ./outputs/phi-sft-out sequence_len: 2048 sample_packing: true diff --git a/examples/phi/phi-qlora.yml b/examples/phi/phi-qlora.yml index d2b5d661c9..7c181a3c15 100644 --- a/examples/phi/phi-qlora.yml +++ b/examples/phi/phi-qlora.yml @@ -12,7 +12,7 @@ datasets: dataset_prepared_path: val_set_size: 0.05 -output_dir: ./phi-sft-out +output_dir: ./outputs/phi-sft-out sequence_len: 2048 sample_packing: true diff --git a/examples/phi/phi2-ft.yml b/examples/phi/phi2-ft.yml index 7a2d05d018..27815550b4 100644 --- a/examples/phi/phi2-ft.yml +++ b/examples/phi/phi2-ft.yml @@ -12,7 +12,7 @@ datasets: dataset_prepared_path: val_set_size: 0.05 -output_dir: ./phi-sft-out +output_dir: ./outputs/phi-sft-out sequence_len: 2048 sample_packing: true diff --git a/examples/pythia-12b/config.yml b/examples/pythia-12b/config.yml index e44bba7451..18e6beaafd 100644 --- a/examples/pythia-12b/config.yml +++ b/examples/pythia-12b/config.yml @@ -26,7 +26,7 @@ wandb_entity: wandb_watch: wandb_name: wandb_log_model: -output_dir: ./pythia-12b +output_dir: ./outputs/pythia-12b gradient_accumulation_steps: 1 micro_batch_size: 1 num_epochs: 5 diff --git a/examples/pythia/lora.yml b/examples/pythia/lora.yml index 7cb07fe258..0aa650f67e 100644 --- a/examples/pythia/lora.yml +++ b/examples/pythia/lora.yml @@ -20,7 +20,7 @@ wandb_entity: wandb_watch: wandb_name: wandb_log_model: -output_dir: ./lora-alpaca-pythia +output_dir: ./outputs/lora-alpaca-pythia gradient_accumulation_steps: 1 micro_batch_size: 4 num_epochs: 4 diff --git a/examples/qwen/lora.yml b/examples/qwen/lora.yml index da4d784e0a..dd8dc1e4f4 100644 --- a/examples/qwen/lora.yml +++ b/examples/qwen/lora.yml @@ -13,7 +13,7 @@ datasets: type: alpaca dataset_prepared_path: val_set_size: 0.05 -output_dir: ./lora-out +output_dir: ./outputs/lora-out sequence_len: 2048 # supports up to 8192 sample_packing: false diff --git a/examples/qwen/qlora.yml b/examples/qwen/qlora.yml index 501a866b2d..01c0c0ab86 100644 --- a/examples/qwen/qlora.yml +++ b/examples/qwen/qlora.yml @@ -13,7 +13,7 @@ datasets: type: alpaca dataset_prepared_path: val_set_size: 0.05 -output_dir: ./lora-out +output_dir: ./outputs/lora-out sequence_len: 2048 # supports up to 8192 sample_packing: false diff --git a/examples/qwen/qwen2-moe-lora.yaml b/examples/qwen/qwen2-moe-lora.yaml index c59b282d0a..452335e38f 100644 --- a/examples/qwen/qwen2-moe-lora.yaml +++ b/examples/qwen/qwen2-moe-lora.yaml @@ -10,7 +10,7 @@ datasets: type: alpaca dataset_prepared_path: val_set_size: 0.05 -output_dir: ./out +output_dir: ./outputs/out sequence_len: 1024 # supports up to 32k sample_packing: false diff --git a/examples/qwen/qwen2-moe-qlora.yaml b/examples/qwen/qwen2-moe-qlora.yaml index d6a835a0a3..bc11007c78 100644 --- a/examples/qwen/qwen2-moe-qlora.yaml +++ b/examples/qwen/qwen2-moe-qlora.yaml @@ -10,7 +10,7 @@ datasets: type: alpaca dataset_prepared_path: val_set_size: 0.05 -output_dir: ./out +output_dir: ./outputs/out sequence_len: 1024 # supports up to 32k sample_packing: false diff --git a/examples/redpajama/config-3b.yml b/examples/redpajama/config-3b.yml index 5a42e2a952..ff395a863d 100644 --- a/examples/redpajama/config-3b.yml +++ b/examples/redpajama/config-3b.yml @@ -24,7 +24,7 @@ wandb_entity: wandb_watch: wandb_name: wandb_log_model: -output_dir: ./redpajama-alpaca-3b +output_dir: ./outputs/redpajama-alpaca-3b batch_size: 4 micro_batch_size: 1 num_epochs: 4 diff --git a/examples/replit-3b/config-lora.yml b/examples/replit-3b/config-lora.yml index bdfe1bd854..9fee099d47 100644 --- a/examples/replit-3b/config-lora.yml +++ b/examples/replit-3b/config-lora.yml @@ -23,7 +23,7 @@ wandb_entity: wandb_watch: wandb_name: wandb_log_model: -output_dir: ./lora-replit +output_dir: ./outputs/lora-replit batch_size: 8 micro_batch_size: 1 num_epochs: 4 diff --git a/examples/stablelm-2/1.6b/fft.yml b/examples/stablelm-2/1.6b/fft.yml index f3fc16f867..777262a7ee 100644 --- a/examples/stablelm-2/1.6b/fft.yml +++ b/examples/stablelm-2/1.6b/fft.yml @@ -12,7 +12,7 @@ datasets: type: alpaca dataset_prepared_path: last_run_prepared val_set_size: 0.05 -output_dir: ./out +output_dir: ./outputs/out sequence_len: 4096 sample_packing: true diff --git a/examples/stablelm-2/1.6b/lora.yml b/examples/stablelm-2/1.6b/lora.yml index c5051fab6e..c65b9e4cd0 100644 --- a/examples/stablelm-2/1.6b/lora.yml +++ b/examples/stablelm-2/1.6b/lora.yml @@ -12,7 +12,7 @@ datasets: type: alpaca dataset_prepared_path: val_set_size: 0.05 -output_dir: ./lora-out +output_dir: ./outputs/lora-out sequence_len: 4096 sample_packing: true diff --git a/examples/starcoder2/qlora.yml b/examples/starcoder2/qlora.yml index 1efdfbc8e0..83fc0d89f7 100644 --- a/examples/starcoder2/qlora.yml +++ b/examples/starcoder2/qlora.yml @@ -11,7 +11,7 @@ datasets: dataset_prepared_path: val_set_size: 0.2 -output_dir: ./qlora +output_dir: ./outputs/qlora adapter: qlora lora_model_dir: diff --git a/examples/tiny-llama/lora-mps.yml b/examples/tiny-llama/lora-mps.yml index fd7b02caca..c08be82d3b 100644 --- a/examples/tiny-llama/lora-mps.yml +++ b/examples/tiny-llama/lora-mps.yml @@ -11,7 +11,7 @@ datasets: type: alpaca dataset_prepared_path: val_set_size: 0 -output_dir: ./lora-out +output_dir: ./outputs/lora-out sequence_len: 4096 sample_packing: true diff --git a/examples/tiny-llama/lora.yml b/examples/tiny-llama/lora.yml index 4a16f14b92..c5ff0437e8 100644 --- a/examples/tiny-llama/lora.yml +++ b/examples/tiny-llama/lora.yml @@ -11,7 +11,7 @@ datasets: type: alpaca dataset_prepared_path: val_set_size: 0.05 -output_dir: ./lora-out +output_dir: ./outputs/lora-out sequence_len: 4096 sample_packing: true diff --git a/examples/tiny-llama/pretrain.yml b/examples/tiny-llama/pretrain.yml index 3b68a7f547..e501dcb8e5 100644 --- a/examples/tiny-llama/pretrain.yml +++ b/examples/tiny-llama/pretrain.yml @@ -14,7 +14,7 @@ pretraining_dataset: type: pretrain dataset_prepared_path: val_set_size: 0.0 -output_dir: ./model-out +output_dir: ./outputs/model-out sequence_len: 2048 sample_packing: true diff --git a/examples/tiny-llama/qlora.yml b/examples/tiny-llama/qlora.yml index 3ea313c838..384f3315c0 100644 --- a/examples/tiny-llama/qlora.yml +++ b/examples/tiny-llama/qlora.yml @@ -11,13 +11,14 @@ datasets: type: alpaca dataset_prepared_path: val_set_size: 0.05 -output_dir: ./qlora-out +output_dir: ./outputs/qlora-out adapter: qlora lora_model_dir: sequence_len: 4096 sample_packing: true +eval_sample_packing: false pad_to_sequence_len: true lora_r: 32 diff --git a/examples/xgen-7b/xgen-7b-8k-qlora.yml b/examples/xgen-7b/xgen-7b-8k-qlora.yml index e3faa01bdb..7e3f83cbd7 100644 --- a/examples/xgen-7b/xgen-7b-8k-qlora.yml +++ b/examples/xgen-7b/xgen-7b-8k-qlora.yml @@ -40,7 +40,7 @@ wandb_entity: wandb_watch: wandb_name: wandb_log_model: -output_dir: ./qlora-out +output_dir: ./outputs/qlora-out # QLoRA paper Table 9 # - 16 for 7b & 13b diff --git a/examples/yi-34B-chat/qlora.yml b/examples/yi-34B-chat/qlora.yml index dc8c37d187..7fe322d63d 100644 --- a/examples/yi-34B-chat/qlora.yml +++ b/examples/yi-34B-chat/qlora.yml @@ -33,7 +33,7 @@ eval_sample_packing: false eval_batch_size: 1 # LoRA -output_dir: ./qlora-out +output_dir: ./outputs/qlora-out adapter: qlora lora_model_dir: lora_r: 32 diff --git a/requirements.txt b/requirements.txt index e44b27a595..c4d4a56eda 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,22 +1,22 @@ --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/ packaging==23.2 peft==0.10.0 -transformers @ git+https://github.com/huggingface/transformers.git@43d17c18360ac9c3d3491389328e2fe55fe8f9ce -tokenizers==0.15.0 -bitsandbytes==0.43.0 -accelerate==0.28.0 -deepspeed==0.13.1 +transformers==4.40.2 +tokenizers==0.19.1 +bitsandbytes==0.43.1 +accelerate==0.30.1 +deepspeed==0.14.2 pydantic==2.6.3 addict fire PyYAML>=6.0 requests -datasets==2.15.0 +datasets==2.19.1 flash-attn==2.4.3.post1 sentencepiece wandb einops -xformers==0.0.22 +xformers==0.0.23.post1 optimum==1.16.2 hf_transfer colorama @@ -28,7 +28,7 @@ scipy scikit-learn==1.2.2 pynvml art -fschat @ git+https://github.com/lm-sys/FastChat.git@5095615810cf613dba7f27dd155f571fcff976d8 +fschat @ git+https://github.com/lm-sys/FastChat.git@27a05b04a35510afb1d767ae7e5990cbd278f8fe gradio==3.50.2 tensorboard diff --git a/scripts/cloud-entrypoint-term.sh b/scripts/cloud-entrypoint-term.sh new file mode 100755 index 0000000000..94511ec7c6 --- /dev/null +++ b/scripts/cloud-entrypoint-term.sh @@ -0,0 +1,82 @@ +#!/bin/bash + +# Export specific ENV variables to /etc/rp_environment +echo "Exporting environment variables..." +printenv | grep -E '^RUNPOD_|^PATH=|^_=' | sed 's/^\(.*\)=\(.*\)$/export \1="\2"/' >> /etc/rp_environment +conda init +# this needs to come after conda init +echo 'source /etc/rp_environment' >> ~/.bashrc + +add_keys_to_authorized() { + local key_value=$1 + + # Create the ~/.ssh directory and set permissions + mkdir -p ~/.ssh + chmod 700 ~/.ssh + + # Create the authorized_keys file if it doesn't exist + touch ~/.ssh/authorized_keys + + # Initialize an empty key variable + local key="" + + # Read the key variable word by word + for word in $key_value; do + # Check if the word looks like the start of a key + if [[ $word == ssh-* ]]; then + # If there's a key being built, add it to the authorized_keys file + if [[ -n $key ]]; then + echo $key >> ~/.ssh/authorized_keys + fi + # Start a new key + key=$word + else + # Append the word to the current key + key="$key $word" + fi + done + + # Add the last key to the authorized_keys file + if [[ -n $key ]]; then + echo $key >> ~/.ssh/authorized_keys + fi + + # Set the correct permissions + chmod 600 ~/.ssh/authorized_keys + chmod 700 -R ~/.ssh +} + +if [[ $PUBLIC_KEY ]]; then + # runpod + add_keys_to_authorized "$PUBLIC_KEY" + # Start the SSH service in the background + service ssh start +elif [[ $SSH_KEY ]]; then + # latitude.sh + add_keys_to_authorized "$SSH_KEY" + # Start the SSH service in the background + service ssh start +else + echo "No PUBLIC_KEY or SSH_KEY environment variable provided, not starting openSSH daemon" +fi + +# Check if JUPYTER_PASSWORD is set and not empty +if [ -n "$JUPYTER_PASSWORD" ]; then + # Set JUPYTER_TOKEN to the value of JUPYTER_PASSWORD + export JUPYTER_TOKEN="$JUPYTER_PASSWORD" +fi + +if [ "$JUPYTER_DISABLE" != "1" ]; then + # Run Jupyter Lab in the background + jupyter lab --port=8888 --ip=* --allow-root --ServerApp.allow_origin=* & +fi + +if [ ! -d "/workspace/data/axolotl-artifacts" ]; then + mkdir -p /workspace/data/axolotl-artifacts +fi +if [ ! -L "/workspace/axolotl/outputs" ]; then + ln -sf /workspace/data/axolotl-artifacts /workspace/axolotl/outputs +fi + +# Execute the passed arguments (CMD) +exec "$@" diff --git a/scripts/cloud-entrypoint.sh b/scripts/cloud-entrypoint.sh index c7b9ca3e0f..5b0337f2b2 100755 --- a/scripts/cloud-entrypoint.sh +++ b/scripts/cloud-entrypoint.sh @@ -5,20 +5,53 @@ echo "Exporting environment variables..." printenv | grep -E '^RUNPOD_|^PATH=|^_=' | sed 's/^\(.*\)=\(.*\)$/export \1="\2"/' >> /etc/rp_environment echo 'source /etc/rp_environment' >> ~/.bashrc -if [[ $PUBLIC_KEY ]]; then - # runpod +add_keys_to_authorized() { + local key_value=$1 + + # Create the ~/.ssh directory and set permissions mkdir -p ~/.ssh chmod 700 ~/.ssh - echo $PUBLIC_KEY >> ~/.ssh/authorized_keys + + # Create the authorized_keys file if it doesn't exist + touch ~/.ssh/authorized_keys + + # Initialize an empty key variable + local key="" + + # Read the key variable word by word + for word in $key_value; do + # Check if the word looks like the start of a key + if [[ $word == ssh-* ]]; then + # If there's a key being built, add it to the authorized_keys file + if [[ -n $key ]]; then + echo $key >> ~/.ssh/authorized_keys + fi + # Start a new key + key=$word + else + # Append the word to the current key + key="$key $word" + fi + done + + # Add the last key to the authorized_keys file + if [[ -n $key ]]; then + echo $key >> ~/.ssh/authorized_keys + fi + + # Set the correct permissions + chmod 600 ~/.ssh/authorized_keys chmod 700 -R ~/.ssh +} + +if [[ $PUBLIC_KEY ]]; then + # runpod + add_keys_to_authorized "$PUBLIC_KEY" # Start the SSH service in the background service ssh start -elif [ -n "$SSH_KEY" ]; then +elif [[ $SSH_KEY ]]; then # latitude.sh - mkdir -p ~/.ssh - chmod 700 ~/.ssh - echo $SSH_KEY >> ~/.ssh/authorized_keys - chmod 700 -R ~/.ssh + add_keys_to_authorized "$SSH_KEY" # Start the SSH service in the background service ssh start else @@ -33,7 +66,14 @@ fi if [ "$JUPYTER_DISABLE" != "1" ]; then # Run Jupyter Lab in the background - jupyter lab --port=8888 --ip=* --allow-root --ServerApp.allow_origin=* --ServerApp.preferred_dir=/workspace & + jupyter lab --port=8888 --ip=* --allow-root --ServerApp.allow_origin=* & +fi + +if [ ! -d "/workspace/data/axolotl-artifacts" ]; then + mkdir -p /workspace/data/axolotl-artifacts +fi +if [ ! -L "/workspace/axolotl/outputs" ]; then + ln -sf /workspace/data/axolotl-artifacts /workspace/axolotl/outputs fi # Execute the passed arguments (CMD) diff --git a/setup.py b/setup.py index 905df16193..88bcb74b16 100644 --- a/setup.py +++ b/setup.py @@ -30,7 +30,7 @@ def parse_requirements(): try: if "Darwin" in platform.system(): - _install_requires.pop(_install_requires.index("xformers==0.0.22")) + _install_requires.pop(_install_requires.index("xformers==0.0.23.post1")) else: torch_version = version("torch") _install_requires.append(f"torch=={torch_version}") @@ -45,9 +45,12 @@ def parse_requirements(): else: raise ValueError("Invalid version format") - if (major, minor) >= (2, 1): - _install_requires.pop(_install_requires.index("xformers==0.0.22")) - _install_requires.append("xformers>=0.0.23") + if (major, minor) >= (2, 3): + _install_requires.pop(_install_requires.index("xformers==0.0.23.post1")) + _install_requires.append("xformers>=0.0.26.post1") + elif (major, minor) >= (2, 2): + _install_requires.pop(_install_requires.index("xformers==0.0.23.post1")) + _install_requires.append("xformers>=0.0.25.post1") except PackageNotFoundError: pass @@ -71,10 +74,10 @@ def parse_requirements(): "flash-attn==2.4.3.post1", ], "fused-dense-lib": [ - "fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.3.3#subdirectory=csrc/fused_dense_lib", + "fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.5.8#subdirectory=csrc/fused_dense_lib", ], "deepspeed": [ - "deepspeed==0.13.1", + "deepspeed==0.14.2", "deepspeed-kernels", ], "mamba-ssm": [ diff --git a/src/axolotl/cli/__init__.py b/src/axolotl/cli/__init__.py index 81d20802cd..42ce74c0c8 100644 --- a/src/axolotl/cli/__init__.py +++ b/src/axolotl/cli/__init__.py @@ -265,8 +265,8 @@ def generate(instruction): with torch.no_grad(): generation_config = GenerationConfig( repetition_penalty=1.1, - max_new_tokens=1024, - temperature=0.9, + max_new_tokens=cfg.get("gradio_max_new_tokens", 1024), + temperature=cfg.get("gradio_temperature", 0.9), top_p=0.95, top_k=40, bos_token_id=tokenizer.bos_token_id, @@ -301,7 +301,13 @@ def generate(instruction): outputs="text", title=cfg.get("gradio_title", "Axolotl Gradio Interface"), ) - demo.queue().launch(show_api=False, share=True) + + demo.queue().launch( + show_api=False, + share=cfg.get("gradio_share", True), + server_name=cfg.get("gradio_server_name", "127.0.0.1"), + server_port=cfg.get("gradio_server_port", None), + ) def choose_config(path: Path): diff --git a/src/axolotl/cli/merge_lora.py b/src/axolotl/cli/merge_lora.py index 611881ab1a..8db3fa9897 100644 --- a/src/axolotl/cli/merge_lora.py +++ b/src/axolotl/cli/merge_lora.py @@ -25,6 +25,8 @@ def do_cli(config: Path = Path("examples/"), **kwargs): load_in_8bit=False, load_in_4bit=False, flash_attention=False, + deepspeed=None, + fsdp=None, **kwargs, ) @@ -40,6 +42,7 @@ def do_cli(config: Path = Path("examples/"), **kwargs): parsed_cfg.flash_attention = False parsed_cfg.deepspeed = None parsed_cfg.fsdp = None + parsed_cfg.fsdp_config = None do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args) diff --git a/src/axolotl/cli/preprocess.py b/src/axolotl/cli/preprocess.py index fa71d67934..e7b3596a4f 100644 --- a/src/axolotl/cli/preprocess.py +++ b/src/axolotl/cli/preprocess.py @@ -19,7 +19,10 @@ ) from axolotl.common.cli import PreprocessCliArgs from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH -from axolotl.prompt_strategies.sharegpt import register_chatml_template +from axolotl.prompt_strategies.sharegpt import ( + register_chatml_template, + register_llama3_template, +) LOG = logging.getLogger("axolotl.cli.preprocess") @@ -36,13 +39,22 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): return_remaining_strings=True ) - if parsed_cfg.chat_template == "chatml" and parsed_cfg.default_system_message: - LOG.info( - f"ChatML set. Adding default system message: {parsed_cfg.default_system_message}" - ) - register_chatml_template(parsed_cfg.default_system_message) - else: - register_chatml_template() + if parsed_cfg.chat_template == "chatml": + if parsed_cfg.default_system_message: + LOG.info( + f"ChatML set. Adding default system message: {parsed_cfg.default_system_message}" + ) + register_chatml_template(parsed_cfg.default_system_message) + else: + register_chatml_template() + elif parsed_cfg.chat_template == "llama3": + if parsed_cfg.default_system_message: + LOG.info( + f"LLaMA-3 set. Adding default system message: {parsed_cfg.default_system_message}" + ) + register_llama3_template(parsed_cfg.default_system_message) + else: + register_llama3_template() if not parsed_cfg.dataset_prepared_path: msg = ( diff --git a/src/axolotl/cli/train.py b/src/axolotl/cli/train.py index 0cebe5a52b..7bb4a51844 100644 --- a/src/axolotl/cli/train.py +++ b/src/axolotl/cli/train.py @@ -19,7 +19,10 @@ print_axolotl_text_art, ) from axolotl.common.cli import TrainerCliArgs -from axolotl.prompt_strategies.sharegpt import register_chatml_template +from axolotl.prompt_strategies.sharegpt import ( + register_chatml_template, + register_llama3_template, +) from axolotl.train import train LOG = logging.getLogger("axolotl.cli.train") @@ -47,6 +50,14 @@ def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]: else: register_chatml_template() + if cfg.chat_template == "llama3" and cfg.default_system_message: + LOG.info( + f"LLaMA-3 set. Adding default system message: {cfg.default_system_message}" + ) + register_llama3_template(cfg.default_system_message) + else: + register_llama3_template() + if cfg.rl: # and cfg.rl != "orpo": dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) else: diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py old mode 100644 new mode 100755 index aaa3420f74..f9138fff2e --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -30,7 +30,7 @@ ) from transformers.trainer_utils import seed_worker from transformers.utils import is_sagemaker_mp_enabled -from trl import DPOTrainer, ORPOConfig, ORPOTrainer +from trl import DPOTrainer, KTOConfig, KTOTrainer, ORPOConfig, ORPOTrainer from trl.trainer.utils import pad_to_length from axolotl.loraplus import create_loraplus_optimizer @@ -43,6 +43,7 @@ LossWatchDogCallback, SaveAxolotlConfigtoWandBCallback, SaveBetterTransformerModelCallback, + SaveModelCallback, bench_eval_callback_factory, causal_lm_bench_eval_callback_factory, log_prediction_callback_factory, @@ -124,14 +125,22 @@ class AxolotlTrainingArguments(TrainingArguments): default=1.0, metadata={"help": "Sample packing efficiency for calculating batch length."}, ) + sample_packing_bin_size: int = field( + default=200, + metadata={ + "help": "The max number of samples that packed sample can contain after packing. Increase for better packing." + }, + ) + sample_packing_group_size: int = field( + default=100000, + metadata={ + "help": "The number of samples to group together for packing. Increase for better packing." + }, + ) max_seq_length: int = field( default=2048, metadata={"help": "The maximum sequence length the model can handle"}, ) - sample_packing_seq_len_multiplier: int = field( - default=1, - metadata={"help": "the multiplier for the max len for packed sequences"}, - ) relora_steps: Optional[int] = field( default=None, metadata={"help": "how often to reset for ReLoRA"}, @@ -345,11 +354,11 @@ def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: ) return MultipackBatchSampler( RandomSampler(self.train_dataset), - batch_size=batch_size, - drop_last=True, - batch_max_len=batch_max_len, lengths=get_dataset_lengths(self.train_dataset), - packing_efficiency_estimate=self.args.sample_packing_efficiency, + batch_max_len=batch_max_len, + batch_size=batch_size, + group_size=self.args.sample_packing_group_size, + bin_size=self.args.sample_packing_bin_size, ) if self.args.curriculum_sampling: return SequentialSampler(self.train_dataset) @@ -369,11 +378,11 @@ def _get_eval_sampler( ) return MultipackBatchSampler( SequentialSampler(eval_dataset), - batch_size=batch_size, - drop_last=True, + lengths=get_dataset_lengths(self.eval_dataset), batch_max_len=batch_max_len, - lengths=get_dataset_lengths(eval_dataset), - packing_efficiency_estimate=self.args.sample_packing_efficiency, + batch_size=batch_size, + group_size=self.args.sample_packing_group_size, + bin_size=self.args.sample_packing_bin_size, ) return super()._get_eval_sampler(eval_dataset) @@ -797,6 +806,40 @@ class AxolotlDPOTrainer(DPOTrainer): tag_names = ["axolotl", "dpo"] + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.optimizer = None + + def create_optimizer(self): + if self.args.loraplus_lr_ratio is None: + return super().create_optimizer() + + opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model + if self.optimizer is None: # pylint: disable=access-member-before-definition + optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs( + self.args, + opt_model, + ) + + loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None) + if loraplus_lr_ratio: + print("Using lora+") + loraplus_lr_embedding = getattr(self.args, "loraplus_lr_embedding", None) + self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init + opt_model, + optimizer_cls, + optimizer_kwargs, + loraplus_lr_ratio, + loraplus_lr_embedding, + ) + + if is_sagemaker_mp_enabled(): + self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init + self.optimizer + ) + + return self.optimizer + @wraps(DPOTrainer.push_to_hub) def push_to_hub(self, *args, **kwargs) -> str: """ @@ -825,6 +868,14 @@ class AxolotlORPOTrainer(ORPOTrainer): tag_names = ["axolotl", "orpo"] +class AxolotlKTOTrainer(KTOTrainer): + """ + Extend the base KTOTrainer for axolotl helpers + """ + + tag_names = ["axolotl", "kto"] + + class TrainerBuilderBase(abc.ABC): """ Base class for trainer builder @@ -888,6 +939,14 @@ def get_callbacks(self) -> List[TrainerCallback]: callbacks.append( SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path) ) + if self.cfg.use_mlflow and is_mlflow_available(): + from axolotl.utils.callbacks.mlflow_ import ( + SaveAxolotlConfigtoMlflowCallback, + ) + + callbacks.append( + SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path) + ) if self.cfg.vessl_credential_path: from axolotl.utils.callbacks.vessl_ import VesslLogMetricsCallback @@ -938,18 +997,11 @@ def get_callbacks(self): ): callbacks.append(SaveBetterTransformerModelCallback()) - if self.cfg.use_mlflow and is_mlflow_available(): - from axolotl.utils.callbacks.mlflow_ import ( - SaveAxolotlConfigtoMlflowCallback, - ) - - callbacks.append( - SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path) - ) - if self.cfg.loss_watchdog_threshold is not None: callbacks.append(LossWatchDogCallback(self.cfg)) + callbacks.append(SaveModelCallback()) + return callbacks def get_post_trainer_create_callbacks(self, trainer): @@ -1074,11 +1126,6 @@ def build(self, total_num_steps): if self.cfg.save_safetensors is not None: training_arguments_kwargs["save_safetensors"] = self.cfg.save_safetensors - if self.cfg.sample_packing_eff_est: - training_arguments_kwargs[ - "sample_packing_efficiency" - ] = self.cfg.sample_packing_eff_est - if self.cfg.dataloader_pin_memory is not None: training_arguments_kwargs[ "dataloader_pin_memory" @@ -1126,6 +1173,8 @@ def build(self, total_num_steps): # default to saving each epoch if not defined training_arguments_kwargs["save_strategy"] = "epoch" + training_arguments_kwargs["save_only_model"] = self.cfg.save_only_model + if self.cfg.do_bench_eval: training_arguments_kwargs["do_bench_eval"] = self.cfg.do_bench_eval if self.cfg.bench_dataset: @@ -1205,11 +1254,14 @@ def build(self, total_num_steps): ) training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length training_arguments_kwargs["curriculum_sampling"] = self.cfg.curriculum_sampling - report_to = None + report_to = [] if self.cfg.use_wandb: - report_to = "wandb" + report_to.append("wandb") if self.cfg.use_mlflow: - report_to = "mlflow" + report_to.append("mlflow") + if self.cfg.use_tensorboard: + report_to.append("tensorboard") + training_arguments_kwargs["report_to"] = report_to training_arguments_kwargs["run_name"] = ( self.cfg.wandb_name if self.cfg.use_wandb else None @@ -1249,20 +1301,27 @@ def build(self, total_num_steps): training_arguments_kwargs["weight_decay"] = ( self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0 ) - training_arguments_kwargs["sample_packing"] = ( - self.cfg.sample_packing if self.cfg.sample_packing else False - ) - training_arguments_kwargs["multipack_real_batches"] = ( - self.cfg.flash_attention is not True - ) - training_arguments_kwargs["eval_sample_packing"] = ( - self.cfg.sample_packing - if self.cfg.eval_sample_packing is not False - else False - ) + + training_arguments_kwargs["sample_packing"] = bool(self.cfg.sample_packing) training_arguments_kwargs[ - "sample_packing_seq_len_multiplier" - ] = self.cfg.micro_batch_size + "multipack_real_batches" + ] = not self.cfg.flash_attention + training_arguments_kwargs["eval_sample_packing"] = bool( + self.cfg.eval_sample_packing + ) + if self.cfg.sample_packing_bin_size is not None: + training_arguments_kwargs[ + "sample_packing_bin_size" + ] = self.cfg.sample_packing_bin_size + if self.cfg.sample_packing_group_size is not None: + training_arguments_kwargs[ + "sample_packing_group_size" + ] = self.cfg.sample_packing_group_size + if self.cfg.sample_packing_eff_est: + training_arguments_kwargs[ + "sample_packing_efficiency" + ] = self.cfg.sample_packing_eff_est + if self.cfg.relora_steps: training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps training_arguments_kwargs[ @@ -1432,6 +1491,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase): def get_callbacks(self): callbacks = super().get_callbacks() + callbacks.append(SaveModelCallback()) + return callbacks def get_post_trainer_create_callbacks(self, trainer): @@ -1467,9 +1528,12 @@ def build_training_arguments(self, total_num_steps): training_args_kwargs["eval_steps"] = self.cfg.eval_steps else: training_args_kwargs["evaluation_strategy"] = "no" + if self.cfg.bf16 or self.cfg.bfloat16: training_args_kwargs["bf16"] = True + training_args_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio + training_args_kwargs["loraplus_lr_embedding"] = self.cfg.loraplus_lr_embedding training_args_kwargs["lr_scheduler_type"] = ( self.cfg.lr_scheduler if self.cfg.lr_scheduler else "cosine" ) @@ -1522,9 +1586,29 @@ def build_training_arguments(self, total_num_steps): # trl does some odd mapping of alpha to beta to reuse the beta parameter ??? training_args_kwargs["beta"] = self.cfg.orpo_alpha - training_args_cls = TrainingArguments + training_args_cls = AxolotlTrainingArguments if self.cfg.rl == "orpo": training_args_cls = ORPOConfig + training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes + training_args_kwargs["max_length"] = self.cfg.sequence_len + if self.cfg.max_prompt_len: + training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len + + if self.cfg.rl == "kto": + training_args_cls = KTOConfig + + training_args_kwargs["beta"] = self.cfg.rl_beta or 0.1 + training_args_kwargs["desirable_weight"] = ( + self.cfg.kto_desirable_weight or 1.0 + ) + training_args_kwargs["undesirable_weight"] = ( + self.cfg.kto_undesirable_weight or 1.0 + ) + + training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes + training_args_kwargs["max_length"] = self.cfg.sequence_len + if self.cfg.max_prompt_len: + training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len training_args = training_args_cls( per_device_train_batch_size=self.cfg.micro_batch_size, @@ -1561,7 +1645,7 @@ def build(self, total_num_steps): ] = self.cfg.precompute_ref_log_probs if self.cfg.rl in ["dpo", "ipo", "kto_pair"]: trainer_cls = AxolotlDPOTrainer - dpo_trainer_kwargs["beta"] = self.cfg.dpo_beta or 0.1 + dpo_trainer_kwargs["beta"] = self.cfg.rl_beta or 0.1 trainer_cls_args = [self.model, self.model_ref] # these aren't used for the ORPO trainer @@ -1569,9 +1653,14 @@ def build(self, total_num_steps): dpo_trainer_kwargs["max_target_length"] = None dpo_trainer_kwargs["max_prompt_length"] = self.cfg.sequence_len dpo_trainer_kwargs["generate_during_eval"] = True + if self.cfg.rl == "dpo": + dpo_trainer_kwargs["dataset_num_proc"] = self.cfg.dataset_processes elif self.cfg.rl == "orpo": trainer_cls = AxolotlORPOTrainer trainer_cls_args = [self.model] + elif self.cfg.rl == "kto": + trainer_cls = AxolotlKTOTrainer + trainer_cls_args = [self.model] else: raise ValueError(f"Unsupported RL: {self.cfg.rl}") dpo_trainer = trainer_cls( diff --git a/src/axolotl/monkeypatch/fastchat_conversation_turns.py b/src/axolotl/monkeypatch/fastchat_conversation_turns.py index 7ab07d4854..a09bfddb4b 100644 --- a/src/axolotl/monkeypatch/fastchat_conversation_turns.py +++ b/src/axolotl/monkeypatch/fastchat_conversation_turns.py @@ -123,6 +123,17 @@ def get_turns( # pylint: disable=too-many-return-statements else: yield role, "" return + if self.sep_style == SeparatorStyle.LLAMA3: + if self.system_message: + # For llama3, the system message is NOT incorporated into the first human instruction + # All messages follow <|start_header_id|>' + role + '<|end_header_id|>\n\n'+ message + '<|eot_id|> + yield "", system_prompt + for i, (role, message) in enumerate(self.messages): + if message: + yield f"<|start_header_id|>{role}<|end_header_id|>\n\n", f"{message.strip()}<|eot_id|>" + else: + yield f"<|start_header_id|>{role}<|end_header_id|>\n\n", "" + return if self.sep_style == SeparatorStyle.GEMMA: if self.system_message: raise ValueError("Gemma chat template does not support system messages") diff --git a/src/axolotl/monkeypatch/unsloth_.py b/src/axolotl/monkeypatch/unsloth_.py new file mode 100644 index 0000000000..de8260414e --- /dev/null +++ b/src/axolotl/monkeypatch/unsloth_.py @@ -0,0 +1,267 @@ +"""module for patching with unsloth optimizations""" + +import inspect +import logging +import re +import types +from typing import Tuple + +from peft import PeftModelForCausalLM +from transformers.models.llama.modeling_llama import ( + LlamaFlashAttention2, + LlamaForCausalLM, +) + +LOG = logging.getLogger("axolotl.monkeypatch.unsloth") + +ORIGINAL_CEL_CODE = """ if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) +""" + +PATCHED_CEL_CODE = """ if labels is not None: + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + loss = fast_cross_entropy_loss( + logits = shift_logits, + labels = shift_labels, + ) +""" + +ORIGINAL_QKV_CODE = """ + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) +""".lstrip( + "\n" +) + +PATCHED_QKV_CODE = """ + query_states, key_states, value_states = self.apply_qkv(self, hidden_states) +""".lstrip( + "\n" +) + +ORIGINAL_O_CODE = """ + attn_output = self.o_proj(attn_output) +""".lstrip( + "\n" +) + +PATCHED_O_CODE = """ + attn_output = self.apply_o(self, attn_output) +""".lstrip( + "\n" +) + + +def original_apply_qkv(self, hidden_states): + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + return query_states, key_states, value_states + + +def original_apply_o(self, hidden_states): + attn_output = self.o_proj(hidden_states) + return attn_output + + +def get_forward_code() -> str: + forward = inspect.getsource(LlamaForCausalLM.forward) + return forward + + +def test_cel_is_patchable() -> bool: + forward = get_forward_code() + return ORIGINAL_CEL_CODE in forward + + +def get_self_attn_code() -> str: + forward = inspect.getsource(LlamaFlashAttention2.forward) + return forward + + +def test_self_attn_is_patchable() -> bool: + qkv = get_self_attn_code() + return ORIGINAL_QKV_CODE in qkv and ORIGINAL_QKV_CODE in qkv + + +def integrate_cross_entropy_loss_patch(): + forward = get_forward_code() + LlamaForCausalLM._original_forward = forward # pylint: disable=protected-access + forward, _ = detab_code(forward) + assert ORIGINAL_CEL_CODE in forward, "Original forward code not found" + + forward = forward.replace( + "@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)", "" + ) + forward = forward.replace( + "@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)", + "", + ) + forward = forward.replace(ORIGINAL_CEL_CODE, PATCHED_CEL_CODE) + forward = forward.replace( + "def forward(", + "def fast_cross_entropy_loss_forward(", + 1, + ) + + # load imports necessary + import transformers.models.llama.modeling_llama + + items_to_import = [] + for item in dir(transformers.models.llama.modeling_llama): + if item in forward: + items_to_import.append(item) + + exec( # pylint: disable=exec-used # nosec B102 + "from unsloth.kernels.cross_entropy_loss import fast_cross_entropy_loss", + globals(), + ) + + exec( # pylint: disable=exec-used # nosec B102 + "from transformers.models.llama.modeling_llama import (" + + ", ".join(x for x in items_to_import) + + ")", + globals(), + ) + exec(forward, globals()) # pylint: disable=exec-used # nosec B102 + print("patching unsloth fast_cross_entropy_loss") + LlamaForCausalLM.forward = fast_cross_entropy_loss_forward # pylint: disable=undefined-variable # noqa: F821 + + +def detab_code(code: str) -> Tuple[str, str]: + spaces = re.match(r"([\s\t]{1,})", code).group(0) + code = re.sub(r"^" + spaces, "", code, flags=re.MULTILINE) + return code, spaces + + +def patch_self_attn_lora(): + self_attn_forward = get_self_attn_code() + LlamaFlashAttention2._original_forward = ( # pylint: disable=protected-access + self_attn_forward + ) + self_attn_forward, _ = detab_code(self_attn_forward) + assert ORIGINAL_QKV_CODE in self_attn_forward, "Original qkv code not found" + assert ORIGINAL_O_CODE in self_attn_forward, "Original o code not found" + + self_attn_forward = self_attn_forward.replace(ORIGINAL_QKV_CODE, PATCHED_QKV_CODE) + self_attn_forward = self_attn_forward.replace(ORIGINAL_O_CODE, PATCHED_O_CODE) + self_attn_forward = self_attn_forward.replace( + "def forward(", + "def unsloth_attn_forward(", + 1, + ) + + # load imports necessary + import transformers.models.llama.modeling_llama + + items_to_import = [] + for item in dir(transformers.models.llama.modeling_llama): + if item in self_attn_forward: + items_to_import.append(item) + + exec( # pylint: disable=exec-used # nosec B102 + "from transformers.models.llama.modeling_llama import (" + + ", ".join(x for x in items_to_import) + + ")", + globals(), + ) + exec(self_attn_forward, globals()) # pylint: disable=exec-used # nosec B102 + print("patching unsloth attn lora") + LlamaFlashAttention2.forward = ( + unsloth_attn_forward # pylint: disable=undefined-variable # noqa: F821 + ) + + +def integrate_lora_mlp_patch(peft_model: PeftModelForCausalLM): + if peft_model.base_model.config.model_type in ["llama", "mistral"]: + from unsloth.kernels import apply_lora_mlp_swiglu + + apply_lora_mlp = apply_lora_mlp_swiglu + elif peft_model.base_model.config.model_type == "gemma": + from unsloth.kernels import apply_lora_mlp_geglu_approx + + apply_lora_mlp = apply_lora_mlp_geglu_approx + else: + raise NotImplementedError( + f"Model type {peft_model.base_model.config.model_type} not supported" + ) + + for idx, layer in enumerate(peft_model.model.model.layers): + layer_modules = [ + getattr(layer.mlp, linear_proj) + for linear_proj in ["gate_proj", "up_proj", "down_proj"] + ] + is_mlp_lora = all(hasattr(module, "lora_A") for module in layer_modules) + mlp_no_bias = all( + getattr(module, "base_layer", module).bias is None + for module in layer_modules + ) + mlp_not_dora = all( + getattr(module, "lora_magnitude_vector", None) is None + for module in layer_modules + ) + + if is_mlp_lora and mlp_no_bias and mlp_not_dora: + layer.mlp.forward = types.MethodType(apply_lora_mlp, layer.mlp) + else: + logging.warning("unable to apply unsloth lora mlp patch to layer %d", idx) + + +def integrate_lora_patch(peft_model: PeftModelForCausalLM, cfg): + from unsloth.kernels import apply_lora_o, apply_lora_qkv + + for idx, layer in enumerate(peft_model.model.model.layers): + if cfg.unsloth_lora_qkv: + layer_modules = [ + getattr(layer.self_attn, linear_proj) + for linear_proj in ["q_proj", "k_proj", "v_proj"] + ] + is_qkv_lora = all(hasattr(module, "lora_A") for module in layer_modules) + qkv_no_bias = all( + getattr(module, "base_layer", module).bias is None + for module in layer_modules + ) + qkv_not_dora = all( + getattr(module, "lora_magnitude_vector", None) is None + for module in layer_modules + ) + + if is_qkv_lora and qkv_no_bias and qkv_not_dora: + layer.self_attn.apply_qkv = apply_lora_qkv + else: + layer.self_attn.apply_qkv = original_apply_qkv + logging.warning( + "unable to apply unsloth lora qkv patch to layer %d", idx + ) + if cfg.unsloth_lora_o: + layer_modules = [ + getattr(layer.self_attn, linear_proj) for linear_proj in ["o_proj"] + ] + is_o_lora = all(hasattr(module, "lora_A") for module in layer_modules) + o_no_bias = all( + getattr(module, "base_layer", module).bias is None + for module in layer_modules + ) + o_not_dora = all( + getattr(module, "lora_magnitude_vector", None) is None + for module in layer_modules + ) + + if is_o_lora and o_no_bias and o_not_dora: + layer.self_attn.apply_o = apply_lora_o + else: + layer.self_attn.apply_o = original_apply_o + logging.warning( + "unable to apply unsloth lora o_proj patch to layer %d", idx + ) diff --git a/src/axolotl/prompt_strategies/dpo/llama3.py b/src/axolotl/prompt_strategies/dpo/llama3.py new file mode 100644 index 0000000000..cb394cc228 --- /dev/null +++ b/src/axolotl/prompt_strategies/dpo/llama3.py @@ -0,0 +1,133 @@ +""" +DPO strategies for llama-3 chat template +""" + + +def argilla( + cfg, + **kwargs, +): # pylint: disable=possibly-unused-variable,unused-argument + def transform_fn(sample): + if "system" in sample and sample["system"]: + sample["prompt"] = ( + f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>" + f"<|start_header_id|>user<|end_header_id|>\n\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + ) + else: + sample[ + "prompt" + ] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + sample["chosen"] = f"{sample['chosen_response']}<|eot_id|>" + sample["rejected"] = f"{sample['rejected_response']}<|eot_id|>" + return sample + + return transform_fn + + +def argilla_chat( + cfg, + **kwargs, +): # pylint: disable=possibly-unused-variable,unused-argument + """ + for argilla/dpo-mix-7k conversations + """ + + def transform_fn(sample): + sample[ + "prompt" + ] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['chosen'][0]['content']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + sample["chosen"] = f"{sample['chosen'][1]['content']}<|eot_id|>" + sample["rejected"] = f"{sample['rejected'][1]['content']}<|eot_id|>" + return sample + + return transform_fn + + +def icr( + cfg, + **kwargs, +): # pylint: disable=possibly-unused-variable,unused-argument + """ + chatml transforms for datasets with system, input, chosen, rejected + ex. https://huggingface.co/datasets/argilla/distilabel-intel-orca-dpo-pairs + """ + + def transform_fn(sample): + if "system" in sample and sample["system"]: + sample["prompt"] = ( + f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>" + f"<|start_header_id|>user<|end_header_id|>\n\n{sample['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + ) + else: + sample[ + "prompt" + ] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + sample["chosen"] = f"{sample['chosen']}<|eot_id|>" + sample["rejected"] = f"{sample['rejected']}<|eot_id|>" + return sample + + return transform_fn + + +def intel(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument + """ + For Intel Orca DPO Pairs + """ + + def transform_fn(sample): + if "system" in sample and sample["system"]: + sample["prompt"] = ( + f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>" + f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + ) + else: + sample[ + "prompt" + ] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + sample["chosen"] = f"{sample['chosen']}<|eot_id|>" + sample["rejected"] = f"{sample['rejected']}<|eot_id|>" + return sample + + return transform_fn + + +def prompt_pairs( + cfg, **kwargs +): # pylint: disable=possibly-unused-variable,unused-argument + def transform_fn(sample): + if "system" in sample and sample["system"]: + sample["prompt"] = ( + f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>" + f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + ) + else: + sample[ + "prompt" + ] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + sample["chosen"] = f"{sample['chosen']}<|eot_id|>" + sample["rejected"] = f"{sample['rejected']}<|eot_id|>" + return sample + + return transform_fn + + +def ultra(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument + """ + for ultrafeedback binarized conversations + """ + + def transform_fn(sample): + if "system" in sample and sample["system"]: + sample["prompt"] = ( + f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>" + f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + ) + else: + sample[ + "prompt" + ] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + sample["chosen"] = f"{sample['chosen'][1]['content']}<|eot_id|>" + sample["rejected"] = f"{sample['rejected'][1]['content']}<|eot_id|>" + return sample + + return transform_fn diff --git a/src/axolotl/prompt_strategies/kto/__init__.py b/src/axolotl/prompt_strategies/kto/__init__.py new file mode 100644 index 0000000000..9af6300eb3 --- /dev/null +++ b/src/axolotl/prompt_strategies/kto/__init__.py @@ -0,0 +1,9 @@ +""" +module for KTO style dataset transform strategies +""" + +from functools import partial + +from ..base import load as load_base + +load = partial(load_base, module_base="axolotl.prompt_strategies.kto") diff --git a/src/axolotl/prompt_strategies/kto/chatml.py b/src/axolotl/prompt_strategies/kto/chatml.py new file mode 100644 index 0000000000..46c305f831 --- /dev/null +++ b/src/axolotl/prompt_strategies/kto/chatml.py @@ -0,0 +1,105 @@ +""" +KTO strategies for chatml +""" +# pylint: disable=duplicate-code + + +def argilla( + cfg, + **kwargs, +): # pylint: disable=possibly-unused-variable,unused-argument + def transform_fn(sample): + if "system" in sample and sample["system"]: + sample["prompt"] = ( + f"<|im_start|>system\n{sample['system']}<|im_end|>\n" + f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n" + ) + else: + sample[ + "prompt" + ] = f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n" + sample["completion"] = f"{sample['completion']}<|im_end|>" + return sample + + return transform_fn + + +def argilla_chat( + cfg, + **kwargs, +): # pylint: disable=possibly-unused-variable,unused-argument + """ + for argilla/kto-mix-15k conversations + """ + + def transform_fn(sample): + sample[ + "prompt" + ] = f"<|im_start|>user\n{sample['chosen'][0]['content']}<|im_end|>\n<|im_start|>assistant\n" + sample["completion"] = f"{sample['completion'][1]['content']}<|im_end|>" + return sample + + return transform_fn + + +def intel(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument + """ + For Intel Orca KTO + ex: argilla/distilabel-intel-orca-kto + """ + + def transform_fn(sample): + if "system" in sample and sample["system"]: + sample["prompt"] = ( + f"<|im_start|>system\n{sample['system']}<|im_end|>\n" + f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n" + ) + else: + sample[ + "prompt" + ] = f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n" + sample["completion"] = f"{sample['completion']}<|im_end|>" + return sample + + return transform_fn + + +def prompt_pairs( + cfg, **kwargs +): # pylint: disable=possibly-unused-variable,unused-argument + def transform_fn(sample): + if "system" in sample and sample["system"]: + sample["prompt"] = ( + f"<|im_start|>system\n{sample['system']}<|im_end|>\n" + f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n" + ) + else: + sample[ + "prompt" + ] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n" + sample["completion"] = f"{sample['completion']}<|im_end|>" + return sample + + return transform_fn + + +def ultra(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument + """ + for ultrafeedback binarized conversations + ex: argilla/ultrafeedback-binarized-preferences-cleaned-kto + """ + + def transform_fn(sample): + if "system" in sample and sample["system"]: + sample["prompt"] = ( + f"<|im_start|>system\n{sample['system']}<|im_end|>\n" + f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n" + ) + else: + sample[ + "prompt" + ] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n" + sample["completion"] = f"{sample['completion']}<|im_end|>" + return sample + + return transform_fn diff --git a/src/axolotl/prompt_strategies/kto/llama3.py b/src/axolotl/prompt_strategies/kto/llama3.py new file mode 100644 index 0000000000..795d343fe3 --- /dev/null +++ b/src/axolotl/prompt_strategies/kto/llama3.py @@ -0,0 +1,105 @@ +""" +KTO strategies for llama-3 chat template +""" +# pylint: disable=duplicate-code + + +def argilla( + cfg, + **kwargs, +): # pylint: disable=possibly-unused-variable,unused-argument + def transform_fn(sample): + if "system" in sample and sample["system"]: + sample["prompt"] = ( + f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>" + f"<|start_header_id|>user<|end_header_id|>\n\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + ) + else: + sample[ + "prompt" + ] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + sample["completion"] = f"{sample['completion']}<|eot_id|>" + return sample + + return transform_fn + + +def argilla_chat( + cfg, + **kwargs, +): # pylint: disable=possibly-unused-variable,unused-argument + """ + for argilla/kto-mix-15k conversations + """ + + def transform_fn(sample): + sample[ + "prompt" + ] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['completion'][0]['content']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + sample["completion"] = f"{sample['completion'][1]['content']}<|eot_id|>" + return sample + + return transform_fn + + +def intel(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument + """ + For Intel Orca KTO + ex: argilla/distilabel-intel-orca-kto + """ + + def transform_fn(sample): + if "system" in sample and sample["system"]: + sample["prompt"] = ( + f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>" + f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + ) + else: + sample[ + "prompt" + ] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + sample["completion"] = f"{sample['completion']}<|eot_id|>" + return sample + + return transform_fn + + +def prompt_pairs( + cfg, **kwargs +): # pylint: disable=possibly-unused-variable,unused-argument + def transform_fn(sample): + if "system" in sample and sample["system"]: + sample["prompt"] = ( + f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>" + f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + ) + else: + sample[ + "prompt" + ] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + sample["completion"] = f"{sample['completion']}<|eot_id|>" + return sample + + return transform_fn + + +def ultra(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument + """ + for ultrafeedback binarized conversations + ex: argilla/ultrafeedback-binarized-preferences-cleaned-kto + """ + + def transform_fn(sample): + if "system" in sample and sample["system"]: + sample["prompt"] = ( + f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>" + f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + ) + else: + sample[ + "prompt" + ] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + sample["completion"] = f"{sample['completion']}<|eot_id|>" + return sample + + return transform_fn diff --git a/src/axolotl/prompt_strategies/kto/user_defined.py b/src/axolotl/prompt_strategies/kto/user_defined.py new file mode 100644 index 0000000000..7e5458bb70 --- /dev/null +++ b/src/axolotl/prompt_strategies/kto/user_defined.py @@ -0,0 +1,39 @@ +""" +User-defined KTO strategies +""" +# pylint: disable=duplicate-code + + +def default(cfg, dataset_idx=0, **kwargs): # pylint: disable=unused-argument + ds_cfg = cfg["datasets"][dataset_idx]["type"] + if not isinstance(ds_cfg, dict): + raise ValueError( + f"User-defined dataset type must be a dictionary. Got: {ds_cfg}" + ) + field_prompt = ds_cfg.get("field_prompt", "prompt") + field_system = ds_cfg.get("field_system", "system") + field_completion = ds_cfg.get("field_completion", "completion") + field_label = ds_cfg.get("field_label", "label") + prompt_format = ds_cfg.get("prompt_format") + if not prompt_format: + prompt_format = "{" + field_prompt + "}" + completion_format = ds_cfg.get("completion_format") + if not completion_format: + chosen_format = "{" + field_completion + "}" + + def transform_fn(sample): + if ( + "{" + field_system + "}" in prompt_format + and field_system in sample + and sample[field_system] + ): + sample["prompt"] = prompt_format.format( + system=sample[field_system], prompt=sample[field_prompt] + ) + else: + sample["prompt"] = prompt_format.format(prompt=sample["prompt"]) + sample["completion"] = chosen_format.format(chosen=sample[field_completion]) + sample["label"] = sample[field_label] + return sample + + return transform_fn diff --git a/src/axolotl/prompt_strategies/sharegpt.py b/src/axolotl/prompt_strategies/sharegpt.py index 55bdd37b4f..8b452ae199 100644 --- a/src/axolotl/prompt_strategies/sharegpt.py +++ b/src/axolotl/prompt_strategies/sharegpt.py @@ -1,7 +1,7 @@ """Module containing the SimpleShareGPTPromptTokenizingStrategy class""" import logging -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Type from fastchat.conversation import Conversation, SeparatorStyle, register_conv_template @@ -22,7 +22,7 @@ def register_chatml_template(system_message=None): name="chatml", system_template="<|im_start|>system\n{system_message}", system_message=system_message, - roles=["<|im_start|>user", "<|im_start|>assistant"], + roles=("<|im_start|>user", "<|im_start|>assistant"), sep_style=SeparatorStyle.CHATML, sep="<|im_end|>", ) @@ -32,83 +32,65 @@ def register_chatml_template(system_message=None): name="chatml_glaive", system_template="<|im_start|>system\n{system_message}", system_message=system_message, - roles=["<|im_start|>user", "<|im_start|>assistant", "<|im_start|>tool"], + roles=("<|im_start|>user", "<|im_start|>assistant", "<|im_start|>tool"), sep_style=SeparatorStyle.CHATML, sep="<|im_end|>", ) ) -def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): - conversation = ( - ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None - ) - field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None - field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None - roles = ds_cfg["roles"].to_dict() if ds_cfg and "roles" in ds_cfg else None - strategy = SimpleShareGPTPromptTokenizingStrategy( - ShareGPTPrompterV2( - conversation=conversation, - role_key_model=field_model, - role_key_human=field_human, - roles=roles, - ), - tokenizer, - cfg.train_on_inputs, - cfg.sequence_len, - ) - if ds_cfg and "strict" in ds_cfg: - strategy.strict = ds_cfg["strict"] - return strategy - - -def load_ultrachat(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): - conversation = ( - ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None - ) - strategy = UltrachatShareGPTPromptTokenizingStrategy( - ShareGPTPrompterV2( - conversation=conversation, - ), - tokenizer, - cfg.train_on_inputs, - cfg.sequence_len, - ) - if ds_cfg and "strict" in ds_cfg: - strategy.strict = ds_cfg["strict"] - return strategy - - -def load_role(tokenizer, cfg): - return SimpleRoleShareGPTPromptTokenizingStrategy( - ShareGPTPrompterV2(), - tokenizer, - cfg.train_on_inputs, - cfg.sequence_len, +def register_llama3_template(system_message=None): + system_message = system_message or "You are a helpful assistant." + register_conv_template( + Conversation( + name="llama3", + system_template="<|start_header_id|>system<|end_header_id|>\n\n{system_message}<|eot_id|>", + system_message=system_message, + roles=("user", "assistant"), + sep_style=SeparatorStyle.LLAMA3, + sep="", + stop_str="<|eot_id|>", + stop_token_ids=[128001, 128009], + ) ) -def load_guanaco(tokenizer, cfg): - return GuanacoShareGPTPromptTokenizingStrategy( - ShareGPTPrompterV2(), - tokenizer, - cfg.train_on_inputs, - cfg.sequence_len, - ) - +def build_loader( + tokenization_strategy_cls: Type["ShareGPTPromptTokenizingStrategy"], + prompter_cls: Type["ShareGPTPrompterV2"], + default_conversation: Optional[str] = None, +): + def _load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): + conversation = ( + ds_cfg["conversation"] + if ds_cfg and "conversation" in ds_cfg + else default_conversation + ) + field_human = ( + ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None + ) + field_model = ( + ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None + ) + roles = ds_cfg["roles"].to_dict() if ds_cfg and "roles" in ds_cfg else None + strategy = tokenization_strategy_cls( + prompter_cls( + conversation=conversation, + role_key_model=field_model, + role_key_human=field_human, + roles=roles, + ), + tokenizer, + cfg.train_on_inputs, + cfg.sequence_len, + ) + if ds_cfg and "strict" in ds_cfg and hasattr(strategy, "strict"): + strategy.strict = ds_cfg["strict"] + if ds_cfg and "field_messages" in ds_cfg and hasattr(strategy, "messages"): + strategy.messages = ds_cfg["field_messages"] + return strategy -def load_glaive(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): - conversation = ( - ds_cfg["conversation"] - if ds_cfg and "conversation" in ds_cfg - else "chatml_glaive" - ) - return GlaiveShareGPTPromptTokenizingStrategy( - ShareGPTPrompterV2(conversation=conversation), - tokenizer, - cfg.train_on_inputs, - cfg.sequence_len, - ) + return _load class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy): @@ -117,6 +99,7 @@ class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy): """ _strict = False + _messages = "conversations" @property def strict(self): @@ -126,8 +109,16 @@ def strict(self): def strict(self, strict): self._strict = strict + @property + def messages(self): + return self._messages + + @messages.setter + def messages(self, messages): + self._messages = messages + def get_conversation_thread(self, prompt): - conversations = prompt["conversations"] + conversations = prompt[self.messages] if self.strict: return conversations role_key = "from" @@ -158,7 +149,9 @@ def get_conversation_thread(self, prompt): return turns -class SimpleRoleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy): +class SimpleRoleShareGPTPromptTokenizingStrategy( + SimpleShareGPTPromptTokenizingStrategy +): """ basic sharegpt strategy to grab conversations from the sample row, but uses role instead of from """ @@ -209,3 +202,16 @@ def get_conversation_thread(self, prompt): conversation = merge_consecutive_messages(conversation) return conversation + + +load = build_loader(SimpleShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2) +load_role = build_loader(SimpleRoleShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2) +load_ultrachat = build_loader( + UltrachatShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2 +) +load_guanaco = build_loader(GuanacoShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2) +load_glaive = build_loader( + GlaiveShareGPTPromptTokenizingStrategy, + ShareGPTPrompterV2, + default_conversation="chatml_glaive", +) diff --git a/src/axolotl/prompters.py b/src/axolotl/prompters.py index 2b6b4f8577..60ea5c99f9 100644 --- a/src/axolotl/prompters.py +++ b/src/axolotl/prompters.py @@ -263,6 +263,7 @@ def __repr__(self) -> str: "chatml": "<|im_start|>{ROLE}", "zephyr": "<|{ROLE}|>", "vicuna_v1.1": "{ROLE}", + "llama3": "<|start_header_id|>{ROLE}<|end_header_id|>", } @@ -348,7 +349,10 @@ def _build_result(self, source): ) if len(conv.messages) > 0 and ((role == conv.messages[-1][0])): - LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}") + if ( + role != "assistant" + ): # back to back assistant calls may be okay for tool calls + LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}") conv.append_message(role, sentence["value"]) diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 01e07640f9..32bcbc1d0a 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -3,6 +3,7 @@ import os import signal import sys +import weakref from dataclasses import dataclass from pathlib import Path from typing import Optional, Tuple, Union @@ -127,14 +128,20 @@ def train( # In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model if cfg.local_rank == 0: - def terminate_handler(_, __, model): - if cfg.flash_optimum and BetterTransformer: - model = BetterTransformer.reverse(model) - model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) + def terminate_handler(_, __, model_weakref): + if model_weakref() is not None: + _model = model_weakref() + if cfg.flash_optimum and BetterTransformer: + _model = BetterTransformer.reverse(_model) + _model.save_pretrained( + cfg.output_dir, safe_serialization=safe_serialization + ) sys.exit(0) + _model_weakref = weakref.ref(model) signal.signal( - signal.SIGINT, lambda signum, frame: terminate_handler(signum, frame, model) + signal.SIGINT, + lambda signum, frame: terminate_handler(signum, frame, _model_weakref), ) badge_markdown = """[Built with Axolotl](https://github.com/OpenAccess-AI-Collective/axolotl)""" @@ -205,6 +212,10 @@ def terminate_handler(_, __, model): if cfg.flash_optimum and BetterTransformer: model = BetterTransformer.reverse(model) + if cfg.rl and cfg.adapter and not cfg.rl_adapter_ref_model: + trainer.model.save_pretrained( + cfg.output_dir, safe_serialization=safe_serialization + ) model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) if not cfg.hub_model_id: diff --git a/src/axolotl/utils/callbacks/__init__.py b/src/axolotl/utils/callbacks/__init__.py index fbc1dcfad8..c21ef0ad7a 100644 --- a/src/axolotl/utils/callbacks/__init__.py +++ b/src/axolotl/utils/callbacks/__init__.py @@ -3,6 +3,7 @@ from __future__ import annotations import logging +import math import os from shutil import copyfile from tempfile import NamedTemporaryFile @@ -773,3 +774,31 @@ def on_train_begin( except (FileNotFoundError, ConnectionError) as err: LOG.warning(f"Error while saving Axolotl config to WandB: {err}") return control + + +class SaveModelCallback(TrainerCallback): + """Callback to save model on train end""" + + def on_step_end( # pylint: disable=unused-argument + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + # Save + if state.global_step >= state.max_steps: + control.should_save = True + elif ( + args.save_strategy == IntervalStrategy.STEPS + and state.save_steps < 1.0 + and state.global_step % math.ceil(state.save_steps * state.max_steps) == 0 + ): + # workaround to save model on fractional save_steps + control.should_save = True + + def on_train_end( # pylint: disable=unused-argument + self, args, state, control, **kwargs + ): + control.should_save = True + return control diff --git a/src/axolotl/utils/chat_templates.py b/src/axolotl/utils/chat_templates.py index c1dde8c0f3..1fe888aa80 100644 --- a/src/axolotl/utils/chat_templates.py +++ b/src/axolotl/utils/chat_templates.py @@ -24,6 +24,7 @@ def chat_templates(user_choice: str): "chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}", "gemma": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '' + role + '\n' + message['content'] | trim + '\n' }}{% endfor %}{% if add_generation_prompt %}{{'model\n'}}{% endif %}", "cohere": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}", + "llama3": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}", } if user_choice in templates: diff --git a/src/axolotl/utils/collators.py b/src/axolotl/utils/collators.py index f0a1fb1261..26c7fa9f3c 100644 --- a/src/axolotl/utils/collators.py +++ b/src/axolotl/utils/collators.py @@ -229,9 +229,8 @@ def __call__(self, features, return_tensors=None): if feature == "attention_mask": if self.multipack_attn: arrays = [ - (i + 1) * np.array(item[feature]) + (i + 1) * np.array(item) for i, item in enumerate(features[feature]) - if feature in item ] else: arrays = [(1) * np.array(item) for item in features[feature]] diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index e33774972c..067baf12ff 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -24,6 +24,7 @@ class DeprecatedParameters(BaseModel): max_packed_sequence_len: Optional[int] = None rope_scaling: Optional[Any] = None noisy_embedding_alpha: Optional[float] = None + dpo_beta: Optional[float] = None @field_validator("max_packed_sequence_len") @classmethod @@ -48,6 +49,13 @@ def validate_noisy_embedding_alpha(cls, noisy_embedding_alpha): LOG.warning("noisy_embedding_alpha is deprecated, use neftune_noise_alpha") return noisy_embedding_alpha + @field_validator("dpo_beta") + @classmethod + def validate_dpo_beta(cls, dpo_beta): + if dpo_beta is not None: + LOG.warning("dpo_beta is deprecated, use rl_beta instead") + return dpo_beta + class RemappedParameters(BaseModel): """parameters that have been remapped to other names""" @@ -101,6 +109,7 @@ class SFTDataset(BaseModel): field: Optional[str] = None field_human: Optional[str] = None field_model: Optional[str] = None + field_messages: Optional[str] = None roles: Optional[Dict[str, List[str]]] = None @@ -126,6 +135,26 @@ class DPODataset(BaseModel): data_files: Optional[List[str]] = None +class UserDefinedKTOType(BaseModel): + """User defined typing for KTO""" + + field_system: Optional[str] = None + field_prompt: Optional[str] = None + field_completion: Optional[str] = None + field_label: Optional[bool] = None + prompt_format: Optional[str] = None + completion_format: Optional[str] = None + + +class KTODataset(BaseModel): + """KTO configuration subset""" + + path: Optional[str] = None + split: Optional[str] = None + type: Optional[Union[UserDefinedKTOType, str]] = None + data_files: Optional[List[str]] = None + + class RLType(str, Enum): """RL trainer type configuration subset""" @@ -133,6 +162,7 @@ class RLType(str, Enum): ipo = "ipo" # pylint: disable=invalid-name kto_pair = "kto_pair" # pylint: disable=invalid-name orpo = "orpo" # pylint: disable=invalid-name + kto = "kto" # pylint: disable=invalid-name class ChatTemplate(str, Enum): @@ -143,6 +173,7 @@ class ChatTemplate(str, Enum): inst = "inst" # pylint: disable=invalid-name gemma = "gemma" # pylint: disable=invalid-name cohere = "cohere" # pylint: disable=invalid-name + llama3 = "llama3" # pylint: disable=invalid-name class LoftQConfig(BaseModel): @@ -182,7 +213,7 @@ class LoraConfig(BaseModel): lora_target_modules: Optional[List[str]] = None lora_target_linear: Optional[bool] = None lora_modules_to_save: Optional[List[str]] = None - lora_dropout: Optional[float] = None + lora_dropout: Optional[float] = 0.0 peft_layers_to_transform: Optional[List[int]] = None peft: Optional[PeftConfig] = None peft_use_dora: Optional[bool] = None @@ -409,6 +440,17 @@ def check_wandb_run(cls, data): return data +class GradioConfig(BaseModel): + """Gradio configuration subset""" + + gradio_title: Optional[str] = None + gradio_share: Optional[bool] = None + gradio_server_name: Optional[str] = None + gradio_server_port: Optional[int] = None + gradio_max_new_tokens: Optional[int] = None + gradio_temperature: Optional[float] = None + + class VesslConfig(BaseModel): """Vessl AI configuration subset""" @@ -426,6 +468,7 @@ class AxolotlInputConfig( MLFlowConfig, VesslConfig, LISAConfig, + GradioConfig, RemappedParameters, DeprecatedParameters, BaseModel, @@ -444,8 +487,8 @@ class Config: rl: Optional[RLType] = None - datasets: Optional[conlist(Union[SFTDataset, DPODataset], min_length=1)] = None # type: ignore - test_datasets: Optional[conlist(Union[SFTDataset, DPODataset], min_length=1)] = None # type: ignore + datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore + test_datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore shuffle_merged_datasets: Optional[bool] = True dataset_prepared_path: Optional[str] = None dataset_shard_num: Optional[int] = None @@ -511,7 +554,12 @@ class Config: sequence_len: int = Field(default=512) min_sample_len: Optional[int] = None + max_prompt_len: int = Field( + default=512, metadata={"help": "maximum prompt length for RL training"} + ) sample_packing: Optional[bool] = None + sample_packing_group_size: Optional[int] = 100_000 + sample_packing_bin_size: Optional[int] = 200 eval_sample_packing: Optional[bool] = None pad_to_sequence_len: Optional[bool] = None curriculum_sampling: Optional[bool] = None @@ -540,6 +588,11 @@ class Config: flash_attn_fuse_mlp: Optional[bool] = None flash_optimum: Optional[bool] = None + unsloth_cross_entropy_loss: Optional[bool] = None + unsloth_lora_mlp: Optional[bool] = None + unsloth_lora_qkv: Optional[bool] = None + unsloth_lora_o: Optional[bool] = None + deepspeed: Optional[Union[str, Dict[str, Any]]] = None fsdp: Optional[List[str]] = None fsdp_config: Optional[Dict[str, Any]] = None @@ -565,11 +618,17 @@ class Config: logging_steps: Optional[int] = None early_stopping_patience: Optional[int] = None load_best_model_at_end: Optional[bool] = False + save_only_model: Optional[bool] = False + use_tensorboard: Optional[bool] = None neftune_noise_alpha: Optional[float] = None orpo_alpha: Optional[float] = None + kto_desirable_weight: Optional[float] = None + kto_undesirable_weight: Optional[float] = None + rl_beta: Optional[float] = None + max_memory: Optional[ Dict[Union[int, Literal["cpu", "disk"]], Union[int, str]] ] = None @@ -869,6 +928,13 @@ def validate_neftune_noise_alpha(cls, neftune_noise_alpha): raise ValueError("neftune_noise_alpha must be > 0.0") return neftune_noise_alpha + @model_validator(mode="after") + def check(self): + if self.dpo_beta and not self.rl_beta: + self.rl_beta = self.dpo_beta + del self.dpo_beta + return self + @model_validator(mode="before") @classmethod def check_frozen(cls, data): diff --git a/src/axolotl/utils/data/pretraining.py b/src/axolotl/utils/data/pretraining.py index 544ed13162..e056c7f509 100644 --- a/src/axolotl/utils/data/pretraining.py +++ b/src/axolotl/utils/data/pretraining.py @@ -150,6 +150,8 @@ def wrap_pretraining_dataset( max_seq_length=max_tokens, batch_size=batch_size, multipack_attn=cfg.pretrain_multipack_attn, + group_size=cfg.sample_packing_group_size, + bin_size=cfg.sample_packing_bin_size, ) # set this to 1 so downstream data_loader doesn't try to increase the batch again cfg.micro_batch_size = 1 @@ -189,6 +191,8 @@ def encode_packed_pretraining( max_seq_length: int = 2048, batch_size: int = 4, multipack_attn: Optional[bool] = False, + group_size: int = 100000, + bin_size: int = 200, ) -> Dict[str, List]: # pylint: disable=duplicate-code # tokenize all the examples @@ -202,11 +206,13 @@ def encode_packed_pretraining( ) sampler = MultipackBatchSampler( - RandomSampler(train_dataset), + sampler=RandomSampler(train_dataset), + lengths=get_dataset_lengths(train_dataset), batch_size=1, - drop_last=True, batch_max_len=batch_size * max_seq_length, - lengths=get_dataset_lengths(train_dataset), + group_size=group_size, + bin_size=bin_size, + drop_last=True, ) chunked_data = defaultdict(list) diff --git a/src/axolotl/utils/data/rl.py b/src/axolotl/utils/data/rl.py index ff5ca87ddf..7416ca28bb 100644 --- a/src/axolotl/utils/data/rl.py +++ b/src/axolotl/utils/data/rl.py @@ -10,6 +10,7 @@ from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH from axolotl.prompt_strategies.dpo import load as load_dpo +from axolotl.prompt_strategies.kto import load as load_kto from axolotl.prompt_strategies.orpo import load as load_orpo from axolotl.utils.data.utils import md5 from axolotl.utils.dict import DictDefault @@ -55,6 +56,22 @@ def _save_preprocessed_ds(cfg, sub_cfg, dataset): dataset.save_to_disk(str(prepared_ds_path)) +def map_dataset(cfg, data_set, ds_transform_fn, tokenizer): + sig = inspect.signature(ds_transform_fn) + if "tokenizer" in sig.parameters: + if not tokenizer: + tokenizer = load_tokenizer(cfg) + ds_transform_fn = partial(ds_transform_fn, tokenizer=tokenizer) + + data_set = data_set.map( + ds_transform_fn, + desc="Mapping RL Dataset", + ) + if isinstance(data_set, DatasetDict): + data_set = data_set["train"] + return data_set + + def load_prepare_dpo_datasets(cfg): def load_split(dataset_cfgs, _cfg): split_datasets: List[Any] = [] @@ -76,6 +93,7 @@ def load_split(dataset_cfgs, _cfg): split_datasets.insert(i, ds) tokenizer = None + for i, data_set in enumerate(split_datasets): _type = dataset_cfgs[i]["type"] if _type: @@ -83,21 +101,19 @@ def load_split(dataset_cfgs, _cfg): _type = "user_defined.default" if _cfg.rl == "orpo": ds_transform_fn = load_orpo(_type, _cfg, dataset_idx=i) + elif _cfg.rl == "kto": + ds_transform_fn = load_kto(_type, _cfg, dataset_idx=i) else: ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i) - sig = inspect.signature(ds_transform_fn) - if "tokenizer" in sig.parameters: - if not tokenizer: - tokenizer = load_tokenizer(_cfg) - ds_transform_fn = partial(ds_transform_fn, tokenizer=tokenizer) - - data_set = data_set.map( - ds_transform_fn, - desc="Mapping RL Dataset", + + split_datasets[i] = map_dataset( + cfg, data_set, ds_transform_fn, tokenizer + ) + elif _cfg.rl == "kto": + ds_transform_fn = load_kto(_type, _cfg, dataset_idx=i) + split_datasets[i] = map_dataset( + cfg, data_set, ds_transform_fn, tokenizer ) - if isinstance(data_set, DatasetDict): - data_set = data_set["train"] - split_datasets[i] = data_set else: # If no `type` is provided, assume the dataset is already in the expected format with # "prompt", "chosen" and "rejected" already preprocessed diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 8537b7e754..a8df4bbad7 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -1,4 +1,5 @@ """Module for models and model loading""" + # pylint: disable=too-many-lines import logging @@ -389,6 +390,16 @@ def load_model( "Shifted-sparse attention not currently implemented without flash attention." ) + if cfg.unsloth_cross_entropy_loss: + from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch + + integrate_cross_entropy_loss_patch() + + if cfg.unsloth_lora_qkv or cfg.unsloth_lora_o: + from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora + + patch_self_attn_lora() + # Modify mistral derived models if ( cfg.model_config_type == "mistral" @@ -504,6 +515,9 @@ def load_model( bnb_config = { "load_in_8bit": True, } + # Exclude mamba blocks from int8 quantization for jamba + if cfg.model_config_type == "jamba": + bnb_config["llm_int8_skip_modules"] = ["mamba"] model_kwargs["quantization_config"] = BitsAndBytesConfig( **bnb_config, ) @@ -789,7 +803,11 @@ def load_model( if not reference_model or cfg.lora_model_dir: # if we're not loading the reference model, then we're loading the model for training # then the dpo trainer doesn't want the peft model loaded over it, it just wants the lora/peft config - if cfg.adapter and cfg.rl in ["dpo", "ipo", "kto_pair"] and not cfg.merge_lora: + if ( + cfg.adapter + and cfg.rl in ["dpo", "ipo", "kto_pair", "kto"] + and not cfg.merge_lora + ): _, lora_config = load_lora(model, cfg, inference=False, config_only=True) else: model, lora_config = load_adapter(model, cfg, cfg.adapter) @@ -824,6 +842,15 @@ def load_model( if cfg.adapter is not None: log_gpu_memory_usage(LOG, "after adapters", model.device) + if cfg.unsloth_lora_mlp: + from axolotl.monkeypatch.unsloth_ import integrate_lora_mlp_patch + + integrate_lora_mlp_patch(model) + if cfg.unsloth_lora_qkv or cfg.unsloth_lora_o: + from axolotl.monkeypatch.unsloth_ import integrate_lora_patch + + integrate_lora_patch(model, cfg) + # TODO resume_from_checkpoint handling return model, lora_config diff --git a/src/axolotl/utils/samplers/multipack.py b/src/axolotl/utils/samplers/multipack.py index cf47d9639b..07fd056826 100644 --- a/src/axolotl/utils/samplers/multipack.py +++ b/src/axolotl/utils/samplers/multipack.py @@ -1,105 +1,64 @@ -# pylint: skip-file """ Multipack Batch Sampler """ import logging -import math -import os -from typing import Any, Iterable, List, Union +from concurrent.futures import ProcessPoolExecutor +from multiprocessing import cpu_count import numba import numpy as np -from torch.utils.data import BatchSampler, Sampler +from torch.utils.data import BatchSampler LOG = logging.getLogger("axolotl.utils.samplers.multipack") +# First-fit-decreasing bin packing. @numba.njit -def ffd_check(a: np.ndarray, c: int, n: int): - # First-fit-decreasing bin packing - # Check if a[] could fit in n bins with capacity c - # https://en.wikipedia.org/wiki/First-fit-decreasing_bin_packing - - a = np.sort(a)[::-1] - bins = np.full((n,), c, dtype=a.dtype) - for size in a: - not_found = True - for idx in range(n): - if bins[idx] >= size: - bins[idx] -= size - not_found = False +def pack_group(items, group_offset, bin_capacity, max_items_per_bin): + idxs = np.argsort(items)[::-1] + sorted_items = items[idxs] + num_bins = len(items) + bins = np.full(num_bins, bin_capacity, dtype=np.int32) + bin_counts = np.zeros(num_bins, dtype=np.int32) + group_packing = np.full((num_bins, max_items_per_bin), -1, dtype=np.int32) + + for idx, item in enumerate(sorted_items): + global_idx = idxs[idx] + group_offset + + placed = False + for i in range(num_bins): + if bins[i] >= item and bin_counts[i] < max_items_per_bin: + bins[i] -= item + group_packing[i, bin_counts[i]] = global_idx + bin_counts[i] += 1 + placed = True break - if not_found: - return False + if not placed: + raise ValueError( + f"Item could not be packed. Try increasing cfg.sample_packing_bin_size ({max_items_per_bin})." + ) - return True + return group_packing -@numba.njit -def ffd_with_result(a: np.ndarray, c: int, start_index: int): - # First-fit-decreasing bin packing (with result return) - - indices = np.argsort(a)[::-1] - a = a[indices] - - bins: List[Any] = [] - bins_result: List[Any] = [] - for a_id, size in enumerate(a): - add_new = True - for idx in range(len(bins)): - if bins[idx] >= size: - bins[idx] -= size - bins_result[idx].append(indices[a_id] + start_index) - add_new = False - break - - if add_new: - bins.append(c - size) - bins_result.append([indices[a_id] + start_index]) - - return bins_result - - -@numba.njit -def allocate( - lengths: np.ndarray, lengths_cumsum: np.ndarray, rank: int, c: int, n: int -): - # Dynamic batch allocator, similar to Multifit - # https://en.wikipedia.org/wiki/Multifit_algorithm - # ~99.5% efficiency on OpenChat training set (12 * 2048 ctx len) - - s = 0 - start_index = 0 - result = [] - - while True: - # binary search [l, r) - left = 1 - right = 1 + np.searchsorted(lengths_cumsum[start_index:], s + c * n, "right") - - while right - left > 1: - mid = (left + right) // 2 - if ffd_check(lengths[start_index : start_index + mid], c, n): - left = mid - else: - right = mid - - # use length l - batch = ffd_with_result( - lengths[start_index : start_index + left], c, start_index - ) - assert len(batch) <= n - if len(batch) < n: - break - - start_index += left - s = lengths_cumsum[start_index - 1] +def pack(items, bin_capacity, group_size, max_items_per_bin): + num_items = len(items) + num_processes = max(1, min(num_items // group_size, cpu_count())) + tasks = [ + (items[i : i + group_size], i, bin_capacity, max_items_per_bin) + for i in range(0, num_items, group_size) + ] - # add local rank - result.append(batch[rank]) + packed_bins = [] + with ProcessPoolExecutor(max_workers=num_processes) as executor: + for group_packing in executor.map(pack_group, *zip(*tasks)): + for bin_pack in group_packing: + filtered_pack = bin_pack[bin_pack != -1] + if filtered_pack.size > 0: + packed_bins.append(filtered_pack.tolist()) - return result, s, len(result) * c * n + return packed_bins class MultipackBatchSampler(BatchSampler): @@ -109,94 +68,63 @@ class MultipackBatchSampler(BatchSampler): def __init__( self, - sampler: Union[Sampler[int], Iterable[int]], - batch_size: int, - drop_last: bool, - batch_max_len: int, - lengths: np.ndarray, - packing_efficiency_estimate: float = 1.0, + sampler, + lengths, + batch_max_len, + batch_size, + group_size=100_000, + bin_size=200, + drop_last=False, ): - super().__init__(sampler, batch_size, drop_last) - self.batch_size = batch_size + self.sampler = sampler + self.lengths = np.array(lengths, dtype=np.int32) self.batch_max_len = batch_max_len - self.lengths: np.ndarray = lengths - self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0 - - assert isinstance(self.lengths, np.ndarray) - - self.epoch = 0 - - # statistics - self.eff_total_used = 0 - self.eff_total_slots = 0 - - def set_epoch(self, epoch: int): - self.epoch = epoch - - def generate_batches(self, set_stats=False): - indices = [idx for idx in self.sampler] + self.batch_size = batch_size + self.group_size = group_size + self.bin_size = bin_size + self.drop_last = drop_last - lengths = self.lengths[indices] - lengths_cumsum = np.cumsum(lengths) + self._efficiency = None + self._batches = None - batches, total_used, total_slots = allocate( - lengths=lengths, - lengths_cumsum=lengths_cumsum, - rank=0, - c=self.batch_max_len, - n=1, + def efficiency(self): + if self._efficiency is None: + self._batches = self._pack_batches() + return self._efficiency + + def _pack_batches(self): + # Get possibly shuffled indices from sampler. + sample_idxs = np.arange(len(self.sampler)) + lengths = self.lengths[sample_idxs] + + pack_idxs = pack( + lengths, + self.batch_max_len, + self.group_size, + self.bin_size, ) - batches = [ - [ - [indices[b_idx] for b_idx in batch] - for batch in batches[i : i + self.batch_size] - ] - for i in range(0, len(batches), self.batch_size) + used_tokens = self.lengths.sum() + available_tokens = len(pack_idxs) * self.batch_max_len + self._efficiency = used_tokens / available_tokens + + # Wrap packs into batches. + batch_idxs = [ + pack_idxs[i : i + self.batch_size] + for i in range(0, len(pack_idxs), self.batch_size) ] - # statistics - if set_stats: - self.eff_total_used += total_used - self.eff_total_slots += total_slots + # Drop last batch if needed. + if self.drop_last and len(batch_idxs[-1]) < self.batch_size: + batch_idxs = batch_idxs[:-1] - return batches + return batch_idxs def __iter__(self): - batches = self.generate_batches(set_stats=True) - return iter(batches) - - def num_batches(self): - batches = self.generate_batches(set_stats=True) - return len(batches) - - def efficiency(self): - return self.eff_total_used / self.eff_total_slots + self._batches = self._pack_batches() + return iter(self._batches) def __len__(self): - self.num_batches() - return self._len_est() - - def _len_est(self): - world_size = int(os.getenv("WORLD_SIZE", "1")) - lengths_sum = np.sum(self.lengths) - lengths_sum_per_device = lengths_sum // world_size - LOG.info( - f"packing_efficiency_estimate: {self.packing_efficiency_estimate} " - f"total_num_tokens per device: {lengths_sum_per_device}" - ) - - # shave off 1% + 1 for dealing with variance in packing from random sampler to sampler - return max( - 0, - ( - world_size - * math.floor( - 0.99 - * lengths_sum_per_device - / self.packing_efficiency_estimate - // (self.batch_max_len * self.batch_size) - ) - - 1 - ), - ) + if self._batches is None: + self._batches = self._pack_batches() + return len(self._batches) diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 2e3728cc8a..6760dc4882 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -330,7 +330,6 @@ def calculate_total_num_steps(cfg, train_dataset, update=True): / cfg.sample_packing_eff_est / cfg.sequence_len // cfg.batch_size - // int(os.environ.get("WORLD_SIZE", 1)) ) - 1 ) @@ -342,42 +341,37 @@ def calculate_total_num_steps(cfg, train_dataset, update=True): ) else: if cfg.flash_attention: - batch_size = 1 + sampler_batch_size = 1 batch_max_len = cfg.micro_batch_size * cfg.sequence_len else: - batch_size = cfg.micro_batch_size + sampler_batch_size = cfg.micro_batch_size batch_max_len = cfg.sequence_len sampler = MultipackBatchSampler( sampler=RandomSampler(train_dataset), - batch_size=batch_size, - drop_last=True, - batch_max_len=batch_max_len, lengths=get_dataset_lengths(train_dataset), + batch_size=sampler_batch_size, + batch_max_len=batch_max_len, + group_size=cfg.sample_packing_group_size, + bin_size=cfg.sample_packing_bin_size, + drop_last=True, ) data_loader = DataLoader( train_dataset.remove_columns(["length"]), batch_sampler=sampler, ) - data_loader_len = len(data_loader) // cfg.batch_size - actual_eff = sampler.efficiency() + data_loader_len = len(data_loader) * cfg.micro_batch_size // cfg.batch_size LOG.debug(f"data_loader_len: {data_loader_len}", main_process_only=True) # FIXME: is there a bug here somewhere? the total num steps depends # on the agreed on value for sample_packing_eff_est - total_num_steps = int( - math.floor( - data_loader_len - * cfg.num_epochs - / int(os.environ.get("WORLD_SIZE", 1)) - ) - ) + total_num_steps = int(math.floor(data_loader_len * cfg.num_epochs)) def calc_sample_packing_eff_est(estimates: List[float]): LOG.info(f"sample_packing_eff_est across ranks: {repr(estimates)}") return max(estimates) sample_packing_actual_eff_all = reduce_and_broadcast( - lambda: actual_eff, + lambda: sampler.efficiency(), # pylint: disable=unnecessary-lambda calc_sample_packing_eff_est, ) sample_packing_eff_est = ( @@ -391,12 +385,7 @@ def calc_sample_packing_eff_est(estimates: List[float]): ) else: total_num_steps = int( - math.ceil( - len(train_dataset) - * cfg.num_epochs - / int(os.environ.get("WORLD_SIZE", 1)) - / cfg.batch_size - ) + math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size) ) LOG.debug(f"total_num_steps: {total_num_steps}", main_process_only=True) return total_num_steps @@ -438,7 +427,7 @@ def prepare_optim_env(cfg): def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps): - if cfg.rl in ["dpo", "ipo", "kto_pair", "orpo"]: + if cfg.rl in ["dpo", "ipo", "kto_pair", "orpo", "kto"]: trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer) trainer_builder.model_ref = model[1] trainer_builder.peft_config = model[2] diff --git a/tests/e2e/test_dpo.py b/tests/e2e/test_dpo.py index 9596b1873f..ddd63d8271 100644 --- a/tests/e2e/test_dpo.py +++ b/tests/e2e/test_dpo.py @@ -205,3 +205,66 @@ def test_orpo_lora(self, temp_dir): train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists() + + @with_temp_dir + def test_kto_lora(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "JackFram/llama-68m", + "tokenizer_type": "LlamaTokenizer", + "sequence_len": 1024, + "load_in_8bit": True, + "adapter": "lora", + "lora_r": 64, + "lora_alpha": 32, + "lora_dropout": 0.1, + "lora_target_linear": True, + "special_tokens": {}, + "rl": "kto", + "rl_beta": 0.5, + "kto_desirable_weight": 1.0, + "kto_undesirable_weight": 1.0, + "remove_unused_columns": False, + "datasets": [ + # { + # "path": "argilla/kto-mix-15k", + # "type": "chatml.argilla_chat", + # "split": "train", + # }, + { + "path": "argilla/ultrafeedback-binarized-preferences-cleaned-kto", + "type": "chatml.ultra", + "split": "train", + }, + # { + # "path": "argilla/kto-mix-15k", + # "type": "llama3.argilla_chat", + # "split": "train", + # }, + { + "path": "argilla/ultrafeedback-binarized-preferences-cleaned-kto", + "type": "llama3.ultra", + "split": "train", + }, + ], + "num_epochs": 1, + "micro_batch_size": 4, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "paged_adamw_8bit", + "lr_scheduler": "cosine", + "max_steps": 20, + "save_steps": 10, + "warmup_steps": 5, + "gradient_checkpointing": True, + "gradient_checkpointing_kwargs": {"use_reentrant": True}, + } + ) + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists() diff --git a/tests/prompt_strategies/test_chat_templates.py b/tests/prompt_strategies/test_chat_templates.py new file mode 100644 index 0000000000..1076c6a3bf --- /dev/null +++ b/tests/prompt_strategies/test_chat_templates.py @@ -0,0 +1,85 @@ +""" +tests for chat_template prompt strategy +""" +import unittest + +import pytest +from datasets import Dataset +from transformers import AutoTokenizer + +from axolotl.prompt_strategies.chat_template import ( + ChatTemplatePrompter, + ChatTemplateStrategy, +) +from axolotl.utils.chat_templates import chat_templates + + +@pytest.fixture(name="sharegpt_dataset") +def fixture_sharegpt_dataset(): + # pylint: disable=duplicate-code + return Dataset.from_list( + [ + { + "conversations": [ + { + "from": "human", + "value": "hello", + }, + { + "from": "gpt", + "value": "hello", + }, + { + "from": "human", + "value": "goodbye", + }, + { + "from": "gpt", + "value": "goodbye", + }, + ] + } + ] + ) + + +@pytest.fixture(name="llama3_tokenizer") +def fixture_llama3_tokenizer(): + tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B") + tokenizer.eos_token = "<|eot_id|>" + + return tokenizer + + +class TestSharegptChatTemplateLlama3: + """ + Test class for ShareGPT style datasets with llama-3 prompts using the chat_template strategy. + """ + + def test_llama3(self, llama3_tokenizer, sharegpt_dataset): + # pylint: disable=duplicate-code + strategy = ChatTemplateStrategy( + ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + llama3_tokenizer, + False, + 512, + ) + res = strategy.tokenize_prompt(sharegpt_dataset[0]) + input_ids = res["input_ids"] + # fmt: off + assert input_ids == [ + 128000, # bos + 128006, 882, 128007, # user header + 271, 15339, 128009, # user prompt eot + 128006, 78191, 128007, # assistant header + 271, 15339, 128009, # assistant response eot + 128006, 882, 128007, + 271, 19045, 29474, 128009, + 128006, 78191, 128007, + 271, 19045, 29474, 128009, + ] + # fmt: on + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/prompt_strategies/test_sharegpt.py b/tests/prompt_strategies/test_sharegpt.py index 3ff0eab053..6e69098340 100644 --- a/tests/prompt_strategies/test_sharegpt.py +++ b/tests/prompt_strategies/test_sharegpt.py @@ -12,10 +12,12 @@ GlaiveShareGPTPromptTokenizingStrategy, SimpleShareGPTPromptTokenizingStrategy, register_chatml_template, + register_llama3_template, ) from axolotl.prompters import ShareGPTPrompterV2 register_chatml_template() +register_llama3_template() @pytest.fixture(name="sharegpt_dataset") @@ -115,7 +117,53 @@ def fixture_tokenizer(): return tokenizer -class TestSharegpt: +@pytest.fixture(name="llama3_tokenizer") +def fixture_llama3_tokenizer(): + tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B") + tokenizer.eos_token = "<|eot_id|>" + + return tokenizer + + +class TestSharegptLlama3: + """Test class for ShareGPT style datasets with llama-3 prompts""" + + def test_tokenization(self, sharegpt_dataset, llama3_tokenizer): + strategy = SimpleShareGPTPromptTokenizingStrategy( + ShareGPTPrompterV2( + conversation="llama3", + role_key_model=None, + role_key_human=None, + ), + llama3_tokenizer, + False, # train_on_inputs + 2048, # sequence_len + ) + + dataset_wrapper = TokenizedPromptDataset( + strategy, sharegpt_dataset, process_count=1 + ) + + input_ids = dataset_wrapper[0]["input_ids"] + + # fmt: off + assert input_ids == [ + 128000, # bos + 128006, 9125, 128007, # system header + 271, 31724, 128009, # sys prompt, eot + 128006, 882, 128007, # user header + 271, 15339, 128009, # user prompt eot + 128006, 78191, 128007, # assistant header + 271, 15339, 128009, # assistant response eot + 128006, 882, 128007, + 271, 19045, 29474, 128009, + 128006, 78191, 128007, + 271, 19045, 29474, 128009, + ] + # fmt: on + + +class TestSharegptChatML: """ Test class for sharegpt prompter """ diff --git a/tests/test_packed_batch_sampler.py b/tests/test_packed_batch_sampler.py index 50f39d60f5..ceff11df94 100644 --- a/tests/test_packed_batch_sampler.py +++ b/tests/test_packed_batch_sampler.py @@ -62,12 +62,14 @@ def test_packing(self, batch_size, num_workers, tokenizer, max_seq_length): dataset, ) train_dataset = concatenate_datasets([dataset_wrapper]) + lengths = get_dataset_lengths(train_dataset) batch_sampler = MultipackBatchSampler( sampler=RandomSampler(train_dataset), + lengths=lengths, batch_size=batch_size, - drop_last=True, batch_max_len=max_seq_length, - lengths=get_dataset_lengths(train_dataset), + group_size=100000, + bin_size=200, ) loader = DataLoader( @@ -81,19 +83,15 @@ def test_packing(self, batch_size, num_workers, tokenizer, max_seq_length): ), num_workers=num_workers, ) - inputs = next(iter(loader)) - assert inputs["input_ids"].shape == (batch_size, max_seq_length) - assert inputs["labels"].shape == (batch_size, max_seq_length) - assert inputs["attention_mask"].shape == (batch_size, max_seq_length) + batch_idxs = [] + for batch in batch_sampler: + for pack in batch: + batch_idxs.extend(pack) - assert inputs["input_ids"].tolist()[0][0] == 2 - assert inputs["labels"].tolist()[0][0] == -100 - assert inputs["attention_mask"].tolist()[0][0] == 0 - assert inputs["attention_mask"].tolist()[0][-1] > 1 + for batch in loader: + assert len(batch["input_ids"]) <= batch_size * max_seq_length + assert batch["input_ids"].shape[1] == max_seq_length - if batch_size >= 2: - assert inputs["input_ids"].tolist()[1][0] == 2 - assert inputs["labels"].tolist()[1][0] == -100 - assert inputs["attention_mask"].tolist()[1][0] == 0 - assert inputs["attention_mask"].tolist()[1][-1] > 1 + original_idxs = set(range(len(train_dataset))) + assert original_idxs == set(batch_idxs) diff --git a/tests/test_packed_pretraining.py b/tests/test_packed_pretraining.py index 528f9c8074..fb623a43dc 100644 --- a/tests/test_packed_pretraining.py +++ b/tests/test_packed_pretraining.py @@ -42,6 +42,8 @@ def test_packing_stream_dataset(self): "pad_to_sequence_len": True, "sequence_len": 2048, "micro_batch_size": 2, + "sample_packing_group_size": 100000, + "sample_packing_bin_size": 200, } ) diff --git a/tests/test_validation.py b/tests/test_validation.py index 27824f2887..35d0e265e7 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -1117,6 +1117,15 @@ def test_hub_model_id_save_value_no_set_save_strategy(self, minimal_cfg): validate_config(cfg) assert len(self._caplog.records) == 0 + def test_dpo_beta_deprecation(self, minimal_cfg): + cfg = DictDefault({"dpo_beta": 0.2}) | minimal_cfg + + with self._caplog.at_level(logging.WARNING): + new_cfg = validate_config(cfg) + assert new_cfg["rl_beta"] == 0.2 + assert new_cfg["dpo_beta"] is None + assert len(self._caplog.records) == 1 + class TestValidationCheckModelConfig(BaseValidation): """