diff --git a/.circleci/config.yml b/.circleci/config.yml
index 75413af8bf5254..9c414901c4f5ac 100644
--- a/.circleci/config.yml
+++ b/.circleci/config.yml
@@ -58,14 +58,14 @@ jobs:
name: "Prepare pipeline parameters"
command: |
python utils/process_test_artifacts.py
-
+
# To avoid too long generated_config.yaml on the continuation orb, we pass the links to the artifacts as parameters.
# Otherwise the list of tests was just too big. Explicit is good but for that it was a limitation.
# We used:
# https://circleci.com/docs/api/v2/index.html#operation/getJobArtifacts : to get the job artifacts
# We could not pass a nested dict, which is why we create the test_file_... parameters for every single job
-
+
- store_artifacts:
path: test_preparation/transformed_artifacts.json
- store_artifacts:
diff --git a/.circleci/create_circleci_config.py b/.circleci/create_circleci_config.py
index 71c75dac2ff053..84c0f65166baef 100644
--- a/.circleci/create_circleci_config.py
+++ b/.circleci/create_circleci_config.py
@@ -32,7 +32,7 @@
"RUN_PT_FLAX_CROSS_TESTS": False,
}
# Disable the use of {"s": None} as the output is way too long, causing the navigation on CircleCI impractical
-COMMON_PYTEST_OPTIONS = {"max-worker-restart": 0, "dist": "loadfile", "vvv": None, "rsfE":None}
+COMMON_PYTEST_OPTIONS = {"max-worker-restart": 0, "dist": "loadfile", "vvv": None, "rsf":None}
DEFAULT_DOCKER_IMAGE = [{"image": "cimg/python:3.8.12"}]
@@ -40,23 +40,9 @@ class EmptyJob:
job_name = "empty"
def to_dict(self):
- steps = [{"run": 'ls -la'}]
- if self.job_name == "collection_job":
- steps.extend(
- [
- "checkout",
- {"run": "pip install requests || true"},
- {"run": """while [[ $(curl --location --request GET "https://circleci.com/api/v2/workflow/$CIRCLE_WORKFLOW_ID/job" --header "Circle-Token: $CCI_TOKEN"| jq -r '.items[]|select(.name != "collection_job")|.status' | grep -c "running") -gt 0 ]]; do sleep 5; done || true"""},
- {"run": 'python utils/process_circleci_workflow_test_reports.py --workflow_id $CIRCLE_WORKFLOW_ID || true'},
- {"store_artifacts": {"path": "outputs"}},
- {"run": 'echo "All required jobs have now completed"'},
- ]
- )
-
return {
"docker": copy.deepcopy(DEFAULT_DOCKER_IMAGE),
- "resource_class": "small",
- "steps": steps,
+ "steps":["checkout"],
}
@@ -68,9 +54,9 @@ class CircleCIJob:
install_steps: List[str] = None
marker: Optional[str] = None
parallelism: Optional[int] = 0
- pytest_num_workers: int = 8
+ pytest_num_workers: int = 12
pytest_options: Dict[str, Any] = None
- resource_class: Optional[str] = "xlarge"
+ resource_class: Optional[str] = "2xlarge"
tests_to_run: Optional[List[str]] = None
num_test_files_per_worker: Optional[int] = 10
# This should be only used for doctest job!
@@ -199,6 +185,7 @@ def job_name(self):
docker_image=[{"image": "huggingface/transformers-torch-light"}],
marker="not generate",
parallelism=6,
+ pytest_num_workers=8
)
generate_job = CircleCIJob(
@@ -206,24 +193,28 @@ def job_name(self):
docker_image=[{"image": "huggingface/transformers-torch-light"}],
marker="generate",
parallelism=6,
+ pytest_num_workers=8
)
tokenization_job = CircleCIJob(
"tokenization",
docker_image=[{"image": "huggingface/transformers-torch-light"}],
parallelism=8,
+ pytest_num_workers=16
)
processor_job = CircleCIJob(
"processors",
docker_image=[{"image": "huggingface/transformers-torch-light"}],
parallelism=8,
+ pytest_num_workers=6
)
tf_job = CircleCIJob(
"tf",
docker_image=[{"image":"huggingface/transformers-tf-light"}],
parallelism=6,
+ pytest_num_workers=16,
)
@@ -231,8 +222,7 @@ def job_name(self):
"flax",
docker_image=[{"image":"huggingface/transformers-jax-light"}],
parallelism=6,
- pytest_num_workers=16,
- resource_class="2xlarge",
+ pytest_num_workers=16
)
@@ -241,7 +231,7 @@ def job_name(self):
additional_env={"RUN_PIPELINE_TESTS": True},
docker_image=[{"image":"huggingface/transformers-torch-light"}],
marker="is_pipeline_test",
- parallelism=4,
+ parallelism=4
)
@@ -250,7 +240,7 @@ def job_name(self):
additional_env={"RUN_PIPELINE_TESTS": True},
docker_image=[{"image":"huggingface/transformers-tf-light"}],
marker="is_pipeline_test",
- parallelism=4,
+ parallelism=4
)
@@ -267,6 +257,7 @@ def job_name(self):
docker_image=[{"image":"huggingface/transformers-examples-torch"}],
# TODO @ArthurZucker remove this once docker is easier to build
install_steps=["uv venv && uv pip install . && uv pip install -r examples/pytorch/_tests_requirements.txt"],
+ pytest_num_workers=8,
)
@@ -274,6 +265,7 @@ def job_name(self):
"examples_tensorflow",
additional_env={"OMP_NUM_THREADS": 8},
docker_image=[{"image":"huggingface/transformers-examples-tf"}],
+ pytest_num_workers=16,
)
@@ -288,7 +280,6 @@ def job_name(self):
],
marker="is_staging_test",
pytest_num_workers=2,
- resource_class="medium",
)
@@ -301,13 +292,13 @@ def job_name(self):
],
pytest_options={"k onnx": None},
pytest_num_workers=1,
- resource_class="small",
)
exotic_models_job = CircleCIJob(
"exotic_models",
docker_image=[{"image":"huggingface/transformers-exotic-models"}],
+ pytest_num_workers=12,
parallelism=4,
pytest_options={"durations": 100},
)
@@ -326,6 +317,7 @@ def job_name(self):
docker_image=[{"image": "huggingface/transformers-torch-light"}],
marker="not generate",
parallelism=6,
+ pytest_num_workers=8,
)
@@ -360,7 +352,6 @@ def job_name(self):
DOC_TESTS = [doc_test_job]
ALL_TESTS = REGULAR_TESTS + EXAMPLES_TESTS + PIPELINE_TESTS + REPO_UTIL_TESTS + DOC_TESTS + [custom_tokenizers_job] + [exotic_models_job] # fmt: skip
-
def create_circleci_config(folder=None):
if folder is None:
folder = os.getcwd()
@@ -370,13 +361,7 @@ def create_circleci_config(folder=None):
if len(jobs) == 0:
jobs = [EmptyJob()]
- else:
- print("Full list of job name inputs", {j.job_name + "_test_list":{"type":"string", "default":''} for j in jobs})
- # Add a job waiting all the test jobs and aggregate their test summary files at the end
- collection_job = EmptyJob()
- collection_job.job_name = "collection_job"
- jobs = [collection_job] + jobs
-
+ print("Full list of job name inputs", {j.job_name + "_test_list":{"type":"string", "default":''} for j in jobs})
config = {
"version": "2.1",
"parameters": {
@@ -386,7 +371,7 @@ def create_circleci_config(folder=None):
**{j.job_name + "_test_list":{"type":"string", "default":''} for j in jobs},
**{j.job_name + "_parallelism":{"type":"integer", "default":1} for j in jobs},
},
- "jobs": {j.job_name: j.to_dict() for j in jobs}
+ "jobs" : {j.job_name: j.to_dict() for j in jobs}
}
if "CIRCLE_TOKEN" in os.environ:
# For private forked repo. (e.g. new model addition)
diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml
index 1bbd1c1e94d08c..eaa4b3b2f82456 100644
--- a/.github/workflows/benchmark.yml
+++ b/.github/workflows/benchmark.yml
@@ -63,7 +63,7 @@ jobs:
commit_id=$GITHUB_SHA
fi
commit_msg=$(git show -s --format=%s | cut -c1-70)
- python3 benchmark/benchmarks_entrypoint.py "${{ github.head_ref || github.ref_name }}" "$commit_id" "$commit_msg"
+ python3 benchmark/llama.py "${{ github.head_ref || github.ref_name }}" "$commit_id" "$commit_msg"
env:
HF_TOKEN: ${{ secrets.HF_HUB_READ_TOKEN }}
# Enable this to see debug logs
diff --git a/.github/workflows/push-important-models.yml b/.github/workflows/push-important-models.yml
index 7294777655e183..1887af0f4c5bac 100644
--- a/.github/workflows/push-important-models.yml
+++ b/.github/workflows/push-important-models.yml
@@ -134,3 +134,10 @@ jobs:
slackChannel: ${{ secrets.SLACK_CIFEEDBACK_CHANNEL }}
slackToken: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
waitForSSH: true
+
+ benchmark:
+ name: Benchmark workflow
+ needs: get_modified_models
+ if: ${{ needs.get_modified_models.outputs.matrix != '[]' && needs.get_modified_models.outputs.matrix != '' && fromJson(needs.get_modified_models.outputs.matrix)[0] != null }}
+ uses: ./.github/workflows/benchmark.yml
+ secrets: inherit
diff --git a/.github/workflows/self-comment-ci.yml b/.github/workflows/self-comment-ci.yml
deleted file mode 100644
index b344ecfd59527d..00000000000000
--- a/.github/workflows/self-comment-ci.yml
+++ /dev/null
@@ -1,253 +0,0 @@
-name: PR comment GitHub CI
-
-on:
- issue_comment:
- types:
- - created
- branches-ignore:
- - main
-concurrency:
- group: ${{ github.workflow }}-${{ github.event.issue.number }}-${{ startsWith(github.event.comment.body, 'run-slow') || startsWith(github.event.comment.body, 'run slow') || startsWith(github.event.comment.body, 'run_slow') }}
- cancel-in-progress: true
-
-jobs:
- get-pr-number:
- runs-on: ubuntu-22.04
- name: Get PR number
- # For security: only allow team members to run
- if: ${{ github.event.issue.state == 'open' && contains(fromJSON('["ydshieh", "ArthurZucker", "zucchini-nlp", "qubvel", "molbap", "gante", "LysandreJik", "Cyrilvallez"]'), github.actor) && (startsWith(github.event.comment.body, 'run-slow') || startsWith(github.event.comment.body, 'run slow') || startsWith(github.event.comment.body, 'run_slow')) }}
- outputs:
- PR_NUMBER: ${{ steps.set_pr_number.outputs.PR_NUMBER }}
- steps:
- - name: Get PR number
- shell: bash
- run: |
- if [[ "${{ github.event.issue.number }}" != "" && "${{ github.event.issue.pull_request }}" != "" ]]; then
- echo "PR_NUMBER=${{ github.event.issue.number }}" >> $GITHUB_ENV
- else
- echo "PR_NUMBER=" >> $GITHUB_ENV
- fi
-
- - name: Check PR number
- shell: bash
- run: |
- echo "${{ env.PR_NUMBER }}"
-
- - name: Set PR number
- id: set_pr_number
- run: echo "PR_NUMBER=${{ env.PR_NUMBER }}" >> "$GITHUB_OUTPUT"
-
- get-sha:
- runs-on: ubuntu-22.04
- needs: get-pr-number
- if: ${{ needs.get-pr-number.outputs.PR_NUMBER != ''}}
- outputs:
- PR_HEAD_SHA: ${{ steps.get_sha.outputs.PR_HEAD_SHA }}
- steps:
- - uses: actions/checkout@v4
- with:
- fetch-depth: "0"
- ref: "refs/pull/${{needs.get-pr-number.outputs.PR_NUMBER}}/merge"
-
- - name: Get SHA
- id: get_sha
- env:
- PR_NUMBER: ${{needs.get-pr-number.outputs.PR_NUMBER}}
- run: |
- git fetch origin refs/pull/$PR_NUMBER/head:refs/remotes/pull/$PR_NUMBER/head
- git checkout refs/remotes/pull/$PR_NUMBER/head
- echo "PR_HEAD_SHA: $(git log -1 --format=%H)"
- echo "PR_HEAD_SHA=$(git log -1 --format=%H)" >> "$GITHUB_OUTPUT"
-
- # use a python script to handle this complex logic
- # case 1: `run-slow` (auto. infer with limited number of models, but in particular, new model)
- # case 2: `run-slow model_1, model_2`
- get-tests:
- runs-on: ubuntu-22.04
- needs: get-pr-number
- if: ${{ needs.get-pr-number.outputs.PR_NUMBER != ''}}
- permissions: write-all
- outputs:
- models: ${{ steps.models_to_run.outputs.models }}
- steps:
- - uses: actions/checkout@v4
- with:
- fetch-depth: "0"
- ref: "refs/pull/${{needs.get-pr-number.outputs.PR_NUMBER}}/merge"
-
- - name: Get models to test
- env:
- PR_COMMENT: ${{ github.event.comment.body }}
- run: |
- python -m pip install GitPython
- python utils/pr_slow_ci_models.py --message "$PR_COMMENT" | tee output.txt
- echo "models=$(tail -n 1 output.txt)" >> $GITHUB_ENV
-
- - name: Show models to test
- id: models_to_run
- run: |
- echo "${{ env.models }}"
- echo "models=${{ env.models }}" >> $GITHUB_ENV
- echo "models=${{ env.models }}" >> $GITHUB_OUTPUT
-
- - name: Reply to the comment
- if: ${{ env.models != '[]' }}
- env:
- GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- run: |
- gh api \
- --method POST \
- -H "Accept: application/vnd.github+json" \
- -H "X-GitHub-Api-Version: 2022-11-28" \
- repos/${{ github.repository }}/issues/${{ needs.get-pr-number.outputs.PR_NUMBER }}/comments \
- -f "body=This comment contains run-slow, running the specified jobs: ${{ env.models }} ..."
-
- create_run:
- name: Create run
- if: ${{ needs.get-tests.outputs.models != '[]' }}
- needs: [get-sha, get-tests]
- permissions: write-all
- runs-on: ubuntu-22.04
- steps:
- - name: Create Run
- id: create_run
- env:
- GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- # Create a commit status (pending) for a run of this workflow. The status has to be updated later in `update_run_status`.
- # See https://docs.github.com/en/rest/commits/statuses?apiVersion=2022-11-28#create-a-commit-status
- GITHUB_RUN_URL: https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}
- run: |
- gh api \
- --method POST \
- -H "Accept: application/vnd.github+json" \
- -H "X-GitHub-Api-Version: 2022-11-28" \
- repos/${{ github.repository }}/statuses/${{ needs.get-sha.outputs.PR_HEAD_SHA }} \
- -f "target_url=$GITHUB_RUN_URL" -f "state=pending" -f "description=Slow CI job" -f "context=pytest/custom-tests"
-
- run_models_gpu:
- name: Run all tests for the model
- if: ${{ needs.get-tests.outputs.models != '[]' }}
- needs: [get-pr-number, get-tests, create_run]
- strategy:
- fail-fast: false
- matrix:
- folders: ${{ fromJson(needs.get-tests.outputs.models) }}
- machine_type: [aws-g4dn-2xlarge-cache, aws-g4dn-12xlarge-cache]
- runs-on:
- group: '${{ matrix.machine_type }}'
- container:
- image: huggingface/transformers-all-latest-gpu
- options: --gpus all --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/
- steps:
- - name: Echo input and matrix info
- shell: bash
- run: |
- echo "${{ matrix.folders }}"
-
- - name: Echo folder ${{ matrix.folders }}
- shell: bash
- # For folders like `models/bert`, set an env. var. (`matrix_folders`) to `models_bert`, which will be used to
- # set the artifact folder names (because the character `/` is not allowed).
- run: |
- echo "${{ matrix.folders }}"
- matrix_folders=${{ matrix.folders }}
- matrix_folders=${matrix_folders/'models/'/'models_'}
- echo "$matrix_folders"
- echo "matrix_folders=$matrix_folders" >> $GITHUB_ENV
-
- - name: Checkout to PR merge commit
- working-directory: /transformers
- run: |
- git fetch origin refs/pull/${{ needs.get-pr-number.outputs.PR_NUMBER }}/merge:refs/remotes/pull/${{ needs.get-pr-number.outputs.PR_NUMBER }}/merge
- git checkout refs/remotes/pull/${{ needs.get-pr-number.outputs.PR_NUMBER }}/merge
- git log -1 --format=%H
-
- - name: Reinstall transformers in edit mode (remove the one installed during docker image build)
- working-directory: /transformers
- run: python3 -m pip uninstall -y transformers && python3 -m pip install -e .
-
- - name: NVIDIA-SMI
- run: |
- nvidia-smi
-
- - name: Set `machine_type` for report and artifact names
- working-directory: /transformers
- shell: bash
- run: |
- echo "${{ matrix.machine_type }}"
- if [ "${{ matrix.machine_type }}" = "aws-g4dn-2xlarge-cache" ]; then
- machine_type=single-gpu
- elif [ "${{ matrix.machine_type }}" = "aws-g4dn-12xlarge-cache" ]; then
- machine_type=multi-gpu
- else
- machine_type=${{ matrix.machine_type }}
- fi
- echo "$machine_type"
- echo "machine_type=$machine_type" >> $GITHUB_ENV
-
- - name: Environment
- working-directory: /transformers
- run: |
- python3 utils/print_env.py
-
- - name: Show installed libraries and their versions
- working-directory: /transformers
- run: pip freeze
-
- - name: Run all tests on GPU
- working-directory: /transformers
- run: |
- export CUDA_VISIBLE_DEVICES="$(python3 utils/set_cuda_devices_for_ci.py --test_folder ${{ matrix.folders }})"
- echo $CUDA_VISIBLE_DEVICES
- python3 -m pytest -v -rsfE --make-reports=${{ env.machine_type }}_run_models_gpu_${{ matrix.folders }}_test_reports tests/${{ matrix.folders }}
-
- - name: Failure short reports
- if: ${{ failure() }}
- continue-on-error: true
- run: cat /transformers/reports/${{ env.machine_type }}_run_models_gpu_${{ matrix.folders }}_test_reports/failures_short.txt
-
- - name: Make sure report directory exists
- shell: bash
- run: |
- mkdir -p /transformers/reports/${{ env.machine_type }}_run_models_gpu_${{ matrix.folders }}_test_reports
- echo "hello" > /transformers/reports/${{ env.machine_type }}_run_models_gpu_${{ matrix.folders }}_test_reports/hello.txt
- echo "${{ env.machine_type }}_run_models_gpu_${{ matrix.folders }}_test_reports"
-
- - name: "Test suite reports artifacts: ${{ env.machine_type }}_run_models_gpu_${{ env.matrix_folders }}_test_reports"
- if: ${{ always() }}
- uses: actions/upload-artifact@v4
- with:
- name: ${{ env.machine_type }}_run_models_gpu_${{ env.matrix_folders }}_test_reports
- path: /transformers/reports/${{ env.machine_type }}_run_models_gpu_${{ matrix.folders }}_test_reports
-
- update_run_status:
- name: Update Check Run Status
- needs: [get-sha, create_run, run_models_gpu]
- permissions: write-all
- if: ${{ always() && needs.create_run.result == 'success' }}
- runs-on: ubuntu-22.04
- env:
- GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- GITHUB_RUN_URL: https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}
- steps:
- - name: Get `run_models_gpu` job status
- run: |
- echo "${{ needs.run_models_gpu.result }}"
- if [ "${{ needs.run_models_gpu.result }}" = "cancelled" ]; then
- echo "STATUS=failure" >> $GITHUB_ENV
- elif [ "${{ needs.run_models_gpu.result }}" = "skipped" ]; then
- echo "STATUS=success" >> $GITHUB_ENV
- else
- echo "STATUS=${{ needs.run_models_gpu.result }}" >> $GITHUB_ENV
- fi
-
- - name: Update PR commit statuses
- run: |
- echo "${{ needs.run_models_gpu.result }}"
- echo "${{ env.STATUS }}"
- gh api \
- --method POST \
- -H "Accept: application/vnd.github+json" \
- -H "X-GitHub-Api-Version: 2022-11-28" \
- repos/${{ github.repository }}/statuses/${{ needs.get-sha.outputs.PR_HEAD_SHA }} \
- -f "target_url=$GITHUB_RUN_URL" -f "state=${{ env.STATUS }}" -f "description=Slow CI job" -f "context=pytest/custom-tests"
diff --git a/.github/workflows/self-nightly-past-ci-caller.yml b/.github/workflows/self-nightly-past-ci-caller.yml
index 46d811d4a43394..142399a6366ce6 100644
--- a/.github/workflows/self-nightly-past-ci-caller.yml
+++ b/.github/workflows/self-nightly-past-ci-caller.yml
@@ -21,6 +21,39 @@ jobs:
echo "$(python3 -c 'print(int(${{ github.run_number }}) % 10)')"
echo "run_number=$(python3 -c 'print(int(${{ github.run_number }}) % 10)')" >> $GITHUB_OUTPUT
+ run_past_ci_pytorch_1-13:
+ name: PyTorch 1.13
+ needs: get_number
+ if: needs.get_number.outputs.run_number == 0 && (cancelled() != true) && ((github.event_name == 'schedule') || ((github.event_name == 'push') && startsWith(github.ref_name, 'run_past_ci')))
+ uses: ./.github/workflows/self-past-caller.yml
+ with:
+ framework: pytorch
+ version: "1.13"
+ sha: ${{ github.sha }}
+ secrets: inherit
+
+ run_past_ci_pytorch_1-12:
+ name: PyTorch 1.12
+ needs: get_number
+ if: needs.get_number.outputs.run_number == 1 && (cancelled() != true) && ((github.event_name == 'schedule') || ((github.event_name == 'push') && startsWith(github.ref_name, 'run_past_ci')))
+ uses: ./.github/workflows/self-past-caller.yml
+ with:
+ framework: pytorch
+ version: "1.12"
+ sha: ${{ github.sha }}
+ secrets: inherit
+
+ run_past_ci_pytorch_1-11:
+ name: PyTorch 1.11
+ needs: get_number
+ if: needs.get_number.outputs.run_number == 2 && (cancelled() != true) && ((github.event_name == 'schedule') || ((github.event_name == 'push') && startsWith(github.ref_name, 'run_past_ci')))
+ uses: ./.github/workflows/self-past-caller.yml
+ with:
+ framework: pytorch
+ version: "1.11"
+ sha: ${{ github.sha }}
+ secrets: inherit
+
run_past_ci_tensorflow_2-11:
name: TensorFlow 2.11
needs: get_number
diff --git a/.github/workflows/self-pr-slow-ci.yml b/.github/workflows/self-pr-slow-ci.yml
new file mode 100644
index 00000000000000..43fcecd8def21e
--- /dev/null
+++ b/.github/workflows/self-pr-slow-ci.yml
@@ -0,0 +1,151 @@
+name: PR slow CI
+
+on:
+ pull_request:
+ paths:
+ - "src/transformers/models/*/modeling_*.py"
+ - "tests/**/test_*.py"
+
+concurrency:
+ group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
+ cancel-in-progress: true
+
+env:
+ HF_HOME: /mnt/cache
+ TRANSFORMERS_IS_CI: yes
+ OMP_NUM_THREADS: 8
+ MKL_NUM_THREADS: 8
+ RUN_SLOW: yes
+ # For gated repositories, we still need to agree to share information on the Hub repo. page in order to get access.
+ # This token is created under the bot `hf-transformers-bot`.
+ HF_HUB_READ_TOKEN: ${{ secrets.HF_HUB_READ_TOKEN }}
+ SIGOPT_API_TOKEN: ${{ secrets.SIGOPT_API_TOKEN }}
+ TF_FORCE_GPU_ALLOW_GROWTH: true
+ RUN_PT_TF_CROSS_TESTS: 1
+ CUDA_VISIBLE_DEVICES: 0,1
+
+jobs:
+ find_models_to_run:
+ runs-on: ubuntu-22.04
+ name: Find models to run slow tests
+ # Triggered only if the required label `run-slow` is added
+ if: ${{ contains(github.event.pull_request.labels.*.name, 'run-slow') }}
+ outputs:
+ models: ${{ steps.models_to_run.outputs.models }}
+ steps:
+ - uses: actions/checkout@v4
+ with:
+ fetch-depth: "0"
+ ref: ${{ github.event.pull_request.head.sha }}
+
+ - name: Get commit message
+ run: |
+ echo "commit_message=$(git show -s --format=%s)" >> $GITHUB_ENV
+
+ - name: Get models to run slow tests
+ run: |
+ echo "${{ env.commit_message }}"
+ python -m pip install GitPython
+ python utils/pr_slow_ci_models.py --commit_message "${{ env.commit_message }}" | tee output.txt
+ echo "models=$(tail -n 1 output.txt)" >> $GITHUB_ENV
+
+ - name: Models to run slow tests
+ id: models_to_run
+ run: |
+ echo "${{ env.models }}"
+ echo "models=${{ env.models }}" >> $GITHUB_OUTPUT
+
+ run_models_gpu:
+ name: Run all tests for the model
+ # Triggered only `find_models_to_run` is triggered (label `run-slow` is added) which gives the models to run
+ # (either a new model PR or via a commit message)
+ if: ${{ needs.find_models_to_run.outputs.models != '[]' }}
+ needs: find_models_to_run
+ strategy:
+ fail-fast: false
+ matrix:
+ folders: ${{ fromJson(needs.find_models_to_run.outputs.models) }}
+ machine_type: [aws-g4dn-2xlarge-cache, aws-g4dn-12xlarge-cache]
+ runs-on:
+ group: '${{ matrix.machine_type }}'
+ container:
+ image: huggingface/transformers-all-latest-gpu
+ options: --gpus all --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/
+ steps:
+ - name: Echo input and matrix info
+ shell: bash
+ run: |
+ echo "${{ matrix.folders }}"
+
+ - name: Echo folder ${{ matrix.folders }}
+ shell: bash
+ # For folders like `models/bert`, set an env. var. (`matrix_folders`) to `models_bert`, which will be used to
+ # set the artifact folder names (because the character `/` is not allowed).
+ run: |
+ echo "${{ matrix.folders }}"
+ matrix_folders=${{ matrix.folders }}
+ matrix_folders=${matrix_folders/'models/'/'models_'}
+ echo "$matrix_folders"
+ echo "matrix_folders=$matrix_folders" >> $GITHUB_ENV
+
+ - name: Update clone
+ working-directory: /transformers
+ run: git fetch && git fetch origin pull/${{ github.event.pull_request.number }}/head:pull/${{ github.event.pull_request.number }}/merge && git checkout pull/${{ github.event.pull_request.number }}/merge
+
+ - name: Reinstall transformers in edit mode (remove the one installed during docker image build)
+ working-directory: /transformers
+ run: python3 -m pip uninstall -y transformers && python3 -m pip install -e . && python3 -m pip install --upgrade torch torchaudio torchvision
+
+ - name: NVIDIA-SMI
+ run: |
+ nvidia-smi
+
+ - name: Set `machine_type` for report and artifact names
+ working-directory: /transformers
+ shell: bash
+ run: |
+ echo "${{ matrix.machine_type }}"
+ if [ "${{ matrix.machine_type }}" = "aws-g4dn-2xlarge-cache" ]; then
+ machine_type=single-gpu
+ elif [ "${{ matrix.machine_type }}" = "aws-g4dn-12xlarge-cache" ]; then
+ machine_type=multi-gpu
+ else
+ machine_type=${{ matrix.machine_type }}
+ fi
+ echo "$machine_type"
+ echo "machine_type=$machine_type" >> $GITHUB_ENV
+
+ - name: Environment
+ working-directory: /transformers
+ run: |
+ python3 utils/print_env.py
+
+ - name: Show installed libraries and their versions
+ working-directory: /transformers
+ run: pip freeze
+
+ - name: Run all tests on GPU
+ working-directory: /transformers
+ run: |
+ export CUDA_VISIBLE_DEVICES="$(python3 utils/set_cuda_devices_for_ci.py --test_folder ${{ matrix.folders }})"
+ echo $CUDA_VISIBLE_DEVICES
+ python3 -m pytest -v -rsfE --make-reports=${{ env.machine_type }}_run_models_gpu_${{ matrix.folders }}_test_reports tests/${{ matrix.folders }}
+
+ - name: Failure short reports
+ if: ${{ failure() }}
+ continue-on-error: true
+ run: cat /transformers/reports/${{ env.machine_type }}_run_models_gpu_${{ matrix.folders }}_test_reports/failures_short.txt
+
+ - name: Make sure report directory exists
+ shell: bash
+ run: |
+ mkdir -p /transformers/reports/${{ env.machine_type }}_run_models_gpu_${{ matrix.folders }}_test_reports
+ echo "hello" > /transformers/reports/${{ env.machine_type }}_run_models_gpu_${{ matrix.folders }}_test_reports/hello.txt
+ echo "${{ env.machine_type }}_run_models_gpu_${{ matrix.folders }}_test_reports"
+
+ - name: "Test suite reports artifacts: ${{ env.machine_type }}_run_models_gpu_${{ env.matrix_folders }}_test_reports"
+ if: ${{ always() }}
+ uses: actions/upload-artifact@v4
+ with:
+ name: ${{ env.machine_type }}_run_models_gpu_${{ env.matrix_folders }}_test_reports
+ path: /transformers/reports/${{ env.machine_type }}_run_models_gpu_${{ matrix.folders }}_test_reports
diff --git a/.github/workflows/self-push-amd-mi210-caller.yml b/.github/workflows/self-push-amd-mi210-caller.yml
index 45b325f7b357bf..a401e40ee7f164 100644
--- a/.github/workflows/self-push-amd-mi210-caller.yml
+++ b/.github/workflows/self-push-amd-mi210-caller.yml
@@ -1,25 +1,25 @@
-name: Self-hosted runner (AMD mi210 CI caller)
-
-on:
- #workflow_run:
- # workflows: ["Self-hosted runner (push-caller)"]
- # branches: ["main"]
- # types: [completed]
- push:
- branches:
- - run_amd_push_ci_caller*
- paths:
- - "src/**"
- - "tests/**"
- - ".github/**"
- - "templates/**"
- - "utils/**"
-
-jobs:
- run_amd_ci:
- name: AMD mi210
- if: (cancelled() != true) && ((github.event_name == 'workflow_run') || ((github.event_name == 'push') && startsWith(github.ref_name, 'run_amd_push_ci_caller')))
- uses: ./.github/workflows/self-push-amd.yml
- with:
- gpu_flavor: mi210
- secrets: inherit
+name: Self-hosted runner (AMD mi210 CI caller)
+
+on:
+ workflow_run:
+ workflows: ["Self-hosted runner (push-caller)"]
+ branches: ["main"]
+ types: [completed]
+ push:
+ branches:
+ - run_amd_push_ci_caller*
+ paths:
+ - "src/**"
+ - "tests/**"
+ - ".github/**"
+ - "templates/**"
+ - "utils/**"
+
+jobs:
+ run_amd_ci:
+ name: AMD mi210
+ if: (cancelled() != true) && ((github.event_name == 'workflow_run') || ((github.event_name == 'push') && startsWith(github.ref_name, 'run_amd_push_ci_caller')))
+ uses: ./.github/workflows/self-push-amd.yml
+ with:
+ gpu_flavor: mi210
+ secrets: inherit
diff --git a/.github/workflows/self-push-amd-mi250-caller.yml b/.github/workflows/self-push-amd-mi250-caller.yml
index 91b978b593d0b5..fef532703170cb 100644
--- a/.github/workflows/self-push-amd-mi250-caller.yml
+++ b/.github/workflows/self-push-amd-mi250-caller.yml
@@ -1,25 +1,25 @@
-name: Self-hosted runner (AMD mi250 CI caller)
-
-on:
- #workflow_run:
- # workflows: ["Self-hosted runner (push-caller)"]
- # branches: ["main"]
- # types: [completed]
- push:
- branches:
- - run_amd_push_ci_caller*
- paths:
- - "src/**"
- - "tests/**"
- - ".github/**"
- - "templates/**"
- - "utils/**"
-
-jobs:
- run_amd_ci:
- name: AMD mi250
- if: (cancelled() != true) && ((github.event_name == 'workflow_run') || ((github.event_name == 'push') && startsWith(github.ref_name, 'run_amd_push_ci_caller')))
- uses: ./.github/workflows/self-push-amd.yml
- with:
- gpu_flavor: mi250
- secrets: inherit
+name: Self-hosted runner (AMD mi250 CI caller)
+
+on:
+ workflow_run:
+ workflows: ["Self-hosted runner (push-caller)"]
+ branches: ["main"]
+ types: [completed]
+ push:
+ branches:
+ - run_amd_push_ci_caller*
+ paths:
+ - "src/**"
+ - "tests/**"
+ - ".github/**"
+ - "templates/**"
+ - "utils/**"
+
+jobs:
+ run_amd_ci:
+ name: AMD mi250
+ if: (cancelled() != true) && ((github.event_name == 'workflow_run') || ((github.event_name == 'push') && startsWith(github.ref_name, 'run_amd_push_ci_caller')))
+ uses: ./.github/workflows/self-push-amd.yml
+ with:
+ gpu_flavor: mi250
+ secrets: inherit
diff --git a/.github/workflows/self-push-amd-mi300-caller.yml b/.github/workflows/self-push-amd-mi300-caller.yml
index 797916125a24fb..a8ee4e540ecf3f 100644
--- a/.github/workflows/self-push-amd-mi300-caller.yml
+++ b/.github/workflows/self-push-amd-mi300-caller.yml
@@ -1,10 +1,10 @@
name: Self-hosted runner (AMD mi300 CI caller)
on:
- #workflow_run:
- # workflows: ["Self-hosted runner (push-caller)"]
- # branches: ["main"]
- # types: [completed]
+ workflow_run:
+ workflows: ["Self-hosted runner (push-caller)"]
+ branches: ["main"]
+ types: [completed]
push:
branches:
- run_amd_push_ci_caller*
diff --git a/README.md b/README.md
index 42403f84b885da..c748e675066202 100644
--- a/README.md
+++ b/README.md
@@ -249,7 +249,7 @@ The model itself is a regular [Pytorch `nn.Module`](https://pytorch.org/docs/sta
### With pip
-This repository is tested on Python 3.9+, Flax 0.4.1+, PyTorch 2.0+, and TensorFlow 2.6+.
+This repository is tested on Python 3.9+, Flax 0.4.1+, PyTorch 1.11+, and TensorFlow 2.6+.
You should install 🤗 Transformers in a [virtual environment](https://docs.python.org/3/library/venv.html). If you're unfamiliar with Python virtual environments, check out the [user guide](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/).
diff --git a/benchmark/README.md b/benchmark/README.md
deleted file mode 100644
index a827da444f0801..00000000000000
--- a/benchmark/README.md
+++ /dev/null
@@ -1,49 +0,0 @@
-# Benchmarks
-
-You might want to add new benchmarks.
-
-You will need to define a python function named `run_benchmark` in your python file and the file must be located in this `benchmark/` directory.
-
-The expected function signature is the following:
-
-```py
-def run_benchmark(logger: Logger, branch: str, commit_id: str, commit_msg: str, num_tokens_to_generate=100):
-```
-
-## Writing metrics to the database
-
-`MetricRecorder` is thread-safe, in the sense of the python [`Thread`](https://docs.python.org/3/library/threading.html#threading.Thread). This means you can start a background thread to do the readings on the device measurements while not blocking the main thread to execute the model measurements.
-
-cf [`llama.py`](./llama.py) to see an example of this in practice.
-
-```py
-from benchmarks_entrypoint import MetricsRecorder
-import psycopg2
-
-def run_benchmark(logger: Logger, branch: str, commit_id: str, commit_msg: str, num_tokens_to_generate=100):
- metrics_recorder = MetricsRecorder(psycopg2.connect("dbname=metrics"), logger, branch, commit_id, commit_msg)
- benchmark_id = metrics_recorder.initialise_benchmark({"gpu_name": gpu_name, "model_id": model_id})
- # To collect device measurements
- metrics_recorder.collect_device_measurements(
- benchmark_id, cpu_util, mem_megabytes, gpu_util, gpu_mem_megabytes
- )
- # To collect your model measurements
- metrics_recorder.collect_model_measurements(
- benchmark_id,
- {
- "model_load_time": model_load_time,
- "first_eager_forward_pass_time_secs": first_eager_fwd_pass_time,
- "second_eager_forward_pass_time_secs": second_eager_fwd_pass_time,
- "first_eager_generate_time_secs": first_eager_generate_time,
- "second_eager_generate_time_secs": second_eager_generate_time,
- "time_to_first_token_secs": time_to_first_token,
- "time_to_second_token_secs": time_to_second_token,
- "time_to_third_token_secs": time_to_third_token,
- "time_to_next_token_mean_secs": mean_time_to_next_token,
- "first_compile_generate_time_secs": first_compile_generate_time,
- "second_compile_generate_time_secs": second_compile_generate_time,
- "third_compile_generate_time_secs": third_compile_generate_time,
- "fourth_compile_generate_time_secs": fourth_compile_generate_time,
- },
- )
-```
diff --git a/benchmark/benchmarks_entrypoint.py b/benchmark/benchmarks_entrypoint.py
deleted file mode 100644
index 7925e2902834f7..00000000000000
--- a/benchmark/benchmarks_entrypoint.py
+++ /dev/null
@@ -1,144 +0,0 @@
-import argparse
-import importlib.util
-import logging
-import os
-from typing import Dict
-import psycopg2
-import sys
-
-from psycopg2.extras import Json
-from psycopg2.extensions import register_adapter
-
-
-register_adapter(dict, Json)
-
-
-class ImportModuleException(Exception):
- pass
-
-
-class MetricsRecorder:
- def __init__(self, connection, logger: logging.Logger, branch: str, commit_id: str, commit_msg: str):
- self.conn = connection
- self.conn.autocommit = True
- self.logger = logger
- self.branch = branch
- self.commit_id = commit_id
- self.commit_msg = commit_msg
-
- def initialise_benchmark(self, metadata: Dict[str, str]) -> int:
- """
- Creates a new benchmark, returns the benchmark id
- """
- # gpu_name: str, model_id: str
- with self.conn.cursor() as cur:
- cur.execute(
- "INSERT INTO benchmarks (branch, commit_id, commit_message, metadata) VALUES (%s, %s, %s, %s) RETURNING benchmark_id",
- (self.branch, self.commit_id, self.commit_msg, metadata),
- )
- benchmark_id = cur.fetchone()[0]
- logger.debug(f"initialised benchmark #{benchmark_id}")
- return benchmark_id
-
- def collect_device_measurements(self, benchmark_id: int, cpu_util, mem_megabytes, gpu_util, gpu_mem_megabytes):
- """
- Collect device metrics, such as CPU & GPU usage. These are "static", as in you cannot pass arbitrary arguments to the function.
- """
- with self.conn.cursor() as cur:
- cur.execute(
- "INSERT INTO device_measurements (benchmark_id, cpu_util, mem_megabytes, gpu_util, gpu_mem_megabytes) VALUES (%s, %s, %s, %s, %s)",
- (benchmark_id, cpu_util, mem_megabytes, gpu_util, gpu_mem_megabytes),
- )
- self.logger.debug(
- f"inserted device measurements for benchmark #{benchmark_id} [CPU util: {cpu_util}, mem MBs: {mem_megabytes}, GPU util: {gpu_util}, GPU mem MBs: {gpu_mem_megabytes}]"
- )
-
- def collect_model_measurements(self, benchmark_id: int, measurements: Dict[str, float]):
- with self.conn.cursor() as cur:
- cur.execute(
- """
- INSERT INTO model_measurements (
- benchmark_id,
- measurements
- ) VALUES (%s, %s)
- """,
- (
- benchmark_id,
- measurements,
- ),
- )
- self.logger.debug(f"inserted model measurements for benchmark #{benchmark_id}: {measurements}")
-
- def close(self):
- self.conn.close()
-
-
-logger = logging.getLogger(__name__)
-logger.setLevel(logging.INFO)
-
-handler = logging.StreamHandler(sys.stdout)
-handler.setLevel(logging.INFO)
-formatter = logging.Formatter("[%(levelname)s - %(asctime)s] %(message)s")
-handler.setFormatter(formatter)
-logger.addHandler(handler)
-
-
-def parse_arguments():
- """
- Parse command line arguments for the benchmarking CLI.
- """
- parser = argparse.ArgumentParser(description="CLI for benchmarking the huggingface/transformers.")
-
- parser.add_argument(
- "branch",
- type=str,
- help="The branch name on which the benchmarking is performed.",
- )
-
- parser.add_argument(
- "commit_id",
- type=str,
- help="The commit hash on which the benchmarking is performed.",
- )
-
- parser.add_argument(
- "commit_msg",
- type=str,
- help="The commit message associated with the commit, truncated to 70 characters.",
- )
-
- args = parser.parse_args()
-
- return args.branch, args.commit_id, args.commit_msg
-
-
-def import_from_path(module_name, file_path):
- try:
- spec = importlib.util.spec_from_file_location(module_name, file_path)
- module = importlib.util.module_from_spec(spec)
- sys.modules[module_name] = module
- spec.loader.exec_module(module)
- return module
- except Exception as e:
- raise ImportModuleException(f"failed to load python module: {e}")
-
-
-if __name__ == "__main__":
- benchmarks_folder_path = os.path.dirname(os.path.realpath(__file__))
-
- branch, commit_id, commit_msg = parse_arguments()
-
- for entry in os.scandir(benchmarks_folder_path):
- try:
- if not entry.name.endswith(".py"):
- continue
- if entry.path == __file__:
- continue
- logger.debug(f"loading: {entry.name}")
- module = import_from_path(entry.name.split(".")[0], entry.path)
- logger.info(f"runnning benchmarks in: {entry.name}")
- module.run_benchmark(logger, branch, commit_id, commit_msg)
- except ImportModuleException as e:
- logger.error(e)
- except Exception as e:
- logger.error(f"error running benchmarks for {entry.name}: {e}")
diff --git a/benchmark/default.yml b/benchmark/default.yml
deleted file mode 100644
index f3f02cab34d1bd..00000000000000
--- a/benchmark/default.yml
+++ /dev/null
@@ -1,10 +0,0 @@
-apiVersion: 1
-
-providers:
- - name: 'Transformers Benchmarks'
- orgId: 1
- type: file
- updateIntervalSeconds: 10
- allowUiUpdates: true
- options:
- path: /etc/grafana/dashboards
diff --git a/benchmark/grafana_dashboard.json b/benchmark/grafana_dashboard.json
index caaec78a522303..3d579f7b368711 100644
--- a/benchmark/grafana_dashboard.json
+++ b/benchmark/grafana_dashboard.json
@@ -30,7 +30,7 @@
"title": "Go to data",
"tooltip": "Go to data",
"type": "link",
- "url": "http://transformers-benchmarks.hf.co/d/fdz33iyzln9c0a/transformers-benchmarks?orgId=1&from=${StartTime}&to=${EndTime}"
+ "url": "http://transformers-benchmarks.huggingface.co/d/fdz33iyzln9c0a/transformers-benchmarks?orgId=1&from=${StartTime}&to=${EndTime}"
}
],
"liveNow": true,
@@ -77,7 +77,7 @@
"properties": [
{
"id": "custom.width",
- "value": 202
+ "value": 196
}
]
},
@@ -101,7 +101,7 @@
"properties": [
{
"id": "custom.width",
- "value": 524
+ "value": 581
}
]
},
@@ -113,19 +113,7 @@
"properties": [
{
"id": "custom.width",
- "value": 353
- }
- ]
- },
- {
- "matcher": {
- "id": "byName",
- "options": "model_id"
- },
- "properties": [
- {
- "id": "custom.width",
- "value": 216
+ "value": 379
}
]
}
@@ -155,14 +143,12 @@
"targets": [
{
"datasource": {
- "default": true,
- "type": "grafana-postgresql-datasource",
- "uid": "be28nkzirtb0gd"
+ "type": "grafana-postgresql-datasource"
},
"editorMode": "code",
"format": "table",
"rawQuery": true,
- "rawSql": "SELECT commit_id, commit_message, metadata->>'gpu_name' as gpu_name, metadata->>'model_id' as model_id, created_at AS date FROM benchmarks WHERE branch = '${branch}' AND metadata->>'gpu_name' = '${gpu_name}' ORDER BY benchmark_id DESC LIMIT ${last_n_commits};",
+ "rawSql": "SELECT commit_id as commit_id, commit_message, gpu_name, created_at AS date FROM benchmarks WHERE branch = '${branch}' ORDER BY benchmark_id DESC LIMIT ${last_n_commits};",
"refId": "A",
"sql": {
"columns": [
@@ -320,14 +306,13 @@
"targets": [
{
"datasource": {
- "default": true,
"type": "grafana-postgresql-datasource",
- "uid": "be28nkzirtb0gd"
+ "uid": "bdz2yss7sxo1sc"
},
"editorMode": "code",
"format": "table",
"rawQuery": true,
- "rawSql": "SELECT CAST(m.measurements->'first_eager_forward_pass_time_secs' AS double precision) AS first_eager_forward_pass_time_secs, left(b.commit_id, 7), m.time FROM benchmarks as b JOIN model_measurements AS m ON b.benchmark_id = m.benchmark_id WHERE b.branch = '${branch}' AND b.metadata->>'gpu_name' = '${gpu_name}' ORDER BY b.benchmark_id DESC LIMIT ${last_n_commits};",
+ "rawSql": "SELECT CAST(m.measurements->'first_eager_forward_pass_time_secs' AS double precision) AS first_eager_forward_pass_time_secs, left(b.commit_id, 7), m.time FROM benchmarks as b JOIN model_measurements AS m ON b.benchmark_id = m.benchmark_id WHERE b.branch = '${branch}' AND gpu_name = '${gpu_name}' ORDER BY b.benchmark_id DESC LIMIT ${last_n_commits};",
"refId": "A",
"sql": {
"columns": [
@@ -446,14 +431,13 @@
"targets": [
{
"datasource": {
- "default": true,
"type": "grafana-postgresql-datasource",
- "uid": "be28nkzirtb0gd"
+ "uid": "bdz2yss7sxo1sc"
},
"editorMode": "code",
"format": "table",
"rawQuery": true,
- "rawSql": "SELECT CAST(m.measurements->'second_eager_forward_pass_time_secs' AS double precision) AS second_eager_forward_pass_time_secs, left(b.commit_id, 7), m.time FROM benchmarks as b JOIN model_measurements AS m ON b.benchmark_id = m.benchmark_id WHERE b.branch = '${branch}' AND b.metadata->>'gpu_name' = '${gpu_name}' ORDER BY b.benchmark_id DESC LIMIT ${last_n_commits};",
+ "rawSql": "SELECT CAST(m.measurements->'second_eager_forward_pass_time_secs' AS double precision) AS second_eager_forward_pass_time_secs, left(b.commit_id, 7), m.time FROM benchmarks as b JOIN model_measurements AS m ON b.benchmark_id = m.benchmark_id WHERE b.branch = '${branch}' AND gpu_name = '${gpu_name}' ORDER BY b.benchmark_id DESC LIMIT ${last_n_commits};",
"refId": "A",
"sql": {
"columns": [
@@ -581,14 +565,13 @@
"targets": [
{
"datasource": {
- "default": true,
"type": "grafana-postgresql-datasource",
- "uid": "be28nkzirtb0gd"
+ "uid": "bdz2yss7sxo1sc"
},
"editorMode": "code",
"format": "table",
"rawQuery": true,
- "rawSql": "SELECT CAST(m.measurements->'time_to_first_token_secs' AS double precision) AS time_to_first_token_secs, left(b.commit_id, 7), m.time FROM benchmarks as b JOIN model_measurements AS m ON b.benchmark_id = m.benchmark_id WHERE b.branch = '${branch}' AND b.metadata->>'gpu_name' = '${gpu_name}' ORDER BY b.benchmark_id DESC LIMIT ${last_n_commits};",
+ "rawSql": "SELECT CAST(m.measurements->'time_to_first_token_secs' AS double precision) AS time_to_first_token_secs, left(b.commit_id, 7), m.time FROM benchmarks as b JOIN model_measurements AS m ON b.benchmark_id = m.benchmark_id WHERE b.branch = '${branch}' AND gpu_name = '${gpu_name}' ORDER BY b.benchmark_id DESC LIMIT ${last_n_commits};",
"refId": "A",
"sql": {
"columns": [
@@ -703,14 +686,13 @@
"targets": [
{
"datasource": {
- "default": true,
"type": "grafana-postgresql-datasource",
- "uid": "be28nkzirtb0gd"
+ "uid": "bdz2yss7sxo1sc"
},
"editorMode": "code",
"format": "table",
"rawQuery": true,
- "rawSql": "SELECT CAST(m.measurements->'time_to_second_token_secs' AS double precision) AS time_to_second_token_secs, left(b.commit_id, 7), m.time FROM benchmarks as b JOIN model_measurements AS m ON b.benchmark_id = m.benchmark_id WHERE b.branch = '${branch}' AND b.metadata->>'gpu_name' = '${gpu_name}' ORDER BY b.benchmark_id DESC LIMIT ${last_n_commits};",
+ "rawSql": "SELECT CAST(m.measurements->'time_to_second_token_secs' AS double precision) AS time_to_second_token_secs, left(b.commit_id, 7), m.time FROM benchmarks as b JOIN model_measurements AS m ON b.benchmark_id = m.benchmark_id WHERE b.branch = '${branch}' AND gpu_name = '${gpu_name}' ORDER BY b.benchmark_id DESC LIMIT ${last_n_commits};",
"refId": "A",
"sql": {
"columns": [
@@ -825,14 +807,13 @@
"targets": [
{
"datasource": {
- "default": true,
"type": "grafana-postgresql-datasource",
- "uid": "be28nkzirtb0gd"
+ "uid": "bdz2yss7sxo1sc"
},
"editorMode": "code",
"format": "table",
"rawQuery": true,
- "rawSql": "SELECT CAST(m.measurements->'time_to_third_token_secs' AS double precision) AS time_to_third_token_secs, left(b.commit_id, 7), m.time FROM benchmarks as b JOIN model_measurements AS m ON b.benchmark_id = m.benchmark_id WHERE b.branch = '${branch}' AND b.metadata->>'gpu_name' = '${gpu_name}' ORDER BY b.benchmark_id DESC LIMIT ${last_n_commits};",
+ "rawSql": "SELECT CAST(m.measurements->'time_to_third_token_secs' AS double precision) AS time_to_third_token_secs, left(b.commit_id, 7), m.time FROM benchmarks as b JOIN model_measurements AS m ON b.benchmark_id = m.benchmark_id WHERE b.branch = '${branch}' AND gpu_name = '${gpu_name}' ORDER BY b.benchmark_id DESC LIMIT ${last_n_commits};",
"refId": "A",
"sql": {
"columns": [
@@ -947,14 +928,13 @@
"targets": [
{
"datasource": {
- "default": true,
"type": "grafana-postgresql-datasource",
- "uid": "be28nkzirtb0gd"
+ "uid": "bdz2yss7sxo1sc"
},
"editorMode": "code",
"format": "table",
"rawQuery": true,
- "rawSql": "SELECT CAST(m.measurements->'time_to_next_token_mean_secs' AS double precision) AS time_to_next_token_mean_secs, left(b.commit_id, 7), m.time FROM benchmarks as b JOIN model_measurements AS m ON b.benchmark_id = m.benchmark_id WHERE b.branch = '${branch}' AND b.metadata->>'gpu_name' = '${gpu_name}' ORDER BY b.benchmark_id DESC LIMIT ${last_n_commits};",
+ "rawSql": "SELECT CAST(m.measurements->'time_to_next_token_mean_secs' AS double precision) AS time_to_next_token_mean_secs, left(b.commit_id, 7), m.time FROM benchmarks as b JOIN model_measurements AS m ON b.benchmark_id = m.benchmark_id WHERE b.branch = '${branch}' AND gpu_name = '${gpu_name}' ORDER BY b.benchmark_id DESC LIMIT ${last_n_commits};",
"refId": "A",
"sql": {
"columns": [
@@ -1082,14 +1062,13 @@
"targets": [
{
"datasource": {
- "default": true,
"type": "grafana-postgresql-datasource",
- "uid": "be28nkzirtb0gd"
+ "uid": "bdz2yss7sxo1sc"
},
"editorMode": "code",
"format": "table",
"rawQuery": true,
- "rawSql": "SELECT CAST(m.measurements->'first_compile_generate_time_secs' AS double precision) AS first_compile_generate_time_secs, left(b.commit_id, 7), m.time FROM benchmarks as b JOIN model_measurements AS m ON b.benchmark_id = m.benchmark_id WHERE b.branch = '${branch}' AND b.metadata->>'gpu_name' = '${gpu_name}' ORDER BY b.benchmark_id DESC LIMIT ${last_n_commits};",
+ "rawSql": "SELECT CAST(m.measurements->'first_compile_generate_time_secs' AS double precision) AS first_compile_generate_time_secs, left(b.commit_id, 7), m.time FROM benchmarks as b JOIN model_measurements AS m ON b.benchmark_id = m.benchmark_id WHERE b.branch = '${branch}' AND gpu_name = '${gpu_name}' ORDER BY b.benchmark_id DESC LIMIT ${last_n_commits};",
"refId": "A",
"sql": {
"columns": [
@@ -1204,14 +1183,13 @@
"targets": [
{
"datasource": {
- "default": true,
"type": "grafana-postgresql-datasource",
- "uid": "be28nkzirtb0gd"
+ "uid": "bdz2yss7sxo1sc"
},
"editorMode": "code",
"format": "table",
"rawQuery": true,
- "rawSql": "SELECT CAST(m.measurements->'second_compile_generate_time_secs' AS double precision) AS second_compile_generate_time_secs, left(b.commit_id, 7), m.time FROM benchmarks as b JOIN model_measurements AS m ON b.benchmark_id = m.benchmark_id WHERE b.branch = '${branch}' AND b.metadata->>'gpu_name' = '${gpu_name}' ORDER BY b.benchmark_id DESC LIMIT ${last_n_commits};",
+ "rawSql": "SELECT CAST(m.measurements->'second_compile_generate_time_secs' AS double precision) AS second_compile_generate_time_secs, left(b.commit_id, 7), m.time FROM benchmarks as b JOIN model_measurements AS m ON b.benchmark_id = m.benchmark_id WHERE b.branch = '${branch}' AND gpu_name = '${gpu_name}' ORDER BY b.benchmark_id DESC LIMIT ${last_n_commits};",
"refId": "A",
"sql": {
"columns": [
@@ -1326,14 +1304,13 @@
"targets": [
{
"datasource": {
- "default": true,
"type": "grafana-postgresql-datasource",
- "uid": "be28nkzirtb0gd"
+ "uid": "bdz2yss7sxo1sc"
},
"editorMode": "code",
"format": "table",
"rawQuery": true,
- "rawSql": "SELECT CAST(m.measurements->'third_compile_generate_time_secs' AS double precision) AS third_compile_generate_time_secs, left(b.commit_id, 7), m.time FROM benchmarks as b JOIN model_measurements AS m ON b.benchmark_id = m.benchmark_id WHERE b.branch = '${branch}' AND b.metadata->>'gpu_name' = '${gpu_name}' ORDER BY b.benchmark_id DESC LIMIT ${last_n_commits};",
+ "rawSql": "SELECT CAST(m.measurements->'third_compile_generate_time_secs' AS double precision) AS third_compile_generate_time_secs, left(b.commit_id, 7), m.time FROM benchmarks as b JOIN model_measurements AS m ON b.benchmark_id = m.benchmark_id WHERE b.branch = '${branch}' AND gpu_name = '${gpu_name}' ORDER BY b.benchmark_id DESC LIMIT ${last_n_commits};",
"refId": "A",
"sql": {
"columns": [
@@ -1448,14 +1425,13 @@
"targets": [
{
"datasource": {
- "default": true,
"type": "grafana-postgresql-datasource",
- "uid": "be28nkzirtb0gd"
+ "uid": "bdz2yss7sxo1sc"
},
"editorMode": "code",
"format": "table",
"rawQuery": true,
- "rawSql": "SELECT CAST(m.measurements->'fourth_compile_generate_time_secs' AS double precision) AS fourth_compile_generate_time_secs, left(b.commit_id, 7), m.time FROM benchmarks as b JOIN model_measurements AS m ON b.benchmark_id = m.benchmark_id WHERE b.branch = '${branch}' AND b.metadata->>'gpu_name' = '${gpu_name}' ORDER BY b.benchmark_id DESC LIMIT ${last_n_commits};",
+ "rawSql": "SELECT CAST(m.measurements->'fourth_compile_generate_time_secs' AS double precision) AS fourth_compile_generate_time_secs, left(b.commit_id, 7), m.time FROM benchmarks as b JOIN model_measurements AS m ON b.benchmark_id = m.benchmark_id WHERE b.branch = '${branch}' AND gpu_name = '${gpu_name}' ORDER BY b.benchmark_id DESC LIMIT ${last_n_commits};",
"refId": "A",
"sql": {
"columns": [
@@ -1504,7 +1480,11 @@
"id": 15,
"panels": [
{
- "datasource": {},
+ "datasource": {
+ "default": true,
+ "type": "grafana-postgresql-datasource",
+ "uid": "be28nkzirtb0gd"
+ },
"fieldConfig": {
"defaults": {
"color": {
@@ -1548,7 +1528,8 @@
"mode": "absolute",
"steps": [
{
- "color": "green"
+ "color": "green",
+ "value": null
},
{
"color": "red",
@@ -1582,9 +1563,8 @@
"targets": [
{
"datasource": {
- "default": true,
"type": "grafana-postgresql-datasource",
- "uid": "be28nkzirtb0gd"
+ "uid": "bdz2yss7sxo1sc"
},
"editorMode": "code",
"format": "table",
@@ -1685,7 +1665,11 @@
"type": "timeseries"
},
{
- "datasource": {},
+ "datasource": {
+ "default": true,
+ "type": "grafana-postgresql-datasource",
+ "uid": "be28nkzirtb0gd"
+ },
"fieldConfig": {
"defaults": {
"color": {
@@ -1729,7 +1713,8 @@
"mode": "absolute",
"steps": [
{
- "color": "green"
+ "color": "green",
+ "value": null
},
{
"color": "red",
@@ -1763,9 +1748,8 @@
"targets": [
{
"datasource": {
- "default": true,
"type": "grafana-postgresql-datasource",
- "uid": "be28nkzirtb0gd"
+ "uid": "bdz2yss7sxo1sc"
},
"editorMode": "code",
"format": "table",
@@ -1866,7 +1850,11 @@
"type": "timeseries"
},
{
- "datasource": {},
+ "datasource": {
+ "default": true,
+ "type": "grafana-postgresql-datasource",
+ "uid": "be28nkzirtb0gd"
+ },
"fieldConfig": {
"defaults": {
"color": {
@@ -1910,7 +1898,8 @@
"mode": "absolute",
"steps": [
{
- "color": "green"
+ "color": "green",
+ "value": null
},
{
"color": "red",
@@ -1944,9 +1933,8 @@
"targets": [
{
"datasource": {
- "default": true,
"type": "grafana-postgresql-datasource",
- "uid": "be28nkzirtb0gd"
+ "uid": "bdz2yss7sxo1sc"
},
"editorMode": "code",
"format": "table",
@@ -2047,7 +2035,11 @@
"type": "timeseries"
},
{
- "datasource": {},
+ "datasource": {
+ "default": true,
+ "type": "grafana-postgresql-datasource",
+ "uid": "be28nkzirtb0gd"
+ },
"fieldConfig": {
"defaults": {
"color": {
@@ -2091,7 +2083,8 @@
"mode": "absolute",
"steps": [
{
- "color": "green"
+ "color": "green",
+ "value": null
},
{
"color": "red",
@@ -2125,9 +2118,8 @@
"targets": [
{
"datasource": {
- "default": true,
"type": "grafana-postgresql-datasource",
- "uid": "be28nkzirtb0gd"
+ "uid": "bdz2yss7sxo1sc"
},
"editorMode": "code",
"format": "table",
@@ -2232,6 +2224,7 @@
"type": "row"
}
],
+ "refresh": "",
"schemaVersion": 39,
"tags": [],
"templating": {
@@ -2243,7 +2236,6 @@
"value": "main"
},
"datasource": {
- "default": true,
"type": "grafana-postgresql-datasource",
"uid": "be28nkzirtb0gd"
},
@@ -2256,7 +2248,7 @@
"name": "branch",
"options": [],
"query": "SELECT DISTINCT branch FROM benchmarks;",
- "refresh": 1,
+ "refresh": 2,
"regex": "",
"skipUrlSync": false,
"sort": 0,
@@ -2269,7 +2261,6 @@
"value": "1729701492845"
},
"datasource": {
- "default": true,
"type": "grafana-postgresql-datasource",
"uid": "be28nkzirtb0gd"
},
@@ -2290,11 +2281,10 @@
{
"current": {
"selected": false,
- "text": "1730393397577",
- "value": "1730393397577"
+ "text": "1730120430069",
+ "value": "1730120430069"
},
"datasource": {
- "default": true,
"type": "grafana-postgresql-datasource",
"uid": "be28nkzirtb0gd"
},
@@ -2322,16 +2312,15 @@
"type": "grafana-postgresql-datasource",
"uid": "be28nkzirtb0gd"
},
- "definition": "SELECT DISTINCT metadata->>'gpu_name' FROM benchmarks;",
- "description": "",
+ "definition": "SELECT DISTINCT gpu_name FROM benchmarks;",
"hide": 0,
"includeAll": false,
"label": "GPU",
"multi": false,
"name": "gpu_name",
"options": [],
- "query": "SELECT DISTINCT metadata->>'gpu_name' FROM benchmarks;",
- "refresh": 1,
+ "query": "SELECT DISTINCT gpu_name FROM benchmarks;",
+ "refresh": 2,
"regex": "",
"skipUrlSync": false,
"sort": 0,
@@ -2339,7 +2328,7 @@
},
{
"current": {
- "selected": true,
+ "selected": false,
"text": "10",
"value": "10"
},
@@ -2370,6 +2359,6 @@
"timezone": "browser",
"title": "Transformers benchmarks",
"uid": "fdz33iyzln9c0a",
- "version": 10,
+ "version": 4,
"weekStart": ""
}
diff --git a/benchmark/grafana_datasource.yaml b/benchmark/grafana_datasource.yaml
deleted file mode 100644
index 25f36254104ab5..00000000000000
--- a/benchmark/grafana_datasource.yaml
+++ /dev/null
@@ -1,17 +0,0 @@
-apiVersion: 1
-datasources:
- - name: grafana-postgresql-datasource
- uid: be28nkzirtb0gd
- type: postgres
- url: $GRAFANA_POSTGRES_DATASOURCE_URL
- user: $GRAFANA_POSTGRES_DATASOURCE_USER
- secureJsonData:
- password: $GRAFANA_POSTGRES_DATASOURCE_PWD
- jsonData:
- database: metrics
- maxOpenConns: 100
- maxIdleConns: 100
- maxIdleConnsAuto: true
- connMaxLifetime: 14400
- postgresVersion: 1000
- timescaledb: false
diff --git a/benchmark/init_db.sql b/benchmark/init_db.sql
index a7864c4af183b6..573cc11518e857 100644
--- a/benchmark/init_db.sql
+++ b/benchmark/init_db.sql
@@ -3,7 +3,7 @@ CREATE TABLE IF NOT EXISTS benchmarks (
branch VARCHAR(255),
commit_id VARCHAR(72),
commit_message VARCHAR(70),
- metadata jsonb,
+ gpu_name VARCHAR(255),
created_at timestamp without time zone NOT NULL DEFAULT (current_timestamp AT TIME ZONE 'UTC')
);
diff --git a/benchmark/llama.py b/benchmark/llama.py
index bbe1afefd5ef1b..4a2c57422e6ffb 100644
--- a/benchmark/llama.py
+++ b/benchmark/llama.py
@@ -1,25 +1,71 @@
-from logging import Logger
+import argparse
+import json
+import logging
import os
+import sys
+from statistics import mean
from threading import Event, Thread
from time import perf_counter, sleep
from typing import Optional
-from benchmarks_entrypoint import MetricsRecorder
import gpustat
import psutil
import psycopg2
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, StaticCache
+from psycopg2.extras import Json
+from psycopg2.extensions import register_adapter
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
+logger = logging.getLogger(__name__)
+logger.setLevel(logging.INFO)
+
+handler = logging.StreamHandler(sys.stdout)
+handler.setLevel(logging.INFO)
+formatter = logging.Formatter("[%(levelname)s - %(asctime)s] %(message)s")
+handler.setFormatter(formatter)
+logger.addHandler(handler)
+
os.environ["TOKENIZERS_PARALLELISM"] = "1"
torch.set_float32_matmul_precision("high")
+register_adapter(dict, Json)
+
+
+def parse_arguments():
+ """
+ Parse command line arguments for the benchmarking CLI.
+ """
+ parser = argparse.ArgumentParser(description="CLI for benchmarking the huggingface/transformers.")
+
+ parser.add_argument(
+ "branch",
+ type=str,
+ help="The branch name on which the benchmarking is performed.",
+ )
+
+ parser.add_argument(
+ "commit_id",
+ type=str,
+ help="The commit hash on which the benchmarking is performed.",
+ )
+ parser.add_argument(
+ "commit_msg",
+ type=str,
+ help="The commit message associated with the commit, truncated to 70 characters.",
+ )
-def collect_metrics(benchmark_id, continue_metric_collection, metrics_recorder):
+ args = parser.parse_args()
+
+ return args.branch, args.commit_id, args.commit_msg
+
+
+def collect_metrics(benchmark_id, continue_metric_collection):
p = psutil.Process(os.getpid())
+ conn = psycopg2.connect("dbname=metrics")
+ cur = conn.cursor()
while not continue_metric_collection.is_set():
with p.oneshot():
cpu_util = p.cpu_percent()
@@ -27,41 +73,47 @@ def collect_metrics(benchmark_id, continue_metric_collection, metrics_recorder):
gpu_stats = gpustat.GPUStatCollection.new_query()
gpu_util = gpu_stats[0]["utilization.gpu"]
gpu_mem_megabytes = gpu_stats[0]["memory.used"]
- metrics_recorder.collect_device_measurements(
- benchmark_id, cpu_util, mem_megabytes, gpu_util, gpu_mem_megabytes
+ cur.execute(
+ "INSERT INTO device_measurements (benchmark_id, cpu_util, mem_megabytes, gpu_util, gpu_mem_megabytes) VALUES (%s, %s, %s, %s, %s)",
+ (benchmark_id, cpu_util, mem_megabytes, gpu_util, gpu_mem_megabytes),
)
sleep(0.01)
+ conn.commit()
+ conn.close()
-def run_benchmark(logger: Logger, branch: str, commit_id: str, commit_msg: str, num_tokens_to_generate=100):
+def run_benchmark(branch: str, commit_id: str, commit_msg: str, num_tokens_to_generate=100):
continue_metric_collection = Event()
metrics_thread = None
- model_id = "meta-llama/Llama-2-7b-hf"
- metrics_recorder = MetricsRecorder(psycopg2.connect("dbname=metrics"), logger, branch, commit_id, commit_msg)
try:
gpu_stats = gpustat.GPUStatCollection.new_query()
gpu_name = gpu_stats[0]["name"]
- benchmark_id = metrics_recorder.initialise_benchmark({"gpu_name": gpu_name, "model_id": model_id})
- logger.info(f"running benchmark #{benchmark_id} on {gpu_name} for {model_id}")
- metrics_thread = Thread(
- target=collect_metrics,
- args=[benchmark_id, continue_metric_collection, metrics_recorder],
+ conn = psycopg2.connect("dbname=metrics")
+ cur = conn.cursor()
+ cur.execute(
+ "INSERT INTO benchmarks (branch, commit_id, commit_message, gpu_name) VALUES (%s, %s, %s, %s) RETURNING benchmark_id",
+ (branch, commit_id, commit_msg, gpu_name),
)
+ conn.commit()
+ benchmark_id = cur.fetchone()[0]
+ logger.info(f"running benchmark #{benchmark_id} on {gpu_name}")
+ metrics_thread = Thread(target=collect_metrics, args=[benchmark_id, continue_metric_collection])
metrics_thread.start()
logger.info("started background thread to fetch device metrics")
os.environ["TOKENIZERS_PARALLELISM"] = "false" # silence warnings when compiling
device = "cuda"
+ ckpt = "meta-llama/Llama-2-7b-hf"
logger.info("downloading weights")
# This is to avoid counting download in model load time measurement
- model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16)
+ model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16)
gen_config = GenerationConfig(do_sample=False, top_p=1, temperature=1)
logger.info("loading model")
start = perf_counter()
model = AutoModelForCausalLM.from_pretrained(
- model_id, torch_dtype=torch.float16, generation_config=gen_config
+ ckpt, torch_dtype=torch.float16, generation_config=gen_config
).eval()
model.to(device)
torch.cuda.synchronize()
@@ -69,7 +121,7 @@ def run_benchmark(logger: Logger, branch: str, commit_id: str, commit_msg: str,
model_load_time = end - start
logger.info(f"loaded model in: {model_load_time}s")
- tokenizer = AutoTokenizer.from_pretrained(model_id)
+ tokenizer = AutoTokenizer.from_pretrained(ckpt)
prompt = "Why dogs are so cute?"
inputs = tokenizer(prompt, return_tensors="pt").to(device)
@@ -316,27 +368,41 @@ def decode_one_token(model, cur_token, cache_position, past_key_values):
logger.info(f"completed second compile generation in: {fourth_compile_generate_time}s")
logger.info(f"generated: {tokenizer.batch_decode(output.cpu().tolist())}")
- metrics_recorder.collect_model_measurements(
- benchmark_id,
- {
- "model_load_time": model_load_time,
- "first_eager_forward_pass_time_secs": first_eager_fwd_pass_time,
- "second_eager_forward_pass_time_secs": second_eager_fwd_pass_time,
- "first_eager_generate_time_secs": first_eager_generate_time,
- "second_eager_generate_time_secs": second_eager_generate_time,
- "time_to_first_token_secs": time_to_first_token,
- "time_to_second_token_secs": time_to_second_token,
- "time_to_third_token_secs": time_to_third_token,
- "time_to_next_token_mean_secs": mean_time_to_next_token,
- "first_compile_generate_time_secs": first_compile_generate_time,
- "second_compile_generate_time_secs": second_compile_generate_time,
- "third_compile_generate_time_secs": third_compile_generate_time,
- "fourth_compile_generate_time_secs": fourth_compile_generate_time,
- },
+ cur.execute(
+ """
+ INSERT INTO model_measurements (
+ benchmark_id,
+ measurements
+ ) VALUES (%s, %s)
+ """,
+ (
+ benchmark_id,
+ {
+ "model_load_time": model_load_time,
+ "first_eager_forward_pass_time_secs": first_eager_fwd_pass_time,
+ "second_eager_forward_pass_time_secs": second_eager_fwd_pass_time,
+ "first_eager_generate_time_secs": first_eager_generate_time,
+ "second_eager_generate_time_secs": second_eager_generate_time,
+ "time_to_first_token_secs": time_to_first_token,
+ "time_to_second_token_secs": time_to_second_token,
+ "time_to_third_token_secs": time_to_third_token,
+ "time_to_next_token_mean_secs": mean_time_to_next_token,
+ "first_compile_generate_time_secs": first_compile_generate_time,
+ "second_compile_generate_time_secs": second_compile_generate_time,
+ "third_compile_generate_time_secs": third_compile_generate_time,
+ "fourth_compile_generate_time_secs": fourth_compile_generate_time,
+ },
+ ),
)
+ conn.commit()
+ conn.close()
except Exception as e:
logger.error(f"Caught exception: {e}")
continue_metric_collection.set()
if metrics_thread is not None:
metrics_thread.join()
- metrics_recorder.close()
+
+
+if __name__ == "__main__":
+ branch, commit_id, commit_msg = parse_arguments()
+ run_benchmark(branch, commit_id, commit_msg, num_tokens_to_generate=20)
diff --git a/docker/transformers-pytorch-amd-gpu/Dockerfile b/docker/transformers-pytorch-amd-gpu/Dockerfile
index 83f8565c8f467e..da91906d621429 100644
--- a/docker/transformers-pytorch-amd-gpu/Dockerfile
+++ b/docker/transformers-pytorch-amd-gpu/Dockerfile
@@ -1,4 +1,4 @@
-FROM rocm/dev-ubuntu-22.04:6.1
+FROM rocm/dev-ubuntu-22.04:6.0.2
# rocm/pytorch has no version with 2.1.0
LABEL maintainer="Hugging Face"
@@ -11,7 +11,7 @@ RUN apt update && \
RUN python3 -m pip install --no-cache-dir --upgrade pip numpy
-RUN python3 -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.1
+RUN python3 -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.0
RUN python3 -m pip install --no-cache-dir --upgrade importlib-metadata setuptools ninja git+https://github.com/facebookresearch/detectron2.git pytesseract "itsdangerous<2.1.0"
@@ -30,5 +30,5 @@ RUN python3 -m pip uninstall -y tensorflow flax
# this line must be added in order for python to be aware of transformers.
RUN cd transformers && python3 setup.py develop
-# Remove nvml and nvidia-ml-py as it is not compatible with ROCm. apex is not tested on NVIDIA either.
-RUN python3 -m pip uninstall py3nvml pynvml nvidia-ml-py apex -y
+# Remove nvml as it is not compatible with ROCm. apex is not tested on NVIDIA either.
+RUN python3 -m pip uninstall py3nvml pynvml apex -y
diff --git a/docker/transformers-quantization-latest-gpu/Dockerfile b/docker/transformers-quantization-latest-gpu/Dockerfile
index 3cb2acdc53bb1a..089be4a4460101 100755
--- a/docker/transformers-quantization-latest-gpu/Dockerfile
+++ b/docker/transformers-quantization-latest-gpu/Dockerfile
@@ -50,9 +50,6 @@ RUN python3 -m pip install --no-cache-dir git+https://github.com/huggingface/pef
# Add aqlm for quantization testing
RUN python3 -m pip install --no-cache-dir aqlm[gpu]==1.0.2
-# Add vptq for quantization testing
-RUN python3 -m pip install --no-cache-dir vptq
-
# Add hqq for quantization testing
RUN python3 -m pip install --no-cache-dir hqq
diff --git a/docs/source/ar/_toctree.yml b/docs/source/ar/_toctree.yml
index 287f4dffbb384e..1208153c22df68 100644
--- a/docs/source/ar/_toctree.yml
+++ b/docs/source/ar/_toctree.yml
@@ -133,18 +133,12 @@
title: المعايير
- local: notebooks
title: دفاتر الملاحظات مع الأمثلة
- - local: community
- title: موارد المجتمع
+# - local: community
+# title: موارد المجتمع
- local: troubleshooting
title: استكشاف الأخطاء وإصلاحها
- local: gguf
title: التوافق مع ملفات GGUF
- - local: tiktoken
- title: التوافق مع ملفات TikToken
- - local: modular_transformers
- title: الوحدات النمطية في `transformers`
- - local: how_to_hack_models
- title: اختراق النموذج (الكتابة فوق فئة لاستخدامك)
title: أدلة المطورين
# - sections:
# - local: quantization/overview
@@ -157,8 +151,6 @@
# title: AWQ
# - local: quantization/aqlm
# title: AQLM
-# - local: quantization/vptq
-# title: VPTQ
# - local: quantization/quanto
# title: Quanto
# - local: quantization/eetq
diff --git a/docs/source/ar/community.md b/docs/source/ar/community.md
deleted file mode 100644
index 5a1c31de0aaa3f..00000000000000
--- a/docs/source/ar/community.md
+++ /dev/null
@@ -1,66 +0,0 @@
-# مجتمع المطورين
-
-هذه الصفحة تجمع الموارد حول 🤗 Transformers التي طورها المجتمع.
-
-## موارد المجتمع:
-
-| المصدر | الوصف | المؤلف |
-|:----------|:-------------|------:|
-| [Hugging Face Transformers Glossary Flashcards](https://www.darigovresearch.com/huggingface-transformers-glossary-flashcards) | مجموعة من البطاقات التعليمية القائمة على [Transformers Docs Glossary](glossary) والتي تم وضعها في شكل يمكن تعلمه/مراجعته بسهولة باستخدام [Anki](https://apps.ankiweb.net/) وهو تطبيق مفتوح المصدر متعدد المنصات مصمم خصيصًا للاحتفاظ بالمعرفة على المدى الطويل. شاهد هذا [فيديو تمهيدي حول كيفية استخدام البطاقات التعليمية](https://www.youtube.com/watch?v=Dji_7PILrw). | [Darigov Research](https://www.darigovresearch.com/) |
-
-## دفاتر ملاحظات المجتمع:
-
-| الدفتر | الوصف | المؤلف | |
-|:----------|:-------------|:-------------|------:|
-| [Fine-tune a pre-trained Transformer to generate lyrics](https://github.com/AlekseyKorshuk/huggingartists) | كيفية توليد كلمات الأغاني على غرار فنانك المفضل من خلال ضبط نموذج GPT-2 | [Aleksey Korshuk](https://github.com/AlekseyKorshuk) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/AlekseyKorshuk/huggingartists/blob/master/huggingartists-demo.ipynb) |
-| [Train T5 in Tensorflow 2](https://github.com/snapthat/TF-T5-text-to-text) | كيفية تدريب T5 لأي مهمة باستخدام Tensorflow 2. يوضح هذا الدفتر مهمة السؤال والجواب المنفذة في Tensorflow 2 باستخدام SQUAD | [Muhammad Harris](https://github.com/HarrisDePerceptron) |[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/snapthat/TF-T5-text-to-text/blob/master/snapthatT5/notebooks/TF-T5-Datasets%20Training.ipynb) |
-| [Train T5 on TPU](https://github.com/patil-suraj/exploring-T5/blob/master/T5_on_TPU.ipynb) | كيفية تدريب T5 على SQUAD مع Transformers و Nlp | [Suraj Patil](https://github.com/patil-suraj) |[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/patil-suraj/exploring-T5/blob/master/T5_on_TPU.ipynb#scrollTo=QLGiFCDqvuil) |
-| [Fine-tune T5 for Classification and Multiple Choice](https://github.com/patil-suraj/exploring-T5/blob/master/t5_fine_tuning.ipynb) | كيفية ضبط نموذج T5 للتصنيف والمهام متعددة الخيارات باستخدام تنسيق النص إلى نص مع PyTorch Lightning | [Suraj Patil](https://github.com/patil-suraj) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/patil-suraj/exploring-T5/blob/master/t5_fine_tuning.ipynb) |
-| [Fine-tune DialoGPT on New Datasets and Languages](https://github.com/ncoop57/i-am-a-nerd/blob/master/_notebooks/2020-05-12-chatbot-part-1.ipynb) | كيفية ضبط نموذج DialoGPT على مجموعة بيانات جديدة لروبوتات الدردشة المحادثية المفتوحة | [Nathan Cooper](https://github.com/ncoop57) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ncoop57/i-am-a-nerd/blob/master/_notebooks/2020-05-12-chatbot-part-1.ipynb) |
-| [Long Sequence Modeling with Reformer](https://github.com/patrickvonplaten/notebooks/blob/master/PyTorch_Reformer.ipynb) | كيفية التدريب على تسلسلات طويلة تصل إلى 500,000 رمز باستخدام Reformer | [Patrick von Platen](https://github.com/patrickvonplaten) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/patrickvonplaten/notebooks/blob/master/PyTorch_Reformer.ipynb) |
-| [Fine-tune BART for Summarization](https://github.com/ohmeow/ohmeow_website/blob/master/posts/2021-05-25-mbart-sequence-classification-with-blurr.ipynb) | كيفية ضبط نموذج BART للتلخيص باستخدام fastai باستخدام blurr | [Wayde Gilliam](https://ohmeow.com/) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ohmeow/ohmeow_website/blob/master/posts/2021-05-25-mbart-sequence-classification-with-blurr.ipynb) |
-| [Fine-tune a pre-trained Transformer on anyone's tweets](https://colab.research.google.com/github/borisdayma/huggingtweets/blob/master/huggingtweets-demo.ipynb) | كيفية توليد تغريدات على غرار حساب Twitter المفضل لديك من خلال ضبط نموذج GPT-2 | [Boris Dayma](https://github.com/borisdayma) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/borisdayma/huggingtweets/blob/master/huggingtweets-demo.ipynb) |
-| [Optimize 🤗 Hugging Face models with Weights & Biases](https://colab.research.google.com/github/wandb/examples/blob/master/colabs/huggingface/Optimize_Hugging_Face_models_with_Weights_%26_Biases.ipynb) | دليل كامل لعرض تكامل W&B مع Hugging Face | [Boris Dayma](https://github.com/borisdayma) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/wandb/examples/blob/master/colabs/huggingface/Optimize_Hugging_Face_models_with_Weights_%26_Biases.ipynb) |
-| [Pretrain Longformer](https://github.com/allenai/longformer/blob/master/scripts/convert_model_to_long.ipynb) | كيفية بناء نسخة "طويلة" من النماذج المسبقة التدريب الموجودة | [Iz Beltagy](https://beltagy.net) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/allenai/longformer/blob/master/scripts/convert_model_to_long.ipynb) |
-| [Fine-tune Longformer for QA](https://github.com/patil-suraj/Notebooks/blob/master/longformer_qa_training.ipynb) | كيفية ضبط نموذج Longformer لمهمة QA | [Suraj Patil](https://github.com/patil-suraj) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/patil-suraj/Notebooks/blob/master/longformer_qa_training.ipynb) |
-| [Evaluate Model with 🤗nlp](https://github.com/patrickvonplaten/notebooks/blob/master/How_to_evaluate_Longformer_on_TriviaQA_using_NLP.ipynb) | كيفية تقييم نموذج Longformer على TriviaQA مع `nlp` | [Patrick von Platen](https://github.com/patrickvonplaten) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1m7eTGlPmLRgoPkkA7rkhQdZ9ydpmsdLE?usp=sharing) |
-| [Fine-tune T5 for Sentiment Span Extraction](https://github.com/enzoampil/t5-intro/blob/master/t5_qa_training_pytorch_span_extraction.ipynb) | كيفية ضبط نموذج T5 لاستخراج المشاعر باستخدام تنسيق النص إلى نص مع PyTorch Lightning | [Lorenzo Ampil](https://github.com/enzoampil) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/enzoampil/t5-intro/blob/master/t5_qa_training_pytorch_span_extraction.ipynb) |
-| [Fine-tune DistilBert for Multiclass Classification](https://github.com/abhimishra91/transformers-tutorials/blob/master/transformers_multiclass_classification.ipynb) | كيفية ضبط نموذج DistilBert للتصنيف متعدد الفئات باستخدام PyTorch | [Abhishek Kumar Mishra](https://github.com/abhimishra91) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/abhimishra91/transformers-tutorials/blob/master/transformers_multiclass_classification.ipynb)|
-|[Fine-tune BERT for Multi-label Classification](https://github.com/abhimishra91/transformers-tutorials/blob/master/transformers_multi_label_classification.ipynb)|كيفية ضبط نموذج BERT للتصنيف متعدد التصنيفات باستخدام PyTorch|[Abhishek Kumar Mishra](https://github.com/abhimishra91) |[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/abhimishra91/transformers-tutorials/blob/master/transformers_multi_label_classification.ipynb)|
-|[Fine-tune T5 for Summarization](https://github.com/abhimishra91/transformers-tutorials/blob/master/transformers_summarization_wandb.ipynb)|كيفية ضبط نموذج T5 للتلخيص في PyTorch وتتبع التجارب باستخدام WandB|[Abhishek Kumar Mishra](https://github.com/abhimishra91) |[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/abhimishra91/transformers-tutorials/blob/master/transformers_summarization_wandb.ipynb)|
-|[Speed up Fine-Tuning in Transformers with Dynamic Padding / Bucketing](https://github.com/ELS-RD/transformers-notebook/blob/master/Divide_Hugging_Face_Transformers_training_time_by_2_or_more.ipynb)|كيفية تسريع الضبط الدقيق بعامل 2 باستخدام الضبط الديناميكي/التقسيم|[Michael Benesty](https://github.com/pommedeterresautee) |[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1CBfRU1zbfu7-ijiOqAAQUA-RJaxfcJoO?usp=sharing)|
-|[Pretrain Reformer for Masked Language Modeling](https://github.com/patrickvonplaten/notebooks/blob/master/Reformer_For_Masked_LM.ipynb)| كيفية تدريب نموذج Reformer مع طبقات الانتباه ثنائية الاتجاه | [Patrick von Platen](https://github.com/patrickvonplaten) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1tzzh0i8PgDQGV3SMFUGxM7_gGae3K-uW?usp=sharing)|
-|[Expand and Fine Tune Sci-BERT](https://github.com/lordtt13/word-embeddings/blob/master/COVID-19%20Research%20Data/COVID-SciBERT.ipynb)| كيفية زيادة مفردات نموذج SciBERT المسبق التدريب من AllenAI على مجموعة بيانات CORD وإنشاء خط أنابيب لها. | [Tanmay Thakur](https://github.com/lordtt13) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1rqAR40goxbAfez1xvF3hBJphSCsvXmh8)|
-|[Fine Tune BlenderBotSmall for Summarization using the Trainer API](https://github.com/lordtt13/transformers-experiments/blob/master/Custom%20Tasks/fine-tune-blenderbot_small-for-summarization.ipynb)| كيفية ضبط نموذج BlenderBotSmall للتلخيص على مجموعة بيانات مخصصة، باستخدام واجهة برمجة التطبيقات Trainer. | [Tanmay Thakur](https://github.com/lordtt13) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/19Wmupuls7mykSGyRN_Qo6lPQhgp56ymq?usp=sharing)|
-|[Fine-tune Electra and interpret with Integrated Gradients](https://github.com/elsanns/xai-nlp-notebooks/blob/master/electra_fine_tune_interpret_captum_ig.ipynb) | كيفية ضبط نموذج Electra للتحليل العاطفي وتفسير التنبؤات باستخدام Captum Integrated Gradients | [Eliza Szczechla](https://elsanns.github.io) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/elsanns/xai-nlp-notebooks/blob/master/electra_fine_tune_interpret_captum_ig.ipynb)|
-|[fine-tune a non-English GPT-2 Model with Trainer class](https://github.com/philschmid/fine-tune-GPT-2/blob/master/Fine_tune_a_non_English_GPT_2_Model_with_Huggingface.ipynb) | كيفية ضبط نموذج GPT-2 غير الإنجليزي باستخدام فئة Trainer | [Philipp Schmid](https://www.philschmid.de) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/philschmid/fine-tune-GPT-2/blob/master/Fine_tune_a_non_English_GPT_2_Model_with_Huggingface.ipynb)|
-|[Fine-tune a DistilBERT Model for Multi Label Classification task](https://github.com/DhavalTaunk08/Transformers_scripts/blob/master/Transformers_multilabel_distilbert.ipynb) | كيفية ضبط نموذج DistilBERT لمهمة التصنيف متعدد التصنيفات | [Dhaval Taunk](https://github.com/DhavalTaunk08) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/DhavalTaunk08/Transformers_scripts/blob/master/Transformers_multilabel_distilbert.ipynb)|
-|[Fine-tune ALBERT for sentence-pair classification](https://github.com/NadirEM/nlp-notebooks/blob/master/Fine_tune_ALBERT_sentence_pair_classification.ipynb) | كيفية ضبط نموذج ALBERT أو أي نموذج آخر قائم على BERT لمهمة التصنيف المزدوج للجمل | [Nadir El Manouzi](https://github.com/NadirEM) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NadirEM/nlp-notebooks/blob/master/Fine_tune_ALBERT_sentence_pair_classification.ipynb)|
-|[Fine-tune Roberta for sentiment analysis](https://github.com/DhavalTaunk08/NLP_scripts/blob/master/sentiment_analysis_using_roberta.ipynb) | كيفية ضبط نموذج Roberta للتحليل العاطفي | [Dhaval Taunk](https://github.com/DhavalTaunk08) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/DhavalTaunk08/NLP_scripts/blob/master/sentiment_analysis_using_roberta.ipynb)|
-|[Evaluating Question Generation Models](https://github.com/flexudy-pipe/qugeev) | ما مدى دقة الإجابات على الأسئلة التي يولدها نموذجك التحويلي seq2seq؟ | [Pascal Zoleko](https://github.com/zolekode) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1bpsSqCQU-iw_5nNoRm_crPq6FRuJthq_?usp=sharing)|
-|[Classify text with DistilBERT and Tensorflow](https://github.com/peterbayerle/huggingface_notebook/blob/main/distilbert_tf.ipynb) | كيفية ضبط نموذج DistilBERT للتصنيف النصي في TensorFlow | [Peter Bayerle](https://github.com/peterbayerle) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/peterbayerle/huggingface_notebook/blob/main/distilbert_tf.ipynb)|
-|[Leverage BERT for Encoder-Decoder Summarization on CNN/Dailymail](https://github.com/patrickvonplaten/notebooks/blob/master/BERT2BERT_for_CNN_Dailymail.ipynb) | كيفية البدء السريع لنموذج *EncoderDecoderModel* مع نقطة تفتيش *google-bert/bert-base-uncased* للتلخيص على CNN/Dailymail | [Patrick von Platen](https://github.com/patrickvonplaten) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/patrickvonplaten/notebooks/blob/master/BERT2BERT_for_CNN_Dailymail.ipynb)|
-|[Leverage RoBERTa for Encoder-Decoder Summarization on BBC XSum](https://github.com/patrickvonplaten/notebooks/blob/master/RoBERTaShared_for_BBC_XSum.ipynb) | كيفية البدء السريع لنموذج *EncoderDecoderModel* المشترك مع نقطة تفتيش *FacebookAI/roberta-base* للتلخيص على BBC/XSum | [Patrick von Platen](https://github.com/patrickvonplaten) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/patrickvonplaten/notebooks/blob/master/RoBERTaShared_for_BBC_XSum.ipynb)|
-|[Fine-tune TAPAS on Sequential Question Answering (SQA)](https://github.com/NielsRogge/Transformers-Tutorials/blob/master/TAPAS/Fine_tuning_TapasForQuestionAnswering_on_SQA.ipynb) | كيفية ضبط نموذج *TapasForQuestionAnswering* مع نقطة تفتيش *tapas-base* على مجموعة بيانات Sequential Question Answering (SQA) | [Niels Rogge](https://github.com/nielsrogge) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NielsRogge/Transformers-Tutorials/blob/master/TAPAS/Fine_tuning_TapasForQuestionAnswering_on_SQA.ipynb)|
-|[Evaluate TAPAS on Table Fact Checking (TabFact)](https://github.com/NielsRogge/Transformers-Tutorials/blob/master/TAPAS/Evaluating_TAPAS_on_the_Tabfact_test_set.ipynb) | كيفية تقييم نموذج *TapasForSequenceClassification* المضبوط مسبقًا مع نقطة تفتيش *tapas-base-finetuned-tabfact* باستخدام مزيج من مكتبتي 🤗 datasets و 🤗 transformers | [Niels Rogge](https://github.com/nielsrogge) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NielsRogge/Transformers-Tutorials/blob/master/TAPAS/Evaluating_TAPAS_on_the_Tabfact_test_set.ipynb)|
-|[Fine-tuning mBART for translation](https://colab.research.google.com/github/vasudevgupta7/huggingface-tutorials/blob/main/translation_training.ipynb) | كيفية ضبط نموذج mBART باستخدام Seq2SeqTrainer للترجمة من الهندية إلى الإنجليزية | [Vasudev Gupta](https://github.com/vasudevgupta7) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/vasudevgupta7/huggingface-tutorials/blob/main/translation_training.ipynb)|
-|[Fine-tune LayoutLM on FUNSD (a form understanding dataset)](https://github.com/NielsRogge/Transformers-Tutorials/blob/master/LayoutLM/Fine_tuning_LayoutLMForTokenClassification_on_FUNSD.ipynb) | كيفية ضبط نموذج *LayoutLMForTokenClassification* على مجموعة بيانات FUNSD لاستخراج المعلومات من المستندات الممسوحة ضوئيًا | [Niels Rogge](https://github.com/nielsrogge) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NielsRogge/Transformers-Tutorials/blob/master/LayoutLM/Fine_tuning_LayoutLMForTokenClassification_on_FUNSD.ipynb)|
-|[Fine-Tune DistilGPT2 and Generate Text](https://colab.research.google.com/github/tripathiaakash/DistilGPT2-Tutorial/blob/main/distilgpt2_fine_tuning.ipynb) | كيفية ضبط نموذج DistilGPT2 وتوليد النص | [Aakash Tripathi](https://github.com/tripathiaakash) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tripathiaakash/DistilGPT2-Tutorial/blob/main/distilgpt2_fine_tuning.ipynb)|
-|[Fine-Tune LED on up to 8K tokens](https://github.com/patrickvonplaten/notebooks/blob/master/Fine_tune_Longformer_Encoder_Decoder_(LED)_for_Summarization_on_pubmed.ipynb) | كيفية ضبط نموذج LED على pubmed للتلخيص طويل المدى | [Patrick von Platen](https://github.com/patrickvonplaten) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/patrickvonplaten/notebooks/blob/master/Fine_tune_Longformer_Encoder_Decoder_(LED)_for_Summarization_on_pubmed.ipynb)|
-|[Evaluate LED on Arxiv](https://github.com/patrickvonplaten/notebooks/blob/master/LED_on_Arxiv.ipynb) | كيفية تقييم نموذج LED للتلخيص طويل المدى بشكل فعال | [Patrick von Platen](https://github.com/patrickvonplaten) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/patrickvonplaten/notebooks/blob/master/LED_on_Arxiv.ipynb)|
-|[Fine-tune LayoutLM on RVL-CDIP (a document image classification dataset)](https://github.com/NielsRogge/Transformers-Tutorials/blob/master/LayoutLM/Fine_tuning_LayoutLMForSequenceClassification_on_RVL_CDIP.ipynb) | كيفية ضبط نموذج *LayoutLMForSequenceClassification* على مجموعة بيانات RVL-CDIP لتصنيف المستندات الممسوحة ضوئيًا | [Niels Rogge](https://github.com/nielsrogge) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NielsRogge/Transformers-Tutorials/blob/master/LayoutLM/Fine_tuning_LayoutLMForSequenceClassification_on_RVL_CDIP.ipynb)|
-|[Wav2Vec2 CTC decoding with GPT2 adjustment](https://github.com/voidful/huggingface_notebook/blob/main/xlsr_gpt.ipynb) | كيفية فك تشفير تسلسل CTC مع تعديل نموذج اللغة | [Eric Lam](https://github.com/voidful) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1e_zQHYbO2YKEaUgzb1ww1WwiAyydAj?usp=sharing)|
-|[Fine-tune BART for summarization in two languages with Trainer class](https://github.com/elsanns/xai-nlp-notebooks/blob/master/fine_tune_bart_summarization_two_langs.ipynb) | كيفية ضبط نموذج BART للتلخيص بلغتين باستخدام فئة Trainer | [Eliza Szczechla](https://github.com/elsanns) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/elsanns/xai-nlp-notebooks/blob/master/fine_tune_bart_summarization_two_langs.ipynb)|
-|[Evaluate Big Bird on Trivia QA](https://github.com/patrickvonplaten/notebooks/blob/master/Evaluating_Big_Bird_on_TriviaQA.ipynb) | كيفية تقييم نموذج BigBird للأسئلة والأجوبة على وثائق طويلة على Trivia QA | [Patrick von Platen](https://github.com/patrickvonplaten) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/patrickvonplaten/notebooks/blob/master/Evaluating_Big_Bird_on_TriviaQA.ipynb)|
-| [Create video captions using Wav2Vec2](https://github.com/Muennighoff/ytclipcc/blob/main/wav2vec_youtube_captions.ipynb) | كيفية إنشاء تعليقات توضيحية على YouTube من أي فيديو من خلال تفريغ الصوت باستخدام Wav2Vec | [Niklas Muennighoff](https://github.com/Muennighoff) |[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Muennighoff/ytclipcc/blob/main/wav2vec_youtube_captions.ipynb) |
-| [Fine-tune the Vision Transformer on CIFAR-10 using PyTorch Lightning](https://github.com/NielsRogge/Transformers-Tutorials/blob/master/VisionTransformer/Fine_tuning_the_Vision_Transformer_on_CIFAR_10_with_PyTorch_Lightning.ipynb) | كيفية ضبط نموذج Vision Transformer (ViT) على CIFAR-10 باستخدام مكتبات HuggingFace Transformers و Datasets و PyTorch Lightning | [Niels Rogge](https://github.com/nielsrogge) |[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NielsRogge/Transformers-Tutorials/blob/master/VisionTransformer/Fine_tuning_the_Vision_Transformer_on_CIFAR_10_with_PyTorch_Lightning.ipynb) |
-| [Fine-tune the Vision Transformer on CIFAR-10 using the 🤗 Trainer](https://github.com/NielsRogge/Transformers-Tutorials/blob/master/VisionTransformer/Fine_tuning_the_Vision_Transformer_on_CIFAR_10_with_the_%F0%9F%A4%97_Trainer.ipynb) | كيفية ضبط نموذج Vision Transformer (ViT) على CIFAR-10 باستخدام مكتبات HuggingFace Transformers و Datasets و 🤗 Trainer | [Niels Rogge](https://github.com/nielsrogge) |[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NielsRogge/Transformers-Tutorials/blob/master/VisionTransformer/Fine_tuning_the_Vision_Transformer_on_CIFAR_10_with_the_%F0%9F%A4%97_Trainer.ipynb) |
-| [Evaluate LUKE on Open Entity, an entity typing dataset](https://github.com/studio-ousia/luke/blob/master/notebooks/huggingface_open_entity.ipynb) | كيفية تقييم نموذج *LukeForEntityClassification* على مجموعة بيانات Open Entity | [Ikuya Yamada](https://github.com/ikuyamada) |[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/studio-ousia/luke/blob/master/notebooks/huggingface_open_entity.ipynb) |
-| [Evaluate LUKE on TACRED, a relation extraction dataset](https://github.com/studio-ousia/luke/blob/master/notebooks/huggingface_tacred.ipynb) | كيفية تقييم نموذج *LukeForEntityPairClassification* على مجموعة بيانات TACRED | [Ikuya Yamada](https://github.com/ikuyamada) |[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/studio-ousia/luke/blob/master/notebooks/huggingface_tacred.ipynb) |
-| [Evaluate LUKE on CoNLL-2003, an important NER benchmark](https://github.com/studio-ousia/luke/blob/master/notebooks/huggingface_conll_2003.ipynb) | كيفية تقييم نموذج *LukeForEntitySpanClassification* على مجموعة بيانات CoNLL-2003 | [Ikuya Yamada](https://github.com/ikuyamada) |[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/studio-ousia/luke/blob/master/notebooks/huggingface_conll_2003.ipynb) |
-| [Evaluate BigBird-Pegasus on PubMed dataset](https://github.com/vasudevgupta7/bigbird/blob/main/notebooks/bigbird_pegasus_evaluation.ipynb) | كيفية تقييم نموذج *BigBirdPegasusForConditionalGeneration* على مجموعة بيانات PubMed | [Vasudev Gupta](https://github.com/vasudevgupta7) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/vasudevgupta7/bigbird/blob/main/notebooks/bigbird_pegasus_evaluation.ipynb) |
-| [Speech Emotion Classification with Wav2Vec2](https://github.com/m3hrdadfi/soxan/blob/main/notebooks/Emotion_recognition_in_Greek_speech_using_Wav2Vec2.ipynb) | كيفية استخدام نموذج Wav2Vec2 المسبق التدريب لتصنيف المشاعر على مجموعة بيانات MEGA | [Mehrdad Farahani](https://github.com/m3hrdadfi) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/m3hrdadfi/soxan/blob/main/notebooks/Emotion_recognition_in_Greek_speech_using_Wav2Vec2.ipynb) |
-| [Detect objects in an image with DETR](https://github.com/NielsRogge/Transformers-Tutorials/blob/master/DETR/DETR_minimal_example_(with_DetrFeatureExtractor).ipynb) | كيفية استخدام نموذج *DetrForObjectDetection* المدرب للكشف عن الأجسام في صورة وتصوير الانتباه | [Niels Rogge](https://github.com/NielsRogge) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NielsRogge/Transformers-Tutorials/blob/master/DETR/DETR_minimal_example_(with_DetrFeatureExtractor).ipynb) |
-| [Fine-tune DETR on a custom object detection dataset](https://github.com/NielsRogge/Transformers-Tutorials/blob/master/DETR/Fine_tuning_DetrForObjectDetection_on_custom_dataset_(balloon).ipynb) | كيفية ضبط نموذج *DetrForObjectDetection* على مجموعة بيانات الكشف عن الأجسام المخصصة | [Niels Rogge](https://github.com/NielsRogge) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NielsRogge/Transformers-Tutorials/blob/master/DETR/Fine_tuning_DetrForObjectDetection_on_custom_dataset_(balloon).ipynb) |
-| [Finetune T5 for Named Entity Recognition](https://github.com/ToluClassics/Notebooks/blob/main/T5_Ner_Finetuning.ipynb) | كيفية ضبط نموذج *T5* على مهمة التعرف على الكيانات المسماة | [Ogundepo Odunayo](https://github.com/ToluClassics) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1obr78FY_cBmWY5ODViCmzdY6O1KB65Vc?usp=sharing) |
-| [Fine-Tuning Open-Source LLM using QLoRA with MLflow and PEFT](https://github.com/mlflow/mlflow/blob/master/docs/source/llms/transformers/tutorials/fine-tuning/transformers-peft.ipynb) | كيفية استخدام [QLoRA](https://github.com/artidoro/qlora) و [PEFT](https://huggingface.co/docs/peft/en/index) لضبط نموذج LLM بطريقة فعالة من حيث الذاكرة، مع استخدام [MLflow](https://mlflow.org/docs/latest/llms/transformers/index.html) لإدارة تتبع التجارب | [Yuki Watanabe](https://github.com/B-Step62) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/mlflow/mlflow/blob/master/docs/source/llms/transformers/tutorials/fine-tuning/transformers-peft.ipynb) |
diff --git a/docs/source/ar/how_to_hack_models.md b/docs/source/ar/how_to_hack_models.md
deleted file mode 100644
index 8ce3589732f06a..00000000000000
--- a/docs/source/ar/how_to_hack_models.md
+++ /dev/null
@@ -1,163 +0,0 @@
-# كيفية تعديل أي نموذج من نماذج Transformers
-
-توفر مكتبة [🤗 Transformers](https://github.com/huggingface/transformers) مجموعة من النماذج المسبقة التدريب والأدوات لمعالجة اللغات الطبيعية، والرؤية، وما إلى ذلك. على الرغم من أن هذه النماذج تغطي مجموعة واسعة من التطبيقات، فقد تواجه حالات استخدام لا تدعمها المكتبة بشكل افتراضي. يُمكن للتخصيص أن يفتح إمكانيات جديدة، مثل إضافة طبقات جديدة، أو تعديل البنية المعمارية، أو تحسين آليات الانتباه. سيُوضح لك هذا الدليل كيفية تعديل نماذج Transformers الموجودة لتلبية احتياجاتك المحددة. الشيء الرائع هو أنك لست بحاجة إلى الخروج من إطار عمل Transformers لإجراء هذه التغييرات. ي يمكنك تعديل النماذج مباشرةً في Transformers والاستفادة من الميزات مثل [واجهة برمجة التطبيقات Trainer](https://huggingface.co/docs/transformers/main/en/main_classes/trainer)، و [PreTrainedModel](https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel)، والضبط الدقيق الفعال باستخدام أدوات مثل [PEFT](https://huggingface.co/docs/peft/index).
-
-سنرشدك في هذا الدليل لكيفية تخصيص نماذج Transformers الموجودة لتلبية متطلباتك، دون فقدان مزايا الإطار. ستتعلم كيفية:
-
-- تعديل بنية نموذج ما من خلال تغيير آلية الانتباه الخاصة به.
-- تطبيق تقنيات مثل Low-Rank Adaptation (LoRA) على مكونات نموذج محددة.
-
-نحن نشجعك على المساهمة باختراقاتك الخاصة ومشاركتها هنا مع المجتمع1
-
-## مثال: تعديل آلية الانتباه في نموذج Segment Anything (SAM)
-
-نموذج **Segment Anything (SAM)** هو نموذج رائد في مجال تجزئة الصور. في تنفيذه الافتراضي، يستخدم SAM إسقاطًا مجمعًا للاستعلام والمفتاح والقيمة (`qkv`) في آلية الانتباه الخاصة به. ومع ذلك، قد ترغب في ضبط مكونات محددة فقط من آلية الانتباه، مثل إسقاطات الاستعلام (`q`) والقيمة (`v`)، لتقليل عدد المعلمات القابلة للتدريب والموارد الحسابية المطلوبة.
-
-### الدافع
-
-من خلال تقسيم الإسقاط المجمع `qkv` إلى إسقاطات منفصلة `q` و `k` و `v`، يمكنك تطبيق تقنيات مثل **LoRA** (Low-Rank Adaptation) على إسقاطي `q` و `v` فقط. يسمح لك هذا بما يلي:
-
-- ضبط عدد أقل من المعلمات، مما يقلل من العبء الحسابي.
-- تحقيق أداء أفضل من خلال التركيز على مكونات محددة.
-- تجربة استراتيجيات تعديل مختلفة في آلية الانتباه.
-
-### التنفيذ
-
-#### **الخطوة 1: إنشاء فئة اهتمام مخصصة**
-
-بعد ذلك، قم بإنشاء فئة فرعية من فئة `SamVisionAttention` الأصلية وعدلها لتضم إسقاطات `q` و `k` و `v` منفصلة.
-
-```python
-import torch
-import torch.nn as nn
-from transformers.models.sam.modeling_sam import SamVisionAttention
-
-class SamVisionAttentionSplit(SamVisionAttention, nn.Module):
- def __init__(self, config, window_size):
- super().__init__(config, window_size)
- del self.qkv
- # إسقاطات منفصلة q و k و v
- self.q = nn.Linear(config.hidden_size, config.hidden_size, bias=config.qkv_bias)
- self.k = nn.Linear(config.hidden_size, config.hidden_size, bias=config.qkv_bias)
- self.v = nn.Linear(config.hidden_size, config.hidden_size, bias=config.qkv_bias)
- self._register_load_state_dict_pre_hook(self.split_q_k_v_load_hook)
-
- def split_q_k_v_load_hook(self, state_dict, prefix, *args):
- keys_to_delete = []
- for key in list(state_dict.keys()):
- if "qkv." in key:
- # تقسيم q و k و v من الإسقاط المجمع
- q, k, v = state_dict[key].chunk(3, dim=0)
- # استبدال الإسقاطات الفردية q و k و v
- state_dict[key.replace("qkv.", "q.")] = q
- state_dict[key.replace("qkv.", "k.")] = k
- state_dict[key.replace("qkv.", "v.")] = v
- # وضع علامة على مفتاح qkv القديم للحذف
- keys_to_delete.append(key)
-
- # حذف مفاتيح qkv القديمة
- for key in keys_to_delete:
- del state_dict[key]
-
- def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor:
- batch_size, height, width, _ = hidden_states.shape
- qkv_shapes = (batch_size * self.num_attention_heads, height * width, -1)
- query = self.q(hidden_states).reshape((batch_size, height * width,self.num_attention_heads, -1)).permute(0,2,1,3).reshape(qkv_shapes)
- key = self.k(hidden_states).reshape((batch_size, height * width,self.num_attention_heads, -1)).permute(0,2,1,3).reshape(qkv_shapes)
- value = self.v(hidden_states).reshape((batch_size, height * width,self.num_attention_heads, -1)).permute(0,2,1,3).reshape(qkv_shapes)
-
- attn_weights = (query * self.scale) @ key.transpose(-2, -1)
-
- if self.use_rel_pos:
- attn_weights = self.add_decomposed_rel_pos(
- attn_weights, query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
- )
-
- attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype)
- attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
- attn_output = (attn_probs @ value).reshape(batch_size, self.num_attention_heads, height, width, -1)
- attn_output = attn_output.permute(0, 2, 3, 1, 4).reshape(batch_size, height, width, -1)
- attn_output = self.proj(attn_output)
-
- if output_attentions:
- outputs = (attn_output, attn_weights)
- else:
- outputs = (attn_output, None)
- return outputs
-```
-
-**الشرح:**
-
-- **الإسقاطات المنفصلة:** يتم إزالة الإسقاط المُجمع `qkv`، وإنشاء إسقاطات خطية منفصلة `q` و `k` و `v`.
-- **دالة استدعاء تحميل الأوزان:** تقوم طريقة `_split_qkv_load_hook` بتقسيم أوزان `qkv` المسبقة التدريب إلى أوزان `q` و `k` و `v` منفصلة عند تحميل النموذج. يضمن هذا التوافق مع أي نموذج مسبق التدريب.
-- **التنفيذ الأمامي:** يتم حساب الاستعلامات والمفاتيح والقيم بشكل منفصل، وتستمر آلية الانتباه كالمعتاد.
-
-#### **الخطوة 2: استبدال فئة الانتباه الأصلية**
-
-استبدل فئة `SamVisionAttention` الأصلية بفئتك المخصصة بحيث يستخدم النموذج آلية الانتباه المعدلة.
-
-```python
-from transformers import SamModel
-from transformers.models.sam import modeling_sam
-
-# استبدال فئة الاهتمام في وحدة نمطية modeling_sam
-modeling_sam.SamVisionAttention = SamVisionAttentionSplit
-
-# تحميل نموذج SAM المسبق التدريب
-model = SamModel.from_pretrained("facebook/sam-vit-base")
-```
-
-**الشرح:**
-
-- **استبدال الفئة:** من خلال تعيين فئتك المخصصة إلى `modeling_sam.SamVisionAttention`، فإن أي حالات من فئة `SamVisionAttention` في النموذج ستستخدم النسخة المعدلة. وبالتالي، عند استدعاء `SamModel`، سيتم استخدام `SamVisionAttentionSplit` المحددة حديثًا.
-- **تحميل النموذج:** يتم تحميل النموذج باستخدام `from_pretrained`، ويتم دمج آلية الانتباه المخصصة.
-
-#### **الخطوة 3: تطبيق LoRA على إسقاطات محددة**
-
-مع وجود إسقاطات `q` و `k` و `v` منفصلة، يمكنك الآن تطبيق LoRA على مكونات محددة، مثل إسقاطات `q` و `v`.
-
-```python
-from peft import LoraConfig, get_peft_model
-
-config = LoraConfig(
- r=16,
- lora_alpha=32,
- target_modules=["q", "v"], # تطبيق LoRA على إسقاطات q و v
- lora_dropout=0.1,
- task_type="mask-generation"
-)
-
-# تطبيق LoRA على النموذج
-model = get_peft_model(model, config)
-```
-
-**الشرح:**
-
-- **تكوين LoRA:** تحدد `LoraConfig` المرتبة `r`، وعامل القياس `lora_alpha`، والوحدات المستهدفة (`"q"` و `"v"`)، ومعدل التخلي، ونوع المهمة.
-- **تطبيق LoRA:** تقوم دالة `get_peft_model` بتطبيق LoRA على الوحدات المحددة في النموذج.
-- **تقليل المعلمات:** من خلال التركيز على `q` و `v`، فإنك تقلل عدد المعلمات القابلة للتدريب، مما يؤدي إلى تسريع التدريب وتقليل استخدام الذاكرة.
-
-#### **الخطوة 4: التحقق من عدد المعلمات القابلة للتدريب**
-
-من السهل التحقق من عدد المعلمات القابلة للتدريب ومعرفة تأثير تعديلك.
-
-```python
-model.print_trainable_parameters()
-```
-
-**الناتج المتوقع:**
-
-```
-عدد المعلمات القابلة للتدريب: 608,256 || جميع المعلمات: 94,343,728 || نسبة المعلمات القابلة للتدريب: 0.6447
-عدد المعلمات القابلة للتدريب: 912,384 || جميع المعلمات: 94,647,856 || نسبة المعلمات القابلة للتدريب: 0.9640 # مع k
-```
-
-## المساهمة بابداعاتك الخاصة
-
-يمكن لتعديل النماذج المسبقة التدريب أن يفتح آفاقًا جديدة للبحث والتطبيق. من خلال فهم وتعديل الآليات الداخلية للنماذج مثل SAM، يمكنك تخصيصها لتلبية احتياجاتك المحددة، وتحسين الأداء، وتجربة أفكار جديدة.
-
-إذا قمت بتطوير تعديﻻتك الخاصة لنماذج Transformers وترغب في مشاركتها، ففكر في المساهمة في هذه الوثيقة.
-
-- **إنشاء طلب سحب (Pull Request):** شارك تغييراتك وتحسيناتك في التعليمات البرمجية مباشرة في المستودع.
-- **كتابة التوثيق:** قدم تفسيرات وأمثلة واضحة لتعديلاتك.
-- **التفاعل مع المجتمع:** ناقش أفكارك واحصل على تعليقات من المطورين والباحثين الآخرين من خلال فتح مشكلة.
diff --git a/docs/source/ar/modular_transformers.md b/docs/source/ar/modular_transformers.md
deleted file mode 100644
index b500fec1c92d25..00000000000000
--- a/docs/source/ar/modular_transformers.md
+++ /dev/null
@@ -1,184 +0,0 @@
-# المحولات النمطية
-
-مكتبة `transformers` هي إطار عمل ذو فلسفة محدد؛ يتم تعريف فلسفتنا في [الدليل المفاهيمي](./philosophy).
-
-جوهر هذه الفلسفة يتمثل في مبدأ [نموذج واحد، ملف واحد](https://huggingface.co/blog/transformers-design-philosophy)
-في المكتبة. الجانب السلبي لهذا المكون هو تقييده لوراثة واستيراد مكونات الملفات.
-
-نتيجة لذلك، تتكرر مكونات النموذج عبر العديد من الملفات. يحتوي `transformers` على عدد كبير من طبقات الانتباه، يقارب عدد النماذج، والكثير منها متطابق. يتسبب هذا في تباعد عمليات التنفيذ المستقلة مع تطبيق الإصلاحات والتغييرات.
-على أجزاء محددة من التعليمات البرمجية.
-
-ولمعالجة ذلك، اعتمدنا مفهوم "النسخ" في المكتبة. فبإضافة تعليق يُشير إلى أن التعليمات البرمجية هي نسخة من أخرى، نضمن من خلال أنظمة CI والأوامر المحلية عدم تباعد النسخ. لكن هذه العملية، رغم بساطتها، تُسبب إرهاقاً. كما أنها تزيد العبء على المساهمين، وهو ما نهدف إلى تجاوزه.
-
-غالباً ما تتطلب مساهمات النماذج إضافة تعليمات برمجية (حوالي 1000 سطر)، ومعالج (حوالي 500 سطر)، واختبارات، ووثائق، إلخ. ونادراً ما تقل مساهمات النماذج عن 3000-5000 سطر من التعليمات البرمجية، معظمها أكواد نمطية. هذا يرفع مستوى المساهمات،
-
-ونهدف مع المحولات النمطية إلى خفض هذا المستوى إلى حدّ مقبول.
-
-## ما هو؟
-
-تقدم المحولات النمطية مفهوم ملف "نمطي" لمجلد نموذج. يقبل هذا الملف النمطي تعليمات برمجية
-غير مقبولة عادة في ملفات النمذجة/المعالجة، حيث يسمح بالاستيراد من نماذج مجاورة وكذلك
-الوراثة من الفئات إلى فئات أخرى.
-
-يعرّف هذا الملف النمطي النماذج والمعالجات وفئة التكوين التي سيتم تعريفها في وحداتهم
-المتعلقة.
-
-وأخيرًا، يقدم هذا الميزة أداة `linter` جديدة والتي ستعمل على "تفكيك" الملف النمطي إلى بنية "نموذج واحد، ملف واحد"
-هيكل الدليل. سيتم إنشاء هذه الملفات تلقائيًا في كل مرة يتم فيها تشغيل البرنامج النصي؛ مما يقلل من المساهمات المطلوبة
-إلى الملف النمطي، وبالتالي فقط إلى التغييرات بين النموذج المساهم والنماذج الأخرى.
-
-سيقوم مستخدمو النموذج في النهاية باستيراد واستخدام واجهة الملف الواحد، لذا لا يتوقع حدوث أي تغيير هنا. من خلال القيام بذلك،
-نأمل في الجمع بين أفضل ما في العالمين: تمكين المساهمات البسيطة مع الالتزام بفلسفتنا.
-
-لذلك، هذا بديل لعلامات `# Copied from`، ويمكن توقع انتقال النماذج المساهمة سابقًا إلى
-تنسيق المحولات النمطية الجديد في الأشهر المقبلة.
-
-### التفاصيل
-
-تُبسط أداة "linter" الوراثة، مُنشئةً جميع الملفات المفردة من الملف النمطي، مع الحفاظ على شفافيتها أمام مستخدمي Python. حاليًا، تُبسط الأداة مستوىً واحدًا من الوراثة
-
-على سبيل المثال:
-- إذا ورثت فئة التكوين من فئة أخرى وأضافت/حذفت معامل، فسيتم إما الإشارة إلى الملف المولد مباشرةً
- (في حالة الإضافة) أو إزالته تمامًا (في حالة الحذف).
-- إذا ورثت فئة من فئة أخرى، على سبيل المثال: `class GemmaModel(LlamaModel):`، تُستنتج التبعيات تلقائيًا
- سيتم استنتاج جميع الوحدات الفرعية تلقائيًا من الفئة الأصلية.
-- إذا قمت بتعريف وظائف جديدة في الملف `modular` واستخدمتها داخل الفئات، فستستنتج أداة linter ذلك تلقائيًا
-
-يجب أن تكون قادرًا على كتابة كل شيء (المجزىء اللغوي، ومُعالِج الصور، والنموذج، والتكوين) في الملف `modular`، وسيتم إنشاء الملفات المُقابلة تلقائيًا.
-
-### التطبيق
-
-[TODO] نقدم اختبارًا جديدًا، للتأكد من أن المحتوى المولد يتطابق مع ما هو موجود في `modular_xxxx.py`
-
-### الأمثلة
-
-هنا مثال سريع باستخدام BERT و RoBERTa. النموذجان مرتبطان ارتباطًا وثيقًا: يختلف تنفيذهما النموذجي في طبقة تضمين.
-
-بدلاً من إعادة تعريف النموذج بالكامل، إليك كيف يبدو ملف `modular_roberta.py` لفئات النمذجة والتكوين (لأغراض المثال، يتم تجاهل المجزىء اللغوي في هذا الوقت حيث أنه مختلف جدًا).
-
-```python
-from torch import nn
-from ..bert.configuration_bert import BertConfig
-from ..bert.modeling_bert import (
- BertModel,
- BertEmbeddings,
- BertForMaskedLM
-)
-
-# تكوين RoBERTa مطابق لتكوين BERT
-class RobertaConfig(BertConfig):
- model_type = 'roberta'
-
-# نعيد تعريف الإضافات هنا لتسليط الضوء على اختلاف معرف الحشو، ونعيد تعريف الإضافات الموضعية
-class RobertaEmbeddings(BertEmbeddings):
- def __init__(self, config):
- super().__init__(config())
-
- self.padding_idx = config.pad_token_id
- self.position_embeddings = nn.Embedding(
- config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
- )
-
-# نموذج RoBERTa مطابق لنموذج BERT، باستثناء طبقة الإضافات.
-# نعيد تعريف الإضافات أعلاه، لذا هنا لا توجد حاجة لعمل إضافي
-class RobertaModel(BertModel):
- def __init__(self, config):
- super().__init__(config)
- self.embeddings = RobertaEmbeddings(config)
-
-
-# الرؤوس الآن تحتاج فقط إلى إعادة تعريف النموذج داخل `RobertaModel` الصحيح
-class RobertaForMaskedLM(BertForMaskedLM):
- def __init__(self, config):
- super().__init__(config)
- self.model = RobertaModel(config)
-```
-
-لاحظ أنه إذا لم تستخدم الاعتماد الذي حددته، فستحصل على الخطأ التالي:
-
-```bash
-ValueError: You defined `RobertaEmbeddings` in the modular_roberta.py, it should be used
- when you define `BertModel`, as it is one of it's direct dependencies. Make sure
- you use it in the `__init__` function.
-```
-
-بالإضافة إلى ذلك، قد تجد قائمة بالأمثلة هنا:
-
-## ما هو ليس كذلك
-
-ليس بديلاً لتعليمات برمجة النمذجة (بعد؟)، وإذا لم يكن نموذجك يعتمد على أي شيء آخر موجود من قبل، فيمكنك إضافة ملف `نمذجة` كالعادة.
-
-
-## الاستخدام المتقدم
-
-### إزالة السمات والوظائف
-لإزالة السمات التي لا تستخدم في نموذجك النمطي، والتي لا تريد رؤيتها في النمذجة المفككة:
-
-```python
-class GemmaModel(LlamaModel): | class GemmaModel(PreTrainedModel):
- def __init__(self, config): | def __init__(self, config):
- super().__init__(self, eos_token) | super().__init__(config)
- del self.embed_tokens | self.padding_idx = config.pad_token_id
- | self.vocab_size = config.vocab_size
- |
- | self.layers = nn.ModuleList(
- | [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
- | )
- | self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
- | self.rotary_emb = LlamaRotaryEmbedding(config=config)
- | self.gradient_checkpointing = False
- |
- | # Initialize weights and apply final processing
- | self.post_init()
-```
-إذا قمت بالتحقق من `LlamaModel` الأصلي، فستجد `embed_tokens` الذي تمت إزالته هنا (كما هو متوقع!)
-
-إزالة وظيفة مشابهة، تحتاج فقط إلى كتابتها مع `raise ValueError("")` لمحاكاة السلوك الذي تريده فعليًا عند إزالة وظيفة أصلية في بايثون.
-
-```python
-class GemmaTokenizer(LlamaTokenizer):
- ...
-
- def get_spm_processor(self):
- raise AttributeError("Not needed for Gemma")
-
- def unk_token_length(self):
- raise AttributeError("Not needed for Gemma")
-```
-
-### تعريف وظائف جديدة
-
-إذا قمت بتعريف وظيفة جديدة في الملف `modular` لاستخدامها داخل فئة، على سبيل المثال
-
-```python
-def my_new_function(*args, **kwargs):
- # Do something here
- pass
-
-class GemmaModel(LlamaModel):
- def forward(*args, **kwargs):
- # Call the function
- example = my_new_function(*args, **kwargs)
- # continue here
-```
-
-سيتم نسخ وظيفة `my_new_function` (وبشكل متكرر، أي وظائف أخرى جديدة يتم استدعاؤها في جسمها) تلقائيًا
-في الملف الذي يتم استخدامه.
-
-### استدعاء `super()`
-قمنا مؤخرًا بشحن بعض الميزات التي تسمح لك بالانتقال من:
-```python
-class GemmaTokenizer(LlamaTokenizer, PretrainedTokenizerFast): | class GemmaModel(nn.Module):
- def __init__(self, eos_token=""): | def __init__(self):
- eos_token = AddedToken(eos_token) | eos_token = AddedToken(eos_token)
- PretrainedTokenizerFast.__init__(self, eos_token) | super().__init__(eos_token)
-```
-هذا مفيد عندما لا تريد تفكيك استدعاء `super()`، وتريد التمييز بين أي استدعاء super init تقوم به!
-
-### التسمية الخاصة
-ندعم الآن أيضًا حالات خاصة مثل
-```python
-class GemmaVisionModel(CLIPModel):
- pass
-```
-حيث اسم فئة `GemmaVision` الخاصة بك ليس هو نفسه `Gemma` النمطي. هذا مفيد للغاية للنماذج المركبة.
diff --git a/docs/source/ar/quicktour.md b/docs/source/ar/quicktour.md
index 1795c3a5d74fcc..9a99c28287d622 100644
--- a/docs/source/ar/quicktour.md
+++ b/docs/source/ar/quicktour.md
@@ -347,8 +347,8 @@ tensor([[0.0021, 0.0018, 0.0115, 0.2121, 0.7725],
```py
>>> from transformers import AutoModel
->>> tokenizer = AutoTokenizer.from_pretrained(pt_save_directory)
->>> pt_model = AutoModelForSequenceClassification.from_pretrained(pt_save_directory, from_pt=True)
+>>> tokenizer = AutoTokenizer.from_pretrained(tf_save_directory)
+>>> pt_model = AutoModelForSequenceClassification.from_pretrained(tf_save_directory, from_tf=True)
```
@@ -356,8 +356,8 @@ tensor([[0.0021, 0.0018, 0.0115, 0.2121, 0.7725],
```py
>>> from transformers import TFAutoModel
->>> tokenizer = AutoTokenizer.from_pretrained(tf_save_directory)
->>> tf_model = TFAutoModelForSequenceClassification.from_pretrained(tf_save_directory, from_tf=True)
+>>> tokenizer = AutoTokenizer.from_pretrained(pt_save_directory)
+>>> tf_model = TFAutoModelForSequenceClassification.from_pretrained(pt_save_directory, from_pt=True)
```
diff --git a/docs/source/ar/tiktoken.md b/docs/source/ar/tiktoken.md
deleted file mode 100644
index 6f3755d8670cdc..00000000000000
--- a/docs/source/ar/tiktoken.md
+++ /dev/null
@@ -1,41 +0,0 @@
-# Tiktoken والتفاعل مع Transformers
-
-يتم دمج دعم ملفات نموذج tiktoken بسلاسة في 🤗 transformers عند تحميل النماذج
-`from_pretrained` مع ملف `tokenizer.model` tiktoken على Hub، والذي يتم تحويله تلقائيًا إلى [المحلل اللغوي السريع](https://huggingface.co/docs/transformers/main/en/main_classes/tokenizer#transformers.PreTrainedTokenizerFast).
-
-### النماذج المعروفة التي تم إصدارها مع `tiktoken.model`:
- - gpt2
- - llama3
-
-## مثال على الاستخدام
-
-من أجل تحميل ملفات `tiktoken` في `transformers`، تأكد من أن ملف `tokenizer.model` هو ملف tiktoken وسيتم تحميله تلقائيًا عند التحميل `from_pretrained`. إليك كيفية تحميل مجزىء لغوي ونموذج، والذي
-يمكن تحميله من نفس الملف بالضبط:
-
-```py
-from transformers import AutoTokenizer
-
-model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
-tokenizer = AutoTokenizer.from_pretrained(model_id, subfolder="original")
-```
-## إنشاء مجزىء لغوي tiktoken
-
-لا يحتوي ملف `tokenizer.model` على أي معلومات حول الرموز أو الأنماط الإضافية. إذا كانت هذه الأمور مهمة، قم بتحويل المحلل اللغوي إلى `tokenizer.json`، وهو التنسيق المناسب لـ [`PreTrainedTokenizerFast`].
-
-قم بتوليد ملف `tokenizer.model` باستخدام [tiktoken.get_encoding](https://github.com/openai/tiktoken/blob/63527649963def8c759b0f91f2eb69a40934e468/tiktoken/registry.py#L63) ثم قم بتحويله إلى `tokenizer.json` باستخدام [`convert_tiktoken_to_fast`].
-
-```py
-
-from transformers.integrations.tiktoken import convert_tiktoken_to_fast
-from tiktoken import get_encoding
-
-# يمكنك تحميل ترميزك المخصص أو الترميز الذي توفره OpenAI
-encoding = get_encoding("gpt2")
-convert_tiktoken_to_fast(encoding, "config/save/dir")
-```
-
-يتم حفظ ملف `tokenizer.json` الناتج في الدليل المحدد ويمكن تحميله باستخدام [`PreTrainedTokenizerFast`].
-
-```py
-tokenizer = PreTrainedTokenizerFast.from_pretrained("config/save/dir")
-```
diff --git a/docs/source/de/quicktour.md b/docs/source/de/quicktour.md
index c01609207fec2a..01cd7200750c4d 100644
--- a/docs/source/de/quicktour.md
+++ b/docs/source/de/quicktour.md
@@ -109,7 +109,7 @@ label: NEGATIVE, with score: 0.5309
Die [`pipeline`] kann auch über einen ganzen Datensatz iterieren. Starten wir mit der Installation der [🤗 Datasets](https://huggingface.co/docs/datasets/) Bibliothek:
```bash
-pip install datasets
+pip install datasets
```
Erstellen wir eine [`pipeline`] mit der Aufgabe die wir lösen und dem Modell welches wir nutzen möchten.
@@ -191,7 +191,7 @@ Wenn Sie kein Modell für Ihren Anwendungsfall finden können, müssen Sie ein v
-Unter der Haube arbeiten die Klassen [`AutoModelForSequenceClassification`] und [`AutoTokenizer`] zusammen, um die [`pipeline`] zu betreiben. Eine [`AutoClass`](./model_doc/auto) ist eine Abkürzung, die automatisch die Architektur eines trainierten Modells aus dessen Namen oder Pfad abruft. Sie müssen nur die passende `AutoClass` für Ihre Aufgabe und den zugehörigen Tokenizer mit [`AutoTokenizer`] auswählen.
+Unter der Haube arbeiten die Klassen [`AutoModelForSequenceClassification`] und [`AutoTokenizer`] zusammen, um die [`pipeline`] zu betreiben. Eine [`AutoClass`](./model_doc/auto) ist eine Abkürzung, die automatisch die Architektur eines trainierten Modells aus dessen Namen oder Pfad abruft. Sie müssen nur die passende `AutoClass` für Ihre Aufgabe und den zugehörigen Tokenizer mit [`AutoTokenizer`] auswählen.
Kehren wir zu unserem Beispiel zurück und sehen wir uns an, wie Sie die `AutoClass` verwenden können, um die Ergebnisse der [`pipeline`] zu replizieren.
@@ -281,7 +281,7 @@ Jetzt können Sie Ihren vorverarbeiteten Stapel von Eingaben direkt an das Model
```
Das Modell gibt die endgültigen Aktivierungen in dem Attribut "logits" aus. Wenden Sie die Softmax-Funktion auf die "logits" an, um die Wahrscheinlichkeiten zu erhalten:
-
+
```py
>>> from torch import nn
@@ -308,7 +308,7 @@ In der [Aufgabenzusammenfassung](./task_summary) steht, welche [AutoModel]-Klass
Jetzt können Sie Ihren vorverarbeiteten Stapel von Eingaben direkt an das Modell übergeben, indem Sie die Wörterbuchschlüssel direkt an die Tensoren übergeben:
-
+
```py
>>> tf_outputs = tf_model(tf_batch)
```
@@ -383,8 +383,8 @@ Ein besonders cooles 🤗 Transformers-Feature ist die Möglichkeit, ein Modell
```py
>>> from transformers import AutoModel
->>> tokenizer = AutoTokenizer.from_pretrained(pt_save_directory)
->>> pt_model = AutoModelForSequenceClassification.from_pretrained(pt_save_directory, from_pt=True)
+>>> tokenizer = AutoTokenizer.from_pretrained(tf_save_directory)
+>>> pt_model = AutoModelForSequenceClassification.from_pretrained(tf_save_directory, from_tf=True)
```
@@ -392,8 +392,8 @@ Ein besonders cooles 🤗 Transformers-Feature ist die Möglichkeit, ein Modell
```py
>>> from transformers import TFAutoModel
->>> tokenizer = AutoTokenizer.from_pretrained(tf_save_directory)
->>> tf_model = TFAutoModelForSequenceClassification.from_pretrained(tf_save_directory, from_tf=True)
+>>> tokenizer = AutoTokenizer.from_pretrained(pt_save_directory)
+>>> tf_model = TFAutoModelForSequenceClassification.from_pretrained(pt_save_directory, from_pt=True)
```
diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml
index e5d1f95a9da0d1..3efa1bb317fb06 100644
--- a/docs/source/en/_toctree.yml
+++ b/docs/source/en/_toctree.yml
@@ -167,8 +167,6 @@
title: AWQ
- local: quantization/aqlm
title: AQLM
- - local: quantization/vptq
- title: VPTQ
- local: quantization/quanto
title: Quanto
- local: quantization/eetq
@@ -324,8 +322,6 @@
sections:
- local: model_doc/albert
title: ALBERT
- - local: model_doc/bamba
- title: Bamba
- local: model_doc/bart
title: BART
- local: model_doc/barthez
@@ -366,8 +362,6 @@
title: CodeLlama
- local: model_doc/cohere
title: Cohere
- - local: model_doc/cohere2
- title: Cohere2
- local: model_doc/convbert
title: ConvBERT
- local: model_doc/cpm
@@ -402,8 +396,6 @@
title: ESM
- local: model_doc/falcon
title: Falcon
- - local: model_doc/falcon3
- title: Falcon3
- local: model_doc/falcon_mamba
title: FalconMamba
- local: model_doc/fastspeech2_conformer
@@ -502,8 +494,6 @@
title: mLUKE
- local: model_doc/mobilebert
title: MobileBERT
- - local: model_doc/modernbert
- title: ModernBert
- local: model_doc/mpnet
title: MPNet
- local: model_doc/mpt
@@ -717,8 +707,6 @@
title: Swin2SR
- local: model_doc/table-transformer
title: Table Transformer
- - local: model_doc/timm_wrapper
- title: Timm Wrapper
- local: model_doc/upernet
title: UperNet
- local: model_doc/van
@@ -844,8 +832,6 @@
title: CLIPSeg
- local: model_doc/clvp
title: CLVP
- - local: model_doc/colpali
- title: ColPali
- local: model_doc/data2vec
title: Data2Vec
- local: model_doc/deplot
diff --git a/docs/source/en/add_new_pipeline.md b/docs/source/en/add_new_pipeline.md
index e8234c565b26c8..1e5b95e9b48cfc 100644
--- a/docs/source/en/add_new_pipeline.md
+++ b/docs/source/en/add_new_pipeline.md
@@ -184,7 +184,7 @@ class PairClassificationPipeline(Pipeline):
```
The implementation is framework agnostic, and will work for PyTorch and TensorFlow models. If we have saved this in
-a file named `pair_classification.py`, we can then import it and register it like this.
+a file named `pair_classification.py`, we can then import it and register it like this:
```py
from pair_classification import PairClassificationPipeline
@@ -199,22 +199,6 @@ PIPELINE_REGISTRY.register_pipeline(
)
```
-The [register_pipeline](https://github.com/huggingface/transformers/blob/9feae5fb0164e89d4998e5776897c16f7330d3df/src/transformers/pipelines/base.py#L1387) function registers the pipeline details (task type, pipeline class, supported backends) to a models `config.json` file.
-
-```json
- "custom_pipelines": {
- "pair-classification": {
- "impl": "pair_classification.PairClassificationPipeline",
- "pt": [
- "AutoModelForSequenceClassification"
- ],
- "tf": [
- "TFAutoModelForSequenceClassification"
- ],
- }
- },
-```
-
Once this is done, we can use it with a pretrained model. For instance `sgugger/finetuned-bert-mrpc` has been
fine-tuned on the MRPC dataset, which classifies pairs of sentences as paraphrases or not.
diff --git a/docs/source/en/chat_templating.md b/docs/source/en/chat_templating.md
index 0108cb48e95cee..1bdf05a26c8d08 100644
--- a/docs/source/en/chat_templating.md
+++ b/docs/source/en/chat_templating.md
@@ -683,7 +683,7 @@ one is a little simplified from the actual one!
```
{%- for message in messages %}
- {{- '<|' + message['role'] + '|>\n' }}
+ {{- '<|' + message['role'] + |>\n' }}
{{- message['content'] + eos_token }}
{%- endfor %}
{%- if add_generation_prompt %}
@@ -1116,4 +1116,4 @@ name to be included in the tool response, then rendering it can be as simple as:
```
Again, remember that the actual formatting and special tokens are model-specific - you should take a lot of care
-to ensure that tokens, whitespace and everything else exactly match the format your model was trained with!
+to ensure that tokens, whitespace and everything else exactly match the format your model was trained with!
\ No newline at end of file
diff --git a/docs/source/en/index.md b/docs/source/en/index.md
index d9c14ead608c8a..ad4d6db0dc3152 100644
--- a/docs/source/en/index.md
+++ b/docs/source/en/index.md
@@ -66,7 +66,6 @@ Flax), PyTorch, and/or TensorFlow.
| [AriaText](model_doc/aria_text) | ✅ | ❌ | ❌ |
| [Audio Spectrogram Transformer](model_doc/audio-spectrogram-transformer) | ✅ | ❌ | ❌ |
| [Autoformer](model_doc/autoformer) | ✅ | ❌ | ❌ |
-| [Bamba](model_doc/bamba) | ✅ | ❌ | ❌ |
| [Bark](model_doc/bark) | ✅ | ❌ | ❌ |
| [BART](model_doc/bart) | ✅ | ✅ | ✅ |
| [BARThez](model_doc/barthez) | ✅ | ✅ | ✅ |
@@ -100,8 +99,6 @@ Flax), PyTorch, and/or TensorFlow.
| [CodeGen](model_doc/codegen) | ✅ | ❌ | ❌ |
| [CodeLlama](model_doc/code_llama) | ✅ | ❌ | ✅ |
| [Cohere](model_doc/cohere) | ✅ | ❌ | ❌ |
-| [Cohere2](model_doc/cohere2) | ✅ | ❌ | ❌ |
-| [ColPali](model_doc/colpali) | ✅ | ❌ | ❌ |
| [Conditional DETR](model_doc/conditional_detr) | ✅ | ❌ | ❌ |
| [ConvBERT](model_doc/convbert) | ✅ | ✅ | ❌ |
| [ConvNeXT](model_doc/convnext) | ✅ | ✅ | ❌ |
@@ -143,7 +140,6 @@ Flax), PyTorch, and/or TensorFlow.
| [ESM](model_doc/esm) | ✅ | ✅ | ❌ |
| [FairSeq Machine-Translation](model_doc/fsmt) | ✅ | ❌ | ❌ |
| [Falcon](model_doc/falcon) | ✅ | ❌ | ❌ |
-| [Falcon3](model_doc/falcon3) | ✅ | ❌ | ✅ |
| [FalconMamba](model_doc/falcon_mamba) | ✅ | ❌ | ❌ |
| [FastSpeech2Conformer](model_doc/fastspeech2_conformer) | ✅ | ❌ | ❌ |
| [FLAN-T5](model_doc/flan-t5) | ✅ | ✅ | ✅ |
@@ -233,7 +229,6 @@ Flax), PyTorch, and/or TensorFlow.
| [MobileNetV2](model_doc/mobilenet_v2) | ✅ | ❌ | ❌ |
| [MobileViT](model_doc/mobilevit) | ✅ | ✅ | ❌ |
| [MobileViTV2](model_doc/mobilevitv2) | ✅ | ❌ | ❌ |
-| [ModernBERT](model_doc/modernbert) | ✅ | ❌ | ❌ |
| [Moshi](model_doc/moshi) | ✅ | ❌ | ❌ |
| [MPNet](model_doc/mpnet) | ✅ | ✅ | ❌ |
| [MPT](model_doc/mpt) | ✅ | ❌ | ❌ |
@@ -327,7 +322,6 @@ Flax), PyTorch, and/or TensorFlow.
| [TAPEX](model_doc/tapex) | ✅ | ✅ | ✅ |
| [Time Series Transformer](model_doc/time_series_transformer) | ✅ | ❌ | ❌ |
| [TimeSformer](model_doc/timesformer) | ✅ | ❌ | ❌ |
-| [TimmWrapperModel](model_doc/timm_wrapper) | ✅ | ❌ | ❌ |
| [Trajectory Transformer](model_doc/trajectory_transformer) | ✅ | ❌ | ❌ |
| [Transformer-XL](model_doc/transfo-xl) | ✅ | ✅ | ❌ |
| [TrOCR](model_doc/trocr) | ✅ | ❌ | ❌ |
diff --git a/docs/source/en/internal/generation_utils.md b/docs/source/en/internal/generation_utils.md
index d8931342ee45f8..a54ac432006a84 100644
--- a/docs/source/en/internal/generation_utils.md
+++ b/docs/source/en/internal/generation_utils.md
@@ -352,8 +352,6 @@ A [`Constraint`] can be used to force the generation to include specific tokens
[[autodoc]] TextIteratorStreamer
-[[autodoc]] AsyncTextIteratorStreamer
-
## Caches
[[autodoc]] Cache
diff --git a/docs/source/en/llm_optims.md b/docs/source/en/llm_optims.md
index 17ebb841de7a39..e97ace8a625050 100644
--- a/docs/source/en/llm_optims.md
+++ b/docs/source/en/llm_optims.md
@@ -473,7 +473,7 @@ with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable
Quantization reduces the size of the LLM weights by storing them in a lower precision. This translates to lower memory usage and makes loading LLMs for inference more accessible if you're constrained by your GPUs memory. If you aren't limited by your GPU, you don't necessarily need to quantize your model because it can incur a small latency cost (except for AWQ and fused AWQ modules) due to the extra step required to quantize and dequantize the weights.
> [!TIP]
-> There are many quantization libraries (see the [Quantization](./quantization) guide for more details) available, such as Quanto, AQLM, VPTQ, AWQ, and AutoGPTQ. Feel free to try them out and see which one works best for your use case. We also recommend reading the [Overview of natively supported quantization schemes in 🤗 Transformers](https://hf.co/blog/overview-quantization-transformers) blog post which compares AutoGPTQ and bitsandbytes.
+> There are many quantization libraries (see the [Quantization](./quantization) guide for more details) available, such as Quanto, AQLM, AWQ, and AutoGPTQ. Feel free to try them out and see which one works best for your use case. We also recommend reading the [Overview of natively supported quantization schemes in 🤗 Transformers](https://hf.co/blog/overview-quantization-transformers) blog post which compares AutoGPTQ and bitsandbytes.
Use the Model Memory Calculator below to estimate and compare how much memory is required to load a model. For example, try estimating how much memory it costs to load [Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1).
diff --git a/docs/source/en/main_classes/image_processor.md b/docs/source/en/main_classes/image_processor.md
index cbf6ae95577f70..320916f1ce9421 100644
--- a/docs/source/en/main_classes/image_processor.md
+++ b/docs/source/en/main_classes/image_processor.md
@@ -27,7 +27,6 @@ from transformers import AutoImageProcessor
processor = AutoImageProcessor.from_pretrained("facebook/detr-resnet-50", use_fast=True)
```
-Note that `use_fast` will be set to `True` by default in a future release.
When using a fast image processor, you can also set the `device` argument to specify the device on which the processing should be done. By default, the processing is done on the same device as the inputs if the inputs are tensors, or on the CPU otherwise.
@@ -43,17 +42,21 @@ images_processed = processor(images, return_tensors="pt", device="cuda")
Here are some speed comparisons between the base and fast image processors for the `DETR` and `RT-DETR` models, and how they impact overall inference time:
-
-
-
-
+
+
+
+
+
+
-
-
-
-
+
+
+
+
+
+
These benchmarks were run on an [AWS EC2 g5.2xlarge instance](https://aws.amazon.com/ec2/instance-types/g5/), utilizing an NVIDIA A10G Tensor Core GPU.
diff --git a/docs/source/en/main_classes/quantization.md b/docs/source/en/main_classes/quantization.md
index 9b500b69374c88..3f44569697777b 100755
--- a/docs/source/en/main_classes/quantization.md
+++ b/docs/source/en/main_classes/quantization.md
@@ -34,10 +34,6 @@ Learn how to quantize models in the [Quantization](../quantization) guide.
[[autodoc]] AqlmConfig
-## VptqConfig
-
-[[autodoc]] VptqConfig
-
## AwqConfig
[[autodoc]] AwqConfig
diff --git a/docs/source/en/model_doc/bamba.md b/docs/source/en/model_doc/bamba.md
deleted file mode 100644
index 4ea8475edb885a..00000000000000
--- a/docs/source/en/model_doc/bamba.md
+++ /dev/null
@@ -1,64 +0,0 @@
-
-
-# Bamba
-
-
-## Overview
-
-Bamba-9B is a decoder-only language model based on the [Mamba-2](https://github.com/state-spaces/mamba) architecture and is designed to handle a wide range of text generation tasks. It is trained from scratch using a two-stage training approach. In the first stage, the model is trained on 2 trillion tokens from the Dolma v1.7 dataset. In the second stage, it undergoes additional training on 200 billion tokens, leveraging a carefully curated blend of high-quality data to further refine its performance and enhance output quality.
-
-Checkout all Bamba-9B model checkpoints [here](https://github.com/foundation-model-stack/bamba).
-
-## BambaConfig
-
-| Model | Params | # Layers | Hidden Dim. | Attention Heads | GQA | KV Heads | Context Length | Tied Embeddings |
-|-------------------|--------------|----------|-------------|-----------------|-----|----------|----------------|------------------|
-| Bamba | 9B (9.78B) | 32 | 4096 | 32 | Yes | 8 | 4096 | True |
-
-[[autodoc]] BambaConfig
-
-
-
-## BambaForCausalLM
-
-```python
-from transformers import AutoModelForCausalLM, AutoTokenizer
-
-model = AutoModelForCausalLM.from_pretrained("ibm-fms/Bamba-9B")
-tokenizer = AutoTokenizer.from_pretrained("ibm-fms/Bamba-9B")
-
-message = ["Mamba is a snake with following properties "]
-inputs = tokenizer(message, return_tensors='pt', return_token_type_ids=False)
-response = model.generate(**inputs, max_new_tokens=64)
-print(tokenizer.batch_decode(response, skip_special_tokens=True)[0])
-```
-
-[[autodoc]] BambaForCausalLM
- - forward
-
-This HF implementation is contributed by [ani300](https://github.com/ani300) and [fabianlim](https://github.com/fabianlim).
diff --git a/docs/source/en/model_doc/beit.md b/docs/source/en/model_doc/beit.md
index 25b0eafb26a039..f7605ebcdf90d4 100644
--- a/docs/source/en/model_doc/beit.md
+++ b/docs/source/en/model_doc/beit.md
@@ -71,43 +71,6 @@ alt="drawing" width="600"/>
BEiT pre-training. Taken from the original paper.
-### Using Scaled Dot Product Attention (SDPA)
-
-PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function
-encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the
-[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
-or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention)
-page for more information.
-
-SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set
-`attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used.
-
-```
-from transformers import BeitForImageClassification
-model = BeitForImageClassification.from_pretrained("microsoft/beit-base-patch16-224", attn_implementation="sdpa", torch_dtype=torch.float16)
-...
-```
-
-For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`).
-
-On a local benchmark (NVIDIA GeForce RTX 2060-8GB, PyTorch 2.5.1, OS Ubuntu 20.04) with `float16` and
-`microsoft/beit-base-patch16-224` model, we saw the following improvements during training and inference:
-
-#### Training
-
-| num_training_steps | batch_size | image_size | is_cuda | Time per batch (eager - s) | Time per batch (sdpa - s) | Speedup (%) | Eager peak mem (MB) | SDPA peak mem (MB) | Mem saving (%) |
-|--------------------|------------|--------------|---------|----------------------------|---------------------------|-------------|----------------------|--------------------|----------------|
-| 50 | 2 | (1048, 640) | True | 0.984 | 0.746 | 31.975 | 6738.915 | 4319.886 | 55.998 |
-
-#### Inference
-
-| Image batch size | Eager (s/iter) | Eager CI, % | Eager memory (MB) | SDPA (s/iter) | SDPA CI, % | SDPA memory (MB) | SDPA speedup | SDPA memory saved (%) |
-|-------------------:|-----------------:|:--------------|--------------------:|----------------:|:-------------|-------------------:|---------------:|----------------------:|
-| 1 | 0.012 | ±0.3% | 3.76657e+08 | 0.011 | ±0.5% | 3.75739e+08 | 1.05 | 0.244 |
-| 4 | 0.013 | ±0.1% | 4.03147e+08 | 0.011 | ±0.2% | 3.90554e+08 | 1.178 | 3.225 |
-| 16 | 0.045 | ±0.1% | 4.96697e+08 | 0.035 | ±0.1% | 4.51232e+08 | 1.304 | 10.076 |
-| 32 | 0.088 | ±0.1% | 6.24417e+08 | 0.066 | ±0.1% | 5.33488e+08 | 1.325 | 17.044 |
-
## Resources
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with BEiT.
diff --git a/docs/source/en/model_doc/cohere2.md b/docs/source/en/model_doc/cohere2.md
deleted file mode 100644
index 33e67d48fb0e8b..00000000000000
--- a/docs/source/en/model_doc/cohere2.md
+++ /dev/null
@@ -1,51 +0,0 @@
-# Cohere
-
-## Overview
-[C4AI Command R7B](https://cohere.com/blog/command-r7b) is an open weights research release of a 7B billion parameter model developed by Cohere and Cohere For AI. It has advanced capabilities optimized for various use cases, including reasoning, summarization, question answering, and code. The model is trained to perform sophisticated tasks including Retrieval Augmented Generation (RAG) and tool use. The model also has powerful agentic capabilities that can use and combine multiple tools over multiple steps to accomplish more difficult tasks. It obtains top performance on enterprise-relevant code use cases. C4AI Command R7B is a multilingual model trained on 23 languages.
-
-The model features three layers with sliding window attention (window size 4096) and ROPE for efficient local context modeling and relative positional encoding. A fourth layer uses global attention without positional embeddings, enabling unrestricted token interactions across the entire sequence.
-
-The model has been trained on 23 languages: English, French, Spanish, Italian, German, Portuguese, Japanese, Korean, Arabic, Chinese, Russian, Polish, Turkish, Vietnamese, Dutch, Czech, Indonesian, Ukrainian, Romanian, Greek, Hindi, Hebrew, and Persian.
-
-## Usage tips
-The model and tokenizer can be loaded via:
-
-```python
-# pip install transformers
-from transformers import AutoTokenizer, AutoModelForCausalLM
-
-model_id = "CohereForAI/c4ai-command-r7b-12-2024"
-tokenizer = AutoTokenizer.from_pretrained(model_id)
-model = AutoModelForCausalLM.from_pretrained(model_id)
-
-# Format message with the command-r chat template
-messages = [{"role": "user", "content": "Hello, how are you?"}]
-input_ids = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt")
-
-gen_tokens = model.generate(
- input_ids,
- max_new_tokens=100,
- do_sample=True,
- temperature=0.3,
-)
-
-gen_text = tokenizer.decode(gen_tokens[0])
-print(gen_text)
-```
-
-## Cohere2Config
-
-[[autodoc]] Cohere2Config
-
-## Cohere2Model
-
-[[autodoc]] Cohere2Model
- - forward
-
-
-## Cohere2ForCausalLM
-
-[[autodoc]] Cohere2ForCausalLM
- - forward
-
-
diff --git a/docs/source/en/model_doc/colpali.md b/docs/source/en/model_doc/colpali.md
deleted file mode 100644
index 3f6b0cbc6613a9..00000000000000
--- a/docs/source/en/model_doc/colpali.md
+++ /dev/null
@@ -1,90 +0,0 @@
-
-
-# ColPali
-
-## Overview
-
-The *ColPali* model was proposed in [ColPali: Efficient Document Retrieval with Vision Language Models](https://doi.org/10.48550/arXiv.2407.01449) by **Manuel Faysse***, **Hugues Sibille***, **Tony Wu***, Bilel Omrani, Gautier Viaud, Céline Hudelot, Pierre Colombo (* denotes equal contribution). Work lead by ILLUIN Technology.
-
-In our proposed *ColPali* approach, we leverage VLMs to construct efficient multi-vector embeddings directly from document images (“screenshots”) for document retrieval. We train the model to maximize the similarity between these document embeddings and the corresponding query embeddings, using the late interaction method introduced in ColBERT.
-
-Using *ColPali* removes the need for potentially complex and brittle layout recognition and OCR pipelines with a single model that can take into account both the textual and visual content (layout, charts, etc.) of a document.
-
-## Resources
-
-- The *ColPali* arXiv paper can be found [here](https://doi.org/10.48550/arXiv.2407.01449). 📄
-- The official blog post detailing ColPali can be found [here](https://huggingface.co/blog/manu/colpali). 📝
-- The original model implementation code for the ColPali model and for the `colpali-engine` package can be found [here](https://github.com/illuin-tech/colpali). 🌎
-- Cookbooks for learning to use the transformers-native version of *ColPali*, fine-tuning, and similarity maps generation can be found [here](https://github.com/tonywu71/colpali-cookbooks). 📚
-
-This model was contributed by [@tonywu71](https://huggingface.co/tonywu71) and [@yonigozlan](https://huggingface.co/yonigozlan).
-
-## Usage
-
-This example demonstrates how to use *ColPali* to embed both queries and images, calculate their similarity scores, and identify the most relevant matches. For a specific query, you can retrieve the top-k most similar images by selecting the ones with the highest similarity scores.
-
-```python
-import torch
-from PIL import Image
-
-from transformers import ColPaliForRetrieval, ColPaliProcessor
-
-model_name = "vidore/colpali-v1.2-hf"
-
-model = ColPaliForRetrieval.from_pretrained(
- model_name,
- torch_dtype=torch.bfloat16,
- device_map="cuda:0", # or "mps" if on Apple Silicon
-).eval()
-
-processor = ColPaliProcessor.from_pretrained(model_name)
-
-# Your inputs (replace dummy images with screenshots of your documents)
-images = [
- Image.new("RGB", (32, 32), color="white"),
- Image.new("RGB", (16, 16), color="black"),
-]
-queries = [
- "What is the organizational structure for our R&D department?",
- "Can you provide a breakdown of last year’s financial performance?",
-]
-
-# Process the inputs
-batch_images = processor(images=images).to(model.device)
-batch_queries = processor(text=queries).to(model.device)
-
-# Forward pass
-with torch.no_grad():
- image_embeddings = model(**batch_images).embeddings
- query_embeddings = model(**batch_queries).embeddings
-
-# Score the queries against the images
-scores = processor.score_retrieval(query_embeddings, image_embeddings)
-```
-
-## ColPaliConfig
-
-[[autodoc]] ColPaliConfig
-
-## ColPaliProcessor
-
-[[autodoc]] ColPaliProcessor
-
-## ColPaliForRetrieval
-
-[[autodoc]] ColPaliForRetrieval
- - forward
diff --git a/docs/source/en/model_doc/data2vec.md b/docs/source/en/model_doc/data2vec.md
index cb1dc675caa55e..517a51ce46a3a4 100644
--- a/docs/source/en/model_doc/data2vec.md
+++ b/docs/source/en/model_doc/data2vec.md
@@ -48,46 +48,6 @@ The original code for vision can be found [here](https://github.com/facebookrese
- For Data2VecText, preprocessing is identical to [`RobertaModel`], including tokenization.
- For Data2VecVision, preprocessing is identical to [`BeitModel`], including feature extraction.
-### Using Scaled Dot Product Attention (SDPA)
-
-PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function
-encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the
-[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
-or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention)
-page for more information.
-
-SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set
-`attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used.
-
-The SDPA implementation is currently available for the Data2VecAudio and Data2VecVision models.
-
-```
-from transformers import Data2VecVisionForImageClassification
-model = Data2VecVisionForImageClassification.from_pretrained("facebook/data2vec-vision-base", attn_implementation="sdpa", torch_dtype=torch.float16)
-...
-```
-
-For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`).
-
-For the Data2VecVision model, on a local benchmark (NVIDIA GeForce RTX 2060-8GB, PyTorch 2.5.1, OS Ubuntu 20.04)
-with `float16` and `facebook/data2vec-vision-base` model, we saw the following improvements during training and
-inference:
-
-#### Training
-
-| num_training_steps | batch_size | image_size | is_cuda | Time per batch (eager - s) | Time per batch (sdpa - s) | Speedup (%) | Eager peak mem (MB) | SDPA peak mem (MB) | Mem saving (%) |
-|--------------------|------------|--------------|---------|----------------------------|---------------------------|-------------|----------------------|--------------------|----------------|
-| 50 | 2 | (1048, 640) | True | 0.996 | 0.754 | 32.147 | 6722.198 | 4264.653 | 57.626 |
-
-#### Inference
-
-| Image batch size | Eager (s/iter) | Eager CI, % | Eager memory (MB) | SDPA (s/iter) | SDPA CI, % | SDPA memory (MB) | SDPA speedup | SDPA memory saved |
-|-------------------:|-----------------:|:--------------|--------------------:|----------------:|:-------------|-------------------:|---------------:|--------------------:|
-| 1 | 0.011 | ±0.3% | 3.76143e+08 | 0.01 | ±0.3% | 3.74397e+08 | 1.101 | 0.466 |
-| 4 | 0.014 | ±0.1% | 4.02756e+08 | 0.012 | ±0.2% | 3.91373e+08 | 1.219 | 2.909 |
-| 16 | 0.046 | ±0.3% | 4.96482e+08 | 0.035 | ±0.2% | 4.51017e+08 | 1.314 | 10.081 |
-| 32 | 0.088 | ±0.1% | 6.23903e+08 | 0.067 | ±0.1% | 5.32974e+08 | 1.33 | 17.061 |
-
## Resources
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with Data2Vec.
diff --git a/docs/source/en/model_doc/falcon3.md b/docs/source/en/model_doc/falcon3.md
deleted file mode 100644
index 813533dd7f4d0a..00000000000000
--- a/docs/source/en/model_doc/falcon3.md
+++ /dev/null
@@ -1,29 +0,0 @@
-
-
-# Falcon3
-
-## Overview
-
-Falcon3 represents a natural evolution from previous releases, emphasizing expanding the models' science, math, and code capabilities. This iteration includes five base models: Falcon3-1B-Base, Falcon3-3B-Base, Falcon3-Mamba-7B-Base, Falcon3-7B-Base, and Falcon3-10B-Base. In developing these models, we incorporated several key innovations aimed at improving the models' performances while reducing training costs:
-
-One pre-training: We conducted a single large-scale pretraining run on the 7B model, using 2048 H100 GPU chips, leveraging 14 trillion tokens featuring web, code, STEM, and curated high-quality and multilingual data.
-Depth up-scaling for improved reasoning: Building on recent studies on the effects of model depth, we upscaled the 7B model to a 10B parameters model by duplicating the redundant layers and continuing pre-training with 2TT of high-quality data. This yielded Falcon3-10B-Base which achieves state-of-the-art zero-shot and few-shot performance for models under 13B parameters.
-Knowledge distillation for better tiny models: To provide compact and efficient alternatives, we developed Falcon3-1B-Base and Falcon3-3B-Base by leveraging pruning and knowledge distillation techniques, using less than 100GT of curated high-quality data, thereby redefining pre-training efficiency.
-
-## Resources
-- [Blog post](https://huggingface.co/blog/falcon3)
-- [Models on Huggingface](https://huggingface.co/collections/tiiuae/falcon3-67605ae03578be86e4e87026)
diff --git a/docs/source/en/model_doc/idefics2.md b/docs/source/en/model_doc/idefics2.md
index b9b51082f29e5b..5ad56b7b5c525d 100644
--- a/docs/source/en/model_doc/idefics2.md
+++ b/docs/source/en/model_doc/idefics2.md
@@ -141,7 +141,7 @@ Do note that when training Idefics2 on multi-turn conversations between a user a
## Model optimizations: Flash Attention
-The code snippets above showcase inference without any optimization tricks. However, one can drastically speed up the model by leveraging [Flash Attention](../perf_train_gpu_one#flash-attention-2), which is a faster implementation of the attention mechanism used inside the model.
+The code snippets above showcase inference without any optimization tricks. However, one can drastically speed up the model by leveraging [Flash Attention](../perf_train_gpu_one.md#flash-attention-2), which is a faster implementation of the attention mechanism used inside the model.
First, make sure to install the latest version of Flash Attention 2 to include the sliding window attention feature.
diff --git a/docs/source/en/model_doc/ijepa.md b/docs/source/en/model_doc/ijepa.md
index cb2afd25e20bca..9a0cd368a8188f 100644
--- a/docs/source/en/model_doc/ijepa.md
+++ b/docs/source/en/model_doc/ijepa.md
@@ -18,18 +18,13 @@ rendered properly in your Markdown viewer.
## Overview
-The I-JEPA model was proposed in [Image-based Joint-Embedding Predictive Architecture](https://arxiv.org/abs/2301.08243) by Mahmoud Assran, Quentin Duval, Ishan Misra, Piotr Bojanowski, Pascal Vincent, Michael Rabbat, Yann LeCun, Nicolas Ballas.
+The I-JEPA model was proposed in [Image-based Joint-Embedding Predictive Architecture](https://arxiv.org/pdf/2301.08243.pdf) by Mahmoud Assran, Quentin Duval, Ishan Misra, Piotr Bojanowski, Pascal Vincent, Michael Rabbat, Yann LeCun, Nicolas Ballas.
I-JEPA is a self-supervised learning method that predicts the representations of one part of an image based on other parts of the same image. This approach focuses on learning semantic features without relying on pre-defined invariances from hand-crafted data transformations, which can bias specific tasks, or on filling in pixel-level details, which often leads to less meaningful representations.
The abstract from the paper is the following:
This paper demonstrates an approach for learning highly semantic image representations without relying on hand-crafted data-augmentations. We introduce the Image- based Joint-Embedding Predictive Architecture (I-JEPA), a non-generative approach for self-supervised learning from images. The idea behind I-JEPA is simple: from a single context block, predict the representations of various target blocks in the same image. A core design choice to guide I-JEPA towards producing semantic representations is the masking strategy; specifically, it is crucial to (a) sample tar- get blocks with sufficiently large scale (semantic), and to (b) use a sufficiently informative (spatially distributed) context block. Empirically, when combined with Vision Transform- ers, we find I-JEPA to be highly scalable. For instance, we train a ViT-Huge/14 on ImageNet using 16 A100 GPUs in under 72 hours to achieve strong downstream performance across a wide range of tasks, from linear classification to object counting and depth prediction.
-
-
- I-JEPA architecture. Taken from the original paper.
-
This model was contributed by [jmtzt](https://huggingface.co/jmtzt).
The original code can be found [here](https://github.com/facebookresearch/ijepa).
@@ -50,7 +45,7 @@ url_2 = "http://images.cocodataset.org/val2017/000000219578.jpg"
image_1 = Image.open(requests.get(url_1, stream=True).raw)
image_2 = Image.open(requests.get(url_2, stream=True).raw)
-model_id = "facebook/ijepa_vith14_1k"
+model_id = "jmtzt/ijepa_vith14_1k"
processor = AutoProcessor.from_pretrained(model_id)
model = AutoModel.from_pretrained(model_id)
@@ -68,15 +63,6 @@ similarity = cosine_similarity(embed_1, embed_2)
print(similarity)
```
-## Resources
-
-A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with I-JEPA.
-
-
-
-- [`IJepaForImageClassification`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/pytorch/image-classification) and [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/image_classification.ipynb).
-- See also: [Image classification task guide](../tasks/image_classification)
-
## IJepaConfig
[[autodoc]] IJepaConfig
@@ -89,4 +75,4 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h
## IJepaForImageClassification
[[autodoc]] IJepaForImageClassification
- - forward
\ No newline at end of file
+ - forward
diff --git a/docs/source/en/model_doc/llava.md b/docs/source/en/model_doc/llava.md
index e883572995e924..dec19ca5ef45db 100644
--- a/docs/source/en/model_doc/llava.md
+++ b/docs/source/en/model_doc/llava.md
@@ -131,7 +131,7 @@ prompt_2 = processor.apply_chat_template(conversation_2, add_generation_prompt=T
prompts = [prompt_1, prompt_2]
# We can simply feed images in the order they have to be used in the text prompt
-inputs = processor(images=[image_stop, image_cats], text=prompts, padding=True, return_tensors="pt").to(model.device, torch.float16)
+inputs = processor(images=[image_stop, image_cats, image_snowman], text=prompts, padding=True, return_tensors="pt").to(model.device, torch.float16)
# Generate
generate_ids = model.generate(**inputs, max_new_tokens=30)
diff --git a/docs/source/en/model_doc/llava_next_video.md b/docs/source/en/model_doc/llava_next_video.md
index cc3a61aae6c736..f8a149f12b6779 100644
--- a/docs/source/en/model_doc/llava_next_video.md
+++ b/docs/source/en/model_doc/llava_next_video.md
@@ -240,7 +240,7 @@ model = LlavaNextVideoForConditionalGeneration.from_pretrained("llava-hf/LLaVA-N
### Flash-Attention 2 to speed-up generation
-Additionally, we can greatly speed-up model inference by using [Flash Attention](../perf_train_gpu_one#flash-attention-2), which is a faster implementation of the attention mechanism used inside the model.
+Additionally, we can greatly speed-up model inference by using [Flash Attention](../perf_train_gpu_one.md#flash-attention-2), which is a faster implementation of the attention mechanism used inside the model.
First, make sure to install the latest version of Flash Attention 2:
diff --git a/docs/source/en/model_doc/mistral.md b/docs/source/en/model_doc/mistral.md
index cfa2af3678137a..2be657109a8d46 100644
--- a/docs/source/en/model_doc/mistral.md
+++ b/docs/source/en/model_doc/mistral.md
@@ -91,7 +91,7 @@ As can be seen, the instruction-tuned model requires a [chat template](../chat_t
## Speeding up Mistral by using Flash Attention
-The code snippets above showcase inference without any optimization tricks. However, one can drastically speed up the model by leveraging [Flash Attention](../perf_train_gpu_one#flash-attention-2), which is a faster implementation of the attention mechanism used inside the model.
+The code snippets above showcase inference without any optimization tricks. However, one can drastically speed up the model by leveraging [Flash Attention](../perf_train_gpu_one.md#flash-attention-2), which is a faster implementation of the attention mechanism used inside the model.
First, make sure to install the latest version of Flash Attention 2 to include the sliding window attention feature.
diff --git a/docs/source/en/model_doc/mixtral.md b/docs/source/en/model_doc/mixtral.md
index b5451702e44a16..7afcaa798ecac4 100644
--- a/docs/source/en/model_doc/mixtral.md
+++ b/docs/source/en/model_doc/mixtral.md
@@ -93,7 +93,7 @@ As can be seen, the instruction-tuned model requires a [chat template](../chat_t
## Speeding up Mixtral by using Flash Attention
-The code snippets above showcase inference without any optimization tricks. However, one can drastically speed up the model by leveraging [Flash Attention](../perf_train_gpu_one#flash-attention-2), which is a faster implementation of the attention mechanism used inside the model.
+The code snippets above showcase inference without any optimization tricks. However, one can drastically speed up the model by leveraging [Flash Attention](../perf_train_gpu_one.md#flash-attention-2), which is a faster implementation of the attention mechanism used inside the model.
First, make sure to install the latest version of Flash Attention 2 to include the sliding window attention feature.
diff --git a/docs/source/en/model_doc/modernbert.md b/docs/source/en/model_doc/modernbert.md
deleted file mode 100644
index b641d7f3f58199..00000000000000
--- a/docs/source/en/model_doc/modernbert.md
+++ /dev/null
@@ -1,95 +0,0 @@
-
-
-# ModernBert
-
-
-
-## Overview
-
-The ModernBert model was proposed in [Smarter, Better, Faster, Longer: A Modern Bidirectional Encoder for Fast, Memory Efficient, and Long Context Finetuning and Inference](https://arxiv.org/abs/2412.13663) by Benjamin Warner, Antoine Chaffin, Benjamin Clavié, Orion Weller, Oskar Hallström, Said Taghadouini, Alexis Galalgher, Raja Bisas, Faisal Ladhak, Tom Aarsen, Nathan Cooper, Grifin Adams, Jeremy Howard and Iacopo Poli.
-
-It is a refresh of the traditional encoder architecture, as used in previous models such as [BERT](https://huggingface.co/docs/transformers/en/model_doc/bert) and [RoBERTa](https://huggingface.co/docs/transformers/en/model_doc/roberta).
-
-It builds on BERT and implements many modern architectural improvements which have been developed since its original release, such as:
-- [Rotary Positional Embeddings](https://huggingface.co/blog/designing-positional-encoding) to support sequences of up to 8192 tokens.
-- [Unpadding](https://arxiv.org/abs/2208.08124) to ensure no compute is wasted on padding tokens, speeding up processing time for batches with mixed-length sequences.
-- [GeGLU](https://arxiv.org/abs/2002.05202) Replacing the original MLP layers with GeGLU layers, shown to improve performance.
-- [Alternating Attention](https://arxiv.org/abs/2004.05150v2) where most attention layers employ a sliding window of 128 tokens, with Global Attention only used every 3 layers.
-- [Flash Attention](https://github.com/Dao-AILab/flash-attention) to speed up processing.
-- A model designed following recent [The Case for Co-Designing Model Architectures with Hardware](https://arxiv.org/abs/2401.14489), ensuring maximum efficiency across inference GPUs.
-- Modern training data scales (2 trillion tokens) and mixtures (including code ande math data)
-
-The abstract from the paper is the following:
-
-*Encoder-only transformer models such as BERT offer a great performance-size tradeoff for retrieval and classification tasks with respect to larger decoder-only models. Despite being the workhorse of numerous production pipelines, there have been limited Pareto improvements to BERT since its release. In this paper, we introduce ModernBERT, bringing modern model optimizations to encoder-only models and representing a major Pareto improvement over older encoders. Trained on 2 trillion tokens with a native 8192 sequence length, ModernBERT models exhibit state-of-the-art results on a large pool of evaluations encompassing diverse classification tasks and both single and multi-vector retrieval on different domains (including code). In addition to strong downstream performance, ModernBERT is also the most speed and memory efficient encoder and is designed for inference on common GPUs.*
-
-The original code can be found [here](https://github.com/answerdotai/modernbert).
-
-## Resources
-
-A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with ModernBert.
-
-
-
-- A notebook on how to [finetune for General Language Understanding Evaluation (GLUE) with Transformers](https://github.com/AnswerDotAI/ModernBERT/blob/main/examples/finetune_modernbert_on_glue.ipynb), also available as a Google Colab [notebook](https://colab.research.google.com/github/AnswerDotAI/ModernBERT/blob/main/examples/finetune_modernbert_on_glue.ipynb). 🌎
-
-
-
-- A script on how to [finetune for text similarity or information retrieval with Sentence Transformers](https://github.com/AnswerDotAI/ModernBERT/blob/main/examples/train_st.py). 🌎
-- A script on how to [finetune for information retrieval with PyLate](https://github.com/AnswerDotAI/ModernBERT/blob/main/examples/train_pylate.py). 🌎
-
-
-
-- [Masked language modeling task guide](../tasks/masked_language_modeling)
-
-
-## ModernBertConfig
-
-[[autodoc]] ModernBertConfig
-
-
-
-
-## ModernBertModel
-
-[[autodoc]] ModernBertModel
- - forward
-
-## ModernBertForMaskedLM
-
-[[autodoc]] ModernBertForMaskedLM
- - forward
-
-## ModernBertForSequenceClassification
-
-[[autodoc]] ModernBertForSequenceClassification
- - forward
-
-## ModernBertForTokenClassification
-
-[[autodoc]] ModernBertForTokenClassification
- - forward
-
-
-
diff --git a/docs/source/en/model_doc/timm_wrapper.md b/docs/source/en/model_doc/timm_wrapper.md
deleted file mode 100644
index 5af3d51746c325..00000000000000
--- a/docs/source/en/model_doc/timm_wrapper.md
+++ /dev/null
@@ -1,67 +0,0 @@
-
-
-# TimmWrapper
-
-## Overview
-
-Helper class to enable loading timm models to be used with the transformers library and its autoclasses.
-
-```python
->>> import torch
->>> from PIL import Image
->>> from urllib.request import urlopen
->>> from transformers import AutoModelForImageClassification, AutoImageProcessor
-
->>> # Load image
->>> image = Image.open(urlopen(
-... 'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'
-... ))
-
->>> # Load model and image processor
->>> checkpoint = "timm/resnet50.a1_in1k"
->>> image_processor = AutoImageProcessor.from_pretrained(checkpoint)
->>> model = AutoModelForImageClassification.from_pretrained(checkpoint).eval()
-
->>> # Preprocess image
->>> inputs = image_processor(image)
-
->>> # Forward pass
->>> with torch.no_grad():
-... logits = model(**inputs).logits
-
->>> # Get top 5 predictions
->>> top5_probabilities, top5_class_indices = torch.topk(logits.softmax(dim=1) * 100, k=5)
-```
-
-## TimmWrapperConfig
-
-[[autodoc]] TimmWrapperConfig
-
-## TimmWrapperImageProcessor
-
-[[autodoc]] TimmWrapperImageProcessor
- - preprocess
-
-## TimmWrapperModel
-
-[[autodoc]] TimmWrapperModel
- - forward
-
-## TimmWrapperForImageClassification
-
-[[autodoc]] TimmWrapperForImageClassification
- - forward
diff --git a/docs/source/en/model_doc/video_llava.md b/docs/source/en/model_doc/video_llava.md
index a3ba1258ecfa06..105307196effd0 100644
--- a/docs/source/en/model_doc/video_llava.md
+++ b/docs/source/en/model_doc/video_llava.md
@@ -174,7 +174,7 @@ model = VideoLlavaForConditionalGeneration.from_pretrained("LanguageBind/Video-L
### Flash-Attention 2 to speed-up generation
-Additionally, we can greatly speed-up model inference by using [Flash Attention](../perf_train_gpu_one#flash-attention-2), which is a faster implementation of the attention mechanism used inside the model.
+Additionally, we can greatly speed-up model inference by using [Flash Attention](../perf_train_gpu_one.md#flash-attention-2), which is a faster implementation of the attention mechanism used inside the model.
First, make sure to install the latest version of Flash Attention 2:
diff --git a/docs/source/en/modular_transformers.md b/docs/source/en/modular_transformers.md
index 8eebbf347c11c3..1516233ec4d6e1 100644
--- a/docs/source/en/modular_transformers.md
+++ b/docs/source/en/modular_transformers.md
@@ -22,9 +22,6 @@ etc. Model contribution PRs rarely add less than 3-5k lines of code, with much o
This raises the bar for contributions, and with Modular Transformers, we're aiming to lower the bar to a much more
acceptable point.
-If you plan to add a model to `transformers` make sure you read [How to add a model to 🤗 Transformers?](https://huggingface.co/docs/transformers/add_new_model).
-For any kind of contributions, see [CONTRIBUTING.md](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md).
-
## What is it?
Modular Transformers introduces the concept of a "modular" file to a model folder. This modular file accepts code
@@ -46,12 +43,6 @@ be moved to the new Modular Transformers format in the coming months.
### Details
-To generate a single file from the modular file, run the following command.
-
-```bash
-python utils/modular_model_converter.py --files-to-parse src/transformers/models//modular_.py
-```
-
The "linter", which unravels the inheritance and creates all single-files from the modular file, will flatten the
inheritance while trying to be invisible to Python users. At this time, the linter flattens a **single** level of
inheritance.
@@ -68,11 +59,7 @@ file, and the corresponding files will be created for you.
### Enforcement
-Run the command below to ensure the generated content matches `modular_.py`
-
-```bash
-python utils/check_modular_conversion.py --files src/transformers/models//modular_.py
-```
+[TODO] We are introducing a new test, that makes sure the generated content matches what is present in the `modular_xxxx.py`
### Examples
@@ -207,4 +194,4 @@ We now also support special cases like
class GemmaVisionModel(CLIPModel):
pass
```
-where the name of your class `GemmaVision` is not the same as the modular `Gemma`. This is super useful for composite models.
+where the name of your class `GemmaVision` is not the same as the modular `Gemma`. This is super useful for composite models.
\ No newline at end of file
diff --git a/docs/source/en/perf_infer_gpu_multi.md b/docs/source/en/perf_infer_gpu_multi.md
index ea9421747c13df..9975094411527a 100644
--- a/docs/source/en/perf_infer_gpu_multi.md
+++ b/docs/source/en/perf_infer_gpu_multi.md
@@ -64,5 +64,5 @@ You can benefit from considerable speedups for inference, especially for inputs
For a single forward pass on [Llama](https://huggingface.co/docs/transformers/model_doc/llama#transformers.LlamaModel) with a sequence length of 512 and various batch sizes, the expected speedup is as follows:
-
+
diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md
index a3a3a7c7bc4232..692d7110272b75 100644
--- a/docs/source/en/perf_infer_gpu_one.md
+++ b/docs/source/en/perf_infer_gpu_one.md
@@ -39,12 +39,10 @@ FlashAttention-2 is experimental and may change considerably in future versions.
FlashAttention-2 is currently supported for the following architectures:
* [Aria](https://huggingface.co/docs/transformers/model_doc/aria#transformers.AriaForConditionalGeneration)
* [Bark](https://huggingface.co/docs/transformers/model_doc/bark#transformers.BarkModel)
-* [Bamba](https://huggingface.co/docs/transformers/model_doc/bamba#transformers.BambaModel)
* [Bart](https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartModel)
* [Chameleon](https://huggingface.co/docs/transformers/model_doc/chameleon#transformers.Chameleon)
* [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPModel)
* [Cohere](https://huggingface.co/docs/transformers/model_doc/cohere#transformers.CohereModel)
-* [Cohere2](https://huggingface.co/docs/transformers/model_doc/cohere2#transformers.Cohere2Model)
* [GLM](https://huggingface.co/docs/transformers/model_doc/glm#transformers.GLMModel)
* [Dbrx](https://huggingface.co/docs/transformers/model_doc/dbrx#transformers.DbrxModel)
* [DiffLlama](https://huggingface.co/docs/transformers/model_doc/diffllama#transformers.DiffLlamaModel)
@@ -75,7 +73,6 @@ FlashAttention-2 is currently supported for the following architectures:
* [MBart](https://huggingface.co/docs/transformers/model_doc/mbart#transformers.MBartModel)
* [Mistral](https://huggingface.co/docs/transformers/model_doc/mistral#transformers.MistralModel)
* [Mixtral](https://huggingface.co/docs/transformers/model_doc/mixtral#transformers.MixtralModel)
-* [ModernBert](https://huggingface.co/docs/transformers/model_doc/modernbert#transformers.ModernBert)
* [Moshi](https://huggingface.co/docs/transformers/model_doc/moshi#transformers.MoshiModel)
* [Musicgen](https://huggingface.co/docs/transformers/model_doc/musicgen#transformers.MusicgenModel)
* [MusicGen Melody](https://huggingface.co/docs/transformers/model_doc/musicgen_melody#transformers.MusicgenMelodyModel)
@@ -223,9 +220,7 @@ For now, Transformers supports SDPA inference and training for the following arc
* [Albert](https://huggingface.co/docs/transformers/model_doc/albert#transformers.AlbertModel)
* [Aria](https://huggingface.co/docs/transformers/model_doc/aria#transformers.AriaForConditionalGeneration)
* [Audio Spectrogram Transformer](https://huggingface.co/docs/transformers/model_doc/audio-spectrogram-transformer#transformers.ASTModel)
-* [Bamba](https://huggingface.co/docs/transformers/model_doc/bamba#transformers.BambaModel)
* [Bart](https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartModel)
-* [Beit](https://huggingface.co/docs/transformers/model_doc/beit#transformers.BeitModel)
* [Bert](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertModel)
* [BioGpt](https://huggingface.co/docs/transformers/model_doc/biogpt#transformers.BioGptModel)
* [CamemBERT](https://huggingface.co/docs/transformers/model_doc/camembert#transformers.CamembertModel)
@@ -233,9 +228,7 @@ For now, Transformers supports SDPA inference and training for the following arc
* [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPModel)
* [GLM](https://huggingface.co/docs/transformers/model_doc/glm#transformers.GLMModel)
* [Cohere](https://huggingface.co/docs/transformers/model_doc/cohere#transformers.CohereModel)
-* [Cohere2](https://huggingface.co/docs/transformers/model_doc/cohere2#transformers.Cohere2Model)
* [data2vec_audio](https://huggingface.co/docs/transformers/main/en/model_doc/data2vec#transformers.Data2VecAudioModel)
-* [data2vec_vision](https://huggingface.co/docs/transformers/main/en/model_doc/data2vec#transformers.Data2VecVisionModel)
* [Dbrx](https://huggingface.co/docs/transformers/model_doc/dbrx#transformers.DbrxModel)
* [DeiT](https://huggingface.co/docs/transformers/model_doc/deit#transformers.DeiTModel)
* [DiffLlama](https://huggingface.co/docs/transformers/model_doc/diffllama#transformers.DiffLlamaModel)
@@ -268,7 +261,6 @@ For now, Transformers supports SDPA inference and training for the following arc
* [Mistral](https://huggingface.co/docs/transformers/model_doc/mistral#transformers.MistralModel)
* [Mllama](https://huggingface.co/docs/transformers/model_doc/mllama#transformers.MllamaForConditionalGeneration)
* [Mixtral](https://huggingface.co/docs/transformers/model_doc/mixtral#transformers.MixtralModel)
-* [ModernBert](https://huggingface.co/docs/transformers/model_doc/modernbert#transformers.ModernBert)
* [Moshi](https://huggingface.co/docs/transformers/model_doc/moshi#transformers.MoshiModel)
* [Musicgen](https://huggingface.co/docs/transformers/model_doc/musicgen#transformers.MusicgenModel)
* [MusicGen Melody](https://huggingface.co/docs/transformers/model_doc/musicgen_melody#transformers.MusicgenMelodyModel)
diff --git a/docs/source/en/quantization/overview.md b/docs/source/en/quantization/overview.md
index f3508aed0674f6..0fb72d26058e55 100644
--- a/docs/source/en/quantization/overview.md
+++ b/docs/source/en/quantization/overview.md
@@ -58,7 +58,6 @@ Use the table below to help you decide which quantization method to use.
| [optimum-quanto](./quanto) | 🟢 | 🟢 | 🟢 | 🔴 | 🟢 | 🔴 | 🟢 | 2 / 4 / 8 | 🔴 | 🔴 | 🟢 | https://github.com/huggingface/optimum-quanto |
| [FBGEMM_FP8](./fbgemm_fp8.md) | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 🔴 | 8 | 🔴 | 🟢 | 🟢 | https://github.com/pytorch/FBGEMM |
| [torchao](./torchao.md) | 🟢 | | 🟢 | 🔴 | partial support (int4 weight only) | 🔴 | | 4 / 8 | | 🟢🔴 | 🟢 | https://github.com/pytorch/ao |
-| [VPTQ](./vptq) | 🔴 | 🔴 | 🟢 | 🟡 | 🔴 | 🔴 | 🟢 | 1 - 8 | 🔴 | 🟢 | 🟢 | https://github.com/microsoft/VPTQ |
@@ -72,4 +71,4 @@ We value your feedback to help identify bugs before the full release! Check out
\** bitsandbytes is seeking contributors to help develop and lead the Apple Silicon backend. Interested? Contact them directly via their repo. Stipends may be available through sponsorships.
-
\ No newline at end of file
+
diff --git a/docs/source/en/quantization/vptq.md b/docs/source/en/quantization/vptq.md
deleted file mode 100644
index b86e82f0a3503d..00000000000000
--- a/docs/source/en/quantization/vptq.md
+++ /dev/null
@@ -1,111 +0,0 @@
-
-
-# VPTQ
-
-> [!TIP]
-> Try VPTQ on [Hugging Face](https://huggingface.co/spaces/microsoft/VPTQ)!
-> Try VPTQ on [Google Colab](https://colab.research.google.com/github/microsoft/VPTQ/blob/main/notebooks/vptq_example.ipynb)!
-> Know more about VPTQ on [ArXiv](https://arxiv.org/pdf/2409.17066)!
-
-Vector Post-Training Quantization ([VPTQ](https://github.com/microsoft/VPTQ)) is a novel Post-Training Quantization method that leverages Vector Quantization to high accuracy on LLMs at an extremely low bit-width (<2-bit). VPTQ can compress 70B, even the 405B model, to 1-2 bits without retraining and maintain high accuracy.
-
-- Better Accuracy on 1-2 bits, (405B @ <2bit, 70B @ 2bit)
-- Lightweight Quantization Algorithm: only cost ~17 hours to quantize 405B Llama-3.1
-- Agile Quantization Inference: low decode overhead, best throughput, and TTFT
-
-Inference support for VPTQ is released in the `vptq` library. Make sure to install it to run the models:
-```bash
-pip install vptq
-```
-
-The library provides efficient kernels for NVIDIA/AMD GPU inference.
-
-To run VPTQ models simply load a model that has been quantized with VPTQ:
-
-## Inference example
-**Run Llama 3.1 70b on RTX4090 (24G @ ~2bits) in real time**
-![Llama3 1-70b-prompt](https://github.com/user-attachments/assets/d8729aca-4e1d-4fe1-ac71-c14da4bdd97f)
-
-
-```python
-from transformers import AutoTokenizer, AutoModelForCausalLM
-
-quantized_model = AutoModelForCausalLM.from_pretrained(
- "VPTQ-community/Meta-Llama-3.1-70B-Instruct-v16-k65536-65536-woft",
- torch_dtype="auto",
- device_map="auto"
-)
-tokenizer = AutoTokenizer.from_pretrained("VPTQ-community/Meta-Llama-3.1-70B-Instruct-v16-k65536-65536-woft")
-input_ids = tokenizer("hello, it's me", return_tensors="pt").to("cuda")
-out = model.generate(**input_ids, max_new_tokens=32, do_sample=False)
-```
-
-## Quantize your own model
-VPTQ algorithm early-released at [VPTQ ](https://github.com/microsoft/VPTQ/tree/algorithm),
-and checkout the [tutorial](https://github.com/microsoft/VPTQ/blob/algorithm/algorithm.md).
-
-## Early Results from Tech Report
-VPTQ achieves better accuracy and higher throughput with lower quantization overhead across models of different sizes. The following experimental results are for reference only; VPTQ can achieve better outcomes under reasonable parameters, especially in terms of model accuracy and inference speed.
-
-
-| Model | bitwidth | W2↓ | C4↓ | AvgQA↑ | tok/s↑ | mem(GB) | cost/h↓ |
-| ----------- | -------- | ---- | ---- | ------ | ------ | ------- | ------- |
-| LLaMA-2 7B | 2.02 | 6.13 | 8.07 | 58.2 | 39.9 | 2.28 | 2 |
-| | 2.26 | 5.95 | 7.87 | 59.4 | 35.7 | 2.48 | 3.1 |
-| LLaMA-2 13B | 2.02 | 5.32 | 7.15 | 62.4 | 26.9 | 4.03 | 3.2 |
-| | 2.18 | 5.28 | 7.04 | 63.1 | 18.5 | 4.31 | 3.6 |
-| LLaMA-2 70B | 2.07 | 3.93 | 5.72 | 68.6 | 9.7 | 19.54 | 19 |
-| | 2.11 | 3.92 | 5.71 | 68.7 | 9.7 | 20.01 | 19 |
-
-
-
-## More Models in [VPTQ-community](https://huggingface.co/VPTQ-community)
-
-⚠️ The repository only provides a method of model quantization algorithm.
-
-⚠️ The open-source community VPTQ-community provides models based on the technical report and quantization algorithm.
-
-
-
-**Quick Estimation of Model Bitwidth (Excluding Codebook Overhead)**:
-
-- **Model Naming Convention**: The model's name includes the **vector length** $v$, **codebook (lookup table) size**, and **residual codebook size**. For example, "Meta-Llama-3.1-70B-Instruct-v8-k65536-256-woft" is "Meta-Llama-3.1-70B-Instruct", where:
- - **Vector Length**: 8
- - **Number of Centroids**: 65536 (2^16)
- - **Number of Residual Centroids**: 256 (2^8)
-- **Equivalent Bitwidth Calculation**:
- - **Index**: log2(65536) = 16 / 8 = 2 bits
- - **Residual Index**: log2(256) = 8 / 8 = 1 bit
- - **Total Bitwidth**: 2 + 1 = 3 bits
-- **Model Size Estimation**: 70B * 3 bits / 8 bits per Byte = 26.25 GB
-
-- **Note**: This estimate does not include the size of the codebook (lookup table), other parameter overheads, and the padding overhead for storing indices. For the detailed calculation method, please refer to **Tech Report Appendix C.2**.
-
-
-| Model Series | Collections | (Estimated) Bit per weight |
-| :--------------------------------: | :-----------------------------------------------------------------------------------------------------------------------------: | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
-| Llama 3.1 Nemotron 70B Instruct HF | [HF 🤗](https://huggingface.co/collections/VPTQ-community/vptq-llama-31-nemotron-70b-instruct-hf-without-finetune-671730b96f16208d0b3fe942) | [4 bits](https://huggingface.co/VPTQ-community/Llama-3.1-Nemotron-70B-Instruct-HF-v8-k65536-65536-woft) [3 bits](https://huggingface.co/VPTQ-community/Llama-3.1-Nemotron-70B-Instruct-HF-v8-k65536-256-woft) [2 bits (1)](https://huggingface.co/VPTQ-community/Llama-3.1-Nemotron-70B-Instruct-HF-v16-k65536-65536-woft) [2 bits (2)](https://huggingface.co/VPTQ-community/Llama-3.1-Nemotron-70B-Instruct-HF-v8-k65536-0-woft) [1.875 bits](https://huggingface.co/VPTQ-community/Llama-3.1-Nemotron-70B-Instruct-HF-v16-k65536-16384-woft) [1.625 bits](https://huggingface.co/VPTQ-community/Llama-3.1-Nemotron-70B-Instruct-HF-v16-k65536-1024-woft) [1.5 bits](https://huggingface.co/VPTQ-community/Llama-3.1-Nemotron-70B-Instruct-HF-v16-k65536-256-woft) |
-| Llama 3.1 8B Instruct | [HF 🤗](https://huggingface.co/collections/VPTQ-community/vptq-llama-31-8b-instruct-without-finetune-66f2b70b1d002ceedef02d2e) | [4 bits](https://huggingface.co/VPTQ-community/Meta-Llama-3.1-8B-Instruct-v8-k65536-65536-woft) [3.5 bits](https://huggingface.co/VPTQ-community/Meta-Llama-3.1-8B-Instruct-v8-k65536-4096-woft) [3 bits](https://huggingface.co/VPTQ-community/Meta-Llama-3.1-8B-Instruct-v8-k65536-256-woft) [2.3 bits](https://huggingface.co/VPTQ-community/Meta-Llama-3.1-8B-Instruct-v12-k65536-4096-woft) |
-| Llama 3.1 70B Instruct | [HF 🤗](https://huggingface.co/collections/VPTQ-community/vptq-llama-31-70b-instruct-without-finetune-66f2bf454d3dd78dfee2ff11) | [4 bits](https://huggingface.co/VPTQ-community/Meta-Llama-3.1-70B-Instruct-v8-k65536-65536-woft) [3 bits](https://huggingface.co/VPTQ-community/Meta-Llama-3.1-70B-Instruct-v8-k65536-256-woft) [2.25 bits](https://huggingface.co/VPTQ-community/Meta-Llama-3.1-70B-Instruct-v8-k65536-4-woft) [2 bits (1)](https://huggingface.co/VPTQ-community/Meta-Llama-3.1-70B-Instruct-v16-k65536-65536-woft) [2 bits (2)](https://huggingface.co/VPTQ-community/Meta-Llama-3.1-70B-Instruct-v8-k65536-0-woft) [1.93 bits](https://huggingface.co/VPTQ-community/Meta-Llama-3.1-70B-Instruct-v16-k65536-32768-woft) [1.875 bits](https://huggingface.co/VPTQ-community/Meta-Llama-3.1-70B-Instruct-v8-k32768-0-woft) [1.75 bits](https://huggingface.co/VPTQ-community/Meta-Llama-3.1-70B-Instruct-v8-k16384-0-woft) |
-| Llama 3.1 405B Instruct | [HF 🤗](https://huggingface.co/collections/VPTQ-community/vptq-llama-31-405b-instruct-without-finetune-66f4413f9ba55e1a9e52cfb0) | [4 bits](https://huggingface.co/VPTQ-community/Meta-Llama-3.1-405B-Instruct-v8-k65536-65536-woft) [3 bits](https://huggingface.co/VPTQ-community/Meta-Llama-3.1-405B-Instruct-v8-k65536-256-woft) [2 bits](https://huggingface.co/VPTQ-community/Meta-Llama-3.1-405B-Instruct-v16-k65536-65536-woft) [1.875 bits](https://huggingface.co/VPTQ-community/Meta-Llama-3.1-405B-Instruct-v16-k32768-32768-woft) [1.625 bits](https://huggingface.co/VPTQ-community/Meta-Llama-3.1-405B-Instruct-v16-k65536-1024-woft) [1.5 bits (1)](https://huggingface.co/VPTQ-community/Meta-Llama-3.1-405B-Instruct-v8-k4096-0-woft) [1.5 bits (2)](https://huggingface.co/VPTQ-community/Meta-Llama-3.1-405B-Instruct-v16-k65536-256-woft) [1.43 bits](https://huggingface.co/VPTQ-community/Meta-Llama-3.1-405B-Instruct-v16-k65536-128-woft) [1.375 bits](https://huggingface.co/VPTQ-community/Meta-Llama-3.1-405B-Instruct-v16-k65536-64-woft) |
-| Mistral Large Instruct 2407 (123B) | [HF 🤗](https://huggingface.co/collections/VPTQ-community/vptq-mistral-large-instruct-2407-without-finetune-6711ebfb7faf85eed9cceb16) | [4 bits](https://huggingface.co/VPTQ-community/Mistral-Large-Instruct-2407-v8-k65536-65536-woft) [3 bits](https://huggingface.co/VPTQ-community/Mistral-Large-Instruct-2407-v8-k65536-256-woft) [2 bits (1)](https://huggingface.co/VPTQ-community/Mistral-Large-Instruct-2407-v16-k65536-65536-woft) [2 bits (2)](https://huggingface.co/VPTQ-community/Mistral-Large-Instruct-2407-v8-k65536-0-woft) [1.875 bits](https://huggingface.co/VPTQ-community/Mistral-Large-Instruct-2407-v16-k65536-16384-woft) [1.75 bits](https://huggingface.co/VPTQ-community/Mistral-Large-Instruct-2407-v16-k65536-4096-woft) [1.625 bits](https://huggingface.co/VPTQ-community/Mistral-Large-Instruct-2407-v16-k65536-1024-woft) [1.5 bits](https://huggingface.co/VPTQ-community/Mistral-Large-Instruct-2407-v16-k65536-256-woft) |
-| Qwen 2.5 7B Instruct | [HF 🤗](https://huggingface.co/collections/VPTQ-community/vptq-qwen-25-7b-instruct-without-finetune-66f3e9866d3167cc05ce954a) | [4 bits](https://huggingface.co/VPTQ-community/Qwen2.5-7B-Instruct-v8-k65536-65536-woft) [3 bits](https://huggingface.co/VPTQ-community/Qwen2.5-7B-Instruct-v8-k65536-256-woft) [2 bits (1)](https://huggingface.co/VPTQ-community/Qwen2.5-7B-Instruct-v8-k256-256-woft) [2 bits (2)](https://huggingface.co/VPTQ-community/Qwen2.5-7B-Instruct-v8-k65536-0-woft) [2 bits (3)](https://huggingface.co/VPTQ-community/Qwen2.5-7B-Instruct-v16-k65536-65536-woft) |
-| Qwen 2.5 14B Instruct | [HF 🤗](https://huggingface.co/collections/VPTQ-community/vptq-qwen-25-14b-instruct-without-finetune-66f827f83c7ffa7931b8376c) | [4 bits](https://huggingface.co/VPTQ-community/Qwen2.5-14B-Instruct-v8-k65536-65536-woft) [3 bits](https://huggingface.co/VPTQ-community/Qwen2.5-14B-Instruct-v8-k65536-256-woft) [2 bits (1)](https://huggingface.co/VPTQ-community/Qwen2.5-14B-Instruct-v8-k256-256-woft) [2 bits (2)](https://huggingface.co/VPTQ-community/Qwen2.5-14B-Instruct-v8-k65536-0-woft) [2 bits (3)](https://huggingface.co/VPTQ-community/Qwen2.5-14B-Instruct-v16-k65536-65536-woft) |
-| Qwen 2.5 32B Instruct | [HF 🤗](https://huggingface.co/collections/VPTQ-community/vptq-qwen-25-32b-instruct-without-finetune-66fe77173bf7d64139f0f613) | [4 bits](https://huggingface.co/VPTQ-community/Qwen2.5-32B-Instruct-v8-k65536-65536-woft) [3 bits](https://huggingface.co/VPTQ-community/Qwen2.5-32B-Instruct-v8-k65536-256-woft) [2 bits (1)](https://huggingface.co/VPTQ-community/Qwen2.5-32B-Instruct-v16-k65536-65536-woft) [2 bits (2)](https://huggingface.co/VPTQ-community/Qwen2.5-32B-Instruct-v8-k65536-0-woft) [2 bits (3)](https://huggingface.co/VPTQ-community/Qwen2.5-32B-Instruct-v8-k256-256-woft) |
-| Qwen 2.5 72B Instruct | [HF 🤗](https://huggingface.co/collections/VPTQ-community/vptq-qwen-25-72b-instruct-without-finetune-66f3bf1b3757dfa1ecb481c0) | [4 bits](https://huggingface.co/VPTQ-community/Qwen2.5-72B-Instruct-v8-k65536-65536-woft) [3 bits](https://huggingface.co/VPTQ-community/Qwen2.5-72B-Instruct-v8-k65536-256-woft) [2.38 bits](https://huggingface.co/VPTQ-community/Qwen2.5-72B-Instruct-v8-k1024-512-woft) [2.25 bits (1)](https://huggingface.co/VPTQ-community/Qwen2.5-72B-Instruct-v8-k512-512-woft) [2.25 bits (2)](https://huggingface.co/VPTQ-community/Qwen2.5-72B-Instruct-v8-k65536-4-woft) [2 bits (1)](https://huggingface.co/VPTQ-community/Qwen2.5-72B-Instruct-v8-k65536-0-woft) [2 bits (2)](https://huggingface.co/VPTQ-community/Qwen2.5-72B-Instruct-v16-k65536-65536-woft) [1.94 bits](https://huggingface.co/VPTQ-community/Qwen2.5-72B-Instruct-v16-k65536-32768-woft) |
-| Reproduced from the tech report | [HF 🤗](https://huggingface.co/collections/VPTQ-community/reproduced-vptq-tech-report-baseline-66fbf1dffe741cc9e93ecf04) | Results from the open source community for reference only, please use them responsibly. |
-| Hessian and Inverse Hessian Matrix | [HF 🤗](https://huggingface.co/collections/VPTQ-community/hessian-and-invhessian-checkpoints-66fd249a104850d17b23fd8b) | Collected from RedPajama-Data-1T-Sample, following [Quip#](https://github.com/Cornell-RelaxML/quip-sharp/blob/main/quantize_llama/hessian_offline_llama.py)
\ No newline at end of file
diff --git a/docs/source/en/tasks/asr.md b/docs/source/en/tasks/asr.md
index e8884d327b565b..f3e068444ca556 100644
--- a/docs/source/en/tasks/asr.md
+++ b/docs/source/en/tasks/asr.md
@@ -20,12 +20,12 @@ rendered properly in your Markdown viewer.
-Automatic speech recognition (ASR) converts a speech signal to text, mapping a sequence of audio inputs to text outputs. Virtual assistants like Siri and Alexa use ASR models to help users every day, and there are many other useful user-facing applications like live captioning and note-taking during meetings.
+Automatic speech recognition (ASR) converts a speech signal to text, mapping a sequence of audio inputs to text outputs. Virtual assistants like Siri and Alexa use ASR models to help users everyday, and there are many other useful user-facing applications like live captioning and note-taking during meetings.
This guide will show you how to:
-1. Fine-tune [Wav2Vec2](https://huggingface.co/facebook/wav2vec2-base) on the [MInDS-14](https://huggingface.co/datasets/PolyAI/minds14) dataset to transcribe audio to text.
-2. Use your fine-tuned model for inference.
+1. Finetune [Wav2Vec2](https://huggingface.co/facebook/wav2vec2-base) on the [MInDS-14](https://huggingface.co/datasets/PolyAI/minds14) dataset to transcribe audio to text.
+2. Use your finetuned model for inference.
@@ -49,7 +49,7 @@ We encourage you to login to your Hugging Face account so you can upload and sha
## Load MInDS-14 dataset
-Start by loading a smaller subset of the [MInDS-14](https://huggingface.co/datasets/PolyAI/minds14) dataset from the 🤗 Datasets library. This will give you a chance to experiment and make sure everything works before spending more time training on the full dataset.
+Start by loading a smaller subset of the [MInDS-14](https://huggingface.co/datasets/PolyAI/minds14) dataset from the 🤗 Datasets library. This'll give you a chance to experiment and make sure everything works before spending more time training on the full dataset.
```py
>>> from datasets import load_dataset, Audio
@@ -79,13 +79,13 @@ DatasetDict({
})
```
-While the dataset contains a lot of useful information, like `lang_id` and `english_transcription`, this guide focuses on the `audio` and `transcription`. Remove the other columns with the [`~datasets.Dataset.remove_columns`] method:
+While the dataset contains a lot of useful information, like `lang_id` and `english_transcription`, you'll focus on the `audio` and `transcription` in this guide. Remove the other columns with the [`~datasets.Dataset.remove_columns`] method:
```py
>>> minds = minds.remove_columns(["english_transcription", "intent_class", "lang_id"])
```
-Review the example again:
+Take a look at the example again:
```py
>>> minds["train"][0]
@@ -112,7 +112,7 @@ The next step is to load a Wav2Vec2 processor to process the audio signal:
>>> processor = AutoProcessor.from_pretrained("facebook/wav2vec2-base")
```
-The MInDS-14 dataset has a sampling rate of 8000Hz (you can find this information in its [dataset card](https://huggingface.co/datasets/PolyAI/minds14)), which means you'll need to resample the dataset to 16000Hz to use the pretrained Wav2Vec2 model:
+The MInDS-14 dataset has a sampling rate of 8000kHz (you can find this information in its [dataset card](https://huggingface.co/datasets/PolyAI/minds14)), which means you'll need to resample the dataset to 16000kHz to use the pretrained Wav2Vec2 model:
```py
>>> minds = minds.cast_column("audio", Audio(sampling_rate=16_000))
@@ -125,7 +125,7 @@ The MInDS-14 dataset has a sampling rate of 8000Hz (you can find this informatio
'transcription': "hi I'm trying to use the banking app on my phone and currently my checking and savings account balance is not refreshing"}
```
-As you can see in the `transcription` above, the text contains a mix of uppercase and lowercase characters. The Wav2Vec2 tokenizer is only trained on uppercase characters so you'll need to make sure the text matches the tokenizer's vocabulary:
+As you can see in the `transcription` above, the text contains a mix of upper and lowercase characters. The Wav2Vec2 tokenizer is only trained on uppercase characters so you'll need to make sure the text matches the tokenizer's vocabulary:
```py
>>> def uppercase(example):
@@ -196,7 +196,7 @@ Now instantiate your `DataCollatorForCTCWithPadding`:
## Evaluate
-Including a metric during training is often helpful for evaluating your model's performance. You can quickly load an evaluation method with the 🤗 [Evaluate](https://huggingface.co/docs/evaluate/index) library. For this task, load the [word error rate](https://huggingface.co/spaces/evaluate-metric/wer) (WER) metric (refer to the 🤗 Evaluate [quick tour](https://huggingface.co/docs/evaluate/a_quick_tour) to learn more about loading and computing metrics):
+Including a metric during training is often helpful for evaluating your model's performance. You can quickly load an evaluation method with the 🤗 [Evaluate](https://huggingface.co/docs/evaluate/index) library. For this task, load the [word error rate](https://huggingface.co/spaces/evaluate-metric/wer) (WER) metric (see the 🤗 Evaluate [quick tour](https://huggingface.co/docs/evaluate/a_quick_tour) to learn more about how to load and compute a metric):
```py
>>> import evaluate
@@ -236,7 +236,7 @@ If you aren't familiar with finetuning a model with the [`Trainer`], take a look
-You are now ready to start training your model! Load Wav2Vec2 with [`AutoModelForCTC`]. Specify the reduction to apply with the `ctc_loss_reduction` parameter. It is often better to use the average instead of the default summation:
+You're ready to start training your model now! Load Wav2Vec2 with [`AutoModelForCTC`]. Specify the reduction to apply with the `ctc_loss_reduction` parameter. It is often better to use the average instead of the default summation:
```py
>>> from transformers import AutoModelForCTC, TrainingArguments, Trainer
@@ -252,7 +252,7 @@ At this point, only three steps remain:
1. Define your training hyperparameters in [`TrainingArguments`]. The only required parameter is `output_dir` which specifies where to save your model. You'll push this model to the Hub by setting `push_to_hub=True` (you need to be signed in to Hugging Face to upload your model). At the end of each epoch, the [`Trainer`] will evaluate the WER and save the training checkpoint.
2. Pass the training arguments to [`Trainer`] along with the model, dataset, tokenizer, data collator, and `compute_metrics` function.
-3. Call [`~Trainer.train`] to fine-tune your model.
+3. Call [`~Trainer.train`] to finetune your model.
```py
>>> training_args = TrainingArguments(
@@ -289,7 +289,7 @@ At this point, only three steps remain:
>>> trainer.train()
```
-Once training is completed, share your model to the Hub with the [`~transformers.Trainer.push_to_hub`] method so it can be accessible to everyone:
+Once training is completed, share your model to the Hub with the [`~transformers.Trainer.push_to_hub`] method so everyone can use your model:
```py
>>> trainer.push_to_hub()
@@ -299,13 +299,13 @@ Once training is completed, share your model to the Hub with the [`~transformers
-For a more in-depth example of how to fine-tune a model for automatic speech recognition, take a look at this blog [post](https://huggingface.co/blog/fine-tune-wav2vec2-english) for English ASR and this [post](https://huggingface.co/blog/fine-tune-xlsr-wav2vec2) for multilingual ASR.
+For a more in-depth example of how to finetune a model for automatic speech recognition, take a look at this blog [post](https://huggingface.co/blog/fine-tune-wav2vec2-english) for English ASR and this [post](https://huggingface.co/blog/fine-tune-xlsr-wav2vec2) for multilingual ASR.
## Inference
-Great, now that you've fine-tuned a model, you can use it for inference!
+Great, now that you've finetuned a model, you can use it for inference!
Load an audio file you'd like to run inference on. Remember to resample the sampling rate of the audio file to match the sampling rate of the model if you need to!
@@ -318,7 +318,7 @@ Load an audio file you'd like to run inference on. Remember to resample the samp
>>> audio_file = dataset[0]["audio"]["path"]
```
-The simplest way to try out your fine-tuned model for inference is to use it in a [`pipeline`]. Instantiate a `pipeline` for automatic speech recognition with your model, and pass your audio file to it:
+The simplest way to try out your finetuned model for inference is to use it in a [`pipeline`]. Instantiate a `pipeline` for automatic speech recognition with your model, and pass your audio file to it:
```py
>>> from transformers import pipeline
diff --git a/docs/source/en/tasks/audio_classification.md b/docs/source/en/tasks/audio_classification.md
index 973f95e1e9555d..59d6a175da82ba 100644
--- a/docs/source/en/tasks/audio_classification.md
+++ b/docs/source/en/tasks/audio_classification.md
@@ -9,7 +9,7 @@ Unless required by applicable law or agreed to in writing, software distributed
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-⚠️ Note that this file is in Markdown but contains specific syntax for our doc-builder (similar to MDX) that may not be
+⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
@@ -20,12 +20,12 @@ rendered properly in your Markdown viewer.
-Audio classification - just like with text - assigns a class label as output from the input data. The only difference is instead of text inputs, you have raw audio waveforms. Some practical applications of audio classification include identifying speaker intent, language classification, and even animal species by their sounds.
+Audio classification - just like with text - assigns a class label output from the input data. The only difference is instead of text inputs, you have raw audio waveforms. Some practical applications of audio classification include identifying speaker intent, language classification, and even animal species by their sounds.
This guide will show you how to:
-1. Fine-tune [Wav2Vec2](https://huggingface.co/facebook/wav2vec2-base) on the [MInDS-14](https://huggingface.co/datasets/PolyAI/minds14) dataset to classify speaker intent.
-2. Use your fine-tuned model for inference.
+1. Finetune [Wav2Vec2](https://huggingface.co/facebook/wav2vec2-base) on the [MInDS-14](https://huggingface.co/datasets/PolyAI/minds14) dataset to classify speaker intent.
+2. Use your finetuned model for inference.
@@ -57,7 +57,7 @@ Start by loading the MInDS-14 dataset from the 🤗 Datasets library:
>>> minds = load_dataset("PolyAI/minds14", name="en-US", split="train")
```
-Split the dataset's `train` split into a smaller train and test set with the [`~datasets.Dataset.train_test_split`] method. This will give you a chance to experiment and make sure everything works before spending more time on the full dataset.
+Split the dataset's `train` split into a smaller train and test set with the [`~datasets.Dataset.train_test_split`] method. This'll give you a chance to experiment and make sure everything works before spending more time on the full dataset.
```py
>>> minds = minds.train_test_split(test_size=0.2)
@@ -79,13 +79,13 @@ DatasetDict({
})
```
-While the dataset contains a lot of useful information, like `lang_id` and `english_transcription`, you will focus on the `audio` and `intent_class` in this guide. Remove the other columns with the [`~datasets.Dataset.remove_columns`] method:
+While the dataset contains a lot of useful information, like `lang_id` and `english_transcription`, you'll focus on the `audio` and `intent_class` in this guide. Remove the other columns with the [`~datasets.Dataset.remove_columns`] method:
```py
>>> minds = minds.remove_columns(["path", "transcription", "english_transcription", "lang_id"])
```
-Here's an example:
+Take a look at an example now:
```py
>>> minds["train"][0]
@@ -128,7 +128,7 @@ The next step is to load a Wav2Vec2 feature extractor to process the audio signa
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base")
```
-The MInDS-14 dataset has a sampling rate of 8kHz (you can find this information in its [dataset card](https://huggingface.co/datasets/PolyAI/minds14)), which means you'll need to resample the dataset to 16kHz to use the pretrained Wav2Vec2 model:
+The MInDS-14 dataset has a sampling rate of 8000khz (you can find this information in it's [dataset card](https://huggingface.co/datasets/PolyAI/minds14)), which means you'll need to resample the dataset to 16000kHz to use the pretrained Wav2Vec2 model:
```py
>>> minds = minds.cast_column("audio", Audio(sampling_rate=16_000))
@@ -155,7 +155,7 @@ Now create a preprocessing function that:
... return inputs
```
-To apply the preprocessing function over the entire dataset, use 🤗 Datasets [`~datasets.Dataset.map`] function. You can speed up `map` by setting `batched=True` to process multiple elements of the dataset at once. Remove unnecessary columns and rename `intent_class` to `label`, as required by the model:
+To apply the preprocessing function over the entire dataset, use 🤗 Datasets [`~datasets.Dataset.map`] function. You can speed up `map` by setting `batched=True` to process multiple elements of the dataset at once. Remove the columns you don't need, and rename `intent_class` to `label` because that's the name the model expects:
```py
>>> encoded_minds = minds.map(preprocess_function, remove_columns="audio", batched=True)
@@ -208,9 +208,9 @@ You're ready to start training your model now! Load Wav2Vec2 with [`AutoModelFor
At this point, only three steps remain:
-1. Define your training hyperparameters in [`TrainingArguments`]. The only required parameter is `output_dir`, which specifies where to save your model. You'll push this model to the Hub by setting `push_to_hub=True` (you need to be signed in to Hugging Face to upload your model). At the end of each epoch, the [`Trainer`] will evaluate the accuracy and save the training checkpoint.
+1. Define your training hyperparameters in [`TrainingArguments`]. The only required parameter is `output_dir` which specifies where to save your model. You'll push this model to the Hub by setting `push_to_hub=True` (you need to be signed in to Hugging Face to upload your model). At the end of each epoch, the [`Trainer`] will evaluate the accuracy and save the training checkpoint.
2. Pass the training arguments to [`Trainer`] along with the model, dataset, tokenizer, data collator, and `compute_metrics` function.
-3. Call [`~Trainer.train`] to fine-tune your model.
+3. Call [`~Trainer.train`] to finetune your model.
```py
@@ -252,15 +252,15 @@ Once training is completed, share your model to the Hub with the [`~transformers
-For a more in-depth example of how to fine-tune a model for audio classification, take a look at the corresponding [PyTorch notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/audio_classification.ipynb).
+For a more in-depth example of how to finetune a model for audio classification, take a look at the corresponding [PyTorch notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/audio_classification.ipynb).
## Inference
-Great, now that you've fine-tuned a model, you can use it for inference!
+Great, now that you've finetuned a model, you can use it for inference!
-Load an audio file for inference. Remember to resample the sampling rate of the audio file to match the model's sampling rate, if necessary.
+Load an audio file you'd like to run inference on. Remember to resample the sampling rate of the audio file to match the sampling rate of the model if you need to!
```py
>>> from datasets import load_dataset, Audio
@@ -271,7 +271,7 @@ Load an audio file for inference. Remember to resample the sampling rate of the
>>> audio_file = dataset[0]["audio"]["path"]
```
-The simplest way to try out your fine-tuned model for inference is to use it in a [`pipeline`]. Instantiate a `pipeline` for audio classification with your model, and pass your audio file to it:
+The simplest way to try out your finetuned model for inference is to use it in a [`pipeline`]. Instantiate a `pipeline` for audio classification with your model, and pass your audio file to it:
```py
>>> from transformers import pipeline
diff --git a/docs/source/en/tasks/multiple_choice.md b/docs/source/en/tasks/multiple_choice.md
index 18b12f2166637e..06eb45eda99150 100644
--- a/docs/source/en/tasks/multiple_choice.md
+++ b/docs/source/en/tasks/multiple_choice.md
@@ -419,7 +419,7 @@ Get the class with the highest probability:
```py
>>> predicted_class = logits.argmax().item()
>>> predicted_class
-0
+'0'
```
@@ -448,7 +448,7 @@ Get the class with the highest probability:
```py
>>> predicted_class = int(tf.math.argmax(logits, axis=-1)[0])
>>> predicted_class
-0
+'0'
```
diff --git a/docs/source/en/tasks/question_answering.md b/docs/source/en/tasks/question_answering.md
index 41d7fd48cf816e..998010e67ca95f 100644
--- a/docs/source/en/tasks/question_answering.md
+++ b/docs/source/en/tasks/question_answering.md
@@ -325,7 +325,7 @@ or [TensorFlow notebook](https://colab.research.google.com/github/huggingface/no
Evaluation for question answering requires a significant amount of postprocessing. To avoid taking up too much of your time, this guide skips the evaluation step. The [`Trainer`] still calculates the evaluation loss during training so you're not completely in the dark about your model's performance.
-If you have more time and you're interested in how to evaluate your model for question answering, take a look at the [Question answering](https://huggingface.co/course/chapter7/7?fw=pt#post-processing) chapter from the 🤗 Hugging Face Course!
+If have more time and you're interested in how to evaluate your model for question answering, take a look at the [Question answering](https://huggingface.co/course/chapter7/7?fw=pt#post-processing) chapter from the 🤗 Hugging Face Course!
## Inference
@@ -397,7 +397,7 @@ Tokenize the text and return TensorFlow tensors:
>>> from transformers import AutoTokenizer
>>> tokenizer = AutoTokenizer.from_pretrained("my_awesome_qa_model")
->>> inputs = tokenizer(question, context, return_tensors="tf")
+>>> inputs = tokenizer(question, text, return_tensors="tf")
```
Pass your inputs to the model and return the `logits`:
diff --git a/docs/source/en/tasks/summarization.md b/docs/source/en/tasks/summarization.md
index e16dd17dfe1fc8..7d7ecf1fbab6db 100644
--- a/docs/source/en/tasks/summarization.md
+++ b/docs/source/en/tasks/summarization.md
@@ -283,7 +283,7 @@ Pass your `compute_metrics` function to [`~transformers.KerasMetricCallback`]:
```py
>>> from transformers.keras_callbacks import KerasMetricCallback
->>> metric_callback = KerasMetricCallback(metric_fn=compute_metrics, eval_dataset=tf_test_set)
+>>> metric_callback = KerasMetricCallback(metric_fn=compute_metrics, eval_dataset=tf_validation_set)
```
Specify where to push your model and tokenizer in the [`~transformers.PushToHubCallback`]:
diff --git a/docs/source/en/tasks/translation.md b/docs/source/en/tasks/translation.md
index 922cdc7241176a..426ba1c340fb81 100644
--- a/docs/source/en/tasks/translation.md
+++ b/docs/source/en/tasks/translation.md
@@ -290,7 +290,7 @@ Pass your `compute_metrics` function to [`~transformers.KerasMetricCallback`]:
```py
>>> from transformers.keras_callbacks import KerasMetricCallback
->>> metric_callback = KerasMetricCallback(metric_fn=compute_metrics, eval_dataset=tf_test_set)
+>>> metric_callback = KerasMetricCallback(metric_fn=compute_metrics, eval_dataset=tf_validation_set)
```
Specify where to push your model and tokenizer in the [`~transformers.PushToHubCallback`]:
diff --git a/docs/source/es/quicktour.md b/docs/source/es/quicktour.md
index c4babab09f023d..ad2549ef450bb2 100644
--- a/docs/source/es/quicktour.md
+++ b/docs/source/es/quicktour.md
@@ -385,8 +385,8 @@ Una característica particularmente interesante de 🤗 Transformers es la habil
```py
>>> from transformers import AutoModel
->>> tokenizer = AutoTokenizer.from_pretrained(pt_save_directory)
->>> pt_model = AutoModelForSequenceClassification.from_pretrained(pt_save_directory, from_pt=True)
+>>> tokenizer = AutoTokenizer.from_pretrained(tf_save_directory)
+>>> pt_model = AutoModelForSequenceClassification.from_pretrained(tf_save_directory, from_tf=True)
```
@@ -394,8 +394,8 @@ Una característica particularmente interesante de 🤗 Transformers es la habil
```py
>>> from transformers import TFAutoModel
->>> tokenizer = AutoTokenizer.from_pretrained(tf_save_directory)
->>> tf_model = TFAutoModelForSequenceClassification.from_pretrained(tf_save_directory, from_tf=True)
+>>> tokenizer = AutoTokenizer.from_pretrained(pt_save_directory)
+>>> tf_model = TFAutoModelForSequenceClassification.from_pretrained(pt_save_directory, from_pt=True)
```
diff --git a/docs/source/fr/quicktour.md b/docs/source/fr/quicktour.md
index dcf21562316d5d..3cc2a8c5faac76 100644
--- a/docs/source/fr/quicktour.md
+++ b/docs/source/fr/quicktour.md
@@ -354,8 +354,8 @@ Une fonctionnalité particulièrement cool 🤗 Transformers est la possibilité
```py
>>> from transformers import AutoModel
->>> tokenizer = AutoTokenizer.from_pretrained(pt_save_directory)
->>> pt_model = AutoModelForSequenceClassification.from_pretrained(pt_save_directory, from_pt=True)
+>>> tokenizer = AutoTokenizer.from_pretrained(tf_save_directory)
+>>> pt_model = AutoModelForSequenceClassification.from_pretrained(tf_save_directory, from_tf=True)
```
@@ -363,8 +363,8 @@ Une fonctionnalité particulièrement cool 🤗 Transformers est la possibilité
```py
>>> from transformers import TFAutoModel
->>> tokenizer = AutoTokenizer.from_pretrained(tf_save_directory)
->>> tf_model = TFAutoModelForSequenceClassification.from_pretrained(tf_save_directory, from_tf=True)
+>>> tokenizer = AutoTokenizer.from_pretrained(pt_save_directory)
+>>> tf_model = TFAutoModelForSequenceClassification.from_pretrained(pt_save_directory, from_pt=True)
```
diff --git a/docs/source/it/quicktour.md b/docs/source/it/quicktour.md
index f0291a6167715a..07e7a2974a1fbc 100644
--- a/docs/source/it/quicktour.md
+++ b/docs/source/it/quicktour.md
@@ -111,7 +111,7 @@ etichetta: negative, con punteggio: 0.9998
La [`pipeline`] può anche iterare su un dataset intero. Inizia installando la libreria [🤗 Datasets](https://huggingface.co/docs/datasets/):
```bash
-pip install datasets
+pip install datasets
```
Crea una [`pipeline`] con il compito che vuoi risolvere e con il modello che vuoi utilizzare.
@@ -385,8 +385,8 @@ Una caratteristica particolarmente interessante di 🤗 Transformers è la sua a
```py
>>> from transformers import AutoModel
->>> tokenizer = AutoTokenizer.from_pretrained(pt_save_directory)
->>> pt_model = AutoModelForSequenceClassification.from_pretrained(pt_save_directory, from_pt=True)
+>>> tokenizer = AutoTokenizer.from_pretrained(tf_save_directory)
+>>> pt_model = AutoModelForSequenceClassification.from_pretrained(tf_save_directory, from_tf=True)
```
@@ -394,8 +394,8 @@ Una caratteristica particolarmente interessante di 🤗 Transformers è la sua a
```py
>>> from transformers import TFAutoModel
->>> tokenizer = AutoTokenizer.from_pretrained(tf_save_directory)
->>> tf_model = TFAutoModelForSequenceClassification.from_pretrained(tf_save_directory, from_tf=True)
+>>> tokenizer = AutoTokenizer.from_pretrained(pt_save_directory)
+>>> tf_model = TFAutoModelForSequenceClassification.from_pretrained(pt_save_directory, from_pt=True)
```
diff --git a/docs/source/ja/quicktour.md b/docs/source/ja/quicktour.md
index 0eb00cf220b54a..e03dea33cbd189 100644
--- a/docs/source/ja/quicktour.md
+++ b/docs/source/ja/quicktour.md
@@ -386,8 +386,8 @@ tensor([[0.0021, 0.0018, 0.0115, 0.2121, 0.7725],
```py
>>> from transformers import AutoModel
->>> tokenizer = AutoTokenizer.from_pretrained(pt_save_directory)
->>> pt_model = AutoModelForSequenceClassification.from_pretrained(pt_save_directory, from_pt=True)
+>>> tokenizer = AutoTokenizer.from_pretrained(tf_save_directory)
+>>> pt_model = AutoModelForSequenceClassification.from_pretrained(tf_save_directory, from_tf=True)
```
@@ -396,8 +396,8 @@ tensor([[0.0021, 0.0018, 0.0115, 0.2121, 0.7725],
```py
>>> from transformers import TFAutoModel
->>> tokenizer = AutoTokenizer.from_pretrained(tf_save_directory)
->>> tf_model = TFAutoModelForSequenceClassification.from_pretrained(tf_save_directory, from_tf=True)
+>>> tokenizer = AutoTokenizer.from_pretrained(pt_save_directory)
+>>> tf_model = TFAutoModelForSequenceClassification.from_pretrained(pt_save_directory, from_pt=True)
```
diff --git a/docs/source/ja/tasks/audio_classification.md b/docs/source/ja/tasks/audio_classification.md
index 3b33d1b6043d78..aa38d12d4ef0cf 100644
--- a/docs/source/ja/tasks/audio_classification.md
+++ b/docs/source/ja/tasks/audio_classification.md
@@ -128,7 +128,7 @@ DatasetDict({
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base")
```
-MInDS-14 データセットのサンプリング レートは 8khz です (この情報は [データセット カード](https://huggingface.co/datasets/PolyAI/minds14) で確認できます)。つまり、データセットを再サンプリングする必要があります。事前トレーニングされた Wav2Vec2 モデルを使用するには、16kHz に設定します。
+MInDS-14 データセットのサンプリング レートは 8000khz です (この情報は [データセット カード](https://huggingface.co/datasets/PolyAI/minds14) で確認できます)。つまり、データセットを再サンプリングする必要があります。事前トレーニングされた Wav2Vec2 モデルを使用するには、16000kHz に設定します。
```py
>>> minds = minds.cast_column("audio", Audio(sampling_rate=16_000))
diff --git a/docs/source/ko/_toctree.yml b/docs/source/ko/_toctree.yml
index 54740610ee1148..7e9567769cca1a 100644
--- a/docs/source/ko/_toctree.yml
+++ b/docs/source/ko/_toctree.yml
@@ -151,8 +151,6 @@
title: AWQ
- local: in_translation
title: (번역중) AQLM
- - local: in_translation
- title: (번역중) VPTQ
- local: in_translation
title: (번역중) Quanto
- local: in_translation
@@ -175,8 +173,6 @@
title: (번역중) AWQ
- local: in_translation
title: (번역중) AQLM
- - local: in_translation
- title: (번역중) VPTQ
- local: quantization/quanto
title: Quanto
- local: quantization/eetq
diff --git a/docs/source/ko/llm_optims.md b/docs/source/ko/llm_optims.md
index 99eabc19ce860a..656ed53584c226 100644
--- a/docs/source/ko/llm_optims.md
+++ b/docs/source/ko/llm_optims.md
@@ -375,7 +375,7 @@ with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable
양자화는 LLM 가중치를 더 낮은 정밀도로 저장하여 크기를 줄입니다. 이는 메모리 사용량을 줄이며 GPU 메모리에 제약이 있는 경우 추론을 위해 LLM을 로드하는 것을 더 용이하게 합니다. GPU가 충분하다면, 모델을 양자화할 필요는 없습니다. 추가적인 양자화 및 양자화 해제 단계로 인해 약간의 지연이 발생할 수 있기 때문입니다(AWQ 및 융합 AWQ 모듈 제외).
> [!TIP]
-> 다양한 양자화 라이브러리(자세한 내용은 [Quantization](./quantization) 가이드를 참조하십시오)가 있습니다. 여기에는 Quanto, AQLM, VPTQ, AWQ 및 AutoGPTQ가 포함됩니다. 사용 사례에 가장 잘 맞는 라이브러리를 사용해 보십시오. 또한 AutoGPTQ와 bitsandbytes를 비교하는 [Overview of natively supported quantization schemes in 🤗 Transformers](https://hf.co/blog/overview-quantization-transformers) 블로그 게시물을 읽어보는 것을 추천합니다.
+> 다양한 양자화 라이브러리(자세한 내용은 [Quantization](./quantization) 가이드를 참조하십시오)가 있습니다. 여기에는 Quanto, AQLM, AWQ 및 AutoGPTQ가 포함됩니다. 사용 사례에 가장 잘 맞는 라이브러리를 사용해 보십시오. 또한 AutoGPTQ와 bitsandbytes를 비교하는 [Overview of natively supported quantization schemes in 🤗 Transformers](https://hf.co/blog/overview-quantization-transformers) 블로그 게시물을 읽어보는 것을 추천합니다.
아래의 모델 메모리 계산기를 사용하여 모델을 로드하는 데 필요한 메모리를 추정하고 비교해 보십시오. 예를 들어 [Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1)를 로드하는 데 필요한 메모리를 추정해 보십시오.
diff --git a/docs/source/ko/main_classes/quantization.md b/docs/source/ko/main_classes/quantization.md
index 6f793f22107417..b1d1730d28d00b 100644
--- a/docs/source/ko/main_classes/quantization.md
+++ b/docs/source/ko/main_classes/quantization.md
@@ -35,10 +35,6 @@ Transformers에서 지원되지 않는 양자화 기법들은 [`HfQuantizer`]
[[autodoc]] AqlmConfig
-## VptqConfig[[transformers.VptqConfig]]
-
-[[autodoc]] VptqConfig
-
## AwqConfig[[transformers.AwqConfig]]
[[autodoc]] AwqConfig
diff --git a/docs/source/ko/quicktour.md b/docs/source/ko/quicktour.md
index 4c3b137aa00ff9..06f44e6fd2970c 100644
--- a/docs/source/ko/quicktour.md
+++ b/docs/source/ko/quicktour.md
@@ -361,8 +361,8 @@ tensor([[0.0021, 0.0018, 0.0115, 0.2121, 0.7725],
```py
>>> from transformers import AutoModel
->>> tokenizer = AutoTokenizer.from_pretrained(pt_save_directory)
->>> pt_model = AutoModelForSequenceClassification.from_pretrained(pt_save_directory, from_pt=True)
+>>> tokenizer = AutoTokenizer.from_pretrained(tf_save_directory)
+>>> pt_model = AutoModelForSequenceClassification.from_pretrained(tf_save_directory, from_tf=True)
```
@@ -370,8 +370,8 @@ tensor([[0.0021, 0.0018, 0.0115, 0.2121, 0.7725],
```py
>>> from transformers import TFAutoModel
->>> tokenizer = AutoTokenizer.from_pretrained(tf_save_directory)
->>> tf_model = TFAutoModelForSequenceClassification.from_pretrained(tf_save_directory, from_tf=True)
+>>> tokenizer = AutoTokenizer.from_pretrained(pt_save_directory)
+>>> tf_model = TFAutoModelForSequenceClassification.from_pretrained(pt_save_directory, from_pt=True)
```
diff --git a/docs/source/ko/tasks/audio_classification.md b/docs/source/ko/tasks/audio_classification.md
index 2defa691edef75..936b4eb1989827 100644
--- a/docs/source/ko/tasks/audio_classification.md
+++ b/docs/source/ko/tasks/audio_classification.md
@@ -128,7 +128,7 @@ DatasetDict({
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base")
```
-MinDS-14 데이터 세트의 샘플링 속도는 8khz이므로(이 정보는 [데이터세트 카드](https://huggingface.co/datasets/PolyAI/minds14)에서 확인할 수 있습니다), 사전 훈련된 Wav2Vec2 모델을 사용하려면 데이터 세트를 16kHz로 리샘플링해야 합니다:
+MinDS-14 데이터 세트의 샘플링 속도는 8000khz이므로(이 정보는 [데이터세트 카드](https://huggingface.co/datasets/PolyAI/minds14)에서 확인할 수 있습니다), 사전 훈련된 Wav2Vec2 모델을 사용하려면 데이터 세트를 16000kHz로 리샘플링해야 합니다:
```py
>>> minds = minds.cast_column("audio", Audio(sampling_rate=16_000))
diff --git a/docs/source/pt/quicktour.md b/docs/source/pt/quicktour.md
index cc583697b9a658..d34480ee23a880 100644
--- a/docs/source/pt/quicktour.md
+++ b/docs/source/pt/quicktour.md
@@ -37,7 +37,7 @@ A [`pipeline`] apoia diversas tarefas fora da caixa:
**Texto**:
* Análise sentimental: classifica a polaridade de um texto.
* Geração de texto (em Inglês): gera texto a partir de uma entrada.
-* Reconhecimento de entidade mencionada: legenda cada palavra com uma classe que a representa (pessoa, data, local, etc...)
+* Reconhecimento de entidade mencionada: legenda cada palavra com uma classe que a representa (pessoa, data, local, etc...)
* Respostas: extrai uma resposta dado algum contexto e uma questão
* Máscara de preenchimento: preenche o espaço, dado um texto com máscaras de palavras.
* Sumarização: gera o resumo de um texto longo ou documento.
@@ -87,7 +87,7 @@ Importe [`pipeline`] e especifique a tarefa que deseja completar:
>>> classifier = pipeline("sentiment-analysis")
```
-A pipeline baixa and armazena um [modelo pré-treinado](https://huggingface.co/distilbert/distilbert-base-uncased-finetuned-sst-2-english) padrão e tokenizer para análise sentimental. Agora você pode usar `classifier` no texto alvo:
+A pipeline baixa and armazena um [modelo pré-treinado](https://huggingface.co/distilbert/distilbert-base-uncased-finetuned-sst-2-english) padrão e tokenizer para análise sentimental. Agora você pode usar `classifier` no texto alvo:
```py
>>> classifier("We are very happy to show you the 🤗 Transformers library.")
@@ -107,7 +107,7 @@ label: NEGATIVE, with score: 0.5309
A [`pipeline`] também pode iterar sobre um Dataset inteiro. Comece instalando a biblioteca de [🤗 Datasets](https://huggingface.co/docs/datasets/):
```bash
-pip install datasets
+pip install datasets
```
Crie uma [`pipeline`] com a tarefa que deseja resolver e o modelo que deseja usar.
@@ -133,7 +133,7 @@ Precisamos garantir que a taxa de amostragem do conjunto de dados corresponda à
>>> dataset = dataset.cast_column("audio", Audio(sampling_rate=speech_recognizer.feature_extractor.sampling_rate))
```
-Os arquivos de áudio são carregados e re-amostrados automaticamente ao chamar a coluna `"audio"`.
+Os arquivos de áudio são carregados e re-amostrados automaticamente ao chamar a coluna `"audio"`.
Vamos extrair as arrays de formas de onda originais das primeiras 4 amostras e passá-las como uma lista para o pipeline:
```py
@@ -176,7 +176,7 @@ Use o [`TFAutoModelForSequenceClassification`] and [`AutoTokenizer`] para carreg
-Então você pode especificar o modelo e o tokenizador na [`pipeline`] e aplicar o `classifier` no seu texto alvo:
+Então você pode especificar o modelo e o tokenizador na [`pipeline`] e aplicar o `classifier` no seu texto alvo:
```py
>>> classifier = pipeline("sentiment-analysis", model=model, tokenizer=tokenizer)
@@ -190,7 +190,7 @@ Se você não conseguir achar um modelo para o seu caso de uso, precisará usar
-Por baixo dos panos, as classes [`AutoModelForSequenceClassification`] e [`AutoTokenizer`] trabalham juntas para fortificar o [`pipeline`]. Um [AutoClass](./model_doc/auto) é um atalho que automaticamente recupera a arquitetura de um modelo pré-treinado a partir de seu nome ou caminho. Basta selecionar a `AutoClass` apropriada para sua tarefa e seu tokenizer associado com [`AutoTokenizer`].
+Por baixo dos panos, as classes [`AutoModelForSequenceClassification`] e [`AutoTokenizer`] trabalham juntas para fortificar o [`pipeline`]. Um [AutoClass](./model_doc/auto) é um atalho que automaticamente recupera a arquitetura de um modelo pré-treinado a partir de seu nome ou caminho. Basta selecionar a `AutoClass` apropriada para sua tarefa e seu tokenizer associado com [`AutoTokenizer`].
Vamos voltar ao nosso exemplo e ver como você pode usar a `AutoClass` para replicar os resultados do [`pipeline`].
@@ -383,8 +383,8 @@ Um recurso particularmente interessante dos 🤗 Transformers é a capacidade de
```py
>>> from transformers import AutoModel
->>> tokenizer = AutoTokenizer.from_pretrained(pt_save_directory)
->>> pt_model = AutoModelForSequenceClassification.from_pretrained(pt_save_directory, from_pt=True)
+>>> tokenizer = AutoTokenizer.from_pretrained(tf_save_directory)
+>>> pt_model = AutoModelForSequenceClassification.from_pretrained(tf_save_directory, from_tf=True)
```
@@ -392,8 +392,8 @@ Um recurso particularmente interessante dos 🤗 Transformers é a capacidade de
```py
>>> from transformers import TFAutoModel
->>> tokenizer = AutoTokenizer.from_pretrained(tf_save_directory)
->>> tf_model = TFAutoModelForSequenceClassification.from_pretrained(tf_save_directory, from_tf=True)
+>>> tokenizer = AutoTokenizer.from_pretrained(pt_save_directory)
+>>> tf_model = TFAutoModelForSequenceClassification.from_pretrained(pt_save_directory, from_pt=True)
```
\ No newline at end of file
diff --git a/docs/source/te/quicktour.md b/docs/source/te/quicktour.md
index 6045b673d2d3d0..67e530f35f3294 100644
--- a/docs/source/te/quicktour.md
+++ b/docs/source/te/quicktour.md
@@ -366,8 +366,8 @@ tensor([[0.0021, 0.0018, 0.0115, 0.2121, 0.7725],
```py
>>> from transformers import AutoModel
->>> tokenizer = AutoTokenizer.from_pretrained(pt_save_directory)
->>> pt_model = AutoModelForSequenceClassification.from_pretrained(pt_save_directory, from_pt=True)
+>>> tokenizer = AutoTokenizer.from_pretrained(tf_save_directory)
+>>> pt_model = AutoModelForSequenceClassification.from_pretrained(tf_save_directory, from_tf=True)
```
@@ -375,8 +375,8 @@ tensor([[0.0021, 0.0018, 0.0115, 0.2121, 0.7725],
```py
>>> from transformers import TFAutoModel
->>> tokenizer = AutoTokenizer.from_pretrained(tf_save_directory)
->>> tf_model = TFAutoModelForSequenceClassification.from_pretrained(tf_save_directory, from_tf=True)
+>>> tokenizer = AutoTokenizer.from_pretrained(pt_save_directory)
+>>> tf_model = TFAutoModelForSequenceClassification.from_pretrained(pt_save_directory, from_pt=True)
```
diff --git a/docs/source/zh/_toctree.yml b/docs/source/zh/_toctree.yml
index 2cce86b6592484..d4863efde710ea 100644
--- a/docs/source/zh/_toctree.yml
+++ b/docs/source/zh/_toctree.yml
@@ -23,10 +23,8 @@
title: 使用🤗 PEFT加载和训练adapters
- local: model_sharing
title: 分享您的模型
- - local: agents
- title: 智能体和工具
- - local: agents_advanced
- title: 智能体,超强版 - 多智能体、外部工具等
+ - local: transformers_agents
+ title: agents教程
- local: llm_tutorial
title: 使用LLMs进行生成
title: 教程
@@ -52,8 +50,6 @@
title: 导出为 TFLite
- local: torchscript
title: 导出为 TorchScript
- - local: benchmarks
- title: 对模型进行基准测试
- local: gguf
title: 与 GGUF 格式的互操作性
- local: tiktoken
@@ -69,10 +65,6 @@
title: 完全分片数据并行
- local: perf_train_special
title: 在 Apple silicon 芯片上进行 PyTorch 训练
- - local: perf_infer_gpu_multi
- title: 多GPU推理
- - local: perf_train_cpu
- title: 在CPU上进行高效训练
- local: perf_hardware
title: 用于训练的定制硬件
- local: hpo_train
@@ -108,7 +100,7 @@
- sections:
- sections:
- local: main_classes/agent
- title: 智能体和工具
+ title: Agents和工具
- local: main_classes/callback
title: Callbacks
- local: main_classes/configuration
diff --git a/docs/source/zh/agents.md b/docs/source/zh/agents.md
deleted file mode 100644
index 00fa74e6545025..00000000000000
--- a/docs/source/zh/agents.md
+++ /dev/null
@@ -1,427 +0,0 @@
-
-# 智能体和工具
-
-[[在colab里打开]]
-
-### 什么是智能体 (Agent)?
-
-大型语言模型(LLM)经过 [因果语言建模训练](./tasks/language_modeling) 可以应对各种任务,但在一些基本任务(如逻辑推理、计算和搜索)上常常表现不佳。当它们被用在自己不擅长的领域时,往往无法生成我们期望的答案。
-
-为了解决这个问题,可以创建**智能体**.
-
-智能体是一个系统,它使用 LLM 作为引擎,并且能够访问称为**工具**的功能。
-
-这些**工具**是执行任务的函数,包含所有必要的描述信息,帮助智能体正确使用它们。
-
-智能体可以被编程为:
-- 一次性设计一系列工具并同时执行它们,像 [`CodeAgent`]
-- 一次执行一个工具,并等待每个工具的结果后再启动下一个,像 [`ReactJsonAgent`]
-
-### 智能体类型
-
-#### 代码智能体
-
-此智能体包含一个规划步骤,然后生成 Python 代码一次性执行所有任务。它原生支持处理不同输入和输出类型,因此推荐用于多模态任务。
-
-#### 推理智能体
-
-这是解决推理任务的首选代理,因为 ReAct 框架 ([Yao et al., 2022](https://huggingface.co/papers/2210.03629)) 使其在基于之前观察进行推理时非常高效。
-
-我们实现了两种版本的 ReactJsonAgent:
-- [`ReactJsonAgent`] 将工具调用作为 JSON 格式输出。
-- [`ReactCodeAgent`] 是 ReactJsonAgent 的一种新型,生成工具调用的代码块,对于具备强大编程能力的 LLM 非常适用。
-
-> [TIP]
-> 阅读 [Open-source LLMs as LangChain Agents](https://huggingface.co/blog/open-source-llms-as-agents) 博文,了解更多关于推理智能体的信息。
-
-
-
-
-
-
-![推理智能体的框架](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/open-source-llms-as-agents/ReAct.png)
-
-以下是一个推理代码智能体如何处理以下问题的示例:
-
-```py3
->>> agent.run(
-... "How many more blocks (also denoted as layers) in BERT base encoder than the encoder from the architecture proposed in Attention is All You Need?",
-... )
-=====New task=====
-How many more blocks (also denoted as layers) in BERT base encoder than the encoder from the architecture proposed in Attention is All You Need?
-====Agent is executing the code below:
-bert_blocks = search(query="number of blocks in BERT base encoder")
-print("BERT blocks:", bert_blocks)
-====
-Print outputs:
-BERT blocks: twelve encoder blocks
-
-====Agent is executing the code below:
-attention_layer = search(query="number of layers in Attention is All You Need")
-print("Attention layers:", attention_layer)
-====
-Print outputs:
-Attention layers: Encoder: The encoder is composed of a stack of N = 6 identical layers. Each layer has two sub-layers. The first is a multi-head self-attention mechanism, and the second is a simple, position- 2 Page 3 Figure 1: The Transformer - model architecture.
-
-====Agent is executing the code below:
-bert_blocks = 12
-attention_layers = 6
-diff = bert_blocks - attention_layers
-print("Difference in blocks:", diff)
-final_answer(diff)
-====
-
-Print outputs:
-Difference in blocks: 6
-
-Final answer: 6
-```
-
-### 如何构建智能体?
-
-要初始化一个智能体,您需要以下参数:
-
-- **一个 LLM** 来驱动智能体——智能体本身并不是 LLM,而是一个使用 LLM 作为引擎的程序。
-- **一个系统提示**:告诉 LLM 引擎应该如何生成输出。
-- **一个工具箱**,智能体可以从中选择工具执行。
-- **一个解析器**,从 LLM 输出中提取出哪些工具需要调用,以及使用哪些参数。
-
-在智能体系统初始化时,工具属性将生成工具描述,并嵌入到智能体的系统提示中,告知智能体可以使用哪些工具,并且为什么使用它们。
-
-**安装依赖**
-
-首先,您需要安装**智能体**所需的额外依赖:
-
-```bash
-pip install transformers[agents]
-```
-**创建LLM引擎**
-
-定义一个 `llm_engine` 方法,该方法接受一系列[消息](./chat_templating)并返回文本。该 `callable` 还需要接受一个 `stop` 参数,用于指示何时停止生成输出。
-
-```python
-from huggingface_hub import login, InferenceClient
-
-login("")
-
-client = InferenceClient(model="meta-llama/Meta-Llama-3-70B-Instruct")
-
-def llm_engine(messages, stop_sequences=["Task"]) -> str:
- response = client.chat_completion(messages, stop=stop_sequences, max_tokens=1000)
- answer = response.choices[0].message.content
- return answer
-```
-
-您可以使用任何符合以下要求的 `llm_engine` 方法:
-1. [输入格式](./chat_templating)为 (`List[Dict[str, str]]`),并且返回一个字符串。
-2. 它在 `stop_sequences` 参数传递的序列处停止生成输出。
-
-此外,`llm_engine` 还可以接受一个 `grammar` 参数。如果在智能体初始化时指定了 `grammar`,则该参数将传递给 `llm_engine` 的调用,以允许[受限生成](https://huggingface.co/docs/text-generation-inference/conceptual/guidance),以强制生成格式正确的智能体输出。
-
-您还需要一个 `tools` 参数,它接受一个 `Tools` 列表 —— 可以是空列表。您也可以通过定义可选参数 `add_base_tools=True` 来将默认工具箱添加到工具列表中。
-
-现在,您可以创建一个智能体,例如 [`CodeAgent`],并运行它。您还可以创建一个 [`TransformersEngine`],使用 `transformers` 在本地机器上运行预初始化的推理管道。 为了方便起见,由于智能体行为通常需要更强大的模型,例如 `Llama-3.1-70B-Instruct`,它们目前较难在本地运行,我们还提供了 [`HfApiEngine`] 类,它在底层初始化了一个 `huggingface_hub.InferenceClient`。
-
-```python
-from transformers import CodeAgent, HfApiEngine
-
-llm_engine = HfApiEngine(model="meta-llama/Meta-Llama-3-70B-Instruct")
-agent = CodeAgent(tools=[], llm_engine=llm_engine, add_base_tools=True)
-
-agent.run(
- "Could you translate this sentence from French, say it out loud and return the audio.",
- sentence="Où est la boulangerie la plus proche?",
-)
-```
-
-当你急需某个东西时这将会很有用!
-您甚至可以将 `llm_engine` 参数留空,默认情况下会创建一个 [`HfApiEngine`]。
-
-```python
-from transformers import CodeAgent
-
-agent = CodeAgent(tools=[], add_base_tools=True)
-
-agent.run(
- "Could you translate this sentence from French, say it out loud and give me the audio.",
- sentence="Où est la boulangerie la plus proche?",
-)
-```
-
-请注意,我们使用了额外的 `sentence` 参数:您可以将文本作为附加参数传递给模型。
-
-您还可以使用这个来指定本地或远程文件的路径供模型使用:
-
-```py
-from transformers import ReactCodeAgent
-
-agent = ReactCodeAgent(tools=[], llm_engine=llm_engine, add_base_tools=True)
-
-agent.run("Why does Mike not know many people in New York?", audio="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/recording.mp3")
-```
-
-系统提示和输出解析器会自动定义,但您可以通过调用智能体的 `system_prompt_template` 来轻松查看它们。
-
-```python
-print(agent.system_prompt_template)
-```
-
-尽可能清楚地解释您要执行的任务非常重要。 每次 [`~Agent.run`] 操作都是独立的,并且由于智能体是由 LLM 驱动的,提示中的细微变化可能会导致完全不同的结果。
-您还可以连续运行多个任务,每次都会重新初始化智能体的 `agent.task` 和 `agent.logs` 属性。
-
-
-#### 代码执行
-
-Python 解释器在一组输入和工具上执行代码。 这应该是安全的,因为只能调用您提供的工具(特别是 Hugging Face 的工具)和 print 函数,因此您已经限制了可以执行的操作。
-
-Python 解释器默认不允许导入不在安全列表中的模块,因此大多数明显的攻击问题应该不成问题。 您仍然可以通过在 [`ReactCodeAgent`] 或 [`CodeAgent`] 初始化时通过 `additional_authorized_imports` 参数传递一个授权的模块列表来授权额外的导入:
-
-```py
->>> from transformers import ReactCodeAgent
-
->>> agent = ReactCodeAgent(tools=[], additional_authorized_imports=['requests', 'bs4'])
->>> agent.run("Could you get me the title of the page at url 'https://huggingface.co/blog'?")
-
-(...)
-'Hugging Face – Blog'
-```
-
-如果有任何代码尝试执行非法操作,或者生成的代码出现常规 Python 错误,执行将停止。
-
-> [!WARNING]
-> 在使用大语言模型(LLM)生成代码时,生成的代码会被执行,避免导入或使用任何不安全的库或模块。
-
-### 系统提示
-
-智能体,或者说驱动智能体的 LLM,根据系统提示生成输出。系统提示可以定制并根据目标任务进行调整。例如,检查 [`ReactCodeAgent`] 的系统提示(以下版本经过简化)。
-
-```text
-You will be given a task to solve as best you can.
-You have access to the following tools:
-<>
-
-To solve the task, you must plan forward to proceed in a series of steps, in a cycle of 'Thought:', 'Code:', and 'Observation:' sequences.
-
-At each step, in the 'Thought:' sequence, you should first explain your reasoning towards solving the task, then the tools that you want to use.
-Then in the 'Code:' sequence, you should write the code in simple Python. The code sequence must end with '/End code' sequence.
-During each intermediate step, you can use 'print()' to save whatever important information you will then need.
-These print outputs will then be available in the 'Observation:' field, for using this information as input for the next step.
-
-In the end you have to return a final answer using the `final_answer` tool.
-
-Here are a few examples using notional tools:
----
-{examples}
-
-Above example were using notional tools that might not exist for you. You only have acces to those tools:
-<>
-You also can perform computations in the python code you generate.
-
-Always provide a 'Thought:' and a 'Code:\n```py' sequence ending with '```' sequence. You MUST provide at least the 'Code:' sequence to move forward.
-
-Remember to not perform too many operations in a single code block! You should split the task into intermediate code blocks.
-Print results at the end of each step to save the intermediate results. Then use final_answer() to return the final result.
-
-Remember to make sure that variables you use are all defined.
-
-Now Begin!
-```
-
-系统提示包括:
-- 解释智能体应该如何工作以及工具的**介绍**。
-- 所有工具的描述由 `<>` 标记在运行时动态替换,这样智能体就知道可以使用哪些工具及其用途。
- - 工具的描述来自工具的属性,`name`、`description`、`inputs` 和 `output_type`,以及一个简单的 `jinja2` 模板,您可以根据需要进行调整。
-- 期望的输出格式。
-
-您可以通过向 `system_prompt` 参数传递自定义提示来最大程度地提高灵活性,从而覆盖整个系统提示模板。
-
-```python
-from transformers import ReactJsonAgent
-from transformers.agents import PythonInterpreterTool
-
-agent = ReactJsonAgent(tools=[PythonInterpreterTool()], system_prompt="{your_custom_prompt}")
-```
-
-> [WARNING]
-> 必须在`template`中定义 `<>` 这个变量,以便智能体能够正确地识别并使用可用的工具
-
-
-### 检查智能体的运行
-
-以下是检查运行后发生了什么的一些有用属性:
-- `agent.logs` 存储了智能体的详细日志。每一步的所有内容都会存储在一个字典中,然后附加到 `agent.logs`。
-- 运行 `agent.write_inner_memory_from_logs()` 会从日志中创建智能体的内存,以便 LLM 查看,作为一系列聊天消息。此方法会遍历日志的每个步骤,只保存其感兴趣的消息:例如,它会单独保存系统提示和任务,然后为每个步骤保存 LLM 输出的消息,以及工具调用输出的消息。如果您想要更高层次的查看发生了什么,可以使用此方法 —— 但并不是每个日志都会被此方法转录。
-
-## 工具
-
-工具是智能体使用的基本功能。
-
-例如,您可以检查 [`PythonInterpreterTool`]:它有一个名称、描述、输入描述、输出类型和 `__call__` 方法来执行该操作。
-
-当智能体初始化时,工具属性会用来生成工具描述,然后将其嵌入到智能体的系统提示中,这让智能体知道可以使用哪些工具以及为什么使用它们。
-
-### 默认工具箱
-
-Transformers 提供了一个默认工具箱,用于增强智能体,您可以在初始化时通过 `add_base_tools=True` 参数将其添加到智能体中:
-
-- **文档问答**:给定一个文档(如图像格式的 PDF),回答关于该文档的问题([Donut](./model_doc/donut))
-- **图像问答**:给定一张图片,回答关于该图像的问题([VILT](./model_doc/vilt))
-- **语音转文本**:给定一个人讲述的音频录音,将其转录为文本(Whisper)
-- **文本转语音**:将文本转换为语音([SpeechT5](./model_doc/speecht5))
-- **翻译**:将给定的句子从源语言翻译为目标语言
-- **DuckDuckGo 搜索**:使用 `DuckDuckGo` 浏览器进行网络搜索
-- **Python 代码解释器**:在安全环境中运行 LLM 生成的 Python 代码。只有在初始化 [`ReactJsonAgent`] 时将 `add_base_tools=True` 时,代码智能体才会添加此工具,因为基于代码的智能体已经能够原生执行 Python 代码
-
-
-您可以通过调用 [`load_tool`] 函数来手动使用某个工具并执行任务。
-
-
-```python
-from transformers import load_tool
-
-tool = load_tool("text-to-speech")
-audio = tool("This is a text to speech tool")
-```
-
-
-### 创建新工具
-
-您可以为 `Hugging Face` 默认工具无法涵盖的用例创建自己的工具。
-例如,假设我们要创建一个返回在 `Hugging Face Hub` 上某个任务中下载次数最多的模型的工具。
-
-您将从以下代码开始:
-
-```python
-from huggingface_hub import list_models
-
-task = "text-classification"
-
-model = next(iter(list_models(filter=task, sort="downloads", direction=-1)))
-print(model.id)
-```
-
-这段代码可以很快转换为工具,只需将其包装成一个函数,并添加 `tool` 装饰器:
-
-
-```py
-from transformers import tool
-
-@tool
-def model_download_tool(task: str) -> str:
- """
- This is a tool that returns the most downloaded model of a given task on the Hugging Face Hub.
- It returns the name of the checkpoint.
-
- Args:
- task: The task for which
- """
- model = next(iter(list_models(filter="text-classification", sort="downloads", direction=-1)))
- return model.id
-```
-
-该函数需要:
-- 一个清晰的名称。名称通常描述工具的功能。由于代码返回某个任务中下载次数最多的模型,因此我们将其命名为 `model_download_tool`。
-- 对输入和输出进行类型提示
-- 描述,其中包括 "`Args`:" 部分,描述每个参数(这次不需要类型指示,它会从类型提示中获取)。
-
-所有这些将自动嵌入到智能体的系统提示中,因此请尽量使它们尽可能清晰!
-
-> [TIP]
-> 这个定义格式与 apply_chat_template 中使用的工具模式相同,唯一的区别是添加了 tool 装饰器:可以在我们的工具使用 API 中[了解更多](https://huggingface.co/blog/unified-tool-use#passing-tools-to-a-chat-template).
-
-然后,您可以直接初始化您的智能体:
-```py
-from transformers import CodeAgent
-agent = CodeAgent(tools=[model_download_tool], llm_engine=llm_engine)
-agent.run(
- "Can you give me the name of the model that has the most downloads in the 'text-to-video' task on the Hugging Face Hub?"
-)
-```
-
-您将得到以下输出:
-```text
-======== New task ========
-Can you give me the name of the model that has the most downloads in the 'text-to-video' task on the Hugging Face Hub?
-==== Agent is executing the code below:
-most_downloaded_model = model_download_tool(task="text-to-video")
-print(f"The most downloaded model for the 'text-to-video' task is {most_downloaded_model}.")
-====
-```
-
-输出:
-`"The most downloaded model for the 'text-to-video' task is ByteDance/AnimateDiff-Lightning."`
-
-### 管理智能体的工具箱
-
-如果您已经初始化了一个智能体,但想添加一个新的工具,重新初始化智能体会很麻烦。借助 Transformers,您可以通过添加或替换工具来管理智能体的工具箱。
-
-让我们将 `model_download_tool` 添加到一个仅初始化了默认工具箱的现有智能体中。
-
-```python
-from transformers import CodeAgent
-
-agent = CodeAgent(tools=[], llm_engine=llm_engine, add_base_tools=True)
-agent.toolbox.add_tool(model_download_tool)
-```
-现在,我们可以同时使用新工具和之前的文本到语音工具:
-
-```python
-agent.run(
- "Can you read out loud the name of the model that has the most downloads in the 'text-to-video' task on the Hugging Face Hub and return the audio?"
-)
-```
-
-
-| **Audio** |
-|------------------------------------------------------------------------------------------------------------------------------------------------------|
-|
token (`str`, `optional`):
- Whether to use authentication token to load the remote folder. Useful to load private repositories
+ Whether to use authentication token to load the remote folder. Userful to load private repositories
that are on HuggingFace Hub. You might need to call `huggingface-cli login` and paste your tokens to
cache it.
device_map (`str` or `Dict[str, Union[int, str, torch.device]]` or `int` or `torch.device`, *optional*):
diff --git a/src/transformers/integrations/sdpa_attention.py b/src/transformers/integrations/sdpa_attention.py
deleted file mode 100644
index 38701690bf7c2a..00000000000000
--- a/src/transformers/integrations/sdpa_attention.py
+++ /dev/null
@@ -1,59 +0,0 @@
-from typing import Optional, Tuple
-
-import torch
-
-
-def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
- """
- This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
- num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
- """
- batch, num_key_value_heads, slen, head_dim = hidden_states.shape
- if n_rep == 1:
- return hidden_states
- hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
- return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
-
-
-def sdpa_attention_forward(
- module: torch.nn.Module,
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- attention_mask: Optional[torch.Tensor],
- dropout: float = 0.0,
- scaling: Optional[float] = None,
- is_causal: Optional[bool] = None,
- **kwargs,
-) -> Tuple[torch.Tensor, None]:
- if hasattr(module, "num_key_value_groups"):
- key = repeat_kv(key, module.num_key_value_groups)
- value = repeat_kv(value, module.num_key_value_groups)
-
- causal_mask = attention_mask
- if attention_mask is not None:
- causal_mask = causal_mask[:, :, :, : key.shape[-2]]
-
- # SDPA with memory-efficient backend is bugged with non-contiguous inputs and custom attn_mask for some torch versions
- # Reference: https://github.com/pytorch/pytorch/issues/112577.
- query = query.contiguous()
- key = key.contiguous()
- value = value.contiguous()
-
- # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
- # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
- if is_causal is None:
- is_causal = causal_mask is None and query.shape[2] > 1
-
- attn_output = torch.nn.functional.scaled_dot_product_attention(
- query,
- key,
- value,
- attn_mask=causal_mask,
- dropout_p=dropout,
- scale=scaling,
- is_causal=is_causal,
- )
- attn_output = attn_output.transpose(1, 2).contiguous()
-
- return attn_output, None
diff --git a/src/transformers/integrations/vptq.py b/src/transformers/integrations/vptq.py
deleted file mode 100644
index aa435517e81ebe..00000000000000
--- a/src/transformers/integrations/vptq.py
+++ /dev/null
@@ -1,101 +0,0 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"VPTQ (Vector Post-Training Quantization) integration file"
-
-import torch.nn as nn
-from accelerate import init_empty_weights
-from vptq import VQuantLinear
-
-
-def replace_with_vptq_linear(
- model,
- quantization_config=None,
- modules_to_not_convert=None,
- current_key_name=None,
- has_been_replaced=False,
-):
- """
- Public method that recursively replaces the Linear layers of the given model with VPTQ quantized layers.
- `accelerate` is needed to use this method. Returns the converted model and a boolean that indicates if the
- conversion has been successfull or not.
-
- Args:
- model (`torch.nn.Module`):
- The model to convert, can be any `torch.nn.Module` instance.
- quantization_config (`VptqConfig`):
- The quantization config object that contains the quantization parameters.
- modules_to_not_convert (`List[`str`]`, *optional*, defaults to `["lm_head"]`):
- Names of the modules to not convert in `VQuantLinear`. In practice we keep the `lm_head` in full precision
- for numerical stability reasons.
- current_key_name (`list`, *optional*):
- A list that contains the current key name. This is used for recursion and should not be passed by the user.
- has_been_replaced (`bool`, *optional*):
- A boolean that indicates if the conversion has been successful or not. This is used for recursion and
- should not be passed by the user.
- """
-
- modules_to_not_convert = ["lm_head"] if not modules_to_not_convert else modules_to_not_convert
-
- for name, module in model.named_children():
- if current_key_name is None:
- current_key_name = []
- current_key_name.append(name)
- layer_name = ".".join(current_key_name)
- shared_layer_config = quantization_config.shared_layer_config
- config_for_layers = quantization_config.config_for_layers
-
- if (
- isinstance(module, nn.Linear)
- and layer_name not in modules_to_not_convert
- and ((layer_name in config_for_layers) or (current_key_name[-1] in shared_layer_config))
- ):
- layer_params = config_for_layers.get(layer_name, None) or shared_layer_config.get(
- current_key_name[-1], None
- )
-
- with init_empty_weights():
- in_features = module.in_features
- out_features = module.out_features
-
- model._modules[name] = VQuantLinear(
- in_features,
- out_features,
- vector_lens=layer_params["vector_lens"],
- num_centroids=layer_params["num_centroids"],
- num_res_centroids=layer_params["num_res_centroids"],
- group_num=layer_params["group_num"],
- group_size=layer_params["group_size"],
- outlier_size=layer_params["outlier_size"],
- indices_as_float=layer_params["indices_as_float"],
- enable_norm=layer_params["enable_norm"],
- enable_perm=layer_params["enable_perm"],
- is_indice_packed=True,
- enable_proxy_error=False,
- bias=module.bias is not None,
- )
- has_been_replaced = True
-
- # Force requires grad to False to avoid unexpected errors
- model._modules[name].requires_grad_(False)
- if len(list(module.children())) > 0:
- _, has_been_replaced = replace_with_vptq_linear(
- module,
- quantization_config=quantization_config,
- modules_to_not_convert=modules_to_not_convert,
- current_key_name=current_key_name,
- has_been_replaced=has_been_replaced,
- )
- # Remove the last key for recursion
- current_key_name.pop(-1)
- return model, has_been_replaced
diff --git a/src/transformers/loss/loss_utils.py b/src/transformers/loss/loss_utils.py
index 7f6aaaa44264ca..efa23d24e360b4 100644
--- a/src/transformers/loss/loss_utils.py
+++ b/src/transformers/loss/loss_utils.py
@@ -47,22 +47,6 @@ def ForCausalLMLoss(
return loss
-def ForMaskedLMLoss(
- logits, labels, vocab_size: int, num_items_in_batch: int = None, ignore_index: int = -100, **kwargs
-):
- # Upcast to float if we need to compute the loss to avoid potential precision issues
- logits = logits.float()
-
- # Flatten the tokens
- logits = logits.view(-1, vocab_size)
- labels = labels.view(-1)
- # Enable model parallelism
-
- labels = labels.to(logits.device)
- loss = fixed_cross_entropy(logits, labels, num_items_in_batch, ignore_index, **kwargs)
- return loss
-
-
def ForSequenceClassificationLoss(labels, pooled_logits, config, **kwargs):
num_labels = config.num_labels
if config.problem_type is None:
@@ -117,7 +101,6 @@ def ForTokenClassification(logits, labels, config, **kwargs):
LOSS_MAPPING = {
"ForCausalLM": ForCausalLMLoss,
- "ForMaskedLM": ForMaskedLMLoss,
"ForQuestionAnswering": ForQuestionAnsweringLoss,
"ForSequenceClassification": ForSequenceClassificationLoss,
"ForTokenClassification": ForTokenClassification,
diff --git a/src/transformers/modeling_attn_mask_utils.py b/src/transformers/modeling_attn_mask_utils.py
index 09fc77e46b07ed..4319c021cb2bc3 100755
--- a/src/transformers/modeling_attn_mask_utils.py
+++ b/src/transformers/modeling_attn_mask_utils.py
@@ -169,10 +169,6 @@ def _make_causal_mask(
diagonal = past_key_values_length - sliding_window - 1
context_mask = torch.tril(torch.ones_like(mask, dtype=torch.bool), diagonal=diagonal)
- # Recent changes in PyTorch prevent mutations on tensors converted with aten::_to_copy
- # See https://github.com/pytorch/pytorch/issues/127571
- if is_torchdynamo_compiling():
- mask = mask.clone()
mask.masked_fill_(context_mask, torch.finfo(dtype).min)
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
diff --git a/src/transformers/modeling_flash_attention_utils.py b/src/transformers/modeling_flash_attention_utils.py
index 6adda0036cc096..ec03ba1eb5fd83 100644
--- a/src/transformers/modeling_flash_attention_utils.py
+++ b/src/transformers/modeling_flash_attention_utils.py
@@ -247,7 +247,6 @@ def _flash_attention_forward(
max_length_q: Optional[int] = None,
max_length_k: Optional[int] = None,
target_dtype: Optional[torch.dtype] = None,
- **kwargs,
):
"""
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
@@ -277,7 +276,7 @@ def _flash_attention_forward(
if not use_top_left_mask:
causal = is_causal
else:
- # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1.
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__.
causal = is_causal and query_length != 1
# Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length).
diff --git a/src/transformers/modeling_flax_pytorch_utils.py b/src/transformers/modeling_flax_pytorch_utils.py
index 8fbba8a1651364..8bbd8587b683f4 100644
--- a/src/transformers/modeling_flax_pytorch_utils.py
+++ b/src/transformers/modeling_flax_pytorch_utils.py
@@ -63,6 +63,8 @@ def load_pytorch_checkpoint_in_flax_state_dict(
else:
try:
import torch # noqa: F401
+
+ from .pytorch_utils import is_torch_greater_or_equal_than_1_13 # noqa: F401
except (ImportError, ModuleNotFoundError):
logger.error(
"Loading a PyTorch model in Flax, requires both PyTorch and Flax to be installed. Please see"
@@ -71,7 +73,7 @@ def load_pytorch_checkpoint_in_flax_state_dict(
)
raise
- weights_only_kwarg = {"weights_only": True}
+ weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
pt_state_dict = torch.load(pt_path, map_location="cpu", **weights_only_kwarg)
logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters.")
@@ -244,11 +246,13 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
def convert_pytorch_sharded_state_dict_to_flax(shard_filenames, flax_model):
import torch
+ from .pytorch_utils import is_torch_greater_or_equal_than_1_13
+
# Load the index
flax_state_dict = {}
for shard_file in shard_filenames:
# load using msgpack utils
- weights_only_kwarg = {"weights_only": True}
+ weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
pt_state_dict = torch.load(shard_file, **weights_only_kwarg)
weight_dtypes = {k: v.dtype for k, v in pt_state_dict.items()}
pt_state_dict = {
diff --git a/src/transformers/modeling_gguf_pytorch_utils.py b/src/transformers/modeling_gguf_pytorch_utils.py
index 00c080fbea81c7..7562649be753bb 100644
--- a/src/transformers/modeling_gguf_pytorch_utils.py
+++ b/src/transformers/modeling_gguf_pytorch_utils.py
@@ -307,7 +307,7 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False):
ffn_norm_name = "ffn_norm"
qkv_bias = any(bias_name in tensor.name for tensor in reader.tensors for bias_name in attn_bias_name)
use_parallel_residual = any(ffn_norm_name in tensor.name for tensor in reader.tensors)
- parsed_parameters["config"]["use_qkv_bias"] = qkv_bias
+ parsed_parameters["config"]["qkv_bias"] = qkv_bias
parsed_parameters["config"]["use_parallel_residual"] = not use_parallel_residual
model_size = ""
diff --git a/src/transformers/modeling_tf_pytorch_utils.py b/src/transformers/modeling_tf_pytorch_utils.py
index 8ec24d6e1872ef..7f1367481ade62 100644
--- a/src/transformers/modeling_tf_pytorch_utils.py
+++ b/src/transformers/modeling_tf_pytorch_utils.py
@@ -180,6 +180,8 @@ def load_pytorch_checkpoint_in_tf2_model(
import tensorflow as tf # noqa: F401
import torch # noqa: F401
from safetensors.torch import load_file as safe_load_file # noqa: F401
+
+ from .pytorch_utils import is_torch_greater_or_equal_than_1_13 # noqa: F401
except ImportError:
logger.error(
"Loading a PyTorch model in TensorFlow, requires both PyTorch and TensorFlow to be installed. Please see "
@@ -199,7 +201,7 @@ def load_pytorch_checkpoint_in_tf2_model(
if pt_path.endswith(".safetensors"):
state_dict = safe_load_file(pt_path)
else:
- weights_only_kwarg = {"weights_only": True}
+ weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
state_dict = torch.load(pt_path, map_location="cpu", **weights_only_kwarg)
pt_state_dict.update(state_dict)
diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py
index a6d4a1cc5b54ed..dae29111c8dcc0 100755
--- a/src/transformers/modeling_utils.py
+++ b/src/transformers/modeling_utils.py
@@ -29,7 +29,7 @@
from contextlib import contextmanager
from dataclasses import dataclass
from functools import partial, wraps
-from threading import Thread
+from multiprocessing import Process
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union
from zipfile import is_zipfile
@@ -45,15 +45,13 @@
from .dynamic_module_utils import custom_object_save
from .generation import CompileConfig, GenerationConfig, GenerationMixin
from .integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled
-from .integrations.flash_attention import flash_attention_forward
-from .integrations.flex_attention import flex_attention_forward
-from .integrations.sdpa_attention import sdpa_attention_forward
from .loss.loss_utils import LOSS_MAPPING
from .pytorch_utils import ( # noqa: F401
Conv1D,
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
id_tensor_storage,
+ is_torch_greater_or_equal_than_1_13,
prune_conv1d_layer,
prune_layer,
prune_linear_layer,
@@ -173,8 +171,10 @@ def is_local_dist_rank_0():
if is_peft_available():
from .utils import find_adapter_config_file
+
SpecificPreTrainedModelType = TypeVar("SpecificPreTrainedModelType", bound="PreTrainedModel")
+
TORCH_INIT_FUNCTIONS = {
"uniform_": nn.init.uniform_,
"normal_": nn.init.normal_,
@@ -475,7 +475,7 @@ def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True):
error_message += f"\nMissing key(s): {str_unexpected_keys}."
raise RuntimeError(error_message)
- weights_only_kwarg = {"weights_only": True}
+ weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
loader = safe_load_file if load_safe else partial(torch.load, map_location="cpu", **weights_only_kwarg)
for shard_file in shard_files:
@@ -503,7 +503,7 @@ def load_state_dict(
# Check format of the archive
with safe_open(checkpoint_file, framework="pt") as f:
metadata = f.metadata()
- if metadata is not None and metadata.get("format") not in ["pt", "tf", "flax", "mlx"]:
+ if metadata.get("format") not in ["pt", "tf", "flax", "mlx"]:
raise OSError(
f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure "
"you save your model with the `save_pretrained` method."
@@ -531,7 +531,7 @@ def load_state_dict(
and is_zipfile(checkpoint_file)
):
extra_args = {"mmap": True}
- weights_only_kwarg = {"weights_only": weights_only}
+ weights_only_kwarg = {"weights_only": weights_only} if is_torch_greater_or_equal_than_1_13 else {}
return torch.load(
checkpoint_file,
map_location=map_location,
@@ -652,6 +652,36 @@ def _find_identical(tensors: List[Set[str]], state_dict: Dict[str, torch.Tensor]
def _load_state_dict_into_model(model_to_load, state_dict, start_prefix, assign_to_params_buffers=False):
+ # Convert old format to new format if needed from a PyTorch state_dict
+ old_keys = []
+ new_keys = []
+ renamed_keys = {}
+ renamed_gamma = {}
+ renamed_beta = {}
+ warning_msg = f"A pretrained model of type `{model_to_load.__class__.__name__}` "
+ for key in state_dict.keys():
+ new_key = None
+ if "gamma" in key:
+ # We add only the first key as an example
+ new_key = key.replace("gamma", "weight")
+ renamed_gamma[key] = new_key if not renamed_gamma else renamed_gamma
+ if "beta" in key:
+ # We add only the first key as an example
+ new_key = key.replace("beta", "bias")
+ renamed_beta[key] = new_key if not renamed_beta else renamed_beta
+ if new_key:
+ old_keys.append(key)
+ new_keys.append(new_key)
+ renamed_keys = {**renamed_gamma, **renamed_beta}
+ if renamed_keys:
+ warning_msg += "contains parameters that have been renamed internally (a few are listed below but more are present in the model):\n"
+ for old_key, new_key in renamed_keys.items():
+ warning_msg += f"* `{old_key}` -> `{new_key}`\n"
+ warning_msg += "If you are using a model from the Hub, consider submitting a PR to adjust these weights and help future users."
+ logger.info_once(warning_msg)
+ for old_key, new_key in zip(old_keys, new_keys):
+ state_dict[new_key] = state_dict.pop(old_key)
+
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, "_metadata", None)
state_dict = state_dict.copy()
@@ -782,7 +812,46 @@ def _load_state_dict_into_meta_model(
error_msgs = []
+ old_keys = []
+ new_keys = []
+ renamed_gamma = {}
+ renamed_beta = {}
is_quantized = hf_quantizer is not None
+ warning_msg = f"This model {type(model)}"
+ for key in state_dict.keys():
+ new_key = None
+ if "gamma" in key:
+ # We add only the first key as an example
+ new_key = key.replace("gamma", "weight")
+ renamed_gamma[key] = new_key if not renamed_gamma else renamed_gamma
+ if "beta" in key:
+ # We add only the first key as an example
+ new_key = key.replace("beta", "bias")
+ renamed_beta[key] = new_key if not renamed_beta else renamed_beta
+
+ # To reproduce `_load_state_dict_into_model` behaviour, we need to manually rename parametrized weigth norm, if necessary.
+ if hasattr(nn.utils.parametrizations, "weight_norm"):
+ if "weight_g" in key:
+ new_key = key.replace("weight_g", "parametrizations.weight.original0")
+ if "weight_v" in key:
+ new_key = key.replace("weight_v", "parametrizations.weight.original1")
+ else:
+ if "parametrizations.weight.original0" in key:
+ new_key = key.replace("parametrizations.weight.original0", "weight_g")
+ if "parametrizations.weight.original1" in key:
+ new_key = key.replace("parametrizations.weight.original1", "weight_v")
+ if new_key:
+ old_keys.append(key)
+ new_keys.append(new_key)
+ renamed_keys = {**renamed_gamma, **renamed_beta}
+ if renamed_keys:
+ warning_msg += "contains parameters that have been renamed internally (a few are listed below but more are present in the model):\n"
+ for old_key, new_key in renamed_keys.items():
+ warning_msg += f"* `{old_key}` -> `{new_key}`\n"
+ warning_msg += "If you are using a model from the Hub, consider submitting a PR to adjust these weights and help future users."
+ logger.info_once(warning_msg)
+ for old_key, new_key in zip(old_keys, new_keys):
+ state_dict[new_key] = state_dict.pop(old_key)
is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn")
@@ -1473,8 +1542,11 @@ def _autoset_attn_implementation(
)
if not isinstance(config._attn_implementation, dict) and config._attn_implementation not in [
- "eager"
- ] + list(ALL_ATTENTION_FUNCTIONS.keys()):
+ "eager",
+ "sdpa",
+ "flash_attention_2",
+ "flex_attention",
+ ]:
message = f'Specified `attn_implementation="{config._attn_implementation}"` is not supported. The only possible arguments are `attn_implementation="eager"` (manual attention implementation)'
if cls._supports_flash_attn_2:
message += ', `"attn_implementation=flash_attention_2"` (implementation using flash attention 2)'
@@ -1536,8 +1608,6 @@ def _autoset_attn_implementation(
"Using the `SDPA` attention implementation on multi-gpu setup with ROCM may lead to performance issues due to the FA backend. Disabling it to use alternative backends."
)
torch.backends.cuda.enable_flash_sdp(False)
- elif requested_attn_implementation in list(ALL_ATTENTION_FUNCTIONS.keys()):
- config._attn_implementation = requested_attn_implementation
elif isinstance(requested_attn_implementation, dict):
config._attn_implementation = None
else:
@@ -2818,11 +2888,6 @@ def save_pretrained(
for ignore_key in self._keys_to_ignore_on_save:
if ignore_key in state_dict.keys():
del state_dict[ignore_key]
-
- # Rename state_dict keys before saving to file. Do nothing unless overriden in a particular model.
- # (initially introduced with TimmWrapperModel to remove prefix and make checkpoints compatible with timm)
- state_dict = self._fix_state_dict_keys_on_save(state_dict)
-
if safe_serialization:
# Safetensors does not allow tensor aliasing.
# We're going to remove aliases before saving
@@ -3596,12 +3661,7 @@ def from_pretrained(
)
else:
config.quantization_config = quantization_config
-
- hf_quantizer = AutoHfQuantizer.from_config(
- config.quantization_config,
- pre_quantized=pre_quantized,
- )
-
+ hf_quantizer = AutoHfQuantizer.from_config(config.quantization_config, pre_quantized=pre_quantized)
else:
hf_quantizer = None
@@ -3829,11 +3889,11 @@ def from_pretrained(
**has_file_kwargs,
}
if not has_file(pretrained_model_name_or_path, safe_weights_name, **has_file_kwargs):
- Thread(
+ Process(
target=auto_conversion,
args=(pretrained_model_name_or_path,),
kwargs={"ignore_errors_during_conversion": True, **cached_file_kwargs},
- name="Thread-auto_conversion",
+ name="Process-auto_conversion",
).start()
else:
# Otherwise, no PyTorch file was found, maybe there is a TF or Flax model file.
@@ -3950,10 +4010,7 @@ def from_pretrained(
with safe_open(resolved_archive_file, framework="pt") as f:
metadata = f.metadata()
- if metadata is None:
- # Assume it's a pytorch checkpoint (introduced for timm checkpoints)
- pass
- elif metadata.get("format") == "pt":
+ if metadata.get("format") == "pt":
pass
elif metadata.get("format") == "tf":
from_tf = True
@@ -4021,11 +4078,8 @@ def from_pretrained(
loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]
else:
loaded_state_dict_keys = list(state_dict.keys())
- if (
- gguf_path is None
- and (low_cpu_mem_usage or (use_keep_in_fp32_modules and is_accelerate_available()))
- and pretrained_model_name_or_path is not None
- ):
+
+ if gguf_path is None and (low_cpu_mem_usage or (use_keep_in_fp32_modules and is_accelerate_available())):
# In case some weights need to be kept in float32 and accelerate is not installed,
# we later on want to take the path where state_dict is not None, that is the one
# that do not require accelerate.
@@ -4288,7 +4342,7 @@ def from_pretrained(
dispatch_model(model, **device_map_kwargs)
if hf_quantizer is not None:
- hf_quantizer.postprocess_model(model, config=config)
+ hf_quantizer.postprocess_model(model)
model.hf_quantizer = hf_quantizer
if _adapter_model_path is not None:
@@ -4321,72 +4375,6 @@ def from_pretrained(
return model
- @staticmethod
- def _fix_state_dict_key_on_load(key):
- """Replace legacy parameter names with their modern equivalents. E.g. beta -> bias, gamma -> weight."""
-
- if "beta" in key:
- return key.replace("beta", "bias")
- if "gamma" in key:
- return key.replace("gamma", "weight")
-
- # to avoid logging parametrized weight norm renaming
- if hasattr(nn.utils.parametrizations, "weight_norm"):
- if "weight_g" in key:
- return key.replace("weight_g", "parametrizations.weight.original0")
- if "weight_v" in key:
- return key.replace("weight_v", "parametrizations.weight.original1")
- else:
- if "parametrizations.weight.original0" in key:
- return key.replace("parametrizations.weight.original0", "weight_g")
- if "parametrizations.weight.original1" in key:
- return key.replace("parametrizations.weight.original1", "weight_v")
- return key
-
- @classmethod
- def _fix_state_dict_keys_on_load(cls, state_dict):
- """Fixes state dict keys by replacing legacy parameter names with their modern equivalents.
- Logs if any parameters have been renamed.
- """
-
- renamed_keys = {}
- state_dict_keys = list(state_dict.keys())
- for key in state_dict_keys:
- new_key = cls._fix_state_dict_key_on_load(key)
- if new_key != key:
- state_dict[new_key] = state_dict.pop(key)
-
- # add it once for logging
- if "gamma" in key and "gamma" not in renamed_keys:
- renamed_keys["gamma"] = (key, new_key)
- if "beta" in key and "beta" not in renamed_keys:
- renamed_keys["beta"] = (key, new_key)
-
- if renamed_keys:
- warning_msg = f"A pretrained model of type `{cls.__name__}` "
- warning_msg += "contains parameters that have been renamed internally (a few are listed below but more are present in the model):\n"
- for old_key, new_key in renamed_keys.values():
- warning_msg += f"* `{old_key}` -> `{new_key}`\n"
- warning_msg += "If you are using a model from the Hub, consider submitting a PR to adjust these weights and help future users."
- logger.info_once(warning_msg)
-
- return state_dict
-
- @staticmethod
- def _fix_state_dict_key_on_save(key):
- """
- Similar to `_fix_state_dict_key_on_load` allows to define hook for state dict key renaming on model save.
- Do nothing by default, but can be overriden in particular models.
- """
- return key
-
- def _fix_state_dict_keys_on_save(self, state_dict):
- """
- Similar to `_fix_state_dict_keys_on_load` allows to define hook for state dict key renaming on model save.
- Apply `_fix_state_dict_key_on_save` to all keys in `state_dict`.
- """
- return {self._fix_state_dict_key_on_save(key): value for key, value in state_dict.items()}
-
@classmethod
def _load_pretrained_model(
cls,
@@ -4442,8 +4430,27 @@ def _load_pretrained_model(
if hf_quantizer is not None:
expected_keys = hf_quantizer.update_expected_keys(model, expected_keys, loaded_keys)
+ def _fix_key(key):
+ if "beta" in key:
+ return key.replace("beta", "bias")
+ if "gamma" in key:
+ return key.replace("gamma", "weight")
+
+ # to avoid logging parametrized weight norm renaming
+ if hasattr(nn.utils.parametrizations, "weight_norm"):
+ if "weight_g" in key:
+ return key.replace("weight_g", "parametrizations.weight.original0")
+ if "weight_v" in key:
+ return key.replace("weight_v", "parametrizations.weight.original1")
+ else:
+ if "parametrizations.weight.original0" in key:
+ return key.replace("parametrizations.weight.original0", "weight_g")
+ if "parametrizations.weight.original1" in key:
+ return key.replace("parametrizations.weight.original1", "weight_v")
+ return key
+
original_loaded_keys = loaded_keys
- loaded_keys = [cls._fix_state_dict_key_on_load(key) for key in loaded_keys]
+ loaded_keys = [_fix_key(key) for key in loaded_keys]
if len(prefix) > 0:
has_prefix_module = any(s.startswith(prefix) for s in loaded_keys)
@@ -4608,23 +4615,23 @@ def _find_mismatched_keys(
state_dict,
model_state_dict,
loaded_keys,
- original_loaded_keys,
add_prefix_to_model,
remove_prefix_from_model,
ignore_mismatched_sizes,
):
mismatched_keys = []
if ignore_mismatched_sizes:
- for checkpoint_key, model_key in zip(original_loaded_keys, loaded_keys):
+ for checkpoint_key in loaded_keys:
# If the checkpoint is sharded, we may not have the key here.
if checkpoint_key not in state_dict:
continue
+ model_key = checkpoint_key
if remove_prefix_from_model:
# The model key starts with `prefix` but `checkpoint_key` doesn't so we add it.
- model_key = f"{prefix}.{model_key}"
+ model_key = f"{prefix}.{checkpoint_key}"
elif add_prefix_to_model:
# The model key doesn't start with `prefix` but `checkpoint_key` does so we remove it.
- model_key = ".".join(model_key.split(".")[1:])
+ model_key = ".".join(checkpoint_key.split(".")[1:])
if (
model_key in model_state_dict
@@ -4673,7 +4680,6 @@ def _find_mismatched_keys(
mismatched_keys = _find_mismatched_keys(
state_dict,
model_state_dict,
- loaded_keys,
original_loaded_keys,
add_prefix_to_model,
remove_prefix_from_model,
@@ -4681,11 +4687,10 @@ def _find_mismatched_keys(
)
# For GGUF models `state_dict` is never set to None as the state dict is always small
- if gguf_path or low_cpu_mem_usage:
- fixed_state_dict = cls._fix_state_dict_keys_on_load(state_dict)
+ if gguf_path:
error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
model_to_load,
- fixed_state_dict,
+ state_dict,
start_prefix,
expected_keys,
device_map=device_map,
@@ -4704,9 +4709,8 @@ def _find_mismatched_keys(
assign_to_params_buffers = check_support_param_buffer_assignment(
model_to_load, state_dict, start_prefix
)
- fixed_state_dict = cls._fix_state_dict_keys_on_load(state_dict)
error_msgs = _load_state_dict_into_model(
- model_to_load, fixed_state_dict, start_prefix, assign_to_params_buffers
+ model_to_load, state_dict, start_prefix, assign_to_params_buffers
)
else:
@@ -4757,7 +4761,6 @@ def _find_mismatched_keys(
mismatched_keys += _find_mismatched_keys(
state_dict,
model_state_dict,
- loaded_keys,
original_loaded_keys,
add_prefix_to_model,
remove_prefix_from_model,
@@ -4771,10 +4774,9 @@ def _find_mismatched_keys(
model_to_load, key, "cpu", torch.empty(*param.size(), dtype=dtype)
)
else:
- fixed_state_dict = cls._fix_state_dict_keys_on_load(state_dict)
new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
model_to_load,
- fixed_state_dict,
+ state_dict,
start_prefix,
expected_keys,
device_map=device_map,
@@ -4795,9 +4797,8 @@ def _find_mismatched_keys(
assign_to_params_buffers = check_support_param_buffer_assignment(
model_to_load, state_dict, start_prefix
)
- fixed_state_dict = cls._fix_state_dict_keys_on_load(state_dict)
error_msgs += _load_state_dict_into_model(
- model_to_load, fixed_state_dict, start_prefix, assign_to_params_buffers
+ model_to_load, state_dict, start_prefix, assign_to_params_buffers
)
# force memory release
@@ -4929,10 +4930,9 @@ def _load_pretrained_model_low_mem(
_move_model_to_meta(model, loaded_state_dict_keys, start_prefix)
state_dict = load_state_dict(resolved_archive_file, weights_only=weights_only)
expected_keys = loaded_state_dict_keys # plug for missing expected_keys. TODO: replace with proper keys
- fixed_state_dict = model._fix_state_dict_keys_on_load(state_dict)
error_msgs = _load_state_dict_into_meta_model(
model,
- fixed_state_dict,
+ state_dict,
start_prefix,
expected_keys=expected_keys,
hf_quantizer=hf_quantizer,
@@ -5633,14 +5633,3 @@ def get_disk_only_shard_files(device_map, sharded_metadata, start_prefix):
files_content[filename].append(device_map[weight_name])
return [fname for fname, devices in files_content.items() if set(devices) == {"disk"}]
-
-
-ALL_ATTENTION_FUNCTIONS: Dict[str, Dict[str, Callable]] = {}
-
-ALL_ATTENTION_FUNCTIONS.update(
- {
- "flash_attention_2": flash_attention_forward,
- "flex_attention": flex_attention_forward,
- "sdpa": sdpa_attention_forward,
- }
-)
diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py
index d06c680672dd6e..7f4085b5c8d0b8 100644
--- a/src/transformers/models/__init__.py
+++ b/src/transformers/models/__init__.py
@@ -20,7 +20,6 @@
audio_spectrogram_transformer,
auto,
autoformer,
- bamba,
bark,
bart,
barthez,
@@ -53,8 +52,6 @@
code_llama,
codegen,
cohere,
- cohere2,
- colpali,
conditional_detr,
convbert,
convnext,
@@ -168,7 +165,6 @@
mobilenet_v2,
mobilevit,
mobilevitv2,
- modernbert,
moshi,
mpnet,
mpt,
@@ -254,7 +250,6 @@
time_series_transformer,
timesformer,
timm_backbone,
- timm_wrapper,
trocr,
tvp,
udop,
diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py
index 6481d6f3c434c7..1b4e4087b1a49d 100644
--- a/src/transformers/models/aria/modeling_aria.py
+++ b/src/transformers/models/aria/modeling_aria.py
@@ -18,22 +18,24 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import math
from dataclasses import dataclass
-from typing import Callable, List, Optional, Tuple, Union
+from typing import List, Optional, Tuple, Union
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, StaticCache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter
-from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...modeling_flash_attention_utils import FlashAttentionKwargs, _flash_attention_forward
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, ModelOutput
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
-from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...modeling_utils import PreTrainedModel
from ...processing_utils import Unpack
from ...utils import (
LossKwargs,
add_start_docstrings,
add_start_docstrings_to_model_forward,
+ is_flash_attn_greater_or_equal_2_10,
logging,
replace_return_docstrings,
)
@@ -430,6 +432,93 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return output + shared_expert_output
+class AriaTextRotaryEmbedding(nn.Module):
+ def __init__(
+ self,
+ dim=None,
+ max_position_embeddings=2048,
+ base=10000,
+ device=None,
+ scaling_factor=1.0,
+ rope_type="default",
+ config: Optional[AriaTextConfig] = None,
+ ):
+ super().__init__()
+ # TODO (joao): remove the `if` below, only used for BC
+ self.rope_kwargs = {}
+ if config is None:
+ logger.warning_once(
+ "`AriaTextRotaryEmbedding` can now be fully parameterized by passing the model config through the "
+ "`config` argument. All other arguments will be removed in v4.46"
+ )
+ self.rope_kwargs = {
+ "rope_type": rope_type,
+ "factor": scaling_factor,
+ "dim": dim,
+ "base": base,
+ "max_position_embeddings": max_position_embeddings,
+ }
+ self.rope_type = rope_type
+ self.max_seq_len_cached = max_position_embeddings
+ self.original_max_seq_len = max_position_embeddings
+ else:
+ # BC: "rope_type" was originally "type"
+ if config.rope_scaling is not None:
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
+ else:
+ self.rope_type = "default"
+ self.max_seq_len_cached = config.max_position_embeddings
+ self.original_max_seq_len = config.max_position_embeddings
+
+ self.config = config
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
+
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ self.original_inv_freq = self.inv_freq
+
+ def _dynamic_frequency_update(self, position_ids, device):
+ """
+ dynamic RoPE layers should recompute `inv_freq` in the following situations:
+ 1 - growing beyond the cached sequence length (allow scaling)
+ 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
+ """
+ seq_len = torch.max(position_ids) + 1
+ if seq_len > self.max_seq_len_cached: # growth
+ inv_freq, self.attention_scaling = self.rope_init_fn(
+ self.config, device, seq_len=seq_len, **self.rope_kwargs
+ )
+ self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
+ self.max_seq_len_cached = seq_len
+
+ if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
+ self.max_seq_len_cached = self.original_max_seq_len
+
+ @torch.no_grad()
+ def forward(self, x, position_ids):
+ if "dynamic" in self.rope_type:
+ self._dynamic_frequency_update(position_ids, device=x.device)
+
+ # Core RoPE block
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
+ position_ids_expanded = position_ids[:, None, :].float()
+ # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
+ device_type = x.device.type
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False):
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos()
+ sin = emb.sin()
+
+ # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
+ cos = cos * self.attention_scaling
+ sin = sin * self.attention_scaling
+
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
@@ -476,75 +565,167 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
-def eager_attention_forward(
- module: nn.Module,
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- attention_mask: Optional[torch.Tensor],
- scaling: float,
- dropout: float = 0.0,
- **kwargs,
-):
- key_states = repeat_kv(key, module.num_key_value_groups)
- value_states = repeat_kv(value, module.num_key_value_groups)
-
- attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
- if attention_mask is not None:
- causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
- attn_weights = attn_weights + causal_mask
-
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
- attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
- attn_output = torch.matmul(attn_weights, value_states)
- attn_output = attn_output.transpose(1, 2).contiguous()
-
- return attn_output, attn_weights
-
-
class AriaTextAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
- def __init__(self, config: AriaTextConfig, layer_idx: int):
+ def __init__(self, config: AriaTextConfig, layer_idx: Optional[int] = None):
super().__init__()
self.config = config
self.layer_idx = layer_idx
- self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
- self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
- self.scaling = self.head_dim**-0.5
+ if layer_idx is None:
+ logger.warning_once(
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
+ "when creating this class."
+ )
+
self.attention_dropout = config.attention_dropout
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads)
+ self.num_key_value_heads = config.num_key_value_heads
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
+ self.max_position_embeddings = config.max_position_embeddings
+ self.rope_theta = config.rope_theta
self.is_causal = True
- self.q_proj = nn.Linear(
- config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
- )
- self.k_proj = nn.Linear(
- config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
- )
- self.v_proj = nn.Linear(
- config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
- )
- self.o_proj = nn.Linear(
- config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
- )
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
+
+ # TODO (joao): remove in v4.46 (RoPE is computed in the model, not in the decoder layers)
+ self.rotary_emb = AriaTextRotaryEmbedding(config=self.config)
def forward(
self,
hidden_states: torch.Tensor,
- position_embeddings: Tuple[torch.Tensor, torch.Tensor],
- attention_mask: Optional[torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ # use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used
+ query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
+
+ if position_embeddings is None:
+ logger.warning_once(
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
+ "removed and `position_embeddings` will be mandatory."
+ )
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ else:
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
+
+ if attention_mask is not None: # no matter the length, we just slice it
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ attn_output = attn_output.reshape(bsz, q_len, -1)
+
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+class AriaTextFlashAttention2(AriaTextAttention):
+ """
+ AriaText flash attention module. This module inherits from `AriaTextAttention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
- input_shape = hidden_states.shape[:-1]
- hidden_shape = (*input_shape, -1, self.head_dim)
+ if isinstance(past_key_value, StaticCache):
+ raise ValueError(
+ "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
+ "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
+ )
+
+ output_attentions = False
+
+ bsz, q_len, _ = hidden_states.size()
- query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
- key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
- value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
- cos, sin = position_embeddings
+ # Flash attention requires the input to have the shape
+ # batch_size x seq_length x head_dim x hidden_dim
+ # therefore we just need to keep the original shape
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ if position_embeddings is None:
+ logger.warning_once(
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
+ "removed and `position_embeddings` will be mandatory."
+ )
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ else:
+ cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
@@ -552,30 +733,168 @@ def forward(
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
- attention_interface: Callable = eager_attention_forward
- if self.config._attn_implementation != "eager":
- if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
- logger.warning_once(
- "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
- 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
- )
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
+ # to be able to avoid many of these transpose/reshape/view.
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ dropout_rate = self.attention_dropout if self.training else 0.0
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in the correct dtype just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+ # in fp32. (AriaTextRMSNorm handles it correctly)
+
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
else:
- attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+ target_dtype = self.q_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
- attn_output, attn_weights = attention_interface(
- self,
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ attn_output = _flash_attention_forward(
query_states,
key_states,
value_states,
attention_mask,
- dropout=0.0 if not self.training else self.attention_dropout,
- scaling=self.scaling,
+ q_len,
+ position_ids=position_ids,
+ dropout=dropout_rate,
+ sliding_window=getattr(self, "sliding_window", None),
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
+ is_causal=self.is_causal,
**kwargs,
)
- attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
attn_output = self.o_proj(attn_output)
- return attn_output, attn_weights
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+class AriaTextSdpaAttention(AriaTextAttention):
+ """
+ AriaText attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
+ `AriaTextAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
+ SDPA API.
+ """
+
+ # Adapted from AriaTextAttention.forward
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ if output_attentions:
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
+ logger.warning_once(
+ "AriaTextModel is using AriaTextSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
+ )
+ return super().forward(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ )
+
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ # use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used
+ query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
+
+ if position_embeddings is None:
+ logger.warning_once(
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
+ "removed and `position_embeddings` will be mandatory."
+ )
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ else:
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ causal_mask = attention_mask
+ if attention_mask is not None:
+ causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
+
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
+ if query_states.device.type == "cuda" and causal_mask is not None:
+ query_states = query_states.contiguous()
+ key_states = key_states.contiguous()
+ value_states = value_states.contiguous()
+
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
+ is_causal = True if causal_mask is None and q_len > 1 else False
+
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
+ query_states,
+ key_states,
+ value_states,
+ attn_mask=causal_mask,
+ dropout_p=self.attention_dropout if self.training else 0.0,
+ is_causal=is_causal,
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.view(bsz, q_len, -1)
+
+ attn_output = self.o_proj(attn_output)
+
+ return attn_output, None, past_key_value
+
+
+ARIA_TEXT_ATTENTION_CLASSES = {
+ "eager": AriaTextAttention,
+ "flash_attention_2": AriaTextFlashAttention2,
+ "sdpa": AriaTextSdpaAttention,
+}
class AriaTextDecoderLayer(nn.Module):
@@ -595,7 +914,7 @@ def __init__(self, config: AriaTextConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
- self.self_attn = AriaTextAttention(config=config, layer_idx=layer_idx)
+ self.self_attn = ARIA_TEXT_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
self.mlp = AriaTextMoELayer(config)
self.input_layernorm = AriaTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = AriaTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -609,15 +928,37 @@ def forward(
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
- **kwargs: Unpack[FlashAttentionKwargs],
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
+ **kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`, *optional*):
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
+ query_sequence_length, key_sequence_length)` if default attention is used.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence
+ position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
+ Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
+ with `head_dim` being the embedding dimension of each attention head.
+ kwargs (`dict`, *optional*):
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
+ into the model
+ """
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
- hidden_states, self_attn_weights = self.self_attn(
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
@@ -637,9 +978,13 @@ def forward(
hidden_states = residual + hidden_states
outputs = (hidden_states,)
+
if output_attentions:
outputs += (self_attn_weights,)
+ if use_cache:
+ outputs += (present_key_value,)
+
return outputs
@@ -722,71 +1067,6 @@ def _init_weights(self, module):
nn.init.trunc_normal_(module.query, std=std)
-class AriaTextRotaryEmbedding(nn.Module):
- def __init__(
- self,
- config: AriaTextConfig,
- device=None,
- ):
- super().__init__()
- self.rope_kwargs = {}
- # BC: "rope_type" was originally "type"
- if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
- self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
- else:
- self.rope_type = "default"
- self.max_seq_len_cached = config.max_position_embeddings
- self.original_max_seq_len = config.max_position_embeddings
-
- self.config = config
- self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
-
- inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
- self.register_buffer("inv_freq", inv_freq, persistent=False)
- self.original_inv_freq = self.inv_freq
-
- def _dynamic_frequency_update(self, position_ids, device):
- """
- dynamic RoPE layers should recompute `inv_freq` in the following situations:
- 1 - growing beyond the cached sequence length (allow scaling)
- 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
- """
- seq_len = torch.max(position_ids) + 1
- if seq_len > self.max_seq_len_cached: # growth
- inv_freq, self.attention_scaling = self.rope_init_fn(
- self.config, device, seq_len=seq_len, **self.rope_kwargs
- )
- self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
- self.max_seq_len_cached = seq_len
-
- if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
- self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
- self.max_seq_len_cached = self.original_max_seq_len
-
- @torch.no_grad()
- def forward(self, x, position_ids):
- if "dynamic" in self.rope_type:
- self._dynamic_frequency_update(position_ids, device=x.device)
-
- # Core RoPE block
- inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
- position_ids_expanded = position_ids[:, None, :].float()
- # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
- device_type = x.device.type
- device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
- with torch.autocast(device_type=device_type, enabled=False):
- freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
- emb = torch.cat((freqs, freqs), dim=-1)
- cos = emb.cos()
- sin = emb.sin()
-
- # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
- cos = cos * self.attention_scaling
- sin = sin * self.attention_scaling
-
- return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
-
-
ARIA_TEXT_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
@@ -886,6 +1166,8 @@ def __init__(self, config: AriaTextConfig):
self.norm = AriaTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = AriaTextRotaryEmbedding(config=config)
self.gradient_checkpointing = False
+ if getattr(config, "pretraining_tp", 1) != 1:
+ logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.")
# Initialize weights and apply final processing
self.post_init()
@@ -902,7 +1184,7 @@ def forward(
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[Cache] = None,
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
@@ -930,22 +1212,31 @@ def forward(
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
- if use_cache and past_key_values is None:
- past_key_values = DynamicCache()
+ # kept for BC (non `Cache` `past_key_values` inputs)
+ return_legacy_cache = False
+ if use_cache and not isinstance(past_key_values, Cache):
+ return_legacy_cache = True
+ if past_key_values is None:
+ past_key_values = DynamicCache()
+ else:
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
+ logger.warning_once(
+ "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
+ "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
+ "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
+ )
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
-
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
-
hidden_states = inputs_embeds
# create position embeddings to be shared across the decoder layers
@@ -954,6 +1245,7 @@ def forward(
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
+ next_decoder_cache = None
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
if output_hidden_states:
@@ -986,6 +1278,9 @@ def forward(
hidden_states = layer_outputs[0]
+ if use_cache:
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
+
if output_attentions:
all_self_attns += (layer_outputs[1],)
@@ -995,13 +1290,18 @@ def forward(
if output_hidden_states:
all_hidden_states += (hidden_states,)
- output = BaseModelOutputWithPast(
+ next_cache = next_decoder_cache if use_cache else None
+ if return_legacy_cache:
+ next_cache = next_cache.to_legacy_cache()
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
+ return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
- past_key_values=past_key_values if use_cache else None,
+ past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
- return output if return_dict else output.to_tuple()
def _update_causal_mask(
self,
diff --git a/src/transformers/models/audio_spectrogram_transformer/__init__.py b/src/transformers/models/audio_spectrogram_transformer/__init__.py
index 3fe10d60c03a92..9f1d65e1aac839 100644
--- a/src/transformers/models/audio_spectrogram_transformer/__init__.py
+++ b/src/transformers/models/audio_spectrogram_transformer/__init__.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,17 +13,47 @@
# limitations under the License.
from typing import TYPE_CHECKING
-from ...utils import _LazyModule
-from ...utils.import_utils import define_import_structure
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
+
+
+_import_structure = {
+ "configuration_audio_spectrogram_transformer": ["ASTConfig"],
+ "feature_extraction_audio_spectrogram_transformer": ["ASTFeatureExtractor"],
+}
+
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_audio_spectrogram_transformer"] = [
+ "ASTForAudioClassification",
+ "ASTModel",
+ "ASTPreTrainedModel",
+ ]
if TYPE_CHECKING:
- from .configuration_audio_spectrogram_transformer import *
- from .convert_audio_spectrogram_transformer_original_to_pytorch import *
- from .feature_extraction_audio_spectrogram_transformer import *
- from .modeling_audio_spectrogram_transformer import *
+ from .configuration_audio_spectrogram_transformer import (
+ ASTConfig,
+ )
+ from .feature_extraction_audio_spectrogram_transformer import ASTFeatureExtractor
+
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_audio_spectrogram_transformer import (
+ ASTForAudioClassification,
+ ASTModel,
+ ASTPreTrainedModel,
+ )
+
+
else:
import sys
- _file = globals()["__file__"]
- sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
diff --git a/src/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py b/src/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py
index 77bec930236f60..7980667a68d7c5 100644
--- a/src/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py
+++ b/src/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py
@@ -126,6 +126,3 @@ def __init__(
# generative parameters deprecation cycle, overwriting this function prevents this from happening.
def _get_non_default_generation_parameters(self) -> Dict[str, Any]:
return {}
-
-
-__all__ = ["ASTConfig"]
diff --git a/src/transformers/models/audio_spectrogram_transformer/feature_extraction_audio_spectrogram_transformer.py b/src/transformers/models/audio_spectrogram_transformer/feature_extraction_audio_spectrogram_transformer.py
index b181afe19e9ef8..2bd122b4098c36 100644
--- a/src/transformers/models/audio_spectrogram_transformer/feature_extraction_audio_spectrogram_transformer.py
+++ b/src/transformers/models/audio_spectrogram_transformer/feature_extraction_audio_spectrogram_transformer.py
@@ -234,6 +234,3 @@ def __call__(
padded_inputs = padded_inputs.convert_to_tensors(return_tensors)
return padded_inputs
-
-
-__all__ = ["ASTFeatureExtractor"]
diff --git a/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py b/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py
index a9fe0d75f5c380..491c6ce164611a 100644
--- a/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py
+++ b/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py
@@ -670,6 +670,3 @@ def forward(
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
-
-
-__all__ = ["ASTForAudioClassification", "ASTModel", "ASTPreTrainedModel"]
diff --git a/src/transformers/models/auto/__init__.py b/src/transformers/models/auto/__init__.py
index 1f626d8c24f42a..2ee0541a1a71b8 100644
--- a/src/transformers/models/auto/__init__.py
+++ b/src/transformers/models/auto/__init__.py
@@ -74,7 +74,6 @@
"MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING",
"MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING",
"MODEL_FOR_VISION_2_SEQ_MAPPING",
- "MODEL_FOR_RETRIEVAL_MAPPING",
"MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING",
"MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING",
"MODEL_MAPPING",
@@ -253,7 +252,6 @@
MODEL_FOR_OBJECT_DETECTION_MAPPING,
MODEL_FOR_PRETRAINING_MAPPING,
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
- MODEL_FOR_RETRIEVAL_MAPPING,
MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING,
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py
index e6efbe80e4cc7b..fdd5b2e473d306 100644
--- a/src/transformers/models/auto/configuration_auto.py
+++ b/src/transformers/models/auto/configuration_auto.py
@@ -39,7 +39,6 @@
("aria_text", "AriaTextConfig"),
("audio-spectrogram-transformer", "ASTConfig"),
("autoformer", "AutoformerConfig"),
- ("bamba", "BambaConfig"),
("bark", "BarkConfig"),
("bart", "BartConfig"),
("beit", "BeitConfig"),
@@ -70,8 +69,6 @@
("code_llama", "LlamaConfig"),
("codegen", "CodeGenConfig"),
("cohere", "CohereConfig"),
- ("cohere2", "Cohere2Config"),
- ("colpali", "ColPaliConfig"),
("conditional_detr", "ConditionalDetrConfig"),
("convbert", "ConvBertConfig"),
("convnext", "ConvNextConfig"),
@@ -188,7 +185,6 @@
("mobilenet_v2", "MobileNetV2Config"),
("mobilevit", "MobileViTConfig"),
("mobilevitv2", "MobileViTV2Config"),
- ("modernbert", "ModernBertConfig"),
("moshi", "MoshiConfig"),
("mpnet", "MPNetConfig"),
("mpt", "MptConfig"),
@@ -281,7 +277,6 @@
("time_series_transformer", "TimeSeriesTransformerConfig"),
("timesformer", "TimesformerConfig"),
("timm_backbone", "TimmBackboneConfig"),
- ("timm_wrapper", "TimmWrapperConfig"),
("trajectory_transformer", "TrajectoryTransformerConfig"),
("transfo-xl", "TransfoXLConfig"),
("trocr", "TrOCRConfig"),
@@ -340,7 +335,6 @@
("aria_text", "AriaText"),
("audio-spectrogram-transformer", "Audio Spectrogram Transformer"),
("autoformer", "Autoformer"),
- ("bamba", "Bamba"),
("bark", "Bark"),
("bart", "BART"),
("barthez", "BARThez"),
@@ -377,8 +371,6 @@
("code_llama", "CodeLlama"),
("codegen", "CodeGen"),
("cohere", "Cohere"),
- ("cohere2", "Cohere2"),
- ("colpali", "ColPali"),
("conditional_detr", "Conditional DETR"),
("convbert", "ConvBERT"),
("convnext", "ConvNeXT"),
@@ -420,7 +412,6 @@
("ernie_m", "ErnieM"),
("esm", "ESM"),
("falcon", "Falcon"),
- ("falcon3", "Falcon3"),
("falcon_mamba", "FalconMamba"),
("fastspeech2_conformer", "FastSpeech2Conformer"),
("flan-t5", "FLAN-T5"),
@@ -513,7 +504,6 @@
("mobilenet_v2", "MobileNetV2"),
("mobilevit", "MobileViT"),
("mobilevitv2", "MobileViTV2"),
- ("modernbert", "ModernBERT"),
("moshi", "Moshi"),
("mpnet", "MPNet"),
("mpt", "MPT"),
@@ -611,7 +601,6 @@
("time_series_transformer", "Time Series Transformer"),
("timesformer", "TimeSformer"),
("timm_backbone", "TimmBackbone"),
- ("timm_wrapper", "TimmWrapperModel"),
("trajectory_transformer", "Trajectory Transformer"),
("transfo-xl", "Transformer-XL"),
("trocr", "TrOCR"),
diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py
index db25591eaa3544..a699314f858928 100644
--- a/src/transformers/models/auto/image_processing_auto.py
+++ b/src/transformers/models/auto/image_processing_auto.py
@@ -30,8 +30,6 @@
CONFIG_NAME,
IMAGE_PROCESSOR_NAME,
get_file_from_repo,
- is_timm_config_dict,
- is_timm_local_checkpoint,
is_torchvision_available,
is_vision_available,
logging,
@@ -139,7 +137,6 @@
("swinv2", ("ViTImageProcessor", "ViTImageProcessorFast")),
("table-transformer", ("DetrImageProcessor",)),
("timesformer", ("VideoMAEImageProcessor",)),
- ("timm_wrapper", ("TimmWrapperImageProcessor",)),
("tvlt", ("TvltImageProcessor",)),
("tvp", ("TvpImageProcessor",)),
("udop", ("LayoutLMv3ImageProcessor",)),
@@ -175,7 +172,7 @@
IMAGE_PROCESSOR_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, IMAGE_PROCESSOR_MAPPING_NAMES)
-def get_image_processor_class_from_name(class_name: str):
+def image_processor_class_from_name(class_name: str):
if class_name == "BaseImageProcessorFast":
return BaseImageProcessorFast
@@ -368,7 +365,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
identifier allowed by git.
use_fast (`bool`, *optional*, defaults to `False`):
Use a fast torchvision-base image processor if it is supported for a given model.
- If a fast image processor is not available for a given model, a normal numpy-based image processor
+ If a fast tokenizer is not available for a given model, a normal numpy-based image processor
is returned instead.
return_unused_kwargs (`bool`, *optional*, defaults to `False`):
If `False`, then this function returns just the final image processor object. If `True`, then this
@@ -379,8 +376,6 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
should only be set to `True` for repositories you trust and in which you have read the code, as it will
execute code present on the Hub on your local machine.
- image_processor_filename (`str`, *optional*, defaults to `"config.json"`):
- The name of the file in the model directory to use for the image processor config.
kwargs (`Dict[str, Any]`, *optional*):
The values in kwargs of any keys which are image processor attributes will be used to override the
loaded values. Behavior concerning key/value pairs whose keys are *not* image processor attributes is
@@ -416,59 +411,28 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
kwargs["token"] = use_auth_token
config = kwargs.pop("config", None)
- # TODO: @yoni, change in v4.48 (use_fast set to True by default)
use_fast = kwargs.pop("use_fast", None)
trust_remote_code = kwargs.pop("trust_remote_code", None)
kwargs["_from_auto"] = True
- # Resolve the image processor config filename
- if "image_processor_filename" in kwargs:
- image_processor_filename = kwargs.pop("image_processor_filename")
- elif is_timm_local_checkpoint(pretrained_model_name_or_path):
- image_processor_filename = CONFIG_NAME
- else:
- image_processor_filename = IMAGE_PROCESSOR_NAME
-
- # Load the image processor config
- try:
- # Main path for all transformers models and local TimmWrapper checkpoints
- config_dict, _ = ImageProcessingMixin.get_image_processor_dict(
- pretrained_model_name_or_path, image_processor_filename=image_processor_filename, **kwargs
- )
- except Exception as initial_exception:
- # Fallback path for Hub TimmWrapper checkpoints. Timm models' image processing is saved in `config.json`
- # instead of `preprocessor_config.json`. Because this is an Auto class and we don't have any information
- # except the model name, the only way to check if a remote checkpoint is a timm model is to try to
- # load `config.json` and if it fails with some error, we raise the initial exception.
- try:
- config_dict, _ = ImageProcessingMixin.get_image_processor_dict(
- pretrained_model_name_or_path, image_processor_filename=CONFIG_NAME, **kwargs
- )
- except Exception:
- raise initial_exception
-
- # In case we have a config_dict, but it's not a timm config dict, we raise the initial exception,
- # because only timm models have image processing in `config.json`.
- if not is_timm_config_dict(config_dict):
- raise initial_exception
-
- image_processor_type = config_dict.get("image_processor_type", None)
+ config_dict, _ = ImageProcessingMixin.get_image_processor_dict(pretrained_model_name_or_path, **kwargs)
+ image_processor_class = config_dict.get("image_processor_type", None)
image_processor_auto_map = None
if "AutoImageProcessor" in config_dict.get("auto_map", {}):
image_processor_auto_map = config_dict["auto_map"]["AutoImageProcessor"]
# If we still don't have the image processor class, check if we're loading from a previous feature extractor config
# and if so, infer the image processor class from there.
- if image_processor_type is None and image_processor_auto_map is None:
+ if image_processor_class is None and image_processor_auto_map is None:
feature_extractor_class = config_dict.pop("feature_extractor_type", None)
if feature_extractor_class is not None:
- image_processor_type = feature_extractor_class.replace("FeatureExtractor", "ImageProcessor")
+ image_processor_class = feature_extractor_class.replace("FeatureExtractor", "ImageProcessor")
if "AutoFeatureExtractor" in config_dict.get("auto_map", {}):
feature_extractor_auto_map = config_dict["auto_map"]["AutoFeatureExtractor"]
image_processor_auto_map = feature_extractor_auto_map.replace("FeatureExtractor", "ImageProcessor")
# If we don't find the image processor class in the image processor config, let's try the model config.
- if image_processor_type is None and image_processor_auto_map is None:
+ if image_processor_class is None and image_processor_auto_map is None:
if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(
pretrained_model_name_or_path,
@@ -476,47 +440,18 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
**kwargs,
)
# It could be in `config.image_processor_type``
- image_processor_type = getattr(config, "image_processor_type", None)
+ image_processor_class = getattr(config, "image_processor_type", None)
if hasattr(config, "auto_map") and "AutoImageProcessor" in config.auto_map:
image_processor_auto_map = config.auto_map["AutoImageProcessor"]
- image_processor_class = None
- # TODO: @yoni, change logic in v4.48 (when use_fast set to True by default)
- if image_processor_type is not None:
- # if use_fast is not set and the processor was saved with a fast processor, we use it, otherwise we use the slow processor.
- if use_fast is None:
- use_fast = image_processor_type.endswith("Fast")
- if not use_fast:
- logger.warning_once(
- "Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. "
- "`use_fast=True` will be the default behavior in v4.48, even if the model was saved with a slow processor. "
- "This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`."
- )
- # Update class name to reflect the use_fast option. If class is not found, we fall back to the slow version.
- if use_fast and not is_torchvision_available():
- logger.warning_once(
- "Using `use_fast=True` but `torchvision` is not available. Falling back to the slow image processor."
- )
- use_fast = False
- if use_fast:
- if not image_processor_type.endswith("Fast"):
- image_processor_type += "Fast"
- for _, image_processors in IMAGE_PROCESSOR_MAPPING_NAMES.items():
- if image_processor_type in image_processors:
- break
- else:
- image_processor_type = image_processor_type[:-4]
- use_fast = False
- logger.warning_once(
- "`use_fast` is set to `True` but the image processor class does not have a fast version. "
- " Falling back to the slow version."
- )
- image_processor_class = get_image_processor_class_from_name(image_processor_type)
- else:
- image_processor_type = (
- image_processor_type[:-4] if image_processor_type.endswith("Fast") else image_processor_type
- )
- image_processor_class = get_image_processor_class_from_name(image_processor_type)
+ if image_processor_class is not None:
+ # Update class name to reflect the use_fast option. If class is not found, None is returned.
+ if use_fast is not None:
+ if use_fast and not image_processor_class.endswith("Fast"):
+ image_processor_class += "Fast"
+ elif not use_fast and image_processor_class.endswith("Fast"):
+ image_processor_class = image_processor_class[:-4]
+ image_processor_class = image_processor_class_from_name(image_processor_class)
has_remote_code = image_processor_auto_map is not None
has_local_code = image_processor_class is not None or type(config) in IMAGE_PROCESSOR_MAPPING
diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py
index 3930796acf2d35..9400cc58d81139 100644
--- a/src/transformers/models/auto/modeling_auto.py
+++ b/src/transformers/models/auto/modeling_auto.py
@@ -39,7 +39,6 @@
("aria_text", "AriaTextModel"),
("audio-spectrogram-transformer", "ASTModel"),
("autoformer", "AutoformerModel"),
- ("bamba", "BambaModel"),
("bark", "BarkModel"),
("bart", "BartModel"),
("beit", "BeitModel"),
@@ -70,7 +69,6 @@
("code_llama", "LlamaModel"),
("codegen", "CodeGenModel"),
("cohere", "CohereModel"),
- ("cohere2", "Cohere2Model"),
("conditional_detr", "ConditionalDetrModel"),
("convbert", "ConvBertModel"),
("convnext", "ConvNextModel"),
@@ -177,7 +175,6 @@
("mobilenet_v2", "MobileNetV2Model"),
("mobilevit", "MobileViTModel"),
("mobilevitv2", "MobileViTV2Model"),
- ("modernbert", "ModernBertModel"),
("moshi", "MoshiModel"),
("mpnet", "MPNetModel"),
("mpt", "MptModel"),
@@ -259,7 +256,6 @@
("time_series_transformer", "TimeSeriesTransformerModel"),
("timesformer", "TimesformerModel"),
("timm_backbone", "TimmBackbone"),
- ("timm_wrapper", "TimmWrapperModel"),
("trajectory_transformer", "TrajectoryTransformerModel"),
("transfo-xl", "TransfoXLModel"),
("tvlt", "TvltModel"),
@@ -309,7 +305,6 @@
("big_bird", "BigBirdForPreTraining"),
("bloom", "BloomForCausalLM"),
("camembert", "CamembertForMaskedLM"),
- ("colpali", "ColPaliForRetrieval"),
("ctrl", "CTRLLMHeadModel"),
("data2vec-text", "Data2VecTextForMaskedLM"),
("deberta", "DebertaForMaskedLM"),
@@ -474,7 +469,6 @@
[
# Model for Causal LM mapping
("aria_text", "AriaTextForCausalLM"),
- ("bamba", "BambaForCausalLM"),
("bart", "BartForCausalLM"),
("bert", "BertLMHeadModel"),
("bert-generation", "BertGenerationDecoder"),
@@ -488,7 +482,6 @@
("code_llama", "LlamaForCausalLM"),
("codegen", "CodeGenForCausalLM"),
("cohere", "CohereForCausalLM"),
- ("cohere2", "Cohere2ForCausalLM"),
("cpmant", "CpmAntForCausalLM"),
("ctrl", "CTRLLMHeadModel"),
("data2vec-text", "Data2VecTextForCausalLM"),
@@ -614,7 +607,6 @@
("table-transformer", "TableTransformerModel"),
("timesformer", "TimesformerModel"),
("timm_backbone", "TimmBackbone"),
- ("timm_wrapper", "TimmWrapperModel"),
("van", "VanModel"),
("videomae", "VideoMAEModel"),
("vit", "ViTModel"),
@@ -700,7 +692,6 @@
("swiftformer", "SwiftFormerForImageClassification"),
("swin", "SwinForImageClassification"),
("swinv2", "Swinv2ForImageClassification"),
- ("timm_wrapper", "TimmWrapperForImageClassification"),
("van", "VanForImageClassification"),
("vit", "ViTForImageClassification"),
("vit_hybrid", "ViTHybridForImageClassification"),
@@ -781,12 +772,6 @@
]
)
-MODEL_FOR_RETRIEVAL_MAPPING_NAMES = OrderedDict(
- [
- ("colpali", "ColPaliForRetrieval"),
- ]
-)
-
MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES = OrderedDict(
[
("aria", "AriaForConditionalGeneration"),
@@ -841,7 +826,6 @@
("mega", "MegaForMaskedLM"),
("megatron-bert", "MegatronBertForMaskedLM"),
("mobilebert", "MobileBertForMaskedLM"),
- ("modernbert", "ModernBertForMaskedLM"),
("mpnet", "MPNetForMaskedLM"),
("mra", "MraForMaskedLM"),
("mvp", "MvpForConditionalGeneration"),
@@ -997,7 +981,6 @@
("mistral", "MistralForSequenceClassification"),
("mixtral", "MixtralForSequenceClassification"),
("mobilebert", "MobileBertForSequenceClassification"),
- ("modernbert", "ModernBertForSequenceClassification"),
("mpnet", "MPNetForSequenceClassification"),
("mpt", "MptForSequenceClassification"),
("mra", "MraForSequenceClassification"),
@@ -1186,7 +1169,6 @@
("mistral", "MistralForTokenClassification"),
("mixtral", "MixtralForTokenClassification"),
("mobilebert", "MobileBertForTokenClassification"),
- ("modernbert", "ModernBertForTokenClassification"),
("mpnet", "MPNetForTokenClassification"),
("mpt", "MptForTokenClassification"),
("mra", "MraForTokenClassification"),
@@ -1491,7 +1473,6 @@
MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES
)
-MODEL_FOR_RETRIEVAL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_RETRIEVAL_MAPPING_NAMES)
MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING_NAMES
)
diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py
index 815e2ca755bee3..3e475b1be211fa 100644
--- a/src/transformers/models/auto/processing_auto.py
+++ b/src/transformers/models/auto/processing_auto.py
@@ -58,7 +58,6 @@
("clip", "CLIPProcessor"),
("clipseg", "CLIPSegProcessor"),
("clvp", "ClvpProcessor"),
- ("colpali", "ColPaliProcessor"),
("flava", "FlavaProcessor"),
("fuyu", "FuyuProcessor"),
("git", "GitProcessor"),
diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py
index 2f6624057e0fa2..3006e89ff0673c 100644
--- a/src/transformers/models/auto/tokenization_auto.py
+++ b/src/transformers/models/auto/tokenization_auto.py
@@ -147,8 +147,6 @@
),
("codegen", ("CodeGenTokenizer", "CodeGenTokenizerFast" if is_tokenizers_available() else None)),
("cohere", (None, "CohereTokenizerFast" if is_tokenizers_available() else None)),
- ("cohere2", (None, "CohereTokenizerFast" if is_tokenizers_available() else None)),
- ("colpali", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
("convbert", ("ConvBertTokenizer", "ConvBertTokenizerFast" if is_tokenizers_available() else None)),
(
"cpm",
@@ -320,7 +318,6 @@
("mllama", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
("mluke", ("MLukeTokenizer" if is_sentencepiece_available() else None, None)),
("mobilebert", ("MobileBertTokenizer", "MobileBertTokenizerFast" if is_tokenizers_available() else None)),
- ("modernbert", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
("moshi", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
("mpnet", ("MPNetTokenizer", "MPNetTokenizerFast" if is_tokenizers_available() else None)),
("mpt", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
diff --git a/src/transformers/models/bamba/__init__.py b/src/transformers/models/bamba/__init__.py
deleted file mode 100644
index c3920da849a333..00000000000000
--- a/src/transformers/models/bamba/__init__.py
+++ /dev/null
@@ -1,28 +0,0 @@
-# Copyright 2024 IBM and the HuggingFace Inc. team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-from typing import TYPE_CHECKING
-
-from ...utils import _LazyModule
-from ...utils.import_utils import define_import_structure
-
-
-if TYPE_CHECKING:
- from .configuration_bamba import *
- from .modeling_bamba import *
- from .processing_bamba import *
-else:
- import sys
-
- _file = globals()["__file__"]
- sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/src/transformers/models/bamba/configuration_bamba.py b/src/transformers/models/bamba/configuration_bamba.py
deleted file mode 100644
index f84d63ec04a9c7..00000000000000
--- a/src/transformers/models/bamba/configuration_bamba.py
+++ /dev/null
@@ -1,206 +0,0 @@
-# coding=utf-8
-# Copyright 2024 IBM and the HuggingFace Inc. team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""Bamba model configuration"""
-
-from ...configuration_utils import PretrainedConfig
-from ...utils import logging
-
-
-logger = logging.get_logger(__name__)
-
-
-class BambaConfig(PretrainedConfig):
- r"""
- This is the configuration class to store the configuration of a [`BambaModel`]. It is used to instantiate a
- BambaModel model according to the specified arguments, defining the model architecture. Instantiating a configuration
- with defaults taken from [ibm-fms/Bamba-9.8b-2.2T-hf](https://huggingface.co/ibm-fms/Bamba-9.8b-2.2T-hf).
-
- The BambaModel is a hybrid [mamba2](https://github.com/state-spaces/mamba) architecture with SwiGLU.
- The checkpoints are jointly trained by IBM, Princeton, and UIUC.
-
- Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
- documentation from [`PretrainedConfig`] for more information.
-
- Args:
- vocab_size (`int`, *optional*, defaults to 128000):
- Vocabulary size of the Bamba model. Defines the number of different tokens that can be represented by the
- `inputs_ids` passed when calling [`BambaModel`]
- tie_word_embeddings (`bool`, *optional*, defaults to `False`):
- Whether the model's input and output word embeddings should be tied. Note that this is only relevant if the
- model has a output word embedding layer.
- hidden_size (`int`, *optional*, defaults to 4096):
- Dimension of the hidden representations.
- intermediate_size (`int`, *optional*, defaults to 14336):
- Dimension of the MLP representations.
- num_hidden_layers (`int`, *optional*, defaults to 32):
- Number of hidden layers in the Transformer encoder.
- num_attention_heads (`int`, *optional*, defaults to 32):
- Number of attention heads for each attention layer in the Transformer encoder.
- num_key_value_heads (`int`, *optional*, defaults to 8):
- This is the number of key_value heads that should be used to implement Grouped Query Attention. If
- `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
- `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
- converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
- by meanpooling all the original heads within that group. For more details checkout [this
- paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`.
- hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
- The non-linear activation function (function or string) in the decoder.
- initializer_range (`float`, *optional*, defaults to 0.02):
- The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
- rms_norm_eps (`float`, *optional*, defaults to 1e-05):
- The epsilon used by the rms normalization layers.
- use_cache (`bool`, *optional*, defaults to `True`):
- Whether or not the model should return the last key/values attentions (not used by all models). Only
- relevant if `config.is_decoder=True`.
- num_logits_to_keep (`int` or `None`, *optional*, defaults to 1):
- Number of prompt logits to calculate during generation. If `None`, all logits will be calculated. If an
- integer value, only last `num_logits_to_keep` logits will be calculated. Default is 1 because only the
- logits of the last prompt token are needed for generation. For long sequences, the logits for the entire
- sequence may use a lot of memory so, setting `num_logits_to_keep=1` will reduce memory footprint
- significantly.
- pad_token_id (`int`, *optional*, defaults to 0):
- The id of the padding token.
- bos_token_id (`int`, *optional*, defaults to 1):
- The id of the "beginning-of-sequence" token.
- eos_token_id (`int`, *optional*, defaults to 2):
- The id of the "end-of-sequence" token.
- max_position_embeddings (`int`, *optional*, defaults to 262144):
- Max cached sequence length for the model
- attention_dropout (`float`, *optional*, defaults to 0.0):
- The dropout ratio for the attention probabilities.
- attn_layer_indices (`list`, *optional*):
- Specifies the layer indices that will have full attention. Must contain values at most num_hidden_layers.
- mamba_n_heads (`int`, *optional*, defaults to 128):
- The number of mamba heads used in the v2 implementation.
- mamba_d_head (`int`, *optional*, defaults to `"auto"`):
- Head embeddding dimension size
- mamba_n_groups (`int`, *optional*, defaults to 1):
- The number of the mamba groups used in the v2 implementation.
- mamba_d_state (`int`, *optional*, defaults to 256):
- The dimension the mamba state space latents
- mamba_d_conv (`int`, *optional*, defaults to 4):
- The size of the mamba convolution kernel
- mamba_expand (`int`, *optional*, defaults to 2):
- Expanding factor (relative to hidden_size) used to determine the mamba intermediate size
- mamba_chunk_size (`int`, *optional*, defaults to 256):
- The chunks in which to break the sequence when doing prefill/training
- mamba_conv_bias (`bool`, *optional*, defaults to `True`):
- Flag indicating whether or not to use bias in the convolution layer of the mamba mixer block.
- mamba_proj_bias (`bool`, *optional*, defaults to `False`):
- Flag indicating whether or not to use bias in the input and output projections (["in_proj", "out_proj"]) of the mamba mixer block
-
- """
-
- model_type = "bamba"
- keys_to_ignore_at_inference = ["past_key_values"]
-
- def __init__(
- self,
- vocab_size=128000,
- tie_word_embeddings=False,
- hidden_size=4096,
- intermediate_size=14336,
- num_hidden_layers=32,
- num_attention_heads=32,
- num_key_value_heads=8,
- hidden_act="silu",
- initializer_range=0.02,
- rms_norm_eps=1e-5,
- use_cache=True,
- num_logits_to_keep=1,
- pad_token_id=0,
- bos_token_id=1,
- eos_token_id=2,
- max_position_embeddings=262144,
- attention_dropout=0.0,
- attn_layer_indices=None,
- mamba_n_heads=128,
- mamba_d_head="auto",
- mamba_n_groups=1,
- mamba_d_state=256,
- mamba_d_conv=4,
- mamba_expand=2,
- mamba_chunk_size=256,
- mamba_conv_bias=True,
- mamba_proj_bias=False,
- **kwargs,
- ):
- self.vocab_size = vocab_size
- self.tie_word_embeddings = tie_word_embeddings
- self.hidden_size = hidden_size
- self.intermediate_size = intermediate_size
- self.num_hidden_layers = num_hidden_layers
- self.num_attention_heads = num_attention_heads
- self.max_position_embeddings = max_position_embeddings
- self.attention_dropout = attention_dropout
- self.attention_bias = False
- self.mlp_bias = False
-
- # for backward compatibility
- if num_key_value_heads is None:
- num_key_value_heads = num_attention_heads
-
- self.num_key_value_heads = num_key_value_heads
- self.hidden_act = hidden_act
- self.initializer_range = initializer_range
- self.rms_norm_eps = rms_norm_eps
-
- self.use_cache = use_cache
- self.num_logits_to_keep = num_logits_to_keep
-
- self.attn_layer_indices = attn_layer_indices
- self.rope_theta = 10000.0
- self.rope_scaling = None
- self.partial_rotary_factor = 0.5
-
- mamba_intermediate = mamba_expand * hidden_size
-
- if mamba_intermediate % mamba_n_heads != 0:
- raise ValueError("mamba_n_heads must divide mamba_expand * hidden_size")
-
- # for the mamba_v2, must satisfy the following
- if mamba_d_head == "auto":
- mamba_d_head = mamba_intermediate // mamba_n_heads
-
- if mamba_d_head * mamba_n_heads != mamba_intermediate:
- raise ValueError("The dimensions for the Mamba head state do not match the model intermediate_size")
-
- self.mamba_n_heads = mamba_n_heads
- self.mamba_d_head = mamba_d_head
- self.mamba_n_groups = mamba_n_groups
- self.mamba_d_state = mamba_d_state
- self.mamba_d_conv = mamba_d_conv
- self.mamba_expand = mamba_expand
- self.mamba_chunk_size = mamba_chunk_size
- self.mamba_conv_bias = mamba_conv_bias
- self.mamba_proj_bias = mamba_proj_bias
-
- super().__init__(
- pad_token_id=pad_token_id,
- bos_token_id=bos_token_id,
- eos_token_id=eos_token_id,
- tie_word_embeddings=tie_word_embeddings,
- **kwargs,
- )
-
- @property
- def layers_block_type(self):
- return [
- "attention" if (self.attn_layer_indices and i in self.attn_layer_indices) else "mamba"
- for i in range(self.num_hidden_layers)
- ]
-
-
-__all__ = ["BambaConfig"]
diff --git a/src/transformers/models/bamba/convert_mamba_ssm_checkpoint.py b/src/transformers/models/bamba/convert_mamba_ssm_checkpoint.py
deleted file mode 100644
index a7b8cfc782907b..00000000000000
--- a/src/transformers/models/bamba/convert_mamba_ssm_checkpoint.py
+++ /dev/null
@@ -1,273 +0,0 @@
-# coding=utf-8
-# Copyright 2024 IBM and the HuggingFace Inc. team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""This script can be used to convert checkpoints provided in the `mamba_ssm` library into the format provided in HuggingFace `transformers`. It depends on the `mamba2_ssm` package to be installed."""
-
-import argparse
-import json
-import os
-import re
-from os import path
-from typing import Dict, Union
-
-import torch
-from huggingface_hub import split_torch_state_dict_into_shards
-from safetensors.torch import save_file
-
-from transformers import AutoTokenizer
-from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME
-
-from .configuration_bamba import BambaConfig
-
-
-def convert_state_dict_from_mamba_ssm(original_sd: Dict) -> Dict[str, torch.Tensor]:
- state_dict = {}
-
- for orig_k, param in original_sd.items():
- k = orig_k.replace("backbone", "model")
-
- # for embeddings
- k = k.replace("embedding", "embed_tokens")
-
- # for mixer
- k = k.replace("mixer", "mamba")
-
- # for final layernorm
- k = k.replace("norm_f", "final_layernorm")
-
- # for block layernorm
- k = re.sub(r"(\d+)\.norm\.", r"\1.input_layernorm.", k)
- k = re.sub(r"(\d+)\.norm2\.", r"\1.pre_ff_layernorm.", k)
-
- # for mlp
- k = k.replace("mlp.fc2", "feed_forward.down_proj")
-
- if "mlp.fc1" in k:
- param, param2 = torch.chunk(param, 2, dim=0)
- k2 = k.replace("mlp.fc1", "feed_forward.gate_proj")
- state_dict[k2] = param2
- k = k.replace("mlp.fc1", "feed_forward.up_proj")
-
- if ("in_proj" in k and orig_k.replace("in_proj", "conv1d") in original_sd) or (
- "out_proj" in k and orig_k.replace("out_proj", "conv1d") in original_sd
- ):
- # then this must be a mamba
- pass
- else:
- # for attn
- # - because mixer was replaced to mamba above
- k = k.replace("mamba.out_proj", "self_attn.o_proj")
- if "mamba.in_proj" in k:
- m, n = param.shape
- d = (m - n) // 2
- param, param2, param3 = torch.split(param, [n, d, d], dim=0)
- k2 = k.replace("mamba.in_proj", "self_attn.k_proj")
- state_dict[k2] = param2
- k2 = k.replace("mamba.in_proj", "self_attn.v_proj")
- state_dict[k2] = param3
- k = k.replace("mamba.in_proj", "self_attn.q_proj")
-
- state_dict[k] = param
-
- return state_dict
-
-
-# Adapted from transformers.models.mamba.convert_mamba_ssm_checkpoint_to_pytorch.py
-def convert_ssm_config_to_hf_config(
- config_ssm: Dict,
- **kwargs,
-) -> BambaConfig:
- """Convert a config from mamba_ssm to a BambaConfig from here."""
- hf_config: BambaConfig = BambaConfig(**kwargs)
-
- hf_config.architectures = ["BambaForCausalLM"]
-
- # Set important values from config and recalculate other resulting entries
- hf_config.hidden_size = config_ssm["d_model"]
- hf_config.intermediate_size = config_ssm["d_intermediate"]
- hf_config.mamba_n_heads = (hf_config.hidden_size * hf_config.mamba_expand) // hf_config.mamba_d_head
- hf_config.num_hidden_layers = config_ssm["n_layer"]
- hf_config.tie_word_embeddings = config_ssm["tie_embeddings"]
-
- # currently this script assumes config_ssm belongs to v2
- if config_ssm["ssm_cfg"].get("layer") != "Mamba2":
- raise ValueError("Conversion script only supports Mamba2")
-
- # Set attention values
- attn_cfg = config_ssm.get("attn_cfg")
- if attn_cfg:
- assert attn_cfg["causal"], "Only support non-causal attention."
- assert not attn_cfg["qkv_proj_bias"], "Only support no qkv bias."
- assert not attn_cfg["out_proj_bias"], "Only support no out bias."
- hf_config.attn_rotary_emb = attn_cfg["rotary_emb_dim"]
- hf_config.num_attention_heads = attn_cfg["num_heads"]
- hf_config.num_key_value_heads = attn_cfg["num_heads_kv"]
-
- attention_layer_indices = config_ssm.get("attn_layer_idx")
- if attention_layer_indices:
- hf_config.attn_layer_indices = attention_layer_indices
-
- # Padded vocab size, mostly of 16 but 32 is also very common in different models
- vocab_size = config_ssm["vocab_size"]
- pad_vocab_size_multiple = config_ssm["pad_vocab_size_multiple"]
- if (vocab_size % pad_vocab_size_multiple) != 0:
- vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple)
- hf_config.vocab_size = vocab_size
-
- return hf_config
-
-
-def save_single_safetensor(
- state_dict: Dict,
- save_directory: str,
- metadata: Dict,
-):
- save_file(
- state_dict,
- os.path.join(save_directory, SAFE_WEIGHTS_NAME),
- metadata,
- )
-
-
-def save_sharded_safetensors(
- state_dict: Dict,
- save_directory: str,
- metadata: Dict,
- max_shard_size: Union[int, str] = "5GB",
-):
- filename_pattern = SAFE_WEIGHTS_NAME.replace(".bin", "{suffix}.bin").replace(
- ".safetensors", "{suffix}.safetensors"
- )
- state_dict_split = split_torch_state_dict_into_shards(
- state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size
- )
- index = {
- "metadata": state_dict_split.metadata,
- "weight_map": state_dict_split.tensor_to_filename,
- }
- # Save the index
- with open(os.path.join(save_directory, SAFE_WEIGHTS_INDEX_NAME), "w", encoding="utf-8") as f:
- content = json.dumps(index, indent=2, sort_keys=True) + "\n"
- f.write(content)
-
- filename_to_tensors = state_dict_split.filename_to_tensors.items()
- for shard_file, tensors in filename_to_tensors:
- shard = {tensor: state_dict[tensor].contiguous() for tensor in tensors}
- save_file(shard, os.path.join(save_directory, shard_file), metadata=metadata)
-
-
-# Adapted from transformers.models.mamba.convert_mamba_ssm_checkpoint_to_pytorch.py
-def convert_mamba_ssm_checkpoint_file_to_huggingface_model_file(
- mamba_ssm_checkpoint_path: str,
- precision: str,
- output_dir: str,
- tokenizer_path: str = None,
- save_model: Union[bool, str] = True,
-) -> None:
- # load tokenizer if provided, this will be used to set the
- # token_ids in the config file
- token_ids = {}
- if tokenizer_path:
- tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
- for key in [
- "bos_token_id",
- "eos_token_id",
- "pad_token_id",
- ]:
- id = getattr(tokenizer, key, None)
- if id:
- token_ids[key] = id
-
- # there are some configs unsettable by mamba_ssn config, so
- # if there are changes from the defaults, have to pass them into
- # the function
- unsettables = {
- "mamba_d_head": 64,
- "mamba_d_state": 128,
- "mamba_n_groups": 1,
- "rms_norm_eps": 1e-5,
- }
-
- # Load and save config based on name
- config_path = path.join(mamba_ssm_checkpoint_path, "config.json")
- with open(config_path, "r", encoding="utf-8") as json_file:
- config = json.load(json_file)
-
- # convert the config
- hf_config = convert_ssm_config_to_hf_config(
- config_ssm=config,
- **token_ids,
- **unsettables,
- )
- hf_config.save_pretrained(output_dir)
-
- # Load state dict of the original model and transfer to hf model
- state_dict = torch.load(
- path.join(mamba_ssm_checkpoint_path, "pytorch_model.bin"),
- map_location="cpu",
- weights_only=True,
- )
- # FIXME: allow other parameters to pass in
- state_dict = convert_state_dict_from_mamba_ssm(state_dict)
-
- # Save new model to pytorch_dump_path
- dtype = torch.float32 if precision == "fp32" else (torch.bfloat16 if precision == "bf16" else torch.float16)
-
- save_file_fn = None
- if isinstance(save_model, bool) and save_model:
- save_file_fn = save_single_safetensor
- elif isinstance(save_model, str) and save_model == "sharded":
- save_file_fn = save_sharded_safetensors
-
- if save_file_fn:
- save_file_fn({k: v.to(dtype) for k, v in state_dict.items()}, output_dir, metadata={"format": "pt"})
-
-
-if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "-i",
- "--mamba_ssm_checkpoint_directory",
- type=str,
- required=True,
- help="Path to a directory containing the `pytorch_model.bin` mamba_ssm checkpoint file to be converted.",
- )
- parser.add_argument(
- "-p",
- "--precision",
- type=str,
- default="fp16",
- const="fp16",
- required=True,
- choices=("fp32", "fp16", "bf16"),
- help="The precision the model will be saved in. Select from fp32, fp16 or bf16.",
- )
- parser.add_argument(
- "-o", "--output_dir", type=str, required=True, help="Path to directory to save the converted output model to."
- )
- parser.add_argument(
- "-t",
- "--tokenizer_model_path",
- type=str,
- default=None,
- required=False,
- help="Path to a the tokenizer file.",
- )
- args = parser.parse_args()
-
- convert_mamba_ssm_checkpoint_file_to_huggingface_model_file(
- args.mamba2_checkpoint_directory,
- args.precision,
- args.output_dir,
- )
diff --git a/src/transformers/models/bamba/modeling_bamba.py b/src/transformers/models/bamba/modeling_bamba.py
deleted file mode 100644
index c89d8d7853008d..00000000000000
--- a/src/transformers/models/bamba/modeling_bamba.py
+++ /dev/null
@@ -1,1615 +0,0 @@
-# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
-# This file was automatically generated from src/transformers/models/bamba/modular_bamba.py.
-# Do NOT edit this file manually as any edits will be overwritten by the generation of
-# the file from the modular. If any change should be done, please apply the change to the
-# modular_bamba.py file directly. One of our CI enforces this.
-# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
-# coding=utf-8
-# Copyright 2024 IBM and the HuggingFace Inc. team. All rights reserved.
-#
-# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
-# and OPT implementations in this library. It has been modified from its
-# original forms to accommodate minor architectural differences compared
-# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from typing import Callable, Optional, Tuple, Union
-
-import torch
-from torch import nn
-
-import transformers.models.jamba.modeling_jamba as modeling_jamba
-from transformers.activations import ACT2FN
-
-from ...cache_utils import Cache # we need __iter__ and __len__ of pkv
-from ...generation import GenerationMixin
-from ...modeling_attn_mask_utils import AttentionMaskConverter
-from ...modeling_flash_attention_utils import FlashAttentionKwargs
-from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
-from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
-from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
-from ...processing_utils import Unpack
-from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
-from ...utils.import_utils import (
- is_causal_conv1d_available,
- is_mamba_2_ssm_available,
-)
-from .configuration_bamba import BambaConfig
-
-
-if is_mamba_2_ssm_available():
- from mamba_ssm.ops.triton.selective_state_update import selective_state_update
- from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined
-else:
- selective_state_update = None
-
-if is_causal_conv1d_available():
- from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
-else:
- causal_conv1d_update, causal_conv1d_fn = None, None
-
-
-logger = logging.get_logger(__name__)
-_CONFIG_FOR_DOC = "BambaConfig"
-
-
-# Adapted from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache for the v2 mixer
-class HybridMambaAttentionDynamicCache(modeling_jamba.HybridMambaAttentionDynamicCache):
- """
- A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache
- (which has a constant shape regardless of seq_len).
-
- This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states`
- and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor
- For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`,
- while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors).
- For mamba layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors),
- while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`,
- and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`.
- """
-
- def __init__(self, config: BambaConfig, batch_size, dtype=torch.float16, device=None):
- super().__init__(config, batch_size, dtype, device)
- self.layers_block_type = config.layers_block_type
- self.has_previous_state = False # only used by mamba
- conv_kernel_size = config.mamba_d_conv
- ssm_state_size = config.mamba_d_state
-
- self.conv_states = []
- self.ssm_states = []
- self.transformer_layers = []
- for i in range(config.num_hidden_layers):
- if self.layers_block_type[i] == "mamba":
- self.conv_states += [
- torch.zeros(
- batch_size,
- (config.mamba_expand * config.hidden_size + 2 * config.mamba_n_groups * ssm_state_size),
- conv_kernel_size,
- device=device,
- dtype=dtype,
- )
- ]
- self.ssm_states += [
- torch.zeros(
- batch_size,
- config.mamba_n_heads,
- config.mamba_d_head,
- ssm_state_size,
- device=device,
- dtype=dtype,
- )
- ]
- else:
- self.conv_states += [torch.tensor([[]] * batch_size, device=device)]
- self.ssm_states += [torch.tensor([[]] * batch_size, device=device)]
- self.transformer_layers.append(i)
-
- self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)]
- self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)]
-
-
-class BambaRotaryEmbedding(nn.Module):
- def __init__(
- self,
- config: BambaConfig,
- device=None,
- ):
- super().__init__()
- self.rope_kwargs = {}
- # BC: "rope_type" was originally "type"
- if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
- self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
- else:
- self.rope_type = "default"
- self.max_seq_len_cached = config.max_position_embeddings
- self.original_max_seq_len = config.max_position_embeddings
-
- self.config = config
- self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
-
- inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
- self.register_buffer("inv_freq", inv_freq, persistent=False)
- self.original_inv_freq = self.inv_freq
-
- def _dynamic_frequency_update(self, position_ids, device):
- """
- dynamic RoPE layers should recompute `inv_freq` in the following situations:
- 1 - growing beyond the cached sequence length (allow scaling)
- 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
- """
- seq_len = torch.max(position_ids) + 1
- if seq_len > self.max_seq_len_cached: # growth
- inv_freq, self.attention_scaling = self.rope_init_fn(
- self.config, device, seq_len=seq_len, **self.rope_kwargs
- )
- self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
- self.max_seq_len_cached = seq_len
-
- if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
- self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
- self.max_seq_len_cached = self.original_max_seq_len
-
- @torch.no_grad()
- def forward(self, x, position_ids):
- if "dynamic" in self.rope_type:
- self._dynamic_frequency_update(position_ids, device=x.device)
-
- # Core RoPE block
- inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
- position_ids_expanded = position_ids[:, None, :].float()
- # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
- device_type = x.device.type
- device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
- with torch.autocast(device_type=device_type, enabled=False):
- freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
- emb = torch.cat((freqs, freqs), dim=-1)
- cos = emb.cos()
- sin = emb.sin()
-
- # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
- cos = cos * self.attention_scaling
- sin = sin * self.attention_scaling
-
- return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
-
-
-def rotate_half(x):
- """Rotates half the hidden dims of the input."""
- x1 = x[..., : x.shape[-1] // 2]
- x2 = x[..., x.shape[-1] // 2 :]
- return torch.cat((-x2, x1), dim=-1)
-
-
-def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
- """
- This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
- num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
- """
- batch, num_key_value_heads, slen, head_dim = hidden_states.shape
- if n_rep == 1:
- return hidden_states
- hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
- return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
-
-
-def eager_attention_forward(
- module: nn.Module,
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- attention_mask: Optional[torch.Tensor],
- scaling: float,
- dropout: float = 0.0,
- **kwargs,
-):
- key_states = repeat_kv(key, module.num_key_value_groups)
- value_states = repeat_kv(value, module.num_key_value_groups)
-
- attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
- if attention_mask is not None:
- causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
- attn_weights = attn_weights + causal_mask
-
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
- attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
- attn_output = torch.matmul(attn_weights, value_states)
- attn_output = attn_output.transpose(1, 2).contiguous()
-
- return attn_output, attn_weights
-
-
-# Adapted from transformers.models.glm.modular_glm.apply_rotary_pos_emb
-def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
- """Applies Rotary Position Embedding to the query and key tensors.
-
- Removes the interleaving of cos and sin from GLM
-
- Args:
- q (`torch.Tensor`): The query tensor.
- k (`torch.Tensor`): The key tensor.
- cos (`torch.Tensor`): The cosine part of the rotary embedding.
- sin (`torch.Tensor`): The sine part of the rotary embedding.
- position_ids (`torch.Tensor`, *optional*):
- Deprecated and unused.
- unsqueeze_dim (`int`, *optional*, defaults to 1):
- The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
- sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
- that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
- k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
- cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
- the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
- Returns:
- `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
- """
- cos = cos.unsqueeze(unsqueeze_dim)
- sin = sin.unsqueeze(unsqueeze_dim)
-
- # Keep half or full tensor for later concatenation
- rotary_dim = cos.shape[-1]
- q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
- k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
-
- # Apply rotary embeddings on the first half or full tensor
- q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin)
- k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin)
-
- # Concatenate back to full shape
- q_embed = torch.cat([q_embed, q_pass], dim=-1)
- k_embed = torch.cat([k_embed, k_pass], dim=-1)
- return q_embed, k_embed
-
-
-class BambaAttention(nn.Module):
- """Multi-headed attention from 'Attention Is All You Need' paper"""
-
- def __init__(self, config: BambaConfig, layer_idx: int):
- super().__init__()
- self.config = config
- self.layer_idx = layer_idx
- self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
- self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
- self.scaling = self.head_dim**-0.5
- self.attention_dropout = config.attention_dropout
- self.is_causal = True
-
- self.q_proj = nn.Linear(
- config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
- )
- self.k_proj = nn.Linear(
- config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
- )
- self.v_proj = nn.Linear(
- config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
- )
- self.o_proj = nn.Linear(
- config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
- )
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- position_embeddings: Tuple[torch.Tensor, torch.Tensor],
- attention_mask: Optional[torch.Tensor],
- past_key_value: Optional[Cache] = None,
- cache_position: Optional[torch.LongTensor] = None,
- **kwargs: Unpack[FlashAttentionKwargs],
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
- input_shape = hidden_states.shape[:-1]
- hidden_shape = (*input_shape, -1, self.head_dim)
-
- query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
- key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
- value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
-
- cos, sin = position_embeddings
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
-
- if past_key_value is not None:
- # sin and cos are specific to RoPE models; cache_position needed for the static cache
- cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
- key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
-
- attention_interface: Callable = eager_attention_forward
- if self.config._attn_implementation != "eager":
- if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
- logger.warning_once(
- "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
- 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
- )
- else:
- attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
-
- attn_output, attn_weights = attention_interface(
- self,
- query_states,
- key_states,
- value_states,
- attention_mask,
- dropout=0.0 if not self.training else self.attention_dropout,
- scaling=self.scaling,
- **kwargs,
- )
-
- attn_output = attn_output.reshape(*input_shape, -1).contiguous()
- attn_output = self.o_proj(attn_output)
- return attn_output, attn_weights
-
-
-class BambaRMSNormGated(torch.nn.Module):
- def __init__(self, hidden_size, eps=1e-6):
- super().__init__()
- self.weight = nn.Parameter(torch.ones(hidden_size))
- self.variance_epsilon = eps
-
- def forward(self, hidden_states, gate=None):
- input_dtype = hidden_states.dtype
- hidden_states = hidden_states.to(torch.float32)
-
- if gate is not None:
- hidden_states = hidden_states * nn.functional.silu(gate.to(torch.float32))
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
-
- return self.weight * hidden_states.to(input_dtype)
-
-
-# Helper methods for segment sum computation
-
-
-def pad_tensor_by_size(input_tensor: torch.Tensor, pad_size: int):
- """
- Padding x tensor with `pad_size` on the seq_len dim (dim=1)
-
- Assumes that we only have tensors of either size 4 or 3
- """
- pad_shape = (0, 0, 0, 0, 0, pad_size, 0, 0) if len(input_tensor.shape) == 4 else (0, 0, 0, pad_size, 0, 0)
-
- return torch.nn.functional.pad(input_tensor, pad_shape, mode="constant", value=0)
-
-
-def reshape_into_chunks(input_tensor, pad_size, chunk_size):
- """
- Padding input_tensor with `pad_size` on the seq_len dim (dim=1) and
- simultaneously splitting it into chunk sequences.
-
- Assumes that we only have tensors of either size 4 or 3
- """
- # [bsz, seq_len, ...] -> [bsz, seq_len multiple of chunk_size, ...]
- input_tensor = pad_tensor_by_size(input_tensor, pad_size)
-
- if len(input_tensor.shape) == 3:
- # [bsz, seq_len multiple of chunk_size, num_heads] -> [bsz, -1, chunk_size, num_heads]
- return input_tensor.reshape(input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2])
- else:
- # [bsz, seq_len multiple of chunk_size, num_heads, head_dim or state_size] -> [bsz, -1, chunk_size, num_heads, head_dim or state_size]
- return input_tensor.reshape(
- input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2], input_tensor.shape[3]
- )
-
-
-def segment_sum(input_tensor):
- """
- More stable segment sum calculation. Uses cumulative sums and masking instead of direct subtractions.
- """
- chunk_size = input_tensor.size(-1)
- # 1. expand input tensor to have an additional dimension and repeat along that dimension
- # [..., chunk_size] -> [..., chunk_size, chunk_size]
- input_tensor = input_tensor[..., None].expand(*input_tensor.size(), chunk_size)
- # 2. create a lower triangular mask with the diagonal set to 0 to 0 out elements above diag
- mask = torch.tril(torch.ones(chunk_size, chunk_size, device=input_tensor.device, dtype=torch.bool), diagonal=-1)
- input_tensor = input_tensor.masked_fill(~mask, 0)
- # 3. compute actual cumsum
- tensor_segsum = torch.cumsum(input_tensor, dim=-2)
-
- # 4. apply mask to keep only the lower triangular part of the cumulative sum result (incl diagonal this time)
- mask = torch.tril(torch.ones(chunk_size, chunk_size, device=input_tensor.device, dtype=torch.bool), diagonal=0)
- tensor_segsum = tensor_segsum.masked_fill(~mask, -torch.inf)
- return tensor_segsum
-
-
-is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update))
-
-
-def apply_mask_to_padding_states(hidden_states, attention_mask):
- """
- Tunes out the hidden states for padding tokens, see https://github.com/state-spaces/mamba/issues/66
- """
- if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
- dtype = hidden_states.dtype
- hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)
-
- return hidden_states
-
-
-# Adapted from transformers.models.mamba2.modeling_mamba2.Mamba2Mixer
-class BambaMixer(nn.Module):
- """
- Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`.
- A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)
- ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4,
- and is why Mamba is called **selective** state spaces)
-
- The are a few differences between this and Mamba2Mixer:
- - The variable use_precomputed_states is slightly different due to the HybridCache structure
- - There's a few non-obvious bugs fixed with batching in the slow path that exist in main
- - Some extra variables that our layer doesn't need have been removed
- - We ported most of the refactors in https://github.com/huggingface/transformers/pull/35154, which is (as of Dec 18, 2024) unmerged
- """
-
- def __init__(self, config: BambaConfig, layer_idx: int):
- super().__init__()
- self.num_heads = config.mamba_n_heads
- self.hidden_size = config.hidden_size
- self.ssm_state_size = config.mamba_d_state
- self.conv_kernel_size = config.mamba_d_conv
- self.intermediate_size = int(config.mamba_expand * self.hidden_size)
- self.layer_idx = layer_idx
- self.use_conv_bias = config.mamba_conv_bias
- self.activation = config.hidden_act
- self.act = ACT2FN[config.hidden_act]
- self.use_bias = config.mamba_proj_bias
-
- self.layer_norm_epsilon = config.rms_norm_eps
-
- self.n_groups = config.mamba_n_groups
- self.head_dim = config.mamba_d_head
- self.chunk_size = config.mamba_chunk_size
-
- # FIXME:
- self.time_step_limit = (0.0, float("inf"))
- self.time_step_min = 0.001
- self.time_step_max = 0.1
-
- self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.ssm_state_size
- self.conv1d = nn.Conv1d(
- in_channels=self.conv_dim,
- out_channels=self.conv_dim,
- bias=config.mamba_conv_bias,
- kernel_size=self.conv_kernel_size,
- groups=self.conv_dim,
- padding=self.conv_kernel_size - 1,
- )
-
- # projection of the input hidden states
- projection_size = self.intermediate_size + self.conv_dim + self.num_heads
- self.in_proj = nn.Linear(
- self.hidden_size,
- projection_size,
- bias=self.use_bias,
- )
- # selective projection used to make dt, B and C input dependant
-
- # time step projection (discretization)
- # instantiate once and copy inv_dt in init_weights of PretrainedModel
- self.dt_bias = nn.Parameter(torch.ones(self.num_heads))
-
- # S4D real initialization. These are not discretized!
- # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded
- A = torch.arange(1, self.num_heads + 1)
- self.A_log = nn.Parameter(torch.log(A))
- self.A_log._no_weight_decay = True
- self.norm = BambaRMSNormGated(self.intermediate_size, eps=self.layer_norm_epsilon)
- self.D = nn.Parameter(torch.ones(self.num_heads))
- self.D._no_weight_decay = True
-
- self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=self.use_bias)
-
- if not is_fast_path_available:
- logger.warning_once(
- "The fast path is not available because on of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)`"
- " is None. Falling back to the naive implementation. To install follow https://github.com/state-spaces/mamba/#installation and"
- " https://github.com/Dao-AILab/causal-conv1d"
- )
- else:
- logger.warning_once("The fast path for Bamba will be used when running the model on a GPU")
-
- def cuda_kernels_forward(
- self,
- hidden_states: torch.Tensor,
- cache_params: Optional[HybridMambaAttentionDynamicCache] = None,
- cache_position: Optional[torch.LongTensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- ):
- # 1. Gated MLP's linear projection
- hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)
- projected_states = self.in_proj(hidden_states)
-
- # Set up dimensions for reshapes later
- batch_size, seq_len, _ = hidden_states.shape
- groups_time_state_size = self.n_groups * self.ssm_state_size
-
- use_precomputed_states = (
- cache_params is not None
- and cache_params.has_previous_state
- and seq_len == 1
- and cache_params.conv_states[self.layer_idx].shape[0]
- == cache_params.ssm_states[self.layer_idx].shape[0]
- == batch_size
- and cache_position is not None
- and cache_position[0] > 0
- )
-
- # getting projected states from cache if it exists
- if use_precomputed_states:
- gate, hidden_states_B_C, dt = projected_states.squeeze(1).split(
- [self.intermediate_size, self.conv_dim, self.num_heads], dim=-1
- )
-
- # 2. Convolution sequence transformation
- hidden_states_B_C = causal_conv1d_update(
- hidden_states_B_C,
- cache_params.conv_states[self.layer_idx],
- self.conv1d.weight.squeeze(1),
- self.conv1d.bias,
- self.activation,
- )
-
- hidden_states, B, C = torch.split(
- hidden_states_B_C,
- [self.intermediate_size, groups_time_state_size, groups_time_state_size],
- dim=-1,
- )
-
- # 3. SSM transformation
- A = -torch.exp(self.A_log.float()) # (nheads,)
- A = A[:, None, ...][:, :, None].expand(-1, self.head_dim, self.ssm_state_size).to(dtype=torch.float32)
- dt = dt[:, :, None].expand(-1, -1, self.head_dim)
- dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim)
- D = self.D[:, None, ...].expand(-1, self.head_dim)
- B = B.view(batch_size, self.n_groups, B.shape[1] // self.n_groups)
- C = C.view(batch_size, self.n_groups, C.shape[1] // self.n_groups)
- hidden_states_reshaped = hidden_states.view(batch_size, self.num_heads, self.head_dim)
- hidden_states = selective_state_update(
- cache_params.ssm_states[self.layer_idx],
- hidden_states_reshaped,
- dt,
- A,
- B,
- C,
- D,
- z=None,
- dt_bias=dt_bias,
- dt_softplus=True,
- )
- hidden_states = hidden_states.view(batch_size, self.num_heads * self.head_dim)
- hidden_states = self.norm(hidden_states, gate)
-
- # 4. Final linear projection
- out = self.out_proj(hidden_states)[:, None, ...]
- # Fused calculations or step by step if no initialized cache is found
- else:
- A = -torch.exp(self.A_log.float()) # (num_heads) or (intermediate_size, state_size)
- dt_limit_kwargs = {} if self.time_step_limit == (0.0, float("inf")) else {"dt_limit": self.time_step_limit}
-
- # 2-4. Fused kernel for conv1d, SSM, and the final projection
- if self.training and cache_params is None:
- out = mamba_split_conv1d_scan_combined(
- projected_states,
- self.conv1d.weight.squeeze(1),
- self.conv1d.bias,
- self.dt_bias,
- A,
- D=self.D,
- chunk_size=self.chunk_size,
- seq_idx=None, # was seq_idx
- activation=self.activation,
- rmsnorm_weight=self.norm.weight,
- rmsnorm_eps=self.norm.variance_epsilon,
- outproj_weight=self.out_proj.weight,
- outproj_bias=self.out_proj.bias,
- headdim=self.head_dim,
- ngroups=self.n_groups,
- norm_before_gate=False,
- return_final_states=False,
- **dt_limit_kwargs,
- )
-
- else:
- gate, hidden_states_B_C, dt = projected_states.split(
- [self.intermediate_size, self.conv_dim, self.num_heads], dim=-1
- )
-
- # 2. Convolution sequence transformation
- # Init cache
- if cache_params is not None:
- # storing the states
- # If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
- # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
- hidden_states_B_C_transposed = hidden_states_B_C.transpose(1, 2)
- conv_states = nn.functional.pad(
- hidden_states_B_C_transposed,
- (self.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0),
- )
- cache_params.conv_states[self.layer_idx].copy_(conv_states)
-
- if self.activation not in ["silu", "swish"]:
- hidden_states_B_C = self.act(
- self.conv1d(hidden_states_B_C.transpose(1, 2))[..., :seq_len].transpose(1, 2)
- )
- else:
- hidden_states_B_C = causal_conv1d_fn(
- x=hidden_states_B_C.transpose(1, 2),
- weight=self.conv1d.weight.squeeze(1),
- bias=self.conv1d.bias,
- activation=self.activation,
- ).transpose(1, 2)
-
- hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask)
- hidden_states, B, C = torch.split(
- hidden_states_B_C,
- [self.intermediate_size, groups_time_state_size, groups_time_state_size],
- dim=-1,
- )
-
- # 3. SSM transformation
- scan_output, ssm_state = mamba_chunk_scan_combined(
- hidden_states.view(batch_size, seq_len, -1, self.head_dim),
- dt,
- A,
- B.view(batch_size, seq_len, self.n_groups, -1),
- C.view(batch_size, seq_len, self.n_groups, -1),
- chunk_size=self.chunk_size,
- D=self.D,
- z=None,
- seq_idx=None,
- return_final_states=True,
- dt_bias=self.dt_bias,
- dt_softplus=True,
- **dt_limit_kwargs,
- )
-
- # Init cache
- if ssm_state is not None and cache_params is not None:
- cache_params.ssm_states[self.layer_idx].copy_(ssm_state)
-
- scan_output = scan_output.view(batch_size, seq_len, -1)
- # Multiply "gate" branch and apply extra normalization layer
- scan_output = self.norm(scan_output, gate)
-
- # 4. Final linear projection
- out = self.out_proj(scan_output)
- return out
-
- # fmt: off
- def torch_forward(
- self,
- input_states,
- cache_params: Optional[HybridMambaAttentionDynamicCache] = None,
- cache_position: Optional[torch.LongTensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- ):
- batch_size, seq_len, _ = input_states.shape
- dtype = input_states.dtype
-
- # 1. Gated MLP's linear projection
- input_states = apply_mask_to_padding_states(input_states, attention_mask)
- projected_states = self.in_proj(input_states)
- gate, hidden_states_B_C, dt = projected_states.split(
- [self.intermediate_size, self.conv_dim, self.num_heads], dim=-1
- )
-
- use_precomputed_states = (
- cache_params is not None
- and cache_params.has_previous_state
- and seq_len == 1
- and cache_params.conv_states[self.layer_idx].shape[0]
- == cache_params.ssm_states[self.layer_idx].shape[0]
- == batch_size
- and cache_position is not None
- and cache_position[0] > 0
- )
-
- # 2. Convolution sequence transformation
- if use_precomputed_states:
- cache_params.conv_states[self.layer_idx] = cache_params.conv_states[self.layer_idx].roll(shifts=-1, dims=-1)
- cache_params.conv_states[self.layer_idx][:, :, -1] = hidden_states_B_C[:, 0, :].to(cache_params.conv_states[self.layer_idx].device)
-
- # We need to guarantee that anything regarding the cache is on the same device
- conv_states = cache_params.conv_states[self.layer_idx].to(device=self.conv1d.weight.device)
-
- hidden_states_B_C = torch.sum(
- conv_states * self.conv1d.weight.squeeze(1), dim=-1
- )
- if self.use_conv_bias:
- hidden_states_B_C = hidden_states_B_C + self.conv1d.bias
- hidden_states_B_C = self.act(hidden_states_B_C)
- else:
- # Init cache
- if cache_params is not None:
- hidden_states_B_C_transposed = hidden_states_B_C.transpose(1, 2)
- conv_states = nn.functional.pad(
- hidden_states_B_C_transposed, (self.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0)
- )
- cache_params.conv_states[self.layer_idx].copy_(conv_states)
-
- hidden_states_B_C = self.act(self.conv1d(hidden_states_B_C.transpose(1, 2))[..., :seq_len].transpose(1, 2))
-
- hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask)
- hidden_states, B, C = torch.split(
- hidden_states_B_C,
- [self.intermediate_size, self.n_groups * self.ssm_state_size, self.n_groups * self.ssm_state_size],
- dim=-1
- )
-
- # 3. SSM transformation
- A = -torch.exp(self.A_log.float()) # [num_heads]
- if use_precomputed_states:
- # We need to guarantee that anything regarding the cache is on the same device
- cache_device = cache_params.ssm_states[self.layer_idx].device
-
- # Note: there is no need to pad parameter matrices here, as there is just one new token
- # for batched generation
- dt = dt[:, 0, :][:, None, ...]
- dt = dt.transpose(1, 2).expand(batch_size, dt.shape[-1], self.head_dim)
- # [num_heads] -> [num_heads, head_dim]
- dt_bias = self.dt_bias[..., None].expand(self.dt_bias.shape[0], self.head_dim)
-
- dt = torch.nn.functional.softplus(dt + dt_bias.to(dt.dtype))
- dt = torch.clamp(dt, self.time_step_limit[0], self.time_step_limit[1])
- A = A[..., None, None].expand(self.num_heads, self.head_dim, self.ssm_state_size).to(dtype=torch.float32)
- # [bsz, num_heads, head_dim, state_size]
- dA = (torch.exp(dt[..., None] * A)).to(device=cache_device)
-
- # Discretize B
- # [bsz, n_groups * state_size] -> [bsz, n_groups, 1, state_size] ->
- # -> [bsz, n_groups, group to head repetition factor, state_size] -> [bsz, num_heads, state_size]
- B = B.reshape(batch_size, self.n_groups, -1)[..., None, :]
- B = B.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, B.shape[-1]).contiguous()
- B = B.reshape(batch_size, -1, B.shape[-1])
- # [bsz, num_heads, head_dim, state_size]
- dB = dt[..., None] * B[..., None, :]
-
- # Discretize x into dB
- # [bsz, intermediate_size] -> [bsz, num_heads, head_dim]
- hidden_states = hidden_states.reshape(batch_size, -1, self.head_dim)
- dBx = (dB * hidden_states[..., None]).to(device=cache_device)
-
- # State calculation
- cache_params.ssm_states[self.layer_idx].copy_(
- cache_params.ssm_states[self.layer_idx] * dA + dBx
- )
-
- # Subsequent output
- # [bsz, n_groups * state_size] -> [bsz, num_heads, state_size]
- C = C.reshape(batch_size, self.n_groups, -1)[..., None, :]
- C = C.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, C.shape[-1]).contiguous()
- C = C.reshape(batch_size, -1, C.shape[-1])
- # [bsz, num_heads, head_dim]
-
- ssm_states = cache_params.ssm_states[self.layer_idx].to(device=C.device, dtype=C.dtype) # Shape: [b, h, d, n]
- # Reshape ssm_states to merge the first two dimensions
- ssm_states_reshaped = ssm_states.view(batch_size * self.num_heads, self.head_dim, self.ssm_state_size) # Shape: [b*h, d, n]
- C_reshaped = C.view(batch_size * self.num_heads, self.ssm_state_size, 1) # Shape: [b*h, n, 1]
- y = torch.bmm(ssm_states_reshaped, C_reshaped)
- y = y.view(batch_size, self.num_heads, self.head_dim)
-
- # D skip connection
- # [num_heads] -> [num_heads, head_dim]
- D = self.D[..., None].expand(self.D.shape[0], self.head_dim)
- y = (y + hidden_states * D).to(y.dtype)
-
- # [bsz, num_heads, head_dim] -> [bsz, 1, intermediate_size]
- y = y.reshape(batch_size, -1)[:, None, ...]
- else:
- # begin ssd naive implementation without einsums
- dt = nn.functional.softplus(dt + self.dt_bias)
- dt = torch.clamp(dt, self.time_step_limit[0], self.time_step_limit[1])
- hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim).float()
- B = B.reshape(batch_size, seq_len, -1, self.ssm_state_size).float()
- C = C.reshape(batch_size, seq_len, -1, self.ssm_state_size).float()
- B = B.repeat(1, 1, self.num_heads // self.n_groups, 1)
- C = C.repeat(1, 1, self.num_heads // self.n_groups, 1)
- pad_size = (self.chunk_size - seq_len % self.chunk_size) % self.chunk_size
-
- D_residual = self.D[..., None] * pad_tensor_by_size(hidden_states, pad_size)
-
- # Discretize x and A
- hidden_states = hidden_states * dt[..., None]
- A = A.to(hidden_states.dtype) * dt
-
- # Rearrange into blocks/chunks
- hidden_states, A, B, C = [reshape_into_chunks(t, pad_size, self.chunk_size) for t in (hidden_states, A, B, C)]
-
- # [bsz, -1, chunk_size, num_heads] -> [bsz, num_heads, -1, chunk_size]
- A = A.permute(0, 3, 1, 2)
- A_cumsum = torch.cumsum(A, dim=-1)
-
- # 1. Compute the output for each intra-chunk (diagonal blocks)
- # This is the analog of a causal mask
- L = torch.exp(segment_sum(A))
-
- # Contraction of C and B to get G (attention-weights like)
- G_intermediate = C[:, :, :, None, :, :] * B[:, :, None, :, :, :] # shape: (b, c, l, s, h, n)
- G = G_intermediate.sum(dim=-1) # shape: (b, c, l, s, h)
-
- # Compute M, equivalent to applying attention mask to weights
- M_intermediate = G[..., None] * L.permute(0, 2, 3, 4, 1)[..., None]
- M = M_intermediate.sum(dim=-1)
-
- # Compute Y_diag (apply to values)
- Y_diag = (M[..., None] * hidden_states[:, :, None]).sum(dim=3)
-
- # 2. Compute the state for each intra-chunk
- # (right term of low-rank factorization of off-diagonal blocks; B terms)
- decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum))
- B_decay = B * decay_states.permute(0, -2, -1, 1)[..., None]
- states = (B_decay[..., None, :] * hidden_states[..., None]).sum(dim=2)
-
- # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
- # (middle term of factorization of off-diag blocks; A terms)
- if use_precomputed_states:
- previous_states = cache_params.ssm_states[self.layer_idx][:, None, ...].to(device=states.device)
- else:
- previous_states = torch.zeros_like(states[:, :1])
- states = torch.cat([previous_states, states], dim=1)
- decay_chunk = torch.exp(segment_sum(nn.functional.pad(A_cumsum[:, :, :, -1], (1, 0))))
- decay_chunk = decay_chunk.transpose(1, 3)
- new_states = (decay_chunk[..., None, None] * states[:, :, None, ...]).sum(dim=1)
- states, ssm_state = new_states[:, :-1], new_states[:, -1]
-
- # 4. Compute state -> output conversion per chunk
- # (left term of low-rank factorization of off-diagonal blocks; C terms)
- state_decay_out = torch.exp(A_cumsum)
- C_times_states = (C[..., None, :] * states[:, :, None, ...])
- state_decay_out_permuted = state_decay_out.permute(0, 2, 3, 1)
- Y_off = (C_times_states.sum(-1) * state_decay_out_permuted[..., None])
-
- # Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks)
- y = Y_diag + Y_off
- # [bsz, -1, self.chunk_size, num_heads, head_dim] -> [bsz, (padded) seq_len, num_heads, head_dim]
- y = y.reshape(batch_size, -1, self.num_heads, self.head_dim)
-
- y = y + D_residual
- # Cutting off padded chunks
- if pad_size > 0:
- y = y[:, :seq_len, :, :]
- y = y.reshape(batch_size, seq_len, -1)
-
- # Init cache
- if ssm_state is not None and cache_params is not None:
- cache_params.ssm_states[self.layer_idx].copy_(ssm_state)
-
- scan_output = self.norm(y, gate)
-
- # end ssd naive
-
- # 4. Final linear projection
- contextualized_states = self.out_proj(scan_output.to(dtype)) # [batch, seq_len, hidden_size]
- return contextualized_states
- # fmt: on
-
- def forward(
- self,
- hidden_states,
- cache_params: Optional[HybridMambaAttentionDynamicCache] = None,
- cache_position: Optional[torch.LongTensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- ):
- if is_fast_path_available and "cuda" in self.in_proj.weight.device.type:
- return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask)
- dtype = hidden_states.dtype
- if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
- # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66
- hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)
-
- return self.torch_forward(hidden_states, cache_params, cache_position, attention_mask)
-
-
-class BambaMLP(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.config = config
- self.hidden_size = config.hidden_size
- self.intermediate_size = config.intermediate_size
- self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
- self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
- self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
- self.act_fn = ACT2FN[config.hidden_act]
-
- def forward(self, x):
- down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
- return down_proj
-
-
-class BambaRMSNorm(nn.Module):
- def __init__(self, hidden_size, eps=1e-6):
- """
- BambaRMSNorm is equivalent to T5LayerNorm
- """
- super().__init__()
- self.weight = nn.Parameter(torch.ones(hidden_size))
- self.variance_epsilon = eps
-
- def forward(self, hidden_states):
- input_dtype = hidden_states.dtype
- hidden_states = hidden_states.to(torch.float32)
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
- return self.weight * hidden_states.to(input_dtype)
-
- def extra_repr(self):
- return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
-
-
-class BambaDecoderLayer(nn.Module):
- def __init__(self, config: BambaConfig, layer_idx: int, layer_type: str = "mamba"):
- super().__init__()
-
- num_experts = 1
- ffn_layer_class = BambaMLP if num_experts == 1 else None
- self.feed_forward = ffn_layer_class(config)
- self.input_layernorm = BambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
- self.pre_ff_layernorm = BambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
-
- self.layer_type = layer_type
- if layer_type == "mamba":
- self.mamba = BambaMixer(config=config, layer_idx=layer_idx)
- elif layer_type == "attention":
- self.self_attn = BambaAttention(config, layer_idx)
- else:
- raise ValueError("Invalid layer_type")
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[HybridMambaAttentionDynamicCache] = None,
- output_attentions: Optional[bool] = False,
- use_cache: Optional[bool] = False,
- cache_position: Optional[torch.LongTensor] = None,
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
- **kwargs,
- ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
- """
- Args:
- hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
- attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
- `(batch, sequence_length)` where padding elements are indicated by 0.
- past_key_value (`HybridMambaAttentionDynamicCache`, *optional*): cached past key and value projection states
- output_attentions (`bool`, *optional*):
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
- returned tensors for more detail.
- use_cache (`bool`, *optional*):
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
- (see `past_key_values`).
- cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
- Indices depicting the position of the input sequence tokens in the sequence.
- position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
- Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
- with `head_dim` being the embedding dimension of each attention head.
- kwargs (`dict`, *optional*):
- Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
- into the model
- """
-
- residual = hidden_states
-
- hidden_states = self.input_layernorm(hidden_states)
-
- # this is a hybrid decoder layer
- if self.layer_type == "mamba":
- hidden_states = self.mamba(
- hidden_states=hidden_states,
- cache_params=past_key_value,
- cache_position=cache_position,
- attention_mask=attention_mask,
- )
- self_attn_weights = None
- elif self.layer_type == "attention":
- hidden_states, self_attn_weights = self.self_attn(
- hidden_states=hidden_states,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_value=past_key_value,
- output_attentions=output_attentions,
- use_cache=use_cache,
- cache_position=cache_position,
- position_embeddings=position_embeddings,
- **kwargs,
- )
-
- # residual connection after attention
- hidden_states = residual + hidden_states
-
- # feed-forward
- residual = hidden_states
- hidden_states = self.pre_ff_layernorm(hidden_states)
- hidden_states = self.feed_forward(hidden_states)
- hidden_states = residual + hidden_states
-
- outputs = (hidden_states,)
-
- if output_attentions:
- outputs += (self_attn_weights,)
-
- return outputs
-
-
-BAMBA_START_DOCSTRING = r"""
- This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
- library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
- etc.)
-
- This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
- Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
- and behavior.
-
- Parameters:
- config ([`BambaConfig`]):
- Model configuration class with all the parameters of the model. Initializing with a config file does not
- load the weights associated with the model, only the configuration. Check out the
- [`~PreTrainedModel.from_pretrained`] method to load the model weights.
-"""
-
-
-@add_start_docstrings(
- "The bare BambaModel outputting raw hidden-states without any specific head on top.",
- BAMBA_START_DOCSTRING,
-)
-class BambaPreTrainedModel(PreTrainedModel):
- config_class = BambaConfig
- base_model_prefix = "model"
- supports_gradient_checkpointing = True
- _no_split_modules = ["BambaDecoderLayer"]
- _skip_keys_device_placement = "past_key_values"
- _supports_flash_attn_2 = True
- _supports_sdpa = True
- _supports_cache_class = True # Note: only supports HybridMambaAttentionDynamicCache
- _is_stateful = True
-
- def _init_weights(self, module):
- std = self.config.initializer_range
- if isinstance(module, (nn.Linear, nn.Conv1d)):
- module.weight.data.normal_(mean=0.0, std=std)
- if module.bias is not None:
- module.bias.data.zero_()
- elif isinstance(module, nn.Embedding):
- module.weight.data.normal_(mean=0.0, std=std)
- if module.padding_idx is not None:
- module.weight.data[module.padding_idx].zero_()
-
-
-BAMBA_INPUTS_DOCSTRING = r"""
- Args:
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
- Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
- it.
-
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
- [`PreTrainedTokenizer.__call__`] for details.
-
- [What are input IDs?](../glossary#input-ids)
- attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
-
- - 1 for tokens that are **not masked**,
- - 0 for tokens that are **masked**.
-
- [What are attention masks?](../glossary#attention-mask)
-
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
- [`PreTrainedTokenizer.__call__`] for details.
-
- If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
- `past_key_values`).
-
- If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
- and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
- information on the default strategy.
-
- - 1 indicates the head is **not masked**,
- - 0 indicates the head is **masked**.
- position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
- config.n_positions - 1]`.
-
- [What are position IDs?](../glossary#position-ids)
- past_key_values (`HybridMambaAttentionDynamicCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
- A HybridMambaAttentionDynamicCache object containing pre-computed hidden-states (keys and values in the
- self-attention blocks and convolution and ssm states in the mamba blocks) that can be used (see
- `past_key_values` input) to speed up sequential decoding.
- Key and value cache tensors have shape `(batch_size, num_heads, seq_len, head_dim)`.
- Convolution and ssm states tensors have shape `(batch_size, d_inner, d_conv)` and
- `(batch_size, d_inner, d_state)` respectively.
- See the `HybridMambaAttentionDynamicCache` class for more details.
-
- If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that
- don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
- `input_ids` of shape `(batch_size, sequence_length)`.
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
- is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
- model's internal embedding lookup matrix.
- use_cache (`bool`, *optional*):
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
- `past_key_values`).
- output_attentions (`bool`, *optional*):
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
- tensors for more detail.
- output_hidden_states (`bool`, *optional*):
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
- more detail.
- output_router_logits (`bool`, *optional*):
- Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
- should not be returned during inference.
- return_dict (`bool`, *optional*):
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
- cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
- Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
- this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
- the complete sequence length.
-"""
-
-
-@add_start_docstrings(
- "The bare Bamba Model outputting raw hidden-states without any specific head on top.",
- BAMBA_START_DOCSTRING,
-)
-# Adapted from transformers.models.jamba.modeling_jamba.JambaModel
-class BambaModel(BambaPreTrainedModel):
- """
- Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`BambaDecoderLayer`]
-
- Args:
- config: BambaConfig
- """
-
- def __init__(self, config: BambaConfig):
- super().__init__(config)
- self.padding_idx = config.pad_token_id
- self.vocab_size = config.vocab_size
-
- self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
- decoder_layers = []
- for i in range(config.num_hidden_layers):
- decoder_layers.append(BambaDecoderLayer(config, layer_idx=i, layer_type=config.layers_block_type[i]))
- self.layers = nn.ModuleList(decoder_layers)
-
- self._attn_implementation = config._attn_implementation
- self.final_layernorm = BambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
- self.rotary_emb = BambaRotaryEmbedding(config=config)
-
- self.gradient_checkpointing = False
- # Initialize weights and apply final processing
- self.post_init()
-
- def get_input_embeddings(self):
- return self.embed_tokens
-
- def set_input_embeddings(self, value):
- self.embed_tokens = value
-
- @add_start_docstrings_to_model_forward(BAMBA_INPUTS_DOCSTRING)
- def forward(
- self,
- input_ids: torch.LongTensor = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[HybridMambaAttentionDynamicCache] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- cache_position: Optional[torch.LongTensor] = None,
- ) -> Union[Tuple, BaseModelOutputWithPast]:
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- use_cache = use_cache if use_cache is not None else self.config.use_cache
-
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
-
- if (input_ids is None) ^ (inputs_embeds is not None):
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
-
- if self.gradient_checkpointing and self.training and use_cache:
- logger.warning_once(
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
- )
- use_cache = False
-
- if inputs_embeds is None:
- inputs_embeds = self.embed_tokens(input_ids)
- hidden_states = inputs_embeds
-
- if use_cache and past_key_values is None:
- logger.warning_once(
- "Bamba requires an initialized `HybridMambaAttentionDynamicCache` to return a cache. None was "
- "provided, so no cache will be returned."
- )
-
- if cache_position is None:
- cache_position = torch.arange(hidden_states.shape[1], device=hidden_states.device)
- if position_ids is None:
- position_ids = cache_position.unsqueeze(0)
-
- causal_mask = self._update_causal_mask(
- attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
- )
- mamba_mask = self._update_mamba_mask(attention_mask, cache_position)
-
- # create position embeddings to be shared across the decoder layers
- position_embeddings = self.rotary_emb(hidden_states, position_ids)
-
- all_hidden_states = () if output_hidden_states else None
- all_self_attns = () if output_attentions else None
-
- for decoder_layer in self.layers:
- # Depending on the layer type we opt for 2D base attention mask (Mamba) or 4D causal mask (Attention)
- layer_mask = mamba_mask if decoder_layer.layer_type == "mamba" else causal_mask
-
- if output_hidden_states:
- all_hidden_states += (hidden_states,)
-
- if self.gradient_checkpointing and self.training:
- layer_outputs = self._gradient_checkpointing_func(
- decoder_layer.__call__,
- hidden_states,
- layer_mask,
- position_ids,
- past_key_values,
- output_attentions,
- use_cache,
- cache_position,
- position_embeddings,
- )
- else:
- layer_outputs = decoder_layer(
- hidden_states,
- attention_mask=layer_mask,
- position_ids=position_ids,
- past_key_value=past_key_values,
- output_attentions=output_attentions,
- use_cache=use_cache,
- cache_position=cache_position,
- position_embeddings=position_embeddings,
- )
-
- hidden_states = layer_outputs[0]
-
- if output_attentions:
- if layer_outputs[1] is not None:
- # append attentions only of attention layers. Mamba layers return `None` as the attention weights
- all_self_attns += (layer_outputs[1],)
-
- hidden_states = self.final_layernorm(hidden_states)
-
- # add hidden states from the last decoder layer
- if output_hidden_states:
- all_hidden_states += (hidden_states,)
-
- if past_key_values and not past_key_values.has_previous_state:
- past_key_values.has_previous_state = True
-
- next_cache = None if not use_cache else past_key_values
-
- if not return_dict:
- return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
- return BaseModelOutputWithPast(
- last_hidden_state=hidden_states,
- past_key_values=next_cache,
- hidden_states=all_hidden_states,
- attentions=all_self_attns,
- )
-
- def _update_causal_mask(
- self,
- attention_mask: torch.Tensor,
- input_tensor: torch.Tensor,
- cache_position: torch.Tensor,
- past_key_values: HybridMambaAttentionDynamicCache,
- output_attentions: bool,
- ):
- if self.config._attn_implementation == "flash_attention_2":
- if attention_mask is not None and 0.0 in attention_mask:
- return attention_mask
- return None
-
- # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
- # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
- # to infer the attention mask.
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
-
- # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
- if self.config._attn_implementation == "sdpa" and not output_attentions:
- if AttentionMaskConverter._ignore_causal_mask_sdpa(
- attention_mask,
- inputs_embeds=input_tensor,
- past_key_values_length=past_seen_tokens,
- is_training=self.training,
- ):
- return None
-
- dtype, device = input_tensor.dtype, input_tensor.device
- sequence_length = input_tensor.shape[1]
- target_length = (
- attention_mask.shape[-1]
- if isinstance(attention_mask, torch.Tensor)
- else past_seen_tokens + sequence_length + 1
- )
-
- # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
- causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
- attention_mask,
- sequence_length=sequence_length,
- target_length=target_length,
- dtype=dtype,
- device=device,
- cache_position=cache_position,
- batch_size=input_tensor.shape[0],
- )
-
- if (
- self.config._attn_implementation == "sdpa"
- and attention_mask is not None
- and attention_mask.device.type == "cuda"
- and not output_attentions
- ):
- # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
- # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
- # Details: https://github.com/pytorch/pytorch/issues/110213
- min_dtype = torch.finfo(dtype).min
- causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
-
- return causal_mask
-
- @staticmethod
- def _prepare_4d_causal_attention_mask_with_cache_position(
- attention_mask: torch.Tensor,
- sequence_length: int,
- target_length: int,
- dtype: torch.dtype,
- device: torch.device,
- cache_position: torch.Tensor,
- batch_size: int,
- **kwargs,
- ):
- """
- Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
- `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
-
- Args:
- attention_mask (`torch.Tensor`):
- A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
- `(batch_size, 1, query_length, key_value_length)`.
- sequence_length (`int`):
- The sequence length being processed.
- target_length (`int`):
- The target length: when generating with static cache, the mask should be as long as the static cache,
- to account for the 0 padding, the part of the cache that is not filled yet.
- dtype (`torch.dtype`):
- The dtype to use for the 4D attention mask.
- device (`torch.device`):
- The device to plcae the 4D attention mask on.
- cache_position (`torch.Tensor`):
- Indices depicting the position of the input sequence tokens in the sequence.
- batch_size (`torch.Tensor`):
- Batch size.
- """
- if attention_mask is not None and attention_mask.dim() == 4:
- # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
- causal_mask = attention_mask
- else:
- min_dtype = torch.finfo(dtype).min
- causal_mask = torch.full(
- (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
- )
- if sequence_length != 1:
- causal_mask = torch.triu(causal_mask, diagonal=1)
- causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
- causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
- if attention_mask is not None:
- causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
- mask_length = attention_mask.shape[-1]
- padding_attention_mask = (attention_mask[:, None, None, :] == attention_mask[:, None, :, None])[
- :, :, -sequence_length:, :
- ].to(dtype)
- padding_mask = causal_mask[:, :, :, :mask_length] + padding_attention_mask
- padding_mask = padding_mask == 0
- causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
- padding_mask, min_dtype
- )
-
- return causal_mask
-
- def _update_mamba_mask(self, attention_mask, cache_position):
- """
- No need for zeroing states when
- 1. Cached forward
- 2. Attending to all inputs
- """
- mamba_mask = attention_mask
- if cache_position[0] > 0 or (attention_mask is not None and torch.all(attention_mask == 1)):
- mamba_mask = None
- return mamba_mask
-
-
-class BambaForCausalLM(BambaPreTrainedModel, GenerationMixin):
- _tied_weights_keys = ["lm_head.weight"]
- _tp_plan = {"lm_head": "colwise_rep"}
-
- def __init__(self, config):
- super().__init__(config)
- self.model = BambaModel(config)
- self.vocab_size = config.vocab_size
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
-
- # Initialize weights and apply final processing
- self.post_init()
-
- def get_input_embeddings(self):
- return self.model.embed_tokens
-
- def set_input_embeddings(self, value):
- self.model.embed_tokens = value
-
- def get_output_embeddings(self):
- return self.lm_head
-
- def set_output_embeddings(self, new_embeddings):
- self.lm_head = new_embeddings
-
- def set_decoder(self, decoder):
- self.model = decoder
-
- def get_decoder(self):
- return self.model
-
- @add_start_docstrings_to_model_forward(BAMBA_INPUTS_DOCSTRING)
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
- def forward(
- self,
- input_ids: torch.LongTensor = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[HybridMambaAttentionDynamicCache] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- labels: Optional[torch.LongTensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- cache_position: Optional[torch.LongTensor] = None,
- num_logits_to_keep: int = 0,
- **kwargs,
- ) -> Union[Tuple, CausalLMOutputWithPast]:
- r"""
- Args:
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
-
- num_logits_to_keep (`int` or `None`, *optional*):
- Calculate logits for the last `num_logits_to_keep` tokens. If `None`, calculate logits for all
- `input_ids`. Only last token logits are needed for generation, and calculating them only for that token
- can save memory, which becomes pretty significant for long sequences.
-
- Returns:
-
- Example:
-
- ```python
- >>> from transformers import AutoTokenizer, BambaForCausalLM
-
- >>> model = BambaForCausalLM.from_pretrained("...")
- >>> tokenizer = AutoTokenizer.from_pretrained("...")
-
- >>> prompt = "Hey, are you conscious? Can you talk to me?"
- >>> inputs = tokenizer(prompt, return_tensors="pt")
-
- >>> # Generate
- >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
- >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
- "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
- ```"""
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
-
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
- outputs = self.model(
- input_ids=input_ids,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_values=past_key_values,
- inputs_embeds=inputs_embeds,
- use_cache=use_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- cache_position=cache_position,
- **kwargs,
- )
-
- hidden_states = outputs[0]
- # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
- logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
-
- loss = None
- if labels is not None:
- loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
-
- if not return_dict:
- output = (logits,) + outputs[1:]
- return (loss,) + output if loss is not None else output
-
- return CausalLMOutputWithPast(
- loss=loss,
- logits=logits,
- past_key_values=outputs.past_key_values,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
-
- def prepare_inputs_for_generation(
- self,
- input_ids,
- past_key_values=None,
- attention_mask=None,
- inputs_embeds=None,
- cache_position=None,
- position_ids=None,
- use_cache=True,
- **kwargs,
- ):
- # Overwitten -- has a unique cache type, `HybridMambaAttentionDynamicCache`
-
- empty_past_kv = past_key_values is None
-
- # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
- # Exception 1: when passing input_embeds, input_ids may be missing entries
- # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
- if not empty_past_kv:
- if inputs_embeds is not None: # Exception 1
- input_ids = input_ids[:, -cache_position.shape[0] :]
- elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
- input_ids = input_ids[:, cache_position]
- else:
- past_key_values = HybridMambaAttentionDynamicCache(
- self.config, input_ids.shape[0], self.dtype, device=self.device
- )
-
- if attention_mask is not None and position_ids is None:
- # create position_ids on the fly for batch generation
- position_ids = attention_mask.long().cumsum(-1) - 1
- position_ids.masked_fill_(attention_mask == 0, 1)
- if not empty_past_kv:
- position_ids = position_ids[:, -input_ids.shape[1] :]
-
- # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
- if inputs_embeds is not None and empty_past_kv:
- model_inputs = {"inputs_embeds": inputs_embeds}
- else:
- model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases
-
- model_inputs.update(
- {
- "position_ids": position_ids,
- "past_key_values": past_key_values,
- "use_cache": use_cache,
- "attention_mask": attention_mask,
- "num_logits_to_keep": self.config.num_logits_to_keep,
- "cache_position": cache_position,
- }
- )
- return model_inputs
-
-
-__all__ = ["BambaModel", "BambaForCausalLM", "BambaPreTrainedModel"]
diff --git a/src/transformers/models/bamba/modular_bamba.py b/src/transformers/models/bamba/modular_bamba.py
deleted file mode 100644
index 7fb35f48fb3b76..00000000000000
--- a/src/transformers/models/bamba/modular_bamba.py
+++ /dev/null
@@ -1,1303 +0,0 @@
-# coding=utf-8
-# Copyright 2024 IBM and the HuggingFace Inc. team. All rights reserved.
-#
-# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
-# and OPT implementations in this library. It has been modified from its
-# original forms to accommodate minor architectural differences compared
-# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""PyTorch Bamba model."""
-
-from typing import Optional, Tuple, Union
-
-import torch
-import torch.utils.checkpoint
-from torch import nn
-
-import transformers.models.jamba.modeling_jamba as modeling_jamba
-from transformers.activations import ACT2FN
-from transformers.models.jamba.modeling_jamba import JambaAttentionDecoderLayer
-from transformers.models.llama.modeling_llama import (
- LlamaAttention,
- LlamaForCausalLM,
- LlamaMLP,
- LlamaRMSNorm,
- LlamaRotaryEmbedding,
- rotate_half,
-)
-from transformers.models.mamba2.modeling_mamba2 import (
- MambaRMSNormGated,
- pad_tensor_by_size,
- reshape_into_chunks,
- segment_sum,
-)
-
-from ...modeling_attn_mask_utils import (
- AttentionMaskConverter,
-)
-from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
-from ...modeling_utils import PreTrainedModel
-from ...utils import (
- add_start_docstrings,
- add_start_docstrings_to_model_forward,
- logging,
- replace_return_docstrings,
-)
-from ...utils.import_utils import (
- is_causal_conv1d_available,
- is_flash_attn_2_available,
- is_mamba_2_ssm_available,
-)
-from .configuration_bamba import BambaConfig
-
-
-if is_flash_attn_2_available():
- pass
-
-if is_mamba_2_ssm_available():
- from mamba_ssm.ops.triton.selective_state_update import selective_state_update
- from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined
-else:
- selective_state_update = None
-
-if is_causal_conv1d_available():
- from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
-else:
- causal_conv1d_update, causal_conv1d_fn = None, None
-
-is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update))
-
-
-logger = logging.get_logger(__name__)
-
-_CONFIG_FOR_DOC = "BambaConfig"
-
-
-# Adapted from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache for the v2 mixer
-class HybridMambaAttentionDynamicCache(modeling_jamba.HybridMambaAttentionDynamicCache):
- """
- A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache
- (which has a constant shape regardless of seq_len).
-
- This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states`
- and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor
- For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`,
- while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors).
- For mamba layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors),
- while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`,
- and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`.
- """
-
- def __init__(self, config: BambaConfig, batch_size, dtype=torch.float16, device=None):
- super().__init__(config, batch_size, dtype, device)
- self.layers_block_type = config.layers_block_type
- self.has_previous_state = False # only used by mamba
- conv_kernel_size = config.mamba_d_conv
- ssm_state_size = config.mamba_d_state
-
- self.conv_states = []
- self.ssm_states = []
- self.transformer_layers = []
- for i in range(config.num_hidden_layers):
- if self.layers_block_type[i] == "mamba":
- self.conv_states += [
- torch.zeros(
- batch_size,
- (config.mamba_expand * config.hidden_size + 2 * config.mamba_n_groups * ssm_state_size),
- conv_kernel_size,
- device=device,
- dtype=dtype,
- )
- ]
- self.ssm_states += [
- torch.zeros(
- batch_size,
- config.mamba_n_heads,
- config.mamba_d_head,
- ssm_state_size,
- device=device,
- dtype=dtype,
- )
- ]
- else:
- self.conv_states += [torch.tensor([[]] * batch_size, device=device)]
- self.ssm_states += [torch.tensor([[]] * batch_size, device=device)]
- self.transformer_layers.append(i)
-
- self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)]
- self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)]
-
-
-class BambaRotaryEmbedding(LlamaRotaryEmbedding):
- pass
-
-
-# Adapted from transformers.models.glm.modular_glm.apply_rotary_pos_emb
-def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
- """Applies Rotary Position Embedding to the query and key tensors.
-
- Removes the interleaving of cos and sin from GLM
-
- Args:
- q (`torch.Tensor`): The query tensor.
- k (`torch.Tensor`): The key tensor.
- cos (`torch.Tensor`): The cosine part of the rotary embedding.
- sin (`torch.Tensor`): The sine part of the rotary embedding.
- position_ids (`torch.Tensor`, *optional*):
- Deprecated and unused.
- unsqueeze_dim (`int`, *optional*, defaults to 1):
- The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
- sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
- that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
- k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
- cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
- the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
- Returns:
- `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
- """
- cos = cos.unsqueeze(unsqueeze_dim)
- sin = sin.unsqueeze(unsqueeze_dim)
-
- # Keep half or full tensor for later concatenation
- rotary_dim = cos.shape[-1]
- q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
- k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
-
- # Apply rotary embeddings on the first half or full tensor
- q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin)
- k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin)
-
- # Concatenate back to full shape
- q_embed = torch.cat([q_embed, q_pass], dim=-1)
- k_embed = torch.cat([k_embed, k_pass], dim=-1)
- return q_embed, k_embed
-
-
-class BambaAttention(LlamaAttention):
- pass
-
-
-class BambaRMSNormGated(MambaRMSNormGated):
- pass
-
-
-def apply_mask_to_padding_states(hidden_states, attention_mask):
- """
- Tunes out the hidden states for padding tokens, see https://github.com/state-spaces/mamba/issues/66
- """
- if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
- dtype = hidden_states.dtype
- hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)
-
- return hidden_states
-
-
-# Adapted from transformers.models.mamba2.modeling_mamba2.Mamba2Mixer
-class BambaMixer(nn.Module):
- """
- Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`.
- A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)
- ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4,
- and is why Mamba is called **selective** state spaces)
-
- The are a few differences between this and Mamba2Mixer:
- - The variable use_precomputed_states is slightly different due to the HybridCache structure
- - There's a few non-obvious bugs fixed with batching in the slow path that exist in main
- - Some extra variables that our layer doesn't need have been removed
- - We ported most of the refactors in https://github.com/huggingface/transformers/pull/35154, which is (as of Dec 18, 2024) unmerged
- """
-
- def __init__(self, config: BambaConfig, layer_idx: int):
- super().__init__()
- self.num_heads = config.mamba_n_heads
- self.hidden_size = config.hidden_size
- self.ssm_state_size = config.mamba_d_state
- self.conv_kernel_size = config.mamba_d_conv
- self.intermediate_size = int(config.mamba_expand * self.hidden_size)
- self.layer_idx = layer_idx
- self.use_conv_bias = config.mamba_conv_bias
- self.activation = config.hidden_act
- self.act = ACT2FN[config.hidden_act]
- self.use_bias = config.mamba_proj_bias
-
- self.layer_norm_epsilon = config.rms_norm_eps
-
- self.n_groups = config.mamba_n_groups
- self.head_dim = config.mamba_d_head
- self.chunk_size = config.mamba_chunk_size
-
- # FIXME:
- self.time_step_limit = (0.0, float("inf"))
- self.time_step_min = 0.001
- self.time_step_max = 0.1
-
- self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.ssm_state_size
- self.conv1d = nn.Conv1d(
- in_channels=self.conv_dim,
- out_channels=self.conv_dim,
- bias=config.mamba_conv_bias,
- kernel_size=self.conv_kernel_size,
- groups=self.conv_dim,
- padding=self.conv_kernel_size - 1,
- )
-
- # projection of the input hidden states
- projection_size = self.intermediate_size + self.conv_dim + self.num_heads
- self.in_proj = nn.Linear(
- self.hidden_size,
- projection_size,
- bias=self.use_bias,
- )
- # selective projection used to make dt, B and C input dependant
-
- # time step projection (discretization)
- # instantiate once and copy inv_dt in init_weights of PretrainedModel
- self.dt_bias = nn.Parameter(torch.ones(self.num_heads))
-
- # S4D real initialization. These are not discretized!
- # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded
- A = torch.arange(1, self.num_heads + 1)
- self.A_log = nn.Parameter(torch.log(A))
- self.A_log._no_weight_decay = True
- self.norm = BambaRMSNormGated(self.intermediate_size, eps=self.layer_norm_epsilon)
- self.D = nn.Parameter(torch.ones(self.num_heads))
- self.D._no_weight_decay = True
-
- self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=self.use_bias)
-
- if not is_fast_path_available:
- logger.warning_once(
- "The fast path is not available because on of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)`"
- " is None. Falling back to the naive implementation. To install follow https://github.com/state-spaces/mamba/#installation and"
- " https://github.com/Dao-AILab/causal-conv1d"
- )
- else:
- logger.warning_once("The fast path for Bamba will be used when running the model on a GPU")
-
- def cuda_kernels_forward(
- self,
- hidden_states: torch.Tensor,
- cache_params: Optional[HybridMambaAttentionDynamicCache] = None,
- cache_position: Optional[torch.LongTensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- ):
- # 1. Gated MLP's linear projection
- hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)
- projected_states = self.in_proj(hidden_states)
-
- # Set up dimensions for reshapes later
- batch_size, seq_len, _ = hidden_states.shape
- groups_time_state_size = self.n_groups * self.ssm_state_size
-
- use_precomputed_states = (
- cache_params is not None
- and cache_params.has_previous_state
- and seq_len == 1
- and cache_params.conv_states[self.layer_idx].shape[0]
- == cache_params.ssm_states[self.layer_idx].shape[0]
- == batch_size
- and cache_position is not None
- and cache_position[0] > 0
- )
-
- # getting projected states from cache if it exists
- if use_precomputed_states:
- gate, hidden_states_B_C, dt = projected_states.squeeze(1).split(
- [self.intermediate_size, self.conv_dim, self.num_heads], dim=-1
- )
-
- # 2. Convolution sequence transformation
- hidden_states_B_C = causal_conv1d_update(
- hidden_states_B_C,
- cache_params.conv_states[self.layer_idx],
- self.conv1d.weight.squeeze(1),
- self.conv1d.bias,
- self.activation,
- )
-
- hidden_states, B, C = torch.split(
- hidden_states_B_C,
- [self.intermediate_size, groups_time_state_size, groups_time_state_size],
- dim=-1,
- )
-
- # 3. SSM transformation
- A = -torch.exp(self.A_log.float()) # (nheads,)
- A = A[:, None, ...][:, :, None].expand(-1, self.head_dim, self.ssm_state_size).to(dtype=torch.float32)
- dt = dt[:, :, None].expand(-1, -1, self.head_dim)
- dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim)
- D = self.D[:, None, ...].expand(-1, self.head_dim)
- B = B.view(batch_size, self.n_groups, B.shape[1] // self.n_groups)
- C = C.view(batch_size, self.n_groups, C.shape[1] // self.n_groups)
- hidden_states_reshaped = hidden_states.view(batch_size, self.num_heads, self.head_dim)
- hidden_states = selective_state_update(
- cache_params.ssm_states[self.layer_idx],
- hidden_states_reshaped,
- dt,
- A,
- B,
- C,
- D,
- z=None,
- dt_bias=dt_bias,
- dt_softplus=True,
- )
- hidden_states = hidden_states.view(batch_size, self.num_heads * self.head_dim)
- hidden_states = self.norm(hidden_states, gate)
-
- # 4. Final linear projection
- out = self.out_proj(hidden_states)[:, None, ...]
- # Fused calculations or step by step if no initialized cache is found
- else:
- A = -torch.exp(self.A_log.float()) # (num_heads) or (intermediate_size, state_size)
- dt_limit_kwargs = {} if self.time_step_limit == (0.0, float("inf")) else {"dt_limit": self.time_step_limit}
-
- # 2-4. Fused kernel for conv1d, SSM, and the final projection
- if self.training and cache_params is None:
- out = mamba_split_conv1d_scan_combined(
- projected_states,
- self.conv1d.weight.squeeze(1),
- self.conv1d.bias,
- self.dt_bias,
- A,
- D=self.D,
- chunk_size=self.chunk_size,
- seq_idx=None, # was seq_idx
- activation=self.activation,
- rmsnorm_weight=self.norm.weight,
- rmsnorm_eps=self.norm.variance_epsilon,
- outproj_weight=self.out_proj.weight,
- outproj_bias=self.out_proj.bias,
- headdim=self.head_dim,
- ngroups=self.n_groups,
- norm_before_gate=False,
- return_final_states=False,
- **dt_limit_kwargs,
- )
-
- else:
- gate, hidden_states_B_C, dt = projected_states.split(
- [self.intermediate_size, self.conv_dim, self.num_heads], dim=-1
- )
-
- # 2. Convolution sequence transformation
- # Init cache
- if cache_params is not None:
- # storing the states
- # If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
- # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
- hidden_states_B_C_transposed = hidden_states_B_C.transpose(1, 2)
- conv_states = nn.functional.pad(
- hidden_states_B_C_transposed,
- (self.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0),
- )
- cache_params.conv_states[self.layer_idx].copy_(conv_states)
-
- if self.activation not in ["silu", "swish"]:
- hidden_states_B_C = self.act(
- self.conv1d(hidden_states_B_C.transpose(1, 2))[..., :seq_len].transpose(1, 2)
- )
- else:
- hidden_states_B_C = causal_conv1d_fn(
- x=hidden_states_B_C.transpose(1, 2),
- weight=self.conv1d.weight.squeeze(1),
- bias=self.conv1d.bias,
- activation=self.activation,
- ).transpose(1, 2)
-
- hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask)
- hidden_states, B, C = torch.split(
- hidden_states_B_C,
- [self.intermediate_size, groups_time_state_size, groups_time_state_size],
- dim=-1,
- )
-
- # 3. SSM transformation
- scan_output, ssm_state = mamba_chunk_scan_combined(
- hidden_states.view(batch_size, seq_len, -1, self.head_dim),
- dt,
- A,
- B.view(batch_size, seq_len, self.n_groups, -1),
- C.view(batch_size, seq_len, self.n_groups, -1),
- chunk_size=self.chunk_size,
- D=self.D,
- z=None,
- seq_idx=None,
- return_final_states=True,
- dt_bias=self.dt_bias,
- dt_softplus=True,
- **dt_limit_kwargs,
- )
-
- # Init cache
- if ssm_state is not None and cache_params is not None:
- cache_params.ssm_states[self.layer_idx].copy_(ssm_state)
-
- scan_output = scan_output.view(batch_size, seq_len, -1)
- # Multiply "gate" branch and apply extra normalization layer
- scan_output = self.norm(scan_output, gate)
-
- # 4. Final linear projection
- out = self.out_proj(scan_output)
- return out
-
- # fmt: off
- def torch_forward(
- self,
- input_states,
- cache_params: Optional[HybridMambaAttentionDynamicCache] = None,
- cache_position: Optional[torch.LongTensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- ):
- batch_size, seq_len, _ = input_states.shape
- dtype = input_states.dtype
-
- # 1. Gated MLP's linear projection
- input_states = apply_mask_to_padding_states(input_states, attention_mask)
- projected_states = self.in_proj(input_states)
- gate, hidden_states_B_C, dt = projected_states.split(
- [self.intermediate_size, self.conv_dim, self.num_heads], dim=-1
- )
-
- use_precomputed_states = (
- cache_params is not None
- and cache_params.has_previous_state
- and seq_len == 1
- and cache_params.conv_states[self.layer_idx].shape[0]
- == cache_params.ssm_states[self.layer_idx].shape[0]
- == batch_size
- and cache_position is not None
- and cache_position[0] > 0
- )
-
- # 2. Convolution sequence transformation
- if use_precomputed_states:
- cache_params.conv_states[self.layer_idx] = cache_params.conv_states[self.layer_idx].roll(shifts=-1, dims=-1)
- cache_params.conv_states[self.layer_idx][:, :, -1] = hidden_states_B_C[:, 0, :].to(cache_params.conv_states[self.layer_idx].device)
-
- # We need to guarantee that anything regarding the cache is on the same device
- conv_states = cache_params.conv_states[self.layer_idx].to(device=self.conv1d.weight.device)
-
- hidden_states_B_C = torch.sum(
- conv_states * self.conv1d.weight.squeeze(1), dim=-1
- )
- if self.use_conv_bias:
- hidden_states_B_C = hidden_states_B_C + self.conv1d.bias
- hidden_states_B_C = self.act(hidden_states_B_C)
- else:
- # Init cache
- if cache_params is not None:
- hidden_states_B_C_transposed = hidden_states_B_C.transpose(1, 2)
- conv_states = nn.functional.pad(
- hidden_states_B_C_transposed, (self.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0)
- )
- cache_params.conv_states[self.layer_idx].copy_(conv_states)
-
- hidden_states_B_C = self.act(self.conv1d(hidden_states_B_C.transpose(1, 2))[..., :seq_len].transpose(1, 2))
-
- hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask)
- hidden_states, B, C = torch.split(
- hidden_states_B_C,
- [self.intermediate_size, self.n_groups * self.ssm_state_size, self.n_groups * self.ssm_state_size],
- dim=-1
- )
-
- # 3. SSM transformation
- A = -torch.exp(self.A_log.float()) # [num_heads]
- if use_precomputed_states:
- # We need to guarantee that anything regarding the cache is on the same device
- cache_device = cache_params.ssm_states[self.layer_idx].device
-
- # Note: there is no need to pad parameter matrices here, as there is just one new token
- # for batched generation
- dt = dt[:, 0, :][:, None, ...]
- dt = dt.transpose(1, 2).expand(batch_size, dt.shape[-1], self.head_dim)
- # [num_heads] -> [num_heads, head_dim]
- dt_bias = self.dt_bias[..., None].expand(self.dt_bias.shape[0], self.head_dim)
-
- dt = torch.nn.functional.softplus(dt + dt_bias.to(dt.dtype))
- dt = torch.clamp(dt, self.time_step_limit[0], self.time_step_limit[1])
- A = A[..., None, None].expand(self.num_heads, self.head_dim, self.ssm_state_size).to(dtype=torch.float32)
- # [bsz, num_heads, head_dim, state_size]
- dA = (torch.exp(dt[..., None] * A)).to(device=cache_device)
-
- # Discretize B
- # [bsz, n_groups * state_size] -> [bsz, n_groups, 1, state_size] ->
- # -> [bsz, n_groups, group to head repetition factor, state_size] -> [bsz, num_heads, state_size]
- B = B.reshape(batch_size, self.n_groups, -1)[..., None, :]
- B = B.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, B.shape[-1]).contiguous()
- B = B.reshape(batch_size, -1, B.shape[-1])
- # [bsz, num_heads, head_dim, state_size]
- dB = dt[..., None] * B[..., None, :]
-
- # Discretize x into dB
- # [bsz, intermediate_size] -> [bsz, num_heads, head_dim]
- hidden_states = hidden_states.reshape(batch_size, -1, self.head_dim)
- dBx = (dB * hidden_states[..., None]).to(device=cache_device)
-
- # State calculation
- cache_params.ssm_states[self.layer_idx].copy_(
- cache_params.ssm_states[self.layer_idx] * dA + dBx
- )
-
- # Subsequent output
- # [bsz, n_groups * state_size] -> [bsz, num_heads, state_size]
- C = C.reshape(batch_size, self.n_groups, -1)[..., None, :]
- C = C.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, C.shape[-1]).contiguous()
- C = C.reshape(batch_size, -1, C.shape[-1])
- # [bsz, num_heads, head_dim]
-
- ssm_states = cache_params.ssm_states[self.layer_idx].to(device=C.device, dtype=C.dtype) # Shape: [b, h, d, n]
- # Reshape ssm_states to merge the first two dimensions
- ssm_states_reshaped = ssm_states.view(batch_size * self.num_heads, self.head_dim, self.ssm_state_size) # Shape: [b*h, d, n]
- C_reshaped = C.view(batch_size * self.num_heads, self.ssm_state_size, 1) # Shape: [b*h, n, 1]
- y = torch.bmm(ssm_states_reshaped, C_reshaped)
- y = y.view(batch_size, self.num_heads, self.head_dim)
-
- # D skip connection
- # [num_heads] -> [num_heads, head_dim]
- D = self.D[..., None].expand(self.D.shape[0], self.head_dim)
- y = (y + hidden_states * D).to(y.dtype)
-
- # [bsz, num_heads, head_dim] -> [bsz, 1, intermediate_size]
- y = y.reshape(batch_size, -1)[:, None, ...]
- else:
- # begin ssd naive implementation without einsums
- dt = nn.functional.softplus(dt + self.dt_bias)
- dt = torch.clamp(dt, self.time_step_limit[0], self.time_step_limit[1])
- hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim).float()
- B = B.reshape(batch_size, seq_len, -1, self.ssm_state_size).float()
- C = C.reshape(batch_size, seq_len, -1, self.ssm_state_size).float()
- B = B.repeat(1, 1, self.num_heads // self.n_groups, 1)
- C = C.repeat(1, 1, self.num_heads // self.n_groups, 1)
- pad_size = (self.chunk_size - seq_len % self.chunk_size) % self.chunk_size
-
- D_residual = self.D[..., None] * pad_tensor_by_size(hidden_states, pad_size)
-
- # Discretize x and A
- hidden_states = hidden_states * dt[..., None]
- A = A.to(hidden_states.dtype) * dt
-
- # Rearrange into blocks/chunks
- hidden_states, A, B, C = [reshape_into_chunks(t, pad_size, self.chunk_size) for t in (hidden_states, A, B, C)]
-
- # [bsz, -1, chunk_size, num_heads] -> [bsz, num_heads, -1, chunk_size]
- A = A.permute(0, 3, 1, 2)
- A_cumsum = torch.cumsum(A, dim=-1)
-
- # 1. Compute the output for each intra-chunk (diagonal blocks)
- # This is the analog of a causal mask
- L = torch.exp(segment_sum(A))
-
- # Contraction of C and B to get G (attention-weights like)
- G_intermediate = C[:, :, :, None, :, :] * B[:, :, None, :, :, :] # shape: (b, c, l, s, h, n)
- G = G_intermediate.sum(dim=-1) # shape: (b, c, l, s, h)
-
- # Compute M, equivalent to applying attention mask to weights
- M_intermediate = G[..., None] * L.permute(0, 2, 3, 4, 1)[..., None]
- M = M_intermediate.sum(dim=-1)
-
- # Compute Y_diag (apply to values)
- Y_diag = (M[..., None] * hidden_states[:, :, None]).sum(dim=3)
-
- # 2. Compute the state for each intra-chunk
- # (right term of low-rank factorization of off-diagonal blocks; B terms)
- decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum))
- B_decay = B * decay_states.permute(0, -2, -1, 1)[..., None]
- states = (B_decay[..., None, :] * hidden_states[..., None]).sum(dim=2)
-
- # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
- # (middle term of factorization of off-diag blocks; A terms)
- if use_precomputed_states:
- previous_states = cache_params.ssm_states[self.layer_idx][:, None, ...].to(device=states.device)
- else:
- previous_states = torch.zeros_like(states[:, :1])
- states = torch.cat([previous_states, states], dim=1)
- decay_chunk = torch.exp(segment_sum(nn.functional.pad(A_cumsum[:, :, :, -1], (1, 0))))
- decay_chunk = decay_chunk.transpose(1, 3)
- new_states = (decay_chunk[..., None, None] * states[:, :, None, ...]).sum(dim=1)
- states, ssm_state = new_states[:, :-1], new_states[:, -1]
-
- # 4. Compute state -> output conversion per chunk
- # (left term of low-rank factorization of off-diagonal blocks; C terms)
- state_decay_out = torch.exp(A_cumsum)
- C_times_states = (C[..., None, :] * states[:, :, None, ...])
- state_decay_out_permuted = state_decay_out.permute(0, 2, 3, 1)
- Y_off = (C_times_states.sum(-1) * state_decay_out_permuted[..., None])
-
- # Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks)
- y = Y_diag + Y_off
- # [bsz, -1, self.chunk_size, num_heads, head_dim] -> [bsz, (padded) seq_len, num_heads, head_dim]
- y = y.reshape(batch_size, -1, self.num_heads, self.head_dim)
-
- y = y + D_residual
- # Cutting off padded chunks
- if pad_size > 0:
- y = y[:, :seq_len, :, :]
- y = y.reshape(batch_size, seq_len, -1)
-
- # Init cache
- if ssm_state is not None and cache_params is not None:
- cache_params.ssm_states[self.layer_idx].copy_(ssm_state)
-
- scan_output = self.norm(y, gate)
-
- # end ssd naive
-
- # 4. Final linear projection
- contextualized_states = self.out_proj(scan_output.to(dtype)) # [batch, seq_len, hidden_size]
- return contextualized_states
- # fmt: on
-
- def forward(
- self,
- hidden_states,
- cache_params: Optional[HybridMambaAttentionDynamicCache] = None,
- cache_position: Optional[torch.LongTensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- ):
- if is_fast_path_available and "cuda" in self.in_proj.weight.device.type:
- return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask)
- dtype = hidden_states.dtype
- if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
- # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66
- hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)
-
- return self.torch_forward(hidden_states, cache_params, cache_position, attention_mask)
-
-
-class BambaMLP(LlamaMLP):
- pass
-
-
-class BambaRMSNorm(LlamaRMSNorm):
- pass
-
-
-class BambaDecoderLayer(JambaAttentionDecoderLayer):
- def __init__(self, config: BambaConfig, layer_idx: int, layer_type: str = "mamba"):
- super().__init__()
-
- del self.self_attn
-
- num_experts = 1
- ffn_layer_class = BambaMLP if num_experts == 1 else None
- self.feed_forward = ffn_layer_class(config)
-
- self.layer_type = layer_type
- if layer_type == "mamba":
- self.mamba = BambaMixer(config=config, layer_idx=layer_idx)
- elif layer_type == "attention":
- self.self_attn = BambaAttention(config, layer_idx)
- else:
- raise ValueError("Invalid layer_type")
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[HybridMambaAttentionDynamicCache] = None,
- output_attentions: Optional[bool] = False,
- use_cache: Optional[bool] = False,
- cache_position: Optional[torch.LongTensor] = None,
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
- **kwargs,
- ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
- """
- Args:
- hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
- attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
- `(batch, sequence_length)` where padding elements are indicated by 0.
- past_key_value (`HybridMambaAttentionDynamicCache`, *optional*): cached past key and value projection states
- output_attentions (`bool`, *optional*):
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
- returned tensors for more detail.
- use_cache (`bool`, *optional*):
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
- (see `past_key_values`).
- cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
- Indices depicting the position of the input sequence tokens in the sequence.
- position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
- Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
- with `head_dim` being the embedding dimension of each attention head.
- kwargs (`dict`, *optional*):
- Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
- into the model
- """
-
- residual = hidden_states
-
- hidden_states = self.input_layernorm(hidden_states)
-
- # this is a hybrid decoder layer
- if self.layer_type == "mamba":
- hidden_states = self.mamba(
- hidden_states=hidden_states,
- cache_params=past_key_value,
- cache_position=cache_position,
- attention_mask=attention_mask,
- )
- self_attn_weights = None
- elif self.layer_type == "attention":
- hidden_states, self_attn_weights = self.self_attn(
- hidden_states=hidden_states,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_value=past_key_value,
- output_attentions=output_attentions,
- use_cache=use_cache,
- cache_position=cache_position,
- position_embeddings=position_embeddings,
- **kwargs,
- )
-
- # residual connection after attention
- hidden_states = residual + hidden_states
-
- # feed-forward
- residual = hidden_states
- hidden_states = self.pre_ff_layernorm(hidden_states)
- hidden_states = self.feed_forward(hidden_states)
- hidden_states = residual + hidden_states
-
- outputs = (hidden_states,)
-
- if output_attentions:
- outputs += (self_attn_weights,)
-
- return outputs
-
-
-BAMBA_START_DOCSTRING = r"""
- This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
- library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
- etc.)
-
- This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
- Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
- and behavior.
-
- Parameters:
- config ([`BambaConfig`]):
- Model configuration class with all the parameters of the model. Initializing with a config file does not
- load the weights associated with the model, only the configuration. Check out the
- [`~PreTrainedModel.from_pretrained`] method to load the model weights.
-"""
-
-
-@add_start_docstrings(
- "The bare BambaModel outputting raw hidden-states without any specific head on top.",
- BAMBA_START_DOCSTRING,
-)
-class BambaPreTrainedModel(PreTrainedModel):
- config_class = BambaConfig
- base_model_prefix = "model"
- supports_gradient_checkpointing = True
- _no_split_modules = ["BambaDecoderLayer"]
- _skip_keys_device_placement = "past_key_values"
- _supports_flash_attn_2 = True
- _supports_sdpa = True
- _supports_cache_class = True # Note: only supports HybridMambaAttentionDynamicCache
- _is_stateful = True
-
- def _init_weights(self, module):
- std = self.config.initializer_range
- if isinstance(module, (nn.Linear, nn.Conv1d)):
- module.weight.data.normal_(mean=0.0, std=std)
- if module.bias is not None:
- module.bias.data.zero_()
- elif isinstance(module, nn.Embedding):
- module.weight.data.normal_(mean=0.0, std=std)
- if module.padding_idx is not None:
- module.weight.data[module.padding_idx].zero_()
-
-
-BAMBA_INPUTS_DOCSTRING = r"""
- Args:
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
- Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
- it.
-
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
- [`PreTrainedTokenizer.__call__`] for details.
-
- [What are input IDs?](../glossary#input-ids)
- attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
-
- - 1 for tokens that are **not masked**,
- - 0 for tokens that are **masked**.
-
- [What are attention masks?](../glossary#attention-mask)
-
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
- [`PreTrainedTokenizer.__call__`] for details.
-
- If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
- `past_key_values`).
-
- If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
- and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
- information on the default strategy.
-
- - 1 indicates the head is **not masked**,
- - 0 indicates the head is **masked**.
- position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
- config.n_positions - 1]`.
-
- [What are position IDs?](../glossary#position-ids)
- past_key_values (`HybridMambaAttentionDynamicCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
- A HybridMambaAttentionDynamicCache object containing pre-computed hidden-states (keys and values in the
- self-attention blocks and convolution and ssm states in the mamba blocks) that can be used (see
- `past_key_values` input) to speed up sequential decoding.
- Key and value cache tensors have shape `(batch_size, num_heads, seq_len, head_dim)`.
- Convolution and ssm states tensors have shape `(batch_size, d_inner, d_conv)` and
- `(batch_size, d_inner, d_state)` respectively.
- See the `HybridMambaAttentionDynamicCache` class for more details.
-
- If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that
- don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
- `input_ids` of shape `(batch_size, sequence_length)`.
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
- is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
- model's internal embedding lookup matrix.
- use_cache (`bool`, *optional*):
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
- `past_key_values`).
- output_attentions (`bool`, *optional*):
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
- tensors for more detail.
- output_hidden_states (`bool`, *optional*):
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
- more detail.
- output_router_logits (`bool`, *optional*):
- Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
- should not be returned during inference.
- return_dict (`bool`, *optional*):
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
- cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
- Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
- this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
- the complete sequence length.
-"""
-
-
-@add_start_docstrings(
- "The bare Bamba Model outputting raw hidden-states without any specific head on top.",
- BAMBA_START_DOCSTRING,
-)
-# Adapted from transformers.models.jamba.modeling_jamba.JambaModel
-class BambaModel(BambaPreTrainedModel):
- """
- Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`BambaDecoderLayer`]
-
- Args:
- config: BambaConfig
- """
-
- def __init__(self, config: BambaConfig):
- super().__init__(config)
- self.padding_idx = config.pad_token_id
- self.vocab_size = config.vocab_size
-
- self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
- decoder_layers = []
- for i in range(config.num_hidden_layers):
- decoder_layers.append(BambaDecoderLayer(config, layer_idx=i, layer_type=config.layers_block_type[i]))
- self.layers = nn.ModuleList(decoder_layers)
-
- self._attn_implementation = config._attn_implementation
- self.final_layernorm = BambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
- self.rotary_emb = BambaRotaryEmbedding(config=config)
-
- self.gradient_checkpointing = False
- # Initialize weights and apply final processing
- self.post_init()
-
- def get_input_embeddings(self):
- return self.embed_tokens
-
- def set_input_embeddings(self, value):
- self.embed_tokens = value
-
- @add_start_docstrings_to_model_forward(BAMBA_INPUTS_DOCSTRING)
- def forward(
- self,
- input_ids: torch.LongTensor = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[HybridMambaAttentionDynamicCache] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- cache_position: Optional[torch.LongTensor] = None,
- ) -> Union[Tuple, BaseModelOutputWithPast]:
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- use_cache = use_cache if use_cache is not None else self.config.use_cache
-
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
-
- if (input_ids is None) ^ (inputs_embeds is not None):
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
-
- if self.gradient_checkpointing and self.training and use_cache:
- logger.warning_once(
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
- )
- use_cache = False
-
- if inputs_embeds is None:
- inputs_embeds = self.embed_tokens(input_ids)
- hidden_states = inputs_embeds
-
- if use_cache and past_key_values is None:
- logger.warning_once(
- "Bamba requires an initialized `HybridMambaAttentionDynamicCache` to return a cache. None was "
- "provided, so no cache will be returned."
- )
-
- if cache_position is None:
- cache_position = torch.arange(hidden_states.shape[1], device=hidden_states.device)
- if position_ids is None:
- position_ids = cache_position.unsqueeze(0)
-
- causal_mask = self._update_causal_mask(
- attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
- )
- mamba_mask = self._update_mamba_mask(attention_mask, cache_position)
-
- # create position embeddings to be shared across the decoder layers
- position_embeddings = self.rotary_emb(hidden_states, position_ids)
-
- all_hidden_states = () if output_hidden_states else None
- all_self_attns = () if output_attentions else None
-
- for decoder_layer in self.layers:
- # Depending on the layer type we opt for 2D base attention mask (Mamba) or 4D causal mask (Attention)
- layer_mask = mamba_mask if decoder_layer.layer_type == "mamba" else causal_mask
-
- if output_hidden_states:
- all_hidden_states += (hidden_states,)
-
- if self.gradient_checkpointing and self.training:
- layer_outputs = self._gradient_checkpointing_func(
- decoder_layer.__call__,
- hidden_states,
- layer_mask,
- position_ids,
- past_key_values,
- output_attentions,
- use_cache,
- cache_position,
- position_embeddings,
- )
- else:
- layer_outputs = decoder_layer(
- hidden_states,
- attention_mask=layer_mask,
- position_ids=position_ids,
- past_key_value=past_key_values,
- output_attentions=output_attentions,
- use_cache=use_cache,
- cache_position=cache_position,
- position_embeddings=position_embeddings,
- )
-
- hidden_states = layer_outputs[0]
-
- if output_attentions:
- if layer_outputs[1] is not None:
- # append attentions only of attention layers. Mamba layers return `None` as the attention weights
- all_self_attns += (layer_outputs[1],)
-
- hidden_states = self.final_layernorm(hidden_states)
-
- # add hidden states from the last decoder layer
- if output_hidden_states:
- all_hidden_states += (hidden_states,)
-
- if past_key_values and not past_key_values.has_previous_state:
- past_key_values.has_previous_state = True
-
- next_cache = None if not use_cache else past_key_values
-
- if not return_dict:
- return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
- return BaseModelOutputWithPast(
- last_hidden_state=hidden_states,
- past_key_values=next_cache,
- hidden_states=all_hidden_states,
- attentions=all_self_attns,
- )
-
- def _update_causal_mask(
- self,
- attention_mask: torch.Tensor,
- input_tensor: torch.Tensor,
- cache_position: torch.Tensor,
- past_key_values: HybridMambaAttentionDynamicCache,
- output_attentions: bool,
- ):
- if self.config._attn_implementation == "flash_attention_2":
- if attention_mask is not None and 0.0 in attention_mask:
- return attention_mask
- return None
-
- # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
- # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
- # to infer the attention mask.
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
-
- # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
- if self.config._attn_implementation == "sdpa" and not output_attentions:
- if AttentionMaskConverter._ignore_causal_mask_sdpa(
- attention_mask,
- inputs_embeds=input_tensor,
- past_key_values_length=past_seen_tokens,
- is_training=self.training,
- ):
- return None
-
- dtype, device = input_tensor.dtype, input_tensor.device
- sequence_length = input_tensor.shape[1]
- target_length = (
- attention_mask.shape[-1]
- if isinstance(attention_mask, torch.Tensor)
- else past_seen_tokens + sequence_length + 1
- )
-
- # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
- causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
- attention_mask,
- sequence_length=sequence_length,
- target_length=target_length,
- dtype=dtype,
- device=device,
- cache_position=cache_position,
- batch_size=input_tensor.shape[0],
- )
-
- if (
- self.config._attn_implementation == "sdpa"
- and attention_mask is not None
- and attention_mask.device.type == "cuda"
- and not output_attentions
- ):
- # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
- # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
- # Details: https://github.com/pytorch/pytorch/issues/110213
- min_dtype = torch.finfo(dtype).min
- causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
-
- return causal_mask
-
- @staticmethod
- def _prepare_4d_causal_attention_mask_with_cache_position(
- attention_mask: torch.Tensor,
- sequence_length: int,
- target_length: int,
- dtype: torch.dtype,
- device: torch.device,
- cache_position: torch.Tensor,
- batch_size: int,
- **kwargs,
- ):
- """
- Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
- `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
-
- Args:
- attention_mask (`torch.Tensor`):
- A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
- `(batch_size, 1, query_length, key_value_length)`.
- sequence_length (`int`):
- The sequence length being processed.
- target_length (`int`):
- The target length: when generating with static cache, the mask should be as long as the static cache,
- to account for the 0 padding, the part of the cache that is not filled yet.
- dtype (`torch.dtype`):
- The dtype to use for the 4D attention mask.
- device (`torch.device`):
- The device to plcae the 4D attention mask on.
- cache_position (`torch.Tensor`):
- Indices depicting the position of the input sequence tokens in the sequence.
- batch_size (`torch.Tensor`):
- Batch size.
- """
- if attention_mask is not None and attention_mask.dim() == 4:
- # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
- causal_mask = attention_mask
- else:
- min_dtype = torch.finfo(dtype).min
- causal_mask = torch.full(
- (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
- )
- if sequence_length != 1:
- causal_mask = torch.triu(causal_mask, diagonal=1)
- causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
- causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
- if attention_mask is not None:
- causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
- mask_length = attention_mask.shape[-1]
- padding_attention_mask = (attention_mask[:, None, None, :] == attention_mask[:, None, :, None])[
- :, :, -sequence_length:, :
- ].to(dtype)
- padding_mask = causal_mask[:, :, :, :mask_length] + padding_attention_mask
- padding_mask = padding_mask == 0
- causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
- padding_mask, min_dtype
- )
-
- return causal_mask
-
- def _update_mamba_mask(self, attention_mask, cache_position):
- """
- No need for zeroing states when
- 1. Cached forward
- 2. Attending to all inputs
- """
- mamba_mask = attention_mask
- if cache_position[0] > 0 or (attention_mask is not None and torch.all(attention_mask == 1)):
- mamba_mask = None
- return mamba_mask
-
-
-class BambaForCausalLM(LlamaForCausalLM):
- @add_start_docstrings_to_model_forward(BAMBA_INPUTS_DOCSTRING)
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
- def forward(
- self,
- input_ids: torch.LongTensor = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[HybridMambaAttentionDynamicCache] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- labels: Optional[torch.LongTensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- cache_position: Optional[torch.LongTensor] = None,
- num_logits_to_keep: int = 0,
- **kwargs,
- ) -> Union[Tuple, CausalLMOutputWithPast]:
- r"""
- Args:
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
-
- num_logits_to_keep (`int` or `None`, *optional*):
- Calculate logits for the last `num_logits_to_keep` tokens. If `None`, calculate logits for all
- `input_ids`. Only last token logits are needed for generation, and calculating them only for that token
- can save memory, which becomes pretty significant for long sequences.
-
- Returns:
-
- Example:
-
- ```python
- >>> from transformers import AutoTokenizer, BambaForCausalLM
-
- >>> model = BambaForCausalLM.from_pretrained("...")
- >>> tokenizer = AutoTokenizer.from_pretrained("...")
-
- >>> prompt = "Hey, are you conscious? Can you talk to me?"
- >>> inputs = tokenizer(prompt, return_tensors="pt")
-
- >>> # Generate
- >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
- >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
- "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
- ```"""
- return super().forward(
- input_ids,
- attention_mask,
- position_ids,
- past_key_values,
- inputs_embeds,
- labels,
- use_cache,
- output_attentions,
- output_hidden_states,
- return_dict,
- cache_position,
- num_logits_to_keep,
- **kwargs,
- )
-
- def prepare_inputs_for_generation(
- self,
- input_ids,
- past_key_values=None,
- attention_mask=None,
- inputs_embeds=None,
- cache_position=None,
- position_ids=None,
- use_cache=True,
- **kwargs,
- ):
- # Overwitten -- has a unique cache type, `HybridMambaAttentionDynamicCache`
-
- empty_past_kv = past_key_values is None
-
- # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
- # Exception 1: when passing input_embeds, input_ids may be missing entries
- # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
- if not empty_past_kv:
- if inputs_embeds is not None: # Exception 1
- input_ids = input_ids[:, -cache_position.shape[0] :]
- elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
- input_ids = input_ids[:, cache_position]
- else:
- past_key_values = HybridMambaAttentionDynamicCache(
- self.config, input_ids.shape[0], self.dtype, device=self.device
- )
-
- if attention_mask is not None and position_ids is None:
- # create position_ids on the fly for batch generation
- position_ids = attention_mask.long().cumsum(-1) - 1
- position_ids.masked_fill_(attention_mask == 0, 1)
- if not empty_past_kv:
- position_ids = position_ids[:, -input_ids.shape[1] :]
-
- # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
- if inputs_embeds is not None and empty_past_kv:
- model_inputs = {"inputs_embeds": inputs_embeds}
- else:
- model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases
-
- model_inputs.update(
- {
- "position_ids": position_ids,
- "past_key_values": past_key_values,
- "use_cache": use_cache,
- "attention_mask": attention_mask,
- "num_logits_to_keep": self.config.num_logits_to_keep,
- "cache_position": cache_position,
- }
- )
- return model_inputs
-
-
-__all__ = ["BambaModel", "BambaForCausalLM", "BambaPreTrainedModel"]
diff --git a/src/transformers/models/bark/__init__.py b/src/transformers/models/bark/__init__.py
index 6c21cf99976a15..4cb1a606cf6567 100644
--- a/src/transformers/models/bark/__init__.py
+++ b/src/transformers/models/bark/__init__.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,18 +13,63 @@
# limitations under the License.
from typing import TYPE_CHECKING
-from ...utils import _LazyModule
-from ...utils.import_utils import define_import_structure
+from ...utils import (
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ is_torch_available,
+)
+_import_structure = {
+ "configuration_bark": [
+ "BarkCoarseConfig",
+ "BarkConfig",
+ "BarkFineConfig",
+ "BarkSemanticConfig",
+ ],
+ "processing_bark": ["BarkProcessor"],
+}
+
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_bark"] = [
+ "BarkFineModel",
+ "BarkSemanticModel",
+ "BarkCoarseModel",
+ "BarkModel",
+ "BarkPreTrainedModel",
+ "BarkCausalModel",
+ ]
+
if TYPE_CHECKING:
- from .configuration_bark import *
- from .convert_suno_to_hf import *
- from .generation_configuration_bark import *
- from .modeling_bark import *
- from .processing_bark import *
+ from .configuration_bark import (
+ BarkCoarseConfig,
+ BarkConfig,
+ BarkFineConfig,
+ BarkSemanticConfig,
+ )
+ from .processing_bark import BarkProcessor
+
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_bark import (
+ BarkCausalModel,
+ BarkCoarseModel,
+ BarkFineModel,
+ BarkModel,
+ BarkPreTrainedModel,
+ BarkSemanticModel,
+ )
+
else:
import sys
- _file = globals()["__file__"]
- sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
diff --git a/src/transformers/models/bark/configuration_bark.py b/src/transformers/models/bark/configuration_bark.py
index 932bad618aa187..a498d1dd19371d 100644
--- a/src/transformers/models/bark/configuration_bark.py
+++ b/src/transformers/models/bark/configuration_bark.py
@@ -298,6 +298,3 @@ def from_sub_model_configs(
codec_config=codec_config.to_dict(),
**kwargs,
)
-
-
-__all__ = ["BarkCoarseConfig", "BarkConfig", "BarkFineConfig", "BarkSemanticConfig"]
diff --git a/src/transformers/models/bark/modeling_bark.py b/src/transformers/models/bark/modeling_bark.py
index 36a278263b558a..f1c77367e5beb7 100644
--- a/src/transformers/models/bark/modeling_bark.py
+++ b/src/transformers/models/bark/modeling_bark.py
@@ -197,6 +197,7 @@ class BarkSelfFlashAttention2(BarkSelfAttention):
flash attention and deal with padding tokens in case the input contains any of them.
"""
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -1818,13 +1819,3 @@ def _check_and_enable_flash_attn_2(
config.coarse_acoustics_config._attn_implementation = config._attn_implementation
config.fine_acoustics_config._attn_implementation = config._attn_implementation
return config
-
-
-__all__ = [
- "BarkFineModel",
- "BarkSemanticModel",
- "BarkCoarseModel",
- "BarkModel",
- "BarkPreTrainedModel",
- "BarkCausalModel",
-]
diff --git a/src/transformers/models/bark/processing_bark.py b/src/transformers/models/bark/processing_bark.py
index 0bed6ca79f410b..53715f3260422c 100644
--- a/src/transformers/models/bark/processing_bark.py
+++ b/src/transformers/models/bark/processing_bark.py
@@ -285,6 +285,3 @@ def __call__(
encoded_text["history_prompt"] = voice_preset
return encoded_text
-
-
-__all__ = ["BarkProcessor"]
diff --git a/src/transformers/models/bart/__init__.py b/src/transformers/models/bart/__init__.py
index 11c3f4863f46a1..d538fbb7d34304 100644
--- a/src/transformers/models/bart/__init__.py
+++ b/src/transformers/models/bart/__init__.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,20 +13,134 @@
# limitations under the License.
from typing import TYPE_CHECKING
-from ...utils import _LazyModule
-from ...utils.import_utils import define_import_structure
+from ...utils import (
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ is_flax_available,
+ is_tf_available,
+ is_tokenizers_available,
+ is_torch_available,
+)
+_import_structure = {
+ "configuration_bart": ["BartConfig", "BartOnnxConfig"],
+ "tokenization_bart": ["BartTokenizer"],
+}
+
+try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["tokenization_bart_fast"] = ["BartTokenizerFast"]
+
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_bart"] = [
+ "BartForCausalLM",
+ "BartForConditionalGeneration",
+ "BartForQuestionAnswering",
+ "BartForSequenceClassification",
+ "BartModel",
+ "BartPreTrainedModel",
+ "BartPretrainedModel",
+ "PretrainedBartModel",
+ ]
+
+try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_tf_bart"] = [
+ "TFBartForConditionalGeneration",
+ "TFBartForSequenceClassification",
+ "TFBartModel",
+ "TFBartPretrainedModel",
+ ]
+
+try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_flax_bart"] = [
+ "FlaxBartDecoderPreTrainedModel",
+ "FlaxBartForCausalLM",
+ "FlaxBartForConditionalGeneration",
+ "FlaxBartForQuestionAnswering",
+ "FlaxBartForSequenceClassification",
+ "FlaxBartModel",
+ "FlaxBartPreTrainedModel",
+ ]
+
if TYPE_CHECKING:
- from .configuration_bart import *
- from .convert_bart_original_pytorch_checkpoint_to_pytorch import *
- from .modeling_bart import *
- from .modeling_flax_bart import *
- from .modeling_tf_bart import *
- from .tokenization_bart import *
- from .tokenization_bart_fast import *
+ from .configuration_bart import BartConfig, BartOnnxConfig
+ from .tokenization_bart import BartTokenizer
+
+ try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .tokenization_bart_fast import BartTokenizerFast
+
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_bart import (
+ BartForCausalLM,
+ BartForConditionalGeneration,
+ BartForQuestionAnswering,
+ BartForSequenceClassification,
+ BartModel,
+ BartPreTrainedModel,
+ BartPretrainedModel,
+ PretrainedBartModel,
+ )
+
+ try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_tf_bart import (
+ TFBartForConditionalGeneration,
+ TFBartForSequenceClassification,
+ TFBartModel,
+ TFBartPretrainedModel,
+ )
+
+ try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_flax_bart import (
+ FlaxBartDecoderPreTrainedModel,
+ FlaxBartForCausalLM,
+ FlaxBartForConditionalGeneration,
+ FlaxBartForQuestionAnswering,
+ FlaxBartForSequenceClassification,
+ FlaxBartModel,
+ FlaxBartPreTrainedModel,
+ )
+
else:
import sys
- _file = globals()["__file__"]
- sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
diff --git a/src/transformers/models/bart/configuration_bart.py b/src/transformers/models/bart/configuration_bart.py
index 4ce4316e3c0315..a3bc7f38653a8a 100644
--- a/src/transformers/models/bart/configuration_bart.py
+++ b/src/transformers/models/bart/configuration_bart.py
@@ -400,6 +400,3 @@ def _flatten_past_key_values_(self, flattened_output, name, idx, t):
flattened_output = super(OnnxSeq2SeqConfigWithPast, self)._flatten_past_key_values_(
flattened_output, name, idx, t
)
-
-
-__all__ = ["BartConfig", "BartOnnxConfig"]
diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py
index 4e1f0b389d42ea..07c1fa622ea3b6 100755
--- a/src/transformers/models/bart/modeling_bart.py
+++ b/src/transformers/models/bart/modeling_bart.py
@@ -294,6 +294,7 @@ class BartFlashAttention2(BartAttention):
flash attention and deal with padding tokens in case the input contains any of them.
"""
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -2157,15 +2158,3 @@ def _reorder_cache(past_key_values, beam_idx):
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
return reordered_past
-
-
-__all__ = [
- "BartForCausalLM",
- "BartForConditionalGeneration",
- "BartForQuestionAnswering",
- "BartForSequenceClassification",
- "BartModel",
- "BartPreTrainedModel",
- "BartPretrainedModel",
- "PretrainedBartModel",
-]
diff --git a/src/transformers/models/bart/modeling_flax_bart.py b/src/transformers/models/bart/modeling_flax_bart.py
index b346eaa39fc199..634c256fe7d81d 100644
--- a/src/transformers/models/bart/modeling_flax_bart.py
+++ b/src/transformers/models/bart/modeling_flax_bart.py
@@ -1993,14 +1993,3 @@ def update_inputs_for_generation(self, model_outputs, model_kwargs):
FlaxCausalLMOutputWithCrossAttentions,
_CONFIG_FOR_DOC,
)
-
-
-__all__ = [
- "FlaxBartDecoderPreTrainedModel",
- "FlaxBartForCausalLM",
- "FlaxBartForConditionalGeneration",
- "FlaxBartForQuestionAnswering",
- "FlaxBartForSequenceClassification",
- "FlaxBartModel",
- "FlaxBartPreTrainedModel",
-]
diff --git a/src/transformers/models/bart/modeling_tf_bart.py b/src/transformers/models/bart/modeling_tf_bart.py
index 7ab9817986e6ad..5ebde8cba60c45 100644
--- a/src/transformers/models/bart/modeling_tf_bart.py
+++ b/src/transformers/models/bart/modeling_tf_bart.py
@@ -1709,6 +1709,3 @@ def build(self, input_shape=None):
if getattr(self, "classification_head", None) is not None:
with tf.name_scope(self.classification_head.name):
self.classification_head.build(None)
-
-
-__all__ = ["TFBartForConditionalGeneration", "TFBartForSequenceClassification", "TFBartModel", "TFBartPretrainedModel"]
diff --git a/src/transformers/models/bart/tokenization_bart.py b/src/transformers/models/bart/tokenization_bart.py
index 4c516cb81be0d2..5207b9c92b07ff 100644
--- a/src/transformers/models/bart/tokenization_bart.py
+++ b/src/transformers/models/bart/tokenization_bart.py
@@ -388,6 +388,3 @@ def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs):
if (is_split_into_words or add_prefix_space) and (len(text) > 0 and not text[0].isspace()):
text = " " + text
return (text, kwargs)
-
-
-__all__ = ["BartTokenizer"]
diff --git a/src/transformers/models/bart/tokenization_bart_fast.py b/src/transformers/models/bart/tokenization_bart_fast.py
index 4586ab4797e5ec..e9fb8497c907b9 100644
--- a/src/transformers/models/bart/tokenization_bart_fast.py
+++ b/src/transformers/models/bart/tokenization_bart_fast.py
@@ -274,6 +274,3 @@ def create_token_type_ids_from_sequences(
if token_ids_1 is None:
return len(cls + token_ids_0 + sep) * [0]
return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
-
-
-__all__ = ["BartTokenizerFast"]
diff --git a/src/transformers/models/barthez/__init__.py b/src/transformers/models/barthez/__init__.py
index 323fe2fe8af982..084cd22bdf1d88 100644
--- a/src/transformers/models/barthez/__init__.py
+++ b/src/transformers/models/barthez/__init__.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -11,17 +11,49 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
from typing import TYPE_CHECKING
-from ...utils import _LazyModule
-from ...utils.import_utils import define_import_structure
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_sentencepiece_available, is_tokenizers_available
+
+
+_import_structure = {}
+
+try:
+ if not is_sentencepiece_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["tokenization_barthez"] = ["BarthezTokenizer"]
+
+try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["tokenization_barthez_fast"] = ["BarthezTokenizerFast"]
if TYPE_CHECKING:
- from .tokenization_barthez import *
- from .tokenization_barthez_fast import *
+ try:
+ if not is_sentencepiece_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .tokenization_barthez import BarthezTokenizer
+
+ try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .tokenization_barthez_fast import BarthezTokenizerFast
+
else:
import sys
- _file = globals()["__file__"]
- sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
diff --git a/src/transformers/models/barthez/tokenization_barthez.py b/src/transformers/models/barthez/tokenization_barthez.py
index 604f9c7c21519a..46decddb3e10ba 100644
--- a/src/transformers/models/barthez/tokenization_barthez.py
+++ b/src/transformers/models/barthez/tokenization_barthez.py
@@ -284,6 +284,3 @@ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] =
fi.write(content_spiece_model)
return (out_vocab_file,)
-
-
-__all__ = ["BarthezTokenizer"]
diff --git a/src/transformers/models/barthez/tokenization_barthez_fast.py b/src/transformers/models/barthez/tokenization_barthez_fast.py
index a1d95ef03e4882..df8cc7757e96c0 100644
--- a/src/transformers/models/barthez/tokenization_barthez_fast.py
+++ b/src/transformers/models/barthez/tokenization_barthez_fast.py
@@ -192,6 +192,3 @@ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] =
copyfile(self.vocab_file, out_vocab_file)
return (out_vocab_file,)
-
-
-__all__ = ["BarthezTokenizerFast"]
diff --git a/src/transformers/models/bartpho/__init__.py b/src/transformers/models/bartpho/__init__.py
index 597be95d8175ca..c20d7370c6566c 100644
--- a/src/transformers/models/bartpho/__init__.py
+++ b/src/transformers/models/bartpho/__init__.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -11,16 +11,32 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
from typing import TYPE_CHECKING
-from ...utils import _LazyModule
-from ...utils.import_utils import define_import_structure
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_sentencepiece_available
+
+_import_structure = {}
+
+try:
+ if not is_sentencepiece_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["tokenization_bartpho"] = ["BartphoTokenizer"]
if TYPE_CHECKING:
- from .tokenization_bartpho import *
+ try:
+ if not is_sentencepiece_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .tokenization_bartpho import BartphoTokenizer
+
else:
import sys
- _file = globals()["__file__"]
- sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
diff --git a/src/transformers/models/bartpho/tokenization_bartpho.py b/src/transformers/models/bartpho/tokenization_bartpho.py
index e6e4f889842e8f..df121f26e255f4 100644
--- a/src/transformers/models/bartpho/tokenization_bartpho.py
+++ b/src/transformers/models/bartpho/tokenization_bartpho.py
@@ -311,6 +311,3 @@ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] =
fp.write(f"{str(token)} \n")
return out_vocab_file, out_monolingual_vocab_file
-
-
-__all__ = ["BartphoTokenizer"]
diff --git a/src/transformers/models/beit/__init__.py b/src/transformers/models/beit/__init__.py
index 0fc8919c7ea19a..c2f49240d6e64c 100644
--- a/src/transformers/models/beit/__init__.py
+++ b/src/transformers/models/beit/__init__.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -11,21 +11,100 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
from typing import TYPE_CHECKING
-from ...utils import _LazyModule
-from ...utils.import_utils import define_import_structure
+from ...utils import (
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ is_flax_available,
+ is_torch_available,
+ is_vision_available,
+)
+
+
+_import_structure = {"configuration_beit": ["BeitConfig", "BeitOnnxConfig"]}
+
+try:
+ if not is_vision_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["feature_extraction_beit"] = ["BeitFeatureExtractor"]
+ _import_structure["image_processing_beit"] = ["BeitImageProcessor"]
+
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_beit"] = [
+ "BeitForImageClassification",
+ "BeitForMaskedImageModeling",
+ "BeitForSemanticSegmentation",
+ "BeitModel",
+ "BeitPreTrainedModel",
+ "BeitBackbone",
+ ]
+
+try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_flax_beit"] = [
+ "FlaxBeitForImageClassification",
+ "FlaxBeitForMaskedImageModeling",
+ "FlaxBeitModel",
+ "FlaxBeitPreTrainedModel",
+ ]
if TYPE_CHECKING:
- from .configuration_beit import *
- from .convert_beit_unilm_to_pytorch import *
- from .feature_extraction_beit import *
- from .image_processing_beit import *
- from .modeling_beit import *
- from .modeling_flax_beit import *
+ from .configuration_beit import BeitConfig, BeitOnnxConfig
+
+ try:
+ if not is_vision_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .feature_extraction_beit import BeitFeatureExtractor
+ from .image_processing_beit import BeitImageProcessor
+
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_beit import (
+ BeitBackbone,
+ BeitForImageClassification,
+ BeitForMaskedImageModeling,
+ BeitForSemanticSegmentation,
+ BeitModel,
+ BeitPreTrainedModel,
+ )
+
+ try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_flax_beit import (
+ FlaxBeitForImageClassification,
+ FlaxBeitForMaskedImageModeling,
+ FlaxBeitModel,
+ FlaxBeitPreTrainedModel,
+ )
+
+
else:
import sys
- _file = globals()["__file__"]
- sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
diff --git a/src/transformers/models/beit/configuration_beit.py b/src/transformers/models/beit/configuration_beit.py
index 834988258c6b75..f0f3c2582c35cc 100644
--- a/src/transformers/models/beit/configuration_beit.py
+++ b/src/transformers/models/beit/configuration_beit.py
@@ -224,6 +224,3 @@ def inputs(self) -> Mapping[str, Mapping[int, str]]:
@property
def atol_for_validation(self) -> float:
return 1e-4
-
-
-__all__ = ["BeitConfig", "BeitOnnxConfig"]
diff --git a/src/transformers/models/beit/feature_extraction_beit.py b/src/transformers/models/beit/feature_extraction_beit.py
index 141d8bc36d2bbb..59dacb4ae51f6e 100644
--- a/src/transformers/models/beit/feature_extraction_beit.py
+++ b/src/transformers/models/beit/feature_extraction_beit.py
@@ -31,6 +31,3 @@ def __init__(self, *args, **kwargs) -> None:
FutureWarning,
)
super().__init__(*args, **kwargs)
-
-
-__all__ = ["BeitFeatureExtractor"]
diff --git a/src/transformers/models/beit/image_processing_beit.py b/src/transformers/models/beit/image_processing_beit.py
index af76dd2e9656cb..7398381b2229bf 100644
--- a/src/transformers/models/beit/image_processing_beit.py
+++ b/src/transformers/models/beit/image_processing_beit.py
@@ -510,6 +510,3 @@ def post_process_semantic_segmentation(self, outputs, target_sizes: List[Tuple]
semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])]
return semantic_segmentation
-
-
-__all__ = ["BeitImageProcessor"]
diff --git a/src/transformers/models/beit/modeling_beit.py b/src/transformers/models/beit/modeling_beit.py
index 601e2801d67587..f972e021f3e2b3 100755
--- a/src/transformers/models/beit/modeling_beit.py
+++ b/src/transformers/models/beit/modeling_beit.py
@@ -361,68 +361,6 @@ def forward(
return outputs
-class BeitSdpaSelfAttention(BeitSelfAttention):
- def forward(
- self,
- hidden_states: torch.Tensor,
- head_mask: Optional[torch.Tensor] = None,
- output_attentions: bool = False,
- relative_position_bias: Optional["BeitRelativePositionBias"] = None,
- interpolate_pos_encoding: bool = False,
- resolution: Optional[Tuple[int]] = None,
- ) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
- if output_attentions or head_mask is not None:
- logger.warning_once(
- "`BeitSdpaSelfAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not "
- "support `output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, "
- "but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. "
- 'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
- )
- return super().forward(
- hidden_states=hidden_states,
- head_mask=head_mask,
- output_attentions=output_attentions,
- relative_position_bias=relative_position_bias,
- interpolate_pos_encoding=interpolate_pos_encoding,
- resolution=resolution,
- )
-
- mixed_query_layer = self.query(hidden_states)
- key_layer = self.transpose_for_scores(self.key(hidden_states))
- value_layer = self.transpose_for_scores(self.value(hidden_states))
- query_layer = self.transpose_for_scores(mixed_query_layer)
-
- attn_bias = None
- if self.relative_position_bias is not None:
- height, width = resolution
- window_size = (height // self.config.patch_size, width // self.config.patch_size)
- attn_bias = self.relative_position_bias(
- window_size, interpolate_pos_encoding, dim_size=hidden_states.shape[1]
- )
-
- # Add shared relative position bias if provided.
- if relative_position_bias is not None:
- if attn_bias is None:
- attn_bias = relative_position_bias
- else:
- attn_bias += relative_position_bias
-
- scaling = 1 / math.sqrt(self.attention_head_size)
- context_layer = torch.nn.functional.scaled_dot_product_attention(
- query_layer,
- key_layer,
- value_layer,
- attn_mask=attn_bias,
- dropout_p=self.config.attention_probs_dropout_prob if self.training else 0.0,
- is_causal=False,
- scale=scaling,
- )
- context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
- new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
- context_layer = context_layer.view(*new_context_layer_shape)
- return context_layer, None
-
-
class BeitSelfOutput(nn.Module):
"""
The residual connection is defined in BeitLayer instead of here (as is the case with other models), due to the
@@ -441,16 +379,10 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor, gamma
return hidden_states
-BEIT_SELF_ATTENTION_CLASSES = {
- "eager": BeitSelfAttention,
- "sdpa": BeitSdpaSelfAttention,
-}
-
-
class BeitAttention(nn.Module):
def __init__(self, config: BeitConfig, window_size: Optional[tuple] = None) -> None:
super().__init__()
- self.attention = BEIT_SELF_ATTENTION_CLASSES[config._attn_implementation](config, window_size=window_size)
+ self.attention = BeitSelfAttention(config, window_size=window_size)
self.output = BeitSelfOutput(config)
self.pruned_heads = set()
@@ -768,7 +700,6 @@ class BeitPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["BeitLayer"]
_keys_to_ignore_on_load_unexpected = [r".*relative_position_index.*"]
- _supports_sdpa = True
def _init_weights(self, module):
"""Initialize the weights"""
@@ -1645,13 +1576,3 @@ def forward(
hidden_states=outputs.hidden_states if output_hidden_states else None,
attentions=outputs.attentions,
)
-
-
-__all__ = [
- "BeitForImageClassification",
- "BeitForMaskedImageModeling",
- "BeitForSemanticSegmentation",
- "BeitModel",
- "BeitPreTrainedModel",
- "BeitBackbone",
-]
diff --git a/src/transformers/models/beit/modeling_flax_beit.py b/src/transformers/models/beit/modeling_flax_beit.py
index 2d79c1820088a1..c1da64d263a266 100644
--- a/src/transformers/models/beit/modeling_flax_beit.py
+++ b/src/transformers/models/beit/modeling_flax_beit.py
@@ -946,11 +946,3 @@ class FlaxBeitForImageClassification(FlaxBeitPreTrainedModel):
append_replace_return_docstrings(
FlaxBeitForImageClassification, output_type=FlaxSequenceClassifierOutput, config_class=BeitConfig
)
-
-
-__all__ = [
- "FlaxBeitForImageClassification",
- "FlaxBeitForMaskedImageModeling",
- "FlaxBeitModel",
- "FlaxBeitPreTrainedModel",
-]
diff --git a/src/transformers/models/bert/__init__.py b/src/transformers/models/bert/__init__.py
index 3ed12a889321e6..17048a5d1c967a 100644
--- a/src/transformers/models/bert/__init__.py
+++ b/src/transformers/models/bert/__init__.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -11,26 +11,183 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
from typing import TYPE_CHECKING
-from ...utils import _LazyModule
-from ...utils.import_utils import define_import_structure
+from ...utils import (
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ is_flax_available,
+ is_tensorflow_text_available,
+ is_tf_available,
+ is_tokenizers_available,
+ is_torch_available,
+)
+
+
+_import_structure = {
+ "configuration_bert": ["BertConfig", "BertOnnxConfig"],
+ "tokenization_bert": ["BasicTokenizer", "BertTokenizer", "WordpieceTokenizer"],
+}
+
+try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["tokenization_bert_fast"] = ["BertTokenizerFast"]
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_bert"] = [
+ "BertForMaskedLM",
+ "BertForMultipleChoice",
+ "BertForNextSentencePrediction",
+ "BertForPreTraining",
+ "BertForQuestionAnswering",
+ "BertForSequenceClassification",
+ "BertForTokenClassification",
+ "BertLayer",
+ "BertLMHeadModel",
+ "BertModel",
+ "BertPreTrainedModel",
+ "load_tf_weights_in_bert",
+ ]
+
+try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_tf_bert"] = [
+ "TFBertEmbeddings",
+ "TFBertForMaskedLM",
+ "TFBertForMultipleChoice",
+ "TFBertForNextSentencePrediction",
+ "TFBertForPreTraining",
+ "TFBertForQuestionAnswering",
+ "TFBertForSequenceClassification",
+ "TFBertForTokenClassification",
+ "TFBertLMHeadModel",
+ "TFBertMainLayer",
+ "TFBertModel",
+ "TFBertPreTrainedModel",
+ ]
+try:
+ if not is_tensorflow_text_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["tokenization_bert_tf"] = ["TFBertTokenizer"]
+
+try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_flax_bert"] = [
+ "FlaxBertForCausalLM",
+ "FlaxBertForMaskedLM",
+ "FlaxBertForMultipleChoice",
+ "FlaxBertForNextSentencePrediction",
+ "FlaxBertForPreTraining",
+ "FlaxBertForQuestionAnswering",
+ "FlaxBertForSequenceClassification",
+ "FlaxBertForTokenClassification",
+ "FlaxBertModel",
+ "FlaxBertPreTrainedModel",
+ ]
if TYPE_CHECKING:
- from .configuration_bert import *
- from .convert_bert_original_tf2_checkpoint_to_pytorch import *
- from .convert_bert_original_tf_checkpoint_to_pytorch import *
- from .convert_bert_pytorch_checkpoint_to_original_tf import *
- from .convert_bert_token_dropping_original_tf2_checkpoint_to_pytorch import *
- from .modeling_bert import *
- from .modeling_flax_bert import *
- from .modeling_tf_bert import *
- from .tokenization_bert import *
- from .tokenization_bert_fast import *
- from .tokenization_bert_tf import *
+ from .configuration_bert import BertConfig, BertOnnxConfig
+ from .tokenization_bert import BasicTokenizer, BertTokenizer, WordpieceTokenizer
+
+ try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .tokenization_bert_fast import BertTokenizerFast
+
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_bert import (
+ BertForMaskedLM,
+ BertForMultipleChoice,
+ BertForNextSentencePrediction,
+ BertForPreTraining,
+ BertForQuestionAnswering,
+ BertForSequenceClassification,
+ BertForTokenClassification,
+ BertLayer,
+ BertLMHeadModel,
+ BertModel,
+ BertPreTrainedModel,
+ load_tf_weights_in_bert,
+ )
+
+ try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_tf_bert import (
+ TFBertEmbeddings,
+ TFBertForMaskedLM,
+ TFBertForMultipleChoice,
+ TFBertForNextSentencePrediction,
+ TFBertForPreTraining,
+ TFBertForQuestionAnswering,
+ TFBertForSequenceClassification,
+ TFBertForTokenClassification,
+ TFBertLMHeadModel,
+ TFBertMainLayer,
+ TFBertModel,
+ TFBertPreTrainedModel,
+ )
+
+ try:
+ if not is_tensorflow_text_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .tokenization_bert_tf import TFBertTokenizer
+
+ try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_flax_bert import (
+ FlaxBertForCausalLM,
+ FlaxBertForMaskedLM,
+ FlaxBertForMultipleChoice,
+ FlaxBertForNextSentencePrediction,
+ FlaxBertForPreTraining,
+ FlaxBertForQuestionAnswering,
+ FlaxBertForSequenceClassification,
+ FlaxBertForTokenClassification,
+ FlaxBertModel,
+ FlaxBertPreTrainedModel,
+ )
+
else:
import sys
- _file = globals()["__file__"]
- sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
diff --git a/src/transformers/models/bert/configuration_bert.py b/src/transformers/models/bert/configuration_bert.py
index ea29fb81c435aa..613cf6a11463c2 100644
--- a/src/transformers/models/bert/configuration_bert.py
+++ b/src/transformers/models/bert/configuration_bert.py
@@ -149,6 +149,3 @@ def inputs(self) -> Mapping[str, Mapping[int, str]]:
("token_type_ids", dynamic_axis),
]
)
-
-
-__all__ = ["BertConfig", "BertOnnxConfig"]
diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py
index 0c53963cee7922..6b05fa648158a6 100755
--- a/src/transformers/models/bert/modeling_bert.py
+++ b/src/transformers/models/bert/modeling_bert.py
@@ -1325,7 +1325,6 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
- **loss_kwargs,
) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
r"""
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
@@ -1376,7 +1375,11 @@ def forward(
lm_loss = None
if labels is not None:
- lm_loss = self.loss_function(prediction_scores, labels, self.config.vocab_size, **loss_kwargs)
+ # we are doing next-token prediction; shift prediction scores and input ids by one
+ shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
+ labels = labels[:, 1:].contiguous()
+ loss_fct = CrossEntropyLoss()
+ lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
if not return_dict:
output = (prediction_scores,) + outputs[2:]
@@ -1991,19 +1994,3 @@ def forward(
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
-
-
-__all__ = [
- "BertForMaskedLM",
- "BertForMultipleChoice",
- "BertForNextSentencePrediction",
- "BertForPreTraining",
- "BertForQuestionAnswering",
- "BertForSequenceClassification",
- "BertForTokenClassification",
- "BertLayer",
- "BertLMHeadModel",
- "BertModel",
- "BertPreTrainedModel",
- "load_tf_weights_in_bert",
-]
diff --git a/src/transformers/models/bert/modeling_flax_bert.py b/src/transformers/models/bert/modeling_flax_bert.py
index 83358c86bd280d..772ea2bf12b2ee 100644
--- a/src/transformers/models/bert/modeling_flax_bert.py
+++ b/src/transformers/models/bert/modeling_flax_bert.py
@@ -1711,17 +1711,3 @@ def update_inputs_for_generation(self, model_outputs, model_kwargs):
FlaxCausalLMOutputWithCrossAttentions,
_CONFIG_FOR_DOC,
)
-
-
-__all__ = [
- "FlaxBertForCausalLM",
- "FlaxBertForMaskedLM",
- "FlaxBertForMultipleChoice",
- "FlaxBertForNextSentencePrediction",
- "FlaxBertForPreTraining",
- "FlaxBertForQuestionAnswering",
- "FlaxBertForSequenceClassification",
- "FlaxBertForTokenClassification",
- "FlaxBertModel",
- "FlaxBertPreTrainedModel",
-]
diff --git a/src/transformers/models/bert/modeling_tf_bert.py b/src/transformers/models/bert/modeling_tf_bert.py
index ce862194dc7787..bb3281278adaa1 100644
--- a/src/transformers/models/bert/modeling_tf_bert.py
+++ b/src/transformers/models/bert/modeling_tf_bert.py
@@ -2108,19 +2108,3 @@ def build(self, input_shape=None):
if getattr(self, "qa_outputs", None) is not None:
with tf.name_scope(self.qa_outputs.name):
self.qa_outputs.build([None, None, self.config.hidden_size])
-
-
-__all__ = [
- "TFBertEmbeddings",
- "TFBertForMaskedLM",
- "TFBertForMultipleChoice",
- "TFBertForNextSentencePrediction",
- "TFBertForPreTraining",
- "TFBertForQuestionAnswering",
- "TFBertForSequenceClassification",
- "TFBertForTokenClassification",
- "TFBertLMHeadModel",
- "TFBertMainLayer",
- "TFBertModel",
- "TFBertPreTrainedModel",
-]
diff --git a/src/transformers/models/bert/tokenization_bert.py b/src/transformers/models/bert/tokenization_bert.py
index 42d4dd94554d41..07583b949661de 100644
--- a/src/transformers/models/bert/tokenization_bert.py
+++ b/src/transformers/models/bert/tokenization_bert.py
@@ -502,6 +502,3 @@ def tokenize(self, text):
else:
output_tokens.extend(sub_tokens)
return output_tokens
-
-
-__all__ = ["BasicTokenizer", "BertTokenizer", "WordpieceTokenizer"]
diff --git a/src/transformers/models/bert/tokenization_bert_fast.py b/src/transformers/models/bert/tokenization_bert_fast.py
index 4a89e6053b988f..f4897772847029 100644
--- a/src/transformers/models/bert/tokenization_bert_fast.py
+++ b/src/transformers/models/bert/tokenization_bert_fast.py
@@ -170,6 +170,3 @@ def create_token_type_ids_from_sequences(
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
files = self._tokenizer.model.save(save_directory, name=filename_prefix)
return tuple(files)
-
-
-__all__ = ["BertTokenizerFast"]
diff --git a/src/transformers/models/bert/tokenization_bert_tf.py b/src/transformers/models/bert/tokenization_bert_tf.py
index b1f49722fbdffa..ebf88eeac9bbe8 100644
--- a/src/transformers/models/bert/tokenization_bert_tf.py
+++ b/src/transformers/models/bert/tokenization_bert_tf.py
@@ -252,6 +252,3 @@ def get_config(self):
"sep_token_id": self.sep_token_id,
"pad_token_id": self.pad_token_id,
}
-
-
-__all__ = ["TFBertTokenizer"]
diff --git a/src/transformers/models/bert_generation/__init__.py b/src/transformers/models/bert_generation/__init__.py
index 3f83b1f6e5bba3..14cf8bb5879320 100644
--- a/src/transformers/models/bert_generation/__init__.py
+++ b/src/transformers/models/bert_generation/__init__.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -11,18 +11,61 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
from typing import TYPE_CHECKING
-from ...utils import _LazyModule
-from ...utils.import_utils import define_import_structure
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_sentencepiece_available, is_torch_available
+
+
+_import_structure = {"configuration_bert_generation": ["BertGenerationConfig"]}
+
+try:
+ if not is_sentencepiece_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["tokenization_bert_generation"] = ["BertGenerationTokenizer"]
+
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_bert_generation"] = [
+ "BertGenerationDecoder",
+ "BertGenerationEncoder",
+ "BertGenerationPreTrainedModel",
+ "load_tf_weights_in_bert_generation",
+ ]
if TYPE_CHECKING:
- from .configuration_bert_generation import *
- from .modeling_bert_generation import *
- from .tokenization_bert_generation import *
+ from .configuration_bert_generation import BertGenerationConfig
+
+ try:
+ if not is_sentencepiece_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .tokenization_bert_generation import BertGenerationTokenizer
+
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_bert_generation import (
+ BertGenerationDecoder,
+ BertGenerationEncoder,
+ BertGenerationPreTrainedModel,
+ load_tf_weights_in_bert_generation,
+ )
+
else:
import sys
- _file = globals()["__file__"]
- sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
diff --git a/src/transformers/models/bert_generation/configuration_bert_generation.py b/src/transformers/models/bert_generation/configuration_bert_generation.py
index 1abe7c1a1c44ab..d1d1b51b6538e2 100644
--- a/src/transformers/models/bert_generation/configuration_bert_generation.py
+++ b/src/transformers/models/bert_generation/configuration_bert_generation.py
@@ -122,6 +122,3 @@ def __init__(
self.layer_norm_eps = layer_norm_eps
self.position_embedding_type = position_embedding_type
self.use_cache = use_cache
-
-
-__all__ = ["BertGenerationConfig"]
diff --git a/src/transformers/models/bert_generation/modeling_bert_generation.py b/src/transformers/models/bert_generation/modeling_bert_generation.py
index aaf326aa2de8eb..800ea2bef1d631 100755
--- a/src/transformers/models/bert_generation/modeling_bert_generation.py
+++ b/src/transformers/models/bert_generation/modeling_bert_generation.py
@@ -996,11 +996,3 @@ def _reorder_cache(self, past_key_values, beam_idx):
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
return reordered_past
-
-
-__all__ = [
- "BertGenerationDecoder",
- "BertGenerationEncoder",
- "BertGenerationPreTrainedModel",
- "load_tf_weights_in_bert_generation",
-]
diff --git a/src/transformers/models/bert_generation/tokenization_bert_generation.py b/src/transformers/models/bert_generation/tokenization_bert_generation.py
index 31f046863c289c..b1adb9b62b2551 100644
--- a/src/transformers/models/bert_generation/tokenization_bert_generation.py
+++ b/src/transformers/models/bert_generation/tokenization_bert_generation.py
@@ -170,6 +170,3 @@ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] =
fi.write(content_spiece_model)
return (out_vocab_file,)
-
-
-__all__ = ["BertGenerationTokenizer"]
diff --git a/src/transformers/models/bert_japanese/__init__.py b/src/transformers/models/bert_japanese/__init__.py
index f5296087db1d00..a569c3cc54bff8 100644
--- a/src/transformers/models/bert_japanese/__init__.py
+++ b/src/transformers/models/bert_japanese/__init__.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -11,16 +11,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
from typing import TYPE_CHECKING
from ...utils import _LazyModule
-from ...utils.import_utils import define_import_structure
+
+
+_import_structure = {"tokenization_bert_japanese": ["BertJapaneseTokenizer", "CharacterTokenizer", "MecabTokenizer"]}
if TYPE_CHECKING:
- from .tokenization_bert_japanese import *
+ from .tokenization_bert_japanese import BertJapaneseTokenizer, CharacterTokenizer, MecabTokenizer
+
else:
import sys
- _file = globals()["__file__"]
- sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
diff --git a/src/transformers/models/bert_japanese/tokenization_bert_japanese.py b/src/transformers/models/bert_japanese/tokenization_bert_japanese.py
index 8a841a3091623d..732e9e7aff5741 100644
--- a/src/transformers/models/bert_japanese/tokenization_bert_japanese.py
+++ b/src/transformers/models/bert_japanese/tokenization_bert_japanese.py
@@ -977,6 +977,3 @@ def tokenize(self, text):
new_pieces.append(piece)
return new_pieces
-
-
-__all__ = ["BertJapaneseTokenizer", "CharacterTokenizer", "MecabTokenizer"]
diff --git a/src/transformers/models/bertweet/__init__.py b/src/transformers/models/bertweet/__init__.py
index 432622f1595d1a..42e4a23337c20c 100644
--- a/src/transformers/models/bertweet/__init__.py
+++ b/src/transformers/models/bertweet/__init__.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -11,16 +11,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
from typing import TYPE_CHECKING
from ...utils import _LazyModule
-from ...utils.import_utils import define_import_structure
+
+
+_import_structure = {"tokenization_bertweet": ["BertweetTokenizer"]}
if TYPE_CHECKING:
- from .tokenization_bertweet import *
+ from .tokenization_bertweet import BertweetTokenizer
+
else:
import sys
- _file = globals()["__file__"]
- sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
diff --git a/src/transformers/models/bertweet/tokenization_bertweet.py b/src/transformers/models/bertweet/tokenization_bertweet.py
index 499238e5955fe0..f478dd0832b6e4 100644
--- a/src/transformers/models/bertweet/tokenization_bertweet.py
+++ b/src/transformers/models/bertweet/tokenization_bertweet.py
@@ -764,6 +764,3 @@ def casual_tokenize(text, preserve_case=True, reduce_len=False, strip_handles=Fa
###############################################################################
-
-
-__all__ = ["BertweetTokenizer"]
diff --git a/src/transformers/models/big_bird/__init__.py b/src/transformers/models/big_bird/__init__.py
index b89712ab5ab49f..8eda33d9ee6608 100644
--- a/src/transformers/models/big_bird/__init__.py
+++ b/src/transformers/models/big_bird/__init__.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,19 +13,133 @@
# limitations under the License.
from typing import TYPE_CHECKING
-from ...utils import _LazyModule
-from ...utils.import_utils import define_import_structure
+from ...utils import (
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ is_flax_available,
+ is_sentencepiece_available,
+ is_tf_available,
+ is_tokenizers_available,
+ is_torch_available,
+)
+_import_structure = {
+ "configuration_big_bird": ["BigBirdConfig", "BigBirdOnnxConfig"],
+}
+
+try:
+ if not is_sentencepiece_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["tokenization_big_bird"] = ["BigBirdTokenizer"]
+
+try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["tokenization_big_bird_fast"] = ["BigBirdTokenizerFast"]
+
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_big_bird"] = [
+ "BigBirdForCausalLM",
+ "BigBirdForMaskedLM",
+ "BigBirdForMultipleChoice",
+ "BigBirdForPreTraining",
+ "BigBirdForQuestionAnswering",
+ "BigBirdForSequenceClassification",
+ "BigBirdForTokenClassification",
+ "BigBirdLayer",
+ "BigBirdModel",
+ "BigBirdPreTrainedModel",
+ "load_tf_weights_in_big_bird",
+ ]
+
+try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_flax_big_bird"] = [
+ "FlaxBigBirdForCausalLM",
+ "FlaxBigBirdForMaskedLM",
+ "FlaxBigBirdForMultipleChoice",
+ "FlaxBigBirdForPreTraining",
+ "FlaxBigBirdForQuestionAnswering",
+ "FlaxBigBirdForSequenceClassification",
+ "FlaxBigBirdForTokenClassification",
+ "FlaxBigBirdModel",
+ "FlaxBigBirdPreTrainedModel",
+ ]
+
if TYPE_CHECKING:
- from .configuration_big_bird import *
- from .convert_bigbird_original_tf_checkpoint_to_pytorch import *
- from .modeling_big_bird import *
- from .modeling_flax_big_bird import *
- from .tokenization_big_bird import *
- from .tokenization_big_bird_fast import *
+ from .configuration_big_bird import BigBirdConfig, BigBirdOnnxConfig
+
+ try:
+ if not is_sentencepiece_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .tokenization_big_bird import BigBirdTokenizer
+
+ try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .tokenization_big_bird_fast import BigBirdTokenizerFast
+
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_big_bird import (
+ BigBirdForCausalLM,
+ BigBirdForMaskedLM,
+ BigBirdForMultipleChoice,
+ BigBirdForPreTraining,
+ BigBirdForQuestionAnswering,
+ BigBirdForSequenceClassification,
+ BigBirdForTokenClassification,
+ BigBirdLayer,
+ BigBirdModel,
+ BigBirdPreTrainedModel,
+ load_tf_weights_in_big_bird,
+ )
+
+ try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_flax_big_bird import (
+ FlaxBigBirdForCausalLM,
+ FlaxBigBirdForMaskedLM,
+ FlaxBigBirdForMultipleChoice,
+ FlaxBigBirdForPreTraining,
+ FlaxBigBirdForQuestionAnswering,
+ FlaxBigBirdForSequenceClassification,
+ FlaxBigBirdForTokenClassification,
+ FlaxBigBirdModel,
+ FlaxBigBirdPreTrainedModel,
+ )
+
else:
import sys
- _file = globals()["__file__"]
- sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
diff --git a/src/transformers/models/big_bird/configuration_big_bird.py b/src/transformers/models/big_bird/configuration_big_bird.py
index 1019e008aa3b38..cbcf2e6bf57fd7 100644
--- a/src/transformers/models/big_bird/configuration_big_bird.py
+++ b/src/transformers/models/big_bird/configuration_big_bird.py
@@ -171,6 +171,3 @@ def inputs(self) -> Mapping[str, Mapping[int, str]]:
("attention_mask", dynamic_axis),
]
)
-
-
-__all__ = ["BigBirdConfig", "BigBirdOnnxConfig"]
diff --git a/src/transformers/models/big_bird/modeling_big_bird.py b/src/transformers/models/big_bird/modeling_big_bird.py
index 47c78284b7f29c..958d192fa03dbc 100755
--- a/src/transformers/models/big_bird/modeling_big_bird.py
+++ b/src/transformers/models/big_bird/modeling_big_bird.py
@@ -3126,18 +3126,3 @@ def prepare_question_mask(q_lengths: torch.Tensor, maxlen: int):
mask.unsqueeze_(0) # -> (1, maxlen)
mask = torch.where(mask < q_lengths, 1, 0)
return mask
-
-
-__all__ = [
- "BigBirdForCausalLM",
- "BigBirdForMaskedLM",
- "BigBirdForMultipleChoice",
- "BigBirdForPreTraining",
- "BigBirdForQuestionAnswering",
- "BigBirdForSequenceClassification",
- "BigBirdForTokenClassification",
- "BigBirdLayer",
- "BigBirdModel",
- "BigBirdPreTrainedModel",
- "load_tf_weights_in_big_bird",
-]
diff --git a/src/transformers/models/big_bird/modeling_flax_big_bird.py b/src/transformers/models/big_bird/modeling_flax_big_bird.py
index 8d23180a8348cd..94eabdec451dda 100644
--- a/src/transformers/models/big_bird/modeling_flax_big_bird.py
+++ b/src/transformers/models/big_bird/modeling_flax_big_bird.py
@@ -2633,16 +2633,3 @@ def update_inputs_for_generation(self, model_outputs, model_kwargs):
FlaxCausalLMOutputWithCrossAttentions,
_CONFIG_FOR_DOC,
)
-
-
-__all__ = [
- "FlaxBigBirdForCausalLM",
- "FlaxBigBirdForMaskedLM",
- "FlaxBigBirdForMultipleChoice",
- "FlaxBigBirdForPreTraining",
- "FlaxBigBirdForQuestionAnswering",
- "FlaxBigBirdForSequenceClassification",
- "FlaxBigBirdForTokenClassification",
- "FlaxBigBirdModel",
- "FlaxBigBirdPreTrainedModel",
-]
diff --git a/src/transformers/models/big_bird/tokenization_big_bird.py b/src/transformers/models/big_bird/tokenization_big_bird.py
index 194cbc68cb56ba..e435477ef3c6b4 100644
--- a/src/transformers/models/big_bird/tokenization_big_bird.py
+++ b/src/transformers/models/big_bird/tokenization_big_bird.py
@@ -319,6 +319,3 @@ def create_token_type_ids_from_sequences(
if token_ids_1 is None:
return len(cls + token_ids_0 + sep) * [0]
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
-
-
-__all__ = ["BigBirdTokenizer"]
diff --git a/src/transformers/models/big_bird/tokenization_big_bird_fast.py b/src/transformers/models/big_bird/tokenization_big_bird_fast.py
index 83f2fac07fae72..f4ccbb8b1797f9 100644
--- a/src/transformers/models/big_bird/tokenization_big_bird_fast.py
+++ b/src/transformers/models/big_bird/tokenization_big_bird_fast.py
@@ -227,6 +227,3 @@ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] =
copyfile(self.vocab_file, out_vocab_file)
return (out_vocab_file,)
-
-
-__all__ = ["BigBirdTokenizerFast"]
diff --git a/src/transformers/models/bigbird_pegasus/__init__.py b/src/transformers/models/bigbird_pegasus/__init__.py
index 8684d999d85cb4..85621ce76d902b 100644
--- a/src/transformers/models/bigbird_pegasus/__init__.py
+++ b/src/transformers/models/bigbird_pegasus/__init__.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,16 +13,55 @@
# limitations under the License.
from typing import TYPE_CHECKING
-from ...utils import _LazyModule
-from ...utils.import_utils import define_import_structure
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
+
+
+_import_structure = {
+ "configuration_bigbird_pegasus": [
+ "BigBirdPegasusConfig",
+ "BigBirdPegasusOnnxConfig",
+ ],
+}
+
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_bigbird_pegasus"] = [
+ "BigBirdPegasusForCausalLM",
+ "BigBirdPegasusForConditionalGeneration",
+ "BigBirdPegasusForQuestionAnswering",
+ "BigBirdPegasusForSequenceClassification",
+ "BigBirdPegasusModel",
+ "BigBirdPegasusPreTrainedModel",
+ ]
if TYPE_CHECKING:
- from .configuration_bigbird_pegasus import *
- from .convert_bigbird_pegasus_tf_to_pytorch import *
- from .modeling_bigbird_pegasus import *
+ from .configuration_bigbird_pegasus import (
+ BigBirdPegasusConfig,
+ BigBirdPegasusOnnxConfig,
+ )
+
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_bigbird_pegasus import (
+ BigBirdPegasusForCausalLM,
+ BigBirdPegasusForConditionalGeneration,
+ BigBirdPegasusForQuestionAnswering,
+ BigBirdPegasusForSequenceClassification,
+ BigBirdPegasusModel,
+ BigBirdPegasusPreTrainedModel,
+ )
+
+
else:
import sys
- _file = globals()["__file__"]
- sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
diff --git a/src/transformers/models/bigbird_pegasus/configuration_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/configuration_bigbird_pegasus.py
index 5d9c9bf1a4b0b2..9de2a7267acba8 100644
--- a/src/transformers/models/bigbird_pegasus/configuration_bigbird_pegasus.py
+++ b/src/transformers/models/bigbird_pegasus/configuration_bigbird_pegasus.py
@@ -407,6 +407,3 @@ def _flatten_past_key_values_(self, flattened_output, name, idx, t):
flattened_output = super(OnnxSeq2SeqConfigWithPast, self)._flatten_past_key_values_(
flattened_output, name, idx, t
)
-
-
-__all__ = ["BigBirdPegasusConfig", "BigBirdPegasusOnnxConfig"]
diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py
index fd52e4b8bb731c..520e7dab1f119d 100755
--- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py
+++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py
@@ -3028,13 +3028,3 @@ def _reorder_cache(past_key_values, beam_idx):
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
return reordered_past
-
-
-__all__ = [
- "BigBirdPegasusForCausalLM",
- "BigBirdPegasusForConditionalGeneration",
- "BigBirdPegasusForQuestionAnswering",
- "BigBirdPegasusForSequenceClassification",
- "BigBirdPegasusModel",
- "BigBirdPegasusPreTrainedModel",
-]
diff --git a/src/transformers/models/biogpt/__init__.py b/src/transformers/models/biogpt/__init__.py
index 27773fb642459c..355c87e67ba2b7 100644
--- a/src/transformers/models/biogpt/__init__.py
+++ b/src/transformers/models/biogpt/__init__.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,17 +13,49 @@
# limitations under the License.
from typing import TYPE_CHECKING
-from ...utils import _LazyModule
-from ...utils.import_utils import define_import_structure
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tokenizers_available, is_torch_available
+
+
+_import_structure = {
+ "configuration_biogpt": ["BioGptConfig"],
+ "tokenization_biogpt": ["BioGptTokenizer"],
+}
+
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_biogpt"] = [
+ "BioGptForCausalLM",
+ "BioGptForTokenClassification",
+ "BioGptForSequenceClassification",
+ "BioGptModel",
+ "BioGptPreTrainedModel",
+ ]
if TYPE_CHECKING:
- from .configuration_biogpt import *
- from .convert_biogpt_original_pytorch_checkpoint_to_pytorch import *
- from .modeling_biogpt import *
- from .tokenization_biogpt import *
+ from .configuration_biogpt import BioGptConfig
+ from .tokenization_biogpt import BioGptTokenizer
+
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_biogpt import (
+ BioGptForCausalLM,
+ BioGptForSequenceClassification,
+ BioGptForTokenClassification,
+ BioGptModel,
+ BioGptPreTrainedModel,
+ )
+
+
else:
import sys
- _file = globals()["__file__"]
- sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
diff --git a/src/transformers/models/biogpt/configuration_biogpt.py b/src/transformers/models/biogpt/configuration_biogpt.py
index b338092edd1d0b..18f7b6d6bf06e7 100644
--- a/src/transformers/models/biogpt/configuration_biogpt.py
+++ b/src/transformers/models/biogpt/configuration_biogpt.py
@@ -129,6 +129,3 @@ def __init__(
self.layerdrop = layerdrop
self.activation_dropout = activation_dropout
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
-
-
-__all__ = ["BioGptConfig"]
diff --git a/src/transformers/models/biogpt/modeling_biogpt.py b/src/transformers/models/biogpt/modeling_biogpt.py
index e9d76413600879..6bc80bc04959b6 100755
--- a/src/transformers/models/biogpt/modeling_biogpt.py
+++ b/src/transformers/models/biogpt/modeling_biogpt.py
@@ -1028,12 +1028,3 @@ def get_input_embeddings(self):
def set_input_embeddings(self, value):
self.biogpt.embed_tokens = value
-
-
-__all__ = [
- "BioGptForCausalLM",
- "BioGptForTokenClassification",
- "BioGptForSequenceClassification",
- "BioGptModel",
- "BioGptPreTrainedModel",
-]
diff --git a/src/transformers/models/biogpt/tokenization_biogpt.py b/src/transformers/models/biogpt/tokenization_biogpt.py
index a898976d985f58..f9760eb604e7d2 100644
--- a/src/transformers/models/biogpt/tokenization_biogpt.py
+++ b/src/transformers/models/biogpt/tokenization_biogpt.py
@@ -356,6 +356,3 @@ def __setstate__(self, d):
)
self.sm = sacremoses
-
-
-__all__ = ["BioGptTokenizer"]
diff --git a/src/transformers/models/bit/__init__.py b/src/transformers/models/bit/__init__.py
index f46988ca2d8f88..8f298a9adf6535 100644
--- a/src/transformers/models/bit/__init__.py
+++ b/src/transformers/models/bit/__init__.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,17 +13,59 @@
# limitations under the License.
from typing import TYPE_CHECKING
-from ...utils import _LazyModule
-from ...utils.import_utils import define_import_structure
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
+
+
+_import_structure = {"configuration_bit": ["BitConfig", "BitOnnxConfig"]}
+
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_bit"] = [
+ "BitForImageClassification",
+ "BitModel",
+ "BitPreTrainedModel",
+ "BitBackbone",
+ ]
+
+
+try:
+ if not is_vision_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["image_processing_bit"] = ["BitImageProcessor"]
if TYPE_CHECKING:
- from .configuration_bit import *
- from .convert_bit_to_pytorch import *
- from .image_processing_bit import *
- from .modeling_bit import *
+ from .configuration_bit import BitConfig, BitOnnxConfig
+
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_bit import (
+ BitBackbone,
+ BitForImageClassification,
+ BitModel,
+ BitPreTrainedModel,
+ )
+
+ try:
+ if not is_vision_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .image_processing_bit import BitImageProcessor
+
else:
import sys
- _file = globals()["__file__"]
- sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
diff --git a/src/transformers/models/bit/configuration_bit.py b/src/transformers/models/bit/configuration_bit.py
index 238749f1fbe70f..8f4326a2d5a709 100644
--- a/src/transformers/models/bit/configuration_bit.py
+++ b/src/transformers/models/bit/configuration_bit.py
@@ -131,6 +131,3 @@ def __init__(
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
out_features=out_features, out_indices=out_indices, stage_names=self.stage_names
)
-
-
-__all__ = ["BitConfig"]
diff --git a/src/transformers/models/bit/image_processing_bit.py b/src/transformers/models/bit/image_processing_bit.py
index c32bb934bdc528..ba234078997048 100644
--- a/src/transformers/models/bit/image_processing_bit.py
+++ b/src/transformers/models/bit/image_processing_bit.py
@@ -319,6 +319,3 @@ def preprocess(
data = {"pixel_values": images}
return BatchFeature(data=data, tensor_type=return_tensors)
-
-
-__all__ = ["BitImageProcessor"]
diff --git a/src/transformers/models/bit/modeling_bit.py b/src/transformers/models/bit/modeling_bit.py
index 3d834671becccd..3c7e4c57b2f190 100644
--- a/src/transformers/models/bit/modeling_bit.py
+++ b/src/transformers/models/bit/modeling_bit.py
@@ -901,6 +901,3 @@ def forward(
hidden_states=outputs.hidden_states if output_hidden_states else None,
attentions=None,
)
-
-
-__all__ = ["BitForImageClassification", "BitModel", "BitPreTrainedModel", "BitBackbone"]
diff --git a/src/transformers/models/blenderbot/__init__.py b/src/transformers/models/blenderbot/__init__.py
index d1180bd200d45c..8b53b9100a4af1 100644
--- a/src/transformers/models/blenderbot/__init__.py
+++ b/src/transformers/models/blenderbot/__init__.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -11,22 +11,128 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
from typing import TYPE_CHECKING
-from ...utils import _LazyModule
-from ...utils.import_utils import define_import_structure
+from ...utils import (
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ is_flax_available,
+ is_tf_available,
+ is_tokenizers_available,
+ is_torch_available,
+)
+
+
+_import_structure = {
+ "configuration_blenderbot": [
+ "BlenderbotConfig",
+ "BlenderbotOnnxConfig",
+ ],
+ "tokenization_blenderbot": ["BlenderbotTokenizer"],
+}
+
+try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["tokenization_blenderbot_fast"] = ["BlenderbotTokenizerFast"]
+
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_blenderbot"] = [
+ "BlenderbotForCausalLM",
+ "BlenderbotForConditionalGeneration",
+ "BlenderbotModel",
+ "BlenderbotPreTrainedModel",
+ ]
+
+
+try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_tf_blenderbot"] = [
+ "TFBlenderbotForConditionalGeneration",
+ "TFBlenderbotModel",
+ "TFBlenderbotPreTrainedModel",
+ ]
+
+
+try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_flax_blenderbot"] = [
+ "FlaxBlenderbotForConditionalGeneration",
+ "FlaxBlenderbotModel",
+ "FlaxBlenderbotPreTrainedModel",
+ ]
if TYPE_CHECKING:
- from .configuration_blenderbot import *
- from .convert_blenderbot_original_pytorch_checkpoint_to_pytorch import *
- from .modeling_blenderbot import *
- from .modeling_flax_blenderbot import *
- from .modeling_tf_blenderbot import *
- from .tokenization_blenderbot import *
- from .tokenization_blenderbot_fast import *
+ from .configuration_blenderbot import (
+ BlenderbotConfig,
+ BlenderbotOnnxConfig,
+ )
+ from .tokenization_blenderbot import BlenderbotTokenizer
+
+ try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .tokenization_blenderbot_fast import BlenderbotTokenizerFast
+
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_blenderbot import (
+ BlenderbotForCausalLM,
+ BlenderbotForConditionalGeneration,
+ BlenderbotModel,
+ BlenderbotPreTrainedModel,
+ )
+
+ try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_tf_blenderbot import (
+ TFBlenderbotForConditionalGeneration,
+ TFBlenderbotModel,
+ TFBlenderbotPreTrainedModel,
+ )
+
+ try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_flax_blenderbot import (
+ FlaxBlenderbotForConditionalGeneration,
+ FlaxBlenderbotModel,
+ FlaxBlenderbotPreTrainedModel,
+ )
+
else:
import sys
- _file = globals()["__file__"]
- sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
diff --git a/src/transformers/models/blenderbot/configuration_blenderbot.py b/src/transformers/models/blenderbot/configuration_blenderbot.py
index c9f323210e8c47..105d38c2559170 100644
--- a/src/transformers/models/blenderbot/configuration_blenderbot.py
+++ b/src/transformers/models/blenderbot/configuration_blenderbot.py
@@ -390,6 +390,3 @@ def fill_with_past_key_values_(self, inputs_or_outputs: Mapping[str, Mapping[int
inputs_or_outputs[f"{name}.{i}.decoder.value"] = {0: "batch", 2: decoder_sequence}
inputs_or_outputs[f"{name}.{i}.encoder.key"] = {0: "batch", 2: encoder_sequence}
inputs_or_outputs[f"{name}.{i}.encoder.value"] = {0: "batch", 2: encoder_sequence}
-
-
-__all__ = ["BlenderbotConfig", "BlenderbotOnnxConfig"]
diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py
index ace9470d01e3b2..5c4fdfb472c37e 100755
--- a/src/transformers/models/blenderbot/modeling_blenderbot.py
+++ b/src/transformers/models/blenderbot/modeling_blenderbot.py
@@ -1547,11 +1547,3 @@ def _reorder_cache(past_key_values, beam_idx):
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
return reordered_past
-
-
-__all__ = [
- "BlenderbotForCausalLM",
- "BlenderbotForConditionalGeneration",
- "BlenderbotModel",
- "BlenderbotPreTrainedModel",
-]
diff --git a/src/transformers/models/blenderbot/modeling_flax_blenderbot.py b/src/transformers/models/blenderbot/modeling_flax_blenderbot.py
index fcef08fdeab8de..97c9653da36dee 100644
--- a/src/transformers/models/blenderbot/modeling_flax_blenderbot.py
+++ b/src/transformers/models/blenderbot/modeling_flax_blenderbot.py
@@ -1503,6 +1503,3 @@ def update_inputs_for_generation(self, model_outputs, model_kwargs):
append_replace_return_docstrings(
FlaxBlenderbotForConditionalGeneration, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC
)
-
-
-__all__ = ["FlaxBlenderbotForConditionalGeneration", "FlaxBlenderbotModel", "FlaxBlenderbotPreTrainedModel"]
diff --git a/src/transformers/models/blenderbot/modeling_tf_blenderbot.py b/src/transformers/models/blenderbot/modeling_tf_blenderbot.py
index f3476cb925b6b4..bbfe4726deef97 100644
--- a/src/transformers/models/blenderbot/modeling_tf_blenderbot.py
+++ b/src/transformers/models/blenderbot/modeling_tf_blenderbot.py
@@ -1553,6 +1553,3 @@ def build(self, input_shape=None):
if getattr(self, "bias_layer", None) is not None:
with tf.name_scope(self.bias_layer.name):
self.bias_layer.build(None)
-
-
-__all__ = ["TFBlenderbotForConditionalGeneration", "TFBlenderbotModel", "TFBlenderbotPreTrainedModel"]
diff --git a/src/transformers/models/blenderbot/tokenization_blenderbot.py b/src/transformers/models/blenderbot/tokenization_blenderbot.py
index 08b2a8c1283b67..1a8807214d52ba 100644
--- a/src/transformers/models/blenderbot/tokenization_blenderbot.py
+++ b/src/transformers/models/blenderbot/tokenization_blenderbot.py
@@ -405,6 +405,3 @@ def build_inputs_with_special_tokens(self, token_ids_0: List[int], token_ids_1:
`List[int]`: list of [input IDs](../glossary#input-ids) with the appropriate special tokens.
"""
return token_ids_0 + [self.eos_token_id]
-
-
-__all__ = ["BlenderbotTokenizer"]
diff --git a/src/transformers/models/blenderbot/tokenization_blenderbot_fast.py b/src/transformers/models/blenderbot/tokenization_blenderbot_fast.py
index f649246517d271..0d24ed62c574a3 100644
--- a/src/transformers/models/blenderbot/tokenization_blenderbot_fast.py
+++ b/src/transformers/models/blenderbot/tokenization_blenderbot_fast.py
@@ -287,6 +287,3 @@ def build_inputs_with_special_tokens(self, token_ids_0: List[int], token_ids_1:
`List[int]`: list of [input IDs](../glossary#input-ids) with the appropriate special tokens.
"""
return token_ids_0 + [self.eos_token_id]
-
-
-__all__ = ["BlenderbotTokenizerFast"]
diff --git a/src/transformers/models/blenderbot_small/__init__.py b/src/transformers/models/blenderbot_small/__init__.py
index 075d0070e4c4e2..e6cab05c0cae02 100644
--- a/src/transformers/models/blenderbot_small/__init__.py
+++ b/src/transformers/models/blenderbot_small/__init__.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,19 +13,122 @@
# limitations under the License.
from typing import TYPE_CHECKING
-from ...utils import _LazyModule
-from ...utils.import_utils import define_import_structure
+from ...utils import (
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ is_flax_available,
+ is_tf_available,
+ is_tokenizers_available,
+ is_torch_available,
+)
+_import_structure = {
+ "configuration_blenderbot_small": [
+ "BlenderbotSmallConfig",
+ "BlenderbotSmallOnnxConfig",
+ ],
+ "tokenization_blenderbot_small": ["BlenderbotSmallTokenizer"],
+}
+
+try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["tokenization_blenderbot_small_fast"] = ["BlenderbotSmallTokenizerFast"]
+
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_blenderbot_small"] = [
+ "BlenderbotSmallForCausalLM",
+ "BlenderbotSmallForConditionalGeneration",
+ "BlenderbotSmallModel",
+ "BlenderbotSmallPreTrainedModel",
+ ]
+
+try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_tf_blenderbot_small"] = [
+ "TFBlenderbotSmallForConditionalGeneration",
+ "TFBlenderbotSmallModel",
+ "TFBlenderbotSmallPreTrainedModel",
+ ]
+
+try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_flax_blenderbot_small"] = [
+ "FlaxBlenderbotSmallForConditionalGeneration",
+ "FlaxBlenderbotSmallModel",
+ "FlaxBlenderbotSmallPreTrainedModel",
+ ]
+
if TYPE_CHECKING:
- from .configuration_blenderbot_small import *
- from .modeling_blenderbot_small import *
- from .modeling_flax_blenderbot_small import *
- from .modeling_tf_blenderbot_small import *
- from .tokenization_blenderbot_small import *
- from .tokenization_blenderbot_small_fast import *
+ from .configuration_blenderbot_small import (
+ BlenderbotSmallConfig,
+ BlenderbotSmallOnnxConfig,
+ )
+ from .tokenization_blenderbot_small import BlenderbotSmallTokenizer
+
+ try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .tokenization_blenderbot_small_fast import BlenderbotSmallTokenizerFast
+
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_blenderbot_small import (
+ BlenderbotSmallForCausalLM,
+ BlenderbotSmallForConditionalGeneration,
+ BlenderbotSmallModel,
+ BlenderbotSmallPreTrainedModel,
+ )
+
+ try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_tf_blenderbot_small import (
+ TFBlenderbotSmallForConditionalGeneration,
+ TFBlenderbotSmallModel,
+ TFBlenderbotSmallPreTrainedModel,
+ )
+
+ try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_flax_blenderbot_small import (
+ FlaxBlenderbotSmallForConditionalGeneration,
+ FlaxBlenderbotSmallModel,
+ FlaxBlenderbotSmallPreTrainedModel,
+ )
+
else:
import sys
- _file = globals()["__file__"]
- sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
diff --git a/src/transformers/models/blenderbot_small/configuration_blenderbot_small.py b/src/transformers/models/blenderbot_small/configuration_blenderbot_small.py
index 5865486370e5b9..6ee26365de8d88 100644
--- a/src/transformers/models/blenderbot_small/configuration_blenderbot_small.py
+++ b/src/transformers/models/blenderbot_small/configuration_blenderbot_small.py
@@ -385,6 +385,3 @@ def _flatten_past_key_values_(self, flattened_output, name, idx, t):
flattened_output = super(OnnxSeq2SeqConfigWithPast, self)._flatten_past_key_values_(
flattened_output, name, idx, t
)
-
-
-__all__ = ["BlenderbotSmallConfig", "BlenderbotSmallOnnxConfig"]
diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py
index 8564fbf3115d96..6f79d2a7d005cc 100755
--- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py
+++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py
@@ -1499,11 +1499,3 @@ def _reorder_cache(past_key_values, beam_idx):
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
return reordered_past
-
-
-__all__ = [
- "BlenderbotSmallForCausalLM",
- "BlenderbotSmallForConditionalGeneration",
- "BlenderbotSmallModel",
- "BlenderbotSmallPreTrainedModel",
-]
diff --git a/src/transformers/models/blenderbot_small/modeling_flax_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_flax_blenderbot_small.py
index 236685ac5971f6..325ff0a20b5567 100644
--- a/src/transformers/models/blenderbot_small/modeling_flax_blenderbot_small.py
+++ b/src/transformers/models/blenderbot_small/modeling_flax_blenderbot_small.py
@@ -1519,10 +1519,3 @@ def update_inputs_for_generation(self, model_outputs, model_kwargs):
append_replace_return_docstrings(
FlaxBlenderbotSmallForConditionalGeneration, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC
)
-
-
-__all__ = [
- "FlaxBlenderbotSmallForConditionalGeneration",
- "FlaxBlenderbotSmallModel",
- "FlaxBlenderbotSmallPreTrainedModel",
-]
diff --git a/src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py
index 4de98280836d4a..15764629799098 100644
--- a/src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py
+++ b/src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py
@@ -1523,6 +1523,3 @@ def build(self, input_shape=None):
if getattr(self, "bias_layer", None) is not None:
with tf.name_scope(self.bias_layer.name):
self.bias_layer.build(None)
-
-
-__all__ = ["TFBlenderbotSmallForConditionalGeneration", "TFBlenderbotSmallModel", "TFBlenderbotSmallPreTrainedModel"]
diff --git a/src/transformers/models/blenderbot_small/tokenization_blenderbot_small.py b/src/transformers/models/blenderbot_small/tokenization_blenderbot_small.py
index be950f0dbe629b..08c7be332e31ef 100644
--- a/src/transformers/models/blenderbot_small/tokenization_blenderbot_small.py
+++ b/src/transformers/models/blenderbot_small/tokenization_blenderbot_small.py
@@ -217,6 +217,3 @@ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] =
index += 1
return vocab_file, merge_file
-
-
-__all__ = ["BlenderbotSmallTokenizer"]
diff --git a/src/transformers/models/blenderbot_small/tokenization_blenderbot_small_fast.py b/src/transformers/models/blenderbot_small/tokenization_blenderbot_small_fast.py
index ac98ce008baad8..21fb76cbfc8691 100644
--- a/src/transformers/models/blenderbot_small/tokenization_blenderbot_small_fast.py
+++ b/src/transformers/models/blenderbot_small/tokenization_blenderbot_small_fast.py
@@ -98,6 +98,3 @@ def create_token_type_ids_from_sequences(
if token_ids_1 is None:
return len(cls + token_ids_0 + sep) * [0]
return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
-
-
-__all__ = ["BlenderbotSmallTokenizerFast"]
diff --git a/src/transformers/models/blip/__init__.py b/src/transformers/models/blip/__init__.py
index b3b604b24307ce..f78c2500bd64f4 100644
--- a/src/transformers/models/blip/__init__.py
+++ b/src/transformers/models/blip/__init__.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,21 +13,110 @@
# limitations under the License.
from typing import TYPE_CHECKING
-from ...utils import _LazyModule
-from ...utils.import_utils import define_import_structure
+from ...utils import (
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ is_tf_available,
+ is_torch_available,
+ is_vision_available,
+)
+_import_structure = {
+ "configuration_blip": [
+ "BlipConfig",
+ "BlipTextConfig",
+ "BlipVisionConfig",
+ ],
+ "processing_blip": ["BlipProcessor"],
+}
+
+try:
+ if not is_vision_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["image_processing_blip"] = ["BlipImageProcessor"]
+
+
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_blip"] = [
+ "BlipModel",
+ "BlipPreTrainedModel",
+ "BlipForConditionalGeneration",
+ "BlipForQuestionAnswering",
+ "BlipVisionModel",
+ "BlipTextModel",
+ "BlipForImageTextRetrieval",
+ ]
+
+try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_tf_blip"] = [
+ "TFBlipModel",
+ "TFBlipPreTrainedModel",
+ "TFBlipForConditionalGeneration",
+ "TFBlipForQuestionAnswering",
+ "TFBlipVisionModel",
+ "TFBlipTextModel",
+ "TFBlipForImageTextRetrieval",
+ ]
+
if TYPE_CHECKING:
- from .configuration_blip import *
- from .convert_blip_original_pytorch_to_hf import *
- from .image_processing_blip import *
- from .modeling_blip import *
- from .modeling_blip_text import *
- from .modeling_tf_blip import *
- from .modeling_tf_blip_text import *
- from .processing_blip import *
+ from .configuration_blip import BlipConfig, BlipTextConfig, BlipVisionConfig
+ from .processing_blip import BlipProcessor
+
+ try:
+ if not is_vision_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .image_processing_blip import BlipImageProcessor
+
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_blip import (
+ BlipForConditionalGeneration,
+ BlipForImageTextRetrieval,
+ BlipForQuestionAnswering,
+ BlipModel,
+ BlipPreTrainedModel,
+ BlipTextModel,
+ BlipVisionModel,
+ )
+
+ try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_tf_blip import (
+ TFBlipForConditionalGeneration,
+ TFBlipForImageTextRetrieval,
+ TFBlipForQuestionAnswering,
+ TFBlipModel,
+ TFBlipPreTrainedModel,
+ TFBlipTextModel,
+ TFBlipVisionModel,
+ )
+
else:
import sys
- _file = globals()["__file__"]
- sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
diff --git a/src/transformers/models/blip/configuration_blip.py b/src/transformers/models/blip/configuration_blip.py
index c46cd2a08be28e..18db71eb14890b 100644
--- a/src/transformers/models/blip/configuration_blip.py
+++ b/src/transformers/models/blip/configuration_blip.py
@@ -324,6 +324,3 @@ def from_text_vision_configs(cls, text_config: BlipTextConfig, vision_config: Bl
"""
return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
-
-
-__all__ = ["BlipConfig", "BlipTextConfig", "BlipVisionConfig"]
diff --git a/src/transformers/models/blip/image_processing_blip.py b/src/transformers/models/blip/image_processing_blip.py
index 6bb2dd23733ee3..6f520f9fb9cb77 100644
--- a/src/transformers/models/blip/image_processing_blip.py
+++ b/src/transformers/models/blip/image_processing_blip.py
@@ -292,6 +292,3 @@ def preprocess(
encoded_outputs = BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors)
return encoded_outputs
-
-
-__all__ = ["BlipImageProcessor"]
diff --git a/src/transformers/models/blip/modeling_blip.py b/src/transformers/models/blip/modeling_blip.py
index 27dbbee6c671ee..b623d2a8adb17b 100644
--- a/src/transformers/models/blip/modeling_blip.py
+++ b/src/transformers/models/blip/modeling_blip.py
@@ -464,8 +464,6 @@ class BlipPreTrainedModel(PreTrainedModel):
config_class = BlipConfig
base_model_prefix = "blip"
supports_gradient_checkpointing = True
- _no_split_modules = ["BlipEncoderLayer", "BlipTextEmbeddings"]
- _skip_keys_device_placement = ["past_key_value"]
def _init_weights(self, module):
"""Initialize the weights"""
@@ -1011,8 +1009,7 @@ def forward(
text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
# cosine similarity as logits
- logit_scale = self.logit_scale.exp().to(device=text_embeds.device)
- image_embeds = image_embeds.to(device=text_embeds.device, dtype=text_embeds.dtype)
+ logit_scale = self.logit_scale.exp()
logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
logits_per_image = logits_per_text.t()
@@ -1583,14 +1580,3 @@ def forward(
attentions=vision_outputs.attentions,
question_embeds=question_embeds,
)
-
-
-__all__ = [
- "BlipModel",
- "BlipPreTrainedModel",
- "BlipForConditionalGeneration",
- "BlipForQuestionAnswering",
- "BlipVisionModel",
- "BlipTextModel",
- "BlipForImageTextRetrieval",
-]
diff --git a/src/transformers/models/blip/modeling_blip_text.py b/src/transformers/models/blip/modeling_blip_text.py
index db8ad939725aca..97a4f523380bc5 100644
--- a/src/transformers/models/blip/modeling_blip_text.py
+++ b/src/transformers/models/blip/modeling_blip_text.py
@@ -82,6 +82,7 @@ def forward(
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
if inputs_embeds is None:
+ input_ids = input_ids.to(self.word_embeddings.weight.device)
inputs_embeds = self.word_embeddings(input_ids)
embeddings = inputs_embeds
diff --git a/src/transformers/models/blip/modeling_tf_blip.py b/src/transformers/models/blip/modeling_tf_blip.py
index 92f61bf470d93f..6c9942b73acefb 100644
--- a/src/transformers/models/blip/modeling_tf_blip.py
+++ b/src/transformers/models/blip/modeling_tf_blip.py
@@ -1696,14 +1696,3 @@ def build(self, input_shape=None):
if getattr(self, "itm_head", None) is not None:
with tf.name_scope(self.itm_head.name):
self.itm_head.build([None, None, self.config.text_config.hidden_size])
-
-
-__all__ = [
- "TFBlipModel",
- "TFBlipPreTrainedModel",
- "TFBlipForConditionalGeneration",
- "TFBlipForQuestionAnswering",
- "TFBlipVisionModel",
- "TFBlipTextModel",
- "TFBlipForImageTextRetrieval",
-]
diff --git a/src/transformers/models/blip/processing_blip.py b/src/transformers/models/blip/processing_blip.py
index edef863e404907..78e1aa58ef0443 100644
--- a/src/transformers/models/blip/processing_blip.py
+++ b/src/transformers/models/blip/processing_blip.py
@@ -134,6 +134,3 @@ def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
image_processor_input_names = self.image_processor.model_input_names
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
-
-
-__all__ = ["BlipProcessor"]
diff --git a/src/transformers/models/blip_2/__init__.py b/src/transformers/models/blip_2/__init__.py
index 1014e8c88102c9..329ddfe19ac66c 100644
--- a/src/transformers/models/blip_2/__init__.py
+++ b/src/transformers/models/blip_2/__init__.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,17 +13,61 @@
# limitations under the License.
from typing import TYPE_CHECKING
-from ...utils import _LazyModule
-from ...utils.import_utils import define_import_structure
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
+_import_structure = {
+ "configuration_blip_2": [
+ "Blip2Config",
+ "Blip2QFormerConfig",
+ "Blip2VisionConfig",
+ ],
+ "processing_blip_2": ["Blip2Processor"],
+}
+
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_blip_2"] = [
+ "Blip2Model",
+ "Blip2VisionModelWithProjection",
+ "Blip2QFormerModel",
+ "Blip2PreTrainedModel",
+ "Blip2ForConditionalGeneration",
+ "Blip2ForImageTextRetrieval",
+ "Blip2VisionModel",
+ "Blip2TextModelWithProjection",
+ ]
+
if TYPE_CHECKING:
- from .configuration_blip_2 import *
- from .convert_blip_2_original_to_pytorch import *
- from .modeling_blip_2 import *
- from .processing_blip_2 import *
+ from .configuration_blip_2 import (
+ Blip2Config,
+ Blip2QFormerConfig,
+ Blip2VisionConfig,
+ )
+ from .processing_blip_2 import Blip2Processor
+
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_blip_2 import (
+ Blip2ForConditionalGeneration,
+ Blip2ForImageTextRetrieval,
+ Blip2Model,
+ Blip2PreTrainedModel,
+ Blip2QFormerModel,
+ Blip2TextModelWithProjection,
+ Blip2VisionModel,
+ Blip2VisionModelWithProjection,
+ )
+
else:
import sys
- _file = globals()["__file__"]
- sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
diff --git a/src/transformers/models/blip_2/configuration_blip_2.py b/src/transformers/models/blip_2/configuration_blip_2.py
index 539a3e365c9883..d690d22338a687 100644
--- a/src/transformers/models/blip_2/configuration_blip_2.py
+++ b/src/transformers/models/blip_2/configuration_blip_2.py
@@ -343,6 +343,3 @@ def from_vision_qformer_text_configs(
text_config=text_config.to_dict() if text_config is not None else None,
**kwargs,
)
-
-
-__all__ = ["Blip2Config", "Blip2QFormerConfig", "Blip2VisionConfig"]
diff --git a/src/transformers/models/blip_2/modeling_blip_2.py b/src/transformers/models/blip_2/modeling_blip_2.py
index 99d678b1227be3..ed8ddd3c47dea3 100644
--- a/src/transformers/models/blip_2/modeling_blip_2.py
+++ b/src/transformers/models/blip_2/modeling_blip_2.py
@@ -2533,15 +2533,3 @@ def forward(
text_model_output=text_outputs,
vision_model_output=vision_outputs,
)
-
-
-__all__ = [
- "Blip2Model",
- "Blip2VisionModelWithProjection",
- "Blip2QFormerModel",
- "Blip2PreTrainedModel",
- "Blip2ForConditionalGeneration",
- "Blip2ForImageTextRetrieval",
- "Blip2VisionModel",
- "Blip2TextModelWithProjection",
-]
diff --git a/src/transformers/models/blip_2/processing_blip_2.py b/src/transformers/models/blip_2/processing_blip_2.py
index 5d09ea7c07668b..4129920f9b3663 100644
--- a/src/transformers/models/blip_2/processing_blip_2.py
+++ b/src/transformers/models/blip_2/processing_blip_2.py
@@ -188,6 +188,3 @@ def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
image_processor_input_names = self.image_processor.model_input_names
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
-
-
-__all__ = ["Blip2Processor"]
diff --git a/src/transformers/models/bloom/__init__.py b/src/transformers/models/bloom/__init__.py
index 012bbbc15c25d6..3c903b39dca23f 100644
--- a/src/transformers/models/bloom/__init__.py
+++ b/src/transformers/models/bloom/__init__.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -11,20 +11,91 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
from typing import TYPE_CHECKING
-from ...utils import _LazyModule
-from ...utils.import_utils import define_import_structure
+from ...utils import (
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ is_flax_available,
+ is_tokenizers_available,
+ is_torch_available,
+)
+
+
+_import_structure = {
+ "configuration_bloom": ["BloomConfig", "BloomOnnxConfig"],
+}
+try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["tokenization_bloom_fast"] = ["BloomTokenizerFast"]
+
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_bloom"] = [
+ "BloomForCausalLM",
+ "BloomModel",
+ "BloomPreTrainedModel",
+ "BloomForSequenceClassification",
+ "BloomForTokenClassification",
+ "BloomForQuestionAnswering",
+ ]
+
+try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_flax_bloom"] = [
+ "FlaxBloomForCausalLM",
+ "FlaxBloomModel",
+ "FlaxBloomPreTrainedModel",
+ ]
if TYPE_CHECKING:
- from .configuration_bloom import *
- from .convert_bloom_original_checkpoint_to_pytorch import *
- from .modeling_bloom import *
- from .modeling_flax_bloom import *
- from .tokenization_bloom_fast import *
+ from .configuration_bloom import BloomConfig, BloomOnnxConfig
+
+ try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .tokenization_bloom_fast import BloomTokenizerFast
+
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_bloom import (
+ BloomForCausalLM,
+ BloomForQuestionAnswering,
+ BloomForSequenceClassification,
+ BloomForTokenClassification,
+ BloomModel,
+ BloomPreTrainedModel,
+ )
+
+ try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_flax_bloom import FlaxBloomForCausalLM, FlaxBloomModel, FlaxBloomPreTrainedModel
else:
import sys
- _file = globals()["__file__"]
- sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
diff --git a/src/transformers/models/bloom/configuration_bloom.py b/src/transformers/models/bloom/configuration_bloom.py
index ca10c7ce7ed4ef..dc9f6d3082ecbe 100644
--- a/src/transformers/models/bloom/configuration_bloom.py
+++ b/src/transformers/models/bloom/configuration_bloom.py
@@ -232,6 +232,3 @@ def generate_dummy_inputs(
@property
def default_onnx_opset(self) -> int:
return 13
-
-
-__all__ = ["BloomConfig", "BloomOnnxConfig"]
diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py
index 086f8ce03c62fc..b3dd3446cd848e 100644
--- a/src/transformers/models/bloom/modeling_bloom.py
+++ b/src/transformers/models/bloom/modeling_bloom.py
@@ -1362,13 +1362,3 @@ def forward(
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
-
-
-__all__ = [
- "BloomForCausalLM",
- "BloomModel",
- "BloomPreTrainedModel",
- "BloomForSequenceClassification",
- "BloomForTokenClassification",
- "BloomForQuestionAnswering",
-]
diff --git a/src/transformers/models/bloom/modeling_flax_bloom.py b/src/transformers/models/bloom/modeling_flax_bloom.py
index 077c2123bf95c4..187230f35ab9e4 100644
--- a/src/transformers/models/bloom/modeling_flax_bloom.py
+++ b/src/transformers/models/bloom/modeling_flax_bloom.py
@@ -732,6 +732,3 @@ def update_inputs_for_generation(self, model_outputs, model_kwargs):
append_call_sample_docstring(FlaxBloomForCausalLM, _CHECKPOINT_FOR_DOC, FlaxCausalLMOutput, _CONFIG_FOR_DOC)
-
-
-__all__ = ["FlaxBloomForCausalLM", "FlaxBloomModel", "FlaxBloomPreTrainedModel"]
diff --git a/src/transformers/models/bloom/tokenization_bloom_fast.py b/src/transformers/models/bloom/tokenization_bloom_fast.py
index c84322637cb7e8..3ea7a1a39cd8a5 100644
--- a/src/transformers/models/bloom/tokenization_bloom_fast.py
+++ b/src/transformers/models/bloom/tokenization_bloom_fast.py
@@ -147,6 +147,3 @@ def _encode_plus(self, *args, **kwargs) -> BatchEncoding:
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
files = self._tokenizer.model.save(save_directory, name=filename_prefix)
return tuple(files)
-
-
-__all__ = ["BloomTokenizerFast"]
diff --git a/src/transformers/models/bridgetower/__init__.py b/src/transformers/models/bridgetower/__init__.py
index 65613444624fcf..3120ca9f2a163a 100644
--- a/src/transformers/models/bridgetower/__init__.py
+++ b/src/transformers/models/bridgetower/__init__.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2023 The Intel Labs Team Authors, The Microsoft Research Team Authors and HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,17 +13,73 @@
# limitations under the License.
from typing import TYPE_CHECKING
-from ...utils import _LazyModule
-from ...utils.import_utils import define_import_structure
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
+
+
+_import_structure = {
+ "configuration_bridgetower": [
+ "BridgeTowerConfig",
+ "BridgeTowerTextConfig",
+ "BridgeTowerVisionConfig",
+ ],
+ "processing_bridgetower": ["BridgeTowerProcessor"],
+}
+
+try:
+ if not is_vision_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["image_processing_bridgetower"] = ["BridgeTowerImageProcessor"]
+
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_bridgetower"] = [
+ "BridgeTowerForContrastiveLearning",
+ "BridgeTowerForImageAndTextRetrieval",
+ "BridgeTowerForMaskedLM",
+ "BridgeTowerModel",
+ "BridgeTowerPreTrainedModel",
+ ]
if TYPE_CHECKING:
- from .configuration_bridgetower import *
- from .image_processing_bridgetower import *
- from .modeling_bridgetower import *
- from .processing_bridgetower import *
+ from .configuration_bridgetower import (
+ BridgeTowerConfig,
+ BridgeTowerTextConfig,
+ BridgeTowerVisionConfig,
+ )
+ from .processing_bridgetower import BridgeTowerProcessor
+
+ try:
+ if not is_vision_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .image_processing_bridgetower import BridgeTowerImageProcessor
+
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_bridgetower import (
+ BridgeTowerForContrastiveLearning,
+ BridgeTowerForImageAndTextRetrieval,
+ BridgeTowerForMaskedLM,
+ BridgeTowerModel,
+ BridgeTowerPreTrainedModel,
+ )
+
+
else:
import sys
- _file = globals()["__file__"]
- sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
diff --git a/src/transformers/models/bridgetower/configuration_bridgetower.py b/src/transformers/models/bridgetower/configuration_bridgetower.py
index 6a3d9072defadc..de49283493b63f 100644
--- a/src/transformers/models/bridgetower/configuration_bridgetower.py
+++ b/src/transformers/models/bridgetower/configuration_bridgetower.py
@@ -314,6 +314,3 @@ def from_text_vision_configs(
"""
return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
-
-
-__all__ = ["BridgeTowerConfig", "BridgeTowerTextConfig", "BridgeTowerVisionConfig"]
diff --git a/src/transformers/models/bridgetower/image_processing_bridgetower.py b/src/transformers/models/bridgetower/image_processing_bridgetower.py
index a8b94e7c9709dc..7272093715f882 100644
--- a/src/transformers/models/bridgetower/image_processing_bridgetower.py
+++ b/src/transformers/models/bridgetower/image_processing_bridgetower.py
@@ -538,6 +538,3 @@ def preprocess(
encoded_outputs = BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors)
return encoded_outputs
-
-
-__all__ = ["BridgeTowerImageProcessor"]
diff --git a/src/transformers/models/bridgetower/modeling_bridgetower.py b/src/transformers/models/bridgetower/modeling_bridgetower.py
index 0d4338261eec4b..9a900acf500c1e 100644
--- a/src/transformers/models/bridgetower/modeling_bridgetower.py
+++ b/src/transformers/models/bridgetower/modeling_bridgetower.py
@@ -1973,12 +1973,3 @@ def forward(
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
-
-
-__all__ = [
- "BridgeTowerForContrastiveLearning",
- "BridgeTowerForImageAndTextRetrieval",
- "BridgeTowerForMaskedLM",
- "BridgeTowerModel",
- "BridgeTowerPreTrainedModel",
-]
diff --git a/src/transformers/models/bridgetower/processing_bridgetower.py b/src/transformers/models/bridgetower/processing_bridgetower.py
index 5519d0a34ce911..177eb12051654d 100644
--- a/src/transformers/models/bridgetower/processing_bridgetower.py
+++ b/src/transformers/models/bridgetower/processing_bridgetower.py
@@ -109,6 +109,3 @@ def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
image_processor_input_names = self.image_processor.model_input_names
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
-
-
-__all__ = ["BridgeTowerProcessor"]
diff --git a/src/transformers/models/bros/__init__.py b/src/transformers/models/bros/__init__.py
index 54e429863ec85b..516c6349cd120c 100644
--- a/src/transformers/models/bros/__init__.py
+++ b/src/transformers/models/bros/__init__.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2023-present NAVER Corp, The Microsoft Research Asia LayoutLM Team Authors and the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,17 +13,63 @@
# limitations under the License.
from typing import TYPE_CHECKING
-from ...utils import _LazyModule
-from ...utils.import_utils import define_import_structure
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tokenizers_available, is_torch_available
+
+
+_import_structure = {
+ "configuration_bros": ["BrosConfig"],
+}
+
+try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["processing_bros"] = ["BrosProcessor"]
+
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_bros"] = [
+ "BrosPreTrainedModel",
+ "BrosModel",
+ "BrosForTokenClassification",
+ "BrosSpadeEEForTokenClassification",
+ "BrosSpadeELForTokenClassification",
+ ]
if TYPE_CHECKING:
- from .configuration_bros import *
- from .convert_bros_to_pytorch import *
- from .modeling_bros import *
- from .processing_bros import *
+ from .configuration_bros import BrosConfig
+
+ try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .processing_bros import BrosProcessor
+
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_bros import (
+ BrosForTokenClassification,
+ BrosModel,
+ BrosPreTrainedModel,
+ BrosSpadeEEForTokenClassification,
+ BrosSpadeELForTokenClassification,
+ )
+
+
else:
import sys
- _file = globals()["__file__"]
- sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
diff --git a/src/transformers/models/bros/configuration_bros.py b/src/transformers/models/bros/configuration_bros.py
index 84c9989f309fb7..8c2a3cc73a55a0 100644
--- a/src/transformers/models/bros/configuration_bros.py
+++ b/src/transformers/models/bros/configuration_bros.py
@@ -133,6 +133,3 @@ def __init__(
self.dim_bbox_sinusoid_emb_1d = self.dim_bbox_sinusoid_emb_2d // self.dim_bbox
self.dim_bbox_projection = self.hidden_size // self.num_attention_heads
self.classifier_dropout_prob = classifier_dropout_prob
-
-
-__all__ = ["BrosConfig"]
diff --git a/src/transformers/models/bros/modeling_bros.py b/src/transformers/models/bros/modeling_bros.py
index 0e1e86c0b39f7e..c062278309b7b6 100755
--- a/src/transformers/models/bros/modeling_bros.py
+++ b/src/transformers/models/bros/modeling_bros.py
@@ -1312,12 +1312,3 @@ def forward(
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
-
-
-__all__ = [
- "BrosPreTrainedModel",
- "BrosModel",
- "BrosForTokenClassification",
- "BrosSpadeEEForTokenClassification",
- "BrosSpadeELForTokenClassification",
-]
diff --git a/src/transformers/models/bros/processing_bros.py b/src/transformers/models/bros/processing_bros.py
index 4687e7f8a86ae5..9c2e0642d8cdc4 100644
--- a/src/transformers/models/bros/processing_bros.py
+++ b/src/transformers/models/bros/processing_bros.py
@@ -107,6 +107,3 @@ def decode(self, *args, **kwargs):
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
return list(dict.fromkeys(tokenizer_input_names))
-
-
-__all__ = ["BrosProcessor"]
diff --git a/src/transformers/models/byt5/__init__.py b/src/transformers/models/byt5/__init__.py
index c4243d1970d31d..662a427383ff69 100644
--- a/src/transformers/models/byt5/__init__.py
+++ b/src/transformers/models/byt5/__init__.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -11,17 +11,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
from typing import TYPE_CHECKING
from ...utils import _LazyModule
-from ...utils.import_utils import define_import_structure
+
+
+_import_structure = {"tokenization_byt5": ["ByT5Tokenizer"]}
if TYPE_CHECKING:
- from .convert_byt5_original_tf_checkpoint_to_pytorch import *
- from .tokenization_byt5 import *
+ from .tokenization_byt5 import ByT5Tokenizer
else:
import sys
- _file = globals()["__file__"]
- sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
diff --git a/src/transformers/models/byt5/tokenization_byt5.py b/src/transformers/models/byt5/tokenization_byt5.py
index b39ba254b38170..21513ab4cd3ce1 100644
--- a/src/transformers/models/byt5/tokenization_byt5.py
+++ b/src/transformers/models/byt5/tokenization_byt5.py
@@ -231,6 +231,3 @@ def convert_tokens_to_string(self, tokens):
# ByT5Tokenizer has no vocab file
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
return ()
-
-
-__all__ = ["ByT5Tokenizer"]
diff --git a/src/transformers/models/camembert/__init__.py b/src/transformers/models/camembert/__init__.py
index 9d90f64de97f78..1759762f47f1a1 100644
--- a/src/transformers/models/camembert/__init__.py
+++ b/src/transformers/models/camembert/__init__.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -11,20 +11,128 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
from typing import TYPE_CHECKING
-from ...utils import _LazyModule
-from ...utils.import_utils import define_import_structure
+from ...utils import (
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ is_sentencepiece_available,
+ is_tf_available,
+ is_tokenizers_available,
+ is_torch_available,
+)
+
+
+_import_structure = {
+ "configuration_camembert": ["CamembertConfig", "CamembertOnnxConfig"],
+}
+
+try:
+ if not is_sentencepiece_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["tokenization_camembert"] = ["CamembertTokenizer"]
+
+try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["tokenization_camembert_fast"] = ["CamembertTokenizerFast"]
+
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_camembert"] = [
+ "CamembertForCausalLM",
+ "CamembertForMaskedLM",
+ "CamembertForMultipleChoice",
+ "CamembertForQuestionAnswering",
+ "CamembertForSequenceClassification",
+ "CamembertForTokenClassification",
+ "CamembertModel",
+ "CamembertPreTrainedModel",
+ ]
+
+try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_tf_camembert"] = [
+ "TFCamembertForCausalLM",
+ "TFCamembertForMaskedLM",
+ "TFCamembertForMultipleChoice",
+ "TFCamembertForQuestionAnswering",
+ "TFCamembertForSequenceClassification",
+ "TFCamembertForTokenClassification",
+ "TFCamembertModel",
+ "TFCamembertPreTrainedModel",
+ ]
if TYPE_CHECKING:
- from .configuration_camembert import *
- from .modeling_camembert import *
- from .modeling_tf_camembert import *
- from .tokenization_camembert import *
- from .tokenization_camembert_fast import *
+ from .configuration_camembert import CamembertConfig, CamembertOnnxConfig
+
+ try:
+ if not is_sentencepiece_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .tokenization_camembert import CamembertTokenizer
+
+ try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .tokenization_camembert_fast import CamembertTokenizerFast
+
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_camembert import (
+ CamembertForCausalLM,
+ CamembertForMaskedLM,
+ CamembertForMultipleChoice,
+ CamembertForQuestionAnswering,
+ CamembertForSequenceClassification,
+ CamembertForTokenClassification,
+ CamembertModel,
+ CamembertPreTrainedModel,
+ )
+
+ try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_tf_camembert import (
+ TFCamembertForCausalLM,
+ TFCamembertForMaskedLM,
+ TFCamembertForMultipleChoice,
+ TFCamembertForQuestionAnswering,
+ TFCamembertForSequenceClassification,
+ TFCamembertForTokenClassification,
+ TFCamembertModel,
+ TFCamembertPreTrainedModel,
+ )
+
else:
import sys
- _file = globals()["__file__"]
- sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
diff --git a/src/transformers/models/camembert/configuration_camembert.py b/src/transformers/models/camembert/configuration_camembert.py
index eaf8c94b891481..b5738012008a00 100644
--- a/src/transformers/models/camembert/configuration_camembert.py
+++ b/src/transformers/models/camembert/configuration_camembert.py
@@ -150,6 +150,3 @@ def inputs(self) -> Mapping[str, Mapping[int, str]]:
("attention_mask", dynamic_axis),
]
)
-
-
-__all__ = ["CamembertConfig", "CamembertOnnxConfig"]
diff --git a/src/transformers/models/camembert/modeling_camembert.py b/src/transformers/models/camembert/modeling_camembert.py
index e94e4a0a8948c5..32e8a0af2ba2d1 100644
--- a/src/transformers/models/camembert/modeling_camembert.py
+++ b/src/transformers/models/camembert/modeling_camembert.py
@@ -1698,15 +1698,3 @@ def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_l
mask = input_ids.ne(padding_idx).int()
incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask
return incremental_indices.long() + padding_idx
-
-
-__all__ = [
- "CamembertForCausalLM",
- "CamembertForMaskedLM",
- "CamembertForMultipleChoice",
- "CamembertForQuestionAnswering",
- "CamembertForSequenceClassification",
- "CamembertForTokenClassification",
- "CamembertModel",
- "CamembertPreTrainedModel",
-]
diff --git a/src/transformers/models/camembert/modeling_tf_camembert.py b/src/transformers/models/camembert/modeling_tf_camembert.py
index 6f456723dea54a..f5ddc2242b6868 100644
--- a/src/transformers/models/camembert/modeling_tf_camembert.py
+++ b/src/transformers/models/camembert/modeling_tf_camembert.py
@@ -1787,15 +1787,3 @@ def build(self, input_shape=None):
if getattr(self, "lm_head", None) is not None:
with tf.name_scope(self.lm_head.name):
self.lm_head.build(None)
-
-
-__all__ = [
- "TFCamembertForCausalLM",
- "TFCamembertForMaskedLM",
- "TFCamembertForMultipleChoice",
- "TFCamembertForQuestionAnswering",
- "TFCamembertForSequenceClassification",
- "TFCamembertForTokenClassification",
- "TFCamembertModel",
- "TFCamembertPreTrainedModel",
-]
diff --git a/src/transformers/models/camembert/tokenization_camembert.py b/src/transformers/models/camembert/tokenization_camembert.py
index 3353bf3433c7e1..113fe1b121e2d9 100644
--- a/src/transformers/models/camembert/tokenization_camembert.py
+++ b/src/transformers/models/camembert/tokenization_camembert.py
@@ -316,6 +316,3 @@ def create_token_type_ids_from_sequences(
if token_ids_1 is None:
return len(cls + token_ids_0 + sep) * [0]
return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
-
-
-__all__ = ["CamembertTokenizer"]
diff --git a/src/transformers/models/camembert/tokenization_camembert_fast.py b/src/transformers/models/camembert/tokenization_camembert_fast.py
index c04b5618390234..ffec8d98e194cb 100644
--- a/src/transformers/models/camembert/tokenization_camembert_fast.py
+++ b/src/transformers/models/camembert/tokenization_camembert_fast.py
@@ -196,6 +196,3 @@ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] =
copyfile(self.vocab_file, out_vocab_file)
return (out_vocab_file,)
-
-
-__all__ = ["CamembertTokenizerFast"]
diff --git a/src/transformers/models/canine/__init__.py b/src/transformers/models/canine/__init__.py
index 5f9611153bbd40..93f103344d476b 100644
--- a/src/transformers/models/canine/__init__.py
+++ b/src/transformers/models/canine/__init__.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,17 +13,55 @@
# limitations under the License.
from typing import TYPE_CHECKING
-from ...utils import _LazyModule
-from ...utils.import_utils import define_import_structure
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tokenizers_available, is_torch_available
+
+
+_import_structure = {
+ "configuration_canine": ["CanineConfig"],
+ "tokenization_canine": ["CanineTokenizer"],
+}
+
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_canine"] = [
+ "CanineForMultipleChoice",
+ "CanineForQuestionAnswering",
+ "CanineForSequenceClassification",
+ "CanineForTokenClassification",
+ "CanineLayer",
+ "CanineModel",
+ "CaninePreTrainedModel",
+ "load_tf_weights_in_canine",
+ ]
if TYPE_CHECKING:
- from .configuration_canine import *
- from .convert_canine_original_tf_checkpoint_to_pytorch import *
- from .modeling_canine import *
- from .tokenization_canine import *
+ from .configuration_canine import CanineConfig
+ from .tokenization_canine import CanineTokenizer
+
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_canine import (
+ CanineForMultipleChoice,
+ CanineForQuestionAnswering,
+ CanineForSequenceClassification,
+ CanineForTokenClassification,
+ CanineLayer,
+ CanineModel,
+ CaninePreTrainedModel,
+ load_tf_weights_in_canine,
+ )
+
+
else:
import sys
- _file = globals()["__file__"]
- sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
diff --git a/src/transformers/models/canine/configuration_canine.py b/src/transformers/models/canine/configuration_canine.py
index 29e90327d08f02..9add399112f290 100644
--- a/src/transformers/models/canine/configuration_canine.py
+++ b/src/transformers/models/canine/configuration_canine.py
@@ -136,6 +136,3 @@ def __init__(
self.num_hash_functions = num_hash_functions
self.num_hash_buckets = num_hash_buckets
self.local_transformer_stride = local_transformer_stride
-
-
-__all__ = ["CanineConfig"]
diff --git a/src/transformers/models/canine/modeling_canine.py b/src/transformers/models/canine/modeling_canine.py
index 9f18fc9ac3df19..c48559497a2ec0 100644
--- a/src/transformers/models/canine/modeling_canine.py
+++ b/src/transformers/models/canine/modeling_canine.py
@@ -1639,15 +1639,3 @@ def forward(
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
-
-
-__all__ = [
- "CanineForMultipleChoice",
- "CanineForQuestionAnswering",
- "CanineForSequenceClassification",
- "CanineForTokenClassification",
- "CanineLayer",
- "CanineModel",
- "CaninePreTrainedModel",
- "load_tf_weights_in_canine",
-]
diff --git a/src/transformers/models/canine/tokenization_canine.py b/src/transformers/models/canine/tokenization_canine.py
index fe2734712dca5b..024507f77877d7 100644
--- a/src/transformers/models/canine/tokenization_canine.py
+++ b/src/transformers/models/canine/tokenization_canine.py
@@ -239,6 +239,3 @@ def create_token_type_ids_from_sequences(
# CanineTokenizer has no vocab file
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None):
return ()
-
-
-__all__ = ["CanineTokenizer"]
diff --git a/src/transformers/models/chameleon/__init__.py b/src/transformers/models/chameleon/__init__.py
index ad00f5cd3dab3d..e8e38630d25253 100644
--- a/src/transformers/models/chameleon/__init__.py
+++ b/src/transformers/models/chameleon/__init__.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2024 Meta Inc. and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,18 +13,71 @@
# limitations under the License.
from typing import TYPE_CHECKING
-from ...utils import _LazyModule
-from ...utils.import_utils import define_import_structure
+from ...utils import (
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ is_sentencepiece_available,
+ is_tokenizers_available,
+ is_torch_available,
+ is_vision_available,
+)
+
+
+_import_structure = {
+ "configuration_chameleon": ["ChameleonConfig", "ChameleonVQVAEConfig"],
+ "processing_chameleon": ["ChameleonProcessor"],
+}
+
+
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_chameleon"] = [
+ "ChameleonForConditionalGeneration",
+ "ChameleonModel",
+ "ChameleonPreTrainedModel",
+ "ChameleonVQVAE",
+ ]
+
+try:
+ if not is_vision_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["image_processing_chameleon"] = ["ChameleonImageProcessor"]
if TYPE_CHECKING:
- from .configuration_chameleon import *
- from .convert_chameleon_weights_to_hf import *
- from .image_processing_chameleon import *
- from .modeling_chameleon import *
- from .processing_chameleon import *
+ from .configuration_chameleon import ChameleonConfig, ChameleonVQVAEConfig
+ from .processing_chameleon import ChameleonProcessor
+
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_chameleon import (
+ ChameleonForConditionalGeneration,
+ ChameleonModel,
+ ChameleonPreTrainedModel,
+ ChameleonVQVAE,
+ )
+
+ try:
+ if not is_vision_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .image_processing_chameleon import ChameleonImageProcessor
+
+
else:
import sys
- _file = globals()["__file__"]
- sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
diff --git a/src/transformers/models/chameleon/configuration_chameleon.py b/src/transformers/models/chameleon/configuration_chameleon.py
index 2cc9cdb29d46c5..9842127e7bb48f 100644
--- a/src/transformers/models/chameleon/configuration_chameleon.py
+++ b/src/transformers/models/chameleon/configuration_chameleon.py
@@ -276,6 +276,3 @@ def _rope_scaling_validation(self):
)
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
-
-
-__all__ = ["ChameleonConfig", "ChameleonVQVAEConfig"]
diff --git a/src/transformers/models/chameleon/image_processing_chameleon.py b/src/transformers/models/chameleon/image_processing_chameleon.py
index cadaeb2e09a624..46d081973bb468 100644
--- a/src/transformers/models/chameleon/image_processing_chameleon.py
+++ b/src/transformers/models/chameleon/image_processing_chameleon.py
@@ -362,6 +362,3 @@ def blend_rgba(self, image: ImageInput) -> ImageInput:
alpha = img_rgba[:, :, 3] / 255.0
img_rgb = (1 - alpha[:, :, np.newaxis]) * 255 + alpha[:, :, np.newaxis] * img_rgba[:, :, :3]
return PIL.Image.fromarray(img_rgb.astype("uint8"), "RGB")
-
-
-__all__ = ["ChameleonImageProcessor"]
diff --git a/src/transformers/models/chameleon/modeling_chameleon.py b/src/transformers/models/chameleon/modeling_chameleon.py
index 11bc411a00c005..3255b6f44c05fb 100644
--- a/src/transformers/models/chameleon/modeling_chameleon.py
+++ b/src/transformers/models/chameleon/modeling_chameleon.py
@@ -115,6 +115,8 @@ def forward(self, x, position_ids):
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+# copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->Chameleon
+# TODO(joao): add me back asap :)
class ChameleonLinearScalingRotaryEmbedding(ChameleonRotaryEmbedding):
"""ChameleonRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
@@ -125,6 +127,8 @@ def forward(self, x, position_ids):
return cos, sin
+# copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Chameleon
+# TODO(joao): add me back asap :)
class ChameleonDynamicNTKScalingRotaryEmbedding(ChameleonRotaryEmbedding):
"""ChameleonRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
@@ -362,7 +366,7 @@ def forward(
return attn_output, attn_weights, past_key_value
-# NO LONGER EXIST copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->Chameleon
+# copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->Chameleon
# TODO(joao): add me back asap :)
class ChameleonFlashAttention2(ChameleonAttention):
"""
@@ -1685,6 +1689,3 @@ def prepare_inputs_for_generation(
}
)
return model_inputs
-
-
-__all__ = ["ChameleonForConditionalGeneration", "ChameleonModel", "ChameleonPreTrainedModel", "ChameleonVQVAE"]
diff --git a/src/transformers/models/chameleon/processing_chameleon.py b/src/transformers/models/chameleon/processing_chameleon.py
index 9f4bc2904c861c..e2a50d1af51b9e 100644
--- a/src/transformers/models/chameleon/processing_chameleon.py
+++ b/src/transformers/models/chameleon/processing_chameleon.py
@@ -168,6 +168,3 @@ def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
image_processor_input_names = self.image_processor.model_input_names
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
-
-
-__all__ = ["ChameleonProcessor"]
diff --git a/src/transformers/models/chinese_clip/__init__.py b/src/transformers/models/chinese_clip/__init__.py
index 8770bde94ecf3a..03c9665ab0d09f 100644
--- a/src/transformers/models/chinese_clip/__init__.py
+++ b/src/transformers/models/chinese_clip/__init__.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2022 The OFA-Sys Team Authors and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,19 +13,72 @@
# limitations under the License.
from typing import TYPE_CHECKING
-from ...utils import _LazyModule
-from ...utils.import_utils import define_import_structure
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
+_import_structure = {
+ "configuration_chinese_clip": [
+ "ChineseCLIPConfig",
+ "ChineseCLIPOnnxConfig",
+ "ChineseCLIPTextConfig",
+ "ChineseCLIPVisionConfig",
+ ],
+ "processing_chinese_clip": ["ChineseCLIPProcessor"],
+}
+
+try:
+ if not is_vision_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["feature_extraction_chinese_clip"] = ["ChineseCLIPFeatureExtractor"]
+ _import_structure["image_processing_chinese_clip"] = ["ChineseCLIPImageProcessor"]
+
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_chinese_clip"] = [
+ "ChineseCLIPModel",
+ "ChineseCLIPPreTrainedModel",
+ "ChineseCLIPTextModel",
+ "ChineseCLIPVisionModel",
+ ]
+
if TYPE_CHECKING:
- from .configuration_chinese_clip import *
- from .convert_chinese_clip_original_pytorch_to_hf import *
- from .feature_extraction_chinese_clip import *
- from .image_processing_chinese_clip import *
- from .modeling_chinese_clip import *
- from .processing_chinese_clip import *
+ from .configuration_chinese_clip import (
+ ChineseCLIPConfig,
+ ChineseCLIPOnnxConfig,
+ ChineseCLIPTextConfig,
+ ChineseCLIPVisionConfig,
+ )
+ from .processing_chinese_clip import ChineseCLIPProcessor
+
+ try:
+ if not is_vision_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .feature_extraction_chinese_clip import ChineseCLIPFeatureExtractor, ChineseCLIPImageProcessor
+
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_chinese_clip import (
+ ChineseCLIPModel,
+ ChineseCLIPPreTrainedModel,
+ ChineseCLIPTextModel,
+ ChineseCLIPVisionModel,
+ )
+
else:
import sys
- _file = globals()["__file__"]
- sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
diff --git a/src/transformers/models/chinese_clip/configuration_chinese_clip.py b/src/transformers/models/chinese_clip/configuration_chinese_clip.py
index c52b563cb2df9a..d50d6c842b313c 100644
--- a/src/transformers/models/chinese_clip/configuration_chinese_clip.py
+++ b/src/transformers/models/chinese_clip/configuration_chinese_clip.py
@@ -429,6 +429,3 @@ def generate_dummy_inputs(
@property
def default_onnx_opset(self) -> int:
return 14
-
-
-__all__ = ["ChineseCLIPConfig", "ChineseCLIPOnnxConfig", "ChineseCLIPTextConfig", "ChineseCLIPVisionConfig"]
diff --git a/src/transformers/models/chinese_clip/feature_extraction_chinese_clip.py b/src/transformers/models/chinese_clip/feature_extraction_chinese_clip.py
index fd416ca93b9ff3..09aa4106b718eb 100644
--- a/src/transformers/models/chinese_clip/feature_extraction_chinese_clip.py
+++ b/src/transformers/models/chinese_clip/feature_extraction_chinese_clip.py
@@ -31,6 +31,3 @@ def __init__(self, *args, **kwargs) -> None:
FutureWarning,
)
super().__init__(*args, **kwargs)
-
-
-__all__ = ["ChineseCLIPFeatureExtractor"]
diff --git a/src/transformers/models/chinese_clip/image_processing_chinese_clip.py b/src/transformers/models/chinese_clip/image_processing_chinese_clip.py
index 2c338f5a71b9db..52349f84bffe0b 100644
--- a/src/transformers/models/chinese_clip/image_processing_chinese_clip.py
+++ b/src/transformers/models/chinese_clip/image_processing_chinese_clip.py
@@ -305,6 +305,3 @@ def preprocess(
data = {"pixel_values": images}
return BatchFeature(data=data, tensor_type=return_tensors)
-
-
-__all__ = ["ChineseCLIPImageProcessor"]
diff --git a/src/transformers/models/chinese_clip/modeling_chinese_clip.py b/src/transformers/models/chinese_clip/modeling_chinese_clip.py
index c9c19073b0e77a..dffa9028af4ffe 100644
--- a/src/transformers/models/chinese_clip/modeling_chinese_clip.py
+++ b/src/transformers/models/chinese_clip/modeling_chinese_clip.py
@@ -1625,6 +1625,3 @@ def forward(
text_model_output=text_outputs,
vision_model_output=vision_outputs,
)
-
-
-__all__ = ["ChineseCLIPModel", "ChineseCLIPPreTrainedModel", "ChineseCLIPTextModel", "ChineseCLIPVisionModel"]
diff --git a/src/transformers/models/chinese_clip/processing_chinese_clip.py b/src/transformers/models/chinese_clip/processing_chinese_clip.py
index 53ba3d31259be9..2cfd314c649866 100644
--- a/src/transformers/models/chinese_clip/processing_chinese_clip.py
+++ b/src/transformers/models/chinese_clip/processing_chinese_clip.py
@@ -158,6 +158,3 @@ def feature_extractor_class(self):
FutureWarning,
)
return self.image_processor_class
-
-
-__all__ = ["ChineseCLIPProcessor"]
diff --git a/src/transformers/models/clap/__init__.py b/src/transformers/models/clap/__init__.py
index aa2a04536f5d9e..4d3d3ba04e136f 100644
--- a/src/transformers/models/clap/__init__.py
+++ b/src/transformers/models/clap/__init__.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,18 +13,60 @@
# limitations under the License.
from typing import TYPE_CHECKING
-from ...utils import _LazyModule
-from ...utils.import_utils import define_import_structure
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
+_import_structure = {
+ "configuration_clap": [
+ "ClapAudioConfig",
+ "ClapConfig",
+ "ClapTextConfig",
+ ],
+ "processing_clap": ["ClapProcessor"],
+}
+
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_clap"] = [
+ "ClapModel",
+ "ClapPreTrainedModel",
+ "ClapTextModel",
+ "ClapTextModelWithProjection",
+ "ClapAudioModel",
+ "ClapAudioModelWithProjection",
+ ]
+ _import_structure["feature_extraction_clap"] = ["ClapFeatureExtractor"]
+
if TYPE_CHECKING:
- from .configuration_clap import *
- from .convert_clap_original_pytorch_to_hf import *
- from .feature_extraction_clap import *
- from .modeling_clap import *
- from .processing_clap import *
+ from .configuration_clap import (
+ ClapAudioConfig,
+ ClapConfig,
+ ClapTextConfig,
+ )
+ from .processing_clap import ClapProcessor
+
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .feature_extraction_clap import ClapFeatureExtractor
+ from .modeling_clap import (
+ ClapAudioModel,
+ ClapAudioModelWithProjection,
+ ClapModel,
+ ClapPreTrainedModel,
+ ClapTextModel,
+ ClapTextModelWithProjection,
+ )
+
+
else:
import sys
- _file = globals()["__file__"]
- sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
diff --git a/src/transformers/models/clap/configuration_clap.py b/src/transformers/models/clap/configuration_clap.py
index c5b7d3b7a21a96..b2added7f0e073 100644
--- a/src/transformers/models/clap/configuration_clap.py
+++ b/src/transformers/models/clap/configuration_clap.py
@@ -389,6 +389,3 @@ def from_text_audio_configs(cls, text_config: ClapTextConfig, audio_config: Clap
"""
return cls(text_config=text_config.to_dict(), audio_config=audio_config.to_dict(), **kwargs)
-
-
-__all__ = ["ClapAudioConfig", "ClapConfig", "ClapTextConfig"]
diff --git a/src/transformers/models/clap/feature_extraction_clap.py b/src/transformers/models/clap/feature_extraction_clap.py
index 42d3646065ece7..2d1f16e19442f7 100644
--- a/src/transformers/models/clap/feature_extraction_clap.py
+++ b/src/transformers/models/clap/feature_extraction_clap.py
@@ -360,6 +360,3 @@ def __call__(
input_features = input_features.convert_to_tensors(return_tensors)
return input_features
-
-
-__all__ = ["ClapFeatureExtractor"]
diff --git a/src/transformers/models/clap/modeling_clap.py b/src/transformers/models/clap/modeling_clap.py
index 5792257e026d7d..f422b17b204f13 100644
--- a/src/transformers/models/clap/modeling_clap.py
+++ b/src/transformers/models/clap/modeling_clap.py
@@ -2302,13 +2302,3 @@ def forward(
attentions=audio_outputs.attentions,
hidden_states=audio_outputs.hidden_states,
)
-
-
-__all__ = [
- "ClapModel",
- "ClapPreTrainedModel",
- "ClapTextModel",
- "ClapTextModelWithProjection",
- "ClapAudioModel",
- "ClapAudioModelWithProjection",
-]
diff --git a/src/transformers/models/clap/processing_clap.py b/src/transformers/models/clap/processing_clap.py
index 6df9d4aa3961d0..4d1739ecf26172 100644
--- a/src/transformers/models/clap/processing_clap.py
+++ b/src/transformers/models/clap/processing_clap.py
@@ -115,6 +115,3 @@ def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
feature_extractor_input_names = self.feature_extractor.model_input_names
return list(dict.fromkeys(tokenizer_input_names + feature_extractor_input_names))
-
-
-__all__ = ["ClapProcessor"]
diff --git a/src/transformers/models/clip/__init__.py b/src/transformers/models/clip/__init__.py
index 3bc3eff946f60f..36247e943ecaf7 100644
--- a/src/transformers/models/clip/__init__.py
+++ b/src/transformers/models/clip/__init__.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,23 +13,165 @@
# limitations under the License.
from typing import TYPE_CHECKING
-from ...utils import _LazyModule
-from ...utils.import_utils import define_import_structure
+from ...utils import (
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ is_flax_available,
+ is_tf_available,
+ is_tokenizers_available,
+ is_torch_available,
+ is_vision_available,
+)
+
+
+_import_structure = {
+ "configuration_clip": [
+ "CLIPConfig",
+ "CLIPOnnxConfig",
+ "CLIPTextConfig",
+ "CLIPVisionConfig",
+ ],
+ "processing_clip": ["CLIPProcessor"],
+ "tokenization_clip": ["CLIPTokenizer"],
+}
+
+try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["tokenization_clip_fast"] = ["CLIPTokenizerFast"]
+
+try:
+ if not is_vision_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["feature_extraction_clip"] = ["CLIPFeatureExtractor"]
+ _import_structure["image_processing_clip"] = ["CLIPImageProcessor"]
+
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_clip"] = [
+ "CLIPModel",
+ "CLIPPreTrainedModel",
+ "CLIPTextModel",
+ "CLIPTextModelWithProjection",
+ "CLIPVisionModel",
+ "CLIPVisionModelWithProjection",
+ "CLIPForImageClassification",
+ ]
+
+try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_tf_clip"] = [
+ "TFCLIPModel",
+ "TFCLIPPreTrainedModel",
+ "TFCLIPTextModel",
+ "TFCLIPVisionModel",
+ ]
+
+try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_flax_clip"] = [
+ "FlaxCLIPModel",
+ "FlaxCLIPPreTrainedModel",
+ "FlaxCLIPTextModel",
+ "FlaxCLIPTextPreTrainedModel",
+ "FlaxCLIPTextModelWithProjection",
+ "FlaxCLIPVisionModel",
+ "FlaxCLIPVisionPreTrainedModel",
+ ]
if TYPE_CHECKING:
- from .configuration_clip import *
- from .convert_clip_original_pytorch_to_hf import *
- from .feature_extraction_clip import *
- from .image_processing_clip import *
- from .modeling_clip import *
- from .modeling_flax_clip import *
- from .modeling_tf_clip import *
- from .processing_clip import *
- from .tokenization_clip import *
- from .tokenization_clip_fast import *
+ from .configuration_clip import (
+ CLIPConfig,
+ CLIPOnnxConfig,
+ CLIPTextConfig,
+ CLIPVisionConfig,
+ )
+ from .processing_clip import CLIPProcessor
+ from .tokenization_clip import CLIPTokenizer
+
+ try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .tokenization_clip_fast import CLIPTokenizerFast
+
+ try:
+ if not is_vision_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .feature_extraction_clip import CLIPFeatureExtractor
+ from .image_processing_clip import CLIPImageProcessor
+
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_clip import (
+ CLIPForImageClassification,
+ CLIPModel,
+ CLIPPreTrainedModel,
+ CLIPTextModel,
+ CLIPTextModelWithProjection,
+ CLIPVisionModel,
+ CLIPVisionModelWithProjection,
+ )
+
+ try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_tf_clip import (
+ TFCLIPModel,
+ TFCLIPPreTrainedModel,
+ TFCLIPTextModel,
+ TFCLIPVisionModel,
+ )
+
+ try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_flax_clip import (
+ FlaxCLIPModel,
+ FlaxCLIPPreTrainedModel,
+ FlaxCLIPTextModel,
+ FlaxCLIPTextModelWithProjection,
+ FlaxCLIPTextPreTrainedModel,
+ FlaxCLIPVisionModel,
+ FlaxCLIPVisionPreTrainedModel,
+ )
+
+
else:
import sys
- _file = globals()["__file__"]
- sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
diff --git a/src/transformers/models/clip/configuration_clip.py b/src/transformers/models/clip/configuration_clip.py
index 3f5cb47cdd121c..2e1f2deede00c9 100644
--- a/src/transformers/models/clip/configuration_clip.py
+++ b/src/transformers/models/clip/configuration_clip.py
@@ -417,6 +417,3 @@ def generate_dummy_inputs(
@property
def default_onnx_opset(self) -> int:
return 14
-
-
-__all__ = ["CLIPConfig", "CLIPOnnxConfig", "CLIPTextConfig", "CLIPVisionConfig"]
diff --git a/src/transformers/models/clip/feature_extraction_clip.py b/src/transformers/models/clip/feature_extraction_clip.py
index 1984d883875740..5696a63abe621e 100644
--- a/src/transformers/models/clip/feature_extraction_clip.py
+++ b/src/transformers/models/clip/feature_extraction_clip.py
@@ -31,6 +31,3 @@ def __init__(self, *args, **kwargs) -> None:
FutureWarning,
)
super().__init__(*args, **kwargs)
-
-
-__all__ = ["CLIPFeatureExtractor"]
diff --git a/src/transformers/models/clip/image_processing_clip.py b/src/transformers/models/clip/image_processing_clip.py
index a5d12bd7ba2987..fa398821ca614c 100644
--- a/src/transformers/models/clip/image_processing_clip.py
+++ b/src/transformers/models/clip/image_processing_clip.py
@@ -343,6 +343,3 @@ def preprocess(
data = {"pixel_values": images}
return BatchFeature(data=data, tensor_type=return_tensors)
-
-
-__all__ = ["CLIPImageProcessor"]
diff --git a/src/transformers/models/clip/modeling_clip.py b/src/transformers/models/clip/modeling_clip.py
index 0bd9c9c0abce2f..04a3a73de0455e 100644
--- a/src/transformers/models/clip/modeling_clip.py
+++ b/src/transformers/models/clip/modeling_clip.py
@@ -401,6 +401,7 @@ class CLIPFlashAttention2(CLIPAttention):
flash attention and deal with padding tokens in case the input contains any of them.
"""
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -1676,14 +1677,3 @@ def forward(
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
-
-
-__all__ = [
- "CLIPModel",
- "CLIPPreTrainedModel",
- "CLIPTextModel",
- "CLIPTextModelWithProjection",
- "CLIPVisionModel",
- "CLIPVisionModelWithProjection",
- "CLIPForImageClassification",
-]
diff --git a/src/transformers/models/clip/modeling_flax_clip.py b/src/transformers/models/clip/modeling_flax_clip.py
index c674d35e3daf41..265e7005b74e0e 100644
--- a/src/transformers/models/clip/modeling_flax_clip.py
+++ b/src/transformers/models/clip/modeling_flax_clip.py
@@ -1293,14 +1293,3 @@ class FlaxCLIPModel(FlaxCLIPPreTrainedModel):
overwrite_call_docstring(FlaxCLIPModel, CLIP_INPUTS_DOCSTRING + FLAX_CLIP_MODEL_DOCSTRING)
append_replace_return_docstrings(FlaxCLIPModel, output_type=FlaxCLIPOutput, config_class=CLIPConfig)
-
-
-__all__ = [
- "FlaxCLIPModel",
- "FlaxCLIPPreTrainedModel",
- "FlaxCLIPTextModel",
- "FlaxCLIPTextPreTrainedModel",
- "FlaxCLIPTextModelWithProjection",
- "FlaxCLIPVisionModel",
- "FlaxCLIPVisionPreTrainedModel",
-]
diff --git a/src/transformers/models/clip/modeling_tf_clip.py b/src/transformers/models/clip/modeling_tf_clip.py
index aedea502e88645..ca5f4aede21854 100644
--- a/src/transformers/models/clip/modeling_tf_clip.py
+++ b/src/transformers/models/clip/modeling_tf_clip.py
@@ -1455,6 +1455,3 @@ def build(self, input_shape=None):
if getattr(self, "clip", None) is not None:
with tf.name_scope(self.clip.name):
self.clip.build(None)
-
-
-__all__ = ["TFCLIPModel", "TFCLIPPreTrainedModel", "TFCLIPTextModel", "TFCLIPVisionModel"]
diff --git a/src/transformers/models/clip/processing_clip.py b/src/transformers/models/clip/processing_clip.py
index e69e65dec68d9b..60805402b4cea7 100644
--- a/src/transformers/models/clip/processing_clip.py
+++ b/src/transformers/models/clip/processing_clip.py
@@ -151,6 +151,3 @@ def feature_extractor(self):
FutureWarning,
)
return self.image_processor
-
-
-__all__ = ["CLIPProcessor"]
diff --git a/src/transformers/models/clip/tokenization_clip.py b/src/transformers/models/clip/tokenization_clip.py
index 41a73db8c1ecb2..83e79890d084b3 100644
--- a/src/transformers/models/clip/tokenization_clip.py
+++ b/src/transformers/models/clip/tokenization_clip.py
@@ -514,6 +514,3 @@ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] =
index += 1
return vocab_file, merge_file
-
-
-__all__ = ["CLIPTokenizer"]
diff --git a/src/transformers/models/clip/tokenization_clip_fast.py b/src/transformers/models/clip/tokenization_clip_fast.py
index 89e7c8360310ee..48741a6293e48e 100644
--- a/src/transformers/models/clip/tokenization_clip_fast.py
+++ b/src/transformers/models/clip/tokenization_clip_fast.py
@@ -159,6 +159,3 @@ def create_token_type_ids_from_sequences(
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
files = self._tokenizer.model.save(save_directory, name=filename_prefix)
return tuple(files)
-
-
-__all__ = ["CLIPTokenizerFast"]
diff --git a/src/transformers/models/clipseg/__init__.py b/src/transformers/models/clipseg/__init__.py
index 77b338e8fea31c..cb7daf11553efd 100644
--- a/src/transformers/models/clipseg/__init__.py
+++ b/src/transformers/models/clipseg/__init__.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,17 +13,55 @@
# limitations under the License.
from typing import TYPE_CHECKING
-from ...utils import _LazyModule
-from ...utils.import_utils import define_import_structure
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
+_import_structure = {
+ "configuration_clipseg": [
+ "CLIPSegConfig",
+ "CLIPSegTextConfig",
+ "CLIPSegVisionConfig",
+ ],
+ "processing_clipseg": ["CLIPSegProcessor"],
+}
+
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_clipseg"] = [
+ "CLIPSegModel",
+ "CLIPSegPreTrainedModel",
+ "CLIPSegTextModel",
+ "CLIPSegVisionModel",
+ "CLIPSegForImageSegmentation",
+ ]
+
if TYPE_CHECKING:
- from .configuration_clipseg import *
- from .convert_clipseg_original_pytorch_to_hf import *
- from .modeling_clipseg import *
- from .processing_clipseg import *
+ from .configuration_clipseg import (
+ CLIPSegConfig,
+ CLIPSegTextConfig,
+ CLIPSegVisionConfig,
+ )
+ from .processing_clipseg import CLIPSegProcessor
+
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_clipseg import (
+ CLIPSegForImageSegmentation,
+ CLIPSegModel,
+ CLIPSegPreTrainedModel,
+ CLIPSegTextModel,
+ CLIPSegVisionModel,
+ )
+
else:
import sys
- _file = globals()["__file__"]
- sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
diff --git a/src/transformers/models/clipseg/configuration_clipseg.py b/src/transformers/models/clipseg/configuration_clipseg.py
index 7be9bd4d55eb0e..5474840f357a34 100644
--- a/src/transformers/models/clipseg/configuration_clipseg.py
+++ b/src/transformers/models/clipseg/configuration_clipseg.py
@@ -391,6 +391,3 @@ def from_text_vision_configs(cls, text_config: CLIPSegTextConfig, vision_config:
"""
return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
-
-
-__all__ = ["CLIPSegConfig", "CLIPSegTextConfig", "CLIPSegVisionConfig"]
diff --git a/src/transformers/models/clipseg/modeling_clipseg.py b/src/transformers/models/clipseg/modeling_clipseg.py
index f3b963070dfb29..4ead68032b6034 100644
--- a/src/transformers/models/clipseg/modeling_clipseg.py
+++ b/src/transformers/models/clipseg/modeling_clipseg.py
@@ -1504,12 +1504,3 @@ def forward(
vision_model_output=vision_outputs,
decoder_output=decoder_outputs,
)
-
-
-__all__ = [
- "CLIPSegModel",
- "CLIPSegPreTrainedModel",
- "CLIPSegTextModel",
- "CLIPSegVisionModel",
- "CLIPSegForImageSegmentation",
-]
diff --git a/src/transformers/models/clipseg/processing_clipseg.py b/src/transformers/models/clipseg/processing_clipseg.py
index bd817ae786550d..f8eaca82334a22 100644
--- a/src/transformers/models/clipseg/processing_clipseg.py
+++ b/src/transformers/models/clipseg/processing_clipseg.py
@@ -159,6 +159,3 @@ def feature_extractor(self):
FutureWarning,
)
return self.image_processor
-
-
-__all__ = ["CLIPSegProcessor"]
diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py
index 7b8b9547ac1c33..d481d87e7ab8ed 100644
--- a/src/transformers/models/cohere/modeling_cohere.py
+++ b/src/transformers/models/cohere/modeling_cohere.py
@@ -283,6 +283,9 @@ def __init__(self, config: CohereConfig, layer_idx: Optional[int] = None):
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias)
+ # TODO (joao): remove in v4.46 (RoPE is computed in the model, not in the decoder layers)
+ self.rotary_emb = CohereRotaryEmbedding(config=self.config)
+
def forward(
self,
hidden_states: torch.Tensor,
@@ -292,7 +295,7 @@ def forward(
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
@@ -311,7 +314,16 @@ def forward(
key_states = key_states.transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
- cos, sin = position_embeddings
+ if position_embeddings is None:
+ logger.warning_once(
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
+ "removed and `position_embeddings` will be mandatory."
+ )
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ else:
+ cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
@@ -351,8 +363,7 @@ def forward(
return attn_output, attn_weights, past_key_value
-# NO LONGER EXIST Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->Cohere
-# TODO cyril: modular
+# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->Cohere
class CohereFlashAttention2(CohereAttention):
"""
Cohere flash attention module. This module inherits from `CohereAttention` as the weights of the module stays
@@ -378,7 +389,7 @@ def forward(
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if isinstance(past_key_value, StaticCache):
@@ -404,7 +415,16 @@ def forward(
key_states = key_states.transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
- cos, sin = position_embeddings
+ if position_embeddings is None:
+ logger.warning_once(
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
+ "removed and `position_embeddings` will be mandatory."
+ )
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ else:
+ cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
@@ -482,7 +502,7 @@ def forward(
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions:
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
@@ -498,7 +518,6 @@ def forward(
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
- position_embeddings=position_embeddings,
)
bsz, q_len, _ = hidden_states.size()
@@ -517,7 +536,16 @@ def forward(
key_states = key_states.transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
- cos, sin = position_embeddings
+ if position_embeddings is None:
+ logger.warning_once(
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
+ "removed and `position_embeddings` will be mandatory."
+ )
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ else:
+ cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
@@ -587,7 +615,7 @@ def forward(
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
@@ -761,8 +789,7 @@ def _init_weights(self, module):
"The bare Cohere Model outputting raw hidden-states without any specific head on top.",
COHERE_START_DOCSTRING,
)
-# copied from transformers.models.llama.modeling_llama.LlamaModel with Llama->Cohere, LLAMA->COHERE
-# TODO cyril: modular
+# Copied from transformers.models.llama.modeling_llama.LlamaModel with Llama->Cohere, LLAMA->COHERE
class CohereModel(CoherePreTrainedModel):
"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`CohereDecoderLayer`]
@@ -828,22 +855,31 @@ def forward(
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
- if use_cache and past_key_values is None:
- past_key_values = DynamicCache()
+ # kept for BC (non `Cache` `past_key_values` inputs)
+ return_legacy_cache = False
+ if use_cache and not isinstance(past_key_values, Cache):
+ return_legacy_cache = True
+ if past_key_values is None:
+ past_key_values = DynamicCache()
+ else:
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
+ logger.warning_once(
+ "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
+ "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
+ "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
+ )
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
-
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
-
hidden_states = inputs_embeds
# create position embeddings to be shared across the decoder layers
@@ -852,6 +888,7 @@ def forward(
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
+ next_decoder_cache = None
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
if output_hidden_states:
@@ -884,6 +921,9 @@ def forward(
hidden_states = layer_outputs[0]
+ if use_cache:
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
+
if output_attentions:
all_self_attns += (layer_outputs[1],)
@@ -893,13 +933,18 @@ def forward(
if output_hidden_states:
all_hidden_states += (hidden_states,)
- output = BaseModelOutputWithPast(
+ next_cache = next_decoder_cache if use_cache else None
+ if return_legacy_cache:
+ next_cache = next_cache.to_legacy_cache()
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
+ return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
- past_key_values=past_key_values if use_cache else None,
+ past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
- return output if return_dict else output.to_tuple()
def _update_causal_mask(
self,
diff --git a/src/transformers/models/cohere2/__init__.py b/src/transformers/models/cohere2/__init__.py
deleted file mode 100644
index 1447f65935601f..00000000000000
--- a/src/transformers/models/cohere2/__init__.py
+++ /dev/null
@@ -1,27 +0,0 @@
-# Copyright 2024 Cohere and The HuggingFace Inc. team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-from typing import TYPE_CHECKING
-
-from ...utils import _LazyModule
-from ...utils.import_utils import define_import_structure
-
-
-if TYPE_CHECKING:
- from .configuration_cohere2 import *
- from .modeling_cohere2 import *
-else:
- import sys
-
- _file = globals()["__file__"]
- sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/src/transformers/models/cohere2/configuration_cohere2.py b/src/transformers/models/cohere2/configuration_cohere2.py
deleted file mode 100644
index aa22ec8eabef71..00000000000000
--- a/src/transformers/models/cohere2/configuration_cohere2.py
+++ /dev/null
@@ -1,209 +0,0 @@
-# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
-# This file was automatically generated from src/transformers/models/cohere2/modular_cohere2.py.
-# Do NOT edit this file manually as any edits will be overwritten by the generation of
-# the file from the modular. If any change should be done, please apply the change to the
-# modular_cohere2.py file directly. One of our CI enforces this.
-# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
-# coding=utf-8
-# Copyright 2024 Cohere Inc. HuggingFace Inc. team. All rights reserved.
-#
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-from ...configuration_utils import PretrainedConfig
-from ...modeling_rope_utils import rope_config_validation
-
-
-class Cohere2Config(PretrainedConfig):
- r"""
- This is the configuration class to store the configuration of a [`CohereModel`]. It is used to instantiate an Cohere
- model according to the specified arguments, defining the model architecture.
-
- Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
- documentation from [`PretrainedConfig`] for more information. Instantiating a configuration
- with the defaults will yield a similar configuration to that of the [CohereForAI/c4ai-command-r-v01](https://huggingface.co/CohereForAI/c4ai-command-r-v01) model.
-
-
- Args:
- vocab_size (`int`, *optional*, defaults to 256000):
- Vocabulary size of the Cohere model. Defines the number of different tokens that can be represented by the
- `inputs_ids` passed when calling [`CohereModel`]
- hidden_size (`int`, *optional*, defaults to 8192):
- Dimension of the hidden representations.
- intermediate_size (`int`, *optional*, defaults to 22528):
- Dimension of the MLP representations.
- logit_scale (`float`, *optional*, defaults to 0.0625):
- The scaling factor for the output logits.
- num_hidden_layers (`int`, *optional*, defaults to 40):
- Number of hidden layers in the Transformer decoder.
- num_attention_heads (`int`, *optional*, defaults to 64):
- Number of attention heads for each attention layer in the Transformer decoder.
- num_key_value_heads (`int`, *optional*):
- This is the number of key_value heads that should be used to implement Grouped Query Attention. If
- `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
- `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
- converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
- by meanpooling all the original heads within that group. For more details checkout [this
- paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
- `num_attention_heads`.
- hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
- The non-linear activation function (function or string) in the decoder.
- max_position_embeddings (`int`, *optional*, defaults to 8192):
- The maximum sequence length that this model might ever be used with.
- initializer_range (`float`, *optional*, defaults to 0.02):
- The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
- layer_norm_eps (`float`, *optional*, defaults to 1e-05):
- The epsilon used by the layer normalization.
- use_cache (`bool`, *optional*, defaults to `True`):
- Whether or not the model should return the last key/values attentions (not used by all models). Only
- relevant if `config.is_decoder=True`.
- pad_token_id (`int`, *optional*, defaults to 0):
- Padding token id.
- bos_token_id (`int`, *optional*, defaults to 5):
- Beginning of stream token id.
- eos_token_id (`int`, *optional*, defaults to 255001):
- End of stream token id.
- tie_word_embeddings (`bool`, *optional*, defaults to `True`):
- Whether to tie weight embeddings
- rope_theta (`float`, *optional*, defaults to 10000.0):
- The base period of the RoPE embeddings.
- rope_scaling (`Dict`, *optional*):
- Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
- and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
- accordingly.
- Expected contents:
- `rope_type` (`str`):
- The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
- 'llama3'], with 'default' being the original RoPE implementation.
- `factor` (`float`, *optional*):
- Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
- most scaling types, a `factor` of x will enable the model to handle sequences of length x *
- original maximum pre-trained length.
- `original_max_position_embeddings` (`int`, *optional*):
- Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
- pretraining.
- `attention_factor` (`float`, *optional*):
- Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
- computation. If unspecified, it defaults to value recommended by the implementation, using the
- `factor` field to infer the suggested value.
- `beta_fast` (`float`, *optional*):
- Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
- ramp function. If unspecified, it defaults to 32.
- `beta_slow` (`float`, *optional*):
- Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
- ramp function. If unspecified, it defaults to 1.
- `short_factor` (`List[float]`, *optional*):
- Only used with 'longrope'. The scaling factor to be applied to short contexts (<
- `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
- size divided by the number of attention heads divided by 2
- `long_factor` (`List[float]`, *optional*):
- Only used with 'longrope'. The scaling factor to be applied to long contexts (<
- `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
- size divided by the number of attention heads divided by 2
- `low_freq_factor` (`float`, *optional*):
- Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
- `high_freq_factor` (`float`, *optional*):
- Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
- attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
- Whether to use a bias in the query, key, value and output projection layers during self-attention.
- attention_dropout (`float`, *optional*, defaults to 0.0):
- The dropout ratio for the attention probabilities.
- sliding_window (`int`, *optional*, defaults to 4096):
- Size of the sliding window attention context.
- sliding_window_pattern (`int`, *optional*, defaults to 4):
- Pattern for the sliding window attention.
- cache_implementation (`str`, *optional*, defaults to `"hybrid"`): the cache type to be used with `generate`.
-
- ```python
- >>> from transformers import Cohere2Model, Cohere2Config
-
- >>> # Initializing a Cohere Nextmodel configuration
- >>> configuration = Cohere2Config()
-
- >>> # Initializing a model from the Cohere2 configuration
- >>> model = Cohere2Model(configuration) # doctest: +SKIP
-
- >>> # Accessing the model configuration
- >>> configuration = model.config # doctest: +SKIP
- ```
- """
-
- model_type = "cohere2"
- keys_to_ignore_at_inference = ["past_key_values"]
-
- def __init__(
- self,
- vocab_size=256000,
- hidden_size=8192,
- intermediate_size=22528,
- logit_scale=0.0625,
- num_hidden_layers=40,
- num_attention_heads=64,
- num_key_value_heads=None,
- hidden_act="silu",
- max_position_embeddings=8192,
- initializer_range=0.02,
- layer_norm_eps=1e-5,
- use_cache=True,
- pad_token_id=0,
- bos_token_id=5,
- eos_token_id=255001,
- tie_word_embeddings=True,
- rope_theta=10000.0,
- rope_scaling=None,
- attention_bias=False,
- attention_dropout=0.0,
- sliding_window=4096,
- sliding_window_pattern=4,
- cache_implementation="hybrid",
- **kwargs,
- ):
- self.vocab_size = vocab_size
- self.max_position_embeddings = max_position_embeddings
- self.hidden_size = hidden_size
- self.logit_scale = logit_scale
- self.intermediate_size = intermediate_size
- self.num_hidden_layers = num_hidden_layers
- self.num_attention_heads = num_attention_heads
-
- # for backward compatibility
- if num_key_value_heads is None:
- num_key_value_heads = num_attention_heads
-
- self.num_key_value_heads = num_key_value_heads
- self.hidden_act = hidden_act
- self.initializer_range = initializer_range
- self.layer_norm_eps = layer_norm_eps
- self.use_cache = use_cache
- self.rope_theta = rope_theta
- self.rope_scaling = rope_scaling
- self.attention_bias = attention_bias
- self.attention_dropout = attention_dropout
- self.sliding_window = sliding_window
- self.sliding_window_pattern = sliding_window_pattern
- # Need to specify head_dim in the config so it can be used in the attention forward functions
- self.head_dim = hidden_size // num_attention_heads
- self.cache_implementation = cache_implementation
-
- # Validate the correctness of rotary position embeddings parameters
- rope_config_validation(self)
-
- super().__init__(
- pad_token_id=pad_token_id,
- bos_token_id=bos_token_id,
- eos_token_id=eos_token_id,
- tie_word_embeddings=tie_word_embeddings,
- **kwargs,
- )
-
-
-__all__ = ["Cohere2Config"]
diff --git a/src/transformers/models/cohere2/modeling_cohere2.py b/src/transformers/models/cohere2/modeling_cohere2.py
deleted file mode 100644
index 1ffa4bffddc3df..00000000000000
--- a/src/transformers/models/cohere2/modeling_cohere2.py
+++ /dev/null
@@ -1,1079 +0,0 @@
-# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
-# This file was automatically generated from src/transformers/models/cohere2/modular_cohere2.py.
-# Do NOT edit this file manually as any edits will be overwritten by the generation of
-# the file from the modular. If any change should be done, please apply the change to the
-# modular_cohere2.py file directly. One of our CI enforces this.
-# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
-# coding=utf-8
-# Copyright 2024 Cohere Inc. HuggingFace Inc. team. All rights reserved.
-#
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-import math
-from typing import List, Optional, Tuple, Union
-
-import torch
-import torch.nn as nn
-
-from ...activations import ACT2FN
-from ...cache_utils import Cache, HybridCache
-from ...generation import GenerationMixin
-from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
-from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
-from ...modeling_utils import PreTrainedModel
-from ...utils import (
- add_start_docstrings,
- add_start_docstrings_to_model_forward,
- is_flash_attn_2_available,
- logging,
- replace_return_docstrings,
-)
-from .configuration_cohere2 import Cohere2Config
-
-
-if is_flash_attn_2_available():
- from ...modeling_flash_attention_utils import _flash_attention_forward
-
-
-logger = logging.get_logger(__name__)
-
-_CONFIG_FOR_DOC = "Cohere2Config"
-
-
-class Cohere2RotaryEmbedding(nn.Module):
- # Note: the forward pass of this RoPE is slightly different from Llama's, resulting in different `sin`/`cos` for
- # the same parameterization. The differences are highlighted with a comment.
-
- def __init__(
- self,
- dim=None,
- max_position_embeddings=2048,
- base=10000,
- device=None,
- scaling_factor=1.0,
- rope_type="default",
- config: Optional[Cohere2Config] = None,
- ):
- super().__init__()
- # TODO (joao): remove the `if` below, only used for BC
- self.rope_kwargs = {}
- if config is None:
- logger.warning_once(
- "`Cohere2RotaryEmbedding` can now be fully parameterized by passing the model config through the "
- "`config` argument. All other arguments will be removed in v4.46"
- )
- self.rope_kwargs = {
- "rope_type": rope_type,
- "factor": scaling_factor,
- "dim": dim,
- "base": base,
- "max_position_embeddings": max_position_embeddings,
- }
- self.rope_type = rope_type
- self.max_seq_len_cached = max_position_embeddings
- self.original_max_seq_len = max_position_embeddings
- else:
- # BC: "rope_type" was originally "type"
- if config.rope_scaling is not None:
- self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
- else:
- self.rope_type = "default"
- self.max_seq_len_cached = config.max_position_embeddings
- self.original_max_seq_len = config.max_position_embeddings
-
- self.config = config
- self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
-
- inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
- self.register_buffer("inv_freq", inv_freq, persistent=False)
- self.original_inv_freq = self.inv_freq
-
- def _dynamic_frequency_update(self, position_ids, device):
- """
- dynamic RoPE layers should recompute `inv_freq` in the following situations:
- 1 - growing beyond the cached sequence length (allow scaling)
- 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
- """
- seq_len = torch.max(position_ids) + 1
- if seq_len > self.max_seq_len_cached: # growth
- inv_freq, self.attention_scaling = self.rope_init_fn(
- self.config, device, seq_len=seq_len, **self.rope_kwargs
- )
- self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
- self.max_seq_len_cached = seq_len
-
- if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
- self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
- self.max_seq_len_cached = self.original_max_seq_len
-
- @torch.no_grad()
- def forward(self, x, position_ids):
- if "dynamic" in self.rope_type:
- self._dynamic_frequency_update(position_ids, device=x.device)
-
- # Core RoPE block
- inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
- position_ids_expanded = position_ids[:, None, :].float()
- # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
- device_type = x.device.type
- device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
- with torch.autocast(device_type=device_type, enabled=False):
- freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
- emb = torch.repeat_interleave(freqs, 2, dim=-1) # This line differs from Llama's implementation
- cos = emb.cos()
- sin = emb.sin()
-
- # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
- cos = cos * self.attention_scaling
- sin = sin * self.attention_scaling
-
- return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
-
-
-class Cohere2LayerNorm(nn.Module):
- def __init__(self, hidden_size=None, eps=1e-5, bias=False):
- """The hidden size can be a tuple or an int. The tuple is used for QKNorm to normalize across head_dim"""
- super().__init__()
- self.weight = nn.Parameter(torch.ones(hidden_size))
- self.variance_epsilon = eps
-
- def forward(self, hidden_states):
- input_dtype = hidden_states.dtype
- hidden_states = hidden_states.to(torch.float32)
- mean = hidden_states.mean(-1, keepdim=True)
- variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True)
- hidden_states = (hidden_states - mean) * torch.rsqrt(variance + self.variance_epsilon)
- hidden_states = self.weight.to(torch.float32) * hidden_states
- return hidden_states.to(input_dtype)
-
-
-def rotate_half(x):
- # Split and rotate. Note that this function is different from e.g. Llama.
- x1 = x[..., ::2]
- x2 = x[..., 1::2]
- rot_x = torch.stack([-x2, x1], dim=-1).flatten(-2)
- return rot_x
-
-
-def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
- """Applies Rotary Position Embedding to the query and key tensors.
-
- Args:
- q (`torch.Tensor`): The query tensor.
- k (`torch.Tensor`): The key tensor.
- cos (`torch.Tensor`): The cosine part of the rotary embedding.
- sin (`torch.Tensor`): The sine part of the rotary embedding.
- position_ids (`torch.Tensor`, *optional*):
- Deprecated and unused.
- unsqueeze_dim (`int`, *optional*, defaults to 1):
- The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
- sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
- that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
- k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
- cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
- the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
- Returns:
- `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
- """
- dtype = q.dtype
- q = q.float()
- k = k.float()
- cos = cos.unsqueeze(unsqueeze_dim)
- sin = sin.unsqueeze(unsqueeze_dim)
- q_embed = (q * cos) + (rotate_half(q) * sin)
- k_embed = (k * cos) + (rotate_half(k) * sin)
- return q_embed.to(dtype=dtype), k_embed.to(dtype=dtype)
-
-
-def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
- """
- This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
- num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
- """
- batch, num_key_value_heads, slen, head_dim = hidden_states.shape
- if n_rep == 1:
- return hidden_states
- hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
- return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
-
-
-def eager_attention_forward(
- config: Cohere2Config,
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- mask: Optional[torch.Tensor],
- **_kwargs,
-) -> Tuple[torch.Tensor, torch.Tensor]:
- key_states = repeat_kv(key, config.num_key_value_groups)
- value_states = repeat_kv(value, config.num_key_value_groups)
-
- attn_weights = torch.matmul(query, key_states.transpose(2, 3)) / math.sqrt(config.head_dim)
-
- if mask is not None: # no matter the length, we just slice it
- causal_mask = mask[:, :, :, : key_states.shape[-2]]
- attn_weights = attn_weights + causal_mask
-
- # upcast attention to fp32
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
- attn_weights = nn.functional.dropout(attn_weights, p=config.attention_dropout, training=config.training)
- attn_output = torch.matmul(attn_weights, value_states)
- attn_output = attn_output.transpose(1, 2).contiguous()
- return attn_output, attn_weights
-
-
-def flash_attention_forward(
- config: Cohere2Config,
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- mask: Optional[torch.Tensor],
- target_dtype: torch.dtype = torch.float16,
- **_kwargs,
-) -> Tuple[torch.Tensor, None]:
- if mask is not None:
- seq_len = mask.shape[1]
- query = query[:, :, :seq_len]
- value = value[:, :, :seq_len]
-
- # TODO: These transpose are quite inefficient but Flash Attention requires the layout
- # [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor rotary embedding
- query_states = query.transpose(1, 2)
- key_states = key.transpose(1, 2)
- value_states = value.transpose(1, 2)
-
- dropout_rate = config.attention_dropout if config.training else 0.0
-
- input_dtype = query_states.dtype
- if input_dtype == torch.float32:
- query_states = query_states.to(target_dtype)
- key_states = key_states.to(target_dtype)
- value_states = value_states.to(target_dtype)
-
- attn_output = _flash_attention_forward(
- query_states,
- key_states,
- value_states,
- mask,
- seq_len,
- dropout=dropout_rate,
- is_causal=config.is_causal,
- sliding_window=config.sliding_window,
- use_top_left_mask=config._flash_attn_uses_top_left_mask,
- )
-
- return attn_output, None
-
-
-def sdpa_attention_forward(
- config: Cohere2Config,
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- mask: Optional[torch.Tensor],
- **_kwargs,
-) -> Tuple[torch.Tensor, None]:
- key = repeat_kv(key, config.num_key_value_groups)
- value = repeat_kv(value, config.num_key_value_groups)
-
- causal_mask = mask
- if mask is not None:
- causal_mask = causal_mask[:, :, :, : key.shape[-2]]
-
- # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
- # Reference: https://github.com/pytorch/pytorch/issues/112577.
- if query.device.type == "cuda" and causal_mask is not None:
- query = query.contiguous()
- key = key.contiguous()
- value = value.contiguous()
-
- # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
- # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
- is_causal = True if causal_mask is None and query.shape[1] > 1 else False
-
- attn_output = torch.nn.functional.scaled_dot_product_attention(
- query,
- key,
- value,
- attn_mask=causal_mask,
- dropout_p=config.attention_dropout if config.training else 0.0,
- is_causal=is_causal,
- )
-
- attn_output = attn_output.transpose(1, 2).contiguous()
- return attn_output, None
-
-
-COHERE2_ATTENTION_FUNCTION = {
- "flash_attention_2": flash_attention_forward,
- "eager": eager_attention_forward,
- "sdpa": sdpa_attention_forward,
-}
-
-
-class Cohere2Attention(nn.Module):
- """Multi-headed attention from 'Attention Is All You Need' paper"""
-
- def __init__(self, config: Cohere2Config, layer_idx: Optional[int] = None):
- super().__init__()
- self.config = config
- self.layer_idx = layer_idx
- if layer_idx is None:
- logger.warning_once(
- f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
- "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
- "when creating this class."
- )
-
- self.attention_dropout = config.attention_dropout
- self.hidden_size = config.hidden_size
- self.num_heads = config.num_attention_heads
- self.head_dim = config.head_dim
- self.num_key_value_heads = config.num_key_value_heads
- self.num_key_value_groups = self.num_heads // self.num_key_value_heads
- self.max_position_embeddings = config.max_position_embeddings
- self.rope_theta = config.rope_theta
- self.is_causal = True
-
- if (self.head_dim * self.num_heads) != self.hidden_size:
- raise ValueError(
- f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
- f" and `num_heads`: {self.num_heads})."
- )
-
- self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
- self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
- self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
- self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias)
-
- self.sliding_window = (
- config.sliding_window if (self.layer_idx + 1) % self.config.sliding_window_pattern != 0 else None
- )
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- position_embeddings: Tuple[torch.Tensor, torch.Tensor],
- attention_mask: Optional[torch.Tensor] = None,
- past_key_value: Optional[Cache] = None,
- output_attentions: bool = False,
- use_cache: bool = False,
- cache_position: Optional[torch.LongTensor] = None,
- **kwargs,
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
- bsz, q_len, _ = hidden_states.size()
-
- query_states = self.q_proj(hidden_states)
- key_states = self.k_proj(hidden_states)
- value_states = self.v_proj(hidden_states)
-
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
-
- cos, sin = position_embeddings
-
- if self.sliding_window is not None:
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
-
- if past_key_value is not None:
- cache_kwargs = {
- "sin": sin,
- "cos": cos,
- "sliding_window": self.sliding_window,
- "cache_position": cache_position,
- }
- key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
-
- if output_attentions and self.config._attn_implementation in ["sdpa", "flash_attention_2"]:
- logger.warning_once("Setting `attention_type` to `eager` because `output_attentions=True`")
- attention_type = "eager"
- else:
- attention_type = self.config._attn_implementation
-
- attn_output, attn_weights = COHERE2_ATTENTION_FUNCTION[attention_type](
- self, query_states, key_states, value_states, attention_mask, output_attentions=output_attentions
- )
-
- attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
- attn_output = self.o_proj(attn_output)
-
- if not output_attentions:
- attn_weights = None
-
- return attn_output, attn_weights, past_key_value
-
-
-class Cohere2MLP(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.config = config
- self.hidden_size = config.hidden_size
- self.intermediate_size = config.intermediate_size
- self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
- self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
- self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
- self.act_fn = ACT2FN[config.hidden_act]
-
- # Ignore copy
- def forward(self, x):
- down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
- return down_proj
-
-
-class Cohere2DecoderLayer(nn.Module):
- def __init__(self, config: Cohere2Config, layer_idx: int):
- super().__init__()
- self.hidden_size = config.hidden_size
- self.self_attn = Cohere2Attention(config, layer_idx)
-
- self.mlp = Cohere2MLP(config)
- self.input_layernorm = Cohere2LayerNorm(hidden_size=(config.hidden_size), eps=config.layer_norm_eps)
- self.config = config
- self.is_sliding = (layer_idx + 1) % self.config.sliding_window_pattern != 0
- self.sliding_window = config.sliding_window
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- position_embeddings: Tuple[torch.Tensor, torch.Tensor],
- attention_mask: Optional[torch.Tensor] = None,
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
- output_attentions: Optional[bool] = False,
- use_cache: Optional[bool] = False,
- cache_position: Optional[torch.LongTensor] = None,
- ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
- """
- Args:
- hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
- position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`):
- Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
- with `head_dim` being the embedding dimension of each attention head.
- attention_mask (`torch.FloatTensor`, *optional*):
- attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
- query_sequence_length, key_sequence_length)` if default attention is used.
- output_attentions (`bool`, *optional*):
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
- returned tensors for more detail.
- use_cache (`bool`, *optional*):
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
- (see `past_key_values`).
- past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
- cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
- Indices depicting the position of the input sequence tokens in the sequence
- """
-
- if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
- # Flash-attn is a 2D tensor
- if self.config._attn_implementation == "flash_attention_2":
- if past_key_value is not None: # when decoding
- attention_mask = attention_mask[:, -self.sliding_window :]
- else:
- min_dtype = torch.finfo(hidden_states.dtype).min
- sliding_window_mask = torch.tril(
- torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window
- )
- attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
- if attention_mask.shape[-1] <= 1: # when decoding
- attention_mask = attention_mask[:, :, :, -self.sliding_window :]
-
- residual = hidden_states
-
- hidden_states = self.input_layernorm(hidden_states)
-
- # Self Attention
- hidden_states_attention, self_attn_weights, present_key_value = self.self_attn(
- hidden_states=hidden_states,
- position_embeddings=position_embeddings,
- attention_mask=attention_mask,
- past_key_value=past_key_value,
- output_attentions=output_attentions,
- use_cache=use_cache,
- cache_position=cache_position,
- )
-
- # Fully Connected
- hidden_states_mlp = self.mlp(hidden_states)
-
- # Add everything together
- hidden_states = residual + hidden_states_attention + hidden_states_mlp
-
- outputs = (hidden_states,)
-
- if output_attentions:
- outputs += (self_attn_weights,)
-
- if use_cache:
- outputs += (present_key_value,)
-
- return outputs
-
-
-COHERE2_START_DOCSTRING = r"""
- This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
- library implements for all its model (such as downloading or saving, resizing the input embeddings etc.).
-
- This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
- Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
- and behavior.
-
- Parameters:
- config ([`Cohere2Config`]):
- Model configuration class with all the parameters of the model. Initializing with a config file does not
- load the weights associated with the model, only the configuration. Check out the
- [`~PreTrainedModel.from_pretrained`] method to load the model weights.
-"""
-
-
-@add_start_docstrings(
- "The bare Cohere2 Model outputting raw hidden-states without any specific head on top.",
- COHERE2_START_DOCSTRING,
-)
-class Cohere2PreTrainedModel(PreTrainedModel):
- config_class = Cohere2Config
- base_model_prefix = "model"
- supports_gradient_checkpointing = True
- _no_split_modules = ["Cohere2DecoderLayer"]
- _skip_keys_device_placement = ["past_key_values"]
- _supports_flash_attn_2 = True
- _supports_sdpa = True
- _supports_cache_class = True
- _supports_quantized_cache = True
- _supports_static_cache = True
-
- def _init_weights(self, module):
- std = self.config.initializer_range
- if isinstance(module, nn.Linear):
- module.weight.data.normal_(mean=0.0, std=std)
- if module.bias is not None:
- module.bias.data.zero_()
- elif isinstance(module, nn.Embedding):
- module.weight.data.normal_(mean=0.0, std=std)
- if module.padding_idx is not None:
- module.weight.data[module.padding_idx].zero_()
-
-
-COHERE2_INPUTS_DOCSTRING = r"""
- Args:
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
- Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
- it.
-
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
- [`PreTrainedTokenizer.__call__`] for details.
-
- [What are input IDs?](../glossary#input-ids)
- attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
-
- - 1 for tokens that are **not masked**,
- - 0 for tokens that are **masked**.
-
- [What are attention masks?](../glossary#attention-mask)
-
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
- [`PreTrainedTokenizer.__call__`] for details.
-
- If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
- `past_key_values`).
-
- If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
- and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
- information on the default strategy.
-
- - 1 indicates the head is **not masked**,
- - 0 indicates the head is **masked**.
- position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
- config.n_positions - 1]`.
-
- [What are position IDs?](../glossary#position-ids)
- past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
- Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
- blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
- returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
-
- Two formats are allowed:
- - a [`~cache_utils.Cache`] instance, see our
- [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
- - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
- shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
- cache format.
-
- The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
- legacy cache format will be returned.
-
- If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
- have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
- of shape `(batch_size, sequence_length)`.
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
- is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
- model's internal embedding lookup matrix.
- use_cache (`bool`, *optional*):
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
- `past_key_values`).
- output_attentions (`bool`, *optional*):
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
- tensors for more detail.
- output_hidden_states (`bool`, *optional*):
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
- more detail.
- return_dict (`bool`, *optional*):
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
- cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
- Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
- this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
- the complete sequence length.
-"""
-
-
-@add_start_docstrings(
- "The bare Cohere2 Model outputting raw hidden-states without any specific head on top.",
- COHERE2_START_DOCSTRING,
-)
-class Cohere2Model(Cohere2PreTrainedModel):
- """
- Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Cohere2DecoderLayer`]
- Args:
- config: Cohere2Config
- """
-
- def __init__(self, config: Cohere2Config):
- super().__init__(config)
- self.padding_idx = config.pad_token_id
- self.vocab_size = config.vocab_size
-
- self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
- self.layers = nn.ModuleList(
- [Cohere2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
- )
- self.norm = Cohere2LayerNorm(hidden_size=(config.hidden_size), eps=config.layer_norm_eps)
- self.rotary_emb = Cohere2RotaryEmbedding(config=config)
- self.gradient_checkpointing = False
-
- # Initialize weights and apply final processing
- self.post_init()
-
- def get_input_embeddings(self):
- return self.embed_tokens
-
- def set_input_embeddings(self, value):
- self.embed_tokens = value
-
- @add_start_docstrings_to_model_forward(COHERE2_INPUTS_DOCSTRING)
- def forward(
- self,
- input_ids: torch.LongTensor = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[HybridCache] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- cache_position: Optional[torch.LongTensor] = None,
- ) -> Union[Tuple, BaseModelOutputWithPast]:
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- use_cache = use_cache if use_cache is not None else self.config.use_cache
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
-
- if (input_ids is None) ^ (inputs_embeds is not None):
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
-
- if self.gradient_checkpointing and self.training and use_cache:
- logger.warning_once(
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
- )
- use_cache = False
-
- if inputs_embeds is None:
- inputs_embeds = self.embed_tokens(input_ids)
-
- if use_cache and past_key_values is None and not self.training:
- batch_size, seq_len, _ = inputs_embeds.shape
- past_key_values = HybridCache(
- self.config,
- batch_size=batch_size,
- max_cache_len=seq_len,
- device=self.device,
- dtype=inputs_embeds.dtype,
- )
-
- if cache_position is None:
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
- cache_position = torch.arange(
- past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
- )
- if position_ids is None:
- position_ids = cache_position.unsqueeze(0)
-
- causal_mask = self._update_causal_mask(
- attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
- )
- hidden_states = inputs_embeds
-
- # create position embeddings to be shared across the decoder layers
- position_embeddings = self.rotary_emb(hidden_states, position_ids)
-
- # decoder layers
- all_hidden_states = () if output_hidden_states else None
- all_self_attns = () if output_attentions else None
-
- for decoder_layer in self.layers:
- if output_hidden_states:
- all_hidden_states += (hidden_states,)
-
- if self.gradient_checkpointing and self.training:
- layer_outputs = self._gradient_checkpointing_func(
- decoder_layer.__call__,
- hidden_states,
- position_embeddings,
- causal_mask,
- past_key_values,
- output_attentions,
- use_cache,
- cache_position,
- )
- else:
- layer_outputs = decoder_layer(
- hidden_states,
- position_embeddings=position_embeddings,
- attention_mask=causal_mask,
- past_key_value=past_key_values,
- output_attentions=output_attentions,
- use_cache=use_cache,
- cache_position=cache_position,
- )
-
- hidden_states = layer_outputs[0]
-
- if output_attentions:
- all_self_attns += (layer_outputs[1],)
-
- hidden_states = self.norm(hidden_states)
-
- # add hidden states from the last decoder layer
- if output_hidden_states:
- all_hidden_states += (hidden_states,)
-
- next_cache = past_key_values if use_cache else None
-
- if not return_dict:
- return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
- return BaseModelOutputWithPast(
- last_hidden_state=hidden_states,
- past_key_values=next_cache,
- hidden_states=all_hidden_states,
- attentions=all_self_attns,
- )
-
- @torch.no_grad()
- def _update_causal_mask(
- self,
- attention_mask: torch.Tensor,
- input_tensor: torch.Tensor,
- cache_position: torch.Tensor,
- past_key_values: HybridCache,
- output_attentions: bool,
- ):
- # Flash Attention currently doesn't support static cache but Cohere2 work only with static cache.
- # So we will pass in attention mask as is in any case, not only when ther's padding. Then we'll use its shape
- # to cut out keys/values trailing 0 used in static cache. This workaround should be compile compatible
- # as it doesn't cause dynamic control issues.
- if self.config._attn_implementation == "flash_attention_2":
- return attention_mask
-
- dtype, device = input_tensor.dtype, input_tensor.device
- sequence_length = input_tensor.shape[1]
- if isinstance(past_key_values, HybridCache):
- target_length = past_key_values.get_max_cache_shape()
- else:
- target_length = attention_mask.shape[-1] if attention_mask is not None else input_tensor.shape[1]
-
- # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
- causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
- attention_mask,
- sequence_length=sequence_length,
- target_length=target_length,
- dtype=dtype,
- device=device,
- cache_position=cache_position,
- batch_size=input_tensor.shape[0],
- )
- return causal_mask
-
- @staticmethod
- def _prepare_4d_causal_attention_mask_with_cache_position(
- attention_mask: torch.Tensor,
- sequence_length: int,
- target_length: int,
- dtype: torch.dtype,
- device: torch.device,
- cache_position: torch.Tensor,
- batch_size: int,
- **kwargs,
- ):
- """
- Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
- `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
-
- Args:
- attention_mask (`torch.Tensor`):
- A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
- `(batch_size, 1, query_length, key_value_length)`.
- sequence_length (`int`):
- The sequence length being processed.
- target_length (`int`):
- The target length: when generating with static cache, the mask should be as long as the static cache,
- to account for the 0 padding, the part of the cache that is not filled yet.
- dtype (`torch.dtype`):
- The dtype to use for the 4D attention mask.
- device (`torch.device`):
- The device to plcae the 4D attention mask on.
- cache_position (`torch.Tensor`):
- Indices depicting the position of the input sequence tokens in the sequence.
- batch_size (`torch.Tensor`):
- Batch size.
- """
- if attention_mask is not None and attention_mask.dim() == 4:
- # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
- causal_mask = attention_mask
- else:
- min_dtype = torch.finfo(dtype).min
- causal_mask = torch.full(
- (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
- )
- if sequence_length != 1:
- causal_mask = torch.triu(causal_mask, diagonal=1)
- causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
- causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
- if attention_mask is not None:
- causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
- mask_length = attention_mask.shape[-1]
- padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
- padding_mask = padding_mask == 0
- causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
- padding_mask, min_dtype
- )
-
- return causal_mask
-
-
-# TODO: re-enable check: Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with Llama->Cohere2
-class Cohere2ForCausalLM(Cohere2PreTrainedModel, GenerationMixin):
- _tied_weights_keys = ["lm_head.weight"]
-
- # Ignore copy
- def __init__(self, config: Cohere2Config):
- super().__init__(config)
- self.model = Cohere2Model(config)
- self.vocab_size = config.vocab_size
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
- self.logit_scale = config.logit_scale
- self.tie_word_embeddings = config.tie_word_embeddings
- # Initialize weights and apply final processing
- self.post_init()
-
- def get_input_embeddings(self):
- return self.model.embed_tokens
-
- def set_input_embeddings(self, value):
- self.model.embed_tokens = value
-
- def get_output_embeddings(self):
- return self.lm_head
-
- def set_output_embeddings(self, new_embeddings):
- self.lm_head = new_embeddings
-
- def set_decoder(self, decoder):
- self.model = decoder
-
- def get_decoder(self):
- return self.model
-
- # Ignore copy
- @add_start_docstrings_to_model_forward(COHERE2_INPUTS_DOCSTRING)
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
- def forward(
- self,
- input_ids: torch.LongTensor = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[List[torch.FloatTensor]] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- labels: Optional[torch.LongTensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- cache_position: Optional[torch.LongTensor] = None,
- num_logits_to_keep: int = 0,
- **loss_kwargs,
- ) -> Union[Tuple, CausalLMOutputWithPast]:
- r"""
- Args:
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
-
- num_logits_to_keep (`int`, *optional*):
- Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
- `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
- token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
-
- Returns:
-
- Example:
-
- ```python
- >> from transformers import AutoTokenizer, Cohere2ForCausalLM
-
- >> model = Cohere2ForCausalLM.from_pretrained("Cohere2ForAI/c4ai-command-r-v01")
- >> tokenizer = AutoTokenizer.from_pretrained("Cohere2ForAI/c4ai-command-r-v01")
-
- >> prompt = "Hey, are you conscious? Can you talk to me?"
- >> inputs = tokenizer(prompt, return_tensors="pt")
-
- >> # Generate
- >> generate_ids = model.generate(inputs.input_ids, max_length=30)
- >> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
- "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
- ```"""
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
-
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
- outputs = self.model(
- input_ids=input_ids,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_values=past_key_values,
- inputs_embeds=inputs_embeds,
- use_cache=use_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- cache_position=cache_position,
- )
-
- hidden_states = outputs[0]
- # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
- logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
- logits = logits * self.logit_scale
-
- loss = None
- if labels is not None:
- loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
-
- if not return_dict:
- output = (logits,) + outputs[1:]
- return (loss,) + output if loss is not None else output
-
- return CausalLMOutputWithPast(
- loss=loss,
- logits=logits,
- past_key_values=outputs.past_key_values,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
-
- def prepare_inputs_for_generation(
- self,
- input_ids,
- past_key_values=None,
- attention_mask=None,
- inputs_embeds=None,
- cache_position=None,
- position_ids=None,
- use_cache=True,
- num_logits_to_keep=None,
- **kwargs,
- ):
- # Overwritten: has a special cache type, `HybridCache`
-
- # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
- # Exception 1: when passing input_embeds, input_ids may be missing entries
- # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
- if past_key_values is not None:
- if inputs_embeds is not None: # Exception 1
- input_ids = input_ids[:, -cache_position.shape[0] :]
- elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
- input_ids = input_ids[:, cache_position]
- if attention_mask is not None and position_ids is None:
- # create position_ids on the fly for batch generation
- position_ids = attention_mask.long().cumsum(-1) - 1
- position_ids.masked_fill_(attention_mask == 0, 1)
- if past_key_values:
- position_ids = position_ids[:, -input_ids.shape[1] :]
- # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s
- # `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride
- # during the decoding. Here, simply using `.contiguous()` is not sufficient as in the
- # batch size = 1 case, `position_ids` is already contiguous but with varying stride
- # which retriggers a capture.
- position_ids = position_ids.clone(memory_format=torch.contiguous_format)
-
- # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
- if inputs_embeds is not None and cache_position[0] == 0:
- model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
- else:
- # The clone here is for the same reason as for `position_ids`.
- model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
-
- if (
- isinstance(past_key_values, HybridCache)
- and attention_mask.ndim == 2
- and not self.config._attn_implementation == "flash_attention_2"
- ):
- if model_inputs["inputs_embeds"] is not None:
- batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
- device = model_inputs["inputs_embeds"].device
- else:
- batch_size, sequence_length = model_inputs["input_ids"].shape
- device = model_inputs["input_ids"].device
-
- attention_mask = self.model._prepare_4d_causal_attention_mask_with_cache_position(
- attention_mask,
- sequence_length=sequence_length,
- target_length=past_key_values.get_max_cache_shape(),
- dtype=self.lm_head.weight.dtype,
- device=device,
- cache_position=cache_position,
- batch_size=batch_size,
- )
-
- if num_logits_to_keep is not None:
- model_inputs["num_logits_to_keep"] = num_logits_to_keep
-
- model_inputs.update(
- {
- "position_ids": position_ids,
- "cache_position": cache_position,
- "past_key_values": past_key_values,
- "use_cache": use_cache,
- "attention_mask": attention_mask,
- }
- )
- return model_inputs
-
-
-__all__ = ["Cohere2ForCausalLM", "Cohere2Model", "Cohere2PreTrainedModel"]
diff --git a/src/transformers/models/cohere2/modular_cohere2.py b/src/transformers/models/cohere2/modular_cohere2.py
deleted file mode 100644
index 3e6999b29bbfa1..00000000000000
--- a/src/transformers/models/cohere2/modular_cohere2.py
+++ /dev/null
@@ -1,744 +0,0 @@
-# coding=utf-8
-# Copyright 2024 Cohere Inc. HuggingFace Inc. team. All rights reserved.
-#
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-import math
-from typing import Optional, Tuple, Union
-
-import torch
-import torch.nn as nn
-import torch.utils.checkpoint
-
-from ...cache_utils import Cache, HybridCache
-from ...configuration_utils import PretrainedConfig
-from ...modeling_outputs import (
- BaseModelOutputWithPast,
-)
-from ...modeling_rope_utils import rope_config_validation
-from ...utils import (
- is_flash_attn_2_available,
- logging,
-)
-from ..cohere.modeling_cohere import (
- CohereDecoderLayer,
- CohereForCausalLM,
- CohereLayerNorm,
- CoherePreTrainedModel,
- CohereRotaryEmbedding,
- apply_rotary_pos_emb,
- repeat_kv,
-)
-from ..gemma2.modeling_gemma2 import Gemma2Model
-
-
-if is_flash_attn_2_available():
- from ...modeling_flash_attention_utils import _flash_attention_forward
-
-
-logger = logging.get_logger(__name__)
-
-
-class Cohere2Config(PretrainedConfig):
- r"""
- This is the configuration class to store the configuration of a [`CohereModel`]. It is used to instantiate an Cohere
- model according to the specified arguments, defining the model architecture.
-
- Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
- documentation from [`PretrainedConfig`] for more information. Instantiating a configuration
- with the defaults will yield a similar configuration to that of the [CohereForAI/c4ai-command-r-v01](https://huggingface.co/CohereForAI/c4ai-command-r-v01) model.
-
-
- Args:
- vocab_size (`int`, *optional*, defaults to 256000):
- Vocabulary size of the Cohere model. Defines the number of different tokens that can be represented by the
- `inputs_ids` passed when calling [`CohereModel`]
- hidden_size (`int`, *optional*, defaults to 8192):
- Dimension of the hidden representations.
- intermediate_size (`int`, *optional*, defaults to 22528):
- Dimension of the MLP representations.
- logit_scale (`float`, *optional*, defaults to 0.0625):
- The scaling factor for the output logits.
- num_hidden_layers (`int`, *optional*, defaults to 40):
- Number of hidden layers in the Transformer decoder.
- num_attention_heads (`int`, *optional*, defaults to 64):
- Number of attention heads for each attention layer in the Transformer decoder.
- num_key_value_heads (`int`, *optional*):
- This is the number of key_value heads that should be used to implement Grouped Query Attention. If
- `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
- `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
- converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
- by meanpooling all the original heads within that group. For more details checkout [this
- paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
- `num_attention_heads`.
- hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
- The non-linear activation function (function or string) in the decoder.
- max_position_embeddings (`int`, *optional*, defaults to 8192):
- The maximum sequence length that this model might ever be used with.
- initializer_range (`float`, *optional*, defaults to 0.02):
- The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
- layer_norm_eps (`float`, *optional*, defaults to 1e-05):
- The epsilon used by the layer normalization.
- use_cache (`bool`, *optional*, defaults to `True`):
- Whether or not the model should return the last key/values attentions (not used by all models). Only
- relevant if `config.is_decoder=True`.
- pad_token_id (`int`, *optional*, defaults to 0):
- Padding token id.
- bos_token_id (`int`, *optional*, defaults to 5):
- Beginning of stream token id.
- eos_token_id (`int`, *optional*, defaults to 255001):
- End of stream token id.
- tie_word_embeddings (`bool`, *optional*, defaults to `True`):
- Whether to tie weight embeddings
- rope_theta (`float`, *optional*, defaults to 10000.0):
- The base period of the RoPE embeddings.
- rope_scaling (`Dict`, *optional*):
- Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
- and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
- accordingly.
- Expected contents:
- `rope_type` (`str`):
- The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
- 'llama3'], with 'default' being the original RoPE implementation.
- `factor` (`float`, *optional*):
- Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
- most scaling types, a `factor` of x will enable the model to handle sequences of length x *
- original maximum pre-trained length.
- `original_max_position_embeddings` (`int`, *optional*):
- Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
- pretraining.
- `attention_factor` (`float`, *optional*):
- Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
- computation. If unspecified, it defaults to value recommended by the implementation, using the
- `factor` field to infer the suggested value.
- `beta_fast` (`float`, *optional*):
- Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
- ramp function. If unspecified, it defaults to 32.
- `beta_slow` (`float`, *optional*):
- Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
- ramp function. If unspecified, it defaults to 1.
- `short_factor` (`List[float]`, *optional*):
- Only used with 'longrope'. The scaling factor to be applied to short contexts (<
- `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
- size divided by the number of attention heads divided by 2
- `long_factor` (`List[float]`, *optional*):
- Only used with 'longrope'. The scaling factor to be applied to long contexts (<
- `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
- size divided by the number of attention heads divided by 2
- `low_freq_factor` (`float`, *optional*):
- Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
- `high_freq_factor` (`float`, *optional*):
- Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
- attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
- Whether to use a bias in the query, key, value and output projection layers during self-attention.
- attention_dropout (`float`, *optional*, defaults to 0.0):
- The dropout ratio for the attention probabilities.
- sliding_window (`int`, *optional*, defaults to 4096):
- Size of the sliding window attention context.
- sliding_window_pattern (`int`, *optional*, defaults to 4):
- Pattern for the sliding window attention.
- cache_implementation (`str`, *optional*, defaults to `"hybrid"`): the cache type to be used with `generate`.
-
- ```python
- >>> from transformers import Cohere2Model, Cohere2Config
-
- >>> # Initializing a Cohere Nextmodel configuration
- >>> configuration = Cohere2Config()
-
- >>> # Initializing a model from the Cohere2 configuration
- >>> model = Cohere2Model(configuration) # doctest: +SKIP
-
- >>> # Accessing the model configuration
- >>> configuration = model.config # doctest: +SKIP
- ```
- """
-
- model_type = "cohere2"
- keys_to_ignore_at_inference = ["past_key_values"]
-
- def __init__(
- self,
- vocab_size=256000,
- hidden_size=8192,
- intermediate_size=22528,
- logit_scale=0.0625,
- num_hidden_layers=40,
- num_attention_heads=64,
- num_key_value_heads=None,
- hidden_act="silu",
- max_position_embeddings=8192,
- initializer_range=0.02,
- layer_norm_eps=1e-5,
- use_cache=True,
- pad_token_id=0,
- bos_token_id=5,
- eos_token_id=255001,
- tie_word_embeddings=True,
- rope_theta=10000.0,
- rope_scaling=None,
- attention_bias=False,
- attention_dropout=0.0,
- sliding_window=4096,
- sliding_window_pattern=4,
- cache_implementation="hybrid",
- **kwargs,
- ):
- self.vocab_size = vocab_size
- self.max_position_embeddings = max_position_embeddings
- self.hidden_size = hidden_size
- self.logit_scale = logit_scale
- self.intermediate_size = intermediate_size
- self.num_hidden_layers = num_hidden_layers
- self.num_attention_heads = num_attention_heads
-
- # for backward compatibility
- if num_key_value_heads is None:
- num_key_value_heads = num_attention_heads
-
- self.num_key_value_heads = num_key_value_heads
- self.hidden_act = hidden_act
- self.initializer_range = initializer_range
- self.layer_norm_eps = layer_norm_eps
- self.use_cache = use_cache
- self.rope_theta = rope_theta
- self.rope_scaling = rope_scaling
- self.attention_bias = attention_bias
- self.attention_dropout = attention_dropout
- self.sliding_window = sliding_window
- self.sliding_window_pattern = sliding_window_pattern
- # Need to specify head_dim in the config so it can be used in the attention forward functions
- self.head_dim = hidden_size // num_attention_heads
- self.cache_implementation = cache_implementation
-
- # Validate the correctness of rotary position embeddings parameters
- rope_config_validation(self)
-
- super().__init__(
- pad_token_id=pad_token_id,
- bos_token_id=bos_token_id,
- eos_token_id=eos_token_id,
- tie_word_embeddings=tie_word_embeddings,
- **kwargs,
- )
-
-
-class Cohere2RotaryEmbedding(CohereRotaryEmbedding):
- pass
-
-
-class Cohere2LayerNorm(CohereLayerNorm):
- pass
-
-
-def eager_attention_forward(
- config: Cohere2Config,
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- mask: Optional[torch.Tensor],
- **_kwargs,
-) -> Tuple[torch.Tensor, torch.Tensor]:
- key_states = repeat_kv(key, config.num_key_value_groups)
- value_states = repeat_kv(value, config.num_key_value_groups)
-
- attn_weights = torch.matmul(query, key_states.transpose(2, 3)) / math.sqrt(config.head_dim)
-
- if mask is not None: # no matter the length, we just slice it
- causal_mask = mask[:, :, :, : key_states.shape[-2]]
- attn_weights = attn_weights + causal_mask
-
- # upcast attention to fp32
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
- attn_weights = nn.functional.dropout(attn_weights, p=config.attention_dropout, training=config.training)
- attn_output = torch.matmul(attn_weights, value_states)
- attn_output = attn_output.transpose(1, 2).contiguous()
- return attn_output, attn_weights
-
-
-def flash_attention_forward(
- config: Cohere2Config,
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- mask: Optional[torch.Tensor],
- target_dtype: torch.dtype = torch.float16,
- **_kwargs,
-) -> Tuple[torch.Tensor, None]:
- if mask is not None:
- seq_len = mask.shape[1]
- query = query[:, :, :seq_len]
- value = value[:, :, :seq_len]
-
- # TODO: These transpose are quite inefficient but Flash Attention requires the layout
- # [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor rotary embedding
- query_states = query.transpose(1, 2)
- key_states = key.transpose(1, 2)
- value_states = value.transpose(1, 2)
-
- dropout_rate = config.attention_dropout if config.training else 0.0
-
- input_dtype = query_states.dtype
- if input_dtype == torch.float32:
- query_states = query_states.to(target_dtype)
- key_states = key_states.to(target_dtype)
- value_states = value_states.to(target_dtype)
-
- attn_output = _flash_attention_forward(
- query_states,
- key_states,
- value_states,
- mask,
- seq_len,
- dropout=dropout_rate,
- is_causal=config.is_causal,
- sliding_window=config.sliding_window,
- use_top_left_mask=config._flash_attn_uses_top_left_mask,
- )
-
- return attn_output, None
-
-
-def sdpa_attention_forward(
- config: Cohere2Config,
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- mask: Optional[torch.Tensor],
- **_kwargs,
-) -> Tuple[torch.Tensor, None]:
- key = repeat_kv(key, config.num_key_value_groups)
- value = repeat_kv(value, config.num_key_value_groups)
-
- causal_mask = mask
- if mask is not None:
- causal_mask = causal_mask[:, :, :, : key.shape[-2]]
-
- # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
- # Reference: https://github.com/pytorch/pytorch/issues/112577.
- if query.device.type == "cuda" and causal_mask is not None:
- query = query.contiguous()
- key = key.contiguous()
- value = value.contiguous()
-
- # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
- # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
- is_causal = True if causal_mask is None and query.shape[1] > 1 else False
-
- attn_output = torch.nn.functional.scaled_dot_product_attention(
- query,
- key,
- value,
- attn_mask=causal_mask,
- dropout_p=config.attention_dropout if config.training else 0.0,
- is_causal=is_causal,
- )
-
- attn_output = attn_output.transpose(1, 2).contiguous()
- return attn_output, None
-
-
-COHERE2_ATTENTION_FUNCTION = {
- "flash_attention_2": flash_attention_forward,
- "eager": eager_attention_forward,
- "sdpa": sdpa_attention_forward,
-}
-
-
-class Cohere2Attention(nn.Module):
- """Multi-headed attention from 'Attention Is All You Need' paper"""
-
- def __init__(self, config: Cohere2Config, layer_idx: Optional[int] = None):
- super().__init__()
- self.config = config
- self.layer_idx = layer_idx
- if layer_idx is None:
- logger.warning_once(
- f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
- "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
- "when creating this class."
- )
-
- self.attention_dropout = config.attention_dropout
- self.hidden_size = config.hidden_size
- self.num_heads = config.num_attention_heads
- self.head_dim = config.head_dim
- self.num_key_value_heads = config.num_key_value_heads
- self.num_key_value_groups = self.num_heads // self.num_key_value_heads
- self.max_position_embeddings = config.max_position_embeddings
- self.rope_theta = config.rope_theta
- self.is_causal = True
-
- if (self.head_dim * self.num_heads) != self.hidden_size:
- raise ValueError(
- f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
- f" and `num_heads`: {self.num_heads})."
- )
-
- self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
- self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
- self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
- self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias)
-
- self.sliding_window = (
- config.sliding_window if (self.layer_idx + 1) % self.config.sliding_window_pattern != 0 else None
- )
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- position_embeddings: Tuple[torch.Tensor, torch.Tensor],
- attention_mask: Optional[torch.Tensor] = None,
- past_key_value: Optional[Cache] = None,
- output_attentions: bool = False,
- use_cache: bool = False,
- cache_position: Optional[torch.LongTensor] = None,
- **kwargs,
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
- bsz, q_len, _ = hidden_states.size()
-
- query_states = self.q_proj(hidden_states)
- key_states = self.k_proj(hidden_states)
- value_states = self.v_proj(hidden_states)
-
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
-
- cos, sin = position_embeddings
-
- if self.sliding_window is not None:
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
-
- if past_key_value is not None:
- cache_kwargs = {
- "sin": sin,
- "cos": cos,
- "sliding_window": self.sliding_window,
- "cache_position": cache_position,
- }
- key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
-
- if output_attentions and self.config._attn_implementation in ["sdpa", "flash_attention_2"]:
- logger.warning_once("Setting `attention_type` to `eager` because `output_attentions=True`")
- attention_type = "eager"
- else:
- attention_type = self.config._attn_implementation
-
- attn_output, attn_weights = COHERE2_ATTENTION_FUNCTION[attention_type](
- self, query_states, key_states, value_states, attention_mask, output_attentions=output_attentions
- )
-
- attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
- attn_output = self.o_proj(attn_output)
-
- if not output_attentions:
- attn_weights = None
-
- return attn_output, attn_weights, past_key_value
-
-
-class Cohere2DecoderLayer(CohereDecoderLayer):
- def __init__(self, config: Cohere2Config, layer_idx: int):
- super().__init__(config, layer_idx)
- self.self_attn = Cohere2Attention(config, layer_idx)
- self.config = config
- self.is_sliding = (layer_idx + 1) % self.config.sliding_window_pattern != 0
- self.sliding_window = config.sliding_window
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- position_embeddings: Tuple[torch.Tensor, torch.Tensor],
- attention_mask: Optional[torch.Tensor] = None,
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
- output_attentions: Optional[bool] = False,
- use_cache: Optional[bool] = False,
- cache_position: Optional[torch.LongTensor] = None,
- ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
- """
- Args:
- hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
- position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`):
- Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
- with `head_dim` being the embedding dimension of each attention head.
- attention_mask (`torch.FloatTensor`, *optional*):
- attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
- query_sequence_length, key_sequence_length)` if default attention is used.
- output_attentions (`bool`, *optional*):
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
- returned tensors for more detail.
- use_cache (`bool`, *optional*):
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
- (see `past_key_values`).
- past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
- cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
- Indices depicting the position of the input sequence tokens in the sequence
- """
-
- if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
- # Flash-attn is a 2D tensor
- if self.config._attn_implementation == "flash_attention_2":
- if past_key_value is not None: # when decoding
- attention_mask = attention_mask[:, -self.sliding_window :]
- else:
- min_dtype = torch.finfo(hidden_states.dtype).min
- sliding_window_mask = torch.tril(
- torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window
- )
- attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
- if attention_mask.shape[-1] <= 1: # when decoding
- attention_mask = attention_mask[:, :, :, -self.sliding_window :]
-
- residual = hidden_states
-
- hidden_states = self.input_layernorm(hidden_states)
-
- # Self Attention
- hidden_states_attention, self_attn_weights, present_key_value = self.self_attn(
- hidden_states=hidden_states,
- position_embeddings=position_embeddings,
- attention_mask=attention_mask,
- past_key_value=past_key_value,
- output_attentions=output_attentions,
- use_cache=use_cache,
- cache_position=cache_position,
- )
-
- # Fully Connected
- hidden_states_mlp = self.mlp(hidden_states)
-
- # Add everything together
- hidden_states = residual + hidden_states_attention + hidden_states_mlp
-
- outputs = (hidden_states,)
-
- if output_attentions:
- outputs += (self_attn_weights,)
-
- if use_cache:
- outputs += (present_key_value,)
-
- return outputs
-
-
-class Cohere2PreTrainedModel(CoherePreTrainedModel):
- config_class = Cohere2Config
-
-
-class Cohere2Model(Gemma2Model):
- """
- Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Cohere2DecoderLayer`]
- Args:
- config: Cohere2Config
- """
-
- def __init__(self, config: Cohere2Config):
- super().__init__(config)
- self.norm = Cohere2LayerNorm(hidden_size=(config.hidden_size), eps=config.layer_norm_eps)
- self.rotary_emb = Cohere2RotaryEmbedding(config=config)
-
- def forward(
- self,
- input_ids: torch.LongTensor = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[HybridCache] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- cache_position: Optional[torch.LongTensor] = None,
- ) -> Union[Tuple, BaseModelOutputWithPast]:
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- use_cache = use_cache if use_cache is not None else self.config.use_cache
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
-
- if (input_ids is None) ^ (inputs_embeds is not None):
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
-
- if self.gradient_checkpointing and self.training and use_cache:
- logger.warning_once(
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
- )
- use_cache = False
-
- if inputs_embeds is None:
- inputs_embeds = self.embed_tokens(input_ids)
-
- if use_cache and past_key_values is None and not self.training:
- batch_size, seq_len, _ = inputs_embeds.shape
- past_key_values = HybridCache(
- self.config,
- batch_size=batch_size,
- max_cache_len=seq_len,
- device=self.device,
- dtype=inputs_embeds.dtype,
- )
-
- if cache_position is None:
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
- cache_position = torch.arange(
- past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
- )
- if position_ids is None:
- position_ids = cache_position.unsqueeze(0)
-
- causal_mask = self._update_causal_mask(
- attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
- )
- hidden_states = inputs_embeds
-
- # create position embeddings to be shared across the decoder layers
- position_embeddings = self.rotary_emb(hidden_states, position_ids)
-
- # decoder layers
- all_hidden_states = () if output_hidden_states else None
- all_self_attns = () if output_attentions else None
-
- for decoder_layer in self.layers:
- if output_hidden_states:
- all_hidden_states += (hidden_states,)
-
- if self.gradient_checkpointing and self.training:
- layer_outputs = self._gradient_checkpointing_func(
- decoder_layer.__call__,
- hidden_states,
- position_embeddings,
- causal_mask,
- past_key_values,
- output_attentions,
- use_cache,
- cache_position,
- )
- else:
- layer_outputs = decoder_layer(
- hidden_states,
- position_embeddings=position_embeddings,
- attention_mask=causal_mask,
- past_key_value=past_key_values,
- output_attentions=output_attentions,
- use_cache=use_cache,
- cache_position=cache_position,
- )
-
- hidden_states = layer_outputs[0]
-
- if output_attentions:
- all_self_attns += (layer_outputs[1],)
-
- hidden_states = self.norm(hidden_states)
-
- # add hidden states from the last decoder layer
- if output_hidden_states:
- all_hidden_states += (hidden_states,)
-
- next_cache = past_key_values if use_cache else None
-
- if not return_dict:
- return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
- return BaseModelOutputWithPast(
- last_hidden_state=hidden_states,
- past_key_values=next_cache,
- hidden_states=all_hidden_states,
- attentions=all_self_attns,
- )
-
-
-class Cohere2ForCausalLM(CohereForCausalLM):
- def __init__(self, config: Cohere2Config):
- super().__init__(config)
-
- def prepare_inputs_for_generation(
- self,
- input_ids,
- past_key_values=None,
- attention_mask=None,
- inputs_embeds=None,
- cache_position=None,
- position_ids=None,
- use_cache=True,
- num_logits_to_keep=None,
- **kwargs,
- ):
- # Overwritten: has a special cache type, `HybridCache`
-
- # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
- # Exception 1: when passing input_embeds, input_ids may be missing entries
- # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
- if past_key_values is not None:
- if inputs_embeds is not None: # Exception 1
- input_ids = input_ids[:, -cache_position.shape[0] :]
- elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
- input_ids = input_ids[:, cache_position]
- if attention_mask is not None and position_ids is None:
- # create position_ids on the fly for batch generation
- position_ids = attention_mask.long().cumsum(-1) - 1
- position_ids.masked_fill_(attention_mask == 0, 1)
- if past_key_values:
- position_ids = position_ids[:, -input_ids.shape[1] :]
- # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s
- # `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride
- # during the decoding. Here, simply using `.contiguous()` is not sufficient as in the
- # batch size = 1 case, `position_ids` is already contiguous but with varying stride
- # which retriggers a capture.
- position_ids = position_ids.clone(memory_format=torch.contiguous_format)
-
- # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
- if inputs_embeds is not None and cache_position[0] == 0:
- model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
- else:
- # The clone here is for the same reason as for `position_ids`.
- model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
-
- if (
- isinstance(past_key_values, HybridCache)
- and attention_mask.ndim == 2
- and not self.config._attn_implementation == "flash_attention_2"
- ):
- if model_inputs["inputs_embeds"] is not None:
- batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
- device = model_inputs["inputs_embeds"].device
- else:
- batch_size, sequence_length = model_inputs["input_ids"].shape
- device = model_inputs["input_ids"].device
-
- attention_mask = self.model._prepare_4d_causal_attention_mask_with_cache_position(
- attention_mask,
- sequence_length=sequence_length,
- target_length=past_key_values.get_max_cache_shape(),
- dtype=self.lm_head.weight.dtype,
- device=device,
- cache_position=cache_position,
- batch_size=batch_size,
- )
-
- if num_logits_to_keep is not None:
- model_inputs["num_logits_to_keep"] = num_logits_to_keep
-
- model_inputs.update(
- {
- "position_ids": position_ids,
- "cache_position": cache_position,
- "past_key_values": past_key_values,
- "use_cache": use_cache,
- "attention_mask": attention_mask,
- }
- )
- return model_inputs
-
-
-__all__ = ["Cohere2Config", "Cohere2ForCausalLM", "Cohere2Model", "Cohere2PreTrainedModel"]
diff --git a/src/transformers/models/colpali/__init__.py b/src/transformers/models/colpali/__init__.py
deleted file mode 100644
index fa1b63fd009803..00000000000000
--- a/src/transformers/models/colpali/__init__.py
+++ /dev/null
@@ -1,28 +0,0 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-from typing import TYPE_CHECKING
-
-from ...utils import _LazyModule
-from ...utils.import_utils import define_import_structure
-
-
-if TYPE_CHECKING:
- from .configuration_colpali import *
- from .modeling_colpali import *
- from .processing_colpali import *
-else:
- import sys
-
- _file = globals()["__file__"]
- sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/src/transformers/models/colpali/configuration_colpali.py b/src/transformers/models/colpali/configuration_colpali.py
deleted file mode 100644
index 045462adca4e2c..00000000000000
--- a/src/transformers/models/colpali/configuration_colpali.py
+++ /dev/null
@@ -1,106 +0,0 @@
-# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""ColPali model configuration"""
-
-import logging
-from copy import deepcopy
-
-from ...configuration_utils import PretrainedConfig
-from ..auto import CONFIG_MAPPING, AutoConfig
-
-
-logger = logging.getLogger(__name__)
-
-
-class ColPaliConfig(PretrainedConfig):
- r"""
- Configuration class to store the configuration of a [`ColPaliForRetrieval`]. It is used to instantiate an instance
- of `ColPaliForRetrieval` according to the specified arguments, defining the model architecture following the methodology
- from the "ColPali: Efficient Document Retrieval with Vision Language Models" paper.
-
- Creating a configuration with the default settings will result in a configuration where the VLM backbone is set to the
- default PaliGemma configuration, i.e the one from [vidore/colpali-v1.2](https://huggingface.co/vidore/colpali-v1.2).
-
- The ColPali config is very similar to [`PaligemmaConfig`], but with an extra attribute defining the embedding dimension.
-
- Note that contrarily to what the class name suggests (actually the name refers to the ColPali **methodology**), you can
- use a different VLM backbone model than PaliGemma by passing the corresponding VLM configuration to the class constructor.
-
- Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
- documentation from [`PretrainedConfig`] for more information.
-
- Args:
- vlm_config (`PretrainedConfig`, *optional*):
- Configuration of the VLM backbone model.
- text_config (`PretrainedConfig`, *optional*):
- Configuration of the text backbone model. Overrides the `text_config` attribute of the `vlm_config` if provided.
- embedding_dim (`int`, *optional*, defaults to 128):
- Dimension of the multi-vector embeddings produced by the model.
-
- Example:
-
- ```python
- from transformers.models.colpali import ColPaliConfig, ColPaliForRetrieval
-
- config = ColPaliConfig()
- model = ColPaliForRetrieval(config)
- ```
- """
-
- model_type = "colpali"
- sub_configs = {"vlm_config": PretrainedConfig, "text_config": AutoConfig}
-
- def __init__(
- self,
- vlm_config=None,
- text_config=None,
- embedding_dim: int = 128,
- **kwargs,
- ):
- if vlm_config is None:
- vlm_config = CONFIG_MAPPING["paligemma"]()
- logger.info(
- "`vlm_config` is `None`. Initializing `vlm_config` with the `PaliGemmaConfig` with default values."
- )
- elif isinstance(vlm_config, dict):
- vlm_config = deepcopy(vlm_config)
- if "model_type" not in vlm_config:
- raise KeyError(
- "The `model_type` key is missing in the `vlm_config` dictionary. Please provide the model type."
- )
- elif vlm_config["model_type"] not in CONFIG_MAPPING:
- raise ValueError(
- f"The model type `{vlm_config['model_type']}` is not supported. Please provide a valid model type."
- )
- vlm_config = CONFIG_MAPPING[vlm_config["model_type"]](**vlm_config)
- elif isinstance(vlm_config, PretrainedConfig):
- vlm_config = vlm_config
- else:
- raise TypeError(
- f"Invalid type for `vlm_config`. Expected `PretrainedConfig`, `dict`, or `None`, but got {type(vlm_config)}."
- )
-
- self.vlm_config = vlm_config
- self.text_config = text_config = text_config if text_config is not None else vlm_config.text_config
- if isinstance(self.text_config, dict):
- text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "gemma"
- self.text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
-
- self.embedding_dim = embedding_dim
-
- super().__init__(**kwargs)
-
-
-__all__ = ["ColPaliConfig"]
diff --git a/src/transformers/models/colpali/convert_colpali_weights_to_hf.py b/src/transformers/models/colpali/convert_colpali_weights_to_hf.py
deleted file mode 100644
index 1b30f3f97acda3..00000000000000
--- a/src/transformers/models/colpali/convert_colpali_weights_to_hf.py
+++ /dev/null
@@ -1,214 +0,0 @@
-# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""
-Convert ColPali weights from the original repository to the HF model format.
-
-Original repository: https://github.com/illuin-tech/colpali.
-
-NOTE: This script was originally run using `torch==2.5.1` and with:
-
-```bash
-python src/transformers/models/colpali/convert_colpali_weights_to_hf.py \
- --model_id vidore/colpali-v1.2-merged \
- --revision 89fd9736194236a1ecb7a9ec9b04f537f6f896af \
- --original_vlm_name_or_path google/paligemma-3b-mix-448 \
- --output_dir vidore/colpali-v1.2-hf-internal \
- --push_to_hub
-
-python src/transformers/models/colpali/convert_colpali_weights_to_hf.py \
- --model_id vidore/colpali-v1.3-merged \
- --revision 5b955e3415a7c5468ab33119d98d6d45c3a5b2c3 \
- --original_vlm_name_or_path google/paligemma-3b-mix-448 \
- --output_dir vidore/colpali-v1.3-hf \
- --push_to_hub
-```
-"""
-
-import argparse
-import glob
-from pathlib import Path
-from typing import Any, Dict, Optional
-
-import torch
-from huggingface_hub import snapshot_download
-from safetensors import safe_open
-
-from transformers import AutoConfig
-from transformers.models.colpali import ColPaliForRetrieval
-from transformers.models.colpali.configuration_colpali import ColPaliConfig
-from transformers.utils import logging
-
-
-logging.set_verbosity_info()
-logger = logging.get_logger(__name__)
-
-
-ORIGINAL_DTYPE = torch.bfloat16
-
-
-def rename_state_dict_keys(state_dict: Dict[str, Any]) -> Dict[str, Any]:
- new_state_dict = {}
- for key, value in state_dict.items():
- new_key = key
- if key.startswith("custom_text_proj"):
- new_key = key.replace("custom_text_proj", "embedding_proj_layer")
- if key.startswith("model."):
- new_key = key.replace("model.", "vlm.", 1)
- new_state_dict[new_key] = value
- return new_state_dict
-
-
-def load_original_state_dict(model_id: str, revision: Optional[str] = None) -> Dict[str, torch.Tensor]:
- directory_path = snapshot_download(
- repo_id=model_id,
- revision=revision,
- allow_patterns=["*.safetensors"],
- )
-
- original_state_dict = {}
- for path in glob.glob(f"{directory_path}/*"):
- if path.endswith(".safetensors"):
- with safe_open(path, framework="pt", device="cpu") as f:
- for key in f.keys():
- original_state_dict[key] = f.get_tensor(key)
-
- # Some weights are tied, so `lm.head`` is not saved. Let's clone to load state dict.
- if "lm_head.weight" not in original_state_dict:
- original_state_dict["vlm.language_model.lm_head.weight"] = original_state_dict[
- "model.language_model.model.embed_tokens.weight"
- ].clone()
-
- return original_state_dict
-
-
-@torch.no_grad()
-def convert_colpali_weights_to_hf(
- model_id: str,
- output_dir: str,
- push_to_hub: bool,
- revision: Optional[str] = None,
- original_vlm_name_or_path: Optional[str] = None,
-):
- # Load the original model data
- original_config = AutoConfig.from_pretrained(
- model_id,
- revision=revision,
- )
- if original_vlm_name_or_path is not None:
- original_config._name_or_path = original_vlm_name_or_path
- if hasattr(original_config, "architectures"):
- delattr(original_config, "architectures")
-
- original_state_dict = load_original_state_dict(model_id, revision=revision)
-
- # Format the state_dict keys
- original_state_dict = rename_state_dict_keys(original_state_dict)
-
- # Create the new config
- config = ColPaliConfig(
- vlm_config=original_config,
- embedding_dim=128, # hardcoded in the original model
- )
- config.model_type = "colpali"
- config.is_composition = False
-
- # Load the untrained model
- model = ColPaliForRetrieval(config=config).to("cpu").eval()
- print("Created model with new config and randomly initialized weights")
-
- # NOTE: The model was initialized with float32 weights. We need to convert it to the desired precision.
- # There are two ways to set the model's dtype:
- # - Using `model.from_pretrained(..., torch_dtype=dtype_precision)` doesn't convert the hyperparameters to the desired precision.
- # - Using `model.to(dtype_precision)` converts all values - including the hyperparameters - to the desired precision.
- # The following snippet allows a fine-grained control over the model's dtype, making sure that all
- # the new weights' dtypes match the original model.
- for param in model.parameters():
- param.data = param.data.to(ORIGINAL_DTYPE)
- print(f"Converted the new model weights to `{ORIGINAL_DTYPE}`")
-
- # Load the original weights
- model.load_state_dict(original_state_dict)
- print("Loaded original model weights")
-
- # Tie the weights (following ColPali's `__init__`` step)
- if model.vlm.language_model._tied_weights_keys is not None:
- model._tied_weights_keys = [f"vlm.language_model.{k}" for k in model.vlm.language_model._tied_weights_keys]
-
- # Sanity check: ensure all keys are the same
- state_dict_keys_old = set(original_state_dict.keys())
- state_dict_keys_new = set(model.state_dict().keys())
- disjoint_keys = state_dict_keys_old.symmetric_difference(state_dict_keys_new)
- if disjoint_keys:
- raise ValueError(f"Incompatible keys: {disjoint_keys}")
-
- # Save the model
- if push_to_hub:
- model.push_to_hub(output_dir, private=True)
- print(f"Model pushed to the hub at `{output_dir}`")
- else:
- Path(output_dir).mkdir(exist_ok=True, parents=True)
- model.save_pretrained(output_dir)
- print(f"Model saved to `{output_dir}`")
-
-
-if __name__ == "__main__":
- parser = argparse.ArgumentParser(
- description="""
- This script converts the original ColPali model to the HF model format.
-
- Example usage:
- ```bash
- python src/transformers/models/colpali/convert_colpali_weights_to_hf.py \
- --model_id vidore/colpali-v1.2-merged \
- --revision 89fd9736194236a1ecb7a9ec9b04f537f6f896af \
- --original_vlm_name_or_path google/paligemma-3b-mix-448 \
- --output_dir vidore/colpali-v1.2-hf \
- --push_to_hub
- ```
- """
- )
- parser.add_argument(
- "--model_id",
- help="Model ID of the original model to convert",
- )
- parser.add_argument(
- "--output_dir",
- help="Location to write HF model and tokenizer",
- )
- parser.add_argument(
- "--push_to_hub",
- help="Whether or not to push the model to the hub at `output_dir` instead of saving it locally",
- action="store_true",
- default=False,
- )
- parser.add_argument(
- "--revision",
- help="Revision of the model to download",
- default=None,
- )
- parser.add_argument(
- "--original_vlm_name_or_path",
- help="Name or path of the original VLM backbone model",
- default=None,
- )
- args = parser.parse_args()
-
- convert_colpali_weights_to_hf(
- model_id=args.model_id,
- output_dir=args.output_dir,
- push_to_hub=args.push_to_hub,
- revision=args.revision,
- original_vlm_name_or_path=args.original_vlm_name_or_path,
- )
diff --git a/src/transformers/models/colpali/modeling_colpali.py b/src/transformers/models/colpali/modeling_colpali.py
deleted file mode 100644
index d84f29a3414f0f..00000000000000
--- a/src/transformers/models/colpali/modeling_colpali.py
+++ /dev/null
@@ -1,293 +0,0 @@
-# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""PyTorch ColPali model"""
-
-from dataclasses import dataclass
-from typing import List, Optional, Tuple, Union
-
-import torch
-from torch import nn
-
-from transformers import AutoModelForImageTextToText
-
-from ...cache_utils import Cache
-from ...modeling_utils import PreTrainedModel
-from ...utils import (
- ModelOutput,
- add_start_docstrings,
- add_start_docstrings_to_model_forward,
- replace_return_docstrings,
-)
-from .configuration_colpali import ColPaliConfig
-
-
-_CONFIG_FOR_DOC = "ColPaliConfig"
-
-COLPALI_START_DOCSTRING = r"""
- This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
- library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
- etc.)
-
- This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
- Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
- and behavior.
-
- Parameters:
- config ([`ColPaliConfig`]):
- Model configuration class with all the parameters of the model. Initializing with a config file does not
- load the weights associated with the model, only the configuration. Check out the
- [`~PreTrainedModel.from_pretrained`] method to load the model weights.
-"""
-
-
-@add_start_docstrings(
- "The bare ColPali model outputting raw hidden-states without any specific head on top.",
- COLPALI_START_DOCSTRING,
-)
-class ColPaliPreTrainedModel(PreTrainedModel):
- config_class = ColPaliConfig
- base_model_prefix = "model"
- _no_split_modules = []
-
- def _init_weights(self, module):
- std = (
- self.config.initializer_range
- if hasattr(self.config, "initializer_range")
- else self.config.vlm_config.text_config.initializer_range
- )
-
- if isinstance(module, (nn.Linear, nn.Conv2d)):
- module.weight.data.normal_(mean=0.0, std=std)
- if module.bias is not None:
- module.bias.data.zero_()
- elif isinstance(module, nn.Embedding):
- module.weight.data.normal_(mean=0.0, std=std)
- if module.padding_idx is not None:
- module.weight.data[module.padding_idx].zero_()
-
-
-@dataclass
-class ColPaliForRetrievalOutput(ModelOutput):
- """
- Base class for ColPali embeddings output.
-
- Args:
- loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
- Language modeling loss (for next-token prediction).
- embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
- The embeddings of the model.
- past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
- `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
-
- Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
- `past_key_values` input) to speed up sequential decoding.
- hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
- Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
- one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
-
- Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
- attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
- sequence_length)`.
-
- Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
- heads.
- image_hidden_states (`torch.FloatTensor`, *optional*):
- A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
- image_hidden_states of the model produced by the vision encoder after projecting last hidden state.
- """
-
- loss: Optional[torch.FloatTensor] = None
- embeddings: torch.Tensor = None
- past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None
- hidden_states: Optional[Tuple[torch.FloatTensor]] = None
- attentions: Optional[Tuple[torch.FloatTensor]] = None
- image_hidden_states: Optional[torch.FloatTensor] = None
-
-
-COLPALI_FOR_RETRIEVAL_INPUT_DOCSTRING = r"""
- Args:
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
- Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
- it.
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
- [`PreTrainedTokenizer.__call__`] for details.
- [What are input IDs?](../glossary#input-ids)
- pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)):
- The tensors corresponding to the input images. Pixel values can be obtained using
- [`AutoImageProcessor`]. See [`SiglipImageProcessor.__call__`] for details ([]`PaliGemmaProcessor`] uses
- [`SiglipImageProcessor`] for processing images). If none, ColPali will only process text (query embeddings).
- attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- - 1 for tokens that are **not masked**,
- - 0 for tokens that are **masked**.
- [What are attention masks?](../glossary#attention-mask)
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
- [`PreTrainedTokenizer.__call__`] for details.
- If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
- `past_key_values`).
- If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
- and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
- information on the default strategy.
- - 1 indicates the head is **not masked**,
- - 0 indicates the head is **masked**.
- output_attentions (`bool`, *optional*):
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
- tensors for more detail.
- output_hidden_states (`bool`, *optional*):
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
- more detail.
- return_dict (`bool`, *optional*):
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
- kwargs (`Dict[str, Any]`, *optional*):
- Additional key word arguments passed along to the vlm backbone model.
-"""
-
-
-@add_start_docstrings(
- """
- In our proposed ColPali approach, we leverage VLMs to construct efficient multi-vector embeddings directly
- from document images (“screenshots”) for document retrieval. We train the model to maximize the similarity
- between these document embeddings and the corresponding query embeddings, using the late interaction method
- introduced in ColBERT.
-
- Using ColPali removes the need for potentially complex and brittle layout recognition and OCR pipelines with a
- single model that can take into account both the textual and visual content (layout, charts, etc.) of a document.
- """
-)
-class ColPaliForRetrieval(ColPaliPreTrainedModel):
- def __init__(self, config: ColPaliConfig):
- super().__init__(config)
- self.config = config
- self.vocab_size = config.vlm_config.text_config.vocab_size
-
- vlm = AutoModelForImageTextToText.from_config(config.vlm_config)
- if vlm.language_model._tied_weights_keys is not None:
- self._tied_weights_keys = [f"vlm.language_model.{k}" for k in vlm.language_model._tied_weights_keys]
- self.vlm = vlm
-
- self.embedding_dim = self.config.embedding_dim
- self.embedding_proj_layer = nn.Linear(
- self.config.vlm_config.text_config.hidden_size,
- self.embedding_dim,
- )
-
- self.post_init()
-
- @add_start_docstrings_to_model_forward(COLPALI_FOR_RETRIEVAL_INPUT_DOCSTRING)
- @replace_return_docstrings(output_type=ColPaliForRetrievalOutput, config_class=_CONFIG_FOR_DOC)
- def forward(
- self,
- input_ids: torch.LongTensor = None,
- pixel_values: torch.FloatTensor = None,
- attention_mask: Optional[torch.Tensor] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- **kwargs,
- ) -> Union[Tuple, ColPaliForRetrievalOutput]:
- r"""
- Returns:
- """
- if "pixel_values" in kwargs:
- kwargs["pixel_values"] = kwargs["pixel_values"].to(dtype=self.dtype)
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
-
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
-
- outputs = self.vlm(
- input_ids=input_ids,
- attention_mask=attention_mask,
- pixel_values=pixel_values,
- output_hidden_states=True,
- return_dict=return_dict,
- output_attentions=output_attentions,
- **kwargs,
- )
-
- last_hidden_states = outputs.hidden_states[-1] # (batch_size, sequence_length, hidden_size)
- embeddings = self.embedding_proj_layer(last_hidden_states) # (batch_size, sequence_length, dim)
-
- # L2 normalization
- embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True) # (batch_size, sequence_length, dim)
-
- embeddings = embeddings * attention_mask.unsqueeze(-1) # (batch_size, sequence_length, dim)
-
- loss = None
- if not return_dict:
- output = (embeddings,) + outputs[2:]
- output[2] = output[2] if output_hidden_states is not None else None
- output[-1] = (outputs.image_hidden_states if pixel_values is not None else None,)
- return (loss,) + output if loss is not None else output
-
- return ColPaliForRetrievalOutput(
- loss=loss,
- embeddings=embeddings,
- past_key_values=outputs.past_key_values,
- hidden_states=outputs.hidden_states if output_hidden_states else None,
- attentions=outputs.attentions,
- image_hidden_states=outputs.image_hidden_states if pixel_values is not None else None,
- )
-
- def get_input_embeddings(self):
- return self.vlm.language_model.get_input_embeddings()
-
- def set_input_embeddings(self, value):
- self.vlm.language_model.set_input_embeddings(value)
-
- def get_output_embeddings(self):
- return self.vlm.language_model.get_output_embeddings()
-
- def set_output_embeddings(self, new_embeddings):
- self.vlm.language_model.set_output_embeddings(new_embeddings)
-
- def set_decoder(self, decoder):
- self.vlm.language_model.set_decoder(decoder)
-
- def get_decoder(self):
- return self.vlm.language_model.get_decoder()
-
- def tie_weights(self):
- return self.vlm.language_model.tie_weights()
-
- def resize_token_embeddings(
- self,
- new_num_tokens: Optional[int] = None,
- pad_to_multiple_of: Optional[int] = None,
- mean_resizing: bool = True,
- ) -> nn.Embedding:
- model_embeds = self.vlm.language_model.resize_token_embeddings(
- new_num_tokens=new_num_tokens,
- pad_to_multiple_of=pad_to_multiple_of,
- mean_resizing=mean_resizing,
- )
-
- self.config.vlm_config.text_config.vocab_size = model_embeds.num_embeddings
- self.config.vlm_config.vocab_size = model_embeds.num_embeddings
- self.vlm.vocab_size = model_embeds.num_embeddings
- self.vocab_size = model_embeds.num_embeddings
-
- return model_embeds
-
-
-__all__ = [
- "ColPaliForRetrieval",
- "ColPaliForRetrievalOutput",
- "ColPaliPreTrainedModel",
-]
diff --git a/src/transformers/models/colpali/modular_colpali.py b/src/transformers/models/colpali/modular_colpali.py
deleted file mode 100644
index ceb43e2d66f335..00000000000000
--- a/src/transformers/models/colpali/modular_colpali.py
+++ /dev/null
@@ -1,354 +0,0 @@
-# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-from typing import ClassVar, List, Optional, Union
-
-from transformers.models.paligemma.processing_paligemma import (
- IMAGE_TOKEN,
- PaliGemmaProcessor,
- build_string_from_input,
- make_batched_images,
-)
-
-from ...feature_extraction_utils import BatchFeature
-from ...image_utils import ImageInput, is_valid_image
-from ...processing_utils import (
- ProcessingKwargs,
- Unpack,
-)
-from ...tokenization_utils_base import (
- PreTokenizedInput,
- TextInput,
-)
-from ...utils import (
- is_torch_available,
- logging,
-)
-
-
-if is_torch_available():
- import torch
-
-
-logger = logging.get_logger(__name__)
-
-
-class ColPaliProcessorKwargs(ProcessingKwargs, total=False):
- _defaults = {
- "text_kwargs": {
- "padding": "longest",
- },
- "images_kwargs": {
- "data_format": "channels_first",
- "do_convert_rgb": True,
- },
- "common_kwargs": {"return_tensors": "pt"},
- }
-
-
-class ColPaliProcessor(PaliGemmaProcessor):
- r"""
- Constructs a ColPali processor which wraps a PaliGemmaProcessor and special methods to process images and queries, as
- well as to compute the late-interaction retrieval score.
-
- [`ColPaliProcessor`] offers all the functionalities of [`PaliGemmaProcessor`]. See the [`~PaliGemmaProcessor.__call__`]
- for more information.
-
- Args:
- image_processor ([`SiglipImageProcessor`], *optional*):
- The image processor is a required input.
- tokenizer ([`LlamaTokenizerFast`], *optional*):
- The tokenizer is a required input.
- chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
- in a chat into a tokenizable string.
- """
-
- visual_prompt_prefix: ClassVar[str] = "Describe the image."
- query_prefix: ClassVar[str] = "Question: "
-
- @property
- def query_augmentation_token(self) -> str:
- """
- Return the query augmentation token.
-
- Query augmentation buffers are used as reasoning buffers during inference.
- """
- return self.tokenizer.pad_token
-
- def __call__(
- self,
- images: ImageInput = None,
- text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
- audio=None,
- videos=None,
- **kwargs: Unpack[ColPaliProcessorKwargs],
- ) -> BatchFeature:
- """
- Main method to prepare for the model either (1) one or several texts, either (2) one or several image(s). This method is custom
- wrapper around the PaliGemmaProcessor's [`~PaliGemmaProcessor.__call__`] method adapted for the ColPali model. It cannot process
- both text and images at the same time.
-
- When preparing the the text(s), this method forwards the `text` and `kwargs` arguments to LlamaTokenizerFast's
- [`~LlamaTokenizerFast.__call__`].
- When preparing the the image(s), this method forwards the `images` and `kwargs` arguments to SiglipImageProcessor's
- [`~SiglipImageProcessor.__call__`].
- Please refer to the doctsring of the above two methods for more information.
-
- Args:
- images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
- The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
- tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
- number of channels, H and W are image height and width.
- text (`str`, `List[str]`, `List[List[str]]`):
- The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
- (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
- `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
- return_tensors (`str` or [`~utils.TensorType`], *optional*):
- If set, will return tensors of a particular framework. Acceptable values are:
-
- - `'tf'`: Return TensorFlow `tf.constant` objects.
- - `'pt'`: Return PyTorch `torch.Tensor` objects.
- - `'np'`: Return NumPy `np.ndarray` objects.
- - `'jax'`: Return JAX `jnp.ndarray` objects.
-
- Returns:
- [`BatchFeature`]: A [`BatchFeature`] with the following fields:
-
- - **input_ids** -- List of token ids to be fed to a model.
- - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
- `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
- `None`).
- - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
- """
- output_kwargs = self._merge_kwargs(
- ColPaliProcessorKwargs,
- tokenizer_init_kwargs=self.tokenizer.init_kwargs,
- **kwargs,
- )
- suffix = output_kwargs["text_kwargs"].pop("suffix", None)
-
- return_token_type_ids = True if suffix is not None else False
-
- if text is None and images is None:
- raise ValueError("Either text or images must be provided")
- if text is not None and images is not None:
- raise ValueError("Only one of text or images can be processed at a time")
-
- if images is not None:
- if is_valid_image(images):
- images = [images]
- elif isinstance(images, list) and is_valid_image(images[0]):
- pass
- elif not (isinstance(images, list) and isinstance(images[0], list) and is_valid_image(images[0][0])):
- raise ValueError("images must be an image, list of images or list of list of images")
-
- texts_doc = [self.visual_prompt_prefix] * len(images)
- images = [image.convert("RGB") for image in images]
-
- input_strings = [
- build_string_from_input(
- prompt=prompt,
- bos_token=self.tokenizer.bos_token,
- image_seq_len=self.image_seq_length,
- image_token=IMAGE_TOKEN,
- num_images=len(image_list) if isinstance(image_list, list) else 1,
- )
- for prompt, image_list in zip(texts_doc, images)
- ]
- images = make_batched_images(images)
- pixel_values = self.image_processor(images, **output_kwargs["images_kwargs"])["pixel_values"]
-
- # max_length has to account for the image tokens
- if output_kwargs["text_kwargs"].get("max_length", None) is not None:
- output_kwargs["text_kwargs"]["max_length"] += self.image_seq_length
-
- inputs = self.tokenizer(
- input_strings,
- return_token_type_ids=False,
- **output_kwargs["text_kwargs"],
- )
-
- return_data = {**inputs, "pixel_values": pixel_values}
-
- if return_token_type_ids:
- labels = inputs["input_ids"].masked_fill(inputs["token_type_ids"] == 0, -100)
- return_data.update({"labels": labels})
-
- return BatchFeature(data=return_data)
-
- elif text is not None:
- if isinstance(text, str):
- text = [text]
- elif not (isinstance(text, list) and isinstance(text[0], str)):
- raise ValueError("Text must be a string or a list of strings")
-
- if suffix is None:
- suffix = self.query_augmentation_token * 10
- texts_query: List[str] = []
-
- for query in text:
- query = self.tokenizer.bos_token + self.query_prefix + query
- query += suffix # add suffix (pad tokens)
- query += "\n" # make input ISO to PaliGemma's processor
- texts_query.append(query)
-
- output_kwargs["text_kwargs"]["max_length"] = output_kwargs["text_kwargs"].get("max_length", 50)
-
- batch_query = self.tokenizer(
- texts_query,
- return_token_type_ids=False,
- **output_kwargs["text_kwargs"],
- )
-
- return batch_query
-
- def process_images(
- self,
- images: ImageInput = None,
- **kwargs: Unpack[ColPaliProcessorKwargs],
- ) -> BatchFeature:
- """
- Prepare for the model one or several image(s). This method is a wrapper around the `__call__` method of the ColPaliProcessor's
- [`ColPaliProcessor.__call__`].
-
- This method forwards the `images` and `kwargs` arguments to SiglipImageProcessor's [`~SiglipImageProcessor.__call__`].
-
- Args:
- images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
- The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
- tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
- number of channels, H and W are image height and width.
- return_tensors (`str` or [`~utils.TensorType`], *optional*):
- If set, will return tensors of a particular framework. Acceptable values are:
-
- - `'tf'`: Return TensorFlow `tf.constant` objects.
- - `'pt'`: Return PyTorch `torch.Tensor` objects.
- - `'np'`: Return NumPy `np.ndarray` objects.
- - `'jax'`: Return JAX `jnp.ndarray` objects.
-
- Returns:
- [`BatchFeature`]: A [`BatchFeature`] with the following fields:
-
- - **input_ids** -- List of token ids to be fed to a model.
- - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
- `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
- `None`).
- - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
- """
- return self.__call__(images=images, **kwargs)
-
- def process_queries(
- self,
- text: Union[TextInput, List[TextInput]],
- **kwargs: Unpack[ColPaliProcessorKwargs],
- ) -> BatchFeature:
- """
- Prepare for the model one or several texts. This method is a wrapper around the `__call__` method of the ColPaliProcessor's
- [`ColPaliProcessor.__call__`].
-
- This method forwards the `text` and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`].
-
- Args:
- text (`str`, `List[str]`, `List[List[str]]`):
- The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
- (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
- `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
- return_tensors (`str` or [`~utils.TensorType`], *optional*):
- If set, will return tensors of a particular framework. Acceptable values are:
-
- - `'tf'`: Return TensorFlow `tf.constant` objects.
- - `'pt'`: Return PyTorch `torch.Tensor` objects.
- - `'np'`: Return NumPy `np.ndarray` objects.
- - `'jax'`: Return JAX `jnp.ndarray` objects.
-
- Returns:
- [`BatchFeature`]: A [`BatchFeature`] with the following fields:
-
- - **input_ids** -- List of token ids to be fed to a model.
- - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
- `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
- `None`).
- """
- return self.__call__(text=text, **kwargs)
-
- def score_retrieval(
- self,
- query_embeddings: Union["torch.Tensor", List["torch.Tensor"]],
- passage_embeddings: Union["torch.Tensor", List["torch.Tensor"]],
- batch_size: int = 128,
- output_dtype: Optional["torch.dtype"] = None,
- output_device: Union["torch.device", str] = "cpu",
- ) -> "torch.Tensor":
- """
- Compute the late-interaction/MaxSim score (ColBERT-like) for the given multi-vector
- query embeddings (`qs`) and passage embeddings (`ps`). For ColPali, a passage is the
- image of a document page.
-
- Because the embedding tensors are multi-vector and can thus have different shapes, they
- should be fed as:
- (1) a list of tensors, where the i-th tensor is of shape (sequence_length_i, embedding_dim)
- (2) a single tensor of shape (n_passages, max_sequence_length, embedding_dim) -> usually
- obtained by padding the list of tensors.
-
- Args:
- query_embeddings (`Union[torch.Tensor, List[torch.Tensor]`): Query embeddings.
- passage_embeddings (`Union[torch.Tensor, List[torch.Tensor]`): Passage embeddings.
- batch_size (`int`, *optional*, defaults to 128): Batch size for computing scores.
- output_dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): The dtype of the output tensor.
- If `None`, the dtype of the input embeddings is used.
- output_device (`torch.device` or `str`, *optional*, defaults to "cpu"): The device of the output tensor.
-
- Returns:
- `torch.Tensor`: A tensor of shape `(n_queries, n_passages)` containing the scores. The score
- tensor is saved on the "cpu" device.
- """
-
- if len(query_embeddings) == 0:
- raise ValueError("No queries provided")
- if len(passage_embeddings) == 0:
- raise ValueError("No passages provided")
-
- if query_embeddings[0].device != passage_embeddings[0].device:
- raise ValueError("Queries and passages must be on the same device")
-
- if query_embeddings[0].dtype != passage_embeddings[0].dtype:
- raise ValueError("Queries and passages must have the same dtype")
-
- if output_dtype is None:
- output_dtype = query_embeddings[0].dtype
-
- scores: List[torch.Tensor] = []
-
- for i in range(0, len(query_embeddings), batch_size):
- batch_scores: List[torch.Tensor] = []
- batch_queries = torch.nn.utils.rnn.pad_sequence(
- query_embeddings[i : i + batch_size], batch_first=True, padding_value=0
- )
- for j in range(0, len(passage_embeddings), batch_size):
- batch_passages = torch.nn.utils.rnn.pad_sequence(
- passage_embeddings[j : j + batch_size], batch_first=True, padding_value=0
- )
- batch_scores.append(
- torch.einsum("bnd,csd->bcns", batch_queries, batch_passages).max(dim=3)[0].sum(dim=2)
- )
- scores.append(torch.cat(batch_scores, dim=1).to(output_dtype).to(output_device))
-
- return torch.cat(scores, dim=0)
-
-
-__all__ = [
- "ColPaliProcessor",
-]
diff --git a/src/transformers/models/colpali/processing_colpali.py b/src/transformers/models/colpali/processing_colpali.py
deleted file mode 100644
index f8d68675798bc4..00000000000000
--- a/src/transformers/models/colpali/processing_colpali.py
+++ /dev/null
@@ -1,443 +0,0 @@
-# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
-# This file was automatically generated from src/transformers/models/colpali/modular_colpali.py.
-# Do NOT edit this file manually as any edits will be overwritten by the generation of
-# the file from the modular. If any change should be done, please apply the change to the
-# modular_colpali.py file directly. One of our CI enforces this.
-# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
-# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-from typing import ClassVar, List, Optional, Union
-
-from ...feature_extraction_utils import BatchFeature
-from ...image_utils import ImageInput, is_valid_image
-from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
-from ...tokenization_utils_base import AddedToken, PreTokenizedInput, TextInput
-from ...utils import is_torch_available
-
-
-if is_torch_available():
- import torch
-
-
-class ColPaliProcessorKwargs(ProcessingKwargs, total=False):
- _defaults = {
- "text_kwargs": {
- "padding": "longest",
- },
- "images_kwargs": {
- "data_format": "channels_first",
- "do_convert_rgb": True,
- },
- "common_kwargs": {"return_tensors": "pt"},
- }
-
-
-IMAGE_TOKEN = ""
-EXTRA_TOKENS = [f"4}>" for i in range(1024)] + [f"3}>" for i in range(128)]
-
-
-def build_string_from_input(prompt, bos_token, image_seq_len, image_token, num_images):
- """
- Builds a string from the input prompt and image tokens.
- For example, for the call:
- build_string_from_input(
- prompt="Prefix str"
- bos_token="",
- image_seq_len=3,
- image_token="",
- )
- The output will be:
- "Initial str"
- Args:
- prompt (`List[Union[str, ImageInput]]`): The input prompt.
- bos_token (`str`): The beginning of sentence token.
- image_seq_len (`int`): The length of the image sequence.
- image_token (`str`): The image token.
- num_images (`int`): Number of images in the prompt.
- """
- return f"{image_token * image_seq_len * num_images}{bos_token}{prompt}\n"
-
-
-def make_batched_images(images) -> List[List[ImageInput]]:
- """
- Accepts images in list or nested list format, and makes a list of images for preprocessing.
-
- Args:
- images (`Union[List[List[ImageInput]], List[ImageInput], ImageInput]`):
- The input image.
-
- Returns:
- list: A list of images.
- """
- if isinstance(images, (list, tuple)) and isinstance(images[0], (list, tuple)) and is_valid_image(images[0][0]):
- return [img for img_list in images for img in img_list]
-
- elif isinstance(images, (list, tuple)) and is_valid_image(images[0]):
- return images
-
- elif is_valid_image(images):
- return [images]
-
- raise ValueError(f"Could not make batched video from {images}")
-
-
-class ColPaliProcessor(ProcessorMixin):
- r"""
- Constructs a ColPali processor which wraps a PaliGemmaProcessor and special methods to process images and queries, as
- well as to compute the late-interaction retrieval score.
-
- [`ColPaliProcessor`] offers all the functionalities of [`PaliGemmaProcessor`]. See the [`~PaliGemmaProcessor.__call__`]
- for more information.
-
- Args:
- image_processor ([`SiglipImageProcessor`], *optional*):
- The image processor is a required input.
- tokenizer ([`LlamaTokenizerFast`], *optional*):
- The tokenizer is a required input.
- chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
- in a chat into a tokenizable string.
- """
-
- attributes = ["image_processor", "tokenizer"]
- valid_kwargs = ["chat_template"]
- image_processor_class = "SiglipImageProcessor"
- tokenizer_class = ("GemmaTokenizer", "GemmaTokenizerFast")
-
- visual_prompt_prefix: ClassVar[str] = "Describe the image."
- query_prefix: ClassVar[str] = "Question: "
-
- def __init__(
- self,
- image_processor=None,
- tokenizer=None,
- chat_template=None,
- **kwargs,
- ):
- if image_processor is None:
- raise ValueError("You need to specify an `image_processor`.")
- if tokenizer is None:
- raise ValueError("You need to specify a `tokenizer`.")
- if not hasattr(image_processor, "image_seq_length"):
- raise ValueError("Image processor is missing an `image_seq_length` attribute.")
-
- self.image_seq_length = image_processor.image_seq_length
-
- if not hasattr(tokenizer, "image_token"):
- image_token = AddedToken(IMAGE_TOKEN, normalized=False, special=True)
- tokens_to_add = {"additional_special_tokens": [image_token]}
- tokenizer.add_special_tokens(tokens_to_add)
- self.image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
- else:
- self.image_token_id = tokenizer.image_token_id
-
- tokenizer.add_tokens(EXTRA_TOKENS)
- tokenizer.add_bos_token = False
- tokenizer.add_eos_token = False
-
- super().__init__(image_processor, tokenizer, chat_template=chat_template)
-
- def __call__(
- self,
- images: ImageInput = None,
- text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
- audio=None,
- videos=None,
- **kwargs: Unpack[ColPaliProcessorKwargs],
- ) -> BatchFeature:
- """
- Main method to prepare for the model either (1) one or several texts, either (2) one or several image(s). This method is custom
- wrapper around the PaliGemmaProcessor's [`~PaliGemmaProcessor.__call__`] method adapted for the ColPali model. It cannot process
- both text and images at the same time.
-
- When preparing the the text(s), this method forwards the `text` and `kwargs` arguments to LlamaTokenizerFast's
- [`~LlamaTokenizerFast.__call__`].
- When preparing the the image(s), this method forwards the `images` and `kwargs` arguments to SiglipImageProcessor's
- [`~SiglipImageProcessor.__call__`].
- Please refer to the doctsring of the above two methods for more information.
-
- Args:
- images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
- The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
- tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
- number of channels, H and W are image height and width.
- text (`str`, `List[str]`, `List[List[str]]`):
- The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
- (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
- `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
- return_tensors (`str` or [`~utils.TensorType`], *optional*):
- If set, will return tensors of a particular framework. Acceptable values are:
-
- - `'tf'`: Return TensorFlow `tf.constant` objects.
- - `'pt'`: Return PyTorch `torch.Tensor` objects.
- - `'np'`: Return NumPy `np.ndarray` objects.
- - `'jax'`: Return JAX `jnp.ndarray` objects.
-
- Returns:
- [`BatchFeature`]: A [`BatchFeature`] with the following fields:
-
- - **input_ids** -- List of token ids to be fed to a model.
- - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
- `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
- `None`).
- - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
- """
- output_kwargs = self._merge_kwargs(
- ColPaliProcessorKwargs,
- tokenizer_init_kwargs=self.tokenizer.init_kwargs,
- **kwargs,
- )
- suffix = output_kwargs["text_kwargs"].pop("suffix", None)
-
- return_token_type_ids = True if suffix is not None else False
-
- if text is None and images is None:
- raise ValueError("Either text or images must be provided")
- if text is not None and images is not None:
- raise ValueError("Only one of text or images can be processed at a time")
-
- if images is not None:
- if is_valid_image(images):
- images = [images]
- elif isinstance(images, list) and is_valid_image(images[0]):
- pass
- elif not (isinstance(images, list) and isinstance(images[0], list) and is_valid_image(images[0][0])):
- raise ValueError("images must be an image, list of images or list of list of images")
-
- texts_doc = [self.visual_prompt_prefix] * len(images)
- images = [image.convert("RGB") for image in images]
-
- input_strings = [
- build_string_from_input(
- prompt=prompt,
- bos_token=self.tokenizer.bos_token,
- image_seq_len=self.image_seq_length,
- image_token=IMAGE_TOKEN,
- num_images=len(image_list) if isinstance(image_list, list) else 1,
- )
- for prompt, image_list in zip(texts_doc, images)
- ]
- images = make_batched_images(images)
- pixel_values = self.image_processor(images, **output_kwargs["images_kwargs"])["pixel_values"]
-
- # max_length has to account for the image tokens
- if output_kwargs["text_kwargs"].get("max_length", None) is not None:
- output_kwargs["text_kwargs"]["max_length"] += self.image_seq_length
-
- inputs = self.tokenizer(
- input_strings,
- return_token_type_ids=False,
- **output_kwargs["text_kwargs"],
- )
-
- return_data = {**inputs, "pixel_values": pixel_values}
-
- if return_token_type_ids:
- labels = inputs["input_ids"].masked_fill(inputs["token_type_ids"] == 0, -100)
- return_data.update({"labels": labels})
-
- return BatchFeature(data=return_data)
-
- elif text is not None:
- if isinstance(text, str):
- text = [text]
- elif not (isinstance(text, list) and isinstance(text[0], str)):
- raise ValueError("Text must be a string or a list of strings")
-
- if suffix is None:
- suffix = self.query_augmentation_token * 10
- texts_query: List[str] = []
-
- for query in text:
- query = self.tokenizer.bos_token + self.query_prefix + query
- query += suffix # add suffix (pad tokens)
- query += "\n" # make input ISO to PaliGemma's processor
- texts_query.append(query)
-
- output_kwargs["text_kwargs"]["max_length"] = output_kwargs["text_kwargs"].get("max_length", 50)
-
- batch_query = self.tokenizer(
- texts_query,
- return_token_type_ids=False,
- **output_kwargs["text_kwargs"],
- )
-
- return batch_query
-
- def batch_decode(self, *args, **kwargs):
- """
- This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
- refer to the docstring of this method for more information.
- """
- return self.tokenizer.batch_decode(*args, **kwargs)
-
- def decode(self, *args, **kwargs):
- """
- This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
- the docstring of this method for more information.
- """
- return self.tokenizer.decode(*args, **kwargs)
-
- @property
- def model_input_names(self):
- tokenizer_input_names = self.tokenizer.model_input_names
- image_processor_input_names = self.image_processor.model_input_names
- return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
-
- @property
- def query_augmentation_token(self) -> str:
- """
- Return the query augmentation token.
-
- Query augmentation buffers are used as reasoning buffers during inference.
- """
- return self.tokenizer.pad_token
-
- def process_images(
- self,
- images: ImageInput = None,
- **kwargs: Unpack[ColPaliProcessorKwargs],
- ) -> BatchFeature:
- """
- Prepare for the model one or several image(s). This method is a wrapper around the `__call__` method of the ColPaliProcessor's
- [`ColPaliProcessor.__call__`].
-
- This method forwards the `images` and `kwargs` arguments to SiglipImageProcessor's [`~SiglipImageProcessor.__call__`].
-
- Args:
- images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
- The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
- tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
- number of channels, H and W are image height and width.
- return_tensors (`str` or [`~utils.TensorType`], *optional*):
- If set, will return tensors of a particular framework. Acceptable values are:
-
- - `'tf'`: Return TensorFlow `tf.constant` objects.
- - `'pt'`: Return PyTorch `torch.Tensor` objects.
- - `'np'`: Return NumPy `np.ndarray` objects.
- - `'jax'`: Return JAX `jnp.ndarray` objects.
-
- Returns:
- [`BatchFeature`]: A [`BatchFeature`] with the following fields:
-
- - **input_ids** -- List of token ids to be fed to a model.
- - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
- `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
- `None`).
- - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
- """
- return self.__call__(images=images, **kwargs)
-
- def process_queries(
- self,
- text: Union[TextInput, List[TextInput]],
- **kwargs: Unpack[ColPaliProcessorKwargs],
- ) -> BatchFeature:
- """
- Prepare for the model one or several texts. This method is a wrapper around the `__call__` method of the ColPaliProcessor's
- [`ColPaliProcessor.__call__`].
-
- This method forwards the `text` and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`].
-
- Args:
- text (`str`, `List[str]`, `List[List[str]]`):
- The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
- (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
- `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
- return_tensors (`str` or [`~utils.TensorType`], *optional*):
- If set, will return tensors of a particular framework. Acceptable values are:
-
- - `'tf'`: Return TensorFlow `tf.constant` objects.
- - `'pt'`: Return PyTorch `torch.Tensor` objects.
- - `'np'`: Return NumPy `np.ndarray` objects.
- - `'jax'`: Return JAX `jnp.ndarray` objects.
-
- Returns:
- [`BatchFeature`]: A [`BatchFeature`] with the following fields:
-
- - **input_ids** -- List of token ids to be fed to a model.
- - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
- `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
- `None`).
- """
- return self.__call__(text=text, **kwargs)
-
- def score_retrieval(
- self,
- query_embeddings: Union["torch.Tensor", List["torch.Tensor"]],
- passage_embeddings: Union["torch.Tensor", List["torch.Tensor"]],
- batch_size: int = 128,
- output_dtype: Optional["torch.dtype"] = None,
- output_device: Union["torch.device", str] = "cpu",
- ) -> "torch.Tensor":
- """
- Compute the late-interaction/MaxSim score (ColBERT-like) for the given multi-vector
- query embeddings (`qs`) and passage embeddings (`ps`). For ColPali, a passage is the
- image of a document page.
-
- Because the embedding tensors are multi-vector and can thus have different shapes, they
- should be fed as:
- (1) a list of tensors, where the i-th tensor is of shape (sequence_length_i, embedding_dim)
- (2) a single tensor of shape (n_passages, max_sequence_length, embedding_dim) -> usually
- obtained by padding the list of tensors.
-
- Args:
- query_embeddings (`Union[torch.Tensor, List[torch.Tensor]`): Query embeddings.
- passage_embeddings (`Union[torch.Tensor, List[torch.Tensor]`): Passage embeddings.
- batch_size (`int`, *optional*, defaults to 128): Batch size for computing scores.
- output_dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): The dtype of the output tensor.
- If `None`, the dtype of the input embeddings is used.
- output_device (`torch.device` or `str`, *optional*, defaults to "cpu"): The device of the output tensor.
-
- Returns:
- `torch.Tensor`: A tensor of shape `(n_queries, n_passages)` containing the scores. The score
- tensor is saved on the "cpu" device.
- """
-
- if len(query_embeddings) == 0:
- raise ValueError("No queries provided")
- if len(passage_embeddings) == 0:
- raise ValueError("No passages provided")
-
- if query_embeddings[0].device != passage_embeddings[0].device:
- raise ValueError("Queries and passages must be on the same device")
-
- if query_embeddings[0].dtype != passage_embeddings[0].dtype:
- raise ValueError("Queries and passages must have the same dtype")
-
- if output_dtype is None:
- output_dtype = query_embeddings[0].dtype
-
- scores: List[torch.Tensor] = []
-
- for i in range(0, len(query_embeddings), batch_size):
- batch_scores: List[torch.Tensor] = []
- batch_queries = torch.nn.utils.rnn.pad_sequence(
- query_embeddings[i : i + batch_size], batch_first=True, padding_value=0
- )
- for j in range(0, len(passage_embeddings), batch_size):
- batch_passages = torch.nn.utils.rnn.pad_sequence(
- passage_embeddings[j : j + batch_size], batch_first=True, padding_value=0
- )
- batch_scores.append(
- torch.einsum("bnd,csd->bcns", batch_queries, batch_passages).max(dim=3)[0].sum(dim=2)
- )
- scores.append(torch.cat(batch_scores, dim=1).to(output_dtype).to(output_device))
-
- return torch.cat(scores, dim=0)
-
-
-__all__ = ["ColPaliProcessor"]
diff --git a/src/transformers/models/data2vec/modeling_data2vec_audio.py b/src/transformers/models/data2vec/modeling_data2vec_audio.py
index 801bd19fca3b60..590509eaf9057c 100755
--- a/src/transformers/models/data2vec/modeling_data2vec_audio.py
+++ b/src/transformers/models/data2vec/modeling_data2vec_audio.py
@@ -489,6 +489,7 @@ class Data2VecAudioFlashAttention2(Data2VecAudioAttention):
flash attention and deal with padding tokens in case the input contains any of them.
"""
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -1421,8 +1422,7 @@ def forward(
pooled_output = hidden_states.mean(dim=1)
else:
padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
- expand_padding_mask = padding_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
- hidden_states[~expand_padding_mask] = 0.0
+ hidden_states[~padding_mask] = 0.0
pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
logits = self.classifier(pooled_output)
diff --git a/src/transformers/models/data2vec/modeling_data2vec_vision.py b/src/transformers/models/data2vec/modeling_data2vec_vision.py
index 770162285bf33b..4d252ce1f19db7 100644
--- a/src/transformers/models/data2vec/modeling_data2vec_vision.py
+++ b/src/transformers/models/data2vec/modeling_data2vec_vision.py
@@ -362,69 +362,6 @@ def forward(
return outputs
-# Copied from transformers.models.beit.modeling_beit.BeitSdpaSelfAttention with Beit->Data2VecVision
-class Data2VecVisionSdpaSelfAttention(Data2VecVisionSelfAttention):
- def forward(
- self,
- hidden_states: torch.Tensor,
- head_mask: Optional[torch.Tensor] = None,
- output_attentions: bool = False,
- relative_position_bias: Optional["Data2VecVisionRelativePositionBias"] = None,
- interpolate_pos_encoding: bool = False,
- resolution: Optional[Tuple[int]] = None,
- ) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
- if output_attentions or head_mask is not None:
- logger.warning_once(
- "`Data2VecVisionSdpaSelfAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not "
- "support `output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, "
- "but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. "
- 'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
- )
- return super().forward(
- hidden_states=hidden_states,
- head_mask=head_mask,
- output_attentions=output_attentions,
- relative_position_bias=relative_position_bias,
- interpolate_pos_encoding=interpolate_pos_encoding,
- resolution=resolution,
- )
-
- mixed_query_layer = self.query(hidden_states)
- key_layer = self.transpose_for_scores(self.key(hidden_states))
- value_layer = self.transpose_for_scores(self.value(hidden_states))
- query_layer = self.transpose_for_scores(mixed_query_layer)
-
- attn_bias = None
- if self.relative_position_bias is not None:
- height, width = resolution
- window_size = (height // self.config.patch_size, width // self.config.patch_size)
- attn_bias = self.relative_position_bias(
- window_size, interpolate_pos_encoding, dim_size=hidden_states.shape[1]
- )
-
- # Add shared relative position bias if provided.
- if relative_position_bias is not None:
- if attn_bias is None:
- attn_bias = relative_position_bias
- else:
- attn_bias += relative_position_bias
-
- scaling = 1 / math.sqrt(self.attention_head_size)
- context_layer = torch.nn.functional.scaled_dot_product_attention(
- query_layer,
- key_layer,
- value_layer,
- attn_mask=attn_bias,
- dropout_p=self.config.attention_probs_dropout_prob if self.training else 0.0,
- is_causal=False,
- scale=scaling,
- )
- context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
- new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
- context_layer = context_layer.view(*new_context_layer_shape)
- return context_layer, None
-
-
# Copied from transformers.models.beit.modeling_beit.BeitSelfOutput with Beit->Data2VecVision
class Data2VecVisionSelfOutput(nn.Module):
"""
@@ -444,19 +381,11 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor, gamma
return hidden_states
-DATA2VEC_VISION_SELF_ATTENTION_CLASSES = {
- "eager": Data2VecVisionSelfAttention,
- "sdpa": Data2VecVisionSdpaSelfAttention,
-}
-
-
-# Copied from tests.models.beit.modeling_beit.BeitAttention with Beit->Data2VecVision, BEIT->DATA2VEC_VISION
+# Copied from transformers.models.beit.modeling_beit.BeitAttention with Beit->Data2VecVision
class Data2VecVisionAttention(nn.Module):
def __init__(self, config: Data2VecVisionConfig, window_size: Optional[tuple] = None) -> None:
super().__init__()
- self.attention = DATA2VEC_VISION_SELF_ATTENTION_CLASSES[config._attn_implementation](
- config, window_size=window_size
- )
+ self.attention = Data2VecVisionSelfAttention(config, window_size=window_size)
self.output = Data2VecVisionSelfOutput(config)
self.pruned_heads = set()
@@ -782,7 +711,6 @@ class Data2VecVisionPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["Data2VecVisionLayer"]
_keys_to_ignore_on_load_unexpected = [r".*relative_position_index.*"]
- _supports_sdpa = True
def _init_weights(self, module):
"""Initialize the weights"""
diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py
index 0d2c4297e0d473..659fa154ecf776 100644
--- a/src/transformers/models/dbrx/modeling_dbrx.py
+++ b/src/transformers/models/dbrx/modeling_dbrx.py
@@ -46,6 +46,7 @@
_CONFIG_FOR_DOC = "DbrxConfig"
+# Copied from transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding with Gemma->Dbrx
class DbrxRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
@@ -317,6 +318,7 @@ class DbrxFlashAttention2(DbrxAttention):
calls the public API of flash attention.
"""
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -843,7 +845,7 @@ def _init_weights(self, module: nn.Module):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
- module.weight.data.fill_(1.0)
+ module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, DbrxExpertGLU):
diff --git a/src/transformers/models/deberta/modeling_deberta.py b/src/transformers/models/deberta/modeling_deberta.py
index c9a85bcad1bd6f..6993121b6c1ebe 100644
--- a/src/transformers/models/deberta/modeling_deberta.py
+++ b/src/transformers/models/deberta/modeling_deberta.py
@@ -290,6 +290,7 @@ def forward(
attention_scores = attention_scores.masked_fill(~(attention_mask), torch.finfo(query_layer.dtype).min)
# bsz x height x length x dimension
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
+ attention_probs.masked_fill(attention_mask, 0)
attention_probs = self.dropout(attention_probs)
if self.head_weights_proj is not None:
diff --git a/src/transformers/models/deberta_v2/modeling_deberta_v2.py b/src/transformers/models/deberta_v2/modeling_deberta_v2.py
index 7d2f25603a6f96..6645c1de832e12 100644
--- a/src/transformers/models/deberta_v2/modeling_deberta_v2.py
+++ b/src/transformers/models/deberta_v2/modeling_deberta_v2.py
@@ -267,6 +267,7 @@ def forward(
attention_scores = attention_scores.masked_fill(~(attention_mask), torch.finfo(query_layer.dtype).min)
# bsz x height x length x dimension
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
+ attention_probs.masked_fill(attention_mask, 0)
attention_probs = self.dropout(attention_probs)
context_layer = torch.bmm(
diff --git a/src/transformers/models/decision_transformer/modeling_decision_transformer.py b/src/transformers/models/decision_transformer/modeling_decision_transformer.py
index 60fea55d87be5d..b8eb9f5a8b4222 100755
--- a/src/transformers/models/decision_transformer/modeling_decision_transformer.py
+++ b/src/transformers/models/decision_transformer/modeling_decision_transformer.py
@@ -17,7 +17,7 @@
import math
import os
from dataclasses import dataclass
-from typing import Callable, Optional, Tuple, Union
+from typing import Optional, Tuple, Union
import torch
import torch.utils.checkpoint
@@ -25,7 +25,7 @@
from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
-from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
from ...utils import (
ModelOutput,
@@ -100,49 +100,6 @@ def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
return model
-# Copied from transformers.models.gpt2.modeling_gpt2.eager_attention_forward
-def eager_attention_forward(module, query, key, value, attention_mask, head_mask=None, **kwargs):
- attn_weights = torch.matmul(query, key.transpose(-1, -2))
-
- if module.scale_attn_weights:
- attn_weights = attn_weights / torch.full(
- [], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device
- )
-
- # Layer-wise attention scaling
- if module.scale_attn_by_inverse_layer_idx:
- attn_weights = attn_weights / float(module.layer_idx + 1)
-
- if not module.is_cross_attention:
- # if only "normal" attention layer implements causal mask
- query_length, key_length = query.size(-2), key.size(-2)
- causal_mask = module.bias[:, :, key_length - query_length : key_length, :key_length]
- mask_value = torch.finfo(attn_weights.dtype).min
- # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
- # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
- mask_value = torch.full([], mask_value, dtype=attn_weights.dtype, device=attn_weights.device)
- attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value)
-
- if attention_mask is not None:
- # Apply the attention mask
- attn_weights = attn_weights + attention_mask
-
- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
-
- # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise
- attn_weights = attn_weights.type(value.dtype)
- attn_weights = module.attn_dropout(attn_weights)
-
- # Mask heads if we want to
- if head_mask is not None:
- attn_weights = attn_weights * head_mask
-
- attn_output = torch.matmul(attn_weights, value)
- attn_output = attn_output.transpose(1, 2)
-
- return attn_output, attn_weights
-
-
# Copied from transformers.models.gpt2.modeling_gpt2.GPT2Attention with GPT2->DecisionTransformerGPT2
class DecisionTransformerGPT2Attention(nn.Module):
def __init__(self, config, is_cross_attention=False, layer_idx=None):
@@ -204,6 +161,46 @@ def prune_heads(self, heads):
self.num_heads = self.num_heads - len(heads)
self.pruned_heads = self.pruned_heads.union(heads)
+ def _attn(self, query, key, value, attention_mask=None, head_mask=None):
+ attn_weights = torch.matmul(query, key.transpose(-1, -2))
+
+ if self.scale_attn_weights:
+ attn_weights = attn_weights / torch.full(
+ [], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device
+ )
+
+ # Layer-wise attention scaling
+ if self.scale_attn_by_inverse_layer_idx:
+ attn_weights = attn_weights / float(self.layer_idx + 1)
+
+ if not self.is_cross_attention:
+ # if only "normal" attention layer implements causal mask
+ query_length, key_length = query.size(-2), key.size(-2)
+ causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
+ mask_value = torch.finfo(attn_weights.dtype).min
+ # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
+ # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
+ mask_value = torch.full([], mask_value, dtype=attn_weights.dtype, device=attn_weights.device)
+ attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value)
+
+ if attention_mask is not None:
+ # Apply the attention mask
+ attn_weights = attn_weights + attention_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+
+ # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise
+ attn_weights = attn_weights.type(value.dtype)
+ attn_weights = self.attn_dropout(attn_weights)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attn_weights = attn_weights * head_mask
+
+ attn_output = torch.matmul(attn_weights, value)
+
+ return attn_output, attn_weights
+
def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, head_mask=None):
# Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM)
bsz, num_heads, q_seq_len, dk = query.size()
@@ -253,10 +250,25 @@ def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, hea
attn_weights = attn_weights * head_mask
attn_output = torch.matmul(attn_weights, value)
- attn_output = attn_output.transpose(1, 2)
return attn_output, attn_weights
+ def _split_heads(self, tensor, num_heads, attn_head_size):
+ """
+ Splits hidden_size dim into attn_head_size and num_heads
+ """
+ new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
+ tensor = tensor.view(new_shape)
+ return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
+
+ def _merge_heads(self, tensor, num_heads, attn_head_size):
+ """
+ Merges attn_head_size dim and num_attn_heads dim into hidden_size
+ """
+ tensor = tensor.permute(0, 2, 1, 3).contiguous()
+ new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
+ return tensor.view(new_shape)
+
def forward(
self,
hidden_states: Optional[Tuple[torch.FloatTensor]],
@@ -267,7 +279,6 @@ def forward(
encoder_attention_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
- **kwargs,
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
if encoder_hidden_states is not None:
if not hasattr(self, "q_attn"):
@@ -276,65 +287,32 @@ def forward(
"Please make sure to instantiate class with `DecisionTransformerGPT2Attention(..., is_cross_attention=True)`."
)
- query_states = self.q_attn(hidden_states)
- key_states, value_states = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
+ query = self.q_attn(hidden_states)
+ key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
attention_mask = encoder_attention_mask
else:
- query_states, key_states, value_states = self.c_attn(hidden_states).split(self.split_size, dim=2)
-
- shape_q = (*query_states.shape[:-1], -1, self.head_dim)
- shape_kv = (*key_states.shape[:-1], -1, self.head_dim)
+ query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
- query_states = query_states.reshape(shape_q).transpose(1, 2)
- key_states = key_states.reshape(shape_kv).transpose(1, 2)
- value_states = value_states.reshape(shape_kv).transpose(1, 2)
+ query = self._split_heads(query, self.num_heads, self.head_dim)
+ key = self._split_heads(key, self.num_heads, self.head_dim)
+ value = self._split_heads(value, self.num_heads, self.head_dim)
if layer_past is not None:
past_key, past_value = layer_past
- key_states = torch.cat((past_key, key_states), dim=-2)
- value_states = torch.cat((past_value, value_states), dim=-2)
+ key = torch.cat((past_key, key), dim=-2)
+ value = torch.cat((past_value, value), dim=-2)
if use_cache is True:
- present = (key_states, value_states)
+ present = (key, value)
else:
present = None
- is_cross_attention = encoder_hidden_states is not None
- is_causal = attention_mask is None and query_states.shape[-2] > 1 and not is_cross_attention
-
- using_eager = self.config._attn_implementation == "eager"
- attention_interface: Callable = eager_attention_forward
- if self.config._attn_implementation != "eager":
- if self.config._attn_implementation == "sdpa" and (output_attentions or head_mask is not None):
- using_eager = True
- logger.warning_once(
- "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
- 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
- )
- else:
- # Attention functions are consistent with previous equivalent attention classes, however they do not support some options
- # (e.g. layer scaling, head mask) that eager supports. These implementations are thus equivalent to previous code, but
- # not necessarily to eager (if mentionned options are provided).
- attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
-
- if using_eager and self.reorder_and_upcast_attn:
- attn_output, attn_weights = self._upcast_and_reordered_attn(
- query_states, key_states, value_states, attention_mask, head_mask
- )
+ if self.reorder_and_upcast_attn:
+ attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask)
else:
- attn_output, attn_weights = attention_interface(
- self,
- query_states,
- key_states,
- value_states,
- attention_mask,
- head_mask=head_mask,
- dropout=self.attn_dropout.p if self.training else 0.0,
- is_causal=is_causal,
- **kwargs,
- )
+ attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
- attn_output = attn_output.reshape(*attn_output.shape[:-2], -1).contiguous()
+ attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
attn_output = self.c_proj(attn_output)
attn_output = self.resid_dropout(attn_output)
diff --git a/src/transformers/models/deprecated/transfo_xl/tokenization_transfo_xl.py b/src/transformers/models/deprecated/transfo_xl/tokenization_transfo_xl.py
index 53dec63cfc4fd8..ca80636b23565d 100644
--- a/src/transformers/models/deprecated/transfo_xl/tokenization_transfo_xl.py
+++ b/src/transformers/models/deprecated/transfo_xl/tokenization_transfo_xl.py
@@ -222,7 +222,7 @@ def __init__(
"from a PyTorch pretrained vocabulary, "
"or activate it with environment variables USE_TORCH=1 and USE_TF=0."
)
- vocab_dict = torch.load(pretrained_vocab_file, weights_only=True)
+ vocab_dict = torch.load(pretrained_vocab_file)
if vocab_dict is not None:
for key, value in vocab_dict.items():
@@ -705,7 +705,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs,
# Instantiate tokenizer.
corpus = cls(*inputs, **kwargs)
- corpus_dict = torch.load(resolved_corpus_file, weights_only=True)
+ corpus_dict = torch.load(resolved_corpus_file)
for key, value in corpus_dict.items():
corpus.__dict__[key] = value
corpus.vocab = vocab
@@ -784,7 +784,7 @@ def get_lm_corpus(datadir, dataset):
fn_pickle = os.path.join(datadir, "cache.pkl")
if os.path.exists(fn):
logger.info("Loading cached dataset...")
- corpus = torch.load(fn_pickle, weights_only=True)
+ corpus = torch.load(fn_pickle)
elif os.path.exists(fn):
logger.info("Loading cached dataset from pickle...")
if not strtobool(os.environ.get("TRUST_REMOTE_CODE", "False")):
diff --git a/src/transformers/models/distilbert/modeling_distilbert.py b/src/transformers/models/distilbert/modeling_distilbert.py
index a826272956e503..36e35594b3d3c6 100755
--- a/src/transformers/models/distilbert/modeling_distilbert.py
+++ b/src/transformers/models/distilbert/modeling_distilbert.py
@@ -245,6 +245,7 @@ class DistilBertFlashAttention2(MultiHeadSelfAttention):
API of flash attention and deal with padding tokens in case the input contains any of them.
"""
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py
index e0e4ff424cb47d..faea670ecbf428 100644
--- a/src/transformers/models/falcon/modeling_falcon.py
+++ b/src/transformers/models/falcon/modeling_falcon.py
@@ -38,6 +38,7 @@
)
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
from ...modeling_utils import PreTrainedModel
+from ...pytorch_utils import is_torch_greater_or_equal_than_2_0
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
@@ -112,18 +113,40 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
class FalconRotaryEmbedding(nn.Module):
def __init__(
self,
- config: FalconConfig,
+ dim=None,
+ max_position_embeddings=2048,
+ base=10000,
device=None,
+ scaling_factor=1.0,
+ rope_type="default",
+ config: Optional[FalconConfig] = None,
):
super().__init__()
+ # TODO (joao): remove the `if` below, only used for BC
self.rope_kwargs = {}
- # BC: "rope_type" was originally "type"
- if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
- self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
+ if config is None:
+ logger.warning_once(
+ "`FalconRotaryEmbedding` can now be fully parameterized by passing the model config through the "
+ "`config` argument. All other arguments will be removed in v4.46"
+ )
+ self.rope_kwargs = {
+ "rope_type": rope_type,
+ "factor": scaling_factor,
+ "dim": dim,
+ "base": base,
+ "max_position_embeddings": max_position_embeddings,
+ }
+ self.rope_type = rope_type
+ self.max_seq_len_cached = max_position_embeddings
+ self.original_max_seq_len = max_position_embeddings
else:
- self.rope_type = "default"
- self.max_seq_len_cached = config.max_position_embeddings
- self.original_max_seq_len = config.max_position_embeddings
+ # BC: "rope_type" was originally "type"
+ if config.rope_scaling is not None:
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
+ else:
+ self.rope_type = "default"
+ self.max_seq_len_cached = config.max_position_embeddings
+ self.original_max_seq_len = config.max_position_embeddings
self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
@@ -174,6 +197,33 @@ def forward(self, x, position_ids):
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->Falcon
+class FalconLinearScalingRotaryEmbedding(FalconRotaryEmbedding):
+ """FalconRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
+
+ def __init__(self, *args, **kwargs):
+ logger.warning_once(
+ "`FalconLinearScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use "
+ "`FalconRotaryEmbedding`, which now also does linear scaling (simply pass the model config to __init__)."
+ )
+ kwargs["rope_type"] = "linear"
+ super().__init__(*args, **kwargs)
+
+
+# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Falcon
+class FalconDynamicNTKScalingRotaryEmbedding(FalconRotaryEmbedding):
+ """FalconRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
+
+ def __init__(self, *args, **kwargs):
+ logger.warning_once(
+ "`FalconDynamicNTKScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use "
+ "`FalconRotaryEmbedding`, which now also does dynamic ntk scaling (simply pass the model config to "
+ "__init__)."
+ )
+ kwargs["rope_type"] = "dynamic"
+ super().__init__(*args, **kwargs)
+
+
def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor:
batch_size, seq_length = attention_mask.shape
closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
@@ -338,7 +388,7 @@ def forward(
use_cache: bool = False,
output_attentions: bool = False,
cache_position: Optional[torch.LongTensor] = None,
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
):
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads
@@ -352,7 +402,16 @@ def forward(
value_layer = value_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim)
if alibi is None:
- cos, sin = position_embeddings
+ if position_embeddings is None:
+ logger.warning_once(
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
+ "removed and `position_embeddings` will be mandatory."
+ )
+ cos, sin = self.rotary_emb(value_layer, position_ids)
+ else:
+ cos, sin = position_embeddings
query_layer, key_layer = apply_rotary_pos_emb(query_layer, key_layer, cos, sin)
if layer_past is not None:
@@ -469,6 +528,7 @@ class FalconFlashAttention2(FalconAttention):
flash attention and deal with padding tokens in case the input contains any of them.
"""
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -488,7 +548,7 @@ def forward(
use_cache: bool = False,
output_attentions: bool = False,
cache_position: Optional[torch.LongTensor] = None,
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
):
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads
@@ -502,7 +562,16 @@ def forward(
value_layer = value_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim)
if alibi is None:
- cos, sin = position_embeddings
+ if position_embeddings is None:
+ logger.warning_once(
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
+ "removed and `position_embeddings` will be mandatory."
+ )
+ cos, sin = self.rotary_emb(value_layer, position_ids)
+ else:
+ cos, sin = position_embeddings
query_layer, key_layer = apply_rotary_pos_emb(query_layer, key_layer, cos, sin)
if layer_past is not None:
@@ -626,7 +695,7 @@ def forward(
use_cache: bool = False,
output_attentions: bool = False,
cache_position: Optional[torch.LongTensor] = None,
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
**kwargs,
):
residual = hidden_states
@@ -814,6 +883,14 @@ def _init_weights(self, module: nn.Module):
# Adapted from transformers.modeling_utils.PreTrainedModel._check_and_enable_sdpa
@classmethod
def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False) -> "PretrainedConfig":
+ # NOTE: Falcon supported SDPA from PyTorch 2.0. We keep it like that for backward compatibility (automatically use SDPA for torch>=2.0).
+ if hard_check_only:
+ if not is_torch_greater_or_equal_than_2_0:
+ raise ImportError("PyTorch SDPA requirements in Transformers are not met. Please install torch>=2.0.")
+
+ if not is_torch_greater_or_equal_than_2_0:
+ return config
+
_is_bettertransformer = getattr(cls, "use_bettertransformer", False)
if _is_bettertransformer:
return config
diff --git a/src/transformers/models/gemma/__init__.py b/src/transformers/models/gemma/__init__.py
index 65fb1ca5edef43..1aafae6e88c2f1 100644
--- a/src/transformers/models/gemma/__init__.py
+++ b/src/transformers/models/gemma/__init__.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,18 +13,111 @@
# limitations under the License.
from typing import TYPE_CHECKING
-from ...utils import _LazyModule
-from ...utils.import_utils import define_import_structure
+from ...utils import (
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ is_flax_available,
+ is_sentencepiece_available,
+ is_tokenizers_available,
+ is_torch_available,
+)
+
+
+_import_structure = {
+ "configuration_gemma": ["GemmaConfig"],
+}
+
+try:
+ if not is_sentencepiece_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["tokenization_gemma"] = ["GemmaTokenizer"]
+
+try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["tokenization_gemma_fast"] = ["GemmaTokenizerFast"]
+
+
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_gemma"] = [
+ "GemmaForCausalLM",
+ "GemmaModel",
+ "GemmaPreTrainedModel",
+ "GemmaForSequenceClassification",
+ "GemmaForTokenClassification",
+ ]
+
+try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_flax_gemma"] = [
+ "FlaxGemmaForCausalLM",
+ "FlaxGemmaModel",
+ "FlaxGemmaPreTrainedModel",
+ ]
if TYPE_CHECKING:
- from .configuration_gemma import *
- from .modeling_flax_gemma import *
- from .modeling_gemma import *
- from .tokenization_gemma import *
- from .tokenization_gemma_fast import *
+ from .configuration_gemma import GemmaConfig
+
+ try:
+ if not is_sentencepiece_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .tokenization_gemma import GemmaTokenizer
+
+ try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .tokenization_gemma_fast import GemmaTokenizerFast
+
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_gemma import (
+ GemmaForCausalLM,
+ GemmaForSequenceClassification,
+ GemmaForTokenClassification,
+ GemmaModel,
+ GemmaPreTrainedModel,
+ )
+
+ try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_flax_gemma import (
+ FlaxGemmaForCausalLM,
+ FlaxGemmaModel,
+ FlaxGemmaPreTrainedModel,
+ )
+
+
else:
import sys
- _file = globals()["__file__"]
- sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
diff --git a/src/transformers/models/gemma/modeling_flax_gemma.py b/src/transformers/models/gemma/modeling_flax_gemma.py
index dfe9739ba6555d..16291f3c3abe0a 100644
--- a/src/transformers/models/gemma/modeling_flax_gemma.py
+++ b/src/transformers/models/gemma/modeling_flax_gemma.py
@@ -772,6 +772,3 @@ def update_inputs_for_generation(self, model_outputs, model_kwargs):
_CONFIG_FOR_DOC,
real_checkpoint=_REAL_CHECKPOINT_FOR_DOC,
)
-
-
-__all__ = ["FlaxGemmaForCausalLM", "FlaxGemmaModel", "FlaxGemmaPreTrainedModel"]
diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py
index e2ea12b03fe434..52d02995016167 100644
--- a/src/transformers/models/gemma/modeling_gemma.py
+++ b/src/transformers/models/gemma/modeling_gemma.py
@@ -19,7 +19,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Callable, List, Optional, Tuple, Union
+import math
+from typing import List, Optional, Tuple, Union
import torch
from torch import nn
@@ -28,21 +29,19 @@
from ...cache_utils import Cache, DynamicCache, StaticCache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter
-from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...modeling_flash_attention_utils import _flash_attention_forward
from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
SequenceClassifierOutputWithPast,
TokenClassifierOutput,
)
-from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
-from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
-from ...processing_utils import Unpack
+from ...modeling_utils import PreTrainedModel
from ...utils import (
- LossKwargs,
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
+ is_flash_attn_greater_or_equal_2_10,
logging,
replace_return_docstrings,
)
@@ -75,72 +74,24 @@ def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.eps}"
-class GemmaMLP(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.config = config
- self.hidden_size = config.hidden_size
- self.intermediate_size = config.intermediate_size
- self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
- self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
- self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
- self.act_fn = ACT2FN[config.hidden_act]
-
- def forward(self, x):
- down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
- return down_proj
-
-
class GemmaRotaryEmbedding(nn.Module):
- def __init__(
- self,
- config: GemmaConfig,
- device=None,
- ):
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
- self.rope_kwargs = {}
- # BC: "rope_type" was originally "type"
- if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
- self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
- else:
- self.rope_type = "default"
- self.max_seq_len_cached = config.max_position_embeddings
- self.original_max_seq_len = config.max_position_embeddings
-
- self.config = config
- self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
- inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
- self.register_buffer("inv_freq", inv_freq, persistent=False)
- self.original_inv_freq = self.inv_freq
-
- def _dynamic_frequency_update(self, position_ids, device):
- """
- dynamic RoPE layers should recompute `inv_freq` in the following situations:
- 1 - growing beyond the cached sequence length (allow scaling)
- 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
- """
- seq_len = torch.max(position_ids) + 1
- if seq_len > self.max_seq_len_cached: # growth
- inv_freq, self.attention_scaling = self.rope_init_fn(
- self.config, device, seq_len=seq_len, **self.rope_kwargs
- )
- self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
- self.max_seq_len_cached = seq_len
-
- if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
- self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
- self.max_seq_len_cached = self.original_max_seq_len
+ self.dim = dim
+ self.max_position_embeddings = max_position_embeddings
+ self.base = base
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim))
+ self.register_buffer("inv_freq", tensor=inv_freq, persistent=False)
@torch.no_grad()
- def forward(self, x, position_ids):
- if "dynamic" in self.rope_type:
- self._dynamic_frequency_update(position_ids, device=x.device)
-
- # Core RoPE block
+ def forward(self, x, position_ids, seq_len=None):
+ # x: [bs, num_attention_heads, seq_len, head_size]
+ self.inv_freq.to(x.device)
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
- # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
+ # Force float32 since bfloat16 loses precision on long contexts
+ # See https://github.com/huggingface/transformers/pull/29285
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
@@ -148,12 +99,60 @@ def forward(self, x, position_ids):
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
- # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
- cos = cos * self.attention_scaling
- sin = sin * self.attention_scaling
- return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+class GemmaLinearScalingRotaryEmbedding(GemmaRotaryEmbedding):
+ """GemmaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
+
+ def forward(self, x, position_ids):
+ # difference to the original RoPE: a scaling factor is aplied to the position ids
+ position_ids = position_ids.float() / self.scaling_factor
+ cos, sin = super().forward(x, position_ids)
+ return cos, sin
+
+
+class GemmaDynamicNTKScalingRotaryEmbedding(GemmaRotaryEmbedding):
+ """GemmaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
+
+ def forward(self, x, position_ids):
+ # difference to the original RoPE: inv_freq is recomputed when the sequence length > original length
+ seq_len = torch.max(position_ids) + 1
+ if seq_len > self.max_position_embeddings:
+ base = self.base * (
+ (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
+ ) ** (self.dim / (self.dim - 2))
+ inv_freq = 1.0 / (
+ base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(x.device) / self.dim)
+ )
+ self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: this may break with compilation
+
+ cos, sin = super().forward(x, position_ids)
+ return cos, sin
+
+
+class GemmaMLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = config.intermediate_size
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
+ if config.hidden_activation is None:
+ logger.warning_once(
+ "`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.\n"
+ "Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use\n"
+ "`config.hidden_activation` if you want to override this behaviour.\n"
+ "See https://github.com/huggingface/transformers/pull/29402 for more details."
+ )
+ config.hidden_activation = "gelu_pytorch_tanh"
+ hidden_activation = config.hidden_activation
+ self.act_fn = ACT2FN[hidden_activation]
+
+ def forward(self, x):
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
def rotate_half(x):
@@ -202,75 +201,241 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
-def eager_attention_forward(
- module: nn.Module,
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- attention_mask: Optional[torch.Tensor],
- scaling: float,
- dropout: float = 0.0,
- **kwargs,
-):
- key_states = repeat_kv(key, module.num_key_value_groups)
- value_states = repeat_kv(value, module.num_key_value_groups)
-
- attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
- if attention_mask is not None:
- causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
- attn_weights = attn_weights + causal_mask
-
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
- attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
- attn_output = torch.matmul(attn_weights, value_states)
- attn_output = attn_output.transpose(1, 2).contiguous()
-
- return attn_output, attn_weights
-
-
class GemmaAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
- def __init__(self, config: GemmaConfig, layer_idx: int):
+ def __init__(self, config: GemmaConfig, layer_idx: Optional[int] = None):
super().__init__()
self.config = config
self.layer_idx = layer_idx
- self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
- self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
- self.scaling = self.head_dim**-0.5
+ if layer_idx is None:
+ logger.warning_once(
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
+ "when creating this class."
+ )
+
self.attention_dropout = config.attention_dropout
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = config.head_dim
+ self.num_key_value_heads = config.num_key_value_heads
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
+ self.max_position_embeddings = config.max_position_embeddings
+ self.rope_theta = config.rope_theta
self.is_causal = True
+ self.scaling = 1 / math.sqrt(config.head_dim)
- self.q_proj = nn.Linear(
- config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
- )
- self.k_proj = nn.Linear(
- config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
- )
- self.v_proj = nn.Linear(
- config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
+ if self.hidden_size % self.num_heads != 0:
+ raise ValueError(
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
+ f" and `num_heads`: {self.num_heads})."
+ )
+
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
+ self.rotary_emb = GemmaRotaryEmbedding(
+ self.head_dim,
+ max_position_embeddings=self.max_position_embeddings,
+ base=self.rope_theta,
)
- self.o_proj = nn.Linear(
- config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling
+
+ if attention_mask is not None: # no matter the length, we just slice it
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ attn_output = attn_output.view(bsz, q_len, -1)
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+class GemmaSdpaAttention(GemmaAttention):
+ """
+ Gemma attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
+ `GemmaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
+ SDPA API.
+ """
+
+ # Adapted from GemmaAttention.forward
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ if output_attentions:
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
+ logger.warning_once(
+ "GemmaModel is using GemmaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
+ )
+ return super().forward(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ )
+
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ causal_mask = attention_mask
+ if attention_mask is not None:
+ causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
+
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
+ if query_states.device.type == "cuda" and causal_mask is not None:
+ query_states = query_states.contiguous()
+ key_states = key_states.contiguous()
+ value_states = value_states.contiguous()
+
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
+ is_causal = True if causal_mask is None and q_len > 1 else False
+
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
+ query_states,
+ key_states,
+ value_states,
+ attn_mask=causal_mask,
+ dropout_p=self.attention_dropout if self.training else 0.0,
+ is_causal=is_causal,
)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.view(bsz, q_len, -1)
+
+ attn_output = self.o_proj(attn_output)
+
+ return attn_output, None, past_key_value
+
+
+class GemmaFlashAttention2(GemmaAttention):
+ """
+ Gemma flash attention module. This module inherits from `GemmaAttention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+
def forward(
self,
hidden_states: torch.Tensor,
- position_embeddings: Tuple[torch.Tensor, torch.Tensor],
- attention_mask: Optional[torch.Tensor],
+ attention_mask: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
- **kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
- input_shape = hidden_states.shape[:-1]
- hidden_shape = (*input_shape, -1, self.head_dim)
+ if isinstance(past_key_value, StaticCache):
+ raise ValueError(
+ "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
+ "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
+ )
+
+ output_attentions = False
+
+ bsz, q_len, _ = hidden_states.size()
- query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
- key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
- value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
- cos, sin = position_embeddings
+ # Flash attention requires the input to have the shape
+ # batch_size x seq_length x head_dim x hidden_dim
+ # therefore we just need to keep the original shape
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
@@ -278,39 +443,73 @@ def forward(
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
- attention_interface: Callable = eager_attention_forward
- if self.config._attn_implementation != "eager":
- if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
- logger.warning_once(
- "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
- 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
- )
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
+ # to be able to avoid many of these transpose/reshape/view.
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ dropout_rate = self.attention_dropout if self.training else 0.0
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in the correct dtype just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+ # in fp32. (GemmaRMSNorm handles it correctly)
+
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
else:
- attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+ target_dtype = self.q_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
- attn_output, attn_weights = attention_interface(
- self,
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ attn_output = _flash_attention_forward(
query_states,
key_states,
value_states,
attention_mask,
- dropout=0.0 if not self.training else self.attention_dropout,
- scaling=self.scaling,
- **kwargs,
+ q_len,
+ position_ids=position_ids,
+ dropout=dropout_rate,
+ sliding_window=getattr(self, "sliding_window", None),
+ is_causal=self.is_causal,
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
)
-
- attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
attn_output = self.o_proj(attn_output)
- return attn_output, attn_weights
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+GEMMA_ATTENTION_CLASSES = {
+ "eager": GemmaAttention,
+ "flash_attention_2": GemmaFlashAttention2,
+ "sdpa": GemmaSdpaAttention,
+}
class GemmaDecoderLayer(nn.Module):
def __init__(self, config: GemmaConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
-
- self.self_attn = GemmaAttention(config=config, layer_idx=layer_idx)
-
+ self.self_attn = GEMMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
self.mlp = GemmaMLP(config)
self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -324,15 +523,33 @@ def forward(
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
- **kwargs: Unpack[FlashAttentionKwargs],
+ **kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`, *optional*):
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
+ query_sequence_length, key_sequence_length)` if default attention is used.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence
+ kwargs (`dict`, *optional*):
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
+ into the model
+ """
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
- hidden_states, self_attn_weights = self.self_attn(
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
@@ -340,7 +557,6 @@ def forward(
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
- position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = residual + hidden_states
@@ -352,9 +568,13 @@ def forward(
hidden_states = residual + hidden_states
outputs = (hidden_states,)
+
if output_attentions:
outputs += (self_attn_weights,)
+ if use_cache:
+ outputs += (present_key_value,)
+
return outputs
@@ -500,8 +720,10 @@ def __init__(self, config: GemmaConfig):
[GemmaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
- self.rotary_emb = GemmaRotaryEmbedding(config=config)
+
self.gradient_checkpointing = False
+ if getattr(config, "pretraining_tp", 1) != 1:
+ logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.")
# Initialize weights and apply final processing
self.post_init()
@@ -545,8 +767,19 @@ def forward(
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
- if use_cache and past_key_values is None:
- past_key_values = DynamicCache()
+ # kept for BC (non `Cache` `past_key_values` inputs)
+ return_legacy_cache = False # noqa: F841
+ if use_cache and not isinstance(past_key_values, Cache):
+ return_legacy_cache = True # noqa: F841
+ if past_key_values is None:
+ past_key_values = DynamicCache()
+ else:
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
+ logger.warning_once(
+ "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
+ "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
+ "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
+ )
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
@@ -564,9 +797,6 @@ def forward(
# embed positions
hidden_states = inputs_embeds
- # create position embeddings to be shared across the decoder layers
- position_embeddings = self.rotary_emb(hidden_states, position_ids)
-
# normalized
# Gemma downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5
# See https://github.com/huggingface/transformers/pull/29402
@@ -576,6 +806,7 @@ def forward(
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
+ next_decoder_cache = None
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
if output_hidden_states:
@@ -591,7 +822,6 @@ def forward(
output_attentions,
use_cache,
cache_position,
- position_embeddings,
)
else:
layer_outputs = decoder_layer(
@@ -602,11 +832,13 @@ def forward(
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
- position_embeddings=position_embeddings,
)
hidden_states = layer_outputs[0]
+ if use_cache:
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
+
if output_attentions:
all_self_attns += (layer_outputs[1],)
@@ -616,13 +848,18 @@ def forward(
if output_hidden_states:
all_hidden_states += (hidden_states,)
- output = BaseModelOutputWithPast(
+ next_cache = next_decoder_cache if use_cache else None
+ if return_legacy_cache:
+ next_cache = next_cache.to_legacy_cache()
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
+ return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
- past_key_values=past_key_values if use_cache else None,
+ past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
- return output if return_dict else output.to_tuple()
def _update_causal_mask(
self,
@@ -746,9 +983,6 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
return causal_mask
-class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
-
-
class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
@@ -796,7 +1030,7 @@ def forward(
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
- **kwargs: Unpack[KwargsForCausalLM],
+ **loss_kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
@@ -846,7 +1080,6 @@ def forward(
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
- **kwargs,
)
hidden_states = outputs[0]
@@ -855,7 +1088,7 @@ def forward(
loss = None
if labels is not None:
- loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
+ loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
if not return_dict:
output = (logits,) + outputs[1:]
@@ -1062,10 +1295,4 @@ def forward(
)
-__all__ = [
- "GemmaModel",
- "GemmaForCausalLM",
- "GemmaForSequenceClassification",
- "GemmaForTokenClassification",
- "GemmaPreTrainedModel",
-]
+__all__ = ["GemmaModel", "GemmaForCausalLM", "GemmaForSequenceClassification", "GemmaForTokenClassification"]
diff --git a/src/transformers/models/gemma/modular_gemma.py b/src/transformers/models/gemma/modular_gemma.py
index 29b6f8a1946173..ad1348ae5e3163 100644
--- a/src/transformers/models/gemma/modular_gemma.py
+++ b/src/transformers/models/gemma/modular_gemma.py
@@ -13,6 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import math
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import sentencepiece as spm
@@ -20,17 +21,23 @@
import torch.utils.checkpoint
from torch import nn
-from ...cache_utils import Cache, DynamicCache
+from ...activations import ACT2FN
+from ...cache_utils import Cache, DynamicCache, StaticCache
from ...configuration_utils import PretrainedConfig
-from ...modeling_outputs import BaseModelOutputWithPast
+from ...modeling_flash_attention_utils import _flash_attention_forward
+from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
+from ...pytorch_utils import ALL_LAYERNORM_LAYERS
from ...tokenization_utils import AddedToken, PreTrainedTokenizer
from ...utils import logging
from ..llama.modeling_llama import (
+ LlamaDecoderLayer,
+ LlamaFlashAttention2,
LlamaForCausalLM,
LlamaForSequenceClassification,
LlamaForTokenClassification,
- LlamaMLP,
LlamaModel,
+ apply_rotary_pos_emb,
+ repeat_kv,
)
from ..llama.tokenization_llama import LlamaTokenizer
@@ -344,15 +351,468 @@ def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.eps}"
-class GemmaMLP(LlamaMLP):
+ALL_LAYERNORM_LAYERS.append(GemmaRMSNorm)
+
+
+class GemmaRotaryEmbedding(nn.Module):
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
+ super().__init__()
+
+ self.dim = dim
+ self.max_position_embeddings = max_position_embeddings
+ self.base = base
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim))
+ self.register_buffer("inv_freq", tensor=inv_freq, persistent=False)
+
+ @torch.no_grad()
+ def forward(self, x, position_ids, seq_len=None):
+ # x: [bs, num_attention_heads, seq_len, head_size]
+ self.inv_freq.to(x.device)
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
+ position_ids_expanded = position_ids[:, None, :].float()
+ # Force float32 since bfloat16 loses precision on long contexts
+ # See https://github.com/huggingface/transformers/pull/29285
+ device_type = x.device.type
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False):
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos()
+ sin = emb.sin()
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+class GemmaLinearScalingRotaryEmbedding(GemmaRotaryEmbedding):
+ """GemmaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
+
+ def forward(self, x, position_ids):
+ # difference to the original RoPE: a scaling factor is aplied to the position ids
+ position_ids = position_ids.float() / self.scaling_factor
+ cos, sin = super().forward(x, position_ids)
+ return cos, sin
+
+
+class GemmaDynamicNTKScalingRotaryEmbedding(GemmaRotaryEmbedding):
+ """GemmaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
+
+ def forward(self, x, position_ids):
+ # difference to the original RoPE: inv_freq is recomputed when the sequence length > original length
+ seq_len = torch.max(position_ids) + 1
+ if seq_len > self.max_position_embeddings:
+ base = self.base * (
+ (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
+ ) ** (self.dim / (self.dim - 2))
+ inv_freq = 1.0 / (
+ base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(x.device) / self.dim)
+ )
+ self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: this may break with compilation
+
+ cos, sin = super().forward(x, position_ids)
+ return cos, sin
+
+
+class GemmaMLP(nn.Module):
def __init__(self, config):
super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
+ if config.hidden_activation is None:
+ logger.warning_once(
+ "`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.\n"
+ "Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use\n"
+ "`config.hidden_activation` if you want to override this behaviour.\n"
+ "See https://github.com/huggingface/transformers/pull/29402 for more details."
+ )
+ config.hidden_activation = "gelu_pytorch_tanh"
+ hidden_activation = config.hidden_activation
+ self.act_fn = ACT2FN[hidden_activation]
+
+ def forward(self, x):
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
+
+
+class GemmaAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: GemmaConfig, layer_idx: Optional[int] = None):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ if layer_idx is None:
+ logger.warning_once(
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
+ "when creating this class."
+ )
+
+ self.attention_dropout = config.attention_dropout
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = config.head_dim
+ self.num_key_value_heads = config.num_key_value_heads
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
+ self.max_position_embeddings = config.max_position_embeddings
+ self.rope_theta = config.rope_theta
+ self.is_causal = True
+ self.scaling = 1 / math.sqrt(config.head_dim)
+
+ if self.hidden_size % self.num_heads != 0:
+ raise ValueError(
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
+ f" and `num_heads`: {self.num_heads})."
+ )
+
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
+ self.rotary_emb = GemmaRotaryEmbedding(
+ self.head_dim,
+ max_position_embeddings=self.max_position_embeddings,
+ base=self.rope_theta,
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling
+
+ if attention_mask is not None: # no matter the length, we just slice it
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ attn_output = attn_output.view(bsz, q_len, -1)
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+class GemmaSdpaAttention(GemmaAttention):
+ """
+ Gemma attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
+ `GemmaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
+ SDPA API.
+ """
+
+ # Adapted from GemmaAttention.forward
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ if output_attentions:
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
+ logger.warning_once(
+ "GemmaModel is using GemmaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
+ )
+ return super().forward(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ )
+
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ causal_mask = attention_mask
+ if attention_mask is not None:
+ causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
+
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
+ if query_states.device.type == "cuda" and causal_mask is not None:
+ query_states = query_states.contiguous()
+ key_states = key_states.contiguous()
+ value_states = value_states.contiguous()
+
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
+ is_causal = True if causal_mask is None and q_len > 1 else False
+
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
+ query_states,
+ key_states,
+ value_states,
+ attn_mask=causal_mask,
+ dropout_p=self.attention_dropout if self.training else 0.0,
+ is_causal=is_causal,
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.view(bsz, q_len, -1)
+
+ attn_output = self.o_proj(attn_output)
+
+ return attn_output, None, past_key_value
+
+
+class GemmaFlashAttention2(LlamaFlashAttention2, GemmaAttention):
+ """
+ Gemma flash attention module. This module inherits from `GemmaAttention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ if isinstance(past_key_value, StaticCache):
+ raise ValueError(
+ "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
+ "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
+ )
+
+ output_attentions = False
+
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ # Flash attention requires the input to have the shape
+ # batch_size x seq_length x head_dim x hidden_dim
+ # therefore we just need to keep the original shape
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
+ # to be able to avoid many of these transpose/reshape/view.
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ dropout_rate = self.attention_dropout if self.training else 0.0
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in the correct dtype just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+ # in fp32. (GemmaRMSNorm handles it correctly)
+
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.q_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ attn_output = _flash_attention_forward(
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ q_len,
+ position_ids=position_ids,
+ dropout=dropout_rate,
+ sliding_window=getattr(self, "sliding_window", None),
+ is_causal=self.is_causal,
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
+ )
+ attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+GEMMA_ATTENTION_CLASSES = {
+ "eager": GemmaAttention,
+ "flash_attention_2": GemmaFlashAttention2,
+ "sdpa": GemmaSdpaAttention,
+}
+
+
+class GemmaDecoderLayer(LlamaDecoderLayer):
+ def __init__(self, config: GemmaConfig, layer_idx: int):
+ super().__init__(config)
+ self.self_attn = GEMMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
+ self.mlp = GemmaMLP(config)
+ self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs,
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`, *optional*):
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
+ query_sequence_length, key_sequence_length)` if default attention is used.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence
+ kwargs (`dict`, *optional*):
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
+ into the model
+ """
+ residual = hidden_states
+
+ hidden_states = self.input_layernorm(hidden_states)
+
+ # Self Attention
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ **kwargs,
+ )
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ if use_cache:
+ outputs += (present_key_value,)
+
+ return outputs
class GemmaModel(LlamaModel):
+ def __init__(self, config: GemmaConfig):
+ super().__init__(config)
+ self.layers = nn.ModuleList(
+ [GemmaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ del self.rotary_emb # Gemma does not implement rotary emb at the modeling level yet!
+ self.post_init()
+
def forward(
self,
input_ids: torch.LongTensor = None,
@@ -385,8 +845,19 @@ def forward(
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
- if use_cache and past_key_values is None:
- past_key_values = DynamicCache()
+ # kept for BC (non `Cache` `past_key_values` inputs)
+ return_legacy_cache = False # noqa: F841
+ if use_cache and not isinstance(past_key_values, Cache):
+ return_legacy_cache = True # noqa: F841
+ if past_key_values is None:
+ past_key_values = DynamicCache()
+ else:
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
+ logger.warning_once(
+ "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
+ "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
+ "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
+ )
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
@@ -404,9 +875,6 @@ def forward(
# embed positions
hidden_states = inputs_embeds
- # create position embeddings to be shared across the decoder layers
- position_embeddings = self.rotary_emb(hidden_states, position_ids)
-
# normalized
# Gemma downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5
# See https://github.com/huggingface/transformers/pull/29402
@@ -416,6 +884,7 @@ def forward(
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
+ next_decoder_cache = None
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
if output_hidden_states:
@@ -431,7 +900,6 @@ def forward(
output_attentions,
use_cache,
cache_position,
- position_embeddings,
)
else:
layer_outputs = decoder_layer(
@@ -442,11 +910,13 @@ def forward(
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
- position_embeddings=position_embeddings,
)
hidden_states = layer_outputs[0]
+ if use_cache:
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
+
if output_attentions:
all_self_attns += (layer_outputs[1],)
@@ -456,33 +926,44 @@ def forward(
if output_hidden_states:
all_hidden_states += (hidden_states,)
- output = BaseModelOutputWithPast(
+ next_cache = next_decoder_cache if use_cache else None
+ if return_legacy_cache:
+ next_cache = next_cache.to_legacy_cache()
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
+ return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
- past_key_values=past_key_values if use_cache else None,
+ past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
- return output if return_dict else output.to_tuple()
+# Example where we ony modify the docstring and call super
class GemmaForCausalLM(LlamaForCausalLM):
- def forward(**super_kwargs):
- r"""
- Args:
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
-
- num_logits_to_keep (`int`, *optional*):
- Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
- `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
- token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
-
- Returns:
-
- Example:
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = GemmaModel(config)
+ self.post_init()
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ num_logits_to_keep: int = 0,
+ **loss_kwargs,
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
+ r"""
```python
>>> from transformers import AutoTokenizer, GemmaForCausalLM
@@ -497,15 +978,59 @@ def forward(**super_kwargs):
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"What is your favorite condiment?"
```"""
- return super().forward(**super_kwargs)
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ cache_position=cache_position,
+ )
+
+ hidden_states = outputs[0]
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return (loss,) + output if loss is not None else output
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
class GemmaForSequenceClassification(LlamaForSequenceClassification):
- pass
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = GemmaModel(config)
+ self.post_init()
class GemmaForTokenClassification(LlamaForTokenClassification):
- pass
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = GemmaModel(config)
+ self.post_init()
__all__ = [
@@ -515,5 +1040,4 @@ class GemmaForTokenClassification(LlamaForTokenClassification):
"GemmaForCausalLM",
"GemmaForSequenceClassification",
"GemmaForTokenClassification",
- "GemmaPreTrainedModel", # noqa: F822
]
diff --git a/src/transformers/models/gemma/tokenization_gemma_fast.py b/src/transformers/models/gemma/tokenization_gemma_fast.py
index 0e6f4a20b6d6d7..fd7a979e8b7509 100644
--- a/src/transformers/models/gemma/tokenization_gemma_fast.py
+++ b/src/transformers/models/gemma/tokenization_gemma_fast.py
@@ -197,6 +197,3 @@ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
output = output + bos_token_id + token_ids_1 + eos_token_id
return output
-
-
-__all__ = ["GemmaTokenizerFast"]
diff --git a/src/transformers/models/gemma2/__init__.py b/src/transformers/models/gemma2/__init__.py
index 18905bac42cc6b..ce59dfd8c7ac5a 100644
--- a/src/transformers/models/gemma2/__init__.py
+++ b/src/transformers/models/gemma2/__init__.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,15 +13,49 @@
# limitations under the License.
from typing import TYPE_CHECKING
-from ...utils import _LazyModule
-from ...utils.import_utils import define_import_structure
+from ...utils import (
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ is_torch_available,
+)
+_import_structure = {
+ "configuration_gemma2": ["Gemma2Config"],
+}
+
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_gemma2"] = [
+ "Gemma2ForCausalLM",
+ "Gemma2Model",
+ "Gemma2PreTrainedModel",
+ "Gemma2ForSequenceClassification",
+ "Gemma2ForTokenClassification",
+ ]
+
if TYPE_CHECKING:
- from .configuration_gemma2 import *
- from .modeling_gemma2 import *
+ from .configuration_gemma2 import Gemma2Config
+
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_gemma2 import (
+ Gemma2ForCausalLM,
+ Gemma2ForSequenceClassification,
+ Gemma2ForTokenClassification,
+ Gemma2Model,
+ Gemma2PreTrainedModel,
+ )
+
else:
import sys
- _file = globals()["__file__"]
- sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
diff --git a/src/transformers/models/gemma2/configuration_gemma2.py b/src/transformers/models/gemma2/configuration_gemma2.py
index dc2eba7893a058..eb562b3a6893bd 100644
--- a/src/transformers/models/gemma2/configuration_gemma2.py
+++ b/src/transformers/models/gemma2/configuration_gemma2.py
@@ -153,6 +153,3 @@ def __init__(
self.final_logit_softcapping = final_logit_softcapping
self.attn_logit_softcapping = attn_logit_softcapping
self.cache_implementation = cache_implementation
-
-
-__all__ = ["Gemma2Config"]
diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py
index 67fc6c86a3bac6..58836a5631c2c0 100644
--- a/src/transformers/models/gemma2/modeling_gemma2.py
+++ b/src/transformers/models/gemma2/modeling_gemma2.py
@@ -19,7 +19,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Callable, List, Optional, Tuple, Union
+from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
@@ -27,26 +27,32 @@
from ...activations import ACT2FN
from ...cache_utils import Cache, HybridCache
from ...generation import GenerationMixin
-from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
SequenceClassifierOutputWithPast,
TokenClassifierOutput,
)
-from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
-from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
-from ...processing_utils import Unpack
+from ...modeling_utils import PreTrainedModel
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
+ is_flash_attn_2_available,
+ is_flash_attn_greater_or_equal,
+ is_torch_greater_or_equal,
logging,
replace_return_docstrings,
)
from .configuration_gemma2 import Gemma2Config
+if is_flash_attn_2_available():
+ from ...modeling_flash_attention_utils import _flash_attention_forward
+
+if is_torch_greater_or_equal("2.5"):
+ from torch.nn.attention.flex_attention import flex_attention
+
logger = logging.get_logger(__name__)
@@ -86,8 +92,35 @@ def __init__(self, config):
self.act_fn = ACT2FN[config.hidden_activation]
def forward(self, x):
- down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
- return down_proj
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
+
+
+class Gemma2RotaryEmbedding(nn.Module):
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
+ super().__init__()
+
+ self.dim = dim
+ self.max_position_embeddings = max_position_embeddings
+ self.base = base
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim))
+ self.register_buffer("inv_freq", tensor=inv_freq, persistent=False)
+
+ @torch.no_grad()
+ def forward(self, x, position_ids, seq_len=None):
+ # x: [bs, num_attention_heads, seq_len, head_size]
+ self.inv_freq.to(x.device)
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
+ position_ids_expanded = position_ids[:, None, :].float()
+ # Force float32 since bfloat16 loses precision on long contexts
+ # See https://github.com/huggingface/transformers/pull/29285
+ device_type = x.device.type
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False):
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos()
+ sin = emb.sin()
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
def rotate_half(x):
@@ -137,118 +170,266 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
def eager_attention_forward(
- module: nn.Module,
+ config: Gemma2Config,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
- attention_mask: Optional[torch.Tensor],
- dropout: float = 0.0,
- scaling: Optional[float] = None,
- softcap: Optional[float] = None,
- **kwargs,
+ mask: Optional[torch.Tensor],
+ **_kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
- if scaling is None:
- scaling = module.head_dim**-0.5
+ key_states = repeat_kv(key, config.num_key_value_groups)
+ value_states = repeat_kv(value, config.num_key_value_groups)
- key_states = repeat_kv(key, module.num_key_value_groups)
- value_states = repeat_kv(value, module.num_key_value_groups)
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * config.scaling
- attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
-
- if softcap is not None:
- attn_weights = attn_weights / softcap
+ if config.attn_logit_softcapping is not None:
+ attn_weights = attn_weights / config.attn_logit_softcapping
attn_weights = torch.tanh(attn_weights)
- attn_weights = attn_weights * softcap
- if attention_mask is not None: # no matter the length, we just slice it
- causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights * config.attn_logit_softcapping
+ if mask is not None: # no matter the length, we just slice it
+ causal_mask = mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
- attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_weights = nn.functional.dropout(attn_weights, p=config.attention_dropout, training=config.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
+def flash_attention_forward(
+ config: Gemma2Config,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ mask: Optional[torch.Tensor],
+ target_dtype: torch.dtype = torch.float16,
+ **_kwargs,
+) -> Tuple[torch.Tensor, None]:
+ if mask is not None:
+ seq_len = mask.shape[1]
+ query = query[:, :, :seq_len]
+ value = value[:, :, :seq_len]
+
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout
+ # [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor rotary embedding
+ query_states = query.transpose(1, 2)
+ key_states = key.transpose(1, 2)
+ value_states = value.transpose(1, 2)
+
+ dropout_rate = config.attention_dropout if config.training else 0.0
+
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ attn_output = _flash_attention_forward(
+ query_states,
+ key_states,
+ value_states,
+ mask,
+ seq_len,
+ dropout=dropout_rate,
+ softmax_scale=config.scaling,
+ is_causal=config.is_causal,
+ sliding_window=config.sliding_window,
+ use_top_left_mask=config._flash_attn_uses_top_left_mask,
+ softcap=config.attn_logit_softcapping if is_flash_attn_greater_or_equal("2.6.0") else None,
+ )
+
+ return attn_output, None
+
+
+def flex_attention_forward(
+ config: Gemma2Config,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ mask: Optional[torch.Tensor],
+ output_attentions: bool = False,
+ **_kwargs,
+) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+ def tanh_softcap(score, b, h, q_idx, kv_idx):
+ soft_cap = config.attn_logit_softcapping
+ score = soft_cap * torch.tanh(score / soft_cap)
+ if mask is not None:
+ return score + mask[b][0][q_idx][kv_idx]
+ return score
+
+ attn_output = flex_attention(
+ query,
+ key,
+ value,
+ score_mod=tanh_softcap,
+ enable_gqa=True,
+ scale=config.scaling,
+ return_lse=output_attentions,
+ )
+ if not output_attentions:
+ attn_weights = None
+ else:
+ attn_output, attn_weights = attn_output
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ return attn_output, attn_weights
+
+
+def sdpa_attention_forward(
+ config: Gemma2Config,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ mask: Optional[torch.Tensor],
+ **_kwargs,
+) -> Tuple[torch.Tensor, None]:
+ key = repeat_kv(key, config.num_key_value_groups)
+ value = repeat_kv(value, config.num_key_value_groups)
+
+ causal_mask = mask
+ if mask is not None:
+ causal_mask = causal_mask[:, :, :, : key.shape[-2]]
+
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
+ if query.device.type == "cuda" and causal_mask is not None:
+ query = query.contiguous()
+ key = key.contiguous()
+ value = value.contiguous()
+
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
+ is_causal = True if causal_mask is None and query.shape[1] > 1 else False
+
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
+ query,
+ key,
+ value,
+ attn_mask=causal_mask,
+ dropout_p=config.attention_dropout if config.training else 0.0,
+ is_causal=is_causal,
+ scale=config.scaling,
+ )
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ return attn_output, None
+
+
+GEMMA2_ATTENTION_FUNCTION = {
+ "flash_attention_2": flash_attention_forward,
+ "flex_attention": flex_attention_forward,
+ "eager": eager_attention_forward,
+ "sdpa": sdpa_attention_forward,
+}
+
+
class Gemma2Attention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
- def __init__(self, config: Gemma2Config, layer_idx: int):
+ def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None):
super().__init__()
self.config = config
self.layer_idx = layer_idx
- self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
- self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
- self.scaling = config.query_pre_attn_scalar**-0.5
- self.attention_dropout = self.config.attention_dropout
+
+ self.attention_dropout = config.attention_dropout
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = config.head_dim
+ self.num_key_value_heads = config.num_key_value_heads
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
+ self.max_position_embeddings = config.max_position_embeddings
+ self.rope_theta = config.rope_theta
self.is_causal = True
+ self.scaling = config.query_pre_attn_scalar**-0.5
+ self.sliding_window = config.sliding_window if not bool(layer_idx % 2) else None
+ self.attn_logit_softcapping = config.attn_logit_softcapping
+ if self.hidden_size % self.num_heads != 0:
+ raise ValueError(
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
+ f" and `num_heads`: {self.num_heads})."
+ )
- self.q_proj = nn.Linear(
- config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
- )
- self.k_proj = nn.Linear(
- config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
- )
- self.v_proj = nn.Linear(
- config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
- )
- self.o_proj = nn.Linear(
- config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
+ self.rotary_emb = Gemma2RotaryEmbedding(
+ self.head_dim,
+ max_position_embeddings=self.max_position_embeddings,
+ base=self.rope_theta,
)
- self.attn_logit_softcapping = self.config.attn_logit_softcapping
- self.sliding_window = config.sliding_window if not bool(layer_idx % 2) else None
def forward(
self,
hidden_states: torch.Tensor,
- position_embeddings: Tuple[torch.Tensor, torch.Tensor],
- attention_mask: Optional[torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
- **kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
- input_shape = hidden_states.shape[:-1]
- hidden_shape = (*input_shape, -1, self.head_dim)
+ bsz, q_len, _ = hidden_states.size()
- query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
- key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
- value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
- cos, sin = position_embeddings
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
- cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ cache_kwargs = {
+ "sin": sin,
+ "cos": cos,
+ "sliding_window": self.sliding_window,
+ "cache_position": cache_position,
+ }
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
- attention_interface: Callable = eager_attention_forward
- if self.config._attn_implementation != "eager":
- if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
- logger.warning_once(
- "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
- 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
- )
- else:
- attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+ if output_attentions and self.config._attn_implementation in ["sdpa", "flash_attention_2"]:
+ logger.warning_once("Setting `attention_type` to `flex_attention` because `output_attentions=True`")
+ attention_type = "flex_attention"
+ else:
+ attention_type = self.config._attn_implementation
- attn_output, attn_weights = attention_interface(
- self,
- query_states,
- key_states,
- value_states,
- attention_mask,
- dropout=self.attention_dropout if self.training else 0.0,
- scaling=self.scaling,
- sliding_window=self.sliding_window,
- softcap=self.attn_logit_softcapping,
- **kwargs,
+ attn_output, attn_weights = GEMMA2_ATTENTION_FUNCTION[attention_type](
+ self, query_states, key_states, value_states, attention_mask, output_attentions=output_attentions
)
- attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
attn_output = self.o_proj(attn_output)
- return attn_output, attn_weights
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+class Gemma2FlashAttention2(Gemma2Attention):
+ def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None):
+ super().__init__(config, layer_idx)
+ self.config._attn_implementation = "flash_attention_2"
+ logger.warning_once(
+ "The `Gemma2FlashAttention2` class is deprecated in favor of simply modifying the `config._attn_implementation`"
+ "attribute of the `GemmaAttention` class! It will be removed in v4.48"
+ )
+
+
+class Gemma2SdpaAttention(Gemma2Attention):
+ def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None):
+ super().__init__(config, layer_idx)
+ self.config._attn_implementation = "sdpa"
+ logger.warning_once(
+ "The `Gemma2FlashAttention2` class is deprecated in favor of simply modifying the `config._attn_implementation`"
+ "attribute of the `GemmaAttention` class! It will be removed in v4.48"
+ )
class Gemma2DecoderLayer(nn.Module):
@@ -269,7 +450,6 @@ def __init__(self, config: Gemma2Config, layer_idx: int):
def forward(
self,
hidden_states: torch.Tensor,
- position_embeddings: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
@@ -296,9 +476,8 @@ def forward(
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
- hidden_states, self_attn_weights = self.self_attn(
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
- position_embeddings=position_embeddings,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
@@ -320,72 +499,10 @@ def forward(
if output_attentions:
outputs += (self_attn_weights,)
- return outputs
-
-
-class Gemma2RotaryEmbedding(nn.Module):
- def __init__(
- self,
- config: Gemma2Config,
- device=None,
- ):
- super().__init__()
- self.rope_kwargs = {}
- # BC: "rope_type" was originally "type"
- if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
- self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
- else:
- self.rope_type = "default"
- self.max_seq_len_cached = config.max_position_embeddings
- self.original_max_seq_len = config.max_position_embeddings
-
- self.config = config
- self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
-
- inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
- self.register_buffer("inv_freq", inv_freq, persistent=False)
- self.original_inv_freq = self.inv_freq
-
- def _dynamic_frequency_update(self, position_ids, device):
- """
- dynamic RoPE layers should recompute `inv_freq` in the following situations:
- 1 - growing beyond the cached sequence length (allow scaling)
- 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
- """
- seq_len = torch.max(position_ids) + 1
- if seq_len > self.max_seq_len_cached: # growth
- inv_freq, self.attention_scaling = self.rope_init_fn(
- self.config, device, seq_len=seq_len, **self.rope_kwargs
- )
- self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
- self.max_seq_len_cached = seq_len
-
- if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
- self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
- self.max_seq_len_cached = self.original_max_seq_len
-
- @torch.no_grad()
- def forward(self, x, position_ids):
- if "dynamic" in self.rope_type:
- self._dynamic_frequency_update(position_ids, device=x.device)
-
- # Core RoPE block
- inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
- position_ids_expanded = position_ids[:, None, :].float()
- # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
- device_type = x.device.type
- device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
- with torch.autocast(device_type=device_type, enabled=False):
- freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
- emb = torch.cat((freqs, freqs), dim=-1)
- cos = emb.cos()
- sin = emb.sin()
-
- # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
- cos = cos * self.attention_scaling
- sin = sin * self.attention_scaling
+ if use_cache:
+ outputs += (present_key_value,)
- return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+ return outputs
GEMMA2_START_DOCSTRING = r"""
@@ -418,7 +535,7 @@ class Gemma2PreTrainedModel(PreTrainedModel):
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
- _supports_quantized_cache = True
+ _supports_quantized_cache = False
_supports_static_cache = True
def _init_weights(self, module):
@@ -432,6 +549,20 @@ def _init_weights(self, module):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
+ @classmethod
+ def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False):
+ """
+ Overloads `PreTrainedModel._check_and_enable_sdpa` so as to DISABLE torch SDPA by default on Gemma2 models.
+ SDPA reduces the model performance on Gemma2 because of the logits softcapping.
+ """
+ config = super()._check_and_enable_sdpa(config, hard_check_only=hard_check_only)
+
+ # if using the default path -> swap sdpa by eager
+ if not hard_check_only and config._attn_implementation == "sdpa":
+ config._attn_implementation = "eager"
+
+ return config
+
GEMMA2_INPUTS_DOCSTRING = r"""
Args:
@@ -530,8 +661,10 @@ def __init__(self, config: Gemma2Config):
[Gemma2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
- self.rotary_emb = Gemma2RotaryEmbedding(config=config)
+
self.gradient_checkpointing = False
+ if getattr(config, "pretraining_tp", 1) != 1:
+ logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.")
# Initialize weights and apply final processing
self.post_init()
@@ -601,9 +734,6 @@ def forward(
# embed positions
hidden_states = inputs_embeds
- # create position embeddings to be shared across the decoder layers
- position_embeddings = self.rotary_emb(hidden_states, position_ids)
-
# normalized
# Gemma2 downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5
# See https://github.com/huggingface/transformers/pull/29402
@@ -622,7 +752,6 @@ def forward(
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
- position_embeddings,
causal_mask,
position_ids,
past_key_values,
@@ -633,7 +762,6 @@ def forward(
else:
layer_outputs = decoder_layer(
hidden_states,
- position_embeddings=position_embeddings,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
@@ -652,13 +780,16 @@ def forward(
if output_hidden_states:
all_hidden_states += (hidden_states,)
- output = BaseModelOutputWithPast(
+ next_cache = past_key_values if use_cache else None
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
+ return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
- past_key_values=past_key_values,
+ past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
- return output if return_dict else output.to_tuple()
@torch.no_grad()
def _update_causal_mask(
@@ -1149,12 +1280,3 @@ def forward(
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
-
-
-__all__ = [
- "Gemma2ForCausalLM",
- "Gemma2Model",
- "Gemma2PreTrainedModel",
- "Gemma2ForSequenceClassification",
- "Gemma2ForTokenClassification",
-]
diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py
index 48b12411361aff..7236ae2f5c9f87 100644
--- a/src/transformers/models/gemma2/modular_gemma2.py
+++ b/src/transformers/models/gemma2/modular_gemma2.py
@@ -13,7 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Callable, Optional, Tuple, Union
+from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
@@ -22,27 +22,36 @@
from ...activations import ACT2FN
from ...cache_utils import Cache, HybridCache
from ...configuration_utils import PretrainedConfig
-from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
)
-from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
-from ...processing_utils import Unpack
-from ...utils import logging
+from ...utils import (
+ is_flash_attn_2_available,
+ is_flash_attn_greater_or_equal,
+ is_torch_greater_or_equal,
+ logging,
+)
from ..gemma.modeling_gemma import (
- GemmaAttention,
GemmaForCausalLM,
GemmaForSequenceClassification,
GemmaForTokenClassification,
- GemmaMLP,
GemmaModel,
+ GemmaPreTrainedModel,
GemmaRMSNorm,
+ GemmaRotaryEmbedding,
apply_rotary_pos_emb,
repeat_kv,
)
+if is_flash_attn_2_available():
+ from ...modeling_flash_attention_utils import _flash_attention_forward
+
+if is_torch_greater_or_equal("2.5"):
+ from torch.nn.attention.flex_attention import flex_attention
+
+
_CHECKPOINT_FOR_DOC = "google/gemma2-7b"
logger = logging.get_logger(__name__)
@@ -185,106 +194,286 @@ class Gemma2RMSNorm(GemmaRMSNorm):
pass
-class Gemma2MLP(GemmaMLP):
+class Gemma2MLP(nn.Module):
def __init__(self, config):
super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = config.intermediate_size
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_activation]
+ def forward(self, x):
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
+
+
+class Gemma2RotaryEmbedding(GemmaRotaryEmbedding):
+ pass
+
def eager_attention_forward(
- module: nn.Module,
+ config: Gemma2Config,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
- attention_mask: Optional[torch.Tensor],
- dropout: float = 0.0,
- scaling: Optional[float] = None,
- softcap: Optional[float] = None,
- **kwargs,
+ mask: Optional[torch.Tensor],
+ **_kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
- if scaling is None:
- scaling = module.head_dim**-0.5
+ key_states = repeat_kv(key, config.num_key_value_groups)
+ value_states = repeat_kv(value, config.num_key_value_groups)
- key_states = repeat_kv(key, module.num_key_value_groups)
- value_states = repeat_kv(value, module.num_key_value_groups)
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * config.scaling
- attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
-
- if softcap is not None:
- attn_weights = attn_weights / softcap
+ if config.attn_logit_softcapping is not None:
+ attn_weights = attn_weights / config.attn_logit_softcapping
attn_weights = torch.tanh(attn_weights)
- attn_weights = attn_weights * softcap
- if attention_mask is not None: # no matter the length, we just slice it
- causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights * config.attn_logit_softcapping
+ if mask is not None: # no matter the length, we just slice it
+ causal_mask = mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
- attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_weights = nn.functional.dropout(attn_weights, p=config.attention_dropout, training=config.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
-class Gemma2Attention(GemmaAttention):
- def __init__(self, config: Gemma2Config, layer_idx: int):
- super().__init__(config, layer_idx)
- self.attn_logit_softcapping = self.config.attn_logit_softcapping
- self.attention_dropout = self.config.attention_dropout
+def flash_attention_forward(
+ config: Gemma2Config,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ mask: Optional[torch.Tensor],
+ target_dtype: torch.dtype = torch.float16,
+ **_kwargs,
+) -> Tuple[torch.Tensor, None]:
+ if mask is not None:
+ seq_len = mask.shape[1]
+ query = query[:, :, :seq_len]
+ value = value[:, :, :seq_len]
+
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout
+ # [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor rotary embedding
+ query_states = query.transpose(1, 2)
+ key_states = key.transpose(1, 2)
+ value_states = value.transpose(1, 2)
+
+ dropout_rate = config.attention_dropout if config.training else 0.0
+
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ attn_output = _flash_attention_forward(
+ query_states,
+ key_states,
+ value_states,
+ mask,
+ seq_len,
+ dropout=dropout_rate,
+ softmax_scale=config.scaling,
+ is_causal=config.is_causal,
+ sliding_window=config.sliding_window,
+ use_top_left_mask=config._flash_attn_uses_top_left_mask,
+ softcap=config.attn_logit_softcapping if is_flash_attn_greater_or_equal("2.6.0") else None,
+ )
+
+ return attn_output, None
+
+
+def flex_attention_forward(
+ config: Gemma2Config,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ mask: Optional[torch.Tensor],
+ output_attentions: bool = False,
+ **_kwargs,
+) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+ def tanh_softcap(score, b, h, q_idx, kv_idx):
+ soft_cap = config.attn_logit_softcapping
+ score = soft_cap * torch.tanh(score / soft_cap)
+ if mask is not None:
+ return score + mask[b][0][q_idx][kv_idx]
+ return score
+
+ attn_output = flex_attention(
+ query,
+ key,
+ value,
+ score_mod=tanh_softcap,
+ enable_gqa=True,
+ scale=config.scaling,
+ return_lse=output_attentions,
+ )
+ if not output_attentions:
+ attn_weights = None
+ else:
+ attn_output, attn_weights = attn_output
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ return attn_output, attn_weights
+
+
+def sdpa_attention_forward(
+ config: Gemma2Config,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ mask: Optional[torch.Tensor],
+ **_kwargs,
+) -> Tuple[torch.Tensor, None]:
+ key = repeat_kv(key, config.num_key_value_groups)
+ value = repeat_kv(value, config.num_key_value_groups)
+
+ causal_mask = mask
+ if mask is not None:
+ causal_mask = causal_mask[:, :, :, : key.shape[-2]]
+
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
+ if query.device.type == "cuda" and causal_mask is not None:
+ query = query.contiguous()
+ key = key.contiguous()
+ value = value.contiguous()
+
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
+ is_causal = True if causal_mask is None and query.shape[1] > 1 else False
+
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
+ query,
+ key,
+ value,
+ attn_mask=causal_mask,
+ dropout_p=config.attention_dropout if config.training else 0.0,
+ is_causal=is_causal,
+ scale=config.scaling,
+ )
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ return attn_output, None
+
+
+GEMMA2_ATTENTION_FUNCTION = {
+ "flash_attention_2": flash_attention_forward,
+ "flex_attention": flex_attention_forward,
+ "eager": eager_attention_forward,
+ "sdpa": sdpa_attention_forward,
+}
+
+
+class Gemma2Attention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+
+ self.attention_dropout = config.attention_dropout
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = config.head_dim
+ self.num_key_value_heads = config.num_key_value_heads
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
+ self.max_position_embeddings = config.max_position_embeddings
+ self.rope_theta = config.rope_theta
self.is_causal = True
self.scaling = config.query_pre_attn_scalar**-0.5
self.sliding_window = config.sliding_window if not bool(layer_idx % 2) else None
+ self.attn_logit_softcapping = config.attn_logit_softcapping
+ if self.hidden_size % self.num_heads != 0:
+ raise ValueError(
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
+ f" and `num_heads`: {self.num_heads})."
+ )
+
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
+ self.rotary_emb = Gemma2RotaryEmbedding(
+ self.head_dim,
+ max_position_embeddings=self.max_position_embeddings,
+ base=self.rope_theta,
+ )
def forward(
self,
hidden_states: torch.Tensor,
- position_embeddings: Tuple[torch.Tensor, torch.Tensor],
- attention_mask: Optional[torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
- **kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
- input_shape = hidden_states.shape[:-1]
- hidden_shape = (*input_shape, -1, self.head_dim)
+ bsz, q_len, _ = hidden_states.size()
- query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
- key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
- value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
- cos, sin = position_embeddings
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
- cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ cache_kwargs = {
+ "sin": sin,
+ "cos": cos,
+ "sliding_window": self.sliding_window,
+ "cache_position": cache_position,
+ }
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
- attention_interface: Callable = eager_attention_forward
- if self.config._attn_implementation != "eager":
- if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
- logger.warning_once(
- "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
- 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
- )
- else:
- attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+ if output_attentions and self.config._attn_implementation in ["sdpa", "flash_attention_2"]:
+ logger.warning_once("Setting `attention_type` to `flex_attention` because `output_attentions=True`")
+ attention_type = "flex_attention"
+ else:
+ attention_type = self.config._attn_implementation
- attn_output, attn_weights = attention_interface(
- self,
- query_states,
- key_states,
- value_states,
- attention_mask,
- dropout=self.attention_dropout if self.training else 0.0,
- scaling=self.scaling,
- sliding_window=self.sliding_window,
- softcap=self.attn_logit_softcapping,
- **kwargs,
+ attn_output, attn_weights = GEMMA2_ATTENTION_FUNCTION[attention_type](
+ self, query_states, key_states, value_states, attention_mask, output_attentions=output_attentions
)
- attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
attn_output = self.o_proj(attn_output)
- return attn_output, attn_weights
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+class Gemma2FlashAttention2(Gemma2Attention):
+ def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None):
+ super().__init__(config, layer_idx)
+ self.config._attn_implementation = "flash_attention_2"
+ logger.warning_once(
+ "The `Gemma2FlashAttention2` class is deprecated in favor of simply modifying the `config._attn_implementation`"
+ "attribute of the `GemmaAttention` class! It will be removed in v4.48"
+ )
+
+
+class Gemma2SdpaAttention(Gemma2Attention):
+ def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None):
+ super().__init__(config, layer_idx)
+ self.config._attn_implementation = "sdpa"
+ logger.warning_once(
+ "The `Gemma2FlashAttention2` class is deprecated in favor of simply modifying the `config._attn_implementation`"
+ "attribute of the `GemmaAttention` class! It will be removed in v4.48"
+ )
class Gemma2DecoderLayer(nn.Module):
@@ -305,7 +494,6 @@ def __init__(self, config: Gemma2Config, layer_idx: int):
def forward(
self,
hidden_states: torch.Tensor,
- position_embeddings: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
@@ -332,9 +520,8 @@ def forward(
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
- hidden_states, self_attn_weights = self.self_attn(
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
- position_embeddings=position_embeddings,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
@@ -356,15 +543,37 @@ def forward(
if output_attentions:
outputs += (self_attn_weights,)
+ if use_cache:
+ outputs += (present_key_value,)
+
return outputs
-class Gemma2Model(GemmaModel):
+class Gemma2PreTrainedModel(GemmaPreTrainedModel):
+ _supports_quantized_cache = False
+
+ @classmethod
+ def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False):
+ """
+ Overloads `PreTrainedModel._check_and_enable_sdpa` so as to DISABLE torch SDPA by default on Gemma2 models.
+ SDPA reduces the model performance on Gemma2 because of the logits softcapping.
+ """
+ config = super()._check_and_enable_sdpa(config, hard_check_only=hard_check_only)
+
+ # if using the default path -> swap sdpa by eager
+ if not hard_check_only and config._attn_implementation == "sdpa":
+ config._attn_implementation = "eager"
+
+ return config
+
+
+class Gemma2Model(GemmaModel, Gemma2PreTrainedModel):
def __init__(self, config: Gemma2Config):
super().__init__(config)
self.layers = nn.ModuleList(
[Gemma2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
+ self.post_init()
def forward(
self,
@@ -424,9 +633,6 @@ def forward(
# embed positions
hidden_states = inputs_embeds
- # create position embeddings to be shared across the decoder layers
- position_embeddings = self.rotary_emb(hidden_states, position_ids)
-
# normalized
# Gemma2 downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5
# See https://github.com/huggingface/transformers/pull/29402
@@ -445,7 +651,6 @@ def forward(
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
- position_embeddings,
causal_mask,
position_ids,
past_key_values,
@@ -456,7 +661,6 @@ def forward(
else:
layer_outputs = decoder_layer(
hidden_states,
- position_embeddings=position_embeddings,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
@@ -475,13 +679,16 @@ def forward(
if output_hidden_states:
all_hidden_states += (hidden_states,)
- output = BaseModelOutputWithPast(
+ next_cache = past_key_values if use_cache else None
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
+ return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
- past_key_values=past_key_values,
+ past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
- return output if return_dict else output.to_tuple()
@torch.no_grad()
def _update_causal_mask(
@@ -696,13 +903,3 @@ def __init__(self, config):
super().__init__(config)
self.model = Gemma2Model(config)
self.post_init()
-
-
-__all__ = [
- "Gemma2Config",
- "Gemma2ForCausalLM",
- "Gemma2Model",
- "Gemma2PreTrainedModel", # noqa: F822
- "Gemma2ForSequenceClassification",
- "Gemma2ForTokenClassification",
-]
diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py
index 95ad0d9719951d..16a724f69464a9 100644
--- a/src/transformers/models/glm/modeling_glm.py
+++ b/src/transformers/models/glm/modeling_glm.py
@@ -19,7 +19,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Callable, List, Optional, Tuple, Union
+import math
+from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
@@ -28,21 +29,20 @@
from ...cache_utils import Cache, DynamicCache, StaticCache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter
-from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...modeling_flash_attention_utils import FlashAttentionKwargs, _flash_attention_forward
from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
SequenceClassifierOutputWithPast,
TokenClassifierOutput,
)
-from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
-from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...modeling_utils import PreTrainedModel
from ...processing_utils import Unpack
from ...utils import (
- LossKwargs,
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
+ is_flash_attn_greater_or_equal_2_10,
logging,
replace_return_docstrings,
)
@@ -55,6 +55,55 @@
_CONFIG_FOR_DOC = "GlmConfig"
+class GlmRMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ GlmRMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+ def extra_repr(self):
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
+
+
+class GlmRotaryEmbedding(nn.Module):
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
+ super().__init__()
+
+ self.dim = dim
+ self.max_position_embeddings = max_position_embeddings
+ self.base = base
+
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim))
+ self.register_buffer("inv_freq", tensor=inv_freq, persistent=False)
+
+ @torch.no_grad()
+ def forward(self, x, position_ids, seq_len=None):
+ # x: [bs, num_attention_heads, seq_len, head_size]
+ self.inv_freq.to(x.device)
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
+ position_ids_expanded = position_ids[:, None, :].float()
+ # Force float32 since bfloat16 loses precision on long contexts
+ # See https://github.com/huggingface/transformers/pull/29285
+ device_type = x.device.type
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False):
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos()
+ sin = emb.sin()
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
class GlmMLP(nn.Module):
def __init__(self, config):
super().__init__()
@@ -86,32 +135,6 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
-def eager_attention_forward(
- module: nn.Module,
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- attention_mask: Optional[torch.Tensor],
- scaling: float,
- dropout: float = 0.0,
- **kwargs,
-):
- key_states = repeat_kv(key, module.num_key_value_groups)
- value_states = repeat_kv(value, module.num_key_value_groups)
-
- attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
- if attention_mask is not None:
- causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
- attn_weights = attn_weights + causal_mask
-
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
- attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
- attn_output = torch.matmul(attn_weights, value_states)
- attn_output = attn_output.transpose(1, 2).contiguous()
-
- return attn_output, attn_weights
-
-
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., 0::2]
@@ -168,38 +191,54 @@ def __init__(self, config: GlmConfig, layer_idx: Optional[int] = None):
super().__init__()
self.config = config
self.layer_idx = layer_idx
- self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
- self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
- self.scaling = self.head_dim**-0.5
+ if layer_idx is None:
+ logger.warning_once(
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
+ "when creating this class."
+ )
+
self.attention_dropout = config.attention_dropout
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.hidden_size // self.num_heads
+ self.num_key_value_heads = config.num_key_value_heads
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.is_causal = True
+ self.scaling = 1 / math.sqrt(self.head_dim)
- self.q_proj = nn.Linear(
- config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
- )
- self.k_proj = nn.Linear(
- config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
- )
- self.v_proj = nn.Linear(
- config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
- )
- self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
+ if (self.head_dim * self.num_heads) != self.hidden_size:
+ raise ValueError(
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
+ f" and `num_heads`: {self.num_heads})."
+ )
+
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
+ self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
def forward(
self,
hidden_states: torch.Tensor,
- position_embeddings: Tuple[torch.Tensor, torch.Tensor],
- attention_mask: Optional[torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
- **kwargs: Unpack[FlashAttentionKwargs],
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
+ **kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
- input_shape = hidden_states.shape[:-1]
- hidden_shape = (*input_shape, -1, self.head_dim)
+ bsz, q_len, _ = hidden_states.size()
- query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
- key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
- value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
@@ -209,123 +248,247 @@ def forward(
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
- attention_interface: Callable = eager_attention_forward
- if self.config._attn_implementation != "eager":
- if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
- logger.warning_once(
- "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
- 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
- )
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling
+
+ if attention_mask is not None: # no matter the length, we just slice it
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ attn_output = attn_output.view(bsz, q_len, -1)
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+class GlmFlashAttention2(GlmAttention):
+ """
+ Glm flash attention module. This module inherits from `GlmAttention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ output_attentions = False
+
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ # Flash attention requires the input to have the shape
+ # batch_size x seq_length x head_dim x hidden_dim
+ # therefore we just need to keep the original shape
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
+ # to be able to avoid many of these transpose/reshape/view.
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ dropout_rate = self.attention_dropout if self.training else 0.0
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in the correct dtype just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+ # in fp32. (GlmRMSNorm handles it correctly)
+
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
else:
- attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+ target_dtype = self.q_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
- attn_output, attn_weights = attention_interface(
- self,
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ attn_output = _flash_attention_forward(
query_states,
key_states,
value_states,
attention_mask,
- dropout=0.0 if not self.training else self.attention_dropout,
- scaling=self.scaling,
- **kwargs,
+ q_len,
+ position_ids=position_ids,
+ dropout=dropout_rate,
+ softmax_scale=self.scaling,
+ sliding_window=getattr(self, "sliding_window", None),
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
+ is_causal=self.is_causal,
)
- attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
attn_output = self.o_proj(attn_output)
- return attn_output, attn_weights
+ if not output_attentions:
+ attn_weights = None
-class GlmRMSNorm(nn.Module):
- def __init__(self, hidden_size, eps=1e-6):
- """
- GlmRMSNorm is equivalent to T5LayerNorm
- """
- super().__init__()
- self.weight = nn.Parameter(torch.ones(hidden_size))
- self.variance_epsilon = eps
+ return attn_output, attn_weights, past_key_value
- def forward(self, hidden_states):
- input_dtype = hidden_states.dtype
- hidden_states = hidden_states.to(torch.float32)
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
- return self.weight * hidden_states.to(input_dtype)
-
- def extra_repr(self):
- return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
+class GlmSdpaAttention(GlmAttention):
+ """
+ Glm attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
+ `GlmAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
+ SDPA API.
+ """
-class GlmRotaryEmbedding(nn.Module):
- def __init__(
+ # Adapted from GlmAttention.forward
+ def forward(
self,
- config: GlmConfig,
- device=None,
- ):
- super().__init__()
- self.rope_kwargs = {}
- # BC: "rope_type" was originally "type"
- if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
- self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
- else:
- self.rope_type = "default"
- self.max_seq_len_cached = config.max_position_embeddings
- self.original_max_seq_len = config.max_position_embeddings
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ if output_attentions:
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
+ logger.warning_once(
+ "GlmModel is using GlmSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
+ )
+ return super().forward(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ )
- self.config = config
- self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
+ bsz, q_len, _ = hidden_states.size()
- inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
- self.register_buffer("inv_freq", inv_freq, persistent=False)
- self.original_inv_freq = self.inv_freq
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
- def _dynamic_frequency_update(self, position_ids, device):
- """
- dynamic RoPE layers should recompute `inv_freq` in the following situations:
- 1 - growing beyond the cached sequence length (allow scaling)
- 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
- """
- seq_len = torch.max(position_ids) + 1
- if seq_len > self.max_seq_len_cached: # growth
- inv_freq, self.attention_scaling = self.rope_init_fn(
- self.config, device, seq_len=seq_len, **self.rope_kwargs
- )
- self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
- self.max_seq_len_cached = seq_len
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
- if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
- self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
- self.max_seq_len_cached = self.original_max_seq_len
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
- @torch.no_grad()
- def forward(self, x, position_ids):
- if "dynamic" in self.rope_type:
- self._dynamic_frequency_update(position_ids, device=x.device)
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
- # Core RoPE block
- inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
- position_ids_expanded = position_ids[:, None, :].float()
- # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
- device_type = x.device.type
- device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
- with torch.autocast(device_type=device_type, enabled=False):
- freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
- emb = torch.cat((freqs, freqs), dim=-1)
- cos = emb.cos()
- sin = emb.sin()
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
- # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
- cos = cos * self.attention_scaling
- sin = sin * self.attention_scaling
+ causal_mask = attention_mask
+ if attention_mask is not None:
+ causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
+
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
+ if query_states.device.type == "cuda" and causal_mask is not None:
+ query_states = query_states.contiguous()
+ key_states = key_states.contiguous()
+ value_states = value_states.contiguous()
+
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
+ is_causal = True if causal_mask is None and q_len > 1 else False
+
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
+ query_states,
+ key_states,
+ value_states,
+ attn_mask=causal_mask,
+ dropout_p=self.attention_dropout if self.training else 0.0,
+ is_causal=is_causal,
+ scale=self.scaling,
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.view(bsz, q_len, -1)
+
+ attn_output = self.o_proj(attn_output)
+
+ return attn_output, None, past_key_value
- return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+GLM_ATTENTION_CLASSES = {
+ "eager": GlmAttention,
+ "flash_attention_2": GlmFlashAttention2,
+ "sdpa": GlmSdpaAttention,
+}
class GlmDecoderLayer(nn.Module):
- def __init__(self, config: GlmConfig, layer_idx: int):
+ def __init__(self, config: GlmConfig, layer_idx: Optional[int] = None):
super().__init__()
self.hidden_size = config.hidden_size
- self.self_attn = GlmAttention(config=config, layer_idx=layer_idx)
+ self.self_attn = GLM_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
self.mlp = GlmMLP(config)
self.input_layernorm = GlmRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -340,15 +503,37 @@ def forward(
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
- **kwargs: Unpack[FlashAttentionKwargs],
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
+ **kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`, *optional*):
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
+ query_sequence_length, key_sequence_length)` if default attention is used.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence
+ position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
+ Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
+ with `head_dim` being the embedding dimension of each attention head.
+ kwargs (`dict`, *optional*):
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
+ into the model
+ """
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
- hidden_states, self_attn_weights = self.self_attn(
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
@@ -368,9 +553,13 @@ def forward(
hidden_states = residual + hidden_states
outputs = (hidden_states,)
+
if output_attentions:
outputs += (self_attn_weights,)
+ if use_cache:
+ outputs += (present_key_value,)
+
return outputs
@@ -516,8 +705,14 @@ def __init__(self, config: GlmConfig):
[GlmDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = GlmRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
- self.rotary_emb = GlmRotaryEmbedding(config=config)
+ self.rotary_emb = GlmRotaryEmbedding(
+ dim=int(config.head_dim * config.partial_rotary_factor),
+ max_position_embeddings=config.max_position_embeddings,
+ base=config.rope_theta,
+ )
self.gradient_checkpointing = False
+ if getattr(config, "pretraining_tp", 1) != 1:
+ logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.")
# Initialize weights and apply final processing
self.post_init()
@@ -534,7 +729,7 @@ def forward(
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[Cache] = None,
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
@@ -562,22 +757,31 @@ def forward(
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
- if use_cache and past_key_values is None:
- past_key_values = DynamicCache()
+ # kept for BC (non `Cache` `past_key_values` inputs)
+ return_legacy_cache = False
+ if use_cache and not isinstance(past_key_values, Cache):
+ return_legacy_cache = True
+ if past_key_values is None:
+ past_key_values = DynamicCache()
+ else:
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
+ logger.warning_once(
+ "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
+ "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
+ "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
+ )
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
-
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
-
hidden_states = inputs_embeds
# create position embeddings to be shared across the decoder layers
@@ -586,6 +790,7 @@ def forward(
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
+ next_decoder_cache = None
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
if output_hidden_states:
@@ -618,6 +823,9 @@ def forward(
hidden_states = layer_outputs[0]
+ if use_cache:
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
+
if output_attentions:
all_self_attns += (layer_outputs[1],)
@@ -627,13 +835,18 @@ def forward(
if output_hidden_states:
all_hidden_states += (hidden_states,)
- output = BaseModelOutputWithPast(
+ next_cache = next_decoder_cache if use_cache else None
+ if return_legacy_cache:
+ next_cache = next_cache.to_legacy_cache()
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
+ return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
- past_key_values=past_key_values if use_cache else None,
+ past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
- return output if return_dict else output.to_tuple()
def _update_causal_mask(
self,
@@ -757,14 +970,11 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
return causal_mask
-class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
-
-
class GlmForCausalLM(GlmPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
- def __init__(self, config):
+ def __init__(self, config: GlmConfig):
super().__init__(config)
self.model = GlmModel(config)
self.vocab_size = config.vocab_size
@@ -807,7 +1017,7 @@ def forward(
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
- **kwargs: Unpack[KwargsForCausalLM],
+ **loss_kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
@@ -828,16 +1038,16 @@ def forward(
```python
>>> from transformers import AutoTokenizer, GlmForCausalLM
- >>> model = GlmForCausalLM.from_pretrained("meta-glm/Glm-2-7b-hf")
- >>> tokenizer = AutoTokenizer.from_pretrained("meta-glm/Glm-2-7b-hf")
+ >>> model = GlmForCausalLM.from_pretrained("google/glm-7b")
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/glm-7b")
- >>> prompt = "Hey, are you conscious? Can you talk to me?"
+ >>> prompt = "What is your favorite condiment?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
- "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
+ "What is your favorite condiment?"
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
@@ -857,7 +1067,6 @@ def forward(
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
- **kwargs,
)
hidden_states = outputs[0]
@@ -866,7 +1075,7 @@ def forward(
loss = None
if labels is not None:
- loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
+ loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
if not return_dict:
output = (logits,) + outputs[1:]
@@ -897,7 +1106,7 @@ def forward(
GLM_START_DOCSTRING,
)
class GlmForSequenceClassification(GlmPreTrainedModel):
- def __init__(self, config):
+ def __init__(self, config: GlmConfig):
super().__init__(config)
self.num_labels = config.num_labels
self.model = GlmModel(config)
@@ -993,7 +1202,7 @@ def forward(
GLM_START_DOCSTRING,
)
class GlmForTokenClassification(GlmPreTrainedModel):
- def __init__(self, config):
+ def __init__(self, config: GlmConfig):
super().__init__(config)
self.num_labels = config.num_labels
self.model = GlmModel(config)
diff --git a/src/transformers/models/glm/modular_glm.py b/src/transformers/models/glm/modular_glm.py
index ec07be10fb6a55..48605c15d30be3 100644
--- a/src/transformers/models/glm/modular_glm.py
+++ b/src/transformers/models/glm/modular_glm.py
@@ -13,6 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import math
from typing import Optional
import torch
@@ -20,13 +21,26 @@
import torch.utils.checkpoint
from ...utils import logging
+from ..gemma.modeling_gemma import (
+ GemmaForCausalLM,
+ GemmaForSequenceClassification,
+ GemmaForTokenClassification,
+)
+from ..granite.modeling_granite import (
+ GraniteAttention,
+ GraniteFlashAttention2,
+ GraniteSdpaAttention,
+)
from ..llama.modeling_llama import (
- LlamaAttention,
- LlamaForCausalLM,
- LlamaForSequenceClassification,
- LlamaForTokenClassification,
+ LlamaDecoderLayer,
+ LlamaModel,
+ LlamaPreTrainedModel,
+)
+from ..phi3.modeling_phi3 import (
+ Phi3MLP,
+ Phi3RMSNorm,
+ Phi3RotaryEmbedding,
)
-from ..phi3.modeling_phi3 import Phi3MLP
from .configuration_glm import GlmConfig
@@ -35,6 +49,14 @@
_CHECKPOINT_FOR_DOC = "THUDM/glm-4-9b"
+class GlmRMSNorm(Phi3RMSNorm):
+ pass
+
+
+class GlmRotaryEmbedding(Phi3RotaryEmbedding):
+ pass
+
+
class GlmMLP(Phi3MLP):
pass
@@ -88,27 +110,83 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
return q_embed, k_embed
-class GlmAttention(LlamaAttention):
+class GlmAttention(GraniteAttention):
def __init__(self, config: GlmConfig, layer_idx: Optional[int] = None):
super().__init__(config, layer_idx)
- self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
+ self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
+ self.scaling = 1 / math.sqrt(self.head_dim)
-class GlmForCausalLM(LlamaForCausalLM):
+class GlmFlashAttention2(GlmAttention, GraniteFlashAttention2):
pass
-class GlmForSequenceClassification(LlamaForSequenceClassification):
+class GlmSdpaAttention(GraniteSdpaAttention):
pass
-class GlmForTokenClassification(LlamaForTokenClassification):
+GLM_ATTENTION_CLASSES = {
+ "eager": GlmAttention,
+ "flash_attention_2": GlmFlashAttention2,
+ "sdpa": GlmSdpaAttention,
+}
+
+
+class GlmDecoderLayer(LlamaDecoderLayer):
+ def __init__(self, config: GlmConfig, layer_idx: Optional[int] = None):
+ super().__init__()
+
+ self.mlp = GlmMLP(config)
+ self.input_layernorm = GlmRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = GlmRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+
+class GlmPreTrainedModel(LlamaPreTrainedModel):
pass
+class GlmModel(GlmPreTrainedModel, LlamaModel):
+ def __init__(self, config: GlmConfig):
+ super().__init__(config)
+ self.layers = nn.ModuleList(
+ [GlmDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self.norm = GlmRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.rotary_emb = GlmRotaryEmbedding(
+ dim=int(config.head_dim * config.partial_rotary_factor),
+ max_position_embeddings=config.max_position_embeddings,
+ base=config.rope_theta,
+ )
+ self.gradient_checkpointing = False
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+
+class GlmForCausalLM(GemmaForCausalLM):
+ def __init__(self, config: GlmConfig):
+ super().__init__(config)
+ self.model = GlmModel(config)
+ self.post_init()
+
+
+class GlmForSequenceClassification(GemmaForSequenceClassification):
+ def __init__(self, config: GlmConfig):
+ super().__init__(config)
+ self.model = GlmModel(config)
+ self.post_init()
+
+
+class GlmForTokenClassification(GemmaForTokenClassification):
+ def __init__(self, config: GlmConfig):
+ super().__init__(config)
+ self.model = GlmModel(config)
+ self.post_init()
+
+
__all__ = [
- "GlmPreTrainedModel", # noqa: F822
- "GlmModel", # noqa: F822
+ "GlmPreTrainedModel",
+ "GlmModel",
"GlmForCausalLM",
"GlmForSequenceClassification",
"GlmForTokenClassification",
diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py
index ad53c7804ebeea..58143192c20482 100644
--- a/src/transformers/models/gpt2/modeling_gpt2.py
+++ b/src/transformers/models/gpt2/modeling_gpt2.py
@@ -19,10 +19,11 @@
import os
import warnings
from dataclasses import dataclass
-from typing import Callable, Optional, Tuple, Union
+from typing import Optional, Tuple, Union
import torch
import torch.utils.checkpoint
+from packaging import version
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
@@ -36,13 +37,16 @@
SequenceClassifierOutputWithPast,
TokenClassifierOutput,
)
-from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel, SequenceSummary
+from ...modeling_utils import PreTrainedModel, SequenceSummary
from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
from ...utils import (
ModelOutput,
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
+ get_torch_version,
+ is_flash_attn_2_available,
+ is_flash_attn_greater_or_equal_2_10,
logging,
replace_return_docstrings,
)
@@ -50,6 +54,10 @@
from .configuration_gpt2 import GPT2Config
+if is_flash_attn_2_available():
+ from ...modeling_flash_attention_utils import _flash_attention_forward
+
+
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "openai-community/gpt2"
@@ -112,48 +120,6 @@ def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
return model
-def eager_attention_forward(module, query, key, value, attention_mask, head_mask=None, **kwargs):
- attn_weights = torch.matmul(query, key.transpose(-1, -2))
-
- if module.scale_attn_weights:
- attn_weights = attn_weights / torch.full(
- [], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device
- )
-
- # Layer-wise attention scaling
- if module.scale_attn_by_inverse_layer_idx:
- attn_weights = attn_weights / float(module.layer_idx + 1)
-
- if not module.is_cross_attention:
- # if only "normal" attention layer implements causal mask
- query_length, key_length = query.size(-2), key.size(-2)
- causal_mask = module.bias[:, :, key_length - query_length : key_length, :key_length]
- mask_value = torch.finfo(attn_weights.dtype).min
- # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
- # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
- mask_value = torch.full([], mask_value, dtype=attn_weights.dtype, device=attn_weights.device)
- attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value)
-
- if attention_mask is not None:
- # Apply the attention mask
- attn_weights = attn_weights + attention_mask
-
- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
-
- # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise
- attn_weights = attn_weights.type(value.dtype)
- attn_weights = module.attn_dropout(attn_weights)
-
- # Mask heads if we want to
- if head_mask is not None:
- attn_weights = attn_weights * head_mask
-
- attn_output = torch.matmul(attn_weights, value)
- attn_output = attn_output.transpose(1, 2)
-
- return attn_output, attn_weights
-
-
class GPT2Attention(nn.Module):
def __init__(self, config, is_cross_attention=False, layer_idx=None):
super().__init__()
@@ -214,6 +180,46 @@ def prune_heads(self, heads):
self.num_heads = self.num_heads - len(heads)
self.pruned_heads = self.pruned_heads.union(heads)
+ def _attn(self, query, key, value, attention_mask=None, head_mask=None):
+ attn_weights = torch.matmul(query, key.transpose(-1, -2))
+
+ if self.scale_attn_weights:
+ attn_weights = attn_weights / torch.full(
+ [], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device
+ )
+
+ # Layer-wise attention scaling
+ if self.scale_attn_by_inverse_layer_idx:
+ attn_weights = attn_weights / float(self.layer_idx + 1)
+
+ if not self.is_cross_attention:
+ # if only "normal" attention layer implements causal mask
+ query_length, key_length = query.size(-2), key.size(-2)
+ causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
+ mask_value = torch.finfo(attn_weights.dtype).min
+ # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
+ # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
+ mask_value = torch.full([], mask_value, dtype=attn_weights.dtype, device=attn_weights.device)
+ attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value)
+
+ if attention_mask is not None:
+ # Apply the attention mask
+ attn_weights = attn_weights + attention_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+
+ # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise
+ attn_weights = attn_weights.type(value.dtype)
+ attn_weights = self.attn_dropout(attn_weights)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attn_weights = attn_weights * head_mask
+
+ attn_output = torch.matmul(attn_weights, value)
+
+ return attn_output, attn_weights
+
def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, head_mask=None):
# Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM)
bsz, num_heads, q_seq_len, dk = query.size()
@@ -263,10 +269,25 @@ def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, hea
attn_weights = attn_weights * head_mask
attn_output = torch.matmul(attn_weights, value)
- attn_output = attn_output.transpose(1, 2)
return attn_output, attn_weights
+ def _split_heads(self, tensor, num_heads, attn_head_size):
+ """
+ Splits hidden_size dim into attn_head_size and num_heads
+ """
+ new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
+ tensor = tensor.view(new_shape)
+ return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
+
+ def _merge_heads(self, tensor, num_heads, attn_head_size):
+ """
+ Merges attn_head_size dim and num_attn_heads dim into hidden_size
+ """
+ tensor = tensor.permute(0, 2, 1, 3).contiguous()
+ new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
+ return tensor.view(new_shape)
+
def forward(
self,
hidden_states: Optional[Tuple[torch.FloatTensor]],
@@ -277,7 +298,6 @@ def forward(
encoder_attention_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
- **kwargs,
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
if encoder_hidden_states is not None:
if not hasattr(self, "q_attn"):
@@ -286,65 +306,32 @@ def forward(
"Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
)
- query_states = self.q_attn(hidden_states)
- key_states, value_states = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
+ query = self.q_attn(hidden_states)
+ key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
attention_mask = encoder_attention_mask
else:
- query_states, key_states, value_states = self.c_attn(hidden_states).split(self.split_size, dim=2)
-
- shape_q = (*query_states.shape[:-1], -1, self.head_dim)
- shape_kv = (*key_states.shape[:-1], -1, self.head_dim)
+ query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
- query_states = query_states.reshape(shape_q).transpose(1, 2)
- key_states = key_states.reshape(shape_kv).transpose(1, 2)
- value_states = value_states.reshape(shape_kv).transpose(1, 2)
+ query = self._split_heads(query, self.num_heads, self.head_dim)
+ key = self._split_heads(key, self.num_heads, self.head_dim)
+ value = self._split_heads(value, self.num_heads, self.head_dim)
if layer_past is not None:
past_key, past_value = layer_past
- key_states = torch.cat((past_key, key_states), dim=-2)
- value_states = torch.cat((past_value, value_states), dim=-2)
+ key = torch.cat((past_key, key), dim=-2)
+ value = torch.cat((past_value, value), dim=-2)
if use_cache is True:
- present = (key_states, value_states)
+ present = (key, value)
else:
present = None
- is_cross_attention = encoder_hidden_states is not None
- is_causal = attention_mask is None and query_states.shape[-2] > 1 and not is_cross_attention
-
- using_eager = self.config._attn_implementation == "eager"
- attention_interface: Callable = eager_attention_forward
- if self.config._attn_implementation != "eager":
- if self.config._attn_implementation == "sdpa" and (output_attentions or head_mask is not None):
- using_eager = True
- logger.warning_once(
- "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
- 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
- )
- else:
- # Attention functions are consistent with previous equivalent attention classes, however they do not support some options
- # (e.g. layer scaling, head mask) that eager supports. These implementations are thus equivalent to previous code, but
- # not necessarily to eager (if mentionned options are provided).
- attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
-
- if using_eager and self.reorder_and_upcast_attn:
- attn_output, attn_weights = self._upcast_and_reordered_attn(
- query_states, key_states, value_states, attention_mask, head_mask
- )
+ if self.reorder_and_upcast_attn:
+ attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask)
else:
- attn_output, attn_weights = attention_interface(
- self,
- query_states,
- key_states,
- value_states,
- attention_mask,
- head_mask=head_mask,
- dropout=self.attn_dropout.p if self.training else 0.0,
- is_causal=is_causal,
- **kwargs,
- )
+ attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
- attn_output = attn_output.reshape(*attn_output.shape[:-2], -1).contiguous()
+ attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
attn_output = self.c_proj(attn_output)
attn_output = self.resid_dropout(attn_output)
@@ -355,6 +342,226 @@ def forward(
return outputs # a, present, (attentions)
+class GPT2FlashAttention2(GPT2Attention):
+ """
+ GPT2 flash attention module. This module inherits from `GPT2Attention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+
+ def forward(
+ self,
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = False,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
+ bsz, _, _ = hidden_states.size()
+ if encoder_hidden_states is not None:
+ if not hasattr(self, "q_attn"):
+ raise ValueError(
+ "If class is used as cross attention, the weights `q_attn` have to be defined. "
+ "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
+ )
+
+ query = self.q_attn(hidden_states)
+ key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
+ attention_mask = encoder_attention_mask
+ else:
+ query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
+
+ query = self._split_heads(query, self.num_heads, self.head_dim)
+ key = self._split_heads(key, self.num_heads, self.head_dim)
+ value = self._split_heads(value, self.num_heads, self.head_dim)
+
+ if layer_past is not None:
+ past_key = layer_past[0]
+ past_value = layer_past[1]
+ key = torch.cat((past_key, key), dim=-2)
+ value = torch.cat((past_value, value), dim=-2)
+
+ present = None
+ if use_cache is True:
+ present = (key, value)
+
+ query_length = query.shape[2]
+ tgt_len = key.shape[2]
+
+ # Flash attention requires the input to have the shape
+ # batch_size x seq_length x head_dim x hidden_dim
+ query = query.transpose(1, 2).view(bsz, query_length, self.num_heads, self.head_dim)
+ key = key.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim)
+ value = value.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim)
+
+ attn_dropout = self.attn_dropout.p if self.training else 0.0
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in the correct dtype just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+ # in fp32. (LlamaRMSNorm handles it correctly)
+
+ if query.dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.c_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query = query.to(target_dtype)
+ key = key.to(target_dtype)
+ value = value.to(target_dtype)
+
+ attn_output = _flash_attention_forward(
+ query,
+ key,
+ value,
+ attention_mask,
+ query_length,
+ dropout=attn_dropout,
+ is_causal=self.is_causal,
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
+ )
+
+ attn_weights_reshaped = attn_output.reshape(bsz, query_length, self.num_heads * self.head_dim)
+ attn_output = self.c_proj(attn_weights_reshaped)
+ attn_output = self.resid_dropout(attn_output)
+
+ outputs = (attn_output, present)
+ if output_attentions:
+ outputs += (attn_weights_reshaped,)
+
+ return outputs
+
+
+class GPT2SdpaAttention(GPT2Attention):
+ """
+ GPT2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
+ `GPT2Attention` as the weights of the module stays untouched. The only changes are on the forward pass
+ to adapt to the SDPA API.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ # Idea adapted from transformers.models.bert.modeling_bert.BertSdpaSelfAttention.__init__
+ # SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom
+ # attn_mask, so we need to call `.contiguous()`. This was fixed in torch==2.2.0.
+ # Reference: https://github.com/pytorch/pytorch/issues/112577
+ self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0")
+
+ def forward(
+ self,
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = False,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
+ if output_attentions or head_mask is not None:
+ logger.warning_once(
+ "`GPT2SdpaAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
+ "`output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, but "
+ "specifying the manual implementation will be required from Transformers version v5.0.0 onwards. "
+ 'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
+ )
+ return super().forward(
+ hidden_states=hidden_states,
+ layer_past=layer_past,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ )
+
+ bsz, q_len, _ = hidden_states.size()
+
+ # Initial attention projections
+ is_cross_attention = encoder_hidden_states is not None
+ if is_cross_attention:
+ if not hasattr(self, "q_attn"):
+ raise ValueError(
+ "If class is used as cross attention, the weights `q_attn` have to be defined. "
+ "Please make sure to instantiate class with `GPT2SdpaAttention(..., is_cross_attention=True)`."
+ )
+
+ query = self.q_attn(hidden_states)
+ key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
+ attention_mask = encoder_attention_mask
+ else:
+ query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
+
+ query = self._split_heads(query, self.num_heads, self.head_dim)
+ key = self._split_heads(key, self.num_heads, self.head_dim)
+ value = self._split_heads(value, self.num_heads, self.head_dim)
+
+ # Optional kv caching
+ if layer_past is not None:
+ past_key = layer_past[0]
+ past_value = layer_past[1]
+ key = torch.cat((past_key, key), dim=-2)
+ value = torch.cat((past_value, value), dim=-2)
+
+ present = None
+ if use_cache is True:
+ present = (key, value)
+
+ # Avoid torch==2.1.2 specific bug for the memory-efficient backend in SDPA
+ if self.require_contiguous_qkv and query.device.type == "cuda" and attention_mask is not None:
+ query = query.contiguous()
+ key = key.contiguous()
+ value = value.contiguous()
+
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
+ is_causal = True if attention_mask is None and q_len > 1 and not is_cross_attention else False
+
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
+ query,
+ key,
+ value,
+ attn_mask=attention_mask,
+ dropout_p=self.attn_dropout.p if self.training else 0.0,
+ is_causal=is_causal,
+ )
+
+ # Reshape outputs
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.view(bsz, q_len, self.embed_dim)
+
+ # Final projection
+ attn_output = self.c_proj(attn_output)
+ attn_output = self.resid_dropout(attn_output)
+
+ return attn_output, present, None
+
+
class GPT2MLP(nn.Module):
def __init__(self, intermediate_size, config):
super().__init__()
@@ -372,18 +579,22 @@ def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.Fl
return hidden_states
+GPT2_ATTENTION_CLASSES = {"eager": GPT2Attention, "flash_attention_2": GPT2FlashAttention2, "sdpa": GPT2SdpaAttention}
+
+
class GPT2Block(nn.Module):
def __init__(self, config, layer_idx=None):
super().__init__()
hidden_size = config.hidden_size
inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
+ attention_class = GPT2_ATTENTION_CLASSES[config._attn_implementation]
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
- self.attn = GPT2Attention(config=config, layer_idx=layer_idx)
+ self.attn = attention_class(config=config, layer_idx=layer_idx)
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
if config.add_cross_attention:
- self.crossattention = GPT2Attention(config=config, is_cross_attention=True, layer_idx=layer_idx)
+ self.crossattention = attention_class(config=config, is_cross_attention=True, layer_idx=layer_idx)
self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.mlp = GPT2MLP(inner_dim, config)
diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py
index 403159cdf39c9a..5326c7b907d4b1 100644
--- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py
+++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py
@@ -278,6 +278,7 @@ class GPTBigCodeFlashAttention2(GPTBigCodeAttention):
API of flash attention and deal with padding tokens in case the input contains any of them.
"""
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py
index ef23b5d208fd79..28bfbabc1fd8e0 100755
--- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py
+++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py
@@ -36,6 +36,7 @@
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
+from ...pytorch_utils import is_torch_greater_or_equal_than_1_13
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
@@ -55,6 +56,9 @@
# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
# It means that the function will not be traced through and simply appear as a node in the graph.
if is_torch_fx_available():
+ if not is_torch_greater_or_equal_than_1_13:
+ import torch.fx
+
_prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)
@@ -274,6 +278,7 @@ class GPTNeoFlashAttention2(GPTNeoSelfAttention):
flash attention and deal with padding tokens in case the input contains any of them.
"""
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py
index 7152d72f5b7fc8..3fdb814ebab51a 100755
--- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py
+++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py
@@ -311,7 +311,7 @@ def forward(
output_attentions: Optional[bool] = False,
padding_mask: Optional[torch.Tensor] = None,
cache_position: Optional[torch.LongTensor] = None,
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
):
bsz, seq_len, _ = hidden_states.shape
@@ -404,7 +404,7 @@ def _attn_projections_and_rope(
layer_past: Optional[Tuple[torch.Tensor]] = None,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
):
# Compute QKV
# Attention heads [batch, seq_len, hidden_size]
@@ -427,7 +427,16 @@ def _attn_projections_and_rope(
key_rot = key[..., : self.rotary_ndims]
key_pass = key[..., self.rotary_ndims :]
- cos, sin = position_embeddings
+ if position_embeddings is None:
+ logger.warning_once(
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
+ "removed and `position_embeddings` will be mandatory."
+ )
+ cos, sin = self.rotary_emb(value, position_ids)
+ else:
+ cos, sin = position_embeddings
query, key = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
query = torch.cat((query, query_pass), dim=-1)
key = torch.cat((key, key_pass), dim=-1)
@@ -490,18 +499,40 @@ def __init__(self, config, layer_idx=None):
class GPTNeoXRotaryEmbedding(nn.Module):
def __init__(
self,
- config: GPTNeoXConfig,
+ dim=None,
+ max_position_embeddings=2048,
+ base=10000,
device=None,
+ scaling_factor=1.0,
+ rope_type="default",
+ config: Optional[GPTNeoXConfig] = None,
):
super().__init__()
+ # TODO (joao): remove the `if` below, only used for BC
self.rope_kwargs = {}
- # BC: "rope_type" was originally "type"
- if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
- self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
+ if config is None:
+ logger.warning_once(
+ "`GPTNeoXRotaryEmbedding` can now be fully parameterized by passing the model config through the "
+ "`config` argument. All other arguments will be removed in v4.46"
+ )
+ self.rope_kwargs = {
+ "rope_type": rope_type,
+ "factor": scaling_factor,
+ "dim": dim,
+ "base": base,
+ "max_position_embeddings": max_position_embeddings,
+ }
+ self.rope_type = rope_type
+ self.max_seq_len_cached = max_position_embeddings
+ self.original_max_seq_len = max_position_embeddings
else:
- self.rope_type = "default"
- self.max_seq_len_cached = config.max_position_embeddings
- self.original_max_seq_len = config.max_position_embeddings
+ # BC: "rope_type" was originally "type"
+ if config.rope_scaling is not None:
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
+ else:
+ self.rope_type = "default"
+ self.max_seq_len_cached = config.max_position_embeddings
+ self.original_max_seq_len = config.max_position_embeddings
self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
@@ -552,6 +583,33 @@ def forward(self, x, position_ids):
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->GPTNeoX
+class GPTNeoXLinearScalingRotaryEmbedding(GPTNeoXRotaryEmbedding):
+ """GPTNeoXRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
+
+ def __init__(self, *args, **kwargs):
+ logger.warning_once(
+ "`GPTNeoXLinearScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use "
+ "`GPTNeoXRotaryEmbedding`, which now also does linear scaling (simply pass the model config to __init__)."
+ )
+ kwargs["rope_type"] = "linear"
+ super().__init__(*args, **kwargs)
+
+
+# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->GPTNeoX
+class GPTNeoXDynamicNTKScalingRotaryEmbedding(GPTNeoXRotaryEmbedding):
+ """GPTNeoXRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
+
+ def __init__(self, *args, **kwargs):
+ logger.warning_once(
+ "`GPTNeoXDynamicNTKScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use "
+ "`GPTNeoXRotaryEmbedding`, which now also does dynamic ntk scaling (simply pass the model config to "
+ "__init__)."
+ )
+ kwargs["rope_type"] = "dynamic"
+ super().__init__(*args, **kwargs)
+
+
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
@@ -630,7 +688,7 @@ def forward(
layer_past: Optional[Cache] = None,
output_attentions: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
):
attention_layer_outputs = self.attention(
self.input_layernorm(hidden_states),
diff --git a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py
index 71602f01e7d6f8..6c3f3313f57faf 100755
--- a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py
+++ b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py
@@ -105,7 +105,7 @@ def forward(
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
):
# Compute QKV
# Attention heads [batch, seq_len, hidden_size]
@@ -128,7 +128,16 @@ def forward(
key_rot = key[..., : self.rotary_ndims]
key_pass = key[..., self.rotary_ndims :]
- cos, sin = position_embeddings
+ if position_embeddings is None:
+ logger.warning_once(
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
+ "removed and `position_embeddings` will be mandatory."
+ )
+ cos, sin = self.rotary_emb(value, position_ids)
+ else:
+ cos, sin = position_embeddings
query, key = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
query = torch.cat((query, query_pass), dim=-1)
key = torch.cat((key, key_pass), dim=-1)
@@ -227,18 +236,40 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None):
class GPTNeoXJapaneseRotaryEmbedding(nn.Module):
def __init__(
self,
- config: GPTNeoXJapaneseConfig,
+ dim=None,
+ max_position_embeddings=2048,
+ base=10000,
device=None,
+ scaling_factor=1.0,
+ rope_type="default",
+ config: Optional[GPTNeoXJapaneseConfig] = None,
):
super().__init__()
+ # TODO (joao): remove the `if` below, only used for BC
self.rope_kwargs = {}
- # BC: "rope_type" was originally "type"
- if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
- self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
+ if config is None:
+ logger.warning_once(
+ "`GPTNeoXJapaneseRotaryEmbedding` can now be fully parameterized by passing the model config through the "
+ "`config` argument. All other arguments will be removed in v4.46"
+ )
+ self.rope_kwargs = {
+ "rope_type": rope_type,
+ "factor": scaling_factor,
+ "dim": dim,
+ "base": base,
+ "max_position_embeddings": max_position_embeddings,
+ }
+ self.rope_type = rope_type
+ self.max_seq_len_cached = max_position_embeddings
+ self.original_max_seq_len = max_position_embeddings
else:
- self.rope_type = "default"
- self.max_seq_len_cached = config.max_position_embeddings
- self.original_max_seq_len = config.max_position_embeddings
+ # BC: "rope_type" was originally "type"
+ if config.rope_scaling is not None:
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
+ else:
+ self.rope_type = "default"
+ self.max_seq_len_cached = config.max_position_embeddings
+ self.original_max_seq_len = config.max_position_embeddings
self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
@@ -384,7 +415,7 @@ def forward(
layer_past: Optional[Cache] = None,
output_attentions: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
):
residual = hidden_states
ln_out = self.input_layernorm(hidden_states)
diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py
index 4af8f73b5f5eea..1cc9cf369d1887 100644
--- a/src/transformers/models/gptj/modeling_gptj.py
+++ b/src/transformers/models/gptj/modeling_gptj.py
@@ -266,6 +266,7 @@ class GPTJFlashAttention2(GPTJAttention):
flash attention and deal with padding tokens in case the input contains any of them.
"""
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
diff --git a/src/transformers/models/granite/modeling_granite.py b/src/transformers/models/granite/modeling_granite.py
index 2e045e149d95de..9cabd48a51021f 100644
--- a/src/transformers/models/granite/modeling_granite.py
+++ b/src/transformers/models/granite/modeling_granite.py
@@ -1,9 +1,3 @@
-# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
-# This file was automatically generated from src/transformers/models/granite/modular_granite.py.
-# Do NOT edit this file manually as any edits will be overwritten by the generation of
-# the file from the modular. If any change should be done, please apply the change to the
-# modular_granite.py file directly. One of our CI enforces this.
-# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8
# Copyright 2024 IBM and the HuggingFace Inc. team. All rights reserved.
#
@@ -19,24 +13,29 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Callable, List, Optional, Tuple, Union
+from typing import List, Optional, Tuple, Union
import torch
+import torch.utils.checkpoint
from torch import nn
+from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, StaticCache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter
-from ...modeling_flash_attention_utils import FlashAttentionKwargs
-from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
+from ...modeling_flash_attention_utils import _flash_attention_forward
+from ...modeling_outputs import (
+ BaseModelOutputWithPast,
+ CausalLMOutputWithPast,
+)
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
-from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
-from ...processing_utils import Unpack
+from ...modeling_utils import PreTrainedModel
+from ...pytorch_utils import ALL_LAYERNORM_LAYERS
from ...utils import (
- LossKwargs,
add_start_docstrings,
add_start_docstrings_to_model_forward,
+ is_flash_attn_greater_or_equal_2_10,
logging,
replace_return_docstrings,
)
@@ -44,9 +43,96 @@
logger = logging.get_logger(__name__)
+
_CONFIG_FOR_DOC = "GraniteConfig"
+# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Granite
+class GraniteRMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ GraniteRMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+ def extra_repr(self):
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
+
+
+ALL_LAYERNORM_LAYERS.append(GraniteRMSNorm)
+
+
+class GraniteRotaryEmbedding(nn.Module):
+ def __init__(self, config: GraniteConfig):
+ super().__init__()
+ # TODO (joao): remove the `if` below, only used for BC
+ self.rope_kwargs = {}
+ if config.rope_scaling is not None:
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
+ else:
+ self.rope_type = "default"
+ self.max_seq_len_cached = config.max_position_embeddings
+ self.original_max_seq_len = config.max_position_embeddings
+
+ self.config = config
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
+
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device=None, **self.rope_kwargs)
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ self.original_inv_freq = self.inv_freq
+
+ def _dynamic_frequency_update(self, position_ids, device):
+ """
+ dynamic RoPE layers should recompute `inv_freq` in the following situations:
+ 1 - growing beyond the cached sequence length (allow scaling)
+ 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
+ """
+ seq_len = torch.max(position_ids) + 1
+ if seq_len > self.max_seq_len_cached: # growth
+ inv_freq, self.attention_scaling = self.rope_init_fn(
+ self.config, device, seq_len=seq_len, **self.rope_kwargs
+ )
+ self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
+ self.max_seq_len_cached = seq_len
+
+ if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
+ self.max_seq_len_cached = self.original_max_seq_len
+
+ @torch.no_grad()
+ def forward(self, x, position_ids):
+ if "dynamic" in self.rope_type:
+ self._dynamic_frequency_update(position_ids, device=x.device)
+
+ # Core RoPE block
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
+ position_ids_expanded = position_ids[:, None, :].float()
+ # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
+ device_type = x.device.type
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False):
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos()
+ sin = emb.sin()
+
+ # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
+ cos = cos * self.attention_scaling
+ sin = sin * self.attention_scaling
+
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+# Copied from transformers.models.llama.modeling_llama.rotate_half with Llama->Granite
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
@@ -54,6 +140,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)
+# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb with Llama->Granite
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
@@ -81,6 +168,24 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
return q_embed, k_embed
+class GraniteMLP(nn.Module):
+ # Copied from transformers.models.llama.modeling_llama.LlamaMLP.__init__ with Llama->Granite
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = config.intermediate_size
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ # Copied from transformers.models.gemma.modeling_gemma.GemmaMLP.forward with Gemma->Granite
+ def forward(self, x):
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
+
+
+# Copied from transformers.models.llama.modeling_llama.repeat_kv with Llama->Granite
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
@@ -93,32 +198,6 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
-def eager_attention_forward(
- module: nn.Module,
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- attention_mask: Optional[torch.Tensor],
- scaling: float,
- dropout: float = 0.0,
- **kwargs,
-):
- key_states = repeat_kv(key, module.num_key_value_groups)
- value_states = repeat_kv(value, module.num_key_value_groups)
-
- attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
- if attention_mask is not None:
- causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
- attn_weights = attn_weights + causal_mask
-
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
- attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
- attn_output = torch.matmul(attn_weights, value_states)
- attn_output = attn_output.transpose(1, 2).contiguous()
-
- return attn_output, attn_weights
-
-
class GraniteAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
@@ -126,40 +205,55 @@ def __init__(self, config: GraniteConfig, layer_idx: Optional[int] = None):
super().__init__()
self.config = config
self.layer_idx = layer_idx
- self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
- self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
- self.scaling = config.attention_multiplier
+ if layer_idx is None:
+ logger.warning_once(
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
+ "when creating this class."
+ )
+
self.attention_dropout = config.attention_dropout
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.hidden_size // self.num_heads
+ self.num_key_value_heads = config.num_key_value_heads
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.is_causal = True
- self.q_proj = nn.Linear(
- config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
- )
- self.k_proj = nn.Linear(
- config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
- )
- self.v_proj = nn.Linear(
- config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
- )
- self.o_proj = nn.Linear(
- config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
- )
+ self.scaling = config.attention_multiplier
+
+ if (self.head_dim * self.num_heads) != self.hidden_size:
+ raise ValueError(
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
+ f" and `num_heads`: {self.num_heads})."
+ )
+
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
+ self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias)
def forward(
self,
hidden_states: torch.Tensor,
- position_embeddings: Tuple[torch.Tensor, torch.Tensor],
- attention_mask: Optional[torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
- **kwargs: Unpack[FlashAttentionKwargs],
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
+ **kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
- input_shape = hidden_states.shape[:-1]
- hidden_shape = (*input_shape, -1, self.head_dim)
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
- query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
- key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
- value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
@@ -169,77 +263,252 @@ def forward(
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
- attention_interface: Callable = eager_attention_forward
- if self.config._attn_implementation != "eager":
- if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
- logger.warning_once(
- "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
- 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
- )
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling
+
+ if attention_mask is not None: # no matter the length, we just slice it
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ attn_output = attn_output.view(bsz, q_len, -1)
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+class GraniteFlashAttention2(GraniteAttention):
+ """
+ Granite flash attention module. This module inherits from `GraniteAttention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ output_attentions = False
+
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ # Flash attention requires the input to have the shape
+ # batch_size x seq_length x head_dim x hidden_dim
+ # therefore we just need to keep the original shape
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
+ # to be able to avoid many of these transpose/reshape/view.
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ dropout_rate = self.attention_dropout if self.training else 0.0
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in the correct dtype just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+ # in fp32. (GraniteRMSNorm handles it correctly)
+
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
else:
- attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+ target_dtype = self.q_proj.weight.dtype
- attn_output, attn_weights = attention_interface(
- self,
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ attn_output = _flash_attention_forward(
query_states,
key_states,
value_states,
attention_mask,
- dropout=0.0 if not self.training else self.attention_dropout,
- scaling=self.scaling,
- **kwargs,
+ q_len,
+ position_ids=position_ids,
+ dropout=dropout_rate,
+ softmax_scale=self.scaling,
+ sliding_window=getattr(self, "sliding_window", None),
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
+ is_causal=self.is_causal,
)
- attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
attn_output = self.o_proj(attn_output)
- return attn_output, attn_weights
+ if not output_attentions:
+ attn_weights = None
-class GraniteRMSNorm(nn.Module):
- def __init__(self, hidden_size, eps=1e-6):
- """
- GraniteRMSNorm is equivalent to T5LayerNorm
- """
- super().__init__()
- self.weight = nn.Parameter(torch.ones(hidden_size))
- self.variance_epsilon = eps
+ return attn_output, attn_weights, past_key_value
- def forward(self, hidden_states):
- input_dtype = hidden_states.dtype
- hidden_states = hidden_states.to(torch.float32)
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
- return self.weight * hidden_states.to(input_dtype)
- def extra_repr(self):
- return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
+class GraniteSdpaAttention(GraniteAttention):
+ """
+ Granite attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
+ `GraniteAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
+ SDPA API.
+ """
+
+ # Adapted from GraniteAttention.forward
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ if output_attentions:
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
+ logger.warning_once(
+ "GraniteModel is using GraniteSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
+ )
+ return super().forward(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ )
+ bsz, q_len, _ = hidden_states.size()
-class GraniteMLP(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.config = config
- self.hidden_size = config.hidden_size
- self.intermediate_size = config.intermediate_size
- self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
- self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
- self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
- self.act_fn = ACT2FN[config.hidden_act]
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
- def forward(self, x):
- down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
- return down_proj
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ causal_mask = attention_mask
+ if attention_mask is not None:
+ causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
+
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
+ if query_states.device.type == "cuda" and causal_mask is not None:
+ query_states = query_states.contiguous()
+ key_states = key_states.contiguous()
+ value_states = value_states.contiguous()
+
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
+ is_causal = True if causal_mask is None and q_len > 1 else False
+
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
+ query_states,
+ key_states,
+ value_states,
+ attn_mask=causal_mask,
+ dropout_p=self.attention_dropout if self.training else 0.0,
+ is_causal=is_causal,
+ scale=self.scaling,
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.view(bsz, q_len, -1)
+
+ attn_output = self.o_proj(attn_output)
+
+ return attn_output, None, past_key_value
+
+
+GRANITE_ATTENTION_CLASSES = {
+ "eager": GraniteAttention,
+ "flash_attention_2": GraniteFlashAttention2,
+ "sdpa": GraniteSdpaAttention,
+}
class GraniteDecoderLayer(nn.Module):
def __init__(self, config: GraniteConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
- self.self_attn = GraniteAttention(config=config, layer_idx=layer_idx)
+
+ self.self_attn = GRANITE_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
self.mlp = GraniteMLP(config)
self.input_layernorm = GraniteRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = GraniteRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
self.residual_multiplier = config.residual_multiplier
def forward(
@@ -251,7 +520,7 @@ def forward(
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
@@ -281,7 +550,7 @@ def forward(
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
- hidden_states, self_attn_weights = self.self_attn(
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
@@ -298,79 +567,17 @@ def forward(
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
- hidden_states = residual + hidden_states * self.residual_multiplier # main diff with Llama
+ hidden_states = residual + hidden_states * self.residual_multiplier
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
- return outputs
-
-
-class GraniteRotaryEmbedding(nn.Module):
- def __init__(
- self,
- config: GraniteConfig,
- device=None,
- ):
- super().__init__()
- self.rope_kwargs = {}
- # BC: "rope_type" was originally "type"
- if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
- self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
- else:
- self.rope_type = "default"
- self.max_seq_len_cached = config.max_position_embeddings
- self.original_max_seq_len = config.max_position_embeddings
-
- self.config = config
- self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
-
- inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
- self.register_buffer("inv_freq", inv_freq, persistent=False)
- self.original_inv_freq = self.inv_freq
-
- def _dynamic_frequency_update(self, position_ids, device):
- """
- dynamic RoPE layers should recompute `inv_freq` in the following situations:
- 1 - growing beyond the cached sequence length (allow scaling)
- 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
- """
- seq_len = torch.max(position_ids) + 1
- if seq_len > self.max_seq_len_cached: # growth
- inv_freq, self.attention_scaling = self.rope_init_fn(
- self.config, device, seq_len=seq_len, **self.rope_kwargs
- )
- self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
- self.max_seq_len_cached = seq_len
-
- if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
- self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
- self.max_seq_len_cached = self.original_max_seq_len
-
- @torch.no_grad()
- def forward(self, x, position_ids):
- if "dynamic" in self.rope_type:
- self._dynamic_frequency_update(position_ids, device=x.device)
-
- # Core RoPE block
- inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
- position_ids_expanded = position_ids[:, None, :].float()
- # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
- device_type = x.device.type
- device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
- with torch.autocast(device_type=device_type, enabled=False):
- freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
- emb = torch.cat((freqs, freqs), dim=-1)
- cos = emb.cos()
- sin = emb.sin()
-
- # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
- cos = cos * self.attention_scaling
- sin = sin * self.attention_scaling
+ if use_cache:
+ outputs += (present_key_value,)
- return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+ return outputs
GRANITE_START_DOCSTRING = r"""
@@ -394,6 +601,7 @@ def forward(self, x, position_ids):
"The bare Granite Model outputting raw hidden-states without any specific head on top.",
GRANITE_START_DOCSTRING,
)
+# Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel with Llama->Granite
class GranitePreTrainedModel(PreTrainedModel):
config_class = GraniteConfig
base_model_prefix = "model"
@@ -515,9 +723,17 @@ def __init__(self, config: GraniteConfig):
[GraniteDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = GraniteRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
- self.rotary_emb = GraniteRotaryEmbedding(config=config)
self.gradient_checkpointing = False
+
self.embedding_multiplier = config.embedding_multiplier
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.hidden_size // self.num_heads
+ self.max_position_embeddings = config.max_position_embeddings
+ self.rope_theta = config.rope_theta
+
+ # rope
+ self.rotary_emb = GraniteRotaryEmbedding(config)
# Initialize weights and apply final processing
self.post_init()
@@ -534,14 +750,13 @@ def forward(
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[Cache] = None,
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
- **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
@@ -562,17 +777,27 @@ def forward(
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
- inputs_embeds = inputs_embeds * self.embedding_multiplier # main diff with Llama
+ inputs_embeds = inputs_embeds * self.embedding_multiplier
- if use_cache and past_key_values is None:
- past_key_values = DynamicCache()
+ # kept for BC (non `Cache` `past_key_values` inputs)
+ return_legacy_cache = False
+ if use_cache and not isinstance(past_key_values, Cache):
+ return_legacy_cache = True
+ if past_key_values is None:
+ past_key_values = DynamicCache()
+ else:
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
+ logger.warning_once(
+ "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
+ "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
+ "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
+ )
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
-
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
@@ -580,6 +805,7 @@ def forward(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
+ # embed positions
hidden_states = inputs_embeds
# create position embeddings to be shared across the decoder layers
@@ -588,8 +814,9 @@ def forward(
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
+ next_decoder_cache = None
- for decoder_layer in self.layers[: self.config.num_hidden_layers]:
+ for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
@@ -615,11 +842,13 @@ def forward(
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
- **flash_attn_kwargs,
)
hidden_states = layer_outputs[0]
+ if use_cache:
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
+
if output_attentions:
all_self_attns += (layer_outputs[1],)
@@ -629,13 +858,18 @@ def forward(
if output_hidden_states:
all_hidden_states += (hidden_states,)
- output = BaseModelOutputWithPast(
+ next_cache = next_decoder_cache if use_cache else None
+ if return_legacy_cache:
+ next_cache = next_cache.to_legacy_cache()
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
+ return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
- past_key_values=past_key_values if use_cache else None,
+ past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
- return output if return_dict else output.to_tuple()
def _update_causal_mask(
self,
@@ -645,6 +879,11 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool,
):
+ # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
+ # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
+ # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
+ # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
+
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
@@ -667,6 +906,7 @@ def _update_causal_mask(
return None
dtype, device = input_tensor.dtype, input_tensor.device
+ min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_cache_shape()
@@ -677,17 +917,24 @@ def _update_causal_mask(
else past_seen_tokens + sequence_length + 1
)
- # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
- causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
- attention_mask,
- sequence_length=sequence_length,
- target_length=target_length,
- dtype=dtype,
- device=device,
- cache_position=cache_position,
- batch_size=input_tensor.shape[0],
- )
-
+ if attention_mask is not None and attention_mask.dim() == 4:
+ causal_mask = attention_mask
+ else:
+ causal_mask = torch.full(
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
+ )
+ if sequence_length != 1:
+ causal_mask = torch.triu(causal_mask, diagonal=1)
+ causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
+ causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
+ if attention_mask is not None:
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
+ mask_length = attention_mask.shape[-1]
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
+ padding_mask = padding_mask == 0
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
+ padding_mask, min_dtype
+ )
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
@@ -697,12 +944,12 @@ def _update_causal_mask(
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
- min_dtype = torch.finfo(dtype).min
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
return causal_mask
@staticmethod
+ # Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
@@ -759,13 +1006,10 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
return causal_mask
-class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
-
-
class GraniteForCausalLM(GranitePreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
- _tp_plan = {"lm_head": "colwise_rep"}
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__ with Llama->Granite
def __init__(self, config):
super().__init__(config)
self.model = GraniteModel(config)
@@ -808,8 +1052,6 @@ def forward(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
- num_logits_to_keep: int = 0,
- **kwargs: Unpack[KwargsForCausalLM],
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
@@ -818,11 +1060,6 @@ def forward(
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
- num_logits_to_keep (`int`, *optional*):
- Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
- `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
- token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
-
Returns:
Example:
@@ -830,8 +1067,8 @@ def forward(
```python
>>> from transformers import AutoTokenizer, GraniteForCausalLM
- >>> model = GraniteForCausalLM.from_pretrained("meta-granite/Granite-2-7b-hf")
- >>> tokenizer = AutoTokenizer.from_pretrained("meta-granite/Granite-2-7b-hf")
+ >>> model = GraniteForCausalLM.from_pretrained("ibm/PowerLM-3b")
+ >>> tokenizer = AutoTokenizer.from_pretrained("ibm/PowerLM-3b")
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
@@ -859,17 +1096,26 @@ def forward(
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
- **kwargs,
)
hidden_states = outputs[0]
- # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
- logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
- logits = logits / self.config.logits_scaling # main diff with Llama
+ logits = self.lm_head(hidden_states)
+ logits = logits / self.config.logits_scaling
loss = None
if labels is not None:
- loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
+ # Upcast to float if we need to compute the loss to avoid potential precision issues
+ logits = logits.float()
+ # Shift so that tokens < n predict n
+ shift_logits = logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ # Flatten the tokens
+ loss_fct = CrossEntropyLoss()
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
+ shift_labels = shift_labels.view(-1)
+ # Enable model parallelism
+ shift_labels = shift_labels.to(shift_logits.device)
+ loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + outputs[1:]
@@ -882,3 +1128,12 @@ def forward(
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
+
+ @staticmethod
+ def _reorder_cache(past_key_values, beam_idx):
+ reordered_past = ()
+ for layer_past in past_key_values:
+ reordered_past += (
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
+ )
+ return reordered_past
diff --git a/src/transformers/models/granite/modular_granite.py b/src/transformers/models/granite/modular_granite.py
deleted file mode 100644
index 698280085f1852..00000000000000
--- a/src/transformers/models/granite/modular_granite.py
+++ /dev/null
@@ -1,291 +0,0 @@
-# coding=utf-8
-# Copyright 2024 IBM and the HuggingFace Inc. team. All rights reserved.
-#
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-from typing import List, Optional, Tuple, Union
-
-import torch
-import torch.utils.checkpoint
-from torch import nn
-
-from ...cache_utils import Cache, DynamicCache
-from ...modeling_flash_attention_utils import FlashAttentionKwargs
-from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
-from ...processing_utils import Unpack
-from ...utils import LossKwargs, logging
-from ..llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaForCausalLM, LlamaModel
-from .configuration_granite import GraniteConfig
-
-
-logger = logging.get_logger(__name__)
-
-
-class GraniteAttention(LlamaAttention):
- """Multi-headed attention from 'Attention Is All You Need' paper"""
-
- def __init__(self, config: GraniteConfig, layer_idx: Optional[int] = None):
- super().__init__(config, layer_idx)
- self.scaling = config.attention_multiplier
-
-
-class GraniteDecoderLayer(LlamaDecoderLayer):
- def __init__(self, config: GraniteConfig, layer_idx: int):
- super().__init__(config, layer_idx)
- self.residual_multiplier = config.residual_multiplier
- self.self_attn = GraniteAttention(config=config, layer_idx=layer_idx)
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[Cache] = None,
- output_attentions: Optional[bool] = False,
- use_cache: Optional[bool] = False,
- cache_position: Optional[torch.LongTensor] = None,
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
- **kwargs,
- ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
- """
- Args:
- hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
- attention_mask (`torch.FloatTensor`, *optional*):
- attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
- query_sequence_length, key_sequence_length)` if default attention is used.
- output_attentions (`bool`, *optional*):
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
- returned tensors for more detail.
- use_cache (`bool`, *optional*):
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
- (see `past_key_values`).
- past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
- cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
- Indices depicting the position of the input sequence tokens in the sequence
- position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
- Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
- with `head_dim` being the embedding dimension of each attention head.
- kwargs (`dict`, *optional*):
- Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
- into the model
- """
- residual = hidden_states
-
- hidden_states = self.input_layernorm(hidden_states)
-
- # Self Attention
- hidden_states, self_attn_weights = self.self_attn(
- hidden_states=hidden_states,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_value=past_key_value,
- output_attentions=output_attentions,
- use_cache=use_cache,
- cache_position=cache_position,
- position_embeddings=position_embeddings,
- **kwargs,
- )
- hidden_states = residual + hidden_states * self.residual_multiplier
-
- # Fully Connected
- residual = hidden_states
- hidden_states = self.post_attention_layernorm(hidden_states)
- hidden_states = self.mlp(hidden_states)
- hidden_states = residual + hidden_states * self.residual_multiplier # main diff with Llama
-
- outputs = (hidden_states,)
-
- if output_attentions:
- outputs += (self_attn_weights,)
-
- return outputs
-
-
-class GraniteModel(LlamaModel):
- def __init__(self, config: GraniteConfig):
- super().__init__(config)
- self.embedding_multiplier = config.embedding_multiplier
- self.layers = nn.ModuleList(
- [GraniteDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
- )
-
- def forward(
- self,
- input_ids: torch.LongTensor = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[Cache] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- cache_position: Optional[torch.LongTensor] = None,
- **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
- ) -> Union[Tuple, BaseModelOutputWithPast]:
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- use_cache = use_cache if use_cache is not None else self.config.use_cache
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
-
- if (input_ids is None) ^ (inputs_embeds is not None):
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
-
- if self.gradient_checkpointing and self.training and use_cache:
- logger.warning_once(
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
- )
- use_cache = False
-
- if inputs_embeds is None:
- inputs_embeds = self.embed_tokens(input_ids)
-
- inputs_embeds = inputs_embeds * self.embedding_multiplier # main diff with Llama
-
- if use_cache and past_key_values is None:
- past_key_values = DynamicCache()
-
- if cache_position is None:
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
- cache_position = torch.arange(
- past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
- )
-
- if position_ids is None:
- position_ids = cache_position.unsqueeze(0)
-
- causal_mask = self._update_causal_mask(
- attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
- )
-
- hidden_states = inputs_embeds
-
- # create position embeddings to be shared across the decoder layers
- position_embeddings = self.rotary_emb(hidden_states, position_ids)
-
- # decoder layers
- all_hidden_states = () if output_hidden_states else None
- all_self_attns = () if output_attentions else None
-
- for decoder_layer in self.layers[: self.config.num_hidden_layers]:
- if output_hidden_states:
- all_hidden_states += (hidden_states,)
-
- if self.gradient_checkpointing and self.training:
- layer_outputs = self._gradient_checkpointing_func(
- decoder_layer.__call__,
- hidden_states,
- causal_mask,
- position_ids,
- past_key_values,
- output_attentions,
- use_cache,
- cache_position,
- position_embeddings,
- )
- else:
- layer_outputs = decoder_layer(
- hidden_states,
- attention_mask=causal_mask,
- position_ids=position_ids,
- past_key_value=past_key_values,
- output_attentions=output_attentions,
- use_cache=use_cache,
- cache_position=cache_position,
- position_embeddings=position_embeddings,
- **flash_attn_kwargs,
- )
-
- hidden_states = layer_outputs[0]
-
- if output_attentions:
- all_self_attns += (layer_outputs[1],)
-
- hidden_states = self.norm(hidden_states)
-
- # add hidden states from the last decoder layer
- if output_hidden_states:
- all_hidden_states += (hidden_states,)
-
- output = BaseModelOutputWithPast(
- last_hidden_state=hidden_states,
- past_key_values=past_key_values if use_cache else None,
- hidden_states=all_hidden_states,
- attentions=all_self_attns,
- )
- return output if return_dict else output.to_tuple()
-
-
-class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
-
-
-class GraniteForCausalLM(LlamaForCausalLM):
- def forward(
- self,
- input_ids: torch.LongTensor = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- labels: Optional[torch.LongTensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- cache_position: Optional[torch.LongTensor] = None,
- num_logits_to_keep: int = 0,
- **kwargs: Unpack[KwargsForCausalLM],
- ) -> Union[Tuple, CausalLMOutputWithPast]:
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
-
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
- outputs = self.model(
- input_ids=input_ids,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_values=past_key_values,
- inputs_embeds=inputs_embeds,
- use_cache=use_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- cache_position=cache_position,
- **kwargs,
- )
-
- hidden_states = outputs[0]
- # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
- logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
- logits = logits / self.config.logits_scaling # main diff with Llama
-
- loss = None
- if labels is not None:
- loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
-
- if not return_dict:
- output = (logits,) + outputs[1:]
- return (loss,) + output if loss is not None else output
-
- return CausalLMOutputWithPast(
- loss=loss,
- logits=logits,
- past_key_values=outputs.past_key_values,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
diff --git a/src/transformers/models/granitemoe/modeling_granitemoe.py b/src/transformers/models/granitemoe/modeling_granitemoe.py
index 1c4c06bbc8d71e..4871fc3584faee 100644
--- a/src/transformers/models/granitemoe/modeling_granitemoe.py
+++ b/src/transformers/models/granitemoe/modeling_granitemoe.py
@@ -158,15 +158,11 @@ def extra_repr(self):
# Copied from transformers.models.granite.modeling_granite.GraniteRotaryEmbedding with Granite->GraniteMoe
class GraniteMoeRotaryEmbedding(nn.Module):
- def __init__(
- self,
- config: GraniteMoeConfig,
- device=None,
- ):
+ def __init__(self, config: GraniteMoeConfig):
super().__init__()
+ # TODO (joao): remove the `if` below, only used for BC
self.rope_kwargs = {}
- # BC: "rope_type" was originally "type"
- if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
+ if config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
else:
self.rope_type = "default"
@@ -176,7 +172,7 @@ def __init__(
self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
- inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device=None, **self.rope_kwargs)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
@@ -417,8 +413,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
-# copied from transformers.models.granite.modeling_granite.GraniteAttention with Granite->GraniteMoe
-# no longer copied after attention refactors
+# Copied from transformers.models.granite.modeling_granite.GraniteAttention with Granite->GraniteMoe
class GraniteMoeAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
@@ -463,7 +458,7 @@ def forward(
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
@@ -515,8 +510,7 @@ def forward(
return attn_output, attn_weights, past_key_value
-# NO LONGER EXIST Copied from transformers.models.granite.modeling_granite.GraniteFlashAttention2 with Granite->GraniteMoe
-# TODO cyril: modular
+# Copied from transformers.models.granite.modeling_granite.GraniteFlashAttention2 with Granite->GraniteMoe
class GraniteMoeFlashAttention2(GraniteMoeAttention):
"""
GraniteMoe flash attention module. This module inherits from `GraniteMoeAttention` as the weights of the module stays
@@ -541,7 +535,7 @@ def forward(
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
output_attentions = False
@@ -623,8 +617,7 @@ def forward(
return attn_output, attn_weights, past_key_value
-# NO LONGER EXIST Copied from transformers.models.granite.modeling_granite.GraniteSdpaAttention with Granite->GraniteMoe
-# TODO cyril: modular
+# Copied from transformers.models.granite.modeling_granite.GraniteSdpaAttention with Granite->GraniteMoe
class GraniteMoeSdpaAttention(GraniteMoeAttention):
"""
GraniteMoe attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
@@ -642,7 +635,7 @@ def forward(
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions:
@@ -746,7 +739,7 @@ def forward(
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
output_router_logits: Optional[bool] = False,
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
diff --git a/src/transformers/models/hubert/configuration_hubert.py b/src/transformers/models/hubert/configuration_hubert.py
index 9f488b19888957..20977cff87d167 100644
--- a/src/transformers/models/hubert/configuration_hubert.py
+++ b/src/transformers/models/hubert/configuration_hubert.py
@@ -94,8 +94,6 @@ class HubertConfig(PretrainedConfig):
embeddings layer.
num_conv_pos_embedding_groups (`int`, *optional*, defaults to 16):
Number of groups of 1D convolutional positional embeddings layer.
- conv_pos_batch_norm (`bool`, *optional*, defaults to `False`):
- Whether to use batch norm instead of weight norm in conv_pos
do_stable_layer_norm (`bool`, *optional*, defaults to `False`):
Whether do apply *stable* layer norm architecture of the Transformer encoder. `do_stable_layer_norm is
True` corresponds to applying layer norm before the attention layer, whereas `do_stable_layer_norm is
@@ -184,7 +182,6 @@ def __init__(
conv_bias=False,
num_conv_pos_embeddings=128,
num_conv_pos_embedding_groups=16,
- conv_pos_batch_norm=False,
do_stable_layer_norm=False,
apply_spec_augment=True,
mask_time_prob=0.05,
@@ -212,7 +209,6 @@ def __init__(
self.conv_bias = conv_bias
self.num_conv_pos_embeddings = num_conv_pos_embeddings
self.num_conv_pos_embedding_groups = num_conv_pos_embedding_groups
- self.conv_pos_batch_norm = conv_pos_batch_norm
self.num_feat_extract_layers = len(self.conv_dim)
self.num_hidden_layers = num_hidden_layers
self.intermediate_size = intermediate_size
diff --git a/src/transformers/models/hubert/convert_hubert_original_pytorch_checkpoint_to_pytorch.py b/src/transformers/models/hubert/convert_hubert_original_pytorch_checkpoint_to_pytorch.py
index 4966340493f35c..6478fdadf13de3 100644
--- a/src/transformers/models/hubert/convert_hubert_original_pytorch_checkpoint_to_pytorch.py
+++ b/src/transformers/models/hubert/convert_hubert_original_pytorch_checkpoint_to_pytorch.py
@@ -38,8 +38,7 @@
MAPPING = {
"post_extract_proj": "feature_projection.projection",
- "encoder.pos_conv.0": "encoder.pos_conv_embed.batch_norm",
- "encoder.pos_conv.1": "encoder.pos_conv_embed.conv",
+ "encoder.pos_conv.0": "encoder.pos_conv_embed.conv",
"self_attn.k_proj": "encoder.layers.*.attention.k_proj",
"self_attn.v_proj": "encoder.layers.*.attention.v_proj",
"self_attn.q_proj": "encoder.layers.*.attention.q_proj",
@@ -77,12 +76,6 @@ def set_recursively(hf_pointer, key, value, full_name, weight_type):
hf_pointer.weight_v.data = value
elif weight_type == "bias":
hf_pointer.bias.data = value
- elif weight_type == "running_mean":
- hf_pointer.running_mean.data = value
- elif weight_type == "running_var":
- hf_pointer.running_var.data = value
- elif weight_type == "num_batches_tracked":
- hf_pointer.num_batches_tracked.data = value
else:
hf_pointer.data = value
@@ -123,12 +116,6 @@ def recursively_load_weights(fairseq_model, hf_model, is_finetuned):
weight_type = "weight"
elif "bias" in name:
weight_type = "bias"
- elif "running_mean" in name:
- weight_type = "running_mean"
- elif "running_var" in name:
- weight_type = "running_var"
- elif "num_batches_tracked" in name:
- weight_type = "num_batches_tracked"
else:
weight_type = None
set_recursively(hf_model, mapped_key, value, name, weight_type)
diff --git a/src/transformers/models/hubert/modeling_hubert.py b/src/transformers/models/hubert/modeling_hubert.py
index f2700836789ebd..57f59cf9aab94f 100755
--- a/src/transformers/models/hubert/modeling_hubert.py
+++ b/src/transformers/models/hubert/modeling_hubert.py
@@ -260,6 +260,7 @@ def forward(self, hidden_states):
return hidden_states
+# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2PositionalConvEmbedding with Wav2Vec2->Hubert
class HubertPositionalConvEmbedding(nn.Module):
def __init__(self, config):
super().__init__()
@@ -271,37 +272,32 @@ def __init__(self, config):
groups=config.num_conv_pos_embedding_groups,
)
- self.batch_norm = None
- if config.conv_pos_batch_norm:
- self.batch_norm = nn.BatchNorm1d(config.hidden_size)
- else:
- weight_norm = nn.utils.weight_norm
- if hasattr(nn.utils.parametrizations, "weight_norm"):
- weight_norm = nn.utils.parametrizations.weight_norm
+ weight_norm = nn.utils.weight_norm
+ if hasattr(nn.utils.parametrizations, "weight_norm"):
+ weight_norm = nn.utils.parametrizations.weight_norm
- if is_deepspeed_zero3_enabled():
- import deepspeed
+ if is_deepspeed_zero3_enabled():
+ import deepspeed
- with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
- self.conv = weight_norm(self.conv, name="weight", dim=2)
- if hasattr(self.conv, "parametrizations"):
- weight_g = self.conv.parametrizations.weight.original0
- weight_v = self.conv.parametrizations.weight.original1
- else:
- weight_g = self.conv.weight_g
- weight_v = self.conv.weight_v
- deepspeed.zero.register_external_parameter(self, weight_v)
- deepspeed.zero.register_external_parameter(self, weight_g)
- else:
+ with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
self.conv = weight_norm(self.conv, name="weight", dim=2)
+ if hasattr(self.conv, "parametrizations"):
+ weight_g = self.conv.parametrizations.weight.original0
+ weight_v = self.conv.parametrizations.weight.original1
+ else:
+ weight_g = self.conv.weight_g
+ weight_v = self.conv.weight_v
+ deepspeed.zero.register_external_parameter(self, weight_v)
+ deepspeed.zero.register_external_parameter(self, weight_g)
+ else:
+ self.conv = weight_norm(self.conv, name="weight", dim=2)
self.padding = HubertSamePadLayer(config.num_conv_pos_embeddings)
self.activation = ACT2FN[config.feat_extract_activation]
def forward(self, hidden_states):
hidden_states = hidden_states.transpose(1, 2)
- if self.batch_norm is not None:
- hidden_states = self.batch_norm(hidden_states)
+
hidden_states = self.conv(hidden_states)
hidden_states = self.padding(hidden_states)
hidden_states = self.activation(hidden_states)
@@ -563,6 +559,7 @@ class HubertFlashAttention2(HubertAttention):
flash attention and deal with padding tokens in case the input contains any of them.
"""
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -1629,8 +1626,7 @@ def forward(
pooled_output = hidden_states.mean(dim=1)
else:
padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
- expand_padding_mask = padding_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
- hidden_states[~expand_padding_mask] = 0.0
+ hidden_states[~padding_mask] = 0.0
pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
logits = self.classifier(pooled_output)
diff --git a/src/transformers/models/idefics/modeling_idefics.py b/src/transformers/models/idefics/modeling_idefics.py
index b2ffbcbc695696..8bd24728b03885 100644
--- a/src/transformers/models/idefics/modeling_idefics.py
+++ b/src/transformers/models/idefics/modeling_idefics.py
@@ -444,6 +444,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)
+# Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
diff --git a/src/transformers/models/idefics2/modeling_idefics2.py b/src/transformers/models/idefics2/modeling_idefics2.py
index 6d7295b5120d29..3d46c3bd82e788 100644
--- a/src/transformers/models/idefics2/modeling_idefics2.py
+++ b/src/transformers/models/idefics2/modeling_idefics2.py
@@ -272,6 +272,7 @@ class Idefics2VisionFlashAttention2(Idefics2VisionAttention):
flash attention and deal with padding tokens in case the input contains any of them.
"""
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -858,8 +859,7 @@ def forward(
return attn_output, attn_weights, past_key_value
-# NO LONGER EXIST Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2 with MistralAttention->Idefics2PerceiverAttention,MistralFlashAttention->Idefics2PerceiverFlashAttention,Mistral->Idefics2
-# TODO cyril: modular
+# Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2 with MistralAttention->Idefics2PerceiverAttention,MistralFlashAttention->Idefics2PerceiverFlashAttention,Mistral->Idefics2
class Idefics2PerceiverFlashAttention2(Idefics2PerceiverAttention):
"""
Idefics2 flash attention module. This module inherits from `Idefics2PerceiverAttention` as the weights of the module stays
@@ -867,6 +867,7 @@ class Idefics2PerceiverFlashAttention2(Idefics2PerceiverAttention):
flash attention and deal with padding tokens in case the input contains any of them.
"""
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
diff --git a/src/transformers/models/idefics3/configuration_idefics3.py b/src/transformers/models/idefics3/configuration_idefics3.py
index 0d385b0ee48dec..4b10d8d2d03a81 100644
--- a/src/transformers/models/idefics3/configuration_idefics3.py
+++ b/src/transformers/models/idefics3/configuration_idefics3.py
@@ -54,8 +54,7 @@ class Idefics3VisionConfig(PretrainedConfig):
The epsilon used by the layer normalization layers.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
- initializer_range (`float`, *optional*, defaults to 0.02):
- The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ initializer_range (``, *optional*, defaults to 0.02):
Example:
diff --git a/src/transformers/models/idefics3/modeling_idefics3.py b/src/transformers/models/idefics3/modeling_idefics3.py
index 3a52b8b6d54d0e..31d43948fbd565 100644
--- a/src/transformers/models/idefics3/modeling_idefics3.py
+++ b/src/transformers/models/idefics3/modeling_idefics3.py
@@ -273,6 +273,7 @@ class Idefics3VisionFlashAttention2(Idefics3VisionAttention):
flash attention and deal with padding tokens in case the input contains any of them.
"""
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
diff --git a/src/transformers/models/ijepa/modular_ijepa.py b/src/transformers/models/ijepa/modular_ijepa.py
index 3b3756dd5ce697..efbd71d91342fd 100644
--- a/src/transformers/models/ijepa/modular_ijepa.py
+++ b/src/transformers/models/ijepa/modular_ijepa.py
@@ -155,7 +155,7 @@ def __init__(self, config: IJepaConfig, add_pooling_layer: bool = False, use_mas
self.embeddings = IJepaEmbeddings(config, use_mask_token=use_mask_token)
-_IMAGE_CLASS_CHECKPOINT = "facebook/ijepa_vith14_1k"
+_IMAGE_CLASS_CHECKPOINT = "jmtzt/ijepa_vith14_1k"
_IMAGE_CLASS_EXPECTED_OUTPUT = "Egyptian cat"
diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py
index ae7470d789b27e..a185d5ebc6e86c 100755
--- a/src/transformers/models/jamba/modeling_jamba.py
+++ b/src/transformers/models/jamba/modeling_jamba.py
@@ -384,6 +384,7 @@ class JambaFlashAttention2(JambaAttention):
flash attention and deal with padding tokens in case the input contains any of them.
"""
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -834,7 +835,6 @@ def forward(
class JambaMLP(nn.Module):
def __init__(self, config):
super().__init__()
- self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
@@ -842,9 +842,8 @@ def __init__(self, config):
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act]
- def forward(self, x):
- down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
- return down_proj
+ def forward(self, hidden_state):
+ return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
# Adapted from transformers.models.mixtral.modeling_mixtral.MixtralSparseMoeBlock with Mistral->Jamba
diff --git a/src/transformers/models/jetmoe/modeling_jetmoe.py b/src/transformers/models/jetmoe/modeling_jetmoe.py
index 7b7fd5a90d69ed..a4bb1d78fdc5ce 100644
--- a/src/transformers/models/jetmoe/modeling_jetmoe.py
+++ b/src/transformers/models/jetmoe/modeling_jetmoe.py
@@ -32,7 +32,6 @@
MoeModelOutputWithPast,
SequenceClassifierOutputWithPast,
)
-from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
from ...modeling_utils import PreTrainedModel
from ...utils import (
add_start_docstrings,
@@ -386,55 +385,24 @@ def extra_repr(self):
# Copied from transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding with Gemma->JetMoe
class JetMoeRotaryEmbedding(nn.Module):
- def __init__(
- self,
- config: JetMoeConfig,
- device=None,
- ):
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
- self.rope_kwargs = {}
- # BC: "rope_type" was originally "type"
- if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
- self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
- else:
- self.rope_type = "default"
- self.max_seq_len_cached = config.max_position_embeddings
- self.original_max_seq_len = config.max_position_embeddings
- self.config = config
- self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
+ self.dim = dim
+ self.max_position_embeddings = max_position_embeddings
+ self.base = base
- inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
- self.register_buffer("inv_freq", inv_freq, persistent=False)
- self.original_inv_freq = self.inv_freq
-
- def _dynamic_frequency_update(self, position_ids, device):
- """
- dynamic RoPE layers should recompute `inv_freq` in the following situations:
- 1 - growing beyond the cached sequence length (allow scaling)
- 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
- """
- seq_len = torch.max(position_ids) + 1
- if seq_len > self.max_seq_len_cached: # growth
- inv_freq, self.attention_scaling = self.rope_init_fn(
- self.config, device, seq_len=seq_len, **self.rope_kwargs
- )
- self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
- self.max_seq_len_cached = seq_len
-
- if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
- self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
- self.max_seq_len_cached = self.original_max_seq_len
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim))
+ self.register_buffer("inv_freq", tensor=inv_freq, persistent=False)
@torch.no_grad()
- def forward(self, x, position_ids):
- if "dynamic" in self.rope_type:
- self._dynamic_frequency_update(position_ids, device=x.device)
-
- # Core RoPE block
+ def forward(self, x, position_ids, seq_len=None):
+ # x: [bs, num_attention_heads, seq_len, head_size]
+ self.inv_freq.to(x.device)
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
- # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
+ # Force float32 since bfloat16 loses precision on long contexts
+ # See https://github.com/huggingface/transformers/pull/29285
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
@@ -442,11 +410,6 @@ def forward(self, x, position_ids):
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
-
- # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
- cos = cos * self.attention_scaling
- sin = sin * self.attention_scaling
-
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
@@ -523,7 +486,11 @@ def __init__(self, config: JetMoeConfig, layer_idx: Optional[int] = None):
self.kv_proj = torch.nn.Linear(config.hidden_size, self.kv_projection_size * 2, bias=False)
- self.rotary_emb = JetMoeRotaryEmbedding(config)
+ self.rotary_emb = JetMoeRotaryEmbedding(
+ config.kv_channels,
+ max_position_embeddings=config.max_position_embeddings,
+ base=config.rope_theta,
+ )
def forward(
self,
@@ -674,6 +641,7 @@ def forward(
class JetMoeFlashAttention2(JetMoeAttention):
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py
index 5be33c26414cd7..0408bb73c7f2da 100644
--- a/src/transformers/models/llama/modeling_llama.py
+++ b/src/transformers/models/llama/modeling_llama.py
@@ -17,7 +17,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Callable, List, Optional, Tuple, Union
+import math
+from typing import List, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
@@ -27,7 +28,7 @@
from ...cache_utils import Cache, DynamicCache, StaticCache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter
-from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...modeling_flash_attention_utils import FlashAttentionKwargs, _flash_attention_forward
from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
@@ -36,7 +37,7 @@
TokenClassifierOutput,
)
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
-from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...modeling_utils import PreTrainedModel
from ...processing_utils import Unpack
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
from ...utils import (
@@ -44,6 +45,7 @@
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
+ is_flash_attn_greater_or_equal_2_10,
logging,
replace_return_docstrings,
)
@@ -82,18 +84,40 @@ def extra_repr(self):
class LlamaRotaryEmbedding(nn.Module):
def __init__(
self,
- config: LlamaConfig,
+ dim=None,
+ max_position_embeddings=2048,
+ base=10000,
device=None,
+ scaling_factor=1.0,
+ rope_type="default",
+ config: Optional[LlamaConfig] = None,
):
super().__init__()
+ # TODO (joao): remove the `if` below, only used for BC
self.rope_kwargs = {}
- # BC: "rope_type" was originally "type"
- if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
- self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
+ if config is None:
+ logger.warning_once(
+ "`LlamaRotaryEmbedding` can now be fully parameterized by passing the model config through the "
+ "`config` argument. All other arguments will be removed in v4.46"
+ )
+ self.rope_kwargs = {
+ "rope_type": rope_type,
+ "factor": scaling_factor,
+ "dim": dim,
+ "base": base,
+ "max_position_embeddings": max_position_embeddings,
+ }
+ self.rope_type = rope_type
+ self.max_seq_len_cached = max_position_embeddings
+ self.original_max_seq_len = max_position_embeddings
else:
- self.rope_type = "default"
- self.max_seq_len_cached = config.max_position_embeddings
- self.original_max_seq_len = config.max_position_embeddings
+ # BC: "rope_type" was originally "type"
+ if config.rope_scaling is not None:
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
+ else:
+ self.rope_type = "default"
+ self.max_seq_len_cached = config.max_position_embeddings
+ self.original_max_seq_len = config.max_position_embeddings
self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
@@ -144,6 +168,31 @@ def forward(self, x, position_ids):
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
+ """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
+
+ def __init__(self, *args, **kwargs):
+ logger.warning_once(
+ "`LlamaLinearScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use "
+ "`LlamaRotaryEmbedding`, which now also does linear scaling (simply pass the model config to __init__)."
+ )
+ kwargs["rope_type"] = "linear"
+ super().__init__(*args, **kwargs)
+
+
+class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
+ """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
+
+ def __init__(self, *args, **kwargs):
+ logger.warning_once(
+ "`LlamaDynamicNTKScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use "
+ "`LlamaRotaryEmbedding`, which now also does dynamic ntk scaling (simply pass the model config to "
+ "__init__)."
+ )
+ kwargs["rope_type"] = "dynamic"
+ super().__init__(*args, **kwargs)
+
+
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
@@ -206,75 +255,167 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
-def eager_attention_forward(
- module: nn.Module,
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- attention_mask: Optional[torch.Tensor],
- scaling: float,
- dropout: float = 0.0,
- **kwargs,
-):
- key_states = repeat_kv(key, module.num_key_value_groups)
- value_states = repeat_kv(value, module.num_key_value_groups)
-
- attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
- if attention_mask is not None:
- causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
- attn_weights = attn_weights + causal_mask
-
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
- attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
- attn_output = torch.matmul(attn_weights, value_states)
- attn_output = attn_output.transpose(1, 2).contiguous()
-
- return attn_output, attn_weights
-
-
class LlamaAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
- def __init__(self, config: LlamaConfig, layer_idx: int):
+ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
super().__init__()
self.config = config
self.layer_idx = layer_idx
- self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
- self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
- self.scaling = self.head_dim**-0.5
+ if layer_idx is None:
+ logger.warning_once(
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
+ "when creating this class."
+ )
+
self.attention_dropout = config.attention_dropout
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads)
+ self.num_key_value_heads = config.num_key_value_heads
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
+ self.max_position_embeddings = config.max_position_embeddings
+ self.rope_theta = config.rope_theta
self.is_causal = True
- self.q_proj = nn.Linear(
- config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
- )
- self.k_proj = nn.Linear(
- config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
- )
- self.v_proj = nn.Linear(
- config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
- )
- self.o_proj = nn.Linear(
- config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
- )
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
+
+ # TODO (joao): remove in v4.46 (RoPE is computed in the model, not in the decoder layers)
+ self.rotary_emb = LlamaRotaryEmbedding(config=self.config)
def forward(
self,
hidden_states: torch.Tensor,
- position_embeddings: Tuple[torch.Tensor, torch.Tensor],
- attention_mask: Optional[torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ # use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used
+ query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
+
+ if position_embeddings is None:
+ logger.warning_once(
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
+ "removed and `position_embeddings` will be mandatory."
+ )
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ else:
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
+
+ if attention_mask is not None: # no matter the length, we just slice it
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ attn_output = attn_output.reshape(bsz, q_len, -1)
+
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+class LlamaFlashAttention2(LlamaAttention):
+ """
+ Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
- input_shape = hidden_states.shape[:-1]
- hidden_shape = (*input_shape, -1, self.head_dim)
+ if isinstance(past_key_value, StaticCache):
+ raise ValueError(
+ "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
+ "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
+ )
+
+ output_attentions = False
+
+ bsz, q_len, _ = hidden_states.size()
- query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
- key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
- value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
- cos, sin = position_embeddings
+ # Flash attention requires the input to have the shape
+ # batch_size x seq_length x head_dim x hidden_dim
+ # therefore we just need to keep the original shape
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ if position_embeddings is None:
+ logger.warning_once(
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
+ "removed and `position_embeddings` will be mandatory."
+ )
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ else:
+ cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
@@ -282,30 +423,168 @@ def forward(
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
- attention_interface: Callable = eager_attention_forward
- if self.config._attn_implementation != "eager":
- if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
- logger.warning_once(
- "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
- 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
- )
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
+ # to be able to avoid many of these transpose/reshape/view.
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ dropout_rate = self.attention_dropout if self.training else 0.0
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in the correct dtype just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+ # in fp32. (LlamaRMSNorm handles it correctly)
+
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
else:
- attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+ target_dtype = self.q_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
- attn_output, attn_weights = attention_interface(
- self,
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ attn_output = _flash_attention_forward(
query_states,
key_states,
value_states,
attention_mask,
- dropout=0.0 if not self.training else self.attention_dropout,
- scaling=self.scaling,
+ q_len,
+ position_ids=position_ids,
+ dropout=dropout_rate,
+ sliding_window=getattr(self, "sliding_window", None),
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
+ is_causal=self.is_causal,
**kwargs,
)
- attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+class LlamaSdpaAttention(LlamaAttention):
+ """
+ Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
+ `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
+ SDPA API.
+ """
+
+ # Adapted from LlamaAttention.forward
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ if output_attentions:
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
+ logger.warning_once(
+ "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
+ )
+ return super().forward(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ )
+
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ # use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used
+ query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
+
+ if position_embeddings is None:
+ logger.warning_once(
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
+ "removed and `position_embeddings` will be mandatory."
+ )
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ else:
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ causal_mask = attention_mask
+ if attention_mask is not None:
+ causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
+
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
+ if query_states.device.type == "cuda" and causal_mask is not None:
+ query_states = query_states.contiguous()
+ key_states = key_states.contiguous()
+ value_states = value_states.contiguous()
+
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
+ is_causal = True if causal_mask is None and q_len > 1 else False
+
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
+ query_states,
+ key_states,
+ value_states,
+ attn_mask=causal_mask,
+ dropout_p=self.attention_dropout if self.training else 0.0,
+ is_causal=is_causal,
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.view(bsz, q_len, -1)
+
attn_output = self.o_proj(attn_output)
- return attn_output, attn_weights
+
+ return attn_output, None, past_key_value
+
+
+LLAMA_ATTENTION_CLASSES = {
+ "eager": LlamaAttention,
+ "flash_attention_2": LlamaFlashAttention2,
+ "sdpa": LlamaSdpaAttention,
+}
class LlamaDecoderLayer(nn.Module):
@@ -313,7 +592,7 @@ def __init__(self, config: LlamaConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
- self.self_attn = LlamaAttention(config=config, layer_idx=layer_idx)
+ self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
self.mlp = LlamaMLP(config)
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -328,15 +607,37 @@ def forward(
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
- **kwargs: Unpack[FlashAttentionKwargs],
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
+ **kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`, *optional*):
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
+ query_sequence_length, key_sequence_length)` if default attention is used.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence
+ position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
+ Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
+ with `head_dim` being the embedding dimension of each attention head.
+ kwargs (`dict`, *optional*):
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
+ into the model
+ """
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
- hidden_states, self_attn_weights = self.self_attn(
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
@@ -356,9 +657,13 @@ def forward(
hidden_states = residual + hidden_states
outputs = (hidden_states,)
+
if output_attentions:
outputs += (self_attn_weights,)
+ if use_cache:
+ outputs += (present_key_value,)
+
return outputs
@@ -505,7 +810,10 @@ def __init__(self, config: LlamaConfig):
)
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = LlamaRotaryEmbedding(config=config)
+
self.gradient_checkpointing = False
+ if getattr(config, "pretraining_tp", 1) != 1:
+ logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.")
# Initialize weights and apply final processing
self.post_init()
@@ -522,7 +830,7 @@ def forward(
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[Cache] = None,
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
@@ -550,22 +858,31 @@ def forward(
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
- if use_cache and past_key_values is None:
- past_key_values = DynamicCache()
+ # kept for BC (non `Cache` `past_key_values` inputs)
+ return_legacy_cache = False
+ if use_cache and not isinstance(past_key_values, Cache):
+ return_legacy_cache = True
+ if past_key_values is None:
+ past_key_values = DynamicCache()
+ else:
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
+ logger.warning_once(
+ "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
+ "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
+ "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
+ )
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
-
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
-
hidden_states = inputs_embeds
# create position embeddings to be shared across the decoder layers
@@ -574,6 +891,7 @@ def forward(
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
+ next_decoder_cache = None
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
if output_hidden_states:
@@ -606,6 +924,9 @@ def forward(
hidden_states = layer_outputs[0]
+ if use_cache:
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
+
if output_attentions:
all_self_attns += (layer_outputs[1],)
@@ -615,13 +936,18 @@ def forward(
if output_hidden_states:
all_hidden_states += (hidden_states,)
- output = BaseModelOutputWithPast(
+ next_cache = next_decoder_cache if use_cache else None
+ if return_legacy_cache:
+ next_cache = next_cache.to_legacy_cache()
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
+ return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
- past_key_values=past_key_values if use_cache else None,
+ past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
- return output if return_dict else output.to_tuple()
def _update_causal_mask(
self,
diff --git a/src/transformers/models/llava_next_video/__init__.py b/src/transformers/models/llava_next_video/__init__.py
index e3632c7a2a1427..d079643e73e99d 100644
--- a/src/transformers/models/llava_next_video/__init__.py
+++ b/src/transformers/models/llava_next_video/__init__.py
@@ -13,17 +13,58 @@
# limitations under the License.
from typing import TYPE_CHECKING
-from ...utils import _LazyModule
-from ...utils.import_utils import define_import_structure
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
+_import_structure = {
+ "configuration_llava_next_video": ["LlavaNextVideoConfig"],
+ "processing_llava_next_video": ["LlavaNextVideoProcessor"],
+}
+
+
+try:
+ if not is_vision_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["image_processing_llava_next_video"] = ["LlavaNextVideoImageProcessor"]
+
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_llava_next_video"] = [
+ "LlavaNextVideoForConditionalGeneration",
+ "LlavaNextVideoPreTrainedModel",
+ ]
+
if TYPE_CHECKING:
- from .configuration_llava_next_video import *
- from .image_processing_llava_next_video import *
- from .modeling_llava_next_video import *
- from .processing_llava_next_video import *
+ from .configuration_llava_next_video import LlavaNextVideoConfig
+ from .processing_llava_next_video import LlavaNextVideoProcessor
+
+ try:
+ if not is_vision_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .image_processing_llava_next_video import LlavaNextVideoImageProcessor
+
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_llava_next_video import (
+ LlavaNextVideoForConditionalGeneration,
+ LlavaNextVideoPreTrainedModel,
+ )
+
else:
import sys
- _file = globals()["__file__"]
- sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
diff --git a/src/transformers/models/llava_next_video/configuration_llava_next_video.py b/src/transformers/models/llava_next_video/configuration_llava_next_video.py
index e608e5a0d20ece..2fe889da60336b 100644
--- a/src/transformers/models/llava_next_video/configuration_llava_next_video.py
+++ b/src/transformers/models/llava_next_video/configuration_llava_next_video.py
@@ -158,6 +158,3 @@ def __init__(
self.text_config = text_config
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
-
-
-__all__ = ["LlavaNextVideoConfig"]
diff --git a/src/transformers/models/llava_next_video/image_processing_llava_next_video.py b/src/transformers/models/llava_next_video/image_processing_llava_next_video.py
index f30e2c54fe90a3..59d0d9d9447252 100644
--- a/src/transformers/models/llava_next_video/image_processing_llava_next_video.py
+++ b/src/transformers/models/llava_next_video/image_processing_llava_next_video.py
@@ -414,6 +414,3 @@ def preprocess(
data = {"pixel_values_videos": pixel_values}
return BatchFeature(data=data, tensor_type=return_tensors)
-
-
-__all__ = ["LlavaNextVideoImageProcessor"]
diff --git a/src/transformers/models/llava_next_video/modeling_llava_next_video.py b/src/transformers/models/llava_next_video/modeling_llava_next_video.py
index 7cd7e18abaf3e0..b0a20d6c5ccd93 100644
--- a/src/transformers/models/llava_next_video/modeling_llava_next_video.py
+++ b/src/transformers/models/llava_next_video/modeling_llava_next_video.py
@@ -122,6 +122,21 @@ def forward(self, image_features):
return image_features_spatial_pool.flatten(2).transpose(1, 2).contiguous()
+class LlavaNextVideoMultiModalProjector(nn.Module):
+ def __init__(self, config: LlavaNextVideoConfig):
+ super().__init__()
+
+ self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size, bias=True)
+ self.act = ACT2FN[config.projector_hidden_act]
+ self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True)
+
+ def forward(self, image_features):
+ hidden_states = self.linear_1(image_features)
+ hidden_states = self.act(hidden_states)
+ hidden_states = self.linear_2(hidden_states)
+ return hidden_states
+
+
LLAVA_NEXT_VIDEO_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
@@ -176,21 +191,6 @@ def _init_weights(self, module):
module.weight.data[module.padding_idx].zero_()
-class LlavaNextVideoMultiModalProjector(nn.Module):
- def __init__(self, config: LlavaNextVideoConfig):
- super().__init__()
-
- self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size, bias=True)
- self.act = ACT2FN[config.projector_hidden_act]
- self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True)
-
- def forward(self, image_features):
- hidden_states = self.linear_1(image_features)
- hidden_states = self.act(hidden_states)
- hidden_states = self.linear_2(hidden_states)
- return hidden_states
-
-
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
"""
Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
@@ -1157,6 +1157,3 @@ def get_video_features(
video_features = self.multi_modal_projector(video_features)
video_features = torch.split(video_features, frames, dim=0)
return video_features
-
-
-__all__ = ["LlavaNextVideoForConditionalGeneration", "LlavaNextVideoPreTrainedModel"]
diff --git a/src/transformers/models/llava_next_video/modular_llava_next_video.py b/src/transformers/models/llava_next_video/modular_llava_next_video.py
index 94c1432a41b1f1..3d6431d7ea29ba 100644
--- a/src/transformers/models/llava_next_video/modular_llava_next_video.py
+++ b/src/transformers/models/llava_next_video/modular_llava_next_video.py
@@ -24,7 +24,6 @@
from transformers.models.llava_next.modeling_llava_next import (
LlavaNextCausalLMOutputWithPast,
LlavaNextForConditionalGeneration,
- LlavaNextPreTrainedModel,
image_size_to_num_patches,
)
@@ -219,10 +218,6 @@ def forward(self, image_features):
return image_features_spatial_pool.flatten(2).transpose(1, 2).contiguous()
-class LlavaNextVideoPreTrainedModel(LlavaNextPreTrainedModel):
- pass
-
-
class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration):
def __init__(self, config: LlavaNextVideoConfig, **super_kwargs):
super().__init__(config, **super_kwargs)
@@ -646,6 +641,3 @@ def prepare_inputs_for_generation(
model_inputs["image_sizes"] = image_sizes
return model_inputs
-
-
-__all__ = ["LlavaNextVideoConfig", "LlavaNextVideoForConditionalGeneration", "LlavaNextVideoPreTrainedModel"]
diff --git a/src/transformers/models/llava_next_video/processing_llava_next_video.py b/src/transformers/models/llava_next_video/processing_llava_next_video.py
index 857ee28a080041..65195b77240721 100644
--- a/src/transformers/models/llava_next_video/processing_llava_next_video.py
+++ b/src/transformers/models/llava_next_video/processing_llava_next_video.py
@@ -291,6 +291,3 @@ def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
image_processor_input_names = self.image_processor.model_input_names
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
-
-
-__all__ = ["LlavaNextVideoProcessor"]
diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py
index 4e116e7e3db585..cc35a3504255bf 100755
--- a/src/transformers/models/m2m_100/modeling_m2m_100.py
+++ b/src/transformers/models/m2m_100/modeling_m2m_100.py
@@ -348,6 +348,7 @@ class M2M100FlashAttention2(M2M100Attention):
flash attention and deal with padding tokens in case the input contains any of them.
"""
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
diff --git a/src/transformers/models/mamba2/modeling_mamba2.py b/src/transformers/models/mamba2/modeling_mamba2.py
index 550eeb7f9665e4..c312b9b94351d2 100644
--- a/src/transformers/models/mamba2/modeling_mamba2.py
+++ b/src/transformers/models/mamba2/modeling_mamba2.py
@@ -44,22 +44,14 @@
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined
else:
- mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined, selective_state_update = None, None, None
+ selective_state_update = None
if is_causal_conv1d_available():
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
else:
causal_conv1d_update, causal_conv1d_fn = None, None
-is_fast_path_available = all(
- (
- selective_state_update,
- mamba_chunk_scan_combined,
- mamba_split_conv1d_scan_combined,
- causal_conv1d_fn,
- causal_conv1d_update,
- )
-)
+is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update))
_CHECKPOINT_FOR_DOC = "mistralai/mamba-codestral-7B-v0.1"
_CONFIG_FOR_DOC = "Mamba2Config"
@@ -119,17 +111,6 @@ def segment_sum(input_tensor):
return tensor_segsum
-def apply_mask_to_padding_states(hidden_states, attention_mask):
- """
- Tunes out the hidden states for padding tokens, see https://github.com/state-spaces/mamba/issues/66
- """
- if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
- dtype = hidden_states.dtype
- hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)
-
- return hidden_states
-
-
class Mamba2Cache:
"""
Arguments:
@@ -139,68 +120,50 @@ class Mamba2Cache:
device: torch.device
Attributes:
- dtype: (`torch.dtype`):
- The default `dtype` used to initializing the cache.
- conv_kernel_size: (`int`):
- Model's convolution kernel size taken from config.
- n_groups: (`int`):
- Model's number of groups taken from the config - similar to tensor parallel in Transformer.
- state_size: (`int`):
- Model's SSM state size taken from config.
- num_heads: (`int`):
- The number of heads used in the linear attention / SSM.
- head_dim: (`int`):
- The respective dimension of the heads used in the linear attention / SSM.
- intermediate_size: (`int`):
- Model's intermediate_size based on (expand * hidden_dim) from config.
- conv_states: (`torch.Tensor`):
- A tensor of shape `[num_layers, batch_size, conv_kernel_size, intermediate_size + 2 * n_groups * state_size]` that holds convolutional states.
- ssm_states: (`torch.Tensor`):
- A tensor of shape `[num_layers, batch_size, num_heads, head_dim, state_size]` that holds ssm states.
+ seqlen_offset: int
+ dtype: torch.dtype
+ conv_states: Dict[int, torch.Tensor] # layer_idx -> [batch_size, intermediate_size, conv_kernel_size]
+ ssm_states: Dict[int, torch.Tensor] # layer_idx -> [batch_size, intermediate_size, ssm_state_size]
"""
def __init__(
self, config: Mamba2Config, batch_size: int, dtype: torch.dtype = torch.float16, device: Optional[str] = None
):
+ self.seqlen_offset = 0
self.dtype = dtype
self.conv_kernel_size = config.conv_kernel
- self.n_groups = config.n_groups
- self.state_size = config.state_size
- self.num_heads = config.num_heads
- self.head_dim = config.head_dim
self.intermediate_size = int(config.expand * config.hidden_size)
- self.conv_states = torch.zeros(
- config.num_hidden_layers,
- batch_size,
- self.intermediate_size + 2 * self.n_groups * self.state_size,
- self.conv_kernel_size,
- device=device,
- dtype=dtype,
- )
- self.ssm_states = torch.zeros(
- config.num_hidden_layers,
- batch_size,
- self.num_heads,
- self.head_dim,
- self.state_size,
- device=device,
- dtype=dtype,
- )
+ self.conv_states = {
+ i: torch.zeros(
+ batch_size,
+ self.intermediate_size + 2 * config.n_groups * config.state_size,
+ self.conv_kernel_size,
+ device=device,
+ dtype=dtype,
+ )
+ for i in range(config.num_hidden_layers)
+ }
+ self.ssm_states = {
+ i: torch.zeros(
+ batch_size, config.num_heads, config.head_dim, config.state_size, device=device, dtype=dtype
+ )
+ for i in range(config.num_hidden_layers)
+ }
+ self.activation = config.hidden_act
+ self.act = ACT2FN[config.hidden_act]
def update_conv_state(
- self, layer_idx: int, new_conv_state: torch.Tensor, cache_init: bool = False
+ self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor
) -> torch.Tensor:
- if cache_init:
- self.conv_states[layer_idx] = new_conv_state.to(self.conv_states.device)
- else:
- self.conv_states[layer_idx] = self.conv_states[layer_idx].roll(shifts=-1, dims=-1)
- self.conv_states[layer_idx][:, :, -1] = new_conv_state[:, 0, :].to(self.conv_states.device)
- return self.conv_states[layer_idx]
+ conv_state = self.conv_states[layer_idx]
+ cache_position = cache_position.clamp(0, self.conv_kernel_size - 1)
- def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor):
- self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states.device)
- return self.ssm_states[layer_idx]
+ conv_state = conv_state.roll(shifts=-1, dims=-1)
+ conv_state[:, :, cache_position] = new_conv_state.to(conv_state.device)
+ self.conv_states[layer_idx].zero_()
+ self.conv_states[layer_idx] += conv_state
+ return self.conv_states[layer_idx]
def reset(self):
self.conv_states.zero_()
@@ -306,27 +269,19 @@ def cuda_kernels_forward(
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
):
- # 1. Gated MLP's linear projection
- hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)
- projected_states = self.in_proj(hidden_states)
+ # set up dimensions for reshapes later
- # Set up dimensions for reshapes later
batch_size, seq_len, _ = hidden_states.shape
groups_time_state_size = self.n_groups * self.ssm_state_size
- d_mlp = (
- projected_states.shape[-1]
- - 2 * self.intermediate_size
- - 2 * self.n_groups * self.ssm_state_size
- - self.num_heads
- ) // 2
-
- # Single step calculations via cache
- if cache_params is not None and cache_position is not None and cache_position[0] > 0:
- _, _, gate, hidden_states_B_C, dt = projected_states.squeeze(1).split(
- [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1
- )
+ d_to_remove = 2 * self.intermediate_size + 2 * self.n_groups * self.ssm_state_size + self.num_heads
+
+ # getting projected states from cache if it exists
+ if cache_params is not None and cache_params.seqlen_offset > 0:
+ in_projected_states = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
+ d_mlp = (in_projected_states.shape[-1] - d_to_remove) // 2
+ split_projection_dim = [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads]
+ _, _, gate, hidden_states_B_C, dt = torch.split(in_projected_states, split_projection_dim, dim=-1)
- # 2. Convolution sequence transformation
hidden_states_B_C = causal_conv1d_update(
hidden_states_B_C,
cache_params.conv_states[self.layer_idx],
@@ -340,9 +295,8 @@ def cuda_kernels_forward(
[self.intermediate_size, groups_time_state_size, groups_time_state_size],
dim=-1,
)
-
- # 3. SSM transformation
A = -torch.exp(self.A_log.float()) # (nheads,)
+
A = A[:, None, ...][:, :, None].expand(-1, self.head_dim, self.ssm_state_size).to(dtype=torch.float32)
dt = dt[:, :, None].expand(-1, -1, self.head_dim)
dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim)
@@ -364,18 +318,20 @@ def cuda_kernels_forward(
)
hidden_states = hidden_states.view(batch_size, self.num_heads * self.head_dim)
hidden_states = self.norm(hidden_states, gate)
-
- # 4. Final linear projection
out = self.out_proj(hidden_states)[:, None, ...]
-
- # Fused calculations or step by step if no initialized cache is found
+ # if no cache is found, calling the kernel
else:
+ if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
+ # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66
+ dtype = hidden_states.dtype
+ hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)
+ # 1. Gated MLP's linear projection
+ projected_states = self.in_proj(hidden_states)
A = -torch.exp(self.A_log.float()) # (num_heads) or (intermediate_size, state_size)
dt_limit_kwargs = {} if self.time_step_limit == (0.0, float("inf")) else {"dt_limit": self.time_step_limit}
- # 2-4. Fused kernel for conv1d, SSM, and the final projection
if self.training and cache_params is None:
- out = mamba_split_conv1d_scan_combined(
+ out, ssm_state = mamba_split_conv1d_scan_combined(
projected_states,
self.conv1d.weight.squeeze(1),
self.conv1d.bias,
@@ -392,50 +348,41 @@ def cuda_kernels_forward(
headdim=self.head_dim,
ngroups=self.n_groups,
norm_before_gate=False,
- return_final_states=False,
+ return_final_states=True,
**dt_limit_kwargs,
)
else:
- _, _, gate, hidden_states_B_C, dt = projected_states.split(
- [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1
+ gate, hidden_states_B_C, time_step = torch.split(
+ projected_states,
+ [self.intermediate_size, self.conv_dim, self.num_heads],
+ dim=-1,
)
- # 2. Convolution sequence transformation
- # Init cache
- if cache_params is not None:
- hidden_states_B_C_transposed = hidden_states_B_C.transpose(1, 2)
- conv_states = nn.functional.pad(
- hidden_states_B_C_transposed,
- (cache_params.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0),
- )
- cache_params.update_conv_state(
- layer_idx=self.layer_idx, new_conv_state=conv_states, cache_init=True
- )
-
- if self.activation not in ["silu", "swish"]:
+ # 1D Convolution
+ if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]:
hidden_states_B_C = self.act(
- self.conv1d(hidden_states_B_C.transpose(1, 2))[..., :seq_len].transpose(1, 2)
- )
+ self.conv1d(hidden_states_B_C.transpose(1, 2)).transpose(1, 2)[:, :seq_len]
+ ) # (B, L, self.d_inner + 2 * ngroups * d_state)
else:
hidden_states_B_C = causal_conv1d_fn(
x=hidden_states_B_C.transpose(1, 2),
weight=self.conv1d.weight.squeeze(1),
bias=self.conv1d.bias,
activation=self.activation,
- ).transpose(1, 2)
-
- hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask)
+ ).transpose(1, 2)[:, :seq_len]
hidden_states, B, C = torch.split(
hidden_states_B_C,
[self.intermediate_size, groups_time_state_size, groups_time_state_size],
dim=-1,
)
-
- # 3. SSM transformation
+ if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
+ # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66
+ dtype = hidden_states.dtype
+ hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)
scan_output, ssm_state = mamba_chunk_scan_combined(
hidden_states.view(batch_size, seq_len, -1, self.head_dim),
- dt,
+ time_step,
A,
B.view(batch_size, seq_len, self.n_groups, -1),
C.view(batch_size, seq_len, self.n_groups, -1),
@@ -448,16 +395,11 @@ def cuda_kernels_forward(
dt_softplus=True,
**dt_limit_kwargs,
)
-
- # Init cache
if ssm_state is not None and cache_params is not None:
- cache_params.update_ssm_state(layer_idx=self.layer_idx, new_ssm_state=ssm_state)
-
+ cache_params.ssm_states[self.layer_idx].copy_(ssm_state)
scan_output = scan_output.view(batch_size, seq_len, -1)
# Multiply "gate" branch and apply extra normalization layer
scan_output = self.norm(scan_output, gate)
-
- # 4. Final linear projection
out = self.out_proj(scan_output)
return out
@@ -465,64 +407,60 @@ def cuda_kernels_forward(
def torch_forward(self, input_states, cache_params: Optional[Mamba2Cache]=None, cache_position:Optional[torch.LongTensor]=None, attention_mask: Optional[torch.Tensor]=None):
batch_size, seq_len, _ = input_states.shape
dtype = input_states.dtype
-
- # 1. Gated MLP's linear projection
- input_states = apply_mask_to_padding_states(input_states, attention_mask)
- projected_states = self.in_proj(input_states)
- d_mlp = (projected_states.shape[-1] - 2 * self.intermediate_size - 2 * self.n_groups * self.ssm_state_size-self.num_heads) // 2
- _, _, gate, hidden_states_B_C, dt = projected_states.split(
+ # Gated MLP's linear projection
+ projected_states = self.in_proj(input_states.squeeze(1))
+ d_mlp = (projected_states.shape[-1] - 2 * self.intermediate_size - 2 * self.n_groups * self.ssm_state_size- self.num_heads) // 2
+ _, _, gate, hidden_states, dt = projected_states.split(
[d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1
)
- # 2. Convolution sequence transformation
- if cache_params is not None and cache_position is not None and cache_position[0] > 0:
- cache_params.update_conv_state(layer_idx=self.layer_idx, new_conv_state=hidden_states_B_C, cache_init=False)
-
- # We need to guarantee that anything regarding the cache is on the same device
- conv_states = cache_params.conv_states[self.layer_idx].to(device=self.conv1d.weight.device)
-
- hidden_states_B_C = torch.sum(
- conv_states * self.conv1d.weight.squeeze(1), dim=-1
- )
- if self.use_conv_bias:
- hidden_states_B_C = hidden_states_B_C + self.conv1d.bias
- hidden_states_B_C = self.act(hidden_states_B_C)
- else:
- # Init cache
- if cache_params is not None:
- hidden_states_B_C_transposed = hidden_states_B_C.transpose(1, 2)
- conv_states = nn.functional.pad(
- hidden_states_B_C_transposed, (cache_params.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0)
+ # Convolution sequence transformation
+ if cache_params is not None:
+ ssm_state = cache_params.ssm_states[self.layer_idx].clone()
+ ssm_state = ssm_state.to(hidden_states.device)
+ if cache_params.seqlen_offset > 0:
+ conv_state = cache_params.conv_states[self.layer_idx] # [batch, intermediate_size, conv_kernel_size]
+ conv_state = torch.roll(conv_state, shifts=-1, dims=-1)
+ # handle batched generation - states are copied through
+ conv_state[:, :, -1] = hidden_states[:, 0, :] if hidden_states.ndim == 3 else hidden_states
+ cache_params.conv_states[self.layer_idx].copy_(conv_state)
+ hidden_states = torch.sum(conv_state.to(projected_states.device) * self.conv1d.weight[:, 0, :], dim=-1)
+ if self.use_conv_bias:
+ hidden_states += self.conv1d.bias
+ hidden_states = self.act(hidden_states).to(dtype)[:, None, ...] # [batch, 1, intermediate_size] : decoding
+ else:
+ hidden_states = hidden_states.transpose(1,2)
+ conv_state = nn.functional.pad(
+ hidden_states,
+ (self.conv_kernel_size - hidden_states.shape[-1], 0)
)
- cache_params.update_conv_state(layer_idx=self.layer_idx, new_conv_state=conv_states, cache_init=True)
-
- hidden_states_B_C = self.act(self.conv1d(hidden_states_B_C.transpose(1, 2))[..., :seq_len].transpose(1, 2))
-
- hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask)
- hidden_states, B, C = torch.split(
- hidden_states_B_C,
- [self.intermediate_size, self.n_groups * self.ssm_state_size, self.n_groups * self.ssm_state_size],
- dim=-1
- )
-
- # 3. SSM transformation
+ cache_params.conv_states[self.layer_idx].copy_(conv_state)
+ hidden_states = self.act(self.conv1d(hidden_states).transpose(1,2))[:, :seq_len, :] # [batch, intermediate_size, seq_len]
+ if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
+ dtype = hidden_states.dtype
+ # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66
+ hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)
+ else:
+ ssm_state = torch.zeros(
+ (batch_size, self.num_heads, self.head_dim, self.ssm_state_size),
+ device=hidden_states.device, dtype=dtype
+ )
+ hidden_states = self.act(self.conv1d(hidden_states.transpose(1, 2))[..., :seq_len].transpose(1, 2))
+ hidden_states, B, C = torch.split(hidden_states, [self.intermediate_size, self.n_groups * self.ssm_state_size, self.n_groups * self.ssm_state_size], dim=-1)
A = -torch.exp(self.A_log.float()) # [num_heads]
- if cache_params is not None and cache_position is not None and cache_position[0] > 0:
- # We need to guarantee that anything regarding the cache is on the same device
- cache_device = cache_params.ssm_states.device
-
+ if cache_params is not None and cache_params.seqlen_offset > 0:
# Note: there is no need to pad parameter matrices here, as there is just one new token
# for batched generation
- dt = dt[:, 0, :][:, None, ...]
+ dt = dt[:, None, ...] if dt.ndim == 2 else dt[:, 0, :][:, None, ...]
dt = dt.transpose(1, 2).expand(batch_size, dt.shape[-1], self.head_dim)
# [num_heads] -> [num_heads, head_dim]
dt_bias = self.dt_bias[..., None].expand(self.dt_bias.shape[0], self.head_dim)
dt = torch.nn.functional.softplus(dt + dt_bias.to(dt.dtype))
- dt = torch.clamp(dt, self.time_step_limit[0], self.time_step_limit[1])
+ dt = torch.clamp(dt, self.time_step_min) #, self.time_step_max)
A = A[..., None, None].expand(self.num_heads, self.head_dim, self.ssm_state_size).to(dtype=torch.float32)
# [bsz, num_heads, head_dim, state_size]
- dA = (torch.exp(dt[..., None] * A)).to(device=cache_device)
+ dA = torch.exp(dt[..., None] * A)
# Discretize B
# [bsz, n_groups * state_size] -> [bsz, n_groups, 1, state_size] ->
@@ -536,12 +474,11 @@ def torch_forward(self, input_states, cache_params: Optional[Mamba2Cache]=None,
# Discretize x into dB
# [bsz, intermediate_size] -> [bsz, num_heads, head_dim]
hidden_states = hidden_states.reshape(batch_size, -1, self.head_dim)
- dBx = (dB * hidden_states[..., None]).to(device=cache_device)
+ dBx = dB * hidden_states[..., None]
# State calculation
- cache_params.update_ssm_state(
- layer_idx=self.layer_idx,
- new_ssm_state=cache_params.ssm_states[self.layer_idx] * dA + dBx
+ cache_params.ssm_states[self.layer_idx].copy_(
+ cache_params.ssm_states[self.layer_idx] * dA + dBx
)
# Subsequent output
@@ -551,7 +488,7 @@ def torch_forward(self, input_states, cache_params: Optional[Mamba2Cache]=None,
C = C.reshape(batch_size, -1, C.shape[-1])
# [bsz, num_heads, head_dim]
- ssm_states = cache_params.ssm_states[self.layer_idx].to(device=C.device, dtype=C.dtype) # Shape: [b, h, d, n]
+ ssm_states = cache_params.ssm_states[self.layer_idx].to(C.dtype) # Shape: [b, h, d, n]
# Reshape ssm_states to merge the first two dimensions
ssm_states_reshaped = ssm_states.view(batch_size * self.num_heads, self.head_dim, self.ssm_state_size) # Shape: [b*h, d, n]
C_reshaped = C.view(batch_size * self.num_heads, self.ssm_state_size, 1) # Shape: [b*h, n, 1]
@@ -568,9 +505,9 @@ def torch_forward(self, input_states, cache_params: Optional[Mamba2Cache]=None,
else:
# begin ssd naive implementation without einsums
dt = nn.functional.softplus(dt + self.dt_bias)
- dt = torch.clamp(dt, self.time_step_limit[0], self.time_step_limit[1])
+ dt = torch.clamp(dt, self.time_step_min)
hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim).float()
- B = B.reshape(batch_size, seq_len, -1, self.ssm_state_size).float()
+ B = B.reshape(batch_size, seq_len, -1, self.ssm_state_size).float()
C = C.reshape(batch_size, seq_len, -1, self.ssm_state_size).float()
B = B.repeat(1, 1, self.num_heads // self.n_groups, 1)
C = C.repeat(1, 1, self.num_heads // self.n_groups, 1)
@@ -585,6 +522,7 @@ def torch_forward(self, input_states, cache_params: Optional[Mamba2Cache]=None,
# Rearrange into blocks/chunks
hidden_states, A, B, C = [reshape_into_chunks(t, pad_size, self.chunk_size) for t in (hidden_states, A, B, C)]
+
# [bsz, -1, chunk_size, num_heads] -> [bsz, num_heads, -1, chunk_size]
A = A.permute(0, 3, 1, 2)
A_cumsum = torch.cumsum(A, dim=-1)
@@ -593,43 +531,45 @@ def torch_forward(self, input_states, cache_params: Optional[Mamba2Cache]=None,
# This is the analog of a causal mask
L = torch.exp(segment_sum(A))
- # Contraction of C and B to get G (attention-weights like)
- G_intermediate = C[:, :, :, None, :, :] * B[:, :, None, :, :, :] # shape: (b, c, l, s, h, n)
+ # First, contraction of C and B to get G (attention-weights like)
+ G_intermediate = C[:, :, :, None, :, :] * B[:, :, None, :, : ,:] # shape: (b, c, l, s, h, n)
G = G_intermediate.sum(dim=-1) # shape: (b, c, l, s, h)
- # Compute M, equivalent to applying attention mask to weights
+
+ # Step 2: Compute M, equivalent to applying attention mask to weights
M_intermediate = G[..., None] * L.permute(0, 2, 3, 4, 1)[..., None]
M = M_intermediate.sum(dim=-1)
- # Compute Y_diag (apply to values)
- Y_diag = (M[..., None] * hidden_states[:, :, None]).sum(dim=3)
+ # Step 3: Compute Y_diag (apply to values)
+ Y_diag = (M[..., None] * hidden_states[:, :, None]).sum(3)
- # 2. Compute the state for each intra-chunk
# (right term of low-rank factorization of off-diagonal blocks; B terms)
- decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum))
- B_decay = B * decay_states.permute(0, -2, -1, 1)[..., None]
- states = (B_decay[..., None, :] * hidden_states[..., None]).sum(dim=2)
- # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
- # (middle term of factorization of off-diag blocks; A terms)
- if cache_params is not None and cache_position is not None and cache_position[0] > 0:
- previous_states = cache_params.ssm_states[self.layer_idx][:, None, ...].to(device=states.device)
+ decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum))
+ B_decay_contraction = B * decay_states.permute(0, 2, 3, 1)[..., None]
+ # permute back B * decay states
+ states = (B_decay_contraction.permute(0, 1, 3, 2, 4)[..., None] * hidden_states.permute(0, 1, 3, 2, 4)[..., None, :]).sum(dim=3).permute(0, 1, 2, 4, 3)
+ if cache_params is not None and cache_params.seqlen_offset > 0:
+ previous_states = cache_params.ssm_states[self.layer_idx][:, None, ...]
else:
previous_states = torch.zeros_like(states[:, :1])
states = torch.cat([previous_states, states], dim=1)
decay_chunk = torch.exp(segment_sum(nn.functional.pad(A_cumsum[:, :, :, -1], (1, 0))))
- decay_chunk = decay_chunk.transpose(1, 3)
- new_states = (decay_chunk[..., None, None] * states[:, :, None, ...]).sum(dim=1)
+
+ states_permuted = states.permute(0, 2, 1, 3, 4)
+ result = (decay_chunk[..., None, None] * states_permuted[:, :, None, ...]).sum(dim=2)
+ new_states = result.permute(0, 2, 1, 3, 4)
states, ssm_state = new_states[:, :-1], new_states[:, -1]
- # 4. Compute state -> output conversion per chunk
+ # Compute state -> output conversion per chunk
# (left term of low-rank factorization of off-diagonal blocks; C terms)
state_decay_out = torch.exp(A_cumsum)
+ # compute Yoff
C_times_states = (C[..., None, :] * states[:, :, None, ...])
state_decay_out_permuted = state_decay_out.permute(0, 2, 3, 1)
Y_off = (C_times_states.sum(-1) * state_decay_out_permuted[..., None])
-
# Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks)
+
y = Y_diag + Y_off
# [bsz, -1, self.chunk_size, num_heads, head_dim] -> [bsz, (padded) seq_len, num_heads, head_dim]
y = y.reshape(batch_size, -1, self.num_heads, self.head_dim)
@@ -639,10 +579,8 @@ def torch_forward(self, input_states, cache_params: Optional[Mamba2Cache]=None,
if pad_size > 0:
y = y[:, :seq_len, :, :]
y = y.reshape(batch_size, seq_len, -1)
-
- # Init cache
if ssm_state is not None and cache_params is not None:
- cache_params.update_ssm_state(layer_idx=self.layer_idx, new_ssm_state=ssm_state)
+ cache_params.ssm_states[self.layer_idx].copy_(ssm_state)
scan_output = self.norm(y, gate)
@@ -978,6 +916,9 @@ def forward(
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
+ if use_cache:
+ cache_params.seqlen_offset += inputs_embeds.shape[1]
+
hidden_states = self.norm_f(hidden_states)
if output_hidden_states:
@@ -1034,6 +975,10 @@ def prepare_inputs_for_generation(
):
# Overwitten -- uses `cache_params` as opposed to `past_key_values`
+ if inputs_embeds is not None:
+ past_len = inputs_embeds.shape[1] + input_ids.shape[1]
+ else:
+ past_len = input_ids.shape[1]
if use_cache:
# `cache_position` should have been initialized in `generate`
if cache_position is None:
@@ -1042,18 +987,33 @@ def prepare_inputs_for_generation(
"`model.generate`, you are responsible for passing in a valid `cache_position` if "
"you are calling `prepare_inputs_for_generation` directly with `use_cache=True`"
)
+ # how do we detect that we are in decoding without cache?
if cache_position[0] > 0:
input_ids = input_ids[:, -1][..., None]
-
- if attention_mask is not None:
- attention_mask = None
+ attention_mask = attention_mask[:, -1][..., None]
else:
# we initialize the `cache_position` to full size of `conv_states` at prefill stage
# considering padding will be applied when input length is shorter, and truncation
# will be applied when it is longer, so it will be equivalent to always have it match
# the length of `cache_params.conv_states`, which is `config.conv_kernel`
- cache_position = torch.arange(0, self.config.conv_kernel, device=input_ids.device)
-
+ cache_position = torch.arange(0, past_len, device=input_ids.device)
+ # if the cache is not used, we also do have to extend the attention mask here
+ # TODO there is likely a cleverer way to do this
+ extended_mask = torch.ones(
+ attention_mask.size(0), past_len - attention_mask.shape[1], device=attention_mask.device
+ )
+ attention_mask = torch.cat([attention_mask, extended_mask], dim=1)
+ cache_params = None
+
+ if attention_mask.shape[1] < past_len:
+ # we have to update manually the attention mask if
+ # we are in decoding without cache
+ # and we don't have position_ids here
+ # TODO but we should be able to use cache_position though at a later time
+ extended_mask = torch.ones(
+ attention_mask.size(0), past_len - attention_mask.shape[1], device=attention_mask.device
+ )
+ attention_mask = torch.cat([attention_mask, extended_mask], dim=1)
if inputs_embeds is not None and cache_params is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py
index e272c98f06975a..95cd7c65ef32c2 100755
--- a/src/transformers/models/mbart/modeling_mbart.py
+++ b/src/transformers/models/mbart/modeling_mbart.py
@@ -291,6 +291,7 @@ class MBartFlashAttention2(MBartAttention):
flash attention and deal with padding tokens in case the input contains any of them.
"""
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
diff --git a/src/transformers/models/mimi/modeling_mimi.py b/src/transformers/models/mimi/modeling_mimi.py
index 1440ce1e075c95..cbdd2c663c5844 100644
--- a/src/transformers/models/mimi/modeling_mimi.py
+++ b/src/transformers/models/mimi/modeling_mimi.py
@@ -26,7 +26,6 @@
from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_outputs import BaseModelOutputWithPast
-from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
from ...modeling_utils import PreTrainedModel
from ...utils import (
ModelOutput,
@@ -365,55 +364,24 @@ def forward(self, x: torch.Tensor):
# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Mimi
class MimiRotaryEmbedding(nn.Module):
- def __init__(
- self,
- config: MimiConfig,
- device=None,
- ):
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
- self.rope_kwargs = {}
- # BC: "rope_type" was originally "type"
- if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
- self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
- else:
- self.rope_type = "default"
- self.max_seq_len_cached = config.max_position_embeddings
- self.original_max_seq_len = config.max_position_embeddings
-
- self.config = config
- self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
- inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
+ self.dim = dim
+ self.max_position_embeddings = max_position_embeddings
+ self.base = base
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
- self.original_inv_freq = self.inv_freq
-
- def _dynamic_frequency_update(self, position_ids, device):
- """
- dynamic RoPE layers should recompute `inv_freq` in the following situations:
- 1 - growing beyond the cached sequence length (allow scaling)
- 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
- """
- seq_len = torch.max(position_ids) + 1
- if seq_len > self.max_seq_len_cached: # growth
- inv_freq, self.attention_scaling = self.rope_init_fn(
- self.config, device, seq_len=seq_len, **self.rope_kwargs
- )
- self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
- self.max_seq_len_cached = seq_len
-
- if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
- self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
- self.max_seq_len_cached = self.original_max_seq_len
@torch.no_grad()
+ # copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.forward
+ # TODO(joao): add me back asap :)
def forward(self, x, position_ids):
- if "dynamic" in self.rope_type:
- self._dynamic_frequency_update(position_ids, device=x.device)
-
- # Core RoPE block
+ # x: [bs, num_attention_heads, seq_len, head_size]
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
- # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
+ # Force float32 since bfloat16 loses precision on long contexts
+ # See https://github.com/huggingface/transformers/pull/29285
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
@@ -421,11 +389,6 @@ def forward(self, x, position_ids):
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
-
- # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
- cos = cos * self.attention_scaling
- sin = sin * self.attention_scaling
-
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
@@ -494,8 +457,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
-# copied from transformers.models.gemma.modeling_gemma.GemmaAttention with Gemma->Mimi
-# no longer copied after attention refactors
+# Copied from transformers.models.gemma.modeling_gemma.GemmaAttention with Gemma->Mimi
class MimiAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
@@ -531,7 +493,11 @@ def __init__(self, config: MimiConfig, layer_idx: Optional[int] = None):
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
- self.rotary_emb = MimiRotaryEmbedding(config)
+ self.rotary_emb = MimiRotaryEmbedding(
+ self.head_dim,
+ max_position_embeddings=self.max_position_embeddings,
+ base=self.rope_theta,
+ )
self.sliding_window = config.sliding_window # Ignore copy
def forward(
@@ -593,8 +559,7 @@ def forward(
return attn_output, attn_weights, past_key_value
-# NO LONGER EXIST Copied from transformers.models.gemma.modeling_gemma.GemmaFlashAttention2 with Gemma->Mimi
-# TODO cyril: modular
+# Copied from transformers.models.gemma.modeling_gemma.GemmaFlashAttention2 with Gemma->Mimi
class MimiFlashAttention2(MimiAttention):
"""
Mimi flash attention module. This module inherits from `MimiAttention` as the weights of the module stays
@@ -705,8 +670,7 @@ def forward(
return attn_output, attn_weights, past_key_value
-# NO LONGER EXIST Copied from transformers.models.gemma.modeling_gemma.GemmaSdpaAttention with Gemma->Mimi
-# TODO cyril: modular
+# Copied from transformers.models.gemma.modeling_gemma.GemmaSdpaAttention with Gemma->Mimi
class MimiSdpaAttention(MimiAttention):
"""
Mimi attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py
index 90c38895b4280b..e94281b29fe1a8 100644
--- a/src/transformers/models/mistral/modeling_mistral.py
+++ b/src/transformers/models/mistral/modeling_mistral.py
@@ -1,19 +1,36 @@
-# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
-# This file was automatically generated from src/transformers/models/mistral/modular_mistral.py.
-# Do NOT edit this file manually as any edits will be overwritten by the generation of
-# the file from the modular. If any change should be done, please apply the change to the
-# modular_mistral.py file directly. One of our CI enforces this.
-# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
-from typing import Callable, List, Optional, Tuple, Union
+# coding=utf-8
+# Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch Mistral model."""
+
+import math
+from typing import List, Optional, Tuple, Union
import torch
+import torch.utils.checkpoint
from torch import nn
+from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter
-from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
@@ -21,42 +38,79 @@
SequenceClassifierOutputWithPast,
TokenClassifierOutput,
)
-from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
-from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
-from ...processing_utils import Unpack
+from ...modeling_utils import PreTrainedModel
from ...utils import (
- LossKwargs,
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
+ is_flash_attn_2_available,
+ is_flash_attn_greater_or_equal_2_10,
logging,
replace_return_docstrings,
)
from .configuration_mistral import MistralConfig
+if is_flash_attn_2_available():
+ from ...modeling_flash_attention_utils import _flash_attention_forward
+
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "mistralai/Mistral-7B-v0.1"
_CONFIG_FOR_DOC = "MistralConfig"
-class MistralMLP(nn.Module):
- def __init__(self, config):
+# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Mistral
+class MistralRMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ MistralRMSNorm is equivalent to T5LayerNorm
+ """
super().__init__()
- self.config = config
- self.hidden_size = config.hidden_size
- self.intermediate_size = config.intermediate_size
- self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
- self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
- self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
- self.act_fn = ACT2FN[config.hidden_act]
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
- def forward(self, x):
- down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
- return down_proj
+ def extra_repr(self):
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
+
+
+class MistralRotaryEmbedding(nn.Module):
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
+ super().__init__()
+
+ self.dim = dim
+ self.max_position_embeddings = max_position_embeddings
+ self.base = base
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+
+ @torch.no_grad()
+ # copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.forward
+ # TODO(joao): add me back asap :)
+ def forward(self, x, position_ids):
+ # x: [bs, num_attention_heads, seq_len, head_size]
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
+ position_ids_expanded = position_ids[:, None, :].float()
+ # Force float32 since bfloat16 loses precision on long contexts
+ # See https://github.com/huggingface/transformers/pull/29285
+ device_type = x.device.type
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False):
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos()
+ sin = emb.sin()
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+# Copied from transformers.models.llama.modeling_llama.rotate_half
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
@@ -64,6 +118,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)
+# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
@@ -91,6 +146,21 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
return q_embed, k_embed
+class MistralMLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = config.intermediate_size
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, hidden_state):
+ return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
+
+
+# Copied from transformers.models.llama.modeling_llama.repeat_kv
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
@@ -103,66 +173,65 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
-def eager_attention_forward(
- module: nn.Module,
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- attention_mask: Optional[torch.Tensor],
- scaling: float,
- dropout: float = 0.0,
- **kwargs,
-):
- key_states = repeat_kv(key, module.num_key_value_groups)
- value_states = repeat_kv(value, module.num_key_value_groups)
-
- attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
- if attention_mask is not None:
- causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
- attn_weights = attn_weights + causal_mask
-
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
- attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
- attn_output = torch.matmul(attn_weights, value_states)
- attn_output = attn_output.transpose(1, 2).contiguous()
-
- return attn_output, attn_weights
-
-
class MistralAttention(nn.Module):
- """Multi-headed attention from 'Attention Is All You Need' paper"""
+ """
+ Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
+ and "Generating Long Sequences with Sparse Transformers".
+ """
- def __init__(self, config: MistralConfig, layer_idx: int):
+ def __init__(self, config: MistralConfig, layer_idx: Optional[int] = None):
super().__init__()
self.config = config
self.layer_idx = layer_idx
- self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
- self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
- self.scaling = self.head_dim**-0.5
+ if layer_idx is None:
+ logger.warning_once(
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
+ "when creating this class."
+ )
+
self.attention_dropout = config.attention_dropout
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = config.head_dim
+ self.num_key_value_heads = config.num_key_value_heads
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
+ self.max_position_embeddings = config.max_position_embeddings
+ self.rope_theta = config.rope_theta
self.is_causal = True
- self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False)
- self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
- self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
- self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
+
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
+
+ self.rotary_emb = MistralRotaryEmbedding(
+ self.head_dim,
+ max_position_embeddings=self.max_position_embeddings,
+ base=self.rope_theta,
+ )
def forward(
self,
hidden_states: torch.Tensor,
- position_embeddings: Tuple[torch.Tensor, torch.Tensor],
- attention_mask: Optional[torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
- **kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
- input_shape = hidden_states.shape[:-1]
- hidden_shape = (*input_shape, -1, self.head_dim)
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
- query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
- key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
- value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
- cos, sin = position_embeddings
+ cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
@@ -170,58 +239,253 @@ def forward(
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
- attention_interface: Callable = eager_attention_forward
- if self.config._attn_implementation != "eager":
- if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
- logger.warning_once(
- "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
- 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
- )
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
+
+ if attention_mask is not None: # no matter the length, we just slice it
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ attn_output = attn_output.view(bsz, q_len, -1)
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+class MistralFlashAttention2(MistralAttention):
+ """
+ Mistral flash attention module. This module inherits from `MistralAttention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ ):
+ if isinstance(past_key_value, StaticCache):
+ raise ValueError(
+ "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
+ "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
+ )
+
+ output_attentions = False
+
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ kv_seq_len += cache_position[0]
+
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ # repeat k/v heads if n_kv_heads < n_heads
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+ dropout_rate = 0.0 if not self.training else self.attention_dropout
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in float16 just to be sure everything works as expected.
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
else:
- attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+ target_dtype = self.q_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
- attn_output, attn_weights = attention_interface(
- self,
+ # Reashape to the expected shape for Flash Attention
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ attn_output = _flash_attention_forward(
query_states,
key_states,
value_states,
attention_mask,
- dropout=0.0 if not self.training else self.attention_dropout,
- scaling=self.scaling,
- sliding_window=getattr(self.config, "sliding_window", None), # main diff with Llama
- **kwargs,
+ q_len,
+ position_ids=position_ids,
+ dropout=dropout_rate,
+ sliding_window=getattr(self.config, "sliding_window", None),
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
+ is_causal=self.is_causal,
)
- attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim).contiguous()
attn_output = self.o_proj(attn_output)
- return attn_output, attn_weights
+ if not output_attentions:
+ attn_weights = None
-class MistralRMSNorm(nn.Module):
- def __init__(self, hidden_size, eps=1e-6):
- """
- MistralRMSNorm is equivalent to T5LayerNorm
- """
- super().__init__()
- self.weight = nn.Parameter(torch.ones(hidden_size))
- self.variance_epsilon = eps
+ return attn_output, attn_weights, past_key_value
- def forward(self, hidden_states):
- input_dtype = hidden_states.dtype
- hidden_states = hidden_states.to(torch.float32)
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
- return self.weight * hidden_states.to(input_dtype)
- def extra_repr(self):
- return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
+# copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Mistral
+# TODO(joao): add me back asap :)
+class MistralSdpaAttention(MistralAttention):
+ """
+ Mistral attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
+ `MistralAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
+ SDPA API.
+ """
+ # Adapted from MistralAttention.forward
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ if output_attentions:
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
+ logger.warning_once(
+ "MistralModel is using MistralSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
+ )
+ return super().forward(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ )
+
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
+
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ causal_mask = attention_mask
+ if attention_mask is not None:
+ causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
+
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
+ if query_states.device.type == "cuda" and causal_mask is not None:
+ query_states = query_states.contiguous()
+ key_states = key_states.contiguous()
+ value_states = value_states.contiguous()
+
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
+ is_causal = True if causal_mask is None and q_len > 1 else False
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
+ query_states,
+ key_states,
+ value_states,
+ attn_mask=causal_mask,
+ dropout_p=self.attention_dropout if self.training else 0.0,
+ is_causal=is_causal,
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.view(bsz, q_len, -1)
+
+ attn_output = self.o_proj(attn_output)
+
+ return attn_output, None, past_key_value
+
+
+MISTRAL_ATTENTION_CLASSES = {
+ "eager": MistralAttention,
+ "flash_attention_2": MistralFlashAttention2,
+ "sdpa": MistralSdpaAttention,
+}
+
+
+# copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with Llama->Mistral, LLAMA->MISTRAL
+# TODO(joao): add me back asap :)
class MistralDecoderLayer(nn.Module):
def __init__(self, config: MistralConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
- self.self_attn = MistralAttention(config=config, layer_idx=layer_idx)
+
+ self.self_attn = MISTRAL_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
+
self.mlp = MistralMLP(config)
self.input_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -235,15 +499,33 @@ def forward(
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
- **kwargs: Unpack[FlashAttentionKwargs],
+ **kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`, *optional*):
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
+ query_sequence_length, key_sequence_length)` if default attention is used.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence
+ kwargs (`dict`, *optional*):
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
+ into the model
+ """
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
- hidden_states, self_attn_weights = self.self_attn(
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
@@ -251,7 +533,6 @@ def forward(
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
- position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = residual + hidden_states
@@ -263,75 +544,14 @@ def forward(
hidden_states = residual + hidden_states
outputs = (hidden_states,)
+
if output_attentions:
outputs += (self_attn_weights,)
- return outputs
-
-
-class MistralRotaryEmbedding(nn.Module):
- def __init__(
- self,
- config: MistralConfig,
- device=None,
- ):
- super().__init__()
- self.rope_kwargs = {}
- # BC: "rope_type" was originally "type"
- if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
- self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
- else:
- self.rope_type = "default"
- self.max_seq_len_cached = config.max_position_embeddings
- self.original_max_seq_len = config.max_position_embeddings
-
- self.config = config
- self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
-
- inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
- self.register_buffer("inv_freq", inv_freq, persistent=False)
- self.original_inv_freq = self.inv_freq
-
- def _dynamic_frequency_update(self, position_ids, device):
- """
- dynamic RoPE layers should recompute `inv_freq` in the following situations:
- 1 - growing beyond the cached sequence length (allow scaling)
- 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
- """
- seq_len = torch.max(position_ids) + 1
- if seq_len > self.max_seq_len_cached: # growth
- inv_freq, self.attention_scaling = self.rope_init_fn(
- self.config, device, seq_len=seq_len, **self.rope_kwargs
- )
- self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
- self.max_seq_len_cached = seq_len
-
- if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
- self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
- self.max_seq_len_cached = self.original_max_seq_len
-
- @torch.no_grad()
- def forward(self, x, position_ids):
- if "dynamic" in self.rope_type:
- self._dynamic_frequency_update(position_ids, device=x.device)
-
- # Core RoPE block
- inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
- position_ids_expanded = position_ids[:, None, :].float()
- # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
- device_type = x.device.type
- device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
- with torch.autocast(device_type=device_type, enabled=False):
- freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
- emb = torch.cat((freqs, freqs), dim=-1)
- cos = emb.cos()
- sin = emb.sin()
-
- # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
- cos = cos * self.attention_scaling
- sin = sin * self.attention_scaling
+ if use_cache:
+ outputs += (present_key_value,)
- return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+ return outputs
MISTRAL_START_DOCSTRING = r"""
@@ -360,11 +580,10 @@ class MistralPreTrainedModel(PreTrainedModel):
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["MistralDecoderLayer"]
- _skip_keys_device_placement = ["past_key_values"]
+ _skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
- _supports_quantized_cache = True
_supports_static_cache = True
def _init_weights(self, module):
@@ -448,7 +667,7 @@ def _init_weights(self, module):
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
- Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
+ Indices indicating the position of the input sequence tokens in the sequence. Unlike `position_ids`,
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
the complete sequence length.
"""
@@ -475,10 +694,10 @@ def __init__(self, config: MistralConfig):
self.layers = nn.ModuleList(
[MistralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
+ self._attn_implementation = config._attn_implementation
self.norm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
- self.rotary_emb = MistralRotaryEmbedding(config=config)
- self.gradient_checkpointing = False
+ self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
@@ -494,36 +713,48 @@ def forward(
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[Cache] = None,
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
- **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ # retrieve input_ids and inputs_embeds
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if self.gradient_checkpointing and self.training and use_cache:
logger.warning_once(
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
- if use_cache and past_key_values is None:
- past_key_values = DynamicCache()
+ # kept for BC (non `Cache` `past_key_values` inputs)
+ return_legacy_cache = False
+ if use_cache and not isinstance(past_key_values, Cache):
+ return_legacy_cache = True
+ if past_key_values is None:
+ past_key_values = DynamicCache()
+ else:
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
+ logger.warning_once(
+ "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
+ "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
+ "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
+ )
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
@@ -535,19 +766,17 @@ def forward(
position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask(
- attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
+ attention_mask, inputs_embeds, cache_position, past_key_values, use_cache, output_attentions
)
hidden_states = inputs_embeds
- # create position embeddings to be shared across the decoder layers
- position_embeddings = self.rotary_emb(hidden_states, position_ids)
-
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
+ next_decoder_cache = None
- for decoder_layer in self.layers[: self.config.num_hidden_layers]:
+ for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
@@ -561,7 +790,6 @@ def forward(
output_attentions,
use_cache,
cache_position,
- position_embeddings,
)
else:
layer_outputs = decoder_layer(
@@ -572,12 +800,13 @@ def forward(
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
- position_embeddings=position_embeddings,
- **flash_attn_kwargs,
)
hidden_states = layer_outputs[0]
+ if use_cache:
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
+
if output_attentions:
all_self_attns += (layer_outputs[1],)
@@ -587,13 +816,18 @@ def forward(
if output_hidden_states:
all_hidden_states += (hidden_states,)
- output = BaseModelOutputWithPast(
+ next_cache = next_decoder_cache if use_cache else None
+ if return_legacy_cache:
+ next_cache = next_cache.to_legacy_cache()
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
+ return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
- past_key_values=past_key_values if use_cache else None,
+ past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
- return output if return_dict else output.to_tuple()
def _update_causal_mask(
self,
@@ -601,10 +835,11 @@ def _update_causal_mask(
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
+ use_cache: bool,
output_attentions: bool,
):
if self.config._attn_implementation == "flash_attention_2":
- if attention_mask is not None and past_key_values is not None:
+ if attention_mask is not None and use_cache:
is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0]
if is_padding_right:
raise ValueError(
@@ -746,9 +981,6 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
return causal_mask
-class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
-
-
class MistralForCausalLM(MistralPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
@@ -796,7 +1028,6 @@ def forward(
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
- **kwargs: Unpack[KwargsForCausalLM],
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
@@ -817,8 +1048,8 @@ def forward(
```python
>>> from transformers import AutoTokenizer, MistralForCausalLM
- >>> model = MistralForCausalLM.from_pretrained("meta-mistral/Mistral-2-7b-hf")
- >>> tokenizer = AutoTokenizer.from_pretrained("meta-mistral/Mistral-2-7b-hf")
+ >>> model = MistralForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1")
+ >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
@@ -828,6 +1059,7 @@ def forward(
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -846,7 +1078,6 @@ def forward(
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
- **kwargs,
)
hidden_states = outputs[0]
@@ -855,7 +1086,18 @@ def forward(
loss = None
if labels is not None:
- loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
+ # Upcast to float if we need to compute the loss to avoid potential precision issues
+ logits = logits.float()
+ # Shift so that tokens < n predict n
+ shift_logits = logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ # Flatten the tokens
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
+ shift_labels = shift_labels.view(-1)
+ # Ensure tensors are on the same device
+ shift_labels = shift_labels.to(shift_logits.device)
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + outputs[1:]
@@ -872,24 +1114,26 @@ def forward(
@add_start_docstrings(
"""
- The Mistral Model transformer with a token classification head on top (a linear layer on top of the hidden-states
- output) e.g. for Named-Entity-Recognition (NER) tasks.
+ The Mistral Model transformer with a sequence classification head on top (linear layer).
+
+ [`MistralForSequenceClassification`] uses the last token in order to do the classification, as other causal models
+ (e.g. GPT-2) do.
+
+ Since it does classification on the last token, it requires to know the position of the last token. If a
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
+ each row of the batch).
""",
MISTRAL_START_DOCSTRING,
)
-class MistralForTokenClassification(MistralPreTrainedModel):
+# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Mistral, LLAMA->MISTRAL
+class MistralForSequenceClassification(MistralPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.model = MistralModel(config)
- if getattr(config, "classifier_dropout", None) is not None:
- classifier_dropout = config.classifier_dropout
- elif getattr(config, "hidden_dropout", None) is not None:
- classifier_dropout = config.hidden_dropout
- else:
- classifier_dropout = 0.1
- self.dropout = nn.Dropout(classifier_dropout)
- self.score = nn.Linear(config.hidden_size, config.num_labels)
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
# Initialize weights and apply final processing
self.post_init()
@@ -901,24 +1145,19 @@ def set_input_embeddings(self, value):
self.model.embed_tokens = value
@add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
- @add_code_sample_docstrings(
- checkpoint=_CHECKPOINT_FOR_DOC,
- output_type=TokenClassifierOutput,
- config_class=_CONFIG_FOR_DOC,
- )
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[List[torch.FloatTensor]] = None,
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
- ) -> Union[Tuple, TokenClassifierOutput]:
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
@@ -927,7 +1166,7 @@ def forward(
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- outputs = self.model(
+ transformer_outputs = self.model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
@@ -938,47 +1177,67 @@ def forward(
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
- sequence_output = outputs[0]
- sequence_output = self.dropout(sequence_output)
- logits = self.score(sequence_output)
+ hidden_states = transformer_outputs[0]
+ logits = self.score(hidden_states)
+
+ if input_ids is not None:
+ batch_size = input_ids.shape[0]
+ else:
+ batch_size = inputs_embeds.shape[0]
+
+ if self.config.pad_token_id is None and batch_size != 1:
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
+ if self.config.pad_token_id is None:
+ sequence_lengths = -1
+ else:
+ if input_ids is not None:
+ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
+ sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
+ sequence_lengths = sequence_lengths.to(logits.device)
+ else:
+ sequence_lengths = -1
+
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
loss = None
if labels is not None:
- loss = self.loss_function(logits, labels, self.config)
+ loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
if not return_dict:
- output = (logits,) + outputs[2:]
+ output = (pooled_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
- return TokenClassifierOutput(
+ return SequenceClassifierOutputWithPast(
loss=loss,
- logits=logits,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
+ logits=pooled_logits,
+ past_key_values=transformer_outputs.past_key_values,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
)
@add_start_docstrings(
"""
- The Mistral Model transformer with a sequence classification head on top (linear layer).
-
- [`MistralForSequenceClassification`] uses the last token in order to do the classification, as other causal models
- (e.g. GPT-2) do.
-
- Since it does classification on the last token, it requires to know the position of the last token. If a
- `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
- no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
- padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
- each row of the batch).
+ The Mistral Model transformer with a token classification head on top (a linear layer on top of the hidden-states
+ output) e.g. for Named-Entity-Recognition (NER) tasks.
""",
MISTRAL_START_DOCSTRING,
)
-class MistralForSequenceClassification(MistralPreTrainedModel):
+# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Mistral, LLAMA->MISTRAL
+class MistralForTokenClassification(MistralPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.model = MistralModel(config)
- self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
+ if getattr(config, "classifier_dropout", None) is not None:
+ classifier_dropout = config.classifier_dropout
+ elif getattr(config, "hidden_dropout", None) is not None:
+ classifier_dropout = config.hidden_dropout
+ else:
+ classifier_dropout = 0.1
+ self.dropout = nn.Dropout(classifier_dropout)
+ self.score = nn.Linear(config.hidden_size, config.num_labels)
# Initialize weights and apply final processing
self.post_init()
@@ -990,19 +1249,24 @@ def set_input_embeddings(self, value):
self.model.embed_tokens = value
@add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=TokenClassifierOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
- ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
+ ) -> Union[Tuple, TokenClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
@@ -1011,7 +1275,7 @@ def forward(
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- transformer_outputs = self.model(
+ outputs = self.model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
@@ -1022,43 +1286,23 @@ def forward(
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
- hidden_states = transformer_outputs[0]
- logits = self.score(hidden_states)
-
- if input_ids is not None:
- batch_size = input_ids.shape[0]
- else:
- batch_size = inputs_embeds.shape[0]
-
- if self.config.pad_token_id is None and batch_size != 1:
- raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
- if self.config.pad_token_id is None:
- sequence_lengths = -1
- else:
- if input_ids is not None:
- # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
- sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
- sequence_lengths = sequence_lengths % input_ids.shape[-1]
- sequence_lengths = sequence_lengths.to(logits.device)
- else:
- sequence_lengths = -1
-
- pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
+ sequence_output = outputs[0]
+ sequence_output = self.dropout(sequence_output)
+ logits = self.score(sequence_output)
loss = None
if labels is not None:
- loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
+ loss = self.loss_function(logits, labels, self.config)
if not return_dict:
- output = (pooled_logits,) + transformer_outputs[1:]
+ output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
- return SequenceClassifierOutputWithPast(
+ return TokenClassifierOutput(
loss=loss,
- logits=pooled_logits,
- past_key_values=transformer_outputs.past_key_values,
- hidden_states=transformer_outputs.hidden_states,
- attentions=transformer_outputs.attentions,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
)
@@ -1069,13 +1313,15 @@ def forward(
""",
MISTRAL_START_DOCSTRING,
)
+# Copied from transformers.models.llama.modeling_llama.LlamaForQuestionAnswering with Llama->Mistral,LLAMA->MISTRAL,transformer->model
class MistralForQuestionAnswering(MistralPreTrainedModel):
base_model_prefix = "model"
+ # Copied from models.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Mistral
def __init__(self, config):
super().__init__(config)
+ self.model = MistralModel(config)
self.qa_outputs = nn.Linear(config.hidden_size, 2)
- self.model = MistralModel(config) # diff with Llama: transformer->model
# Initialize weights and apply final processing
self.post_init()
diff --git a/src/transformers/models/mistral/modular_mistral.py b/src/transformers/models/mistral/modular_mistral.py
deleted file mode 100644
index 362233a21b70f4..00000000000000
--- a/src/transformers/models/mistral/modular_mistral.py
+++ /dev/null
@@ -1,350 +0,0 @@
-from typing import Callable, List, Optional, Tuple, Union
-
-import torch
-import torch.utils.checkpoint
-from torch import nn
-
-from ...cache_utils import Cache, SlidingWindowCache, StaticCache
-from ...modeling_attn_mask_utils import AttentionMaskConverter
-from ...modeling_flash_attention_utils import FlashAttentionKwargs
-from ...modeling_outputs import QuestionAnsweringModelOutput
-from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
-from ...processing_utils import Unpack
-from ...utils import logging
-from ..llama.modeling_llama import (
- LlamaAttention,
- LlamaDecoderLayer,
- LlamaForCausalLM,
- LlamaForQuestionAnswering,
- LlamaForSequenceClassification,
- LlamaForTokenClassification,
- LlamaMLP,
- LlamaModel,
- apply_rotary_pos_emb,
- eager_attention_forward,
-)
-from .configuration_mistral import MistralConfig
-
-
-logger = logging.get_logger(__name__)
-
-_CHECKPOINT_FOR_DOC = "mistralai/Mistral-7B-v0.1"
-
-
-class MistralMLP(LlamaMLP):
- def __init__(self, config):
- super().__init__(config)
- self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
- self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
- self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
-
-
-class MistralAttention(LlamaAttention):
- def __init__(self, config: MistralConfig, layer_idx: int):
- super().__init__()
- self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False)
- self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
- self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
- self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- position_embeddings: Tuple[torch.Tensor, torch.Tensor],
- attention_mask: Optional[torch.Tensor],
- past_key_value: Optional[Cache] = None,
- cache_position: Optional[torch.LongTensor] = None,
- **kwargs: Unpack[FlashAttentionKwargs],
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
- input_shape = hidden_states.shape[:-1]
- hidden_shape = (*input_shape, -1, self.head_dim)
-
- query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
- key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
- value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
-
- cos, sin = position_embeddings
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
-
- if past_key_value is not None:
- # sin and cos are specific to RoPE models; cache_position needed for the static cache
- cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
- key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
-
- attention_interface: Callable = eager_attention_forward
- if self.config._attn_implementation != "eager":
- if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
- logger.warning_once(
- "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
- 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
- )
- else:
- attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
-
- attn_output, attn_weights = attention_interface(
- self,
- query_states,
- key_states,
- value_states,
- attention_mask,
- dropout=0.0 if not self.training else self.attention_dropout,
- scaling=self.scaling,
- sliding_window=getattr(self.config, "sliding_window", None), # main diff with Llama
- **kwargs,
- )
-
- attn_output = attn_output.reshape(*input_shape, -1).contiguous()
- attn_output = self.o_proj(attn_output)
- return attn_output, attn_weights
-
-
-class MistralDecoderLayer(LlamaDecoderLayer):
- def __init__(self, config: MistralConfig, layer_idx: int):
- super().__init__(config, layer_idx)
- self.self_attn = MistralAttention(config=config, layer_idx=layer_idx)
- self.mlp = MistralMLP(config)
-
-
-class MistralModel(LlamaModel):
- def __init__(self, config: MistralConfig):
- super().__init__(config)
- self.layers = nn.ModuleList(
- [MistralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
- )
-
- def _update_causal_mask(
- self,
- attention_mask: torch.Tensor,
- input_tensor: torch.Tensor,
- cache_position: torch.Tensor,
- past_key_values: Cache,
- output_attentions: bool,
- ):
- if self.config._attn_implementation == "flash_attention_2":
- if attention_mask is not None and past_key_values is not None:
- is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0]
- if is_padding_right:
- raise ValueError(
- "You are attempting to perform batched generation with padding_side='right'"
- " this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to "
- " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
- )
- if attention_mask is not None and 0.0 in attention_mask:
- return attention_mask
- return None
-
- # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
- # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
- # to infer the attention mask.
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
- using_static_cache = isinstance(past_key_values, StaticCache)
- using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)
-
- # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
- if (
- self.config._attn_implementation == "sdpa"
- and not (using_static_cache or using_sliding_window_cache)
- and not output_attentions
- ):
- if AttentionMaskConverter._ignore_causal_mask_sdpa(
- attention_mask,
- inputs_embeds=input_tensor,
- past_key_values_length=past_seen_tokens,
- sliding_window=self.config.sliding_window,
- is_training=self.training,
- ):
- return None
-
- dtype, device = input_tensor.dtype, input_tensor.device
- min_dtype = torch.finfo(dtype).min
- sequence_length = input_tensor.shape[1]
- # SlidingWindowCache or StaticCache
- if using_sliding_window_cache or using_static_cache:
- target_length = past_key_values.get_max_cache_shape()
- # DynamicCache or no cache
- else:
- target_length = (
- attention_mask.shape[-1]
- if isinstance(attention_mask, torch.Tensor)
- else past_seen_tokens + sequence_length + 1
- )
-
- # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
- causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
- attention_mask,
- sequence_length=sequence_length,
- target_length=target_length,
- dtype=dtype,
- device=device,
- cache_position=cache_position,
- batch_size=input_tensor.shape[0],
- config=self.config,
- past_key_values=past_key_values,
- )
-
- if (
- self.config._attn_implementation == "sdpa"
- and attention_mask is not None
- and attention_mask.device.type == "cuda"
- and not output_attentions
- ):
- # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
- # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
- # Details: https://github.com/pytorch/pytorch/issues/110213
- causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
-
- return causal_mask
-
- @staticmethod
- def _prepare_4d_causal_attention_mask_with_cache_position(
- attention_mask: torch.Tensor,
- sequence_length: int,
- target_length: int,
- dtype: torch.dtype,
- device: torch.device,
- cache_position: torch.Tensor,
- batch_size: int,
- config: MistralConfig,
- past_key_values: Cache,
- ):
- """
- Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
- `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
-
- Args:
- attention_mask (`torch.Tensor`):
- A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
- sequence_length (`int`):
- The sequence length being processed.
- target_length (`int`):
- The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
- dtype (`torch.dtype`):
- The dtype to use for the 4D attention mask.
- device (`torch.device`):
- The device to plcae the 4D attention mask on.
- cache_position (`torch.Tensor`):
- Indices depicting the position of the input sequence tokens in the sequence.
- batch_size (`torch.Tensor`):
- Batch size.
- config (`MistralConfig`):
- The model's configuration class
- past_key_values (`Cache`):
- The cache class that is being used currently to generate
- """
- if attention_mask is not None and attention_mask.dim() == 4:
- # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
- causal_mask = attention_mask
- else:
- min_dtype = torch.finfo(dtype).min
- causal_mask = torch.full(
- (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
- )
- diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
- if config.sliding_window is not None:
- # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
- # the check is needed to verify is current checkpoint was trained with sliding window or not
- if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
- sliding_attend_mask = torch.arange(target_length, device=device) <= (
- cache_position.reshape(-1, 1) - config.sliding_window
- )
- diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
- causal_mask *= diagonal_attend_mask
- causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
- if attention_mask is not None:
- causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
- if attention_mask.shape[-1] > target_length:
- attention_mask = attention_mask[:, :target_length]
- mask_length = attention_mask.shape[-1]
- padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
- padding_mask = padding_mask == 0
- causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
- padding_mask, min_dtype
- )
- return causal_mask
-
-
-class MistralForCausalLM(LlamaForCausalLM):
- pass
-
-
-class MistralForTokenClassification(LlamaForTokenClassification):
- pass
-
-
-class MistralForSequenceClassification(LlamaForSequenceClassification):
- pass
-
-
-class MistralForQuestionAnswering(LlamaForQuestionAnswering):
- base_model_prefix = "model"
-
- def __init__(self, config):
- super().__init__(config)
- self.model = MistralModel(config) # diff with Llama: transformer->model
- del self.transformer
-
- def get_input_embeddings(self):
- return self.model.embed_tokens
-
- def set_input_embeddings(self, value):
- self.model.embed_tokens = value
-
- def forward(
- self,
- input_ids: Optional[torch.LongTensor] = None,
- attention_mask: Optional[torch.FloatTensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- start_positions: Optional[torch.LongTensor] = None,
- end_positions: Optional[torch.LongTensor] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- **kwargs,
- ) -> Union[Tuple, QuestionAnsweringModelOutput]:
- r"""
- start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Labels for position (index) of the start of the labelled span for computing the token classification loss.
- Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
- are not taken into account for computing the loss.
- end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Labels for position (index) of the end of the labelled span for computing the token classification loss.
- Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
- are not taken into account for computing the loss.
- """
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
-
- outputs = self.model(
- input_ids,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_values=past_key_values,
- inputs_embeds=inputs_embeds,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
-
- sequence_output = outputs[0]
-
- logits = self.qa_outputs(sequence_output)
- start_logits, end_logits = logits.split(1, dim=-1)
- start_logits = start_logits.squeeze(-1).contiguous()
- end_logits = end_logits.squeeze(-1).contiguous()
-
- loss = None
- if start_positions is not None and end_positions is not None:
- loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs)
-
- if not return_dict:
- output = (start_logits, end_logits) + outputs[2:]
- return ((loss,) + output) if loss is not None else output
-
- return QuestionAnsweringModelOutput(
- loss=loss,
- start_logits=start_logits,
- end_logits=end_logits,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py
index 84ed327d9be920..0f04ef255c431d 100644
--- a/src/transformers/models/mixtral/modeling_mixtral.py
+++ b/src/transformers/models/mixtral/modeling_mixtral.py
@@ -1,9 +1,3 @@
-# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
-# This file was automatically generated from src/transformers/models/mixtral/modular_mixtral.py.
-# Do NOT edit this file manually as any edits will be overwritten by the generation of
-# the file from the modular. If any change should be done, please apply the change to the
-# modular_mixtral.py file directly. One of our CI enforces this.
-# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8
# Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved.
#
@@ -23,133 +17,142 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+"""PyTorch Mixtral model."""
-from typing import Callable, List, Optional, Tuple, Union
+import math
+from typing import List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
+import torch.utils.checkpoint
from torch import nn
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
from ...generation import GenerationMixin
-from ...modeling_attn_mask_utils import AttentionMaskConverter
-from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...modeling_attn_mask_utils import AttentionMaskConverter, _prepare_4d_causal_attention_mask
from ...modeling_outputs import (
- BaseModelOutputWithPast,
- CausalLMOutputWithPast,
MoeCausalLMOutputWithPast,
MoeModelOutputWithPast,
QuestionAnsweringModelOutput,
SequenceClassifierOutputWithPast,
TokenClassifierOutput,
)
-from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
-from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
-from ...processing_utils import Unpack
+from ...modeling_utils import PreTrainedModel
+from ...pytorch_utils import is_torch_greater_or_equal_than_1_13
from ...utils import (
- LossKwargs,
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
+ is_flash_attn_2_available,
logging,
replace_return_docstrings,
)
+from ...utils.import_utils import is_torch_fx_available
from .configuration_mixtral import MixtralConfig
+if is_flash_attn_2_available():
+ from ...modeling_flash_attention_utils import _flash_attention_forward
+
+# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
+# It means that the function will not be traced through and simply appear as a node in the graph.
+if is_torch_fx_available():
+ if not is_torch_greater_or_equal_than_1_13:
+ import torch.fx
+
+ _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)
+
+
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "mistralai/Mixtral-8x7B-v0.1"
_CONFIG_FOR_DOC = "MixtralConfig"
-class MixtralBlockSparseTop2MLP(nn.Module):
- def __init__(self, config: MixtralConfig):
- super().__init__()
- self.ffn_dim = config.intermediate_size
- self.hidden_dim = config.hidden_size
-
- self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
- self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
- self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
-
- self.act_fn = ACT2FN[config.hidden_act]
+def load_balancing_loss_func(
+ gate_logits: Union[torch.Tensor, Tuple[torch.Tensor], None],
+ num_experts: Optional[int] = None,
+ top_k=2,
+ attention_mask: Optional[torch.Tensor] = None,
+) -> Union[torch.Tensor, int]:
+ r"""
+ Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
- def forward(self, hidden_states):
- current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
- current_hidden_states = self.w2(current_hidden_states)
- return current_hidden_states
+ See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss
+ function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
+ experts is too unbalanced.
+ Args:
+ gate_logits:
+ Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
+ shape [batch_size X sequence_length, num_experts].
+ num_experts:
+ Number of experts
+ top_k:
+ The number of experts to route per-token, can be also interpreted as the `top-k` routing
+ parameter.
+ attention_mask (`torch.Tensor`, *optional*):
+ The attention_mask used in forward function
+ shape [batch_size X sequence_length] if not None.
-class MixtralSparseMoeBlock(nn.Module):
- """
- This implementation is
- strictly equivalent to standard MoE with full capacity (no
- dropped tokens). It's faster since it formulates MoE operations
- in terms of block-sparse operations to accommodate imbalanced
- assignments of tokens to experts, whereas standard MoE either
- (1) drop tokens at the cost of reduced performance or (2) set
- capacity factor to number of experts and thus waste computation
- and memory on padding.
+ Returns:
+ The auxiliary loss.
"""
+ if gate_logits is None or not isinstance(gate_logits, tuple):
+ return 0
- def __init__(self, config):
- super().__init__()
- self.hidden_dim = config.hidden_size
- self.ffn_dim = config.intermediate_size
- self.num_experts = config.num_local_experts
- self.top_k = config.num_experts_per_tok
+ if isinstance(gate_logits, tuple):
+ compute_device = gate_logits[0].device
+ concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)
- # gating
- self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
+ routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
- self.experts = nn.ModuleList([MixtralBlockSparseTop2MLP(config) for _ in range(self.num_experts)])
+ _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
- # Jitter parameters
- self.jitter_noise = config.router_jitter_noise
+ expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- """ """
- batch_size, sequence_length, hidden_dim = hidden_states.shape
- if self.training and self.jitter_noise > 0:
- hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise)
- hidden_states = hidden_states.view(-1, hidden_dim)
- # router_logits: (batch * sequence_length, n_experts)
- router_logits = self.gate(hidden_states)
+ if attention_mask is None:
+ # Compute the percentage of tokens routed to each experts
+ tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
- routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
- routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
- routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
- # we cast back to the input dtype
- routing_weights = routing_weights.to(hidden_states.dtype)
+ # Compute the average probability of routing to these experts
+ router_prob_per_expert = torch.mean(routing_weights, dim=0)
+ else:
+ batch_size, sequence_length = attention_mask.shape
+ num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
- final_hidden_states = torch.zeros(
- (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
+ # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
+ expert_attention_mask = (
+ attention_mask[None, :, :, None, None]
+ .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
+ .reshape(-1, top_k, num_experts)
+ .to(compute_device)
)
- # One hot encode the selected experts to create an expert mask
- # this will be used to easily index which expert is going to be sollicitated
- expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
+ # Compute the percentage of tokens routed to each experts
+ tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
+ expert_attention_mask, dim=0
+ )
- # Loop over all available experts in the model and perform the computation on each expert
- for expert_idx in range(self.num_experts):
- expert_layer = self.experts[expert_idx]
- idx, top_x = torch.where(expert_mask[expert_idx])
+ # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
+ router_per_expert_attention_mask = (
+ attention_mask[None, :, :, None]
+ .expand((num_hidden_layers, batch_size, sequence_length, num_experts))
+ .reshape(-1, num_experts)
+ .to(compute_device)
+ )
- # Index the correct hidden states and compute the expert hidden state for
- # the current expert. We need to make sure to multiply the output hidden
- # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
- current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
- current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
+ # Compute the average probability of routing to these experts
+ router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
+ router_per_expert_attention_mask, dim=0
+ )
- # However `index_add_` only support torch tensors for indexing so we'll use
- # the `top_x` tensor here.
- final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
- final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
- return final_hidden_states, router_logits
+ overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
+ return overall_loss * num_experts
+# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Mixtral
class MixtralRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
@@ -170,6 +173,45 @@ def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
+# copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Mixtral
+# TODO @longjie no longer copied from Mistral after static cache
+class MixtralRotaryEmbedding(nn.Module):
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
+ super().__init__()
+
+ self.dim = dim
+ self.max_position_embeddings = max_position_embeddings
+ self.base = base
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+
+ # Build here to make `torch.jit.trace` work.
+ self._set_cos_sin_cache(
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
+ )
+
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
+ self.max_seq_len_cached = seq_len
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
+
+ freqs = torch.outer(t, self.inv_freq)
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
+ emb = torch.cat((freqs, freqs), dim=-1)
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
+
+ def forward(self, x, seq_len=None):
+ # x: [bs, num_attention_heads, seq_len, head_size]
+ if seq_len > self.max_seq_len_cached:
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
+
+ return (
+ self.cos_cached[:seq_len].to(dtype=x.dtype),
+ self.sin_cached[:seq_len].to(dtype=x.dtype),
+ )
+
+
+# Copied from transformers.models.llama.modeling_llama.rotate_half
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
@@ -177,7 +219,9 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)
-def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+# copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
+# TODO @longjie no longer copied from Mistral after static cache
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
@@ -185,8 +229,9 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
- position_ids (`torch.Tensor`, *optional*):
- Deprecated and unused.
+ position_ids (`torch.Tensor`):
+ The position indices of the tokens corresponding to the query and key tensors. For example, this can be
+ used to pass offsetted position ids when working with a KV-cache.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
@@ -197,13 +242,14 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
- cos = cos.unsqueeze(unsqueeze_dim)
- sin = sin.unsqueeze(unsqueeze_dim)
+ cos = cos[position_ids].unsqueeze(unsqueeze_dim)
+ sin = sin[position_ids].unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
+# Copied from transformers.models.llama.modeling_llama.repeat_kv
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
@@ -216,98 +262,412 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
-def eager_attention_forward(
- module: nn.Module,
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- attention_mask: Optional[torch.Tensor],
- scaling: float,
- dropout: float = 0.0,
- **kwargs,
-):
- key_states = repeat_kv(key, module.num_key_value_groups)
- value_states = repeat_kv(value, module.num_key_value_groups)
-
- attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
- if attention_mask is not None:
- causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
- attn_weights = attn_weights + causal_mask
-
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
- attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
- attn_output = torch.matmul(attn_weights, value_states)
- attn_output = attn_output.transpose(1, 2).contiguous()
-
- return attn_output, attn_weights
-
-
+# copied from transformers.models.mistral.modeling_mistral.MistralAttention with Mistral->Mixtral
+# TODO @longjie no longer copied from Mistral after static cache
class MixtralAttention(nn.Module):
- """Multi-headed attention from 'Attention Is All You Need' paper"""
+ """
+ Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
+ and "Generating Long Sequences with Sparse Transformers".
+ """
- def __init__(self, config: MixtralConfig, layer_idx: int):
+ def __init__(self, config: MixtralConfig, layer_idx: Optional[int] = None):
super().__init__()
self.config = config
self.layer_idx = layer_idx
- self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
- self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
- self.scaling = self.head_dim**-0.5
- self.attention_dropout = config.attention_dropout
+ if layer_idx is None:
+ logger.warning_once(
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
+ "when creating this class."
+ )
+
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = config.head_dim
+ self.num_key_value_heads = config.num_key_value_heads
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
+ self.max_position_embeddings = config.max_position_embeddings
+ self.rope_theta = config.rope_theta
self.is_causal = True
- self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False)
- self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
- self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
- self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
+ self.attention_dropout = config.attention_dropout
+
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
+
+ self.rotary_emb = MixtralRotaryEmbedding(
+ self.head_dim,
+ max_position_embeddings=self.max_position_embeddings,
+ base=self.rope_theta,
+ )
+
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def forward(
self,
hidden_states: torch.Tensor,
- position_embeddings: Tuple[torch.Tensor, torch.Tensor],
- attention_mask: Optional[torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
- **kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
- input_shape = hidden_states.shape[:-1]
- hidden_shape = (*input_shape, -1, self.head_dim)
+ bsz, q_len, _ = hidden_states.size()
- query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
- key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
- value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
- cos, sin = position_embeddings
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
- # sin and cos are specific to RoPE models; cache_position needed for the static cache
- cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ if self.layer_idx is None:
+ raise ValueError(
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
+ "with a layer index."
+ )
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
+
+ if past_key_value is not None:
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
- attention_interface: Callable = eager_attention_forward
- if self.config._attn_implementation != "eager":
- if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
- logger.warning_once(
- "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
- 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
+ # repeat k/v heads if n_kv_heads < n_heads
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
+
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
+ raise ValueError(
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
+ f" {attn_weights.size()}"
+ )
+
+ if attention_mask is not None: # no matter the length, we just slice it
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.reshape(bsz, q_len, -1)
+
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+# copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2 with Mistral->Mixtral
+# TODO @longjie no longer copied from Mistral after static cache
+class MixtralFlashAttention2(MixtralAttention):
+ """
+ Mixtral flash attention module. This module inherits from `MixtralAttention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ ):
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ if self.layer_idx is None:
+ raise ValueError(
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
+ "with a layer index."
)
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
+
+ # Because the input can be padded, the absolute sequence length depends on the max position id.
+ rotary_seq_len = (
+ max(kv_seq_len, position_ids[:, -1].max().item() + 1) if position_ids is not None else kv_seq_len
+ )
+
+ cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
+
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
+
+ if past_key_value is not None:
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ # repeat k/v heads if n_kv_heads < n_heads
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+ dropout_rate = 0.0 if not self.training else self.attention_dropout
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in float16 just to be sure everything works as expected.
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
else:
- attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+ target_dtype = self.q_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
- attn_output, attn_weights = attention_interface(
- self,
+ # Reashape to the expected shape for Flash Attention
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ attn_output = _flash_attention_forward(
query_states,
key_states,
value_states,
attention_mask,
- dropout=0.0 if not self.training else self.attention_dropout,
- scaling=self.scaling,
- sliding_window=getattr(self.config, "sliding_window", None), # main diff with Llama
- **kwargs,
+ q_len,
+ position_ids=position_ids,
+ dropout=dropout_rate,
+ sliding_window=getattr(self.config, "sliding_window", None),
+ is_causal=self.is_causal,
+ )
+
+ attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+# copied from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Mixtral
+# TODO @longjie no longer copied from Mistral after static cache
+class MixtralSdpaAttention(MixtralAttention):
+ """
+ Mixtral attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
+ `MixtralAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
+ SDPA API.
+ """
+
+ # Adapted from MixtralAttention.forward
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ if output_attentions:
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
+ logger.warning_once(
+ "MixtralModel is using MixtralSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
+ )
+ return super().forward(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ )
+
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
+
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
+
+ if past_key_value is not None:
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ causal_mask = attention_mask
+ if attention_mask is not None: # no matter the length, we just slice it
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
+ if query_states.device.type == "cuda" and attention_mask is not None:
+ query_states = query_states.contiguous()
+ key_states = key_states.contiguous()
+ value_states = value_states.contiguous()
+
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
+ # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
+ is_causal = True if causal_mask is None and q_len > 1 else False
+
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
+ query_states,
+ key_states,
+ value_states,
+ attn_mask=causal_mask,
+ dropout_p=self.attention_dropout if self.training else 0.0,
+ is_causal=is_causal,
)
- attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.view(bsz, q_len, -1)
+
attn_output = self.o_proj(attn_output)
- return attn_output, attn_weights
+
+ return attn_output, None, past_key_value
+
+
+MIXTRAL_ATTENTION_CLASSES = {
+ "eager": MixtralAttention,
+ "flash_attention_2": MixtralFlashAttention2,
+ "sdpa": MixtralSdpaAttention,
+}
+
+
+class MixtralBlockSparseTop2MLP(nn.Module):
+ def __init__(self, config: MixtralConfig):
+ super().__init__()
+ self.ffn_dim = config.intermediate_size
+ self.hidden_dim = config.hidden_size
+
+ self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
+ self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
+ self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
+
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, hidden_states):
+ current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
+ current_hidden_states = self.w2(current_hidden_states)
+ return current_hidden_states
+
+
+class MixtralSparseMoeBlock(nn.Module):
+ """
+ This implementation is
+ strictly equivalent to standard MoE with full capacity (no
+ dropped tokens). It's faster since it formulates MoE operations
+ in terms of block-sparse operations to accommodate imbalanced
+ assignments of tokens to experts, whereas standard MoE either
+ (1) drop tokens at the cost of reduced performance or (2) set
+ capacity factor to number of experts and thus waste computation
+ and memory on padding.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ self.hidden_dim = config.hidden_size
+ self.ffn_dim = config.intermediate_size
+ self.num_experts = config.num_local_experts
+ self.top_k = config.num_experts_per_tok
+
+ # gating
+ self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
+
+ self.experts = nn.ModuleList([MixtralBlockSparseTop2MLP(config) for _ in range(self.num_experts)])
+
+ # Jitter parameters
+ self.jitter_noise = config.router_jitter_noise
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ """ """
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
+ if self.training and self.jitter_noise > 0:
+ hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise)
+ hidden_states = hidden_states.view(-1, hidden_dim)
+ # router_logits: (batch * sequence_length, n_experts)
+ router_logits = self.gate(hidden_states)
+
+ routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
+ routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
+ routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
+ # we cast back to the input dtype
+ routing_weights = routing_weights.to(hidden_states.dtype)
+
+ final_hidden_states = torch.zeros(
+ (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
+ )
+
+ # One hot encode the selected experts to create an expert mask
+ # this will be used to easily index which expert is going to be sollicitated
+ expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
+
+ # Loop over all available experts in the model and perform the computation on each expert
+ for expert_idx in range(self.num_experts):
+ expert_layer = self.experts[expert_idx]
+ idx, top_x = torch.where(expert_mask[expert_idx])
+
+ # Index the correct hidden states and compute the expert hidden state for
+ # the current expert. We need to make sure to multiply the output hidden
+ # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
+ current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
+ current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
+
+ # However `index_add_` only support torch tensors for indexing so we'll use
+ # the `top_x` tensor here.
+ final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
+ final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
+ return final_hidden_states, router_logits
class MixtralDecoderLayer(nn.Module):
@@ -315,7 +675,7 @@ def __init__(self, config: MixtralConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
- self.self_attn = MixtralAttention(config, layer_idx)
+ self.self_attn = MIXTRAL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
self.block_sparse_moe = MixtralSparseMoeBlock(config)
self.input_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -331,8 +691,7 @@ def forward(
output_router_logits: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
- **kwargs: Unpack[FlashAttentionKwargs],
+ **kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
@@ -361,16 +720,14 @@ def forward(
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
- hidden_states, self_attn_weights = self.self_attn(
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
- position_embeddings=position_embeddings,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
- **kwargs,
)
hidden_states = residual + hidden_states
@@ -385,77 +742,15 @@ def forward(
if output_attentions:
outputs += (self_attn_weights,)
+ if use_cache:
+ outputs += (present_key_value,)
+
if output_router_logits:
outputs += (router_logits,)
return outputs
-class MixtralRotaryEmbedding(nn.Module):
- def __init__(
- self,
- config: MixtralConfig,
- device=None,
- ):
- super().__init__()
- self.rope_kwargs = {}
- # BC: "rope_type" was originally "type"
- if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
- self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
- else:
- self.rope_type = "default"
- self.max_seq_len_cached = config.max_position_embeddings
- self.original_max_seq_len = config.max_position_embeddings
-
- self.config = config
- self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
-
- inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
- self.register_buffer("inv_freq", inv_freq, persistent=False)
- self.original_inv_freq = self.inv_freq
-
- def _dynamic_frequency_update(self, position_ids, device):
- """
- dynamic RoPE layers should recompute `inv_freq` in the following situations:
- 1 - growing beyond the cached sequence length (allow scaling)
- 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
- """
- seq_len = torch.max(position_ids) + 1
- if seq_len > self.max_seq_len_cached: # growth
- inv_freq, self.attention_scaling = self.rope_init_fn(
- self.config, device, seq_len=seq_len, **self.rope_kwargs
- )
- self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
- self.max_seq_len_cached = seq_len
-
- if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
- self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
- self.max_seq_len_cached = self.original_max_seq_len
-
- @torch.no_grad()
- def forward(self, x, position_ids):
- if "dynamic" in self.rope_type:
- self._dynamic_frequency_update(position_ids, device=x.device)
-
- # Core RoPE block
- inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
- position_ids_expanded = position_ids[:, None, :].float()
- # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
- device_type = x.device.type
- device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
- with torch.autocast(device_type=device_type, enabled=False):
- freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
- emb = torch.cat((freqs, freqs), dim=-1)
- cos = emb.cos()
- sin = emb.sin()
-
- # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
- cos = cos * self.attention_scaling
- sin = sin * self.attention_scaling
-
- return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
-
-
MIXTRAL_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
@@ -477,17 +772,17 @@ def forward(self, x, position_ids):
"The bare Mixtral Model outputting raw hidden-states without any specific head on top.",
MIXTRAL_START_DOCSTRING,
)
+# copied from transformers.models.qwen2.modeling_qwen2.Qwen2PreTrainedModel with Qwen2->Mixtral
+# TODO (Raushan): bring back copied after compile compatibility
class MixtralPreTrainedModel(PreTrainedModel):
config_class = MixtralConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["MixtralDecoderLayer"]
- _skip_keys_device_placement = ["past_key_values"]
+ _skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
- _supports_quantized_cache = True
- _supports_static_cache = True
def _init_weights(self, module):
std = self.config.initializer_range
@@ -522,7 +817,7 @@ def _init_weights(self, module):
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
- If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
`past_key_values`).
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
@@ -536,24 +831,17 @@ def _init_weights(self, module):
config.n_positions - 1]`.
[What are position IDs?](../glossary#position-ids)
- past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
- Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
- blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
- returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
-
- Two formats are allowed:
- - a [`~cache_utils.Cache`] instance, see our
- [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
- - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
- shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
- cache format.
-
- The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
- legacy cache format will be returned.
-
- If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
- have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
- of shape `(batch_size, sequence_length)`.
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
@@ -567,6 +855,9 @@ def _init_weights(self, module):
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
+ output_router_logits (`bool`, *optional*):
+ Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
+ should not be returned during inference.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
@@ -580,6 +871,8 @@ def _init_weights(self, module):
"The bare Mixtral Model outputting raw hidden-states without any specific head on top.",
MIXTRAL_START_DOCSTRING,
)
+# copied from transformers.models.mistral.modeling_mistral.MistralModel with MISTRAL->MIXTRAL,Mistral->Mixtral
+# TODO @longjie no longer copied from Mistral after static cache
class MixtralModel(MixtralPreTrainedModel):
"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MixtralDecoderLayer`]
@@ -597,10 +890,10 @@ def __init__(self, config: MixtralConfig):
self.layers = nn.ModuleList(
[MixtralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
+ self._attn_implementation = config._attn_implementation
self.norm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
- self.rotary_emb = MixtralRotaryEmbedding(config=config)
- self.gradient_checkpointing = False
+ self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
@@ -610,6 +903,7 @@ def get_input_embeddings(self):
def set_input_embeddings(self, value):
self.embed_tokens = value
+ # Ignore copy
@add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
def forward(
self,
@@ -624,8 +918,7 @@ def forward(
output_router_logits: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
- **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
- ) -> Union[Tuple, BaseModelOutputWithPast]:
+ ) -> Union[Tuple, MoeModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_router_logits = (
output_router_logits if output_router_logits is not None else self.config.output_router_logits
@@ -647,8 +940,19 @@ def forward(
)
use_cache = False
- if use_cache and past_key_values is None:
- past_key_values = DynamicCache()
+ # kept for BC (non `Cache` `past_key_values` inputs)
+ return_legacy_cache = False
+ if use_cache and not isinstance(past_key_values, Cache):
+ return_legacy_cache = True
+ if past_key_values is None:
+ past_key_values = DynamicCache()
+ else:
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
+ logger.warning_once(
+ "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
+ "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
+ "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
+ )
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
@@ -667,13 +971,11 @@ def forward(
hidden_states = inputs_embeds
- # create position embeddings to be shared across the decoder layers
- position_embeddings = self.rotary_emb(hidden_states, position_ids)
-
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
all_router_logits = () if output_router_logits else None
+ next_decoder_cache = None
for decoder_layer in self.layers:
if output_hidden_states:
@@ -690,7 +992,6 @@ def forward(
output_router_logits,
use_cache,
cache_position,
- position_embeddings,
)
else:
layer_outputs = decoder_layer(
@@ -702,12 +1003,13 @@ def forward(
output_router_logits=output_router_logits,
use_cache=use_cache,
cache_position=cache_position,
- position_embeddings=position_embeddings,
- **flash_attn_kwargs,
)
hidden_states = layer_outputs[0]
+ if use_cache:
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
+
if output_attentions:
all_self_attns += (layer_outputs[1],)
@@ -720,15 +1022,25 @@ def forward(
if output_hidden_states:
all_hidden_states += (hidden_states,)
- output = MoeModelOutputWithPast(
+ next_cache = next_decoder_cache if use_cache else None
+ if return_legacy_cache:
+ next_cache = next_cache.to_legacy_cache()
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits]
+ if v is not None
+ )
+ return MoeModelOutputWithPast(
last_hidden_state=hidden_states,
- past_key_values=past_key_values,
+ past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
router_logits=all_router_logits,
)
- return output if return_dict else output.to_tuple()
+ # Copied from transformers.models.phi3.modeling_phi3.Phi3Model._update_causal_mask
def _update_causal_mask(
self,
attention_mask: torch.Tensor,
@@ -738,14 +1050,6 @@ def _update_causal_mask(
output_attentions: bool,
):
if self.config._attn_implementation == "flash_attention_2":
- if attention_mask is not None and past_key_values is not None:
- is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0]
- if is_padding_right:
- raise ValueError(
- "You are attempting to perform batched generation with padding_side='right'"
- " this may lead to unexpected behaviour for Flash Attention version of Mixtral. Make sure to "
- " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
- )
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
@@ -813,6 +1117,7 @@ def _update_causal_mask(
return causal_mask
@staticmethod
+ # Copied from transformers.models.mistral.modeling_mistral.MistralModel._prepare_4d_causal_attention_mask_with_cache_position with Mistral->Mixtral
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
@@ -880,94 +1185,8 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
return causal_mask
-class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
-
-
-def load_balancing_loss_func(
- gate_logits: Union[torch.Tensor, Tuple[torch.Tensor], None],
- num_experts: Optional[int] = None,
- top_k=2,
- attention_mask: Optional[torch.Tensor] = None,
-) -> Union[torch.Tensor, int]:
- r"""
- Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
-
- See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss
- function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
- experts is too unbalanced.
-
- Args:
- gate_logits:
- Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
- shape [batch_size X sequence_length, num_experts].
- num_experts:
- Number of experts
- top_k:
- The number of experts to route per-token, can be also interpreted as the `top-k` routing
- parameter.
- attention_mask (`torch.Tensor`, *optional*):
- The attention_mask used in forward function
- shape [batch_size X sequence_length] if not None.
-
- Returns:
- The auxiliary loss.
- """
- if gate_logits is None or not isinstance(gate_logits, tuple):
- return 0
-
- if isinstance(gate_logits, tuple):
- compute_device = gate_logits[0].device
- concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)
-
- routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
-
- _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
-
- expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
-
- if attention_mask is None:
- # Compute the percentage of tokens routed to each experts
- tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
-
- # Compute the average probability of routing to these experts
- router_prob_per_expert = torch.mean(routing_weights, dim=0)
- else:
- batch_size, sequence_length = attention_mask.shape
- num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
-
- # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
- expert_attention_mask = (
- attention_mask[None, :, :, None, None]
- .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
- .reshape(-1, top_k, num_experts)
- .to(compute_device)
- )
-
- # Compute the percentage of tokens routed to each experts
- tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
- expert_attention_mask, dim=0
- )
-
- # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
- router_per_expert_attention_mask = (
- attention_mask[None, :, :, None]
- .expand((num_hidden_layers, batch_size, sequence_length, num_experts))
- .reshape(-1, num_experts)
- .to(compute_device)
- )
-
- # Compute the average probability of routing to these experts
- router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
- router_per_expert_attention_mask, dim=0
- )
-
- overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
- return overall_loss * num_experts
-
-
class MixtralForCausalLM(MixtralPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
- _tp_plan = {"lm_head": "colwise_rep"}
def __init__(self, config):
super().__init__(config)
@@ -977,7 +1196,6 @@ def __init__(self, config):
self.router_aux_loss_coef = config.router_aux_loss_coef
self.num_experts = config.num_local_experts
self.num_experts_per_tok = config.num_experts_per_tok
-
# Initialize weights and apply final processing
self.post_init()
@@ -1000,7 +1218,8 @@ def get_decoder(self):
return self.model
@add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
+ @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
+ # Ignore copy
def forward(
self,
input_ids: torch.LongTensor = None,
@@ -1016,8 +1235,8 @@ def forward(
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
- **kwargs: Unpack[KwargsForCausalLM],
- ) -> Union[Tuple, CausalLMOutputWithPast]:
+ **loss_kwargs,
+ ) -> Union[Tuple, MoeCausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -1072,7 +1291,6 @@ def forward(
output_router_logits=output_router_logits,
return_dict=return_dict,
cache_position=cache_position,
- **kwargs,
)
hidden_states = outputs[0]
@@ -1081,7 +1299,7 @@ def forward(
loss = None
if labels is not None:
- loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
+ loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
aux_loss = None
if output_router_logits:
@@ -1126,6 +1344,7 @@ def forward(
""",
MIXTRAL_START_DOCSTRING,
)
+# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Mixtral, LLAMA->MIXTRAL
class MixtralForSequenceClassification(MixtralPreTrainedModel):
def __init__(self, config):
super().__init__(config)
@@ -1222,6 +1441,7 @@ def forward(
""",
MIXTRAL_START_DOCSTRING,
)
+# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Mixtral, LLAMA->MIXTRAL
class MixtralForTokenClassification(MixtralPreTrainedModel):
def __init__(self, config):
super().__init__(config)
@@ -1310,13 +1530,15 @@ def forward(
""",
MIXTRAL_START_DOCSTRING,
)
+# Copied from transformers.models.mistral.modeling_mistral.MistralForQuestionAnswering with Mistral->Mixtral, MISTRAL->MIXTRAL
class MixtralForQuestionAnswering(MixtralPreTrainedModel):
base_model_prefix = "model"
+ # Copied from models.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Mixtral
def __init__(self, config):
super().__init__(config)
+ self.model = MixtralModel(config)
self.qa_outputs = nn.Linear(config.hidden_size, 2)
- self.model = MixtralModel(config) # diff with Llama: transformer->model
# Initialize weights and apply final processing
self.post_init()
diff --git a/src/transformers/models/mixtral/modular_mixtral.py b/src/transformers/models/mixtral/modular_mixtral.py
deleted file mode 100644
index a6069f69b33421..00000000000000
--- a/src/transformers/models/mixtral/modular_mixtral.py
+++ /dev/null
@@ -1,574 +0,0 @@
-# coding=utf-8
-# Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved.
-#
-# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
-# and OPT implementations in this library. It has been modified from its
-# original forms to accommodate minor architectural differences compared
-# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""PyTorch Mixtral model."""
-
-from typing import List, Optional, Tuple, Union
-
-import torch
-import torch.nn.functional as F
-import torch.utils.checkpoint
-from torch import nn
-
-from ...activations import ACT2FN
-from ...cache_utils import DynamicCache
-from ...modeling_flash_attention_utils import FlashAttentionKwargs
-from ...modeling_outputs import (
- MoeCausalLMOutputWithPast,
- MoeModelOutputWithPast,
-)
-from ...processing_utils import Unpack
-from ...utils import (
- LossKwargs,
- logging,
-)
-from ..mistral.modeling_mistral import (
- MistralAttention,
- MistralForCausalLM,
- MistralForQuestionAnswering,
- MistralForSequenceClassification,
- MistralForTokenClassification,
- MistralModel,
- MistralRMSNorm,
-)
-from .configuration_mixtral import MixtralConfig
-
-
-logger = logging.get_logger(__name__)
-
-_CHECKPOINT_FOR_DOC = "mistralai/Mixtral-8x7B-v0.1"
-_CONFIG_FOR_DOC = "MixtralConfig"
-
-
-def load_balancing_loss_func(
- gate_logits: Union[torch.Tensor, Tuple[torch.Tensor], None],
- num_experts: Optional[int] = None,
- top_k=2,
- attention_mask: Optional[torch.Tensor] = None,
-) -> Union[torch.Tensor, int]:
- r"""
- Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
-
- See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss
- function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
- experts is too unbalanced.
-
- Args:
- gate_logits:
- Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
- shape [batch_size X sequence_length, num_experts].
- num_experts:
- Number of experts
- top_k:
- The number of experts to route per-token, can be also interpreted as the `top-k` routing
- parameter.
- attention_mask (`torch.Tensor`, *optional*):
- The attention_mask used in forward function
- shape [batch_size X sequence_length] if not None.
-
- Returns:
- The auxiliary loss.
- """
- if gate_logits is None or not isinstance(gate_logits, tuple):
- return 0
-
- if isinstance(gate_logits, tuple):
- compute_device = gate_logits[0].device
- concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)
-
- routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
-
- _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
-
- expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
-
- if attention_mask is None:
- # Compute the percentage of tokens routed to each experts
- tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
-
- # Compute the average probability of routing to these experts
- router_prob_per_expert = torch.mean(routing_weights, dim=0)
- else:
- batch_size, sequence_length = attention_mask.shape
- num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
-
- # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
- expert_attention_mask = (
- attention_mask[None, :, :, None, None]
- .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
- .reshape(-1, top_k, num_experts)
- .to(compute_device)
- )
-
- # Compute the percentage of tokens routed to each experts
- tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
- expert_attention_mask, dim=0
- )
-
- # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
- router_per_expert_attention_mask = (
- attention_mask[None, :, :, None]
- .expand((num_hidden_layers, batch_size, sequence_length, num_experts))
- .reshape(-1, num_experts)
- .to(compute_device)
- )
-
- # Compute the average probability of routing to these experts
- router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
- router_per_expert_attention_mask, dim=0
- )
-
- overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
- return overall_loss * num_experts
-
-
-class MixtralBlockSparseTop2MLP(nn.Module):
- def __init__(self, config: MixtralConfig):
- super().__init__()
- self.ffn_dim = config.intermediate_size
- self.hidden_dim = config.hidden_size
-
- self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
- self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
- self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
-
- self.act_fn = ACT2FN[config.hidden_act]
-
- def forward(self, hidden_states):
- current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
- current_hidden_states = self.w2(current_hidden_states)
- return current_hidden_states
-
-
-class MixtralSparseMoeBlock(nn.Module):
- """
- This implementation is
- strictly equivalent to standard MoE with full capacity (no
- dropped tokens). It's faster since it formulates MoE operations
- in terms of block-sparse operations to accommodate imbalanced
- assignments of tokens to experts, whereas standard MoE either
- (1) drop tokens at the cost of reduced performance or (2) set
- capacity factor to number of experts and thus waste computation
- and memory on padding.
- """
-
- def __init__(self, config):
- super().__init__()
- self.hidden_dim = config.hidden_size
- self.ffn_dim = config.intermediate_size
- self.num_experts = config.num_local_experts
- self.top_k = config.num_experts_per_tok
-
- # gating
- self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
-
- self.experts = nn.ModuleList([MixtralBlockSparseTop2MLP(config) for _ in range(self.num_experts)])
-
- # Jitter parameters
- self.jitter_noise = config.router_jitter_noise
-
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- """ """
- batch_size, sequence_length, hidden_dim = hidden_states.shape
- if self.training and self.jitter_noise > 0:
- hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise)
- hidden_states = hidden_states.view(-1, hidden_dim)
- # router_logits: (batch * sequence_length, n_experts)
- router_logits = self.gate(hidden_states)
-
- routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
- routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
- routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
- # we cast back to the input dtype
- routing_weights = routing_weights.to(hidden_states.dtype)
-
- final_hidden_states = torch.zeros(
- (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
- )
-
- # One hot encode the selected experts to create an expert mask
- # this will be used to easily index which expert is going to be sollicitated
- expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
-
- # Loop over all available experts in the model and perform the computation on each expert
- for expert_idx in range(self.num_experts):
- expert_layer = self.experts[expert_idx]
- idx, top_x = torch.where(expert_mask[expert_idx])
-
- # Index the correct hidden states and compute the expert hidden state for
- # the current expert. We need to make sure to multiply the output hidden
- # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
- current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
- current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
-
- # However `index_add_` only support torch tensors for indexing so we'll use
- # the `top_x` tensor here.
- final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
- final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
- return final_hidden_states, router_logits
-
-
-class MixtralRMSNorm(MistralRMSNorm):
- pass
-
-
-class MixtralAttention(MistralAttention):
- pass
-
-
-class MixtralDecoderLayer(nn.Module):
- def __init__(self, config: MixtralConfig, layer_idx: int):
- super().__init__()
- self.hidden_size = config.hidden_size
-
- self.self_attn = MixtralAttention(config, layer_idx)
-
- self.block_sparse_moe = MixtralSparseMoeBlock(config)
- self.input_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
- self.post_attention_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
- output_attentions: Optional[bool] = False,
- output_router_logits: Optional[bool] = False,
- use_cache: Optional[bool] = False,
- cache_position: Optional[torch.LongTensor] = None,
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
- **kwargs: Unpack[FlashAttentionKwargs],
- ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
- """
- Args:
- hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
- attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
- `(batch, sequence_length)` where padding elements are indicated by 0.
- past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
- output_attentions (`bool`, *optional*):
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
- returned tensors for more detail.
- output_router_logits (`bool`, *optional*):
- Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
- should not be returned during inference.
- use_cache (`bool`, *optional*):
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
- (see `past_key_values`).
- cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
- Indices depicting the position of the input sequence tokens in the sequence.
- kwargs (`dict`, *optional*):
- Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
- into the model
- """
-
- residual = hidden_states
-
- hidden_states = self.input_layernorm(hidden_states)
-
- # Self Attention
- hidden_states, self_attn_weights = self.self_attn(
- hidden_states=hidden_states,
- position_embeddings=position_embeddings,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_value=past_key_value,
- output_attentions=output_attentions,
- use_cache=use_cache,
- cache_position=cache_position,
- **kwargs,
- )
- hidden_states = residual + hidden_states
-
- # Fully Connected
- residual = hidden_states
- hidden_states = self.post_attention_layernorm(hidden_states)
- hidden_states, router_logits = self.block_sparse_moe(hidden_states)
- hidden_states = residual + hidden_states
-
- outputs = (hidden_states,)
-
- if output_attentions:
- outputs += (self_attn_weights,)
-
- if output_router_logits:
- outputs += (router_logits,)
-
- return outputs
-
-
-class MixtralModel(MistralModel):
- def __init__(self, config: MixtralConfig):
- super().__init__(config)
- self.layers = nn.ModuleList(
- [MixtralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
- )
-
- def forward(
- self,
- input_ids: torch.LongTensor = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[List[torch.FloatTensor]] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- output_router_logits: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- cache_position: Optional[torch.LongTensor] = None,
- **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
- ) -> Union[Tuple, MoeModelOutputWithPast]:
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_router_logits = (
- output_router_logits if output_router_logits is not None else self.config.output_router_logits
- )
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- use_cache = use_cache if use_cache is not None else self.config.use_cache
-
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
-
- if (input_ids is None) ^ (inputs_embeds is not None):
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
-
- if self.gradient_checkpointing and self.training:
- if use_cache:
- logger.warning_once(
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
- )
- use_cache = False
-
- if use_cache and past_key_values is None:
- past_key_values = DynamicCache()
-
- if inputs_embeds is None:
- inputs_embeds = self.embed_tokens(input_ids)
-
- if cache_position is None:
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
- cache_position = torch.arange(
- past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
- )
- if position_ids is None:
- position_ids = cache_position.unsqueeze(0)
-
- causal_mask = self._update_causal_mask(
- attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
- )
-
- hidden_states = inputs_embeds
-
- # create position embeddings to be shared across the decoder layers
- position_embeddings = self.rotary_emb(hidden_states, position_ids)
-
- # decoder layers
- all_hidden_states = () if output_hidden_states else None
- all_self_attns = () if output_attentions else None
- all_router_logits = () if output_router_logits else None
-
- for decoder_layer in self.layers:
- if output_hidden_states:
- all_hidden_states += (hidden_states,)
-
- if self.gradient_checkpointing and self.training:
- layer_outputs = self._gradient_checkpointing_func(
- decoder_layer.__call__,
- hidden_states,
- causal_mask,
- position_ids,
- past_key_values,
- output_attentions,
- output_router_logits,
- use_cache,
- cache_position,
- position_embeddings,
- )
- else:
- layer_outputs = decoder_layer(
- hidden_states,
- attention_mask=causal_mask,
- position_ids=position_ids,
- past_key_value=past_key_values,
- output_attentions=output_attentions,
- output_router_logits=output_router_logits,
- use_cache=use_cache,
- cache_position=cache_position,
- position_embeddings=position_embeddings,
- **flash_attn_kwargs,
- )
-
- hidden_states = layer_outputs[0]
-
- if output_attentions:
- all_self_attns += (layer_outputs[1],)
-
- if output_router_logits:
- all_router_logits += (layer_outputs[-1],)
-
- hidden_states = self.norm(hidden_states)
-
- # add hidden states from the last decoder layer
- if output_hidden_states:
- all_hidden_states += (hidden_states,)
-
- output = MoeModelOutputWithPast(
- last_hidden_state=hidden_states,
- past_key_values=past_key_values,
- hidden_states=all_hidden_states,
- attentions=all_self_attns,
- router_logits=all_router_logits,
- )
- return output if return_dict else output.to_tuple()
-
-
-class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
-
-
-class MixtralForCausalLM(MistralForCausalLM):
- _tied_weights_keys = ["lm_head.weight"]
-
- def __init__(self, config):
- super().__init__(config)
- self.model = MixtralModel(config)
- self.router_aux_loss_coef = config.router_aux_loss_coef
- self.num_experts = config.num_local_experts
- self.num_experts_per_tok = config.num_experts_per_tok
-
- def forward(
- self,
- input_ids: torch.LongTensor = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[List[torch.FloatTensor]] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- labels: Optional[torch.LongTensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- output_router_logits: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- cache_position: Optional[torch.LongTensor] = None,
- num_logits_to_keep: int = 0,
- **kwargs: Unpack[KwargsForCausalLM],
- ) -> Union[Tuple, MoeCausalLMOutputWithPast]:
- r"""
- Args:
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
-
- num_logits_to_keep (`int`, *optional*):
- Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
- `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
- token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
-
- Returns:
-
- Example:
-
- ```python
- >>> from transformers import AutoTokenizer, MixtralForCausalLM
-
- >>> model = MixtralForCausalLM.from_pretrained("mistralai/Mixtral-8x7B-v0.1")
- >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-v0.1")
-
- >>> prompt = "Hey, are you conscious? Can you talk to me?"
- >>> inputs = tokenizer(prompt, return_tensors="pt")
-
- >>> # Generate
- >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
- >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
- "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
- ```"""
-
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_router_logits = (
- output_router_logits if output_router_logits is not None else self.config.output_router_logits
- )
-
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
-
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
- outputs = self.model(
- input_ids=input_ids,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_values=past_key_values,
- inputs_embeds=inputs_embeds,
- use_cache=use_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- output_router_logits=output_router_logits,
- return_dict=return_dict,
- cache_position=cache_position,
- **kwargs,
- )
-
- hidden_states = outputs[0]
- # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
- logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
-
- loss = None
- if labels is not None:
- loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
-
- aux_loss = None
- if output_router_logits:
- aux_loss = load_balancing_loss_func(
- outputs.router_logits if return_dict else outputs[-1],
- self.num_experts,
- self.num_experts_per_tok,
- attention_mask,
- )
- if labels is not None:
- loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
-
- if not return_dict:
- output = (logits,) + outputs[1:]
- if output_router_logits:
- output = (aux_loss,) + output
- return (loss,) + output if loss is not None else output
-
- return MoeCausalLMOutputWithPast(
- loss=loss,
- aux_loss=aux_loss,
- logits=logits,
- past_key_values=outputs.past_key_values,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- router_logits=outputs.router_logits,
- )
-
-
-class MixtralForSequenceClassification(MistralForSequenceClassification):
- pass
-
-
-class MixtralForTokenClassification(MistralForTokenClassification):
- pass
-
-
-class MixtralForQuestionAnswering(MistralForQuestionAnswering):
- pass
diff --git a/src/transformers/models/mllama/convert_mllama_weights_to_hf.py b/src/transformers/models/mllama/convert_mllama_weights_to_hf.py
index b2c40e27bb2b40..ca22d31ee3ca5e 100644
--- a/src/transformers/models/mllama/convert_mllama_weights_to_hf.py
+++ b/src/transformers/models/mllama/convert_mllama_weights_to_hf.py
@@ -338,11 +338,7 @@ def write_model(
print(f"Fetching all parameters from the checkpoint at {input_base_path}...")
if num_shards == 1:
- if os.path.exists(os.path.join(input_base_path, "consolidated.00.pth")):
- path = os.path.join(input_base_path, "consolidated.00.pth")
- else:
- path = os.path.join(input_base_path, "consolidated.pth")
- loaded = [torch.load(path, map_location="cpu", mmap=True)]
+ loaded = [torch.load(os.path.join(input_base_path, "consolidated.pth"), map_location="cpu", mmap=True)]
else:
loaded = [
torch.load(os.path.join(input_base_path, f"consolidated.{i:02d}.pth"), map_location="cpu", mmap=True)
diff --git a/src/transformers/models/mllama/modeling_mllama.py b/src/transformers/models/mllama/modeling_mllama.py
index 3e0c4d7a5123a7..763ad97b1e721a 100644
--- a/src/transformers/models/mllama/modeling_mllama.py
+++ b/src/transformers/models/mllama/modeling_mllama.py
@@ -829,8 +829,7 @@ def __init__(self, config):
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x):
- down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
- return down_proj
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
# Modified from transformers.models.llama.modeling_llama.LlamaDecoderLayer
@@ -859,7 +858,7 @@ def forward(
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
diff --git a/src/transformers/models/modernbert/__init__.py b/src/transformers/models/modernbert/__init__.py
deleted file mode 100644
index 18317742981909..00000000000000
--- a/src/transformers/models/modernbert/__init__.py
+++ /dev/null
@@ -1,27 +0,0 @@
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-from typing import TYPE_CHECKING
-
-from ...utils import _LazyModule
-from ...utils.import_utils import define_import_structure
-
-
-if TYPE_CHECKING:
- from .configuration_modernbert import *
- from .modeling_modernbert import *
-else:
- import sys
-
- _file = globals()["__file__"]
- sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/src/transformers/models/modernbert/configuration_modernbert.py b/src/transformers/models/modernbert/configuration_modernbert.py
deleted file mode 100644
index 13e9edf067efc4..00000000000000
--- a/src/transformers/models/modernbert/configuration_modernbert.py
+++ /dev/null
@@ -1,213 +0,0 @@
-# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
-# This file was automatically generated from src/transformers/models/modernbert/modular_modernbert.py.
-# Do NOT edit this file manually as any edits will be overwritten by the generation of
-# the file from the modular. If any change should be done, please apply the change to the
-# modular_modernbert.py file directly. One of our CI enforces this.
-# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
-# Copyright 2024 Answer.AI, LightOn, and contributors, and the HuggingFace Inc. team. All rights reserved.
-#
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from typing import Literal
-
-from ...configuration_utils import PretrainedConfig
-
-
-class ModernBertConfig(PretrainedConfig):
- r"""
- This is the configuration class to store the configuration of a [`ModernBertModel`]. It is used to instantiate an ModernBert
- model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
- defaults will yield a similar configuration to that of the ModernBERT-base.
- e.g. [answerdotai/ModernBERT-base](https://huggingface.co/answerdotai/ModernBERT-base)
-
- Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
- documentation from [`PretrainedConfig`] for more information.
-
- Args:
- vocab_size (`int`, *optional*, defaults to 50368):
- Vocabulary size of the ModernBert model. Defines the number of different tokens that can be represented by the
- `inputs_ids` passed when calling [`ModernBertModel`]
- hidden_size (`int`, *optional*, defaults to 768):
- Dimension of the hidden representations.
- intermediate_size (`int`, *optional*, defaults to 1152):
- Dimension of the MLP representations.
- num_hidden_layers (`int`, *optional*, defaults to 22):
- Number of hidden layers in the Transformer decoder.
- num_attention_heads (`int`, *optional*, defaults to 12):
- Number of attention heads for each attention layer in the Transformer decoder.
- hidden_activation (`str` or `function`, *optional*, defaults to `"gelu"`):
- The non-linear activation function (function or string) in the decoder. Will default to `"gelu"`
- if not specified.
- max_position_embeddings (`int`, *optional*, defaults to 8192):
- The maximum sequence length that this model might ever be used with.
- initializer_range (`float`, *optional*, defaults to 0.02):
- The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
- initializer_cutoff_factor (`float`, *optional*, defaults to 2.0):
- The cutoff factor for the truncated_normal_initializer for initializing all weight matrices.
- norm_eps (`float`, *optional*, defaults to 1e-05):
- The epsilon used by the rms normalization layers.
- norm_bias (`bool`, *optional*, defaults to `False`):
- Whether to use bias in the normalization layers.
- pad_token_id (`int`, *optional*, defaults to 50283):
- Padding token id.
- eos_token_id (`int`, *optional*, defaults to 50282):
- End of stream token id.
- bos_token_id (`int`, *optional*, defaults to 50281):
- Beginning of stream token id.
- cls_token_id (`int`, *optional*, defaults to 50281):
- Classification token id.
- sep_token_id (`int`, *optional*, defaults to 50282):
- Separation token id.
- global_rope_theta (`float`, *optional*, defaults to 160000.0):
- The base period of the global RoPE embeddings.
- attention_bias (`bool`, *optional*, defaults to `False`):
- Whether to use a bias in the query, key, value and output projection layers during self-attention.
- attention_dropout (`float`, *optional*, defaults to 0.0):
- The dropout ratio for the attention probabilities.
- global_attn_every_n_layers (`int`, *optional*, defaults to 3):
- The number of layers between global attention layers.
- local_attention (`int`, *optional*, defaults to 128):
- The window size for local attention.
- local_rope_theta (`float`, *optional*, defaults to 10000.0):
- The base period of the local RoPE embeddings.
- embedding_dropout (`float`, *optional*, defaults to 0.0):
- The dropout ratio for the embeddings.
- mlp_bias (`bool`, *optional*, defaults to `False`):
- Whether to use bias in the MLP layers.
- mlp_dropout (`float`, *optional*, defaults to 0.0):
- The dropout ratio for the MLP layers.
- decoder_bias (`bool`, *optional*, defaults to `True`):
- Whether to use bias in the decoder layers.
- classifier_pooling (`str`, *optional*, defaults to `"cls"`):
- The pooling method for the classifier. Should be either `"cls"` or `"mean"`. In local attention layers, the
- CLS token doesn't attend to all tokens on long sequences.
- classifier_dropout (`float`, *optional*, defaults to 0.0):
- The dropout ratio for the classifier.
- classifier_bias (`bool`, *optional*, defaults to `False`):
- Whether to use bias in the classifier.
- classifier_activation (`str`, *optional*, defaults to `"gelu"`):
- The activation function for the classifier.
- deterministic_flash_attn (`bool`, *optional*, defaults to `False`):
- Whether to use deterministic flash attention. If `False`, inference will be faster but not deterministic.
- sparse_prediction (`bool`, *optional*, defaults to `False`):
- Whether to use sparse prediction for the masked language model instead of returning the full dense logits.
- sparse_pred_ignore_index (`int`, *optional*, defaults to -100):
- The index to ignore for the sparse prediction.
- reference_compile (`bool`, *optional*):
- Whether to compile the layers of the model which were compiled during pretraining. If `None`, then parts of
- the model will be compiled if 1) `triton` is installed, 2) the model is not on MPS, 3) the model is not
- shared between devices, and 4) the model is not resized after initialization. If `True`, then the model may
- be faster in some scenarios.
-
- Examples:
-
- ```python
- >>> from transformers import ModernBertModel, ModernBertConfig
-
- >>> # Initializing a ModernBert style configuration
- >>> configuration = ModernBertConfig()
-
- >>> # Initializing a model from the modernbert-base style configuration
- >>> model = ModernBertModel(configuration)
-
- >>> # Accessing the model configuration
- >>> configuration = model.config
- ```"""
-
- model_type = "modernbert"
- keys_to_ignore_at_inference = ["past_key_values"]
-
- def __init__(
- self,
- vocab_size=50368,
- hidden_size=768,
- intermediate_size=1152,
- num_hidden_layers=22,
- num_attention_heads=12,
- hidden_activation="gelu",
- max_position_embeddings=8192,
- initializer_range=0.02,
- initializer_cutoff_factor=2.0,
- norm_eps=1e-5,
- norm_bias=False,
- pad_token_id=50283,
- eos_token_id=50282,
- bos_token_id=50281,
- cls_token_id=50281,
- sep_token_id=50282,
- global_rope_theta=160000.0,
- attention_bias=False,
- attention_dropout=0.0,
- global_attn_every_n_layers=3,
- local_attention=128,
- local_rope_theta=10000.0,
- embedding_dropout=0.0,
- mlp_bias=False,
- mlp_dropout=0.0,
- decoder_bias=True,
- classifier_pooling: Literal["cls", "mean"] = "cls",
- classifier_dropout=0.0,
- classifier_bias=False,
- classifier_activation="gelu",
- deterministic_flash_attn=False,
- sparse_prediction=False,
- sparse_pred_ignore_index=-100,
- reference_compile=None,
- **kwargs,
- ):
- super().__init__(
- pad_token_id=pad_token_id,
- bos_token_id=bos_token_id,
- eos_token_id=eos_token_id,
- cls_token_id=cls_token_id,
- sep_token_id=sep_token_id,
- **kwargs,
- )
- self.vocab_size = vocab_size
- self.max_position_embeddings = max_position_embeddings
- self.hidden_size = hidden_size
- self.intermediate_size = intermediate_size
- self.num_hidden_layers = num_hidden_layers
- self.num_attention_heads = num_attention_heads
- self.initializer_range = initializer_range
- self.initializer_cutoff_factor = initializer_cutoff_factor
- self.norm_eps = norm_eps
- self.norm_bias = norm_bias
- self.global_rope_theta = global_rope_theta
- self.attention_bias = attention_bias
- self.attention_dropout = attention_dropout
- self.hidden_activation = hidden_activation
- self.global_attn_every_n_layers = global_attn_every_n_layers
- self.local_attention = local_attention
- self.local_rope_theta = local_rope_theta
- self.embedding_dropout = embedding_dropout
- self.mlp_bias = mlp_bias
- self.mlp_dropout = mlp_dropout
- self.decoder_bias = decoder_bias
- self.classifier_pooling = classifier_pooling
- self.classifier_dropout = classifier_dropout
- self.classifier_bias = classifier_bias
- self.classifier_activation = classifier_activation
- self.deterministic_flash_attn = deterministic_flash_attn
- self.sparse_prediction = sparse_prediction
- self.sparse_pred_ignore_index = sparse_pred_ignore_index
- self.reference_compile = reference_compile
-
- if self.classifier_pooling not in ["cls", "mean"]:
- raise ValueError(
- f'Invalid value for `classifier_pooling`, should be either "cls" or "mean", but is {self.classifier_pooling}.'
- )
-
-
-__all__ = ["ModernBertConfig"]
diff --git a/src/transformers/models/modernbert/modeling_modernbert.py b/src/transformers/models/modernbert/modeling_modernbert.py
deleted file mode 100644
index 237fba6f645fa5..00000000000000
--- a/src/transformers/models/modernbert/modeling_modernbert.py
+++ /dev/null
@@ -1,1311 +0,0 @@
-# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
-# This file was automatically generated from src/transformers/models/modernbert/modular_modernbert.py.
-# Do NOT edit this file manually as any edits will be overwritten by the generation of
-# the file from the modular. If any change should be done, please apply the change to the
-# modular_modernbert.py file directly. One of our CI enforces this.
-# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
-# Copyright 2024 Answer.AI, LightOn, and contributors, and the HuggingFace Inc. team. All rights reserved.
-#
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import math
-from typing import Dict, Optional, Tuple, Union
-
-import torch
-import torch.nn.functional as F
-from torch import nn
-from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
-
-from ...activations import ACT2FN
-from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
-from ...modeling_outputs import BaseModelOutput, MaskedLMOutput, SequenceClassifierOutput, TokenClassifierOutput
-from ...modeling_utils import PreTrainedModel
-from ...utils import (
- add_code_sample_docstrings,
- add_start_docstrings,
- add_start_docstrings_to_model_forward,
- is_flash_attn_2_available,
- logging,
-)
-from ...utils.import_utils import is_triton_available
-from .configuration_modernbert import ModernBertConfig
-
-
-if is_flash_attn_2_available():
- from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
- from flash_attn.layers.rotary import RotaryEmbedding
- from flash_attn.ops.triton.rotary import apply_rotary
-else:
- RotaryEmbedding = object
-
-logger = logging.get_logger(__name__)
-
-_CHECKPOINT_FOR_DOC = "answerdotai/ModernBERT-base"
-_CONFIG_FOR_DOC = "ModernBertConfig"
-
-
-class ApplyRotaryEmbUnpad(torch.autograd.Function):
- @staticmethod
- def forward(
- ctx,
- qkv,
- cos,
- sin,
- cu_seqlens: Optional[torch.Tensor] = None,
- max_seqlen: Optional[int] = None,
- ):
- # (total_nnz, 3, nheads, headdim)
- qkv = qkv.contiguous()
- total_nnz, _three, _nheads, headdim = qkv.shape
- # We need qkv to be contiguous so that when we reshape to combine (3, nheads) dimensions,
- # we get the same tensor
- # qk = rearrange(qkv[:, :2], "b_s t h d -> b_s (t h) d")
- qk = qkv[:, :2].view(total_nnz, -1, headdim)
- apply_rotary(
- qk,
- cos,
- sin,
- seqlen_offsets=0,
- cu_seqlens=cu_seqlens,
- max_seqlen=max_seqlen,
- interleaved=False,
- inplace=True,
- )
-
- ctx.save_for_backward(cos, sin, cu_seqlens)
- ctx.max_seqlen = max_seqlen
- return qkv
-
- @staticmethod
- def backward(ctx, do):
- cos, sin, cu_seqlens = ctx.saved_tensors
- do = do.contiguous()
- total_nnz, _three, _nheads, headdim = do.shape
- # We need dqkv to be contiguous so that when we reshape to combine (3, nheads) dimensions,
- # we get the same tensor
- dqk = do[:, :2].view(total_nnz, -1, headdim)
- apply_rotary(
- dqk,
- cos,
- sin,
- seqlen_offsets=0,
- cu_seqlens=cu_seqlens,
- max_seqlen=ctx.max_seqlen,
- interleaved=False,
- inplace=True,
- conjugate=True,
- )
-
- return do, None, None, None, None, None, None
-
-
-def apply_rotary_unpadded(
- qkv,
- cos,
- sin,
- cu_seqlens: Optional[torch.Tensor] = None,
- max_seqlen: Optional[int] = None,
-):
- """
- Arguments:
- qkv: (total_nnz, 3, nheads, headdim) - input tensor for packed QKV.
- cos, sin: (seqlen_rotary, rotary_dim / 2)
- interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
- of 1st half and 2nd half (GPT-NeoX style).
- inplace: if True, apply rotary embedding in-place.
- seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount.
- Most commonly used in inference when we have KV cache.
- cu_seqlens: (batch + 1,) or None
- max_seqlen: int
- Return:
- out: (total_nnz, dim)
- rotary_dim must be <= headdim
- Apply rotary embedding to the first rotary_dim of x.
- """
- return ApplyRotaryEmbUnpad.apply(qkv, cos, sin, cu_seqlens, max_seqlen)
-
-
-class ModernBertUnpaddedRotaryEmbedding(RotaryEmbedding):
- """
- The rotary position embeddings applied directly to unpadded sequences.
- """
-
- def __init__(
- self,
- dim: int,
- base: float = 10000.0,
- max_seqlen: Optional[int] = None,
- device: Optional[torch.device] = None,
- dtype: Optional[torch.dtype] = None,
- ):
- """
- max_seqlen: if max_seqlen, device, and dtype are provided, we precompute the cos_sin_cache
- up to max_seqlen. If the max_seqlen, device, or dtype during training/inference differ,
- the cos_sin_cache wll be recomputed during the forward pass.
- """
- super().__init__(dim=dim, base=base, pos_idx_in_fp32=True, device=device, interleaved=False)
- self.max_seqlen = max_seqlen
-
- if max_seqlen is not None and device is not None and dtype is not None:
- self._update_cos_sin_cache(max_seqlen, device=device, dtype=dtype)
-
- def forward(
- self,
- qkv: torch.Tensor,
- cu_seqlens: torch.Tensor,
- max_seqlen: Optional[int] = None,
- ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
- """
- Apply rotary embedding *inplace* to qkv.
- qkv: (total_nnz, 3, nheads, headdim)
- cu_seqlens: (batch + 1,) cumulative sequence lengths
- max_seqlen: int max seq length in the batch
- """
- if max_seqlen is not None:
- self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
-
- qkv = apply_rotary_unpadded(
- qkv,
- self._cos_cached,
- self._sin_cached,
- cu_seqlens=cu_seqlens,
- max_seqlen=max_seqlen,
- )
-
- return qkv
-
- def extra_repr(self) -> str:
- return f"dim={self.dim}, base={self.base}, scale_base={self.scale_base}"
-
-
-class ModernBertEmbeddings(nn.Module):
- """
- Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
- """
-
- def __init__(self, config: ModernBertConfig):
- super().__init__()
- self.config = config
- self.tok_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
- self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
- self.drop = nn.Dropout(config.embedding_dropout)
-
- @torch.compile(dynamic=True)
- def compiled_embeddings(self, input_ids: torch.LongTensor) -> torch.Tensor:
- return self.drop(self.norm(self.tok_embeddings(input_ids)))
-
- def forward(self, input_ids: torch.LongTensor, position_ids: Optional[torch.LongTensor] = None) -> torch.Tensor:
- hidden_states = (
- self.compiled_embeddings(input_ids)
- if self.config.reference_compile
- else self.drop(self.norm(self.tok_embeddings(input_ids)))
- )
- return hidden_states
-
-
-class ModernBertMLP(nn.Module):
- """Applies the GLU at the end of each ModernBERT layer.
-
- Compared to the default BERT architecture, this block replaces :class:`~transformers.model.bert.modeling_bert.BertIntermediate`
- and :class:`~transformers.model.bert.modeling_bert.SelfOutput` with a single module that has similar functionality.
- """
-
- def __init__(self, config: ModernBertConfig):
- super().__init__()
- self.config = config
- self.Wi = nn.Linear(config.hidden_size, int(config.intermediate_size) * 2, bias=config.mlp_bias)
- self.act = ACT2FN[config.hidden_activation]
- self.drop = nn.Dropout(config.mlp_dropout)
- self.Wo = nn.Linear(config.intermediate_size, config.hidden_size, bias=config.mlp_bias)
-
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- input, gate = self.Wi(hidden_states).chunk(2, dim=-1)
- return self.Wo(self.drop(self.act(input) * gate))
-
-
-class ModernBertRotaryEmbedding(nn.Module):
- def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
- super().__init__()
-
- self.dim = dim
- self.max_position_embeddings = max_position_embeddings
- self.base = base
- inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim))
- self.register_buffer("inv_freq", tensor=inv_freq, persistent=False)
-
- @torch.no_grad()
- def forward(self, x, position_ids, seq_len=None):
- # x: [bs, num_attention_heads, seq_len, head_size]
- self.inv_freq.to(x.device)
- inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
- position_ids_expanded = position_ids[:, None, :].float()
- # Force float32 since bfloat16 loses precision on long contexts
- # See https://github.com/huggingface/transformers/pull/29285
- device_type = x.device.type
- device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
- with torch.autocast(device_type=device_type, enabled=False):
- freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
- emb = torch.cat((freqs, freqs), dim=-1)
- cos = emb.cos()
- sin = emb.sin()
- return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
-
-
-def rotate_half(x):
- """Rotates half the hidden dims of the input."""
- x1 = x[..., : x.shape[-1] // 2]
- x2 = x[..., x.shape[-1] // 2 :]
- return torch.cat((-x2, x1), dim=-1)
-
-
-def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
- """Applies Rotary Position Embedding to the query and key tensors.
-
- Args:
- q (`torch.Tensor`): The query tensor.
- k (`torch.Tensor`): The key tensor.
- cos (`torch.Tensor`): The cosine part of the rotary embedding.
- sin (`torch.Tensor`): The sine part of the rotary embedding.
- position_ids (`torch.Tensor`, *optional*):
- Deprecated and unused.
- unsqueeze_dim (`int`, *optional*, defaults to 1):
- The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
- sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
- that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
- k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
- cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
- the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
- Returns:
- `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
- """
- cos = cos.unsqueeze(unsqueeze_dim)
- sin = sin.unsqueeze(unsqueeze_dim)
- q_embed = (q * cos) + (rotate_half(q) * sin)
- k_embed = (k * cos) + (rotate_half(k) * sin)
- return q_embed, k_embed
-
-
-def eager_attention_forward(
- module: "ModernBertAttention",
- qkv: torch.Tensor,
- attention_mask: torch.Tensor,
- sliding_window_mask: torch.Tensor,
- position_ids: Optional[torch.LongTensor],
- local_attention: Tuple[int, int],
- bs: int,
- dim: int,
- output_attentions: Optional[bool] = False,
- **_kwargs,
-) -> Tuple[torch.Tensor, torch.Tensor] | Tuple[torch.Tensor]:
- # qkv: [batch_size, seqlen, 3, nheads, headdim]
- cos, sin = module.rotary_emb(qkv, position_ids=position_ids)
- query, key, value = qkv.transpose(3, 1).unbind(dim=2)
- # query, key, value: [batch_size, heads, seq_len, head_dim]
- query, key = apply_rotary_pos_emb(query, key, cos, sin)
-
- scale = module.head_dim**-0.5
- attn_weights = torch.matmul(query, key.transpose(2, 3)) * scale
-
- if local_attention != (-1, -1):
- attention_mask = sliding_window_mask
-
- attn_weights = attn_weights + attention_mask
-
- # upcast attention to fp32
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
- attn_weights = nn.functional.dropout(attn_weights, p=module.attention_dropout, training=module.training)
- attn_output = torch.matmul(attn_weights, value)
- attn_output = attn_output.transpose(1, 2).contiguous()
- attn_output = attn_output.view(bs, -1, dim)
- if output_attentions:
- return (attn_output, attn_weights)
- return (attn_output,)
-
-
-def flash_attention_forward(
- module: "ModernBertAttention",
- qkv: torch.Tensor,
- rotary_emb: ModernBertUnpaddedRotaryEmbedding,
- cu_seqlens: torch.Tensor,
- max_seqlen: int,
- local_attention: Tuple[int, int],
- bs: int,
- dim: int,
- target_dtype: torch.dtype = torch.bfloat16,
- **_kwargs,
-) -> Tuple[torch.Tensor]:
- # (total_seqlen, 3, nheads, headdim)
- qkv = rotary_emb(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen)
-
- convert_dtype = qkv.dtype not in (torch.float16, torch.bfloat16)
- if convert_dtype:
- # FA2 implementation only supports fp16 and bf16. If FA2 is supported,
- # bfloat16 must be supported as of FA2 2.5.7. (Turing GPUs not supported)
- orig_dtype = qkv.dtype
- qkv = qkv.to(target_dtype)
-
- attn = flash_attn_varlen_qkvpacked_func(
- qkv,
- cu_seqlens=cu_seqlens,
- max_seqlen=max_seqlen,
- dropout_p=module.attention_dropout if module.training else 0.0,
- deterministic=module.deterministic_flash_attn,
- window_size=local_attention,
- )
- attn = attn.to(orig_dtype) # type: ignore
- else:
- attn = flash_attn_varlen_qkvpacked_func(
- qkv,
- cu_seqlens=cu_seqlens,
- max_seqlen=max_seqlen,
- dropout_p=module.attention_dropout if module.training else 0.0,
- deterministic=module.deterministic_flash_attn,
- window_size=local_attention,
- )
- return (attn.view(bs, dim),)
-
-
-def sdpa_attention_forward(
- module: "ModernBertAttention",
- qkv: torch.Tensor,
- attention_mask: torch.Tensor,
- sliding_window_mask: torch.Tensor,
- position_ids: Optional[torch.LongTensor],
- local_attention: Tuple[int, int],
- bs: int,
- dim: int,
- **_kwargs,
-) -> Tuple[torch.Tensor]:
- # qkv: [batch_size, seqlen, 3, nheads, headdim]
- cos, sin = module.rotary_emb(qkv, position_ids=position_ids)
- query, key, value = qkv.transpose(3, 1).unbind(dim=2)
- # query, key, value: [batch_size, heads, seq_len, head_dim]
- query, key = apply_rotary_pos_emb(query, key, cos, sin)
-
- if local_attention != (-1, -1):
- attention_mask = sliding_window_mask
-
- attn_output = (
- F.scaled_dot_product_attention(
- query,
- key,
- value,
- dropout_p=module.attention_dropout if module.training else 0.0,
- attn_mask=attention_mask,
- )
- .transpose(1, 2)
- .contiguous()
- )
- attn_output = attn_output.view(bs, -1, dim)
- return (attn_output,)
-
-
-MODERNBERT_ATTENTION_FUNCTION = {
- "flash_attention_2": flash_attention_forward,
- "eager": eager_attention_forward,
- "sdpa": sdpa_attention_forward,
-}
-
-
-class ModernBertAttention(nn.Module):
- """Performs multi-headed self attention on a batch of unpadded sequences.
-
- If Flash Attention 2 is installed, this module uses Flash Attention to improve throughput.
- If Flash Attention 2 is not installed, the implementation will use PyTorch's SDPA kernel,
- which requires padding and unpadding inputs, adding some overhead.
-
- See `forward` method for additional details.
- """
-
- def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None):
- super().__init__()
- self.config = config
- self.layer_id = layer_id
-
- if config.hidden_size % config.num_attention_heads != 0:
- raise ValueError(
- f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention heads ({config.num_attention_heads})"
- )
-
- self.attention_dropout = config.attention_dropout
- self.deterministic_flash_attn = config.deterministic_flash_attn
- self.num_heads = config.num_attention_heads
- self.head_dim = config.hidden_size // config.num_attention_heads
- self.all_head_size = self.head_dim * self.num_heads
- self.Wqkv = nn.Linear(config.hidden_size, 3 * self.all_head_size, bias=config.attention_bias)
-
- if layer_id % config.global_attn_every_n_layers != 0:
- self.local_attention = (config.local_attention // 2, config.local_attention // 2)
- else:
- self.local_attention = (-1, -1)
-
- rope_theta = config.global_rope_theta
- max_position_embeddings = config.max_position_embeddings
- if self.local_attention != (-1, -1):
- if config.local_rope_theta is not None:
- rope_theta = config.local_rope_theta
- max_position_embeddings = config.local_attention
-
- if config._attn_implementation == "flash_attention_2":
- self.rotary_emb = ModernBertUnpaddedRotaryEmbedding(
- dim=self.head_dim, max_seqlen=max_position_embeddings, base=rope_theta
- )
- else:
- self.rotary_emb = ModernBertRotaryEmbedding(
- dim=self.head_dim, max_position_embeddings=max_position_embeddings, base=rope_theta
- )
-
- self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias)
- self.out_drop = nn.Dropout(config.attention_dropout) if config.attention_dropout > 0.0 else nn.Identity()
- self.pruned_heads = set()
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- output_attentions: Optional[bool] = False,
- **kwargs,
- ) -> torch.Tensor:
- qkv = self.Wqkv(hidden_states)
-
- bs = hidden_states.shape[0]
- if self.config._attn_implementation == "flash_attention_2":
- qkv = qkv.view(-1, 3, self.num_heads, self.head_dim)
- else:
- qkv = qkv.view(bs, -1, 3, self.num_heads, self.head_dim)
-
- attn_outputs = MODERNBERT_ATTENTION_FUNCTION[self.config._attn_implementation](
- self,
- qkv=qkv,
- rotary_emb=self.rotary_emb,
- local_attention=self.local_attention,
- bs=bs,
- dim=self.all_head_size,
- output_attentions=output_attentions,
- **kwargs,
- )
- hidden_states = attn_outputs[0]
- hidden_states = self.out_drop(self.Wo(hidden_states))
-
- return (hidden_states,) + attn_outputs[1:] # add attentions if outputted
-
-
-class ModernBertEncoderLayer(nn.Module):
- def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None):
- super().__init__()
- self.config = config
- if layer_id == 0:
- self.attn_norm = nn.Identity()
- else:
- self.attn_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
- self.attn = ModernBertAttention(config=config, layer_id=layer_id)
- self.mlp_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
- self.mlp = ModernBertMLP(config)
-
- @torch.compile(dynamic=True)
- def compiled_mlp(self, hidden_states: torch.Tensor) -> torch.Tensor:
- return self.mlp(self.mlp_norm(hidden_states))
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- sliding_window_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- cu_seqlens: Optional[torch.Tensor] = None,
- max_seqlen: Optional[int] = None,
- output_attentions: Optional[bool] = False,
- ) -> torch.Tensor:
- attn_outputs = self.attn(
- self.attn_norm(hidden_states),
- attention_mask=attention_mask,
- sliding_window_mask=sliding_window_mask,
- position_ids=position_ids,
- cu_seqlens=cu_seqlens,
- max_seqlen=max_seqlen,
- output_attentions=output_attentions,
- )
- hidden_states = hidden_states + attn_outputs[0]
- mlp_output = (
- self.compiled_mlp(hidden_states)
- if self.config.reference_compile
- else self.mlp(self.mlp_norm(hidden_states))
- )
- hidden_states = hidden_states + mlp_output
-
- return (hidden_states,) + attn_outputs[1:] # add attentions if outputted
-
-
-MODERNBERT_START_DOCSTRING = r"""
- This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
- library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
- etc.)
-
- This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
- Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
- and behavior.
-
- Parameters:
- config ([`ModernBertConfig`]):
- Model configuration class with all the parameters of the model. Initializing with a config file does not
- load the weights associated with the model, only the configuration. Check out the
- [`~PreTrainedModel.from_pretrained`] method to load the model weights.
-"""
-
-
-@add_start_docstrings(
- "The bare ModernBert Model outputting raw hidden-states without any specific head on top.",
- MODERNBERT_START_DOCSTRING,
-)
-class ModernBertPreTrainedModel(PreTrainedModel):
- config_class = ModernBertConfig
- base_model_prefix = "model"
- supports_gradient_checkpointing = True
- _no_split_modules = ["ModernBertEmbeddings", "ModernBertEncoderLayer"]
- _supports_flash_attn_2 = True
- _supports_sdpa = True
- _supports_flex_attn = False
-
- def _init_weights(self, module: nn.Module):
- cutoff_factor = self.config.initializer_cutoff_factor
- if cutoff_factor is None:
- cutoff_factor = 3
-
- def init_weight(module: nn.Module, std: float):
- nn.init.trunc_normal_(
- module.weight,
- mean=0.0,
- std=std,
- a=-cutoff_factor * std,
- b=cutoff_factor * std,
- )
-
- if isinstance(module, nn.Linear):
- if module.bias is not None:
- nn.init.zeros_(module.bias)
-
- stds = {
- "in": self.config.initializer_range,
- "out": self.config.initializer_range / math.sqrt(2.0 * self.config.num_hidden_layers),
- "embedding": self.config.initializer_range,
- "final_out": self.config.hidden_size**-0.5,
- }
-
- if isinstance(module, ModernBertEmbeddings):
- init_weight(module.tok_embeddings, stds["embedding"])
- elif isinstance(module, ModernBertMLP):
- init_weight(module.Wi, stds["in"])
- init_weight(module.Wo, stds["out"])
- elif isinstance(module, ModernBertAttention):
- init_weight(module.Wqkv, stds["in"])
- init_weight(module.Wo, stds["out"])
- elif isinstance(module, ModernBertPredictionHead):
- init_weight(module.dense, stds["out"])
- elif isinstance(module, ModernBertForMaskedLM):
- init_weight(module.decoder, stds["out"])
- elif isinstance(module, (ModernBertForSequenceClassification, ModernBertForTokenClassification)):
- init_weight(module.classifier, stds["final_out"])
-
- @classmethod
- def _autoset_attn_implementation(
- cls,
- config,
- use_flash_attention_2: bool = False,
- torch_dtype: Optional[torch.dtype] = None,
- device_map: Optional[Union[str, Dict[str, int]]] = None,
- check_device_map: bool = True,
- ):
- # If the user didn't specify anything, try to use flash_attention_2 if available.
- # Otherwise we fall back to the default SDPA -> Eager from the super() method.
- if config._attn_implementation_internal is None:
- config._attn_implementation_internal = "flash_attention_2"
- try:
- return cls._check_and_enable_flash_attn_2(
- config,
- torch_dtype=torch_dtype,
- device_map=device_map,
- hard_check_only=False,
- check_device_map=check_device_map,
- )
- except (ValueError, ImportError):
- config._attn_implementation_internal = None
- return super()._autoset_attn_implementation(
- config,
- use_flash_attention_2=use_flash_attention_2,
- torch_dtype=torch_dtype,
- device_map=device_map,
- check_device_map=check_device_map,
- )
-
- def _maybe_set_compile(self):
- if self.config.reference_compile is False:
- return
-
- if hasattr(self, "hf_device_map") and len(self.hf_device_map) > 1:
- if self.config.reference_compile:
- logger.warning_once(
- "If `accelerate` split the model across devices, `torch.compile` will not work. "
- "Falling back to non-compiled mode."
- )
- self.config.reference_compile = False
-
- if self.device.type == "mps":
- if self.config.reference_compile:
- logger.warning_once(
- "Compiling the model with `torch.compile` and using a `torch.mps` device is not supported. "
- "Falling back to non-compiled mode."
- )
- self.config.reference_compile = False
-
- if self.config.reference_compile is None:
- self.config.reference_compile = is_triton_available()
-
- def resize_token_embeddings(self, *args, **kwargs):
- model_embeds = super().resize_token_embeddings(*args, **kwargs)
-
- if self.config.reference_compile in {True, None}:
- if self.config.reference_compile:
- logger.warning_once(
- "Resizing token embeddings with `torch.compile` is not supported. Falling back to non-compiled mode."
- )
- self.config.reference_compile = False
-
- return model_embeds
-
-
-def _unpad_modernbert_input(
- inputs: torch.Tensor,
- attention_mask: torch.Tensor,
- position_ids: Optional[torch.Tensor] = None,
- labels: Optional[torch.Tensor] = None,
-) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, Optional[torch.Tensor], Optional[torch.Tensor]]:
- """
- Remove padding from input sequences.
-
- Args:
- inputs: (batch, seqlen, ...) or (batch, seqlen)
- attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
- position_ids: (batch, seqlen), int, position ids
- labels: (batch, seqlen), int, labels
-
- Returns:
- unpadded_inputs: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask.
- indices: (total_nnz)
- cu_seqlens: (batch + 1), the cumulative sequence lengths
- max_seqlen_in_batch: int
- unpadded_position_ids: (total_nnz) or None
- unpadded_labels: (total_nnz) or None
- """
- seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
- indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
- max_seqlen_in_batch = int(seqlens_in_batch.max().item())
- cu_seqlens = torch.nn.functional.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
-
- if inputs.dim() == 2:
- unpadded_inputs = inputs.flatten()[indices]
- else:
- batch, seqlen, *rest = inputs.shape
- shape = batch * seqlen
- unpadded_inputs = inputs.view(shape, *rest)[indices]
-
- unpadded_position_ids = position_ids.flatten()[indices] if position_ids is not None else None
- unpadded_labels = labels.flatten()[indices] if labels is not None else None
-
- return unpadded_inputs, indices, cu_seqlens, max_seqlen_in_batch, unpadded_position_ids, unpadded_labels
-
-
-def _pad_modernbert_output(
- inputs: torch.Tensor,
- indices: torch.Tensor,
- batch: int,
- seqlen: int,
-) -> torch.Tensor:
- """
- Add padding to sequences.
-
- Args:
- inputs: (total_nnz, ...) or (total_nnz,), where total_nnz = number of tokens selected in attention_mask.
- indices: (total_nnz)
- batch: int, batch size
- seqlen: int, max sequence length
-
- Returns:
- padded_inputs: (batch, seqlen, ...) or (batch, seqlen)
- """
- if inputs.dim() == 1:
- output = torch.zeros(batch * seqlen, dtype=inputs.dtype, device=inputs.device)
- output[indices] = inputs
- padded_inputs = output.view(batch, seqlen)
- else:
- _, *rest = inputs.shape
- output = torch.zeros(batch * seqlen, *rest, dtype=inputs.dtype, device=inputs.device)
- output[indices] = inputs
- padded_inputs = output.view(batch, seqlen, *rest)
-
- return padded_inputs
-
-
-MODERNBERT_INPUTS_DOCSTRING = r"""
- Args:
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
- Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
- it.
-
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
- [`PreTrainedTokenizer.__call__`] for details.
-
- [What are input IDs?](../glossary#input-ids)
- attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
-
- - 1 for tokens that are **not masked**,
- - 0 for tokens that are **masked**.
-
- [What are attention masks?](../glossary#attention-mask)
-
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
- [`PreTrainedTokenizer.__call__`] for details.
-
- If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
- `past_key_values`).
-
- If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
- and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
- information on the default strategy.
-
- - 1 indicates the head is **not masked**,
- - 0 indicates the head is **masked**.
- sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
- Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
- perform global attention, while the rest perform local attention. This mask is used to avoid attending to
- far-away tokens in the local attention layers.
- position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
- config.n_positions - 1]`.
-
- [What are position IDs?](../glossary#position-ids)
- indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
- Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
- cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
- Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
- max_seqlen (`int`, *optional*):
- Maximum sequence length in the batch. Used to pad the output tensors.
- batch_size (`int`, *optional*):
- Batch size of the input sequences. Used to pad the output tensors.
- seq_len (`int`, *optional*):
- Sequence length of the input sequences. Used to pad the output tensors.
- output_attentions (`bool`, *optional*):
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
- tensors for more detail.
- output_hidden_states (`bool`, *optional*):
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
- more detail.
- return_dict (`bool`, *optional*):
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
-"""
-
-
-@add_start_docstrings(
- "The bare ModernBert Model outputting raw hidden-states without any specific head on top.",
- MODERNBERT_START_DOCSTRING,
-)
-class ModernBertModel(ModernBertPreTrainedModel):
- def __init__(self, config: ModernBertConfig):
- super().__init__(config)
- self.config = config
- self.embeddings = ModernBertEmbeddings(config)
- self.layers = nn.ModuleList(
- [ModernBertEncoderLayer(config, layer_id) for layer_id in range(config.num_hidden_layers)]
- )
- self.final_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
- self.gradient_checkpointing = False
- self.post_init()
-
- def get_input_embeddings(self):
- return self.embeddings.tok_embeddings
-
- def set_input_embeddings(self, value):
- self.embeddings.tok_embeddings = value
-
- @add_start_docstrings_to_model_forward(MODERNBERT_INPUTS_DOCSTRING)
- @add_code_sample_docstrings(
- checkpoint=_CHECKPOINT_FOR_DOC,
- output_type=BaseModelOutput,
- config_class=_CONFIG_FOR_DOC,
- )
- def forward(
- self,
- input_ids: torch.LongTensor = None,
- attention_mask: Optional[torch.Tensor] = None,
- sliding_window_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- indices: Optional[torch.Tensor] = None,
- cu_seqlens: Optional[torch.Tensor] = None,
- max_seqlen: Optional[int] = None,
- batch_size: Optional[int] = None,
- seq_len: Optional[int] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutput]:
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
-
- all_hidden_states = () if output_hidden_states else None
- all_self_attentions = () if output_attentions else None
-
- self._maybe_set_compile()
- self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
-
- if batch_size is None and seq_len is None:
- batch_size, seq_len = input_ids.shape[:2]
-
- if attention_mask is None:
- attention_mask = torch.ones((batch_size, seq_len), device=input_ids.device, dtype=torch.bool)
-
- repad = False
- if self.config._attn_implementation == "flash_attention_2":
- if indices is None and cu_seqlens is None and max_seqlen is None:
- repad = True
- with torch.no_grad():
- input_ids, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input(
- inputs=input_ids, attention_mask=attention_mask
- )
- else:
- if position_ids is None:
- position_ids = torch.arange(seq_len, device=input_ids.device).unsqueeze(0)
-
- attention_mask, sliding_window_mask = self._update_attention_mask(
- attention_mask, output_attentions=output_attentions
- )
-
- hidden_states = self.embeddings(input_ids)
-
- for encoder_layer in self.layers:
- if output_hidden_states:
- all_hidden_states = all_hidden_states + (hidden_states,)
-
- if self.gradient_checkpointing and self.training:
- layer_outputs = self._gradient_checkpointing_func(
- encoder_layer.__call__,
- hidden_states,
- attention_mask,
- sliding_window_mask,
- position_ids,
- cu_seqlens,
- max_seqlen,
- output_attentions,
- )
- else:
- layer_outputs = encoder_layer(
- hidden_states,
- attention_mask=attention_mask,
- sliding_window_mask=sliding_window_mask,
- position_ids=position_ids,
- cu_seqlens=cu_seqlens,
- max_seqlen=max_seqlen,
- output_attentions=output_attentions,
- )
- hidden_states = layer_outputs[0]
- if output_attentions and len(layer_outputs) > 1:
- all_self_attentions = all_self_attentions + (layer_outputs[1],)
-
- if output_hidden_states:
- all_hidden_states = all_hidden_states + (hidden_states,)
-
- hidden_states = self.final_norm(hidden_states)
-
- if repad:
- hidden_states = _pad_modernbert_output(
- inputs=hidden_states, indices=indices, batch=batch_size, seqlen=seq_len
- )
- if all_hidden_states is not None:
- all_hidden_states = tuple(
- _pad_modernbert_output(inputs=hs, indices=indices, batch=batch_size, seqlen=seq_len)
- for hs in all_hidden_states
- )
-
- if not return_dict:
- return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
- return BaseModelOutput(
- last_hidden_state=hidden_states,
- hidden_states=all_hidden_states,
- attentions=all_self_attentions,
- )
-
- def _update_attention_mask(self, attention_mask: torch.Tensor, output_attentions: bool) -> torch.Tensor:
- if output_attentions:
- if self.config._attn_implementation == "sdpa":
- logger.warning_once(
- "Outputting attentions is only supported with the 'eager' attention implementation, "
- 'not with "sdpa". Falling back to `attn_implementation="eager"`.'
- )
- self.config._attn_implementation = "eager"
- elif self.config._attn_implementation != "eager":
- logger.warning_once(
- "Outputting attentions is only supported with the eager attention implementation, "
- f'not with {self.config._attn_implementation}. Consider setting `attn_implementation="eager"`.'
- " Setting `output_attentions=False`."
- )
-
- global_attention_mask = _prepare_4d_attention_mask(attention_mask, self.dtype)
-
- # Create position indices
- rows = torch.arange(global_attention_mask.shape[2]).unsqueeze(0)
- # Calculate distance between positions
- distance = torch.abs(rows - rows.T)
-
- # Create sliding window mask (1 for positions within window, 0 outside)
- window_mask = (
- (distance <= self.config.local_attention // 2).unsqueeze(0).unsqueeze(0).to(attention_mask.device)
- )
- # Combine with existing mask
- sliding_window_mask = global_attention_mask.masked_fill(window_mask.logical_not(), torch.finfo(self.dtype).min)
-
- return global_attention_mask, sliding_window_mask
-
-
-class ModernBertPredictionHead(nn.Module):
- def __init__(self, config: ModernBertConfig):
- super().__init__()
- self.config = config
- self.dense = nn.Linear(config.hidden_size, config.hidden_size, config.classifier_bias)
- self.act = ACT2FN[config.classifier_activation]
- self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
-
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- return self.norm(self.act(self.dense(hidden_states)))
-
-
-@add_start_docstrings(
- "The ModernBert Model with a decoder head on top that is used for masked language modeling.",
- MODERNBERT_START_DOCSTRING,
-)
-class ModernBertForMaskedLM(ModernBertPreTrainedModel):
- _tied_weights_keys = ["decoder.weight"]
-
- def __init__(self, config: ModernBertConfig):
- super().__init__(config)
- self.config = config
- self.model = ModernBertModel(config)
- self.head = ModernBertPredictionHead(config)
- self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=config.decoder_bias)
-
- self.sparse_prediction = self.config.sparse_prediction
- self.sparse_pred_ignore_index = self.config.sparse_pred_ignore_index
-
- # Initialize weights and apply final processing
- self.post_init()
-
- def get_output_embeddings(self):
- return self.decoder
-
- def set_output_embeddings(self, new_embeddings: nn.Linear):
- self.decoder = new_embeddings
-
- @torch.compile(dynamic=True)
- def compiled_head(self, output: torch.Tensor) -> torch.Tensor:
- return self.decoder(self.head(output))
-
- @add_start_docstrings_to_model_forward(MODERNBERT_INPUTS_DOCSTRING)
- @add_code_sample_docstrings(
- checkpoint=_CHECKPOINT_FOR_DOC,
- output_type=MaskedLMOutput,
- config_class=_CONFIG_FOR_DOC,
- )
- def forward(
- self,
- input_ids: Optional[torch.Tensor],
- attention_mask: Optional[torch.Tensor] = None,
- sliding_window_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.Tensor] = None,
- labels: Optional[torch.Tensor] = None,
- indices: Optional[torch.Tensor] = None,
- cu_seqlens: Optional[torch.Tensor] = None,
- max_seqlen: Optional[int] = None,
- batch_size: Optional[int] = None,
- seq_len: Optional[int] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- **kwargs,
- ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- self._maybe_set_compile()
-
- if self.config._attn_implementation == "flash_attention_2":
- if indices is None and cu_seqlens is None and max_seqlen is None:
- batch_size, seq_len = input_ids.shape[:2]
- if attention_mask is None:
- attention_mask = torch.ones((batch_size, seq_len), device=input_ids.device, dtype=torch.bool)
- with torch.no_grad():
- input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input(
- inputs=input_ids, attention_mask=attention_mask, position_ids=position_ids, labels=labels
- )
-
- outputs = self.model(
- input_ids,
- attention_mask=attention_mask,
- sliding_window_mask=sliding_window_mask,
- position_ids=position_ids,
- indices=indices,
- cu_seqlens=cu_seqlens,
- max_seqlen=max_seqlen,
- batch_size=batch_size,
- seq_len=seq_len,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- last_hidden_state = outputs[0]
-
- if self.sparse_prediction and labels is not None:
- # flatten labels and output first
- labels = labels.view(-1)
- last_hidden_state = last_hidden_state.view(labels.shape[0], -1)
-
- # then filter out the non-masked tokens
- mask_tokens = labels != self.sparse_pred_ignore_index
- last_hidden_state = last_hidden_state[mask_tokens]
- labels = labels[mask_tokens]
-
- logits = (
- self.compiled_head(last_hidden_state)
- if self.config.reference_compile
- else self.decoder(self.head(last_hidden_state))
- )
-
- loss = None
- if labels is not None:
- loss = self.loss_function(logits, labels, vocab_size=self.config.vocab_size)
-
- if self.config._attn_implementation == "flash_attention_2":
- with torch.no_grad():
- logits = _pad_modernbert_output(inputs=logits, indices=indices, batch=batch_size, seqlen=seq_len)
- if not return_dict:
- output = (logits,)
- return ((loss,) + output) if loss is not None else output
-
- return MaskedLMOutput(
- loss=loss,
- logits=logits,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
-
-
-@add_start_docstrings(
- "The ModernBert Model with a sequence classification head on top that performs pooling.",
- MODERNBERT_START_DOCSTRING,
-)
-class ModernBertForSequenceClassification(ModernBertPreTrainedModel):
- def __init__(self, config: ModernBertConfig):
- super().__init__(config)
- self.num_labels = config.num_labels
- self.config = config
-
- self.model = ModernBertModel(config)
- self.head = ModernBertPredictionHead(config)
- self.drop = torch.nn.Dropout(config.classifier_dropout)
- self.classifier = nn.Linear(config.hidden_size, config.num_labels)
-
- # Initialize weights and apply final processing
- self.post_init()
-
- @add_start_docstrings_to_model_forward(MODERNBERT_INPUTS_DOCSTRING)
- @add_code_sample_docstrings(
- checkpoint=_CHECKPOINT_FOR_DOC,
- output_type=SequenceClassifierOutput,
- config_class=_CONFIG_FOR_DOC,
- )
- def forward(
- self,
- input_ids: Optional[torch.Tensor],
- attention_mask: Optional[torch.Tensor] = None,
- sliding_window_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.Tensor] = None,
- labels: Optional[torch.Tensor] = None,
- indices: Optional[torch.Tensor] = None,
- cu_seqlens: Optional[torch.Tensor] = None,
- max_seqlen: Optional[int] = None,
- batch_size: Optional[int] = None,
- seq_len: Optional[int] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- **kwargs,
- ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
- r"""
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
- """
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- self._maybe_set_compile()
-
- outputs = self.model(
- input_ids,
- attention_mask=attention_mask,
- sliding_window_mask=sliding_window_mask,
- position_ids=position_ids,
- indices=indices,
- cu_seqlens=cu_seqlens,
- max_seqlen=max_seqlen,
- batch_size=batch_size,
- seq_len=seq_len,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- last_hidden_state = outputs[0]
-
- if self.config.classifier_pooling == "cls":
- last_hidden_state = last_hidden_state[:, 0]
- elif self.config.classifier_pooling == "mean":
- last_hidden_state = (last_hidden_state * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum(
- dim=1, keepdim=True
- )
-
- pooled_output = self.head(last_hidden_state)
- pooled_output = self.drop(pooled_output)
- logits = self.classifier(pooled_output)
-
- loss = None
- if labels is not None:
- if self.config.problem_type is None:
- if self.num_labels == 1:
- self.config.problem_type = "regression"
- elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
- self.config.problem_type = "single_label_classification"
- else:
- self.config.problem_type = "multi_label_classification"
-
- if self.config.problem_type == "regression":
- loss_fct = MSELoss()
- if self.num_labels == 1:
- loss = loss_fct(logits.squeeze(), labels.squeeze())
- else:
- loss = loss_fct(logits, labels)
- elif self.config.problem_type == "single_label_classification":
- loss_fct = CrossEntropyLoss()
- loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
- elif self.config.problem_type == "multi_label_classification":
- loss_fct = BCEWithLogitsLoss()
- loss = loss_fct(logits, labels)
-
- if not return_dict:
- output = (logits,)
- return ((loss,) + output) if loss is not None else output
-
- return SequenceClassifierOutput(
- loss=loss,
- logits=logits,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
-
-
-@add_start_docstrings(
- "The ModernBert Model with a token classification head on top, e.g. for Named Entity Recognition (NER) tasks.",
- MODERNBERT_START_DOCSTRING,
-)
-class ModernBertForTokenClassification(ModernBertPreTrainedModel):
- def __init__(self, config: ModernBertConfig):
- super().__init__(config)
- self.num_labels = config.num_labels
-
- self.model = ModernBertModel(config)
- self.head = ModernBertPredictionHead(config)
- self.drop = torch.nn.Dropout(config.classifier_dropout)
- self.classifier = nn.Linear(config.hidden_size, config.num_labels)
-
- # Initialize weights and apply final processing
- self.post_init()
-
- @add_start_docstrings_to_model_forward(MODERNBERT_INPUTS_DOCSTRING)
- @add_code_sample_docstrings(
- checkpoint=_CHECKPOINT_FOR_DOC,
- output_type=TokenClassifierOutput,
- config_class=_CONFIG_FOR_DOC,
- )
- def forward(
- self,
- input_ids: Optional[torch.Tensor],
- attention_mask: Optional[torch.Tensor] = None,
- sliding_window_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.Tensor] = None,
- labels: Optional[torch.Tensor] = None,
- indices: Optional[torch.Tensor] = None,
- cu_seqlens: Optional[torch.Tensor] = None,
- max_seqlen: Optional[int] = None,
- batch_size: Optional[int] = None,
- seq_len: Optional[int] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
- r"""
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
- """
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- self._maybe_set_compile()
-
- outputs = self.model(
- input_ids,
- attention_mask=attention_mask,
- sliding_window_mask=sliding_window_mask,
- position_ids=position_ids,
- indices=indices,
- cu_seqlens=cu_seqlens,
- max_seqlen=max_seqlen,
- batch_size=batch_size,
- seq_len=seq_len,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- last_hidden_state = outputs[0]
-
- last_hidden_state = self.head(last_hidden_state)
- last_hidden_state = self.drop(last_hidden_state)
- logits = self.classifier(last_hidden_state)
-
- loss = None
- if labels is not None:
- loss_fct = CrossEntropyLoss()
- loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
-
- if not return_dict:
- output = (logits,) + outputs[1:]
- return ((loss,) + output) if loss is not None else output
-
- return TokenClassifierOutput(
- loss=loss,
- logits=logits,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
-
-
-__all__ = [
- "ModernBertModel",
- "ModernBertPreTrainedModel",
- "ModernBertForMaskedLM",
- "ModernBertForSequenceClassification",
- "ModernBertForTokenClassification",
-]
diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py
deleted file mode 100644
index dac356146f3015..00000000000000
--- a/src/transformers/models/modernbert/modular_modernbert.py
+++ /dev/null
@@ -1,1465 +0,0 @@
-# Copyright 2024 Answer.AI, LightOn, and contributors, and the HuggingFace Inc. team. All rights reserved.
-#
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import math
-from typing import Dict, Literal, Optional, Tuple, Union
-
-import torch
-import torch.nn.functional as F
-import torch.utils.checkpoint
-from torch import nn
-from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
-
-from ...activations import ACT2FN
-from ...configuration_utils import PretrainedConfig
-from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
-from ...modeling_outputs import (
- BaseModelOutput,
- MaskedLMOutput,
- SequenceClassifierOutput,
- TokenClassifierOutput,
-)
-from ...modeling_utils import PreTrainedModel
-from ...utils import (
- add_code_sample_docstrings,
- add_start_docstrings,
- add_start_docstrings_to_model_forward,
- is_flash_attn_2_available,
- logging,
-)
-from ...utils.import_utils import is_triton_available
-from ..gemma.modeling_gemma import apply_rotary_pos_emb
-
-
-if is_flash_attn_2_available():
- from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
- from flash_attn.layers.rotary import RotaryEmbedding
- from flash_attn.ops.triton.rotary import apply_rotary
-else:
- RotaryEmbedding = object
-
-_CHECKPOINT_FOR_DOC = "answerdotai/ModernBERT-base"
-_CONFIG_FOR_DOC = "ModernBertConfig"
-
-logger = logging.get_logger(__name__)
-
-
-class ModernBertConfig(PretrainedConfig):
- r"""
- This is the configuration class to store the configuration of a [`ModernBertModel`]. It is used to instantiate an ModernBert
- model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
- defaults will yield a similar configuration to that of the ModernBERT-base.
- e.g. [answerdotai/ModernBERT-base](https://huggingface.co/answerdotai/ModernBERT-base)
-
- Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
- documentation from [`PretrainedConfig`] for more information.
-
- Args:
- vocab_size (`int`, *optional*, defaults to 50368):
- Vocabulary size of the ModernBert model. Defines the number of different tokens that can be represented by the
- `inputs_ids` passed when calling [`ModernBertModel`]
- hidden_size (`int`, *optional*, defaults to 768):
- Dimension of the hidden representations.
- intermediate_size (`int`, *optional*, defaults to 1152):
- Dimension of the MLP representations.
- num_hidden_layers (`int`, *optional*, defaults to 22):
- Number of hidden layers in the Transformer decoder.
- num_attention_heads (`int`, *optional*, defaults to 12):
- Number of attention heads for each attention layer in the Transformer decoder.
- hidden_activation (`str` or `function`, *optional*, defaults to `"gelu"`):
- The non-linear activation function (function or string) in the decoder. Will default to `"gelu"`
- if not specified.
- max_position_embeddings (`int`, *optional*, defaults to 8192):
- The maximum sequence length that this model might ever be used with.
- initializer_range (`float`, *optional*, defaults to 0.02):
- The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
- initializer_cutoff_factor (`float`, *optional*, defaults to 2.0):
- The cutoff factor for the truncated_normal_initializer for initializing all weight matrices.
- norm_eps (`float`, *optional*, defaults to 1e-05):
- The epsilon used by the rms normalization layers.
- norm_bias (`bool`, *optional*, defaults to `False`):
- Whether to use bias in the normalization layers.
- pad_token_id (`int`, *optional*, defaults to 50283):
- Padding token id.
- eos_token_id (`int`, *optional*, defaults to 50282):
- End of stream token id.
- bos_token_id (`int`, *optional*, defaults to 50281):
- Beginning of stream token id.
- cls_token_id (`int`, *optional*, defaults to 50281):
- Classification token id.
- sep_token_id (`int`, *optional*, defaults to 50282):
- Separation token id.
- global_rope_theta (`float`, *optional*, defaults to 160000.0):
- The base period of the global RoPE embeddings.
- attention_bias (`bool`, *optional*, defaults to `False`):
- Whether to use a bias in the query, key, value and output projection layers during self-attention.
- attention_dropout (`float`, *optional*, defaults to 0.0):
- The dropout ratio for the attention probabilities.
- global_attn_every_n_layers (`int`, *optional*, defaults to 3):
- The number of layers between global attention layers.
- local_attention (`int`, *optional*, defaults to 128):
- The window size for local attention.
- local_rope_theta (`float`, *optional*, defaults to 10000.0):
- The base period of the local RoPE embeddings.
- embedding_dropout (`float`, *optional*, defaults to 0.0):
- The dropout ratio for the embeddings.
- mlp_bias (`bool`, *optional*, defaults to `False`):
- Whether to use bias in the MLP layers.
- mlp_dropout (`float`, *optional*, defaults to 0.0):
- The dropout ratio for the MLP layers.
- decoder_bias (`bool`, *optional*, defaults to `True`):
- Whether to use bias in the decoder layers.
- classifier_pooling (`str`, *optional*, defaults to `"cls"`):
- The pooling method for the classifier. Should be either `"cls"` or `"mean"`. In local attention layers, the
- CLS token doesn't attend to all tokens on long sequences.
- classifier_dropout (`float`, *optional*, defaults to 0.0):
- The dropout ratio for the classifier.
- classifier_bias (`bool`, *optional*, defaults to `False`):
- Whether to use bias in the classifier.
- classifier_activation (`str`, *optional*, defaults to `"gelu"`):
- The activation function for the classifier.
- deterministic_flash_attn (`bool`, *optional*, defaults to `False`):
- Whether to use deterministic flash attention. If `False`, inference will be faster but not deterministic.
- sparse_prediction (`bool`, *optional*, defaults to `False`):
- Whether to use sparse prediction for the masked language model instead of returning the full dense logits.
- sparse_pred_ignore_index (`int`, *optional*, defaults to -100):
- The index to ignore for the sparse prediction.
- reference_compile (`bool`, *optional*):
- Whether to compile the layers of the model which were compiled during pretraining. If `None`, then parts of
- the model will be compiled if 1) `triton` is installed, 2) the model is not on MPS, 3) the model is not
- shared between devices, and 4) the model is not resized after initialization. If `True`, then the model may
- be faster in some scenarios.
-
- Examples:
-
- ```python
- >>> from transformers import ModernBertModel, ModernBertConfig
-
- >>> # Initializing a ModernBert style configuration
- >>> configuration = ModernBertConfig()
-
- >>> # Initializing a model from the modernbert-base style configuration
- >>> model = ModernBertModel(configuration)
-
- >>> # Accessing the model configuration
- >>> configuration = model.config
- ```"""
-
- model_type = "modernbert"
- keys_to_ignore_at_inference = ["past_key_values"]
-
- def __init__(
- self,
- vocab_size=50368,
- hidden_size=768,
- intermediate_size=1152,
- num_hidden_layers=22,
- num_attention_heads=12,
- hidden_activation="gelu",
- max_position_embeddings=8192,
- initializer_range=0.02,
- initializer_cutoff_factor=2.0,
- norm_eps=1e-5,
- norm_bias=False,
- pad_token_id=50283,
- eos_token_id=50282,
- bos_token_id=50281,
- cls_token_id=50281,
- sep_token_id=50282,
- global_rope_theta=160000.0,
- attention_bias=False,
- attention_dropout=0.0,
- global_attn_every_n_layers=3,
- local_attention=128,
- local_rope_theta=10000.0,
- embedding_dropout=0.0,
- mlp_bias=False,
- mlp_dropout=0.0,
- decoder_bias=True,
- classifier_pooling: Literal["cls", "mean"] = "cls",
- classifier_dropout=0.0,
- classifier_bias=False,
- classifier_activation="gelu",
- deterministic_flash_attn=False,
- sparse_prediction=False,
- sparse_pred_ignore_index=-100,
- reference_compile=None,
- **kwargs,
- ):
- super().__init__(
- pad_token_id=pad_token_id,
- bos_token_id=bos_token_id,
- eos_token_id=eos_token_id,
- cls_token_id=cls_token_id,
- sep_token_id=sep_token_id,
- **kwargs,
- )
- self.vocab_size = vocab_size
- self.max_position_embeddings = max_position_embeddings
- self.hidden_size = hidden_size
- self.intermediate_size = intermediate_size
- self.num_hidden_layers = num_hidden_layers
- self.num_attention_heads = num_attention_heads
- self.initializer_range = initializer_range
- self.initializer_cutoff_factor = initializer_cutoff_factor
- self.norm_eps = norm_eps
- self.norm_bias = norm_bias
- self.global_rope_theta = global_rope_theta
- self.attention_bias = attention_bias
- self.attention_dropout = attention_dropout
- self.hidden_activation = hidden_activation
- self.global_attn_every_n_layers = global_attn_every_n_layers
- self.local_attention = local_attention
- self.local_rope_theta = local_rope_theta
- self.embedding_dropout = embedding_dropout
- self.mlp_bias = mlp_bias
- self.mlp_dropout = mlp_dropout
- self.decoder_bias = decoder_bias
- self.classifier_pooling = classifier_pooling
- self.classifier_dropout = classifier_dropout
- self.classifier_bias = classifier_bias
- self.classifier_activation = classifier_activation
- self.deterministic_flash_attn = deterministic_flash_attn
- self.sparse_prediction = sparse_prediction
- self.sparse_pred_ignore_index = sparse_pred_ignore_index
- self.reference_compile = reference_compile
-
- if self.classifier_pooling not in ["cls", "mean"]:
- raise ValueError(
- f'Invalid value for `classifier_pooling`, should be either "cls" or "mean", but is {self.classifier_pooling}.'
- )
-
-
-def _unpad_modernbert_input(
- inputs: torch.Tensor,
- attention_mask: torch.Tensor,
- position_ids: Optional[torch.Tensor] = None,
- labels: Optional[torch.Tensor] = None,
-) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, Optional[torch.Tensor], Optional[torch.Tensor]]:
- """
- Remove padding from input sequences.
-
- Args:
- inputs: (batch, seqlen, ...) or (batch, seqlen)
- attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
- position_ids: (batch, seqlen), int, position ids
- labels: (batch, seqlen), int, labels
-
- Returns:
- unpadded_inputs: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask.
- indices: (total_nnz)
- cu_seqlens: (batch + 1), the cumulative sequence lengths
- max_seqlen_in_batch: int
- unpadded_position_ids: (total_nnz) or None
- unpadded_labels: (total_nnz) or None
- """
- seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
- indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
- max_seqlen_in_batch = int(seqlens_in_batch.max().item())
- cu_seqlens = torch.nn.functional.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
-
- if inputs.dim() == 2:
- unpadded_inputs = inputs.flatten()[indices]
- else:
- batch, seqlen, *rest = inputs.shape
- shape = batch * seqlen
- unpadded_inputs = inputs.view(shape, *rest)[indices]
-
- unpadded_position_ids = position_ids.flatten()[indices] if position_ids is not None else None
- unpadded_labels = labels.flatten()[indices] if labels is not None else None
-
- return unpadded_inputs, indices, cu_seqlens, max_seqlen_in_batch, unpadded_position_ids, unpadded_labels
-
-
-def _pad_modernbert_output(
- inputs: torch.Tensor,
- indices: torch.Tensor,
- batch: int,
- seqlen: int,
-) -> torch.Tensor:
- """
- Add padding to sequences.
-
- Args:
- inputs: (total_nnz, ...) or (total_nnz,), where total_nnz = number of tokens selected in attention_mask.
- indices: (total_nnz)
- batch: int, batch size
- seqlen: int, max sequence length
-
- Returns:
- padded_inputs: (batch, seqlen, ...) or (batch, seqlen)
- """
- if inputs.dim() == 1:
- output = torch.zeros(batch * seqlen, dtype=inputs.dtype, device=inputs.device)
- output[indices] = inputs
- padded_inputs = output.view(batch, seqlen)
- else:
- _, *rest = inputs.shape
- output = torch.zeros(batch * seqlen, *rest, dtype=inputs.dtype, device=inputs.device)
- output[indices] = inputs
- padded_inputs = output.view(batch, seqlen, *rest)
-
- return padded_inputs
-
-
-class ApplyRotaryEmbUnpad(torch.autograd.Function):
- @staticmethod
- def forward(
- ctx,
- qkv,
- cos,
- sin,
- cu_seqlens: Optional[torch.Tensor] = None,
- max_seqlen: Optional[int] = None,
- ):
- # (total_nnz, 3, nheads, headdim)
- qkv = qkv.contiguous()
- total_nnz, _three, _nheads, headdim = qkv.shape
- # We need qkv to be contiguous so that when we reshape to combine (3, nheads) dimensions,
- # we get the same tensor
- # qk = rearrange(qkv[:, :2], "b_s t h d -> b_s (t h) d")
- qk = qkv[:, :2].view(total_nnz, -1, headdim)
- apply_rotary(
- qk,
- cos,
- sin,
- seqlen_offsets=0,
- cu_seqlens=cu_seqlens,
- max_seqlen=max_seqlen,
- interleaved=False,
- inplace=True,
- )
-
- ctx.save_for_backward(cos, sin, cu_seqlens)
- ctx.max_seqlen = max_seqlen
- return qkv
-
- @staticmethod
- def backward(ctx, do):
- cos, sin, cu_seqlens = ctx.saved_tensors
- do = do.contiguous()
- total_nnz, _three, _nheads, headdim = do.shape
- # We need dqkv to be contiguous so that when we reshape to combine (3, nheads) dimensions,
- # we get the same tensor
- dqk = do[:, :2].view(total_nnz, -1, headdim)
- apply_rotary(
- dqk,
- cos,
- sin,
- seqlen_offsets=0,
- cu_seqlens=cu_seqlens,
- max_seqlen=ctx.max_seqlen,
- interleaved=False,
- inplace=True,
- conjugate=True,
- )
-
- return do, None, None, None, None, None, None
-
-
-def apply_rotary_unpadded(
- qkv,
- cos,
- sin,
- cu_seqlens: Optional[torch.Tensor] = None,
- max_seqlen: Optional[int] = None,
-):
- """
- Arguments:
- qkv: (total_nnz, 3, nheads, headdim) - input tensor for packed QKV.
- cos, sin: (seqlen_rotary, rotary_dim / 2)
- interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
- of 1st half and 2nd half (GPT-NeoX style).
- inplace: if True, apply rotary embedding in-place.
- seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount.
- Most commonly used in inference when we have KV cache.
- cu_seqlens: (batch + 1,) or None
- max_seqlen: int
- Return:
- out: (total_nnz, dim)
- rotary_dim must be <= headdim
- Apply rotary embedding to the first rotary_dim of x.
- """
- return ApplyRotaryEmbUnpad.apply(qkv, cos, sin, cu_seqlens, max_seqlen)
-
-
-class ModernBertUnpaddedRotaryEmbedding(RotaryEmbedding):
- """
- The rotary position embeddings applied directly to unpadded sequences.
- """
-
- def __init__(
- self,
- dim: int,
- base: float = 10000.0,
- max_seqlen: Optional[int] = None,
- device: Optional[torch.device] = None,
- dtype: Optional[torch.dtype] = None,
- ):
- """
- max_seqlen: if max_seqlen, device, and dtype are provided, we precompute the cos_sin_cache
- up to max_seqlen. If the max_seqlen, device, or dtype during training/inference differ,
- the cos_sin_cache wll be recomputed during the forward pass.
- """
- super().__init__(dim=dim, base=base, pos_idx_in_fp32=True, device=device, interleaved=False)
- self.max_seqlen = max_seqlen
-
- if max_seqlen is not None and device is not None and dtype is not None:
- self._update_cos_sin_cache(max_seqlen, device=device, dtype=dtype)
-
- def forward(
- self,
- qkv: torch.Tensor,
- cu_seqlens: torch.Tensor,
- max_seqlen: Optional[int] = None,
- ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
- """
- Apply rotary embedding *inplace* to qkv.
- qkv: (total_nnz, 3, nheads, headdim)
- cu_seqlens: (batch + 1,) cumulative sequence lengths
- max_seqlen: int max seq length in the batch
- """
- if max_seqlen is not None:
- self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
-
- qkv = apply_rotary_unpadded(
- qkv,
- self._cos_cached,
- self._sin_cached,
- cu_seqlens=cu_seqlens,
- max_seqlen=max_seqlen,
- )
-
- return qkv
-
- def extra_repr(self) -> str:
- return f"dim={self.dim}, base={self.base}, scale_base={self.scale_base}"
-
-
-class ModernBertEmbeddings(nn.Module):
- """
- Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
- """
-
- def __init__(self, config: ModernBertConfig):
- super().__init__()
- self.config = config
- self.tok_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
- self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
- self.drop = nn.Dropout(config.embedding_dropout)
-
- @torch.compile(dynamic=True)
- def compiled_embeddings(self, input_ids: torch.LongTensor) -> torch.Tensor:
- return self.drop(self.norm(self.tok_embeddings(input_ids)))
-
- def forward(self, input_ids: torch.LongTensor, position_ids: Optional[torch.LongTensor] = None) -> torch.Tensor:
- hidden_states = (
- self.compiled_embeddings(input_ids)
- if self.config.reference_compile
- else self.drop(self.norm(self.tok_embeddings(input_ids)))
- )
- return hidden_states
-
-
-class ModernBertMLP(nn.Module):
- """Applies the GLU at the end of each ModernBERT layer.
-
- Compared to the default BERT architecture, this block replaces :class:`~transformers.model.bert.modeling_bert.BertIntermediate`
- and :class:`~transformers.model.bert.modeling_bert.SelfOutput` with a single module that has similar functionality.
- """
-
- def __init__(self, config: ModernBertConfig):
- super().__init__()
- self.config = config
- self.Wi = nn.Linear(config.hidden_size, int(config.intermediate_size) * 2, bias=config.mlp_bias)
- self.act = ACT2FN[config.hidden_activation]
- self.drop = nn.Dropout(config.mlp_dropout)
- self.Wo = nn.Linear(config.intermediate_size, config.hidden_size, bias=config.mlp_bias)
-
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- input, gate = self.Wi(hidden_states).chunk(2, dim=-1)
- return self.Wo(self.drop(self.act(input) * gate))
-
-
-class ModernBertRotaryEmbedding(nn.Module):
- def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
- super().__init__()
-
- self.dim = dim
- self.max_position_embeddings = max_position_embeddings
- self.base = base
- inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim))
- self.register_buffer("inv_freq", tensor=inv_freq, persistent=False)
-
- @torch.no_grad()
- def forward(self, x, position_ids, seq_len=None):
- # x: [bs, num_attention_heads, seq_len, head_size]
- self.inv_freq.to(x.device)
- inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
- position_ids_expanded = position_ids[:, None, :].float()
- # Force float32 since bfloat16 loses precision on long contexts
- # See https://github.com/huggingface/transformers/pull/29285
- device_type = x.device.type
- device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
- with torch.autocast(device_type=device_type, enabled=False):
- freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
- emb = torch.cat((freqs, freqs), dim=-1)
- cos = emb.cos()
- sin = emb.sin()
- return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
-
-
-def eager_attention_forward(
- module: "ModernBertAttention",
- qkv: torch.Tensor,
- attention_mask: torch.Tensor,
- sliding_window_mask: torch.Tensor,
- position_ids: Optional[torch.LongTensor],
- local_attention: Tuple[int, int],
- bs: int,
- dim: int,
- output_attentions: Optional[bool] = False,
- **_kwargs,
-) -> Tuple[torch.Tensor, torch.Tensor] | Tuple[torch.Tensor]:
- # qkv: [batch_size, seqlen, 3, nheads, headdim]
- cos, sin = module.rotary_emb(qkv, position_ids=position_ids)
- query, key, value = qkv.transpose(3, 1).unbind(dim=2)
- # query, key, value: [batch_size, heads, seq_len, head_dim]
- query, key = apply_rotary_pos_emb(query, key, cos, sin)
-
- scale = module.head_dim**-0.5
- attn_weights = torch.matmul(query, key.transpose(2, 3)) * scale
-
- if local_attention != (-1, -1):
- attention_mask = sliding_window_mask
-
- attn_weights = attn_weights + attention_mask
-
- # upcast attention to fp32
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
- attn_weights = nn.functional.dropout(attn_weights, p=module.attention_dropout, training=module.training)
- attn_output = torch.matmul(attn_weights, value)
- attn_output = attn_output.transpose(1, 2).contiguous()
- attn_output = attn_output.view(bs, -1, dim)
- if output_attentions:
- return (attn_output, attn_weights)
- return (attn_output,)
-
-
-def flash_attention_forward(
- module: "ModernBertAttention",
- qkv: torch.Tensor,
- rotary_emb: ModernBertUnpaddedRotaryEmbedding,
- cu_seqlens: torch.Tensor,
- max_seqlen: int,
- local_attention: Tuple[int, int],
- bs: int,
- dim: int,
- target_dtype: torch.dtype = torch.bfloat16,
- **_kwargs,
-) -> Tuple[torch.Tensor]:
- # (total_seqlen, 3, nheads, headdim)
- qkv = rotary_emb(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen)
-
- convert_dtype = qkv.dtype not in (torch.float16, torch.bfloat16)
- if convert_dtype:
- # FA2 implementation only supports fp16 and bf16. If FA2 is supported,
- # bfloat16 must be supported as of FA2 2.5.7. (Turing GPUs not supported)
- orig_dtype = qkv.dtype
- qkv = qkv.to(target_dtype)
-
- attn = flash_attn_varlen_qkvpacked_func(
- qkv,
- cu_seqlens=cu_seqlens,
- max_seqlen=max_seqlen,
- dropout_p=module.attention_dropout if module.training else 0.0,
- deterministic=module.deterministic_flash_attn,
- window_size=local_attention,
- )
- attn = attn.to(orig_dtype) # type: ignore
- else:
- attn = flash_attn_varlen_qkvpacked_func(
- qkv,
- cu_seqlens=cu_seqlens,
- max_seqlen=max_seqlen,
- dropout_p=module.attention_dropout if module.training else 0.0,
- deterministic=module.deterministic_flash_attn,
- window_size=local_attention,
- )
- return (attn.view(bs, dim),)
-
-
-def sdpa_attention_forward(
- module: "ModernBertAttention",
- qkv: torch.Tensor,
- attention_mask: torch.Tensor,
- sliding_window_mask: torch.Tensor,
- position_ids: Optional[torch.LongTensor],
- local_attention: Tuple[int, int],
- bs: int,
- dim: int,
- **_kwargs,
-) -> Tuple[torch.Tensor]:
- # qkv: [batch_size, seqlen, 3, nheads, headdim]
- cos, sin = module.rotary_emb(qkv, position_ids=position_ids)
- query, key, value = qkv.transpose(3, 1).unbind(dim=2)
- # query, key, value: [batch_size, heads, seq_len, head_dim]
- query, key = apply_rotary_pos_emb(query, key, cos, sin)
-
- if local_attention != (-1, -1):
- attention_mask = sliding_window_mask
-
- attn_output = (
- F.scaled_dot_product_attention(
- query,
- key,
- value,
- dropout_p=module.attention_dropout if module.training else 0.0,
- attn_mask=attention_mask,
- )
- .transpose(1, 2)
- .contiguous()
- )
- attn_output = attn_output.view(bs, -1, dim)
- return (attn_output,)
-
-
-MODERNBERT_ATTENTION_FUNCTION = {
- "flash_attention_2": flash_attention_forward,
- "eager": eager_attention_forward,
- "sdpa": sdpa_attention_forward,
-}
-
-
-class ModernBertAttention(nn.Module):
- """Performs multi-headed self attention on a batch of unpadded sequences.
-
- If Flash Attention 2 is installed, this module uses Flash Attention to improve throughput.
- If Flash Attention 2 is not installed, the implementation will use PyTorch's SDPA kernel,
- which requires padding and unpadding inputs, adding some overhead.
-
- See `forward` method for additional details.
- """
-
- def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None):
- super().__init__()
- self.config = config
- self.layer_id = layer_id
-
- if config.hidden_size % config.num_attention_heads != 0:
- raise ValueError(
- f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention heads ({config.num_attention_heads})"
- )
-
- self.attention_dropout = config.attention_dropout
- self.deterministic_flash_attn = config.deterministic_flash_attn
- self.num_heads = config.num_attention_heads
- self.head_dim = config.hidden_size // config.num_attention_heads
- self.all_head_size = self.head_dim * self.num_heads
- self.Wqkv = nn.Linear(config.hidden_size, 3 * self.all_head_size, bias=config.attention_bias)
-
- if layer_id % config.global_attn_every_n_layers != 0:
- self.local_attention = (config.local_attention // 2, config.local_attention // 2)
- else:
- self.local_attention = (-1, -1)
-
- rope_theta = config.global_rope_theta
- max_position_embeddings = config.max_position_embeddings
- if self.local_attention != (-1, -1):
- if config.local_rope_theta is not None:
- rope_theta = config.local_rope_theta
- max_position_embeddings = config.local_attention
-
- if config._attn_implementation == "flash_attention_2":
- self.rotary_emb = ModernBertUnpaddedRotaryEmbedding(
- dim=self.head_dim, max_seqlen=max_position_embeddings, base=rope_theta
- )
- else:
- self.rotary_emb = ModernBertRotaryEmbedding(
- dim=self.head_dim, max_position_embeddings=max_position_embeddings, base=rope_theta
- )
-
- self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias)
- self.out_drop = nn.Dropout(config.attention_dropout) if config.attention_dropout > 0.0 else nn.Identity()
- self.pruned_heads = set()
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- output_attentions: Optional[bool] = False,
- **kwargs,
- ) -> torch.Tensor:
- qkv = self.Wqkv(hidden_states)
-
- bs = hidden_states.shape[0]
- if self.config._attn_implementation == "flash_attention_2":
- qkv = qkv.view(-1, 3, self.num_heads, self.head_dim)
- else:
- qkv = qkv.view(bs, -1, 3, self.num_heads, self.head_dim)
-
- attn_outputs = MODERNBERT_ATTENTION_FUNCTION[self.config._attn_implementation](
- self,
- qkv=qkv,
- rotary_emb=self.rotary_emb,
- local_attention=self.local_attention,
- bs=bs,
- dim=self.all_head_size,
- output_attentions=output_attentions,
- **kwargs,
- )
- hidden_states = attn_outputs[0]
- hidden_states = self.out_drop(self.Wo(hidden_states))
-
- return (hidden_states,) + attn_outputs[1:] # add attentions if outputted
-
-
-class ModernBertEncoderLayer(nn.Module):
- def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None):
- super().__init__()
- self.config = config
- if layer_id == 0:
- self.attn_norm = nn.Identity()
- else:
- self.attn_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
- self.attn = ModernBertAttention(config=config, layer_id=layer_id)
- self.mlp_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
- self.mlp = ModernBertMLP(config)
-
- @torch.compile(dynamic=True)
- def compiled_mlp(self, hidden_states: torch.Tensor) -> torch.Tensor:
- return self.mlp(self.mlp_norm(hidden_states))
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- sliding_window_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- cu_seqlens: Optional[torch.Tensor] = None,
- max_seqlen: Optional[int] = None,
- output_attentions: Optional[bool] = False,
- ) -> torch.Tensor:
- attn_outputs = self.attn(
- self.attn_norm(hidden_states),
- attention_mask=attention_mask,
- sliding_window_mask=sliding_window_mask,
- position_ids=position_ids,
- cu_seqlens=cu_seqlens,
- max_seqlen=max_seqlen,
- output_attentions=output_attentions,
- )
- hidden_states = hidden_states + attn_outputs[0]
- mlp_output = (
- self.compiled_mlp(hidden_states)
- if self.config.reference_compile
- else self.mlp(self.mlp_norm(hidden_states))
- )
- hidden_states = hidden_states + mlp_output
-
- return (hidden_states,) + attn_outputs[1:] # add attentions if outputted
-
-
-MODERNBERT_START_DOCSTRING = r"""
- This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
- library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
- etc.)
-
- This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
- Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
- and behavior.
-
- Parameters:
- config ([`ModernBertConfig`]):
- Model configuration class with all the parameters of the model. Initializing with a config file does not
- load the weights associated with the model, only the configuration. Check out the
- [`~PreTrainedModel.from_pretrained`] method to load the model weights.
-"""
-
-
-@add_start_docstrings(
- "The bare ModernBert Model outputting raw hidden-states without any specific head on top.",
- MODERNBERT_START_DOCSTRING,
-)
-class ModernBertPreTrainedModel(PreTrainedModel):
- config_class = ModernBertConfig
- base_model_prefix = "model"
- supports_gradient_checkpointing = True
- _no_split_modules = ["ModernBertEmbeddings", "ModernBertEncoderLayer"]
- _supports_flash_attn_2 = True
- _supports_sdpa = True
- _supports_flex_attn = False
-
- def _init_weights(self, module: nn.Module):
- cutoff_factor = self.config.initializer_cutoff_factor
- if cutoff_factor is None:
- cutoff_factor = 3
-
- def init_weight(module: nn.Module, std: float):
- nn.init.trunc_normal_(
- module.weight,
- mean=0.0,
- std=std,
- a=-cutoff_factor * std,
- b=cutoff_factor * std,
- )
-
- if isinstance(module, nn.Linear):
- if module.bias is not None:
- nn.init.zeros_(module.bias)
-
- stds = {
- "in": self.config.initializer_range,
- "out": self.config.initializer_range / math.sqrt(2.0 * self.config.num_hidden_layers),
- "embedding": self.config.initializer_range,
- "final_out": self.config.hidden_size**-0.5,
- }
-
- if isinstance(module, ModernBertEmbeddings):
- init_weight(module.tok_embeddings, stds["embedding"])
- elif isinstance(module, ModernBertMLP):
- init_weight(module.Wi, stds["in"])
- init_weight(module.Wo, stds["out"])
- elif isinstance(module, ModernBertAttention):
- init_weight(module.Wqkv, stds["in"])
- init_weight(module.Wo, stds["out"])
- elif isinstance(module, ModernBertPredictionHead):
- init_weight(module.dense, stds["out"])
- elif isinstance(module, ModernBertForMaskedLM):
- init_weight(module.decoder, stds["out"])
- elif isinstance(module, (ModernBertForSequenceClassification, ModernBertForTokenClassification)):
- init_weight(module.classifier, stds["final_out"])
-
- @classmethod
- def _autoset_attn_implementation(
- cls,
- config,
- use_flash_attention_2: bool = False,
- torch_dtype: Optional[torch.dtype] = None,
- device_map: Optional[Union[str, Dict[str, int]]] = None,
- check_device_map: bool = True,
- ):
- # If the user didn't specify anything, try to use flash_attention_2 if available.
- # Otherwise we fall back to the default SDPA -> Eager from the super() method.
- if config._attn_implementation_internal is None:
- config._attn_implementation_internal = "flash_attention_2"
- try:
- return cls._check_and_enable_flash_attn_2(
- config,
- torch_dtype=torch_dtype,
- device_map=device_map,
- hard_check_only=False,
- check_device_map=check_device_map,
- )
- except (ValueError, ImportError):
- config._attn_implementation_internal = None
- return super()._autoset_attn_implementation(
- config,
- use_flash_attention_2=use_flash_attention_2,
- torch_dtype=torch_dtype,
- device_map=device_map,
- check_device_map=check_device_map,
- )
-
- def _maybe_set_compile(self):
- if self.config.reference_compile is False:
- return
-
- if hasattr(self, "hf_device_map") and len(self.hf_device_map) > 1:
- if self.config.reference_compile:
- logger.warning_once(
- "If `accelerate` split the model across devices, `torch.compile` will not work. "
- "Falling back to non-compiled mode."
- )
- self.config.reference_compile = False
-
- if self.device.type == "mps":
- if self.config.reference_compile:
- logger.warning_once(
- "Compiling the model with `torch.compile` and using a `torch.mps` device is not supported. "
- "Falling back to non-compiled mode."
- )
- self.config.reference_compile = False
-
- if self.config.reference_compile is None:
- self.config.reference_compile = is_triton_available()
-
- def resize_token_embeddings(self, *args, **kwargs):
- model_embeds = super().resize_token_embeddings(*args, **kwargs)
-
- if self.config.reference_compile in {True, None}:
- if self.config.reference_compile:
- logger.warning_once(
- "Resizing token embeddings with `torch.compile` is not supported. Falling back to non-compiled mode."
- )
- self.config.reference_compile = False
-
- return model_embeds
-
-
-MODERNBERT_INPUTS_DOCSTRING = r"""
- Args:
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
- Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
- it.
-
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
- [`PreTrainedTokenizer.__call__`] for details.
-
- [What are input IDs?](../glossary#input-ids)
- attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
-
- - 1 for tokens that are **not masked**,
- - 0 for tokens that are **masked**.
-
- [What are attention masks?](../glossary#attention-mask)
-
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
- [`PreTrainedTokenizer.__call__`] for details.
-
- If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
- `past_key_values`).
-
- If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
- and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
- information on the default strategy.
-
- - 1 indicates the head is **not masked**,
- - 0 indicates the head is **masked**.
- sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
- Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
- perform global attention, while the rest perform local attention. This mask is used to avoid attending to
- far-away tokens in the local attention layers.
- position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
- config.n_positions - 1]`.
-
- [What are position IDs?](../glossary#position-ids)
- indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
- Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
- cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
- Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
- max_seqlen (`int`, *optional*):
- Maximum sequence length in the batch. Used to pad the output tensors.
- batch_size (`int`, *optional*):
- Batch size of the input sequences. Used to pad the output tensors.
- seq_len (`int`, *optional*):
- Sequence length of the input sequences. Used to pad the output tensors.
- output_attentions (`bool`, *optional*):
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
- tensors for more detail.
- output_hidden_states (`bool`, *optional*):
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
- more detail.
- return_dict (`bool`, *optional*):
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
-"""
-
-
-@add_start_docstrings(
- "The bare ModernBert Model outputting raw hidden-states without any specific head on top.",
- MODERNBERT_START_DOCSTRING,
-)
-class ModernBertModel(ModernBertPreTrainedModel):
- def __init__(self, config: ModernBertConfig):
- super().__init__(config)
- self.config = config
- self.embeddings = ModernBertEmbeddings(config)
- self.layers = nn.ModuleList(
- [ModernBertEncoderLayer(config, layer_id) for layer_id in range(config.num_hidden_layers)]
- )
- self.final_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
- self.gradient_checkpointing = False
- self.post_init()
-
- def get_input_embeddings(self):
- return self.embeddings.tok_embeddings
-
- def set_input_embeddings(self, value):
- self.embeddings.tok_embeddings = value
-
- @add_start_docstrings_to_model_forward(MODERNBERT_INPUTS_DOCSTRING)
- @add_code_sample_docstrings(
- checkpoint=_CHECKPOINT_FOR_DOC,
- output_type=BaseModelOutput,
- config_class=_CONFIG_FOR_DOC,
- )
- def forward(
- self,
- input_ids: torch.LongTensor = None,
- attention_mask: Optional[torch.Tensor] = None,
- sliding_window_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- indices: Optional[torch.Tensor] = None,
- cu_seqlens: Optional[torch.Tensor] = None,
- max_seqlen: Optional[int] = None,
- batch_size: Optional[int] = None,
- seq_len: Optional[int] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutput]:
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
-
- all_hidden_states = () if output_hidden_states else None
- all_self_attentions = () if output_attentions else None
-
- self._maybe_set_compile()
- self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
-
- if batch_size is None and seq_len is None:
- batch_size, seq_len = input_ids.shape[:2]
-
- if attention_mask is None:
- attention_mask = torch.ones((batch_size, seq_len), device=input_ids.device, dtype=torch.bool)
-
- repad = False
- if self.config._attn_implementation == "flash_attention_2":
- if indices is None and cu_seqlens is None and max_seqlen is None:
- repad = True
- with torch.no_grad():
- input_ids, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input(
- inputs=input_ids, attention_mask=attention_mask
- )
- else:
- if position_ids is None:
- position_ids = torch.arange(seq_len, device=input_ids.device).unsqueeze(0)
-
- attention_mask, sliding_window_mask = self._update_attention_mask(
- attention_mask, output_attentions=output_attentions
- )
-
- hidden_states = self.embeddings(input_ids)
-
- for encoder_layer in self.layers:
- if output_hidden_states:
- all_hidden_states = all_hidden_states + (hidden_states,)
-
- if self.gradient_checkpointing and self.training:
- layer_outputs = self._gradient_checkpointing_func(
- encoder_layer.__call__,
- hidden_states,
- attention_mask,
- sliding_window_mask,
- position_ids,
- cu_seqlens,
- max_seqlen,
- output_attentions,
- )
- else:
- layer_outputs = encoder_layer(
- hidden_states,
- attention_mask=attention_mask,
- sliding_window_mask=sliding_window_mask,
- position_ids=position_ids,
- cu_seqlens=cu_seqlens,
- max_seqlen=max_seqlen,
- output_attentions=output_attentions,
- )
- hidden_states = layer_outputs[0]
- if output_attentions and len(layer_outputs) > 1:
- all_self_attentions = all_self_attentions + (layer_outputs[1],)
-
- if output_hidden_states:
- all_hidden_states = all_hidden_states + (hidden_states,)
-
- hidden_states = self.final_norm(hidden_states)
-
- if repad:
- hidden_states = _pad_modernbert_output(
- inputs=hidden_states, indices=indices, batch=batch_size, seqlen=seq_len
- )
- if all_hidden_states is not None:
- all_hidden_states = tuple(
- _pad_modernbert_output(inputs=hs, indices=indices, batch=batch_size, seqlen=seq_len)
- for hs in all_hidden_states
- )
-
- if not return_dict:
- return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
- return BaseModelOutput(
- last_hidden_state=hidden_states,
- hidden_states=all_hidden_states,
- attentions=all_self_attentions,
- )
-
- def _update_attention_mask(self, attention_mask: torch.Tensor, output_attentions: bool) -> torch.Tensor:
- if output_attentions:
- if self.config._attn_implementation == "sdpa":
- logger.warning_once(
- "Outputting attentions is only supported with the 'eager' attention implementation, "
- 'not with "sdpa". Falling back to `attn_implementation="eager"`.'
- )
- self.config._attn_implementation = "eager"
- elif self.config._attn_implementation != "eager":
- logger.warning_once(
- "Outputting attentions is only supported with the eager attention implementation, "
- f'not with {self.config._attn_implementation}. Consider setting `attn_implementation="eager"`.'
- " Setting `output_attentions=False`."
- )
-
- global_attention_mask = _prepare_4d_attention_mask(attention_mask, self.dtype)
-
- # Create position indices
- rows = torch.arange(global_attention_mask.shape[2]).unsqueeze(0)
- # Calculate distance between positions
- distance = torch.abs(rows - rows.T)
-
- # Create sliding window mask (1 for positions within window, 0 outside)
- window_mask = (
- (distance <= self.config.local_attention // 2).unsqueeze(0).unsqueeze(0).to(attention_mask.device)
- )
- # Combine with existing mask
- sliding_window_mask = global_attention_mask.masked_fill(window_mask.logical_not(), torch.finfo(self.dtype).min)
-
- return global_attention_mask, sliding_window_mask
-
-
-class ModernBertPredictionHead(nn.Module):
- def __init__(self, config: ModernBertConfig):
- super().__init__()
- self.config = config
- self.dense = nn.Linear(config.hidden_size, config.hidden_size, config.classifier_bias)
- self.act = ACT2FN[config.classifier_activation]
- self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
-
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- return self.norm(self.act(self.dense(hidden_states)))
-
-
-@add_start_docstrings(
- "The ModernBert Model with a decoder head on top that is used for masked language modeling.",
- MODERNBERT_START_DOCSTRING,
-)
-class ModernBertForMaskedLM(ModernBertPreTrainedModel):
- _tied_weights_keys = ["decoder.weight"]
-
- def __init__(self, config: ModernBertConfig):
- super().__init__(config)
- self.config = config
- self.model = ModernBertModel(config)
- self.head = ModernBertPredictionHead(config)
- self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=config.decoder_bias)
-
- self.sparse_prediction = self.config.sparse_prediction
- self.sparse_pred_ignore_index = self.config.sparse_pred_ignore_index
-
- # Initialize weights and apply final processing
- self.post_init()
-
- def get_output_embeddings(self):
- return self.decoder
-
- def set_output_embeddings(self, new_embeddings: nn.Linear):
- self.decoder = new_embeddings
-
- @torch.compile(dynamic=True)
- def compiled_head(self, output: torch.Tensor) -> torch.Tensor:
- return self.decoder(self.head(output))
-
- @add_start_docstrings_to_model_forward(MODERNBERT_INPUTS_DOCSTRING)
- @add_code_sample_docstrings(
- checkpoint=_CHECKPOINT_FOR_DOC,
- output_type=MaskedLMOutput,
- config_class=_CONFIG_FOR_DOC,
- )
- def forward(
- self,
- input_ids: Optional[torch.Tensor],
- attention_mask: Optional[torch.Tensor] = None,
- sliding_window_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.Tensor] = None,
- labels: Optional[torch.Tensor] = None,
- indices: Optional[torch.Tensor] = None,
- cu_seqlens: Optional[torch.Tensor] = None,
- max_seqlen: Optional[int] = None,
- batch_size: Optional[int] = None,
- seq_len: Optional[int] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- **kwargs,
- ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- self._maybe_set_compile()
-
- if self.config._attn_implementation == "flash_attention_2":
- if indices is None and cu_seqlens is None and max_seqlen is None:
- batch_size, seq_len = input_ids.shape[:2]
- if attention_mask is None:
- attention_mask = torch.ones((batch_size, seq_len), device=input_ids.device, dtype=torch.bool)
- with torch.no_grad():
- input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input(
- inputs=input_ids, attention_mask=attention_mask, position_ids=position_ids, labels=labels
- )
-
- outputs = self.model(
- input_ids,
- attention_mask=attention_mask,
- sliding_window_mask=sliding_window_mask,
- position_ids=position_ids,
- indices=indices,
- cu_seqlens=cu_seqlens,
- max_seqlen=max_seqlen,
- batch_size=batch_size,
- seq_len=seq_len,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- last_hidden_state = outputs[0]
-
- if self.sparse_prediction and labels is not None:
- # flatten labels and output first
- labels = labels.view(-1)
- last_hidden_state = last_hidden_state.view(labels.shape[0], -1)
-
- # then filter out the non-masked tokens
- mask_tokens = labels != self.sparse_pred_ignore_index
- last_hidden_state = last_hidden_state[mask_tokens]
- labels = labels[mask_tokens]
-
- logits = (
- self.compiled_head(last_hidden_state)
- if self.config.reference_compile
- else self.decoder(self.head(last_hidden_state))
- )
-
- loss = None
- if labels is not None:
- loss = self.loss_function(logits, labels, vocab_size=self.config.vocab_size)
-
- if self.config._attn_implementation == "flash_attention_2":
- with torch.no_grad():
- logits = _pad_modernbert_output(inputs=logits, indices=indices, batch=batch_size, seqlen=seq_len)
- if not return_dict:
- output = (logits,)
- return ((loss,) + output) if loss is not None else output
-
- return MaskedLMOutput(
- loss=loss,
- logits=logits,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
-
-
-@add_start_docstrings(
- "The ModernBert Model with a sequence classification head on top that performs pooling.",
- MODERNBERT_START_DOCSTRING,
-)
-class ModernBertForSequenceClassification(ModernBertPreTrainedModel):
- def __init__(self, config: ModernBertConfig):
- super().__init__(config)
- self.num_labels = config.num_labels
- self.config = config
-
- self.model = ModernBertModel(config)
- self.head = ModernBertPredictionHead(config)
- self.drop = torch.nn.Dropout(config.classifier_dropout)
- self.classifier = nn.Linear(config.hidden_size, config.num_labels)
-
- # Initialize weights and apply final processing
- self.post_init()
-
- @add_start_docstrings_to_model_forward(MODERNBERT_INPUTS_DOCSTRING)
- @add_code_sample_docstrings(
- checkpoint=_CHECKPOINT_FOR_DOC,
- output_type=SequenceClassifierOutput,
- config_class=_CONFIG_FOR_DOC,
- )
- def forward(
- self,
- input_ids: Optional[torch.Tensor],
- attention_mask: Optional[torch.Tensor] = None,
- sliding_window_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.Tensor] = None,
- labels: Optional[torch.Tensor] = None,
- indices: Optional[torch.Tensor] = None,
- cu_seqlens: Optional[torch.Tensor] = None,
- max_seqlen: Optional[int] = None,
- batch_size: Optional[int] = None,
- seq_len: Optional[int] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- **kwargs,
- ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
- r"""
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
- """
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- self._maybe_set_compile()
-
- outputs = self.model(
- input_ids,
- attention_mask=attention_mask,
- sliding_window_mask=sliding_window_mask,
- position_ids=position_ids,
- indices=indices,
- cu_seqlens=cu_seqlens,
- max_seqlen=max_seqlen,
- batch_size=batch_size,
- seq_len=seq_len,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- last_hidden_state = outputs[0]
-
- if self.config.classifier_pooling == "cls":
- last_hidden_state = last_hidden_state[:, 0]
- elif self.config.classifier_pooling == "mean":
- last_hidden_state = (last_hidden_state * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum(
- dim=1, keepdim=True
- )
-
- pooled_output = self.head(last_hidden_state)
- pooled_output = self.drop(pooled_output)
- logits = self.classifier(pooled_output)
-
- loss = None
- if labels is not None:
- if self.config.problem_type is None:
- if self.num_labels == 1:
- self.config.problem_type = "regression"
- elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
- self.config.problem_type = "single_label_classification"
- else:
- self.config.problem_type = "multi_label_classification"
-
- if self.config.problem_type == "regression":
- loss_fct = MSELoss()
- if self.num_labels == 1:
- loss = loss_fct(logits.squeeze(), labels.squeeze())
- else:
- loss = loss_fct(logits, labels)
- elif self.config.problem_type == "single_label_classification":
- loss_fct = CrossEntropyLoss()
- loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
- elif self.config.problem_type == "multi_label_classification":
- loss_fct = BCEWithLogitsLoss()
- loss = loss_fct(logits, labels)
-
- if not return_dict:
- output = (logits,)
- return ((loss,) + output) if loss is not None else output
-
- return SequenceClassifierOutput(
- loss=loss,
- logits=logits,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
-
-
-@add_start_docstrings(
- "The ModernBert Model with a token classification head on top, e.g. for Named Entity Recognition (NER) tasks.",
- MODERNBERT_START_DOCSTRING,
-)
-class ModernBertForTokenClassification(ModernBertPreTrainedModel):
- def __init__(self, config: ModernBertConfig):
- super().__init__(config)
- self.num_labels = config.num_labels
-
- self.model = ModernBertModel(config)
- self.head = ModernBertPredictionHead(config)
- self.drop = torch.nn.Dropout(config.classifier_dropout)
- self.classifier = nn.Linear(config.hidden_size, config.num_labels)
-
- # Initialize weights and apply final processing
- self.post_init()
-
- @add_start_docstrings_to_model_forward(MODERNBERT_INPUTS_DOCSTRING)
- @add_code_sample_docstrings(
- checkpoint=_CHECKPOINT_FOR_DOC,
- output_type=TokenClassifierOutput,
- config_class=_CONFIG_FOR_DOC,
- )
- def forward(
- self,
- input_ids: Optional[torch.Tensor],
- attention_mask: Optional[torch.Tensor] = None,
- sliding_window_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.Tensor] = None,
- labels: Optional[torch.Tensor] = None,
- indices: Optional[torch.Tensor] = None,
- cu_seqlens: Optional[torch.Tensor] = None,
- max_seqlen: Optional[int] = None,
- batch_size: Optional[int] = None,
- seq_len: Optional[int] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
- r"""
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
- """
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- self._maybe_set_compile()
-
- outputs = self.model(
- input_ids,
- attention_mask=attention_mask,
- sliding_window_mask=sliding_window_mask,
- position_ids=position_ids,
- indices=indices,
- cu_seqlens=cu_seqlens,
- max_seqlen=max_seqlen,
- batch_size=batch_size,
- seq_len=seq_len,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- last_hidden_state = outputs[0]
-
- last_hidden_state = self.head(last_hidden_state)
- last_hidden_state = self.drop(last_hidden_state)
- logits = self.classifier(last_hidden_state)
-
- loss = None
- if labels is not None:
- loss_fct = CrossEntropyLoss()
- loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
-
- if not return_dict:
- output = (logits,) + outputs[1:]
- return ((loss,) + output) if loss is not None else output
-
- return TokenClassifierOutput(
- loss=loss,
- logits=logits,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
-
-
-__all__ = [
- "ModernBertConfig",
- "ModernBertModel",
- "ModernBertPreTrainedModel",
- "ModernBertForMaskedLM",
- "ModernBertForSequenceClassification",
- "ModernBertForTokenClassification",
-]
diff --git a/src/transformers/models/moshi/modeling_moshi.py b/src/transformers/models/moshi/modeling_moshi.py
index f0281f57cf1c75..82abfa66c2e837 100644
--- a/src/transformers/models/moshi/modeling_moshi.py
+++ b/src/transformers/models/moshi/modeling_moshi.py
@@ -36,7 +36,6 @@
ModelOutput,
Seq2SeqLMOutput,
)
-from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
from ...utils import (
@@ -308,55 +307,24 @@ def forward(self, x, layer_idx=None):
# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Moshi
class MoshiRotaryEmbedding(nn.Module):
- def __init__(
- self,
- config: MoshiConfig,
- device=None,
- ):
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
- self.rope_kwargs = {}
- # BC: "rope_type" was originally "type"
- if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
- self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
- else:
- self.rope_type = "default"
- self.max_seq_len_cached = config.max_position_embeddings
- self.original_max_seq_len = config.max_position_embeddings
- self.config = config
- self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
-
- inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
+ self.dim = dim
+ self.max_position_embeddings = max_position_embeddings
+ self.base = base
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
- self.original_inv_freq = self.inv_freq
-
- def _dynamic_frequency_update(self, position_ids, device):
- """
- dynamic RoPE layers should recompute `inv_freq` in the following situations:
- 1 - growing beyond the cached sequence length (allow scaling)
- 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
- """
- seq_len = torch.max(position_ids) + 1
- if seq_len > self.max_seq_len_cached: # growth
- inv_freq, self.attention_scaling = self.rope_init_fn(
- self.config, device, seq_len=seq_len, **self.rope_kwargs
- )
- self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
- self.max_seq_len_cached = seq_len
-
- if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
- self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
- self.max_seq_len_cached = self.original_max_seq_len
@torch.no_grad()
+ # copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.forward
+ # TODO(joao): add me back asap :)
def forward(self, x, position_ids):
- if "dynamic" in self.rope_type:
- self._dynamic_frequency_update(position_ids, device=x.device)
-
- # Core RoPE block
+ # x: [bs, num_attention_heads, seq_len, head_size]
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
- # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
+ # Force float32 since bfloat16 loses precision on long contexts
+ # See https://github.com/huggingface/transformers/pull/29285
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
@@ -364,11 +332,6 @@ def forward(self, x, position_ids):
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
-
- # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
- cos = cos * self.attention_scaling
- sin = sin * self.attention_scaling
-
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
@@ -493,10 +456,13 @@ def __init__(self, config: MoshiConfig, layer_idx: Optional[int] = None, use_fle
self.rotary_emb = None
if use_rope:
self.rope_theta = config.rope_theta
- self.rotary_emb = MoshiRotaryEmbedding(config)
+ self.rotary_emb = MoshiRotaryEmbedding(
+ self.head_dim,
+ max_position_embeddings=self.max_position_embeddings,
+ base=self.rope_theta,
+ )
- # copied from transformers.models.gemma.modeling_gemma.GemmaAttention.forward
- # no longer copied after attention refactors
+ # Copied from transformers.models.gemma.modeling_gemma.GemmaAttention.forward
def forward(
self,
hidden_states: torch.Tensor,
@@ -561,8 +527,7 @@ def forward(
return attn_output, attn_weights, past_key_value
-# NO LONGER EXIST Copied from transformers.models.gemma.modeling_gemma.GemmaFlashAttention2 with Gemma->Moshi
-# TODO cyril: modular
+# Copied from transformers.models.gemma.modeling_gemma.GemmaFlashAttention2 with Gemma->Moshi
class MoshiFlashAttention2(MoshiAttention):
"""
Moshi flash attention module. This module inherits from `MoshiAttention` as the weights of the module stays
@@ -678,8 +643,7 @@ def forward(
return attn_output, attn_weights, past_key_value
-# NO LONGER EXIST Copied from transformers.models.gemma.modeling_gemma.GemmaSdpaAttention with Gemma->Moshi
-# TODO cyril: modular
+# Copied from transformers.models.gemma.modeling_gemma.GemmaSdpaAttention with Gemma->Moshi
class MoshiSdpaAttention(MoshiAttention):
"""
Moshi attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py
index f83bccb7e4f6f3..109ddfb626d26b 100644
--- a/src/transformers/models/musicgen/modeling_musicgen.py
+++ b/src/transformers/models/musicgen/modeling_musicgen.py
@@ -324,6 +324,7 @@ class MusicgenFlashAttention2(MusicgenAttention):
flash attention and deal with padding tokens in case the input contains any of them.
"""
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
diff --git a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py
index dc0e9b882b20cf..61f2ce414e1ddf 100644
--- a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py
+++ b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py
@@ -340,6 +340,7 @@ class MusicgenMelodyFlashAttention2(MusicgenMelodyAttention):
flash attention and deal with padding tokens in case the input contains any of them.
"""
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
diff --git a/src/transformers/models/nemotron/modeling_nemotron.py b/src/transformers/models/nemotron/modeling_nemotron.py
index a0a10bdc6f3550..a56b5c68085cb3 100644
--- a/src/transformers/models/nemotron/modeling_nemotron.py
+++ b/src/transformers/models/nemotron/modeling_nemotron.py
@@ -301,8 +301,7 @@ def forward(
return attn_output, attn_weights, past_key_value
-# NO LONGER EXIST Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with LLAMA->NEMOTRON,Llama->Nemotron,llama->nemotron
-# TODO cyril: modular
+# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with LLAMA->NEMOTRON,Llama->Nemotron,llama->nemotron
class NemotronFlashAttention2(NemotronAttention):
"""
Nemotron flash attention module. This module inherits from `NemotronAttention` as the weights of the module stays
@@ -416,8 +415,7 @@ def forward(
return attn_output, attn_weights, past_key_value
-# NO LONGER EXIST Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with LLAMA->NEMOTRON,Llama->Nemotron,llama->nemotron
-# TODO cyril: modular
+# Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with LLAMA->NEMOTRON,Llama->Nemotron,llama->nemotron
class NemotronSdpaAttention(NemotronAttention):
"""
Nemotron attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
@@ -516,8 +514,7 @@ def forward(
}
-# copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with LLAMA->NEMOTRON,Llama->Nemotron,llama->nemotron
-# no longer copied after attention refactors
+# Copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with LLAMA->NEMOTRON,Llama->Nemotron,llama->nemotron
class NemotronDecoderLayer(nn.Module):
# Ignore copy
def __init__(self, config: NemotronConfig, layer_idx: int):
@@ -539,7 +536,7 @@ def forward(
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
diff --git a/src/transformers/models/nougat/tokenization_nougat_fast.py b/src/transformers/models/nougat/tokenization_nougat_fast.py
index 5d0a8934c05ee1..0a7eec4ad98a4c 100644
--- a/src/transformers/models/nougat/tokenization_nougat_fast.py
+++ b/src/transformers/models/nougat/tokenization_nougat_fast.py
@@ -514,7 +514,7 @@ def post_process_single(self, generation: str, fix_markdown: bool = True) -> str
generation = generation.replace("\n* [leftmargin=*]\n", "\n")
# Remove lines with markdown headings starting with #, with numerals,
# and possibly roman numerals with trailing spaces and newlines
- generation = re.sub(r"^#+ (?:[\d+\.]+|[ixv\.]+)?\s*(?:$|\n\s*)", "", generation, flags=re.M)
+ generation = re.sub(r"^#+ (?:\.?(?:\d|[ixv])+)*\s*(?:$|\n\s*)", "", generation, flags=re.M)
# most likely hallucinated titles
lines = generation.split("\n")
if lines[-1].startswith("#") and lines[-1].lstrip("#").startswith(" ") and len(lines) > 1:
diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py
index 11d3d99f4f72c9..d865c51e50578e 100644
--- a/src/transformers/models/olmo/modeling_olmo.py
+++ b/src/transformers/models/olmo/modeling_olmo.py
@@ -1,35 +1,59 @@
-# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
-# This file was automatically generated from src/transformers/models/olmo/modular_olmo.py.
-# Do NOT edit this file manually as any edits will be overwritten by the generation of
-# the file from the modular. If any change should be done, please apply the change to the
-# modular_olmo.py file directly. One of our CI enforces this.
-# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
-from typing import Callable, List, Optional, Tuple, Union
+# coding=utf-8
+# Copyright 2024 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch OLMo model."""
+
+import math
+from typing import List, Optional, Tuple, Union
import torch
-import torch.nn as nn
import torch.nn.functional as F
+import torch.utils.checkpoint
+from torch import nn
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, StaticCache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter
-from ...modeling_flash_attention_utils import FlashAttentionKwargs
-from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
-from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
-from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
-from ...processing_utils import Unpack
+from ...modeling_outputs import (
+ BaseModelOutputWithPast,
+ CausalLMOutputWithPast,
+)
+from ...modeling_utils import PreTrainedModel
+from ...pytorch_utils import ALL_LAYERNORM_LAYERS
from ...utils import (
- LossKwargs,
add_start_docstrings,
add_start_docstrings_to_model_forward,
+ is_flash_attn_2_available,
+ is_flash_attn_greater_or_equal_2_10,
logging,
replace_return_docstrings,
)
from .configuration_olmo import OlmoConfig
+if is_flash_attn_2_available():
+ from ...modeling_flash_attention_utils import _flash_attention_forward
+
+
logger = logging.get_logger(__name__)
+
_CONFIG_FOR_DOC = "OlmoConfig"
@@ -47,22 +71,74 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
)
-class OlmoMLP(nn.Module):
- def __init__(self, config):
+ALL_LAYERNORM_LAYERS.append(OlmoLayerNorm)
+
+
+# copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Olmo
+# TODO(joao): add me back asap :)
+class OlmoRotaryEmbedding(nn.Module):
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
super().__init__()
- self.config = config
- self.hidden_size = config.hidden_size
- self.intermediate_size = config.intermediate_size
- self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
- self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
- self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
- self.act_fn = ACT2FN[config.hidden_act]
+ self.scaling_factor = scaling_factor
+ self.dim = dim
+ self.max_position_embeddings = max_position_embeddings
+ self.base = base
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ # For BC we register cos and sin cached
+ self.max_seq_len_cached = max_position_embeddings
- def forward(self, x):
- down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
- return down_proj
+ @torch.no_grad()
+ def forward(self, x, position_ids):
+ # x: [bs, num_attention_heads, seq_len, head_size]
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
+ position_ids_expanded = position_ids[:, None, :].float()
+ # Force float32 since bfloat16 loses precision on long contexts
+ # See https://github.com/huggingface/transformers/pull/29285
+ device_type = x.device.type
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False):
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos()
+ sin = emb.sin()
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+# copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->Olmo
+# TODO(joao): add me back asap :)
+class OlmoLinearScalingRotaryEmbedding(OlmoRotaryEmbedding):
+ """OlmoRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
+
+ def forward(self, x, position_ids):
+ # difference to the original RoPE: a scaling factor is aplied to the position ids
+ position_ids = position_ids.float() / self.scaling_factor
+ cos, sin = super().forward(x, position_ids)
+ return cos, sin
+
+
+# copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Olmo
+# TODO(joao): add me back asap :)
+class OlmoDynamicNTKScalingRotaryEmbedding(OlmoRotaryEmbedding):
+ """OlmoRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
+
+ def forward(self, x, position_ids):
+ # difference to the original RoPE: inv_freq is recomputed when the sequence length > original length
+ seq_len = torch.max(position_ids) + 1
+ if seq_len > self.max_position_embeddings:
+ base = self.base * (
+ (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
+ ) ** (self.dim / (self.dim - 2))
+ inv_freq = 1.0 / (
+ base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(x.device) / self.dim)
+ )
+ self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: this may break with compilation
+
+ cos, sin = super().forward(x, position_ids)
+ return cos, sin
+
+
+# Copied from transformers.models.llama.modeling_llama.rotate_half
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
@@ -70,6 +146,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)
+# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
@@ -97,6 +174,22 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
return q_embed, k_embed
+class OlmoMLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = config.intermediate_size
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, x):
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
+
+
+# Copied from transformers.models.llama.modeling_llama.repeat_kv
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
@@ -109,69 +202,83 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
-def eager_attention_forward(
- module: nn.Module,
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- attention_mask: Optional[torch.Tensor],
- scaling: float,
- dropout: float = 0.0,
- **kwargs,
-):
- key_states = repeat_kv(key, module.num_key_value_groups)
- value_states = repeat_kv(value, module.num_key_value_groups)
-
- attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
- if attention_mask is not None:
- causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
- attn_weights = attn_weights + causal_mask
-
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
- attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
- attn_output = torch.matmul(attn_weights, value_states)
- attn_output = attn_output.transpose(1, 2).contiguous()
-
- return attn_output, attn_weights
-
-
class OlmoAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
- def __init__(self, config: OlmoConfig, layer_idx: int):
+ # copied from transformers.models.llama.modeling_llama.LlamaAttention.__init__ with Llama->Olmo
+ # TODO(joao): add me back asap :)
+ def __init__(self, config: OlmoConfig, layer_idx: Optional[int] = None):
super().__init__()
self.config = config
self.layer_idx = layer_idx
- self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
- self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
- self.scaling = self.head_dim**-0.5
+ if layer_idx is None:
+ logger.warning_once(
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
+ "when creating this class."
+ )
+
self.attention_dropout = config.attention_dropout
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.hidden_size // self.num_heads
+ self.num_key_value_heads = config.num_key_value_heads
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
+ self.max_position_embeddings = config.max_position_embeddings
+ self.rope_theta = config.rope_theta
self.is_causal = True
- self.q_proj = nn.Linear(
- config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
- )
- self.k_proj = nn.Linear(
- config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
- )
- self.v_proj = nn.Linear(
- config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
- )
- self.o_proj = nn.Linear(
- config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
- )
+ if (self.head_dim * self.num_heads) != self.hidden_size:
+ raise ValueError(
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
+ f" and `num_heads`: {self.num_heads})."
+ )
+
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
+ self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias)
+ self._init_rope()
+
+ def _init_rope(self):
+ if self.config.rope_scaling is None:
+ self.rotary_emb = OlmoRotaryEmbedding(
+ self.head_dim,
+ max_position_embeddings=self.max_position_embeddings,
+ base=self.rope_theta,
+ )
+ else:
+ scaling_type = self.config.rope_scaling["type"]
+ scaling_factor = self.config.rope_scaling["factor"]
+ if scaling_type == "linear":
+ self.rotary_emb = OlmoLinearScalingRotaryEmbedding(
+ self.head_dim,
+ max_position_embeddings=self.max_position_embeddings,
+ scaling_factor=scaling_factor,
+ base=self.rope_theta,
+ )
+ elif scaling_type == "dynamic":
+ self.rotary_emb = OlmoDynamicNTKScalingRotaryEmbedding(
+ self.head_dim,
+ max_position_embeddings=self.max_position_embeddings,
+ scaling_factor=scaling_factor,
+ base=self.rope_theta,
+ )
+ else:
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
def forward(
self,
hidden_states: torch.Tensor,
- position_embeddings: Tuple[torch.Tensor, torch.Tensor],
- attention_mask: Optional[torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
- input_shape = hidden_states.shape[:-1]
- hidden_shape = (*input_shape, -1, self.head_dim)
+ bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
@@ -182,11 +289,11 @@ def forward(
key_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
value_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
- query_states = query_states.view(hidden_shape).transpose(1, 2)
- key_states = key_states.view(hidden_shape).transpose(1, 2)
- value_states = value_states.view(hidden_shape).transpose(1, 2)
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
- cos, sin = position_embeddings
+ cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
@@ -194,42 +301,261 @@ def forward(
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
- attention_interface: Callable = eager_attention_forward
- if self.config._attn_implementation != "eager":
- if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
- logger.warning_once(
- "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
- 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
- )
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
+
+ if attention_mask is not None: # no matter the length, we just slice it
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+class OlmoFlashAttention2(OlmoAttention):
+ """
+ OLMo flash attention module. This module inherits from `OlmoAttention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ output_attentions = False
+
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ if self.config.clip_qkv is not None:
+ query_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
+ key_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
+ value_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
+
+ # Flash attention requires the input to have the shape
+ # batch_size x seq_length x head_dim x hidden_dim
+ # therefore we just need to keep the original shape
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
+ # to be able to avoid many of these transpose/reshape/view.
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ dropout_rate = self.attention_dropout if self.training else 0.0
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in the correct dtype just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+ # in fp32. (OlmoRMSNorm handles it correctly)
+
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
else:
- attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+ target_dtype = self.q_proj.weight.dtype
- attn_output, attn_weights = attention_interface(
- self,
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ attn_output = _flash_attention_forward(
query_states,
key_states,
value_states,
attention_mask,
- dropout=0.0 if not self.training else self.attention_dropout,
- scaling=self.scaling,
- **kwargs,
+ q_len,
+ position_ids=position_ids,
+ dropout=dropout_rate,
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
+ is_causal=self.is_causal,
+ )
+
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+class OlmoSdpaAttention(OlmoAttention):
+ """
+ OLMo attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
+ `OlmoAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
+ SDPA API.
+ """
+
+ # Adapted from OlmoAttention.forward
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ if output_attentions:
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
+ logger.warning_once(
+ "OlmoModel is using OlmoSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
+ )
+ return super().forward(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ )
+
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ if self.config.clip_qkv is not None:
+ query_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
+ key_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
+ value_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ causal_mask = attention_mask
+ # if attention_mask is not None and cache_position is not None:
+ if attention_mask is not None:
+ causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
+
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
+ if query_states.device.type == "cuda" and causal_mask is not None:
+ query_states = query_states.contiguous()
+ key_states = key_states.contiguous()
+ value_states = value_states.contiguous()
+
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
+ is_causal = True if causal_mask is None and q_len > 1 else False
+
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
+ query_states,
+ key_states,
+ value_states,
+ attn_mask=causal_mask,
+ dropout_p=self.attention_dropout if self.training else 0.0,
+ is_causal=is_causal,
)
- attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.view(bsz, q_len, self.hidden_size)
+
attn_output = self.o_proj(attn_output)
- return attn_output, attn_weights
+
+ return attn_output, None, past_key_value
+
+
+OLMO_ATTENTION_CLASSES = {
+ "eager": OlmoAttention,
+ "flash_attention_2": OlmoFlashAttention2,
+ "sdpa": OlmoSdpaAttention,
+}
class OlmoDecoderLayer(nn.Module):
def __init__(self, config: OlmoConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
- self.self_attn = OlmoAttention(config=config, layer_idx=layer_idx)
+
+ self.self_attn = OLMO_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
self.mlp = OlmoMLP(config)
self.input_layernorm = OlmoLayerNorm(config.hidden_size)
self.post_attention_layernorm = OlmoLayerNorm(config.hidden_size)
+ # copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward
+ # TODO(joao): add me back asap :)
def forward(
self,
hidden_states: torch.Tensor,
@@ -239,15 +565,33 @@ def forward(
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
- **kwargs: Unpack[FlashAttentionKwargs],
+ **kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`, *optional*):
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
+ query_sequence_length, key_sequence_length)` if default attention is used.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence
+ kwargs (`dict`, *optional*):
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
+ into the model
+ """
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
- hidden_states, self_attn_weights = self.self_attn(
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
@@ -255,7 +599,6 @@ def forward(
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
- position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = residual + hidden_states
@@ -267,75 +610,14 @@ def forward(
hidden_states = residual + hidden_states
outputs = (hidden_states,)
+
if output_attentions:
outputs += (self_attn_weights,)
- return outputs
-
-
-class OlmoRotaryEmbedding(nn.Module):
- def __init__(
- self,
- config: OlmoConfig,
- device=None,
- ):
- super().__init__()
- self.rope_kwargs = {}
- # BC: "rope_type" was originally "type"
- if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
- self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
- else:
- self.rope_type = "default"
- self.max_seq_len_cached = config.max_position_embeddings
- self.original_max_seq_len = config.max_position_embeddings
-
- self.config = config
- self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
-
- inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
- self.register_buffer("inv_freq", inv_freq, persistent=False)
- self.original_inv_freq = self.inv_freq
+ if use_cache:
+ outputs += (present_key_value,)
- def _dynamic_frequency_update(self, position_ids, device):
- """
- dynamic RoPE layers should recompute `inv_freq` in the following situations:
- 1 - growing beyond the cached sequence length (allow scaling)
- 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
- """
- seq_len = torch.max(position_ids) + 1
- if seq_len > self.max_seq_len_cached: # growth
- inv_freq, self.attention_scaling = self.rope_init_fn(
- self.config, device, seq_len=seq_len, **self.rope_kwargs
- )
- self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
- self.max_seq_len_cached = seq_len
-
- if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
- self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
- self.max_seq_len_cached = self.original_max_seq_len
-
- @torch.no_grad()
- def forward(self, x, position_ids):
- if "dynamic" in self.rope_type:
- self._dynamic_frequency_update(position_ids, device=x.device)
-
- # Core RoPE block
- inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
- position_ids_expanded = position_ids[:, None, :].float()
- # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
- device_type = x.device.type
- device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
- with torch.autocast(device_type=device_type, enabled=False):
- freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
- emb = torch.cat((freqs, freqs), dim=-1)
- cos = emb.cos()
- sin = emb.sin()
-
- # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
- cos = cos * self.attention_scaling
- sin = sin * self.attention_scaling
-
- return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+ return outputs
OLMO_START_DOCSTRING = r"""
@@ -359,6 +641,7 @@ def forward(self, x, position_ids):
"The bare Olmo Model outputting raw hidden-states without any specific head on top.",
OLMO_START_DOCSTRING,
)
+# Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel with Llama->Olmo
class OlmoPreTrainedModel(PreTrainedModel):
config_class = OlmoConfig
base_model_prefix = "model"
@@ -480,7 +763,6 @@ def __init__(self, config: OlmoConfig):
[OlmoDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = OlmoLayerNorm(config.hidden_size)
- self.rotary_emb = OlmoRotaryEmbedding(config=config)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
@@ -493,19 +775,20 @@ def set_input_embeddings(self, value):
self.embed_tokens = value
@add_start_docstrings_to_model_forward(OLMO_INPUTS_DOCSTRING)
+ # copied from transformers.models.llama.modeling_llama.LlamaModel.forward
+ # TODO(joao): add me back asap :)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[Cache] = None,
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
- **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
@@ -526,15 +809,25 @@ def forward(
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
- if use_cache and past_key_values is None:
- past_key_values = DynamicCache()
+ # kept for BC (non `Cache` `past_key_values` inputs)
+ return_legacy_cache = False
+ if use_cache and not isinstance(past_key_values, Cache):
+ return_legacy_cache = True
+ if past_key_values is None:
+ past_key_values = DynamicCache()
+ else:
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
+ logger.warning_once(
+ "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
+ "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
+ "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
+ )
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
-
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
@@ -542,16 +835,15 @@ def forward(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
+ # embed positions
hidden_states = inputs_embeds
- # create position embeddings to be shared across the decoder layers
- position_embeddings = self.rotary_emb(hidden_states, position_ids)
-
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
+ next_decoder_cache = None
- for decoder_layer in self.layers[: self.config.num_hidden_layers]:
+ for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
@@ -565,7 +857,6 @@ def forward(
output_attentions,
use_cache,
cache_position,
- position_embeddings,
)
else:
layer_outputs = decoder_layer(
@@ -576,12 +867,13 @@ def forward(
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
- position_embeddings=position_embeddings,
- **flash_attn_kwargs,
)
hidden_states = layer_outputs[0]
+ if use_cache:
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
+
if output_attentions:
all_self_attns += (layer_outputs[1],)
@@ -591,14 +883,20 @@ def forward(
if output_hidden_states:
all_hidden_states += (hidden_states,)
- output = BaseModelOutputWithPast(
+ next_cache = next_decoder_cache if use_cache else None
+ if return_legacy_cache:
+ next_cache = next_cache.to_legacy_cache()
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
+ return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
- past_key_values=past_key_values if use_cache else None,
+ past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
- return output if return_dict else output.to_tuple()
+ # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
def _update_causal_mask(
self,
attention_mask: torch.Tensor,
@@ -665,6 +963,7 @@ def _update_causal_mask(
return causal_mask
@staticmethod
+ # Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
@@ -721,12 +1020,9 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
return causal_mask
-class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
-
-
+# TODO: re-enable check: Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with LLAMA->OLMO,Llama->Olmo
class OlmoForCausalLM(OlmoPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
- _tp_plan = {"lm_head": "colwise_rep"}
def __init__(self, config):
super().__init__(config)
@@ -757,12 +1053,13 @@ def get_decoder(self):
@add_start_docstrings_to_model_forward(OLMO_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
+ # Ignore copy
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
@@ -771,7 +1068,7 @@ def forward(
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
- **kwargs: Unpack[KwargsForCausalLM],
+ **loss_kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
@@ -792,8 +1089,8 @@ def forward(
```python
>>> from transformers import AutoTokenizer, OlmoForCausalLM
- >>> model = OlmoForCausalLM.from_pretrained("meta-olmo/Olmo-2-7b-hf")
- >>> tokenizer = AutoTokenizer.from_pretrained("meta-olmo/Olmo-2-7b-hf")
+ >>> model = OlmoForCausalLM.from_pretrained("allenai/OLMo-1B-hf")
+ >>> tokenizer = AutoTokenizer.from_pretrained("allenai/OLMo-1B-hf")
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
@@ -801,8 +1098,9 @@ def forward(
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
- "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
- ```"""
+ 'Hey, are you conscious? Can you talk to me?\nI’m not sure if you’re conscious of this, but I’m'
+ ```
+ """
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -821,7 +1119,6 @@ def forward(
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
- **kwargs,
)
hidden_states = outputs[0]
@@ -830,7 +1127,7 @@ def forward(
loss = None
if labels is not None:
- loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
+ loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
if not return_dict:
output = (logits,) + outputs[1:]
diff --git a/src/transformers/models/olmo/modular_olmo.py b/src/transformers/models/olmo/modular_olmo.py
deleted file mode 100644
index 2a43e6f9c75d05..00000000000000
--- a/src/transformers/models/olmo/modular_olmo.py
+++ /dev/null
@@ -1,126 +0,0 @@
-from typing import Callable, Optional, Tuple
-
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-import torch.utils.checkpoint
-
-from ...cache_utils import Cache
-from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
-from ...utils import logging
-from ..llama.modeling_llama import (
- LlamaAttention,
- LlamaDecoderLayer,
- LlamaForCausalLM,
- LlamaMLP,
- LlamaModel,
- apply_rotary_pos_emb,
- eager_attention_forward,
-)
-from .configuration_olmo import OlmoConfig
-
-
-logger = logging.get_logger(__name__)
-
-
-class OlmoLayerNorm(nn.Module):
- """LayerNorm but with no learnable weight or bias."""
-
- def __init__(self, hidden_size: int) -> None:
- super().__init__()
- self.normalized_shape = (hidden_size,)
-
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- orig_dtype = hidden_states.dtype
- return F.layer_norm(hidden_states.to(dtype=torch.float32), self.normalized_shape, None, None, eps=1e-5).to(
- orig_dtype
- )
-
-
-class OlmoMLP(LlamaMLP):
- def __init__(self, config):
- super().__init__(config)
- self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
- self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
- self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
-
-
-class OlmoAttention(LlamaAttention):
- def forward(
- self,
- hidden_states: torch.Tensor,
- position_embeddings: Tuple[torch.Tensor, torch.Tensor],
- attention_mask: Optional[torch.Tensor],
- past_key_value: Optional[Cache] = None,
- cache_position: Optional[torch.LongTensor] = None,
- **kwargs,
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
- input_shape = hidden_states.shape[:-1]
- hidden_shape = (*input_shape, -1, self.head_dim)
-
- query_states = self.q_proj(hidden_states)
- key_states = self.k_proj(hidden_states)
- value_states = self.v_proj(hidden_states)
-
- if self.config.clip_qkv is not None:
- query_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
- key_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
- value_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
-
- query_states = query_states.view(hidden_shape).transpose(1, 2)
- key_states = key_states.view(hidden_shape).transpose(1, 2)
- value_states = value_states.view(hidden_shape).transpose(1, 2)
-
- cos, sin = position_embeddings
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
-
- if past_key_value is not None:
- # sin and cos are specific to RoPE models; cache_position needed for the static cache
- cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
- key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
-
- attention_interface: Callable = eager_attention_forward
- if self.config._attn_implementation != "eager":
- if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
- logger.warning_once(
- "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
- 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
- )
- else:
- attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
-
- attn_output, attn_weights = attention_interface(
- self,
- query_states,
- key_states,
- value_states,
- attention_mask,
- dropout=0.0 if not self.training else self.attention_dropout,
- scaling=self.scaling,
- **kwargs,
- )
-
- attn_output = attn_output.reshape(*input_shape, -1).contiguous()
- attn_output = self.o_proj(attn_output)
- return attn_output, attn_weights
-
-
-class OlmoDecoderLayer(LlamaDecoderLayer):
- def __init__(self, config: OlmoConfig, layer_idx: int):
- super().__init__(config, layer_idx)
- self.input_layernorm = OlmoLayerNorm(config.hidden_size)
- self.post_attention_layernorm = OlmoLayerNorm(config.hidden_size)
- self.self_attn = OlmoAttention(config=config, layer_idx=layer_idx)
-
-
-class OlmoModel(LlamaModel):
- def __init__(self, config: OlmoConfig):
- super().__init__(config)
- self.layers = nn.ModuleList(
- [OlmoDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
- )
- self.norm = OlmoLayerNorm(config.hidden_size)
-
-
-class OlmoForCausalLM(LlamaForCausalLM):
- pass
diff --git a/src/transformers/models/olmo2/configuration_olmo2.py b/src/transformers/models/olmo2/configuration_olmo2.py
index 83c3263de1f552..144520f87ed7f9 100644
--- a/src/transformers/models/olmo2/configuration_olmo2.py
+++ b/src/transformers/models/olmo2/configuration_olmo2.py
@@ -5,7 +5,6 @@
# modular_olmo2.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
-
from ...configuration_utils import PretrainedConfig
diff --git a/src/transformers/models/olmo2/modeling_olmo2.py b/src/transformers/models/olmo2/modeling_olmo2.py
index 49ae798e7f1101..c042669e1ed5c3 100644
--- a/src/transformers/models/olmo2/modeling_olmo2.py
+++ b/src/transformers/models/olmo2/modeling_olmo2.py
@@ -4,31 +4,35 @@
# the file from the modular. If any change should be done, please apply the change to the
# modular_olmo2.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
-from typing import Callable, List, Optional, Tuple, Union
+import math
+from typing import List, Optional, Tuple, Union
import torch
-import torch.nn as nn
+from torch import nn
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, StaticCache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter
-from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
-from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
-from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
-from ...processing_utils import Unpack
+from ...modeling_utils import PreTrainedModel
from ...utils import (
- LossKwargs,
add_start_docstrings,
add_start_docstrings_to_model_forward,
+ is_flash_attn_2_available,
+ is_flash_attn_greater_or_equal_2_10,
logging,
replace_return_docstrings,
)
from .configuration_olmo2 import Olmo2Config
+if is_flash_attn_2_available():
+ from ...modeling_flash_attention_utils import _flash_attention_forward
+
+
logger = logging.get_logger(__name__)
+
_CONFIG_FOR_DOC = "Olmo2Config"
@@ -52,6 +56,70 @@ def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
+# copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Olmo2
+# TODO(joao): add me back asap :)
+class Olmo2RotaryEmbedding(nn.Module):
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
+ super().__init__()
+ self.scaling_factor = scaling_factor
+ self.dim = dim
+ self.max_position_embeddings = max_position_embeddings
+ self.base = base
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ # For BC we register cos and sin cached
+ self.max_seq_len_cached = max_position_embeddings
+
+ @torch.no_grad()
+ def forward(self, x, position_ids):
+ # x: [bs, num_attention_heads, seq_len, head_size]
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
+ position_ids_expanded = position_ids[:, None, :].float()
+ # Force float32 since bfloat16 loses precision on long contexts
+ # See https://github.com/huggingface/transformers/pull/29285
+ device_type = x.device.type
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False):
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos()
+ sin = emb.sin()
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+# copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->Olmo2
+# TODO(joao): add me back asap :)
+class Olmo2LinearScalingRotaryEmbedding(Olmo2RotaryEmbedding):
+ """Olmo2RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
+
+ def forward(self, x, position_ids):
+ # difference to the original RoPE: a scaling factor is aplied to the position ids
+ position_ids = position_ids.float() / self.scaling_factor
+ cos, sin = super().forward(x, position_ids)
+ return cos, sin
+
+
+# copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Olmo2
+# TODO(joao): add me back asap :)
+class Olmo2DynamicNTKScalingRotaryEmbedding(Olmo2RotaryEmbedding):
+ """Olmo2RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
+
+ def forward(self, x, position_ids):
+ # difference to the original RoPE: inv_freq is recomputed when the sequence length > original length
+ seq_len = torch.max(position_ids) + 1
+ if seq_len > self.max_position_embeddings:
+ base = self.base * (
+ (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
+ ) ** (self.dim / (self.dim - 2))
+ inv_freq = 1.0 / (
+ base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(x.device) / self.dim)
+ )
+ self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: this may break with compilation
+
+ cos, sin = super().forward(x, position_ids)
+ return cos, sin
+
+
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
@@ -98,81 +166,95 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
-def eager_attention_forward(
- module: nn.Module,
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- attention_mask: Optional[torch.Tensor],
- scaling: float,
- dropout: float = 0.0,
- **kwargs,
-):
- key_states = repeat_kv(key, module.num_key_value_groups)
- value_states = repeat_kv(value, module.num_key_value_groups)
-
- attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
- if attention_mask is not None:
- causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
- attn_weights = attn_weights + causal_mask
-
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
- attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
- attn_output = torch.matmul(attn_weights, value_states)
- attn_output = attn_output.transpose(1, 2).contiguous()
-
- return attn_output, attn_weights
-
-
class Olmo2Attention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
+ # copied from transformers.models.llama.modeling_llama.LlamaAttention.__init__ with Llama->Olmo2
+ # TODO(joao): add me back asap :)
def __init__(self, config: Olmo2Config, layer_idx: Optional[int] = None):
super().__init__()
self.config = config
self.layer_idx = layer_idx
- self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
- self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
- self.scaling = self.head_dim**-0.5
+ if layer_idx is None:
+ logger.warning_once(
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
+ "when creating this class."
+ )
+
self.attention_dropout = config.attention_dropout
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.hidden_size // self.num_heads
+ self.num_key_value_heads = config.num_key_value_heads
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
+ self.max_position_embeddings = config.max_position_embeddings
+ self.rope_theta = config.rope_theta
self.is_causal = True
- self.q_proj = nn.Linear(
- config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
- )
- self.k_proj = nn.Linear(
- config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
- )
- self.v_proj = nn.Linear(
- config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
- )
- self.o_proj = nn.Linear(
- config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
- )
- self.q_norm = Olmo2RMSNorm(config.num_attention_heads * self.head_dim, config.rms_norm_eps)
- self.k_norm = Olmo2RMSNorm(config.num_key_value_heads * self.head_dim, config.rms_norm_eps)
+ if (self.head_dim * self.num_heads) != self.hidden_size:
+ raise ValueError(
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
+ f" and `num_heads`: {self.num_heads})."
+ )
+
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
+ self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias)
+ self._init_rope()
+ self.q_norm = Olmo2RMSNorm(self.num_heads * self.head_dim, config.rms_norm_eps)
+ self.k_norm = Olmo2RMSNorm(self.num_key_value_heads * self.head_dim, config.rms_norm_eps)
+
+ def _init_rope(self):
+ if self.config.rope_scaling is None:
+ self.rotary_emb = Olmo2RotaryEmbedding(
+ self.head_dim,
+ max_position_embeddings=self.max_position_embeddings,
+ base=self.rope_theta,
+ )
+ else:
+ scaling_type = self.config.rope_scaling["type"]
+ scaling_factor = self.config.rope_scaling["factor"]
+ if scaling_type == "linear":
+ self.rotary_emb = Olmo2LinearScalingRotaryEmbedding(
+ self.head_dim,
+ max_position_embeddings=self.max_position_embeddings,
+ scaling_factor=scaling_factor,
+ base=self.rope_theta,
+ )
+ elif scaling_type == "dynamic":
+ self.rotary_emb = Olmo2DynamicNTKScalingRotaryEmbedding(
+ self.head_dim,
+ max_position_embeddings=self.max_position_embeddings,
+ scaling_factor=scaling_factor,
+ base=self.rope_theta,
+ )
+ else:
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
def forward(
self,
hidden_states: torch.Tensor,
- position_embeddings: Tuple[torch.Tensor, torch.Tensor],
- attention_mask: Optional[torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
- input_shape = hidden_states.shape[:-1]
- hidden_shape = (*input_shape, -1, self.head_dim)
+ bsz, q_len, _ = hidden_states.size()
query_states = self.q_norm(self.q_proj(hidden_states))
key_states = self.k_norm(self.k_proj(hidden_states))
value_states = self.v_proj(hidden_states)
- query_states = query_states.view(hidden_shape).transpose(1, 2)
- key_states = key_states.view(hidden_shape).transpose(1, 2)
- value_states = value_states.view(hidden_shape).transpose(1, 2)
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
- cos, sin = position_embeddings
+ cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
@@ -180,30 +262,220 @@ def forward(
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
- attention_interface: Callable = eager_attention_forward
- if self.config._attn_implementation != "eager":
- if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
- logger.warning_once(
- "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
- 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
- )
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
+
+ if attention_mask is not None: # no matter the length, we just slice it
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+class Olmo2FlashAttention2(Olmo2Attention):
+ """
+ Olmo2 flash attention module. This module inherits from `Olmo2Attention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+
+ OLMo2 flash attention module. This module inherits from `Olmo2Attention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ output_attentions = False
+
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_norm(self.q_proj(hidden_states))
+ key_states = self.k_norm(self.k_proj(hidden_states))
+ value_states = self.v_proj(hidden_states)
+
+ # Flash attention requires the input to have the shape
+ # batch_size x seq_length x head_dim x hidden_dim
+ # therefore we just need to keep the original shape
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
+ # to be able to avoid many of these transpose/reshape/view.
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ dropout_rate = self.attention_dropout if self.training else 0.0
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in the correct dtype just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+ # in fp32. (OlmoRMSNorm handles it correctly)
+
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
else:
- attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+ target_dtype = self.q_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
- attn_output, attn_weights = attention_interface(
- self,
+ attn_output = _flash_attention_forward(
query_states,
key_states,
value_states,
attention_mask,
- dropout=0.0 if not self.training else self.attention_dropout,
- scaling=self.scaling,
- **kwargs,
+ q_len,
+ position_ids=position_ids,
+ dropout=dropout_rate,
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
+ is_causal=self.is_causal,
)
- attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+class Olmo2SdpaAttention(Olmo2Attention):
+ """
+ Olmo2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
+ `Olmo2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
+ SDPA API.
+ """
+
+ # Adapted from Olmo2Attention.forward
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ if output_attentions:
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
+ logger.warning_once(
+ "Olmo2Model is using Olmo2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
+ )
+ return super().forward(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ )
+ bsz, q_len, _ = hidden_states.size()
+ query_states = self.q_norm(self.q_proj(hidden_states))
+ key_states = self.k_norm(self.k_proj(hidden_states))
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+ causal_mask = attention_mask
+ # if attention_mask is not None and cache_position is not None:
+ if attention_mask is not None:
+ causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
+ if query_states.device.type == "cuda" and causal_mask is not None:
+ query_states = query_states.contiguous()
+ key_states = key_states.contiguous()
+ value_states = value_states.contiguous()
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
+ is_causal = True if causal_mask is None and q_len > 1 else False
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
+ query_states,
+ key_states,
+ value_states,
+ attn_mask=causal_mask,
+ dropout_p=self.attention_dropout if self.training else 0.0,
+ is_causal=is_causal,
+ )
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.view(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
- return attn_output, attn_weights
+ return attn_output, None, past_key_value
class Olmo2MLP(nn.Module):
@@ -218,20 +490,29 @@ def __init__(self, config):
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x):
- down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
- return down_proj
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
+
+
+OLMO2_ATTENTION_CLASSES = {
+ "eager": Olmo2Attention,
+ "flash_attention_2": Olmo2FlashAttention2,
+ "sdpa": Olmo2SdpaAttention,
+}
class Olmo2DecoderLayer(nn.Module):
def __init__(self, config: Olmo2Config, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
- self.self_attn = Olmo2Attention(config=config, layer_idx=layer_idx)
+
+ self.self_attn = OLMO2_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
self.mlp = Olmo2MLP(config)
self.post_attention_layernorm = Olmo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_feedforward_layernorm = Olmo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ # copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward
+ # TODO(joao): add me back asap :)
def forward(
self,
hidden_states: torch.Tensor,
@@ -241,13 +522,31 @@ def forward(
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`, *optional*):
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
+ query_sequence_length, key_sequence_length)` if default attention is used.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence
+ kwargs (`dict`, *optional*):
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
+ into the model
+ """
residual = hidden_states
# Self Attention
- hidden_states, self_attn_weights = self.self_attn(
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
@@ -255,7 +554,6 @@ def forward(
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
- position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = self.post_attention_layernorm(hidden_states)
@@ -270,75 +568,11 @@ def forward(
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
-
+ if use_cache:
+ outputs += (present_key_value,)
return outputs
-class Olmo2RotaryEmbedding(nn.Module):
- def __init__(
- self,
- config: Olmo2Config,
- device=None,
- ):
- super().__init__()
- self.rope_kwargs = {}
- # BC: "rope_type" was originally "type"
- if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
- self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
- else:
- self.rope_type = "default"
- self.max_seq_len_cached = config.max_position_embeddings
- self.original_max_seq_len = config.max_position_embeddings
-
- self.config = config
- self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
-
- inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
- self.register_buffer("inv_freq", inv_freq, persistent=False)
- self.original_inv_freq = self.inv_freq
-
- def _dynamic_frequency_update(self, position_ids, device):
- """
- dynamic RoPE layers should recompute `inv_freq` in the following situations:
- 1 - growing beyond the cached sequence length (allow scaling)
- 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
- """
- seq_len = torch.max(position_ids) + 1
- if seq_len > self.max_seq_len_cached: # growth
- inv_freq, self.attention_scaling = self.rope_init_fn(
- self.config, device, seq_len=seq_len, **self.rope_kwargs
- )
- self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
- self.max_seq_len_cached = seq_len
-
- if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
- self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
- self.max_seq_len_cached = self.original_max_seq_len
-
- @torch.no_grad()
- def forward(self, x, position_ids):
- if "dynamic" in self.rope_type:
- self._dynamic_frequency_update(position_ids, device=x.device)
-
- # Core RoPE block
- inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
- position_ids_expanded = position_ids[:, None, :].float()
- # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
- device_type = x.device.type
- device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
- with torch.autocast(device_type=device_type, enabled=False):
- freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
- emb = torch.cat((freqs, freqs), dim=-1)
- cos = emb.cos()
- sin = emb.sin()
-
- # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
- cos = cos * self.attention_scaling
- sin = sin * self.attention_scaling
-
- return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
-
-
OLMO2_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
@@ -481,7 +715,6 @@ def __init__(self, config: Olmo2Config):
[Olmo2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = Olmo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
- self.rotary_emb = Olmo2RotaryEmbedding(config=config)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
@@ -494,19 +727,20 @@ def set_input_embeddings(self, value):
self.embed_tokens = value
@add_start_docstrings_to_model_forward(OLMO2_INPUTS_DOCSTRING)
+ # copied from transformers.models.llama.modeling_llama.LlamaModel.forward
+ # TODO(joao): add me back asap :)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[Cache] = None,
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
- **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
@@ -527,15 +761,25 @@ def forward(
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
- if use_cache and past_key_values is None:
- past_key_values = DynamicCache()
+ # kept for BC (non `Cache` `past_key_values` inputs)
+ return_legacy_cache = False
+ if use_cache and not isinstance(past_key_values, Cache):
+ return_legacy_cache = True
+ if past_key_values is None:
+ past_key_values = DynamicCache()
+ else:
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
+ logger.warning_once(
+ "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
+ "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
+ "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
+ )
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
-
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
@@ -543,16 +787,15 @@ def forward(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
+ # embed positions
hidden_states = inputs_embeds
- # create position embeddings to be shared across the decoder layers
- position_embeddings = self.rotary_emb(hidden_states, position_ids)
-
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
+ next_decoder_cache = None
- for decoder_layer in self.layers[: self.config.num_hidden_layers]:
+ for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
@@ -566,7 +809,6 @@ def forward(
output_attentions,
use_cache,
cache_position,
- position_embeddings,
)
else:
layer_outputs = decoder_layer(
@@ -577,12 +819,13 @@ def forward(
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
- position_embeddings=position_embeddings,
- **flash_attn_kwargs,
)
hidden_states = layer_outputs[0]
+ if use_cache:
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
+
if output_attentions:
all_self_attns += (layer_outputs[1],)
@@ -592,13 +835,18 @@ def forward(
if output_hidden_states:
all_hidden_states += (hidden_states,)
- output = BaseModelOutputWithPast(
+ next_cache = next_decoder_cache if use_cache else None
+ if return_legacy_cache:
+ next_cache = next_cache.to_legacy_cache()
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
+ return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
- past_key_values=past_key_values if use_cache else None,
+ past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
- return output if return_dict else output.to_tuple()
def _update_causal_mask(
self,
@@ -722,14 +970,11 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
return causal_mask
-class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
-
-
+# TODO: re-enable check: Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with LLAMA->OLMO2,Llama->Olmo2
class Olmo2ForCausalLM(Olmo2PreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
- _tp_plan = {"lm_head": "colwise_rep"}
- def __init__(self, config):
+ def __init__(self, config: Olmo2Config):
super().__init__(config)
self.model = Olmo2Model(config)
self.vocab_size = config.vocab_size
@@ -758,12 +1003,13 @@ def get_decoder(self):
@add_start_docstrings_to_model_forward(OLMO2_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
+ # Ignore copy
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
@@ -772,7 +1018,7 @@ def forward(
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
- **kwargs: Unpack[KwargsForCausalLM],
+ **loss_kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
@@ -793,8 +1039,8 @@ def forward(
```python
>>> from transformers import AutoTokenizer, Olmo2ForCausalLM
- >>> model = Olmo2ForCausalLM.from_pretrained("meta-olmo2/Olmo2-2-7b-hf")
- >>> tokenizer = AutoTokenizer.from_pretrained("meta-olmo2/Olmo2-2-7b-hf")
+ >>> model = Olmo2ForCausalLM.from_pretrained("allenai/Olmo2-1B-hf")
+ >>> tokenizer = AutoTokenizer.from_pretrained("allenai/Olmo2-1B-hf")
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
@@ -802,8 +1048,9 @@ def forward(
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
- "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
- ```"""
+ 'Hey, are you conscious? Can you talk to me?\nI’m not sure if you’re conscious of this, but I’m'
+ ```
+ """
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -822,7 +1069,6 @@ def forward(
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
- **kwargs,
)
hidden_states = outputs[0]
@@ -831,7 +1077,7 @@ def forward(
loss = None
if labels is not None:
- loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
+ loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
if not return_dict:
output = (logits,) + outputs[1:]
diff --git a/src/transformers/models/olmo2/modular_olmo2.py b/src/transformers/models/olmo2/modular_olmo2.py
index 5f119170804466..393d17c59c1a8b 100644
--- a/src/transformers/models/olmo2/modular_olmo2.py
+++ b/src/transformers/models/olmo2/modular_olmo2.py
@@ -1,23 +1,30 @@
-from typing import Callable, Optional, Tuple
+import math
+from typing import Optional, Tuple
import torch
from torch import nn
from ...cache_utils import Cache
-from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
-from ...utils import logging
-from ..llama.modeling_llama import LlamaRMSNorm, eager_attention_forward
+from ...utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, logging
+from ..llama.modeling_llama import LlamaRMSNorm
from ..olmo.configuration_olmo import OlmoConfig
from ..olmo.modeling_olmo import (
OlmoAttention,
OlmoDecoderLayer,
+ OlmoFlashAttention2,
OlmoForCausalLM,
OlmoModel,
+ OlmoPreTrainedModel,
+ OlmoSdpaAttention,
apply_rotary_pos_emb,
+ repeat_kv,
)
+if is_flash_attn_2_available():
+ from ...modeling_flash_attention_utils import _flash_attention_forward
+
logger = logging.get_logger(__name__)
@@ -163,30 +170,112 @@ class Olmo2RMSNorm(LlamaRMSNorm):
class Olmo2Attention(OlmoAttention):
def __init__(self, config: Olmo2Config, layer_idx: Optional[int] = None):
super().__init__(config, layer_idx=layer_idx)
- self.q_norm = Olmo2RMSNorm(config.num_attention_heads * self.head_dim, config.rms_norm_eps)
- self.k_norm = Olmo2RMSNorm(config.num_key_value_heads * self.head_dim, config.rms_norm_eps)
+ self.q_norm = Olmo2RMSNorm(self.num_heads * self.head_dim, config.rms_norm_eps)
+ self.k_norm = Olmo2RMSNorm(self.num_key_value_heads * self.head_dim, config.rms_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
- position_embeddings: Tuple[torch.Tensor, torch.Tensor],
- attention_mask: Optional[torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_norm(self.q_proj(hidden_states))
+ key_states = self.k_norm(self.k_proj(hidden_states))
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
+
+ if attention_mask is not None: # no matter the length, we just slice it
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+class Olmo2FlashAttention2(OlmoFlashAttention2, Olmo2Attention):
+ """
+ OLMo2 flash attention module. This module inherits from `Olmo2Attention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ def __init__(self, *args, **kwargs):
+ Olmo2Attention.__init__(*args, **kwargs)
+
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
- input_shape = hidden_states.shape[:-1]
- hidden_shape = (*input_shape, -1, self.head_dim)
+ output_attentions = False
+
+ bsz, q_len, _ = hidden_states.size()
query_states = self.q_norm(self.q_proj(hidden_states))
key_states = self.k_norm(self.k_proj(hidden_states))
value_states = self.v_proj(hidden_states)
- query_states = query_states.view(hidden_shape).transpose(1, 2)
- key_states = key_states.view(hidden_shape).transpose(1, 2)
- value_states = value_states.view(hidden_shape).transpose(1, 2)
+ # Flash attention requires the input to have the shape
+ # batch_size x seq_length x head_dim x hidden_dim
+ # therefore we just need to keep the original shape
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
- cos, sin = position_embeddings
+ cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
@@ -194,30 +283,129 @@ def forward(
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
- attention_interface: Callable = eager_attention_forward
- if self.config._attn_implementation != "eager":
- if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
- logger.warning_once(
- "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
- 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
- )
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
+ # to be able to avoid many of these transpose/reshape/view.
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ dropout_rate = self.attention_dropout if self.training else 0.0
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in the correct dtype just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+ # in fp32. (OlmoRMSNorm handles it correctly)
+
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
else:
- attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+ target_dtype = self.q_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
- attn_output, attn_weights = attention_interface(
- self,
+ attn_output = _flash_attention_forward(
query_states,
key_states,
value_states,
attention_mask,
- dropout=0.0 if not self.training else self.attention_dropout,
- scaling=self.scaling,
- **kwargs,
+ q_len,
+ position_ids=position_ids,
+ dropout=dropout_rate,
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
+ is_causal=self.is_causal,
)
- attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
attn_output = self.o_proj(attn_output)
- return attn_output, attn_weights
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+class Olmo2SdpaAttention(OlmoSdpaAttention, Olmo2Attention):
+ # Adapted from Olmo2Attention.forward
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ if output_attentions:
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
+ logger.warning_once(
+ "Olmo2Model is using Olmo2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
+ )
+ return super().forward(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ )
+ bsz, q_len, _ = hidden_states.size()
+ query_states = self.q_norm(self.q_proj(hidden_states))
+ key_states = self.k_norm(self.k_proj(hidden_states))
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+ causal_mask = attention_mask
+ # if attention_mask is not None and cache_position is not None:
+ if attention_mask is not None:
+ causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
+ if query_states.device.type == "cuda" and causal_mask is not None:
+ query_states = query_states.contiguous()
+ key_states = key_states.contiguous()
+ value_states = value_states.contiguous()
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
+ is_causal = True if causal_mask is None and q_len > 1 else False
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
+ query_states,
+ key_states,
+ value_states,
+ attn_mask=causal_mask,
+ dropout_p=self.attention_dropout if self.training else 0.0,
+ is_causal=is_causal,
+ )
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.view(bsz, q_len, self.hidden_size)
+ attn_output = self.o_proj(attn_output)
+ return attn_output, None, past_key_value
# The OLMo2 layers are identical to those of the OLMo model except:
@@ -228,7 +416,6 @@ def __init__(self, config: Olmo2Config, layer_idx: int):
super().__init__(config, layer_idx=layer_idx)
self.post_attention_layernorm = Olmo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_feedforward_layernorm = Olmo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
- self.self_attn = Olmo2Attention(config=config, layer_idx=layer_idx)
del self.input_layernorm
def forward(
@@ -240,13 +427,12 @@ def forward(
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states
# Self Attention
- hidden_states, self_attn_weights = self.self_attn(
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
@@ -254,7 +440,6 @@ def forward(
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
- position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = self.post_attention_layernorm(hidden_states)
@@ -269,29 +454,36 @@ def forward(
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
-
+ if use_cache:
+ outputs += (present_key_value,)
return outputs
+class Olmo2PreTrainedModel(OlmoPreTrainedModel):
+ pass
+
+
# The OLMo2 model is identical to the OLMo model, except RMSNorm is used instead of
# standard layer norm for the output norm.
class Olmo2Model(OlmoModel):
def __init__(self, config: Olmo2Config):
super().__init__(config)
- self.norm = Olmo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.layers = nn.ModuleList(
[Olmo2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
+ self.norm = Olmo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# The heads now only need to redefine the model inside to the correct `RobertaModel`
class Olmo2ForCausalLM(OlmoForCausalLM):
- pass
+ def __init__(self, config: Olmo2Config):
+ super().__init__(config)
+ self.model = Olmo2Model(config)
__all__ = [
"Olmo2Config",
"Olmo2ForCausalLM",
"Olmo2Model",
- "Olmo2PreTrainedModel", # noqa: F822
+ "Olmo2PreTrainedModel",
]
diff --git a/src/transformers/models/olmoe/modeling_olmoe.py b/src/transformers/models/olmoe/modeling_olmoe.py
index fa3c2f3cd4d11b..4398e2f5c9a1fd 100644
--- a/src/transformers/models/olmoe/modeling_olmoe.py
+++ b/src/transformers/models/olmoe/modeling_olmoe.py
@@ -160,18 +160,40 @@ def extra_repr(self):
class OlmoeRotaryEmbedding(nn.Module):
def __init__(
self,
- config: OlmoeConfig,
+ dim=None,
+ max_position_embeddings=2048,
+ base=10000,
device=None,
+ scaling_factor=1.0,
+ rope_type="default",
+ config: Optional[OlmoeConfig] = None,
):
super().__init__()
+ # TODO (joao): remove the `if` below, only used for BC
self.rope_kwargs = {}
- # BC: "rope_type" was originally "type"
- if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
- self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
+ if config is None:
+ logger.warning_once(
+ "`OlmoeRotaryEmbedding` can now be fully parameterized by passing the model config through the "
+ "`config` argument. All other arguments will be removed in v4.46"
+ )
+ self.rope_kwargs = {
+ "rope_type": rope_type,
+ "factor": scaling_factor,
+ "dim": dim,
+ "base": base,
+ "max_position_embeddings": max_position_embeddings,
+ }
+ self.rope_type = rope_type
+ self.max_seq_len_cached = max_position_embeddings
+ self.original_max_seq_len = max_position_embeddings
else:
- self.rope_type = "default"
- self.max_seq_len_cached = config.max_position_embeddings
- self.original_max_seq_len = config.max_position_embeddings
+ # BC: "rope_type" was originally "type"
+ if config.rope_scaling is not None:
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
+ else:
+ self.rope_type = "default"
+ self.max_seq_len_cached = config.max_position_embeddings
+ self.original_max_seq_len = config.max_position_embeddings
self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
@@ -271,8 +293,7 @@ def __init__(self, config):
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x):
- down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
- return down_proj
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
# Copied from transformers.models.llama.modeling_llama.repeat_kv
@@ -401,6 +422,7 @@ class OlmoeFlashAttention2(OlmoeAttention):
flash attention and deal with padding tokens in case the input contains any of them.
"""
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py
index 3350ae1a23c2b7..e4ef510f099d66 100644
--- a/src/transformers/models/opt/modeling_opt.py
+++ b/src/transformers/models/opt/modeling_opt.py
@@ -257,6 +257,7 @@ class OptFlashAttention2(OPTAttention):
attention and deal with padding tokens in case the input contains any of them.
"""
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
diff --git a/src/transformers/models/paligemma/processing_paligemma.py b/src/transformers/models/paligemma/processing_paligemma.py
index 5783308f831541..cb35aab66cba49 100644
--- a/src/transformers/models/paligemma/processing_paligemma.py
+++ b/src/transformers/models/paligemma/processing_paligemma.py
@@ -287,6 +287,11 @@ def __call__(
elif not (isinstance(images, list) and isinstance(images[0], list) and is_valid_image(images[0][0])):
raise ValueError("images must be an image, list of images or list of list of images")
+ if suffix is not None and _is_str_or_image(suffix):
+ suffix = [suffix]
+ if suffix is not None:
+ suffix = [sfx + self.tokenizer.eos_token for sfx in suffix]
+
input_strings = [
build_string_from_input(
prompt=prompt,
@@ -309,11 +314,6 @@ def __call__(
)
expanded_samples.append(expanded_sample)
input_strings = [f"{sample}\n" for sample in expanded_samples]
-
- if suffix is not None and _is_str_or_image(suffix):
- suffix = [suffix]
- if suffix is not None:
- suffix = [sfx + self.tokenizer.eos_token for sfx in suffix]
pixel_values = self.image_processor(images, **output_kwargs["images_kwargs"])["pixel_values"]
# max_length has to account for the image tokens
diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py
index 8d3c20b9ace717..cd580ab0dc0f8c 100644
--- a/src/transformers/models/persimmon/modeling_persimmon.py
+++ b/src/transformers/models/persimmon/modeling_persimmon.py
@@ -59,18 +59,40 @@
class PersimmonRotaryEmbedding(nn.Module):
def __init__(
self,
- config: PersimmonConfig,
+ dim=None,
+ max_position_embeddings=2048,
+ base=10000,
device=None,
+ scaling_factor=1.0,
+ rope_type="default",
+ config: Optional[PersimmonConfig] = None,
):
super().__init__()
+ # TODO (joao): remove the `if` below, only used for BC
self.rope_kwargs = {}
- # BC: "rope_type" was originally "type"
- if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
- self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
+ if config is None:
+ logger.warning_once(
+ "`PersimmonRotaryEmbedding` can now be fully parameterized by passing the model config through the "
+ "`config` argument. All other arguments will be removed in v4.46"
+ )
+ self.rope_kwargs = {
+ "rope_type": rope_type,
+ "factor": scaling_factor,
+ "dim": dim,
+ "base": base,
+ "max_position_embeddings": max_position_embeddings,
+ }
+ self.rope_type = rope_type
+ self.max_seq_len_cached = max_position_embeddings
+ self.original_max_seq_len = max_position_embeddings
else:
- self.rope_type = "default"
- self.max_seq_len_cached = config.max_position_embeddings
- self.original_max_seq_len = config.max_position_embeddings
+ # BC: "rope_type" was originally "type"
+ if config.rope_scaling is not None:
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
+ else:
+ self.rope_type = "default"
+ self.max_seq_len_cached = config.max_position_embeddings
+ self.original_max_seq_len = config.max_position_embeddings
self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
@@ -121,6 +143,33 @@ def forward(self, x, position_ids):
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->Persimmon
+class PersimmonLinearScalingRotaryEmbedding(PersimmonRotaryEmbedding):
+ """PersimmonRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
+
+ def __init__(self, *args, **kwargs):
+ logger.warning_once(
+ "`PersimmonLinearScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use "
+ "`PersimmonRotaryEmbedding`, which now also does linear scaling (simply pass the model config to __init__)."
+ )
+ kwargs["rope_type"] = "linear"
+ super().__init__(*args, **kwargs)
+
+
+# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Persimmon
+class PersimmonDynamicNTKScalingRotaryEmbedding(PersimmonRotaryEmbedding):
+ """PersimmonRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
+
+ def __init__(self, *args, **kwargs):
+ logger.warning_once(
+ "`PersimmonDynamicNTKScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use "
+ "`PersimmonRotaryEmbedding`, which now also does dynamic ntk scaling (simply pass the model config to "
+ "__init__)."
+ )
+ kwargs["rope_type"] = "dynamic"
+ super().__init__(*args, **kwargs)
+
+
# Copied from transformers.models.llama.modeling_llama.rotate_half
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
@@ -237,7 +286,7 @@ def forward(
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
@@ -256,7 +305,16 @@ def forward(
value_states = value_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
- cos, sin = position_embeddings
+ if position_embeddings is None:
+ logger.warning_once(
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
+ "removed and `position_embeddings` will be mandatory."
+ )
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ else:
+ cos, sin = position_embeddings
# Partial rotary embedding
query_rot, query_pass = (
@@ -332,7 +390,7 @@ def forward(
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py
index 477896decd5318..4613672ff2740b 100644
--- a/src/transformers/models/phi/modeling_phi.py
+++ b/src/transformers/models/phi/modeling_phi.py
@@ -1,19 +1,33 @@
-# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
-# This file was automatically generated from src/transformers/models/phi/modular_phi.py.
-# Do NOT edit this file manually as any edits will be overwritten by the generation of
-# the file from the modular. If any change should be done, please apply the change to the
-# modular_phi.py file directly. One of our CI enforces this.
-# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
-from typing import Callable, List, Optional, Tuple, Union
+# coding=utf-8
+# Copyright 2023 Microsoft and the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""PyTorch Phi model."""
+
+import math
+from typing import List, Optional, Tuple, Union
import torch
-import torch.nn as nn
+import torch.utils.checkpoint
+from packaging import version
+from torch import nn
+from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, StaticCache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter
-from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
@@ -21,25 +35,146 @@
TokenClassifierOutput,
)
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
-from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
-from ...processing_utils import Unpack
+from ...modeling_utils import PreTrainedModel
from ...utils import (
- LossKwargs,
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
+ get_torch_version,
+ is_flash_attn_2_available,
+ is_flash_attn_greater_or_equal_2_10,
logging,
replace_return_docstrings,
)
from .configuration_phi import PhiConfig
+if is_flash_attn_2_available():
+ from ...modeling_flash_attention_utils import _flash_attention_forward
+
+
logger = logging.get_logger(__name__)
-_CHECKPOINT_FOR_DOC = "meta-phi/Phi-2-7b-hf"
+_CHECKPOINT_FOR_DOC = "microsoft/phi-1"
_CONFIG_FOR_DOC = "PhiConfig"
+# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Phi
+class PhiRotaryEmbedding(nn.Module):
+ def __init__(
+ self,
+ dim=None,
+ max_position_embeddings=2048,
+ base=10000,
+ device=None,
+ scaling_factor=1.0,
+ rope_type="default",
+ config: Optional[PhiConfig] = None,
+ ):
+ super().__init__()
+ # TODO (joao): remove the `if` below, only used for BC
+ self.rope_kwargs = {}
+ if config is None:
+ logger.warning_once(
+ "`PhiRotaryEmbedding` can now be fully parameterized by passing the model config through the "
+ "`config` argument. All other arguments will be removed in v4.46"
+ )
+ self.rope_kwargs = {
+ "rope_type": rope_type,
+ "factor": scaling_factor,
+ "dim": dim,
+ "base": base,
+ "max_position_embeddings": max_position_embeddings,
+ }
+ self.rope_type = rope_type
+ self.max_seq_len_cached = max_position_embeddings
+ self.original_max_seq_len = max_position_embeddings
+ else:
+ # BC: "rope_type" was originally "type"
+ if config.rope_scaling is not None:
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
+ else:
+ self.rope_type = "default"
+ self.max_seq_len_cached = config.max_position_embeddings
+ self.original_max_seq_len = config.max_position_embeddings
+
+ self.config = config
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
+
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ self.original_inv_freq = self.inv_freq
+
+ def _dynamic_frequency_update(self, position_ids, device):
+ """
+ dynamic RoPE layers should recompute `inv_freq` in the following situations:
+ 1 - growing beyond the cached sequence length (allow scaling)
+ 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
+ """
+ seq_len = torch.max(position_ids) + 1
+ if seq_len > self.max_seq_len_cached: # growth
+ inv_freq, self.attention_scaling = self.rope_init_fn(
+ self.config, device, seq_len=seq_len, **self.rope_kwargs
+ )
+ self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
+ self.max_seq_len_cached = seq_len
+
+ if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
+ self.max_seq_len_cached = self.original_max_seq_len
+
+ @torch.no_grad()
+ def forward(self, x, position_ids):
+ if "dynamic" in self.rope_type:
+ self._dynamic_frequency_update(position_ids, device=x.device)
+
+ # Core RoPE block
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
+ position_ids_expanded = position_ids[:, None, :].float()
+ # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
+ device_type = x.device.type
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False):
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos()
+ sin = emb.sin()
+
+ # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
+ cos = cos * self.attention_scaling
+ sin = sin * self.attention_scaling
+
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->Phi
+class PhiLinearScalingRotaryEmbedding(PhiRotaryEmbedding):
+ """PhiRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
+
+ def __init__(self, *args, **kwargs):
+ logger.warning_once(
+ "`PhiLinearScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use "
+ "`PhiRotaryEmbedding`, which now also does linear scaling (simply pass the model config to __init__)."
+ )
+ kwargs["rope_type"] = "linear"
+ super().__init__(*args, **kwargs)
+
+
+# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Phi
+class PhiDynamicNTKScalingRotaryEmbedding(PhiRotaryEmbedding):
+ """PhiRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
+
+ def __init__(self, *args, **kwargs):
+ logger.warning_once(
+ "`PhiDynamicNTKScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use "
+ "`PhiRotaryEmbedding`, which now also does dynamic ntk scaling (simply pass the model config to "
+ "__init__)."
+ )
+ kwargs["rope_type"] = "dynamic"
+ super().__init__(*args, **kwargs)
+
+
+# Copied from transformers.models.llama.modeling_llama.rotate_half
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
@@ -47,6 +182,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)
+# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
@@ -74,6 +210,23 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
return q_embed, k_embed
+# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Phi
+class PhiMLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.activation_fn = ACT2FN[config.hidden_act]
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.fc1(hidden_states)
+ hidden_states = self.activation_fn(hidden_states)
+ hidden_states = self.fc2(hidden_states)
+ return hidden_states
+
+
+# Copied from transformers.models.llama.modeling_llama.repeat_kv with llama->phi
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
@@ -86,79 +239,208 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
-def eager_attention_forward(
- module: nn.Module,
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- attention_mask: Optional[torch.Tensor],
- scaling: float,
- dropout: float = 0.0,
- **kwargs,
-):
- key_states = repeat_kv(key, module.num_key_value_groups)
- value_states = repeat_kv(value, module.num_key_value_groups)
-
- attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
- if attention_mask is not None:
- causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
- attn_weights = attn_weights + causal_mask
-
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
- attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
- attn_output = torch.matmul(attn_weights, value_states)
- attn_output = attn_output.transpose(1, 2).contiguous()
-
- return attn_output, attn_weights
-
-
class PhiAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
- def __init__(self, config: PhiConfig, layer_idx: int):
+ def __init__(self, config: PhiConfig, layer_idx: Optional[int] = None):
super().__init__()
self.config = config
self.layer_idx = layer_idx
- self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
- self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
- self.scaling = self.head_dim**-0.5
+ if layer_idx is None:
+ logger.warning_once(
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
+ "when creating this class."
+ )
+
self.attention_dropout = config.attention_dropout
- self.is_causal = True
- self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=True)
- self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True)
- self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True)
- self.dense = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=True)
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.hidden_size // self.num_heads
+ self.num_key_value_heads = config.num_key_value_heads
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
+ self.rope_theta = config.rope_theta
self.rotary_ndims = int(self.head_dim * config.partial_rotary_factor)
+ self.is_causal = True
+
+ if (self.head_dim * self.num_heads) != self.hidden_size:
+ raise ValueError(
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
+ f" and `num_heads`: {self.num_heads})."
+ )
+
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
+ self.dense = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=True)
+
self.qk_layernorm = config.qk_layernorm
if self.qk_layernorm:
self.q_layernorm = nn.LayerNorm(
- config.hidden_size // config.num_attention_heads, eps=config.layer_norm_eps, elementwise_affine=True
+ config.hidden_size // self.num_heads, eps=config.layer_norm_eps, elementwise_affine=True
)
self.k_layernorm = nn.LayerNorm(
- config.hidden_size // config.num_attention_heads, eps=config.layer_norm_eps, elementwise_affine=True
+ config.hidden_size // self.num_heads, eps=config.layer_norm_eps, elementwise_affine=True
)
+ self.rotary_emb = PhiRotaryEmbedding(config=self.config)
+
def forward(
self,
hidden_states: torch.Tensor,
- position_embeddings: Tuple[torch.Tensor, torch.Tensor],
- attention_mask: Optional[torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ if self.qk_layernorm:
+ query_states = self.q_layernorm(query_states)
+ key_states = self.k_layernorm(key_states)
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ if position_embeddings is None:
+ logger.warning_once(
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
+ "removed and `position_embeddings` will be mandatory."
+ )
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ else:
+ cos, sin = position_embeddings
+
+ # Partial rotary embedding
+ query_rot, query_pass = (
+ query_states[..., : self.rotary_ndims],
+ query_states[..., self.rotary_ndims :],
+ )
+ key_rot, key_pass = (
+ key_states[..., : self.rotary_ndims],
+ key_states[..., self.rotary_ndims :],
+ )
+ # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
+ query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
+
+ # [batch_size, seq_length, num_heads, head_dim]
+ query_states = torch.cat((query_rot, query_pass), dim=-1)
+ key_states = torch.cat((key_rot, key_pass), dim=-1)
+
+ if past_key_value is not None:
+ cache_kwargs = {
+ "sin": sin,
+ "cos": cos,
+ "partial_rotation_size": self.rotary_ndims,
+ "cache_position": cache_position,
+ }
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ # Queries and keys upcast to fp32 is required by Phi-2 to avoid overflow
+ attn_weights = torch.matmul(
+ query_states.to(torch.float32), key_states.to(torch.float32).transpose(2, 3)
+ ) / math.sqrt(self.head_dim)
+
+ if attention_mask is not None:
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights += causal_mask
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(value_states.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
+
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+
+ attn_output = self.dense(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+class PhiFlashAttention2(PhiAttention):
+ """
+ Phi flash attention module. This module inherits from `PhiAttention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
- input_shape = hidden_states.shape[:-1]
- hidden_shape = (*input_shape, -1, self.head_dim)
+ # PhiFlashAttention2 attention does not support output_attentions
+
+ output_attentions = False
- query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
- key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
- value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
if self.qk_layernorm:
query_states = self.q_layernorm(query_states)
key_states = self.k_layernorm(key_states)
- cos, sin = position_embeddings
+ # Flash attention requires the input to have the shape
+ # batch_size x seq_length x head_dim x hidden_dim
+ # therefore we just need to keep the original shape
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ if position_embeddings is None:
+ logger.warning_once(
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
+ "removed and `position_embeddings` will be mandatory."
+ )
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ else:
+ cos, sin = position_embeddings
+
# Partial rotary embedding
query_rot, query_pass = (
query_states[..., : self.rotary_ndims],
@@ -176,55 +458,206 @@ def forward(
key_states = torch.cat((key_rot, key_pass), dim=-1)
if past_key_value is not None:
- # sin and cos are specific to RoPE models; cache_position needed for the static cache
- cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ cache_kwargs = {
+ "sin": sin,
+ "cos": cos,
+ "partial_rotation_size": self.rotary_ndims,
+ "cache_position": cache_position,
+ }
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
- attention_interface: Callable = eager_attention_forward
- if self.config._attn_implementation != "eager":
- if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
- logger.warning_once(
- "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
- 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
- )
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
+ # to be able to avoid many of these transpose/reshape/view.
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ attn_dropout = self.attention_dropout if self.training else 0.0
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in the correct dtype just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+ # in fp32.
+
+ if query_states.dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
else:
- attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+ target_dtype = self.q_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
- attn_output, attn_weights = attention_interface(
- self,
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ attn_output = _flash_attention_forward(
query_states,
key_states,
value_states,
attention_mask,
- dropout=0.0 if not self.training else self.attention_dropout,
- scaling=self.scaling,
- **kwargs,
+ q_len,
+ position_ids=position_ids,
+ dropout=attn_dropout,
+ softmax_scale=None,
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
+ is_causal=self.is_causal,
)
- attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
attn_output = self.dense(attn_output)
- return attn_output, attn_weights
+ if not output_attentions:
+ attn_weights = None
-class PhiMLP(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.config = config
- self.activation_fn = ACT2FN[config.hidden_act]
- self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
- self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
+ return attn_output, attn_weights, past_key_value
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- hidden_states = self.fc1(hidden_states)
- hidden_states = self.activation_fn(hidden_states)
- hidden_states = self.fc2(hidden_states)
- return hidden_states
+
+class PhiSdpaAttention(PhiAttention):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0")
+
+ """
+ SDPA attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
+ `PhiAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
+ SDPA API.
+ """
+
+ # Adapted from PhiAttention.forward
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ if output_attentions:
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
+ logger.warning_once(
+ "PhiModel is using PhiSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not "
+ "support `output_attentions=True`. Falling back to the manual attention implementation, but specifying "
+ "the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can "
+ 'be removed using the argument `attn_implementation="eager"` when loading the model.'
+ )
+ return super().forward(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ )
+
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ if self.qk_layernorm:
+ query_states = self.q_layernorm(query_states)
+ key_states = self.k_layernorm(key_states)
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ if position_embeddings is None:
+ logger.warning_once(
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
+ "removed and `position_embeddings` will be mandatory."
+ )
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ else:
+ cos, sin = position_embeddings
+
+ # Partial rotary embedding
+ query_rot, query_pass = (
+ query_states[..., : self.rotary_ndims],
+ query_states[..., self.rotary_ndims :],
+ )
+ key_rot, key_pass = (
+ key_states[..., : self.rotary_ndims],
+ key_states[..., self.rotary_ndims :],
+ )
+ # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
+ query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
+
+ # [batch_size, seq_length, num_heads, head_dim]
+ query_states = torch.cat((query_rot, query_pass), dim=-1)
+ key_states = torch.cat((key_rot, key_pass), dim=-1)
+
+ if past_key_value is not None:
+ cache_kwargs = {
+ "sin": sin,
+ "cos": cos,
+ "partial_rotation_size": self.rotary_ndims,
+ "cache_position": cache_position,
+ }
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ causal_mask = attention_mask
+ if attention_mask is not None:
+ causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
+
+ # SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom
+ # attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0.
+ # Reference: https://github.com/pytorch/pytorch/issues/112577
+ if self.require_contiguous_qkv and query_states.device.type == "cuda" and attention_mask is not None:
+ query_states = query_states.contiguous()
+ key_states = key_states.contiguous()
+ value_states = value_states.contiguous()
+
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
+ is_causal = True if causal_mask is None and q_len > 1 else False
+
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
+ query_states,
+ key_states,
+ value_states,
+ attn_mask=causal_mask,
+ dropout_p=self.attention_dropout if self.training else 0.0,
+ is_causal=is_causal,
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+
+ attn_output = self.dense(attn_output)
+
+ return attn_output, None, past_key_value
+
+
+PHI_ATTENTION_CLASSES = {
+ "eager": PhiAttention,
+ "flash_attention_2": PhiFlashAttention2,
+ "sdpa": PhiSdpaAttention,
+}
class PhiDecoderLayer(nn.Module):
def __init__(self, config: PhiConfig, layer_idx: int):
super().__init__()
- self.self_attn = PhiAttention(config, layer_idx=layer_idx)
+ self.self_attn = PHI_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx)
self.mlp = PhiMLP(config)
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.resid_dropout = nn.Dropout(config.resid_pdrop)
@@ -234,19 +667,45 @@ def forward(
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
cache_position: Optional[torch.LongTensor] = None,
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`):
+ input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+ position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range
+ `[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence
+ position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
+ Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
+ with `head_dim` being the embedding dimension of each attention head.
+ kwargs (`dict`, *optional*):
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
+ into the model
+ """
+
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
- attn_outputs, self_attn_weights = self.self_attn(
+ attn_outputs, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
@@ -255,7 +714,6 @@ def forward(
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
- **kwargs,
)
attn_outputs = self.resid_dropout(attn_outputs)
@@ -266,72 +724,10 @@ def forward(
if output_attentions:
outputs += (self_attn_weights,)
- return outputs
-
-
-class PhiRotaryEmbedding(nn.Module):
- def __init__(
- self,
- config: PhiConfig,
- device=None,
- ):
- super().__init__()
- self.rope_kwargs = {}
- # BC: "rope_type" was originally "type"
- if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
- self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
- else:
- self.rope_type = "default"
- self.max_seq_len_cached = config.max_position_embeddings
- self.original_max_seq_len = config.max_position_embeddings
-
- self.config = config
- self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
+ if use_cache:
+ outputs += (present_key_value,)
- inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
- self.register_buffer("inv_freq", inv_freq, persistent=False)
- self.original_inv_freq = self.inv_freq
-
- def _dynamic_frequency_update(self, position_ids, device):
- """
- dynamic RoPE layers should recompute `inv_freq` in the following situations:
- 1 - growing beyond the cached sequence length (allow scaling)
- 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
- """
- seq_len = torch.max(position_ids) + 1
- if seq_len > self.max_seq_len_cached: # growth
- inv_freq, self.attention_scaling = self.rope_init_fn(
- self.config, device, seq_len=seq_len, **self.rope_kwargs
- )
- self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
- self.max_seq_len_cached = seq_len
-
- if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
- self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
- self.max_seq_len_cached = self.original_max_seq_len
-
- @torch.no_grad()
- def forward(self, x, position_ids):
- if "dynamic" in self.rope_type:
- self._dynamic_frequency_update(position_ids, device=x.device)
-
- # Core RoPE block
- inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
- position_ids_expanded = position_ids[:, None, :].float()
- # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
- device_type = x.device.type
- device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
- with torch.autocast(device_type=device_type, enabled=False):
- freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
- emb = torch.cat((freqs, freqs), dim=-1)
- cos = emb.cos()
- sin = emb.sin()
-
- # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
- cos = cos * self.attention_scaling
- sin = sin * self.attention_scaling
-
- return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+ return outputs
PHI_START_DOCSTRING = r"""
@@ -360,12 +756,12 @@ class PhiPreTrainedModel(PreTrainedModel):
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["PhiDecoderLayer"]
- _skip_keys_device_placement = ["past_key_values"]
+ _skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
- _supports_quantized_cache = True
_supports_static_cache = True
+ _supports_quantized_cache = True
def _init_weights(self, module):
std = self.config.initializer_range
@@ -472,14 +868,17 @@ def __init__(self, config: PhiConfig):
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
+ self.embed_dropout = nn.Dropout(config.embd_pdrop)
self.layers = nn.ModuleList(
[PhiDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
- self.rotary_emb = PhiRotaryEmbedding(config=config)
- self.gradient_checkpointing = False
- self.embed_dropout = nn.Dropout(config.embd_pdrop)
self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.rotary_emb = PhiRotaryEmbedding(config=config)
+
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
+ self._use_sdpa = config._attn_implementation == "sdpa"
+ self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
@@ -495,43 +894,54 @@ def forward(
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[Cache] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
- **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
- if self.gradient_checkpointing and self.training and use_cache:
- logger.warning_once(
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
- )
- use_cache = False
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ # kept for BC (non `Cache` `past_key_values` inputs)
+ return_legacy_cache = False
+ if use_cache and not isinstance(past_key_values, Cache):
+ return_legacy_cache = True
+ if past_key_values is None:
+ past_key_values = DynamicCache()
+ else:
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
+ logger.warning_once(
+ "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
+ "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
+ "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
+ )
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
- if use_cache and past_key_values is None:
- past_key_values = DynamicCache()
-
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
-
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
@@ -539,7 +949,7 @@ def forward(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
- inputs_embeds = self.embed_dropout(inputs_embeds) # diff with Llama
+ inputs_embeds = self.embed_dropout(inputs_embeds)
hidden_states = inputs_embeds
# create position embeddings to be shared across the decoder layers
@@ -548,8 +958,9 @@ def forward(
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
+ next_decoder_cache = None
- for decoder_layer in self.layers[: self.config.num_hidden_layers]:
+ for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
@@ -559,9 +970,9 @@ def forward(
hidden_states,
causal_mask,
position_ids,
- past_key_values,
output_attentions,
use_cache,
+ past_key_values,
cache_position,
position_embeddings,
)
@@ -575,28 +986,36 @@ def forward(
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
- **flash_attn_kwargs,
)
hidden_states = layer_outputs[0]
+ if use_cache:
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
+
if output_attentions:
all_self_attns += (layer_outputs[1],)
- hidden_states = self.final_layernorm(hidden_states) # diff with Llama
+ hidden_states = self.final_layernorm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
- output = BaseModelOutputWithPast(
+ next_cache = next_decoder_cache if use_cache else None
+ if return_legacy_cache:
+ next_cache = next_cache.to_legacy_cache()
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
+ return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
- past_key_values=past_key_values if use_cache else None,
+ past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
- return output if return_dict else output.to_tuple()
+ # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
def _update_causal_mask(
self,
attention_mask: torch.Tensor,
@@ -663,6 +1082,7 @@ def _update_causal_mask(
return causal_mask
@staticmethod
+ # Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
@@ -719,37 +1139,40 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
return causal_mask
-class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
-
-
class PhiForCausalLM(PhiPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
- _tp_plan = {"lm_head": "colwise_rep"}
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__ with Llama->Phi,bias=False->bias=True
def __init__(self, config):
super().__init__(config)
self.model = PhiModel(config)
self.vocab_size = config.vocab_size
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=True)
# Initialize weights and apply final processing
self.post_init()
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_input_embeddings
def get_input_embeddings(self):
return self.model.embed_tokens
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_input_embeddings
def set_input_embeddings(self, value):
self.model.embed_tokens = value
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_output_embeddings
def get_output_embeddings(self):
return self.lm_head
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_output_embeddings
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_decoder
def set_decoder(self, decoder):
self.model = decoder
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_decoder
def get_decoder(self):
return self.model
@@ -760,7 +1183,7 @@ def forward(
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
@@ -769,7 +1192,7 @@ def forward(
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
- **kwargs: Unpack[KwargsForCausalLM],
+ **loss_kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
@@ -790,17 +1213,18 @@ def forward(
```python
>>> from transformers import AutoTokenizer, PhiForCausalLM
- >>> model = PhiForCausalLM.from_pretrained("meta-phi/Phi-2-7b-hf")
- >>> tokenizer = AutoTokenizer.from_pretrained("meta-phi/Phi-2-7b-hf")
+ >>> model = PhiForCausalLM.from_pretrained("microsoft/phi-1")
+ >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-1")
- >>> prompt = "Hey, are you conscious? Can you talk to me?"
+ >>> prompt = "This is an example script ."
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
- "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
+ 'This is an example script .\n\n\n\nfrom typing import List\n\ndef find_most_common_letter(words: List[str'
```"""
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -819,7 +1243,6 @@ def forward(
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
- **kwargs,
)
hidden_states = outputs[0]
@@ -828,7 +1251,7 @@ def forward(
loss = None
if labels is not None:
- loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
+ loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
if not return_dict:
output = (logits,) + outputs[1:]
@@ -845,7 +1268,7 @@ def forward(
@add_start_docstrings(
"""
- The Phi Model transformer with a sequence classification head on top (linear layer).
+ The PhiModel with a sequence classification head on top (linear layer).
[`PhiForSequenceClassification`] uses the last token in order to do the classification, as other causal models
(e.g. GPT-2) do.
@@ -858,6 +1281,7 @@ def forward(
""",
PHI_START_DOCSTRING,
)
+# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with LLAMA->PHI,Llama->Phi with self.transformer->self.model, transformer_outputs->model_outputs
class PhiForSequenceClassification(PhiPreTrainedModel):
def __init__(self, config):
super().__init__(config)
@@ -896,7 +1320,7 @@ def forward(
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- transformer_outputs = self.model(
+ model_outputs = self.model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
@@ -907,7 +1331,7 @@ def forward(
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
- hidden_states = transformer_outputs[0]
+ hidden_states = model_outputs[0]
logits = self.score(hidden_states)
if input_ids is not None:
@@ -935,48 +1359,44 @@ def forward(
loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
if not return_dict:
- output = (pooled_logits,) + transformer_outputs[1:]
+ output = (pooled_logits,) + model_outputs[1:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutputWithPast(
loss=loss,
logits=pooled_logits,
- past_key_values=transformer_outputs.past_key_values,
- hidden_states=transformer_outputs.hidden_states,
- attentions=transformer_outputs.attentions,
+ past_key_values=model_outputs.past_key_values,
+ hidden_states=model_outputs.hidden_states,
+ attentions=model_outputs.attentions,
)
@add_start_docstrings(
"""
- The Phi Model transformer with a token classification head on top (a linear layer on top of the hidden-states
- output) e.g. for Named-Entity-Recognition (NER) tasks.
+ PhiModel with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
+ Named-Entity-Recognition (NER) tasks.
""",
PHI_START_DOCSTRING,
)
+# Copied from transformers.models.mpt.modeling_mpt.MptForTokenClassification with MPT->PHI,Mpt->Phi,self.transformer->self.model,transformer_outputs->model_outputs
class PhiForTokenClassification(PhiPreTrainedModel):
- def __init__(self, config):
+ def __init__(self, config: PhiConfig):
super().__init__(config)
self.num_labels = config.num_labels
+
self.model = PhiModel(config)
- if getattr(config, "classifier_dropout", None) is not None:
+ if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None:
classifier_dropout = config.classifier_dropout
- elif getattr(config, "hidden_dropout", None) is not None:
+ elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None:
classifier_dropout = config.hidden_dropout
else:
classifier_dropout = 0.1
self.dropout = nn.Dropout(classifier_dropout)
- self.score = nn.Linear(config.hidden_size, config.num_labels)
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
# Initialize weights and apply final processing
self.post_init()
- def get_input_embeddings(self):
- return self.model.embed_tokens
-
- def set_input_embeddings(self, value):
- self.model.embed_tokens = value
-
@add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
@@ -986,16 +1406,16 @@ def set_input_embeddings(self, value):
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[List[torch.FloatTensor]] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- labels: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
- ) -> Union[Tuple, TokenClassifierOutput]:
+ **deprecated_arguments,
+ ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
@@ -1004,32 +1424,38 @@ def forward(
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- outputs = self.model(
+ model_outputs = self.model(
input_ids,
- attention_mask=attention_mask,
- position_ids=position_ids,
past_key_values=past_key_values,
+ attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
- sequence_output = outputs[0]
- sequence_output = self.dropout(sequence_output)
- logits = self.score(sequence_output)
+
+ hidden_states = model_outputs[0]
+ hidden_states = self.dropout(hidden_states)
+ logits = self.classifier(hidden_states)
loss = None
if labels is not None:
- loss = self.loss_function(logits, labels, self.config)
+ # move labels to correct device to enable model parallelism
+ labels = labels.to(logits.device)
+ batch_size, seq_length = labels.shape
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(
+ logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length)
+ )
if not return_dict:
- output = (logits,) + outputs[2:]
+ output = (logits,) + model_outputs[2:]
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(
loss=loss,
logits=logits,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
+ hidden_states=model_outputs.hidden_states,
+ attentions=model_outputs.attentions,
)
diff --git a/src/transformers/models/phi/modular_phi.py b/src/transformers/models/phi/modular_phi.py
deleted file mode 100644
index 0faa4629f1a768..00000000000000
--- a/src/transformers/models/phi/modular_phi.py
+++ /dev/null
@@ -1,295 +0,0 @@
-from typing import Callable, Optional, Tuple, Union
-
-import torch
-import torch.nn as nn
-
-from ...cache_utils import Cache, DynamicCache
-from ...modeling_flash_attention_utils import FlashAttentionKwargs
-from ...modeling_outputs import (
- BaseModelOutputWithPast,
-)
-from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
-from ...processing_utils import Unpack
-from ...utils import logging
-from ..clip.modeling_clip import CLIPMLP
-from ..llama.modeling_llama import (
- LlamaAttention,
- LlamaForCausalLM,
- LlamaForSequenceClassification,
- LlamaForTokenClassification,
- LlamaModel,
- apply_rotary_pos_emb,
- eager_attention_forward, # copied from Llama
-)
-from .configuration_phi import PhiConfig
-
-
-logger = logging.get_logger(__name__)
-
-
-class PhiAttention(LlamaAttention):
- def __init__(self, config: PhiConfig, layer_idx: int):
- super().__init__(config, layer_idx)
- self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=True)
- self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True)
- self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True)
- self.dense = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=True)
- del self.o_proj
- self.rotary_ndims = int(self.head_dim * config.partial_rotary_factor)
- self.qk_layernorm = config.qk_layernorm
- if self.qk_layernorm:
- self.q_layernorm = nn.LayerNorm(
- config.hidden_size // config.num_attention_heads, eps=config.layer_norm_eps, elementwise_affine=True
- )
- self.k_layernorm = nn.LayerNorm(
- config.hidden_size // config.num_attention_heads, eps=config.layer_norm_eps, elementwise_affine=True
- )
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- position_embeddings: Tuple[torch.Tensor, torch.Tensor],
- attention_mask: Optional[torch.Tensor],
- past_key_value: Optional[Cache] = None,
- cache_position: Optional[torch.LongTensor] = None,
- **kwargs,
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
- input_shape = hidden_states.shape[:-1]
- hidden_shape = (*input_shape, -1, self.head_dim)
-
- query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
- key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
- value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
-
- if self.qk_layernorm:
- query_states = self.q_layernorm(query_states)
- key_states = self.k_layernorm(key_states)
-
- cos, sin = position_embeddings
- # Partial rotary embedding
- query_rot, query_pass = (
- query_states[..., : self.rotary_ndims],
- query_states[..., self.rotary_ndims :],
- )
- key_rot, key_pass = (
- key_states[..., : self.rotary_ndims],
- key_states[..., self.rotary_ndims :],
- )
- # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
- query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
-
- # [batch_size, seq_length, num_heads, head_dim]
- query_states = torch.cat((query_rot, query_pass), dim=-1)
- key_states = torch.cat((key_rot, key_pass), dim=-1)
-
- if past_key_value is not None:
- # sin and cos are specific to RoPE models; cache_position needed for the static cache
- cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
- key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
-
- attention_interface: Callable = eager_attention_forward
- if self.config._attn_implementation != "eager":
- if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
- logger.warning_once(
- "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
- 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
- )
- else:
- attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
-
- attn_output, attn_weights = attention_interface(
- self,
- query_states,
- key_states,
- value_states,
- attention_mask,
- dropout=0.0 if not self.training else self.attention_dropout,
- scaling=self.scaling,
- **kwargs,
- )
-
- attn_output = attn_output.reshape(*input_shape, -1).contiguous()
- attn_output = self.dense(attn_output)
- return attn_output, attn_weights
-
-
-class PhiMLP(CLIPMLP):
- pass
-
-
-class PhiDecoderLayer(nn.Module):
- def __init__(self, config: PhiConfig, layer_idx: int):
- super().__init__()
- self.self_attn = PhiAttention(config, layer_idx=layer_idx)
- self.mlp = PhiMLP(config)
- self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- self.resid_dropout = nn.Dropout(config.resid_pdrop)
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
- output_attentions: Optional[bool] = False,
- use_cache: Optional[bool] = False,
- cache_position: Optional[torch.LongTensor] = None,
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
- **kwargs,
- ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
- residual = hidden_states
-
- hidden_states = self.input_layernorm(hidden_states)
-
- # Self Attention
- attn_outputs, self_attn_weights = self.self_attn(
- hidden_states=hidden_states,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_value=past_key_value,
- output_attentions=output_attentions,
- use_cache=use_cache,
- cache_position=cache_position,
- position_embeddings=position_embeddings,
- **kwargs,
- )
- attn_outputs = self.resid_dropout(attn_outputs)
-
- feed_forward_hidden_states = self.resid_dropout(self.mlp(hidden_states))
- hidden_states = attn_outputs + feed_forward_hidden_states + residual
- outputs = (hidden_states,)
-
- if output_attentions:
- outputs += (self_attn_weights,)
-
- return outputs
-
-
-class PhiModel(LlamaModel):
- def __init__(self, config: PhiConfig):
- super().__init__(config)
- self.layers = nn.ModuleList(
- [PhiDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
- )
- self.embed_dropout = nn.Dropout(config.embd_pdrop)
- self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- del self.norm
-
- def forward(
- self,
- input_ids: torch.LongTensor = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[Cache] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- cache_position: Optional[torch.LongTensor] = None,
- **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
- ) -> Union[Tuple, BaseModelOutputWithPast]:
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- use_cache = use_cache if use_cache is not None else self.config.use_cache
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
-
- if (input_ids is None) ^ (inputs_embeds is not None):
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
-
- if self.gradient_checkpointing and self.training and use_cache:
- logger.warning_once(
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
- )
- use_cache = False
-
- if inputs_embeds is None:
- inputs_embeds = self.embed_tokens(input_ids)
-
- if use_cache and past_key_values is None:
- past_key_values = DynamicCache()
-
- if cache_position is None:
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
- cache_position = torch.arange(
- past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
- )
-
- if position_ids is None:
- position_ids = cache_position.unsqueeze(0)
-
- causal_mask = self._update_causal_mask(
- attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
- )
-
- inputs_embeds = self.embed_dropout(inputs_embeds) # diff with Llama
- hidden_states = inputs_embeds
-
- # create position embeddings to be shared across the decoder layers
- position_embeddings = self.rotary_emb(hidden_states, position_ids)
-
- # decoder layers
- all_hidden_states = () if output_hidden_states else None
- all_self_attns = () if output_attentions else None
-
- for decoder_layer in self.layers[: self.config.num_hidden_layers]:
- if output_hidden_states:
- all_hidden_states += (hidden_states,)
-
- if self.gradient_checkpointing and self.training:
- layer_outputs = self._gradient_checkpointing_func(
- decoder_layer.__call__,
- hidden_states,
- causal_mask,
- position_ids,
- past_key_values,
- output_attentions,
- use_cache,
- cache_position,
- position_embeddings,
- )
- else:
- layer_outputs = decoder_layer(
- hidden_states,
- attention_mask=causal_mask,
- position_ids=position_ids,
- past_key_value=past_key_values,
- output_attentions=output_attentions,
- use_cache=use_cache,
- cache_position=cache_position,
- position_embeddings=position_embeddings,
- **flash_attn_kwargs,
- )
-
- hidden_states = layer_outputs[0]
-
- if output_attentions:
- all_self_attns += (layer_outputs[1],)
-
- hidden_states = self.final_layernorm(hidden_states) # diff with Llama
-
- # add hidden states from the last decoder layer
- if output_hidden_states:
- all_hidden_states += (hidden_states,)
-
- output = BaseModelOutputWithPast(
- last_hidden_state=hidden_states,
- past_key_values=past_key_values if use_cache else None,
- hidden_states=all_hidden_states,
- attentions=all_self_attns,
- )
- return output if return_dict else output.to_tuple()
-
-
-class PhiForCausalLM(LlamaForCausalLM):
- pass
-
-
-class PhiForSequenceClassification(LlamaForSequenceClassification):
- pass
-
-
-class PhiForTokenClassification(LlamaForTokenClassification):
- pass
diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py
index 908fd982b9c73c..bae3f6d4cdaeaa 100644
--- a/src/transformers/models/phi3/modeling_phi3.py
+++ b/src/transformers/models/phi3/modeling_phi3.py
@@ -74,8 +74,7 @@ def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
-# copied from transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding with gemma->phi3, Gemma->Phi3
-# TODO cyril: modular
+# Copied from transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding with gemma->phi3, Gemma->Phi3
class Phi3RotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
@@ -432,6 +431,7 @@ class Phi3FlashAttention2(Phi3Attention):
flash attention and deal with padding tokens in case the input contains any of them.
"""
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -550,8 +550,8 @@ def forward(
return attn_output, attn_weights, past_key_value
-# NO LONGER EXIST copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Phi3
-# TODO cyril: modular
+# copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Phi3
+# TODO @Arthur no longer copied from LLama after static cache
class Phi3SdpaAttention(Phi3Attention):
"""
Phi3 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
diff --git a/src/transformers/models/phimoe/modeling_phimoe.py b/src/transformers/models/phimoe/modeling_phimoe.py
index 8f6b092da6e6ad..82763ccea62e4c 100644
--- a/src/transformers/models/phimoe/modeling_phimoe.py
+++ b/src/transformers/models/phimoe/modeling_phimoe.py
@@ -33,6 +33,7 @@
)
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
from ...modeling_utils import PreTrainedModel
+from ...pytorch_utils import is_torch_greater_or_equal_than_1_13
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
@@ -50,6 +51,9 @@
# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
# It means that the function will not be traced through and simply appear as a node in the graph.
if is_torch_fx_available():
+ if not is_torch_greater_or_equal_than_1_13:
+ import torch.fx
+
_prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)
@@ -182,6 +186,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)
+# Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
@@ -907,12 +912,10 @@ class PhimoePreTrainedModel(PreTrainedModel):
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["PhimoeDecoderLayer"]
- _skip_keys_device_placement = ["past_key_values"]
+ _skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
- _supports_quantized_cache = True
- _supports_static_cache = True
def _init_weights(self, module):
std = self.config.initializer_range
diff --git a/src/transformers/models/pixtral/modeling_pixtral.py b/src/transformers/models/pixtral/modeling_pixtral.py
index 03886d4a528478..b65fbd634ba789 100644
--- a/src/transformers/models/pixtral/modeling_pixtral.py
+++ b/src/transformers/models/pixtral/modeling_pixtral.py
@@ -216,7 +216,6 @@ def forward(
class PixtralMLP(nn.Module):
def __init__(self, config):
super().__init__()
- self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
@@ -224,9 +223,8 @@ def __init__(self, config):
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act]
- def forward(self, x):
- down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
- return down_proj
+ def forward(self, hidden_state):
+ return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Pixtral
diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py
index 36fb1ddf1390ac..36c5271c5c5e6c 100644
--- a/src/transformers/models/qwen2/modeling_qwen2.py
+++ b/src/transformers/models/qwen2/modeling_qwen2.py
@@ -1,19 +1,36 @@
-# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
-# This file was automatically generated from src/transformers/models/qwen2/modular_qwen2.py.
-# Do NOT edit this file manually as any edits will be overwritten by the generation of
-# the file from the modular. If any change should be done, please apply the change to the
-# modular_qwen2.py file directly. One of our CI enforces this.
-# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
-from typing import Callable, List, Optional, Tuple, Union
+# coding=utf-8
+# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch Qwen2 model."""
+
+import math
+from typing import List, Optional, Tuple, Union
import torch
+import torch.utils.checkpoint
from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
-from ...cache_utils import Cache, DynamicCache, StaticCache
+from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter
-from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
@@ -22,41 +39,140 @@
TokenClassifierOutput,
)
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
-from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
-from ...processing_utils import Unpack
+from ...modeling_utils import PreTrainedModel
from ...utils import (
- LossKwargs,
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
+ is_flash_attn_2_available,
+ is_flash_attn_greater_or_equal_2_10,
logging,
replace_return_docstrings,
)
from .configuration_qwen2 import Qwen2Config
+if is_flash_attn_2_available():
+ from ...modeling_flash_attention_utils import _flash_attention_forward
+
+
logger = logging.get_logger(__name__)
-_CHECKPOINT_FOR_DOC = "meta-qwen2/Qwen2-2-7b-hf"
+
+_CHECKPOINT_FOR_DOC = "Qwen/Qwen2-7B"
_CONFIG_FOR_DOC = "Qwen2Config"
-class Qwen2MLP(nn.Module):
- def __init__(self, config):
+# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Qwen2
+class Qwen2RMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ Qwen2RMSNorm is equivalent to T5LayerNorm
+ """
super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+ def extra_repr(self):
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
+
+
+# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Qwen2
+class Qwen2RotaryEmbedding(nn.Module):
+ def __init__(
+ self,
+ dim=None,
+ max_position_embeddings=2048,
+ base=10000,
+ device=None,
+ scaling_factor=1.0,
+ rope_type="default",
+ config: Optional[Qwen2Config] = None,
+ ):
+ super().__init__()
+ # TODO (joao): remove the `if` below, only used for BC
+ self.rope_kwargs = {}
+ if config is None:
+ logger.warning_once(
+ "`Qwen2RotaryEmbedding` can now be fully parameterized by passing the model config through the "
+ "`config` argument. All other arguments will be removed in v4.46"
+ )
+ self.rope_kwargs = {
+ "rope_type": rope_type,
+ "factor": scaling_factor,
+ "dim": dim,
+ "base": base,
+ "max_position_embeddings": max_position_embeddings,
+ }
+ self.rope_type = rope_type
+ self.max_seq_len_cached = max_position_embeddings
+ self.original_max_seq_len = max_position_embeddings
+ else:
+ # BC: "rope_type" was originally "type"
+ if config.rope_scaling is not None:
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
+ else:
+ self.rope_type = "default"
+ self.max_seq_len_cached = config.max_position_embeddings
+ self.original_max_seq_len = config.max_position_embeddings
+
self.config = config
- self.hidden_size = config.hidden_size
- self.intermediate_size = config.intermediate_size
- self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
- self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
- self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
- self.act_fn = ACT2FN[config.hidden_act]
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
+
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ self.original_inv_freq = self.inv_freq
+
+ def _dynamic_frequency_update(self, position_ids, device):
+ """
+ dynamic RoPE layers should recompute `inv_freq` in the following situations:
+ 1 - growing beyond the cached sequence length (allow scaling)
+ 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
+ """
+ seq_len = torch.max(position_ids) + 1
+ if seq_len > self.max_seq_len_cached: # growth
+ inv_freq, self.attention_scaling = self.rope_init_fn(
+ self.config, device, seq_len=seq_len, **self.rope_kwargs
+ )
+ self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
+ self.max_seq_len_cached = seq_len
+
+ if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
+ self.max_seq_len_cached = self.original_max_seq_len
+
+ @torch.no_grad()
+ def forward(self, x, position_ids):
+ if "dynamic" in self.rope_type:
+ self._dynamic_frequency_update(position_ids, device=x.device)
+
+ # Core RoPE block
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
+ position_ids_expanded = position_ids[:, None, :].float()
+ # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
+ device_type = x.device.type
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False):
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos()
+ sin = emb.sin()
+
+ # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
+ cos = cos * self.attention_scaling
+ sin = sin * self.attention_scaling
- def forward(self, x):
- down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
- return down_proj
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+# Copied from transformers.models.llama.modeling_llama.rotate_half
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
@@ -64,6 +180,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)
+# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
@@ -91,6 +208,22 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
return q_embed, k_embed
+# Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Qwen2
+class Qwen2MLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = config.intermediate_size
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, hidden_state):
+ return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
+
+
+# Copied from transformers.models.llama.modeling_llama.repeat_kv
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
@@ -103,160 +236,391 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
-def eager_attention_forward(
- module: nn.Module,
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- attention_mask: Optional[torch.Tensor],
- scaling: float,
- dropout: float = 0.0,
- **kwargs,
-):
- key_states = repeat_kv(key, module.num_key_value_groups)
- value_states = repeat_kv(value, module.num_key_value_groups)
-
- attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
- if attention_mask is not None:
- causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
- attn_weights = attn_weights + causal_mask
-
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
- attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
- attn_output = torch.matmul(attn_weights, value_states)
- attn_output = attn_output.transpose(1, 2).contiguous()
-
- return attn_output, attn_weights
-
-
class Qwen2Attention(nn.Module):
- """Multi-headed attention from 'Attention Is All You Need' paper"""
+ """
+ Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
+ and "Generating Long Sequences with Sparse Transformers".
+ """
- def __init__(self, config: Qwen2Config, layer_idx: int):
+ def __init__(self, config: Qwen2Config, layer_idx: Optional[int] = None):
super().__init__()
self.config = config
self.layer_idx = layer_idx
- self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
- self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
- self.scaling = self.head_dim**-0.5
- self.attention_dropout = config.attention_dropout
+ if layer_idx is None:
+ logger.warning_once(
+ f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
+ "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
+ "when creating this class."
+ )
+
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.hidden_size // self.num_heads
+ self.num_key_value_heads = config.num_key_value_heads
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
+ self.max_position_embeddings = config.max_position_embeddings
+ self.rope_theta = config.rope_theta
self.is_causal = True
- self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=True)
- self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True)
- self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True)
- self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
+ self.attention_dropout = config.attention_dropout
+
+ if (self.head_dim * self.num_heads) != self.hidden_size:
+ raise ValueError(
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
+ f" and `num_heads`: {self.num_heads})."
+ )
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
+
+ self.rotary_emb = Qwen2RotaryEmbedding(config=self.config)
def forward(
self,
hidden_states: torch.Tensor,
- position_embeddings: Tuple[torch.Tensor, torch.Tensor],
- attention_mask: Optional[torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
- **kwargs: Unpack[FlashAttentionKwargs],
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
- input_shape = hidden_states.shape[:-1]
- hidden_shape = (*input_shape, -1, self.head_dim)
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
+
+ if position_embeddings is None:
+ logger.warning_once(
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
+ "removed and `position_embeddings` will be mandatory."
+ )
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ else:
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ # repeat k/v heads if n_kv_heads < n_heads
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
- query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
- key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
- value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
+ if attention_mask is not None: # no matter the length, we just slice it
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
- cos, sin = position_embeddings
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+class Qwen2FlashAttention2(Qwen2Attention):
+ """
+ Qwen2 flash attention module, following Qwen2 attention module. This module inherits from `Qwen2Attention`
+ as the weights of the module stays untouched. The only required change would be on the forward pass
+ where it needs to correctly call the public API of flash attention and deal with padding tokens
+ in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom
+ config.max_window_layers layers.
+ """
+
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
+ ):
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
+
+ if position_embeddings is None:
+ logger.warning_once(
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
+ "removed and `position_embeddings` will be mandatory."
+ )
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ else:
+ cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
- # sin and cos are specific to RoPE models; cache_position needed for the static cache
- cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
- sliding_window = None
+ # repeat k/v heads if n_kv_heads < n_heads
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+ dropout_rate = 0.0 if not self.training else self.attention_dropout
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in float16 just to be sure everything works as expected.
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.q_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ # Reashape to the expected shape for Flash Attention
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
if (
self.config.use_sliding_window
and getattr(self.config, "sliding_window", None) is not None
and self.layer_idx >= self.config.max_window_layers
):
sliding_window = self.config.sliding_window
+ else:
+ sliding_window = None
- attention_interface: Callable = eager_attention_forward
- if self.config._attn_implementation != "eager":
- if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
- logger.warning_once(
- "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
- 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
- )
- else:
- attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
-
- attn_output, attn_weights = attention_interface(
- self,
+ attn_output = _flash_attention_forward(
query_states,
key_states,
value_states,
attention_mask,
- dropout=0.0 if not self.training else self.attention_dropout,
- scaling=self.scaling,
- sliding_window=sliding_window, # main diff with Llama
- **kwargs,
+ q_len,
+ position_ids=position_ids,
+ dropout=dropout_rate,
+ sliding_window=sliding_window,
+ is_causal=self.is_causal,
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
)
- attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
attn_output = self.o_proj(attn_output)
- return attn_output, attn_weights
+ if not output_attentions:
+ attn_weights = None
-class Qwen2RMSNorm(nn.Module):
- def __init__(self, hidden_size, eps=1e-6):
- """
- Qwen2RMSNorm is equivalent to T5LayerNorm
- """
- super().__init__()
- self.weight = nn.Parameter(torch.ones(hidden_size))
- self.variance_epsilon = eps
+ return attn_output, attn_weights, past_key_value
- def forward(self, hidden_states):
- input_dtype = hidden_states.dtype
- hidden_states = hidden_states.to(torch.float32)
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
- return self.weight * hidden_states.to(input_dtype)
- def extra_repr(self):
- return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
+class Qwen2SdpaAttention(Qwen2Attention):
+ """
+ Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
+ `Qwen2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
+ SDPA API.
+ """
+
+ # Adapted from Qwen2Attention.forward
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ if output_attentions:
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
+ logger.warning_once(
+ "Qwen2Model is using Qwen2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
+ )
+ return super().forward(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ )
+
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
+
+ if position_embeddings is None:
+ logger.warning_once(
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
+ "removed and `position_embeddings` will be mandatory."
+ )
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ else:
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ causal_mask = attention_mask
+ if attention_mask is not None: # no matter the length, we just slice it
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
+ if query_states.device.type == "cuda" and attention_mask is not None:
+ query_states = query_states.contiguous()
+ key_states = key_states.contiguous()
+ value_states = value_states.contiguous()
+
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
+ # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
+ is_causal = True if causal_mask is None and q_len > 1 else False
+
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
+ query_states,
+ key_states,
+ value_states,
+ attn_mask=causal_mask,
+ dropout_p=self.attention_dropout if self.training else 0.0,
+ is_causal=is_causal,
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.view(bsz, q_len, self.hidden_size)
+
+ attn_output = self.o_proj(attn_output)
+
+ return attn_output, None, past_key_value
+
+
+QWEN2_ATTENTION_CLASSES = {
+ "eager": Qwen2Attention,
+ "flash_attention_2": Qwen2FlashAttention2,
+ "sdpa": Qwen2SdpaAttention,
+}
class Qwen2DecoderLayer(nn.Module):
def __init__(self, config: Qwen2Config, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
- self.self_attn = Qwen2Attention(config=config, layer_idx=layer_idx)
- self.mlp = Qwen2MLP(config)
- self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
- self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
if config.sliding_window and config._attn_implementation != "flash_attention_2":
logger.warning_once(
f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
"unexpected results may be encountered."
)
+ self.self_attn = QWEN2_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
+
+ self.mlp = Qwen2MLP(config)
+ self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[Cache] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
- **kwargs: Unpack[FlashAttentionKwargs],
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
+ **kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
+ `(batch, sequence_length)` where padding elements are indicated by 0.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence.
+ position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
+ Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
+ with `head_dim` being the embedding dimension of each attention head.
+ kwargs (`dict`, *optional*):
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
+ into the model
+ """
+
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
- hidden_states, self_attn_weights = self.self_attn(
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
@@ -265,7 +629,6 @@ def forward(
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
- **kwargs,
)
hidden_states = residual + hidden_states
@@ -276,75 +639,14 @@ def forward(
hidden_states = residual + hidden_states
outputs = (hidden_states,)
+
if output_attentions:
outputs += (self_attn_weights,)
- return outputs
-
-
-class Qwen2RotaryEmbedding(nn.Module):
- def __init__(
- self,
- config: Qwen2Config,
- device=None,
- ):
- super().__init__()
- self.rope_kwargs = {}
- # BC: "rope_type" was originally "type"
- if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
- self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
- else:
- self.rope_type = "default"
- self.max_seq_len_cached = config.max_position_embeddings
- self.original_max_seq_len = config.max_position_embeddings
-
- self.config = config
- self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
-
- inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
- self.register_buffer("inv_freq", inv_freq, persistent=False)
- self.original_inv_freq = self.inv_freq
-
- def _dynamic_frequency_update(self, position_ids, device):
- """
- dynamic RoPE layers should recompute `inv_freq` in the following situations:
- 1 - growing beyond the cached sequence length (allow scaling)
- 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
- """
- seq_len = torch.max(position_ids) + 1
- if seq_len > self.max_seq_len_cached: # growth
- inv_freq, self.attention_scaling = self.rope_init_fn(
- self.config, device, seq_len=seq_len, **self.rope_kwargs
- )
- self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
- self.max_seq_len_cached = seq_len
-
- if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
- self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
- self.max_seq_len_cached = self.original_max_seq_len
-
- @torch.no_grad()
- def forward(self, x, position_ids):
- if "dynamic" in self.rope_type:
- self._dynamic_frequency_update(position_ids, device=x.device)
-
- # Core RoPE block
- inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
- position_ids_expanded = position_ids[:, None, :].float()
- # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
- device_type = x.device.type
- device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
- with torch.autocast(device_type=device_type, enabled=False):
- freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
- emb = torch.cat((freqs, freqs), dim=-1)
- cos = emb.cos()
- sin = emb.sin()
-
- # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
- cos = cos * self.attention_scaling
- sin = sin * self.attention_scaling
+ if use_cache:
+ outputs += (present_key_value,)
- return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+ return outputs
QWEN2_START_DOCSTRING = r"""
@@ -373,7 +675,7 @@ class Qwen2PreTrainedModel(PreTrainedModel):
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["Qwen2DecoderLayer"]
- _skip_keys_device_placement = ["past_key_values"]
+ _skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
@@ -413,7 +715,7 @@ def _init_weights(self, module):
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
- If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
`past_key_values`).
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
@@ -488,10 +790,11 @@ def __init__(self, config: Qwen2Config):
self.layers = nn.ModuleList(
[Qwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
+ self._attn_implementation = config._attn_implementation
self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = Qwen2RotaryEmbedding(config=config)
- self.gradient_checkpointing = False
+ self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
@@ -507,43 +810,54 @@ def forward(
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[Cache] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
- **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
- if self.gradient_checkpointing and self.training and use_cache:
- logger.warning_once(
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
- )
- use_cache = False
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ # kept for BC (non `Cache` `past_key_values` inputs)
+ return_legacy_cache = False
+ if use_cache and not isinstance(past_key_values, Cache):
+ return_legacy_cache = True
+ if past_key_values is None:
+ past_key_values = DynamicCache()
+ else:
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
+ logger.warning_once(
+ "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
+ "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
+ "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
+ )
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
- if use_cache and past_key_values is None:
- past_key_values = DynamicCache()
-
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
-
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
@@ -559,8 +873,9 @@ def forward(
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
+ next_decoder_cache = None
- for decoder_layer in self.layers[: self.config.num_hidden_layers]:
+ for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
@@ -586,11 +901,13 @@ def forward(
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
- **flash_attn_kwargs,
)
hidden_states = layer_outputs[0]
+ if use_cache:
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
+
if output_attentions:
all_self_attns += (layer_outputs[1],)
@@ -600,14 +917,20 @@ def forward(
if output_hidden_states:
all_hidden_states += (hidden_states,)
- output = BaseModelOutputWithPast(
+ next_cache = next_decoder_cache if use_cache else None
+ if return_legacy_cache:
+ next_cache = next_cache.to_legacy_cache()
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
+ return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
- past_key_values=past_key_values if use_cache else None,
+ past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
- return output if return_dict else output.to_tuple()
+ # Copied from transformers.models.phi3.modeling_phi3.Phi3Model._update_causal_mask
def _update_causal_mask(
self,
attention_mask: torch.Tensor,
@@ -626,21 +949,30 @@ def _update_causal_mask(
# to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache)
+ using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
- if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
+ if (
+ self.config._attn_implementation == "sdpa"
+ and not (using_static_cache or using_sliding_window_cache)
+ and not output_attentions
+ ):
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
+ sliding_window=self.config.sliding_window,
is_training=self.training,
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
+ min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
- if using_static_cache:
+ # SlidingWindowCache or StaticCache
+ if using_sliding_window_cache or using_static_cache:
target_length = past_key_values.get_max_cache_shape()
+ # DynamicCache or no cache
else:
target_length = (
attention_mask.shape[-1]
@@ -657,6 +989,8 @@ def _update_causal_mask(
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
+ config=self.config,
+ past_key_values=past_key_values,
)
if (
@@ -668,12 +1002,12 @@ def _update_causal_mask(
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
- min_dtype = torch.finfo(dtype).min
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
return causal_mask
@staticmethod
+ # Copied from transformers.models.mistral.modeling_mistral.MistralModel._prepare_4d_causal_attention_mask_with_cache_position with Mistral->Qwen2
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
@@ -682,7 +1016,8 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
- **kwargs,
+ config: Qwen2Config,
+ past_key_values: Cache,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
@@ -690,13 +1025,11 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
Args:
attention_mask (`torch.Tensor`):
- A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
- `(batch_size, 1, query_length, key_value_length)`.
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
- The target length: when generating with static cache, the mask should be as long as the static cache,
- to account for the 0 padding, the part of the cache that is not filled yet.
+ The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
@@ -705,6 +1038,10 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
+ config (`Qwen2Config`):
+ The model's configuration class
+ past_key_values (`Cache`):
+ The cache class that is being used currently to generate
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
@@ -714,25 +1051,30 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
)
- if sequence_length != 1:
- causal_mask = torch.triu(causal_mask, diagonal=1)
- causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
+ diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
+ if config.sliding_window is not None:
+ # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
+ # the check is needed to verify is current checkpoint was trained with sliding window or not
+ if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
+ sliding_attend_mask = torch.arange(target_length, device=device) <= (
+ cache_position.reshape(-1, 1) - config.sliding_window
+ )
+ diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
+ causal_mask *= diagonal_attend_mask
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
+ if attention_mask.shape[-1] > target_length:
+ attention_mask = attention_mask[:, :target_length]
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
-
return causal_mask
-class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
-
-
class Qwen2ForCausalLM(Qwen2PreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
@@ -771,7 +1113,7 @@ def forward(
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
@@ -780,7 +1122,7 @@ def forward(
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
- **kwargs: Unpack[KwargsForCausalLM],
+ **loss_kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
@@ -801,8 +1143,8 @@ def forward(
```python
>>> from transformers import AutoTokenizer, Qwen2ForCausalLM
- >>> model = Qwen2ForCausalLM.from_pretrained("meta-qwen2/Qwen2-2-7b-hf")
- >>> tokenizer = AutoTokenizer.from_pretrained("meta-qwen2/Qwen2-2-7b-hf")
+ >>> model = Qwen2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
@@ -812,6 +1154,7 @@ def forward(
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -830,7 +1173,6 @@ def forward(
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
- **kwargs,
)
hidden_states = outputs[0]
@@ -839,7 +1181,7 @@ def forward(
loss = None
if labels is not None:
- loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
+ loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
if not return_dict:
output = (logits,) + outputs[1:]
@@ -888,10 +1230,10 @@ def set_input_embeddings(self, value):
@add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
def forward(
self,
- input_ids: Optional[torch.LongTensor] = None,
+ input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
@@ -943,8 +1285,27 @@ def forward(
loss = None
if labels is not None:
- loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
-
+ labels = labels.to(logits.device)
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(pooled_logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(pooled_logits, labels)
if not return_dict:
output = (pooled_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
@@ -965,6 +1326,7 @@ def forward(
""",
QWEN2_START_DOCSTRING,
)
+# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Qwen2, LLAMA->QWEN2
class Qwen2ForTokenClassification(Qwen2PreTrainedModel):
def __init__(self, config):
super().__init__(config)
@@ -1053,22 +1415,24 @@ def forward(
""",
QWEN2_START_DOCSTRING,
)
+# Copied from transformers.models.mistral.modeling_mistral.MistralForQuestionAnswering with Mistral->Qwen2, MISTRAL->QWEN2
class Qwen2ForQuestionAnswering(Qwen2PreTrainedModel):
- base_model_prefix = "transformer"
+ base_model_prefix = "model"
+ # Copied from models.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Qwen2
def __init__(self, config):
super().__init__(config)
- self.transformer = Qwen2Model(config)
+ self.model = Qwen2Model(config)
self.qa_outputs = nn.Linear(config.hidden_size, 2)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
- return self.transformer.embed_tokens
+ return self.model.embed_tokens
def set_input_embeddings(self, value):
- self.transformer.embed_tokens = value
+ self.model.embed_tokens = value
@add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
def forward(
@@ -1097,7 +1461,7 @@ def forward(
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- outputs = self.transformer(
+ outputs = self.model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
diff --git a/src/transformers/models/qwen2/modular_qwen2.py b/src/transformers/models/qwen2/modular_qwen2.py
deleted file mode 100644
index 718abd01090c2b..00000000000000
--- a/src/transformers/models/qwen2/modular_qwen2.py
+++ /dev/null
@@ -1,134 +0,0 @@
-from typing import Callable, Optional, Tuple
-
-import torch
-import torch.utils.checkpoint
-from torch import nn
-
-from ...cache_utils import Cache
-from ...modeling_flash_attention_utils import FlashAttentionKwargs
-from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
-from ...processing_utils import Unpack
-from ...utils import logging
-from ..llama.modeling_llama import (
- LlamaAttention,
- LlamaDecoderLayer,
- LlamaForCausalLM,
- LlamaForQuestionAnswering,
- LlamaForSequenceClassification,
- LlamaForTokenClassification,
- LlamaMLP,
- LlamaModel,
- apply_rotary_pos_emb,
- eager_attention_forward,
-)
-from .configuration_qwen2 import Qwen2Config
-
-
-logger = logging.get_logger(__name__)
-
-
-class Qwen2MLP(LlamaMLP):
- def __init__(self, config):
- super().__init__(config)
- self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
- self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
- self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
-
-
-class Qwen2Attention(LlamaAttention):
- def __init__(self, config: Qwen2Config, layer_idx: int):
- super().__init__(config, layer_idx)
- self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=True)
- self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True)
- self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True)
- self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- position_embeddings: Tuple[torch.Tensor, torch.Tensor],
- attention_mask: Optional[torch.Tensor],
- past_key_value: Optional[Cache] = None,
- cache_position: Optional[torch.LongTensor] = None,
- **kwargs: Unpack[FlashAttentionKwargs],
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
- input_shape = hidden_states.shape[:-1]
- hidden_shape = (*input_shape, -1, self.head_dim)
-
- query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
- key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
- value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
-
- cos, sin = position_embeddings
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
-
- if past_key_value is not None:
- # sin and cos are specific to RoPE models; cache_position needed for the static cache
- cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
- key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
-
- sliding_window = None
- if (
- self.config.use_sliding_window
- and getattr(self.config, "sliding_window", None) is not None
- and self.layer_idx >= self.config.max_window_layers
- ):
- sliding_window = self.config.sliding_window
-
- attention_interface: Callable = eager_attention_forward
- if self.config._attn_implementation != "eager":
- if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
- logger.warning_once(
- "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
- 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
- )
- else:
- attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
-
- attn_output, attn_weights = attention_interface(
- self,
- query_states,
- key_states,
- value_states,
- attention_mask,
- dropout=0.0 if not self.training else self.attention_dropout,
- scaling=self.scaling,
- sliding_window=sliding_window, # main diff with Llama
- **kwargs,
- )
-
- attn_output = attn_output.reshape(*input_shape, -1).contiguous()
- attn_output = self.o_proj(attn_output)
- return attn_output, attn_weights
-
-
-class Qwen2DecoderLayer(LlamaDecoderLayer):
- def __init__(self, config: Qwen2Config, layer_idx: int):
- super().__init__()
- self.self_attn = Qwen2Attention(config=config, layer_idx=layer_idx)
- self.mlp = Qwen2MLP(config)
- if config.sliding_window and config._attn_implementation != "flash_attention_2":
- logger.warning_once(
- f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
- "unexpected results may be encountered."
- )
-
-
-class Qwen2Model(LlamaModel):
- pass
-
-
-class Qwen2ForCausalLM(LlamaForCausalLM):
- pass
-
-
-class Qwen2ForSequenceClassification(LlamaForSequenceClassification):
- pass
-
-
-class Qwen2ForTokenClassification(LlamaForTokenClassification):
- pass
-
-
-class Qwen2ForQuestionAnswering(LlamaForQuestionAnswering):
- pass
diff --git a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py
index 44a5b5ce315570..ce0e427048cf23 100644
--- a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py
+++ b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py
@@ -223,6 +223,7 @@ class Qwen2AudioFlashAttention2(Qwen2AudioAttention):
flash attention and deal with padding tokens in case the input contains any of them.
"""
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py
index 1ce41509a5c0d1..6c5cbec2220e23 100644
--- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py
+++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py
@@ -169,18 +169,40 @@ def extra_repr(self):
class Qwen2MoeRotaryEmbedding(nn.Module):
def __init__(
self,
- config: Qwen2MoeConfig,
+ dim=None,
+ max_position_embeddings=2048,
+ base=10000,
device=None,
+ scaling_factor=1.0,
+ rope_type="default",
+ config: Optional[Qwen2MoeConfig] = None,
):
super().__init__()
+ # TODO (joao): remove the `if` below, only used for BC
self.rope_kwargs = {}
- # BC: "rope_type" was originally "type"
- if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
- self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
+ if config is None:
+ logger.warning_once(
+ "`Qwen2MoeRotaryEmbedding` can now be fully parameterized by passing the model config through the "
+ "`config` argument. All other arguments will be removed in v4.46"
+ )
+ self.rope_kwargs = {
+ "rope_type": rope_type,
+ "factor": scaling_factor,
+ "dim": dim,
+ "base": base,
+ "max_position_embeddings": max_position_embeddings,
+ }
+ self.rope_type = rope_type
+ self.max_seq_len_cached = max_position_embeddings
+ self.original_max_seq_len = max_position_embeddings
else:
- self.rope_type = "default"
- self.max_seq_len_cached = config.max_position_embeddings
- self.original_max_seq_len = config.max_position_embeddings
+ # BC: "rope_type" was originally "type"
+ if config.rope_scaling is not None:
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
+ else:
+ self.rope_type = "default"
+ self.max_seq_len_cached = config.max_position_embeddings
+ self.original_max_seq_len = config.max_position_embeddings
self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
@@ -296,8 +318,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
-# copied from transformers.models.qwen2.modeling_qwen2.Qwen2Attention with Qwen2->Qwen2Moe
-# no longer copied after attention refactors
+# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2Attention with Qwen2->Qwen2Moe
class Qwen2MoeAttention(nn.Module):
"""
Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
@@ -347,7 +368,7 @@ def forward(
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
@@ -359,7 +380,17 @@ def forward(
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
- cos, sin = position_embeddings
+ if position_embeddings is None:
+ logger.warning_once(
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
+ "removed and `position_embeddings` will be mandatory."
+ )
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ else:
+ cos, sin = position_embeddings
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
@@ -398,8 +429,7 @@ def forward(
return attn_output, attn_weights, past_key_value
-# NO LONGER EXIST Copied from transformers.models.qwen2.modeling_qwen2.Qwen2FlashAttention2 with Qwen2->Qwen2Moe
-# TODO cyril: modular
+# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2FlashAttention2 with Qwen2->Qwen2Moe
class Qwen2MoeFlashAttention2(Qwen2MoeAttention):
"""
Qwen2Moe flash attention module, following Qwen2Moe attention module. This module inherits from `Qwen2MoeAttention`
@@ -409,6 +439,7 @@ class Qwen2MoeFlashAttention2(Qwen2MoeAttention):
config.max_window_layers layers.
"""
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -426,7 +457,7 @@ def forward(
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
):
bsz, q_len, _ = hidden_states.size()
@@ -438,7 +469,16 @@ def forward(
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
- cos, sin = position_embeddings
+ if position_embeddings is None:
+ logger.warning_once(
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
+ "removed and `position_embeddings` will be mandatory."
+ )
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ else:
+ cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
@@ -509,8 +549,7 @@ def forward(
return attn_output, attn_weights, past_key_value
-# NO LONGER EXIST Copied from transformers.models.qwen2.modeling_qwen2.Qwen2SdpaAttention with Qwen2->Qwen2Moe
-# TODO cyril: modular
+# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2SdpaAttention with Qwen2->Qwen2Moe
class Qwen2MoeSdpaAttention(Qwen2MoeAttention):
"""
Qwen2Moe attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
@@ -528,7 +567,7 @@ def forward(
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions:
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
@@ -543,8 +582,6 @@ def forward(
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
- cache_position=cache_position,
- position_embeddings=position_embeddings,
)
bsz, q_len, _ = hidden_states.size()
@@ -557,7 +594,16 @@ def forward(
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
- cos, sin = position_embeddings
+ if position_embeddings is None:
+ logger.warning_once(
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
+ "removed and `position_embeddings` will be mandatory."
+ )
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ else:
+ cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
@@ -696,7 +742,7 @@ def forward(
output_router_logits: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
@@ -1558,10 +1604,11 @@ def forward(
class Qwen2MoeForQuestionAnswering(Qwen2MoePreTrainedModel):
base_model_prefix = "model"
+ # Copied from models.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Qwen2Moe
def __init__(self, config):
super().__init__(config)
+ self.model = Qwen2MoeModel(config)
self.qa_outputs = nn.Linear(config.hidden_size, 2)
- self.model = Qwen2MoeModel(config) # diff with Llama: transformer->model
# Initialize weights and apply final processing
self.post_init()
diff --git a/src/transformers/models/qwen2_vl/configuration_qwen2_vl.py b/src/transformers/models/qwen2_vl/configuration_qwen2_vl.py
index ef98ae5e3f508f..55042327de4ec3 100644
--- a/src/transformers/models/qwen2_vl/configuration_qwen2_vl.py
+++ b/src/transformers/models/qwen2_vl/configuration_qwen2_vl.py
@@ -163,16 +163,6 @@ class Qwen2VLConfig(PretrainedConfig):
model_type = "qwen2_vl"
sub_configs = {"vision_config": Qwen2VLVisionConfig}
keys_to_ignore_at_inference = ["past_key_values"]
- # Default tensor parallel plan for base model `Qwen2VL`
- base_model_tp_plan = {
- "layers.*.self_attn.q_proj": "colwise",
- "layers.*.self_attn.k_proj": "colwise",
- "layers.*.self_attn.v_proj": "colwise",
- "layers.*.self_attn.o_proj": "rowwise",
- "layers.*.mlp.gate_proj": "colwise",
- "layers.*.mlp.up_proj": "colwise",
- "layers.*.mlp.down_proj": "rowwise",
- }
def __init__(
self,
diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
index 566141d3f75c27..f7648f4a53d1af 100644
--- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
+++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
@@ -460,7 +460,6 @@ def extra_repr(self):
class Qwen2MLP(nn.Module):
def __init__(self, config):
super().__init__()
- self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
@@ -468,9 +467,8 @@ def __init__(self, config):
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act]
- def forward(self, x):
- down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
- return down_proj
+ def forward(self, hidden_state):
+ return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
# Copied from transformers.models.llama.modeling_llama.repeat_kv
@@ -539,7 +537,7 @@ def forward(
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
@@ -547,11 +545,20 @@ def forward(
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
- query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
- key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
- value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
- cos, sin = position_embeddings
+ if position_embeddings is None:
+ logger.warning_once(
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
+ "removed and `position_embeddings` will be mandatory."
+ )
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ else:
+ cos, sin = position_embeddings
query_states, key_states = apply_multimodal_rotary_pos_emb(
query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
)
@@ -623,7 +630,7 @@ def forward(
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
):
bsz, q_len, _ = hidden_states.size()
@@ -631,12 +638,22 @@ def forward(
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
- query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
- key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
- value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
# Because the input can be padded, the absolute sequence length depends on the max position id.
- cos, sin = position_embeddings
+ if position_embeddings is None:
+ logger.warning_once(
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
+ "removed and `position_embeddings` will be mandatory."
+ )
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ else:
+ cos, sin = position_embeddings
+
query_states, key_states = apply_multimodal_rotary_pos_emb(
query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
)
@@ -725,7 +742,7 @@ def forward(
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions:
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
@@ -741,7 +758,6 @@ def forward(
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
- position_embeddings=position_embeddings,
)
bsz, q_len, _ = hidden_states.size()
@@ -750,11 +766,20 @@ def forward(
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
- query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
- key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
- value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
- cos, sin = position_embeddings
+ if position_embeddings is None:
+ logger.warning_once(
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
+ "removed and `position_embeddings` will be mandatory."
+ )
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ else:
+ cos, sin = position_embeddings
query_states, key_states = apply_multimodal_rotary_pos_emb(
query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
)
@@ -831,7 +856,7 @@ def forward(
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
diff --git a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py
index 74fc2085c36519..2b3cf7eb0cb82e 100644
--- a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py
+++ b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py
@@ -77,6 +77,7 @@ def __init__(self, dim, base=10000, device=None):
self.register_buffer("inv_freq", tensor=inv_freq, persistent=False)
@torch.no_grad()
+ # Copied from transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding.forward with Gemma->RecurrentGemma
def forward(self, x, position_ids, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
self.inv_freq.to(x.device)
@@ -184,7 +185,7 @@ def forward(
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
- cos, sin = self.rotary_emb(value_states, position_ids)
+ cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
# Partial rotary embedding
query_rot, query_pass = torch.chunk(query_states, int(1 / self.partial_rotary_factor), dim=-1)
diff --git a/src/transformers/models/sam/configuration_sam.py b/src/transformers/models/sam/configuration_sam.py
index 22a237615d1280..b0045655d2066b 100644
--- a/src/transformers/models/sam/configuration_sam.py
+++ b/src/transformers/models/sam/configuration_sam.py
@@ -46,8 +46,6 @@ class SamPromptEncoderConfig(PretrainedConfig):
The non-linear activation function in the encoder and pooler.
"""
- base_config_key = "prompt_encoder_config"
-
def __init__(
self,
hidden_size=256,
@@ -104,8 +102,6 @@ class SamMaskDecoderConfig(PretrainedConfig):
"""
- base_config_key = "mask_decoder_config"
-
def __init__(
self,
hidden_size=256,
@@ -185,8 +181,6 @@ class SamVisionConfig(PretrainedConfig):
hidden_size`.
"""
- base_config_key = "vision_config"
-
def __init__(
self,
hidden_size=768,
@@ -284,11 +278,6 @@ class SamConfig(PretrainedConfig):
```"""
model_type = "sam"
- sub_configs = {
- "prompt_encoder_config": SamPromptEncoderConfig,
- "mask_decoder_config": SamMaskDecoderConfig,
- "vision_config": SamVisionConfig,
- }
def __init__(
self,
diff --git a/src/transformers/models/sam/modeling_sam.py b/src/transformers/models/sam/modeling_sam.py
index b935bc9e421e01..c99fb9d7e869f8 100644
--- a/src/transformers/models/sam/modeling_sam.py
+++ b/src/transformers/models/sam/modeling_sam.py
@@ -246,47 +246,6 @@ def forward(self, query: Tensor, key: Tensor, value: Tensor, attention_similarit
return out
-class SamSdpaAttention(SamAttention):
- """
- SAM's attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and
- values. Using SDPA instead of the default attention.
- """
-
- def __init__(self, config, downsample_rate=None):
- super().__init__(config, downsample_rate)
-
- def forward(self, query: Tensor, key: Tensor, value: Tensor, attention_similarity: Tensor = None) -> Tensor:
- # Input projections
- query = self.q_proj(query)
- key = self.k_proj(key)
- value = self.v_proj(value)
-
- point_batch_size = query.shape[1]
- # Separate into heads
- query = self._separate_heads(query, self.num_attention_heads)
- key = self._separate_heads(key, self.num_attention_heads)
- value = self._separate_heads(value, self.num_attention_heads)
-
- # Scaled dot product attention
- attn_mask = None
- if attention_similarity is not None:
- attn_mask = attention_similarity.unsqueeze(1).expand(-1, self.num_attention_heads, -1, -1)
-
- out = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask)
-
- # Get output
- out = self._recombine_heads(out, point_batch_size)
- out = self.out_proj(out)
-
- return out
-
-
-SAM_ATTENTION_CLASSES = {
- "eager": SamAttention,
- "sdpa": SamSdpaAttention,
-}
-
-
class SamTwoWayAttentionBlock(nn.Module):
def __init__(self, config, attention_downsample_rate: int = 2, skip_first_layer_pe: bool = False):
"""
@@ -307,21 +266,18 @@ def __init__(self, config, attention_downsample_rate: int = 2, skip_first_layer_
self.hidden_size = config.hidden_size
self.layer_norm_eps = config.layer_norm_eps
- self.self_attn = SAM_ATTENTION_CLASSES[config._attn_implementation](config, downsample_rate=1)
+ self.self_attn = SamAttention(config, downsample_rate=1)
self.layer_norm1 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
- self.cross_attn_token_to_image = SAM_ATTENTION_CLASSES[config._attn_implementation](
- config, downsample_rate=attention_downsample_rate
- )
+ self.cross_attn_token_to_image = SamAttention(config, downsample_rate=attention_downsample_rate)
self.layer_norm2 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
self.mlp = SamMLPBlock(config)
self.layer_norm3 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
self.layer_norm4 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
- self.cross_attn_image_to_token = SAM_ATTENTION_CLASSES[config._attn_implementation](
- config, downsample_rate=attention_downsample_rate
- )
+ self.cross_attn_image_to_token = SamAttention(config, downsample_rate=attention_downsample_rate)
+
self.skip_first_layer_pe = skip_first_layer_pe
def forward(
@@ -388,7 +344,7 @@ def __init__(self, config: SamMaskDecoderConfig):
for i in range(self.num_hidden_layers):
self.layers.append(SamTwoWayAttentionBlock(config, skip_first_layer_pe=(i == 0)))
- self.final_attn_token_to_image = SAM_ATTENTION_CLASSES[config._attn_implementation](config)
+ self.final_attn_token_to_image = SamAttention(config)
self.layer_norm_final_attn = nn.LayerNorm(config.hidden_size)
def forward(
@@ -475,7 +431,7 @@ def forward(self, hidden_states):
class SamMaskDecoder(nn.Module):
def __init__(self, config: SamMaskDecoderConfig):
super().__init__()
- self.config = config
+
self.hidden_size = config.hidden_size
self.num_multimask_outputs = config.num_multimask_outputs
@@ -900,118 +856,11 @@ def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch
return outputs
-class SamVisionSdpaAttention(SamVisionAttention):
- """
- Multi-head Attention block with relative position embeddings.
- Using SDPA instead of the default attention.
- """
-
- def __init__(self, config, window_size):
- super().__init__(config, window_size)
-
- def add_decomposed_rel_pos(
- self,
- query: torch.Tensor,
- rel_pos_h: torch.Tensor,
- rel_pos_w: torch.Tensor,
- q_size: Tuple[int, int],
- k_size: Tuple[int, int],
- ) -> torch.Tensor:
- """
- Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
- https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
- This method is reimplemented to follow the implementation in:
- https://github.com/pytorch-labs/segment-anything-fast/blob/main/segment_anything_fast/modeling/image_encoder.py # noqa B950
- This implementation is more memory efficient when using SDPA in the forward method.
- Args:
- q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
- rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
- rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
- q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
- k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
-
- Returns:
- attn (Tensor): attention map with added relative positional embeddings.
- """
- query_height, query_width = q_size
- key_height, key_width = k_size
- relative_position_height = self.get_rel_pos(query_height, key_height, rel_pos_h)
- relative_position_width = self.get_rel_pos(query_width, key_width, rel_pos_w)
-
- batch_size, _, dim = query.shape
- reshaped_query = query.reshape(batch_size, query_height, query_width, dim)
- rel_h = torch.einsum("bhwc,hkc->bhwk", reshaped_query, relative_position_height)
- rel_w = torch.einsum("bhwc,wkc->bhwk", reshaped_query, relative_position_width)
- rel_h = rel_h.unsqueeze(-1)
- rel_w = rel_w.unsqueeze(-2)
- rel_h = rel_h.reshape(batch_size, query_height * query_width, key_height, 1)
- rel_w = rel_w.reshape(batch_size, query_height * query_width, 1, key_width)
-
- return rel_h, rel_w
-
- def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor:
- batch_size, height, width, _ = hidden_states.shape
- # qkv with shape (3, B, nHead, H * W, C)
- qkv = (
- self.qkv(hidden_states)
- .reshape(batch_size, height * width, 3, self.num_attention_heads, -1)
- .permute(2, 0, 3, 1, 4)
- )
- # q, k, v with shape (B * nHead, H * W, C)
- query, key, value = qkv.reshape(3, batch_size * self.num_attention_heads, height * width, -1).unbind(0)
-
- rel_h, rel_w = None, None
- if self.use_rel_pos:
- rel_h, rel_w = self.add_decomposed_rel_pos(
- query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
- )
-
- query = query.view(batch_size, self.num_attention_heads, height * width, -1)
- key = key.view(batch_size, self.num_attention_heads, height * width, -1)
- value = value.view(batch_size, self.num_attention_heads, height * width, -1)
-
- if self.use_rel_pos:
- rel_h = rel_h.view(batch_size, self.num_attention_heads, rel_h.size(1), rel_h.size(2), rel_h.size(3))
- rel_w = rel_w.view(batch_size, self.num_attention_heads, rel_w.size(1), rel_w.size(2), rel_w.size(3))
- attn_bias = (rel_h + rel_w).view(
- batch_size, self.num_attention_heads, rel_h.size(2), rel_h.size(3) * rel_w.size(4)
- )
- attn_output = torch.nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=attn_bias)
- else:
- attn_output = torch.nn.functional.scaled_dot_product_attention(query, key, value)
-
- attn_output = (
- attn_output.view(batch_size, self.num_attention_heads, height, width, -1)
- .permute(0, 2, 3, 1, 4)
- .reshape(batch_size, height, width, -1)
- )
-
- attn_output = self.proj(attn_output)
-
- if output_attentions:
- # For output_attentions, calculate the attention weights
- attn_weights = (query @ key.transpose(-2, -1)) * self.scale
- if attn_bias is not None:
- attn_weights = attn_weights + attn_bias
- attn_weights = F.softmax(attn_weights, dim=-1)
- outputs = (attn_output, attn_weights)
- else:
- outputs = (attn_output, None)
-
- return outputs
-
-
-SAM_VISION_ATTENTION_CLASSES = {
- "eager": SamVisionAttention,
- "sdpa": SamVisionSdpaAttention,
-}
-
-
class SamVisionLayer(nn.Module):
def __init__(self, config, window_size):
super().__init__()
self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- self.attn = SAM_VISION_ATTENTION_CLASSES[config._attn_implementation](config, window_size)
+ self.attn = SamVisionAttention(config, window_size)
self.layer_norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.mlp = SamMLPBlock(config)
self.window_size = window_size
@@ -1222,8 +1071,6 @@ class SamPreTrainedModel(PreTrainedModel):
base_model_prefix = "sam"
main_input_name = "pixel_values"
_no_split_modules = ["SamVisionAttention"]
- supports_gradient_checkpointing = True
- _supports_sdpa = True
def _init_weights(self, module):
std = self.config.initializer_range
diff --git a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py
index 6aa967416d5477..c5c3b202846705 100755
--- a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py
+++ b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py
@@ -293,8 +293,6 @@ def format_speech_generation_kwargs(kwargs):
elif key.startswith("speech_"):
key = key[len("speech_") :]
kwargs_speech[key] = value
- elif key == "generation_config":
- kwargs_text[key] = value
else:
# If the key is already in a specific config, then it's been set with a
# submodules specific value and we don't override
diff --git a/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py b/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py
index 978000086e2c3b..a8068eb0ad01ea 100644
--- a/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py
+++ b/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py
@@ -421,8 +421,6 @@ def format_speech_generation_kwargs(kwargs):
elif key.startswith("speech_"):
key = key[len("speech_") :]
kwargs_speech[key] = value
- elif key == "generation_config":
- kwargs_text[key] = value
else:
# If the key is already in a specific config, then it's been set with a
# submodules specific value and we don't override
diff --git a/src/transformers/models/sew/modeling_sew.py b/src/transformers/models/sew/modeling_sew.py
index 8dc3e2297d4525..8638d93385843d 100644
--- a/src/transformers/models/sew/modeling_sew.py
+++ b/src/transformers/models/sew/modeling_sew.py
@@ -563,6 +563,7 @@ class SEWFlashAttention2(SEWAttention):
flash attention and deal with padding tokens in case the input contains any of them.
"""
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -882,15 +883,15 @@ def forward(
all_self_attentions = () if output_attentions else None
if attention_mask is not None:
- expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
if self._use_flash_attention_2:
# make sure padded tokens output 0
- hidden_states[~expand_attention_mask] = 0.0
+ hidden_states[~attention_mask] = 0.0
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
else:
# make sure padded tokens output 0
- hidden_states[~expand_attention_mask] = 0.0
+ hidden_states[~attention_mask] = 0.0
+
input_lengths = (attention_mask.long()).sum(-1)
# apply pooling formula to get real output_lengths
output_lengths = input_lengths // self.config.squeeze_factor
@@ -1473,8 +1474,7 @@ def forward(
pooled_output = hidden_states.mean(dim=1)
else:
padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
- expand_padding_mask = padding_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
- hidden_states[~expand_padding_mask] = 0.0
+ hidden_states[~padding_mask] = 0.0
pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
logits = self.classifier(pooled_output)
diff --git a/src/transformers/models/sew_d/modeling_sew_d.py b/src/transformers/models/sew_d/modeling_sew_d.py
index 2df687f4cc362a..5cccc0218e6ccf 100644
--- a/src/transformers/models/sew_d/modeling_sew_d.py
+++ b/src/transformers/models/sew_d/modeling_sew_d.py
@@ -1175,8 +1175,7 @@ def forward(
)
else:
# make sure padded tokens output 0
- expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
- hidden_states[~expand_attention_mask.bool()] = 0.0
+ hidden_states[~attention_mask.bool()] = 0.0
input_lengths = (attention_mask.long()).sum(-1)
# apply pooling formula to get real output_lengths
@@ -1722,8 +1721,7 @@ def forward(
pooled_output = hidden_states.mean(dim=1)
else:
padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
- expand_padding_mask = padding_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
- hidden_states[~expand_padding_mask] = 0.0
+ hidden_states[~padding_mask] = 0.0
pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
logits = self.classifier(pooled_output)
diff --git a/src/transformers/models/siglip/modeling_siglip.py b/src/transformers/models/siglip/modeling_siglip.py
index 9a2dfe013716a7..a42bcd0e17461e 100644
--- a/src/transformers/models/siglip/modeling_siglip.py
+++ b/src/transformers/models/siglip/modeling_siglip.py
@@ -438,6 +438,7 @@ class SiglipFlashAttention2(SiglipAttention):
is_causal = False
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
diff --git a/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py b/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py
index 3bff8f6acd290d..0d2b911bebe582 100644
--- a/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py
+++ b/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py
@@ -491,8 +491,6 @@ def forward(
kwargs_decoder = {
argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
}
- if "num_items_in_batch" in kwargs_encoder:
- kwargs_decoder["num_items_in_batch"] = kwargs_encoder.pop("num_items_in_batch", None)
if encoder_outputs is None:
if inputs is None:
diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py
index 88dc437cdcb91d..004e4ff3f6c030 100755
--- a/src/transformers/models/stablelm/modeling_stablelm.py
+++ b/src/transformers/models/stablelm/modeling_stablelm.py
@@ -65,18 +65,40 @@
class StableLmRotaryEmbedding(nn.Module):
def __init__(
self,
- config: StableLmConfig,
+ dim=None,
+ max_position_embeddings=2048,
+ base=10000,
device=None,
+ scaling_factor=1.0,
+ rope_type="default",
+ config: Optional[StableLmConfig] = None,
):
super().__init__()
+ # TODO (joao): remove the `if` below, only used for BC
self.rope_kwargs = {}
- # BC: "rope_type" was originally "type"
- if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
- self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
+ if config is None:
+ logger.warning_once(
+ "`StableLmRotaryEmbedding` can now be fully parameterized by passing the model config through the "
+ "`config` argument. All other arguments will be removed in v4.46"
+ )
+ self.rope_kwargs = {
+ "rope_type": rope_type,
+ "factor": scaling_factor,
+ "dim": dim,
+ "base": base,
+ "max_position_embeddings": max_position_embeddings,
+ }
+ self.rope_type = rope_type
+ self.max_seq_len_cached = max_position_embeddings
+ self.original_max_seq_len = max_position_embeddings
else:
- self.rope_type = "default"
- self.max_seq_len_cached = config.max_position_embeddings
- self.original_max_seq_len = config.max_position_embeddings
+ # BC: "rope_type" was originally "type"
+ if config.rope_scaling is not None:
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
+ else:
+ self.rope_type = "default"
+ self.max_seq_len_cached = config.max_position_embeddings
+ self.original_max_seq_len = config.max_position_embeddings
self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
@@ -127,6 +149,33 @@ def forward(self, x, position_ids):
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->StableLm
+class StableLmLinearScalingRotaryEmbedding(StableLmRotaryEmbedding):
+ """StableLmRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
+
+ def __init__(self, *args, **kwargs):
+ logger.warning_once(
+ "`StableLmLinearScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use "
+ "`StableLmRotaryEmbedding`, which now also does linear scaling (simply pass the model config to __init__)."
+ )
+ kwargs["rope_type"] = "linear"
+ super().__init__(*args, **kwargs)
+
+
+# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->StableLm
+class StableLmDynamicNTKScalingRotaryEmbedding(StableLmRotaryEmbedding):
+ """StableLmRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
+
+ def __init__(self, *args, **kwargs):
+ logger.warning_once(
+ "`StableLmDynamicNTKScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use "
+ "`StableLmRotaryEmbedding`, which now also does dynamic ntk scaling (simply pass the model config to "
+ "__init__)."
+ )
+ kwargs["rope_type"] = "dynamic"
+ super().__init__(*args, **kwargs)
+
+
# Copied from transformers.models.llama.modeling_llama.rotate_half
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
@@ -167,7 +216,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
class StableLmMLP(nn.Module):
def __init__(self, config):
super().__init__()
- self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
@@ -175,9 +223,8 @@ def __init__(self, config):
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act]
- def forward(self, x):
- down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
- return down_proj
+ def forward(self, hidden_state):
+ return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
class StableLmLayerNormPerHead(nn.Module):
@@ -260,7 +307,7 @@ def forward(
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
@@ -276,7 +323,16 @@ def forward(
query_states = self.q_layernorm(query_states)
key_states = self.k_layernorm(key_states)
- cos, sin = position_embeddings
+ if position_embeddings is None:
+ logger.warning_once(
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
+ "removed and `position_embeddings` will be mandatory."
+ )
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ else:
+ cos, sin = position_embeddings
# Partial rotary embedding
query_rot, query_pass = (
@@ -347,7 +403,7 @@ def forward(
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions:
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
@@ -362,8 +418,6 @@ def forward(
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
- cache_position=cache_position,
- position_embeddings=position_embeddings,
)
bsz, q_len, _ = hidden_states.size()
@@ -380,7 +434,16 @@ def forward(
query_states = self.q_layernorm(query_states)
key_states = self.k_layernorm(key_states)
- cos, sin = position_embeddings
+ if position_embeddings is None:
+ logger.warning_once(
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
+ "removed and `position_embeddings` will be mandatory."
+ )
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ else:
+ cos, sin = position_embeddings
# Partial rotary embedding
query_rot, query_pass = (
@@ -452,6 +515,7 @@ class StableLmFlashAttention2(StableLmAttention):
flash attention and deal with padding tokens in case the input contains any of them.
"""
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -469,7 +533,7 @@ def forward(
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# StableLmFlashAttention2 attention does not support output_attentions
@@ -493,7 +557,16 @@ def forward(
query_states = self.q_layernorm(query_states)
key_states = self.k_layernorm(key_states)
- cos, sin = position_embeddings
+ if position_embeddings is None:
+ logger.warning_once(
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
+ "removed and `position_embeddings` will be mandatory."
+ )
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ else:
+ cos, sin = position_embeddings
# Partial rotary embedding
query_rot, query_pass = (
@@ -577,7 +650,7 @@ def forward(
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
diff --git a/src/transformers/models/starcoder2/__init__.py b/src/transformers/models/starcoder2/__init__.py
index 6349255ed3a475..d9dc2cd1e5001c 100644
--- a/src/transformers/models/starcoder2/__init__.py
+++ b/src/transformers/models/starcoder2/__init__.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2024 BigCode and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,15 +13,52 @@
# limitations under the License.
from typing import TYPE_CHECKING
-from ...utils import _LazyModule
-from ...utils.import_utils import define_import_structure
+from ...utils import (
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ is_torch_available,
+)
+
+
+_import_structure = {
+ "configuration_starcoder2": ["Starcoder2Config"],
+}
+
+
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_starcoder2"] = [
+ "Starcoder2ForCausalLM",
+ "Starcoder2Model",
+ "Starcoder2PreTrainedModel",
+ "Starcoder2ForSequenceClassification",
+ "Starcoder2ForTokenClassification",
+ ]
if TYPE_CHECKING:
- from .configuration_starcoder2 import *
- from .modeling_starcoder2 import *
+ from .configuration_starcoder2 import Starcoder2Config
+
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_starcoder2 import (
+ Starcoder2ForCausalLM,
+ Starcoder2ForSequenceClassification,
+ Starcoder2ForTokenClassification,
+ Starcoder2Model,
+ Starcoder2PreTrainedModel,
+ )
+
+
else:
import sys
- _file = globals()["__file__"]
- sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
diff --git a/src/transformers/models/starcoder2/configuration_starcoder2.py b/src/transformers/models/starcoder2/configuration_starcoder2.py
index 7f21d1f12d8b22..5749eb68358468 100644
--- a/src/transformers/models/starcoder2/configuration_starcoder2.py
+++ b/src/transformers/models/starcoder2/configuration_starcoder2.py
@@ -197,6 +197,3 @@ def __init__(
eos_token_id=eos_token_id,
**kwargs,
)
-
-
-__all__ = ["Starcoder2Config"]
diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py
index 3b4fdbcb81ccc4..5ecffc8719bec9 100644
--- a/src/transformers/models/starcoder2/modeling_starcoder2.py
+++ b/src/transformers/models/starcoder2/modeling_starcoder2.py
@@ -24,7 +24,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Callable, List, Optional, Tuple, Union
+import math
+from typing import List, Optional, Tuple, Union
import torch
from torch import nn
@@ -33,7 +34,6 @@
from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter
-from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
@@ -41,24 +41,115 @@
TokenClassifierOutput,
)
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
-from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
-from ...processing_utils import Unpack
+from ...modeling_utils import PreTrainedModel
from ...utils import (
- LossKwargs,
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
+ is_flash_attn_2_available,
+ is_flash_attn_greater_or_equal_2_10,
logging,
replace_return_docstrings,
)
from .configuration_starcoder2 import Starcoder2Config
+if is_flash_attn_2_available():
+ from ...modeling_flash_attention_utils import _flash_attention_forward
+
+
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "bigcode/starcoder2-7b"
_CONFIG_FOR_DOC = "Starcoder2Config"
+class Starcoder2RotaryEmbedding(nn.Module):
+ def __init__(
+ self,
+ dim=None,
+ max_position_embeddings=2048,
+ base=10000,
+ device=None,
+ scaling_factor=1.0,
+ rope_type="default",
+ config: Optional[Starcoder2Config] = None,
+ ):
+ super().__init__()
+ # TODO (joao): remove the `if` below, only used for BC
+ self.rope_kwargs = {}
+ if config is None:
+ logger.warning_once(
+ "`Starcoder2RotaryEmbedding` can now be fully parameterized by passing the model config through the "
+ "`config` argument. All other arguments will be removed in v4.46"
+ )
+ self.rope_kwargs = {
+ "rope_type": rope_type,
+ "factor": scaling_factor,
+ "dim": dim,
+ "base": base,
+ "max_position_embeddings": max_position_embeddings,
+ }
+ self.rope_type = rope_type
+ self.max_seq_len_cached = max_position_embeddings
+ self.original_max_seq_len = max_position_embeddings
+ else:
+ # BC: "rope_type" was originally "type"
+ if config.rope_scaling is not None:
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
+ else:
+ self.rope_type = "default"
+ self.max_seq_len_cached = config.max_position_embeddings
+ self.original_max_seq_len = config.max_position_embeddings
+
+ self.config = config
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
+
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ self.original_inv_freq = self.inv_freq
+
+ def _dynamic_frequency_update(self, position_ids, device):
+ """
+ dynamic RoPE layers should recompute `inv_freq` in the following situations:
+ 1 - growing beyond the cached sequence length (allow scaling)
+ 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
+ """
+ seq_len = torch.max(position_ids) + 1
+ if seq_len > self.max_seq_len_cached: # growth
+ inv_freq, self.attention_scaling = self.rope_init_fn(
+ self.config, device, seq_len=seq_len, **self.rope_kwargs
+ )
+ self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
+ self.max_seq_len_cached = seq_len
+
+ if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
+ self.max_seq_len_cached = self.original_max_seq_len
+
+ @torch.no_grad()
+ def forward(self, x, position_ids):
+ if "dynamic" in self.rope_type:
+ self._dynamic_frequency_update(position_ids, device=x.device)
+
+ # Core RoPE block
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
+ position_ids_expanded = position_ids[:, None, :].float()
+ # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
+ device_type = x.device.type
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False):
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos()
+ sin = emb.sin()
+
+ # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
+ cos = cos * self.attention_scaling
+ sin = sin * self.attention_scaling
+
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
class Starcoder2MLP(nn.Module):
def __init__(self, config: Starcoder2Config):
super().__init__()
@@ -122,111 +213,336 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
-def eager_attention_forward(
- module: nn.Module,
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- attention_mask: Optional[torch.Tensor],
- scaling: float,
- dropout: float = 0.0,
- **kwargs,
-):
- key_states = repeat_kv(key, module.num_key_value_groups)
- value_states = repeat_kv(value, module.num_key_value_groups)
-
- attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
- if attention_mask is not None:
- causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
- attn_weights = attn_weights + causal_mask
-
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
- attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
- attn_output = torch.matmul(attn_weights, value_states)
- attn_output = attn_output.transpose(1, 2).contiguous()
-
- return attn_output, attn_weights
-
-
class Starcoder2Attention(nn.Module):
- """Multi-headed attention from 'Attention Is All You Need' paper"""
+ """
+ Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
+ and "Generating Long Sequences with Sparse Transformers".
+ """
def __init__(self, config: Starcoder2Config, layer_idx: Optional[int] = None):
super().__init__()
self.config = config
self.layer_idx = layer_idx
- self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
- self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
- self.scaling = self.head_dim**-0.5
- self.attention_dropout = config.attention_dropout
+ if layer_idx is None:
+ logger.warning_once(
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
+ "when creating this class."
+ )
+
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.hidden_size // self.num_heads
+ self.num_key_value_heads = config.num_key_value_heads
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
+ self.rope_theta = config.rope_theta
+ self.use_bias = config.use_bias
self.is_causal = True
- self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.use_bias)
- self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias)
- self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias)
- self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.use_bias)
+ self.attention_dropout = config.attention_dropout
self.residual_dropout = config.residual_dropout
+ if (self.head_dim * self.num_heads) != self.hidden_size:
+ raise ValueError(
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
+ f" and `num_heads`: {self.num_heads})."
+ )
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=self.use_bias)
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=self.use_bias)
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=self.use_bias)
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=self.use_bias)
+
+ self.rotary_emb = Starcoder2RotaryEmbedding(config=self.config)
+
def forward(
self,
hidden_states: torch.Tensor,
- position_embeddings: Tuple[torch.Tensor, torch.Tensor],
- attention_mask: Optional[torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
- **kwargs: Unpack[FlashAttentionKwargs],
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
- input_shape = hidden_states.shape[:-1]
- hidden_shape = (*input_shape, -1, self.head_dim)
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
- query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
- key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
- value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
- cos, sin = position_embeddings
+ if position_embeddings is None:
+ logger.warning_once(
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
+ "removed and `position_embeddings` will be mandatory."
+ )
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ else:
+ cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
- # sin and cos are specific to RoPE models; cache_position needed for the static cache
- cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
- attention_interface: Callable = eager_attention_forward
- if self.config._attn_implementation != "eager":
- if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
- logger.warning_once(
- "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
- 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
- )
+ # repeat k/v heads if n_kv_heads < n_heads
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
+ if attention_mask is not None: # no matter the length, we just slice it
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights += causal_mask
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+
+ attn_output = self.o_proj(attn_output)
+ attn_output = nn.functional.dropout(attn_output, p=self.residual_dropout, training=self.training)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+class Starcoder2FlashAttention2(Starcoder2Attention):
+ """
+ Starcoder2 flash attention module. This module inherits from `Starcoder2Attention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
+ ):
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
+
+ if position_embeddings is None:
+ logger.warning_once(
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
+ "removed and `position_embeddings` will be mandatory."
+ )
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ else:
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ # repeat k/v heads if n_kv_heads < n_heads
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+ dropout_rate = 0.0 if not self.training else self.attention_dropout
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in float16 just to be sure everything works as expected.
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
else:
- attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+ target_dtype = self.q_proj.weight.dtype
- attn_output, attn_weights = attention_interface(
- self,
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ # Reshape to the expected shape for Flash Attention
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ attn_output = _flash_attention_forward(
query_states,
key_states,
value_states,
attention_mask,
- dropout=0.0 if not self.training else self.attention_dropout,
- scaling=self.scaling,
- sliding_window=getattr(self.config, "sliding_window", None), # diff with Llama
- **kwargs,
+ q_len,
+ position_ids=position_ids,
+ dropout=dropout_rate,
+ sliding_window=getattr(self.config, "sliding_window", None),
+ is_causal=self.is_causal,
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
)
- attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
attn_output = self.o_proj(attn_output)
- attn_output = nn.functional.dropout(
- attn_output, p=self.residual_dropout, training=self.training
- ) # diff with Llama
+ attn_output = nn.functional.dropout(attn_output, p=self.residual_dropout, training=self.training)
+
+ if not output_attentions:
+ attn_weights = None
- return attn_output, attn_weights
+ return attn_output, attn_weights, past_key_value
+
+
+class Starcoder2SdpaAttention(Starcoder2Attention):
+ """
+ Starcoder2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
+ `Starcoder2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
+ SDPA API.
+ """
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ if output_attentions:
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
+ logger.warning_once(
+ "Starcoder2Model is using Starcoder2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
+ )
+ return super().forward(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ )
+
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
+
+ if position_embeddings is None:
+ logger.warning_once(
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
+ "removed and `position_embeddings` will be mandatory."
+ )
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ else:
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ causal_mask = attention_mask
+ if attention_mask is not None: # no matter the length, we just slice it
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
+ if query_states.device.type == "cuda" and attention_mask is not None:
+ query_states = query_states.contiguous()
+ key_states = key_states.contiguous()
+ value_states = value_states.contiguous()
+
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
+ # # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
+ is_causal = True if causal_mask is None and q_len > 1 else False
+
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
+ query_states,
+ key_states,
+ value_states,
+ attn_mask=causal_mask,
+ dropout_p=self.attention_dropout if self.training else 0.0,
+ is_causal=is_causal,
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+
+ attn_output = self.o_proj(attn_output)
+ # The difference with Mistral is that here it uses dropout
+ attn_output = nn.functional.dropout(attn_output, p=self.residual_dropout, training=self.training)
+
+ return attn_output, None, past_key_value
+
+
+STARCODER2_ATTENTION_CLASSES = {
+ "eager": Starcoder2Attention,
+ "flash_attention_2": Starcoder2FlashAttention2,
+ "sdpa": Starcoder2SdpaAttention,
+}
class Starcoder2DecoderLayer(nn.Module):
def __init__(self, config: Starcoder2Config, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
- self.self_attn = Starcoder2Attention(config=config, layer_idx=layer_idx)
+
+ self.self_attn = STARCODER2_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
+
self.mlp = Starcoder2MLP(config)
+
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon)
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon)
@@ -235,19 +551,41 @@ def forward(
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[Cache] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
- **kwargs: Unpack[FlashAttentionKwargs],
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
+ **kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
+ `(batch, sequence_length)` where padding elements are indicated by 0.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence.
+ position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
+ Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
+ with `head_dim` being the embedding dimension of each attention head.
+ kwargs (`dict`, *optional*):
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
+ into the model
+ """
+
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
- hidden_states, self_attn_weights = self.self_attn(
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
@@ -256,7 +594,6 @@ def forward(
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
- **kwargs,
)
hidden_states = residual + hidden_states
@@ -267,75 +604,14 @@ def forward(
hidden_states = residual + hidden_states
outputs = (hidden_states,)
+
if output_attentions:
outputs += (self_attn_weights,)
- return outputs
-
-
-class Starcoder2RotaryEmbedding(nn.Module):
- def __init__(
- self,
- config: Starcoder2Config,
- device=None,
- ):
- super().__init__()
- self.rope_kwargs = {}
- # BC: "rope_type" was originally "type"
- if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
- self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
- else:
- self.rope_type = "default"
- self.max_seq_len_cached = config.max_position_embeddings
- self.original_max_seq_len = config.max_position_embeddings
-
- self.config = config
- self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
-
- inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
- self.register_buffer("inv_freq", inv_freq, persistent=False)
- self.original_inv_freq = self.inv_freq
-
- def _dynamic_frequency_update(self, position_ids, device):
- """
- dynamic RoPE layers should recompute `inv_freq` in the following situations:
- 1 - growing beyond the cached sequence length (allow scaling)
- 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
- """
- seq_len = torch.max(position_ids) + 1
- if seq_len > self.max_seq_len_cached: # growth
- inv_freq, self.attention_scaling = self.rope_init_fn(
- self.config, device, seq_len=seq_len, **self.rope_kwargs
- )
- self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
- self.max_seq_len_cached = seq_len
-
- if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
- self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
- self.max_seq_len_cached = self.original_max_seq_len
-
- @torch.no_grad()
- def forward(self, x, position_ids):
- if "dynamic" in self.rope_type:
- self._dynamic_frequency_update(position_ids, device=x.device)
-
- # Core RoPE block
- inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
- position_ids_expanded = position_ids[:, None, :].float()
- # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
- device_type = x.device.type
- device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
- with torch.autocast(device_type=device_type, enabled=False):
- freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
- emb = torch.cat((freqs, freqs), dim=-1)
- cos = emb.cos()
- sin = emb.sin()
-
- # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
- cos = cos * self.attention_scaling
- sin = sin * self.attention_scaling
+ if use_cache:
+ outputs += (present_key_value,)
- return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+ return outputs
STARCODER2_START_DOCSTRING = r"""
@@ -364,7 +640,7 @@ class Starcoder2PreTrainedModel(PreTrainedModel):
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["Starcoder2DecoderLayer"]
- _skip_keys_device_placement = ["past_key_values"]
+ _skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
@@ -404,7 +680,7 @@ def _init_weights(self, module):
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
- If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
`past_key_values`).
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
@@ -479,11 +755,12 @@ def __init__(self, config: Starcoder2Config):
self.layers = nn.ModuleList(
[Starcoder2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
+ self._attn_implementation = config._attn_implementation
self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon)
self.rotary_emb = Starcoder2RotaryEmbedding(config=config)
+
self.gradient_checkpointing = False
self.embedding_dropout = config.embedding_dropout
-
# Initialize weights and apply final processing
self.post_init()
@@ -499,43 +776,54 @@ def forward(
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
- **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
- if self.gradient_checkpointing and self.training and use_cache:
- logger.warning_once(
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
- )
- use_cache = False
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ # kept for BC (non `Cache` `past_key_values` inputs)
+ return_legacy_cache = False
+ if use_cache and not isinstance(past_key_values, Cache):
+ return_legacy_cache = True
+ if past_key_values is None:
+ past_key_values = DynamicCache()
+ else:
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
+ logger.warning_once(
+ "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
+ "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
+ "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
+ )
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
- if use_cache and past_key_values is None:
- past_key_values = DynamicCache()
-
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
-
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
@@ -544,9 +832,7 @@ def forward(
)
hidden_states = inputs_embeds
- hidden_states = nn.functional.dropout(
- hidden_states, p=self.embedding_dropout, training=self.training
- ) # main diff with Llama
+ hidden_states = nn.functional.dropout(hidden_states, p=self.embedding_dropout, training=self.training)
# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)
@@ -554,25 +840,41 @@ def forward(
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
+ next_decoder_cache = None
- for decoder_layer in self.layers[: self.config.num_hidden_layers]:
+ for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
- layer_outputs = decoder_layer(
- hidden_states,
- attention_mask=causal_mask,
- position_ids=position_ids,
- past_key_value=past_key_values,
- output_attentions=output_attentions,
- use_cache=use_cache,
- cache_position=cache_position,
- position_embeddings=position_embeddings,
- **flash_attn_kwargs,
- )
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ decoder_layer.__call__,
+ hidden_states,
+ causal_mask,
+ position_ids,
+ past_key_values,
+ output_attentions,
+ use_cache,
+ cache_position,
+ position_embeddings,
+ )
+ else:
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ )
hidden_states = layer_outputs[0]
+ if use_cache:
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
+
if output_attentions:
all_self_attns += (layer_outputs[1],)
@@ -582,13 +884,18 @@ def forward(
if output_hidden_states:
all_hidden_states += (hidden_states,)
- output = BaseModelOutputWithPast(
+ next_cache = next_decoder_cache if use_cache else None
+ if return_legacy_cache:
+ next_cache = next_cache.to_legacy_cache()
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
+ return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
- past_key_values=past_key_values if use_cache else None,
+ past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
- return output if return_dict else output.to_tuple()
def _update_causal_mask(
self,
@@ -599,14 +906,6 @@ def _update_causal_mask(
output_attentions: bool,
):
if self.config._attn_implementation == "flash_attention_2":
- if attention_mask is not None and past_key_values is not None:
- is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0]
- if is_padding_right:
- raise ValueError(
- "You are attempting to perform batched generation with padding_side='right'"
- " this may lead to unexpected behaviour for Flash Attention version of Starcoder2. Make sure to "
- " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
- )
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
@@ -741,9 +1040,6 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
return causal_mask
-class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
-
-
class Starcoder2ForCausalLM(Starcoder2PreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
@@ -782,7 +1078,7 @@ def forward(
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
@@ -791,7 +1087,7 @@ def forward(
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
- **kwargs: Unpack[KwargsForCausalLM],
+ **loss_kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
@@ -812,8 +1108,8 @@ def forward(
```python
>>> from transformers import AutoTokenizer, Starcoder2ForCausalLM
- >>> model = Starcoder2ForCausalLM.from_pretrained("meta-starcoder2/Starcoder2-2-7b-hf")
- >>> tokenizer = AutoTokenizer.from_pretrained("meta-starcoder2/Starcoder2-2-7b-hf")
+ >>> model = Starcoder2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
@@ -823,6 +1119,7 @@ def forward(
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -841,7 +1138,6 @@ def forward(
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
- **kwargs,
)
hidden_states = outputs[0]
@@ -850,7 +1146,7 @@ def forward(
loss = None
if labels is not None:
- loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
+ loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
if not return_dict:
output = (logits,) + outputs[1:]
@@ -1055,12 +1351,3 @@ def forward(
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
-
-
-__all__ = [
- "Starcoder2ForCausalLM",
- "Starcoder2Model",
- "Starcoder2PreTrainedModel",
- "Starcoder2ForSequenceClassification",
- "Starcoder2ForTokenClassification",
-]
diff --git a/src/transformers/models/starcoder2/modular_starcoder2.py b/src/transformers/models/starcoder2/modular_starcoder2.py
index 32d64cd167ba50..b5d74bf7feb39f 100644
--- a/src/transformers/models/starcoder2/modular_starcoder2.py
+++ b/src/transformers/models/starcoder2/modular_starcoder2.py
@@ -19,7 +19,8 @@
# limitations under the License.
"""PyTorch Starcoder2 model."""
-from typing import Callable, List, Optional, Tuple, Union
+import math
+from typing import List, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
@@ -27,32 +28,40 @@
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
-from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import (
BaseModelOutputWithPast,
)
-from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
-from ...processing_utils import Unpack
-from ...utils import add_start_docstrings_to_model_forward, logging
-from ..mistral.modeling_mistral import (
- MistralAttention,
- MistralDecoderLayer,
- MistralForCausalLM,
- MistralForSequenceClassification,
- MistralForTokenClassification,
- MistralModel,
+from ...utils import (
+ add_start_docstrings_to_model_forward,
+ is_flash_attn_2_available,
+ is_flash_attn_greater_or_equal_2_10,
+ logging,
+)
+from ..llama.modeling_llama import (
+ LlamaForSequenceClassification,
+ LlamaForTokenClassification,
+ LlamaRotaryEmbedding,
apply_rotary_pos_emb,
- eager_attention_forward,
+ repeat_kv,
)
+from ..qwen2.modeling_qwen2 import Qwen2DecoderLayer, Qwen2ForCausalLM, Qwen2Model, Qwen2PreTrainedModel
from .configuration_starcoder2 import Starcoder2Config
+if is_flash_attn_2_available():
+ from ...modeling_flash_attention_utils import _flash_attention_forward
+
+
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "Starcoder2Config"
_CHECKPOINT_FOR_DOC = "bigcode/starcoder2-7b"
+class Starcoder2RotaryEmbedding(LlamaRotaryEmbedding):
+ pass
+
+
class Starcoder2MLP(nn.Module):
def __init__(self, config: Starcoder2Config):
super().__init__()
@@ -70,90 +79,359 @@ def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.Fl
return hidden_states
-class Starcoder2Attention(MistralAttention):
+class Starcoder2Attention(nn.Module):
+ """
+ Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
+ and "Generating Long Sequences with Sparse Transformers".
+ """
+
def __init__(self, config: Starcoder2Config, layer_idx: Optional[int] = None):
super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ if layer_idx is None:
+ logger.warning_once(
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
+ "when creating this class."
+ )
+
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.hidden_size // self.num_heads
+ self.num_key_value_heads = config.num_key_value_heads
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
+ self.rope_theta = config.rope_theta
+ self.use_bias = config.use_bias
+ self.is_causal = True
+ self.attention_dropout = config.attention_dropout
self.residual_dropout = config.residual_dropout
- self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.use_bias)
- self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias)
- self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias)
- self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.use_bias)
+
+ if (self.head_dim * self.num_heads) != self.hidden_size:
+ raise ValueError(
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
+ f" and `num_heads`: {self.num_heads})."
+ )
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=self.use_bias)
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=self.use_bias)
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=self.use_bias)
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=self.use_bias)
+
+ self.rotary_emb = Starcoder2RotaryEmbedding(config=self.config)
def forward(
self,
hidden_states: torch.Tensor,
- position_embeddings: Tuple[torch.Tensor, torch.Tensor],
- attention_mask: Optional[torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
- **kwargs: Unpack[FlashAttentionKwargs],
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
- input_shape = hidden_states.shape[:-1]
- hidden_shape = (*input_shape, -1, self.head_dim)
+ bsz, q_len, _ = hidden_states.size()
- query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
- key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
- value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
- cos, sin = position_embeddings
+ query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
+
+ if position_embeddings is None:
+ logger.warning_once(
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
+ "removed and `position_embeddings` will be mandatory."
+ )
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ else:
+ cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
- # sin and cos are specific to RoPE models; cache_position needed for the static cache
- cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
- attention_interface: Callable = eager_attention_forward
- if self.config._attn_implementation != "eager":
- if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
- logger.warning_once(
- "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
- 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
- )
+ # repeat k/v heads if n_kv_heads < n_heads
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
+ if attention_mask is not None: # no matter the length, we just slice it
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights += causal_mask
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+
+ attn_output = self.o_proj(attn_output)
+ attn_output = nn.functional.dropout(attn_output, p=self.residual_dropout, training=self.training)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+class Starcoder2FlashAttention2(Starcoder2Attention):
+ """
+ Starcoder2 flash attention module. This module inherits from `Starcoder2Attention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
+ ):
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
+
+ if position_embeddings is None:
+ logger.warning_once(
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
+ "removed and `position_embeddings` will be mandatory."
+ )
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ else:
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ # repeat k/v heads if n_kv_heads < n_heads
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+ dropout_rate = 0.0 if not self.training else self.attention_dropout
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in float16 just to be sure everything works as expected.
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
else:
- attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+ target_dtype = self.q_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
- attn_output, attn_weights = attention_interface(
- self,
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ # Reshape to the expected shape for Flash Attention
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ attn_output = _flash_attention_forward(
query_states,
key_states,
value_states,
attention_mask,
- dropout=0.0 if not self.training else self.attention_dropout,
- scaling=self.scaling,
- sliding_window=getattr(self.config, "sliding_window", None), # diff with Llama
- **kwargs,
+ q_len,
+ position_ids=position_ids,
+ dropout=dropout_rate,
+ sliding_window=getattr(self.config, "sliding_window", None),
+ is_causal=self.is_causal,
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
)
- attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
attn_output = self.o_proj(attn_output)
- attn_output = nn.functional.dropout(
- attn_output, p=self.residual_dropout, training=self.training
- ) # diff with Llama
+ attn_output = nn.functional.dropout(attn_output, p=self.residual_dropout, training=self.training)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
- return attn_output, attn_weights
+class Starcoder2SdpaAttention(Starcoder2Attention):
+ """
+ Starcoder2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
+ `Starcoder2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
+ SDPA API.
+ """
-class Starcoder2DecoderLayer(MistralDecoderLayer):
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ if output_attentions:
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
+ logger.warning_once(
+ "Starcoder2Model is using Starcoder2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
+ )
+ return super().forward(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ )
+
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
+
+ if position_embeddings is None:
+ logger.warning_once(
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
+ "removed and `position_embeddings` will be mandatory."
+ )
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ else:
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ causal_mask = attention_mask
+ if attention_mask is not None: # no matter the length, we just slice it
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
+ if query_states.device.type == "cuda" and attention_mask is not None:
+ query_states = query_states.contiguous()
+ key_states = key_states.contiguous()
+ value_states = value_states.contiguous()
+
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
+ # # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
+ is_causal = True if causal_mask is None and q_len > 1 else False
+
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
+ query_states,
+ key_states,
+ value_states,
+ attn_mask=causal_mask,
+ dropout_p=self.attention_dropout if self.training else 0.0,
+ is_causal=is_causal,
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+
+ attn_output = self.o_proj(attn_output)
+ # The difference with Mistral is that here it uses dropout
+ attn_output = nn.functional.dropout(attn_output, p=self.residual_dropout, training=self.training)
+
+ return attn_output, None, past_key_value
+
+
+STARCODER2_ATTENTION_CLASSES = {
+ "eager": Starcoder2Attention,
+ "flash_attention_2": Starcoder2FlashAttention2,
+ "sdpa": Starcoder2SdpaAttention,
+}
+
+
+class Starcoder2DecoderLayer(Qwen2DecoderLayer, nn.Module):
def __init__(self, config: Starcoder2Config, layer_idx: int):
- super().__init__(self)
- self.self_attn = Starcoder2Attention(config=config, layer_idx=layer_idx)
+ nn.Module.__init__(self)
+ self.hidden_size = config.hidden_size
+
+ self.self_attn = STARCODER2_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
+
self.mlp = Starcoder2MLP(config)
+
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon)
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon)
+class Starcoder2PreTrainedModel(Qwen2PreTrainedModel):
+ pass
+
+
STARCODER2_INPUTS_DOCSTRING = None # will be automatically redefined
-class Starcoder2Model(MistralModel):
+class Starcoder2Model(Qwen2Model):
+ """
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Starcoder2DecoderLayer`]
+
+ Args:
+ config: Starcoder2Config
+ """
+
def __init__(self, config: Starcoder2Config):
super().__init__(config)
- self.layers = nn.ModuleList(
- [Starcoder2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
- )
- self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon)
self.embedding_dropout = config.embedding_dropout
+ self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon)
@add_start_docstrings_to_model_forward(STARCODER2_INPUTS_DOCSTRING)
def forward(
@@ -161,43 +439,54 @@ def forward(
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
- **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
- if self.gradient_checkpointing and self.training and use_cache:
- logger.warning_once(
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
- )
- use_cache = False
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ # kept for BC (non `Cache` `past_key_values` inputs)
+ return_legacy_cache = False
+ if use_cache and not isinstance(past_key_values, Cache):
+ return_legacy_cache = True
+ if past_key_values is None:
+ past_key_values = DynamicCache()
+ else:
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
+ logger.warning_once(
+ "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
+ "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
+ "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
+ )
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
- if use_cache and past_key_values is None:
- past_key_values = DynamicCache()
-
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
-
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
@@ -206,9 +495,7 @@ def forward(
)
hidden_states = inputs_embeds
- hidden_states = nn.functional.dropout(
- hidden_states, p=self.embedding_dropout, training=self.training
- ) # main diff with Llama
+ hidden_states = nn.functional.dropout(hidden_states, p=self.embedding_dropout, training=self.training)
# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)
@@ -216,25 +503,41 @@ def forward(
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
+ next_decoder_cache = None
- for decoder_layer in self.layers[: self.config.num_hidden_layers]:
+ for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
- layer_outputs = decoder_layer(
- hidden_states,
- attention_mask=causal_mask,
- position_ids=position_ids,
- past_key_value=past_key_values,
- output_attentions=output_attentions,
- use_cache=use_cache,
- cache_position=cache_position,
- position_embeddings=position_embeddings,
- **flash_attn_kwargs,
- )
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ decoder_layer.__call__,
+ hidden_states,
+ causal_mask,
+ position_ids,
+ past_key_values,
+ output_attentions,
+ use_cache,
+ cache_position,
+ position_embeddings,
+ )
+ else:
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ )
hidden_states = layer_outputs[0]
+ if use_cache:
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
+
if output_attentions:
all_self_attns += (layer_outputs[1],)
@@ -244,31 +547,27 @@ def forward(
if output_hidden_states:
all_hidden_states += (hidden_states,)
- output = BaseModelOutputWithPast(
+ next_cache = next_decoder_cache if use_cache else None
+ if return_legacy_cache:
+ next_cache = next_cache.to_legacy_cache()
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
+ return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
- past_key_values=past_key_values if use_cache else None,
+ past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
- return output if return_dict else output.to_tuple()
-class Starcoder2ForCausalLM(MistralForCausalLM):
+class Starcoder2ForCausalLM(Qwen2ForCausalLM):
pass
-class Starcoder2ForSequenceClassification(MistralForSequenceClassification):
+class Starcoder2ForSequenceClassification(LlamaForSequenceClassification):
pass
-class Starcoder2ForTokenClassification(MistralForTokenClassification):
+class Starcoder2ForTokenClassification(LlamaForTokenClassification):
pass
-
-
-__all__ = [
- "Starcoder2ForCausalLM",
- "Starcoder2Model",
- "Starcoder2PreTrainedModel", # noqa: F822
- "Starcoder2ForSequenceClassification",
- "Starcoder2ForTokenClassification",
-]
diff --git a/src/transformers/models/superpoint/modeling_superpoint.py b/src/transformers/models/superpoint/modeling_superpoint.py
index dcdd85460b39bd..1075de299a9f40 100644
--- a/src/transformers/models/superpoint/modeling_superpoint.py
+++ b/src/transformers/models/superpoint/modeling_superpoint.py
@@ -25,6 +25,7 @@
)
from transformers.models.superpoint.configuration_superpoint import SuperPointConfig
+from ...pytorch_utils import is_torch_greater_or_equal_than_1_13
from ...utils import (
ModelOutput,
add_start_docstrings,
@@ -313,7 +314,7 @@ def _sample_descriptors(keypoints, descriptors, scale: int = 8) -> torch.Tensor:
divisor = divisor.to(keypoints)
keypoints /= divisor
keypoints = keypoints * 2 - 1 # normalize to (-1, 1)
- kwargs = {"align_corners": True}
+ kwargs = {"align_corners": True} if is_torch_greater_or_equal_than_1_13 else {}
# [batch_size, num_channels, num_keypoints, 2] -> [batch_size, num_channels, num_keypoints, 2]
keypoints = keypoints.view(batch_size, 1, -1, 2)
descriptors = nn.functional.grid_sample(descriptors, keypoints, mode="bilinear", **kwargs)
diff --git a/src/transformers/models/tapas/modeling_tapas.py b/src/transformers/models/tapas/modeling_tapas.py
index 2ea0d38a23f933..b74a27ae5ce589 100644
--- a/src/transformers/models/tapas/modeling_tapas.py
+++ b/src/transformers/models/tapas/modeling_tapas.py
@@ -31,6 +31,7 @@
from ...pytorch_utils import (
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
+ is_torch_greater_or_equal_than_1_12,
prune_linear_layer,
)
from ...utils import (
@@ -45,6 +46,12 @@
logger = logging.get_logger(__name__)
+if not is_torch_greater_or_equal_than_1_12:
+ logger.warning(
+ f"You are using torch=={torch.__version__}, but torch>=1.12.0 is required to use "
+ "TapasModel. Please upgrade torch."
+ )
+
_CONFIG_FOR_DOC = "TapasConfig"
_CHECKPOINT_FOR_DOC = "google/tapas-base"
diff --git a/src/transformers/models/timm_wrapper/__init__.py b/src/transformers/models/timm_wrapper/__init__.py
deleted file mode 100644
index 9fbc4150412a73..00000000000000
--- a/src/transformers/models/timm_wrapper/__init__.py
+++ /dev/null
@@ -1,28 +0,0 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-from typing import TYPE_CHECKING
-
-from ...utils import _LazyModule
-from ...utils.import_utils import define_import_structure
-
-
-if TYPE_CHECKING:
- from .configuration_timm_wrapper import *
- from .modeling_timm_wrapper import *
- from .processing_timm_wrapper import *
-else:
- import sys
-
- _file = globals()["__file__"]
- sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/src/transformers/models/timm_wrapper/configuration_timm_wrapper.py b/src/transformers/models/timm_wrapper/configuration_timm_wrapper.py
deleted file mode 100644
index 691a2b2b76ec3f..00000000000000
--- a/src/transformers/models/timm_wrapper/configuration_timm_wrapper.py
+++ /dev/null
@@ -1,86 +0,0 @@
-# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Configuration for TimmWrapper models"""
-
-from typing import Any, Dict
-
-from ...configuration_utils import PretrainedConfig
-from ...utils import logging
-
-
-logger = logging.get_logger(__name__)
-
-
-class TimmWrapperConfig(PretrainedConfig):
- r"""
- This is the configuration class to store the configuration for a timm backbone [`TimmWrapper`].
-
- It is used to instantiate a timm model according to the specified arguments, defining the model.
-
- Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
- documentation from [`PretrainedConfig`] for more information.
-
- Args:
- initializer_range (`float`, *optional*, defaults to 0.02):
- The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
- do_pooling (`bool`, *optional*, defaults to `True`):
- Whether to do pooling for the last_hidden_state in `TimmWrapperModel` or not.
-
- Example:
- ```python
- >>> from transformers import TimmWrapperModel
-
- >>> # Initializing a timm model
- >>> model = TimmWrapperModel.from_pretrained("timm/resnet18.a1_in1k")
-
- >>> # Accessing the model configuration
- >>> configuration = model.config
- ```
- """
-
- model_type = "timm_wrapper"
-
- def __init__(self, initializer_range: float = 0.02, do_pooling: bool = True, **kwargs):
- self.initializer_range = initializer_range
- self.do_pooling = do_pooling
- super().__init__(**kwargs)
-
- @classmethod
- def from_dict(cls, config_dict: Dict[str, Any], **kwargs):
- # timm config stores the `num_classes` attribute in both the root of config and in the "pretrained_cfg" dict.
- # We are removing these attributes in order to have the native `transformers` num_labels attribute in config
- # and to avoid duplicate attributes
-
- num_labels_in_kwargs = kwargs.pop("num_labels", None)
- num_labels_in_dict = config_dict.pop("num_classes", None)
-
- # passed num_labels has priority over num_classes in config_dict
- kwargs["num_labels"] = num_labels_in_kwargs or num_labels_in_dict
-
- # pop num_classes from "pretrained_cfg",
- # it is not necessary to have it, only root one is used in timm
- if "pretrained_cfg" in config_dict and "num_classes" in config_dict["pretrained_cfg"]:
- config_dict["pretrained_cfg"].pop("num_classes", None)
-
- return super().from_dict(config_dict, **kwargs)
-
- def to_dict(self) -> Dict[str, Any]:
- output = super().to_dict()
- output["num_classes"] = self.num_labels
- return output
-
-
-__all__ = ["TimmWrapperConfig"]
diff --git a/src/transformers/models/timm_wrapper/image_processing_timm_wrapper.py b/src/transformers/models/timm_wrapper/image_processing_timm_wrapper.py
deleted file mode 100644
index 02075a50fb2676..00000000000000
--- a/src/transformers/models/timm_wrapper/image_processing_timm_wrapper.py
+++ /dev/null
@@ -1,138 +0,0 @@
-# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import os
-from typing import Any, Dict, Optional, Tuple, Union
-
-import torch
-
-from ...image_processing_utils import BaseImageProcessor, BatchFeature
-from ...image_transforms import to_pil_image
-from ...image_utils import ImageInput, make_list_of_images
-from ...utils import TensorType, logging, requires_backends
-from ...utils.import_utils import is_timm_available, is_torch_available
-
-
-if is_timm_available():
- import timm
-
-if is_torch_available():
- import torch
-
-
-logger = logging.get_logger(__name__)
-
-
-class TimmWrapperImageProcessor(BaseImageProcessor):
- """
- Wrapper class for timm models to be used within transformers.
-
- Args:
- pretrained_cfg (`Dict[str, Any]`):
- The configuration of the pretrained model used to resolve evaluation and
- training transforms.
- architecture (`Optional[str]`, *optional*):
- Name of the architecture of the model.
- """
-
- main_input_name = "pixel_values"
-
- def __init__(
- self,
- pretrained_cfg: Dict[str, Any],
- architecture: Optional[str] = None,
- **kwargs,
- ):
- requires_backends(self, "timm")
- super().__init__(architecture=architecture)
-
- self.data_config = timm.data.resolve_data_config(pretrained_cfg, model=None, verbose=False)
- self.val_transforms = timm.data.create_transform(**self.data_config, is_training=False)
-
- # useful for training, see examples/pytorch/image-classification/run_image_classification.py
- self.train_transforms = timm.data.create_transform(**self.data_config, is_training=True)
-
- # If `ToTensor` is in the transforms, then the input should be numpy array or PIL image.
- # Otherwise, the input can be a tensor. In later timm versions, `MaybeToTensor` is used
- # which can handle both numpy arrays / PIL images and tensors.
- self._not_supports_tensor_input = any(
- transform.__class__.__name__ == "ToTensor" for transform in self.val_transforms.transforms
- )
-
- def to_dict(self) -> Dict[str, Any]:
- """
- Serializes this instance to a Python dictionary.
- """
- output = super().to_dict()
- output.pop("train_transforms", None)
- output.pop("val_transforms", None)
- output.pop("_not_supports_tensor_input", None)
- return output
-
- @classmethod
- def get_image_processor_dict(
- cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
- ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
- """
- Get the image processor dict for the model.
- """
- image_processor_filename = kwargs.pop("image_processor_filename", "config.json")
- return super().get_image_processor_dict(
- pretrained_model_name_or_path, image_processor_filename=image_processor_filename, **kwargs
- )
-
- def preprocess(
- self,
- images: ImageInput,
- return_tensors: Optional[Union[str, TensorType]] = "pt",
- ) -> BatchFeature:
- """
- Preprocess an image or batch of images.
-
- Args:
- images (`ImageInput`):
- Image to preprocess. Expects a single or batch of images
- return_tensors (`str` or `TensorType`, *optional*):
- The type of tensors to return.
- """
- if return_tensors != "pt":
- raise ValueError(f"return_tensors for TimmWrapperImageProcessor must be 'pt', but got {return_tensors}")
-
- if self._not_supports_tensor_input and isinstance(images, torch.Tensor):
- images = images.cpu().numpy()
-
- # If the input is a torch tensor, then no conversion is needed
- # Otherwise, we need to pass in a list of PIL images
- if isinstance(images, torch.Tensor):
- images = self.val_transforms(images)
- # Add batch dimension if a single image
- images = images.unsqueeze(0) if images.ndim == 3 else images
- else:
- images = make_list_of_images(images)
- images = [to_pil_image(image) for image in images]
- images = torch.stack([self.val_transforms(image) for image in images])
-
- return BatchFeature({"pixel_values": images}, tensor_type=return_tensors)
-
- def save_pretrained(self, *args, **kwargs):
- # disable it to make checkpoint the same as in `timm` library.
- logger.warning_once(
- "The `save_pretrained` method is disabled for TimmWrapperImageProcessor. "
- "The image processor configuration is saved directly in `config.json` when "
- "`save_pretrained` is called for saving the model."
- )
-
-
-__all__ = ["TimmWrapperImageProcessor"]
diff --git a/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py b/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py
deleted file mode 100644
index dfb14dfccec4c6..00000000000000
--- a/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py
+++ /dev/null
@@ -1,363 +0,0 @@
-# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from dataclasses import dataclass
-from typing import List, Optional, Tuple, Union
-
-import torch
-from torch import Tensor, nn
-from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
-
-from ...modeling_outputs import ImageClassifierOutput, ModelOutput
-from ...modeling_utils import PreTrainedModel
-from ...utils import (
- add_start_docstrings_to_model_forward,
- is_timm_available,
- replace_return_docstrings,
- requires_backends,
-)
-from .configuration_timm_wrapper import TimmWrapperConfig
-
-
-if is_timm_available():
- import timm
-
-
-@dataclass
-class TimmWrapperModelOutput(ModelOutput):
- """
- Output class for models TimmWrapperModel, containing the last hidden states, an optional pooled output,
- and optional hidden states.
-
- Args:
- last_hidden_state (`torch.FloatTensor`):
- The last hidden state of the model, output before applying the classification head.
- pooler_output (`torch.FloatTensor`, *optional*):
- The pooled output derived from the last hidden state, if applicable.
- hidden_states (`tuple(torch.FloatTensor)`, *optional*):
- A tuple containing the intermediate hidden states of the model at the output of each layer or specified layers.
- Returned if `output_hidden_states=True` is set or if `config.output_hidden_states=True`.
- attentions (`tuple(torch.FloatTensor)`, *optional*):
- A tuple containing the intermediate attention weights of the model at the output of each layer.
- Returned if `output_attentions=True` is set or if `config.output_attentions=True`.
- Note: Currently, Timm models do not support attentions output.
- """
-
- last_hidden_state: torch.FloatTensor
- pooler_output: Optional[torch.FloatTensor] = None
- hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
- attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
-
-
-TIMM_WRAPPER_INPUTS_DOCSTRING = r"""
- Args:
- pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
- Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`TimmWrapperImageProcessor.preprocess`]
- for details.
- output_attentions (`bool`, *optional*):
- Whether or not to return the attentions tensors of all attention layers. Not compatible with timm wrapped models.
- output_hidden_states (`bool`, *optional*):
- Whether or not to return the hidden states of all layers. Not compatible with timm wrapped models.
- return_dict (`bool`, *optional*):
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
- **kwargs:
- Additional keyword arguments passed along to the `timm` model forward.
-"""
-
-
-class TimmWrapperPreTrainedModel(PreTrainedModel):
- main_input_name = "pixel_values"
- config_class = TimmWrapperConfig
- _no_split_modules = []
-
- def __init__(self, *args, **kwargs):
- requires_backends(self, ["vision", "timm"])
- super().__init__(*args, **kwargs)
-
- @staticmethod
- def _fix_state_dict_key_on_load(key):
- """
- Overrides original method that renames `gamma` and `beta` to `weight` and `bias`.
- We don't want this behavior for timm wrapped models. Instead, this method adds a
- "timm_model." prefix to enable loading official timm Hub checkpoints.
- """
- if "timm_model." not in key:
- return f"timm_model.{key}"
- return key
-
- def _fix_state_dict_key_on_save(self, key):
- """
- Overrides original method to remove "timm_model." prefix from state_dict keys.
- Makes the saved checkpoint compatible with the `timm` library.
- """
- return key.replace("timm_model.", "")
-
- def load_state_dict(self, state_dict, *args, **kwargs):
- """
- Override original method to fix state_dict keys on load for cases when weights are loaded
- without using the `from_pretrained` method (e.g., in Trainer to resume from checkpoint).
- """
- state_dict = self._fix_state_dict_keys_on_load(state_dict)
- return super().load_state_dict(state_dict, *args, **kwargs)
-
- def _init_weights(self, module):
- """
- Initialize weights function to properly initialize Linear layer weights.
- Since model architectures may vary, we assume only the classifier requires
- initialization, while all other weights should be loaded from the checkpoint.
- """
- if isinstance(module, (nn.Linear)):
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
- if module.bias is not None:
- module.bias.data.zero_()
-
-
-class TimmWrapperModel(TimmWrapperPreTrainedModel):
- """
- Wrapper class for timm models to be used in transformers.
- """
-
- def __init__(self, config: TimmWrapperConfig):
- super().__init__(config)
- # using num_classes=0 to avoid creating classification head
- self.timm_model = timm.create_model(config.architecture, pretrained=False, num_classes=0)
- self.post_init()
-
- @add_start_docstrings_to_model_forward(TIMM_WRAPPER_INPUTS_DOCSTRING)
- @replace_return_docstrings(output_type=TimmWrapperModelOutput, config_class=TimmWrapperConfig)
- def forward(
- self,
- pixel_values: torch.FloatTensor,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[Union[bool, List[int]]] = None,
- return_dict: Optional[bool] = None,
- do_pooling: Optional[bool] = None,
- **kwargs,
- ) -> Union[TimmWrapperModelOutput, Tuple[Tensor, ...]]:
- r"""
- do_pooling (`bool`, *optional*):
- Whether to do pooling for the last_hidden_state in `TimmWrapperModel` or not. If `None` is passed, the
- `do_pooling` value from the config is used.
-
- Returns:
-
- Examples:
- ```python
- >>> import torch
- >>> from PIL import Image
- >>> from urllib.request import urlopen
- >>> from transformers import AutoModel, AutoImageProcessor
-
- >>> # Load image
- >>> image = Image.open(urlopen(
- ... 'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'
- ... ))
-
- >>> # Load model and image processor
- >>> checkpoint = "timm/resnet50.a1_in1k"
- >>> image_processor = AutoImageProcessor.from_pretrained(checkpoint)
- >>> model = AutoModel.from_pretrained(checkpoint).eval()
-
- >>> # Preprocess image
- >>> inputs = image_processor(image)
-
- >>> # Forward pass
- >>> with torch.no_grad():
- ... outputs = model(**inputs)
-
- >>> # Get pooled output
- >>> pooled_output = outputs.pooler_output
-
- >>> # Get last hidden state
- >>> last_hidden_state = outputs.last_hidden_state
- ```
- """
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- do_pooling = do_pooling if do_pooling is not None else self.config.do_pooling
-
- if output_attentions:
- raise ValueError("Cannot set `output_attentions` for timm models.")
-
- if output_hidden_states and not hasattr(self.timm_model, "forward_intermediates"):
- raise ValueError(
- "The 'output_hidden_states' option cannot be set for this timm model. "
- "To enable this feature, the 'forward_intermediates' method must be implemented "
- "in the timm model (available in timm versions > 1.*). Please consider using a "
- "different architecture or updating the timm package to a compatible version."
- )
-
- pixel_values = pixel_values.to(self.device, self.dtype)
-
- if output_hidden_states:
- # to enable hidden states selection
- if isinstance(output_hidden_states, (list, tuple)):
- kwargs["indices"] = output_hidden_states
- last_hidden_state, hidden_states = self.timm_model.forward_intermediates(pixel_values, **kwargs)
- else:
- last_hidden_state = self.timm_model.forward_features(pixel_values, **kwargs)
- hidden_states = None
-
- if do_pooling:
- # classification head is not created, applying pooling only
- pooler_output = self.timm_model.forward_head(last_hidden_state)
- else:
- pooler_output = None
-
- if not return_dict:
- outputs = (last_hidden_state, pooler_output, hidden_states)
- outputs = tuple(output for output in outputs if output is not None)
- return outputs
-
- return TimmWrapperModelOutput(
- last_hidden_state=last_hidden_state,
- pooler_output=pooler_output,
- hidden_states=hidden_states,
- )
-
-
-class TimmWrapperForImageClassification(TimmWrapperPreTrainedModel):
- """
- Wrapper class for timm models to be used in transformers for image classification.
- """
-
- def __init__(self, config: TimmWrapperConfig):
- super().__init__(config)
-
- if config.num_labels == 0:
- raise ValueError(
- "You are trying to load weights into `TimmWrapperForImageClassification` from a checkpoint with no classifier head. "
- "Please specify the number of classes, e.g. `model = TimmWrapperForImageClassification.from_pretrained(..., num_labels=10)`, "
- "or use `TimmWrapperModel` for feature extraction."
- )
-
- self.timm_model = timm.create_model(config.architecture, pretrained=False, num_classes=config.num_labels)
- self.num_labels = config.num_labels
- self.post_init()
-
- @add_start_docstrings_to_model_forward(TIMM_WRAPPER_INPUTS_DOCSTRING)
- @replace_return_docstrings(output_type=ImageClassifierOutput, config_class=TimmWrapperConfig)
- def forward(
- self,
- pixel_values: torch.FloatTensor,
- labels: Optional[torch.LongTensor] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[Union[bool, List[int]]] = None,
- return_dict: Optional[bool] = None,
- **kwargs,
- ) -> Union[ImageClassifierOutput, Tuple[Tensor, ...]]:
- r"""
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
-
- Returns:
-
- Examples:
- ```python
- >>> import torch
- >>> from PIL import Image
- >>> from urllib.request import urlopen
- >>> from transformers import AutoModelForImageClassification, AutoImageProcessor
-
- >>> # Load image
- >>> image = Image.open(urlopen(
- ... 'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'
- ... ))
-
- >>> # Load model and image processor
- >>> checkpoint = "timm/resnet50.a1_in1k"
- >>> image_processor = AutoImageProcessor.from_pretrained(checkpoint)
- >>> model = AutoModelForImageClassification.from_pretrained(checkpoint).eval()
-
- >>> # Preprocess image
- >>> inputs = image_processor(image)
-
- >>> # Forward pass
- >>> with torch.no_grad():
- ... logits = model(**inputs).logits
-
- >>> # Get top 5 predictions
- >>> top5_probabilities, top5_class_indices = torch.topk(logits.softmax(dim=1) * 100, k=5)
- ```
- """
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
-
- if output_attentions:
- raise ValueError("Cannot set `output_attentions` for timm models.")
-
- if output_hidden_states and not hasattr(self.timm_model, "forward_intermediates"):
- raise ValueError(
- "The 'output_hidden_states' option cannot be set for this timm model. "
- "To enable this feature, the 'forward_intermediates' method must be implemented "
- "in the timm model (available in timm versions > 1.*). Please consider using a "
- "different architecture or updating the timm package to a compatible version."
- )
-
- pixel_values = pixel_values.to(self.device, self.dtype)
-
- if output_hidden_states:
- # to enable hidden states selection
- if isinstance(output_hidden_states, (list, tuple)):
- kwargs["indices"] = output_hidden_states
- last_hidden_state, hidden_states = self.timm_model.forward_intermediates(pixel_values, **kwargs)
- logits = self.timm_model.forward_head(last_hidden_state)
- else:
- logits = self.timm_model(pixel_values, **kwargs)
- hidden_states = None
-
- loss = None
- if labels is not None:
- if self.config.problem_type is None:
- if self.num_labels == 1:
- self.config.problem_type = "regression"
- elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
- self.config.problem_type = "single_label_classification"
- else:
- self.config.problem_type = "multi_label_classification"
- if self.config.problem_type == "regression":
- loss_fct = MSELoss()
- if self.num_labels == 1:
- loss = loss_fct(logits.squeeze(), labels.squeeze())
- else:
- loss = loss_fct(logits, labels)
- elif self.config.problem_type == "single_label_classification":
- loss_fct = CrossEntropyLoss()
- loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
- elif self.config.problem_type == "multi_label_classification":
- loss_fct = BCEWithLogitsLoss()
- loss = loss_fct(logits, labels)
-
- if not return_dict:
- outputs = (loss, logits, hidden_states)
- outputs = tuple(output for output in outputs if output is not None)
- return outputs
-
- return ImageClassifierOutput(
- loss=loss,
- logits=logits,
- hidden_states=hidden_states,
- )
-
-
-__all__ = ["TimmWrapperPreTrainedModel", "TimmWrapperModel", "TimmWrapperForImageClassification"]
diff --git a/src/transformers/models/unispeech/modeling_unispeech.py b/src/transformers/models/unispeech/modeling_unispeech.py
index f355eb03bdb82f..6ce5e77706d358 100755
--- a/src/transformers/models/unispeech/modeling_unispeech.py
+++ b/src/transformers/models/unispeech/modeling_unispeech.py
@@ -595,6 +595,7 @@ class UniSpeechFlashAttention2(UniSpeechAttention):
flash attention and deal with padding tokens in case the input contains any of them.
"""
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -1876,8 +1877,7 @@ def forward(
pooled_output = hidden_states.mean(dim=1)
else:
padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
- expand_padding_mask = padding_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
- hidden_states[~expand_padding_mask] = 0.0
+ hidden_states[~padding_mask] = 0.0
pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
logits = self.classifier(pooled_output)
diff --git a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py
index 0fd6e7cb2c04e1..52d82ea739426b 100755
--- a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py
+++ b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py
@@ -612,6 +612,7 @@ class UniSpeechSatFlashAttention2(UniSpeechSatAttention):
flash attention and deal with padding tokens in case the input contains any of them.
"""
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -1886,8 +1887,7 @@ def forward(
pooled_output = hidden_states.mean(dim=1)
else:
padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
- expand_padding_mask = padding_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
- hidden_states[~expand_padding_mask] = 0.0
+ hidden_states[~padding_mask] = 0.0
pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
logits = self.classifier(pooled_output)
diff --git a/src/transformers/models/vit/image_processing_vit_fast.py b/src/transformers/models/vit/image_processing_vit_fast.py
index e8abdcfe5cc82d..98ecfb3927a342 100644
--- a/src/transformers/models/vit/image_processing_vit_fast.py
+++ b/src/transformers/models/vit/image_processing_vit_fast.py
@@ -254,7 +254,6 @@ def preprocess(
image_std = image_std if image_std is not None else self.image_std
size = size if size is not None else self.size
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
- return_tensors = "pt" if return_tensors is None else return_tensors
# Make hashable for cache
size = SizeDict(**size)
image_mean = tuple(image_mean) if isinstance(image_mean, list) else image_mean
diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py
index 5168904a3579d9..bf1bb7746ce802 100755
--- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py
+++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py
@@ -38,6 +38,7 @@
XVectorOutput,
)
from ...modeling_utils import PreTrainedModel
+from ...pytorch_utils import is_torch_greater_or_equal_than_1_13
from ...utils import (
ModelOutput,
add_code_sample_docstrings,
@@ -658,6 +659,7 @@ class Wav2Vec2FlashAttention2(Wav2Vec2Attention):
flash attention and deal with padding tokens in case the input contains any of them.
"""
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -1589,7 +1591,7 @@ def load_adapter(self, target_lang: str, force_load=True, **kwargs):
cache_dir=cache_dir,
)
- weights_only_kwarg = {"weights_only": True}
+ weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
state_dict = torch.load(
weight_path,
map_location="cpu",
@@ -2375,8 +2377,7 @@ def forward(
pooled_output = hidden_states.mean(dim=1)
else:
padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
- expand_padding_mask = padding_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
- hidden_states[~expand_padding_mask] = 0.0
+ hidden_states[~padding_mask] = 0.0
pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
logits = self.classifier(pooled_output)
diff --git a/src/transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py b/src/transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py
index 7774c7a4069d02..6f1d5576df7316 100644
--- a/src/transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py
+++ b/src/transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py
@@ -1359,8 +1359,7 @@ def forward(
pooled_output = hidden_states.mean(dim=1)
else:
padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
- expand_padding_mask = padding_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
- hidden_states[~expand_padding_mask] = 0.0
+ hidden_states[~padding_mask] = 0.0
pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
logits = self.classifier(pooled_output)
diff --git a/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py b/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py
index 494654a6774754..933bf8f6dc0bcd 100644
--- a/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py
+++ b/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py
@@ -878,8 +878,7 @@ def forward(
if attention_mask is not None:
# make sure padded tokens output 0
- expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
- hidden_states[~expand_attention_mask] = 0.0
+ hidden_states[~attention_mask] = 0.0
# extend attention_mask
attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)
@@ -1792,8 +1791,7 @@ def forward(
pooled_output = hidden_states.mean(dim=1)
else:
padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
- expand_padding_mask = padding_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
- hidden_states[~expand_padding_mask] = 0.0
+ hidden_states[~padding_mask] = 0.0
pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
logits = self.classifier(pooled_output)
diff --git a/src/transformers/models/wavlm/modeling_wavlm.py b/src/transformers/models/wavlm/modeling_wavlm.py
index 3e5e3790005377..4df192fda5efa3 100755
--- a/src/transformers/models/wavlm/modeling_wavlm.py
+++ b/src/transformers/models/wavlm/modeling_wavlm.py
@@ -691,8 +691,7 @@ def forward(
if attention_mask is not None:
# make sure padded tokens output 0
- expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
- hidden_states[~expand_attention_mask] = 0
+ hidden_states[~attention_mask] = 0.0
position_embeddings = self.pos_conv_embed(hidden_states)
hidden_states = hidden_states + position_embeddings
@@ -777,8 +776,7 @@ def forward(
if attention_mask is not None:
# make sure padded tokens are not attended to
- expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
- hidden_states[~expand_attention_mask] = 0
+ hidden_states[~attention_mask] = 0
position_embeddings = self.pos_conv_embed(hidden_states)
hidden_states = hidden_states + position_embeddings
@@ -1510,8 +1508,7 @@ def forward(
pooled_output = hidden_states.mean(dim=1)
else:
padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
- expand_padding_mask = padding_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
- hidden_states[~expand_padding_mask] = 0.0
+ hidden_states[~padding_mask] = 0.0
pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
logits = self.classifier(pooled_output)
diff --git a/src/transformers/models/whisper/generation_whisper.py b/src/transformers/models/whisper/generation_whisper.py
index 360c0c0b687bab..2f58375f3de751 100644
--- a/src/transformers/models/whisper/generation_whisper.py
+++ b/src/transformers/models/whisper/generation_whisper.py
@@ -133,12 +133,9 @@ def _pad_to_max_length(
padding="longest",
bos_token_tensor=None,
cut_off_length=None,
- return_token_timestamps=False,
- force_unique_generate_call=False,
):
max_total_length = 0
sequences = []
- token_timestamps_list = []
if padding_side not in ["right", "left"]:
raise ValueError(f"`padding_side` must be either 'right' or 'left', not {padding_side}")
@@ -148,74 +145,31 @@ def _pad_to_max_length(
elif padding == "max_length" and cut_off_length is None:
raise ValueError("`cut_off_length` must be specified when `padding='max_length'`")
- if force_unique_generate_call:
- sequences_list = []
- timestamps_list = []
- for segments in current_segments:
- result = segments[0]["result"]
- sequences_list.append(result if isinstance(result, torch.Tensor) else result["sequences"])
- if return_token_timestamps:
- timestamps_list.append(result["token_timestamps"])
-
- sequences = torch.stack(sequences_list, dim=0)
- if return_token_timestamps:
- token_timestamps = torch.stack(timestamps_list, dim=0)
- return sequences, token_timestamps
- return sequences
-
for current_segment_list in current_segments:
if current_segment_list is not None and len([d["tokens"] for d in current_segment_list]) > 0:
sequence = torch.cat([d["tokens"] for d in current_segment_list], dim=-1)
- if return_token_timestamps:
- token_timestamps = torch.cat(
- [d["result"]["token_timestamps"][d["idxs"][0] : d["idxs"][1]] for d in current_segment_list],
- dim=-1,
- )
if cut_off_length is not None:
sequence = sequence[-cut_off_length:]
- if return_token_timestamps:
- token_timestamps = token_timestamps[-cut_off_length:]
if bos_token_tensor is not None:
sequence = torch.cat([bos_token_tensor, sequence])
- if return_token_timestamps:
- token_timestamps = torch.cat(
- [torch.ones_like(bos_token_tensor, device=device) * 0.0, token_timestamps]
- )
+
sequences.append(sequence)
- if return_token_timestamps:
- token_timestamps_list.append(token_timestamps)
max_total_length = max(max_total_length, len(sequences[-1]))
elif bos_token_tensor is not None:
sequences.append(bos_token_tensor)
- if return_token_timestamps:
- token_timestamps_list.append(torch.ones_like(bos_token_tensor, device=device) * 0.0)
else:
sequences.append(torch.tensor([], device=device))
- if return_token_timestamps:
- token_timestamps_list.append(torch.tensor([], device=device))
max_total_length = cut_off_length + 1 if padding == "max_length" else max_total_length
for i in range(len(current_segments)):
pad_length = max_total_length - len(sequences[i])
pad = (0, pad_length) if padding_side == "right" else (pad_length, 0)
-
sequences[i] = F.pad(sequences[i], pad=pad, value=pad_token_id)
- if return_token_timestamps:
- token_timestamps_list[i] = F.pad(
- token_timestamps_list[i],
- pad=pad,
- value=token_timestamps_list[i][-1] if len(token_timestamps_list[i]) > 0 else 0.0,
- )
sequences = torch.stack(sequences, dim=0)
-
- if return_token_timestamps:
- token_timestamps = torch.stack(token_timestamps_list, dim=0)
- return sequences, token_timestamps
- else:
- return sequences
+ return sequences
class WhisperGenerationMixin(GenerationMixin):
@@ -358,7 +312,6 @@ def generate(
return_token_timestamps: Optional[bool] = None,
return_segments: bool = False,
return_dict_in_generate: Optional[bool] = None,
- force_unique_generate_call: Optional[bool] = None,
**kwargs,
):
"""
@@ -382,7 +335,7 @@ def generate(
the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the
[`AutoFeatureExtractor`] should be used for extracting the mel features, padding and conversion into a
tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`] for details.
- generation_config ([`~generation.GenerationConfig`], *optional*):
+ generation_config (`~generation.GenerationConfig`, *optional*):
The generation configuration to be used as base parametrization for the generation call. `**kwargs`
passed to generate matching the attributes of `generation_config` will override them. If
`generation_config` is not provided, the default will be used, which had the following loading
@@ -479,39 +432,27 @@ def generate(
Note that when doing long-form transcription, `return_dict_in_generate` can only be enabled when
`return_segments` is set True. In this case the generation outputs of each segment is added to each
segment.
- force_unique_generate_call (`bool`, *optional*):
- Whether to force a unique call to the underlying GenerationMixin's [~generation.GenerationMixin.generate] method. This is useful for assisted decoding and testing purposes to ensure
- that only one call to [~generation.GenerationMixin.generate] is made and therefore decoder input token ids and eos token ids are returned.
kwargs (`Dict[str, Any]`, *optional*):
Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*.
+
Return:
- [`~utils.ModelOutput`] or `Dict[str, Any]` or `torch.LongTensor`:
+ [`~utils.ModelOutput`] or `torch.LongTensor` or `Dict[str, Any]`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
+ or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor` or a dict of segments when `return_segments=True`.
- A:
- - [`~utils.ModelOutput`] when `return_dict_in_generate=True` and (`return_timestamps=False` or `force_unique_generate_call=True`), including the decoder input ids and end of sequence id.
- - `Dict[str, Any]` when (`return_dict_in_generate=True` and `return_timestamps=True`) or `return_segments=True` or `return_token_timestamps=True`.
- - `torch.LongTensor` in all other cases, excluding the decoder input ids and end of sequence id.
+ If the passed input is > 30 seconds / > 3000 mel input features and `return_segments=True` then a dictionary of generated sequence ids, called `sequences` and a list of each generated segment is returned.
- The possible [`~utils.ModelOutput`] types are:
- - [`~generation.GenerateEncoderDecoderOutput`]
- - [`~generation.GenerateBeamEncoderDecoderOutput`]
+ else if the passed input is <= 30 seconds / >= 3000 mel input features, the possible [`~utils.ModelOutput`] types are:
- `segments` is a list of lists (one list per batch element) of `segment`.
- A `segment` is a dictionary with keys `start`, `end`, `tokens`, `idxs`, and `result`.
- - `start`: the start timestamp of the segment.
- - `end`: the end timestamp of the segment.
- - `tokens`: the tokens of the segment, excluding the decoder input ids and end of sequence id.
- - `idxs`: the start (included) and end (excluded) indices of the `tokens` of the segment in the underlying call to GenerationMixin's [~generation.GenerationMixin.generate] (present in `result`).
- - `result`: the result of the underlying call to GenerationMixin's [~generation.GenerationMixin.generate].
+ - [`~generation.GenerateEncoderDecoderOutput`],
+ - [`~generation.GenerateBeamEncoderDecoderOutput`]
- When `return_timestamps=True`, `return_dict_in_generate=True` applies to each call of the underlying GenerationMixin's [~generation.GenerationMixin.generate], with outputs stored in `result` of each `segment`.
+ else only the generated output sequence ids are returned.
Example:
- - *Longform transcription*: To transcribe or translate audios longer than 30 seconds, process the audio files without truncation and pass all mel features at once to generate. It is necessary to set `return_timestamps=True`.
- Indeed, long-form transcription uses a sequential algorithm based on timestamps predictions, with heuristics like compression ratio threshold, log probability threshold and temperature fallback. This algorithm is described in the [the Whisper original paper](https://cdn.openai.com/papers/whisper.pdf), section *3.8. Long-form Transcription*.
+ - *Longform transcription*: To transcribe or translate audios longer than 30 seconds, process the audio files without truncation and pass all mel features at once to generate.
```python
>>> import torch
@@ -542,9 +483,7 @@ def generate(
" Folks, if you watch the show, you know, I spent a lot of time right over there. Patiently and astutely scrutinizing the boxwood and mahogany chest set of the day's biggest stories developing the central headline pawns, definitely maneuvering an oso topical night to F6, fainting a classic Sicilian, nade door variation on the news, all the while seeing eight moves deep and patiently marshalling the latest press releases into a fisher's shows in Lip Nitsky attack that culminates in the elegant lethal slow-played, all-passant checkmate that is my nightly monologue. But sometimes, sometimes, folks, I. CHEERING AND APPLAUSE Sometimes I startle away, cubside down in the monkey bars of a condemned playground on a super fun site. Get all hept up on goofballs. Rummage that were discarded tag bag of defective toys. Yank out a fist bowl of disembodied doll limbs, toss them on a stained kid's place mat from a defunct dennies. set up a table inside a rusty cargo container down by the Wharf and challenged toothless drifters to the godless bughouse blitz of tournament that is my segment. Meanwhile."
```
- - *Shortform transcription*: If passed mel input features are <= 30 seconds, there are two possibilities:
- - `return_timestamps=False`: the whole audio will be transcribed with a single call to GenerationMixin's [~generation.GenerationMixin.generate].
- - `return_timestamps=True`: the audio will be transcribed using the same logic as long-form transcription.
+ - *Shortform transcription*: If passed mel input features are < 30 seconds, the whole audio will be transcribed with a single call to generate.
```python
>>> import torch
@@ -631,21 +570,11 @@ def generate(
# 3. Retrieve logits processors
device = kwargs["encoder_outputs"][0].device if "encoder_outputs" in kwargs else input_features.device
begin_index = init_tokens.shape[1]
- num_beams = kwargs.get(
- "num_beams",
- generation_config.num_beams
- if hasattr(generation_config, "num_beams") and generation_config.num_beams is not None
- else 1,
- )
- if "assistant_model" in kwargs:
- # speculative decoding: the model should be able to return eos token
- generation_config.begin_suppress_tokens = None
-
logits_processor = self._retrieve_logit_processors(
generation_config=generation_config,
logits_processor=logits_processor,
begin_index=begin_index, # begin index is index of first generated decoder token
- num_beams=num_beams,
+ num_beams=kwargs.get("num_beams", 1),
device=device,
)
@@ -689,19 +618,6 @@ def generate(
batch_size=cur_bsz,
generation_config=generation_config,
)
- # 5bis speculative decoding: ensure the assistant model does only one call to generate and therefore returns decoder input token ids and eos token id
- # we set a flag in the generation config to force the model to make only one call to generate and return the decoder input token ids and eos token id
- if "assistant_model" in kwargs:
- assistant_model = kwargs["assistant_model"]
- assistant_model.generation_config.force_unique_generate_call = True
-
- if force_unique_generate_call is None:
- if hasattr(generation_config, "force_unique_generate_call"):
- force_unique_generate_call = generation_config.force_unique_generate_call
- elif hasattr(self.generation_config, "force_unique_generate_call"):
- force_unique_generate_call = self.generation_config.force_unique_generate_call
- else:
- force_unique_generate_call = False
# 6 Transcribe audio until we reach the end of all input audios
while (seek < max_frames).any():
@@ -716,9 +632,7 @@ def generate(
cur_bsz=cur_bsz,
batch_idx_map=batch_idx_map,
)
- time_offset = (
- seek.to(torch.float32 if device.type == "mps" else torch.float64) * time_precision / input_stride
- )
+ time_offset = seek.to(torch.float64) * time_precision / input_stride
seek_num_frames = (max_frames - seek).clamp(max=num_segment_frames)
# 6.2 cut out next 30s segment from input features
@@ -813,15 +727,14 @@ def generate(
prev_idx=prev_i,
idx=i,
return_token_timestamps=return_token_timestamps,
- decoder_input_ids=decoder_input_ids,
)
- seek[prev_i] += segment_offset
-
current_segments[prev_i] += segments
- if force_unique_generate_call:
- break
+ if is_shortform:
+ seek[prev_i] += max_frames[i]
+ else:
+ seek[prev_i] += segment_offset
# 7. Once all segments are added to the list of all segments, called `current_segments`, we extract the predicted
# output tokens from the list of dicts. If we use batch size > 1, we make sure to pad the output
@@ -831,62 +744,51 @@ def generate(
else current_segments
)
- # if return_dict_in_generate=True and we forced a unique call to generate or return_timestamps=False, meaning we are sure only one call to generate has been made,
- # -> we can return a ModelOutput
- # otherwise, return_dict_in_generate is applied in the 'result' of each segment in final_segments
- if (
- return_dict_in_generate
- and generation_config.return_dict_in_generate
- and (force_unique_generate_call or not return_timestamps)
- ):
- # only one call to generate_with_fallback, we can return a ModelOutput
- outputs = self._stack_split_outputs(seek_outputs, model_output_type, self.device, kwargs)
- if num_return_sequences > 1:
- if hasattr(outputs, "encoder_attentions") and outputs.encoder_attentions is not None:
- outputs.encoder_attentions = tuple(
- outputs.encoder_attentions[i][::num_return_sequences]
- for i in range(len(outputs.encoder_attentions))
- )
- if hasattr(outputs, "encoder_hidden_states") and outputs.encoder_hidden_states is not None:
- outputs.encoder_hidden_states = tuple(
- outputs.encoder_hidden_states[i][::num_return_sequences]
- for i in range(len(outputs.encoder_hidden_states))
- )
- return outputs
-
- padded_outputs = _pad_to_max_length(
- current_segments=final_segments,
- pad_token_id=generation_config.pad_token_id,
- device=self.device,
- padding_side="right",
- return_token_timestamps=return_token_timestamps,
- force_unique_generate_call=force_unique_generate_call,
+ sequences = _pad_to_max_length(
+ final_segments, generation_config.pad_token_id, device=self.device, padding_side="right"
)
- if return_dict_in_generate and generation_config.return_dict_in_generate:
- logger.warning_once(
- "You have passed `return_dict_in_generate=True` and `return_timestamps=True`, this automatically sets `return_segments=True` to access the resuls of the underlying calls to GenerationMixin's generate in the returned `segments`."
- )
- return_segments = True
- elif not return_segments and not return_token_timestamps:
- return padded_outputs
+ # 8. If we return all segments, the predicted output sequences are put under `"sequences"`.
+ if return_segments:
+ return {"sequences": sequences, "segments": final_segments}
- if return_token_timestamps:
- sequences, token_timestamps = padded_outputs
- outputs = {
- "sequences": sequences,
- "token_timestamps": token_timestamps,
- }
- else:
- sequences = padded_outputs
- outputs = {
- "sequences": sequences,
- }
+ if is_shortform:
+ # add eos token:
+ if generation_config.max_new_tokens is None and generation_config.max_length is None:
+ eos_tokens = torch.full((sequences.shape[0], 1), generation_config.eos_token_id)
+ sequences = torch.cat([sequences, eos_tokens], dim=-1)
- if return_segments:
- outputs["segments"] = final_segments
+ if return_token_timestamps:
+ outputs = {}
+ outputs["sequences"] = sequences
+ outputs["token_timestamps"] = torch.stack([d["token_timestamps"] for d in seek_outputs], dim=0)
+ else:
+ outputs = sequences
- return outputs
+ if return_dict_in_generate and generation_config.return_dict_in_generate:
+ dict_outputs = self._stack_split_outputs(seek_outputs, model_output_type, sequences.device, kwargs)
+
+ if num_return_sequences > 1:
+ if hasattr(dict_outputs, "encoder_attentions") and dict_outputs.encoder_attentions is not None:
+ dict_outputs.encoder_attentions = tuple(
+ dict_outputs.encoder_attentions[i][::num_return_sequences]
+ for i in range(len(dict_outputs.encoder_attentions))
+ )
+ if (
+ hasattr(dict_outputs, "encoder_hidden_states")
+ and dict_outputs.encoder_hidden_states is not None
+ ):
+ dict_outputs.encoder_hidden_states = tuple(
+ dict_outputs.encoder_hidden_states[i][::num_return_sequences]
+ for i in range(len(dict_outputs.encoder_hidden_states))
+ )
+ if return_token_timestamps:
+ dict_outputs["token_timestamps"] = outputs["token_timestamps"]
+ return dict_outputs
+
+ return outputs
+
+ return sequences
def generate_with_fallback(
self,
@@ -982,14 +884,22 @@ def generate_with_fallback(
new_decoder_attention_mask = []
for i, seek_sequence in enumerate(seek_sequences):
- # remove all padding tokens, except for the eos token
+ # make sure we cut a predicted EOS token if we are not finished with the generation yet
+ prev_i = batch_idx_map[fallback_index_map[i]]
+ is_not_final = (seek[prev_i] + num_segment_frames) < max_frames[prev_i]
+
+ # remove eos token id
+ if is_not_final and seek_sequence[-1] == generation_config.eos_token_id:
+ seek_sequence = seek_sequence[:-1]
+ if return_token_timestamps and not is_shortform:
+ seek_outputs[i]["token_timestamps"] = seek_outputs[i]["token_timestamps"][:-1]
+
+ # remove all padding tokens
if seek_sequence[-1] == generation_config.pad_token_id:
num_paddings = (seek_sequence == generation_config.pad_token_id).sum()
- if generation_config.pad_token_id == generation_config.eos_token_id:
- # we do not remove the eos token id since it is needed for avg logprob calculation in _need_fallback
- num_paddings -= 1
- if num_paddings != 0:
- seek_sequence = seek_sequence[:-num_paddings]
+ seek_sequence = seek_sequence[:-num_paddings]
+ if return_token_timestamps and not is_shortform:
+ seek_outputs[i]["token_timestamps"] = seek_outputs[i]["token_timestamps"][:-num_paddings]
# check which sequences in batch need fallback & which should be skipped
needs_fallback[i], should_skip[i] = self._need_fallback(
@@ -1002,10 +912,6 @@ def generate_with_fallback(
temperature,
)
- # remove eos token
- if seek_sequence[-1] == generation_config.eos_token_id:
- seek_sequence = seek_sequence[:-1]
-
seek_sequence_list[fallback_index_map[i]] = seek_sequence
seek_outputs_list[fallback_index_map[i]] = seek_outputs[i]
is_low_temperature = temperature is None or temperature < 0.5
@@ -1048,19 +954,14 @@ def _prepare_segments(prompt_ids, batch_size, generation_config):
return current_segments
def _postprocess_outputs(
- self,
- seek_outputs,
- decoder_input_ids,
- return_token_timestamps,
- generation_config,
- is_shortform,
+ self, seek_outputs, decoder_input_ids, return_token_timestamps, generation_config, is_shortform
):
# remove all previously passed decoder input ids
- # should happen only if it is the first generated segment
- start_idx = decoder_input_ids.shape[-1]
+ start_idx = decoder_input_ids.shape[-1] if not is_shortform else torch.tensor(0)
if isinstance(seek_outputs, torch.Tensor):
- return seek_outputs[:, start_idx:], seek_outputs
+ seek_outputs = seek_outputs[:, start_idx:]
+ return seek_outputs, seek_outputs
if return_token_timestamps and hasattr(generation_config, "alignment_heads"):
num_frames = getattr(generation_config, "num_frames", None)
@@ -1070,6 +971,9 @@ def _postprocess_outputs(
num_frames=num_frames,
num_input_ids=decoder_input_ids.shape[-1],
)
+ seek_outputs["token_timestamps"] = seek_outputs["token_timestamps"][:, start_idx:]
+
+ seek_outputs["sequences"] = seek_outputs["sequences"][:, start_idx:]
def split_by_batch_index(values, key, batch_idx, is_shortform, beam_indices=None):
if beam_indices is not None and key == "scores":
@@ -1105,7 +1009,7 @@ def split_by_batch_index(values, key, batch_idx, is_shortform, beam_indices=None
return values[batch_idx].cpu()
- sequence_tokens = seek_outputs["sequences"][:, start_idx:]
+ sequence_tokens = seek_outputs["sequences"]
seek_outputs = [
{
k: split_by_batch_index(v, k, i, is_shortform, beam_indices=seek_outputs.get("beam_indices"))
@@ -1120,7 +1024,7 @@ def _stack_split_outputs(self, seek_outputs, model_output_type, device, kwargs):
# Stack back seek_outputs tensors after splitting them with the split_by_batch_index method
outputs = {}
for key in seek_outputs[0].keys():
- if key in ["sequences", "beam_indices", "token_timestamps"]:
+ if key in ["sequences", "beam_indices"]:
outputs[key] = torch.stack([v[key] for v in seek_outputs], dim=0).to(device)
elif key in ["scores", "encoder_attentions", "encoder_hidden_states", "logits"]:
outputs[key] = tuple(
@@ -1151,10 +1055,6 @@ def _stack_split_outputs(self, seek_outputs, model_output_type, device, kwargs):
else:
outputs[key] = None
- token_timestamps = outputs.get("token_timestamps", None)
- if token_timestamps is not None:
- model_output_type = dict
-
return model_output_type(**outputs)
def _need_fallback(
@@ -1181,9 +1081,7 @@ def _need_fallback(
else:
scores = seek_outputs[index]["scores"]
logprobs = self._retrieve_avg_logprobs(
- scores,
- seek_sequence,
- temperature,
+ scores, seek_sequence, generation_config.eos_token_id, temperature
)
if logprobs < generation_config.logprob_threshold:
@@ -1279,6 +1177,13 @@ def _maybe_warn_unused_inputs(
if no_speech_threshold is not None:
logger.warning(warning_prefix.format(f"no_speech_threshold is set to {no_speech_threshold}"))
+ # when passing temperature as a list it cannot just be ignored => throw error in this case
+ if isinstance(temperature, (list, tuple)):
+ raise ValueError(
+ f"Audio input consists of only {total_input_frames}. Short-form transcription is activated."
+ f"temperature cannot be set to {temperature} which can only be used for temperature fallback for long-form generation. Make sure to set `temperature` to a float value or `None` for short-form generation."
+ )
+
@staticmethod
def _set_return_outputs(return_dict_in_generate, return_token_timestamps, logprob_threshold, generation_config):
if return_dict_in_generate is None:
@@ -1861,7 +1766,7 @@ def _retrieve_compression_ratio(tokens, vocab_size):
return compression_ratio
@staticmethod
- def _retrieve_avg_logprobs(scores, tokens, temperature):
+ def _retrieve_avg_logprobs(scores, tokens, eos_token_id, temperature):
rescale_temperature = temperature if temperature > 0.0 else 1
scores = torch.stack(scores).to(tokens.device)
@@ -1873,10 +1778,10 @@ def _retrieve_avg_logprobs(scores, tokens, temperature):
logprobs = F.log_softmax((scores * rescale_temperature).float(), dim=-1).to(scores.dtype)
# retrieve logprob of selected tokens and sum
- # don't remove the eos token logprob! it counts in avg_logprob calculation in the original implementation
- sum_logprobs = sum(logprobs[i][tokens[i]] for i in range(logprobs.shape[0]))
+ sum_logprobs = sum((logprobs[i][tokens[i]] * (tokens[i] != eos_token_id)) for i in range(logprobs.shape[0]))
+ length = (tokens != eos_token_id).sum(-1) if eos_token_id is not None else tokens.shape[0]
- avg_logprobs = sum_logprobs / len(tokens)
+ avg_logprobs = sum_logprobs / (length + 1)
return avg_logprobs
@staticmethod
@@ -1892,7 +1797,6 @@ def _retrieve_segment(
prev_idx,
idx,
return_token_timestamps,
- decoder_input_ids,
):
# find the predicted "end of segment" predictions of Whisper
# "end of segment" predictions occur whenever Whisper predicts a timestamp token
@@ -1901,8 +1805,6 @@ def _retrieve_segment(
timestamp_segment_indices = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
timestamp_segment_indices.add_(1)
token_timestamps = seek_outputs[idx]["token_timestamps"] if return_token_timestamps else []
- idx_offset = decoder_input_ids.shape[-1]
- device = seek_sequence.device
# If whisper predicted a "end of segment" via a timestep token, let's go ever each
# "end of segment" prediction and slice the decoding into segments accordingly
@@ -1926,20 +1828,15 @@ def _retrieve_segment(
end_timestamp_pos = sliced_tokens[idx_sliced_tokens] - timestamp_begin
segments.append(
{
- "start": time_offset[prev_idx]
- + start_timestamp_pos.to(torch.float32 if device.type == "mps" else torch.float64)
- * time_precision,
- "end": time_offset[prev_idx]
- + end_timestamp_pos.to(torch.float32 if device.type == "mps" else torch.float64)
- * time_precision,
+ "start": time_offset[prev_idx] + start_timestamp_pos.to(torch.float64) * time_precision,
+ "end": time_offset[prev_idx] + end_timestamp_pos.to(torch.float64) * time_precision,
"tokens": sliced_tokens,
- "idxs": (idx_offset + last_slice, idx_offset + current_slice),
"result": seek_outputs[idx],
}
)
if return_token_timestamps:
segments[-1]["token_timestamps"] = (
- token_timestamps[idx_offset + last_slice : idx_offset + current_slice] + time_offset[prev_idx]
+ token_timestamps[last_slice:current_slice] + time_offset[prev_idx]
)
last_slice = current_slice
@@ -1959,22 +1856,17 @@ def _retrieve_segment(
last_timestamp_pos = int(seek_num_frames[prev_idx] * time_precision_features / time_precision)
if timestamps.numel() > 0 and timestamps[-1] != timestamp_begin:
# no consecutive timestamps but it has a timestamp; use the last one.
- last_timestamp_pos = (timestamps[-1] - timestamp_begin).to(
- torch.float32 if device.type == "mps" else torch.float64
- )
+ last_timestamp_pos = (timestamps[-1] - timestamp_begin).to(torch.float64)
segments = [
{
"start": time_offset[prev_idx],
"end": time_offset[prev_idx] + last_timestamp_pos * time_precision,
"tokens": seek_sequence,
- "idxs": (idx_offset, idx_offset + len(seek_sequence)),
"result": seek_outputs[idx],
}
]
if return_token_timestamps:
- segments[-1]["token_timestamps"] = (
- token_timestamps[idx_offset : idx_offset + len(seek_sequence)] + time_offset[prev_idx]
- )
+ segments[-1]["token_timestamps"] = token_timestamps + time_offset[prev_idx]
segment_offset = seek_num_frames[prev_idx]
return segments, segment_offset
diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py
index fb01823a29c017..ce3df3e16707e5 100644
--- a/src/transformers/models/whisper/modeling_whisper.py
+++ b/src/transformers/models/whisper/modeling_whisper.py
@@ -354,6 +354,7 @@ class WhisperFlashAttention2(WhisperAttention):
flash attention and deal with padding tokens in case the input contains any of them.
"""
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
diff --git a/src/transformers/models/zamba/modeling_zamba.py b/src/transformers/models/zamba/modeling_zamba.py
index 3b7348eadd4785..dee7f898fcf93a 100644
--- a/src/transformers/models/zamba/modeling_zamba.py
+++ b/src/transformers/models/zamba/modeling_zamba.py
@@ -312,6 +312,7 @@ class ZambaFlashAttention2(ZambaAttention):
flash attention and deal with padding tokens in case the input contains any of them.
"""
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -773,7 +774,6 @@ def forward(self, hidden_states, cache_params: HybridMambaAttentionDynamicCache
class ZambaMLP(nn.Module):
def __init__(self, config):
super().__init__()
- self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
@@ -781,9 +781,8 @@ def __init__(self, config):
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act]
- def forward(self, x):
- down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
- return down_proj
+ def forward(self, hidden_state):
+ return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
class ZambaAttentionDecoderLayer(nn.Module):
diff --git a/src/transformers/models/zoedepth/modeling_zoedepth.py b/src/transformers/models/zoedepth/modeling_zoedepth.py
index 2f4a42d2818005..5cbbdcdc04b756 100644
--- a/src/transformers/models/zoedepth/modeling_zoedepth.py
+++ b/src/transformers/models/zoedepth/modeling_zoedepth.py
@@ -417,7 +417,7 @@ def __init__(self, n_classes=256, act=torch.softmax):
self.k = n_classes
self.act = act
self.register_buffer("k_idx", torch.arange(0, n_classes).view(1, -1, 1, 1), persistent=False)
- self.register_buffer("k_minus_1", torch.tensor([self.k - 1]).view(1, -1, 1, 1), persistent=False)
+ self.register_buffer("k_minus_1", torch.Tensor([self.k - 1]).view(1, -1, 1, 1), persistent=False)
def forward(self, probabilities, temperature=1.0, eps=1e-4):
"""Compute the log binomial distribution for probabilities.
diff --git a/src/transformers/pipelines/audio_utils.py b/src/transformers/pipelines/audio_utils.py
index 72a5f51db6129a..4a8a93c9683a82 100644
--- a/src/transformers/pipelines/audio_utils.py
+++ b/src/transformers/pipelines/audio_utils.py
@@ -68,7 +68,7 @@ def ffmpeg_microphone(
The name of the format of the audio samples to be returned by ffmpeg. The standard is `f32le`, `s16le`
could also be used.
ffmpeg_input_device (`str`, *optional*):
- The identifier of the input device to be used by ffmpeg (i.e. ffmpeg's '-i' argument). If unset,
+ The indentifier of the input device to be used by ffmpeg (i.e. ffmpeg's '-i' argument). If unset,
the default input device will be used. See `https://www.ffmpeg.org/ffmpeg-devices.html#Input-Devices`
for how to specify and list input devices.
ffmpeg_additional_args (`list[str]`, *optional*):
diff --git a/src/transformers/pytorch_utils.py b/src/transformers/pytorch_utils.py
index 95c8748375ce0a..5bdf8a355ddfaa 100644
--- a/src/transformers/pytorch_utils.py
+++ b/src/transformers/pytorch_utils.py
@@ -34,11 +34,12 @@
is_torch_greater_or_equal_than_2_3 = parsed_torch_version_base >= version.parse("2.3")
is_torch_greater_or_equal_than_2_2 = parsed_torch_version_base >= version.parse("2.2")
is_torch_greater_or_equal_than_2_1 = parsed_torch_version_base >= version.parse("2.1")
+is_torch_greater_or_equal_than_2_0 = parsed_torch_version_base >= version.parse("2.0")
+is_torch_greater_or_equal_than_1_13 = parsed_torch_version_base >= version.parse("1.13")
+is_torch_greater_or_equal_than_1_12 = parsed_torch_version_base >= version.parse("1.12")
-# Cache this result has it's a C FFI call which can be pretty time-consuming
-_torch_distributed_available = torch.distributed.is_available()
-if is_torch_greater_or_equal("2.5") and _torch_distributed_available:
+if is_torch_greater_or_equal("2.5"):
from torch.distributed.tensor import Replicate
from torch.distributed.tensor.parallel import (
ColwiseParallel,
diff --git a/src/transformers/quantizers/auto.py b/src/transformers/quantizers/auto.py
index 47b54cd27bcebe..38bebd2d8410e4 100755
--- a/src/transformers/quantizers/auto.py
+++ b/src/transformers/quantizers/auto.py
@@ -29,7 +29,6 @@
QuantizationMethod,
QuantoConfig,
TorchAoConfig,
- VptqConfig,
)
from .quantizer_aqlm import AqlmHfQuantizer
from .quantizer_awq import AwqQuantizer
@@ -43,7 +42,6 @@
from .quantizer_hqq import HqqHfQuantizer
from .quantizer_quanto import QuantoHfQuantizer
from .quantizer_torchao import TorchAoHfQuantizer
-from .quantizer_vptq import VptqHfQuantizer
AUTO_QUANTIZER_MAPPING = {
@@ -59,7 +57,6 @@
"fbgemm_fp8": FbgemmFp8HfQuantizer,
"torchao": TorchAoHfQuantizer,
"bitnet": BitNetHfQuantizer,
- "vptq": VptqHfQuantizer,
}
AUTO_QUANTIZATION_CONFIG_MAPPING = {
@@ -75,7 +72,6 @@
"fbgemm_fp8": FbgemmFp8Config,
"torchao": TorchAoConfig,
"bitnet": BitNetConfig,
- "vptq": VptqConfig,
}
@@ -177,14 +173,13 @@ def merge_quantization_configs(
quantization_config = AutoQuantizationConfig.from_dict(quantization_config)
if (
- isinstance(quantization_config, (GPTQConfig, AwqConfig, FbgemmFp8Config, CompressedTensorsConfig))
+ isinstance(quantization_config, (GPTQConfig, AwqConfig, FbgemmFp8Config))
and quantization_config_from_args is not None
):
# special case for GPTQ / AWQ / FbgemmFp8 config collision
loading_attr_dict = quantization_config_from_args.get_loading_attributes()
for attr, val in loading_attr_dict.items():
setattr(quantization_config, attr, val)
-
warning_msg += f"However, loading attributes (e.g. {list(loading_attr_dict.keys())}) will be overwritten with the one you passed to `from_pretrained`. The rest will be ignored."
if warning_msg != "":
diff --git a/src/transformers/quantizers/quantizer_awq.py b/src/transformers/quantizers/quantizer_awq.py
index 7b81c93edf1fac..0c14c236d26036 100644
--- a/src/transformers/quantizers/quantizer_awq.py
+++ b/src/transformers/quantizers/quantizer_awq.py
@@ -111,7 +111,7 @@ def _process_model_before_weight_loading(self, model: "PreTrainedModel", **kwarg
" Please double check your model architecture, or submit an issue on github if you think this is a bug."
)
- def _process_model_after_weight_loading(self, model, **kwargs):
+ def _process_model_after_weight_loading(self, model):
if self.quantization_config.do_fuse:
from ..integrations import fuse_awq_modules
diff --git a/src/transformers/quantizers/quantizer_compressed_tensors.py b/src/transformers/quantizers/quantizer_compressed_tensors.py
index 5064f2c019d74e..61e940886d942f 100644
--- a/src/transformers/quantizers/quantizer_compressed_tensors.py
+++ b/src/transformers/quantizers/quantizer_compressed_tensors.py
@@ -12,11 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
-import os
-
from ..utils import is_compressed_tensors_available, is_torch_available, logging
-from ..utils.quantization_config import CompressedTensorsConfig
+from ..utils.quantization_config import QuantizationConfigMixin
from .base import HfQuantizer
@@ -35,13 +32,12 @@ class CompressedTensorsHfQuantizer(HfQuantizer):
requires_calibration = True
required_packages = ["compressed_tensors"]
- def __init__(self, quantization_config: CompressedTensorsConfig, **kwargs):
+ def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs):
super().__init__(quantization_config, **kwargs)
+
from compressed_tensors.compressors import ModelCompressor
self.compressor = ModelCompressor.from_compression_config(quantization_config)
- self.run_compressed = quantization_config.run_compressed
- self.quantization_config = quantization_config
def validate_environment(self, *args, **kwargs):
if not is_compressed_tensors_available():
@@ -67,57 +63,20 @@ def _process_model_before_weight_loading(self, model, **kwargs):
from compressed_tensors.quantization import apply_quantization_config
ct_quantization_config = self.compressor.quantization_config
+ apply_quantization_config(model, ct_quantization_config, run_compressed=True)
- if self.run_compressed and self.is_quantization_compressed:
- apply_quantization_config(model, ct_quantization_config, run_compressed=True)
- elif not self.is_quantization_compressed:
- apply_quantization_config(model, ct_quantization_config)
-
- def _process_model_after_weight_loading(self, model, **kwargs):
- """Decompress loaded model if necessary - need for qat"""
-
- if (self.is_quantization_compressed and not self.run_compressed) or self.is_sparsification_compressed:
- config = kwargs.get("config", None)
- cache_path = config._name_or_path
-
- if not os.path.exists(cache_path):
- from transformers.utils import cached_file
-
- config_file_path = cached_file(cache_path, "config.json")
- cache_path = os.path.sep.join(config_file_path.split(os.path.sep)[:-1])
-
- if self.is_quantization_compressed and not self.run_compressed:
- from compressed_tensors.quantization import QuantizationStatus
-
- self.compressor.quantization_config.quantization_status = QuantizationStatus.FROZEN
- self.compressor.decompress(model_path=cache_path, model=model)
+ def _process_model_after_weight_loading(self, model, **kwargs) -> None:
+ pass
@property
- def is_quantization_compressed(self):
- from compressed_tensors.quantization import QuantizationStatus
-
- return (
- self.quantization_config.quantization_config is not None
- and self.quantization_config.quantization_config.quantization_status == QuantizationStatus.COMPRESSED
- )
-
- @property
- def is_sparsification_compressed(self):
- from compressed_tensors.config.base import CompressionFormat
-
- return (
- self.quantization_config.sparsity_config is not None
- and self.quantization_config.sparsity_config.format != CompressionFormat.dense.value
- )
-
- @property
- def is_trainable(self):
+ def is_trainable(self) -> bool:
+ """Models quantized using compressed tensors can be finetuned"""
return True
+ @property
def is_qat_trainable(self) -> bool:
"""Loaded Models can carry out quantization aware training"""
- # models need to be decompressed carry out qat
- return not self.run_compressed or not self.is_quantization_compressed
+ return True
def is_serializable(self, safe_serialization=None) -> bool:
"""Models quantized using compressed tensors can be saved to disk"""
diff --git a/src/transformers/quantizers/quantizer_quanto.py b/src/transformers/quantizers/quantizer_quanto.py
index 230e8efe150672..d91019dea15226 100644
--- a/src/transformers/quantizers/quantizer_quanto.py
+++ b/src/transformers/quantizers/quantizer_quanto.py
@@ -197,7 +197,7 @@ def _process_model_before_weight_loading(
)
model.config.quantization_config = self.quantization_config
- def _process_model_after_weight_loading(self, model, **kwargs):
+ def _process_model_after_weight_loading(self, model):
return model
@property
diff --git a/src/transformers/quantizers/quantizer_torchao.py b/src/transformers/quantizers/quantizer_torchao.py
index 10d2b184ef146b..e6c2dc1ce36b3f 100644
--- a/src/transformers/quantizers/quantizer_torchao.py
+++ b/src/transformers/quantizers/quantizer_torchao.py
@@ -195,7 +195,7 @@ def create_quantized_param(
module._parameters[tensor_name] = torch.nn.Parameter(param_value).to(device=target_device)
quantize_(module, self.quantization_config.get_apply_tensor_subclass())
- def _process_model_after_weight_loading(self, model, **kwargs):
+ def _process_model_after_weight_loading(self, model):
"""No process required for torchao quantized model"""
return
diff --git a/src/transformers/quantizers/quantizer_vptq.py b/src/transformers/quantizers/quantizer_vptq.py
deleted file mode 100644
index 1672c3ebc5a7d3..00000000000000
--- a/src/transformers/quantizers/quantizer_vptq.py
+++ /dev/null
@@ -1,98 +0,0 @@
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-from typing import TYPE_CHECKING, Optional
-
-from .base import HfQuantizer
-
-
-if TYPE_CHECKING:
- from ..modeling_utils import PreTrainedModel
-
-from ..utils import is_accelerate_available, is_torch_available, is_vptq_available, logging
-from ..utils.quantization_config import QuantizationConfigMixin
-
-
-if is_torch_available():
- import torch
-
-logger = logging.get_logger(__name__)
-
-
-class VptqHfQuantizer(HfQuantizer):
- """
- Quantizer of the VPTQ method. Enables the loading of prequantized models.
- """
-
- requires_calibration = True
- required_packages = ["vptq"]
-
- def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs):
- super().__init__(quantization_config, **kwargs)
- self.quantization_config = quantization_config
-
- def validate_environment(self, *args, **kwargs):
- if not is_accelerate_available():
- raise ImportError("Using `vptq` quantization requires Accelerate: `pip install accelerate`")
-
- if not is_vptq_available():
- raise ImportError("Using `vptq` quantization requires VPTQ>=0.0.4: `pip install -U vptq`")
-
- def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
- if torch_dtype is None:
- if torch.cuda.is_available():
- torch_dtype = torch.float16
- logger.info(
- "CUDA available. Assuming VPTQ inference on GPU and loading the model in `torch.float16`. To overwrite it, set `torch_dtype` manually."
- )
- else:
- import vptq
-
- device_availability = getattr(vptq, "device_availability", lambda device: False)
- if device_availability("cpu") is True:
- raise RuntimeError("No GPU found. Please wait for the next release of VPTQ to use CPU inference")
- torch_dtype = torch.float32
- logger.info("No GPU found. Assuming VPTQ inference on CPU and loading the model in `torch.float32`.")
- return torch_dtype
-
- def _process_model_before_weight_loading(
- self,
- model: "PreTrainedModel",
- **kwargs,
- ):
- """
- we don't have param like modules_to_not_convert to indicate which layers should not be quantized
- because `quantization_config` include the layers that should be quantized
- """
- from ..integrations import replace_with_vptq_linear
-
- modules_to_not_convert = kwargs.get("modules_to_not_convert", []) + (
- self.quantization_config.modules_to_not_convert or []
- )
-
- replace_with_vptq_linear(
- model,
- quantization_config=self.quantization_config,
- modules_to_not_convert=modules_to_not_convert,
- )
- model.config.quantization_config = self.quantization_config
-
- def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
- return model
-
- @property
- def is_trainable(self, model: Optional["PreTrainedModel"] = None):
- return False
-
- def is_serializable(self, safe_serialization=None):
- return True
diff --git a/src/transformers/safetensors_conversion.py b/src/transformers/safetensors_conversion.py
index f1612d3ea57c98..5c0179350ea2ef 100644
--- a/src/transformers/safetensors_conversion.py
+++ b/src/transformers/safetensors_conversion.py
@@ -67,7 +67,7 @@ def get_conversion_pr_reference(api: HfApi, model_id: str, **kwargs):
# security breaches.
pr = previous_pr(api, model_id, pr_title, token=token)
- if pr is None or (not private and pr.author != "SFconvertbot"):
+ if pr is None or (not private and pr.author != "SFConvertBot"):
spawn_conversion(token, private, model_id)
pr = previous_pr(api, model_id, pr_title, token=token)
else:
diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py
index 2f523ed36d983f..30f7b5a68fb2c0 100644
--- a/src/transformers/testing_utils.py
+++ b/src/transformers/testing_utils.py
@@ -14,7 +14,6 @@
import collections
import contextlib
-import copy
import doctest
import functools
import gc
@@ -29,7 +28,6 @@
import subprocess
import sys
import tempfile
-import threading
import time
import unittest
from collections import defaultdict
@@ -143,7 +141,6 @@
is_torchdynamo_available,
is_torchvision_available,
is_vision_available,
- is_vptq_available,
strtobool,
)
@@ -1144,13 +1141,6 @@ def require_aqlm(test_case):
return unittest.skipUnless(is_aqlm_available(), "test requires aqlm")(test_case)
-def require_vptq(test_case):
- """
- Decorator marking a test that requires vptq
- """
- return unittest.skipUnless(is_vptq_available(), "test requires vptq")(test_case)
-
-
def require_eetq(test_case):
"""
Decorator marking a test that requires eetq
@@ -1397,53 +1387,6 @@ def assert_screenout(out, what):
assert match_str != -1, f"expecting to find {what} in output: f{out_pr}"
-def set_model_tester_for_less_flaky_test(test_case):
- if hasattr(test_case.model_tester, "num_hidden_layers"):
- test_case.model_tester.num_hidden_layers = 1
- if (
- hasattr(test_case.model_tester, "vision_config")
- and "num_hidden_layers" in test_case.model_tester.vision_config
- ):
- test_case.model_tester.vision_config = copy.deepcopy(test_case.model_tester.vision_config)
- test_case.model_tester.vision_config["num_hidden_layers"] = 1
- if hasattr(test_case.model_tester, "text_config") and "num_hidden_layers" in test_case.model_tester.text_config:
- test_case.model_tester.text_config = copy.deepcopy(test_case.model_tester.text_config)
- test_case.model_tester.text_config["num_hidden_layers"] = 1
-
-
-def set_config_for_less_flaky_test(config):
- target_attrs = [
- "rms_norm_eps",
- "layer_norm_eps",
- "norm_eps",
- "norm_epsilon",
- "layer_norm_epsilon",
- "batch_norm_eps",
- ]
- for target_attr in target_attrs:
- setattr(config, target_attr, 1.0)
-
- # norm layers (layer/group norm, etc.) could cause flaky tests when the tensors have very small variance.
- # (We don't need the original epsilon values to check eager/sdpa matches)
- attrs = ["text_config", "vision_config", "text_encoder", "audio_encoder", "decoder"]
- for attr in attrs:
- if hasattr(config, attr):
- for target_attr in target_attrs:
- setattr(getattr(config, attr), target_attr, 1.0)
-
-
-def set_model_for_less_flaky_test(model):
- # Another way to make sure norm layers have desired epsilon. (Some models don't set it from its config.)
- target_names = ("LayerNorm", "GroupNorm", "BatchNorm", "RMSNorm", "BatchNorm2d", "BatchNorm1d")
- target_attrs = ["eps", "epsilon", "variance_epsilon"]
- if is_torch_available() and isinstance(model, torch.nn.Module):
- for module in model.modules():
- if type(module).__name__.endswith(target_names):
- for attr in target_attrs:
- if hasattr(module, attr):
- setattr(module, attr, 1.0)
-
-
class CaptureStd:
"""
Context manager to capture:
@@ -2368,28 +2311,12 @@ class RequestCounter:
def __enter__(self):
self._counter = defaultdict(int)
- self._thread_id = threading.get_ident()
- self._extra_info = []
-
- def patched_with_thread_info(func):
- def wrap(*args, **kwargs):
- self._extra_info.append(threading.get_ident())
- return func(*args, **kwargs)
-
- return wrap
-
- self.patcher = patch.object(
- urllib3.connectionpool.log, "debug", side_effect=patched_with_thread_info(urllib3.connectionpool.log.debug)
- )
+ self.patcher = patch.object(urllib3.connectionpool.log, "debug", wraps=urllib3.connectionpool.log.debug)
self.mock = self.patcher.start()
return self
def __exit__(self, *args, **kwargs) -> None:
- assert len(self.mock.call_args_list) == len(self._extra_info)
-
- for thread_id, call in zip(self._extra_info, self.mock.call_args_list):
- if thread_id != self._thread_id:
- continue
+ for call in self.mock.call_args_list:
log = call.args[0] % call.args[1:]
for method in ("HEAD", "GET", "POST", "PUT", "DELETE", "CONNECT", "OPTIONS", "TRACE", "PATCH"):
if method in log:
diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py
index de0bc87b26b676..0bfcc4aa303665 100644
--- a/src/transformers/tokenization_utils_base.py
+++ b/src/transformers/tokenization_utils_base.py
@@ -28,7 +28,7 @@
from contextlib import contextmanager
from dataclasses import dataclass
from inspect import isfunction
-from typing import TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
+from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
import numpy as np
from packaging import version
@@ -799,13 +799,12 @@ def as_tensor(value, dtype=None):
return self
- def to(self, device: Union[str, "torch.device"], *, non_blocking: bool = False) -> "BatchEncoding":
+ def to(self, device: Union[str, "torch.device"]) -> "BatchEncoding":
"""
- Send all values to device by calling `v.to(device, non_blocking=non_blocking)` (PyTorch only).
+ Send all values to device by calling `v.to(device)` (PyTorch only).
Args:
device (`str` or `torch.device`): The device to put the tensors on.
- non_blocking (`bool`): Whether to perform the copy asynchronously.
Returns:
[`BatchEncoding`]: The same instance after modification.
@@ -817,10 +816,7 @@ def to(self, device: Union[str, "torch.device"], *, non_blocking: bool = False)
# Otherwise it passes the casts down and casts the LongTensor containing the token idxs
# into a HalfTensor
if isinstance(device, str) or is_torch_device(device) or isinstance(device, int):
- self.data = {
- k: v.to(device=device, non_blocking=non_blocking) if isinstance(v, torch.Tensor) else v
- for k, v in self.data.items()
- }
+ self.data = {k: v.to(device=device) if isinstance(v, torch.Tensor) else v for k, v in self.data.items()}
else:
logger.warning(f"Attempting to cast a BatchEncoding to type {str(device)}. This is not supported.")
return self
@@ -1527,7 +1523,7 @@ def get_vocab(self) -> Dict[str, int]:
def apply_chat_template(
self,
conversation: Union[List[Dict[str, str]], List[List[Dict[str, str]]]],
- tools: Optional[List[Union[Dict, Callable]]] = None,
+ tools: Optional[List[Dict]] = None,
documents: Optional[List[Dict[str, str]]] = None,
chat_template: Optional[str] = None,
add_generation_prompt: bool = False,
diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py
index c878d2b345cc31..af908e48e4b8c4 100755
--- a/src/transformers/trainer.py
+++ b/src/transformers/trainer.py
@@ -75,6 +75,7 @@
from .processing_utils import ProcessorMixin
from .pytorch_utils import (
ALL_LAYERNORM_LAYERS,
+ is_torch_greater_or_equal_than_1_13,
is_torch_greater_or_equal_than_2_3,
)
from .tokenization_utils_base import PreTrainedTokenizerBase
@@ -2250,7 +2251,7 @@ def _inner_training_loop(
else:
debug_overflow = DebugUnderflowOverflow(self.model) # noqa
- delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled
+ delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled
# We need to reset the scheduler, as its parameters may be different on subsequent calls
if self._created_lr_scheduler:
@@ -2303,13 +2304,12 @@ def _inner_training_loop(
# In case of auto_find_batch_size=True
# Remove FSDP wrapping from sub-models.
self.model = unwrap_model(self.model, recursive=True)
+ # configure fsdp plugin for qlora if any
+ self._fsdp_qlora_plugin_updates()
if delay_optimizer_creation:
if use_accelerator_prepare:
- # configure fsdp plugin for qlora if any
- self._fsdp_qlora_plugin_updates()
- if self.accelerator.mixed_precision != "fp8":
- self.model = self.accelerator.prepare(self.model)
+ self.model = self.accelerator.prepare(self.model)
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
# prepare using `accelerator` prepare
@@ -2516,7 +2516,6 @@ def _inner_training_loop(
context = (
functools.partial(self.accelerator.no_sync, model=model)
if i != len(batch_samples) - 1
- and self.accelerator.distributed_type != DistributedType.DEEPSPEED
else contextlib.nullcontext
)
with context():
@@ -2777,7 +2776,7 @@ def _load_from_checkpoint(self, resume_from_checkpoint, model=None):
)
if os.path.isfile(weights_file) or os.path.isfile(safe_weights_file) or is_fsdp_ckpt:
- weights_only_kwarg = {"weights_only": True}
+ weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
# If the model is on the GPU, it still works!
if is_sagemaker_mp_enabled():
if os.path.isfile(os.path.join(resume_from_checkpoint, "user_content.pt")):
@@ -2898,7 +2897,7 @@ def _load_best_model(self):
or os.path.exists(best_safe_adapter_model_path)
):
has_been_loaded = True
- weights_only_kwarg = {"weights_only": True}
+ weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
if is_sagemaker_mp_enabled():
if os.path.isfile(os.path.join(self.state.best_model_checkpoint, "user_content.pt")):
# If the 'user_content.pt' file exists, load with the new smp api.
@@ -2939,22 +2938,7 @@ def _load_best_model(self):
active_adapter = model.active_adapter
if os.path.exists(best_adapter_model_path) or os.path.exists(best_safe_adapter_model_path):
- try:
- model.load_adapter(self.state.best_model_checkpoint, active_adapter)
- except RuntimeError as exc:
- if model.peft_config[active_adapter].is_prompt_learning:
- # for context: https://github.com/huggingface/peft/issues/2256
- msg = (
- "When using prompt learning PEFT methods such as "
- f"{model.peft_config[active_adapter].peft_type.value}, setting "
- "load_best_model_at_end=True can lead to errors, it is recommended "
- "to set this to False and to load the model manually from the checkpoint "
- "directory using PeftModel.from_pretrained(base_model, ) after training "
- "has finished."
- )
- raise RuntimeError(msg) from exc
- else:
- raise
+ model.load_adapter(self.state.best_model_checkpoint, active_adapter)
# Load_adapter has no return value present, modify it when appropriate.
from torch.nn.modules.module import _IncompatibleKeys
@@ -3665,7 +3649,10 @@ def training_step(
return loss_mb.reduce_mean().detach().to(self.args.device)
with self.compute_loss_context_manager():
- loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
+ if self.model_accepts_loss_kwargs:
+ loss = self.compute_loss(model, inputs)
+ else:
+ loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
del inputs
if (
@@ -4188,7 +4175,7 @@ def evaluation_loop(
start_time = time.time()
model = (
self.accelerator.prepare(model)
- if self.is_deepspeed_enabled or (self.is_fsdp_enabled and self.accelerator.mixed_precision != "fp8")
+ if self.is_deepspeed_enabled or self.is_fsdp_enabled
else self.accelerator.prepare_model(model, evaluation_mode=True)
)
self.model_preparation_time = round(time.time() - start_time, 4)
@@ -5145,6 +5132,10 @@ def get_batch_samples(self, epoch_iterator, num_batches):
except StopIteration:
break
+ # Keep default behavior the same
+ if not self.model_accepts_loss_kwargs:
+ return batch_samples, None
+
if len(batch_samples) > 0 and "labels" in batch_samples[0]:
# For now we don't support object detection
try:
@@ -5154,8 +5145,4 @@ def get_batch_samples(self, epoch_iterator, num_batches):
if self.args.average_tokens_across_devices:
num_items_in_batch = self.accelerator.gather(num_items_in_batch).sum().item()
-
- if torch.is_tensor(num_items_in_batch):
- num_items_in_batch = num_items_in_batch.item()
-
return batch_samples, num_items_in_batch
diff --git a/src/transformers/trainer_pt_utils.py b/src/transformers/trainer_pt_utils.py
index da95329e184567..5f78860fe6c115 100644
--- a/src/transformers/trainer_pt_utils.py
+++ b/src/transformers/trainer_pt_utils.py
@@ -56,7 +56,12 @@
import torch_xla.core.xla_model as xm
if is_torch_available():
- from torch.optim.lr_scheduler import LRScheduler
+ from .pytorch_utils import is_torch_greater_or_equal_than_2_0
+
+ if is_torch_greater_or_equal_than_2_0:
+ from torch.optim.lr_scheduler import LRScheduler
+ else:
+ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
logger = logging.get_logger(__name__)
diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py
index 6950e8e66d3ac1..6b141cff39e1f7 100644
--- a/src/transformers/training_args.py
+++ b/src/transformers/training_args.py
@@ -71,6 +71,8 @@
import torch
import torch.distributed as dist
+ from .pytorch_utils import is_torch_greater_or_equal_than_2_0
+
if is_accelerate_available():
from accelerate.state import AcceleratorState, PartialState
from accelerate.utils import DistributedType
@@ -1155,7 +1157,7 @@ class TrainingArguments:
},
)
dataloader_prefetch_factor: Optional[int] = field(
- default=None,
+ default=None if not is_torch_available() or is_torch_greater_or_equal_than_2_0 else 2,
metadata={
"help": (
"Number of batches loaded in advance by each worker. "
@@ -1700,6 +1702,14 @@ def __post_init__(self):
raise ValueError(
"Your setup doesn't support bf16/gpu. You need torch>=1.10, using Ampere GPU with cuda>=11.0"
)
+ elif not is_torch_xpu_available():
+ # xpu
+ from .pytorch_utils import is_torch_greater_or_equal_than_1_12
+
+ if not is_torch_greater_or_equal_than_1_12:
+ raise ValueError(
+ "Your setup doesn't support bf16/xpu. You need torch>=1.12, using Intel XPU/GPU with IPEX installed"
+ )
if self.fp16 and self.bf16:
raise ValueError("At most one of fp16 and bf16 can be True, but not both")
@@ -2046,7 +2056,11 @@ def __post_init__(self):
if self.use_cpu:
self.dataloader_pin_memory = False
- if self.dataloader_num_workers == 0 and self.dataloader_prefetch_factor is not None:
+ if (
+ (not is_torch_available() or is_torch_greater_or_equal_than_2_0)
+ and self.dataloader_num_workers == 0
+ and self.dataloader_prefetch_factor is not None
+ ):
raise ValueError(
"--dataloader_prefetch_factor can only be set when data is loaded in a different process, i.e."
" when --dataloader_num_workers > 1."
diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py
index 2edfcdcd101c78..08d23e0e6a5d41 100755
--- a/src/transformers/utils/__init__.py
+++ b/src/transformers/utils/__init__.py
@@ -55,8 +55,6 @@
is_tensor,
is_tf_symbolic_tensor,
is_tf_tensor,
- is_timm_config_dict,
- is_timm_local_checkpoint,
is_torch_device,
is_torch_dtype,
is_torch_tensor,
@@ -233,7 +231,6 @@
is_training_run_on_sagemaker,
is_uroman_available,
is_vision_available,
- is_vptq_available,
requires_backends,
torch_only_method,
)
diff --git a/src/transformers/utils/chat_template_utils.py b/src/transformers/utils/chat_template_utils.py
index 72bec701e14daf..c64a2c4dcb3468 100644
--- a/src/transformers/utils/chat_template_utils.py
+++ b/src/transformers/utils/chat_template_utils.py
@@ -15,7 +15,6 @@
import inspect
import json
import re
-import types
from contextlib import contextmanager
from datetime import datetime
from functools import lru_cache
@@ -98,7 +97,7 @@ def _parse_type_hint(hint: str) -> Dict:
"Couldn't parse this type hint, likely due to a custom class or object: ", hint
)
- elif origin is Union or (hasattr(types, "UnionType") and origin is types.UnionType):
+ elif origin is Union:
# Recurse into each of the subtypes in the Union, except None, which is handled separately at the end
subtypes = [_parse_type_hint(t) for t in args if t is not type(None)]
if len(subtypes) == 1:
diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py
index 3079dccf917efd..ba93fa9e1d4165 100644
--- a/src/transformers/utils/dummy_pt_objects.py
+++ b/src/transformers/utils/dummy_pt_objects.py
@@ -813,9 +813,6 @@ def __init__(self, *args, **kwargs):
MODEL_FOR_QUESTION_ANSWERING_MAPPING = None
-MODEL_FOR_RETRIEVAL_MAPPING = None
-
-
MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING = None
@@ -1167,27 +1164,6 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
-class BambaForCausalLM(metaclass=DummyObject):
- _backends = ["torch"]
-
- def __init__(self, *args, **kwargs):
- requires_backends(self, ["torch"])
-
-
-class BambaModel(metaclass=DummyObject):
- _backends = ["torch"]
-
- def __init__(self, *args, **kwargs):
- requires_backends(self, ["torch"])
-
-
-class BambaPreTrainedModel(metaclass=DummyObject):
- _backends = ["torch"]
-
- def __init__(self, *args, **kwargs):
- requires_backends(self, ["torch"])
-
-
class BarkCausalModel(metaclass=DummyObject):
_backends = ["torch"]
@@ -2261,41 +2237,6 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
-class Cohere2ForCausalLM(metaclass=DummyObject):
- _backends = ["torch"]
-
- def __init__(self, *args, **kwargs):
- requires_backends(self, ["torch"])
-
-
-class Cohere2Model(metaclass=DummyObject):
- _backends = ["torch"]
-
- def __init__(self, *args, **kwargs):
- requires_backends(self, ["torch"])
-
-
-class Cohere2PreTrainedModel(metaclass=DummyObject):
- _backends = ["torch"]
-
- def __init__(self, *args, **kwargs):
- requires_backends(self, ["torch"])
-
-
-class ColPaliForRetrieval(metaclass=DummyObject):
- _backends = ["torch"]
-
- def __init__(self, *args, **kwargs):
- requires_backends(self, ["torch"])
-
-
-class ColPaliPreTrainedModel(metaclass=DummyObject):
- _backends = ["torch"]
-
- def __init__(self, *args, **kwargs):
- requires_backends(self, ["torch"])
-
-
class ConditionalDetrForObjectDetection(metaclass=DummyObject):
_backends = ["torch"]
@@ -6460,41 +6401,6 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
-class ModernBertForMaskedLM(metaclass=DummyObject):
- _backends = ["torch"]
-
- def __init__(self, *args, **kwargs):
- requires_backends(self, ["torch"])
-
-
-class ModernBertForSequenceClassification(metaclass=DummyObject):
- _backends = ["torch"]
-
- def __init__(self, *args, **kwargs):
- requires_backends(self, ["torch"])
-
-
-class ModernBertForTokenClassification(metaclass=DummyObject):
- _backends = ["torch"]
-
- def __init__(self, *args, **kwargs):
- requires_backends(self, ["torch"])
-
-
-class ModernBertModel(metaclass=DummyObject):
- _backends = ["torch"]
-
- def __init__(self, *args, **kwargs):
- requires_backends(self, ["torch"])
-
-
-class ModernBertPreTrainedModel(metaclass=DummyObject):
- _backends = ["torch"]
-
- def __init__(self, *args, **kwargs):
- requires_backends(self, ["torch"])
-
-
class MoshiForCausalLM(metaclass=DummyObject):
_backends = ["torch"]
@@ -9179,27 +9085,6 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
-class TimmWrapperForImageClassification(metaclass=DummyObject):
- _backends = ["torch"]
-
- def __init__(self, *args, **kwargs):
- requires_backends(self, ["torch"])
-
-
-class TimmWrapperModel(metaclass=DummyObject):
- _backends = ["torch"]
-
- def __init__(self, *args, **kwargs):
- requires_backends(self, ["torch"])
-
-
-class TimmWrapperPreTrainedModel(metaclass=DummyObject):
- _backends = ["torch"]
-
- def __init__(self, *args, **kwargs):
- requires_backends(self, ["torch"])
-
-
class TrOCRForCausalLM(metaclass=DummyObject):
_backends = ["torch"]
diff --git a/src/transformers/utils/dummy_timm_and_torchvision_objects.py b/src/transformers/utils/dummy_timm_and_torchvision_objects.py
deleted file mode 100644
index 8b67b5dac58db1..00000000000000
--- a/src/transformers/utils/dummy_timm_and_torchvision_objects.py
+++ /dev/null
@@ -1,9 +0,0 @@
-# This file is autogenerated by the command `make fix-copies`, do not edit.
-from ..utils import DummyObject, requires_backends
-
-
-class TimmWrapperImageProcessor(metaclass=DummyObject):
- _backends = ["timm", "torchvision"]
-
- def __init__(self, *args, **kwargs):
- requires_backends(self, ["timm", "torchvision"])
diff --git a/src/transformers/utils/fx.py b/src/transformers/utils/fx.py
index 45fa3d9ca68c51..101b34182a7309 100755
--- a/src/transformers/utils/fx.py
+++ b/src/transformers/utils/fx.py
@@ -60,6 +60,7 @@
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES,
MODEL_MAPPING_NAMES,
)
+from ..pytorch_utils import is_torch_greater_or_equal_than_2_0
from .import_utils import (
ENV_VARS_TRUE_VALUES,
TORCH_FX_REQUIRED_VERSION,
@@ -634,9 +635,10 @@ def to_concrete(t):
operator.getitem: operator_getitem,
}
-_MANUAL_META_OVERRIDES[torch.nn.functional.scaled_dot_product_attention] = (
- torch_nn_functional_scaled_dot_product_attention
-)
+if is_torch_greater_or_equal_than_2_0:
+ _MANUAL_META_OVERRIDES[torch.nn.functional.scaled_dot_product_attention] = (
+ torch_nn_functional_scaled_dot_product_attention
+ )
class HFProxy(Proxy):
diff --git a/src/transformers/utils/generic.py b/src/transformers/utils/generic.py
index a997da79e8419d..26ec82b20fd40e 100644
--- a/src/transformers/utils/generic.py
+++ b/src/transformers/utils/generic.py
@@ -16,8 +16,6 @@
"""
import inspect
-import json
-import os
import tempfile
import warnings
from collections import OrderedDict, UserDict
@@ -26,7 +24,7 @@
from dataclasses import fields, is_dataclass
from enum import Enum
from functools import partial, wraps
-from typing import Any, ContextManager, Dict, Iterable, List, Optional, Tuple, TypedDict
+from typing import Any, ContextManager, Iterable, List, Optional, Tuple, TypedDict
import numpy as np
from packaging import version
@@ -869,36 +867,3 @@ class LossKwargs(TypedDict, total=False):
"""
num_items_in_batch: Optional[int]
-
-
-def is_timm_config_dict(config_dict: Dict[str, Any]) -> bool:
- """Checks whether a config dict is a timm config dict."""
- return "pretrained_cfg" in config_dict
-
-
-def is_timm_local_checkpoint(pretrained_model_path: str) -> bool:
- """
- Checks whether a checkpoint is a timm model checkpoint.
- """
- if pretrained_model_path is None:
- return False
-
- # in case it's Path, not str
- pretrained_model_path = str(pretrained_model_path)
-
- is_file = os.path.isfile(pretrained_model_path)
- is_dir = os.path.isdir(pretrained_model_path)
-
- # pretrained_model_path is a file
- if is_file and pretrained_model_path.endswith(".json"):
- with open(pretrained_model_path, "r") as f:
- config_dict = json.load(f)
- return is_timm_config_dict(config_dict)
-
- # pretrained_model_path is a directory with a config.json
- if is_dir and os.path.exists(os.path.join(pretrained_model_path, "config.json")):
- with open(os.path.join(pretrained_model_path, "config.json"), "r") as f:
- config_dict = json.load(f)
- return is_timm_config_dict(config_dict)
-
- return False
diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py
index cfc8b88fd81ed6..32a647594741dd 100755
--- a/src/transformers/utils/import_utils.py
+++ b/src/transformers/utils/import_utils.py
@@ -93,13 +93,11 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
GGUF_MIN_VERSION = "0.10.0"
XLA_FSDPV2_MIN_VERSION = "2.2.0"
HQQ_MIN_VERSION = "0.2.1"
-VPTQ_MIN_VERSION = "0.0.4"
_accelerate_available, _accelerate_version = _is_package_available("accelerate", return_version=True)
_apex_available = _is_package_available("apex")
_aqlm_available = _is_package_available("aqlm")
-_vptq_available, _vptq_version = _is_package_available("vptq", return_version=True)
_av_available = importlib.util.find_spec("av") is not None
_bitsandbytes_available = _is_package_available("bitsandbytes")
_eetq_available = _is_package_available("eetq")
@@ -194,7 +192,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
_tiktoken_available = _is_package_available("tiktoken")
_blobfile_available = _is_package_available("blobfile")
_liger_kernel_available = _is_package_available("liger_kernel")
-_triton_available = _is_package_available("triton")
+
_torch_version = "N/A"
_torch_available = False
@@ -818,10 +816,6 @@ def is_aqlm_available():
return _aqlm_available
-def is_vptq_available(min_version: str = VPTQ_MIN_VERSION):
- return _vptq_available and version.parse(_vptq_version) >= version.parse(min_version)
-
-
def is_av_available():
return _av_available
@@ -1249,10 +1243,6 @@ def is_liger_kernel_available():
return version.parse(importlib.metadata.version("liger_kernel")) >= version.parse("0.3.0")
-def is_triton_available():
- return _triton_available
-
-
# docstyle-ignore
AV_IMPORT_ERROR = """
{0} requires the PyAv library but it was not found in your environment. You can install it with:
diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py
index 44e47e4f6e65c2..bacbca94cd823f 100755
--- a/src/transformers/utils/quantization_config.py
+++ b/src/transformers/utils/quantization_config.py
@@ -39,7 +39,6 @@ class QuantizationMethod(str, Enum):
GPTQ = "gptq"
AWQ = "awq"
AQLM = "aqlm"
- VPTQ = "vptq"
QUANTO = "quanto"
EETQ = "eetq"
HQQ = "hqq"
@@ -995,102 +994,6 @@ def post_init(self):
self.linear_weights_not_to_quantize = []
-@dataclass
-class VptqLayerConfig(QuantizationConfigMixin):
- """
- This is used to explain vptq config params for each layer
- Args:
- enable_norm (`bool`, *optional*, defaults to `True`): to control if we have scale/bias for fp-weight
- enable_perm (`bool`, *optional*, defaults to `True`): to perm input_channel or not
- group_num (`int`, *optional*, defaults to `1`): how many single groups for vector-quantization
- group_size (`int`, *optional*, defaults to `-1`): depends on out-features
- indices_as_float (`bool`, *optional*, defaults to `False`): for Finetuning
- is_indice_packed (`bool`, *optional*, defaults to `True`): should always be True
- num_centroids (`list`, *optional*, defaults to `[-1, -1]`): centriod numbers of clusters
- num_res_centroids (`list`, *optional*, defaults to `[-1, -1]`): ditto for residual
- outlier_size (`int`, *optional*, defaults to `1`): outliers
- vector_lens (`list`, *optional*, defaults to `[-1, -1]`): centroid vector length in quantization
- """
-
- def __init__(
- self,
- enable_norm: bool = True,
- enable_perm: bool = True,
- group_num: int = 1,
- group_size: int = -1,
- in_features: int = -1,
- indices_as_float: bool = False,
- is_indice_packed: bool = True,
- num_centroids: tuple = [-1, -1],
- num_res_centroids: tuple = [-1, -1],
- out_features: int = -1,
- outlier_size: int = 0,
- vector_lens: tuple = [-1, -1],
- **kwargs,
- ):
- self.enable_norm = enable_norm
- self.enable_perm = enable_perm
- self.group_num = group_num
- self.group_size = group_size
- self.in_features = in_features
- self.indices_as_float = indices_as_float
- self.is_indice_packed = is_indice_packed
- self.num_centroids = num_centroids
- self.num_res_centroids = num_res_centroids
- self.out_features = out_features
- self.outlier_size = outlier_size
- self.vector_lens = vector_lens
- self.post_init()
-
- def post_init(self):
- r"""
- Safety checker that arguments are correct
- """
- if self.is_indice_packed is False:
- raise ValueError("is_indice_packed should always be True")
-
-
-@dataclass
-class VptqConfig(QuantizationConfigMixin):
- """
- This is a wrapper class about `vptq` parameters.
-
- Args:
- enable_proxy_error (`bool`, *optional*, defaults to `False`): calculate proxy error for each layer
- config_for_layers (`Dict`, *optional*, defaults to `{}`): quantization params for each layer
- shared_layer_config (`Dict`, *optional*, defaults to `{}`): shared quantization params among layers
- modules_to_not_convert (`list`, *optional*, default to `None`):
- The list of modules to not quantize, useful for quantizing models that explicitly require to have
- some modules left in their original precision (e.g. Whisper encoder, Llava encoder, Mixtral gate layers).
- kwargs (`Dict[str, Any]`, *optional*):
- Additional parameters from which to initialize the configuration object.
- """
-
- def __init__(
- self,
- enable_proxy_error: bool = False,
- config_for_layers: Dict[str, Any] = {},
- shared_layer_config: Dict[str, Any] = {},
- modules_to_not_convert: Optional[List] = None,
- **kwargs,
- ):
- self.quant_method = QuantizationMethod.VPTQ
- self.enable_proxy_error = enable_proxy_error
- self.config_for_layers: Dict[str, Any] = config_for_layers
- self.shared_layer_config: Dict[str, Any] = shared_layer_config
- self.modules_to_not_convert = modules_to_not_convert
- self.post_init()
-
- def post_init(self):
- r"""
- Safety checker that arguments are correct
- """
- for layer_name, layer_param in self.config_for_layers.items():
- VptqLayerConfig(**layer_param)
- if self.enable_proxy_error is True:
- raise ValueError("enable_proxy_error should always be False until we support training")
-
-
@dataclass
class QuantoConfig(QuantizationConfigMixin):
"""
@@ -1174,8 +1077,7 @@ class CompressedTensorsConfig(QuantizationConfigMixin):
config_groups (`typing.Dict[str, typing.Union[ForwardRef('QuantizationScheme'), typing.List[str]]]`, *optional*):
dictionary mapping group name to a quantization scheme definition
format (`str`, *optional*, defaults to `"dense"`):
- format the model is represented as. Set `run_compressed` True to execute model as the
- compressed format if not `dense`
+ format the model is represented as
quantization_status (`QuantizationStatus`, *optional*, defaults to `"initialized"`):
status of model in the quantization lifecycle, ie 'initialized', 'calibration', 'frozen'
kv_cache_scheme (`typing.Union[QuantizationArgs, NoneType]`, *optional*):
@@ -1188,8 +1090,6 @@ class CompressedTensorsConfig(QuantizationConfigMixin):
configuration for sparsity compression
quant_method (`str`, *optional*, defaults to `"compressed-tensors"`):
do not override, should be compressed-tensors
- run_compressed (`bool`, *optional*, defaults to `True`): alter submodules (usually linear) in order to
- emulate compressed model execution if True, otherwise use default submodule
"""
def __init__(
@@ -1202,17 +1102,14 @@ def __init__(
ignore: Optional[List[str]] = None,
sparsity_config: Dict[str, Any] = None,
quant_method: str = "compressed-tensors",
- run_compressed: bool = True,
**kwargs,
):
+ from compressed_tensors import QuantizationConfig
from compressed_tensors.config import SparsityCompressionConfig
- from compressed_tensors.quantization import QuantizationConfig
self.quantization_config = None
self.sparsity_config = None
- self.run_compressed = run_compressed
-
# parse from dict to load nested QuantizationScheme objects
if config_groups or kv_cache_scheme:
self.quantization_config = QuantizationConfig.parse_obj(
@@ -1224,7 +1121,6 @@ def __init__(
"kv_cache_scheme": kv_cache_scheme,
"global_compression_ratio": global_compression_ratio,
"ignore": ignore,
- "run_compressed": run_compressed,
**kwargs,
}
)
@@ -1253,7 +1149,6 @@ def from_dict(cls, config_dict, return_unused_kwargs=False, **kwargs):
Returns:
[`QuantizationConfigMixin`]: The configuration object instantiated from those parameters.
-
"""
if "quantization_config" in config_dict:
@@ -1305,9 +1200,6 @@ def to_diff_dict(self) -> Dict[str, Any]:
return serializable_config_dict
- def get_loading_attributes(self):
- return {"run_compressed": self.run_compressed}
-
@dataclass
class FbgemmFp8Config(QuantizationConfigMixin):
diff --git a/tests/generation/test_candidate_generator.py b/tests/generation/test_candidate_generator.py
deleted file mode 100644
index 03fd51324b022f..00000000000000
--- a/tests/generation/test_candidate_generator.py
+++ /dev/null
@@ -1,43 +0,0 @@
-import unittest
-
-import numpy as np
-
-from transformers.generation.candidate_generator import AssistedCandidateGeneratorDifferentTokenizers
-
-
-class TestAssistedCandidateGeneratorDifferentTokenizers(unittest.TestCase):
- def test_no_intersection(self):
- prompt = np.array([[1, 2, 3]])
- prompt_plus_new_tokens = np.array([[4, 5, 6]])
- result = AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag(prompt, prompt_plus_new_tokens)
- self.assertEqual(result, (None, None, None))
-
- def test_complete_overlap(self):
- prompt = np.array([[1, 2, 3]])
- prompt_plus_new_tokens = np.array([[1, 2, 3, 4, 5]])
- discrep_length, new_tokens_only, discrep_only = AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag(
- prompt, prompt_plus_new_tokens
- )
- self.assertEqual(discrep_length, 0)
- np.testing.assert_array_equal(new_tokens_only, np.array([[4, 5]]))
- np.testing.assert_array_equal(discrep_only, np.array([[]]))
-
- def test_partial_overlap(self):
- prompt = np.array([[1, 2, 3]])
- prompt_plus_new_tokens = np.array([[2, 3, 4, 5]])
- discrep_length, new_tokens_only, discrep_only = AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag(
- prompt, prompt_plus_new_tokens
- )
- self.assertEqual(discrep_length, 0)
- np.testing.assert_array_equal(new_tokens_only, np.array([[4, 5]]))
- np.testing.assert_array_equal(discrep_only, np.array([[]]))
-
- def test_no_new_tokens(self):
- prompt = np.array([[1, 2, 3]])
- prompt_plus_new_tokens = np.array([[1, 2, 3]])
- discrep_length, new_tokens_only, discrep_only = AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag(
- prompt, prompt_plus_new_tokens
- )
- self.assertEqual(discrep_length, 0)
- np.testing.assert_array_equal(new_tokens_only, np.array([[]]))
- np.testing.assert_array_equal(discrep_only, np.array([[]]))
diff --git a/tests/generation/test_streamers.py b/tests/generation/test_streamers.py
index be8c37334d02fc..c82a5e99e0ded0 100644
--- a/tests/generation/test_streamers.py
+++ b/tests/generation/test_streamers.py
@@ -17,15 +17,7 @@
from queue import Empty
from threading import Thread
-import pytest
-
-from transformers import (
- AsyncTextIteratorStreamer,
- AutoTokenizer,
- TextIteratorStreamer,
- TextStreamer,
- is_torch_available,
-)
+from transformers import AutoTokenizer, TextIteratorStreamer, TextStreamer, is_torch_available
from transformers.testing_utils import CaptureStdout, require_torch, torch_device
from ..test_modeling_common import ids_tensor
@@ -128,43 +120,3 @@ def test_iterator_streamer_timeout(self):
streamer_text = ""
for new_text in streamer:
streamer_text += new_text
-
-
-@require_torch
-@pytest.mark.asyncio(loop_scope="class")
-class AsyncStreamerTester(unittest.IsolatedAsyncioTestCase):
- async def test_async_iterator_streamer_matches_non_streaming(self):
- tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
- model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
- model.config.eos_token_id = -1
-
- input_ids = ids_tensor((1, 5), vocab_size=model.config.vocab_size).to(torch_device)
- greedy_ids = model.generate(input_ids, max_new_tokens=10, do_sample=False)
- greedy_text = tokenizer.decode(greedy_ids[0])
-
- streamer = AsyncTextIteratorStreamer(tokenizer)
- generation_kwargs = {"input_ids": input_ids, "max_new_tokens": 10, "do_sample": False, "streamer": streamer}
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
- thread.start()
- streamer_text = ""
- async for new_text in streamer:
- streamer_text += new_text
-
- self.assertEqual(streamer_text, greedy_text)
-
- async def test_async_iterator_streamer_timeout(self):
- tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
- model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
- model.config.eos_token_id = -1
-
- input_ids = ids_tensor((1, 5), vocab_size=model.config.vocab_size).to(torch_device)
- streamer = AsyncTextIteratorStreamer(tokenizer, timeout=0.001)
- generation_kwargs = {"input_ids": input_ids, "max_new_tokens": 10, "do_sample": False, "streamer": streamer}
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
- thread.start()
-
- # The streamer will timeout after 0.001 seconds, so TimeoutError will be raised
- with self.assertRaises(TimeoutError):
- streamer_text = ""
- async for new_text in streamer:
- streamer_text += new_text
diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py
index 4ac22e77779022..76ab793e3a36c0 100644
--- a/tests/generation/test_utils.py
+++ b/tests/generation/test_utils.py
@@ -37,9 +37,6 @@
require_torch_multi_accelerator,
require_torch_multi_gpu,
require_torch_sdpa,
- set_config_for_less_flaky_test,
- set_model_for_less_flaky_test,
- set_model_tester_for_less_flaky_test,
slow,
torch_device,
)
@@ -1205,7 +1202,6 @@ def test_prompt_lookup_decoding_matches_greedy_search(self):
"prophetnet",
"seamlessm4t",
"clvp",
- "fuyu",
]
):
self.skipTest(reason="May fix in the future: need model-specific fixes")
@@ -1924,13 +1920,11 @@ def test_generate_with_static_cache(self):
Tests that generating with static cache give almost same results as with dynamic cache, and the output cache
has the expected shapes
"""
- set_model_tester_for_less_flaky_test(self)
for model_class in self.all_generative_model_classes:
if not model_class._supports_static_cache:
self.skipTest(reason="This model does not support the static cache format")
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
- set_config_for_less_flaky_test(config)
main_input = inputs_dict[model_class.main_input_name]
if config.is_encoder_decoder:
@@ -1943,8 +1937,6 @@ def test_generate_with_static_cache(self):
for dtype in (torch.float32, torch.float16):
model = model_class(config).to(torch_device).to(dtype).eval()
- set_model_for_less_flaky_test(model)
-
generation_kwargs = {
"max_new_tokens": max_new_tokens,
"return_dict_in_generate": True, # Required to return `past_key_values`
@@ -2320,7 +2312,6 @@ def _check_outputs(self, output, config, use_cache=False, num_return_sequences=1
# 2. We ignore models that have unique cache structures (e.g. mamba) or are in need of refatoring to match the
# standard cache format (e.g.gptbigcode )
models_without_standard_cache = (
- "bamba",
"ctrl",
"fsmt",
"gptbigcode",
diff --git a/tests/models/aria/test_modeling_aria.py b/tests/models/aria/test_modeling_aria.py
index b6f1da56c6782e..d3458530ac349e 100644
--- a/tests/models/aria/test_modeling_aria.py
+++ b/tests/models/aria/test_modeling_aria.py
@@ -45,7 +45,8 @@
if is_torch_available():
import torch
-
+else:
+ is_torch_greater_or_equal_than_2_0 = False
if is_vision_available():
from PIL import Image
diff --git a/tests/models/audio_spectrogram_transformer/test_feature_extraction_audio_spectrogram_transformer.py b/tests/models/audio_spectrogram_transformer/test_feature_extraction_audio_spectrogram_transformer.py
index ff33de487df324..fbe250908633db 100644
--- a/tests/models/audio_spectrogram_transformer/test_feature_extraction_audio_spectrogram_transformer.py
+++ b/tests/models/audio_spectrogram_transformer/test_feature_extraction_audio_spectrogram_transformer.py
@@ -50,7 +50,7 @@ def floats_list(shape, scale=1.0, rng=None, name=None):
return values
-class ASTFeatureExtractionTester:
+class ASTFeatureExtractionTester(unittest.TestCase):
def __init__(
self,
parent,
diff --git a/tests/models/auto/test_image_processing_auto.py b/tests/models/auto/test_image_processing_auto.py
index 1becf25ae7c33c..c0046ae1c363cd 100644
--- a/tests/models/auto/test_image_processing_auto.py
+++ b/tests/models/auto/test_image_processing_auto.py
@@ -140,7 +140,6 @@ def test_image_processor_not_found(self):
def test_use_fast_selection(self):
checkpoint = "hf-internal-testing/tiny-random-vit"
- # TODO: @yoni, change in v4.48 (when use_fast set to True by default)
# Slow image processor is selected by default
image_processor = AutoImageProcessor.from_pretrained(checkpoint)
self.assertIsInstance(image_processor, ViTImageProcessor)
diff --git a/tests/models/bamba/__init__.py b/tests/models/bamba/__init__.py
deleted file mode 100644
index e69de29bb2d1d6..00000000000000
diff --git a/tests/models/bamba/test_modeling_bamba.py b/tests/models/bamba/test_modeling_bamba.py
deleted file mode 100644
index 45819e66b73c08..00000000000000
--- a/tests/models/bamba/test_modeling_bamba.py
+++ /dev/null
@@ -1,603 +0,0 @@
-# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""Testing suite for the PyTorch Bamba model."""
-
-import inspect
-import unittest
-
-import pytest
-from parameterized import parameterized
-
-from transformers import AutoTokenizer, BambaConfig, is_torch_available
-from transformers.testing_utils import (
- require_torch,
- slow,
- torch_device,
-)
-
-from ...generation.test_utils import GenerationTesterMixin
-from ...test_configuration_common import ConfigTester
-from ...test_modeling_common import ModelTesterMixin, _config_zero_init, ids_tensor
-from ...test_pipeline_mixin import PipelineTesterMixin
-
-
-if is_torch_available():
- import torch
-
- from transformers import (
- BambaForCausalLM,
- BambaModel,
- )
- from transformers.models.bamba.modeling_bamba import (
- HybridMambaAttentionDynamicCache,
- )
-
-
-class BambaModelTester:
- def __init__(
- self,
- parent,
- batch_size=13,
- seq_length=7,
- is_training=True,
- use_input_mask=True,
- use_labels=True,
- vocab_size=99,
- hidden_size=32,
- num_hidden_layers=4,
- num_attention_heads=4,
- num_key_value_heads=2,
- intermediate_size=64,
- hidden_act="silu",
- attention_dropout=0.0,
- attn_layer_indices=None,
- attn_rotary_emb=8,
- max_position_embeddings=512,
- type_vocab_size=16,
- initializer_range=0.02,
- num_labels=3,
- pad_token_id=0,
- mamba_n_groups=1,
- mamba_n_heads=16,
- mamba_d_state=16,
- mamba_d_conv=4,
- mamba_expand=2,
- mamba_chunk_size=16,
- scope=None,
- ):
- self.parent = parent
- self.batch_size = batch_size
- self.seq_length = seq_length
- self.is_training = is_training
- self.use_input_mask = use_input_mask
- self.use_labels = use_labels
- self.vocab_size = vocab_size
- self.hidden_size = hidden_size
- self.num_hidden_layers = num_hidden_layers
- self.num_attention_heads = num_attention_heads
- self.num_key_value_heads = num_key_value_heads
- self.intermediate_size = intermediate_size
- self.hidden_act = hidden_act
- self.attention_dropout = attention_dropout
- self.attn_layer_indices = attn_layer_indices
- self.attn_rotary_emb = attn_rotary_emb
- self.max_position_embeddings = max_position_embeddings
- self.type_vocab_size = type_vocab_size
- self.initializer_range = initializer_range
- self.num_labels = num_labels
- self.pad_token_id = pad_token_id
- self.scope = scope
- self.mamba_n_groups = mamba_n_groups
- self.mamba_n_heads = mamba_n_heads
- self.mamba_d_state = mamba_d_state
- self.mamba_d_conv = mamba_d_conv
- self.mamba_expand = mamba_expand
- self.mamba_chunk_size = mamba_chunk_size
-
- def prepare_config_and_inputs(self):
- input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
-
- input_mask = None
- if self.use_input_mask:
- input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))
-
- token_labels = None
- if self.use_labels:
- token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
-
- config = self.get_config()
-
- return config, input_ids, input_mask, token_labels
-
- def prepare_config_and_inputs_for_common(self):
- config_and_inputs = self.prepare_config_and_inputs()
- (
- config,
- input_ids,
- input_mask,
- token_labels,
- ) = config_and_inputs
- inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
- return config, inputs_dict
-
- def get_config(self):
- # Fix for SDPA tests, force at least 4 layers
- if self.num_hidden_layers < 4:
- self.num_hidden_layers = 4
- if self.attn_layer_indices is None:
- d = [x for x in range(2, self.num_hidden_layers) if self.num_hidden_layers % x == 0]
- if len(d) == 0:
- raise ValueError("num_hidden_layers is prime, cannot automatically set attn_layer_indices.")
- d = d[-1] # get the largest divisor
- self.attn_layer_indices = [x + 1 for x in range(0, self.num_hidden_layers, d)]
-
- return BambaConfig(
- vocab_size=self.vocab_size,
- hidden_size=self.hidden_size,
- num_hidden_layers=self.num_hidden_layers,
- num_attention_heads=self.num_attention_heads,
- num_key_value_heads=self.num_key_value_heads,
- intermediate_size=self.intermediate_size,
- hidden_act=self.hidden_act,
- attention_dropout=self.attention_dropout,
- attn_layer_indices=self.attn_layer_indices,
- attn_rotary_emb=self.attn_rotary_emb,
- max_position_embeddings=self.max_position_embeddings,
- initializer_range=self.initializer_range,
- pad_token_id=self.pad_token_id,
- mamba_n_groups=self.mamba_n_groups,
- mamba_n_heads=self.mamba_n_heads,
- mamba_d_state=self.mamba_d_state,
- mamba_d_conv=self.mamba_d_conv,
- mamba_expand=self.mamba_expand,
- mamba_chunk_size=self.mamba_chunk_size,
- )
-
- def create_and_check_model(
- self,
- config,
- input_ids,
- input_mask,
- token_labels,
- ):
- model = BambaModel(config=config)
- model.to(torch_device)
- model.eval()
- result = model(input_ids, attention_mask=input_mask)
- result = model(input_ids)
- self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
-
- def create_and_check_for_causal_lm(
- self,
- config,
- input_ids,
- input_mask,
- token_labels,
- ):
- model = BambaForCausalLM(config=config)
- model.to(torch_device)
- model.eval()
- result = model(input_ids, attention_mask=input_mask, labels=token_labels)
- result = model(input_ids, attention_mask=input_mask)
- result = model(input_ids, labels=token_labels)
- result = model(input_ids)
- self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
-
- def create_and_check_decoder_model_past_large_inputs(
- self,
- config,
- input_ids,
- input_mask,
- token_labels,
- ):
- # config.is_decoder = True
- # config.add_cross_attention = True
- model = BambaForCausalLM(config=config)
- model.to(torch_device)
- model.eval()
-
- # first forward pass
- # Attention: Jamba needs the cache to be initialized to return a cache!
- past_key_values = HybridMambaAttentionDynamicCache(
- config, input_ids.shape[0], model.dtype, device=model.device
- )
- outputs = model(
- input_ids,
- attention_mask=input_mask,
- past_key_values=past_key_values,
- use_cache=True,
- )
- past_key_values = outputs.past_key_values
-
- # create hypothetical multiple next token and extent to next_input_ids
- next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
- next_mask = ids_tensor((self.batch_size, 3), vocab_size=2)
-
- # append to next input_ids and
- next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
- next_attention_mask = torch.cat([input_mask, next_mask], dim=-1)
-
- output_from_no_past = model(
- next_input_ids,
- attention_mask=next_attention_mask,
- output_hidden_states=True,
- )["hidden_states"][0]
- output_from_past = model(
- next_tokens,
- attention_mask=next_attention_mask,
- past_key_values=past_key_values,
- output_hidden_states=True,
- cache_position=torch.arange(
- input_ids.shape[1], input_ids.shape[1] + next_tokens.shape[1], device=model.device
- ),
- )["hidden_states"][0]
-
- # select random slice
- random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
- output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()
- output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()
-
- self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
-
- # test that outputs are equal for slice
- self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
-
-
-@require_torch
-class BambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
- all_model_classes = (
- (
- BambaModel,
- BambaForCausalLM,
- )
- if is_torch_available()
- else ()
- )
- all_generative_model_classes = (BambaForCausalLM,) if is_torch_available() else ()
- pipeline_model_mapping = (
- {
- "feature-extraction": BambaModel,
- "text-generation": BambaForCausalLM,
- }
- if is_torch_available()
- else {}
- )
- test_headmasking = False
- test_pruning = False
- fx_compatible = False
-
- # Need to use `0.8` instead of `0.9` for `test_cpu_offload`
- # This is because we are hitting edge cases with the causal_mask buffer
- model_split_percents = [0.5, 0.7, 0.8]
-
- def setUp(self):
- self.model_tester = BambaModelTester(self)
- self.config_tester = ConfigTester(self, config_class=BambaConfig, hidden_size=64)
-
- def test_config(self):
- self.config_tester.run_common_tests()
-
- def test_model(self):
- config_and_inputs = self.model_tester.prepare_config_and_inputs()
- self.model_tester.create_and_check_model(*config_and_inputs)
-
- def test_for_casual_lm(self):
- config_and_inputs = self.model_tester.prepare_config_and_inputs()
- self.model_tester.create_and_check_for_causal_lm(*config_and_inputs)
-
- def test_decoder_model_past_with_large_inputs(self):
- config_and_inputs = self.model_tester.prepare_config_and_inputs()
- self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)
-
- def test_initialization(self):
- r"""
- Overriding the test_initialization test as the A_log and D params of the Bamba mixer are initialized differently
- """
- config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
-
- configs_no_init = _config_zero_init(config)
- for model_class in self.all_model_classes:
- model = model_class(config=configs_no_init)
- for name, param in model.named_parameters():
- if param.requires_grad:
- if "A_log" in name:
- A = torch.arange(1, config.mamba_n_heads + 1, dtype=torch.float32)[None, :]
- self.assertTrue(torch.allclose(param.data, torch.log(A), atol=1e-5, rtol=1e-5))
- elif "D" in name:
- D = torch.ones(config.mamba_n_heads, dtype=torch.float32)
- self.assertTrue(torch.allclose(param.data, D, atol=1e-5, rtol=1e-5))
- else:
- self.assertIn(
- ((param.data.mean() * 1e9).round() / 1e9).item(),
- [0.0, 1.0],
- msg=f"Parameter {name} of model {model_class} seems not properly initialized",
- )
-
- def test_mismatched_shapes_have_properly_initialized_weights(self):
- r"""
- Overriding the test_mismatched_shapes_have_properly_initialized_weights test because A_log and D params of the
- Bamba mixer are initialized differently and we tested that in test_initialization
- """
- self.skipTest(reason="Cumbersome and redundant for Bamba")
-
- def test_attention_outputs(self):
- r"""
- Overriding the test_attention_outputs test as the Bamba model outputs attention only for its attention layers
- """
- config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
- config.return_dict = True
-
- seq_len = getattr(self.model_tester, "seq_length", None)
- encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len)
- encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
-
- expected_num_attentions = self.model_tester.num_hidden_layers - len(self.model_tester.attn_layer_indices)
-
- for model_class in self.all_model_classes:
- inputs_dict["output_attentions"] = True
- inputs_dict["output_hidden_states"] = False
- config.return_dict = True
- model = model_class(config)
- model.to(torch_device)
- model.eval()
-
- with torch.no_grad():
- outputs = model(**self._prepare_for_class(inputs_dict, model_class))
- attentions = outputs.attentions
- self.assertEqual(len(attentions), expected_num_attentions)
-
- # check that output_attentions also work using config
- del inputs_dict["output_attentions"]
- config.output_attentions = True
- model = model_class(config)
- model.to(torch_device)
- model.eval()
- with torch.no_grad():
- outputs = model(**self._prepare_for_class(inputs_dict, model_class))
- attentions = outputs.attentions
- self.assertEqual(len(attentions), expected_num_attentions)
-
- self.assertListEqual(
- list(attentions[0].shape[-3:]),
- [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
- )
- out_len = len(outputs)
-
- # Check attention is always last and order is fine
- inputs_dict["output_attentions"] = True
- inputs_dict["output_hidden_states"] = True
- model = model_class(config)
- model.to(torch_device)
- model.eval()
- with torch.no_grad():
- outputs = model(**self._prepare_for_class(inputs_dict, model_class))
-
- added_hidden_states = 1
- self.assertEqual(out_len + added_hidden_states, len(outputs))
-
- self_attentions = outputs.attentions
-
- self.assertEqual(len(self_attentions), expected_num_attentions)
- self.assertListEqual(
- list(self_attentions[0].shape[-3:]),
- [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
- )
-
- @unittest.skip(reason="Bamba has its own special cache type")
- @parameterized.expand([(1, False), (1, True), (4, False)])
- def test_new_cache_format(self, num_beams, do_sample):
- pass
-
- def test_batching_equivalence(self):
- # need to disable the tril input mask
- orig = self.model_tester.use_input_mask
- self.model_tester.use_input_mask = False
- super().test_batching_equivalence()
- self.model_tester.use_input_mask = orig
-
- # essentially the same test in test_utils, just adjustment for rtol for this model
- @pytest.mark.generate
- def test_left_padding_compatibility(self):
- # NOTE: left-padding results in small numerical differences. This is expected.
- # See https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535
-
- # First, filter out models that don't support left padding
- # - The model must have generative capabilities
- if len(self.all_generative_model_classes) == 0:
- self.skipTest(reason="No generative architecture available for this model.")
-
- # - The model must support padding
- if not self.has_attentions:
- self.skipTest(reason="This model doesn't support padding.")
-
- # - The model must be a decoder-only architecture (encoder-based architectures use right-padding)
- decoder_only_classes = []
- for model_class in self.all_generative_model_classes:
- config, _ = self.prepare_config_and_inputs_for_generate()
- if config.is_encoder_decoder:
- continue
- else:
- decoder_only_classes.append(model_class)
- if len(decoder_only_classes) == 0:
- self.skipTest(reason="No decoder-only architecture available for this model.")
-
- # - Decoder-only architectures derived from encoder-decoder models could support it in theory, but we haven't
- # added support for it yet. We skip these models for now.
- has_encoder_attributes = any(
- attr_name
- for attr_name in config.to_dict().keys()
- if attr_name.startswith("encoder") and attr_name != "encoder_no_repeat_ngram_size"
- )
- if has_encoder_attributes:
- self.skipTest(
- reason="The decoder-only derived from encoder-decoder models are not expected to support left-padding."
- )
-
- # Then, test left-padding
- def _prepare_model_kwargs(input_ids, attention_mask, signature):
- model_kwargs = {"input_ids": input_ids, "attention_mask": attention_mask}
- if "position_ids" in signature:
- position_ids = torch.cumsum(attention_mask, dim=-1) - 1
- position_ids.masked_fill_(attention_mask == 0, 1)
- model_kwargs["position_ids"] = position_ids
- if "cache_position" in signature:
- cache_position = torch.arange(input_ids.shape[-1], device=torch_device)
- model_kwargs["cache_position"] = cache_position
- return model_kwargs
-
- for model_class in decoder_only_classes:
- config, inputs_dict = self.prepare_config_and_inputs_for_generate()
- input_ids = inputs_dict["input_ids"]
-
- # - for left padding we absolutely need to use an all ones
- # attention mask, so we do not use the one in inputs_dict
- attention_mask = torch.ones_like(input_ids)
-
- model = model_class(config).to(torch_device).eval()
- signature = inspect.signature(model.forward).parameters.keys()
-
- # no cache as some models require special cache classes to be init outside forward
- model.generation_config.use_cache = False
-
- # Without padding
- model_kwargs = _prepare_model_kwargs(input_ids, attention_mask, signature)
- next_logits_wo_padding = model(**model_kwargs).logits[:, -1, :]
-
- # With left-padding (length 32)
- # can hardcode pad_token to be 0 as we'll do attn masking anyway
- pad_token_id = (
- config.get_text_config().pad_token_id if config.get_text_config().pad_token_id is not None else 0
- )
- pad_size = (input_ids.shape[0], 32)
- padding = torch.ones(pad_size, dtype=input_ids.dtype, device=torch_device) * pad_token_id
- padded_input_ids = torch.cat((padding, input_ids), dim=1)
- padded_attention_mask = torch.cat((torch.zeros_like(padding), attention_mask), dim=1)
- model_kwargs = _prepare_model_kwargs(padded_input_ids, padded_attention_mask, signature)
- next_logits_with_padding = model(**model_kwargs).logits[:, -1, :]
-
- # They should result in very similar logits
- torch.testing.assert_close(next_logits_wo_padding, next_logits_with_padding, atol=1e-5, rtol=1e-1)
-
-
-@slow
-@require_torch
-class BambaModelIntegrationTest(unittest.TestCase):
- model = None
- tokenizer = None
- # This variable is used to determine which CUDA device are we using for our runners (A10 or T4)
- # Depending on the hardware we get different logits / generations
- cuda_compute_capability_major_version = None
-
- @classmethod
- def setUpClass(cls):
- model_id = "ibm-fms/Bamba-9B"
- cls.model = BambaForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True)
- cls.tokenizer = AutoTokenizer.from_pretrained(model_id)
-
- # feels a bit forced to have to do this for the generation test
- cls.tokenizer.pad_token_id = cls.model.config.pad_token_id
- cls.tokenizer.padding_side = "left"
-
- if is_torch_available() and torch.cuda.is_available():
- # 8 is for A100 / A10 and 7 for T4
- cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]
-
- def test_simple_generate(self):
- # Key 9 for MI300, Key 8 for A100/A10, and Key 7 for T4.
- #
- # Note: Key 9 is currently set for MI300, but may need potential future adjustments for H100s,
- # considering differences in hardware processing and potential deviations in generated text.
- EXPECTED_TEXTS = {
- # 7: "",
- 8: "<|begin_of_text|>Hey how are you doing on this lovely evening? I hope you are all having a good time.",
- # 9: """,
- }
-
- self.model.to(torch_device)
-
- input_ids = self.tokenizer("Hey how are you doing on this lovely evening?", return_tensors="pt")[
- "input_ids"
- ].to(torch_device)
- out = self.model.generate(input_ids, do_sample=False, max_new_tokens=10)
- output_sentence = self.tokenizer.decode(out[0, :])
- self.assertEqual(output_sentence, EXPECTED_TEXTS[self.cuda_compute_capability_major_version])
-
- # TODO: there are significant differences in the logits across major cuda versions, which shouldn't exist
- if self.cuda_compute_capability_major_version == 8:
- with torch.no_grad():
- logits = self.model(input_ids=input_ids, num_logits_to_keep=40).logits
-
- EXPECTED_LOGITS_NO_GRAD = torch.tensor(
- [
- 149., 142., 146., 142., 143., 144., 142., 145.,
- 142., 146., 144., 146., 147., 147., 148., 145.,
- 147., 145., 145., 145., 145., 144., 144., 144.,
- 144., 145., 147., 146., 144., 144., 148., 147.,
- 148., 147., 147., 147., 146., 146., 148., 148.
- ], dtype=torch.bfloat16) # fmt: skip
-
- torch.testing.assert_close(logits[0, -1, :40].cpu(), EXPECTED_LOGITS_NO_GRAD, rtol=1e-3, atol=1)
-
- def test_simple_batched_generate_with_padding(self):
- # Key 9 for MI300, Key 8 for A100/A10, and Key 7 for T4.
- #
- # Note: Key 9 is currently set for MI300, but may need potential future adjustments for H100s,
- # considering differences in hardware processing and potential deviations in generated text.
- EXPECTED_TEXTS = {
- 7: [],
- 8: [
- "<|begin_of_text|>Hey how are you doing on this lovely evening? I hope you are doing well. I am here",
- "!!!<|begin_of_text|>I am late! I need to get to work! I have to get to the",
- ],
- 9: [],
- }
-
- self.model.to(torch_device)
-
- inputs = self.tokenizer(
- ["Hey how are you doing on this lovely evening?", "I am late! I need to"],
- padding=True,
- return_tensors="pt",
- ).to(torch_device)
- out = self.model.generate(**inputs, do_sample=False, max_new_tokens=10)
- output_sentences = self.tokenizer.batch_decode(out)
- self.assertEqual(output_sentences[0], EXPECTED_TEXTS[self.cuda_compute_capability_major_version][0])
- self.assertEqual(output_sentences[1], EXPECTED_TEXTS[self.cuda_compute_capability_major_version][1])
-
- # TODO: there are significant differences in the logits across major cuda versions, which shouldn't exist
- if self.cuda_compute_capability_major_version == 8:
- with torch.no_grad():
- logits = self.model(input_ids=inputs["input_ids"]).logits
-
- EXPECTED_LOGITS_NO_GRAD_0 = torch.tensor(
- [
- 149., 142., 146., 142., 143., 144., 142., 145.,
- 142., 146., 144., 146., 147., 147., 148., 145.,
- 147., 145., 145., 145., 145., 144., 144., 144.,
- 144., 145., 147., 146., 144., 144., 148., 147.,
- 148., 147., 147., 147., 146., 146., 148., 148.
- ], dtype=torch.bfloat16) # fmt: skip
-
- EXPECTED_LOGITS_NO_GRAD_1 = torch.tensor(
- [
- 182., 178., 177., 174., 176., 176., 178., 178.,
- 177., 179., 176., 183., 180., 182., 179., 174.,
- 178., 176., 176., 175., 175., 175., 174., 173.,
- 174., 182., 180., 176., 177., 177., 180., 176.,
- 178., 177., 177., 175., 176., 177., 175., 177.
- ], dtype=torch.bfloat16) # fmt: skip
-
- torch.testing.assert_close(logits[0, -1, :40].cpu(), EXPECTED_LOGITS_NO_GRAD_0, rtol=1e-3, atol=1)
- torch.testing.assert_close(logits[1, -1, :40].cpu(), EXPECTED_LOGITS_NO_GRAD_1, rtol=1e-3, atol=1)
diff --git a/tests/models/beit/test_image_processing_beit.py b/tests/models/beit/test_image_processing_beit.py
index 58175c6fe18c02..526a78a563ea36 100644
--- a/tests/models/beit/test_image_processing_beit.py
+++ b/tests/models/beit/test_image_processing_beit.py
@@ -33,7 +33,7 @@
from transformers import BeitImageProcessor
-class BeitImageProcessingTester:
+class BeitImageProcessingTester(unittest.TestCase):
def __init__(
self,
parent,
diff --git a/tests/models/beit/test_modeling_beit.py b/tests/models/beit/test_modeling_beit.py
index e54273f7839965..ac64f0fd3b0b11 100644
--- a/tests/models/beit/test_modeling_beit.py
+++ b/tests/models/beit/test_modeling_beit.py
@@ -14,35 +14,18 @@
# limitations under the License.
"""Testing suite for the PyTorch BEiT model."""
-import inspect
-import tempfile
import unittest
-import numpy as np
from datasets import load_dataset
from packaging import version
-from parameterized import parameterized
from transformers import BeitConfig
-from transformers.testing_utils import (
- require_torch,
- require_torch_multi_gpu,
- require_torch_sdpa,
- require_vision,
- slow,
- torch_device,
-)
-from transformers.utils import (
- cached_property,
- is_torch_available,
- is_torch_bf16_available_on_device,
- is_torch_fp16_available_on_device,
- is_vision_available,
-)
+from transformers.testing_utils import require_torch, require_torch_multi_gpu, require_vision, slow, torch_device
+from transformers.utils import cached_property, is_torch_available, is_vision_available
from ...test_backbone_common import BackboneTesterMixin
from ...test_configuration_common import ConfigTester
-from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor, sdpa_kernel
+from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
@@ -91,8 +74,6 @@ def __init__(
scope=None,
out_indices=[1, 2, 3, 4],
out_features=["stage1", "stage2", "stage3", "stage4"],
- attn_implementation="eager",
- mask_ratio=0.5,
):
self.parent = parent
self.vocab_size = vocab_size
@@ -119,8 +100,6 @@ def __init__(
# in BeiT, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token)
num_patches = (image_size // patch_size) ** 2
self.seq_length = num_patches + 1
- self.num_masks = int(mask_ratio * self.seq_length)
- self.attn_implementation = attn_implementation
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
@@ -152,7 +131,6 @@ def get_config(self):
initializer_range=self.initializer_range,
out_indices=self.out_indices,
out_features=self.out_features,
- attn_implementation=self.attn_implementation,
)
def create_and_check_model(self, config, pixel_values, labels, pixel_labels):
@@ -409,193 +387,6 @@ def test_model_from_pretrained(self):
model = BeitModel.from_pretrained(model_name)
self.assertIsNotNone(model)
- @parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
- @require_torch_sdpa
- def test_eager_matches_sdpa_inference(self, torch_dtype: str):
- # The common test modifies the num_hidden_layers to be 1. However, for Beit we want to
- # avoid that because the num_hidden_layers is generally assumed to be 4. Also, the code
- # related to attention masks in the original common tests is not required as the Beit
- # model does not handle attention masks. Furthermore, some extra code like modifying
- # the norm layers eps values for specialized configs and checking for the 'noise'
- # has been omitted to simply the test.
- if not self.has_attentions:
- self.skipTest(reason="Model architecture does not support attentions")
-
- if not self.all_model_classes[0]._supports_sdpa:
- self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA")
-
- if torch_dtype == "float16" and not is_torch_fp16_available_on_device(torch_device):
- self.skipTest(f"float16 not supported on {torch_device} (on the specific device currently used)")
-
- if torch_dtype == "bfloat16" and not is_torch_bf16_available_on_device(torch_device):
- self.skipTest(
- f"bfloat16 not supported on {torch_device} (on the specific device currently used, e.g. Nvidia T4 GPU)"
- )
-
- # Not sure whether it's fine to put torch.XXX in a decorator if torch is not available so hacking it here instead.
- if torch_dtype == "float16":
- torch_dtype = torch.float16
- elif torch_dtype == "bfloat16":
- torch_dtype = torch.bfloat16
- elif torch_dtype == "float32":
- torch_dtype = torch.float32
-
- atols = {
- ("cpu", False, torch.float32): 1e-6,
- ("cpu", False, torch.float16): 5e-3,
- ("cpu", False, torch.bfloat16): 1e-2,
- ("cpu", True, torch.float32): 1e-6,
- ("cpu", True, torch.float16): 5e-3,
- ("cpu", True, torch.bfloat16): 1e-2,
- ("cuda", False, torch.float32): 1e-6,
- ("cuda", False, torch.bfloat16): 1e-2,
- ("cuda", False, torch.float16): 5e-3,
- ("cuda", True, torch.float32): 1e-6,
- ("cuda", True, torch.bfloat16): 1e-2,
- ("cuda", True, torch.float16): 5e-3,
- }
- rtols = {
- ("cpu", False, torch.float32): 1e-4,
- ("cpu", False, torch.float16): 5e-3,
- ("cpu", False, torch.bfloat16): 1e-2,
- ("cpu", True, torch.float32): 1e-4,
- ("cpu", True, torch.float16): 5e-3,
- ("cpu", True, torch.bfloat16): 1e-2,
- ("cuda", False, torch.float32): 1e-4,
- ("cuda", False, torch.bfloat16): 1e-2,
- ("cuda", False, torch.float16): 5e-3,
- ("cuda", True, torch.float32): 1e-4,
- ("cuda", True, torch.bfloat16): 3e-2,
- ("cuda", True, torch.float16): 5e-3,
- }
-
- def get_mean_reldiff(failcase, x, ref, atol, rtol):
- return f"{failcase}: mean relative difference: {((x - ref).abs() / (ref.abs() + 1e-12)).mean():.3e}, torch atol = {atol}, torch rtol = {rtol}"
-
- for model_class in self.all_model_classes:
- config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
-
- config.rms_norm_eps = 1.0
- config.layer_norm_eps = 1.0
- config.norm_eps = 1.0
- config.norm_epsilon = 1.0
- config.layer_norm_epsilon = 1.0
-
- model = model_class(config)
- with tempfile.TemporaryDirectory() as tmpdirname:
- model.save_pretrained(tmpdirname)
- model_sdpa = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype, use_mask_token=True)
- model_sdpa = model_sdpa.eval().to(torch_device, dtype=torch_dtype)
-
- model_eager = model_class.from_pretrained(
- tmpdirname,
- torch_dtype=torch_dtype,
- attn_implementation="eager",
- use_mask_token=True,
- )
- model_eager = model_eager.eval().to(torch_device, dtype=torch_dtype)
-
- # Another way to make sure norm layers have desired epsilon. (Some models don't set it from its config.)
- for x in model_eager.modules():
- if isinstance(x, (nn.LayerNorm, nn.GroupNorm)):
- x.eps = 1.0
- for x in model_sdpa.modules():
- if isinstance(x, (nn.LayerNorm, nn.GroupNorm)):
- x.eps = 1.0
-
- # We use these for loops instead of parameterized.expand just for the interest of avoiding loading/saving 16 times the model,
- # but it would be nicer to have an efficient way to use parameterized.expand
- fail_cases = []
- for padding_side in ["left", "right"]:
- for use_mask in [False, True]:
- for output_attentions in [True, False]:
- can_output_attn = "output_attentions" in inspect.signature(model_sdpa.forward).parameters
- if not (self.has_attentions and can_output_attn) and output_attentions:
- continue
- # TODO: if we can also check with `batch_size=1` without being flaky?
- for batch_size in [7]:
- dummy_input = inputs_dict[model.main_input_name]
-
- if dummy_input.dtype in [torch.float32, torch.bfloat16, torch.float16]:
- dummy_input = dummy_input.to(torch_dtype)
-
- dummy_input = dummy_input[:batch_size]
- for enable_kernels in [False, True]:
- failcase = f"padding_side={padding_side}, use_mask={use_mask}, enable_kernels={enable_kernels}"
- processed_inputs = {
- model.main_input_name: dummy_input,
- "output_hidden_states": True,
- }
-
- if (
- self.has_attentions
- and "output_attentions" in inspect.signature(model_sdpa.forward).parameters
- ):
- processed_inputs["output_attentions"] = output_attentions
-
- if "bool_masked_pos" in inspect.signature(model_eager.forward).parameters:
- dummy_mask = torch.ones((self.model_tester.num_masks,))
- mask_length = self.model_tester.seq_length - 1 - dummy_mask.size(0)
- dummy_mask = torch.cat([dummy_mask, torch.zeros(mask_length)])
- dummy_bool_masked_pos = dummy_mask.expand(batch_size, -1).bool()
- processed_inputs["bool_masked_pos"] = dummy_bool_masked_pos.to(torch_device)
-
- with torch.no_grad():
- with sdpa_kernel(
- enable_flash=enable_kernels,
- enable_math=True,
- enable_mem_efficient=enable_kernels,
- ):
- prepared_inputs = self._prepare_for_class(processed_inputs, model_class)
- outputs_eager = model_eager(**prepared_inputs)
- outputs_sdpa = model_sdpa(**prepared_inputs)
-
- logits_eager = outputs_eager.hidden_states[-1]
- logits_sdpa = outputs_sdpa.hidden_states[-1]
- if torch_device in ["cpu", "cuda"]:
- atol = atols[torch_device, enable_kernels, torch_dtype]
- rtol = rtols[torch_device, enable_kernels, torch_dtype]
- elif torch_device == "xpu":
- # As of PyTorch 2.5 XPU backend supports only torch.nn.attention.SDPBackend.MATH
- # which is implemented on PyTorch level using aten operators and is
- # device agnostic with respect to implementation of each aten operator.
- atol = atols["cuda", False, torch_dtype]
- rtol = rtols["cuda", False, torch_dtype]
- else:
- atol = 1e-7
- rtol = 1e-4
-
- # Masked tokens output slightly deviates - we don't mind that.
- if use_mask:
- _logits_sdpa = torch.zeros_like(input=logits_sdpa)
- _logits_eager = torch.zeros_like(input=logits_eager)
-
- _logits_sdpa[:-1] = logits_sdpa[:-1]
- _logits_eager[:-1] = logits_eager[:-1]
-
- if padding_side == "left":
- _logits_sdpa[-1:, 2:] = logits_sdpa[-1:, 2:]
- _logits_eager[-1:, 2:] = logits_eager[-1:, 2:]
-
- elif padding_side == "right":
- _logits_sdpa[-1:, 2:] = logits_sdpa[-1:, :-2]
- _logits_eager[-1:, 2:] = logits_eager[-1:, :-2]
-
- logits_sdpa = _logits_sdpa
- logits_eager = _logits_eager
-
- results = [
- torch.allclose(_logits_sdpa, _logits_eager, atol=atol, rtol=rtol)
- for (_logits_sdpa, _logits_eager) in zip(logits_sdpa, logits_eager)
- ]
- # If 80% batch elements have matched results, it's fine
- if np.mean(results) < 0.8:
- fail_cases.append(
- get_mean_reldiff(failcase, logits_sdpa, logits_eager, atol, rtol)
- )
-
- self.assertTrue(len(fail_cases) == 0, "\n".join(fail_cases))
-
# We will verify our results on an image of cute cats
def prepare_img():
diff --git a/tests/models/clap/test_feature_extraction_clap.py b/tests/models/clap/test_feature_extraction_clap.py
index 0d6c00b79ddec4..d0e913df828b84 100644
--- a/tests/models/clap/test_feature_extraction_clap.py
+++ b/tests/models/clap/test_feature_extraction_clap.py
@@ -53,7 +53,7 @@ def floats_list(shape, scale=1.0, rng=None, name=None):
@require_torch
@require_torchaudio
# Copied from tests.models.whisper.test_feature_extraction_whisper.WhisperFeatureExtractionTester with Whisper->Clap
-class ClapFeatureExtractionTester:
+class ClapFeatureExtractionTester(unittest.TestCase):
def __init__(
self,
parent,
diff --git a/tests/models/clip/test_image_processing_clip.py b/tests/models/clip/test_image_processing_clip.py
index ef4fdc819b2c4e..740399d13fbb11 100644
--- a/tests/models/clip/test_image_processing_clip.py
+++ b/tests/models/clip/test_image_processing_clip.py
@@ -26,7 +26,7 @@
from transformers import CLIPImageProcessor
-class CLIPImageProcessingTester:
+class CLIPImageProcessingTester(unittest.TestCase):
def __init__(
self,
parent,
diff --git a/tests/models/clvp/test_feature_extraction_clvp.py b/tests/models/clvp/test_feature_extraction_clvp.py
index b57cb65ebb210d..1f059ca46944e1 100644
--- a/tests/models/clvp/test_feature_extraction_clvp.py
+++ b/tests/models/clvp/test_feature_extraction_clvp.py
@@ -57,7 +57,7 @@ def floats_list(shape, scale=1.0, rng=None, name=None):
@require_torch
-class ClvpFeatureExtractionTester:
+class ClvpFeatureExtractionTester(unittest.TestCase):
def __init__(
self,
parent,
diff --git a/tests/models/cohere/test_modeling_cohere.py b/tests/models/cohere/test_modeling_cohere.py
index d02dee553b4668..cd3b2f978e7ab7 100644
--- a/tests/models/cohere/test_modeling_cohere.py
+++ b/tests/models/cohere/test_modeling_cohere.py
@@ -40,11 +40,6 @@
# Copied from transformers.tests.models.llama.LlamaModelTester with Llama->Cohere
class CohereModelTester:
- config_class = CohereConfig
- if is_torch_available():
- model_class = CohereModel
- for_causal_lm_class = CohereForCausalLM
-
def __init__(
self,
parent,
@@ -56,7 +51,7 @@ def __init__(
use_labels=True,
vocab_size=99,
hidden_size=32,
- num_hidden_layers=4,
+ num_hidden_layers=2,
num_attention_heads=4,
intermediate_size=37,
hidden_act="gelu",
@@ -120,7 +115,7 @@ def prepare_config_and_inputs(self):
# Ignore copy
def get_config(self):
- return self.config_class(
+ return CohereConfig(
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
@@ -134,12 +129,13 @@ def get_config(self):
is_decoder=False,
initializer_range=self.initializer_range,
pad_token_id=self.pad_token_id,
+ eos_token_id=self.pad_token_id,
)
def create_and_check_model(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
- model = self.model_class(config=config)
+ model = CohereModel(config=config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=input_mask)
@@ -159,7 +155,7 @@ def create_and_check_model_as_decoder(
encoder_attention_mask,
):
config.add_cross_attention = True
- model = self.model_class(config)
+ model = CohereModel(config)
model.to(torch_device)
model.eval()
result = model(
@@ -188,7 +184,7 @@ def create_and_check_for_causal_lm(
encoder_hidden_states,
encoder_attention_mask,
):
- model = self.for_causal_lm_class(config=config)
+ model = CohereForCausalLM(config=config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=input_mask, labels=token_labels)
@@ -208,7 +204,7 @@ def create_and_check_decoder_model_past_large_inputs(
):
config.is_decoder = True
config.add_cross_attention = True
- model = self.for_causal_lm_class(config=config)
+ model = CohereForCausalLM(config=config)
model.to(torch_device)
model.eval()
@@ -285,7 +281,7 @@ class CohereModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
)
test_headmasking = False
test_pruning = False
- fx_compatible = False
+ fx_compatible = True
# Need to use `0.8` instead of `0.9` for `test_cpu_offload`
# This is because we are hitting edge cases with the causal_mask buffer
diff --git a/tests/models/cohere2/__init__.py b/tests/models/cohere2/__init__.py
deleted file mode 100644
index e69de29bb2d1d6..00000000000000
diff --git a/tests/models/cohere2/test_modeling_cohere2.py b/tests/models/cohere2/test_modeling_cohere2.py
deleted file mode 100644
index 8e1a4834d1ed41..00000000000000
--- a/tests/models/cohere2/test_modeling_cohere2.py
+++ /dev/null
@@ -1,347 +0,0 @@
-# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""Testing suite for the PyTorch Cohere2 model."""
-
-import unittest
-
-from packaging import version
-from parameterized import parameterized
-from pytest import mark
-
-from transformers import AutoModelForCausalLM, AutoTokenizer, Cohere2Config, HybridCache, is_torch_available, pipeline
-from transformers.generation.configuration_utils import GenerationConfig
-from transformers.testing_utils import (
- require_flash_attn,
- require_read_token,
- require_torch,
- require_torch_gpu,
- slow,
- torch_device,
-)
-
-from ...models.cohere.test_modeling_cohere import CohereModelTest, CohereModelTester
-from ...test_configuration_common import ConfigTester
-
-
-if is_torch_available():
- import torch
-
- from transformers import (
- Cohere2ForCausalLM,
- Cohere2Model,
- )
-
-
-class Cohere2ModelTester(CohereModelTester):
- config_class = Cohere2Config
- if is_torch_available():
- model_class = Cohere2Model
- for_causal_lm_class = Cohere2ForCausalLM
-
-
-@require_torch
-class Cohere2ModelTest(CohereModelTest, unittest.TestCase):
- all_model_classes = (Cohere2Model, Cohere2ForCausalLM) if is_torch_available() else ()
- all_generative_model_classes = (Cohere2ForCausalLM,) if is_torch_available() else ()
- pipeline_model_mapping = (
- {
- "feature-extraction": Cohere2Model,
- "text-generation": Cohere2ForCausalLM,
- }
- if is_torch_available()
- else {}
- )
- _is_stateful = True
-
- def setUp(self):
- self.model_tester = Cohere2ModelTester(self)
- self.config_tester = ConfigTester(self, config_class=Cohere2Config, hidden_size=37)
-
- @unittest.skip("Failing because of unique cache (HybridCache)")
- def test_model_outputs_equivalence(self, **kwargs):
- pass
-
- @unittest.skip("Cohere2's forcefully disables sdpa due to softcapping")
- def test_sdpa_can_dispatch_non_composite_models(self):
- pass
-
- @parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
- @unittest.skip("Cohere2's eager attn/sdpa attn outputs are expected to be different")
- def test_eager_matches_sdpa_inference(self):
- pass
-
- @unittest.skip("Cohere2's eager attn/sdpa attn outputs are expected to be different")
- def test_eager_matches_sdpa_generate(self):
- pass
-
- @parameterized.expand([("random",), ("same",)])
- @unittest.skip("Cohere2 has HybridCache which is not compatible with assisted decoding")
- def test_assisted_decoding_matches_greedy_search(self, assistant_type):
- pass
-
- @unittest.skip("Cohere2 has HybridCache which is not compatible with assisted decoding")
- def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type):
- pass
-
- @unittest.skip("Cohere2 has HybridCache which is not compatible with assisted decoding")
- def test_assisted_decoding_sample(self):
- pass
-
- @unittest.skip("Cohere2 has HybridCache which is not compatible with dola decoding")
- def test_dola_decoding_sample(self):
- pass
-
- @parameterized.expand([(1, False), (1, True), (4, False)])
- @unittest.skip("Cohere2 has HybridCache and doesn't support old tuple format at all")
- def test_new_cache_format(self, num_beams, do_sample):
- pass
-
- @unittest.skip("Cohere2 has HybridCache and doesn't support continue from past kv")
- def test_generate_continue_from_past_key_values(self):
- pass
-
- @unittest.skip("Cohere2 has HybridCache and doesn't support low_memory generation")
- def test_beam_search_low_memory(self):
- pass
-
- @unittest.skip("Cohere2 has HybridCache and doesn't support contrastive generation")
- def test_contrastive_generate(self):
- pass
-
- @unittest.skip("Cohere2 has HybridCache and doesn't support contrastive generation")
- def test_contrastive_generate_dict_outputs_use_cache(self):
- pass
-
- @unittest.skip("Cohere2 has HybridCache and doesn't support contrastive generation")
- def test_contrastive_generate_low_memory(self):
- pass
-
- @unittest.skip("Cohere2 has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.")
- def test_generate_with_static_cache(self):
- pass
-
- @unittest.skip("Cohere2 has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.")
- def test_generate_from_inputs_embeds_with_static_cache(self):
- pass
-
- # overwrite because HybridCache has fixed length for key/values
- def _check_attentions_for_generate(
- self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1
- ):
- self.assertIsInstance(attentions, tuple)
- self.assertListEqual(
- [isinstance(iter_attentions, tuple) for iter_attentions in attentions], [True] * len(attentions)
- )
- self.assertEqual(len(attentions), (max_length - min_length) * num_beam_groups)
-
- for idx, iter_attentions in enumerate(attentions):
- tgt_len = min_length + idx if not use_cache else 1
- src_len = min_length + idx if not use_cache else max_length
-
- expected_shape = (
- batch_size * num_beam_groups,
- config.num_attention_heads,
- tgt_len,
- src_len,
- )
- # check attn size
- self.assertListEqual(
- [layer_attention.shape for layer_attention in iter_attentions], [expected_shape] * len(iter_attentions)
- )
-
- # overwrite because HybridCache has fixed length for key/values
- def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config, num_beam_groups=1):
- self.assertIsInstance(past_key_values, HybridCache)
-
- # check shape key, value (batch, head, max_seq_length, head_features)
- head_dim = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
- num_key_value_heads = (
- config.num_attention_heads
- if getattr(config, "num_key_value_heads", None) is None
- else config.num_key_value_heads
- )
- num_hidden_layers = config.num_hidden_layers
-
- # we should get `max_length` in shape, not `max_length - embeds_length`
- # `+1` because the test in Mixin subtracts 1 which is needed for tuple cache
- static_cache_shape = (batch_size, num_key_value_heads, seq_length + 1, head_dim)
- static_layers = [layer_idx for layer_idx, boolean in enumerate(past_key_values.is_sliding) if not boolean]
- self.assertTrue(len(past_key_values.key_cache) == num_hidden_layers)
- self.assertTrue(past_key_values.key_cache[static_layers[0]].shape == static_cache_shape)
-
- @unittest.skip("Cohere2's eager attn/sdpa attn outputs are expected to be different")
- def test_sdpa_equivalence(self):
- pass
-
-
-@slow
-@require_torch_gpu
-class Cohere2IntegrationTest(unittest.TestCase):
- input_text = ["Hello I am doing", "Hi today"]
- # This variable is used to determine which CUDA device are we using for our runners (A10 or T4)
- # Depending on the hardware we get different logits / generations
- cuda_compute_capability_major_version = None
-
- @classmethod
- def setUpClass(cls):
- if is_torch_available() and torch.cuda.is_available():
- # 8 is for A100 / A10 and 7 for T4
- cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]
-
- @require_read_token
- @unittest.skip("Cohere2 has not been released yet")
- def test_model_bf16(self):
- model_id = "CohereForAI/command-r7b-12-2024"
- EXPECTED_TEXTS = [
- "Hello I am doing a project on the 1918 flu pandemic and I am trying to find out how many",
- "Hi today I'm going to be talking about the history of the United States. The United States of America",
- ]
-
- model = AutoModelForCausalLM.from_pretrained(
- model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, attn_implementation="eager"
- ).to(torch_device)
-
- tokenizer = AutoTokenizer.from_pretrained(model_id)
- inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device)
-
- output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
- output_text = tokenizer.batch_decode(output, skip_special_tokens=False)
-
- self.assertEqual(output_text, EXPECTED_TEXTS)
-
- @require_read_token
- @unittest.skip("Cohere2 has not been released yet")
- def test_model_fp16(self):
- model_id = "CohereForAI/command-r7b-12-2024"
- EXPECTED_TEXTS = [
- "Hello I am doing a project on the 1918 flu pandemic and I am trying to find out how many",
- "Hi today I'm going to be talking about the history of the United States. The United States of America",
- ]
-
- model = AutoModelForCausalLM.from_pretrained(
- model_id, low_cpu_mem_usage=True, torch_dtype=torch.float16, attn_implementation="eager"
- ).to(torch_device)
-
- tokenizer = AutoTokenizer.from_pretrained(model_id)
- inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device)
-
- output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
- output_text = tokenizer.batch_decode(output, skip_special_tokens=False)
-
- self.assertEqual(output_text, EXPECTED_TEXTS)
-
- @require_read_token
- @unittest.skip("Cohere2 has not been released yet")
- def test_model_pipeline_bf16(self):
- # See https://github.com/huggingface/transformers/pull/31747 -- pipeline was broken for Cohere2 before this PR
- model_id = "CohereForAI/command-r7b-12-2024"
- # EXPECTED_TEXTS should match the same non-pipeline test, minus the special tokens
- EXPECTED_TEXTS = [
- "Hello I am doing a project on the 1918 flu pandemic and I am trying to find out how many",
- "Hi today I'm going to be talking about the history of the United States. The United States of America",
- ]
-
- model = AutoModelForCausalLM.from_pretrained(
- model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, attn_implementation="flex_attention"
- ).to(torch_device)
- tokenizer = AutoTokenizer.from_pretrained(model_id)
- pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
-
- output = pipe(self.input_text, max_new_tokens=20, do_sample=False, padding=True)
-
- self.assertEqual(output[0][0]["generated_text"], EXPECTED_TEXTS[0])
- self.assertEqual(output[1][0]["generated_text"], EXPECTED_TEXTS[1])
-
- @require_read_token
- @require_flash_attn
- @require_torch_gpu
- @mark.flash_attn_test
- @slow
- @unittest.skip("Cohere2 has not been released yet")
- def test_model_flash_attn(self):
- # See https://github.com/huggingface/transformers/issues/31953 --- flash attn was generating garbage for Gemma2, especially in long context
- model_id = "CohereForAI/command-r7b-12-2024"
- EXPECTED_TEXTS = [
- 'Hello I am doing a project on the 1918 flu pandemic and I am trying to find out how many people died in the United States. I have found a few sites that say 500,000 but I am not sure if that is correct. I have also found a site that says 675,000 but I am not sure if that is correct either. I am trying to find out how many people died in the United States. I have found a few',
- "Hi today I'm going to be talking about the history of the United States. The United States of America is a country in North America. It is the third largest country in the world by total area and the third most populous country with over 320 million people. The United States is a federal republic consisting of 50 states and a federal district. The 48 contiguous states and the district of Columbia are in central North America between Canada and Mexico. The state of Alaska is in the"
- ] # fmt: skip
-
- model = AutoModelForCausalLM.from_pretrained(
- model_id, attn_implementation="flash_attention_2", torch_dtype="float16"
- ).to(torch_device)
- tokenizer = AutoTokenizer.from_pretrained(model_id)
- inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device)
-
- output = model.generate(**inputs, max_new_tokens=100, do_sample=False)
- output_text = tokenizer.batch_decode(output, skip_special_tokens=False)
-
- self.assertEqual(output_text, EXPECTED_TEXTS)
-
- @slow
- @require_read_token
- @unittest.skip("Cohere2 has not been released yet")
- def test_export_static_cache(self):
- if version.parse(torch.__version__) < version.parse("2.5.0"):
- self.skipTest(reason="This test requires torch >= 2.5 to run.")
-
- from transformers.integrations.executorch import (
- TorchExportableModuleWithStaticCache,
- convert_and_export_with_cache,
- )
-
- tokenizer = AutoTokenizer.from_pretrained(
- "CohereForAI/command-r7b-12-2024", pad_token="", padding_side="right"
- )
- EXPECTED_TEXT_COMPLETION = [
- "Hello I am doing a project for my school and I need to know how to make a program that will take a number",
- ]
- max_generation_length = tokenizer(EXPECTED_TEXT_COMPLETION, return_tensors="pt", padding=True)[
- "input_ids"
- ].shape[-1]
-
- # Load model
- device = "cpu"
- dtype = torch.bfloat16
- cache_implementation = "static"
- attn_implementation = "sdpa"
- batch_size = 1
- model = AutoModelForCausalLM.from_pretrained(
- "CohereForAI/command-r7b-12-2024",
- device_map=device,
- torch_dtype=dtype,
- attn_implementation=attn_implementation,
- generation_config=GenerationConfig(
- use_cache=True,
- cache_implementation=cache_implementation,
- max_length=max_generation_length,
- cache_config={
- "batch_size": batch_size,
- "max_cache_len": max_generation_length,
- },
- ),
- )
-
- prompts = ["Hello I am doing"]
- prompt_tokens = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)
- prompt_token_ids = prompt_tokens["input_ids"]
- max_new_tokens = max_generation_length - prompt_token_ids.shape[-1]
-
- # Static Cache + export
- exported_program = convert_and_export_with_cache(model)
- ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
- exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
- )
- ep_generated_text = tokenizer.batch_decode(ep_generated_ids, skip_special_tokens=True)
- self.assertEqual(EXPECTED_TEXT_COMPLETION, ep_generated_text)
diff --git a/tests/models/colpali/__init__.py b/tests/models/colpali/__init__.py
deleted file mode 100644
index e69de29bb2d1d6..00000000000000
diff --git a/tests/models/colpali/test_modeling_colpali.py b/tests/models/colpali/test_modeling_colpali.py
deleted file mode 100644
index 646726ac700ee5..00000000000000
--- a/tests/models/colpali/test_modeling_colpali.py
+++ /dev/null
@@ -1,368 +0,0 @@
-# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""Testing suite for the PyTorch ColPali model."""
-
-import gc
-import unittest
-from typing import ClassVar
-
-import torch
-from datasets import load_dataset
-from parameterized import parameterized
-
-from tests.test_configuration_common import ConfigTester
-from tests.test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
-from transformers import (
- is_torch_available,
- is_vision_available,
-)
-from transformers.models.colpali.configuration_colpali import ColPaliConfig
-from transformers.models.colpali.modeling_colpali import ColPaliForRetrieval, ColPaliForRetrievalOutput
-from transformers.models.colpali.processing_colpali import ColPaliProcessor
-from transformers.testing_utils import (
- require_torch,
- require_torch_sdpa,
- require_vision,
- slow,
- torch_device,
-)
-
-
-if is_torch_available():
- import torch
-
-if is_vision_available():
- pass
-
-
-class ColPaliForRetrievalModelTester:
- def __init__(
- self,
- parent,
- ignore_index=-100,
- image_token_index=0,
- projector_hidden_act="gelu",
- seq_length=25,
- vision_feature_select_strategy="default",
- vision_feature_layer=-1,
- projection_dim=32,
- text_config={
- "model_type": "gemma",
- "seq_length": 128,
- "is_training": True,
- "use_token_type_ids": False,
- "use_labels": True,
- "vocab_size": 99,
- "hidden_size": 32,
- "num_hidden_layers": 2,
- "num_attention_heads": 4,
- "num_key_value_heads": 1,
- "head_dim": 8,
- "intermediate_size": 37,
- "hidden_activation": "gelu_pytorch_tanh",
- "hidden_dropout_prob": 0.1,
- "attention_probs_dropout_prob": 0.1,
- "max_position_embeddings": 512,
- "type_vocab_size": 16,
- "type_sequence_label_size": 2,
- "initializer_range": 0.02,
- "num_labels": 3,
- "num_choices": 4,
- "pad_token_id": 1,
- },
- is_training=False,
- vision_config={
- "use_labels": True,
- "image_size": 20,
- "patch_size": 5,
- "num_image_tokens": 4,
- "num_channels": 3,
- "is_training": True,
- "hidden_size": 32,
- "projection_dim": 32,
- "num_key_value_heads": 1,
- "num_hidden_layers": 2,
- "num_attention_heads": 4,
- "intermediate_size": 37,
- "dropout": 0.1,
- "attention_dropout": 0.1,
- "initializer_range": 0.02,
- },
- use_cache=False,
- embedding_dim=128,
- ):
- self.parent = parent
- self.ignore_index = ignore_index
- # `image_token_index` is set to 0 to pass "resize_embeddings" test, do not modify
- self.image_token_index = image_token_index
- self.projector_hidden_act = projector_hidden_act
- self.vision_feature_select_strategy = vision_feature_select_strategy
- self.vision_feature_layer = vision_feature_layer
- self.text_config = text_config
- self.vision_config = vision_config
- self.seq_length = seq_length
- self.projection_dim = projection_dim
- self.pad_token_id = text_config["pad_token_id"]
-
- self.num_hidden_layers = text_config["num_hidden_layers"]
- self.vocab_size = text_config["vocab_size"]
- self.hidden_size = text_config["hidden_size"]
- self.num_attention_heads = text_config["num_attention_heads"]
- self.is_training = is_training
-
- self.batch_size = 3
- self.num_channels = vision_config["num_channels"]
- self.image_size = vision_config["image_size"]
- self.encoder_seq_length = seq_length
- self.use_cache = use_cache
-
- self.embedding_dim = embedding_dim
- self.vlm_config = {
- "model_type": "paligemma",
- "text_config": self.text_config,
- "vision_config": self.vision_config,
- "ignore_index": self.ignore_index,
- "image_token_index": self.image_token_index,
- "projector_hidden_act": self.projector_hidden_act,
- "projection_dim": self.projection_dim,
- "vision_feature_select_strategy": self.vision_feature_select_strategy,
- "vision_feature_layer": self.vision_feature_layer,
- }
-
- def get_config(self):
- return ColPaliConfig(
- vlm_config=self.vlm_config,
- embedding_dim=self.embedding_dim,
- )
-
- def prepare_config_and_inputs(self):
- pixel_values = floats_tensor(
- [
- self.batch_size,
- self.vision_config["num_channels"],
- self.vision_config["image_size"],
- self.vision_config["image_size"],
- ]
- )
- config = self.get_config()
-
- return config, pixel_values
-
- def prepare_config_and_inputs_for_common(self):
- config_and_inputs = self.prepare_config_and_inputs()
- config, pixel_values = config_and_inputs
- input_ids = ids_tensor([self.batch_size, self.seq_length], config.vlm_config.text_config.vocab_size - 1) + 1
- attention_mask = input_ids.ne(1).to(torch_device)
- # set the 16 first tokens to be image, and ensure that no other tokens are image tokens
- # do not change this unless you modified image size or patch size
- input_ids[input_ids == config.vlm_config.image_token_index] = self.pad_token_id
- input_ids[:, :16] = config.vlm_config.image_token_index
- inputs_dict = {
- "pixel_values": pixel_values,
- "input_ids": input_ids,
- "attention_mask": attention_mask,
- "labels": input_ids,
- "token_type_ids": torch.zeros_like(input_ids),
- }
- return config, inputs_dict
-
-
-@require_torch
-class ColPaliForRetrievalModelTest(ModelTesterMixin, unittest.TestCase):
- """
- Model tester for `ColPaliForRetrieval`.
- """
-
- all_model_classes = (ColPaliForRetrieval,) if is_torch_available() else ()
- fx_compatible = False
- test_torchscript = False
- test_pruning = False
- test_resize_embeddings = True
- test_head_masking = False
-
- def setUp(self):
- self.model_tester = ColPaliForRetrievalModelTester(self)
- self.config_tester = ConfigTester(self, config_class=ColPaliConfig, has_text_modality=False)
-
- # overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
-
- def test_inputs_embeds(self):
- config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
-
- for model_class in self.all_model_classes:
- model = model_class(config)
- model.to(torch_device)
- model.eval()
-
- inputs = self._prepare_for_class(inputs_dict, model_class)
-
- input_ids = inputs["input_ids"]
- del inputs["input_ids"]
- del inputs["pixel_values"]
-
- wte = model.get_input_embeddings()
- inputs["inputs_embeds"] = wte(input_ids)
-
- with torch.no_grad():
- model(**inputs)
-
- # overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
- # while some other models require pixel_values to be present
- def test_inputs_embeds_matches_input_ids(self):
- config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
-
- for model_class in self.all_model_classes:
- model = model_class(config)
- model.to(torch_device)
- model.eval()
-
- inputs = self._prepare_for_class(inputs_dict, model_class)
- input_ids = inputs["input_ids"]
- del inputs["input_ids"]
- del inputs["pixel_values"]
-
- inputs_embeds = model.get_input_embeddings()(input_ids)
-
- with torch.no_grad():
- out_ids = model(input_ids=input_ids, **inputs)[0]
- out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
- self.assertTrue(torch.allclose(out_embeds, out_ids))
-
- @slow
- @require_vision
- def test_colpali_forward_inputs(self):
- config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
-
- for model_class in self.all_model_classes:
- model = model_class(config)
- model.to(torch_device)
- model.eval()
-
- inputs = self._prepare_for_class(inputs_dict, model_class)
-
- with torch.no_grad():
- outputs = model(**inputs, return_dict=True)
-
- self.assertIsInstance(outputs, ColPaliForRetrievalOutput)
-
- @unittest.skip(
- reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
- )
- def test_training_gradient_checkpointing(self):
- pass
-
- @unittest.skip(
- reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
- )
- def test_training_gradient_checkpointing_use_reentrant(self):
- pass
-
- @unittest.skip(
- reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
- )
- def test_training_gradient_checkpointing_use_reentrant_false(self):
- pass
-
- @require_torch_sdpa
- @slow
- @parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
- def test_eager_matches_sdpa_inference(self, torch_dtype: str):
- self.skipTest(
- "Due to custom causal mask, there is a slightly too big difference between eager and sdpa in bfloat16."
- )
-
- @unittest.skip(
- reason="From PaliGemma: Some undefined behavior encountered with test versions of this model. Skip for now."
- )
- def test_model_parallelism(self):
- pass
-
- @unittest.skip(
- reason="PaliGemmma's SigLip encoder uses the same initialization scheme as the Flax original implementation"
- )
- def test_initialization(self):
- pass
-
- # TODO extend valid outputs to include this test @Molbap
- @unittest.skip(reason="PaliGemma has currently one output format.")
- def test_model_outputs_equivalence(self):
- pass
-
- @unittest.skip(reason="Pass because ColPali requires `attention_mask is not None`")
- def test_sdpa_can_dispatch_on_flash(self):
- pass
-
- @unittest.skip(reason="Pass because ColPali requires `attention_mask is not None`")
- def test_sdpa_can_compile_dynamic(self):
- pass
-
-
-@require_torch
-class ColPaliModelIntegrationTest(unittest.TestCase):
- model_name: ClassVar[str] = "vidore/colpali-v1.2-hf"
-
- def setUp(self):
- self.processor = ColPaliProcessor.from_pretrained(self.model_name)
-
- def tearDown(self):
- gc.collect()
- torch.cuda.empty_cache()
-
- @slow
- def test_model_integration_test(self):
- """
- Test if the model is able to retrieve the correct pages for a small and easy dataset.
- """
- model = ColPaliForRetrieval.from_pretrained(
- self.model_name,
- torch_dtype=torch.bfloat16,
- device_map=torch_device,
- ).eval()
-
- # Load the test dataset
- ds = load_dataset("hf-internal-testing/document-visual-retrieval-test", split="test")
-
- # Preprocess the examples
- batch_images = self.processor(images=ds["image"]).to(torch_device)
- batch_queries = self.processor(text=ds["query"]).to(torch_device)
-
- # Run inference
- with torch.inference_mode():
- image_embeddings = model(**batch_images).embeddings
- query_embeddings = model(**batch_queries).embeddings
-
- # Compute retrieval scores
- scores = self.processor.score_retrieval(
- query_embeddings=query_embeddings,
- passage_embeddings=image_embeddings,
- ) # (len(qs), len(ps))
-
- assert scores.ndim == 2, f"Expected 2D tensor, got {scores.ndim}"
- assert scores.shape == (len(ds), len(ds)), f"Expected shape {(len(ds), len(ds))}, got {scores.shape}"
-
- # Check if the maximum scores per row are in the diagonal of the matrix score
- self.assertTrue((scores.argmax(axis=1) == torch.arange(len(ds), device=scores.device)).all())
-
- # Further validation: fine-grained check, with a hardcoded score from the original implementation
- expected_scores = torch.tensor(
- [
- [15.5625, 6.5938, 14.4375],
- [12.2500, 16.2500, 11.0000],
- [15.0625, 11.7500, 21.0000],
- ],
- dtype=scores.dtype,
- )
-
- assert torch.allclose(scores, expected_scores, atol=1), f"Expected scores {expected_scores}, got {scores}"
diff --git a/tests/models/colpali/test_processing_colpali.py b/tests/models/colpali/test_processing_colpali.py
deleted file mode 100644
index 42592460fa28ed..00000000000000
--- a/tests/models/colpali/test_processing_colpali.py
+++ /dev/null
@@ -1,247 +0,0 @@
-import shutil
-import tempfile
-import unittest
-
-import torch
-
-from transformers import GemmaTokenizer
-from transformers.models.colpali.processing_colpali import ColPaliProcessor
-from transformers.testing_utils import get_tests_dir, require_torch, require_vision
-from transformers.utils import is_vision_available
-from transformers.utils.dummy_vision_objects import SiglipImageProcessor
-
-from ...test_processing_common import ProcessorTesterMixin
-
-
-if is_vision_available():
- from transformers import (
- ColPaliProcessor,
- PaliGemmaProcessor,
- SiglipImageProcessor,
- )
-
-SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model")
-
-
-@require_vision
-class ColPaliProcessorTest(ProcessorTesterMixin, unittest.TestCase):
- processor_class = ColPaliProcessor
-
- def setUp(self):
- self.tmpdirname = tempfile.mkdtemp()
- image_processor = SiglipImageProcessor.from_pretrained("google/siglip-so400m-patch14-384")
- image_processor.image_seq_length = 0
- tokenizer = GemmaTokenizer(SAMPLE_VOCAB, keep_accents=True)
- processor = PaliGemmaProcessor(image_processor=image_processor, tokenizer=tokenizer)
- processor.save_pretrained(self.tmpdirname)
-
- def tearDown(self):
- shutil.rmtree(self.tmpdirname)
-
- @require_torch
- @require_vision
- def test_process_images(self):
- # Processor configuration
- image_input = self.prepare_image_inputs()
- image_processor = self.get_component("image_processor")
- tokenizer = self.get_component("tokenizer", max_length=112, padding="max_length")
- image_processor.image_seq_length = 14
-
- # Get the processor
- processor = self.processor_class(
- tokenizer=tokenizer,
- image_processor=image_processor,
- )
-
- # Process the image
- batch_feature = processor.process_images(images=image_input, return_tensors="pt")
-
- # Assertions
- self.assertIn("pixel_values", batch_feature)
- self.assertEqual(batch_feature["pixel_values"].shape, torch.Size([1, 3, 384, 384]))
-
- @require_torch
- @require_vision
- def test_process_queries(self):
- # Inputs
- queries = [
- "Is attention really all you need?",
- "Are Benjamin, Antoine, Merve, and Jo best friends?",
- ]
-
- # Processor configuration
- image_processor = self.get_component("image_processor")
- tokenizer = self.get_component("tokenizer", max_length=112, padding="max_length")
- image_processor.image_seq_length = 14
-
- # Get the processor
- processor = self.processor_class(
- tokenizer=tokenizer,
- image_processor=image_processor,
- )
-
- # Process the image
- batch_feature = processor.process_queries(text=queries, return_tensors="pt")
-
- # Assertions
- self.assertIn("input_ids", batch_feature)
- self.assertIsInstance(batch_feature["input_ids"], torch.Tensor)
- self.assertEqual(batch_feature["input_ids"].shape[0], len(queries))
-
- # The following tests are overwritten as ColPaliProcessor can only take one of images or text as input at a time
-
- def test_tokenizer_defaults_preserved_by_kwargs(self):
- if "image_processor" not in self.processor_class.attributes:
- self.skipTest(f"image_processor attribute not present in {self.processor_class}")
- processor_components = self.prepare_components()
- processor_components["tokenizer"] = self.get_component("tokenizer", max_length=117, padding="max_length")
-
- processor = self.processor_class(**processor_components)
- self.skip_processor_without_typed_kwargs(processor)
- input_str = self.prepare_text_inputs()
- inputs = processor(text=input_str, return_tensors="pt")
- self.assertEqual(inputs[self.text_input_name].shape[-1], 117)
-
- def test_image_processor_defaults_preserved_by_image_kwargs(self):
- """
- We use do_rescale=True, rescale_factor=-1 to ensure that image_processor kwargs are preserved in the processor.
- We then check that the mean of the pixel_values is less than or equal to 0 after processing.
- Since the original pixel_values are in [0, 255], this is a good indicator that the rescale_factor is indeed applied.
- """
- if "image_processor" not in self.processor_class.attributes:
- self.skipTest(f"image_processor attribute not present in {self.processor_class}")
- processor_components = self.prepare_components()
- processor_components["image_processor"] = self.get_component(
- "image_processor", do_rescale=True, rescale_factor=-1
- )
- processor_components["tokenizer"] = self.get_component("tokenizer", max_length=117, padding="max_length")
-
- processor = self.processor_class(**processor_components)
- self.skip_processor_without_typed_kwargs(processor)
-
- image_input = self.prepare_image_inputs()
-
- inputs = processor(images=image_input, return_tensors="pt")
- self.assertLessEqual(inputs[self.images_input_name][0][0].mean(), 0)
-
- def test_kwargs_overrides_default_tokenizer_kwargs(self):
- if "image_processor" not in self.processor_class.attributes:
- self.skipTest(f"image_processor attribute not present in {self.processor_class}")
- processor_components = self.prepare_components()
- processor_components["tokenizer"] = self.get_component("tokenizer", padding="longest")
-
- processor = self.processor_class(**processor_components)
- self.skip_processor_without_typed_kwargs(processor)
- input_str = self.prepare_text_inputs()
- inputs = processor(text=input_str, return_tensors="pt", max_length=112, padding="max_length")
- self.assertEqual(inputs[self.text_input_name].shape[-1], 112)
-
- def test_kwargs_overrides_default_image_processor_kwargs(self):
- if "image_processor" not in self.processor_class.attributes:
- self.skipTest(f"image_processor attribute not present in {self.processor_class}")
- processor_components = self.prepare_components()
- processor_components["image_processor"] = self.get_component(
- "image_processor", do_rescale=True, rescale_factor=1
- )
- processor_components["tokenizer"] = self.get_component("tokenizer", max_length=117, padding="max_length")
-
- processor = self.processor_class(**processor_components)
- self.skip_processor_without_typed_kwargs(processor)
-
- image_input = self.prepare_image_inputs()
-
- inputs = processor(images=image_input, do_rescale=True, rescale_factor=-1, return_tensors="pt")
- self.assertLessEqual(inputs[self.images_input_name][0][0].mean(), 0)
-
- def test_unstructured_kwargs(self):
- if "image_processor" not in self.processor_class.attributes:
- self.skipTest(f"image_processor attribute not present in {self.processor_class}")
- processor_components = self.prepare_components()
- processor = self.processor_class(**processor_components)
- self.skip_processor_without_typed_kwargs(processor)
-
- input_str = self.prepare_text_inputs()
- inputs = processor(
- text=input_str,
- return_tensors="pt",
- do_rescale=True,
- rescale_factor=-1,
- padding="max_length",
- max_length=76,
- )
-
- self.assertEqual(inputs[self.text_input_name].shape[-1], 76)
-
- def test_unstructured_kwargs_batched(self):
- if "image_processor" not in self.processor_class.attributes:
- self.skipTest(f"image_processor attribute not present in {self.processor_class}")
- processor_components = self.prepare_components()
- processor = self.processor_class(**processor_components)
- self.skip_processor_without_typed_kwargs(processor)
-
- image_input = self.prepare_image_inputs(batch_size=2)
- inputs = processor(
- images=image_input,
- return_tensors="pt",
- do_rescale=True,
- rescale_factor=-1,
- padding="longest",
- max_length=76,
- )
-
- self.assertLessEqual(inputs[self.images_input_name][0][0].mean(), 0)
-
- def test_doubly_passed_kwargs(self):
- if "image_processor" not in self.processor_class.attributes:
- self.skipTest(f"image_processor attribute not present in {self.processor_class}")
- processor_components = self.prepare_components()
- processor = self.processor_class(**processor_components)
- self.skip_processor_without_typed_kwargs(processor)
-
- image_input = self.prepare_image_inputs()
- with self.assertRaises(ValueError):
- _ = processor(
- images=image_input,
- images_kwargs={"do_rescale": True, "rescale_factor": -1},
- do_rescale=True,
- return_tensors="pt",
- )
-
- def test_structured_kwargs_nested(self):
- if "image_processor" not in self.processor_class.attributes:
- self.skipTest(f"image_processor attribute not present in {self.processor_class}")
- processor_components = self.prepare_components()
- processor = self.processor_class(**processor_components)
- self.skip_processor_without_typed_kwargs(processor)
-
- input_str = self.prepare_text_inputs()
-
- # Define the kwargs for each modality
- all_kwargs = {
- "common_kwargs": {"return_tensors": "pt"},
- "images_kwargs": {"do_rescale": True, "rescale_factor": -1},
- "text_kwargs": {"padding": "max_length", "max_length": 76},
- }
-
- inputs = processor(text=input_str, **all_kwargs)
- self.skip_processor_without_typed_kwargs(processor)
-
- self.assertEqual(inputs[self.text_input_name].shape[-1], 76)
-
- def test_structured_kwargs_nested_from_dict(self):
- if "image_processor" not in self.processor_class.attributes:
- self.skipTest(f"image_processor attribute not present in {self.processor_class}")
- processor_components = self.prepare_components()
- processor = self.processor_class(**processor_components)
- self.skip_processor_without_typed_kwargs(processor)
- image_input = self.prepare_image_inputs()
-
- # Define the kwargs for each modality
- all_kwargs = {
- "common_kwargs": {"return_tensors": "pt"},
- "images_kwargs": {"do_rescale": True, "rescale_factor": -1},
- "text_kwargs": {"padding": "max_length", "max_length": 76},
- }
-
- inputs = processor(images=image_input, **all_kwargs)
- self.assertEqual(inputs[self.text_input_name].shape[-1], 76)
diff --git a/tests/models/conditional_detr/test_image_processing_conditional_detr.py b/tests/models/conditional_detr/test_image_processing_conditional_detr.py
index 4e46161a7bd0fa..32b135bcd220bd 100644
--- a/tests/models/conditional_detr/test_image_processing_conditional_detr.py
+++ b/tests/models/conditional_detr/test_image_processing_conditional_detr.py
@@ -35,7 +35,7 @@
from transformers import ConditionalDetrImageProcessor
-class ConditionalDetrImageProcessingTester:
+class ConditionalDetrImageProcessingTester(unittest.TestCase):
def __init__(
self,
parent,
diff --git a/tests/models/dac/test_feature_extraction_dac.py b/tests/models/dac/test_feature_extraction_dac.py
index 598a7c725eccb2..019a4f07c6abcb 100644
--- a/tests/models/dac/test_feature_extraction_dac.py
+++ b/tests/models/dac/test_feature_extraction_dac.py
@@ -51,7 +51,7 @@ def floats_list(shape, scale=1.0, rng=None, name=None):
@require_torch
# Copied from transformers.tests.encodec.test_feature_extraction_dac.EncodecFeatureExtractionTester with Encodec->Dac
-class DacFeatureExtractionTester:
+class DacFeatureExtractionTester(unittest.TestCase):
# Ignore copy
def __init__(
self,
diff --git a/tests/models/data2vec/test_modeling_data2vec_vision.py b/tests/models/data2vec/test_modeling_data2vec_vision.py
index 02276d905fa402..c729d88d614fbc 100644
--- a/tests/models/data2vec/test_modeling_data2vec_vision.py
+++ b/tests/models/data2vec/test_modeling_data2vec_vision.py
@@ -14,32 +14,14 @@
# limitations under the License.
"""Testing suite for the PyTorch Data2VecVision model."""
-import inspect
-import tempfile
import unittest
-import numpy as np
-from parameterized import parameterized
-
from transformers import Data2VecVisionConfig
-from transformers.testing_utils import (
- require_torch,
- require_torch_multi_gpu,
- require_torch_sdpa,
- require_vision,
- slow,
- torch_device,
-)
-from transformers.utils import (
- cached_property,
- is_torch_available,
- is_torch_bf16_available_on_device,
- is_torch_fp16_available_on_device,
- is_vision_available,
-)
+from transformers.testing_utils import require_torch, require_torch_multi_gpu, require_vision, slow, torch_device
+from transformers.utils import cached_property, is_torch_available, is_vision_available
from ...test_configuration_common import ConfigTester
-from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor, sdpa_kernel
+from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
@@ -84,8 +66,6 @@ def __init__(
num_labels=3,
scope=None,
out_indices=[0, 1, 2, 3],
- attn_implementation="eager",
- mask_ratio=0.5,
):
self.parent = parent
self.vocab_size = 100
@@ -111,8 +91,6 @@ def __init__(
# in BeiT, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token)
num_patches = (image_size // patch_size) ** 2
self.seq_length = num_patches + 1
- self.num_masks = int(mask_ratio * self.seq_length)
- self.attn_implementation = attn_implementation
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
@@ -143,7 +121,6 @@ def get_config(self):
is_decoder=False,
initializer_range=self.initializer_range,
out_indices=self.out_indices,
- attn_implementation=self.attn_implementation,
)
def create_and_check_model(self, config, pixel_values, labels, pixel_labels):
@@ -323,194 +300,6 @@ def test_model_from_pretrained(self):
model = Data2VecVisionModel.from_pretrained(model_name)
self.assertIsNotNone(model)
- @parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
- @require_torch_sdpa
- # Copied from tests.models.beit.test_modeling_beit.BeitModelTest.test_eager_matches_sdpa_inference with Beit->Data2VecVision
- def test_eager_matches_sdpa_inference(self, torch_dtype: str):
- # The common test modifies the num_hidden_layers to be 1. However, for Data2VecVision we want to
- # avoid that because the num_hidden_layers is generally assumed to be 4. Also, the code
- # related to attention masks in the original common tests is not required as the Data2VecVision
- # model does not handle attention masks. Furthermore, some extra code like modifying
- # the norm layers eps values for specialized configs and checking for the 'noise'
- # has been omitted to simply the test.
- if not self.has_attentions:
- self.skipTest(reason="Model architecture does not support attentions")
-
- if not self.all_model_classes[0]._supports_sdpa:
- self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA")
-
- if torch_dtype == "float16" and not is_torch_fp16_available_on_device(torch_device):
- self.skipTest(f"float16 not supported on {torch_device} (on the specific device currently used)")
-
- if torch_dtype == "bfloat16" and not is_torch_bf16_available_on_device(torch_device):
- self.skipTest(
- f"bfloat16 not supported on {torch_device} (on the specific device currently used, e.g. Nvidia T4 GPU)"
- )
-
- # Not sure whether it's fine to put torch.XXX in a decorator if torch is not available so hacking it here instead.
- if torch_dtype == "float16":
- torch_dtype = torch.float16
- elif torch_dtype == "bfloat16":
- torch_dtype = torch.bfloat16
- elif torch_dtype == "float32":
- torch_dtype = torch.float32
-
- atols = {
- ("cpu", False, torch.float32): 1e-6,
- ("cpu", False, torch.float16): 5e-3,
- ("cpu", False, torch.bfloat16): 1e-2,
- ("cpu", True, torch.float32): 1e-6,
- ("cpu", True, torch.float16): 5e-3,
- ("cpu", True, torch.bfloat16): 1e-2,
- ("cuda", False, torch.float32): 1e-6,
- ("cuda", False, torch.bfloat16): 1e-2,
- ("cuda", False, torch.float16): 5e-3,
- ("cuda", True, torch.float32): 1e-6,
- ("cuda", True, torch.bfloat16): 1e-2,
- ("cuda", True, torch.float16): 5e-3,
- }
- rtols = {
- ("cpu", False, torch.float32): 1e-4,
- ("cpu", False, torch.float16): 5e-3,
- ("cpu", False, torch.bfloat16): 1e-2,
- ("cpu", True, torch.float32): 1e-4,
- ("cpu", True, torch.float16): 5e-3,
- ("cpu", True, torch.bfloat16): 1e-2,
- ("cuda", False, torch.float32): 1e-4,
- ("cuda", False, torch.bfloat16): 1e-2,
- ("cuda", False, torch.float16): 5e-3,
- ("cuda", True, torch.float32): 1e-4,
- ("cuda", True, torch.bfloat16): 3e-2,
- ("cuda", True, torch.float16): 5e-3,
- }
-
- def get_mean_reldiff(failcase, x, ref, atol, rtol):
- return f"{failcase}: mean relative difference: {((x - ref).abs() / (ref.abs() + 1e-12)).mean():.3e}, torch atol = {atol}, torch rtol = {rtol}"
-
- for model_class in self.all_model_classes:
- config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
-
- config.rms_norm_eps = 1.0
- config.layer_norm_eps = 1.0
- config.norm_eps = 1.0
- config.norm_epsilon = 1.0
- config.layer_norm_epsilon = 1.0
-
- model = model_class(config)
- with tempfile.TemporaryDirectory() as tmpdirname:
- model.save_pretrained(tmpdirname)
- model_sdpa = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype, use_mask_token=True)
- model_sdpa = model_sdpa.eval().to(torch_device, dtype=torch_dtype)
-
- model_eager = model_class.from_pretrained(
- tmpdirname,
- torch_dtype=torch_dtype,
- attn_implementation="eager",
- use_mask_token=True,
- )
- model_eager = model_eager.eval().to(torch_device, dtype=torch_dtype)
-
- # Another way to make sure norm layers have desired epsilon. (Some models don't set it from its config.)
- for x in model_eager.modules():
- if isinstance(x, (nn.LayerNorm, nn.GroupNorm)):
- x.eps = 1.0
- for x in model_sdpa.modules():
- if isinstance(x, (nn.LayerNorm, nn.GroupNorm)):
- x.eps = 1.0
-
- # We use these for loops instead of parameterized.expand just for the interest of avoiding loading/saving 16 times the model,
- # but it would be nicer to have an efficient way to use parameterized.expand
- fail_cases = []
- for padding_side in ["left", "right"]:
- for use_mask in [False, True]:
- for output_attentions in [True, False]:
- can_output_attn = "output_attentions" in inspect.signature(model_sdpa.forward).parameters
- if not (self.has_attentions and can_output_attn) and output_attentions:
- continue
- # TODO: if we can also check with `batch_size=1` without being flaky?
- for batch_size in [7]:
- dummy_input = inputs_dict[model.main_input_name]
-
- if dummy_input.dtype in [torch.float32, torch.bfloat16, torch.float16]:
- dummy_input = dummy_input.to(torch_dtype)
-
- dummy_input = dummy_input[:batch_size]
- for enable_kernels in [False, True]:
- failcase = f"padding_side={padding_side}, use_mask={use_mask}, enable_kernels={enable_kernels}"
- processed_inputs = {
- model.main_input_name: dummy_input,
- "output_hidden_states": True,
- }
-
- if (
- self.has_attentions
- and "output_attentions" in inspect.signature(model_sdpa.forward).parameters
- ):
- processed_inputs["output_attentions"] = output_attentions
-
- if "bool_masked_pos" in inspect.signature(model_eager.forward).parameters:
- dummy_mask = torch.ones((self.model_tester.num_masks,))
- mask_length = self.model_tester.seq_length - 1 - dummy_mask.size(0)
- dummy_mask = torch.cat([dummy_mask, torch.zeros(mask_length)])
- dummy_bool_masked_pos = dummy_mask.expand(batch_size, -1).bool()
- processed_inputs["bool_masked_pos"] = dummy_bool_masked_pos.to(torch_device)
-
- with torch.no_grad():
- with sdpa_kernel(
- enable_flash=enable_kernels,
- enable_math=True,
- enable_mem_efficient=enable_kernels,
- ):
- prepared_inputs = self._prepare_for_class(processed_inputs, model_class)
- outputs_eager = model_eager(**prepared_inputs)
- outputs_sdpa = model_sdpa(**prepared_inputs)
-
- logits_eager = outputs_eager.hidden_states[-1]
- logits_sdpa = outputs_sdpa.hidden_states[-1]
- if torch_device in ["cpu", "cuda"]:
- atol = atols[torch_device, enable_kernels, torch_dtype]
- rtol = rtols[torch_device, enable_kernels, torch_dtype]
- elif torch_device == "xpu":
- # As of PyTorch 2.5 XPU backend supports only torch.nn.attention.SDPBackend.MATH
- # which is implemented on PyTorch level using aten operators and is
- # device agnostic with respect to implementation of each aten operator.
- atol = atols["cuda", False, torch_dtype]
- rtol = rtols["cuda", False, torch_dtype]
- else:
- atol = 1e-7
- rtol = 1e-4
-
- # Masked tokens output slightly deviates - we don't mind that.
- if use_mask:
- _logits_sdpa = torch.zeros_like(input=logits_sdpa)
- _logits_eager = torch.zeros_like(input=logits_eager)
-
- _logits_sdpa[:-1] = logits_sdpa[:-1]
- _logits_eager[:-1] = logits_eager[:-1]
-
- if padding_side == "left":
- _logits_sdpa[-1:, 2:] = logits_sdpa[-1:, 2:]
- _logits_eager[-1:, 2:] = logits_eager[-1:, 2:]
-
- elif padding_side == "right":
- _logits_sdpa[-1:, 2:] = logits_sdpa[-1:, :-2]
- _logits_eager[-1:, 2:] = logits_eager[-1:, :-2]
-
- logits_sdpa = _logits_sdpa
- logits_eager = _logits_eager
-
- results = [
- torch.allclose(_logits_sdpa, _logits_eager, atol=atol, rtol=rtol)
- for (_logits_sdpa, _logits_eager) in zip(logits_sdpa, logits_eager)
- ]
- # If 80% batch elements have matched results, it's fine
- if np.mean(results) < 0.8:
- fail_cases.append(
- get_mean_reldiff(failcase, logits_sdpa, logits_eager, atol, rtol)
- )
-
- self.assertTrue(len(fail_cases) == 0, "\n".join(fail_cases))
-
# We will verify our results on an image of cute cats
def prepare_img():
diff --git a/tests/models/detr/test_image_processing_detr.py b/tests/models/detr/test_image_processing_detr.py
index a0b469f2de92ff..f91c520873668f 100644
--- a/tests/models/detr/test_image_processing_detr.py
+++ b/tests/models/detr/test_image_processing_detr.py
@@ -19,7 +19,7 @@
import numpy as np
-from transformers.testing_utils import require_torch, require_torch_gpu, require_torchvision, require_vision, slow
+from transformers.testing_utils import require_torch, require_torch_gpu, require_vision, slow
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
from ...test_image_processing_common import AnnotationFormatTestMixin, ImageProcessingTestMixin, prepare_image_inputs
@@ -669,7 +669,6 @@ def test_longest_edge_shortest_edge_resizing_strategy(self):
@slow
@require_torch_gpu
- @require_torchvision
def test_fast_processor_equivalence_cpu_gpu_coco_detection_annotations(self):
# prepare image and target
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
@@ -725,7 +724,6 @@ def test_fast_processor_equivalence_cpu_gpu_coco_detection_annotations(self):
@slow
@require_torch_gpu
- @require_torchvision
def test_fast_processor_equivalence_cpu_gpu_coco_panoptic_annotations(self):
# prepare image, target and masks_path
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
diff --git a/tests/models/encodec/test_feature_extraction_encodec.py b/tests/models/encodec/test_feature_extraction_encodec.py
index 112f1022c00e8f..e56517ac410661 100644
--- a/tests/models/encodec/test_feature_extraction_encodec.py
+++ b/tests/models/encodec/test_feature_extraction_encodec.py
@@ -50,7 +50,7 @@ def floats_list(shape, scale=1.0, rng=None, name=None):
@require_torch
-class EnCodecFeatureExtractionTester:
+class EnCodecFeatureExtractionTester(unittest.TestCase):
def __init__(
self,
parent,
diff --git a/tests/models/encoder_decoder/test_modeling_encoder_decoder.py b/tests/models/encoder_decoder/test_modeling_encoder_decoder.py
index 1c4051f2e2645c..64ebedcb45984b 100644
--- a/tests/models/encoder_decoder/test_modeling_encoder_decoder.py
+++ b/tests/models/encoder_decoder/test_modeling_encoder_decoder.py
@@ -733,6 +733,15 @@ def test_sdpa_can_dispatch_composite_models(self):
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
raise ValueError("The eager model should not have SDPA attention layers")
+ has_sdpa = False
+ for name, submodule in model_sdpa.named_modules():
+ class_name = submodule.__class__.__name__
+ if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
+ has_sdpa = True
+ break
+ if not has_sdpa:
+ raise ValueError("The SDPA model should have SDPA attention layers")
+
@require_torch
class BertEncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCase):
diff --git a/tests/models/falcon/test_modeling_falcon.py b/tests/models/falcon/test_modeling_falcon.py
index 3ad46a92bc0938..ce04fae94ea904 100644
--- a/tests/models/falcon/test_modeling_falcon.py
+++ b/tests/models/falcon/test_modeling_falcon.py
@@ -50,6 +50,8 @@
FalconModel,
)
from transformers.models.falcon.modeling_falcon import (
+ FalconDynamicNTKScalingRotaryEmbedding,
+ FalconLinearScalingRotaryEmbedding,
FalconRotaryEmbedding,
)
@@ -453,9 +455,11 @@ def test_model_rope_scaling_from_config(self, scaling_type):
# The output should be different for long inputs
self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5))
- # Copied from tests.models.gpt_neox.test_modeling_gpt_neox.GPTNeoXModelTest.test_model_rope_scaling with GPTNeoX->Falcon
def test_model_rope_scaling(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+ hidden_size = config.hidden_size
+ num_heads = config.num_attention_heads
+ head_dim = hidden_size // num_heads
scaling_factor = 10
short_input_length = 10
long_input_length = int(config.max_position_embeddings * 1.5)
@@ -468,7 +472,11 @@ def test_model_rope_scaling(self):
position_ids_long = position_ids_long.unsqueeze(0)
# Sanity check original RoPE
- original_rope = FalconRotaryEmbedding(config).to(torch_device)
+ original_rope = FalconRotaryEmbedding(
+ head_dim,
+ max_position_embeddings=config.max_position_embeddings,
+ base=config.rope_theta,
+ ).to(torch_device)
original_cos_short, original_sin_short = original_rope(x, position_ids_short)
original_cos_long, original_sin_long = original_rope(x, position_ids_long)
torch.testing.assert_close(original_cos_short, original_cos_long[:, :short_input_length, :])
@@ -476,8 +484,12 @@ def test_model_rope_scaling(self):
# Sanity check linear RoPE scaling
# New position "x" should match original position with index "x/scaling_factor"
- config.rope_scaling = {"type": "linear", "factor": scaling_factor}
- linear_scaling_rope = FalconRotaryEmbedding(config).to(torch_device)
+ linear_scaling_rope = FalconLinearScalingRotaryEmbedding(
+ head_dim,
+ max_position_embeddings=config.max_position_embeddings,
+ base=config.rope_theta,
+ scaling_factor=scaling_factor,
+ ).to(torch_device)
linear_cos_short, linear_sin_short = linear_scaling_rope(x, position_ids_short)
linear_cos_long, linear_sin_long = linear_scaling_rope(x, position_ids_long)
torch.testing.assert_close(linear_cos_short, linear_cos_long[:, :short_input_length, :])
@@ -490,8 +502,12 @@ def test_model_rope_scaling(self):
# Sanity check Dynamic NTK RoPE scaling
# Scaling should only be observed after a long input is fed. We can observe that the frequencies increase
# with scaling_factor (or that `inv_freq` decreases)
- config.rope_scaling = {"type": "dynamic", "factor": scaling_factor}
- ntk_scaling_rope = FalconRotaryEmbedding(config).to(torch_device)
+ ntk_scaling_rope = FalconDynamicNTKScalingRotaryEmbedding(
+ head_dim,
+ max_position_embeddings=config.max_position_embeddings,
+ base=config.rope_theta,
+ scaling_factor=scaling_factor,
+ ).to(torch_device)
ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, position_ids_short)
ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, position_ids_long)
torch.testing.assert_close(ntk_cos_short, original_cos_short)
diff --git a/tests/models/falcon_mamba/test_modeling_falcon_mamba.py b/tests/models/falcon_mamba/test_modeling_falcon_mamba.py
index f02e8f167636eb..893132f4337dd4 100644
--- a/tests/models/falcon_mamba/test_modeling_falcon_mamba.py
+++ b/tests/models/falcon_mamba/test_modeling_falcon_mamba.py
@@ -43,6 +43,9 @@
FalconMambaModel,
)
from transformers.cache_utils import MambaCache
+ from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_0
+else:
+ is_torch_greater_or_equal_than_2_0 = False
# Copied from transformers.tests.models.mamba.MambaModelTester with Mamba->FalconMamba,mamba->falcon_mamba
@@ -243,6 +246,9 @@ def prepare_config_and_inputs_for_common(self):
return config, inputs_dict
+@unittest.skipIf(
+ not is_torch_greater_or_equal_than_2_0, reason="See https://github.com/huggingface/transformers/pull/24204"
+)
@require_torch
# Copied from transformers.tests.models.mamba.MambaModelTest with Mamba->Falcon,mamba->falcon_mamba,FalconMambaCache->MambaCache
class FalconMambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
diff --git a/tests/models/gpt2/test_modeling_gpt2.py b/tests/models/gpt2/test_modeling_gpt2.py
index 88ccdc8ee45a2d..012444b472c0fc 100644
--- a/tests/models/gpt2/test_modeling_gpt2.py
+++ b/tests/models/gpt2/test_modeling_gpt2.py
@@ -507,7 +507,7 @@ class GPT2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
else {}
)
all_parallelizable_model_classes = (GPT2LMHeadModel, GPT2DoubleHeadsModel) if is_torch_available() else ()
- fx_compatible = False # Broken by attention refactor cc @Cyrilvallez
+ fx_compatible = True
test_missing_keys = False
test_model_parallel = True
diff --git a/tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py b/tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py
index 281594492500b0..1db484c4062c35 100644
--- a/tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py
+++ b/tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py
@@ -37,6 +37,9 @@
GPTBigCodeModel,
)
from transformers.models.gpt_bigcode.modeling_gpt_bigcode import GPTBigCodeAttention
+ from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_12
+else:
+ is_torch_greater_or_equal_than_1_12 = False
class GPTBigCodeModelTester:
@@ -501,6 +504,10 @@ class GPTBigCodeMHAModelTest(GPTBigCodeModelTest):
multi_query = False
+@unittest.skipIf(
+ not is_torch_greater_or_equal_than_1_12,
+ reason="`GPTBigCode` checkpoints use `PytorchGELUTanh` which requires `torch>=1.12.0`.",
+)
@slow
@require_torch
class GPTBigCodeModelLanguageGenerationTest(unittest.TestCase):
diff --git a/tests/models/gpt_neox/test_modeling_gpt_neox.py b/tests/models/gpt_neox/test_modeling_gpt_neox.py
index 6d5e081d50b152..435133e93860ac 100644
--- a/tests/models/gpt_neox/test_modeling_gpt_neox.py
+++ b/tests/models/gpt_neox/test_modeling_gpt_neox.py
@@ -37,7 +37,11 @@
GPTNeoXForTokenClassification,
GPTNeoXModel,
)
- from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXRotaryEmbedding
+ from transformers.models.gpt_neox.modeling_gpt_neox import (
+ GPTNeoXDynamicNTKScalingRotaryEmbedding,
+ GPTNeoXLinearScalingRotaryEmbedding,
+ GPTNeoXRotaryEmbedding,
+ )
class GPTNeoXModelTester:
@@ -366,8 +370,12 @@ def test_model_rope_scaling_from_config(self, scaling_type):
# The output should be different for long inputs
self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5))
+ # Copied from tests.models.falcon.test_modeling_falcon.FalconModelTest.test_model_rope_scaling with Falcon->GPTNeoX, rope_theta->rotary_emb_base
def test_model_rope_scaling(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+ hidden_size = config.hidden_size
+ num_heads = config.num_attention_heads
+ head_dim = hidden_size // num_heads
scaling_factor = 10
short_input_length = 10
long_input_length = int(config.max_position_embeddings * 1.5)
@@ -380,7 +388,11 @@ def test_model_rope_scaling(self):
position_ids_long = position_ids_long.unsqueeze(0)
# Sanity check original RoPE
- original_rope = GPTNeoXRotaryEmbedding(config).to(torch_device)
+ original_rope = GPTNeoXRotaryEmbedding(
+ head_dim,
+ max_position_embeddings=config.max_position_embeddings,
+ base=config.rotary_emb_base,
+ ).to(torch_device)
original_cos_short, original_sin_short = original_rope(x, position_ids_short)
original_cos_long, original_sin_long = original_rope(x, position_ids_long)
torch.testing.assert_close(original_cos_short, original_cos_long[:, :short_input_length, :])
@@ -388,8 +400,12 @@ def test_model_rope_scaling(self):
# Sanity check linear RoPE scaling
# New position "x" should match original position with index "x/scaling_factor"
- config.rope_scaling = {"type": "linear", "factor": scaling_factor}
- linear_scaling_rope = GPTNeoXRotaryEmbedding(config).to(torch_device)
+ linear_scaling_rope = GPTNeoXLinearScalingRotaryEmbedding(
+ head_dim,
+ max_position_embeddings=config.max_position_embeddings,
+ base=config.rotary_emb_base,
+ scaling_factor=scaling_factor,
+ ).to(torch_device)
linear_cos_short, linear_sin_short = linear_scaling_rope(x, position_ids_short)
linear_cos_long, linear_sin_long = linear_scaling_rope(x, position_ids_long)
torch.testing.assert_close(linear_cos_short, linear_cos_long[:, :short_input_length, :])
@@ -402,8 +418,12 @@ def test_model_rope_scaling(self):
# Sanity check Dynamic NTK RoPE scaling
# Scaling should only be observed after a long input is fed. We can observe that the frequencies increase
# with scaling_factor (or that `inv_freq` decreases)
- config.rope_scaling = {"type": "dynamic", "factor": scaling_factor}
- ntk_scaling_rope = GPTNeoXRotaryEmbedding(config).to(torch_device)
+ ntk_scaling_rope = GPTNeoXDynamicNTKScalingRotaryEmbedding(
+ head_dim,
+ max_position_embeddings=config.max_position_embeddings,
+ base=config.rotary_emb_base,
+ scaling_factor=scaling_factor,
+ ).to(torch_device)
ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, position_ids_short)
ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, position_ids_long)
torch.testing.assert_close(ntk_cos_short, original_cos_short)
diff --git a/tests/models/gptj/test_modeling_gptj.py b/tests/models/gptj/test_modeling_gptj.py
index 50840bbcfaa6dc..afc741cd502dec 100644
--- a/tests/models/gptj/test_modeling_gptj.py
+++ b/tests/models/gptj/test_modeling_gptj.py
@@ -41,6 +41,9 @@
GPTJForSequenceClassification,
GPTJModel,
)
+ from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_12
+else:
+ is_torch_greater_or_equal_than_1_12 = False
class GPTJModelTester:
@@ -360,9 +363,15 @@ class GPTJModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
test_model_parallel = False
test_head_masking = False
+ @unittest.skipIf(
+ not is_torch_greater_or_equal_than_1_12, reason="PR #22069 made changes that require torch v1.12+."
+ )
def test_torch_fx(self):
super().test_torch_fx()
+ @unittest.skipIf(
+ not is_torch_greater_or_equal_than_1_12, reason="PR #22069 made changes that require torch v1.12+."
+ )
def test_torch_fx_output_loss(self):
super().test_torch_fx_output_loss()
diff --git a/tests/models/granite/test_modeling_granite.py b/tests/models/granite/test_modeling_granite.py
index 686544825c3551..60eb964927278a 100644
--- a/tests/models/granite/test_modeling_granite.py
+++ b/tests/models/granite/test_modeling_granite.py
@@ -14,12 +14,14 @@
# limitations under the License.
"""Testing suite for the PyTorch Granite model."""
+import tempfile
import unittest
from parameterized import parameterized
from transformers import GraniteConfig, is_torch_available, set_seed
from transformers.testing_utils import (
+ require_flash_attn,
require_read_token,
require_torch,
require_torch_gpu,
@@ -415,6 +417,33 @@ def test_model_rope_scaling(self):
with self.assertRaises(AssertionError):
torch.testing.assert_close(yarn_sin_long, original_sin_long)
+ @require_flash_attn
+ @require_torch_gpu
+ @slow
+ def test_use_flash_attention_2_true(self):
+ """
+ NOTE: this is the only test testing that the legacy `use_flash_attention=2` argument still works as intended.
+ """
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ for model_class in self.all_model_classes:
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ model = model_class(config)
+ model.save_pretrained(tmp_dir)
+
+ new_model = GraniteForCausalLM.from_pretrained(
+ tmp_dir, use_flash_attention_2=True, torch_dtype=torch.float16
+ ).to("cuda")
+
+ self.assertTrue(new_model.config._attn_implementation == "flash_attention_2")
+
+ has_flash = False
+ for name, submodule in new_model.named_modules():
+ if "FlashAttention" in submodule.__class__.__name__:
+ has_flash = True
+ break
+ if not has_flash:
+ raise ValueError("The flash model should have flash attention layers")
+
@require_torch_gpu
class GraniteIntegrationTest(unittest.TestCase):
diff --git a/tests/models/granitemoe/test_modeling_granitemoe.py b/tests/models/granitemoe/test_modeling_granitemoe.py
index 31307865a77da7..97af65667ed048 100644
--- a/tests/models/granitemoe/test_modeling_granitemoe.py
+++ b/tests/models/granitemoe/test_modeling_granitemoe.py
@@ -14,12 +14,14 @@
# limitations under the License.
"""Testing suite for the PyTorch GraniteMoe model."""
+import tempfile
import unittest
from parameterized import parameterized
from transformers import AutoTokenizer, GraniteMoeConfig, is_torch_available, set_seed
from transformers.testing_utils import (
+ require_flash_attn,
require_read_token,
require_torch,
require_torch_gpu,
@@ -414,6 +416,33 @@ def test_model_rope_scaling(self):
with self.assertRaises(AssertionError):
torch.testing.assert_close(yarn_sin_long, original_sin_long)
+ @require_flash_attn
+ @require_torch_gpu
+ @slow
+ def test_use_flash_attention_2_true(self):
+ """
+ NOTE: this is the only test testing that the legacy `use_flash_attention=2` argument still works as intended.
+ """
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ for model_class in self.all_model_classes:
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ model = model_class(config)
+ model.save_pretrained(tmp_dir)
+
+ new_model = GraniteMoeForCausalLM.from_pretrained(
+ tmp_dir, use_flash_attention_2=True, torch_dtype=torch.float16
+ ).to("cuda")
+
+ self.assertTrue(new_model.config._attn_implementation == "flash_attention_2")
+
+ has_flash = False
+ for name, submodule in new_model.named_modules():
+ if "FlashAttention" in submodule.__class__.__name__:
+ has_flash = True
+ break
+ if not has_flash:
+ raise ValueError("The flash model should have flash attention layers")
+
@require_torch_gpu
class GraniteMoeIntegrationTest(unittest.TestCase):
diff --git a/tests/models/grounding_dino/test_image_processing_grounding_dino.py b/tests/models/grounding_dino/test_image_processing_grounding_dino.py
index 5cc1e6c232c26e..bb8b9272efc952 100644
--- a/tests/models/grounding_dino/test_image_processing_grounding_dino.py
+++ b/tests/models/grounding_dino/test_image_processing_grounding_dino.py
@@ -37,7 +37,7 @@
from transformers import GroundingDinoImageProcessor
-class GroundingDinoImageProcessingTester:
+class GroundingDinoImageProcessingTester(unittest.TestCase):
def __init__(
self,
parent,
diff --git a/tests/models/hubert/test_modeling_hubert.py b/tests/models/hubert/test_modeling_hubert.py
index 191d2f8c88c380..86f2b4119324ae 100644
--- a/tests/models/hubert/test_modeling_hubert.py
+++ b/tests/models/hubert/test_modeling_hubert.py
@@ -943,40 +943,3 @@ def test_inference_distilhubert(self):
self.assertTrue(torch.allclose(outputs[:, :4, :4], expected_outputs_first, atol=5e-3))
self.assertTrue(torch.allclose(outputs[:, -4:, -4:], expected_outputs_last, atol=5e-3))
self.assertTrue(abs(outputs.sum() - expected_output_sum) < 0.1)
-
- def test_inference_hubert_25hz(self):
- model = HubertModel.from_pretrained("slprl/mhubert-base-25hz").to(torch_device)
-
- sample = self._load_datasamples(1)
- input_speech = torch.tensor(sample[0], dtype=torch.float, device=torch_device).unsqueeze(0)
-
- with torch.no_grad():
- outputs = model(input_speech, output_hidden_states=True).hidden_states[11]
-
- # expected outputs taken from the original textlesslib implementation by:
- # model = SpeechEncoder.by_name(dense_model_name='mhubert-base-25hz', quantizer_model_name='kmeans',
- # vocab_size=500, deduplicate=False, need_f0=False)
- # model(wav)['dense']
- expected_outputs_first = torch.tensor(
- [
- [0.0267, 0.1776, -0.1706, -0.4559],
- [-0.2430, -0.2943, -0.1864, -0.1187],
- [-0.1812, -0.4239, -0.1916, -0.0858],
- [-0.1495, -0.4758, -0.4036, 0.0302],
- ],
- device=torch_device,
- )
- expected_outputs_last = torch.tensor(
- [
- [0.3366, -0.2734, -0.1415, -0.3055],
- [0.2329, -0.3580, -0.1421, -0.3197],
- [0.1631, -0.4301, -0.1965, -0.2956],
- [0.3342, -0.2185, -0.2253, -0.2363],
- ],
- device=torch_device,
- )
- expected_output_sum = 1681.7603
-
- self.assertTrue(torch.allclose(outputs[:, :4, :4], expected_outputs_first, atol=5e-3))
- self.assertTrue(torch.allclose(outputs[:, -4:, -4:], expected_outputs_last, atol=5e-3))
- self.assertTrue(abs(outputs.sum() - expected_output_sum) < 0.1)
diff --git a/tests/models/idefics/test_image_processing_idefics.py b/tests/models/idefics/test_image_processing_idefics.py
index ad208881578cfb..2f7a8993df5348 100644
--- a/tests/models/idefics/test_image_processing_idefics.py
+++ b/tests/models/idefics/test_image_processing_idefics.py
@@ -36,7 +36,7 @@
from transformers import IdeficsImageProcessor
-class IdeficsImageProcessingTester:
+class IdeficsImageProcessingTester(unittest.TestCase):
def __init__(
self,
parent,
diff --git a/tests/models/idefics/test_modeling_idefics.py b/tests/models/idefics/test_modeling_idefics.py
index 94229b13d2cbfe..12004cc3c8ad89 100644
--- a/tests/models/idefics/test_modeling_idefics.py
+++ b/tests/models/idefics/test_modeling_idefics.py
@@ -44,6 +44,9 @@
from transformers import IdeficsForVisionText2Text, IdeficsModel, IdeficsProcessor
from transformers.models.idefics.configuration_idefics import IdeficsPerceiverConfig, IdeficsVisionConfig
+ from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_0
+else:
+ is_torch_greater_or_equal_than_2_0 = False
if is_vision_available():
from PIL import Image
@@ -324,6 +327,7 @@ def test_eager_matches_sdpa_generate(self):
self.skipTest(reason="Idefics has a hard requirement on SDPA, skipping this test")
+@unittest.skipIf(not is_torch_greater_or_equal_than_2_0, reason="pytorch 2.0 or higher is required")
@require_torch
class IdeficsModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (IdeficsModel, IdeficsForVisionText2Text) if is_torch_available() else ()
@@ -590,6 +594,7 @@ def test_sdpa_can_dispatch_non_composite_models(self):
pass
+@unittest.skipIf(not is_torch_greater_or_equal_than_2_0, reason="pytorch 2.0 or higher is required")
@require_torch
class IdeficsForVisionText2TextTest(IdeficsModelTest, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (IdeficsForVisionText2Text,) if is_torch_available() else ()
@@ -813,6 +818,7 @@ def test_sdpa_can_dispatch_non_composite_models(self):
pass
+@unittest.skipIf(not is_torch_greater_or_equal_than_2_0, reason="pytorch 2.0 or higher is required")
@require_torch
@require_vision
class IdeficsModelIntegrationTest(TestCasePlus):
diff --git a/tests/models/idefics2/test_image_processing_idefics2.py b/tests/models/idefics2/test_image_processing_idefics2.py
index bf9634b398b678..624fdd6c98b3e5 100644
--- a/tests/models/idefics2/test_image_processing_idefics2.py
+++ b/tests/models/idefics2/test_image_processing_idefics2.py
@@ -34,7 +34,7 @@
import torch
-class Idefics2ImageProcessingTester:
+class Idefics2ImageProcessingTester(unittest.TestCase):
def __init__(
self,
parent,
diff --git a/tests/models/idefics2/test_modeling_idefics2.py b/tests/models/idefics2/test_modeling_idefics2.py
index 974628c8b4324f..ae8c91f29d4d46 100644
--- a/tests/models/idefics2/test_modeling_idefics2.py
+++ b/tests/models/idefics2/test_modeling_idefics2.py
@@ -48,6 +48,8 @@
if is_torch_available():
import torch
+else:
+ is_torch_greater_or_equal_than_2_0 = False
if is_vision_available():
from PIL import Image
@@ -360,6 +362,15 @@ def test_sdpa_can_dispatch_composite_models(self):
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
raise ValueError("The eager model should not have SDPA attention layers")
+ has_sdpa = False
+ for name, submodule in model_sdpa.named_modules():
+ class_name = submodule.__class__.__name__
+ if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
+ has_sdpa = True
+ break
+ if not has_sdpa and model_sdpa.config.model_type != "falcon":
+ raise ValueError("The SDPA model should have SDPA attention layers")
+
@require_torch
class Idefics2ForConditionalGenerationModelTest(GenerationTesterMixin, ModelTesterMixin, unittest.TestCase):
diff --git a/tests/models/idefics3/test_modeling_idefics3.py b/tests/models/idefics3/test_modeling_idefics3.py
index c25fa1180649fa..5bfd4c3f3c0e83 100644
--- a/tests/models/idefics3/test_modeling_idefics3.py
+++ b/tests/models/idefics3/test_modeling_idefics3.py
@@ -40,6 +40,8 @@
Idefics3ForConditionalGeneration,
Idefics3Model,
)
+else:
+ is_torch_greater_or_equal_than_2_0 = False
if is_vision_available():
from PIL import Image
diff --git a/tests/models/ijepa/test_modeling_ijepa.py b/tests/models/ijepa/test_modeling_ijepa.py
index 723ddcf7988826..27a79bc6724285 100644
--- a/tests/models/ijepa/test_modeling_ijepa.py
+++ b/tests/models/ijepa/test_modeling_ijepa.py
@@ -250,7 +250,7 @@ def test_for_image_classification(self):
@slow
def test_model_from_pretrained(self):
- model_name = "facebook/ijepa_vith14_1k"
+ model_name = "jmtzt/ijepa_vith14_1k"
model = IJepaModel.from_pretrained(model_name)
self.assertIsNotNone(model)
@@ -266,11 +266,11 @@ def prepare_img():
class IJepaModelIntegrationTest(unittest.TestCase):
@cached_property
def default_image_processor(self):
- return ViTImageProcessor.from_pretrained("facebook/ijepa_vith14_1k") if is_vision_available() else None
+ return ViTImageProcessor.from_pretrained("jmtzt/ijepa_vith14_1k") if is_vision_available() else None
@slow
def test_inference_no_head(self):
- model = IJepaModel.from_pretrained("facebook/ijepa_vith14_1k").to(torch_device)
+ model = IJepaModel.from_pretrained("jmtzt/ijepa_vith14_1k").to(torch_device)
image_processor = self.default_image_processor
image = prepare_img()
@@ -299,7 +299,7 @@ def test_inference_fp16(self):
A small test to make sure that inference work in half precision without any problem.
"""
model = IJepaModel.from_pretrained(
- "facebook/ijepa_vith14_1k",
+ "jmtzt/ijepa_vith14_1k",
torch_dtype=torch.float16,
device_map="auto",
)
@@ -319,7 +319,7 @@ def test_inference_interpolate_pos_encoding(self):
# allowing to interpolate the pre-trained position embeddings in order to use
# the model on higher resolutions. The DINO model by Facebook AI leverages this
# to visualize self-attention on higher resolution images.
- model = IJepaModel.from_pretrained("facebook/ijepa_vith14_1k").to(torch_device)
+ model = IJepaModel.from_pretrained("jmtzt/ijepa_vith14_1k").to(torch_device)
image_processor = self.default_image_processor
image = prepare_img()
diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py
index feca640bb4a119..9e67f4f7381e24 100644
--- a/tests/models/llama/test_modeling_llama.py
+++ b/tests/models/llama/test_modeling_llama.py
@@ -14,8 +14,10 @@
# limitations under the License.
"""Testing suite for the PyTorch LLaMA model."""
+import tempfile
import unittest
+import pytest
from packaging import version
from parameterized import parameterized
@@ -23,6 +25,7 @@
from transformers.generation.configuration_utils import GenerationConfig
from transformers.testing_utils import (
cleanup,
+ require_flash_attn,
require_read_token,
require_torch,
require_torch_accelerator,
@@ -48,7 +51,7 @@
LlamaModel,
LlamaTokenizer,
)
- from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding
+ from transformers.models.llama.modeling_llama import LlamaLinearScalingRotaryEmbedding, LlamaRotaryEmbedding
class LlamaModelTester:
@@ -305,7 +308,7 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
)
test_headmasking = False
test_pruning = False
- fx_compatible = False # Broken by attention refactor cc @Cyrilvallez
+ fx_compatible = True
# Need to use `0.8` instead of `0.9` for `test_cpu_offload`
# This is because we are hitting edge cases with the causal_mask buffer
@@ -486,6 +489,43 @@ def test_model_rope_scaling(self):
with self.assertRaises(AssertionError):
torch.testing.assert_close(yarn_sin_long, original_sin_long)
+ def test_rope_class_retrocompatibility(self):
+ # Delete me when we remove compatibility for the old API :)
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+ scaling_factor = 10
+ short_input_length = 10
+ long_input_length = int(config.max_position_embeddings * 1.5)
+ config.rope_scaling = {"type": "linear", "factor": 10}
+
+ # Inputs
+ x = torch.randn(1, dtype=torch.float32, device=torch_device) # used exlusively to get the dtype and the device
+ position_ids_short = torch.arange(short_input_length, dtype=torch.long, device=torch_device)
+ position_ids_short = position_ids_short.unsqueeze(0)
+ position_ids_long = torch.arange(long_input_length, dtype=torch.long, device=torch_device)
+ position_ids_long = position_ids_long.unsqueeze(0)
+
+ # Old API -- under the hood, "type": "linear" is set and `LlamaRotaryEmbedding` is called
+ old_api_rope = LlamaLinearScalingRotaryEmbedding(
+ config.hidden_size // config.num_attention_heads,
+ max_position_embeddings=config.max_position_embeddings,
+ base=config.rope_theta,
+ scaling_factor=scaling_factor,
+ ).to(torch_device)
+ old_cos_short, old_sin_short = old_api_rope(x, position_ids_short)
+ old_cos_long, old_sin_long = old_api_rope(x, position_ids_long)
+
+ # New API
+ config.rope_scaling = {"type": "linear", "factor": scaling_factor}
+ new_api_rope = LlamaRotaryEmbedding(config=config).to(torch_device)
+ new_cos_short, new_sin_short = new_api_rope(x, position_ids_short)
+ new_cos_long, new_sin_long = new_api_rope(x, position_ids_long)
+
+ # The results should match
+ torch.testing.assert_close(old_cos_short, new_cos_short)
+ torch.testing.assert_close(old_sin_short, new_sin_short)
+ torch.testing.assert_close(old_cos_long, new_cos_long)
+ torch.testing.assert_close(old_sin_long, new_sin_long)
+
def test_model_loading_old_rope_configs(self):
def _reinitialize_config(base_config, new_kwargs):
# Reinitialize the config with the new kwargs, forcing the config to go through its __init__ validation
@@ -540,6 +580,38 @@ def _reinitialize_config(base_config, new_kwargs):
with self.assertRaises(KeyError):
config = _reinitialize_config(base_config, {"rope_scaling": {"rope_type": "linear"}}) # missing "factor"
+ @require_flash_attn
+ @require_torch_gpu
+ @slow
+ @pytest.mark.flash_attn_test
+ def test_use_flash_attention_2_true(self):
+ """
+ NOTE: this is the only test testing that the legacy `use_flash_attention=2` argument still works as intended.
+ """
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ for model_class in self.all_model_classes:
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ model = model_class(config)
+ model.save_pretrained(tmp_dir)
+
+ new_model = LlamaForCausalLM.from_pretrained(
+ tmp_dir, use_flash_attention_2=True, torch_dtype=torch.float16
+ ).to("cuda")
+
+ self.assertTrue(new_model.config._attn_implementation == "flash_attention_2")
+
+ has_flash = False
+ for name, submodule in new_model.named_modules():
+ if "FlashAttention" in submodule.__class__.__name__:
+ has_flash = True
+ break
+ if not has_flash:
+ raise ValueError("The flash model should have flash attention layers")
+
+ @unittest.skip("Broken by the loss update will fix soon @ArthurZucker")
+ def test_torch_fx_output_loss(self, *args, **kwargs):
+ pass
+
@require_torch_gpu
class LlamaIntegrationTest(unittest.TestCase):
diff --git a/tests/models/llava/test_modeling_llava.py b/tests/models/llava/test_modeling_llava.py
index b4a959a00d2a0c..3d08ab35e0f630 100644
--- a/tests/models/llava/test_modeling_llava.py
+++ b/tests/models/llava/test_modeling_llava.py
@@ -43,7 +43,8 @@
if is_torch_available():
import torch
-
+else:
+ is_torch_greater_or_equal_than_2_0 = False
if is_vision_available():
from PIL import Image
diff --git a/tests/models/llava_next/test_image_processing_llava_next.py b/tests/models/llava_next/test_image_processing_llava_next.py
index 4b3f5e0dd3ff42..fc399298c39a46 100644
--- a/tests/models/llava_next/test_image_processing_llava_next.py
+++ b/tests/models/llava_next/test_image_processing_llava_next.py
@@ -34,7 +34,7 @@
from transformers import LlavaNextImageProcessor
-class LlavaNextImageProcessingTester:
+class LlavaNextImageProcessingTester(unittest.TestCase):
def __init__(
self,
parent,
diff --git a/tests/models/llava_next/test_modeling_llava_next.py b/tests/models/llava_next/test_modeling_llava_next.py
index 14b0fb8cc07db7..c258ce96b94e48 100644
--- a/tests/models/llava_next/test_modeling_llava_next.py
+++ b/tests/models/llava_next/test_modeling_llava_next.py
@@ -48,7 +48,8 @@
import torch
from transformers.models.llava_next.modeling_llava_next import image_size_to_num_patches
-
+else:
+ is_torch_greater_or_equal_than_2_0 = False
if is_vision_available():
from PIL import Image
diff --git a/tests/models/llava_next_video/test_image_processing_llava_next_video.py b/tests/models/llava_next_video/test_image_processing_llava_next_video.py
index 385475c262f197..8c525fa256da07 100644
--- a/tests/models/llava_next_video/test_image_processing_llava_next_video.py
+++ b/tests/models/llava_next_video/test_image_processing_llava_next_video.py
@@ -33,7 +33,7 @@
from transformers import LlavaNextVideoImageProcessor
-class LlavaNextVideoProcessingTester:
+class LlavaNextVideoProcessingTester(unittest.TestCase):
def __init__(
self,
parent,
diff --git a/tests/models/llava_next_video/test_modeling_llava_next_video.py b/tests/models/llava_next_video/test_modeling_llava_next_video.py
index c431f91bf5102f..a6fb341ff9bf56 100644
--- a/tests/models/llava_next_video/test_modeling_llava_next_video.py
+++ b/tests/models/llava_next_video/test_modeling_llava_next_video.py
@@ -48,6 +48,8 @@
if is_torch_available():
import torch
+else:
+ is_torch_greater_or_equal_than_2_0 = False
if is_vision_available():
from PIL import Image
diff --git a/tests/models/llava_onevision/test_image_processing_llava_onevision.py b/tests/models/llava_onevision/test_image_processing_llava_onevision.py
index f392f2b8956d4b..47b6ef86c5dd10 100644
--- a/tests/models/llava_onevision/test_image_processing_llava_onevision.py
+++ b/tests/models/llava_onevision/test_image_processing_llava_onevision.py
@@ -33,7 +33,7 @@
from transformers import LlavaOnevisionImageProcessor, LlavaOnevisionVideoProcessor
-class LlavaOnevisionImageProcessingTester:
+class LlavaOnevisionImageProcessingTester(unittest.TestCase):
def __init__(
self,
parent,
diff --git a/tests/models/llava_onevision/test_modeling_llava_onevision.py b/tests/models/llava_onevision/test_modeling_llava_onevision.py
index 6965d2033ec730..a217eee2c70671 100644
--- a/tests/models/llava_onevision/test_modeling_llava_onevision.py
+++ b/tests/models/llava_onevision/test_modeling_llava_onevision.py
@@ -48,6 +48,8 @@
if is_torch_available():
import torch
+else:
+ is_torch_greater_or_equal_than_2_0 = False
if is_vision_available():
from PIL import Image
diff --git a/tests/models/mamba/test_modeling_mamba.py b/tests/models/mamba/test_modeling_mamba.py
index 455022140f7c5b..d432dfa93df487 100644
--- a/tests/models/mamba/test_modeling_mamba.py
+++ b/tests/models/mamba/test_modeling_mamba.py
@@ -38,6 +38,9 @@
MambaModel,
)
from transformers.models.mamba.modeling_mamba import MambaCache
+ from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_0
+else:
+ is_torch_greater_or_equal_than_2_0 = False
class MambaModelTester:
@@ -236,6 +239,9 @@ def prepare_config_and_inputs_for_common(self):
return config, inputs_dict
+@unittest.skipIf(
+ not is_torch_greater_or_equal_than_2_0, reason="See https://github.com/huggingface/transformers/pull/24204"
+)
@require_torch
class MambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (MambaModel, MambaForCausalLM) if is_torch_available() else ()
diff --git a/tests/models/mamba2/test_modeling_mamba2.py b/tests/models/mamba2/test_modeling_mamba2.py
index 17cbdc1e8d51dd..9b3a9563b58ddc 100644
--- a/tests/models/mamba2/test_modeling_mamba2.py
+++ b/tests/models/mamba2/test_modeling_mamba2.py
@@ -21,7 +21,6 @@
from transformers import AutoTokenizer, Mamba2Config, is_torch_available
from transformers.testing_utils import require_read_token, require_torch, require_torch_gpu, slow, torch_device
-from transformers.utils.import_utils import is_causal_conv1d_available, is_mamba_2_ssm_available
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
@@ -37,6 +36,9 @@
Mamba2Model,
)
from transformers.models.mamba2.modeling_mamba2 import Mamba2Cache, Mamba2Mixer
+ from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_0
+else:
+ is_torch_greater_or_equal_than_2_0 = False
class Mamba2ModelTester:
@@ -101,10 +103,6 @@ def prepare_config_and_inputs(
):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
- # Only left padding is valid
- attention_mask = torch.ones(size=(self.batch_size, self.seq_length), device=input_ids.device, dtype=torch.long)
- attention_mask[0, :1] = 0
-
sequence_labels = None
token_labels = None
choice_labels = None
@@ -120,7 +118,7 @@ def prepare_config_and_inputs(
return (
config,
input_ids,
- attention_mask,
+ None,
sequence_labels,
token_labels,
choice_labels,
@@ -160,57 +158,10 @@ def prepare_config_and_inputs_for_common(self):
inputs_dict = {"input_ids": input_ids}
return config, inputs_dict
- def create_and_check_mamba2_caching(self, config, input_ids, attention_mask, *args):
- model = Mamba2Model(config=config)
- model.to(torch_device)
- model.eval()
-
- output_whole = model(input_ids, attention_mask=attention_mask).last_hidden_state
-
- outputs = model(
- input_ids[:, :-1],
- attention_mask=attention_mask[:, :-1],
- use_cache=True,
- cache_position=torch.arange(0, config.conv_kernel, device=input_ids.device),
- )
- output_one = outputs.last_hidden_state
-
- # Using the state computed on the first inputs, we will get the same output
- outputs = model(
- input_ids[:, -1:],
- attention_mask=attention_mask[:, -1:],
- use_cache=True,
- cache_params=outputs.cache_params,
- cache_position=torch.arange(config.conv_kernel, config.conv_kernel + 1, device=input_ids.device),
- )
- output_two = outputs.last_hidden_state
-
- self.parent.assertTrue(
- torch.allclose(torch.cat([output_one, output_two], dim=1), output_whole, atol=1e-3, rtol=1e-3)
- )
-
- def create_and_check_mamba2_slow_vs_fast_forward(self, config, input_ids, *args, gradient_checkpointing=False):
- model = Mamba2Model(config)
- model.eval()
-
- if not (is_mamba_2_ssm_available() and is_causal_conv1d_available()):
- self.parent.skipTest(
- "This test needs the Mamba2 fast path. Skipping as the necessary packages have not been found."
- )
- if torch_device != "cuda":
- self.parent.skipTest("This test needs the Mamba2 fast path. Skipping as we need a cuda capable device.")
-
- model.to(torch_device)
- if gradient_checkpointing:
- model.gradient_checkpointing_enable()
-
- token_emb = model.embeddings(input_ids)
- outputs_fast = model.layers[0].mixer.cuda_kernels_forward(token_emb)
- outputs_slow = model.layers[0].mixer.torch_forward(token_emb)
-
- self.parent.assertTrue(torch.allclose(outputs_fast, outputs_slow, atol=1e-3, rtol=1e-3))
-
+@unittest.skipIf(
+ not is_torch_greater_or_equal_than_2_0, reason="See https://github.com/huggingface/transformers/pull/24204"
+)
@require_torch
class Mamba2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (Mamba2Model, Mamba2ForCausalLM) if is_torch_available() else ()
@@ -233,14 +184,6 @@ def setUp(self):
self, config_class=Mamba2Config, n_embd=37, common_properties=["hidden_size", "num_hidden_layers"]
)
- def test_mamba2_caching(self):
- config_and_inputs = self.model_tester.prepare_config_and_inputs()
- self.model_tester.create_and_check_mamba2_caching(*config_and_inputs)
-
- def test_mamba2_slow_vs_fast_forward(self):
- config_and_inputs = self.model_tester.prepare_config_and_inputs()
- self.model_tester.create_and_check_mamba2_slow_vs_fast_forward(*config_and_inputs)
-
def test_initialization(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
@@ -256,6 +199,23 @@ def test_initialization(self):
def test_tied_weights_keys(self):
pass
+ @unittest.skip(reason="To fix, Mamba 2 cache slicing test case is an edge case")
+ def test_generate_without_input_ids(self):
+ pass
+
+ @unittest.skip(reason="To fix, Mamba 2 cache slicing test case is an edge case")
+ @parameterized.expand([("greedy", 1), ("beam search", 2)])
+ def test_generate_from_inputs_embeds(self, _, num_beams):
+ pass
+
+ @unittest.skip(reason="To fix, Mamba 2 cache slicing test case is an edge case")
+ def test_greedy_generate_dict_outputs_use_cache(self):
+ pass
+
+ @unittest.skip(reason="To fix, Mamba 2 cache slicing is interacting with beam search")
+ def test_beam_search_generate_dict_outputs_use_cache(self):
+ pass
+
@unittest.skip(reason="A large mamba2 would be necessary (and costly) for that")
def test_multi_gpu_data_parallel_forward(self):
pass
diff --git a/tests/models/markuplm/test_feature_extraction_markuplm.py b/tests/models/markuplm/test_feature_extraction_markuplm.py
index 381483d65559db..4541cb9480bbe8 100644
--- a/tests/models/markuplm/test_feature_extraction_markuplm.py
+++ b/tests/models/markuplm/test_feature_extraction_markuplm.py
@@ -26,7 +26,7 @@
from transformers import MarkupLMFeatureExtractor
-class MarkupLMFeatureExtractionTester:
+class MarkupLMFeatureExtractionTester(unittest.TestCase):
def __init__(self, parent):
self.parent = parent
diff --git a/tests/models/mask2former/test_image_processing_mask2former.py b/tests/models/mask2former/test_image_processing_mask2former.py
index b298336a81ceb2..7468c3fd476a6e 100644
--- a/tests/models/mask2former/test_image_processing_mask2former.py
+++ b/tests/models/mask2former/test_image_processing_mask2former.py
@@ -39,7 +39,7 @@
from PIL import Image
-class Mask2FormerImageProcessingTester:
+class Mask2FormerImageProcessingTester(unittest.TestCase):
def __init__(
self,
parent,
diff --git a/tests/models/maskformer/test_image_processing_maskformer.py b/tests/models/maskformer/test_image_processing_maskformer.py
index 8b3c7db762a57d..23e517a32626f7 100644
--- a/tests/models/maskformer/test_image_processing_maskformer.py
+++ b/tests/models/maskformer/test_image_processing_maskformer.py
@@ -38,7 +38,7 @@
from PIL import Image
-class MaskFormerImageProcessingTester:
+class MaskFormerImageProcessingTester(unittest.TestCase):
def __init__(
self,
parent,
diff --git a/tests/models/mistral/test_modeling_mistral.py b/tests/models/mistral/test_modeling_mistral.py
index d9e6b9d7bfe7c0..c5ea050edf92ef 100644
--- a/tests/models/mistral/test_modeling_mistral.py
+++ b/tests/models/mistral/test_modeling_mistral.py
@@ -316,7 +316,7 @@ class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
)
test_headmasking = False
test_pruning = False
- fx_compatible = False # Broken by attention refactor cc @Cyrilvallez
+ fx_compatible = True
# TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146
def is_pipeline_test_to_skip(
diff --git a/tests/models/mixtral/test_modeling_mixtral.py b/tests/models/mixtral/test_modeling_mixtral.py
index 9abbf444d0b0b4..931bb1f17beccf 100644
--- a/tests/models/mixtral/test_modeling_mixtral.py
+++ b/tests/models/mixtral/test_modeling_mixtral.py
@@ -314,7 +314,7 @@ class MixtralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
)
test_headmasking = False
test_pruning = False
- fx_compatible = False # Broken by attention refactor cc @Cyrilvallez
+ fx_compatible = True
# TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146
def is_pipeline_test_to_skip(
diff --git a/tests/models/modernbert/__init__.py b/tests/models/modernbert/__init__.py
deleted file mode 100644
index e69de29bb2d1d6..00000000000000
diff --git a/tests/models/modernbert/test_modeling_modernbert.py b/tests/models/modernbert/test_modeling_modernbert.py
deleted file mode 100644
index 4fce0cd86352f0..00000000000000
--- a/tests/models/modernbert/test_modeling_modernbert.py
+++ /dev/null
@@ -1,367 +0,0 @@
-# coding=utf-8
-# Copyright 2020 The HuggingFace Team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-import os
-import unittest
-
-import pytest
-
-from transformers import ModernBertConfig, is_torch_available
-from transformers.models.auto import get_values
-from transformers.testing_utils import (
- CaptureLogger,
- require_flash_attn,
- require_torch,
- require_torch_gpu,
- slow,
- torch_device,
-)
-
-from ...generation.test_utils import GenerationTesterMixin
-from ...test_configuration_common import ConfigTester
-from ...test_modeling_common import ModelTesterMixin, _config_zero_init, ids_tensor, random_attention_mask
-from ...test_pipeline_mixin import PipelineTesterMixin
-
-
-if is_torch_available():
- import torch
-
- from transformers import (
- MODEL_FOR_PRETRAINING_MAPPING,
- ModernBertForMaskedLM,
- ModernBertForSequenceClassification,
- ModernBertForTokenClassification,
- ModernBertModel,
- logging,
- )
-
-
-class ModernBertModelTester:
- def __init__(
- self,
- parent,
- batch_size=13,
- seq_length=7,
- is_training=True,
- use_input_mask=True,
- use_labels=True,
- vocab_size=99,
- pad_token_id=0,
- hidden_size=32,
- num_hidden_layers=2,
- num_attention_heads=4,
- intermediate_size=37,
- hidden_activation="gelu",
- mlp_dropout=0.0,
- attention_dropout=0.0,
- embedding_dropout=0.0,
- classifier_dropout=0.0,
- max_position_embeddings=512,
- type_vocab_size=16,
- type_sequence_label_size=2,
- initializer_range=0.02,
- num_labels=3,
- num_choices=4,
- scope=None,
- ):
- self.parent = parent
- self.batch_size = batch_size
- self.seq_length = seq_length
- self.is_training = is_training
- self.use_input_mask = use_input_mask
- self.use_labels = use_labels
- self.vocab_size = vocab_size
- self.pad_token_id = pad_token_id
- self.hidden_size = hidden_size
- self.num_hidden_layers = num_hidden_layers
- self.num_attention_heads = num_attention_heads
- self.intermediate_size = intermediate_size
- self.hidden_activation = hidden_activation
- self.mlp_dropout = mlp_dropout
- self.attention_dropout = attention_dropout
- self.embedding_dropout = embedding_dropout
- self.classifier_dropout = classifier_dropout
- self.max_position_embeddings = max_position_embeddings
- self.type_vocab_size = type_vocab_size
- self.type_sequence_label_size = type_sequence_label_size
- self.initializer_range = initializer_range
- self.num_labels = num_labels
- self.num_choices = num_choices
- self.scope = scope
-
- def prepare_config_and_inputs(self):
- input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
-
- input_mask = None
- if self.use_input_mask:
- input_mask = random_attention_mask([self.batch_size, self.seq_length])
-
- sequence_labels = None
- token_labels = None
- choice_labels = None
- if self.use_labels:
- sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
- token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
- choice_labels = ids_tensor([self.batch_size], self.num_choices)
-
- config = self.get_config()
-
- return config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
-
- def get_config(self):
- """
- Returns a tiny configuration by default.
- """
- config = ModernBertConfig(
- vocab_size=self.vocab_size,
- pad_token_id=self.pad_token_id,
- hidden_size=self.hidden_size,
- num_hidden_layers=self.num_hidden_layers,
- num_attention_heads=self.num_attention_heads,
- intermediate_size=self.intermediate_size,
- hidden_activation=self.hidden_activation,
- mlp_dropout=self.mlp_dropout,
- attention_dropout=self.attention_dropout,
- embedding_dropout=self.embedding_dropout,
- classifier_dropout=self.classifier_dropout,
- max_position_embeddings=self.max_position_embeddings,
- type_vocab_size=self.type_vocab_size,
- is_decoder=False,
- initializer_range=self.initializer_range,
- )
- if test := os.environ.get("PYTEST_CURRENT_TEST", False):
- test_name = test.split(":")[-1].split(" ")[0]
-
- # If we're testing `test_retain_grad_hidden_states_attentions`, we normally get an error
- # that compilation doesn't work. Users can then set compile=False when loading the model,
- # much like here. We're testing whether it works once they've done that.
- if test_name == "test_retain_grad_hidden_states_attentions":
- config.reference_compile = False
- # Some tests require attentions to be outputted, in that case we'll set the attention implementation to eager
- # as the others don't support outputted attentions
- if test_name in (
- "test_attention_outputs",
- "test_hidden_states_output",
- "test_retain_grad_hidden_states_attentions",
- ):
- config._attn_implementation = "eager"
- return config
-
- def create_and_check_model(self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels):
- model = ModernBertModel(config=config)
- model.to(torch_device)
- model.eval()
- result = model(input_ids, attention_mask=input_mask)
- result = model(input_ids)
- result = model(input_ids)
- self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
-
- def create_and_check_for_masked_lm(
- self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
- ):
- model = ModernBertForMaskedLM(config=config)
- model.to(torch_device)
- model.eval()
- result = model(input_ids, attention_mask=input_mask, labels=token_labels)
- self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
-
- def create_and_check_for_sequence_classification(
- self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
- ):
- config.num_labels = self.num_labels
- model = ModernBertForSequenceClassification(config)
- model.to(torch_device)
- model.eval()
- result = model(input_ids, attention_mask=input_mask, labels=sequence_labels)
- self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
-
- def create_and_check_for_token_classification(
- self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
- ):
- config.num_labels = self.num_labels
- model = ModernBertForTokenClassification(config=config)
- model.to(torch_device)
- model.eval()
- result = model(input_ids, attention_mask=input_mask, labels=token_labels)
- self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
-
- def prepare_config_and_inputs_for_common(self):
- config_and_inputs = self.prepare_config_and_inputs()
- (
- config,
- input_ids,
- input_mask,
- sequence_labels,
- token_labels,
- choice_labels,
- ) = config_and_inputs
- inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
- return config, inputs_dict
-
-
-@require_torch
-class ModernBertModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
- test_torchscript = False
-
- all_model_classes = (
- (
- ModernBertModel,
- ModernBertForMaskedLM,
- ModernBertForSequenceClassification,
- ModernBertForTokenClassification,
- )
- if is_torch_available()
- else ()
- )
- all_generative_model_classes = ()
- pipeline_model_mapping = (
- {
- "feature-extraction": ModernBertModel,
- "fill-mask": ModernBertForMaskedLM,
- "text-classification": ModernBertForSequenceClassification,
- "token-classification": ModernBertForTokenClassification,
- "zero-shot": ModernBertForSequenceClassification,
- }
- if is_torch_available()
- else {}
- )
- fx_compatible = False
- test_head_masking = False
- test_pruning = False
- model_split_percents = [0.5, 0.8, 0.9]
-
- # special case for ForPreTraining model
- def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
- inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
-
- if inputs_dict.get("output_attentions", False):
- inputs_dict["output_attentions"] = True
-
- if return_labels:
- if model_class in get_values(MODEL_FOR_PRETRAINING_MAPPING):
- inputs_dict["labels"] = torch.zeros(
- (self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
- )
- inputs_dict["next_sentence_label"] = torch.zeros(
- self.model_tester.batch_size, dtype=torch.long, device=torch_device
- )
- return inputs_dict
-
- def setUp(self):
- self.model_tester = ModernBertModelTester(self)
- self.config_tester = ConfigTester(self, config_class=ModernBertConfig, hidden_size=37)
-
- def test_config(self):
- self.config_tester.run_common_tests()
-
- def test_model(self):
- config_and_inputs = self.model_tester.prepare_config_and_inputs()
- self.model_tester.create_and_check_model(*config_and_inputs)
-
- def test_model_various_embeddings(self):
- config_and_inputs = self.model_tester.prepare_config_and_inputs()
- for type in ["absolute", "relative_key", "relative_key_query"]:
- config_and_inputs[0].position_embedding_type = type
- self.model_tester.create_and_check_model(*config_and_inputs)
-
- def test_initialization(self):
- config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
-
- configs_no_init = _config_zero_init(config)
- for model_class in self.all_model_classes:
- model = model_class(config=configs_no_init)
- for name, param in model.named_parameters():
- # The classifier.weight from ModernBertForSequenceClassification and ModernBertForTokenClassification
- # are initialized without `initializer_range`, so they're not set to ~0 via the _config_zero_init
- if param.requires_grad and not (
- name == "classifier.weight"
- and model_class in [ModernBertForSequenceClassification, ModernBertForTokenClassification]
- ):
- self.assertIn(
- ((param.data.mean() * 1e9).round() / 1e9).item(),
- [0.0, 1.0],
- msg=f"Parameter {name} of model {model_class} seems not properly initialized",
- )
-
- @unittest.skip("ModernBert doesn't use `inputs_embeds` as input.")
- def test_inputs_embeds(self):
- pass
-
- def test_for_masked_lm(self):
- config_and_inputs = self.model_tester.prepare_config_and_inputs()
- self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)
-
- def test_for_sequence_classification(self):
- config_and_inputs = self.model_tester.prepare_config_and_inputs()
- self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs)
-
- def test_for_token_classification(self):
- config_and_inputs = self.model_tester.prepare_config_and_inputs()
- self.model_tester.create_and_check_for_token_classification(*config_and_inputs)
-
- def test_for_warning_if_padding_and_no_attention_mask(self):
- (
- config,
- input_ids,
- input_mask,
- sequence_labels,
- token_labels,
- choice_labels,
- ) = self.model_tester.prepare_config_and_inputs()
-
- # Set pad tokens in the input_ids
- input_ids[0, 0] = config.pad_token_id
-
- # Check for warnings if the attention_mask is missing.
- logger = logging.get_logger("transformers.modeling_utils")
- # clear cache so we can test the warning is emitted (from `warning_once`).
- logger.warning_once.cache_clear()
-
- with CaptureLogger(logger) as cl:
- model = ModernBertModel(config=config)
- model.to(torch_device)
- model.eval()
- model(input_ids, attention_mask=None)
- self.assertIn("We strongly recommend passing in an `attention_mask`", cl.out)
-
- @unittest.skip("ModernBert doesn't use separate classes for SDPA, but a function instead.")
- def test_sdpa_can_dispatch_non_composite_models(self):
- pass
-
- @slow
- def test_model_from_pretrained(self):
- model_name = "google-bert/bert-base-uncased"
- model = ModernBertModel.from_pretrained(model_name)
- self.assertIsNotNone(model)
-
- @require_flash_attn
- @require_torch_gpu
- @pytest.mark.flash_attn_test
- @slow
- def test_flash_attn_2_inference_equivalence_right_padding(self):
- self.skipTest(reason="ModernBert flash attention does not support right padding")
-
- @require_flash_attn
- @require_torch_gpu
- @pytest.mark.flash_attn_test
- @slow
- def test_flash_attn_2_conversion(self):
- self.skipTest(reason="ModernBert doesn't use the ModernBertFlashAttention2 class method.")
-
-
-@require_torch
-class ModernBertModelIntegrationTest(unittest.TestCase):
- """
- These still need to be written, once public models are available.
- """
diff --git a/tests/models/musicgen_melody/test_feature_extraction_musicgen_melody.py b/tests/models/musicgen_melody/test_feature_extraction_musicgen_melody.py
index bdd1cb1e12871d..697e3fb146ec17 100644
--- a/tests/models/musicgen_melody/test_feature_extraction_musicgen_melody.py
+++ b/tests/models/musicgen_melody/test_feature_extraction_musicgen_melody.py
@@ -69,7 +69,7 @@ def get_bip_bip(bip_duration=0.125, duration=0.5, sample_rate=32000):
@require_torch
@require_torchaudio
-class MusicgenMelodyFeatureExtractionTester:
+class MusicgenMelodyFeatureExtractionTester(unittest.TestCase):
def __init__(
self,
parent,
diff --git a/tests/models/musicgen_melody/test_modeling_musicgen_melody.py b/tests/models/musicgen_melody/test_modeling_musicgen_melody.py
index 98b554be65fbf9..bc8baa2746adde 100644
--- a/tests/models/musicgen_melody/test_modeling_musicgen_melody.py
+++ b/tests/models/musicgen_melody/test_modeling_musicgen_melody.py
@@ -41,9 +41,6 @@
require_torch_gpu,
require_torch_sdpa,
require_torchaudio,
- set_config_for_less_flaky_test,
- set_model_for_less_flaky_test,
- set_model_tester_for_less_flaky_test,
slow,
torch_device,
)
@@ -519,11 +516,8 @@ def test_eager_matches_sdpa_inference(self, torch_dtype: str):
def get_mean_reldiff(failcase, x, ref, atol, rtol):
return f"{failcase}: mean relative difference: {((x - ref).abs() / (ref.abs() + 1e-12)).mean():.3e}, torch atol = {atol}, torch rtol = {rtol}"
- set_model_tester_for_less_flaky_test(self)
-
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
- set_config_for_less_flaky_test(config)
model = model_class(config)
is_encoder_decoder = model.config.is_encoder_decoder
@@ -540,9 +534,6 @@ def get_mean_reldiff(failcase, x, ref, atol, rtol):
)
model_eager = model_eager.eval().to(torch_device)
- set_model_for_less_flaky_test(model_eager)
- set_model_for_less_flaky_test(model_sdpa)
-
# We use these for loops instead of parameterized.expand just for the interest of avoiding loading/saving 8 times the model,
# but it would be nicer to have an efficient way to use parameterized.expand
fail_cases = []
@@ -1537,11 +1528,8 @@ def test_eager_matches_sdpa_inference(self, torch_dtype: str):
def get_mean_reldiff(failcase, x, ref, atol, rtol):
return f"{failcase}: mean relative difference: {((x - ref).abs() / (ref.abs() + 1e-12)).mean():.3e}, torch atol = {atol}, torch rtol = {rtol}"
- set_model_tester_for_less_flaky_test(self)
-
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
- set_config_for_less_flaky_test(config)
model = model_class(config)
is_encoder_decoder = model.config.is_encoder_decoder
@@ -1558,9 +1546,6 @@ def get_mean_reldiff(failcase, x, ref, atol, rtol):
)
model_eager = model_eager.eval().to(torch_device)
- set_model_for_less_flaky_test(model_eager)
- set_model_for_less_flaky_test(model_sdpa)
-
# We use these for loops instead of parameterized.expand just for the interest of avoiding loading/saving 8 times the model,
# but it would be nicer to have an efficient way to use parameterized.expand
fail_cases = []
diff --git a/tests/models/oneformer/test_image_processing_oneformer.py b/tests/models/oneformer/test_image_processing_oneformer.py
index 7ac52f76b48adf..853bf241dd9fdb 100644
--- a/tests/models/oneformer/test_image_processing_oneformer.py
+++ b/tests/models/oneformer/test_image_processing_oneformer.py
@@ -39,7 +39,7 @@
from PIL import Image
-class OneFormerImageProcessorTester:
+class OneFormerImageProcessorTester(unittest.TestCase):
def __init__(
self,
parent,
diff --git a/tests/models/oneformer/test_processor_oneformer.py b/tests/models/oneformer/test_processor_oneformer.py
index dae50040ec042b..3a8a378b49009e 100644
--- a/tests/models/oneformer/test_processor_oneformer.py
+++ b/tests/models/oneformer/test_processor_oneformer.py
@@ -59,7 +59,7 @@ def prepare_metadata(class_info_file, repo_path="shi-labs/oneformer_demo"):
return metadata
-class OneFormerProcessorTester:
+class OneFormerProcessorTester(unittest.TestCase):
def __init__(
self,
parent,
diff --git a/tests/models/paligemma/test_modeling_paligemma.py b/tests/models/paligemma/test_modeling_paligemma.py
index f973e1211dc081..5ffea7ffe55087 100644
--- a/tests/models/paligemma/test_modeling_paligemma.py
+++ b/tests/models/paligemma/test_modeling_paligemma.py
@@ -40,7 +40,8 @@
if is_torch_available():
import torch
-
+else:
+ is_torch_greater_or_equal_than_2_0 = False
if is_vision_available():
from PIL import Image
diff --git a/tests/models/persimmon/test_modeling_persimmon.py b/tests/models/persimmon/test_modeling_persimmon.py
index e783cea95a63b3..99d84f9b5b5b09 100644
--- a/tests/models/persimmon/test_modeling_persimmon.py
+++ b/tests/models/persimmon/test_modeling_persimmon.py
@@ -46,7 +46,11 @@
PersimmonForTokenClassification,
PersimmonModel,
)
- from transformers.models.persimmon.modeling_persimmon import PersimmonRotaryEmbedding
+ from transformers.models.persimmon.modeling_persimmon import (
+ PersimmonDynamicNTKScalingRotaryEmbedding,
+ PersimmonLinearScalingRotaryEmbedding,
+ PersimmonRotaryEmbedding,
+ )
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester with Llama->Persimmon
@@ -417,9 +421,12 @@ def test_model_rope_scaling_from_config(self, scaling_type):
# The output should be different for long inputs
self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5))
- # Copied from tests.models.gpt_neox.test_modeling_gpt_neox.GPTNeoXModelTest.test_model_rope_scaling with GPTNeoX->Persimmon
+ # Copied from tests.models.falcon.test_modeling_falcon.FalconModelTest.test_model_rope_scaling with Falcon->Persimmon
def test_model_rope_scaling(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+ hidden_size = config.hidden_size
+ num_heads = config.num_attention_heads
+ head_dim = hidden_size // num_heads
scaling_factor = 10
short_input_length = 10
long_input_length = int(config.max_position_embeddings * 1.5)
@@ -432,7 +439,11 @@ def test_model_rope_scaling(self):
position_ids_long = position_ids_long.unsqueeze(0)
# Sanity check original RoPE
- original_rope = PersimmonRotaryEmbedding(config).to(torch_device)
+ original_rope = PersimmonRotaryEmbedding(
+ head_dim,
+ max_position_embeddings=config.max_position_embeddings,
+ base=config.rope_theta,
+ ).to(torch_device)
original_cos_short, original_sin_short = original_rope(x, position_ids_short)
original_cos_long, original_sin_long = original_rope(x, position_ids_long)
torch.testing.assert_close(original_cos_short, original_cos_long[:, :short_input_length, :])
@@ -440,8 +451,12 @@ def test_model_rope_scaling(self):
# Sanity check linear RoPE scaling
# New position "x" should match original position with index "x/scaling_factor"
- config.rope_scaling = {"type": "linear", "factor": scaling_factor}
- linear_scaling_rope = PersimmonRotaryEmbedding(config).to(torch_device)
+ linear_scaling_rope = PersimmonLinearScalingRotaryEmbedding(
+ head_dim,
+ max_position_embeddings=config.max_position_embeddings,
+ base=config.rope_theta,
+ scaling_factor=scaling_factor,
+ ).to(torch_device)
linear_cos_short, linear_sin_short = linear_scaling_rope(x, position_ids_short)
linear_cos_long, linear_sin_long = linear_scaling_rope(x, position_ids_long)
torch.testing.assert_close(linear_cos_short, linear_cos_long[:, :short_input_length, :])
@@ -454,8 +469,12 @@ def test_model_rope_scaling(self):
# Sanity check Dynamic NTK RoPE scaling
# Scaling should only be observed after a long input is fed. We can observe that the frequencies increase
# with scaling_factor (or that `inv_freq` decreases)
- config.rope_scaling = {"type": "dynamic", "factor": scaling_factor}
- ntk_scaling_rope = PersimmonRotaryEmbedding(config).to(torch_device)
+ ntk_scaling_rope = PersimmonDynamicNTKScalingRotaryEmbedding(
+ head_dim,
+ max_position_embeddings=config.max_position_embeddings,
+ base=config.rope_theta,
+ scaling_factor=scaling_factor,
+ ).to(torch_device)
ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, position_ids_short)
ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, position_ids_long)
torch.testing.assert_close(ntk_cos_short, original_cos_short)
diff --git a/tests/models/phi/test_modeling_phi.py b/tests/models/phi/test_modeling_phi.py
index c7b59d278e4fe6..eae6789bef252e 100644
--- a/tests/models/phi/test_modeling_phi.py
+++ b/tests/models/phi/test_modeling_phi.py
@@ -42,7 +42,11 @@
PhiForTokenClassification,
PhiModel,
)
- from transformers.models.phi.modeling_phi import PhiRotaryEmbedding
+ from transformers.models.phi.modeling_phi import (
+ PhiDynamicNTKScalingRotaryEmbedding,
+ PhiLinearScalingRotaryEmbedding,
+ PhiRotaryEmbedding,
+ )
class PhiModelTester:
@@ -396,9 +400,12 @@ def test_model_rope_scaling_from_config(self, scaling_type):
# The output should be different for long inputs
self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5))
- # Copied from tests.models.gpt_neox.test_modeling_gpt_neox.GPTNeoXModelTest.test_model_rope_scaling with GPTNeoX->Phi
+ # Copied from tests.models.falcon.test_modeling_falcon.FalconModelTest.test_model_rope_scaling with Falcon->Phi
def test_model_rope_scaling(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+ hidden_size = config.hidden_size
+ num_heads = config.num_attention_heads
+ head_dim = hidden_size // num_heads
scaling_factor = 10
short_input_length = 10
long_input_length = int(config.max_position_embeddings * 1.5)
@@ -411,7 +418,11 @@ def test_model_rope_scaling(self):
position_ids_long = position_ids_long.unsqueeze(0)
# Sanity check original RoPE
- original_rope = PhiRotaryEmbedding(config).to(torch_device)
+ original_rope = PhiRotaryEmbedding(
+ head_dim,
+ max_position_embeddings=config.max_position_embeddings,
+ base=config.rope_theta,
+ ).to(torch_device)
original_cos_short, original_sin_short = original_rope(x, position_ids_short)
original_cos_long, original_sin_long = original_rope(x, position_ids_long)
torch.testing.assert_close(original_cos_short, original_cos_long[:, :short_input_length, :])
@@ -419,8 +430,12 @@ def test_model_rope_scaling(self):
# Sanity check linear RoPE scaling
# New position "x" should match original position with index "x/scaling_factor"
- config.rope_scaling = {"type": "linear", "factor": scaling_factor}
- linear_scaling_rope = PhiRotaryEmbedding(config).to(torch_device)
+ linear_scaling_rope = PhiLinearScalingRotaryEmbedding(
+ head_dim,
+ max_position_embeddings=config.max_position_embeddings,
+ base=config.rope_theta,
+ scaling_factor=scaling_factor,
+ ).to(torch_device)
linear_cos_short, linear_sin_short = linear_scaling_rope(x, position_ids_short)
linear_cos_long, linear_sin_long = linear_scaling_rope(x, position_ids_long)
torch.testing.assert_close(linear_cos_short, linear_cos_long[:, :short_input_length, :])
@@ -433,8 +448,12 @@ def test_model_rope_scaling(self):
# Sanity check Dynamic NTK RoPE scaling
# Scaling should only be observed after a long input is fed. We can observe that the frequencies increase
# with scaling_factor (or that `inv_freq` decreases)
- config.rope_scaling = {"type": "dynamic", "factor": scaling_factor}
- ntk_scaling_rope = PhiRotaryEmbedding(config).to(torch_device)
+ ntk_scaling_rope = PhiDynamicNTKScalingRotaryEmbedding(
+ head_dim,
+ max_position_embeddings=config.max_position_embeddings,
+ base=config.rope_theta,
+ scaling_factor=scaling_factor,
+ ).to(torch_device)
ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, position_ids_short)
ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, position_ids_long)
torch.testing.assert_close(ntk_cos_short, original_cos_short)
diff --git a/tests/models/pix2struct/test_image_processing_pix2struct.py b/tests/models/pix2struct/test_image_processing_pix2struct.py
index 6b12b3827dabd9..2d5616b5b78b29 100644
--- a/tests/models/pix2struct/test_image_processing_pix2struct.py
+++ b/tests/models/pix2struct/test_image_processing_pix2struct.py
@@ -34,7 +34,7 @@
from transformers import Pix2StructImageProcessor
-class Pix2StructImageProcessingTester:
+class Pix2StructImageProcessingTester(unittest.TestCase):
def __init__(
self,
parent,
diff --git a/tests/models/pixtral/test_image_processing_pixtral.py b/tests/models/pixtral/test_image_processing_pixtral.py
index a45ead50612933..8b49b5aa60b99a 100644
--- a/tests/models/pixtral/test_image_processing_pixtral.py
+++ b/tests/models/pixtral/test_image_processing_pixtral.py
@@ -38,7 +38,7 @@
from transformers import PixtralImageProcessorFast
-class PixtralImageProcessingTester:
+class PixtralImageProcessingTester(unittest.TestCase):
def __init__(
self,
parent,
diff --git a/tests/models/pixtral/test_modeling_pixtral.py b/tests/models/pixtral/test_modeling_pixtral.py
index 3e5667caf45e3e..0c36cb5a4e0554 100644
--- a/tests/models/pixtral/test_modeling_pixtral.py
+++ b/tests/models/pixtral/test_modeling_pixtral.py
@@ -33,7 +33,8 @@
if is_torch_available():
import torch
-
+else:
+ is_torch_greater_or_equal_than_2_0 = False
if is_vision_available():
pass
diff --git a/tests/models/pop2piano/test_feature_extraction_pop2piano.py b/tests/models/pop2piano/test_feature_extraction_pop2piano.py
index 6b4b1b987a2f1f..c6766147975962 100644
--- a/tests/models/pop2piano/test_feature_extraction_pop2piano.py
+++ b/tests/models/pop2piano/test_feature_extraction_pop2piano.py
@@ -48,7 +48,7 @@
from transformers import Pop2PianoFeatureExtractor
-class Pop2PianoFeatureExtractionTester:
+class Pop2PianoFeatureExtractionTester(unittest.TestCase):
def __init__(
self,
parent,
diff --git a/tests/models/qwen2/test_modeling_qwen2.py b/tests/models/qwen2/test_modeling_qwen2.py
index ecfa9189d12e62..6c32a66e03626c 100644
--- a/tests/models/qwen2/test_modeling_qwen2.py
+++ b/tests/models/qwen2/test_modeling_qwen2.py
@@ -327,7 +327,7 @@ class Qwen2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
)
test_headmasking = False
test_pruning = False
- fx_compatible = False # Broken by attention refactor cc @Cyrilvallez
+ fx_compatible = True
# TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146
def is_pipeline_test_to_skip(
diff --git a/tests/models/qwen2_audio/test_modeling_qwen2_audio.py b/tests/models/qwen2_audio/test_modeling_qwen2_audio.py
index 8974d6923b391c..42b521e518e22e 100644
--- a/tests/models/qwen2_audio/test_modeling_qwen2_audio.py
+++ b/tests/models/qwen2_audio/test_modeling_qwen2_audio.py
@@ -41,6 +41,8 @@
if is_torch_available():
import torch
+else:
+ is_torch_greater_or_equal_than_2_0 = False
class Qwen2AudioModelTester:
@@ -204,6 +206,15 @@ def test_sdpa_can_dispatch_composite_models(self):
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
raise ValueError("The eager model should not have SDPA attention layers")
+ has_sdpa = False
+ for name, submodule in model_sdpa.named_modules():
+ class_name = submodule.__class__.__name__
+ if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
+ has_sdpa = True
+ break
+ if not has_sdpa and model_sdpa.config.model_type != "falcon":
+ raise ValueError("The SDPA model should have SDPA attention layers")
+
@require_torch
class Qwen2AudioForConditionalGenerationIntegrationTest(unittest.TestCase):
diff --git a/tests/models/qwen2_moe/test_modeling_qwen2_moe.py b/tests/models/qwen2_moe/test_modeling_qwen2_moe.py
index 21d11047ff1be8..abc7b57919b083 100644
--- a/tests/models/qwen2_moe/test_modeling_qwen2_moe.py
+++ b/tests/models/qwen2_moe/test_modeling_qwen2_moe.py
@@ -352,7 +352,7 @@ class Qwen2MoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
)
test_headmasking = False
test_pruning = False
- fx_compatible = False # Broken by attention refactor cc @Cyrilvallez
+ fx_compatible = True
# TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146
def is_pipeline_test_to_skip(
diff --git a/tests/models/qwen2_vl/test_image_processing_qwen2_vl.py b/tests/models/qwen2_vl/test_image_processing_qwen2_vl.py
index a6004349b49d11..d69addb9a10cca 100644
--- a/tests/models/qwen2_vl/test_image_processing_qwen2_vl.py
+++ b/tests/models/qwen2_vl/test_image_processing_qwen2_vl.py
@@ -34,7 +34,7 @@
from transformers import Qwen2VLImageProcessor
-class Qwen2VLImageProcessingTester:
+class Qwen2VLImageProcessingTester(unittest.TestCase):
def __init__(
self,
parent,
diff --git a/tests/models/qwen2_vl/test_modeling_qwen2_vl.py b/tests/models/qwen2_vl/test_modeling_qwen2_vl.py
index 2c27e1a03a647c..93ed33ae774458 100644
--- a/tests/models/qwen2_vl/test_modeling_qwen2_vl.py
+++ b/tests/models/qwen2_vl/test_modeling_qwen2_vl.py
@@ -47,6 +47,8 @@
if is_torch_available():
import torch
+else:
+ is_torch_greater_or_equal_than_2_0 = False
if is_vision_available():
from PIL import Image
diff --git a/tests/models/rag/test_modeling_rag.py b/tests/models/rag/test_modeling_rag.py
index b219d5c74edff0..3e3f7b9c457589 100644
--- a/tests/models/rag/test_modeling_rag.py
+++ b/tests/models/rag/test_modeling_rag.py
@@ -33,7 +33,7 @@
require_sentencepiece,
require_tokenizers,
require_torch,
- require_torch_non_multi_accelerator,
+ require_torch_non_multi_gpu,
slow,
torch_device,
)
@@ -678,7 +678,7 @@ def config_and_inputs(self):
@require_retrieval
@require_sentencepiece
@require_tokenizers
-@require_torch_non_multi_accelerator
+@require_torch_non_multi_gpu
class RagModelIntegrationTests(unittest.TestCase):
def tearDown(self):
super().tearDown()
@@ -1002,7 +1002,7 @@ def test_rag_token_generate_batch(self):
torch_device
)
- if torch_device != "cpu":
+ if torch_device == "cuda":
rag_token.half()
input_dict = tokenizer(
diff --git a/tests/models/rt_detr/test_image_processing_rt_detr.py b/tests/models/rt_detr/test_image_processing_rt_detr.py
index 2be3ea3e7651c2..e7bfbae3f9c27a 100644
--- a/tests/models/rt_detr/test_image_processing_rt_detr.py
+++ b/tests/models/rt_detr/test_image_processing_rt_detr.py
@@ -16,7 +16,7 @@
import requests
-from transformers.testing_utils import require_torch, require_torch_gpu, require_torchvision, require_vision, slow
+from transformers.testing_utils import require_torch, require_torch_gpu, require_vision, slow
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
@@ -374,7 +374,6 @@ def test_batched_coco_detection_annotations(self):
@slow
@require_torch_gpu
- @require_torchvision
# Copied from tests.models.detr.test_image_processing_detr.DetrImageProcessingTest.test_fast_processor_equivalence_cpu_gpu_coco_detection_annotations
def test_fast_processor_equivalence_cpu_gpu_coco_detection_annotations(self):
# prepare image and target
diff --git a/tests/models/rwkv/test_modeling_rwkv.py b/tests/models/rwkv/test_modeling_rwkv.py
index 0bc5c2de070135..5e82956e3efa6c 100644
--- a/tests/models/rwkv/test_modeling_rwkv.py
+++ b/tests/models/rwkv/test_modeling_rwkv.py
@@ -33,6 +33,9 @@
RwkvForCausalLM,
RwkvModel,
)
+ from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_0
+else:
+ is_torch_greater_or_equal_than_2_0 = False
class RwkvModelTester:
@@ -228,6 +231,9 @@ def prepare_config_and_inputs_for_common(self):
return config, inputs_dict
+@unittest.skipIf(
+ not is_torch_greater_or_equal_than_2_0, reason="See https://github.com/huggingface/transformers/pull/24204"
+)
@require_torch
class RwkvModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (RwkvModel, RwkvForCausalLM) if is_torch_available() else ()
@@ -434,6 +440,9 @@ def test_left_padding_compatibility(self):
pass
+@unittest.skipIf(
+ not is_torch_greater_or_equal_than_2_0, reason="See https://github.com/huggingface/transformers/pull/24204"
+)
@slow
class RWKVIntegrationTests(unittest.TestCase):
def setUp(self):
diff --git a/tests/models/sam/test_modeling_sam.py b/tests/models/sam/test_modeling_sam.py
index 351016716a0cf1..7faace0096c8de 100644
--- a/tests/models/sam/test_modeling_sam.py
+++ b/tests/models/sam/test_modeling_sam.py
@@ -14,13 +14,12 @@
# limitations under the License.
"""Testing suite for the PyTorch SAM model."""
-import tempfile
import unittest
import requests
from transformers import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig, pipeline
-from transformers.testing_utils import cleanup, require_torch, require_torch_sdpa, slow, torch_device
+from transformers.testing_utils import cleanup, require_torch, slow, torch_device
from transformers.utils import is_torch_available, is_vision_available
from ...test_configuration_common import ConfigTester
@@ -296,7 +295,6 @@ class SamModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
test_resize_embeddings = False
test_head_masking = False
test_torchscript = False
- _is_composite = True
# TODO: Fix me @Arthur: `run_batch_test` in `tests/test_pipeline_mixin.py` not working
def is_pipeline_test_to_skip(
@@ -313,13 +311,22 @@ def is_pipeline_test_to_skip(
def setUp(self):
self.model_tester = SamModelTester(self)
- common_properties = ["initializer_range"]
- self.config_tester = ConfigTester(
- self, config_class=SamConfig, has_text_modality=False, common_properties=common_properties
+ self.vision_config_tester = ConfigTester(self, config_class=SamVisionConfig, has_text_modality=False)
+ self.prompt_encoder_config_tester = ConfigTester(
+ self,
+ config_class=SamPromptEncoderConfig,
+ has_text_modality=False,
+ num_attention_heads=12,
+ num_hidden_layers=2,
+ )
+ self.mask_decoder_config_tester = ConfigTester(
+ self, config_class=SamMaskDecoderConfig, has_text_modality=False
)
def test_config(self):
- self.config_tester.run_common_tests()
+ self.vision_config_tester.run_common_tests()
+ self.prompt_encoder_config_tester.run_common_tests()
+ self.mask_decoder_config_tester.run_common_tests()
@unittest.skip(reason="SAM's vision encoder does not use inputs_embeds")
def test_inputs_embeds(self):
@@ -443,68 +450,6 @@ def test_model_from_pretrained(self):
model = SamModel.from_pretrained(model_name)
self.assertIsNotNone(model)
- @require_torch_sdpa
- def test_sdpa_can_compile_dynamic(self):
- self.skipTest(reason="SAM model can't be compiled dynamic yet")
-
- @require_torch_sdpa
- def test_sdpa_can_dispatch_composite_models(self):
- """
- Tests if composite models dispatch correctly on SDPA/eager when requested so when loading the model.
- This tests only by looking at layer names, as usually SDPA layers are calles "SDPAAttention".
- In contrast to the above test, this one checks if the "config._attn_implamentation" is a dict after the model
- is loaded, because we manually replicate requested attn implementation on each sub-config when loading.
- See https://github.com/huggingface/transformers/pull/32238 for more info
-
- The test tries to cover most general cases of composite models, VLMs with vision and text configs. Any model
- that has a different set of sub-configs has to overwrite this test.
- """
- if not self.has_attentions:
- self.skipTest(reason="Model architecture does not support attentions")
-
- if not self._is_composite:
- self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA")
-
- for model_class in self.all_model_classes:
- config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
- model = model_class(config)
-
- with tempfile.TemporaryDirectory() as tmpdirname:
- model.save_pretrained(tmpdirname)
- model_sdpa = model_class.from_pretrained(tmpdirname, attn_implementation="sdpa")
- model_sdpa = model_sdpa.eval().to(torch_device)
-
- model_eager = model_class.from_pretrained(tmpdirname, attn_implementation="eager")
- model_eager = model_eager.eval().to(torch_device)
-
- # Root model determines SDPA support
- attn_impl = "sdpa" if model._supports_sdpa else "eager"
-
- # Check config propagation to submodels that support it
- self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
- self.assertTrue(model_sdpa.vision_encoder.config._attn_implementation == attn_impl)
- self.assertTrue(model_sdpa.mask_decoder.config._attn_implementation == attn_impl)
-
- self.assertTrue(model_eager.config._attn_implementation == "eager")
- self.assertTrue(model_eager.vision_encoder.config._attn_implementation == "eager")
- self.assertTrue(model_eager.mask_decoder.config._attn_implementation == "eager")
-
- # Verify SDPA/eager layer presence
- has_sdpa = False
- for name, submodule in model_sdpa.named_modules():
- class_name = submodule.__class__.__name__
- if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
- has_sdpa = True
- break
-
- if not has_sdpa and attn_impl == "sdpa":
- raise ValueError("The SDPA model should have SDPA attention layers")
-
- for name, submodule in model_eager.named_modules():
- class_name = submodule.__class__.__name__
- if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
- raise ValueError("The eager model should not have SDPA attention layers")
-
def prepare_image():
img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
diff --git a/tests/models/seamless_m4t/test_feature_extraction_seamless_m4t.py b/tests/models/seamless_m4t/test_feature_extraction_seamless_m4t.py
index 7c13f97b64d7e3..8830660c097c5b 100644
--- a/tests/models/seamless_m4t/test_feature_extraction_seamless_m4t.py
+++ b/tests/models/seamless_m4t/test_feature_extraction_seamless_m4t.py
@@ -52,7 +52,7 @@ def floats_list(shape, scale=1.0, rng=None, name=None):
@require_torch
-class SeamlessM4TFeatureExtractionTester:
+class SeamlessM4TFeatureExtractionTester(unittest.TestCase):
def __init__(
self,
parent,
diff --git a/tests/models/seamless_m4t_v2/test_modeling_seamless_m4t_v2.py b/tests/models/seamless_m4t_v2/test_modeling_seamless_m4t_v2.py
index 276375c7e85439..451fff0b35fb8c 100644
--- a/tests/models/seamless_m4t_v2/test_modeling_seamless_m4t_v2.py
+++ b/tests/models/seamless_m4t_v2/test_modeling_seamless_m4t_v2.py
@@ -589,11 +589,6 @@ def test_attention_outputs(self):
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
)
- # TODO: @ydshieh: refer to #34968
- @unittest.skip(reason="Failing on multi-gpu runner")
- def test_retain_grad_hidden_states_attentions(self):
- pass
-
@require_torch
class SeamlessM4Tv2ModelWithTextInputTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
@@ -840,13 +835,7 @@ def test_generation_languages(self):
def test_speech_generation(self):
config, input_speech, input_text = self.prepare_speech_and_text_input()
- from transformers.testing_utils import set_config_for_less_flaky_test, set_model_for_less_flaky_test
-
- set_config_for_less_flaky_test(config)
-
model = SeamlessM4Tv2Model(config=config)
- set_model_for_less_flaky_test(model)
-
self.update_generation(model)
model.save_pretrained(self.tmpdirname)
model.to(torch_device)
@@ -858,11 +847,6 @@ def test_speech_generation(self):
state_dict = model.state_dict()
text_model = SeamlessM4Tv2ForTextToSpeech.from_pretrained(self.tmpdirname)
- # Even if this component is loaded after `model.save_pretrained` which is after
- # `set_model_for_less_flaky_test(model)`, we still need to apply `set_model_for_less_flaky_test` here as the
- # `eps` attribute in the model's norm layers is not set from the config.
- set_model_for_less_flaky_test(text_model)
-
self.update_generation(text_model)
text_model.to(torch_device)
text_model.eval()
@@ -870,11 +854,6 @@ def test_speech_generation(self):
output_text = self.factory_generation_speech_test(model, input_text)
speech_model = SeamlessM4Tv2ForSpeechToSpeech.from_pretrained(self.tmpdirname)
- # Even if this component is loaded after `model.save_pretrained` which is after
- # `set_model_for_less_flaky_test(model)`, we still need to apply `set_model_for_less_flaky_test` here as the
- # `eps` attribute in the model's norm layers is not set from the config.
- set_model_for_less_flaky_test(speech_model)
-
self.update_generation(speech_model)
speech_model.to(torch_device)
speech_model.eval()
diff --git a/tests/models/segformer/test_image_processing_segformer.py b/tests/models/segformer/test_image_processing_segformer.py
index dba2de7e483038..223993000181a3 100644
--- a/tests/models/segformer/test_image_processing_segformer.py
+++ b/tests/models/segformer/test_image_processing_segformer.py
@@ -33,7 +33,7 @@
from transformers import SegformerImageProcessor
-class SegformerImageProcessingTester:
+class SegformerImageProcessingTester(unittest.TestCase):
def __init__(
self,
parent,
diff --git a/tests/models/seggpt/test_image_processing_seggpt.py b/tests/models/seggpt/test_image_processing_seggpt.py
index 74e78f0082016b..f79b7ea44370dc 100644
--- a/tests/models/seggpt/test_image_processing_seggpt.py
+++ b/tests/models/seggpt/test_image_processing_seggpt.py
@@ -35,7 +35,7 @@
from transformers import SegGptImageProcessor
-class SegGptImageProcessingTester:
+class SegGptImageProcessingTester(unittest.TestCase):
def __init__(
self,
parent,
diff --git a/tests/models/speech_encoder_decoder/test_modeling_speech_encoder_decoder.py b/tests/models/speech_encoder_decoder/test_modeling_speech_encoder_decoder.py
index 897d4b056f1977..7dcb7c406ae287 100644
--- a/tests/models/speech_encoder_decoder/test_modeling_speech_encoder_decoder.py
+++ b/tests/models/speech_encoder_decoder/test_modeling_speech_encoder_decoder.py
@@ -500,6 +500,15 @@ def test_sdpa_can_dispatch_composite_models(self):
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
raise ValueError("The eager model should not have SDPA attention layers")
+ has_sdpa = False
+ for name, submodule in model_sdpa.named_modules():
+ class_name = submodule.__class__.__name__
+ if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
+ has_sdpa = True
+ break
+ if not has_sdpa:
+ raise ValueError("The SDPA model should have SDPA attention layers")
+
@require_torch
class Wav2Vec2BertModelTest(EncoderDecoderMixin, unittest.TestCase):
diff --git a/tests/models/speech_to_text/test_feature_extraction_speech_to_text.py b/tests/models/speech_to_text/test_feature_extraction_speech_to_text.py
index 2a4ad0894911c0..9023e8467f736c 100644
--- a/tests/models/speech_to_text/test_feature_extraction_speech_to_text.py
+++ b/tests/models/speech_to_text/test_feature_extraction_speech_to_text.py
@@ -48,7 +48,7 @@ def floats_list(shape, scale=1.0, rng=None, name=None):
@require_torch
@require_torchaudio
-class Speech2TextFeatureExtractionTester:
+class Speech2TextFeatureExtractionTester(unittest.TestCase):
def __init__(
self,
parent,
diff --git a/tests/models/speecht5/test_feature_extraction_speecht5.py b/tests/models/speecht5/test_feature_extraction_speecht5.py
index 70d60f92238acd..5ec632e7e76c63 100644
--- a/tests/models/speecht5/test_feature_extraction_speecht5.py
+++ b/tests/models/speecht5/test_feature_extraction_speecht5.py
@@ -50,7 +50,7 @@ def floats_list(shape, scale=1.0, rng=None, name=None):
@require_torch
-class SpeechT5FeatureExtractionTester:
+class SpeechT5FeatureExtractionTester(unittest.TestCase):
def __init__(
self,
parent,
diff --git a/tests/models/stablelm/test_modeling_stablelm.py b/tests/models/stablelm/test_modeling_stablelm.py
index c8aa55399035d2..91044a4eb750d1 100644
--- a/tests/models/stablelm/test_modeling_stablelm.py
+++ b/tests/models/stablelm/test_modeling_stablelm.py
@@ -44,7 +44,11 @@
StableLmForTokenClassification,
StableLmModel,
)
- from transformers.models.stablelm.modeling_stablelm import StableLmRotaryEmbedding
+ from transformers.models.stablelm.modeling_stablelm import (
+ StableLmDynamicNTKScalingRotaryEmbedding,
+ StableLmLinearScalingRotaryEmbedding,
+ StableLmRotaryEmbedding,
+ )
# Copied from transformers.tests.models.persimmon.test_modeling_persimmon.PersimmonModelTester with Persimmon -> StableLm
@@ -402,9 +406,12 @@ def test_model_rope_scaling_from_config(self, scaling_type):
# The output should be different for long inputs
self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5))
- # Copied from tests.models.gpt_neox.test_modeling_gpt_neox.GPTNeoXModelTest.test_model_rope_scaling with GPTNeoX->StableLm
+ # Copied from tests.models.falcon.test_modeling_falcon.FalconModelTest.test_model_rope_scaling with Falcon->StableLm
def test_model_rope_scaling(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+ hidden_size = config.hidden_size
+ num_heads = config.num_attention_heads
+ head_dim = hidden_size // num_heads
scaling_factor = 10
short_input_length = 10
long_input_length = int(config.max_position_embeddings * 1.5)
@@ -417,7 +424,11 @@ def test_model_rope_scaling(self):
position_ids_long = position_ids_long.unsqueeze(0)
# Sanity check original RoPE
- original_rope = StableLmRotaryEmbedding(config).to(torch_device)
+ original_rope = StableLmRotaryEmbedding(
+ head_dim,
+ max_position_embeddings=config.max_position_embeddings,
+ base=config.rope_theta,
+ ).to(torch_device)
original_cos_short, original_sin_short = original_rope(x, position_ids_short)
original_cos_long, original_sin_long = original_rope(x, position_ids_long)
torch.testing.assert_close(original_cos_short, original_cos_long[:, :short_input_length, :])
@@ -425,8 +436,12 @@ def test_model_rope_scaling(self):
# Sanity check linear RoPE scaling
# New position "x" should match original position with index "x/scaling_factor"
- config.rope_scaling = {"type": "linear", "factor": scaling_factor}
- linear_scaling_rope = StableLmRotaryEmbedding(config).to(torch_device)
+ linear_scaling_rope = StableLmLinearScalingRotaryEmbedding(
+ head_dim,
+ max_position_embeddings=config.max_position_embeddings,
+ base=config.rope_theta,
+ scaling_factor=scaling_factor,
+ ).to(torch_device)
linear_cos_short, linear_sin_short = linear_scaling_rope(x, position_ids_short)
linear_cos_long, linear_sin_long = linear_scaling_rope(x, position_ids_long)
torch.testing.assert_close(linear_cos_short, linear_cos_long[:, :short_input_length, :])
@@ -439,8 +454,12 @@ def test_model_rope_scaling(self):
# Sanity check Dynamic NTK RoPE scaling
# Scaling should only be observed after a long input is fed. We can observe that the frequencies increase
# with scaling_factor (or that `inv_freq` decreases)
- config.rope_scaling = {"type": "dynamic", "factor": scaling_factor}
- ntk_scaling_rope = StableLmRotaryEmbedding(config).to(torch_device)
+ ntk_scaling_rope = StableLmDynamicNTKScalingRotaryEmbedding(
+ head_dim,
+ max_position_embeddings=config.max_position_embeddings,
+ base=config.rope_theta,
+ scaling_factor=scaling_factor,
+ ).to(torch_device)
ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, position_ids_short)
ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, position_ids_long)
torch.testing.assert_close(ntk_cos_short, original_cos_short)
diff --git a/tests/models/superpoint/test_image_processing_superpoint.py b/tests/models/superpoint/test_image_processing_superpoint.py
index e11fd08422ed3c..c2eae872004c77 100644
--- a/tests/models/superpoint/test_image_processing_superpoint.py
+++ b/tests/models/superpoint/test_image_processing_superpoint.py
@@ -33,7 +33,7 @@
from transformers import SuperPointImageProcessor
-class SuperPointImageProcessingTester:
+class SuperPointImageProcessingTester(unittest.TestCase):
def __init__(
self,
parent,
diff --git a/tests/models/tapas/test_modeling_tapas.py b/tests/models/tapas/test_modeling_tapas.py
index 05618f4a4efd8c..4ee159d6bddd1d 100644
--- a/tests/models/tapas/test_modeling_tapas.py
+++ b/tests/models/tapas/test_modeling_tapas.py
@@ -60,6 +60,9 @@
reduce_mean,
reduce_sum,
)
+ from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_12
+else:
+ is_torch_greater_or_equal_than_1_12 = False
class TapasModelTester:
@@ -408,6 +411,7 @@ def prepare_config_and_inputs_for_common(self):
return config, inputs_dict
+@unittest.skipIf(not is_torch_greater_or_equal_than_1_12, reason="Tapas is only available in torch v1.12+")
@require_torch
class TapasModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (
@@ -574,6 +578,7 @@ def prepare_tapas_batch_inputs_for_training():
return table, queries, answer_coordinates, answer_text, float_answer
+@unittest.skipIf(not is_torch_greater_or_equal_than_1_12, reason="Tapas is only available in torch v1.12+")
@require_torch
class TapasModelIntegrationTest(unittest.TestCase):
@cached_property
@@ -925,6 +930,10 @@ def test_inference_classification_head(self):
self.assertTrue(torch.allclose(outputs.logits, expected_tensor, atol=0.05))
+# Below: tests for Tapas utilities which are defined in modeling_tapas.py.
+# These are based on segmented_tensor_test.py of the original implementation.
+# URL: https://github.com/google-research/tapas/blob/master/tapas/models/segmented_tensor_test.py
+@unittest.skipIf(not is_torch_greater_or_equal_than_1_12, reason="Tapas is only available in torch v1.12+")
@require_torch
class TapasUtilitiesTest(unittest.TestCase):
def _prepare_tables(self):
diff --git a/tests/models/tapas/test_tokenization_tapas.py b/tests/models/tapas/test_tokenization_tapas.py
index 9a3a2578fd16b3..0a911f7182b4a0 100644
--- a/tests/models/tapas/test_tokenization_tapas.py
+++ b/tests/models/tapas/test_tokenization_tapas.py
@@ -23,7 +23,7 @@
import pandas as pd
from parameterized import parameterized
-from transformers import AddedToken
+from transformers import AddedToken, is_torch_available
from transformers.models.tapas.tokenization_tapas import (
VOCAB_FILES_NAMES,
BasicTokenizer,
@@ -45,6 +45,12 @@
from ...test_tokenization_common import TokenizerTesterMixin, filter_non_english, merge_model_tokenizer_mappings
+if is_torch_available():
+ from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_12
+else:
+ is_torch_greater_or_equal_than_1_12 = False
+
+
@require_tokenizers
@require_pandas
class TapasTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
@@ -1042,6 +1048,7 @@ def test_token_type_ids(self):
# Do the same test as modeling common.
self.assertIn(0, output["token_type_ids"][0])
+ @unittest.skipIf(not is_torch_greater_or_equal_than_1_12, reason="Tapas is only available in torch v1.12+")
@require_torch
@slow
def test_torch_encode_plus_sent_to_model(self):
diff --git a/tests/models/timm_wrapper/__init__.py b/tests/models/timm_wrapper/__init__.py
deleted file mode 100644
index e69de29bb2d1d6..00000000000000
diff --git a/tests/models/timm_wrapper/test_image_processing_timm_wrapper.py b/tests/models/timm_wrapper/test_image_processing_timm_wrapper.py
deleted file mode 100644
index 49d864178d14b3..00000000000000
--- a/tests/models/timm_wrapper/test_image_processing_timm_wrapper.py
+++ /dev/null
@@ -1,103 +0,0 @@
-# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import tempfile
-import unittest
-
-import numpy as np
-
-from transformers.testing_utils import require_torch, require_torchvision, require_vision
-from transformers.utils import is_torch_available, is_vision_available
-
-
-if is_torch_available():
- import torch
-
-if is_vision_available():
- from PIL import Image
-
- from transformers import TimmWrapperConfig, TimmWrapperImageProcessor
-
-
-@require_torch
-@require_vision
-@require_torchvision
-class TimmWrapperImageProcessingTest(unittest.TestCase):
- image_processing_class = TimmWrapperImageProcessor if is_vision_available() else None
-
- def setUp(self):
- super().setUp()
- self.temp_dir = tempfile.TemporaryDirectory()
- config = TimmWrapperConfig.from_pretrained("timm/resnet18.a1_in1k")
- config.save_pretrained(self.temp_dir.name)
-
- def tearDown(self):
- self.temp_dir.cleanup()
-
- def test_load_from_hub(self):
- image_processor = TimmWrapperImageProcessor.from_pretrained("timm/resnet18.a1_in1k")
- self.assertIsInstance(image_processor, TimmWrapperImageProcessor)
-
- def test_load_from_local_dir(self):
- image_processor = TimmWrapperImageProcessor.from_pretrained(self.temp_dir.name)
- self.assertIsInstance(image_processor, TimmWrapperImageProcessor)
-
- def test_image_processor_properties(self):
- image_processor = TimmWrapperImageProcessor.from_pretrained(self.temp_dir.name)
- self.assertTrue(hasattr(image_processor, "data_config"))
- self.assertTrue(hasattr(image_processor, "val_transforms"))
- self.assertTrue(hasattr(image_processor, "train_transforms"))
-
- def test_image_processor_call_numpy(self):
- image_processor = TimmWrapperImageProcessor.from_pretrained(self.temp_dir.name)
-
- single_image = np.random.randint(256, size=(256, 256, 3), dtype=np.uint8)
- batch_images = [single_image, single_image, single_image]
-
- # single image
- pixel_values = image_processor(single_image).pixel_values
- self.assertEqual(pixel_values.shape, (1, 3, 224, 224))
-
- # batch images
- pixel_values = image_processor(batch_images).pixel_values
- self.assertEqual(pixel_values.shape, (3, 3, 224, 224))
-
- def test_image_processor_call_pil(self):
- image_processor = TimmWrapperImageProcessor.from_pretrained(self.temp_dir.name)
-
- single_image = Image.fromarray(np.random.randint(256, size=(256, 256, 3), dtype=np.uint8))
- batch_images = [single_image, single_image, single_image]
-
- # single image
- pixel_values = image_processor(single_image).pixel_values
- self.assertEqual(pixel_values.shape, (1, 3, 224, 224))
-
- # batch images
- pixel_values = image_processor(batch_images).pixel_values
- self.assertEqual(pixel_values.shape, (3, 3, 224, 224))
-
- def test_image_processor_call_tensor(self):
- image_processor = TimmWrapperImageProcessor.from_pretrained(self.temp_dir.name)
-
- single_image = torch.from_numpy(np.random.randint(256, size=(3, 256, 256), dtype=np.uint8)).float()
- batch_images = [single_image, single_image, single_image]
-
- # single image
- pixel_values = image_processor(single_image).pixel_values
- self.assertEqual(pixel_values.shape, (1, 3, 224, 224))
-
- # batch images
- pixel_values = image_processor(batch_images).pixel_values
- self.assertEqual(pixel_values.shape, (3, 3, 224, 224))
diff --git a/tests/models/timm_wrapper/test_modeling_timm_wrapper.py b/tests/models/timm_wrapper/test_modeling_timm_wrapper.py
deleted file mode 100644
index 6f63c0aa147d09..00000000000000
--- a/tests/models/timm_wrapper/test_modeling_timm_wrapper.py
+++ /dev/null
@@ -1,366 +0,0 @@
-# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import inspect
-import tempfile
-import unittest
-
-from transformers.testing_utils import (
- require_bitsandbytes,
- require_timm,
- require_torch,
- require_vision,
- slow,
- torch_device,
-)
-from transformers.utils.import_utils import is_timm_available, is_torch_available, is_vision_available
-
-from ...test_configuration_common import ConfigTester
-from ...test_modeling_common import ModelTesterMixin, floats_tensor
-from ...test_pipeline_mixin import PipelineTesterMixin
-
-
-if is_torch_available():
- import torch
-
- from transformers import TimmWrapperConfig, TimmWrapperForImageClassification, TimmWrapperModel
-
-
-if is_timm_available():
- import timm
-
-
-if is_vision_available():
- from PIL import Image
-
- from transformers import TimmWrapperImageProcessor
-
-
-class TimmWrapperModelTester:
- def __init__(
- self,
- parent,
- model_name="timm/resnet18.a1_in1k",
- batch_size=3,
- image_size=32,
- num_channels=3,
- is_training=True,
- ):
- self.parent = parent
- self.model_name = model_name
- self.batch_size = batch_size
- self.image_size = image_size
- self.num_channels = num_channels
- self.is_training = is_training
-
- def prepare_config_and_inputs(self):
- pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
- config = self.get_config()
-
- return config, pixel_values
-
- def get_config(self):
- return TimmWrapperConfig.from_pretrained(self.model_name)
-
- def create_and_check_model(self, config, pixel_values):
- model = TimmWrapperModel(config=config)
- model.to(torch_device)
- model.eval()
- with torch.no_grad():
- result = model(pixel_values)
- self.parent.assertEqual(
- result.feature_map[-1].shape,
- (self.batch_size, model.channels[-1], 14, 14),
- )
-
- def prepare_config_and_inputs_for_common(self):
- config_and_inputs = self.prepare_config_and_inputs()
- config, pixel_values = config_and_inputs
- inputs_dict = {"pixel_values": pixel_values}
- return config, inputs_dict
-
-
-@require_torch
-@require_timm
-class TimmWrapperModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
- all_model_classes = (TimmWrapperModel, TimmWrapperForImageClassification) if is_torch_available() else ()
- pipeline_model_mapping = (
- {"image-feature-extraction": TimmWrapperModel, "image-classification": TimmWrapperForImageClassification}
- if is_torch_available()
- else {}
- )
-
- test_resize_embeddings = False
- test_head_masking = False
- test_pruning = False
- has_attentions = False
- test_model_parallel = False
-
- def setUp(self):
- self.config_class = TimmWrapperConfig
- self.model_tester = TimmWrapperModelTester(self)
- self.config_tester = ConfigTester(
- self,
- config_class=self.config_class,
- has_text_modality=False,
- common_properties=[],
- model_name="timm/resnet18.a1_in1k",
- )
-
- def test_config(self):
- self.config_tester.run_common_tests()
-
- def test_hidden_states_output(self):
- config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
-
- for model_class in self.all_model_classes:
- model = model_class(config)
-
- # check all hidden states
- with torch.no_grad():
- outputs = model(**inputs_dict, output_hidden_states=True)
- self.assertTrue(
- len(outputs.hidden_states) == 5, f"expected 5 hidden states, but got {len(outputs.hidden_states)}"
- )
- expected_shapes = [[16, 16], [8, 8], [4, 4], [2, 2], [1, 1]]
- resulted_shapes = [list(h.shape[2:]) for h in outputs.hidden_states]
- self.assertListEqual(expected_shapes, resulted_shapes)
-
- # check we can select hidden states by indices
- with torch.no_grad():
- outputs = model(**inputs_dict, output_hidden_states=[-2, -1])
- self.assertTrue(
- len(outputs.hidden_states) == 2, f"expected 2 hidden states, but got {len(outputs.hidden_states)}"
- )
- expected_shapes = [[2, 2], [1, 1]]
- resulted_shapes = [list(h.shape[2:]) for h in outputs.hidden_states]
- self.assertListEqual(expected_shapes, resulted_shapes)
-
- @unittest.skip(reason="TimmWrapper models doesn't have inputs_embeds")
- def test_inputs_embeds(self):
- pass
-
- @unittest.skip(reason="TimmWrapper models doesn't have inputs_embeds")
- def test_model_get_set_embeddings(self):
- pass
-
- @unittest.skip(reason="TimmWrapper doesn't support output_attentions=True.")
- def test_torchscript_output_attentions(self):
- pass
-
- @unittest.skip(reason="TimmWrapper doesn't support this.")
- def test_retain_grad_hidden_states_attentions(self):
- pass
-
- @unittest.skip(reason="TimmWrapper initialization is managed on the timm side")
- def test_initialization(self):
- pass
-
- @unittest.skip(reason="Need to use a timm model and there is no tiny model available.")
- def test_model_is_small(self):
- pass
-
- def test_forward_signature(self):
- config, _ = self.model_tester.prepare_config_and_inputs_for_common()
-
- for model_class in self.all_model_classes:
- model = model_class(config)
- signature = inspect.signature(model.forward)
- # signature.parameters is an OrderedDict => so arg_names order is deterministic
- arg_names = [*signature.parameters.keys()]
-
- expected_arg_names = ["pixel_values"]
- self.assertListEqual(arg_names[:1], expected_arg_names)
-
- def test_do_pooling_option(self):
- config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
- config.do_pooling = False
-
- model = TimmWrapperModel._from_config(config)
-
- # check there is no pooling
- with torch.no_grad():
- output = model(**inputs_dict)
- self.assertIsNone(output.pooler_output)
-
- # check there is pooler output
- with torch.no_grad():
- output = model(**inputs_dict, do_pooling=True)
- self.assertIsNotNone(output.pooler_output)
-
-
-# We will verify our results on an image of cute cats
-def prepare_img():
- image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
- return image
-
-
-@require_torch
-@require_timm
-@require_vision
-class TimmWrapperModelIntegrationTest(unittest.TestCase):
- # some popular ones
- model_names_to_test = [
- "vit_small_patch16_384.augreg_in21k_ft_in1k",
- "resnet50.a1_in1k",
- "tf_mobilenetv3_large_minimal_100.in1k",
- "swin_tiny_patch4_window7_224.ms_in1k",
- "ese_vovnet19b_dw.ra_in1k",
- "hrnet_w18.ms_aug_in1k",
- ]
-
- @slow
- def test_inference_image_classification_head(self):
- checkpoint = "timm/resnet18.a1_in1k"
- model = TimmWrapperForImageClassification.from_pretrained(checkpoint, device_map=torch_device).eval()
- image_processor = TimmWrapperImageProcessor.from_pretrained(checkpoint)
-
- image = prepare_img()
- inputs = image_processor(images=image, return_tensors="pt").to(torch_device)
-
- # forward pass
- with torch.no_grad():
- outputs = model(**inputs)
-
- # verify the shape and logits
- expected_shape = torch.Size((1, 1000))
- self.assertEqual(outputs.logits.shape, expected_shape)
-
- expected_label = 281 # tabby cat
- self.assertEqual(torch.argmax(outputs.logits).item(), expected_label)
-
- expected_slice = torch.tensor([-11.2618, -9.6192, -10.3205]).to(torch_device)
- resulted_slice = outputs.logits[0, :3]
- is_close = torch.allclose(resulted_slice, expected_slice, atol=1e-3)
- self.assertTrue(is_close, f"Expected {expected_slice}, but got {resulted_slice}")
-
- @slow
- @require_bitsandbytes
- def test_inference_image_classification_quantized(self):
- from transformers import BitsAndBytesConfig
-
- checkpoint = "timm/vit_small_patch16_384.augreg_in21k_ft_in1k"
-
- quantization_config = BitsAndBytesConfig(load_in_8bit=True)
- model = TimmWrapperForImageClassification.from_pretrained(
- checkpoint, quantization_config=quantization_config, device_map=torch_device
- ).eval()
- image_processor = TimmWrapperImageProcessor.from_pretrained(checkpoint)
-
- image = prepare_img()
- inputs = image_processor(images=image, return_tensors="pt").to(torch_device)
-
- # forward pass
- with torch.no_grad():
- outputs = model(**inputs)
-
- # verify the shape and logits
- expected_shape = torch.Size((1, 1000))
- self.assertEqual(outputs.logits.shape, expected_shape)
-
- expected_label = 281 # tabby cat
- self.assertEqual(torch.argmax(outputs.logits).item(), expected_label)
-
- expected_slice = torch.tensor([-2.4043, 1.4492, -0.5127]).to(outputs.logits.dtype)
- resulted_slice = outputs.logits[0, :3].cpu()
- is_close = torch.allclose(resulted_slice, expected_slice, atol=0.1)
- self.assertTrue(is_close, f"Expected {expected_slice}, but got {resulted_slice}")
-
- @slow
- def test_transformers_model_for_classification_is_equivalent_to_timm(self):
- # check that wrapper logits are the same as timm model logits
-
- image = prepare_img()
-
- for model_name in self.model_names_to_test:
- checkpoint = f"timm/{model_name}"
-
- with self.subTest(msg=model_name):
- # prepare inputs
- image_processor = TimmWrapperImageProcessor.from_pretrained(checkpoint)
- pixel_values = image_processor(images=image).pixel_values.to(torch_device)
-
- # load models
- model = TimmWrapperForImageClassification.from_pretrained(checkpoint, device_map=torch_device).eval()
- timm_model = timm.create_model(model_name, pretrained=True).to(torch_device).eval()
-
- with torch.inference_mode():
- outputs = model(pixel_values)
- timm_outputs = timm_model(pixel_values)
-
- # check shape is the same
- self.assertEqual(outputs.logits.shape, timm_outputs.shape)
-
- # check logits are the same
- diff = (outputs.logits - timm_outputs).max().item()
- self.assertLess(diff, 1e-4)
-
- @slow
- def test_transformers_model_is_equivalent_to_timm(self):
- # check that wrapper logits are the same as timm model logits
-
- image = prepare_img()
-
- models_to_test = ["vit_small_patch16_224.dino"] + self.model_names_to_test
-
- for model_name in models_to_test:
- checkpoint = f"timm/{model_name}"
-
- with self.subTest(msg=model_name):
- # prepare inputs
- image_processor = TimmWrapperImageProcessor.from_pretrained(checkpoint)
- pixel_values = image_processor(images=image).pixel_values.to(torch_device)
-
- # load models
- model = TimmWrapperModel.from_pretrained(checkpoint, device_map=torch_device).eval()
- timm_model = timm.create_model(model_name, pretrained=True, num_classes=0).to(torch_device).eval()
-
- with torch.inference_mode():
- outputs = model(pixel_values)
- timm_outputs = timm_model(pixel_values)
-
- # check shape is the same
- self.assertEqual(outputs.pooler_output.shape, timm_outputs.shape)
-
- # check logits are the same
- diff = (outputs.pooler_output - timm_outputs).max().item()
- self.assertLess(diff, 1e-4)
-
- @slow
- def test_save_load_to_timm(self):
- # test that timm model can be loaded to transformers, saved and then loaded back into timm
-
- model = TimmWrapperForImageClassification.from_pretrained(
- "timm/resnet18.a1_in1k", num_labels=10, ignore_mismatched_sizes=True
- )
-
- with tempfile.TemporaryDirectory() as tmpdirname:
- model.save_pretrained(tmpdirname)
-
- # there is no direct way to load timm model from folder, use the same config + path to weights
- timm_model = timm.create_model(
- "resnet18", num_classes=10, checkpoint_path=f"{tmpdirname}/model.safetensors"
- )
-
- # check that all weights are the same after reload
- different_weights = []
- for (name1, param1), (name2, param2) in zip(
- model.timm_model.named_parameters(), timm_model.named_parameters()
- ):
- if param1.shape != param2.shape or not torch.equal(param1, param2):
- different_weights.append((name1, name2))
-
- if different_weights:
- self.fail(f"Found different weights after reloading: {different_weights}")
diff --git a/tests/models/univnet/test_feature_extraction_univnet.py b/tests/models/univnet/test_feature_extraction_univnet.py
index 2917d206dfde34..dfa335d15383ee 100644
--- a/tests/models/univnet/test_feature_extraction_univnet.py
+++ b/tests/models/univnet/test_feature_extraction_univnet.py
@@ -50,7 +50,7 @@ def floats_list(shape, scale=1.0, rng=None, name=None):
return values
-class UnivNetFeatureExtractionTester:
+class UnivNetFeatureExtractionTester(unittest.TestCase):
def __init__(
self,
parent,
diff --git a/tests/models/vipllava/test_modeling_vipllava.py b/tests/models/vipllava/test_modeling_vipllava.py
index 8286b3c94fb9da..4f501fc10a028f 100644
--- a/tests/models/vipllava/test_modeling_vipllava.py
+++ b/tests/models/vipllava/test_modeling_vipllava.py
@@ -41,6 +41,8 @@
if is_torch_available():
import torch
+else:
+ is_torch_greater_or_equal_than_2_0 = False
if is_vision_available():
from PIL import Image
diff --git a/tests/models/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py b/tests/models/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py
index 2b517034bffb15..77e2a19fea4861 100644
--- a/tests/models/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py
+++ b/tests/models/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py
@@ -441,6 +441,15 @@ def test_sdpa_can_dispatch_composite_models(self):
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
raise ValueError("The eager model should not have SDPA attention layers")
+ has_sdpa = False
+ for name, submodule in model_sdpa.named_modules():
+ class_name = submodule.__class__.__name__
+ if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
+ has_sdpa = True
+ break
+ if not has_sdpa:
+ raise ValueError("The SDPA model should have SDPA attention layers")
+
@require_torch
class DeiT2RobertaModelTest(EncoderDecoderMixin, unittest.TestCase):
diff --git a/tests/models/vision_text_dual_encoder/test_processor_vision_text_dual_encoder.py b/tests/models/vision_text_dual_encoder/test_processor_vision_text_dual_encoder.py
index e62bfe704d1d93..c9386a160f843d 100644
--- a/tests/models/vision_text_dual_encoder/test_processor_vision_text_dual_encoder.py
+++ b/tests/models/vision_text_dual_encoder/test_processor_vision_text_dual_encoder.py
@@ -21,13 +21,13 @@
from transformers import BertTokenizerFast
from transformers.models.bert.tokenization_bert import VOCAB_FILES_NAMES, BertTokenizer
from transformers.testing_utils import require_tokenizers, require_vision
-from transformers.utils import IMAGE_PROCESSOR_NAME, is_torchvision_available, is_vision_available
+from transformers.utils import IMAGE_PROCESSOR_NAME, is_vision_available
from ...test_processing_common import ProcessorTesterMixin
if is_vision_available():
- from transformers import VisionTextDualEncoderProcessor, ViTImageProcessor, ViTImageProcessorFast
+ from transformers import VisionTextDualEncoderProcessor, ViTImageProcessor
@require_tokenizers
@@ -63,8 +63,6 @@ def get_tokenizer(self, **kwargs):
return BertTokenizer.from_pretrained(self.tmpdirname, **kwargs)
def get_image_processor(self, **kwargs):
- if is_torchvision_available():
- return ViTImageProcessorFast.from_pretrained(self.tmpdirname, **kwargs)
return ViTImageProcessor.from_pretrained(self.tmpdirname, **kwargs)
def tearDown(self):
@@ -83,7 +81,7 @@ def test_save_load_pretrained_default(self):
self.assertIsInstance(processor.tokenizer, (BertTokenizer, BertTokenizerFast))
self.assertEqual(processor.image_processor.to_json_string(), image_processor.to_json_string())
- self.assertIsInstance(processor.image_processor, (ViTImageProcessor, ViTImageProcessorFast))
+ self.assertIsInstance(processor.image_processor, ViTImageProcessor)
def test_save_load_pretrained_additional_features(self):
processor = VisionTextDualEncoderProcessor(
@@ -102,7 +100,7 @@ def test_save_load_pretrained_additional_features(self):
self.assertIsInstance(processor.tokenizer, (BertTokenizer, BertTokenizerFast))
self.assertEqual(processor.image_processor.to_json_string(), image_processor_add_kwargs.to_json_string())
- self.assertIsInstance(processor.image_processor, (ViTImageProcessor, ViTImageProcessorFast))
+ self.assertIsInstance(processor.image_processor, ViTImageProcessor)
def test_image_processor(self):
image_processor = self.get_image_processor()
@@ -112,8 +110,8 @@ def test_image_processor(self):
image_input = self.prepare_image_inputs()
- input_feat_extract = image_processor(image_input, return_tensors="pt")
- input_processor = processor(images=image_input, return_tensors="pt")
+ input_feat_extract = image_processor(image_input, return_tensors="np")
+ input_processor = processor(images=image_input, return_tensors="np")
for key in input_feat_extract.keys():
self.assertAlmostEqual(input_feat_extract[key].sum(), input_processor[key].sum(), delta=1e-2)
diff --git a/tests/models/wav2vec2/test_feature_extraction_wav2vec2.py b/tests/models/wav2vec2/test_feature_extraction_wav2vec2.py
index 2a92ce3ac39f88..29e4bf3e28701a 100644
--- a/tests/models/wav2vec2/test_feature_extraction_wav2vec2.py
+++ b/tests/models/wav2vec2/test_feature_extraction_wav2vec2.py
@@ -44,7 +44,7 @@ def floats_list(shape, scale=1.0, rng=None, name=None):
return values
-class Wav2Vec2FeatureExtractionTester:
+class Wav2Vec2FeatureExtractionTester(unittest.TestCase):
def __init__(
self,
parent,
diff --git a/tests/models/whisper/test_feature_extraction_whisper.py b/tests/models/whisper/test_feature_extraction_whisper.py
index 4b2353bce0027e..a8295542f4e377 100644
--- a/tests/models/whisper/test_feature_extraction_whisper.py
+++ b/tests/models/whisper/test_feature_extraction_whisper.py
@@ -50,7 +50,7 @@ def floats_list(shape, scale=1.0, rng=None, name=None):
return values
-class WhisperFeatureExtractionTester:
+class WhisperFeatureExtractionTester(unittest.TestCase):
def __init__(
self,
parent,
diff --git a/tests/models/whisper/test_modeling_tf_whisper.py b/tests/models/whisper/test_modeling_tf_whisper.py
index 504b6174fc52ad..73303e374c8484 100644
--- a/tests/models/whisper/test_modeling_tf_whisper.py
+++ b/tests/models/whisper/test_modeling_tf_whisper.py
@@ -17,22 +17,14 @@
from __future__ import annotations
import inspect
-import os
import tempfile
import traceback
import unittest
import numpy as np
-from transformers import GenerationConfig, WhisperConfig, WhisperFeatureExtractor, WhisperProcessor
-from transformers.testing_utils import (
- is_tf_available,
- require_read_token,
- require_tf,
- require_tokenizers,
- run_test_in_subprocess,
- slow,
-)
+from transformers import WhisperConfig, WhisperFeatureExtractor, WhisperProcessor
+from transformers.testing_utils import is_tf_available, require_tf, require_tokenizers, run_test_in_subprocess, slow
from transformers.utils import cached_property
from transformers.utils.import_utils import is_datasets_available
@@ -757,9 +749,7 @@ def _test_large_generation(in_queue, out_queue, timeout):
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features
generated_ids = model.generate(
- input_features,
- do_sample=False,
- max_length=20,
+ input_features, do_sample=False, max_length=20, language="<|en|>", task="transcribe"
)
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
@@ -782,29 +772,13 @@ def _test_large_generation_multilingual(in_queue, out_queue, timeout):
processor = WhisperProcessor.from_pretrained("openai/whisper-large")
model = TFWhisperForConditionalGeneration.from_pretrained("openai/whisper-large")
- # update generation config
- generation_config = GenerationConfig.from_pretrained("openai/whisper-large-v2")
-
- token = os.getenv("HF_HUB_READ_TOKEN", True)
- ds = load_dataset(
- "mozilla-foundation/common_voice_6_1",
- "ja",
- split="test",
- streaming=True,
- trust_remote_code=True,
- token=token,
- )
+ ds = load_dataset("legacy-datasets/common_voice", "ja", split="test", streaming=True, trust_remote_code=True)
ds = ds.cast_column("audio", datasets.Audio(sampling_rate=16_000))
input_speech = next(iter(ds))["audio"]["array"]
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features
generated_ids = model.generate(
- input_features,
- do_sample=False,
- max_length=20,
- language="<|ja|>",
- task="transcribe",
- generation_config=generation_config,
+ input_features, do_sample=False, max_length=20, language="<|ja|>", task="transcribe"
)
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
@@ -812,12 +786,7 @@ def _test_large_generation_multilingual(in_queue, out_queue, timeout):
unittest.TestCase().assertEqual(transcript, EXPECTED_TRANSCRIPT)
generated_ids = model.generate(
- input_features,
- do_sample=False,
- max_length=20,
- language="<|en|>",
- task="transcribe",
- generation_config=generation_config,
+ input_features, do_sample=False, max_length=20, language="<|en|>", task="transcribe"
)
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
@@ -825,12 +794,7 @@ def _test_large_generation_multilingual(in_queue, out_queue, timeout):
unittest.TestCase().assertEqual(transcript, EXPECTED_TRANSCRIPT)
generated_ids = model.generate(
- input_features,
- do_sample=False,
- max_length=20,
- language="<|ja|>",
- task="translate",
- generation_config=generation_config,
+ input_features, do_sample=False, max_length=20, language="<|ja|>", task="translate"
)
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
@@ -861,10 +825,10 @@ def _test_large_batched_generation(in_queue, out_queue, timeout):
# fmt: off
EXPECTED_IDS = [
- [50258, 50259, 50359, 50363, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 293, 321, 366, 5404],
- [50258, 50259, 50359, 50363, 6966, 307, 2221, 13, 2326, 388, 391, 311, 9060, 1570, 1880, 813, 702, 1871, 13, 50257],
- [50258, 50259, 50359, 50363, 634, 5112, 505, 300, 412, 341, 42729, 3196, 295, 264, 1064, 11, 365, 5272, 293, 12904],
- [50258, 50259, 50359, 50363, 634, 575, 12525, 22618, 1968, 6144, 35617, 20084, 1756, 311, 589, 307, 534, 10281, 934, 439]
+ [50258, 50358, 50363, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 293, 321, 366, 5404, 281],
+ [50258, 50358, 50363, 6966, 307, 2221, 13, 2326, 388, 391, 311, 9060, 1570, 1880, 813, 702, 1871, 13, 50257, 50257],
+ [50258, 50358, 50363, 634, 5112, 505, 300, 412, 341, 42729, 3196, 295, 264, 1064, 11, 365, 5272, 293, 12904, 9256],
+ [50258, 50358, 50363, 634, 575, 12525, 22618, 1968, 6144, 35617, 20084, 1756, 311, 589, 307, 534, 10281, 934, 439, 11]
]
# fmt: on
@@ -872,10 +836,10 @@ def _test_large_batched_generation(in_queue, out_queue, timeout):
# fmt: off
EXPECTED_TRANSCRIPT = [
- " Mr. Quilter is the apostle of the middle classes and we are glad",
+ " Mr. Quilter is the apostle of the middle classes and we are glad to",
" Nor is Mr. Quilter's manner less interesting than his matter.",
- " He tells us that at this festive season of the year, with Christmas and roast",
- " He has grave doubts whether Sir Frederick Layton's work is really Greek after all"
+ " He tells us that at this festive season of the year, with Christmas and roast beef",
+ " He has grave doubts whether Sir Frederick Layton's work is really Greek after all,"
]
# fmt: on
@@ -1045,7 +1009,6 @@ def test_large_generation(self):
run_test_in_subprocess(test_case=self, target_func=_test_large_generation, inputs=None)
@slow
- @require_read_token
def test_large_generation_multilingual(self):
run_test_in_subprocess(test_case=self, target_func=_test_large_generation_multilingual, inputs=None)
diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py
index 2eff406a3b56fc..faab43854cce11 100644
--- a/tests/models/whisper/test_modeling_whisper.py
+++ b/tests/models/whisper/test_modeling_whisper.py
@@ -445,11 +445,6 @@ def setUp(self):
self.config_tester = ConfigTester(self, config_class=WhisperConfig)
self.maxDiff = 3000
- def prepare_config_and_inputs_for_generate(self, batch_size=2):
- config, inputs_dict = super().prepare_config_and_inputs_for_generate(batch_size=batch_size)
- inputs_dict["force_unique_generate_call"] = True
- return config, inputs_dict
-
def test_config(self):
self.config_tester.run_common_tests()
@@ -1896,8 +1891,8 @@ def test_large_batched_generation_multilingual(self):
"ja",
split="test",
streaming=True,
- trust_remote_code=True,
token=token,
+ trust_remote_code=True,
)
ds = ds.cast_column("audio", datasets.Audio(sampling_rate=16_000))
@@ -2149,16 +2144,11 @@ def test_small_longform_timestamps_generation(self):
},
{
"text": " He has grave doubts whether Sir Frederick Layton's work is really Greek after all and",
- # "timestamp": (39.80, 45.36),
- # above is the expected output on A100.
- # on CI T4s, due to sligth difference in floating points operations, expected is below
- "timestamp": (39.80, 45.38),
+ "timestamp": (39.80, 45.36),
},
{
"text": " can discover in it but little of rocky Ithaca.",
- # "timestamp": (45.36, 49.0),
- # see above
- "timestamp": (45.38, 49.0),
+ "timestamp": (45.36, 49.0),
},
{
"text": " Lenell's pictures are a sort of up-guards-and-atom paintings, and Mason's exquisite ittles",
@@ -2285,20 +2275,20 @@ def test_tiny_token_timestamp_generation(self):
# fmt: off
EXPECTED_OUTPUT = torch.tensor([
- [0.0000, 0.4800, 0.8200, 0.9600, 1.1200, 1.1200, 1.2200, 1.5000, 1.7200, 2.0000, 2.3400, 2.5000, 2.6600, 3.1800, 3.5600, 3.6800, 3.8000, 4.1000, 4.3000, 4.5800, 4.9400, 5.3800, 12.4200, 12.4200, 12.4200, 12.4200, 12.4200, 12.4200, 12.4200, 12.4200, 12.4200, 12.4200, 12.4200, 12.4200, 12.4200, 12.4200, 12.4200, 12.4200, 12.4200],
- [0.0000, 0.5200, 0.9000, 1.1400, 1.4200, 1.5200, 1.6800, 1.6800, 1.8800, 2.1000, 2.2200, 2.6200, 3.1400, 3.5800, 3.9600, 4.4000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000],
+ [0.0000, 0.4800, 0.8200, 0.9600, 1.1200, 1.1200, 1.2200, 1.5000, 1.7200, 2.0000, 2.3400, 2.5000, 2.6600, 3.1800, 3.5600, 3.6800, 3.8000, 4.1000, 4.3000, 4.5800, 4.9400, 5.3800, 12.4200, 12.8400, 26.9200, 26.9200, 26.9200, 26.9200, 26.9200, 26.9200, 26.9200, 26.9200, 26.9200, 26.9200, 26.9200, 26.9400, 26.9400, 26.9400, 26.9400],
+ [0.0000, 0.5200, 0.9000, 1.1400, 1.4200, 1.5200, 1.6800, 1.6800, 1.8800, 2.1000, 2.2200, 2.6200, 3.1400, 3.5800, 3.9600, 4.4000, 17.3000, 17.3000, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7400, 26.7400, 26.7400, 26.7400, 26.7400, 26.7400],
[0.0000, 0.0000, 0.7600, 1.0000, 1.4200, 1.8000, 1.9400, 2.1800, 2.5200, 3.0200, 3.3200, 3.5400, 3.9400, 4.5600, 4.9200, 5.2800, 5.5600, 5.9000, 6.1600, 6.3000, 6.4800, 6.4800, 6.6400, 7.8200, 7.9600, 8.2200, 8.6000, 8.9200, 9.2200, 9.5200, 9.7200, 10.0600, 10.5400, 10.8800, 11.2600, 11.5400, 11.7400, 12.0800, 15.6800],
- [0.0000, 0.0000, 0.7400, 1.0400, 1.3200, 1.6800, 2.1400, 2.4800, 2.7800, 3.0800, 3.1600, 3.4000, 3.6000, 4.0200, 4.2200, 4.8600, 5.2400, 5.7400, 6.3400, 6.6200, 6.7600, 6.7600, 6.8600, 7.2400, 7.4200, 7.6800, 7.9200, 8.4800, 8.7600, 9.2000, 9.2000, 9.4200, 15.8200, 15.8200, 15.8200, 15.8200, 15.8200, 15.8200, 15.8200]
+ [0.0000, 0.0000, 0.7400, 1.0400, 1.3200, 1.6800, 2.1400, 2.4800, 2.7800, 3.0800, 3.1600, 3.4000, 3.6000, 4.0200, 4.2200, 4.8600, 5.2400, 5.7400, 6.3400, 6.6200, 6.7600, 6.7600, 6.8600, 7.2400, 7.4200, 7.6800, 7.9200, 8.4800, 8.7600, 9.2000, 9.2000, 9.4200, 15.8200, 15.8200, 29.6400, 29.6600, 29.6600, 29.6600, 29.6600]
])
# fmt: on
self.assertTrue(torch.allclose(generate_outputs["token_timestamps"].to("cpu"), EXPECTED_OUTPUT))
@slow
- def test_small_token_timestamp_generation(self):
+ def test_large_token_timestamp_generation(self):
set_seed(0)
- processor = WhisperProcessor.from_pretrained("openai/whisper-small")
- model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")
+ processor = WhisperProcessor.from_pretrained("openai/whisper-large-v3")
+ model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large-v3")
model.to(torch_device)
input_speech = self._load_datasamples(4)
@@ -2315,10 +2305,10 @@ def test_small_token_timestamp_generation(self):
# fmt: off
EXPECTED_OUTPUT = torch.tensor([
- [0.0000, 0.0000, 0.7400, 0.8000, 0.9800, 1.0200, 1.1400, 1.4000, 1.5200, 1.9200, 2.2600, 2.3800, 2.5400, 2.8600, 3.2600, 3.3400, 3.4400, 3.6000, 3.6800, 3.9200, 4.2000, 4.4800, 4.7800, 5.2600, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200],
- [0.0000, 0.0000, 0.7600, 1.0000, 1.3000, 1.3800, 1.5200, 1.5800, 1.7000, 1.8400, 2.1000, 2.5000, 3.1400, 3.4400, 3.7400, 4.1800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800],
- [0.0000, 0.0000, 0.6600, 0.9000, 1.2200, 1.5200, 1.7600, 2.0200, 2.4000, 2.9200, 3.1800, 3.3200, 3.6200, 4.1000, 4.3600, 4.7800, 5.1200, 5.3400, 5.7200, 6.0600, 6.2000, 6.2000, 6.2000, 6.5000, 6.9000, 7.6400, 8.0000, 8.2400, 8.5200, 8.7400, 9.0800, 9.4000, 9.5400, 9.9400, 10.4200, 10.7600, 11.1200, 11.4400, 11.5800, 11.8600, 12.4600],
- [0.0000, 0.0000, 0.6600, 0.8600, 1.1400, 1.5000, 1.9600, 2.3600, 2.6400, 2.9800, 3.1200, 3.2400, 3.4800, 3.7800, 4.1400, 4.6400, 5.0800, 5.4400, 6.2200, 6.2200, 6.2200, 6.4000, 6.8400, 7.1200, 7.2600, 7.4800, 7.8200, 8.1400, 8.7000, 9.0200, 9.0200, 9.2000, 9.8800, 9.8800, 9.8800, 9.8800, 9.8800, 9.8800, 9.8800, 9.8800, 9.8800]
+ [0.0000, 0.0000, 0.6200, 0.7400, 0.8600, 1.0000, 1.0400, 1.3000, 1.4400, 1.7800, 2.1800, 2.2800, 2.5000, 2.9200, 3.0000, 3.3800, 3.5000, 3.6000, 3.8400, 4.1000, 4.4000, 4.6800, 5.1400, 5.3600, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200],
+ [0.0000, 0.0000, 0.6000, 0.9200, 1.2200, 1.3400, 1.4200, 1.5400, 1.5800, 1.7400, 2.0600, 2.3800, 3.0400, 3.3800, 3.6400, 4.1200, 4.3600, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800],
+ [0.0000, 0.0000, 0.5400, 0.8200, 1.1600, 1.4600, 1.7400, 1.8800, 2.3400, 2.7400, 3.1400, 3.2200, 3.5400, 4.2800, 4.5600, 4.8200, 5.0600, 5.3200, 5.6600, 5.9600, 6.1400, 6.4000, 6.8400, 7.8800, 8.0200, 8.3600, 8.7000, 9.0200, 9.3200, 9.5000, 9.8400, 10.3000, 10.6600, 11.0800, 11.3600, 11.4600, 11.8000],
+ [0.0000, 0.0000, 0.5600, 0.7600, 1.0600, 1.4000, 1.8800, 2.2600, 2.6200, 2.8000, 2.9600, 3.0000, 3.2000, 3.4400, 3.6800, 4.0000, 4.6000, 5.0000, 5.3200, 5.4800, 6.0600, 6.0600, 6.1000, 6.3200, 6.7400, 7.0000, 7.2200, 7.4000, 7.7600, 8.0600, 8.5600, 8.8600, 8.9400, 9.1000, 9.3400, 9.8800, 9.8800]
])
# fmt: on
@@ -3341,7 +3331,6 @@ def test_tiny_static_generation_long_form(self):
# only permit 4 compilations: 2 prefill steps and 2 decoding steps (1 for each of conditioned/not conditioned)
torch._dynamo.config.cache_size_limit = 4
- torch._dynamo.reset()
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
diff --git a/tests/models/yolos/test_image_processing_yolos.py b/tests/models/yolos/test_image_processing_yolos.py
index 55a4be5c09926b..67508532e9c829 100644
--- a/tests/models/yolos/test_image_processing_yolos.py
+++ b/tests/models/yolos/test_image_processing_yolos.py
@@ -36,7 +36,7 @@
from transformers import YolosImageProcessor
-class YolosImageProcessingTester:
+class YolosImageProcessingTester(unittest.TestCase):
def __init__(
self,
parent,
diff --git a/tests/peft_integration/test_peft_integration.py b/tests/peft_integration/test_peft_integration.py
index bdbccee5ad364c..48fd6da3d67e4c 100644
--- a/tests/peft_integration/test_peft_integration.py
+++ b/tests/peft_integration/test_peft_integration.py
@@ -17,19 +17,10 @@
import tempfile
import unittest
-from datasets import Dataset, DatasetDict
from huggingface_hub import hf_hub_download
from packaging import version
-from transformers import (
- AutoModelForCausalLM,
- AutoModelForSequenceClassification,
- AutoTokenizer,
- OPTForCausalLM,
- Trainer,
- TrainingArguments,
- logging,
-)
+from transformers import AutoModelForCausalLM, OPTForCausalLM, logging
from transformers.testing_utils import (
CaptureLogger,
require_bitsandbytes,
@@ -674,76 +665,3 @@ def test_peft_load_adapter_training_inference_mode_false(self):
else:
assert not module.training
assert all(not p.requires_grad for p in module.parameters())
-
- def test_prefix_tuning_trainer_load_best_model_at_end_error(self):
- # Original issue: https://github.com/huggingface/peft/issues/2256
- # There is a potential error when using load_best_model_at_end=True with a prompt learning PEFT method. This is
- # because Trainer uses load_adapter under the hood but with some prompt learning methods, there is an
- # optimization on the saved model to remove parameters that are not required for inference, which in turn
- # requires a change to the model architecture. This is why load_adapter will fail in such cases and users should
- # instead set load_best_model_at_end=False and use PeftModel.from_pretrained. As this is not obvious, we now
- # intercept the error and add a helpful error message.
- # This test checks this error message. It also tests the "happy path" (i.e. no error) when using LoRA.
- from peft import LoraConfig, PrefixTuningConfig, TaskType, get_peft_model
-
- # create a small sequence classification dataset (binary classification)
- dataset = []
- for i, row in enumerate(os.__doc__.splitlines()):
- dataset.append({"text": row, "label": i % 2})
- ds_train = Dataset.from_list(dataset)
- ds_valid = ds_train
- datasets = DatasetDict(
- {
- "train": ds_train,
- "val": ds_valid,
- }
- )
-
- # tokenizer for peft-internal-testing/tiny-OPTForCausalLM-lora cannot be loaded, thus using
- # hf-internal-testing/tiny-random-OPTForCausalLM
- model_id = "hf-internal-testing/tiny-random-OPTForCausalLM"
- tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="left", model_type="opt")
-
- def tokenize_function(examples):
- return tokenizer(examples["text"], max_length=128, truncation=True, padding="max_length")
-
- tokenized_datasets = datasets.map(tokenize_function, batched=True)
- # lora works, prefix-tuning is expected to raise an error
- peft_configs = {
- "lora": LoraConfig(task_type=TaskType.SEQ_CLS),
- "prefix-tuning": PrefixTuningConfig(
- task_type=TaskType.SEQ_CLS,
- inference_mode=False,
- prefix_projection=True,
- num_virtual_tokens=10,
- ),
- }
-
- for peft_type, peft_config in peft_configs.items():
- base_model = AutoModelForSequenceClassification.from_pretrained(model_id, num_labels=2)
- base_model.config.pad_token_id = tokenizer.pad_token_id
- peft_model = get_peft_model(base_model, peft_config)
-
- with tempfile.TemporaryDirectory() as tmpdirname:
- training_args = TrainingArguments(
- output_dir=tmpdirname,
- num_train_epochs=3,
- eval_strategy="epoch",
- save_strategy="epoch",
- load_best_model_at_end=True,
- )
- trainer = Trainer(
- model=peft_model,
- args=training_args,
- train_dataset=tokenized_datasets["train"],
- eval_dataset=tokenized_datasets["val"],
- )
-
- if peft_type == "lora":
- # LoRA works with load_best_model_at_end
- trainer.train()
- else:
- # prefix tuning does not work, but at least users should get a helpful error message
- msg = "When using prompt learning PEFT methods such as PREFIX_TUNING"
- with self.assertRaisesRegex(RuntimeError, msg):
- trainer.train()
diff --git a/tests/pipelines/test_pipelines_table_question_answering.py b/tests/pipelines/test_pipelines_table_question_answering.py
index e2141dc7cc2f66..9481ab200063f8 100644
--- a/tests/pipelines/test_pipelines_table_question_answering.py
+++ b/tests/pipelines/test_pipelines_table_question_answering.py
@@ -20,6 +20,7 @@
AutoTokenizer,
TableQuestionAnsweringPipeline,
TFAutoModelForTableQuestionAnswering,
+ is_torch_available,
pipeline,
)
from transformers.testing_utils import (
@@ -32,6 +33,12 @@
)
+if is_torch_available():
+ from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_12
+else:
+ is_torch_greater_or_equal_than_1_12 = False
+
+
@is_pipeline_test
class TQAPipelineTests(unittest.TestCase):
# Putting it there for consistency, but TQA do not have fast tokenizer
@@ -143,6 +150,7 @@ def test_small_model_tf(self):
},
)
+ @unittest.skipIf(not is_torch_greater_or_equal_than_1_12, reason="Tapas is only available in torch v1.12+")
@require_torch
def test_small_model_pt(self, torch_dtype="float32"):
model_id = "lysandre/tiny-tapas-random-wtq"
@@ -245,10 +253,12 @@ def test_small_model_pt(self, torch_dtype="float32"):
},
)
+ @unittest.skipIf(not is_torch_greater_or_equal_than_1_12, reason="Tapas is only available in torch v1.12+")
@require_torch
def test_small_model_pt_fp16(self):
self.test_small_model_pt(torch_dtype="float16")
+ @unittest.skipIf(not is_torch_greater_or_equal_than_1_12, reason="Tapas is only available in torch v1.12+")
@require_torch
def test_slow_tokenizer_sqa_pt(self, torch_dtype="float32"):
model_id = "lysandre/tiny-tapas-random-sqa"
@@ -368,6 +378,7 @@ def test_slow_tokenizer_sqa_pt(self, torch_dtype="float32"):
},
)
+ @unittest.skipIf(not is_torch_greater_or_equal_than_1_12, reason="Tapas is only available in torch v1.12+")
@require_torch
def test_slow_tokenizer_sqa_pt_fp16(self):
self.test_slow_tokenizer_sqa_pt(torch_dtype="float16")
@@ -494,6 +505,7 @@ def test_slow_tokenizer_sqa_tf(self):
},
)
+ @unittest.skipIf(not is_torch_greater_or_equal_than_1_12, reason="Tapas is only available in torch v1.12+")
@slow
@require_torch
def test_integration_wtq_pt(self, torch_dtype="float32"):
@@ -539,6 +551,7 @@ def test_integration_wtq_pt(self, torch_dtype="float32"):
]
self.assertListEqual(results, expected_results)
+ @unittest.skipIf(not is_torch_greater_or_equal_than_1_12, reason="Tapas is only available in torch v1.12+")
@slow
@require_torch
def test_integration_wtq_pt_fp16(self):
@@ -593,6 +606,7 @@ def test_integration_wtq_tf(self):
]
self.assertListEqual(results, expected_results)
+ @unittest.skipIf(not is_torch_greater_or_equal_than_1_12, reason="Tapas is only available in torch v1.12+")
@slow
@require_torch
def test_integration_sqa_pt(self, torch_dtype="float32"):
@@ -618,6 +632,7 @@ def test_integration_sqa_pt(self, torch_dtype="float32"):
]
self.assertListEqual(results, expected_results)
+ @unittest.skipIf(not is_torch_greater_or_equal_than_1_12, reason="Tapas is only available in torch v1.12+")
@slow
@require_torch
def test_integration_sqa_pt_fp16(self):
diff --git a/tests/pipelines/test_pipelines_text_to_audio.py b/tests/pipelines/test_pipelines_text_to_audio.py
index e07e2ad392a3e6..dac2ce6b30ec22 100644
--- a/tests/pipelines/test_pipelines_text_to_audio.py
+++ b/tests/pipelines/test_pipelines_text_to_audio.py
@@ -27,6 +27,7 @@
require_torch,
require_torch_accelerator,
require_torch_or_tf,
+ run_test_using_subprocess,
slow,
torch_device,
)
@@ -66,8 +67,10 @@ def test_small_musicgen_pt(self):
audio = [output["audio"] for output in outputs]
self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio)
+ # TODO: @ylacombe: `SeamlessM4TForTextToSpeech.generate` has issue with `generation_config`. See issue #34811
@slow
@require_torch
+ @run_test_using_subprocess
def test_medium_seamless_m4t_pt(self):
speech_generator = pipeline(task="text-to-audio", model="facebook/hf-seamless-m4t-medium", framework="pt")
diff --git a/tests/quantization/bnb/test_4bit.py b/tests/quantization/bnb/test_4bit.py
index c4287362b6bc1c..3eae429abb206a 100644
--- a/tests/quantization/bnb/test_4bit.py
+++ b/tests/quantization/bnb/test_4bit.py
@@ -53,8 +53,6 @@ def get_some_linear_layer(model):
except AttributeError:
# for AutoModelforCausalLM
return model.model.decoder.layers[0].fc1
- elif model.config.model_type == "llama":
- return model.model.layers[0].mlp.gate_proj
else:
return model.transformer.h[0].mlp.dense_4h_to_h
@@ -108,7 +106,6 @@ class Base4bitTest(unittest.TestCase):
EXPECTED_OUTPUTS.add("Hello my name is John and I am a professional photographer. I")
EXPECTED_OUTPUTS.add("Hello my name is John.\nI am a friend of your father.\n")
EXPECTED_OUTPUTS.add("Hello my name is John Doe, I am a student at the University")
- EXPECTED_OUTPUTS.add("Hello my name is John and I am 25 years old.")
MAX_NEW_TOKENS = 10
def setUp(self):
@@ -388,14 +385,14 @@ def test_inference_without_keep_in_fp32(self):
# test with `google-t5/t5-small`
model = T5ForConditionalGeneration.from_pretrained(self.model_name, load_in_4bit=True, device_map="auto")
- encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(model.device)
+ encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
_ = model.generate(**encoded_input)
# test with `flan-t5-small`
model = T5ForConditionalGeneration.from_pretrained(
self.dense_act_model_name, load_in_4bit=True, device_map="auto"
)
- encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(model.device)
+ encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
_ = model.generate(**encoded_input)
T5ForConditionalGeneration._keep_in_fp32_modules = modules
@@ -413,14 +410,14 @@ def test_inference_with_keep_in_fp32(self):
# there was a bug with decoders - this test checks that it is fixed
self.assertTrue(isinstance(model.decoder.block[0].layer[0].SelfAttention.q, bnb.nn.Linear4bit))
- encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(model.device)
+ encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
_ = model.generate(**encoded_input)
# test with `flan-t5-small`
model = T5ForConditionalGeneration.from_pretrained(
self.dense_act_model_name, load_in_4bit=True, device_map="auto"
)
- encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(model.device)
+ encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
_ = model.generate(**encoded_input)
@@ -558,8 +555,6 @@ def test_training(self):
if torch.cuda.is_available():
self.assertEqual(set(model.hf_device_map.values()), {torch.cuda.current_device()})
- elif torch.xpu.is_available():
- self.assertEqual(set(model.hf_device_map.values()), {f"xpu:{torch.xpu.current_device()}"})
else:
self.assertTrue(all(param.device.type == "cpu" for param in model.parameters()))
@@ -593,18 +588,11 @@ def test_training(self):
@apply_skip_if_not_implemented
-@unittest.skipIf(torch_device == "xpu", reason="XPU has precision issue on gpt model, will test it once fixed")
class Bnb4BitGPT2Test(Bnb4BitTest):
model_name = "openai-community/gpt2-xl"
EXPECTED_RELATIVE_DIFFERENCE = 3.3191854854152187
-@apply_skip_if_not_implemented
-class Bnb4BitLlamaTest(Bnb4BitTest):
- model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
- EXPECTED_RELATIVE_DIFFERENCE = 2.9461410686392764
-
-
@require_bitsandbytes
@require_accelerate
@require_torch
@@ -684,7 +672,7 @@ def test_serialization(self, quant_type="nf4", double_quant=True, safe_serializa
encoded_input = tokenizer(self.input_text, return_tensors="pt").to(torch_device)
out_0 = model_0(**encoded_input)
out_1 = model_1(**encoded_input)
- self.assertTrue(torch.allclose(out_0["logits"], out_1["logits"], atol=0.05))
+ self.assertTrue(torch.equal(out_0["logits"], out_1["logits"]))
# comparing generate() outputs
encoded_input = tokenizer(self.input_text, return_tensors="pt").to(torch_device)
@@ -746,14 +734,6 @@ class GPTSerializationTest(BaseSerializationTest):
model_name = "openai-community/gpt2-xl"
-class LlamaSerializationTest(BaseSerializationTest):
- """
- default BaseSerializationTest config tested with Llama family model
- """
-
- model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
-
-
@require_bitsandbytes
@require_accelerate
@require_torch_gpu_if_bnb_not_multi_backend_enabled
diff --git a/tests/quantization/bnb/test_mixed_int8.py b/tests/quantization/bnb/test_mixed_int8.py
index 26e8cb2fc731ec..567aa956271b70 100644
--- a/tests/quantization/bnb/test_mixed_int8.py
+++ b/tests/quantization/bnb/test_mixed_int8.py
@@ -48,8 +48,6 @@
def get_some_linear_layer(model):
if model.config.model_type == "gpt2":
return model.transformer.h[0].mlp.c_fc
- elif model.config.model_type == "llama":
- return model.model.layers[0].mlp.gate_proj
return model.transformer.h[0].mlp.dense_4h_to_h
@@ -67,12 +65,12 @@ def get_some_linear_layer(model):
class LoRALayer(nn.Module):
"""Wraps a linear layer with LoRA-like adapter - Used for testing purposes only"""
- def __init__(self, module: nn.Module, rank: int, dtype: torch.dtype):
+ def __init__(self, module: nn.Module, rank: int):
super().__init__()
self.module = module
self.adapter = nn.Sequential(
- nn.Linear(module.in_features, rank, bias=False, dtype=dtype),
- nn.Linear(rank, module.out_features, bias=False, dtype=dtype),
+ nn.Linear(module.in_features, rank, bias=False),
+ nn.Linear(rank, module.out_features, bias=False),
)
small_std = (2.0 / (5 * min(module.in_features, module.out_features))) ** 0.5
nn.init.normal_(self.adapter[0].weight, std=small_std)
@@ -516,14 +514,14 @@ def test_inference_without_keep_in_fp32(self):
# test with `google-t5/t5-small`
model = T5ForConditionalGeneration.from_pretrained(self.model_name, load_in_8bit=True, device_map="auto")
- encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(model.device)
+ encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
_ = model.generate(**encoded_input)
# test with `flan-t5-small`
model = T5ForConditionalGeneration.from_pretrained(
self.dense_act_model_name, load_in_8bit=True, device_map="auto"
)
- encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(model.device)
+ encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
_ = model.generate(**encoded_input)
T5ForConditionalGeneration._keep_in_fp32_modules = modules
@@ -542,14 +540,14 @@ def test_inference_with_keep_in_fp32(self):
# there was a bug with decoders - this test checks that it is fixed
self.assertTrue(isinstance(model.decoder.block[0].layer[0].SelfAttention.q, bnb.nn.Linear8bitLt))
- encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(model.device)
+ encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
_ = model.generate(**encoded_input)
# test with `flan-t5-small`
model = T5ForConditionalGeneration.from_pretrained(
self.dense_act_model_name, load_in_8bit=True, device_map="auto"
)
- encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(model.device)
+ encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
_ = model.generate(**encoded_input)
def test_inference_with_keep_in_fp32_serialized(self):
@@ -573,14 +571,14 @@ def test_inference_with_keep_in_fp32_serialized(self):
# there was a bug with decoders - this test checks that it is fixed
self.assertTrue(isinstance(model.decoder.block[0].layer[0].SelfAttention.q, bnb.nn.Linear8bitLt))
- encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(model.device)
+ encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
_ = model.generate(**encoded_input)
# test with `flan-t5-small`
model = T5ForConditionalGeneration.from_pretrained(
self.dense_act_model_name, load_in_8bit=True, device_map="auto"
)
- encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(model.device)
+ encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
_ = model.generate(**encoded_input)
@@ -860,36 +858,29 @@ def test_training(self):
if torch.cuda.is_available():
self.assertEqual(set(model.hf_device_map.values()), {torch.cuda.current_device()})
- elif torch.xpu.is_available():
- self.assertEqual(set(model.hf_device_map.values()), {f"xpu:{torch.xpu.current_device()}"})
else:
self.assertTrue(all(param.device.type == "cpu" for param in model.parameters()))
for param in model.parameters():
param.requires_grad = False # freeze the model - train adapters later
- # cast all non INT8 parameters to fp32
- if param.dtype in (torch.float16, torch.bfloat16) and param.__class__.__name__ != "Params4bit":
+ if param.ndim == 1:
+ # cast the small parameters (e.g. layernorm) to fp32 for stability
param.data = param.data.to(torch.float32)
# Step 2: add adapters
for _, module in model.named_modules():
if isinstance(module, OPTAttention):
- module.q_proj = LoRALayer(module.q_proj, rank=16, dtype=model.dtype)
- module.k_proj = LoRALayer(module.k_proj, rank=16, dtype=model.dtype)
- module.v_proj = LoRALayer(module.v_proj, rank=16, dtype=model.dtype)
+ module.q_proj = LoRALayer(module.q_proj, rank=16)
+ module.k_proj = LoRALayer(module.k_proj, rank=16)
+ module.v_proj = LoRALayer(module.v_proj, rank=16)
# Step 3: dummy batch
batch = self.tokenizer("Test batch ", return_tensors="pt").to(torch_device)
# Step 4: Check if the gradient is not None
- if torch_device in {"xpu", "cpu"}:
- # XPU and CPU finetune do not support autocast for now.
+ with torch.autocast(torch_device):
out = model.forward(**batch)
out.logits.norm().backward()
- else:
- with torch.autocast(torch_device):
- out = model.forward(**batch)
- out.logits.norm().backward()
for module in model.modules():
if isinstance(module, LoRALayer):
@@ -900,7 +891,6 @@ def test_training(self):
@apply_skip_if_not_implemented
-@unittest.skipIf(torch_device == "xpu", reason="XPU has precision issue on gpt model, will test it once fixed")
class MixedInt8GPT2Test(MixedInt8Test):
model_name = "openai-community/gpt2-xl"
EXPECTED_RELATIVE_DIFFERENCE = 1.8720077507258357
@@ -932,30 +922,3 @@ def test_int8_from_pretrained(self):
output_sequences = model.generate(input_ids=encoded_input["input_ids"].to(torch_device), max_new_tokens=10)
self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
-
-
-class MixedInt8LlamaTest(MixedInt8Test):
- model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
- EXPECTED_RELATIVE_DIFFERENCE = 1.7869331026479096
- EXPECTED_OUTPUTS = set()
- EXPECTED_OUTPUTS.add("Hello my name is John Smith and I am a software engineer. I")
-
- def test_int8_from_pretrained(self):
- r"""
- Test whether loading a 8bit model from the Hub works as expected
- """
- from bitsandbytes.nn import Int8Params
-
- model_id = "Jiqing/TinyLlama-1.1B-Chat-v1.0-bnb-8bit"
-
- model = AutoModelForCausalLM.from_pretrained(model_id)
-
- linear = get_some_linear_layer(model)
- self.assertTrue(linear.weight.__class__ == Int8Params)
- self.assertTrue(hasattr(linear.weight, "SCB"))
-
- # generate
- encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
- output_sequences = model.generate(input_ids=encoded_input["input_ids"].to(torch_device), max_new_tokens=10)
-
- self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
diff --git a/tests/quantization/compressed_tensor/test_load_sparse_model.py b/tests/quantization/compressed_tensor/test_load_sparse_model.py
deleted file mode 100644
index 8992cd3d9bd470..00000000000000
--- a/tests/quantization/compressed_tensor/test_load_sparse_model.py
+++ /dev/null
@@ -1,80 +0,0 @@
-import gc
-import unittest
-
-from transformers import AutoModelForCausalLM
-from transformers.testing_utils import require_compressed_tensors, require_torch
-from transformers.utils import is_torch_available
-
-
-if is_torch_available():
- import torch
-
-
-@require_compressed_tensors
-@require_torch
-class CompressedTensorsTest(unittest.TestCase):
- model_sparse_uncompressed = "horheynm/llama2.c_stories15M_pruned_50.2of4_uncompressed"
- model_sparse_compressed = "horheynm/llama2.c_stories15M_pruned_50.2of4_compressed"
-
- prompt = "Paris is the capital of which country?"
-
- stubs = [model_sparse_uncompressed, model_sparse_compressed]
-
- def tearDown(self):
- gc.collect()
- torch.cuda.empty_cache()
- gc.collect()
-
- def test_compressed_uncompressed_model_shapes(self):
- """
- Check that the weights are the same between
- uncompressed and compressed-decompressed model
- Sparse compressed modules' weights are "packed" and shape/value will
- differ
- """
-
- def _has_nested_attr(obj, attr_path):
- attrs = attr_path.split(".")
- for attr in attrs:
- if not hasattr(obj, attr):
- return None
- obj = getattr(obj, attr)
- return obj
-
- from compressed_tensors.quantization.utils import iter_named_leaf_modules
-
- uncompressed_model = AutoModelForCausalLM.from_pretrained(
- self.model_sparse_uncompressed,
- )
-
- compressed_model_decompressed = AutoModelForCausalLM.from_pretrained(
- self.model_sparse_compressed,
- )
-
- for name, submodule in iter_named_leaf_modules(
- uncompressed_model,
- ):
- if comp_decomp_obj := _has_nested_attr(compressed_model_decompressed, name):
- if hasattr(submodule, "weight"):
- assert torch.equal(submodule.weight, comp_decomp_obj.weight)
-
- def test_run_compressed_outputs_match(self):
- """Check that uncompressed and compressed-decompressed model outputs are the same"""
-
- from transformers import AutoTokenizer
-
- for stub in self.stubs:
- tokenizer = AutoTokenizer.from_pretrained(stub)
- input_ids = tokenizer(self.prompt, return_tensors="pt").input_ids
-
- uncompressed_model = AutoModelForCausalLM.from_pretrained(
- self.model_sparse_uncompressed,
- )
- output_rc_true = uncompressed_model.generate(input_ids, max_new_tokens=100)
-
- compressed_model_decompressed = AutoModelForCausalLM.from_pretrained(
- self.model_sparse_compressed,
- )
- output_rc_false = compressed_model_decompressed.generate(input_ids, max_new_tokens=100)
-
- assert tokenizer.decode(output_rc_true[0]) == tokenizer.decode(output_rc_false[0])
diff --git a/tests/quantization/compressed_tensor/test_run_compressed_model.py b/tests/quantization/compressed_tensor/test_run_compressed_model.py
deleted file mode 100644
index b168ca382ccefa..00000000000000
--- a/tests/quantization/compressed_tensor/test_run_compressed_model.py
+++ /dev/null
@@ -1,94 +0,0 @@
-import gc
-import unittest
-
-from transformers import AutoModelForCausalLM
-from transformers.testing_utils import require_compressed_tensors, require_torch
-from transformers.utils import is_torch_available
-
-
-if is_torch_available():
- import torch
-
-
-@require_compressed_tensors
-@require_torch
-class CompressedTensorsTest(unittest.TestCase):
- tinyllama_w4a16 = "nm-testing/tinyllama-w4a16-compressed-hf-quantizer"
- tinyllama_w8a8 = "nm-testing/tinyllama-w8a8-compressed-hf-quantizer"
-
- prompt = "Paris is the capital of which country?"
-
- stubs = [tinyllama_w4a16, tinyllama_w8a8]
-
- def tearDown(self):
- gc.collect()
- torch.cuda.empty_cache()
- gc.collect()
-
- def test_default_run_compressed__True(self):
- from compressed_tensors.linear.compressed_linear import CompressedLinear
- from compressed_tensors.quantization.utils import iter_named_leaf_modules
-
- for stub in self.stubs:
- model = AutoModelForCausalLM.from_pretrained(
- stub,
- )
- compressed_linear_counts = 0
-
- for _, submodule in iter_named_leaf_modules(
- model,
- ):
- if isinstance(submodule, CompressedLinear):
- compressed_linear_counts += 1
-
- # some linear models are not compressed - ex. lm_head
- assert compressed_linear_counts > 0
-
- def test_default_run_compressed__False(self):
- from compressed_tensors.linear.compressed_linear import CompressedLinear
- from compressed_tensors.quantization.utils import iter_named_leaf_modules
-
- from transformers.utils.quantization_config import CompressedTensorsConfig
-
- quantization_config = CompressedTensorsConfig(run_compressed=False)
-
- for stub in self.stubs:
- model = AutoModelForCausalLM.from_pretrained(
- stub,
- quantization_config=quantization_config,
- )
- compressed_linear_counts = 0
-
- for _, submodule in iter_named_leaf_modules(
- model,
- ):
- if isinstance(submodule, CompressedLinear):
- compressed_linear_counts += 1
-
- # No modules should be CompressedLinear
- assert compressed_linear_counts == 0
-
- def test_run_compressed_outputs_match(self):
- """Check that run_compressed=True/False output are the same"""
-
- from transformers import AutoTokenizer
- from transformers.utils.quantization_config import CompressedTensorsConfig
-
- quantization_config = CompressedTensorsConfig(run_compressed=False)
-
- for stub in self.stubs:
- tokenizer = AutoTokenizer.from_pretrained(stub)
- input_ids = tokenizer(self.prompt, return_tensors="pt").input_ids
-
- model_run_compressed__True = AutoModelForCausalLM.from_pretrained(
- stub,
- )
- output_rc_true = model_run_compressed__True.generate(input_ids, max_new_tokens=100)
-
- model_run_compressed__False = AutoModelForCausalLM.from_pretrained(
- stub,
- quantization_config=quantization_config,
- )
- output_rc_false = model_run_compressed__False.generate(input_ids, max_new_tokens=100)
-
- assert tokenizer.decode(output_rc_true[0]) == tokenizer.decode(output_rc_false[0])
diff --git a/tests/quantization/eetq_integration/test_eetq.py b/tests/quantization/eetq_integration/test_eetq.py
index f14fa076e4bb76..2c01f8145cba0e 100644
--- a/tests/quantization/eetq_integration/test_eetq.py
+++ b/tests/quantization/eetq_integration/test_eetq.py
@@ -119,7 +119,7 @@ def test_quantized_model_conversion(self):
self.assertEqual(nb_linears - 1, nb_eetq_linear)
- # Try with `modules_to_not_convert`
+ # Try with `linear_weights_not_to_quantize`
with init_empty_weights():
model = OPTForCausalLM(config)
quantization_config = EetqConfig(modules_to_not_convert=["fc1"])
@@ -128,7 +128,7 @@ def test_quantized_model_conversion(self):
for module in model.modules():
if isinstance(module, EetqLinear):
nb_eetq_linear += 1
- # 25 corresponds to the lm_head along with 24 fc1 layers.
+
self.assertEqual(nb_linears - 25, nb_eetq_linear)
def test_quantized_model(self):
diff --git a/tests/quantization/ggml/test_ggml.py b/tests/quantization/ggml/test_ggml.py
index 508975865c27af..1171e82e5285d5 100644
--- a/tests/quantization/ggml/test_ggml.py
+++ b/tests/quantization/ggml/test_ggml.py
@@ -45,8 +45,7 @@ class GgufIntegrationTests(unittest.TestCase):
phi3_model_id = "microsoft/Phi-3-mini-4k-instruct-gguf"
bloom_model_id = "afrideva/bloom-560m-GGUF"
original_bloom_model_id = "bigscience/bloom-560m"
- falcon7b_model_id_q2 = "xaviviro/falcon-7b-quantized-gguf"
- falcon7b_model_id_fp16 = "medmekk/falcon-7b-gguf"
+ falcon7b_model_id = "xaviviro/falcon-7b-quantized-gguf"
falcon40b_model_id = "maddes8cht/tiiuae-falcon-40b-gguf"
original_flacon7b_model_id = "tiiuae/falcon-7b"
t5_model_id = "repetitio/flan-t5-small"
@@ -616,9 +615,9 @@ def test_falcon40b_q2_k(self):
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
def test_falcon7b_q2_k(self):
- tokenizer = AutoTokenizer.from_pretrained(self.falcon7b_model_id_q2, gguf_file=self.q2_k_falcon7b_model_id)
+ tokenizer = AutoTokenizer.from_pretrained(self.falcon7b_model_id, gguf_file=self.q2_k_falcon7b_model_id)
model = AutoModelForCausalLM.from_pretrained(
- self.falcon7b_model_id_q2,
+ self.falcon7b_model_id,
gguf_file=self.q2_k_falcon7b_model_id,
device_map="auto",
torch_dtype=torch.float16,
@@ -632,7 +631,7 @@ def test_falcon7b_q2_k(self):
def test_falcon7b_weights_conversion_fp16(self):
quantized_model = AutoModelForCausalLM.from_pretrained(
- self.falcon7b_model_id_fp16,
+ self.falcon7b_model_id,
gguf_file=self.fp16_falcon7b_model_id,
device_map="auto",
torch_dtype=torch.float16,
diff --git a/tests/quantization/vptq_integration/__init__.py b/tests/quantization/vptq_integration/__init__.py
deleted file mode 100644
index e69de29bb2d1d6..00000000000000
diff --git a/tests/quantization/vptq_integration/test_vptq.py b/tests/quantization/vptq_integration/test_vptq.py
deleted file mode 100644
index faa9a5879d1dcc..00000000000000
--- a/tests/quantization/vptq_integration/test_vptq.py
+++ /dev/null
@@ -1,194 +0,0 @@
-# coding=utf-8
-# Copyright 2024 The HuggingFace Team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import gc
-import tempfile
-import unittest
-
-from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, VptqConfig
-from transformers.testing_utils import (
- require_accelerate,
- require_torch_gpu,
- require_torch_multi_gpu,
- require_vptq,
- slow,
- torch_device,
-)
-from transformers.utils import is_accelerate_available, is_torch_available
-
-
-if is_torch_available():
- import torch
-
-if is_accelerate_available():
- from accelerate import init_empty_weights
-
-
-class VptqConfigTest(unittest.TestCase):
- def test_to_dict(self):
- """
- Makes sure the config format is properly set
- """
- quantization_config = VptqConfig()
- vptq_orig_config = quantization_config.to_dict()
-
- self.assertEqual(quantization_config.quant_config, vptq_orig_config["quant_config"])
-
-
-@slow
-@require_torch_gpu
-@require_vptq
-@require_accelerate
-class VptqTest(unittest.TestCase):
- model_name = "VPTQ-community/Meta-Llama-3.1-8B-Instruct-v12-k65536-4096-woft"
-
- input_text = "Hello my name is"
- max_new_tokens = 32
-
- EXPECTED_OUTPUT = "Hello my name is Sarah and I am a 25 year old woman from the United States. I am a college graduate and I am currently working as a marketing specialist for a small"
-
- device_map = "cuda"
-
- # called only once for all test in this class
- @classmethod
- def setUpClass(cls):
- """
- Setup quantized model
- """
- cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name)
- cls.quantized_model = AutoModelForCausalLM.from_pretrained(
- cls.model_name,
- device_map=cls.device_map,
- )
-
- def tearDown(self):
- gc.collect()
- torch.cuda.empty_cache()
- gc.collect()
-
- def test_quantized_model(self):
- """
- Simple test that checks if the quantized model is working properly
- """
- input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
-
- output = self.quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens, do_sample=False)
- self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
-
- def test_raise_if_non_quantized(self):
- model_id = "facebook/opt-125m"
- quantization_config = VptqConfig()
-
- with self.assertRaises(ValueError):
- _ = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=quantization_config)
-
- def test_save_pretrained(self):
- """
- Simple test that checks if the quantized model is working properly after being saved and loaded
- """
- with tempfile.TemporaryDirectory() as tmpdirname:
- self.quantized_model.save_pretrained(tmpdirname)
- model = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map=self.device_map)
-
- input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
-
- output = model.generate(**input_ids, max_new_tokens=self.max_new_tokens, do_sample=False)
- self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
-
- @require_torch_multi_gpu
- def test_quantized_model_multi_gpu(self):
- """
- Simple test that checks if the quantized model is working properly with multiple GPUs
- """
- input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
-
- quantized_model = AutoModelForCausalLM.from_pretrained(self.model_name, device_map="auto")
-
- self.assertTrue(set(quantized_model.hf_device_map.values()) == {0, 1})
-
- output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens, do_sample=False)
-
- self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
-
- def test_quantized_model_conversion(self):
- """
- Simple test that checks if the quantized model has been converted properly
- """
- from vptq import VQuantLinear
-
- from transformers.integrations import replace_with_vptq_linear
-
- model_id = "facebook/opt-350m"
- config = AutoConfig.from_pretrained(model_id, revision="cb32f77e905cccbca1d970436fb0f5e6b58ee3c5")
- modules_to_not_convert = ["lm_head"]
- names = [
- "q_proj",
- "k_proj",
- "v_proj",
- "out_proj",
- "fc1",
- "fc2",
- ]
- value = {
- "enable_norm": True,
- "enable_perm": True,
- "group_num": 1,
- "group_size": 128,
- "indices_as_float": False,
- "num_centroids": [-1, 128],
- "num_res_centroids": [-1, 128],
- "outlier_size": 0,
- "vector_lens": [-1, 12],
- }
- shared_layer_config = {}
- for name in names:
- shared_layer_config[name] = value
- for i in range(24):
- modules_to_not_convert.append("model.decoder.layers.{layer_idx}.fc1".format(layer_idx=i))
- layer_configs = {}
- layer_configs["model.decoder.project_out"] = value
- layer_configs["model.decoder.project_in"] = value
- quantization_config = VptqConfig(config_for_layers=layer_configs, shared_layer_config=shared_layer_config)
-
- with init_empty_weights():
- model = AutoModelForCausalLM.from_config(config)
-
- nb_linears = 0
- for module in model.modules():
- if isinstance(module, torch.nn.Linear):
- nb_linears += 1
-
- model, _ = replace_with_vptq_linear(model, quantization_config=quantization_config)
- nb_vptq_linear = 0
- for module in model.modules():
- if isinstance(module, VQuantLinear):
- nb_vptq_linear += 1
-
- self.assertEqual(nb_linears - 1, nb_vptq_linear)
-
- # Try with `linear_weights_not_to_quantize`
- with init_empty_weights():
- model = AutoModelForCausalLM.from_config(config)
- quantization_config = VptqConfig(config_for_layers=layer_configs, shared_layer_config=shared_layer_config)
- model, _ = replace_with_vptq_linear(
- model, quantization_config=quantization_config, modules_to_not_convert=modules_to_not_convert
- )
- nb_vptq_linear = 0
- for module in model.modules():
- if isinstance(module, VQuantLinear):
- nb_vptq_linear += 1
- # 25 comes from 24 decoder.layers.{layer_idx}.fc1
- # and the last lm_head
- self.assertEqual(nb_linears - 25, nb_vptq_linear)
diff --git a/tests/test_image_processing_common.py b/tests/test_image_processing_common.py
index 221552175a93e3..7d89b43ce35ba4 100644
--- a/tests/test_image_processing_common.py
+++ b/tests/test_image_processing_common.py
@@ -228,15 +228,14 @@ def test_image_processor_from_and_save_pretrained(self):
self.assertEqual(image_processor_second.to_dict(), image_processor_first.to_dict())
def test_image_processor_save_load_with_autoimageprocessor(self):
- for i, image_processing_class in enumerate(self.image_processor_list):
+ for image_processing_class in self.image_processor_list:
image_processor_first = image_processing_class(**self.image_processor_dict)
with tempfile.TemporaryDirectory() as tmpdirname:
saved_file = image_processor_first.save_pretrained(tmpdirname)[0]
check_json_file_has_correct_format(saved_file)
- use_fast = i == 1
- image_processor_second = AutoImageProcessor.from_pretrained(tmpdirname, use_fast=use_fast)
+ image_processor_second = AutoImageProcessor.from_pretrained(tmpdirname)
self.assertEqual(image_processor_second.to_dict(), image_processor_first.to_dict())
diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py
index 929bbb13a56e80..99d0a8058c67f8 100755
--- a/tests/test_modeling_common.py
+++ b/tests/test_modeling_common.py
@@ -89,9 +89,6 @@
require_torch_multi_accelerator,
require_torch_multi_gpu,
require_torch_sdpa,
- set_config_for_less_flaky_test,
- set_model_for_less_flaky_test,
- set_model_tester_for_less_flaky_test,
slow,
torch_device,
)
@@ -122,7 +119,6 @@
from torch import nn
from transformers import MODEL_MAPPING, AdaptiveEmbedding
- from transformers.cache_utils import DynamicCache
from transformers.modeling_utils import load_state_dict, no_init_weights
from transformers.pytorch_utils import id_tensor_storage
@@ -1289,11 +1285,6 @@ def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=Fa
)
for i in range(model.config.num_hidden_layers)
)
- empty_pkv = (
- DynamicCache.from_legacy_cache(empty_pkv)
- if model_class._supports_cache_class
- else empty_pkv
- )
cache_length = 9
cache_shape = (batch_size, num_heads, cache_length, head_dim)
@@ -1304,11 +1295,6 @@ def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=Fa
)
for i in range(model.config.num_hidden_layers)
)
- non_empty_pkv = (
- DynamicCache.from_legacy_cache(non_empty_pkv)
- if model_class._supports_cache_class
- else non_empty_pkv
- )
inps = copy.deepcopy(inputs_to_test[0])
@@ -2485,7 +2471,7 @@ def _postprocessing_to_ignore_test_cases(self, tf_outputs, pt_outputs, model_cla
return new_tf_outputs, new_pt_outputs
# Copied from tests.test_modeling_tf_common.TFModelTesterMixin.check_pt_tf_outputs
- def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=1e-4, name="outputs", attributes=None):
+ def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=1e-5, name="outputs", attributes=None):
"""Check the outputs from PyTorch and TensorFlow models are close enough. Checks are done in a recursive way.
Args:
@@ -2541,8 +2527,6 @@ def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=1e-4, nam
attributes = tuple([f"{name}_{idx}" for idx in range(len(tf_outputs))])
for tf_output, pt_output, attr in zip(tf_outputs, pt_outputs, attributes):
- if isinstance(pt_output, DynamicCache):
- pt_output = pt_output.to_legacy_cache()
self.check_pt_tf_outputs(tf_output, pt_output, model_class, tol=tol, name=attr)
elif isinstance(tf_outputs, tf.Tensor):
@@ -2718,7 +2702,7 @@ def assert_almost_equals(self, a: np.ndarray, b: np.ndarray, tol: float):
diff = np.abs((a - b)).max()
self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).")
- def check_pt_flax_outputs(self, fx_outputs, pt_outputs, model_class, tol=1e-4, name="outputs", attributes=None):
+ def check_pt_flax_outputs(self, fx_outputs, pt_outputs, model_class, tol=1e-5, name="outputs", attributes=None):
"""
Args:
model_class: The class of the model that is currently testing. For example, ..., etc.
@@ -2728,6 +2712,7 @@ def check_pt_flax_outputs(self, fx_outputs, pt_outputs, model_class, tol=1e-4, n
Currently unused, but in the future, we could use this information to make the error message clearer
by giving the name(s) of the output tensor(s) with large difference(s) between PT and Flax.
"""
+
self.assertEqual(type(name), str)
if attributes is not None:
self.assertEqual(type(attributes), tuple, f"{name}: The argument `attributes` should be a `tuple`")
@@ -3458,9 +3443,6 @@ def test_mismatched_shapes_have_properly_initialized_weights(self):
"Data2VecAudioForSequenceClassification",
"UniSpeechForSequenceClassification",
"PvtForImageClassification",
- "ModernBertForSequenceClassification",
- "ModernBertForTokenClassification",
- "TimmWrapperForImageClassification",
]
special_param_names = [
r"^bit\.",
@@ -3481,7 +3463,6 @@ def test_mismatched_shapes_have_properly_initialized_weights(self):
r"^swiftformer\.",
r"^swinv2\.",
r"^transformers\.models\.swiftformer\.",
- r"^timm_model\.",
r"^unispeech\.",
r"^unispeech_sat\.",
r"^vision_model\.",
@@ -3613,6 +3594,34 @@ def test_model_is_small(self):
num_params < 1000000
), f"{model_class} is too big for the common tests ({num_params})! It should have 1M max."
+ @require_flash_attn
+ @require_torch_gpu
+ @mark.flash_attn_test
+ @slow
+ def test_flash_attn_2_conversion(self):
+ if not self.has_attentions:
+ self.skipTest(reason="Model architecture does not support attentions")
+
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ if not model_class._supports_flash_attn_2:
+ self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
+
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model = model_class.from_pretrained(
+ tmpdirname, torch_dtype=torch.float16, attn_implementation="flash_attention_2"
+ ).to(torch_device)
+
+ for _, module in model.named_modules():
+ if "FlashAttention" in module.__class__.__name__:
+ return
+
+ self.assertTrue(False, "FlashAttention2 modules not found in model")
+
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@@ -3870,6 +3879,15 @@ def test_sdpa_can_dispatch_non_composite_models(self):
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
raise ValueError("The eager model should not have SDPA attention layers")
+ has_sdpa = False
+ for name, submodule in model_sdpa.named_modules():
+ class_name = submodule.__class__.__name__
+ if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
+ has_sdpa = True
+ break
+ if not has_sdpa and model_sdpa.config.model_type != "falcon":
+ raise ValueError("The SDPA model should have SDPA attention layers")
+
@require_torch_sdpa
def test_sdpa_can_dispatch_composite_models(self):
"""
@@ -3922,6 +3940,15 @@ def test_sdpa_can_dispatch_composite_models(self):
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
raise ValueError("The eager model should not have SDPA attention layers")
+ has_sdpa = False
+ for name, submodule in model_sdpa.named_modules():
+ class_name = submodule.__class__.__name__
+ if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
+ has_sdpa = True
+ break
+ if not has_sdpa and any(module_attn == "sdpa" for module_attn in [text_attn, vision_attn]):
+ raise ValueError("The SDPA model should have SDPA attention layers")
+
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
@require_torch_sdpa
def test_eager_matches_sdpa_inference(self, torch_dtype: str):
@@ -3979,11 +4006,34 @@ def test_eager_matches_sdpa_inference(self, torch_dtype: str):
def get_mean_reldiff(failcase, x, ref, atol, rtol):
return f"{failcase}: mean relative difference: {((x - ref).abs() / (ref.abs() + 1e-12)).mean():.3e}, torch atol = {atol}, torch rtol = {rtol}"
- set_model_tester_for_less_flaky_test(self)
+ if hasattr(self.model_tester, "num_hidden_layers"):
+ self.model_tester.num_hidden_layers = 1
+ if hasattr(self.model_tester, "vision_config") and "num_hidden_layers" in self.model_tester.vision_config:
+ self.model_tester.vision_config = copy.deepcopy(self.model_tester.vision_config)
+ self.model_tester.vision_config["num_hidden_layers"] = 1
+ if hasattr(self.model_tester, "text_config") and "num_hidden_layers" in self.model_tester.text_config:
+ self.model_tester.text_config = copy.deepcopy(self.model_tester.text_config)
+ self.model_tester.text_config["num_hidden_layers"] = 1
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
- set_config_for_less_flaky_test(config)
+
+ config.rms_norm_eps = 1.0
+ config.layer_norm_eps = 1.0
+ config.norm_eps = 1.0
+ config.norm_epsilon = 1.0
+ config.layer_norm_epsilon = 1.0
+
+ # norm layers (layer/group norm, etc.) could cause flaky tests when the tensors have very small variance.
+ # (We don't need the original epsilon values to check eager/sdpa matches)
+ for attr in ["text_config", "vision_config", "text_encoder", "audio_encoder", "decoder"]:
+ if hasattr(config, attr):
+ getattr(config, attr).rms_norm_eps = 1.0
+ getattr(config, attr).layer_norm_eps = 1.0
+ getattr(config, attr).norm_eps = 1.0
+ getattr(config, attr).norm_epsilon = 1.0
+ getattr(config, attr).layer_norm_epsilon = 1.0
+
model = model_class(config)
# FIXME: we deactivate boolean mask for models using "use_mask_token" in their constructors.
# These models support masking only in the case `use_mask_token=True`. Otherwise they cannot consume an input mask.
@@ -3994,12 +4044,7 @@ def get_mean_reldiff(failcase, x, ref, atol, rtol):
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
- try:
- model_sdpa = model_class.from_pretrained(
- tmpdirname, torch_dtype=torch_dtype, attn_implementation="sdpa"
- )
- except ValueError:
- model_sdpa = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype)
+ model_sdpa = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype)
model_sdpa = model_sdpa.eval().to(torch_device, dtype=torch_dtype)
model_eager = model_class.from_pretrained(
@@ -4009,8 +4054,13 @@ def get_mean_reldiff(failcase, x, ref, atol, rtol):
)
model_eager = model_eager.eval().to(torch_device, dtype=torch_dtype)
- set_model_for_less_flaky_test(model_eager)
- set_model_for_less_flaky_test(model_sdpa)
+ # Another way to make sure norm layers have desired epsilon. (Some models don't set it from its config.)
+ for x in model_eager.modules():
+ if isinstance(x, (nn.LayerNorm, nn.GroupNorm)):
+ x.eps = 1.0
+ for x in model_sdpa.modules():
+ if isinstance(x, (nn.LayerNorm, nn.GroupNorm)):
+ x.eps = 1.0
# We use these for loops instead of parameterized.expand just for the interest of avoiding loading/saving 16 times the model,
# but it would be nicer to have an efficient way to use parameterized.expand
@@ -4150,20 +4200,16 @@ def get_mean_reldiff(failcase, x, ref, atol, rtol):
outputs_eager = model_eager(**prepared_inputs)
outputs_sdpa = model_sdpa(**prepared_inputs)
- if hasattr(outputs_eager, "vision_hidden_states"):
- logits_eager = outputs_eager.vision_hidden_states[-1]
- logits_sdpa = outputs_sdpa.vision_hidden_states[-1]
- else:
- logits_eager = (
- outputs_eager.hidden_states[-1]
- if not is_encoder_decoder
- else outputs_eager.decoder_hidden_states[-1]
- )
- logits_sdpa = (
- outputs_sdpa.hidden_states[-1]
- if not is_encoder_decoder
- else outputs_sdpa.decoder_hidden_states[-1]
- )
+ logits_eager = (
+ outputs_eager.hidden_states[-1]
+ if not is_encoder_decoder
+ else outputs_eager.decoder_hidden_states[-1]
+ )
+ logits_sdpa = (
+ outputs_sdpa.hidden_states[-1]
+ if not is_encoder_decoder
+ else outputs_sdpa.decoder_hidden_states[-1]
+ )
if torch_device in ["cpu", "cuda"]:
atol = atols[torch_device, enable_kernels, torch_dtype]
@@ -4239,8 +4285,6 @@ def test_sdpa_can_dispatch_on_flash(self):
)
if config.model_type in ["idefics", "idefics2", "idefics3"]:
self.skipTest(reason="Idefics currently (transformers==4.39.1) requires an image_attention_mask input")
- if config.model_type in ["sam"]:
- self.skipTest(reason="SAM requires an attention_mask input for relative positional embeddings")
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
diff --git a/tests/test_modeling_flax_common.py b/tests/test_modeling_flax_common.py
index bfe1648de049e1..c7d098be3ea8f2 100644
--- a/tests/test_modeling_flax_common.py
+++ b/tests/test_modeling_flax_common.py
@@ -23,7 +23,6 @@
import transformers
from transformers import is_flax_available, is_torch_available
-from transformers.cache_utils import DynamicCache
from transformers.models.auto import get_values
from transformers.testing_utils import CaptureLogger, is_pt_flax_cross_test, require_flax, torch_device
from transformers.utils import CONFIG_NAME, GENERATION_CONFIG_NAME, logging
@@ -181,7 +180,7 @@ def recursive_check(tuple_object, dict_object):
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
# (Copied from tests.test_modeling_common.ModelTesterMixin.check_pt_flax_outputs)
- def check_pt_flax_outputs(self, fx_outputs, pt_outputs, model_class, tol=1e-4, name="outputs", attributes=None):
+ def check_pt_flax_outputs(self, fx_outputs, pt_outputs, model_class, tol=1e-5, name="outputs", attributes=None):
"""
Args:
model_class: The class of the model that is currently testing. For example, ..., etc.
@@ -191,6 +190,7 @@ def check_pt_flax_outputs(self, fx_outputs, pt_outputs, model_class, tol=1e-4, n
Currently unused, but in the future, we could use this information to make the error message clearer
by giving the name(s) of the output tensor(s) with large difference(s) between PT and Flax.
"""
+
self.assertEqual(type(name), str)
if attributes is not None:
self.assertEqual(type(attributes), tuple, f"{name}: The argument `attributes` should be a `tuple`")
@@ -235,8 +235,6 @@ def check_pt_flax_outputs(self, fx_outputs, pt_outputs, model_class, tol=1e-4, n
attributes = tuple([f"{name}_{idx}" for idx in range(len(fx_outputs))])
for fx_output, pt_output, attr in zip(fx_outputs, pt_outputs, attributes):
- if isinstance(pt_output, DynamicCache):
- pt_output = pt_output.to_legacy_cache()
self.check_pt_flax_outputs(fx_output, pt_output, model_class, tol=tol, name=attr)
elif isinstance(fx_outputs, jnp.ndarray):
diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py
index 9dc712ab67b682..eb328d83e9e7a4 100644
--- a/tests/test_modeling_tf_common.py
+++ b/tests/test_modeling_tf_common.py
@@ -484,7 +484,7 @@ def _postprocessing_to_ignore_test_cases(self, tf_outputs, pt_outputs, model_cla
return new_tf_outputs, new_pt_outputs
- def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=1e-4, name="outputs", attributes=None):
+ def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=1e-5, name="outputs", attributes=None):
"""Check the outputs from PyTorch and TensorFlow models are close enough. Checks are done in a recursive way.
Args:
@@ -495,7 +495,6 @@ def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=1e-4, nam
attributes (`Tuple[str]`): The names of the output's element if the output is a tuple/list with each element
being a named field in the output.
"""
- from transformers.cache_utils import DynamicCache
self.assertEqual(type(name), str)
if attributes is not None:
@@ -541,8 +540,6 @@ def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=1e-4, nam
attributes = tuple([f"{name}_{idx}" for idx in range(len(tf_outputs))])
for tf_output, pt_output, attr in zip(tf_outputs, pt_outputs, attributes):
- if isinstance(pt_output, DynamicCache):
- pt_output = pt_output.to_legacy_cache()
self.check_pt_tf_outputs(tf_output, pt_output, model_class, tol=tol, name=attr)
elif isinstance(tf_outputs, tf.Tensor):
diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py
index d33be2789761da..f7b4a8637bff85 100644
--- a/tests/trainer/test_trainer.py
+++ b/tests/trainer/test_trainer.py
@@ -750,102 +750,11 @@ def test_model_init(self):
self.check_trained_model(trainer.model, alternate_seed=True)
@slow
- def test_gradient_accumulation_loss_alignment_with_model_loss(self):
+ def test_gradient_accumulation_loss_alignment(self):
set_seed(42)
import datasets
- model_name = "nickypro/tinyllama-110M"
- dataset_name = "wikitext"
- dataset_config = "wikitext-2-raw-v1"
- dataset = datasets.load_dataset(dataset_name, dataset_config, split="train[:500]")
- dataset = dataset.train_test_split(test_size=0.2)
- tokenizer = AutoTokenizer.from_pretrained(model_name)
-
- tokenizer.pad_token = tokenizer.eos_token
-
- def tokenize_function(examples):
- return tokenizer(examples["text"], max_length=128, padding="max_length", truncation=True)
-
- tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=dataset["train"].column_names)
-
- data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
-
- model = AutoModelForCausalLM.from_pretrained(model_name)
-
- base_loss_callback = StoreLossCallback()
-
- args_kwargs = {
- "report_to": "none",
- "logging_steps": 1,
- "max_steps": 20,
- "learning_rate": 3e-4,
- "disable_tqdm": True,
- }
-
- args = TrainingArguments(
- "./generation",
- **args_kwargs,
- )
- trainer = Trainer(
- model,
- args,
- train_dataset=tokenized_dataset["train"],
- callbacks=[base_loss_callback],
- data_collator=data_collator,
- )
- assert trainer.model_accepts_loss_kwargs
- trainer.train()
-
- grad_accum_loss_callback = StoreLossCallback()
- args = TrainingArguments(
- "./generation",
- **args_kwargs,
- gradient_accumulation_steps=2,
- per_device_train_batch_size=4,
- )
- set_seed(42)
- model = AutoModelForCausalLM.from_pretrained(model_name)
- trainer = Trainer(
- model,
- args,
- train_dataset=tokenized_dataset["train"],
- callbacks=[grad_accum_loss_callback],
- data_collator=data_collator,
- )
- trainer.train()
-
- set_seed(42)
- model = AutoModelForCausalLM.from_pretrained(model_name)
- broken_loss_callback = StoreLossCallback()
- trainer = Trainer(
- model,
- args,
- train_dataset=tokenized_dataset["train"],
- callbacks=[broken_loss_callback],
- data_collator=data_collator,
- )
- # disable model_accepts_loss_kwargs
- trainer.model_accepts_loss_kwargs = False
- trainer.train()
-
- # Calculate the difference between the base loss and the grad_accum loss
- diff_truth = [
- abs(base - grad) for base, grad in zip(base_loss_callback.losses, grad_accum_loss_callback.losses)
- ]
- diff_broken = [abs(base - grad) for base, grad in zip(base_loss_callback.losses, broken_loss_callback.losses)]
-
- # all diff truth should be quite close
- self.assertLess(max(diff_truth), 0.01, f"Difference {max(diff_truth)} is not within 0.01")
-
- # max diff broken should be very off
- self.assertGreater(max(diff_broken), 3, f"Difference {max(diff_broken)} is not greater than 3")
-
- @slow
- def test_gradient_accumulation_loss_alignment_with_loss_func(self):
- set_seed(42)
- import datasets
-
- model_name = "roneneldan/TinyStories-33M"
+ model_name = "distilgpt2"
dataset_name = "wikitext"
dataset_config = "wikitext-2-raw-v1"
dataset = datasets.load_dataset(dataset_name, dataset_config, split="train[:500]")
@@ -927,16 +836,15 @@ def compute_loss(logits, labels, vocab_size, num_items_in_batch, disable_num_ite
trainer.train()
# Calculate the difference between the base loss and the grad_accum loss
- diff_truth = [
- abs(base - grad) for base, grad in zip(base_loss_callback.losses, grad_accum_loss_callback.losses)
- ]
- diff_broken = [abs(base - grad) for base, grad in zip(base_loss_callback.losses, broken_loss_callback.losses)]
-
- # all diff truth should be quite close
- self.assertLess(max(diff_truth), 0.01, f"Difference {max(diff_truth)} is not within 0.01")
-
- # max diff broken should be very off
- self.assertGreater(max(diff_broken), 3, f"Difference {max(diff_broken)} is not greater than 3")
+ diff_truth = [base - grad for base, grad in zip(base_loss_callback.losses, grad_accum_loss_callback.losses)]
+ diff_broken = [base - grad for base, grad in zip(base_loss_callback.losses, broken_loss_callback.losses)]
+ # These should be quite close
+ for diff in diff_truth:
+ self.assertLess(abs(diff), 0.1, f"Difference {diff} is not within 0.1")
+
+ # These should be very off
+ for diff in diff_broken:
+ self.assertGreater(abs(diff), 0.1, f"Difference {diff} is not greater than 0.1")
def test_gradient_accumulation(self):
# Training with half the batch size but accumulation steps as 2 should give the same training losses.
diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py
index 383f0cbe60e1c9..458ddeee5ff8be 100644
--- a/tests/utils/test_modeling_utils.py
+++ b/tests/utils/test_modeling_utils.py
@@ -563,17 +563,32 @@ def test_model_from_pretrained_attn_implementation(self):
if is_flash_attn_2_available():
attn_implementation_available.append("flash_attention_2")
+ mistral_attention_classes = {
+ "eager": "MistralAttention",
+ "sdpa": "MistralSdpaAttention",
+ "flash_attention_2": "MistralFlashAttention2",
+ }
for requested_attn_implementation in attn_implementation_available:
model = AutoModelForCausalLM.from_pretrained(
TINY_MISTRAL, attn_implementation=requested_attn_implementation
)
self.assertEqual(model.config._attn_implementation, requested_attn_implementation)
+ for module in model.modules():
+ if "Attention" in module.__class__.__name__:
+ self.assertEqual(
+ module.__class__.__name__, mistral_attention_classes[requested_attn_implementation]
+ )
config = AutoConfig.from_pretrained(TINY_MISTRAL)
model = AutoModelForCausalLM.from_pretrained(
TINY_MISTRAL, config=config, attn_implementation=requested_attn_implementation
)
self.assertEqual(model.config._attn_implementation, requested_attn_implementation)
+ for module in model.modules():
+ if "Attention" in module.__class__.__name__:
+ self.assertEqual(
+ module.__class__.__name__, mistral_attention_classes[requested_attn_implementation]
+ )
def test_model_from_config_attn_implementation(self):
# test that the model can be instantiated with attn_implementation of either
@@ -587,6 +602,11 @@ def test_model_from_config_attn_implementation(self):
if is_flash_attn_2_available():
attn_implementation_available.append("flash_attention_2")
+ mistral_attention_classes = {
+ "eager": "MistralAttention",
+ "sdpa": "MistralSdpaAttention",
+ "flash_attention_2": "MistralFlashAttention2",
+ }
for requested_attn_implementation in attn_implementation_available:
config = AutoConfig.from_pretrained(TINY_MISTRAL, attn_implementation=requested_attn_implementation)
# Ensure the config was set correctly
@@ -594,6 +614,11 @@ def test_model_from_config_attn_implementation(self):
self.assertEqual(config._attn_implementation_internal, requested_attn_implementation)
model = AutoModelForCausalLM.from_config(config)
self.assertEqual(model.config._attn_implementation, requested_attn_implementation)
+ for module in model.modules():
+ if "Attention" in module.__class__.__name__:
+ self.assertEqual(
+ module.__class__.__name__, mistral_attention_classes[requested_attn_implementation]
+ )
config = AutoConfig.from_pretrained(TINY_MISTRAL)
# When the config is not set, the default is "eager"
@@ -601,6 +626,11 @@ def test_model_from_config_attn_implementation(self):
self.assertEqual(config._attn_implementation_internal, None)
model = AutoModelForCausalLM.from_config(config=config, attn_implementation=requested_attn_implementation)
self.assertEqual(model.config._attn_implementation, requested_attn_implementation)
+ for module in model.modules():
+ if "Attention" in module.__class__.__name__:
+ self.assertEqual(
+ module.__class__.__name__, mistral_attention_classes[requested_attn_implementation]
+ )
# Set a nonsense attn_implementation in the config, which should be overridden by the explicit argument
config = AutoConfig.from_pretrained(TINY_MISTRAL, attn_implementation="foo-bar-baz")
@@ -608,6 +638,11 @@ def test_model_from_config_attn_implementation(self):
self.assertEqual(config._attn_implementation_internal, "foo-bar-baz")
model = AutoModelForCausalLM.from_config(config=config, attn_implementation=requested_attn_implementation)
self.assertEqual(model.config._attn_implementation, requested_attn_implementation)
+ for module in model.modules():
+ if "Attention" in module.__class__.__name__:
+ self.assertEqual(
+ module.__class__.__name__, mistral_attention_classes[requested_attn_implementation]
+ )
def test_torch_dtype_byte_sizes(self):
torch_dtypes_and_bytes = [
@@ -1715,26 +1750,6 @@ def test_save_and_load_config_with_custom_generation(self):
new_model.generate(random_ids, max_new_tokens=3)
self.assertTrue(len(w) == 0)
- def test_load_model_with_state_dict_only(self):
- model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
- state_dict = model.state_dict()
- config = model.config
-
- model_loaded = BertModel.from_pretrained(
- pretrained_model_name_or_path=None, config=config, state_dict=state_dict
- )
- self.assertTrue(check_models_equal(model, model_loaded))
-
- def test_load_model_with_state_dict_only_low_cpu_mem_usage(self):
- model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
- state_dict = model.state_dict()
- config = model.config
-
- model_loaded = BertModel.from_pretrained(
- pretrained_model_name_or_path=None, config=config, state_dict=state_dict, low_cpu_mem_usage=True
- )
- self.assertTrue(check_models_equal(model, model_loaded))
-
@slow
@require_torch
diff --git a/utils/check_config_attributes.py b/utils/check_config_attributes.py
index 116e26e7834f26..1c81c08fd845b1 100644
--- a/utils/check_config_attributes.py
+++ b/utils/check_config_attributes.py
@@ -34,9 +34,6 @@
SPECIAL_CASES_TO_ALLOW = {
# 'max_position_embeddings' is not used in modeling file, but needed for eval frameworks like Huggingface's lighteval (https://github.com/huggingface/lighteval/blob/af24080ea4f16eaf1683e353042a2dfc9099f038/src/lighteval/models/base_model.py#L264).
# periods and offsers are not used in modeling file, but used in the configuration file to define `layers_block_type` and `layers_num_experts`.
- "BambaConfig": [
- "attn_layer_indices",
- ],
"JambaConfig": [
"max_position_embeddings",
"attn_layer_offset",
@@ -50,7 +47,6 @@
# `cache_implementation` should be in the default generation config, but we don't yet support per-model
# generation configs (TODO joao)
"Gemma2Config": ["tie_word_embeddings", "cache_implementation"],
- "Cohere2Config": ["cache_implementation"],
# used to compute the property `self.chunk_length`
"EncodecConfig": ["overlap"],
# used to compute the property `self.layers_block_type`
@@ -310,10 +306,6 @@ def check_attribute_being_used(config_class, attributes, default_value, source_s
"backbone_config",
"use_timm_backbone",
"backbone_kwargs",
- # rope attributes may not appear directly in the modeling but are used
- "rope_theta",
- "partial_rotary_factor",
- "pretraining_tp",
]
attributes_used_in_generation = ["encoder_no_repeat_ngram_size"]
diff --git a/utils/check_config_docstrings.py b/utils/check_config_docstrings.py
index 341fc42b9c68ee..d243dd0c35b612 100644
--- a/utils/check_config_docstrings.py
+++ b/utils/check_config_docstrings.py
@@ -41,7 +41,6 @@
"RagConfig",
"SpeechEncoderDecoderConfig",
"TimmBackboneConfig",
- "TimmWrapperConfig",
"VisionEncoderDecoderConfig",
"VisionTextDualEncoderConfig",
"LlamaConfig",
diff --git a/utils/check_table.py b/utils/check_table.py
index 957bfd5af6af6f..5876818449558e 100644
--- a/utils/check_table.py
+++ b/utils/check_table.py
@@ -87,7 +87,7 @@ def _find_text_in_file(filename: str, start_prompt: str, end_prompt: str) -> str
_re_tf_models = re.compile(r"TF(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration)")
_re_flax_models = re.compile(r"Flax(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration)")
# Will match any TF or Flax model too so need to be in an else branch after the two previous regexes.
-_re_pt_models = re.compile(r"(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration|ForRetrieval)")
+_re_pt_models = re.compile(r"(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration)")
# This is to make sure the transformers module imported is the one in the repo.
@@ -157,7 +157,6 @@ def _center_text(text: str, width: int) -> str:
"LayoutXLM": "LayoutLMv2",
"Llama2": "LLaMA",
"Llama3": "LLaMA",
- "Falcon3": "LLaMA",
"MADLAD-400": "T5",
"MatCha": "Pix2Struct",
"mBART-50": "mBART",
diff --git a/utils/create_dependency_mapping.py b/utils/create_dependency_mapping.py
index 0df782d1c21740..f25a8fb5ca6ff1 100644
--- a/utils/create_dependency_mapping.py
+++ b/utils/create_dependency_mapping.py
@@ -1,48 +1,40 @@
import ast
-from collections import defaultdict
+from collections import defaultdict, deque
# Function to perform topological sorting
def topological_sort(dependencies):
- new_dependencies = {}
+ # Create a graph and in-degree count for each node
graph = defaultdict(list)
+ in_degree = defaultdict(int)
+
+ # Build the graph
for node, deps in dependencies.items():
for dep in deps:
- if "example" not in node and "auto" not in dep:
- graph[dep.split(".")[-2]].append(node.split("/")[-2])
- new_dependencies[node.split("/")[-2]] = node
+ graph[dep].append(node) # node depends on dep
+ in_degree[node] += 1 # increase in-degree of node
- # Create a graph and in-degree count for each node
- def filter_one_by_one(filtered_list, reverse):
- if len(reverse) == 0:
- return filtered_list
+ # Add all nodes with zero in-degree to the queue
+ zero_in_degree_queue = deque([node for node in dependencies if in_degree[node] == 0])
- graph = defaultdict(list)
- # Build the graph
- for node, deps in reverse.items():
- for dep in deps:
- graph[dep].append(node)
+ sorted_list = []
+ # Perform topological sorting
+ while zero_in_degree_queue:
+ current = zero_in_degree_queue.popleft()
+ sorted_list.append(current)
- base_modules = set(reverse.keys()) - set(graph.keys())
- if base_modules == reverse.keys():
- # we are at the end
- return filtered_list + list(graph.keys())
- to_add = []
- for k in graph.keys():
- if len(graph[k]) == 1 and graph[k][0] in base_modules:
- if graph[k][0] in reverse:
- del reverse[graph[k][0]]
- if k not in filtered_list:
- to_add += [k]
- for k in base_modules:
- if k not in filtered_list:
- to_add += [k]
- filtered_list += list(to_add)
- return filter_one_by_one(filtered_list, reverse)
+ # For each node that current points to, reduce its in-degree
+ for neighbor in graph[current]:
+ in_degree[neighbor] -= 1
+ if in_degree[neighbor] == 0:
+ zero_in_degree_queue.append(neighbor)
- final_order = filter_one_by_one([], graph)
+ # Handle nodes that have no dependencies and were not initially part of the loop
+ for node in dependencies:
+ if node not in sorted_list:
+ sorted_list.append(node)
- return [new_dependencies.get(k) for k in final_order if k in new_dependencies]
+ return sorted_list
# Function to extract class and import info from a file
@@ -54,7 +46,7 @@ def extract_classes_and_imports(file_path):
for node in ast.walk(tree):
if isinstance(node, (ast.Import, ast.ImportFrom)):
module = node.module if isinstance(node, ast.ImportFrom) else None
- if module and (".modeling_" in module):
+ if module and "transformers" in module:
imports.add(module)
return imports
@@ -64,7 +56,7 @@ def map_dependencies(py_files):
dependencies = defaultdict(set)
# First pass: Extract all classes and map to files
for file_path in py_files:
- # dependencies[file_path].add(None)
+ dependencies[file_path].add(None)
class_to_file = extract_classes_and_imports(file_path)
for module in class_to_file:
dependencies[file_path].add(module)
@@ -74,4 +66,4 @@ def map_dependencies(py_files):
def find_priority_list(py_files):
dependencies = map_dependencies(py_files)
ordered_classes = topological_sort(dependencies)
- return ordered_classes
+ return ordered_classes[::-1]
diff --git a/utils/modular_model_converter.py b/utils/modular_model_converter.py
index 28fcc4fc7b9e1a..e8d117cd2af08f 100644
--- a/utils/modular_model_converter.py
+++ b/utils/modular_model_converter.py
@@ -1678,7 +1678,7 @@ def save_modeling_file(modular_file, converted_file):
parser = argparse.ArgumentParser()
parser.add_argument(
"--files_to_parse",
- default=["all"],
+ default=["src/transformers/models/aria/modular_aria.py"],
nargs="+",
help="A list of `modular_xxxx` files that should be converted to single model file",
)
diff --git a/utils/pr_slow_ci_models.py b/utils/pr_slow_ci_models.py
index c6a24c0f219ae7..391e99fc2276f8 100644
--- a/utils/pr_slow_ci_models.py
+++ b/utils/pr_slow_ci_models.py
@@ -15,20 +15,19 @@
"""
This script is used to get the models for which to run slow CI.
-A new model added in a pull request will be included, as well as models specified in a GitHub pull request's comment
-with a prefix `run-slow`, `run_slow` or `run slow`. For example, the commit message `run_slow: bert, gpt2` will give
-`bert` and `gpt2`.
+A new model added in a pull request will be included, as well as models specified in a commit message with a prefix
+`[run-slow]`, `[run_slow]` or `[run slow]`. For example, the commit message `[run_slow]bert, gpt2` will give `bert` and
+`gpt2`.
Usage:
```bash
-python utils/pr_slow_ci_models.py
+python utils/pr_slow_ci_models.py.py
```
"""
import argparse
import re
-import string
from pathlib import Path
from typing import List
@@ -90,7 +89,7 @@ def get_new_python_files() -> List[str]:
def get_new_model():
new_files = get_new_python_files()
- reg = re.compile(r"src/transformers/models/(.*)/modeling_.*\.py")
+ reg = re.compile(r"src/transformers/(models/.*)/modeling_.*\.py")
new_model = ""
for x in new_files:
@@ -102,53 +101,45 @@ def get_new_model():
return new_model
-def parse_message(message: str) -> str:
+def parse_commit_message(commit_message: str) -> str:
"""
- Parses a GitHub pull request's comment to find the models specified in it to run slow CI.
+ Parses the commit message to find the models specified in it to run slow CI.
Args:
- message (`str`): The body of a GitHub pull request's comment.
+ commit_message (`str`): The commit message of the current commit.
Returns:
- `str`: The substring in `message` after `run-slow`, run_slow` or run slow`. If no such prefix is found, the
- empty string is returned.
+ `str`: The substring in `commit_message` after `[run-slow]`, [run_slow]` or [run slow]`. If no such prefix is
+ found, the empty string is returned.
"""
- if message is None:
+ if commit_message is None:
return ""
- message = message.strip().lower()
-
- # run-slow: model_1, model_2
- if not message.startswith(("run-slow", "run_slow", "run slow")):
+ command_search = re.search(r"\[([^\]]*)\](.*)", commit_message)
+ if command_search is None:
return ""
- message = message[len("run slow") :]
- # remove leading `:`
- while message.strip().startswith(":"):
- message = message.strip()[1:]
-
- return message
-
-def get_models(message: str):
- models = parse_message(message)
- return models.replace(",", " ").split()
+ command = command_search.groups()[0]
+ command = command.lower().replace("-", " ").replace("_", " ")
+ run_slow = command == "run slow"
+ if run_slow:
+ models = command_search.groups()[1].strip()
+ return models
+ else:
+ return ""
-def check_model_names(model_name: str):
- allowed = string.ascii_letters + string.digits + "_"
- return not (model_name.startswith("_") or model_name.endswith("_")) and all(c in allowed for c in model_name)
+def get_models(commit_message: str):
+ models = parse_commit_message(commit_message)
+ return [f"models/{x}" for x in models.replace(",", " ").split()]
if __name__ == "__main__":
parser = argparse.ArgumentParser()
- parser.add_argument("--message", type=str, default="", help="The content of a comment.")
+ parser.add_argument("--commit_message", type=str, default="", help="The commit message.")
args = parser.parse_args()
new_model = get_new_model()
- specified_models = get_models(args.message)
+ specified_models = get_models(args.commit_message)
models = ([] if new_model == "" else [new_model]) + specified_models
- # a guard for strange model names
- models = [model for model in models if check_model_names(model)]
- # Add "models/"
- models = [f"models/{model}" for model in models]
print(sorted(set(models)))
diff --git a/utils/process_circleci_workflow_test_reports.py b/utils/process_circleci_workflow_test_reports.py
deleted file mode 100644
index 944bc47a7e2fa4..00000000000000
--- a/utils/process_circleci_workflow_test_reports.py
+++ /dev/null
@@ -1,85 +0,0 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-import argparse
-import json
-import os
-
-import requests
-
-
-if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument("--workflow_id", type=str, required=True)
- args = parser.parse_args()
- workflow_id = args.workflow_id
-
- r = requests.get(
- f"https://circleci.com/api/v2/workflow/{workflow_id}/job",
- headers={"Circle-Token": os.environ.get("CIRCLE_TOKEN", "")},
- )
- jobs = r.json()["items"]
-
- os.makedirs("outputs", exist_ok=True)
-
- workflow_summary = {}
- # for each job, download artifacts
- for job in jobs:
- project_slug = job["project_slug"]
- if job["name"].startswith(("tests_", "examples_", "pipelines_")):
- url = f'https://circleci.com/api/v2/project/{project_slug}/{job["job_number"]}/artifacts'
- r = requests.get(url, headers={"Circle-Token": os.environ.get("CIRCLE_TOKEN", "")})
- job_artifacts = r.json()["items"]
-
- os.makedirs(job["name"], exist_ok=True)
- os.makedirs(f'outputs/{job["name"]}', exist_ok=True)
-
- job_test_summaries = {}
- for artifact in job_artifacts:
- if artifact["path"].startswith("reports/") and artifact["path"].endswith("/summary_short.txt"):
- node_index = artifact["node_index"]
- url = artifact["url"]
- r = requests.get(url, headers={"Circle-Token": os.environ.get("CIRCLE_TOKEN", "")})
- test_summary = r.text
- job_test_summaries[node_index] = test_summary
-
- summary = {}
- for node_index, node_test_summary in job_test_summaries.items():
- for line in node_test_summary.splitlines():
- if line.startswith("PASSED "):
- test = line[len("PASSED ") :]
- summary[test] = "passed"
- elif line.startswith("FAILED "):
- test = line[len("FAILED ") :].split()[0]
- summary[test] = "failed"
- # failed before passed
- summary = dict(sorted(summary.items(), key=lambda x: (x[1], x[0])))
- workflow_summary[job["name"]] = summary
-
- # collected version
- with open(f'outputs/{job["name"]}/test_summary.json', "w") as fp:
- json.dump(summary, fp, indent=4)
-
- new_workflow_summary = {}
- for job_name, job_summary in workflow_summary.items():
- for test, status in job_summary.items():
- if test not in new_workflow_summary:
- new_workflow_summary[test] = {}
- new_workflow_summary[test][job_name] = status
-
- for test, result in new_workflow_summary.items():
- new_workflow_summary[test] = dict(sorted(result.items()))
- new_workflow_summary = dict(sorted(new_workflow_summary.items()))
-
- with open("outputs/test_summary.json", "w") as fp:
- json.dump(new_workflow_summary, fp, indent=4)
diff --git a/utils/release.py b/utils/release.py
index d5b74602e68c09..b0349a80b49802 100644
--- a/utils/release.py
+++ b/utils/release.py
@@ -45,14 +45,12 @@
import argparse
import os
import re
-from pathlib import Path
import packaging.version
# All paths are defined with the intent that this script should be run from the root of the repo.
PATH_TO_EXAMPLES = "examples/"
-PATH_TO_MODELS = "src/transformers/models"
# This maps a type of file to the pattern to look for when searching where the version is defined, as well as the
# template to follow when replacing it with the new version.
REPLACE_PATTERNS = {
@@ -119,17 +117,6 @@ def global_version_update(version: str, patch: bool = False):
update_version_in_examples(version)
-def remove_conversion_scripts():
- """
- Delete the scripts that convert models from older, unsupported formats. We don't want to include these
- in release wheels because they often have to open insecure file types (pickle, Torch .bin models). This results in
- vulnerability scanners flagging us and can cause compliance issues for users with strict security policies.
- """
- model_dir = Path(PATH_TO_MODELS)
- for conversion_script in list(model_dir.glob("**/convert*.py")):
- conversion_script.unlink()
-
-
def get_version() -> packaging.version.Version:
"""
Reads the current version in the main __init__.
@@ -144,7 +131,7 @@ def pre_release_work(patch: bool = False):
"""
Do all the necessary pre-release steps:
- figure out the next minor release version and ask confirmation
- - update the version everywhere
+ - update the version eveywhere
- clean-up the model list in the main README
Args:
@@ -168,15 +155,13 @@ def pre_release_work(patch: bool = False):
print(f"Updating version to {version}.")
global_version_update(version, patch=patch)
- print("Deleting conversion scripts.")
- remove_conversion_scripts()
def post_release_work():
"""
- Do all the necessary post-release steps:
+ Do all the necesarry post-release steps:
- figure out the next dev version and ask confirmation
- - update the version everywhere
+ - update the version eveywhere
- clean-up the model list in the main README
"""
# First let's get the current version
diff --git a/utils/tests_fetcher.py b/utils/tests_fetcher.py
index c641ccb21e2984..906e85e1de61a5 100644
--- a/utils/tests_fetcher.py
+++ b/utils/tests_fetcher.py
@@ -995,7 +995,9 @@ def _print_list(l) -> str:
def infer_tests_to_run(
- output_file: str, diff_with_last_commit: bool = False, filter_models: bool = False, test_all: bool = False
+ output_file: str,
+ diff_with_last_commit: bool = False,
+ filter_models: bool = False,
):
"""
The main function called by the test fetcher. Determines the tests to run from the diff.
@@ -1016,11 +1018,7 @@ def infer_tests_to_run(
Whether or not to filter the tests to core models only, when a file modified results in a lot of model
tests.
"""
- if not test_all:
- modified_files = get_modified_python_files(diff_with_last_commit=diff_with_last_commit)
- else:
- modified_files = [str(k) for k in PATH_TO_TESTS.glob("*/*") if str(k).endswith(".py") and "test_" in str(k)]
- print("\n### test_all is TRUE, FETCHING ALL FILES###\n")
+ modified_files = get_modified_python_files(diff_with_last_commit=diff_with_last_commit)
print(f"\n### MODIFIED FILES ###\n{_print_list(modified_files)}")
# Create the map that will give us all impacted modules.
@@ -1232,6 +1230,5 @@ def create_test_list_from_filter(full_test_list, out_path):
args.output_file,
diff_with_last_commit=diff_with_last_commit,
filter_models=False,
- test_all=commit_flags["test_all"],
)
filter_tests(args.output_file, ["repo_utils"])
diff --git a/utils/update_metadata.py b/utils/update_metadata.py
index 8e4a7e3fe5340e..b6ee1e7c8c13c2 100755
--- a/utils/update_metadata.py
+++ b/utils/update_metadata.py
@@ -56,7 +56,7 @@
_re_tf_models = re.compile(r"TF(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration)")
_re_flax_models = re.compile(r"Flax(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration)")
# Will match any TF or Flax model too so need to be in an else branch afterthe two previous regexes.
-_re_pt_models = re.compile(r"(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration|ForRetrieval)")
+_re_pt_models = re.compile(r"(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration)")
# Fill this with tuples (pipeline_tag, model_mapping, auto_model)