diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index 7f50c8327c9..43299fa0421 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -13,4 +13,5 @@ Summarize the changes made and its impact. - [ ] Blackhole Post commit (if applicable) - [ ] Model regression CI testing passes (if applicable) - [ ] Device performance regression CI testing passes (if applicable) +- [ ] **(For models and ops writers)** Full [new models](https://github.com/tenstorrent/tt-metal/actions/workflows/full-new-models-suite.yaml) tests passes - [ ] New/Existing tests provide coverage for changes diff --git a/.github/workflows/_produce-data.yaml b/.github/workflows/_produce-data.yaml index fd547b44aa9..c66c5bb5702 100644 --- a/.github/workflows/_produce-data.yaml +++ b/.github/workflows/_produce-data.yaml @@ -22,7 +22,9 @@ on: - "(Single-card) Model perf tests" - "(Single-card) Device perf tests" - "(Single-card) Demo tests" + - "(Single-card) Tests for new models" - "Nightly fast dispatch tests" + - "(Single-card) Tests for new models" - "(T3K) T3000 demo tests" - "(T3K) T3000 model perf tests" - "(T3K) T3000 perplexity tests" @@ -39,6 +41,7 @@ on: - "(TGG) TGG frequent tests" - "ttnn - Run sweeps" - "Blackhole post-commit tests" + - "Custom test dispatch" types: - completed diff --git a/.github/workflows/all-post-commit-workflows.yaml b/.github/workflows/all-post-commit-workflows.yaml index dcaf74380c1..57bae427f2f 100644 --- a/.github/workflows/all-post-commit-workflows.yaml +++ b/.github/workflows/all-post-commit-workflows.yaml @@ -2,6 +2,11 @@ name: "All post-commit tests" on: workflow_call: + inputs: + build-type: + required: false + default: Release + type: string workflow_dispatch: inputs: build-type: diff --git a/.github/workflows/fast-dispatch-full-regressions-and-models-impl.yaml b/.github/workflows/fast-dispatch-full-regressions-and-models-impl.yaml index 762324fb3a3..0af646345b1 100644 --- a/.github/workflows/fast-dispatch-full-regressions-and-models-impl.yaml +++ b/.github/workflows/fast-dispatch-full-regressions-and-models-impl.yaml @@ -149,8 +149,8 @@ jobs: fail-fast: false matrix: test-config: - - model: "wh_b0_unstable" - cmd: ./tests/scripts/single_card/nightly/run_wh_b0_unstable.sh + - model: "stable_diffusion" + cmd: pytest --timeout 900 -n auto tests/nightly/single_card/stable_diffusion - model: "mamba 1" cmd: pytest --timeout 900 -n auto tests/nightly/single_card/mamba --splits 6 --group 1 - model: "mamba 2" diff --git a/.github/workflows/full-new-models-suite.yaml b/.github/workflows/full-new-models-suite.yaml new file mode 100644 index 00000000000..15ecd80c104 --- /dev/null +++ b/.github/workflows/full-new-models-suite.yaml @@ -0,0 +1,60 @@ +name: "(Single-card) Tests for new models" + +on: + workflow_dispatch: + inputs: + build-type: + required: false + default: Release + type: choice + options: + - Release + - Debug + - RelWithDebInfo + - CI + +permissions: + actions: read + contents: write + pull-requests: write + pages: write + id-token: write + packages: write + +jobs: + build-docker-image-2004: + uses: ./.github/workflows/build-docker-artifact.yaml + secrets: inherit + with: + os: ubuntu-20.04-amd64 + build-artifact: + needs: build-docker-image-2004 + uses: ./.github/workflows/build-artifact.yaml + secrets: inherit + with: + build-docker: false + build-type: ${{ inputs.build-type || 'Release' }} + build-artifact-profiler: + needs: build-docker-image-2004 + uses: ./.github/workflows/build-artifact.yaml + with: + tracy: true + build-docker: false + build-type: ${{ inputs.build-type || 'Release' }} + secrets: inherit + device-perf-single-card: + needs: build-artifact-profiler + uses: ./.github/workflows/perf-device-models-impl.yaml + secrets: inherit + e2e-model-perf-single-card: + needs: build-artifact + uses: ./.github/workflows/perf-models-impl.yaml + secrets: inherit + nightly-single-card: + needs: build-artifact + uses: ./.github/workflows/fast-dispatch-full-regressions-and-models-impl.yaml + secrets: inherit + demos-single-card: + needs: build-artifact + uses: ./.github/workflows/single-card-demo-tests-impl.yaml + secrets: inherit diff --git a/.github/workflows/t3000-demo-tests-impl.yaml b/.github/workflows/t3000-demo-tests-impl.yaml index f71636bdb15..9ad4ab1b818 100644 --- a/.github/workflows/t3000-demo-tests-impl.yaml +++ b/.github/workflows/t3000-demo-tests-impl.yaml @@ -16,7 +16,7 @@ jobs: test-group: [ { name: "t3k_falcon40b_tests", arch: wormhole_b0, cmd: run_t3000_falcon40b_tests, timeout: 50, owner_id: U053W15B6JF}, #Djordje Ivanovic { name: "t3k_llama3_tests", arch: wormhole_b0, cmd: run_t3000_llama3_tests, timeout: 30, owner_id: U03PUAKE719}, # Miguel Tairum - # { name: "t3k_llama3_vision_tests", arch: wormhole_b0, cmd: run_t3000_llama3_vision_tests, timeout: 30, owner_id: U03FJB5TM5Y}, #Colman Glagovich + { name: "t3k_llama3_vision_tests", arch: wormhole_b0, cmd: run_t3000_llama3_vision_tests, timeout: 30, owner_id: U03FJB5TM5Y}, #Colman Glagovich { name: "t3k_llama3_70b_tests", arch: wormhole_b0, cmd: run_t3000_llama3_70b_tests, timeout: 30, owner_id: U03FJB5TM5Y}, #Colman Glagovich { name: "t3k_falcon7b_tests", arch: wormhole_b0, cmd: run_t3000_falcon7b_tests, timeout: 90, owner_id: U05RWH3QUPM}, #Salar Hosseini { name: "t3k_mixtral_tests", arch: wormhole_b0, cmd: run_t3000_mixtral_tests, timeout: 50, owner_id: U03PUAKE719}, # Miguel Tairum diff --git a/.github/workflows/t3000-frequent-tests-impl.yaml b/.github/workflows/t3000-frequent-tests-impl.yaml index 542e85187c6..11a2df7b146 100644 --- a/.github/workflows/t3000-frequent-tests-impl.yaml +++ b/.github/workflows/t3000-frequent-tests-impl.yaml @@ -18,9 +18,10 @@ jobs: { name: "t3k ethernet tests", arch: wormhole_b0, cmd: run_t3000_ethernet_tests, timeout: 60, owner_id: ULMEPM2MA}, #Sean Nijjar { name: "t3k trace stress tests", arch: wormhole_b0, cmd: run_t3000_trace_stress_tests, timeout: 120, owner_id: U03NG0A5ND7}, #Aditya Saigal { name: "t3k falcon40b tests", arch: wormhole_b0, cmd: run_t3000_falcon40b_tests, timeout: 120, owner_id: U04S2UV6L8N}, #Sofija Jovic - # { name: "t3k llama3.2-vision tests", arch: wormhole_b0, cmd: run_t3000_llama3.2-11b-vision_freq_tests, timeout: 60, owner_id: U03FJB5TM5Y}, #Colman Glagovich - # { name: "t3k n300 mesh llama3.2-vision tests", arch: wormhole_b0, cmd: run_t3000_spoof_n300_llama3.2-11b-vision_freq_tests, timeout: 60, owner_id: U03FJB5TM5Y}, #Colman Glagovich + { name: "t3k llama3.2-vision tests", arch: wormhole_b0, cmd: run_t3000_llama3.2-11b-vision_freq_tests, timeout: 60, owner_id: U03FJB5TM5Y}, #Colman Glagovich + { name: "t3k n300 mesh llama3.2-vision tests", arch: wormhole_b0, cmd: run_t3000_spoof_n300_llama3.2-11b-vision_freq_tests, timeout: 60, owner_id: U03FJB5TM5Y}, #Colman Glagovich { name: "t3k llama3 tests", arch: wormhole_b0, cmd: run_t3000_llama3_tests, timeout: 45, owner_id: U03PUAKE719}, #Miguel Tairum Cruz + { name: "t3k llama3 accuracy tests", arch: wormhole_b0, cmd: run_t3000_llama3_accuracy_tests, timeout: 45, owner_id: U03PUAKE719}, #Miguel Tairum Cruz { name: "t3k llama2_70b tests", arch: wormhole_b0, cmd: run_t3000_llama2_70b_tests, timeout: 45, owner_id: U03FJB5TM5Y}, #Colman Glagovich # { name: "t3k llama3_70b tests", arch: wormhole_b0, cmd: run_t3000_llama3_70b_tests, timeout: 45, owner_id: U03FJB5TM5Y}, #Colman Glagovich # FIXME issue #14934 { name: "t3k mixtral tests", arch: wormhole_b0, cmd: run_t3000_mixtral_tests, timeout: 60, owner_id: U03PUAKE719}, #Miguel Tairum Cruz diff --git a/.github/workflows/t3000-unit-tests-impl.yaml b/.github/workflows/t3000-unit-tests-impl.yaml index f05ee8e7810..303de478fd7 100644 --- a/.github/workflows/t3000-unit-tests-impl.yaml +++ b/.github/workflows/t3000-unit-tests-impl.yaml @@ -20,8 +20,8 @@ jobs: { name: "t3k falcon40b tests", arch: wormhole_b0, cmd: run_t3000_falcon40b_tests, timeout: 30, owner_id: U053W15B6JF}, #Djordje Ivanovic { name: "t3k llama3-small tests", arch: wormhole_b0, cmd: run_t3000_llama3-small_tests, timeout: 30, owner_id: U03PUAKE719}, #Miguel Tairum Cruz { name: "t3k llama3.2-11b tests", arch: wormhole_b0, cmd: run_t3000_llama3.2-11b_tests, timeout: 30, owner_id: U03PUAKE719}, #Miguel Tairum Cruz - # { name: "t3k llama3.2-11b-vision tests", arch: wormhole_b0, cmd: run_t3000_llama3.2-11b-vision_unit_tests, timeout: 30, owner_id: U03FJB5TM5Y}, #Colman Glagovich - # { name: "t3k n300 mesh llama3.2-11b-vision tests", arch: wormhole_b0, cmd: run_t3000_spoof_n300_llama3.2-11b-vision_unit_tests, timeout: 30, owner_id: U03FJB5TM5Y}, #Colman Glagovich + { name: "t3k llama3.2-11b-vision tests", arch: wormhole_b0, cmd: run_t3000_llama3.2-11b-vision_unit_tests, timeout: 30, owner_id: U03FJB5TM5Y}, #Colman Glagovich + { name: "t3k n300 mesh llama3.2-11b-vision tests", arch: wormhole_b0, cmd: run_t3000_spoof_n300_llama3.2-11b-vision_unit_tests, timeout: 30, owner_id: U03FJB5TM5Y}, #Colman Glagovich { name: "t3k llama3.1-70b tests", arch: wormhole_b0, cmd: run_t3000_llama3.1-70b_tests, timeout: 30, owner_id: U03PUAKE719}, #Miguel Tairum Cruz { name: "t3k mixtral tests", arch: wormhole_b0, cmd: run_t3000_mixtral_tests, timeout: 30, owner_id: U03PUAKE719}, #Miguel Tairum Cruz { name: "t3k grok tests", arch: wormhole_b0, cmd: run_t3000_grok_tests, timeout: 30, owner_id: U03HY7MK4BT}, #Mark O'Connor diff --git a/.github/workflows/ttnn-run-sweeps.yaml b/.github/workflows/ttnn-run-sweeps.yaml index 99ea50f9860..22ef67bb730 100644 --- a/.github/workflows/ttnn-run-sweeps.yaml +++ b/.github/workflows/ttnn-run-sweeps.yaml @@ -348,6 +348,7 @@ on: - transformer.split_query_key_value_and_split_heads.split_query_key_value_and_split_heads_kv_input - transformer.attention_softmax.attention_softmax - transformer.attention_softmax.attention_softmax_ + - transformer.rotary_embedding.rotary_embedding - data_movement.stack.stack_pytorch2 - data_movement.repeat.repeat_pytorch2 - data_movement.split.split_pytorch2 diff --git a/.gitmodules b/.gitmodules index d20bc574cc3..4029f9918d6 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,6 +1,3 @@ -[submodule "third_party/pybind11"] - path = tt_metal/third_party/pybind11 - url = https://github.com/pybind/pybind11.git [submodule "third_party/lfs"] path = tt_metal/third_party/lfs url = https://github.com/tenstorrent-metal/lfs.git diff --git a/CMakeLists.txt b/CMakeLists.txt index a4146a61cb9..f929b62c989 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -259,18 +259,6 @@ if(ENABLE_TRACY) add_link_options(-rdynamic) endif() -if(WITH_PYTHON_BINDINGS) - # Can't use the `REUSE_FROM` option bc tt_lib and ttnn have different build flags :( - add_library(pch_pybinds INTERFACE) - target_precompile_headers( - pch_pybinds - INTERFACE - ${PROJECT_SOURCE_DIR}/tt_metal/third_party/pybind11/include/pybind11/operators.h - ${PROJECT_SOURCE_DIR}/tt_metal/third_party/pybind11/include/pybind11/pybind11.h - ${PROJECT_SOURCE_DIR}/tt_metal/third_party/pybind11/include/pybind11/stl.h - ) -endif() - ############################################################################################################################ # Build subdirectories ############################################################################################################################ diff --git a/CODEOWNERS b/CODEOWNERS index aa80b7671c4..3b74d00a047 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -173,10 +173,6 @@ tests/**/dtx/ @mywoodstock @sankarmanoj-tt tests/**/*test*conv*.py @mywoodstock @sankarmanoj-tt tests/python_api_testing/conv/ @mywoodstock @sankarmanoj-tt tests/python_api_testing/unit_testing/fallback_ops @tt-aho -tests/ttnn/integration_tests/stable_diffusion @esmalTT @uaydonat @mywoodstock -tests/device_perf_tests/stable_diffusion/test_perf_stable_diffusion.py @esmalTT @uaydonat @mywoodstock -tests/ttnn/integration_tests/unet @esmalTT @uaydonat @mywoodstock -tests/nightly/wh_b0_only_eth/experimental/functional_unet @esmalTT @uaydonat @mywoodstock scripts/profiler/ @mo-tenstorrent scripts/docker @tenstorrent/metalium-developers-infra diff --git a/README.md b/README.md index 2765715fea1..d06621a08fa 100644 --- a/README.md +++ b/README.md @@ -21,29 +21,34 @@ --- ## LLMs -| Model | Batch | Hardware | ttft (ms) | t/s/u | Target
t/s/u | t/s | Release | -|---------------------------------------------------------------|-------|----------------------------------------------------------|----------|-------|-----------------|--------|---------------------------------------------------------------------------| -| [Falcon7B-decode](./models/demos/ttnn_falcon7b) | 32 | [e150](https://tenstorrent.com/hardware/grayskull) | | 4.2 | 4.4 | 134.4 | | -| [Falcon7B](./models/demos/wormhole/falcon7b) | 32 | [n150](https://tenstorrent.com/hardware/wormhole) | 71 | 17.6 | 26 | 563.2 | [v0.53.0-rc44](https://github.com/tenstorrent/tt-metal/tree/v0.53.0-rc44) | -| [Mistral-7B](./models/demos/wormhole/mistral7b) | 32 | [n150](https://tenstorrent.com/hardware/wormhole) | | 9.9 | 25 | 316.8 | [v0.51.0-rc28](https://github.com/tenstorrent/tt-metal/tree/v0.51.0-rc28) | -| [Mamba-2.8B](./models/demos/wormhole/mamba) | 32 | [n150](https://tenstorrent.com/hardware/wormhole) | 48 | 12.3 | 41 | 393.6 | [v0.51.0-rc26](https://github.com/tenstorrent/tt-metal/tree/v0.51.0-rc26) | -| [LLaMA-3.1-8B](./models/demos/llama3) | 1 | [n150](https://tenstorrent.com/hardware/wormhole) | 202 | 28.6 | 23 | 28.6 | [v0.53.1-rc7](https://github.com/tenstorrent/tt-metal/tree/v0.53.1-rc7) | -| [LLaMA-3.2-1B](./models/demos/llama3) | 1 | [n150](https://tenstorrent.com/hardware/wormhole) | 71 | 90.8 | 160 | 90.8 | [v0.53.1-rc7](https://github.com/tenstorrent/tt-metal/tree/v0.53.1-rc7) | -| [LLaMA-3.2-3B](./models/demos/llama3) | 1 | [n150](https://tenstorrent.com/hardware/wormhole) | 112 | 49.1 | 60 | 49.1 | [v0.53.1-rc7](https://github.com/tenstorrent/tt-metal/tree/v0.53.1-rc7) | -| [Falcon7B (DP=8)](./models/demos/t3000/falcon7b) | 256 | [QuietBox](https://tenstorrent.com/hardware/tt-quietbox) | 97 | 14.6 | 26 | 3737.6 | [v0.53.0-rc44](https://github.com/tenstorrent/tt-metal/tree/v0.53.0-rc44) | -| [LLaMA-3.1-70B (TP=8)](./models/demos/t3000/llama3_70b) | 32 | [QuietBox](https://tenstorrent.com/hardware/tt-quietbox) | 190 | 15.1 | 20 | 483.2 | [v0.53.0-rc36](https://github.com/tenstorrent/tt-metal/tree/v0.53.0-rc36) | -| [Falcon40B (TP=8)](./models/demos/t3000/falcon40b) | 32 | [QuietBox](https://tenstorrent.com/hardware/tt-quietbox) | | 5.3 | 36 | 169.6 | [v0.53.1-rc7](https://github.com/tenstorrent/tt-metal/tree/v0.53.1-rc7) | -| [Mixtral7Bx8 (TP=8)](./models/demos/t3000/mixtral8x7b) | 32 | [QuietBox](https://tenstorrent.com/hardware/tt-quietbox) | 230 | 14.6 | 33 | 467.2 | [v0.53.0-rc44](https://github.com/tenstorrent/tt-metal/tree/v0.53.0-rc44) | -| [Falcon7B (DP=32)](./models/demos/tg/falcon7b) | 1024 | [Galaxy](https://tenstorrent.com/hardware/galaxy) | 242 | 4.4 | 26 | 4505.6 | [v0.53.0-rc33](https://github.com/tenstorrent/tt-metal/tree/v0.53.0-rc33) | -| [LLaMA-3.1-70B (DP=4, TP=8)](./models/demos/t3000/llama3_70b) | 128 | [Galaxy](https://tenstorrent.com/hardware/galaxy) | 190 | 14.3 | 20 | 1835.5 | [v0.52.0-rc31](https://github.com/tenstorrent/tt-metal/tree/v0.52.0-rc31) | -> **Last Update:** December 2, 2024 +| Model | Batch | Hardware | ttft (ms) | t/s/u | Target
t/s/u | t/s | TT-Metalium Release | vLLM Tenstorrent Repo Release | +|---------------------------------------------------------------|-------|----------------------------------------------------------|-----------|-------|-----------------|--------|---------------------------------------------------|---------------------------------------------------------------------------------------------------| +| [Falcon 7B (decode only)](./models/demos/ttnn_falcon7b) | 32 | [e150](https://tenstorrent.com/hardware/grayskull) | | 4.2 | 4.4 | 134.4 | | | +| [Falcon 7B](./models/demos/wormhole/falcon7b) | 32 | [n150](https://tenstorrent.com/hardware/wormhole) | 71 | 17.6 | 26 | 563.2 | [v0.53.0-rc44](https://github.com/tenstorrent/tt-metal/tree/v0.53.0-rc44) | | +| [Mistral 7B](./models/demos/wormhole/mistral7b) | 32 | [n150](https://tenstorrent.com/hardware/wormhole) | | 9.9 | 25 | 316.8 | [v0.51.0-rc28](https://github.com/tenstorrent/tt-metal/tree/v0.51.0-rc28) | | +| [Mamba 2.8B](./models/demos/wormhole/mamba) | 32 | [n150](https://tenstorrent.com/hardware/wormhole) | 48 | 12.3 | 41 | 393.6 | [v0.51.0-rc26](https://github.com/tenstorrent/tt-metal/tree/v0.51.0-rc26) | | +| [Llama 3.1 8B](./models/demos/llama3) | 1 | [n150](https://tenstorrent.com/hardware/wormhole) | 202 | 28.6 | 23 | 28.6 | [v0.53.1-rc7](https://github.com/tenstorrent/tt-metal/tree/v0.53.1-rc7) | | +| [Llama 3.2 1B](./models/demos/llama3) | 1 | [n150](https://tenstorrent.com/hardware/wormhole) | 71 | 90.8 | 160 | 90.8 | [v0.53.1-rc7](https://github.com/tenstorrent/tt-metal/tree/v0.53.1-rc7) | | +| [Llama 3.2 3B](./models/demos/llama3) | 1 | [n150](https://tenstorrent.com/hardware/wormhole) | 112 | 49.1 | 60 | 49.1 | [v0.53.1-rc7](https://github.com/tenstorrent/tt-metal/tree/v0.53.1-rc7) | | +| [Falcon 7B (DP=8)](./models/demos/t3000/falcon7b) | 256 | [QuietBox](https://tenstorrent.com/hardware/tt-quietbox) | 97 | 14.6 | 26 | 3737.6 | [v0.53.0-rc44](https://github.com/tenstorrent/tt-metal/tree/v0.53.0-rc44) | | +| [Llama 3.1 70B (TP=8)](./models/demos/t3000/llama3_70b) | 32 | [QuietBox](https://tenstorrent.com/hardware/tt-quietbox) | 190 | 15.1 | 20 | 483.2 | [v0.53.0-rc36](https://github.com/tenstorrent/tt-metal/tree/v0.53.0-rc36) | [384f179](https://github.com/tenstorrent/vllm/tree/384f1790c3be16e1d1b10de07252be2e66d00935) | +| [Falcon 40B (TP=8)](./models/demos/t3000/falcon40b) | 32 | [QuietBox](https://tenstorrent.com/hardware/tt-quietbox) | | 5.3 | 36 | 169.6 | [v0.53.1-rc7](https://github.com/tenstorrent/tt-metal/tree/v0.53.1-rc7) | | +| [Mixtral 8x7B (TP=8)](./models/demos/t3000/mixtral8x7b) | 32 | [QuietBox](https://tenstorrent.com/hardware/tt-quietbox) | 230 | 14.6 | 33 | 467.2 | [v0.53.0-rc44](https://github.com/tenstorrent/tt-metal/tree/v0.53.0-rc44) | | +| [Falcon 7B (DP=32)](./models/demos/tg/falcon7b) | 1024 | [Galaxy](https://tenstorrent.com/hardware/galaxy) | 242 | 4.4 | 26 | 4505.6 | [v0.53.0-rc33](https://github.com/tenstorrent/tt-metal/tree/v0.53.0-rc33) | | +| [Llama 3.1 70B (DP=4, TP=8)](./models/demos/t3000/llama3_70b) | 128 | [Galaxy](https://tenstorrent.com/hardware/galaxy) | 190 | 14.3 | 20 | 1835.5 | [v0.52.0-rc31](https://github.com/tenstorrent/tt-metal/tree/v0.52.0-rc31) | | + +> **Last Update:** December 7, 2024 +> > **Notes:** +> +> - ttft = time to first token | t/s/u = tokens/second/user | t/s = tokens/second; where t/s = t/s/u * batch. > - TP = Tensor Parallel, DP = Data Parallel; Defines parallelization factors across multiple devices. > - The reported LLM performance is for an input sequence length (number of rows filled in the KV cache) of 128 for all models except Mamba (which can accept any sequence length). > - The t/s/u reported is the throughput of the first token generated after prefill, i.e. 1 / inter token latency. ## CNNs + | Model | Batch | Hardware | fps | Target fps | Release | |-----------------------------------------------------------------------------|-------|----------------------------------------------------------|---------|------------|-------------| | [ResNet-50 (224x224)](./models/demos/grayskull/resnet50) | 20 | [e150](https://tenstorrent.com/hardware/grayskull) | 5,100 | 10,000 | | @@ -55,11 +60,11 @@ | [ViT (224x224)](./models/demos/grayskull/vit) | 9 | [e150](https://tenstorrent.com/hardware/grayskull) | 1,360 | 2,000 | | | [ViT (224x224)](./models/demos/wormhole/vit) | 8 | [n150](https://tenstorrent.com/hardware/wormhole) | 912 | 1,600 | | | [Stable Diffusion 1.4 (512x512)](./models/demos/wormhole/stable_diffusion) | 1 | [n150](https://tenstorrent.com/hardware/wormhole) | 0.167 | 0.3 | | -| [Yolo V4 (320x320)](./models/demos/yolov4) | 1 | [n150](https://tenstorrent.com/hardware/wormhole) | 95 | 300 | | -| [Segformer Semantic Segmentation (512x512)](./models/demos/segformer) | 1 | [n150](https://tenstorrent.com/hardware/wormhole) | 90 | 300 | | - +| [YOLOv4 (320x320)](./models/demos/yolov4) | 1 | [n150](https://tenstorrent.com/hardware/wormhole) | 95 | 300 | | +| [SegFormer Semantic Segmentation (512x512)](./models/demos/segformer) | 1 | [n150](https://tenstorrent.com/hardware/wormhole) | 90 | 300 | | ## NLPs + | Model | Batch | Hardware | sen/sec | Target sen/sec | Release | |-----------------------------------------------------|-------|----------------------------------------------------|---------|----------------|---------| | [BERT-Large](./models/demos/metal_BERT_large_11/) | 12 | [e150](https://tenstorrent.com/hardware/grayskull) | 370 | 410 | | @@ -68,9 +73,11 @@ | [Bloom](.models/demos/grayskull/functional_bloom) | | [e150](https://tenstorrent.com/hardware/grayskull) | 70 | | | ## Model Updates + For the latest model updates and features, please see [MODEL_UPDATES.md](models/MODEL_UPDATES.md) ## TT-NN Tech Reports + - [Advanced Performance Optimizations for Models](./tech_reports/AdvancedPerformanceOptimizationsForModels/AdvancedPerformanceOptimizationsForModels.md) (updated Dec 4th) - [Programming Mesh of Devices](./tech_reports/Programming%20Mesh%20of%20Devices/Programming%20Mesh%20of%20Devices%20with%20TT-NN.md) (updated Sept 9th) - [ViT Implementation in TT-NN on GS](./tech_reports/ViT-TTNN/vit.md) (updated Sept 22nd) @@ -78,8 +85,8 @@ For the latest model updates and features, please see [MODEL_UPDATES.md](models/ - [YOLOv4 Implementation in TT-NN on WH](./tech_reports/YoloV4-TTNN/yolov4.md) (updated November 8th) ## Benchmarks -- [Matrix Multiply FLOPS on WH](./tech_reports/GEMM_FLOPS/GEMM_FLOPS.md) (updated November 13th) +- [Matrix Multiply FLOPS on WH](./tech_reports/GEMM_FLOPS/GEMM_FLOPS.md) (updated November 13th) --- @@ -89,7 +96,6 @@ For the latest model updates and features, please see [MODEL_UPDATES.md](models/ **TT-Metalium** is our low-level programming model, enabling kernel development for Tenstorrent hardware. -

[Programming Guide](./METALIUM_GUIDE.md) | [API Reference](https://docs.tenstorrent.com/tt-metalium/latest/tt_metal/apis/index.html) @@ -102,6 +108,7 @@ For the latest model updates and features, please see [MODEL_UPDATES.md](models/ Get started with [simple kernels](https://docs.tenstorrent.com/tt-metalium/latest/tt_metal/examples/index.html). ## TT-Metalium Tech Reports + - [Matrix Engine](./tech_reports/matrix_engine/matrix_engine.md) (updated Sept 6th) - [Data Formats](./tech_reports/data_formats/data_formats.md) (updated Sept 7th) - [Reconfiguring Data Formats](./tech_reports/data_formats/reconfig_data_format.md) (updated Oct 17th) @@ -113,24 +120,36 @@ Get started with [simple kernels](https://docs.tenstorrent.com/tt-metalium/lates - [CNNs on TT Architectures](./tech_reports/CNNs/ttcnn.md) (updated Sept 6th) - [Ethernet and Multichip Basics](./tech_reports/EthernetMultichip/BasicEthernetGuide.md) (Updated Sept 20th) - [Collective Communication Library (CCL)](./tech_reports/EthernetMultichip/CclDeveloperGuide.md) (Updated Sept 20th) -- [Blackhole Bring-Up Prgramming Guide](./tech_reports/Blackhole/BlackholeBringUpProgrammingGuide.md) (Updated Oct 30th) +- [Blackhole Bring-Up Programming Guide](./tech_reports/Blackhole/BlackholeBringUpProgrammingGuide.md) (Updated Oct 30th) ## TT-Metalium Programming Examples + ### Hello World + - [Hello World! Compute Kernel](./tech_reports/prog_examples/hello_world_compute/hello_world_compute.md) - [Hello World! Data Movement Kernel](./tech_reports/prog_examples/hello_world_data_movement/hello_world_data_movement.md) + ### Add Integers + - [Add 2 Integers in Baby RiscV](./tech_reports/prog_examples/add_2_integers_in_riscv/add_2_integers_in_riscv.md) - [Add 2 Integers in Compute Kernel](./tech_reports/prog_examples/add_2_integers_in_compute/add_2_integers_in_compute.md) + ### Simple Tensor Manipulation + - [Sharding](./tech_reports/prog_examples/shard_data_rm/shard_data_rm.md) - [Padding](./tech_reports/prog_examples/pad_multi_core/pad_multi_core.md) + ### DRAM Data Movement + - [Dram Loopback Data Movement](./tech_reports/prog_examples/dram_loopback/dram_loopback.md) + ### Eltwise + - [Eltwise Unary OP in Vector Engine (SFPU)](./tech_reports/prog_examples/eltwise_sfpu/eltwise_sfpu.md) - [Eltwise Binary OP in Matrix Engine (FPU)](./tech_reports/prog_examples/eltwise_binary/eltwise_binary.md) + ### Matmul + - [Matmul OP on a Single_core](./tech_reports/prog_examples/matmul_single_core/matmul_single_core.md) - [Matmul OP on Multi_core (Basic)](./tech_reports/prog_examples/matmul_multi_core/matmul_multi_core.md) - [Matmul Multi_core Reuse (Optimized)](./tech_reports/prog_examples/matmul_multi_core_optimized/data_reuse.md) diff --git a/conftest.py b/conftest.py index f1a753ca243..6e43d1a6499 100644 --- a/conftest.py +++ b/conftest.py @@ -257,7 +257,7 @@ def pcie_mesh_device(request, silicon_arch_name, silicon_arch_wormhole_b0, devic mesh_shape=ttnn.MeshShape(2, 2), dispatch_core_config=dispatch_core_config, **device_params, - offset=(0, 1), + offset=ttnn.MeshOffset(0, 1), mesh_type=ttnn.MeshType.Ring, ) diff --git a/dependencies/CMakeLists.txt b/dependencies/CMakeLists.txt index 7369f655c1c..e14310435a4 100644 --- a/dependencies/CMakeLists.txt +++ b/dependencies/CMakeLists.txt @@ -85,3 +85,9 @@ CPMAddPackage(NAME fmt GITHUB_REPOSITORY fmtlib/fmt GIT_TAG 11.0.1) ############################################################################################################################ CPMAddPackage(NAME range-v3 GITHUB_REPOSITORY ericniebler/range-v3 GIT_TAG 0.12.0) + +############################################################################################################################ +# pybind11 : https://github.com/pybind/pybind11 +############################################################################################################################ + +CPMAddPackage(NAME pybind11 GITHUB_REPOSITORY pybind/pybind11 GIT_TAG b8f28551cc3a98ea9fbfc15c05b513c8f2d23e84) diff --git a/models/demos/convnet_mnist/tt/convnet_mnist.py b/models/demos/convnet_mnist/tt/convnet_mnist.py index a38aa60a770..1d9ac8acba0 100644 --- a/models/demos/convnet_mnist/tt/convnet_mnist.py +++ b/models/demos/convnet_mnist/tt/convnet_mnist.py @@ -19,21 +19,23 @@ def convnet_mnist( conv_config = ttnn.Conv2dConfig( dtype=ttnn.bfloat16, weights_dtype=ttnn.bfloat16, - math_fidelity=ttnn.MathFidelity.LoFi, activation="", shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED, - math_approx_mode_enabled=True, - fp32_dest_acc_enabled=False, - packer_l1_accum_enabled=False, input_channels_alignment=32, transpose_shards=False, reshard_if_not_optimal=True, deallocate_activation=True, reallocate_halo_output=True, ) - + compute_config = ttnn.init_device_compute_kernel_config( + device.arch(), + math_fidelity=ttnn.MathFidelity.LoFi, + math_approx_mode=True, + fp32_dest_acc_en=False, + packer_l1_acc=False, + ) x = ttnn.to_layout(input_tensor, layout=ttnn.ROW_MAJOR_LAYOUT) - [x, out_height, out_width, weights_device, bias_device] = ttnn.conv2d( + x = ttnn.conv2d( input_tensor=x, weight_tensor=parameters.conv1.weight, in_channels=1, @@ -47,9 +49,12 @@ def convnet_mnist( input_height=input_tensor.shape[1], input_width=input_tensor.shape[2], conv_config=conv_config, + compute_config=compute_config, conv_op_cache={}, debug=True, groups=1, + return_output_dim=False, + return_weights_and_bias=False, ) x = ttnn.relu(x) @@ -76,7 +81,7 @@ def convnet_mnist( dilation=[1, 1], ) - [x, out_height, out_width, weights_device, bias_device] = ttnn.conv2d( + x, [out_height, out_width] = ttnn.conv2d( input_tensor=x, weight_tensor=parameters.conv2.weight, in_channels=32, @@ -93,6 +98,8 @@ def convnet_mnist( conv_op_cache={}, debug=False, groups=1, + return_output_dim=True, + return_weights_and_bias=False, ) x = ttnn.relu(x) diff --git a/models/demos/llama3/PERF.md b/models/demos/llama3/PERF.md index dd060a14c1c..f0dbf00ec4b 100644 --- a/models/demos/llama3/PERF.md +++ b/models/demos/llama3/PERF.md @@ -12,16 +12,16 @@ This configuration uses bfp4 MLP FF1+FF3 for all models. |-------|--------|-----------|-----------|---------------| | 1b | N150 | 79 | 98 | 90.5 | | 1b | N300 | 81 | 98 | 101.7 | -| 1b | T3K | 81 | 98 | 97.5 | +| 1b | T3K | 81 | 98 | 96.8 | | 3b | N150 | 85 | 96 | 49.0 | | 3b | N300 | 88 | 97 | 56.9 | | 3b | T3K | 88 | 97 | 54.5 | | 8b | N150 | 86 | 98 | 28.4 | | 8b | N300 | 84 | 98 | 38.6 | -| 8b | T3K | 84 | 98 | 52.6 | +| 8b | T3K | 84 | 97 | 52.6 | | 11b | N300 | 86 | 97 | 38.6 | | 11b | T3K | 84 | 98 | 52.6 | -| 70b | T3K | 95 | 100 | 14.3 | +| 70b | T3K | 94 | 100 | 14.3 | ## LlamaOptimizations.accuracy @@ -40,4 +40,4 @@ This configuration uses bfp4 MLP FF1+FF3 only for the 3.1-70B model. | 8b | T3K | 88 | 97 | 49.9 | | 11b | N300 | 90 | 97 | 33.8 | | 11b | T3K | 88 | 97 | 52.6 | -| 70b | T3K | 95 | 100 | 14.5 | +| 70b | T3K | 94 | 100 | 14.5 | diff --git a/models/demos/llama3/demo/demo.py b/models/demos/llama3/demo/demo.py index a828830d330..f3b5b998fcb 100644 --- a/models/demos/llama3/demo/demo.py +++ b/models/demos/llama3/demo/demo.py @@ -26,7 +26,6 @@ ) from models.demos.llama3.tt.llama_model import TtTransformer from models.demos.llama3.tt.llama_embedding import TtLlamaEmbedding -from models.demos.llama3.tt.llama_rope import TtLlamaRotarySetup from models.demos.t3000.llama2_70b.reference.llama.llama31_8b.tokenizer import Tokenizer from models.demos.llama3.tt.model_config import TtModelArgs @@ -227,6 +226,7 @@ def run_llama3_demo( optimizations=optimizations, max_seq_len=max_seq_len, ) + tokenizer = Tokenizer(model_args.tokenizer_path) # Check max sequence length compatibility with model and architecture. Refer to README for more information @@ -259,34 +259,13 @@ def run_llama3_demo( ), "T3K only supports a max context length of 128k tokens for Llama3.1-8B and Llama3.2-11B" if llama_model_name == "3.1-70B": assert tt_device_name in ["T3K", "TG"], "Llama3.1-70B is only supported on T3K or TG" + assert max_seq_len <= 64 * 1024, "T3K only supports a max context length of 64k tokens for Llama3.1-70B" logger.info("Loading weights...") profiler.start("weight_loading") state_dict = model_args.load_state_dict() profiler.end("weight_loading") - # Setup RoPE transformation matrices - rope_setup = TtLlamaRotarySetup( - mesh_device, - batch_size, - model_args.head_dim, - model_args.max_seq_len, - model_args.rope_theta, - model_args.use_scaled_rope, - ) - transformation_mats_decode = rope_setup.get_trans_mats() - - transformation_mats_prefill_torch = get_rot_transformation_mat(model_args.head_dim) - transformation_mats_prefill = ttnn.from_torch( - transformation_mats_prefill_torch, - dtype=ttnn.bfloat16, - layout=ttnn.TILE_LAYOUT, - device=mesh_device, - memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), - ) - transformation_mats = {"decode": transformation_mats_decode, "prefill": transformation_mats_prefill} - page_table_tt = None if paged_attention: @@ -314,7 +293,6 @@ def run_llama3_demo( dtype=dtype, state_dict=state_dict, weight_cache_path=model_args.weight_cache_path(dtype), - transformation_mats=transformation_mats, paged_attention_config=paged_attention_config, ) tt_embd = TtLlamaEmbedding( @@ -384,7 +362,7 @@ def run_llama3_demo( :, decoding_pos[batch_id] :, : ] = 0 # Zero out the tokens after the prefill length - prefill_input = model_args.prepare_inputs_ttnn_prefill( + prefill_input = model_args.prepare_residual_tensor_prefill( pt_prefill_input[batch_id], ) @@ -476,7 +454,7 @@ def run_llama3_demo( ) # Get cos/sin matrices for the current position of each user - rot_mats, rot_mat_idxs = rope_setup.get_rot_mats(current_pos, return_rot_idxs=True) + rot_mats, rot_mat_idxs = tt_model.rope_setup.get_rot_mats(current_pos, return_rot_idxs=True) # Compile logger.info(f"Compiling model trace...") decode_input = ttnn.unsqueeze_to_4D(tt_embd(tt_out_tok)) @@ -519,7 +497,7 @@ def run_llama3_demo( decode_input = ttnn.unsqueeze_to_4D(tt_embd(tt_out_tok)) decode_input = ttnn.to_memory_config(decode_input, tt_model.args.model_config["DECODE_RESIDUAL_MEMCFG"]) - rot_mats = rope_setup.get_rot_mats(rot_mat_idxs) + rot_mats = tt_model.rope_setup.get_rot_mats(rot_mat_idxs) tt_out = tt_model( decode_input, current_pos_tensor, @@ -562,7 +540,7 @@ def run_llama3_demo( # Reset the current position and output token tensors for the real decode run ttnn.copy_host_to_device_tensor(current_pos_reset, current_pos_tensor) ttnn.copy_host_to_device_tensor(tt_out_tok_reset, tt_out_tok) - rot_mat_idxs_reset = rope_setup.get_rot_idxs(current_pos, on_host=True) + rot_mat_idxs_reset = tt_model.rope_setup.get_rot_idxs(current_pos, on_host=True) ttnn.copy_host_to_device_tensor(rot_mat_idxs_reset, rot_mat_idxs) profiler.end(f"capture_trace_{batch_idx}") @@ -591,7 +569,7 @@ def run_llama3_demo( # TODO This is required for now since we cannot ttnn.plus_one(rot_mat_idxs) while it being uint32. # If this tensor is int32, it won't be supported by ttnn.embedding current_pos += 1 - rot_mat_idxs_updated = rope_setup.get_rot_idxs(current_pos, on_host=True) + rot_mat_idxs_updated = tt_model.rope_setup.get_rot_idxs(current_pos, on_host=True) ttnn.copy_host_to_device_tensor(rot_mat_idxs_updated, rot_mat_idxs) # Write to host diff --git a/models/demos/llama3/demo/multimodal_demo_chat.py b/models/demos/llama3/demo/multimodal_demo_chat.py index ca3d5b498e3..ac7c5a60b2e 100644 --- a/models/demos/llama3/demo/multimodal_demo_chat.py +++ b/models/demos/llama3/demo/multimodal_demo_chat.py @@ -21,7 +21,7 @@ IMG_PATH = Path(resource_filename("llama_models", "scripts/resources/")) -from models.demos.llama3.tt.multimodal.vision_generator import LlamaVision +from models.demos.llama3.tt.generator import LlamaGenerator from models.demos.llama3.demo.simple_vision_demo import create_multimodal_model @@ -67,7 +67,7 @@ def test_llama_multimodal_demo_chat( model_args, model = create_multimodal_model(mesh_device, max_batch_size=max_batch_size, max_seq_len=max_seq_len) tokenizer = Tokenizer(model_path=tokenizer_path) formatter = ChatFormat(tokenizer) - generator = LlamaVision(model, model_args, mesh_device, tokenizer=tokenizer, formatter=formatter) + generator = LlamaGenerator(model, model_args, mesh_device, tokenizer=tokenizer, formatter=formatter) # image understanding dialogs = [] diff --git a/models/demos/llama3/demo/multimodal_demo_text.py b/models/demos/llama3/demo/multimodal_demo_text.py index 2029c43458b..4bea26c781b 100644 --- a/models/demos/llama3/demo/multimodal_demo_text.py +++ b/models/demos/llama3/demo/multimodal_demo_text.py @@ -23,7 +23,7 @@ IMG_PATH = Path(resource_filename("llama_models", "scripts/resources/")) from models.demos.llama3.demo.simple_vision_demo import create_multimodal_model -from models.demos.llama3.tt.multimodal.vision_generator import LlamaVision +from models.demos.llama3.tt.generator import LlamaGenerator @pytest.mark.parametrize( @@ -73,7 +73,7 @@ def test_llama_multimodal_demo_text( model_args, model = create_multimodal_model(mesh_device, max_batch_size=max_batch_size, max_seq_len=max_seq_len) tokenizer = Tokenizer(model_path=tokenizer_path) formatter = ChatFormat(tokenizer) - generator = LlamaVision(model, model_args, mesh_device, tokenizer=tokenizer, formatter=formatter) + generator = LlamaGenerator(model, model_args, mesh_device, tokenizer=tokenizer, formatter=formatter) with open(IMG_PATH / "dog.jpg", "rb") as f: img = PIL_Image.open(f).convert("RGB") diff --git a/models/demos/llama3/demo/simple_vision_demo.py b/models/demos/llama3/demo/simple_vision_demo.py index b4946c3eecf..cda3c2ed957 100644 --- a/models/demos/llama3/demo/simple_vision_demo.py +++ b/models/demos/llama3/demo/simple_vision_demo.py @@ -23,10 +23,10 @@ import ttnn import time -from models.demos.llama3.tt.multimodal.vision_generator import LlamaVision +from models.demos.llama3.tt.generator import LlamaGenerator -def get_sampler(temperature, top_p, tokenizer): +def get_batch_sampler(temperature, top_p, tokenizer): def sample(logits): if temperature > 0: probs = torch.softmax(logits[:, -1] / temperature, dim=-1) @@ -34,15 +34,14 @@ def sample(logits): else: next_token = torch.argmax(logits[:, -1], dim=-1) - next_token = next_token.reshape(-1) - token = next_token[0].item() - text = tokenizer.decode(next_token.tolist()) - return token, text + next_tokens = next_token.reshape(-1) + texts = [tokenizer.decode([next_tokens[i].item()]) for i in range(len(next_tokens))] + return next_tokens, texts return sample -def create_multimodal_model(mesh_device, max_batch_size, max_seq_len, dtype=ttnn.bfloat16): +def create_multimodal_model(mesh_device, max_batch_size, max_seq_len, dtype=ttnn.bfloat16, use_paged_kv_cache=False): from models.demos.llama3.tt.multimodal.llama_vision_model import CrossAttentionTransformer from models.demos.llama3.tt.model_config import TtModelArgs @@ -56,6 +55,7 @@ def create_multimodal_model(mesh_device, max_batch_size, max_seq_len, dtype=ttnn weight_cache_path=tt_model_args.weight_cache_path(dtype), dtype=dtype, configuration=tt_model_args, + use_paged_kv_cache=use_paged_kv_cache, ) return tt_model_args, model @@ -70,32 +70,30 @@ def create_multimodal_model(mesh_device, max_batch_size, max_seq_len, dtype=ttnn indirect=True, ) @pytest.mark.parametrize( - "warmup_iters", - (0, 1), - ids=["cold", "warm"], + "test_type,max_seq_len", + (("normal", 512),), + ids=["normal"], ) @pytest.mark.parametrize( - "test_case", + "warmup_iters, enable_trace, max_batch_size", [ - "normal", + (0, False, 1), # batch1-notrace + (0, True, 1), # batch1-trace + (0, True, 32), # batch32-trace ], -) -@pytest.mark.parametrize( - "enable_trace", - (False, True), - ids=["no_trace", "yes_trace"], + ids=["batch1-notrace", "batch1-trace", "batch32-trace"], ) @pytest.mark.parametrize("device_params", [{"trace_region_size": 14951424, "num_command_queues": 2}], indirect=True) def test_llama_multimodal_demo_text( mesh_device, warmup_iters, - test_case, enable_trace, + max_batch_size, + test_type, + max_seq_len, temperature: float = 0, top_p: float = 0.9, - max_seq_len: int = 512, - max_batch_size: int = 1, - max_gen_len: Optional[int] = 200, + max_gen_len: Optional[int] = 500, model_parallel_size: Optional[int] = None, ): """ @@ -107,7 +105,7 @@ def test_llama_multimodal_demo_text( mesh_device.enable_program_cache() mesh_device.enable_async(True) model_args, model = create_multimodal_model(mesh_device, max_batch_size=max_batch_size, max_seq_len=max_seq_len) - generator = LlamaVision(model, model_args, mesh_device) + generator = LlamaGenerator(model, model_args, mesh_device) tokenizer = Tokenizer(model_path=tokenizer_path) formatter = ChatFormat(tokenizer) @@ -132,96 +130,106 @@ def test_llama_multimodal_demo_text( [UserMessage(content=[ImageMedia(image=ocr_image), "What is the full text of this image? Do OCR"])], [UserMessage(content=[ImageMedia(image=clutter), "What objects are in this image?"])], ] + if len(dialogs) < max_batch_size: + dialogs *= max_batch_size // len(dialogs) - sampler = get_sampler(temperature, top_p, tokenizer) - - for iter_num in range(warmup_iters + 1): - for dialog in dialogs: - for msg in dialog: - print(f"{msg.role.capitalize()}: {msg.content}\n") + assert len(dialogs) % max_batch_size == 0 + num_batches = len(dialogs) // max_batch_size - if iter_num <= warmup_iters: - logger.info(f"Warmup iteration {iter_num}") + sampler = get_batch_sampler(temperature, top_p, tokenizer) - model_input = formatter.encode_dialog_prompt(dialog, tool_prompt_format=False) + for iter_num in range(warmup_iters + 1): + logger.info(f"Iteration {iter_num}") + for batch_idx in range(num_batches): + batch_dialogs = dialogs[batch_idx * max_batch_size : (batch_idx + 1) * max_batch_size] + for dialog in batch_dialogs: + for msg in dialog: + print(f"{msg.role.capitalize()}: {msg.content}\n") + batch_model_input = [ + formatter.encode_dialog_prompt(dialog, tool_prompt_format=False) for dialog in batch_dialogs + ] # Do initial prefill - vision_images = model_input.vision.images - vision_mask = model_input.vision.mask - prompt_tokens = model_input.tokens - prefill_len = len(prompt_tokens) - total_len = prefill_len + max_gen_len # Prepares mask for full length of output - # Create tokens tensor + vision_images = [model_input.vision.images for model_input in batch_model_input] + vision_mask = [model_input.vision.mask for model_input in batch_model_input] + prompt_tokens = [model_input.tokens for model_input in batch_model_input] + # Get max length of prompts in batch + prefill_lens = torch.tensor([len(tokens) for tokens in prompt_tokens], dtype=torch.long) + total_lens = prefill_lens + max_gen_len + + # Create padded tokens tensor for batch pad_id = tokenizer.pad_id - bsz = 1 - tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long) - tokens[0, : len(prompt_tokens)] = torch.tensor(prompt_tokens, dtype=torch.long) + bsz = len(prompt_tokens) + tokens = torch.full((bsz, max(total_lens)), pad_id, dtype=torch.long) + + # Fill in actual tokens for each sequence in batch + for i, seq in enumerate(prompt_tokens): + tokens[i, : len(seq)] = torch.tensor(seq, dtype=torch.long) + prefill_start = time.perf_counter() - prompt_tokens_tensor = torch.tensor(prompt_tokens, dtype=torch.long).reshape(1, -1) # B, S - ( - xattn_caches, - cross_attention_masks, - full_text_row_masked_out_mask, - logits, - ) = generator.prefill_forward_single_user( + batch_logits, batch_xattn_masks, batch_text_masks = generator.prefill_forward( vision_images, vision_mask, - prompt_tokens_tensor, + tokens, xattn_caches, - user_id=0, - total_len=total_len, - prefill_len=prefill_len, + total_lens, + prefill_lens, ) - prefill_end = time.perf_counter() - - next_token, text = sampler(logits) - tokens[0, prefill_len] = next_token + prefill_end = time.perf_counter() + next_tokens, next_texts = sampler(batch_logits) + for i, (next_token, next_text) in enumerate(zip(next_tokens, next_texts)): + tokens[i, prefill_lens[i]] = next_token + print(f"Next tokens: {next_tokens}") + print(f"Next texts: {next_texts}") decode_times = [] for gen_idx in range(max_gen_len - 1): decode_start = time.perf_counter() - position_id = prefill_len + gen_idx - next_token_tensor = torch.tensor([next_token], dtype=torch.long).reshape(1, 1) # B, S + position_id = prefill_lens + gen_idx + next_token_tensor = next_tokens.reshape(max_batch_size, 1) if enable_trace: logits = generator.easy_trace( position_id, next_token_tensor, - cross_attention_masks, - full_text_row_masked_out_mask, + batch_xattn_masks, + batch_text_masks, xattn_caches, ) else: logits = generator.decode_forward( position_id, next_token_tensor, - cross_attention_masks, - full_text_row_masked_out_mask, + batch_xattn_masks, + batch_text_masks, xattn_caches, ) - next_token, text = sampler(logits) + next_tokens, next_texts = sampler(logits) # Update next token - tokens[0, position_id + 1] = next_token + tokens[torch.arange(max_batch_size), position_id + 1] = next_tokens decode_end = time.perf_counter() decode_times.append(decode_end - decode_start) - if text in ["<|eot_id|>", "<|eom_id|>"]: - break - - # Log full text output + # Disable checking for eot until I have more robust code for batch > 1 + # if text in ["<|eot_id|>", "<|eom_id|>"]: + # break + # Log full text output for each user in batch vision_tokens = [tokenizer.special_tokens["<|image|>"], 128256] - # Remove <|image|> tokens since they break the tokenizer - tokens_out = [ - t if t not in vision_tokens else tokenizer.pad_id for t in tokens[0].tolist()[: position_id + 2] - ] - text = tokenizer.decode(tokens_out) - logger.info(f"Full text: {text}") + + for user_id in range(max_batch_size): + # Remove <|image|> tokens since they break the tokenizer + tokens_out = [ + t if t not in vision_tokens else tokenizer.pad_id + for t in tokens[user_id].tolist()[: position_id[user_id] + 2] + ] + text = tokenizer.decode(tokens_out) + logger.info(f"User {user_id} full text: {text}") prefill_time_ms = (prefill_end - prefill_start) * 1000 logger.info(f"Prefill time: {prefill_time_ms:.2f} ms") decode_time_ms = sum(decode_times) / (gen_idx + 1) * 1000 - logger.info(f"Decode time: {decode_time_ms:.2f} ms") + logger.info(f"Average decode time per token: {decode_time_ms:.2f} ms") # ttnn.release_trace(generator.mesh_device, trace_id) diff --git a/models/demos/llama3/lt b/models/demos/llama3/lt index 594568609ba..8f68983a2b6 100755 --- a/models/demos/llama3/lt +++ b/models/demos/llama3/lt @@ -733,7 +733,7 @@ def run_entry_command(entry, screen_lock, output_entries, screen_needs_update): command_shortcuts = { "demo": "pytest models/demos/llama3/demo/demo.py -k performance-batch-1", "demo-32": "pytest models/demos/llama3/demo/demo.py -k performance-batch-32", - "demo-long": "pytest models/demos/llama3/demo/demo.py -k long", + "demo-long": "pytest models/demos/llama3/demo/demo.py -k performance-long", "attention": "pytest models/demos/llama3/tests/test_llama_attention.py", "attention-prefill": "pytest models/demos/llama3/tests/test_llama_attention_prefill.py", "mlp": "pytest models/demos/llama3/tests/test_llama_mlp.py", diff --git a/models/demos/llama3/tests/multimodal/test_llama_cross_attention.py b/models/demos/llama3/tests/multimodal/test_llama_cross_attention.py index cc34f091e17..462a004b133 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_cross_attention.py +++ b/models/demos/llama3/tests/multimodal/test_llama_cross_attention.py @@ -34,9 +34,10 @@ ) @pytest.mark.parametrize( "batch", - (1,), + (1, 2), ids=[ "batch_1", + "batch_2", ], ) def test_llama_cross_attention_inference(text_seq_len, batch, mesh_device, reset_seeds, ensure_gc): @@ -46,6 +47,7 @@ def test_llama_cross_attention_inference(text_seq_len, batch, mesh_device, reset mesh_device.enable_async(True) model_args = TtModelArgs(mesh_device) + model_args.max_seq_len = text_seq_len state_dict = torch.load(model_args.consolidated_weights_path, map_location=torch.device("cpu")) # Ref model needs partial state dict, but our models use full state dict keys as cached weight names @@ -91,12 +93,15 @@ def test_llama_cross_attention_inference(text_seq_len, batch, mesh_device, reset """ pt_xattn_cache = reference_model.compute_xattn_kv_cache(pt_xattn_tokens) pt_xattn_cache_chunks = torch.chunk(pt_xattn_cache, 2, dim=0) - pt_xattn_cache_chunks = [x.view(batch, n_heads, vision_seq_len, head_dim) for x in pt_xattn_cache] + # slice out repeated KV heads + pt_xattn_cache_chunks = [ + x.view(batch, n_heads, vision_seq_len, head_dim)[:, :: n_heads // n_kv_heads] for x in pt_xattn_cache + ] # Preallocate K and V caches tt_xattn_cache = [ ttnn.from_torch( - torch.zeros(batch, n_heads, vision_seq_len, head_dim), + torch.zeros(batch, n_kv_heads, vision_seq_len, head_dim), device=mesh_device, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, @@ -109,9 +114,10 @@ def test_llama_cross_attention_inference(text_seq_len, batch, mesh_device, reset """ Test forward, prefill and decode! """ - for i in range(10): - seq_len = text_seq_len if i == 0 else 1 + n_iter = 10 + for i in range(n_iter): mode = "prefill" if i == 0 else "decode" + seq_len = text_seq_len if mode == "prefill" else 1 pt_x = (torch.rand(batch, seq_len, dim) * 2) - 1 tt_x = pt_x.clone() @@ -150,18 +156,18 @@ def test_llama_cross_attention_inference(text_seq_len, batch, mesh_device, reset if mode == "prefill": outputs = [] for b in range(batch): - tt_tensor_xattn_tokens = model_args.prepare_inputs_ttnn_prefill( + tt_tensor_xattn_tokens = model_args.prepare_residual_tensor_prefill( tt_xattn_tokens[b : b + 1], force_replicated=True, ) - tt_tensor_x = model_args.prepare_inputs_ttnn_prefill( + tt_tensor_x = model_args.prepare_residual_tensor_prefill( tt_x[b : b + 1], force_replicated=True, ) tt_xattn_mask = ttnn.from_torch( - xattn_mask_expand[b : b + 1], + xattn_mask[b : b + 1], device=mesh_device, - dtype=ttnn.bfloat8_b, + dtype=ttnn.bfloat4_b, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), @@ -169,7 +175,7 @@ def test_llama_cross_attention_inference(text_seq_len, batch, mesh_device, reset tt_full_text_mask = ttnn.from_torch( full_text_mask_expand[b : b + 1], device=mesh_device, - dtype=ttnn.bfloat8_b, + dtype=ttnn.bfloat4_b, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), @@ -190,18 +196,17 @@ def test_llama_cross_attention_inference(text_seq_len, batch, mesh_device, reset tt_output_torch = torch.cat(outputs, dim=0).view(batch, seq_len, dim) else: - tt_x = model_args.prepare_inputs_ttnn_decode( + tt_x = model_args.prepare_residual_tensor_decode( tt_x, - ttnn.DRAM_MEMORY_CONFIG, + model_args.model_config["SHARDED_ATTN_INPUT_MEMCFG"], force_replicated=True, ) - tt_x = ttnn.interleaved_to_sharded(tt_x, model_args.model_config["SHARDED_ATTN_INPUT_MEMCFG"]) xattn_mask_expand = xattn_mask_expand.permute(2, 0, 1, 3).contiguous() tt_xattn_mask = ttnn.from_torch( xattn_mask_expand, device=mesh_device, - dtype=ttnn.bfloat8_b, + dtype=ttnn.bfloat4_b, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), @@ -218,7 +223,7 @@ def test_llama_cross_attention_inference(text_seq_len, batch, mesh_device, reset tt_full_text_mask = ttnn.from_torch( full_text_mask_expand, device=mesh_device, - dtype=ttnn.bfloat8_b, + dtype=ttnn.bfloat4_b, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), @@ -239,7 +244,7 @@ def test_llama_cross_attention_inference(text_seq_len, batch, mesh_device, reset mode=mode, ) - tt_output_torch = ttnn.to_torch(tt_out, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=1)) + tt_output_torch = ttnn.to_torch(tt_out, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=-1)) tt_output_torch = tt_output_torch[:, :, :batch, :].reshape(batch, seq_len, dim) passing, pcc_message = comp_pcc(pt_out, tt_output_torch, pcc_required) @@ -251,12 +256,13 @@ def test_llama_cross_attention_inference(text_seq_len, batch, mesh_device, reset tt_xattn_cache_torch = [ ttnn.to_torch(x, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=1)).view( batch, - n_heads, + n_kv_heads, vision_seq_len, head_dim, ) for x in tt_xattn_cache ] + for pt, tt in zip(pt_xattn_cache_chunks, tt_xattn_cache_torch): passing, pcc_message = comp_pcc(pt, tt, pcc_required) diff --git a/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py b/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py index 7448601b8ce..1d9da2fbcca 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py +++ b/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py @@ -75,6 +75,7 @@ def test_llama_cross_attention_transformer_text_inference( dim = model_args.dim head_dim = model_args.head_dim n_heads = model_args.n_heads + n_kv_heads = model_args.n_kv_heads reference_model = llama_reference_mod.CrossAttentionTransformerText(args=model_args) reference_model.setup_cache(model_args.max_batch_size, torch.float32) reference_model.load_state_dict(partial_state_dict) @@ -107,15 +108,18 @@ def test_llama_cross_attention_transformer_text_inference( # unstack k/v pt_xattn_cache_chunks = [torch.chunk(x, 2, dim=1) for x in pt_xattn_cache_chunks] pt_xattn_cache_chunks = [x for xx in pt_xattn_cache_chunks for x in xx] - pt_xattn_cache_chunks = [x.view(batch, n_heads, vision_seq_len, head_dim) for x in pt_xattn_cache_chunks] + pt_xattn_cache_chunks = [ + x.view(batch, n_heads, vision_seq_len, head_dim)[:, :: n_heads // n_kv_heads] for x in pt_xattn_cache_chunks + ] # Iterate over batch # Preallocate K and V caches tt_xattn_cache = tt_model.setup_cache(max_batch_size=batch) # Test forward pass of the model - n_iter = 10 + prev_pos = 0 + n_iter = 10 # tokens = torch.randint(100, 1000, (batch, text_seq_len+n_iter), dtype=torch.long)#, device="cuda" tokens = torch.randint(0, model_args.vocab_size, (batch, text_seq_len + n_iter), dtype=torch.long) for i in range(n_iter): @@ -177,17 +181,17 @@ def test_llama_cross_attention_transformer_text_inference( if mode == "prefill": outputs = [] for b in range(batch): - tt_tensor_vision_tokens = model_args.prepare_inputs_ttnn_prefill( + tt_tensor_vision_tokens = model_args.prepare_residual_tensor_prefill( tt_vision_tokens[b : b + 1], force_replicated=True, ) - tt_h = model_args.prepare_inputs_ttnn_prefill( + tt_h = model_args.prepare_residual_tensor_prefill( h[b : b + 1], ) tt_xattn_mask = ttnn.from_torch( - xattn_mask_expand[b : b + 1], + xattn_mask[b : b + 1], device=mesh_device, - dtype=ttnn.bfloat8_b, + dtype=ttnn.bfloat4_b, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), @@ -195,7 +199,7 @@ def test_llama_cross_attention_transformer_text_inference( tt_full_text_mask_expand_1NSH = ttnn.from_torch( full_text_mask_expand_1NSH[b : b + 1], device=mesh_device, - dtype=ttnn.bfloat8_b, + dtype=ttnn.bfloat4_b, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), @@ -203,7 +207,7 @@ def test_llama_cross_attention_transformer_text_inference( tt_full_text_mask_expand_11SD = ttnn.from_torch( full_text_mask_expand_11SD[b : b + 1], device=mesh_device, - dtype=ttnn.bfloat8_b, + dtype=ttnn.bfloat4_b, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=-1), @@ -212,16 +216,6 @@ def test_llama_cross_attention_transformer_text_inference( rot_mats = get_prefill_rot_mat( model_args.head_dim, model_args.max_seq_len, mesh_device, seq_len=seq_len ) - transformation_mat_torch = get_rot_transformation_mat(model_args.head_dim) - transformation_mats = ttnn.as_tensor( - transformation_mat_torch, - dtype=ttnn.bfloat16, - layout=ttnn.TILE_LAYOUT, - device=mesh_device, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), - memory_config=ttnn.DRAM_MEMORY_CONFIG, - ) - tt_out = tt_model( tt_h, xattn_mask=tt_xattn_mask, @@ -229,8 +223,7 @@ def test_llama_cross_attention_transformer_text_inference( full_text_row_masked_out_mask_11SD=tt_full_text_mask_expand_11SD, xattn_caches=tt_xattn_cache, current_pos=None, - rot_mat=rot_mats, - transformation_mats=transformation_mats, + rot_mats=rot_mats, user_id=b, mode=mode, text_only_inference=TEXT_ONLY, @@ -245,7 +238,7 @@ def test_llama_cross_attention_transformer_text_inference( pcc_required = prefill_pcc_required else: - tt_h = model_args.prepare_inputs_ttnn_decode( + tt_h = model_args.prepare_residual_tensor_decode( h, model_args.model_config["DECODE_RESIDUAL_MEMCFG"], ) @@ -265,14 +258,14 @@ def test_llama_cross_attention_transformer_text_inference( model_args.num_devices, start_pos=cur_pos - 1, ) - - transformation_mats = None + tt_rope_id = tt_model.rope_setup.get_rot_idxs(position_ids) + rot_mats = tt_model.rope_setup.get_rot_mats(tt_rope_id) xattn_mask_expand = xattn_mask_expand.permute(2, 0, 1, 3).contiguous() tt_xattn_mask = ttnn.from_torch( xattn_mask_expand, device=mesh_device, - dtype=ttnn.bfloat8_b, + dtype=ttnn.bfloat4_b, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), @@ -289,7 +282,7 @@ def test_llama_cross_attention_transformer_text_inference( tt_full_text_mask_expand_1NSH = ttnn.from_torch( full_text_mask_expand_1NSH, device=mesh_device, - dtype=ttnn.bfloat8_b, + dtype=ttnn.bfloat4_b, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), @@ -311,8 +304,7 @@ def test_llama_cross_attention_transformer_text_inference( full_text_row_masked_out_mask_11SD=None, xattn_caches=tt_xattn_cache, current_pos=tt_position_id, - rot_mat=rot_mats, - transformation_mats=transformation_mats, + rot_mats=rot_mats, mode=mode, text_only_inference=TEXT_ONLY, ) @@ -332,7 +324,7 @@ def test_llama_cross_attention_transformer_text_inference( tt_xattn_cache_torch = [ ttnn.to_torch(x, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=1)).view( batch, - n_heads, + n_kv_heads, vision_seq_len, head_dim, ) diff --git a/models/demos/llama3/tests/multimodal/test_llama_cross_block.py b/models/demos/llama3/tests/multimodal/test_llama_cross_block.py index 96637e5090c..3f6a9253e5d 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_cross_block.py +++ b/models/demos/llama3/tests/multimodal/test_llama_cross_block.py @@ -30,9 +30,10 @@ ) @pytest.mark.parametrize( "batch", - (1,), + (1, 2), ids=[ "batch_1", + "batch_2", ], ) def test_llama_cross_attention_transformer_block_inference( @@ -57,6 +58,7 @@ def test_llama_cross_attention_transformer_block_inference( dim = model_args.dim head_dim = model_args.head_dim n_heads = model_args.n_heads + n_kv_heads = model_args.n_kv_heads reference_model = llama_reference_mod.CrossAttentionTransformerBlock(args=model_args, layer_id=0, no_ffn=False) reference_model.load_state_dict(partial_state_dict) @@ -83,12 +85,14 @@ def test_llama_cross_attention_transformer_block_inference( """ pt_xattn_cache = reference_model.compute_xattn_kv_cache(pt_xattn_tokens) pt_xattn_cache_chunks = torch.chunk(pt_xattn_cache, 2, dim=0) - pt_xattn_cache_chunks = [x.view(batch, n_heads, vision_seq_len, head_dim) for x in pt_xattn_cache] + pt_xattn_cache_chunks = [ + x.view(batch, n_heads, vision_seq_len, head_dim)[:, :: n_heads // n_kv_heads] for x in pt_xattn_cache + ] # Preallocate K and V caches tt_xattn_cache = [ ttnn.from_torch( - torch.zeros(batch, n_heads, vision_seq_len, head_dim), + torch.zeros(batch, n_kv_heads, vision_seq_len, head_dim), device=mesh_device, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, @@ -145,17 +149,17 @@ def test_llama_cross_attention_transformer_block_inference( if mode == "prefill": outputs = [] for b in range(batch): - tt_tensor_xattn_tokens = model_args.prepare_inputs_ttnn_prefill( + tt_tensor_xattn_tokens = model_args.prepare_residual_tensor_prefill( tt_xattn_tokens[b : b + 1], force_replicated=True, ) - tt_tensor_x = model_args.prepare_inputs_ttnn_prefill( + tt_tensor_x = model_args.prepare_residual_tensor_prefill( tt_x[b : b + 1], ) tt_xattn_mask = ttnn.from_torch( - xattn_mask_expand[b : b + 1], + xattn_mask[b : b + 1], device=mesh_device, - dtype=ttnn.bfloat8_b, + dtype=ttnn.bfloat4_b, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), @@ -163,7 +167,7 @@ def test_llama_cross_attention_transformer_block_inference( tt_full_text_mask_expand_1NSH = ttnn.from_torch( full_text_mask_expand_1NSH[b : b + 1], device=mesh_device, - dtype=ttnn.bfloat8_b, + dtype=ttnn.bfloat4_b, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), @@ -171,7 +175,7 @@ def test_llama_cross_attention_transformer_block_inference( tt_full_text_mask_expand_11SD = ttnn.from_torch( full_text_mask_expand_11SD[b : b + 1], device=mesh_device, - dtype=ttnn.bfloat8_b, + dtype=ttnn.bfloat4_b, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=-1), @@ -193,15 +197,15 @@ def test_llama_cross_attention_transformer_block_inference( tt_output_torch = torch.cat(outputs, dim=0).view(batch, seq_len, dim) else: - tt_x = model_args.prepare_inputs_ttnn_decode( + tt_x = model_args.prepare_residual_tensor_decode( tt_x, - ttnn.DRAM_MEMORY_CONFIG, + model_args.model_config["DECODE_RESIDUAL_MEMCFG"], ) xattn_mask_expand = xattn_mask_expand.permute(2, 0, 1, 3).contiguous() tt_xattn_mask = ttnn.from_torch( xattn_mask_expand, device=mesh_device, - dtype=ttnn.bfloat8_b, + dtype=ttnn.bfloat4_b, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), @@ -218,7 +222,7 @@ def test_llama_cross_attention_transformer_block_inference( tt_full_text_mask_expand_1NSH = ttnn.from_torch( full_text_mask_expand_1NSH, device=mesh_device, - dtype=ttnn.bfloat8_b, + dtype=ttnn.bfloat4_b, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), @@ -252,7 +256,7 @@ def test_llama_cross_attention_transformer_block_inference( tt_xattn_cache_torch = [ ttnn.to_torch(x, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=1)).view( batch, - n_heads, + n_kv_heads, vision_seq_len, head_dim, ) diff --git a/models/demos/llama3/tests/multimodal/test_llama_image_attention.py b/models/demos/llama3/tests/multimodal/test_llama_image_attention.py index 844937a518b..3d9e6977145 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_image_attention.py +++ b/models/demos/llama3/tests/multimodal/test_llama_image_attention.py @@ -73,7 +73,7 @@ def test_llama_attention_inference(batch, num_chunks, mesh_device, use_program_c mask = encoder_utils.build_encoder_attention_mask(pt_block_input, ar, ntok, num_chunks, 1) pt_block_input = pt_block_input.reshape(batch, -1, dim) - attention_input = model_args.prepare_inputs_ttnn_prefill( + attention_input = model_args.prepare_residual_tensor_prefill( tt_attention_input.view(num_chunks, ntok, dim), force_replicated=True, ) diff --git a/models/demos/llama3/tests/multimodal/test_llama_image_block.py b/models/demos/llama3/tests/multimodal/test_llama_image_block.py index 001aa518828..23096202e29 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_image_block.py +++ b/models/demos/llama3/tests/multimodal/test_llama_image_block.py @@ -84,7 +84,7 @@ def test_llama_block_inference(batch, num_chunks, mesh_device, gated, use_progra mask = encoder_utils.build_encoder_attention_mask(pt_block_input, ar, ntok, num_chunks, 1) pt_block_input = pt_block_input.reshape(batch, -1, dim) - attention_input = model_args.prepare_inputs_ttnn_prefill( + attention_input = model_args.prepare_residual_tensor_prefill( tt_attention_input.view(num_chunks, ntok, dim), force_replicated=True, ) diff --git a/models/demos/llama3/tests/multimodal/test_llama_image_transformer.py b/models/demos/llama3/tests/multimodal/test_llama_image_transformer.py index 03f3310a0e3..502736ac790 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_image_transformer.py +++ b/models/demos/llama3/tests/multimodal/test_llama_image_transformer.py @@ -113,7 +113,7 @@ def test_llama_image_transformer_inference( mask = encoder_utils.build_encoder_attention_mask(pt_block_input, ar, ntok, num_chunks, 1) pt_block_input = pt_block_input.reshape(batch, -1, dim) - attention_input = model_args.prepare_inputs_ttnn_prefill( + attention_input = model_args.prepare_residual_tensor_prefill( tt_attention_input.view(num_chunks, ntok, dim), force_replicated=True, ) diff --git a/models/demos/llama3/tests/multimodal/test_llama_vision_model.py b/models/demos/llama3/tests/multimodal/test_llama_vision_model.py deleted file mode 100644 index f55a47891ac..00000000000 --- a/models/demos/llama3/tests/multimodal/test_llama_vision_model.py +++ /dev/null @@ -1,154 +0,0 @@ -# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. - -# SPDX-License-Identifier: Apache-2.0 -from pathlib import Path -from typing import Optional -from loguru import logger - -from PIL import Image as PIL_Image -from termcolor import cprint - -import llama_models.llama3.reference_impl.generation as llama_reference_generation - -from llama_models.llama3.api.datatypes import ImageMedia - -from models.utility_functions import ( - comp_pcc, - comp_allclose, -) - -THIS_DIR = Path(__file__).parent.parent.parent.resolve() / "reference/llama_models/models/scripts/" - -import torch -import pytest -import os -import ttnn - - -def create_multimodal_model(model_args, mesh_device, dtype=ttnn.bfloat16): - from models.demos.llama3.tt.multimodal.llama_vision_model import CrossAttentionTransformer - from models.demos.llama3.tt.model_config import TtModelArgs - - tt_model_args = TtModelArgs(mesh_device) - checkpoint = torch.load(tt_model_args.consolidated_weights_path, map_location="cpu", weights_only=True) - model = CrossAttentionTransformer( - model_args, - mesh_device, - checkpoint, - weight_cache_path=tt_model_args.weight_cache_path(dtype), - dtype=dtype, - configuration=tt_model_args, - ) - model.setup_cache(model_args.max_batch_size, torch.float32) - return model - - -@pytest.mark.parametrize( - "mesh_device", - [ - {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( - os.environ.get("FAKE_DEVICE"), len(ttnn.get_device_ids()) - ) - ], - indirect=True, -) -def test_llama_vision_model( - mesh_device, - temperature: float = 0, - max_seq_len: int = 512, - max_batch_size: int = 4, - max_gen_len: Optional[int] = 50, - model_parallel_size: Optional[int] = None, -): - """ - This test runs the Llama3.2 vision model on CPU and TT concurrently. - It does not use teacher forcing and compares output logits at every token. - """ - mesh_device.enable_program_cache() - mesh_device.enable_async(True) - ckpt_dir = os.environ["LLAMA_DIR"] - tokenizer_path = str(Path(ckpt_dir) / "tokenizer.model") - - logger.info(f"Creating reference model from checkpoint in '{ckpt_dir}'") - generator_pt = llama_reference_generation.Llama.build( - ckpt_dir, - tokenizer_path=tokenizer_path, - max_seq_len=max_seq_len, - max_batch_size=max_batch_size, - model_parallel_size=model_parallel_size, - ) - - generator_tt = llama_reference_generation.Llama(generator_pt.model, generator_pt.tokenizer, generator_pt.args) - logger.info(f"Creating TT model on {len(mesh_device.get_devices())} devices") - model = create_multimodal_model(generator_tt.args, mesh_device) - generator_tt.model = model - - # with open(THIS_DIR / "resources/dog.jpg", "rb") as f: - # img = PIL_Image.open(f).convert("RGB") - - # with open(THIS_DIR / "resources/pasta.jpeg", "rb") as f: - # img2 = PIL_Image.open(f).convert("RGB") - - with open(THIS_DIR / "resources/ocr_image.jpeg", "rb") as f: - ocr_image = PIL_Image.open(f).convert("RGB") - - # with open(THIS_DIR / "resources/clutter.jpeg", "rb") as f: - # clutter = PIL_Image.open(f).convert("RGB") - - interleaved_contents = [ - # text only - # "The color of the sky is blue but sometimes it can also be", - # image understanding - # [ImageMedia(image=img), "If I had to write a haiku for this one"], - # [ImageMedia(image=img2), "Couting the number of individual spaghetti strands in this image"], - [ImageMedia(image=ocr_image), "The full text in this image is as follows"], - # [ImageMedia(image=clutter), "The count of vases, books, and miscellaneous items in this image is"], - ] - - for content in interleaved_contents: - logger.info(f"Generating text for content: {content}") - model_input = generator_pt.formatter.encode_content(content) - gen_pt = generator_pt.generate( - model_input, max_gen_len=max_gen_len, temperature=temperature, return_logits=True - ) - gen_tt = generator_tt.generate( - model_input, max_gen_len=max_gen_len, temperature=temperature, return_logits=True - ) - - for out_idx, (token_pt, token_tt) in enumerate(zip(gen_pt, gen_tt)): - logger.info(f"Comparing output token {out_idx}") - out_pt, out_tt = token_pt[1], token_tt[1] - out_pt = out_pt[0, -1] - out_tt = out_tt[0, -1] - passing, pcc_message = comp_pcc(out_pt, out_tt, 0.90) - print(f"PCC: {pcc_message}") - # Check shapes of logprobs - - ref_argmax = torch.argmax(out_pt).item() - ref_logprob = out_pt[ref_argmax].item() - ref_token = generator_pt.tokenizer.decode([ref_argmax]) - - # Reference model: top-5 tokens - ref_top5_vals, ref_top5_idxs = torch.topk(out_pt, 5) - ref_top5_tokens = [generator_pt.tokenizer.decode([idx.item()]) for idx in ref_top5_idxs] - ref_top5_logprobs = ref_top5_vals.tolist() - - # Test model: top-5 tokens - top5_vals, top5_idxs = torch.topk(out_tt, 5) - top5_tokens = [generator_pt.tokenizer.decode([idx.item()]) for idx in top5_idxs] - top5_logprobs = top5_vals.tolist() - - def entropy(logits): - probs = torch.softmax(logits, dim=-1) - return -(probs * torch.log(probs)).sum().item() - - # Print the information - print(f"Token Position {out_idx}:") - print(f" Reference | Test") - print(f" Entropy: {entropy(out_pt):.4f} | {entropy(out_tt):.4f}") - print(f" Top-5 Tokens:") - for rank in range(5): - print( - f" {rank+1}. Token='{ref_top5_tokens[rank]}' @ {ref_top5_logprobs[rank]:.2f} | '{top5_tokens[rank]}' @ {top5_logprobs[rank]:.2f}" - ) - print() diff --git a/models/demos/llama3/tests/test_interleaved_to_sharded.py b/models/demos/llama3/tests/test_interleaved_to_sharded.py index b69d7d2459b..9edc9a89dd0 100644 --- a/models/demos/llama3/tests/test_interleaved_to_sharded.py +++ b/models/demos/llama3/tests/test_interleaved_to_sharded.py @@ -80,7 +80,7 @@ def test_llama_decoder_inference(mesh_device, use_program_cache, reset_seeds): mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), ) - decode_input = model_args.prepare_inputs_ttnn_decode( + decode_input = model_args.prepare_residual_tensor_decode( tt_decode_input, ttnn.L1_MEMORY_CONFIG, ) diff --git a/models/demos/llama3/tests/test_llama_accuracy.py b/models/demos/llama3/tests/test_llama_accuracy.py index 2ae973a907d..b19cb086066 100644 --- a/models/demos/llama3/tests/test_llama_accuracy.py +++ b/models/demos/llama3/tests/test_llama_accuracy.py @@ -9,13 +9,11 @@ import ttnn from models.demos.llama3.tt.llama_common import ( get_prefill_rot_mat, - get_rot_transformation_mat, HostEmbedding, PagedAttentionConfig, ) from models.demos.llama3.tt.llama_model import TtTransformer from models.demos.llama3.tt.model_config import TtModelArgs, LlamaOptimizations -from models.demos.llama3.tt.llama_rope import TtLlamaRotarySetup from models.demos.t3000.llama2_70b.reference.llama.llama31_8b.tokenizer import Tokenizer from models.demos.llama3.demo.demo import preprocess_inputs_prefill from pathlib import Path @@ -141,28 +139,6 @@ def test_tt_model_accuracy( N = prefill_len + decode_len input_ids = reference_tokens[:, : N + 1] # Shape [1, N+1] - # Setup RoPE transformation matrices - rope_setup = TtLlamaRotarySetup( - mesh_device, - model_args.max_batch_size, - model_args.head_dim, - model_args.max_seq_len, - model_args.rope_theta, - model_args.use_scaled_rope, - ) - transformation_mats_decode = rope_setup.get_trans_mats() - - transformation_mats_prefill_torch = get_rot_transformation_mat(model_args.head_dim) - transformation_mats_prefill = ttnn.from_torch( - transformation_mats_prefill_torch, - dtype=ttnn.bfloat16, - layout=ttnn.TILE_LAYOUT, - device=mesh_device, - memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), - ) - transformation_mats = {"decode": transformation_mats_decode, "prefill": transformation_mats_prefill} - page_table_tt = None paged_attention_config = None @@ -193,7 +169,6 @@ def test_tt_model_accuracy( dtype=dtype, state_dict=state_dict, weight_cache_path=model_args.weight_cache_path(dtype), - transformation_mats=transformation_mats, paged_attention_config=paged_attention_config, ) # Initialize embedding @@ -226,7 +201,7 @@ def test_tt_model_accuracy( model_args.head_dim, model_args.max_seq_len, mesh_device, seq_len=prefill_lens[0] ) - prefill_input = model_args.prepare_inputs_ttnn_prefill( + prefill_input = model_args.prepare_residual_tensor_prefill( pt_prefill_input[batch_id], ) @@ -256,7 +231,7 @@ def test_tt_model_accuracy( ) # Get cos/sin matrices for the current position of each user - rot_mats = rope_setup.get_rot_mats(current_pos) + rot_mats = tt_model.rope_setup.get_rot_mats(current_pos) # Print table header logger.info(f"{'Progress':<15}{'Correct':<8}{'True':<15}{'Actual':<15}{'Top 5 Predictions':<75}") @@ -276,7 +251,7 @@ def test_tt_model_accuracy( # Get embedding pt_decode_input = embd(ref_token).view(1, 1, -1) # Prepare input for TT model - decode_input = model_args.prepare_inputs_ttnn_decode( + decode_input = model_args.prepare_residual_tensor_decode( pt_decode_input, model_args.model_config["DECODE_RESIDUAL_MEMCFG"], ) @@ -309,7 +284,7 @@ def test_tt_model_accuracy( # Update rot_mats for next iteration current_pos += 1 - rot_mats = rope_setup.get_rot_mats(current_pos) + rot_mats = tt_model.rope_setup.get_rot_mats(current_pos) # Get reference top5 tokens and probabilities for this position ref_top5_tokens = top5_tokens[prefill_len + i] diff --git a/models/demos/llama3/tests/test_llama_attention.py b/models/demos/llama3/tests/test_llama_attention.py index 8690b91d3b9..edb9ac99a43 100644 --- a/models/demos/llama3/tests/test_llama_attention.py +++ b/models/demos/llama3/tests/test_llama_attention.py @@ -100,8 +100,7 @@ def test_llama_attention_inference( model_args.use_scaled_rope, ) - transformation_mats = rope_setup.get_trans_mats() - transformation_mats = {"decode": transformation_mats} + transformation_mats = rope_setup.get_both_trans_mats() page_table_tt = None paged_attention_config = None @@ -158,7 +157,7 @@ def test_llama_attention_inference( tt_attention_input = pt_attention_input.clone() - attention_input = model_args.prepare_inputs_ttnn_decode( + attention_input = model_args.prepare_residual_tensor_decode( tt_attention_input, model_args.model_config["SHARDED_ATTN_INPUT_MEMCFG"], force_replicated=True, diff --git a/models/demos/llama3/tests/test_llama_attention_prefill.py b/models/demos/llama3/tests/test_llama_attention_prefill.py index ef33adc4481..4335bdb4ee1 100644 --- a/models/demos/llama3/tests/test_llama_attention_prefill.py +++ b/models/demos/llama3/tests/test_llama_attention_prefill.py @@ -141,7 +141,7 @@ def test_llama_attention_inference( pt_attention_input = (torch.rand(batch_size, max_seq_len, model_args.dim) * 2) - 1 tt_attention_input = pt_attention_input.clone() - attention_input = model_args.prepare_inputs_ttnn_prefill( + attention_input = model_args.prepare_residual_tensor_prefill( tt_attention_input, force_replicated=True, ) diff --git a/models/demos/llama3/tests/test_llama_decoder.py b/models/demos/llama3/tests/test_llama_decoder.py index 5d24d3b4298..316c811aaf3 100644 --- a/models/demos/llama3/tests/test_llama_decoder.py +++ b/models/demos/llama3/tests/test_llama_decoder.py @@ -94,8 +94,7 @@ def test_llama_decoder_inference( model_args.rope_theta, model_args.use_scaled_rope, ) - transformation_mats = rope_setup.get_trans_mats() - transformation_mats = {"decode": transformation_mats} + transformation_mats = rope_setup.get_both_trans_mats() # Prepare page table for paged attention page_table_tt = None @@ -155,7 +154,7 @@ def test_llama_decoder_inference( pt_decode_input = (torch.rand(batch_size, seqlen, model_args.dim) * 2) - 1 tt_decode_input = pt_decode_input.clone() - decode_input = model_args.prepare_inputs_ttnn_decode( + decode_input = model_args.prepare_residual_tensor_decode( tt_decode_input, # ttnn.DRAM_MEMORY_CONFIG, model_args.model_config["DECODE_RESIDUAL_MEMCFG"], diff --git a/models/demos/llama3/tests/test_llama_decoder_prefill.py b/models/demos/llama3/tests/test_llama_decoder_prefill.py index 0c40e21b773..622e67f91b4 100644 --- a/models/demos/llama3/tests/test_llama_decoder_prefill.py +++ b/models/demos/llama3/tests/test_llama_decoder_prefill.py @@ -142,7 +142,7 @@ def test_llama_decoder_inference( logger.info(f"[Decoder] Generating token {i}") pt_decode_input = (torch.rand(batch_size, max_seq_len, model_args.dim) * 2) - 1 tt_decode_input = pt_decode_input.clone() - decode_input = model_args.prepare_inputs_ttnn_prefill( + decode_input = model_args.prepare_residual_tensor_prefill( tt_decode_input, ) positions = torch.LongTensor(range(max_seq_len)) diff --git a/models/demos/llama3/tests/test_llama_model.py b/models/demos/llama3/tests/test_llama_model.py index cd425579a23..37e0e438419 100644 --- a/models/demos/llama3/tests/test_llama_model.py +++ b/models/demos/llama3/tests/test_llama_model.py @@ -15,7 +15,6 @@ ) from models.demos.llama3.tt.model_config import TtModelArgs, LlamaOptimizations from models.demos.llama3.tt.llama_model import TtTransformer -from models.demos.llama3.tt.llama_rope import TtLlamaRotarySetup from models.demos.t3000.llama2_70b.reference.llama.llama31_8b.model import Transformer from models.demos.t3000.llama2_70b.reference.llama.llama31_8b.tokenizer import Tokenizer from models.utility_functions import ( @@ -191,18 +190,6 @@ def test_llama_model_inference( generation_start_pos = 0 generation_length = iterations - # Setup RoPE transformation matrices - rope_setup = TtLlamaRotarySetup( - mesh_device, - model_args.max_batch_size, - model_args.head_dim, - model_args.max_seq_len, - model_args.rope_theta, - model_args.use_scaled_rope, - ) - transformation_mats = rope_setup.get_trans_mats() - transformation_mats = {"decode": transformation_mats} - page_table_tt = None paged_attention_config = None @@ -234,7 +221,6 @@ def test_llama_model_inference( dtype=dtype, state_dict=state_dict, weight_cache_path=model_args.weight_cache_path(dtype), - transformation_mats=transformation_mats, paged_attention_config=paged_attention_config, ) logger.info("Model and caches loaded.") @@ -269,13 +255,13 @@ def test_llama_model_inference( for i in range(generation_length): logger.info(f"[Llama3 Model] Generating token {i}") - decode_input = model_args.prepare_inputs_ttnn_decode( + decode_input = model_args.prepare_residual_tensor_decode( tt_decode_input, model_args.model_config["DECODE_RESIDUAL_MEMCFG"], ) # Get cos/sin matrices for the current position of each user - rot_mats = rope_setup.get_rot_mats(current_pos) + rot_mats = tt_model.rope_setup.get_rot_mats(current_pos) # Run TT model tt_out = tt_model( diff --git a/models/demos/llama3/tests/test_llama_model_prefill.py b/models/demos/llama3/tests/test_llama_model_prefill.py index 934c91d5746..e30c25cc8f4 100644 --- a/models/demos/llama3/tests/test_llama_model_prefill.py +++ b/models/demos/llama3/tests/test_llama_model_prefill.py @@ -93,7 +93,7 @@ def test_llama_model_inference( pcc = 0.91 # TODO Look on improving PCC else: # performance mode assert optimizations == LlamaOptimizations.performance - pcc = 0.87 # TODO Look on improving PCC + pcc = 0.869 # TODO Look on improving PCC mesh_device.enable_async(True) @@ -143,17 +143,6 @@ def test_llama_model_inference( # pre-compute the rotational embedding matrix and send to device rot_mats = get_prefill_rot_mat(model_args.head_dim, model_args.max_seq_len, mesh_device, seq_len=seq_len) - transformation_mat_torch = get_rot_transformation_mat(model_args.head_dim) - transformation_mats_prefill = ttnn.as_tensor( - transformation_mat_torch, - dtype=ttnn.bfloat16, - layout=ttnn.TILE_LAYOUT, - device=mesh_device, - memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), - ) - transformation_mats = {"prefill": transformation_mats_prefill} - # Setup page table page_table_tt = None paged_attention_config = None @@ -185,7 +174,6 @@ def test_llama_model_inference( dtype=dtype, state_dict=state_dict, weight_cache_path=model_args.weight_cache_path(dtype), - transformation_mats=transformation_mats, paged_attention_config=paged_attention_config, ) @@ -200,7 +188,7 @@ def test_llama_model_inference( tt_prefill_input = pt_prefill_input - tt_prefill_input = model_args.prepare_inputs_ttnn_prefill( + tt_prefill_input = model_args.prepare_residual_tensor_prefill( pt_prefill_input, ) for i in range(1): diff --git a/models/demos/llama3/tt/multimodal/vision_generator.py b/models/demos/llama3/tt/generator.py similarity index 59% rename from models/demos/llama3/tt/multimodal/vision_generator.py rename to models/demos/llama3/tt/generator.py index b00fbf3ff73..c42450e48d3 100644 --- a/models/demos/llama3/tt/multimodal/vision_generator.py +++ b/models/demos/llama3/tt/generator.py @@ -1,8 +1,10 @@ # SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. # SPDX-License-Identifier: Apache-2.0 + import ttnn import torch +from loguru import logger from llama_models.llama3.api.datatypes import ( InterleavedTextMedia, @@ -15,10 +17,16 @@ TokenResult, sample_top_p, ) +from models.demos.llama3.tt.llama_common import ( + copy_host_to_device, + get_padded_prefill_len, + num_blocks_in_seq, + get_block_size, +) -class LlamaVision: - def __init__(self, model, model_args, mesh_device, vllm=False, tokenizer=None, formatter=None): +class LlamaGenerator: + def __init__(self, model, model_args, mesh_device, tokenizer=None, formatter=None): """ Creating a LlamaVision wrapper requires only a mesh_device and model_args. With model_args you have the checkpoint location, can specify max batch size @@ -32,10 +40,133 @@ def __init__(self, model, model_args, mesh_device, vllm=False, tokenizer=None, f self.model = model self.model_args = model_args self.mesh_device = mesh_device - self.vllm = vllm self.tokenizer = tokenizer self.formatter = formatter + def prefill_forward_text(self, tokens: torch.Tensor, page_table=None, kv_cache=None, prompt_lens=None): + batch, batch_seq_len = tokens.shape + output_logits = torch.zeros(batch, 1, self.model_args.vocab_size) + prompt_lens = prompt_lens if prompt_lens is not None else torch.tensor([batch_seq_len] * batch) + + if page_table is not None: + assert isinstance( + page_table, torch.Tensor + ), "page_table must be a torch.Tensor when passing into prefill_forward" + + for user_id in range(batch): + seq_len = prompt_lens[user_id] + last_token_idx = seq_len - 1 + + prefill_seq_len = get_padded_prefill_len(seq_len) + prefill_ids = torch.cat( + [tokens[user_id : user_id + 1, :seq_len], torch.zeros(1, prefill_seq_len - seq_len).long()], dim=-1 + ) + if page_table is not None: + page_table_user = self._get_prefill_user_page_table(page_table, kv_cache, seq_len) + + logits = self.prefill_forward_single_user_text( + prefill_ids, + page_table=page_table_user if page_table is not None else None, + user_id=user_id, + last_token_idx=last_token_idx, + kv_cache=kv_cache, + ) + + # Since we give unpadded_seq_len, only the tile containing the last token is returned + output_logits[user_id] = logits + + return output_logits + + def prefill_forward_single_user_text(self, tokens, page_table, user_id, last_token_idx, kv_cache=None): + prefill_input, rot_mats_prefill, page_table_tt = self.model.prepare_inputs_prefill( + tokens, + page_table=page_table, + ) + + tt_logits = self.model.ttnn_prefill_forward( + prefill_input, + rot_mats=rot_mats_prefill, + user_id=user_id, + page_table=page_table_tt, + get_last_token=(last_token_idx // 32) * 32, + ) + + logits = self.model.process_output_prefill(tt_logits, last_token_idx=(last_token_idx % 32)) + + return logits + + def decode_forward_text( + self, + tokens, + current_pos, + page_table=None, + ): + """ + Performs text decode step. + Returns logits + """ + tt_tokens, tt_current_pos, tt_rot_mats, tt_page_table = self.model.prepare_inputs_decode( + tokens, current_pos, page_table + ) + + tt_logits = self.model.ttnn_decode_forward( + tt_tokens, + tt_current_pos, + rot_mats=tt_rot_mats, + page_table=tt_page_table, + ) + + logits = self.model.process_output_decode(tt_logits) + return logits + + def capture_trace_text( + self, + tokens, + current_pos, + page_table=None, + ): + """ + Captures a trace for the decode_forward method. + """ + + # Compile run + self.decode_forward_text(tokens, current_pos, page_table) + + # Get inputs ready for trace run + host_inputs = self.model.prepare_decode_inputs_host(tokens, current_pos, page_table) + + device_inputs = copy_host_to_device(host_inputs, mesh_device=self.mesh_device) + + trace_id = ttnn.begin_trace_capture(self.mesh_device, cq_id=0) + transformed_inputs = self.model.transform_decode_inputs_device(*device_inputs) + tt_out_trace = self.model.ttnn_decode_forward(*transformed_inputs) + + ttnn.end_trace_capture(self.mesh_device, trace_id, cq_id=0) + + return trace_id, tt_out_trace, *device_inputs + + def decode_forward_trace_text( + self, + trace_id, + device_inputs, + tt_out_trace, + tokens, + current_pos, + page_table=None, + ): + host_inputs = self.model.prepare_decode_inputs_host(tokens, current_pos, page_table) + + device_inputs = copy_host_to_device( + host_tensors=host_inputs, + device_tensors=device_inputs, + ) + + ttnn.execute_trace(self.mesh_device, trace_id, cq_id=0, blocking=False) + + logits = self.model.process_output_decode(tt_out_trace) + + return logits + def prefill_forward_single_user( self, vision_images, @@ -45,28 +176,37 @@ def prefill_forward_single_user( user_id, total_len, prefill_len, + page_table=None, + kv_cache=None, ): """ Performs vision encode step then text prefill. Returns (xattn_caches, cross_attention_masks, full_text_row_masked_out_mask, logits) """ B = tokens.shape[0] + last_token_idx = prefill_len - 1 vision_tokens, cross_attention_masks, full_text_row_masked_out_mask = self.model.compute_vision_tokens_masks( batch_images=[vision_images], batch_masks=[vision_mask], total_len=total_len, ) + if page_table is not None: + page_table = self._get_prefill_user_page_table(page_table, kv_cache, prefill_len) + ( tt_h, tt_xattn_mask, tt_full_text_mask_expand_1NSH, tt_full_text_mask_expand_11SD, - tt_position_id, rot_mats, - transformation_mats, + tt_page_table, ) = self.model.prepare_inputs_prefill( - tokens, cross_attention_masks, full_text_row_masked_out_mask, prefill_len=prefill_len + tokens, + cross_attention_masks, + full_text_row_masked_out_mask, + prefill_len=prefill_len, + page_table=page_table, ) tt_logits = self.model.ttnn_prefill_forward( @@ -75,24 +215,76 @@ def prefill_forward_single_user( tt_full_text_mask_expand_1NSH, tt_full_text_mask_expand_11SD, xattn_caches, - tt_position_id, rot_mats, - transformation_mats, user_id, vision_tokens, + page_table=tt_page_table, + kv_cache=kv_cache, + get_last_token=(last_token_idx // 32) * 32, ) - logits = self.model.process_output_prefill(tt_logits, B, prefill_len) + del tt_page_table + + logits = self.model.process_output_prefill(tt_logits, B, last_token_idx=(last_token_idx % 32)) return xattn_caches, cross_attention_masks, full_text_row_masked_out_mask, logits + def prefill_forward( + self, + vision_images, + vision_masks, + tokens: torch.Tensor, + xattn_caches, + total_lens, + prompt_lens, + page_table=None, + kv_cache=None, + ): + """ + Batched version of prefill_forward_single_user for vision model. + """ + batch, batch_seq_len = tokens.shape + output_logits = torch.zeros(batch, 1, self.model_args.vocab_size) + output_xattn_masks = [] + output_full_text_row_masked_out_masks = [] + + for user_id in range(batch): + print(f"Prefilling User {user_id}") + seq_len = prompt_lens[user_id] + ( + xattn_caches, + cross_attention_masks, + full_text_row_masked_out_mask, + logits, + ) = self.prefill_forward_single_user( + vision_images=vision_images[user_id], + vision_mask=vision_masks[user_id], + tokens=tokens[user_id : user_id + 1, :seq_len], # Keep batch dimension + xattn_caches=xattn_caches, + user_id=user_id, + total_len=total_lens[user_id], + prefill_len=seq_len, + page_table=page_table, + kv_cache=kv_cache, + ) + output_logits[user_id] = logits + output_xattn_masks.append(cross_attention_masks) + output_full_text_row_masked_out_masks.append(full_text_row_masked_out_mask) + + logger.info(f"Finished prefill for all users up to {batch_seq_len} tokens, Starting decode...") + + return output_logits, output_xattn_masks, output_full_text_row_masked_out_masks + def decode_forward( self, - position_id, + start_pos, tokens, cross_attention_masks, full_text_row_masked_out_mask, xattn_caches, + page_table=None, + kv_cache=None, + prompt_lens=None, ): """ Performs text decode step. @@ -101,19 +293,18 @@ def decode_forward( # forward_decode should be traced callable # decorator does compilation, capture, execute - # B = 1 # TODO: Only supports batch=1 right now! Might make tokens input a tensor. B, S = tokens.shape + assert S == 1 ( tt_h, tt_xattn_mask, tt_full_text_mask_expand_1NSH, - _, tt_position_id, - rot_mats, - _, + tt_rot_mats, + tt_page_table, ) = self.model.prepare_inputs_decode( - tokens, cross_attention_masks, full_text_row_masked_out_mask, position_id=position_id + tokens, cross_attention_masks, full_text_row_masked_out_mask, position_id=start_pos, page_table=page_table ) tt_logits = self.model.ttnn_decode_forward( @@ -122,7 +313,9 @@ def decode_forward( tt_full_text_mask_expand_1NSH, xattn_caches, tt_position_id, - rot_mats, + tt_rot_mats, + page_table=tt_page_table, + kv_cache=kv_cache, ) logits = self.model.process_output_decode(tt_logits, B, S) @@ -143,10 +336,9 @@ def capture_trace( tt_h, tt_xattn_mask, tt_full_text_mask_expand_1NSH, - _, tt_position_id, - rot_mats, - _, + tt_rot_mats, + tt_page_table, ) = self.model.prepare_inputs_decode( tokens, cross_attention_masks, full_text_row_masked_out_mask, position_id=position_id ) @@ -158,7 +350,7 @@ def capture_trace( tt_full_text_mask_expand_1NSH, xattn_caches, tt_position_id, - rot_mats, + tt_rot_mats, ) # Get inputs ready for trace run @@ -166,9 +358,8 @@ def capture_trace( tt_h, tt_xattn_mask, tt_full_text_mask_expand_1NSH, - _, tt_position_id, - rot_mats, + tt_rope_id, _, ) = self.model.prepare_decode_inputs_host( tokens, cross_attention_masks, full_text_row_masked_out_mask, position_id @@ -179,9 +370,10 @@ def capture_trace( tt_xattn_mask, tt_full_text_mask_expand_1NSH, tt_position_id, - rot_mats, - ) = self.model.copy_host_to_device( - (tt_h, tt_xattn_mask, tt_full_text_mask_expand_1NSH, tt_position_id, rot_mats) + tt_rope_id, + ) = copy_host_to_device( + (tt_h, tt_xattn_mask, tt_full_text_mask_expand_1NSH, tt_position_id, tt_rope_id), + mesh_device=self.mesh_device, ) trace_id = ttnn.begin_trace_capture(self.mesh_device, cq_id=0) @@ -189,36 +381,30 @@ def capture_trace( B = tokens.shape[0] # Do on-device transformations of inputs before forward ( - tt_h, + tt_h_transform, + tt_rot_mats, tt_xattn_mask_transform, tt_full_text_mask_expand_1NSH_transform, ) = self.model.transform_decode_inputs_device( tt_h, + tt_rope_id, tt_xattn_mask, tt_full_text_mask_expand_1NSH, B=B, ) tt_logits_rm = self.model.ttnn_decode_forward( - tt_h, + tt_h_transform, tt_xattn_mask_transform, tt_full_text_mask_expand_1NSH_transform, xattn_caches, tt_position_id, - rot_mats, + tt_rot_mats, ) ttnn.end_trace_capture(self.mesh_device, trace_id, cq_id=0) - return ( - trace_id, - tt_logits_rm, - tt_h_trace_input, - tt_xattn_mask, - tt_full_text_mask_expand_1NSH, - tt_position_id, - rot_mats, - ) + return trace_id, tt_logits_rm, tt_h, tt_xattn_mask, tt_full_text_mask_expand_1NSH, tt_position_id, tt_rope_id def decode_forward_trace( self, @@ -233,28 +419,27 @@ def decode_forward_trace( trace_xattn_mask, trace_full_text_mask_expand_1NSH, trace_position_id, - trace_rot_mats, + trace_rope_id, ): ( tt_h, tt_xattn_mask, tt_full_text_mask_expand_1NSH, - _, tt_position_id, - rot_mats, + tt_rope_id, _, ) = self.model.prepare_decode_inputs_host( tokens, cross_attention_masks, full_text_row_masked_out_mask, position_id=position_id ) - self.model.copy_host_to_device( - host_tensors=(tt_h, tt_xattn_mask, tt_full_text_mask_expand_1NSH, tt_position_id, rot_mats), + copy_host_to_device( + host_tensors=(tt_h, tt_xattn_mask, tt_full_text_mask_expand_1NSH, tt_position_id, tt_rope_id), device_tensors=( trace_h, trace_xattn_mask, trace_full_text_mask_expand_1NSH, trace_position_id, - trace_rot_mats, + trace_rope_id, ), ) @@ -284,7 +469,7 @@ def easy_trace( tt_xattn_mask, tt_full_text_mask_expand_1NSH, tt_position_id, - rot_mats, + tt_rope_id, ) = self.capture_trace( position_id, tokens, @@ -298,7 +483,7 @@ def easy_trace( "tt_xattn_mask": tt_xattn_mask, "tt_full_text_mask_expand_1NSH": tt_full_text_mask_expand_1NSH, "tt_position_id": tt_position_id, - "rot_mats": rot_mats, + "tt_rope_id": tt_rope_id, } self.trace_outputs = { "tt_logits_rm": tt_logits_rm, @@ -316,7 +501,7 @@ def easy_trace( self.trace_inputs["tt_xattn_mask"], self.trace_inputs["tt_full_text_mask_expand_1NSH"], self.trace_inputs["tt_position_id"], - self.trace_inputs["rot_mats"], + self.trace_inputs["tt_rope_id"], ) def generate( @@ -351,6 +536,8 @@ def generate( prefill_len=prefill_len, ) + logits = logits.view(1, 1, self.model_args.max_vocab_size) + def sample(logits): if temperature > 0: probs = torch.softmax(logits[:, -1] / temperature, dim=-1) @@ -368,14 +555,14 @@ def sample(logits): ) for gen_idx in range(max_gen_len - 1): - position_id = prefill_len + gen_idx + position_id = torch.tensor([prefill_len + gen_idx]) next_token_tensor = next_token.reshape(1, 1) # B, S logits = self.decode_forward( position_id, next_token_tensor, - cross_attention_masks, - full_text_row_masked_out_mask, + [cross_attention_masks], + [full_text_row_masked_out_mask], xattn_caches, ) @@ -442,3 +629,9 @@ def text_completion( generation = self.tokenizer.decode(tokens) return CompletionPrediction(generation=generation) + + def _get_prefill_user_page_table(self, page_table, kv_cache, prefill_len): + # Ensure page_table is not padded with extra blocks for paged_fill_cache to work properly + block_size = get_block_size(kv_cache) + num_blocks = num_blocks_in_seq(prefill_len, block_size) + return page_table[:, :num_blocks] diff --git a/models/demos/llama3/tt/generator_vllm.py b/models/demos/llama3/tt/generator_vllm.py new file mode 100644 index 00000000000..f962b2801b1 --- /dev/null +++ b/models/demos/llama3/tt/generator_vllm.py @@ -0,0 +1,77 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +from typing import List, Union +import torch +import PIL +from llama_models.llama3.api.chat_format import create_vision_mask + +from models.demos.llama3.tt.generator import LlamaGenerator +from models.demos.llama3.demo.simple_vision_demo import create_multimodal_model + +from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, EncoderDecoderInputs, InputContext + + +def input_processor_for_mllama(ctx: InputContext, inputs: Union[DecoderOnlyInputs, EncoderDecoderInputs]): + """ + Based on vllm.model_executor.models.mllama.py::input_processor_for_mllama(). + Note that vLLM's input_processor_for_mllama performs additional processing to handle chunking which we do not yet support. + """ + + # Move encoder_prompt to prompt. If the user does not explicitly provide separate + # encoder and decoder prompts, vLLM by default will treat the prompt as the encoder prompt. + # For the block manager to allocate enough blocks and add them to the block table, the decoder prompt + # must contain the full text prompt. + if inputs.get("prompt") is None: + inputs["prompt"] = inputs["encoder_prompt"] + inputs["prompt_token_ids"] = inputs["encoder_prompt_token_ids"] + + return inputs + + +@INPUT_REGISTRY.register_input_processor(input_processor_for_mllama) +class TtMllamaForConditionalGeneration(LlamaGenerator): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.MLLAMA_IMAGE_TOKEN_ID = 128256 + self.max_gen_len = self.model_args.max_seq_len - 1 # TODO: double check what this should be + + @classmethod + def initialize_vllm_model(cls, hf_config, mesh_device, max_batch_size): + max_seq_len = 512 # TODO: Increase to 131072 once it's verified to work + model_args, model = create_multimodal_model(mesh_device, max_batch_size, max_seq_len, use_paged_kv_cache=True) + return cls(model, model_args, mesh_device) + + @property + def cache_path(self): + return self.model_args.model_cache_path + + def prefill_forward( + self, + tokens: torch.Tensor, + images: List[PIL.Image.Image], + xattn_caches, + start_pos, + page_table: torch.Tensor = None, + kv_cache=None, + prompt_lens=None, + ): + """ + Replaces prefill_forward from LlamaGenerator with a version that supports mask creation. + """ + batch = tokens.shape[0] + + vision_images = [] + vision_masks = [] + total_lens = [] + for user_id in range(batch): + vision_images.append([images[user_id]]) + prompt_tokens = [int(tokens[user_id, i]) for i in range(prompt_lens[user_id])] + vision_masks.append(create_vision_mask(prompt_tokens, self.MLLAMA_IMAGE_TOKEN_ID)) + total_lens.append(prompt_lens[user_id] + self.max_gen_len) + + return super().prefill_forward( + vision_images, vision_masks, tokens, xattn_caches, total_lens, prompt_lens, page_table, kv_cache + ) diff --git a/models/demos/llama3/tt/llama_attention.py b/models/demos/llama3/tt/llama_attention.py index a925044554a..b93438e469d 100644 --- a/models/demos/llama3/tt/llama_attention.py +++ b/models/demos/llama3/tt/llama_attention.py @@ -23,6 +23,7 @@ def __init__( transformation_mats, configuration, paged_attention_config=None, + use_paged_kv_cache=False, ): super().__init__() @@ -56,6 +57,7 @@ def __init__( self.ccl_topology = configuration.ccl_topology() self.is_multichip = configuration.is_multichip + self.layer_num = layer_num layer_name = configuration.get_state_dict_prefix(self.__class__.__name__, layer_num) if configuration.dummy_weights or (weight_cache_path is None): cache_name = lambda _: None @@ -144,6 +146,17 @@ def __init__( cache_file_name=cache_name("wo_height_sharded"), ) + if not use_paged_kv_cache: + # vLLM provides its own kv cache + self.init_kv_cache(configuration, weight_cache_path) + + self.scale = self.head_dim**-0.5 + + def init_kv_cache(self, configuration, weight_cache_path): + """ + Generates empty KV cache and pushed to device memory + """ + if self.paged_attention_config: cache_k = torch.zeros( ( @@ -194,14 +207,13 @@ def __init__( for k_or_v in [cache_k, cache_v] ] - self.scale = self.head_dim**-0.5 - def forward_decode( self, x: ttnn.Tensor, current_pos, rot_mats=None, page_table=None, + kv_cache=None, ) -> ttnn.Tensor: """ x: (seq_len, 1, batch, dim) @@ -262,8 +274,12 @@ def forward_decode( ### # KV update ### - keys = self.layer_past[0] - values = self.layer_past[1] + if kv_cache: + keys = kv_cache[self.layer_num][0] + values = kv_cache[self.layer_num][1] + else: + keys = self.layer_past[0] + values = self.layer_past[1] # k_heads, [seqlen, n_kv_heads, bsz, head_dim] # v_heads [seqlen, n_kv_heads, bsz, head_dim] # keys, [max_batch_size, n_kv_heads // configuration.num_devices, max_seq_len, head_dim] @@ -272,9 +288,6 @@ def forward_decode( values, v_heads_1BKD, update_idxs_tensor=current_pos, page_table=page_table ) - self.layer_past[0] = keys - self.layer_past[1] = values - ttnn.deallocate(k_heads_1BKD) ttnn.deallocate(v_heads_1BKD) @@ -362,7 +375,7 @@ def forward_decode( dense_out_sharded = ttnn.to_memory_config(dense_out_sharded, self.model_config["DECODE_RESIDUAL_MEMCFG"]) return dense_out_sharded - def forward_prefill(self, x_11SH, rot_mats, user_id: int = 0, page_table=None): + def forward_prefill(self, x_11SH, rot_mats, user_id: int = 0, page_table=None, kv_cache=None): seq_len = x_11SH.shape[-2] assert seq_len % 128 == 0 and seq_len > 0, "Seqlen must be divisible by 128" ### @@ -425,7 +438,10 @@ def forward_prefill(self, x_11SH, rot_mats, user_id: int = 0, page_table=None): ttnn.deallocate(k_heads_1KSD_pre_rot) # Fill KV-Cache - keys_BKSD, values_BKSD = self.layer_past[0], self.layer_past[1] + if kv_cache: + keys_BKSD, values_BKSD = kv_cache[self.layer_num][0], kv_cache[self.layer_num][1] + else: + keys_BKSD, values_BKSD = self.layer_past[0], self.layer_past[1] k_heads_1KSD_8b = ttnn.typecast(k_heads_1KSD, dtype=ttnn.bfloat8_b) v_heads_1VSD_8b = ttnn.typecast(v_heads_1VSD, dtype=ttnn.bfloat8_b) @@ -451,8 +467,14 @@ def forward_prefill(self, x_11SH, rot_mats, user_id: int = 0, page_table=None): ttnn.deallocate(v_heads_1VSD) if page_table: - ttnn.experimental.paged_fill_cache(keys_BKSD, k_fill, page_table, batch_idx=user_id) - ttnn.experimental.paged_fill_cache(values_BKSD, v_fill, page_table, batch_idx=user_id) + # In the case that the tokens have been padded along the seq len dimension, we need to fill the cache with the unpadded k/v values. + # Assume that the page table does not have padding, so we can use it to get the unpadded page len. + block_size = keys_BKSD.shape[2] + page_len = page_table.shape[1] * block_size + k_fill_sliced = k_fill[:, :, :page_len, :] if page_len < k_fill.shape[2] else k_fill + v_fill_sliced = v_fill[:, :, :page_len, :] if page_len < v_fill.shape[2] else v_fill + ttnn.experimental.paged_fill_cache(keys_BKSD, k_fill_sliced, page_table, batch_idx=user_id) + ttnn.experimental.paged_fill_cache(values_BKSD, v_fill_sliced, page_table, batch_idx=user_id) else: ttnn.fill_cache( keys_BKSD, @@ -469,8 +491,6 @@ def forward_prefill(self, x_11SH, rot_mats, user_id: int = 0, page_table=None): ttnn.deallocate(k_fill) ttnn.deallocate(v_fill) - self.layer_past = [keys_BKSD, values_BKSD] - # SDPA # reshaping to put group in batch dim to do sdpa on 8 MQAs in parallel @@ -550,8 +570,17 @@ def forward_prefill(self, x_11SH, rot_mats, user_id: int = 0, page_table=None): else: return output_11SH - def forward(self, x, current_pos, rot_mats=None, user_id=0, mode="decode", page_table=None): + def forward( + self, + x, + current_pos, + rot_mats=None, + user_id=0, + mode="decode", + page_table=None, + kv_cache=None, + ): if mode == "prefill": - return self.forward_prefill(x, rot_mats, user_id, page_table) + return self.forward_prefill(x, rot_mats, user_id, page_table=page_table, kv_cache=kv_cache) else: - return self.forward_decode(x, current_pos, rot_mats, page_table) + return self.forward_decode(x, current_pos, rot_mats, page_table=page_table, kv_cache=kv_cache) diff --git a/models/demos/llama3/tt/llama_common.py b/models/demos/llama3/tt/llama_common.py index b9b5484cb89..fd7f368557f 100644 --- a/models/demos/llama3/tt/llama_common.py +++ b/models/demos/llama3/tt/llama_common.py @@ -209,6 +209,27 @@ def num_to_core_range_set(x): ) +def copy_host_to_device(host_tensors, device_tensors=None, mesh_device=None): + """ + Helper function which copies host tensors to device tensors. + If no device_tensors are provided, it creates new device tensors and returns them. + """ + if device_tensors is None: + assert mesh_device is not None, "mesh_device is required when device_tensors is None" + ret = [] + for i in range(len(host_tensors)): + on_device = ttnn.to_device(host_tensors[i], device=mesh_device) if host_tensors[i] else None + ret.append(on_device) + return ret + else: + for i in range(len(host_tensors)): + if host_tensors[i] is None: + assert device_tensors[i] is None + continue + ttnn.copy_host_to_device_tensor(host_tensors[i], device_tensors[i]) + return device_tensors + + def calculate_hidden_dim(dim, ffn_dim_multiplier, multiple_of): """Helper function based on logic used in reference model: https://github.com/meta-llama/llama-models/blob/e4a6ed52a142bb9b5106dcbf48e41f97f8e7378e/models/llama3/reference_impl/model.py#L227C7-L231C83 @@ -295,3 +316,29 @@ def sample_host(tt_input, mesh_device, temperature=0.6, top_p=0.08, on_host=True ), pt_out, ) + + +def get_padded_prefill_len(seq_len): + """ + If seq_len is less than 32, pad to 32 + If seq_len is more than 32, pad to whichever is smaller: a power of 2 or a multiple of 1024 + TODO: Generalize for max_mm_seq_len different from 1024 + """ + if seq_len <= 32: + return 32 + pow_2_pad = nearest_pow_2(seq_len) + mult_1024_pad = 1024 * math.ceil(seq_len / 1024) + min_extended_pad = min(pow_2_pad, mult_1024_pad) + return min_extended_pad + + +def get_block_size(kv_cache): + return kv_cache[0][0].shape[2] + + +def num_blocks_in_seq(seq_len, block_size): + return math.ceil(seq_len / block_size) + + +def nearest_pow_2(x): + return 2 ** math.ceil(math.log2(x)) diff --git a/models/demos/llama3/tt/llama_decoder.py b/models/demos/llama3/tt/llama_decoder.py index e5edfce889a..ad1bdf9b59a 100644 --- a/models/demos/llama3/tt/llama_decoder.py +++ b/models/demos/llama3/tt/llama_decoder.py @@ -20,6 +20,7 @@ def __init__( weight_cache_path, transformation_mats, paged_attention_config=None, + use_paged_kv_cache=False, ): super().__init__() @@ -48,6 +49,7 @@ def __init__( transformation_mats=transformation_mats, configuration=args, paged_attention_config=paged_attention_config, + use_paged_kv_cache=use_paged_kv_cache, ) self.feed_forward = TtLlamaMLP( mesh_device=mesh_device, @@ -97,6 +99,7 @@ def forward( user_id=0, mode="decode", page_table=None, + kv_cache=None, ) -> ttnn.Tensor: # x is fractured across devices and interleaved in DRAM (for prefill) and sharded in L1 (for decode) skip_mem_cfg = self.model_config["DECODE_RESIDUAL_MEMCFG"] if mode == "decode" else ttnn.DRAM_MEMORY_CONFIG @@ -112,7 +115,8 @@ def forward( rot_mats, user_id, mode, - page_table, + page_table=page_table, + kv_cache=kv_cache, ) # Here x and attn_out are both fractured across devices h = ttnn.add(x, attn_out, memory_config=skip_mem_cfg) diff --git a/models/demos/llama3/tt/llama_embedding.py b/models/demos/llama3/tt/llama_embedding.py index 89b6fb1b3f0..6259c17619f 100644 --- a/models/demos/llama3/tt/llama_embedding.py +++ b/models/demos/llama3/tt/llama_embedding.py @@ -23,7 +23,7 @@ def __init__( base_name = args.get_state_dict_prefix("", None) + "tok_embeddings.weight" torch_weight = self.state_dict[base_name].unsqueeze(0).unsqueeze(0) - cache_name = weight_cache_path / base_name + cache_name = None if args.dummy_weights else weight_cache_path / base_name self.weights = ttnn.as_tensor( torch_weight, dtype=dtype, diff --git a/models/demos/llama3/tt/llama_model.py b/models/demos/llama3/tt/llama_model.py index e04ed2c4cf8..9c55182115f 100644 --- a/models/demos/llama3/tt/llama_model.py +++ b/models/demos/llama3/tt/llama_model.py @@ -14,6 +14,9 @@ from models.common.lightweightmodule import LightweightModule from models.demos.llama3.tt.distributed_norm import DistributedNorm from models.demos.llama3.tt.lm_head import LMHead +from models.demos.llama3.tt.llama_common import copy_host_to_device, get_prefill_rot_mat, HostEmbedding +from models.demos.llama3.tt.llama_rope import TtLlamaRotarySetup +from models.demos.llama3.tt.llama_embedding import TtLlamaEmbedding class TtTransformer(LightweightModule): @@ -24,7 +27,6 @@ def __init__( mesh_device, state_dict, weight_cache_path, - transformation_mats, paged_attention_config=None, ): super().__init__() @@ -38,6 +40,24 @@ def __init__( self.grid_size = self.args.max_grid_size state_dict_prefix = args.get_state_dict_prefix("", None) + self.embd = TtLlamaEmbedding( + mesh_device=mesh_device, + args=args, + weight_cache_path=args.weight_cache_path(dtype), + state_dict=state_dict, + dtype=ttnn.bfloat16, # Row major layout requires bfloat16 + ) + + self.rope_setup = TtLlamaRotarySetup( + mesh_device, + args.max_batch_size, + args.head_dim, + args.max_seq_len, + args.rope_theta, + args.use_scaled_rope, + ) + self.trans_mats_dict = self.rope_setup.get_both_trans_mats() + self.layers = [ TtTransformerBlock( args=args, @@ -46,7 +66,7 @@ def __init__( state_dict=state_dict, weight_cache_path=weight_cache_path, layer_num=i, - transformation_mats=transformation_mats, + transformation_mats=self.trans_mats_dict, paged_attention_config=paged_attention_config, ) for i in range(self.n_layers) @@ -76,6 +96,167 @@ def __init__( weight_cache_path=weight_cache_path, ) + def prepare_inputs_prefill(self, tokens, page_table=None): + """ + Inputs are torch tensors or python types. This function returns ttnn + tensors on device. + TODO: Debate whether this function is responsible for padding + """ + + tokens = tokens.reshape(1, 1, 1, -1) + S = tokens.shape[-1] + + tokens = ttnn.from_torch( + tokens, + device=self.mesh_device, + dtype=ttnn.uint32, + layout=ttnn.ROW_MAJOR_LAYOUT, + mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + ) + + tokens_embd = self.embd(tokens) + + tt_rot_mats_prefill = get_prefill_rot_mat( + self.args.head_dim, self.args.max_seq_len, self.mesh_device, seq_len=S + ) + + if page_table is not None: + tt_page_table = ttnn.from_torch( + page_table, + device=self.mesh_device, + dtype=ttnn.int32, + layout=ttnn.ROW_MAJOR_LAYOUT, + mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + ) + else: + tt_page_table = None + + return tokens_embd, tt_rot_mats_prefill, tt_page_table + + def prepare_inputs_decode(self, *inputs): + """ + Inputs are torch tensors or python types. This function returns ttnn + tensors on device. + Its implementation can take advantage of a few other functions which the + model must implement. + """ + host_inputs = self.prepare_decode_inputs_host(*inputs) + device_inputs = copy_host_to_device(host_inputs, mesh_device=self.mesh_device) # Helper function + transformed_device_inputs = self.transform_decode_inputs_device(*device_inputs) + return transformed_device_inputs + + def prepare_decode_inputs_host(self, tokens, current_pos, page_table=None): + """ + Inputs are torch tensors or python types. Outputs are ttnn tensors on host. + NOTE: Tokens and current_pos are padded to batch + """ + B = tokens.shape[-1] + assert current_pos.shape[0] == B, "Batch size mismatch" + assert B == self.args.max_batch_size, "Batch size must be equal to max_batch_size" + + tokens = ttnn.from_torch( + tokens, + device=None, + dtype=ttnn.uint32, + mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + ) + + rope_idxs = self.rope_setup.get_rot_idxs(current_pos, on_host=True) + current_pos_tt = ttnn.from_torch( + current_pos, + device=None, + dtype=ttnn.int32, + mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + ) + + if page_table is not None: + page_table = ttnn.from_torch( + page_table, + device=None, + dtype=ttnn.int32, + mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + ) + return tokens, current_pos_tt, rope_idxs, page_table + + def transform_decode_inputs_device(self, tokens, current_pos, rope_idxs, page_table=None): + """ + Inputs are ttnn tensors on device. This function applies any on-device + transformations which should happen before forward decode. + For example: tilize, reshape, shard. + Return transformed device tensors + + Get rope sin/cos + Embed tokens + """ + tt_rot_mats = self.rope_setup.get_rot_mats(rope_idxs) + tt_tokens = self.embd(tokens) + tt_tokens = ttnn.unsqueeze_to_4D(tt_tokens) + return tt_tokens, current_pos, tt_rot_mats, page_table + + def process_output_prefill(self, tt_out, last_token_idx): + """ + Input is ttnn device tensor of logits. Output is torch logits tensor. + NOTE: In this model, prefill always uses get_last_token + """ + logits = ttnn.to_torch(tt_out, mesh_composer=ttnn.ConcatMeshToTensor(self.mesh_device, -1))[ + 0, 0, last_token_idx, : + ] + return logits + + def process_output_decode(self, tt_out): + """ + Input is ttnn device tensor of logits. Output is torch logits tensor + """ + if self.args.num_devices > 1: + tt_out = ttnn.all_gather(tt_out, dim=3, num_links=1, topology=ttnn.Topology.Linear) + tt_out_rm = ttnn.untilize(tt_out, use_multicore=True) + if self.args.num_devices > 1: + return ttnn.to_torch(ttnn.get_device_tensors(tt_out_rm)[0]).float() + else: + return ttnn.to_torch(tt_out_rm).float() + + def ttnn_prefill_forward( + self, + x, + rot_mats, + user_id, + page_table=None, + get_last_token=-1, + ): + """ + This method will take device tensors and any other args to run forward. + It returns ttnn device tensors. + """ + return self.forward( + x, + current_pos=None, + rot_mats=rot_mats, + transformation_mats=None, + user_id=user_id, + mode="prefill", + page_table=page_table, + get_last_token=get_last_token, + ) + + def ttnn_decode_forward( + self, + x, + current_pos, + rot_mats, + page_table=None, + ): + """ + This method will take device tensors and any other args to run forward. + It returns ttnn device tensors. + """ + return self.forward( + x, + current_pos, + rot_mats=rot_mats, + mode="decode", + page_table=page_table, + ) + def forward( self, x: ttnn.Tensor, diff --git a/models/demos/llama3/tt/llama_rope.py b/models/demos/llama3/tt/llama_rope.py index 576ce982e8c..c1b982308bc 100644 --- a/models/demos/llama3/tt/llama_rope.py +++ b/models/demos/llama3/tt/llama_rope.py @@ -87,9 +87,21 @@ def __init__( mesh_mapper=ReplicateTensorToMesh(device) if self.is_mesh_device else None, ) - def get_trans_mats(self): + # TODO: Colman, should this be TILE_SIZE or head_dim? Why should it be different for prefill and decode? + prefill_trans_mat_torch = get_rot_transformation_mat(dhead=head_dim) + self.transformation_mat_prefill = ttnn.from_torch( + prefill_trans_mat_torch, + device=device, + layout=ttnn.TILE_LAYOUT, + dtype=datatype, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ReplicateTensorToMesh(device) if self.is_mesh_device else None, + ) + + def get_both_trans_mats(self): assert self.transformation_mat is not None, "Transformation matrix not initialized" - return self.transformation_mat + assert self.transformation_mat_prefill is not None, "Prefill Transformation matrix not initialized" + return {"decode": self.transformation_mat, "prefill": self.transformation_mat_prefill} def get_rot_idxs(self, position_idxs, on_host=False): assert isinstance(position_idxs, torch.Tensor), "Position ids must be a torch tensor" diff --git a/models/demos/llama3/tt/model_config.py b/models/demos/llama3/tt/model_config.py index c3f9f385e7a..4ddb684fe9c 100644 --- a/models/demos/llama3/tt/model_config.py +++ b/models/demos/llama3/tt/model_config.py @@ -563,7 +563,7 @@ def find_largest_divisor(n, max_divisor=8): fuse_batch=False, ) self.model_config["VISION_XATTN_DENSE_PROGCFG"] = lambda seq_len: self.matmul_config( - m=seq_len, + m=min(seq_len, 1024), k=self.dim // self.num_devices, n=self.dim, grid_size=(8, 8), @@ -589,23 +589,21 @@ def find_largest_divisor(n, max_divisor=8): fuse_batch=seq_len <= max_seq, ) - xattn_cache_y_cores = ( - 16 // self.num_devices - ) # Based on seqlen, this formula gives us a valid number of y cores - xattn_cache_x_cores = 8 - self.model_config["XATTN_KV_PREFILL_MEM_CFG"] = lambda seq_len: ttnn.create_sharded_memory_config( - # using n_heads since xattn repeats KV to match Q - ( - nearest_32( - (self.n_heads // self.num_devices) * seq_len // (xattn_cache_y_cores * xattn_cache_x_cores) + def _get_xattn_kv_prefill_mem_cfg(seq_len): + M = (self.n_kv_heads // self.num_devices) * seq_len + cores_x, cores_y = self.find_grid(M // self.tile_size) + return ttnn.create_sharded_memory_config( + ( + nearest_32(M // (cores_x * cores_y)), + self.head_dim, ), - self.head_dim, - ), - ttnn.CoreGrid(y=xattn_cache_y_cores, x=xattn_cache_x_cores), - ttnn.ShardStrategy.HEIGHT, - ttnn.ShardOrientation.ROW_MAJOR, - use_height_and_width_as_shard_shape=True, - ) + ttnn.CoreGrid(y=cores_y, x=cores_x), + ttnn.ShardStrategy.HEIGHT, + ttnn.ShardOrientation.ROW_MAJOR, + use_height_and_width_as_shard_shape=True, + ) + + self.model_config["XATTN_KV_PREFILL_MEM_CFG"] = _get_xattn_kv_prefill_mem_cfg self.VISION_MAX_MM_SEQ = nearest_32(self.vision_chunk_ntok) # RMS NORM @@ -648,7 +646,7 @@ def ccl_topology(self): return ttnn.Topology.Linear return None - def prepare_inputs_ttnn_decode(self, x, input_mem_cfg, force_replicated=False, on_host=False): + def prepare_residual_tensor_decode(self, x, input_mem_cfg, force_replicated=False, on_host=False): """ Prepare inputs for decode mode. x: (batch, seq, dim) @@ -698,7 +696,7 @@ def prepare_inputs_ttnn_decode(self, x, input_mem_cfg, force_replicated=False, o x = ttnn.to_layout(x, layout=ttnn.TILE_LAYOUT) return x - def prepare_inputs_ttnn_prefill(self, x_bsh, force_replicated=False): + def prepare_residual_tensor_prefill(self, x_bsh, force_replicated=False): """ Prepare inputs for prefill mode. x: (batch, seq, hidden_dim) diff --git a/models/demos/llama3/tt/multimodal/llama_conv2d_patch.py b/models/demos/llama3/tt/multimodal/llama_conv2d_patch.py index a4d1bb59885..f5ff04f7e3e 100644 --- a/models/demos/llama3/tt/multimodal/llama_conv2d_patch.py +++ b/models/demos/llama3/tt/multimodal/llama_conv2d_patch.py @@ -79,7 +79,8 @@ def __init__( mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), ) - self.compute_kernel_config = ttnn.WormholeComputeKernelConfig( + self.compute_kernel_config = ttnn.init_device_compute_kernel_config( + mesh_device.arch(), math_fidelity=ttnn.MathFidelity.HiFi2, math_approx_mode=True, fp32_dest_acc_en=True, diff --git a/models/demos/llama3/tt/multimodal/llama_cross_attention.py b/models/demos/llama3/tt/multimodal/llama_cross_attention.py index d7032fd59ba..5aa338a012b 100644 --- a/models/demos/llama3/tt/multimodal/llama_cross_attention.py +++ b/models/demos/llama3/tt/multimodal/llama_cross_attention.py @@ -187,12 +187,7 @@ def compute_xattn_kv_cache(self, xattn_tokens, user_id, xattn_cache): xk = self.k_norm(xk, mode="decode") - # NOTE: Doing repeat in xattn_cache generation to avoid massive overhead in forward - xk = ttnn.repeat_interleave(xk, self.n_local_heads // self.n_local_kv_heads, dim=1) - xv = ttnn.repeat_interleave(xv, self.n_local_heads // self.n_local_kv_heads, dim=1) - k_cache, v_cache = xattn_cache - # Work around fill_cache memory constraint by making these sharded k_fill = ttnn.interleaved_to_sharded(xk, self.model_config["XATTN_KV_PREFILL_MEM_CFG"](seqlen_y)) v_fill = ttnn.interleaved_to_sharded(xv, self.model_config["XATTN_KV_PREFILL_MEM_CFG"](seqlen_y)) @@ -312,27 +307,22 @@ def forward_prefill( xq = self.q_norm(xq, mode="prefill") - scores = ttnn.matmul( - xq, - ttnn.transpose(k_cache_user, -1, -2), - dtype=ttnn.bfloat16, - memory_config=ttnn.DRAM_MEMORY_CONFIG, - compute_kernel_config=self.compute_kernel_config_hifi2, - program_config=self.model_config["VISION_XATTN_SCORE_PROGCFG"](seq_len, cache_seq_len), + program_config = ttnn.SDPAProgramConfig( + compute_with_storage_grid_size=self.mesh_device.compute_with_storage_grid_size(), + q_chunk_size=128, + k_chunk_size=128, + exp_approx_mode=False, ) - scores = ttnn.multiply(scores, self.scale) - # WARNING: This add is buggy if xattn_mask has to be broadcasted to n_local_heads. Workaround is to broadcast on host side - scores = ttnn.add(scores, xattn_mask) - scores = ttnn.softmax(scores, dim=-1, numeric_stable=True) - - output = ttnn.matmul( - scores, + output = ttnn.transformer.scaled_dot_product_attention( + xq, + k_cache_user, v_cache_user, - dtype=ttnn.bfloat16, - memory_config=ttnn.DRAM_MEMORY_CONFIG, - compute_kernel_config=self.compute_kernel_config_hifi4, - program_config=self.model_config["VISION_XATTN_OUTPUT_PROGCFG"](seq_len, cache_seq_len), + is_causal=False, + attn_mask=xattn_mask, + scale=self.scale, + program_config=program_config, + compute_kernel_config=self.compute_kernel_config_sdpa, ) # WARNING: this broadcast is also broken, must broadcast on host diff --git a/models/demos/llama3/tt/multimodal/llama_cross_attention_transformer_text.py b/models/demos/llama3/tt/multimodal/llama_cross_attention_transformer_text.py index f657abb8672..a7ce9def430 100644 --- a/models/demos/llama3/tt/multimodal/llama_cross_attention_transformer_text.py +++ b/models/demos/llama3/tt/multimodal/llama_cross_attention_transformer_text.py @@ -5,12 +5,15 @@ import math import ttnn import torch +from tqdm import tqdm from models.demos.llama3.tt.llama_decoder import TtTransformerBlock from models.demos.llama3.tt.multimodal.llama_cross_block import TtLlamaCrossAttentionTransformerBlock from models.demos.llama3.tt.distributed_norm import DistributedNorm from models.common.rmsnorm import RMSNorm import ttnn from models.common.lightweightmodule import LightweightModule +from models.demos.llama3.tt.llama_embedding import TtLlamaEmbedding +from models.demos.llama3.tt.llama_rope import TtLlamaRotarySetup from models.utility_functions import ( nearest_32, @@ -41,6 +44,7 @@ def __init__( weight_cache_path, dtype, configuration, + use_paged_kv_cache=False, ): super().__init__() self.vocab_size = configuration.vocab_size @@ -116,12 +120,31 @@ def __init__( self.num_frozen_embeddings = self.tok_embeddings.num_embeddings self._thresh = self.num_frozen_embeddings - 1 + self.rope_setup = TtLlamaRotarySetup( + mesh_device, + configuration.max_batch_size, + configuration.head_dim, + configuration.max_seq_len, + configuration.rope_theta, + configuration.use_scaled_rope, + ) + self.trans_mats_dict = self.rope_setup.get_both_trans_mats() + # transformer blocks self.layers = [] self.cross_attention_layers = [] - for i in range(configuration.n_layers): + for i in tqdm(range(configuration.n_layers), desc="Loading text transformer layers"): layer_id = i - block = TtTransformerBlock(configuration, mesh_device, dtype, state_dict, layer_id, weight_cache_path) + block = TtTransformerBlock( + configuration, + mesh_device, + dtype, + state_dict, + layer_id, + weight_cache_path, + transformation_mats=self.trans_mats_dict, + use_paged_kv_cache=use_paged_kv_cache, + ) self.layers.append(block) if layer_id in self.fusion_schedule: xa_layer_id = self.fusion_schedule.index(layer_id) @@ -224,7 +247,7 @@ def setup_cache(self, max_batch_size): [ ttnn.from_torch( torch.zeros( - max_batch_size, self.configuration.n_heads, vision_seq_len, self.configuration.head_dim + max_batch_size, self.configuration.n_kv_heads, vision_seq_len, self.configuration.head_dim ), device=self.mesh_device, layout=ttnn.TILE_LAYOUT, @@ -247,14 +270,14 @@ def forward( full_text_row_masked_out_mask_11SD: ttnn.Tensor, xattn_caches, current_pos, - rot_mat=None, - transformation_mats=None, + rot_mats=None, user_id=0, mode="decode", page_table=None, - # get_last_token=-1, + kv_cache=None, text_only_inference=False, vision_tokens=None, + get_last_token=-1, ): for idx, ( layer, @@ -275,12 +298,15 @@ def forward( h = layer( h, current_pos, - rot_mat=rot_mat, - transformation_mats=transformation_mats, + rot_mats=rot_mats, user_id=user_id, mode=mode, + page_table=page_table, + kv_cache=kv_cache, ) + if get_last_token != -1: + h = ttnn.slice(h, (0, 0, get_last_token, 0), (1, 1, get_last_token + 32, h.shape[-1])) h = self.norm(h, mode=mode) # TODO: Switch to using dram-sharded LM head and remove this diff --git a/models/demos/llama3/tt/multimodal/llama_cross_block.py b/models/demos/llama3/tt/multimodal/llama_cross_block.py index 1761bc7ac66..9d8c3760af0 100644 --- a/models/demos/llama3/tt/multimodal/llama_cross_block.py +++ b/models/demos/llama3/tt/multimodal/llama_cross_block.py @@ -126,6 +126,11 @@ def forward( user_id=0, vision_tokens=None, ): + skip_mem_cfg = self.model_config["DECODE_RESIDUAL_MEMCFG"] if mode == "decode" else ttnn.DRAM_MEMORY_CONFIG + assert ( + x_11SH.memory_config() == skip_mem_cfg + ), f"decoder input memcfg mismatch: {x_11SH.memory_config()} != {skip_mem_cfg}" + attn_out = self.attention( x_11SH=self.attention_norm(x_11SH, mode=mode), xattn_mask=xattn_mask, diff --git a/models/demos/llama3/tt/multimodal/llama_image_transformer.py b/models/demos/llama3/tt/multimodal/llama_image_transformer.py index ea86d302748..e9bef2377ed 100644 --- a/models/demos/llama3/tt/multimodal/llama_image_transformer.py +++ b/models/demos/llama3/tt/multimodal/llama_image_transformer.py @@ -2,10 +2,8 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import List, Optional -import torch +from tqdm import tqdm -import ttnn from models.utility_functions import ( nearest_32, ) @@ -41,7 +39,7 @@ def __init__( configuration=configuration, gated=gated, ) - for i in range(layers) + for i in tqdm(range(layers), desc=f"Loading vision transformer layers") ] def forward(self, x, return_intermediate=None, mask=None): diff --git a/models/demos/llama3/tt/multimodal/llama_vision_model.py b/models/demos/llama3/tt/multimodal/llama_vision_model.py index 96149d5a0f9..0b1f36fd6f4 100644 --- a/models/demos/llama3/tt/multimodal/llama_vision_model.py +++ b/models/demos/llama3/tt/multimodal/llama_vision_model.py @@ -29,6 +29,7 @@ get_prefill_rot_mat, get_rot_transformation_mat, get_single_rot_mat, + copy_host_to_device, ) from models.utility_functions import ( nearest_32, @@ -128,6 +129,7 @@ def __init__( weight_cache_path, dtype, configuration, + use_paged_kv_cache=False, ) -> None: super().__init__() @@ -159,6 +161,7 @@ def __init__( weight_cache_path=configuration.weight_cache_path(ttnn.bfloat8_b), dtype=ttnn.bfloat8_b, configuration=configuration, + use_paged_kv_cache=use_paged_kv_cache, ) self.image_res = configuration.vision_chunk_size self.max_num_chunks = configuration.vision_max_num_chunks @@ -268,7 +271,6 @@ def compute_vision_tokens_masks( def validate_inputs(self, tokens, position_ids): batch, seq_len = tokens.shape[:2] - assert batch == 1, f"Only batch 1 is supported, got {batch}" assert ( seq_len <= self.configuration.max_seq_len ), f"Sequence length {seq_len} exceeds max sequence length {self.configuration.max_seq_len}" @@ -279,7 +281,9 @@ def prepare_inputs_common(self, position_ids, tokens): h = self.text_model.get_partially_trainable_embedding(tokens) return h - def prepare_inputs_prefill(self, tokens, cross_attention_masks, full_text_row_masked_out_mask, prefill_len): + def prepare_inputs_prefill( + self, tokens, cross_attention_masks, full_text_row_masked_out_mask, prefill_len, page_table=None + ): B = tokens.shape[0] assert B == 1, f"Only batch 1 is supported, got {B}" S = tokens.shape[1] @@ -287,26 +291,16 @@ def prepare_inputs_prefill(self, tokens, cross_attention_masks, full_text_row_ma h = self.prepare_inputs_common(position_ids, tokens) padded_seq_len = _get_padded_prefill_seqlen(S) - tt_position_id = ttnn.from_torch( - position_ids, - device=self.mesh_device, - dtype=ttnn.int32, - layout=ttnn.ROW_MAJOR_LAYOUT, - memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), - ) - xattn_mask = cross_attention_masks[:, :, position_ids] - xattn_mask_expand = xattn_mask.expand(-1, self.configuration.n_heads // self.configuration.num_devices, -1, -1) - xattn_mask_expand = torch.nn.functional.pad( - xattn_mask_expand, - (0, 0, 0, padded_seq_len - xattn_mask_expand.shape[2]), + xattn_mask = torch.nn.functional.pad( + xattn_mask, + (0, 0, 0, padded_seq_len - xattn_mask.shape[2]), "constant", get_negative_inf_value(torch.float32), ) tt_xattn_mask = ttnn.from_torch( - xattn_mask_expand, + xattn_mask, device=self.mesh_device, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, @@ -314,6 +308,7 @@ def prepare_inputs_prefill(self, tokens, cross_attention_masks, full_text_row_ma mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), ) tt_xattn_mask = ttnn.to_layout(tt_xattn_mask, ttnn.TILE_LAYOUT) + tt_xattn_mask = ttnn.typecast(tt_xattn_mask, ttnn.bfloat4_b) full_text_mask = full_text_row_masked_out_mask[:, :, position_ids] full_text_mask = torch.nn.functional.pad( @@ -331,65 +326,75 @@ def prepare_inputs_prefill(self, tokens, cross_attention_masks, full_text_row_ma mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), ) tt_full_text_mask_expand_1NSH = ttnn.to_layout(tt_full_text_mask_expand_1NSH, ttnn.TILE_LAYOUT) + tt_full_text_mask_expand_1NSH = ttnn.typecast(tt_full_text_mask_expand_1NSH, ttnn.bfloat4_b) h = torch.nn.functional.pad(h, (0, 0, 0, padded_seq_len - h.shape[1]), "constant", 0) - tt_h = self.configuration.prepare_inputs_ttnn_prefill( + tt_h = self.configuration.prepare_residual_tensor_prefill( h, ) rot_mats = get_prefill_rot_mat( self.configuration.head_dim, self.configuration.max_seq_len, self.mesh_device, seq_len=S ) - transformation_mat_torch = get_rot_transformation_mat(self.configuration.head_dim) - transformation_mats = ttnn.as_tensor( - transformation_mat_torch, - dtype=ttnn.bfloat16, - layout=ttnn.TILE_LAYOUT, - device=self.mesh_device, - mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), - memory_config=ttnn.DRAM_MEMORY_CONFIG, - ) full_text_mask_expand_11SD = full_text_mask.expand(-1, -1, -1, self.configuration.dim) tt_full_text_mask_expand_11SD = ttnn.from_torch( full_text_mask_expand_11SD, device=self.mesh_device, - dtype=ttnn.bfloat8_b, + dtype=ttnn.bfloat4_b, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=-1), ) + if isinstance(page_table, torch.Tensor): + # Support vLLM tensor page_table input + page_table = ttnn.as_tensor( + page_table, + device=self.mesh_device, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + dtype=ttnn.int32, + layout=ttnn.ROW_MAJOR_LAYOUT, + mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + ) + return ( tt_h, tt_xattn_mask, tt_full_text_mask_expand_1NSH, tt_full_text_mask_expand_11SD, - tt_position_id, rot_mats, - transformation_mats, + page_table, ) - def prepare_inputs_decode(self, tokens, cross_attention_masks, full_text_row_masked_out_mask, position_id): + def prepare_inputs_decode( + self, tokens, cross_attention_masks, full_text_row_masked_out_mask, position_id, page_table=None + ): ( tt_h, tt_xattn_mask, tt_full_text_mask_expand_1NSH, - _tt_full_text_mask_expand_11SD, tt_position_id, - rot_mats, - _transformation_mats, - ) = self.prepare_decode_inputs_host(tokens, cross_attention_masks, full_text_row_masked_out_mask, position_id) + tt_rope_id, + tt_page_table, + ) = self.prepare_decode_inputs_host( + tokens, cross_attention_masks, full_text_row_masked_out_mask, position_id, page_table=page_table + ) ( tt_h, tt_xattn_mask, tt_full_text_mask_expand_1NSH, tt_position_id, - rot_mats, - ) = self.copy_host_to_device((tt_h, tt_xattn_mask, tt_full_text_mask_expand_1NSH, tt_position_id, rot_mats)) + tt_rope_id, + tt_page_table, + ) = copy_host_to_device( + (tt_h, tt_xattn_mask, tt_full_text_mask_expand_1NSH, tt_position_id, tt_rope_id, tt_page_table), + mesh_device=self.mesh_device, + ) - tt_h, tt_xattn_mask, tt_full_text_mask_expand_1NSH = self.transform_decode_inputs_device( + tt_h, tt_rot_mats, tt_xattn_mask, tt_full_text_mask_expand_1NSH = self.transform_decode_inputs_device( tt_h, + tt_rope_id, tt_xattn_mask, tt_full_text_mask_expand_1NSH, B=tokens.shape[0], @@ -399,34 +404,35 @@ def prepare_inputs_decode(self, tokens, cross_attention_masks, full_text_row_mas tt_h, tt_xattn_mask, tt_full_text_mask_expand_1NSH, - _tt_full_text_mask_expand_11SD, tt_position_id, - rot_mats, - _transformation_mats, + tt_rot_mats, + tt_page_table, ) - def prepare_decode_inputs_host(self, tokens, cross_attention_masks, full_text_row_masked_out_mask, position_id): + def prepare_decode_inputs_host( + self, tokens, cross_attention_masks, full_text_row_masked_out_mask, position_id, page_table=None + ): B = tokens.shape[0] assert ( B == self.configuration.max_batch_size ), f"Batch size must match max batch size. Got {B}, expected {self.configuration.max_batch_size}" - position_ids = torch.tensor([position_id], dtype=torch.long) - h = self.prepare_inputs_common(position_ids, tokens) - tt_h = self.configuration.prepare_inputs_ttnn_decode( + h = self.prepare_inputs_common(position_id, tokens) + tt_h = self.configuration.prepare_residual_tensor_decode( h, - None, # on_host tensors have no memory_config + None, on_host=True, ) tt_position_id = ttnn.from_torch( - position_ids, + position_id, device=None, dtype=ttnn.int32, layout=ttnn.ROW_MAJOR_LAYOUT, mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), ) - xattn_mask = cross_attention_masks[:, :, position_ids] + tt_rope_id = self.text_model.rope_setup.get_rot_idxs(position_id, on_host=True) + xattn_mask = torch.cat([cross_attention_masks[i][:, :, position_id[i]] for i in range(B)], dim=1).unsqueeze(0) xattn_mask_expand = xattn_mask.expand(-1, self.configuration.n_heads // self.configuration.num_devices, -1, -1) xattn_mask_expand = xattn_mask_expand.transpose(1, 2).contiguous() @@ -438,7 +444,9 @@ def prepare_decode_inputs_host(self, tokens, cross_attention_masks, full_text_ro mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), ) - full_text_mask = full_text_row_masked_out_mask[:, :, position_ids] + full_text_mask = torch.cat( + [full_text_row_masked_out_mask[i][:, :, position_id[i]] for i in range(B)], dim=1 + ).unsqueeze(0) full_text_mask_expand_1NSH = full_text_mask.expand( -1, self.configuration.n_heads // self.configuration.num_devices, -1, self.configuration.head_dim ) @@ -451,44 +459,25 @@ def prepare_decode_inputs_host(self, tokens, cross_attention_masks, full_text_ro mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), ) - rot_mats, _ = get_single_rot_mat( - self.configuration.head_dim, - self.mesh_device, - self.configuration.num_devices, - start_pos=position_ids.item() - 1, # TODO: Change function to support decode batch > 1 - # TODO: B must match max_batch_size, be careful - on_host=True, - ) - - transformation_mats = None - tt_full_text_mask_expand_11SD = None + if isinstance(page_table, torch.Tensor): + # Support vLLM tensor page_table input + page_table = ttnn.as_tensor( + page_table, + dtype=ttnn.int32, + layout=ttnn.ROW_MAJOR_LAYOUT, + mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + ) return ( tt_h, tt_xattn_mask, tt_full_text_mask_expand_1NSH, - tt_full_text_mask_expand_11SD, tt_position_id, - rot_mats, - transformation_mats, + tt_rope_id, + page_table, ) - def copy_host_to_device(self, host_tensors, device_tensors=None): - """ - Helper function which copies host tensors to device tensors - """ - if device_tensors is None: - ret = [] - for i in range(len(host_tensors)): - on_device = ttnn.to_device(host_tensors[i], device=self.mesh_device) - ret.append(on_device) - return ret - else: - for i in range(len(host_tensors)): - ttnn.copy_host_to_device_tensor(host_tensors[i], device_tensors[i]) - return device_tensors - - def transform_decode_inputs_device(self, tt_h, tt_xattn_mask, tt_full_text_mask_expand_1NSH, B): + def transform_decode_inputs_device(self, tt_h, tt_rope_id, tt_xattn_mask, tt_full_text_mask_expand_1NSH, B): """ Does any transformations on device tensors which are necessary before ttnn_decode_forward """ @@ -499,6 +488,8 @@ def transform_decode_inputs_device(self, tt_h, tt_xattn_mask, tt_full_text_mask_ tt_h = ttnn.to_memory_config(tt_h, self.configuration.model_config["DECODE_RESIDUAL_MEMCFG"]) + tt_rot_mats = self.text_model.rope_setup.get_rot_mats(tt_rope_id) + tt_xattn_mask = ttnn.to_layout(tt_xattn_mask, ttnn.TILE_LAYOUT) tt_xattn_mask = ttnn.reshape( tt_xattn_mask, @@ -531,12 +522,11 @@ def transform_decode_inputs_device(self, tt_h, tt_xattn_mask, tt_full_text_mask_ ), ) - return (tt_h, tt_xattn_mask, tt_full_text_mask_expand_1NSH) + return (tt_h, tt_rot_mats, tt_xattn_mask, tt_full_text_mask_expand_1NSH) - def process_output_prefill(self, tt_out, B, S): - padded_seq_len = _get_padded_prefill_seqlen(S) + def process_output_prefill(self, tt_out, B, last_token_idx): tt_out = ttnn.to_torch(ttnn.get_device_tensors(tt_out)[0]).float() - tt_out = tt_out[0].reshape(B, padded_seq_len, -1)[:, :S, :] + tt_out = tt_out[0, 0, last_token_idx, :] return tt_out def process_output_decode(self, tt_out, B, S): @@ -554,6 +544,8 @@ def forward( text_only_inference: bool = False, user_id=0, vision_tokens=None, + page_table=None, + kv_cache=None, ) -> torch.Tensor: """ This method takes torch tensors in, returns torch tensors. @@ -573,12 +565,13 @@ def forward( tt_full_text_mask_expand_11SD, tt_position_id, rot_mats, - transformation_mats, + tt_page_table, ) = prepare_fn( tokens, cross_attention_masks, full_text_row_masked_out_mask, pos_arg, + page_table=page_table, ) logits = self.text_model.forward( @@ -588,10 +581,11 @@ def forward( full_text_row_masked_out_mask_11SD=tt_full_text_mask_expand_11SD, xattn_caches=xattn_caches, current_pos=tt_position_id, - rot_mat=rot_mats, - transformation_mats=transformation_mats, + rot_mats=rot_mats, user_id=user_id, mode=mode, + page_table=tt_page_table, + kv_cache=kv_cache, text_only_inference=text_only_inference, vision_tokens=vision_tokens, ) @@ -607,11 +601,12 @@ def ttnn_prefill_forward( full_text_mas_expand_1NSH, full_text_mask_expand_11SD, xattn_caches, - position_id, rot_mats, - transformation_mats, user_id, vision_tokens, + page_table=None, + kv_cache=None, + get_last_token=-1, ): """ This method runs prefill forward. It takes ttnn tensors in, returns ttnn tensors. @@ -622,12 +617,14 @@ def ttnn_prefill_forward( full_text_row_masked_out_mask_1NSH=full_text_mas_expand_1NSH, full_text_row_masked_out_mask_11SD=full_text_mask_expand_11SD, xattn_caches=xattn_caches, - current_pos=position_id, - rot_mat=rot_mats, - transformation_mats=transformation_mats, + current_pos=None, + rot_mats=rot_mats, user_id=user_id, mode="prefill", + page_table=page_table, + kv_cache=kv_cache, vision_tokens=vision_tokens, + get_last_token=get_last_token, ) tt_out = ttnn.to_layout(logits, ttnn.ROW_MAJOR_LAYOUT) return tt_out @@ -640,6 +637,8 @@ def ttnn_decode_forward( xattn_caches, position_id, rot_mats, + page_table=None, + kv_cache=None, ): """ This method runs decode forward. It takes ttnn tensors in, returns ttnn tensors. @@ -651,8 +650,10 @@ def ttnn_decode_forward( full_text_row_masked_out_mask_11SD=None, xattn_caches=xattn_caches, current_pos=position_id, - rot_mat=rot_mats, + rot_mats=rot_mats, mode="decode", + page_table=page_table, + kv_cache=kv_cache, ) tt_out = ttnn.to_layout(logits, ttnn.ROW_MAJOR_LAYOUT) return tt_out @@ -720,11 +721,11 @@ def _pad_masks( def _get_padded_prefill_seqlen(seq_len): """ If seq_len is less than 128, pad to 128 - If seq_len is more than 128, pad to whichever is smaller: a power of 2 or a multiple of 1024 + If seq_len is more than 128, pad to whichever is smaller: a power of 2 or a multiple of 2048 (text model requires mult of 2048, while vision allows mult of 1024) """ if seq_len < 128: return 128 else: - mult_1024 = 1024 * math.ceil(seq_len / 1024) + mult_2k = 2048 * math.ceil(seq_len / 2048) pow_2 = 2 ** math.ceil(math.log2(seq_len)) - return min(mult_1024, pow_2) + return min(mult_2k, pow_2) diff --git a/models/demos/segformer/tt/common.py b/models/demos/segformer/tt/common.py index 5f52fe0e507..d777116d232 100644 --- a/models/demos/segformer/tt/common.py +++ b/models/demos/segformer/tt/common.py @@ -40,12 +40,8 @@ def __call__(self, device, input_tensor): conv_config = ttnn.Conv2dConfig( dtype=self.dtype, weights_dtype=ttnn.bfloat16, - math_fidelity=ttnn.MathFidelity.LoFi, activation=self.activation, shard_layout=self.shard_layout, - math_approx_mode_enabled=True, - fp32_dest_acc_enabled=False, - packer_l1_accum_enabled=False, input_channels_alignment=16 if input_tensor.shape[3] < 16 else 32, transpose_shards=False, reshard_if_not_optimal=self.reshard, @@ -54,10 +50,17 @@ def __call__(self, device, input_tensor): enable_act_double_buffer=True, enable_split_reader=False, ) + compute_config = ttnn.init_device_compute_kernel_config( + device.arch(), + math_fidelity=ttnn.MathFidelity.LoFi, + math_approx_mode=True, + fp32_dest_acc_en=False, + packer_l1_acc=False, + ) if self.act_block_h is not None: conv_config.act_block_h_override = self.act_block_h - [output_tensor, _out_height, _out_width, self.weights, self.bias] = ttnn.conv2d( + [output_tensor, [_out_height, _out_width]] = ttnn.conv2d( input_tensor=input_tensor, weight_tensor=self.weights, bias_tensor=self.bias, @@ -71,7 +74,10 @@ def __call__(self, device, input_tensor): input_height=input_tensor.shape[1], input_width=input_tensor.shape[2], conv_config=conv_config, + compute_config=compute_config, groups=self.groups, + return_output_dim=True, + return_weights_and_bias=False, ) return output_tensor, _out_height, _out_width diff --git a/models/demos/segformer/tt/ttnn_segformer_decode_head.py b/models/demos/segformer/tt/ttnn_segformer_decode_head.py index 6aed216c578..4be9a957a8c 100644 --- a/models/demos/segformer/tt/ttnn_segformer_decode_head.py +++ b/models/demos/segformer/tt/ttnn_segformer_decode_head.py @@ -78,7 +78,7 @@ def __call__(self, encoder_hidden_states: ttnn.bfloat8_b, parameters) -> ttnn.Te encoder_hidden_state = ttnn.upsample( encoder_hidden_state, - scale_factor=(128 // encoder_hidden_state.shape[2], 128 // encoder_hidden_state.shape[2], 1), + scale_factor=(128 // encoder_hidden_state.shape[2], 128 // encoder_hidden_state.shape[2]), mode="bilinear", ) diff --git a/models/demos/squeezebert/README.md b/models/demos/squeezebert/README.md new file mode 100644 index 00000000000..70adcab9c2e --- /dev/null +++ b/models/demos/squeezebert/README.md @@ -0,0 +1,30 @@ +# SqueezeBERT demo + +Demo showcasing SqueezeBERT running on Grayskull - e150 and Wormhole - n150, n300 using ttnn. + +## Introduction +SqueezeBERT is a bidirectional transformer similar to the BERT model. The key difference between the BERT architecture and the SqueezeBERT architecture is that SqueezeBERT uses grouped convolutions instead of fully-connected layers for the Q, K, V and FFN layers. + + +## Details +The entry point to functional_squeezebert model is squeezebert_for_question_answering in `models/demos/squeezebert/tt/ttnn_functional_squeezebert.py`. The model picks up certain configs and weights from huggingface pretrained model. We have used `squeezebert/squeezebert-uncased` version from huggingface as our reference. + +### Sequence Size: 384 +Sequence size determines the maximum length of input sequences processed by the model, optimizing performance and compatibility. It's recommended to set the sequence_size to 384 + +### Batch size: 8 +Batch Size determines the number of input sequences processed simultaneously during training or inference, impacting computational efficiency and memory usage. It's recommended to set the batch_size to 8 + +## How to Run + +Use `pytest --disable-warnings models/demos/squeezebert/demo/demo.py::test_demo[models.demos.squeezebert.tt.ttnn_functional_squeezebert-squeezebert/squeezebert-uncased-models/demos/squeezebert/demo/input_data.json-8-384-device_params0]` to run the demo. + +If you wish to run the demo with a different input use `pytest --disable-warnings models/demos/squeezebert/demo/demo.py::test_demo[models.demos.squeezebert.tt.ttnn_functional_squeezebert-squeezebert/squeezebert-uncased--8-384-device_params0]`. This file is expected to have exactly 8 inputs. + +Our second demo is designed to run SQuADV2 dataset, run this with `pytest --disable-warnings models/demos/squeezebert/demo/demo.py::test_demo_squadv2[3-models.demos.squeezebert.tt.ttnn_functional_squeezebert-squeezebert/squeezebert-uncased-8-384-device_params0]`. + +If you wish to run for `n_iterations` samples, use `pytest --disable-warnings models/demos/squeezebert/demo/demo.py::test_demo_squadv2[-models.demos.squeezebert.tt.ttnn_functional_squeezebert-squeezebert/squeezebert-uncased-8-384-device_params0]` + + +## Inputs +The demo receives inputs from respective `input_data.json` by default. To modify the inputs or specify a different path, adjust the input_path parameter in the command accordingly. It's recommended to avoid direct modifications to the input_data.json file. diff --git a/models/demos/squeezebert/demo/demo.py b/models/demos/squeezebert/demo/demo.py new file mode 100644 index 00000000000..6f7dc817d1d --- /dev/null +++ b/models/demos/squeezebert/demo/demo.py @@ -0,0 +1,324 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import ttnn +import json +import torch +import pytest +import evaluate + +from loguru import logger +from ttnn.model_preprocessing import * +from models.utility_functions import ( + profiler, + skip_for_wormhole_b0, + disable_compilation_reports, + disable_persistent_kernel_cache, +) +from ttnn.model_preprocessing import preprocess_model_parameters +from models.demos.squeezebert.tt import ttnn_functional_squeezebert +from models.datasets.dataset_squadv2 import squadv2_1K_samples_input, squadv2_answer_decode_batch + +from transformers import SqueezeBertForQuestionAnswering, pipeline, SqueezeBertTokenizer + + +def load_inputs(input_path, batch): + with open(input_path) as f: + input_data = json.load(f) + assert len(input_data) >= batch, f"Input data needs to have at least {batch} (batch size) entries." + + context = [] + question = [] + for i in range(batch): + context.append(input_data[i]["context"]) + question.append(input_data[i]["question"]) + + return context, question + + +def positional_ids(config, input_ids, past_key_values_length=0): + seq_length = input_ids.size(1) + position_ids = torch.arange(config.max_position_embeddings, dtype=torch.long, device=input_ids.device) + position_ids = position_ids.unsqueeze(0)[:, past_key_values_length : seq_length + past_key_values_length] + position_ids = position_ids.expand_as(input_ids) + + return position_ids + + +def run_squeezebert_question_and_answering_inference( + device, + use_program_cache, + model_name, + batch_size, + sequence_size, + squeezebert, + input_path, +): + disable_persistent_kernel_cache() + + hugging_face_reference_model = SqueezeBertForQuestionAnswering.from_pretrained(model_name, torchscript=False) + hugging_face_reference_model.eval() + state_dict = hugging_face_reference_model.state_dict() + + tokenizer = SqueezeBertTokenizer.from_pretrained(model_name) + config = hugging_face_reference_model.config + nlp = pipeline("question-answering", model=hugging_face_reference_model, tokenizer=tokenizer) + + tt_model_name = f"ttnn_{model_name}" + + def convert_to_ttnn(model, name): + return not isinstance(model, torch.nn.Conv1d) + + profiler.start(f"preprocessing_parameter") + parameters = preprocess_model_parameters( + model_name=tt_model_name, + initialize_model=lambda: hugging_face_reference_model, + convert_to_ttnn=convert_to_ttnn, + custom_preprocessor=squeezebert.custom_preprocessor, + device=device, + ) + profiler.end(f"preprocessing_parameter") + + context, question = load_inputs(input_path, batch_size) + + preprocess_params, _, postprocess_params = nlp._sanitize_parameters() + preprocess_params["max_seq_len"] = sequence_size + inputs = nlp._args_parser({"context": context, "question": question}) + + preprocessed_inputs = [] + for i in range(batch_size): + model_input = next(nlp.preprocess(inputs[0][i], **preprocess_params)) + + single_input = { + "example": model_input["example"], + "inputs": model_input, + } + preprocessed_inputs.append(single_input) + + squeezebert_input = tokenizer.batch_encode_plus( + zip(question, context), + max_length=sequence_size, + padding="max_length", + truncation=True, + return_attention_mask=True, + return_token_type_ids=True, + return_tensors="pt", + ) + + profiler.start(f"preprocessing_input") + position_ids = positional_ids(config, squeezebert_input.input_ids) + ttnn_squeezebert_inputs = squeezebert.preprocess_inputs( + squeezebert_input["input_ids"], + squeezebert_input["token_type_ids"], + position_ids, + squeezebert_input["attention_mask"], + device=device, + ) + profiler.end(f"preprocessing_input") + + profiler.start(f"inference_time") + tt_output = squeezebert.squeezebert_for_question_answering( + config, + *ttnn_squeezebert_inputs, + state_dict=state_dict, + base_addr=f"transformer.", + parameters=parameters, + device=device, + reader_patterns_cache=None, + ) + profiler.end(f"inference_time") + + tt_output = ttnn.to_torch(ttnn.from_device(tt_output)).reshape(batch_size, 1, sequence_size, -1).to(torch.float32) + + tt_start_logits = tt_output[..., :, 0].squeeze(1) + tt_end_logits = tt_output[..., :, 1].squeeze(1) + + model_answers = {} + profiler.start("post_processing_output_to_string") + for i in range(batch_size): + tt_res = { + "start": tt_start_logits[i], + "end": tt_end_logits[i], + "example": preprocessed_inputs[i]["example"], + **preprocessed_inputs[i]["inputs"], + } + tt_answer = nlp.postprocess([tt_res], **postprocess_params) + + logger.info(f"answer: {tt_answer['answer']}\n") + model_answers[i] = tt_answer["answer"] + + profiler.end("post_processing_output_to_string") + + measurements = { + "preprocessing_parameter": profiler.get("preprocessing_parameter"), + "preprocessing_input": profiler.get("preprocessing_input"), + "inference_time": profiler.get("inference_time"), + "post_processing": profiler.get("post_processing_output_to_string"), + } + logger.info(f"preprocessing_parameter: {measurements['preprocessing_parameter']} s") + logger.info(f"preprocessing_input: {measurements['preprocessing_input']} s") + logger.info(f"inference_time: {measurements['inference_time']} s") + logger.info(f"post_processing : {measurements['post_processing']} s") + + return measurements + + +def run_squeezebert_question_and_answering_inference_squad_v2( + device, + use_program_cache, + model_name, + batch_size, + sequence_size, + squeezebert, + n_iterations, +): + disable_persistent_kernel_cache() + + hugging_face_reference_model = SqueezeBertForQuestionAnswering.from_pretrained(model_name, torchscript=False) + hugging_face_reference_model.eval() + state_dict = hugging_face_reference_model.state_dict() + + tokenizer = SqueezeBertTokenizer.from_pretrained(model_name) + config = hugging_face_reference_model.config + tt_model_name = ttnn_functional_squeezebert + + parameters = preprocess_model_parameters( + model_name=tt_model_name, + initialize_model=lambda: hugging_face_reference_model, + custom_preprocessor=squeezebert.custom_preprocessor, + device=device, + ) + + nlp = pipeline("question-answering", model=hugging_face_reference_model, tokenizer=tokenizer) + + attention_mask = True + token_type_ids = True + inputs_squadv2 = squadv2_1K_samples_input(tokenizer, sequence_size, attention_mask, token_type_ids, batch_size) + squad_metric = evaluate.load("squad_v2") + + with torch.no_grad(): + pred_labels = [] + cpu_pred_labels = [] + true_labels = [] + i = 0 + for batch in inputs_squadv2: + if i < n_iterations: + batch_data = batch[0] + curr_batch_size = batch_data["input_ids"].shape[0] + position_ids = positional_ids(config, batch_data.input_ids) + ttnn_squeezebert_inputs = squeezebert.preprocess_inputs( + batch_data["input_ids"], + batch_data["token_type_ids"], + position_ids, + batch_data["attention_mask"], + device=device, + ) + + tt_output = squeezebert.squeezebert_for_question_answering( + config, + *ttnn_squeezebert_inputs, + state_dict=state_dict, + base_addr=f"transformer.", + parameters=parameters, + device=device, + reader_patterns_cache=None, + ) + tt_output = ( + ttnn.to_torch(ttnn.from_device(tt_output)) + .reshape(batch_size, 1, sequence_size, -1) + .to(torch.float32) + ) + + cpu_output = hugging_face_reference_model(**batch_data) + references = batch[1] + question = batch[2] + context = batch[3] + + cpu_predictions, tt_predictions = squadv2_answer_decode_batch( + hugging_face_reference_model, + tokenizer, + nlp, + references, + cpu_output, + tt_output, + curr_batch_size, + question, + context, + ) + pred_labels.extend(tt_predictions) + cpu_pred_labels.extend(cpu_predictions) + true_labels.extend(references) + + del tt_output + i += 1 + eval_score = squad_metric.compute(predictions=pred_labels, references=true_labels) + cpu_eval_score = squad_metric.compute(predictions=cpu_pred_labels, references=true_labels) + logger.info(f"\tTT_Eval: exact: {eval_score['exact']} -- F1: {eval_score['f1']}") + logger.info(f"\tCPU_Eval: exact: {cpu_eval_score['exact']} -- F1: {cpu_eval_score['f1']}") + + tolerance = 0.03 + assert ( + abs(eval_score["exact"] - cpu_eval_score["exact"]) <= tolerance + and abs(eval_score["f1"] - cpu_eval_score["f1"]) <= tolerance + ), ( + f"Expected Exact Match : {cpu_eval_score['exact']}, Actual Exact Match: {eval_score['exact']}; " + f"Expected F1 Score : {cpu_eval_score['f1']}, Actual F1 Score: {eval_score['f1']}" + ) + + +@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) +@pytest.mark.parametrize( + "batch_size, sequence_size", + [ + (8, 384), + ], +) +@pytest.mark.parametrize( + "model_name, input_loc", + ((["squeezebert/squeezebert-uncased", "models/demos/squeezebert/demo/input_data.json"]),), +) +@pytest.mark.parametrize("squeezebert", [ttnn_functional_squeezebert]) +def test_demo(input_loc, batch_size, sequence_size, model_name, squeezebert, device, use_program_cache, reset_seeds): + disable_persistent_kernel_cache() + disable_compilation_reports() + + return run_squeezebert_question_and_answering_inference( + device=device, + use_program_cache=use_program_cache, + model_name=model_name, + batch_size=batch_size, + sequence_size=sequence_size, + squeezebert=squeezebert, + input_path=input_loc, + ) + + +@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) +@pytest.mark.parametrize( + "batch_size, sequence_size", + [ + (8, 384), + ], +) +@pytest.mark.parametrize("model_name", ["squeezebert/squeezebert-uncased"]) +@pytest.mark.parametrize("squeezebert", [ttnn_functional_squeezebert]) +@pytest.mark.parametrize( + "n_iterations", + ((3),), +) +def test_demo_squadv2( + batch_size, sequence_size, model_name, squeezebert, n_iterations, device, use_program_cache, reset_seeds +): + disable_persistent_kernel_cache() + disable_compilation_reports() + + return run_squeezebert_question_and_answering_inference_squad_v2( + device=device, + use_program_cache=use_program_cache, + model_name=model_name, + batch_size=batch_size, + sequence_size=sequence_size, + squeezebert=squeezebert, + n_iterations=n_iterations, + ) diff --git a/models/demos/squeezebert/demo/input_data.json b/models/demos/squeezebert/demo/input_data.json new file mode 100644 index 00000000000..950b8d36323 --- /dev/null +++ b/models/demos/squeezebert/demo/input_data.json @@ -0,0 +1,50 @@ +[ + { + "context" : "Johann Joachim Winckelmann was a German art historian and archaeologist. He was a pioneering Hellenist who first articulated the difference between Greek, Greco-Roman and Roman art. The prophet and founding hero of modern archaeology, Winckelmann was one of the founders of scientific archaeology and first applied the categories of style on a large, systematic basis to the history of art.", + "question" : "What discipline did Winkelmann create?" + }, + { + "context" : "The Norman dynasty had a major political, cultural and military impact on medieval Europe and even the Near East. The Normans were famed for their martial spirit and eventually for their Christian piety, becoming exponents of the Catholic orthodoxy into which they assimilated. They adopted the Gallo-Romance language of the Frankish land they settled, their dialect becoming known as Norman, Normaund or Norman French, an important literary language. The Duchy of Normandy, which they formed by treaty with the French crown, was a great fief of medieval France, and under Richard I of Normandy was forged into a cohesive and formidable principality in feudal tenure. The Normans are noted both for their culture, such as their unique Romanesque architecture and musical traditions, and for their significant military accomplishments and innovations. Norman adventurers founded the Kingdom of Sicily under Roger II after conquering southern Italy on the Saracens and Byzantines, and an expedition on behalf of their duke, William the Conqueror, led to the Norman conquest of England at the Battle of Hastings in 1066. Norman cultural and military influence spread from these new European centres to the Crusader states of the Near East, where their prince Bohemond I founded the Principality of Antioch in the Levant, to Scotland and Wales in Great Britain, to Ireland, and to the coasts of north Africa and the Canary Islands.", + "question" : "Who ruled the duchy of Normandy" + }, + { + "context" : "In many countries, there is a Gender pay gap in favor of males in the labor market. Several factors other than discrimination may contribute to this gap. On average, women are more likely than men to consider factors other than pay when looking for work, and may be less willing to travel or relocate. Thomas Sowell, in his book Knowledge and Decisions, claims that this difference is due to women not taking jobs due to marriage or pregnancy, but income studies show that that does not explain the entire difference. A U.S. Census's report stated that in US once other factors are accounted for there is still a difference in earnings between women and men. The income gap in other countries ranges from 53% in Botswana to -40% in Bahrain.", + "question" : "Who does a gender pay gap tend to favor?" + }, + { + "context" : "Most of the Huguenot congregations (or individuals) in North America eventually affiliated with other Protestant denominations with more numerous members. The Huguenots adapted quickly and often married outside their immediate French communities, which led to their assimilation. Their descendants in many families continued to use French first names and surnames for their children well into the nineteenth century. Assimilated, the French made numerous contributions to United States economic life, especially as merchants and artisans in the late Colonial and early Federal periods. For example, E.I. du Pont, a former student of Lavoisier, established the Eleutherian gunpowder mills.", + "question" : "How were Huguenot settlers assimilated into North American society at large?" + }, + { + "context" : "In the laboratory, biostratigraphers analyze rock samples from outcrop and drill cores for the fossils found in them. These fossils help scientists to date the core and to understand the depositional environment in which the rock units formed. Geochronologists precisely date rocks within the stratigraphic section in order to provide better absolute bounds on the timing and rates of deposition. Magnetic stratigraphers look for signs of magnetic reversals in igneous rock units within the drill cores. Other scientists perform stable isotope studies on the rocks to gain information about past climate.", + "question" : "Who analyzes rock samples from drill cores in the lab?" + }, + { + "context" : "Neutrophils and macrophages are phagocytes that travel throughout the body in pursuit of invading pathogens. Neutrophils are normally found in the bloodstream and are the most abundant type of phagocyte, normally representing 50% to 60% of the total circulating leukocytes. During the acute phase of inflammation, particularly as a result of bacterial infection, neutrophils migrate toward the site of inflammation in a process called chemotaxis, and are usually the first cells to arrive at the scene of infection. Macrophages are versatile cells that reside within tissues and produce a wide array of chemicals including enzymes, complement proteins, and regulatory factors such as interleukin 1. Macrophages also act as scavengers, ridding the body of worn-out cells and other debris, and as antigen-presenting cells that activate the adaptive immune system.", + "question" : "What is the process in which neutrophils move towards the site of inflammation called?" + }, + { + "context" : "In Afghanistan, the mujahideen's victory against the Soviet Union in the 1980s did not lead to justice and prosperity, due to a vicious and destructive civil war between political and tribal warlords, making Afghanistan one of the poorest countries on earth. In 1992, the Democratic Republic of Afghanistan ruled by communist forces collapsed, and democratic Islamist elements of mujahdeen founded the Islamic State of Afghanistan. In 1996, a more conservative and anti-democratic Islamist movement known as the Taliban rose to power, defeated most of the warlords and took over roughly 80% of Afghanistan.", + "question" : "When did the Democratic Republic of Afghanistan collapse?" + }, + { + "context" : "The largest single sensory feature is the aboral organ (at the opposite end from the mouth). Its main component is a statocyst, a balance sensor consisting of a statolith, a solid particle supported on four bundles of cilia, called \"balancers\", that sense its orientation. The statocyst is protected by a transparent dome made of long, immobile cilia. A ctenophore does not automatically try to keep the statolith resting equally on all the balancers. Instead its response is determined by the animal's \"mood\", in other words the overall state of the nervous system. For example, if a ctenophore with trailing tentacles captures prey, it will often put some comb rows into reverse, spinning the mouth towards the prey.", + "question" : "What is the main component of the aboral organ?" + }, + { + "context": "Mark Rothko was a Latvian-born American abstract painter. He is best known for his color field paintings that depicted irregular and painterly rectangular regions of color, which he produced from 1949 to 1970. Although Rothko did not personally subscribe to any one school, he is associated with the American Abstract Expressionist movement of modern art. Originally emigrating to Portland, Oregon, from Russian Empire (Latvia) with his family, Rothko later moved to New York City where his youthful period of artistic production dealt primarily with urban scenery.", + "question": "what is Rothko best known for?" + }, + { + "context": "Malignant narcissism is a psychological syndrome that could include aspects of narcissistic personality disorder (NPD) alongside a mix of antisocial, paranoid and sadistic personality disorder traits. The importance of malignant narcissism and of projection as a defense mechanism has been confirmed in paranoia, as well as the patient's vulnerability to malignant narcissistic regression. A person with malignant narcissism exhibits paranoia in addition to the symptoms of a Narcissistic Personality Disorder. Because a malignant narcissist's personality cannot tolerate any criticism, being mocked typically causes paranoia.", + "question": "What symptoms a malignant narcissist might exhibit in addition to the symptoms of a NPD patient?" + }, + { + "context": "The 14 July Revolution, also known as the 1958 Iraqi military coup, was a coup d'état that took place on 14 July 1958 in Iraq which resulted in the toppling of King Faisal II and the overthrow of the Hashemite-led Kingdom of Iraq. The Iraqi Republic established in its wake ended the Hashemite Arab Federation between Iraq and Jordan that had been established just six months earlier. In July 1958, units of the Royal Iraqi Army were dispatched to Jordan in support of King Hussein. A group of Iraqi Free Officers, led by Brigadier Abd al-Karim Qasim and Colonel Abdul Salam Arif, took advantage of the opportunity and instead marched on Baghdad. On 14 July, revolutionary forces seized control of the capital and proclaimed a new republic, headed by a Revolutionary Council.", + "question": "When was the Hashemite Arab Federation formed?" + }, + { + "context": "The Tasmanian devil is a carnivorous marsupial of the family Dasyuridae. It was formerly present across mainland Australia, but became extinct there around 3,500 years ago. The size of a small dog, the Tasmanian devil became the largest carnivorous marsupial in the world following the extinction of the thylacine in 1936. It is related to quolls, and distantly related to the thylacine. It is characterised by its stocky and muscular build, black fur, pungent odour, extremely loud and disturbing screech, keen sense of smell, and ferocity when feeding. The Tasmanian devil's large head and neck allow it to generate among the strongest bites per unit body mass of any extant predatory land mammal. It hunts prey and scavenges on carrion.", + "question": "What allows Tasmanian devil to generate strong bites?" + } +] diff --git a/models/demos/squeezebert/tests/test_perf_device_squeezebert.py b/models/demos/squeezebert/tests/test_perf_device_squeezebert.py new file mode 100644 index 00000000000..7f8acbca401 --- /dev/null +++ b/models/demos/squeezebert/tests/test_perf_device_squeezebert.py @@ -0,0 +1,37 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import pytest +from models.utility_functions import is_grayskull +from models.perf.device_perf_utils import run_device_perf, check_device_perf, prep_device_perf_report + + +@pytest.mark.models_device_performance_bare_metal +@pytest.mark.parametrize( + "batch_size, test", + [ + [8, "sequence_size=384-batch_size=8-model_name=squeezebert/squeezebert-uncased"], + ], +) +def test_perf_device_bare_metal(batch_size, test): + subdir = "ttnn_squeezebert" + num_iterations = 1 + margin = 0.03 + expected_perf = 114.8 if is_grayskull() else 284.5 + + command = f"pytest tests/ttnn/integration_tests/squeezebert/test_ttnn_squeezebert.py::test_squeezebert_for_question_answering" + cols = ["DEVICE FW", "DEVICE KERNEL", "DEVICE BRISC KERNEL"] + + inference_time_key = "AVG DEVICE KERNEL SAMPLES/S" + expected_perf_cols = {inference_time_key: expected_perf} + + post_processed_results = run_device_perf(command, subdir, num_iterations, cols, batch_size) + expected_results = check_device_perf(post_processed_results, margin, expected_perf_cols, assert_on_fail=True) + prep_device_perf_report( + model_name=f"ttnn_squeezebert_{batch_size}", + batch_size=batch_size, + post_processed_results=post_processed_results, + expected_results=expected_results, + comments=test.replace("/", "_"), + ) diff --git a/models/demos/squeezebert/tests/test_performance.py b/models/demos/squeezebert/tests/test_performance.py new file mode 100644 index 00000000000..e08fcd6aa2d --- /dev/null +++ b/models/demos/squeezebert/tests/test_performance.py @@ -0,0 +1,152 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import ttnn +import time +import torch +import pytest +import transformers +from loguru import logger + +from models.utility_functions import is_grayskull +from models.perf.perf_utils import prep_perf_report + +from ttnn.model_preprocessing import preprocess_model_parameters +from models.demos.squeezebert.tt import ttnn_functional_squeezebert +from models.experimental.functional_common.attention_mask_functions import get_extended_attention_mask + +from models.utility_functions import ( + enable_persistent_kernel_cache, + disable_persistent_kernel_cache, +) + + +def preprocess_inputs( + input_ids, + token_type_ids, + position_ids, + attention_mask, +): + batch_size, *_ = input_ids.shape + + input_ids = ttnn.from_torch(input_ids, dtype=ttnn.uint32) + token_type_ids = ttnn.from_torch(token_type_ids, dtype=ttnn.uint32) + position_ids = ttnn.from_torch(position_ids, dtype=ttnn.uint32) + + if attention_mask is not None: + attention_mask = get_extended_attention_mask(attention_mask, input_ids.shape) + attention_mask = attention_mask.expand((batch_size, -1, -1, -1)) + attention_mask = torch.clamp(attention_mask, min=-100000) + attention_mask = ttnn.from_torch( + attention_mask, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + ) + return input_ids, token_type_ids, position_ids, attention_mask + + +def get_expected_times(squeezebert): + return {ttnn_functional_squeezebert: (13.5, 11.5) if is_grayskull() else (16.5, 8.5)}[squeezebert] + + +@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) +@pytest.mark.models_performance_bare_metal +@pytest.mark.models_performance_virtual_machine +@pytest.mark.parametrize("model_name", ["squeezebert/squeezebert-uncased"]) +@pytest.mark.parametrize("sequence_size", [384]) +@pytest.mark.parametrize("squeezebert", [ttnn_functional_squeezebert]) +def test_performance(device, use_program_cache, model_name, sequence_size, squeezebert): + disable_persistent_kernel_cache() + + num_iterations = 2 + batch_size = 8 + + config = transformers.SqueezeBertConfig.from_pretrained(model_name) + rf_model = transformers.SqueezeBertForQuestionAnswering.from_pretrained(model_name) + state_dict = rf_model.state_dict() + + input_ids = torch.randint(0, config.vocab_size, (batch_size, sequence_size)).to(torch.int32) + torch_token_type_ids = torch.zeros((batch_size, sequence_size), dtype=torch.int32) + position_ids = torch.zeros((batch_size, sequence_size), dtype=torch.int32) + torch_attention_mask = torch.ones(1, sequence_size) + + if squeezebert == ttnn_functional_squeezebert: + tt_model_name = f"ttnn_{model_name}_optimized" + else: + raise ValueError(f"Unknown squeezebert: {squeezebert}") + + parameters = preprocess_model_parameters( + model_name=tt_model_name, + initialize_model=lambda: rf_model, + custom_preprocessor=squeezebert.custom_preprocessor, + device=device, + ) + + ttnn_squeezebert_inputs_on_cpu = preprocess_inputs( + input_ids, + torch_token_type_ids, + position_ids, + torch_attention_mask, + ) + + start = time.time() + ttnn_squeezebert_inputs = [ + ttnn.to_device(tensor, device=device, memory_config=ttnn.L1_MEMORY_CONFIG) if tensor is not None else tensor + for tensor in ttnn_squeezebert_inputs_on_cpu + ] + tt_output = squeezebert.squeezebert_for_question_answering( + config, + *ttnn_squeezebert_inputs, + state_dict=state_dict, + base_addr=f"transformer.", + parameters=parameters, + device=device, + reader_patterns_cache={}, + ) + + tt_output = ttnn.from_device(tt_output, blocking=False) + ttnn.synchronize_device(device) + end = time.time() + inference_and_compile_time = end - start + enable_persistent_kernel_cache() + + start = time.time() + for _ in range(num_iterations): + ttnn_squeezebert_inputs = [ + ttnn.to_device(tensor, device=device, memory_config=ttnn.L1_MEMORY_CONFIG) if tensor is not None else tensor + for tensor in ttnn_squeezebert_inputs_on_cpu + ] + tt_output = squeezebert.squeezebert_for_question_answering( + config, + *ttnn_squeezebert_inputs, + state_dict=state_dict, + base_addr=f"transformer.", + parameters=parameters, + device=device, + reader_patterns_cache={}, + ) + tt_output = ttnn.from_device(tt_output, blocking=False) + ttnn.synchronize_device(device) + end = time.time() + average_inference_time = (end - start) / num_iterations + + expected_compile_time, expected_inference_time = get_expected_times(squeezebert) + prep_perf_report( + model_name=tt_model_name, + batch_size=batch_size, + inference_and_compile_time=inference_and_compile_time, + inference_time=average_inference_time, + expected_compile_time=expected_compile_time, + expected_inference_time=expected_inference_time, + comments="", + inference_time_cpu=0.0, + ) + + logger.info(f"Compile time: {inference_and_compile_time - average_inference_time}") + logger.info(f"Inference time: {average_inference_time}") + logger.info(f"Samples per second: {1 / average_inference_time * batch_size}") + + assert ( + average_inference_time < expected_inference_time + ), f"Expected inference time: {expected_inference_time} Actual inference time: {average_inference_time}" diff --git a/models/demos/squeezebert/tt/ttnn_functional_squeezebert.py b/models/demos/squeezebert/tt/ttnn_functional_squeezebert.py new file mode 100644 index 00000000000..b0fa2c60431 --- /dev/null +++ b/models/demos/squeezebert/tt/ttnn_functional_squeezebert.py @@ -0,0 +1,515 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import ttnn +import torch +from torch import nn +from models.utility_functions import is_grayskull +from tests.ttnn.ttnn_utility_fuction import get_shard_grid_from_num_cores +from models.experimental.functional_common.attention_mask_functions import get_extended_attention_mask + + +def transpose_for_scores(config, x, device, permute_tensor: bool): + new_x_shape = (x.shape[0], config.num_attention_heads, config.attention_head_size, x.shape[-1]) + x = ttnn.from_device(x) + x = ttnn.reshape(x, new_x_shape) + x = ttnn.to_device(x, device) + + if permute_tensor: + x = ttnn.permute(x, (0, 1, 3, 2)) + + return x + + +def transpose_output(config, x, device): + all_head_size = config.num_attention_heads * config.attention_head_size + if len(x.shape) == 4: + x = ttnn.permute(x, (0, 1, 3, 2)) + + new_x_shape = (x.shape[0], all_head_size, x.shape[3]) + x = ttnn.reshape(x, new_x_shape) + + return x + + +def permute_reshape(hidden_states, shape=(0, 2, 1), reshape=True): + bs, *_ = hidden_states.shape + hidden_states = ttnn.permute(hidden_states, (0, 2, 1)) + if reshape: + hidden_states = ttnn.reshape(hidden_states, (bs, hidden_states.shape[-2], hidden_states.shape[-1])) + + return hidden_states + + +def ttnn_conv1d( + device, + tt_input_tensor, + weights, + conv_params, + bias, + *, + output_dtype=ttnn.bfloat16, + weights_dtype=ttnn.bfloat8_b, + math_fidelity=ttnn.MathFidelity.LoFi, + deallocate_activation=False, + act_block_h=None, + height_sharding=True, + use_shallow_conv_variant=False, + fp32_accum=False, + packer_l1_acc=False, + debug=False, + groups=4, + math_approx=True, + activation="", + reallocate_halo=False, + reshard=False, +): + weights = ttnn.from_torch(weights, dtype=ttnn.float32) + bias = ttnn.from_torch(bias.unsqueeze(0).unsqueeze(0).unsqueeze(0), dtype=ttnn.float32) + + conv_config = ttnn.Conv1dConfig( + dtype=ttnn.bfloat16, + weights_dtype=ttnn.bfloat8_b, + activation=activation, + input_channels_alignment=(16 if use_shallow_conv_variant else 32), + deallocate_activation=deallocate_activation, + reallocate_halo_output=reallocate_halo, + act_block_h_override=32, + reshard_if_not_optimal=reshard, + shard_layout=( + ttnn.TensorMemoryLayout.HEIGHT_SHARDED if height_sharding else ttnn.TensorMemoryLayout.BLOCK_SHARDED + ), + core_grid=get_shard_grid_from_num_cores(56, device), + ) + compute_config = ttnn.init_device_compute_kernel_config( + device.arch(), + math_fidelity=math_fidelity, + math_approx_mode=math_approx, + fp32_dest_acc_en=fp32_accum, + packer_l1_acc=packer_l1_acc, + ) + + [tt_output_tensor_on_device, out_length, [weights_device, bias_device]] = ttnn.Conv1d( + input_tensor=tt_input_tensor, + weight_tensor=weights, + in_channels=tt_input_tensor.shape[-1], + out_channels=weights.shape[0], + device=device, + bias_tensor=bias, + kernel_size=1, + stride=1, + padding=0, + batch_size=tt_input_tensor.shape[0], + input_length=tt_input_tensor.shape[1], + conv_config=conv_config, + compute_config=compute_config, + conv_op_cache={}, + debug=debug, + groups=groups, + return_output_dim=True, + return_weights_and_bias=True, + ) + + tt_output_tensor_on_device = ttnn.squeeze(tt_output_tensor_on_device, 0) + tt_output_tensor_on_device = ttnn.reshape( + tt_output_tensor_on_device, (tt_input_tensor.shape[0], out_length, tt_output_tensor_on_device.shape[-1]) + ) + + tt_output_tensor = ttnn.from_device(tt_output_tensor_on_device) + + return tt_output_tensor + + +def squeezebert_conv_layernorm( + config, + hidden_states, + input_tensor, + *, + state_dict, + base_addr, + parameters, + device, + cin, + cout, + groups, +): + torch_hidden_states = ttnn.to_torch(hidden_states).to(torch.float32) + self_output_conv1d_ = nn.Conv1d(in_channels=cin, out_channels=cout, kernel_size=1, groups=groups) + self_output_conv1d_.weight = nn.Parameter(state_dict[f"{base_addr}conv1d.weight"]) + self_output_conv1d_.bias = nn.Parameter(state_dict[f"{base_addr}conv1d.bias"]) + + torch_self_output = self_output_conv1d_(torch_hidden_states) + self_output = ttnn.from_torch(torch_self_output, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device) + + self_output_layernorm = ttnn.add(self_output, input_tensor) + self_output_layernorm = permute_reshape(self_output_layernorm) + + attention_output = ttnn.layer_norm( + self_output_layernorm, + weight=parameters.layernorm.weight, + bias=parameters.layernorm.bias, + epsilon=config.layer_norm_eps, + ) + ttnn.deallocate(self_output_layernorm) + attention_output = permute_reshape(attention_output) + + return attention_output + + +def squeezebert_attention( + config, + hidden_states, + attention_mask, + *, + state_dict, + base_addr, + parameters, + device, + reader_patterns_cache, + num_cores_x=12, +): + num_heads = config.num_attention_heads + batch_size, hidden_size, _ = hidden_states.shape + head_size = hidden_size // num_heads + config.attention_head_size = head_size + + hidden_states = ttnn.to_layout(hidden_states, ttnn.ROW_MAJOR_LAYOUT) + hidden_states = permute_reshape(hidden_states) + hidden_states = ttnn.from_device(hidden_states) + mixed_query_layer = ttnn_conv1d( + device, + hidden_states, + nn.Parameter(state_dict[f"{base_addr}query.weight"]), + conv_params=[1, 0], + bias=nn.Parameter(state_dict[f"{base_addr}query.bias"]), + ) + mixed_query_layer = ttnn.to_device(mixed_query_layer, device) + mixed_query_layer = ttnn.permute(mixed_query_layer, (0, 2, 1)) + + mixed_key_layer = ttnn_conv1d( + device, + hidden_states, + nn.Parameter(state_dict[f"{base_addr}key.weight"]), + conv_params=[1, 0], + bias=nn.Parameter(state_dict[f"{base_addr}key.bias"]), + ) + mixed_key_layer = ttnn.to_device(mixed_key_layer, device) + mixed_key_layer = ttnn.permute(mixed_key_layer, (0, 2, 1)) + + mixed_value_layer = ttnn_conv1d( + device, + hidden_states, + nn.Parameter(state_dict[f"{base_addr}value.weight"]), + conv_params=[1, 0], + bias=nn.Parameter(state_dict[f"{base_addr}value.bias"]), + ) + mixed_value_layer = ttnn.to_device(mixed_value_layer, device) + mixed_value_layer = ttnn.permute(mixed_value_layer, (0, 2, 1)) + + query = transpose_for_scores(config, mixed_query_layer, device, True) + key = transpose_for_scores(config, mixed_key_layer, device, False) + value = transpose_for_scores(config, mixed_value_layer, device, True) + + ttnn.deallocate(mixed_query_layer) + ttnn.deallocate(mixed_key_layer) + ttnn.deallocate(mixed_value_layer) + + attention_scores = ttnn.matmul( + query, + key, + memory_config=ttnn.L1_MEMORY_CONFIG, + dtype=ttnn.bfloat16, + core_grid=ttnn.CoreGrid(y=batch_size, x=num_cores_x), + ) + ttnn.deallocate(query) + ttnn.deallocate(key) + + attention_probs = ttnn.transformer.attention_softmax_( + attention_scores, attention_mask=attention_mask, head_size=head_size + ) + + context_layer = ttnn.matmul( + attention_probs, + value, + memory_config=ttnn.L1_MEMORY_CONFIG, + dtype=ttnn.bfloat8_b, + core_grid=ttnn.CoreGrid(y=batch_size, x=num_cores_x), + ) + context_layer = transpose_output(config, context_layer, device) + + return context_layer + + +def squeezebert_intermediate( + config, + hidden_states, + *, + state_dict, + base_addr, + parameters, + device, + num_cores_x=12, +): + torch_hidden_states = ttnn.to_torch(hidden_states).to(torch.float32) + + torch_conv_ = nn.Conv1d( + in_channels=config.hidden_size, + out_channels=config.intermediate_size, + kernel_size=1, + groups=config.intermediate_groups, + ) + torch_conv_.weight = nn.Parameter(state_dict[f"{base_addr}conv1d.weight"]) + torch_conv_.bias = nn.Parameter(state_dict[f"{base_addr}conv1d.bias"]) + + torch_conv_output = torch_conv_(torch_hidden_states) + ttnn_conv_output = ttnn.from_torch(torch_conv_output, device=device, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT) + + output = ttnn.gelu(ttnn_conv_output) + return output + + +def squeezebert_layer( + config, + hidden_states, + attention_mask, + state_dict, + base_addr, + parameters, + device, + reader_patterns_cache, +): + multi_head_attention_output = squeezebert_attention( + config, + hidden_states=hidden_states, + attention_mask=attention_mask, + state_dict=state_dict, + base_addr=f"{base_addr}attention.", + parameters=parameters.attention, + device=device, + reader_patterns_cache=reader_patterns_cache, + ) + + attention_output = squeezebert_conv_layernorm( + config, + hidden_states=multi_head_attention_output, + input_tensor=hidden_states, + state_dict=state_dict, + base_addr=f"{base_addr}post_attention.", + parameters=parameters.post_attention, + device=device, + cin=config.hidden_size, + cout=config.hidden_size, + groups=config.post_attention_groups, + ) + ttnn.deallocate(hidden_states) + ttnn.deallocate(multi_head_attention_output) + + intermediate = squeezebert_intermediate( + config, + attention_output, + state_dict=state_dict, + base_addr=f"{base_addr}intermediate.", + parameters=parameters.intermediate, + device=device, + ) + + output = squeezebert_conv_layernorm( + config, + hidden_states=intermediate, + input_tensor=attention_output, + state_dict=state_dict, + base_addr=f"{base_addr}output.", + parameters=parameters.output, + device=device, + cin=config.intermediate_size, + cout=config.hidden_size, + groups=config.output_groups, + ) + + return output + + +def squeezebert_encoder( + config, + hidden_states, + attention_mask, + *, + state_dict, + base_addr, + parameters, + device, + reader_patterns_cache, +): + hidden_states = permute_reshape(hidden_states) + encoder_output = None + + for layer_idx, encoder_parameters in enumerate(parameters.layers): + encoder_output = squeezebert_layer( + config, + hidden_states, + attention_mask, + state_dict, + base_addr=f"{base_addr}layers.{layer_idx}.", + parameters=encoder_parameters, + device=device, + reader_patterns_cache=reader_patterns_cache, + ) + encoder_output = ttnn.reallocate(encoder_output) + hidden_states = encoder_output + + hidden_states = permute_reshape(hidden_states) + + return hidden_states + + +def squeezebert( + config, + input_ids, + token_type_ids, + position_ids, + attention_mask, + state_dict, + base_addr, + parameters, + device, + reader_patterns_cache, +): + word_embeddings = ttnn.embedding( + input_ids, + parameters.embeddings.word_embeddings.weight, + layout=ttnn.TILE_LAYOUT, + padding_idx=config.pad_token_id, + ) + ttnn.deallocate(input_ids) + + token_type_embeddings = ttnn.embedding( + token_type_ids, + parameters.embeddings.token_type_embeddings.weight, + layout=ttnn.TILE_LAYOUT, + ) + ttnn.deallocate(token_type_ids) + + word_plus_token_type_embeddings = word_embeddings + token_type_embeddings + ttnn.deallocate(word_embeddings) + ttnn.deallocate(token_type_embeddings) + + position_embeddings = ttnn.embedding( + position_ids, + parameters.embeddings.position_embeddings.weight, + layout=ttnn.TILE_LAYOUT, + ) + ttnn.deallocate(position_ids) + + embeddings = word_plus_token_type_embeddings + position_embeddings + ttnn.deallocate(word_plus_token_type_embeddings) + ttnn.deallocate(position_embeddings) + + encoder_input = ttnn.layer_norm( + embeddings, + weight=parameters.embeddings.LayerNorm.weight, + bias=parameters.embeddings.LayerNorm.bias, + memory_config=ttnn.DRAM_MEMORY_CONFIG if is_grayskull() else ttnn.L1_MEMORY_CONFIG, + ) + ttnn.deallocate(embeddings) + + encoder_output = squeezebert_encoder( + config=config, + hidden_states=encoder_input, + attention_mask=attention_mask, + state_dict=state_dict, + base_addr=f"{base_addr}encoder.", + parameters=parameters.encoder, + device=device, + reader_patterns_cache=reader_patterns_cache, + ) + ttnn.deallocate(encoder_input) + + return encoder_output + + +def squeezebert_for_question_answering( + config, + input_ids, + token_type_ids, + position_ids, + attention_mask, + *, + state_dict, + base_addr, + parameters, + device, + reader_patterns_cache, + name="transformer", +): + squeezebert_output = squeezebert( + config, + input_ids, + token_type_ids, + position_ids, + attention_mask, + state_dict, + base_addr, + parameters=parameters.transformer, + device=device, + reader_patterns_cache=reader_patterns_cache, + ) + qa_outputs = ttnn.linear( + squeezebert_output, + parameters.qa_outputs.weight, + bias=parameters.qa_outputs.bias, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + + return qa_outputs + + +def preprocess_inputs( + input_ids, + token_type_ids, + position_ids, + attention_mask, + device, +): + import torch + + batch_size, _ = input_ids.shape + + input_ids = ttnn.from_torch(input_ids, dtype=ttnn.uint32, device=device, memory_config=ttnn.L1_MEMORY_CONFIG) + token_type_ids = ttnn.from_torch( + token_type_ids, dtype=ttnn.uint32, device=device, memory_config=ttnn.L1_MEMORY_CONFIG + ) + position_ids = ttnn.from_torch(position_ids, dtype=ttnn.uint32, device=device, memory_config=ttnn.L1_MEMORY_CONFIG) + + if attention_mask is not None: + attention_mask = get_extended_attention_mask(attention_mask, input_ids.shape, torch.float32) + attention_mask = attention_mask.expand((batch_size, -1, -1, -1)) + attention_mask = torch.clamp(attention_mask, min=-100000) + attention_mask = ttnn.from_torch( + attention_mask, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + device=device, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + + return input_ids, token_type_ids, position_ids, attention_mask + + +def preprocess_conv_parameter(parameter, *, dtype): + parameter = ttnn.from_torch(parameter, dtype=dtype) + return parameter + + +def custom_preprocessor(model, name): + parameters = {} + if isinstance(model, nn.Conv1d): + weight = model.weight + bias = model.bias + + while bias.dim() < 4: + bias = bias.unsqueeze(0).unsqueeze(0).unsqueeze(0) + parameters["weight"] = preprocess_conv_parameter(weight, dtype=ttnn.float32) + parameters["bias"] = preprocess_conv_parameter(bias, dtype=ttnn.float32) + + return parameters diff --git a/models/demos/t3000/llama3_70b/README.md b/models/demos/t3000/llama3_70b/README.md index 6555cd36dbf..80f344040d4 100644 --- a/models/demos/t3000/llama3_70b/README.md +++ b/models/demos/t3000/llama3_70b/README.md @@ -1,32 +1,74 @@ -# Llama3-70B Demo +# Llama3/3.1-70B Demo + +## Table of Contents + +- [One command run](#one-command-run) +- [How to Run](#how-to-run) + - [Running the demo from TT-Metalium](#running-the-demo-from-tt-metalium) + - [Serving the model from vLLM](#serving-the-model-from-vllm) + +## One command run + +```bash +chmod +x ./models/demos/t3000/llama3_70b/setup_llama.sh && ./models/demos/t3000/llama3_70b/setup_llama.sh +``` + +Where, `TT_METAL_COMMIT_SHA_OR_TAG` and `TT_VLLM_COMMIT_SHA_OR_TAG` are found in the root [README](/README.md#llms) under "Release" version, respectively. + +Example: + +```bash +./models/demos/t3000/llama3_70b/setup_llama.sh llama-3.1-70b-instruct v0.53.0-rc36 384f1790c3be16e1d1b10de07252be2e66d00935 +``` + +Follow prompts as they come up in CLI to select appropriate weights for Llama 3.1 70B Instruct. + +Prerequisites: + +- Submit request to access weights from Meta: [Llama Downloads](https://www.llama.com/llama-downloads) +- Submit permissions on HuggingFace and have a HF personal access token: [Llama 3.1 70B Instruct](https://huggingface.co/meta-llama/Llama-3.1-70B-Instruct) + +Steps run: + +- Setup environment +- Build `tt-metal` +- Download Llama 3.1 70B Instruct weights +- Install vLLM +- Deploy vLLM server ## How to Run -1. **Download the Llama3-70B weights from Meta (https://llama.meta.com/):** +Note: This guide requires the installation / build of `tt-metal`. Please refer to the [installation instructions](/INSTALLING.md) for the release corresponding to [README](/README.md#llms). + +1. **Download the Llama3/3.1-70B weights from Meta ():** 2. **Repack the weights:** + ```bash # This concatenates the sharded checkpoints and makes it easier for us to load. python models/demos/t3000/llama2_70b/scripts/repack_weights.py ``` + Note: Use `5` for `chunk_size`. Once the weights are repacked, move the `params.json` file from the `checkpoint_dir` to the `repacked_output_dir`. -### Running the Demo +### Running the demo from TT-Metalium After setting up the repacked weights and tokenizer, you can run the demo using the commands below: 1. **Prepare the weight cache directory:** + ```bash # Make a directory for us to cache weights into. This speeds up subsequent runs. mkdir ``` 2. **Set up environment variables:** + ```bash export LLAMA3_CKPT_DIR= - export LLAMA3_TOKENIZER_PATH= # Path needs to include the tokenizer.model file + export LLAMA3_TOKENIZER_PATH=/tokenizer.model # Path needs to include the tokenizer.model file export LLAMA3_CACHE_PATH= export WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml @@ -38,13 +80,11 @@ After setting up the repacked weights and tokenizer, you can run the demo using # export LLAMA3_CKPT_DIR="/home/llama-data-repacked/llama-3-70b/" # export LLAMA3_TOKENIZER_PATH="/home/llama-data-repacked/tokenizer.model" # export LLAMA3_CACHE_PATH="/home/llama-data-cache/weights-cache" - - ``` 3. **Run the demo:** - NOTE: Run the following comand twice. + Note: Run the following command twice. 1. The first run will cache the weights. This will take some time. 2. The second run will use the cached weights, thereby running much faster. @@ -58,31 +98,74 @@ After setting up the repacked weights and tokenizer, you can run the demo using The above demo does not achieve peak performance because we log outputs to the screen. The following perf test will print an accurate end-to-end throughput number. For best performance, ensure that tt-metal is built in release mode (default), and ensure the host's CPU frequency governors are set to `performance` -- instructions for setting the frequency governor vary by machine. This performance test runs with sequence length 128 and batch size 32. + ```bash pytest -svv models/demos/t3000/llama2_70b/tests/test_llama_perf_decode.py::test_Llama_perf_host[wormhole_b0-True-device_params0-gen128-llama3] ``` -## Details +#### Details Supported context lengths and batch sizes for the Llama3.1-70B demo are as follows: | Context Length | Max Batch Size | -|----------------|------------| -| 2k | 32 | -| 8k | 16 | -| 128k | 1 | +|----------------|----------------| +| 2k | 32 | +| 8k | 16 | +| 128k | 1 | - **Input File:** Uses `./demo/data/multi_prompt.json`. - **Model Configuration:** Utilizes a pretrained model. - **Hardware Requirements:** Runs on an 8-chip T3000 machine using tensor parallelism. The host machine must have at least 512 GB of memory. - **Demo arguments:** - - `context: [short_context, long_context, 128k_context]`: Select between short context (batch 32, sequence_length 2k) and long context (batch 16, sequence length 8k) and full context (batch 1, sequence length 128k) - - `ground_truth: [check_disabled, check_enabled]`: Enable or disable ground truth checking, used for testing - - `sampling: [greedy, sampling]`: Select between greedy decoding and top-k/top-p sampling - - `implementation: [tt-70b-T3000]`: Run the 70B model on the Tenstorrent backend - - `num_layers: [1L, 2L, 10L, 80L]`: Select 80L to run the full model - - `decode_only: [decode_only, prefill_decode]`: Use `prefill_decode`. Alternately, `decode_only` implements prefill via decode. - - `chat: [text_completion, chat_completion]`: Run in text_completion mode for the pretrained model or chat_completion for the finetuned model - - `llama_version: [llama3, llama2]`: Select the Llama3 model + - `context: [short_context, long_context, 128k_context]`: Select between short context (batch 32, sequence_length 2k) and long context (batch 16, sequence length 8k) and full context (batch 1, sequence length 128k) + - `ground_truth: [check_disabled, check_enabled]`: Enable or disable ground truth checking, used for testing + - `sampling: [greedy, sampling]`: Select between greedy decoding and top-k/top-p sampling + - `implementation: [tt-70b-T3000]`: Run the 70B model on the Tenstorrent backend + - `num_layers: [1L, 2L, 10L, 80L]`: Select 80L to run the full model + - `decode_only: [decode_only, prefill_decode]`: Use `prefill_decode`. Alternately, `decode_only` implements prefill via decode. + - `chat: [text_completion, chat_completion]`: Run in text_completion mode for the pretrained model or chat_completion for the finetuned model + - `llama_version: [llama3, llama2]`: Select the Llama3 model Ensure you follow these guidelines to successfully run the Llama3-70B demo. + +### Serving the model from vLLM + +1. Complete Step 1 and Step 2 of [Running the Demo from TT-Metalium](#running-the-demo-from-tt-metalium) + +2. **Install vLLM** + + ```bash + # Installing from within `tt-metal` + export VLLM_TARGET_DEVICE="tt" + git clone https://github.com/tenstorrent/vllm.git + cd vllm + git checkout TT_VLLM_COMMIT_SHA_OR_TAG + pip install -e . + cd .. + ``` + + > **Note:** TT_VLLM_COMMIT_SHA_OR_TAG is the vLLM Release version from [README](/README.md#llms) + +3. **Running the server** + + ```bash + python vllm/examples/server_example_tt.py + ``` + +4. **Interact with server** + + In a separate terminal window, run: + + ```bash + curl http://localhost:8000/v1/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "meta-llama/Meta-Llama-3.1-70B", + "prompt": "Write a poem about RISC-V", + "max_tokens": 128, + "temperature": 1, + "top_p": 0.9, + "top_k": 10, + "stream": false + }' + ``` diff --git a/models/demos/t3000/llama3_70b/setup_llama.sh b/models/demos/t3000/llama3_70b/setup_llama.sh new file mode 100644 index 00000000000..636ce070b2b --- /dev/null +++ b/models/demos/t3000/llama3_70b/setup_llama.sh @@ -0,0 +1,250 @@ +#!/bin/bash +# SPDX-License-Identifier: Apache-2.0 +# +# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC +# +# Purpose: Setup and deploy Llama 3.1 70B Instruct model with dependencies. + +set -euo pipefail + +# Function to display usage information +usage() { + cat < + +Description: + This script sets up and deploys the Llama model along with its dependencies. + +Arguments: + The type of model to deploy. Supported options: + - llama-3.1-70b-instruct + - llama-3.1-70b + - llama-3.1-8b-instruct + - llama-3.1-8b + - llama-3-70b-instruct + - llama-3-70b + - llama-3-8b-instruct + - llama-3-8b + The commit SHA or tag to use for TT_METAL. + The commit SHA or tag to use for vLLM. + +Options: + -h, --help Display this help message. + +Examples: + # Deploy the llama-3.1-70b-instruct model + $0 llama-3.1-70b-instruct main dev + + # Deploy with specific commit SHAs + $0 llama-3.1-70b-instruct v0.53.0-rc36 384f1790c3be16e1d1b10de07252be2e66d00935 + +EOF + exit 0 +} + +# helper +if [[ "$1" == "-h" || "$1" == "--help" ]]; then + usage +fi + +# Require commit SHA or tag for TT_METAL and vLLM +TT_METAL_COMMIT_SHA_OR_TAG=${2:-""} +TT_VLLM_COMMIT_SHA_OR_TAG=${3:-""} + +# Ensure required arguments are passed +if [[ -z "${TT_METAL_COMMIT_SHA_OR_TAG}" || -z "${TT_VLLM_COMMIT_SHA_OR_TAG}" ]]; then + echo "❌ Error: Both TT_METAL_COMMIT_SHA_OR_TAG and TT_VLLM_COMMIT_SHA_OR_TAG are required." + usage +fi + +# Defined variables +DEFAULT_PERSISTENT_VOLUME_ROOT=~/persistent_volume +DEFAULT_LLAMA_REPO=~/llama-models + +# functions +error_exit() { + echo "⛔ Error: $1" >&2 + exit 1 +} + +print_step() { + echo -e "\n👉 $1...\n" +} + +setup_model_environment() { + print_step "Setting up model environment for $1" + case "$1" in + "llama-3.1-70b-instruct") + MODEL="llama-3.1-70b-instruct" + META_MODEL_NAME="Meta-Llama-3.1-70B-Instruct" + META_DIR_FILTER="llama3_1" + REPACKED=1 + ;; + "llama-3.1-70b") + MODEL="llama-3.1-70b" + META_MODEL_NAME="Meta-Llama-3.1-70B" + META_DIR_FILTER="llama3_1" + REPACKED=1 + ;; + "llama-3.1-8b-instruct") + MODEL="llama-3.1-8b-instruct" + META_MODEL_NAME="Meta-Llama-3.1-8B-Instruct" + META_DIR_FILTER="llama3_1" + REPACKED=0 + ;; + "llama-3.1-8b") + MODEL_NAME="llama-3.1-8b" + META_MODEL_NAME="Meta-Llama-3.1-8B" + META_DIR_FILTER="llama3_1" + REPACKED=0 + ;; + "llama-3-70b-instruct") + MODEL="llama-3-70b-instruct" + META_MODEL_NAME="Meta-Llama-3-70B-Instruct" + META_DIR_FILTER="llama3" + REPACKED=1 + ;; + "llama-3-70b") + MODEL="llama-3-70b" + META_MODEL_NAME="Meta-Llama-3-70B" + META_DIR_FILTER="llama3" + REPACKED=1 + ;; + "llama-3-8b-instruct") + MODEL="llama-3-8b-instruct" + META_MODEL_NAME="Meta-Llama-3-8B-Instruct" + META_DIR_FILTER="llama3" + REPACKED=0 + ;; + "llama-3-8b") + MODEL="llama-3-8b" + META_MODEL_NAME="Meta-Llama-3-8B" + META_DIR_FILTER="llama3" + REPACKED=0 + ;; + *) + echo "⛔ Invalid model choice." + usage + exit 1 + ;; + esac + + if [ "${REPACKED}" -eq 1 ]; then + echo "REPACKED is enabled." + REPACKED_STR="repacked-" + else + echo "REPACKED is disabled." + REPACKED_STR="" + fi +} + +setup_environment() { + print_step "Setting up environment" + export LLAMA3_CKPT_DIR="${DEFAULT_PERSISTENT_VOLUME_ROOT}/model_weights/${REPACKED_STR}${MODEL}" + export LLAMA3_TOKENIZER_PATH="${LLAMA3_CKPT_DIR}/tokenizer.model" + export LLAMA3_CACHE_PATH="${DEFAULT_PERSISTENT_VOLUME_ROOT}/tt_metal_cache/cache_${REPACKED_STR}${MODEL}" + export ARCH_NAME=wormhole_b0 + export TT_METAL_HOME=$(pwd) + export PYTHONPATH=$(pwd) + echo "Environment variables set." +} + +check_and_build_tt_metal() { + print_step "Checking and building tt-metal" + pushd "${TT_METAL_HOME}" >/dev/null + if [[ ! -d "python_env" ]]; then + git checkout "${TT_METAL_COMMIT_SHA_OR_TAG}" + git submodule update --init --recursive + git submodule foreach 'git lfs fetch --all && git lfs pull' + ./build_metal.sh + ./create_venv.sh + source python_env/bin/activate + pip install -r models/demos/t3000/llama2_70b/reference/llama/requirements.txt + else + echo "🔔 tt-metal Python environment already exists. Skipping build." + source python_env/bin/activate + fi + popd >/dev/null +} + +clone_repo() { + local REPO_PATH=$1 + local REPO_URL=$2 + local COMMIT_HASH=$3 + + print_step "Cloning Llama repository" + if [[ ! -d "${REPO_PATH}" ]]; then + git clone "${REPO_URL}" "${REPO_PATH}" + pushd "${REPO_PATH}" >/dev/null + git checkout "${COMMIT_HASH}" + popd >/dev/null + else + echo "🔔 Repository already exists at ${REPO_PATH}, skipping clone." + fi +} + +setup_weights() { + print_step "Setting up weights" + local LLAMA_REPO=$1 + local LLAMA_DIR="${LLAMA_REPO}/models/${META_DIR_FILTER}" + local LLAMA_WEIGHTS_DIR="${LLAMA_DIR}/${META_MODEL_NAME}" + local WEIGHTS_DIR="${LLAMA3_CKPT_DIR}" + + mkdir -p "${WEIGHTS_DIR}" "${LLAMA3_CACHE_PATH}" + + if [[ -d "${LLAMA_WEIGHTS_DIR}" && -n "$(ls -A "${LLAMA_WEIGHTS_DIR}")" ]]; then + echo "Weights already downloaded in ${LLAMA_WEIGHTS_DIR}" + else + print_step "Downloading weights" + pushd "${LLAMA_DIR}" >/dev/null + [[ -x "./download.sh" ]] && ./download.sh || error_exit "Download script not found!" + popd >/dev/null + fi + + huggingface-cli login + + if [ "${REPACKED}" -eq 1 ]; then + print_step "Repacking weights" + source python_env/bin/activate + cp "${LLAMA_WEIGHTS_DIR}/tokenizer.model" "${WEIGHTS_DIR}/tokenizer.model" + cp "${LLAMA_WEIGHTS_DIR}/params.json" "${WEIGHTS_DIR}/params.json" + python models/demos/t3000/llama2_70b/scripts/repack_weights.py "${LLAMA_WEIGHTS_DIR}" "${WEIGHTS_DIR}" 5 + else + cp -rf "${LLAMA_WEIGHTS_DIR}" "${WEIGHTS_DIR}" + fi + + echo "🔔 Using weights directory ${WEIGHTS_DIR}" +} + +install_vllm() { + print_step "Installing vLLM" + if [[ ! -d "vllm" ]]; then + source python_env/bin/activate + export VLLM_TARGET_DEVICE="tt" + git clone https://github.com/tenstorrent/vllm.git + pushd vllm >/dev/null + git checkout "${TT_VLLM_COMMIT_SHA_OR_TAG}" + pip install -e . + popd >/dev/null + else + echo "🔔 vLLM already installed. Skipping install." + fi +} + +deploy_server() { + print_step "Deploying Llama server" + source python_env/bin/activate + export WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml + python vllm/examples/server_example_tt.py + echo "✅ Deployment complete! Interact via http://localhost:8000." +} + +# ---- MAIN ---- +MODEL_TYPE=$1 +setup_model_environment "$MODEL_TYPE" +setup_environment +check_and_build_tt_metal +clone_repo "${DEFAULT_LLAMA_REPO}" "https://github.com/meta-llama/llama-models.git" "685ac4c107c75ce8c291248710bf990a876e1623" +setup_weights "${DEFAULT_LLAMA_REPO}" +install_vllm +deploy_server diff --git a/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_large_new_conv_api.py b/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_large_new_conv_api.py index cfe555d0367..ce49bfbfa51 100644 --- a/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_large_new_conv_api.py +++ b/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_large_new_conv_api.py @@ -167,7 +167,7 @@ def run_downsample_if_req( shard_layout = ( ttnn.TensorMemoryLayout.HEIGHT_SHARDED if height_sharding else ttnn.TensorMemoryLayout.BLOCK_SHARDED ) - ds_out, _, _, self.ds_conv_weight_tensor, self.ds_conv_bias_tensor = ttnn.conv2d( + ds_out, [self.ds_conv_weight_tensor, self.ds_conv_bias_tensor] = ttnn.conv2d( input_tensor=x, weight_tensor=self.ds_conv_weight_tensor, in_channels=self.ds_conv_input_channels, @@ -183,13 +183,17 @@ def run_downsample_if_req( conv_config=ttnn.Conv2dConfig( dtype=self.model_config["ACTIVATIONS_DTYPE"], weights_dtype=self.model_config["WEIGHTS_DTYPE"], - math_fidelity=self.model_config["MATH_FIDELITY"], shard_layout=shard_layout, deallocate_activation=True, reallocate_halo_output=True, reshard_if_not_optimal=reshard_if_not_optimal, ), + compute_config=ttnn.init_device_compute_kernel_config( + device.arch(), math_fidelity=self.model_config["MATH_FIDELITY"] + ), conv_op_cache=conv_op_cache, + return_output_dim=False, + return_weights_and_bias=True, ) ttnn.deallocate(x) ds_out = ttnn.reallocate(ds_out) @@ -214,7 +218,7 @@ def __call__( # conv1 is 1x1 conv # print("Running conv1") module_input_height = input_height - out, input_height, input_width, self.conv1_weight_tensor, self.conv1_bias_tensor = ttnn.conv2d( + out, [input_height, input_width], [self.conv1_weight_tensor, self.conv1_bias_tensor] = ttnn.conv2d( input_tensor=x, weight_tensor=self.conv1_weight_tensor, in_channels=self.conv1_input_channels, @@ -230,14 +234,18 @@ def __call__( conv_config=ttnn.Conv2dConfig( dtype=self.model_config["ACTIVATIONS_DTYPE"], weights_dtype=self.model_config["WEIGHTS_DTYPE"], - math_fidelity=self.model_config["MATH_FIDELITY"], activation="relu", shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED if height_sharding else ttnn.TensorMemoryLayout.BLOCK_SHARDED, reshard_if_not_optimal=reshard_if_not_optimal, ), + compute_config=ttnn.init_device_compute_kernel_config( + device.arch(), math_fidelity=self.model_config["MATH_FIDELITY"] + ), conv_op_cache=conv_op_cache, + return_output_dim=True, + return_weights_and_bias=True, ) act_block_h_override = 0 @@ -277,7 +285,7 @@ def __call__( ) # if ds_out_mem_config and ds_out_mem_config != ttnn.get_memory_config(out): # out = ttnn.to_memory_config(out, ds_out_mem_config) - out, input_height, input_width, self.conv2_weight_tensor, self.conv2_bias_tensor = ttnn.conv2d( + out, [input_height, input_width], [self.conv2_weight_tensor, self.conv2_bias_tensor] = ttnn.conv2d( input_tensor=out, weight_tensor=self.conv2_weight_tensor, in_channels=self.conv2_input_channels, @@ -293,7 +301,6 @@ def __call__( conv_config=ttnn.Conv2dConfig( dtype=self.model_config["ACTIVATIONS_DTYPE"], weights_dtype=self.model_config["WEIGHTS_DTYPE"], - math_fidelity=self.model_config["MATH_FIDELITY"], activation="relu", deallocate_activation=True, reallocate_halo_output=reallocate_halo_output, @@ -303,12 +310,17 @@ def __call__( else ttnn.TensorMemoryLayout.BLOCK_SHARDED, reshard_if_not_optimal=reshard_if_not_optimal, ), + compute_config=ttnn.init_device_compute_kernel_config( + device.arch(), math_fidelity=self.model_config["MATH_FIDELITY"] + ), conv_op_cache=conv_op_cache, + return_output_dim=True, + return_weights_and_bias=True, ) # conv3 is 1x1 conv # print("Running conv3") - out, _, _, self.conv3_weight_tensor, self.conv3_bias_tensor = ttnn.conv2d( + out, [self.conv3_weight_tensor, self.conv3_bias_tensor] = ttnn.conv2d( input_tensor=out, weight_tensor=self.conv3_weight_tensor, in_channels=self.conv3_input_channels, @@ -324,13 +336,17 @@ def __call__( conv_config=ttnn.Conv2dConfig( dtype=self.model_config["ACTIVATIONS_DTYPE"], weights_dtype=self.model_config["WEIGHTS_DTYPE"], - math_fidelity=self.model_config["MATH_FIDELITY"], shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED if height_sharding else ttnn.TensorMemoryLayout.BLOCK_SHARDED, reshard_if_not_optimal=reshard_if_not_optimal, ), + compute_config=ttnn.init_device_compute_kernel_config( + device.arch(), math_fidelity=self.model_config["MATH_FIDELITY"] + ), conv_op_cache=conv_op_cache, + return_weights_and_bias=True, + return_output_dim=False, ) if not self.run_downsample_before_conv2: @@ -546,7 +562,7 @@ def first_run(self, input_tensor, device, batch_size, ops_parallel_config) -> tt input_tensor, device=device, memory_config=self.grayskull_conv1_input_memory_config ) - x, x_height, x_width, self.conv1_weight_tensor, self.conv1_bias_tensor = ttnn.conv2d( + x, [x_height, x_width], [self.conv1_weight_tensor, self.conv1_bias_tensor] = ttnn.conv2d( input_tensor=input_tensor, weight_tensor=self.conv1_weight_tensor, in_channels=self.conv1_input_channels, @@ -562,13 +578,17 @@ def first_run(self, input_tensor, device, batch_size, ops_parallel_config) -> tt conv_config=ttnn.Conv2dConfig( dtype=self.model_config["ACTIVATIONS_DTYPE"], weights_dtype=self.model_config["WEIGHTS_DTYPE"], - math_fidelity=self.model_config["MATH_FIDELITY"], activation="relu", deallocate_activation=True, input_channels_alignment=16 if not is_wormhole_b0() else 32, act_block_h_override=act_block_h_override, ), + compute_config=ttnn.init_device_compute_kernel_config( + device.arch(), math_fidelity=self.model_config["MATH_FIDELITY"] + ), conv_op_cache=conv_op_cache, + return_output_dim=True, + return_weights_and_bias=True, ) # Relu is fused with conv1 @@ -857,7 +877,7 @@ def optimized_run(self, input_tensor, device, batch_size, ops_parallel_config, c input_tensor, device=device, memory_config=self.grayskull_conv1_input_memory_config ) - x, x_height, x_width, self.conv1_weight_tensor, self.conv1_bias_tensor = ttnn.conv2d( + x, [x_height, x_width], [self.conv1_weight_tensor, self.conv1_bias_tensor] = ttnn.conv2d( input_tensor=input_tensor, weight_tensor=self.conv1_weight_tensor, in_channels=self.conv1_input_channels, @@ -873,13 +893,17 @@ def optimized_run(self, input_tensor, device, batch_size, ops_parallel_config, c conv_config=ttnn.Conv2dConfig( dtype=self.model_config["ACTIVATIONS_DTYPE"], weights_dtype=self.model_config["WEIGHTS_DTYPE"], - math_fidelity=self.model_config["MATH_FIDELITY"], activation="relu", deallocate_activation=True, input_channels_alignment=16 if not is_wormhole_b0() else 32, act_block_h_override=act_block_h_override, ), + compute_config=ttnn.init_device_compute_kernel_config( + device.arch(), math_fidelity=self.model_config["MATH_FIDELITY"] + ), conv_op_cache=conv_op_cache, + return_output_dim=True, + return_weights_and_bias=True, ) # Relu is fused with conv1 diff --git a/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_new_conv_api.py b/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_new_conv_api.py index 44d90cb0f34..a8944b654c3 100644 --- a/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_new_conv_api.py +++ b/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_new_conv_api.py @@ -160,7 +160,7 @@ def run_downsample_if_req( ): if self.downsample: logger.debug(f"Running downsample") - ds_out, _, _, self.ds_conv_weight_tensor, self.ds_conv_bias_tensor = ttnn.conv2d( + ds_out, [self.ds_conv_weight_tensor, self.ds_conv_bias_tensor] = ttnn.conv2d( input_tensor=x, weight_tensor=self.ds_conv_weight_tensor, in_channels=self.ds_conv_input_channels, @@ -176,7 +176,6 @@ def run_downsample_if_req( conv_config=ttnn.Conv2dConfig( dtype=self.model_config["ACTIVATIONS_DTYPE"], weights_dtype=self.model_config["WEIGHTS_DTYPE"], - math_fidelity=self.model_config["MATH_FIDELITY"], shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED if height_sharding else ttnn.TensorMemoryLayout.BLOCK_SHARDED, @@ -184,7 +183,6 @@ def run_downsample_if_req( reallocate_halo_output=not (is_wormhole_b0() and batch_size == 16), reshard_if_not_optimal=reshard_if_not_optimal, transpose_shards=transpose_shards, - packer_l1_accum_enabled=packer_l1_accum_enabled, enable_act_double_buffer=enable_act_double_buffer if height_sharding else True @@ -194,7 +192,14 @@ def run_downsample_if_req( enable_split_reader=enable_split_reader, enable_subblock_padding=enable_subblock_padding, ), + compute_config=ttnn.init_device_compute_kernel_config( + device.arch(), + math_fidelity=self.model_config["MATH_FIDELITY"], + packer_l1_acc=packer_l1_accum_enabled, + ), conv_op_cache=conv_op_cache, + return_output_dim=False, + return_weights_and_bias=True, ) ttnn.deallocate(x) ds_out = ttnn.reallocate(ds_out) @@ -226,7 +231,7 @@ def __call__( # conv1 is 1x1 conv logger.debug(f"Running conv1") module_input_height = input_height - out, input_height, input_width, self.conv1_weight_tensor, self.conv1_bias_tensor = ttnn.conv2d( + out, [input_height, input_width], [self.conv1_weight_tensor, self.conv1_bias_tensor] = ttnn.conv2d( input_tensor=x, weight_tensor=self.conv1_weight_tensor, in_channels=self.conv1_input_channels, @@ -242,16 +247,21 @@ def __call__( conv_config=ttnn.Conv2dConfig( dtype=self.model_config["ACTIVATIONS_DTYPE"], weights_dtype=self.model_config["WEIGHTS_DTYPE"], - math_fidelity=self.model_config["MATH_FIDELITY"], activation="relu", shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED if height_sharding else ttnn.TensorMemoryLayout.BLOCK_SHARDED, reshard_if_not_optimal=reshard_if_not_optimal, transpose_shards=transpose_shards, - packer_l1_accum_enabled=packer_l1_acc, + ), + compute_config=ttnn.init_device_compute_kernel_config( + device.arch(), + math_fidelity=self.model_config["MATH_FIDELITY"], + packer_l1_acc=packer_l1_acc, ), conv_op_cache=conv_op_cache, + return_output_dim=True, + return_weights_and_bias=True, ) act_block_h_override = 0 @@ -307,7 +317,7 @@ def __call__( reallocate_halo_output = batch_size == 20 logger.debug(f"Running conv2") - out, input_height, input_width, self.conv2_weight_tensor, self.conv2_bias_tensor = ttnn.conv2d( + out, [input_height, input_width], [self.conv2_weight_tensor, self.conv2_bias_tensor] = ttnn.conv2d( input_tensor=out, weight_tensor=self.conv2_weight_tensor, in_channels=self.conv2_input_channels, @@ -323,7 +333,6 @@ def __call__( conv_config=ttnn.Conv2dConfig( dtype=self.model_config["ACTIVATIONS_DTYPE"], weights_dtype=self.model_config["WEIGHTS_DTYPE"], - math_fidelity=self.model_config["MATH_FIDELITY"], activation="relu", deallocate_activation=True, reallocate_halo_output=reallocate_halo_output, @@ -333,13 +342,19 @@ def __call__( else ttnn.TensorMemoryLayout.BLOCK_SHARDED, reshard_if_not_optimal=reshard_if_not_optimal, transpose_shards=transpose_shards, - packer_l1_accum_enabled=packer_l1_acc, enable_act_double_buffer=enable_act_double_buffer, enable_weights_double_buffer=True, enable_split_reader=enable_split_reader, enable_subblock_padding=enable_subblock_padding, ), + compute_config=ttnn.init_device_compute_kernel_config( + device.arch(), + math_fidelity=self.model_config["MATH_FIDELITY"], + packer_l1_acc=packer_l1_acc, + ), conv_op_cache=conv_op_cache, + return_output_dim=True, + return_weights_and_bias=True, ) logger.debug( @@ -358,7 +373,7 @@ def __call__( # conv3 is 1x1 conv logger.debug(f"Running conv3") - out, _, _, self.conv3_weight_tensor, self.conv3_bias_tensor = ttnn.conv2d( + out, [self.conv3_weight_tensor, self.conv3_bias_tensor] = ttnn.conv2d( input_tensor=out, weight_tensor=self.conv3_weight_tensor, in_channels=self.conv3_input_channels, @@ -374,15 +389,20 @@ def __call__( conv_config=ttnn.Conv2dConfig( dtype=self.model_config["ACTIVATIONS_DTYPE"], weights_dtype=self.model_config["WEIGHTS_DTYPE"], - math_fidelity=self.model_config["MATH_FIDELITY"], shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED if height_sharding else ttnn.TensorMemoryLayout.BLOCK_SHARDED, reshard_if_not_optimal=reshard_if_not_optimal, transpose_shards=transpose_shards, - packer_l1_accum_enabled=packer_l1_acc, + ), + compute_config=ttnn.init_device_compute_kernel_config( + device.arch(), + math_fidelity=self.model_config["MATH_FIDELITY"], + packer_l1_acc=packer_l1_acc, ), conv_op_cache=conv_op_cache, + return_output_dim=False, + return_weights_and_bias=True, ) if not run_downsample_before_conv2: @@ -569,19 +589,22 @@ def __init__( self.conv1_config = ttnn.Conv2dConfig( dtype=self.model_config["ACTIVATIONS_DTYPE"], weights_dtype=self.model_config["WEIGHTS_DTYPE"], - math_fidelity=self.model_config["MATH_FIDELITY"], activation="relu", deallocate_activation=dealloc_input, input_channels_alignment=input_channels_alignment, act_block_h_override=act_block_h_override, transpose_shards=self.transpose_shards, - packer_l1_accum_enabled=True if whb0_and_b16 else False, enable_act_double_buffer=True if whb0_and_b16 else False, enable_split_reader=True if whb0_and_b16 or not is_wormhole_b0() else False, enable_subblock_padding=False, shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED, reshard_if_not_optimal=False, ) + self.conv1_compute_config = ttnn.init_device_compute_kernel_config( + device.arch(), + math_fidelity=self.model_config["MATH_FIDELITY"], + packer_l1_acc=True if whb0_and_b16 else False, + ) if whb0_and_b16: # Issue #13145: Temp workaround for Galaxy to avoid hangs if type(device) == ttnn.MeshDevice and device.get_num_devices() > 8: @@ -719,7 +742,7 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt logger.debug(f"==== first conv") # first conv - x, x_height, x_width, self.conv1_weight_tensor, self.conv1_bias_tensor = ttnn.conv2d( + x, [x_height, x_width], [self.conv1_weight_tensor, self.conv1_bias_tensor] = ttnn.conv2d( input_tensor=fold_output_tensor, weight_tensor=self.conv1_weight_tensor, in_channels=self.conv1_input_channels, @@ -733,7 +756,10 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt input_height=self.conv1_input_height, input_width=self.conv1_input_width, conv_config=self.conv1_config, + compute_config=self.conv1_compute_config, conv_op_cache=conv_op_cache, + return_output_dim=True, + return_weights_and_bias=True, ) # Relu is fused with conv1 if self.batch_size == 20: diff --git a/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_xlarge_new_conv_api.py b/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_xlarge_new_conv_api.py index 5c0750003c1..a5427f1fc87 100644 --- a/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_xlarge_new_conv_api.py +++ b/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_xlarge_new_conv_api.py @@ -162,7 +162,7 @@ def run_downsample_if_req( height_sharding=None, ): if self.downsample: - ds_out, _, _, self.ds_conv_weight_tensor, self.ds_conv_bias_tensor = ttnn.conv2d( + ds_out, [self.ds_conv_weight_tensor, self.ds_conv_bias_tensor] = ttnn.conv2d( input_tensor=x, weight_tensor=self.ds_conv_weight_tensor, in_channels=self.ds_conv_input_channels, @@ -178,7 +178,6 @@ def run_downsample_if_req( conv_config=ttnn.Conv2dConfig( dtype=self.model_config["ACTIVATIONS_DTYPE"], weights_dtype=self.model_config["WEIGHTS_DTYPE"], - math_fidelity=self.model_config["MATH_FIDELITY"], shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED if height_sharding else ttnn.TensorMemoryLayout.BLOCK_SHARDED, @@ -186,7 +185,12 @@ def run_downsample_if_req( reallocate_halo_output=True, reshard_if_not_optimal=reshard_if_not_optimal, ), + compute_config=ttnn.init_device_compute_kernel_config( + device.arch(), math_fidelity=self.model_config["MATH_FIDELITY"] + ), conv_op_cache=conv_op_cache, + return_output_dim=False, + return_weights_and_bias=True, ) ttnn.deallocate(x) ds_out = ttnn.reallocate(ds_out) @@ -209,7 +213,7 @@ def __call__( # conv1 is 1x1 conv # print("Running conv1") module_input_height = input_height - out, input_height, input_width, self.conv1_weight_tensor, self.conv1_bias_tensor = ttnn.conv2d( + out, [input_height, input_width], [self.conv1_weight_tensor, self.conv1_bias_tensor] = ttnn.conv2d( input_tensor=x, weight_tensor=self.conv1_weight_tensor, in_channels=self.conv1_input_channels, @@ -225,14 +229,18 @@ def __call__( conv_config=ttnn.Conv2dConfig( dtype=self.model_config["ACTIVATIONS_DTYPE"], weights_dtype=self.model_config["WEIGHTS_DTYPE"], - math_fidelity=self.model_config["MATH_FIDELITY"], activation="relu", shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED if height_sharding else ttnn.TensorMemoryLayout.BLOCK_SHARDED, reshard_if_not_optimal=reshard_if_not_optimal, ), + compute_config=ttnn.init_device_compute_kernel_config( + device.arch(), math_fidelity=self.model_config["MATH_FIDELITY"] + ), conv_op_cache=conv_op_cache, + return_output_dim=True, + return_weights_and_bias=True, ) act_block_h_override = 0 @@ -270,7 +278,7 @@ def __call__( # self.conv1_input_channels == 256 and # self.downsample ) - out, input_height, input_width, self.conv2_weight_tensor, self.conv2_bias_tensor = ttnn.conv2d( + out, [input_height, input_width], [self.conv2_weight_tensor, self.conv2_bias_tensor] = ttnn.conv2d( input_tensor=out, weight_tensor=self.conv2_weight_tensor, in_channels=self.conv2_input_channels, @@ -286,7 +294,6 @@ def __call__( conv_config=ttnn.Conv2dConfig( dtype=self.model_config["ACTIVATIONS_DTYPE"], weights_dtype=self.model_config["WEIGHTS_DTYPE"], - math_fidelity=self.model_config["MATH_FIDELITY"], activation="relu", deallocate_activation=True, reallocate_halo_output=reallocate_halo_output, @@ -296,12 +303,17 @@ def __call__( else ttnn.TensorMemoryLayout.BLOCK_SHARDED, reshard_if_not_optimal=reshard_if_not_optimal, ), + compute_config=ttnn.init_device_compute_kernel_config( + device.arch(), math_fidelity=self.model_config["MATH_FIDELITY"] + ), conv_op_cache=conv_op_cache, + return_output_dim=True, + return_weights_and_bias=True, ) # conv3 is 1x1 conv # print("Running conv3") - out, _, _, self.conv3_weight_tensor, self.conv3_bias_tensor = ttnn.conv2d( + out, [self.conv3_weight_tensor, self.conv3_bias_tensor] = ttnn.conv2d( input_tensor=out, weight_tensor=self.conv3_weight_tensor, in_channels=self.conv3_input_channels, @@ -317,13 +329,17 @@ def __call__( conv_config=ttnn.Conv2dConfig( dtype=self.model_config["ACTIVATIONS_DTYPE"], weights_dtype=self.model_config["WEIGHTS_DTYPE"], - math_fidelity=self.model_config["MATH_FIDELITY"], shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED if height_sharding else ttnn.TensorMemoryLayout.BLOCK_SHARDED, reshard_if_not_optimal=reshard_if_not_optimal, ), + compute_config=ttnn.init_device_compute_kernel_config( + device.arch(), math_fidelity=self.model_config["MATH_FIDELITY"] + ), conv_op_cache=conv_op_cache, + return_output_dim=False, + return_weights_and_bias=True, ) if not self.run_downsample_before_conv2: @@ -516,7 +532,7 @@ def first_run(self, input_tensor, device, batch_size, ops_parallel_config) -> tt elif batch_size == 20: act_block_h_override = 640 - x, x_height, x_width, self.conv1_weight_tensor, self.conv1_bias_tensor = ttnn.conv2d( + x, [x_height, x_width], [self.conv1_weight_tensor, self.conv1_bias_tensor] = ttnn.conv2d( input_tensor=input_tensor, weight_tensor=self.conv1_weight_tensor, in_channels=self.conv1_input_channels, @@ -532,13 +548,17 @@ def first_run(self, input_tensor, device, batch_size, ops_parallel_config) -> tt conv_config=ttnn.Conv2dConfig( dtype=self.model_config["ACTIVATIONS_DTYPE"], weights_dtype=self.model_config["WEIGHTS_DTYPE"], - math_fidelity=self.model_config["MATH_FIDELITY"], activation="relu", deallocate_activation=True, input_channels_alignment=16 if not is_wormhole_b0() else 32, act_block_h_override=act_block_h_override, ), + compute_config=ttnn.init_device_compute_kernel_config( + device.arch(), math_fidelity=self.model_config["MATH_FIDELITY"] + ), conv_op_cache=conv_op_cache, + return_output_dim=True, + return_weights_and_bias=True, ) # Relu is fused with conv1 @@ -819,7 +839,7 @@ def optimized_run(self, input_tensor, device, batch_size, ops_parallel_config, c else: act_block_h_override = 0 - x, x_height, x_width, self.conv1_weight_tensor, self.conv1_bias_tensor = ttnn.conv2d( + x, [x_height, x_width], [self.conv1_weight_tensor, self.conv1_bias_tensor] = ttnn.conv2d( input_tensor=input_tensor, weight_tensor=self.conv1_weight_tensor, in_channels=self.conv1_input_channels, @@ -835,13 +855,17 @@ def optimized_run(self, input_tensor, device, batch_size, ops_parallel_config, c conv_config=ttnn.Conv2dConfig( dtype=self.model_config["ACTIVATIONS_DTYPE"], weights_dtype=self.model_config["WEIGHTS_DTYPE"], - math_fidelity=self.model_config["MATH_FIDELITY"], activation="relu", deallocate_activation=True, input_channels_alignment=16 if not is_wormhole_b0() else 32, act_block_h_override=act_block_h_override, ), + compute_config=ttnn.init_device_compute_kernel_config( + device.arch(), math_fidelity=self.model_config["MATH_FIDELITY"] + ), conv_op_cache=conv_op_cache, + return_output_dim=True, + return_weights_and_bias=True, ) # Relu is fused with conv1 diff --git a/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_xlarge_new_conv_api_24.py b/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_xlarge_new_conv_api_24.py index f2e266e1d8b..6bc5013bbf6 100644 --- a/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_xlarge_new_conv_api_24.py +++ b/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_xlarge_new_conv_api_24.py @@ -164,7 +164,7 @@ def run_downsample_if_req( height_sharding=None, ): if self.downsample: - ds_out, _, _, self.ds_conv_weight_tensor, self.ds_conv_bias_tensor = ttnn.conv2d( + ds_out, [self.ds_conv_weight_tensor, self.ds_conv_bias_tensor] = ttnn.conv2d( input_tensor=x, weight_tensor=self.ds_conv_weight_tensor, in_channels=self.ds_conv_input_channels, @@ -180,7 +180,6 @@ def run_downsample_if_req( conv_config=ttnn.Conv2dConfig( dtype=self.model_config["ACTIVATIONS_DTYPE"], weights_dtype=self.model_config["WEIGHTS_DTYPE"], - math_fidelity=self.model_config["MATH_FIDELITY"], shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED if height_sharding else ttnn.TensorMemoryLayout.BLOCK_SHARDED, @@ -188,7 +187,12 @@ def run_downsample_if_req( reallocate_halo_output=True, reshard_if_not_optimal=reshard_if_not_optimal, ), + compute_config=ttnn.init_device_compute_kernel_config( + device.arch(), math_fidelity=self.model_config["MATH_FIDELITY"] + ), conv_op_cache=conv_op_cache, + return_output_dim=False, + return_weights_and_bias=True, ) ttnn.deallocate(x) ds_out = ttnn.reallocate(ds_out) @@ -211,7 +215,7 @@ def __call__( # conv1 is 1x1 conv # print("Running conv1") module_input_height = input_height - out, input_height, input_width, self.conv1_weight_tensor, self.conv1_bias_tensor = ttnn.conv2d( + out, [input_height, input_width], [self.conv1_weight_tensor, self.conv1_bias_tensor] = ttnn.conv2d( input_tensor=x, weight_tensor=self.conv1_weight_tensor, in_channels=self.conv1_input_channels, @@ -227,14 +231,18 @@ def __call__( conv_config=ttnn.Conv2dConfig( dtype=self.model_config["ACTIVATIONS_DTYPE"], weights_dtype=self.model_config["WEIGHTS_DTYPE"], - math_fidelity=self.model_config["MATH_FIDELITY"], activation="relu", shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED if height_sharding else ttnn.TensorMemoryLayout.BLOCK_SHARDED, reshard_if_not_optimal=reshard_if_not_optimal, ), + compute_config=ttnn.init_device_compute_kernel_config( + device.arch(), math_fidelity=self.model_config["MATH_FIDELITY"] + ), conv_op_cache=conv_op_cache, + return_output_dim=True, + return_weights_and_bias=True, ) act_block_h_override = 0 @@ -273,7 +281,7 @@ def __call__( logger.info( f"Running conv2 with reallocate_halo_output={reallocate_halo_output}, input_height={input_height}, conv2_output_channels={self.conv2_output_channels}" ) - out, input_height, input_width, self.conv2_weight_tensor, self.conv2_bias_tensor = ttnn.conv2d( + out, [input_height, input_width], [self.conv2_weight_tensor, self.conv2_bias_tensor] = ttnn.conv2d( input_tensor=out, weight_tensor=self.conv2_weight_tensor, in_channels=self.conv2_input_channels, @@ -289,7 +297,6 @@ def __call__( conv_config=ttnn.Conv2dConfig( dtype=self.model_config["ACTIVATIONS_DTYPE"], weights_dtype=self.model_config["WEIGHTS_DTYPE"], - math_fidelity=self.model_config["MATH_FIDELITY"], activation="relu", deallocate_activation=True, reallocate_halo_output=reallocate_halo_output, @@ -299,12 +306,17 @@ def __call__( else ttnn.TensorMemoryLayout.BLOCK_SHARDED, reshard_if_not_optimal=reshard_if_not_optimal, ), + compute_config=ttnn.init_device_compute_kernel_config( + device.arch(), math_fidelity=self.model_config["MATH_FIDELITY"] + ), conv_op_cache=conv_op_cache, + return_output_dim=True, + return_weights_and_bias=True, ) # conv3 is 1x1 conv # print("Running conv3") - out, _, _, self.conv3_weight_tensor, self.conv3_bias_tensor = ttnn.conv2d( + out, self.conv3_weight_tensor, self.conv3_bias_tensor = ttnn.conv2d( input_tensor=out, weight_tensor=self.conv3_weight_tensor, in_channels=self.conv3_input_channels, @@ -320,13 +332,17 @@ def __call__( conv_config=ttnn.Conv2dConfig( dtype=self.model_config["ACTIVATIONS_DTYPE"], weights_dtype=self.model_config["WEIGHTS_DTYPE"], - math_fidelity=self.model_config["MATH_FIDELITY"], shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED if height_sharding else ttnn.TensorMemoryLayout.BLOCK_SHARDED, reshard_if_not_optimal=reshard_if_not_optimal, ), + compute_config=ttnn.init_device_compute_kernel_config( + device.arch(), math_fidelity=self.model_config["MATH_FIDELITY"] + ), conv_op_cache=conv_op_cache, + return_output_dim=True, + return_weights_and_bias=True, ) if not self.run_downsample_before_conv2: @@ -541,7 +557,7 @@ def first_run(self, input_tensor, device, batch_size, ops_parallel_config) -> tt input_tensor, device=device, memory_config=self.grayskull_conv1_input_memory_config ) - x, x_height, x_width, self.conv1_weight_tensor, self.conv1_bias_tensor = ttnn.conv2d( + x, [x_height, x_width], [self.conv1_weight_tensor, self.conv1_bias_tensor] = ttnn.conv2d( input_tensor=input_tensor, weight_tensor=self.conv1_weight_tensor, in_channels=self.conv1_input_channels, @@ -557,13 +573,17 @@ def first_run(self, input_tensor, device, batch_size, ops_parallel_config) -> tt conv_config=ttnn.Conv2dConfig( dtype=self.model_config["ACTIVATIONS_DTYPE"], weights_dtype=self.model_config["WEIGHTS_DTYPE"], - math_fidelity=self.model_config["MATH_FIDELITY"], activation="relu", deallocate_activation=True, input_channels_alignment=16 if not is_wormhole_b0() else 32, act_block_h_override=act_block_h_override, ), + compute_config=ttnn.init_device_compute_kernel_config( + device.arch(), math_fidelity=self.model_config["MATH_FIDELITY"] + ), conv_op_cache=conv_op_cache, + return_output_dim=True, + return_weights_and_bias=True, ) # Relu is fused with conv1 @@ -872,7 +892,7 @@ def optimized_run(self, input_tensor, device, batch_size, ops_parallel_config, c elif batch_size == 20: act_block_h_override = 640 - x, x_height, x_width, self.conv1_weight_tensor, self.conv1_bias_tensor = ttnn.conv2d( + x, [x_height, x_width], [self.conv1_weight_tensor, self.conv1_bias_tensor] = ttnn.conv2d( input_tensor=input_tensor, weight_tensor=self.conv1_weight_tensor, in_channels=self.conv1_input_channels, @@ -888,13 +908,17 @@ def optimized_run(self, input_tensor, device, batch_size, ops_parallel_config, c conv_config=ttnn.Conv2dConfig( dtype=self.model_config["ACTIVATIONS_DTYPE"], weights_dtype=self.model_config["WEIGHTS_DTYPE"], - math_fidelity=self.model_config["MATH_FIDELITY"], activation="relu", deallocate_activation=True, input_channels_alignment=16 if not is_wormhole_b0() else 32, act_block_h_override=act_block_h_override, ), + compute_config=ttnn.init_device_compute_kernel_config( + device.arch(), math_fidelity=self.model_config["MATH_FIDELITY"] + ), conv_op_cache=conv_op_cache, + return_output_dim=True, + return_weights_and_bias=True, ) # Relu is fused with conv1 diff --git a/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_xxlarge_new_conv_api.py b/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_xxlarge_new_conv_api.py index 45d93ebf685..d59a6c75238 100644 --- a/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_xxlarge_new_conv_api.py +++ b/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_xxlarge_new_conv_api.py @@ -163,7 +163,7 @@ def run_downsample_if_req( height_sharding=None, ): if self.downsample: - ds_out, _, _, self.ds_conv_weight_tensor, self.ds_conv_bias_tensor = ttnn.conv2d( + ds_out, [self.ds_conv_weight_tensor, self.ds_conv_bias_tensor] = ttnn.conv2d( input_tensor=x, weight_tensor=self.ds_conv_weight_tensor, in_channels=self.ds_conv_input_channels, @@ -179,7 +179,6 @@ def run_downsample_if_req( conv_config=ttnn.Conv2dConfig( dtype=self.model_config["ACTIVATIONS_DTYPE"], weights_dtype=self.model_config["WEIGHTS_DTYPE"], - math_fidelity=self.model_config["MATH_FIDELITY"], shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED if height_sharding else ttnn.TensorMemoryLayout.BLOCK_SHARDED, @@ -188,7 +187,12 @@ def run_downsample_if_req( reshard_if_not_optimal=reshard_if_not_optimal, transpose_shards=height_sharding, ), + compute_config=ttnn.init_device_compute_kernel_config( + device.arch(), math_fidelity=self.model_config["MATH_FIDELITY"] + ), conv_op_cache=conv_op_cache, + return_output_dim=False, + return_weights_and_bias=True, ) ttnn.deallocate(x) ds_out = ttnn.reallocate(ds_out) @@ -216,7 +220,7 @@ def __call__( # conv1 is 1x1 conv # print("Running conv1") module_input_height = input_height - out, input_height, input_width, self.conv1_weight_tensor, self.conv1_bias_tensor = ttnn.conv2d( + out, [input_height, input_width], [self.conv1_weight_tensor, self.conv1_bias_tensor] = ttnn.conv2d( input_tensor=x, weight_tensor=self.conv1_weight_tensor, in_channels=self.conv1_input_channels, @@ -232,7 +236,6 @@ def __call__( conv_config=ttnn.Conv2dConfig( dtype=self.model_config["ACTIVATIONS_DTYPE"], weights_dtype=self.model_config["WEIGHTS_DTYPE"], - math_fidelity=self.model_config["MATH_FIDELITY"], activation="relu", shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED if height_sharding @@ -240,7 +243,12 @@ def __call__( reshard_if_not_optimal=reshard_if_not_optimal, transpose_shards=height_sharding, ), + compute_config=ttnn.init_device_compute_kernel_config( + device.arch(), math_fidelity=self.model_config["MATH_FIDELITY"] + ), conv_op_cache=conv_op_cache, + return_output_dim=True, + return_weights_and_bias=True, ) if is_wormhole_b0(): @@ -321,7 +329,7 @@ def __call__( # self.conv1_input_channels == 256 and # self.downsample ) - out, input_height, input_width, self.conv2_weight_tensor, self.conv2_bias_tensor = ttnn.conv2d( + out, [input_height, input_width], [self.conv2_weight_tensor, self.conv2_bias_tensor] = ttnn.conv2d( input_tensor=out, weight_tensor=self.conv2_weight_tensor, in_channels=self.conv2_input_channels, @@ -337,7 +345,6 @@ def __call__( conv_config=ttnn.Conv2dConfig( dtype=self.model_config["ACTIVATIONS_DTYPE"], weights_dtype=self.model_config["WEIGHTS_DTYPE"], - math_fidelity=self.model_config["MATH_FIDELITY"], activation="relu", deallocate_activation=True, reallocate_halo_output=reallocate_halo_output, @@ -348,12 +355,17 @@ def __call__( reshard_if_not_optimal=reshard_if_not_optimal, transpose_shards=height_sharding, ), + compute_config=ttnn.init_device_compute_kernel_config( + device.arch(), math_fidelity=self.model_config["MATH_FIDELITY"] + ), conv_op_cache=conv_op_cache, + return_output_dim=True, + return_weights_and_bias=True, ) # conv3 is 1x1 conv # print("Running conv3") - out, _, _, self.conv3_weight_tensor, self.conv3_bias_tensor = ttnn.conv2d( + out, [self.conv3_weight_tensor, self.conv3_bias_tensor] = ttnn.conv2d( input_tensor=out, weight_tensor=self.conv3_weight_tensor, in_channels=self.conv3_input_channels, @@ -369,14 +381,18 @@ def __call__( conv_config=ttnn.Conv2dConfig( dtype=self.model_config["ACTIVATIONS_DTYPE"], weights_dtype=self.model_config["WEIGHTS_DTYPE"], - math_fidelity=self.model_config["MATH_FIDELITY"], shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED if height_sharding else ttnn.TensorMemoryLayout.BLOCK_SHARDED, reshard_if_not_optimal=reshard_if_not_optimal, transpose_shards=height_sharding, ), + compute_config=ttnn.init_device_compute_kernel_config( + device.arch(), math_fidelity=self.model_config["MATH_FIDELITY"] + ), conv_op_cache=conv_op_cache, + return_output_dim=False, + return_weights_and_bias=True, ) if not self.run_downsample_before_conv2: @@ -581,7 +597,7 @@ def first_run(self, input_tensor, device, batch_size, ops_parallel_config) -> tt else: act_block_h_override = 0 - x, x_height, x_width, self.conv1_weight_tensor, self.conv1_bias_tensor = ttnn.conv2d( + x, [x_height, x_width], [self.conv1_weight_tensor, self.conv1_bias_tensor] = ttnn.conv2d( input_tensor=input_tensor, weight_tensor=self.conv1_weight_tensor, in_channels=self.conv1_input_channels, @@ -597,14 +613,18 @@ def first_run(self, input_tensor, device, batch_size, ops_parallel_config) -> tt conv_config=ttnn.Conv2dConfig( dtype=self.model_config["ACTIVATIONS_DTYPE"], weights_dtype=self.model_config["WEIGHTS_DTYPE"], - math_fidelity=self.model_config["MATH_FIDELITY"], activation="relu", deallocate_activation=True, reallocate_halo_output=True, input_channels_alignment=16 if not is_wormhole_b0() else 32, act_block_h_override=act_block_h_override, ), + compute_config=ttnn.init_device_compute_kernel_config( + device.arch(), math_fidelity=self.model_config["MATH_FIDELITY"] + ), conv_op_cache=conv_op_cache, + return_output_dim=True, + return_weights_and_bias=True, ) # Relu is fused with conv1 @@ -915,7 +935,7 @@ def optimized_run(self, input_tensor, device, batch_size, ops_parallel_config, c elif batch_size == 1: act_block_h_override = 256 - x, x_height, x_width, self.conv1_weight_tensor, self.conv1_bias_tensor = ttnn.conv2d( + x, [x_height, x_width], [self.conv1_weight_tensor, self.conv1_bias_tensor] = ttnn.conv2d( input_tensor=input_tensor, weight_tensor=self.conv1_weight_tensor, in_channels=self.conv1_input_channels, @@ -931,13 +951,17 @@ def optimized_run(self, input_tensor, device, batch_size, ops_parallel_config, c conv_config=ttnn.Conv2dConfig( dtype=self.model_config["ACTIVATIONS_DTYPE"], weights_dtype=self.model_config["WEIGHTS_DTYPE"], - math_fidelity=self.model_config["MATH_FIDELITY"], activation="relu", deallocate_activation=True, input_channels_alignment=16 if not is_wormhole_b0() else 32, act_block_h_override=act_block_h_override, ), + compute_config=ttnn.init_device_compute_kernel_config( + device.arch(), math_fidelity=self.model_config["MATH_FIDELITY"] + ), conv_op_cache=conv_op_cache, + return_output_dim=True, + return_weights_and_bias=True, ) # Relu is fused with conv1 diff --git a/models/demos/vgg/tests/test_perf_vgg.py b/models/demos/vgg/tests/test_perf_vgg.py index 4d5bdf30e06..9cc0397bb07 100644 --- a/models/demos/vgg/tests/test_perf_vgg.py +++ b/models/demos/vgg/tests/test_perf_vgg.py @@ -79,17 +79,6 @@ def test_vgg( "ACTIVATIONS_DTYPE": act_dtype, } - conv_config = ttnn.Conv2dConfig( - dtype=model_config["ACTIVATIONS_DTYPE"], - weights_dtype=model_config["WEIGHTS_DTYPE"], - math_fidelity=model_config["MATH_FIDELITY"], - activation="relu", - deallocate_activation=True, - input_channels_alignment=16, - act_block_h_override=0, - transpose_shards=True, - ) - torch_batched_tensor = torch_input_tensor_nchw.repeat(batch_size, 1, 1, 1) torch_input_tensor = torch.permute(torch_batched_tensor, (0, 2, 3, 1)) tt_batched_input_tensor = ttnn.from_torch(torch_input_tensor, ttnn.bfloat16) diff --git a/models/demos/vgg/tt/ttnn_vgg.py b/models/demos/vgg/tt/ttnn_vgg.py index 0748c745d16..82f5dd1c03d 100644 --- a/models/demos/vgg/tt/ttnn_vgg.py +++ b/models/demos/vgg/tt/ttnn_vgg.py @@ -90,10 +90,6 @@ def ttnn_vgg16( conv_config = ttnn.Conv2dConfig( dtype=model_config["ACTIVATIONS_DTYPE"], weights_dtype=model_config["WEIGHTS_DTYPE"], - math_fidelity=model_config["MATH_FIDELITY"], - math_approx_mode_enabled=True, - fp32_dest_acc_enabled=False, - packer_l1_accum_enabled=False, activation="relu", deallocate_activation=False, input_channels_alignment=32, @@ -107,13 +103,20 @@ def ttnn_vgg16( ) if h_override[iter_conv_id] is not None: conv_config.act_block_h_override = h_override[iter_conv_id] + compute_config = ttnn.init_device_compute_kernel_config( + device.arch(), + math_fidelity=model_config["MATH_FIDELITY"], + math_approx_mode=True, + fp32_dest_acc_en=False, + packer_l1_acc=False, + ) tt_weight = parameters.features[conv_feature_ids[iter_conv_id]].weight tt_weight = ttnn.to_layout(ttnn.from_device(tt_weight), layout=ttnn.ROW_MAJOR_LAYOUT) tt_bias = parameters.features[conv_feature_ids[iter_conv_id]].bias # Call ttnn.conv conv_op_cache = {} - [tt_output_tensor_on_device, out_height, out_width, weights_device, bias_device] = ttnn.conv2d( + [tt_output_tensor_on_device, [out_height, out_width], [weights_device, bias_device]] = ttnn.conv2d( input_tensor=tt_x, weight_tensor=tt_weight, in_channels=conv_ttnn_params[iter_conv_id][0], @@ -127,7 +130,10 @@ def ttnn_vgg16( input_height=conv_ttnn_params[iter_conv_id][2], input_width=conv_ttnn_params[iter_conv_id][3], conv_config=conv_config, + compute_config=compute_config, conv_op_cache=conv_op_cache, + return_output_dim=True, + return_weights_and_bias=True, ) tt_x = ttnn.from_device(tt_output_tensor_on_device) ttnn.deallocate(tt_output_tensor_on_device) @@ -214,9 +220,6 @@ def ttnn_vgg11( conv_config = ttnn.Conv2dConfig( dtype=model_config["ACTIVATIONS_DTYPE"], weights_dtype=model_config["WEIGHTS_DTYPE"], - math_fidelity=model_config["MATH_FIDELITY"], - math_approx_mode_enabled=True, - fp32_dest_acc_enabled=True, activation="relu", deallocate_activation=False, input_channels_alignment=32, @@ -230,13 +233,19 @@ def ttnn_vgg11( if height_override_11[iter_conv_id] is not None: conv_config.act_block_h_override = height_override_11[iter_conv_id] + compute_config = ttnn.init_device_compute_kernel_config( + device.arch(), + math_fidelity=model_config["MATH_FIDELITY"], + math_approx_mode=True, + fp32_dest_acc_en=True, + ) tt_weight = parameters.features[conv_feature_ids_2[iter_conv_id]].weight tt_weight = ttnn.to_layout(ttnn.from_device(tt_weight), layout=ttnn.ROW_MAJOR_LAYOUT) tt_bias = parameters.features[conv_feature_ids_2[iter_conv_id]].bias # Call ttnn.conv conv_op_cache = {} - [tt_output_tensor_on_device, out_height, out_width, weights_device, bias_device] = ttnn.conv2d( + [tt_output_tensor_on_device, [out_height, out_width], [weights_device, bias_device]] = ttnn.conv2d( input_tensor=tt_x, weight_tensor=tt_weight, in_channels=conv_ttnn_params_2[iter_conv_id][0], @@ -250,7 +259,10 @@ def ttnn_vgg11( input_height=conv_ttnn_params_2[iter_conv_id][2], input_width=conv_ttnn_params_2[iter_conv_id][3], conv_config=conv_config, + compute_config=compute_config, conv_op_cache=conv_op_cache, + return_output_dim=True, + return_weights_and_bias=True, ) tt_x = ttnn.from_device(tt_output_tensor_on_device) ttnn.deallocate(tt_output_tensor_on_device) diff --git a/models/demos/wormhole/mamba/tt/mamba_conv.py b/models/demos/wormhole/mamba/tt/mamba_conv.py index a2700198f83..c7e897ad484 100644 --- a/models/demos/wormhole/mamba/tt/mamba_conv.py +++ b/models/demos/wormhole/mamba/tt/mamba_conv.py @@ -54,11 +54,14 @@ def prepare_conv_config(self): self.conv1d_config = ttnn.Conv1dConfig( dtype=self.config.output_dtype, weights_dtype=self.config.weights_dtype, - math_fidelity=self.config.math_fidelity, shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED, input_channels_alignment=32, deallocate_activation=True, ) + self.conv1d_compute_config = ttnn.init_device_compute_kernel_config( + self.device.arch(), + math_fidelity=self.config.math_fidelity, + ) def prepare_input(self, input_tensor): # input_tensor (1, 1, B, 2E) @@ -87,7 +90,7 @@ def __call__(self, input_tensor): input_tensor_splits = self.prepare_input(input_tensor) output_tensor_splits = [] for i in range(self.config.channels_split_factor): - [tt_output_tensor_on_device, out_length, weights_device, _] = ttnn.Conv1d( + [tt_output_tensor_on_device, out_length, [weights_device, _]] = ttnn.Conv1d( input_tensor=input_tensor_splits[i], weight_tensor=self.tt_weight_tensor_splits[i], in_channels=self.config.input_channels // self.config.channels_split_factor, @@ -100,9 +103,12 @@ def __call__(self, input_tensor): batch_size=1, input_length=self.config.input_length, conv_config=self.conv1d_config, + compute_config=self.conv1d_compute_config, conv_op_cache={}, debug=False, groups=self.config.groups // self.config.channels_split_factor, + return_output_dim=True, + return_weights_and_bias=True, ) self.tt_weight_tensor_splits[i] = weights_device output_tensor_splits.append(ttnn.sharded_to_interleaved(tt_output_tensor_on_device)) diff --git a/models/demos/wormhole/stable_diffusion/test_multiple_iterations.py b/models/demos/wormhole/stable_diffusion/test_multiple_iterations.py deleted file mode 100644 index 8db6aee6f39..00000000000 --- a/models/demos/wormhole/stable_diffusion/test_multiple_iterations.py +++ /dev/null @@ -1,236 +0,0 @@ -# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. - -# SPDX-License-Identifier: Apache-2.0 - -import ttnn -import json -import torch -import pytest -import numpy as np -from PIL import Image -from loguru import logger -from tqdm.auto import tqdm -from datasets import load_dataset - -from transformers import CLIPTextModel, CLIPTokenizer -from diffusers import ( - AutoencoderKL, - UNet2DConditionModel, - LMSDiscreteScheduler, -) -from models.utility_functions import ( - comp_allclose_and_pcc, - enable_persistent_kernel_cache, - disable_persistent_kernel_cache, -) -from models.utility_functions import skip_for_wormhole_b0 -from ttnn.model_preprocessing import preprocess_model_parameters -from models.demos.wormhole.stable_diffusion.custom_preprocessing import custom_preprocessor -from models.demos.wormhole.stable_diffusion.tt.ttnn_functional_unet_2d_condition_model import ( - UNet2DConditionModel as UNet2D, -) - -from torchvision.transforms import ToTensor - - -def load_inputs(input_path): - with open(input_path) as f: - input_data = json.load(f) - assert input_data, "Input data is empty." - prompt = [item["prompt"] for item in input_data] - return prompt - - -def constant_prop_time_embeddings(timesteps, sample, time_proj): - timesteps = timesteps[None] - timesteps = timesteps.expand(sample.shape[0]) - t_emb = time_proj(timesteps) - return t_emb - - -def save_image_and_latents(latents, iter, vae, pre_fix="", pre_fix2=""): - pre_fix = "" if pre_fix == "" else f"{pre_fix}_" - pre_fix2 = "" if pre_fix2 == "" else f"{pre_fix2}_" - _latents = 1 / 0.18215 * latents - - with torch.no_grad(): - image = vae.decode(_latents).sample - # Image post-processing - image = (image / 2 + 0.5).clamp(0, 1) - image = image.detach().cpu().permute(0, 2, 3, 1).numpy() - images = (image * 255).round().astype("uint8") - pil_images = [Image.fromarray(image) for image in images][0] - pil_images.save(f"{pre_fix}{pre_fix2}image_iter_{iter}.png") - - -def guide(noise_pred, guidance_scale, t): # will return latents - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - return noise_pred - - -def latent_expansion(latents, scheduler, t): - latent_model_input = torch.cat([latents] * 2, dim=0) - latent_model_input = scheduler.scale_model_input(latent_model_input, timestep=t) - return latent_model_input - - -def calculate_fid_score(imgs_path1, imgs_path2): - fid = FrechetInceptionDistance(normalize=True) - fid.update(imgs_path1, real=False) - fid.update(imgs_path2, real=True) - return fid.compute() - - -def preprocess_images(image_paths): - images = [] - for image_path in image_paths: - image = Image.open(image_path) - image = image.convert("RGB") - image = image.resize((299, 299)) - image = ToTensor()(image) - images.append(image) - return torch.stack(images) - - -def run_demo_inference_diffusiondb(device, reset_seeds, input_path, num_inference_steps, image_size): - disable_persistent_kernel_cache() - - height, width = image_size - - experiment_name = f"diffusiondb_{height}x{width}" - input_prompt = [ - "oil painting frame of Breathtaking mountain range with a clear river running through it, surrounded by tall trees and misty clouds, serene, peaceful, mountain landscape, high detail" - ] - logger.info(f"input_prompts: {input_prompt}") - - # 1. Load the autoencoder model which will be used to decode the latents into image space. - vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae") - - # 2. Load the tokenizer and text encoder to tokenize and encode the text. - tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") - text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14") - - # 3. The UNet model for generating the latents. - unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet") - - # 4. load the K-LMS scheduler with some fitting parameters. - ttnn_scheduler = LMSDiscreteScheduler( - beta_start=0.00085, - beta_end=0.012, - beta_schedule="scaled_linear", - num_train_timesteps=1000, - ) - - torch_device = "cpu" - vae.to(torch_device) - text_encoder.to(torch_device) - unet.to(torch_device) - - guidance_scale = 7.5 # Scale for classifier-free guidance - generator = torch.manual_seed(174) # 10233 Seed generator to create the inital latent noise - batch_size = len(input_prompt) - - ## First, we get the text_embeddings for the prompt. These embeddings will be used to condition the UNet model. - # Tokenizer and Text Encoder - text_input = tokenizer( - input_prompt, - padding="max_length", - max_length=tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - text_embeddings = text_encoder(text_input.input_ids.to(torch_device))[0] - max_length = text_input.input_ids.shape[-1] - uncond_input = tokenizer([""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt") - uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0] - - # For classifier-free guidance, we need to do two forward passes: one with the conditioned input (text_embeddings), - # and another with the unconditional embeddings (uncond_embeddings). - # In practice, we can concatenate both into a single batch to avoid doing two forward passes. - text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) - ttnn_text_embeddings = ttnn.from_torch(text_embeddings, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device) - - vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1) - # Initial random noise - latents = torch.randn( - (batch_size, unet.config.in_channels, height // vae_scale_factor, width // vae_scale_factor), - generator=generator, - ) - latents = latents.to(torch_device) - - ttnn_scheduler.set_timesteps(num_inference_steps) - - latents = latents * ttnn_scheduler.init_noise_sigma - ttnn_latents = torch.tensor(latents) - - iter = 0 - config = unet.config - - parameters = preprocess_model_parameters( - initialize_model=lambda: unet, custom_preprocessor=custom_preprocessor, device=device - ) - input_height = 64 - input_width = 64 - reader_patterns_cache = {} if height == 512 and width == 512 else None - - model = UNet2D(device, parameters, 2, input_height, input_width, reader_patterns_cache) - # # Denoising loop - for t in tqdm(ttnn_scheduler.timesteps): - # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes. - ttnn_latent_model_input = latent_expansion(ttnn_latents, ttnn_scheduler, t) - ttnn_latent_model_input = ttnn.from_torch( - ttnn_latent_model_input, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device - ) - - _t = constant_prop_time_embeddings(t, ttnn_latent_model_input, unet.time_proj) - _t = _t.unsqueeze(0).unsqueeze(0) - _t = ttnn.from_torch(_t, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device) - - # predict the noise residual - with torch.no_grad(): - ttnn_output = model( - ttnn_latent_model_input, # input - timestep=_t, - encoder_hidden_states=ttnn_text_embeddings, - class_labels=None, - attention_mask=None, - cross_attention_kwargs=None, - return_dict=True, - config=config, - ) - noise_pred = ttnn.to_torch(ttnn_output) - - # perform guidance - noise_pred = guide(noise_pred, guidance_scale, t) - - ttnn_latents = ttnn_scheduler.step(noise_pred, t, ttnn_latents).prev_sample - save_image_and_latents(ttnn_latents, iter, vae, pre_fix=f"{experiment_name}_tt", pre_fix2="") - - iter += 1 - enable_persistent_kernel_cache() - - latents = ttnn_latents - # scale and decode the image latents with vae - latents = 1 / 0.18215 * latents - with torch.no_grad(): - image = vae.decode(latents).sample - - # Image post-processing - image = (image / 2 + 0.5).clamp(0, 1) - image = image.detach().cpu().permute(0, 2, 3, 1).numpy() - images = (image * 255).round().astype("uint8") - pil_images = [Image.fromarray(image) for image in images][0] - ttnn_output_path = f"{experiment_name}_ttnn.png" - pil_images.save(ttnn_output_path) - - ref_paths = [ref_img_path, ref_img_path] - ttnn_paths = [ttnn_output_path, ttnn_output_path] - - ref_images = preprocess_images(ref_paths) - ttnn_images = preprocess_images(ttnn_paths) - - -def test_tt2_multiple_iteration(device, reset_seeds, input_path): - # 30 iterations, generate 512x512 image - return run_demo_inference_diffusiondb(device, reset_seeds, input_path, 30, (512, 512)) diff --git a/tests/ttnn/integration_tests/stable_diffusion/test_basic_transformer_block.py b/models/demos/wormhole/stable_diffusion/tests/test_basic_transformer_block.py similarity index 100% rename from tests/ttnn/integration_tests/stable_diffusion/test_basic_transformer_block.py rename to models/demos/wormhole/stable_diffusion/tests/test_basic_transformer_block.py diff --git a/tests/ttnn/integration_tests/stable_diffusion/test_cross_attention.py b/models/demos/wormhole/stable_diffusion/tests/test_cross_attention.py similarity index 100% rename from tests/ttnn/integration_tests/stable_diffusion/test_cross_attention.py rename to models/demos/wormhole/stable_diffusion/tests/test_cross_attention.py diff --git a/tests/ttnn/integration_tests/stable_diffusion/test_cross_attn_up_block_2d_new_conv.py b/models/demos/wormhole/stable_diffusion/tests/test_cross_attn_up_block_2d.py similarity index 100% rename from tests/ttnn/integration_tests/stable_diffusion/test_cross_attn_up_block_2d_new_conv.py rename to models/demos/wormhole/stable_diffusion/tests/test_cross_attn_up_block_2d.py diff --git a/tests/ttnn/integration_tests/stable_diffusion/test_demo.py b/models/demos/wormhole/stable_diffusion/tests/test_demo.py similarity index 91% rename from tests/ttnn/integration_tests/stable_diffusion/test_demo.py rename to models/demos/wormhole/stable_diffusion/tests/test_demo.py index 5c8dc03b967..36b73e70e3c 100644 --- a/tests/ttnn/integration_tests/stable_diffusion/test_demo.py +++ b/models/demos/wormhole/stable_diffusion/tests/test_demo.py @@ -29,6 +29,8 @@ ((512, 512),), ) def test_demo_sd(device, reset_seeds, input_path, num_prompts, num_inference_steps, image_size): + if device.core_grid.y != 8: + pytest.skip("Needs 8x8 Grid") demo(device, reset_seeds, input_path, num_prompts, num_inference_steps, image_size) @@ -48,4 +50,6 @@ def test_demo_sd(device, reset_seeds, input_path, num_prompts, num_inference_ste ((512, 512),), ) def test_demo_sd_db(device, reset_seeds, input_path, num_prompts, num_inference_steps, image_size): + if device.core_grid.y != 8: + pytest.skip("Needs 8x8 Grid") demo_db(device, reset_seeds, input_path, num_prompts, num_inference_steps, image_size) diff --git a/tests/ttnn/integration_tests/stable_diffusion/test_embedding.py b/models/demos/wormhole/stable_diffusion/tests/test_embedding.py similarity index 100% rename from tests/ttnn/integration_tests/stable_diffusion/test_embedding.py rename to models/demos/wormhole/stable_diffusion/tests/test_embedding.py diff --git a/tests/ttnn/integration_tests/stable_diffusion/test_feedforward.py b/models/demos/wormhole/stable_diffusion/tests/test_feedforward.py similarity index 100% rename from tests/ttnn/integration_tests/stable_diffusion/test_feedforward.py rename to models/demos/wormhole/stable_diffusion/tests/test_feedforward.py diff --git a/tests/ttnn/integration_tests/stable_diffusion/test_geglu.py b/models/demos/wormhole/stable_diffusion/tests/test_geglu.py similarity index 100% rename from tests/ttnn/integration_tests/stable_diffusion/test_geglu.py rename to models/demos/wormhole/stable_diffusion/tests/test_geglu.py diff --git a/tests/device_perf_tests/stable_diffusion/test_perf_stable_diffusion.py b/models/demos/wormhole/stable_diffusion/tests/test_perf.py similarity index 100% rename from tests/device_perf_tests/stable_diffusion/test_perf_stable_diffusion.py rename to models/demos/wormhole/stable_diffusion/tests/test_perf.py diff --git a/tests/ttnn/integration_tests/stable_diffusion/test_resnet_block_2d_new_conv.py b/models/demos/wormhole/stable_diffusion/tests/test_resnet_block_2d.py similarity index 66% rename from tests/ttnn/integration_tests/stable_diffusion/test_resnet_block_2d_new_conv.py rename to models/demos/wormhole/stable_diffusion/tests/test_resnet_block_2d.py index 51afb5afd0d..91a0f3755e5 100644 --- a/tests/ttnn/integration_tests/stable_diffusion/test_resnet_block_2d_new_conv.py +++ b/models/demos/wormhole/stable_diffusion/tests/test_resnet_block_2d.py @@ -25,90 +25,6 @@ def ttnn_to_torch(input): return input -@skip_for_grayskull() -@pytest.mark.parametrize( - "batch_size, in_channels, input_height, input_width, index1,index2,block_name,out_channels", - [ - (2, 320, 32, 32, 0, 0, "down", None), - (2, 320, 16, 16, 0, 0, "down", None), - (2, 640, 16, 16, 1, 1, "down", None), - (2, 640, 8, 8, 1, 1, "down", None), - (2, 1280, 8, 8, 2, 1, "down", None), - (2, 1280, 4, 4, 2, 1, "down", None), - (2, 2560, 4, 4, 0, 0, "up", 1280), - (2, 2560, 8, 8, 0, 0, "up", 1280), - (2, 1920, 8, 8, 2, 0, "up", 640), - (2, 1920, 16, 16, 2, 0, "up", 640), - (2, 1280, 16, 16, 3, 0, "down", None), - (2, 960, 16, 16, 3, 0, "up", 320), - (2, 960, 32, 32, 3, 0, "up", 320), - (2, 640, 32, 32, 3, 1, "up", 320), - ], -) -def test_resnet_block_2d_256x256( - device, batch_size, in_channels, input_height, input_width, index1, index2, block_name, out_channels -): - pytest.skip() - # setup pytorch model - model_name = "CompVis/stable-diffusion-v1-4" - pipe = StableDiffusionPipeline.from_pretrained(model_name, torch_dtype=torch.float32) - - model = pipe.unet - model.eval() - - parameters = preprocess_model_parameters( - model_name=model_name, initialize_model=lambda: model, custom_preprocessor=custom_preprocessor, device=device - ) - - if block_name == "up": - parameters = parameters.up_blocks[index1].resnets[index2] - resnet = pipe.unet.up_blocks[index1].resnets[index2] - elif block_name == "down": - parameters = parameters.down_blocks[index1].resnets[index2] - resnet = pipe.unet.down_blocks[index1].resnets[index2] - else: - parameters = parameters.mid_block.resnets[index2] - resnet = pipe.unet.mid_block.resnets[index2] - - ############ start of residual block ############# - temb_channels = 1280 - groups = 32 - time_embedding_norm = "default" - output_scale_factor = 1 - use_in_shortcut = None - ########## end of residual block ############# - hidden_states_shape = [batch_size, in_channels, input_height, input_width] - temb_shape = [1, 1, 2, 1280] - - input = torch.randn(hidden_states_shape) - temb = torch.randn(temb_shape) - - torch_output = resnet(input, temb.squeeze(0).squeeze(0)) - - input = ttnn.from_torch(input, ttnn.bfloat16) - input = ttnn.to_layout(input, ttnn.TILE_LAYOUT) - input = ttnn.to_device(input, device, memory_config=ttnn.L1_MEMORY_CONFIG) - - temb = ttnn.from_torch(temb, ttnn.bfloat16) - temb = ttnn.to_layout(temb, ttnn.TILE_LAYOUT) - temb = ttnn.to_device(temb, device, memory_config=ttnn.L1_MEMORY_CONFIG) - ttnn_output = resnetBlock2D( - input, - temb=temb, - temb_channels=temb_channels, - time_embedding_norm=time_embedding_norm, - in_channels=in_channels, - out_channels=out_channels, - use_in_shortcut=use_in_shortcut, - groups=groups, - output_scale_factor=output_scale_factor, - parameters=parameters, - device=device, - ) - ttnn_output = ttnn_to_torch(ttnn_output) - assert_with_pcc(torch_output, ttnn_output, pcc=0.99) - - @skip_for_grayskull() @pytest.mark.parametrize("device_params", [{"l1_small_size": 32768}], indirect=True) @pytest.mark.parametrize( diff --git a/tests/ttnn/integration_tests/stable_diffusion/test_sharded_matmuls.py b/models/demos/wormhole/stable_diffusion/tests/test_sharded_matmuls.py similarity index 100% rename from tests/ttnn/integration_tests/stable_diffusion/test_sharded_matmuls.py rename to models/demos/wormhole/stable_diffusion/tests/test_sharded_matmuls.py diff --git a/tests/ttnn/integration_tests/stable_diffusion/test_transformer_2d_model_new_conv.py b/models/demos/wormhole/stable_diffusion/tests/test_transformer_2d_model.py similarity index 100% rename from tests/ttnn/integration_tests/stable_diffusion/test_transformer_2d_model_new_conv.py rename to models/demos/wormhole/stable_diffusion/tests/test_transformer_2d_model.py diff --git a/tests/ttnn/integration_tests/stable_diffusion/test_unet_2d_condition_model_new_conv.py b/models/demos/wormhole/stable_diffusion/tests/test_unet_2d_condition_model.py similarity index 98% rename from tests/ttnn/integration_tests/stable_diffusion/test_unet_2d_condition_model_new_conv.py rename to models/demos/wormhole/stable_diffusion/tests/test_unet_2d_condition_model.py index 35b1253ea54..72efdb4e178 100644 --- a/tests/ttnn/integration_tests/stable_diffusion/test_unet_2d_condition_model_new_conv.py +++ b/models/demos/wormhole/stable_diffusion/tests/test_unet_2d_condition_model.py @@ -63,7 +63,6 @@ def unsqueeze_all_params_to_4d(params): @skip_for_grayskull() -@pytest.mark.skipif(is_wormhole_b0() or is_blackhole(), reason="#10923: CB / L1 buffer clash") @pytest.mark.parametrize( "device_params", [{"l1_small_size": 32768}], ids=["device_params=l1_small_size_24576"], indirect=True ) @@ -204,7 +203,7 @@ def test_unet_2d_condition_model_512x512(device, batch_size, in_channels, input_ # print(iter) # print(f"Time taken for 50 iterations: {total_time}") # print(f"Samples per second: {50 / total_time}") - passing, output = comp_pcc(torch_output, ttnn_output, pcc=0.99) + passing, output = comp_pcc(torch_output, ttnn_output, pcc=0.981) print(output) assert passing diff --git a/tests/ttnn/integration_tests/stable_diffusion/test_upblock_2d_new_conv.py b/models/demos/wormhole/stable_diffusion/tests/test_upblock_2d.py similarity index 100% rename from tests/ttnn/integration_tests/stable_diffusion/test_upblock_2d_new_conv.py rename to models/demos/wormhole/stable_diffusion/tests/test_upblock_2d.py diff --git a/tests/ttnn/integration_tests/stable_diffusion/test_upsample_2d_new_conv.py b/models/demos/wormhole/stable_diffusion/tests/test_upsample_2d.py similarity index 100% rename from tests/ttnn/integration_tests/stable_diffusion/test_upsample_2d_new_conv.py rename to models/demos/wormhole/stable_diffusion/tests/test_upsample_2d.py diff --git a/tests/ttnn/integration_tests/stable_diffusion/test_upsample_nearest_2d.py b/models/demos/wormhole/stable_diffusion/tests/test_upsample_nearest_2d.py similarity index 100% rename from tests/ttnn/integration_tests/stable_diffusion/test_upsample_nearest_2d.py rename to models/demos/wormhole/stable_diffusion/tests/test_upsample_nearest_2d.py diff --git a/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_downsample_2d_new_conv.py b/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_downsample_2d_new_conv.py index 2ad02078d71..570d2457f1a 100644 --- a/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_downsample_2d_new_conv.py +++ b/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_downsample_2d_new_conv.py @@ -126,11 +126,7 @@ def __call__( conv_config = ttnn.Conv2dConfig( dtype=ttnn.bfloat8_b, weights_dtype=ttnn.bfloat8_b, - math_fidelity=ttnn.MathFidelity.LoFi, activation="", - math_approx_mode_enabled=True, - fp32_dest_acc_enabled=True, - packer_l1_accum_enabled=False, shard_layout=self.shard_layout, input_channels_alignment=32, transpose_shards=False, @@ -140,10 +136,17 @@ def __call__( if hidden_states.memory_config() != self.input_memory_config: hidden_states = ttnn.to_memory_config(hidden_states, self.input_memory_config) + compute_config = ttnn.init_device_compute_kernel_config( + self.device.arch(), + math_fidelity=ttnn.MathFidelity.LoFi, + math_approx_mode=True, + fp32_dest_acc_en=True, + packer_l1_acc=False, + ) if self.conv_config_override and "act_block_h" in self.conv_config_override: conv_config.act_block_h_override = self.conv_config_override["act_block_h"] - [hidden_states, _out_height, _out_width, self.conv_weights, self.conv_bias] = ttnn.conv2d( + [hidden_states, [self.conv_weights, self.conv_bias]] = ttnn.conv2d( input_tensor=hidden_states, in_channels=self.in_channels, out_channels=self.out_channels, @@ -157,7 +160,10 @@ def __call__( weight_tensor=self.conv_weights, bias_tensor=self.conv_bias, conv_config=conv_config, + compute_config=compute_config, conv_op_cache=conv_cache, + return_output_dim=False, + return_weights_and_bias=True, ) # hidden_states = run_ttnn_conv_with_pre_and_post_tensor_formatting( # self.device, diff --git a/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_resnetblock2d_new_conv.py b/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_resnetblock2d_new_conv.py index 4cedbdea78c..45024d9c9d9 100644 --- a/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_resnetblock2d_new_conv.py +++ b/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_resnetblock2d_new_conv.py @@ -459,19 +459,22 @@ def __call__( conv_config = ttnn.Conv2dConfig( dtype=ttnn.bfloat8_b, weights_dtype=ttnn.bfloat8_b, - math_fidelity=ttnn.MathFidelity.LoFi, activation="", shard_layout=self.conv1_shard_layout, - math_approx_mode_enabled=True, - fp32_dest_acc_enabled=True, - packer_l1_accum_enabled=False, input_channels_alignment=32, transpose_shards=False, reshard_if_not_optimal=False, ) + compute_config = ttnn.init_device_compute_kernel_config( + self.device.arch(), + math_fidelity=ttnn.MathFidelity.LoFi, + math_approx_mode=True, + fp32_dest_acc_en=True, + packer_l1_acc=False, + ) if self.conv1_config_override and "act_block_h" in self.conv2_config_override: conv_config.act_block_h_override = self.conv1_config_override["act_block_h"] - [hidden_states, _out_height, _out_width, self.conv1s_weights[0], self.conv1s_bias[0]] = ttnn.conv2d( + [hidden_states, [self.conv1s_weights[0], self.conv1s_bias[0]]] = ttnn.conv2d( input_tensor=hidden_states, weight_tensor=self.conv1s_weights[0], in_channels=self.conv1_in_channels, @@ -485,7 +488,10 @@ def __call__( input_height=self.conv1_input_height, input_width=self.conv1_input_width, conv_config=conv_config, + compute_config=compute_config, conv_op_cache=conv_cache, + return_output_dim=False, + return_weights_and_bias=True, ) else: @@ -529,26 +535,26 @@ def __call__( conv_config = ttnn.Conv2dConfig( dtype=ttnn.bfloat8_b, weights_dtype=ttnn.bfloat8_b, - math_fidelity=ttnn.MathFidelity.LoFi, activation="", shard_layout=ttnn.TensorMemoryLayout.BLOCK_SHARDED, - math_approx_mode_enabled=True, - fp32_dest_acc_enabled=True, - packer_l1_accum_enabled=False, input_channels_alignment=32, transpose_shards=False, reshard_if_not_optimal=False, ) - + compute_config = ttnn.init_device_compute_kernel_config( + self.device.arch(), + math_fidelity=ttnn.MathFidelity.LoFi, + math_approx_mode=True, + fp32_dest_acc_en=True, + packer_l1_acc=False, + ) if self.conv1_config_override and "act_block_h" in self.conv2_config_override: conv_config.act_block_h_override = self.conv1_config_override["act_block_h"] [ split_hidden_states[i], - _out_height, - _out_width, - self.conv1s_weights[i], - self.conv1s_bias[i], + [_out_height, _out_width], + [self.conv1s_weights[i], self.conv1s_bias[i]], ] = ttnn.conv2d( input_tensor=split_hidden_states[i], weight_tensor=self.conv1s_weights[i], @@ -563,7 +569,10 @@ def __call__( input_height=self.conv1_input_height, input_width=self.conv1_input_width, conv_config=conv_config, + compute_config=compute_config, conv_op_cache=conv_cache, + return_output_dim=True, + return_weights_and_bias=True, ) if i != 0: split_hidden_states[i] = ttnn.add( @@ -658,19 +667,22 @@ def __call__( conv_config = ttnn.Conv2dConfig( dtype=ttnn.bfloat8_b, weights_dtype=ttnn.bfloat8_b, - math_fidelity=ttnn.MathFidelity.LoFi, activation="", shard_layout=ttnn.TensorMemoryLayout.BLOCK_SHARDED, - math_approx_mode_enabled=True, - fp32_dest_acc_enabled=True, - packer_l1_accum_enabled=False, input_channels_alignment=32, transpose_shards=False, reshard_if_not_optimal=False, ) + compute_config = ttnn.init_device_compute_kernel_config( + self.device.arch(), + math_fidelity=ttnn.MathFidelity.LoFi, + math_approx_mode=True, + fp32_dest_acc_en=True, + packer_l1_acc=False, + ) if self.conv2_config_override and "act_block_h" in self.conv2_config_override: conv_config.act_block_h_override = self.conv2_config_override["act_block_h"] - [hidden_states, _out_height, _out_width, self.conv2_weights, self.conv2_bias] = ttnn.conv2d( + [hidden_states, [_out_height, _out_width], [self.conv2_weights, self.conv2_bias]] = ttnn.conv2d( input_tensor=hidden_states, weight_tensor=self.conv2_weights, bias_tensor=self.conv2_bias, @@ -684,7 +696,10 @@ def __call__( input_height=self.conv2_input_height, input_width=self.conv2_input_width, conv_config=conv_config, + compute_config=compute_config, conv_op_cache=conv_cache, + return_output_dim=True, + return_weights_and_bias=True, ) use_in_shortcut = in_channels != out_channels if use_in_shortcut is None else use_in_shortcut @@ -702,17 +717,24 @@ def __call__( conv_config = ttnn.Conv2dConfig( dtype=ttnn.bfloat8_b, weights_dtype=ttnn.bfloat8_b, - math_fidelity=ttnn.MathFidelity.LoFi, activation="", shard_layout=ttnn.TensorMemoryLayout.BLOCK_SHARDED, - math_approx_mode_enabled=True, - fp32_dest_acc_enabled=True, - packer_l1_accum_enabled=False, input_channels_alignment=32, transpose_shards=False, reshard_if_not_optimal=False, ) - [input_tensor, _out_height, _out_width, self.conv_shortcut_weights, self.conv_shortcut_bias] = ttnn.conv2d( + compute_config = ttnn.init_device_compute_kernel_config( + self.device.arch(), + math_fidelity=ttnn.MathFidelity.LoFi, + math_approx_mode=True, + fp32_dest_acc_en=True, + packer_l1_acc=False, + ) + [ + input_tensor, + [_out_height, _out_width], + [self.conv_shortcut_weights, self.conv_shortcut_bias], + ] = ttnn.conv2d( input_tensor=input_tensor, weight_tensor=self.conv_shortcut_weights, in_channels=self.conv_shortcut_in_channels, @@ -726,7 +748,10 @@ def __call__( input_height=self.conv_shortcut_input_height, input_width=self.conv_shortcut_input_width, conv_config=conv_config, + compute_config=compute_config, conv_op_cache=conv_cache, + return_output_dim=True, + return_weights_and_bias=True, ) if ttnn.get_memory_config(input_tensor) != ttnn.get_memory_config(hidden_states): diff --git a/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_transformer_2d_new_conv.py b/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_transformer_2d_new_conv.py index 12e4d543207..e89a957357e 100644 --- a/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_transformer_2d_new_conv.py +++ b/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_transformer_2d_new_conv.py @@ -242,14 +242,17 @@ def __call__( conv_config = ttnn.Conv2dConfig( dtype=ttnn.bfloat8_b, weights_dtype=ttnn.bfloat8_b, - math_fidelity=ttnn.MathFidelity.LoFi, activation="", shard_layout=ttnn.TensorMemoryLayout.BLOCK_SHARDED, input_channels_alignment=32, - fp32_dest_acc_enabled=self.compute_kernel_config.fp32_dest_acc_en, transpose_shards=False, ) - [hidden_states, _out_height, _out_width, self.proj_in_conv_weights, self.proj_in_conv_bias] = ttnn.conv2d( + compute_config = ttnn.init_device_compute_kernel_config( + self.device.arch(), + math_fidelity=ttnn.MathFidelity.LoFi, + fp32_dest_acc_en=self.compute_kernel_config.fp32_dest_acc_en, + ) + [hidden_states, [self.proj_in_conv_weights, self.proj_in_conv_bias]] = ttnn.conv2d( input_tensor=hidden_states, in_channels=self.proj_in_in_channels, out_channels=self.proj_in_out_channels, @@ -263,7 +266,10 @@ def __call__( weight_tensor=self.proj_in_conv_weights, bias_tensor=self.proj_in_conv_bias, conv_config=conv_config, + compute_config=compute_config, conv_op_cache=conv_cache, + return_output_dim=False, + return_weights_and_bias=True, ) inner_dim = hidden_states.shape[-1] @@ -293,10 +299,8 @@ def __call__( # hidden_states = ttnn.to_memory_config(hidden_states, self.proj_out.conv.input_sharded_memory_config) [ hidden_states, - _out_height, - _out_width, - self.proj_out_conv_weights, - self.proj_out_conv_bias, + [_out_height, _out_width], + [self.proj_out_conv_weights, self.proj_out_conv_bias], ] = ttnn.conv2d( input_tensor=hidden_states, in_channels=self.proj_out_in_channels, @@ -312,6 +316,8 @@ def __call__( bias_tensor=self.proj_out_conv_bias, conv_config=conv_config, conv_op_cache=conv_cache, + return_output_dim=True, + return_weights_and_bias=True, ) if output_bfloat16: diff --git a/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_unet_2d_condition_model_new_conv.py b/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_unet_2d_condition_model_new_conv.py index 9cbdfff2f48..1003c1efc4e 100644 --- a/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_unet_2d_condition_model_new_conv.py +++ b/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_unet_2d_condition_model_new_conv.py @@ -383,18 +383,21 @@ def __call__( conv_config = ttnn.Conv2dConfig( dtype=ttnn.bfloat8_b, weights_dtype=ttnn.bfloat8_b, - math_fidelity=ttnn.MathFidelity.LoFi, activation="", - math_approx_mode_enabled=True, - fp32_dest_acc_enabled=True, - packer_l1_accum_enabled=False, shard_layout=shard_layout, input_channels_alignment=32, transpose_shards=False, reshard_if_not_optimal=True, ) + compute_config = ttnn.init_device_compute_kernel_config( + self.device.arch(), + math_fidelity=ttnn.MathFidelity.LoFi, + math_approx_mode=True, + fp32_dest_acc_en=True, + packer_l1_acc=False, + ) - [sample, _out_height, _out_width, self.conv_in_weights, self.conv_in_bias] = ttnn.conv2d( + [sample, [self.conv_in_weights, self.conv_in_bias]] = ttnn.conv2d( input_tensor=sample, weight_tensor=self.conv_in_weights, bias_tensor=self.conv_in_bias, @@ -408,7 +411,10 @@ def __call__( input_height=self.input_height, input_width=self.input_width, conv_config=conv_config, + compute_config=compute_config, conv_op_cache=conv_cache, + return_output_dim=False, + return_weights_and_bias=True, ) sample = ttnn.reallocate(sample) # TODO: Test remove @@ -646,18 +652,21 @@ def __call__( conv_config = ttnn.Conv2dConfig( dtype=ttnn.bfloat8_b, weights_dtype=ttnn.bfloat8_b, - math_fidelity=ttnn.MathFidelity.LoFi, activation="", shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED, - math_approx_mode_enabled=True, - fp32_dest_acc_enabled=True, - packer_l1_accum_enabled=False, input_channels_alignment=32, act_block_h_override=64, transpose_shards=False, reshard_if_not_optimal=True, ) - [sample, _out_height, _out_width, self.conv_out_weights, self.conv_out_bias] = ttnn.conv2d( + compute_config = ttnn.init_device_compute_kernel_config( + self.device.arch(), + math_fidelity=ttnn.MathFidelity.LoFi, + math_approx_mode=True, + fp32_dest_acc_en=True, + packer_l1_acc=False, + ) + [sample, [self.conv_out_weights, self.conv_out_bias]] = ttnn.conv2d( input_tensor=sample, in_channels=self.conv_out_in_channels, out_channels=self.conv_out_out_channels, @@ -671,7 +680,10 @@ def __call__( weight_tensor=self.conv_out_weights, bias_tensor=self.conv_out_bias, conv_config=conv_config, + compute_config=compute_config, conv_op_cache=conv_cache, + return_output_dim=False, + return_weights_and_bias=True, ) sample = ttnn.to_memory_config(sample, ttnn.L1_MEMORY_CONFIG) sample = ttnn.clone(sample, memory_config=ttnn.L1_MEMORY_CONFIG, dtype=ttnn.bfloat16) diff --git a/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_upsample_2d_new_conv.py b/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_upsample_2d_new_conv.py index 622a63065db..52e9fb5c913 100644 --- a/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_upsample_2d_new_conv.py +++ b/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_upsample_2d_new_conv.py @@ -91,19 +91,22 @@ def __call__(self, input, in_channels, out_channels): conv_config = ttnn.Conv2dConfig( dtype=ttnn.bfloat8_b, weights_dtype=ttnn.bfloat8_b, - math_fidelity=ttnn.MathFidelity.LoFi, activation="", shard_layout=ttnn.TensorMemoryLayout.BLOCK_SHARDED, - math_approx_mode_enabled=True, - fp32_dest_acc_enabled=True, - packer_l1_accum_enabled=False, input_channels_alignment=32, transpose_shards=False, reshard_if_not_optimal=False, # Reshard has error : 1616 Bytes unique+common runtime args targeting kernel reshard_reader on (x=0,y=0) are too large. Cannot be written as they will run into memory region reserved for result. Max allowable size is 1024 Bytes ) + compute_config = ttnn.init_device_compute_kernel_config( + self.device.arch(), + math_fidelity=ttnn.MathFidelity.LoFi, + math_approx_mode=True, + fp32_dest_acc_en=True, + packer_l1_acc=False, + ) if self.conv_config_override and "act_block_h" in self.conv_config_override: conv_config.act_block_h_override = self.conv_config_override["act_block_h"] - [tt_out, _out_height, _out_width, self.conv_weight_tensor, self.conv_bias_tensor] = ttnn.conv2d( + [tt_out, [self.conv_weight_tensor, self.conv_bias_tensor]] = ttnn.conv2d( input_tensor=tt_out, in_channels=self.conv_in_channels, out_channels=self.conv_out_channels, @@ -117,6 +120,9 @@ def __call__(self, input, in_channels, out_channels): weight_tensor=self.conv_weight_tensor, bias_tensor=self.conv_bias_tensor, conv_config=conv_config, + compute_config=compute_config, conv_op_cache=conv_cache, + return_output_dim=False, + return_weights_and_bias=True, ) return tt_out diff --git a/models/demos/yolov4/ttnn/common.py b/models/demos/yolov4/ttnn/common.py index b293a6db751..1579f9112f9 100644 --- a/models/demos/yolov4/ttnn/common.py +++ b/models/demos/yolov4/ttnn/common.py @@ -80,13 +80,9 @@ def __call__(self, device, input_tensor): conv_config = ttnn.Conv2dConfig( dtype=ttnn.bfloat16, weights_dtype=ttnn.bfloat8_b, - math_fidelity=ttnn.MathFidelity.LoFi, activation=self.activation, shard_layout=self.shard_layout, - math_approx_mode_enabled=True, - fp32_dest_acc_enabled=False, act_block_w_div=1, - packer_l1_accum_enabled=False, input_channels_alignment=16 if self.input_params[3] < 16 else 32, transpose_shards=False, reshard_if_not_optimal=self.reshard, @@ -96,10 +92,17 @@ def __call__(self, device, input_tensor): enable_act_double_buffer=self.enable_act_double_buffer, output_layout=self.output_layout, ) + compute_config = ttnn.init_device_compute_kernel_config( + device.arch(), + math_fidelity=ttnn.MathFidelity.LoFi, + math_approx_mode=False, + fp32_dest_acc_en=False, + packer_l1_acc=False, + ) if self.act_block_h is not None: conv_config.act_block_h_override = self.act_block_h - [output_tensor, _out_height, _out_width, self.weights, self.bias] = ttnn.conv2d( + output_tensor, [self.weights, self.bias] = ttnn.conv2d( input_tensor=input_tensor, weight_tensor=self.weights, bias_tensor=self.bias, @@ -113,5 +116,8 @@ def __call__(self, device, input_tensor): input_height=self.input_params[1], input_width=self.input_params[2], conv_config=conv_config, + compute_config=compute_config, + return_output_dim=False, + return_weights_and_bias=True, ) return output_tensor diff --git a/models/demos/yolov4/ttnn/neck.py b/models/demos/yolov4/ttnn/neck.py index f7e3b541278..d86d9faa527 100644 --- a/models/demos/yolov4/ttnn/neck.py +++ b/models/demos/yolov4/ttnn/neck.py @@ -262,7 +262,7 @@ def __call__(self, device, input_tensor): ttnn.TensorMemoryLayout.BLOCK_SHARDED, ttnn.types.BufferType.L1, shard_spec ) - output_tensor_upsample_1 = ttnn.upsample(output_tensor, (2, 2, 1), memory_config=out_sharded_mem_config) + output_tensor_upsample_1 = ttnn.upsample(output_tensor, (2, 2), memory_config=out_sharded_mem_config) output_tensor_upsample_1 = ttnn.sharded_to_interleaved(output_tensor_upsample_1, ttnn.L1_MEMORY_CONFIG) output_tensor_upsample_1 = ttnn.reshape(output_tensor_upsample_1, (1, 1, 400, 256)) output_tensor_upsample_1 = ttnn.to_layout(output_tensor_upsample_1, layout=ttnn.TILE_LAYOUT) @@ -336,7 +336,7 @@ def __call__(self, device, input_tensor): ttnn.TensorMemoryLayout.BLOCK_SHARDED, ttnn.types.BufferType.L1, shard_spec ) - output_tensor_upsample_2 = ttnn.upsample(output_tensor, (2, 2, 1), memory_config=out_sharded_mem_config) + output_tensor_upsample_2 = ttnn.upsample(output_tensor, (2, 2), memory_config=out_sharded_mem_config) output_tensor_upsample_2 = ttnn.sharded_to_interleaved(output_tensor_upsample_2, ttnn.L1_MEMORY_CONFIG) output_tensor_upsample_2 = ttnn.reshape(output_tensor_upsample_2, (1, 1, 1600, 128)) output_tensor_upsample_2 = ttnn.to_layout(output_tensor_upsample_2, ttnn.TILE_LAYOUT) diff --git a/models/experimental/functional_unet/tt/unet_shallow_ttnn.py b/models/experimental/functional_unet/tt/unet_shallow_ttnn.py index 215399ea23b..8a5157d51dc 100644 --- a/models/experimental/functional_unet/tt/unet_shallow_ttnn.py +++ b/models/experimental/functional_unet/tt/unet_shallow_ttnn.py @@ -114,10 +114,8 @@ def __init__( self.conv_config = ttnn.Conv2dConfig( dtype=activation_dtype, weights_dtype=weights_dtype, - math_fidelity=ttnn.MathFidelity.LoFi, shard_layout=shard_layout, deallocate_activation=self.deallocate_activation, - packer_l1_accum_enabled=False, enable_act_double_buffer=( conv.use_activation_double_buffer if "use_activation_double_buffer" in conv else False ), @@ -128,6 +126,12 @@ def __init__( input_channels_alignment=conv.input_channels_alignment if "input_channels_alignment" in conv else 32, reshard_if_not_optimal=reshard_if_not_optimal, ) + self.compute_config = ttnn.init_device_compute_kernel_config( + device.arch(), + math_fidelity=ttnn.MathFidelity.LoFi, + fp32_dest_acc_en=True, + packer_l1_acc=False, + ) config_override = conv.conv_blocking_and_parallelization_config_override if config_override and "act_block_h" in config_override: self.conv_config.act_block_h_override = config_override["act_block_h"] @@ -143,7 +147,7 @@ def __init__( self.bias = ttnn.from_torch(bias, dtype=ttnn.float32, mesh_mapper=mesh_mapper) def __call__(self, x): - x, _, _, self.weight, self.bias = ttnn.conv2d( + x, [self.weight, self.bias] = ttnn.conv2d( input_tensor=x, weight_tensor=self.weight, bias_tensor=self.bias, @@ -157,8 +161,11 @@ def __call__(self, x): stride=self.stride, padding=self.padding, conv_config=self.conv_config, + compute_config=self.compute_config, conv_op_cache=self.cache, groups=2, + return_output_dim=False, + return_weights_and_bias=True, ) return x @@ -257,7 +264,7 @@ def upsample(self, x): else: x = ttnn.interleaved_to_sharded(x, shardspec) - x = ttnn.upsample(x, (2, 2, 1), memory_config=x.memory_config()) + x = ttnn.upsample(x, (2, 2), memory_config=x.memory_config()) x = ttnn.reshape( x, (1, 1, self.conv1.batch_size * self.conv1.input_height * self.conv1.input_width, x.shape[-1]) ) diff --git a/models/perf/benchmarking_utils.py b/models/perf/benchmarking_utils.py index 5ca7ae269c8..8136c4ef0c1 100644 --- a/models/perf/benchmarking_utils.py +++ b/models/perf/benchmarking_utils.py @@ -16,6 +16,24 @@ def __init__(self): self.start_times = dict() self.end_times = dict() + def __call__(self, step_name: str, iteration: int = 0): + # Return a context manager for this step + return self.StepContext(self, step_name, iteration) + + class StepContext: + def __init__(self, profiler, step_name: str, iteration: int): + self.profiler = profiler + self.step_name = step_name + self.iteration = iteration + + def __enter__(self): + self.profiler.start(self.step_name, self.iteration) + return self.profiler + + def __exit__(self, exc_type, exc_val, exc_tb): + self.profiler.end(self.step_name, self.iteration) + return False + def start(self, step_name: str, iteration: int = 0): self.start_times[(iteration, step_name)] = datetime.now(tz=pytz.UTC) diff --git a/tests/nightly/single_card/stable_diffusion/test_basic_transformer_block.py b/tests/nightly/single_card/stable_diffusion/test_basic_transformer_block.py new file mode 120000 index 00000000000..61408ffa9e7 --- /dev/null +++ b/tests/nightly/single_card/stable_diffusion/test_basic_transformer_block.py @@ -0,0 +1 @@ +../../../../models/demos/wormhole/stable_diffusion/tests/test_basic_transformer_block.py \ No newline at end of file diff --git a/tests/nightly/single_card/stable_diffusion/test_cross_attention.py b/tests/nightly/single_card/stable_diffusion/test_cross_attention.py new file mode 120000 index 00000000000..c161012b886 --- /dev/null +++ b/tests/nightly/single_card/stable_diffusion/test_cross_attention.py @@ -0,0 +1 @@ +../../../../models/demos/wormhole/stable_diffusion/tests/test_cross_attention.py \ No newline at end of file diff --git a/tests/nightly/single_card/stable_diffusion/test_cross_attn_up_block_2d.py b/tests/nightly/single_card/stable_diffusion/test_cross_attn_up_block_2d.py new file mode 120000 index 00000000000..8fce2d91ed2 --- /dev/null +++ b/tests/nightly/single_card/stable_diffusion/test_cross_attn_up_block_2d.py @@ -0,0 +1 @@ +../../../../models/demos/wormhole/stable_diffusion/tests/test_cross_attn_up_block_2d.py \ No newline at end of file diff --git a/tests/nightly/single_card/stable_diffusion/test_demo.py b/tests/nightly/single_card/stable_diffusion/test_demo.py new file mode 120000 index 00000000000..c375047f633 --- /dev/null +++ b/tests/nightly/single_card/stable_diffusion/test_demo.py @@ -0,0 +1 @@ +../../../../models/demos/wormhole/stable_diffusion/tests/test_demo.py \ No newline at end of file diff --git a/tests/nightly/single_card/stable_diffusion/test_embedding.py b/tests/nightly/single_card/stable_diffusion/test_embedding.py new file mode 120000 index 00000000000..3e89c128424 --- /dev/null +++ b/tests/nightly/single_card/stable_diffusion/test_embedding.py @@ -0,0 +1 @@ +../../../../models/demos/wormhole/stable_diffusion/tests/test_embedding.py \ No newline at end of file diff --git a/tests/nightly/single_card/stable_diffusion/test_feedforward.py b/tests/nightly/single_card/stable_diffusion/test_feedforward.py new file mode 120000 index 00000000000..915332488d5 --- /dev/null +++ b/tests/nightly/single_card/stable_diffusion/test_feedforward.py @@ -0,0 +1 @@ +../../../../models/demos/wormhole/stable_diffusion/tests/test_feedforward.py \ No newline at end of file diff --git a/tests/nightly/single_card/stable_diffusion/test_geglu.py b/tests/nightly/single_card/stable_diffusion/test_geglu.py new file mode 120000 index 00000000000..5880ea6e17d --- /dev/null +++ b/tests/nightly/single_card/stable_diffusion/test_geglu.py @@ -0,0 +1 @@ +../../../../models/demos/wormhole/stable_diffusion/tests/test_geglu.py \ No newline at end of file diff --git a/tests/nightly/single_card/stable_diffusion/test_resnet_block_2d.py b/tests/nightly/single_card/stable_diffusion/test_resnet_block_2d.py new file mode 120000 index 00000000000..1b6513e5b50 --- /dev/null +++ b/tests/nightly/single_card/stable_diffusion/test_resnet_block_2d.py @@ -0,0 +1 @@ +../../../../models/demos/wormhole/stable_diffusion/tests/test_resnet_block_2d.py \ No newline at end of file diff --git a/tests/nightly/single_card/stable_diffusion/test_sharded_matmuls.py b/tests/nightly/single_card/stable_diffusion/test_sharded_matmuls.py new file mode 120000 index 00000000000..d5d12d47849 --- /dev/null +++ b/tests/nightly/single_card/stable_diffusion/test_sharded_matmuls.py @@ -0,0 +1 @@ +../../../../models/demos/wormhole/stable_diffusion/tests/test_sharded_matmuls.py \ No newline at end of file diff --git a/tests/nightly/single_card/stable_diffusion/test_transformer_2d_model.py b/tests/nightly/single_card/stable_diffusion/test_transformer_2d_model.py new file mode 120000 index 00000000000..d82d4a899f6 --- /dev/null +++ b/tests/nightly/single_card/stable_diffusion/test_transformer_2d_model.py @@ -0,0 +1 @@ +../../../../models/demos/wormhole/stable_diffusion/tests/test_transformer_2d_model.py \ No newline at end of file diff --git a/tests/nightly/single_card/stable_diffusion/test_unet_2d_condition_model.py b/tests/nightly/single_card/stable_diffusion/test_unet_2d_condition_model.py new file mode 120000 index 00000000000..c25a861ed35 --- /dev/null +++ b/tests/nightly/single_card/stable_diffusion/test_unet_2d_condition_model.py @@ -0,0 +1 @@ +../../../../models/demos/wormhole/stable_diffusion/tests/test_unet_2d_condition_model.py \ No newline at end of file diff --git a/tests/nightly/single_card/stable_diffusion/test_upblock_2d.py b/tests/nightly/single_card/stable_diffusion/test_upblock_2d.py new file mode 120000 index 00000000000..3997b30be69 --- /dev/null +++ b/tests/nightly/single_card/stable_diffusion/test_upblock_2d.py @@ -0,0 +1 @@ +../../../../models/demos/wormhole/stable_diffusion/tests/test_upblock_2d.py \ No newline at end of file diff --git a/tests/nightly/single_card/stable_diffusion/test_upsample_2d.py b/tests/nightly/single_card/stable_diffusion/test_upsample_2d.py new file mode 120000 index 00000000000..88a98649844 --- /dev/null +++ b/tests/nightly/single_card/stable_diffusion/test_upsample_2d.py @@ -0,0 +1 @@ +../../../../models/demos/wormhole/stable_diffusion/tests/test_upsample_2d.py \ No newline at end of file diff --git a/tests/nightly/single_card/stable_diffusion/test_upsample_nearest_2d.py b/tests/nightly/single_card/stable_diffusion/test_upsample_nearest_2d.py new file mode 120000 index 00000000000..815ccb622b4 --- /dev/null +++ b/tests/nightly/single_card/stable_diffusion/test_upsample_nearest_2d.py @@ -0,0 +1 @@ +../../../../models/demos/wormhole/stable_diffusion/tests/test_upsample_nearest_2d.py \ No newline at end of file diff --git a/tests/nightly/single_card/wh_b0_unstable/tests/ttnn/integration_tests/stable_diffusion b/tests/nightly/single_card/wh_b0_unstable/tests/ttnn/integration_tests/stable_diffusion deleted file mode 120000 index 608e08f48e2..00000000000 --- a/tests/nightly/single_card/wh_b0_unstable/tests/ttnn/integration_tests/stable_diffusion +++ /dev/null @@ -1 +0,0 @@ -../../../../../../../tests/ttnn/integration_tests/stable_diffusion \ No newline at end of file diff --git a/tests/scripts/run_performance.sh b/tests/scripts/run_performance.sh index 7956d1c7b03..7c42512474d 100755 --- a/tests/scripts/run_performance.sh +++ b/tests/scripts/run_performance.sh @@ -41,6 +41,8 @@ run_perf_models_other() { env pytest -n auto models/demos/mnist/tests -m $test_marker + env pytest -n auto models/demos/squeezebert/tests/test_performance.py -m $test_marker + ## Merge all the generated reports env python models/perf/merge_perf_results.py } @@ -71,7 +73,7 @@ run_perf_models_cnn_javelin() { # Run tests env WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest models/experimental/functional_unet/tests/test_unet_perf.py -m $test_marker - env WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest -n auto tests/device_perf_tests/stable_diffusion -m $test_marker --timeout=480 + env WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest -n auto models/demos/wormhole/stable_diffusion/tests -m $test_marker --timeout=480 ## Merge all the generated reports env python models/perf/merge_perf_results.py @@ -81,7 +83,7 @@ run_device_perf_models() { set -eo pipefail local test_marker=$1 - env pytest tests/device_perf_tests/stable_diffusion -m $test_marker --timeout=600 + env pytest models/demos/wormhole/stable_diffusion/tests -m $test_marker --timeout=600 env pytest models/demos/distilbert/tests -m $test_marker @@ -93,6 +95,8 @@ run_device_perf_models() { env pytest models/demos/mnist/tests -m $test_marker + env pytest models/demos/squeezebert/tests -m $test_marker + if [ "$tt_arch" == "grayskull" ]; then #TODO(MO): Until #6560 is fixed, GS device profiler test are grouped with #Model Device perf regression tests to make sure thy run on no-soft-reset BMs diff --git a/tests/scripts/single_card/run_single_card_demo_tests.sh b/tests/scripts/single_card/run_single_card_demo_tests.sh index 5f5642483f6..11a4e96a895 100755 --- a/tests/scripts/single_card/run_single_card_demo_tests.sh +++ b/tests/scripts/single_card/run_single_card_demo_tests.sh @@ -15,6 +15,20 @@ run_common_func_tests() { # Qwen7B QWEN_DIR=/mnt/MLPerf/tt_dnn-models/qwen/Qwen2-7B-Instruct WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml FAKE_DEVICE=N150 pytest -n auto models/demos/qwen/demo/demo.py -k instruct --timeout 420; fail+=$? + # Llama3 Accuracy tests + # Llama3.2-1B + llama1b=/mnt/MLPerf/tt_dnn-models/llama/Llama3.2-1B-Instruct/ + # Llama3.2-3B + llama3b=/mnt/MLPerf/tt_dnn-models/llama/Llama3.2-3B-Instruct/ + # Llama3.1-8B (11B weights are the same) + llama8b=/mnt/MLPerf/tt_dnn-models/llama/Meta-Llama-3.1-8B-Instruct/ + + # Run Llama3 accuracy tests for 1B, 3B, 8B weights + for llama_dir in "$llama1b" "$llama3b" "$llama8b"; do + LLAMA_DIR=$llama_dir WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest -n auto models/demos/llama3/tests/test_llama_accuracy.py -k perf --timeout 420; fail+=$? + echo "LOG_METAL: Llama3 accuracy tests for $llama_dir completed" + done + #VGG11/VGG16 WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest -n auto models/demos/vgg/demo/demo.py --timeout 600; fail+=$? @@ -39,6 +53,9 @@ run_common_func_tests() { # Mnist pytest --disable-warnings models/demos/mnist/demo/demo.py --timeout 600; fail+=$? + # SqueezeBERT + pytest --disable-warnings models/demos/squeezebert/demo/demo.py --timeout 600; fail+=$? + return $fail } diff --git a/tests/scripts/t3000/run_t3000_demo_tests.sh b/tests/scripts/t3000/run_t3000_demo_tests.sh index 805fa83e97b..627769a5a9e 100755 --- a/tests/scripts/t3000/run_t3000_demo_tests.sh +++ b/tests/scripts/t3000/run_t3000_demo_tests.sh @@ -93,7 +93,7 @@ run_t3000_llama3_vision_tests() { pip install -r models/demos/llama3/requirements.txt for fake_device in "$n300" "$t3k"; do - FAKE_DEVICE=$fake_device LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/demo/simple_vision_demo.py -k "cold and yes_trace" --timeout 600; fail+=$? + FAKE_DEVICE=$fake_device LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/demo/simple_vision_demo.py -k "batch1-trace" --timeout 600; fail+=$? echo "LOG_METAL: Llama3 vision tests for $fake_device completed" done diff --git a/tests/scripts/t3000/run_t3000_frequent_tests.sh b/tests/scripts/t3000/run_t3000_frequent_tests.sh index 0058a3fc9e3..3ade2f43355 100755 --- a/tests/scripts/t3000/run_t3000_frequent_tests.sh +++ b/tests/scripts/t3000/run_t3000_frequent_tests.sh @@ -63,7 +63,7 @@ run_t3000_llama3_tests() { # Run test model for llama3 - 1B, 3B, 8B and 11B weights for llama_dir in "$llama1b" "$llama3b" "$llama8b" "$llama11b"; do LLAMA_DIR=$llama_dir WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/test_llama_model.py -k full ; fail+=$? - # LLAMA_DIR=$llama_dir WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/test_llama_model_prefill.py ; fail+=$? # FIXME Issue #14843 + LLAMA_DIR=$llama_dir WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/test_llama_model_prefill.py ; fail+=$? echo "LOG_METAL: Llama3 tests for $llama_dir completed" done @@ -96,6 +96,40 @@ run_t3000_llama3_70b_tests() { fi } +run_t3000_llama3_accuracy_tests() { + # Record the start time + fail=0 + start_time=$(date +%s) + + echo "LOG_METAL: Running run_t3000_llama3_accuracy_tests" + + wh_arch_yaml=wormhole_b0_80_arch_eth_dispatch.yaml + # Llama3.2-1B + llama1b=/mnt/MLPerf/tt_dnn-models/llama/Llama3.2-1B-Instruct/ + # Llama3.2-3B + llama3b=/mnt/MLPerf/tt_dnn-models/llama/Llama3.2-3B-Instruct/ + # Llama3.1-8B + llama8b=/mnt/MLPerf/tt_dnn-models/llama/Meta-Llama-3.1-8B-Instruct/ + # Llama3.2-11B + llama11b=/mnt/MLPerf/tt_dnn-models/llama/Llama3.2-11B-Vision-Instruct/ + # Llama3.1-70B + llama70b=/mnt/MLPerf/tt_dnn-models/llama/Llama3.1-70B-Instruct/ + + # Run test accuracy llama3 - 1B, 3B, 8B, 11B and 70B weights + for llama_dir in "$llama1b" "$llama3b" "$llama8b" "$llama11b" "$llama70b"; do + LLAMA_DIR=$llama_dir WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/test_llama_accuracy.py -k perf ; fail+=$? + echo "LOG_METAL: Llama3 accuracy tests for $llama_dir completed" + done + + # Record the end time + end_time=$(date +%s) + duration=$((end_time - start_time)) + echo "LOG_METAL: run_t3000_llama3_accuracy_tests $duration seconds to complete" + if [[ $fail -ne 0 ]]; then + exit 1 + fi +} + run_t3000_llama3.2-11b-vision_freq_tests() { # Record the start time fail=0 @@ -277,6 +311,9 @@ run_t3000_tests() { # Run llama3-70b tests run_t3000_llama3_70b_tests + # Run llama3 accuracy tests + run_t3000_llama3_accuracy_tests + # Run Llama3.2-11B Vision tests run_t3000_llama3.2-11b-vision_freq_tests diff --git a/tests/scripts/t3000/run_t3000_unit_tests.sh b/tests/scripts/t3000/run_t3000_unit_tests.sh index 6b33b853a07..60c671c6e83 100755 --- a/tests/scripts/t3000/run_t3000_unit_tests.sh +++ b/tests/scripts/t3000/run_t3000_unit_tests.sh @@ -197,8 +197,8 @@ run_t3000_llama3.2-11b-vision_unit_tests() { LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_image_mlp.py ; fail+=$? LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_image_attention.py ; fail+=$? LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_image_block.py ; fail+=$? - LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_cross_attention.py ; fail+=$? - LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_cross_block.py ; fail+=$? + LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_cross_attention.py -k "batch_1" ; fail+=$? + LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_cross_block.py -k "batch_1" ; fail+=$? LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_conv2d_patch.py ; fail+=$? LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_class_embedding.py ; fail+=$? LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_tile_position_embedding.py ; fail+=$? @@ -232,8 +232,8 @@ run_t3000_spoof_n300_llama3.2-11b-vision_unit_tests() { FAKE_DEVICE=$fake_device LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_image_mlp.py ; fail+=$? FAKE_DEVICE=$fake_device LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_image_attention.py ; fail+=$? FAKE_DEVICE=$fake_device LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_image_block.py ; fail+=$? - FAKE_DEVICE=$fake_device LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_cross_attention.py ; fail+=$? - FAKE_DEVICE=$fake_device LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_cross_block.py ; fail+=$? + FAKE_DEVICE=$fake_device LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_cross_attention.py -k "batch_1" ; fail+=$? + FAKE_DEVICE=$fake_device LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_cross_block.py -k "batch_1" ; fail+=$? FAKE_DEVICE=$fake_device LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_conv2d_patch.py ; fail+=$? FAKE_DEVICE=$fake_device LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_class_embedding.py ; fail+=$? FAKE_DEVICE=$fake_device LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_tile_position_embedding.py ; fail+=$? diff --git a/tests/sweep_framework/sweep_utils/conv2d_common.py b/tests/sweep_framework/sweep_utils/conv2d_common.py index 55769adb984..c7509247213 100644 --- a/tests/sweep_framework/sweep_utils/conv2d_common.py +++ b/tests/sweep_framework/sweep_utils/conv2d_common.py @@ -117,18 +117,20 @@ def run_full( conv_config = ttnn.Conv2dConfig( dtype=activations_dtype, weights_dtype=weights_dtype, - math_fidelity=math_fidelity, shard_layout=None, deallocate_activation=deallocate_activation, - fp32_dest_acc_enabled=fp32_accum, - packer_l1_accum_enabled=packer_l1_acc, override_sharding_config=override_sharding_config, output_layout=output_layout, enable_act_double_buffer=enable_act_double_buffer, enable_split_reader=enable_split_reader, enable_subblock_padding=enable_subblock_padding, ) - + compute_config = ttnn.init_device_compute_kernel_config( + device.arch(), + math_fidelity=math_fidelity, + fp32_dest_acc_en=fp32_accum, + packer_l1_acc=packer_l1_acc, + ) if override_sharding_config: if len(core_grid) == 2: conv_config.core_grid = ttnn.CoreRangeSet({ttnn.CoreRange(core_grid[0], core_grid[1])}) @@ -137,7 +139,7 @@ def run_full( {ttnn.CoreRange(core_grid[0], core_grid[1]), ttnn.CoreRange(core_grid[2], core_grid[3])} ) start_time = start_measuring_time() - [tt_output_tensor_on_device, out_height, out_width, weights_device, bias_device] = ttnn.conv2d( + [tt_output_tensor_on_device, [out_height, out_width], [weights_device, bias_device]] = ttnn.conv2d( input_tensor=tt_input_tensor, weight_tensor=tt_weight_tensor, in_channels=input_channels, @@ -152,7 +154,10 @@ def run_full( input_height=input_height, input_width=input_width, conv_config=conv_config, + compute_config=compute_config, groups=groups, + return_output_dim=True, + return_weights_and_bias=True, ) tt_output_tensor = ttnn.from_device(tt_output_tensor_on_device) @@ -220,7 +225,7 @@ def run_short( tt_input_tensor = ttnn.from_torch(torch_input_tensor, ttnn.bfloat16) start_time = start_measuring_time() - [tt_output_tensor_on_device, out_height, out_width, weights_device, bias_device] = ttnn.conv2d( + [tt_output_tensor_on_device, [out_height, out_width], [weights_device, bias_device]] = ttnn.conv2d( input_tensor=tt_input_tensor, weight_tensor=tt_weight_tensor, in_channels=input_channels, @@ -235,6 +240,8 @@ def run_short( input_height=input_height, input_width=input_width, groups=groups, + return_output_dim=True, + return_weights_and_bias=True, ) tt_output_tensor = ttnn.from_device(tt_output_tensor_on_device) diff --git a/tests/sweep_framework/sweep_utils/utils.py b/tests/sweep_framework/sweep_utils/utils.py index 6f2199055d8..ef6349c71e9 100644 --- a/tests/sweep_framework/sweep_utils/utils.py +++ b/tests/sweep_framework/sweep_utils/utils.py @@ -220,6 +220,28 @@ def gen_split_qkv_heads_spec( } +def gen_rotary_embedding_spec( + input_shape_list, + cache_size_list, + use_token_idx_list=[True, False], +): + for input_shape, cache_size, use_token_idx in itertools.product( + input_shape_list, cache_size_list, use_token_idx_list + ): + input_shape_ = input_shape.copy() + if use_token_idx is True: + token_idx = random.randint(1, cache_size - 1) + input_shape_[0] = 1 + else: + token_idx = None + + yield { + "input_shape": input_shape_, + "cache_size": cache_size, + "token_idx": token_idx, + } + + def gen_complex_tensor(input_shape, low, high, dtype=ttnn.bfloat16): torch_real = gen_func_with_cast_tt(partial(torch_random, low=-100, high=100, dtype=torch.float32), dtype)( input_shape diff --git a/tests/sweep_framework/sweeps/conv2d/short/conv2d_short_sweep.py b/tests/sweep_framework/sweeps/conv2d/short/conv2d_short_sweep.py index 97fa8debcbb..c41f6be9092 100644 --- a/tests/sweep_framework/sweeps/conv2d/short/conv2d_short_sweep.py +++ b/tests/sweep_framework/sweeps/conv2d/short/conv2d_short_sweep.py @@ -19,1551 +19,1551 @@ "input_specs": [ # Contains following params # [batch_size, output_channels, input_channels, input_height, input_width, kernel_height, kernel_width, stride_x, stride_y, pad_x, pad_y, groups, bias, dilation] + [1, 960, 960, 27, 27, 5, 5, 2, 2, 0, 0, 960, False, 1], + [1, 960, 960, 3, 3, 1, 5, 1, 1, 0, 2, 960, False, 1], + [1, 960, 960, 3, 3, 5, 1, 1, 1, 2, 0, 960, False, 1], + [1, 960, 960, 7, 7, 3, 3, 1, 1, 1, 1, 960, False, 1], + [1, 960, 960, 7, 7, 5, 5, 1, 1, 2, 2, 960, False, 1], + [1, 960, 960, 24, 24, 5, 5, 1, 1, 2, 2, 960, False, 1], + [1, 96, 96, 112, 112, 3, 3, 2, 2, 1, 1, 96, False, 1], + [1, 96, 96, 113, 113, 3, 3, 2, 2, 0, 0, 96, False, 1], + [1, 96, 96, 121, 121, 3, 3, 2, 2, 0, 0, 96, False, 1], + [1, 96, 96, 131, 131, 3, 3, 2, 2, 0, 0, 96, False, 1], + [1, 96, 96, 28, 28, 5, 5, 2, 2, 2, 2, 96, False, 1], + [1, 92, 92, 14, 14, 3, 3, 1, 1, 1, 1, 92, False, 1], + [1, 144, 144, 28, 28, 3, 3, 1, 1, 1, 1, 9, False, 1], + [1, 144, 144, 56, 56, 3, 3, 2, 2, 1, 1, 9, False, 1], + [1, 216, 216, 28, 28, 3, 3, 1, 1, 1, 1, 9, False, 1], + [1, 216, 216, 56, 56, 3, 3, 2, 2, 1, 1, 9, False, 1], + [1, 432, 432, 14, 14, 3, 3, 1, 1, 1, 1, 9, False, 1], + [1, 432, 432, 28, 28, 3, 3, 2, 2, 1, 1, 9, False, 1], + [1, 88, 88, 28, 28, 3, 3, 1, 1, 1, 1, 88, False, 1], + [1, 816, 816, 19, 19, 5, 5, 1, 1, 2, 2, 816, False, 1], + [1, 816, 816, 23, 23, 5, 5, 2, 2, 0, 0, 816, False, 1], + [1, 80, 80, 14, 14, 3, 3, 1, 1, 1, 1, 80, False, 1], + [1, 80, 80, 7, 7, 3, 3, 1, 1, 1, 1, 80, False, 1], + [1, 128, 128, 28, 28, 3, 3, 1, 1, 1, 1, 8, False, 1], + [1, 128, 128, 56, 56, 3, 3, 2, 2, 1, 1, 8, False, 1], + [1, 1344, 1344, 14, 14, 3, 3, 1, 1, 1, 1, 8, False, 1], + [1, 1344, 1344, 28, 28, 3, 3, 2, 2, 1, 1, 8, False, 1], + [1, 448, 448, 28, 28, 3, 3, 1, 1, 1, 1, 8, False, 1], + [1, 448, 448, 56, 56, 3, 3, 2, 2, 1, 1, 8, False, 1], + [1, 8, 8, 112, 112, 3, 3, 1, 1, 1, 1, 8, False, 1], + [1, 728, 728, 19, 19, 3, 3, 1, 1, 1, 1, 728, False, 1], + [1, 728, 728, 38, 38, 3, 3, 2, 2, 1, 1, 728, False, 1], + [1, 728, 728, 38, 38, 3, 3, 1, 1, 1, 1, 728, False, 1], + [1, 720, 720, 17, 17, 5, 5, 1, 1, 2, 2, 720, False, 1], + [1, 720, 720, 21, 21, 5, 5, 2, 2, 0, 0, 720, False, 1], + [1, 72, 72, 28, 28, 1, 5, 1, 1, 0, 2, 72, False, 1], + [1, 72, 72, 28, 28, 5, 1, 1, 1, 2, 0, 72, False, 1], + [1, 72, 72, 56, 56, 3, 3, 1, 1, 1, 1, 72, False, 1], + [1, 72, 72, 56, 56, 3, 3, 2, 2, 1, 1, 72, False, 1], + [1, 72, 72, 56, 56, 5, 5, 2, 2, 2, 2, 72, False, 1], + [1, 72, 72, 80, 80, 3, 3, 1, 1, 1, 1, 72, False, 1], + [1, 72, 72, 80, 80, 5, 5, 2, 2, 2, 2, 72, False, 1], + [1, 168, 168, 28, 28, 3, 3, 1, 1, 1, 1, 7, False, 1], + [1, 168, 168, 56, 56, 3, 3, 2, 2, 1, 1, 7, False, 1], + [1, 896, 896, 14, 14, 3, 3, 1, 1, 1, 1, 7, False, 1], + [1, 896, 896, 28, 28, 3, 3, 2, 2, 1, 1, 7, False, 1], + [1, 672, 672, 14, 14, 3, 3, 1, 1, 1, 1, 672, False, 1], + [1, 672, 672, 14, 14, 5, 5, 1, 1, 2, 2, 672, False, 1], + [1, 672, 672, 14, 14, 5, 5, 2, 2, 2, 2, 672, False, 1], + [1, 672, 672, 15, 15, 5, 5, 1, 1, 2, 2, 672, False, 1], + [1, 672, 672, 17, 17, 5, 5, 2, 2, 0, 0, 672, False, 1], + [1, 672, 672, 19, 19, 5, 5, 2, 2, 0, 0, 672, False, 1], + [1, 672, 672, 20, 20, 3, 3, 1, 1, 1, 1, 672, False, 1], + [1, 672, 672, 20, 20, 5, 5, 2, 2, 2, 2, 672, False, 1], + [1, 672, 672, 24, 24, 3, 3, 1, 1, 1, 1, 672, False, 1], + [1, 672, 672, 24, 24, 5, 5, 1, 1, 2, 2, 672, False, 1], + [1, 672, 672, 7, 7, 1, 5, 1, 1, 0, 2, 672, False, 1], + [1, 672, 672, 7, 7, 5, 1, 1, 1, 2, 0, 672, False, 1], + [1, 640, 640, 32, 32, 3, 3, 1, 1, 1, 1, 640, True, 1], + [1, 1024, 1024, 14, 14, 3, 3, 1, 1, 1, 1, 64, False, 1], + [1, 1024, 1024, 28, 28, 3, 3, 2, 2, 1, 1, 64, False, 1], + [1, 2048, 2048, 14, 14, 3, 3, 2, 2, 1, 1, 64, False, 1], + [1, 2048, 2048, 7, 7, 3, 3, 1, 1, 1, 1, 64, False, 1], + [1, 512, 512, 28, 28, 3, 3, 1, 1, 1, 1, 64, False, 1], + [1, 512, 512, 56, 56, 3, 3, 2, 2, 1, 1, 64, False, 1], + [1, 64, 64, 112, 112, 3, 3, 1, 1, 1, 1, 64, False, 1], + [1, 64, 64, 112, 112, 3, 3, 2, 2, 1, 1, 64, False, 1], + [1, 64, 64, 150, 150, 3, 3, 1, 1, 1, 1, 64, False, 1], + [1, 64, 64, 160, 160, 3, 3, 2, 2, 1, 1, 64, False, 1], + [1, 64, 64, 2, 2, 3, 3, 2, 2, 1, 1, 64, False, 1], + [1, 256, 256, 56, 56, 3, 3, 1, 1, 1, 1, 64, False, 1], + [1, 1512, 1512, 14, 14, 3, 3, 2, 2, 1, 1, 63, False, 1], + [1, 60, 60, 28, 28, 3, 3, 1, 1, 1, 1, 60, False, 1], + [1, 1392, 1392, 14, 14, 3, 3, 1, 1, 1, 1, 6, False, 1], + [1, 1392, 1392, 28, 28, 3, 3, 2, 2, 1, 1, 6, False, 1], + [1, 48, 48, 112, 112, 3, 3, 2, 2, 1, 1, 6, False, 1], + [1, 720, 720, 14, 14, 3, 3, 1, 1, 1, 1, 6, False, 1], + [1, 720, 720, 28, 28, 3, 3, 2, 2, 1, 1, 6, False, 1], + [1, 576, 576, 14, 14, 3, 3, 1, 1, 1, 1, 576, False, 1], + [1, 576, 576, 14, 14, 3, 3, 2, 2, 1, 1, 576, False, 1], + [1, 576, 576, 19, 19, 3, 3, 1, 1, 1, 1, 576, False, 1], + [1, 576, 576, 19, 19, 5, 5, 1, 1, 2, 2, 576, False, 1], + [1, 576, 576, 7, 7, 5, 5, 1, 1, 2, 2, 576, False, 1], + [1, 56, 56, 14, 14, 3, 3, 1, 1, 1, 1, 56, False, 1], + [1, 440, 440, 14, 14, 3, 3, 2, 2, 1, 1, 55, False, 1], + [1, 440, 440, 7, 7, 3, 3, 1, 1, 1, 1, 55, False, 1], + [1, 528, 528, 17, 17, 3, 3, 1, 1, 1, 1, 528, False, 1], + [1, 528, 528, 17, 17, 5, 5, 1, 1, 2, 2, 528, False, 1], + [1, 512, 512, 14, 14, 3, 3, 1, 1, 1, 1, 512, False, 1], + [1, 512, 512, 14, 14, 3, 3, 2, 2, 1, 1, 512, False, 1], + [1, 512, 512, 28, 28, 3, 3, 1, 1, 1, 1, 512, False, 1], + [1, 512, 512, 28, 28, 3, 3, 1, 1, 2, 2, 512, False, 1], + [1, 512, 512, 5, 5, 3, 3, 1, 1, 1, 1, 512, False, 1], + [1, 512, 512, 60, 80, 3, 3, 1, 1, 1, 1, 512, True, 1], + [1, 120, 120, 28, 28, 3, 3, 1, 1, 1, 1, 5, False, 1], + [1, 120, 120, 56, 56, 3, 3, 2, 2, 1, 1, 5, False, 1], + [1, 784, 784, 14, 14, 3, 3, 2, 2, 1, 1, 49, False, 1], + [1, 784, 784, 7, 7, 3, 3, 1, 1, 1, 1, 49, False, 1], + [1, 480, 480, 10, 10, 3, 3, 1, 1, 1, 1, 480, False, 1], + [1, 480, 480, 10, 10, 5, 5, 1, 1, 2, 2, 480, False, 1], + [1, 480, 480, 14, 14, 3, 3, 1, 1, 1, 1, 480, False, 1], + [1, 480, 480, 14, 14, 5, 5, 1, 1, 2, 2, 480, False, 1], + [1, 480, 480, 15, 15, 3, 3, 1, 1, 1, 1, 480, False, 1], + [1, 480, 480, 15, 15, 5, 5, 1, 1, 2, 2, 480, False, 1], + [1, 480, 480, 20, 20, 3, 3, 1, 1, 1, 1, 480, False, 1], + [1, 480, 480, 7, 7, 1, 5, 1, 1, 0, 2, 480, False, 1], + [1, 480, 480, 7, 7, 3, 3, 1, 1, 1, 1, 480, False, 1], + [1, 480, 480, 7, 7, 5, 1, 1, 1, 2, 0, 480, False, 1], + [1, 48, 48, 112, 112, 3, 3, 2, 2, 1, 1, 48, False, 1], + [1, 672, 672, 14, 14, 3, 3, 2, 2, 1, 1, 42, False, 1], + [1, 672, 672, 7, 7, 3, 3, 1, 1, 1, 1, 42, False, 1], + [1, 40, 40, 14, 14, 3, 3, 1, 1, 1, 1, 40, False, 1], + [1, 40, 40, 28, 28, 3, 3, 2, 2, 1, 1, 40, False, 1], + [1, 192, 192, 28, 28, 3, 3, 1, 1, 1, 1, 4, False, 1], + [1, 192, 192, 56, 56, 3, 3, 2, 2, 1, 1, 4, False, 1], + [1, 224, 224, 112, 112, 3, 3, 2, 2, 1, 1, 4, False, 1], + [1, 224, 224, 56, 56, 3, 3, 1, 1, 1, 1, 4, False, 1], + [1, 448, 448, 28, 28, 3, 3, 1, 1, 1, 1, 4, False, 1], + [1, 448, 448, 56, 56, 3, 3, 2, 2, 1, 1, 4, False, 1], + [1, 512, 512, 28, 28, 3, 3, 1, 1, 1, 1, 4, False, 1], + [1, 512, 512, 56, 56, 3, 3, 2, 2, 1, 1, 4, False, 1], + [1, 64, 64, 112, 112, 3, 3, 2, 2, 1, 1, 4, False, 1], + [1, 64, 64, 28, 28, 3, 3, 1, 1, 1, 1, 4, False, 1], + [1, 64, 64, 56, 56, 3, 3, 2, 2, 1, 1, 4, False, 1], + [1, 672, 672, 28, 28, 3, 3, 1, 1, 1, 1, 4, False, 1], + [1, 672, 672, 56, 56, 3, 3, 2, 2, 1, 1, 4, False, 1], + [1, 1056, 1056, 48, 48, 3, 3, 1, 1, 1, 1, 4, False, 1], + [1, 1056, 1056, 96, 96, 3, 3, 2, 2, 1, 1, 4, False, 1], + [1, 384, 384, 14, 14, 3, 3, 1, 1, 1, 1, 384, False, 1], + [1, 912, 912, 14, 14, 3, 3, 2, 2, 1, 1, 38, False, 1], + [1, 912, 912, 7, 7, 3, 3, 1, 1, 1, 1, 38, False, 1], + [1, 888, 888, 14, 14, 3, 3, 2, 2, 1, 1, 37, False, 1], + [1, 888, 888, 7, 7, 3, 3, 1, 1, 1, 1, 37, False, 1], + [1, 2016, 2016, 14, 14, 3, 3, 2, 2, 1, 1, 36, False, 1], + [1, 36, 36, 56, 56, 3, 3, 1, 1, 1, 1, 36, False, 1], + [1, 336, 336, 14, 14, 3, 3, 1, 1, 1, 1, 336, False, 1], + [1, 336, 336, 49, 49, 3, 3, 2, 2, 0, 0, 336, False, 1], + [1, 336, 336, 48, 48, 5, 5, 1, 1, 2, 2, 336, False, 1], + [1, 1024, 1024, 14, 14, 3, 3, 1, 1, 1, 1, 32, False, 1], + [1, 1024, 1024, 14, 14, 3, 3, 2, 2, 1, 1, 32, False, 1], + [1, 1024, 1024, 28, 28, 3, 3, 2, 2, 1, 1, 32, False, 1], + [1, 1024, 1024, 7, 7, 3, 3, 1, 1, 1, 1, 32, False, 1], + [1, 128, 128, 56, 56, 3, 3, 1, 1, 1, 1, 32, False, 1], + [1, 2048, 2048, 14, 14, 3, 3, 2, 2, 1, 1, 32, False, 1], + [1, 2048, 2048, 7, 7, 3, 3, 1, 1, 1, 1, 32, False, 1], + [1, 256, 256, 28, 28, 3, 3, 1, 1, 1, 1, 32, False, 1], + [1, 256, 256, 56, 56, 3, 3, 2, 2, 1, 1, 32, False, 1], + [1, 32, 32, 112, 112, 3, 3, 1, 1, 1, 1, 32, False, 1], + [1, 32, 32, 120, 120, 3, 3, 1, 1, 1, 1, 32, False, 1], + [1, 32, 32, 130, 130, 3, 3, 1, 1, 1, 1, 32, False, 1], + [1, 32, 32, 150, 150, 3, 3, 1, 1, 1, 1, 32, False, 1], + [1, 32, 32, 190, 190, 3, 3, 1, 1, 1, 1, 32, False, 1], + [1, 512, 512, 14, 14, 3, 3, 1, 1, 1, 1, 32, False, 1], + [1, 512, 512, 28, 28, 3, 3, 1, 1, 1, 1, 32, False, 1], + [1, 512, 512, 28, 28, 3, 3, 2, 2, 1, 1, 32, False, 1], + [1, 512, 512, 56, 56, 3, 3, 2, 2, 1, 1, 32, False, 1], + [1, 256, 256, 56, 56, 3, 3, 1, 1, 1, 1, 32, False, 1], + [1, 72, 72, 112, 112, 3, 3, 2, 2, 1, 1, 3, False, 1], + [1, 72, 72, 56, 56, 3, 3, 1, 1, 1, 1, 3, False, 1], + [1, 696, 696, 28, 28, 3, 3, 1, 1, 1, 1, 3, False, 1], + [1, 696, 696, 56, 56, 3, 3, 2, 2, 1, 1, 3, False, 1], + [1, 288, 288, 14, 14, 5, 5, 2, 2, 2, 2, 288, False, 1], + [1, 288, 288, 33, 33, 5, 5, 1, 1, 2, 2, 288, False, 1], + [1, 288, 288, 35, 35, 3, 3, 2, 2, 0, 0, 288, False, 1], + [1, 288, 288, 38, 38, 5, 5, 1, 1, 2, 2, 288, False, 1], + [1, 288, 288, 39, 39, 3, 3, 2, 2, 0, 0, 288, False, 1], + [1, 7392, 7392, 24, 24, 3, 3, 2, 2, 1, 1, 28, False, 1], + [1, 3024, 3024, 14, 14, 3, 3, 2, 2, 1, 1, 27, False, 1], + [1, 208, 208, 14, 14, 3, 3, 1, 1, 1, 1, 26, False, 1], + [1, 208, 208, 28, 28, 3, 3, 2, 2, 1, 1, 26, False, 1], + [1, 256, 256, 10, 10, 3, 3, 2, 2, 1, 1, 256, False, 1], + [1, 256, 256, 2, 2, 3, 3, 1, 1, 1, 1, 256, False, 1], + [1, 256, 256, 28, 28, 3, 3, 1, 1, 1, 1, 256, False, 1], + [1, 256, 256, 28, 28, 3, 3, 2, 2, 1, 1, 256, False, 1], + [1, 256, 256, 3, 3, 3, 3, 1, 1, 1, 1, 256, False, 1], + [1, 256, 256, 38, 38, 3, 3, 1, 1, 1, 1, 256, False, 1], + [1, 256, 256, 64, 64, 3, 3, 1, 1, 1, 1, 256, True, 1], + [1, 256, 256, 75, 75, 3, 3, 2, 2, 1, 1, 256, False, 1], + [1, 256, 256, 120, 160, 3, 3, 1, 1, 1, 1, 256, True, 1], + [1, 256, 256, 75, 75, 3, 3, 1, 1, 1, 1, 256, False, 1], + [1, 400, 400, 14, 14, 3, 3, 2, 2, 1, 1, 25, False, 1], + [1, 400, 400, 7, 7, 3, 3, 1, 1, 1, 1, 25, False, 1], + [1, 240, 240, 14, 14, 1, 5, 1, 1, 0, 2, 240, False, 1], + [1, 240, 240, 14, 14, 3, 3, 1, 1, 1, 1, 240, False, 1], + [1, 240, 240, 14, 14, 5, 1, 1, 1, 2, 0, 240, False, 1], + [1, 240, 240, 14, 14, 5, 5, 1, 1, 2, 2, 240, False, 1], + [1, 240, 240, 28, 28, 3, 3, 2, 2, 1, 1, 240, False, 1], + [1, 240, 240, 28, 28, 5, 5, 1, 1, 2, 2, 240, False, 1], + [1, 240, 240, 29, 29, 3, 3, 2, 2, 0, 0, 240, False, 1], + [1, 240, 240, 30, 30, 5, 5, 1, 1, 2, 2, 240, False, 1], + [1, 240, 240, 31, 31, 3, 3, 2, 2, 0, 0, 240, False, 1], + [1, 240, 240, 40, 40, 3, 3, 2, 2, 1, 1, 240, False, 1], + [1, 24, 24, 112, 112, 3, 3, 1, 1, 1, 1, 24, False, 1], + [1, 24, 24, 56, 56, 5, 5, 2, 2, 2, 2, 24, False, 1], + [1, 576, 576, 14, 14, 3, 3, 1, 1, 1, 1, 24, False, 1], + [1, 576, 576, 28, 28, 3, 3, 2, 2, 1, 1, 24, False, 1], + [1, 224, 224, 7, 7, 3, 3, 1, 1, 1, 1, 224, False, 1], + [1, 1008, 1008, 14, 14, 3, 3, 2, 2, 1, 1, 21, False, 1], + [1, 1008, 1008, 7, 7, 3, 3, 1, 1, 1, 1, 21, False, 1], + [1, 2048, 2048, 15, 20, 3, 3, 1, 1, 1, 1, 2048, True, 1], + [1, 200, 200, 14, 14, 3, 3, 1, 1, 1, 1, 200, False, 1], + [1, 200, 200, 20, 20, 3, 3, 1, 1, 1, 1, 200, False, 1], + [1, 200, 200, 7, 7, 1, 5, 1, 1, 0, 2, 200, False, 1], + [1, 200, 200, 7, 7, 5, 1, 1, 1, 2, 0, 200, False, 1], + [1, 20, 20, 28, 28, 3, 3, 1, 1, 1, 1, 20, False, 1], + [1, 320, 320, 14, 14, 3, 3, 1, 1, 1, 1, 20, False, 1], + [1, 320, 320, 28, 28, 3, 3, 2, 2, 1, 1, 20, False, 1], + [1, 224, 224, 112, 112, 3, 3, 2, 2, 1, 1, 2, False, 1], + [1, 224, 224, 56, 56, 3, 3, 1, 1, 1, 1, 2, False, 1], + [1, 240, 240, 28, 28, 3, 3, 1, 1, 1, 1, 2, False, 1], + [1, 240, 240, 56, 56, 3, 3, 2, 2, 1, 1, 2, False, 1], + [1, 32, 32, 112, 112, 3, 3, 2, 2, 1, 1, 2, False, 1], + [1, 48, 48, 112, 112, 3, 3, 2, 2, 1, 1, 2, False, 1], + [1, 48, 48, 56, 56, 3, 3, 1, 1, 1, 1, 2, False, 1], + [1, 96, 96, 112, 112, 3, 3, 2, 2, 1, 1, 2, False, 1], + [1, 96, 96, 56, 56, 3, 3, 1, 1, 1, 1, 2, False, 1], + [1, 256, 256, 112, 112, 3, 3, 2, 2, 1, 1, 2, False, 1], + [1, 256, 256, 56, 56, 3, 3, 1, 1, 1, 1, 2, False, 1], + [1, 336, 336, 112, 112, 3, 3, 2, 2, 1, 1, 2, False, 1], + [1, 336, 336, 56, 56, 3, 3, 1, 1, 1, 1, 2, False, 1], + [1, 528, 528, 192, 192, 3, 3, 2, 2, 1, 1, 2, False, 1], + [1, 528, 528, 96, 96, 3, 3, 1, 1, 1, 1, 2, False, 1], + [1, 192, 192, 14, 14, 3, 3, 1, 1, 1, 1, 192, False, 1], + [1, 192, 192, 28, 28, 3, 3, 1, 1, 1, 1, 192, False, 1], + [1, 192, 192, 28, 28, 3, 3, 2, 2, 1, 1, 192, False, 1], + [1, 192, 192, 75, 75, 3, 3, 1, 1, 1, 1, 192, False, 1], + [1, 192, 192, 79, 79, 5, 5, 2, 2, 0, 0, 192, False, 1], + [1, 192, 192, 95, 95, 3, 3, 1, 1, 1, 1, 192, False, 1], + [1, 192, 192, 99, 99, 5, 5, 2, 2, 0, 0, 192, False, 1], + [1, 184, 184, 14, 14, 3, 3, 1, 1, 1, 1, 184, False, 1], + [1, 184, 184, 20, 20, 3, 3, 1, 1, 1, 1, 184, False, 1], + [1, 184, 184, 7, 7, 1, 5, 1, 1, 0, 2, 184, False, 1], + [1, 184, 184, 7, 7, 5, 1, 1, 1, 2, 0, 184, False, 1], + [1, 288, 288, 14, 14, 3, 3, 1, 1, 1, 1, 18, False, 1], + [1, 288, 288, 28, 28, 3, 3, 2, 2, 1, 1, 18, False, 1], + [1, 408, 408, 14, 14, 3, 3, 1, 1, 1, 1, 17, False, 1], + [1, 408, 408, 28, 28, 3, 3, 2, 2, 1, 1, 17, False, 1], + [1, 1632, 1632, 12, 12, 3, 3, 1, 1, 1, 1, 1632, False, 1], + [1, 1632, 1632, 12, 12, 5, 5, 1, 1, 2, 2, 1632, False, 1], + [1, 160, 160, 28, 28, 3, 3, 1, 1, 1, 1, 160, False, 1], + [1, 16, 16, 112, 112, 3, 3, 1, 1, 1, 1, 16, False, 1], + [1, 16, 16, 112, 112, 3, 3, 2, 2, 1, 1, 16, False, 1], + [1, 16, 16, 160, 160, 3, 3, 1, 1, 1, 1, 16, False, 1], + [1, 1920, 1920, 14, 14, 3, 3, 2, 2, 1, 1, 16, False, 1], + [1, 2048, 2048, 14, 14, 3, 3, 2, 2, 1, 1, 16, False, 1], + [1, 3712, 3712, 14, 14, 3, 3, 2, 2, 1, 1, 16, False, 1], + [1, 896, 896, 14, 14, 3, 3, 1, 1, 1, 1, 16, False, 1], + [1, 896, 896, 28, 28, 3, 3, 2, 2, 1, 1, 16, False, 1], + [1, 1536, 1536, 10, 10, 3, 3, 1, 1, 1, 1, 1536, False, 1], + [1, 2520, 2520, 14, 14, 3, 3, 2, 2, 1, 1, 15, False, 1], + [1, 144, 144, 14, 14, 5, 5, 1, 1, 2, 2, 144, False, 1], + [1, 144, 144, 151, 151, 3, 3, 2, 2, 0, 0, 144, False, 1], + [1, 144, 144, 191, 191, 3, 3, 2, 2, 0, 0, 144, False, 1], + [1, 144, 144, 56, 56, 3, 3, 1, 1, 1, 1, 144, False, 1], + [1, 144, 144, 56, 56, 3, 3, 2, 2, 1, 1, 144, False, 1], + [1, 144, 144, 59, 59, 5, 5, 2, 2, 0, 0, 144, False, 1], + [1, 144, 144, 60, 60, 3, 3, 1, 1, 1, 1, 144, False, 1], + [1, 144, 144, 63, 63, 5, 5, 2, 2, 0, 0, 144, False, 1], + [1, 144, 144, 65, 65, 3, 3, 1, 1, 1, 1, 144, False, 1], + [1, 144, 144, 69, 69, 5, 5, 2, 2, 0, 0, 144, False, 1], + [1, 336, 336, 14, 14, 3, 3, 1, 1, 1, 1, 14, False, 1], + [1, 336, 336, 28, 28, 3, 3, 2, 2, 1, 1, 14, False, 1], + [1, 1392, 1392, 10, 10, 3, 3, 1, 1, 1, 1, 1392, False, 1], + [1, 1392, 1392, 10, 10, 5, 5, 1, 1, 2, 2, 1392, False, 1], + [1, 104, 104, 28, 28, 3, 3, 1, 1, 1, 1, 13, False, 1], + [1, 104, 104, 56, 56, 3, 3, 2, 2, 1, 1, 13, False, 1], + [1, 1280, 1280, 30, 40, 3, 3, 1, 1, 1, 1, 1280, True, 1], + [1, 128, 128, 1, 1, 3, 3, 1, 1, 1, 1, 128, False, 1], + [1, 128, 128, 128, 128, 3, 3, 1, 1, 1, 1, 128, True, 1], + [1, 128, 128, 150, 150, 3, 3, 1, 1, 1, 1, 128, False, 1], + [1, 128, 128, 150, 150, 3, 3, 2, 2, 1, 1, 128, False, 1], + [1, 128, 128, 28, 28, 3, 3, 1, 1, 1, 1, 128, False, 1], + [1, 128, 128, 3, 3, 3, 3, 2, 2, 1, 1, 128, False, 1], + [1, 128, 128, 56, 56, 3, 3, 1, 1, 1, 1, 128, False, 1], + [1, 128, 128, 56, 56, 3, 3, 2, 2, 1, 1, 128, False, 1], + [1, 128, 128, 75, 75, 3, 3, 1, 1, 1, 1, 128, False, 1], + [1, 128, 128, 5, 5, 3, 3, 2, 2, 1, 1, 128, False, 1], + [1, 1248, 1248, 9, 9, 3, 3, 1, 1, 1, 1, 1248, False, 1], + [1, 1248, 1248, 9, 9, 5, 5, 1, 1, 2, 2, 1248, False, 1], + [1, 120, 120, 14, 14, 1, 5, 1, 1, 0, 2, 120, False, 1], + [1, 120, 120, 14, 14, 5, 1, 1, 1, 2, 0, 120, False, 1], + [1, 120, 120, 14, 14, 5, 5, 1, 1, 2, 2, 120, False, 1], + [1, 120, 120, 28, 28, 3, 3, 1, 1, 1, 1, 120, False, 1], + [1, 120, 120, 28, 28, 5, 5, 1, 1, 2, 2, 120, False, 1], + [1, 120, 120, 40, 40, 5, 5, 1, 1, 2, 2, 120, False, 1], + [1, 12, 12, 56, 56, 3, 3, 1, 1, 1, 1, 12, False, 1], + [1, 1152, 1152, 7, 7, 3, 3, 1, 1, 1, 1, 1152, False, 1], + [1, 1152, 1152, 7, 7, 5, 5, 1, 1, 2, 2, 1152, False, 1], + [1, 1152, 1152, 8, 8, 3, 3, 1, 1, 1, 1, 1152, False, 1], + [1, 1152, 1152, 8, 8, 5, 5, 1, 1, 2, 2, 1152, False, 1], + [1, 112, 112, 14, 14, 5, 5, 2, 2, 2, 2, 112, False, 1], + [1, 1232, 1232, 14, 14, 3, 3, 1, 1, 1, 1, 11, False, 1], + [1, 1232, 1232, 28, 28, 3, 3, 2, 2, 1, 1, 11, False, 1], + [1, 2904, 2904, 24, 24, 3, 3, 1, 1, 1, 1, 11, False, 1], + [1, 2904, 2904, 48, 48, 3, 3, 2, 2, 1, 1, 11, False, 1], + [1, 1024, 1024, 10, 10, 3, 3, 1, 1, 1, 1, 1024, False, 1], + [1, 1024, 1024, 16, 16, 3, 3, 1, 1, 1, 1, 1024, True, 1], + [1, 1024, 1024, 19, 19, 3, 3, 2, 2, 1, 1, 1024, False, 1], + [1, 1024, 1024, 7, 7, 3, 3, 1, 1, 1, 1, 1024, False, 1], + [1, 100, 100, 14, 14, 3, 3, 1, 1, 1, 1, 100, False, 1], + [1, 160, 160, 14, 14, 3, 3, 1, 1, 1, 1, 10, False, 1], + [1, 160, 160, 28, 28, 3, 3, 2, 2, 1, 1, 10, False, 1], [1, 16, 1, 28, 28, 3, 3, 1, 1, 1, 1, 1, True, 1], - [1, 32, 1, 28, 28, 3, 3, 1, 1, 1, 1, 0, True, 1], - [1, 100, 100, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 1008, 1008, 14, 14, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 192, 1008, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 1008, 1008, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 1008, 1008, 7, 7, 3, 3, 1, 1, 1, 1, 1, False, 1], + [1, 32, 1, 28, 28, 3, 3, 1, 1, 0, 0, 1, True, 1], + [1, 192, 1008, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 1008, 1008, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], [1, 40, 102, 56, 56, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 1024, 1024, 10, 10, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 1024, 1024, 10, 10, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 1536, 1024, 10, 10, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 1024, 1024, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], + [1, 1024, 1024, 10, 10, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 1536, 1024, 10, 10, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 1024, 1024, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], [1, 1024, 1024, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 1024, 1024, 14, 14, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 1024, 1024, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 1024, 1024, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 1024, 1024, 14, 14, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 128, 1024, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 256, 1024, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 512, 1024, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 1024, 1024, 16, 16, 3, 3, 1, 1, 1, 1, 1, True, 1], - [1, 255, 1024, 16, 16, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 512, 1024, 16, 16, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 128, 1024, 17, 17, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 192, 1024, 17, 17, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 256, 1024, 17, 17, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 384, 1024, 17, 17, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 1024, 1024, 19, 19, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 1024, 1024, 19, 19, 1, 1, 1, 1, 1, 1, 0, True, 1], + [1, 1024, 1024, 14, 14, 3, 3, 2, 2, 1, 1, 1, False, 1], + [1, 128, 1024, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 256, 1024, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 512, 1024, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 255, 1024, 16, 16, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 512, 1024, 16, 16, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 128, 1024, 17, 17, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 192, 1024, 17, 17, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 256, 1024, 17, 17, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 384, 1024, 17, 17, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 1024, 1024, 19, 19, 1, 1, 1, 1, 0, 0, 1, True, 1], [1, 24, 1024, 19, 19, 3, 3, 1, 1, 1, 1, 1, True, 1], - [1, 256, 1024, 19, 19, 1, 1, 1, 1, 1, 1, 0, True, 1], + [1, 256, 1024, 19, 19, 1, 1, 1, 1, 0, 0, 1, True, 1], [1, 546, 1024, 19, 19, 3, 3, 1, 1, 1, 1, 1, True, 1], - [1, 1024, 1024, 28, 28, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 1024, 1024, 28, 28, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 256, 1024, 45, 80, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 512, 1024, 45, 80, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 256, 1024, 50, 68, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 256, 1024, 50, 68, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 512, 1024, 50, 68, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 1024, 1024, 7, 7, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 1024, 1024, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], + [1, 256, 1024, 45, 80, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 512, 1024, 45, 80, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 256, 1024, 50, 68, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 256, 1024, 50, 68, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 512, 1024, 50, 68, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 1024, 1024, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], [1, 1024, 1024, 7, 7, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 1024, 1024, 7, 7, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 128, 1024, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 12, 104, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 26, 104, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 104, 104, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 104, 104, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 208, 104, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 208, 104, 28, 28, 1, 1, 2, 2, 2, 2, 0, False, 1], - [1, 104, 104, 56, 56, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 132, 1056, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 264, 1056, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 128, 1056, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 192, 1056, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 128, 1056, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 192, 1056, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], + [1, 128, 1024, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 12, 104, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 26, 104, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 104, 104, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 208, 104, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 208, 104, 28, 28, 1, 1, 2, 2, 0, 0, 1, False, 1], + [1, 132, 1056, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 264, 1056, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 128, 1056, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 192, 1056, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 128, 1056, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 192, 1056, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], [1, 462, 1072, 7, 7, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 128, 1088, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 768, 1088, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 128, 1088, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 440, 110, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 192, 1104, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 192, 1104, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 1232, 112, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 448, 112, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 896, 112, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 112, 112, 14, 14, 5, 5, 2, 2, 2, 2, 2, False, 1], + [1, 128, 1088, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 768, 1088, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 128, 1088, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 440, 110, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 192, 1104, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 192, 1104, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 1232, 112, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 448, 112, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 896, 112, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], [1, 224, 112, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 336, 112, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 672, 112, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 672, 112, 15, 15, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 672, 112, 20, 20, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 672, 112, 24, 24, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 160, 112, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 672, 112, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 128, 1120, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 128, 1120, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 128, 1152, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 192, 1152, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 1152, 1152, 7, 7, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 1152, 1152, 7, 7, 5, 5, 1, 1, 1, 1, 2, False, 1], - [1, 128, 1152, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 192, 1152, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 320, 1152, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 1152, 1152, 8, 8, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 1152, 1152, 8, 8, 5, 5, 1, 1, 1, 1, 2, False, 1], - [1, 192, 1152, 8, 8, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 320, 1152, 8, 8, 1, 1, 1, 1, 1, 1, 0, False, 1], + [1, 336, 112, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 672, 112, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 672, 112, 15, 15, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 672, 112, 20, 20, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 672, 112, 24, 24, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 160, 112, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 672, 112, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 128, 1120, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 128, 1120, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 128, 1152, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 192, 1152, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 128, 1152, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 192, 1152, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 320, 1152, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 192, 1152, 8, 8, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 320, 1152, 8, 8, 1, 1, 1, 1, 0, 0, 1, False, 1], [1, 40, 116, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1], [1, 34, 118, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 128, 1184, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 128, 1184, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 104, 12, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 120, 12, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 48, 12, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 12, 12, 56, 56, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 12, 120, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 30, 120, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 32, 120, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 480, 120, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 120, 120, 14, 14, 1, 5, 1, 1, 1, 1, 0, False, 1], - [1, 120, 120, 14, 14, 5, 1, 1, 1, 1, 1, 2, False, 1], - [1, 120, 120, 14, 14, 5, 5, 1, 1, 1, 1, 2, False, 1], - [1, 48, 120, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 720, 120, 17, 17, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 120, 120, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 120, 120, 28, 28, 5, 5, 1, 1, 1, 1, 2, False, 1], - [1, 120, 120, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 120, 120, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 20, 120, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 336, 120, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 336, 120, 28, 28, 1, 1, 2, 2, 2, 2, 0, False, 1], - [1, 40, 120, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 120, 120, 40, 40, 5, 5, 1, 1, 1, 1, 2, False, 1], - [1, 40, 120, 40, 40, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 120, 120, 56, 56, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 192, 1200, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 192, 1200, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 128, 1216, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 128, 1216, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], + [1, 128, 1184, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 128, 1184, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 104, 12, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 120, 12, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 48, 12, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 12, 120, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 30, 120, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 32, 120, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 480, 120, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 48, 120, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 720, 120, 17, 17, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 120, 120, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 20, 120, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 336, 120, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 336, 120, 28, 28, 1, 1, 2, 2, 0, 0, 1, False, 1], + [1, 40, 120, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 40, 120, 40, 40, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 192, 1200, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 192, 1200, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 128, 1216, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 128, 1216, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], [1, 46, 122, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 112, 1232, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 308, 1232, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 1232, 1232, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 1232, 1232, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 1232, 1232, 28, 28, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 128, 124, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 128, 1248, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 192, 1248, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 128, 1248, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 192, 1248, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 1248, 1248, 9, 9, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 1248, 1248, 9, 9, 5, 5, 1, 1, 1, 1, 2, False, 1], - [1, 208, 1248, 9, 9, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 352, 1248, 9, 9, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 128, 128, 1, 1, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 24, 128, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 546, 128, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 256, 128, 10, 10, 3, 3, 2, 2, 2, 2, 1, True, 1], + [1, 112, 1232, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 308, 1232, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 1232, 1232, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 128, 124, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 128, 1248, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 192, 1248, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 128, 1248, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 192, 1248, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 208, 1248, 9, 9, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 352, 1248, 9, 9, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 24, 128, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 546, 128, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 256, 128, 10, 10, 3, 3, 2, 2, 1, 1, 1, True, 1], [1, 128, 128, 100, 136, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 512, 128, 100, 136, 1, 1, 1, 1, 1, 1, 0, False, 1], + [1, 512, 128, 100, 136, 1, 1, 1, 1, 0, 0, 1, False, 1], [1, 128, 128, 112, 112, 3, 3, 1, 1, 1, 1, 1, False, 1], [1, 128, 128, 112, 112, 3, 3, 1, 1, 1, 1, 1, True, 1], [1, 64, 128, 120, 160, 3, 3, 1, 1, 1, 1, 1, True, 1], - [1, 128, 128, 128, 128, 3, 3, 1, 1, 1, 1, 1, True, 1], - [1, 256, 128, 128, 128, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 64, 128, 128, 128, 1, 1, 1, 1, 1, 1, 0, False, 1], + [1, 256, 128, 128, 128, 3, 3, 2, 2, 1, 1, 1, False, 1], + [1, 64, 128, 128, 128, 1, 1, 1, 1, 0, 0, 1, False, 1], [1, 64, 128, 128, 128, 3, 3, 1, 1, 1, 1, 1, False, 1], [1, 128, 128, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 256, 128, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], + [1, 256, 128, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], [1, 256, 128, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1], [1, 32, 128, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 512, 128, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 128, 128, 150, 150, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 128, 128, 150, 150, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 128, 128, 150, 150, 1, 1, 1, 1, 1, 1, 0, False, 1], + [1, 512, 128, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 128, 128, 150, 150, 1, 1, 1, 1, 0, 0, 1, False, 1], [1, 128, 128, 150, 150, 3, 3, 1, 1, 1, 1, 1, True, 1], - [1, 128, 128, 180, 320, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 256, 128, 2, 2, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 128, 128, 200, 272, 3, 3, 2, 2, 2, 2, 1, False, 1], + [1, 128, 128, 180, 320, 3, 3, 2, 2, 1, 1, 1, False, 1], + [1, 256, 128, 2, 2, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 128, 128, 200, 272, 3, 3, 2, 2, 1, 1, 1, False, 1], [1, 64, 128, 224, 224, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 128, 128, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 128, 128, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 128, 128, 28, 28, 1, 1, 1, 1, 1, 1, 0, True, 1], + [1, 128, 128, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 128, 128, 28, 28, 1, 1, 1, 1, 0, 0, 1, True, 1], [1, 128, 128, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1], [1, 128, 128, 28, 28, 3, 3, 1, 1, 1, 1, 1, True, 1], - [1, 128, 128, 28, 28, 3, 3, 1, 1, 1, 1, 2, True, 1], - [1, 128, 128, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1], + [1, 128, 128, 28, 28, 3, 3, 1, 1, 2, 2, 1, True, 1], [1, 16, 128, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 19, 128, 28, 28, 1, 1, 1, 1, 1, 1, 0, True, 1], + [1, 19, 128, 28, 28, 1, 1, 1, 1, 0, 0, 1, True, 1], [1, 192, 128, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 256, 128, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 256, 128, 28, 28, 1, 1, 2, 2, 2, 2, 0, False, 1], - [1, 256, 128, 28, 28, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 288, 128, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 288, 128, 28, 28, 1, 1, 2, 2, 2, 2, 0, False, 1], + [1, 256, 128, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 256, 128, 28, 28, 1, 1, 2, 2, 0, 0, 1, False, 1], + [1, 256, 128, 28, 28, 3, 3, 2, 2, 1, 1, 1, False, 1], + [1, 288, 128, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 288, 128, 28, 28, 1, 1, 2, 2, 0, 0, 1, False, 1], [1, 32, 128, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 38, 128, 28, 28, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 512, 128, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 512, 128, 28, 28, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 128, 128, 3, 3, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 256, 128, 3, 3, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 256, 128, 3, 3, 3, 3, 1, 1, 1, 1, 0, True, 1], + [1, 38, 128, 28, 28, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 512, 128, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 512, 128, 28, 28, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 256, 128, 3, 3, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 256, 128, 3, 3, 3, 3, 1, 1, 0, 0, 1, True, 1], [1, 64, 128, 30, 40, 3, 3, 1, 1, 1, 1, 1, True, 1], [1, 256, 128, 32, 32, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 256, 128, 5, 5, 3, 3, 1, 1, 1, 1, 0, True, 1], + [1, 256, 128, 5, 5, 3, 3, 1, 1, 0, 0, 1, True, 1], + [1, 128, 128, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1], [1, 128, 128, 56, 56, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 128, 128, 56, 56, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 128, 128, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 128, 128, 56, 56, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 128, 128, 56, 56, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 128, 128, 56, 56, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 128, 128, 56, 56, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 256, 128, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1], + [1, 128, 128, 56, 56, 3, 3, 2, 2, 1, 1, 1, False, 1], + [1, 256, 128, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1], [1, 256, 128, 56, 56, 3, 3, 1, 1, 1, 1, 1, False, 1], [1, 256, 128, 56, 56, 3, 3, 1, 1, 1, 1, 1, True, 1], - [1, 256, 128, 56, 56, 3, 3, 2, 2, 2, 2, 1, True, 1], + [1, 256, 128, 56, 56, 3, 3, 2, 2, 1, 1, 1, True, 1], [1, 32, 128, 56, 56, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 64, 128, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 128, 128, 60, 80, 4, 4, 4, 4, 4, 4, 0, True, 1], - [1, 320, 128, 60, 80, 3, 3, 2, 2, 2, 2, 1, True, 1], - [1, 64, 128, 60, 80, 1, 1, 1, 1, 1, 1, 0, True, 1], + [1, 64, 128, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 128, 128, 60, 80, 4, 4, 4, 4, 0, 0, 1, True, 1], + [1, 320, 128, 60, 80, 3, 3, 2, 2, 1, 1, 1, True, 1], + [1, 64, 128, 60, 80, 1, 1, 1, 1, 0, 0, 1, True, 1], [1, 64, 128, 60, 80, 3, 3, 1, 1, 1, 1, 1, True, 1], [1, 128, 128, 64, 64, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 128, 128, 64, 64, 2, 2, 2, 2, 2, 2, 0, True, 1], + [1, 128, 128, 64, 64, 2, 2, 2, 2, 0, 0, 1, True, 1], [1, 256, 128, 64, 64, 3, 3, 1, 1, 1, 1, 1, False, 1], [1, 32, 128, 7, 7, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 128, 128, 75, 75, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 128, 128, 75, 75, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 256, 128, 75, 75, 1, 1, 1, 1, 1, 1, 0, False, 1], + [1, 128, 128, 75, 75, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 256, 128, 75, 75, 1, 1, 1, 1, 0, 0, 1, False, 1], [1, 256, 128, 75, 75, 3, 3, 1, 1, 1, 1, 1, True, 1], [1, 128, 128, 90, 160, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 512, 128, 90, 160, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 128, 1280, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 640, 1280, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 128, 1280, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 512, 1280, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 192, 1296, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 192, 1296, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 128, 1312, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 128, 1312, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 1056, 132, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 528, 132, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 128, 1344, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 1344, 1344, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 1344, 1344, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 192, 1344, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 1344, 1344, 28, 28, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 128, 1344, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 192, 1344, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 816, 136, 19, 19, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 128, 1376, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 128, 1376, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 174, 1392, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 348, 1392, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 1392, 1392, 10, 10, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 1392, 1392, 10, 10, 5, 5, 1, 1, 1, 1, 2, False, 1], - [1, 232, 1392, 10, 10, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 384, 1392, 10, 10, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 1392, 1392, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 1392, 1392, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 192, 1392, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 1392, 1392, 28, 28, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 192, 1392, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 128, 1408, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 128, 1408, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], + [1, 512, 128, 90, 160, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 128, 1280, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 640, 1280, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 128, 1280, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 512, 1280, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 192, 1296, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 192, 1296, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 128, 1312, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 128, 1312, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 1056, 132, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 528, 132, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 128, 1344, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 1344, 1344, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 192, 1344, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 128, 1344, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 192, 1344, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 816, 136, 19, 19, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 128, 1376, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 128, 1376, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 174, 1392, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 348, 1392, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 232, 1392, 10, 10, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 384, 1392, 10, 10, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 1392, 1392, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 192, 1392, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 192, 1392, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 128, 1408, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 128, 1408, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], [1, 68, 142, 56, 56, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 1512, 144, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 16, 144, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 36, 144, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 40, 144, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 576, 144, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 144, 144, 14, 14, 5, 5, 1, 1, 1, 1, 2, False, 1], + [1, 1512, 144, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 16, 144, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 36, 144, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 40, 144, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 576, 144, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], [1, 288, 144, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 48, 144, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 144, 144, 151, 151, 3, 3, 2, 2, 2, 2, 0, False, 1], - [1, 144, 144, 191, 191, 3, 3, 2, 2, 2, 2, 0, False, 1], - [1, 144, 144, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 144, 144, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1], + [1, 48, 144, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 144, 144, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1], [1, 28, 144, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 32, 144, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 320, 144, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 320, 144, 28, 28, 1, 1, 2, 2, 2, 2, 0, False, 1], - [1, 40, 144, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 40, 144, 30, 30, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 48, 144, 33, 33, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 144, 144, 56, 56, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 144, 144, 56, 56, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 144, 144, 56, 56, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 192, 144, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 24, 144, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 144, 144, 59, 59, 5, 5, 2, 2, 2, 2, 0, False, 1], - [1, 144, 144, 60, 60, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 24, 144, 60, 60, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 144, 144, 63, 63, 5, 5, 2, 2, 2, 2, 0, False, 1], - [1, 144, 144, 65, 65, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 24, 144, 65, 65, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 144, 144, 69, 69, 5, 5, 2, 2, 2, 2, 0, False, 1], - [1, 1024, 144, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], + [1, 32, 144, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 320, 144, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 320, 144, 28, 28, 1, 1, 2, 2, 0, 0, 1, False, 1], + [1, 40, 144, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 40, 144, 30, 30, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 48, 144, 33, 33, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 192, 144, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 24, 144, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 24, 144, 60, 60, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 24, 144, 65, 65, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 1024, 144, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], [1, 144, 144, 7, 7, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 18, 144, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 256, 144, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 36, 144, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 72, 144, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 32, 144, 75, 75, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 32, 144, 95, 95, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 128, 1440, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 192, 1440, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 128, 1440, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 192, 1440, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 128, 1472, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 128, 1472, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 192, 1488, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 192, 1488, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 128, 1504, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 128, 1504, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 144, 1512, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 1512, 1512, 14, 14, 3, 3, 2, 2, 2, 2, 1, False, 1], + [1, 18, 144, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 256, 144, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 36, 144, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 72, 144, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 32, 144, 75, 75, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 32, 144, 95, 95, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 128, 1440, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 192, 1440, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 128, 1440, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 192, 1440, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 128, 1472, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 128, 1472, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 192, 1488, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 192, 1488, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 128, 1504, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 128, 1504, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 144, 1512, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], [1, 58, 152, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 1536, 1536, 10, 10, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 128, 1536, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 192, 1536, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 128, 1536, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 192, 1536, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 256, 1536, 8, 8, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 384, 1536, 8, 8, 1, 1, 1, 1, 1, 1, 0, False, 1], + [1, 128, 1536, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 192, 1536, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 128, 1536, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 192, 1536, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 256, 1536, 8, 8, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 384, 1536, 8, 8, 1, 1, 1, 1, 0, 0, 1, False, 1], [1, 68, 156, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 128, 1568, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 128, 1568, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 192, 1584, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 192, 1584, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 144, 16, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 8, 16, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 16, 16, 112, 112, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 16, 16, 112, 112, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 16, 16, 112, 112, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 24, 16, 112, 112, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 64, 16, 112, 112, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 8, 16, 112, 112, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 96, 16, 112, 112, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 96, 16, 120, 120, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 96, 16, 130, 130, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 16, 16, 14, 14, 2, 2, 2, 2, 2, 2, 0, True, 1], + [1, 128, 1568, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 128, 1568, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 192, 1584, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 192, 1584, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 144, 16, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 8, 16, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 16, 16, 112, 112, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 24, 16, 112, 112, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 64, 16, 112, 112, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 8, 16, 112, 112, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 96, 16, 112, 112, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 96, 16, 120, 120, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 96, 16, 130, 130, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 16, 16, 14, 14, 2, 2, 2, 2, 0, 0, 1, True, 1], [1, 4, 16, 14, 14, 3, 3, 1, 1, 1, 1, 1, True, 1], [1, 48, 16, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 16, 16, 160, 160, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 16, 16, 160, 160, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 64, 16, 160, 160, 1, 1, 1, 1, 1, 1, 0, False, 1], + [1, 16, 16, 160, 160, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 64, 16, 160, 160, 1, 1, 1, 1, 0, 0, 1, False, 1], [1, 16, 16, 224, 224, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 32, 16, 224, 224, 3, 3, 2, 2, 2, 2, 1, False, 1], + [1, 32, 16, 224, 224, 3, 3, 2, 2, 1, 1, 1, False, 1], [1, 32, 16, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 16, 16, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 24, 16, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 72, 16, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 160, 160, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 160, 160, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], + [1, 16, 16, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 24, 16, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 72, 16, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 160, 160, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], [1, 320, 160, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 400, 160, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 400, 160, 14, 14, 1, 1, 2, 2, 2, 2, 0, False, 1], - [1, 960, 160, 24, 24, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 128, 160, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 160, 160, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 160, 160, 28, 28, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 160, 160, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 960, 160, 3, 3, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 160, 160, 32, 32, 2, 2, 2, 2, 2, 2, 0, True, 1], - [1, 256, 160, 32, 32, 3, 3, 2, 2, 2, 2, 1, True, 1], - [1, 128, 160, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1], + [1, 400, 160, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 400, 160, 14, 14, 1, 1, 2, 2, 0, 0, 1, False, 1], + [1, 960, 160, 24, 24, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 128, 160, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 160, 160, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 960, 160, 3, 3, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 160, 160, 32, 32, 2, 2, 2, 2, 0, 0, 1, True, 1], + [1, 256, 160, 32, 32, 3, 3, 2, 2, 1, 1, 1, True, 1], + [1, 128, 160, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1], [1, 320, 160, 7, 7, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 480, 160, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 960, 160, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 64, 160, 73, 73, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 128, 1600, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 128, 1600, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 1632, 1632, 12, 12, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 1632, 1632, 12, 12, 5, 5, 1, 1, 1, 1, 2, False, 1], - [1, 272, 1632, 12, 12, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 128, 1632, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 192, 1632, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 128, 1632, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 192, 1632, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 128, 1664, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 128, 1664, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 672, 168, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 168, 168, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 168, 168, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 408, 168, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 408, 168, 28, 28, 1, 1, 2, 2, 2, 2, 0, False, 1], - [1, 168, 168, 56, 56, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 192, 1680, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 192, 1680, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 128, 1696, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 128, 1696, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], + [1, 480, 160, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 960, 160, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 64, 160, 73, 73, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 128, 1600, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 128, 1600, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 272, 1632, 12, 12, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 128, 1632, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 192, 1632, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 128, 1632, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 192, 1632, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 128, 1664, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 128, 1664, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 672, 168, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 168, 168, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 408, 168, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 408, 168, 28, 28, 1, 1, 2, 2, 0, 0, 1, False, 1], + [1, 192, 1680, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 192, 1680, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 128, 1696, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 128, 1696, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], [1, 46, 172, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 128, 1728, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 192, 1728, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 128, 1728, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 192, 1728, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 1392, 174, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 696, 174, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 128, 1760, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 128, 1760, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 192, 1776, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 192, 1776, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 896, 1792, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 128, 1792, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 216, 18, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 72, 18, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 144, 18, 14, 14, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 18, 18, 28, 28, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 72, 18, 28, 28, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 128, 18, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1], + [1, 128, 1728, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 192, 1728, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 128, 1728, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 192, 1728, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 1392, 174, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 696, 174, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 128, 1760, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 128, 1760, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 192, 1776, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 192, 1776, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 896, 1792, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 128, 1792, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 216, 18, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 72, 18, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 144, 18, 14, 14, 3, 3, 2, 2, 1, 1, 1, False, 1], + [1, 18, 18, 28, 28, 3, 3, 2, 2, 1, 1, 1, False, 1], + [1, 72, 18, 28, 28, 3, 3, 2, 2, 1, 1, 1, False, 1], + [1, 128, 18, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1], [1, 18, 18, 56, 56, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 18, 18, 56, 56, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 32, 18, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 36, 18, 56, 56, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 192, 1824, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 128, 1824, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 192, 1824, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 184, 184, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 40, 184, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 80, 184, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 184, 184, 20, 20, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 80, 184, 20, 20, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 184, 184, 7, 7, 1, 5, 1, 1, 1, 1, 0, False, 1], - [1, 184, 184, 7, 7, 5, 1, 1, 1, 1, 1, 2, False, 1], - [1, 128, 185, 28, 28, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 128, 1856, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 192, 1872, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 192, 1872, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 128, 1888, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 192, 192, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 192, 192, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], + [1, 18, 18, 56, 56, 3, 3, 2, 2, 1, 1, 1, False, 1], + [1, 32, 18, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 36, 18, 56, 56, 3, 3, 2, 2, 1, 1, 1, False, 1], + [1, 192, 1824, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 128, 1824, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 192, 1824, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 40, 184, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 80, 184, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 80, 184, 20, 20, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 128, 185, 28, 28, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 128, 1856, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 192, 1872, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 192, 1872, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 128, 1888, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 192, 192, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], [1, 48, 192, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 64, 192, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 192, 192, 17, 17, 3, 3, 2, 2, 2, 2, 0, False, 1], - [1, 192, 192, 17, 17, 7, 1, 1, 1, 1, 1, 3, False, 1], - [1, 224, 192, 17, 17, 1, 7, 1, 1, 1, 1, 0, False, 1], - [1, 128, 192, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 16, 192, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 192, 192, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 192, 192, 28, 28, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 192, 192, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 192, 192, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 32, 192, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 432, 192, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 432, 192, 28, 28, 1, 1, 2, 2, 2, 2, 0, False, 1], + [1, 64, 192, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 192, 192, 17, 17, 3, 3, 2, 2, 0, 0, 1, False, 1], + [1, 192, 192, 17, 17, 7, 1, 1, 1, 3, 0, 1, False, 1], + [1, 224, 192, 17, 17, 1, 7, 1, 1, 0, 3, 1, False, 1], + [1, 128, 192, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 16, 192, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 192, 192, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 32, 192, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 432, 192, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 432, 192, 28, 28, 1, 1, 2, 2, 0, 0, 1, False, 1], [1, 48, 192, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 64, 192, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 96, 192, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1], + [1, 64, 192, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 96, 192, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1], [1, 224, 192, 35, 35, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 48, 192, 38, 38, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 56, 192, 48, 48, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 128, 192, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 192, 192, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 192, 192, 56, 56, 3, 3, 2, 2, 2, 2, 1, False, 1], + [1, 48, 192, 38, 38, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 56, 192, 48, 48, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 128, 192, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 192, 192, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1], [1, 48, 192, 56, 56, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 1152, 192, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], + [1, 1152, 192, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], [1, 384, 192, 7, 7, 3, 3, 1, 1, 1, 1, 1, False, 1], [1, 48, 192, 7, 7, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 192, 192, 71, 71, 3, 3, 2, 2, 2, 2, 0, False, 1], - [1, 192, 192, 75, 75, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 32, 192, 75, 75, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 192, 192, 79, 79, 5, 5, 2, 2, 2, 2, 0, False, 1], - [1, 1152, 192, 8, 8, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 192, 192, 95, 95, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 32, 192, 95, 95, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 192, 192, 99, 99, 5, 5, 2, 2, 2, 2, 0, False, 1], - [1, 192, 1920, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 1920, 1920, 14, 14, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 192, 1920, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 784, 196, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], + [1, 192, 192, 71, 71, 3, 3, 2, 2, 0, 0, 1, False, 1], + [1, 32, 192, 75, 75, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 1152, 192, 8, 8, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 32, 192, 95, 95, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 192, 1920, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 192, 1920, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 784, 196, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], [1, 40, 196, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 192, 1968, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 192, 1968, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 72, 20, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 20, 20, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 200, 200, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 40, 200, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 80, 200, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 200, 200, 20, 20, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 80, 200, 20, 20, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 200, 200, 7, 7, 1, 5, 1, 1, 1, 1, 0, False, 1], - [1, 200, 200, 7, 7, 5, 1, 1, 1, 1, 1, 2, False, 1], - [1, 224, 2016, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 192, 2016, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 2016, 2016, 14, 14, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 192, 2016, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 2048, 2048, 14, 14, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 2048, 2048, 14, 14, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 2048, 2048, 14, 14, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 256, 2048, 23, 40, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 512, 2048, 23, 40, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 256, 2048, 25, 34, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 256, 2048, 25, 34, 3, 3, 2, 2, 2, 2, 1, True, 1], - [1, 512, 2048, 25, 34, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 2048, 2048, 7, 7, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 2048, 2048, 7, 7, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 512, 2048, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 192, 2064, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 192, 2064, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 26, 208, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 52, 208, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 208, 208, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 208, 208, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 440, 208, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 440, 208, 14, 14, 1, 1, 2, 2, 2, 2, 0, False, 1], - [1, 208, 208, 28, 28, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 1248, 208, 9, 9, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 192, 2112, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 18, 216, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 54, 216, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 216, 216, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 216, 216, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 576, 216, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 216, 216, 56, 56, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 192, 2160, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], + [1, 192, 1968, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 192, 1968, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 72, 20, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 40, 200, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 80, 200, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 80, 200, 20, 20, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 224, 2016, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 192, 2016, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 192, 2016, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 256, 2048, 23, 40, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 512, 2048, 23, 40, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 256, 2048, 25, 34, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 256, 2048, 25, 34, 3, 3, 2, 2, 1, 1, 1, True, 1], + [1, 512, 2048, 25, 34, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 512, 2048, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 192, 2064, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 192, 2064, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 26, 208, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 52, 208, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 208, 208, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 440, 208, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 440, 208, 14, 14, 1, 1, 2, 2, 0, 0, 1, False, 1], + [1, 1248, 208, 9, 9, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 192, 2112, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 18, 216, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 54, 216, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 216, 216, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 576, 216, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 192, 2160, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], [1, 78, 218, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 888, 222, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 2016, 224, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 56, 224, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 8, 224, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 896, 224, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 224, 224, 112, 112, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 224, 224, 112, 112, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 224, 224, 17, 17, 7, 1, 1, 1, 1, 1, 3, False, 1], - [1, 256, 224, 17, 17, 1, 7, 1, 1, 1, 1, 0, False, 1], - [1, 256, 224, 17, 17, 7, 1, 1, 1, 1, 1, 3, False, 1], - [1, 128, 224, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 256, 224, 35, 35, 3, 3, 2, 2, 2, 2, 0, False, 1], - [1, 128, 224, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 224, 224, 56, 56, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 224, 224, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 224, 224, 56, 56, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 448, 224, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 448, 224, 56, 56, 1, 1, 2, 2, 2, 2, 0, False, 1], - [1, 224, 224, 7, 7, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 224, 224, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 58, 232, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 8, 232, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 1392, 232, 10, 10, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 232, 232, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 696, 232, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 696, 232, 56, 56, 1, 1, 2, 2, 2, 2, 0, False, 1], + [1, 888, 222, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 2016, 224, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 56, 224, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 8, 224, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 896, 224, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 224, 224, 17, 17, 7, 1, 1, 1, 3, 0, 1, False, 1], + [1, 256, 224, 17, 17, 1, 7, 1, 1, 0, 3, 1, False, 1], + [1, 256, 224, 17, 17, 7, 1, 1, 1, 3, 0, 1, False, 1], + [1, 128, 224, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 256, 224, 35, 35, 3, 3, 2, 2, 0, 0, 1, False, 1], + [1, 128, 224, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 224, 224, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 448, 224, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 448, 224, 56, 56, 1, 1, 2, 2, 0, 0, 1, False, 1], + [1, 224, 224, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 58, 232, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 8, 232, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 1392, 232, 10, 10, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 232, 232, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 696, 232, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 696, 232, 56, 56, 1, 1, 2, 2, 0, 0, 1, False, 1], [1, 68, 236, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 72, 24, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 96, 24, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 24, 24, 112, 112, 3, 3, 1, 1, 1, 1, 1, False, 1], + [1, 72, 24, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 96, 24, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], [1, 64, 24, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 144, 24, 150, 150, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 40, 24, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 72, 24, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 88, 24, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 96, 24, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1], + [1, 144, 24, 150, 150, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 40, 24, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 72, 24, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 88, 24, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 96, 24, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1], [1, 14, 24, 56, 56, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 144, 24, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 24, 24, 56, 56, 5, 5, 2, 2, 2, 2, 2, False, 1], - [1, 36, 24, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 72, 24, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 144, 24, 60, 60, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 144, 24, 65, 65, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 72, 24, 80, 80, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 64, 240, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 960, 240, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 240, 240, 14, 14, 1, 5, 1, 1, 1, 1, 0, False, 1], - [1, 240, 240, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 240, 240, 14, 14, 5, 1, 1, 1, 1, 1, 2, False, 1], - [1, 240, 240, 14, 14, 5, 5, 1, 1, 1, 1, 2, False, 1], - [1, 40, 240, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 80, 240, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 80, 240, 15, 15, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 80, 240, 20, 20, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 192, 240, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 240, 240, 28, 28, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 240, 240, 28, 28, 5, 5, 1, 1, 1, 1, 2, False, 1], - [1, 240, 240, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 240, 240, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 40, 240, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 720, 240, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 240, 240, 29, 29, 3, 3, 2, 2, 2, 2, 0, False, 1], - [1, 240, 240, 30, 30, 5, 5, 1, 1, 1, 1, 2, False, 1], - [1, 40, 240, 30, 30, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 240, 240, 31, 31, 3, 3, 2, 2, 2, 2, 0, False, 1], - [1, 240, 240, 40, 40, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 192, 240, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 240, 240, 56, 56, 3, 3, 2, 2, 2, 2, 1, False, 1], + [1, 144, 24, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 36, 24, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 72, 24, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 144, 24, 60, 60, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 144, 24, 65, 65, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 72, 24, 80, 80, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 64, 240, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 960, 240, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 40, 240, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 80, 240, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 80, 240, 15, 15, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 80, 240, 20, 20, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 192, 240, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 240, 240, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 40, 240, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 720, 240, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 40, 240, 30, 30, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 192, 240, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1], [1, 16, 256, 1, 1, 3, 3, 1, 1, 1, 1, 1, True, 1], - [1, 256, 256, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], + [1, 256, 256, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], [1, 364, 256, 1, 1, 3, 3, 1, 1, 1, 1, 1, True, 1], - [1, 256, 256, 10, 10, 3, 3, 2, 2, 2, 2, 1, False, 1], [1, 256, 256, 100, 136, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 256, 256, 100, 136, 3, 3, 2, 2, 2, 2, 1, False, 1], + [1, 256, 256, 100, 136, 3, 3, 2, 2, 1, 1, 1, False, 1], [1, 256, 256, 100, 136, 3, 3, 1, 1, 1, 1, 1, True, 1], [1, 36, 256, 100, 136, 3, 3, 1, 1, 1, 1, 1, True, 1], [1, 128, 256, 112, 112, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 150, 256, 128, 128, 1, 1, 1, 1, 1, 1, 0, True, 1], + [1, 150, 256, 128, 128, 1, 1, 1, 1, 0, 0, 1, True, 1], [1, 256, 256, 13, 17, 3, 3, 1, 1, 1, 1, 1, False, 1], [1, 256, 256, 13, 17, 3, 3, 1, 1, 1, 1, 1, True, 1], - [1, 256, 256, 13, 17, 3, 3, 2, 2, 2, 2, 1, True, 1], + [1, 256, 256, 13, 17, 3, 3, 2, 2, 1, 1, 1, True, 1], [1, 36, 256, 13, 17, 3, 3, 1, 1, 1, 1, 1, True, 1], [1, 819, 256, 13, 17, 3, 3, 1, 1, 1, 1, 1, True, 1], - [1, 1024, 256, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 128, 256, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], + [1, 1024, 256, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 128, 256, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], [1, 256, 256, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 512, 256, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 512, 256, 14, 14, 1, 1, 2, 2, 2, 2, 0, False, 1], - [1, 512, 256, 14, 14, 3, 3, 2, 2, 2, 2, 1, False, 1], + [1, 512, 256, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 512, 256, 14, 14, 1, 1, 2, 2, 0, 0, 1, False, 1], + [1, 512, 256, 14, 14, 3, 3, 2, 2, 1, 1, 1, False, 1], [1, 512, 256, 16, 16, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 256, 256, 17, 17, 1, 7, 1, 1, 1, 1, 0, False, 1], - [1, 320, 256, 17, 17, 7, 1, 1, 1, 1, 1, 3, False, 1], - [1, 128, 256, 180, 320, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 64, 256, 180, 320, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 512, 256, 19, 19, 3, 3, 2, 2, 2, 2, 1, True, 1], - [1, 24, 256, 2, 2, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 256, 256, 2, 2, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 546, 256, 2, 2, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 64, 256, 2, 2, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 128, 256, 200, 272, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 64, 256, 200, 272, 1, 1, 1, 1, 1, 1, 0, False, 1], + [1, 256, 256, 17, 17, 1, 7, 1, 1, 0, 3, 1, False, 1], + [1, 320, 256, 17, 17, 7, 1, 1, 1, 3, 0, 1, False, 1], + [1, 128, 256, 180, 320, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 64, 256, 180, 320, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 512, 256, 19, 19, 3, 3, 2, 2, 1, 1, 1, True, 1], + [1, 24, 256, 2, 2, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 546, 256, 2, 2, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 64, 256, 2, 2, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 128, 256, 200, 272, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 64, 256, 200, 272, 1, 1, 1, 1, 0, 0, 1, False, 1], [1, 256, 256, 25, 34, 3, 3, 1, 1, 1, 1, 1, False, 1], [1, 256, 256, 25, 34, 3, 3, 1, 1, 1, 1, 1, True, 1], - [1, 256, 256, 25, 34, 3, 3, 2, 2, 2, 2, 1, True, 1], + [1, 256, 256, 25, 34, 3, 3, 2, 2, 1, 1, 1, True, 1], [1, 36, 256, 25, 34, 3, 3, 1, 1, 1, 1, 1, True, 1], - [1, 128, 256, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 160, 256, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1], + [1, 128, 256, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 160, 256, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1], [1, 20, 256, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1], + [1, 256, 256, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1], [1, 256, 256, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 256, 256, 28, 28, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 256, 256, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 256, 256, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 256, 256, 28, 28, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 256, 256, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 32, 256, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 512, 256, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1], + [1, 256, 256, 28, 28, 3, 3, 2, 2, 1, 1, 1, False, 1], + [1, 32, 256, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 512, 256, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1], [1, 512, 256, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1], [1, 512, 256, 28, 28, 3, 3, 1, 1, 1, 1, 1, True, 1], - [1, 512, 256, 28, 28, 3, 3, 2, 2, 2, 2, 1, True, 1], - [1, 64, 256, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 128, 256, 3, 3, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 128, 256, 3, 3, 1, 1, 1, 1, 1, 1, 0, True, 1], + [1, 512, 256, 28, 28, 3, 3, 2, 2, 1, 1, 1, True, 1], + [1, 64, 256, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 128, 256, 3, 3, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 128, 256, 3, 3, 1, 1, 1, 1, 0, 0, 1, True, 1], [1, 16, 256, 3, 3, 3, 3, 1, 1, 1, 1, 1, True, 1], - [1, 24, 256, 3, 3, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 256, 256, 3, 3, 3, 3, 1, 1, 1, 1, 1, False, 1], + [1, 24, 256, 3, 3, 1, 1, 1, 1, 0, 0, 1, True, 1], [1, 364, 256, 3, 3, 3, 3, 1, 1, 1, 1, 1, True, 1], - [1, 546, 256, 3, 3, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 128, 256, 32, 32, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 256, 256, 32, 32, 2, 2, 2, 2, 2, 2, 0, True, 1], + [1, 546, 256, 3, 3, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 128, 256, 32, 32, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 256, 256, 32, 32, 2, 2, 2, 2, 0, 0, 1, True, 1], [1, 256, 256, 32, 32, 3, 3, 1, 1, 1, 1, 1, False, 1], [1, 512, 256, 32, 32, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 256, 256, 38, 38, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 256, 256, 38, 38, 1, 1, 1, 1, 1, 1, 0, False, 1], + [1, 256, 256, 38, 38, 1, 1, 1, 1, 0, 0, 1, False, 1], [1, 512, 256, 38, 38, 3, 3, 1, 1, 1, 1, 1, True, 1], - [1, 728, 256, 38, 38, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 1024, 256, 45, 80, 1, 1, 1, 1, 1, 1, 0, False, 1], + [1, 728, 256, 38, 38, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 1024, 256, 45, 80, 1, 1, 1, 1, 0, 0, 1, False, 1], [1, 256, 256, 45, 80, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 128, 256, 5, 5, 1, 1, 1, 1, 1, 1, 0, True, 1], + [1, 128, 256, 5, 5, 1, 1, 1, 1, 0, 0, 1, True, 1], [1, 24, 256, 5, 5, 3, 3, 1, 1, 1, 1, 1, True, 1], - [1, 512, 256, 5, 5, 1, 1, 1, 1, 1, 1, 0, False, 1], + [1, 512, 256, 5, 5, 1, 1, 1, 1, 0, 0, 1, False, 1], [1, 546, 256, 5, 5, 3, 3, 1, 1, 1, 1, 1, True, 1], - [1, 1024, 256, 50, 68, 1, 1, 1, 1, 1, 1, 0, False, 1], + [1, 1024, 256, 50, 68, 1, 1, 1, 1, 0, 0, 1, False, 1], [1, 256, 256, 50, 68, 3, 3, 1, 1, 1, 1, 1, False, 1], [1, 256, 256, 50, 68, 3, 3, 1, 1, 1, 1, 1, True, 1], [1, 36, 256, 50, 68, 3, 3, 1, 1, 1, 1, 1, True, 1], - [1, 128, 256, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1], + [1, 128, 256, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1], [1, 18, 256, 56, 56, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 256, 256, 56, 56, 2, 2, 2, 2, 2, 2, 0, True, 1], - [1, 256, 256, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 256, 256, 56, 56, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 256, 256, 56, 56, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 36, 256, 56, 56, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 512, 256, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 512, 256, 56, 56, 1, 1, 2, 2, 2, 2, 0, False, 1], - [1, 64, 256, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 128, 256, 64, 64, 1, 1, 1, 1, 1, 1, 0, False, 1], + [1, 256, 256, 56, 56, 2, 2, 2, 2, 0, 0, 1, True, 1], + [1, 256, 256, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 256, 256, 56, 56, 3, 3, 2, 2, 1, 1, 1, False, 1], + [1, 36, 256, 56, 56, 3, 3, 2, 2, 1, 1, 1, False, 1], + [1, 512, 256, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 512, 256, 56, 56, 1, 1, 2, 2, 0, 0, 1, False, 1], + [1, 64, 256, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 128, 256, 64, 64, 1, 1, 1, 1, 0, 0, 1, False, 1], [1, 128, 256, 64, 64, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 255, 256, 64, 64, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 256, 256, 64, 64, 3, 3, 1, 1, 1, 1, 1, True, 1], - [1, 512, 256, 64, 64, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 1024, 256, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], + [1, 255, 256, 64, 64, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 512, 256, 64, 64, 3, 3, 2, 2, 1, 1, 1, False, 1], + [1, 1024, 256, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], [1, 256, 256, 7, 7, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 512, 256, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], + [1, 512, 256, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], [1, 256, 256, 7, 9, 3, 3, 1, 1, 1, 1, 1, False, 1], [1, 256, 256, 7, 9, 3, 3, 1, 1, 1, 1, 1, True, 1], [1, 36, 256, 7, 9, 3, 3, 1, 1, 1, 1, 1, True, 1], [1, 819, 256, 7, 9, 3, 3, 1, 1, 1, 1, 1, True, 1], - [1, 256, 256, 75, 75, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 256, 256, 75, 75, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 256, 256, 75, 75, 1, 1, 2, 2, 2, 2, 0, False, 1], - [1, 256, 256, 90, 160, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 104, 26, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 208, 26, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 256, 262, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 1056, 264, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 1632, 272, 12, 12, 1, 1, 1, 1, 1, 1, 0, False, 1], + [1, 256, 256, 75, 75, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 256, 256, 75, 75, 1, 1, 2, 2, 0, 0, 1, False, 1], + [1, 256, 256, 90, 160, 3, 3, 2, 2, 1, 1, 1, False, 1], + [1, 104, 26, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 208, 26, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 256, 262, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 1056, 264, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 1632, 272, 12, 12, 1, 1, 1, 1, 0, 0, 1, False, 1], [1, 160, 272, 7, 7, 3, 3, 1, 1, 1, 1, 1, False, 1], [1, 34, 276, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1], [1, 16, 28, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 72, 288, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 128, 288, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 288, 288, 14, 14, 5, 5, 2, 2, 2, 2, 2, False, 1], - [1, 288, 288, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 288, 288, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 672, 288, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 672, 288, 14, 14, 1, 1, 2, 2, 2, 2, 0, False, 1], - [1, 88, 288, 17, 17, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 96, 288, 19, 19, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 128, 288, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 192, 288, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 288, 288, 28, 28, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 288, 288, 33, 33, 5, 5, 1, 1, 1, 1, 2, False, 1], - [1, 48, 288, 33, 33, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 288, 288, 35, 35, 3, 3, 2, 2, 2, 2, 0, False, 1], - [1, 288, 288, 38, 38, 5, 5, 1, 1, 1, 1, 2, False, 1], - [1, 48, 288, 38, 38, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 288, 288, 39, 39, 3, 3, 2, 2, 2, 2, 0, False, 1], - [1, 192, 288, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 96, 288, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], + [1, 72, 288, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 128, 288, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 288, 288, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 672, 288, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 672, 288, 14, 14, 1, 1, 2, 2, 0, 0, 1, False, 1], + [1, 88, 288, 17, 17, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 96, 288, 19, 19, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 128, 288, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 192, 288, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 48, 288, 33, 33, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 48, 288, 38, 38, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 192, 288, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 96, 288, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], [1, 134, 296, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 128, 3, 224, 224, 4, 4, 4, 4, 4, 4, 0, True, 1], - [1, 16, 3, 224, 224, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 16, 3, 224, 224, 7, 7, 1, 1, 1, 1, 3, False, 1], - [1, 32, 3, 224, 224, 3, 3, 2, 2, 2, 2, 1, False, 1], + [1, 128, 3, 224, 224, 4, 4, 4, 4, 0, 0, 1, True, 1], + [1, 16, 3, 224, 224, 3, 3, 2, 2, 1, 1, 1, False, 1], + [1, 16, 3, 224, 224, 7, 7, 1, 1, 3, 3, 1, False, 1], + [1, 32, 3, 224, 224, 3, 3, 2, 2, 1, 1, 1, False, 1], [1, 64, 3, 224, 224, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 64, 3, 224, 224, 3, 3, 2, 2, 2, 2, 1, False, 1], + [1, 64, 3, 224, 224, 3, 3, 2, 2, 1, 1, 1, False, 1], [1, 64, 3, 224, 224, 3, 3, 1, 1, 1, 1, 1, True, 1], - [1, 64, 3, 224, 224, 7, 7, 2, 2, 2, 2, 3, False, 1], - [1, 96, 3, 224, 224, 4, 4, 4, 4, 4, 4, 0, True, 1], - [1, 96, 3, 224, 224, 7, 7, 2, 2, 2, 2, 3, False, 1], - [1, 32, 3, 225, 225, 3, 3, 2, 2, 2, 2, 0, False, 1], - [1, 32, 3, 241, 241, 3, 3, 2, 2, 2, 2, 0, False, 1], - [1, 128, 3, 256, 256, 4, 4, 4, 4, 4, 4, 0, True, 1], + [1, 64, 3, 224, 224, 7, 7, 2, 2, 3, 3, 1, False, 1], + [1, 96, 3, 224, 224, 4, 4, 4, 4, 0, 0, 1, True, 1], + [1, 96, 3, 224, 224, 7, 7, 2, 2, 3, 3, 1, False, 1], + [1, 32, 3, 225, 225, 3, 3, 2, 2, 0, 0, 1, False, 1], + [1, 32, 3, 241, 241, 3, 3, 2, 2, 0, 0, 1, False, 1], + [1, 128, 3, 256, 256, 4, 4, 4, 4, 0, 0, 1, True, 1], [1, 32, 3, 256, 256, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 96, 3, 256, 256, 4, 4, 4, 4, 4, 4, 0, True, 1], - [1, 32, 3, 261, 261, 3, 3, 2, 2, 2, 2, 0, False, 1], - [1, 32, 3, 299, 299, 3, 3, 2, 2, 2, 2, 1, False, 1], + [1, 96, 3, 256, 256, 4, 4, 4, 4, 0, 0, 1, True, 1], + [1, 32, 3, 261, 261, 3, 3, 2, 2, 0, 0, 1, False, 1], + [1, 32, 3, 299, 299, 3, 3, 2, 2, 1, 1, 1, False, 1], [1, 64, 3, 300, 300, 3, 3, 1, 1, 1, 1, 1, True, 1], - [1, 32, 3, 301, 301, 3, 3, 2, 2, 2, 2, 0, False, 1], - [1, 16, 3, 320, 320, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 32, 3, 384, 384, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 120, 30, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 336, 30, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 3024, 3024, 14, 14, 3, 3, 2, 2, 2, 2, 1, False, 1], + [1, 32, 3, 301, 301, 3, 3, 2, 2, 0, 0, 1, False, 1], + [1, 16, 3, 320, 320, 3, 3, 2, 2, 1, 1, 1, False, 1], + [1, 32, 3, 384, 384, 3, 3, 2, 2, 1, 1, 1, False, 1], + [1, 120, 30, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 336, 30, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], [1, 116, 304, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 1232, 308, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], + [1, 1232, 308, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], [1, 58, 310, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 120, 32, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 16, 32, 112, 112, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 224, 32, 112, 112, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 224, 32, 112, 112, 1, 1, 2, 2, 2, 2, 0, False, 1], - [1, 232, 32, 112, 112, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 232, 32, 112, 112, 1, 1, 2, 2, 2, 2, 0, False, 1], - [1, 256, 32, 112, 112, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 256, 32, 112, 112, 1, 1, 2, 2, 2, 2, 0, False, 1], - [1, 32, 32, 112, 112, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 32, 32, 112, 112, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 32, 32, 112, 112, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 32, 32, 112, 112, 1, 1, 2, 2, 2, 2, 0, False, 1], - [1, 336, 32, 112, 112, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 336, 32, 112, 112, 1, 1, 2, 2, 2, 2, 0, False, 1], - [1, 48, 32, 112, 112, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 48, 32, 112, 112, 1, 1, 2, 2, 2, 2, 0, False, 1], - [1, 64, 32, 112, 112, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 64, 32, 112, 112, 1, 1, 2, 2, 2, 2, 0, False, 1], + [1, 120, 32, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 16, 32, 112, 112, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 224, 32, 112, 112, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 224, 32, 112, 112, 1, 1, 2, 2, 0, 0, 1, False, 1], + [1, 232, 32, 112, 112, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 232, 32, 112, 112, 1, 1, 2, 2, 0, 0, 1, False, 1], + [1, 256, 32, 112, 112, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 256, 32, 112, 112, 1, 1, 2, 2, 0, 0, 1, False, 1], + [1, 32, 32, 112, 112, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 32, 32, 112, 112, 1, 1, 2, 2, 0, 0, 1, False, 1], + [1, 336, 32, 112, 112, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 336, 32, 112, 112, 1, 1, 2, 2, 0, 0, 1, False, 1], + [1, 48, 32, 112, 112, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 48, 32, 112, 112, 1, 1, 2, 2, 0, 0, 1, False, 1], + [1, 64, 32, 112, 112, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 64, 32, 112, 112, 1, 1, 2, 2, 0, 0, 1, False, 1], [1, 64, 32, 112, 112, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 64, 32, 112, 112, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 72, 32, 112, 112, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 72, 32, 112, 112, 1, 1, 2, 2, 2, 2, 0, False, 1], - [1, 80, 32, 112, 112, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 80, 32, 112, 112, 1, 1, 2, 2, 2, 2, 0, False, 1], - [1, 96, 32, 112, 112, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 96, 32, 112, 112, 1, 1, 2, 2, 2, 2, 0, False, 1], - [1, 16, 32, 120, 120, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 32, 32, 120, 120, 3, 3, 1, 1, 1, 1, 1, False, 1], + [1, 64, 32, 112, 112, 3, 3, 2, 2, 1, 1, 1, False, 1], + [1, 72, 32, 112, 112, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 72, 32, 112, 112, 1, 1, 2, 2, 0, 0, 1, False, 1], + [1, 80, 32, 112, 112, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 80, 32, 112, 112, 1, 1, 2, 2, 0, 0, 1, False, 1], + [1, 96, 32, 112, 112, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 96, 32, 112, 112, 1, 1, 2, 2, 0, 0, 1, False, 1], + [1, 16, 32, 120, 120, 1, 1, 1, 1, 0, 0, 1, False, 1], [1, 2, 32, 120, 160, 3, 3, 1, 1, 1, 1, 1, True, 1], - [1, 32, 32, 128, 128, 8, 8, 8, 8, 8, 8, 0, True, 1], + [1, 32, 32, 128, 128, 8, 8, 8, 8, 0, 0, 1, True, 1], [1, 64, 32, 128, 128, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 64, 32, 128, 128, 3, 3, 2, 2, 2, 2, 1, True, 1], - [1, 16, 32, 130, 130, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 32, 32, 130, 130, 3, 3, 1, 1, 1, 1, 1, False, 1], + [1, 64, 32, 128, 128, 3, 3, 2, 2, 1, 1, 1, True, 1], + [1, 16, 32, 130, 130, 1, 1, 1, 1, 0, 0, 1, False, 1], [1, 128, 32, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1], [1, 64, 32, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1], [1, 64, 32, 147, 147, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 32, 32, 149, 149, 3, 3, 1, 1, 1, 1, 0, False, 1], - [1, 24, 32, 150, 150, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 32, 32, 150, 150, 3, 3, 1, 1, 1, 1, 1, False, 1], + [1, 32, 32, 149, 149, 3, 3, 1, 1, 0, 0, 1, False, 1], + [1, 24, 32, 150, 150, 1, 1, 1, 1, 0, 0, 1, False, 1], [1, 64, 32, 150, 150, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 24, 32, 190, 190, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 32, 32, 190, 190, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 528, 32, 192, 192, 1, 1, 2, 2, 2, 2, 0, False, 1], - [1, 1, 32, 256, 256, 1, 1, 1, 1, 1, 1, 0, True, 1], + [1, 24, 32, 190, 190, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 528, 32, 192, 192, 1, 1, 2, 2, 0, 0, 1, False, 1], + [1, 1, 32, 256, 256, 1, 1, 1, 1, 0, 0, 1, True, 1], [1, 32, 32, 256, 256, 3, 3, 1, 1, 1, 1, 1, False, 1], [1, 64, 32, 256, 256, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 64, 32, 26, 26, 3, 3, 1, 1, 1, 1, 0, True, 1], - [1, 192, 32, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1], + [1, 64, 32, 26, 26, 3, 3, 1, 1, 0, 0, 1, True, 1], + [1, 192, 32, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1], [1, 96, 32, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1], [1, 2, 32, 30, 40, 3, 3, 1, 1, 1, 1, 1, True, 1], - [1, 128, 32, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 32, 32, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1], + [1, 128, 32, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 32, 32, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1], [1, 32, 32, 56, 56, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 64, 32, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 64, 32, 56, 56, 1, 1, 2, 2, 2, 2, 0, False, 1], + [1, 64, 32, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 64, 32, 56, 56, 1, 1, 2, 2, 0, 0, 1, False, 1], [1, 2, 32, 60, 80, 3, 3, 1, 1, 1, 1, 1, True, 1], [1, 128, 32, 7, 7, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 192, 32, 75, 75, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 192, 32, 95, 95, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 36, 320, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 80, 320, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 128, 320, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 320, 320, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 320, 320, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], + [1, 192, 32, 75, 75, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 192, 32, 95, 95, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 36, 320, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 80, 320, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 128, 320, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 320, 320, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], [1, 40, 320, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 784, 320, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 784, 320, 14, 14, 1, 1, 2, 2, 2, 2, 0, False, 1], - [1, 320, 320, 17, 17, 3, 3, 2, 2, 2, 2, 0, False, 1], - [1, 128, 320, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 320, 320, 28, 28, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 320, 320, 30, 40, 2, 2, 2, 2, 2, 2, 0, True, 1], - [1, 512, 320, 30, 40, 3, 3, 2, 2, 2, 2, 1, True, 1], - [1, 64, 320, 30, 40, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 1280, 320, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 1280, 320, 8, 8, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 320, 328, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 30, 336, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 84, 336, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 336, 336, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 336, 336, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 336, 336, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 888, 336, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 112, 336, 24, 24, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 192, 336, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 336, 336, 28, 28, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 56, 336, 48, 48, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 336, 336, 49, 49, 3, 3, 2, 2, 2, 2, 0, False, 1], - [1, 192, 336, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 336, 336, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 672, 336, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 672, 336, 56, 56, 1, 1, 2, 2, 2, 2, 0, False, 1], + [1, 784, 320, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 784, 320, 14, 14, 1, 1, 2, 2, 0, 0, 1, False, 1], + [1, 320, 320, 17, 17, 3, 3, 2, 2, 0, 0, 1, False, 1], + [1, 128, 320, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 320, 320, 30, 40, 2, 2, 2, 2, 0, 0, 1, True, 1], + [1, 512, 320, 30, 40, 3, 3, 2, 2, 1, 1, 1, True, 1], + [1, 64, 320, 30, 40, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 1280, 320, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 1280, 320, 8, 8, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 320, 328, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 30, 336, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 84, 336, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 336, 336, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 888, 336, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 112, 336, 24, 24, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 192, 336, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 56, 336, 48, 48, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 192, 336, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 336, 336, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 672, 336, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 672, 336, 56, 56, 1, 1, 2, 2, 0, 0, 1, False, 1], [1, 20, 34, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 1392, 348, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 128, 352, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 128, 352, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 1280, 352, 9, 9, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 144, 36, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 320, 36, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 144, 36, 14, 14, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 18, 36, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 256, 36, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1], + [1, 1392, 348, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 128, 352, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 128, 352, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 1280, 352, 9, 9, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 144, 36, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 320, 36, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 144, 36, 14, 14, 3, 3, 2, 2, 1, 1, 1, False, 1], + [1, 18, 36, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 256, 36, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1], [1, 36, 36, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 36, 36, 28, 28, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 64, 36, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 72, 36, 28, 28, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 36, 36, 56, 56, 3, 3, 1, 1, 1, 1, 1, False, 1], + [1, 36, 36, 28, 28, 3, 3, 2, 2, 1, 1, 1, False, 1], + [1, 64, 36, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 72, 36, 28, 28, 3, 3, 2, 2, 1, 1, 1, False, 1], [1, 68, 360, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1], [1, 98, 368, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 3712, 3712, 14, 14, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 1280, 384, 10, 10, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 128, 384, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 192, 384, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 384, 384, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 64, 384, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 96, 384, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 128, 384, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 192, 384, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 192, 384, 35, 35, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 384, 384, 35, 35, 3, 3, 2, 2, 2, 2, 0, False, 1], - [1, 64, 384, 35, 35, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 96, 384, 35, 35, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 192, 384, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 128, 384, 64, 64, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 256, 384, 8, 8, 1, 3, 1, 1, 1, 1, 0, False, 1], - [1, 256, 384, 8, 8, 3, 1, 1, 1, 1, 1, 1, False, 1], - [1, 448, 384, 8, 8, 3, 1, 1, 1, 1, 1, 1, False, 1], - [1, 4, 4, 7, 7, 2, 2, 2, 2, 2, 2, 0, True, 1], - [1, 144, 40, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 120, 40, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 240, 40, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 40, 40, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 80, 40, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 120, 40, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 240, 40, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 40, 40, 28, 28, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 60, 40, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 240, 40, 30, 30, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 120, 40, 40, 40, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 240, 40, 40, 40, 1, 1, 1, 1, 1, 1, 0, False, 1], + [1, 1280, 384, 10, 10, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 128, 384, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 192, 384, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 64, 384, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 96, 384, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 128, 384, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 192, 384, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 192, 384, 35, 35, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 384, 384, 35, 35, 3, 3, 2, 2, 0, 0, 1, False, 1], + [1, 64, 384, 35, 35, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 96, 384, 35, 35, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 192, 384, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 128, 384, 64, 64, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 256, 384, 8, 8, 1, 3, 1, 1, 0, 1, 1, False, 1], + [1, 256, 384, 8, 8, 3, 1, 1, 1, 1, 0, 1, False, 1], + [1, 448, 384, 8, 8, 3, 1, 1, 1, 1, 0, 1, False, 1], + [1, 4, 4, 7, 7, 2, 2, 2, 2, 0, 0, 1, True, 1], + [1, 144, 40, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 120, 40, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 240, 40, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 80, 40, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 120, 40, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 240, 40, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 60, 40, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 240, 40, 30, 30, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 120, 40, 40, 40, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 240, 40, 40, 40, 1, 1, 1, 1, 0, 0, 1, False, 1], [1, 14, 40, 56, 56, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 400, 400, 14, 14, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 400, 400, 7, 7, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 400, 400, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 408, 408, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 408, 408, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 912, 408, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 408, 408, 28, 28, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 128, 416, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 128, 416, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1], + [1, 400, 400, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 408, 408, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 912, 408, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 128, 416, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 128, 416, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1], [1, 116, 428, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 1008, 432, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 192, 432, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 432, 432, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 432, 432, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 192, 432, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 432, 432, 28, 28, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 110, 440, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 52, 440, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 440, 440, 14, 14, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 440, 440, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 440, 440, 7, 7, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 112, 448, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 56, 448, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 1280, 448, 12, 12, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 128, 448, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 1232, 448, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 1232, 448, 28, 28, 1, 1, 2, 2, 2, 2, 0, False, 1], - [1, 128, 448, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 448, 448, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 448, 448, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 448, 448, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 896, 448, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 896, 448, 28, 28, 1, 1, 2, 2, 2, 2, 0, False, 1], - [1, 256, 448, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 448, 448, 56, 56, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 448, 448, 56, 56, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 512, 448, 8, 8, 1, 3, 1, 1, 1, 1, 0, False, 1], + [1, 1008, 432, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 192, 432, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 432, 432, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 192, 432, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 110, 440, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 52, 440, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 440, 440, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 112, 448, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 56, 448, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 1280, 448, 12, 12, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 128, 448, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 1232, 448, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 1232, 448, 28, 28, 1, 1, 2, 2, 0, 0, 1, False, 1], + [1, 128, 448, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 448, 448, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 896, 448, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 896, 448, 28, 28, 1, 1, 2, 2, 0, 0, 1, False, 1], + [1, 256, 448, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 512, 448, 8, 8, 1, 3, 1, 1, 0, 1, 1, False, 1], [1, 16, 46, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1], [1, 168, 466, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 12, 48, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 8, 48, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 48, 48, 112, 112, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 48, 48, 112, 112, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 48, 48, 112, 112, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 144, 48, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 288, 48, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 288, 48, 33, 33, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 288, 48, 38, 38, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 104, 48, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 104, 48, 56, 56, 1, 1, 2, 2, 2, 2, 0, False, 1], - [1, 12, 48, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 120, 48, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 120, 48, 56, 56, 1, 1, 2, 2, 2, 2, 0, False, 1], - [1, 48, 48, 56, 56, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 48, 48, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1], + [1, 12, 48, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 8, 48, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 144, 48, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 288, 48, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 288, 48, 33, 33, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 288, 48, 38, 38, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 104, 48, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 104, 48, 56, 56, 1, 1, 2, 2, 0, 0, 1, False, 1], + [1, 12, 48, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 120, 48, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 120, 48, 56, 56, 1, 1, 2, 2, 0, 0, 1, False, 1], + [1, 48, 48, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1], [1, 128, 48, 7, 7, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 120, 480, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 24, 480, 10, 10, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 256, 480, 10, 10, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 480, 480, 10, 10, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 480, 480, 10, 10, 5, 5, 1, 1, 1, 1, 2, False, 1], - [1, 546, 480, 10, 10, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 80, 480, 10, 10, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 112, 480, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 128, 480, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 16, 480, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 192, 480, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 480, 480, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 480, 480, 14, 14, 5, 5, 1, 1, 1, 1, 2, False, 1], - [1, 56, 480, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 64, 480, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 80, 480, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 96, 480, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 112, 480, 15, 15, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 480, 480, 15, 15, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 480, 480, 15, 15, 5, 5, 1, 1, 1, 1, 2, False, 1], - [1, 80, 480, 15, 15, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 112, 480, 20, 20, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 480, 480, 20, 20, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 128, 480, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 192, 480, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 480, 480, 7, 7, 1, 5, 1, 1, 1, 1, 0, False, 1], - [1, 480, 480, 7, 7, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 480, 480, 7, 7, 5, 1, 1, 1, 1, 1, 2, False, 1], - [1, 1000, 512, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 512, 512, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 128, 512, 10, 10, 1, 1, 1, 1, 1, 1, 0, True, 1], + [1, 120, 480, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 24, 480, 10, 10, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 256, 480, 10, 10, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 546, 480, 10, 10, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 80, 480, 10, 10, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 112, 480, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 128, 480, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 16, 480, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 192, 480, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 56, 480, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 64, 480, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 80, 480, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 96, 480, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 112, 480, 15, 15, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 80, 480, 15, 15, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 112, 480, 20, 20, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 128, 480, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 192, 480, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 1000, 512, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 512, 512, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 128, 512, 10, 10, 1, 1, 1, 1, 0, 0, 1, True, 1], [1, 24, 512, 10, 10, 3, 3, 1, 1, 1, 1, 1, True, 1], [1, 546, 512, 10, 10, 3, 3, 1, 1, 1, 1, 1, True, 1], - [1, 1024, 512, 100, 136, 1, 1, 2, 2, 2, 2, 0, False, 1], - [1, 1024, 512, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], + [1, 1024, 512, 100, 136, 1, 1, 2, 2, 0, 0, 1, False, 1], + [1, 1024, 512, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], [1, 1024, 512, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 1024, 512, 14, 14, 3, 3, 2, 2, 2, 2, 1, True, 1], - [1, 112, 512, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 128, 512, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 144, 512, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 160, 512, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 192, 512, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 24, 512, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 256, 512, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 32, 512, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 512, 512, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 512, 512, 14, 14, 3, 3, 2, 2, 2, 2, 1, False, 1], + [1, 1024, 512, 14, 14, 3, 3, 2, 2, 1, 1, 1, True, 1], + [1, 112, 512, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 128, 512, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 144, 512, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 160, 512, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 192, 512, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 24, 512, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 256, 512, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 32, 512, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 512, 512, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], [1, 512, 512, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 512, 512, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 512, 512, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 512, 512, 14, 14, 3, 3, 2, 2, 2, 2, 1, False, 1], + [1, 512, 512, 14, 14, 3, 3, 2, 2, 1, 1, 1, False, 1], [1, 512, 512, 14, 14, 3, 3, 1, 1, 1, 1, 1, True, 1], - [1, 64, 512, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 64, 512, 15, 20, 1, 1, 1, 1, 1, 1, 0, True, 1], + [1, 64, 512, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 64, 512, 15, 20, 1, 1, 1, 1, 0, 0, 1, True, 1], [1, 1024, 512, 16, 16, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 256, 512, 16, 16, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 512, 512, 16, 16, 2, 2, 2, 2, 2, 2, 0, True, 1], + [1, 256, 512, 16, 16, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 512, 512, 16, 16, 2, 2, 2, 2, 0, 0, 1, True, 1], [1, 512, 512, 16, 16, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 1024, 512, 19, 19, 3, 3, 1, 1, 1, 1, 6, True, 1], + [1, 1024, 512, 19, 19, 3, 3, 1, 1, 6, 6, 1, True, 1], [1, 512, 512, 19, 19, 3, 3, 1, 1, 1, 1, 1, True, 1], - [1, 2048, 512, 23, 40, 1, 1, 1, 1, 1, 1, 0, False, 1], + [1, 2048, 512, 23, 40, 1, 1, 1, 1, 0, 0, 1, False, 1], [1, 512, 512, 23, 40, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 2048, 512, 25, 34, 1, 1, 1, 1, 1, 1, 0, False, 1], + [1, 2048, 512, 25, 34, 1, 1, 1, 1, 0, 0, 1, False, 1], [1, 512, 512, 25, 34, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 1024, 512, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 1024, 512, 28, 28, 1, 1, 2, 2, 2, 2, 0, False, 1], - [1, 128, 512, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 128, 512, 28, 28, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 19, 512, 28, 28, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 256, 512, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 38, 512, 28, 28, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 512, 512, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 512, 512, 28, 28, 3, 3, 1, 1, 1, 1, 2, False, 1], - [1, 512, 512, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1], + [1, 1024, 512, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 1024, 512, 28, 28, 1, 1, 2, 2, 0, 0, 1, False, 1], + [1, 128, 512, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 128, 512, 28, 28, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 19, 512, 28, 28, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 256, 512, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 38, 512, 28, 28, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 512, 512, 28, 28, 2, 2, 2, 2, 0, 0, 1, True, 1], + [1, 512, 512, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1], [1, 512, 512, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 512, 512, 28, 28, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 512, 512, 28, 28, 2, 2, 2, 2, 2, 2, 0, True, 1], - [1, 512, 512, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 512, 512, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 512, 512, 28, 28, 3, 3, 2, 2, 2, 2, 1, False, 1], + [1, 512, 512, 28, 28, 3, 3, 2, 2, 1, 1, 1, False, 1], [1, 512, 512, 28, 28, 3, 3, 1, 1, 1, 1, 1, True, 1], - [1, 512, 512, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 896, 512, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 896, 512, 28, 28, 1, 1, 2, 2, 2, 2, 0, False, 1], - [1, 1024, 512, 32, 32, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 255, 512, 32, 32, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 256, 512, 32, 32, 1, 1, 1, 1, 1, 1, 0, False, 1], + [1, 896, 512, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 896, 512, 28, 28, 1, 1, 2, 2, 0, 0, 1, False, 1], + [1, 1024, 512, 32, 32, 3, 3, 2, 2, 1, 1, 1, False, 1], + [1, 255, 512, 32, 32, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 256, 512, 32, 32, 1, 1, 1, 1, 0, 0, 1, False, 1], [1, 256, 512, 32, 32, 3, 3, 1, 1, 1, 1, 1, False, 1], [1, 16, 512, 38, 38, 3, 3, 1, 1, 1, 1, 1, True, 1], - [1, 512, 512, 45, 80, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 128, 512, 5, 5, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 24, 512, 5, 5, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 512, 512, 5, 5, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 546, 512, 5, 5, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 512, 512, 50, 68, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 512, 512, 56, 56, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 512, 512, 56, 56, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 512, 512, 56, 56, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 1024, 512, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 128, 512, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 2048, 512, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], + [1, 512, 512, 45, 80, 3, 3, 2, 2, 1, 1, 1, False, 1], + [1, 128, 512, 5, 5, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 24, 512, 5, 5, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 546, 512, 5, 5, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 512, 512, 50, 68, 3, 3, 2, 2, 1, 1, 1, False, 1], + [1, 1024, 512, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 128, 512, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 2048, 512, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], [1, 512, 512, 7, 7, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 256, 512, 8, 8, 1, 3, 1, 1, 1, 1, 0, False, 1], - [1, 256, 512, 8, 8, 3, 1, 1, 1, 1, 1, 1, False, 1], - [1, 128, 512, 90, 160, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 256, 512, 90, 160, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 208, 52, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 440, 52, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 132, 528, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 8, 528, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 128, 528, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 160, 528, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 192, 528, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 256, 528, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 32, 528, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 120, 528, 17, 17, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 528, 528, 17, 17, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 88, 528, 17, 17, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 192, 528, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 528, 528, 96, 96, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 216, 54, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 576, 54, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], + [1, 256, 512, 8, 8, 1, 3, 1, 1, 0, 1, 1, False, 1], + [1, 256, 512, 8, 8, 3, 1, 1, 1, 1, 0, 1, False, 1], + [1, 128, 512, 90, 160, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 256, 512, 90, 160, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 208, 52, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 440, 52, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 132, 528, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 8, 528, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 128, 528, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 160, 528, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 192, 528, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 256, 528, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 32, 528, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 120, 528, 17, 17, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 88, 528, 17, 17, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 192, 528, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 528, 528, 96, 96, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 216, 54, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 576, 54, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], [1, 24, 54, 56, 56, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 128, 544, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], + [1, 128, 544, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], [1, 196, 544, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 128, 544, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 224, 56, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 448, 56, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 56, 56, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 336, 56, 48, 48, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 144, 576, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 54, 576, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 128, 576, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 1512, 576, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 1512, 576, 14, 14, 1, 1, 2, 2, 2, 2, 0, False, 1], - [1, 192, 576, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 576, 576, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 576, 576, 14, 14, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 576, 576, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 576, 576, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 96, 576, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 136, 576, 19, 19, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 576, 576, 19, 19, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 576, 576, 19, 19, 5, 5, 1, 1, 1, 1, 2, False, 1], - [1, 96, 576, 19, 19, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 192, 576, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 576, 576, 28, 28, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 128, 576, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 160, 576, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 576, 576, 7, 7, 5, 5, 1, 1, 1, 1, 2, False, 1], - [1, 96, 576, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 232, 58, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 696, 58, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], + [1, 128, 544, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 224, 56, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 448, 56, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 336, 56, 48, 48, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 144, 576, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 54, 576, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 128, 576, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 1512, 576, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 1512, 576, 14, 14, 1, 1, 2, 2, 0, 0, 1, False, 1], + [1, 192, 576, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 576, 576, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 96, 576, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 136, 576, 19, 19, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 96, 576, 19, 19, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 192, 576, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 128, 576, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 160, 576, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 96, 576, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 232, 58, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 696, 58, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], [1, 20, 58, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 60, 60, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 128, 608, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 128, 608, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], + [1, 128, 608, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 128, 608, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], [1, 28, 62, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 192, 624, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 192, 624, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 128, 64, 1, 1, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 240, 64, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 8, 64, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], + [1, 192, 624, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 192, 624, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 128, 64, 1, 1, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 240, 64, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 8, 64, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], [1, 128, 64, 112, 112, 3, 3, 1, 1, 1, 1, 1, False, 1], [1, 128, 64, 112, 112, 3, 3, 1, 1, 1, 1, 1, True, 1], - [1, 64, 64, 112, 112, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 64, 64, 112, 112, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 64, 64, 112, 112, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 64, 64, 112, 112, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 64, 64, 112, 112, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 128, 64, 120, 160, 3, 3, 2, 2, 2, 2, 1, True, 1], + [1, 64, 64, 112, 112, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 64, 64, 112, 112, 3, 3, 2, 2, 1, 1, 1, False, 1], + [1, 128, 64, 120, 160, 3, 3, 2, 2, 1, 1, 1, True, 1], [1, 32, 64, 120, 160, 3, 3, 1, 1, 1, 1, 1, True, 1], - [1, 64, 64, 120, 160, 8, 8, 8, 8, 8, 8, 0, True, 1], + [1, 64, 64, 120, 160, 8, 8, 8, 8, 0, 0, 1, True, 1], [1, 128, 64, 128, 128, 3, 3, 1, 1, 1, 1, 1, False, 1], [1, 64, 64, 128, 128, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 384, 64, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 96, 64, 147, 147, 3, 3, 2, 2, 2, 2, 0, False, 1], - [1, 128, 64, 150, 150, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 128, 64, 150, 150, 1, 1, 2, 2, 2, 2, 0, False, 1], + [1, 384, 64, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 96, 64, 147, 147, 3, 3, 2, 2, 0, 0, 1, False, 1], + [1, 128, 64, 150, 150, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 128, 64, 150, 150, 1, 1, 2, 2, 0, 0, 1, False, 1], [1, 128, 64, 150, 150, 3, 3, 1, 1, 1, 1, 1, True, 1], - [1, 64, 64, 150, 150, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 64, 64, 160, 160, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 64, 64, 180, 320, 1, 1, 1, 1, 1, 1, 0, False, 1], + [1, 64, 64, 180, 320, 1, 1, 1, 1, 0, 0, 1, False, 1], [1, 64, 64, 180, 320, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 64, 64, 2, 2, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 256, 64, 200, 272, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 64, 64, 200, 272, 1, 1, 1, 1, 1, 1, 0, False, 1], + [1, 256, 64, 200, 272, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 64, 64, 200, 272, 1, 1, 1, 1, 0, 0, 1, False, 1], [1, 64, 64, 200, 272, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 1, 64, 224, 224, 1, 1, 1, 1, 1, 1, 0, True, 1], + [1, 1, 64, 224, 224, 1, 1, 1, 1, 0, 0, 1, True, 1], [1, 64, 64, 224, 224, 3, 3, 1, 1, 1, 1, 1, False, 1], [1, 64, 64, 224, 224, 3, 3, 1, 1, 1, 1, 1, True, 1], - [1, 128, 64, 256, 256, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 32, 64, 256, 256, 1, 1, 1, 1, 1, 1, 0, False, 1], + [1, 128, 64, 256, 256, 3, 3, 2, 2, 1, 1, 1, False, 1], + [1, 32, 64, 256, 256, 1, 1, 1, 1, 0, 0, 1, False, 1], [1, 32, 64, 256, 256, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 128, 64, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 160, 64, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 160, 64, 28, 28, 1, 1, 2, 2, 2, 2, 0, False, 1], - [1, 256, 64, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 64, 64, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 64, 64, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1], + [1, 128, 64, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 160, 64, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 160, 64, 28, 28, 1, 1, 2, 2, 0, 0, 1, False, 1], + [1, 256, 64, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 64, 64, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1], [1, 64, 64, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1], [1, 32, 64, 30, 40, 3, 3, 1, 1, 1, 1, 1, True, 1], [1, 64, 64, 300, 300, 3, 3, 1, 1, 1, 1, 1, True, 1], [1, 96, 64, 35, 35, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 128, 64, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 128, 64, 56, 56, 1, 1, 2, 2, 2, 2, 0, False, 1], - [1, 128, 64, 56, 56, 3, 3, 2, 2, 2, 2, 1, False, 1], + [1, 128, 64, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 128, 64, 56, 56, 1, 1, 2, 2, 0, 0, 1, False, 1], + [1, 128, 64, 56, 56, 3, 3, 2, 2, 1, 1, 1, False, 1], [1, 14, 64, 56, 56, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 144, 64, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 144, 64, 56, 56, 1, 1, 2, 2, 2, 2, 0, False, 1], + [1, 144, 64, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 144, 64, 56, 56, 1, 1, 2, 2, 0, 0, 1, False, 1], [1, 192, 64, 56, 56, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 24, 64, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 256, 64, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 64, 64, 56, 56, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 64, 64, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1], + [1, 24, 64, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 256, 64, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 64, 64, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1], [1, 64, 64, 56, 56, 3, 3, 1, 1, 1, 1, 1, False, 1], [1, 32, 64, 60, 80, 3, 3, 1, 1, 1, 1, 1, True, 1], [1, 128, 64, 64, 64, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 160, 64, 64, 64, 3, 3, 2, 2, 2, 2, 1, True, 1], - [1, 64, 64, 64, 64, 4, 4, 4, 4, 4, 4, 0, True, 1], - [1, 64, 64, 73, 73, 1, 7, 1, 1, 1, 1, 0, False, 1], - [1, 64, 64, 73, 73, 7, 1, 1, 1, 1, 1, 3, False, 1], - [1, 96, 64, 73, 73, 3, 3, 1, 1, 1, 1, 0, False, 1], - [1, 24, 64, 80, 80, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 128, 640, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 640, 640, 32, 32, 3, 3, 1, 1, 1, 1, 1, True, 1], - [1, 128, 640, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], + [1, 160, 64, 64, 64, 3, 3, 2, 2, 1, 1, 1, True, 1], + [1, 64, 64, 64, 64, 4, 4, 4, 4, 0, 0, 1, True, 1], + [1, 64, 64, 73, 73, 1, 7, 1, 1, 0, 3, 1, False, 1], + [1, 64, 64, 73, 73, 7, 1, 1, 1, 3, 0, 1, False, 1], + [1, 96, 64, 73, 73, 3, 3, 1, 1, 0, 0, 1, False, 1], + [1, 24, 64, 80, 80, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 128, 640, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 128, 640, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], [1, 160, 640, 7, 7, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 640, 654, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 168, 672, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 80, 672, 10, 10, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 112, 672, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 128, 672, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 192, 672, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 56, 672, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 672, 672, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 672, 672, 14, 14, 5, 5, 1, 1, 1, 1, 2, False, 1], - [1, 672, 672, 14, 14, 5, 5, 2, 2, 2, 2, 2, False, 1], - [1, 672, 672, 14, 14, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 112, 672, 15, 15, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 672, 672, 15, 15, 5, 5, 1, 1, 1, 1, 2, False, 1], - [1, 672, 672, 17, 17, 5, 5, 2, 2, 2, 2, 0, False, 1], - [1, 672, 672, 19, 19, 5, 5, 2, 2, 2, 2, 0, False, 1], - [1, 112, 672, 20, 20, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 24, 672, 20, 20, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 546, 672, 20, 20, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 672, 672, 20, 20, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 672, 672, 20, 20, 5, 5, 2, 2, 2, 2, 2, False, 1], - [1, 112, 672, 24, 24, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 160, 672, 24, 24, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 672, 672, 24, 24, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 672, 672, 24, 24, 5, 5, 1, 1, 1, 1, 2, False, 1], - [1, 1344, 672, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 1344, 672, 28, 28, 1, 1, 2, 2, 2, 2, 0, False, 1], - [1, 192, 672, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 672, 672, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 672, 672, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 672, 672, 56, 56, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 128, 672, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 160, 672, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 192, 672, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 672, 672, 7, 7, 1, 5, 1, 1, 1, 1, 0, False, 1], - [1, 672, 672, 7, 7, 5, 1, 1, 1, 1, 1, 2, False, 1], - [1, 672, 672, 7, 7, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 672, 672, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 80, 672, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 192, 672, 8, 8, 1, 1, 1, 1, 1, 1, 0, False, 1], + [1, 640, 654, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 168, 672, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 80, 672, 10, 10, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 112, 672, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 128, 672, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 192, 672, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 56, 672, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 112, 672, 15, 15, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 112, 672, 20, 20, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 24, 672, 20, 20, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 546, 672, 20, 20, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 112, 672, 24, 24, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 160, 672, 24, 24, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 1344, 672, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 1344, 672, 28, 28, 1, 1, 2, 2, 0, 0, 1, False, 1], + [1, 192, 672, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 672, 672, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 128, 672, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 160, 672, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 192, 672, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 672, 672, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 80, 672, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 192, 672, 8, 8, 1, 1, 1, 1, 0, 0, 1, False, 1], [1, 40, 68, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 174, 696, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 58, 696, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 696, 696, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 128, 704, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 128, 704, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 18, 72, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 20, 72, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 24, 72, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 288, 72, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 8, 72, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 72, 72, 112, 112, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 128, 72, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 144, 72, 14, 14, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 18, 72, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 36, 72, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 512, 72, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], + [1, 174, 696, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 58, 696, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 696, 696, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 128, 704, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 128, 704, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 18, 72, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 20, 72, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 24, 72, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 288, 72, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 8, 72, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 128, 72, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 144, 72, 14, 14, 3, 3, 2, 2, 1, 1, 1, False, 1], + [1, 18, 72, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 36, 72, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 512, 72, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], [1, 72, 72, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 20, 72, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 24, 72, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 40, 72, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 72, 72, 28, 28, 1, 5, 1, 1, 1, 1, 0, False, 1], - [1, 72, 72, 28, 28, 5, 1, 1, 1, 1, 1, 2, False, 1], - [1, 40, 72, 40, 40, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 12, 72, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 168, 72, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 168, 72, 56, 56, 1, 1, 2, 2, 2, 2, 0, False, 1], - [1, 216, 72, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 216, 72, 56, 56, 1, 1, 2, 2, 2, 2, 0, False, 1], - [1, 24, 72, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 72, 72, 56, 56, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 72, 72, 56, 56, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 72, 72, 56, 56, 5, 5, 2, 2, 2, 2, 2, False, 1], - [1, 72, 72, 56, 56, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 72, 72, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 24, 72, 80, 80, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 72, 72, 80, 80, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 72, 72, 80, 80, 5, 5, 2, 2, 2, 2, 2, False, 1], - [1, 192, 720, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 720, 720, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 720, 720, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 120, 720, 17, 17, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 192, 720, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 720, 720, 28, 28, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 208, 720, 9, 9, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 728, 728, 19, 19, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 728, 728, 19, 19, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 728, 728, 38, 38, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 728, 728, 38, 38, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 128, 736, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 512, 736, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 128, 736, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], + [1, 20, 72, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 24, 72, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 40, 72, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 40, 72, 40, 40, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 12, 72, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 168, 72, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 168, 72, 56, 56, 1, 1, 2, 2, 0, 0, 1, False, 1], + [1, 216, 72, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 216, 72, 56, 56, 1, 1, 2, 2, 0, 0, 1, False, 1], + [1, 24, 72, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 72, 72, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 24, 72, 80, 80, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 192, 720, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 720, 720, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 120, 720, 17, 17, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 192, 720, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 208, 720, 9, 9, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 728, 728, 19, 19, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 728, 728, 38, 38, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 128, 736, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 512, 736, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 128, 736, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], [1, 334, 740, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 768, 768, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 128, 768, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 192, 768, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 384, 768, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 256, 768, 32, 32, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 128, 768, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 224, 768, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], + [1, 768, 768, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 128, 768, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 192, 768, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 384, 768, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 256, 768, 32, 32, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 128, 768, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 224, 768, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], [1, 16, 78, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1], [1, 34, 78, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1], [1, 24, 78, 56, 56, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 196, 784, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 80, 784, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 784, 784, 14, 14, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 784, 784, 7, 7, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 784, 784, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 16, 8, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 224, 8, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 232, 8, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 48, 8, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 528, 8, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 64, 8, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 72, 8, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 8, 8, 112, 112, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 320, 80, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 784, 80, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 480, 80, 10, 10, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 80, 80, 112, 112, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 100, 80, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 112, 80, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 184, 80, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 200, 80, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 240, 80, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 480, 80, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 80, 80, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 92, 80, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 480, 80, 15, 15, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 184, 80, 20, 20, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 200, 80, 20, 20, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 480, 80, 20, 20, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 240, 80, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 240, 80, 56, 56, 1, 1, 2, 2, 2, 2, 0, False, 1], - [1, 80, 80, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1], + [1, 196, 784, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 80, 784, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 784, 784, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 16, 8, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 224, 8, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 232, 8, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 48, 8, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 528, 8, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 64, 8, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 72, 8, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 320, 80, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 784, 80, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 480, 80, 10, 10, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 80, 80, 112, 112, 3, 3, 2, 2, 1, 1, 1, False, 1], + [1, 100, 80, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 112, 80, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 184, 80, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 200, 80, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 240, 80, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 480, 80, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 92, 80, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 480, 80, 15, 15, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 184, 80, 20, 20, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 200, 80, 20, 20, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 480, 80, 20, 20, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 240, 80, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 240, 80, 56, 56, 1, 1, 2, 2, 0, 0, 1, False, 1], + [1, 80, 80, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1], [1, 80, 80, 56, 56, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 184, 80, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 200, 80, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 480, 80, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 80, 80, 7, 7, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 128, 800, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 128, 800, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], + [1, 184, 80, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 200, 80, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 480, 80, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 128, 800, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 128, 800, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], [1, 272, 800, 7, 7, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 232, 816, 10, 10, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 192, 816, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 136, 816, 19, 19, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 128, 832, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 128, 832, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 160, 832, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 192, 832, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 256, 832, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 32, 832, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 384, 832, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 48, 832, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 336, 84, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 888, 84, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 128, 864, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 192, 864, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 128, 864, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 528, 88, 17, 17, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 24, 88, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 88, 88, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 222, 888, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 84, 888, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 888, 888, 14, 14, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 888, 888, 7, 7, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 888, 888, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 112, 896, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 224, 896, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 128, 896, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 2016, 896, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 2016, 896, 14, 14, 1, 1, 2, 2, 2, 2, 0, False, 1], - [1, 2048, 896, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 2048, 896, 14, 14, 1, 1, 2, 2, 2, 2, 0, False, 1], - [1, 256, 896, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 896, 896, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 896, 896, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 896, 896, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 896, 896, 28, 28, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 896, 896, 28, 28, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 128, 896, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 192, 912, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 912, 912, 14, 14, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 912, 912, 7, 7, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 92, 92, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 128, 928, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 128, 928, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], + [1, 232, 816, 10, 10, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 192, 816, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 136, 816, 19, 19, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 128, 832, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 128, 832, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 160, 832, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 192, 832, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 256, 832, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 32, 832, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 384, 832, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 48, 832, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 336, 84, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 888, 84, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 128, 864, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 192, 864, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 128, 864, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 528, 88, 17, 17, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 24, 88, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 222, 888, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 84, 888, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 888, 888, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 112, 896, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 224, 896, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 128, 896, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 2016, 896, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 2016, 896, 14, 14, 1, 1, 2, 2, 0, 0, 1, False, 1], + [1, 2048, 896, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 2048, 896, 14, 14, 1, 1, 2, 2, 0, 0, 1, False, 1], + [1, 256, 896, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 896, 896, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 128, 896, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 192, 912, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 128, 928, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 128, 928, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], [1, 28, 94, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 24, 96, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 96, 96, 112, 112, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 96, 96, 112, 112, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 96, 96, 113, 113, 3, 3, 2, 2, 2, 2, 0, False, 1], - [1, 96, 96, 121, 121, 3, 3, 2, 2, 2, 2, 0, False, 1], - [1, 96, 96, 131, 131, 3, 3, 2, 2, 2, 2, 0, False, 1], + [1, 24, 96, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], [1, 208, 96, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 40, 96, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 576, 96, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 576, 96, 19, 19, 1, 1, 1, 1, 1, 1, 0, False, 1], + [1, 40, 96, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 576, 96, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 576, 96, 19, 19, 1, 1, 1, 1, 0, 0, 1, False, 1], [1, 128, 96, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 96, 96, 28, 28, 5, 5, 2, 2, 2, 2, 2, False, 1], [1, 96, 96, 35, 35, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 128, 96, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 192, 96, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 192, 96, 56, 56, 1, 1, 2, 2, 2, 2, 0, False, 1], - [1, 24, 96, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 96, 96, 56, 56, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 96, 96, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 24, 96, 60, 60, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 24, 96, 65, 65, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 576, 96, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 240, 960, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 272, 960, 12, 12, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 128, 960, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 192, 960, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 160, 960, 24, 24, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 960, 960, 27, 27, 5, 5, 2, 2, 2, 2, 0, False, 1], - [1, 960, 960, 3, 3, 1, 5, 1, 1, 1, 1, 0, False, 1], - [1, 960, 960, 3, 3, 5, 1, 1, 1, 1, 1, 2, False, 1], - [1, 128, 960, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 160, 960, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 320, 960, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 80, 960, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 960, 960, 7, 7, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 960, 960, 7, 7, 5, 5, 1, 1, 1, 1, 2, False, 1], + [1, 128, 96, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 192, 96, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 192, 96, 56, 56, 1, 1, 2, 2, 0, 0, 1, False, 1], + [1, 24, 96, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 96, 96, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 24, 96, 60, 60, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 24, 96, 65, 65, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 576, 96, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 240, 960, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 272, 960, 12, 12, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 128, 960, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 192, 960, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 160, 960, 24, 24, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 128, 960, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 160, 960, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 320, 960, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 80, 960, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], [1, 20, 98, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 128, 992, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 128, 992, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 1024, 1024, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 256, 1024, 128, 128, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 1024, 1024, 14, 14, 2, 2, 2, 2, 2, 2, 0, True, 1], - [1, 2048, 1024, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 2048, 1024, 14, 14, 1, 1, 2, 2, 2, 2, 0, False, 1], + [1, 128, 992, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 128, 992, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 1024, 1024, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 256, 1024, 128, 128, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 1024, 1024, 14, 14, 2, 2, 2, 2, 0, 0, 1, True, 1], + [1, 2048, 1024, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 2048, 1024, 14, 14, 1, 1, 2, 2, 0, 0, 1, False, 1], [1, 512, 1024, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 2048, 1024, 45, 80, 1, 1, 2, 2, 2, 2, 0, False, 1], - [1, 2048, 1024, 50, 68, 1, 1, 2, 2, 2, 2, 0, False, 1], - [1, 2048, 1024, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 2048, 1024, 7, 7, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 1056, 1056, 48, 48, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 1056, 1056, 48, 48, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 2904, 1056, 48, 48, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 2904, 1056, 48, 48, 1, 1, 2, 2, 2, 2, 0, False, 1], - [1, 1056, 1056, 96, 96, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 3024, 1232, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 3024, 1232, 14, 14, 1, 1, 2, 2, 2, 2, 0, False, 1], - [1, 128, 128, 112, 112, 2, 2, 2, 2, 2, 2, 0, True, 1], - [1, 128, 128, 5, 5, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 1280, 1280, 30, 40, 3, 3, 1, 1, 1, 1, 1, True, 1], - [1, 2520, 1344, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 2520, 1344, 14, 14, 1, 1, 2, 2, 2, 2, 0, False, 1], - [1, 3712, 1392, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 3712, 1392, 14, 14, 1, 1, 2, 2, 2, 2, 0, False, 1], - [1, 1024, 1440, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 1512, 1512, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 1536, 1536, 10, 10, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 2048, 1536, 10, 10, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 448, 1632, 12, 12, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 1920, 1920, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 2016, 2016, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 2048, 2048, 15, 20, 3, 3, 1, 1, 1, 1, 1, True, 1], - [1, 1024, 2048, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 2048, 2048, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 1056, 2112, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 576, 216, 28, 28, 1, 1, 2, 2, 2, 2, 0, False, 1], - [1, 232, 232, 112, 112, 3, 3, 2, 2, 2, 2, 1, False, 1], + [1, 2048, 1024, 45, 80, 1, 1, 2, 2, 0, 0, 1, False, 1], + [1, 2048, 1024, 50, 68, 1, 1, 2, 2, 0, 0, 1, False, 1], + [1, 2048, 1024, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 2048, 1024, 7, 7, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 1056, 1056, 48, 48, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 2904, 1056, 48, 48, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 2904, 1056, 48, 48, 1, 1, 2, 2, 0, 0, 1, False, 1], + [1, 3024, 1232, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 3024, 1232, 14, 14, 1, 1, 2, 2, 0, 0, 1, False, 1], + [1, 128, 128, 112, 112, 2, 2, 2, 2, 0, 0, 1, True, 1], + [1, 2520, 1344, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 2520, 1344, 14, 14, 1, 1, 2, 2, 0, 0, 1, False, 1], + [1, 3712, 1392, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 3712, 1392, 14, 14, 1, 1, 2, 2, 0, 0, 1, False, 1], + [1, 1024, 1440, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 1512, 1512, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 1536, 1536, 10, 10, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 2048, 1536, 10, 10, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 448, 1632, 12, 12, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 1920, 1920, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 2016, 2016, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 1024, 2048, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 2048, 2048, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 1056, 2112, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 576, 216, 28, 28, 1, 1, 2, 2, 0, 0, 1, False, 1], + [1, 232, 232, 112, 112, 3, 3, 2, 2, 1, 1, 1, False, 1], [1, 232, 232, 56, 56, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 144, 24, 190, 190, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 720, 240, 28, 28, 1, 1, 2, 2, 2, 2, 0, False, 1], - [1, 2520, 2520, 14, 14, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 2520, 2520, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], + [1, 144, 24, 190, 190, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 720, 240, 28, 28, 1, 1, 2, 2, 0, 0, 1, False, 1], + [1, 2520, 2520, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], [1, 819, 256, 100, 136, 3, 3, 1, 1, 1, 1, 1, True, 1], - [1, 256, 256, 112, 112, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 256, 256, 120, 160, 3, 3, 1, 1, 1, 1, 1, True, 1], - [1, 512, 256, 180, 320, 1, 1, 2, 2, 2, 2, 0, False, 1], - [1, 512, 256, 200, 272, 1, 1, 2, 2, 2, 2, 0, False, 1], + [1, 512, 256, 180, 320, 1, 1, 2, 2, 0, 0, 1, False, 1], + [1, 512, 256, 200, 272, 1, 1, 2, 2, 0, 0, 1, False, 1], [1, 819, 256, 25, 34, 3, 3, 1, 1, 1, 1, 1, True, 1], [1, 819, 256, 50, 68, 3, 3, 1, 1, 1, 1, 1, True, 1], [1, 256, 256, 56, 56, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 256, 256, 56, 56, 3, 3, 1, 1, 1, 1, 1, False, 1], [1, 256, 256, 56, 56, 3, 3, 1, 1, 1, 1, 1, True, 1], - [1, 256, 256, 56, 56, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 256, 256, 56, 56, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 256, 256, 75, 75, 3, 3, 1, 1, 1, 1, 1, False, 1], [1, 256, 256, 75, 75, 3, 3, 1, 1, 1, 1, 1, True, 1], - [1, 2904, 264, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 264, 2904, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 726, 2904, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 2904, 2904, 24, 24, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 2904, 2904, 24, 24, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 7392, 2904, 24, 24, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 7392, 2904, 24, 24, 1, 1, 2, 2, 2, 2, 0, False, 1], - [1, 2904, 2904, 48, 48, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 1024, 3, 224, 224, 16, 16, 16, 16, 16, 16, 0, True, 1], - [1, 1024, 3, 224, 224, 32, 32, 32, 32, 32, 32, 0, True, 1], - [1, 768, 3, 224, 224, 16, 16, 16, 16, 16, 16, 0, True, 1], - [1, 768, 3, 224, 224, 32, 32, 32, 32, 32, 32, 0, False, 1], - [1, 768, 3, 224, 224, 32, 32, 32, 32, 32, 32, 0, True, 1], - [1, 32, 3, 299, 299, 3, 3, 2, 2, 2, 2, 0, False, 1], - [1, 32, 3, 381, 381, 3, 3, 2, 2, 2, 2, 0, False, 1], - [1, 768, 3, 384, 512, 32, 32, 32, 32, 32, 32, 0, True, 1], - [1, 64, 3, 480, 640, 7, 7, 4, 4, 4, 4, 3, True, 1], + [1, 2904, 264, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 264, 2904, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 726, 2904, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 2904, 2904, 24, 24, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 7392, 2904, 24, 24, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 7392, 2904, 24, 24, 1, 1, 2, 2, 0, 0, 1, False, 1], + [1, 1024, 3, 224, 224, 16, 16, 16, 16, 0, 0, 1, True, 1], + [1, 1024, 3, 224, 224, 32, 32, 32, 32, 0, 0, 1, True, 1], + [1, 768, 3, 224, 224, 16, 16, 16, 16, 0, 0, 1, True, 1], + [1, 768, 3, 224, 224, 32, 32, 32, 32, 0, 0, 1, False, 1], + [1, 768, 3, 224, 224, 32, 32, 32, 32, 0, 0, 1, True, 1], + [1, 32, 3, 299, 299, 3, 3, 2, 2, 0, 0, 1, False, 1], + [1, 32, 3, 381, 381, 3, 3, 2, 2, 0, 0, 1, False, 1], + [1, 768, 3, 384, 512, 32, 32, 32, 32, 0, 0, 1, True, 1], + [1, 64, 3, 480, 640, 7, 7, 4, 4, 3, 3, 1, True, 1], [1, 32, 3, 512, 512, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 32, 3, 512, 512, 6, 6, 2, 2, 2, 2, 2, False, 1], - [1, 32, 3, 512, 512, 7, 7, 4, 4, 4, 4, 3, True, 1], - [1, 192, 3, 512, 672, 16, 16, 16, 16, 16, 16, 0, True, 1], - [1, 1280, 3, 518, 518, 14, 14, 14, 14, 14, 14, 0, True, 1], - [1, 64, 3, 720, 1280, 7, 7, 2, 2, 2, 2, 3, False, 1], - [1, 64, 3, 800, 1088, 7, 7, 2, 2, 2, 2, 3, False, 1], - [1, 308, 3024, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 3024, 3024, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 3024, 308, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 528, 32, 192, 192, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 64, 32, 512, 512, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 336, 336, 112, 112, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 888, 336, 14, 14, 1, 1, 2, 2, 2, 2, 0, False, 1], - [1, 336, 336, 48, 48, 5, 5, 1, 1, 1, 1, 2, False, 1], - [1, 336, 336, 56, 56, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 3712, 348, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 348, 3712, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 3712, 3712, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 912, 408, 14, 14, 1, 1, 2, 2, 2, 2, 0, False, 1], - [1, 1008, 432, 14, 14, 1, 1, 2, 2, 2, 2, 0, False, 1], - [1, 128, 512, 100, 136, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 256, 512, 100, 136, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 256, 512, 100, 136, 1, 1, 1, 1, 1, 1, 0, True, 1], + [1, 32, 3, 512, 512, 6, 6, 2, 2, 2, 2, 1, False, 1], + [1, 32, 3, 512, 512, 7, 7, 4, 4, 3, 3, 1, True, 1], + [1, 192, 3, 512, 672, 16, 16, 16, 16, 0, 0, 1, True, 1], + [1, 1280, 3, 518, 518, 14, 14, 14, 14, 0, 0, 1, True, 1], + [1, 64, 3, 720, 1280, 7, 7, 2, 2, 3, 3, 1, False, 1], + [1, 64, 3, 800, 1088, 7, 7, 2, 2, 3, 3, 1, False, 1], + [1, 308, 3024, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 3024, 3024, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 3024, 308, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 528, 32, 192, 192, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 64, 32, 512, 512, 3, 3, 2, 2, 1, 1, 1, False, 1], + [1, 888, 336, 14, 14, 1, 1, 2, 2, 0, 0, 1, False, 1], + [1, 3712, 348, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 348, 3712, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 3712, 3712, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 912, 408, 14, 14, 1, 1, 2, 2, 0, 0, 1, False, 1], + [1, 1008, 432, 14, 14, 1, 1, 2, 2, 0, 0, 1, False, 1], + [1, 128, 512, 100, 136, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 256, 512, 100, 136, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 256, 512, 100, 136, 1, 1, 1, 1, 0, 0, 1, True, 1], [1, 364, 512, 38, 38, 3, 3, 1, 1, 1, 1, 1, True, 1], [1, 512, 512, 38, 38, 3, 3, 1, 1, 1, 1, 1, True, 1], [1, 256, 512, 56, 56, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 512, 512, 60, 80, 3, 3, 1, 1, 1, 1, 1, True, 1], - [1, 1024, 512, 90, 160, 1, 1, 2, 2, 2, 2, 0, False, 1], - [1, 528, 528, 17, 17, 5, 5, 1, 1, 1, 1, 2, False, 1], - [1, 528, 528, 192, 192, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 1056, 528, 96, 96, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 1056, 528, 96, 96, 1, 1, 2, 2, 2, 2, 0, False, 1], - [1, 528, 528, 96, 96, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 64, 64, 128, 128, 2, 2, 2, 2, 2, 2, 0, True, 1], - [1, 256, 64, 180, 320, 1, 1, 1, 1, 1, 1, 0, False, 1], + [1, 1024, 512, 90, 160, 1, 1, 2, 2, 0, 0, 1, False, 1], + [1, 1056, 528, 96, 96, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 1056, 528, 96, 96, 1, 1, 2, 2, 0, 0, 1, False, 1], + [1, 64, 64, 128, 128, 2, 2, 2, 2, 0, 0, 1, True, 1], + [1, 256, 64, 180, 320, 1, 1, 1, 1, 0, 0, 1, False, 1], [1, 1, 64, 480, 640, 3, 3, 1, 1, 1, 1, 1, True, 1], [1, 64, 64, 480, 640, 3, 3, 1, 1, 1, 1, 1, True, 1], - [1, 1392, 696, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 1392, 696, 28, 28, 1, 1, 2, 2, 2, 2, 0, False, 1], - [1, 696, 696, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 696, 696, 56, 56, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 1920, 720, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 1920, 720, 14, 14, 1, 1, 2, 2, 2, 2, 0, False, 1], - [1, 720, 720, 17, 17, 5, 5, 1, 1, 1, 1, 2, False, 1], - [1, 720, 720, 21, 21, 5, 5, 2, 2, 2, 2, 0, False, 1], - [1, 2904, 726, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 7392, 726, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 1024, 728, 19, 19, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 1024, 728, 19, 19, 1, 1, 2, 2, 2, 2, 0, False, 1], - [1, 728, 728, 38, 38, 3, 3, 1, 1, 1, 1, 1, False, 1], - [1, 728, 728, 38, 38, 1, 1, 2, 2, 2, 2, 0, False, 1], - [1, 726, 7392, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 7392, 7392, 12, 12, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 7392, 7392, 24, 24, 3, 3, 2, 2, 2, 2, 1, False, 1], - [1, 1024, 782, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 816, 816, 19, 19, 5, 5, 1, 1, 1, 1, 2, False, 1], - [1, 816, 816, 23, 23, 5, 5, 2, 2, 2, 2, 0, False, 1], - [1, 912, 912, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1], - [1, 1280, 960, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1], - [1, 960, 960, 24, 24, 5, 5, 1, 1, 1, 1, 2, False, 1], - [1, 1280, 1280, 16, 16, 1, 1, 1, 1, 1, 1, 0, True, 1], + [1, 1392, 696, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 1392, 696, 28, 28, 1, 1, 2, 2, 0, 0, 1, False, 1], + [1, 1920, 720, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 1920, 720, 14, 14, 1, 1, 2, 2, 0, 0, 1, False, 1], + [1, 2904, 726, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 7392, 726, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 1024, 728, 19, 19, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 1024, 728, 19, 19, 1, 1, 2, 2, 0, 0, 1, False, 1], + [1, 728, 728, 38, 38, 1, 1, 2, 2, 0, 0, 1, False, 1], + [1, 726, 7392, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 7392, 7392, 12, 12, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 1024, 782, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 912, 912, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], + [1, 1280, 960, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], + [1, 1280, 1280, 16, 16, 1, 1, 1, 1, 0, 0, 1, True, 1], [1, 1280, 1280, 16, 16, 3, 3, 1, 1, 1, 1, 1, True, 1], - [1, 1280, 1280, 16, 16, 3, 3, 2, 2, 2, 2, 1, True, 1], + [1, 1280, 1280, 16, 16, 3, 3, 2, 2, 1, 1, 1, True, 1], [1, 1280, 1280, 32, 32, 3, 3, 1, 1, 1, 1, 1, True, 1], - [1, 640, 1280, 32, 32, 1, 1, 1, 1, 1, 1, 0, True, 1], + [1, 640, 1280, 32, 32, 1, 1, 1, 1, 0, 0, 1, True, 1], [1, 640, 1280, 32, 32, 3, 3, 1, 1, 1, 1, 1, True, 1], - [1, 1280, 1280, 8, 8, 1, 1, 1, 1, 1, 1, 0, True, 1], + [1, 1280, 1280, 8, 8, 1, 1, 1, 1, 0, 0, 1, True, 1], [1, 1280, 1280, 8, 8, 3, 3, 1, 1, 1, 1, 1, True, 1], - [1, 1280, 1920, 16, 16, 1, 1, 1, 1, 1, 1, 0, True, 1], + [1, 1280, 1920, 16, 16, 1, 1, 1, 1, 0, 0, 1, True, 1], [1, 1280, 1920, 16, 16, 3, 3, 1, 1, 1, 1, 1, True, 1], - [1, 640, 1920, 32, 32, 1, 1, 1, 1, 1, 1, 0, True, 1], + [1, 640, 1920, 32, 32, 1, 1, 1, 1, 0, 0, 1, True, 1], [1, 640, 1920, 32, 32, 3, 3, 1, 1, 1, 1, 1, True, 1], - [1, 1280, 2560, 16, 16, 1, 1, 1, 1, 1, 1, 0, True, 1], + [1, 1280, 2560, 16, 16, 1, 1, 1, 1, 0, 0, 1, True, 1], [1, 1280, 2560, 16, 16, 3, 3, 1, 1, 1, 1, 1, True, 1], - [1, 1280, 2560, 8, 8, 1, 1, 1, 1, 1, 1, 0, True, 1], + [1, 1280, 2560, 8, 8, 1, 1, 1, 1, 0, 0, 1, True, 1], [1, 1280, 2560, 8, 8, 3, 3, 1, 1, 1, 1, 1, True, 1], - [1, 640, 320, 32, 32, 1, 1, 1, 1, 1, 1, 0, True, 1], + [1, 640, 320, 32, 32, 1, 1, 1, 1, 0, 0, 1, True, 1], [1, 640, 320, 32, 32, 3, 3, 1, 1, 1, 1, 1, True, 1], - [1, 320, 320, 64, 64, 1, 1, 1, 1, 1, 1, 0, True, 1], + [1, 320, 320, 64, 64, 1, 1, 1, 1, 0, 0, 1, True, 1], [1, 320, 320, 64, 64, 3, 3, 1, 1, 1, 1, 1, True, 1], - [1, 320, 320, 64, 64, 3, 3, 2, 2, 2, 2, 1, True, 1], + [1, 320, 320, 64, 64, 3, 3, 2, 2, 1, 1, 1, True, 1], [1, 4, 320, 64, 64, 3, 3, 1, 1, 1, 1, 1, True, 1], [1, 320, 4, 64, 64, 3, 3, 1, 1, 1, 1, 1, True, 1], - [1, 1280, 640, 16, 16, 1, 1, 1, 1, 1, 1, 0, True, 1], + [1, 1280, 640, 16, 16, 1, 1, 1, 1, 0, 0, 1, True, 1], [1, 1280, 640, 16, 16, 3, 3, 1, 1, 1, 1, 1, True, 1], - [1, 640, 640, 32, 32, 1, 1, 1, 1, 1, 1, 0, True, 1], + [1, 640, 640, 32, 32, 1, 1, 1, 1, 0, 0, 1, True, 1], [1, 640, 640, 32, 32, 3, 3, 1, 1, 1, 1, 1, True, 1], - [1, 640, 640, 32, 32, 3, 3, 2, 2, 2, 2, 1, True, 1], - [1, 320, 640, 64, 64, 1, 1, 1, 1, 1, 1, 0, True, 1], + [1, 640, 640, 32, 32, 3, 3, 2, 2, 1, 1, 1, True, 1], + [1, 320, 640, 64, 64, 1, 1, 1, 1, 0, 0, 1, True, 1], [1, 320, 640, 64, 64, 3, 3, 1, 1, 1, 1, 1, True, 1], [1, 640, 640, 64, 64, 3, 3, 1, 1, 1, 1, 1, True, 1], - [1, 640, 960, 32, 32, 1, 1, 1, 1, 1, 1, 0, True, 1], + [1, 640, 960, 32, 32, 1, 1, 1, 1, 0, 0, 1, True, 1], [1, 640, 960, 32, 32, 3, 3, 1, 1, 1, 1, 1, True, 1], - [1, 320, 960, 64, 64, 1, 1, 1, 1, 1, 1, 0, True, 1], + [1, 320, 960, 64, 64, 1, 1, 1, 1, 0, 0, 1, True, 1], [1, 320, 960, 64, 64, 3, 3, 1, 1, 1, 1, 1, True, 1], ], }, @@ -1599,21 +1599,59 @@ def test_conv2d_localrun(device, input_spec): failing_parameters = [ # [batch_size, output_channels, input_channels, input_height, input_width, kernel_height, kernel_width, stride_x, stride_y, pad_x, pad_y, groups, bias, dilation] - # Input is 32MB maps to MM 64 cores, we neeed to avoid sharding this tensor and use dram intrelaved directly with MM - [1, 256, 1024, 128, 128, 1, 1, 1, 1, 0, 0, 1, False, 1], # 5 - [1, 1056, 1056, 96, 96, 3, 3, 2, 2, 1, 1, 4, False, 1], # 14 - [1, 2904, 2904, 48, 48, 3, 3, 2, 2, 1, 1, 11, False, 1], # 170 - [1, 1024, 3, 224, 224, 32, 32, 32, 32, 0, 0, 1, True, 1], # 172 - [1, 768, 3, 224, 224, 32, 32, 32, 32, 0, 0, 1, False, 1], # 181 - [1, 768, 3, 224, 224, 32, 32, 32, 32, 0, 0, 1, True, 1], # 182 - [1, 768, 3, 384, 512, 32, 32, 32, 32, 0, 0, 1, True, 1], # 198 - [1, 64, 3, 720, 1280, 7, 7, 2, 2, 3, 3, 1, False, 1], # 203 - [1, 64, 3, 800, 1088, 7, 7, 2, 2, 3, 3, 1, False, 1], # 204 - [1, 528, 528, 192, 192, 3, 3, 2, 2, 1, 1, 2, False, 1], # 292 - [1, 816, 816, 19, 19, 5, 5, 1, 1, 2, 2, 816, False, 1], # 373 - [1, 816, 816, 23, 23, 5, 5, 2, 2, 0, 0, 816, False, 1], # 374 - [1, 960, 960, 24, 24, 5, 5, 1, 1, 2, 2, 960, False, 1], # 394 - [1, 960, 960, 27, 27, 5, 5, 2, 2, 0, 0, 960, False, 1], # 395 + [1, 960, 960, 27, 27, 5, 5, 2, 2, 0, 0, 960, False, 1], # 0 + [1, 960, 960, 24, 24, 5, 5, 1, 1, 2, 2, 960, False, 1], # 5 + [1, 816, 816, 19, 19, 5, 5, 1, 1, 2, 2, 816, False, 1], # 19 + [1, 816, 816, 23, 23, 5, 5, 2, 2, 0, 0, 816, False, 1], # 20 + [1, 1056, 1056, 96, 96, 3, 3, 2, 2, 1, 1, 4, False, 1], # 127 + [1, 528, 528, 192, 192, 3, 3, 2, 2, 1, 1, 2, False, 1], # 220 + [1, 2904, 2904, 48, 48, 3, 3, 2, 2, 1, 1, 11, False, 1], # 294 + [1, 1024, 1024, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], # 1407 + [1, 256, 1024, 128, 128, 1, 1, 1, 1, 0, 0, 1, False, 1], # 1408 + [1, 1056, 1056, 48, 48, 1, 1, 1, 1, 0, 0, 1, False, 1], # 1417 + [1, 2904, 1056, 48, 48, 1, 1, 1, 1, 0, 0, 1, False, 1], # 1418 + [1, 3024, 1232, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], # 1420 + [1, 3024, 1232, 14, 14, 1, 1, 2, 2, 0, 0, 1, False, 1], # 1421 + [1, 2520, 1344, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], # 1423 + [1, 3712, 1392, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], # 1425 + [1, 1024, 1440, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], # 1427 + [1, 448, 1632, 12, 12, 1, 1, 1, 1, 0, 0, 1, False, 1], # 1431 + [1, 2520, 2520, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], # 1442 + [1, 819, 256, 100, 136, 3, 3, 1, 1, 1, 1, 1, True, 1], # 1443 + [1, 819, 256, 50, 68, 3, 3, 1, 1, 1, 1, 1, True, 1], # 1447 + [1, 2904, 264, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], # 1451 + [1, 264, 2904, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], # 1452 + [1, 726, 2904, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], # 1453 + [1, 7392, 2904, 24, 24, 1, 1, 1, 1, 0, 0, 1, False, 1], # 1455 + [1, 1024, 3, 224, 224, 32, 32, 32, 32, 0, 0, 1, True, 1], # 1458 + [1, 768, 3, 224, 224, 32, 32, 32, 32, 0, 0, 1, False, 1], # 1460 + [1, 768, 3, 224, 224, 32, 32, 32, 32, 0, 0, 1, True, 1], # 1461 + [1, 768, 3, 384, 512, 32, 32, 32, 32, 0, 0, 1, True, 1], # 1464 + [1, 64, 3, 720, 1280, 7, 7, 2, 2, 3, 3, 1, False, 1], # 1471 + [1, 64, 3, 800, 1088, 7, 7, 2, 2, 3, 3, 1, False, 1], # 1472 + [1, 308, 3024, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], # 1473 + [1, 3024, 3024, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], # 1474 + [1, 3024, 308, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], # 1475 + [1, 3712, 348, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], # 1479 + [1, 348, 3712, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], # 1480 + [1, 3712, 3712, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], # 1481 + [1, 1056, 528, 96, 96, 1, 1, 1, 1, 0, 0, 1, False, 1], # 1491 + [1, 1, 64, 480, 640, 3, 3, 1, 1, 1, 1, 1, True, 1], # 1495 + [1, 64, 64, 480, 640, 3, 3, 1, 1, 1, 1, 1, True, 1], # 1496 + [1, 1392, 696, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1], # 1497 + [1, 1920, 720, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], # 1499 + [1, 2904, 726, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], # 1501 + [1, 7392, 726, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], # 1502 + [1, 1024, 728, 19, 19, 1, 1, 1, 1, 0, 0, 1, False, 1], # 1503 + [1, 726, 7392, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], # 1506 + [1, 7392, 7392, 12, 12, 1, 1, 1, 1, 0, 0, 1, False, 1], # 1507 + [1, 1024, 782, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], # 1508 + [1, 912, 912, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], # 1509 + [1, 1280, 960, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], # 1510 + [1, 640, 1920, 32, 32, 3, 3, 1, 1, 1, 1, 1, True, 1], # 1522 + [1, 320, 320, 64, 64, 3, 3, 1, 1, 1, 1, 1, True, 1], # 1530 + [1, 320, 640, 64, 64, 3, 3, 1, 1, 1, 1, 1, True, 1], # 1540 + [1, 320, 960, 64, 64, 3, 3, 1, 1, 1, 1, 1, True, 1], # 1545 ] diff --git a/tests/sweep_framework/sweeps/matmul/short/matmul_traces.py b/tests/sweep_framework/sweeps/matmul/short/matmul_traces.py index c66e5548834..0b6c9031553 100644 --- a/tests/sweep_framework/sweeps/matmul/short/matmul_traces.py +++ b/tests/sweep_framework/sweeps/matmul/short/matmul_traces.py @@ -13,6 +13,10 @@ TIMEOUT = 70 +# params contains the shape of the first tensor followed by the second tensor +# Note: the shape of the second tensor starts at int(count / 2). It's easiest +# to reason about if both tensors are the same rank, although some other +# combinations may be valid. parameters = { "default": { "params": [ @@ -111,25 +115,269 @@ (9, 768, 768, 640), (920, 256, 256, 256), ], - } + "core_grid": [False], + }, + "gpt": { + "params": [ + (1, 1, 1, 1, 1, 1, 1, 1), + (1, 1, 1, 1, 1, 1, 1, 2304), + (1, 1, 1, 1, 1, 1, 1, 3072), + (1, 1, 1, 1, 1, 1, 1, 65536), + (1, 1, 1, 1, 1, 1, 1, 768), + (1, 1, 1, 1, 1, 1, 1, 96), + (1, 1, 1, 2304, 1, 1, 2304, 1), + (1, 1, 1, 2304, 1, 1, 2304, 65536), + (1, 1, 1, 2304, 1, 1, 2304, 768), + (1, 1, 1, 3072, 1, 1, 3072, 1), + (1, 1, 1, 3072, 1, 1, 3072, 65536), + (1, 1, 1, 3072, 1, 1, 3072, 768), + (1, 1, 1, 65536, 1, 1, 65536, 2304), + (1, 1, 1, 65536, 1, 1, 65536, 3072), + (1, 1, 1, 65536, 1, 1, 65536, 768), + (1, 1, 1, 65536, 1, 1, 65536, 96), + (1, 1, 1, 768, 1, 1, 768, 1), + (1, 1, 1, 768, 1, 1, 768, 1024), + (1, 1, 1, 768, 1, 1, 768, 2304), + (1, 1, 1, 768, 1, 1, 768, 3072), + (1, 1, 1, 768, 1, 1, 768, 65536), + (1, 1, 1, 768, 1, 1, 768, 768), + (1, 1, 1, 768, 1, 1, 768, 96), + (1, 1, 1, 96, 1, 1, 96, 1), + (1, 1, 1, 96, 1, 1, 96, 65536), + (1, 1, 1, 96, 1, 1, 96, 768), + (1, 1, 1024, 768, 1, 1, 768, 1), + (1, 1, 1024, 768, 1, 1, 768, 1024), + (1, 1, 1024, 768, 1, 1, 768, 2304), + (1, 1, 1024, 768, 1, 1, 768, 3072), + (1, 1, 1024, 768, 1, 1, 768, 65536), + (1, 1, 1024, 768, 1, 1, 768, 768), + (1, 1, 1024, 768, 1, 1, 768, 96), + (1, 1, 2304, 1, 1, 1, 1, 1), + (1, 1, 2304, 1, 1, 1, 1, 2304), + (1, 1, 2304, 1, 1, 1, 1, 3072), + (1, 1, 2304, 1, 1, 1, 1, 65536), + (1, 1, 2304, 1, 1, 1, 1, 768), + (1, 1, 2304, 1, 1, 1, 1, 96), + (1, 1, 2304, 65536, 1, 1, 65536, 1), + (1, 1, 2304, 65536, 1, 1, 65536, 2304), + (1, 1, 2304, 65536, 1, 1, 65536, 3072), + (1, 1, 2304, 65536, 1, 1, 65536, 768), + (1, 1, 2304, 65536, 1, 1, 65536, 96), + (1, 1, 2304, 768, 1, 1, 768, 1), + (1, 1, 2304, 768, 1, 1, 768, 1024), + (1, 1, 2304, 768, 1, 1, 768, 2304), + (1, 1, 2304, 768, 1, 1, 768, 3072), + (1, 1, 2304, 768, 1, 1, 768, 65536), + (1, 1, 2304, 768, 1, 1, 768, 768), + (1, 1, 2304, 768, 1, 1, 768, 96), + (1, 1, 3072, 1, 1, 1, 1, 1), + (1, 1, 3072, 1, 1, 1, 1, 2304), + (1, 1, 3072, 1, 1, 1, 1, 3072), + (1, 1, 3072, 1, 1, 1, 1, 65536), + (1, 1, 3072, 1, 1, 1, 1, 768), + (1, 1, 3072, 1, 1, 1, 1, 96), + (1, 1, 3072, 65536, 1, 1, 65536, 1), + (1, 1, 3072, 65536, 1, 1, 65536, 2304), + (1, 1, 3072, 65536, 1, 1, 65536, 3072), + (1, 1, 3072, 65536, 1, 1, 65536, 768), + (1, 1, 3072, 65536, 1, 1, 65536, 96), + (1, 1, 3072, 768, 1, 1, 768, 1), + (1, 1, 3072, 768, 1, 1, 768, 1024), + (1, 1, 3072, 768, 1, 1, 768, 2304), + (1, 1, 3072, 768, 1, 1, 768, 3072), + (1, 1, 3072, 768, 1, 1, 768, 65536), + (1, 1, 3072, 768, 1, 1, 768, 768), + (1, 1, 3072, 768, 1, 1, 768, 96), + (1, 1, 65536, 1, 1, 1, 1, 1), + (1, 1, 65536, 1, 1, 1, 1, 2304), + (1, 1, 65536, 1, 1, 1, 1, 3072), + (1, 1, 65536, 1, 1, 1, 1, 65536), + (1, 1, 65536, 1, 1, 1, 1, 768), + (1, 1, 65536, 1, 1, 1, 1, 96), + (1, 1, 65536, 2304, 1, 1, 2304, 1), + (1, 1, 65536, 2304, 1, 1, 2304, 65536), + (1, 1, 65536, 2304, 1, 1, 2304, 768), + (1, 1, 65536, 3072, 1, 1, 3072, 1), + (1, 1, 65536, 3072, 1, 1, 3072, 65536), + (1, 1, 65536, 3072, 1, 1, 3072, 768), + (1, 1, 65536, 768, 1, 1, 768, 1), + (1, 1, 65536, 768, 1, 1, 768, 1024), + (1, 1, 65536, 768, 1, 1, 768, 2304), + (1, 1, 65536, 768, 1, 1, 768, 3072), + (1, 1, 65536, 768, 1, 1, 768, 65536), + (1, 1, 65536, 768, 1, 1, 768, 768), + (1, 1, 65536, 768, 1, 1, 768, 96), + (1, 1, 65536, 96, 1, 1, 96, 65536), + (1, 1, 65536, 96, 1, 1, 96, 768), + (1, 1, 768, 1, 1, 1, 1, 1), + (1, 1, 768, 1, 1, 1, 1, 2304), + (1, 1, 768, 1, 1, 1, 1, 3072), + (1, 1, 768, 1, 1, 1, 1, 65536), + (1, 1, 768, 1, 1, 1, 1, 768), + (1, 1, 768, 1, 1, 1, 1, 96), + (1, 1, 768, 1024, 1, 1, 1024, 768), + (1, 1, 768, 2304, 1, 1, 2304, 1), + (1, 1, 768, 2304, 1, 1, 2304, 65536), + (1, 1, 768, 2304, 1, 1, 2304, 768), + (1, 1, 768, 3072, 1, 1, 3072, 1), + (1, 1, 768, 3072, 1, 1, 3072, 65536), + (1, 1, 768, 3072, 1, 1, 3072, 768), + (1, 1, 768, 65536, 1, 1, 65536, 1), + (1, 1, 768, 65536, 1, 1, 65536, 2304), + (1, 1, 768, 65536, 1, 1, 65536, 3072), + (1, 1, 768, 65536, 1, 1, 65536, 768), + (1, 1, 768, 65536, 1, 1, 65536, 96), + (1, 1, 768, 768, 1, 1, 768, 1), + (1, 1, 768, 768, 1, 1, 768, 1024), + (1, 1, 768, 768, 1, 1, 768, 2304), + (1, 1, 768, 768, 1, 1, 768, 3072), + (1, 1, 768, 768, 1, 1, 768, 65536), + (1, 1, 768, 768, 1, 1, 768, 768), + (1, 1, 768, 768, 1, 1, 768, 96), + (1, 1, 768, 96, 1, 1, 96, 1), + (1, 1, 768, 96, 1, 1, 96, 65536), + (1, 1, 768, 96, 1, 1, 96, 768), + (1, 1, 96, 1, 1, 1, 1, 1), + (1, 1, 96, 1, 1, 1, 1, 2304), + (1, 1, 96, 1, 1, 1, 1, 3072), + (1, 1, 96, 1, 1, 1, 1, 65536), + (1, 1, 96, 1, 1, 1, 1, 768), + (1, 1, 96, 1, 1, 1, 1, 96), + (1, 1, 96, 65536, 1, 1, 65536, 1), + (1, 1, 96, 65536, 1, 1, 65536, 2304), + (1, 1, 96, 65536, 1, 1, 65536, 3072), + (1, 1, 96, 65536, 1, 1, 65536, 768), + (1, 1, 96, 65536, 1, 1, 65536, 96), + (1, 1, 96, 768, 1, 1, 768, 1), + (1, 1, 96, 768, 1, 1, 768, 1024), + (1, 1, 96, 768, 1, 1, 768, 2304), + (1, 1, 96, 768, 1, 1, 768, 3072), + (1, 1, 96, 768, 1, 1, 768, 65536), + (1, 1, 96, 768, 1, 1, 768, 768), + (1, 1, 96, 768, 1, 1, 768, 96), + (1, 64, 1024, 768, 1, 1, 768, 1), + (1, 64, 1024, 768, 1, 1, 768, 2304), + (1, 64, 1024, 768, 1, 1, 768, 3072), + (1, 64, 1024, 768, 1, 1, 768, 65536), + (1, 64, 1024, 768, 1, 1, 768, 768), + (1, 64, 1024, 768, 1, 1, 768, 96), + (1, 64, 768, 1024, 1, 1, 1024, 768), + (1, 64, 768, 1024, 1, 64, 1024, 768), + (64, 1, 1, 1024, 1, 1, 1024, 768), + (64, 1, 1, 1024, 64, 1, 1024, 1), + (64, 1, 1, 1024, 64, 1, 1024, 2304), + (64, 1, 1, 1024, 64, 1, 1024, 3072), + (64, 1, 1, 1024, 64, 1, 1024, 768), + (64, 1, 1, 1024, 64, 1, 1024, 96), + (64, 1, 1, 768, 1, 1, 768, 1), + (64, 1, 1, 768, 1, 1, 768, 1024), + (64, 1, 1, 768, 1, 1, 768, 2304), + (64, 1, 1, 768, 1, 1, 768, 3072), + (64, 1, 1, 768, 1, 1, 768, 65536), + (64, 1, 1, 768, 1, 1, 768, 768), + (64, 1, 1, 768, 1, 1, 768, 96), + (64, 1, 1, 768, 64, 1, 768, 1), + (64, 1, 1, 768, 64, 1, 768, 1024), + (64, 1, 1024, 1, 1, 1, 1, 2304), + (64, 1, 1024, 1, 1, 1, 1, 3072), + (64, 1, 1024, 1, 1, 1, 1, 768), + (64, 1, 1024, 1, 1, 1, 1, 96), + (64, 1, 1024, 1, 64, 1, 1, 1024), + (64, 1, 1024, 1, 64, 1, 1, 768), + (64, 1, 1024, 2304, 1, 1, 2304, 65536), + (64, 1, 1024, 2304, 1, 1, 2304, 768), + (64, 1, 1024, 2304, 64, 1, 2304, 1024), + (64, 1, 1024, 3072, 1, 1, 3072, 1), + (64, 1, 1024, 3072, 1, 1, 3072, 65536), + (64, 1, 1024, 3072, 1, 1, 3072, 768), + (64, 1, 1024, 768, 1, 1, 768, 1), + (64, 1, 1024, 768, 1, 1, 768, 1024), + (64, 1, 1024, 768, 1, 1, 768, 2304), + (64, 1, 1024, 768, 1, 1, 768, 3072), + (64, 1, 1024, 768, 1, 1, 768, 65536), + (64, 1, 1024, 768, 1, 1, 768, 768), + (64, 1, 1024, 768, 1, 1, 768, 96), + (64, 1, 1024, 768, 64, 1, 768, 1024), + (64, 1, 1024, 96, 1, 1, 96, 65536), + (64, 1, 1024, 96, 1, 1, 96, 768), + (64, 1, 1024, 96, 64, 1, 96, 1024), + (64, 1, 2304, 1024, 1, 1, 1024, 768), + (64, 1, 2304, 1024, 64, 1, 1024, 1), + (64, 1, 2304, 1024, 64, 1, 1024, 2304), + (64, 1, 2304, 1024, 64, 1, 1024, 3072), + (64, 1, 2304, 1024, 64, 1, 1024, 768), + (64, 1, 2304, 1024, 64, 1, 1024, 96), + (64, 1, 3072, 1024, 1, 1, 1024, 768), + (64, 1, 3072, 1024, 64, 1, 1024, 1), + (64, 1, 3072, 1024, 64, 1, 1024, 2304), + (64, 1, 3072, 1024, 64, 1, 1024, 3072), + (64, 1, 3072, 1024, 64, 1, 1024, 768), + (64, 1, 3072, 1024, 64, 1, 1024, 96), + (64, 1, 768, 1, 1, 1, 1, 2304), + (64, 1, 768, 1, 1, 1, 1, 3072), + (64, 1, 768, 1, 1, 1, 1, 768), + (64, 1, 768, 1, 1, 1, 1, 96), + (64, 1, 768, 1, 64, 1, 1, 768), + (64, 1, 768, 1024, 1, 1, 1024, 768), + (64, 1, 768, 1024, 64, 1, 1024, 1), + (64, 1, 768, 1024, 64, 1, 1024, 2304), + (64, 1, 768, 1024, 64, 1, 1024, 3072), + (64, 1, 768, 1024, 64, 1, 1024, 768), + (64, 1, 768, 1024, 64, 1, 1024, 96), + (64, 1, 96, 1024, 1, 1, 1024, 768), + (64, 1, 96, 1024, 64, 1, 1024, 1), + (64, 1, 96, 1024, 64, 1, 1024, 2304), + (64, 1, 96, 1024, 64, 1, 1024, 3072), + (64, 1, 96, 1024, 64, 1, 1024, 768), + (64, 1, 96, 1024, 64, 1, 1024, 96), + (64, 12, 1, 1024, 1, 1, 1024, 768), + (64, 12, 1, 1024, 64, 12, 1024, 1), + (64, 12, 1, 1024, 64, 12, 1024, 1024), + (64, 12, 1, 1024, 64, 12, 1024, 64), + (64, 12, 1024, 1, 1, 1, 1, 1), + (64, 12, 1024, 1, 1, 1, 1, 2304), + (64, 12, 1024, 1, 1, 1, 1, 3072), + (64, 12, 1024, 1, 1, 1, 1, 768), + (64, 12, 1024, 1, 1, 1, 1, 96), + (64, 12, 1024, 1, 64, 12, 1, 1024), + (64, 12, 1024, 1024, 1, 1, 1024, 768), + (64, 12, 1024, 1024, 64, 12, 1024, 1), + (64, 12, 1024, 1024, 64, 12, 1024, 1024), + (64, 12, 1024, 1024, 64, 12, 1024, 64), + (64, 12, 1024, 64, 64, 12, 64, 1024), + (64, 12, 64, 1024, 1, 1, 1024, 768), + (64, 12, 64, 1024, 64, 12, 1024, 1), + (64, 12, 64, 1024, 64, 12, 1024, 1024), + (64, 12, 64, 1024, 64, 12, 1024, 64), + ], + "core_grid": [True, False], + }, } def run( params, + core_grid, *, device, ) -> list: - [in0_h, in0_w, in1_h, in1_w] = params - torch_input_tensor0 = torch.rand([in0_h, in0_w], dtype=torch.float32) - torch_input_tensor1 = torch.rand([in1_h, in1_w], dtype=torch.float32) + if core_grid == False: + grid = None + else: + grid = device.core_grid + count = len(params) + half = int(count / 2) + shape0 = params[0:half] + shape1 = params[half:count] + torch_input_tensor0 = torch.rand(shape0, dtype=torch.float32) + torch_input_tensor1 = torch.rand(shape1, dtype=torch.float32) torch_output_tensor = torch.matmul(torch_input_tensor0, torch_input_tensor1) input_tensor0 = ttnn.from_torch(torch_input_tensor0, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device) input_tensor1 = ttnn.from_torch(torch_input_tensor1, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device) start_time = start_measuring_time() - output_tensor = ttnn.matmul(input_tensor0, input_tensor1) + output_tensor = ttnn.matmul(input_tensor0, input_tensor1, core_grid=grid) output_tensor = ttnn.to_torch(output_tensor) e2e_perf = stop_measuring_time(start_time) expected_pcc = 0.99 diff --git a/tests/sweep_framework/sweeps/transformer/rotary_embedding/rotary_embedding.py b/tests/sweep_framework/sweeps/transformer/rotary_embedding/rotary_embedding.py new file mode 100644 index 00000000000..9c39ab1ae1f --- /dev/null +++ b/tests/sweep_framework/sweeps/transformer/rotary_embedding/rotary_embedding.py @@ -0,0 +1,137 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, Tuple +from functools import partial + +import torch +import random +import ttnn +from tests.sweep_framework.sweep_utils.utils import gen_shapes, gen_rotary_embedding_spec +from tests.tt_eager.python_api_testing.sweep_tests.generation_funcs import gen_func_with_cast_tt + +from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time +from models.utility_functions import torch_random + +# Override the default timeout in seconds for hang detection. +TIMEOUT = 30 + +random.seed(0) + + +# Parameters provided to the test vector generator are defined here. +# They are defined as dict-type suites that contain the arguments to the run function as keys, and lists of possible inputs as values. +# Each suite has a key name (in this case "suite_1" and "suite_2") which will associate the test vectors to this specific suite of inputs. +# Developers can create their own generator functions and pass them to the parameters as inputs. +parameters = { + "nightly": { + "input_spec": gen_rotary_embedding_spec( + input_shape_list=gen_shapes([1, 1, 32, 64], [6, 12, 256, 512], [1, 1, 32, 64], 16), + cache_size_list=[random.randint(1, 2048) for i in range(8)], + ), + "input_dtype": [ttnn.bfloat16, ttnn.bfloat8_b], + "input_layout": [ttnn.ROW_MAJOR_LAYOUT, ttnn.TILE_LAYOUT], + "input_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + "output_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + }, +} + + +def invalidate_vector(test_vector) -> Tuple[bool, Optional[str]]: + if test_vector["input_layout"] == ttnn.ROW_MAJOR_LAYOUT and test_vector["input_dtype"] == ttnn.bfloat8_b: + return True, "bfloat8_b/bfloat4_b requires TILE_LAYOUT!" + if test_vector["input_spec"]["input_shape"][-1] % 64 != 0: + return True, "Input X dimension (133) must be divisible by 64 for tiling" + if test_vector["input_spec"]["token_idx"] and test_vector["input_spec"]["input_shape"][0] != 1: + return True, "When passing token_idx, sequence length must be 1" + return False, None + + +# This is the run instructions for the test, defined by the developer. +# The run function must take the above-defined parameters as inputs. +# The runner will call this run function with each test vector, and the returned results from this function will be stored. +# If you defined a mesh_device_fixture above, the object you yielded will be passed into this function as 'device'. Otherwise, it will be the default ttnn device opened by the infra. +def run( + input_spec, + input_dtype, + input_layout, + input_memory_config, + output_memory_config, + *, + device, +) -> list: + data_seed = random.randint(0, 20000000) + torch.manual_seed(data_seed) + + input_shape, cache_size, token_idx = input_spec.values() + seq_length, batch_size, num_heads, head_dim = input_shape + + sin_cos_cache_shape = [1, 1, cache_size, head_dim] + + torch_input_tensor = gen_func_with_cast_tt( + partial(torch_random, low=-100, high=100, dtype=torch.float32), input_dtype + )(input_shape) + torch_cos_cache_tensor = gen_func_with_cast_tt( + partial(torch_random, low=-100, high=100, dtype=torch.float32), input_dtype + )(sin_cos_cache_shape) + torch_sin_cache_tensor = gen_func_with_cast_tt( + partial(torch_random, low=-100, high=100, dtype=torch.float32), input_dtype + )(sin_cos_cache_shape) + + if token_idx: + golden_function = partial(ttnn.get_golden_function(ttnn.experimental.rotary_embedding), token_idx=token_idx) + else: + + def rotate_half(x): + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + def apply_rotary_pos_emb(x, cos_cached, sin_cached, token_idx=None): + seq_len = x.shape[-2] + if token_idx is None: + cos = cos_cached[:, :, :seq_len, ...] + sin = sin_cached[:, :, :seq_len, ...] + else: + cos = cos_cached[:, :, token_idx : token_idx + 1, ...] + sin = sin_cached[:, :, token_idx : token_idx + 1, ...] + + x_embed = (x * cos) + (rotate_half(x) * sin) + return x_embed + + golden_function = apply_rotary_pos_emb + + torch_output_tensor = golden_function(torch_input_tensor, torch_cos_cache_tensor, torch_sin_cache_tensor) + + input_tensor = ttnn.from_torch( + torch_input_tensor, + dtype=input_dtype, + layout=input_layout, + device=device, + memory_config=input_memory_config, + ) + cos_cache_tensor = ttnn.from_torch( + torch_cos_cache_tensor, + dtype=input_dtype, + layout=input_layout, + device=device, + memory_config=input_memory_config, + ) + sin_cache_tensor = ttnn.from_torch( + torch_sin_cache_tensor, + dtype=input_dtype, + layout=input_layout, + device=device, + memory_config=input_memory_config, + ) + + start_time = start_measuring_time() + output_tensor = ttnn.experimental.rotary_embedding( + input_tensor, cos_cache_tensor, sin_cache_tensor, token_idx, memory_config=output_memory_config + ) + e2e_perf = stop_measuring_time(start_time) + + output_tensor = ttnn.to_torch(output_tensor) + + return [check_with_pcc(torch_output_tensor, output_tensor, 0.999), e2e_perf] diff --git a/tests/tt_eager/ops/test_bcast_op.cpp b/tests/tt_eager/ops/test_bcast_op.cpp index a761c1eba51..05be3303c06 100644 --- a/tests/tt_eager/ops/test_bcast_op.cpp +++ b/tests/tt_eager/ops/test_bcast_op.cpp @@ -3,16 +3,13 @@ // SPDX-License-Identifier: Apache-2.0 #include "tt_metal/host_api.hpp" +#include "ttnn/cpp/ttnn/operations/creation.hpp" #include "ttnn/tensor/tensor.hpp" #include "ttnn/operations/data_movement/bcast/bcast.hpp" #include "common/constants.hpp" #include #include -#include -#include -#include - using namespace tt; using namespace tt_metal; using namespace constants; @@ -53,9 +50,8 @@ int main(int argc, char** argv) { } Tensor a = ttnn::numpy::random::random(input_shape_a).to(Layout::TILE).to(device); - Tensor b = ttnn::numpy::zeros({1, 1, TILE_HEIGHT, TILE_WIDTH}, DataType::BFLOAT16) - .to(Layout::TILE) - .to(device); + Tensor b = ttnn::zeros( + ttnn::Shape({1, 1, TILE_HEIGHT, TILE_WIDTH}), DataType::BFLOAT16, Layout::TILE, *device); for (auto bcast_math : magic_enum::enum_values()) { Tensor c = ttnn::bcast(0, a, b, bcast_math, bcast_dim); @@ -72,28 +68,28 @@ int main(int argc, char** argv) { { Tensor a = ttnn::numpy::random::random({1, 1, 32, 4544}).to(Layout::TILE).to(device); - Tensor b = ttnn::numpy::zeros({1, 1, 32, 4544}, DataType::BFLOAT16).to(Layout::TILE).to(device); + Tensor b = ttnn::zeros(ttnn::Shape({1, 1, 32, 4544}), DataType::BFLOAT16, Layout::TILE, *device); Tensor c = ttnn::bcast(0, a, b, ttnn::BcastOpMath::MUL, ttnn::BcastOpDim::H); Tensor d = c.cpu(); } { Tensor a = ttnn::numpy::random::random({1, 1, 32, 4544}).to(Layout::TILE).to(device); - Tensor b = ttnn::numpy::zeros({1, 1, 32, 4544}, DataType::BFLOAT16).to(Layout::TILE).to(device); + Tensor b = ttnn::zeros(ttnn::Shape({1, 1, 32, 4544}), DataType::BFLOAT16, Layout::TILE, *device); Tensor c = ttnn::bcast(0, a, b, ttnn::BcastOpMath::ADD, ttnn::BcastOpDim::H); Tensor d = c.cpu(); } { Tensor a = ttnn::numpy::random::random({1, 71, 32, 32}).to(Layout::TILE).to(device); - Tensor b = ttnn::numpy::zeros({1, 1, 32, 32}, DataType::BFLOAT16).to(Layout::TILE).to(device); + Tensor b = ttnn::zeros(ttnn::Shape({1, 1, 32, 32}), DataType::BFLOAT16, Layout::TILE, *device); Tensor c = ttnn::bcast(0, a, b, ttnn::BcastOpMath::MUL, ttnn::BcastOpDim::HW); Tensor d = c.cpu(); } { Tensor a = ttnn::numpy::random::random({1, 71, 32, 64}).to(Layout::TILE).to(device); - Tensor b = ttnn::numpy::zeros({1, 1, 32, 32}, DataType::BFLOAT16).to(Layout::TILE).to(device); + Tensor b = ttnn::zeros(ttnn::Shape({1, 1, 32, 32}), DataType::BFLOAT16, Layout::TILE, *device); Tensor c = ttnn::bcast(0, a, b, ttnn::BcastOpMath::MUL, ttnn::BcastOpDim::HW); Tensor d = c.cpu(); } diff --git a/tests/tt_eager/ops/test_bmm_op.cpp b/tests/tt_eager/ops/test_bmm_op.cpp index c7760e67354..f769870b595 100644 --- a/tests/tt_eager/ops/test_bmm_op.cpp +++ b/tests/tt_eager/ops/test_bmm_op.cpp @@ -3,15 +3,13 @@ // SPDX-License-Identifier: Apache-2.0 #include "tt_metal/host_api.hpp" +#include "ttnn/cpp/ttnn/operations/creation.hpp" #include "ttnn/tensor/tensor.hpp" +#include "ttnn/tensor/types.hpp" #include "ttnn/operations/matmul/device/matmul_op.hpp" #include "common/constants.hpp" #include "ttnn/operations/numpy/functions.hpp" -#include -#include -#include - using namespace tt; using namespace tt_metal; using namespace constants; @@ -37,14 +35,14 @@ int main(int argc, char** argv) { uint32_t Kt = 2; uint32_t Nt = 4; uint32_t B = 5; - tt::tt_metal::LegacyShape shapea = {B, 1, Mt * TILE_HEIGHT, Kt * TILE_WIDTH}; - tt::tt_metal::LegacyShape shapeb = {B, 1, Kt * TILE_HEIGHT, Nt * TILE_WIDTH}; - tt::tt_metal::LegacyShape shapeb1 = {1, 1, Kt * TILE_HEIGHT, Nt * TILE_WIDTH}; + ttnn::Shape shapea({B, 1, Mt * TILE_HEIGHT, Kt * TILE_WIDTH}); + ttnn::Shape shapeb({B, 1, Kt * TILE_HEIGHT, Nt * TILE_WIDTH}); + ttnn::Shape shapeb1({1, 1, Kt * TILE_HEIGHT, Nt * TILE_WIDTH}); // Allocates a DRAM buffer on device populated with values specified by initialize - Tensor a = ttnn::numpy::random::random(shapea).to(Layout::TILE).to(device); - Tensor b = ttnn::numpy::zeros(shapeb, DataType::BFLOAT16).to(Layout::TILE).to(device); - Tensor b1 = ttnn::numpy::zeros(shapeb1, DataType::BFLOAT16).to(Layout::TILE).to(device); + Tensor a = ttnn::numpy::random::random(shapea.value).to(Layout::TILE).to(device); + Tensor b = ttnn::zeros(shapeb, DataType::BFLOAT16, Layout::TILE, *device); + Tensor b1 = ttnn::zeros(shapeb1, DataType::BFLOAT16, Layout::TILE, *device); Tensor mm = ttnn::operations::matmul::matmul( a, diff --git a/tests/tt_eager/ops/test_tensor_utils.cpp b/tests/tt_eager/ops/test_tensor_utils.cpp index 1121b455b2a..0b6c2e3d376 100644 --- a/tests/tt_eager/ops/test_tensor_utils.cpp +++ b/tests/tt_eager/ops/test_tensor_utils.cpp @@ -11,14 +11,11 @@ #include "ttnn/tensor/host_buffer/types.hpp" #include "ttnn/tensor/tensor.hpp" #include "ttnn/tensor/tensor.hpp" -#include "ttnn/operations/numpy/functions.hpp" +#include "ttnn/operations/creation.hpp" #include "ttnn/tensor/types.hpp" #include "ttnn/tensor/tensor_utils.hpp" -using std::vector; -using tt::tt_metal::Tensor; -using namespace tt::tt_metal; -static vector> ref_weight_in = { +static std::vector> ref_weight_in = { { 16140, 16151, 16183, 16216, 16154, 16219, 16139, 16216, 16088, 16159, 16165, 16068, 16096, 16024, 16228, 15720, 16246, 16011, 16068, 16116, 16202, 16207, 16135, 16117, 16145, 16073, 16236, 16214, 15761, 16044, 15794, 16165, @@ -246,7 +243,7 @@ static vector> ref_weight_in = { } }; -static vector> ref_weight_out = { +static std::vector> ref_weight_out = { {16140, 16151, 16183, 16216, 16154, 16219, 16139, 16216, 16088, 16156, 15971, 16157, 16069, 16241, 16231, 16174, 16102, 16056, 16250, 15716, 16154, 16102, 16189, 15523, 15648, 16098, 16016, 15972, 16228, 16243, 16174, 16100, 16101, 16216, 16250, 16179, 16206, 16137, 16180, 16101, 15821, 15819, 16235, 16052, 16182, 15912, 16128, 16159, @@ -419,9 +416,11 @@ static vector> ref_weight_out = { 15832, 15895, 16234, 16062, 16231, 16173, 16122, 16016, 16187, 15560, 16229, 16046, 16243, 16219, 15849, 16135, }}; -static vector weight_tensor_shape = {{8, 8, 3, 3}, {10, 10, 3, 3}, {12, 8, 3, 3}, {8, 15, 3, 3}}; -static vector bias_tensor_shape = {{1, 1, 1, 32}, {1, 1, 1, 60}, {12, 1, 1, 320}, {8, 1, 1, 48}}; -static vector shards = {8, 3, 5, 4}; +static std::vector weight_tensor_shape = { + {8, 8, 3, 3}, {10, 10, 3, 3}, {12, 8, 3, 3}, {8, 15, 3, 3}}; +static std::vector bias_tensor_shape = { + {1, 1, 1, 32}, {1, 1, 1, 60}, {12, 1, 1, 320}, {8, 1, 1, 48}}; +static std::vector shards = {8, 3, 5, 4}; template static uint32_t compare_out_with_ref(const owned_buffer::Buffer& out_buf, T& ref) { @@ -447,7 +446,7 @@ static uint32_t compare_out_with_ref(const owned_buffer::Buffer& out_b static void test_convert_conv_weight_tensor_to_tiled_layout_block_sharded() { tt::log_info(tt::LogTest, "Running {}", __func__); for (auto i = 0; i < weight_tensor_shape.size(); i++) { - auto input_tensor = ttnn::numpy::zeros(weight_tensor_shape[i]); + auto input_tensor = ttnn::zeros(weight_tensor_shape[i]); auto input_buffer = owned_buffer::get_as(input_tensor); for (auto j = 0; j < input_buffer.size(); j++) { input_buffer[j] = ref_weight_in[i][j]; diff --git a/tests/tt_eager/ops/test_tilize_zero_padding_channels_last.cpp b/tests/tt_eager/ops/test_tilize_zero_padding_channels_last.cpp index 1a0d13092ad..a94093f8ac6 100644 --- a/tests/tt_eager/ops/test_tilize_zero_padding_channels_last.cpp +++ b/tests/tt_eager/ops/test_tilize_zero_padding_channels_last.cpp @@ -7,12 +7,12 @@ #include #include "common/constants.hpp" +#include "ttnn/cpp/ttnn/operations/creation.hpp" #include "ttnn/tensor/host_buffer/functions.hpp" #include "ttnn/tensor/host_buffer/types.hpp" #include "ttnn/tensor/tensor.hpp" #include "ttnn/operations/data_movement/tilize_with_val_padding/tilize_with_val_padding.hpp" #include "tt_metal/host_api.hpp" -#include "ttnn/operations/numpy/functions.hpp" using namespace tt; using namespace tt_metal; @@ -37,7 +37,8 @@ int main(int argc, char** argv) { //////////////////////////////////////////////////////////////////////////// ttnn::SimpleShape shape{1, 32, 61, 32}; // Allocates a DRAM buffer on device populated with values specified by initialize - Tensor a = ttnn::numpy::arange(0, shape.volume(), 1).reshape(shape).to(device); + Tensor a = ttnn::arange(/*start=*/0, /*stop=*/shape.volume(), /*step=*/1, DataType::BFLOAT16, std::ref(*device)) + .reshape(shape); Tensor b = ttnn::tilize_with_zero_padding(a); Tensor c = b.cpu(); //////////////////////////////////////////////////////////////////////////// diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_num_cores_to_corerangeset_in_subcoregrids.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_num_cores_to_corerangeset_in_subcoregrids.py new file mode 100644 index 00000000000..0840e55513d --- /dev/null +++ b/tests/tt_eager/python_api_testing/unit_testing/misc/test_num_cores_to_corerangeset_in_subcoregrids.py @@ -0,0 +1,93 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import ttnn +import pytest + + +@pytest.mark.parametrize( + "start_core, num_cores, sub_core_grids, row_wise, expected_core_range_set", + [ + # Test Case 1: Basic row-wise scenario with enough cores in sub_core_grids + ( + ttnn.CoreCoord(1, 0), + 32, + ttnn.CoreRangeSet( + [ + ttnn.CoreRange(ttnn.CoreCoord(1, 0), ttnn.CoreCoord(3, 9)), + ttnn.CoreRange(ttnn.CoreCoord(5, 0), ttnn.CoreCoord(6, 9)), + ] + ), + True, + ttnn.CoreRangeSet( + [ + ttnn.CoreRange(ttnn.CoreCoord(1, 0), ttnn.CoreCoord(3, 9)), + ttnn.CoreRange(ttnn.CoreCoord(5, 0), ttnn.CoreCoord(6, 0)), + ] + ), + ), + # Test Case 2: Basic Column-wise processing + ( + ttnn.CoreCoord(1, 0), + 32, + ttnn.CoreRangeSet( + [ + ttnn.CoreRange(ttnn.CoreCoord(1, 0), ttnn.CoreCoord(3, 9)), + ttnn.CoreRange(ttnn.CoreCoord(5, 0), ttnn.CoreCoord(6, 9)), + ] + ), + False, + ttnn.CoreRangeSet( + [ + ttnn.CoreRange(ttnn.CoreCoord(1, 0), ttnn.CoreCoord(3, 9)), + ttnn.CoreRange(ttnn.CoreCoord(5, 0), ttnn.CoreCoord(5, 1)), + ] + ), + ), + # Test Case 3: row-wise scenario with small target cores and start offset + ( + ttnn.CoreCoord(3, 2), + 8, + ttnn.CoreRangeSet( + [ + ttnn.CoreRange(ttnn.CoreCoord(1, 0), ttnn.CoreCoord(3, 9)), + ttnn.CoreRange(ttnn.CoreCoord(5, 0), ttnn.CoreCoord(6, 9)), + ] + ), + True, + ttnn.CoreRangeSet( + [ + ttnn.CoreRange(ttnn.CoreCoord(3, 2), ttnn.CoreCoord(3, 2)), + ttnn.CoreRange(ttnn.CoreCoord(1, 3), ttnn.CoreCoord(3, 4)), + ttnn.CoreRange(ttnn.CoreCoord(1, 5), ttnn.CoreCoord(1, 5)), + ] + ), + ), + # Test Case 4: col-wise scenario with small target cores and start offset + ( + ttnn.CoreCoord(1, 8), + 8, + ttnn.CoreRangeSet( + [ + ttnn.CoreRange(ttnn.CoreCoord(1, 0), ttnn.CoreCoord(3, 9)), + ttnn.CoreRange(ttnn.CoreCoord(5, 0), ttnn.CoreCoord(6, 9)), + ] + ), + False, + ttnn.CoreRangeSet( + [ + ttnn.CoreRange(ttnn.CoreCoord(1, 8), ttnn.CoreCoord(1, 9)), + ttnn.CoreRange(ttnn.CoreCoord(2, 0), ttnn.CoreCoord(2, 5)), + ] + ), + ), + ], +) +def test_numcores_to_corerangeset_in_subcoregrids( + start_core, num_cores, sub_core_grids, row_wise, expected_core_range_set +): + output_corerangeset = ttnn.num_cores_to_corerangeset_in_subcoregrids( + start_core, num_cores, sub_core_grids, row_wise=row_wise + ) + assert output_corerangeset.to_json() == expected_core_range_set.to_json() diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_scaled_dot_product_attention.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_scaled_dot_product_attention.py index 98bca8dc1a3..736b2f2db82 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_scaled_dot_product_attention.py +++ b/tests/tt_eager/python_api_testing/unit_testing/misc/test_scaled_dot_product_attention.py @@ -183,8 +183,10 @@ def test_sdpa_tt_with_program_cache(device, b, nh, nkv, s, d, q_chunk_size, k_ch assert device.num_program_cache_entries() == 1 -def run_sdpa_noncausal(device, b, nh, nkv, s, d, q_chunk_size, k_chunk_size, dtype): +def run_sdpa_noncausal(device, b, nh, nkv, sq, d, q_chunk_size, k_chunk_size, dtype, sk=None): torch.manual_seed(1234) + if sk is None: + sk = sq program_config = ttnn.SDPAProgramConfig( compute_with_storage_grid_size=device.compute_with_storage_grid_size(), @@ -200,16 +202,16 @@ def run_sdpa_noncausal(device, b, nh, nkv, s, d, q_chunk_size, k_chunk_size, dty packer_l1_acc=False, ) - Q = fa_rand(b, nh, s, d) - K = fa_rand(b, nkv, s, d) - V = fa_rand(b, nkv, s, d) + Q = fa_rand(b, nh, sq, d) + K = fa_rand(b, nkv, sk, d) + V = fa_rand(b, nkv, sk, d) # Generate random non-causal attention mask mask = torch.bernoulli( torch.full( ( b, - s, - s, + sq, + sk, ), 0.25, ) @@ -240,8 +242,8 @@ def run_sdpa_noncausal(device, b, nh, nkv, s, d, q_chunk_size, k_chunk_size, dty if nkv > 1 and nkv != nh: assert nh % nkv == 0 - K = K.reshape(b, nkv, 1, s, d).repeat(1, 1, nh // nkv, 1, 1).reshape(b, nh, s, d) - V = V.reshape(b, nkv, 1, s, d).repeat(1, 1, nh // nkv, 1, 1).reshape(b, nh, s, d) + K = K.reshape(b, nkv, 1, sk, d).repeat(1, 1, nh // nkv, 1, 1).reshape(b, nh, sk, d) + V = V.reshape(b, nkv, 1, sk, d).repeat(1, 1, nh // nkv, 1, 1).reshape(b, nh, sk, d) gt = torch.nn.functional.scaled_dot_product_attention(Q, K, V, is_causal=False, attn_mask=mask) @@ -274,3 +276,24 @@ def test_sdpa_noncausal(device, b, nh, nkv, s, d, q_chunk_size, k_chunk_size, dt pytest.skip("Bad PCC for small chunks") ttnn.device.DisablePersistentKernelCache() run_sdpa_noncausal(device, b, nh, nkv, s, d, q_chunk_size, k_chunk_size, dtype) + + +@skip_for_blackhole("Mismatching on BH, see #12349") +@pytest.mark.skipif(is_watcher_enabled(), reason="Kernel OOM with watcher enabled") +@skip_for_grayskull("Unsupported in GS since L1 runs OOM with most configs") +@pytest.mark.parametrize("dtype", [ttnn.bfloat8_b, ttnn.bfloat16], ids=["bfp8", "bf16"]) +@pytest.mark.parametrize("q_chunk_size", [128, 256], ids=["q128", "q256"]) +@pytest.mark.parametrize("k_chunk_size", [128, 256], ids=["k128", "k256"]) +@pytest.mark.parametrize( + "b, nh, nkv, sq, sk, d", + ( + [1, 8, 1, 4096, 2048, 128], + # [1, 4, 4, 128*1024, 6528, 128], # Llama-Vision long seq + [1, 4, 1, 2048, 6528, 128], # Llama-Vision + ), +) +def test_sdpa_noncausal_unequal_seqlen(device, b, nh, nkv, sq, sk, d, q_chunk_size, k_chunk_size, dtype): + if (sq % q_chunk_size != 0) or (sk % k_chunk_size != 0): + pytest.skip("s must be divisible by q_chunk_size and k_chunk_size") + ttnn.device.DisablePersistentKernelCache() + run_sdpa_noncausal(device, b, nh, nkv, sq, d, q_chunk_size, k_chunk_size, dtype, sk=sk) diff --git a/tests/tt_eager/tensors/test_async_tensor_apis.cpp b/tests/tt_eager/tensors/test_async_tensor_apis.cpp index b762f10acf7..3ef44800178 100644 --- a/tests/tt_eager/tensors/test_async_tensor_apis.cpp +++ b/tests/tt_eager/tensors/test_async_tensor_apis.cpp @@ -2,13 +2,11 @@ // // SPDX-License-Identifier: Apache-2.0 -#include #include -#include -#include #include "common/bfloat16.hpp" #include "common/constants.hpp" +#include "ttnn/cpp/ttnn/operations/creation.hpp" #include "ttnn/tensor/host_buffer/functions.hpp" #include "ttnn/tensor/host_buffer/types.hpp" #include "ttnn/tensor/tensor.hpp" @@ -16,16 +14,16 @@ #include "ttnn/tensor/types.hpp" #include "tests/tt_metal/tt_metal/common/dispatch_fixture.hpp" #include "tt_metal/host_api.hpp" -#include "ttnn/operations/numpy/functions.hpp" #include "ttnn/operations/eltwise/binary/binary.hpp" #include "ttnn/operations/eltwise/unary/unary.hpp" -using namespace tt; -using namespace tt_metal; -using namespace constants; - +namespace tt::tt_metal { namespace { + +using ::tt::constants::TILE_HEIGHT; +using ::tt::constants::TILE_WIDTH; + uint32_t get_device_buffer_address(const Tensor& tensor) { TT_FATAL(std::holds_alternative(tensor.get_storage()), "Tensor storage is not DeviceStorage"); auto buffer = std::get(tensor.get_storage()).buffer; @@ -33,13 +31,12 @@ uint32_t get_device_buffer_address(const Tensor& tensor) { buffer->device()->push_work([&]() { result = buffer->address(); }, true); return result; } -} // namespace TEST_F(DispatchFixture, TestTensorOwnershipSanity) { // Sanity test tensor read, write and update paths with synchronous // Ensure that tensor data is copied and owned as expected Device* device = this->devices_[0]; - Tensor host_tensor = ttnn::numpy::arange(0, 32 * 32 * 4, 1); + Tensor host_tensor = ttnn::arange(/*start=*/0, /*stop=*/32 * 32 * 4, /*step=*/1, DataType::FLOAT32); Tensor readback_tensor(1); auto func = [device, host_tensor, readback_tensor]() mutable { @@ -122,18 +119,12 @@ TEST_F(DispatchFixture, TestAsyncEltwiseBinary) { for (int i = 0; i < 5; i++) { // Initialize tensors and move them to DRAM - Tensor input_tensor_a = - ttnn::numpy::full( - tt::tt_metal::LegacyShape({1, 1, 1024, 1024}), static_cast(i), DataType::BFLOAT16, Layout::TILE) - .to(device); - Tensor input_tensor_b = - ttnn::numpy::full( - tt::tt_metal::LegacyShape({1, 1, 1024, 1024}), static_cast(i), DataType::BFLOAT16, Layout::TILE) - .to(device); - Tensor input_tensor_c = - ttnn::numpy::full( - tt::tt_metal::LegacyShape({1, 1, 1024, 1024}), static_cast(i), DataType::BFLOAT16, Layout::TILE) - .to(device); + Tensor input_tensor_a = ttnn::full( + ttnn::Shape({1, 1, 1024, 1024}), static_cast(i), DataType::BFLOAT16, Layout::TILE, *device); + Tensor input_tensor_b = ttnn::full( + ttnn::Shape({1, 1, 1024, 1024}), static_cast(i), DataType::BFLOAT16, Layout::TILE, *device); + Tensor input_tensor_c = ttnn::full( + ttnn::Shape({1, 1, 1024, 1024}), static_cast(i), DataType::BFLOAT16, Layout::TILE, *device); Tensor output_tensor_device = ttnn::multiply(ttnn::add(input_tensor_a, input_tensor_b), input_tensor_c); Tensor output_tensor_device_2 = ttnn::neg(ttnn::subtract(output_tensor_device, input_tensor_c)); @@ -181,12 +172,18 @@ TEST_F(DispatchFixture, TestAsyncRefCountManager) { for (int i = 0; i < 5; i++) { // Run for multiple loops to ensure deterministic behaviour with device addresses // Initialize 2 tensors on device - Tensor tensor1 = ttnn::numpy::full( - tt::tt_metal::LegacyShape({1, 1, 1024, 1024}), static_cast(i), DataType::BFLOAT16) - .to(device); - Tensor tensor2 = ttnn::numpy::full( - tt::tt_metal::LegacyShape({1, 1, 1024, 1024}), static_cast(i), DataType::BFLOAT16) - .to(device); + Tensor tensor1 = ttnn::full( + ttnn::Shape({1, 1, 1024, 1024}), + static_cast(i), + DataType::BFLOAT16, + /*layout=*/std::nullopt, + *device); + Tensor tensor2 = ttnn::full( + ttnn::Shape({1, 1, 1024, 1024}), + static_cast(i), + DataType::BFLOAT16, + /*layout=*/std::nullopt, + *device); uint32_t tensor2_device_buf_addr = get_device_buffer_address(tensor2); // Assign tensor1 to tensor2 and ensure that ref counts are appropriately updated with the buffer for tensor2 // deallocated @@ -195,18 +192,23 @@ TEST_F(DispatchFixture, TestAsyncRefCountManager) { EXPECT_EQ(tensor1.tensor_attributes->main_thread_ref_count, 2); // To check if tensor2 is deallocated, create a third tensor on device and ensure that its address matches the // prev addr for tensor2 - Tensor tensor3 = ttnn::numpy::full( - tt::tt_metal::LegacyShape({1, 1, 1024, 1024}), static_cast(i), DataType::BFLOAT16) - .to(device); + Tensor tensor3 = ttnn::full( + ttnn::Shape({1, 1, 1024, 1024}), + static_cast(i), + DataType::BFLOAT16, + /*layout=*/std::nullopt, + *device); EXPECT_EQ(get_device_buffer_address(tensor3), tensor2_device_buf_addr); EXPECT_EQ(get_device_buffer_address(tensor1), get_device_buffer_address(tensor2)); } log_info(LogTest, "Testing Device tensor self-assignment through function"); for (int i = 0; i < 5; i++) { - Tensor device_tensor = - ttnn::numpy::full( - tt::tt_metal::LegacyShape({1, 1, 1024, 1024}), static_cast(i), DataType::BFLOAT16) - .to(device); + Tensor device_tensor = ttnn::full( + ttnn::Shape({1, 1, 1024, 1024}), + static_cast(i), + DataType::BFLOAT16, + /*layout=*/std::nullopt, + *device); uint32_t device_tensor_address = get_device_buffer_address(device_tensor); // This step will copy the tensor to a temp rval and std::move it back to the caller's instance of device_tensor // Ensure ref count and address remain unchanged @@ -217,18 +219,19 @@ TEST_F(DispatchFixture, TestAsyncRefCountManager) { log_info(LogTest, "Testing Device tensor move assignment"); for (int i = 0; i < 5; i++) { - Tensor tensor1 = ttnn::numpy::full( - tt::tt_metal::LegacyShape({1, 1, 1024, 1024}), static_cast(i), DataType::BFLOAT16) - .to(device); + Tensor tensor1 = ttnn::full( + ttnn::Shape({1, 1, 1024, 1024}), + static_cast(i), + DataType::BFLOAT16, + /*layout=*/std::nullopt, + *device); Tensor tensor2 = std::move(tensor1); EXPECT_EQ(tensor2.tensor_attributes->main_thread_ref_count, 1); } log_info(LogTest, "Testing Device tensor self-assignment"); - Tensor tensor_to_self_assign = - ttnn::numpy::full( - tt::tt_metal::LegacyShape({1, 1, 1024, 1024}), static_cast(0), DataType::BFLOAT16) - .to(device); + Tensor tensor_to_self_assign = ttnn::full( + ttnn::Shape({1, 1, 1024, 1024}), static_cast(0), DataType::BFLOAT16, /*layout=*/std::nullopt, *device); uint32_t tensor_to_self_assign_address = get_device_buffer_address(tensor_to_self_assign); tensor_to_self_assign = tensor_to_self_assign; EXPECT_EQ(tensor_to_self_assign.tensor_attributes->main_thread_ref_count, 1); @@ -255,7 +258,7 @@ TEST_F(DispatchFixture, TestTensorAsyncDataMovement) { { // host_tensor only lives in this scope - Tensor host_tensor = ttnn::numpy::arange(tensor_start, tensor_stop, 1); + Tensor host_tensor = ttnn::arange(tensor_start, tensor_stop, /*step=*/1, DataType::FLOAT32); log_info(LogTest, "Spawning worker thread"); worker = std::thread([tensor_stop, host_tensor, readback_tensor, device]() mutable { // Sleep for 3 seconds to ensure that main thread deallocates host_tensor @@ -338,3 +341,5 @@ TEST_F(DispatchFixture, TestTensorAsyncDataMovement) { EXPECT_EQ(readback_tensor.get_layout(), Layout::ROW_MAJOR); EXPECT_EQ(readback_tensor.get_shape(), ttnn::Shape(tt::tt_metal::LegacyShape({1, 1, 32, tensor_stop / 32}))); } +} // namespace +} // namespace tt::tt_metal diff --git a/tests/tt_eager/tensors/test_copy_and_move.cpp b/tests/tt_eager/tensors/test_copy_and_move.cpp index f735791ad4c..656585f3351 100644 --- a/tests/tt_eager/tensors/test_copy_and_move.cpp +++ b/tests/tt_eager/tensors/test_copy_and_move.cpp @@ -2,12 +2,9 @@ // // SPDX-License-Identifier: Apache-2.0 -#include -#include -#include - #include "common/bfloat16.hpp" #include "common/constants.hpp" +#include "ttnn/cpp/ttnn/operations/creation.hpp" #include "ttnn/tensor/host_buffer/functions.hpp" #include "ttnn/tensor/host_buffer/types.hpp" #include "ttnn/tensor/tensor.hpp" @@ -40,7 +37,7 @@ bool test_tensor_copy_semantics(Device* device) { pass &= dev_a_data == dev_a_copy_data; // host tensor updated with host tensor copy assignment - Tensor host_c = ttnn::numpy::arange(0, tt_metal::compute_volume(single_tile_shape), 1) + Tensor host_c = ttnn::arange(/*start=*/0, /*stop=*/tt_metal::compute_volume(single_tile_shape), /*step=*/1) .reshape(single_tile_shape) .to(Layout::TILE); Tensor host_c_copy = ttnn::numpy::random::random(single_tile_shape).to(Layout::TILE); @@ -58,7 +55,7 @@ bool test_tensor_copy_semantics(Device* device) { pass &= dev_a_data == host_d_copy_data; // dev tensor updated with host tensor copy assignment - Tensor host_e = ttnn::numpy::ones(single_tile_shape).to(Layout::TILE); + Tensor host_e = ttnn::ones(single_tile_shape, DataType::BFLOAT16, Layout::TILE); Tensor dev_e_copy = ttnn::numpy::random::random(single_tile_shape).to(Layout::TILE).to(device); dev_e_copy = host_e; pass &= (dev_e_copy.storage_type() == StorageType::OWNED); @@ -67,8 +64,8 @@ bool test_tensor_copy_semantics(Device* device) { pass &= host_e_data == dev_e_copy_data; // dev tensor updated with dev tensor copy assignment - Tensor dev_b = ttnn::numpy::ones(single_tile_shape).to(Layout::TILE).to(device); - Tensor dev_b_copy = ttnn::numpy::zeros(single_tile_shape).to(Layout::TILE).to(device); + Tensor dev_b = ttnn::ones(single_tile_shape, DataType::BFLOAT16, Layout::TILE, *device); + Tensor dev_b_copy = ttnn::zeros(single_tile_shape, DataType::BFLOAT16, Layout::TILE, *device); dev_b_copy = dev_b; pass &= (dev_b_copy.storage_type() == StorageType::DEVICE); auto dev_b_on_host = dev_b.cpu(); diff --git a/tests/tt_metal/distributed/test_distributed.cpp b/tests/tt_metal/distributed/test_distributed.cpp index 26df7dbcc78..45afeece3be 100644 --- a/tests/tt_metal/distributed/test_distributed.cpp +++ b/tests/tt_metal/distributed/test_distributed.cpp @@ -46,7 +46,7 @@ TEST(MeshDeviceSuite, Test1x1SystemMeshInitialize) { auto& sys = tt::tt_metal::distributed::SystemMesh::instance(); auto config = - tt::tt_metal::distributed::MeshDeviceConfig({1, 1}, std::pair(0, 0), {}, MeshType::RowMajor); + tt::tt_metal::distributed::MeshDeviceConfig(MeshShape(1, 1), MeshOffset(0, 0), {}, MeshType::RowMajor); EXPECT_NO_THROW({ auto mesh = tt::tt_metal::distributed::MeshDevice::create( diff --git a/tests/tt_metal/tools/profiler/test_device_profiler.py b/tests/tt_metal/tools/profiler/test_device_profiler.py index 01d71f987aa..1c2ab3823cc 100644 --- a/tests/tt_metal/tools/profiler/test_device_profiler.py +++ b/tests/tt_metal/tools/profiler/test_device_profiler.py @@ -7,6 +7,7 @@ import re import inspect import pytest +import subprocess import pandas as pd @@ -24,6 +25,30 @@ PROG_EXMP_DIR = "programming_examples/profiler" +def get_device_data(setupStr=""): + postProcessRun = os.system( + f"cd {PROFILER_SCRIPTS_ROOT} && " f"./process_device_log.py {setupStr} --no-artifacts --no-print-stats" + ) + + assert postProcessRun == 0, f"Log process script crashed with exit code {postProcessRun}" + + devicesData = {} + with open(f"{PROFILER_ARTIFACTS_DIR}/output/device/device_analysis_data.json", "r") as devicesDataJson: + devicesData = json.load(devicesDataJson) + + return devicesData + + +def run_gtest_profiler_test(testbin, testname): + clear_profiler_runtime_artifacts() + output = subprocess.check_output( + f"cd {TT_METAL_HOME} && {testbin} --gtest_filter={testname}", stderr=subprocess.STDOUT, shell=True + ).decode("UTF-8") + print(output) + if "SKIPPED" not in output: + get_device_data() + + def run_device_profiler_test(testName=None, setup=False, slowDispatch=False): name = inspect.stack()[1].function testCommand = f"build/{PROG_EXMP_DIR}/{name}" @@ -41,17 +66,7 @@ def run_device_profiler_test(testName=None, setup=False, slowDispatch=False): if setup: setupStr = f"-s {name}" - postProcessRun = os.system( - f"cd {PROFILER_SCRIPTS_ROOT} && " f"./process_device_log.py {setupStr} --no-artifacts --no-print-stats" - ) - - assert postProcessRun == 0, f"Log process script crashed with exit code {postProcessRun}" - - devicesData = {} - with open(f"{PROFILER_ARTIFACTS_DIR}/output/device/device_analysis_data.json", "r") as devicesDataJson: - devicesData = json.load(devicesDataJson) - - return devicesData + return get_device_data(setupStr) def get_function_name(): @@ -231,6 +246,8 @@ def test_profiler_host_device_sync(): assert freq < (reportedFreq * (1 + TOLERANCE)), f"Frequency too large on device {device}" assert freq > (reportedFreq * (1 - TOLERANCE)), f"Frequency too small on device {device}" + os.environ["TT_METAL_PROFILER_SYNC"] = "0" + def test_timestamped_events(): OP_COUNT = 2 @@ -268,3 +285,19 @@ def test_timestamped_events(): devicesData["data"]["devices"]["0"]["cores"]["DEVICE"]["riscs"]["TENSIX"]["events"]["all_events"] ) assert eventCount in REF_COUNT_DICT[ENV_VAR_ARCH_NAME], "Wrong event count" + + +def test_sub_device_profiler(): + run_gtest_profiler_test( + "./build/test/tt_metal/unit_tests_dispatch", "CommandQueueSingleCardFixture.TensixTestSubDeviceBasicPrograms" + ) + os.environ["TT_METAL_PROFILER_SYNC"] = "1" + run_gtest_profiler_test( + "./build/test/tt_metal/unit_tests_dispatch", + "CommandQueueSingleCardFixture.TensixActiveEthTestSubDeviceBasicEthPrograms", + ) + os.environ["TT_METAL_PROFILER_SYNC"] = "0" + run_gtest_profiler_test( + "./build/test/tt_metal/unit_tests_dispatch_trace", + "CommandQueueSingleCardTraceFixture.TensixTestSubDeviceTraceBasicPrograms", + ) diff --git a/tests/tt_metal/tt_metal/api/circular_buffer/test_CircularBuffer_allocation.cpp b/tests/tt_metal/tt_metal/api/circular_buffer/test_CircularBuffer_allocation.cpp index 624226798ec..8c332185840 100644 --- a/tests/tt_metal/tt_metal/api/circular_buffer/test_CircularBuffer_allocation.cpp +++ b/tests/tt_metal/tt_metal/api/circular_buffer/test_CircularBuffer_allocation.cpp @@ -41,8 +41,7 @@ void validate_cb_address( for (const auto& [buffer_index, expected_address] : address_per_buffer_index) { auto base_index = UINT32_WORDS_PER_LOCAL_CIRCULAR_BUFFER_CONFIG * buffer_index; - EXPECT_EQ( - expected_address >> CIRCULAR_BUFFER_LOG2_WORD_SIZE_BYTES, cb_config_vector.at(base_index)); + EXPECT_EQ(expected_address, cb_config_vector.at(base_index)); } } } @@ -358,9 +357,8 @@ TEST_F(DeviceFixture, TensixTestUpdateCircularBufferPageSize) { for (const auto& [buffer_index, expected_address] : address_per_buffer_index) { auto base_index = UINT32_WORDS_PER_LOCAL_CIRCULAR_BUFFER_CONFIG * buffer_index; - EXPECT_EQ( - expected_address >> CIRCULAR_BUFFER_LOG2_WORD_SIZE_BYTES, - cb_config_vector.at(base_index)); // address validation + EXPECT_EQ(expected_address, + cb_config_vector.at(base_index)); // address validation EXPECT_EQ( num_pages_per_buffer_index.at(buffer_index), cb_config_vector.at(base_index + 2)); // num pages validation @@ -391,9 +389,8 @@ TEST_F(DeviceFixture, TensixTestUpdateCircularBufferPageSize) { for (const auto& [buffer_index, expected_address] : address_per_buffer_index) { auto base_index = UINT32_WORDS_PER_LOCAL_CIRCULAR_BUFFER_CONFIG * buffer_index; - EXPECT_EQ( - expected_address >> CIRCULAR_BUFFER_LOG2_WORD_SIZE_BYTES, - cb_config_vector.at(base_index)); // address validation + EXPECT_EQ(expected_address, + cb_config_vector.at(base_index)); // address validation EXPECT_EQ( num_pages_per_buffer_index.at(buffer_index), cb_config_vector.at(base_index + 2)); // num pages validation diff --git a/tests/tt_metal/tt_metal/api/circular_buffer/test_CircularBuffer_creation.cpp b/tests/tt_metal/tt_metal/api/circular_buffer/test_CircularBuffer_creation.cpp index b9d1c369973..1278c8abb7d 100644 --- a/tests/tt_metal/tt_metal/api/circular_buffer/test_CircularBuffer_creation.cpp +++ b/tests/tt_metal/tt_metal/api/circular_buffer/test_CircularBuffer_creation.cpp @@ -65,22 +65,10 @@ TEST_F(DeviceFixture, TensixTestCreateCircularBufferAtValidIndices) { uint32_t l1_unreserved_base = devices_.at(0)->get_base_allocator_addr(HalMemType::L1); std::map> golden_cb_config = { - {0, - {l1_unreserved_base >> CIRCULAR_BUFFER_LOG2_WORD_SIZE_BYTES, - cb_config.page_size >> CIRCULAR_BUFFER_LOG2_WORD_SIZE_BYTES, - cb_config.num_pages}}, - {2, - {l1_unreserved_base >> CIRCULAR_BUFFER_LOG2_WORD_SIZE_BYTES, - cb_config.page_size >> CIRCULAR_BUFFER_LOG2_WORD_SIZE_BYTES, - cb_config.num_pages}}, - {16, - {l1_unreserved_base >> CIRCULAR_BUFFER_LOG2_WORD_SIZE_BYTES, - cb_config.page_size >> CIRCULAR_BUFFER_LOG2_WORD_SIZE_BYTES, - cb_config.num_pages}}, - {24, - {l1_unreserved_base >> CIRCULAR_BUFFER_LOG2_WORD_SIZE_BYTES, - cb_config.page_size >> CIRCULAR_BUFFER_LOG2_WORD_SIZE_BYTES, - cb_config.num_pages}}}; + {0, {l1_unreserved_base, cb_config.page_size, cb_config.num_pages}}, + {2, {l1_unreserved_base, cb_config.page_size, cb_config.num_pages}}, + {16, {l1_unreserved_base, cb_config.page_size, cb_config.num_pages}}, + {24, {l1_unreserved_base, cb_config.page_size, cb_config.num_pages}}}; std::map data_format_spec = { {0, cb_config.data_format}, {2, cb_config.data_format}, diff --git a/tests/tt_metal/tt_metal/common/command_queue_fixture.hpp b/tests/tt_metal/tt_metal/common/command_queue_fixture.hpp index ff143198763..b51338d1d0f 100644 --- a/tests/tt_metal/tt_metal/common/command_queue_fixture.hpp +++ b/tests/tt_metal/tt_metal/common/command_queue_fixture.hpp @@ -8,7 +8,7 @@ #include "dispatch_fixture.hpp" #include "hostdevcommon/common_values.hpp" #include "impl/device/device.hpp" -#include "umd/device/tt_cluster_descriptor_types.h" +#include "umd/device/types/cluster_descriptor_types.h" #include "tt_metal/host_api.hpp" #include "tt_metal/detail/tt_metal.hpp" #include "tt_metal/test_utils/env_vars.hpp" diff --git a/tests/tt_metal/tt_metal/common/dispatch_fixture.hpp b/tests/tt_metal/tt_metal/common/dispatch_fixture.hpp index 7656ac8c147..57bfbcdb934 100644 --- a/tests/tt_metal/tt_metal/common/dispatch_fixture.hpp +++ b/tests/tt_metal/tt_metal/common/dispatch_fixture.hpp @@ -46,7 +46,7 @@ class DispatchFixture : public ::testing::Test { } void ReadBuffer( tt::tt_metal::Device* device, - std::shared_ptr out_buffer, + const std::shared_ptr& out_buffer, std::vector& dst_vec) { if (this->slow_dispatch_) { tt::tt_metal::detail::ReadFromBuffer(out_buffer, dst_vec); diff --git a/tests/tt_metal/tt_metal/common/multi_device_fixture.hpp b/tests/tt_metal/tt_metal/common/multi_device_fixture.hpp index 7e1626ba7fe..21cf4dc2943 100644 --- a/tests/tt_metal/tt_metal/common/multi_device_fixture.hpp +++ b/tests/tt_metal/tt_metal/common/multi_device_fixture.hpp @@ -8,7 +8,7 @@ #include "host_api.hpp" #include "dispatch_fixture.hpp" -#include "umd/device/tt_cluster_descriptor_types.h" +#include "umd/device/types/cluster_descriptor_types.h" #include "tt_metal/test_utils/env_vars.hpp" #include "tt_metal/impl/device/device_pool.hpp" diff --git a/tests/tt_metal/tt_metal/dispatch/dispatch_program/test_EnqueueProgram.cpp b/tests/tt_metal/tt_metal/dispatch/dispatch_program/test_EnqueueProgram.cpp index 356f7766820..5dd7eea0042 100644 --- a/tests/tt_metal/tt_metal/dispatch/dispatch_program/test_EnqueueProgram.cpp +++ b/tests/tt_metal/tt_metal/dispatch/dispatch_program/test_EnqueueProgram.cpp @@ -101,7 +101,7 @@ bool cb_config_successful(Device* device, Program& program, const DummyProgramMu tt::tt_metal::detail::ReadFromDeviceL1( device, core_coord, - program.get_sem_base_addr(device, core_coord, CoreType::WORKER), + program.get_cb_base_addr(device, core_coord, CoreType::WORKER), cb_config_buffer_size, cb_config_vector); @@ -110,8 +110,8 @@ bool cb_config_successful(Device* device, Program& program, const DummyProgramMu const uint32_t index = program_config.cb_config_vector[i].cb_id * sizeof(uint32_t); const uint32_t cb_num_pages = program_config.cb_config_vector[i].num_pages; const uint32_t cb_size = cb_num_pages * program_config.cb_config_vector[i].page_size; - const bool addr_match = cb_config_vector.at(index) == ((cb_addr) >> 4); - const bool size_match = cb_config_vector.at(index + 1) == (cb_size >> 4); + const bool addr_match = cb_config_vector.at(index) == cb_addr; + const bool size_match = cb_config_vector.at(index + 1) == cb_size; const bool num_pages_match = cb_config_vector.at(index + 2) == cb_num_pages; pass &= (addr_match and size_match and num_pages_match); @@ -860,15 +860,15 @@ TEST_F(CommandQueueSingleCardProgramFixture, TensixTestMultiCBSharedAddressSpace uint32_t cb_addr = device->get_base_allocator_addr(HalMemType::L1); uint32_t intermediate_index = intermediate_cb * sizeof(uint32_t); - bool addr_match_intermediate = cb_config_vector.at(intermediate_index) == ((cb_addr) >> 4); - bool size_match_intermediate = cb_config_vector.at(intermediate_index + 1) == (cb_size >> 4); + bool addr_match_intermediate = cb_config_vector.at(intermediate_index) == (cb_addr); + bool size_match_intermediate = cb_config_vector.at(intermediate_index + 1) == (cb_size); bool num_pages_match_intermediate = cb_config_vector.at(intermediate_index + 2) == num_tiles; bool pass_intermediate = (addr_match_intermediate and size_match_intermediate and num_pages_match_intermediate); EXPECT_TRUE(pass_intermediate); uint32_t out_index = out_cb * sizeof(uint32_t); - bool addr_match_out = cb_config_vector.at(out_index) == ((cb_addr) >> 4); - bool size_match_out = cb_config_vector.at(out_index + 1) == (cb_size >> 4); + bool addr_match_out = cb_config_vector.at(out_index) == cb_addr; + bool size_match_out = cb_config_vector.at(out_index + 1) == cb_size; bool num_pages_match_out = cb_config_vector.at(out_index + 2) == num_tiles; bool pass_out = (addr_match_out and size_match_out and num_pages_match_out); EXPECT_TRUE(pass_out); diff --git a/tests/tt_metal/tt_metal/dispatch/dispatch_program/test_dispatch_program_with_kernel_created_from_string.cpp b/tests/tt_metal/tt_metal/dispatch/dispatch_program/test_dispatch_program_with_kernel_created_from_string.cpp index d54453532aa..d0c969bee1b 100644 --- a/tests/tt_metal/tt_metal/dispatch/dispatch_program/test_dispatch_program_with_kernel_created_from_string.cpp +++ b/tests/tt_metal/tt_metal/dispatch/dispatch_program/test_dispatch_program_with_kernel_created_from_string.cpp @@ -11,7 +11,7 @@ #include "impl/kernels/data_types.hpp" #include "impl/kernels/kernel_types.hpp" #include "impl/program/program.hpp" -#include "umd/device/tt_cluster_descriptor_types.h" +#include "umd/device/types/cluster_descriptor_types.h" #include "program_with_kernel_created_from_string_fixture.hpp" TEST_F(ProgramWithKernelCreatedFromStringFixture, TensixDataMovementKernel) { diff --git a/tests/tt_metal/tt_metal/dispatch/dispatch_program/test_sub_device.cpp b/tests/tt_metal/tt_metal/dispatch/dispatch_program/test_sub_device.cpp index 6016433f556..f140433f3a9 100644 --- a/tests/tt_metal/tt_metal/dispatch/dispatch_program/test_sub_device.cpp +++ b/tests/tt_metal/tt_metal/dispatch/dispatch_program/test_sub_device.cpp @@ -108,6 +108,7 @@ TEST_F(CommandQueueSingleCardFixture, TensixTestSubDeviceBasicPrograms) { EnqueueProgram(device->command_queue(), incrementer_program, false); } Synchronize(device); + detail::DumpDeviceProfileResults(device); } } @@ -136,5 +137,6 @@ TEST_F(CommandQueueSingleCardFixture, TensixActiveEthTestSubDeviceBasicEthProgra EnqueueProgram(device->command_queue(), incrementer_program, false); } Synchronize(device); + detail::DumpDeviceProfileResults(device); } } diff --git a/tests/tt_metal/tt_metal/dispatch/dispatch_trace/test_sub_device.cpp b/tests/tt_metal/tt_metal/dispatch/dispatch_trace/test_sub_device.cpp index 74fa8256ab8..5caff9052aa 100644 --- a/tests/tt_metal/tt_metal/dispatch/dispatch_trace/test_sub_device.cpp +++ b/tests/tt_metal/tt_metal/dispatch/dispatch_trace/test_sub_device.cpp @@ -63,6 +63,7 @@ TEST_F(CommandQueueSingleCardTraceFixture, TensixTestSubDeviceTraceBasicPrograms ReplayTrace(device, device->command_queue().id(), tid_2, false); } Synchronize(device); + detail::DumpDeviceProfileResults(device); } } diff --git a/tests/tt_metal/tt_metal/dispatch/multi_command_queue_fixture.hpp b/tests/tt_metal/tt_metal/dispatch/multi_command_queue_fixture.hpp index 17a4da1cd7b..75c4c3f5cca 100644 --- a/tests/tt_metal/tt_metal/dispatch/multi_command_queue_fixture.hpp +++ b/tests/tt_metal/tt_metal/dispatch/multi_command_queue_fixture.hpp @@ -9,7 +9,7 @@ #include "hostdevcommon/common_values.hpp" #include "impl/device/device.hpp" #include "llrt/hal.hpp" -#include "umd/device/tt_cluster_descriptor_types.h" +#include "umd/device/types/cluster_descriptor_types.h" #include "tt_metal/host_api.hpp" #include "tt_metal/detail/tt_metal.hpp" #include "tt_metal/test_utils/env_vars.hpp" diff --git a/tests/tt_metal/tt_metal/dispatch/random_program_fixture.hpp b/tests/tt_metal/tt_metal/dispatch/random_program_fixture.hpp index 55c9d3d40a4..fc87c2b58df 100644 --- a/tests/tt_metal/tt_metal/dispatch/random_program_fixture.hpp +++ b/tests/tt_metal/tt_metal/dispatch/random_program_fixture.hpp @@ -9,6 +9,7 @@ #include "llrt/hal.hpp" #include "tt_metal/host_api.hpp" #include "tt_metal/detail/tt_metal.hpp" +#include "tt_metal/hw/inc/circular_buffer_constants.h" #include "tt_metal/impl/kernels/kernel.hpp" #include "tt_metal/common/tt_backend_api_types.hpp" #include "dispatch_test_utils.hpp" @@ -141,13 +142,14 @@ class RandomProgramFixture : virtual public CommandQueueSingleCardProgramFixture const uint32_t num_cbs = this->generate_random_num(min, max); std::vector cb_page_sizes; for (uint32_t cb_idx = 0; cb_idx < num_cbs; cb_idx++) { - const uint32_t cb_page_size = this->generate_random_num(MIN_CB_PAGE_SIZE, MAX_CB_PAGE_SIZE, 16); + const uint32_t cb_page_size = + this->generate_random_num(MIN_CB_PAGE_SIZE, MAX_CB_PAGE_SIZE, CIRCULAR_BUFFER_COMPUTE_WORD_SIZE); const uint32_t cb_total_size = this->generate_random_num(MIN_CB_TOTAL_SIZE, MAX_CB_TOTAL_SIZE, cb_page_size); CircularBufferConfig config = CircularBufferConfig(cb_total_size, {{cb_idx, tt::DataFormat::Float16_b}}) .set_page_size(cb_idx, cb_page_size); CreateCircularBuffer(program, cores, config); - cb_page_sizes.push_back(cb_page_size / 16); + cb_page_sizes.push_back(cb_page_size); } return cb_page_sizes; } diff --git a/tests/tt_metal/tt_metal/integration/test_autonomous_relay_streams.cpp b/tests/tt_metal/tt_metal/integration/test_autonomous_relay_streams.cpp index 1a2618ddc09..4c4077c4adc 100644 --- a/tests/tt_metal/tt_metal/integration/test_autonomous_relay_streams.cpp +++ b/tests/tt_metal/tt_metal/integration/test_autonomous_relay_streams.cpp @@ -11,7 +11,7 @@ #include #include "gtest/gtest.h" -#include "umd/device/tt_arch_types.h" +#include "umd/device/types/arch.h" #include "command_queue_fixture.hpp" #include "tt_metal/common/logger.hpp" #include "impl/device/device.hpp" diff --git a/tests/tt_metal/tt_metal/perf_microbenchmark/3_pcie_transfer/kernels/pull_from_pcie.cpp b/tests/tt_metal/tt_metal/perf_microbenchmark/3_pcie_transfer/kernels/pull_from_pcie.cpp index 12ad5aa9c30..ff9b279d489 100644 --- a/tests/tt_metal/tt_metal/perf_microbenchmark/3_pcie_transfer/kernels/pull_from_pcie.cpp +++ b/tests/tt_metal/tt_metal/perf_microbenchmark/3_pcie_transfer/kernels/pull_from_pcie.cpp @@ -17,7 +17,7 @@ void kernel_main() { volatile tt_l1_ptr uint32_t* done_address_ptr = reinterpret_cast(done_address); - uint64_t pcie_noc_xy_encoding = (uint64_t)NOC_XY_PCIE_ENCODING(PCIE_NOC_X, PCIE_NOC_Y, NOC_INDEX); + uint64_t pcie_noc_xy_encoding = (uint64_t)NOC_XY_PCIE_ENCODING(PCIE_NOC_X, PCIE_NOC_Y); while (done_address_ptr[0] == 0) { uint64_t host_src_addr = pcie_noc_xy_encoding | pcie_read_ptr; noc_async_read(host_src_addr, done_address, read_sizeB); diff --git a/tests/tt_metal/tt_metal/perf_microbenchmark/dispatch/test_pgm_dispatch.cpp b/tests/tt_metal/tt_metal/perf_microbenchmark/dispatch/test_pgm_dispatch.cpp index 600483ffe02..22127a8fb8a 100644 --- a/tests/tt_metal/tt_metal/perf_microbenchmark/dispatch/test_pgm_dispatch.cpp +++ b/tests/tt_metal/tt_metal/perf_microbenchmark/dispatch/test_pgm_dispatch.cpp @@ -2,7 +2,7 @@ // // SPDX-License-Identifier: Apache-2.0 -#include "umd/device/tt_cluster_descriptor_types.h" +#include "umd/device/types/cluster_descriptor_types.h" #include "tt_metal/host_api.hpp" #include "tt_metal/detail/tt_metal.hpp" #include "tt_metal/impl/dispatch/command_queue.hpp" diff --git a/tests/tt_metal/tt_metal/perf_microbenchmark/ethernet/test_ethernet_bidirectional_bandwidth_no_edm.cpp b/tests/tt_metal/tt_metal/perf_microbenchmark/ethernet/test_ethernet_bidirectional_bandwidth_no_edm.cpp index f5353a53a35..f351b0a75d7 100644 --- a/tests/tt_metal/tt_metal/perf_microbenchmark/ethernet/test_ethernet_bidirectional_bandwidth_no_edm.cpp +++ b/tests/tt_metal/tt_metal/perf_microbenchmark/ethernet/test_ethernet_bidirectional_bandwidth_no_edm.cpp @@ -9,7 +9,7 @@ #include #include -#include "umd/device/tt_arch_types.h" +#include "umd/device/types/arch.h" #include "impl/device/device.hpp" #include "impl/kernels/kernel_types.hpp" #include "tt_backend_api_types.hpp" diff --git a/tests/tt_metal/tt_metal/perf_microbenchmark/ethernet/test_ethernet_hop_latencies_no_edm.cpp b/tests/tt_metal/tt_metal/perf_microbenchmark/ethernet/test_ethernet_hop_latencies_no_edm.cpp index fbbdcfaf5c7..3123ea1736a 100644 --- a/tests/tt_metal/tt_metal/perf_microbenchmark/ethernet/test_ethernet_hop_latencies_no_edm.cpp +++ b/tests/tt_metal/tt_metal/perf_microbenchmark/ethernet/test_ethernet_hop_latencies_no_edm.cpp @@ -10,7 +10,7 @@ #include "tt_metal/distributed/mesh_device_view.hpp" #include "tt_metal/common/logger.hpp" -#include "umd/device/tt_arch_types.h" +#include "umd/device/types/arch.h" #include "impl/device/device.hpp" #include "impl/kernels/data_types.hpp" #include "impl/kernels/kernel_types.hpp" diff --git a/tests/tt_metal/tt_metal/perf_microbenchmark/ethernet/test_ethernet_link_ping_latency_no_edm.cpp b/tests/tt_metal/tt_metal/perf_microbenchmark/ethernet/test_ethernet_link_ping_latency_no_edm.cpp index 1cc7f4e6c7e..d2544f39271 100644 --- a/tests/tt_metal/tt_metal/perf_microbenchmark/ethernet/test_ethernet_link_ping_latency_no_edm.cpp +++ b/tests/tt_metal/tt_metal/perf_microbenchmark/ethernet/test_ethernet_link_ping_latency_no_edm.cpp @@ -10,7 +10,7 @@ #include #include -#include "umd/device/tt_arch_types.h" +#include "umd/device/types/arch.h" #include "impl/device/device.hpp" #include "impl/kernels/kernel_types.hpp" #include "tt_backend_api_types.hpp" diff --git a/tests/tt_metal/tt_metal/perf_microbenchmark/ethernet/test_workers_and_erisc_datamover_unidirectional.cpp b/tests/tt_metal/tt_metal/perf_microbenchmark/ethernet/test_workers_and_erisc_datamover_unidirectional.cpp index c7e0da23d94..d8d4896badf 100644 --- a/tests/tt_metal/tt_metal/perf_microbenchmark/ethernet/test_workers_and_erisc_datamover_unidirectional.cpp +++ b/tests/tt_metal/tt_metal/perf_microbenchmark/ethernet/test_workers_and_erisc_datamover_unidirectional.cpp @@ -8,7 +8,7 @@ #include #include -#include "umd/device/tt_arch_types.h" +#include "umd/device/types/arch.h" #include "tt_backend_api_types.hpp" #include "tt_metal/common/core_coord.hpp" #include "tt_metal/common/math.hpp" diff --git a/tests/tt_metal/tt_metal/test_kernels/dataflow/unit_tests/command_queue/pcie_write_16b.cpp b/tests/tt_metal/tt_metal/test_kernels/dataflow/unit_tests/command_queue/pcie_write_16b.cpp index 9867f4c7341..173519b7ca2 100644 --- a/tests/tt_metal/tt_metal/test_kernels/dataflow/unit_tests/command_queue/pcie_write_16b.cpp +++ b/tests/tt_metal/tt_metal/test_kernels/dataflow/unit_tests/command_queue/pcie_write_16b.cpp @@ -11,7 +11,7 @@ void kernel_main() { constexpr uint32_t base_pcie_dst_address = get_compile_time_arg_val(1); constexpr uint32_t num_16b_writes = get_compile_time_arg_val(2); - uint64_t pcie_core_noc_encoding = uint64_t(NOC_XY_PCIE_ENCODING(PCIE_NOC_X, PCIE_NOC_Y, NOC_INDEX)); + uint64_t pcie_core_noc_encoding = uint64_t(NOC_XY_PCIE_ENCODING(PCIE_NOC_X, PCIE_NOC_Y)); uint32_t l1_src_address = base_l1_src_address; uint32_t pcie_dst_address = base_pcie_dst_address; diff --git a/tests/tt_metal/tt_metal/test_kernels/dataflow/unit_tests/command_queue/random_program.cpp b/tests/tt_metal/tt_metal/test_kernels/dataflow/unit_tests/command_queue/random_program.cpp index f3e64ae1b3e..3d5ee14d71b 100644 --- a/tests/tt_metal/tt_metal/test_kernels/dataflow/unit_tests/command_queue/random_program.cpp +++ b/tests/tt_metal/tt_metal/test_kernels/dataflow/unit_tests/command_queue/random_program.cpp @@ -45,7 +45,7 @@ void kernel_main() { (uint32_t tt_l1_ptr*)(kernel_config_base + mailboxes->launch[mailboxes->launch_msg_rd_ptr].kernel_config.local_cb_offset); uint32_t cb_val = reinterpret_cast(cb_l1_base + i * 4)[3]; - uint32_t expected = ((i + 1) * page_size) >> CIRCULAR_BUFFER_LOG2_WORD_SIZE_BYTES; + uint32_t expected = ((i + 1) * page_size); if (cb_val != expected) { DPRINT << "Problem with CB idx: " << i << " Expected: " << expected << " Got: " << cb_val << ENDL(); while (true); // Purposefully hang the kernel if CBs did not arrive correctly diff --git a/tests/ttnn/distributed/test_distributed.cpp b/tests/ttnn/distributed/test_distributed.cpp index 94ffe01bd3a..fb5f53988c5 100644 --- a/tests/ttnn/distributed/test_distributed.cpp +++ b/tests/ttnn/distributed/test_distributed.cpp @@ -26,4 +26,20 @@ TEST_F(DistributedTest, TestSystemMeshTearDownWithoutClose) { EXPECT_GT(cols, 0); } +TEST_F(DistributedTest, TestMemoryAllocationStatistics) { + auto mesh = ttnn::distributed::open_mesh_device( + {2, 4}, DEFAULT_L1_SMALL_SIZE, DEFAULT_TRACE_REGION_SIZE, 1, tt::tt_metal::DispatchCoreType::WORKER); + auto stats = mesh->get_memory_allocation_statistics(tt::tt_metal::BufferType::DRAM); + for (auto* device : mesh->get_devices()) { + auto device_stats = device->get_memory_allocation_statistics(tt::tt_metal::BufferType::DRAM); + EXPECT_EQ(stats.total_allocatable_size_bytes, device_stats.total_allocatable_size_bytes); + } +} + +TEST_F(DistributedTest, TestNumDramChannels) { + auto mesh = ttnn::distributed::open_mesh_device( + {2, 4}, DEFAULT_L1_SMALL_SIZE, DEFAULT_TRACE_REGION_SIZE, 1, tt::tt_metal::DispatchCoreType::WORKER); + EXPECT_EQ(mesh->num_dram_channels(), 96); // 8 devices * 12 channels +} + } // namespace ttnn::distributed::test diff --git a/tests/ttnn/integration_tests/squeezebert/test_ttnn_squeezebert.py b/tests/ttnn/integration_tests/squeezebert/test_ttnn_squeezebert.py new file mode 100644 index 00000000000..65197936dbf --- /dev/null +++ b/tests/ttnn/integration_tests/squeezebert/test_ttnn_squeezebert.py @@ -0,0 +1,361 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import ttnn +import torch +import pytest +import transformers +from models.utility_functions import torch_random, is_grayskull +from tests.ttnn.utils_for_testing import assert_with_pcc +from ttnn.model_preprocessing import preprocess_model_parameters +from models.demos.squeezebert.tt import ttnn_functional_squeezebert + + +@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) +@pytest.mark.parametrize("model_name", ["squeezebert/squeezebert-uncased"]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("sequence_size", [384]) +@pytest.mark.parametrize("torch_dtype", [torch.bfloat16]) +def test_squeezebert_attention(device, model_name, batch_size, sequence_size, torch_dtype, reset_seeds): + config = transformers.SqueezeBertConfig.from_pretrained(model_name) + model = transformers.models.squeezebert.modeling_squeezebert.SqueezeBertSelfAttention( + config, cin=config.hidden_size, q_groups=config.q_groups, k_groups=config.k_groups, v_groups=config.v_groups + ).eval() + state_dict = model.state_dict() + model = model.to(torch_dtype) + + torch_hidden_states = torch_random((batch_size, sequence_size, config.hidden_size), -0.1, 0.1, dtype=torch_dtype) + torch_hidden_states = torch_hidden_states.permute(0, 2, 1) + + torch_attention_mask = torch.ones(batch_size, sequence_size, dtype=torch_dtype) + torch_attention_mask = torch_attention_mask[:, None, None, :] + + torch_output = model(torch_hidden_states, attention_mask=torch_attention_mask, output_attentions=False) + + ttnn_attention_mask = ttnn.from_torch(torch_attention_mask, layout=ttnn.TILE_LAYOUT, device=device) + + tt_model_name = f"ttnn_{model_name}_optimized" + + parameters = preprocess_model_parameters( + model_name=tt_model_name, + initialize_model=lambda: model, + custom_preprocessor=ttnn_functional_squeezebert.custom_preprocessor, + device=device, + ) + + hidden_states = ttnn.from_torch(torch_hidden_states, layout=ttnn.TILE_LAYOUT, device=device) + + output = ttnn_functional_squeezebert.squeezebert_attention( + config, + hidden_states, + attention_mask=ttnn_attention_mask, + state_dict=state_dict, + base_addr=f"", + parameters=parameters, + device=device, + reader_patterns_cache={}, + ) + + output = ttnn.to_torch(output) + + assert_with_pcc(torch_output["context_layer"], output, 0.99) + + +@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) +@pytest.mark.parametrize("model_name", ["squeezebert/squeezebert-uncased"]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("sequence_size", [384]) +@pytest.mark.parametrize("torch_dtype", [torch.bfloat16]) +def test_squeezebert_intermediate(device, model_name, batch_size, sequence_size, torch_dtype, reset_seeds): + config = transformers.SqueezeBertConfig.from_pretrained(model_name) + model = transformers.models.squeezebert.modeling_squeezebert.ConvActivation( + cin=config.hidden_size, cout=config.intermediate_size, groups=config.intermediate_groups, act=config.hidden_act + ).eval() + state_dict = model.state_dict() + model = model.to(torch_dtype) + + torch_hidden_states = torch_random((batch_size, sequence_size, config.hidden_size), -0.1, 0.1, dtype=torch_dtype) + torch_hidden_states = torch_hidden_states.permute(0, 2, 1) + + torch_output = model(torch_hidden_states) + + tt_model_name = f"ttnn_{model_name}_optimized" + + parameters = preprocess_model_parameters( + model_name=tt_model_name, + initialize_model=lambda: model, + custom_preprocessor=ttnn_functional_squeezebert.custom_preprocessor, + device=device, + ) + + hidden_states = ttnn.from_torch(torch_hidden_states, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device) + + output = ttnn_functional_squeezebert.squeezebert_intermediate( + config=config, + hidden_states=hidden_states, + state_dict=state_dict, + base_addr=f"", + parameters=parameters, + device=device, + ) + output = ttnn.to_torch(output) + + assert_with_pcc(torch_output, output.to(torch_output.dtype), 0.99) + + +@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) +@pytest.mark.parametrize("model_name", ["squeezebert/squeezebert-uncased"]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("sequence_size", [384]) +@pytest.mark.parametrize("torch_dtype", [torch.bfloat16]) +def test_squeezebert_output(device, model_name, batch_size, sequence_size, torch_dtype, reset_seeds): + config = transformers.SqueezeBertConfig.from_pretrained(model_name) + model = transformers.models.squeezebert.modeling_squeezebert.ConvDropoutLayerNorm( + cin=config.intermediate_size, + cout=config.hidden_size, + groups=config.output_groups, + dropout_prob=config.hidden_dropout_prob, + ).eval() + state_dict = model.state_dict() + model = model.to(torch_dtype) + + torch_hidden_states = torch_random( + (batch_size, sequence_size, config.intermediate_size), -0.1, 0.1, dtype=torch_dtype + ) + torch_hidden_states = torch_hidden_states.permute(0, 2, 1) + + torch_residual = torch_random((batch_size, sequence_size, config.hidden_size), -0.1, 0.1, dtype=torch_dtype) + torch_residual = torch_residual.permute(0, 2, 1) + + torch_output = model(torch_hidden_states, torch_residual) + + tt_model_name = f"ttnn_{model_name}_optimized" + + parameters = preprocess_model_parameters( + model_name=tt_model_name, + initialize_model=lambda: model, + custom_preprocessor=ttnn_functional_squeezebert.custom_preprocessor, + device=device, + ) + + hidden_states = ttnn.from_torch(torch_hidden_states, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device) + residual = ttnn.from_torch(torch_residual, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device) + + output = ttnn_functional_squeezebert.squeezebert_conv_layernorm( + config=config, + hidden_states=hidden_states, + input_tensor=residual, + state_dict=state_dict, + base_addr=f"", + parameters=parameters, + device=device, + cin=config.intermediate_size, + cout=config.hidden_size, + groups=config.output_groups, + ) + output = ttnn.to_torch(output) + + assert_with_pcc(torch_output, output.to(torch_output.dtype), 0.99) + + +@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) +@pytest.mark.parametrize("model_name", ["squeezebert/squeezebert-uncased"]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("sequence_size", [384]) +@pytest.mark.parametrize("torch_dtype", [torch.bfloat16]) +def test_squeezebert_layer(device, model_name, batch_size, sequence_size, torch_dtype, reset_seeds): + config = transformers.SqueezeBertConfig.from_pretrained(model_name) + model = transformers.models.squeezebert.modeling_squeezebert.SqueezeBertModule(config).eval() + state_dict = model.state_dict() + model = model.to(torch_dtype) + + torch_hidden_states = torch_random((batch_size, sequence_size, config.hidden_size), -0.1, 0.1, dtype=torch_dtype) + torch_hidden_states = torch_hidden_states.permute(0, 2, 1) + + torch_attention_mask = torch.ones(batch_size, sequence_size, dtype=torch_dtype) + torch_attention_mask = torch_attention_mask[:, None, None, :] + + torch_output = model(torch_hidden_states, attention_mask=torch_attention_mask, output_attentions=False) + + tt_model_name = f"ttnn_{model_name}_optimized" + + parameters = preprocess_model_parameters( + model_name=tt_model_name, + initialize_model=lambda: model, + device=device, + custom_preprocessor=ttnn_functional_squeezebert.custom_preprocessor, + ) + + hidden_states = ttnn.from_torch(torch_hidden_states, layout=ttnn.TILE_LAYOUT, device=device) + ttnn_attention_mask = ttnn.from_torch(torch_attention_mask, layout=ttnn.TILE_LAYOUT, device=device) + + output = ttnn_functional_squeezebert.squeezebert_layer( + config, + hidden_states, + attention_mask=ttnn_attention_mask, + state_dict=state_dict, + base_addr=f"", + parameters=parameters, + device=device, + reader_patterns_cache={}, + ) + + output = ttnn.to_torch(output) + + assert_with_pcc(torch_output["feature_map"], output, 0.99) + + +@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) +@pytest.mark.parametrize("model_name", ["squeezebert/squeezebert-uncased"]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("sequence_size", [384]) +@pytest.mark.parametrize("torch_dtype", [torch.bfloat16]) +def test_squeezebert_encoder(device, model_name, batch_size, sequence_size, torch_dtype, reset_seeds): + config = transformers.SqueezeBertConfig.from_pretrained(model_name) + model = transformers.models.squeezebert.modeling_squeezebert.SqueezeBertEncoder(config).eval() + state_dict = model.state_dict() + model = model.to(torch_dtype) + + torch_hidden_states = torch_random((batch_size, sequence_size, config.hidden_size), -0.1, 0.1, dtype=torch_dtype) + torch_attention_mask = torch.ones(batch_size, sequence_size, dtype=torch_dtype) + torch_attention_mask = torch_attention_mask[:, None, None, :] + + torch_output = model(torch_hidden_states, attention_mask=torch_attention_mask).last_hidden_state + + tt_model_name = f"ttnn_{model_name}_optimized" + + parameters = preprocess_model_parameters( + model_name=tt_model_name, + initialize_model=lambda: model, + custom_preprocessor=ttnn_functional_squeezebert.custom_preprocessor, + device=device, + ) + + hidden_states = ttnn.from_torch(torch_hidden_states, ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device) + ttnn_attention_mask = ttnn.from_torch(torch_attention_mask, layout=ttnn.TILE_LAYOUT, device=device) + + output = ttnn_functional_squeezebert.squeezebert_encoder( + config, + hidden_states, + attention_mask=ttnn_attention_mask, + state_dict=state_dict, + base_addr=f"", + parameters=parameters, + device=device, + reader_patterns_cache={}, + ) + + output = ttnn.to_torch(output) + + assert_with_pcc(torch_output, output, 0.99) + + +@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) +@pytest.mark.parametrize("model_name", ["squeezebert/squeezebert-uncased"]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("sequence_size", [384]) +def test_squeezebert_model(device, model_name, batch_size, sequence_size, reset_seeds): + config = transformers.SqueezeBertConfig.from_pretrained(model_name) + model = transformers.SqueezeBertModel.from_pretrained(model_name, config=config).eval() + state_dict = model.state_dict() + model = model.to(torch.bfloat16) + + torch_input_ids = torch.randint(0, config.vocab_size, (batch_size, sequence_size)).to(torch.int32) + torch_token_type_ids = torch.ones((batch_size, sequence_size), dtype=torch.int32) + torch_position_ids = torch.ones((batch_size, sequence_size), dtype=torch.int32) + torch_attention_mask = torch.ones(1, sequence_size, dtype=torch.bfloat16) + + torch_output = model( + torch_input_ids, + token_type_ids=torch_token_type_ids, + position_ids=torch_position_ids, + attention_mask=torch_attention_mask, + ).last_hidden_state + + tt_model_name = f"ttnn_{model_name}_optimized" + + parameters = preprocess_model_parameters( + model_name=tt_model_name, + initialize_model=lambda: model, + device=device, + custom_preprocessor=ttnn_functional_squeezebert.custom_preprocessor, + ) + + ttnn_bert_inputs = ttnn_functional_squeezebert.preprocess_inputs( + torch_input_ids, + torch_token_type_ids, + torch_position_ids, + torch_attention_mask, + device=device, + ) + + output = ttnn_functional_squeezebert.squeezebert( + config, + *ttnn_bert_inputs, + state_dict=state_dict, + base_addr=f"", + parameters=parameters, + device=device, + reader_patterns_cache={}, + ) + output = ttnn.to_torch(output) + + assert_with_pcc(torch_output, output, 0.98 if is_grayskull() else 0.99) + + +@pytest.mark.parametrize("model_name", ["squeezebert/squeezebert-uncased"]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("sequence_size", [384]) +@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) +def test_squeezebert_for_question_answering(device, model_name, batch_size, sequence_size, reset_seeds): + rf_model = transformers.SqueezeBertForQuestionAnswering.from_pretrained(model_name) + config = transformers.SqueezeBertConfig.from_pretrained(model_name) + state_dict = rf_model.state_dict() + + torch_squeezebert_input = torch.randint(0, config.vocab_size, (batch_size, sequence_size)).to(torch.int32) + torch_token_type_ids = torch.zeros((batch_size, sequence_size), dtype=torch.int32) + torch_position_ids = torch.zeros((batch_size, sequence_size), dtype=torch.int32) + torch_attention_mask = torch.ones(batch_size, sequence_size) + + torch_output = rf_model( + input_ids=torch_squeezebert_input, + token_type_ids=torch_token_type_ids, + position_ids=torch_position_ids, + attention_mask=torch_attention_mask, + ) + + tt_model_name = f"ttnn_{model_name}_optimized" + + parameters = preprocess_model_parameters( + model_name=tt_model_name, + initialize_model=lambda: rf_model, + custom_preprocessor=ttnn_functional_squeezebert.custom_preprocessor, + device=device, + ) + + ttnn_squeezebert_inputs = ttnn_functional_squeezebert.preprocess_inputs( + torch_squeezebert_input, + torch_token_type_ids, + torch_position_ids, + torch_attention_mask, + device=device, + ) + + tt_output = ttnn_functional_squeezebert.squeezebert_for_question_answering( + config, + *ttnn_squeezebert_inputs, + state_dict=state_dict, + base_addr=f"transformer.", + parameters=parameters, + device=device, + reader_patterns_cache={}, + ) + + tt_output = ttnn.to_torch(tt_output) + + tt_start_logits = tt_output[..., :, 0] + tt_end_logits = tt_output[..., :, 1] + + assert_with_pcc(torch_output.start_logits, tt_start_logits, 0.84 if is_grayskull() else 0.88) + assert_with_pcc(torch_output.end_logits, tt_end_logits, 0.85 if is_grayskull() else 0.93) diff --git a/tests/ttnn/integration_tests/stable_diffusion/test_sharded_attention.py b/tests/ttnn/integration_tests/stable_diffusion/test_sharded_attention.py deleted file mode 100644 index 1b45761e11c..00000000000 --- a/tests/ttnn/integration_tests/stable_diffusion/test_sharded_attention.py +++ /dev/null @@ -1,966 +0,0 @@ -# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. - -# SPDX-License-Identifier: Apache-2.0 - -import torch -import math -import pytest -import ttnn - -from tests.ttnn.utils_for_testing import assert_with_pcc -from models.utility_functions import ( - comp_pcc, - tt2torch_tensor, - torch2tt_tensor, - is_wormhole_b0, - skip_for_grayskull, -) -from models.demos.wormhole.stable_diffusion.tt.ttnn_functional_utility_functions import ( - determine_largest_subblock_size, - determine_blocking, -) - - -# Test matmul attention sequence with InterleavedToShardedPartialOp -@skip_for_grayskull() -@pytest.mark.parametrize("seq_len", [4096, 1024]) -@pytest.mark.parametrize("num_slices", [16]) -@pytest.mark.parametrize("num_cores", [64]) -@pytest.mark.parametrize("num_heads", [16]) -@pytest.mark.parametrize("data_format", [ttnn.bfloat8_b]) -def test_time_sharded_attnention_hwb( - device, - seq_len, - num_slices, - num_cores, - num_heads, - data_format, - function_level_defaults, -): - pytest.skip() - compute_grid_size = device.compute_with_storage_grid_size() - if num_cores > (compute_grid_size.x * compute_grid_size.y): - pytest.skip(f"Need {num_cores} cores to run this test but core grid is {compute_grid_size}") - grid_size = (8, 8) - - M = seq_len - K = 64 - N = seq_len - - query_layer_shape = [1, num_heads, seq_len, 64] - key_layer_transposed_shape = [1, num_heads, 64, seq_len] - value_layer_shape = [1, num_heads, seq_len, 64] - output_shape = [1, num_heads, seq_len, 64] - - torch_query_layer = torch.randn(query_layer_shape).bfloat16().float() - torch_key_layer_transposed = torch.randn(key_layer_transposed_shape).bfloat16().float() - torch_value_layer = torch.randn(value_layer_shape).bfloat16().float() - torch_output = torch.randn(output_shape).bfloat16().float() - - dram_interleaved_memory_config = ttnn.DRAM_MEMORY_CONFIG - - height_sharded_mem_config = ttnn.MemoryConfig( - memory_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED, buffer_type=ttnn.BufferType.L1 - ) - block_sharded_mem_config = ttnn.MemoryConfig( - memory_layout=ttnn.TensorMemoryLayout.BLOCK_SHARDED, - buffer_type=ttnn.BufferType.L1, - ) - - # compare output to regular case - reference_query_layer = torch2tt_tensor( - torch_query_layer, - device, - tt_memory_config=dram_interleaved_memory_config, - tt_dtype=data_format, - ) - reference_key_layer_transposed = torch2tt_tensor( - torch_key_layer_transposed, - device, - tt_memory_config=dram_interleaved_memory_config, - tt_dtype=data_format, - ) - reference_value_layer = torch2tt_tensor( - torch_value_layer, - device, - tt_memory_config=dram_interleaved_memory_config, - tt_dtype=data_format, - ) - - attn_weights_qkt = torch_query_layer @ torch_key_layer_transposed - attn_weights_torch_sm = torch.nn.functional.softmax(attn_weights_qkt, dim=-1) - attn_weights_torch = attn_weights_torch_sm @ torch_value_layer - - compute_kernel_config = ttnn.WormholeComputeKernelConfig( - math_fidelity=ttnn.MathFidelity.LoFi, - math_approx_mode=True, - fp32_dest_acc_en=False, - packer_l1_acc=False, - ) - - passing = True - output = None - - mm_out = torch2tt_tensor( - torch_output, - device, - tt_memory_config=dram_interleaved_memory_config, - tt_dtype=data_format, - ) - - tiles_per_shard = math.ceil((((num_heads * seq_len) / num_cores) / num_slices) / 32) - mm_output_block_shard_spec = [seq_len // 8, seq_len // 8] - tiles_per_shard = math.ceil((((num_heads * seq_len) / num_cores) / num_slices) / 32) - mm_output_height_shard_spec = [tiles_per_shard * 32, seq_len] - - heads_per_slice = num_heads // num_slices - for i in range(num_slices): - q_slice = ttnn.interleaved_to_sharded_partial( - reference_query_layer, - ttnn.CoreCoord(1, grid_size[0]), - [M // grid_size[0], K], - num_slices, - i, - ttnn.TensorMemoryLayout.HEIGHT_SHARDED, - ttnn.ShardOrientation.ROW_MAJOR, - ) - k_slice = ttnn.interleaved_to_sharded_partial( - reference_key_layer_transposed, - ttnn.CoreCoord(grid_size[1], 1), - [K, N // grid_size[1]], - num_slices, - i, - ttnn.TensorMemoryLayout.WIDTH_SHARDED, - ttnn.ShardOrientation.ROW_MAJOR, - ) - - program_config = ttnn.MatmulMultiCoreReuseMultiCastProgramConfig( - compute_with_storage_grid_size=grid_size, - in0_block_w=K // 32, - out_subblock_h=1, - out_subblock_w=1, - per_core_M=M // (32 * grid_size[0]), - per_core_N=N // (32 * grid_size[1]), - transpose_mcast=False, - fused_activation=None, - ) - - mm_slice = ttnn.matmul( - q_slice, - k_slice, - program_config=program_config, - memory_config=block_sharded_mem_config, - dtype=data_format, - compute_kernel_config=compute_kernel_config, - ) - # mmt = tt2torch_tensor(mm_slice) - # passed, message = comp_pcc(mmt, attn_weights_qkt[:, i * heads_per_slice : (i + 1) * heads_per_slice, :, :]) - # print(message) - # assert passed - k_slice.deallocate() - q_slice.deallocate() - - height_per_core = seq_len // 64 - output_shard_grid = ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(7, 7))}) - output_shard_spec = ttnn.ShardSpec( - output_shard_grid, [height_per_core, seq_len], ttnn.ShardOrientation.ROW_MAJOR, False - ) - output_mem_config = ttnn.MemoryConfig( - ttnn.TensorMemoryLayout.HEIGHT_SHARDED, ttnn.BufferType.L1, output_shard_spec - ) - mm_slice = ttnn.reshard( - mm_slice, - output_mem_config, - ) - mm_slice = ttnn.move(mm_slice) - - softmax_program_config = ttnn.SoftmaxShardedMultiCoreProgramConfig( - compute_with_storage_grid_size=grid_size, - subblock_w=1, - block_h=mm_output_height_shard_spec[0] // 32, - block_w=mm_output_height_shard_spec[1] // 32, - ) - # print(program_config) - - mm_slice = ttnn.softmax_in_place(mm_slice, program_config=softmax_program_config) - # mmt = tt2torch_tensor(mm_slice) - # passed, message = comp_pcc(mmt, attn_weights_torch_sm[:, i * heads_per_slice : (i + 1) * heads_per_slice, :, :]) - # print(message) - # assert passed - - program_config = ttnn.MatmulMultiCoreReuseMultiCast1DProgramConfig( - compute_with_storage_grid_size=grid_size, - in0_block_w=seq_len // 32, - per_core_M=tiles_per_shard, - per_core_N=2, - out_subblock_h=1, - out_subblock_w=1, - fuse_batch=True, - fused_activation=None, - mcast_in0=False, - ) - v_slice = ttnn.slice( - reference_value_layer, - (0, (i * heads_per_slice), 0, 0), - (1, (i * heads_per_slice) + (heads_per_slice), seq_len, 64), - memory_config=dram_interleaved_memory_config, - ) - - mm_slice = ttnn.matmul( - mm_slice, - v_slice, - program_config=program_config, - memory_config=height_sharded_mem_config, - dtype=data_format, - compute_kernel_config=compute_kernel_config, - ) - v_slice.deallocate() - - ttnn.sharded_to_interleaved_partial( - mm_slice, - mm_out, - num_slices, - i, - memory_config=dram_interleaved_memory_config, - ) - - mm_slice.deallocate() - - mm_out_torch = tt2torch_tensor(mm_out) - - passing, output = comp_pcc(mm_out_torch, attn_weights_torch) - - print(output) - assert passing - - -# Test matmul attention sequence with InterleavedToShardedPartialOp -@skip_for_grayskull() -@pytest.mark.parametrize("seq_len", [4096, 1024]) -@pytest.mark.parametrize("num_slices", [16]) -@pytest.mark.parametrize("num_cores", [64]) -@pytest.mark.parametrize("num_heads", [16]) -@pytest.mark.parametrize("data_format", [ttnn.bfloat8_b]) -def test_time_sharded_attnention( - device, - seq_len, - num_slices, - num_cores, - num_heads, - data_format, - function_level_defaults, -): - pytest.skip() # ND hang on CI - compute_grid_size = device.compute_with_storage_grid_size() - if num_cores > (compute_grid_size.x * compute_grid_size.y): - pytest.skip(f"Need {num_cores} cores to run this test but core grid is {compute_grid_size}") - grid_size = (8, 8) - - query_layer_shape = [1, num_heads, seq_len, 64] - key_layer_transposed_shape = [1, num_heads, 64, seq_len] - value_layer_shape = [1, num_heads, seq_len, 64] - output_shape = [1, num_heads, seq_len, 64] - - torch_query_layer = torch.randn(query_layer_shape).bfloat16().float() - torch_key_layer_transposed = torch.randn(key_layer_transposed_shape).bfloat16().float() - torch_value_layer = torch.randn(value_layer_shape).bfloat16().float() - torch_output = torch.randn(output_shape).bfloat16().float() - - dram_interleaved_memory_config = ttnn.DRAM_MEMORY_CONFIG - l1_interleaved_memory_config = ttnn.L1_MEMORY_CONFIG - - height_sharded_memory_config = ttnn.MemoryConfig( - memory_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED, buffer_type=ttnn.BufferType.L1 - ) - - # compare output to regular case - reference_query_layer = torch2tt_tensor( - torch_query_layer, - device, - tt_memory_config=dram_interleaved_memory_config, - tt_dtype=data_format, - ) - reference_key_layer_transposed = torch2tt_tensor( - torch_key_layer_transposed, - device, - tt_memory_config=dram_interleaved_memory_config, - tt_dtype=data_format, - ) - reference_value_layer = torch2tt_tensor( - torch_value_layer, - device, - tt_memory_config=dram_interleaved_memory_config, - tt_dtype=data_format, - ) - - compute_kernel_config = ttnn.WormholeComputeKernelConfig( - math_fidelity=ttnn.MathFidelity.LoFi, - math_approx_mode=True, - fp32_dest_acc_en=False, - packer_l1_acc=True, - ) - - passing = True - output = None - - mm_out = torch2tt_tensor( - torch_output, - device, - tt_memory_config=dram_interleaved_memory_config, - tt_dtype=data_format, - ) - tiles_per_shard = math.ceil((((num_heads * seq_len) / num_cores) / num_slices) / 32) - mm_activations_height_shard_spec = [tiles_per_shard * 32, 2 * 32] - mm_output_height_shard_spec = [tiles_per_shard * 32, seq_len] - - heads_per_slice = num_heads // num_slices - for i in range(num_slices): - slice = ttnn.interleaved_to_sharded_partial( - reference_query_layer, - grid_size, - mm_activations_height_shard_spec, - num_slices, - i, - ttnn.TensorMemoryLayout.HEIGHT_SHARDED, - ttnn.ShardOrientation.ROW_MAJOR, - ) - program_config = ttnn.MatmulMultiCoreReuseMultiCast1DProgramConfig( - compute_with_storage_grid_size=grid_size, - in0_block_w=2, - per_core_M=tiles_per_shard, - per_core_N=seq_len // 32, - out_subblock_h=1, - out_subblock_w=1, - fuse_batch=True, - fused_activation=None, - mcast_in0=False, - ) - - k_slice = ttnn.slice( - reference_key_layer_transposed, - (0, (i * heads_per_slice), 0, 0), - (1, (i * heads_per_slice) + (heads_per_slice), 64, seq_len), - memory_config=l1_interleaved_memory_config, - ) - mm_slice = ttnn.matmul( - slice, - k_slice, - program_config=program_config, - memory_config=height_sharded_memory_config, - dtype=data_format, - compute_kernel_config=compute_kernel_config, - ) - k_slice.deallocate() - slice.deallocate() - - softmax_program_config = ttnn.SoftmaxShardedMultiCoreProgramConfig( - compute_with_storage_grid_size=grid_size, - subblock_w=1, - block_h=mm_output_height_shard_spec[0] // 32, - block_w=mm_output_height_shard_spec[1] // 32, - ) - - mm_slice = ttnn.softmax_in_place(mm_slice, program_config=softmax_program_config) - - program_config = ttnn.MatmulMultiCoreReuseMultiCast1DProgramConfig( - compute_with_storage_grid_size=grid_size, - in0_block_w=seq_len // 32, - per_core_M=tiles_per_shard, - per_core_N=2, - out_subblock_h=1, - out_subblock_w=1, - fuse_batch=True, - fused_activation=None, - mcast_in0=False, - ) - v_slice = ttnn.slice( - reference_value_layer, - (0, (i * heads_per_slice), 0, 0), - (1, (i * heads_per_slice) + (heads_per_slice), seq_len, 64), - memory_config=l1_interleaved_memory_config, - ) - mm_slice = ttnn.matmul( - mm_slice, - v_slice, - program_config=program_config, - memory_config=height_sharded_memory_config, - dtype=data_format, - compute_kernel_config=compute_kernel_config, - ) - v_slice.deallocate() - - ttnn.sharded_to_interleaved_partial( - mm_slice, - mm_out, - num_slices, - i, - memory_config=dram_interleaved_memory_config, - ) - - mm_slice.deallocate() - - return - - mm_out_torch = tt2torch_tensor(mm_out) - - attn_weights = ttnn.matmul( - reference_query_layer, reference_key_layer_transposed, memory_config=dram_interleaved_memory_config - ) - attn_weights = ttnn.softmax_in_place(attn_weights) - attn_weights = ttnn.matmul(attn_weights, reference_value_layer, memory_config=dram_interleaved_memory_config) - - attn_weights_torch = tt2torch_tensor(attn_weights) - passing, output = comp_pcc(mm_out_torch, attn_weights_torch) - - print(output) - assert passing - - -# Test matmul attention sequence with InterleavedToShardedPartialOp -@skip_for_grayskull() -@pytest.mark.parametrize("seq_len", [4096, 1024, 256, 64]) -@pytest.mark.parametrize("kv_len", [96]) -@pytest.mark.parametrize("num_heads", [16]) -@pytest.mark.parametrize("data_format", [ttnn.bfloat8_b]) -@pytest.mark.parametrize("reshard_for_softmax", [True, False]) -def test_cross_attnention( - device, - seq_len, - kv_len, - num_heads, - data_format, - reshard_for_softmax, - function_level_defaults, -): - if seq_len == 64 and reshard_for_softmax: - pytest.skip() - compute_grid_size = device.compute_with_storage_grid_size() - grid_size = (8, 2) - num_cores = grid_size[0] * grid_size[1] - if num_cores > (compute_grid_size.x * compute_grid_size.y): - pytest.skip(f"Need {num_cores} cores to run this test but core grid is {compute_grid_size}") - - query_layer_shape = [1, num_heads, seq_len, 64] - key_layer_transposed_shape = [1, num_heads, 64, kv_len] - value_layer_shape = [1, num_heads, kv_len, 64] - output_shape = [1, num_heads, seq_len, 64] - - torch_query_layer = torch.randn(query_layer_shape).bfloat16().float() - torch_key_layer_transposed = torch.randn(key_layer_transposed_shape).bfloat16().float() - torch_value_layer = torch.randn(value_layer_shape).bfloat16().float() - torch_output = torch.randn(output_shape).bfloat16().float() - - dram_interleaved_memory_config = ttnn.DRAM_MEMORY_CONFIG - l1_interleaved_memory_config = ttnn.L1_MEMORY_CONFIG - - height_sharded_memory_config = ttnn.MemoryConfig( - memory_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED, buffer_type=ttnn.BufferType.L1 - ) - - # compare output to regular case - reference_query_layer = torch2tt_tensor( - torch_query_layer, - device, - tt_memory_config=l1_interleaved_memory_config, - tt_dtype=data_format, - ) - reference_key_layer_transposed = torch2tt_tensor( - torch_key_layer_transposed, - device, - tt_memory_config=dram_interleaved_memory_config, - tt_dtype=data_format, - ) - reference_value_layer = torch2tt_tensor( - torch_value_layer, - device, - tt_memory_config=dram_interleaved_memory_config, - tt_dtype=data_format, - ) - - compute_kernel_config = ttnn.WormholeComputeKernelConfig( - math_fidelity=ttnn.MathFidelity.LoFi, - math_approx_mode=True, - fp32_dest_acc_en=False, - packer_l1_acc=False, - ) - - passing = True - output = None - - q_sharded = ttnn.interleaved_to_sharded( - reference_query_layer, - grid_size, - [num_heads * seq_len // num_cores, 64], - ttnn.TensorMemoryLayout.HEIGHT_SHARDED, - ttnn.ShardOrientation.COL_MAJOR, - ) - - program_config = ttnn.MatmulMultiCoreReuseProgramConfig( - compute_with_storage_grid_size=grid_size, - in0_block_w=2, - out_subblock_h=1, - out_subblock_w=1, - per_core_M=num_heads * seq_len // num_cores // 32, - per_core_N=kv_len // 32, - ) - print(program_config) - - compute_kernel_config = ttnn.WormholeComputeKernelConfig( - math_fidelity=ttnn.MathFidelity.LoFi, - math_approx_mode=True, - fp32_dest_acc_en=False, - packer_l1_acc=False, - ) - - mm_slice = ttnn.matmul( - q_sharded, - reference_key_layer_transposed, - program_config=program_config, - memory_config=height_sharded_memory_config, - dtype=data_format, - compute_kernel_config=compute_kernel_config, - ) - q_sharded.deallocate() - - if reshard_for_softmax: - height_per_core = num_heads * seq_len // 64 - orig_mem_config = mm_slice.memory_config() - if seq_len == 1024: - mm_slice = ttnn.sharded_to_interleaved(mm_slice, dram_interleaved_memory_config) - mm_slice = ttnn.interleaved_to_sharded( - mm_slice, - (8, 8), - [height_per_core, kv_len], - ttnn.TensorMemoryLayout.HEIGHT_SHARDED, - ttnn.ShardOrientation.COL_MAJOR, - ) - else: - output_shard_grid = ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(7, 7))}) - output_shard_spec = ttnn.ShardSpec( - output_shard_grid, [height_per_core, kv_len], ttnn.ShardOrientation.COL_MAJOR, False - ) - output_mem_config = ttnn.MemoryConfig( - ttnn.TensorMemoryLayout.HEIGHT_SHARDED, ttnn.BufferType.L1, output_shard_spec - ) - mm_slice = ttnn.reshard( - mm_slice, - output_mem_config, - ) - softmax_program_config = ttnn.SoftmaxShardedMultiCoreProgramConfig( - compute_with_storage_grid_size=(8, 8), - subblock_w=1, - block_h=32, - block_w=3, - ) - mm_slice = ttnn.softmax_in_place(mm_slice, program_config=softmax_program_config) - mm_slice = ttnn.reshard(mm_slice, orig_mem_config) - - else: - softmax_program_config = ttnn.SoftmaxShardedMultiCoreProgramConfig( - compute_with_storage_grid_size=grid_size, - subblock_w=1, - block_h=seq_len // 32, - block_w=kv_len // 32, - ) - mm_slice = ttnn.softmax_in_place(mm_slice, program_config=softmax_program_config) - - v_sharded = ttnn.interleaved_to_sharded( - reference_value_layer, - grid_size, - [num_heads * kv_len // num_cores, 64], - ttnn.TensorMemoryLayout.HEIGHT_SHARDED, - ttnn.ShardOrientation.COL_MAJOR, - ) - compute_kernel_config = ttnn.WormholeComputeKernelConfig( - math_fidelity=ttnn.MathFidelity.LoFi, - math_approx_mode=True, - fp32_dest_acc_en=False, - packer_l1_acc=False, - ) - program_config = ttnn.MatmulMultiCoreReuseProgramConfig( - compute_with_storage_grid_size=grid_size, - in0_block_w=kv_len // 32, - out_subblock_h=1, - out_subblock_w=1, - per_core_M=num_heads * seq_len // num_cores // 32, - per_core_N=2, - ) - mm_slice = ttnn.matmul( - mm_slice, - v_sharded, - program_config=program_config, - memory_config=height_sharded_memory_config, - dtype=data_format, - compute_kernel_config=compute_kernel_config, - ) - v_sharded.deallocate() - - mm_out_torch = tt2torch_tensor(mm_slice) - - attn_weights_torch = torch_query_layer @ torch_key_layer_transposed - attn_weights_torch = torch.nn.functional.softmax(attn_weights_torch, dim=-1) - attn_weights_torch = attn_weights_torch @ torch_value_layer - - passing, output = comp_pcc(mm_out_torch, attn_weights_torch) - - print(output) - assert passing - - -# Test matmul attention sequence with InterleavedToShardedPartialOp -@skip_for_grayskull() -@pytest.mark.parametrize("seq_len", [1024, 256, 64]) -@pytest.mark.parametrize("num_heads", [16]) -@pytest.mark.parametrize("data_format", [ttnn.bfloat8_b]) -@pytest.mark.parametrize("reshard_for_softmax", [True, False]) -def test_attention( - device, - seq_len, - num_heads, - data_format, - reshard_for_softmax, - function_level_defaults, -): - if (seq_len == 64 or seq_len == 1024) and reshard_for_softmax: - pytest.skip() - compute_grid_size = device.compute_with_storage_grid_size() - grid_size = (2, 8) - num_cores = grid_size[0] * grid_size[1] - if num_cores > (compute_grid_size.x * compute_grid_size.y): - pytest.skip(f"Need {num_cores} cores to run this test but core grid is {compute_grid_size}") - - query_layer_shape = [1, num_heads, seq_len, 64] - key_layer_transposed_shape = [1, num_heads, 64, seq_len] - value_layer_shape = [1, num_heads, seq_len, 64] - output_shape = [1, num_heads, seq_len, 64] - - torch_query_layer = torch.randn(query_layer_shape).bfloat16().float() - torch_key_layer_transposed = torch.randn(key_layer_transposed_shape).bfloat16().float() - torch_value_layer = torch.randn(value_layer_shape).bfloat16().float() - torch_output = torch.randn(output_shape).bfloat16().float() - - dram_interleaved_memory_config = ttnn.DRAM_MEMORY_CONFIG - l1_interleaved_memory_config = ttnn.L1_MEMORY_CONFIG - - height_sharded_memory_config = ttnn.MemoryConfig( - memory_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED, buffer_type=ttnn.BufferType.L1 - ) - - # compare output to regular case - reference_query_layer = torch2tt_tensor( - torch_query_layer, - device, - tt_memory_config=dram_interleaved_memory_config, - tt_dtype=data_format, - ) - reference_key_layer_transposed = torch2tt_tensor( - torch_key_layer_transposed, - device, - tt_memory_config=dram_interleaved_memory_config, - tt_dtype=data_format, - ) - reference_value_layer = torch2tt_tensor( - torch_value_layer, - device, - tt_memory_config=dram_interleaved_memory_config, - tt_dtype=data_format, - ) - - compute_kernel_config = ttnn.WormholeComputeKernelConfig( - math_fidelity=ttnn.MathFidelity.LoFi, - math_approx_mode=True, - fp32_dest_acc_en=False, - packer_l1_acc=False, - ) - - passing = True - output = None - - q_sharded = ttnn.interleaved_to_sharded( - reference_query_layer, - grid_size, - [num_heads * seq_len // num_cores, 64], - ttnn.TensorMemoryLayout.HEIGHT_SHARDED, - ttnn.ShardOrientation.ROW_MAJOR, - ) - M = num_heads * seq_len - K = 64 - N = seq_len - program_config = ttnn.MatmulMultiCoreReuseProgramConfig( - compute_with_storage_grid_size=grid_size, - in0_block_w=K // 32, - out_subblock_h=1, - out_subblock_w=1, - per_core_M=M // num_cores // 32, - per_core_N=N // 32, - ) - print(program_config) - - compute_kernel_config = ttnn.WormholeComputeKernelConfig( - math_fidelity=ttnn.MathFidelity.LoFi, - math_approx_mode=True, - fp32_dest_acc_en=False, - packer_l1_acc=False, - ) - - mm_slice = ttnn.matmul( - q_sharded, - reference_key_layer_transposed, - program_config=program_config, - memory_config=height_sharded_memory_config, - dtype=data_format, - compute_kernel_config=compute_kernel_config, - ) - q_sharded.deallocate() - - if reshard_for_softmax: - height_per_core = num_heads * seq_len // 64 - orig_mem_config = mm_slice.memory_config() - if seq_len == 1024: - mm_slice = ttnn.sharded_to_interleaved(mm_slice, l1_interleaved_memory_config) - mm_slice = ttnn.interleaved_to_sharded( - mm_slice, - (8, 8), - [height_per_core, seq_len], - ttnn.TensorMemoryLayout.HEIGHT_SHARDED, - ttnn.ShardOrientation.ROW_MAJOR, - ) - softmax_program_config = ttnn.SoftmaxShardedMultiCoreProgramConfig( - compute_with_storage_grid_size=(8, 8), - subblock_w=1, - block_h=height_per_core // 32, - block_w=seq_len // 32, - ) - mm_slice = ttnn.softmax_in_place(mm_slice, program_config=softmax_program_config) - mm_slice = ttnn.sharded_to_interleaved(mm_slice, l1_interleaved_memory_config) - mm_slice = ttnn.interleaved_to_sharded( - mm_slice, - (8, 2), - [num_heads * seq_len // 16, seq_len], - ttnn.TensorMemoryLayout.HEIGHT_SHARDED, - ttnn.ShardOrientation.COL_MAJOR, - ) - - else: - output_shard_grid = ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(7, 7))}) - output_shard_spec = ttnn.ShardSpec( - output_shard_grid, [height_per_core, seq_len], ttnn.ShardOrientation.COL_MAJOR, False - ) - output_mem_config = ttnn.MemoryConfig( - ttnn.TensorMemoryLayout.HEIGHT_SHARDED, ttnn.BufferType.L1, output_shard_spec - ) - mm_slice = ttnn.reshard( - mm_slice, - output_mem_config, - ) - softmax_program_config = ttnn.SoftmaxShardedMultiCoreProgramConfig( - compute_with_storage_grid_size=(8, 8), - subblock_w=1, - block_h=height_per_core // 32, - block_w=seq_len // 32, - ) - mm_slice = ttnn.softmax_in_place(mm_slice, program_config=softmax_program_config) - mm_slice = ttnn.reshard(mm_slice, orig_mem_config) - else: - softmax_program_config = ttnn.SoftmaxShardedMultiCoreProgramConfig( - compute_with_storage_grid_size=grid_size, - subblock_w=1, - block_h=seq_len // 32, - block_w=seq_len // 32, - ) - print(softmax_program_config) - mm_slice = ttnn.softmax_in_place(mm_slice, program_config=softmax_program_config) - - v_sharded = ttnn.interleaved_to_sharded( - reference_value_layer, - grid_size, - [num_heads * seq_len // num_cores, 64], - ttnn.TensorMemoryLayout.HEIGHT_SHARDED, - ttnn.ShardOrientation.ROW_MAJOR, - ) - compute_kernel_config = ttnn.WormholeComputeKernelConfig( - math_fidelity=ttnn.MathFidelity.LoFi, - math_approx_mode=True, - fp32_dest_acc_en=False, - packer_l1_acc=False, - ) - program_config = ttnn.MatmulMultiCoreReuseProgramConfig( - compute_with_storage_grid_size=grid_size, - in0_block_w=seq_len // 32, - out_subblock_h=1, - out_subblock_w=1, - per_core_M=num_heads * seq_len // num_cores // 32, - per_core_N=2, - ) - print(program_config) - mm_slice = ttnn.matmul( - mm_slice, - v_sharded, - program_config=program_config, - memory_config=height_sharded_memory_config, - dtype=data_format, - compute_kernel_config=compute_kernel_config, - ) - v_sharded.deallocate() - - mm_out_torch = tt2torch_tensor(mm_slice) - - attn_weights_torch = torch_query_layer @ torch_key_layer_transposed - attn_weights_torch = torch.nn.functional.softmax(attn_weights_torch, dim=-1) - attn_weights_torch = attn_weights_torch @ torch_value_layer - - passing, output = comp_pcc(mm_out_torch, attn_weights_torch) - - print(output) - assert passing - - -@skip_for_grayskull() -@pytest.mark.parametrize("size", [4096, 1024, 256, 64]) -@pytest.mark.parametrize("is_qkv", [1, 2, 3]) -@pytest.mark.parametrize("data_format", [ttnn.bfloat8_b]) -def test_q_and_kv( - device, - size, - data_format, - is_qkv, - function_level_defaults, -): - # Test matmul attention sequence with InterleavedToShardedPartialOp - sizes = {4096: [1, 8192, 320, 512], 1024: [1, 2048, 640, 768], 256: [1, 512, 1280, 1280], 64: [1, 128, 1280, 1280]} - grid_sizes = {4096: (5, 8), 1024: (5, 8), 256: (8, 8), 64: (8, 4)} - B, M, K, N = sizes[size] - N = N * is_qkv - grid_size = grid_sizes[size] - compute_grid_size = device.compute_with_storage_grid_size() - num_cores = grid_size[0] * grid_size[1] - if num_cores > (compute_grid_size.x * compute_grid_size.y): - pytest.skip(f"Need {num_cores} cores to run this test but core grid is {compute_grid_size}") - - in_0_shape = [1, B, M, K] - in_1_shape = [1, B, K, N] - in_2_shape = [1, B, 192, K] - in_3_shape = [1, B, K, 2 * N] - - in_0_torch = torch.randn(in_0_shape).bfloat16().float() - in_1_torch = torch.randn(in_1_shape).bfloat16().float() - in_2_torch = torch.randn(in_2_shape).bfloat16().float() - in_3_torch = torch.randn(in_3_shape).bfloat16().float() - - dram_interleaved_memory_config = ttnn.DRAM_MEMORY_CONFIG - l1_interleaved_memory_config = ttnn.L1_MEMORY_CONFIG - - height_sharded_memory_config = ttnn.MemoryConfig( - memory_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED, buffer_type=ttnn.BufferType.L1 - ) - - block_sharded_memory_config = ttnn.MemoryConfig( - memory_layout=ttnn.TensorMemoryLayout.BLOCK_SHARDED, buffer_type=ttnn.BufferType.L1 - ) - - # compare output to regular case - in_0 = torch2tt_tensor( - in_0_torch, - device, - tt_memory_config=dram_interleaved_memory_config, - tt_dtype=data_format, - ) - in_1 = torch2tt_tensor( - in_1_torch, - device, - tt_memory_config=dram_interleaved_memory_config, - tt_dtype=data_format, - ) - in_2 = torch2tt_tensor( - in_2_torch, - device, - tt_memory_config=dram_interleaved_memory_config, - tt_dtype=data_format, - ) - in_3 = torch2tt_tensor( - in_3_torch, - device, - tt_memory_config=dram_interleaved_memory_config, - tt_dtype=data_format, - ) - - compute_kernel_config = ttnn.WormholeComputeKernelConfig( - math_fidelity=ttnn.MathFidelity.LoFi, - math_approx_mode=True, - fp32_dest_acc_en=False, - packer_l1_acc=False, - ) - - passing = True - output = None - - in_0_sharded = ttnn.interleaved_to_sharded( - in_0, - grid_size, - [M // grid_size[1], K // grid_size[0]], - ttnn.TensorMemoryLayout.BLOCK_SHARDED, - ttnn.ShardOrientation.ROW_MAJOR, - ) - M, K = in_0.shape[-2], in_0.shape[-1] - N = in_1.shape[-1] - in0_block_h, in0_block_w, out_subblock_h, out_subblock_w, out_block_h, out_block_w = determine_blocking( - M, K, N, grid_size - ) - program_config = ttnn.MatmulMultiCoreReuseMultiCastProgramConfig( - compute_with_storage_grid_size=grid_size, - in0_block_w=in0_block_w, - out_subblock_h=out_subblock_h, - out_subblock_w=out_subblock_w, - per_core_M=out_block_h, - per_core_N=out_block_w, - transpose_mcast=False, - fused_activation=None, - ) - - compute_kernel_config = ttnn.WormholeComputeKernelConfig( - math_fidelity=ttnn.MathFidelity.LoFi, - math_approx_mode=True, - fp32_dest_acc_en=False, - packer_l1_acc=False, - ) - mm = ttnn.matmul( - in_0_sharded if size != 4096 else in_0, - in_1, - program_config=program_config, - memory_config=block_sharded_memory_config, - dtype=ttnn.bfloat8_b, - compute_kernel_config=compute_kernel_config, - ) - in_0_sharded.deallocate() - - M, K, N = in_2.shape[-2], in_2.shape[-1], in_3.shape[-1] - in0_block_h = M // grid_size[1] // 32 - in0_block_w = K // grid_size[0] // 32 - out_block_h = math.ceil(M / grid_size[1] / 32) - out_block_w = math.ceil(N / grid_size[0] / 32) - out_subblock_h, out_subblock_w = determine_largest_subblock_size(out_block_h, out_block_w) - program_config = ttnn.MatmulMultiCoreReuseMultiCastProgramConfig( - compute_with_storage_grid_size=grid_size, - in0_block_w=in0_block_w, - out_subblock_h=out_subblock_h, - out_subblock_w=out_subblock_w, - per_core_M=out_block_h, - per_core_N=out_block_w, - transpose_mcast=False, - fused_activation=None, - ) - compute_kernel_config = ttnn.WormholeComputeKernelConfig( - math_fidelity=ttnn.MathFidelity.LoFi, - math_approx_mode=True, - fp32_dest_acc_en=False, - packer_l1_acc=False, - ) - - mm_out_torch = tt2torch_tensor(mm) - - out_torch = in_0_torch @ in_1_torch - - passing, output = comp_pcc(mm_out_torch, out_torch) - - print(output) - assert passing diff --git a/tests/ttnn/sweep_tests/sweeps/sweeps/upsample.py b/tests/ttnn/sweep_tests/sweeps/sweeps/upsample.py index 0b752e7cef2..88ae18ccd9c 100644 --- a/tests/ttnn/sweep_tests/sweeps/sweeps/upsample.py +++ b/tests/ttnn/sweep_tests/sweeps/sweeps/upsample.py @@ -37,8 +37,7 @@ def run( torch_result = m(tt_input) torch_result = torch_result.permute(0, 2, 3, 1) - ## ttnn uses NHWC, so need to set scale_factor_c = 1 - scale_factor = (scale_h, scale_w, 1) + scale_factor = (scale_h, scale_w) input_tensor = ttnn.from_torch(input, device=device) output_tensor = ttnn.upsample(input_tensor, scale_factor) output_tensor = ttnn.to_torch(output_tensor) diff --git a/tests/ttnn/unit_tests/gtests/ccl/test_erisc_data_mover_with_workers.cpp b/tests/ttnn/unit_tests/gtests/ccl/test_erisc_data_mover_with_workers.cpp index d16bb46ba2e..bcef0d0e182 100644 --- a/tests/ttnn/unit_tests/gtests/ccl/test_erisc_data_mover_with_workers.cpp +++ b/tests/ttnn/unit_tests/gtests/ccl/test_erisc_data_mover_with_workers.cpp @@ -10,7 +10,7 @@ #include "gtest/gtest.h" -#include "umd/device/tt_arch_types.h" +#include "umd/device/types/arch.h" // #include "tt_backend_api_types.hpp" #include "tt_metal/common/core_coord.hpp" #include "tt_metal/common/math.hpp" diff --git a/tests/ttnn/unit_tests/gtests/ccl/test_fabric_erisc_data_mover_loopback_with_workers.cpp b/tests/ttnn/unit_tests/gtests/ccl/test_fabric_erisc_data_mover_loopback_with_workers.cpp index 53c94f0702d..db2910cb801 100644 --- a/tests/ttnn/unit_tests/gtests/ccl/test_fabric_erisc_data_mover_loopback_with_workers.cpp +++ b/tests/ttnn/unit_tests/gtests/ccl/test_fabric_erisc_data_mover_loopback_with_workers.cpp @@ -8,7 +8,7 @@ #include #include -#include "umd/device/tt_arch_types.h" +#include "umd/device/types/arch.h" #include "gtest/gtest.h" // #include "tt_backend_api_types.hpp" #include "tt_metal/common/core_coord.hpp" diff --git a/tests/ttnn/unit_tests/gtests/tensor/test_create_tensor_multi_device.cpp b/tests/ttnn/unit_tests/gtests/tensor/test_create_tensor_multi_device.cpp index 585326afc8b..f4279cc8753 100644 --- a/tests/ttnn/unit_tests/gtests/tensor/test_create_tensor_multi_device.cpp +++ b/tests/ttnn/unit_tests/gtests/tensor/test_create_tensor_multi_device.cpp @@ -170,6 +170,25 @@ TEST_P(MultiDeviceTensorCreationTest, FullLikeWithOptTensor) { EXPECT_TRUE(std::holds_alternative(distributed_tensor_config)); } +TEST_P(MultiDeviceTensorCreationTest, Arange) { + MeshDevice* mesh_device = this->mesh_device_.get(); + mesh_device->enable_async(GetParam()); + + Tensor tensor = ttnn::arange( + /*start=*/0, + /*end=*/1024, + /*step=*/1, + ttnn::DataType::BFLOAT16, + std::ref(*mesh_device)); + + EXPECT_EQ(tensor.storage_type(), StorageType::MULTI_DEVICE); + EXPECT_EQ(tensor.get_workers().size(), mesh_device->num_devices()); + EXPECT_EQ(tensor.shape(), ttnn::SimpleShape({1, 1, 1, 1024})); + + const auto distributed_tensor_config = get_distributed_tensor_config_from_tensor(tensor); + EXPECT_TRUE(std::holds_alternative(distributed_tensor_config)); +} + INSTANTIATE_TEST_SUITE_P(AllTests, MultiDeviceTensorCreationTest, ::testing::Bool()); } // namespace diff --git a/tests/ttnn/unit_tests/gtests/test_async_runtime.cpp b/tests/ttnn/unit_tests/gtests/test_async_runtime.cpp index 7e1ab23115e..b5495a324db 100644 --- a/tests/ttnn/unit_tests/gtests/test_async_runtime.cpp +++ b/tests/ttnn/unit_tests/gtests/test_async_runtime.cpp @@ -2,6 +2,7 @@ // // SPDX-License-Identifier: Apache-2.0 +#include "ttnn/cpp/ttnn/operations/creation.hpp" #include "ttnn/tensor/tensor.hpp" #include "ttnn/tensor/layout/tensor_layout.hpp" #include "ttnn_multi_command_queue_fixture.hpp" @@ -10,14 +11,13 @@ #include "ttnn/operations/moreh/moreh_sum/moreh_sum.hpp" #include "common/bfloat16.hpp" #include "ttnn/async_runtime.hpp" -#include "ttnn/operations/numpy/functions.hpp" #include "tt_metal/impl/event/event.hpp" #include -using namespace tt; -using namespace tt_metal; -using MultiCommandQueueSingleDeviceFixture = ttnn::MultiCommandQueueSingleDeviceFixture; -using namespace constants; +namespace tt::tt_metal { +namespace { + +using MultiCommandQueueSingleDeviceFixture = ::ttnn::MultiCommandQueueSingleDeviceFixture; TEST_F(MultiCommandQueueSingleDeviceFixture, TestAsyncPreallocatedOutputs) { Device* device = this->device_; @@ -40,16 +40,14 @@ TEST_F(MultiCommandQueueSingleDeviceFixture, TestAsyncPreallocatedOutputs) { host_data[i] = bfloat16(static_cast(1)); } // Create golden data using tt_eager APIs - Tensor np_tensor = ttnn::numpy::full(input_shape.value, static_cast(1), DataType::BFLOAT16) - .to(Layout::TILE) - .to(device); + Tensor np_tensor = ttnn::full(input_shape, static_cast(1), DataType::BFLOAT16, Layout::TILE, *device_); ttnn::SmallVector reduce_dims = {3}; Tensor np_out = ttnn::moreh_sum(np_tensor, reduce_dims, false, std::nullopt, std::nullopt, std::nullopt); Tensor np_out_host = np_out.cpu(); const bfloat16* golden_output = std::get>(std::get(np_out_host.get_storage()).buffer).begin(); // Enable Asynchronous Execution and test ttnn runtime APIs - device->enable_async(true); + device_->enable_async(true); // Events for host - device synchronization auto write_event = std::make_shared(); auto workload_event = std::make_shared(); @@ -63,9 +61,9 @@ TEST_F(MultiCommandQueueSingleDeviceFixture, TestAsyncPreallocatedOutputs) { output_buf_size_datums * datum_size_bytes, tensor_layout.compute_packed_buffer_size_bytes(np_out.get_padded_shape())); auto input_buffer = tt::tt_metal::tensor_impl::allocate_buffer_on_device( - device, TensorSpec(input_shape.padded_shape(), tensor_layout)); + device_, TensorSpec(input_shape.padded_shape(), tensor_layout)); auto output_buffer = tt::tt_metal::tensor_impl::allocate_buffer_on_device( - device, TensorSpec(np_out.get_padded_shape(), tensor_layout)); + device_, TensorSpec(np_out.get_padded_shape(), tensor_layout)); auto input_storage = tt::tt_metal::DeviceStorage{input_buffer}; auto output_storage = tt::tt_metal::DeviceStorage{output_buffer}; Tensor input_tensor = Tensor(input_storage, input_shape, DataType::BFLOAT16, Layout::TILE); @@ -73,13 +71,13 @@ TEST_F(MultiCommandQueueSingleDeviceFixture, TestAsyncPreallocatedOutputs) { // Populate input_tensor with data ttnn::write_buffer(io_cq, input_tensor, {host_data}); // Record the completion of the write event - ttnn::record_event(device->command_queue(io_cq), write_event); + ttnn::record_event(device_->command_queue(io_cq), write_event); // Host stalls until write is completed, before sending workload ttnn::event_synchronize(write_event); // Dispatch workload. Preallocated output_tensor is populated by op/ ttnn::moreh_sum(input_tensor, /*dim*/ 3, false, output_tensor, std::nullopt, std::nullopt); // Record completion of workload - ttnn::record_event(device->command_queue(workload_dispatch_cq), workload_event); + ttnn::record_event(device_->command_queue(workload_dispatch_cq), workload_event); ttnn::event_synchronize(workload_event); // Read output back, once workload is complete ttnn::read_buffer(io_cq, output_tensor, {readback_data}); @@ -93,7 +91,7 @@ TEST_F(MultiCommandQueueSingleDeviceFixture, TestAsyncPreallocatedOutputs) { // Deallocate tensors (tensor gives up buffer). Done asynchronously, so sync on queue after. input_tensor.deallocate(); output_tensor.deallocate(); - ttnn::queue_synchronize(device->command_queue(io_cq)); + ttnn::queue_synchronize(device_->command_queue(io_cq)); // Buffer only has 2 owners in main thread. EXPECT_EQ(input_buffer.use_count(), 2); EXPECT_EQ(output_buffer.use_count(), 2); @@ -103,8 +101,7 @@ TEST_F(MultiCommandQueueSingleDeviceFixture, TestAsyncPreallocatedOutputs) { } TEST_F(MultiCommandQueueSingleDeviceFixture, TestAsyncRuntimeAllocatedBuffers) { - Device* device = this->device_; - device->enable_async(true); + device_->enable_async(true); MemoryConfig mem_cfg = MemoryConfig{ .memory_layout = tt::tt_metal::TensorMemoryLayout::INTERLEAVED, .buffer_type = BufferType::DRAM, @@ -131,26 +128,26 @@ TEST_F(MultiCommandQueueSingleDeviceFixture, TestAsyncRuntimeAllocatedBuffers) { TensorLayout tensor_layout(DataType::BFLOAT16, PageConfig(Layout::TILE), mem_cfg); ASSERT_EQ(buf_size_datums * datum_size_bytes, tensor_layout.compute_packed_buffer_size_bytes(shape)); auto input_buffer = - tt::tt_metal::tensor_impl::allocate_buffer_on_device(device, TensorSpec(shape, tensor_layout)); + tt::tt_metal::tensor_impl::allocate_buffer_on_device(device_, TensorSpec(shape, tensor_layout)); auto input_storage = tt::tt_metal::DeviceStorage{input_buffer}; Tensor input_tensor = Tensor(input_storage, shape, DataType::BFLOAT16, Layout::TILE); - ttnn::write_buffer(io_cq, input_tensor, {host_data}); // Write using cq 1 - ttnn::record_event(device->command_queue(io_cq), write_event); // Record write on cq 1 + ttnn::write_buffer(io_cq, input_tensor, {host_data}); // Write using cq 1 + ttnn::record_event(device_->command_queue(io_cq), write_event); // Record write on cq 1 // Wait until cq 1 write is complete - ttnn::wait_for_event(device->command_queue(workload_dispatch_cq), write_event); + ttnn::wait_for_event(device_->command_queue(workload_dispatch_cq), write_event); // Run operation on cq 0 Tensor output_tensor = ttnn::sqrt(workload_dispatch_cq, input_tensor); auto dummy_buffer_0 = - tt::tt_metal::tensor_impl::allocate_buffer_on_device(device, TensorSpec(shape, tensor_layout)); + tt::tt_metal::tensor_impl::allocate_buffer_on_device(device_, TensorSpec(shape, tensor_layout)); output_tensor = ttnn::neg(workload_dispatch_cq, output_tensor); // Allocate this buffer to stress test async allocation across op execution and explicit allocation auto dummy_buffer_1 = - tt::tt_metal::tensor_impl::allocate_buffer_on_device(device, TensorSpec(shape, tensor_layout)); + tt::tt_metal::tensor_impl::allocate_buffer_on_device(device_, TensorSpec(shape, tensor_layout)); // Record cq 0 prog execution - ttnn::record_event(device->command_queue(workload_dispatch_cq), workload_event); + ttnn::record_event(device_->command_queue(workload_dispatch_cq), workload_event); // Wait until cq 0 prog execution is done - ttnn::wait_for_event(device->command_queue(io_cq), workload_event); + ttnn::wait_for_event(device_->command_queue(io_cq), workload_event); // Read using cq 1 ttnn::read_buffer(io_cq, output_tensor, {readback_data}); for (int i = 0; i < buf_size_datums; i++) { @@ -166,8 +163,7 @@ TEST_F(MultiCommandQueueSingleDeviceFixture, TestAsyncRuntimeBufferDestructor) { // Test functionality for the buffer destructor, which will call deallocate asynchronously // We must ensure that the deallocate step, which can run after the buffer has been destroyed // does not rely on stale buffer state, after the buffer has been destroyed on host - Device* device = this->device_; - device->enable_async(true); + device_->enable_async(true); MemoryConfig mem_cfg = MemoryConfig{ .memory_layout = tt::tt_metal::TensorMemoryLayout::INTERLEAVED, .buffer_type = BufferType::DRAM, @@ -182,9 +178,9 @@ TEST_F(MultiCommandQueueSingleDeviceFixture, TestAsyncRuntimeBufferDestructor) { TensorLayout tensor_layout(DataType::BFLOAT16, PageConfig(Layout::TILE), mem_cfg); TensorSpec tensor_spec(shape, tensor_layout); for (int loop = 0; loop < 100000; loop++) { - { - auto input_buffer_dummy = tt::tt_metal::tensor_impl::allocate_buffer_on_device(device, tensor_spec); - device->synchronize(); - } + auto input_buffer_dummy = tt::tt_metal::tensor_impl::allocate_buffer_on_device(device_, tensor_spec); + device_->synchronize(); } } +} // namespace +} // namespace tt::tt_metal diff --git a/tests/ttnn/unit_tests/gtests/test_ccl_on_galaxy.cpp b/tests/ttnn/unit_tests/gtests/test_ccl_on_galaxy.cpp index 6ab86fc8c8e..abecf9445ce 100644 --- a/tests/ttnn/unit_tests/gtests/test_ccl_on_galaxy.cpp +++ b/tests/ttnn/unit_tests/gtests/test_ccl_on_galaxy.cpp @@ -216,7 +216,7 @@ TEST(GalaxyTests, TestReduceScatterDeadlock) { auto view = ttnn::MeshDeviceView(*mesh); std::vector ring_devices = view.get_devices_on_row(0); // Tunnel 0 std::vector ring_devices_1 = - view.get_devices_on_column(mesh_shape.second - 1); // Orthogonal to tunnel .. no deadlocks + view.get_devices_on_column(mesh_shape.num_cols - 1); // Orthogonal to tunnel .. no deadlocks ring_devices_1 = std::vector(ring_devices_1.begin() + 1, ring_devices_1.end()); std::vector ring_devices_2 = view.get_devices_on_row(7); // Tunnel 7 .. potential deadlocks with lack of buffering diff --git a/tests/ttnn/unit_tests/operations/ccl/test_all_reduce_t3000_frequent.py b/tests/ttnn/unit_tests/operations/ccl/test_all_reduce_t3000_frequent.py index 41c1d80edeb..f46b1b4f76a 100644 --- a/tests/ttnn/unit_tests/operations/ccl/test_all_reduce_t3000_frequent.py +++ b/tests/ttnn/unit_tests/operations/ccl/test_all_reduce_t3000_frequent.py @@ -244,3 +244,69 @@ def test_ring_all_reduce_post_commit( num_iters=num_iters, enable_async=enable_async, ) + + +@skip_for_grayskull("Requires eth connected devices to run") +@pytest.mark.timeout(120) +@pytest.mark.parametrize( + "num_devices, num_links", + [ + (2, 1), + ], +) +@pytest.mark.parametrize( + "per_chip_output_shape", + [ + ([2, 2, 64, 64]), + ([1, 1, 64, 64]), + ], +) +@pytest.mark.parametrize( + "layout", + [ + ttnn.TILE_LAYOUT, + ], +) +@pytest.mark.parametrize( + "input_dtype", + [ + ttnn.bfloat16, + ], +) +@pytest.mark.parametrize( + "mem_config", + [ + ttnn.MemoryConfig(buffer_type=ttnn.BufferType.DRAM), + ], +) +@pytest.mark.parametrize("math_op", [ttnn.ReduceType.Sum]) +@pytest.mark.parametrize("enable_async", [True]) +def test_ring_all_reduce_post_commit_2chip( + pcie_mesh_device, + num_devices, + per_chip_output_shape, + num_links, + math_op, + input_dtype, + layout, + mem_config, + use_program_cache, + function_level_defaults, + enable_async, + num_iters=2, +): + run_all_reduce_test( + pcie_mesh_device, + num_devices, + per_chip_output_shape, + num_links, + math_op, + input_dtype, + layout, + mem_config, + use_program_cache, + function_level_defaults, + num_iters=num_iters, + enable_async=enable_async, + topology=ttnn.Topology.Linear, + ) diff --git a/tests/ttnn/unit_tests/operations/eltwise/test_unary.py b/tests/ttnn/unit_tests/operations/eltwise/test_unary.py index 2bb028ee6cf..da305202e9c 100644 --- a/tests/ttnn/unit_tests/operations/eltwise/test_unary.py +++ b/tests/ttnn/unit_tests/operations/eltwise/test_unary.py @@ -428,3 +428,22 @@ def run_unary_test_bitwise_not(device, h, w, fill_value, ttnn_function, pcc=0.99 @pytest.mark.parametrize("fill_value", [-2147483647, 2147483648, 7534, 225, 97, 3]) def test_bitwise_not(device, h, w, fill_value): run_unary_test_bitwise_not(device, h, w, fill_value, ttnn.bitwise_not) + + +@skip_for_grayskull() +@pytest.mark.parametrize( + "input_shapes", + ( + (torch.Size([1, 1, 32, 32])), + (torch.Size([1, 1, 320, 384])), + (torch.Size([1, 3, 320, 384])), + ), +) +def test_unary_floor(input_shapes, device): + in_data1 = torch.empty(input_shapes, dtype=torch.float32).uniform_(-43566, 43565) + input_tensor1 = ttnn.from_torch(in_data1, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device) + output_tensor = ttnn.floor(input_tensor1) + golden_function = ttnn.get_golden_function(ttnn.floor) + golden_tensor = golden_function(in_data1) + output_tensor = ttnn.to_torch(output_tensor) + assert_with_pcc(golden_tensor, output_tensor, 0.999) diff --git a/tests/ttnn/unit_tests/operations/test_conv1d.py b/tests/ttnn/unit_tests/operations/test_conv1d.py index 3e7a1496c63..7013ef6b2db 100644 --- a/tests/ttnn/unit_tests/operations/test_conv1d.py +++ b/tests/ttnn/unit_tests/operations/test_conv1d.py @@ -88,12 +88,15 @@ def run_conv( conv_config = ttnn.Conv1dConfig( dtype=output_dtype, weights_dtype=weights_dtype, - math_fidelity=math_fidelity, shard_layout=shard_layout, input_channels_alignment=(16 if use_shallow_conv_variant else 32), deallocate_activation=deallocate_activation, - fp32_dest_acc_enabled=fp32_accum, - packer_l1_accum_enabled=packer_l1_acc, + ) + compute_config = ttnn.init_device_compute_kernel_config( + device.arch(), + math_fidelity=math_fidelity, + fp32_dest_acc_en=fp32_accum, + packer_l1_acc=packer_l1_acc, ) if config_override and "act_block_h" in config_override: conv_config.act_block_h_override = config_override["act_block_h"] @@ -104,7 +107,7 @@ def run_conv( conv_config.override_sharding_config = True print("Setting num_cores_nhw to 98") - [tt_output_tensor_on_device, out_length, weights_device, bias_device] = ttnn.Conv1d( + [tt_output_tensor_on_device, out_length, [weights_device, bias_device]] = ttnn.Conv1d( input_tensor=tt_input_tensor, weight_tensor=tt_weight_tensor, in_channels=input_channels, @@ -117,9 +120,12 @@ def run_conv( batch_size=batch_size, input_length=input_length, conv_config=conv_config, + compute_config=compute_config, conv_op_cache=reader_patterns_cache, debug=debug, groups=groups, + return_output_dim=True, + return_weights_and_bias=True, ) tt_output_tensor = ttnn.from_device(tt_output_tensor_on_device) diff --git a/tests/ttnn/unit_tests/operations/test_conv_transpose2d.py b/tests/ttnn/unit_tests/operations/test_conv_transpose2d.py index 699caa49e54..63942ef0f8f 100644 --- a/tests/ttnn/unit_tests/operations/test_conv_transpose2d.py +++ b/tests/ttnn/unit_tests/operations/test_conv_transpose2d.py @@ -104,19 +104,22 @@ def run_conv_transpose2d( conv_config = ttnn.Conv2dConfig( dtype=activations_dtype, weights_dtype=weights_dtype, - math_fidelity=math_fidelity, shard_layout=shard_layout, input_channels_alignment=( 16 if use_shallow_conv_variant or (input_channels == 16 and input_height == 115) else 32 ), deallocate_activation=deallocate_activation, - fp32_dest_acc_enabled=fp32_accum, - packer_l1_accum_enabled=packer_l1_acc, enable_act_double_buffer=False, enable_split_reader=False, enable_subblock_padding=False, output_layout=ttnn.ROW_MAJOR_LAYOUT, ) + compute_config = ttnn.init_device_compute_kernel_config( + device.arch(), + math_fidelity=math_fidelity, + fp32_dest_acc_en=fp32_accum, + packer_l1_acc=packer_l1_acc, + ) if config_override and "act_block_h" in config_override: conv_config.act_block_h_override = config_override["act_block_h"] @@ -139,6 +142,7 @@ def run_conv_transpose2d( input_height=input_height, input_width=input_width, conv_config=conv_config, + compute_config=compute_config, groups=groups, ) logger.info(f"Conv2d Transpose Input = {(input_height, input_width)} Output = {out_height, out_width}") diff --git a/tests/ttnn/unit_tests/operations/test_creation.py b/tests/ttnn/unit_tests/operations/test_creation.py index f6f6773dc81..79f09ca122d 100644 --- a/tests/ttnn/unit_tests/operations/test_creation.py +++ b/tests/ttnn/unit_tests/operations/test_creation.py @@ -297,6 +297,39 @@ def test_arange(device, start, end, step): assert_with_pcc(torch_output_tensor, output_tensor, 0.9999) +@pytest.mark.parametrize( + "start", + [4, 8, 16, 32], +) +@pytest.mark.parametrize( + "end", + [100, 200, 300], +) +@pytest.mark.parametrize( + "step", + [1, 2, 3, 4, 5], +) +def test_arange_multi_device(mesh_device, start, end, step): + torch_input_tensor = torch.rand((start, end, step), dtype=torch.bfloat16) + torch_output_tensor = torch.arange(start, end, step) + + output_tensor = ttnn.arange( + torch_input_tensor.shape[0], + torch_input_tensor.shape[1], + torch_input_tensor.shape[2], + ttnn.bfloat16, + mesh_device, + ) + output_tensor = ttnn.to_layout(output_tensor, ttnn.ROW_MAJOR_LAYOUT) + output_tensor = ttnn.from_device(output_tensor) + output_tensors = ttnn.to_torch(output_tensor, mesh_composer=ttnn.ListMeshToTensor(mesh_device)) + for output_tensor in output_tensors: + output_tensor = output_tensor[-1, -1, -1, :] + if divup((end - start), step) % 2 != 0: + output_tensor = output_tensor[:-1] + assert_with_pcc(torch_output_tensor, output_tensor, 0.9999) + + @pytest.mark.parametrize( "input_shapes", [ diff --git a/tests/ttnn/unit_tests/operations/test_new_conv2d.py b/tests/ttnn/unit_tests/operations/test_new_conv2d.py index 3e5f5f857f9..d41c5deae4f 100644 --- a/tests/ttnn/unit_tests/operations/test_new_conv2d.py +++ b/tests/ttnn/unit_tests/operations/test_new_conv2d.py @@ -137,19 +137,22 @@ def run_conv( conv_config = ttnn.Conv2dConfig( dtype=activations_dtype, weights_dtype=weights_dtype, - math_fidelity=math_fidelity, shard_layout=shard_layout, input_channels_alignment=( 16 if use_shallow_conv_variant or (input_channels == 16 and input_height == 115) else 32 ), deallocate_activation=deallocate_activation, - fp32_dest_acc_enabled=fp32_accum, - packer_l1_accum_enabled=packer_l1_acc, enable_act_double_buffer=False, enable_split_reader=False, enable_subblock_padding=False, output_layout=output_layout, ) + compute_config = ttnn.init_device_compute_kernel_config( + device.arch(), + math_fidelity=math_fidelity, + fp32_dest_acc_en=fp32_accum, + packer_l1_acc=packer_l1_acc, + ) if config_override and "act_block_h" in config_override and not auto_shard: conv_config.act_block_h_override = config_override["act_block_h"] @@ -162,7 +165,7 @@ def run_conv( conv_config.override_sharding_config = True print("Setting num_cores_nhw to 98") - [tt_output_tensor_on_device, out_height, out_width, weights_device, bias_device] = ttnn.conv2d( + [tt_output_tensor_on_device, [out_height, out_width], [weights_device, bias_device]] = ttnn.conv2d( input_tensor=tt_input_tensor, weight_tensor=tt_weight_tensor, in_channels=input_channels, @@ -177,10 +180,13 @@ def run_conv( input_height=input_height, input_width=input_width, conv_config=conv_config, + compute_config=compute_config, conv_op_cache=reader_patterns_cache, debug=debug, groups=groups, memory_config=memory_config, + return_weights_and_bias=True, + return_output_dim=True, ) tt_output_tensor = ttnn.from_device(tt_output_tensor_on_device) @@ -280,12 +286,15 @@ def run_conv_with_split( conv_config = ttnn.Conv2dConfig( dtype=activations_dtype, weights_dtype=weights_dtype, - math_fidelity=math_fidelity, shard_layout=shard_layout if use_1d_systolic_array else ttnn.TensorMemoryLayout.BLOCK_SHARDED, - fp32_dest_acc_enabled=fp32_accum, - packer_l1_accum_enabled=packer_l1_acc, # input_channels_alignment=(16 if use_shallow_conv_variant else 32), ) + compute_config = ttnn.init_device_compute_kernel_config( + device.arch(), + math_fidelity=math_fidelity, + fp32_dest_acc_en=fp32_accum, + packer_l1_acc=packer_l1_acc, + ) if config_override and "act_block_h" in config_override: conv_config.act_block_h_override = config_override["act_block_h"] print("Setting Act Block H to ", conv_config.act_block_h_override) @@ -306,7 +315,7 @@ def run_conv_with_split( tt_input_tensor = ttnn.from_torch(torch_input_tensor, ttnn.bfloat16) # tt_input_tensor_on_device = convs[i].copy_input_to_device(tt_input_tensor) # tt_output_tensor_on_device = convs[i](tt_input_tensor_on_device) - [tt_output_tensor_on_device, out_height, out_width, weights_device, bias_device] = ttnn.conv2d( + [tt_output_tensor_on_device, [out_height, out_width], [weights_device, bias_device]] = ttnn.conv2d( input_tensor=tt_input_tensor, weight_tensor=tt_weight_tensor, in_channels=split_input_channels, @@ -320,7 +329,10 @@ def run_conv_with_split( input_height=input_height, input_width=input_width, conv_config=conv_config, + compute_config=compute_config, conv_op_cache=reader_patterns_cache, + return_output_dim=True, + return_weights_and_bias=True, ) tt_conv_output_tensor = ttnn.from_device(tt_output_tensor_on_device) torch_conv_output_tensor = ttnn.to_torch(tt_conv_output_tensor) @@ -625,12 +637,9 @@ def test_conv_ws( conv_config = ttnn.Conv2dConfig( dtype=activations_dtype, weights_dtype=weights_dtype, - math_fidelity=ttnn.MathFidelity.HiFi4, shard_layout=ttnn.TensorMemoryLayout.WIDTH_SHARDED if not auto_shard else None, input_channels_alignment=32, deallocate_activation=deallocate_activation, - fp32_dest_acc_enabled=fp32_accum, - packer_l1_accum_enabled=packer_l1_acc, enable_act_double_buffer=False, enable_split_reader=False, enable_subblock_padding=False, @@ -638,7 +647,13 @@ def test_conv_ws( act_block_w_div=act_block_w_div if not auto_shard else 1, act_block_h_override=32, ) - [tt_output_tensor_on_device, out_height, out_width, weights_device, bias_device] = ttnn.conv2d( + compute_config = ttnn.init_device_compute_kernel_config( + device.arch(), + math_fidelity=ttnn.MathFidelity.HiFi4, + fp32_dest_acc_en=fp32_accum, + packer_l1_acc=packer_l1_acc, + ) + [tt_output_tensor_on_device, [out_height, out_width], [weights_device, bias_device]] = ttnn.conv2d( input_tensor=tt_input_tensor, weight_tensor=tt_weight_tensor, in_channels=input_channels, @@ -652,9 +667,12 @@ def test_conv_ws( input_height=input_height, input_width=input_width, conv_config=conv_config, + compute_config=compute_config, conv_op_cache=reader_patterns_cache, debug=debug, groups=groups, + return_output_dim=True, + return_weights_and_bias=True, ) tt_output_tensor = ttnn.from_device(tt_output_tensor_on_device) @@ -2730,7 +2748,7 @@ def test_shallow_conv_with_tiled_input(device): tt_input = ttnn.reshape(tt_input, (1, 1, batch_size * img_h * img_w, in_channels)) - tt_out, out_height, out_width, _, _ = ttnn.conv2d( + [tt_out, [out_height, out_width], [weights_device, bias_device]] = ttnn.conv2d( input_tensor=tt_input, weight_tensor=tt_kernel, in_channels=in_channels, @@ -2745,7 +2763,12 @@ def test_shallow_conv_with_tiled_input(device): input_height=img_h, input_width=img_w, groups=1, + compute_config=ttnn.init_device_compute_kernel_config( + device.arch(), + ), memory_config=ttnn.DRAM_MEMORY_CONFIG, + return_output_dim=True, + return_weights_and_bias=True, ) tt_output_tensor = ttnn.from_device(tt_out) diff --git a/tests/ttnn/unit_tests/operations/test_prepare_conv_weights.py b/tests/ttnn/unit_tests/operations/test_prepare_conv_weights.py index 23f28658fc3..09cafdd0aca 100644 --- a/tests/ttnn/unit_tests/operations/test_prepare_conv_weights.py +++ b/tests/ttnn/unit_tests/operations/test_prepare_conv_weights.py @@ -127,12 +127,11 @@ def test_prepare_conv_weights( dtype=ttnn.bfloat16, weights_dtype=ttnn.bfloat16, input_channels_alignment=(16 if input_channels == 16 and input_height == 115 else 32), - packer_l1_accum_enabled=packer_l1_acc, enable_act_double_buffer=False, enable_split_reader=False, enable_subblock_padding=False, ) - + compute_config = ttnn.init_device_compute_kernel_config(device.arch(), packer_l1_acc=packer_l1_acc) if config_override and "act_block_h" in config_override: conv_config.act_block_h_override = config_override["act_block_h"] @@ -179,11 +178,12 @@ def test_prepare_conv_weights( tt_weight_tensor_formatted = ttnn.to_device(tt_weight_tensor_formatted, device) tt_bias_tensor_formatted = ttnn.to_device(tt_bias_tensor_formatted, device) if has_bias else None (k := next(iter(conv_kwargs)), conv_kwargs.pop(k)) ##removing 1st element from dict - tt_output_tensor_on_device, _, _, _, _ = ttnn.conv2d( + tt_output_tensor_on_device = ttnn.conv2d( input_tensor=tt_input_tensor, weight_tensor=tt_weight_tensor_formatted, bias_tensor=tt_bias_tensor_formatted, **conv_kwargs, + compute_config=compute_config, ) tt_output_tensor = ttnn.from_device(tt_output_tensor_on_device) diff --git a/tests/ttnn/unit_tests/operations/test_small_resnet50_block.py b/tests/ttnn/unit_tests/operations/test_small_resnet50_block.py index 84ee4d5d972..638251f20da 100644 --- a/tests/ttnn/unit_tests/operations/test_small_resnet50_block.py +++ b/tests/ttnn/unit_tests/operations/test_small_resnet50_block.py @@ -103,7 +103,7 @@ def __call__(self, x, device, batch_size, input_height, input_width, conv_op_cac # logger.info("This module input shape - ", self.module_input_shape) # conv1 is 1x1 conv # print("Running conv1") - x, input_height, input_width, self.identity_conv_weight_tensor, _ = ttnn.conv2d( + x, [input_height, input_width], [self.identity_conv_weight_tensor, _] = ttnn.conv2d( input_tensor=x, weight_tensor=self.identity_conv_weight_tensor, in_channels=self.conv1_input_channels, @@ -118,12 +118,17 @@ def __call__(self, x, device, batch_size, input_height, input_width, conv_op_cac conv_config=ttnn.Conv2dConfig( dtype=self.model_config["ACTIVATIONS_DTYPE"], weights_dtype=self.model_config["WEIGHTS_DTYPE"], + ), + compute_config=ttnn.init_device_compute_kernel_config( + device.arch(), math_fidelity=self.model_config["MATH_FIDELITY"], ), conv_op_cache=conv_op_cache, + return_output_dim=True, + return_weights_and_bias=True, ) - out, input_height, input_width, self.conv1_weight_tensor, self.conv1_bias_tensor = ttnn.conv2d( + out, [input_height, input_width], [self.conv1_weight_tensor, self.conv1_bias_tensor] = ttnn.conv2d( input_tensor=x, weight_tensor=self.conv1_weight_tensor, in_channels=self.conv1_input_channels, @@ -139,14 +144,19 @@ def __call__(self, x, device, batch_size, input_height, input_width, conv_op_cac conv_config=ttnn.Conv2dConfig( dtype=self.model_config["ACTIVATIONS_DTYPE"], weights_dtype=self.model_config["WEIGHTS_DTYPE"], - math_fidelity=self.model_config["MATH_FIDELITY"], activation="relu", ), + compute_config=ttnn.init_device_compute_kernel_config( + device.arch(), + math_fidelity=self.model_config["MATH_FIDELITY"], + ), conv_op_cache=conv_op_cache, + return_output_dim=True, + return_weights_and_bias=True, ) if self.downsample: - ds_out, _, _, self.ds_conv_weight_tensor, self.ds_conv_bias_tensor = ttnn.conv2d( + ds_out, [self.ds_conv_weight_tensor, self.ds_conv_bias_tensor] = ttnn.conv2d( input_tensor=x, weight_tensor=self.ds_conv_weight_tensor, in_channels=self.ds_conv_input_channels, @@ -162,16 +172,21 @@ def __call__(self, x, device, batch_size, input_height, input_width, conv_op_cac conv_config=ttnn.Conv2dConfig( dtype=self.model_config["ACTIVATIONS_DTYPE"], weights_dtype=self.model_config["WEIGHTS_DTYPE"], + ), + compute_config=ttnn.init_device_compute_kernel_config( + device.arch(), math_fidelity=self.model_config["MATH_FIDELITY"], ), conv_op_cache=conv_op_cache, + return_output_dim=False, + return_weights_and_bias=True, ) ttnn.deallocate(x) else: ds_out = x # print("Running conv2") - out, input_height, input_width, self.conv2_weight_tensor, self.conv2_bias_tensor = ttnn.conv2d( + out, [input_height, input_width], [self.conv2_weight_tensor, self.conv2_bias_tensor] = ttnn.conv2d( input_tensor=out, weight_tensor=self.conv2_weight_tensor, in_channels=self.conv2_input_channels, @@ -187,15 +202,20 @@ def __call__(self, x, device, batch_size, input_height, input_width, conv_op_cac conv_config=ttnn.Conv2dConfig( dtype=self.model_config["ACTIVATIONS_DTYPE"], weights_dtype=self.model_config["WEIGHTS_DTYPE"], - math_fidelity=self.model_config["MATH_FIDELITY"], activation="relu", ), + compute_config=ttnn.init_device_compute_kernel_config( + device.arch(), + math_fidelity=self.model_config["MATH_FIDELITY"], + ), conv_op_cache=conv_op_cache, + return_output_dim=True, + return_weights_and_bias=True, ) # conv3 is 1x1 conv # print("Running conv3") - out, _, _, self.conv3_weight_tensor, self.conv3_bias_tensor = ttnn.conv2d( + out, [self.conv3_weight_tensor, self.conv3_bias_tensor] = ttnn.conv2d( input_tensor=out, weight_tensor=self.conv3_weight_tensor, in_channels=self.conv3_input_channels, @@ -211,9 +231,14 @@ def __call__(self, x, device, batch_size, input_height, input_width, conv_op_cac conv_config=ttnn.Conv2dConfig( dtype=self.model_config["ACTIVATIONS_DTYPE"], weights_dtype=self.model_config["WEIGHTS_DTYPE"], + ), + compute_config=ttnn.init_device_compute_kernel_config( + device.arch(), math_fidelity=self.model_config["MATH_FIDELITY"], ), conv_op_cache=conv_op_cache, + return_output_dim=False, + return_weights_and_bias=True, ) # underscore version is in_place = True diff --git a/tests/ttnn/unit_tests/operations/test_upsample.py b/tests/ttnn/unit_tests/operations/test_upsample.py index 86047a86581..e4a8846e3fc 100644 --- a/tests/ttnn/unit_tests/operations/test_upsample.py +++ b/tests/ttnn/unit_tests/operations/test_upsample.py @@ -83,7 +83,7 @@ def test_upsample_single_core(device, input_shapes, scale_h, scale_w): torch_result = torch_result.permute(0, 2, 3, 1) ## ttnn uses NHWC, so need to set scale_factor_c = 1 - scale_factor = (scale_h, scale_w, 1) + scale_factor = (scale_h, scale_w) input_tensor = ttnn.from_torch(input, device=device) output_tensor = ttnn.upsample(input_tensor, scale_factor) output_tensor = ttnn.to_torch(output_tensor) @@ -204,8 +204,7 @@ def test_upsample_multi_core(device, input_shape, scale_h, scale_w, shard_strate print(f"in_shard_mem_config: {in_sharded_mem_config}") print(f"out_shard_mem_config: {out_sharded_mem_config}") - ## ttnn uses NHWC, so need to set scale_factor_c = 1 - scale_factor = (scale_h, scale_w, 1) + scale_factor = (scale_h, scale_w) input_tensor = ttnn.from_torch(tt_input, device=device, memory_config=ttnn.L1_MEMORY_CONFIG) input_tensor = ttnn.to_memory_config(input_tensor, memory_config=in_sharded_mem_config) output_tensor = ttnn.upsample(input_tensor, scale_factor, memory_config=out_sharded_mem_config) @@ -337,8 +336,7 @@ def test_bilinear_multi_core( logger.debug(f"in_shard_mem_config: {in_sharded_mem_config}") logger.debug(f"out_shard_mem_config: {out_sharded_mem_config}") - ## ttnn uses NHWC, so need to set scale_factor_c = 1 - scale_factor = (scale_h, scale_w, 1) + scale_factor = (scale_h, scale_w) input_tensor = ttnn.from_torch(tt_input, device=device) input_tensor = ttnn.to_memory_config(input_tensor, memory_config=in_sharded_mem_config) output_tensor = ttnn.upsample( diff --git a/tests/ttnn/unit_tests/tensor/test_tensor_prealloc_and_write.py b/tests/ttnn/unit_tests/tensor/test_tensor_prealloc_and_write.py index a8418376d26..68df7937879 100644 --- a/tests/ttnn/unit_tests/tensor/test_tensor_prealloc_and_write.py +++ b/tests/ttnn/unit_tests/tensor/test_tensor_prealloc_and_write.py @@ -11,39 +11,76 @@ from models.utility_functions import is_grayskull +@pytest.mark.parametrize("shape", [(1, 10, 64, 96), (32, 1, 64, 64), (32, 3, 256, 256), (16, 1, 1024, 1024)]) @pytest.mark.parametrize("in_dtype", [ttnn.bfloat16, ttnn.bfloat8_b]) @pytest.mark.parametrize("mem_layout", [ttnn.TensorMemoryLayout.INTERLEAVED]) @pytest.mark.parametrize("memory_location", [ttnn.BufferType.L1, ttnn.BufferType.DRAM]) @pytest.mark.parametrize("tensor_layout", [ttnn.TILE_LAYOUT, ttnn.ROW_MAJOR_LAYOUT]) -@pytest.mark.parametrize( - "enable_async, num_loops", - ((True, 5), (False, 5)), -) +@pytest.mark.parametrize("enable_async", (False, True)) def test_tensor_preallocation_and_write_apis( - num_loops, enable_async, in_dtype, mem_layout, memory_location, tensor_layout, device + enable_async, shape, in_dtype, mem_layout, memory_location, tensor_layout, device ): if in_dtype == ttnn.bfloat8_b and tensor_layout == ttnn.ROW_MAJOR_LAYOUT: pytest.skip("Row Major Layout not supported for Bfp8") torch.manual_seed(0) device.enable_async(enable_async) + + # Preallocate tensor on device + preallocated_tensor = ttnn.allocate_tensor_on_device( + ttnn.Shape(shape), + in_dtype, + tensor_layout, + device, + ttnn.MemoryConfig(memory_layout=mem_layout, buffer_type=memory_location), + ) + for loop in range(5): + # Write to prreallocated tensor multiple times + input_tensor_a = torch.randn(shape).bfloat16() + tt_input_tensor_a = ttnn.Tensor(input_tensor_a, in_dtype).to(tensor_layout) + ttnn.copy_host_to_device_tensor(tt_input_tensor_a, preallocated_tensor) + readback = preallocated_tensor.cpu().to(ttnn.ROW_MAJOR_LAYOUT).to_torch() + allclose, output = comp_pcc(readback, input_tensor_a) + assert allclose, f"FAILED: {output}" + + +@pytest.mark.parametrize("shape", [(1, 10, 64, 96), (32, 3, 256, 256)]) +@pytest.mark.parametrize("in_dtype", [ttnn.bfloat16, ttnn.bfloat8_b]) +@pytest.mark.parametrize("mem_layout", [ttnn.TensorMemoryLayout.INTERLEAVED]) +@pytest.mark.parametrize("memory_location", [ttnn.BufferType.L1, ttnn.BufferType.DRAM]) +@pytest.mark.parametrize("tensor_layout", [ttnn.TILE_LAYOUT, ttnn.ROW_MAJOR_LAYOUT]) +@pytest.mark.parametrize("enable_async", (False, True)) +@pytest.mark.parametrize("mesh_device", ((1, 1), 4), indirect=True) +def test_tensor_preallocation_and_write_apis( + enable_async, shape, in_dtype, mem_layout, memory_location, tensor_layout, mesh_device +): + if in_dtype == ttnn.bfloat8_b and tensor_layout == ttnn.ROW_MAJOR_LAYOUT: + pytest.skip("Row Major Layout not supported for Bfp8") + torch.manual_seed(0) + mesh_device.enable_async(enable_async) shapes = [(1, 10, 64, 96), (32, 1, 64, 64), (32, 3, 256, 256), (16, 1, 1024, 1024)] - for tensor_shape in shapes: - # Preallocate tensor on device - preallocated_tensor = ttnn.allocate_tensor_on_device( - ttnn.Shape(tensor_shape), - in_dtype, - tensor_layout, - device, - ttnn.MemoryConfig(memory_layout=mem_layout, buffer_type=memory_location), + # Preallocate tensor on device + preallocated_tensor = ttnn.allocate_tensor_on_device( + ttnn.Shape(shape), + in_dtype, + tensor_layout, + mesh_device, + ttnn.MemoryConfig(memory_layout=mem_layout, buffer_type=memory_location), + ) + for loop in range(5): + # Write to prreallocated tensor multiple times + input_tensor_a = torch.randn(shape).bfloat16() + tt_input_tensor_a = ttnn.from_torch( + input_tensor_a, + dtype=in_dtype, + layout=tensor_layout, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + ) + ttnn.copy_host_to_device_tensor(tt_input_tensor_a, preallocated_tensor) + readback_tensors = ttnn.to_torch( + preallocated_tensor.cpu().to(ttnn.ROW_MAJOR_LAYOUT), + mesh_composer=ttnn.ListMeshToTensor(mesh_device), ) - for loop in range(num_loops): - # Write to prreallocated tensor multiple times - input_tensor_a = torch.randn(tensor_shape).bfloat16() - tt_input_tensor_a = ttnn.Tensor(input_tensor_a, in_dtype).to(tensor_layout) - ttnn.copy_host_to_device_tensor(tt_input_tensor_a, preallocated_tensor) - readback = preallocated_tensor.cpu().to(ttnn.ROW_MAJOR_LAYOUT).to_torch() - allclose, output = comp_pcc(readback, input_tensor_a) + for readback_tensor in readback_tensors: + allclose, output = comp_pcc(readback_tensor, input_tensor_a) assert allclose, f"FAILED: {output}" - - device.enable_async(False) diff --git a/tests/ttnn/unit_tests/test_sub_device.py b/tests/ttnn/unit_tests/test_sub_device.py index 7d3f93797a7..f7bfb20401a 100644 --- a/tests/ttnn/unit_tests/test_sub_device.py +++ b/tests/ttnn/unit_tests/test_sub_device.py @@ -7,7 +7,7 @@ import ttnn -def run_sub_devices(device): +def run_sub_devices(device, replicate_sub_devices=False): tensix_cores0 = ttnn.CoreRangeSet( { ttnn.CoreRange( @@ -26,16 +26,26 @@ def run_sub_devices(device): ) sub_device_1 = ttnn.SubDevice([tensix_cores0]) sub_device_2 = ttnn.SubDevice([tensix_cores1]) - sub_device_manager1 = device.create_sub_device_manager([sub_device_1, sub_device_2], 3200) - sub_device_manager2 = device.create_sub_device_manager([sub_device_2], 3200) + sub_devices_1 = [sub_device_1, sub_device_2] + sub_devices_2 = [sub_device_2] + if replicate_sub_devices: + num_devices = 1 if isinstance(device, ttnn.Device) else device.get_num_devices() + sub_devices_1 = [sub_devices_1] * num_devices + sub_devices_2 = [sub_devices_2] * num_devices + sub_device_manager1 = device.create_sub_device_manager(sub_devices_1, 3200) + sub_device_manager2 = device.create_sub_device_manager(sub_devices_2, 3200) device.load_sub_device_manager(sub_device_manager1) + ttnn.synchronize_devices(device, sub_device_ids=[ttnn.SubDeviceId(1)]) + ttnn.synchronize_devices(device, sub_device_ids=[ttnn.SubDeviceId(0), ttnn.SubDeviceId(1)]) + ttnn.synchronize_devices(device) device.load_sub_device_manager(sub_device_manager2) + ttnn.synchronize_devices(device, sub_device_ids=[ttnn.SubDeviceId(0)]) device.clear_loaded_sub_device_manager() device.remove_sub_device_manager(sub_device_manager1) device.remove_sub_device_manager(sub_device_manager2) -def run_sub_devices_program(device): +def run_sub_devices_program(device, replicate_sub_devices=False): is_mesh_device = isinstance(device, ttnn.MeshDevice) if is_mesh_device: inputs_mesh_mapper = ttnn.ShardTensorToMesh(device, dim=0) @@ -48,22 +58,26 @@ def run_sub_devices_program(device): tensix_cores0 = ttnn.CoreRangeSet( { ttnn.CoreRange( - ttnn.CoreCoord(0, 0), - ttnn.CoreCoord(3, 3), + ttnn.CoreCoord(4, 4), + ttnn.CoreCoord(4, 4), ), } ) tensix_cores1 = ttnn.CoreRangeSet( { ttnn.CoreRange( - ttnn.CoreCoord(4, 4), - ttnn.CoreCoord(4, 4), + ttnn.CoreCoord(0, 0), + ttnn.CoreCoord(3, 3), ), } ) sub_device_1 = ttnn.SubDevice([tensix_cores0]) sub_device_2 = ttnn.SubDevice([tensix_cores1]) - sub_device_manager = device.create_sub_device_manager([sub_device_1, sub_device_2], 3200) + sub_devices = [sub_device_1, sub_device_2] + if replicate_sub_devices: + num_devices = 1 if isinstance(device, ttnn.Device) else device.get_num_devices() + sub_devices = [sub_devices] * num_devices + sub_device_manager = device.create_sub_device_manager(sub_devices, 3200) device.load_sub_device_manager(sub_device_manager) x = torch.randn(num_devices, 1, 64, 64, dtype=torch.bfloat16) @@ -74,8 +88,19 @@ def run_sub_devices_program(device): device=device, memory_config=ttnn.L1_MEMORY_CONFIG, mesh_mapper=inputs_mesh_mapper, + sub_device_ids=[ttnn.SubDeviceId(0)], ) + xt_host = ttnn.from_torch( + x, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + mesh_mapper=inputs_mesh_mapper, + sub_device_ids=[ttnn.SubDeviceId(1)], + ) + + ttnn.copy_host_to_device_tensor(xt_host, xt, sub_device_ids=[ttnn.SubDeviceId(1)]) + grid_size = device.compute_with_storage_grid_size() shard_size = [32, 64] shard_scheme = ttnn.TensorMemoryLayout.HEIGHT_SHARDED @@ -83,11 +108,28 @@ def run_sub_devices_program(device): yt = ttnn.interleaved_to_sharded( xt, grid_size, shard_size, shard_scheme, shard_orientation, output_dtype=ttnn.bfloat16 ) - y = ttnn.to_torch(yt, device=device, mesh_composer=output_mesh_composer) + y = ttnn.to_torch(yt, device=device, mesh_composer=output_mesh_composer, sub_device_ids=[ttnn.SubDeviceId(1)]) + + eq = torch.equal(x, y) + assert eq + + y = ttnn.to_torch(yt.cpu(sub_device_ids=[ttnn.SubDeviceId(0)]), mesh_composer=output_mesh_composer) eq = torch.equal(x, y) assert eq + event = ttnn.create_event(device) + + yt2 = ttnn.interleaved_to_sharded( + xt, grid_size, shard_size, shard_scheme, shard_orientation, output_dtype=ttnn.bfloat16 + ) + ttnn.record_event(0, event, [ttnn.SubDeviceId(1)]) + ttnn.wait_for_event(0, event) + y2 = ttnn.to_torch(yt2, device=device, mesh_composer=output_mesh_composer, sub_device_ids=[ttnn.SubDeviceId(0)]) + + eq = torch.equal(x, y2) + assert eq + device.clear_loaded_sub_device_manager() device.remove_sub_device_manager(sub_device_manager) @@ -98,8 +140,9 @@ def test_sub_devices(device, enable_async_mode): @pytest.mark.parametrize("enable_async_mode", (False, True), indirect=True) -def test_sub_devices_mesh(mesh_device, enable_async_mode): - run_sub_devices(mesh_device) +@pytest.mark.parametrize("replicate_sub_devices", (False, True)) +def test_sub_devices_mesh(mesh_device, replicate_sub_devices, enable_async_mode): + run_sub_devices(mesh_device, replicate_sub_devices) @pytest.mark.parametrize("enable_async_mode", (False, True), indirect=True) @@ -108,5 +151,6 @@ def test_sub_device_program(device, enable_async_mode): @pytest.mark.parametrize("enable_async_mode", (False, True), indirect=True) -def test_sub_device_program_mesh(mesh_device, enable_async_mode): - run_sub_devices_program(mesh_device) +@pytest.mark.parametrize("replicate_sub_devices", (False, True)) +def test_sub_device_program_mesh(mesh_device, replicate_sub_devices, enable_async_mode): + run_sub_devices_program(mesh_device, replicate_sub_devices) diff --git a/tests/ttnn/unit_tests/test_unsqueeze.py b/tests/ttnn/unit_tests/test_unsqueeze.py index cbd10b3cc66..40bf576cdf7 100644 --- a/tests/ttnn/unit_tests/test_unsqueeze.py +++ b/tests/ttnn/unit_tests/test_unsqueeze.py @@ -10,20 +10,40 @@ @pytest.mark.parametrize( - "input_shape, dim", + "input_shape, dim, layout", [ - ((1, 1, 256), 2), - ((1, 1, 256), -2), - ((1, 256), 1), - ((1, 1, 30), 2), - ((1, 1, 30), -2), - ((1, 30), 1), + ((1, 1, 253), 2, ttnn.ROW_MAJOR_LAYOUT), + ((1, 1, 253), -2, ttnn.ROW_MAJOR_LAYOUT), + ((1, 253), 1, ttnn.ROW_MAJOR_LAYOUT), + ((1, 1, 253), -2, ttnn.TILE_LAYOUT), + ((1, 253), 1, ttnn.TILE_LAYOUT), + ((57, 83), 1, ttnn.TILE_LAYOUT), + ((123, 259), -2, ttnn.TILE_LAYOUT), + ((57, 83), 1, ttnn.ROW_MAJOR_LAYOUT), + ((123, 259), -2, ttnn.ROW_MAJOR_LAYOUT), + ((8732,), 1, ttnn.ROW_MAJOR_LAYOUT), + ((8732,), -1, ttnn.ROW_MAJOR_LAYOUT), + ((8732,), 0, ttnn.ROW_MAJOR_LAYOUT), ], ) -def test_unsqueeze(device, input_shape, dim): +def test_unsqueeze(device, input_shape, dim, layout): torch_input_tensor = torch.rand(input_shape, dtype=torch.bfloat16) torch_unsqueeze_tensor = torch.unsqueeze(torch_input_tensor, dim) - input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.ROW_MAJOR_LAYOUT, device=device) + input_tensor = ttnn.from_torch(torch_input_tensor, layout=layout, device=device) ttnn_output = ttnn.unsqueeze(input_tensor, dim) torch_output_tensor = ttnn.to_torch(ttnn_output) assert torch.allclose(torch_output_tensor, torch_unsqueeze_tensor) + + +@pytest.mark.parametrize( + "input_shape, dim, layout", + [ + ((1, 1, 253), 4, ttnn.ROW_MAJOR_LAYOUT), + ((1, 1, 253), -5, ttnn.ROW_MAJOR_LAYOUT), + ], +) +def test_invalid_cases(device, input_shape, dim, layout): + torch_input_tensor = torch.rand(input_shape, dtype=torch.bfloat16) + input_tensor = ttnn.from_torch(torch_input_tensor, layout=layout, device=device) + with pytest.raises(RuntimeError): + ttnn.unsqueeze(input_tensor, dim) diff --git a/tt-train/cmake/dependencies.cmake b/tt-train/cmake/dependencies.cmake index 8972da32891..2f6102a8c55 100644 --- a/tt-train/cmake/dependencies.cmake +++ b/tt-train/cmake/dependencies.cmake @@ -58,6 +58,14 @@ CPMAddPackage(NAME xtl GITHUB_REPOSITORY xtensor-stack/xtl GIT_TAG 0.7.7 OPTIONS CPMAddPackage(NAME xtensor GITHUB_REPOSITORY xtensor-stack/xtensor GIT_TAG 0.25.0 OPTIONS "XTENSOR_ENABLE_TESTS OFF") +CPMAddPackage( + NAME xtensor-blas + GITHUB_REPOSITORY xtensor-stack/xtensor-blas + GIT_TAG 0.21.0 + OPTIONS + "XTENSOR_ENABLE_TESTS OFF" +) + include(${PROJECT_SOURCE_DIR}/cmake/fetch_msgpack.cmake) include(${PROJECT_SOURCE_DIR}/cmake/fetch_cli11.cmake) diff --git a/tt-train/sources/ttml/CMakeLists.txt b/tt-train/sources/ttml/CMakeLists.txt index 9919e85f89c..0e241cd7bb6 100644 --- a/tt-train/sources/ttml/CMakeLists.txt +++ b/tt-train/sources/ttml/CMakeLists.txt @@ -96,6 +96,7 @@ target_link_libraries( magic_enum yaml-cpp::yaml-cpp xtensor + xtensor-blas xtl tokenizers_cpp wandbcpp diff --git a/tt-train/sources/ttml/autograd/auto_context.cpp b/tt-train/sources/ttml/autograd/auto_context.cpp index dbe16758b81..ea0e27e269b 100644 --- a/tt-train/sources/ttml/autograd/auto_context.cpp +++ b/tt-train/sources/ttml/autograd/auto_context.cpp @@ -22,8 +22,8 @@ uint32_t AutoContext::get_seed() const { } AutoContext& AutoContext::get_instance() { - static AutoContext instance; - return instance; + static core::Indestructible instance{}; + return instance.get(); } std::optional AutoContext::add_backward_node(GradFunction&& grad_function, std::span links) { if (m_grads_mode == GradMode::DISABLED) { @@ -42,10 +42,36 @@ void AutoContext::reset_graph() { m_graph.reset(); } +void AutoContext::open_device() { + if (m_device) { + throw std::runtime_error("open_device was called after the device was created."); + } + m_device = std::make_unique(m_mesh_shape); +} + +void AutoContext::close_device() { + m_device = nullptr; +} + ttnn::distributed::MeshDevice& AutoContext::get_device() { - return device.get_device(); + if (!m_device) { + open_device(); + } + + return m_device->get_device(); } AutoContext::AutoContext() : m_generator(m_seed) { } + +void AutoContext::set_mesh_shape(tt::tt_metal::distributed::MeshShape shape) { + if (m_device) { + throw std::runtime_error("set_mesh_shape was called after the device was created."); + } + m_mesh_shape = shape; +} + +tt::tt_metal::distributed::MeshShape AutoContext::get_mesh_shape() const { + return m_mesh_shape; +} } // namespace ttml::autograd diff --git a/tt-train/sources/ttml/autograd/auto_context.hpp b/tt-train/sources/ttml/autograd/auto_context.hpp index 92002025cd7..a4124862ed3 100644 --- a/tt-train/sources/ttml/autograd/auto_context.hpp +++ b/tt-train/sources/ttml/autograd/auto_context.hpp @@ -4,8 +4,10 @@ #pragma once +#include #include +#include "core/indestructible.hpp" #include "core/mesh_device.hpp" #include "graph.hpp" @@ -40,6 +42,14 @@ class AutoContext { ~AutoContext() = default; // to make it work with unique_ptr. ttnn::distributed::MeshDevice& get_device(); + + void set_mesh_shape(tt::tt_metal::distributed::MeshShape shape); + [[nodiscard]] tt::tt_metal::distributed::MeshShape get_mesh_shape() const; + + void open_device(); + + void close_device(); + private: AutoContext(); uint32_t m_seed = 5489U; @@ -48,8 +58,10 @@ class AutoContext { GradMode m_grads_mode = GradMode::ENABLED; Graph m_graph; + tt::tt_metal::distributed::MeshShape m_mesh_shape = {1, 1}; + std::unique_ptr m_device; - core::MeshDevice device{0}; + friend class core::Indestructible; }; inline auto& ctx() { diff --git a/tt-train/sources/ttml/core/distributed_mapping.hpp b/tt-train/sources/ttml/core/distributed_mapping.hpp new file mode 100644 index 00000000000..d40644486da --- /dev/null +++ b/tt-train/sources/ttml/core/distributed_mapping.hpp @@ -0,0 +1,283 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once +#include +#include +#include + +#include "core/xtensor_utils.hpp" + +namespace ttml::core { +template +std::vector> chunk(const xt::xarray& tensor, int num_chunks, int dim) { + if (num_chunks <= 0) { + throw std::invalid_argument("num_chunks must be > 0"); + } + if (dim < 0 || static_cast(dim) >= tensor.dimension()) { + throw std::invalid_argument("invalid dimension index"); + } + + int size_along_dim = static_cast(tensor.shape()[dim]); + if (num_chunks > size_along_dim) { + throw std::invalid_argument("num_chunks cannot exceed the size of the tensor along the given dimension."); + } + + if (num_chunks == 1) { + return {tensor}; + } + + int chunk_size = (size_along_dim + num_chunks - 1) / num_chunks; + int remaining_size = size_along_dim; + + std::vector> chunks; + chunks.reserve(static_cast(num_chunks)); + + int start = 0; + int end = 0; + for (int i = 0; i < num_chunks && end < size_along_dim; ++i) { + int current_chunk_size = std::min(chunk_size, remaining_size); + remaining_size -= current_chunk_size; + end = start + current_chunk_size; + + // Build indices for slicing + xt::xstrided_slice_vector indices(tensor.dimension(), xt::all()); + indices[dim] = xt::range(start, end); + + auto chunk_view = xt::strided_view(tensor, indices); + + // Construct xarray from the view + // This forces a copy of that slice into a new xarray + chunks.push_back(xt::xarray(chunk_view)); + start = end; + } + + return chunks; +} + +template +class XTensorToMesh { +public: + XTensorToMesh(tt::tt_metal::distributed::MeshShape mesh_shape) : m_mesh_shape(std::move(mesh_shape)) { + } + + std::vector> map(const xt::xarray& tensor) const { + return static_cast(this)->map_impl(tensor); + } + + std::unordered_map config() const { + return static_cast(this)->config_impl(); + } + +protected: + tt::tt_metal::distributed::MeshShape m_mesh_shape; + + size_t get_num_devices() const { + return m_mesh_shape.num_rows * m_mesh_shape.num_cols; + } +}; + +template +class MeshToXTensor { +public: + MeshToXTensor(tt::tt_metal::distributed::MeshShape mesh_shape) : m_mesh_shape(std::move(mesh_shape)) { + } + + std::vector> compose(const std::vector>& tensors) const { + return static_cast(this)->compose_impl(tensors); + } + +protected: + tt::tt_metal::distributed::MeshShape m_mesh_shape; +}; + +template +class ShardXTensorToMesh : public XTensorToMesh, T> { +public: + using Base = XTensorToMesh, T>; + ShardXTensorToMesh(tt::tt_metal::distributed::MeshShape mesh_shape, int dim) : + Base(std::move(mesh_shape)), m_shard_dim(dim) { + } + + std::vector> map_impl(const xt::xarray& tensor) const { + int num_devices = Base::get_num_devices(); + auto sliced_tensors = chunk(tensor, num_devices, m_shard_dim); + return sliced_tensors; + } + + std::unordered_map config_impl() const { + return {{"strategy", "shard"}, {"shard_dim", std::to_string(m_shard_dim)}}; + } + +private: + int m_shard_dim = 0; +}; + +template +class ShardTensor2dMesh : public XTensorToMesh, T> { +public: + using Base = XTensorToMesh, T>; + ShardTensor2dMesh( + tt::tt_metal::distributed::MeshShape mesh_shape, + const std::pair, std::optional>& dims) : + Base(std::move(mesh_shape)), m_dims(dims) { + // We trust the provided mesh shape and do not validate against a MeshDevice. + } + + std::vector> map_impl(const xt::xarray& tensor) const { + if (!m_dims.first.has_value() && !m_dims.second.has_value()) { + throw std::invalid_argument("ShardTensor2dMesh requires at least one dimension to shard"); + } + + int rows = Base::m_mesh_shape.num_rows; + int cols = Base::m_mesh_shape.num_cols; + auto row_dim = m_dims.first; + auto col_dim = m_dims.second; + + std::vector> row_tensors; + + // Shard along rows + if (!row_dim.has_value()) { + row_tensors.reserve(rows); + for (int i = 0; i < rows; ++i) { + row_tensors.push_back(tensor); + } + } else { + row_tensors = chunk(tensor, rows, row_dim.value()); + } + + std::vector> tensor_shards; + tensor_shards.reserve(static_cast(rows * cols)); + // Shard along columns + if (!col_dim.has_value()) { + for (const auto& t : row_tensors) { + for (int i = 0; i < cols; ++i) { + tensor_shards.push_back(t); + } + } + } else { + for (const auto& t : row_tensors) { + auto col_chunks = chunk(t, cols, col_dim.value()); + tensor_shards.insert(tensor_shards.end(), col_chunks.begin(), col_chunks.end()); + } + } + + if (static_cast(tensor_shards.size()) != rows * cols) { + throw std::runtime_error(fmt::format( + "ShardTensor2dMesh: Sharding failed. Number of shards should match the product of the mesh " + "dimensions. Size: {}, rows: {}, cols: {}", + tensor_shards.size(), + rows, + cols)); + } + + return tensor_shards; + } + + std::unordered_map config_impl() const { + return { + {"strategy", "shard_2d"}, + {"mesh_shape_y", std::to_string(Base::m_mesh_shape.num_rows)}, + {"mesh_shape_x", std::to_string(Base::m_mesh_shape.num_cols)}}; + } + +private: + std::pair, std::optional> m_dims; +}; + +template +class ConcatMesh2dToTensor : public MeshToXTensor, T> { +public: + using Base = MeshToXTensor, T>; + ConcatMesh2dToTensor( + tt::tt_metal::distributed::MeshShape mesh_shape, const tt::tt_metal::distributed::MeshShape& dims) : + Base(std::move(mesh_shape)), m_dims(dims) { + if (m_dims.num_rows == m_dims.num_cols) { + throw std::invalid_argument("Dimensions in 'dims' must be different"); + } + } + + std::vector> compose_impl(const std::vector>& tensors) const { + int rows = Base::m_mesh_shape.num_rows; + int cols = Base::m_mesh_shape.num_cols; + size_t row_dim = m_dims.num_rows; + size_t col_dim = m_dims.num_cols; + + std::vector> row_concatenated; + row_concatenated.reserve(static_cast(rows)); + + for (int i = 0; i < rows; ++i) { + auto row_start = tensors.begin() + i * cols; + auto row_end = row_start + cols; + std::vector> row_tensors(row_start, row_end); + + auto concatenated_row = core::concatenate(row_tensors, col_dim); + row_concatenated.push_back(std::move(concatenated_row)); + } + + auto result = core::concatenate(row_concatenated, row_dim); + return {result}; + } + +private: + tt::tt_metal::distributed::MeshShape m_dims; +}; + +template +class ReplicateXTensorToMesh : public XTensorToMesh, T> { +public: + using Base = XTensorToMesh, T>; + ReplicateXTensorToMesh(tt::tt_metal::distributed::MeshShape mesh_shape) : Base(std::move(mesh_shape)) { + } + + std::vector> map_impl(const xt::xarray& tensor) const { + int num_devices = Base::get_num_devices(); + std::vector> tensors; + tensors.reserve(static_cast(num_devices)); + for (int i = 0; i < num_devices; ++i) { + tensors.push_back(tensor); // Note: this copies the tensor + } + return tensors; + } + + std::unordered_map config_impl() const { + int num_devices = Base::get_num_devices(); + return {{"strategy", "replicate"}, {"replication_factor", std::to_string(num_devices)}}; + } +}; + +template +class ConcatMeshToXTensor : public MeshToXTensor, T> { +public: + using Base = MeshToXTensor, T>; + ConcatMeshToXTensor(tt::tt_metal::distributed::MeshShape mesh_shape, int dim) : + Base(std::move(mesh_shape)), m_concat_dim(dim) { + } + + std::vector> compose_impl(const std::vector>& tensors) const { + return {core::concatenate(tensors, m_concat_dim)}; + } + +private: + int m_concat_dim = 0; +}; + +template +class VectorMeshToXTensor : public MeshToXTensor, T> { +public: + using Base = MeshToXTensor, T>; + VectorMeshToXTensor([[maybe_unused]] tt::tt_metal::distributed::MeshShape mesh_shape) : Base(mesh_shape) { + } + std::vector> compose_impl(const std::vector>& tensors) const { + return tensors; + } +}; + +template +using XTensorToMeshVariant = std::variant, ShardTensor2dMesh, ReplicateXTensorToMesh>; + +template +using MeshToXTensorVariant = std::variant, ConcatMesh2dToTensor, VectorMeshToXTensor>; + +} // namespace ttml::core diff --git a/tt-train/sources/ttml/core/indestructible.hpp b/tt-train/sources/ttml/core/indestructible.hpp new file mode 100644 index 00000000000..eb30d101bd2 --- /dev/null +++ b/tt-train/sources/ttml/core/indestructible.hpp @@ -0,0 +1,40 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once +#include + +namespace ttml::core { + +template +class Indestructible { +public: + template + explicit Indestructible(Args&&... args) { + // Construct T in our aligned storage + new (&storage) T(std::forward(args)...); + } + + T& get() { + return *reinterpret_cast(&storage); + } + + const T& get() const { + return *reinterpret_cast(&storage); + } + + // Disable copy and assignment + Indestructible(const Indestructible&) = delete; + Indestructible& operator=(const Indestructible&) = delete; + + // Destructor does NOT call T's destructor. + // This leaves the object "indestructible." + ~Indestructible() = default; + +private: + // A buffer of unsigned char with alignment of T and size of T + alignas(T) unsigned char storage[sizeof(T)]; +}; + +} // namespace ttml::core diff --git a/tt-train/sources/ttml/core/mesh_device.cpp b/tt-train/sources/ttml/core/mesh_device.cpp index f30bbf9b884..33f3d026556 100644 --- a/tt-train/sources/ttml/core/mesh_device.cpp +++ b/tt-train/sources/ttml/core/mesh_device.cpp @@ -4,13 +4,11 @@ #include "mesh_device.hpp" -#include - namespace ttml::core { -MeshDevice::MeshDevice([[maybe_unused]] int device_index) : +MeshDevice::MeshDevice(tt::tt_metal::distributed::MeshShape shape) : m_mesh_device(ttnn::distributed::api::open_mesh_device( - ttnn::distributed::MeshShape(1, 1), + shape, DEFAULT_L1_SMALL_SIZE, DEFAULT_TRACE_REGION_SIZE, /* num_command_queues*/ 1, @@ -25,6 +23,7 @@ MeshDevice::MeshDevice([[maybe_unused]] int device_index) : } MeshDevice::~MeshDevice() { + assert(m_mesh_device); ttnn::distributed::api::close_mesh_device(m_mesh_device); } diff --git a/tt-train/sources/ttml/core/mesh_device.hpp b/tt-train/sources/ttml/core/mesh_device.hpp index 1d38bbfe3bb..490f9d5b45c 100644 --- a/tt-train/sources/ttml/core/mesh_device.hpp +++ b/tt-train/sources/ttml/core/mesh_device.hpp @@ -4,14 +4,14 @@ #pragma once -#include #include +#include namespace ttml::core { // should I implement pimpl or its fine class MeshDevice { public: - explicit MeshDevice(int device_index); + explicit MeshDevice(tt::tt_metal::distributed::MeshShape shape); MeshDevice(MeshDevice&& device) = default; MeshDevice(const MeshDevice&) = delete; diff --git a/tt-train/sources/ttml/core/tt_tensor_utils.cpp b/tt-train/sources/ttml/core/tt_tensor_utils.cpp index d9f20c55ff1..706c8d98dfc 100644 --- a/tt-train/sources/ttml/core/tt_tensor_utils.cpp +++ b/tt-train/sources/ttml/core/tt_tensor_utils.cpp @@ -8,13 +8,14 @@ #include #include -#include #include #include #include #include #include +#include "core/xtensor_utils.hpp" + namespace { template @@ -180,6 +181,55 @@ tt::tt_metal::Tensor ones(const ttnn::Shape& shape, ttnn::distributed::MeshDevic return core::full(shape, 1.F, device, dtype); } +template +[[nodiscard]] tt::tt_metal::Tensor from_xtensors_to_host( + const std::vector>& buffers, const std::unordered_map& config) { + std::vector host_owned_buffers; + std::vector host_owned_shapes; + host_owned_buffers.reserve(buffers.size()); + host_owned_shapes.reserve(buffers.size()); + if (buffers.empty()) { + throw std::runtime_error("Cannot create a host buffer from an empty vector of xtensors!"); + } + auto first_shape = buffers.front().shape(); + for (int i = 0; i < buffers.size(); ++i) { + if (buffers[i].shape() != first_shape) { + throw std::runtime_error(fmt::format( + "Cannot create a host buffer from xtensors with different shapes: {} vs {}!", + get_shape_4d(buffers[0]), + get_shape_4d(buffers[i]))); + } + } + for (const auto& buffer : buffers) { + auto shape = create_shape(get_shape_4d(buffer)); + + if constexpr (std::is_same_v) { + auto owned_buffer = + create_owned_buffer_from_vector_of_floats(std::vector(buffer.begin(), buffer.end()), TensorType); + host_owned_buffers.push_back(owned_buffer); + } else { + auto owned_buffer = tt::tt_metal::owned_buffer::create(std::vector(buffer.begin(), buffer.end())); + host_owned_buffers.push_back(owned_buffer); + } + + host_owned_shapes.push_back(shape); + } + auto distributed_tensor_config = get_distributed_tensor_config(config); + auto storage = tt::tt_metal::MultiDeviceHostStorage( + distributed_tensor_config, std::move(host_owned_buffers), host_owned_shapes); + + // remove possible paddings from the shape (it conflicts with ROW MAJOR) + auto output = Tensor(std::move(storage), host_owned_shapes[0], TensorType, Layout::ROW_MAJOR); + return output; +} + +template tt::tt_metal::Tensor from_xtensors_to_host( + const std::vector>& buffers, const std::unordered_map& config); +template tt::tt_metal::Tensor from_xtensors_to_host( + const std::vector>& buffers, const std::unordered_map& config); +template tt::tt_metal::Tensor from_xtensors_to_host( + const std::vector>& buffers, const std::unordered_map& config); + template <> tt::tt_metal::Tensor from_vector( const std::vector& buffer, const ttnn::Shape& shape, ttnn::distributed::MeshDevice* device, Layout layout) { @@ -195,17 +245,10 @@ tt::tt_metal::Tensor from_vector( auto owned_buffer = create_owned_buffer_from_vector_of_floats(buffer, data_type); // remove possible paddings from the shape (it conflicts with ROW MAJOR) auto output = tt::tt_metal::Tensor(OwnedStorage{owned_buffer}, logical_shape, data_type, Layout::ROW_MAJOR); - - auto to_device_even_fast = [&]() { - output = ttnn::to_device(output, device, output_mem_config); - if (layout == Layout::TILE) { - output = ttnn::tilize_with_zero_padding(output, output_mem_config, std::nullopt, /* multicore */ true); - } - - return output; - }; - - output = to_device_even_fast(); + output = ttnn::to_device(output, device, output_mem_config); + if (layout == Layout::TILE) { + output = ttnn::tilize_with_zero_padding(output, output_mem_config, std::nullopt, /* multicore */ true); + } return output; } diff --git a/tt-train/sources/ttml/core/tt_tensor_utils.hpp b/tt-train/sources/ttml/core/tt_tensor_utils.hpp index 5d809935ea9..6775bde4e6c 100644 --- a/tt-train/sources/ttml/core/tt_tensor_utils.hpp +++ b/tt-train/sources/ttml/core/tt_tensor_utils.hpp @@ -5,9 +5,11 @@ #pragma once #include -#include +#include #include +#include "core/distributed_mapping.hpp" + namespace ttml::core { void print_tensor_stats(const tt::tt_metal::Tensor& tensor, const std::string& name); @@ -31,6 +33,10 @@ template ttnn::distributed::MeshDevice* device, Layout layout = Layout::TILE); +template +[[nodiscard]] tt::tt_metal::Tensor from_xtensors_to_host( + const std::vector>& buffers, const std::unordered_map& config); + template [[nodiscard]] std::vector to_vector(const tt::tt_metal::Tensor& tensor); @@ -38,4 +44,49 @@ template [[nodiscard]] ttnn::Shape create_shape(const std::array& args); +template +[[nodiscard]] tt::tt_metal::Tensor from_xtensor( + const xt::xarray& buffer, ttnn::distributed::MeshDevice* device, Layout layout = Layout::TILE) { + auto shape = create_shape(get_shape_4d(buffer)); + auto buffer_view = xtensor_to_span(buffer); + return from_vector(std::vector(buffer_view.begin(), buffer_view.end()), shape, device, layout); +} + +template +[[nodiscard]] xt::xarray to_xtensor(const tt::tt_metal::Tensor& tensor) { + auto vec = to_vector(tensor); + auto shape = tensor.get_shape().logical_shape(); + return span_to_xtensor(std::span(vec.data(), vec.size()), shape); +} + +template +auto to_xtensor(const tt::tt_metal::Tensor& tensor, const MeshToXTensorVariant& composer) { + auto cpu_tensor = tensor.cpu(); + cpu_tensor = cpu_tensor.to(Layout::ROW_MAJOR); + auto cpu_tensors = ttnn::distributed::api::get_device_tensors(cpu_tensor); + std::vector> res; + res.reserve(cpu_tensors.size()); + for (const auto& shard : cpu_tensors) { + res.push_back(to_xtensor(shard)); + } + return std::visit([&res](auto&& arg) { return arg.compose(res); }, composer); +} + +template +tt::tt_metal::Tensor from_xtensor( + const xt::xarray& tensor, + ttnn::distributed::MeshDevice* device, + const XTensorToMeshVariant& composer, + Layout layout = Layout::TILE) { + auto sharded_tensors = std::visit([&tensor](auto&& arg) { return arg.map(tensor); }, composer); + auto config = std::visit([](auto&& arg) { return arg.config(); }, composer); + auto output = from_xtensors_to_host(sharded_tensors, config); + MemoryConfig output_mem_config{}; + output = ttnn::to_device(output, device, output_mem_config); + if (layout == Layout::TILE) { + output = ttnn::tilize_with_zero_padding(output, output_mem_config, std::nullopt, /* multicore */ true); + } + return output; +} + } // namespace ttml::core diff --git a/tt-train/sources/ttml/core/ttnn_all_includes.hpp b/tt-train/sources/ttml/core/ttnn_all_includes.hpp index c01c7b804c2..d41cf6eea2f 100644 --- a/tt-train/sources/ttml/core/ttnn_all_includes.hpp +++ b/tt-train/sources/ttml/core/ttnn_all_includes.hpp @@ -9,7 +9,8 @@ #pragma GCC diagnostic ignored "-Wdeprecated-volatile" #pragma GCC diagnostic ignored "-Wdeprecated-this-capture" -#include // NOLINT +#include // NOLINT +#include #include // NOLINT #include // NOLINT #include // NOLINT @@ -54,8 +55,10 @@ #include // NOLINT #include // NOLINT #include // NOLINT +#include // NOLINT #include // NOLINT #include // NOLINT #include // NOLINT #include // NOLINT + #pragma GCC diagnostic pop diff --git a/tt-train/sources/ttml/core/xtensor_all_includes.hpp b/tt-train/sources/ttml/core/xtensor_all_includes.hpp new file mode 100644 index 00000000000..12bdd2addb8 --- /dev/null +++ b/tt-train/sources/ttml/core/xtensor_all_includes.hpp @@ -0,0 +1,18 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include diff --git a/tt-train/sources/ttml/core/xtensor_utils.cpp b/tt-train/sources/ttml/core/xtensor_utils.cpp new file mode 100644 index 00000000000..96c0d0a7c1f --- /dev/null +++ b/tt-train/sources/ttml/core/xtensor_utils.cpp @@ -0,0 +1,65 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "xtensor_utils.hpp" + +namespace ttml::core { +namespace detail { +template +auto vector_to_tuple_helper(const std::vector& v, std::index_sequence) { + return std::make_tuple(v[Indices]...); +} + +template +auto vector_to_tuple(const std::vector& buffer) { + assert(buffer.size() >= N); + return vector_to_tuple_helper(buffer, std::make_index_sequence()); +} + +template +xt::xarray concat_helper(const std::vector>& v, size_t axis = 0) { + constexpr int FIXED_N = N < 2 ? 2 : N; + if (N < 2) { + throw std::runtime_error("Tuple size in concatenate must be greater than 1"); + } + auto tuple = detail::vector_to_tuple(v); + return xt::concatenate(std::move(tuple), axis); +} + +template +consteval auto create_array_impl(std::index_sequence) { + return std::array (*)(const std::vector>& v, size_t axis), sizeof...(I)>{ + concat_helper...}; +} + +template +consteval auto create_array() { + return create_array_impl(std::make_index_sequence()); +} + +} // namespace detail + +template +xt::xarray concatenate(const std::vector>& v, size_t axis) { + constexpr size_t MAX_TUPLE_SIZE = 64; + + if (v.empty()) { + return {}; + } + if (v.size() == 1) { + return v.front(); + } + if (v.size() > MAX_TUPLE_SIZE) { + throw std::runtime_error( + fmt::format("Number of tensors to concatenate exceeds the maximum supported size {}", MAX_TUPLE_SIZE)); + } + constexpr auto table = detail::create_array(); + return (*table[v.size()])(v, axis); +} + +template xt::xarray concatenate(const std::vector>& v, size_t axis); +template xt::xarray concatenate(const std::vector>& v, size_t axis); +template xt::xarray concatenate(const std::vector>& v, size_t axis); +template xt::xarray concatenate(const std::vector>& v, size_t axis); +} // namespace ttml::core diff --git a/tt-train/sources/ttml/core/xtensor_utils.hpp b/tt-train/sources/ttml/core/xtensor_utils.hpp new file mode 100644 index 00000000000..153323f3e32 --- /dev/null +++ b/tt-train/sources/ttml/core/xtensor_utils.hpp @@ -0,0 +1,59 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include +#include +#include + +// TODO: decide if we want to use xarray everwhere or xtensor is ok +/* +Difference between xtensor and xarray: + +xarray : tensor that can be reshaped to any number of dimensions. xtensor : tensor with a number of dimensions +set to N at compile time. xtensor_fixed : tensor whose shape is fixed at compile time. +*/ + +namespace ttml::core { +template +xt::xarray span_to_xtensor(std::span vec, const ttnn::SimpleShape& shape) { + std::vector shape_vec(shape.cbegin(), shape.cend()); + return xt::adapt(vec.data(), vec.size(), xt::no_ownership(), shape_vec); +} +template +auto xtensor_to_span(const xt::xarray& xtensor) { + auto adaptor = xt::adapt(xtensor.data(), xtensor.size(), xt::no_ownership()); + return std::span(adaptor.data(), adaptor.size()); +} + +// TODO: decide if we want to keep this function with E or use the xtensor type directly +template +std::array get_shape_4d(const E& expr) { + const int max_dims = 4; + // TODO: Ensure that E is an xtensor expression + + // Retrieve the shape of the tensor + auto& expr_shape = expr.shape(); + std::array shape4d = {1, 1, 1, 1}; + + size_t dims = expr_shape.size(); + + if (dims > max_dims) { + throw std::runtime_error(fmt::format("Number of dimensions {} greater than max_shape {}", dims, max_dims)); + } + + // Copy the dimensions into the shape array + for (size_t i = 0; i < dims; ++i) { + shape4d[i + max_dims - dims] = static_cast(expr_shape[i]); + } + + return shape4d; +} + +template +xt::xarray concatenate(const std::vector>& v, size_t axis = 0); + +} // namespace ttml::core diff --git a/tt-train/tests/3rd_party/xtensor_test.cpp b/tt-train/tests/3rd_party/xtensor_test.cpp index ddd5c3b63fd..6a5b6317c17 100644 --- a/tt-train/tests/3rd_party/xtensor_test.cpp +++ b/tt-train/tests/3rd_party/xtensor_test.cpp @@ -4,9 +4,9 @@ #include -#include -#include -#include +#include + +#include "core/xtensor_utils.hpp" TEST(XTensorTest, BasicOperations) { // Create an xtensor array @@ -27,3 +27,71 @@ TEST(XTensorTest, BasicOperations) { // Verify the result EXPECT_TRUE(xt::allclose(arr2, expected)); } + +TEST(XTensorTest, SpanToXtensor) { + std::vector data = {1, 2, 3, 4, 5, 6}; + std::span data_span(data.data(), data.size()); + ttnn::SimpleShape shape({2, 3}); + + auto result = ttml::core::span_to_xtensor(data_span, shape); + + // Check shape + EXPECT_EQ(result.shape().size(), 2); + EXPECT_EQ(result.shape()[0], 2); + EXPECT_EQ(result.shape()[1], 3); + + // Check data + int expected_val = 1; + for (size_t i = 0; i < result.shape()[0]; ++i) { + for (size_t j = 0; j < result.shape()[1]; ++j) { + EXPECT_EQ(result(i, j), expected_val++); + } + } +} + +// Test xtensor_to_span +TEST(XTensorTest, XtensorToSpan) { + xt::xarray arr = {{1.0f, 2.0f}, {3.0f, 4.0f}}; + auto span_result = ttml::core::xtensor_to_span(arr); + + EXPECT_EQ(span_result.size(), arr.size()); + + // Check data + size_t index = 0; + for (float val : arr) { + EXPECT_FLOAT_EQ(span_result[index++], val); + } +} + +// Test get_shape_4d +TEST(XTensorTest, GetShape4D) { + // Test a 4D shape + xt::xarray arr_4d = xt::xarray::from_shape({2, 3, 4, 5}); + auto shape4d = ttml::core::get_shape_4d(arr_4d); + EXPECT_EQ(shape4d[0], 2); + EXPECT_EQ(shape4d[1], 3); + EXPECT_EQ(shape4d[2], 4); + EXPECT_EQ(shape4d[3], 5); + + // Test a 2D shape, should zero-pad to the left (or right) as per logic + xt::xarray arr_2d = xt::xarray::from_shape({10, 20}); + auto shape2d = ttml::core::get_shape_4d(arr_2d); + // dims=2, so shape4d = {1, 1, 10, 20} + EXPECT_EQ(shape2d[0], 1); + EXPECT_EQ(shape2d[1], 1); + EXPECT_EQ(shape2d[2], 10); + EXPECT_EQ(shape2d[3], 20); + + // Test a 1D shape + xt::xarray arr_1d = xt::xarray::from_shape({7}); + auto shape1d = ttml::core::get_shape_4d(arr_1d); + // dims=1, so shape4d = {1, 1, 1, 7} + EXPECT_EQ(shape1d[0], 1); + EXPECT_EQ(shape1d[1], 1); + EXPECT_EQ(shape1d[2], 1); + EXPECT_EQ(shape1d[3], 7); + + // Test throwing an exception for >4D + xt::xarray arr_5d = xt::xarray::from_shape({2, 2, 2, 2, 2}); + EXPECT_THROW(ttml::core::get_shape_4d(arr_5d), std::runtime_error); +} diff --git a/tt-train/tests/core/distributed_test.cpp b/tt-train/tests/core/distributed_test.cpp new file mode 100644 index 00000000000..0f304788ca3 --- /dev/null +++ b/tt-train/tests/core/distributed_test.cpp @@ -0,0 +1,245 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include + +#include + +#include "core/distributed_mapping.hpp" + +template +class MeshOpsTest : public ::testing::Test { +protected: + // Common setup could go here if needed +}; + +using TestTypes = ::testing::Types; +TYPED_TEST_SUITE(MeshOpsTest, TestTypes); + +TYPED_TEST(MeshOpsTest, ChunkBasicNonDivisible3) { + // Create a 1D tensor: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + // Using TypeParam ensures we test both uint32_t and float. + xt::xarray tensor = xt::arange(10); + + // Chunk into 3 parts along dimension 0 + auto chunks = ttml::core::chunk(tensor, 3, 0); + + ASSERT_EQ(chunks.size(), 3u); + EXPECT_EQ(chunks[0].shape()[0], 4u); // first chunk size 4 + EXPECT_EQ(chunks[1].shape()[0], 4u); // next chunk size 4 + EXPECT_EQ(chunks[2].shape()[0], 2u); // last chunk size 2 +} + +TYPED_TEST(MeshOpsTest, ChunkBasicLessChunksThanProvided) { + // Create a 1D tensor: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10,11,12] + xt::xarray tensor = xt::arange(13); + + // Chunk into 6 parts along dimension 0 + auto chunks = ttml::core::chunk(tensor, 6, 0); + + ASSERT_EQ(chunks.size(), 5u); + EXPECT_EQ(chunks[0].shape()[0], 3u); // first chunk size 3 + EXPECT_EQ(chunks[1].shape()[0], 3u); // next chunk size 3 + EXPECT_EQ(chunks[2].shape()[0], 3u); // next chunk size 3 + EXPECT_EQ(chunks[3].shape()[0], 3u); // next chunk size 3 + EXPECT_EQ(chunks[4].shape()[0], 1u); // last chunk size 1 +} + +TYPED_TEST(MeshOpsTest, ShardXTensorToMeshBasicShard) { + tt::tt_metal::distributed::MeshShape mesh_shape = {1, 4}; + + // A simple 1D tensor to shard across 4 devices + auto tensor = xt::arange(8); // [0,...,7] + + ttml::core::ShardXTensorToMesh sharder(mesh_shape, 0); + auto shards = sharder.map(tensor); + + // With 4 shards, each shard should have size 2 + ASSERT_EQ(shards.size(), 4u); + for (auto& s : shards) { + EXPECT_EQ(s.size(), 2u); + } +} + +TYPED_TEST(MeshOpsTest, ShardTensor2dMeshTwoDimSharding) { + // Mesh shape: 2x2, total 4 devices + tt::tt_metal::distributed::MeshShape mesh_shape = {2, 2}; + + // Create a 2D tensor shape: (4,4) + auto tensor = xt::arange(16).reshape({4, 4}); + + // Shard along row_dim=0 and col_dim=1 + ttml::core::ShardTensor2dMesh sharder(mesh_shape, {0, 1}); + auto shards = sharder.map(tensor); + + ASSERT_EQ(shards.size(), 4u); + // Check shapes of shards + for (auto& shard : shards) { + EXPECT_EQ(shard.shape()[0], 2u); + EXPECT_EQ(shard.shape()[1], 2u); + } +} + +TYPED_TEST(MeshOpsTest, ReplicateXTensorToMeshReplication) { + tt::tt_metal::distributed::MeshShape mesh_shape = {2, 2}; + int num_devices = mesh_shape.num_rows * mesh_shape.num_cols; // 4 + + auto tensor = xt::arange(4); // [0,1,2,3] + + ttml::core::ReplicateXTensorToMesh replicator(mesh_shape); + auto replicas = replicator.map(tensor); + + ASSERT_EQ(static_cast(replicas.size()), num_devices); + for (const auto& t : replicas) { + EXPECT_TRUE(xt::allclose(t, tensor)); + } +} + +TYPED_TEST(MeshOpsTest, ConcatMesh2dToTensorRecomposition) { + tt::tt_metal::distributed::MeshShape mesh_shape = {2, 2}; + + // Create shards that would come from a 4x4 tensor: + // Expected final tensor: + // [[0,1,2,3], + // [4,5,6,7], + // [8,9,10,11], + // [12,13,14,15]] + // + // Shards (2x2 each): + xt::xarray top_left = {{TypeParam(0), TypeParam(1)}, {TypeParam(4), TypeParam(5)}}; + xt::xarray top_right = {{TypeParam(2), TypeParam(3)}, {TypeParam(6), TypeParam(7)}}; + xt::xarray bot_left = {{TypeParam(8), TypeParam(9)}, {TypeParam(12), TypeParam(13)}}; + xt::xarray bot_right = {{TypeParam(10), TypeParam(11)}, {TypeParam(14), TypeParam(15)}}; + + std::vector> shards = {top_left, top_right, bot_left, bot_right}; + + ttml::core::ConcatMesh2dToTensor composer(mesh_shape, {0, 1}); + auto composed = composer.compose(shards); + + xt::xarray expected = { + {TypeParam(0), TypeParam(1), TypeParam(2), TypeParam(3)}, + {TypeParam(4), TypeParam(5), TypeParam(6), TypeParam(7)}, + {TypeParam(8), TypeParam(9), TypeParam(10), TypeParam(11)}, + {TypeParam(12), TypeParam(13), TypeParam(14), TypeParam(15)}}; + + EXPECT_TRUE(xt::allclose(composed[0], expected)); +} + +TYPED_TEST(MeshOpsTest, ConcatMeshToXTensorOneDimConcatenation) { + tt::tt_metal::distributed::MeshShape mesh_shape = {1, 3}; + + // Create a few shards: [0,1], [2,3], [4,5] + xt::xarray s1 = {TypeParam(0), TypeParam(1)}; + xt::xarray s2 = {TypeParam(2), TypeParam(3)}; + xt::xarray s3 = {TypeParam(4), TypeParam(5)}; + + std::vector> shards = {s1, s2, s3}; + ttml::core::ConcatMeshToXTensor composer(mesh_shape, 0); + auto composed = composer.compose(shards); + + xt::xarray expected = { + TypeParam(0), TypeParam(1), TypeParam(2), TypeParam(3), TypeParam(4), TypeParam(5)}; + EXPECT_TRUE(xt::allclose(composed[0], expected)); +} + +TYPED_TEST(MeshOpsTest, VectorMeshToXTensorVectorReturn) { + tt::tt_metal::distributed::MeshShape mesh_shape = {2, 2}; + ttml::core::VectorMeshToXTensor vectorComposer(mesh_shape); + + std::vector> shards = { + xt::xarray({TypeParam(0), TypeParam(1)}), xt::xarray({TypeParam(2), TypeParam(3)})}; + + auto result = vectorComposer.compose(shards); + ASSERT_EQ(result.size(), shards.size()); + for (size_t i = 0; i < shards.size(); ++i) { + EXPECT_TRUE(xt::allclose(result[i], shards[i])); + } +} + +TEST(ConcatenateTest, DefaultAxis) { + xt::xarray a = {{1.0, 2.0}, {3.0, 4.0}}; + xt::xarray b = {{5.0, 6.0}, {7.0, 8.0}}; + std::vector> input = {a, b}; + + xt::xarray result = ttml::core::concatenate(input); // axis=0 by default + xt::xarray expected = {{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}, {7.0, 8.0}}; + + xt::allclose(result, expected); +} + +TEST(ConcatenateTest, AxisOne) { + xt::xarray x = {{1, 2, 3}, {4, 5, 6}}; + xt::xarray y = {{7, 8}, {9, 10}}; + std::vector> input = {x, y}; + + xt::xarray result = ttml::core::concatenate(input, 1); + xt::xarray expected = {{1, 2, 3, 7, 8}, {4, 5, 6, 9, 10}}; + + xt::allclose(result, expected); +} + +TEST(ConcatenateTest, MultipleArraysAxis0) { + xt::xarray a = {1.0f, 2.0f}; + xt::xarray b = {3.0f, 4.0f}; + xt::xarray c = {5.0f, 6.0f}; + std::vector> input = {a, b, c}; + + xt::xarray result = ttml::core::concatenate(input, 0); + xt::xarray expected = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + + xt::allclose(result, expected); +} + +TEST(ConcatenateTest, EmptyArray) { + xt::xarray a = {{1, 2}, {3, 4}}; + xt::xarray b; // Empty + std::vector> input = {a, b}; + + EXPECT_ANY_THROW({ xt::xarray result = ttml::core::concatenate(input, 0); }); +} + +TEST(ConcatenateTest, HigherDimensions) { + xt::xarray arr1 = xt::arange(1, 9); // 1 to 8 + arr1.reshape({2, 2, 2}); + xt::xarray arr2 = xt::arange(9, 17); // 9 to 16 + arr2.reshape({2, 2, 2}); + + std::vector> input = {arr1, arr2}; + xt::xarray result = ttml::core::concatenate(input, 0); + + // Expected: shape (4,2,2) with arr1 stacked over arr2 along axis 0 + xt::xarray expected = xt::concatenate(xt::xtuple(arr1, arr2), 0); + + xt::allclose(result, expected); +} + +TEST(ConcatenateTest, HigherAxis) { + xt::xarray arr1 = {{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}; + xt::xarray arr2 = {{{9, 10}, {11, 12}}, {{13, 14}, {15, 16}}}; + // Both have shape (2,2,2) + + std::vector> input = {arr1, arr2}; + xt::xarray result = ttml::core::concatenate(input, 2); + // Expected shape: (2,2,4) + xt::xarray expected = {{{1, 2, 9, 10}, {3, 4, 11, 12}}, {{5, 6, 13, 14}, {7, 8, 15, 16}}}; + + xt::allclose(result, expected); +} + +TYPED_TEST(MeshOpsTest, ConcatenateSameParametersAsCompose) { + tt::tt_metal::distributed::MeshShape mesh_shape = {1, 3}; + + // Create a few shards: [0,1], [2,3], [4,5] + xt::xarray s1 = {TypeParam(0), TypeParam(1)}; + xt::xarray s2 = {TypeParam(2), TypeParam(3)}; + xt::xarray s3 = {TypeParam(4), TypeParam(5)}; + + std::vector> shards = {s1, s2, s3}; + ttml::core::ConcatMeshToXTensor composer(mesh_shape, 0); + auto composed = ttml::core::concatenate(shards); + + xt::xarray expected = { + TypeParam(0), TypeParam(1), TypeParam(2), TypeParam(3), TypeParam(4), TypeParam(5)}; + EXPECT_TRUE(xt::allclose(composed, expected)); +} diff --git a/tt-train/tests/core/n300_utils_test.cpp b/tt-train/tests/core/n300_utils_test.cpp new file mode 100644 index 00000000000..7b376356b76 --- /dev/null +++ b/tt-train/tests/core/n300_utils_test.cpp @@ -0,0 +1,167 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include + +#include +#include +#include + +#include "autograd/auto_context.hpp" +#include "core/compute_kernel_config.hpp" +#include "core/distributed_mapping.hpp" +#include "core/tt_tensor_utils.hpp" +#include "ttnn/operations/ccl/all_gather/all_gather.hpp" +#include "ttnn/operations/experimental/ccl/all_reduce/all_reduce.hpp" + +auto check_board_is_n300() { + return tt::Cluster::instance().get_board_type(0) == BoardType::N300; +} +class N300UtilsTest : public ::testing::Test { +protected: + void SetUp() override { + if (!check_board_is_n300()) { + GTEST_SKIP() << "Skipping N300 specific tests"; + } + ttml::autograd::ctx().set_mesh_shape({1, 2}); + ttml::autograd::ctx().open_device(); + } + + void TearDown() override { + ttml::autograd::ctx().close_device(); + } +}; + +TEST_F(N300UtilsTest, TestXTensorReplicate) { + auto* device = &ttml::autograd::ctx().get_device(); + auto mesh_shape = device->shape(); + xt::xarray test_data = {30.F, 20.F, 2.F}; + xt::xarray xtensor = test_data.reshape({1, 1, 1, 3}); + ttml::core::XTensorToMeshVariant replicate_composer = ttml::core::ReplicateXTensorToMesh(mesh_shape); + auto tensor = ttml::core::from_xtensor(xtensor, device, replicate_composer); + ttml::core::MeshToXTensorVariant identity_composer = ttml::core::VectorMeshToXTensor(mesh_shape); + auto xtensors_back = ttml::core::to_xtensor(tensor, identity_composer); + + EXPECT_TRUE(xt::allclose(xtensor, xtensors_back[0])); + EXPECT_TRUE(xt::allclose(xtensor, xtensors_back[1])); +} + +TEST_F(N300UtilsTest, TestXTensorShardAxis3) { + auto* device = &ttml::autograd::ctx().get_device(); + auto mesh_shape = device->shape(); + + xt::xarray test_data = xt::arange(8); + xt::xarray xtensor = test_data.reshape({1, 1, 2, 4}); + + ttml::core::XTensorToMeshVariant replicate_composer = ttml::core::ShardXTensorToMesh(mesh_shape, 3); + auto tensor = ttml::core::from_xtensor(xtensor, device, replicate_composer); + + ttml::core::MeshToXTensorVariant identity_composer = ttml::core::VectorMeshToXTensor(mesh_shape); + auto xtensors_back = ttml::core::to_xtensor(tensor, identity_composer); + + xt::xarray chunk0 = xt::view(xtensor, xt::all(), xt::all(), xt::all(), xt::range(0, 2)); + xt::xarray chunk1 = xt::view(xtensor, xt::all(), xt::all(), xt::all(), xt::range(2, 4)); + + EXPECT_TRUE(xt::allclose(chunk0, xtensors_back[0])); + EXPECT_TRUE(xt::allclose(chunk1, xtensors_back[1])); +} + +TEST_F(N300UtilsTest, TestXTensorShardAxis2) { + auto* device = &ttml::autograd::ctx().get_device(); + auto mesh_shape = device->shape(); + + xt::xarray test_data = xt::arange(8); + xt::xarray xtensor = test_data.reshape({1, 1, 2, 4}); + + ttml::core::XTensorToMeshVariant replicate_composer = ttml::core::ShardXTensorToMesh(mesh_shape, 2); + auto tensor = ttml::core::from_xtensor(xtensor, device, replicate_composer); + + ttml::core::MeshToXTensorVariant identity_composer = ttml::core::VectorMeshToXTensor(mesh_shape); + auto xtensors_back = ttml::core::to_xtensor(tensor, identity_composer); + + xt::xarray chunk0 = xt::view(xtensor, xt::all(), xt::all(), xt::range(0, 1), xt::all()); + xt::xarray chunk1 = xt::view(xtensor, xt::all(), xt::all(), xt::range(1, 2), xt::all()); + + EXPECT_TRUE(xt::allclose(chunk0, xtensors_back[0])); + EXPECT_TRUE(xt::allclose(chunk1, xtensors_back[1])); +} + +TEST_F(N300UtilsTest, TestXTensorReplicateAllReduce) { + auto* device = &ttml::autograd::ctx().get_device(); + auto mesh_shape = device->shape(); + + xt::xarray xtensor = xt::random::rand({32 * 32}, -0.05, 0.05).reshape({1, 1, 32, 32}); + + ttml::core::XTensorToMeshVariant replicate_composer = ttml::core::ReplicateXTensorToMesh(mesh_shape); + auto tensor = ttml::core::from_xtensor(xtensor, device, replicate_composer); + + auto sum_tensor = ttnn::experimental::all_reduce( + tensor, ttnn::operations::reduction::ReduceType::Sum, 1, std::nullopt, ttnn::ccl::Topology::Ring); + ttml::core::MeshToXTensorVariant identity_composer = ttml::core::VectorMeshToXTensor(mesh_shape); + + auto xtensors_back = ttml::core::to_xtensor(sum_tensor, identity_composer); + auto reduced_tensor = xtensor + xtensor; + + std::cout << "xtensors_back[0]: " << xtensors_back[0] << std::endl; + std::cout << "xtensors_back[1]: " << xtensors_back[1] << std::endl; + std::cout << "reduced_tensor: " << reduced_tensor << std::endl; + EXPECT_TRUE(xt::allclose(reduced_tensor, xtensors_back[0], /*rtol=*/1e-3, /*atol=*/1e-2)); + EXPECT_TRUE(xt::allclose(reduced_tensor, xtensors_back[1], /*rtol=*/1e-3, /*atol=*/1e-2)); +} + +TEST_F(N300UtilsTest, TestXTensorShardAxis2AddScalar) { + auto* device = &ttml::autograd::ctx().get_device(); + auto mesh_shape = device->shape(); + float scalar = 10.F; + xt::xarray test_data = xt::arange(8); + xt::xarray xtensor = test_data.reshape({1, 1, 2, 4}); + + ttml::core::XTensorToMeshVariant shard_composer = ttml::core::ShardXTensorToMesh(mesh_shape, 2); + auto tensor = ttml::core::from_xtensor(xtensor, device, shard_composer); + auto out_tensor = ttnn::add(tensor, scalar); + ttml::core::MeshToXTensorVariant identity_composer = ttml::core::VectorMeshToXTensor(mesh_shape); + auto xtensors_back = ttml::core::to_xtensor(out_tensor, identity_composer); + + xt::xarray chunk0 = xt::view(xtensor, xt::all(), xt::all(), xt::range(0, 1), xt::all()); + xt::xarray chunk1 = xt::view(xtensor, xt::all(), xt::all(), xt::range(1, 2), xt::all()); + + EXPECT_TRUE(xt::allclose(chunk0 + scalar, xtensors_back[0])); + EXPECT_TRUE(xt::allclose(chunk1 + scalar, xtensors_back[1])); +} + +TEST_F(N300UtilsTest, TestXTensorShardAxis3Matmul) { + xt::random::seed(42); + auto* device = &ttml::autograd::ctx().get_device(); + auto mesh_shape = device->shape(); + + xt::xarray xtensor_a = xt::random::rand({128 * 64}, -0.005, 0.005).reshape({1, 1, 128, 64}); + xt::xarray xtensor_b = xt::random::rand({256 * 64}, -0.005, 0.005).reshape({1, 1, 64, 256}); + + ttml::core::XTensorToMeshVariant replicate_composer2 = ttml::core::ShardXTensorToMesh(mesh_shape, 2); + ttml::core::XTensorToMeshVariant replicate_composer3 = ttml::core::ShardXTensorToMesh(mesh_shape, 3); + auto tensor_a = ttml::core::from_xtensor(xtensor_a, device, replicate_composer3); + auto tensor_b = ttml::core::from_xtensor(xtensor_b, device, replicate_composer3); + + auto gathered_ta = + ttnn::all_gather(tensor_a, 3 /*, {0, 4}, 1 ,std::nullopt, std::nullopt, std::nullopt, std::nullopt*/); + fmt::print("gathered_ta shape: {}\n", gathered_ta.get_shape().logical_shape()); + auto mul_tensor = ttnn::matmul( + gathered_ta, + tensor_b, + false, + false, + /* memory_config */ std::nullopt, + /* dtype */ std::nullopt, + /* program_config */ std::nullopt, + /* activation */ std::nullopt, + /* compute_kernel_config */ ttml::core::ComputeKernelConfig::precise(), + /* core_grid */ ttnn::CoreGrid{7, 8}, + /* output_tile */ std::nullopt); + ttml::core::MeshToXTensorVariant composer = ttml::core::ConcatMeshToXTensor(mesh_shape, 3); + auto xtensors_back = ttml::core::to_xtensor(mul_tensor, composer); + xt::xarray mul_res = xt::linalg::dot(xtensor_a, xtensor_b); + + // (128, 64) X (64, 256) => (128, 256) + EXPECT_TRUE(xt::allclose(mul_res, xtensors_back[0], /*rtol=*/1e-3, /*atol=*/1e-2)); +} diff --git a/tt-train/tests/core/tensor_utils_test.cpp b/tt-train/tests/core/tensor_utils_test.cpp index 196cfb8fff2..72e518de091 100644 --- a/tt-train/tests/core/tensor_utils_test.cpp +++ b/tt-train/tests/core/tensor_utils_test.cpp @@ -6,12 +6,11 @@ #include #include -#include #include #include "autograd/auto_context.hpp" -#include "core/device.hpp" #include "core/tt_tensor_utils.hpp" +#include "core/xtensor_utils.hpp" TEST(TensorUtilsTest, TestFloatToFromTensorEven) { auto* device = &ttml::autograd::ctx().get_device(); @@ -212,3 +211,32 @@ TEST(TensorUtilsTest, TestZerosLike) { EXPECT_EQ(val, 0.F); } } + +TEST(TensorUtilsTest, TestFloatXtensor) { + auto* device = &ttml::autograd::ctx().get_device(); + std::vector test_data = {30.F, 20.F, 2.F}; + + auto shape = ttml::core::create_shape({1, 1, 1, 3}); + + xt::xarray xtensor = + ttml::core::span_to_xtensor(std::span{test_data.data(), test_data.size()}, shape.logical_shape()); + auto tensor = ttml::core::from_xtensor(xtensor, device); + + auto xtensor_back = ttml::core::to_xtensor(tensor); + + EXPECT_TRUE(xt::allclose(xtensor, xtensor_back)); +} + +TEST(TensorUtilsTest, TestUint32XTensor) { + auto* device = &ttml::autograd::ctx().get_device(); + std::vector test_data = {30, 20, 2}; + + auto shape = ttml::core::create_shape({1, 1, 1, 3}); + xt::xarray xtensor = + ttml::core::span_to_xtensor(std::span{test_data.data(), test_data.size()}, shape.logical_shape()); + auto tensor = ttml::core::from_xtensor(xtensor, device); + + auto xtensor_back = ttml::core::to_xtensor(tensor); + + EXPECT_TRUE(xt::allclose(xtensor, xtensor_back)); +} diff --git a/tt_metal/common/tt_backend_api_types.hpp b/tt_metal/common/tt_backend_api_types.hpp index a629744eab9..cd815f14530 100644 --- a/tt_metal/common/tt_backend_api_types.hpp +++ b/tt_metal/common/tt_backend_api_types.hpp @@ -12,7 +12,7 @@ #include #include "fmt/base.h" -#include "umd/device/tt_arch_types.h" +#include "umd/device/types/arch.h" namespace tt { diff --git a/tt_metal/common/work_split.cpp b/tt_metal/common/work_split.cpp index f2a213a1721..ba687d9d3da 100644 --- a/tt_metal/common/work_split.cpp +++ b/tt_metal/common/work_split.cpp @@ -148,6 +148,123 @@ CoreRangeSet num_cores_to_corerangeset( return num_cores_to_corerangeset({0, 0}, target_num_cores, grid_size, row_wise); } +CoreRangeSet num_cores_to_corerangeset_in_subcoregrids( + const CoreCoord start_core, + const uint32_t target_num_cores, + const CoreRangeSet& sub_core_grids, + const bool row_wise = false) { + // If target_num_cores is 0 or input_corerangeset is empty, return empty CoreRangeSet + TT_FATAL(target_num_cores > 0, "Target number of cores must be greater than 0"); + TT_FATAL( + target_num_cores <= sub_core_grids.num_cores(), + "Target number of cores {} is greater than total number of available cores {}", + target_num_cores, + sub_core_grids.num_cores()); + + // Validate that the start core is contained within the entire CoreRangeSet + TT_FATAL(sub_core_grids.contains(start_core), "Start core must be contained within the input CoreRangeSet"); + + std::vector result_coreranges; + bool start_core_found = false; + CoreCoord current_start_core = start_core; + CoreCoord current_end_core = start_core; + uint32_t remaining_cores = target_num_cores; + + auto process_row_wise = [&](const CoreRange& subcoregrid) { + uint32_t subcoregrid_width = subcoregrid.grid_size().x; + + for (uint32_t y = current_start_core.y; y <= subcoregrid.end_coord.y; ++y) { + if (remaining_cores == 0) { + break; + } + + uint32_t current_width = + std::min(static_cast(subcoregrid.end_coord.x - current_start_core.x + 1), remaining_cores); + + if (current_width < subcoregrid_width) { + if (current_start_core != current_end_core) { + result_coreranges.push_back(CoreRange(current_start_core, current_end_core)); + } + + current_end_core = CoreCoord(current_start_core.x + current_width - 1, y); + remaining_cores -= current_width; + + result_coreranges.push_back( + CoreRange(CoreCoord(current_start_core.x, y), CoreCoord(current_end_core.x, y))); + + current_start_core = CoreCoord(subcoregrid.start_coord.x, y + 1); + current_end_core = current_start_core; + } else { + current_end_core = CoreCoord(subcoregrid.end_coord.x, y); + remaining_cores -= current_width; + } + } + + if (current_start_core != current_end_core) { + result_coreranges.push_back(CoreRange(current_start_core, current_end_core)); + } + }; + + auto process_col_wise = [&](const CoreRange& subcoregrid) { + uint32_t subcoregrid_height = subcoregrid.grid_size().y; + + for (uint32_t x = current_start_core.x; x <= subcoregrid.end_coord.x; ++x) { + if (remaining_cores == 0) { + break; + } + + uint32_t current_height = + std::min(static_cast(subcoregrid.end_coord.y - current_start_core.y + 1), remaining_cores); + + if (current_height < subcoregrid_height) { + if (current_start_core != current_end_core) { + result_coreranges.push_back(CoreRange(current_start_core, current_end_core)); + } + + current_end_core = CoreCoord(x, current_start_core.y + current_height - 1); + remaining_cores -= current_height; + + result_coreranges.push_back( + CoreRange(CoreCoord(x, current_start_core.y), CoreCoord(x, current_end_core.y))); + + current_start_core = CoreCoord(x + 1, subcoregrid.start_coord.y); + current_end_core = current_start_core; + } else { + current_end_core = CoreCoord(x, subcoregrid.end_coord.y); + remaining_cores -= current_height; + } + } + + if (current_start_core != current_end_core) { + result_coreranges.push_back(CoreRange(current_start_core, current_end_core)); + } + }; + + // Iterate over subcoregrids and process based on row_wise + for (const auto& subcoregrid : sub_core_grids.ranges()) { + if (subcoregrid.contains(start_core)) { + start_core_found = true; + } else { + if (!start_core_found) { + continue; + } else { + current_start_core = subcoregrid.start_coord; + current_end_core = current_start_core; + } + } + + if (row_wise) { + process_row_wise(subcoregrid); + } else { + process_col_wise(subcoregrid); + } + } + + TT_FATAL(remaining_cores == 0, "Failed to split target number of cores into CoreRangeSet"); + + return CoreRangeSet(std::move(result_coreranges)); +} + std::tuple split_work_to_cores( const CoreCoord grid_size, const uint32_t units_to_divide, const bool row_wise) { ZoneScoped; diff --git a/tt_metal/common/work_split.hpp b/tt_metal/common/work_split.hpp index 39cdec9bf21..2b5ae0ecb9d 100644 --- a/tt_metal/common/work_split.hpp +++ b/tt_metal/common/work_split.hpp @@ -40,6 +40,11 @@ CoreRangeSet num_cores_to_corerangeset( CoreRangeSet num_cores_to_corerangeset( const uint32_t target_num_cores, const CoreCoord grid_size, const bool row_wise = false); +CoreRangeSet num_cores_to_corerangeset_in_subcoregrids( + const CoreCoord start_core, + const uint32_t target_num_cores, + const CoreRangeSet& sub_core_grids, + const bool row_wise = false); // This function takes in the core grid size, as well as the number of units of work to divide between the cores // This function returns the number of cores, the CoreRangeSet of all cores, and then the CoreRangeSet that does // the greater amount of work, and the CoreRangeSet that does less work if work cannot be evenly divided diff --git a/tt_metal/detail/tt_metal.hpp b/tt_metal/detail/tt_metal.hpp index e0bd1d543cf..c5e3c021ede 100644 --- a/tt_metal/detail/tt_metal.hpp +++ b/tt_metal/detail/tt_metal.hpp @@ -6,7 +6,7 @@ #include #include -#include "umd/device/tt_cluster_descriptor_types.h" +#include "umd/device/types/cluster_descriptor_types.h" #include "umd/device/tt_soc_descriptor.h" #include "tt_metal/hostdevcommon/common_values.hpp" #include "tt_metal/common/core_coord.hpp" diff --git a/tt_metal/distributed/mesh_device.cpp b/tt_metal/distributed/mesh_device.cpp index 62eaa78186b..6971abd948e 100644 --- a/tt_metal/distributed/mesh_device.cpp +++ b/tt_metal/distributed/mesh_device.cpp @@ -9,7 +9,7 @@ #include #include -#include "umd/device/tt_cluster_descriptor_types.h" +#include "umd/device/types/cluster_descriptor_types.h" #include "tt_metal/common/logger.hpp" #include "tt_metal/detail/tt_metal.hpp" #include "tt_metal/host_api.hpp" @@ -105,7 +105,7 @@ MeshShape SystemMesh::Impl::get_system_mesh_shape(size_t system_num_devices) { TT_FATAL( system_mesh_to_shape.contains(system_num_devices), "Unsupported number of devices: {}", system_num_devices); auto shape = system_mesh_to_shape.at(system_num_devices); - log_debug(LogMetal, "Logical SystemMesh Shape: {}x{}", shape.first, shape.second); + log_debug(LogMetal, "Logical SystemMesh Shape: {}x{}", shape.num_rows, shape.num_cols); return shape; } @@ -269,6 +269,10 @@ static MeshDeviceID generate_unique_mesh_id() { return next_id++; } +Device* MeshDevice::reference_device() const { + return this->devices.at(0); +} + MeshDevice::MeshDevice(const MeshShape& mesh_device_shape, MeshType type, std::weak_ptr parent_mesh) : mesh_device_shape(mesh_device_shape), type(type), @@ -289,32 +293,32 @@ std::shared_ptr MeshDevice::create( std::shared_ptr MeshDevice::create_submesh( const MeshShape& submesh_shape, const MeshOffset& offset, MeshType type) { - if (submesh_shape.first <= 0 || submesh_shape.second <= 0) { + if (submesh_shape.num_rows <= 0 || submesh_shape.num_cols <= 0) { TT_THROW( "Invalid submesh shape: ({}, {}). Both dimensions must be positive.", - submesh_shape.first, - submesh_shape.second); + submesh_shape.num_rows, + submesh_shape.num_cols); } - if (offset.first < 0 || offset.second < 0) { - TT_THROW("Invalid offset: ({}, {}). Offset must be non-negative.", offset.first, offset.second); + if (offset.row < 0 || offset.col < 0) { + TT_THROW("Invalid offset: ({}, {}). Offset must be non-negative.", offset.row, offset.col); } - if (offset.first + submesh_shape.first > this->mesh_device_shape.first || - offset.second + submesh_shape.second > this->mesh_device_shape.second) { + if (offset.row + submesh_shape.num_rows > this->mesh_device_shape.num_rows || + offset.col + submesh_shape.num_cols > this->mesh_device_shape.num_cols) { TT_THROW( "Submesh ({}x{}) with offset ({}, {}) does not fit within parent mesh ({}x{}).", - submesh_shape.first, - submesh_shape.second, - offset.first, - offset.second, - this->mesh_device_shape.first, - this->mesh_device_shape.second); + submesh_shape.num_rows, + submesh_shape.num_cols, + offset.row, + offset.col, + this->mesh_device_shape.num_rows, + this->mesh_device_shape.num_cols); } auto submesh = std::make_shared(submesh_shape, type, shared_from_this()); - auto start_coordinate = Coordinate{offset.first, offset.second}; - auto end_coordinate = Coordinate{offset.first + submesh_shape.first - 1, offset.second + submesh_shape.second - 1}; + auto start_coordinate = Coordinate{offset.row, offset.col}; + auto end_coordinate = Coordinate{offset.row + submesh_shape.num_rows - 1, offset.col + submesh_shape.num_cols - 1}; submesh->primary_view = std::make_shared(*this, start_coordinate, end_coordinate); submesh->devices = submesh->primary_view->get_devices(); SystemMesh::instance().register_mesh_device(submesh, submesh->devices); @@ -323,10 +327,10 @@ std::shared_ptr MeshDevice::create_submesh( LogMetal, "Instantiating submesh {}: {}x{} with offset: {} {}", submesh->get_mesh_id(), - submesh_shape.first, - submesh_shape.second, - offset.first, - offset.second); + submesh_shape.num_rows, + submesh_shape.num_cols, + offset.row, + offset.col); log_trace(LogMetal, "Submesh {} instantiated with {} devices", submesh->get_mesh_id(), submesh->devices); return submesh; @@ -334,8 +338,8 @@ std::shared_ptr MeshDevice::create_submesh( std::vector> MeshDevice::create_submeshes(const MeshShape& submesh_shape, MeshType type) { std::vector> submeshes; - for (int row = 0; row < this->num_rows(); row += submesh_shape.first) { - for (int col = 0; col < this->num_cols(); col += submesh_shape.second) { + for (int row = 0; row < this->num_rows(); row += submesh_shape.num_rows) { + for (int col = 0; col < this->num_cols(); col += submesh_shape.num_cols) { auto submesh = this->create_submesh(submesh_shape, MeshOffset{row, col}, type); submeshes.push_back(submesh); } @@ -403,17 +407,15 @@ const DeviceIds MeshDevice::get_device_ids() const { size_t MeshDevice::num_devices() const { return this->devices.size(); } -CoreCoord MeshDevice::compute_with_storage_grid_size() const { - return get_device_index(0)->compute_with_storage_grid_size(); -} +CoreCoord MeshDevice::compute_with_storage_grid_size() const { return this->reference_device()->compute_with_storage_grid_size(); } -CoreCoord MeshDevice::dram_grid_size() const { return get_device_index(0)->dram_grid_size(); } +CoreCoord MeshDevice::dram_grid_size() const { return this->reference_device()->dram_grid_size(); } -tt::ARCH MeshDevice::arch() const { return get_device_index(0)->arch(); } +tt::ARCH MeshDevice::arch() const { return this->reference_device()->arch(); } -size_t MeshDevice::num_rows() const { return this->mesh_device_shape.first; } +size_t MeshDevice::num_rows() const { return this->mesh_device_shape.num_rows; } -size_t MeshDevice::num_cols() const { return this->mesh_device_shape.second; } +size_t MeshDevice::num_cols() const { return this->mesh_device_shape.num_cols; } MeshShape MeshDevice::shape() const { return this->mesh_device_shape; } @@ -487,6 +489,24 @@ MeshSubDeviceManagerId MeshDevice::create_sub_device_manager(tt::stl::Span>& mesh_sub_devices, DeviceAddr local_l1_size) { + MeshSubDeviceManagerId mesh_sub_device_manager_id(*this); + TT_FATAL(mesh_sub_devices.size() == this->num_devices(), "Number of devices does not match number of sub-device configurations"); + for (uint32_t i = 0; i < this->num_devices(); i++) { + auto* device = this->devices[i]; + auto& sub_device_manager_id = mesh_sub_device_manager_id.sub_device_manager_ids[i]; + tt::stl::Span sub_devices(mesh_sub_devices[i]); + device->push_work([device, sub_devices, local_l1_size, &sub_device_manager_id]() { + sub_device_manager_id = device->create_sub_device_manager(sub_devices, local_l1_size); + }); + } + for (auto* device : this->devices) { + device->synchronize(); + } + return mesh_sub_device_manager_id; +} + void MeshDevice::load_sub_device_manager(MeshSubDeviceManagerId mesh_sub_device_manager_id) { for (uint32_t i = 0; i < this->num_devices(); i++) { auto* device = this->devices[i]; @@ -517,4 +537,15 @@ MeshSubDeviceManagerId::MeshSubDeviceManagerId(const MeshDevice& mesh_device) { this->sub_device_manager_ids.resize(mesh_device.num_devices()); } +int MeshDevice::num_dram_channels() const { + return this->reference_device()->num_dram_channels() * this->num_devices(); +} + +allocator::Statistics MeshDevice::get_memory_allocation_statistics(const BufferType &buffer_type, SubDeviceId sub_device_id) const { + // With current implementation, we assume that all devices have the same memory allocation statistics. + // This will be made more explicit in the future to have lock-step allocation across devices. + // Right now, we just return the statistics of the first device. + return this->reference_device()->get_memory_allocation_statistics(buffer_type, sub_device_id); +} + } // namespace tt::tt_metal::distributed diff --git a/tt_metal/distributed/mesh_device.hpp b/tt_metal/distributed/mesh_device.hpp index 81fa6f45be6..a7727fb97bd 100644 --- a/tt_metal/distributed/mesh_device.hpp +++ b/tt_metal/distributed/mesh_device.hpp @@ -18,7 +18,10 @@ namespace tt::tt_metal::distributed { using DeviceIds = std::vector; using MeshDeviceID = size_t; -using MeshOffset = std::pair; +struct MeshOffset { + size_t row = 0; + size_t col = 0; +}; class MeshDeviceView; struct MeshSubDeviceManagerId; @@ -89,6 +92,9 @@ class MeshDevice : public std::enable_shared_from_this { const DispatchCoreConfig& dispatch_core_config, const MeshDeviceConfig& config); + // This is a reference device used to query properties that are the same for all devices in the mesh. + Device* reference_device() const; + public: MeshDevice(const MeshShape& mesh_device_shape, MeshType type, std::weak_ptr parent_mesh = {}); ~MeshDevice(); @@ -111,15 +117,6 @@ class MeshDevice : public std::enable_shared_from_this { size_t num_cols() const; MeshShape shape() const; - CoreCoord compute_with_storage_grid_size() const; - - CoreCoord dram_grid_size() const; - - tt::ARCH arch() const; - void enable_async(bool enable); - void enable_program_cache(); - void disable_and_clear_program_cache(); - void close_devices(); std::shared_ptr get_view() const; std::shared_ptr get_view(); @@ -138,10 +135,10 @@ class MeshDevice : public std::enable_shared_from_this { std::vector> create_submeshes( const MeshShape& submesh_shape, MeshType type = MeshType::RowMajor); - size_t num_program_cache_entries() const; - MeshSubDeviceManagerId create_sub_device_manager( tt::stl::Span sub_devices, DeviceAddr local_l1_size); + MeshSubDeviceManagerId create_sub_device_manager( + const std::vector>& mesh_sub_devices, DeviceAddr local_l1_size); void load_sub_device_manager(MeshSubDeviceManagerId mesh_sub_device_manager_id); void clear_loaded_sub_device_manager(); void remove_sub_device_manager(MeshSubDeviceManagerId mesh_sub_device_manager_id); @@ -152,6 +149,20 @@ class MeshDevice : public std::enable_shared_from_this { size_t trace_region_size = DEFAULT_TRACE_REGION_SIZE, size_t num_command_queues = 1, const DispatchCoreConfig& dispatch_core_config = DispatchCoreConfig{}); + + // Device API Queries (API contract with Device class to be supported in future) + CoreCoord compute_with_storage_grid_size() const; + CoreCoord dram_grid_size() const; + + tt::ARCH arch() const; + void enable_async(bool enable); + void enable_program_cache(); + void disable_and_clear_program_cache(); + + size_t num_program_cache_entries() const; + + int num_dram_channels() const; + allocator::Statistics get_memory_allocation_statistics(const BufferType &buffer_type, SubDeviceId sub_device_id = SubDeviceId{0}) const; }; std::ostream& operator<<(std::ostream& os, const MeshDevice& mesh_device); @@ -164,4 +175,44 @@ struct MeshSubDeviceManagerId { std::vector sub_device_manager_ids; }; +namespace detail { +template +concept HasMethodsForArchitectureQueries = requires(T& device) { + { device.compute_with_storage_grid_size() } -> std::same_as; + { device.dram_grid_size() } -> std::same_as; + { device.arch() } -> std::same_as; + { device.num_dram_channels() } -> std::same_as; +}; + +template +concept HasMethodsForAllocator = requires(T& device) { + { device.get_memory_allocation_statistics(std::declval(), std::declval()) } -> std::same_as; +}; + +template +concept HasMethodsForProgramCache = requires(T& device) { + { device.num_program_cache_entries() } -> std::same_as; + { device.enable_program_cache() } -> std::same_as; + { device.disable_and_clear_program_cache() } -> std::same_as; +}; + +template +concept HasMethodsForAsync = requires(T& device) { + { device.enable_async(std::declval()) } -> std::same_as; +}; + +template +concept DeviceInterfaceContract = + HasMethodsForArchitectureQueries && + HasMethodsForAllocator && + HasMethodsForProgramCache && + HasMethodsForAsync; + +} // namespace detail + +// For now static_asserts are used to ensure that the concepts are satisfied. +// This is a temporary compile-time check to make sure that Device/MeshDevice don't deviate from the expected interface. +static_assert(detail::DeviceInterfaceContract, "Device must satisfy the DeviceInterfaceContract concept."); +static_assert(detail::DeviceInterfaceContract, "MeshDevice must satisfy the DeviceInterfaceContract concept."); + } // namespace tt::tt_metal::distributed diff --git a/tt_metal/distributed/mesh_device_view.cpp b/tt_metal/distributed/mesh_device_view.cpp index b0c2ef050be..f9e115f0437 100644 --- a/tt_metal/distributed/mesh_device_view.cpp +++ b/tt_metal/distributed/mesh_device_view.cpp @@ -83,7 +83,7 @@ MeshDeviceView::DeviceView MeshDeviceView::get_devices(const Coordinate& start, } MeshDeviceView::DeviceView MeshDeviceView::get_devices(const MeshShape& shape) { - return get_devices({0, 0}, {shape.first - 1, shape.second - 1}); + return get_devices({0, 0}, {shape.num_rows - 1, shape.num_cols - 1}); } std::vector MeshDeviceView::get_devices_on_row(size_t row) const { @@ -128,7 +128,7 @@ bool MeshDeviceView::empty() const noexcept { return devices_.empty(); } size_t MeshDeviceView::size() const noexcept { return devices_.size(); } -std::pair MeshDeviceView::shape() const noexcept { return {num_rows(), num_cols()}; } +MeshShape MeshDeviceView::shape() const noexcept { return {num_rows(), num_cols()}; } bool MeshDeviceView::contains(const Coordinate& coord) const noexcept { return coord.row >= top_left_.row && coord.row <= bottom_right_.row && coord.col >= top_left_.col && diff --git a/tt_metal/distributed/mesh_device_view.hpp b/tt_metal/distributed/mesh_device_view.hpp index 67bda684ebb..31af7aba376 100644 --- a/tt_metal/distributed/mesh_device_view.hpp +++ b/tt_metal/distributed/mesh_device_view.hpp @@ -17,7 +17,10 @@ namespace tt::tt_metal::distributed { // Forward declaration of MeshDevice class MeshDevice; -using MeshShape = std::pair; +struct MeshShape { + size_t num_rows = 0; + size_t num_cols = 0; +}; struct Coordinate { size_t row; diff --git a/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_math_binary_sfpu_api.h b/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_math_binary_sfpu_api.h index d57a3db8c3e..8a7d6543876 100644 --- a/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_math_binary_sfpu_api.h +++ b/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_math_binary_sfpu_api.h @@ -1,75 +1,8 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. // // SPDX-License-Identifier: Apache-2.0 #pragma once #include "llk_math_common_api.h" -#include "llk_math_eltwise_binary_sfpu.h" - -/************************************************************************* - * LLK ELTWISE BINARY SFPU - *************************************************************************/ - -template -inline void llk_math_eltwise_binary_sfpu( - const uint operand, - uint dst_index_a, - uint dst_index_b, - int vector_mode = (int)VectorMode::RC, - uint param0 = 0, - uint param1 = 0, - uint param2 = 0, - uint param3 = 0, - uint param4 = 0, - uint param5 = 0) { - const std::uint32_t operand_id = get_operand_id(operand); - const std::uint32_t num_faces = get_operand_num_faces(operand_id); - const std::uint32_t face_r_dim = get_operand_face_r_dim(operand_id); - - _llk_math_eltwise_binary_sfpu_( - face_r_dim, num_faces, dst_index_a, dst_index_b, vector_mode, param0, param1, param2, param3, param4, param5); -} - -template -inline void llk_math_eltwise_binary_sfpu_init( - const uint param0 = 0, - const uint param1 = 0, - const uint param2 = 0, - const uint param3 = 0, - const uint param4 = 0, - const uint param5 = 0) { - _llk_math_eltwise_binary_sfpu_init_(param0, param1, param2, param3, param4, param5); -} - -template -inline void llk_math_eltwise_binary_sfpu_quant_int32( - const uint operand, uint dst_index_a, uint dst_index_b, int vector_mode = (int)VectorMode::RC) { - llk_math_eltwise_binary_sfpu(operand, dst_index_a, dst_index_b, vector_mode); -} - -template -inline void llk_math_eltwise_binary_sfpu_quant_int32_init(const uint zero_point) { - llk_math_eltwise_binary_sfpu_init(zero_point); -} - -template -inline void llk_math_eltwise_binary_sfpu_requant_int32( - const uint operand, uint dst_index_a, uint dst_index_b, int vector_mode = (int)VectorMode::RC) { - llk_math_eltwise_binary_sfpu(operand, dst_index_a, dst_index_b, vector_mode); -} - -template -inline void llk_math_eltwise_binary_sfpu_requant_int32_init(const uint zero_point) { - llk_math_eltwise_binary_sfpu_init(zero_point); -} - -template -inline void llk_math_eltwise_binary_sfpu_dequant_int32( - const uint operand, uint dst_index_a, uint dst_index_b, int vector_mode = (int)VectorMode::RC) { - llk_math_eltwise_binary_sfpu(operand, dst_index_a, dst_index_b, vector_mode); -} - -template -inline void llk_math_eltwise_binary_sfpu_dequant_int32_init(const uint zero_point) { - llk_math_eltwise_binary_sfpu_init(zero_point); -} +#include "llk_math_eltwise_binary_sfpu_init.h" +#include "llk_math_eltwise_binary_sfpu_binop.h" diff --git a/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/ckernel_sfpu_add_int32.h b/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/ckernel_sfpu_add_int32.h new file mode 100644 index 00000000000..fff976fbf0b --- /dev/null +++ b/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/ckernel_sfpu_add_int32.h @@ -0,0 +1,22 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "ckernel.h" +#include "ckernel_defs.h" +#include "sfpi.h" + +using namespace sfpi; + +namespace ckernel { +namespace sfpu { + +template +inline void calculate_add_int32(const uint dst_offset) { + _add_int32_(dst_offset); +} + +} // namespace sfpu +} // namespace ckernel diff --git a/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/ckernel_sfpu_binary.h b/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/ckernel_sfpu_binary.h new file mode 100644 index 00000000000..6c23abe0a26 --- /dev/null +++ b/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/ckernel_sfpu_binary.h @@ -0,0 +1,23 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "ckernel.h" +#include "ckernel_defs.h" +#include "sfpi.h" + +using namespace sfpi; + +namespace ckernel { +namespace sfpu { + +template +inline void calculate_sfpu_binary(const uint dst_offset) +{ + _calculate_sfpu_binary_(dst_offset); +} + +} // namespace sfpu +} // namespace ckernel diff --git a/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/ckernel_sfpu_binary_bitwise.h b/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/ckernel_sfpu_binary_bitwise.h new file mode 100644 index 00000000000..9648858daac --- /dev/null +++ b/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/ckernel_sfpu_binary_bitwise.h @@ -0,0 +1,22 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "ckernel.h" +#include "ckernel_defs.h" +#include "sfpi.h" + +using namespace sfpi; + +namespace ckernel { +namespace sfpu { + +template +inline void calculate_sfpu_binary_bitwise(const uint dst_offset) { + _calculate_sfpu_binary_bitwise_(dst_offset); +} + +} // namespace sfpu +} // namespace ckernel diff --git a/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/ckernel_sfpu_floor.h b/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/ckernel_sfpu_floor.h index 30d18596ef4..ad167758a24 100644 --- a/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/ckernel_sfpu_floor.h +++ b/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/ckernel_sfpu_floor.h @@ -15,12 +15,40 @@ using namespace sfpi; namespace ckernel { namespace sfpu { +inline vInt float_to_int32(vFloat in) +{ + vInt result; + vInt exp = exexp(in); // extract exponent + v_if (exp < 0) { + result = 0; + } v_elseif (exp > 30) { + // set to int32 max value in case of overflow + result = std::numeric_limits::max(); + // check sign + v_if (in < 0) { + result = reinterpret(setsgn(reinterpret(result), 1)); + } v_endif + } v_else { + // extract mantissa + vInt man = exman8(in); + // shift the mantissa by (23-exponent) to the right + vInt shift = exp - 23; + man = shft(reinterpret(man), shift); + // check sign + v_if (in < 0) { + man = reinterpret(setsgn(reinterpret(man), 1)); + } v_endif + result = man; + } v_endif + return result; +} + template inline void calculate_floor() { for (int d = 0; d < ITERATIONS; d++) { vFloat result = dst_reg[0]; vFloat v = result; - vInt tmp = float_to_int16(result, 0); // TODO: Replace float_to_int16 to float_to_int32 once it is available + vInt tmp = float_to_int16(result, 0); result = int32_to_float(tmp, 0); v_if(result > v) { result = result - 1; } v_endif; @@ -31,5 +59,19 @@ inline void calculate_floor() { } } +template +inline void calculate_floor_float32() { + for (int d = 0; d < ITERATIONS; d++) { + vFloat result = dst_reg[0]; + vFloat v = result; + vInt tmp = float_to_int32(result); + result = int32_to_float(tmp, 0); + v_if(result > v) { result = result - 1; } + v_endif; + dst_reg[0] = result; + dst_reg++; + } +} + } // namespace sfpu } // namespace ckernel diff --git a/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/ckernel_sfpu_quant.h b/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/ckernel_sfpu_quant.h new file mode 100644 index 00000000000..851b71671d2 --- /dev/null +++ b/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/ckernel_sfpu_quant.h @@ -0,0 +1,40 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "ckernel.h" +#include "ckernel_defs.h" +#include "sfpi.h" + +using namespace sfpi; + +namespace ckernel { +namespace sfpu { + +template +inline void calculate_quant_int32(const uint dst_offset) +{ + _quant_int32_(dst_offset); +} + +template +inline void calculate_requant_int32(const uint dst_offset) +{ + _requant_int32_(dst_offset); +} + +template +inline void calculate_dequant_int32(const uint dst_offset) +{ + _dequant_int32_(dst_offset); +} + +template +void quant_init(const uint zero_point) { + _init_quant_zero_point_(zero_point); +} + +} // namespace sfpu +} // namespace ckernel diff --git a/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/llk_math_eltwise_binary_sfpu_add_int32.h b/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/llk_math_eltwise_binary_sfpu_add_int32.h new file mode 100644 index 00000000000..db9d2579956 --- /dev/null +++ b/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/llk_math_eltwise_binary_sfpu_add_int32.h @@ -0,0 +1,27 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "llk_math_eltwise_binary_sfpu_init.h" +#include "llk_math_eltwise_binary_sfpu_params.h" +#include "ckernel_sfpu_add_int32.h" + +namespace ckernel { + +// New LLK SFPU APIs + +template +inline void llk_math_eltwise_binary_sfpu_add_int32_init() { + llk_math_eltwise_binary_sfpu_init(); +} + +template +inline void llk_math_eltwise_binary_sfpu_add_int32( + uint dst_index0, uint32_t dst_index1, int vector_mode = VectorMode::RC) { + llk_math_eltwise_binary_sfpu_params( + ckernel::sfpu::calculate_add_int32, dst_index0, dst_index1, vector_mode); +} + +} // namespace ckernel diff --git a/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/llk_math_eltwise_binary_sfpu_binop.h b/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/llk_math_eltwise_binary_sfpu_binop.h new file mode 100644 index 00000000000..09fcb4a530d --- /dev/null +++ b/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/llk_math_eltwise_binary_sfpu_binop.h @@ -0,0 +1,29 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "llk_math_eltwise_binary_sfpu_init.h" +#include "llk_math_eltwise_binary_sfpu_params.h" +#include "ckernel_sfpu_binary.h" + +namespace ckernel { + +// New LLK SFPU APIs + +template +inline void llk_math_eltwise_binary_sfpu_binop_init() { + llk_math_eltwise_binary_sfpu_init(); +} + +template +inline void llk_math_eltwise_binary_sfpu_binop(uint dst_index0, uint32_t dst_index1, int vector_mode = VectorMode::RC) { + llk_math_eltwise_binary_sfpu_params( + ckernel::sfpu::calculate_sfpu_binary, + dst_index0, + dst_index1, + vector_mode); +} + +} // namespace ckernel diff --git a/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/llk_math_eltwise_binary_sfpu_bitwise.h b/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/llk_math_eltwise_binary_sfpu_bitwise.h new file mode 100644 index 00000000000..de2b20a8c70 --- /dev/null +++ b/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/llk_math_eltwise_binary_sfpu_bitwise.h @@ -0,0 +1,27 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "llk_math_eltwise_binary_sfpu_init.h" +#include "llk_math_eltwise_binary_sfpu_params.h" +#include "ckernel_sfpu_binary_bitwise.h" + +namespace ckernel { + +// New LLK SFPU APIs + +template +inline void llk_math_eltwise_binary_sfpu_bitwise_init() { + llk_math_eltwise_binary_sfpu_init(); +} + +template +inline void llk_math_eltwise_binary_sfpu_bitwise( + uint dst_index0, uint32_t dst_index1, int vector_mode = VectorMode::RC) { + llk_math_eltwise_binary_sfpu_params( + ckernel::sfpu::calculate_sfpu_binary_bitwise, dst_index0, dst_index1, vector_mode); +} + +} // namespace ckernel diff --git a/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/llk_math_eltwise_binary_sfpu_init.h b/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/llk_math_eltwise_binary_sfpu_init.h new file mode 100644 index 00000000000..23f90869e89 --- /dev/null +++ b/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/llk_math_eltwise_binary_sfpu_init.h @@ -0,0 +1,23 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "llk_sfpu_types.h" +#include "llk_math_eltwise_binary_sfpu.h" + +namespace ckernel { + +template +inline void llk_math_eltwise_binary_sfpu_init() { + _llk_math_eltwise_binary_sfpu_init_(); +} + +template +inline void llk_math_eltwise_binary_sfpu_init(F&& init_func, ARGS&& ... args) { + _llk_math_eltwise_binary_sfpu_init_(); + init_func(static_cast(args)...); +} + +} diff --git a/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/llk_math_eltwise_binary_sfpu_params.h b/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/llk_math_eltwise_binary_sfpu_params.h new file mode 100644 index 00000000000..3d54229b9d7 --- /dev/null +++ b/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/llk_math_eltwise_binary_sfpu_params.h @@ -0,0 +1,57 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once +#include "llk_sfpu_types.h" +#include "llk_math_eltwise_binary_sfpu.h" + +template +inline void llk_math_eltwise_binary_sfpu_params( + F&& sfpu_func, + uint dst_index0, + uint dst_index1, + int vector_mode = (int)VectorMode::RC, + ARGS&& ... args) { + + uint dst_index = (dst_index0 <= dst_index1) ? dst_index0 : dst_index1; + uint dst_offset = (dst_index0 > dst_index1) ? dst_index0 - dst_index1 : dst_index1 - dst_index0; + _llk_math_eltwise_binary_sfpu_start_(dst_index); + + if (vector_mode == (int)VectorMode::R) { + // Do a row vector, Face0 + Face1 -- first iteration (first row) + const int ITERATIONS = 1; +#pragma GCC unroll 0 + for (int face = 0; face < 2; face++) { + sfpu_func(dst_offset, static_cast(args)...); + TTI_SETRWC(p_setrwc::CLR_NONE, p_setrwc::CR_D, 8, 0, 0, p_setrwc::SET_D); + TTI_SETRWC(p_setrwc::CLR_NONE, p_setrwc::CR_D, 8, 0, 0, p_setrwc::SET_D); + } + // Skip the next 2 faces + TTI_SETRWC(p_setrwc::CLR_NONE, p_setrwc::CR_D, 8, 0, 0, p_setrwc::SET_D); + TTI_SETRWC(p_setrwc::CLR_NONE, p_setrwc::CR_D, 8, 0, 0, p_setrwc::SET_D); + TTI_SETRWC(p_setrwc::CLR_NONE, p_setrwc::CR_D, 8, 0, 0, p_setrwc::SET_D); + TTI_SETRWC(p_setrwc::CLR_NONE, p_setrwc::CR_D, 8, 0, 0, p_setrwc::SET_D); + } else if (vector_mode == (int)VectorMode::C) { + // Do a column vector, Face0 + Face2 -- All iterations for full face +#pragma GCC unroll 0 + for (int face = 0; face < 2; face++) { + sfpu_func(dst_offset, static_cast(args)...); + TTI_SETRWC(p_setrwc::CLR_NONE, p_setrwc::CR_D, 8, 0, 0, p_setrwc::SET_D); + TTI_SETRWC(p_setrwc::CLR_NONE, p_setrwc::CR_D, 8, 0, 0, p_setrwc::SET_D); + TTI_SETRWC(p_setrwc::CLR_NONE, p_setrwc::CR_D, 8, 0, 0, p_setrwc::SET_D); + TTI_SETRWC(p_setrwc::CLR_NONE, p_setrwc::CR_D, 8, 0, 0, p_setrwc::SET_D); + } + } else if (vector_mode == (int)VectorMode::RC) { + // Do all four faces, and iterate through all 4 blocks of 4 rows each +#pragma GCC unroll 0 + for (int face = 0; face < 4; face++) { + sfpu_func(dst_offset, static_cast(args)...); + TTI_SETRWC(p_setrwc::CLR_NONE, p_setrwc::CR_D, 8, 0, 0, p_setrwc::SET_D); + TTI_SETRWC(p_setrwc::CLR_NONE, p_setrwc::CR_D, 8, 0, 0, p_setrwc::SET_D); + } + } else { + sfpu_func(dst_offset, static_cast(args)...); + } + _llk_math_eltwise_binary_sfpu_done_(); +} diff --git a/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/llk_math_eltwise_binary_sfpu_quant.h b/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/llk_math_eltwise_binary_sfpu_quant.h new file mode 100644 index 00000000000..ffa3ad4451f --- /dev/null +++ b/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/llk_math_eltwise_binary_sfpu_quant.h @@ -0,0 +1,63 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "llk_math_eltwise_binary_sfpu_init.h" +#include "llk_math_eltwise_binary_sfpu_params.h" +#include "ckernel_sfpu_quant.h" + +namespace ckernel { + +// New LLK SFPU APIs + +template +inline void llk_math_eltwise_binary_sfpu_quant_int32_init(const uint zero_point) { + llk_math_eltwise_binary_sfpu_init( + sfpu::quant_init, + zero_point); +} + +template +inline void llk_math_eltwise_binary_sfpu_quant_int32(uint dst_index0, uint dst_index1, int vector_mode = (int)VectorMode::RC) { + llk_math_eltwise_binary_sfpu_params( + ckernel::sfpu::calculate_quant_int32, + dst_index0, + dst_index1, + vector_mode); +} + +template +inline void llk_math_eltwise_binary_sfpu_requant_int32_init(const uint zero_point) { + llk_math_eltwise_binary_sfpu_init( + sfpu::quant_init, + zero_point); +} + +template +inline void llk_math_eltwise_binary_sfpu_requant_int32(uint dst_index0, uint dst_index1, int vector_mode = (int)VectorMode::RC) { + llk_math_eltwise_binary_sfpu_params( + ckernel::sfpu::calculate_requant_int32, + dst_index0, + dst_index1, + vector_mode); +} + +template +inline void llk_math_eltwise_binary_sfpu_dequant_int32_init(const uint zero_point) { + llk_math_eltwise_binary_sfpu_init( + sfpu::quant_init, + zero_point); +} + +template +inline void llk_math_eltwise_binary_sfpu_dequant_int32(uint dst_index0, uint dst_index1, int vector_mode = (int)VectorMode::RC) { + llk_math_eltwise_binary_sfpu_params( + ckernel::sfpu::calculate_dequant_int32, + dst_index0, + dst_index1, + vector_mode); +} + +} // namespace ckernel diff --git a/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_floor.h b/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_floor.h index ff0e6e96daf..26325252a0f 100644 --- a/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_floor.h +++ b/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_floor.h @@ -23,4 +23,9 @@ inline void llk_math_eltwise_unary_sfpu_floor(uint dst_index, int vector_mode = ckernel::sfpu::calculate_floor, dst_index, vector_mode); } +template +inline void llk_math_eltwise_unary_sfpu_floor_float32(uint dst_index, int vector_mode = (int)VectorMode::RC) { + llk_math_eltwise_unary_sfpu_params( + ckernel::sfpu::calculate_floor_float32, dst_index, vector_mode); +} } // namespace ckernel diff --git a/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_params.h b/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_params.h index 6a4874b733e..e681d0940ec 100644 --- a/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_params.h +++ b/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_params.h @@ -1,4 +1,4 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. // // SPDX-License-Identifier: Apache-2.0 diff --git a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_math_binary_sfpu_api.h b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_math_binary_sfpu_api.h index bdca47da10d..8a7d6543876 100644 --- a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_math_binary_sfpu_api.h +++ b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_math_binary_sfpu_api.h @@ -1,70 +1,8 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. // // SPDX-License-Identifier: Apache-2.0 #pragma once #include "llk_math_common_api.h" -#include "llk_math_eltwise_binary_sfpu.h" - -/************************************************************************* - * LLK ELTWISE BINARY SFPU - *************************************************************************/ - -template -inline void llk_math_eltwise_binary_sfpu( - const uint operand, - uint dst_index_a, - uint dst_index_b, - int vector_mode = (int)VectorMode::RC, - uint param0 = 0, - uint param1 = 0, - uint param2 = 0, - uint param3 = 0, - uint param4 = 0, - uint param5 = 0) { - const std::uint32_t operand_id = get_operand_id(0); - const std::uint32_t num_faces = get_operand_num_faces(operand_id); - const std::uint32_t face_r_dim = get_operand_face_r_dim(operand_id); - - _llk_math_eltwise_binary_sfpu_( - face_r_dim, num_faces, dst_index_a, dst_index_b, vector_mode, param0, param1, param2, param3, param4, param5); -} - -template -inline void llk_math_eltwise_binary_sfpu_init( - uint param0 = 0, uint param1 = 0, uint param2 = 0, uint param3 = 0, uint param4 = 0, uint param5 = 0) { - _llk_math_eltwise_binary_sfpu_init_(param0, param1, param2, param3, param4, param5); -} - -template -inline void llk_math_eltwise_binary_sfpu_quant_int32( - uint dst_index_a, uint dst_index_b, int vector_mode = (int)VectorMode::RC) { - llk_math_eltwise_binary_sfpu(dst_index_a, dst_index_b, vector_mode); -} - -template -inline void llk_math_eltwise_binary_sfpu_quant_int32_init(const uint zero_point) { - llk_math_eltwise_binary_sfpu_init(zero_point); -} - -template -inline void llk_math_eltwise_binary_sfpu_requant_int32( - uint dst_index_a, uint dst_index_b, int vector_mode = (int)VectorMode::RC) { - llk_math_eltwise_binary_sfpu(dst_index_a, dst_index_b, vector_mode); -} - -template -inline void llk_math_eltwise_binary_sfpu_requant_int32_init(const uint zero_point) { - llk_math_eltwise_binary_sfpu_init(zero_point); -} - -template -inline void llk_math_eltwise_binary_sfpu_dequant_int32( - uint dst_index_a, uint dst_index_b, int vector_mode = (int)VectorMode::RC) { - llk_math_eltwise_binary_sfpu(dst_index_a, dst_index_b, vector_mode); -} - -template -inline void llk_math_eltwise_binary_sfpu_dequant_int32_init(const uint zero_point) { - llk_math_eltwise_binary_sfpu_init(zero_point); -} +#include "llk_math_eltwise_binary_sfpu_init.h" +#include "llk_math_eltwise_binary_sfpu_binop.h" diff --git a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/ckernel_sfpu_add_int32.h b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/ckernel_sfpu_add_int32.h new file mode 100644 index 00000000000..fff976fbf0b --- /dev/null +++ b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/ckernel_sfpu_add_int32.h @@ -0,0 +1,22 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "ckernel.h" +#include "ckernel_defs.h" +#include "sfpi.h" + +using namespace sfpi; + +namespace ckernel { +namespace sfpu { + +template +inline void calculate_add_int32(const uint dst_offset) { + _add_int32_(dst_offset); +} + +} // namespace sfpu +} // namespace ckernel diff --git a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/ckernel_sfpu_binary.h b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/ckernel_sfpu_binary.h new file mode 100644 index 00000000000..6c23abe0a26 --- /dev/null +++ b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/ckernel_sfpu_binary.h @@ -0,0 +1,23 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "ckernel.h" +#include "ckernel_defs.h" +#include "sfpi.h" + +using namespace sfpi; + +namespace ckernel { +namespace sfpu { + +template +inline void calculate_sfpu_binary(const uint dst_offset) +{ + _calculate_sfpu_binary_(dst_offset); +} + +} // namespace sfpu +} // namespace ckernel diff --git a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/ckernel_sfpu_binary_bitwise.h b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/ckernel_sfpu_binary_bitwise.h new file mode 100644 index 00000000000..9648858daac --- /dev/null +++ b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/ckernel_sfpu_binary_bitwise.h @@ -0,0 +1,22 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "ckernel.h" +#include "ckernel_defs.h" +#include "sfpi.h" + +using namespace sfpi; + +namespace ckernel { +namespace sfpu { + +template +inline void calculate_sfpu_binary_bitwise(const uint dst_offset) { + _calculate_sfpu_binary_bitwise_(dst_offset); +} + +} // namespace sfpu +} // namespace ckernel diff --git a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/ckernel_sfpu_floor.h b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/ckernel_sfpu_floor.h index 30d18596ef4..ad167758a24 100644 --- a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/ckernel_sfpu_floor.h +++ b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/ckernel_sfpu_floor.h @@ -15,12 +15,40 @@ using namespace sfpi; namespace ckernel { namespace sfpu { +inline vInt float_to_int32(vFloat in) +{ + vInt result; + vInt exp = exexp(in); // extract exponent + v_if (exp < 0) { + result = 0; + } v_elseif (exp > 30) { + // set to int32 max value in case of overflow + result = std::numeric_limits::max(); + // check sign + v_if (in < 0) { + result = reinterpret(setsgn(reinterpret(result), 1)); + } v_endif + } v_else { + // extract mantissa + vInt man = exman8(in); + // shift the mantissa by (23-exponent) to the right + vInt shift = exp - 23; + man = shft(reinterpret(man), shift); + // check sign + v_if (in < 0) { + man = reinterpret(setsgn(reinterpret(man), 1)); + } v_endif + result = man; + } v_endif + return result; +} + template inline void calculate_floor() { for (int d = 0; d < ITERATIONS; d++) { vFloat result = dst_reg[0]; vFloat v = result; - vInt tmp = float_to_int16(result, 0); // TODO: Replace float_to_int16 to float_to_int32 once it is available + vInt tmp = float_to_int16(result, 0); result = int32_to_float(tmp, 0); v_if(result > v) { result = result - 1; } v_endif; @@ -31,5 +59,19 @@ inline void calculate_floor() { } } +template +inline void calculate_floor_float32() { + for (int d = 0; d < ITERATIONS; d++) { + vFloat result = dst_reg[0]; + vFloat v = result; + vInt tmp = float_to_int32(result); + result = int32_to_float(tmp, 0); + v_if(result > v) { result = result - 1; } + v_endif; + dst_reg[0] = result; + dst_reg++; + } +} + } // namespace sfpu } // namespace ckernel diff --git a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/ckernel_sfpu_quant.h b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/ckernel_sfpu_quant.h new file mode 100644 index 00000000000..851b71671d2 --- /dev/null +++ b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/ckernel_sfpu_quant.h @@ -0,0 +1,40 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "ckernel.h" +#include "ckernel_defs.h" +#include "sfpi.h" + +using namespace sfpi; + +namespace ckernel { +namespace sfpu { + +template +inline void calculate_quant_int32(const uint dst_offset) +{ + _quant_int32_(dst_offset); +} + +template +inline void calculate_requant_int32(const uint dst_offset) +{ + _requant_int32_(dst_offset); +} + +template +inline void calculate_dequant_int32(const uint dst_offset) +{ + _dequant_int32_(dst_offset); +} + +template +void quant_init(const uint zero_point) { + _init_quant_zero_point_(zero_point); +} + +} // namespace sfpu +} // namespace ckernel diff --git a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_binary_sfpu_add_int32.h b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_binary_sfpu_add_int32.h new file mode 100644 index 00000000000..db9d2579956 --- /dev/null +++ b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_binary_sfpu_add_int32.h @@ -0,0 +1,27 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "llk_math_eltwise_binary_sfpu_init.h" +#include "llk_math_eltwise_binary_sfpu_params.h" +#include "ckernel_sfpu_add_int32.h" + +namespace ckernel { + +// New LLK SFPU APIs + +template +inline void llk_math_eltwise_binary_sfpu_add_int32_init() { + llk_math_eltwise_binary_sfpu_init(); +} + +template +inline void llk_math_eltwise_binary_sfpu_add_int32( + uint dst_index0, uint32_t dst_index1, int vector_mode = VectorMode::RC) { + llk_math_eltwise_binary_sfpu_params( + ckernel::sfpu::calculate_add_int32, dst_index0, dst_index1, vector_mode); +} + +} // namespace ckernel diff --git a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_binary_sfpu_binop.h b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_binary_sfpu_binop.h new file mode 100644 index 00000000000..09fcb4a530d --- /dev/null +++ b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_binary_sfpu_binop.h @@ -0,0 +1,29 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "llk_math_eltwise_binary_sfpu_init.h" +#include "llk_math_eltwise_binary_sfpu_params.h" +#include "ckernel_sfpu_binary.h" + +namespace ckernel { + +// New LLK SFPU APIs + +template +inline void llk_math_eltwise_binary_sfpu_binop_init() { + llk_math_eltwise_binary_sfpu_init(); +} + +template +inline void llk_math_eltwise_binary_sfpu_binop(uint dst_index0, uint32_t dst_index1, int vector_mode = VectorMode::RC) { + llk_math_eltwise_binary_sfpu_params( + ckernel::sfpu::calculate_sfpu_binary, + dst_index0, + dst_index1, + vector_mode); +} + +} // namespace ckernel diff --git a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_binary_sfpu_bitwise.h b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_binary_sfpu_bitwise.h new file mode 100644 index 00000000000..de2b20a8c70 --- /dev/null +++ b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_binary_sfpu_bitwise.h @@ -0,0 +1,27 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "llk_math_eltwise_binary_sfpu_init.h" +#include "llk_math_eltwise_binary_sfpu_params.h" +#include "ckernel_sfpu_binary_bitwise.h" + +namespace ckernel { + +// New LLK SFPU APIs + +template +inline void llk_math_eltwise_binary_sfpu_bitwise_init() { + llk_math_eltwise_binary_sfpu_init(); +} + +template +inline void llk_math_eltwise_binary_sfpu_bitwise( + uint dst_index0, uint32_t dst_index1, int vector_mode = VectorMode::RC) { + llk_math_eltwise_binary_sfpu_params( + ckernel::sfpu::calculate_sfpu_binary_bitwise, dst_index0, dst_index1, vector_mode); +} + +} // namespace ckernel diff --git a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_binary_sfpu_init.h b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_binary_sfpu_init.h new file mode 100644 index 00000000000..23f90869e89 --- /dev/null +++ b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_binary_sfpu_init.h @@ -0,0 +1,23 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "llk_sfpu_types.h" +#include "llk_math_eltwise_binary_sfpu.h" + +namespace ckernel { + +template +inline void llk_math_eltwise_binary_sfpu_init() { + _llk_math_eltwise_binary_sfpu_init_(); +} + +template +inline void llk_math_eltwise_binary_sfpu_init(F&& init_func, ARGS&& ... args) { + _llk_math_eltwise_binary_sfpu_init_(); + init_func(static_cast(args)...); +} + +} diff --git a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_binary_sfpu_params.h b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_binary_sfpu_params.h new file mode 100644 index 00000000000..3d54229b9d7 --- /dev/null +++ b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_binary_sfpu_params.h @@ -0,0 +1,57 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once +#include "llk_sfpu_types.h" +#include "llk_math_eltwise_binary_sfpu.h" + +template +inline void llk_math_eltwise_binary_sfpu_params( + F&& sfpu_func, + uint dst_index0, + uint dst_index1, + int vector_mode = (int)VectorMode::RC, + ARGS&& ... args) { + + uint dst_index = (dst_index0 <= dst_index1) ? dst_index0 : dst_index1; + uint dst_offset = (dst_index0 > dst_index1) ? dst_index0 - dst_index1 : dst_index1 - dst_index0; + _llk_math_eltwise_binary_sfpu_start_(dst_index); + + if (vector_mode == (int)VectorMode::R) { + // Do a row vector, Face0 + Face1 -- first iteration (first row) + const int ITERATIONS = 1; +#pragma GCC unroll 0 + for (int face = 0; face < 2; face++) { + sfpu_func(dst_offset, static_cast(args)...); + TTI_SETRWC(p_setrwc::CLR_NONE, p_setrwc::CR_D, 8, 0, 0, p_setrwc::SET_D); + TTI_SETRWC(p_setrwc::CLR_NONE, p_setrwc::CR_D, 8, 0, 0, p_setrwc::SET_D); + } + // Skip the next 2 faces + TTI_SETRWC(p_setrwc::CLR_NONE, p_setrwc::CR_D, 8, 0, 0, p_setrwc::SET_D); + TTI_SETRWC(p_setrwc::CLR_NONE, p_setrwc::CR_D, 8, 0, 0, p_setrwc::SET_D); + TTI_SETRWC(p_setrwc::CLR_NONE, p_setrwc::CR_D, 8, 0, 0, p_setrwc::SET_D); + TTI_SETRWC(p_setrwc::CLR_NONE, p_setrwc::CR_D, 8, 0, 0, p_setrwc::SET_D); + } else if (vector_mode == (int)VectorMode::C) { + // Do a column vector, Face0 + Face2 -- All iterations for full face +#pragma GCC unroll 0 + for (int face = 0; face < 2; face++) { + sfpu_func(dst_offset, static_cast(args)...); + TTI_SETRWC(p_setrwc::CLR_NONE, p_setrwc::CR_D, 8, 0, 0, p_setrwc::SET_D); + TTI_SETRWC(p_setrwc::CLR_NONE, p_setrwc::CR_D, 8, 0, 0, p_setrwc::SET_D); + TTI_SETRWC(p_setrwc::CLR_NONE, p_setrwc::CR_D, 8, 0, 0, p_setrwc::SET_D); + TTI_SETRWC(p_setrwc::CLR_NONE, p_setrwc::CR_D, 8, 0, 0, p_setrwc::SET_D); + } + } else if (vector_mode == (int)VectorMode::RC) { + // Do all four faces, and iterate through all 4 blocks of 4 rows each +#pragma GCC unroll 0 + for (int face = 0; face < 4; face++) { + sfpu_func(dst_offset, static_cast(args)...); + TTI_SETRWC(p_setrwc::CLR_NONE, p_setrwc::CR_D, 8, 0, 0, p_setrwc::SET_D); + TTI_SETRWC(p_setrwc::CLR_NONE, p_setrwc::CR_D, 8, 0, 0, p_setrwc::SET_D); + } + } else { + sfpu_func(dst_offset, static_cast(args)...); + } + _llk_math_eltwise_binary_sfpu_done_(); +} diff --git a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_binary_sfpu_quant.h b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_binary_sfpu_quant.h new file mode 100644 index 00000000000..ffa3ad4451f --- /dev/null +++ b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_binary_sfpu_quant.h @@ -0,0 +1,63 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "llk_math_eltwise_binary_sfpu_init.h" +#include "llk_math_eltwise_binary_sfpu_params.h" +#include "ckernel_sfpu_quant.h" + +namespace ckernel { + +// New LLK SFPU APIs + +template +inline void llk_math_eltwise_binary_sfpu_quant_int32_init(const uint zero_point) { + llk_math_eltwise_binary_sfpu_init( + sfpu::quant_init, + zero_point); +} + +template +inline void llk_math_eltwise_binary_sfpu_quant_int32(uint dst_index0, uint dst_index1, int vector_mode = (int)VectorMode::RC) { + llk_math_eltwise_binary_sfpu_params( + ckernel::sfpu::calculate_quant_int32, + dst_index0, + dst_index1, + vector_mode); +} + +template +inline void llk_math_eltwise_binary_sfpu_requant_int32_init(const uint zero_point) { + llk_math_eltwise_binary_sfpu_init( + sfpu::quant_init, + zero_point); +} + +template +inline void llk_math_eltwise_binary_sfpu_requant_int32(uint dst_index0, uint dst_index1, int vector_mode = (int)VectorMode::RC) { + llk_math_eltwise_binary_sfpu_params( + ckernel::sfpu::calculate_requant_int32, + dst_index0, + dst_index1, + vector_mode); +} + +template +inline void llk_math_eltwise_binary_sfpu_dequant_int32_init(const uint zero_point) { + llk_math_eltwise_binary_sfpu_init( + sfpu::quant_init, + zero_point); +} + +template +inline void llk_math_eltwise_binary_sfpu_dequant_int32(uint dst_index0, uint dst_index1, int vector_mode = (int)VectorMode::RC) { + llk_math_eltwise_binary_sfpu_params( + ckernel::sfpu::calculate_dequant_int32, + dst_index0, + dst_index1, + vector_mode); +} + +} // namespace ckernel diff --git a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_floor.h b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_floor.h index ff0e6e96daf..26325252a0f 100644 --- a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_floor.h +++ b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_floor.h @@ -23,4 +23,9 @@ inline void llk_math_eltwise_unary_sfpu_floor(uint dst_index, int vector_mode = ckernel::sfpu::calculate_floor, dst_index, vector_mode); } +template +inline void llk_math_eltwise_unary_sfpu_floor_float32(uint dst_index, int vector_mode = (int)VectorMode::RC) { + llk_math_eltwise_unary_sfpu_params( + ckernel::sfpu::calculate_floor_float32, dst_index, vector_mode); +} } // namespace ckernel diff --git a/tt_metal/hw/firmware/src/brisc.cc b/tt_metal/hw/firmware/src/brisc.cc index 320b779d936..5554f2edcf3 100644 --- a/tt_metal/hw/firmware/src/brisc.cc +++ b/tt_metal/hw/firmware/src/brisc.cc @@ -20,7 +20,6 @@ #include "tools/profiler/kernel_profiler.hpp" #include "dev_msgs.h" #include "risc_attribs.h" -#include "generated_bank_to_noc_coord_mapping.h" #include "circular_buffer.h" #include "circular_buffer_init.h" #include "dataflow_api.h" @@ -67,6 +66,13 @@ uint32_t tt_l1_ptr *rta_l1_base __attribute__((used)); uint32_t tt_l1_ptr *crta_l1_base __attribute__((used)); uint32_t tt_l1_ptr *sem_l1_base[ProgrammableCoreType::COUNT] __attribute__((used)); +// These arrays are stored in local memory of FW, but primarily used by the kernel which shares +// FW symbols. Hence mark these as 'used' so that FW compiler doesn't optimize it out. +uint16_t dram_bank_to_noc_xy[NUM_NOCS][NUM_DRAM_BANKS] __attribute__((used)); +uint16_t l1_bank_to_noc_xy[NUM_NOCS][NUM_L1_BANKS] __attribute__((used)); +int32_t bank_to_dram_offset[NUM_DRAM_BANKS] __attribute__((used)); +int32_t bank_to_l1_offset[NUM_L1_BANKS] __attribute__((used)); + #define MEM_MOVER_VIEW_IRAM_BASE_ADDR (0x4 << 12) #if defined(PROFILE_KERNEL) @@ -343,6 +349,8 @@ int main() { do_crt1((uint32_t*)MEM_BRISC_INIT_LOCAL_L1_BASE_SCRATCH); + noc_bank_table_init(MEM_BANK_TO_NOC_SCRATCH); + mailboxes->launch_msg_rd_ptr = 0; // Initialize the rdptr to 0 noc_index = 0; risc_init(); diff --git a/tt_metal/hw/firmware/src/erisc.cc b/tt_metal/hw/firmware/src/erisc.cc index dcf1ffc60a7..44d760a069c 100644 --- a/tt_metal/hw/firmware/src/erisc.cc +++ b/tt_metal/hw/firmware/src/erisc.cc @@ -5,7 +5,6 @@ #include "ethernet/dataflow_api.h" #include "ethernet/tunneling.h" #include "firmware_common.h" -#include "generated_bank_to_noc_coord_mapping.h" #include "noc_parameters.h" #include "risc_attribs.h" #include "tools/profiler/kernel_profiler.hpp" @@ -34,6 +33,13 @@ uint32_t tt_l1_ptr *rta_l1_base __attribute__((used)); uint32_t tt_l1_ptr *crta_l1_base __attribute__((used)); uint32_t tt_l1_ptr *sem_l1_base[ProgrammableCoreType::COUNT] __attribute__((used)); +// These arrays are stored in local memory of FW, but primarily used by the kernel which shares +// FW symbols. Hence mark these as 'used' so that FW compiler doesn't optimize it out. +uint16_t dram_bank_to_noc_xy[NUM_NOCS][NUM_DRAM_BANKS] __attribute__((used)); +uint16_t l1_bank_to_noc_xy[NUM_NOCS][NUM_L1_BANKS] __attribute__((used)); +int32_t bank_to_dram_offset[NUM_DRAM_BANKS] __attribute__((used)); +int32_t bank_to_l1_offset[NUM_L1_BANKS] __attribute__((used)); + void __attribute__((noinline)) Application(void) { WAYPOINT("I"); @@ -43,6 +49,8 @@ void __attribute__((noinline)) Application(void) { rtos_context_switch_ptr = (void (*)())RtosTable[0]; + noc_bank_table_init(eth_l1_mem::address_map::ERISC_MEM_BANK_TO_NOC_SCRATCH); + risc_init(); noc_init(MEM_NOC_ATOMIC_RET_VAL_ADDR); diff --git a/tt_metal/hw/firmware/src/idle_erisc.cc b/tt_metal/hw/firmware/src/idle_erisc.cc index 4e027e0dd7f..455629e95c7 100644 --- a/tt_metal/hw/firmware/src/idle_erisc.cc +++ b/tt_metal/hw/firmware/src/idle_erisc.cc @@ -19,7 +19,6 @@ #include "tools/profiler/kernel_profiler.hpp" #include "dev_msgs.h" #include "risc_attribs.h" -#include "generated_bank_to_noc_coord_mapping.h" #include "circular_buffer.h" #include "dataflow_api.h" @@ -42,6 +41,13 @@ uint32_t tt_l1_ptr *sem_l1_base[ProgrammableCoreType::COUNT] __attribute__((used uint8_t my_x[NUM_NOCS] __attribute__((used)); uint8_t my_y[NUM_NOCS] __attribute__((used)); +// These arrays are stored in local memory of FW, but primarily used by the kernel which shares +// FW symbols. Hence mark these as 'used' so that FW compiler doesn't optimize it out. +uint16_t dram_bank_to_noc_xy[NUM_NOCS][NUM_DRAM_BANKS] __attribute__((used)); +uint16_t l1_bank_to_noc_xy[NUM_NOCS][NUM_L1_BANKS] __attribute__((used)); +int32_t bank_to_dram_offset[NUM_DRAM_BANKS] __attribute__((used)); +int32_t bank_to_l1_offset[NUM_L1_BANKS] __attribute__((used)); + //c_tensix_core core; tt_l1_ptr mailboxes_t * const mailboxes = (tt_l1_ptr mailboxes_t *)(MEM_IERISC_MAILBOX_BASE); @@ -101,6 +107,8 @@ int main() { do_crt1((uint32_t *)MEM_IERISC_INIT_LOCAL_L1_BASE_SCRATCH); uint32_t heartbeat = 0; + noc_bank_table_init(MEM_IERISC_BANK_TO_NOC_SCRATCH); + risc_init(); mailboxes->slave_sync.all = RUN_SYNC_MSG_ALL_SLAVES_DONE; diff --git a/tt_metal/hw/firmware/src/ncrisc.cc b/tt_metal/hw/firmware/src/ncrisc.cc index fb3c6e566b3..ba91c04713b 100644 --- a/tt_metal/hw/firmware/src/ncrisc.cc +++ b/tt_metal/hw/firmware/src/ncrisc.cc @@ -11,7 +11,6 @@ #include "firmware_common.h" #include "tools/profiler/kernel_profiler.hpp" #include "risc_attribs.h" -#include "generated_bank_to_noc_coord_mapping.h" #include "circular_buffer.h" #include "circular_buffer_init.h" @@ -40,6 +39,13 @@ uint32_t tt_l1_ptr *rta_l1_base __attribute__((used)); uint32_t tt_l1_ptr *crta_l1_base __attribute__((used)); uint32_t tt_l1_ptr *sem_l1_base[ProgrammableCoreType::COUNT] __attribute__((used)); +// These arrays are stored in local memory of FW, but primarily used by the kernel which shares +// FW symbols. Hence mark these as 'used' so that FW compiler doesn't optimize it out. +uint16_t dram_bank_to_noc_xy[NUM_NOCS][NUM_DRAM_BANKS] __attribute__((used)); +int32_t bank_to_dram_offset[NUM_DRAM_BANKS] __attribute__((used)); +uint16_t l1_bank_to_noc_xy[NUM_NOCS][NUM_L1_BANKS] __attribute__((used)); +int32_t bank_to_l1_offset[NUM_L1_BANKS] __attribute__((used)); + #if defined(PROFILE_KERNEL) namespace kernel_profiler { uint32_t wIndex __attribute__((used)); @@ -79,6 +85,8 @@ int main(int argc, char *argv[]) { do_crt1((uint32_t tt_l1_ptr *)MEM_NCRISC_INIT_LOCAL_L1_BASE_SCRATCH); + noc_bank_table_init(MEM_BANK_TO_NOC_SCRATCH); + risc_init(); // If NCRISC has IRAM it needs to halt before BRISC copies data from L1 to IRAM diff --git a/tt_metal/hw/firmware/src/slave_idle_erisc.cc b/tt_metal/hw/firmware/src/slave_idle_erisc.cc index 164313f27df..8e0b4500a7a 100644 --- a/tt_metal/hw/firmware/src/slave_idle_erisc.cc +++ b/tt_metal/hw/firmware/src/slave_idle_erisc.cc @@ -11,7 +11,6 @@ #include "firmware_common.h" #include "tools/profiler/kernel_profiler.hpp" #include "risc_attribs.h" -#include "generated_bank_to_noc_coord_mapping.h" #include "circular_buffer.h" #include "debug/waypoint.h" diff --git a/tt_metal/hw/inc/blackhole/dev_mem_map.h b/tt_metal/hw/inc/blackhole/dev_mem_map.h index 4f68f18e9af..3ef1012727a 100644 --- a/tt_metal/hw/inc/blackhole/dev_mem_map.h +++ b/tt_metal/hw/inc/blackhole/dev_mem_map.h @@ -41,6 +41,11 @@ #define MEM_NCRISC_LOCAL_SIZE (8 * 1024) #define MEM_TRISC_LOCAL_SIZE (4 * 1024) +// Memory for (dram/l1)_bank_to_noc_xy arrays, size needs to be atleast 2 * NUM_NOCS * (NUM_DRAM_BANKS + NUM_L1_BANKS) +#define MEM_BANK_TO_NOC_XY_SIZE 1024 +// Memory for bank_to_dram_offset and bank_to_l1_offset arrays, size needs to be atleast 4 * (NUM_DRAM_BANKS + NUM_L1_BANKS) +#define MEM_BANK_OFFSET_SIZE 1024 + ///////////// // Firmware/kernel code holes #define MEM_BRISC_FIRMWARE_SIZE (5 * 1024 + 128) @@ -91,6 +96,9 @@ #define MEM_TRISC1_INIT_LOCAL_L1_BASE_SCRATCH (MEM_TRISC0_INIT_LOCAL_L1_BASE_SCRATCH + MEM_TRISC_LOCAL_SIZE) #define MEM_TRISC2_INIT_LOCAL_L1_BASE_SCRATCH (MEM_TRISC1_INIT_LOCAL_L1_BASE_SCRATCH + MEM_TRISC_LOCAL_SIZE) +#define MEM_BANK_TO_NOC_SCRATCH (MEM_TRISC2_INIT_LOCAL_L1_BASE_SCRATCH + MEM_TRISC_LOCAL_SIZE) +#define MEM_BANK_TO_NOC_SIZE (MEM_BANK_TO_NOC_XY_SIZE + MEM_BANK_OFFSET_SIZE) + ///////////// // Stack info // Increasing the stack size comes at the expense of less local memory for globals @@ -130,6 +138,9 @@ #define MEM_IERISC_STACK_BASE (MEM_LOCAL_BASE + MEM_IERISC_LOCAL_SIZE - MEM_IERISC_STACK_SIZE) #define MEM_SLAVE_IERISC_STACK_BASE (MEM_LOCAL_BASE + MEM_SLAVE_IERISC_LOCAL_SIZE - MEM_SLAVE_IERISC_STACK_SIZE) +#define MEM_IERISC_BANK_TO_NOC_SCRATCH (MEM_SLAVE_IERISC_INIT_LOCAL_L1_BASE_SCRATCH + MEM_SLAVE_IERISC_LOCAL_SIZE) +#define MEM_IERISC_BANK_TO_NOC_SIZE (MEM_BANK_TO_NOC_XY_SIZE + MEM_BANK_OFFSET_SIZE) + ///////////// // Padding/alignment restriction needed in linker scripts for erisc #define MEM_IERISC_KERNEL_PAD 32 diff --git a/tt_metal/hw/inc/blackhole/eth_l1_address_map.h b/tt_metal/hw/inc/blackhole/eth_l1_address_map.h index e99d13af3d4..6cfe5eadaf8 100644 --- a/tt_metal/hw/inc/blackhole/eth_l1_address_map.h +++ b/tt_metal/hw/inc/blackhole/eth_l1_address_map.h @@ -26,6 +26,13 @@ struct address_map { static constexpr std::int32_t DATA_BUFFER_SIZE_ETH = 4 * 1024; static constexpr std::int32_t DATA_BUFFER_SIZE_NOC = 16 * 1024; static constexpr std::int32_t DATA_BUFFER_SIZE = 24 * 1024; + // Memory for (dram/l1)_bank_to_noc_xy arrays, size needs to be atleast 2 * NUM_NOCS * (NUM_DRAM_BANKS + + // NUM_L1_BANKS) + static constexpr std::int32_t ERISC_MEM_BANK_TO_NOC_XY_SIZE = 1024; + // Memory for bank_to_dram_offset and bank_to_l1_offset arrays, size needs to be atleast 4 * (NUM_DRAM_BANKS + + // NUM_L1_BANKS) + static constexpr std::int32_t ERISC_MEM_BANK_OFFSET_SIZE = 1024; + // Kernel config buffer is WIP // Size is presently based on the old sizes of the RTAs + CB config + Sems static constexpr std::int32_t ERISC_L1_KERNEL_CONFIG_SIZE = 96 * 4 + 8 * 16; @@ -51,10 +58,7 @@ struct address_map { static constexpr std::int32_t ERISC_APP_ROUTING_INFO_BASE = TILE_HEADER_BUFFER_BASE; static constexpr std::int32_t ERISC_APP_SYNC_INFO_BASE = ERISC_APP_ROUTING_INFO_BASE + ERISC_APP_ROUTING_INFO_SIZE; - static constexpr uint32_t ISSUE_CQ_CB_BASE = ERISC_APP_SYNC_INFO_BASE + ERISC_APP_SYNC_INFO_SIZE; - static constexpr uint32_t COMPLETION_CQ_CB_BASE = ISSUE_CQ_CB_BASE + 7 * L1_ALIGNMENT; - - static constexpr std::int32_t ERISC_MEM_MAILBOX_BASE = COMPLETION_CQ_CB_BASE + 7 * L1_ALIGNMENT; + static constexpr std::uint32_t ERISC_MEM_MAILBOX_BASE = ERISC_APP_SYNC_INFO_BASE + ERISC_APP_SYNC_INFO_SIZE; static constexpr std::uint32_t ERISC_MEM_MAILBOX_SIZE = 3344; static constexpr std::uint32_t ERISC_MEM_MAILBOX_END = ERISC_MEM_MAILBOX_BASE + ERISC_MEM_MAILBOX_SIZE; @@ -65,10 +69,13 @@ struct address_map { static_assert((ERISC_L1_UNRESERVED_BASE % 32) == 0); - static constexpr std::int32_t LAUNCH_ERISC_APP_FLAG = L1_EPOCH_Q_BASE + 4; + // This scratch address is same as ERISC_L1_UNRESERVED_BASE, as the scratch space is used to copy data during + // runtime build, and is unused once FW copies the data to local memory during FW initialization. + static constexpr std::int32_t ERISC_MEM_BANK_TO_NOC_SCRATCH = + (ERISC_L1_KERNEL_CONFIG_BASE + ERISC_L1_KERNEL_CONFIG_SIZE + 31) & ~31; + static constexpr std::int32_t ERISC_MEM_BANK_TO_NOC_SIZE = ERISC_MEM_BANK_TO_NOC_XY_SIZE + ERISC_MEM_BANK_OFFSET_SIZE; - // BIDIR Tunneling Kernel Space - static constexpr std::int32_t ERISC_L1_TUNNEL_BUFFER_SIZE = ERISC_L1_UNRESERVED_SIZE / 2; + static constexpr std::int32_t LAUNCH_ERISC_APP_FLAG = L1_EPOCH_Q_BASE + 4; template struct TAssertEquality { diff --git a/tt_metal/hw/inc/blackhole/noc/noc_parameters.h b/tt_metal/hw/inc/blackhole/noc/noc_parameters.h index 7618d83fe99..265466d2f28 100644 --- a/tt_metal/hw/inc/blackhole/noc/noc_parameters.h +++ b/tt_metal/hw/inc/blackhole/noc/noc_parameters.h @@ -355,10 +355,8 @@ #define NOC_XY_ENCODING(x, y) ((((uint32_t)(y)) << (NOC_ADDR_NODE_ID_BITS)) | (((uint32_t)(x)))) // Base address pulled from tt::umd::Cluster::get_pcie_base_addr_from_device -#define NOC_XY_PCIE_ENCODING(x, y, noc_index) \ - ((uint64_t(NOC_XY_ENCODING(x, y)) << (NOC_ADDR_LOCAL_BITS - NOC_COORD_REG_OFFSET))) | \ - ((noc_index ? (x == PCIE_NOC1_X and y == PCIE_NOC1_Y) : (x == PCIE_NOC_X and y == PCIE_NOC_Y)) * \ - 0x1000000000000000) +#define NOC_XY_PCIE_ENCODING(x, y) \ + ((uint64_t(NOC_XY_ENCODING(x, y)) << (NOC_ADDR_LOCAL_BITS - NOC_COORD_REG_OFFSET)) | 0x1000000000000000) #define NOC_MULTICAST_ENCODING(x_start, y_start, x_end, y_end) \ ((((uint32_t)(x_start)) << (2 * NOC_ADDR_NODE_ID_BITS)) | (((uint32_t)(y_start)) << (3 * NOC_ADDR_NODE_ID_BITS)) | \ diff --git a/tt_metal/hw/inc/circular_buffer.h b/tt_metal/hw/inc/circular_buffer.h index 35942bb49d7..68a2b3436cc 100644 --- a/tt_metal/hw/inc/circular_buffer.h +++ b/tt_metal/hw/inc/circular_buffer.h @@ -103,3 +103,9 @@ FORCE_INLINE RemoteSenderCBInterface& get_remote_sender_cb_interface(uint32_t cb FORCE_INLINE RemoteReceiverCBInterface& get_remote_receiver_cb_interface(uint32_t cb_id) { return cb_interface[cb_id].remote_receiver_cb_interface; } + +#if defined(COMPILE_FOR_TRISC) +constexpr uint32_t cb_addr_shift = CIRCULAR_BUFFER_COMPUTE_ADDR_SHIFT; +#else +constexpr uint32_t cb_addr_shift = 0; +#endif diff --git a/tt_metal/hw/inc/circular_buffer_constants.h b/tt_metal/hw/inc/circular_buffer_constants.h index 8ff3fda763b..3b80937753c 100644 --- a/tt_metal/hw/inc/circular_buffer_constants.h +++ b/tt_metal/hw/inc/circular_buffer_constants.h @@ -9,5 +9,5 @@ constexpr static std::uint32_t NUM_CIRCULAR_BUFFERS = 32; constexpr static std::uint32_t UINT32_WORDS_PER_LOCAL_CIRCULAR_BUFFER_CONFIG = 4; constexpr static std::uint32_t UINT32_WORDS_PER_REMOTE_CIRCULAR_BUFFER_CONFIG = 2; -constexpr static std::uint32_t CIRCULAR_BUFFER_WORD_SIZE_BYTES = 16; -constexpr static std::uint32_t CIRCULAR_BUFFER_LOG2_WORD_SIZE_BYTES = 4; +constexpr static std::uint32_t CIRCULAR_BUFFER_COMPUTE_WORD_SIZE = 16; +constexpr static std::uint32_t CIRCULAR_BUFFER_COMPUTE_ADDR_SHIFT = 4; diff --git a/tt_metal/hw/inc/circular_buffer_init.h b/tt_metal/hw/inc/circular_buffer_init.h index 29f2af20cb4..b4402508022 100644 --- a/tt_metal/hw/inc/circular_buffer_init.h +++ b/tt_metal/hw/inc/circular_buffer_init.h @@ -25,10 +25,10 @@ inline void setup_local_cb_read_write_interfaces( for (uint32_t cb_id = start_cb_index; cb_id < max_cb_index; cb_id++) { // NOTE: fifo_addr, fifo_size and fifo_limit in 16B words! - uint32_t fifo_addr = circular_buffer_config_addr[0]; - uint32_t fifo_size = circular_buffer_config_addr[1]; + uint32_t fifo_addr = circular_buffer_config_addr[0] >> cb_addr_shift; + uint32_t fifo_size = circular_buffer_config_addr[1] >> cb_addr_shift; uint32_t fifo_num_pages = circular_buffer_config_addr[2]; - uint32_t fifo_page_size = circular_buffer_config_addr[3]; + uint32_t fifo_page_size = circular_buffer_config_addr[3] >> cb_addr_shift; uint32_t fifo_limit = fifo_addr + fifo_size; LocalCBInterface& local_interface = get_local_cb_interface(cb_id); diff --git a/tt_metal/hw/inc/dataflow_api.h b/tt_metal/hw/inc/dataflow_api.h index 7ddf16b8ac4..c6877db3e0a 100644 --- a/tt_metal/hw/inc/dataflow_api.h +++ b/tt_metal/hw/inc/dataflow_api.h @@ -10,9 +10,7 @@ #include "chlkc_unpack_tile_dims.h" #define DATA_FORMATS_DEFINED #endif -#if __has_include("generated_bank_to_noc_coord_mapping.h") -#include "generated_bank_to_noc_coord_mapping.h" -#endif +#include #include @@ -37,9 +35,15 @@ constexpr uint8_t proc_type = static_cast @@ -410,7 +418,7 @@ constexpr inline DataFormat get_dataformat(const std::int32_t operand) { FORCE_INLINE uint32_t get_write_ptr(uint32_t operand) { // return byte address (fifo_wr_ptr is 16B address) - uint32_t wr_ptr_bytes = get_local_cb_interface(operand).fifo_wr_ptr << 4; + uint32_t wr_ptr_bytes = get_local_cb_interface(operand).fifo_wr_ptr; return wr_ptr_bytes; } @@ -429,7 +437,7 @@ uint32_t get_write_ptr(uint32_t operand) { FORCE_INLINE uint32_t get_read_ptr(uint32_t operand) { // return byte address (fifo_rd_ptr is 16B address) - uint32_t rd_ptr_bytes = get_local_cb_interface(operand).fifo_rd_ptr << 4; + uint32_t rd_ptr_bytes = get_local_cb_interface(operand).fifo_rd_ptr; return rd_ptr_bytes; } @@ -694,7 +702,7 @@ uint64_t get_system_memory_noc_addr( const uint32_t offset = 0, uint8_t noc = noc_index) { uint64_t pcie_core_noc_encoding = - uint64_t(NOC_XY_PCIE_ENCODING(DYNAMIC_NOC_X(noc, PCIE_NOC_X), DYNAMIC_NOC_Y(noc, PCIE_NOC_Y), noc)); + uint64_t(NOC_XY_PCIE_ENCODING(DYNAMIC_NOC_X(noc, PCIE_NOC_X), DYNAMIC_NOC_Y(noc, PCIE_NOC_Y))); uint32_t addr = base_addr + page_size * id + offset; uint64_t noc_addr = pcie_core_noc_encoding | addr; return noc_addr; diff --git a/tt_metal/hw/inc/debug/dprint_tile.h b/tt_metal/hw/inc/debug/dprint_tile.h index 85aa838d8e5..1e737f66cf1 100644 --- a/tt_metal/hw/inc/debug/dprint_tile.h +++ b/tt_metal/hw/inc/debug/dprint_tile.h @@ -17,13 +17,12 @@ #endif // Macros for printing circular buffer internals -#define CB_RD_PTR(id) (get_local_cb_interface(id).fifo_rd_ptr << 4) // only valid in unpacker thread -#define CB_RD_LIM(id) ((get_local_cb_interface(id).fifo_limit_plus_1 - 1) << 4) -#define CB_RD_SZ(id) (get_local_cb_interface(id).fifo_size << 4) +#define CB_RD_PTR(id) (get_local_cb_interface(id).fifo_rd_ptr << cb_addr_shift) // only valid in unpacker thread +#define CB_RD_SZ(id) (get_local_cb_interface(id).fifo_size << cb_addr_shift) -#define CB_WR_PTR(id) (get_local_cb_interface(id).fifo_wr_ptr << 4) // only valid in packer thread +#define CB_WR_PTR(id) (get_local_cb_interface(id).fifo_wr_ptr << cb_addr_shift) // only valid in packer thread #define CB_PAGE_COUNT(id) (get_local_cb_interface(id).fifo_num_pages) -#define CB_PAGE_SIZE(id) (get_local_cb_interface(id).fifo_page_size << 4) +#define CB_PAGE_SIZE(id) (get_local_cb_interface(id).fifo_page_size << cb_addr_shift) // // Slices/samples elements of a tile 'itile' from cb using a given numpy style slice object SliceRange. diff --git a/tt_metal/hw/inc/firmware_common.h b/tt_metal/hw/inc/firmware_common.h index c292a7261a8..9f051b32abb 100644 --- a/tt_metal/hw/inc/firmware_common.h +++ b/tt_metal/hw/inc/firmware_common.h @@ -13,39 +13,17 @@ #include "dev_mem_map.h" #include "hostdevcommon/kernel_structs.h" #include "dev_msgs.h" +#include "noc/noc_parameters.h" +#include "debug/dprint.h" + +extern uint16_t dram_bank_to_noc_xy[NUM_NOCS][NUM_DRAM_BANKS]; +extern int32_t bank_to_dram_offset[NUM_DRAM_BANKS]; +extern uint16_t l1_bank_to_noc_xy[NUM_NOCS][NUM_L1_BANKS]; +extern int32_t bank_to_l1_offset[NUM_L1_BANKS]; extern void kernel_init(uint32_t kernel_init); extern void kernel_launch(uint32_t kernel_base_addr); - -inline void l1_to_local_mem_copy(uint32_t* dst, uint32_t tt_l1_ptr* src, int32_t len) { -#pragma GCC unroll 0 - while (len >= 3) { - auto v0 = src[0], v1 = src[1], v2 = src[2]; - // 1) Make sure the optimizer does not think this is memcpy by - // hiding the pointer bookkeeping in an asm. - // 2) The scheduler doesn't know the above loads have 6 cycle - // latency. We emit the 3 bookkeeping adds as a single block - // in the load shadow before the stores. The optimizer will - // not be able to move these. - // 3) We don't need early clobbers here because of the +r - // constraint -- early clobbers would pessimize. - asm inline( - "addi %0,%0,3*%3\n\t" - "addi %1,%1,3*%3\n\t" - "addi %2,%2,-3" - : "+r"(src), "+r"(dst), "+r"(len) - : "i"(sizeof(v0))); - dst[-3] = v0, dst[-2] = v1, dst[-1] = v2; - } - // There are 0, 1 or 2 words of residue. This is smaller than a loop. - // We get smaller code layout by expecting the conditions to be true. - if (__builtin_expect(len >= 1, true)) { - dst[0] = src[0]; - if (__builtin_expect(len >= 2, true)) { - dst[1] = src[1]; - } - } -} +void l1_to_local_mem_copy(uint32_t* dst, uint32_t tt_l1_ptr* src, int32_t len); inline void do_crt1(uint32_t tt_l1_ptr* data_image) { // Clear bss. @@ -59,6 +37,18 @@ inline void do_crt1(uint32_t tt_l1_ptr* data_image) { l1_to_local_mem_copy(__ldm_data_start, data_image, __ldm_data_end - __ldm_data_start); } +inline void noc_bank_table_init(uint64_t mem_bank_to_noc_addr) { + int32_t dram_to_noc_size_bytes = sizeof(dram_bank_to_noc_xy); + l1_to_local_mem_copy((uint*)dram_bank_to_noc_xy, (uint tt_l1_ptr*)mem_bank_to_noc_addr, dram_to_noc_size_bytes >> 2); + int32_t l1_to_noc_size_bytes = sizeof(l1_bank_to_noc_xy); + l1_to_local_mem_copy((uint*)l1_bank_to_noc_xy, (uint tt_l1_ptr*)(mem_bank_to_noc_addr + dram_to_noc_size_bytes), l1_to_noc_size_bytes >> 2); + + int32_t dram_offsets_size_bytes = sizeof(bank_to_dram_offset); + l1_to_local_mem_copy((uint*)bank_to_dram_offset, (uint tt_l1_ptr*)(mem_bank_to_noc_addr + dram_to_noc_size_bytes + l1_to_noc_size_bytes), dram_offsets_size_bytes >> 2); + int32_t l1_offsets_size_bytes = sizeof(bank_to_l1_offset); + l1_to_local_mem_copy((uint*)bank_to_l1_offset, (uint tt_l1_ptr*)(mem_bank_to_noc_addr + dram_to_noc_size_bytes + l1_to_noc_size_bytes + dram_offsets_size_bytes), l1_offsets_size_bytes >> 2); +} + FORCE_INLINE uint32_t firmware_config_init( tt_l1_ptr mailboxes_t* const mailboxes, uint32_t core_type_index, uint32_t dispatch_class) { diff --git a/tt_metal/hw/inc/grayskull/dev_mem_map.h b/tt_metal/hw/inc/grayskull/dev_mem_map.h index ba2077838c2..d7d829e7392 100644 --- a/tt_metal/hw/inc/grayskull/dev_mem_map.h +++ b/tt_metal/hw/inc/grayskull/dev_mem_map.h @@ -40,15 +40,20 @@ #define MEM_NCRISC_LOCAL_SIZE (4 * 1024) #define MEM_TRISC_LOCAL_SIZE (2 * 1024) +// Memory for (dram/l1)_bank_to_noc_xy arrays, size needs to be atleast 2 * NUM_NOCS * (NUM_DRAM_BANKS + NUM_L1_BANKS) +#define MEM_BANK_TO_NOC_XY_SIZE 1024 +// Memory for bank_to_dram_offset and bank_to_l1_offset arrays, size needs to be atleast 4 * (NUM_DRAM_BANKS + NUM_L1_BANKS) +#define MEM_BANK_OFFSET_SIZE 1024 + #define NCRISC_HAS_IRAM 1 #define MEM_NCRISC_IRAM_BASE 0xFFC00000 #define MEM_NCRISC_IRAM_SIZE (16 * 1024) ///////////// // Firmware/kernel code holes -#define MEM_BRISC_FIRMWARE_SIZE (5 * 1024 + 416) +#define MEM_BRISC_FIRMWARE_SIZE (5 * 1024 + 624) // TODO: perhaps put NCRISC FW in the scratch area and free 1.5K after init (GS/WH) -#define MEM_NCRISC_FIRMWARE_SIZE 1616 +#define MEM_NCRISC_FIRMWARE_SIZE 1824 #define MEM_TRISC0_FIRMWARE_SIZE 1536 #define MEM_TRISC1_FIRMWARE_SIZE 1536 #define MEM_TRISC2_FIRMWARE_SIZE 1536 @@ -100,6 +105,9 @@ #define MEM_TRISC1_INIT_LOCAL_L1_BASE_SCRATCH (MEM_TRISC0_INIT_LOCAL_L1_BASE_SCRATCH + MEM_TRISC_LOCAL_SIZE) #define MEM_TRISC2_INIT_LOCAL_L1_BASE_SCRATCH (MEM_TRISC1_INIT_LOCAL_L1_BASE_SCRATCH + MEM_TRISC_LOCAL_SIZE) +#define MEM_BANK_TO_NOC_SCRATCH (MEM_TRISC2_INIT_LOCAL_L1_BASE_SCRATCH + MEM_TRISC_LOCAL_SIZE) +#define MEM_BANK_TO_NOC_SIZE (MEM_BANK_TO_NOC_XY_SIZE + MEM_BANK_OFFSET_SIZE) + ///////////// // Stack info // Increasing the stack size comes at the expense of less local memory for globals @@ -125,5 +133,7 @@ #define MEM_IERISC_MAP_END 0 #define MEM_IERISC_INIT_LOCAL_L1_BASE_SCRATCH 0 #define MEM_IERISC_STACK_SIZE 0 +#define MEM_IERISC_BANK_TO_NOC_SCRATCH 0 +#define MEM_IERISC_BANK_TO_NOC_SIZE 0 #define MEM_IERISC_KERNEL_PAD 0 diff --git a/tt_metal/hw/inc/grayskull/eth_l1_address_map.h b/tt_metal/hw/inc/grayskull/eth_l1_address_map.h index 0ad8580b15b..edec6f63c30 100644 --- a/tt_metal/hw/inc/grayskull/eth_l1_address_map.h +++ b/tt_metal/hw/inc/grayskull/eth_l1_address_map.h @@ -27,8 +27,6 @@ struct address_map { static constexpr std::int32_t ERISC_FIRMWARE_SIZE = 16; static constexpr std::int32_t ERISC_L1_UNRESERVED_BASE = 0; - static constexpr std::uint32_t ISSUE_CQ_CB_BASE = 0; - static constexpr std::uint32_t COMPLETION_CQ_CB_BASE = 0; static constexpr std::int32_t LAUNCH_ERISC_APP_FLAG = 0; static constexpr std::uint32_t FW_VERSION_ADDR = 0; @@ -36,7 +34,8 @@ struct address_map { static constexpr std::int32_t MAX_L1_LOADING_SIZE = 1; static constexpr std::int32_t ERISC_L1_UNRESERVED_SIZE = 0; - static constexpr std::int32_t ERISC_L1_TUNNEL_BUFFER_SIZE = 0; + static constexpr std::int32_t ERISC_MEM_BANK_TO_NOC_SCRATCH = 0; + static constexpr std::int32_t ERISC_MEM_BANK_TO_NOC_SIZE = 0; static constexpr std::uint32_t RETRAIN_COUNT_ADDR = 0x1EDC; static constexpr std::uint32_t RETRAIN_FORCE_ADDR = 0x1EFC; diff --git a/tt_metal/hw/inc/grayskull/noc/noc_parameters.h b/tt_metal/hw/inc/grayskull/noc/noc_parameters.h index 46ebaadd638..7eff21e6dbd 100644 --- a/tt_metal/hw/inc/grayskull/noc/noc_parameters.h +++ b/tt_metal/hw/inc/grayskull/noc/noc_parameters.h @@ -253,8 +253,7 @@ // Address formats #define NOC_XY_ENCODING(x, y) ((((uint32_t)(y)) << (NOC_ADDR_NODE_ID_BITS)) | (((uint32_t)(x)))) -#define NOC_XY_PCIE_ENCODING(x, y, noc_index) \ - ((uint64_t(NOC_XY_ENCODING(x, y)) << (NOC_ADDR_LOCAL_BITS - NOC_COORD_REG_OFFSET))) +#define NOC_XY_PCIE_ENCODING(x, y) ((uint64_t(NOC_XY_ENCODING(x, y)) << (NOC_ADDR_LOCAL_BITS - NOC_COORD_REG_OFFSET))) #define NOC_MULTICAST_ENCODING(x_start, y_start, x_end, y_end) \ ((x_start) << (2 * NOC_ADDR_NODE_ID_BITS)) | ((y_start) << (3 * NOC_ADDR_NODE_ID_BITS)) | (x_end) | \ diff --git a/tt_metal/hw/inc/remote_circular_buffer_api.h b/tt_metal/hw/inc/remote_circular_buffer_api.h index 712458d62b1..044e3705f93 100644 --- a/tt_metal/hw/inc/remote_circular_buffer_api.h +++ b/tt_metal/hw/inc/remote_circular_buffer_api.h @@ -4,6 +4,7 @@ #pragma once +#include "tt_metal/hw/inc/circular_buffer.h" #include "tt_metal/hw/inc/debug/assert.h" #include "utils/utils.h" #ifndef COMPILE_FOR_TRISC @@ -242,11 +243,12 @@ FORCE_INLINE void align_local_cbs_to_remote_cb( // We assert that the offset of sender and receiver common attributes are the same // so we can use either interface here const RemoteReceiverCBInterface& remote_cb = get_remote_receiver_cb_interface(remote_cb_index); - uint32_t fifo_limit = remote_cb.fifo_limit_page_aligned >> CIRCULAR_BUFFER_LOG2_WORD_SIZE_BYTES; - uint32_t fifo_size = fifo_limit - (remote_cb.fifo_start_addr >> CIRCULAR_BUFFER_LOG2_WORD_SIZE_BYTES); - uint32_t fifo_ptr = remote_cb.fifo_rd_ptr >> CIRCULAR_BUFFER_LOG2_WORD_SIZE_BYTES; + uint32_t fifo_limit = remote_cb.fifo_limit_page_aligned >> cb_addr_shift; + uint32_t fifo_size = fifo_limit - (remote_cb.fifo_start_addr >> cb_addr_shift); + uint32_t fifo_ptr = remote_cb.fifo_rd_ptr >> cb_addr_shift; for (uint32_t i = 0; i < num_local_cbs; i++) { LocalCBInterface& local_cb = get_local_cb_interface(local_cb_indices[i]); + ASSERT(fifo_size % local_cb.fifo_page_size == 0); uint32_t fifo_num_pages = fifo_size / local_cb.fifo_page_size; local_cb.fifo_limit = fifo_limit; local_cb.fifo_size = fifo_size; diff --git a/tt_metal/hw/inc/wormhole/dev_mem_map.h b/tt_metal/hw/inc/wormhole/dev_mem_map.h index c107c20d4b9..0d9e1dd932c 100644 --- a/tt_metal/hw/inc/wormhole/dev_mem_map.h +++ b/tt_metal/hw/inc/wormhole/dev_mem_map.h @@ -41,13 +41,18 @@ #define MEM_NCRISC_LOCAL_SIZE (4 * 1024) #define MEM_TRISC_LOCAL_SIZE (2 * 1024) +// Memory for (dram/l1)_bank_to_noc_xy arrays, size needs to be atleast 2 * NUM_NOCS * (NUM_DRAM_BANKS + NUM_L1_BANKS) +#define MEM_BANK_TO_NOC_XY_SIZE 1024 +// Memory for bank_to_dram_offset and bank_to_l1_offset arrays, size needs to be atleast 4 * (NUM_DRAM_BANKS + NUM_L1_BANKS) +#define MEM_BANK_OFFSET_SIZE 1024 + #define NCRISC_HAS_IRAM 1 #define MEM_NCRISC_IRAM_BASE 0xFFC00000 #define MEM_NCRISC_IRAM_SIZE (16 * 1024) ///////////// // Firmware/kernel code holes -#define MEM_BRISC_FIRMWARE_SIZE (5 * 1024 + 64) +#define MEM_BRISC_FIRMWARE_SIZE (5 * 1024 + 256) // TODO: perhaps put NCRISC FW in the scratch area and free 1.5K after init (GS/WH) #define MEM_NCRISC_FIRMWARE_SIZE 1536 #define MEM_TRISC0_FIRMWARE_SIZE 1536 @@ -102,6 +107,9 @@ #define MEM_TRISC1_INIT_LOCAL_L1_BASE_SCRATCH (MEM_TRISC0_INIT_LOCAL_L1_BASE_SCRATCH + MEM_TRISC_LOCAL_SIZE) #define MEM_TRISC2_INIT_LOCAL_L1_BASE_SCRATCH (MEM_TRISC1_INIT_LOCAL_L1_BASE_SCRATCH + MEM_TRISC_LOCAL_SIZE) +#define MEM_BANK_TO_NOC_SCRATCH (MEM_TRISC2_INIT_LOCAL_L1_BASE_SCRATCH + MEM_TRISC_LOCAL_SIZE) +#define MEM_BANK_TO_NOC_SIZE (MEM_BANK_TO_NOC_XY_SIZE + MEM_BANK_OFFSET_SIZE) + ///////////// // Stack info // Increasing the stack size comes at the expense of less local memory for globals @@ -137,6 +145,10 @@ #define MEM_IERISC_STACK_SIZE 1024 #define MEM_IERISC_STACK_BASE (MEM_LOCAL_BASE + MEM_IERISC_LOCAL_SIZE - MEM_IERISC_STACK_SIZE) +#define MEM_IERISC_BANK_TO_NOC_SCRATCH (MEM_IERISC_INIT_LOCAL_L1_BASE_SCRATCH + MEM_IERISC_LOCAL_SIZE) +#define MEM_IERISC_BANK_TO_NOC_SIZE (MEM_BANK_TO_NOC_XY_SIZE + MEM_BANK_OFFSET_SIZE) + + ///////////// // Padding/alignment restriction needed in linker scripts for erisc #define MEM_IERISC_KERNEL_PAD 32 diff --git a/tt_metal/hw/inc/wormhole/eth_l1_address_map.h b/tt_metal/hw/inc/wormhole/eth_l1_address_map.h index 68e67eb9248..3c87023d855 100644 --- a/tt_metal/hw/inc/wormhole/eth_l1_address_map.h +++ b/tt_metal/hw/inc/wormhole/eth_l1_address_map.h @@ -26,6 +26,11 @@ struct address_map { static constexpr std::int32_t DATA_BUFFER_SIZE_ETH = 4 * 1024; static constexpr std::int32_t DATA_BUFFER_SIZE_NOC = 16 * 1024; static constexpr std::int32_t DATA_BUFFER_SIZE = 24 * 1024; + // Memory for (dram/l1)_bank_to_noc_xy arrays, size needs to be atleast 2 * NUM_NOCS * (NUM_DRAM_BANKS + NUM_L1_BANKS) + static constexpr std::int32_t ERISC_MEM_BANK_TO_NOC_XY_SIZE = 1024; + // Memory for bank_to_dram_offset and bank_to_l1_offset arrays, size needs to be atleast 4 * (NUM_DRAM_BANKS + NUM_L1_BANKS) + static constexpr std::int32_t ERISC_MEM_BANK_OFFSET_SIZE = 1024; + // Kernel config buffer is WIP // Size is presently based on the old sizes of the RTAs + CB config + Sems static constexpr std::int32_t ERISC_L1_KERNEL_CONFIG_SIZE = 96 * 4 + 8 * 16; @@ -51,10 +56,7 @@ struct address_map { static constexpr std::int32_t ERISC_APP_ROUTING_INFO_BASE = TILE_HEADER_BUFFER_BASE; static constexpr std::int32_t ERISC_APP_SYNC_INFO_BASE = ERISC_APP_ROUTING_INFO_BASE + ERISC_APP_ROUTING_INFO_SIZE; - static constexpr uint32_t ISSUE_CQ_CB_BASE = ERISC_APP_SYNC_INFO_BASE + ERISC_APP_SYNC_INFO_SIZE; - static constexpr uint32_t COMPLETION_CQ_CB_BASE = ISSUE_CQ_CB_BASE + 7 * L1_ALIGNMENT; - - static constexpr std::int32_t ERISC_MEM_MAILBOX_BASE = COMPLETION_CQ_CB_BASE + 7 * L1_ALIGNMENT; + static constexpr std::int32_t ERISC_MEM_MAILBOX_BASE = ERISC_APP_SYNC_INFO_BASE + ERISC_APP_SYNC_INFO_SIZE; static constexpr std::uint32_t ERISC_MEM_MAILBOX_SIZE = 3232; static constexpr std::uint32_t ERISC_MEM_MAILBOX_END = ERISC_MEM_MAILBOX_BASE + ERISC_MEM_MAILBOX_SIZE; @@ -65,10 +67,13 @@ struct address_map { static_assert((ERISC_L1_UNRESERVED_BASE % 32) == 0); - static constexpr std::int32_t LAUNCH_ERISC_APP_FLAG = L1_EPOCH_Q_BASE + 4; + // This scratch address is same as ERISC_L1_UNRESERVED_BASE, as the scratch space is used to copy data during + // runtime build, and is unused once FW copies the data to local memory during FW initialization. + static constexpr std::int32_t ERISC_MEM_BANK_TO_NOC_SCRATCH = + (ERISC_L1_KERNEL_CONFIG_BASE + ERISC_L1_KERNEL_CONFIG_SIZE + 31) & ~31; + static constexpr std::int32_t ERISC_MEM_BANK_TO_NOC_SIZE = ERISC_MEM_BANK_TO_NOC_XY_SIZE + ERISC_MEM_BANK_OFFSET_SIZE; - // BIDIR Tunneling Kernel Space - static constexpr std::int32_t ERISC_L1_TUNNEL_BUFFER_SIZE = ERISC_L1_UNRESERVED_SIZE / 2; + static constexpr std::int32_t LAUNCH_ERISC_APP_FLAG = L1_EPOCH_Q_BASE + 4; template struct TAssertEquality { diff --git a/tt_metal/hw/inc/wormhole/noc/noc_parameters.h b/tt_metal/hw/inc/wormhole/noc/noc_parameters.h index 248b82bb6ee..34c899447cf 100644 --- a/tt_metal/hw/inc/wormhole/noc/noc_parameters.h +++ b/tt_metal/hw/inc/wormhole/noc/noc_parameters.h @@ -267,9 +267,8 @@ (((uint32_t)(x)) << (NOC_ADDR_LOCAL_BITS % 32)) // Address formats -#define NOC_XY_PCIE_ENCODING(x, y, noc_index) \ - ((uint64_t(NOC_XY_ENCODING(x, y)) << (NOC_ADDR_LOCAL_BITS - NOC_COORD_REG_OFFSET))) | \ - ((noc_index ? (x == PCIE_NOC1_X and y == PCIE_NOC1_Y) : (x == PCIE_NOC_X and y == PCIE_NOC_Y)) * 0x800000000) +#define NOC_XY_PCIE_ENCODING(x, y) \ + ((uint64_t(NOC_XY_ENCODING(x, y)) << (NOC_ADDR_LOCAL_BITS - NOC_COORD_REG_OFFSET)) | 0x800000000) #define NOC_MULTICAST_ENCODING(x_start, y_start, x_end, y_end) \ (((uint32_t)(x_start)) << ((NOC_ADDR_LOCAL_BITS % 32) + 2 * NOC_ADDR_NODE_ID_BITS)) | \ diff --git a/tt_metal/hw/toolchain/substitutes.cpp b/tt_metal/hw/toolchain/substitutes.cpp index a4e5feb40a0..45764316f8c 100644 --- a/tt_metal/hw/toolchain/substitutes.cpp +++ b/tt_metal/hw/toolchain/substitutes.cpp @@ -37,3 +37,34 @@ extern "C" void wzerorange(uint32_t* start, uint32_t* end) { start[-1] = 0; } } + +// Let the LTO decide if this needs to be inline. +void l1_to_local_mem_copy(uint32_t* dst, uint32_t __attribute__((rvtt_l1_ptr))* src, int32_t len) { +#pragma GCC unroll 0 + while (len >= 3) { + auto v0 = src[0], v1 = src[1], v2 = src[2]; + // 1) Make sure the optimizer does not think this is memcpy by + // hiding the pointer bookkeeping in an asm. + // 2) The scheduler doesn't know the above loads have 6 cycle + // latency. We emit the 3 bookkeeping adds as a single block + // in the load shadow before the stores. The optimizer will + // not be able to move these. + // 3) We don't need early clobbers here because of the +r + // constraint -- early clobbers would pessimize. + asm inline( + "addi %0,%0,3*%3\n\t" + "addi %1,%1,3*%3\n\t" + "addi %2,%2,-3" + : "+r"(src), "+r"(dst), "+r"(len) + : "i"(sizeof(v0))); + dst[-3] = v0, dst[-2] = v1, dst[-1] = v2; + } + // There are 0, 1 or 2 words of residue. This is smaller than a loop. + // We get smaller code layout by expecting the conditions to be true. + if (__builtin_expect(len >= 1, true)) { + dst[0] = src[0]; + if (__builtin_expect(len >= 2, true)) { + dst[1] = src[1]; + } + } +} diff --git a/tt_metal/impl/allocator/allocator.cpp b/tt_metal/impl/allocator/allocator.cpp index aaeec83d553..bdbbe8dd3f9 100644 --- a/tt_metal/impl/allocator/allocator.cpp +++ b/tt_metal/impl/allocator/allocator.cpp @@ -437,9 +437,10 @@ void reset_allocator_size(Allocator& allocator, const BufferType& buffer_type) { } } -DeviceAddr allocate_buffer(Allocator& allocator, DeviceAddr size, Buffer* buffer) { +DeviceAddr allocate_buffer(Allocator& allocator, Buffer* buffer) { DeviceAddr address = 0; - auto page_size = buffer->page_size(); + auto size = buffer->aligned_size(); + auto page_size = buffer->aligned_page_size(); auto buffer_type = buffer->buffer_type(); auto bottom_up = buffer->bottom_up(); auto num_shards = buffer->num_cores(); diff --git a/tt_metal/impl/allocator/allocator.hpp b/tt_metal/impl/allocator/allocator.hpp index 0d62f8cd5de..1852e959766 100644 --- a/tt_metal/impl/allocator/allocator.hpp +++ b/tt_metal/impl/allocator/allocator.hpp @@ -136,7 +136,7 @@ void shrink_allocator_size( Allocator& allocator, const BufferType& buffer_type, DeviceAddr shrink_size, bool bottom_up = true); void reset_allocator_size(Allocator& allocator, const BufferType& buffer_type); -DeviceAddr allocate_buffer(Allocator& allocator, DeviceAddr size, Buffer* buffer); +DeviceAddr allocate_buffer(Allocator& allocator, Buffer* buffer); void mark_allocations_unsafe(Allocator& allocator); diff --git a/tt_metal/impl/allocator/l1_banking_allocator.cpp b/tt_metal/impl/allocator/l1_banking_allocator.cpp index 604416a143c..0cd0c38d984 100644 --- a/tt_metal/impl/allocator/l1_banking_allocator.cpp +++ b/tt_metal/impl/allocator/l1_banking_allocator.cpp @@ -18,7 +18,7 @@ #include "tt_metal/impl/buffers/buffer_constants.hpp" #include "tt_metal/common/assert.hpp" #include "tt_metal/common/core_coord.hpp" -#include "umd/device/xy_pair.h" +#include "umd/device/types/xy_pair.h" #include #include "llrt/hal.hpp" diff --git a/tt_metal/impl/buffers/buffer.cpp b/tt_metal/impl/buffers/buffer.cpp index 34975b0442b..b1e5ec3e337 100644 --- a/tt_metal/impl/buffers/buffer.cpp +++ b/tt_metal/impl/buffers/buffer.cpp @@ -467,7 +467,7 @@ DeviceAddr Buffer::aligned_size() const { DeviceAddr Buffer::aligned_size_per_bank() const { uint32_t num_banks = is_sharded(this->buffer_layout_) ? this->num_cores().value() : this->device_->num_banks(this->buffer_type()); - return tt::tt_metal::detail::SizeBytesPerBank(this->size_, this->page_size_, num_banks, this->alignment()); + return tt::tt_metal::detail::SizeBytesPerBank(this->aligned_size(), this->aligned_page_size(), num_banks, this->alignment()); } DeviceAddr Buffer::sharded_page_address(uint32_t bank_id, uint32_t page_index) const { diff --git a/tt_metal/impl/buffers/buffer.hpp b/tt_metal/impl/buffers/buffer.hpp index 4539da12fd8..d3fdc9f60aa 100644 --- a/tt_metal/impl/buffers/buffer.hpp +++ b/tt_metal/impl/buffers/buffer.hpp @@ -22,7 +22,7 @@ #include "tt_metal/impl/buffers/buffer_constants.hpp" #include "tt_metal/impl/sub_device/sub_device_types.hpp" #include "umd/device/tt_soc_descriptor.h" -#include "umd/device/xy_pair.h" +#include "umd/device/types/xy_pair.h" #include "tt_metal/tt_stl/concepts.hpp" #include "tt_metal/common/assert.hpp" #include "third_party/json/json.hpp" diff --git a/tt_metal/impl/buffers/circular_buffer_types.cpp b/tt_metal/impl/buffers/circular_buffer_types.cpp index a14738e0edb..7877c264820 100644 --- a/tt_metal/impl/buffers/circular_buffer_types.cpp +++ b/tt_metal/impl/buffers/circular_buffer_types.cpp @@ -43,10 +43,6 @@ CircularBufferConfig& CircularBufferConfig::set_page_size(uint8_t buffer_index, if (this->total_size_ % page_size != 0) { TT_THROW("Total circular buffer size {} B must be divisible by page size {} B", this->total_size_, page_size); } - // TODO: Should use CIRCULAR_BUFFER_WORD_SIZE_BYTES here - if (page_size % sizeof(uint32_t) != 0) { - TT_THROW("Page size must be divisible by sizeof(uint32_t) because buffers holds uint32_t values"); - } this->page_sizes_[buffer_index] = page_size; return *this; diff --git a/tt_metal/impl/buffers/global_circular_buffer.cpp b/tt_metal/impl/buffers/global_circular_buffer.cpp index 4d438e91fcc..df8df656ac3 100644 --- a/tt_metal/impl/buffers/global_circular_buffer.cpp +++ b/tt_metal/impl/buffers/global_circular_buffer.cpp @@ -159,3 +159,12 @@ uint32_t GlobalCircularBuffer::size() const { return this->size_; } } // namespace v1 } // namespace tt::tt_metal + +namespace std { + +std::size_t hash::operator()( + const tt::tt_metal::v1::experimental::GlobalCircularBuffer& global_circular_buffer) const { + return tt::stl::hash::hash_objects_with_default_seed(global_circular_buffer.attribute_values()); +} + +} // namespace std diff --git a/tt_metal/impl/buffers/global_circular_buffer.hpp b/tt_metal/impl/buffers/global_circular_buffer.hpp index c263fe47d00..d18ed91e0c4 100644 --- a/tt_metal/impl/buffers/global_circular_buffer.hpp +++ b/tt_metal/impl/buffers/global_circular_buffer.hpp @@ -76,3 +76,12 @@ class GlobalCircularBuffer { } // namespace v1 } // namespace tt::tt_metal + +namespace std { + +template <> +struct hash { + std::size_t operator()(const tt::tt_metal::v1::experimental::GlobalCircularBuffer& global_circular_buffer) const; +}; + +} // namespace std diff --git a/tt_metal/impl/buffers/global_semaphore.cpp b/tt_metal/impl/buffers/global_semaphore.cpp index 807e74a8e10..f080ab23b06 100644 --- a/tt_metal/impl/buffers/global_semaphore.cpp +++ b/tt_metal/impl/buffers/global_semaphore.cpp @@ -77,3 +77,12 @@ void GlobalSemaphore::reset_semaphore_value() { } } // namespace tt::tt_metal + +namespace std { + +std::size_t hash::operator()( + const tt::tt_metal::GlobalSemaphore& global_semaphore) const { + return tt::stl::hash::hash_objects_with_default_seed(global_semaphore.attribute_values()); +} + +} // namespace std diff --git a/tt_metal/impl/buffers/global_semaphore.hpp b/tt_metal/impl/buffers/global_semaphore.hpp index 6c2f8d17947..f6d657998f8 100644 --- a/tt_metal/impl/buffers/global_semaphore.hpp +++ b/tt_metal/impl/buffers/global_semaphore.hpp @@ -44,6 +44,9 @@ class GlobalSemaphore { void reset_semaphore_value(); + static constexpr auto attribute_names = std::forward_as_tuple("cores", "initial_value"); + const auto attribute_values() const { return std::make_tuple(this->cores_, this->initial_value_); } + private: void setup_buffer(BufferType buffer_type); @@ -59,3 +62,12 @@ class GlobalSemaphore { } // namespace v0 } // namespace tt::tt_metal + +namespace std { + +template <> +struct hash { + std::size_t operator()(const tt::tt_metal::GlobalSemaphore& global_semaphore) const; +}; + +} // namespace std diff --git a/tt_metal/impl/debug/watcher_device_reader.cpp b/tt_metal/impl/debug/watcher_device_reader.cpp index 3e9799910cb..d764bf8a2a6 100644 --- a/tt_metal/impl/debug/watcher_device_reader.cpp +++ b/tt_metal/impl/debug/watcher_device_reader.cpp @@ -18,8 +18,8 @@ #include "eth_l1_address_map.h" // for address_map #include "hw/inc/dev_msgs.h" -#include "umd/device/tt_arch_types.h" -#include "umd/device/xy_pair.h" +#include "umd/device/types/arch.h" +#include "umd/device/types/xy_pair.h" #include #include "llrt/llrt.hpp" #include "llrt/tt_cluster.hpp" diff --git a/tt_metal/impl/device/device.cpp b/tt_metal/impl/device/device.cpp index 909b3fbcf25..2f8196d7c05 100644 --- a/tt_metal/impl/device/device.cpp +++ b/tt_metal/impl/device/device.cpp @@ -7,7 +7,6 @@ #include "tt_metal/device.hpp" #include "common/core_coord.hpp" #include "tt_metal/host_api.hpp" -#include "tt_metal/jit_build/genfiles.hpp" #include "tt_metal/impl/device/device.hpp" #include "tt_metal/impl/trace/trace.hpp" #include "tt_metal/common/core_descriptor.hpp" @@ -28,6 +27,7 @@ #include "tt_metal/impl/sub_device/sub_device_types.hpp" #include "tt_metal/tt_stl/span.hpp" #include "tt_metal/types.hpp" +#include "noc/noc_parameters.h" // FIXME: ARCH_NAME specific #include "eth_l1_address_map.h" @@ -324,15 +324,10 @@ void Device::initialize_device_kernel_defines() auto pcie_cores = soc_d.get_pcie_cores(); auto grid_size = this->grid_size(); - // Workaround for Simulator integration as they use a 2x2 grid which would underflow PCIE_NOC1* CoreCoord pcie_core = pcie_cores.empty() ? grid_size : pcie_cores[0]; - auto pcie_noc1_x = pcie_cores.empty() ? 14 : tt::tt_metal::hal.noc_coordinate(NOC::NOC_1, grid_size.x, pcie_cores[0].x); - auto pcie_noc1_y = pcie_cores.empty() ? 11 : tt::tt_metal::hal.noc_coordinate(NOC::NOC_1, grid_size.x, pcie_cores[0].y); this->device_kernel_defines_.emplace("PCIE_NOC_X", std::to_string(pcie_core.x)); this->device_kernel_defines_.emplace("PCIE_NOC_Y", std::to_string(pcie_core.y)); - this->device_kernel_defines_.emplace("PCIE_NOC1_X", std::to_string(pcie_noc1_x)); - this->device_kernel_defines_.emplace("PCIE_NOC1_Y", std::to_string(pcie_noc1_x)); } void Device::initialize_build() { @@ -412,13 +407,36 @@ void Device::build_firmware() { log_debug(tt::LogMetal, "Building base firmware for device {}", this->id_); ZoneScoped; - this->generate_device_headers(this->build_env_.get_out_firmware_root_path()); jit_build_set(this->firmware_build_states_, nullptr); } +void Device::initialize_device_bank_to_noc_tables(const HalProgrammableCoreType &core_type, CoreCoord phys_core) +{ + const uint32_t dram_to_noc_sz_in_bytes = dram_bank_to_noc_xy_.size() * sizeof(uint16_t); + const uint32_t l1_to_noc_sz_in_bytes = l1_bank_to_noc_xy_.size() * sizeof(uint16_t); + const uint32_t dram_offset_sz_in_bytes = dram_bank_offset_map_.size() * sizeof(int32_t); + const uint32_t l1_offset_sz_in_bytes = l1_bank_offset_map_.size() * sizeof(int32_t); + + const uint64_t mem_bank_to_noc_addr = hal.get_dev_addr(core_type, HalL1MemAddrType::BANK_TO_NOC_SCRATCH); + const uint32_t mem_bank_to_noc_size = hal.get_dev_size(core_type, HalL1MemAddrType::BANK_TO_NOC_SCRATCH); + + TT_ASSERT((dram_to_noc_sz_in_bytes + l1_to_noc_sz_in_bytes + dram_offset_sz_in_bytes + l1_offset_sz_in_bytes) <= mem_bank_to_noc_size, + "Size of bank_to_noc table is greater than available space"); + + tt::Cluster::instance().write_core(&dram_bank_to_noc_xy_[0], dram_to_noc_sz_in_bytes, tt_cxy_pair(this->id(), phys_core), mem_bank_to_noc_addr); + uint64_t l1_noc_addr = mem_bank_to_noc_addr + dram_to_noc_sz_in_bytes; + tt::Cluster::instance().write_core(&l1_bank_to_noc_xy_[0], l1_to_noc_sz_in_bytes, tt_cxy_pair(this->id(), phys_core), l1_noc_addr); + + uint64_t dram_offset_addr = l1_noc_addr + l1_to_noc_sz_in_bytes; + tt::Cluster::instance().write_core(&dram_bank_offset_map_[0], dram_offset_sz_in_bytes, tt_cxy_pair(this->id(), phys_core), dram_offset_addr); + uint64_t l1_offset_addr = dram_offset_addr + dram_offset_sz_in_bytes; + tt::Cluster::instance().write_core(&l1_bank_offset_map_[0], l1_offset_sz_in_bytes, tt_cxy_pair(this->id(), phys_core), l1_offset_addr); +} + void Device::initialize_firmware(const HalProgrammableCoreType &core_type, CoreCoord phys_core, launch_msg_t *launch_msg, go_msg_t* go_msg) { ZoneScoped; + this->initialize_device_bank_to_noc_tables(core_type, phys_core); uint32_t core_type_idx = hal.get_programmable_core_type_index(core_type); uint32_t processor_class_count = hal.get_processor_classes_count(core_type); @@ -761,7 +779,7 @@ void Device::clear_l1_state() { // These L1 ranges are restricted becase UMD base routing FW uses L1 below FIRMWARE_BASE and // between TILE_HEADER_BUFFER_BASE to COMMAND_Q_BASE std::vector zero_vec_above_tile_header_buffer( - (eth_l1_mem::address_map::ISSUE_CQ_CB_BASE - eth_l1_mem::address_map::TILE_HEADER_BUFFER_BASE) / sizeof(uint32_t), + (eth_l1_mem::address_map::MAX_L1_LOADING_SIZE - eth_l1_mem::address_map::TILE_HEADER_BUFFER_BASE) / sizeof(uint32_t), 0); // Clear erisc sync info @@ -2953,6 +2971,7 @@ bool Device::initialize(const uint8_t num_hw_cqs, size_t l1_small_size, size_t t this->initialize_cluster(); this->initialize_default_sub_device_state(l1_small_size, trace_region_size, l1_bank_remap); this->initialize_build(); + this->generate_device_bank_to_noc_tables(); // For minimal setup, don't initialize FW, watcher, dprint. They won't work if we're attaching to a hung chip. if (minimal) @@ -3558,37 +3577,48 @@ void Device::MarkAllocationsSafe() { tt::tt_metal::allocator::mark_allocations_safe(*this->get_initialized_allocator()); } -void Device::generate_device_headers(const std::string &path) const +void Device::generate_device_bank_to_noc_tables() { const size_t num_dram_banks = this->num_banks(BufferType::DRAM); - const size_t num_dram_banks_pow2 = std::pow(2, std::ceil(std::log2(num_dram_banks))); std::vector dram_noc_coord_per_bank(num_dram_banks); - std::vector dram_offsets_per_bank(num_dram_banks); + dram_bank_offset_map_.clear(); + dram_bank_offset_map_.resize(num_dram_banks); for (unsigned bank_id = 0; bank_id < num_dram_banks; bank_id++) { dram_noc_coord_per_bank[bank_id] = this->dram_core_from_dram_channel(this->dram_channel_from_bank_id(bank_id)); - dram_offsets_per_bank[bank_id] = this->bank_offset(BufferType::DRAM, bank_id); + dram_bank_offset_map_[bank_id] = this->bank_offset(BufferType::DRAM, bank_id); } const size_t num_l1_banks = this->num_banks(BufferType::L1); - const size_t num_l1_banks_pow2 = std::pow(2, std::ceil(std::log2(num_l1_banks))); std::vector l1_noc_coord_per_bank(num_l1_banks); - std::vector l1_offset_per_bank(num_l1_banks); + l1_bank_offset_map_.clear(); + l1_bank_offset_map_.resize(num_l1_banks); for (unsigned bank_id = 0; bank_id < num_l1_banks; bank_id++) { l1_noc_coord_per_bank[bank_id] = this->worker_core_from_logical_core(this->logical_core_from_bank_id(bank_id)); - l1_offset_per_bank[bank_id] = this->bank_offset(BufferType::L1, bank_id); + l1_bank_offset_map_[bank_id] = this->bank_offset(BufferType::L1, bank_id); } const metal_SocDescriptor& soc_d = tt::Cluster::instance().get_soc_desc(this->id()); - // Generate header file in proper location - jit_build_genfiles_bank_to_noc_coord_descriptor ( - path, - soc_d.grid_size, - dram_noc_coord_per_bank, - dram_offsets_per_bank, - l1_noc_coord_per_bank, - l1_offset_per_bank, - this->get_allocator_alignment() - ); + dram_bank_to_noc_xy_.clear(); + dram_bank_to_noc_xy_.reserve(tt::tt_metal::hal.get_num_nocs() * dram_noc_coord_per_bank.size()); + for (unsigned int noc = 0; noc < tt::tt_metal::hal.get_num_nocs(); noc++) { + for (unsigned int bank_id = 0; bank_id < dram_noc_coord_per_bank.size(); bank_id++) { + uint16_t noc_x = tt::tt_metal::hal.noc_coordinate(noc, soc_d.grid_size.x, dram_noc_coord_per_bank[bank_id].x); + uint16_t noc_y = tt::tt_metal::hal.noc_coordinate(noc, soc_d.grid_size.y, dram_noc_coord_per_bank[bank_id].y); + uint16_t xy = ((noc_y << NOC_ADDR_NODE_ID_BITS) | noc_x) << NOC_COORD_REG_OFFSET; + dram_bank_to_noc_xy_.push_back(xy); + } + } + + l1_bank_to_noc_xy_.clear(); + l1_bank_to_noc_xy_.reserve(tt::tt_metal::hal.get_num_nocs() * l1_noc_coord_per_bank.size()); + for (unsigned int noc = 0; noc < tt::tt_metal::hal.get_num_nocs(); noc++) { + for (unsigned int bank_id = 0; bank_id < l1_noc_coord_per_bank.size(); bank_id++) { + uint16_t noc_x = tt::tt_metal::hal.noc_coordinate(noc, soc_d.grid_size.x, l1_noc_coord_per_bank[bank_id].x); + uint16_t noc_y = tt::tt_metal::hal.noc_coordinate(noc, soc_d.grid_size.y, l1_noc_coord_per_bank[bank_id].y); + uint16_t xy = ((noc_y << NOC_ADDR_NODE_ID_BITS) | noc_x) << NOC_COORD_REG_OFFSET; + l1_bank_to_noc_xy_.push_back(xy); + } + } } size_t Device::get_device_kernel_defines_hash() { diff --git a/tt_metal/impl/device/device.hpp b/tt_metal/impl/device/device.hpp index 045a1097aac..616a831e046 100644 --- a/tt_metal/impl/device/device.hpp +++ b/tt_metal/impl/device/device.hpp @@ -231,7 +231,7 @@ class Device { // machine inf float sfpu_inf() const; - void generate_device_headers(const std::string &path) const; + void generate_device_bank_to_noc_tables(); const JitBuildEnv& build_env() const { return this->build_env_; } const string build_firmware_target_path(uint32_t programmable_core, uint32_t processor_class, int i) const; const string build_kernel_target_path(uint32_t programmable_core, uint32_t processor_class, int i, const string& kernel_name) const; @@ -259,6 +259,7 @@ class Device { void initialize_build(); void initialize_device_kernel_defines(); void build_firmware(); + void initialize_device_bank_to_noc_tables(const HalProgrammableCoreType &core_type, CoreCoord phys_core); void initialize_firmware(const HalProgrammableCoreType &core_type, CoreCoord phys_core, launch_msg_t *launch_msg, go_msg_t* go_msg); void reset_cores(); void initialize_and_launch_firmware(); @@ -396,6 +397,11 @@ class Device { SubDeviceManagerId next_sub_device_manager_id_ = {0}; SubDeviceManagerId default_sub_device_manager_id_ = {0}; detail::SubDeviceManager *default_sub_device_manager_ = nullptr; + + std::vector dram_bank_offset_map_; + std::vector l1_bank_offset_map_; + std::vector dram_bank_to_noc_xy_; + std::vector l1_bank_to_noc_xy_; }; } // namespace v0 diff --git a/tt_metal/impl/device/device_pool.hpp b/tt_metal/impl/device/device_pool.hpp index 2b2c64d087a..6d771aba858 100644 --- a/tt_metal/impl/device/device_pool.hpp +++ b/tt_metal/impl/device/device_pool.hpp @@ -4,7 +4,7 @@ #pragma once -#include "umd/device/tt_cluster_descriptor_types.h" +#include "umd/device/types/cluster_descriptor_types.h" #include "tt_metal/host_api.hpp" #include "impl/debug/dprint_server.hpp" #include "tt_metal/impl/device/device.hpp" diff --git a/tt_metal/impl/dispatch/command_queue.cpp b/tt_metal/impl/dispatch/command_queue.cpp index 7b90f313dbe..e0ef8b96cfc 100644 --- a/tt_metal/impl/dispatch/command_queue.cpp +++ b/tt_metal/impl/dispatch/command_queue.cpp @@ -901,8 +901,8 @@ void EnqueueProgramCommand::assemble_device_commands( circular_buffers_on_corerange.size()); for (const std::shared_ptr& cb : circular_buffers_on_corerange) { program_command_sequence.circular_buffers_on_core_ranges[i].emplace_back(cb); - const uint32_t cb_address = cb->address() >> CIRCULAR_BUFFER_LOG2_WORD_SIZE_BYTES; - const uint32_t cb_size = cb->size() >> CIRCULAR_BUFFER_LOG2_WORD_SIZE_BYTES; + const uint32_t cb_address = cb->address(); + const uint32_t cb_size = cb->size(); for (const auto& buffer_index : cb->local_buffer_indices()) { // 1 cmd for all 32 buffer indices, populate with real data for specified indices // cb config payload @@ -910,7 +910,7 @@ void EnqueueProgramCommand::assemble_device_commands( cb_config_payload[base_index] = cb_address; cb_config_payload[base_index + 1] = cb_size; cb_config_payload[base_index + 2] = cb->num_pages(buffer_index); - cb_config_payload[base_index + 3] = cb->page_size(buffer_index) >> CIRCULAR_BUFFER_LOG2_WORD_SIZE_BYTES; + cb_config_payload[base_index + 3] = cb->page_size(buffer_index); max_index = std::max(max_index, base_index + UINT32_WORDS_PER_LOCAL_CIRCULAR_BUFFER_CONFIG); } for (const auto& buffer_index : cb->remote_buffer_indices()) { @@ -1363,8 +1363,8 @@ void EnqueueProgramCommand::update_device_commands( for (const auto& cbs_on_core_range : cached_program_command_sequence.circular_buffers_on_core_ranges) { uint32_t* cb_config_payload = cached_program_command_sequence.cb_configs_payloads[i]; for (const std::shared_ptr& cb : cbs_on_core_range) { - const uint32_t cb_address = cb->address() >> CIRCULAR_BUFFER_LOG2_WORD_SIZE_BYTES; - const uint32_t cb_size = cb->size() >> CIRCULAR_BUFFER_LOG2_WORD_SIZE_BYTES; + const uint32_t cb_address = cb->address(); + const uint32_t cb_size = cb->size(); for (const auto& buffer_index : cb->local_buffer_indices()) { // 1 cmd for all 32 buffer indices, populate with real data for specified indices @@ -1373,7 +1373,7 @@ void EnqueueProgramCommand::update_device_commands( cb_config_payload[base_index] = cb_address; cb_config_payload[base_index + 1] = cb_size; cb_config_payload[base_index + 2] = cb->num_pages(buffer_index); - cb_config_payload[base_index + 3] = cb->page_size(buffer_index) >> CIRCULAR_BUFFER_LOG2_WORD_SIZE_BYTES; + cb_config_payload[base_index + 3] = cb->page_size(buffer_index); } for (const auto& buffer_index : cb->remote_buffer_indices()) { const uint32_t base_index = remote_offset_index + (NUM_CIRCULAR_BUFFERS - 1 - buffer_index) * diff --git a/tt_metal/impl/dispatch/kernels/cq_dispatch.cpp b/tt_metal/impl/dispatch/kernels/cq_dispatch.cpp index 0fe1f4f5992..49877dcf9ae 100644 --- a/tt_metal/impl/dispatch/kernels/cq_dispatch.cpp +++ b/tt_metal/impl/dispatch/kernels/cq_dispatch.cpp @@ -60,10 +60,7 @@ constexpr uint32_t downstream_noc_xy = uint32_t(NOC_XY_ENCODING(DOWNSTREAM_NOC_X constexpr uint32_t dispatch_s_noc_xy = uint32_t(NOC_XY_ENCODING(DOWNSTREAM_SLAVE_NOC_X, DOWNSTREAM_SLAVE_NOC_Y)); constexpr uint8_t my_noc_index = NOC_INDEX; constexpr uint32_t my_noc_xy = uint32_t(NOC_XY_ENCODING(MY_NOC_X, MY_NOC_Y)); -constexpr uint64_t pcie_noc_xy = uint64_t(NOC_XY_PCIE_ENCODING( - NOC_0_X(static_cast(NOC_INDEX), noc_size_x, PCIE_NOC_X), - NOC_0_Y(static_cast(NOC_INDEX), noc_size_y, PCIE_NOC_Y), - NOC_INDEX)); +constexpr uint64_t pcie_noc_xy = uint64_t(NOC_XY_PCIE_ENCODING(NOC_X(PCIE_NOC_X), NOC_Y(PCIE_NOC_Y))); constexpr uint32_t dispatch_cb_page_size = 1 << dispatch_cb_log_page_size; constexpr uint32_t completion_queue_end_addr = completion_queue_base_addr + completion_queue_size; diff --git a/tt_metal/impl/dispatch/kernels/cq_prefetch.cpp b/tt_metal/impl/dispatch/kernels/cq_prefetch.cpp index 711b7140ed6..4e4d7ce297c 100644 --- a/tt_metal/impl/dispatch/kernels/cq_prefetch.cpp +++ b/tt_metal/impl/dispatch/kernels/cq_prefetch.cpp @@ -69,10 +69,7 @@ constexpr uint32_t my_noc_xy = uint32_t(NOC_XY_ENCODING(MY_NOC_X, MY_NOC_Y)); constexpr uint32_t upstream_noc_xy = uint32_t(NOC_XY_ENCODING(UPSTREAM_NOC_X, UPSTREAM_NOC_Y)); constexpr uint32_t downstream_noc_xy = uint32_t(NOC_XY_ENCODING(DOWNSTREAM_NOC_X, DOWNSTREAM_NOC_Y)); constexpr uint32_t dispatch_s_noc_xy = uint32_t(NOC_XY_ENCODING(DOWNSTREAM_SLAVE_NOC_X, DOWNSTREAM_SLAVE_NOC_Y)); -constexpr uint64_t pcie_noc_xy = uint64_t(NOC_XY_PCIE_ENCODING( - NOC_0_X(static_cast(NOC_INDEX), noc_size_x, PCIE_NOC_X), - NOC_0_Y(static_cast(NOC_INDEX), noc_size_y, PCIE_NOC_Y), - NOC_INDEX)); +constexpr uint64_t pcie_noc_xy = uint64_t(NOC_XY_PCIE_ENCODING(NOC_X(PCIE_NOC_X), NOC_Y(PCIE_NOC_Y))); constexpr uint32_t downstream_cb_page_size = 1 << downstream_cb_log_page_size; constexpr uint32_t dispatch_s_cb_page_size = 1 << dispatch_s_cb_log_page_size; constexpr uint32_t downstream_cb_end = downstream_cb_base + (1 << downstream_cb_log_page_size) * downstream_cb_pages; diff --git a/tt_metal/impl/kernels/kernel.cpp b/tt_metal/impl/kernels/kernel.cpp index d21d2c1735d..a3f67470d21 100644 --- a/tt_metal/impl/kernels/kernel.cpp +++ b/tt_metal/impl/kernels/kernel.cpp @@ -335,7 +335,6 @@ void ComputeKernel::set_build_options(JitBuildOptions &build_options) const { void DataMovementKernel::generate_binaries(Device *device, JitBuildOptions &build_options) const { jit_build_genfiles_kernel_include(device->build_env(), *this, this->kernel_src_); - device->generate_device_headers(build_options.path); uint32_t tensix_core_type = hal.get_programmable_core_type_index(this->get_kernel_programmable_core_type()); uint32_t dm_class_idx = magic_enum::enum_integer(HalProcessorClassType::DM); int riscv_id = static_cast::type>(this->config_.processor); @@ -344,7 +343,6 @@ void DataMovementKernel::generate_binaries(Device *device, JitBuildOptions &buil void EthernetKernel::generate_binaries(Device *device, JitBuildOptions &build_options) const { jit_build_genfiles_kernel_include(device->build_env(), *this, this->kernel_src_); - device->generate_device_headers(build_options.path); uint32_t erisc_core_type = hal.get_programmable_core_type_index(this->get_kernel_programmable_core_type()); uint32_t dm_class_idx = magic_enum::enum_integer(HalProcessorClassType::DM); int erisc_id = magic_enum::enum_integer(this->config_.processor); diff --git a/tt_metal/impl/program/program.cpp b/tt_metal/impl/program/program.cpp index 09f90b8017b..216ffcae5b3 100644 --- a/tt_metal/impl/program/program.cpp +++ b/tt_metal/impl/program/program.cpp @@ -776,7 +776,7 @@ void detail::Program_::allocate_circular_buffers(const Device *device) { } } } - + computed_addr = align(computed_addr, device->get_allocator_alignment()); for (const CoreRange &core_range : circular_buffer->core_ranges().ranges()) { for (CircularBufferAllocator &cb_allocator : this->cb_allocators_) { if (cb_allocator.core_range.intersects(core_range)) { diff --git a/tt_metal/include/compute_kernel_api/add_int32_sfpu.h b/tt_metal/include/compute_kernel_api/add_int32_sfpu.h new file mode 100644 index 00000000000..89103555821 --- /dev/null +++ b/tt_metal/include/compute_kernel_api/add_int32_sfpu.h @@ -0,0 +1,43 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "compute_kernel_api/common_globals.h" +#ifdef TRISC_MATH +#include "llk_math_eltwise_binary_sfpu_add_int32.h" +#define MAIN math_main() +#define MATH(x) x +#else +#define MATH(x) +#endif + +namespace ckernel { + +/** + * Performs an elementwise add operation with the two integer inputs: y = add(x0,x1) + * Output overwrites first operand in DST. + * + * A maximum of 4 tiles from each operand can be loaded into DST at once, for a total of 8 tiles, + * when using 16 bit formats. This gets reduced to 2 tiles from each operand for 32 bit formats. + * + * Return value: None + * + * | Argument | Description | Type | Valid Range | + * Required | + * |----------------|-----------------------------------------------------------------------|----------|-------------------------------------------------------|----------| + * | idst0 | The index of the tile in DST register buffer to use as first operand | uint32_t | Must be less + * than the size of the DST register buffer | True | | idst1 | The index of the tile in DST register buffer + * to use as second operand | uint32_t | Must be less than the size of the DST register buffer | True | + */ +ALWI void add_int32_tile(uint32_t idst0, uint32_t idst1) { + MATH((llk_math_eltwise_binary_sfpu_add_int32(idst0, idst1))); +} + +/** + * Please refer to documentation for any_init. + */ +ALWI void add_int32_tile_init() { MATH((llk_math_eltwise_binary_sfpu_add_int32_init())); } + +} // namespace ckernel diff --git a/tt_metal/include/compute_kernel_api/binary_bitwise_sfpu.h b/tt_metal/include/compute_kernel_api/binary_bitwise_sfpu.h new file mode 100644 index 00000000000..cf2a20d0090 --- /dev/null +++ b/tt_metal/include/compute_kernel_api/binary_bitwise_sfpu.h @@ -0,0 +1,52 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "compute_kernel_api/common_globals.h" +#ifdef TRISC_MATH +#include "llk_math_eltwise_binary_sfpu_bitwise.h" +#define MAIN math_main() +#define MATH(x) x +#else +#define MATH(x) +#endif + +namespace ckernel { + +/** + * Performs an elementwise binary bitwise operation with the two inputs: y = bitwise(x0,x1) + * Output overwrites first operand in DST. + * + * A maximum of 4 tiles from each operand can be loaded into DST at once, for a total of 8 tiles, + * when using 16 bit formats. This gets reduced to 2 tiles from each operand for 32 bit formats. + * + * Return value: None + * + * | Argument | Description | Type | Valid Range | + * Required | + * |----------------|-----------------------------------------------------------------------|----------|-------------------------------------------------------|----------| + * | idst0 | The index of the tile in DST register buffer to use as first operand | uint32_t | Must be less + * than the size of the DST register buffer | True | | idst1 | The index of the tile in DST register buffer + * to use as second operand | uint32_t | Must be less than the size of the DST register buffer | True | + */ +enum { AND_BINARY = 0, OR_BINARY = 1, XOR_BINARY = 2 }; +ALWI void and_binary_tile(uint32_t idst0, uint32_t idst1) { + MATH((llk_math_eltwise_binary_sfpu_bitwise(idst0, idst1))); +} + +ALWI void or_binary_tile(uint32_t idst0, uint32_t idst1) { + MATH((llk_math_eltwise_binary_sfpu_bitwise(idst0, idst1))); +} + +ALWI void xor_binary_tile(uint32_t idst0, uint32_t idst1) { + MATH((llk_math_eltwise_binary_sfpu_bitwise(idst0, idst1))); +} + +/** + * Please refer to documentation for any_init. + */ +ALWI void binary_bitwise_tile_init() { MATH((llk_math_eltwise_binary_sfpu_bitwise_init())); } + +} // namespace ckernel diff --git a/tt_metal/include/compute_kernel_api/eltwise_binary_sfpu.h b/tt_metal/include/compute_kernel_api/eltwise_binary_sfpu.h new file mode 100644 index 00000000000..22fc4c13fcf --- /dev/null +++ b/tt_metal/include/compute_kernel_api/eltwise_binary_sfpu.h @@ -0,0 +1,64 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "compute_kernel_api/common_globals.h" +#ifdef TRISC_MATH +#include "llk_math_eltwise_binary_sfpu_binop.h" +#define MAIN math_main() +#define MATH(x) x +#else +#define MATH(x) +#endif + +namespace ckernel { + +/** + * Performs an elementwise binop operation with the two floating point inputs: y = binop(x0,x1) + * Output overwrites first operand in DST. + * + * A maximum of 4 tiles from each operand can be loaded into DST at once, for a total of 8 tiles, + * when using 16 bit formats. This gets reduced to 2 tiles from each operand for 32 bit formats. + * + * Return value: None + * + * | Argument | Description | Type | Valid Range | + * Required | + * |----------------|-----------------------------------------------------------------------|----------|-------------------------------------------------------|----------| + * | idst0 | The index of the tile in DST register buffer to use as first operand | uint32_t | Must be less + * than the size of the DST register buffer | True | | idst1 | The index of the tile in DST register buffer + * to use as second operand | uint32_t | Must be less than the size of the DST register buffer | True | + */ +enum { ADD_BINARY = 0, SUB_BINARY = 1, MUL_BINARY = 2, DIV_BINARY = 3, RSUB_BINARY = 4, POW_BINARY = 5 }; +ALWI void add_binary_tile(uint32_t idst0, uint32_t idst1) { + MATH((llk_math_eltwise_binary_sfpu_binop(idst0, idst1))); +} + +ALWI void sub_binary_tile(uint32_t idst0, uint32_t idst1) { + MATH((llk_math_eltwise_binary_sfpu_binop(idst0, idst1))); +} + +ALWI void mul_binary_tile(uint32_t idst0, uint32_t idst1) { + MATH((llk_math_eltwise_binary_sfpu_binop(idst0, idst1))); +} + +ALWI void div_binary_tile(uint32_t idst0, uint32_t idst1) { + MATH((llk_math_eltwise_binary_sfpu_binop(idst0, idst1))); +} + +ALWI void rsub_binary_tile(uint32_t idst0, uint32_t idst1) { + MATH((llk_math_eltwise_binary_sfpu_binop(idst0, idst1))); +} + +ALWI void power_binary_tile(uint32_t idst0, uint32_t idst1) { + MATH((llk_math_eltwise_binary_sfpu_binop(idst0, idst1))); +} + +/** + * Please refer to documentation for any_init. + */ +ALWI void eltwise_binop_tile_init() { MATH((llk_math_eltwise_binary_sfpu_binop_init())); } + +} // namespace ckernel diff --git a/tt_metal/include/compute_kernel_api/eltwise_unary/floor.h b/tt_metal/include/compute_kernel_api/eltwise_unary/floor.h index ecad592ee52..fe45132ff0a 100644 --- a/tt_metal/include/compute_kernel_api/eltwise_unary/floor.h +++ b/tt_metal/include/compute_kernel_api/eltwise_unary/floor.h @@ -14,7 +14,6 @@ #endif namespace ckernel { - /** * Please refer to documentation for any_init. */ @@ -31,9 +30,25 @@ ALWI void floor_tile_init() { MATH((llk_math_eltwise_unary_sfpu_floor_init(idst))); } +/** + * Performs floor operation on each row of a tile. + * in DST register at index tile_index. The DST register buffer must be in + * acquired state via *acquire_dst* call. This call is blocking and is only + * available on the compute engine. + * + * Return value: None + * + * | Argument | Description | Type | Valid + * Range | Required | + * |-----------------|----------------------------------------------------------------------------|----------|-------------------------------------------------------|----------| + * | idst | The index of the tile in DST register buffer to perform floor operation | uint32_t | Must be + * less than the size of the DST register buffer | True | + */ +ALWI void floor_tile_float32(uint32_t idst) { MATH((llk_math_eltwise_unary_sfpu_floor_float32(idst))); } + } // namespace ckernel diff --git a/tt_metal/jit_build/build.hpp b/tt_metal/jit_build/build.hpp index 45c153439f0..ccd4a7860d2 100644 --- a/tt_metal/jit_build/build.hpp +++ b/tt_metal/jit_build/build.hpp @@ -50,7 +50,6 @@ class JitBuildEnv { tt::ARCH get_arch() const { return arch_; } const string& get_root_path() const { return root_; } const string& get_out_root_path() const { return out_root_; } - const string& get_out_firmware_root_path() const { return out_firmware_root_; } const string& get_out_kernel_root_path() const { return out_kernel_root_; } private: diff --git a/tt_metal/jit_build/data_format.hpp b/tt_metal/jit_build/data_format.hpp index ddd532ebb77..6664e6dccd8 100644 --- a/tt_metal/jit_build/data_format.hpp +++ b/tt_metal/jit_build/data_format.hpp @@ -6,7 +6,7 @@ #include #include #include "common/tt_backend_api_types.hpp" // for DataFormat -#include "umd/device/tt_arch_types.h" // for ARCH +#include "umd/device/types/arch.h" // for ARCH #include "tt_metal/hw/inc/circular_buffer_constants.h" // for NUM_CIRCULAR_BUFFERS enum class UnpackToDestMode : std::uint8_t; diff --git a/tt_metal/jit_build/genfiles.cpp b/tt_metal/jit_build/genfiles.cpp index a008db74e1e..ab920c1d1b0 100644 --- a/tt_metal/jit_build/genfiles.cpp +++ b/tt_metal/jit_build/genfiles.cpp @@ -451,128 +451,4 @@ void jit_build_genfiles_descriptors(const JitBuildEnv& env, JitBuildOptions& opt } } -std::string generate_bank_to_noc_coord_descriptor_string( - tt_xy_pair grid_size, - std::vector& dram_bank_map, - std::vector& dram_bank_offset_map, - std::vector& l1_bank_map, - std::vector& l1_bank_offset_map, - uint32_t allocator_alignment) { - stringstream ss; - - ss << "// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc." << endl; - ss << "//" << endl; - ss << "// SPDX-License-Identifier: Apache-2.0" << endl; - ss << endl; - ss << "/*" << endl; - ss << " * This file is autogenerated by tt-metal runtime" << endl; - ss << " * DO NOT EDIT" << endl; - ss << " * This file contains values that are visible to the device compiled code." << endl; - ss << " * CAREFUL: when included in the FW_BUILD, it defines global variables." << endl; - ss << " * When included in KERNEL_BUILD, it declares global variables." << endl; - ss << " */" << endl; - ss << endl; - ss << "#pragma once" << endl; - ss << endl; - ss << "#include " << endl; - ss << endl; - - ss << "static_assert(NUM_NOCS == 2);" << endl; - ss << endl; - - ss << "#ifdef KERNEL_BUILD" << endl; - ss << endl; - ss << "extern uint16_t dram_bank_to_noc_xy[NUM_NOCS][NUM_DRAM_BANKS];" << endl; - ss << "extern int32_t bank_to_dram_offset[NUM_DRAM_BANKS];" << endl; - ss << "extern uint16_t l1_bank_to_noc_xy[NUM_NOCS][NUM_L1_BANKS];" << endl; - ss << "extern int32_t bank_to_l1_offset[NUM_L1_BANKS];" << endl; - - ss << endl; - ss << "#else // !KERNEL_BUILD (FW_BUILD)" << endl; - ss << endl; - - ss << "uint16_t dram_bank_to_noc_xy[NUM_NOCS][NUM_DRAM_BANKS] __attribute__((used)) = {" << endl; - for (unsigned int noc = 0; noc < 2; noc++) { - ss << " {" - << "\t// noc=" << noc << endl; - for (unsigned int bank_id = 0; bank_id < dram_bank_map.size(); bank_id++) { - uint16_t noc_x = tt::tt_metal::hal.noc_coordinate(noc, grid_size.x, dram_bank_map[bank_id].x); - uint16_t noc_y = tt::tt_metal::hal.noc_coordinate(noc, grid_size.y, dram_bank_map[bank_id].y); - ss << " (((" << noc_y << " << NOC_ADDR_NODE_ID_BITS) | " << noc_x << ") << NOC_COORD_REG_OFFSET)," - << "\t// NOC_X=" << noc_x << " NOC_Y=" << noc_y << endl; - } - ss << " }," << endl; - } - ss << "};" << endl; - ss << endl; - ss << "int32_t bank_to_dram_offset[NUM_DRAM_BANKS] __attribute__((used)) = {" << endl; - for (unsigned int bank_id = 0; bank_id < dram_bank_map.size(); bank_id++) { - ss << " " << dram_bank_offset_map[bank_id] << "," << endl; - } - ss << "};" << endl; - ss << endl; - - ss << "uint16_t l1_bank_to_noc_xy[NUM_NOCS][NUM_L1_BANKS] __attribute__((used)) = {" << endl; - for (unsigned int noc = 0; noc < 2; noc++) { - ss << " {" - << "\t// noc=" << noc << endl; - for (unsigned int bank_id = 0; bank_id < l1_bank_map.size(); bank_id++) { - uint16_t noc_x = tt::tt_metal::hal.noc_coordinate(noc, grid_size.x, l1_bank_map[bank_id].x); - uint16_t noc_y = tt::tt_metal::hal.noc_coordinate(noc, grid_size.y, l1_bank_map[bank_id].y); - ss << " (((" << noc_y << " << NOC_ADDR_NODE_ID_BITS) | " << noc_x << ") << NOC_COORD_REG_OFFSET)," - << "\t// NOC_X=" << noc_x << " NOC_Y=" << noc_y << endl; - } - ss << " }," << endl; - } - ss << "};" << endl; - ss << endl; - ss << "int32_t bank_to_l1_offset[NUM_L1_BANKS] __attribute__((used)) = {" << endl; - for (unsigned int bank_id = 0; bank_id < l1_bank_map.size(); bank_id++) { - ss << " " << l1_bank_offset_map[bank_id] << "," << endl; - } - ss << "};" << endl; - ss << endl; - - ss << "#endif // FW_BUILD" << endl; - - return ss.str(); -} -void jit_build_genfiles_bank_to_noc_coord_descriptor( - const string& path, - tt_xy_pair grid_size, - std::vector& dram_bank_map, - std::vector& dram_bank_offset_map, - std::vector& l1_bank_map, - std::vector& l1_bank_offset_map, - uint32_t allocator_alignment) { - string output_string = generate_bank_to_noc_coord_descriptor_string( - grid_size, - dram_bank_map, - dram_bank_offset_map, - l1_bank_map, - l1_bank_offset_map, - allocator_alignment); - - fs::create_directories(path + "/brisc"); - ofstream file_stream_br(path + "/brisc/generated_bank_to_noc_coord_mapping.h"); - file_stream_br << output_string; - file_stream_br.close(); - fs::create_directories(path + "/ncrisc"); - ofstream file_stream_nc(path + "/ncrisc/generated_bank_to_noc_coord_mapping.h"); - file_stream_nc << output_string; - file_stream_nc.close(); - fs::create_directories(path + "/erisc"); - ofstream file_stream_ec(path + "/erisc/generated_bank_to_noc_coord_mapping.h"); - file_stream_ec << output_string; - file_stream_ec.close(); - fs::create_directories(path + "/idle_erisc"); - ofstream file_stream_iec(path + "/idle_erisc/generated_bank_to_noc_coord_mapping.h"); - file_stream_iec << output_string; - file_stream_iec.close(); - fs::create_directories(path + "/slave_idle_erisc"); - ofstream file_stream_siec(path + "/slave_idle_erisc/generated_bank_to_noc_coord_mapping.h"); - file_stream_siec << output_string; - file_stream_siec.close(); -} - } // namespace tt::tt_metal diff --git a/tt_metal/jit_build/genfiles.hpp b/tt_metal/jit_build/genfiles.hpp index 4dee07a44ab..c21459daabd 100644 --- a/tt_metal/jit_build/genfiles.hpp +++ b/tt_metal/jit_build/genfiles.hpp @@ -21,15 +21,6 @@ void jit_build_genfiles_kernel_include( void jit_build_genfiles_triscs_src( const JitBuildEnv& env, const JitBuildSettings& settings, const KernelSource& kernel_src); -void jit_build_genfiles_bank_to_noc_coord_descriptor( - const std::string& path, - tt_xy_pair grid_size, - std::vector& dram_bank_map, - std::vector& dram_bank_offset_map, - std::vector& l1_bank_map, - std::vector& l1_bank_offset_map, - uint32_t allocator_alignment); - void jit_build_genfiles_descriptors(const JitBuildEnv& env, JitBuildOptions& options); } // namespace tt::tt_metal diff --git a/tt_metal/llrt/blackhole/bh_hal_active_eth.cpp b/tt_metal/llrt/blackhole/bh_hal_active_eth.cpp index 021f58f1075..2fe01d1cd57 100644 --- a/tt_metal/llrt/blackhole/bh_hal_active_eth.cpp +++ b/tt_metal/llrt/blackhole/bh_hal_active_eth.cpp @@ -46,6 +46,8 @@ HalCoreInfoType create_active_eth_mem_map() { GET_ETH_MAILBOX_ADDRESS_HOST(launch_msg_rd_ptr); mem_map_bases[static_cast(HalL1MemAddrType::FW_VERSION_ADDR)] = eth_l1_mem::address_map::FW_VERSION_ADDR; + mem_map_bases[static_cast(HalL1MemAddrType::BANK_TO_NOC_SCRATCH)] = + eth_l1_mem::address_map::ERISC_MEM_BANK_TO_NOC_SCRATCH; std::vector mem_map_sizes; mem_map_sizes.resize(static_cast(HalL1MemAddrType::COUNT)); @@ -65,6 +67,8 @@ HalCoreInfoType create_active_eth_mem_map() { mem_map_sizes[static_cast(HalL1MemAddrType::GO_MSG)] = sizeof(go_msg_t); mem_map_sizes[static_cast(HalL1MemAddrType::LAUNCH_MSG_BUFFER_RD_PTR)] = sizeof(std::uint32_t); mem_map_sizes[static_cast(HalL1MemAddrType::FW_VERSION_ADDR)] = sizeof(std::uint32_t); + mem_map_sizes[static_cast(HalL1MemAddrType::BANK_TO_NOC_SCRATCH)] = + eth_l1_mem::address_map::ERISC_MEM_BANK_TO_NOC_SIZE; std::vector> processor_classes(NumEthDispatchClasses - 1); std::vector processor_types(1); diff --git a/tt_metal/llrt/blackhole/bh_hal_idle_eth.cpp b/tt_metal/llrt/blackhole/bh_hal_idle_eth.cpp index f7f91ed7f44..72ba9e91a22 100644 --- a/tt_metal/llrt/blackhole/bh_hal_idle_eth.cpp +++ b/tt_metal/llrt/blackhole/bh_hal_idle_eth.cpp @@ -49,6 +49,7 @@ HalCoreInfoType create_idle_eth_mem_map() { mem_map_bases[static_cast(HalL1MemAddrType::GO_MSG)] = GET_IERISC_MAILBOX_ADDRESS_HOST(go_message); mem_map_bases[static_cast(HalL1MemAddrType::LAUNCH_MSG_BUFFER_RD_PTR)] = GET_IERISC_MAILBOX_ADDRESS_HOST(launch_msg_rd_ptr); + mem_map_bases[static_cast(HalL1MemAddrType::BANK_TO_NOC_SCRATCH)] = MEM_IERISC_BANK_TO_NOC_SCRATCH; std::vector mem_map_sizes; mem_map_sizes.resize(static_cast(HalL1MemAddrType::COUNT)); @@ -66,6 +67,7 @@ HalCoreInfoType create_idle_eth_mem_map() { ; mem_map_sizes[static_cast(HalL1MemAddrType::GO_MSG)] = sizeof(go_msg_t); mem_map_sizes[static_cast(HalL1MemAddrType::LAUNCH_MSG_BUFFER_RD_PTR)] = sizeof(std::uint32_t); + mem_map_sizes[static_cast(HalL1MemAddrType::BANK_TO_NOC_SCRATCH)] = MEM_IERISC_BANK_TO_NOC_SIZE; std::vector> processor_classes(NumEthDispatchClasses); std::vector processor_types(1); diff --git a/tt_metal/llrt/blackhole/bh_hal_tensix.cpp b/tt_metal/llrt/blackhole/bh_hal_tensix.cpp index d0414dcfbc0..eb17f10bf11 100644 --- a/tt_metal/llrt/blackhole/bh_hal_tensix.cpp +++ b/tt_metal/llrt/blackhole/bh_hal_tensix.cpp @@ -46,6 +46,7 @@ HalCoreInfoType create_tensix_mem_map() { mem_map_bases[static_cast(HalL1MemAddrType::LAUNCH_MSG_BUFFER_RD_PTR)] = GET_MAILBOX_ADDRESS_HOST(launch_msg_rd_ptr); mem_map_bases[static_cast(HalL1MemAddrType::LOCAL)] = MEM_LOCAL_BASE; + mem_map_bases[static_cast(HalL1MemAddrType::BANK_TO_NOC_SCRATCH)] = MEM_BANK_TO_NOC_SCRATCH; std::vector mem_map_sizes; mem_map_sizes.resize(static_cast(HalL1MemAddrType::COUNT)); @@ -62,6 +63,7 @@ HalCoreInfoType create_tensix_mem_map() { mem_map_sizes[static_cast(HalL1MemAddrType::GO_MSG)] = sizeof(go_msg_t); mem_map_sizes[static_cast(HalL1MemAddrType::LAUNCH_MSG_BUFFER_RD_PTR)] = sizeof(uint32_t); mem_map_sizes[static_cast(HalL1MemAddrType::LOCAL)] = MEM_TRISC_LOCAL_SIZE; // TRISC, BRISC, or NCRISC? + mem_map_sizes[static_cast(HalL1MemAddrType::BANK_TO_NOC_SCRATCH)] = MEM_BANK_TO_NOC_SIZE; std::vector> processor_classes(NumTensixDispatchClasses); std::vector processor_types; diff --git a/tt_metal/llrt/get_platform_architecture.hpp b/tt_metal/llrt/get_platform_architecture.hpp index 67ca8f6c6eb..3aa08fd4ee4 100644 --- a/tt_metal/llrt/get_platform_architecture.hpp +++ b/tt_metal/llrt/get_platform_architecture.hpp @@ -43,7 +43,7 @@ namespace tt::tt_metal { * if (arch == tt::ARCH::Invalid) { * std::cerr << "Failed to detect architecture!" << std::endl; * } else { - * std::cout << "Detected architecture: " << tt::get_arch_str(arch) << std::endl; + * std::cout << "Detected architecture: " << tt::arch_to_str(arch) << std::endl; * } * @endcode * @@ -68,9 +68,9 @@ inline tt::ARCH get_platform_architecture() { TT_FATAL( arch == detected_arch, "Expected all devices to be {} but device {} is {}", - get_arch_str(arch), + tt::arch_to_str(arch), device_id, - get_arch_str(detected_arch)); + tt::arch_to_str(detected_arch)); } } } diff --git a/tt_metal/llrt/grayskull/gs_hal.cpp b/tt_metal/llrt/grayskull/gs_hal.cpp index 5477beeec65..71a889179b8 100644 --- a/tt_metal/llrt/grayskull/gs_hal.cpp +++ b/tt_metal/llrt/grayskull/gs_hal.cpp @@ -61,6 +61,7 @@ void Hal::initialize_gs() { mem_map_bases[static_cast(HalL1MemAddrType::LAUNCH_MSG_BUFFER_RD_PTR)] = GET_MAILBOX_ADDRESS_HOST(launch_msg_rd_ptr); mem_map_bases[static_cast(HalL1MemAddrType::LOCAL)] = MEM_LOCAL_BASE; + mem_map_bases[static_cast(HalL1MemAddrType::BANK_TO_NOC_SCRATCH)] = MEM_BANK_TO_NOC_SCRATCH; std::vector mem_map_sizes; mem_map_sizes.resize(static_cast(HalL1MemAddrType::COUNT)); @@ -77,6 +78,7 @@ void Hal::initialize_gs() { mem_map_sizes[static_cast(HalL1MemAddrType::GO_MSG)] = sizeof(go_msg_t); mem_map_sizes[static_cast(HalL1MemAddrType::LAUNCH_MSG_BUFFER_RD_PTR)] = sizeof(uint32_t); mem_map_sizes[static_cast(HalL1MemAddrType::LOCAL)] = MEM_TRISC_LOCAL_SIZE; // TRISC, BRISC, or NCRISC? + mem_map_sizes[static_cast(HalL1MemAddrType::BANK_TO_NOC_SCRATCH)] = MEM_BANK_TO_NOC_SIZE; std::vector> processor_classes(NumTensixDispatchClasses); std::vector processor_types; diff --git a/tt_metal/llrt/hal.hpp b/tt_metal/llrt/hal.hpp index f7da19e2f97..80e88002696 100644 --- a/tt_metal/llrt/hal.hpp +++ b/tt_metal/llrt/hal.hpp @@ -51,6 +51,7 @@ enum class HalL1MemAddrType : uint8_t { LAUNCH_MSG_BUFFER_RD_PTR, FW_VERSION_ADDR, // Really only applicable to active eth core right now LOCAL, + BANK_TO_NOC_SCRATCH, COUNT // Keep this last so it always indicates number of enum options }; diff --git a/tt_metal/llrt/tt_cluster.cpp b/tt_metal/llrt/tt_cluster.cpp index 78ee0871795..d49d2ae4d5e 100644 --- a/tt_metal/llrt/tt_cluster.cpp +++ b/tt_metal/llrt/tt_cluster.cpp @@ -26,13 +26,13 @@ #include "tt_metal/common/metal_soc_descriptor.h" #include "tt_metal/common/test_common.hpp" #include "tt_metal/common/tt_backend_api_types.hpp" -#include "umd/device/tt_arch_types.h" +#include "umd/device/types/arch.h" #include "umd/device/tt_cluster_descriptor.h" -#include "umd/device/tt_cluster_descriptor_types.h" +#include "umd/device/types/cluster_descriptor_types.h" #include "umd/device/cluster.h" #include "umd/device/tt_soc_descriptor.h" #include "umd/device/tt_xy_pair.h" -#include "umd/device/xy_pair.h" +#include "umd/device/types/xy_pair.h" #include "umd/device/hugepage.h" // TODO: ARCH_NAME specific, must remove diff --git a/tt_metal/llrt/wormhole/wh_hal_active_eth.cpp b/tt_metal/llrt/wormhole/wh_hal_active_eth.cpp index 0d1241020c5..c0af4cc0bd7 100644 --- a/tt_metal/llrt/wormhole/wh_hal_active_eth.cpp +++ b/tt_metal/llrt/wormhole/wh_hal_active_eth.cpp @@ -43,6 +43,8 @@ HalCoreInfoType create_active_eth_mem_map() { GET_ETH_MAILBOX_ADDRESS_HOST(launch_msg_rd_ptr); mem_map_bases[static_cast(HalL1MemAddrType::FW_VERSION_ADDR)] = eth_l1_mem::address_map::FW_VERSION_ADDR; + mem_map_bases[static_cast(HalL1MemAddrType::BANK_TO_NOC_SCRATCH)] = + eth_l1_mem::address_map::ERISC_MEM_BANK_TO_NOC_SCRATCH; std::vector mem_map_sizes; mem_map_sizes.resize(static_cast(HalL1MemAddrType::COUNT)); @@ -62,6 +64,7 @@ HalCoreInfoType create_active_eth_mem_map() { mem_map_sizes[static_cast(HalL1MemAddrType::GO_MSG)] = sizeof(go_msg_t); mem_map_sizes[static_cast(HalL1MemAddrType::LAUNCH_MSG_BUFFER_RD_PTR)] = sizeof(uint32_t); mem_map_sizes[static_cast(HalL1MemAddrType::FW_VERSION_ADDR)] = sizeof(std::uint32_t); + mem_map_sizes[static_cast(HalL1MemAddrType::BANK_TO_NOC_SCRATCH)] = eth_l1_mem::address_map::ERISC_MEM_BANK_TO_NOC_SIZE; std::vector> processor_classes(NumEthDispatchClasses); std::vector processor_types(1); diff --git a/tt_metal/llrt/wormhole/wh_hal_idle_eth.cpp b/tt_metal/llrt/wormhole/wh_hal_idle_eth.cpp index a2ce00faf43..6a5b617a3d2 100644 --- a/tt_metal/llrt/wormhole/wh_hal_idle_eth.cpp +++ b/tt_metal/llrt/wormhole/wh_hal_idle_eth.cpp @@ -49,6 +49,7 @@ HalCoreInfoType create_idle_eth_mem_map() { mem_map_bases[static_cast(HalL1MemAddrType::GO_MSG)] = GET_IERISC_MAILBOX_ADDRESS_HOST(go_message); mem_map_bases[static_cast(HalL1MemAddrType::LAUNCH_MSG_BUFFER_RD_PTR)] = GET_IERISC_MAILBOX_ADDRESS_HOST(launch_msg_rd_ptr); + mem_map_bases[static_cast(HalL1MemAddrType::BANK_TO_NOC_SCRATCH)] = MEM_IERISC_BANK_TO_NOC_SCRATCH; std::vector mem_map_sizes; mem_map_sizes.resize(static_cast(HalL1MemAddrType::COUNT)); @@ -66,6 +67,7 @@ HalCoreInfoType create_idle_eth_mem_map() { ; mem_map_sizes[static_cast(HalL1MemAddrType::GO_MSG)] = sizeof(go_msg_t); mem_map_sizes[static_cast(HalL1MemAddrType::LAUNCH_MSG_BUFFER_RD_PTR)] = sizeof(std::uint32_t); + mem_map_sizes[static_cast(HalL1MemAddrType::BANK_TO_NOC_SCRATCH)] = MEM_IERISC_BANK_TO_NOC_SIZE; std::vector> processor_classes(NumEthDispatchClasses); std::vector processor_types(1); diff --git a/tt_metal/llrt/wormhole/wh_hal_tensix.cpp b/tt_metal/llrt/wormhole/wh_hal_tensix.cpp index 7de8185bacb..e4d6c42981e 100644 --- a/tt_metal/llrt/wormhole/wh_hal_tensix.cpp +++ b/tt_metal/llrt/wormhole/wh_hal_tensix.cpp @@ -47,6 +47,7 @@ HalCoreInfoType create_tensix_mem_map() { mem_map_bases[static_cast(HalL1MemAddrType::LAUNCH_MSG_BUFFER_RD_PTR)] = GET_MAILBOX_ADDRESS_HOST(launch_msg_rd_ptr); mem_map_bases[static_cast(HalL1MemAddrType::LOCAL)] = MEM_LOCAL_BASE; + mem_map_bases[static_cast(HalL1MemAddrType::BANK_TO_NOC_SCRATCH)] = MEM_BANK_TO_NOC_SCRATCH; std::vector mem_map_sizes; mem_map_sizes.resize(static_cast(HalL1MemAddrType::COUNT)); @@ -63,6 +64,7 @@ HalCoreInfoType create_tensix_mem_map() { mem_map_sizes[static_cast(HalL1MemAddrType::GO_MSG)] = sizeof(go_msg_t); mem_map_sizes[static_cast(HalL1MemAddrType::LAUNCH_MSG_BUFFER_RD_PTR)] = sizeof(std::uint32_t); mem_map_sizes[static_cast(HalL1MemAddrType::LOCAL)] = MEM_TRISC_LOCAL_SIZE; // TRISC, BRISC, or NCRISC? + mem_map_sizes[static_cast(HalL1MemAddrType::BANK_TO_NOC_SCRATCH)] = MEM_BANK_TO_NOC_SIZE; std::vector> processor_classes(NumTensixDispatchClasses); std::vector processor_types; diff --git a/tt_metal/third_party/pybind11 b/tt_metal/third_party/pybind11 deleted file mode 160000 index b8f28551cc3..00000000000 --- a/tt_metal/third_party/pybind11 +++ /dev/null @@ -1 +0,0 @@ -Subproject commit b8f28551cc3a98ea9fbfc15c05b513c8f2d23e84 diff --git a/tt_metal/third_party/tt_llk_blackhole b/tt_metal/third_party/tt_llk_blackhole index 5686dd5f006..8b5afa5b0f9 160000 --- a/tt_metal/third_party/tt_llk_blackhole +++ b/tt_metal/third_party/tt_llk_blackhole @@ -1 +1 @@ -Subproject commit 5686dd5f006bfa537de74ce7ce0428be99b5cc9a +Subproject commit 8b5afa5b0f92841f13d49263482bdde6aaeef4ca diff --git a/tt_metal/third_party/tt_llk_wormhole_b0 b/tt_metal/third_party/tt_llk_wormhole_b0 index eadfdb466dc..ed02df9eb4b 160000 --- a/tt_metal/third_party/tt_llk_wormhole_b0 +++ b/tt_metal/third_party/tt_llk_wormhole_b0 @@ -1 +1 @@ -Subproject commit eadfdb466dcc738803969f9eda76aacbb33b5ad6 +Subproject commit ed02df9eb4bbfb37da1b9d9a8a129f1f6842a6cd diff --git a/tt_metal/third_party/umd b/tt_metal/third_party/umd index a98ddd2f1f3..e92ac1c44ea 160000 --- a/tt_metal/third_party/umd +++ b/tt_metal/third_party/umd @@ -1 +1 @@ -Subproject commit a98ddd2f1f302885c352154203d96e940b22f71b +Subproject commit e92ac1c44eab7a1bd06e6da321fa9309e5a73159 diff --git a/tt_metal/tools/profiler/kernel_profiler.hpp b/tt_metal/tools/profiler/kernel_profiler.hpp index 78844333dcd..81d4f5a3d1a 100644 --- a/tt_metal/tools/profiler/kernel_profiler.hpp +++ b/tt_metal/tools/profiler/kernel_profiler.hpp @@ -446,4 +446,8 @@ inline __attribute__((always_inline)) void recordEvent(uint16_t event_id) { #define DeviceZoneSetCounter(counter) +#define DeviceTimestampedData(data_id, data) + +#define DeviceRecordEvent(event_id) + #endif diff --git a/tt_metal/tools/profiler/process_device_log.py b/tt_metal/tools/profiler/process_device_log.py index 4fbba965402..ecdd1396cbf 100755 --- a/tt_metal/tools/profiler/process_device_log.py +++ b/tt_metal/tools/profiler/process_device_log.py @@ -309,6 +309,7 @@ def get_ops(timeseries): opsDict[opID].append(ts) ordered_ops = list(opsDict.keys()) + # sort over timestamps ordered_ops.sort(key=lambda x: opsDict[x][0][1]) ops = [] @@ -327,9 +328,7 @@ def get_ops(timeseries): if (risc == "BRISC" and timerID["zone_name"] == "BRISC-FW" and timerID["type"] == "ZONE_START") or ( risc == "ERISC" and timerID["zone_name"] == "ERISC-FW" and timerID["type"] == "ZONE_START" ): - for opDuration in coresOp.values(): - assert len(opDuration) == 2, "Unexpected FW start" - + assert len(coresOp[core]) == 2, "Unexpected FW end" ops.append({"timeseries": []}) coresOp = {} elif (risc == "BRISC" and timerID["zone_name"] == "BRISC-FW" and timerID["type"] == "ZONE_END") or ( diff --git a/tt_metal/tools/profiler/tt_metal_profiler.cpp b/tt_metal/tools/profiler/tt_metal_profiler.cpp index d42d379b635..e90e9caa236 100644 --- a/tt_metal/tools/profiler/tt_metal_profiler.cpp +++ b/tt_metal/tools/profiler/tt_metal_profiler.cpp @@ -99,22 +99,22 @@ void syncDeviceHost( smallestHostime.emplace(device_id, 0); constexpr uint16_t sampleCount = 249; - if (sync_program == nullptr) { - sync_program = std::make_shared(); - - std::map kernel_defines = { - {"SAMPLE_COUNT", std::to_string(sampleCount)}, - }; - - tt_metal::KernelHandle brisc_kernel = tt_metal::CreateKernel( - *sync_program, - "tt_metal/tools/profiler/sync/sync_kernel.cpp", - logical_core, - tt_metal::DataMovementConfig{ - .processor = tt_metal::DataMovementProcessor::RISCV_0, - .noc = tt_metal::NOC::RISCV_0_default, - .defines = kernel_defines}); - } + // TODO(MO): Always recreate a new program until subdevice + // allows using the first program generated by default manager + sync_program = std::make_shared(); + + std::map kernel_defines = { + {"SAMPLE_COUNT", std::to_string(sampleCount)}, + }; + + tt_metal::KernelHandle brisc_kernel = tt_metal::CreateKernel( + *sync_program, + "tt_metal/tools/profiler/sync/sync_kernel.cpp", + logical_core, + tt_metal::DataMovementConfig{ + .processor = tt_metal::DataMovementProcessor::RISCV_0, + .noc = tt_metal::NOC::RISCV_0_default, + .defines = kernel_defines}); EnqueueProgram(device->command_queue(), *sync_program, false); diff --git a/tt_metal/tt_metal.cpp b/tt_metal/tt_metal.cpp index ff1987983e9..e59f14430cd 100644 --- a/tt_metal/tt_metal.cpp +++ b/tt_metal/tt_metal.cpp @@ -745,13 +745,10 @@ bool ConfigureDeviceWithProgram(Device* device, Program& program, bool fd_bootlo uint32_t size_in_bytes = circular_buffer->size(); uint32_t num_pages = circular_buffer->num_pages(buffer_index); uint32_t page_size = size_in_bytes / num_pages; - circular_buffer_config_vec[base_index] = - addr_in_bytes >> CIRCULAR_BUFFER_LOG2_WORD_SIZE_BYTES; // convert to addr in 16B words - circular_buffer_config_vec[base_index + 1] = - size_in_bytes >> CIRCULAR_BUFFER_LOG2_WORD_SIZE_BYTES; // convert to addr in 16B words + circular_buffer_config_vec[base_index] = addr_in_bytes; // convert to addr in 16B words + circular_buffer_config_vec[base_index + 1] = size_in_bytes; // convert to addr in 16B words circular_buffer_config_vec[base_index + 2] = num_pages; - circular_buffer_config_vec[base_index + 3] = - page_size >> CIRCULAR_BUFFER_LOG2_WORD_SIZE_BYTES; + circular_buffer_config_vec[base_index + 3] = page_size; } for (uint32_t buffer_index : circular_buffer->remote_buffer_indices()) { uint32_t base_index = @@ -856,15 +853,9 @@ DeviceAddr AllocateBuffer(Buffer* buffer) { *buffer->sub_device_manager_id(), buffer->device()->get_active_sub_device_manager_id()); } - auto allocator = buffer->allocator(); - DeviceAddr allocated_addr; - if (is_sharded(buffer->buffer_layout())) { - allocated_addr = allocator::allocate_buffer( - *allocator, buffer->shard_spec().size() * buffer->num_cores().value() * buffer->page_size(), buffer); - } else { - allocated_addr = allocator::allocate_buffer(*allocator, buffer->size(), buffer); - } + DeviceAddr allocated_addr = allocator::allocate_buffer(*buffer->allocator(), buffer); + // Assertion here because buffer class returns a u32 when address is queried // Requires updating all use cases of buffer address to accept a u64 to remove TT_ASSERT(allocated_addr <= std::numeric_limits::max()); diff --git a/tt_metal/tt_stl/reflection.hpp b/tt_metal/tt_stl/reflection.hpp index 42c3aecb6a4..e0c7cfd5199 100644 --- a/tt_metal/tt_stl/reflection.hpp +++ b/tt_metal/tt_stl/reflection.hpp @@ -6,7 +6,10 @@ #include +#include +#include #include +#include #include #include #include @@ -14,6 +17,7 @@ #include #include #include +#include #include #include #include @@ -448,6 +452,32 @@ std::ostream& operator<<(std::ostream& os, const std::set& set) { return os; } +template +std::ostream& operator<<(std::ostream& os, const std::map& map) { + os << "{"; + for (auto it = map.begin(); it != map.end(); ++it) { + os << it->first << ": " << it->second; + if (it != map.end()) { + os << ", "; + } + } + os << "}"; + return os; +} + +template +std::ostream& operator<<(std::ostream& os, const std::unordered_map& map) { + os << "{"; + for (auto it = map.begin(); it != map.end(); ++it) { + os << it->first << ": " << it->second; + if (it != map.end()) { + os << ", "; + } + } + os << "}"; + return os; +} + template requires(tt::stl::concepts::Reflectable and not(std::integral or std::is_array::value)) std::ostream& operator<<(std::ostream& os, const T& object) { @@ -978,6 +1008,30 @@ struct fmt::formatter> { } }; +template +struct fmt::formatter> { + constexpr auto parse(format_parse_context& ctx) -> format_parse_context::iterator { return ctx.end(); } + + auto format(const std::map& map, format_context& ctx) const -> format_context::iterator { + using tt::stl::reflection::operator<<; + std::stringstream ss; + ss << map; + return fmt::format_to(ctx.out(), "{}", ss.str()); + } +}; + +template +struct fmt::formatter> { + constexpr auto parse(format_parse_context& ctx) -> format_parse_context::iterator { return ctx.end(); } + + auto format(const std::unordered_map& map, format_context& ctx) const -> format_context::iterator { + using tt::stl::reflection::operator<<; + std::stringstream ss; + ss << map; + return fmt::format_to(ctx.out(), "{}", ss.str()); + } +}; + template requires( tt::stl::concepts::Reflectable and not(std::integral or std::is_array::value or @@ -1063,7 +1117,7 @@ inline hash_t hash_object(const T& object) noexcept { fmt::print("Hashing struct {} using compile-time attributes: {}\n", get_type_name(), object); } constexpr auto num_attributes = reflection::detail::get_num_attributes(); - std::size_t hash = 0; + hash_t hash = 0; const auto attribute_values = object.attribute_values(); [&object, &hash, &attribute_values](std::index_sequence) { ( @@ -1074,11 +1128,26 @@ inline hash_t hash_object(const T& object) noexcept { ...); }(std::make_index_sequence{}); return hash; + } else if constexpr (is_specialization_v) { + if constexpr (DEBUG_HASH_OBJECT_FUNCTION) { + fmt::print("Hashing std::tuple of type {}: {}\n", get_type_name(), object); + } + constexpr auto num_elements = std::tuple_size_v; + hash_t hash = 0; + [&object, &hash](std::index_sequence) { + ( + [&object, &hash] { + const auto& element = std::get(object); + hash = hash_objects(hash, element); + }(), + ...); + }(std::make_index_sequence{}); + return hash; } else if constexpr (is_specialization_v) { if constexpr (DEBUG_HASH_OBJECT_FUNCTION) { fmt::print("Hashing std::vector of type {}: {}\n", get_type_name(), object); } - auto hash = 0; + hash_t hash = 0; for (const auto& element : object) { hash = hash_objects(hash, element); } @@ -1087,11 +1156,37 @@ inline hash_t hash_object(const T& object) noexcept { if constexpr (DEBUG_HASH_OBJECT_FUNCTION) { fmt::print("Hashing std::set of type {}: {}\n", get_type_name(), object); } - auto hash = 0; + hash_t hash = 0; for (const auto& element : object) { hash = hash_objects(hash, element); } return hash; + } else if constexpr (is_specialization_v) { + if constexpr (DEBUG_HASH_OBJECT_FUNCTION) { + fmt::print("Hashing std::map of type {}: {}\n", get_type_name(), object); + } + hash_t hash = 0; + for (const auto& [key, value] : object) { + hash = hash_objects(hash, key, value); + } + return hash; + } else if constexpr (is_specialization_v) { + if constexpr (DEBUG_HASH_OBJECT_FUNCTION) { + fmt::print("Hashing std::unordered_map of type {}: {}\n", get_type_name(), object); + } + // Sort the unordered map by key to make the hash order invariant + std::vector iterators; + iterators.reserve(object.size()); + for (auto it = object.begin(); it != object.end(); ++it) { + iterators.push_back(it); + } + std::sort(iterators.begin(), iterators.end(), [](const auto& a, const auto& b) { return a->first < b->first; }); + + hash_t hash = 0; + for (const auto& it : iterators) { + hash = hash_objects(hash, it->first, it->second); + } + return hash; } else if constexpr (is_specialization_v) { if constexpr (DEBUG_HASH_OBJECT_FUNCTION) { fmt::print("Hashing std::optional of type {}: {}\n", get_type_name(), object); @@ -1105,7 +1200,7 @@ inline hash_t hash_object(const T& object) noexcept { if constexpr (DEBUG_HASH_OBJECT_FUNCTION) { fmt::print("Hashing struct {} using reflect library: {}\n", get_type_name(), object); } - std::size_t hash = 0; + hash_t hash = 0; reflect::for_each([&hash, &object](auto I) { hash = hash_objects(hash, reflect::get(object)); }, object); return hash; } else { @@ -1335,7 +1430,7 @@ struct to_json_t> { nlohmann::json operator()(const std::map& object) { nlohmann::json json_object = nlohmann::json::object(); for (const auto& [key, value] : object) { - json_object[to_json(key)] = to_json(value); + json_object[to_json(key).dump()] = to_json(value); } return json_object; } @@ -1346,7 +1441,29 @@ struct from_json_t> { std::map operator()(const nlohmann::json& json_object) { std::map object; for (const auto& [key, value] : json_object.items()) { - object[from_json(key)] = from_json(value); + object[from_json(nlohmann::json::parse(key))] = from_json(value); + } + return object; + } +}; + +template +struct to_json_t> { + nlohmann::json operator()(const std::unordered_map& object) { + nlohmann::json json_object = nlohmann::json::object(); + for (const auto& [key, value] : object) { + json_object[to_json(key).dump()] = to_json(value); + } + return json_object; + } +}; + +template +struct from_json_t> { + std::map operator()(const nlohmann::json& json_object) { + std::unordered_map object; + for (const auto& [key, value] : json_object.items()) { + object[from_json(nlohmann::json::parse(key))] = from_json(value); } return object; } diff --git a/ttnn/CMakeLists.txt b/ttnn/CMakeLists.txt index 723bf1d4833..1da988236ab 100644 --- a/ttnn/CMakeLists.txt +++ b/ttnn/CMakeLists.txt @@ -8,38 +8,12 @@ set(ALL_TTNN_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/global_semaphore.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/run_operation.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/distributed/api.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/distributed/distributed_tensor_config.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/distributed/distributed_pybind.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/graph/graph_processor.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/graph/graph_trace_utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/graph/graph_pybind.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/ccl/erisc_datamover_builder.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/ccl/all_gather/all_gather.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/ccl/all_gather/all_gather_pybind.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/ccl/all_gather/device/multi_core/all_gather_op_multi_core.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/all_gather_matmul.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/all_gather_matmul_pybind.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/device/all_gather_matmul_op.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/device/multi_core/all_gather_matmul_op_multi_core.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/experimental/ccl/all_reduce/all_reduce.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/experimental/ccl/all_reduce/all_reduce_pybind.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/experimental/ccl/all_reduce/device/all_reduce_op.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/ccl/ccl_op_fusion.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/ccl/ccl_common.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/ccl/ccl_host_datastructures.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/ccl/common/types/ccl_types_args_emitters.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/ccl/common/uops/ccl_command.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/ccl/reduce_scatter/device/host/reduce_scatter_full_worker_grid.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/ccl/reduce_scatter/reduce_scatter.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/ccl/reduce_scatter/reduce_scatter_pybind.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/core/compute_kernel/compute_kernel_config.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/ccl/reduce_scatter/host/reduce_scatter_worker_builder.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/ccl/reduce_scatter/host/reduce_scatter_common.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/ccl/barrier/device/host/barrier_full_worker_grid.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/ccl/barrier/device/barrier_op.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/ccl/barrier/barrier.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/ccl/barrier/barrier_pybind.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/conv/conv2d/conv2d.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/conv/conv2d/conv2d_utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/conv/conv2d/conv2d_pybind.cpp @@ -612,11 +586,15 @@ endforeach() ### Setup TTNN as a shared library with optional Python bindings add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/tensor) +add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/ccl) +add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations//experimental/ccl) add_subdirectory(cpp/ttnn/deprecated) set(TTNN_FINAL_SRC ${TTNN_SRC} ${QUEUE_SRCS} ${TENSOR_SRCS} + ${CCL_TTNN_SRCS} + ${CCL_EXPERIMENTAL_TTNN_SRCS} ${TT_DNN_SRCS} ) @@ -671,7 +649,7 @@ if(WITH_PYTHON_BINDINGS) list( APPEND TTNN_PUBLIC_LINK_LIBRARIES - pch_pybinds + pybind11::module ${Python3_LIBRARIES} ) endif() diff --git a/ttnn/cpp/pybind11/device.cpp b/ttnn/cpp/pybind11/device.cpp index b60a36ed7ad..1196c1f08e0 100644 --- a/ttnn/cpp/pybind11/device.cpp +++ b/ttnn/cpp/pybind11/device.cpp @@ -100,6 +100,8 @@ void py_device_module_types(py::module& m_device) { py::class_(m_device, "SubDevice", "Class describing a sub-device of a Tenstorrent accelerator device."); + py::class_(m_device, "SubDeviceId", "ID of a sub-device."); + py::class_(m_device, "SubDeviceManagerId", "ID of a sub-device manager."); } @@ -114,6 +116,14 @@ void device_module(py::module& m_device) { The order of cores is Tensix, then Ethernet. )doc"); + auto pySubDeviceId = static_cast>(m_device.attr("SubDeviceId")); + pySubDeviceId.def( + py::init(), + py::arg("id"), + R"doc( + Creates a SubDeviceId object with the given ID. + )doc"); + auto pyDevice = static_cast>>(m_device.attr("Device")); pyDevice .def( @@ -482,10 +492,11 @@ void device_module(py::module& m_device) { m_device.def( "synchronize_device", - [](Device* device, const std::optional cq_id) { + [](Device* device, const std::optional cq_id, const std::vector& sub_device_ids) { // Send finish command to issue queue through worker thread // Worker thread will stall until the device is flushed. - device->push_work([device, cq_id]() mutable { Synchronize(device, cq_id); }); + device->push_work( + [device, cq_id, &sub_device_ids]() mutable { Synchronize(device, cq_id, sub_device_ids); }); // Main thread stalls until worker is complete (full device and worker queue flush). device->synchronize(); }, @@ -493,10 +504,13 @@ void device_module(py::module& m_device) { Synchronize the device with host by waiting for all operations to complete. If cq_id is provided then only the operations associated with that cq_id are waited for, otherwise operations for all command queues are waited on. + If the device has been configured with sub-devices, then sub_device_ids can be provided to only wait + for the operations that ran on the specified sub-devices, otherwise all sub-devices (the entire chip) are waited on. Args: device (ttnn.device.Device): The device to synchronize with. cq_id (int, optional): The command queue ID to synchronize. Defaults to `None`. + sub_device_ids (List[ttnn.SubDeviceId], optional): The sub-device IDs to synchronize. Defaults to all sub-devices. Returns: `None`: The op ensures that all operations are completed. @@ -508,7 +522,8 @@ void device_module(py::module& m_device) { >>> ttnn.synchronize_device(device) )doc", py::arg("device"), - py::arg("cq_id") = std::nullopt); + py::arg("cq_id") = std::nullopt, + py::arg("sub_device_ids") = std::vector()); m_device.def("SetLazyCommandQueueMode", &tt::tt_metal::detail::SetLazyCommandQueueMode, R"doc( If set to true, the host does not notify the device that there are commands available other than the FinishCommand. Once set to false, all subsequent commands will immediately notify the device @@ -527,6 +542,8 @@ void device_module(py::module& m_device) { m_device.attr("DEFAULT_L1_SMALL_SIZE") = py::int_(DEFAULT_L1_SMALL_SIZE); m_device.attr("DEFAULT_TRACE_REGION_SIZE") = py::int_(DEFAULT_TRACE_REGION_SIZE); + + m_device.attr("DefaultQueueId") = ttnn::DefaultQueueId; } void py_device_module(py::module& module) { diff --git a/ttnn/cpp/pybind11/events.cpp b/ttnn/cpp/pybind11/events.cpp index fdb12668f63..4ce6d41e644 100644 --- a/ttnn/cpp/pybind11/events.cpp +++ b/ttnn/cpp/pybind11/events.cpp @@ -31,15 +31,17 @@ void py_module(py::module& module) { module.def( "record_event", - py::overload_cast&>(&record_event), + py::overload_cast&, const std::vector&>(&record_event), py::arg("cq_id"), py::arg("event"), + py::arg("sub_device_ids") = std::vector(), R"doc( Record the completion of commands on this CQ, preceeding this call. Args: cq_id (int): The Command Queue on which event completion will be recorded. event (event): The event used to record completion of preceeding commands. + sub_device_ids (List[ttnn.SubDeviceId], optional): The sub-device IDs to record completion for. Defaults to all sub-devices. )doc"); module.def( @@ -69,9 +71,10 @@ void py_module(py::module& module) { module.def( "record_event", - py::overload_cast(&record_event), + py::overload_cast&>(&record_event), py::arg("cq_id"), py::arg("multi_device_event"), + py::arg("sub_device_ids") = std::vector(), R"doc( Record the completion of commands on this CQ, preceeding this call. @@ -91,6 +94,7 @@ void py_module(py::module& module) { Args: cq_id (int): The Command Queue on which event completion will be recorded. event (event): The event used to record completion of preceeding commands. + sub_device_ids (List[ttnn.SubDeviceId], optional): The sub-device IDs to record completion for. Defaults to all sub-devices. )doc"); } diff --git a/ttnn/cpp/pybind11/operations/core.hpp b/ttnn/cpp/pybind11/operations/core.hpp index eaf0014cf52..db8a7a1970c 100644 --- a/ttnn/cpp/pybind11/operations/core.hpp +++ b/ttnn/cpp/pybind11/operations/core.hpp @@ -6,7 +6,9 @@ #include #include +#include +#include "pybind11/cast.h" #include "ttnn/cpp/pybind11/decorators.hpp" #include "ttnn/operations/core/core.hpp" #include "tt_metal/common/work_split.hpp" @@ -22,12 +24,14 @@ void py_module_types(py::module& module) { py::class_(module, "GrayskullComputeKernelConfig") .def( - py::init(), + py::init(), py::kw_only(), py::arg("math_fidelity") = MathFidelity::Invalid, - py::arg("math_approx_mode") = true) + py::arg("math_approx_mode") = true, + py::arg("dst_full_sync_en") = false) .def_readwrite("math_fidelity", &GrayskullComputeKernelConfig::math_fidelity) - .def_readwrite("math_approx_mode", &GrayskullComputeKernelConfig::math_approx_mode); + .def_readwrite("math_approx_mode", &GrayskullComputeKernelConfig::math_approx_mode) + .def_readwrite("dst_full_sync_en", &GrayskullComputeKernelConfig::dst_full_sync_en); py::class_(module, "WormholeComputeKernelConfig") .def( @@ -46,23 +50,46 @@ void py_module_types(py::module& module) { } void py_module(py::module& module) { + + module.def("init_device_compute_kernel_config", &ttnn::init_device_compute_kernel_config, + py::arg("arch"), + py::arg("device_kernel_config") = std::nullopt, + py::kw_only(), + py::arg("math_fidelity") = MathFidelity::LoFi, + py::arg("math_approx_mode") = true, + py::arg("fp32_dest_acc_en") = false, + py::arg("packer_l1_acc") = false, + py::arg("dst_full_sync_en") = false + ); module.def("unsqueeze_to_4D", &ttnn::unsqueeze_to_4D, py::arg("tensor")); module.def( "to_device", - py::overload_cast&>( - &ttnn::operations::core::to_device), + py::overload_cast< + const ttnn::Tensor&, + Device*, + const std::optional&, + uint8_t, + const std::vector&>(&ttnn::operations::core::to_device), py::arg("tensor"), py::arg("device"), - py::arg("memory_config") = std::nullopt); + py::arg("memory_config") = std::nullopt, + py::arg("cq_id") = ttnn::DefaultQueueId, + py::arg("sub_device_ids") = std::vector()); module.def( "to_device", - py::overload_cast&>( - &ttnn::operations::core::to_device), + py::overload_cast< + const ttnn::Tensor&, + MeshDevice*, + const std::optional&, + uint8_t, + const std::vector&>(&ttnn::operations::core::to_device), py::arg("tensor"), py::arg("device"), py::arg("memory_config") = std::nullopt, + py::arg("cq_id") = ttnn::DefaultQueueId, + py::arg("sub_device_ids") = std::vector(), R"doc( Copy tensor from host to device. @@ -70,6 +97,9 @@ void py_module(py::module& module) { tensor (ttnn.Tensor): The tensor to be copied from host to device. device (ttnn.Device | ttnn.MeshDevice): The target device where the tensor will be copied. memory_config (ttnn.MemoryConfig, optional): The memory configuration to use. Defaults to `None`. + cq_id (int, optional): The command queue ID to use. Defaults to `0`. + sub_device_ids (List[ttnn.SubDeviceId], optional): The sub-device IDs to wait on before writing the tensor to device memory. + If it is not provided, device will stall for all programs of the specified cq to finish before writing the tensor to device memory. Returns: ttnn.Tensor: The device tensor copy. @@ -88,6 +118,7 @@ void py_module(py::module& module) { py::arg("blocking") = true, py::kw_only(), py::arg("cq_id") = ttnn::DefaultQueueId, + py::arg("sub_device_ids") = std::vector(), R"doc( Copy tensor from device to host. @@ -97,6 +128,8 @@ void py_module(py::module& module) { Keyword args: cq_id (int, optional): the command queue ID to use. Defaults to `0`. + sub_device_ids (List[ttnn.SubDeviceId], optional): the sub-device IDs to wait on before reading the tensor from device memory. + If it is not provided, device will stall for all programs of the specified cq to finish before reading the tensor from device memory. Returns: ttnn.Tensor: the host tensor copy. @@ -228,7 +261,8 @@ void py_module(py::module& module) { &ttnn::operations::core::copy_host_to_device_tensor, py::arg("host_tensor"), py::arg("device_tensor"), - py::arg("cq_id") = ttnn::DefaultQueueId); + py::arg("cq_id") = ttnn::DefaultQueueId, + py::arg("sub_device_ids") = std::vector()); module.def( "begin_trace_capture", @@ -348,6 +382,12 @@ void py_module(py::module& module) { "num_cores_to_corerangeset", py::overload_cast(&tt::tt_metal::num_cores_to_corerangeset), R"doc(Create a CoreRangeSet containing the specified number of cores)doc"); + + module.def( + "num_cores_to_corerangeset_in_subcoregrids", + py::overload_cast( + &tt::tt_metal::num_cores_to_corerangeset_in_subcoregrids), + R"doc(Create a CoreRangeSet containing the specified number of cores starting from start_core in given subcoregrids)doc"); } } // namespace core diff --git a/ttnn/cpp/pybind11/operations/creation.hpp b/ttnn/cpp/pybind11/operations/creation.hpp index 6581766ccb0..26a5e1778c5 100644 --- a/ttnn/cpp/pybind11/operations/creation.hpp +++ b/ttnn/cpp/pybind11/operations/creation.hpp @@ -138,7 +138,7 @@ auto create_pybind_empty_like_overload() { const ttnn::Tensor& reference, const std::optional& dtype, const std::optional& layout, - const std::optional> device, + const std::optional> device, const std::optional& memory_config) -> ttnn::Tensor { return self(reference, dtype, layout, device, memory_config); }, @@ -150,6 +150,26 @@ auto create_pybind_empty_like_overload() { py::arg("memory_config") = ttnn::DRAM_MEMORY_CONFIG}; } +template +auto create_pybind_arange_overload() { + return ttnn::pybind_overload_t{ + [](const creation_operation_t& self, + const int64_t start, + const int64_t end, + const int64_t step, + const DataType dtype, + const std::optional> device, + const MemoryConfig& memory_config) -> ttnn::Tensor { + return self(start, end, step, dtype, device, memory_config); + }, + py::arg("start") = 0, + py::arg("end"), + py::arg("step") = 1, + py::arg("dtype") = DataType::BFLOAT16, + py::arg("device") = std::nullopt, + py::arg("memory_config") = ttnn::DRAM_MEMORY_CONFIG}; +} + template void bind_full_operation(py::module& module, const creation_operation_t& operation) { auto doc = fmt::format( @@ -350,22 +370,8 @@ void bind_arange_operation(py::module& module, const creation_operation_t& opera module, operation, doc, - ttnn::pybind_overload_t{ - [](const creation_operation_t& self, - const int64_t start, - const int64_t end, - const int64_t step, - const DataType dtype, - const std::optional>& device, - const MemoryConfig& memory_config) -> ttnn::Tensor { - return self(start, end, step, dtype, device, memory_config); - }, - py::arg("start") = 0, - py::arg("end"), - py::arg("step") = 1, - py::arg("dtype") = DataType::BFLOAT16, - py::arg("device") = std::nullopt, - py::arg("memory_config") = ttnn::DRAM_MEMORY_CONFIG}); + create_pybind_arange_overload(), + create_pybind_arange_overload()); } template diff --git a/ttnn/cpp/pybind11/pytensor.cpp b/ttnn/cpp/pybind11/pytensor.cpp index 8cd2e3da094..48a360fb3cb 100644 --- a/ttnn/cpp/pybind11/pytensor.cpp +++ b/ttnn/cpp/pybind11/pytensor.cpp @@ -919,15 +919,22 @@ void pytensor_module(py::module& m_tensor) { )doc") .def( "to", - py::overload_cast(&Tensor::to, py::const_), + py::overload_cast&>( + &Tensor::to, py::const_), py::arg("device").noconvert(), py::arg("mem_config").noconvert() = MemoryConfig{.memory_layout = TensorMemoryLayout::INTERLEAVED}, + py::arg("cq_id") = ttnn::DefaultQueueId, + py::arg("sub_device_ids") = std::vector(), py::keep_alive<0, 2>(), R"doc( Move TT Tensor from host device to TT accelerator device. Only BFLOAT16 (in ROW_MAJOR or TILE layout) and BFLOAT8_B, BFLOAT4_B (in TILE layout) are supported on device. + ``sub_device_ids`` can be used to specify which specific sub devices to wait on before writing the tensor to device memory. + + If it is not provided, device will stall for all programs of the specified cq to finish before writing the tensor to device memory. + If ``arg1`` is not supplied, default ``MemoryConfig`` with ``interleaved`` set to ``True``. +-----------+-------------------------------------------------+----------------------------+-----------------------+----------+ @@ -937,6 +944,10 @@ void pytensor_module(py::module& m_tensor) { +-----------+-------------------------------------------------+----------------------------+-----------------------+----------+ | arg1 | MemoryConfig of tensor of TT accelerator device | ttnn.MemoryConfig | | No | +-----------+-------------------------------------------------+----------------------------+-----------------------+----------+ + | arg2 | CQ ID of TT accelerator device to use | uint8_t | | No | + +-----------+-------------------------------------------------+----------------------------+-----------------------+----------+ + | arg3 | Sub device IDs to wait on before writing tensor | List[ttnn.SubDeviceId] | | No | + +-----------+-------------------------------------------------+----------------------------+-----------------------+----------+ .. code-block:: python @@ -950,15 +961,22 @@ void pytensor_module(py::module& m_tensor) { )doc") .def( "to", - py::overload_cast(&Tensor::to, py::const_), + py::overload_cast&>( + &Tensor::to, py::const_), py::arg("mesh_device").noconvert(), py::arg("mem_config").noconvert() = MemoryConfig{.memory_layout = TensorMemoryLayout::INTERLEAVED}, + py::arg("cq_id") = ttnn::DefaultQueueId, + py::arg("sub_device_ids") = std::vector(), py::keep_alive<0, 2>(), R"doc( Move TT Tensor from host device to TT accelerator device. Only BFLOAT16 (in ROW_MAJOR or TILE layout) and BFLOAT8_B, BFLOAT4_B (in TILE layout) are supported on device. + ``sub_device_ids`` can be used to specify which specific sub devices to wait on before writing the tensor to device memory. + + If it is not provided, device will stall for all programs of the specified cq to finish before writing the tensor to device memory. + If ``arg1`` is not supplied, default ``MemoryConfig`` with ``interleaved`` set to ``True``. +-----------+-------------------------------------------------+----------------------------+-----------------------+----------+ @@ -968,6 +986,10 @@ void pytensor_module(py::module& m_tensor) { +-----------+-------------------------------------------------+----------------------------+-----------------------+----------+ | arg1 | MemoryConfig of tensor of TT accelerator device | ttnn.MemoryConfig | | No | +-----------+-------------------------------------------------+----------------------------+-----------------------+----------+ + | arg2 | CQ ID of TT accelerator device to use | uint8_t | | No | + +-----------+-------------------------------------------------+----------------------------+-----------------------+----------+ + | arg3 | Sub device IDs to wait before writing tensor | List[ttnn.SubDeviceId] | | No | + +-----------+-------------------------------------------------+----------------------------+-----------------------+----------+ .. code-block:: python @@ -1022,12 +1044,19 @@ void pytensor_module(py::module& m_tensor) { )doc") .def( "cpu", - [](const Tensor& self, bool blocking, uint8_t cq_id) { return self.cpu(blocking, cq_id); }, + [](const Tensor& self, bool blocking, uint8_t cq_id, const std::vector& sub_device_ids) { + return self.cpu(blocking, cq_id, sub_device_ids); + }, py::arg("blocking") = true, py::arg("cq_id") = ttnn::DefaultQueueId, + py::arg("sub_device_ids") = std::vector(), R"doc( Move TT Tensor from TT accelerator device to host device. + ``sub_device_ids`` can be used to specify which specific sub devices to wait on before reading the tensor from device memory. + + If it is not provided, device will stall waiting for all programs of the specified cq to finish before reading the tensor from device memory. + .. code-block:: python tt_tensor = tt_tensor.cpu() diff --git a/ttnn/cpp/ttnn/distributed/api.cpp b/ttnn/cpp/ttnn/distributed/api.cpp index e4ab3a5ece1..14aa7085ff5 100644 --- a/ttnn/cpp/ttnn/distributed/api.cpp +++ b/ttnn/cpp/ttnn/distributed/api.cpp @@ -6,8 +6,10 @@ #include +#include "tt_metal/tt_stl/overloaded.hpp" #include "ttnn/tensor/tensor.hpp" #include "ttnn/tensor/tensor_utils.hpp" +#include "ttnn/distributed/distributed_tensor_config.hpp" #include "tt_metal/distributed/mesh_device.hpp" using namespace tt::tt_metal; @@ -21,7 +23,7 @@ std::shared_ptr open_mesh_device( size_t num_command_queues, const DispatchCoreConfig& dispatch_core_config, MeshType mesh_type, - const std::pair& offset, + const MeshOffset& offset, const std::vector& physical_device_ids) { auto config = MeshDeviceConfig(mesh_shape, offset, physical_device_ids, mesh_type); return MeshDevice::create(config, l1_small_size, trace_region_size, num_command_queues, dispatch_core_config); @@ -58,18 +60,20 @@ std::vector get_device_tensors(const ttnn::Tensor& tensor) { TT_THROW("Expected tensor to be on MultiDeviceHostStorage type!"); } -Tensor aggregate_as_tensor(std::vector& tensor_shards) { +Tensor aggregate_as_tensor( + const std::vector& tensor_shards, const tt::tt_metal::DistributedTensorConfig& config) { TT_ASSERT(tensor_shards.size() > 0, "At least one tensor shard must be provided"); + const auto& reference_shard = tensor_shards.at(0); for (const auto& shard : tensor_shards) { - if (shard.storage_type() != tensor_shards.at(0).storage_type()) { + if (shard.storage_type() != reference_shard.storage_type()) { TT_THROW("All tensor shards must have the same storage type"); } } // Based whether the first tensor shard has OwnedBuffer or Device buffer, // we want to use MultiDeviceHostStorage or MultiDeviceStorage - StorageType storage_type = tensor_shards.at(0).storage_type(); - Tile tile = tensor_shards.at(0).get_tensor_spec().tile(); + StorageType storage_type = reference_shard.storage_type(); + Tile tile = reference_shard.get_tensor_spec().tile(); if (storage_type == StorageType::OWNED) { std::vector shapes; std::vector host_owned_buffers; @@ -81,7 +85,7 @@ Tensor aggregate_as_tensor(std::vector& tensor_shards) { TT_THROW( "Error aggregating multichip tensors: Attempting to aggregate tensors with different tiling " "configurations. Device {} has tiling ({}x{}) while device {} has tiling {}x{}.", - tensor_shards.at(0).device()->id(), + reference_shard.device()->id(), tile.get_height(), tile.get_width(), shard.device()->id(), @@ -89,12 +93,12 @@ Tensor aggregate_as_tensor(std::vector& tensor_shards) { shard_tile.get_width()); } } - auto storage = MultiDeviceHostStorage{AllGatherTensor(), std::move(host_owned_buffers), shapes}; + auto storage = MultiDeviceHostStorage{config, std::move(host_owned_buffers), shapes}; return Tensor( std::move(storage), - tensor_shards.at(0).get_legacy_shape(), - tensor_shards.at(0).get_dtype(), - tensor_shards.at(0).get_layout(), + reference_shard.get_legacy_shape(), + reference_shard.get_dtype(), + reference_shard.get_layout(), tile); } else { std::vector ordered_device_ids; @@ -111,7 +115,7 @@ Tensor aggregate_as_tensor(std::vector& tensor_shards) { TT_THROW( "Error aggregating multichip tensors: Attempting to aggregate tensors with different tiling " "configurations. Device {} has tiling ({}x{}) while device {} has tiling {}x{}.", - tensor_shards.at(0).device()->id(), + reference_shard.device()->id(), tile.get_height(), tile.get_width(), shard.device()->id(), @@ -119,12 +123,12 @@ Tensor aggregate_as_tensor(std::vector& tensor_shards) { shard_tile.get_width()); } } - auto storage = MultiDeviceStorage{AllGatherTensor(), ordered_device_ids, std::move(device_buffers), shapes}; + auto storage = MultiDeviceStorage{config, ordered_device_ids, std::move(device_buffers), shapes}; return Tensor( std::move(storage), - tensor_shards.at(0).get_legacy_shape(), - tensor_shards.at(0).get_dtype(), - tensor_shards.at(0).get_layout(), + reference_shard.get_legacy_shape(), + reference_shard.get_dtype(), + reference_shard.get_layout(), tile); } } @@ -140,7 +144,7 @@ std::vector get_t3k_physical_device_ids_ring() { return physical_device_ids; } -std::vector distribute_tensor_to_mesh(const Tensor& tensor, MeshDevice& mesh_device) { +std::vector get_mapped_devices(const Tensor& tensor, MeshDevice& mesh_device) { // For multi-device tensors, returns the number of workers capped by the number of buffers // Otherwise, returns all available workes from mesh_device. auto get_workers_for_tensor = [&tensor, &mesh_device]() { @@ -151,19 +155,15 @@ std::vector distribute_tensor_to_mesh(const Tensor& tensor, MeshDevice& } return workers; }; - if (mesh_device.get_view() != nullptr and std::holds_alternative(tensor.get_storage())) { const auto& host_storage = std::get(tensor.get_storage()); return std::visit( - [&](const auto& strategy) { - using StrategyType = std::decay_t; - if constexpr (std::is_same_v) { - return mesh_device.get_view()->get_devices(strategy.shard_mesh); - } else { - return get_workers_for_tensor(); - } - }, + tt::stl::overloaded{ + [&](const ShardTensor2D& s) { + return mesh_device.get_view()->get_devices(MeshShape{s.shard_mesh.y, s.shard_mesh.x}); + }, + [&](const auto&) { return get_workers_for_tensor(); }}, host_storage.strategy); } else if (std::holds_alternative(tensor.get_storage())) { return tensor.workers; diff --git a/ttnn/cpp/ttnn/distributed/api.hpp b/ttnn/cpp/ttnn/distributed/api.hpp index cfdd86bba50..23a914a02c9 100644 --- a/ttnn/cpp/ttnn/distributed/api.hpp +++ b/ttnn/cpp/ttnn/distributed/api.hpp @@ -7,6 +7,7 @@ #include #include "ttnn/tensor/tensor.hpp" +#include "ttnn/distributed/distributed_tensor_config.hpp" #include "ttnn/distributed/types.hpp" namespace ttnn::distributed::api { @@ -18,19 +19,22 @@ std::shared_ptr open_mesh_device( size_t num_command_queues, const tt::tt_metal::DispatchCoreConfig& dispatch_core_config, MeshType mesh_type = MeshType::RowMajor, - const std::pair& offset = std::pair(0, 0), + const MeshOffset& offset = MeshOffset(0, 0), const std::vector& physical_device_ids = {}); void close_mesh_device(const std::shared_ptr& mesh_device); +// Given a multi-device tensor, returns a list of individual per-device tensors. std::vector get_device_tensors(const ttnn::Tensor& tensor); -Tensor aggregate_as_tensor(std::vector& tensor_shards); +// Given a list of per-device shards, returns multi-device tensor. +Tensor aggregate_as_tensor( + const std::vector& tensor_shards, const tt::tt_metal::DistributedTensorConfig& config); std::vector get_t3k_physical_device_ids_ring(); // Maps a tensor to the set of devices in the device-mesh that the shards will be distributed across. -std::vector distribute_tensor_to_mesh(const Tensor& tensor, MeshDevice& mesh_device); +std::vector get_mapped_devices(const Tensor& tensor, MeshDevice& mesh_device); // Get the distributed tensor config from a tensor. tt::tt_metal::DistributedTensorConfig get_distributed_tensor_config_from_tensor(const Tensor& tensor); diff --git a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp index ec7e8e4691f..ed946f23d9b 100644 --- a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp +++ b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp @@ -3,6 +3,7 @@ // SPDX-License-Identifier: Apache-2.0 #include "ttnn/distributed/distributed_pybind.hpp" +#include #include "ttnn/distributed/api.hpp" #include "ttnn/tensor/tensor_utils.hpp" @@ -20,6 +21,8 @@ namespace py = pybind11; void py_module_types(py::module& module) { py::class_>(module, "MeshDevice"); py::class_(module, "MeshSubDeviceManagerId"); + py::class_(module, "MeshShape", "Struct representing the shape of a mesh device."); + py::class_(module, "MeshOffset", "Struct representing the offset of a mesh device."); } void py_module(py::module& module) { @@ -28,6 +31,37 @@ void py_module(py::module& module) { .value("Ring", MeshType::Ring) .value("Line", MeshType::Line) .export_values(); + + static_cast>(module.attr("MeshShape")) + .def( + py::init([](size_t num_rows, size_t num_cols) { return MeshShape(num_rows, num_cols); }), + "Constructor with specified number of rows and columns.", + py::arg("num_rows"), + py::arg("num_cols")) + .def_readwrite("num_rows", &MeshShape::num_rows, "Number of rows in the mesh.") + .def_readwrite("num_cols", &MeshShape::num_cols, "Number of columns in the mesh.") + .def( + "__repr__", + [](const MeshShape& ms) { + return ""; + }) + .def("__iter__", [](const MeshShape& ms) { return py::iter(py::make_tuple(ms.num_rows, ms.num_cols)); }); + static_cast>(module.attr("MeshOffset")) + .def( + py::init([](size_t row, size_t col) { return MeshOffset(row, col); }), + "Constructor with specified row and column offsets.", + py::arg("row"), + py::arg("col")) + .def_readwrite("row", &MeshOffset::row, "Row offset in the mesh.") + .def_readwrite("col", &MeshOffset::col, "Column offset in the mesh.") + .def( + "__repr__", + [](const MeshOffset& mo) { + return ""; + }) + .def("__iter__", [](const MeshOffset& mo) { return py::iter(py::make_tuple(mo.row, mo.col)); }); + auto py_mesh_device = static_cast>>(module.attr("MeshDevice")); py_mesh_device .def( @@ -36,7 +70,7 @@ void py_module(py::module& module) { size_t trace_region_size, size_t num_command_queues, const DispatchCoreConfig& dispatch_core_config, - const std::pair& offset, + const MeshOffset& offset, const std::vector& physical_device_ids, MeshType mesh_type) { return MeshDevice::create( @@ -134,7 +168,10 @@ void py_module(py::module& module) { R"doc( Disable program cache across all devices in the mesh. )doc") - .def_property_readonly("shape", &MeshDevice::shape, R"doc( + .def_property_readonly( + "shape", + &MeshDevice::shape, + R"doc( Get the shape of the device mesh. Returns: @@ -153,6 +190,26 @@ void py_module(py::module& module) { Args: sub_devices (List[ttnn.SubDevice]): The sub-devices to include in the sub-device manager. + This configuration will be used for each device in the MeshDevice. + local_l1_size (int): The size of the local allocators of each sub-device. The global allocator will be shrunk by this amount. + + Returns: + MeshSubDeviceManagerId: The ID of the created sub-device manager. + )doc") + .def( + "create_sub_device_manager", + [](MeshDevice& self, + const std::vector>& mesh_sub_devices, + DeviceAddr local_l1_size) { return self.create_sub_device_manager(mesh_sub_devices, local_l1_size); }, + py::arg("sub_devices"), + py::arg("local_l1_size"), + R"doc( + Creates a sub-device manager for the given mesh device. + + Args: + mesh_sub_devices (List[List[ttnn.SubDevice]]): The sub-devices to include in the sub-device manager. + Each element of the outer list will be used to configure the corresponding device in the MeshDevice. + This means that the individual devices in the MeshDevice may have different configurations. local_l1_size (int): The size of the local allocators of each sub-device. The global allocator will be shrunk by this amount. Returns: @@ -193,7 +250,6 @@ void py_module(py::module& module) { py::arg("l1_small_size"), py::arg("trace_region_size"), py::arg("num_command_queues"), - py::arg("offset"), py::arg("physical_device_ids"), py::arg("mesh_type"), @@ -233,7 +289,11 @@ void py_module(py::module& module) { Tensor: The shard of the tensor corresponding to the device. )doc"); module.def("get_device_tensors", &get_device_tensors, py::arg("tensor"), py::kw_only()); - module.def("aggregate_as_tensor", &aggregate_as_tensor, py::arg("tensors"), py::kw_only()); + module.def( + "aggregate_as_tensor", + [](const std::vector& tensors) -> Tensor { return aggregate_as_tensor(tensors, AllGatherTensor{}); }, + py::arg("tensors"), + py::kw_only()); module.def("get_t3k_physical_device_ids_ring", &get_t3k_physical_device_ids_ring); } diff --git a/ttnn/cpp/ttnn/distributed/distributed_tensor_config.cpp b/ttnn/cpp/ttnn/distributed/distributed_tensor_config.cpp new file mode 100644 index 00000000000..9ae22852fd5 --- /dev/null +++ b/ttnn/cpp/ttnn/distributed/distributed_tensor_config.cpp @@ -0,0 +1,57 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include +#include + +#include "common/assert.hpp" +#include "ttnn/distributed/distributed_tensor_config.hpp" + +namespace tt::tt_metal { +namespace { + +DistributedTensorConfig create_shard_distributed_tensor_config( + const std::unordered_map& metadata) { + return ShardTensor(std::stoi(metadata.at("shard_dim"))); +} +DistributedTensorConfig create_shard_2d_distributed_tensor_config( + const std::unordered_map& metadata) { + return ShardTensor2D(ShardMesh(std::stoi(metadata.at("mesh_shape_y")), std::stoi(metadata.at("mesh_shape_x")))); +} +DistributedTensorConfig create_replicate_distributed_tensor_config( + const std::unordered_map& metadata) { + if (auto it = metadata.find("replication_factor"); it != metadata.end()) { + return ReplicateTensor(std::stoi(it->second)); + } + TT_THROW("Unsupported Replication strategy:"); +} +} // namespace + +DistributedTensorConfig get_distributed_tensor_config(const std::unordered_map& metadata) { + if (auto it = metadata.find("strategy"); it != metadata.end()) { + const std::string& strategy = it->second; + if (strategy == "shard") { + return create_shard_distributed_tensor_config(metadata); + } else if (strategy == "shard_2d") { + return create_shard_2d_distributed_tensor_config(metadata); + } else if (strategy == "replicate") { + return create_replicate_distributed_tensor_config(metadata); + } + } + TT_THROW("Unsupported DistributedTensorConfig strategy:"); +} + +bool operator==(const ReplicateTensor& a, const ReplicateTensor& b) { + return a.replication_factor == b.replication_factor; +} +bool operator==(const AllGatherTensor&, const AllGatherTensor&) { + // All instances are considered equal because there are no data members. + return true; +} +bool operator==(const ShardTensor& lhs, const ShardTensor& rhs) { return lhs.shard_dimension == rhs.shard_dimension; } +bool operator==(const ShardTensor2D& lhs, const ShardTensor2D& rhs) { + return lhs.shard_mesh.x == rhs.shard_mesh.x && lhs.shard_mesh.y == rhs.shard_mesh.y; +} + +} // namespace tt::tt_metal diff --git a/ttnn/cpp/ttnn/distributed/distributed_tensor_config.hpp b/ttnn/cpp/ttnn/distributed/distributed_tensor_config.hpp new file mode 100644 index 00000000000..5f67262028e --- /dev/null +++ b/ttnn/cpp/ttnn/distributed/distributed_tensor_config.hpp @@ -0,0 +1,43 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include +#include + +namespace tt::tt_metal { + +struct ReplicateTensor { + int replication_factor = 1; + ReplicateTensor() = default; + ReplicateTensor(int replication_factor) : replication_factor(replication_factor) {} +}; +bool operator==(const ReplicateTensor&, const ReplicateTensor&); +struct ShardTensor { + int shard_dimension; + ShardTensor(int shard_dimension) : shard_dimension(shard_dimension) {} +}; +bool operator==(const ShardTensor& lhs, const ShardTensor& rhs); + +struct ShardMesh { + std::uint16_t y = 0; + std::uint16_t x = 0; +}; + +struct ShardTensor2D { + ShardMesh shard_mesh; // logic 2D grid that defines the mapping of shards to devices + ShardTensor2D(ShardMesh mesh) : shard_mesh(std::move(mesh)) {} +}; +bool operator==(const ShardTensor2D& lhs, const ShardTensor2D& rhs); + +struct AllGatherTensor {}; +bool operator==(const AllGatherTensor&, const AllGatherTensor&); + +// DistributedTensorConfig is a variant of different ways in which a tensor can be distributed across devices. +using DistributedTensorConfig = std::variant; +DistributedTensorConfig get_distributed_tensor_config(const std::unordered_map& metadata); + +} // namespace tt::tt_metal diff --git a/ttnn/cpp/ttnn/distributed/types.hpp b/ttnn/cpp/ttnn/distributed/types.hpp index 557d10c90ec..bdb17e71ac1 100644 --- a/ttnn/cpp/ttnn/distributed/types.hpp +++ b/ttnn/cpp/ttnn/distributed/types.hpp @@ -13,6 +13,7 @@ namespace ttnn::distributed { using MeshShape = tt::tt_metal::distributed::MeshShape; +using MeshOffset = tt::tt_metal::distributed::MeshOffset; using DeviceIds = tt::tt_metal::distributed::DeviceIds; using MeshDevice = tt::tt_metal::distributed::MeshDevice; using MeshDeviceView = tt::tt_metal::distributed::MeshDeviceView; @@ -29,6 +30,7 @@ using ttnn::distributed::DeviceIds; using ttnn::distributed::MeshDevice; using ttnn::distributed::MeshDeviceConfig; using ttnn::distributed::MeshDeviceView; +using ttnn::distributed::MeshOffset; using ttnn::distributed::MeshShape; using ttnn::distributed::MeshSubDeviceManagerId; using ttnn::distributed::MeshType; diff --git a/ttnn/cpp/ttnn/events.cpp b/ttnn/cpp/ttnn/events.cpp index 789cdd36c6e..a38da21fccb 100644 --- a/ttnn/cpp/ttnn/events.cpp +++ b/ttnn/cpp/ttnn/events.cpp @@ -29,9 +29,11 @@ std::shared_ptr create_event(Device* device) { return event; } -void record_event(uint8_t cq_id, const std::shared_ptr& event) { +void record_event(uint8_t cq_id, const std::shared_ptr& event, const std::vector& sub_device_ids) { Device* device = event->device; - device->push_work([device, event, cq_id] { EnqueueRecordEvent(device->command_queue(cq_id), event); }); + device->push_work([device, event, cq_id, sub_device_ids] { + EnqueueRecordEvent(device->command_queue(cq_id), event, sub_device_ids); + }); } void wait_for_event(uint8_t cq_id, const std::shared_ptr& event) { @@ -41,9 +43,10 @@ void wait_for_event(uint8_t cq_id, const std::shared_ptr& event) { MultiDeviceEvent create_event(MeshDevice* mesh_device) { return MultiDeviceEvent(mesh_device); } -void record_event(uint8_t cq_id, const MultiDeviceEvent& multi_device_event) { +void record_event( + uint8_t cq_id, const MultiDeviceEvent& multi_device_event, const std::vector& sub_device_ids) { for (auto& event : multi_device_event.events) { - record_event(cq_id, event); + record_event(cq_id, event, sub_device_ids); } } diff --git a/ttnn/cpp/ttnn/events.hpp b/ttnn/cpp/ttnn/events.hpp index 57405fa9526..d4c409338c6 100644 --- a/ttnn/cpp/ttnn/events.hpp +++ b/ttnn/cpp/ttnn/events.hpp @@ -16,11 +16,12 @@ struct MultiDeviceEvent { }; // Single Device APIs std::shared_ptr create_event(Device* device); -void record_event(uint8_t cq_id, const std::shared_ptr& event); +void record_event( + uint8_t cq_id, const std::shared_ptr& event, const std::vector& sub_device_ids = {}); void wait_for_event(uint8_t cq_id, const std::shared_ptr& event); // Multi Device APIs MultiDeviceEvent create_event(MeshDevice* mesh_device); -void record_event(uint8_t cq_id, const MultiDeviceEvent& event); +void record_event(uint8_t cq_id, const MultiDeviceEvent& event, const std::vector& sub_device_ids = {}); void wait_for_event(uint8_t cq_id, const MultiDeviceEvent& event); } // namespace ttnn::events diff --git a/ttnn/cpp/ttnn/operations/ccl/CMakeLists.txt b/ttnn/cpp/ttnn/operations/ccl/CMakeLists.txt new file mode 100644 index 00000000000..148d928be91 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/ccl/CMakeLists.txt @@ -0,0 +1,26 @@ +set(CCL_TTNN_SRCS + # Common + ${CMAKE_CURRENT_SOURCE_DIR}/erisc_datamover_builder.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/ccl_op_fusion.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/ccl_common.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/ccl_host_datastructures.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/common/types/ccl_types_args_emitters.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/common/uops/ccl_command.cpp + # CCL Ops + ${CMAKE_CURRENT_SOURCE_DIR}/all_gather/all_gather.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/all_gather/all_gather_pybind.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/all_gather/device/all_gather_op.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/all_gather/device/multi_core/all_gather_op_multi_core.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/reduce_scatter/device/host/reduce_scatter_full_worker_grid.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/reduce_scatter/device/reduce_scatter_op.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/reduce_scatter/reduce_scatter.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/reduce_scatter/reduce_scatter_pybind.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/reduce_scatter/host/reduce_scatter_worker_builder.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/reduce_scatter/host/reduce_scatter_common.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/barrier/device/host/barrier_full_worker_grid.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/barrier/device/barrier_op.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/barrier/barrier.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/barrier/barrier_pybind.cpp + CACHE INTERNAL + "CCL sources to reuse in ttnn build" +) diff --git a/ttnn/cpp/ttnn/operations/ccl/erisc_datamover_builder.hpp b/ttnn/cpp/ttnn/operations/ccl/erisc_datamover_builder.hpp index f69424cdd0d..bf330aca910 100644 --- a/ttnn/cpp/ttnn/operations/ccl/erisc_datamover_builder.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/erisc_datamover_builder.hpp @@ -9,7 +9,7 @@ #include #include "eth_l1_address_map.h" -#include "umd/device/tt_cluster_descriptor_types.h" +#include "umd/device/types/cluster_descriptor_types.h" #include "ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_types.hpp" #include "ttnn/cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp" diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp index 5f9ba6f0ea9..d6d06ec490f 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp @@ -17,10 +17,6 @@ #include "ttnn/operations/core/compute_kernel/compute_kernel_config.hpp" #include "ttnn/operations/core/core.hpp" #include "ttnn/operations/pool/downsample/device/downsample_op.hpp" -#include "tt_metal/detail/reports/memory_reporter.hpp" -#include "tt_metal/common/work_split.hpp" -#include "ttnn/operations/eltwise/unary/common/unary_op_utils.hpp" -#include "ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.hpp" #include "ttnn/operations/sliding_window/sliding_window.hpp" #include "ttnn/tensor/tensor.hpp" #include "ttnn/tensor/types.hpp" @@ -54,6 +50,7 @@ Result conv2d( uint32_t groups, std::optional bias_tensor, const std::optional& conv_config_, + const std::optional& compute_config_, const std::optional& memory_config) { const bool mm_conv = use_matmul_for_1x1_conv(kernel_size, stride, padding, dilation, groups); const uint32_t output_height = ((input_height - kernel_size[0] - ((kernel_size[0] - 1 ) * (dilation[0] - 1)) + 2 * padding[0]) / stride[0]) + 1; @@ -89,6 +86,14 @@ Result conv2d( (conv_config.weights_dtype == DataType::BFLOAT8_B || conv_config.weights_dtype == DataType::BFLOAT16) && conv_config.output_layout == Layout::ROW_MAJOR && ((elem_size * in_channels) % (16 * num_cores_c)) == 0; + DeviceComputeKernelConfig compute_config = compute_config_.value_or( init_device_compute_kernel_config( + device->arch(), + std::nullopt, + MathFidelity::HiFi4, + true, + false, + false + )); auto [input_tensor_post_tm, parallel_config, output_parallel_config, tensor_manipulated, use_non_tile_height] = shard_or_reshard_tensor_if_required( device, input_tensor, conv_config, batch_size, output_height, output_width, in_channels, out_channels, mm_conv, is_non_tile_mul_width); if (tensor_manipulated) { @@ -138,7 +143,7 @@ Result conv2d( conv_config.act_block_w_div, kernel_size[0], kernel_size[1], - conv_config.fp32_dest_acc_enabled, + get_fp32_dest_acc_en(compute_config), conv_config.enable_split_reader); bool weight_is_on_device = ttnn::is_tensor_on_device_or_multidevice(weight_tensor); ttnn::Tensor weight_tensor_on_device = weight_tensor; @@ -173,13 +178,6 @@ Result conv2d( // call optimized conv op or matmul micro op bool input_is_on_device = ttnn::is_tensor_on_device_or_multidevice(input_tensor_post_tm); TT_ASSERT(input_is_on_device); - DeviceComputeKernelConfig compute_kernel_config = ttnn::init_device_compute_kernel_config( - device->arch(), - std::nullopt, - conv_config.math_fidelity, - conv_config.math_approx_mode_enabled, - conv_config.fp32_dest_acc_enabled, - conv_config.packer_l1_accum_enabled); if (!mm_conv) { // call halo op @@ -238,14 +236,13 @@ Result conv2d( groups, conv_config.output_layout == Layout::ROW_MAJOR, conv_config.activation == "relu", - conv_config.math_fidelity, opt_conv_op_parallel_config, opt_conv_op_block_config, conv_out_memory_config, conv_config.dtype, {batch_size, input_height, input_width, in_channels}, conv_config.input_channels_alignment == 16, - compute_kernel_config, + compute_config, conv_config.enable_act_double_buffer, conv_config.enable_weights_double_buffer, conv_config.enable_split_reader, @@ -284,7 +281,7 @@ Result conv2d( /*bcast_batch=*/std::nullopt, conv_out_memory_config, conv_config.dtype, - compute_kernel_config}); + compute_config}); if (conv_config.deallocate_activation) { ttnn::operations::core::deallocate(matmul_input); } @@ -314,8 +311,9 @@ Result Conv2dOperation::invoke( uint32_t groups, std::optional bias_tensor, const std::optional& conv_config_, + const std::optional& compute_config_, const std::optional& memory_config){ - return conv2d(input_tensor, weight_tensor, device, in_channels, out_channels, batch_size, input_height, input_width, kernel_size, stride, padding, dilation, groups, std::move(bias_tensor), std::move(conv_config_), memory_config); + return conv2d(input_tensor, weight_tensor, device, in_channels, out_channels, batch_size, input_height, input_width, kernel_size, stride, padding, dilation, groups, std::move(bias_tensor), std::move(conv_config_), std::move(compute_config_), memory_config); } Result Conv2dOperation::invoke( @@ -335,10 +333,12 @@ Result Conv2dOperation::invoke( uint32_t groups, std::optional bias_tensor, const std::optional& conv_config_, + const std::optional& compute_config_, const std::optional& memory_config){ - return conv2d(input_tensor, weight_tensor, device, in_channels, out_channels, batch_size, input_height, input_width, kernel_size, stride, padding, dilation, groups, std::move(bias_tensor), std::move(conv_config_), memory_config); + return conv2d(input_tensor, weight_tensor, device, in_channels, out_channels, batch_size, input_height, input_width, kernel_size, stride, padding, dilation, groups, std::move(bias_tensor), std::move(conv_config_), std::move(compute_config_), memory_config); } + } // namespace conv2d } // namespace operations } // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.hpp b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.hpp index d15023abb86..e8310c0dbdc 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.hpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.hpp @@ -47,6 +47,7 @@ Result conv2d( uint32_t groups, std::optional bias_tensor = std::nullopt, const std::optional& conv_config_ = std::nullopt, + const std::optional& compute_config_ = std::nullopt, const std::optional& memory_config = std::nullopt); @@ -68,6 +69,7 @@ struct Conv2dOperation{ uint32_t groups, std::optional bias_tensor = std::nullopt, const std::optional& conv_config_ = std::nullopt, + const std::optional& compute_config_ = std::nullopt, const std::optional& memory_config = std::nullopt); static Result invoke( @@ -87,6 +89,7 @@ struct Conv2dOperation{ uint32_t groups, std::optional bias_tensor = std::nullopt, const std::optional& conv_config_ = std::nullopt, + const std::optional& compute_config_ = std::nullopt, const std::optional& memory_config = std::nullopt); }; } // namespace conv2d diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_pybind.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_pybind.cpp index 6ac28cf56ca..c3356447cab 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_pybind.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_pybind.cpp @@ -60,10 +60,11 @@ void py_bind_conv2d(py::module& module) { std::array dilation, uint32_t groups, std::optional bias_tensor, - std::optional conv_config, + const std::optional& conv_config, + const std::optional& compute_config, const std::optional& memory_config, const uint8_t& queue_id) -> Result { - return self(queue_id, input_tensor, weight_tensor, device, in_channels, out_channels, batch_size, input_height, input_width, kernel_size, stride, padding, dilation, groups, bias_tensor, conv_config, memory_config); + return self(queue_id, input_tensor, weight_tensor, device, in_channels, out_channels, batch_size, input_height, input_width, kernel_size, stride, padding, dilation, groups, bias_tensor, conv_config, compute_config, memory_config); }, py::kw_only(), py::arg("input_tensor"), @@ -81,6 +82,7 @@ void py_bind_conv2d(py::module& module) { py::arg("groups"), py::arg("bias_tensor") = std::nullopt, py::arg("conv_config") = std::nullopt, + py::arg("compute_config") = std::nullopt, py::arg("memory_config") = std::nullopt, py::arg("queue_id") = 0}, @@ -99,10 +101,11 @@ void py_bind_conv2d(py::module& module) { std::array dilation, uint32_t groups, std::optional bias_tensor, - std::optional conv_config, + const std::optional& conv_config, + const std::optional& compute_config, const std::optional& memory_config, const uint8_t& queue_id) -> Result { - return self(queue_id, input_tensor, weight_tensor, device, in_channels, out_channels, batch_size, input_height, input_width, kernel_size, stride, padding, dilation, groups, bias_tensor, conv_config, memory_config); + return self(queue_id, input_tensor, weight_tensor, device, in_channels, out_channels, batch_size, input_height, input_width, kernel_size, stride, padding, dilation, groups, bias_tensor, conv_config, compute_config, memory_config); }, py::kw_only(), py::arg("input_tensor"), @@ -120,6 +123,7 @@ void py_bind_conv2d(py::module& module) { py::arg("groups"), py::arg("bias_tensor") = std::nullopt, py::arg("conv_config") = std::nullopt, + py::arg("compute_config") = std::nullopt, py::arg("memory_config") = std::nullopt, py::arg("queue_id") = 0} ); @@ -143,7 +147,8 @@ void py_bind_conv2d(py::module& module) { py::arg("dilation"), py::arg("groups"), py::arg("device"), - py::arg("conv_config") = std::nullopt); + py::arg("conv_config") = std::nullopt, + py::arg("compute_config") = std::nullopt); module.def( @@ -165,7 +170,8 @@ void py_bind_conv2d(py::module& module) { py::arg("dilation"), py::arg("groups"), py::arg("device"), - py::arg("conv_config") = std::nullopt); + py::arg("conv_config") = std::nullopt, + py::arg("compute_config") = std::nullopt); module.def( "prepare_conv_bias", @@ -185,7 +191,8 @@ void py_bind_conv2d(py::module& module) { py::arg("dilation"), py::arg("groups"), py::arg("device"), - py::arg("conv_config") = std::nullopt); + py::arg("conv_config") = std::nullopt, + py::arg("compute_config") = std::nullopt); module.def( "prepare_conv_bias", @@ -205,7 +212,8 @@ void py_bind_conv2d(py::module& module) { py::arg("dilation"), py::arg("groups"), py::arg("device"), - py::arg("conv_config") = std::nullopt); + py::arg("conv_config") = std::nullopt, + py::arg("compute_config") = std::nullopt); module.def( "convert_conv_weight_tensor_to_tiled_layout", @@ -266,14 +274,10 @@ void py_bind_conv2d(py::module& module) { auto py_conv_config = py::class_(module, "Conv2dConfig"); py_conv_config.def( - py::init, std::optional, bool, Layout, bool, bool, bool, bool>(), + py::init, std::optional, bool, Layout, bool, bool, bool, bool>(), py::kw_only(), - py::arg("math_fidelity") = MathFidelity::HiFi4, py::arg("dtype") = DataType::BFLOAT16, py::arg("weights_dtype") = DataType::BFLOAT16, - py::arg("math_approx_mode_enabled") = true, - py::arg("fp32_dest_acc_enabled") = false, - py::arg("packer_l1_accum_enabled") = false, py::arg("activation") = "", py::arg("input_channels_alignment") = 32, py::arg("deallocate_activation") = false, @@ -291,12 +295,8 @@ void py_bind_conv2d(py::module& module) { py::arg("enable_split_reader") = false, py::arg("enable_subblock_padding") = false ); - py_conv_config.def_readwrite("math_fidelity", &Conv2dConfig::math_fidelity); py_conv_config.def_readwrite("dtype", &Conv2dConfig::dtype); py_conv_config.def_readwrite("weights_dtype", &Conv2dConfig::weights_dtype); - py_conv_config.def_readwrite("math_approx_mode_enabled", &Conv2dConfig::math_approx_mode_enabled); - py_conv_config.def_readwrite("fp32_dest_acc_enabled", &Conv2dConfig::fp32_dest_acc_enabled); - py_conv_config.def_readwrite("packer_l1_accum_enabled", &Conv2dConfig::packer_l1_accum_enabled); py_conv_config.def_readwrite("activation", &Conv2dConfig::activation); py_conv_config.def_readwrite("input_channels_alignment", &Conv2dConfig::input_channels_alignment); py_conv_config.def_readwrite("deallocate_activation", &Conv2dConfig::deallocate_activation); diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.hpp b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.hpp index 9b9645f821f..349e3837329 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.hpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.hpp @@ -30,12 +30,8 @@ using OutputWidth = uint32_t; using Result = std::tuple>; struct Conv2dConfig { - MathFidelity math_fidelity = MathFidelity::HiFi4; DataType dtype = DataType::BFLOAT16; DataType weights_dtype = DataType::BFLOAT16; - bool math_approx_mode_enabled = true; - bool fp32_dest_acc_enabled = false; - bool packer_l1_accum_enabled = false; string activation = ""; uint32_t input_channels_alignment = 32; bool deallocate_activation = false; @@ -54,12 +50,8 @@ struct Conv2dConfig { bool enable_split_reader = false; bool enable_subblock_padding = false; static constexpr auto attribute_names = std::make_tuple( - "math_fidelity", "dtype", "weights_dtype", - "math_approx_mode_enabled", - "fp32_dest_acc_enabled", - "packer_l1_accum_enabled", "activation", "input_channels_alignment", "deallocate_activation", @@ -78,12 +70,8 @@ struct Conv2dConfig { "enable_subblock_padding"); const auto attribute_values() const { return std::make_tuple( - std::cref(this->math_fidelity), std::cref(this->dtype), std::cref(this->weights_dtype), - std::cref(this->math_approx_mode_enabled), - std::cref(this->fp32_dest_acc_enabled), - std::cref(this->packer_l1_accum_enabled), std::cref(this->activation), std::cref(this->input_channels_alignment), std::cref(this->deallocate_activation), diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op.cpp index 9d57c98db84..e09aa621dd5 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op.cpp @@ -17,6 +17,7 @@ #include "tt_metal/tt_stl/reflection.hpp" #include "tt_metal/common/work_split.hpp" +#include "ttnn/operations/core/compute_kernel/compute_kernel_config.hpp" #include "ttnn/operations/sharding_utilities.hpp" #include "ttnn/operations/experimental/auto_format/auto_format.hpp" @@ -57,14 +58,14 @@ Tensor optimized_conv_new(const Tensor& a, const Tensor &b, std::optional input_tensor_shape, bool use_shallow_conv_variant, - std::optional compute_kernel_config, + const DeviceComputeKernelConfig& compute_kernel_config, bool enable_act_double_buffer, bool enable_weights_double_buffer, bool enable_split_reader, @@ -73,7 +74,7 @@ Tensor optimized_conv_new(const Tensor& a, const Tensor &b, std::optional output_tensors = {Tensor(operation::get_workers_for_op_output({a, b}))}; operation::launch_op( - [sliding_window_config, output_channels, groups, untilize_out, fuse_relu, math_fidelity, parallelization_config, block_config, memory_config, dtype, input_tensor_shape, use_shallow_conv_variant, compute_kernel_config, enable_act_double_buffer, enable_weights_double_buffer, enable_split_reader, enable_subblock_padding, use_non_tile_height] + [sliding_window_config, output_channels, groups, untilize_out, fuse_relu, parallelization_config, block_config, memory_config, dtype, input_tensor_shape, use_shallow_conv_variant, compute_kernel_config, enable_act_double_buffer, enable_weights_double_buffer, enable_split_reader, enable_subblock_padding, use_non_tile_height] (const std::vector& input_tensors, const std::vector>& optional_input_tensors, const std::vector>& optional_output_tensors) mutable -> std::vector { using ttnn::operations::experimental::auto_format::FormatParams; auto& a = input_tensors.at(0); @@ -91,9 +92,8 @@ Tensor optimized_conv_new(const Tensor& a, const Tensor &b, std::optionalarch() : ttnn::operations::experimental::auto_format::AutoFormat::GetDefaultDevice()->arch(); bool fp32_accum = a.device()->arch() == tt::ARCH::WORMHOLE_B0; // && compute_kernel_config.has_value()) ? compute_kernel_config.value().fp32_dest_acc_en : false; - auto kernel_config_val = init_device_compute_kernel_config(arch, compute_kernel_config, MathFidelity::LoFi, true, fp32_accum, false); return operation::run_without_autoformat( - OptimizedConvNew(sliding_window_config, output_channels, groups, untilize_out, bias.has_value(), fuse_relu, math_fidelity, parallelization_config, block_config, memory_config, dtype, input_tensor_shape, use_shallow_conv_variant, kernel_config_val, enable_act_double_buffer, enable_weights_double_buffer, enable_split_reader, enable_subblock_padding, use_non_tile_height + OptimizedConvNew(sliding_window_config, output_channels, groups, untilize_out, bias.has_value(), fuse_relu, parallelization_config, block_config, memory_config, dtype, input_tensor_shape, use_shallow_conv_variant, compute_kernel_config, enable_act_double_buffer, enable_weights_double_buffer, enable_split_reader, enable_subblock_padding, use_non_tile_height ), input_tensors, optional_input_tensors); @@ -219,7 +219,7 @@ operation::ProgramWithCallbacks OptimizedConvNew::create_program(const std::vect sliding_window_config, output_channels, groups, - untilize_out, fuse_relu, math_fidelity, + untilize_out, fuse_relu, parallelization_config, block_config, dtype, @@ -265,7 +265,7 @@ operation::OpPerformanceModel OptimizedConvNew::create_op_performance_model(cons int64_t num_mul_adds_per_elem = conv_activation_c * filter_h * filter_w * 2; // 1 multiply and 1 add per element int64_t num_mul_adds = num_mul_adds_per_elem * output_height * output_width * this->output_channels * batch_size; - int ideal_dev_clock_cycles = std::ceil(((float)num_mul_adds / (float)(num_cores * tensix_mul_adds_per_cycle_lofi)) * (float)operation::OpPerformanceModel::fidelity_multiplier(this->math_fidelity)); + int ideal_dev_clock_cycles = std::ceil(((float)num_mul_adds / (float)(num_cores * tensix_mul_adds_per_cycle_lofi)) * (float)operation::OpPerformanceModel::fidelity_multiplier(get_math_fidelity(this->compute_kernel_config))); operation::OpPerformanceModel result(input_tensors, output_tensors, ideal_dev_clock_cycles); diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op.hpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op.hpp index 830ca917e33..a39e97f4fac 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op.hpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op.hpp @@ -47,7 +47,7 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_new(const T const sliding_window::SlidingWindowConfig& sliding_window_config, uint32_t output_channels, uint32_t groups, - bool untilize_out, bool fuse_relu, MathFidelity math_fidelity, + bool untilize_out, bool fuse_relu, const OptimizedConvParallelizationConfig& parallelization_config, const OptimizedConvBlockConfig& block_config, DataType dtype, @@ -69,7 +69,6 @@ struct OptimizedConvNew { const uint32_t output_channels; const uint32_t groups; bool untilize_out, has_bias, fuse_relu; - MathFidelity math_fidelity; MemoryConfig memory_config; const DataType dtype; std::array input_tensor_shape; // For sharded input, input tensor shape is nonsense @@ -84,7 +83,7 @@ struct OptimizedConvNew { uint32_t output_channels, uint32_t groups, bool untile_out, bool has_bias, bool fuse_relu, - MathFidelity mfidelity, const OptimizedConvParallelizationConfig& p_config, + const OptimizedConvParallelizationConfig& p_config, const OptimizedConvBlockConfig& b_config, MemoryConfig memory_config, DataType dtype, @@ -96,7 +95,6 @@ struct OptimizedConvNew { untilize_out(untile_out), has_bias(has_bias), fuse_relu(fuse_relu), - math_fidelity(mfidelity), parallelization_config(p_config), block_config(b_config), memory_config(memory_config), @@ -124,7 +122,6 @@ struct OptimizedConvNew { "untilize_out", "has_bias", "fuse_relu", - "math_fidelity", "dtype", "input_tensor_shape", "use_shallow_conv_variant", @@ -141,7 +138,6 @@ struct OptimizedConvNew { std::cref(this->untilize_out), std::cref(this->has_bias), std::cref(this->fuse_relu), - std::cref(this->math_fidelity), std::cref(this->dtype), std::cref(this->input_tensor_shape), std::cref(this->use_shallow_conv_variant), @@ -156,14 +152,14 @@ Tensor optimized_conv_new(const Tensor& a, const Tensor &b, std::optional input_tensor_shape, bool use_shallow_conv_variant, - std::optional compute_kernel_config = std::nullopt, + const DeviceComputeKernelConfig& compute_kernel_config, bool enable_act_double_buffer = false, bool enable_weights_double_buffer = false, bool enable_split_reader = false, diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op_sharded_program_factory.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op_sharded_program_factory.cpp index 0b452a583df..7c0544a8c69 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op_sharded_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op_sharded_program_factory.cpp @@ -1793,7 +1793,6 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_new( uint32_t groups, bool untilize_out, bool fuse_relu, - MathFidelity math_fidelity, const OptimizedConvParallelizationConfig& parallelization_config, const OptimizedConvBlockConfig& block_config, DataType output_dtype, diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.cpp index 668372c49a4..1009ed7a87b 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.cpp @@ -4,6 +4,7 @@ #include "prepare_conv2d_weights.hpp" #include "conv2d_utils.hpp" +#include "ttnn/operations/core/compute_kernel/compute_kernel_config.hpp" #include #include @@ -67,6 +68,7 @@ OptimizedConvBlockConfig get_opt_block_config( T *device, Conv2dConfig& conv_config, Layout input_tensor_layout, + const DeviceComputeKernelConfig& compute_config, const MemoryConfig& input_memory_config) { auto compute_grid_size = device->compute_with_storage_grid_size(); @@ -138,7 +140,7 @@ OptimizedConvBlockConfig get_opt_block_config( conv_config.act_block_w_div, kernel_size[0], kernel_size[1], - conv_config.fp32_dest_acc_enabled, + get_fp32_dest_acc_en(compute_config), conv_config.enable_split_reader); } @@ -289,9 +291,11 @@ ttnn::Tensor prepare_conv_weights( std::array dilation, uint32_t groups, T *device, - const std::optional& conv_config_) { + const std::optional& conv_config_, + const std::optional& compute_config_) { TT_FATAL(!ttnn::is_tensor_on_device_or_multidevice(weight_tensor), "Error: weight tensor must be on host for preparation."); Conv2dConfig conv_config = conv_config_.value_or(Conv2dConfig()); + DeviceComputeKernelConfig compute_config = compute_config_.value_or(DeviceComputeKernelConfig()); const bool mm_conv = use_matmul_for_1x1_conv(kernel_size, stride, padding, dilation, groups); const uint32_t output_height = ((input_height - kernel_size[0] - ((kernel_size[0] - 1 ) * (dilation[0] - 1)) + 2 * padding[0]) / stride[0]) + 1; const uint32_t output_width = @@ -309,6 +313,7 @@ ttnn::Tensor prepare_conv_weights( device, conv_config, input_tensor_layout, + compute_config, input_memory_config ); @@ -366,7 +371,8 @@ ttnn::Tensor prepare_conv_bias( std::array dilation, uint32_t groups, T *device, - const std::optional& conv_config_) { + const std::optional& conv_config_, + const std::optional& compute_config_) { TT_FATAL(!ttnn::is_tensor_on_device_or_multidevice(bias_tensor), "Error: bias tensor must be on host for preparation."); @@ -376,6 +382,7 @@ ttnn::Tensor prepare_conv_bias( ((input_width - kernel_size[1] - ((kernel_size[0] - 1) * (dilation[0] - 1)) + 2 * padding[1]) / stride[1]) + 1; Conv2dConfig conv_config = conv_config_.value_or(Conv2dConfig()); + DeviceComputeKernelConfig compute_config = compute_config_.value_or(DeviceComputeKernelConfig()); auto opt_conv_op_block_config = get_opt_block_config( mm_conv, in_channels, @@ -389,6 +396,7 @@ ttnn::Tensor prepare_conv_bias( device, conv_config, input_tensor_layout, + compute_config, input_memory_config ); @@ -423,6 +431,7 @@ template OptimizedConvBlockConfig get_opt_block_config( Device *device, Conv2dConfig& conv_config, Layout input_tensor_layout, + const DeviceComputeKernelConfig& compute_config, const ttnn::MemoryConfig& input_memory_config); template OptimizedConvBlockConfig get_opt_block_config( @@ -438,6 +447,7 @@ template OptimizedConvBlockConfig get_opt_block_config( MeshDevice *device, Conv2dConfig& conv_config, Layout input_tensor_layout, + const DeviceComputeKernelConfig& compute_config, const ttnn::MemoryConfig& input_memory_config); template ttnn::Tensor prepare_conv_weights( @@ -456,7 +466,8 @@ template ttnn::Tensor prepare_conv_weights( std::array dilation, uint32_t groups, Device *device, - const std::optional& conv_config_); + const std::optional& conv_config_, + const std::optional& compute_config_); template ttnn::Tensor prepare_conv_weights( const ttnn::Tensor& weight_tensor, @@ -474,7 +485,8 @@ template ttnn::Tensor prepare_conv_weights( std::array dilation, uint32_t groups, MeshDevice *device, - const std::optional& conv_config_); + const std::optional& conv_config_, + const std::optional& compute_config_); template std::pair> prepare_conv_weights_biases_and_move_to_device( const ttnn::Tensor& weight_tensor, @@ -521,7 +533,8 @@ template ttnn::Tensor prepare_conv_bias( std::array dilation, uint32_t groups, Device *device, - const std::optional& conv_config_); + const std::optional& conv_config_, + const std::optional& compute_config_); template ttnn::Tensor prepare_conv_bias( const ttnn::Tensor& bias_tensor, @@ -538,7 +551,8 @@ template ttnn::Tensor prepare_conv_bias( std::array dilation, uint32_t groups, MeshDevice *device, - const std::optional& conv_config_); + const std::optional& conv_config_, + const std::optional& compute_config_); } // namespace conv2d } // namespace operations diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.hpp b/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.hpp index 18e654ad37c..35b80dac824 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.hpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.hpp @@ -43,7 +43,8 @@ ttnn::Tensor prepare_conv_weights( std::array dilation, uint32_t groups, T *device, - const std::optional& conv_config_); + const std::optional& conv_config_, + const std::optional& compute_config_); template ttnn::Tensor prepare_conv_bias( @@ -61,7 +62,8 @@ ttnn::Tensor prepare_conv_bias( std::array dilation, uint32_t groups, T *device, - const std::optional& conv_config_); + const std::optional& conv_config_, + const std::optional& compute_config_); template std::pair> prepare_conv_weights_biases_and_move_to_device( diff --git a/ttnn/cpp/ttnn/operations/conv/conv_transpose2d/conv_transpose2d.cpp b/ttnn/cpp/ttnn/operations/conv/conv_transpose2d/conv_transpose2d.cpp index e2c54193bb0..21af1f921fb 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv_transpose2d/conv_transpose2d.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv_transpose2d/conv_transpose2d.cpp @@ -107,8 +107,11 @@ Result conv_transpose2d( uint32_t groups, std::optional bias_tensor, const std::optional& conv_config_, + const std::optional& compute_config_, const std::optional& memory_config ) { Conv2dConfig conv_config = conv_config_.value_or(Conv2dConfig()); + DeviceComputeKernelConfig compute_config = compute_config_.value_or(DeviceComputeKernelConfig()); + //Inverse of sliding_window.get_output_shape() SlidingWindowConfig sliding_window_config = SlidingWindowConfig{ @@ -174,32 +177,6 @@ Result conv_transpose2d( ttnn::is_tensor_on_device_or_multidevice(input_tensor) ? std::make_optional(input_tensor.memory_config()) : std::nullopt); } - DeviceComputeKernelConfig compute_kernel_config; - switch (device->arch()) { - case tt::ARCH::WORMHOLE_B0: - compute_kernel_config = WormholeComputeKernelConfig( - {.math_fidelity = conv_config.math_fidelity, - .math_approx_mode = conv_config.math_approx_mode_enabled, - .fp32_dest_acc_en = conv_config.fp32_dest_acc_enabled, - .packer_l1_acc = conv_config.packer_l1_accum_enabled}); - break; - - case tt::ARCH::GRAYSKULL: - compute_kernel_config = GrayskullComputeKernelConfig( - {.math_fidelity = conv_config.math_fidelity, .math_approx_mode = conv_config.math_approx_mode_enabled}); - break; - - case tt::ARCH::BLACKHOLE: - compute_kernel_config = BlackholeComputeKernelConfig( - {.math_fidelity = conv_config.math_fidelity, - .math_approx_mode = conv_config.math_approx_mode_enabled, - .fp32_dest_acc_en = conv_config.fp32_dest_acc_enabled, - .packer_l1_acc = conv_config.packer_l1_accum_enabled}); - break; - - default: - TT_THROW("Invalid Device Arch, Got {}",device->arch()); - } //Call Halo Transpose auto [input_tensor_post_tm, parallel_config, output_parallel_config, tensor_manipulated, use_non_tile_height] = shard_or_reshard_tensor_if_required( @@ -239,6 +216,16 @@ Result conv_transpose2d( 0, input_tensor_post_tm.memory_config()); + if(conv_config.deallocate_activation) { + input_tensor_post_tm.deallocate(); + log_debug(tt::LogOp, "Deallocate Input Tensor"); + } + if (conv_config.reallocate_halo_output) { + auto move_output = ttnn::operations::core::reallocate(halo_output, halo_output.memory_config()); + halo_output = move_output; + log_debug(tt::LogOp, "Reallocate Halo Output"); + } + //Call Conv2d u_op with Stride = 1, Padding = 0. auto conv_out_memory_config = create_sharded_memory_config_from_parallel_config( ttnn::Shape(std::array{1, 1, batch_size * output_height * output_width, tt::round_up(out_channels, 32)}), @@ -266,7 +253,7 @@ Result conv_transpose2d( conv_config.act_block_w_div, kernel_size[0], kernel_size[1], - conv_config.fp32_dest_acc_enabled, + get_fp32_dest_acc_en(compute_config), conv_config.enable_split_reader); //TODO: Flip the Weights @@ -300,7 +287,7 @@ Result conv_transpose2d( parallel_config.shard_orientation == ShardOrientation::COL_MAJOR, num_cores_c); Tensor matmul_input = ttnn::to_layout( - input_tensor_post_tm, Layout::TILE, conv_config.dtype, input_tensor_post_tm.memory_config(), device + halo_output, Layout::TILE, conv_config.dtype, input_tensor_post_tm.memory_config(), device ); auto matmul_output = ttnn::operations::matmul::matmul( matmul_input, @@ -311,7 +298,7 @@ Result conv_transpose2d( /*bcast_batch=*/std::nullopt, conv_out_memory_config, conv_config.dtype, - compute_kernel_config}); + compute_config}); if (conv_config.deallocate_activation) { ttnn::operations::core::deallocate(matmul_input); } @@ -332,14 +319,13 @@ Result conv_transpose2d( groups, conv_config.output_layout == Layout::ROW_MAJOR, conv_config.activation == "relu", - conv_config.math_fidelity, opt_conv_op_parallel_config, opt_conv_op_block_config, conv_out_memory_config, conv_config.dtype, {batch_size, input_height, input_width, in_channels}, conv_config.input_channels_alignment == 16, - compute_kernel_config, + compute_config, conv_config.enable_act_double_buffer, conv_config.enable_split_reader, conv_config.enable_subblock_padding); @@ -367,8 +353,9 @@ Result ConvTranpose2dOperation::invoke( uint32_t groups, std::optional bias_tensor, const std::optional& conv_config_, - const std::optional& memory_config ) { - return conv_transpose2d(input_tensor, weight_tensor, device, in_channels, out_channels, batch_size, input_height, input_width, kernel_size, stride, padding, output_padding, dilation, groups, std::move(bias_tensor), std::move(conv_config_), std::move(memory_config)); + const std::optional& compute_config_, + const std::optional& memory_config){ + return conv_transpose2d(input_tensor, weight_tensor, device, in_channels, out_channels, batch_size, input_height, input_width, kernel_size, stride, padding, output_padding, dilation, groups, std::move(bias_tensor), std::move(conv_config_), std::move(compute_config_), std::move(memory_config)); } Result ConvTranpose2dOperation::invoke( @@ -389,8 +376,9 @@ Result ConvTranpose2dOperation::invoke( uint32_t groups, std::optional bias_tensor, const std::optional& conv_config_, - const std::optional& memory_config ) { - return conv_transpose2d(input_tensor, weight_tensor, device, in_channels, out_channels, batch_size, input_height, input_width, kernel_size, stride, padding, output_padding, dilation, groups, std::move(bias_tensor), std::move(conv_config_), std::move(memory_config)); + const std::optional& compute_config_, + const std::optional& memory_config){ + return conv_transpose2d(input_tensor, weight_tensor, device, in_channels, out_channels, batch_size, input_height, input_width, kernel_size, stride, padding, output_padding, dilation, groups, std::move(bias_tensor), std::move(conv_config_), std::move(compute_config_), std::move(memory_config)); } } diff --git a/ttnn/cpp/ttnn/operations/conv/conv_transpose2d/conv_transpose2d.hpp b/ttnn/cpp/ttnn/operations/conv/conv_transpose2d/conv_transpose2d.hpp index 119db2cf842..fc23a6f52d6 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv_transpose2d/conv_transpose2d.hpp +++ b/ttnn/cpp/ttnn/operations/conv/conv_transpose2d/conv_transpose2d.hpp @@ -34,6 +34,7 @@ struct ConvTranpose2dOperation{ uint32_t groups, std::optional bias_tensor = std::nullopt, const std::optional& conv_config_ = std::nullopt, + const std::optional& compute_config_ = std::nullopt, const std::optional& memory_config = std::nullopt); static Result invoke( @@ -54,6 +55,7 @@ struct ConvTranpose2dOperation{ uint32_t groups, std::optional bias_tensor = std::nullopt, const std::optional& conv_config_ = std::nullopt, + const std::optional& compute_config_ = std::nullopt, const std::optional& memory_config = std::nullopt); }; diff --git a/ttnn/cpp/ttnn/operations/conv/conv_transpose2d/conv_transpose2d_pybind.cpp b/ttnn/cpp/ttnn/operations/conv/conv_transpose2d/conv_transpose2d_pybind.cpp index 3cea2a187f9..1e07c21eb42 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv_transpose2d/conv_transpose2d_pybind.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv_transpose2d/conv_transpose2d_pybind.cpp @@ -62,6 +62,7 @@ void py_bind_conv_transpose2d(py::module& module) { groups (int): the number of groups for grouped convolution. bias_tensor (ttnn.Tensor, optional): the bias tensor. Defaults to `None`. conv_config (ttnn.Conv2dConfig, optional): the configuration for the convolution operation. Defaults to `None`. + compute_config (ttnn.DeviceComputeKernelConfig, optional): the configuration for the compute kernel. Defaults to `None`. queue_id (int): the queue id to use for the operation. Defaults to `0`. Returns: @@ -84,6 +85,7 @@ void py_bind_conv_transpose2d(py::module& module) { input_height=input_height, input_width=input_width, conv_config=conv_config, + compute_config=compute_config, groups=groups, ) )doc", @@ -103,10 +105,12 @@ void py_bind_conv_transpose2d(py::module& module) { std::array dilation, uint32_t groups, std::optional bias_tensor, - std::optional conv_config, + const std::optional& conv_config, + const std::optional& compute_config, const std::optional& memory_config, const uint8_t& queue_id) -> Result { - return self(queue_id, input_tensor, weight_tensor, device, in_channels, out_channels, batch_size, input_height, input_width, kernel_size, stride, padding, output_padding, dilation, groups, bias_tensor, conv_config, memory_config); + return self(queue_id, input_tensor, weight_tensor, device, in_channels, out_channels, batch_size, input_height, input_width, kernel_size, stride, padding, output_padding, dilation, groups, bias_tensor, conv_config, compute_config, memory_config); + }, py::kw_only(), py::arg("input_tensor"), @@ -125,6 +129,7 @@ void py_bind_conv_transpose2d(py::module& module) { py::arg("groups"), py::arg("bias_tensor") = std::nullopt, py::arg("conv_config") = std::nullopt, + py::arg("compute_config") = std::nullopt, py::arg("memory_config") = std::nullopt, py::arg("queue_id") = 0}, @@ -144,10 +149,12 @@ void py_bind_conv_transpose2d(py::module& module) { std::array dilation, uint32_t groups, std::optional bias_tensor, - std::optional conv_config, + const std::optional& conv_config, + const std::optional& compute_config, const std::optional& memory_config, const uint8_t& queue_id) -> Result { - return self(queue_id, input_tensor, weight_tensor, device, in_channels, out_channels, batch_size, input_height, input_width, kernel_size, stride, padding, output_padding, dilation, groups, bias_tensor, conv_config, memory_config); + return self(queue_id, input_tensor, weight_tensor, device, in_channels, out_channels, batch_size, input_height, input_width, kernel_size, stride, padding, output_padding, dilation, groups, bias_tensor, conv_config, compute_config, memory_config); + }, py::kw_only(), py::arg("input_tensor"), @@ -166,6 +173,7 @@ void py_bind_conv_transpose2d(py::module& module) { py::arg("groups"), py::arg("bias_tensor") = std::nullopt, py::arg("conv_config") = std::nullopt, + py::arg("compute_config") = std::nullopt, py::arg("memory_config") = std::nullopt, py::arg("queue_id") = 0} ); diff --git a/ttnn/cpp/ttnn/operations/core/compute_kernel/compute_kernel_config.cpp b/ttnn/cpp/ttnn/operations/core/compute_kernel/compute_kernel_config.cpp index 7a249b45264..90cf3942767 100644 --- a/ttnn/cpp/ttnn/operations/core/compute_kernel/compute_kernel_config.cpp +++ b/ttnn/cpp/ttnn/operations/core/compute_kernel/compute_kernel_config.cpp @@ -8,8 +8,9 @@ #define DATUMS_PER_ROW 16 -// FIXME: ARCH_NAME specific include -#include "tensix_types.h" // DEST_REGISTER_FULL_SIZE +// This parameter is the same for all supported architectures +// Check this invariant when adding new architectures +#define DEST_REGISTER_FULL_SIZE 64 * 16 namespace ttnn { diff --git a/ttnn/cpp/ttnn/operations/core/compute_kernel/compute_kernel_config.hpp b/ttnn/cpp/ttnn/operations/core/compute_kernel/compute_kernel_config.hpp index a07175adcf3..f8e0f8869bd 100644 --- a/ttnn/cpp/ttnn/operations/core/compute_kernel/compute_kernel_config.hpp +++ b/ttnn/cpp/ttnn/operations/core/compute_kernel/compute_kernel_config.hpp @@ -8,7 +8,7 @@ #include #include #include -#include "umd/device/tt_arch_types.h" +#include "umd/device/types/arch.h" #include "tt_metal/common/base_types.hpp" namespace ttnn { diff --git a/ttnn/cpp/ttnn/operations/core/core.cpp b/ttnn/cpp/ttnn/operations/core/core.cpp index 184f6e139f1..90fc3f34908 100644 --- a/ttnn/cpp/ttnn/operations/core/core.cpp +++ b/ttnn/cpp/ttnn/operations/core/core.cpp @@ -58,25 +58,34 @@ ttnn::Tensor squeeze_from_4D(const ttnn::Tensor& tensor, const int rank) { return ttnn::reshape(tensor, shape.to_rank(rank)); } -ttnn::Tensor to_device(const ttnn::Tensor& tensor, Device* device, const std::optional& memory_config) { +ttnn::Tensor to_device( + const ttnn::Tensor& tensor, + Device* device, + const std::optional& memory_config, + uint8_t cq_id, + const std::vector& sub_device_ids) { auto mem_config = memory_config.value_or(ttnn::DRAM_MEMORY_CONFIG); if (mem_config.is_sharded() and (device->arch() == tt::ARCH::BLACKHOLE)) { - auto interleaved_tensor = tensor.to(device, ttnn::DRAM_MEMORY_CONFIG); + auto interleaved_tensor = tensor.to(device, ttnn::DRAM_MEMORY_CONFIG, cq_id, sub_device_ids); return ttnn::interleaved_to_sharded(ttnn::DefaultQueueId, interleaved_tensor, mem_config, std::nullopt); } else { - return tensor.to(device, memory_config.value_or(ttnn::DRAM_MEMORY_CONFIG)); + return tensor.to(device, memory_config.value_or(ttnn::DRAM_MEMORY_CONFIG), cq_id, sub_device_ids); } } ttnn::Tensor to_device( - const ttnn::Tensor& tensor, MeshDevice* mesh_device, const std::optional& memory_config) { + const ttnn::Tensor& tensor, + MeshDevice* mesh_device, + const std::optional& memory_config, + uint8_t cq_id, + const std::vector& sub_device_ids) { auto mem_config = memory_config.value_or(ttnn::DRAM_MEMORY_CONFIG); // Currently no direct sharded write support in BLACKHOLE due to alignment issue if (mem_config.is_sharded() and (mesh_device->arch() == tt::ARCH::BLACKHOLE)) { - auto interleaved_tensor = tensor.to(mesh_device, ttnn::DRAM_MEMORY_CONFIG); + auto interleaved_tensor = tensor.to(mesh_device, ttnn::DRAM_MEMORY_CONFIG, cq_id, sub_device_ids); return ttnn::interleaved_to_sharded(ttnn::DefaultQueueId, interleaved_tensor, mem_config, std::nullopt); } else { - return tensor.to(mesh_device, mem_config); + return tensor.to(mesh_device, mem_config, cq_id, sub_device_ids); } } @@ -100,17 +109,22 @@ ttnn::Tensor allocate_tensor_on_device( shape, data_type, layout, mesh_device->get_devices(), memory_config.value_or(ttnn::DRAM_MEMORY_CONFIG)); } -void copy_host_to_device_tensor(const ttnn::Tensor& host_tensor, ttnn::Tensor device_tensor, uint8_t cq_id) { - tt::tt_metal::write_tensor(std::move(host_tensor), std::move(device_tensor), cq_id); +void copy_host_to_device_tensor( + const ttnn::Tensor& host_tensor, + ttnn::Tensor device_tensor, + uint8_t cq_id, + const std::vector& sub_device_ids) { + tt::tt_metal::write_tensor(std::move(host_tensor), std::move(device_tensor), cq_id, sub_device_ids); } -ttnn::Tensor from_device(const ttnn::Tensor& tensor, bool blocking, uint8_t cq_id) { +ttnn::Tensor from_device( + const ttnn::Tensor& tensor, bool blocking, uint8_t cq_id, const std::vector& sub_device_ids) { // Currently no direct sharded read support in BLACKHOLE due to alignment issue if (tensor.is_sharded() and (tensor.device()->arch() == tt::ARCH::BLACKHOLE)) { auto interleaved_tensor = ttnn::sharded_to_interleaved(cq_id, tensor, ttnn::DRAM_MEMORY_CONFIG, std::nullopt); - return interleaved_tensor.cpu(blocking, cq_id); + return interleaved_tensor.cpu(blocking, cq_id, sub_device_ids); } else { - return tensor.cpu(blocking, cq_id); + return tensor.cpu(blocking, cq_id, sub_device_ids); } } diff --git a/ttnn/cpp/ttnn/operations/core/core.hpp b/ttnn/cpp/ttnn/operations/core/core.hpp index e269e7030b6..d3ce90a4e24 100644 --- a/ttnn/cpp/ttnn/operations/core/core.hpp +++ b/ttnn/cpp/ttnn/operations/core/core.hpp @@ -24,10 +24,19 @@ ttnn::Tensor unsqueeze_to_4D(const ttnn::Tensor& tensor); ttnn::Tensor squeeze_from_4D(const ttnn::Tensor& tensor, const int rank); -ttnn::Tensor to_device(const ttnn::Tensor& tensor, Device* device, const std::optional& memory_config); +ttnn::Tensor to_device( + const ttnn::Tensor& tensor, + Device* device, + const std::optional& memory_config, + uint8_t cq_id = ttnn::DefaultQueueId, + const std::vector& = {}); ttnn::Tensor to_device( - const ttnn::Tensor& tensor, MeshDevice* mesh_device, const std::optional& memory_config); + const ttnn::Tensor& tensor, + MeshDevice* mesh_device, + const std::optional& memory_config, + uint8_t cq_id = ttnn::DefaultQueueId, + const std::vector& = {}); ttnn::Tensor allocate_tensor_on_device( const Shape& shape, @@ -44,9 +53,16 @@ ttnn::Tensor allocate_tensor_on_device( const std::optional& memory_config); void copy_host_to_device_tensor( - const ttnn::Tensor& host_tensor, ttnn::Tensor device_tensor, uint8_t cq_id = ttnn::DefaultQueueId); - -ttnn::Tensor from_device(const ttnn::Tensor& tensor, bool blocking = true, uint8_t cq_id = ttnn::DefaultQueueId); + const ttnn::Tensor& host_tensor, + ttnn::Tensor device_tensor, + uint8_t cq_id = ttnn::DefaultQueueId, + const std::vector& sub_device_ids = {}); + +ttnn::Tensor from_device( + const ttnn::Tensor& tensor, + bool blocking = true, + uint8_t cq_id = ttnn::DefaultQueueId, + const std::vector& sub_device_ids = {}); void deallocate(Tensor& tensor, bool force = true); diff --git a/ttnn/cpp/ttnn/operations/creation.hpp b/ttnn/cpp/ttnn/operations/creation.hpp index acd2914c98f..3267e2dab29 100644 --- a/ttnn/cpp/ttnn/operations/creation.hpp +++ b/ttnn/cpp/ttnn/operations/creation.hpp @@ -69,6 +69,89 @@ inline std::vector get_workers_from_device(OptionalAnyDevice device) { return device.has_value() ? device->get_devices() : std::vector{}; } +template +static Tensor arange_impl( + const int64_t start, + const int64_t stop, + const int64_t step, + const Layout layout = Layout::ROW_MAJOR, + OptionalAnyDevice device = std::nullopt, + const MemoryConfig& output_mem_config = MemoryConfig{ + .memory_layout = tt::tt_metal::TensorMemoryLayout::INTERLEAVED}) { + constexpr DataType data_type = tt::tt_metal::convert_to_data_type(); + // Current implementation restrictions + TT_ASSERT(step > 0, "Step must be greater than 0"); + TT_ASSERT(start < stop, "Start must be less than step"); + auto size = tt::div_up((stop - start), step); + if (size % 2 != 0) { + size++; + } + auto owned_buffer = tt::tt_metal::owned_buffer::create(size); + + auto index = 0; + for (auto value = start; value < stop; value += step) { + if constexpr (std::is_same_v) { + owned_buffer[index++] = T(static_cast(value)); + } else { + owned_buffer[index++] = static_cast(value); + } + } + auto output = Tensor( + OwnedStorage{owned_buffer}, + ttnn::SimpleShape{1, 1, 1, static_cast(size)}, + data_type, + Layout::ROW_MAJOR) + .to(layout); + if (device.has_value()) { + output = output.to(device->get_devices(), output_mem_config); + } + return output; +} + +template +static Tensor full_impl( + uint8_t queue_id, + const tt::tt_metal::LegacyShape& shape, + T value, + const Layout layout, + const std::vector& devices, + const MemoryConfig& output_mem_config, + std::optional optional_output_tensor) { + constexpr DataType data_type = tt::tt_metal::convert_to_data_type(); + TensorSpec tensor_spec( + shape.logical_shape(), + TensorLayout::fromLegacyPaddedShape(data_type, PageConfig(layout), MemoryConfig{}, shape)); + auto owned_buffer = tt::tt_metal::owned_buffer::create(tensor_spec.padded_shape().volume()); + // TODO: 15061 - Generalize the header to support generic vector / view types. + std::fill(std::begin(owned_buffer), std::end(owned_buffer), value); + + if (!optional_output_tensor.has_value()) { + auto output = Tensor(OwnedStorage{owned_buffer}, shape, data_type, layout); + if (!devices.empty()) { + output = output.to(devices, output_mem_config); + } + return output; + } else { + const auto buffers = optional_output_tensor->buffers(); + const bool using_fast_dispatch = (std::getenv("TT_METAL_SLOW_DISPATCH_MODE") == nullptr); + + for (auto* buffer : buffers) { + if (using_fast_dispatch) { + auto& cmd_queue = buffer->device()->command_queue(queue_id); + if (CommandQueue::default_mode() == CommandQueue::CommandQueueMode::ASYNC) { + tt::tt_metal::EnqueueWriteBuffer(cmd_queue, *buffer, owned_buffer.get_ptr(), /*blocking=*/false); + } else { + tt::tt_metal::EnqueueWriteBuffer(cmd_queue, *buffer, owned_buffer.data(), /*blocking=*/false); + } + } else { + tt::tt_metal::detail::WriteToBuffer(*buffer, owned_buffer.get()); + } + } + + return *optional_output_tensor; + } +} + } // namespace detail template @@ -122,8 +205,19 @@ inline ttnn::Tensor full_impl( MemoryConfig mem_cfg = optional_output_tensor.has_value() ? optional_output_tensor.value().memory_config() : memory_config.value_or(ttnn::DRAM_MEMORY_CONFIG); - return numpy::full_impl( - queue_id, shape_value, fill_value, dtype_value, layout_value, workers, mem_cfg, optional_output_tensor); + auto concrete_full = [&](BufferType fill_value) { + return detail::full_impl( + queue_id, shape_value, fill_value, layout_value, workers, mem_cfg, optional_output_tensor); + }; + + switch (dtype_value) { + case DataType::UINT8: return concrete_full.template operator()(fill_value); + case DataType::UINT16: return concrete_full.template operator()(fill_value); + case DataType::UINT32: return concrete_full.template operator()(fill_value); + case DataType::FLOAT32: return concrete_full.template operator()(fill_value); + case DataType::BFLOAT16: return concrete_full.template operator()<::bfloat16>(static_cast(fill_value)); + default: TT_THROW("Unsupported DataType!"); + } } template @@ -287,10 +381,12 @@ struct EmptyLike { }; struct Full { + template + requires std::is_same_v or std::is_same_v static ttnn::Tensor invoke( uint8_t queue_id, const ttnn::Shape& shape, - const float fill_value, + const FillValueType fill_value, const std::optional& dtype = std::nullopt, const std::optional& layout = std::nullopt, detail::OptionalAnyDevice device = std::nullopt, @@ -307,48 +403,11 @@ struct Full { optional_output_tensor); } - static ttnn::Tensor invoke( - uint8_t queue_id, - const ttnn::Shape& shape, - const int fill_value, - const std::optional& dtype = std::nullopt, - const std::optional& layout = std::nullopt, - detail::OptionalAnyDevice device = std::nullopt, - const std::optional& memory_config = std::nullopt, - std::optional optional_output_tensor = std::nullopt) { - return full_impl( - queue_id, - shape, - fill_value, - dtype, - layout, - detail::get_workers_from_device(device), - memory_config, - optional_output_tensor); - } - - static ttnn::Tensor invoke( - const ttnn::Shape& shape, - const float fill_value, - const std::optional& dtype = std::nullopt, - const std::optional& layout = std::nullopt, - detail::OptionalAnyDevice device = std::nullopt, - const std::optional& memory_config = std::nullopt, - std::optional optional_output_tensor = std::nullopt) { - return full_impl( - ttnn::DefaultQueueId, - shape, - fill_value, - dtype, - layout, - detail::get_workers_from_device(device), - memory_config, - optional_output_tensor); - } - + template + requires std::is_same_v or std::is_same_v static ttnn::Tensor invoke( const ttnn::Shape& shape, - const int fill_value, + const FillValueType fill_value, const std::optional& dtype = std::nullopt, const std::optional& layout = std::nullopt, detail::OptionalAnyDevice device = std::nullopt, @@ -367,10 +426,12 @@ struct Full { }; struct FullLike { + template + requires std::is_same_v or std::is_same_v static ttnn::Tensor invoke( uint8_t queue_id, const ttnn::Tensor& tensor, - const float fill_value, + const FillValueType fill_value, const std::optional& dtype = std::nullopt, const std::optional& layout = std::nullopt, detail::OptionalAnyDevice device = std::nullopt, @@ -380,34 +441,11 @@ struct FullLike { queue_id, tensor, fill_value, dtype, layout, device, memory_config, optional_output_tensor); } + template + requires std::is_same_v or std::is_same_v static ttnn::Tensor invoke( - uint8_t queue_id, const ttnn::Tensor& tensor, - const int fill_value, - const std::optional& dtype = std::nullopt, - const std::optional& layout = std::nullopt, - detail::OptionalAnyDevice device = std::nullopt, - const std::optional& memory_config = std::nullopt, - std::optional optional_output_tensor = std::nullopt) { - return full_like_impl( - queue_id, tensor, fill_value, dtype, layout, device, memory_config, optional_output_tensor); - } - - static ttnn::Tensor invoke( - const ttnn::Tensor& tensor, - const float fill_value, - const std::optional& dtype = std::nullopt, - const std::optional& layout = std::nullopt, - detail::OptionalAnyDevice device = std::nullopt, - const std::optional& memory_config = std::nullopt, - std::optional optional_output_tensor = std::nullopt) { - return full_like_impl( - ttnn::DefaultQueueId, tensor, fill_value, dtype, layout, device, memory_config, optional_output_tensor); - } - - static ttnn::Tensor invoke( - const ttnn::Tensor& tensor, - const int fill_value, + const FillValueType fill_value, const std::optional& dtype = std::nullopt, const std::optional& layout = std::nullopt, detail::OptionalAnyDevice device = std::nullopt, @@ -418,12 +456,11 @@ struct FullLike { } }; -// TODO: #14974 - Onboard this API onto AnyDevice. struct Arange { static ttnn::Tensor invoke( const int64_t stop, const DataType dtype = DataType::BFLOAT16, - const std::optional>& device = std::nullopt, + detail::OptionalAnyDevice device = std::nullopt, const MemoryConfig& memory_config = ttnn::DRAM_MEMORY_CONFIG) { return Arange::invoke(0, stop, 1, dtype, device, memory_config); } @@ -433,20 +470,18 @@ struct Arange { const int64_t stop, const int64_t step = 1, const DataType dtype = ttnn::DataType::BFLOAT16, - const std::optional>& device_arg = std::nullopt, + detail::OptionalAnyDevice device = std::nullopt, const MemoryConfig& memory_config = ttnn::DRAM_MEMORY_CONFIG) { - Device* device = device_arg.has_value() ? &(device_arg.value().get()) : nullptr; + auto concrete_arange = [&]() { + return detail::arange_impl(start, stop, step, ttnn::ROW_MAJOR_LAYOUT, device, memory_config); + }; + switch (dtype) { - case DataType::BFLOAT16: - return numpy::arange<::bfloat16>(start, stop, step, ttnn::ROW_MAJOR_LAYOUT, device, memory_config); - case DataType::FLOAT32: - return numpy::arange(start, stop, step, ttnn::ROW_MAJOR_LAYOUT, device, memory_config); - case DataType::UINT16: - return numpy::arange(start, stop, step, ttnn::ROW_MAJOR_LAYOUT, device, memory_config); - case DataType::UINT32: - return numpy::arange(start, stop, step, ttnn::ROW_MAJOR_LAYOUT, device, memory_config); - case DataType::INT32: - return numpy::arange(start, stop, step, ttnn::ROW_MAJOR_LAYOUT, device, memory_config); + case DataType::BFLOAT16: return concrete_arange.template operator()<::bfloat16>(); + case DataType::FLOAT32: return concrete_arange.template operator()(); + case DataType::UINT16: return concrete_arange.template operator()(); + case DataType::UINT32: return concrete_arange.template operator()(); + case DataType::INT32: return concrete_arange.template operator()(); default: TT_THROW("Unsupported dtype"); } } diff --git a/ttnn/cpp/ttnn/operations/data_movement/unsqueeze/unsqueeze.cpp b/ttnn/cpp/ttnn/operations/data_movement/unsqueeze/unsqueeze.cpp index 1ab05bd0bb9..636b544ae2b 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/unsqueeze/unsqueeze.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/unsqueeze/unsqueeze.cpp @@ -9,17 +9,24 @@ namespace ttnn::operations::data_movement { ttnn::Tensor UnsqueezeOperation::invoke(const ttnn::Tensor& input_tensor, const int dim) { const auto tensor_shape = input_tensor.get_shape(); - const auto rank = tensor_shape.rank(); - SmallVector output_shape_vector; + const uint32_t rank = tensor_shape.rank(); + const int32_t max_dim = (int)(rank); + const int32_t min_dim = -(max_dim)-1; - TT_FATAL( - input_tensor.get_layout() == Layout::ROW_MAJOR or (!tensor_shape.has_tile_padding()), - "Currently supporing ROW-MAJOR tensors or TILE tensors with no padding"); + SmallVector output_shape_vector; - int normal_dim = dim; + int normal_dim; // Handle negative dimension by converting it to positive + TT_FATAL( + (dim >= min_dim) && (dim <= max_dim), + "Dimension out of range (expected to be in range of [{},{}], but got {})", + min_dim, + max_dim, + dim); if (dim < 0) { - normal_dim += rank + 1; + normal_dim = rank + 1 + dim; + } else { + normal_dim = dim; } // Insert new dimension @@ -31,11 +38,11 @@ ttnn::Tensor UnsqueezeOperation::invoke(const ttnn::Tensor& input_tensor, const } // If the dimension is at the end, append it - if (normal_dim >= tensor_shape.size()) { + if (normal_dim == rank) { output_shape_vector.push_back(1); } - return ttnn::reshape(input_tensor, ttnn::SimpleShape(std::move(output_shape_vector))); + return ttnn::reshape(input_tensor, output_shape_vector); } } // namespace ttnn::operations::data_movement diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary/common/unary_op_types.hpp b/ttnn/cpp/ttnn/operations/eltwise/unary/common/unary_op_types.hpp index 394647cdb68..da0ef55c936 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary/common/unary_op_types.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary/common/unary_op_types.hpp @@ -80,6 +80,7 @@ enum class UnaryOpType { BITWISE_OR, RIGHT_SHIFT, FLOOR, + FLOOR_FLOAT32, CEIL, LEFT_SHIFT, REMAINDER, diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary/common/unary_op_utils.cpp b/ttnn/cpp/ttnn/operations/eltwise/unary/common/unary_op_utils.cpp index f18aa992748..0732e967602 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary/common/unary_op_utils.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary/common/unary_op_utils.cpp @@ -49,6 +49,8 @@ void update_macro_defines(UnaryOpType op_type, std::map get_op_init_and_func_default(UnaryOpType op_type, std: case UnaryOpType::SIGNBIT: op_init_and_name = {"signbit_tile_init();", fmt::format("signbit_tile({});", idst)}; break; - case UnaryOpType::FLOOR: op_init_and_name = {"floor_tile_init();", fmt::format("floor_tile({});", idst)}; break; case UnaryOpType::CEIL: op_init_and_name = {"ceil_tile_init();", fmt::format("ceil_tile({});", idst)}; break; case UnaryOpType::SIN: op_init_and_name = {"sin_tile_init();", fmt::format("sin_tile({});", idst)}; break; case UnaryOpType::COS: op_init_and_name = {"cos_tile_init();", fmt::format("cos_tile({});", idst)}; break; @@ -340,6 +340,12 @@ std::pair get_op_init_and_func_default(UnaryOpType op_type, std: case UnaryOpType::IDENTITY_UINT32: op_init_and_name = {"identity_tile_init();", fmt::format("identity_tile_uint32({});", idst)}; break; + case UnaryOpType::FLOOR: + op_init_and_name = {"floor_tile_init();", fmt::format("floor_tile({});", idst)}; + break; + case UnaryOpType::FLOOR_FLOAT32: + op_init_and_name = {"floor_tile_init();", fmt::format("floor_tile_float32({});", idst)}; break; + break; case UnaryOpType::RELU6: op_init_and_name = {"relu_max_tile_init();", fmt::format("relu_max_tile({}, 0x40c00000u);", idst)}; break; diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary/unary.cpp b/ttnn/cpp/ttnn/operations/eltwise/unary/unary.cpp index 81eba41d570..c87dae81384 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary/unary.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary/unary.cpp @@ -99,7 +99,6 @@ template struct ExecuteUnary; template struct ExecuteUnary; template struct ExecuteUnary; template struct ExecuteUnary; -template struct ExecuteUnary; template struct ExecuteUnary; template struct ExecuteUnary; template struct ExecuteUnary; @@ -337,6 +336,32 @@ Tensor Identity::invoke( DefaultQueueId, input_tensor, {UnaryWithParam{op_type}}, memory_config, optional_output_tensor); } +Tensor Floor::invoke( + uint8_t queue_id, + const Tensor& input_tensor, + const std::optional& memory_config, + const std::optional& optional_output_tensor) { + UnaryOpType op_type = UnaryOpType::FLOOR; + if (input_tensor.get_dtype() == DataType::FLOAT32) { + op_type = UnaryOpType::FLOOR_FLOAT32; + } + + return detail::unary_impl(queue_id, input_tensor, {UnaryWithParam{op_type}}, memory_config, optional_output_tensor); +} + +Tensor Floor::invoke( + const Tensor& input_tensor, + const std::optional& memory_config, + const std::optional& optional_output_tensor) { + UnaryOpType op_type = UnaryOpType::FLOOR; + if (input_tensor.get_dtype() == DataType::FLOAT32) { + op_type = UnaryOpType::FLOOR_FLOAT32; + } + + return detail::unary_impl( + DefaultQueueId, input_tensor, {UnaryWithParam{op_type}}, memory_config, optional_output_tensor); +} + Tensor Dropout::invoke( const Tensor& input, const uint32_t seed, diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary/unary.hpp b/ttnn/cpp/ttnn/operations/eltwise/unary/unary.hpp index 79d5d22eb5a..7c034bfe66e 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary/unary.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary/unary.hpp @@ -148,6 +148,19 @@ struct Identity { const std::optional& optional_output_tensor = std::nullopt); }; +struct Floor { + static Tensor invoke( + uint8_t queue_id, + const Tensor& input_tensor, + const std::optional& memory_config = std::nullopt, + const std::optional& optional_output_tensor = std::nullopt); + + static Tensor invoke( + const Tensor& input_tensor, + const std::optional& memory_config = std::nullopt, + const std::optional& optional_output_tensor = std::nullopt); +}; + struct Dropout { static Tensor invoke( const Tensor& input, @@ -281,7 +294,6 @@ REGISTER_UNARY_OPERATION(erfinv, ERFINV); REGISTER_UNARY_OPERATION(exp2, EXP2); REGISTER_UNARY_OPERATION(expm1, EXPM1); REGISTER_UNARY_OPERATION(eqz, EQZ); -REGISTER_UNARY_OPERATION(floor, FLOOR); REGISTER_UNARY_OPERATION(ceil, CEIL); REGISTER_UNARY_OPERATION(gez, GEZ); REGISTER_UNARY_OPERATION(gtz, GTZ); @@ -354,6 +366,8 @@ constexpr auto dropout = ttnn::register_operation_with_auto_launch_op<"ttnn::dropout", ttnn::operations::unary::Dropout>(); constexpr auto identity = ttnn::register_operation_with_auto_launch_op<"ttnn::identity", ttnn::operations::unary::Identity>(); +constexpr auto floor = + ttnn::register_operation_with_auto_launch_op<"ttnn::floor", ttnn::operations::unary::Floor>(); constexpr auto softplus = ttnn::register_operation_with_auto_launch_op<"ttnn::softplus", ttnn::operations::unary::Softplus>(); constexpr auto prelu_sfpu = diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/CMakeLists.txt b/ttnn/cpp/ttnn/operations/experimental/ccl/CMakeLists.txt new file mode 100644 index 00000000000..82767c44a09 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/CMakeLists.txt @@ -0,0 +1,12 @@ +set(CCL_EXPERIMENTAL_TTNN_SRCS + #Experimental Ops + ${CMAKE_CURRENT_SOURCE_DIR}/all_gather_matmul/all_gather_matmul.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/all_gather_matmul/all_gather_matmul_pybind.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/all_gather_matmul/device/all_gather_matmul_op.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/all_gather_matmul/device/multi_core/all_gather_matmul_op_multi_core.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/all_reduce/all_reduce.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/all_reduce/all_reduce_pybind.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/all_reduce/device/all_reduce_op.cpp + CACHE INTERNAL + "CCL Experimental sources to reuse in ttnn build" +) diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce/device/all_reduce_op.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce/device/all_reduce_op.cpp index 29fd7967763..cb3647c98c1 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce/device/all_reduce_op.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce/device/all_reduce_op.cpp @@ -3,6 +3,7 @@ // SPDX-License-Identifier: Apache-2.0 #include "ttnn/operations/experimental/ccl/all_reduce/device/all_reduce_op.hpp" +#include "ttnn/operations/ccl/ccl_host_types.hpp" #include "ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.hpp" #include "ttnn/operations/ccl/all_gather/all_gather.hpp" #include "ttnn/operations/ccl/all_gather/device/all_gather_op.hpp" @@ -66,13 +67,18 @@ namespace experimental { namespace ccl { static AllReduceStrategy choose_all_reduce_strategy( - const Tensor& input_tensor, uint32_t num_devices, uint32_t num_links) { + const Tensor& input_tensor, uint32_t num_devices, uint32_t num_links, ttnn::ccl::Topology topology) { auto shape = input_tensor.get_logical_shape(); auto rank = shape.rank(); uint32_t all_reduce_dim = -1; bool optimized_version = false; + if (num_devices == 2) { + // 2 devices == n300 == linear topology + topology = ttnn::ccl::Topology::Linear; + } + for (uint32_t i = 0; i < rank; ++i) { if (shape[i] % num_devices == 0) { all_reduce_dim = i; @@ -80,6 +86,11 @@ static AllReduceStrategy choose_all_reduce_strategy( } } + if (topology == ttnn::ccl::Topology::Linear) { + // reduce scatter doesn't reliably support line topology yet + optimized_version = false; + } + if (optimized_version) { if (shape[2] == tt::constants::TILE_HEIGHT || shape[3] == tt::constants::TILE_WIDTH) { optimized_version = false; // Reduce scatter hangs for this shape @@ -110,7 +121,7 @@ static Tensor all_gather_local_reduce( const std::optional user_defined_num_workers, const std::optional user_defined_num_buffers_per_channel, const std::vector& devices, - const ttnn::ccl::Topology& topology) { + ttnn::ccl::Topology topology) { auto shape = input_tensor.get_logical_shape(); auto rank = shape.rank(); log_warning( @@ -119,6 +130,11 @@ static Tensor all_gather_local_reduce( "by optimized version", shape); + if (num_devices == 2) { + // 2 devices == n300 == linear topology + topology = ttnn::ccl::Topology::Linear; + } + TT_FATAL(rank == 4, "Tensor rank must be 4, but has {} ", rank); uint32_t merged_dim_size = 1; for (uint32_t i = 2; i < rank; ++i) { @@ -247,7 +263,6 @@ Tensor all_reduce( ttnn::operations::binary::BinaryOpType binary_op_type = convert_reduce_type_to_eltwise_type(math_op); TT_FATAL( std::getenv("TT_METAL_SLOW_DISPATCH_MODE") == nullptr, "All Reduce op is only supported for Fast Dispatch"); - TT_FATAL(topology == ttnn::ccl::Topology::Ring, "All Reduce op is currently supported only on Ring topology"); auto devices = input_tensor.get_workers(); uint32_t num_devices = devices.size(); @@ -272,7 +287,7 @@ Tensor all_reduce( const auto& input_tensor = input_tensors.at(0); // Choose the appropriate strategy - AllReduceStrategy strategy = choose_all_reduce_strategy(input_tensor, num_devices, num_links); + AllReduceStrategy strategy = choose_all_reduce_strategy(input_tensor, num_devices, num_links, topology); // Run the selected all-reduce operation Tensor result = run_all_reduce( diff --git a/ttnn/cpp/ttnn/operations/experimental/transformer/create_qkv_heads/device/create_qkv_heads_program_factory.cpp b/ttnn/cpp/ttnn/operations/experimental/transformer/create_qkv_heads/device/create_qkv_heads_program_factory.cpp index 8b0ce2c5053..b37b6ae1686 100644 --- a/ttnn/cpp/ttnn/operations/experimental/transformer/create_qkv_heads/device/create_qkv_heads_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/transformer/create_qkv_heads/device/create_qkv_heads_program_factory.cpp @@ -7,9 +7,6 @@ #include "tt_metal/common/constants.hpp" #include "tt_metal/detail/util.hpp" -// FIXME: ARCH_NAME specific include -#include "tensix_types.h" // L1_SIZE - using namespace tt::constants; using namespace tt; @@ -92,14 +89,15 @@ static inline operation::ProgramWithCallbacks create_heads_combined_qkv_sharded( block_ht * TILE_HEIGHT); uint32_t per_core_tiles = block_ht * block_wt; + const uint32_t l1_size = input_tensor.device()->l1_size_per_core(); auto data_format = tt_metal::datatype_to_dataformat_converter(input_tensor.get_dtype()); uint32_t single_tile_size = tile_size(data_format); TT_FATAL( - L1_SIZE >= 2 * per_core_tiles * single_tile_size, + l1_size >= 2 * per_core_tiles * single_tile_size, "Workload of Tiles {} at Tile Size {} (times 2 for output) exceeds L1 capacity {}", per_core_tiles, single_tile_size, - L1_SIZE); + l1_size); std::vector num_tiles_per_group; num_tiles_per_group.reserve(output.size()); diff --git a/ttnn/cpp/ttnn/operations/experimental/transformer/create_qkv_heads_from_separate_tensors/device/create_qkv_heads_from_separate_tensors_device_operation.cpp b/ttnn/cpp/ttnn/operations/experimental/transformer/create_qkv_heads_from_separate_tensors/device/create_qkv_heads_from_separate_tensors_device_operation.cpp index a0fcbe427b6..fa69d508b16 100644 --- a/ttnn/cpp/ttnn/operations/experimental/transformer/create_qkv_heads_from_separate_tensors/device/create_qkv_heads_from_separate_tensors_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/transformer/create_qkv_heads_from_separate_tensors/device/create_qkv_heads_from_separate_tensors_device_operation.cpp @@ -7,9 +7,6 @@ #include "tt_metal/host_api.hpp" -// FIXME: ARCH_NAME specific include -#include "tensix_types.h" // L1_SIZE - namespace ttnn::operations::experimental::transformer { void CreateQKVHeadsSeparateTensorsDeviceOperation::validate(const std::vector& input_tensors) const { @@ -122,10 +119,11 @@ void CreateQKVHeadsSeparateTensorsDeviceOperation::validate(const std::vectorl1_size_per_core(); const uint32_t single_tile_size = tt::tile_size(tt::tt_metal::datatype_to_dataformat_converter(q_input_tensor.get_dtype())); TT_FATAL( - L1_SIZE >= 2 * (per_core_q_tiles + 2 * per_core_k_tiles) * single_tile_size, "Workload exceeds L1 capacity"); + l1_size >= 2 * (per_core_q_tiles + 2 * per_core_k_tiles) * single_tile_size, "Workload exceeds L1 capacity"); // TODO: Add this back when output is HEIGHT sharded only! // TT_FATAL(this->output_mem_config.memory_layout == TensorMemoryLayout::HEIGHT_SHARDED, "Error"); diff --git a/ttnn/cpp/ttnn/operations/numpy/functions.hpp b/ttnn/cpp/ttnn/operations/numpy/functions.hpp index 31f1ec32efe..51a3668eed1 100644 --- a/ttnn/cpp/ttnn/operations/numpy/functions.hpp +++ b/ttnn/cpp/ttnn/operations/numpy/functions.hpp @@ -26,195 +26,6 @@ using tt::tt_metal::MemoryConfig; using tt::tt_metal::OwnedStorage; using tt::tt_metal::StorageType; using tt::tt_metal::Tensor; -namespace detail { - -template -constexpr static DataType get_data_type() { - if constexpr (std::is_same_v) { - return DataType::UINT8; - } else if constexpr (std::is_same_v) { - return DataType::UINT16; - } else if constexpr (std::is_same_v) { - return DataType::INT32; - } else if constexpr (std::is_same_v) { - return DataType::UINT32; - } else if constexpr (std::is_same_v) { - return DataType::FLOAT32; - } else if constexpr (std::is_same_v) { - return DataType::BFLOAT16; - } else { - TT_THROW("Unsupported DataType!"); - } -} - -template -static Tensor full( - uint8_t queue_id, - const tt::tt_metal::LegacyShape& shape, - T value, - const Layout layout, - const std::vector& devices, - const MemoryConfig& output_mem_config, - std::optional optional_output_tensor) { - constexpr DataType data_type = detail::get_data_type(); - TensorSpec tensor_spec( - shape.logical_shape(), - TensorLayout::fromLegacyPaddedShape(data_type, PageConfig(layout), MemoryConfig{}, shape)); - auto owned_buffer = tt::tt_metal::owned_buffer::create(tensor_spec.padded_shape().volume()); - // TODO: 15061 - Generalize the header to support generic vector / view types. - std::fill(std::begin(owned_buffer), std::end(owned_buffer), value); - - if (!optional_output_tensor.has_value()) { - auto output = Tensor(OwnedStorage{owned_buffer}, shape, data_type, layout); - if (!devices.empty()) { - output = output.to(devices, output_mem_config); - } - return output; - } else { - const auto buffers = optional_output_tensor->buffers(); - const bool using_fast_dispatch = (std::getenv("TT_METAL_SLOW_DISPATCH_MODE") == nullptr); - - for (auto* buffer : buffers) { - if (using_fast_dispatch) { - auto& cmd_queue = buffer->device()->command_queue(queue_id); - if (CommandQueue::default_mode() == CommandQueue::CommandQueueMode::ASYNC) { - tt::tt_metal::EnqueueWriteBuffer(cmd_queue, *buffer, owned_buffer.get_ptr(), /*blocking=*/false); - } else { - tt::tt_metal::EnqueueWriteBuffer(cmd_queue, *buffer, owned_buffer.data(), /*blocking=*/false); - } - } else { - tt::tt_metal::detail::WriteToBuffer(*buffer, owned_buffer.get()); - } - } - - return *optional_output_tensor; - } -} - -} // namespace detail - -template -static Tensor full_impl( - uint8_t queue_id, - const tt::tt_metal::LegacyShape& shape, - const T value, - const DataType data_type, - const Layout layout, - const std::vector& devices, - const MemoryConfig& output_mem_config, - std::optional optional_output_tensor) { - switch (data_type) { - case DataType::UINT8: { - return detail::full( - queue_id, shape, uint8_t(value), layout, devices, output_mem_config, optional_output_tensor); - } - case DataType::UINT16: { - return detail::full( - queue_id, shape, uint16_t(value), layout, devices, output_mem_config, optional_output_tensor); - } - case DataType::UINT32: { - return detail::full( - queue_id, shape, uint32_t(value), layout, devices, output_mem_config, optional_output_tensor); - } - case DataType::FLOAT32: { - return detail::full( - queue_id, shape, float(value), layout, devices, output_mem_config, optional_output_tensor); - } - case DataType::BFLOAT16: { - return detail::full<::bfloat16>( - queue_id, - shape, - ::bfloat16(static_cast(value)), - layout, - devices, - output_mem_config, - optional_output_tensor); - } - default: TT_THROW("Unsupported DataType!"); - } -} - -// TODO: #14974 - Can this be deleted, as it is only used in tests? -template -static Tensor full( - const tt::tt_metal::LegacyShape& shape, - const T value, - const DataType data_type, - const Layout layout = Layout::ROW_MAJOR, - Device* device = nullptr, - const MemoryConfig& output_mem_config = MemoryConfig{ - .memory_layout = tt::tt_metal::TensorMemoryLayout::INTERLEAVED}) { - return full_impl( - ttnn::DefaultQueueId, - shape, - value, - data_type, - layout, - device ? std::vector{device} : std::vector{}, - output_mem_config, - std::nullopt); -} - -// TODO: #14974 - Can this be deleted, as it is only used in tests? -static Tensor zeros( - const tt::tt_metal::LegacyShape& shape, - const DataType data_type = DataType::BFLOAT16, - const Layout layout = Layout::ROW_MAJOR, - Device* device = nullptr, - const MemoryConfig& output_mem_config = MemoryConfig{ - .memory_layout = tt::tt_metal::TensorMemoryLayout::INTERLEAVED}) { - return full(shape, 0.0f, data_type, layout, device, output_mem_config); -} - -// TODO: #14974 - Can this be deleted, as it is only used in tests? -static Tensor ones( - const tt::tt_metal::LegacyShape& shape, - const DataType data_type = DataType::BFLOAT16, - const Layout layout = Layout::ROW_MAJOR, - Device* device = nullptr, - const MemoryConfig& output_mem_config = MemoryConfig{ - .memory_layout = tt::tt_metal::TensorMemoryLayout::INTERLEAVED}) { - return full(shape, 1.0f, data_type, layout, device, output_mem_config); -} - -template -static Tensor arange( - const int64_t start, - const int64_t stop, - const int64_t step, - const Layout layout = Layout::ROW_MAJOR, - Device* device = nullptr, - const MemoryConfig& output_mem_config = MemoryConfig{ - .memory_layout = tt::tt_metal::TensorMemoryLayout::INTERLEAVED}) { - constexpr DataType data_type = detail::get_data_type(); - // Current implementation restrictions - TT_ASSERT(step > 0, "Step must be greater than 0"); - TT_ASSERT(start < stop, "Start must be less than step"); - auto size = tt::div_up((stop - start), step); - if (size % 2 != 0) { - size++; - } - auto owned_buffer = tt::tt_metal::owned_buffer::create(size); - - auto index = 0; - for (auto value = start; value < stop; value += step) { - if constexpr (std::is_same_v) { - owned_buffer[index++] = T(static_cast(value)); - } else { - owned_buffer[index++] = static_cast(value); - } - } - auto output = Tensor( - OwnedStorage{owned_buffer}, - ttnn::SimpleShape{1, 1, 1, static_cast(size)}, - data_type, - Layout::ROW_MAJOR) - .to(layout); - if (device != nullptr) { - output = output.to(device, output_mem_config); - } - return output; -} template static Tensor index_trilu( @@ -671,7 +482,7 @@ static void seed(std::size_t seed) { RANDOM_GENERATOR = std::mt19937(seed); } template static Tensor uniform(T low, T high, const tt::tt_metal::LegacyShape& shape, const Layout layout = Layout::ROW_MAJOR) { - constexpr DataType data_type = detail::get_data_type(); + constexpr DataType data_type = tt::tt_metal::convert_to_data_type(); auto owned_buffer = tt::tt_metal::owned_buffer::create(tt::tt_metal::compute_volume(shape)); diff --git a/ttnn/cpp/ttnn/operations/pool/upsample/upsample.cpp b/ttnn/cpp/ttnn/operations/pool/upsample/upsample.cpp index dbd111f0358..576a237a7db 100644 --- a/ttnn/cpp/ttnn/operations/pool/upsample/upsample.cpp +++ b/ttnn/cpp/ttnn/operations/pool/upsample/upsample.cpp @@ -10,7 +10,7 @@ namespace ttnn::operations::upsample { ttnn::Tensor ExecuteUpSample::invoke( const ttnn::Tensor& input_tensor, - std::variant scale_factor, + std::variant scale_factor, const std::string& mode, const std::optional& output_mem_config, const std::optional& compute_kernel_config) { @@ -27,21 +27,8 @@ ttnn::Tensor ExecuteUpSample::invoke( scale_h = sf; scale_w = sf; } else if constexpr (std::is_same_v) { - scale_w = sf.at(0); - int scale_c = sf.at(1); - TT_FATAL(scale_c == 1, "Error"); - } else if constexpr (std::is_same_v) { scale_h = sf.at(0); scale_w = sf.at(1); - int scale_c = sf.at(2); - TT_FATAL(scale_c == 1, "Error"); - } else if constexpr (std::is_same_v) { - int scale_n = sf.at(0); - scale_h = sf.at(1); - scale_w = sf.at(2); - int scale_c = sf.at(3); - TT_FATAL(scale_n == 1, "Error"); - TT_FATAL(scale_c == 1, "Error"); } else { // static_assert(false, "Unsupported scale factor"); static_assert(sizeof(T) != 0, "Type check failed."); diff --git a/ttnn/cpp/ttnn/operations/pool/upsample/upsample.hpp b/ttnn/cpp/ttnn/operations/pool/upsample/upsample.hpp index e8bd68e634a..0a012304548 100644 --- a/ttnn/cpp/ttnn/operations/pool/upsample/upsample.hpp +++ b/ttnn/cpp/ttnn/operations/pool/upsample/upsample.hpp @@ -15,7 +15,7 @@ namespace upsample { struct ExecuteUpSample { static ttnn::Tensor invoke( const ttnn::Tensor& input_tensor, - std::variant scale_factor, + std::variant scale_factor, const std::string& mode = std::string("nearest"), const std::optional& output_mem_config = std::nullopt, const std::optional& compute_kernel_config = std::nullopt); diff --git a/ttnn/cpp/ttnn/operations/pool/upsample/upsample_pybind.cpp b/ttnn/cpp/ttnn/operations/pool/upsample/upsample_pybind.cpp index 93d4137cd70..06c72788e1f 100644 --- a/ttnn/cpp/ttnn/operations/pool/upsample/upsample_pybind.cpp +++ b/ttnn/cpp/ttnn/operations/pool/upsample/upsample_pybind.cpp @@ -25,7 +25,7 @@ void bind_upsample(py::module& module) { Args: input_tensor (ttnn.Tensor): the input tensor. - scale_factor (int or tt::tt_metal::Array2D or tt::tt_metal::Array3D or tt::tt_metal::Array4D): multiplier for spatial size. Has to match input size if it is a tuple. + scale_factor (int or tt::tt_metal::Array2D): multiplier for spatial size. Keyword args: diff --git a/ttnn/cpp/ttnn/operations/reduction/moe/moe.cpp b/ttnn/cpp/ttnn/operations/reduction/moe/moe.cpp index fcfd1f35e60..dbf98519483 100644 --- a/ttnn/cpp/ttnn/operations/reduction/moe/moe.cpp +++ b/ttnn/cpp/ttnn/operations/reduction/moe/moe.cpp @@ -39,9 +39,8 @@ auto MoeOperation::invoke( const uint16_t k, const std::optional& memory_config, std::optional optional_output_tensor) { - constexpr uint8_t DefaultQueueId = 0; return invoke( - DefaultQueueId, + ttnn::DefaultQueueId, input_tensor, expert_mask_tensor, topk_mask_tensor, diff --git a/ttnn/cpp/ttnn/operations/reduction/topk/device/topk_op.cpp b/ttnn/cpp/ttnn/operations/reduction/topk/device/topk_op.cpp index 8d96b25ffec..7ceb4d61916 100644 --- a/ttnn/cpp/ttnn/operations/reduction/topk/device/topk_op.cpp +++ b/ttnn/cpp/ttnn/operations/reduction/topk/device/topk_op.cpp @@ -5,9 +5,6 @@ #include "topk_op.hpp" #include "topk_program_factory.hpp" -// FIXME: ARCH_NAME specific include -#include "tensix_types.h" // L1_SIZE - namespace topk_utils { static inline bool verify_available_cores( @@ -16,6 +13,7 @@ static inline bool verify_available_cores( uint16_t max_dim, CoreCoord grid, uint16_t k, + const uint32_t l1_size, const uint32_t value_tile_size, const uint32_t index_tile_size) { const auto max_cores = grid.y - 1; // reserve one core for the gather - switch to grid.x as it allows for more @@ -30,7 +28,7 @@ static inline bool verify_available_cores( (split_size / tt::constants::TILE_WIDTH) * (value_tile_size + index_tile_size); // we divide the width into split_size chunks and each chunk, as well // as a matching set of indices, is processed by a core - if (num_cores <= max_cores && (memory_cost_gather + memory_cost_local) < L1_SIZE && num_cores > 1) { + if (num_cores <= max_cores && (memory_cost_gather + memory_cost_local) < l1_size && num_cores > 1) { return true; } } @@ -79,6 +77,7 @@ void TopK::validate_with_output_tensors( input_shape[this->dim] / 2, device->compute_with_storage_grid_size(), this->k, + device->l1_size_per_core(), value_tile_size, index_tile_size), "Not enough cores available to run topk operation"); diff --git a/ttnn/cpp/ttnn/operations/reduction/topk/device/topk_program_factory.hpp b/ttnn/cpp/ttnn/operations/reduction/topk/device/topk_program_factory.hpp index ab854db3620..1996aacd555 100644 --- a/ttnn/cpp/ttnn/operations/reduction/topk/device/topk_program_factory.hpp +++ b/ttnn/cpp/ttnn/operations/reduction/topk/device/topk_program_factory.hpp @@ -8,9 +8,6 @@ #include "tt_metal/host_api.hpp" #include "tt_log.h" -// FIXME: ARCH_NAME specific include -#include "tensix_types.h" // L1_SIZE - namespace ttnn::operations::reduction::detail { operation::ProgramWithCallbacks topk_single_core_interleaved( @@ -179,6 +176,7 @@ static inline std::tuple cores_utilized( uint16_t max_dim, CoreCoord grid, uint16_t k, + const uint32_t l1_size, const uint32_t value_tile_size, const uint32_t index_tile_size) { const auto max_cores = grid.y - 1; // reserve one core for the gather - switch to grid.x as it allows for more @@ -193,7 +191,7 @@ static inline std::tuple cores_utilized( (split_size / tt::constants::TILE_WIDTH) * (value_tile_size + index_tile_size); // we divide the width into split_size chunks and each chunk, as well // as a matching set of indices, is processed by a core - if (num_cores <= max_cores && (memory_cost_gather + memory_cost_local) < L1_SIZE && num_cores > 1) { + if (num_cores <= max_cores && (memory_cost_gather + memory_cost_local) < l1_size && num_cores > 1) { return {num_cores + 1, split_size, rem, num_cores * k}; } } @@ -237,6 +235,7 @@ operation::ProgramWithCallbacks topk_multicore_interleaved( input_shape[dim] / 2, device->compute_with_storage_grid_size(), k, + device->l1_size_per_core(), value_tile_size, index_tile_size); diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/compute/sdpa.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/compute/sdpa.cpp index 54e2bd1b403..976364d32cb 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/compute/sdpa.cpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/compute/sdpa.cpp @@ -336,7 +336,7 @@ void MAIN { constexpr uint32_t B = get_compile_time_arg_val(0); constexpr uint32_t NQH = get_compile_time_arg_val(1); constexpr uint32_t NKH = get_compile_time_arg_val(2); - constexpr uint32_t St = get_compile_time_arg_val(3); + constexpr uint32_t Skt = get_compile_time_arg_val(3); constexpr uint32_t DHt = get_compile_time_arg_val(4); constexpr uint32_t Sq_chunk_t = get_compile_time_arg_val(5); constexpr uint32_t q_num_chunks = get_compile_time_arg_val(6); @@ -419,7 +419,7 @@ void MAIN { if constexpr (is_causal) { q_high_idx = q_low_idx + Sq_chunk_t; } else { - q_high_idx = St; + q_high_idx = Skt; } cb_wait_front(cb_q_in, q_chunk_tiles); diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/dataflow/reader_interleaved.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/dataflow/reader_interleaved.cpp index 3309205fa34..8b945b404e8 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/dataflow/reader_interleaved.cpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/dataflow/reader_interleaved.cpp @@ -14,15 +14,16 @@ void kernel_main() { constexpr uint32_t B = get_compile_time_arg_val(0); constexpr uint32_t NQH = get_compile_time_arg_val(1); constexpr uint32_t NKH = get_compile_time_arg_val(2); - constexpr uint32_t St = get_compile_time_arg_val(3); - constexpr uint32_t DHt = get_compile_time_arg_val(4); - constexpr uint32_t Sq_chunk_t = get_compile_time_arg_val(5); - constexpr uint32_t q_num_chunks = get_compile_time_arg_val(6); - constexpr uint32_t Sk_chunk_t = get_compile_time_arg_val(7); - constexpr uint32_t k_num_chunks = get_compile_time_arg_val(8); - constexpr uint32_t num_cores = get_compile_time_arg_val(9); - constexpr uint32_t is_causal = get_compile_time_arg_val(10) == 1; - constexpr uint32_t use_provided_mask = get_compile_time_arg_val(11) == 1; + constexpr uint32_t Sqt = get_compile_time_arg_val(3); + constexpr uint32_t Skt = get_compile_time_arg_val(4); + constexpr uint32_t DHt = get_compile_time_arg_val(5); + constexpr uint32_t Sq_chunk_t = get_compile_time_arg_val(6); + constexpr uint32_t q_num_chunks = get_compile_time_arg_val(7); + constexpr uint32_t Sk_chunk_t = get_compile_time_arg_val(8); + constexpr uint32_t k_num_chunks = get_compile_time_arg_val(9); + constexpr uint32_t num_cores = get_compile_time_arg_val(10); + constexpr uint32_t is_causal = get_compile_time_arg_val(11) == 1; + constexpr uint32_t use_provided_mask = get_compile_time_arg_val(12) == 1; const uint32_t q_addr = get_arg_val(0); const uint32_t k_addr = get_arg_val(1); @@ -82,9 +83,9 @@ void kernel_main() { uint32_t barrier_count = 0; for (uint32_t nb = local_batch_start; nb < local_batch_end; ++nb) { - const uint32_t q_batch_offset = nb * NQH * St * DHt; - const uint32_t kv_batch_offset = nb * NKH * St * DHt; - const uint32_t mask_batch_offset = nb * St * St; + const uint32_t q_batch_offset = nb * NQH * Sqt * DHt; + const uint32_t kv_batch_offset = nb * NKH * Skt * DHt; + const uint32_t mask_batch_offset = nb * Sqt * Skt; for (uint32_t nq = local_nh_start; nq < local_nh_end; ++nq) { for (uint32_t q_iter = 0; q_iter < q_chunks_per_core; ++q_iter) { uint32_t q_chunk; @@ -100,7 +101,7 @@ void kernel_main() { q_chunk = local_q_start + q_iter; #endif - uint32_t q_head_offset = nq * St * DHt; + uint32_t q_head_offset = nq * Sqt * DHt; uint32_t q_chunk_offset = q_chunk * Sq_chunk_t * DHt; q_tile_id = q_batch_offset + q_head_offset + q_chunk_offset; @@ -129,11 +130,11 @@ void kernel_main() { if constexpr (is_causal) { q_high_idx = q_low_idx + Sq_chunk_t; } else { - q_high_idx = St; + q_high_idx = Skt; } const uint32_t kv_head = nq / q_heads_per_kv; - const uint32_t kv_head_offset = kv_head * St * DHt; + const uint32_t kv_head_offset = kv_head * Skt * DHt; // loop while k_low < q_high for (uint32_t k_chunk = 0; (k_chunk * Sk_chunk_t) < q_high_idx; ++k_chunk) { @@ -171,8 +172,7 @@ void kernel_main() { cb_reserve_back(cb_mask_in, mask_chunk_tiles); uint32_t mask_write_ptr = get_write_ptr(cb_mask_in); barrier_count = 0; - mask_tile_id = mask_batch_offset + q_chunk * Sq_chunk_t * St /*row_offset*/ + - k_chunk * Sk_chunk_t /*col_offset*/; + mask_tile_id = mask_batch_offset + q_chunk * Sq_chunk_t * Skt /*row_offset*/ + k_chunk * Sk_chunk_t /*col_offset*/; for (uint32_t row = 0; row < Sq_chunk_t; ++row) { for (uint32_t col = 0; col < Sk_chunk_t; ++col) { noc_async_read_tile(mask_tile_id, mask_reader, mask_write_ptr); @@ -185,7 +185,7 @@ void kernel_main() { } // Strid along columns to get to next row mask_tile_id -= Sk_chunk_t; - mask_tile_id += St; + mask_tile_id += Skt; } noc_async_read_barrier(); cb_push_back(cb_mask_in, mask_chunk_tiles); diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/dataflow/writer_interleaved.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/dataflow/writer_interleaved.cpp index 8d5d7d4f673..5cf07e576e2 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/dataflow/writer_interleaved.cpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/dataflow/writer_interleaved.cpp @@ -136,7 +136,7 @@ void kernel_main() { constexpr uint32_t B = get_compile_time_arg_val(0); constexpr uint32_t NQH = get_compile_time_arg_val(1); constexpr uint32_t NKH = get_compile_time_arg_val(2); - constexpr uint32_t St = get_compile_time_arg_val(3); + constexpr uint32_t Sqt = get_compile_time_arg_val(3); constexpr uint32_t DHt = get_compile_time_arg_val(4); constexpr uint32_t Sq_chunk_t = get_compile_time_arg_val(5); constexpr uint32_t q_num_chunks = get_compile_time_arg_val(6); @@ -184,7 +184,7 @@ void kernel_main() { uint32_t out_tile_id = 0; for (uint32_t nb = local_batch_start; nb < local_batch_end; ++nb) { - const uint32_t q_batch_offset = nb * NQH * St * DHt; + const uint32_t q_batch_offset = nb * NQH * Sqt * DHt; for (uint32_t nq = local_nh_start; nq < local_nh_end; ++nq) { for (uint32_t q_iter = 0; q_iter < q_chunks_per_core; ++q_iter) { uint32_t q_chunk; @@ -200,7 +200,7 @@ void kernel_main() { q_chunk = local_q_start + q_iter; #endif - uint32_t q_head_offset = nq * St * DHt; + uint32_t q_head_offset = nq * Sqt * DHt; uint32_t q_chunk_offset = q_chunk * Sq_chunk_t * DHt; out_tile_id = q_batch_offset + q_head_offset + q_chunk_offset; diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/sdpa_op.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/sdpa_op.cpp index 0708eb6645d..5b3981bedfe 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/sdpa_op.cpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/sdpa_op.cpp @@ -69,50 +69,35 @@ void ScaledDotProductAttention::validate( const auto B = q_shape[0]; const auto nqh = q_shape[1]; const auto nkv = k_shape[1]; - const auto S = q_shape[2]; + const auto Sq = q_shape[2]; const auto DH = q_shape[3]; + const auto Sk = k_shape[2]; + if (this->is_causal) { + TT_FATAL(Sq == Sk, "Causal SDPA requires Q and K to have the same sequence length. Got Q: {}, K: {}", Sq, Sk); + } TT_FATAL(k_shape[0] == B && v_shape[0] == B, "K and V batch must match. Got K: {}, V: {}", k_shape[0], v_shape[0]); TT_FATAL(v_shape[1] == nkv, "K and V num_heads must match. Got K: {}, V: {}", k_shape[1], v_shape[1]); - TT_FATAL( - k_shape[2] == S && v_shape[2] == S, - "K and V sequence length must match. Got K: {}, V: {}", - k_shape[2], - v_shape[2]); - TT_FATAL( - k_shape[3] == DH && v_shape[3] == DH, - "K and V hidden dim must match. Got K: {}, V: {}", - k_shape[3], - v_shape[3]); - TT_FATAL( - nqh >= nkv && nqh % nkv == 0, - "Q num_heads must be >= K num_heads and divisible by K num_heads. Got Q: {}, K: {}", - nqh, - nkv); + TT_FATAL(v_shape[2] == Sk, "K and V sequence length must match. Got K: {}, V: {}", k_shape[2], v_shape[2]); + TT_FATAL(k_shape[3] == DH && v_shape[3] == DH, "K and V hidden dim must match. Got K: {}, V: {}", k_shape[3], v_shape[3]); + TT_FATAL(nqh >= nkv && nqh % nkv == 0, "Q num_heads must be >= K num_heads and divisible by K num_heads. Got Q: {}, K: {}", nqh, nkv); if (mask_option.has_value()) { const auto mask_shape = mask_option.value().get_legacy_shape(); TT_FATAL(mask_shape[0] == B, "Mask batch dim must match Q batch dim"); TT_FATAL(mask_shape[1] == 1, "Mask num_heads must be 1 to be broadcasted across all heads"); - TT_FATAL(mask_shape[2] == S, "Mask sequence length must match Q sequence length"); - TT_FATAL(mask_shape[3] == S, "Mask sequence length must match Q sequence length"); + TT_FATAL(mask_shape[2] == Sq, "Mask sequence length must match Q sequence length"); + TT_FATAL(mask_shape[3] == Sk, "Mask sequence length must match K sequence length"); } if (this->program_config.has_value()) { auto q_chunk_size = program_config->q_chunk_size; auto k_chunk_size = program_config->k_chunk_size; - TT_FATAL( - q_shape[-2] % q_chunk_size == 0, - "q_chunk_size must divide q_shape[-2]. Got q_chunk_size: {}, q_shape[-2]: {}", - q_chunk_size, - q_shape[-2]); - TT_FATAL( - k_shape[-2] % k_chunk_size == 0, - "k_chunk_size must divide k_shape[-2]. Got k_chunk_size: {}, k_shape[-2]: {}", - k_chunk_size, - k_shape[-2]); + TT_FATAL(Sq % q_chunk_size == 0, "q_chunk_size must divide q_shape[-2]. Got q_chunk_size: {}, q_shape[-2]: {}", q_chunk_size, q_shape[-2]); + TT_FATAL(Sk % k_chunk_size == 0, "k_chunk_size must divide k_shape[-2]. Got k_chunk_size: {}, k_shape[-2]: {}", k_chunk_size, k_shape[-2]); + } } diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/sdpa_program_factory.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/sdpa_program_factory.cpp index 70eede0127c..9278d02c812 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/sdpa_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/sdpa_program_factory.cpp @@ -41,24 +41,28 @@ operation::ProgramWithCallbacks sdpa_multi_core( const auto q_shape = input_tensor_q.get_legacy_shape(); const auto k_shape = input_tensor_k.get_legacy_shape(); - const uint32_t B = q_shape[0], NQH = q_shape[1], S = q_shape[2], DH = q_shape[3]; + const uint32_t B = q_shape[0], NQH = q_shape[1], Sq = q_shape[2], DH = q_shape[3]; + const uint32_t Sk = k_shape[2]; const uint32_t NKH = k_shape[1]; - const uint32_t St = S / TILE_HEIGHT; + const uint32_t Sqt = Sq / TILE_HEIGHT; + const uint32_t Skt = Sk / TILE_HEIGHT; const uint32_t DHt = DH / TILE_WIDTH; const uint32_t Sq_chunk_t = q_chunk_size / TILE_HEIGHT; const uint32_t Sk_chunk_t = k_chunk_size / TILE_HEIGHT; - const uint32_t q_num_chunks = S / q_chunk_size; - const uint32_t k_num_chunks = S / k_chunk_size; + const uint32_t q_num_chunks = Sq / q_chunk_size; + const uint32_t k_num_chunks = Sk / k_chunk_size; const bool use_provided_mask = attn_mask.has_value(); // log_debug all of the above tt::log_debug("B: {}", B); tt::log_debug("NQH: {}", NQH); - tt::log_debug("S: {}", S); + tt::log_debug("Sq: {}", Sq); + tt::log_debug("Sk: {}", Sk); tt::log_debug("DH: {}", DH); - tt::log_debug("St: {}", St); + tt::log_debug("Sqt: {}", Sqt); + tt::log_debug("Skt: {}", Skt); tt::log_debug("DHt: {}", DHt); tt::log_debug("Sq_chunk_t: {}", Sq_chunk_t); tt::log_debug("Sk_chunk_t: {}", Sk_chunk_t); @@ -216,60 +220,64 @@ operation::ProgramWithCallbacks sdpa_multi_core( scale_union.f = scale.value_or(1.0f); std::vector reader_compile_time_args = {// interleaved accessor args - B, - NQH, - NKH, - St, - DHt, - Sq_chunk_t, - q_num_chunks, - Sk_chunk_t, - k_num_chunks, - num_cores, - (std::uint32_t)is_causal, - (std::uint32_t)use_provided_mask}; + B, + NQH, + NKH, + Sqt, + Skt, + DHt, + Sq_chunk_t, + q_num_chunks, + Sk_chunk_t, + k_num_chunks, + num_cores, + (std::uint32_t)is_causal, + (std::uint32_t)use_provided_mask + }; std::vector writer_compile_time_args = {// interleaved accessor args - B, - NQH, - NKH, - St, - DHt, - Sq_chunk_t, - q_num_chunks, - Sk_chunk_t, - k_num_chunks, - packed_identity_scalar, - scale_union.u, - num_cores, - (std::uint32_t)is_causal, - (std::uint32_t)use_provided_mask}; + B, + NQH, + NKH, + Sqt, + DHt, + Sq_chunk_t, + q_num_chunks, + Sk_chunk_t, + k_num_chunks, + packed_identity_scalar, + scale_union.u, + num_cores, + (std::uint32_t)is_causal, + (std::uint32_t)use_provided_mask + }; std::vector compute_compile_time_args = {// matmul args - B, - NQH, - NKH, - St, - DHt, - Sq_chunk_t, - q_num_chunks, - Sk_chunk_t, - k_num_chunks, - qk_in0_block_w, - qk_out_subblock_w, - qk_out_subblock_h, - qk_in0_num_subblocks, - qk_in1_num_subblocks, - qk_num_blocks, - out_in0_block_w, - out_out_subblock_w, - out_out_subblock_h, - out_in0_num_subblocks, - out_in1_num_subblocks, - out_num_blocks, - num_cores, - (std::uint32_t)is_causal, - (std::uint32_t)use_provided_mask}; + B, + NQH, + NKH, + Skt, + DHt, + Sq_chunk_t, + q_num_chunks, + Sk_chunk_t, + k_num_chunks, + qk_in0_block_w, + qk_out_subblock_w, + qk_out_subblock_h, + qk_in0_num_subblocks, + qk_in1_num_subblocks, + qk_num_blocks, + out_in0_block_w, + out_out_subblock_w, + out_out_subblock_h, + out_in0_num_subblocks, + out_in1_num_subblocks, + out_num_blocks, + num_cores, + (std::uint32_t)is_causal, + (std::uint32_t)use_provided_mask + }; std::map defines; defines["STATS_GRANULARITY"] = std::to_string(stats_granularity); diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_program_factory.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_program_factory.cpp index 93f918ff092..7c09d0e4de0 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_program_factory.cpp @@ -132,12 +132,14 @@ operation::ProgramWithCallbacks sdpa_decode_multi_core( // balance the number of cores to use based on batch uint32_t max_num_cores_for_compute = program_config->max_cores_per_head_batch * B * num_kv_heads; uint32_t num_cores_per_batch = std::min(num_cores_available, max_num_cores_for_compute) / B; - uint32_t num_active_cores = num_cores_per_batch * B; //// for core assignment, it is the same whether there's 1 core for head or 1 core for many heads uint32_t num_cores_per_head = std::max((uint32_t)1, num_cores_per_batch / num_kv_heads); - uint32_t num_heads_per_core = std::max((uint32_t)1, num_kv_heads / num_cores_per_batch); + uint32_t num_heads_per_core = std::max((uint32_t)1, (uint32_t)std::ceil((float)num_kv_heads / num_cores_per_batch)); uint32_t num_reducer_cores = num_kv_heads * B / num_heads_per_core; uint32_t num_output_cores = B; + uint32_t num_active_cores = num_cores_per_head * num_kv_heads * B / num_heads_per_core; + //// recalculate num_cores_per_batch based on num_active_cores + num_cores_per_batch = num_active_cores / B; TT_FATAL( ((num_cores_per_head >= 1) && (num_heads_per_core == 1)) || @@ -146,10 +148,10 @@ operation::ProgramWithCallbacks sdpa_decode_multi_core( // create core group, assume n batch and k_heads: // this is a 1D list of cores sorted by batch_output1, worker, ..., batch_output2, worker, ..., batch_output n, - // worker, ... Within each batch, we will assign head reducers. e.g. the following mapping: (batch_output1, worker1, - // worker2), (worker3, worker4, worker5), ..., (... worker3*k-1, worker3*k) (head_reducer1, h_worker1, - // h_worker2), (head_reducer2, h_worker1, h_worker2), ..., (head_reducerk, h_worker1, h_worker2) head_reducer2 to - // head_reducerk then send the result to head_reducer1, which is also the batch_output1 + // worker, ... Within each batch, we will assign head reducers. e.g. the following mapping: + // (batch_output1, worker1, worker2), (worker3, worker4, worker5), ..., (... worker3*k-1, worker3*k) + // (head_reducer1, h_worker1, h_worker2), (head_reducer2, h_worker1, h_worker2), ..., (head_reducerk, h_worker1, + // h_worker2) head_reducer2 to head_reducerk then send the result to head_reducer1, which is also the batch_output1 std::vector core_group; std::vector core_group_idle; if (is_q_sharded || is_output_sharded) { diff --git a/ttnn/cpp/ttnn/tensor/tensor.cpp b/ttnn/cpp/ttnn/tensor/tensor.cpp index d6c80e1e92d..f4304d33c6a 100644 --- a/ttnn/cpp/ttnn/tensor/tensor.cpp +++ b/ttnn/cpp/ttnn/tensor/tensor.cpp @@ -559,25 +559,29 @@ const Storage& Tensor::get_storage() const { return this->tensor_attributes->storage; } -Tensor Tensor::to(CommandQueue& queue, const MemoryConfig& mem_config) const { - return tensor_ops::tensor_to(*this, queue.device(), mem_config); +Tensor Tensor::to(Device* target_device, const MemoryConfig& mem_config,uint8_t cq_id, + const std::vector& sub_device_ids) const { + return tensor_ops::tensor_to(*this, target_device, mem_config, cq_id, sub_device_ids); } -Tensor Tensor::to(Device* target_device, const MemoryConfig& mem_config) const { - return tensor_ops::tensor_to(*this, target_device, mem_config); +Tensor Tensor::to(distributed::MeshDevice* mesh_device, const MemoryConfig& mem_config,uint8_t cq_id, + const std::vector& sub_device_ids) const { + std::vector workers_to_use = ttnn::distributed::get_mapped_devices(*this, *mesh_device); + return tensor_ops::tensor_to(*this, workers_to_use, mem_config, cq_id, sub_device_ids); } -Tensor Tensor::to(distributed::MeshDevice* mesh_device, const MemoryConfig& mem_config) const { - std::vector workers_to_use = ttnn::distributed::distribute_tensor_to_mesh(*this, *mesh_device); - return tensor_ops::tensor_to(*this, workers_to_use, mem_config); +Tensor Tensor::to( + const std::vector& workers, + const MemoryConfig& mem_config, + uint8_t cq_id, + const std::vector& sub_device_ids) const { + return tensor_ops::tensor_to(*this, workers, mem_config, cq_id, sub_device_ids); } -Tensor Tensor::to(const std::vector& workers, const MemoryConfig& mem_config) const { - return tensor_ops::tensor_to(*this, workers, mem_config); +Tensor Tensor::cpu(bool blocking, uint8_t cq_id, const std::vector& sub_device_ids) const { + return tensor_ops::tensor_cpu(*this, blocking, cq_id, sub_device_ids); } -Tensor Tensor::cpu(bool blocking, uint8_t cq_id) const { return tensor_ops::tensor_cpu(*this, blocking, cq_id); } - Tensor Tensor::cpu_sharded() const { return tensor_ops::tensor_cpu_sharded(*this); } Tensor Tensor::extract_shard(const CoreCoord& core) const { @@ -656,21 +660,12 @@ std::vector Tensor::host_page_ordering() { StorageType Tensor::storage_type() const { return std::visit( - [](auto&& storage) -> StorageType { - using T = std::decay_t; - if constexpr (std::is_same_v) { - return StorageType::OWNED; - } else if constexpr (std::is_same_v) { - return StorageType::DEVICE; - } else if constexpr (std::is_same_v) { - return StorageType::BORROWED; - } else if constexpr (std::is_same_v) { - return StorageType::MULTI_DEVICE; - } else if constexpr (std::is_same_v) { - return StorageType::MULTI_DEVICE_HOST; - } else { - raise_unsupported_storage(); - } + tt::stl::overloaded{ + [](const OwnedStorage&) { return StorageType::OWNED; }, + [](const DeviceStorage&) { return StorageType::DEVICE; }, + [](const BorrowedStorage&) { return StorageType::BORROWED; }, + [](const MultiDeviceStorage& s) { return StorageType::MULTI_DEVICE; }, + [](const MultiDeviceHostStorage&) { return StorageType::MULTI_DEVICE_HOST; }, }, this->get_storage()); } @@ -870,21 +865,24 @@ Tensor allocate_tensor_on_devices( return device_tensor; } -void write_tensor(const Tensor& host_tensor, Tensor device_tensor, uint8_t cq_id) { +void write_tensor( + const Tensor& host_tensor, Tensor device_tensor, uint8_t cq_id, const std::vector& sub_device_ids) { // Top level wrapper to copy a host tensor to a preallocated device tensor TT_ASSERT(device_tensor.workers.size(), "Workers must be specified for device_tensor in write_tensor"); + Tensor async_safe_tensor = copy_borrowed_tensor_in_async_mode(device_tensor.workers.at(0), host_tensor); + TT_FATAL( + async_safe_tensor.storage_type() == StorageType::BORROWED or + async_safe_tensor.storage_type() == StorageType::OWNED or + async_safe_tensor.storage_type() == StorageType::MULTI_DEVICE_HOST, + "write_tensor only supports host_tensor to device_tensor data transfer"); + uint32_t host_tensor_ref_count = async_safe_tensor.tensor_attributes->record_main_thread_ref_count(); uint32_t device_tensor_ref_count = device_tensor.tensor_attributes->record_main_thread_ref_count(); for (int worker_index = 0; worker_index < device_tensor.workers.size(); ++worker_index) { auto& worker = device_tensor.workers[worker_index]; - worker->push_work([cq_id, worker, worker_index, async_safe_tensor, device_tensor]() mutable { - TT_FATAL( - async_safe_tensor.storage_type() == StorageType::BORROWED or - async_safe_tensor.storage_type() == StorageType::OWNED or - async_safe_tensor.storage_type() == StorageType::MULTI_DEVICE_HOST, - "write_tensor only supports host_tensor to device_tensor data transfer"); + worker->push_work([cq_id, worker, worker_index, async_safe_tensor, device_tensor, sub_device_ids]() mutable { TT_FATAL( device_tensor.storage_type() == StorageType::DEVICE or device_tensor.storage_type() == StorageType::MULTI_DEVICE, @@ -895,33 +893,51 @@ void write_tensor(const Tensor& host_tensor, Tensor device_tensor, uint8_t cq_id async_safe_tensor.get_tensor_spec().page_config() == device_tensor.get_tensor_spec().page_config(), "Error"); std::visit( - [worker_index, worker, cq_id, &async_safe_tensor](auto&& s) { - void* host_data = nullptr; - using StorageType = std::decay_t; - if constexpr (std::is_same_v) { - if (std::holds_alternative(async_safe_tensor.get_storage())) { - // Handle case when writing borrowed tensor single device tensor (only allowed for sync - // mode) - auto host_storage = std::get(async_safe_tensor.get_storage()); - std::visit([&host_data](auto&& b) { host_data = b.data(); }, host_storage.buffer); - } else { - TT_ASSERT( - std::holds_alternative(async_safe_tensor.get_storage()), - "Unexpected type {}", - tt::stl::get_active_type_name_in_variant(async_safe_tensor.get_storage())); - auto host_storage = std::get(async_safe_tensor.get_storage()); - std::visit([&host_data](auto&& b) { host_data = b.begin(); }, host_storage.get_buffer()); - } - EnqueueWriteBuffer(worker->command_queue(cq_id), s.get_buffer(), host_data, false); - } else if constexpr (std::is_same_v) { + tt::stl::overloaded{ + [worker, worker_index, cq_id, &async_safe_tensor, sub_device_ids](const DeviceStorage& device_storage) { + // Copying from host to a single device. + void* host_data = std::visit( + tt::stl::overloaded{ + [](BorrowedStorage s) { + return std::visit([](auto&& b) { return b.data(); }, s.buffer); + }, + [](OwnedStorage s) { + return std::visit([](auto&& b) { return static_cast(b.begin()); }, s.buffer); + }, + [](const MultiDeviceHostStorage& host_storage) { + TT_ASSERT( + host_storage.num_buffers() == 1, + "Cannot copy multi-buffer host storage to a single device"); + return std::visit( + [](auto&& b) -> void* { return b.begin(); }, host_storage.get_buffer(0)); + }, + [](auto&&) -> void* { TT_THROW("Unreachable"); }, + }, + async_safe_tensor.get_storage()); + EnqueueWriteBuffer( + worker->command_queue(cq_id), + device_storage.get_buffer(), + host_data, + /*blocking=*/false, + sub_device_ids); + }, + [worker, worker_index, cq_id, &async_safe_tensor, sub_device_ids](const MultiDeviceStorage& device_storage) { + // Copying from host to multi-device. + TT_ASSERT( + std::holds_alternative(async_safe_tensor.get_storage()), + "Unexpected type {}", + tt::stl::get_active_type_name_in_variant(async_safe_tensor.get_storage())); auto host_storage = std::get(async_safe_tensor.get_storage()); - std::visit( - [worker_index, &host_data](auto&& b) { host_data = b.begin(); }, - host_storage.get_buffer(worker_index)); + void* host_data = std::visit( + [](auto&& b) -> void* { return b.begin(); }, host_storage.get_buffer(worker_index)); EnqueueWriteBuffer( - worker->command_queue(cq_id), s.get_buffer_for_device(worker), host_data, false); - } - }, + worker->command_queue(cq_id), + device_storage.get_buffer_for_device(worker), + host_data, + /*blocking=*/false, + sub_device_ids); + }, + [](auto&& s) { TT_THROW("Unreachable"); }}, device_tensor.get_storage()); }); } diff --git a/ttnn/cpp/ttnn/tensor/tensor.hpp b/ttnn/cpp/ttnn/tensor/tensor.hpp index 7a2976ac8f2..b8b7a993b8a 100644 --- a/ttnn/cpp/ttnn/tensor/tensor.hpp +++ b/ttnn/cpp/ttnn/tensor/tensor.hpp @@ -16,6 +16,7 @@ #include "common/test_tiles.hpp" #include "common/tt_backend_api_types.hpp" #include "ttnn/common/constants.hpp" +#include "ttnn/distributed/distributed_tensor_config.hpp" #include "ttnn/tensor/types.hpp" #include "ttnn/tensor/tensor_spec.hpp" #include "ttnn/tensor/layout/tensor_layout.hpp" @@ -140,19 +141,21 @@ struct Tensor { Tensor to( Device* target_device, - const MemoryConfig& mem_config = {.memory_layout = tt::tt_metal::TensorMemoryLayout::INTERLEAVED}) const; + const MemoryConfig& mem_config = {.memory_layout = tt::tt_metal::TensorMemoryLayout::INTERLEAVED}, + uint8_t cq_id = ttnn::DefaultQueueId, + const std::vector& sub_device_ids = {}) const; Tensor to( distributed::MeshDevice* mesh_device, - const MemoryConfig& mem_config = {.memory_layout = tt::tt_metal::TensorMemoryLayout::INTERLEAVED}) const; - - Tensor to( - CommandQueue& queue, - const MemoryConfig& mem_config = {.memory_layout = tt::tt_metal::TensorMemoryLayout::INTERLEAVED}) const; + const MemoryConfig& mem_config = {.memory_layout = tt::tt_metal::TensorMemoryLayout::INTERLEAVED}, + uint8_t cq_id = ttnn::DefaultQueueId, + const std::vector& sub_device_ids = {}) const; Tensor to( const std::vector& workers, - const MemoryConfig& mem_config = {.memory_layout = tt::tt_metal::TensorMemoryLayout::INTERLEAVED}) const; + const MemoryConfig& mem_config = {.memory_layout = tt::tt_metal::TensorMemoryLayout::INTERLEAVED}, + uint8_t cq_id = ttnn::DefaultQueueId, + const std::vector& sub_device_ids = {}) const; Tensor to(Layout target_layout, Device* worker = nullptr) const; @@ -163,7 +166,10 @@ struct Tensor { const ttnn::SimpleShape& input_tensor_start, float pad_value) const; - Tensor cpu(bool blocking = true, uint8_t cq_id = ttnn::DefaultQueueId) const; + Tensor cpu( + bool blocking = true, + uint8_t cq_id = ttnn::DefaultQueueId, + const std::vector& sub_device_ids = {}) const; Tensor cpu_sharded() const; @@ -373,7 +379,11 @@ Tensor allocate_tensor_on_devices( const std::vector& devices, const MemoryConfig& memory_config = {.memory_layout = tt::tt_metal::TensorMemoryLayout::INTERLEAVED}, const std::optional& tile = std::nullopt); -void write_tensor(const Tensor& host_tensor, Tensor device_tensor, uint8_t cq_id = ttnn::DefaultQueueId); +void write_tensor( + const Tensor& host_tensor, + Tensor device_tensor, + uint8_t cq_id = ttnn::DefaultQueueId, + const std::vector& sub_device_ids = {}); Tensor set_tensor_id(const Tensor& tensor); diff --git a/ttnn/cpp/ttnn/tensor/tensor_impl.cpp b/ttnn/cpp/ttnn/tensor/tensor_impl.cpp index 0386f6e353c..dc7545ac0e5 100644 --- a/ttnn/cpp/ttnn/tensor/tensor_impl.cpp +++ b/ttnn/cpp/ttnn/tensor/tensor_impl.cpp @@ -565,7 +565,11 @@ std::string to_string(const Tensor& tensor, std::optional o // ====================================================================================== template -Tensor to_host_helper(const Tensor& tensor, bool blocking = true, uint8_t cq_id = ttnn::DefaultQueueId) { +Tensor to_host_helper( + const Tensor& tensor, + bool blocking = true, + uint8_t cq_id = ttnn::DefaultQueueId, + tt::stl::Span sub_device_ids = {}) { TT_ASSERT(tensor.is_allocated(), "Buffer must be allocated on device!"); auto device_buffer = tensor.device_buffer(); auto device = tensor.device(); @@ -575,7 +579,8 @@ Tensor to_host_helper(const Tensor& tensor, bool blocking = true, uint8_t cq_id const char* TT_METAL_SLOW_DISPATCH_MODE = std::getenv("TT_METAL_SLOW_DISPATCH_MODE"); if (TT_METAL_SLOW_DISPATCH_MODE == nullptr) { data_vec.resize(size_in_bytes / sizeof(T)); - read_data_from_device_buffer(device->command_queue(cq_id), device_buffer, data_vec.data(), blocking); + read_data_from_device_buffer( + device->command_queue(cq_id), device_buffer, data_vec.data(), blocking, sub_device_ids); } else { read_data_from_device_buffer(device_buffer, data_vec); } @@ -584,9 +589,9 @@ Tensor to_host_helper(const Tensor& tensor, bool blocking = true, uint8_t cq_id } template -Tensor to_host(const Tensor& tensor, bool blocking, uint8_t cq_id) { +Tensor to_host(const Tensor& tensor, bool blocking, uint8_t cq_id, tt::stl::Span sub_device_ids) { if (tensor.storage_type() == StorageType::DEVICE) { - return to_host_helper(tensor, blocking, cq_id); + return to_host_helper(tensor, blocking, cq_id, sub_device_ids); } else if (tensor.storage_type() == StorageType::MULTI_DEVICE) { auto devices = get_devices(tensor); Tensor host_tensor(devices.size()); @@ -594,7 +599,7 @@ Tensor to_host(const Tensor& tensor, bool blocking, uint8_t cq_id) { for (int device_index = 0; device_index < devices.size(); ++device_index) { const auto& device = devices[device_index]; auto shard = get_shard_for_device(tensor, device); - shard = to_host_helper(shard, blocking, cq_id); + shard = to_host_helper(shard, blocking, cq_id, sub_device_ids); insert_buffer_and_shape_for_device(device, shard, host_tensor, device_index); } return host_tensor; @@ -603,21 +608,29 @@ Tensor to_host(const Tensor& tensor, bool blocking, uint8_t cq_id) { } } -template Tensor to_host(const Tensor& tensor, bool blocking, uint8_t cq_id); -template Tensor to_host(const Tensor& tensor, bool blocking, uint8_t cq_id); -template Tensor to_host(const Tensor& tensor, bool blocking, uint8_t cq_id); -template Tensor to_host(const Tensor& tensor, bool blocking, uint8_t cq_id); -template Tensor to_host(const Tensor& tensor, bool blocking, uint8_t cq_id); -template Tensor to_host(const Tensor& tensor, bool blocking, uint8_t cq_id); +template Tensor to_host( + const Tensor& tensor, bool blocking, uint8_t cq_id, tt::stl::Span sub_device_ids); +template Tensor to_host( + const Tensor& tensor, bool blocking, uint8_t cq_id, tt::stl::Span sub_device_ids); +template Tensor to_host( + const Tensor& tensor, bool blocking, uint8_t cq_id, tt::stl::Span sub_device_ids); +template Tensor to_host( + const Tensor& tensor, bool blocking, uint8_t cq_id, tt::stl::Span sub_device_ids); +template Tensor to_host( + const Tensor& tensor, bool blocking, uint8_t cq_id, tt::stl::Span sub_device_ids); +template Tensor to_host( + const Tensor& tensor, bool blocking, uint8_t cq_id, tt::stl::Span sub_device_ids); template <> -Tensor to_host(const Tensor& tensor, bool blocking, uint8_t cq_id) { - return to_host(tensor, blocking, cq_id); +Tensor to_host( + const Tensor& tensor, bool blocking, uint8_t cq_id, tt::stl::Span sub_device_ids) { + return to_host(tensor, blocking, cq_id, sub_device_ids); } template <> -Tensor to_host(const Tensor& tensor, bool blocking, uint8_t cq_id) { - return to_host(tensor, blocking, cq_id); +Tensor to_host( + const Tensor& tensor, bool blocking, uint8_t cq_id, tt::stl::Span sub_device_ids) { + return to_host(tensor, blocking, cq_id, sub_device_ids); } // ====================================================================================== @@ -662,7 +675,11 @@ Tensor to_host_sharded(const Tensor& tensor) { // ====================================================================================== template typename BufferType> -void write_data_to_device_buffer(CommandQueue& cq, const BufferType& host_buffer, DeviceBuffer device_buffer) { +void write_data_to_device_buffer( + CommandQueue& cq, + const BufferType& host_buffer, + DeviceBuffer device_buffer, + tt::stl::Span sub_device_ids) { ZoneScoped; // TODO(arakhmati): can we use generators in this function to go from `data_to_write` to `uint32_data`? // And effectively get rid of any additional allocation @@ -676,12 +693,12 @@ void write_data_to_device_buffer(CommandQueue& cq, const BufferType& host_buf const uint32_t* borrowed_buf_base = static_cast(host_buffer.data()); std::vector owned_copy_vec(borrowed_buf_base, borrowed_buf_base + borrowed_buf_size_words); owned_buffer::Buffer owned_copy(std::make_shared>(owned_copy_vec)); - EnqueueWriteBuffer(cq, device_buffer, owned_copy.get_ptr(), false); + EnqueueWriteBuffer(cq, device_buffer, owned_copy.get_ptr(), false, sub_device_ids); } else if constexpr (std::is_same_v, owned_buffer::Buffer>) { - EnqueueWriteBuffer(cq, device_buffer, host_buffer.get_ptr(), false); + EnqueueWriteBuffer(cq, device_buffer, host_buffer.get_ptr(), false, sub_device_ids); } } else { - EnqueueWriteBuffer(cq, device_buffer, host_buffer.data(), false); + EnqueueWriteBuffer(cq, device_buffer, host_buffer.data(), false, sub_device_ids); } } @@ -699,7 +716,8 @@ DeviceBuffer initialize_data_on_device( BufferType& data_to_write, Device* device, const TensorSpec& tensor_spec, - std::optional> queue = std::nullopt) { + uint8_t cq_id = ttnn::DefaultQueueId, + tt::stl::Span sub_device_ids = {}) { ZoneScoped; TT_ASSERT(device != nullptr); @@ -707,8 +725,7 @@ DeviceBuffer initialize_data_on_device( const char* TT_METAL_SLOW_DISPATCH_MODE = std::getenv("TT_METAL_SLOW_DISPATCH_MODE"); if (TT_METAL_SLOW_DISPATCH_MODE == nullptr) { - write_data_to_device_buffer( - queue.has_value() ? queue.value().get() : device->command_queue(), data_to_write, device_buffer); + write_data_to_device_buffer(device->command_queue(cq_id), data_to_write, device_buffer, sub_device_ids); } else { write_data_to_device_buffer(data_to_write, *device_buffer); } @@ -720,13 +737,14 @@ DeviceBuffer to_device_buffer( const Storage& storage, Device* device, const TensorSpec& tensor_spec, - std::optional> queue) { + uint8_t cq_id, + tt::stl::Span sub_device_ids) { return std::visit( - [&device, &tensor_spec, &queue](auto&& storage) -> DeviceBuffer { + [&device, &tensor_spec, cq_id, sub_device_ids](auto&& storage) -> DeviceBuffer { using StorageType = std::decay_t; if constexpr (std::is_same_v or std::is_same_v) { auto data_to_write = host_buffer::get_as(storage.buffer); - return initialize_data_on_device(data_to_write, device, tensor_spec, queue); + return initialize_data_on_device(data_to_write, device, tensor_spec, cq_id, sub_device_ids); } else if constexpr (std::is_same_v) { TT_THROW("Device storage doesn't support to_device_buffer"); } else if constexpr (std::is_same_v) { @@ -749,7 +767,8 @@ Tensor to_device( const Tensor& tensor, Device* target_device, const MemoryConfig& memory_config, - std::optional> queue) { + uint8_t cq_id, + tt::stl::Span sub_device_ids) { TT_FATAL(tensor.storage_type() != StorageType::DEVICE, "Tensor is already on device!"); if (tensor.storage_type() == StorageType::OWNED) { TT_FATAL(tensor.is_allocated(), "Need host buffer on device to exist to copy data to device!"); @@ -759,7 +778,8 @@ Tensor to_device( TensorSpec tensor_spec( tensor.get_logical_shape(), tensor.get_tensor_spec().tensor_layout().with_memory_config(memory_config)); - auto device_buffer = tensor_impl::to_device_buffer(tensor.get_storage(), target_device, tensor_spec, queue); + auto device_buffer = + tensor_impl::to_device_buffer(tensor.get_storage(), target_device, tensor_spec, cq_id, sub_device_ids); return Tensor(DeviceStorage{device_buffer}, tensor_spec); } @@ -767,40 +787,47 @@ template Tensor to_device( const Tensor& tensor, Device* target_device, const MemoryConfig& memory_config, - std::optional> queue); + uint8_t cq_id, + tt::stl::Span sub_device_ids); template Tensor to_device( const Tensor& tensor, Device* target_device, const MemoryConfig& memory_config, - std::optional> queue); + uint8_t cq_id, + tt::stl::Span sub_device_ids); template Tensor to_device( const Tensor& tensor, Device* target_device, const MemoryConfig& memory_config, - std::optional> queue); + uint8_t cq_id, + tt::stl::Span sub_device_ids); template Tensor to_device( const Tensor& tensor, Device* target_device, const MemoryConfig& memory_config, - std::optional> queue); + uint8_t cq_id, + tt::stl::Span sub_device_ids); template Tensor to_device( const Tensor& tensor, Device* target_device, const MemoryConfig& memory_config, - std::optional> queue); + uint8_t cq_id, + tt::stl::Span sub_device_ids); template Tensor to_device( const Tensor& tensor, Device* target_device, const MemoryConfig& memory_config, - std::optional> queue); + uint8_t cq_id, + tt::stl::Span sub_device_ids); template <> Tensor to_device( const Tensor& tensor, Device* target_device, const MemoryConfig& memory_config, - std::optional> queue) { - return to_device(tensor, target_device, memory_config, queue); + uint8_t cq_id, + tt::stl::Span sub_device_ids) { + return to_device(tensor, target_device, memory_config, cq_id, sub_device_ids); } template <> @@ -808,8 +835,9 @@ Tensor to_device( const Tensor& tensor, Device* target_device, const MemoryConfig& memory_config, - std::optional> queue) { - return to_device(tensor, target_device, memory_config, queue); + uint8_t cq_id, + tt::stl::Span sub_device_ids) { + return to_device(tensor, target_device, memory_config, cq_id, sub_device_ids); } // ====================================================================================== diff --git a/ttnn/cpp/ttnn/tensor/tensor_impl.hpp b/ttnn/cpp/ttnn/tensor/tensor_impl.hpp index 5a0ec30ecdd..87c34bdb199 100644 --- a/ttnn/cpp/ttnn/tensor/tensor_impl.hpp +++ b/ttnn/cpp/ttnn/tensor/tensor_impl.hpp @@ -167,8 +167,12 @@ DeviceBuffer allocate_buffer_on_device(Device* device, const TensorSpec& tensor_ template inline void read_data_from_device_buffer( - CommandQueue& cq, DeviceBuffer device_buffer, void* host_buffer_data, bool blocking) { - EnqueueReadBuffer(cq, device_buffer, host_buffer_data, blocking); + CommandQueue& cq, + DeviceBuffer device_buffer, + void* host_buffer_data, + bool blocking, + tt::stl::Span sub_device_ids = {}) { + EnqueueReadBuffer(cq, device_buffer, host_buffer_data, blocking, sub_device_ids); } template @@ -181,7 +185,11 @@ inline void read_data_from_device_buffer(DeviceBuffer device_buffer, std::vector // ====================================================================================== template -Tensor to_host(const Tensor& tensor, bool blocking = true, uint8_t cq_id = ttnn::DefaultQueueId); +Tensor to_host( + const Tensor& tensor, + bool blocking = true, + uint8_t cq_id = ttnn::DefaultQueueId, + tt::stl::Span sub_device_ids = {}); template Tensor to_host_sharded(const Tensor& tensor); @@ -191,7 +199,8 @@ Tensor to_device( const Tensor& tensor, Device* target_device, const MemoryConfig& memory_config, - std::optional> queue); + uint8_t cq_id = ttnn::DefaultQueueId, + tt::stl::Span sub_device_ids = {}); template Tensor to_layout(const Tensor& tensor, Layout target_layout); diff --git a/ttnn/cpp/ttnn/tensor/tensor_ops.cpp b/ttnn/cpp/ttnn/tensor/tensor_ops.cpp index 8a46d676dc7..f40690d2a44 100644 --- a/ttnn/cpp/ttnn/tensor/tensor_ops.cpp +++ b/ttnn/cpp/ttnn/tensor/tensor_ops.cpp @@ -24,7 +24,12 @@ namespace tt::tt_metal::tensor_ops { -Tensor tensor_to(const Tensor& input_tensor, Device* target_device, const MemoryConfig& mem_config) { +Tensor tensor_to( + const Tensor& input_tensor, + Device* target_device, + const MemoryConfig& mem_config, + uint8_t cq_id, + const std::vector& sub_device_ids) { ZoneScoped; GraphTracker::instance().track_function_start("Tensor::to", input_tensor, target_device, mem_config); // Tensor can be using borrowed storage. If so, when running in async mode, copy this tensor to owned storage. @@ -35,7 +40,12 @@ Tensor tensor_to(const Tensor& input_tensor, Device* target_device, const Memory // Record main thread ref count for tensors before pushing to queue. uint32_t device_tensor_ref_count = device_tensor.tensor_attributes->record_main_thread_ref_count(); uint32_t original_tensor_ref_count = async_safe_tensor.tensor_attributes->record_main_thread_ref_count(); - target_device->push_work([async_safe_tensor, device_tensor, mem_config, target_device]() mutable { + target_device->push_work([async_safe_tensor, + device_tensor, + mem_config, + target_device, + cq_id, + sub_device_ids]() mutable { if (async_safe_tensor.storage_type() == StorageType::DEVICE) { TT_ASSERT(async_safe_tensor.device() == target_device && "Currently do not support moving between devices"); device_tensor.populate_buffers_and_metadata(async_safe_tensor); @@ -46,7 +56,7 @@ Tensor tensor_to(const Tensor& input_tensor, Device* target_device, const Memory async_safe_tensor.get_dtype(), async_safe_tensor.get_layout()); auto local_tensor = - tensor_impl::to_device_wrapper(async_safe_tensor, target_device, mem_config, std::nullopt); + tensor_impl::to_device_wrapper(async_safe_tensor, target_device, mem_config, cq_id, sub_device_ids); // Populate device tensor device_tensor.populate_buffers_and_metadata(local_tensor); } @@ -61,7 +71,12 @@ Tensor tensor_to(const Tensor& input_tensor, Device* target_device, const Memory return device_tensor; } -Tensor tensor_to(const Tensor& input_tensor, const std::vector& workers, const MemoryConfig& mem_config) { +Tensor tensor_to( + const Tensor& input_tensor, + const std::vector& workers, + const MemoryConfig& mem_config, + uint8_t cq_id, + const std::vector& sub_device_ids) { ZoneScoped; GraphTracker::instance().track_function_start("Tensor::to", input_tensor, workers, mem_config); TT_FATAL( @@ -72,10 +87,17 @@ Tensor tensor_to(const Tensor& input_tensor, const std::vector& workers uint32_t num_workers = workers.size(); for (int worker_index = 0; worker_index < workers.size(); ++worker_index) { auto& worker = workers[worker_index]; - worker->push_work([worker, input_tensor, device_tensor, mem_config, num_workers, worker_index]() mutable { + worker->push_work([worker, + input_tensor, + device_tensor, + mem_config, + num_workers, + worker_index, + cq_id, + sub_device_ids]() mutable { auto shard = get_shard_for_device(input_tensor, worker, worker_index); if (shard.storage_type() == StorageType::OWNED) { - shard = tensor_impl::to_device_wrapper(shard, worker, mem_config, std::nullopt); + shard = tensor_impl::to_device_wrapper(shard, worker, mem_config, cq_id, sub_device_ids); } insert_buffer_and_shape_for_device(worker, shard, device_tensor, worker_index); uint32_t num_workers_completed = (device_tensor.tensor_attributes->num_workers_completed)++; @@ -93,7 +115,8 @@ Tensor tensor_to(const Tensor& input_tensor, const std::vector& workers return device_tensor; } -Tensor tensor_cpu(const Tensor& input_tensor, bool blocking, uint8_t cq_id) { +Tensor tensor_cpu( + const Tensor& input_tensor, bool blocking, uint8_t cq_id, const std::vector& sub_device_ids) { ZoneScoped; GraphTracker::instance().track_function_start("Tensor::cpu", input_tensor, blocking); auto workers = input_tensor.get_workers(blocking); @@ -111,19 +134,20 @@ Tensor tensor_cpu(const Tensor& input_tensor, bool blocking, uint8_t cq_id) { uint32_t original_tensor_ref_count = input_tensor.tensor_attributes->record_main_thread_ref_count(); for (int worker_index = 0; worker_index < workers.size(); worker_index++) { auto target_device = workers[worker_index]; - target_device->push_work([host_tensor, blocking, target_device, input_tensor, worker_index, cq_id]() mutable { - TT_ASSERT( - input_tensor.storage_type() == StorageType::DEVICE or - input_tensor.storage_type() == StorageType::MULTI_DEVICE, - "Can only use worker queue for cpu call if tensor is on device."); - auto shard = get_shard_for_device(input_tensor, target_device); - shard = tensor_impl::to_host_wrapper(shard, blocking, cq_id); - insert_buffer_and_shape_for_device(target_device, shard, host_tensor, worker_index); - uint32_t num_workers_completed = (host_tensor.tensor_attributes->num_workers_completed)++; - if (not num_workers_completed) { - host_tensor.set_tensor_spec(input_tensor.get_tensor_spec()); - } - }); + target_device->push_work( + [host_tensor, blocking, target_device, input_tensor, worker_index, cq_id, sub_device_ids]() mutable { + TT_ASSERT( + input_tensor.storage_type() == StorageType::DEVICE or + input_tensor.storage_type() == StorageType::MULTI_DEVICE, + "Can only use worker queue for cpu call if tensor is on device."); + auto shard = get_shard_for_device(input_tensor, target_device); + shard = tensor_impl::to_host_wrapper(shard, blocking, cq_id, sub_device_ids); + insert_buffer_and_shape_for_device(target_device, shard, host_tensor, worker_index); + uint32_t num_workers_completed = (host_tensor.tensor_attributes->num_workers_completed)++; + if (not num_workers_completed) { + host_tensor.set_tensor_spec(input_tensor.get_tensor_spec()); + } + }); } if (blocking) { @@ -181,15 +205,15 @@ Tensor tensor_to(const Tensor& input_tensor, Layout target_layout, distributed:: ZoneScoped; GraphTracker::instance().track_function_start("Tensor::to", input_tensor, target_layout, mesh_device); if (mesh_device) { - auto workers = ttnn::distributed::distribute_tensor_to_mesh(input_tensor, *mesh_device); + auto workers = ttnn::distributed::get_mapped_devices(input_tensor, *mesh_device); TT_FATAL( validate_worker_modes(workers), "All device threads/workers must be running in the same mode (ASYNC or SYNC)"); std::optional distributed_config = std::nullopt; - if (std::holds_alternative(input_tensor.get_storage())) { - auto& host_storage = std::get(input_tensor.get_storage()); - distributed_config = host_storage.strategy; + if (auto* host_storage = std::get_if(&input_tensor.get_storage()); + host_storage != nullptr) { + distributed_config = host_storage->strategy; } Tensor tensor_modified_layout = Tensor(workers.size(), distributed_config); for (int worker_index = 0; worker_index < workers.size(); ++worker_index) { diff --git a/ttnn/cpp/ttnn/tensor/tensor_ops.hpp b/ttnn/cpp/ttnn/tensor/tensor_ops.hpp index 98f8103c151..b8edff425f8 100644 --- a/ttnn/cpp/ttnn/tensor/tensor_ops.hpp +++ b/ttnn/cpp/ttnn/tensor/tensor_ops.hpp @@ -20,15 +20,26 @@ class Device; namespace tt::tt_metal::tensor_ops { -Tensor tensor_to(const Tensor& input_tensor, Device* target_device, const MemoryConfig& mem_config); +Tensor tensor_to( + const Tensor& input_tensor, + Device* target_device, + const MemoryConfig& mem_config, + uint8_t cq_id, + const std::vector& sub_device_ids); -Tensor tensor_to(const Tensor& input_tensor, const std::vector& workers, const MemoryConfig& mem_config); +Tensor tensor_to( + const Tensor& input_tensor, + const std::vector& workers, + const MemoryConfig& mem_config, + uint8_t cq_id, + const std::vector& sub_device_ids); Tensor tensor_to(const Tensor& input_tensor, Layout target_layout, Device* worker); Tensor tensor_to(const Tensor& input_tensor, Layout target_layout, distributed::MeshDevice* mesh_device); -Tensor tensor_cpu(const Tensor& input_tensor, bool blocking, uint8_t cq_id); +Tensor tensor_cpu( + const Tensor& input_tensor, bool blocking, uint8_t cq_id, const std::vector& sub_device_ids); Tensor tensor_cpu_sharded(const Tensor& input_tensor); diff --git a/ttnn/cpp/ttnn/tensor/tensor_utils.hpp b/ttnn/cpp/ttnn/tensor/tensor_utils.hpp index 96ce34431b9..3c2565299b9 100644 --- a/ttnn/cpp/ttnn/tensor/tensor_utils.hpp +++ b/ttnn/cpp/ttnn/tensor/tensor_utils.hpp @@ -142,40 +142,6 @@ void insert_buffer_and_shape_for_device( Tensor copy_borrowed_tensor_in_async_mode(Device* worker, const Tensor& tensor); -template -auto get_device_tensors(Device* device, const TensorContainer& input_tensors) { - // Could be Tensor, const Tensor, std::optional, or std::optional - using ValueType = typename TensorContainer::value_type; - - // We need a way to extract the underlying Tensor type (const or non-const) from ValueType - // and to decide whether we are dealing with an optional type. - using IsOptional = std::conditional_t< - std::is_same_v> || std::is_same_v>, - std::true_type, - std::false_type>; - using TensorType = std::conditional_t< - std::is_same_v> || std::is_same_v, - Tensor, - const Tensor>; - - // Result container type adjustment based on input type - using ResultType = std::conditional_t, TensorType>; - std::vector transformed_tensors; - - for (const auto& tensor : input_tensors) { - if constexpr (IsOptional::value) { - if (tensor.has_value()) { - transformed_tensors.emplace_back(get_device_tensor(tensor.value(), device)); - } else { - transformed_tensors.emplace_back(std::nullopt); - } - } else { - transformed_tensors.emplace_back(get_device_tensor(tensor, device)); - } - } - return transformed_tensors; -} - inline bool is_tensor_on_device(const ttnn::Tensor& tensor) { return tensor.storage_type() == StorageType::DEVICE; } inline bool is_tensor_on_multi_device(const ttnn::Tensor& tensor) { @@ -196,5 +162,4 @@ inline uint32_t get_batch_size(const T& shape) { } } // namespace tt_metal - } // namespace tt diff --git a/ttnn/cpp/ttnn/tensor/types.cpp b/ttnn/cpp/ttnn/tensor/types.cpp index ccb86100718..6ba3893c8a1 100644 --- a/ttnn/cpp/ttnn/tensor/types.cpp +++ b/ttnn/cpp/ttnn/tensor/types.cpp @@ -6,9 +6,7 @@ #include "ttnn/tensor/types.hpp" #include "ttnn/tensor/tensor_impl.hpp" -namespace ttnn { - -namespace types { +namespace ttnn::types { const Shape Shape::to_rank(size_t new_rank) const { auto padded_shape = value; @@ -31,42 +29,10 @@ const Shape Shape::to_rank(size_t new_rank) const { return Shape(std::move(new_shape), std::move(new_padded_shape)); } -} // namespace types - -} // namespace ttnn +} // namespace ttnn::types namespace tt::tt_metal { -static DistributedTensorConfig create_shard_distributed_tensor_config( - const std::unordered_map& metadata) { - return ShardTensor(std::stoi(metadata.at("shard_dim"))); -} -static DistributedTensorConfig create_shard_2d_distributed_tensor_config( - const std::unordered_map& metadata) { - return ShardTensor2D(ShardMesh(std::stoi(metadata.at("mesh_shape_y")), std::stoi(metadata.at("mesh_shape_x")))); -} -static DistributedTensorConfig create_replicate_distributed_tensor_config( - const std::unordered_map& metadata) { - if (auto it = metadata.find("replication_factor"); it != metadata.end()) { - return ReplicateTensor(std::stoi(it->second)); - } - TT_THROW("Unsupported Replication strategy:"); -} - -DistributedTensorConfig get_distributed_tensor_config(const std::unordered_map& metadata) { - if (auto it = metadata.find("strategy"); it != metadata.end()) { - const std::string& strategy = it->second; - if (strategy == "shard") { - return create_shard_distributed_tensor_config(metadata); - } else if (strategy == "shard_2d") { - return create_shard_2d_distributed_tensor_config(metadata); - } else if (strategy == "replicate") { - return create_replicate_distributed_tensor_config(metadata); - } - } - TT_THROW("Unsupported DistributedTensorConfig strategy:"); -} - tt::DataFormat datatype_to_dataformat_converter(tt::tt_metal::DataType datatype) { switch (datatype) { case tt::tt_metal::DataType::BFLOAT16: return tt::DataFormat::Float16_b; @@ -218,20 +184,6 @@ Array4D LegacyShape::to_array_4D() const { return ret_array; } -bool operator==(const ReplicateTensor& a, const ReplicateTensor& b) { - return a.replication_factor == - b.replication_factor; // All instances are considered equal because there are no data members. -} -bool operator==(const AllGatherTensor&, const AllGatherTensor&) { - return true; // All instances are considered equal because there are no data members. -} -bool operator==(const ShardTensor& lhs, const ShardTensor& rhs) { - return lhs.shard_dimension == rhs.shard_dimension; // Equal if they have the same shard_dimension. -} -bool operator==(const ShardTensor2D& lhs, const ShardTensor2D& rhs) { - return lhs.shard_mesh == rhs.shard_mesh; // Equal if they have the same shard_mesh. -} - bool operator==(const tt::tt_metal::LegacyShape& shape_a, const tt::tt_metal::LegacyShape& shape_b) { if (shape_a.rank() != shape_b.rank()) { return false; diff --git a/ttnn/cpp/ttnn/tensor/types.hpp b/ttnn/cpp/ttnn/tensor/types.hpp index 3666c710113..58299ece07b 100644 --- a/ttnn/cpp/ttnn/tensor/types.hpp +++ b/ttnn/cpp/ttnn/tensor/types.hpp @@ -18,6 +18,7 @@ #include "tt_metal/tt_stl/concepts.hpp" #include "tt_metal/tt_stl/reflection.hpp" #include "tt_metal/tt_stl/span.hpp" +#include "ttnn/distributed/distributed_tensor_config.hpp" #include "ttnn/tensor/host_buffer/types.hpp" #include "ttnn/cpp/ttnn/tensor/enum_types.hpp" @@ -41,6 +42,25 @@ enum class DataType { INVALID = 8, }; +template +consteval inline DataType convert_to_data_type() { + if constexpr (std::is_same_v) { + return DataType::UINT8; + } else if constexpr (std::is_same_v) { + return DataType::UINT16; + } else if constexpr (std::is_same_v) { + return DataType::INT32; + } else if constexpr (std::is_same_v) { + return DataType::UINT32; + } else if constexpr (std::is_same_v) { + return DataType::FLOAT32; + } else if constexpr (std::is_same_v) { + return DataType::BFLOAT16; + } else { + static_assert(tt::stl::concepts::always_false_v, "Unsupported DataType!"); + } +} + inline bool is_floating_point(DataType dtype) { switch (dtype) { case DataType::BFLOAT16: @@ -59,31 +79,6 @@ enum class StorageType { MULTI_DEVICE_HOST, // host storage for multi-device context }; -struct AllGatherTensor {}; -bool operator==(const AllGatherTensor &, const AllGatherTensor &); -struct ReplicateTensor { - int replication_factor = 1; - ReplicateTensor() = default; - ReplicateTensor(int replication_factor) : replication_factor(replication_factor) {} -}; -bool operator==(const ReplicateTensor &, const ReplicateTensor &); -struct ShardTensor { - int shard_dimension; - ShardTensor(int shard_dimension) : shard_dimension(shard_dimension) {} -}; -bool operator==(const ShardTensor &lhs, const ShardTensor &rhs); - -using ShardMesh = std::pair; // (y,x) -struct ShardTensor2D { - ShardMesh shard_mesh; // logic 2D grid that defines the mapping of shards to devices - ShardTensor2D(ShardMesh mesh) : shard_mesh(std::move(mesh)) {} -}; -bool operator==(const ShardTensor2D &lhs, const ShardTensor2D &rhs); - -// DistributedTensorConfig is a variant of different ways in which a tensor can be distributed across devices. -using DistributedTensorConfig = std::variant; -DistributedTensorConfig get_distributed_tensor_config(const std::unordered_map &metadata); - tt::DataFormat datatype_to_dataformat_converter(DataType datatype); static constexpr std::size_t MAX_NUM_DIMENSIONS = 8; diff --git a/ttnn/ttnn/__init__.py b/ttnn/ttnn/__init__.py index 71f4f748660..4f613ca11ef 100644 --- a/ttnn/ttnn/__init__.py +++ b/ttnn/ttnn/__init__.py @@ -153,6 +153,7 @@ def manage_config(name, value): WormholeComputeKernelConfig, GrayskullComputeKernelConfig, MeshShape, + MeshOffset, UnaryWithParam, UnaryOpType, BinaryOpType, @@ -186,7 +187,10 @@ def manage_config(name, value): format_output_tensor, pad_to_tile_shape, SubDevice, + SubDeviceId, SubDeviceManagerId, + DefaultQueueId, + init_device_compute_kernel_config, ) from ttnn.profiler import start_tracy_zone, stop_tracy_zone, tracy_message, tracy_frame @@ -207,6 +211,7 @@ def manage_config(name, value): load_memory_config, dump_stack_trace_on_segfault, num_cores_to_corerangeset, + num_cores_to_corerangeset_in_subcoregrids, ) import ttnn.reflection diff --git a/ttnn/ttnn/core.py b/ttnn/ttnn/core.py index 92d840f7868..a1879be24f9 100644 --- a/ttnn/ttnn/core.py +++ b/ttnn/ttnn/core.py @@ -59,6 +59,23 @@ def num_cores_to_corerangeset( ) +def num_cores_to_corerangeset_in_subcoregrids( + start_core: ttnn.CoreCoord, + target_num_cores: int, + sub_core_grids: ttnn.CoreRangeSet, + row_wise: bool = False, +): + """ + Create a CoreRangeSet containing the specified number of cores starting from start_core in given subcoregrids + """ + return ttnn._ttnn.operations.core.num_cores_to_corerangeset_in_subcoregrids( + start_core, + target_num_cores, + sub_core_grids, + row_wise, + ) + + def has_tile_padding(tensor, *, dim=None): if dim is not None: rank = tensor.shape.rank diff --git a/ttnn/ttnn/device.py b/ttnn/ttnn/device.py index e620c800a6c..6cbfaa85ead 100644 --- a/ttnn/ttnn/device.py +++ b/ttnn/ttnn/device.py @@ -6,6 +6,7 @@ from typing import Optional, List import ttnn +import os def get_device_core_grid(device): @@ -27,6 +28,7 @@ def get_device_core_grid(device): DEFAULT_TRACE_REGION_SIZE = ttnn._ttnn.device.DEFAULT_TRACE_REGION_SIZE open_device = ttnn._ttnn.device.open_device +init_device_compute_kernel_config = ttnn._ttnn.operations.core.init_device_compute_kernel_config def close_device(device: "ttnn.device.Device"): @@ -132,12 +134,25 @@ def dump_device_memory_state(device, prefix=""): ttnn._ttnn.device.DumpDeviceMemoryState(device, prefix) -def is_wormhole_b0(device): - return device.arch() == ttnn._ttnn.device.Arch.WORMHOLE_B0 +def is_wormhole_b0(device=None): + if device is not None: + return device.arch() == ttnn._ttnn.device.Arch.WORMHOLE_B0 + ARCH_NAME = os.environ.get("ARCH_NAME", os.environ.get("TT_ARCH_NAME", "")).lower() + return "wormhole_b0" in ARCH_NAME -def is_grayskull(device): - return device.arch() == ttnn._ttnn.device.Arch.GRAYSKULL +def is_grayskull(device=None): + if device is not None: + return device.arch() == ttnn._ttnn.device.Arch.GRAYSKULL + ARCH_NAME = os.environ.get("ARCH_NAME", os.environ.get("TT_ARCH_NAME", "")).lower() + return "grayskull" in ARCH_NAME + + +def is_blackhole(device=None): + if device is not None: + return device.arch() == ttnn._ttnn.device.Arch.BLACKHOLE + ARCH_NAME = os.environ.get("ARCH_NAME", os.environ.get("TT_ARCH_NAME", "")).lower() + return "blackhole" in ARCH_NAME SetDefaultDevice = ttnn._ttnn.device.SetDefaultDevice @@ -147,6 +162,9 @@ def is_grayskull(device): pad_to_tile_shape = ttnn._ttnn.device.pad_to_tile_shape SubDevice = ttnn._ttnn.device.SubDevice +SubDeviceId = ttnn._ttnn.device.SubDeviceId SubDeviceManagerId = ttnn._ttnn.device.SubDeviceManagerId +DefaultQueueId = ttnn._ttnn.device.DefaultQueueId + __all__ = [] diff --git a/ttnn/ttnn/distributed/distributed.py b/ttnn/ttnn/distributed/distributed.py index a4f329eb58a..65a902d11cf 100644 --- a/ttnn/ttnn/distributed/distributed.py +++ b/ttnn/ttnn/distributed/distributed.py @@ -139,7 +139,7 @@ def open_mesh_device( trace_region_size: int = ttnn._ttnn.device.DEFAULT_TRACE_REGION_SIZE, num_command_queues: int = 1, dispatch_core_config: ttnn.DispatchCoreConfig = ttnn.DispatchCoreConfig(), - offset: Tuple[int, int] = (0, 0), + offset: ttnn.MeshOffset = ttnn.MeshOffset(row=0, col=0), physical_device_ids: List[int] = [], mesh_type: "MeshType" = MeshType.RowMajor, ): @@ -152,7 +152,8 @@ def open_mesh_device( trace_region_size (int, optional): Size of the trace region. Defaults to ttnn._ttnn.device.DEFAULT_TRACE_REGION_SIZE. num_command_queues (int, optional): Number of command queues. Defaults to 1. dispatch_core_type (int, optional): Type of dispatch core. Defaults to DispatchCoreType.WORKER. - offset (Tuple[int, int], optional): Offset in logical mesh coordinates for the mesh device. Defaults to (0, 0). + offset (ttnn.MeshOffset, optional): Offset in logical mesh coordinates for the mesh device. Defaults to (0, 0). + physical_device_ids (List[int], optional): List of physical device IDs to use. Defaults to []. mesh_type (MeshType, optional): Defines type of mesh requested. Type imposes connectivity constraints and defines device iteration order. Returns: @@ -160,7 +161,7 @@ def open_mesh_device( """ return ttnn._ttnn.multi_device.MeshDevice( - mesh_shape=mesh_shape.as_tuple(), + mesh_shape=mesh_shape, l1_small_size=l1_small_size, trace_region_size=trace_region_size, num_command_queues=num_command_queues, @@ -208,19 +209,23 @@ def create_mesh_device( close_mesh_device(mesh_device) -def synchronize_devices(devices: Union["ttnn.Device", "ttnn.MeshDevice"], queue_id: Optional[int] = None) -> None: +def synchronize_devices( + devices: Union["ttnn.Device", "ttnn.MeshDevice"], + queue_id: Optional[int] = ttnn.DefaultQueueId, + sub_device_ids: List[ttnn.SubDeviceId] = [], +) -> None: """ - synchronize_devices(devices: Union[ttnn.Device, ttnn.MeshDevice], queue_id: Optional[int] = None) -> None: + synchronize_devices(devices: Union[ttnn.Device, ttnn.MeshDevice], queue_id: Optional[int] = None, sub_device_ids: List[ttnn.SubDeviceId] = []) -> None: Synchronize the devices with host by waiting for all operations to complete. If queue_id is provided then only the operations associated with that queue_id are waited for, otherwise operations for all command queues are waited on. """ if isinstance(devices, ttnn.Device): - ttnn._ttnn.device.synchronize_device(devices, queue_id) + ttnn._ttnn.device.synchronize_device(devices, queue_id, sub_device_ids) else: for device in devices.get_device_ids(): - ttnn._ttnn.device.synchronize_device(devices.get_device(device), queue_id) + ttnn._ttnn.device.synchronize_device(devices.get_device(device), queue_id, sub_device_ids) class TensorToMesh: @@ -455,6 +460,8 @@ def compose(self, tensor: ttnn.Tensor) -> "torch.Tensor": return torch.cat(device_shards_converted_to_torch, dim=self.concat_dim) +# TODO: #15061 - Remove this function, as it does not abide to the MeshToTensor interface. +# Instead, lift this implementation to the caller. class ListMeshToTensor(MeshToTensor): def __init__(self, mesh_device: MeshDevice): self.mesh_device = mesh_device diff --git a/ttnn/ttnn/operations/conv1d.py b/ttnn/ttnn/operations/conv1d.py index e979a12b21d..ef8187cbd78 100644 --- a/ttnn/ttnn/operations/conv1d.py +++ b/ttnn/ttnn/operations/conv1d.py @@ -28,8 +28,11 @@ def Conv1d( groups: int = 1, bias_tensor: ttnn.Tensor = None, conv_config: Conv1dConfig = None, # config overrides by user + compute_config: ttnn.DeviceComputeKernelConfig = None, conv_op_cache={}, # basic conv object caching in python needed for intermediate refactoring. Not needed after full op refactoring in C++. debug=False, + return_output_dim=False, + return_weights_and_bias=False, ) -> Tuple[ttnn.Tensor, int, int, ttnn.Tensor, ttnn.Tensor]: # Reshape the input and weight tensors to 4D for conv2d operation # Should be no-op as input_tensor is in RM layout @@ -60,14 +63,17 @@ def Conv1d( groups=groups, bias_tensor=bias_tensor, conv_config=conv_config, + compute_config=compute_config, ) - return ( - output_tensor_new, - output_length_new, - weight_tensor_on_dev_new, - bias_tensor_on_dev_new, - ) + if return_output_dim and return_weights_and_bias: + return output_tensor_new, output_length_new, [weight_tensor_on_dev_new, bias_tensor_on_dev_new] + elif return_weights_and_bias: + return output_tensor_new, [weight_tensor_on_dev_new, bias_tensor_on_dev_new] + elif return_output_dim: + return output_tensor_new, output_length_new + else: + return output_tensor_new __all__ = [] diff --git a/ttnn/ttnn/operations/conv2d.py b/ttnn/ttnn/operations/conv2d.py index ef2859c43a2..84079a56653 100644 --- a/ttnn/ttnn/operations/conv2d.py +++ b/ttnn/ttnn/operations/conv2d.py @@ -176,11 +176,20 @@ def conv2d( groups: int = 1, bias_tensor: ttnn.Tensor = None, conv_config: Conv2dConfig = None, # config overrides by user + compute_config=None, # compute config overrides by user memory_config: ttnn.MemoryConfig = None, # memory config overrides by user conv_op_cache={}, # basic conv object caching in python needed for intermediate refactoring. Not needed after full op refactoring in C++. debug=False, # ignored + return_output_dim=False, + return_weights_and_bias=False, ) -> Tuple[ttnn.Tensor, int, int, ttnn.Tensor, ttnn.Tensor]: - return ttnn._ttnn.operations.conv.conv2d( + ( + conv_output, + output_height, + output_width, + prepared_device_weight, + prepared_device_bias, + ) = ttnn._ttnn.operations.conv.conv2d( input_tensor=input_tensor, weight_tensor=weight_tensor, device=device, @@ -196,8 +205,18 @@ def conv2d( groups=groups, bias_tensor=bias_tensor, conv_config=conv_config, + compute_config=compute_config, memory_config=memory_config, ) + if return_output_dim and return_weights_and_bias: + return conv_output, [output_height, output_width], [prepared_device_weight, prepared_device_bias] + elif return_weights_and_bias: + return conv_output, [prepared_device_weight, prepared_device_bias] + elif return_output_dim: + return conv_output, [output_height, output_width] + else: + return conv_output + __all__ = [] diff --git a/ttnn/ttnn/operations/core.py b/ttnn/ttnn/operations/core.py index 3eeda3a90b6..24480037a3f 100644 --- a/ttnn/ttnn/operations/core.py +++ b/ttnn/ttnn/operations/core.py @@ -4,7 +4,7 @@ import math import pathlib -from typing import Any, Callable, Dict, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch import ttnn.decorators @@ -158,6 +158,8 @@ def from_torch( device: Optional[ttnn.Device] = None, memory_config: Optional[ttnn.MemoryConfig] = None, mesh_mapper: Optional[ttnn.TensorToMesh] = None, + cq_id: Optional[int] = ttnn.DefaultQueueId, + sub_device_ids: List[ttnn.SubDeviceId] = [], ) -> ttnn.Tensor: """ Converts the `torch.Tensor` tensor into a `ttnn.Tensor`. For bfloat8_b or bfloat4_b format, the function itself is called twice, @@ -176,6 +178,8 @@ def from_torch( device (ttnn.Device, optional): the desired `ttnn` device. Defaults to `None`. memory_config (ttnn.MemoryConfig, optional): The desired `ttnn` memory configuration. Defaults to `None`. mesh_mapper (ttnn.TensorToMesh, optional): The desired `ttnn` mesh mapper. Defaults to `None`. + cq_id (int, optional): The command queue ID to use. Defaults to `0`. + sub_device_ids (List[ttnn.SubDeviceId], optional): The sub-device IDs to wait on. Defaults to all sub-devices. Returns: ttnn.Tensor: The resulting `ttnn` tensor. @@ -225,7 +229,7 @@ def from_torch( if device is not None: if memory_config is None: memory_config = ttnn.DRAM_MEMORY_CONFIG - tensor = ttnn.to_device(tensor, device, memory_config=memory_config) + tensor = ttnn.to_device(tensor, device, memory_config=memory_config, cq_id=cq_id, sub_device_ids=sub_device_ids) if shape_with_padding is not None and shape_with_padding != tensor.shape and mesh_mapper is None: tensor = ttnn.reshape(tensor, shape_with_padding) @@ -262,7 +266,8 @@ def to_torch( torch_rank: Optional[int] = None, mesh_composer: Optional[ttnn.MeshToTensor] = None, device: Optional[ttnn.Device] = None, - cq_id: Optional[int] = 0, + cq_id: Optional[int] = ttnn.DefaultQueueId, + sub_device_ids: List[ttnn.SubDeviceId] = [], ) -> "torch.Tensor": """ Converts the `ttnn.Tensor` tensor into a `torch.Tensor`. It does not call to_layout for bfloat8_b or bfloat4_b as we now convert @@ -278,6 +283,7 @@ def to_torch( mesh_composer (ttnn.MeshToTensor, optional): The desired `ttnn` mesh composer. Defaults to `None`. device (ttnn.Device, optional): The `ttnn` device of the input tensor. Defaults to `None`. cq_id (int, optional): The command queue ID to use. Defaults to `0`. + sub_device_ids (List[ttnn.SubDeviceId], optional): The sub-device IDs to wait on. Defaults to all sub-devices. Returns: torch.Tensor: The converted `torch` tensor. @@ -290,7 +296,7 @@ def to_torch( [ 0.9023, -0.5820, 0.5312]], dtype=torch.bfloat16) """ if ttnn.is_tensor_storage_on_device(tensor): - tensor = ttnn.from_device(tensor, cq_id=cq_id) + tensor = ttnn.from_device(tensor, cq_id=cq_id, sub_device_ids=sub_device_ids) if (tensor.layout != ttnn.ROW_MAJOR_LAYOUT) and not ( tensor.dtype == ttnn.bfloat8_b or tensor.dtype == ttnn.bfloat4_b diff --git a/ttnn/ttnn/types.py b/ttnn/ttnn/types.py index 0fd3f775313..b210fe90f5f 100644 --- a/ttnn/ttnn/types.py +++ b/ttnn/ttnn/types.py @@ -58,25 +58,14 @@ class CoreRange: end: CoreGrid -@dataclasses.dataclass -class MeshShape: - y: int - x: int - - @property - def num_devices(self): - return self.y * self.x - - def as_tuple(self): - return (self.y, self.x) - - class ShardStrategy(Enum): HEIGHT = 1 WIDTH = 2 BLOCK = 3 +MeshShape = ttnn._ttnn.multi_device.MeshShape +MeshOffset = ttnn._ttnn.multi_device.MeshOffset ShardOrientation = ttnn._ttnn.tensor.ShardOrientation ShardMode = ttnn._ttnn.tensor.ShardMode ShardSpec = ttnn._ttnn.tensor.ShardSpec