diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md
index 7f50c8327c9..43299fa0421 100644
--- a/.github/pull_request_template.md
+++ b/.github/pull_request_template.md
@@ -13,4 +13,5 @@ Summarize the changes made and its impact.
- [ ] Blackhole Post commit (if applicable)
- [ ] Model regression CI testing passes (if applicable)
- [ ] Device performance regression CI testing passes (if applicable)
+- [ ] **(For models and ops writers)** Full [new models](https://github.com/tenstorrent/tt-metal/actions/workflows/full-new-models-suite.yaml) tests passes
- [ ] New/Existing tests provide coverage for changes
diff --git a/.github/workflows/_produce-data.yaml b/.github/workflows/_produce-data.yaml
index fd547b44aa9..c66c5bb5702 100644
--- a/.github/workflows/_produce-data.yaml
+++ b/.github/workflows/_produce-data.yaml
@@ -22,7 +22,9 @@ on:
- "(Single-card) Model perf tests"
- "(Single-card) Device perf tests"
- "(Single-card) Demo tests"
+ - "(Single-card) Tests for new models"
- "Nightly fast dispatch tests"
+ - "(Single-card) Tests for new models"
- "(T3K) T3000 demo tests"
- "(T3K) T3000 model perf tests"
- "(T3K) T3000 perplexity tests"
@@ -39,6 +41,7 @@ on:
- "(TGG) TGG frequent tests"
- "ttnn - Run sweeps"
- "Blackhole post-commit tests"
+ - "Custom test dispatch"
types:
- completed
diff --git a/.github/workflows/all-post-commit-workflows.yaml b/.github/workflows/all-post-commit-workflows.yaml
index dcaf74380c1..57bae427f2f 100644
--- a/.github/workflows/all-post-commit-workflows.yaml
+++ b/.github/workflows/all-post-commit-workflows.yaml
@@ -2,6 +2,11 @@ name: "All post-commit tests"
on:
workflow_call:
+ inputs:
+ build-type:
+ required: false
+ default: Release
+ type: string
workflow_dispatch:
inputs:
build-type:
diff --git a/.github/workflows/fast-dispatch-full-regressions-and-models-impl.yaml b/.github/workflows/fast-dispatch-full-regressions-and-models-impl.yaml
index 762324fb3a3..0af646345b1 100644
--- a/.github/workflows/fast-dispatch-full-regressions-and-models-impl.yaml
+++ b/.github/workflows/fast-dispatch-full-regressions-and-models-impl.yaml
@@ -149,8 +149,8 @@ jobs:
fail-fast: false
matrix:
test-config:
- - model: "wh_b0_unstable"
- cmd: ./tests/scripts/single_card/nightly/run_wh_b0_unstable.sh
+ - model: "stable_diffusion"
+ cmd: pytest --timeout 900 -n auto tests/nightly/single_card/stable_diffusion
- model: "mamba 1"
cmd: pytest --timeout 900 -n auto tests/nightly/single_card/mamba --splits 6 --group 1
- model: "mamba 2"
diff --git a/.github/workflows/full-new-models-suite.yaml b/.github/workflows/full-new-models-suite.yaml
new file mode 100644
index 00000000000..15ecd80c104
--- /dev/null
+++ b/.github/workflows/full-new-models-suite.yaml
@@ -0,0 +1,60 @@
+name: "(Single-card) Tests for new models"
+
+on:
+ workflow_dispatch:
+ inputs:
+ build-type:
+ required: false
+ default: Release
+ type: choice
+ options:
+ - Release
+ - Debug
+ - RelWithDebInfo
+ - CI
+
+permissions:
+ actions: read
+ contents: write
+ pull-requests: write
+ pages: write
+ id-token: write
+ packages: write
+
+jobs:
+ build-docker-image-2004:
+ uses: ./.github/workflows/build-docker-artifact.yaml
+ secrets: inherit
+ with:
+ os: ubuntu-20.04-amd64
+ build-artifact:
+ needs: build-docker-image-2004
+ uses: ./.github/workflows/build-artifact.yaml
+ secrets: inherit
+ with:
+ build-docker: false
+ build-type: ${{ inputs.build-type || 'Release' }}
+ build-artifact-profiler:
+ needs: build-docker-image-2004
+ uses: ./.github/workflows/build-artifact.yaml
+ with:
+ tracy: true
+ build-docker: false
+ build-type: ${{ inputs.build-type || 'Release' }}
+ secrets: inherit
+ device-perf-single-card:
+ needs: build-artifact-profiler
+ uses: ./.github/workflows/perf-device-models-impl.yaml
+ secrets: inherit
+ e2e-model-perf-single-card:
+ needs: build-artifact
+ uses: ./.github/workflows/perf-models-impl.yaml
+ secrets: inherit
+ nightly-single-card:
+ needs: build-artifact
+ uses: ./.github/workflows/fast-dispatch-full-regressions-and-models-impl.yaml
+ secrets: inherit
+ demos-single-card:
+ needs: build-artifact
+ uses: ./.github/workflows/single-card-demo-tests-impl.yaml
+ secrets: inherit
diff --git a/.github/workflows/t3000-demo-tests-impl.yaml b/.github/workflows/t3000-demo-tests-impl.yaml
index f71636bdb15..9ad4ab1b818 100644
--- a/.github/workflows/t3000-demo-tests-impl.yaml
+++ b/.github/workflows/t3000-demo-tests-impl.yaml
@@ -16,7 +16,7 @@ jobs:
test-group: [
{ name: "t3k_falcon40b_tests", arch: wormhole_b0, cmd: run_t3000_falcon40b_tests, timeout: 50, owner_id: U053W15B6JF}, #Djordje Ivanovic
{ name: "t3k_llama3_tests", arch: wormhole_b0, cmd: run_t3000_llama3_tests, timeout: 30, owner_id: U03PUAKE719}, # Miguel Tairum
- # { name: "t3k_llama3_vision_tests", arch: wormhole_b0, cmd: run_t3000_llama3_vision_tests, timeout: 30, owner_id: U03FJB5TM5Y}, #Colman Glagovich
+ { name: "t3k_llama3_vision_tests", arch: wormhole_b0, cmd: run_t3000_llama3_vision_tests, timeout: 30, owner_id: U03FJB5TM5Y}, #Colman Glagovich
{ name: "t3k_llama3_70b_tests", arch: wormhole_b0, cmd: run_t3000_llama3_70b_tests, timeout: 30, owner_id: U03FJB5TM5Y}, #Colman Glagovich
{ name: "t3k_falcon7b_tests", arch: wormhole_b0, cmd: run_t3000_falcon7b_tests, timeout: 90, owner_id: U05RWH3QUPM}, #Salar Hosseini
{ name: "t3k_mixtral_tests", arch: wormhole_b0, cmd: run_t3000_mixtral_tests, timeout: 50, owner_id: U03PUAKE719}, # Miguel Tairum
diff --git a/.github/workflows/t3000-frequent-tests-impl.yaml b/.github/workflows/t3000-frequent-tests-impl.yaml
index 542e85187c6..11a2df7b146 100644
--- a/.github/workflows/t3000-frequent-tests-impl.yaml
+++ b/.github/workflows/t3000-frequent-tests-impl.yaml
@@ -18,9 +18,10 @@ jobs:
{ name: "t3k ethernet tests", arch: wormhole_b0, cmd: run_t3000_ethernet_tests, timeout: 60, owner_id: ULMEPM2MA}, #Sean Nijjar
{ name: "t3k trace stress tests", arch: wormhole_b0, cmd: run_t3000_trace_stress_tests, timeout: 120, owner_id: U03NG0A5ND7}, #Aditya Saigal
{ name: "t3k falcon40b tests", arch: wormhole_b0, cmd: run_t3000_falcon40b_tests, timeout: 120, owner_id: U04S2UV6L8N}, #Sofija Jovic
- # { name: "t3k llama3.2-vision tests", arch: wormhole_b0, cmd: run_t3000_llama3.2-11b-vision_freq_tests, timeout: 60, owner_id: U03FJB5TM5Y}, #Colman Glagovich
- # { name: "t3k n300 mesh llama3.2-vision tests", arch: wormhole_b0, cmd: run_t3000_spoof_n300_llama3.2-11b-vision_freq_tests, timeout: 60, owner_id: U03FJB5TM5Y}, #Colman Glagovich
+ { name: "t3k llama3.2-vision tests", arch: wormhole_b0, cmd: run_t3000_llama3.2-11b-vision_freq_tests, timeout: 60, owner_id: U03FJB5TM5Y}, #Colman Glagovich
+ { name: "t3k n300 mesh llama3.2-vision tests", arch: wormhole_b0, cmd: run_t3000_spoof_n300_llama3.2-11b-vision_freq_tests, timeout: 60, owner_id: U03FJB5TM5Y}, #Colman Glagovich
{ name: "t3k llama3 tests", arch: wormhole_b0, cmd: run_t3000_llama3_tests, timeout: 45, owner_id: U03PUAKE719}, #Miguel Tairum Cruz
+ { name: "t3k llama3 accuracy tests", arch: wormhole_b0, cmd: run_t3000_llama3_accuracy_tests, timeout: 45, owner_id: U03PUAKE719}, #Miguel Tairum Cruz
{ name: "t3k llama2_70b tests", arch: wormhole_b0, cmd: run_t3000_llama2_70b_tests, timeout: 45, owner_id: U03FJB5TM5Y}, #Colman Glagovich
# { name: "t3k llama3_70b tests", arch: wormhole_b0, cmd: run_t3000_llama3_70b_tests, timeout: 45, owner_id: U03FJB5TM5Y}, #Colman Glagovich # FIXME issue #14934
{ name: "t3k mixtral tests", arch: wormhole_b0, cmd: run_t3000_mixtral_tests, timeout: 60, owner_id: U03PUAKE719}, #Miguel Tairum Cruz
diff --git a/.github/workflows/t3000-unit-tests-impl.yaml b/.github/workflows/t3000-unit-tests-impl.yaml
index f05ee8e7810..303de478fd7 100644
--- a/.github/workflows/t3000-unit-tests-impl.yaml
+++ b/.github/workflows/t3000-unit-tests-impl.yaml
@@ -20,8 +20,8 @@ jobs:
{ name: "t3k falcon40b tests", arch: wormhole_b0, cmd: run_t3000_falcon40b_tests, timeout: 30, owner_id: U053W15B6JF}, #Djordje Ivanovic
{ name: "t3k llama3-small tests", arch: wormhole_b0, cmd: run_t3000_llama3-small_tests, timeout: 30, owner_id: U03PUAKE719}, #Miguel Tairum Cruz
{ name: "t3k llama3.2-11b tests", arch: wormhole_b0, cmd: run_t3000_llama3.2-11b_tests, timeout: 30, owner_id: U03PUAKE719}, #Miguel Tairum Cruz
- # { name: "t3k llama3.2-11b-vision tests", arch: wormhole_b0, cmd: run_t3000_llama3.2-11b-vision_unit_tests, timeout: 30, owner_id: U03FJB5TM5Y}, #Colman Glagovich
- # { name: "t3k n300 mesh llama3.2-11b-vision tests", arch: wormhole_b0, cmd: run_t3000_spoof_n300_llama3.2-11b-vision_unit_tests, timeout: 30, owner_id: U03FJB5TM5Y}, #Colman Glagovich
+ { name: "t3k llama3.2-11b-vision tests", arch: wormhole_b0, cmd: run_t3000_llama3.2-11b-vision_unit_tests, timeout: 30, owner_id: U03FJB5TM5Y}, #Colman Glagovich
+ { name: "t3k n300 mesh llama3.2-11b-vision tests", arch: wormhole_b0, cmd: run_t3000_spoof_n300_llama3.2-11b-vision_unit_tests, timeout: 30, owner_id: U03FJB5TM5Y}, #Colman Glagovich
{ name: "t3k llama3.1-70b tests", arch: wormhole_b0, cmd: run_t3000_llama3.1-70b_tests, timeout: 30, owner_id: U03PUAKE719}, #Miguel Tairum Cruz
{ name: "t3k mixtral tests", arch: wormhole_b0, cmd: run_t3000_mixtral_tests, timeout: 30, owner_id: U03PUAKE719}, #Miguel Tairum Cruz
{ name: "t3k grok tests", arch: wormhole_b0, cmd: run_t3000_grok_tests, timeout: 30, owner_id: U03HY7MK4BT}, #Mark O'Connor
diff --git a/.github/workflows/ttnn-run-sweeps.yaml b/.github/workflows/ttnn-run-sweeps.yaml
index 99ea50f9860..22ef67bb730 100644
--- a/.github/workflows/ttnn-run-sweeps.yaml
+++ b/.github/workflows/ttnn-run-sweeps.yaml
@@ -348,6 +348,7 @@ on:
- transformer.split_query_key_value_and_split_heads.split_query_key_value_and_split_heads_kv_input
- transformer.attention_softmax.attention_softmax
- transformer.attention_softmax.attention_softmax_
+ - transformer.rotary_embedding.rotary_embedding
- data_movement.stack.stack_pytorch2
- data_movement.repeat.repeat_pytorch2
- data_movement.split.split_pytorch2
diff --git a/.gitmodules b/.gitmodules
index d20bc574cc3..4029f9918d6 100644
--- a/.gitmodules
+++ b/.gitmodules
@@ -1,6 +1,3 @@
-[submodule "third_party/pybind11"]
- path = tt_metal/third_party/pybind11
- url = https://github.com/pybind/pybind11.git
[submodule "third_party/lfs"]
path = tt_metal/third_party/lfs
url = https://github.com/tenstorrent-metal/lfs.git
diff --git a/CMakeLists.txt b/CMakeLists.txt
index a4146a61cb9..f929b62c989 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -259,18 +259,6 @@ if(ENABLE_TRACY)
add_link_options(-rdynamic)
endif()
-if(WITH_PYTHON_BINDINGS)
- # Can't use the `REUSE_FROM` option bc tt_lib and ttnn have different build flags :(
- add_library(pch_pybinds INTERFACE)
- target_precompile_headers(
- pch_pybinds
- INTERFACE
- ${PROJECT_SOURCE_DIR}/tt_metal/third_party/pybind11/include/pybind11/operators.h
- ${PROJECT_SOURCE_DIR}/tt_metal/third_party/pybind11/include/pybind11/pybind11.h
- ${PROJECT_SOURCE_DIR}/tt_metal/third_party/pybind11/include/pybind11/stl.h
- )
-endif()
-
############################################################################################################################
# Build subdirectories
############################################################################################################################
diff --git a/CODEOWNERS b/CODEOWNERS
index aa80b7671c4..3b74d00a047 100644
--- a/CODEOWNERS
+++ b/CODEOWNERS
@@ -173,10 +173,6 @@ tests/**/dtx/ @mywoodstock @sankarmanoj-tt
tests/**/*test*conv*.py @mywoodstock @sankarmanoj-tt
tests/python_api_testing/conv/ @mywoodstock @sankarmanoj-tt
tests/python_api_testing/unit_testing/fallback_ops @tt-aho
-tests/ttnn/integration_tests/stable_diffusion @esmalTT @uaydonat @mywoodstock
-tests/device_perf_tests/stable_diffusion/test_perf_stable_diffusion.py @esmalTT @uaydonat @mywoodstock
-tests/ttnn/integration_tests/unet @esmalTT @uaydonat @mywoodstock
-tests/nightly/wh_b0_only_eth/experimental/functional_unet @esmalTT @uaydonat @mywoodstock
scripts/profiler/ @mo-tenstorrent
scripts/docker @tenstorrent/metalium-developers-infra
diff --git a/README.md b/README.md
index 2765715fea1..d06621a08fa 100644
--- a/README.md
+++ b/README.md
@@ -21,29 +21,34 @@
---
## LLMs
-| Model | Batch | Hardware | ttft (ms) | t/s/u | Target
t/s/u | t/s | Release |
-|---------------------------------------------------------------|-------|----------------------------------------------------------|----------|-------|-----------------|--------|---------------------------------------------------------------------------|
-| [Falcon7B-decode](./models/demos/ttnn_falcon7b) | 32 | [e150](https://tenstorrent.com/hardware/grayskull) | | 4.2 | 4.4 | 134.4 | |
-| [Falcon7B](./models/demos/wormhole/falcon7b) | 32 | [n150](https://tenstorrent.com/hardware/wormhole) | 71 | 17.6 | 26 | 563.2 | [v0.53.0-rc44](https://github.com/tenstorrent/tt-metal/tree/v0.53.0-rc44) |
-| [Mistral-7B](./models/demos/wormhole/mistral7b) | 32 | [n150](https://tenstorrent.com/hardware/wormhole) | | 9.9 | 25 | 316.8 | [v0.51.0-rc28](https://github.com/tenstorrent/tt-metal/tree/v0.51.0-rc28) |
-| [Mamba-2.8B](./models/demos/wormhole/mamba) | 32 | [n150](https://tenstorrent.com/hardware/wormhole) | 48 | 12.3 | 41 | 393.6 | [v0.51.0-rc26](https://github.com/tenstorrent/tt-metal/tree/v0.51.0-rc26) |
-| [LLaMA-3.1-8B](./models/demos/llama3) | 1 | [n150](https://tenstorrent.com/hardware/wormhole) | 202 | 28.6 | 23 | 28.6 | [v0.53.1-rc7](https://github.com/tenstorrent/tt-metal/tree/v0.53.1-rc7) |
-| [LLaMA-3.2-1B](./models/demos/llama3) | 1 | [n150](https://tenstorrent.com/hardware/wormhole) | 71 | 90.8 | 160 | 90.8 | [v0.53.1-rc7](https://github.com/tenstorrent/tt-metal/tree/v0.53.1-rc7) |
-| [LLaMA-3.2-3B](./models/demos/llama3) | 1 | [n150](https://tenstorrent.com/hardware/wormhole) | 112 | 49.1 | 60 | 49.1 | [v0.53.1-rc7](https://github.com/tenstorrent/tt-metal/tree/v0.53.1-rc7) |
-| [Falcon7B (DP=8)](./models/demos/t3000/falcon7b) | 256 | [QuietBox](https://tenstorrent.com/hardware/tt-quietbox) | 97 | 14.6 | 26 | 3737.6 | [v0.53.0-rc44](https://github.com/tenstorrent/tt-metal/tree/v0.53.0-rc44) |
-| [LLaMA-3.1-70B (TP=8)](./models/demos/t3000/llama3_70b) | 32 | [QuietBox](https://tenstorrent.com/hardware/tt-quietbox) | 190 | 15.1 | 20 | 483.2 | [v0.53.0-rc36](https://github.com/tenstorrent/tt-metal/tree/v0.53.0-rc36) |
-| [Falcon40B (TP=8)](./models/demos/t3000/falcon40b) | 32 | [QuietBox](https://tenstorrent.com/hardware/tt-quietbox) | | 5.3 | 36 | 169.6 | [v0.53.1-rc7](https://github.com/tenstorrent/tt-metal/tree/v0.53.1-rc7) |
-| [Mixtral7Bx8 (TP=8)](./models/demos/t3000/mixtral8x7b) | 32 | [QuietBox](https://tenstorrent.com/hardware/tt-quietbox) | 230 | 14.6 | 33 | 467.2 | [v0.53.0-rc44](https://github.com/tenstorrent/tt-metal/tree/v0.53.0-rc44) |
-| [Falcon7B (DP=32)](./models/demos/tg/falcon7b) | 1024 | [Galaxy](https://tenstorrent.com/hardware/galaxy) | 242 | 4.4 | 26 | 4505.6 | [v0.53.0-rc33](https://github.com/tenstorrent/tt-metal/tree/v0.53.0-rc33) |
-| [LLaMA-3.1-70B (DP=4, TP=8)](./models/demos/t3000/llama3_70b) | 128 | [Galaxy](https://tenstorrent.com/hardware/galaxy) | 190 | 14.3 | 20 | 1835.5 | [v0.52.0-rc31](https://github.com/tenstorrent/tt-metal/tree/v0.52.0-rc31) |
-> **Last Update:** December 2, 2024
+| Model | Batch | Hardware | ttft (ms) | t/s/u | Target
t/s/u | t/s | TT-Metalium Release | vLLM Tenstorrent Repo Release |
+|---------------------------------------------------------------|-------|----------------------------------------------------------|-----------|-------|-----------------|--------|---------------------------------------------------|---------------------------------------------------------------------------------------------------|
+| [Falcon 7B (decode only)](./models/demos/ttnn_falcon7b) | 32 | [e150](https://tenstorrent.com/hardware/grayskull) | | 4.2 | 4.4 | 134.4 | | |
+| [Falcon 7B](./models/demos/wormhole/falcon7b) | 32 | [n150](https://tenstorrent.com/hardware/wormhole) | 71 | 17.6 | 26 | 563.2 | [v0.53.0-rc44](https://github.com/tenstorrent/tt-metal/tree/v0.53.0-rc44) | |
+| [Mistral 7B](./models/demos/wormhole/mistral7b) | 32 | [n150](https://tenstorrent.com/hardware/wormhole) | | 9.9 | 25 | 316.8 | [v0.51.0-rc28](https://github.com/tenstorrent/tt-metal/tree/v0.51.0-rc28) | |
+| [Mamba 2.8B](./models/demos/wormhole/mamba) | 32 | [n150](https://tenstorrent.com/hardware/wormhole) | 48 | 12.3 | 41 | 393.6 | [v0.51.0-rc26](https://github.com/tenstorrent/tt-metal/tree/v0.51.0-rc26) | |
+| [Llama 3.1 8B](./models/demos/llama3) | 1 | [n150](https://tenstorrent.com/hardware/wormhole) | 202 | 28.6 | 23 | 28.6 | [v0.53.1-rc7](https://github.com/tenstorrent/tt-metal/tree/v0.53.1-rc7) | |
+| [Llama 3.2 1B](./models/demos/llama3) | 1 | [n150](https://tenstorrent.com/hardware/wormhole) | 71 | 90.8 | 160 | 90.8 | [v0.53.1-rc7](https://github.com/tenstorrent/tt-metal/tree/v0.53.1-rc7) | |
+| [Llama 3.2 3B](./models/demos/llama3) | 1 | [n150](https://tenstorrent.com/hardware/wormhole) | 112 | 49.1 | 60 | 49.1 | [v0.53.1-rc7](https://github.com/tenstorrent/tt-metal/tree/v0.53.1-rc7) | |
+| [Falcon 7B (DP=8)](./models/demos/t3000/falcon7b) | 256 | [QuietBox](https://tenstorrent.com/hardware/tt-quietbox) | 97 | 14.6 | 26 | 3737.6 | [v0.53.0-rc44](https://github.com/tenstorrent/tt-metal/tree/v0.53.0-rc44) | |
+| [Llama 3.1 70B (TP=8)](./models/demos/t3000/llama3_70b) | 32 | [QuietBox](https://tenstorrent.com/hardware/tt-quietbox) | 190 | 15.1 | 20 | 483.2 | [v0.53.0-rc36](https://github.com/tenstorrent/tt-metal/tree/v0.53.0-rc36) | [384f179](https://github.com/tenstorrent/vllm/tree/384f1790c3be16e1d1b10de07252be2e66d00935) |
+| [Falcon 40B (TP=8)](./models/demos/t3000/falcon40b) | 32 | [QuietBox](https://tenstorrent.com/hardware/tt-quietbox) | | 5.3 | 36 | 169.6 | [v0.53.1-rc7](https://github.com/tenstorrent/tt-metal/tree/v0.53.1-rc7) | |
+| [Mixtral 8x7B (TP=8)](./models/demos/t3000/mixtral8x7b) | 32 | [QuietBox](https://tenstorrent.com/hardware/tt-quietbox) | 230 | 14.6 | 33 | 467.2 | [v0.53.0-rc44](https://github.com/tenstorrent/tt-metal/tree/v0.53.0-rc44) | |
+| [Falcon 7B (DP=32)](./models/demos/tg/falcon7b) | 1024 | [Galaxy](https://tenstorrent.com/hardware/galaxy) | 242 | 4.4 | 26 | 4505.6 | [v0.53.0-rc33](https://github.com/tenstorrent/tt-metal/tree/v0.53.0-rc33) | |
+| [Llama 3.1 70B (DP=4, TP=8)](./models/demos/t3000/llama3_70b) | 128 | [Galaxy](https://tenstorrent.com/hardware/galaxy) | 190 | 14.3 | 20 | 1835.5 | [v0.52.0-rc31](https://github.com/tenstorrent/tt-metal/tree/v0.52.0-rc31) | |
+
+> **Last Update:** December 7, 2024
+>
> **Notes:**
+>
+> - ttft = time to first token | t/s/u = tokens/second/user | t/s = tokens/second; where t/s = t/s/u * batch.
> - TP = Tensor Parallel, DP = Data Parallel; Defines parallelization factors across multiple devices.
> - The reported LLM performance is for an input sequence length (number of rows filled in the KV cache) of 128 for all models except Mamba (which can accept any sequence length).
> - The t/s/u reported is the throughput of the first token generated after prefill, i.e. 1 / inter token latency.
## CNNs
+
| Model | Batch | Hardware | fps | Target fps | Release |
|-----------------------------------------------------------------------------|-------|----------------------------------------------------------|---------|------------|-------------|
| [ResNet-50 (224x224)](./models/demos/grayskull/resnet50) | 20 | [e150](https://tenstorrent.com/hardware/grayskull) | 5,100 | 10,000 | |
@@ -55,11 +60,11 @@
| [ViT (224x224)](./models/demos/grayskull/vit) | 9 | [e150](https://tenstorrent.com/hardware/grayskull) | 1,360 | 2,000 | |
| [ViT (224x224)](./models/demos/wormhole/vit) | 8 | [n150](https://tenstorrent.com/hardware/wormhole) | 912 | 1,600 | |
| [Stable Diffusion 1.4 (512x512)](./models/demos/wormhole/stable_diffusion) | 1 | [n150](https://tenstorrent.com/hardware/wormhole) | 0.167 | 0.3 | |
-| [Yolo V4 (320x320)](./models/demos/yolov4) | 1 | [n150](https://tenstorrent.com/hardware/wormhole) | 95 | 300 | |
-| [Segformer Semantic Segmentation (512x512)](./models/demos/segformer) | 1 | [n150](https://tenstorrent.com/hardware/wormhole) | 90 | 300 | |
-
+| [YOLOv4 (320x320)](./models/demos/yolov4) | 1 | [n150](https://tenstorrent.com/hardware/wormhole) | 95 | 300 | |
+| [SegFormer Semantic Segmentation (512x512)](./models/demos/segformer) | 1 | [n150](https://tenstorrent.com/hardware/wormhole) | 90 | 300 | |
## NLPs
+
| Model | Batch | Hardware | sen/sec | Target sen/sec | Release |
|-----------------------------------------------------|-------|----------------------------------------------------|---------|----------------|---------|
| [BERT-Large](./models/demos/metal_BERT_large_11/) | 12 | [e150](https://tenstorrent.com/hardware/grayskull) | 370 | 410 | |
@@ -68,9 +73,11 @@
| [Bloom](.models/demos/grayskull/functional_bloom) | | [e150](https://tenstorrent.com/hardware/grayskull) | 70 | | |
## Model Updates
+
For the latest model updates and features, please see [MODEL_UPDATES.md](models/MODEL_UPDATES.md)
## TT-NN Tech Reports
+
- [Advanced Performance Optimizations for Models](./tech_reports/AdvancedPerformanceOptimizationsForModels/AdvancedPerformanceOptimizationsForModels.md) (updated Dec 4th)
- [Programming Mesh of Devices](./tech_reports/Programming%20Mesh%20of%20Devices/Programming%20Mesh%20of%20Devices%20with%20TT-NN.md) (updated Sept 9th)
- [ViT Implementation in TT-NN on GS](./tech_reports/ViT-TTNN/vit.md) (updated Sept 22nd)
@@ -78,8 +85,8 @@ For the latest model updates and features, please see [MODEL_UPDATES.md](models/
- [YOLOv4 Implementation in TT-NN on WH](./tech_reports/YoloV4-TTNN/yolov4.md) (updated November 8th)
## Benchmarks
-- [Matrix Multiply FLOPS on WH](./tech_reports/GEMM_FLOPS/GEMM_FLOPS.md) (updated November 13th)
+- [Matrix Multiply FLOPS on WH](./tech_reports/GEMM_FLOPS/GEMM_FLOPS.md) (updated November 13th)
---
@@ -89,7 +96,6 @@ For the latest model updates and features, please see [MODEL_UPDATES.md](models/
**TT-Metalium** is our low-level programming model, enabling kernel development for Tenstorrent hardware.
-
[Programming Guide](./METALIUM_GUIDE.md) | [API Reference](https://docs.tenstorrent.com/tt-metalium/latest/tt_metal/apis/index.html)
@@ -102,6 +108,7 @@ For the latest model updates and features, please see [MODEL_UPDATES.md](models/
Get started with [simple kernels](https://docs.tenstorrent.com/tt-metalium/latest/tt_metal/examples/index.html).
## TT-Metalium Tech Reports
+
- [Matrix Engine](./tech_reports/matrix_engine/matrix_engine.md) (updated Sept 6th)
- [Data Formats](./tech_reports/data_formats/data_formats.md) (updated Sept 7th)
- [Reconfiguring Data Formats](./tech_reports/data_formats/reconfig_data_format.md) (updated Oct 17th)
@@ -113,24 +120,36 @@ Get started with [simple kernels](https://docs.tenstorrent.com/tt-metalium/lates
- [CNNs on TT Architectures](./tech_reports/CNNs/ttcnn.md) (updated Sept 6th)
- [Ethernet and Multichip Basics](./tech_reports/EthernetMultichip/BasicEthernetGuide.md) (Updated Sept 20th)
- [Collective Communication Library (CCL)](./tech_reports/EthernetMultichip/CclDeveloperGuide.md) (Updated Sept 20th)
-- [Blackhole Bring-Up Prgramming Guide](./tech_reports/Blackhole/BlackholeBringUpProgrammingGuide.md) (Updated Oct 30th)
+- [Blackhole Bring-Up Programming Guide](./tech_reports/Blackhole/BlackholeBringUpProgrammingGuide.md) (Updated Oct 30th)
## TT-Metalium Programming Examples
+
### Hello World
+
- [Hello World! Compute Kernel](./tech_reports/prog_examples/hello_world_compute/hello_world_compute.md)
- [Hello World! Data Movement Kernel](./tech_reports/prog_examples/hello_world_data_movement/hello_world_data_movement.md)
+
### Add Integers
+
- [Add 2 Integers in Baby RiscV](./tech_reports/prog_examples/add_2_integers_in_riscv/add_2_integers_in_riscv.md)
- [Add 2 Integers in Compute Kernel](./tech_reports/prog_examples/add_2_integers_in_compute/add_2_integers_in_compute.md)
+
### Simple Tensor Manipulation
+
- [Sharding](./tech_reports/prog_examples/shard_data_rm/shard_data_rm.md)
- [Padding](./tech_reports/prog_examples/pad_multi_core/pad_multi_core.md)
+
### DRAM Data Movement
+
- [Dram Loopback Data Movement](./tech_reports/prog_examples/dram_loopback/dram_loopback.md)
+
### Eltwise
+
- [Eltwise Unary OP in Vector Engine (SFPU)](./tech_reports/prog_examples/eltwise_sfpu/eltwise_sfpu.md)
- [Eltwise Binary OP in Matrix Engine (FPU)](./tech_reports/prog_examples/eltwise_binary/eltwise_binary.md)
+
### Matmul
+
- [Matmul OP on a Single_core](./tech_reports/prog_examples/matmul_single_core/matmul_single_core.md)
- [Matmul OP on Multi_core (Basic)](./tech_reports/prog_examples/matmul_multi_core/matmul_multi_core.md)
- [Matmul Multi_core Reuse (Optimized)](./tech_reports/prog_examples/matmul_multi_core_optimized/data_reuse.md)
diff --git a/conftest.py b/conftest.py
index f1a753ca243..6e43d1a6499 100644
--- a/conftest.py
+++ b/conftest.py
@@ -257,7 +257,7 @@ def pcie_mesh_device(request, silicon_arch_name, silicon_arch_wormhole_b0, devic
mesh_shape=ttnn.MeshShape(2, 2),
dispatch_core_config=dispatch_core_config,
**device_params,
- offset=(0, 1),
+ offset=ttnn.MeshOffset(0, 1),
mesh_type=ttnn.MeshType.Ring,
)
diff --git a/dependencies/CMakeLists.txt b/dependencies/CMakeLists.txt
index 7369f655c1c..e14310435a4 100644
--- a/dependencies/CMakeLists.txt
+++ b/dependencies/CMakeLists.txt
@@ -85,3 +85,9 @@ CPMAddPackage(NAME fmt GITHUB_REPOSITORY fmtlib/fmt GIT_TAG 11.0.1)
############################################################################################################################
CPMAddPackage(NAME range-v3 GITHUB_REPOSITORY ericniebler/range-v3 GIT_TAG 0.12.0)
+
+############################################################################################################################
+# pybind11 : https://github.com/pybind/pybind11
+############################################################################################################################
+
+CPMAddPackage(NAME pybind11 GITHUB_REPOSITORY pybind/pybind11 GIT_TAG b8f28551cc3a98ea9fbfc15c05b513c8f2d23e84)
diff --git a/models/demos/convnet_mnist/tt/convnet_mnist.py b/models/demos/convnet_mnist/tt/convnet_mnist.py
index a38aa60a770..1d9ac8acba0 100644
--- a/models/demos/convnet_mnist/tt/convnet_mnist.py
+++ b/models/demos/convnet_mnist/tt/convnet_mnist.py
@@ -19,21 +19,23 @@ def convnet_mnist(
conv_config = ttnn.Conv2dConfig(
dtype=ttnn.bfloat16,
weights_dtype=ttnn.bfloat16,
- math_fidelity=ttnn.MathFidelity.LoFi,
activation="",
shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED,
- math_approx_mode_enabled=True,
- fp32_dest_acc_enabled=False,
- packer_l1_accum_enabled=False,
input_channels_alignment=32,
transpose_shards=False,
reshard_if_not_optimal=True,
deallocate_activation=True,
reallocate_halo_output=True,
)
-
+ compute_config = ttnn.init_device_compute_kernel_config(
+ device.arch(),
+ math_fidelity=ttnn.MathFidelity.LoFi,
+ math_approx_mode=True,
+ fp32_dest_acc_en=False,
+ packer_l1_acc=False,
+ )
x = ttnn.to_layout(input_tensor, layout=ttnn.ROW_MAJOR_LAYOUT)
- [x, out_height, out_width, weights_device, bias_device] = ttnn.conv2d(
+ x = ttnn.conv2d(
input_tensor=x,
weight_tensor=parameters.conv1.weight,
in_channels=1,
@@ -47,9 +49,12 @@ def convnet_mnist(
input_height=input_tensor.shape[1],
input_width=input_tensor.shape[2],
conv_config=conv_config,
+ compute_config=compute_config,
conv_op_cache={},
debug=True,
groups=1,
+ return_output_dim=False,
+ return_weights_and_bias=False,
)
x = ttnn.relu(x)
@@ -76,7 +81,7 @@ def convnet_mnist(
dilation=[1, 1],
)
- [x, out_height, out_width, weights_device, bias_device] = ttnn.conv2d(
+ x, [out_height, out_width] = ttnn.conv2d(
input_tensor=x,
weight_tensor=parameters.conv2.weight,
in_channels=32,
@@ -93,6 +98,8 @@ def convnet_mnist(
conv_op_cache={},
debug=False,
groups=1,
+ return_output_dim=True,
+ return_weights_and_bias=False,
)
x = ttnn.relu(x)
diff --git a/models/demos/llama3/PERF.md b/models/demos/llama3/PERF.md
index dd060a14c1c..f0dbf00ec4b 100644
--- a/models/demos/llama3/PERF.md
+++ b/models/demos/llama3/PERF.md
@@ -12,16 +12,16 @@ This configuration uses bfp4 MLP FF1+FF3 for all models.
|-------|--------|-----------|-----------|---------------|
| 1b | N150 | 79 | 98 | 90.5 |
| 1b | N300 | 81 | 98 | 101.7 |
-| 1b | T3K | 81 | 98 | 97.5 |
+| 1b | T3K | 81 | 98 | 96.8 |
| 3b | N150 | 85 | 96 | 49.0 |
| 3b | N300 | 88 | 97 | 56.9 |
| 3b | T3K | 88 | 97 | 54.5 |
| 8b | N150 | 86 | 98 | 28.4 |
| 8b | N300 | 84 | 98 | 38.6 |
-| 8b | T3K | 84 | 98 | 52.6 |
+| 8b | T3K | 84 | 97 | 52.6 |
| 11b | N300 | 86 | 97 | 38.6 |
| 11b | T3K | 84 | 98 | 52.6 |
-| 70b | T3K | 95 | 100 | 14.3 |
+| 70b | T3K | 94 | 100 | 14.3 |
## LlamaOptimizations.accuracy
@@ -40,4 +40,4 @@ This configuration uses bfp4 MLP FF1+FF3 only for the 3.1-70B model.
| 8b | T3K | 88 | 97 | 49.9 |
| 11b | N300 | 90 | 97 | 33.8 |
| 11b | T3K | 88 | 97 | 52.6 |
-| 70b | T3K | 95 | 100 | 14.5 |
+| 70b | T3K | 94 | 100 | 14.5 |
diff --git a/models/demos/llama3/demo/demo.py b/models/demos/llama3/demo/demo.py
index a828830d330..f3b5b998fcb 100644
--- a/models/demos/llama3/demo/demo.py
+++ b/models/demos/llama3/demo/demo.py
@@ -26,7 +26,6 @@
)
from models.demos.llama3.tt.llama_model import TtTransformer
from models.demos.llama3.tt.llama_embedding import TtLlamaEmbedding
-from models.demos.llama3.tt.llama_rope import TtLlamaRotarySetup
from models.demos.t3000.llama2_70b.reference.llama.llama31_8b.tokenizer import Tokenizer
from models.demos.llama3.tt.model_config import TtModelArgs
@@ -227,6 +226,7 @@ def run_llama3_demo(
optimizations=optimizations,
max_seq_len=max_seq_len,
)
+
tokenizer = Tokenizer(model_args.tokenizer_path)
# Check max sequence length compatibility with model and architecture. Refer to README for more information
@@ -259,34 +259,13 @@ def run_llama3_demo(
), "T3K only supports a max context length of 128k tokens for Llama3.1-8B and Llama3.2-11B"
if llama_model_name == "3.1-70B":
assert tt_device_name in ["T3K", "TG"], "Llama3.1-70B is only supported on T3K or TG"
+ assert max_seq_len <= 64 * 1024, "T3K only supports a max context length of 64k tokens for Llama3.1-70B"
logger.info("Loading weights...")
profiler.start("weight_loading")
state_dict = model_args.load_state_dict()
profiler.end("weight_loading")
- # Setup RoPE transformation matrices
- rope_setup = TtLlamaRotarySetup(
- mesh_device,
- batch_size,
- model_args.head_dim,
- model_args.max_seq_len,
- model_args.rope_theta,
- model_args.use_scaled_rope,
- )
- transformation_mats_decode = rope_setup.get_trans_mats()
-
- transformation_mats_prefill_torch = get_rot_transformation_mat(model_args.head_dim)
- transformation_mats_prefill = ttnn.from_torch(
- transformation_mats_prefill_torch,
- dtype=ttnn.bfloat16,
- layout=ttnn.TILE_LAYOUT,
- device=mesh_device,
- memory_config=ttnn.DRAM_MEMORY_CONFIG,
- mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device),
- )
- transformation_mats = {"decode": transformation_mats_decode, "prefill": transformation_mats_prefill}
-
page_table_tt = None
if paged_attention:
@@ -314,7 +293,6 @@ def run_llama3_demo(
dtype=dtype,
state_dict=state_dict,
weight_cache_path=model_args.weight_cache_path(dtype),
- transformation_mats=transformation_mats,
paged_attention_config=paged_attention_config,
)
tt_embd = TtLlamaEmbedding(
@@ -384,7 +362,7 @@ def run_llama3_demo(
:, decoding_pos[batch_id] :, :
] = 0 # Zero out the tokens after the prefill length
- prefill_input = model_args.prepare_inputs_ttnn_prefill(
+ prefill_input = model_args.prepare_residual_tensor_prefill(
pt_prefill_input[batch_id],
)
@@ -476,7 +454,7 @@ def run_llama3_demo(
)
# Get cos/sin matrices for the current position of each user
- rot_mats, rot_mat_idxs = rope_setup.get_rot_mats(current_pos, return_rot_idxs=True)
+ rot_mats, rot_mat_idxs = tt_model.rope_setup.get_rot_mats(current_pos, return_rot_idxs=True)
# Compile
logger.info(f"Compiling model trace...")
decode_input = ttnn.unsqueeze_to_4D(tt_embd(tt_out_tok))
@@ -519,7 +497,7 @@ def run_llama3_demo(
decode_input = ttnn.unsqueeze_to_4D(tt_embd(tt_out_tok))
decode_input = ttnn.to_memory_config(decode_input, tt_model.args.model_config["DECODE_RESIDUAL_MEMCFG"])
- rot_mats = rope_setup.get_rot_mats(rot_mat_idxs)
+ rot_mats = tt_model.rope_setup.get_rot_mats(rot_mat_idxs)
tt_out = tt_model(
decode_input,
current_pos_tensor,
@@ -562,7 +540,7 @@ def run_llama3_demo(
# Reset the current position and output token tensors for the real decode run
ttnn.copy_host_to_device_tensor(current_pos_reset, current_pos_tensor)
ttnn.copy_host_to_device_tensor(tt_out_tok_reset, tt_out_tok)
- rot_mat_idxs_reset = rope_setup.get_rot_idxs(current_pos, on_host=True)
+ rot_mat_idxs_reset = tt_model.rope_setup.get_rot_idxs(current_pos, on_host=True)
ttnn.copy_host_to_device_tensor(rot_mat_idxs_reset, rot_mat_idxs)
profiler.end(f"capture_trace_{batch_idx}")
@@ -591,7 +569,7 @@ def run_llama3_demo(
# TODO This is required for now since we cannot ttnn.plus_one(rot_mat_idxs) while it being uint32.
# If this tensor is int32, it won't be supported by ttnn.embedding
current_pos += 1
- rot_mat_idxs_updated = rope_setup.get_rot_idxs(current_pos, on_host=True)
+ rot_mat_idxs_updated = tt_model.rope_setup.get_rot_idxs(current_pos, on_host=True)
ttnn.copy_host_to_device_tensor(rot_mat_idxs_updated, rot_mat_idxs)
# Write to host
diff --git a/models/demos/llama3/demo/multimodal_demo_chat.py b/models/demos/llama3/demo/multimodal_demo_chat.py
index ca3d5b498e3..ac7c5a60b2e 100644
--- a/models/demos/llama3/demo/multimodal_demo_chat.py
+++ b/models/demos/llama3/demo/multimodal_demo_chat.py
@@ -21,7 +21,7 @@
IMG_PATH = Path(resource_filename("llama_models", "scripts/resources/"))
-from models.demos.llama3.tt.multimodal.vision_generator import LlamaVision
+from models.demos.llama3.tt.generator import LlamaGenerator
from models.demos.llama3.demo.simple_vision_demo import create_multimodal_model
@@ -67,7 +67,7 @@ def test_llama_multimodal_demo_chat(
model_args, model = create_multimodal_model(mesh_device, max_batch_size=max_batch_size, max_seq_len=max_seq_len)
tokenizer = Tokenizer(model_path=tokenizer_path)
formatter = ChatFormat(tokenizer)
- generator = LlamaVision(model, model_args, mesh_device, tokenizer=tokenizer, formatter=formatter)
+ generator = LlamaGenerator(model, model_args, mesh_device, tokenizer=tokenizer, formatter=formatter)
# image understanding
dialogs = []
diff --git a/models/demos/llama3/demo/multimodal_demo_text.py b/models/demos/llama3/demo/multimodal_demo_text.py
index 2029c43458b..4bea26c781b 100644
--- a/models/demos/llama3/demo/multimodal_demo_text.py
+++ b/models/demos/llama3/demo/multimodal_demo_text.py
@@ -23,7 +23,7 @@
IMG_PATH = Path(resource_filename("llama_models", "scripts/resources/"))
from models.demos.llama3.demo.simple_vision_demo import create_multimodal_model
-from models.demos.llama3.tt.multimodal.vision_generator import LlamaVision
+from models.demos.llama3.tt.generator import LlamaGenerator
@pytest.mark.parametrize(
@@ -73,7 +73,7 @@ def test_llama_multimodal_demo_text(
model_args, model = create_multimodal_model(mesh_device, max_batch_size=max_batch_size, max_seq_len=max_seq_len)
tokenizer = Tokenizer(model_path=tokenizer_path)
formatter = ChatFormat(tokenizer)
- generator = LlamaVision(model, model_args, mesh_device, tokenizer=tokenizer, formatter=formatter)
+ generator = LlamaGenerator(model, model_args, mesh_device, tokenizer=tokenizer, formatter=formatter)
with open(IMG_PATH / "dog.jpg", "rb") as f:
img = PIL_Image.open(f).convert("RGB")
diff --git a/models/demos/llama3/demo/simple_vision_demo.py b/models/demos/llama3/demo/simple_vision_demo.py
index b4946c3eecf..cda3c2ed957 100644
--- a/models/demos/llama3/demo/simple_vision_demo.py
+++ b/models/demos/llama3/demo/simple_vision_demo.py
@@ -23,10 +23,10 @@
import ttnn
import time
-from models.demos.llama3.tt.multimodal.vision_generator import LlamaVision
+from models.demos.llama3.tt.generator import LlamaGenerator
-def get_sampler(temperature, top_p, tokenizer):
+def get_batch_sampler(temperature, top_p, tokenizer):
def sample(logits):
if temperature > 0:
probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
@@ -34,15 +34,14 @@ def sample(logits):
else:
next_token = torch.argmax(logits[:, -1], dim=-1)
- next_token = next_token.reshape(-1)
- token = next_token[0].item()
- text = tokenizer.decode(next_token.tolist())
- return token, text
+ next_tokens = next_token.reshape(-1)
+ texts = [tokenizer.decode([next_tokens[i].item()]) for i in range(len(next_tokens))]
+ return next_tokens, texts
return sample
-def create_multimodal_model(mesh_device, max_batch_size, max_seq_len, dtype=ttnn.bfloat16):
+def create_multimodal_model(mesh_device, max_batch_size, max_seq_len, dtype=ttnn.bfloat16, use_paged_kv_cache=False):
from models.demos.llama3.tt.multimodal.llama_vision_model import CrossAttentionTransformer
from models.demos.llama3.tt.model_config import TtModelArgs
@@ -56,6 +55,7 @@ def create_multimodal_model(mesh_device, max_batch_size, max_seq_len, dtype=ttnn
weight_cache_path=tt_model_args.weight_cache_path(dtype),
dtype=dtype,
configuration=tt_model_args,
+ use_paged_kv_cache=use_paged_kv_cache,
)
return tt_model_args, model
@@ -70,32 +70,30 @@ def create_multimodal_model(mesh_device, max_batch_size, max_seq_len, dtype=ttnn
indirect=True,
)
@pytest.mark.parametrize(
- "warmup_iters",
- (0, 1),
- ids=["cold", "warm"],
+ "test_type,max_seq_len",
+ (("normal", 512),),
+ ids=["normal"],
)
@pytest.mark.parametrize(
- "test_case",
+ "warmup_iters, enable_trace, max_batch_size",
[
- "normal",
+ (0, False, 1), # batch1-notrace
+ (0, True, 1), # batch1-trace
+ (0, True, 32), # batch32-trace
],
-)
-@pytest.mark.parametrize(
- "enable_trace",
- (False, True),
- ids=["no_trace", "yes_trace"],
+ ids=["batch1-notrace", "batch1-trace", "batch32-trace"],
)
@pytest.mark.parametrize("device_params", [{"trace_region_size": 14951424, "num_command_queues": 2}], indirect=True)
def test_llama_multimodal_demo_text(
mesh_device,
warmup_iters,
- test_case,
enable_trace,
+ max_batch_size,
+ test_type,
+ max_seq_len,
temperature: float = 0,
top_p: float = 0.9,
- max_seq_len: int = 512,
- max_batch_size: int = 1,
- max_gen_len: Optional[int] = 200,
+ max_gen_len: Optional[int] = 500,
model_parallel_size: Optional[int] = None,
):
"""
@@ -107,7 +105,7 @@ def test_llama_multimodal_demo_text(
mesh_device.enable_program_cache()
mesh_device.enable_async(True)
model_args, model = create_multimodal_model(mesh_device, max_batch_size=max_batch_size, max_seq_len=max_seq_len)
- generator = LlamaVision(model, model_args, mesh_device)
+ generator = LlamaGenerator(model, model_args, mesh_device)
tokenizer = Tokenizer(model_path=tokenizer_path)
formatter = ChatFormat(tokenizer)
@@ -132,96 +130,106 @@ def test_llama_multimodal_demo_text(
[UserMessage(content=[ImageMedia(image=ocr_image), "What is the full text of this image? Do OCR"])],
[UserMessage(content=[ImageMedia(image=clutter), "What objects are in this image?"])],
]
+ if len(dialogs) < max_batch_size:
+ dialogs *= max_batch_size // len(dialogs)
- sampler = get_sampler(temperature, top_p, tokenizer)
-
- for iter_num in range(warmup_iters + 1):
- for dialog in dialogs:
- for msg in dialog:
- print(f"{msg.role.capitalize()}: {msg.content}\n")
+ assert len(dialogs) % max_batch_size == 0
+ num_batches = len(dialogs) // max_batch_size
- if iter_num <= warmup_iters:
- logger.info(f"Warmup iteration {iter_num}")
+ sampler = get_batch_sampler(temperature, top_p, tokenizer)
- model_input = formatter.encode_dialog_prompt(dialog, tool_prompt_format=False)
+ for iter_num in range(warmup_iters + 1):
+ logger.info(f"Iteration {iter_num}")
+ for batch_idx in range(num_batches):
+ batch_dialogs = dialogs[batch_idx * max_batch_size : (batch_idx + 1) * max_batch_size]
+ for dialog in batch_dialogs:
+ for msg in dialog:
+ print(f"{msg.role.capitalize()}: {msg.content}\n")
+ batch_model_input = [
+ formatter.encode_dialog_prompt(dialog, tool_prompt_format=False) for dialog in batch_dialogs
+ ]
# Do initial prefill
- vision_images = model_input.vision.images
- vision_mask = model_input.vision.mask
- prompt_tokens = model_input.tokens
- prefill_len = len(prompt_tokens)
- total_len = prefill_len + max_gen_len # Prepares mask for full length of output
- # Create tokens tensor
+ vision_images = [model_input.vision.images for model_input in batch_model_input]
+ vision_mask = [model_input.vision.mask for model_input in batch_model_input]
+ prompt_tokens = [model_input.tokens for model_input in batch_model_input]
+ # Get max length of prompts in batch
+ prefill_lens = torch.tensor([len(tokens) for tokens in prompt_tokens], dtype=torch.long)
+ total_lens = prefill_lens + max_gen_len
+
+ # Create padded tokens tensor for batch
pad_id = tokenizer.pad_id
- bsz = 1
- tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long)
- tokens[0, : len(prompt_tokens)] = torch.tensor(prompt_tokens, dtype=torch.long)
+ bsz = len(prompt_tokens)
+ tokens = torch.full((bsz, max(total_lens)), pad_id, dtype=torch.long)
+
+ # Fill in actual tokens for each sequence in batch
+ for i, seq in enumerate(prompt_tokens):
+ tokens[i, : len(seq)] = torch.tensor(seq, dtype=torch.long)
+
prefill_start = time.perf_counter()
- prompt_tokens_tensor = torch.tensor(prompt_tokens, dtype=torch.long).reshape(1, -1) # B, S
- (
- xattn_caches,
- cross_attention_masks,
- full_text_row_masked_out_mask,
- logits,
- ) = generator.prefill_forward_single_user(
+ batch_logits, batch_xattn_masks, batch_text_masks = generator.prefill_forward(
vision_images,
vision_mask,
- prompt_tokens_tensor,
+ tokens,
xattn_caches,
- user_id=0,
- total_len=total_len,
- prefill_len=prefill_len,
+ total_lens,
+ prefill_lens,
)
- prefill_end = time.perf_counter()
-
- next_token, text = sampler(logits)
- tokens[0, prefill_len] = next_token
+ prefill_end = time.perf_counter()
+ next_tokens, next_texts = sampler(batch_logits)
+ for i, (next_token, next_text) in enumerate(zip(next_tokens, next_texts)):
+ tokens[i, prefill_lens[i]] = next_token
+ print(f"Next tokens: {next_tokens}")
+ print(f"Next texts: {next_texts}")
decode_times = []
for gen_idx in range(max_gen_len - 1):
decode_start = time.perf_counter()
- position_id = prefill_len + gen_idx
- next_token_tensor = torch.tensor([next_token], dtype=torch.long).reshape(1, 1) # B, S
+ position_id = prefill_lens + gen_idx
+ next_token_tensor = next_tokens.reshape(max_batch_size, 1)
if enable_trace:
logits = generator.easy_trace(
position_id,
next_token_tensor,
- cross_attention_masks,
- full_text_row_masked_out_mask,
+ batch_xattn_masks,
+ batch_text_masks,
xattn_caches,
)
else:
logits = generator.decode_forward(
position_id,
next_token_tensor,
- cross_attention_masks,
- full_text_row_masked_out_mask,
+ batch_xattn_masks,
+ batch_text_masks,
xattn_caches,
)
- next_token, text = sampler(logits)
+ next_tokens, next_texts = sampler(logits)
# Update next token
- tokens[0, position_id + 1] = next_token
+ tokens[torch.arange(max_batch_size), position_id + 1] = next_tokens
decode_end = time.perf_counter()
decode_times.append(decode_end - decode_start)
- if text in ["<|eot_id|>", "<|eom_id|>"]:
- break
-
- # Log full text output
+ # Disable checking for eot until I have more robust code for batch > 1
+ # if text in ["<|eot_id|>", "<|eom_id|>"]:
+ # break
+ # Log full text output for each user in batch
vision_tokens = [tokenizer.special_tokens["<|image|>"], 128256]
- # Remove <|image|> tokens since they break the tokenizer
- tokens_out = [
- t if t not in vision_tokens else tokenizer.pad_id for t in tokens[0].tolist()[: position_id + 2]
- ]
- text = tokenizer.decode(tokens_out)
- logger.info(f"Full text: {text}")
+
+ for user_id in range(max_batch_size):
+ # Remove <|image|> tokens since they break the tokenizer
+ tokens_out = [
+ t if t not in vision_tokens else tokenizer.pad_id
+ for t in tokens[user_id].tolist()[: position_id[user_id] + 2]
+ ]
+ text = tokenizer.decode(tokens_out)
+ logger.info(f"User {user_id} full text: {text}")
prefill_time_ms = (prefill_end - prefill_start) * 1000
logger.info(f"Prefill time: {prefill_time_ms:.2f} ms")
decode_time_ms = sum(decode_times) / (gen_idx + 1) * 1000
- logger.info(f"Decode time: {decode_time_ms:.2f} ms")
+ logger.info(f"Average decode time per token: {decode_time_ms:.2f} ms")
# ttnn.release_trace(generator.mesh_device, trace_id)
diff --git a/models/demos/llama3/lt b/models/demos/llama3/lt
index 594568609ba..8f68983a2b6 100755
--- a/models/demos/llama3/lt
+++ b/models/demos/llama3/lt
@@ -733,7 +733,7 @@ def run_entry_command(entry, screen_lock, output_entries, screen_needs_update):
command_shortcuts = {
"demo": "pytest models/demos/llama3/demo/demo.py -k performance-batch-1",
"demo-32": "pytest models/demos/llama3/demo/demo.py -k performance-batch-32",
- "demo-long": "pytest models/demos/llama3/demo/demo.py -k long",
+ "demo-long": "pytest models/demos/llama3/demo/demo.py -k performance-long",
"attention": "pytest models/demos/llama3/tests/test_llama_attention.py",
"attention-prefill": "pytest models/demos/llama3/tests/test_llama_attention_prefill.py",
"mlp": "pytest models/demos/llama3/tests/test_llama_mlp.py",
diff --git a/models/demos/llama3/tests/multimodal/test_llama_cross_attention.py b/models/demos/llama3/tests/multimodal/test_llama_cross_attention.py
index cc34f091e17..462a004b133 100644
--- a/models/demos/llama3/tests/multimodal/test_llama_cross_attention.py
+++ b/models/demos/llama3/tests/multimodal/test_llama_cross_attention.py
@@ -34,9 +34,10 @@
)
@pytest.mark.parametrize(
"batch",
- (1,),
+ (1, 2),
ids=[
"batch_1",
+ "batch_2",
],
)
def test_llama_cross_attention_inference(text_seq_len, batch, mesh_device, reset_seeds, ensure_gc):
@@ -46,6 +47,7 @@ def test_llama_cross_attention_inference(text_seq_len, batch, mesh_device, reset
mesh_device.enable_async(True)
model_args = TtModelArgs(mesh_device)
+ model_args.max_seq_len = text_seq_len
state_dict = torch.load(model_args.consolidated_weights_path, map_location=torch.device("cpu"))
# Ref model needs partial state dict, but our models use full state dict keys as cached weight names
@@ -91,12 +93,15 @@ def test_llama_cross_attention_inference(text_seq_len, batch, mesh_device, reset
"""
pt_xattn_cache = reference_model.compute_xattn_kv_cache(pt_xattn_tokens)
pt_xattn_cache_chunks = torch.chunk(pt_xattn_cache, 2, dim=0)
- pt_xattn_cache_chunks = [x.view(batch, n_heads, vision_seq_len, head_dim) for x in pt_xattn_cache]
+ # slice out repeated KV heads
+ pt_xattn_cache_chunks = [
+ x.view(batch, n_heads, vision_seq_len, head_dim)[:, :: n_heads // n_kv_heads] for x in pt_xattn_cache
+ ]
# Preallocate K and V caches
tt_xattn_cache = [
ttnn.from_torch(
- torch.zeros(batch, n_heads, vision_seq_len, head_dim),
+ torch.zeros(batch, n_kv_heads, vision_seq_len, head_dim),
device=mesh_device,
layout=ttnn.TILE_LAYOUT,
memory_config=ttnn.DRAM_MEMORY_CONFIG,
@@ -109,9 +114,10 @@ def test_llama_cross_attention_inference(text_seq_len, batch, mesh_device, reset
"""
Test forward, prefill and decode!
"""
- for i in range(10):
- seq_len = text_seq_len if i == 0 else 1
+ n_iter = 10
+ for i in range(n_iter):
mode = "prefill" if i == 0 else "decode"
+ seq_len = text_seq_len if mode == "prefill" else 1
pt_x = (torch.rand(batch, seq_len, dim) * 2) - 1
tt_x = pt_x.clone()
@@ -150,18 +156,18 @@ def test_llama_cross_attention_inference(text_seq_len, batch, mesh_device, reset
if mode == "prefill":
outputs = []
for b in range(batch):
- tt_tensor_xattn_tokens = model_args.prepare_inputs_ttnn_prefill(
+ tt_tensor_xattn_tokens = model_args.prepare_residual_tensor_prefill(
tt_xattn_tokens[b : b + 1],
force_replicated=True,
)
- tt_tensor_x = model_args.prepare_inputs_ttnn_prefill(
+ tt_tensor_x = model_args.prepare_residual_tensor_prefill(
tt_x[b : b + 1],
force_replicated=True,
)
tt_xattn_mask = ttnn.from_torch(
- xattn_mask_expand[b : b + 1],
+ xattn_mask[b : b + 1],
device=mesh_device,
- dtype=ttnn.bfloat8_b,
+ dtype=ttnn.bfloat4_b,
layout=ttnn.TILE_LAYOUT,
memory_config=ttnn.DRAM_MEMORY_CONFIG,
mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device),
@@ -169,7 +175,7 @@ def test_llama_cross_attention_inference(text_seq_len, batch, mesh_device, reset
tt_full_text_mask = ttnn.from_torch(
full_text_mask_expand[b : b + 1],
device=mesh_device,
- dtype=ttnn.bfloat8_b,
+ dtype=ttnn.bfloat4_b,
layout=ttnn.TILE_LAYOUT,
memory_config=ttnn.DRAM_MEMORY_CONFIG,
mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device),
@@ -190,18 +196,17 @@ def test_llama_cross_attention_inference(text_seq_len, batch, mesh_device, reset
tt_output_torch = torch.cat(outputs, dim=0).view(batch, seq_len, dim)
else:
- tt_x = model_args.prepare_inputs_ttnn_decode(
+ tt_x = model_args.prepare_residual_tensor_decode(
tt_x,
- ttnn.DRAM_MEMORY_CONFIG,
+ model_args.model_config["SHARDED_ATTN_INPUT_MEMCFG"],
force_replicated=True,
)
- tt_x = ttnn.interleaved_to_sharded(tt_x, model_args.model_config["SHARDED_ATTN_INPUT_MEMCFG"])
xattn_mask_expand = xattn_mask_expand.permute(2, 0, 1, 3).contiguous()
tt_xattn_mask = ttnn.from_torch(
xattn_mask_expand,
device=mesh_device,
- dtype=ttnn.bfloat8_b,
+ dtype=ttnn.bfloat4_b,
layout=ttnn.TILE_LAYOUT,
memory_config=ttnn.DRAM_MEMORY_CONFIG,
mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device),
@@ -218,7 +223,7 @@ def test_llama_cross_attention_inference(text_seq_len, batch, mesh_device, reset
tt_full_text_mask = ttnn.from_torch(
full_text_mask_expand,
device=mesh_device,
- dtype=ttnn.bfloat8_b,
+ dtype=ttnn.bfloat4_b,
layout=ttnn.TILE_LAYOUT,
memory_config=ttnn.DRAM_MEMORY_CONFIG,
mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device),
@@ -239,7 +244,7 @@ def test_llama_cross_attention_inference(text_seq_len, batch, mesh_device, reset
mode=mode,
)
- tt_output_torch = ttnn.to_torch(tt_out, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=1))
+ tt_output_torch = ttnn.to_torch(tt_out, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=-1))
tt_output_torch = tt_output_torch[:, :, :batch, :].reshape(batch, seq_len, dim)
passing, pcc_message = comp_pcc(pt_out, tt_output_torch, pcc_required)
@@ -251,12 +256,13 @@ def test_llama_cross_attention_inference(text_seq_len, batch, mesh_device, reset
tt_xattn_cache_torch = [
ttnn.to_torch(x, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=1)).view(
batch,
- n_heads,
+ n_kv_heads,
vision_seq_len,
head_dim,
)
for x in tt_xattn_cache
]
+
for pt, tt in zip(pt_xattn_cache_chunks, tt_xattn_cache_torch):
passing, pcc_message = comp_pcc(pt, tt, pcc_required)
diff --git a/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py b/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py
index 7448601b8ce..1d9da2fbcca 100644
--- a/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py
+++ b/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py
@@ -75,6 +75,7 @@ def test_llama_cross_attention_transformer_text_inference(
dim = model_args.dim
head_dim = model_args.head_dim
n_heads = model_args.n_heads
+ n_kv_heads = model_args.n_kv_heads
reference_model = llama_reference_mod.CrossAttentionTransformerText(args=model_args)
reference_model.setup_cache(model_args.max_batch_size, torch.float32)
reference_model.load_state_dict(partial_state_dict)
@@ -107,15 +108,18 @@ def test_llama_cross_attention_transformer_text_inference(
# unstack k/v
pt_xattn_cache_chunks = [torch.chunk(x, 2, dim=1) for x in pt_xattn_cache_chunks]
pt_xattn_cache_chunks = [x for xx in pt_xattn_cache_chunks for x in xx]
- pt_xattn_cache_chunks = [x.view(batch, n_heads, vision_seq_len, head_dim) for x in pt_xattn_cache_chunks]
+ pt_xattn_cache_chunks = [
+ x.view(batch, n_heads, vision_seq_len, head_dim)[:, :: n_heads // n_kv_heads] for x in pt_xattn_cache_chunks
+ ]
# Iterate over batch
# Preallocate K and V caches
tt_xattn_cache = tt_model.setup_cache(max_batch_size=batch)
# Test forward pass of the model
- n_iter = 10
+
prev_pos = 0
+ n_iter = 10
# tokens = torch.randint(100, 1000, (batch, text_seq_len+n_iter), dtype=torch.long)#, device="cuda"
tokens = torch.randint(0, model_args.vocab_size, (batch, text_seq_len + n_iter), dtype=torch.long)
for i in range(n_iter):
@@ -177,17 +181,17 @@ def test_llama_cross_attention_transformer_text_inference(
if mode == "prefill":
outputs = []
for b in range(batch):
- tt_tensor_vision_tokens = model_args.prepare_inputs_ttnn_prefill(
+ tt_tensor_vision_tokens = model_args.prepare_residual_tensor_prefill(
tt_vision_tokens[b : b + 1],
force_replicated=True,
)
- tt_h = model_args.prepare_inputs_ttnn_prefill(
+ tt_h = model_args.prepare_residual_tensor_prefill(
h[b : b + 1],
)
tt_xattn_mask = ttnn.from_torch(
- xattn_mask_expand[b : b + 1],
+ xattn_mask[b : b + 1],
device=mesh_device,
- dtype=ttnn.bfloat8_b,
+ dtype=ttnn.bfloat4_b,
layout=ttnn.TILE_LAYOUT,
memory_config=ttnn.DRAM_MEMORY_CONFIG,
mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device),
@@ -195,7 +199,7 @@ def test_llama_cross_attention_transformer_text_inference(
tt_full_text_mask_expand_1NSH = ttnn.from_torch(
full_text_mask_expand_1NSH[b : b + 1],
device=mesh_device,
- dtype=ttnn.bfloat8_b,
+ dtype=ttnn.bfloat4_b,
layout=ttnn.TILE_LAYOUT,
memory_config=ttnn.DRAM_MEMORY_CONFIG,
mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device),
@@ -203,7 +207,7 @@ def test_llama_cross_attention_transformer_text_inference(
tt_full_text_mask_expand_11SD = ttnn.from_torch(
full_text_mask_expand_11SD[b : b + 1],
device=mesh_device,
- dtype=ttnn.bfloat8_b,
+ dtype=ttnn.bfloat4_b,
layout=ttnn.TILE_LAYOUT,
memory_config=ttnn.DRAM_MEMORY_CONFIG,
mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=-1),
@@ -212,16 +216,6 @@ def test_llama_cross_attention_transformer_text_inference(
rot_mats = get_prefill_rot_mat(
model_args.head_dim, model_args.max_seq_len, mesh_device, seq_len=seq_len
)
- transformation_mat_torch = get_rot_transformation_mat(model_args.head_dim)
- transformation_mats = ttnn.as_tensor(
- transformation_mat_torch,
- dtype=ttnn.bfloat16,
- layout=ttnn.TILE_LAYOUT,
- device=mesh_device,
- mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device),
- memory_config=ttnn.DRAM_MEMORY_CONFIG,
- )
-
tt_out = tt_model(
tt_h,
xattn_mask=tt_xattn_mask,
@@ -229,8 +223,7 @@ def test_llama_cross_attention_transformer_text_inference(
full_text_row_masked_out_mask_11SD=tt_full_text_mask_expand_11SD,
xattn_caches=tt_xattn_cache,
current_pos=None,
- rot_mat=rot_mats,
- transformation_mats=transformation_mats,
+ rot_mats=rot_mats,
user_id=b,
mode=mode,
text_only_inference=TEXT_ONLY,
@@ -245,7 +238,7 @@ def test_llama_cross_attention_transformer_text_inference(
pcc_required = prefill_pcc_required
else:
- tt_h = model_args.prepare_inputs_ttnn_decode(
+ tt_h = model_args.prepare_residual_tensor_decode(
h,
model_args.model_config["DECODE_RESIDUAL_MEMCFG"],
)
@@ -265,14 +258,14 @@ def test_llama_cross_attention_transformer_text_inference(
model_args.num_devices,
start_pos=cur_pos - 1,
)
-
- transformation_mats = None
+ tt_rope_id = tt_model.rope_setup.get_rot_idxs(position_ids)
+ rot_mats = tt_model.rope_setup.get_rot_mats(tt_rope_id)
xattn_mask_expand = xattn_mask_expand.permute(2, 0, 1, 3).contiguous()
tt_xattn_mask = ttnn.from_torch(
xattn_mask_expand,
device=mesh_device,
- dtype=ttnn.bfloat8_b,
+ dtype=ttnn.bfloat4_b,
layout=ttnn.TILE_LAYOUT,
memory_config=ttnn.DRAM_MEMORY_CONFIG,
mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device),
@@ -289,7 +282,7 @@ def test_llama_cross_attention_transformer_text_inference(
tt_full_text_mask_expand_1NSH = ttnn.from_torch(
full_text_mask_expand_1NSH,
device=mesh_device,
- dtype=ttnn.bfloat8_b,
+ dtype=ttnn.bfloat4_b,
layout=ttnn.TILE_LAYOUT,
memory_config=ttnn.DRAM_MEMORY_CONFIG,
mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device),
@@ -311,8 +304,7 @@ def test_llama_cross_attention_transformer_text_inference(
full_text_row_masked_out_mask_11SD=None,
xattn_caches=tt_xattn_cache,
current_pos=tt_position_id,
- rot_mat=rot_mats,
- transformation_mats=transformation_mats,
+ rot_mats=rot_mats,
mode=mode,
text_only_inference=TEXT_ONLY,
)
@@ -332,7 +324,7 @@ def test_llama_cross_attention_transformer_text_inference(
tt_xattn_cache_torch = [
ttnn.to_torch(x, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=1)).view(
batch,
- n_heads,
+ n_kv_heads,
vision_seq_len,
head_dim,
)
diff --git a/models/demos/llama3/tests/multimodal/test_llama_cross_block.py b/models/demos/llama3/tests/multimodal/test_llama_cross_block.py
index 96637e5090c..3f6a9253e5d 100644
--- a/models/demos/llama3/tests/multimodal/test_llama_cross_block.py
+++ b/models/demos/llama3/tests/multimodal/test_llama_cross_block.py
@@ -30,9 +30,10 @@
)
@pytest.mark.parametrize(
"batch",
- (1,),
+ (1, 2),
ids=[
"batch_1",
+ "batch_2",
],
)
def test_llama_cross_attention_transformer_block_inference(
@@ -57,6 +58,7 @@ def test_llama_cross_attention_transformer_block_inference(
dim = model_args.dim
head_dim = model_args.head_dim
n_heads = model_args.n_heads
+ n_kv_heads = model_args.n_kv_heads
reference_model = llama_reference_mod.CrossAttentionTransformerBlock(args=model_args, layer_id=0, no_ffn=False)
reference_model.load_state_dict(partial_state_dict)
@@ -83,12 +85,14 @@ def test_llama_cross_attention_transformer_block_inference(
"""
pt_xattn_cache = reference_model.compute_xattn_kv_cache(pt_xattn_tokens)
pt_xattn_cache_chunks = torch.chunk(pt_xattn_cache, 2, dim=0)
- pt_xattn_cache_chunks = [x.view(batch, n_heads, vision_seq_len, head_dim) for x in pt_xattn_cache]
+ pt_xattn_cache_chunks = [
+ x.view(batch, n_heads, vision_seq_len, head_dim)[:, :: n_heads // n_kv_heads] for x in pt_xattn_cache
+ ]
# Preallocate K and V caches
tt_xattn_cache = [
ttnn.from_torch(
- torch.zeros(batch, n_heads, vision_seq_len, head_dim),
+ torch.zeros(batch, n_kv_heads, vision_seq_len, head_dim),
device=mesh_device,
layout=ttnn.TILE_LAYOUT,
memory_config=ttnn.DRAM_MEMORY_CONFIG,
@@ -145,17 +149,17 @@ def test_llama_cross_attention_transformer_block_inference(
if mode == "prefill":
outputs = []
for b in range(batch):
- tt_tensor_xattn_tokens = model_args.prepare_inputs_ttnn_prefill(
+ tt_tensor_xattn_tokens = model_args.prepare_residual_tensor_prefill(
tt_xattn_tokens[b : b + 1],
force_replicated=True,
)
- tt_tensor_x = model_args.prepare_inputs_ttnn_prefill(
+ tt_tensor_x = model_args.prepare_residual_tensor_prefill(
tt_x[b : b + 1],
)
tt_xattn_mask = ttnn.from_torch(
- xattn_mask_expand[b : b + 1],
+ xattn_mask[b : b + 1],
device=mesh_device,
- dtype=ttnn.bfloat8_b,
+ dtype=ttnn.bfloat4_b,
layout=ttnn.TILE_LAYOUT,
memory_config=ttnn.DRAM_MEMORY_CONFIG,
mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device),
@@ -163,7 +167,7 @@ def test_llama_cross_attention_transformer_block_inference(
tt_full_text_mask_expand_1NSH = ttnn.from_torch(
full_text_mask_expand_1NSH[b : b + 1],
device=mesh_device,
- dtype=ttnn.bfloat8_b,
+ dtype=ttnn.bfloat4_b,
layout=ttnn.TILE_LAYOUT,
memory_config=ttnn.DRAM_MEMORY_CONFIG,
mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device),
@@ -171,7 +175,7 @@ def test_llama_cross_attention_transformer_block_inference(
tt_full_text_mask_expand_11SD = ttnn.from_torch(
full_text_mask_expand_11SD[b : b + 1],
device=mesh_device,
- dtype=ttnn.bfloat8_b,
+ dtype=ttnn.bfloat4_b,
layout=ttnn.TILE_LAYOUT,
memory_config=ttnn.DRAM_MEMORY_CONFIG,
mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=-1),
@@ -193,15 +197,15 @@ def test_llama_cross_attention_transformer_block_inference(
tt_output_torch = torch.cat(outputs, dim=0).view(batch, seq_len, dim)
else:
- tt_x = model_args.prepare_inputs_ttnn_decode(
+ tt_x = model_args.prepare_residual_tensor_decode(
tt_x,
- ttnn.DRAM_MEMORY_CONFIG,
+ model_args.model_config["DECODE_RESIDUAL_MEMCFG"],
)
xattn_mask_expand = xattn_mask_expand.permute(2, 0, 1, 3).contiguous()
tt_xattn_mask = ttnn.from_torch(
xattn_mask_expand,
device=mesh_device,
- dtype=ttnn.bfloat8_b,
+ dtype=ttnn.bfloat4_b,
layout=ttnn.TILE_LAYOUT,
memory_config=ttnn.DRAM_MEMORY_CONFIG,
mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device),
@@ -218,7 +222,7 @@ def test_llama_cross_attention_transformer_block_inference(
tt_full_text_mask_expand_1NSH = ttnn.from_torch(
full_text_mask_expand_1NSH,
device=mesh_device,
- dtype=ttnn.bfloat8_b,
+ dtype=ttnn.bfloat4_b,
layout=ttnn.TILE_LAYOUT,
memory_config=ttnn.DRAM_MEMORY_CONFIG,
mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device),
@@ -252,7 +256,7 @@ def test_llama_cross_attention_transformer_block_inference(
tt_xattn_cache_torch = [
ttnn.to_torch(x, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=1)).view(
batch,
- n_heads,
+ n_kv_heads,
vision_seq_len,
head_dim,
)
diff --git a/models/demos/llama3/tests/multimodal/test_llama_image_attention.py b/models/demos/llama3/tests/multimodal/test_llama_image_attention.py
index 844937a518b..3d9e6977145 100644
--- a/models/demos/llama3/tests/multimodal/test_llama_image_attention.py
+++ b/models/demos/llama3/tests/multimodal/test_llama_image_attention.py
@@ -73,7 +73,7 @@ def test_llama_attention_inference(batch, num_chunks, mesh_device, use_program_c
mask = encoder_utils.build_encoder_attention_mask(pt_block_input, ar, ntok, num_chunks, 1)
pt_block_input = pt_block_input.reshape(batch, -1, dim)
- attention_input = model_args.prepare_inputs_ttnn_prefill(
+ attention_input = model_args.prepare_residual_tensor_prefill(
tt_attention_input.view(num_chunks, ntok, dim),
force_replicated=True,
)
diff --git a/models/demos/llama3/tests/multimodal/test_llama_image_block.py b/models/demos/llama3/tests/multimodal/test_llama_image_block.py
index 001aa518828..23096202e29 100644
--- a/models/demos/llama3/tests/multimodal/test_llama_image_block.py
+++ b/models/demos/llama3/tests/multimodal/test_llama_image_block.py
@@ -84,7 +84,7 @@ def test_llama_block_inference(batch, num_chunks, mesh_device, gated, use_progra
mask = encoder_utils.build_encoder_attention_mask(pt_block_input, ar, ntok, num_chunks, 1)
pt_block_input = pt_block_input.reshape(batch, -1, dim)
- attention_input = model_args.prepare_inputs_ttnn_prefill(
+ attention_input = model_args.prepare_residual_tensor_prefill(
tt_attention_input.view(num_chunks, ntok, dim),
force_replicated=True,
)
diff --git a/models/demos/llama3/tests/multimodal/test_llama_image_transformer.py b/models/demos/llama3/tests/multimodal/test_llama_image_transformer.py
index 03f3310a0e3..502736ac790 100644
--- a/models/demos/llama3/tests/multimodal/test_llama_image_transformer.py
+++ b/models/demos/llama3/tests/multimodal/test_llama_image_transformer.py
@@ -113,7 +113,7 @@ def test_llama_image_transformer_inference(
mask = encoder_utils.build_encoder_attention_mask(pt_block_input, ar, ntok, num_chunks, 1)
pt_block_input = pt_block_input.reshape(batch, -1, dim)
- attention_input = model_args.prepare_inputs_ttnn_prefill(
+ attention_input = model_args.prepare_residual_tensor_prefill(
tt_attention_input.view(num_chunks, ntok, dim),
force_replicated=True,
)
diff --git a/models/demos/llama3/tests/multimodal/test_llama_vision_model.py b/models/demos/llama3/tests/multimodal/test_llama_vision_model.py
deleted file mode 100644
index f55a47891ac..00000000000
--- a/models/demos/llama3/tests/multimodal/test_llama_vision_model.py
+++ /dev/null
@@ -1,154 +0,0 @@
-# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
-
-# SPDX-License-Identifier: Apache-2.0
-from pathlib import Path
-from typing import Optional
-from loguru import logger
-
-from PIL import Image as PIL_Image
-from termcolor import cprint
-
-import llama_models.llama3.reference_impl.generation as llama_reference_generation
-
-from llama_models.llama3.api.datatypes import ImageMedia
-
-from models.utility_functions import (
- comp_pcc,
- comp_allclose,
-)
-
-THIS_DIR = Path(__file__).parent.parent.parent.resolve() / "reference/llama_models/models/scripts/"
-
-import torch
-import pytest
-import os
-import ttnn
-
-
-def create_multimodal_model(model_args, mesh_device, dtype=ttnn.bfloat16):
- from models.demos.llama3.tt.multimodal.llama_vision_model import CrossAttentionTransformer
- from models.demos.llama3.tt.model_config import TtModelArgs
-
- tt_model_args = TtModelArgs(mesh_device)
- checkpoint = torch.load(tt_model_args.consolidated_weights_path, map_location="cpu", weights_only=True)
- model = CrossAttentionTransformer(
- model_args,
- mesh_device,
- checkpoint,
- weight_cache_path=tt_model_args.weight_cache_path(dtype),
- dtype=dtype,
- configuration=tt_model_args,
- )
- model.setup_cache(model_args.max_batch_size, torch.float32)
- return model
-
-
-@pytest.mark.parametrize(
- "mesh_device",
- [
- {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get(
- os.environ.get("FAKE_DEVICE"), len(ttnn.get_device_ids())
- )
- ],
- indirect=True,
-)
-def test_llama_vision_model(
- mesh_device,
- temperature: float = 0,
- max_seq_len: int = 512,
- max_batch_size: int = 4,
- max_gen_len: Optional[int] = 50,
- model_parallel_size: Optional[int] = None,
-):
- """
- This test runs the Llama3.2 vision model on CPU and TT concurrently.
- It does not use teacher forcing and compares output logits at every token.
- """
- mesh_device.enable_program_cache()
- mesh_device.enable_async(True)
- ckpt_dir = os.environ["LLAMA_DIR"]
- tokenizer_path = str(Path(ckpt_dir) / "tokenizer.model")
-
- logger.info(f"Creating reference model from checkpoint in '{ckpt_dir}'")
- generator_pt = llama_reference_generation.Llama.build(
- ckpt_dir,
- tokenizer_path=tokenizer_path,
- max_seq_len=max_seq_len,
- max_batch_size=max_batch_size,
- model_parallel_size=model_parallel_size,
- )
-
- generator_tt = llama_reference_generation.Llama(generator_pt.model, generator_pt.tokenizer, generator_pt.args)
- logger.info(f"Creating TT model on {len(mesh_device.get_devices())} devices")
- model = create_multimodal_model(generator_tt.args, mesh_device)
- generator_tt.model = model
-
- # with open(THIS_DIR / "resources/dog.jpg", "rb") as f:
- # img = PIL_Image.open(f).convert("RGB")
-
- # with open(THIS_DIR / "resources/pasta.jpeg", "rb") as f:
- # img2 = PIL_Image.open(f).convert("RGB")
-
- with open(THIS_DIR / "resources/ocr_image.jpeg", "rb") as f:
- ocr_image = PIL_Image.open(f).convert("RGB")
-
- # with open(THIS_DIR / "resources/clutter.jpeg", "rb") as f:
- # clutter = PIL_Image.open(f).convert("RGB")
-
- interleaved_contents = [
- # text only
- # "The color of the sky is blue but sometimes it can also be",
- # image understanding
- # [ImageMedia(image=img), "If I had to write a haiku for this one"],
- # [ImageMedia(image=img2), "Couting the number of individual spaghetti strands in this image"],
- [ImageMedia(image=ocr_image), "The full text in this image is as follows"],
- # [ImageMedia(image=clutter), "The count of vases, books, and miscellaneous items in this image is"],
- ]
-
- for content in interleaved_contents:
- logger.info(f"Generating text for content: {content}")
- model_input = generator_pt.formatter.encode_content(content)
- gen_pt = generator_pt.generate(
- model_input, max_gen_len=max_gen_len, temperature=temperature, return_logits=True
- )
- gen_tt = generator_tt.generate(
- model_input, max_gen_len=max_gen_len, temperature=temperature, return_logits=True
- )
-
- for out_idx, (token_pt, token_tt) in enumerate(zip(gen_pt, gen_tt)):
- logger.info(f"Comparing output token {out_idx}")
- out_pt, out_tt = token_pt[1], token_tt[1]
- out_pt = out_pt[0, -1]
- out_tt = out_tt[0, -1]
- passing, pcc_message = comp_pcc(out_pt, out_tt, 0.90)
- print(f"PCC: {pcc_message}")
- # Check shapes of logprobs
-
- ref_argmax = torch.argmax(out_pt).item()
- ref_logprob = out_pt[ref_argmax].item()
- ref_token = generator_pt.tokenizer.decode([ref_argmax])
-
- # Reference model: top-5 tokens
- ref_top5_vals, ref_top5_idxs = torch.topk(out_pt, 5)
- ref_top5_tokens = [generator_pt.tokenizer.decode([idx.item()]) for idx in ref_top5_idxs]
- ref_top5_logprobs = ref_top5_vals.tolist()
-
- # Test model: top-5 tokens
- top5_vals, top5_idxs = torch.topk(out_tt, 5)
- top5_tokens = [generator_pt.tokenizer.decode([idx.item()]) for idx in top5_idxs]
- top5_logprobs = top5_vals.tolist()
-
- def entropy(logits):
- probs = torch.softmax(logits, dim=-1)
- return -(probs * torch.log(probs)).sum().item()
-
- # Print the information
- print(f"Token Position {out_idx}:")
- print(f" Reference | Test")
- print(f" Entropy: {entropy(out_pt):.4f} | {entropy(out_tt):.4f}")
- print(f" Top-5 Tokens:")
- for rank in range(5):
- print(
- f" {rank+1}. Token='{ref_top5_tokens[rank]}' @ {ref_top5_logprobs[rank]:.2f} | '{top5_tokens[rank]}' @ {top5_logprobs[rank]:.2f}"
- )
- print()
diff --git a/models/demos/llama3/tests/test_interleaved_to_sharded.py b/models/demos/llama3/tests/test_interleaved_to_sharded.py
index b69d7d2459b..9edc9a89dd0 100644
--- a/models/demos/llama3/tests/test_interleaved_to_sharded.py
+++ b/models/demos/llama3/tests/test_interleaved_to_sharded.py
@@ -80,7 +80,7 @@ def test_llama_decoder_inference(mesh_device, use_program_cache, reset_seeds):
mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device),
)
- decode_input = model_args.prepare_inputs_ttnn_decode(
+ decode_input = model_args.prepare_residual_tensor_decode(
tt_decode_input,
ttnn.L1_MEMORY_CONFIG,
)
diff --git a/models/demos/llama3/tests/test_llama_accuracy.py b/models/demos/llama3/tests/test_llama_accuracy.py
index 2ae973a907d..b19cb086066 100644
--- a/models/demos/llama3/tests/test_llama_accuracy.py
+++ b/models/demos/llama3/tests/test_llama_accuracy.py
@@ -9,13 +9,11 @@
import ttnn
from models.demos.llama3.tt.llama_common import (
get_prefill_rot_mat,
- get_rot_transformation_mat,
HostEmbedding,
PagedAttentionConfig,
)
from models.demos.llama3.tt.llama_model import TtTransformer
from models.demos.llama3.tt.model_config import TtModelArgs, LlamaOptimizations
-from models.demos.llama3.tt.llama_rope import TtLlamaRotarySetup
from models.demos.t3000.llama2_70b.reference.llama.llama31_8b.tokenizer import Tokenizer
from models.demos.llama3.demo.demo import preprocess_inputs_prefill
from pathlib import Path
@@ -141,28 +139,6 @@ def test_tt_model_accuracy(
N = prefill_len + decode_len
input_ids = reference_tokens[:, : N + 1] # Shape [1, N+1]
- # Setup RoPE transformation matrices
- rope_setup = TtLlamaRotarySetup(
- mesh_device,
- model_args.max_batch_size,
- model_args.head_dim,
- model_args.max_seq_len,
- model_args.rope_theta,
- model_args.use_scaled_rope,
- )
- transformation_mats_decode = rope_setup.get_trans_mats()
-
- transformation_mats_prefill_torch = get_rot_transformation_mat(model_args.head_dim)
- transformation_mats_prefill = ttnn.from_torch(
- transformation_mats_prefill_torch,
- dtype=ttnn.bfloat16,
- layout=ttnn.TILE_LAYOUT,
- device=mesh_device,
- memory_config=ttnn.DRAM_MEMORY_CONFIG,
- mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device),
- )
- transformation_mats = {"decode": transformation_mats_decode, "prefill": transformation_mats_prefill}
-
page_table_tt = None
paged_attention_config = None
@@ -193,7 +169,6 @@ def test_tt_model_accuracy(
dtype=dtype,
state_dict=state_dict,
weight_cache_path=model_args.weight_cache_path(dtype),
- transformation_mats=transformation_mats,
paged_attention_config=paged_attention_config,
)
# Initialize embedding
@@ -226,7 +201,7 @@ def test_tt_model_accuracy(
model_args.head_dim, model_args.max_seq_len, mesh_device, seq_len=prefill_lens[0]
)
- prefill_input = model_args.prepare_inputs_ttnn_prefill(
+ prefill_input = model_args.prepare_residual_tensor_prefill(
pt_prefill_input[batch_id],
)
@@ -256,7 +231,7 @@ def test_tt_model_accuracy(
)
# Get cos/sin matrices for the current position of each user
- rot_mats = rope_setup.get_rot_mats(current_pos)
+ rot_mats = tt_model.rope_setup.get_rot_mats(current_pos)
# Print table header
logger.info(f"{'Progress':<15}{'Correct':<8}{'True':<15}{'Actual':<15}{'Top 5 Predictions':<75}")
@@ -276,7 +251,7 @@ def test_tt_model_accuracy(
# Get embedding
pt_decode_input = embd(ref_token).view(1, 1, -1)
# Prepare input for TT model
- decode_input = model_args.prepare_inputs_ttnn_decode(
+ decode_input = model_args.prepare_residual_tensor_decode(
pt_decode_input,
model_args.model_config["DECODE_RESIDUAL_MEMCFG"],
)
@@ -309,7 +284,7 @@ def test_tt_model_accuracy(
# Update rot_mats for next iteration
current_pos += 1
- rot_mats = rope_setup.get_rot_mats(current_pos)
+ rot_mats = tt_model.rope_setup.get_rot_mats(current_pos)
# Get reference top5 tokens and probabilities for this position
ref_top5_tokens = top5_tokens[prefill_len + i]
diff --git a/models/demos/llama3/tests/test_llama_attention.py b/models/demos/llama3/tests/test_llama_attention.py
index 8690b91d3b9..edb9ac99a43 100644
--- a/models/demos/llama3/tests/test_llama_attention.py
+++ b/models/demos/llama3/tests/test_llama_attention.py
@@ -100,8 +100,7 @@ def test_llama_attention_inference(
model_args.use_scaled_rope,
)
- transformation_mats = rope_setup.get_trans_mats()
- transformation_mats = {"decode": transformation_mats}
+ transformation_mats = rope_setup.get_both_trans_mats()
page_table_tt = None
paged_attention_config = None
@@ -158,7 +157,7 @@ def test_llama_attention_inference(
tt_attention_input = pt_attention_input.clone()
- attention_input = model_args.prepare_inputs_ttnn_decode(
+ attention_input = model_args.prepare_residual_tensor_decode(
tt_attention_input,
model_args.model_config["SHARDED_ATTN_INPUT_MEMCFG"],
force_replicated=True,
diff --git a/models/demos/llama3/tests/test_llama_attention_prefill.py b/models/demos/llama3/tests/test_llama_attention_prefill.py
index ef33adc4481..4335bdb4ee1 100644
--- a/models/demos/llama3/tests/test_llama_attention_prefill.py
+++ b/models/demos/llama3/tests/test_llama_attention_prefill.py
@@ -141,7 +141,7 @@ def test_llama_attention_inference(
pt_attention_input = (torch.rand(batch_size, max_seq_len, model_args.dim) * 2) - 1
tt_attention_input = pt_attention_input.clone()
- attention_input = model_args.prepare_inputs_ttnn_prefill(
+ attention_input = model_args.prepare_residual_tensor_prefill(
tt_attention_input,
force_replicated=True,
)
diff --git a/models/demos/llama3/tests/test_llama_decoder.py b/models/demos/llama3/tests/test_llama_decoder.py
index 5d24d3b4298..316c811aaf3 100644
--- a/models/demos/llama3/tests/test_llama_decoder.py
+++ b/models/demos/llama3/tests/test_llama_decoder.py
@@ -94,8 +94,7 @@ def test_llama_decoder_inference(
model_args.rope_theta,
model_args.use_scaled_rope,
)
- transformation_mats = rope_setup.get_trans_mats()
- transformation_mats = {"decode": transformation_mats}
+ transformation_mats = rope_setup.get_both_trans_mats()
# Prepare page table for paged attention
page_table_tt = None
@@ -155,7 +154,7 @@ def test_llama_decoder_inference(
pt_decode_input = (torch.rand(batch_size, seqlen, model_args.dim) * 2) - 1
tt_decode_input = pt_decode_input.clone()
- decode_input = model_args.prepare_inputs_ttnn_decode(
+ decode_input = model_args.prepare_residual_tensor_decode(
tt_decode_input,
# ttnn.DRAM_MEMORY_CONFIG,
model_args.model_config["DECODE_RESIDUAL_MEMCFG"],
diff --git a/models/demos/llama3/tests/test_llama_decoder_prefill.py b/models/demos/llama3/tests/test_llama_decoder_prefill.py
index 0c40e21b773..622e67f91b4 100644
--- a/models/demos/llama3/tests/test_llama_decoder_prefill.py
+++ b/models/demos/llama3/tests/test_llama_decoder_prefill.py
@@ -142,7 +142,7 @@ def test_llama_decoder_inference(
logger.info(f"[Decoder] Generating token {i}")
pt_decode_input = (torch.rand(batch_size, max_seq_len, model_args.dim) * 2) - 1
tt_decode_input = pt_decode_input.clone()
- decode_input = model_args.prepare_inputs_ttnn_prefill(
+ decode_input = model_args.prepare_residual_tensor_prefill(
tt_decode_input,
)
positions = torch.LongTensor(range(max_seq_len))
diff --git a/models/demos/llama3/tests/test_llama_model.py b/models/demos/llama3/tests/test_llama_model.py
index cd425579a23..37e0e438419 100644
--- a/models/demos/llama3/tests/test_llama_model.py
+++ b/models/demos/llama3/tests/test_llama_model.py
@@ -15,7 +15,6 @@
)
from models.demos.llama3.tt.model_config import TtModelArgs, LlamaOptimizations
from models.demos.llama3.tt.llama_model import TtTransformer
-from models.demos.llama3.tt.llama_rope import TtLlamaRotarySetup
from models.demos.t3000.llama2_70b.reference.llama.llama31_8b.model import Transformer
from models.demos.t3000.llama2_70b.reference.llama.llama31_8b.tokenizer import Tokenizer
from models.utility_functions import (
@@ -191,18 +190,6 @@ def test_llama_model_inference(
generation_start_pos = 0
generation_length = iterations
- # Setup RoPE transformation matrices
- rope_setup = TtLlamaRotarySetup(
- mesh_device,
- model_args.max_batch_size,
- model_args.head_dim,
- model_args.max_seq_len,
- model_args.rope_theta,
- model_args.use_scaled_rope,
- )
- transformation_mats = rope_setup.get_trans_mats()
- transformation_mats = {"decode": transformation_mats}
-
page_table_tt = None
paged_attention_config = None
@@ -234,7 +221,6 @@ def test_llama_model_inference(
dtype=dtype,
state_dict=state_dict,
weight_cache_path=model_args.weight_cache_path(dtype),
- transformation_mats=transformation_mats,
paged_attention_config=paged_attention_config,
)
logger.info("Model and caches loaded.")
@@ -269,13 +255,13 @@ def test_llama_model_inference(
for i in range(generation_length):
logger.info(f"[Llama3 Model] Generating token {i}")
- decode_input = model_args.prepare_inputs_ttnn_decode(
+ decode_input = model_args.prepare_residual_tensor_decode(
tt_decode_input,
model_args.model_config["DECODE_RESIDUAL_MEMCFG"],
)
# Get cos/sin matrices for the current position of each user
- rot_mats = rope_setup.get_rot_mats(current_pos)
+ rot_mats = tt_model.rope_setup.get_rot_mats(current_pos)
# Run TT model
tt_out = tt_model(
diff --git a/models/demos/llama3/tests/test_llama_model_prefill.py b/models/demos/llama3/tests/test_llama_model_prefill.py
index 934c91d5746..e30c25cc8f4 100644
--- a/models/demos/llama3/tests/test_llama_model_prefill.py
+++ b/models/demos/llama3/tests/test_llama_model_prefill.py
@@ -93,7 +93,7 @@ def test_llama_model_inference(
pcc = 0.91 # TODO Look on improving PCC
else: # performance mode
assert optimizations == LlamaOptimizations.performance
- pcc = 0.87 # TODO Look on improving PCC
+ pcc = 0.869 # TODO Look on improving PCC
mesh_device.enable_async(True)
@@ -143,17 +143,6 @@ def test_llama_model_inference(
# pre-compute the rotational embedding matrix and send to device
rot_mats = get_prefill_rot_mat(model_args.head_dim, model_args.max_seq_len, mesh_device, seq_len=seq_len)
- transformation_mat_torch = get_rot_transformation_mat(model_args.head_dim)
- transformation_mats_prefill = ttnn.as_tensor(
- transformation_mat_torch,
- dtype=ttnn.bfloat16,
- layout=ttnn.TILE_LAYOUT,
- device=mesh_device,
- memory_config=ttnn.DRAM_MEMORY_CONFIG,
- mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device),
- )
- transformation_mats = {"prefill": transformation_mats_prefill}
-
# Setup page table
page_table_tt = None
paged_attention_config = None
@@ -185,7 +174,6 @@ def test_llama_model_inference(
dtype=dtype,
state_dict=state_dict,
weight_cache_path=model_args.weight_cache_path(dtype),
- transformation_mats=transformation_mats,
paged_attention_config=paged_attention_config,
)
@@ -200,7 +188,7 @@ def test_llama_model_inference(
tt_prefill_input = pt_prefill_input
- tt_prefill_input = model_args.prepare_inputs_ttnn_prefill(
+ tt_prefill_input = model_args.prepare_residual_tensor_prefill(
pt_prefill_input,
)
for i in range(1):
diff --git a/models/demos/llama3/tt/multimodal/vision_generator.py b/models/demos/llama3/tt/generator.py
similarity index 59%
rename from models/demos/llama3/tt/multimodal/vision_generator.py
rename to models/demos/llama3/tt/generator.py
index b00fbf3ff73..c42450e48d3 100644
--- a/models/demos/llama3/tt/multimodal/vision_generator.py
+++ b/models/demos/llama3/tt/generator.py
@@ -1,8 +1,10 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
# SPDX-License-Identifier: Apache-2.0
+
import ttnn
import torch
+from loguru import logger
from llama_models.llama3.api.datatypes import (
InterleavedTextMedia,
@@ -15,10 +17,16 @@
TokenResult,
sample_top_p,
)
+from models.demos.llama3.tt.llama_common import (
+ copy_host_to_device,
+ get_padded_prefill_len,
+ num_blocks_in_seq,
+ get_block_size,
+)
-class LlamaVision:
- def __init__(self, model, model_args, mesh_device, vllm=False, tokenizer=None, formatter=None):
+class LlamaGenerator:
+ def __init__(self, model, model_args, mesh_device, tokenizer=None, formatter=None):
"""
Creating a LlamaVision wrapper requires only a mesh_device and model_args.
With model_args you have the checkpoint location, can specify max batch size
@@ -32,10 +40,133 @@ def __init__(self, model, model_args, mesh_device, vllm=False, tokenizer=None, f
self.model = model
self.model_args = model_args
self.mesh_device = mesh_device
- self.vllm = vllm
self.tokenizer = tokenizer
self.formatter = formatter
+ def prefill_forward_text(self, tokens: torch.Tensor, page_table=None, kv_cache=None, prompt_lens=None):
+ batch, batch_seq_len = tokens.shape
+ output_logits = torch.zeros(batch, 1, self.model_args.vocab_size)
+ prompt_lens = prompt_lens if prompt_lens is not None else torch.tensor([batch_seq_len] * batch)
+
+ if page_table is not None:
+ assert isinstance(
+ page_table, torch.Tensor
+ ), "page_table must be a torch.Tensor when passing into prefill_forward"
+
+ for user_id in range(batch):
+ seq_len = prompt_lens[user_id]
+ last_token_idx = seq_len - 1
+
+ prefill_seq_len = get_padded_prefill_len(seq_len)
+ prefill_ids = torch.cat(
+ [tokens[user_id : user_id + 1, :seq_len], torch.zeros(1, prefill_seq_len - seq_len).long()], dim=-1
+ )
+ if page_table is not None:
+ page_table_user = self._get_prefill_user_page_table(page_table, kv_cache, seq_len)
+
+ logits = self.prefill_forward_single_user_text(
+ prefill_ids,
+ page_table=page_table_user if page_table is not None else None,
+ user_id=user_id,
+ last_token_idx=last_token_idx,
+ kv_cache=kv_cache,
+ )
+
+ # Since we give unpadded_seq_len, only the tile containing the last token is returned
+ output_logits[user_id] = logits
+
+ return output_logits
+
+ def prefill_forward_single_user_text(self, tokens, page_table, user_id, last_token_idx, kv_cache=None):
+ prefill_input, rot_mats_prefill, page_table_tt = self.model.prepare_inputs_prefill(
+ tokens,
+ page_table=page_table,
+ )
+
+ tt_logits = self.model.ttnn_prefill_forward(
+ prefill_input,
+ rot_mats=rot_mats_prefill,
+ user_id=user_id,
+ page_table=page_table_tt,
+ get_last_token=(last_token_idx // 32) * 32,
+ )
+
+ logits = self.model.process_output_prefill(tt_logits, last_token_idx=(last_token_idx % 32))
+
+ return logits
+
+ def decode_forward_text(
+ self,
+ tokens,
+ current_pos,
+ page_table=None,
+ ):
+ """
+ Performs text decode step.
+ Returns logits
+ """
+ tt_tokens, tt_current_pos, tt_rot_mats, tt_page_table = self.model.prepare_inputs_decode(
+ tokens, current_pos, page_table
+ )
+
+ tt_logits = self.model.ttnn_decode_forward(
+ tt_tokens,
+ tt_current_pos,
+ rot_mats=tt_rot_mats,
+ page_table=tt_page_table,
+ )
+
+ logits = self.model.process_output_decode(tt_logits)
+ return logits
+
+ def capture_trace_text(
+ self,
+ tokens,
+ current_pos,
+ page_table=None,
+ ):
+ """
+ Captures a trace for the decode_forward method.
+ """
+
+ # Compile run
+ self.decode_forward_text(tokens, current_pos, page_table)
+
+ # Get inputs ready for trace run
+ host_inputs = self.model.prepare_decode_inputs_host(tokens, current_pos, page_table)
+
+ device_inputs = copy_host_to_device(host_inputs, mesh_device=self.mesh_device)
+
+ trace_id = ttnn.begin_trace_capture(self.mesh_device, cq_id=0)
+ transformed_inputs = self.model.transform_decode_inputs_device(*device_inputs)
+ tt_out_trace = self.model.ttnn_decode_forward(*transformed_inputs)
+
+ ttnn.end_trace_capture(self.mesh_device, trace_id, cq_id=0)
+
+ return trace_id, tt_out_trace, *device_inputs
+
+ def decode_forward_trace_text(
+ self,
+ trace_id,
+ device_inputs,
+ tt_out_trace,
+ tokens,
+ current_pos,
+ page_table=None,
+ ):
+ host_inputs = self.model.prepare_decode_inputs_host(tokens, current_pos, page_table)
+
+ device_inputs = copy_host_to_device(
+ host_tensors=host_inputs,
+ device_tensors=device_inputs,
+ )
+
+ ttnn.execute_trace(self.mesh_device, trace_id, cq_id=0, blocking=False)
+
+ logits = self.model.process_output_decode(tt_out_trace)
+
+ return logits
+
def prefill_forward_single_user(
self,
vision_images,
@@ -45,28 +176,37 @@ def prefill_forward_single_user(
user_id,
total_len,
prefill_len,
+ page_table=None,
+ kv_cache=None,
):
"""
Performs vision encode step then text prefill.
Returns (xattn_caches, cross_attention_masks, full_text_row_masked_out_mask, logits)
"""
B = tokens.shape[0]
+ last_token_idx = prefill_len - 1
vision_tokens, cross_attention_masks, full_text_row_masked_out_mask = self.model.compute_vision_tokens_masks(
batch_images=[vision_images],
batch_masks=[vision_mask],
total_len=total_len,
)
+ if page_table is not None:
+ page_table = self._get_prefill_user_page_table(page_table, kv_cache, prefill_len)
+
(
tt_h,
tt_xattn_mask,
tt_full_text_mask_expand_1NSH,
tt_full_text_mask_expand_11SD,
- tt_position_id,
rot_mats,
- transformation_mats,
+ tt_page_table,
) = self.model.prepare_inputs_prefill(
- tokens, cross_attention_masks, full_text_row_masked_out_mask, prefill_len=prefill_len
+ tokens,
+ cross_attention_masks,
+ full_text_row_masked_out_mask,
+ prefill_len=prefill_len,
+ page_table=page_table,
)
tt_logits = self.model.ttnn_prefill_forward(
@@ -75,24 +215,76 @@ def prefill_forward_single_user(
tt_full_text_mask_expand_1NSH,
tt_full_text_mask_expand_11SD,
xattn_caches,
- tt_position_id,
rot_mats,
- transformation_mats,
user_id,
vision_tokens,
+ page_table=tt_page_table,
+ kv_cache=kv_cache,
+ get_last_token=(last_token_idx // 32) * 32,
)
- logits = self.model.process_output_prefill(tt_logits, B, prefill_len)
+ del tt_page_table
+
+ logits = self.model.process_output_prefill(tt_logits, B, last_token_idx=(last_token_idx % 32))
return xattn_caches, cross_attention_masks, full_text_row_masked_out_mask, logits
+ def prefill_forward(
+ self,
+ vision_images,
+ vision_masks,
+ tokens: torch.Tensor,
+ xattn_caches,
+ total_lens,
+ prompt_lens,
+ page_table=None,
+ kv_cache=None,
+ ):
+ """
+ Batched version of prefill_forward_single_user for vision model.
+ """
+ batch, batch_seq_len = tokens.shape
+ output_logits = torch.zeros(batch, 1, self.model_args.vocab_size)
+ output_xattn_masks = []
+ output_full_text_row_masked_out_masks = []
+
+ for user_id in range(batch):
+ print(f"Prefilling User {user_id}")
+ seq_len = prompt_lens[user_id]
+ (
+ xattn_caches,
+ cross_attention_masks,
+ full_text_row_masked_out_mask,
+ logits,
+ ) = self.prefill_forward_single_user(
+ vision_images=vision_images[user_id],
+ vision_mask=vision_masks[user_id],
+ tokens=tokens[user_id : user_id + 1, :seq_len], # Keep batch dimension
+ xattn_caches=xattn_caches,
+ user_id=user_id,
+ total_len=total_lens[user_id],
+ prefill_len=seq_len,
+ page_table=page_table,
+ kv_cache=kv_cache,
+ )
+ output_logits[user_id] = logits
+ output_xattn_masks.append(cross_attention_masks)
+ output_full_text_row_masked_out_masks.append(full_text_row_masked_out_mask)
+
+ logger.info(f"Finished prefill for all users up to {batch_seq_len} tokens, Starting decode...")
+
+ return output_logits, output_xattn_masks, output_full_text_row_masked_out_masks
+
def decode_forward(
self,
- position_id,
+ start_pos,
tokens,
cross_attention_masks,
full_text_row_masked_out_mask,
xattn_caches,
+ page_table=None,
+ kv_cache=None,
+ prompt_lens=None,
):
"""
Performs text decode step.
@@ -101,19 +293,18 @@ def decode_forward(
# forward_decode should be traced callable
# decorator does compilation, capture, execute
- # B = 1 # TODO: Only supports batch=1 right now! Might make tokens input a tensor.
B, S = tokens.shape
+ assert S == 1
(
tt_h,
tt_xattn_mask,
tt_full_text_mask_expand_1NSH,
- _,
tt_position_id,
- rot_mats,
- _,
+ tt_rot_mats,
+ tt_page_table,
) = self.model.prepare_inputs_decode(
- tokens, cross_attention_masks, full_text_row_masked_out_mask, position_id=position_id
+ tokens, cross_attention_masks, full_text_row_masked_out_mask, position_id=start_pos, page_table=page_table
)
tt_logits = self.model.ttnn_decode_forward(
@@ -122,7 +313,9 @@ def decode_forward(
tt_full_text_mask_expand_1NSH,
xattn_caches,
tt_position_id,
- rot_mats,
+ tt_rot_mats,
+ page_table=tt_page_table,
+ kv_cache=kv_cache,
)
logits = self.model.process_output_decode(tt_logits, B, S)
@@ -143,10 +336,9 @@ def capture_trace(
tt_h,
tt_xattn_mask,
tt_full_text_mask_expand_1NSH,
- _,
tt_position_id,
- rot_mats,
- _,
+ tt_rot_mats,
+ tt_page_table,
) = self.model.prepare_inputs_decode(
tokens, cross_attention_masks, full_text_row_masked_out_mask, position_id=position_id
)
@@ -158,7 +350,7 @@ def capture_trace(
tt_full_text_mask_expand_1NSH,
xattn_caches,
tt_position_id,
- rot_mats,
+ tt_rot_mats,
)
# Get inputs ready for trace run
@@ -166,9 +358,8 @@ def capture_trace(
tt_h,
tt_xattn_mask,
tt_full_text_mask_expand_1NSH,
- _,
tt_position_id,
- rot_mats,
+ tt_rope_id,
_,
) = self.model.prepare_decode_inputs_host(
tokens, cross_attention_masks, full_text_row_masked_out_mask, position_id
@@ -179,9 +370,10 @@ def capture_trace(
tt_xattn_mask,
tt_full_text_mask_expand_1NSH,
tt_position_id,
- rot_mats,
- ) = self.model.copy_host_to_device(
- (tt_h, tt_xattn_mask, tt_full_text_mask_expand_1NSH, tt_position_id, rot_mats)
+ tt_rope_id,
+ ) = copy_host_to_device(
+ (tt_h, tt_xattn_mask, tt_full_text_mask_expand_1NSH, tt_position_id, tt_rope_id),
+ mesh_device=self.mesh_device,
)
trace_id = ttnn.begin_trace_capture(self.mesh_device, cq_id=0)
@@ -189,36 +381,30 @@ def capture_trace(
B = tokens.shape[0]
# Do on-device transformations of inputs before forward
(
- tt_h,
+ tt_h_transform,
+ tt_rot_mats,
tt_xattn_mask_transform,
tt_full_text_mask_expand_1NSH_transform,
) = self.model.transform_decode_inputs_device(
tt_h,
+ tt_rope_id,
tt_xattn_mask,
tt_full_text_mask_expand_1NSH,
B=B,
)
tt_logits_rm = self.model.ttnn_decode_forward(
- tt_h,
+ tt_h_transform,
tt_xattn_mask_transform,
tt_full_text_mask_expand_1NSH_transform,
xattn_caches,
tt_position_id,
- rot_mats,
+ tt_rot_mats,
)
ttnn.end_trace_capture(self.mesh_device, trace_id, cq_id=0)
- return (
- trace_id,
- tt_logits_rm,
- tt_h_trace_input,
- tt_xattn_mask,
- tt_full_text_mask_expand_1NSH,
- tt_position_id,
- rot_mats,
- )
+ return trace_id, tt_logits_rm, tt_h, tt_xattn_mask, tt_full_text_mask_expand_1NSH, tt_position_id, tt_rope_id
def decode_forward_trace(
self,
@@ -233,28 +419,27 @@ def decode_forward_trace(
trace_xattn_mask,
trace_full_text_mask_expand_1NSH,
trace_position_id,
- trace_rot_mats,
+ trace_rope_id,
):
(
tt_h,
tt_xattn_mask,
tt_full_text_mask_expand_1NSH,
- _,
tt_position_id,
- rot_mats,
+ tt_rope_id,
_,
) = self.model.prepare_decode_inputs_host(
tokens, cross_attention_masks, full_text_row_masked_out_mask, position_id=position_id
)
- self.model.copy_host_to_device(
- host_tensors=(tt_h, tt_xattn_mask, tt_full_text_mask_expand_1NSH, tt_position_id, rot_mats),
+ copy_host_to_device(
+ host_tensors=(tt_h, tt_xattn_mask, tt_full_text_mask_expand_1NSH, tt_position_id, tt_rope_id),
device_tensors=(
trace_h,
trace_xattn_mask,
trace_full_text_mask_expand_1NSH,
trace_position_id,
- trace_rot_mats,
+ trace_rope_id,
),
)
@@ -284,7 +469,7 @@ def easy_trace(
tt_xattn_mask,
tt_full_text_mask_expand_1NSH,
tt_position_id,
- rot_mats,
+ tt_rope_id,
) = self.capture_trace(
position_id,
tokens,
@@ -298,7 +483,7 @@ def easy_trace(
"tt_xattn_mask": tt_xattn_mask,
"tt_full_text_mask_expand_1NSH": tt_full_text_mask_expand_1NSH,
"tt_position_id": tt_position_id,
- "rot_mats": rot_mats,
+ "tt_rope_id": tt_rope_id,
}
self.trace_outputs = {
"tt_logits_rm": tt_logits_rm,
@@ -316,7 +501,7 @@ def easy_trace(
self.trace_inputs["tt_xattn_mask"],
self.trace_inputs["tt_full_text_mask_expand_1NSH"],
self.trace_inputs["tt_position_id"],
- self.trace_inputs["rot_mats"],
+ self.trace_inputs["tt_rope_id"],
)
def generate(
@@ -351,6 +536,8 @@ def generate(
prefill_len=prefill_len,
)
+ logits = logits.view(1, 1, self.model_args.max_vocab_size)
+
def sample(logits):
if temperature > 0:
probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
@@ -368,14 +555,14 @@ def sample(logits):
)
for gen_idx in range(max_gen_len - 1):
- position_id = prefill_len + gen_idx
+ position_id = torch.tensor([prefill_len + gen_idx])
next_token_tensor = next_token.reshape(1, 1) # B, S
logits = self.decode_forward(
position_id,
next_token_tensor,
- cross_attention_masks,
- full_text_row_masked_out_mask,
+ [cross_attention_masks],
+ [full_text_row_masked_out_mask],
xattn_caches,
)
@@ -442,3 +629,9 @@ def text_completion(
generation = self.tokenizer.decode(tokens)
return CompletionPrediction(generation=generation)
+
+ def _get_prefill_user_page_table(self, page_table, kv_cache, prefill_len):
+ # Ensure page_table is not padded with extra blocks for paged_fill_cache to work properly
+ block_size = get_block_size(kv_cache)
+ num_blocks = num_blocks_in_seq(prefill_len, block_size)
+ return page_table[:, :num_blocks]
diff --git a/models/demos/llama3/tt/generator_vllm.py b/models/demos/llama3/tt/generator_vllm.py
new file mode 100644
index 00000000000..f962b2801b1
--- /dev/null
+++ b/models/demos/llama3/tt/generator_vllm.py
@@ -0,0 +1,77 @@
+# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
+
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import List, Union
+import torch
+import PIL
+from llama_models.llama3.api.chat_format import create_vision_mask
+
+from models.demos.llama3.tt.generator import LlamaGenerator
+from models.demos.llama3.demo.simple_vision_demo import create_multimodal_model
+
+from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, EncoderDecoderInputs, InputContext
+
+
+def input_processor_for_mllama(ctx: InputContext, inputs: Union[DecoderOnlyInputs, EncoderDecoderInputs]):
+ """
+ Based on vllm.model_executor.models.mllama.py::input_processor_for_mllama().
+ Note that vLLM's input_processor_for_mllama performs additional processing to handle chunking which we do not yet support.
+ """
+
+ # Move encoder_prompt to prompt. If the user does not explicitly provide separate
+ # encoder and decoder prompts, vLLM by default will treat the prompt as the encoder prompt.
+ # For the block manager to allocate enough blocks and add them to the block table, the decoder prompt
+ # must contain the full text prompt.
+ if inputs.get("prompt") is None:
+ inputs["prompt"] = inputs["encoder_prompt"]
+ inputs["prompt_token_ids"] = inputs["encoder_prompt_token_ids"]
+
+ return inputs
+
+
+@INPUT_REGISTRY.register_input_processor(input_processor_for_mllama)
+class TtMllamaForConditionalGeneration(LlamaGenerator):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ self.MLLAMA_IMAGE_TOKEN_ID = 128256
+ self.max_gen_len = self.model_args.max_seq_len - 1 # TODO: double check what this should be
+
+ @classmethod
+ def initialize_vllm_model(cls, hf_config, mesh_device, max_batch_size):
+ max_seq_len = 512 # TODO: Increase to 131072 once it's verified to work
+ model_args, model = create_multimodal_model(mesh_device, max_batch_size, max_seq_len, use_paged_kv_cache=True)
+ return cls(model, model_args, mesh_device)
+
+ @property
+ def cache_path(self):
+ return self.model_args.model_cache_path
+
+ def prefill_forward(
+ self,
+ tokens: torch.Tensor,
+ images: List[PIL.Image.Image],
+ xattn_caches,
+ start_pos,
+ page_table: torch.Tensor = None,
+ kv_cache=None,
+ prompt_lens=None,
+ ):
+ """
+ Replaces prefill_forward from LlamaGenerator with a version that supports mask creation.
+ """
+ batch = tokens.shape[0]
+
+ vision_images = []
+ vision_masks = []
+ total_lens = []
+ for user_id in range(batch):
+ vision_images.append([images[user_id]])
+ prompt_tokens = [int(tokens[user_id, i]) for i in range(prompt_lens[user_id])]
+ vision_masks.append(create_vision_mask(prompt_tokens, self.MLLAMA_IMAGE_TOKEN_ID))
+ total_lens.append(prompt_lens[user_id] + self.max_gen_len)
+
+ return super().prefill_forward(
+ vision_images, vision_masks, tokens, xattn_caches, total_lens, prompt_lens, page_table, kv_cache
+ )
diff --git a/models/demos/llama3/tt/llama_attention.py b/models/demos/llama3/tt/llama_attention.py
index a925044554a..b93438e469d 100644
--- a/models/demos/llama3/tt/llama_attention.py
+++ b/models/demos/llama3/tt/llama_attention.py
@@ -23,6 +23,7 @@ def __init__(
transformation_mats,
configuration,
paged_attention_config=None,
+ use_paged_kv_cache=False,
):
super().__init__()
@@ -56,6 +57,7 @@ def __init__(
self.ccl_topology = configuration.ccl_topology()
self.is_multichip = configuration.is_multichip
+ self.layer_num = layer_num
layer_name = configuration.get_state_dict_prefix(self.__class__.__name__, layer_num)
if configuration.dummy_weights or (weight_cache_path is None):
cache_name = lambda _: None
@@ -144,6 +146,17 @@ def __init__(
cache_file_name=cache_name("wo_height_sharded"),
)
+ if not use_paged_kv_cache:
+ # vLLM provides its own kv cache
+ self.init_kv_cache(configuration, weight_cache_path)
+
+ self.scale = self.head_dim**-0.5
+
+ def init_kv_cache(self, configuration, weight_cache_path):
+ """
+ Generates empty KV cache and pushed to device memory
+ """
+
if self.paged_attention_config:
cache_k = torch.zeros(
(
@@ -194,14 +207,13 @@ def __init__(
for k_or_v in [cache_k, cache_v]
]
- self.scale = self.head_dim**-0.5
-
def forward_decode(
self,
x: ttnn.Tensor,
current_pos,
rot_mats=None,
page_table=None,
+ kv_cache=None,
) -> ttnn.Tensor:
"""
x: (seq_len, 1, batch, dim)
@@ -262,8 +274,12 @@ def forward_decode(
###
# KV update
###
- keys = self.layer_past[0]
- values = self.layer_past[1]
+ if kv_cache:
+ keys = kv_cache[self.layer_num][0]
+ values = kv_cache[self.layer_num][1]
+ else:
+ keys = self.layer_past[0]
+ values = self.layer_past[1]
# k_heads, [seqlen, n_kv_heads, bsz, head_dim]
# v_heads [seqlen, n_kv_heads, bsz, head_dim]
# keys, [max_batch_size, n_kv_heads // configuration.num_devices, max_seq_len, head_dim]
@@ -272,9 +288,6 @@ def forward_decode(
values, v_heads_1BKD, update_idxs_tensor=current_pos, page_table=page_table
)
- self.layer_past[0] = keys
- self.layer_past[1] = values
-
ttnn.deallocate(k_heads_1BKD)
ttnn.deallocate(v_heads_1BKD)
@@ -362,7 +375,7 @@ def forward_decode(
dense_out_sharded = ttnn.to_memory_config(dense_out_sharded, self.model_config["DECODE_RESIDUAL_MEMCFG"])
return dense_out_sharded
- def forward_prefill(self, x_11SH, rot_mats, user_id: int = 0, page_table=None):
+ def forward_prefill(self, x_11SH, rot_mats, user_id: int = 0, page_table=None, kv_cache=None):
seq_len = x_11SH.shape[-2]
assert seq_len % 128 == 0 and seq_len > 0, "Seqlen must be divisible by 128"
###
@@ -425,7 +438,10 @@ def forward_prefill(self, x_11SH, rot_mats, user_id: int = 0, page_table=None):
ttnn.deallocate(k_heads_1KSD_pre_rot)
# Fill KV-Cache
- keys_BKSD, values_BKSD = self.layer_past[0], self.layer_past[1]
+ if kv_cache:
+ keys_BKSD, values_BKSD = kv_cache[self.layer_num][0], kv_cache[self.layer_num][1]
+ else:
+ keys_BKSD, values_BKSD = self.layer_past[0], self.layer_past[1]
k_heads_1KSD_8b = ttnn.typecast(k_heads_1KSD, dtype=ttnn.bfloat8_b)
v_heads_1VSD_8b = ttnn.typecast(v_heads_1VSD, dtype=ttnn.bfloat8_b)
@@ -451,8 +467,14 @@ def forward_prefill(self, x_11SH, rot_mats, user_id: int = 0, page_table=None):
ttnn.deallocate(v_heads_1VSD)
if page_table:
- ttnn.experimental.paged_fill_cache(keys_BKSD, k_fill, page_table, batch_idx=user_id)
- ttnn.experimental.paged_fill_cache(values_BKSD, v_fill, page_table, batch_idx=user_id)
+ # In the case that the tokens have been padded along the seq len dimension, we need to fill the cache with the unpadded k/v values.
+ # Assume that the page table does not have padding, so we can use it to get the unpadded page len.
+ block_size = keys_BKSD.shape[2]
+ page_len = page_table.shape[1] * block_size
+ k_fill_sliced = k_fill[:, :, :page_len, :] if page_len < k_fill.shape[2] else k_fill
+ v_fill_sliced = v_fill[:, :, :page_len, :] if page_len < v_fill.shape[2] else v_fill
+ ttnn.experimental.paged_fill_cache(keys_BKSD, k_fill_sliced, page_table, batch_idx=user_id)
+ ttnn.experimental.paged_fill_cache(values_BKSD, v_fill_sliced, page_table, batch_idx=user_id)
else:
ttnn.fill_cache(
keys_BKSD,
@@ -469,8 +491,6 @@ def forward_prefill(self, x_11SH, rot_mats, user_id: int = 0, page_table=None):
ttnn.deallocate(k_fill)
ttnn.deallocate(v_fill)
- self.layer_past = [keys_BKSD, values_BKSD]
-
# SDPA
# reshaping to put group in batch dim to do sdpa on 8 MQAs in parallel
@@ -550,8 +570,17 @@ def forward_prefill(self, x_11SH, rot_mats, user_id: int = 0, page_table=None):
else:
return output_11SH
- def forward(self, x, current_pos, rot_mats=None, user_id=0, mode="decode", page_table=None):
+ def forward(
+ self,
+ x,
+ current_pos,
+ rot_mats=None,
+ user_id=0,
+ mode="decode",
+ page_table=None,
+ kv_cache=None,
+ ):
if mode == "prefill":
- return self.forward_prefill(x, rot_mats, user_id, page_table)
+ return self.forward_prefill(x, rot_mats, user_id, page_table=page_table, kv_cache=kv_cache)
else:
- return self.forward_decode(x, current_pos, rot_mats, page_table)
+ return self.forward_decode(x, current_pos, rot_mats, page_table=page_table, kv_cache=kv_cache)
diff --git a/models/demos/llama3/tt/llama_common.py b/models/demos/llama3/tt/llama_common.py
index b9b5484cb89..fd7f368557f 100644
--- a/models/demos/llama3/tt/llama_common.py
+++ b/models/demos/llama3/tt/llama_common.py
@@ -209,6 +209,27 @@ def num_to_core_range_set(x):
)
+def copy_host_to_device(host_tensors, device_tensors=None, mesh_device=None):
+ """
+ Helper function which copies host tensors to device tensors.
+ If no device_tensors are provided, it creates new device tensors and returns them.
+ """
+ if device_tensors is None:
+ assert mesh_device is not None, "mesh_device is required when device_tensors is None"
+ ret = []
+ for i in range(len(host_tensors)):
+ on_device = ttnn.to_device(host_tensors[i], device=mesh_device) if host_tensors[i] else None
+ ret.append(on_device)
+ return ret
+ else:
+ for i in range(len(host_tensors)):
+ if host_tensors[i] is None:
+ assert device_tensors[i] is None
+ continue
+ ttnn.copy_host_to_device_tensor(host_tensors[i], device_tensors[i])
+ return device_tensors
+
+
def calculate_hidden_dim(dim, ffn_dim_multiplier, multiple_of):
"""Helper function based on logic used in reference model:
https://github.com/meta-llama/llama-models/blob/e4a6ed52a142bb9b5106dcbf48e41f97f8e7378e/models/llama3/reference_impl/model.py#L227C7-L231C83
@@ -295,3 +316,29 @@ def sample_host(tt_input, mesh_device, temperature=0.6, top_p=0.08, on_host=True
),
pt_out,
)
+
+
+def get_padded_prefill_len(seq_len):
+ """
+ If seq_len is less than 32, pad to 32
+ If seq_len is more than 32, pad to whichever is smaller: a power of 2 or a multiple of 1024
+ TODO: Generalize for max_mm_seq_len different from 1024
+ """
+ if seq_len <= 32:
+ return 32
+ pow_2_pad = nearest_pow_2(seq_len)
+ mult_1024_pad = 1024 * math.ceil(seq_len / 1024)
+ min_extended_pad = min(pow_2_pad, mult_1024_pad)
+ return min_extended_pad
+
+
+def get_block_size(kv_cache):
+ return kv_cache[0][0].shape[2]
+
+
+def num_blocks_in_seq(seq_len, block_size):
+ return math.ceil(seq_len / block_size)
+
+
+def nearest_pow_2(x):
+ return 2 ** math.ceil(math.log2(x))
diff --git a/models/demos/llama3/tt/llama_decoder.py b/models/demos/llama3/tt/llama_decoder.py
index e5edfce889a..ad1bdf9b59a 100644
--- a/models/demos/llama3/tt/llama_decoder.py
+++ b/models/demos/llama3/tt/llama_decoder.py
@@ -20,6 +20,7 @@ def __init__(
weight_cache_path,
transformation_mats,
paged_attention_config=None,
+ use_paged_kv_cache=False,
):
super().__init__()
@@ -48,6 +49,7 @@ def __init__(
transformation_mats=transformation_mats,
configuration=args,
paged_attention_config=paged_attention_config,
+ use_paged_kv_cache=use_paged_kv_cache,
)
self.feed_forward = TtLlamaMLP(
mesh_device=mesh_device,
@@ -97,6 +99,7 @@ def forward(
user_id=0,
mode="decode",
page_table=None,
+ kv_cache=None,
) -> ttnn.Tensor:
# x is fractured across devices and interleaved in DRAM (for prefill) and sharded in L1 (for decode)
skip_mem_cfg = self.model_config["DECODE_RESIDUAL_MEMCFG"] if mode == "decode" else ttnn.DRAM_MEMORY_CONFIG
@@ -112,7 +115,8 @@ def forward(
rot_mats,
user_id,
mode,
- page_table,
+ page_table=page_table,
+ kv_cache=kv_cache,
)
# Here x and attn_out are both fractured across devices
h = ttnn.add(x, attn_out, memory_config=skip_mem_cfg)
diff --git a/models/demos/llama3/tt/llama_embedding.py b/models/demos/llama3/tt/llama_embedding.py
index 89b6fb1b3f0..6259c17619f 100644
--- a/models/demos/llama3/tt/llama_embedding.py
+++ b/models/demos/llama3/tt/llama_embedding.py
@@ -23,7 +23,7 @@ def __init__(
base_name = args.get_state_dict_prefix("", None) + "tok_embeddings.weight"
torch_weight = self.state_dict[base_name].unsqueeze(0).unsqueeze(0)
- cache_name = weight_cache_path / base_name
+ cache_name = None if args.dummy_weights else weight_cache_path / base_name
self.weights = ttnn.as_tensor(
torch_weight,
dtype=dtype,
diff --git a/models/demos/llama3/tt/llama_model.py b/models/demos/llama3/tt/llama_model.py
index e04ed2c4cf8..9c55182115f 100644
--- a/models/demos/llama3/tt/llama_model.py
+++ b/models/demos/llama3/tt/llama_model.py
@@ -14,6 +14,9 @@
from models.common.lightweightmodule import LightweightModule
from models.demos.llama3.tt.distributed_norm import DistributedNorm
from models.demos.llama3.tt.lm_head import LMHead
+from models.demos.llama3.tt.llama_common import copy_host_to_device, get_prefill_rot_mat, HostEmbedding
+from models.demos.llama3.tt.llama_rope import TtLlamaRotarySetup
+from models.demos.llama3.tt.llama_embedding import TtLlamaEmbedding
class TtTransformer(LightweightModule):
@@ -24,7 +27,6 @@ def __init__(
mesh_device,
state_dict,
weight_cache_path,
- transformation_mats,
paged_attention_config=None,
):
super().__init__()
@@ -38,6 +40,24 @@ def __init__(
self.grid_size = self.args.max_grid_size
state_dict_prefix = args.get_state_dict_prefix("", None)
+ self.embd = TtLlamaEmbedding(
+ mesh_device=mesh_device,
+ args=args,
+ weight_cache_path=args.weight_cache_path(dtype),
+ state_dict=state_dict,
+ dtype=ttnn.bfloat16, # Row major layout requires bfloat16
+ )
+
+ self.rope_setup = TtLlamaRotarySetup(
+ mesh_device,
+ args.max_batch_size,
+ args.head_dim,
+ args.max_seq_len,
+ args.rope_theta,
+ args.use_scaled_rope,
+ )
+ self.trans_mats_dict = self.rope_setup.get_both_trans_mats()
+
self.layers = [
TtTransformerBlock(
args=args,
@@ -46,7 +66,7 @@ def __init__(
state_dict=state_dict,
weight_cache_path=weight_cache_path,
layer_num=i,
- transformation_mats=transformation_mats,
+ transformation_mats=self.trans_mats_dict,
paged_attention_config=paged_attention_config,
)
for i in range(self.n_layers)
@@ -76,6 +96,167 @@ def __init__(
weight_cache_path=weight_cache_path,
)
+ def prepare_inputs_prefill(self, tokens, page_table=None):
+ """
+ Inputs are torch tensors or python types. This function returns ttnn
+ tensors on device.
+ TODO: Debate whether this function is responsible for padding
+ """
+
+ tokens = tokens.reshape(1, 1, 1, -1)
+ S = tokens.shape[-1]
+
+ tokens = ttnn.from_torch(
+ tokens,
+ device=self.mesh_device,
+ dtype=ttnn.uint32,
+ layout=ttnn.ROW_MAJOR_LAYOUT,
+ mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device),
+ )
+
+ tokens_embd = self.embd(tokens)
+
+ tt_rot_mats_prefill = get_prefill_rot_mat(
+ self.args.head_dim, self.args.max_seq_len, self.mesh_device, seq_len=S
+ )
+
+ if page_table is not None:
+ tt_page_table = ttnn.from_torch(
+ page_table,
+ device=self.mesh_device,
+ dtype=ttnn.int32,
+ layout=ttnn.ROW_MAJOR_LAYOUT,
+ mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device),
+ )
+ else:
+ tt_page_table = None
+
+ return tokens_embd, tt_rot_mats_prefill, tt_page_table
+
+ def prepare_inputs_decode(self, *inputs):
+ """
+ Inputs are torch tensors or python types. This function returns ttnn
+ tensors on device.
+ Its implementation can take advantage of a few other functions which the
+ model must implement.
+ """
+ host_inputs = self.prepare_decode_inputs_host(*inputs)
+ device_inputs = copy_host_to_device(host_inputs, mesh_device=self.mesh_device) # Helper function
+ transformed_device_inputs = self.transform_decode_inputs_device(*device_inputs)
+ return transformed_device_inputs
+
+ def prepare_decode_inputs_host(self, tokens, current_pos, page_table=None):
+ """
+ Inputs are torch tensors or python types. Outputs are ttnn tensors on host.
+ NOTE: Tokens and current_pos are padded to batch
+ """
+ B = tokens.shape[-1]
+ assert current_pos.shape[0] == B, "Batch size mismatch"
+ assert B == self.args.max_batch_size, "Batch size must be equal to max_batch_size"
+
+ tokens = ttnn.from_torch(
+ tokens,
+ device=None,
+ dtype=ttnn.uint32,
+ mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device),
+ )
+
+ rope_idxs = self.rope_setup.get_rot_idxs(current_pos, on_host=True)
+ current_pos_tt = ttnn.from_torch(
+ current_pos,
+ device=None,
+ dtype=ttnn.int32,
+ mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device),
+ )
+
+ if page_table is not None:
+ page_table = ttnn.from_torch(
+ page_table,
+ device=None,
+ dtype=ttnn.int32,
+ mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device),
+ )
+ return tokens, current_pos_tt, rope_idxs, page_table
+
+ def transform_decode_inputs_device(self, tokens, current_pos, rope_idxs, page_table=None):
+ """
+ Inputs are ttnn tensors on device. This function applies any on-device
+ transformations which should happen before forward decode.
+ For example: tilize, reshape, shard.
+ Return transformed device tensors
+
+ Get rope sin/cos
+ Embed tokens
+ """
+ tt_rot_mats = self.rope_setup.get_rot_mats(rope_idxs)
+ tt_tokens = self.embd(tokens)
+ tt_tokens = ttnn.unsqueeze_to_4D(tt_tokens)
+ return tt_tokens, current_pos, tt_rot_mats, page_table
+
+ def process_output_prefill(self, tt_out, last_token_idx):
+ """
+ Input is ttnn device tensor of logits. Output is torch logits tensor.
+ NOTE: In this model, prefill always uses get_last_token
+ """
+ logits = ttnn.to_torch(tt_out, mesh_composer=ttnn.ConcatMeshToTensor(self.mesh_device, -1))[
+ 0, 0, last_token_idx, :
+ ]
+ return logits
+
+ def process_output_decode(self, tt_out):
+ """
+ Input is ttnn device tensor of logits. Output is torch logits tensor
+ """
+ if self.args.num_devices > 1:
+ tt_out = ttnn.all_gather(tt_out, dim=3, num_links=1, topology=ttnn.Topology.Linear)
+ tt_out_rm = ttnn.untilize(tt_out, use_multicore=True)
+ if self.args.num_devices > 1:
+ return ttnn.to_torch(ttnn.get_device_tensors(tt_out_rm)[0]).float()
+ else:
+ return ttnn.to_torch(tt_out_rm).float()
+
+ def ttnn_prefill_forward(
+ self,
+ x,
+ rot_mats,
+ user_id,
+ page_table=None,
+ get_last_token=-1,
+ ):
+ """
+ This method will take device tensors and any other args to run forward.
+ It returns ttnn device tensors.
+ """
+ return self.forward(
+ x,
+ current_pos=None,
+ rot_mats=rot_mats,
+ transformation_mats=None,
+ user_id=user_id,
+ mode="prefill",
+ page_table=page_table,
+ get_last_token=get_last_token,
+ )
+
+ def ttnn_decode_forward(
+ self,
+ x,
+ current_pos,
+ rot_mats,
+ page_table=None,
+ ):
+ """
+ This method will take device tensors and any other args to run forward.
+ It returns ttnn device tensors.
+ """
+ return self.forward(
+ x,
+ current_pos,
+ rot_mats=rot_mats,
+ mode="decode",
+ page_table=page_table,
+ )
+
def forward(
self,
x: ttnn.Tensor,
diff --git a/models/demos/llama3/tt/llama_rope.py b/models/demos/llama3/tt/llama_rope.py
index 576ce982e8c..c1b982308bc 100644
--- a/models/demos/llama3/tt/llama_rope.py
+++ b/models/demos/llama3/tt/llama_rope.py
@@ -87,9 +87,21 @@ def __init__(
mesh_mapper=ReplicateTensorToMesh(device) if self.is_mesh_device else None,
)
- def get_trans_mats(self):
+ # TODO: Colman, should this be TILE_SIZE or head_dim? Why should it be different for prefill and decode?
+ prefill_trans_mat_torch = get_rot_transformation_mat(dhead=head_dim)
+ self.transformation_mat_prefill = ttnn.from_torch(
+ prefill_trans_mat_torch,
+ device=device,
+ layout=ttnn.TILE_LAYOUT,
+ dtype=datatype,
+ memory_config=ttnn.DRAM_MEMORY_CONFIG,
+ mesh_mapper=ReplicateTensorToMesh(device) if self.is_mesh_device else None,
+ )
+
+ def get_both_trans_mats(self):
assert self.transformation_mat is not None, "Transformation matrix not initialized"
- return self.transformation_mat
+ assert self.transformation_mat_prefill is not None, "Prefill Transformation matrix not initialized"
+ return {"decode": self.transformation_mat, "prefill": self.transformation_mat_prefill}
def get_rot_idxs(self, position_idxs, on_host=False):
assert isinstance(position_idxs, torch.Tensor), "Position ids must be a torch tensor"
diff --git a/models/demos/llama3/tt/model_config.py b/models/demos/llama3/tt/model_config.py
index c3f9f385e7a..4ddb684fe9c 100644
--- a/models/demos/llama3/tt/model_config.py
+++ b/models/demos/llama3/tt/model_config.py
@@ -563,7 +563,7 @@ def find_largest_divisor(n, max_divisor=8):
fuse_batch=False,
)
self.model_config["VISION_XATTN_DENSE_PROGCFG"] = lambda seq_len: self.matmul_config(
- m=seq_len,
+ m=min(seq_len, 1024),
k=self.dim // self.num_devices,
n=self.dim,
grid_size=(8, 8),
@@ -589,23 +589,21 @@ def find_largest_divisor(n, max_divisor=8):
fuse_batch=seq_len <= max_seq,
)
- xattn_cache_y_cores = (
- 16 // self.num_devices
- ) # Based on seqlen, this formula gives us a valid number of y cores
- xattn_cache_x_cores = 8
- self.model_config["XATTN_KV_PREFILL_MEM_CFG"] = lambda seq_len: ttnn.create_sharded_memory_config(
- # using n_heads since xattn repeats KV to match Q
- (
- nearest_32(
- (self.n_heads // self.num_devices) * seq_len // (xattn_cache_y_cores * xattn_cache_x_cores)
+ def _get_xattn_kv_prefill_mem_cfg(seq_len):
+ M = (self.n_kv_heads // self.num_devices) * seq_len
+ cores_x, cores_y = self.find_grid(M // self.tile_size)
+ return ttnn.create_sharded_memory_config(
+ (
+ nearest_32(M // (cores_x * cores_y)),
+ self.head_dim,
),
- self.head_dim,
- ),
- ttnn.CoreGrid(y=xattn_cache_y_cores, x=xattn_cache_x_cores),
- ttnn.ShardStrategy.HEIGHT,
- ttnn.ShardOrientation.ROW_MAJOR,
- use_height_and_width_as_shard_shape=True,
- )
+ ttnn.CoreGrid(y=cores_y, x=cores_x),
+ ttnn.ShardStrategy.HEIGHT,
+ ttnn.ShardOrientation.ROW_MAJOR,
+ use_height_and_width_as_shard_shape=True,
+ )
+
+ self.model_config["XATTN_KV_PREFILL_MEM_CFG"] = _get_xattn_kv_prefill_mem_cfg
self.VISION_MAX_MM_SEQ = nearest_32(self.vision_chunk_ntok)
# RMS NORM
@@ -648,7 +646,7 @@ def ccl_topology(self):
return ttnn.Topology.Linear
return None
- def prepare_inputs_ttnn_decode(self, x, input_mem_cfg, force_replicated=False, on_host=False):
+ def prepare_residual_tensor_decode(self, x, input_mem_cfg, force_replicated=False, on_host=False):
"""
Prepare inputs for decode mode.
x: (batch, seq, dim)
@@ -698,7 +696,7 @@ def prepare_inputs_ttnn_decode(self, x, input_mem_cfg, force_replicated=False, o
x = ttnn.to_layout(x, layout=ttnn.TILE_LAYOUT)
return x
- def prepare_inputs_ttnn_prefill(self, x_bsh, force_replicated=False):
+ def prepare_residual_tensor_prefill(self, x_bsh, force_replicated=False):
"""
Prepare inputs for prefill mode.
x: (batch, seq, hidden_dim)
diff --git a/models/demos/llama3/tt/multimodal/llama_conv2d_patch.py b/models/demos/llama3/tt/multimodal/llama_conv2d_patch.py
index a4d1bb59885..f5ff04f7e3e 100644
--- a/models/demos/llama3/tt/multimodal/llama_conv2d_patch.py
+++ b/models/demos/llama3/tt/multimodal/llama_conv2d_patch.py
@@ -79,7 +79,8 @@ def __init__(
mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device),
)
- self.compute_kernel_config = ttnn.WormholeComputeKernelConfig(
+ self.compute_kernel_config = ttnn.init_device_compute_kernel_config(
+ mesh_device.arch(),
math_fidelity=ttnn.MathFidelity.HiFi2,
math_approx_mode=True,
fp32_dest_acc_en=True,
diff --git a/models/demos/llama3/tt/multimodal/llama_cross_attention.py b/models/demos/llama3/tt/multimodal/llama_cross_attention.py
index d7032fd59ba..5aa338a012b 100644
--- a/models/demos/llama3/tt/multimodal/llama_cross_attention.py
+++ b/models/demos/llama3/tt/multimodal/llama_cross_attention.py
@@ -187,12 +187,7 @@ def compute_xattn_kv_cache(self, xattn_tokens, user_id, xattn_cache):
xk = self.k_norm(xk, mode="decode")
- # NOTE: Doing repeat in xattn_cache generation to avoid massive overhead in forward
- xk = ttnn.repeat_interleave(xk, self.n_local_heads // self.n_local_kv_heads, dim=1)
- xv = ttnn.repeat_interleave(xv, self.n_local_heads // self.n_local_kv_heads, dim=1)
-
k_cache, v_cache = xattn_cache
-
# Work around fill_cache memory constraint by making these sharded
k_fill = ttnn.interleaved_to_sharded(xk, self.model_config["XATTN_KV_PREFILL_MEM_CFG"](seqlen_y))
v_fill = ttnn.interleaved_to_sharded(xv, self.model_config["XATTN_KV_PREFILL_MEM_CFG"](seqlen_y))
@@ -312,27 +307,22 @@ def forward_prefill(
xq = self.q_norm(xq, mode="prefill")
- scores = ttnn.matmul(
- xq,
- ttnn.transpose(k_cache_user, -1, -2),
- dtype=ttnn.bfloat16,
- memory_config=ttnn.DRAM_MEMORY_CONFIG,
- compute_kernel_config=self.compute_kernel_config_hifi2,
- program_config=self.model_config["VISION_XATTN_SCORE_PROGCFG"](seq_len, cache_seq_len),
+ program_config = ttnn.SDPAProgramConfig(
+ compute_with_storage_grid_size=self.mesh_device.compute_with_storage_grid_size(),
+ q_chunk_size=128,
+ k_chunk_size=128,
+ exp_approx_mode=False,
)
- scores = ttnn.multiply(scores, self.scale)
- # WARNING: This add is buggy if xattn_mask has to be broadcasted to n_local_heads. Workaround is to broadcast on host side
- scores = ttnn.add(scores, xattn_mask)
- scores = ttnn.softmax(scores, dim=-1, numeric_stable=True)
-
- output = ttnn.matmul(
- scores,
+ output = ttnn.transformer.scaled_dot_product_attention(
+ xq,
+ k_cache_user,
v_cache_user,
- dtype=ttnn.bfloat16,
- memory_config=ttnn.DRAM_MEMORY_CONFIG,
- compute_kernel_config=self.compute_kernel_config_hifi4,
- program_config=self.model_config["VISION_XATTN_OUTPUT_PROGCFG"](seq_len, cache_seq_len),
+ is_causal=False,
+ attn_mask=xattn_mask,
+ scale=self.scale,
+ program_config=program_config,
+ compute_kernel_config=self.compute_kernel_config_sdpa,
)
# WARNING: this broadcast is also broken, must broadcast on host
diff --git a/models/demos/llama3/tt/multimodal/llama_cross_attention_transformer_text.py b/models/demos/llama3/tt/multimodal/llama_cross_attention_transformer_text.py
index f657abb8672..a7ce9def430 100644
--- a/models/demos/llama3/tt/multimodal/llama_cross_attention_transformer_text.py
+++ b/models/demos/llama3/tt/multimodal/llama_cross_attention_transformer_text.py
@@ -5,12 +5,15 @@
import math
import ttnn
import torch
+from tqdm import tqdm
from models.demos.llama3.tt.llama_decoder import TtTransformerBlock
from models.demos.llama3.tt.multimodal.llama_cross_block import TtLlamaCrossAttentionTransformerBlock
from models.demos.llama3.tt.distributed_norm import DistributedNorm
from models.common.rmsnorm import RMSNorm
import ttnn
from models.common.lightweightmodule import LightweightModule
+from models.demos.llama3.tt.llama_embedding import TtLlamaEmbedding
+from models.demos.llama3.tt.llama_rope import TtLlamaRotarySetup
from models.utility_functions import (
nearest_32,
@@ -41,6 +44,7 @@ def __init__(
weight_cache_path,
dtype,
configuration,
+ use_paged_kv_cache=False,
):
super().__init__()
self.vocab_size = configuration.vocab_size
@@ -116,12 +120,31 @@ def __init__(
self.num_frozen_embeddings = self.tok_embeddings.num_embeddings
self._thresh = self.num_frozen_embeddings - 1
+ self.rope_setup = TtLlamaRotarySetup(
+ mesh_device,
+ configuration.max_batch_size,
+ configuration.head_dim,
+ configuration.max_seq_len,
+ configuration.rope_theta,
+ configuration.use_scaled_rope,
+ )
+ self.trans_mats_dict = self.rope_setup.get_both_trans_mats()
+
# transformer blocks
self.layers = []
self.cross_attention_layers = []
- for i in range(configuration.n_layers):
+ for i in tqdm(range(configuration.n_layers), desc="Loading text transformer layers"):
layer_id = i
- block = TtTransformerBlock(configuration, mesh_device, dtype, state_dict, layer_id, weight_cache_path)
+ block = TtTransformerBlock(
+ configuration,
+ mesh_device,
+ dtype,
+ state_dict,
+ layer_id,
+ weight_cache_path,
+ transformation_mats=self.trans_mats_dict,
+ use_paged_kv_cache=use_paged_kv_cache,
+ )
self.layers.append(block)
if layer_id in self.fusion_schedule:
xa_layer_id = self.fusion_schedule.index(layer_id)
@@ -224,7 +247,7 @@ def setup_cache(self, max_batch_size):
[
ttnn.from_torch(
torch.zeros(
- max_batch_size, self.configuration.n_heads, vision_seq_len, self.configuration.head_dim
+ max_batch_size, self.configuration.n_kv_heads, vision_seq_len, self.configuration.head_dim
),
device=self.mesh_device,
layout=ttnn.TILE_LAYOUT,
@@ -247,14 +270,14 @@ def forward(
full_text_row_masked_out_mask_11SD: ttnn.Tensor,
xattn_caches,
current_pos,
- rot_mat=None,
- transformation_mats=None,
+ rot_mats=None,
user_id=0,
mode="decode",
page_table=None,
- # get_last_token=-1,
+ kv_cache=None,
text_only_inference=False,
vision_tokens=None,
+ get_last_token=-1,
):
for idx, (
layer,
@@ -275,12 +298,15 @@ def forward(
h = layer(
h,
current_pos,
- rot_mat=rot_mat,
- transformation_mats=transformation_mats,
+ rot_mats=rot_mats,
user_id=user_id,
mode=mode,
+ page_table=page_table,
+ kv_cache=kv_cache,
)
+ if get_last_token != -1:
+ h = ttnn.slice(h, (0, 0, get_last_token, 0), (1, 1, get_last_token + 32, h.shape[-1]))
h = self.norm(h, mode=mode)
# TODO: Switch to using dram-sharded LM head and remove this
diff --git a/models/demos/llama3/tt/multimodal/llama_cross_block.py b/models/demos/llama3/tt/multimodal/llama_cross_block.py
index 1761bc7ac66..9d8c3760af0 100644
--- a/models/demos/llama3/tt/multimodal/llama_cross_block.py
+++ b/models/demos/llama3/tt/multimodal/llama_cross_block.py
@@ -126,6 +126,11 @@ def forward(
user_id=0,
vision_tokens=None,
):
+ skip_mem_cfg = self.model_config["DECODE_RESIDUAL_MEMCFG"] if mode == "decode" else ttnn.DRAM_MEMORY_CONFIG
+ assert (
+ x_11SH.memory_config() == skip_mem_cfg
+ ), f"decoder input memcfg mismatch: {x_11SH.memory_config()} != {skip_mem_cfg}"
+
attn_out = self.attention(
x_11SH=self.attention_norm(x_11SH, mode=mode),
xattn_mask=xattn_mask,
diff --git a/models/demos/llama3/tt/multimodal/llama_image_transformer.py b/models/demos/llama3/tt/multimodal/llama_image_transformer.py
index ea86d302748..e9bef2377ed 100644
--- a/models/demos/llama3/tt/multimodal/llama_image_transformer.py
+++ b/models/demos/llama3/tt/multimodal/llama_image_transformer.py
@@ -2,10 +2,8 @@
# SPDX-License-Identifier: Apache-2.0
-from typing import List, Optional
-import torch
+from tqdm import tqdm
-import ttnn
from models.utility_functions import (
nearest_32,
)
@@ -41,7 +39,7 @@ def __init__(
configuration=configuration,
gated=gated,
)
- for i in range(layers)
+ for i in tqdm(range(layers), desc=f"Loading vision transformer layers")
]
def forward(self, x, return_intermediate=None, mask=None):
diff --git a/models/demos/llama3/tt/multimodal/llama_vision_model.py b/models/demos/llama3/tt/multimodal/llama_vision_model.py
index 96149d5a0f9..0b1f36fd6f4 100644
--- a/models/demos/llama3/tt/multimodal/llama_vision_model.py
+++ b/models/demos/llama3/tt/multimodal/llama_vision_model.py
@@ -29,6 +29,7 @@
get_prefill_rot_mat,
get_rot_transformation_mat,
get_single_rot_mat,
+ copy_host_to_device,
)
from models.utility_functions import (
nearest_32,
@@ -128,6 +129,7 @@ def __init__(
weight_cache_path,
dtype,
configuration,
+ use_paged_kv_cache=False,
) -> None:
super().__init__()
@@ -159,6 +161,7 @@ def __init__(
weight_cache_path=configuration.weight_cache_path(ttnn.bfloat8_b),
dtype=ttnn.bfloat8_b,
configuration=configuration,
+ use_paged_kv_cache=use_paged_kv_cache,
)
self.image_res = configuration.vision_chunk_size
self.max_num_chunks = configuration.vision_max_num_chunks
@@ -268,7 +271,6 @@ def compute_vision_tokens_masks(
def validate_inputs(self, tokens, position_ids):
batch, seq_len = tokens.shape[:2]
- assert batch == 1, f"Only batch 1 is supported, got {batch}"
assert (
seq_len <= self.configuration.max_seq_len
), f"Sequence length {seq_len} exceeds max sequence length {self.configuration.max_seq_len}"
@@ -279,7 +281,9 @@ def prepare_inputs_common(self, position_ids, tokens):
h = self.text_model.get_partially_trainable_embedding(tokens)
return h
- def prepare_inputs_prefill(self, tokens, cross_attention_masks, full_text_row_masked_out_mask, prefill_len):
+ def prepare_inputs_prefill(
+ self, tokens, cross_attention_masks, full_text_row_masked_out_mask, prefill_len, page_table=None
+ ):
B = tokens.shape[0]
assert B == 1, f"Only batch 1 is supported, got {B}"
S = tokens.shape[1]
@@ -287,26 +291,16 @@ def prepare_inputs_prefill(self, tokens, cross_attention_masks, full_text_row_ma
h = self.prepare_inputs_common(position_ids, tokens)
padded_seq_len = _get_padded_prefill_seqlen(S)
- tt_position_id = ttnn.from_torch(
- position_ids,
- device=self.mesh_device,
- dtype=ttnn.int32,
- layout=ttnn.ROW_MAJOR_LAYOUT,
- memory_config=ttnn.DRAM_MEMORY_CONFIG,
- mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device),
- )
-
xattn_mask = cross_attention_masks[:, :, position_ids]
- xattn_mask_expand = xattn_mask.expand(-1, self.configuration.n_heads // self.configuration.num_devices, -1, -1)
- xattn_mask_expand = torch.nn.functional.pad(
- xattn_mask_expand,
- (0, 0, 0, padded_seq_len - xattn_mask_expand.shape[2]),
+ xattn_mask = torch.nn.functional.pad(
+ xattn_mask,
+ (0, 0, 0, padded_seq_len - xattn_mask.shape[2]),
"constant",
get_negative_inf_value(torch.float32),
)
tt_xattn_mask = ttnn.from_torch(
- xattn_mask_expand,
+ xattn_mask,
device=self.mesh_device,
dtype=ttnn.bfloat16,
layout=ttnn.ROW_MAJOR_LAYOUT,
@@ -314,6 +308,7 @@ def prepare_inputs_prefill(self, tokens, cross_attention_masks, full_text_row_ma
mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device),
)
tt_xattn_mask = ttnn.to_layout(tt_xattn_mask, ttnn.TILE_LAYOUT)
+ tt_xattn_mask = ttnn.typecast(tt_xattn_mask, ttnn.bfloat4_b)
full_text_mask = full_text_row_masked_out_mask[:, :, position_ids]
full_text_mask = torch.nn.functional.pad(
@@ -331,65 +326,75 @@ def prepare_inputs_prefill(self, tokens, cross_attention_masks, full_text_row_ma
mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device),
)
tt_full_text_mask_expand_1NSH = ttnn.to_layout(tt_full_text_mask_expand_1NSH, ttnn.TILE_LAYOUT)
+ tt_full_text_mask_expand_1NSH = ttnn.typecast(tt_full_text_mask_expand_1NSH, ttnn.bfloat4_b)
h = torch.nn.functional.pad(h, (0, 0, 0, padded_seq_len - h.shape[1]), "constant", 0)
- tt_h = self.configuration.prepare_inputs_ttnn_prefill(
+ tt_h = self.configuration.prepare_residual_tensor_prefill(
h,
)
rot_mats = get_prefill_rot_mat(
self.configuration.head_dim, self.configuration.max_seq_len, self.mesh_device, seq_len=S
)
- transformation_mat_torch = get_rot_transformation_mat(self.configuration.head_dim)
- transformation_mats = ttnn.as_tensor(
- transformation_mat_torch,
- dtype=ttnn.bfloat16,
- layout=ttnn.TILE_LAYOUT,
- device=self.mesh_device,
- mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device),
- memory_config=ttnn.DRAM_MEMORY_CONFIG,
- )
full_text_mask_expand_11SD = full_text_mask.expand(-1, -1, -1, self.configuration.dim)
tt_full_text_mask_expand_11SD = ttnn.from_torch(
full_text_mask_expand_11SD,
device=self.mesh_device,
- dtype=ttnn.bfloat8_b,
+ dtype=ttnn.bfloat4_b,
layout=ttnn.TILE_LAYOUT,
memory_config=ttnn.DRAM_MEMORY_CONFIG,
mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=-1),
)
+ if isinstance(page_table, torch.Tensor):
+ # Support vLLM tensor page_table input
+ page_table = ttnn.as_tensor(
+ page_table,
+ device=self.mesh_device,
+ memory_config=ttnn.DRAM_MEMORY_CONFIG,
+ dtype=ttnn.int32,
+ layout=ttnn.ROW_MAJOR_LAYOUT,
+ mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device),
+ )
+
return (
tt_h,
tt_xattn_mask,
tt_full_text_mask_expand_1NSH,
tt_full_text_mask_expand_11SD,
- tt_position_id,
rot_mats,
- transformation_mats,
+ page_table,
)
- def prepare_inputs_decode(self, tokens, cross_attention_masks, full_text_row_masked_out_mask, position_id):
+ def prepare_inputs_decode(
+ self, tokens, cross_attention_masks, full_text_row_masked_out_mask, position_id, page_table=None
+ ):
(
tt_h,
tt_xattn_mask,
tt_full_text_mask_expand_1NSH,
- _tt_full_text_mask_expand_11SD,
tt_position_id,
- rot_mats,
- _transformation_mats,
- ) = self.prepare_decode_inputs_host(tokens, cross_attention_masks, full_text_row_masked_out_mask, position_id)
+ tt_rope_id,
+ tt_page_table,
+ ) = self.prepare_decode_inputs_host(
+ tokens, cross_attention_masks, full_text_row_masked_out_mask, position_id, page_table=page_table
+ )
(
tt_h,
tt_xattn_mask,
tt_full_text_mask_expand_1NSH,
tt_position_id,
- rot_mats,
- ) = self.copy_host_to_device((tt_h, tt_xattn_mask, tt_full_text_mask_expand_1NSH, tt_position_id, rot_mats))
+ tt_rope_id,
+ tt_page_table,
+ ) = copy_host_to_device(
+ (tt_h, tt_xattn_mask, tt_full_text_mask_expand_1NSH, tt_position_id, tt_rope_id, tt_page_table),
+ mesh_device=self.mesh_device,
+ )
- tt_h, tt_xattn_mask, tt_full_text_mask_expand_1NSH = self.transform_decode_inputs_device(
+ tt_h, tt_rot_mats, tt_xattn_mask, tt_full_text_mask_expand_1NSH = self.transform_decode_inputs_device(
tt_h,
+ tt_rope_id,
tt_xattn_mask,
tt_full_text_mask_expand_1NSH,
B=tokens.shape[0],
@@ -399,34 +404,35 @@ def prepare_inputs_decode(self, tokens, cross_attention_masks, full_text_row_mas
tt_h,
tt_xattn_mask,
tt_full_text_mask_expand_1NSH,
- _tt_full_text_mask_expand_11SD,
tt_position_id,
- rot_mats,
- _transformation_mats,
+ tt_rot_mats,
+ tt_page_table,
)
- def prepare_decode_inputs_host(self, tokens, cross_attention_masks, full_text_row_masked_out_mask, position_id):
+ def prepare_decode_inputs_host(
+ self, tokens, cross_attention_masks, full_text_row_masked_out_mask, position_id, page_table=None
+ ):
B = tokens.shape[0]
assert (
B == self.configuration.max_batch_size
), f"Batch size must match max batch size. Got {B}, expected {self.configuration.max_batch_size}"
- position_ids = torch.tensor([position_id], dtype=torch.long)
- h = self.prepare_inputs_common(position_ids, tokens)
- tt_h = self.configuration.prepare_inputs_ttnn_decode(
+ h = self.prepare_inputs_common(position_id, tokens)
+ tt_h = self.configuration.prepare_residual_tensor_decode(
h,
- None, # on_host tensors have no memory_config
+ None,
on_host=True,
)
tt_position_id = ttnn.from_torch(
- position_ids,
+ position_id,
device=None,
dtype=ttnn.int32,
layout=ttnn.ROW_MAJOR_LAYOUT,
mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device),
)
- xattn_mask = cross_attention_masks[:, :, position_ids]
+ tt_rope_id = self.text_model.rope_setup.get_rot_idxs(position_id, on_host=True)
+ xattn_mask = torch.cat([cross_attention_masks[i][:, :, position_id[i]] for i in range(B)], dim=1).unsqueeze(0)
xattn_mask_expand = xattn_mask.expand(-1, self.configuration.n_heads // self.configuration.num_devices, -1, -1)
xattn_mask_expand = xattn_mask_expand.transpose(1, 2).contiguous()
@@ -438,7 +444,9 @@ def prepare_decode_inputs_host(self, tokens, cross_attention_masks, full_text_ro
mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device),
)
- full_text_mask = full_text_row_masked_out_mask[:, :, position_ids]
+ full_text_mask = torch.cat(
+ [full_text_row_masked_out_mask[i][:, :, position_id[i]] for i in range(B)], dim=1
+ ).unsqueeze(0)
full_text_mask_expand_1NSH = full_text_mask.expand(
-1, self.configuration.n_heads // self.configuration.num_devices, -1, self.configuration.head_dim
)
@@ -451,44 +459,25 @@ def prepare_decode_inputs_host(self, tokens, cross_attention_masks, full_text_ro
mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device),
)
- rot_mats, _ = get_single_rot_mat(
- self.configuration.head_dim,
- self.mesh_device,
- self.configuration.num_devices,
- start_pos=position_ids.item() - 1, # TODO: Change function to support decode batch > 1
- # TODO: B must match max_batch_size, be careful
- on_host=True,
- )
-
- transformation_mats = None
- tt_full_text_mask_expand_11SD = None
+ if isinstance(page_table, torch.Tensor):
+ # Support vLLM tensor page_table input
+ page_table = ttnn.as_tensor(
+ page_table,
+ dtype=ttnn.int32,
+ layout=ttnn.ROW_MAJOR_LAYOUT,
+ mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device),
+ )
return (
tt_h,
tt_xattn_mask,
tt_full_text_mask_expand_1NSH,
- tt_full_text_mask_expand_11SD,
tt_position_id,
- rot_mats,
- transformation_mats,
+ tt_rope_id,
+ page_table,
)
- def copy_host_to_device(self, host_tensors, device_tensors=None):
- """
- Helper function which copies host tensors to device tensors
- """
- if device_tensors is None:
- ret = []
- for i in range(len(host_tensors)):
- on_device = ttnn.to_device(host_tensors[i], device=self.mesh_device)
- ret.append(on_device)
- return ret
- else:
- for i in range(len(host_tensors)):
- ttnn.copy_host_to_device_tensor(host_tensors[i], device_tensors[i])
- return device_tensors
-
- def transform_decode_inputs_device(self, tt_h, tt_xattn_mask, tt_full_text_mask_expand_1NSH, B):
+ def transform_decode_inputs_device(self, tt_h, tt_rope_id, tt_xattn_mask, tt_full_text_mask_expand_1NSH, B):
"""
Does any transformations on device tensors which are necessary before ttnn_decode_forward
"""
@@ -499,6 +488,8 @@ def transform_decode_inputs_device(self, tt_h, tt_xattn_mask, tt_full_text_mask_
tt_h = ttnn.to_memory_config(tt_h, self.configuration.model_config["DECODE_RESIDUAL_MEMCFG"])
+ tt_rot_mats = self.text_model.rope_setup.get_rot_mats(tt_rope_id)
+
tt_xattn_mask = ttnn.to_layout(tt_xattn_mask, ttnn.TILE_LAYOUT)
tt_xattn_mask = ttnn.reshape(
tt_xattn_mask,
@@ -531,12 +522,11 @@ def transform_decode_inputs_device(self, tt_h, tt_xattn_mask, tt_full_text_mask_
),
)
- return (tt_h, tt_xattn_mask, tt_full_text_mask_expand_1NSH)
+ return (tt_h, tt_rot_mats, tt_xattn_mask, tt_full_text_mask_expand_1NSH)
- def process_output_prefill(self, tt_out, B, S):
- padded_seq_len = _get_padded_prefill_seqlen(S)
+ def process_output_prefill(self, tt_out, B, last_token_idx):
tt_out = ttnn.to_torch(ttnn.get_device_tensors(tt_out)[0]).float()
- tt_out = tt_out[0].reshape(B, padded_seq_len, -1)[:, :S, :]
+ tt_out = tt_out[0, 0, last_token_idx, :]
return tt_out
def process_output_decode(self, tt_out, B, S):
@@ -554,6 +544,8 @@ def forward(
text_only_inference: bool = False,
user_id=0,
vision_tokens=None,
+ page_table=None,
+ kv_cache=None,
) -> torch.Tensor:
"""
This method takes torch tensors in, returns torch tensors.
@@ -573,12 +565,13 @@ def forward(
tt_full_text_mask_expand_11SD,
tt_position_id,
rot_mats,
- transformation_mats,
+ tt_page_table,
) = prepare_fn(
tokens,
cross_attention_masks,
full_text_row_masked_out_mask,
pos_arg,
+ page_table=page_table,
)
logits = self.text_model.forward(
@@ -588,10 +581,11 @@ def forward(
full_text_row_masked_out_mask_11SD=tt_full_text_mask_expand_11SD,
xattn_caches=xattn_caches,
current_pos=tt_position_id,
- rot_mat=rot_mats,
- transformation_mats=transformation_mats,
+ rot_mats=rot_mats,
user_id=user_id,
mode=mode,
+ page_table=tt_page_table,
+ kv_cache=kv_cache,
text_only_inference=text_only_inference,
vision_tokens=vision_tokens,
)
@@ -607,11 +601,12 @@ def ttnn_prefill_forward(
full_text_mas_expand_1NSH,
full_text_mask_expand_11SD,
xattn_caches,
- position_id,
rot_mats,
- transformation_mats,
user_id,
vision_tokens,
+ page_table=None,
+ kv_cache=None,
+ get_last_token=-1,
):
"""
This method runs prefill forward. It takes ttnn tensors in, returns ttnn tensors.
@@ -622,12 +617,14 @@ def ttnn_prefill_forward(
full_text_row_masked_out_mask_1NSH=full_text_mas_expand_1NSH,
full_text_row_masked_out_mask_11SD=full_text_mask_expand_11SD,
xattn_caches=xattn_caches,
- current_pos=position_id,
- rot_mat=rot_mats,
- transformation_mats=transformation_mats,
+ current_pos=None,
+ rot_mats=rot_mats,
user_id=user_id,
mode="prefill",
+ page_table=page_table,
+ kv_cache=kv_cache,
vision_tokens=vision_tokens,
+ get_last_token=get_last_token,
)
tt_out = ttnn.to_layout(logits, ttnn.ROW_MAJOR_LAYOUT)
return tt_out
@@ -640,6 +637,8 @@ def ttnn_decode_forward(
xattn_caches,
position_id,
rot_mats,
+ page_table=None,
+ kv_cache=None,
):
"""
This method runs decode forward. It takes ttnn tensors in, returns ttnn tensors.
@@ -651,8 +650,10 @@ def ttnn_decode_forward(
full_text_row_masked_out_mask_11SD=None,
xattn_caches=xattn_caches,
current_pos=position_id,
- rot_mat=rot_mats,
+ rot_mats=rot_mats,
mode="decode",
+ page_table=page_table,
+ kv_cache=kv_cache,
)
tt_out = ttnn.to_layout(logits, ttnn.ROW_MAJOR_LAYOUT)
return tt_out
@@ -720,11 +721,11 @@ def _pad_masks(
def _get_padded_prefill_seqlen(seq_len):
"""
If seq_len is less than 128, pad to 128
- If seq_len is more than 128, pad to whichever is smaller: a power of 2 or a multiple of 1024
+ If seq_len is more than 128, pad to whichever is smaller: a power of 2 or a multiple of 2048 (text model requires mult of 2048, while vision allows mult of 1024)
"""
if seq_len < 128:
return 128
else:
- mult_1024 = 1024 * math.ceil(seq_len / 1024)
+ mult_2k = 2048 * math.ceil(seq_len / 2048)
pow_2 = 2 ** math.ceil(math.log2(seq_len))
- return min(mult_1024, pow_2)
+ return min(mult_2k, pow_2)
diff --git a/models/demos/segformer/tt/common.py b/models/demos/segformer/tt/common.py
index 5f52fe0e507..d777116d232 100644
--- a/models/demos/segformer/tt/common.py
+++ b/models/demos/segformer/tt/common.py
@@ -40,12 +40,8 @@ def __call__(self, device, input_tensor):
conv_config = ttnn.Conv2dConfig(
dtype=self.dtype,
weights_dtype=ttnn.bfloat16,
- math_fidelity=ttnn.MathFidelity.LoFi,
activation=self.activation,
shard_layout=self.shard_layout,
- math_approx_mode_enabled=True,
- fp32_dest_acc_enabled=False,
- packer_l1_accum_enabled=False,
input_channels_alignment=16 if input_tensor.shape[3] < 16 else 32,
transpose_shards=False,
reshard_if_not_optimal=self.reshard,
@@ -54,10 +50,17 @@ def __call__(self, device, input_tensor):
enable_act_double_buffer=True,
enable_split_reader=False,
)
+ compute_config = ttnn.init_device_compute_kernel_config(
+ device.arch(),
+ math_fidelity=ttnn.MathFidelity.LoFi,
+ math_approx_mode=True,
+ fp32_dest_acc_en=False,
+ packer_l1_acc=False,
+ )
if self.act_block_h is not None:
conv_config.act_block_h_override = self.act_block_h
- [output_tensor, _out_height, _out_width, self.weights, self.bias] = ttnn.conv2d(
+ [output_tensor, [_out_height, _out_width]] = ttnn.conv2d(
input_tensor=input_tensor,
weight_tensor=self.weights,
bias_tensor=self.bias,
@@ -71,7 +74,10 @@ def __call__(self, device, input_tensor):
input_height=input_tensor.shape[1],
input_width=input_tensor.shape[2],
conv_config=conv_config,
+ compute_config=compute_config,
groups=self.groups,
+ return_output_dim=True,
+ return_weights_and_bias=False,
)
return output_tensor, _out_height, _out_width
diff --git a/models/demos/segformer/tt/ttnn_segformer_decode_head.py b/models/demos/segformer/tt/ttnn_segformer_decode_head.py
index 6aed216c578..4be9a957a8c 100644
--- a/models/demos/segformer/tt/ttnn_segformer_decode_head.py
+++ b/models/demos/segformer/tt/ttnn_segformer_decode_head.py
@@ -78,7 +78,7 @@ def __call__(self, encoder_hidden_states: ttnn.bfloat8_b, parameters) -> ttnn.Te
encoder_hidden_state = ttnn.upsample(
encoder_hidden_state,
- scale_factor=(128 // encoder_hidden_state.shape[2], 128 // encoder_hidden_state.shape[2], 1),
+ scale_factor=(128 // encoder_hidden_state.shape[2], 128 // encoder_hidden_state.shape[2]),
mode="bilinear",
)
diff --git a/models/demos/squeezebert/README.md b/models/demos/squeezebert/README.md
new file mode 100644
index 00000000000..70adcab9c2e
--- /dev/null
+++ b/models/demos/squeezebert/README.md
@@ -0,0 +1,30 @@
+# SqueezeBERT demo
+
+Demo showcasing SqueezeBERT running on Grayskull - e150 and Wormhole - n150, n300 using ttnn.
+
+## Introduction
+SqueezeBERT is a bidirectional transformer similar to the BERT model. The key difference between the BERT architecture and the SqueezeBERT architecture is that SqueezeBERT uses grouped convolutions instead of fully-connected layers for the Q, K, V and FFN layers.
+
+
+## Details
+The entry point to functional_squeezebert model is squeezebert_for_question_answering in `models/demos/squeezebert/tt/ttnn_functional_squeezebert.py`. The model picks up certain configs and weights from huggingface pretrained model. We have used `squeezebert/squeezebert-uncased` version from huggingface as our reference.
+
+### Sequence Size: 384
+Sequence size determines the maximum length of input sequences processed by the model, optimizing performance and compatibility. It's recommended to set the sequence_size to 384
+
+### Batch size: 8
+Batch Size determines the number of input sequences processed simultaneously during training or inference, impacting computational efficiency and memory usage. It's recommended to set the batch_size to 8
+
+## How to Run
+
+Use `pytest --disable-warnings models/demos/squeezebert/demo/demo.py::test_demo[models.demos.squeezebert.tt.ttnn_functional_squeezebert-squeezebert/squeezebert-uncased-models/demos/squeezebert/demo/input_data.json-8-384-device_params0]` to run the demo.
+
+If you wish to run the demo with a different input use `pytest --disable-warnings models/demos/squeezebert/demo/demo.py::test_demo[models.demos.squeezebert.tt.ttnn_functional_squeezebert-squeezebert/squeezebert-uncased--8-384-device_params0]`. This file is expected to have exactly 8 inputs.
+
+Our second demo is designed to run SQuADV2 dataset, run this with `pytest --disable-warnings models/demos/squeezebert/demo/demo.py::test_demo_squadv2[3-models.demos.squeezebert.tt.ttnn_functional_squeezebert-squeezebert/squeezebert-uncased-8-384-device_params0]`.
+
+If you wish to run for `n_iterations` samples, use `pytest --disable-warnings models/demos/squeezebert/demo/demo.py::test_demo_squadv2[-models.demos.squeezebert.tt.ttnn_functional_squeezebert-squeezebert/squeezebert-uncased-8-384-device_params0]`
+
+
+## Inputs
+The demo receives inputs from respective `input_data.json` by default. To modify the inputs or specify a different path, adjust the input_path parameter in the command accordingly. It's recommended to avoid direct modifications to the input_data.json file.
diff --git a/models/demos/squeezebert/demo/demo.py b/models/demos/squeezebert/demo/demo.py
new file mode 100644
index 00000000000..6f7dc817d1d
--- /dev/null
+++ b/models/demos/squeezebert/demo/demo.py
@@ -0,0 +1,324 @@
+# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
+
+# SPDX-License-Identifier: Apache-2.0
+
+import ttnn
+import json
+import torch
+import pytest
+import evaluate
+
+from loguru import logger
+from ttnn.model_preprocessing import *
+from models.utility_functions import (
+ profiler,
+ skip_for_wormhole_b0,
+ disable_compilation_reports,
+ disable_persistent_kernel_cache,
+)
+from ttnn.model_preprocessing import preprocess_model_parameters
+from models.demos.squeezebert.tt import ttnn_functional_squeezebert
+from models.datasets.dataset_squadv2 import squadv2_1K_samples_input, squadv2_answer_decode_batch
+
+from transformers import SqueezeBertForQuestionAnswering, pipeline, SqueezeBertTokenizer
+
+
+def load_inputs(input_path, batch):
+ with open(input_path) as f:
+ input_data = json.load(f)
+ assert len(input_data) >= batch, f"Input data needs to have at least {batch} (batch size) entries."
+
+ context = []
+ question = []
+ for i in range(batch):
+ context.append(input_data[i]["context"])
+ question.append(input_data[i]["question"])
+
+ return context, question
+
+
+def positional_ids(config, input_ids, past_key_values_length=0):
+ seq_length = input_ids.size(1)
+ position_ids = torch.arange(config.max_position_embeddings, dtype=torch.long, device=input_ids.device)
+ position_ids = position_ids.unsqueeze(0)[:, past_key_values_length : seq_length + past_key_values_length]
+ position_ids = position_ids.expand_as(input_ids)
+
+ return position_ids
+
+
+def run_squeezebert_question_and_answering_inference(
+ device,
+ use_program_cache,
+ model_name,
+ batch_size,
+ sequence_size,
+ squeezebert,
+ input_path,
+):
+ disable_persistent_kernel_cache()
+
+ hugging_face_reference_model = SqueezeBertForQuestionAnswering.from_pretrained(model_name, torchscript=False)
+ hugging_face_reference_model.eval()
+ state_dict = hugging_face_reference_model.state_dict()
+
+ tokenizer = SqueezeBertTokenizer.from_pretrained(model_name)
+ config = hugging_face_reference_model.config
+ nlp = pipeline("question-answering", model=hugging_face_reference_model, tokenizer=tokenizer)
+
+ tt_model_name = f"ttnn_{model_name}"
+
+ def convert_to_ttnn(model, name):
+ return not isinstance(model, torch.nn.Conv1d)
+
+ profiler.start(f"preprocessing_parameter")
+ parameters = preprocess_model_parameters(
+ model_name=tt_model_name,
+ initialize_model=lambda: hugging_face_reference_model,
+ convert_to_ttnn=convert_to_ttnn,
+ custom_preprocessor=squeezebert.custom_preprocessor,
+ device=device,
+ )
+ profiler.end(f"preprocessing_parameter")
+
+ context, question = load_inputs(input_path, batch_size)
+
+ preprocess_params, _, postprocess_params = nlp._sanitize_parameters()
+ preprocess_params["max_seq_len"] = sequence_size
+ inputs = nlp._args_parser({"context": context, "question": question})
+
+ preprocessed_inputs = []
+ for i in range(batch_size):
+ model_input = next(nlp.preprocess(inputs[0][i], **preprocess_params))
+
+ single_input = {
+ "example": model_input["example"],
+ "inputs": model_input,
+ }
+ preprocessed_inputs.append(single_input)
+
+ squeezebert_input = tokenizer.batch_encode_plus(
+ zip(question, context),
+ max_length=sequence_size,
+ padding="max_length",
+ truncation=True,
+ return_attention_mask=True,
+ return_token_type_ids=True,
+ return_tensors="pt",
+ )
+
+ profiler.start(f"preprocessing_input")
+ position_ids = positional_ids(config, squeezebert_input.input_ids)
+ ttnn_squeezebert_inputs = squeezebert.preprocess_inputs(
+ squeezebert_input["input_ids"],
+ squeezebert_input["token_type_ids"],
+ position_ids,
+ squeezebert_input["attention_mask"],
+ device=device,
+ )
+ profiler.end(f"preprocessing_input")
+
+ profiler.start(f"inference_time")
+ tt_output = squeezebert.squeezebert_for_question_answering(
+ config,
+ *ttnn_squeezebert_inputs,
+ state_dict=state_dict,
+ base_addr=f"transformer.",
+ parameters=parameters,
+ device=device,
+ reader_patterns_cache=None,
+ )
+ profiler.end(f"inference_time")
+
+ tt_output = ttnn.to_torch(ttnn.from_device(tt_output)).reshape(batch_size, 1, sequence_size, -1).to(torch.float32)
+
+ tt_start_logits = tt_output[..., :, 0].squeeze(1)
+ tt_end_logits = tt_output[..., :, 1].squeeze(1)
+
+ model_answers = {}
+ profiler.start("post_processing_output_to_string")
+ for i in range(batch_size):
+ tt_res = {
+ "start": tt_start_logits[i],
+ "end": tt_end_logits[i],
+ "example": preprocessed_inputs[i]["example"],
+ **preprocessed_inputs[i]["inputs"],
+ }
+ tt_answer = nlp.postprocess([tt_res], **postprocess_params)
+
+ logger.info(f"answer: {tt_answer['answer']}\n")
+ model_answers[i] = tt_answer["answer"]
+
+ profiler.end("post_processing_output_to_string")
+
+ measurements = {
+ "preprocessing_parameter": profiler.get("preprocessing_parameter"),
+ "preprocessing_input": profiler.get("preprocessing_input"),
+ "inference_time": profiler.get("inference_time"),
+ "post_processing": profiler.get("post_processing_output_to_string"),
+ }
+ logger.info(f"preprocessing_parameter: {measurements['preprocessing_parameter']} s")
+ logger.info(f"preprocessing_input: {measurements['preprocessing_input']} s")
+ logger.info(f"inference_time: {measurements['inference_time']} s")
+ logger.info(f"post_processing : {measurements['post_processing']} s")
+
+ return measurements
+
+
+def run_squeezebert_question_and_answering_inference_squad_v2(
+ device,
+ use_program_cache,
+ model_name,
+ batch_size,
+ sequence_size,
+ squeezebert,
+ n_iterations,
+):
+ disable_persistent_kernel_cache()
+
+ hugging_face_reference_model = SqueezeBertForQuestionAnswering.from_pretrained(model_name, torchscript=False)
+ hugging_face_reference_model.eval()
+ state_dict = hugging_face_reference_model.state_dict()
+
+ tokenizer = SqueezeBertTokenizer.from_pretrained(model_name)
+ config = hugging_face_reference_model.config
+ tt_model_name = ttnn_functional_squeezebert
+
+ parameters = preprocess_model_parameters(
+ model_name=tt_model_name,
+ initialize_model=lambda: hugging_face_reference_model,
+ custom_preprocessor=squeezebert.custom_preprocessor,
+ device=device,
+ )
+
+ nlp = pipeline("question-answering", model=hugging_face_reference_model, tokenizer=tokenizer)
+
+ attention_mask = True
+ token_type_ids = True
+ inputs_squadv2 = squadv2_1K_samples_input(tokenizer, sequence_size, attention_mask, token_type_ids, batch_size)
+ squad_metric = evaluate.load("squad_v2")
+
+ with torch.no_grad():
+ pred_labels = []
+ cpu_pred_labels = []
+ true_labels = []
+ i = 0
+ for batch in inputs_squadv2:
+ if i < n_iterations:
+ batch_data = batch[0]
+ curr_batch_size = batch_data["input_ids"].shape[0]
+ position_ids = positional_ids(config, batch_data.input_ids)
+ ttnn_squeezebert_inputs = squeezebert.preprocess_inputs(
+ batch_data["input_ids"],
+ batch_data["token_type_ids"],
+ position_ids,
+ batch_data["attention_mask"],
+ device=device,
+ )
+
+ tt_output = squeezebert.squeezebert_for_question_answering(
+ config,
+ *ttnn_squeezebert_inputs,
+ state_dict=state_dict,
+ base_addr=f"transformer.",
+ parameters=parameters,
+ device=device,
+ reader_patterns_cache=None,
+ )
+ tt_output = (
+ ttnn.to_torch(ttnn.from_device(tt_output))
+ .reshape(batch_size, 1, sequence_size, -1)
+ .to(torch.float32)
+ )
+
+ cpu_output = hugging_face_reference_model(**batch_data)
+ references = batch[1]
+ question = batch[2]
+ context = batch[3]
+
+ cpu_predictions, tt_predictions = squadv2_answer_decode_batch(
+ hugging_face_reference_model,
+ tokenizer,
+ nlp,
+ references,
+ cpu_output,
+ tt_output,
+ curr_batch_size,
+ question,
+ context,
+ )
+ pred_labels.extend(tt_predictions)
+ cpu_pred_labels.extend(cpu_predictions)
+ true_labels.extend(references)
+
+ del tt_output
+ i += 1
+ eval_score = squad_metric.compute(predictions=pred_labels, references=true_labels)
+ cpu_eval_score = squad_metric.compute(predictions=cpu_pred_labels, references=true_labels)
+ logger.info(f"\tTT_Eval: exact: {eval_score['exact']} -- F1: {eval_score['f1']}")
+ logger.info(f"\tCPU_Eval: exact: {cpu_eval_score['exact']} -- F1: {cpu_eval_score['f1']}")
+
+ tolerance = 0.03
+ assert (
+ abs(eval_score["exact"] - cpu_eval_score["exact"]) <= tolerance
+ and abs(eval_score["f1"] - cpu_eval_score["f1"]) <= tolerance
+ ), (
+ f"Expected Exact Match : {cpu_eval_score['exact']}, Actual Exact Match: {eval_score['exact']}; "
+ f"Expected F1 Score : {cpu_eval_score['f1']}, Actual F1 Score: {eval_score['f1']}"
+ )
+
+
+@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True)
+@pytest.mark.parametrize(
+ "batch_size, sequence_size",
+ [
+ (8, 384),
+ ],
+)
+@pytest.mark.parametrize(
+ "model_name, input_loc",
+ ((["squeezebert/squeezebert-uncased", "models/demos/squeezebert/demo/input_data.json"]),),
+)
+@pytest.mark.parametrize("squeezebert", [ttnn_functional_squeezebert])
+def test_demo(input_loc, batch_size, sequence_size, model_name, squeezebert, device, use_program_cache, reset_seeds):
+ disable_persistent_kernel_cache()
+ disable_compilation_reports()
+
+ return run_squeezebert_question_and_answering_inference(
+ device=device,
+ use_program_cache=use_program_cache,
+ model_name=model_name,
+ batch_size=batch_size,
+ sequence_size=sequence_size,
+ squeezebert=squeezebert,
+ input_path=input_loc,
+ )
+
+
+@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True)
+@pytest.mark.parametrize(
+ "batch_size, sequence_size",
+ [
+ (8, 384),
+ ],
+)
+@pytest.mark.parametrize("model_name", ["squeezebert/squeezebert-uncased"])
+@pytest.mark.parametrize("squeezebert", [ttnn_functional_squeezebert])
+@pytest.mark.parametrize(
+ "n_iterations",
+ ((3),),
+)
+def test_demo_squadv2(
+ batch_size, sequence_size, model_name, squeezebert, n_iterations, device, use_program_cache, reset_seeds
+):
+ disable_persistent_kernel_cache()
+ disable_compilation_reports()
+
+ return run_squeezebert_question_and_answering_inference_squad_v2(
+ device=device,
+ use_program_cache=use_program_cache,
+ model_name=model_name,
+ batch_size=batch_size,
+ sequence_size=sequence_size,
+ squeezebert=squeezebert,
+ n_iterations=n_iterations,
+ )
diff --git a/models/demos/squeezebert/demo/input_data.json b/models/demos/squeezebert/demo/input_data.json
new file mode 100644
index 00000000000..950b8d36323
--- /dev/null
+++ b/models/demos/squeezebert/demo/input_data.json
@@ -0,0 +1,50 @@
+[
+ {
+ "context" : "Johann Joachim Winckelmann was a German art historian and archaeologist. He was a pioneering Hellenist who first articulated the difference between Greek, Greco-Roman and Roman art. The prophet and founding hero of modern archaeology, Winckelmann was one of the founders of scientific archaeology and first applied the categories of style on a large, systematic basis to the history of art.",
+ "question" : "What discipline did Winkelmann create?"
+ },
+ {
+ "context" : "The Norman dynasty had a major political, cultural and military impact on medieval Europe and even the Near East. The Normans were famed for their martial spirit and eventually for their Christian piety, becoming exponents of the Catholic orthodoxy into which they assimilated. They adopted the Gallo-Romance language of the Frankish land they settled, their dialect becoming known as Norman, Normaund or Norman French, an important literary language. The Duchy of Normandy, which they formed by treaty with the French crown, was a great fief of medieval France, and under Richard I of Normandy was forged into a cohesive and formidable principality in feudal tenure. The Normans are noted both for their culture, such as their unique Romanesque architecture and musical traditions, and for their significant military accomplishments and innovations. Norman adventurers founded the Kingdom of Sicily under Roger II after conquering southern Italy on the Saracens and Byzantines, and an expedition on behalf of their duke, William the Conqueror, led to the Norman conquest of England at the Battle of Hastings in 1066. Norman cultural and military influence spread from these new European centres to the Crusader states of the Near East, where their prince Bohemond I founded the Principality of Antioch in the Levant, to Scotland and Wales in Great Britain, to Ireland, and to the coasts of north Africa and the Canary Islands.",
+ "question" : "Who ruled the duchy of Normandy"
+ },
+ {
+ "context" : "In many countries, there is a Gender pay gap in favor of males in the labor market. Several factors other than discrimination may contribute to this gap. On average, women are more likely than men to consider factors other than pay when looking for work, and may be less willing to travel or relocate. Thomas Sowell, in his book Knowledge and Decisions, claims that this difference is due to women not taking jobs due to marriage or pregnancy, but income studies show that that does not explain the entire difference. A U.S. Census's report stated that in US once other factors are accounted for there is still a difference in earnings between women and men. The income gap in other countries ranges from 53% in Botswana to -40% in Bahrain.",
+ "question" : "Who does a gender pay gap tend to favor?"
+ },
+ {
+ "context" : "Most of the Huguenot congregations (or individuals) in North America eventually affiliated with other Protestant denominations with more numerous members. The Huguenots adapted quickly and often married outside their immediate French communities, which led to their assimilation. Their descendants in many families continued to use French first names and surnames for their children well into the nineteenth century. Assimilated, the French made numerous contributions to United States economic life, especially as merchants and artisans in the late Colonial and early Federal periods. For example, E.I. du Pont, a former student of Lavoisier, established the Eleutherian gunpowder mills.",
+ "question" : "How were Huguenot settlers assimilated into North American society at large?"
+ },
+ {
+ "context" : "In the laboratory, biostratigraphers analyze rock samples from outcrop and drill cores for the fossils found in them. These fossils help scientists to date the core and to understand the depositional environment in which the rock units formed. Geochronologists precisely date rocks within the stratigraphic section in order to provide better absolute bounds on the timing and rates of deposition. Magnetic stratigraphers look for signs of magnetic reversals in igneous rock units within the drill cores. Other scientists perform stable isotope studies on the rocks to gain information about past climate.",
+ "question" : "Who analyzes rock samples from drill cores in the lab?"
+ },
+ {
+ "context" : "Neutrophils and macrophages are phagocytes that travel throughout the body in pursuit of invading pathogens. Neutrophils are normally found in the bloodstream and are the most abundant type of phagocyte, normally representing 50% to 60% of the total circulating leukocytes. During the acute phase of inflammation, particularly as a result of bacterial infection, neutrophils migrate toward the site of inflammation in a process called chemotaxis, and are usually the first cells to arrive at the scene of infection. Macrophages are versatile cells that reside within tissues and produce a wide array of chemicals including enzymes, complement proteins, and regulatory factors such as interleukin 1. Macrophages also act as scavengers, ridding the body of worn-out cells and other debris, and as antigen-presenting cells that activate the adaptive immune system.",
+ "question" : "What is the process in which neutrophils move towards the site of inflammation called?"
+ },
+ {
+ "context" : "In Afghanistan, the mujahideen's victory against the Soviet Union in the 1980s did not lead to justice and prosperity, due to a vicious and destructive civil war between political and tribal warlords, making Afghanistan one of the poorest countries on earth. In 1992, the Democratic Republic of Afghanistan ruled by communist forces collapsed, and democratic Islamist elements of mujahdeen founded the Islamic State of Afghanistan. In 1996, a more conservative and anti-democratic Islamist movement known as the Taliban rose to power, defeated most of the warlords and took over roughly 80% of Afghanistan.",
+ "question" : "When did the Democratic Republic of Afghanistan collapse?"
+ },
+ {
+ "context" : "The largest single sensory feature is the aboral organ (at the opposite end from the mouth). Its main component is a statocyst, a balance sensor consisting of a statolith, a solid particle supported on four bundles of cilia, called \"balancers\", that sense its orientation. The statocyst is protected by a transparent dome made of long, immobile cilia. A ctenophore does not automatically try to keep the statolith resting equally on all the balancers. Instead its response is determined by the animal's \"mood\", in other words the overall state of the nervous system. For example, if a ctenophore with trailing tentacles captures prey, it will often put some comb rows into reverse, spinning the mouth towards the prey.",
+ "question" : "What is the main component of the aboral organ?"
+ },
+ {
+ "context": "Mark Rothko was a Latvian-born American abstract painter. He is best known for his color field paintings that depicted irregular and painterly rectangular regions of color, which he produced from 1949 to 1970. Although Rothko did not personally subscribe to any one school, he is associated with the American Abstract Expressionist movement of modern art. Originally emigrating to Portland, Oregon, from Russian Empire (Latvia) with his family, Rothko later moved to New York City where his youthful period of artistic production dealt primarily with urban scenery.",
+ "question": "what is Rothko best known for?"
+ },
+ {
+ "context": "Malignant narcissism is a psychological syndrome that could include aspects of narcissistic personality disorder (NPD) alongside a mix of antisocial, paranoid and sadistic personality disorder traits. The importance of malignant narcissism and of projection as a defense mechanism has been confirmed in paranoia, as well as the patient's vulnerability to malignant narcissistic regression. A person with malignant narcissism exhibits paranoia in addition to the symptoms of a Narcissistic Personality Disorder. Because a malignant narcissist's personality cannot tolerate any criticism, being mocked typically causes paranoia.",
+ "question": "What symptoms a malignant narcissist might exhibit in addition to the symptoms of a NPD patient?"
+ },
+ {
+ "context": "The 14 July Revolution, also known as the 1958 Iraqi military coup, was a coup d'état that took place on 14 July 1958 in Iraq which resulted in the toppling of King Faisal II and the overthrow of the Hashemite-led Kingdom of Iraq. The Iraqi Republic established in its wake ended the Hashemite Arab Federation between Iraq and Jordan that had been established just six months earlier. In July 1958, units of the Royal Iraqi Army were dispatched to Jordan in support of King Hussein. A group of Iraqi Free Officers, led by Brigadier Abd al-Karim Qasim and Colonel Abdul Salam Arif, took advantage of the opportunity and instead marched on Baghdad. On 14 July, revolutionary forces seized control of the capital and proclaimed a new republic, headed by a Revolutionary Council.",
+ "question": "When was the Hashemite Arab Federation formed?"
+ },
+ {
+ "context": "The Tasmanian devil is a carnivorous marsupial of the family Dasyuridae. It was formerly present across mainland Australia, but became extinct there around 3,500 years ago. The size of a small dog, the Tasmanian devil became the largest carnivorous marsupial in the world following the extinction of the thylacine in 1936. It is related to quolls, and distantly related to the thylacine. It is characterised by its stocky and muscular build, black fur, pungent odour, extremely loud and disturbing screech, keen sense of smell, and ferocity when feeding. The Tasmanian devil's large head and neck allow it to generate among the strongest bites per unit body mass of any extant predatory land mammal. It hunts prey and scavenges on carrion.",
+ "question": "What allows Tasmanian devil to generate strong bites?"
+ }
+]
diff --git a/models/demos/squeezebert/tests/test_perf_device_squeezebert.py b/models/demos/squeezebert/tests/test_perf_device_squeezebert.py
new file mode 100644
index 00000000000..7f8acbca401
--- /dev/null
+++ b/models/demos/squeezebert/tests/test_perf_device_squeezebert.py
@@ -0,0 +1,37 @@
+# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
+
+# SPDX-License-Identifier: Apache-2.0
+
+import pytest
+from models.utility_functions import is_grayskull
+from models.perf.device_perf_utils import run_device_perf, check_device_perf, prep_device_perf_report
+
+
+@pytest.mark.models_device_performance_bare_metal
+@pytest.mark.parametrize(
+ "batch_size, test",
+ [
+ [8, "sequence_size=384-batch_size=8-model_name=squeezebert/squeezebert-uncased"],
+ ],
+)
+def test_perf_device_bare_metal(batch_size, test):
+ subdir = "ttnn_squeezebert"
+ num_iterations = 1
+ margin = 0.03
+ expected_perf = 114.8 if is_grayskull() else 284.5
+
+ command = f"pytest tests/ttnn/integration_tests/squeezebert/test_ttnn_squeezebert.py::test_squeezebert_for_question_answering"
+ cols = ["DEVICE FW", "DEVICE KERNEL", "DEVICE BRISC KERNEL"]
+
+ inference_time_key = "AVG DEVICE KERNEL SAMPLES/S"
+ expected_perf_cols = {inference_time_key: expected_perf}
+
+ post_processed_results = run_device_perf(command, subdir, num_iterations, cols, batch_size)
+ expected_results = check_device_perf(post_processed_results, margin, expected_perf_cols, assert_on_fail=True)
+ prep_device_perf_report(
+ model_name=f"ttnn_squeezebert_{batch_size}",
+ batch_size=batch_size,
+ post_processed_results=post_processed_results,
+ expected_results=expected_results,
+ comments=test.replace("/", "_"),
+ )
diff --git a/models/demos/squeezebert/tests/test_performance.py b/models/demos/squeezebert/tests/test_performance.py
new file mode 100644
index 00000000000..e08fcd6aa2d
--- /dev/null
+++ b/models/demos/squeezebert/tests/test_performance.py
@@ -0,0 +1,152 @@
+# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
+
+# SPDX-License-Identifier: Apache-2.0
+
+import ttnn
+import time
+import torch
+import pytest
+import transformers
+from loguru import logger
+
+from models.utility_functions import is_grayskull
+from models.perf.perf_utils import prep_perf_report
+
+from ttnn.model_preprocessing import preprocess_model_parameters
+from models.demos.squeezebert.tt import ttnn_functional_squeezebert
+from models.experimental.functional_common.attention_mask_functions import get_extended_attention_mask
+
+from models.utility_functions import (
+ enable_persistent_kernel_cache,
+ disable_persistent_kernel_cache,
+)
+
+
+def preprocess_inputs(
+ input_ids,
+ token_type_ids,
+ position_ids,
+ attention_mask,
+):
+ batch_size, *_ = input_ids.shape
+
+ input_ids = ttnn.from_torch(input_ids, dtype=ttnn.uint32)
+ token_type_ids = ttnn.from_torch(token_type_ids, dtype=ttnn.uint32)
+ position_ids = ttnn.from_torch(position_ids, dtype=ttnn.uint32)
+
+ if attention_mask is not None:
+ attention_mask = get_extended_attention_mask(attention_mask, input_ids.shape)
+ attention_mask = attention_mask.expand((batch_size, -1, -1, -1))
+ attention_mask = torch.clamp(attention_mask, min=-100000)
+ attention_mask = ttnn.from_torch(
+ attention_mask,
+ dtype=ttnn.bfloat16,
+ layout=ttnn.TILE_LAYOUT,
+ )
+ return input_ids, token_type_ids, position_ids, attention_mask
+
+
+def get_expected_times(squeezebert):
+ return {ttnn_functional_squeezebert: (13.5, 11.5) if is_grayskull() else (16.5, 8.5)}[squeezebert]
+
+
+@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True)
+@pytest.mark.models_performance_bare_metal
+@pytest.mark.models_performance_virtual_machine
+@pytest.mark.parametrize("model_name", ["squeezebert/squeezebert-uncased"])
+@pytest.mark.parametrize("sequence_size", [384])
+@pytest.mark.parametrize("squeezebert", [ttnn_functional_squeezebert])
+def test_performance(device, use_program_cache, model_name, sequence_size, squeezebert):
+ disable_persistent_kernel_cache()
+
+ num_iterations = 2
+ batch_size = 8
+
+ config = transformers.SqueezeBertConfig.from_pretrained(model_name)
+ rf_model = transformers.SqueezeBertForQuestionAnswering.from_pretrained(model_name)
+ state_dict = rf_model.state_dict()
+
+ input_ids = torch.randint(0, config.vocab_size, (batch_size, sequence_size)).to(torch.int32)
+ torch_token_type_ids = torch.zeros((batch_size, sequence_size), dtype=torch.int32)
+ position_ids = torch.zeros((batch_size, sequence_size), dtype=torch.int32)
+ torch_attention_mask = torch.ones(1, sequence_size)
+
+ if squeezebert == ttnn_functional_squeezebert:
+ tt_model_name = f"ttnn_{model_name}_optimized"
+ else:
+ raise ValueError(f"Unknown squeezebert: {squeezebert}")
+
+ parameters = preprocess_model_parameters(
+ model_name=tt_model_name,
+ initialize_model=lambda: rf_model,
+ custom_preprocessor=squeezebert.custom_preprocessor,
+ device=device,
+ )
+
+ ttnn_squeezebert_inputs_on_cpu = preprocess_inputs(
+ input_ids,
+ torch_token_type_ids,
+ position_ids,
+ torch_attention_mask,
+ )
+
+ start = time.time()
+ ttnn_squeezebert_inputs = [
+ ttnn.to_device(tensor, device=device, memory_config=ttnn.L1_MEMORY_CONFIG) if tensor is not None else tensor
+ for tensor in ttnn_squeezebert_inputs_on_cpu
+ ]
+ tt_output = squeezebert.squeezebert_for_question_answering(
+ config,
+ *ttnn_squeezebert_inputs,
+ state_dict=state_dict,
+ base_addr=f"transformer.",
+ parameters=parameters,
+ device=device,
+ reader_patterns_cache={},
+ )
+
+ tt_output = ttnn.from_device(tt_output, blocking=False)
+ ttnn.synchronize_device(device)
+ end = time.time()
+ inference_and_compile_time = end - start
+ enable_persistent_kernel_cache()
+
+ start = time.time()
+ for _ in range(num_iterations):
+ ttnn_squeezebert_inputs = [
+ ttnn.to_device(tensor, device=device, memory_config=ttnn.L1_MEMORY_CONFIG) if tensor is not None else tensor
+ for tensor in ttnn_squeezebert_inputs_on_cpu
+ ]
+ tt_output = squeezebert.squeezebert_for_question_answering(
+ config,
+ *ttnn_squeezebert_inputs,
+ state_dict=state_dict,
+ base_addr=f"transformer.",
+ parameters=parameters,
+ device=device,
+ reader_patterns_cache={},
+ )
+ tt_output = ttnn.from_device(tt_output, blocking=False)
+ ttnn.synchronize_device(device)
+ end = time.time()
+ average_inference_time = (end - start) / num_iterations
+
+ expected_compile_time, expected_inference_time = get_expected_times(squeezebert)
+ prep_perf_report(
+ model_name=tt_model_name,
+ batch_size=batch_size,
+ inference_and_compile_time=inference_and_compile_time,
+ inference_time=average_inference_time,
+ expected_compile_time=expected_compile_time,
+ expected_inference_time=expected_inference_time,
+ comments="",
+ inference_time_cpu=0.0,
+ )
+
+ logger.info(f"Compile time: {inference_and_compile_time - average_inference_time}")
+ logger.info(f"Inference time: {average_inference_time}")
+ logger.info(f"Samples per second: {1 / average_inference_time * batch_size}")
+
+ assert (
+ average_inference_time < expected_inference_time
+ ), f"Expected inference time: {expected_inference_time} Actual inference time: {average_inference_time}"
diff --git a/models/demos/squeezebert/tt/ttnn_functional_squeezebert.py b/models/demos/squeezebert/tt/ttnn_functional_squeezebert.py
new file mode 100644
index 00000000000..b0fa2c60431
--- /dev/null
+++ b/models/demos/squeezebert/tt/ttnn_functional_squeezebert.py
@@ -0,0 +1,515 @@
+# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
+
+# SPDX-License-Identifier: Apache-2.0
+
+import ttnn
+import torch
+from torch import nn
+from models.utility_functions import is_grayskull
+from tests.ttnn.ttnn_utility_fuction import get_shard_grid_from_num_cores
+from models.experimental.functional_common.attention_mask_functions import get_extended_attention_mask
+
+
+def transpose_for_scores(config, x, device, permute_tensor: bool):
+ new_x_shape = (x.shape[0], config.num_attention_heads, config.attention_head_size, x.shape[-1])
+ x = ttnn.from_device(x)
+ x = ttnn.reshape(x, new_x_shape)
+ x = ttnn.to_device(x, device)
+
+ if permute_tensor:
+ x = ttnn.permute(x, (0, 1, 3, 2))
+
+ return x
+
+
+def transpose_output(config, x, device):
+ all_head_size = config.num_attention_heads * config.attention_head_size
+ if len(x.shape) == 4:
+ x = ttnn.permute(x, (0, 1, 3, 2))
+
+ new_x_shape = (x.shape[0], all_head_size, x.shape[3])
+ x = ttnn.reshape(x, new_x_shape)
+
+ return x
+
+
+def permute_reshape(hidden_states, shape=(0, 2, 1), reshape=True):
+ bs, *_ = hidden_states.shape
+ hidden_states = ttnn.permute(hidden_states, (0, 2, 1))
+ if reshape:
+ hidden_states = ttnn.reshape(hidden_states, (bs, hidden_states.shape[-2], hidden_states.shape[-1]))
+
+ return hidden_states
+
+
+def ttnn_conv1d(
+ device,
+ tt_input_tensor,
+ weights,
+ conv_params,
+ bias,
+ *,
+ output_dtype=ttnn.bfloat16,
+ weights_dtype=ttnn.bfloat8_b,
+ math_fidelity=ttnn.MathFidelity.LoFi,
+ deallocate_activation=False,
+ act_block_h=None,
+ height_sharding=True,
+ use_shallow_conv_variant=False,
+ fp32_accum=False,
+ packer_l1_acc=False,
+ debug=False,
+ groups=4,
+ math_approx=True,
+ activation="",
+ reallocate_halo=False,
+ reshard=False,
+):
+ weights = ttnn.from_torch(weights, dtype=ttnn.float32)
+ bias = ttnn.from_torch(bias.unsqueeze(0).unsqueeze(0).unsqueeze(0), dtype=ttnn.float32)
+
+ conv_config = ttnn.Conv1dConfig(
+ dtype=ttnn.bfloat16,
+ weights_dtype=ttnn.bfloat8_b,
+ activation=activation,
+ input_channels_alignment=(16 if use_shallow_conv_variant else 32),
+ deallocate_activation=deallocate_activation,
+ reallocate_halo_output=reallocate_halo,
+ act_block_h_override=32,
+ reshard_if_not_optimal=reshard,
+ shard_layout=(
+ ttnn.TensorMemoryLayout.HEIGHT_SHARDED if height_sharding else ttnn.TensorMemoryLayout.BLOCK_SHARDED
+ ),
+ core_grid=get_shard_grid_from_num_cores(56, device),
+ )
+ compute_config = ttnn.init_device_compute_kernel_config(
+ device.arch(),
+ math_fidelity=math_fidelity,
+ math_approx_mode=math_approx,
+ fp32_dest_acc_en=fp32_accum,
+ packer_l1_acc=packer_l1_acc,
+ )
+
+ [tt_output_tensor_on_device, out_length, [weights_device, bias_device]] = ttnn.Conv1d(
+ input_tensor=tt_input_tensor,
+ weight_tensor=weights,
+ in_channels=tt_input_tensor.shape[-1],
+ out_channels=weights.shape[0],
+ device=device,
+ bias_tensor=bias,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ batch_size=tt_input_tensor.shape[0],
+ input_length=tt_input_tensor.shape[1],
+ conv_config=conv_config,
+ compute_config=compute_config,
+ conv_op_cache={},
+ debug=debug,
+ groups=groups,
+ return_output_dim=True,
+ return_weights_and_bias=True,
+ )
+
+ tt_output_tensor_on_device = ttnn.squeeze(tt_output_tensor_on_device, 0)
+ tt_output_tensor_on_device = ttnn.reshape(
+ tt_output_tensor_on_device, (tt_input_tensor.shape[0], out_length, tt_output_tensor_on_device.shape[-1])
+ )
+
+ tt_output_tensor = ttnn.from_device(tt_output_tensor_on_device)
+
+ return tt_output_tensor
+
+
+def squeezebert_conv_layernorm(
+ config,
+ hidden_states,
+ input_tensor,
+ *,
+ state_dict,
+ base_addr,
+ parameters,
+ device,
+ cin,
+ cout,
+ groups,
+):
+ torch_hidden_states = ttnn.to_torch(hidden_states).to(torch.float32)
+ self_output_conv1d_ = nn.Conv1d(in_channels=cin, out_channels=cout, kernel_size=1, groups=groups)
+ self_output_conv1d_.weight = nn.Parameter(state_dict[f"{base_addr}conv1d.weight"])
+ self_output_conv1d_.bias = nn.Parameter(state_dict[f"{base_addr}conv1d.bias"])
+
+ torch_self_output = self_output_conv1d_(torch_hidden_states)
+ self_output = ttnn.from_torch(torch_self_output, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
+
+ self_output_layernorm = ttnn.add(self_output, input_tensor)
+ self_output_layernorm = permute_reshape(self_output_layernorm)
+
+ attention_output = ttnn.layer_norm(
+ self_output_layernorm,
+ weight=parameters.layernorm.weight,
+ bias=parameters.layernorm.bias,
+ epsilon=config.layer_norm_eps,
+ )
+ ttnn.deallocate(self_output_layernorm)
+ attention_output = permute_reshape(attention_output)
+
+ return attention_output
+
+
+def squeezebert_attention(
+ config,
+ hidden_states,
+ attention_mask,
+ *,
+ state_dict,
+ base_addr,
+ parameters,
+ device,
+ reader_patterns_cache,
+ num_cores_x=12,
+):
+ num_heads = config.num_attention_heads
+ batch_size, hidden_size, _ = hidden_states.shape
+ head_size = hidden_size // num_heads
+ config.attention_head_size = head_size
+
+ hidden_states = ttnn.to_layout(hidden_states, ttnn.ROW_MAJOR_LAYOUT)
+ hidden_states = permute_reshape(hidden_states)
+ hidden_states = ttnn.from_device(hidden_states)
+ mixed_query_layer = ttnn_conv1d(
+ device,
+ hidden_states,
+ nn.Parameter(state_dict[f"{base_addr}query.weight"]),
+ conv_params=[1, 0],
+ bias=nn.Parameter(state_dict[f"{base_addr}query.bias"]),
+ )
+ mixed_query_layer = ttnn.to_device(mixed_query_layer, device)
+ mixed_query_layer = ttnn.permute(mixed_query_layer, (0, 2, 1))
+
+ mixed_key_layer = ttnn_conv1d(
+ device,
+ hidden_states,
+ nn.Parameter(state_dict[f"{base_addr}key.weight"]),
+ conv_params=[1, 0],
+ bias=nn.Parameter(state_dict[f"{base_addr}key.bias"]),
+ )
+ mixed_key_layer = ttnn.to_device(mixed_key_layer, device)
+ mixed_key_layer = ttnn.permute(mixed_key_layer, (0, 2, 1))
+
+ mixed_value_layer = ttnn_conv1d(
+ device,
+ hidden_states,
+ nn.Parameter(state_dict[f"{base_addr}value.weight"]),
+ conv_params=[1, 0],
+ bias=nn.Parameter(state_dict[f"{base_addr}value.bias"]),
+ )
+ mixed_value_layer = ttnn.to_device(mixed_value_layer, device)
+ mixed_value_layer = ttnn.permute(mixed_value_layer, (0, 2, 1))
+
+ query = transpose_for_scores(config, mixed_query_layer, device, True)
+ key = transpose_for_scores(config, mixed_key_layer, device, False)
+ value = transpose_for_scores(config, mixed_value_layer, device, True)
+
+ ttnn.deallocate(mixed_query_layer)
+ ttnn.deallocate(mixed_key_layer)
+ ttnn.deallocate(mixed_value_layer)
+
+ attention_scores = ttnn.matmul(
+ query,
+ key,
+ memory_config=ttnn.L1_MEMORY_CONFIG,
+ dtype=ttnn.bfloat16,
+ core_grid=ttnn.CoreGrid(y=batch_size, x=num_cores_x),
+ )
+ ttnn.deallocate(query)
+ ttnn.deallocate(key)
+
+ attention_probs = ttnn.transformer.attention_softmax_(
+ attention_scores, attention_mask=attention_mask, head_size=head_size
+ )
+
+ context_layer = ttnn.matmul(
+ attention_probs,
+ value,
+ memory_config=ttnn.L1_MEMORY_CONFIG,
+ dtype=ttnn.bfloat8_b,
+ core_grid=ttnn.CoreGrid(y=batch_size, x=num_cores_x),
+ )
+ context_layer = transpose_output(config, context_layer, device)
+
+ return context_layer
+
+
+def squeezebert_intermediate(
+ config,
+ hidden_states,
+ *,
+ state_dict,
+ base_addr,
+ parameters,
+ device,
+ num_cores_x=12,
+):
+ torch_hidden_states = ttnn.to_torch(hidden_states).to(torch.float32)
+
+ torch_conv_ = nn.Conv1d(
+ in_channels=config.hidden_size,
+ out_channels=config.intermediate_size,
+ kernel_size=1,
+ groups=config.intermediate_groups,
+ )
+ torch_conv_.weight = nn.Parameter(state_dict[f"{base_addr}conv1d.weight"])
+ torch_conv_.bias = nn.Parameter(state_dict[f"{base_addr}conv1d.bias"])
+
+ torch_conv_output = torch_conv_(torch_hidden_states)
+ ttnn_conv_output = ttnn.from_torch(torch_conv_output, device=device, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT)
+
+ output = ttnn.gelu(ttnn_conv_output)
+ return output
+
+
+def squeezebert_layer(
+ config,
+ hidden_states,
+ attention_mask,
+ state_dict,
+ base_addr,
+ parameters,
+ device,
+ reader_patterns_cache,
+):
+ multi_head_attention_output = squeezebert_attention(
+ config,
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ state_dict=state_dict,
+ base_addr=f"{base_addr}attention.",
+ parameters=parameters.attention,
+ device=device,
+ reader_patterns_cache=reader_patterns_cache,
+ )
+
+ attention_output = squeezebert_conv_layernorm(
+ config,
+ hidden_states=multi_head_attention_output,
+ input_tensor=hidden_states,
+ state_dict=state_dict,
+ base_addr=f"{base_addr}post_attention.",
+ parameters=parameters.post_attention,
+ device=device,
+ cin=config.hidden_size,
+ cout=config.hidden_size,
+ groups=config.post_attention_groups,
+ )
+ ttnn.deallocate(hidden_states)
+ ttnn.deallocate(multi_head_attention_output)
+
+ intermediate = squeezebert_intermediate(
+ config,
+ attention_output,
+ state_dict=state_dict,
+ base_addr=f"{base_addr}intermediate.",
+ parameters=parameters.intermediate,
+ device=device,
+ )
+
+ output = squeezebert_conv_layernorm(
+ config,
+ hidden_states=intermediate,
+ input_tensor=attention_output,
+ state_dict=state_dict,
+ base_addr=f"{base_addr}output.",
+ parameters=parameters.output,
+ device=device,
+ cin=config.intermediate_size,
+ cout=config.hidden_size,
+ groups=config.output_groups,
+ )
+
+ return output
+
+
+def squeezebert_encoder(
+ config,
+ hidden_states,
+ attention_mask,
+ *,
+ state_dict,
+ base_addr,
+ parameters,
+ device,
+ reader_patterns_cache,
+):
+ hidden_states = permute_reshape(hidden_states)
+ encoder_output = None
+
+ for layer_idx, encoder_parameters in enumerate(parameters.layers):
+ encoder_output = squeezebert_layer(
+ config,
+ hidden_states,
+ attention_mask,
+ state_dict,
+ base_addr=f"{base_addr}layers.{layer_idx}.",
+ parameters=encoder_parameters,
+ device=device,
+ reader_patterns_cache=reader_patterns_cache,
+ )
+ encoder_output = ttnn.reallocate(encoder_output)
+ hidden_states = encoder_output
+
+ hidden_states = permute_reshape(hidden_states)
+
+ return hidden_states
+
+
+def squeezebert(
+ config,
+ input_ids,
+ token_type_ids,
+ position_ids,
+ attention_mask,
+ state_dict,
+ base_addr,
+ parameters,
+ device,
+ reader_patterns_cache,
+):
+ word_embeddings = ttnn.embedding(
+ input_ids,
+ parameters.embeddings.word_embeddings.weight,
+ layout=ttnn.TILE_LAYOUT,
+ padding_idx=config.pad_token_id,
+ )
+ ttnn.deallocate(input_ids)
+
+ token_type_embeddings = ttnn.embedding(
+ token_type_ids,
+ parameters.embeddings.token_type_embeddings.weight,
+ layout=ttnn.TILE_LAYOUT,
+ )
+ ttnn.deallocate(token_type_ids)
+
+ word_plus_token_type_embeddings = word_embeddings + token_type_embeddings
+ ttnn.deallocate(word_embeddings)
+ ttnn.deallocate(token_type_embeddings)
+
+ position_embeddings = ttnn.embedding(
+ position_ids,
+ parameters.embeddings.position_embeddings.weight,
+ layout=ttnn.TILE_LAYOUT,
+ )
+ ttnn.deallocate(position_ids)
+
+ embeddings = word_plus_token_type_embeddings + position_embeddings
+ ttnn.deallocate(word_plus_token_type_embeddings)
+ ttnn.deallocate(position_embeddings)
+
+ encoder_input = ttnn.layer_norm(
+ embeddings,
+ weight=parameters.embeddings.LayerNorm.weight,
+ bias=parameters.embeddings.LayerNorm.bias,
+ memory_config=ttnn.DRAM_MEMORY_CONFIG if is_grayskull() else ttnn.L1_MEMORY_CONFIG,
+ )
+ ttnn.deallocate(embeddings)
+
+ encoder_output = squeezebert_encoder(
+ config=config,
+ hidden_states=encoder_input,
+ attention_mask=attention_mask,
+ state_dict=state_dict,
+ base_addr=f"{base_addr}encoder.",
+ parameters=parameters.encoder,
+ device=device,
+ reader_patterns_cache=reader_patterns_cache,
+ )
+ ttnn.deallocate(encoder_input)
+
+ return encoder_output
+
+
+def squeezebert_for_question_answering(
+ config,
+ input_ids,
+ token_type_ids,
+ position_ids,
+ attention_mask,
+ *,
+ state_dict,
+ base_addr,
+ parameters,
+ device,
+ reader_patterns_cache,
+ name="transformer",
+):
+ squeezebert_output = squeezebert(
+ config,
+ input_ids,
+ token_type_ids,
+ position_ids,
+ attention_mask,
+ state_dict,
+ base_addr,
+ parameters=parameters.transformer,
+ device=device,
+ reader_patterns_cache=reader_patterns_cache,
+ )
+ qa_outputs = ttnn.linear(
+ squeezebert_output,
+ parameters.qa_outputs.weight,
+ bias=parameters.qa_outputs.bias,
+ memory_config=ttnn.L1_MEMORY_CONFIG,
+ )
+
+ return qa_outputs
+
+
+def preprocess_inputs(
+ input_ids,
+ token_type_ids,
+ position_ids,
+ attention_mask,
+ device,
+):
+ import torch
+
+ batch_size, _ = input_ids.shape
+
+ input_ids = ttnn.from_torch(input_ids, dtype=ttnn.uint32, device=device, memory_config=ttnn.L1_MEMORY_CONFIG)
+ token_type_ids = ttnn.from_torch(
+ token_type_ids, dtype=ttnn.uint32, device=device, memory_config=ttnn.L1_MEMORY_CONFIG
+ )
+ position_ids = ttnn.from_torch(position_ids, dtype=ttnn.uint32, device=device, memory_config=ttnn.L1_MEMORY_CONFIG)
+
+ if attention_mask is not None:
+ attention_mask = get_extended_attention_mask(attention_mask, input_ids.shape, torch.float32)
+ attention_mask = attention_mask.expand((batch_size, -1, -1, -1))
+ attention_mask = torch.clamp(attention_mask, min=-100000)
+ attention_mask = ttnn.from_torch(
+ attention_mask,
+ dtype=ttnn.bfloat16,
+ layout=ttnn.TILE_LAYOUT,
+ device=device,
+ memory_config=ttnn.L1_MEMORY_CONFIG,
+ )
+
+ return input_ids, token_type_ids, position_ids, attention_mask
+
+
+def preprocess_conv_parameter(parameter, *, dtype):
+ parameter = ttnn.from_torch(parameter, dtype=dtype)
+ return parameter
+
+
+def custom_preprocessor(model, name):
+ parameters = {}
+ if isinstance(model, nn.Conv1d):
+ weight = model.weight
+ bias = model.bias
+
+ while bias.dim() < 4:
+ bias = bias.unsqueeze(0).unsqueeze(0).unsqueeze(0)
+ parameters["weight"] = preprocess_conv_parameter(weight, dtype=ttnn.float32)
+ parameters["bias"] = preprocess_conv_parameter(bias, dtype=ttnn.float32)
+
+ return parameters
diff --git a/models/demos/t3000/llama3_70b/README.md b/models/demos/t3000/llama3_70b/README.md
index 6555cd36dbf..80f344040d4 100644
--- a/models/demos/t3000/llama3_70b/README.md
+++ b/models/demos/t3000/llama3_70b/README.md
@@ -1,32 +1,74 @@
-# Llama3-70B Demo
+# Llama3/3.1-70B Demo
+
+## Table of Contents
+
+- [One command run](#one-command-run)
+- [How to Run](#how-to-run)
+ - [Running the demo from TT-Metalium](#running-the-demo-from-tt-metalium)
+ - [Serving the model from vLLM](#serving-the-model-from-vllm)
+
+## One command run
+
+```bash
+chmod +x ./models/demos/t3000/llama3_70b/setup_llama.sh && ./models/demos/t3000/llama3_70b/setup_llama.sh
+```
+
+Where, `TT_METAL_COMMIT_SHA_OR_TAG` and `TT_VLLM_COMMIT_SHA_OR_TAG` are found in the root [README](/README.md#llms) under "Release" version, respectively.
+
+Example:
+
+```bash
+./models/demos/t3000/llama3_70b/setup_llama.sh llama-3.1-70b-instruct v0.53.0-rc36 384f1790c3be16e1d1b10de07252be2e66d00935
+```
+
+Follow prompts as they come up in CLI to select appropriate weights for Llama 3.1 70B Instruct.
+
+Prerequisites:
+
+- Submit request to access weights from Meta: [Llama Downloads](https://www.llama.com/llama-downloads)
+- Submit permissions on HuggingFace and have a HF personal access token: [Llama 3.1 70B Instruct](https://huggingface.co/meta-llama/Llama-3.1-70B-Instruct)
+
+Steps run:
+
+- Setup environment
+- Build `tt-metal`
+- Download Llama 3.1 70B Instruct weights
+- Install vLLM
+- Deploy vLLM server
## How to Run
-1. **Download the Llama3-70B weights from Meta (https://llama.meta.com/):**
+Note: This guide requires the installation / build of `tt-metal`. Please refer to the [installation instructions](/INSTALLING.md) for the release corresponding to [README](/README.md#llms).
+
+1. **Download the Llama3/3.1-70B weights from Meta ():**
2. **Repack the weights:**
+
```bash
# This concatenates the sharded checkpoints and makes it easier for us to load.
python models/demos/t3000/llama2_70b/scripts/repack_weights.py
```
+
Note: Use `5` for `chunk_size`.
Once the weights are repacked, move the `params.json` file from the `checkpoint_dir` to the `repacked_output_dir`.
-### Running the Demo
+### Running the demo from TT-Metalium
After setting up the repacked weights and tokenizer, you can run the demo using the commands below:
1. **Prepare the weight cache directory:**
+
```bash
# Make a directory for us to cache weights into. This speeds up subsequent runs.
mkdir
```
2. **Set up environment variables:**
+
```bash
export LLAMA3_CKPT_DIR=
- export LLAMA3_TOKENIZER_PATH= # Path needs to include the tokenizer.model file
+ export LLAMA3_TOKENIZER_PATH=/tokenizer.model # Path needs to include the tokenizer.model file
export LLAMA3_CACHE_PATH=
export WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml
@@ -38,13 +80,11 @@ After setting up the repacked weights and tokenizer, you can run the demo using
# export LLAMA3_CKPT_DIR="/home/llama-data-repacked/llama-3-70b/"
# export LLAMA3_TOKENIZER_PATH="/home/llama-data-repacked/tokenizer.model"
# export LLAMA3_CACHE_PATH="/home/llama-data-cache/weights-cache"
-
-
```
3. **Run the demo:**
- NOTE: Run the following comand twice.
+ Note: Run the following command twice.
1. The first run will cache the weights. This will take some time.
2. The second run will use the cached weights, thereby running much faster.
@@ -58,31 +98,74 @@ After setting up the repacked weights and tokenizer, you can run the demo using
The above demo does not achieve peak performance because we log outputs to the screen. The following perf test will print an accurate end-to-end throughput number.
For best performance, ensure that tt-metal is built in release mode (default), and ensure the host's CPU frequency governors are set to `performance` -- instructions for setting the frequency governor vary by machine.
This performance test runs with sequence length 128 and batch size 32.
+
```bash
pytest -svv models/demos/t3000/llama2_70b/tests/test_llama_perf_decode.py::test_Llama_perf_host[wormhole_b0-True-device_params0-gen128-llama3]
```
-## Details
+#### Details
Supported context lengths and batch sizes for the Llama3.1-70B demo are as follows:
| Context Length | Max Batch Size |
-|----------------|------------|
-| 2k | 32 |
-| 8k | 16 |
-| 128k | 1 |
+|----------------|----------------|
+| 2k | 32 |
+| 8k | 16 |
+| 128k | 1 |
- **Input File:** Uses `./demo/data/multi_prompt.json`.
- **Model Configuration:** Utilizes a pretrained model.
- **Hardware Requirements:** Runs on an 8-chip T3000 machine using tensor parallelism. The host machine must have at least 512 GB of memory.
- **Demo arguments:**
- - `context: [short_context, long_context, 128k_context]`: Select between short context (batch 32, sequence_length 2k) and long context (batch 16, sequence length 8k) and full context (batch 1, sequence length 128k)
- - `ground_truth: [check_disabled, check_enabled]`: Enable or disable ground truth checking, used for testing
- - `sampling: [greedy, sampling]`: Select between greedy decoding and top-k/top-p sampling
- - `implementation: [tt-70b-T3000]`: Run the 70B model on the Tenstorrent backend
- - `num_layers: [1L, 2L, 10L, 80L]`: Select 80L to run the full model
- - `decode_only: [decode_only, prefill_decode]`: Use `prefill_decode`. Alternately, `decode_only` implements prefill via decode.
- - `chat: [text_completion, chat_completion]`: Run in text_completion mode for the pretrained model or chat_completion for the finetuned model
- - `llama_version: [llama3, llama2]`: Select the Llama3 model
+ - `context: [short_context, long_context, 128k_context]`: Select between short context (batch 32, sequence_length 2k) and long context (batch 16, sequence length 8k) and full context (batch 1, sequence length 128k)
+ - `ground_truth: [check_disabled, check_enabled]`: Enable or disable ground truth checking, used for testing
+ - `sampling: [greedy, sampling]`: Select between greedy decoding and top-k/top-p sampling
+ - `implementation: [tt-70b-T3000]`: Run the 70B model on the Tenstorrent backend
+ - `num_layers: [1L, 2L, 10L, 80L]`: Select 80L to run the full model
+ - `decode_only: [decode_only, prefill_decode]`: Use `prefill_decode`. Alternately, `decode_only` implements prefill via decode.
+ - `chat: [text_completion, chat_completion]`: Run in text_completion mode for the pretrained model or chat_completion for the finetuned model
+ - `llama_version: [llama3, llama2]`: Select the Llama3 model
Ensure you follow these guidelines to successfully run the Llama3-70B demo.
+
+### Serving the model from vLLM
+
+1. Complete Step 1 and Step 2 of [Running the Demo from TT-Metalium](#running-the-demo-from-tt-metalium)
+
+2. **Install vLLM**
+
+ ```bash
+ # Installing from within `tt-metal`
+ export VLLM_TARGET_DEVICE="tt"
+ git clone https://github.com/tenstorrent/vllm.git
+ cd vllm
+ git checkout TT_VLLM_COMMIT_SHA_OR_TAG
+ pip install -e .
+ cd ..
+ ```
+
+ > **Note:** TT_VLLM_COMMIT_SHA_OR_TAG is the vLLM Release version from [README](/README.md#llms)
+
+3. **Running the server**
+
+ ```bash
+ python vllm/examples/server_example_tt.py
+ ```
+
+4. **Interact with server**
+
+ In a separate terminal window, run:
+
+ ```bash
+ curl http://localhost:8000/v1/completions \
+ -H "Content-Type: application/json" \
+ -d '{
+ "model": "meta-llama/Meta-Llama-3.1-70B",
+ "prompt": "Write a poem about RISC-V",
+ "max_tokens": 128,
+ "temperature": 1,
+ "top_p": 0.9,
+ "top_k": 10,
+ "stream": false
+ }'
+ ```
diff --git a/models/demos/t3000/llama3_70b/setup_llama.sh b/models/demos/t3000/llama3_70b/setup_llama.sh
new file mode 100644
index 00000000000..636ce070b2b
--- /dev/null
+++ b/models/demos/t3000/llama3_70b/setup_llama.sh
@@ -0,0 +1,250 @@
+#!/bin/bash
+# SPDX-License-Identifier: Apache-2.0
+#
+# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC
+#
+# Purpose: Setup and deploy Llama 3.1 70B Instruct model with dependencies.
+
+set -euo pipefail
+
+# Function to display usage information
+usage() {
+ cat <
+
+Description:
+ This script sets up and deploys the Llama model along with its dependencies.
+
+Arguments:
+ The type of model to deploy. Supported options:
+ - llama-3.1-70b-instruct
+ - llama-3.1-70b
+ - llama-3.1-8b-instruct
+ - llama-3.1-8b
+ - llama-3-70b-instruct
+ - llama-3-70b
+ - llama-3-8b-instruct
+ - llama-3-8b
+ The commit SHA or tag to use for TT_METAL.
+ The commit SHA or tag to use for vLLM.
+
+Options:
+ -h, --help Display this help message.
+
+Examples:
+ # Deploy the llama-3.1-70b-instruct model
+ $0 llama-3.1-70b-instruct main dev
+
+ # Deploy with specific commit SHAs
+ $0 llama-3.1-70b-instruct v0.53.0-rc36 384f1790c3be16e1d1b10de07252be2e66d00935
+
+EOF
+ exit 0
+}
+
+# helper
+if [[ "$1" == "-h" || "$1" == "--help" ]]; then
+ usage
+fi
+
+# Require commit SHA or tag for TT_METAL and vLLM
+TT_METAL_COMMIT_SHA_OR_TAG=${2:-""}
+TT_VLLM_COMMIT_SHA_OR_TAG=${3:-""}
+
+# Ensure required arguments are passed
+if [[ -z "${TT_METAL_COMMIT_SHA_OR_TAG}" || -z "${TT_VLLM_COMMIT_SHA_OR_TAG}" ]]; then
+ echo "❌ Error: Both TT_METAL_COMMIT_SHA_OR_TAG and TT_VLLM_COMMIT_SHA_OR_TAG are required."
+ usage
+fi
+
+# Defined variables
+DEFAULT_PERSISTENT_VOLUME_ROOT=~/persistent_volume
+DEFAULT_LLAMA_REPO=~/llama-models
+
+# functions
+error_exit() {
+ echo "⛔ Error: $1" >&2
+ exit 1
+}
+
+print_step() {
+ echo -e "\n👉 $1...\n"
+}
+
+setup_model_environment() {
+ print_step "Setting up model environment for $1"
+ case "$1" in
+ "llama-3.1-70b-instruct")
+ MODEL="llama-3.1-70b-instruct"
+ META_MODEL_NAME="Meta-Llama-3.1-70B-Instruct"
+ META_DIR_FILTER="llama3_1"
+ REPACKED=1
+ ;;
+ "llama-3.1-70b")
+ MODEL="llama-3.1-70b"
+ META_MODEL_NAME="Meta-Llama-3.1-70B"
+ META_DIR_FILTER="llama3_1"
+ REPACKED=1
+ ;;
+ "llama-3.1-8b-instruct")
+ MODEL="llama-3.1-8b-instruct"
+ META_MODEL_NAME="Meta-Llama-3.1-8B-Instruct"
+ META_DIR_FILTER="llama3_1"
+ REPACKED=0
+ ;;
+ "llama-3.1-8b")
+ MODEL_NAME="llama-3.1-8b"
+ META_MODEL_NAME="Meta-Llama-3.1-8B"
+ META_DIR_FILTER="llama3_1"
+ REPACKED=0
+ ;;
+ "llama-3-70b-instruct")
+ MODEL="llama-3-70b-instruct"
+ META_MODEL_NAME="Meta-Llama-3-70B-Instruct"
+ META_DIR_FILTER="llama3"
+ REPACKED=1
+ ;;
+ "llama-3-70b")
+ MODEL="llama-3-70b"
+ META_MODEL_NAME="Meta-Llama-3-70B"
+ META_DIR_FILTER="llama3"
+ REPACKED=1
+ ;;
+ "llama-3-8b-instruct")
+ MODEL="llama-3-8b-instruct"
+ META_MODEL_NAME="Meta-Llama-3-8B-Instruct"
+ META_DIR_FILTER="llama3"
+ REPACKED=0
+ ;;
+ "llama-3-8b")
+ MODEL="llama-3-8b"
+ META_MODEL_NAME="Meta-Llama-3-8B"
+ META_DIR_FILTER="llama3"
+ REPACKED=0
+ ;;
+ *)
+ echo "⛔ Invalid model choice."
+ usage
+ exit 1
+ ;;
+ esac
+
+ if [ "${REPACKED}" -eq 1 ]; then
+ echo "REPACKED is enabled."
+ REPACKED_STR="repacked-"
+ else
+ echo "REPACKED is disabled."
+ REPACKED_STR=""
+ fi
+}
+
+setup_environment() {
+ print_step "Setting up environment"
+ export LLAMA3_CKPT_DIR="${DEFAULT_PERSISTENT_VOLUME_ROOT}/model_weights/${REPACKED_STR}${MODEL}"
+ export LLAMA3_TOKENIZER_PATH="${LLAMA3_CKPT_DIR}/tokenizer.model"
+ export LLAMA3_CACHE_PATH="${DEFAULT_PERSISTENT_VOLUME_ROOT}/tt_metal_cache/cache_${REPACKED_STR}${MODEL}"
+ export ARCH_NAME=wormhole_b0
+ export TT_METAL_HOME=$(pwd)
+ export PYTHONPATH=$(pwd)
+ echo "Environment variables set."
+}
+
+check_and_build_tt_metal() {
+ print_step "Checking and building tt-metal"
+ pushd "${TT_METAL_HOME}" >/dev/null
+ if [[ ! -d "python_env" ]]; then
+ git checkout "${TT_METAL_COMMIT_SHA_OR_TAG}"
+ git submodule update --init --recursive
+ git submodule foreach 'git lfs fetch --all && git lfs pull'
+ ./build_metal.sh
+ ./create_venv.sh
+ source python_env/bin/activate
+ pip install -r models/demos/t3000/llama2_70b/reference/llama/requirements.txt
+ else
+ echo "🔔 tt-metal Python environment already exists. Skipping build."
+ source python_env/bin/activate
+ fi
+ popd >/dev/null
+}
+
+clone_repo() {
+ local REPO_PATH=$1
+ local REPO_URL=$2
+ local COMMIT_HASH=$3
+
+ print_step "Cloning Llama repository"
+ if [[ ! -d "${REPO_PATH}" ]]; then
+ git clone "${REPO_URL}" "${REPO_PATH}"
+ pushd "${REPO_PATH}" >/dev/null
+ git checkout "${COMMIT_HASH}"
+ popd >/dev/null
+ else
+ echo "🔔 Repository already exists at ${REPO_PATH}, skipping clone."
+ fi
+}
+
+setup_weights() {
+ print_step "Setting up weights"
+ local LLAMA_REPO=$1
+ local LLAMA_DIR="${LLAMA_REPO}/models/${META_DIR_FILTER}"
+ local LLAMA_WEIGHTS_DIR="${LLAMA_DIR}/${META_MODEL_NAME}"
+ local WEIGHTS_DIR="${LLAMA3_CKPT_DIR}"
+
+ mkdir -p "${WEIGHTS_DIR}" "${LLAMA3_CACHE_PATH}"
+
+ if [[ -d "${LLAMA_WEIGHTS_DIR}" && -n "$(ls -A "${LLAMA_WEIGHTS_DIR}")" ]]; then
+ echo "Weights already downloaded in ${LLAMA_WEIGHTS_DIR}"
+ else
+ print_step "Downloading weights"
+ pushd "${LLAMA_DIR}" >/dev/null
+ [[ -x "./download.sh" ]] && ./download.sh || error_exit "Download script not found!"
+ popd >/dev/null
+ fi
+
+ huggingface-cli login
+
+ if [ "${REPACKED}" -eq 1 ]; then
+ print_step "Repacking weights"
+ source python_env/bin/activate
+ cp "${LLAMA_WEIGHTS_DIR}/tokenizer.model" "${WEIGHTS_DIR}/tokenizer.model"
+ cp "${LLAMA_WEIGHTS_DIR}/params.json" "${WEIGHTS_DIR}/params.json"
+ python models/demos/t3000/llama2_70b/scripts/repack_weights.py "${LLAMA_WEIGHTS_DIR}" "${WEIGHTS_DIR}" 5
+ else
+ cp -rf "${LLAMA_WEIGHTS_DIR}" "${WEIGHTS_DIR}"
+ fi
+
+ echo "🔔 Using weights directory ${WEIGHTS_DIR}"
+}
+
+install_vllm() {
+ print_step "Installing vLLM"
+ if [[ ! -d "vllm" ]]; then
+ source python_env/bin/activate
+ export VLLM_TARGET_DEVICE="tt"
+ git clone https://github.com/tenstorrent/vllm.git
+ pushd vllm >/dev/null
+ git checkout "${TT_VLLM_COMMIT_SHA_OR_TAG}"
+ pip install -e .
+ popd >/dev/null
+ else
+ echo "🔔 vLLM already installed. Skipping install."
+ fi
+}
+
+deploy_server() {
+ print_step "Deploying Llama server"
+ source python_env/bin/activate
+ export WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml
+ python vllm/examples/server_example_tt.py
+ echo "✅ Deployment complete! Interact via http://localhost:8000."
+}
+
+# ---- MAIN ----
+MODEL_TYPE=$1
+setup_model_environment "$MODEL_TYPE"
+setup_environment
+check_and_build_tt_metal
+clone_repo "${DEFAULT_LLAMA_REPO}" "https://github.com/meta-llama/llama-models.git" "685ac4c107c75ce8c291248710bf990a876e1623"
+setup_weights "${DEFAULT_LLAMA_REPO}"
+install_vllm
+deploy_server
diff --git a/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_large_new_conv_api.py b/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_large_new_conv_api.py
index cfe555d0367..ce49bfbfa51 100644
--- a/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_large_new_conv_api.py
+++ b/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_large_new_conv_api.py
@@ -167,7 +167,7 @@ def run_downsample_if_req(
shard_layout = (
ttnn.TensorMemoryLayout.HEIGHT_SHARDED if height_sharding else ttnn.TensorMemoryLayout.BLOCK_SHARDED
)
- ds_out, _, _, self.ds_conv_weight_tensor, self.ds_conv_bias_tensor = ttnn.conv2d(
+ ds_out, [self.ds_conv_weight_tensor, self.ds_conv_bias_tensor] = ttnn.conv2d(
input_tensor=x,
weight_tensor=self.ds_conv_weight_tensor,
in_channels=self.ds_conv_input_channels,
@@ -183,13 +183,17 @@ def run_downsample_if_req(
conv_config=ttnn.Conv2dConfig(
dtype=self.model_config["ACTIVATIONS_DTYPE"],
weights_dtype=self.model_config["WEIGHTS_DTYPE"],
- math_fidelity=self.model_config["MATH_FIDELITY"],
shard_layout=shard_layout,
deallocate_activation=True,
reallocate_halo_output=True,
reshard_if_not_optimal=reshard_if_not_optimal,
),
+ compute_config=ttnn.init_device_compute_kernel_config(
+ device.arch(), math_fidelity=self.model_config["MATH_FIDELITY"]
+ ),
conv_op_cache=conv_op_cache,
+ return_output_dim=False,
+ return_weights_and_bias=True,
)
ttnn.deallocate(x)
ds_out = ttnn.reallocate(ds_out)
@@ -214,7 +218,7 @@ def __call__(
# conv1 is 1x1 conv
# print("Running conv1")
module_input_height = input_height
- out, input_height, input_width, self.conv1_weight_tensor, self.conv1_bias_tensor = ttnn.conv2d(
+ out, [input_height, input_width], [self.conv1_weight_tensor, self.conv1_bias_tensor] = ttnn.conv2d(
input_tensor=x,
weight_tensor=self.conv1_weight_tensor,
in_channels=self.conv1_input_channels,
@@ -230,14 +234,18 @@ def __call__(
conv_config=ttnn.Conv2dConfig(
dtype=self.model_config["ACTIVATIONS_DTYPE"],
weights_dtype=self.model_config["WEIGHTS_DTYPE"],
- math_fidelity=self.model_config["MATH_FIDELITY"],
activation="relu",
shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED
if height_sharding
else ttnn.TensorMemoryLayout.BLOCK_SHARDED,
reshard_if_not_optimal=reshard_if_not_optimal,
),
+ compute_config=ttnn.init_device_compute_kernel_config(
+ device.arch(), math_fidelity=self.model_config["MATH_FIDELITY"]
+ ),
conv_op_cache=conv_op_cache,
+ return_output_dim=True,
+ return_weights_and_bias=True,
)
act_block_h_override = 0
@@ -277,7 +285,7 @@ def __call__(
)
# if ds_out_mem_config and ds_out_mem_config != ttnn.get_memory_config(out):
# out = ttnn.to_memory_config(out, ds_out_mem_config)
- out, input_height, input_width, self.conv2_weight_tensor, self.conv2_bias_tensor = ttnn.conv2d(
+ out, [input_height, input_width], [self.conv2_weight_tensor, self.conv2_bias_tensor] = ttnn.conv2d(
input_tensor=out,
weight_tensor=self.conv2_weight_tensor,
in_channels=self.conv2_input_channels,
@@ -293,7 +301,6 @@ def __call__(
conv_config=ttnn.Conv2dConfig(
dtype=self.model_config["ACTIVATIONS_DTYPE"],
weights_dtype=self.model_config["WEIGHTS_DTYPE"],
- math_fidelity=self.model_config["MATH_FIDELITY"],
activation="relu",
deallocate_activation=True,
reallocate_halo_output=reallocate_halo_output,
@@ -303,12 +310,17 @@ def __call__(
else ttnn.TensorMemoryLayout.BLOCK_SHARDED,
reshard_if_not_optimal=reshard_if_not_optimal,
),
+ compute_config=ttnn.init_device_compute_kernel_config(
+ device.arch(), math_fidelity=self.model_config["MATH_FIDELITY"]
+ ),
conv_op_cache=conv_op_cache,
+ return_output_dim=True,
+ return_weights_and_bias=True,
)
# conv3 is 1x1 conv
# print("Running conv3")
- out, _, _, self.conv3_weight_tensor, self.conv3_bias_tensor = ttnn.conv2d(
+ out, [self.conv3_weight_tensor, self.conv3_bias_tensor] = ttnn.conv2d(
input_tensor=out,
weight_tensor=self.conv3_weight_tensor,
in_channels=self.conv3_input_channels,
@@ -324,13 +336,17 @@ def __call__(
conv_config=ttnn.Conv2dConfig(
dtype=self.model_config["ACTIVATIONS_DTYPE"],
weights_dtype=self.model_config["WEIGHTS_DTYPE"],
- math_fidelity=self.model_config["MATH_FIDELITY"],
shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED
if height_sharding
else ttnn.TensorMemoryLayout.BLOCK_SHARDED,
reshard_if_not_optimal=reshard_if_not_optimal,
),
+ compute_config=ttnn.init_device_compute_kernel_config(
+ device.arch(), math_fidelity=self.model_config["MATH_FIDELITY"]
+ ),
conv_op_cache=conv_op_cache,
+ return_weights_and_bias=True,
+ return_output_dim=False,
)
if not self.run_downsample_before_conv2:
@@ -546,7 +562,7 @@ def first_run(self, input_tensor, device, batch_size, ops_parallel_config) -> tt
input_tensor, device=device, memory_config=self.grayskull_conv1_input_memory_config
)
- x, x_height, x_width, self.conv1_weight_tensor, self.conv1_bias_tensor = ttnn.conv2d(
+ x, [x_height, x_width], [self.conv1_weight_tensor, self.conv1_bias_tensor] = ttnn.conv2d(
input_tensor=input_tensor,
weight_tensor=self.conv1_weight_tensor,
in_channels=self.conv1_input_channels,
@@ -562,13 +578,17 @@ def first_run(self, input_tensor, device, batch_size, ops_parallel_config) -> tt
conv_config=ttnn.Conv2dConfig(
dtype=self.model_config["ACTIVATIONS_DTYPE"],
weights_dtype=self.model_config["WEIGHTS_DTYPE"],
- math_fidelity=self.model_config["MATH_FIDELITY"],
activation="relu",
deallocate_activation=True,
input_channels_alignment=16 if not is_wormhole_b0() else 32,
act_block_h_override=act_block_h_override,
),
+ compute_config=ttnn.init_device_compute_kernel_config(
+ device.arch(), math_fidelity=self.model_config["MATH_FIDELITY"]
+ ),
conv_op_cache=conv_op_cache,
+ return_output_dim=True,
+ return_weights_and_bias=True,
)
# Relu is fused with conv1
@@ -857,7 +877,7 @@ def optimized_run(self, input_tensor, device, batch_size, ops_parallel_config, c
input_tensor, device=device, memory_config=self.grayskull_conv1_input_memory_config
)
- x, x_height, x_width, self.conv1_weight_tensor, self.conv1_bias_tensor = ttnn.conv2d(
+ x, [x_height, x_width], [self.conv1_weight_tensor, self.conv1_bias_tensor] = ttnn.conv2d(
input_tensor=input_tensor,
weight_tensor=self.conv1_weight_tensor,
in_channels=self.conv1_input_channels,
@@ -873,13 +893,17 @@ def optimized_run(self, input_tensor, device, batch_size, ops_parallel_config, c
conv_config=ttnn.Conv2dConfig(
dtype=self.model_config["ACTIVATIONS_DTYPE"],
weights_dtype=self.model_config["WEIGHTS_DTYPE"],
- math_fidelity=self.model_config["MATH_FIDELITY"],
activation="relu",
deallocate_activation=True,
input_channels_alignment=16 if not is_wormhole_b0() else 32,
act_block_h_override=act_block_h_override,
),
+ compute_config=ttnn.init_device_compute_kernel_config(
+ device.arch(), math_fidelity=self.model_config["MATH_FIDELITY"]
+ ),
conv_op_cache=conv_op_cache,
+ return_output_dim=True,
+ return_weights_and_bias=True,
)
# Relu is fused with conv1
diff --git a/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_new_conv_api.py b/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_new_conv_api.py
index 44d90cb0f34..a8944b654c3 100644
--- a/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_new_conv_api.py
+++ b/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_new_conv_api.py
@@ -160,7 +160,7 @@ def run_downsample_if_req(
):
if self.downsample:
logger.debug(f"Running downsample")
- ds_out, _, _, self.ds_conv_weight_tensor, self.ds_conv_bias_tensor = ttnn.conv2d(
+ ds_out, [self.ds_conv_weight_tensor, self.ds_conv_bias_tensor] = ttnn.conv2d(
input_tensor=x,
weight_tensor=self.ds_conv_weight_tensor,
in_channels=self.ds_conv_input_channels,
@@ -176,7 +176,6 @@ def run_downsample_if_req(
conv_config=ttnn.Conv2dConfig(
dtype=self.model_config["ACTIVATIONS_DTYPE"],
weights_dtype=self.model_config["WEIGHTS_DTYPE"],
- math_fidelity=self.model_config["MATH_FIDELITY"],
shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED
if height_sharding
else ttnn.TensorMemoryLayout.BLOCK_SHARDED,
@@ -184,7 +183,6 @@ def run_downsample_if_req(
reallocate_halo_output=not (is_wormhole_b0() and batch_size == 16),
reshard_if_not_optimal=reshard_if_not_optimal,
transpose_shards=transpose_shards,
- packer_l1_accum_enabled=packer_l1_accum_enabled,
enable_act_double_buffer=enable_act_double_buffer
if height_sharding
else True
@@ -194,7 +192,14 @@ def run_downsample_if_req(
enable_split_reader=enable_split_reader,
enable_subblock_padding=enable_subblock_padding,
),
+ compute_config=ttnn.init_device_compute_kernel_config(
+ device.arch(),
+ math_fidelity=self.model_config["MATH_FIDELITY"],
+ packer_l1_acc=packer_l1_accum_enabled,
+ ),
conv_op_cache=conv_op_cache,
+ return_output_dim=False,
+ return_weights_and_bias=True,
)
ttnn.deallocate(x)
ds_out = ttnn.reallocate(ds_out)
@@ -226,7 +231,7 @@ def __call__(
# conv1 is 1x1 conv
logger.debug(f"Running conv1")
module_input_height = input_height
- out, input_height, input_width, self.conv1_weight_tensor, self.conv1_bias_tensor = ttnn.conv2d(
+ out, [input_height, input_width], [self.conv1_weight_tensor, self.conv1_bias_tensor] = ttnn.conv2d(
input_tensor=x,
weight_tensor=self.conv1_weight_tensor,
in_channels=self.conv1_input_channels,
@@ -242,16 +247,21 @@ def __call__(
conv_config=ttnn.Conv2dConfig(
dtype=self.model_config["ACTIVATIONS_DTYPE"],
weights_dtype=self.model_config["WEIGHTS_DTYPE"],
- math_fidelity=self.model_config["MATH_FIDELITY"],
activation="relu",
shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED
if height_sharding
else ttnn.TensorMemoryLayout.BLOCK_SHARDED,
reshard_if_not_optimal=reshard_if_not_optimal,
transpose_shards=transpose_shards,
- packer_l1_accum_enabled=packer_l1_acc,
+ ),
+ compute_config=ttnn.init_device_compute_kernel_config(
+ device.arch(),
+ math_fidelity=self.model_config["MATH_FIDELITY"],
+ packer_l1_acc=packer_l1_acc,
),
conv_op_cache=conv_op_cache,
+ return_output_dim=True,
+ return_weights_and_bias=True,
)
act_block_h_override = 0
@@ -307,7 +317,7 @@ def __call__(
reallocate_halo_output = batch_size == 20
logger.debug(f"Running conv2")
- out, input_height, input_width, self.conv2_weight_tensor, self.conv2_bias_tensor = ttnn.conv2d(
+ out, [input_height, input_width], [self.conv2_weight_tensor, self.conv2_bias_tensor] = ttnn.conv2d(
input_tensor=out,
weight_tensor=self.conv2_weight_tensor,
in_channels=self.conv2_input_channels,
@@ -323,7 +333,6 @@ def __call__(
conv_config=ttnn.Conv2dConfig(
dtype=self.model_config["ACTIVATIONS_DTYPE"],
weights_dtype=self.model_config["WEIGHTS_DTYPE"],
- math_fidelity=self.model_config["MATH_FIDELITY"],
activation="relu",
deallocate_activation=True,
reallocate_halo_output=reallocate_halo_output,
@@ -333,13 +342,19 @@ def __call__(
else ttnn.TensorMemoryLayout.BLOCK_SHARDED,
reshard_if_not_optimal=reshard_if_not_optimal,
transpose_shards=transpose_shards,
- packer_l1_accum_enabled=packer_l1_acc,
enable_act_double_buffer=enable_act_double_buffer,
enable_weights_double_buffer=True,
enable_split_reader=enable_split_reader,
enable_subblock_padding=enable_subblock_padding,
),
+ compute_config=ttnn.init_device_compute_kernel_config(
+ device.arch(),
+ math_fidelity=self.model_config["MATH_FIDELITY"],
+ packer_l1_acc=packer_l1_acc,
+ ),
conv_op_cache=conv_op_cache,
+ return_output_dim=True,
+ return_weights_and_bias=True,
)
logger.debug(
@@ -358,7 +373,7 @@ def __call__(
# conv3 is 1x1 conv
logger.debug(f"Running conv3")
- out, _, _, self.conv3_weight_tensor, self.conv3_bias_tensor = ttnn.conv2d(
+ out, [self.conv3_weight_tensor, self.conv3_bias_tensor] = ttnn.conv2d(
input_tensor=out,
weight_tensor=self.conv3_weight_tensor,
in_channels=self.conv3_input_channels,
@@ -374,15 +389,20 @@ def __call__(
conv_config=ttnn.Conv2dConfig(
dtype=self.model_config["ACTIVATIONS_DTYPE"],
weights_dtype=self.model_config["WEIGHTS_DTYPE"],
- math_fidelity=self.model_config["MATH_FIDELITY"],
shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED
if height_sharding
else ttnn.TensorMemoryLayout.BLOCK_SHARDED,
reshard_if_not_optimal=reshard_if_not_optimal,
transpose_shards=transpose_shards,
- packer_l1_accum_enabled=packer_l1_acc,
+ ),
+ compute_config=ttnn.init_device_compute_kernel_config(
+ device.arch(),
+ math_fidelity=self.model_config["MATH_FIDELITY"],
+ packer_l1_acc=packer_l1_acc,
),
conv_op_cache=conv_op_cache,
+ return_output_dim=False,
+ return_weights_and_bias=True,
)
if not run_downsample_before_conv2:
@@ -569,19 +589,22 @@ def __init__(
self.conv1_config = ttnn.Conv2dConfig(
dtype=self.model_config["ACTIVATIONS_DTYPE"],
weights_dtype=self.model_config["WEIGHTS_DTYPE"],
- math_fidelity=self.model_config["MATH_FIDELITY"],
activation="relu",
deallocate_activation=dealloc_input,
input_channels_alignment=input_channels_alignment,
act_block_h_override=act_block_h_override,
transpose_shards=self.transpose_shards,
- packer_l1_accum_enabled=True if whb0_and_b16 else False,
enable_act_double_buffer=True if whb0_and_b16 else False,
enable_split_reader=True if whb0_and_b16 or not is_wormhole_b0() else False,
enable_subblock_padding=False,
shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED,
reshard_if_not_optimal=False,
)
+ self.conv1_compute_config = ttnn.init_device_compute_kernel_config(
+ device.arch(),
+ math_fidelity=self.model_config["MATH_FIDELITY"],
+ packer_l1_acc=True if whb0_and_b16 else False,
+ )
if whb0_and_b16:
# Issue #13145: Temp workaround for Galaxy to avoid hangs
if type(device) == ttnn.MeshDevice and device.get_num_devices() > 8:
@@ -719,7 +742,7 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt
logger.debug(f"==== first conv")
# first conv
- x, x_height, x_width, self.conv1_weight_tensor, self.conv1_bias_tensor = ttnn.conv2d(
+ x, [x_height, x_width], [self.conv1_weight_tensor, self.conv1_bias_tensor] = ttnn.conv2d(
input_tensor=fold_output_tensor,
weight_tensor=self.conv1_weight_tensor,
in_channels=self.conv1_input_channels,
@@ -733,7 +756,10 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt
input_height=self.conv1_input_height,
input_width=self.conv1_input_width,
conv_config=self.conv1_config,
+ compute_config=self.conv1_compute_config,
conv_op_cache=conv_op_cache,
+ return_output_dim=True,
+ return_weights_and_bias=True,
)
# Relu is fused with conv1
if self.batch_size == 20:
diff --git a/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_xlarge_new_conv_api.py b/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_xlarge_new_conv_api.py
index 5c0750003c1..a5427f1fc87 100644
--- a/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_xlarge_new_conv_api.py
+++ b/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_xlarge_new_conv_api.py
@@ -162,7 +162,7 @@ def run_downsample_if_req(
height_sharding=None,
):
if self.downsample:
- ds_out, _, _, self.ds_conv_weight_tensor, self.ds_conv_bias_tensor = ttnn.conv2d(
+ ds_out, [self.ds_conv_weight_tensor, self.ds_conv_bias_tensor] = ttnn.conv2d(
input_tensor=x,
weight_tensor=self.ds_conv_weight_tensor,
in_channels=self.ds_conv_input_channels,
@@ -178,7 +178,6 @@ def run_downsample_if_req(
conv_config=ttnn.Conv2dConfig(
dtype=self.model_config["ACTIVATIONS_DTYPE"],
weights_dtype=self.model_config["WEIGHTS_DTYPE"],
- math_fidelity=self.model_config["MATH_FIDELITY"],
shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED
if height_sharding
else ttnn.TensorMemoryLayout.BLOCK_SHARDED,
@@ -186,7 +185,12 @@ def run_downsample_if_req(
reallocate_halo_output=True,
reshard_if_not_optimal=reshard_if_not_optimal,
),
+ compute_config=ttnn.init_device_compute_kernel_config(
+ device.arch(), math_fidelity=self.model_config["MATH_FIDELITY"]
+ ),
conv_op_cache=conv_op_cache,
+ return_output_dim=False,
+ return_weights_and_bias=True,
)
ttnn.deallocate(x)
ds_out = ttnn.reallocate(ds_out)
@@ -209,7 +213,7 @@ def __call__(
# conv1 is 1x1 conv
# print("Running conv1")
module_input_height = input_height
- out, input_height, input_width, self.conv1_weight_tensor, self.conv1_bias_tensor = ttnn.conv2d(
+ out, [input_height, input_width], [self.conv1_weight_tensor, self.conv1_bias_tensor] = ttnn.conv2d(
input_tensor=x,
weight_tensor=self.conv1_weight_tensor,
in_channels=self.conv1_input_channels,
@@ -225,14 +229,18 @@ def __call__(
conv_config=ttnn.Conv2dConfig(
dtype=self.model_config["ACTIVATIONS_DTYPE"],
weights_dtype=self.model_config["WEIGHTS_DTYPE"],
- math_fidelity=self.model_config["MATH_FIDELITY"],
activation="relu",
shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED
if height_sharding
else ttnn.TensorMemoryLayout.BLOCK_SHARDED,
reshard_if_not_optimal=reshard_if_not_optimal,
),
+ compute_config=ttnn.init_device_compute_kernel_config(
+ device.arch(), math_fidelity=self.model_config["MATH_FIDELITY"]
+ ),
conv_op_cache=conv_op_cache,
+ return_output_dim=True,
+ return_weights_and_bias=True,
)
act_block_h_override = 0
@@ -270,7 +278,7 @@ def __call__(
# self.conv1_input_channels == 256 and
# self.downsample
)
- out, input_height, input_width, self.conv2_weight_tensor, self.conv2_bias_tensor = ttnn.conv2d(
+ out, [input_height, input_width], [self.conv2_weight_tensor, self.conv2_bias_tensor] = ttnn.conv2d(
input_tensor=out,
weight_tensor=self.conv2_weight_tensor,
in_channels=self.conv2_input_channels,
@@ -286,7 +294,6 @@ def __call__(
conv_config=ttnn.Conv2dConfig(
dtype=self.model_config["ACTIVATIONS_DTYPE"],
weights_dtype=self.model_config["WEIGHTS_DTYPE"],
- math_fidelity=self.model_config["MATH_FIDELITY"],
activation="relu",
deallocate_activation=True,
reallocate_halo_output=reallocate_halo_output,
@@ -296,12 +303,17 @@ def __call__(
else ttnn.TensorMemoryLayout.BLOCK_SHARDED,
reshard_if_not_optimal=reshard_if_not_optimal,
),
+ compute_config=ttnn.init_device_compute_kernel_config(
+ device.arch(), math_fidelity=self.model_config["MATH_FIDELITY"]
+ ),
conv_op_cache=conv_op_cache,
+ return_output_dim=True,
+ return_weights_and_bias=True,
)
# conv3 is 1x1 conv
# print("Running conv3")
- out, _, _, self.conv3_weight_tensor, self.conv3_bias_tensor = ttnn.conv2d(
+ out, [self.conv3_weight_tensor, self.conv3_bias_tensor] = ttnn.conv2d(
input_tensor=out,
weight_tensor=self.conv3_weight_tensor,
in_channels=self.conv3_input_channels,
@@ -317,13 +329,17 @@ def __call__(
conv_config=ttnn.Conv2dConfig(
dtype=self.model_config["ACTIVATIONS_DTYPE"],
weights_dtype=self.model_config["WEIGHTS_DTYPE"],
- math_fidelity=self.model_config["MATH_FIDELITY"],
shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED
if height_sharding
else ttnn.TensorMemoryLayout.BLOCK_SHARDED,
reshard_if_not_optimal=reshard_if_not_optimal,
),
+ compute_config=ttnn.init_device_compute_kernel_config(
+ device.arch(), math_fidelity=self.model_config["MATH_FIDELITY"]
+ ),
conv_op_cache=conv_op_cache,
+ return_output_dim=False,
+ return_weights_and_bias=True,
)
if not self.run_downsample_before_conv2:
@@ -516,7 +532,7 @@ def first_run(self, input_tensor, device, batch_size, ops_parallel_config) -> tt
elif batch_size == 20:
act_block_h_override = 640
- x, x_height, x_width, self.conv1_weight_tensor, self.conv1_bias_tensor = ttnn.conv2d(
+ x, [x_height, x_width], [self.conv1_weight_tensor, self.conv1_bias_tensor] = ttnn.conv2d(
input_tensor=input_tensor,
weight_tensor=self.conv1_weight_tensor,
in_channels=self.conv1_input_channels,
@@ -532,13 +548,17 @@ def first_run(self, input_tensor, device, batch_size, ops_parallel_config) -> tt
conv_config=ttnn.Conv2dConfig(
dtype=self.model_config["ACTIVATIONS_DTYPE"],
weights_dtype=self.model_config["WEIGHTS_DTYPE"],
- math_fidelity=self.model_config["MATH_FIDELITY"],
activation="relu",
deallocate_activation=True,
input_channels_alignment=16 if not is_wormhole_b0() else 32,
act_block_h_override=act_block_h_override,
),
+ compute_config=ttnn.init_device_compute_kernel_config(
+ device.arch(), math_fidelity=self.model_config["MATH_FIDELITY"]
+ ),
conv_op_cache=conv_op_cache,
+ return_output_dim=True,
+ return_weights_and_bias=True,
)
# Relu is fused with conv1
@@ -819,7 +839,7 @@ def optimized_run(self, input_tensor, device, batch_size, ops_parallel_config, c
else:
act_block_h_override = 0
- x, x_height, x_width, self.conv1_weight_tensor, self.conv1_bias_tensor = ttnn.conv2d(
+ x, [x_height, x_width], [self.conv1_weight_tensor, self.conv1_bias_tensor] = ttnn.conv2d(
input_tensor=input_tensor,
weight_tensor=self.conv1_weight_tensor,
in_channels=self.conv1_input_channels,
@@ -835,13 +855,17 @@ def optimized_run(self, input_tensor, device, batch_size, ops_parallel_config, c
conv_config=ttnn.Conv2dConfig(
dtype=self.model_config["ACTIVATIONS_DTYPE"],
weights_dtype=self.model_config["WEIGHTS_DTYPE"],
- math_fidelity=self.model_config["MATH_FIDELITY"],
activation="relu",
deallocate_activation=True,
input_channels_alignment=16 if not is_wormhole_b0() else 32,
act_block_h_override=act_block_h_override,
),
+ compute_config=ttnn.init_device_compute_kernel_config(
+ device.arch(), math_fidelity=self.model_config["MATH_FIDELITY"]
+ ),
conv_op_cache=conv_op_cache,
+ return_output_dim=True,
+ return_weights_and_bias=True,
)
# Relu is fused with conv1
diff --git a/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_xlarge_new_conv_api_24.py b/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_xlarge_new_conv_api_24.py
index f2e266e1d8b..6bc5013bbf6 100644
--- a/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_xlarge_new_conv_api_24.py
+++ b/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_xlarge_new_conv_api_24.py
@@ -164,7 +164,7 @@ def run_downsample_if_req(
height_sharding=None,
):
if self.downsample:
- ds_out, _, _, self.ds_conv_weight_tensor, self.ds_conv_bias_tensor = ttnn.conv2d(
+ ds_out, [self.ds_conv_weight_tensor, self.ds_conv_bias_tensor] = ttnn.conv2d(
input_tensor=x,
weight_tensor=self.ds_conv_weight_tensor,
in_channels=self.ds_conv_input_channels,
@@ -180,7 +180,6 @@ def run_downsample_if_req(
conv_config=ttnn.Conv2dConfig(
dtype=self.model_config["ACTIVATIONS_DTYPE"],
weights_dtype=self.model_config["WEIGHTS_DTYPE"],
- math_fidelity=self.model_config["MATH_FIDELITY"],
shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED
if height_sharding
else ttnn.TensorMemoryLayout.BLOCK_SHARDED,
@@ -188,7 +187,12 @@ def run_downsample_if_req(
reallocate_halo_output=True,
reshard_if_not_optimal=reshard_if_not_optimal,
),
+ compute_config=ttnn.init_device_compute_kernel_config(
+ device.arch(), math_fidelity=self.model_config["MATH_FIDELITY"]
+ ),
conv_op_cache=conv_op_cache,
+ return_output_dim=False,
+ return_weights_and_bias=True,
)
ttnn.deallocate(x)
ds_out = ttnn.reallocate(ds_out)
@@ -211,7 +215,7 @@ def __call__(
# conv1 is 1x1 conv
# print("Running conv1")
module_input_height = input_height
- out, input_height, input_width, self.conv1_weight_tensor, self.conv1_bias_tensor = ttnn.conv2d(
+ out, [input_height, input_width], [self.conv1_weight_tensor, self.conv1_bias_tensor] = ttnn.conv2d(
input_tensor=x,
weight_tensor=self.conv1_weight_tensor,
in_channels=self.conv1_input_channels,
@@ -227,14 +231,18 @@ def __call__(
conv_config=ttnn.Conv2dConfig(
dtype=self.model_config["ACTIVATIONS_DTYPE"],
weights_dtype=self.model_config["WEIGHTS_DTYPE"],
- math_fidelity=self.model_config["MATH_FIDELITY"],
activation="relu",
shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED
if height_sharding
else ttnn.TensorMemoryLayout.BLOCK_SHARDED,
reshard_if_not_optimal=reshard_if_not_optimal,
),
+ compute_config=ttnn.init_device_compute_kernel_config(
+ device.arch(), math_fidelity=self.model_config["MATH_FIDELITY"]
+ ),
conv_op_cache=conv_op_cache,
+ return_output_dim=True,
+ return_weights_and_bias=True,
)
act_block_h_override = 0
@@ -273,7 +281,7 @@ def __call__(
logger.info(
f"Running conv2 with reallocate_halo_output={reallocate_halo_output}, input_height={input_height}, conv2_output_channels={self.conv2_output_channels}"
)
- out, input_height, input_width, self.conv2_weight_tensor, self.conv2_bias_tensor = ttnn.conv2d(
+ out, [input_height, input_width], [self.conv2_weight_tensor, self.conv2_bias_tensor] = ttnn.conv2d(
input_tensor=out,
weight_tensor=self.conv2_weight_tensor,
in_channels=self.conv2_input_channels,
@@ -289,7 +297,6 @@ def __call__(
conv_config=ttnn.Conv2dConfig(
dtype=self.model_config["ACTIVATIONS_DTYPE"],
weights_dtype=self.model_config["WEIGHTS_DTYPE"],
- math_fidelity=self.model_config["MATH_FIDELITY"],
activation="relu",
deallocate_activation=True,
reallocate_halo_output=reallocate_halo_output,
@@ -299,12 +306,17 @@ def __call__(
else ttnn.TensorMemoryLayout.BLOCK_SHARDED,
reshard_if_not_optimal=reshard_if_not_optimal,
),
+ compute_config=ttnn.init_device_compute_kernel_config(
+ device.arch(), math_fidelity=self.model_config["MATH_FIDELITY"]
+ ),
conv_op_cache=conv_op_cache,
+ return_output_dim=True,
+ return_weights_and_bias=True,
)
# conv3 is 1x1 conv
# print("Running conv3")
- out, _, _, self.conv3_weight_tensor, self.conv3_bias_tensor = ttnn.conv2d(
+ out, self.conv3_weight_tensor, self.conv3_bias_tensor = ttnn.conv2d(
input_tensor=out,
weight_tensor=self.conv3_weight_tensor,
in_channels=self.conv3_input_channels,
@@ -320,13 +332,17 @@ def __call__(
conv_config=ttnn.Conv2dConfig(
dtype=self.model_config["ACTIVATIONS_DTYPE"],
weights_dtype=self.model_config["WEIGHTS_DTYPE"],
- math_fidelity=self.model_config["MATH_FIDELITY"],
shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED
if height_sharding
else ttnn.TensorMemoryLayout.BLOCK_SHARDED,
reshard_if_not_optimal=reshard_if_not_optimal,
),
+ compute_config=ttnn.init_device_compute_kernel_config(
+ device.arch(), math_fidelity=self.model_config["MATH_FIDELITY"]
+ ),
conv_op_cache=conv_op_cache,
+ return_output_dim=True,
+ return_weights_and_bias=True,
)
if not self.run_downsample_before_conv2:
@@ -541,7 +557,7 @@ def first_run(self, input_tensor, device, batch_size, ops_parallel_config) -> tt
input_tensor, device=device, memory_config=self.grayskull_conv1_input_memory_config
)
- x, x_height, x_width, self.conv1_weight_tensor, self.conv1_bias_tensor = ttnn.conv2d(
+ x, [x_height, x_width], [self.conv1_weight_tensor, self.conv1_bias_tensor] = ttnn.conv2d(
input_tensor=input_tensor,
weight_tensor=self.conv1_weight_tensor,
in_channels=self.conv1_input_channels,
@@ -557,13 +573,17 @@ def first_run(self, input_tensor, device, batch_size, ops_parallel_config) -> tt
conv_config=ttnn.Conv2dConfig(
dtype=self.model_config["ACTIVATIONS_DTYPE"],
weights_dtype=self.model_config["WEIGHTS_DTYPE"],
- math_fidelity=self.model_config["MATH_FIDELITY"],
activation="relu",
deallocate_activation=True,
input_channels_alignment=16 if not is_wormhole_b0() else 32,
act_block_h_override=act_block_h_override,
),
+ compute_config=ttnn.init_device_compute_kernel_config(
+ device.arch(), math_fidelity=self.model_config["MATH_FIDELITY"]
+ ),
conv_op_cache=conv_op_cache,
+ return_output_dim=True,
+ return_weights_and_bias=True,
)
# Relu is fused with conv1
@@ -872,7 +892,7 @@ def optimized_run(self, input_tensor, device, batch_size, ops_parallel_config, c
elif batch_size == 20:
act_block_h_override = 640
- x, x_height, x_width, self.conv1_weight_tensor, self.conv1_bias_tensor = ttnn.conv2d(
+ x, [x_height, x_width], [self.conv1_weight_tensor, self.conv1_bias_tensor] = ttnn.conv2d(
input_tensor=input_tensor,
weight_tensor=self.conv1_weight_tensor,
in_channels=self.conv1_input_channels,
@@ -888,13 +908,17 @@ def optimized_run(self, input_tensor, device, batch_size, ops_parallel_config, c
conv_config=ttnn.Conv2dConfig(
dtype=self.model_config["ACTIVATIONS_DTYPE"],
weights_dtype=self.model_config["WEIGHTS_DTYPE"],
- math_fidelity=self.model_config["MATH_FIDELITY"],
activation="relu",
deallocate_activation=True,
input_channels_alignment=16 if not is_wormhole_b0() else 32,
act_block_h_override=act_block_h_override,
),
+ compute_config=ttnn.init_device_compute_kernel_config(
+ device.arch(), math_fidelity=self.model_config["MATH_FIDELITY"]
+ ),
conv_op_cache=conv_op_cache,
+ return_output_dim=True,
+ return_weights_and_bias=True,
)
# Relu is fused with conv1
diff --git a/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_xxlarge_new_conv_api.py b/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_xxlarge_new_conv_api.py
index 45d93ebf685..d59a6c75238 100644
--- a/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_xxlarge_new_conv_api.py
+++ b/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_xxlarge_new_conv_api.py
@@ -163,7 +163,7 @@ def run_downsample_if_req(
height_sharding=None,
):
if self.downsample:
- ds_out, _, _, self.ds_conv_weight_tensor, self.ds_conv_bias_tensor = ttnn.conv2d(
+ ds_out, [self.ds_conv_weight_tensor, self.ds_conv_bias_tensor] = ttnn.conv2d(
input_tensor=x,
weight_tensor=self.ds_conv_weight_tensor,
in_channels=self.ds_conv_input_channels,
@@ -179,7 +179,6 @@ def run_downsample_if_req(
conv_config=ttnn.Conv2dConfig(
dtype=self.model_config["ACTIVATIONS_DTYPE"],
weights_dtype=self.model_config["WEIGHTS_DTYPE"],
- math_fidelity=self.model_config["MATH_FIDELITY"],
shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED
if height_sharding
else ttnn.TensorMemoryLayout.BLOCK_SHARDED,
@@ -188,7 +187,12 @@ def run_downsample_if_req(
reshard_if_not_optimal=reshard_if_not_optimal,
transpose_shards=height_sharding,
),
+ compute_config=ttnn.init_device_compute_kernel_config(
+ device.arch(), math_fidelity=self.model_config["MATH_FIDELITY"]
+ ),
conv_op_cache=conv_op_cache,
+ return_output_dim=False,
+ return_weights_and_bias=True,
)
ttnn.deallocate(x)
ds_out = ttnn.reallocate(ds_out)
@@ -216,7 +220,7 @@ def __call__(
# conv1 is 1x1 conv
# print("Running conv1")
module_input_height = input_height
- out, input_height, input_width, self.conv1_weight_tensor, self.conv1_bias_tensor = ttnn.conv2d(
+ out, [input_height, input_width], [self.conv1_weight_tensor, self.conv1_bias_tensor] = ttnn.conv2d(
input_tensor=x,
weight_tensor=self.conv1_weight_tensor,
in_channels=self.conv1_input_channels,
@@ -232,7 +236,6 @@ def __call__(
conv_config=ttnn.Conv2dConfig(
dtype=self.model_config["ACTIVATIONS_DTYPE"],
weights_dtype=self.model_config["WEIGHTS_DTYPE"],
- math_fidelity=self.model_config["MATH_FIDELITY"],
activation="relu",
shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED
if height_sharding
@@ -240,7 +243,12 @@ def __call__(
reshard_if_not_optimal=reshard_if_not_optimal,
transpose_shards=height_sharding,
),
+ compute_config=ttnn.init_device_compute_kernel_config(
+ device.arch(), math_fidelity=self.model_config["MATH_FIDELITY"]
+ ),
conv_op_cache=conv_op_cache,
+ return_output_dim=True,
+ return_weights_and_bias=True,
)
if is_wormhole_b0():
@@ -321,7 +329,7 @@ def __call__(
# self.conv1_input_channels == 256 and
# self.downsample
)
- out, input_height, input_width, self.conv2_weight_tensor, self.conv2_bias_tensor = ttnn.conv2d(
+ out, [input_height, input_width], [self.conv2_weight_tensor, self.conv2_bias_tensor] = ttnn.conv2d(
input_tensor=out,
weight_tensor=self.conv2_weight_tensor,
in_channels=self.conv2_input_channels,
@@ -337,7 +345,6 @@ def __call__(
conv_config=ttnn.Conv2dConfig(
dtype=self.model_config["ACTIVATIONS_DTYPE"],
weights_dtype=self.model_config["WEIGHTS_DTYPE"],
- math_fidelity=self.model_config["MATH_FIDELITY"],
activation="relu",
deallocate_activation=True,
reallocate_halo_output=reallocate_halo_output,
@@ -348,12 +355,17 @@ def __call__(
reshard_if_not_optimal=reshard_if_not_optimal,
transpose_shards=height_sharding,
),
+ compute_config=ttnn.init_device_compute_kernel_config(
+ device.arch(), math_fidelity=self.model_config["MATH_FIDELITY"]
+ ),
conv_op_cache=conv_op_cache,
+ return_output_dim=True,
+ return_weights_and_bias=True,
)
# conv3 is 1x1 conv
# print("Running conv3")
- out, _, _, self.conv3_weight_tensor, self.conv3_bias_tensor = ttnn.conv2d(
+ out, [self.conv3_weight_tensor, self.conv3_bias_tensor] = ttnn.conv2d(
input_tensor=out,
weight_tensor=self.conv3_weight_tensor,
in_channels=self.conv3_input_channels,
@@ -369,14 +381,18 @@ def __call__(
conv_config=ttnn.Conv2dConfig(
dtype=self.model_config["ACTIVATIONS_DTYPE"],
weights_dtype=self.model_config["WEIGHTS_DTYPE"],
- math_fidelity=self.model_config["MATH_FIDELITY"],
shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED
if height_sharding
else ttnn.TensorMemoryLayout.BLOCK_SHARDED,
reshard_if_not_optimal=reshard_if_not_optimal,
transpose_shards=height_sharding,
),
+ compute_config=ttnn.init_device_compute_kernel_config(
+ device.arch(), math_fidelity=self.model_config["MATH_FIDELITY"]
+ ),
conv_op_cache=conv_op_cache,
+ return_output_dim=False,
+ return_weights_and_bias=True,
)
if not self.run_downsample_before_conv2:
@@ -581,7 +597,7 @@ def first_run(self, input_tensor, device, batch_size, ops_parallel_config) -> tt
else:
act_block_h_override = 0
- x, x_height, x_width, self.conv1_weight_tensor, self.conv1_bias_tensor = ttnn.conv2d(
+ x, [x_height, x_width], [self.conv1_weight_tensor, self.conv1_bias_tensor] = ttnn.conv2d(
input_tensor=input_tensor,
weight_tensor=self.conv1_weight_tensor,
in_channels=self.conv1_input_channels,
@@ -597,14 +613,18 @@ def first_run(self, input_tensor, device, batch_size, ops_parallel_config) -> tt
conv_config=ttnn.Conv2dConfig(
dtype=self.model_config["ACTIVATIONS_DTYPE"],
weights_dtype=self.model_config["WEIGHTS_DTYPE"],
- math_fidelity=self.model_config["MATH_FIDELITY"],
activation="relu",
deallocate_activation=True,
reallocate_halo_output=True,
input_channels_alignment=16 if not is_wormhole_b0() else 32,
act_block_h_override=act_block_h_override,
),
+ compute_config=ttnn.init_device_compute_kernel_config(
+ device.arch(), math_fidelity=self.model_config["MATH_FIDELITY"]
+ ),
conv_op_cache=conv_op_cache,
+ return_output_dim=True,
+ return_weights_and_bias=True,
)
# Relu is fused with conv1
@@ -915,7 +935,7 @@ def optimized_run(self, input_tensor, device, batch_size, ops_parallel_config, c
elif batch_size == 1:
act_block_h_override = 256
- x, x_height, x_width, self.conv1_weight_tensor, self.conv1_bias_tensor = ttnn.conv2d(
+ x, [x_height, x_width], [self.conv1_weight_tensor, self.conv1_bias_tensor] = ttnn.conv2d(
input_tensor=input_tensor,
weight_tensor=self.conv1_weight_tensor,
in_channels=self.conv1_input_channels,
@@ -931,13 +951,17 @@ def optimized_run(self, input_tensor, device, batch_size, ops_parallel_config, c
conv_config=ttnn.Conv2dConfig(
dtype=self.model_config["ACTIVATIONS_DTYPE"],
weights_dtype=self.model_config["WEIGHTS_DTYPE"],
- math_fidelity=self.model_config["MATH_FIDELITY"],
activation="relu",
deallocate_activation=True,
input_channels_alignment=16 if not is_wormhole_b0() else 32,
act_block_h_override=act_block_h_override,
),
+ compute_config=ttnn.init_device_compute_kernel_config(
+ device.arch(), math_fidelity=self.model_config["MATH_FIDELITY"]
+ ),
conv_op_cache=conv_op_cache,
+ return_output_dim=True,
+ return_weights_and_bias=True,
)
# Relu is fused with conv1
diff --git a/models/demos/vgg/tests/test_perf_vgg.py b/models/demos/vgg/tests/test_perf_vgg.py
index 4d5bdf30e06..9cc0397bb07 100644
--- a/models/demos/vgg/tests/test_perf_vgg.py
+++ b/models/demos/vgg/tests/test_perf_vgg.py
@@ -79,17 +79,6 @@ def test_vgg(
"ACTIVATIONS_DTYPE": act_dtype,
}
- conv_config = ttnn.Conv2dConfig(
- dtype=model_config["ACTIVATIONS_DTYPE"],
- weights_dtype=model_config["WEIGHTS_DTYPE"],
- math_fidelity=model_config["MATH_FIDELITY"],
- activation="relu",
- deallocate_activation=True,
- input_channels_alignment=16,
- act_block_h_override=0,
- transpose_shards=True,
- )
-
torch_batched_tensor = torch_input_tensor_nchw.repeat(batch_size, 1, 1, 1)
torch_input_tensor = torch.permute(torch_batched_tensor, (0, 2, 3, 1))
tt_batched_input_tensor = ttnn.from_torch(torch_input_tensor, ttnn.bfloat16)
diff --git a/models/demos/vgg/tt/ttnn_vgg.py b/models/demos/vgg/tt/ttnn_vgg.py
index 0748c745d16..82f5dd1c03d 100644
--- a/models/demos/vgg/tt/ttnn_vgg.py
+++ b/models/demos/vgg/tt/ttnn_vgg.py
@@ -90,10 +90,6 @@ def ttnn_vgg16(
conv_config = ttnn.Conv2dConfig(
dtype=model_config["ACTIVATIONS_DTYPE"],
weights_dtype=model_config["WEIGHTS_DTYPE"],
- math_fidelity=model_config["MATH_FIDELITY"],
- math_approx_mode_enabled=True,
- fp32_dest_acc_enabled=False,
- packer_l1_accum_enabled=False,
activation="relu",
deallocate_activation=False,
input_channels_alignment=32,
@@ -107,13 +103,20 @@ def ttnn_vgg16(
)
if h_override[iter_conv_id] is not None:
conv_config.act_block_h_override = h_override[iter_conv_id]
+ compute_config = ttnn.init_device_compute_kernel_config(
+ device.arch(),
+ math_fidelity=model_config["MATH_FIDELITY"],
+ math_approx_mode=True,
+ fp32_dest_acc_en=False,
+ packer_l1_acc=False,
+ )
tt_weight = parameters.features[conv_feature_ids[iter_conv_id]].weight
tt_weight = ttnn.to_layout(ttnn.from_device(tt_weight), layout=ttnn.ROW_MAJOR_LAYOUT)
tt_bias = parameters.features[conv_feature_ids[iter_conv_id]].bias
# Call ttnn.conv
conv_op_cache = {}
- [tt_output_tensor_on_device, out_height, out_width, weights_device, bias_device] = ttnn.conv2d(
+ [tt_output_tensor_on_device, [out_height, out_width], [weights_device, bias_device]] = ttnn.conv2d(
input_tensor=tt_x,
weight_tensor=tt_weight,
in_channels=conv_ttnn_params[iter_conv_id][0],
@@ -127,7 +130,10 @@ def ttnn_vgg16(
input_height=conv_ttnn_params[iter_conv_id][2],
input_width=conv_ttnn_params[iter_conv_id][3],
conv_config=conv_config,
+ compute_config=compute_config,
conv_op_cache=conv_op_cache,
+ return_output_dim=True,
+ return_weights_and_bias=True,
)
tt_x = ttnn.from_device(tt_output_tensor_on_device)
ttnn.deallocate(tt_output_tensor_on_device)
@@ -214,9 +220,6 @@ def ttnn_vgg11(
conv_config = ttnn.Conv2dConfig(
dtype=model_config["ACTIVATIONS_DTYPE"],
weights_dtype=model_config["WEIGHTS_DTYPE"],
- math_fidelity=model_config["MATH_FIDELITY"],
- math_approx_mode_enabled=True,
- fp32_dest_acc_enabled=True,
activation="relu",
deallocate_activation=False,
input_channels_alignment=32,
@@ -230,13 +233,19 @@ def ttnn_vgg11(
if height_override_11[iter_conv_id] is not None:
conv_config.act_block_h_override = height_override_11[iter_conv_id]
+ compute_config = ttnn.init_device_compute_kernel_config(
+ device.arch(),
+ math_fidelity=model_config["MATH_FIDELITY"],
+ math_approx_mode=True,
+ fp32_dest_acc_en=True,
+ )
tt_weight = parameters.features[conv_feature_ids_2[iter_conv_id]].weight
tt_weight = ttnn.to_layout(ttnn.from_device(tt_weight), layout=ttnn.ROW_MAJOR_LAYOUT)
tt_bias = parameters.features[conv_feature_ids_2[iter_conv_id]].bias
# Call ttnn.conv
conv_op_cache = {}
- [tt_output_tensor_on_device, out_height, out_width, weights_device, bias_device] = ttnn.conv2d(
+ [tt_output_tensor_on_device, [out_height, out_width], [weights_device, bias_device]] = ttnn.conv2d(
input_tensor=tt_x,
weight_tensor=tt_weight,
in_channels=conv_ttnn_params_2[iter_conv_id][0],
@@ -250,7 +259,10 @@ def ttnn_vgg11(
input_height=conv_ttnn_params_2[iter_conv_id][2],
input_width=conv_ttnn_params_2[iter_conv_id][3],
conv_config=conv_config,
+ compute_config=compute_config,
conv_op_cache=conv_op_cache,
+ return_output_dim=True,
+ return_weights_and_bias=True,
)
tt_x = ttnn.from_device(tt_output_tensor_on_device)
ttnn.deallocate(tt_output_tensor_on_device)
diff --git a/models/demos/wormhole/mamba/tt/mamba_conv.py b/models/demos/wormhole/mamba/tt/mamba_conv.py
index a2700198f83..c7e897ad484 100644
--- a/models/demos/wormhole/mamba/tt/mamba_conv.py
+++ b/models/demos/wormhole/mamba/tt/mamba_conv.py
@@ -54,11 +54,14 @@ def prepare_conv_config(self):
self.conv1d_config = ttnn.Conv1dConfig(
dtype=self.config.output_dtype,
weights_dtype=self.config.weights_dtype,
- math_fidelity=self.config.math_fidelity,
shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED,
input_channels_alignment=32,
deallocate_activation=True,
)
+ self.conv1d_compute_config = ttnn.init_device_compute_kernel_config(
+ self.device.arch(),
+ math_fidelity=self.config.math_fidelity,
+ )
def prepare_input(self, input_tensor):
# input_tensor (1, 1, B, 2E)
@@ -87,7 +90,7 @@ def __call__(self, input_tensor):
input_tensor_splits = self.prepare_input(input_tensor)
output_tensor_splits = []
for i in range(self.config.channels_split_factor):
- [tt_output_tensor_on_device, out_length, weights_device, _] = ttnn.Conv1d(
+ [tt_output_tensor_on_device, out_length, [weights_device, _]] = ttnn.Conv1d(
input_tensor=input_tensor_splits[i],
weight_tensor=self.tt_weight_tensor_splits[i],
in_channels=self.config.input_channels // self.config.channels_split_factor,
@@ -100,9 +103,12 @@ def __call__(self, input_tensor):
batch_size=1,
input_length=self.config.input_length,
conv_config=self.conv1d_config,
+ compute_config=self.conv1d_compute_config,
conv_op_cache={},
debug=False,
groups=self.config.groups // self.config.channels_split_factor,
+ return_output_dim=True,
+ return_weights_and_bias=True,
)
self.tt_weight_tensor_splits[i] = weights_device
output_tensor_splits.append(ttnn.sharded_to_interleaved(tt_output_tensor_on_device))
diff --git a/models/demos/wormhole/stable_diffusion/test_multiple_iterations.py b/models/demos/wormhole/stable_diffusion/test_multiple_iterations.py
deleted file mode 100644
index 8db6aee6f39..00000000000
--- a/models/demos/wormhole/stable_diffusion/test_multiple_iterations.py
+++ /dev/null
@@ -1,236 +0,0 @@
-# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
-
-# SPDX-License-Identifier: Apache-2.0
-
-import ttnn
-import json
-import torch
-import pytest
-import numpy as np
-from PIL import Image
-from loguru import logger
-from tqdm.auto import tqdm
-from datasets import load_dataset
-
-from transformers import CLIPTextModel, CLIPTokenizer
-from diffusers import (
- AutoencoderKL,
- UNet2DConditionModel,
- LMSDiscreteScheduler,
-)
-from models.utility_functions import (
- comp_allclose_and_pcc,
- enable_persistent_kernel_cache,
- disable_persistent_kernel_cache,
-)
-from models.utility_functions import skip_for_wormhole_b0
-from ttnn.model_preprocessing import preprocess_model_parameters
-from models.demos.wormhole.stable_diffusion.custom_preprocessing import custom_preprocessor
-from models.demos.wormhole.stable_diffusion.tt.ttnn_functional_unet_2d_condition_model import (
- UNet2DConditionModel as UNet2D,
-)
-
-from torchvision.transforms import ToTensor
-
-
-def load_inputs(input_path):
- with open(input_path) as f:
- input_data = json.load(f)
- assert input_data, "Input data is empty."
- prompt = [item["prompt"] for item in input_data]
- return prompt
-
-
-def constant_prop_time_embeddings(timesteps, sample, time_proj):
- timesteps = timesteps[None]
- timesteps = timesteps.expand(sample.shape[0])
- t_emb = time_proj(timesteps)
- return t_emb
-
-
-def save_image_and_latents(latents, iter, vae, pre_fix="", pre_fix2=""):
- pre_fix = "" if pre_fix == "" else f"{pre_fix}_"
- pre_fix2 = "" if pre_fix2 == "" else f"{pre_fix2}_"
- _latents = 1 / 0.18215 * latents
-
- with torch.no_grad():
- image = vae.decode(_latents).sample
- # Image post-processing
- image = (image / 2 + 0.5).clamp(0, 1)
- image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
- images = (image * 255).round().astype("uint8")
- pil_images = [Image.fromarray(image) for image in images][0]
- pil_images.save(f"{pre_fix}{pre_fix2}image_iter_{iter}.png")
-
-
-def guide(noise_pred, guidance_scale, t): # will return latents
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
- return noise_pred
-
-
-def latent_expansion(latents, scheduler, t):
- latent_model_input = torch.cat([latents] * 2, dim=0)
- latent_model_input = scheduler.scale_model_input(latent_model_input, timestep=t)
- return latent_model_input
-
-
-def calculate_fid_score(imgs_path1, imgs_path2):
- fid = FrechetInceptionDistance(normalize=True)
- fid.update(imgs_path1, real=False)
- fid.update(imgs_path2, real=True)
- return fid.compute()
-
-
-def preprocess_images(image_paths):
- images = []
- for image_path in image_paths:
- image = Image.open(image_path)
- image = image.convert("RGB")
- image = image.resize((299, 299))
- image = ToTensor()(image)
- images.append(image)
- return torch.stack(images)
-
-
-def run_demo_inference_diffusiondb(device, reset_seeds, input_path, num_inference_steps, image_size):
- disable_persistent_kernel_cache()
-
- height, width = image_size
-
- experiment_name = f"diffusiondb_{height}x{width}"
- input_prompt = [
- "oil painting frame of Breathtaking mountain range with a clear river running through it, surrounded by tall trees and misty clouds, serene, peaceful, mountain landscape, high detail"
- ]
- logger.info(f"input_prompts: {input_prompt}")
-
- # 1. Load the autoencoder model which will be used to decode the latents into image space.
- vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae")
-
- # 2. Load the tokenizer and text encoder to tokenize and encode the text.
- tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
- text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
-
- # 3. The UNet model for generating the latents.
- unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet")
-
- # 4. load the K-LMS scheduler with some fitting parameters.
- ttnn_scheduler = LMSDiscreteScheduler(
- beta_start=0.00085,
- beta_end=0.012,
- beta_schedule="scaled_linear",
- num_train_timesteps=1000,
- )
-
- torch_device = "cpu"
- vae.to(torch_device)
- text_encoder.to(torch_device)
- unet.to(torch_device)
-
- guidance_scale = 7.5 # Scale for classifier-free guidance
- generator = torch.manual_seed(174) # 10233 Seed generator to create the inital latent noise
- batch_size = len(input_prompt)
-
- ## First, we get the text_embeddings for the prompt. These embeddings will be used to condition the UNet model.
- # Tokenizer and Text Encoder
- text_input = tokenizer(
- input_prompt,
- padding="max_length",
- max_length=tokenizer.model_max_length,
- truncation=True,
- return_tensors="pt",
- )
- text_embeddings = text_encoder(text_input.input_ids.to(torch_device))[0]
- max_length = text_input.input_ids.shape[-1]
- uncond_input = tokenizer([""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt")
- uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0]
-
- # For classifier-free guidance, we need to do two forward passes: one with the conditioned input (text_embeddings),
- # and another with the unconditional embeddings (uncond_embeddings).
- # In practice, we can concatenate both into a single batch to avoid doing two forward passes.
- text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
- ttnn_text_embeddings = ttnn.from_torch(text_embeddings, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
-
- vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
- # Initial random noise
- latents = torch.randn(
- (batch_size, unet.config.in_channels, height // vae_scale_factor, width // vae_scale_factor),
- generator=generator,
- )
- latents = latents.to(torch_device)
-
- ttnn_scheduler.set_timesteps(num_inference_steps)
-
- latents = latents * ttnn_scheduler.init_noise_sigma
- ttnn_latents = torch.tensor(latents)
-
- iter = 0
- config = unet.config
-
- parameters = preprocess_model_parameters(
- initialize_model=lambda: unet, custom_preprocessor=custom_preprocessor, device=device
- )
- input_height = 64
- input_width = 64
- reader_patterns_cache = {} if height == 512 and width == 512 else None
-
- model = UNet2D(device, parameters, 2, input_height, input_width, reader_patterns_cache)
- # # Denoising loop
- for t in tqdm(ttnn_scheduler.timesteps):
- # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
- ttnn_latent_model_input = latent_expansion(ttnn_latents, ttnn_scheduler, t)
- ttnn_latent_model_input = ttnn.from_torch(
- ttnn_latent_model_input, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device
- )
-
- _t = constant_prop_time_embeddings(t, ttnn_latent_model_input, unet.time_proj)
- _t = _t.unsqueeze(0).unsqueeze(0)
- _t = ttnn.from_torch(_t, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
-
- # predict the noise residual
- with torch.no_grad():
- ttnn_output = model(
- ttnn_latent_model_input, # input
- timestep=_t,
- encoder_hidden_states=ttnn_text_embeddings,
- class_labels=None,
- attention_mask=None,
- cross_attention_kwargs=None,
- return_dict=True,
- config=config,
- )
- noise_pred = ttnn.to_torch(ttnn_output)
-
- # perform guidance
- noise_pred = guide(noise_pred, guidance_scale, t)
-
- ttnn_latents = ttnn_scheduler.step(noise_pred, t, ttnn_latents).prev_sample
- save_image_and_latents(ttnn_latents, iter, vae, pre_fix=f"{experiment_name}_tt", pre_fix2="")
-
- iter += 1
- enable_persistent_kernel_cache()
-
- latents = ttnn_latents
- # scale and decode the image latents with vae
- latents = 1 / 0.18215 * latents
- with torch.no_grad():
- image = vae.decode(latents).sample
-
- # Image post-processing
- image = (image / 2 + 0.5).clamp(0, 1)
- image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
- images = (image * 255).round().astype("uint8")
- pil_images = [Image.fromarray(image) for image in images][0]
- ttnn_output_path = f"{experiment_name}_ttnn.png"
- pil_images.save(ttnn_output_path)
-
- ref_paths = [ref_img_path, ref_img_path]
- ttnn_paths = [ttnn_output_path, ttnn_output_path]
-
- ref_images = preprocess_images(ref_paths)
- ttnn_images = preprocess_images(ttnn_paths)
-
-
-def test_tt2_multiple_iteration(device, reset_seeds, input_path):
- # 30 iterations, generate 512x512 image
- return run_demo_inference_diffusiondb(device, reset_seeds, input_path, 30, (512, 512))
diff --git a/tests/ttnn/integration_tests/stable_diffusion/test_basic_transformer_block.py b/models/demos/wormhole/stable_diffusion/tests/test_basic_transformer_block.py
similarity index 100%
rename from tests/ttnn/integration_tests/stable_diffusion/test_basic_transformer_block.py
rename to models/demos/wormhole/stable_diffusion/tests/test_basic_transformer_block.py
diff --git a/tests/ttnn/integration_tests/stable_diffusion/test_cross_attention.py b/models/demos/wormhole/stable_diffusion/tests/test_cross_attention.py
similarity index 100%
rename from tests/ttnn/integration_tests/stable_diffusion/test_cross_attention.py
rename to models/demos/wormhole/stable_diffusion/tests/test_cross_attention.py
diff --git a/tests/ttnn/integration_tests/stable_diffusion/test_cross_attn_up_block_2d_new_conv.py b/models/demos/wormhole/stable_diffusion/tests/test_cross_attn_up_block_2d.py
similarity index 100%
rename from tests/ttnn/integration_tests/stable_diffusion/test_cross_attn_up_block_2d_new_conv.py
rename to models/demos/wormhole/stable_diffusion/tests/test_cross_attn_up_block_2d.py
diff --git a/tests/ttnn/integration_tests/stable_diffusion/test_demo.py b/models/demos/wormhole/stable_diffusion/tests/test_demo.py
similarity index 91%
rename from tests/ttnn/integration_tests/stable_diffusion/test_demo.py
rename to models/demos/wormhole/stable_diffusion/tests/test_demo.py
index 5c8dc03b967..36b73e70e3c 100644
--- a/tests/ttnn/integration_tests/stable_diffusion/test_demo.py
+++ b/models/demos/wormhole/stable_diffusion/tests/test_demo.py
@@ -29,6 +29,8 @@
((512, 512),),
)
def test_demo_sd(device, reset_seeds, input_path, num_prompts, num_inference_steps, image_size):
+ if device.core_grid.y != 8:
+ pytest.skip("Needs 8x8 Grid")
demo(device, reset_seeds, input_path, num_prompts, num_inference_steps, image_size)
@@ -48,4 +50,6 @@ def test_demo_sd(device, reset_seeds, input_path, num_prompts, num_inference_ste
((512, 512),),
)
def test_demo_sd_db(device, reset_seeds, input_path, num_prompts, num_inference_steps, image_size):
+ if device.core_grid.y != 8:
+ pytest.skip("Needs 8x8 Grid")
demo_db(device, reset_seeds, input_path, num_prompts, num_inference_steps, image_size)
diff --git a/tests/ttnn/integration_tests/stable_diffusion/test_embedding.py b/models/demos/wormhole/stable_diffusion/tests/test_embedding.py
similarity index 100%
rename from tests/ttnn/integration_tests/stable_diffusion/test_embedding.py
rename to models/demos/wormhole/stable_diffusion/tests/test_embedding.py
diff --git a/tests/ttnn/integration_tests/stable_diffusion/test_feedforward.py b/models/demos/wormhole/stable_diffusion/tests/test_feedforward.py
similarity index 100%
rename from tests/ttnn/integration_tests/stable_diffusion/test_feedforward.py
rename to models/demos/wormhole/stable_diffusion/tests/test_feedforward.py
diff --git a/tests/ttnn/integration_tests/stable_diffusion/test_geglu.py b/models/demos/wormhole/stable_diffusion/tests/test_geglu.py
similarity index 100%
rename from tests/ttnn/integration_tests/stable_diffusion/test_geglu.py
rename to models/demos/wormhole/stable_diffusion/tests/test_geglu.py
diff --git a/tests/device_perf_tests/stable_diffusion/test_perf_stable_diffusion.py b/models/demos/wormhole/stable_diffusion/tests/test_perf.py
similarity index 100%
rename from tests/device_perf_tests/stable_diffusion/test_perf_stable_diffusion.py
rename to models/demos/wormhole/stable_diffusion/tests/test_perf.py
diff --git a/tests/ttnn/integration_tests/stable_diffusion/test_resnet_block_2d_new_conv.py b/models/demos/wormhole/stable_diffusion/tests/test_resnet_block_2d.py
similarity index 66%
rename from tests/ttnn/integration_tests/stable_diffusion/test_resnet_block_2d_new_conv.py
rename to models/demos/wormhole/stable_diffusion/tests/test_resnet_block_2d.py
index 51afb5afd0d..91a0f3755e5 100644
--- a/tests/ttnn/integration_tests/stable_diffusion/test_resnet_block_2d_new_conv.py
+++ b/models/demos/wormhole/stable_diffusion/tests/test_resnet_block_2d.py
@@ -25,90 +25,6 @@ def ttnn_to_torch(input):
return input
-@skip_for_grayskull()
-@pytest.mark.parametrize(
- "batch_size, in_channels, input_height, input_width, index1,index2,block_name,out_channels",
- [
- (2, 320, 32, 32, 0, 0, "down", None),
- (2, 320, 16, 16, 0, 0, "down", None),
- (2, 640, 16, 16, 1, 1, "down", None),
- (2, 640, 8, 8, 1, 1, "down", None),
- (2, 1280, 8, 8, 2, 1, "down", None),
- (2, 1280, 4, 4, 2, 1, "down", None),
- (2, 2560, 4, 4, 0, 0, "up", 1280),
- (2, 2560, 8, 8, 0, 0, "up", 1280),
- (2, 1920, 8, 8, 2, 0, "up", 640),
- (2, 1920, 16, 16, 2, 0, "up", 640),
- (2, 1280, 16, 16, 3, 0, "down", None),
- (2, 960, 16, 16, 3, 0, "up", 320),
- (2, 960, 32, 32, 3, 0, "up", 320),
- (2, 640, 32, 32, 3, 1, "up", 320),
- ],
-)
-def test_resnet_block_2d_256x256(
- device, batch_size, in_channels, input_height, input_width, index1, index2, block_name, out_channels
-):
- pytest.skip()
- # setup pytorch model
- model_name = "CompVis/stable-diffusion-v1-4"
- pipe = StableDiffusionPipeline.from_pretrained(model_name, torch_dtype=torch.float32)
-
- model = pipe.unet
- model.eval()
-
- parameters = preprocess_model_parameters(
- model_name=model_name, initialize_model=lambda: model, custom_preprocessor=custom_preprocessor, device=device
- )
-
- if block_name == "up":
- parameters = parameters.up_blocks[index1].resnets[index2]
- resnet = pipe.unet.up_blocks[index1].resnets[index2]
- elif block_name == "down":
- parameters = parameters.down_blocks[index1].resnets[index2]
- resnet = pipe.unet.down_blocks[index1].resnets[index2]
- else:
- parameters = parameters.mid_block.resnets[index2]
- resnet = pipe.unet.mid_block.resnets[index2]
-
- ############ start of residual block #############
- temb_channels = 1280
- groups = 32
- time_embedding_norm = "default"
- output_scale_factor = 1
- use_in_shortcut = None
- ########## end of residual block #############
- hidden_states_shape = [batch_size, in_channels, input_height, input_width]
- temb_shape = [1, 1, 2, 1280]
-
- input = torch.randn(hidden_states_shape)
- temb = torch.randn(temb_shape)
-
- torch_output = resnet(input, temb.squeeze(0).squeeze(0))
-
- input = ttnn.from_torch(input, ttnn.bfloat16)
- input = ttnn.to_layout(input, ttnn.TILE_LAYOUT)
- input = ttnn.to_device(input, device, memory_config=ttnn.L1_MEMORY_CONFIG)
-
- temb = ttnn.from_torch(temb, ttnn.bfloat16)
- temb = ttnn.to_layout(temb, ttnn.TILE_LAYOUT)
- temb = ttnn.to_device(temb, device, memory_config=ttnn.L1_MEMORY_CONFIG)
- ttnn_output = resnetBlock2D(
- input,
- temb=temb,
- temb_channels=temb_channels,
- time_embedding_norm=time_embedding_norm,
- in_channels=in_channels,
- out_channels=out_channels,
- use_in_shortcut=use_in_shortcut,
- groups=groups,
- output_scale_factor=output_scale_factor,
- parameters=parameters,
- device=device,
- )
- ttnn_output = ttnn_to_torch(ttnn_output)
- assert_with_pcc(torch_output, ttnn_output, pcc=0.99)
-
-
@skip_for_grayskull()
@pytest.mark.parametrize("device_params", [{"l1_small_size": 32768}], indirect=True)
@pytest.mark.parametrize(
diff --git a/tests/ttnn/integration_tests/stable_diffusion/test_sharded_matmuls.py b/models/demos/wormhole/stable_diffusion/tests/test_sharded_matmuls.py
similarity index 100%
rename from tests/ttnn/integration_tests/stable_diffusion/test_sharded_matmuls.py
rename to models/demos/wormhole/stable_diffusion/tests/test_sharded_matmuls.py
diff --git a/tests/ttnn/integration_tests/stable_diffusion/test_transformer_2d_model_new_conv.py b/models/demos/wormhole/stable_diffusion/tests/test_transformer_2d_model.py
similarity index 100%
rename from tests/ttnn/integration_tests/stable_diffusion/test_transformer_2d_model_new_conv.py
rename to models/demos/wormhole/stable_diffusion/tests/test_transformer_2d_model.py
diff --git a/tests/ttnn/integration_tests/stable_diffusion/test_unet_2d_condition_model_new_conv.py b/models/demos/wormhole/stable_diffusion/tests/test_unet_2d_condition_model.py
similarity index 98%
rename from tests/ttnn/integration_tests/stable_diffusion/test_unet_2d_condition_model_new_conv.py
rename to models/demos/wormhole/stable_diffusion/tests/test_unet_2d_condition_model.py
index 35b1253ea54..72efdb4e178 100644
--- a/tests/ttnn/integration_tests/stable_diffusion/test_unet_2d_condition_model_new_conv.py
+++ b/models/demos/wormhole/stable_diffusion/tests/test_unet_2d_condition_model.py
@@ -63,7 +63,6 @@ def unsqueeze_all_params_to_4d(params):
@skip_for_grayskull()
-@pytest.mark.skipif(is_wormhole_b0() or is_blackhole(), reason="#10923: CB / L1 buffer clash")
@pytest.mark.parametrize(
"device_params", [{"l1_small_size": 32768}], ids=["device_params=l1_small_size_24576"], indirect=True
)
@@ -204,7 +203,7 @@ def test_unet_2d_condition_model_512x512(device, batch_size, in_channels, input_
# print(iter)
# print(f"Time taken for 50 iterations: {total_time}")
# print(f"Samples per second: {50 / total_time}")
- passing, output = comp_pcc(torch_output, ttnn_output, pcc=0.99)
+ passing, output = comp_pcc(torch_output, ttnn_output, pcc=0.981)
print(output)
assert passing
diff --git a/tests/ttnn/integration_tests/stable_diffusion/test_upblock_2d_new_conv.py b/models/demos/wormhole/stable_diffusion/tests/test_upblock_2d.py
similarity index 100%
rename from tests/ttnn/integration_tests/stable_diffusion/test_upblock_2d_new_conv.py
rename to models/demos/wormhole/stable_diffusion/tests/test_upblock_2d.py
diff --git a/tests/ttnn/integration_tests/stable_diffusion/test_upsample_2d_new_conv.py b/models/demos/wormhole/stable_diffusion/tests/test_upsample_2d.py
similarity index 100%
rename from tests/ttnn/integration_tests/stable_diffusion/test_upsample_2d_new_conv.py
rename to models/demos/wormhole/stable_diffusion/tests/test_upsample_2d.py
diff --git a/tests/ttnn/integration_tests/stable_diffusion/test_upsample_nearest_2d.py b/models/demos/wormhole/stable_diffusion/tests/test_upsample_nearest_2d.py
similarity index 100%
rename from tests/ttnn/integration_tests/stable_diffusion/test_upsample_nearest_2d.py
rename to models/demos/wormhole/stable_diffusion/tests/test_upsample_nearest_2d.py
diff --git a/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_downsample_2d_new_conv.py b/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_downsample_2d_new_conv.py
index 2ad02078d71..570d2457f1a 100644
--- a/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_downsample_2d_new_conv.py
+++ b/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_downsample_2d_new_conv.py
@@ -126,11 +126,7 @@ def __call__(
conv_config = ttnn.Conv2dConfig(
dtype=ttnn.bfloat8_b,
weights_dtype=ttnn.bfloat8_b,
- math_fidelity=ttnn.MathFidelity.LoFi,
activation="",
- math_approx_mode_enabled=True,
- fp32_dest_acc_enabled=True,
- packer_l1_accum_enabled=False,
shard_layout=self.shard_layout,
input_channels_alignment=32,
transpose_shards=False,
@@ -140,10 +136,17 @@ def __call__(
if hidden_states.memory_config() != self.input_memory_config:
hidden_states = ttnn.to_memory_config(hidden_states, self.input_memory_config)
+ compute_config = ttnn.init_device_compute_kernel_config(
+ self.device.arch(),
+ math_fidelity=ttnn.MathFidelity.LoFi,
+ math_approx_mode=True,
+ fp32_dest_acc_en=True,
+ packer_l1_acc=False,
+ )
if self.conv_config_override and "act_block_h" in self.conv_config_override:
conv_config.act_block_h_override = self.conv_config_override["act_block_h"]
- [hidden_states, _out_height, _out_width, self.conv_weights, self.conv_bias] = ttnn.conv2d(
+ [hidden_states, [self.conv_weights, self.conv_bias]] = ttnn.conv2d(
input_tensor=hidden_states,
in_channels=self.in_channels,
out_channels=self.out_channels,
@@ -157,7 +160,10 @@ def __call__(
weight_tensor=self.conv_weights,
bias_tensor=self.conv_bias,
conv_config=conv_config,
+ compute_config=compute_config,
conv_op_cache=conv_cache,
+ return_output_dim=False,
+ return_weights_and_bias=True,
)
# hidden_states = run_ttnn_conv_with_pre_and_post_tensor_formatting(
# self.device,
diff --git a/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_resnetblock2d_new_conv.py b/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_resnetblock2d_new_conv.py
index 4cedbdea78c..45024d9c9d9 100644
--- a/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_resnetblock2d_new_conv.py
+++ b/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_resnetblock2d_new_conv.py
@@ -459,19 +459,22 @@ def __call__(
conv_config = ttnn.Conv2dConfig(
dtype=ttnn.bfloat8_b,
weights_dtype=ttnn.bfloat8_b,
- math_fidelity=ttnn.MathFidelity.LoFi,
activation="",
shard_layout=self.conv1_shard_layout,
- math_approx_mode_enabled=True,
- fp32_dest_acc_enabled=True,
- packer_l1_accum_enabled=False,
input_channels_alignment=32,
transpose_shards=False,
reshard_if_not_optimal=False,
)
+ compute_config = ttnn.init_device_compute_kernel_config(
+ self.device.arch(),
+ math_fidelity=ttnn.MathFidelity.LoFi,
+ math_approx_mode=True,
+ fp32_dest_acc_en=True,
+ packer_l1_acc=False,
+ )
if self.conv1_config_override and "act_block_h" in self.conv2_config_override:
conv_config.act_block_h_override = self.conv1_config_override["act_block_h"]
- [hidden_states, _out_height, _out_width, self.conv1s_weights[0], self.conv1s_bias[0]] = ttnn.conv2d(
+ [hidden_states, [self.conv1s_weights[0], self.conv1s_bias[0]]] = ttnn.conv2d(
input_tensor=hidden_states,
weight_tensor=self.conv1s_weights[0],
in_channels=self.conv1_in_channels,
@@ -485,7 +488,10 @@ def __call__(
input_height=self.conv1_input_height,
input_width=self.conv1_input_width,
conv_config=conv_config,
+ compute_config=compute_config,
conv_op_cache=conv_cache,
+ return_output_dim=False,
+ return_weights_and_bias=True,
)
else:
@@ -529,26 +535,26 @@ def __call__(
conv_config = ttnn.Conv2dConfig(
dtype=ttnn.bfloat8_b,
weights_dtype=ttnn.bfloat8_b,
- math_fidelity=ttnn.MathFidelity.LoFi,
activation="",
shard_layout=ttnn.TensorMemoryLayout.BLOCK_SHARDED,
- math_approx_mode_enabled=True,
- fp32_dest_acc_enabled=True,
- packer_l1_accum_enabled=False,
input_channels_alignment=32,
transpose_shards=False,
reshard_if_not_optimal=False,
)
-
+ compute_config = ttnn.init_device_compute_kernel_config(
+ self.device.arch(),
+ math_fidelity=ttnn.MathFidelity.LoFi,
+ math_approx_mode=True,
+ fp32_dest_acc_en=True,
+ packer_l1_acc=False,
+ )
if self.conv1_config_override and "act_block_h" in self.conv2_config_override:
conv_config.act_block_h_override = self.conv1_config_override["act_block_h"]
[
split_hidden_states[i],
- _out_height,
- _out_width,
- self.conv1s_weights[i],
- self.conv1s_bias[i],
+ [_out_height, _out_width],
+ [self.conv1s_weights[i], self.conv1s_bias[i]],
] = ttnn.conv2d(
input_tensor=split_hidden_states[i],
weight_tensor=self.conv1s_weights[i],
@@ -563,7 +569,10 @@ def __call__(
input_height=self.conv1_input_height,
input_width=self.conv1_input_width,
conv_config=conv_config,
+ compute_config=compute_config,
conv_op_cache=conv_cache,
+ return_output_dim=True,
+ return_weights_and_bias=True,
)
if i != 0:
split_hidden_states[i] = ttnn.add(
@@ -658,19 +667,22 @@ def __call__(
conv_config = ttnn.Conv2dConfig(
dtype=ttnn.bfloat8_b,
weights_dtype=ttnn.bfloat8_b,
- math_fidelity=ttnn.MathFidelity.LoFi,
activation="",
shard_layout=ttnn.TensorMemoryLayout.BLOCK_SHARDED,
- math_approx_mode_enabled=True,
- fp32_dest_acc_enabled=True,
- packer_l1_accum_enabled=False,
input_channels_alignment=32,
transpose_shards=False,
reshard_if_not_optimal=False,
)
+ compute_config = ttnn.init_device_compute_kernel_config(
+ self.device.arch(),
+ math_fidelity=ttnn.MathFidelity.LoFi,
+ math_approx_mode=True,
+ fp32_dest_acc_en=True,
+ packer_l1_acc=False,
+ )
if self.conv2_config_override and "act_block_h" in self.conv2_config_override:
conv_config.act_block_h_override = self.conv2_config_override["act_block_h"]
- [hidden_states, _out_height, _out_width, self.conv2_weights, self.conv2_bias] = ttnn.conv2d(
+ [hidden_states, [_out_height, _out_width], [self.conv2_weights, self.conv2_bias]] = ttnn.conv2d(
input_tensor=hidden_states,
weight_tensor=self.conv2_weights,
bias_tensor=self.conv2_bias,
@@ -684,7 +696,10 @@ def __call__(
input_height=self.conv2_input_height,
input_width=self.conv2_input_width,
conv_config=conv_config,
+ compute_config=compute_config,
conv_op_cache=conv_cache,
+ return_output_dim=True,
+ return_weights_and_bias=True,
)
use_in_shortcut = in_channels != out_channels if use_in_shortcut is None else use_in_shortcut
@@ -702,17 +717,24 @@ def __call__(
conv_config = ttnn.Conv2dConfig(
dtype=ttnn.bfloat8_b,
weights_dtype=ttnn.bfloat8_b,
- math_fidelity=ttnn.MathFidelity.LoFi,
activation="",
shard_layout=ttnn.TensorMemoryLayout.BLOCK_SHARDED,
- math_approx_mode_enabled=True,
- fp32_dest_acc_enabled=True,
- packer_l1_accum_enabled=False,
input_channels_alignment=32,
transpose_shards=False,
reshard_if_not_optimal=False,
)
- [input_tensor, _out_height, _out_width, self.conv_shortcut_weights, self.conv_shortcut_bias] = ttnn.conv2d(
+ compute_config = ttnn.init_device_compute_kernel_config(
+ self.device.arch(),
+ math_fidelity=ttnn.MathFidelity.LoFi,
+ math_approx_mode=True,
+ fp32_dest_acc_en=True,
+ packer_l1_acc=False,
+ )
+ [
+ input_tensor,
+ [_out_height, _out_width],
+ [self.conv_shortcut_weights, self.conv_shortcut_bias],
+ ] = ttnn.conv2d(
input_tensor=input_tensor,
weight_tensor=self.conv_shortcut_weights,
in_channels=self.conv_shortcut_in_channels,
@@ -726,7 +748,10 @@ def __call__(
input_height=self.conv_shortcut_input_height,
input_width=self.conv_shortcut_input_width,
conv_config=conv_config,
+ compute_config=compute_config,
conv_op_cache=conv_cache,
+ return_output_dim=True,
+ return_weights_and_bias=True,
)
if ttnn.get_memory_config(input_tensor) != ttnn.get_memory_config(hidden_states):
diff --git a/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_transformer_2d_new_conv.py b/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_transformer_2d_new_conv.py
index 12e4d543207..e89a957357e 100644
--- a/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_transformer_2d_new_conv.py
+++ b/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_transformer_2d_new_conv.py
@@ -242,14 +242,17 @@ def __call__(
conv_config = ttnn.Conv2dConfig(
dtype=ttnn.bfloat8_b,
weights_dtype=ttnn.bfloat8_b,
- math_fidelity=ttnn.MathFidelity.LoFi,
activation="",
shard_layout=ttnn.TensorMemoryLayout.BLOCK_SHARDED,
input_channels_alignment=32,
- fp32_dest_acc_enabled=self.compute_kernel_config.fp32_dest_acc_en,
transpose_shards=False,
)
- [hidden_states, _out_height, _out_width, self.proj_in_conv_weights, self.proj_in_conv_bias] = ttnn.conv2d(
+ compute_config = ttnn.init_device_compute_kernel_config(
+ self.device.arch(),
+ math_fidelity=ttnn.MathFidelity.LoFi,
+ fp32_dest_acc_en=self.compute_kernel_config.fp32_dest_acc_en,
+ )
+ [hidden_states, [self.proj_in_conv_weights, self.proj_in_conv_bias]] = ttnn.conv2d(
input_tensor=hidden_states,
in_channels=self.proj_in_in_channels,
out_channels=self.proj_in_out_channels,
@@ -263,7 +266,10 @@ def __call__(
weight_tensor=self.proj_in_conv_weights,
bias_tensor=self.proj_in_conv_bias,
conv_config=conv_config,
+ compute_config=compute_config,
conv_op_cache=conv_cache,
+ return_output_dim=False,
+ return_weights_and_bias=True,
)
inner_dim = hidden_states.shape[-1]
@@ -293,10 +299,8 @@ def __call__(
# hidden_states = ttnn.to_memory_config(hidden_states, self.proj_out.conv.input_sharded_memory_config)
[
hidden_states,
- _out_height,
- _out_width,
- self.proj_out_conv_weights,
- self.proj_out_conv_bias,
+ [_out_height, _out_width],
+ [self.proj_out_conv_weights, self.proj_out_conv_bias],
] = ttnn.conv2d(
input_tensor=hidden_states,
in_channels=self.proj_out_in_channels,
@@ -312,6 +316,8 @@ def __call__(
bias_tensor=self.proj_out_conv_bias,
conv_config=conv_config,
conv_op_cache=conv_cache,
+ return_output_dim=True,
+ return_weights_and_bias=True,
)
if output_bfloat16:
diff --git a/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_unet_2d_condition_model_new_conv.py b/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_unet_2d_condition_model_new_conv.py
index 9cbdfff2f48..1003c1efc4e 100644
--- a/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_unet_2d_condition_model_new_conv.py
+++ b/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_unet_2d_condition_model_new_conv.py
@@ -383,18 +383,21 @@ def __call__(
conv_config = ttnn.Conv2dConfig(
dtype=ttnn.bfloat8_b,
weights_dtype=ttnn.bfloat8_b,
- math_fidelity=ttnn.MathFidelity.LoFi,
activation="",
- math_approx_mode_enabled=True,
- fp32_dest_acc_enabled=True,
- packer_l1_accum_enabled=False,
shard_layout=shard_layout,
input_channels_alignment=32,
transpose_shards=False,
reshard_if_not_optimal=True,
)
+ compute_config = ttnn.init_device_compute_kernel_config(
+ self.device.arch(),
+ math_fidelity=ttnn.MathFidelity.LoFi,
+ math_approx_mode=True,
+ fp32_dest_acc_en=True,
+ packer_l1_acc=False,
+ )
- [sample, _out_height, _out_width, self.conv_in_weights, self.conv_in_bias] = ttnn.conv2d(
+ [sample, [self.conv_in_weights, self.conv_in_bias]] = ttnn.conv2d(
input_tensor=sample,
weight_tensor=self.conv_in_weights,
bias_tensor=self.conv_in_bias,
@@ -408,7 +411,10 @@ def __call__(
input_height=self.input_height,
input_width=self.input_width,
conv_config=conv_config,
+ compute_config=compute_config,
conv_op_cache=conv_cache,
+ return_output_dim=False,
+ return_weights_and_bias=True,
)
sample = ttnn.reallocate(sample) # TODO: Test remove
@@ -646,18 +652,21 @@ def __call__(
conv_config = ttnn.Conv2dConfig(
dtype=ttnn.bfloat8_b,
weights_dtype=ttnn.bfloat8_b,
- math_fidelity=ttnn.MathFidelity.LoFi,
activation="",
shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED,
- math_approx_mode_enabled=True,
- fp32_dest_acc_enabled=True,
- packer_l1_accum_enabled=False,
input_channels_alignment=32,
act_block_h_override=64,
transpose_shards=False,
reshard_if_not_optimal=True,
)
- [sample, _out_height, _out_width, self.conv_out_weights, self.conv_out_bias] = ttnn.conv2d(
+ compute_config = ttnn.init_device_compute_kernel_config(
+ self.device.arch(),
+ math_fidelity=ttnn.MathFidelity.LoFi,
+ math_approx_mode=True,
+ fp32_dest_acc_en=True,
+ packer_l1_acc=False,
+ )
+ [sample, [self.conv_out_weights, self.conv_out_bias]] = ttnn.conv2d(
input_tensor=sample,
in_channels=self.conv_out_in_channels,
out_channels=self.conv_out_out_channels,
@@ -671,7 +680,10 @@ def __call__(
weight_tensor=self.conv_out_weights,
bias_tensor=self.conv_out_bias,
conv_config=conv_config,
+ compute_config=compute_config,
conv_op_cache=conv_cache,
+ return_output_dim=False,
+ return_weights_and_bias=True,
)
sample = ttnn.to_memory_config(sample, ttnn.L1_MEMORY_CONFIG)
sample = ttnn.clone(sample, memory_config=ttnn.L1_MEMORY_CONFIG, dtype=ttnn.bfloat16)
diff --git a/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_upsample_2d_new_conv.py b/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_upsample_2d_new_conv.py
index 622a63065db..52e9fb5c913 100644
--- a/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_upsample_2d_new_conv.py
+++ b/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_upsample_2d_new_conv.py
@@ -91,19 +91,22 @@ def __call__(self, input, in_channels, out_channels):
conv_config = ttnn.Conv2dConfig(
dtype=ttnn.bfloat8_b,
weights_dtype=ttnn.bfloat8_b,
- math_fidelity=ttnn.MathFidelity.LoFi,
activation="",
shard_layout=ttnn.TensorMemoryLayout.BLOCK_SHARDED,
- math_approx_mode_enabled=True,
- fp32_dest_acc_enabled=True,
- packer_l1_accum_enabled=False,
input_channels_alignment=32,
transpose_shards=False,
reshard_if_not_optimal=False, # Reshard has error : 1616 Bytes unique+common runtime args targeting kernel reshard_reader on (x=0,y=0) are too large. Cannot be written as they will run into memory region reserved for result. Max allowable size is 1024 Bytes
)
+ compute_config = ttnn.init_device_compute_kernel_config(
+ self.device.arch(),
+ math_fidelity=ttnn.MathFidelity.LoFi,
+ math_approx_mode=True,
+ fp32_dest_acc_en=True,
+ packer_l1_acc=False,
+ )
if self.conv_config_override and "act_block_h" in self.conv_config_override:
conv_config.act_block_h_override = self.conv_config_override["act_block_h"]
- [tt_out, _out_height, _out_width, self.conv_weight_tensor, self.conv_bias_tensor] = ttnn.conv2d(
+ [tt_out, [self.conv_weight_tensor, self.conv_bias_tensor]] = ttnn.conv2d(
input_tensor=tt_out,
in_channels=self.conv_in_channels,
out_channels=self.conv_out_channels,
@@ -117,6 +120,9 @@ def __call__(self, input, in_channels, out_channels):
weight_tensor=self.conv_weight_tensor,
bias_tensor=self.conv_bias_tensor,
conv_config=conv_config,
+ compute_config=compute_config,
conv_op_cache=conv_cache,
+ return_output_dim=False,
+ return_weights_and_bias=True,
)
return tt_out
diff --git a/models/demos/yolov4/ttnn/common.py b/models/demos/yolov4/ttnn/common.py
index b293a6db751..1579f9112f9 100644
--- a/models/demos/yolov4/ttnn/common.py
+++ b/models/demos/yolov4/ttnn/common.py
@@ -80,13 +80,9 @@ def __call__(self, device, input_tensor):
conv_config = ttnn.Conv2dConfig(
dtype=ttnn.bfloat16,
weights_dtype=ttnn.bfloat8_b,
- math_fidelity=ttnn.MathFidelity.LoFi,
activation=self.activation,
shard_layout=self.shard_layout,
- math_approx_mode_enabled=True,
- fp32_dest_acc_enabled=False,
act_block_w_div=1,
- packer_l1_accum_enabled=False,
input_channels_alignment=16 if self.input_params[3] < 16 else 32,
transpose_shards=False,
reshard_if_not_optimal=self.reshard,
@@ -96,10 +92,17 @@ def __call__(self, device, input_tensor):
enable_act_double_buffer=self.enable_act_double_buffer,
output_layout=self.output_layout,
)
+ compute_config = ttnn.init_device_compute_kernel_config(
+ device.arch(),
+ math_fidelity=ttnn.MathFidelity.LoFi,
+ math_approx_mode=False,
+ fp32_dest_acc_en=False,
+ packer_l1_acc=False,
+ )
if self.act_block_h is not None:
conv_config.act_block_h_override = self.act_block_h
- [output_tensor, _out_height, _out_width, self.weights, self.bias] = ttnn.conv2d(
+ output_tensor, [self.weights, self.bias] = ttnn.conv2d(
input_tensor=input_tensor,
weight_tensor=self.weights,
bias_tensor=self.bias,
@@ -113,5 +116,8 @@ def __call__(self, device, input_tensor):
input_height=self.input_params[1],
input_width=self.input_params[2],
conv_config=conv_config,
+ compute_config=compute_config,
+ return_output_dim=False,
+ return_weights_and_bias=True,
)
return output_tensor
diff --git a/models/demos/yolov4/ttnn/neck.py b/models/demos/yolov4/ttnn/neck.py
index f7e3b541278..d86d9faa527 100644
--- a/models/demos/yolov4/ttnn/neck.py
+++ b/models/demos/yolov4/ttnn/neck.py
@@ -262,7 +262,7 @@ def __call__(self, device, input_tensor):
ttnn.TensorMemoryLayout.BLOCK_SHARDED, ttnn.types.BufferType.L1, shard_spec
)
- output_tensor_upsample_1 = ttnn.upsample(output_tensor, (2, 2, 1), memory_config=out_sharded_mem_config)
+ output_tensor_upsample_1 = ttnn.upsample(output_tensor, (2, 2), memory_config=out_sharded_mem_config)
output_tensor_upsample_1 = ttnn.sharded_to_interleaved(output_tensor_upsample_1, ttnn.L1_MEMORY_CONFIG)
output_tensor_upsample_1 = ttnn.reshape(output_tensor_upsample_1, (1, 1, 400, 256))
output_tensor_upsample_1 = ttnn.to_layout(output_tensor_upsample_1, layout=ttnn.TILE_LAYOUT)
@@ -336,7 +336,7 @@ def __call__(self, device, input_tensor):
ttnn.TensorMemoryLayout.BLOCK_SHARDED, ttnn.types.BufferType.L1, shard_spec
)
- output_tensor_upsample_2 = ttnn.upsample(output_tensor, (2, 2, 1), memory_config=out_sharded_mem_config)
+ output_tensor_upsample_2 = ttnn.upsample(output_tensor, (2, 2), memory_config=out_sharded_mem_config)
output_tensor_upsample_2 = ttnn.sharded_to_interleaved(output_tensor_upsample_2, ttnn.L1_MEMORY_CONFIG)
output_tensor_upsample_2 = ttnn.reshape(output_tensor_upsample_2, (1, 1, 1600, 128))
output_tensor_upsample_2 = ttnn.to_layout(output_tensor_upsample_2, ttnn.TILE_LAYOUT)
diff --git a/models/experimental/functional_unet/tt/unet_shallow_ttnn.py b/models/experimental/functional_unet/tt/unet_shallow_ttnn.py
index 215399ea23b..8a5157d51dc 100644
--- a/models/experimental/functional_unet/tt/unet_shallow_ttnn.py
+++ b/models/experimental/functional_unet/tt/unet_shallow_ttnn.py
@@ -114,10 +114,8 @@ def __init__(
self.conv_config = ttnn.Conv2dConfig(
dtype=activation_dtype,
weights_dtype=weights_dtype,
- math_fidelity=ttnn.MathFidelity.LoFi,
shard_layout=shard_layout,
deallocate_activation=self.deallocate_activation,
- packer_l1_accum_enabled=False,
enable_act_double_buffer=(
conv.use_activation_double_buffer if "use_activation_double_buffer" in conv else False
),
@@ -128,6 +126,12 @@ def __init__(
input_channels_alignment=conv.input_channels_alignment if "input_channels_alignment" in conv else 32,
reshard_if_not_optimal=reshard_if_not_optimal,
)
+ self.compute_config = ttnn.init_device_compute_kernel_config(
+ device.arch(),
+ math_fidelity=ttnn.MathFidelity.LoFi,
+ fp32_dest_acc_en=True,
+ packer_l1_acc=False,
+ )
config_override = conv.conv_blocking_and_parallelization_config_override
if config_override and "act_block_h" in config_override:
self.conv_config.act_block_h_override = config_override["act_block_h"]
@@ -143,7 +147,7 @@ def __init__(
self.bias = ttnn.from_torch(bias, dtype=ttnn.float32, mesh_mapper=mesh_mapper)
def __call__(self, x):
- x, _, _, self.weight, self.bias = ttnn.conv2d(
+ x, [self.weight, self.bias] = ttnn.conv2d(
input_tensor=x,
weight_tensor=self.weight,
bias_tensor=self.bias,
@@ -157,8 +161,11 @@ def __call__(self, x):
stride=self.stride,
padding=self.padding,
conv_config=self.conv_config,
+ compute_config=self.compute_config,
conv_op_cache=self.cache,
groups=2,
+ return_output_dim=False,
+ return_weights_and_bias=True,
)
return x
@@ -257,7 +264,7 @@ def upsample(self, x):
else:
x = ttnn.interleaved_to_sharded(x, shardspec)
- x = ttnn.upsample(x, (2, 2, 1), memory_config=x.memory_config())
+ x = ttnn.upsample(x, (2, 2), memory_config=x.memory_config())
x = ttnn.reshape(
x, (1, 1, self.conv1.batch_size * self.conv1.input_height * self.conv1.input_width, x.shape[-1])
)
diff --git a/models/perf/benchmarking_utils.py b/models/perf/benchmarking_utils.py
index 5ca7ae269c8..8136c4ef0c1 100644
--- a/models/perf/benchmarking_utils.py
+++ b/models/perf/benchmarking_utils.py
@@ -16,6 +16,24 @@ def __init__(self):
self.start_times = dict()
self.end_times = dict()
+ def __call__(self, step_name: str, iteration: int = 0):
+ # Return a context manager for this step
+ return self.StepContext(self, step_name, iteration)
+
+ class StepContext:
+ def __init__(self, profiler, step_name: str, iteration: int):
+ self.profiler = profiler
+ self.step_name = step_name
+ self.iteration = iteration
+
+ def __enter__(self):
+ self.profiler.start(self.step_name, self.iteration)
+ return self.profiler
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self.profiler.end(self.step_name, self.iteration)
+ return False
+
def start(self, step_name: str, iteration: int = 0):
self.start_times[(iteration, step_name)] = datetime.now(tz=pytz.UTC)
diff --git a/tests/nightly/single_card/stable_diffusion/test_basic_transformer_block.py b/tests/nightly/single_card/stable_diffusion/test_basic_transformer_block.py
new file mode 120000
index 00000000000..61408ffa9e7
--- /dev/null
+++ b/tests/nightly/single_card/stable_diffusion/test_basic_transformer_block.py
@@ -0,0 +1 @@
+../../../../models/demos/wormhole/stable_diffusion/tests/test_basic_transformer_block.py
\ No newline at end of file
diff --git a/tests/nightly/single_card/stable_diffusion/test_cross_attention.py b/tests/nightly/single_card/stable_diffusion/test_cross_attention.py
new file mode 120000
index 00000000000..c161012b886
--- /dev/null
+++ b/tests/nightly/single_card/stable_diffusion/test_cross_attention.py
@@ -0,0 +1 @@
+../../../../models/demos/wormhole/stable_diffusion/tests/test_cross_attention.py
\ No newline at end of file
diff --git a/tests/nightly/single_card/stable_diffusion/test_cross_attn_up_block_2d.py b/tests/nightly/single_card/stable_diffusion/test_cross_attn_up_block_2d.py
new file mode 120000
index 00000000000..8fce2d91ed2
--- /dev/null
+++ b/tests/nightly/single_card/stable_diffusion/test_cross_attn_up_block_2d.py
@@ -0,0 +1 @@
+../../../../models/demos/wormhole/stable_diffusion/tests/test_cross_attn_up_block_2d.py
\ No newline at end of file
diff --git a/tests/nightly/single_card/stable_diffusion/test_demo.py b/tests/nightly/single_card/stable_diffusion/test_demo.py
new file mode 120000
index 00000000000..c375047f633
--- /dev/null
+++ b/tests/nightly/single_card/stable_diffusion/test_demo.py
@@ -0,0 +1 @@
+../../../../models/demos/wormhole/stable_diffusion/tests/test_demo.py
\ No newline at end of file
diff --git a/tests/nightly/single_card/stable_diffusion/test_embedding.py b/tests/nightly/single_card/stable_diffusion/test_embedding.py
new file mode 120000
index 00000000000..3e89c128424
--- /dev/null
+++ b/tests/nightly/single_card/stable_diffusion/test_embedding.py
@@ -0,0 +1 @@
+../../../../models/demos/wormhole/stable_diffusion/tests/test_embedding.py
\ No newline at end of file
diff --git a/tests/nightly/single_card/stable_diffusion/test_feedforward.py b/tests/nightly/single_card/stable_diffusion/test_feedforward.py
new file mode 120000
index 00000000000..915332488d5
--- /dev/null
+++ b/tests/nightly/single_card/stable_diffusion/test_feedforward.py
@@ -0,0 +1 @@
+../../../../models/demos/wormhole/stable_diffusion/tests/test_feedforward.py
\ No newline at end of file
diff --git a/tests/nightly/single_card/stable_diffusion/test_geglu.py b/tests/nightly/single_card/stable_diffusion/test_geglu.py
new file mode 120000
index 00000000000..5880ea6e17d
--- /dev/null
+++ b/tests/nightly/single_card/stable_diffusion/test_geglu.py
@@ -0,0 +1 @@
+../../../../models/demos/wormhole/stable_diffusion/tests/test_geglu.py
\ No newline at end of file
diff --git a/tests/nightly/single_card/stable_diffusion/test_resnet_block_2d.py b/tests/nightly/single_card/stable_diffusion/test_resnet_block_2d.py
new file mode 120000
index 00000000000..1b6513e5b50
--- /dev/null
+++ b/tests/nightly/single_card/stable_diffusion/test_resnet_block_2d.py
@@ -0,0 +1 @@
+../../../../models/demos/wormhole/stable_diffusion/tests/test_resnet_block_2d.py
\ No newline at end of file
diff --git a/tests/nightly/single_card/stable_diffusion/test_sharded_matmuls.py b/tests/nightly/single_card/stable_diffusion/test_sharded_matmuls.py
new file mode 120000
index 00000000000..d5d12d47849
--- /dev/null
+++ b/tests/nightly/single_card/stable_diffusion/test_sharded_matmuls.py
@@ -0,0 +1 @@
+../../../../models/demos/wormhole/stable_diffusion/tests/test_sharded_matmuls.py
\ No newline at end of file
diff --git a/tests/nightly/single_card/stable_diffusion/test_transformer_2d_model.py b/tests/nightly/single_card/stable_diffusion/test_transformer_2d_model.py
new file mode 120000
index 00000000000..d82d4a899f6
--- /dev/null
+++ b/tests/nightly/single_card/stable_diffusion/test_transformer_2d_model.py
@@ -0,0 +1 @@
+../../../../models/demos/wormhole/stable_diffusion/tests/test_transformer_2d_model.py
\ No newline at end of file
diff --git a/tests/nightly/single_card/stable_diffusion/test_unet_2d_condition_model.py b/tests/nightly/single_card/stable_diffusion/test_unet_2d_condition_model.py
new file mode 120000
index 00000000000..c25a861ed35
--- /dev/null
+++ b/tests/nightly/single_card/stable_diffusion/test_unet_2d_condition_model.py
@@ -0,0 +1 @@
+../../../../models/demos/wormhole/stable_diffusion/tests/test_unet_2d_condition_model.py
\ No newline at end of file
diff --git a/tests/nightly/single_card/stable_diffusion/test_upblock_2d.py b/tests/nightly/single_card/stable_diffusion/test_upblock_2d.py
new file mode 120000
index 00000000000..3997b30be69
--- /dev/null
+++ b/tests/nightly/single_card/stable_diffusion/test_upblock_2d.py
@@ -0,0 +1 @@
+../../../../models/demos/wormhole/stable_diffusion/tests/test_upblock_2d.py
\ No newline at end of file
diff --git a/tests/nightly/single_card/stable_diffusion/test_upsample_2d.py b/tests/nightly/single_card/stable_diffusion/test_upsample_2d.py
new file mode 120000
index 00000000000..88a98649844
--- /dev/null
+++ b/tests/nightly/single_card/stable_diffusion/test_upsample_2d.py
@@ -0,0 +1 @@
+../../../../models/demos/wormhole/stable_diffusion/tests/test_upsample_2d.py
\ No newline at end of file
diff --git a/tests/nightly/single_card/stable_diffusion/test_upsample_nearest_2d.py b/tests/nightly/single_card/stable_diffusion/test_upsample_nearest_2d.py
new file mode 120000
index 00000000000..815ccb622b4
--- /dev/null
+++ b/tests/nightly/single_card/stable_diffusion/test_upsample_nearest_2d.py
@@ -0,0 +1 @@
+../../../../models/demos/wormhole/stable_diffusion/tests/test_upsample_nearest_2d.py
\ No newline at end of file
diff --git a/tests/nightly/single_card/wh_b0_unstable/tests/ttnn/integration_tests/stable_diffusion b/tests/nightly/single_card/wh_b0_unstable/tests/ttnn/integration_tests/stable_diffusion
deleted file mode 120000
index 608e08f48e2..00000000000
--- a/tests/nightly/single_card/wh_b0_unstable/tests/ttnn/integration_tests/stable_diffusion
+++ /dev/null
@@ -1 +0,0 @@
-../../../../../../../tests/ttnn/integration_tests/stable_diffusion
\ No newline at end of file
diff --git a/tests/scripts/run_performance.sh b/tests/scripts/run_performance.sh
index 7956d1c7b03..7c42512474d 100755
--- a/tests/scripts/run_performance.sh
+++ b/tests/scripts/run_performance.sh
@@ -41,6 +41,8 @@ run_perf_models_other() {
env pytest -n auto models/demos/mnist/tests -m $test_marker
+ env pytest -n auto models/demos/squeezebert/tests/test_performance.py -m $test_marker
+
## Merge all the generated reports
env python models/perf/merge_perf_results.py
}
@@ -71,7 +73,7 @@ run_perf_models_cnn_javelin() {
# Run tests
env WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest models/experimental/functional_unet/tests/test_unet_perf.py -m $test_marker
- env WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest -n auto tests/device_perf_tests/stable_diffusion -m $test_marker --timeout=480
+ env WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest -n auto models/demos/wormhole/stable_diffusion/tests -m $test_marker --timeout=480
## Merge all the generated reports
env python models/perf/merge_perf_results.py
@@ -81,7 +83,7 @@ run_device_perf_models() {
set -eo pipefail
local test_marker=$1
- env pytest tests/device_perf_tests/stable_diffusion -m $test_marker --timeout=600
+ env pytest models/demos/wormhole/stable_diffusion/tests -m $test_marker --timeout=600
env pytest models/demos/distilbert/tests -m $test_marker
@@ -93,6 +95,8 @@ run_device_perf_models() {
env pytest models/demos/mnist/tests -m $test_marker
+ env pytest models/demos/squeezebert/tests -m $test_marker
+
if [ "$tt_arch" == "grayskull" ]; then
#TODO(MO): Until #6560 is fixed, GS device profiler test are grouped with
#Model Device perf regression tests to make sure thy run on no-soft-reset BMs
diff --git a/tests/scripts/single_card/run_single_card_demo_tests.sh b/tests/scripts/single_card/run_single_card_demo_tests.sh
index 5f5642483f6..11a4e96a895 100755
--- a/tests/scripts/single_card/run_single_card_demo_tests.sh
+++ b/tests/scripts/single_card/run_single_card_demo_tests.sh
@@ -15,6 +15,20 @@ run_common_func_tests() {
# Qwen7B
QWEN_DIR=/mnt/MLPerf/tt_dnn-models/qwen/Qwen2-7B-Instruct WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml FAKE_DEVICE=N150 pytest -n auto models/demos/qwen/demo/demo.py -k instruct --timeout 420; fail+=$?
+ # Llama3 Accuracy tests
+ # Llama3.2-1B
+ llama1b=/mnt/MLPerf/tt_dnn-models/llama/Llama3.2-1B-Instruct/
+ # Llama3.2-3B
+ llama3b=/mnt/MLPerf/tt_dnn-models/llama/Llama3.2-3B-Instruct/
+ # Llama3.1-8B (11B weights are the same)
+ llama8b=/mnt/MLPerf/tt_dnn-models/llama/Meta-Llama-3.1-8B-Instruct/
+
+ # Run Llama3 accuracy tests for 1B, 3B, 8B weights
+ for llama_dir in "$llama1b" "$llama3b" "$llama8b"; do
+ LLAMA_DIR=$llama_dir WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest -n auto models/demos/llama3/tests/test_llama_accuracy.py -k perf --timeout 420; fail+=$?
+ echo "LOG_METAL: Llama3 accuracy tests for $llama_dir completed"
+ done
+
#VGG11/VGG16
WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest -n auto models/demos/vgg/demo/demo.py --timeout 600; fail+=$?
@@ -39,6 +53,9 @@ run_common_func_tests() {
# Mnist
pytest --disable-warnings models/demos/mnist/demo/demo.py --timeout 600; fail+=$?
+ # SqueezeBERT
+ pytest --disable-warnings models/demos/squeezebert/demo/demo.py --timeout 600; fail+=$?
+
return $fail
}
diff --git a/tests/scripts/t3000/run_t3000_demo_tests.sh b/tests/scripts/t3000/run_t3000_demo_tests.sh
index 805fa83e97b..627769a5a9e 100755
--- a/tests/scripts/t3000/run_t3000_demo_tests.sh
+++ b/tests/scripts/t3000/run_t3000_demo_tests.sh
@@ -93,7 +93,7 @@ run_t3000_llama3_vision_tests() {
pip install -r models/demos/llama3/requirements.txt
for fake_device in "$n300" "$t3k"; do
- FAKE_DEVICE=$fake_device LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/demo/simple_vision_demo.py -k "cold and yes_trace" --timeout 600; fail+=$?
+ FAKE_DEVICE=$fake_device LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/demo/simple_vision_demo.py -k "batch1-trace" --timeout 600; fail+=$?
echo "LOG_METAL: Llama3 vision tests for $fake_device completed"
done
diff --git a/tests/scripts/t3000/run_t3000_frequent_tests.sh b/tests/scripts/t3000/run_t3000_frequent_tests.sh
index 0058a3fc9e3..3ade2f43355 100755
--- a/tests/scripts/t3000/run_t3000_frequent_tests.sh
+++ b/tests/scripts/t3000/run_t3000_frequent_tests.sh
@@ -63,7 +63,7 @@ run_t3000_llama3_tests() {
# Run test model for llama3 - 1B, 3B, 8B and 11B weights
for llama_dir in "$llama1b" "$llama3b" "$llama8b" "$llama11b"; do
LLAMA_DIR=$llama_dir WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/test_llama_model.py -k full ; fail+=$?
- # LLAMA_DIR=$llama_dir WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/test_llama_model_prefill.py ; fail+=$? # FIXME Issue #14843
+ LLAMA_DIR=$llama_dir WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/test_llama_model_prefill.py ; fail+=$?
echo "LOG_METAL: Llama3 tests for $llama_dir completed"
done
@@ -96,6 +96,40 @@ run_t3000_llama3_70b_tests() {
fi
}
+run_t3000_llama3_accuracy_tests() {
+ # Record the start time
+ fail=0
+ start_time=$(date +%s)
+
+ echo "LOG_METAL: Running run_t3000_llama3_accuracy_tests"
+
+ wh_arch_yaml=wormhole_b0_80_arch_eth_dispatch.yaml
+ # Llama3.2-1B
+ llama1b=/mnt/MLPerf/tt_dnn-models/llama/Llama3.2-1B-Instruct/
+ # Llama3.2-3B
+ llama3b=/mnt/MLPerf/tt_dnn-models/llama/Llama3.2-3B-Instruct/
+ # Llama3.1-8B
+ llama8b=/mnt/MLPerf/tt_dnn-models/llama/Meta-Llama-3.1-8B-Instruct/
+ # Llama3.2-11B
+ llama11b=/mnt/MLPerf/tt_dnn-models/llama/Llama3.2-11B-Vision-Instruct/
+ # Llama3.1-70B
+ llama70b=/mnt/MLPerf/tt_dnn-models/llama/Llama3.1-70B-Instruct/
+
+ # Run test accuracy llama3 - 1B, 3B, 8B, 11B and 70B weights
+ for llama_dir in "$llama1b" "$llama3b" "$llama8b" "$llama11b" "$llama70b"; do
+ LLAMA_DIR=$llama_dir WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/test_llama_accuracy.py -k perf ; fail+=$?
+ echo "LOG_METAL: Llama3 accuracy tests for $llama_dir completed"
+ done
+
+ # Record the end time
+ end_time=$(date +%s)
+ duration=$((end_time - start_time))
+ echo "LOG_METAL: run_t3000_llama3_accuracy_tests $duration seconds to complete"
+ if [[ $fail -ne 0 ]]; then
+ exit 1
+ fi
+}
+
run_t3000_llama3.2-11b-vision_freq_tests() {
# Record the start time
fail=0
@@ -277,6 +311,9 @@ run_t3000_tests() {
# Run llama3-70b tests
run_t3000_llama3_70b_tests
+ # Run llama3 accuracy tests
+ run_t3000_llama3_accuracy_tests
+
# Run Llama3.2-11B Vision tests
run_t3000_llama3.2-11b-vision_freq_tests
diff --git a/tests/scripts/t3000/run_t3000_unit_tests.sh b/tests/scripts/t3000/run_t3000_unit_tests.sh
index 6b33b853a07..60c671c6e83 100755
--- a/tests/scripts/t3000/run_t3000_unit_tests.sh
+++ b/tests/scripts/t3000/run_t3000_unit_tests.sh
@@ -197,8 +197,8 @@ run_t3000_llama3.2-11b-vision_unit_tests() {
LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_image_mlp.py ; fail+=$?
LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_image_attention.py ; fail+=$?
LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_image_block.py ; fail+=$?
- LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_cross_attention.py ; fail+=$?
- LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_cross_block.py ; fail+=$?
+ LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_cross_attention.py -k "batch_1" ; fail+=$?
+ LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_cross_block.py -k "batch_1" ; fail+=$?
LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_conv2d_patch.py ; fail+=$?
LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_class_embedding.py ; fail+=$?
LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_tile_position_embedding.py ; fail+=$?
@@ -232,8 +232,8 @@ run_t3000_spoof_n300_llama3.2-11b-vision_unit_tests() {
FAKE_DEVICE=$fake_device LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_image_mlp.py ; fail+=$?
FAKE_DEVICE=$fake_device LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_image_attention.py ; fail+=$?
FAKE_DEVICE=$fake_device LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_image_block.py ; fail+=$?
- FAKE_DEVICE=$fake_device LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_cross_attention.py ; fail+=$?
- FAKE_DEVICE=$fake_device LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_cross_block.py ; fail+=$?
+ FAKE_DEVICE=$fake_device LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_cross_attention.py -k "batch_1" ; fail+=$?
+ FAKE_DEVICE=$fake_device LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_cross_block.py -k "batch_1" ; fail+=$?
FAKE_DEVICE=$fake_device LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_conv2d_patch.py ; fail+=$?
FAKE_DEVICE=$fake_device LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_class_embedding.py ; fail+=$?
FAKE_DEVICE=$fake_device LLAMA_DIR=$llama11b WH_ARCH_YAML=$wh_arch_yaml pytest -n auto models/demos/llama3/tests/multimodal/test_llama_tile_position_embedding.py ; fail+=$?
diff --git a/tests/sweep_framework/sweep_utils/conv2d_common.py b/tests/sweep_framework/sweep_utils/conv2d_common.py
index 55769adb984..c7509247213 100644
--- a/tests/sweep_framework/sweep_utils/conv2d_common.py
+++ b/tests/sweep_framework/sweep_utils/conv2d_common.py
@@ -117,18 +117,20 @@ def run_full(
conv_config = ttnn.Conv2dConfig(
dtype=activations_dtype,
weights_dtype=weights_dtype,
- math_fidelity=math_fidelity,
shard_layout=None,
deallocate_activation=deallocate_activation,
- fp32_dest_acc_enabled=fp32_accum,
- packer_l1_accum_enabled=packer_l1_acc,
override_sharding_config=override_sharding_config,
output_layout=output_layout,
enable_act_double_buffer=enable_act_double_buffer,
enable_split_reader=enable_split_reader,
enable_subblock_padding=enable_subblock_padding,
)
-
+ compute_config = ttnn.init_device_compute_kernel_config(
+ device.arch(),
+ math_fidelity=math_fidelity,
+ fp32_dest_acc_en=fp32_accum,
+ packer_l1_acc=packer_l1_acc,
+ )
if override_sharding_config:
if len(core_grid) == 2:
conv_config.core_grid = ttnn.CoreRangeSet({ttnn.CoreRange(core_grid[0], core_grid[1])})
@@ -137,7 +139,7 @@ def run_full(
{ttnn.CoreRange(core_grid[0], core_grid[1]), ttnn.CoreRange(core_grid[2], core_grid[3])}
)
start_time = start_measuring_time()
- [tt_output_tensor_on_device, out_height, out_width, weights_device, bias_device] = ttnn.conv2d(
+ [tt_output_tensor_on_device, [out_height, out_width], [weights_device, bias_device]] = ttnn.conv2d(
input_tensor=tt_input_tensor,
weight_tensor=tt_weight_tensor,
in_channels=input_channels,
@@ -152,7 +154,10 @@ def run_full(
input_height=input_height,
input_width=input_width,
conv_config=conv_config,
+ compute_config=compute_config,
groups=groups,
+ return_output_dim=True,
+ return_weights_and_bias=True,
)
tt_output_tensor = ttnn.from_device(tt_output_tensor_on_device)
@@ -220,7 +225,7 @@ def run_short(
tt_input_tensor = ttnn.from_torch(torch_input_tensor, ttnn.bfloat16)
start_time = start_measuring_time()
- [tt_output_tensor_on_device, out_height, out_width, weights_device, bias_device] = ttnn.conv2d(
+ [tt_output_tensor_on_device, [out_height, out_width], [weights_device, bias_device]] = ttnn.conv2d(
input_tensor=tt_input_tensor,
weight_tensor=tt_weight_tensor,
in_channels=input_channels,
@@ -235,6 +240,8 @@ def run_short(
input_height=input_height,
input_width=input_width,
groups=groups,
+ return_output_dim=True,
+ return_weights_and_bias=True,
)
tt_output_tensor = ttnn.from_device(tt_output_tensor_on_device)
diff --git a/tests/sweep_framework/sweep_utils/utils.py b/tests/sweep_framework/sweep_utils/utils.py
index 6f2199055d8..ef6349c71e9 100644
--- a/tests/sweep_framework/sweep_utils/utils.py
+++ b/tests/sweep_framework/sweep_utils/utils.py
@@ -220,6 +220,28 @@ def gen_split_qkv_heads_spec(
}
+def gen_rotary_embedding_spec(
+ input_shape_list,
+ cache_size_list,
+ use_token_idx_list=[True, False],
+):
+ for input_shape, cache_size, use_token_idx in itertools.product(
+ input_shape_list, cache_size_list, use_token_idx_list
+ ):
+ input_shape_ = input_shape.copy()
+ if use_token_idx is True:
+ token_idx = random.randint(1, cache_size - 1)
+ input_shape_[0] = 1
+ else:
+ token_idx = None
+
+ yield {
+ "input_shape": input_shape_,
+ "cache_size": cache_size,
+ "token_idx": token_idx,
+ }
+
+
def gen_complex_tensor(input_shape, low, high, dtype=ttnn.bfloat16):
torch_real = gen_func_with_cast_tt(partial(torch_random, low=-100, high=100, dtype=torch.float32), dtype)(
input_shape
diff --git a/tests/sweep_framework/sweeps/conv2d/short/conv2d_short_sweep.py b/tests/sweep_framework/sweeps/conv2d/short/conv2d_short_sweep.py
index 97fa8debcbb..c41f6be9092 100644
--- a/tests/sweep_framework/sweeps/conv2d/short/conv2d_short_sweep.py
+++ b/tests/sweep_framework/sweeps/conv2d/short/conv2d_short_sweep.py
@@ -19,1551 +19,1551 @@
"input_specs": [
# Contains following params
# [batch_size, output_channels, input_channels, input_height, input_width, kernel_height, kernel_width, stride_x, stride_y, pad_x, pad_y, groups, bias, dilation]
+ [1, 960, 960, 27, 27, 5, 5, 2, 2, 0, 0, 960, False, 1],
+ [1, 960, 960, 3, 3, 1, 5, 1, 1, 0, 2, 960, False, 1],
+ [1, 960, 960, 3, 3, 5, 1, 1, 1, 2, 0, 960, False, 1],
+ [1, 960, 960, 7, 7, 3, 3, 1, 1, 1, 1, 960, False, 1],
+ [1, 960, 960, 7, 7, 5, 5, 1, 1, 2, 2, 960, False, 1],
+ [1, 960, 960, 24, 24, 5, 5, 1, 1, 2, 2, 960, False, 1],
+ [1, 96, 96, 112, 112, 3, 3, 2, 2, 1, 1, 96, False, 1],
+ [1, 96, 96, 113, 113, 3, 3, 2, 2, 0, 0, 96, False, 1],
+ [1, 96, 96, 121, 121, 3, 3, 2, 2, 0, 0, 96, False, 1],
+ [1, 96, 96, 131, 131, 3, 3, 2, 2, 0, 0, 96, False, 1],
+ [1, 96, 96, 28, 28, 5, 5, 2, 2, 2, 2, 96, False, 1],
+ [1, 92, 92, 14, 14, 3, 3, 1, 1, 1, 1, 92, False, 1],
+ [1, 144, 144, 28, 28, 3, 3, 1, 1, 1, 1, 9, False, 1],
+ [1, 144, 144, 56, 56, 3, 3, 2, 2, 1, 1, 9, False, 1],
+ [1, 216, 216, 28, 28, 3, 3, 1, 1, 1, 1, 9, False, 1],
+ [1, 216, 216, 56, 56, 3, 3, 2, 2, 1, 1, 9, False, 1],
+ [1, 432, 432, 14, 14, 3, 3, 1, 1, 1, 1, 9, False, 1],
+ [1, 432, 432, 28, 28, 3, 3, 2, 2, 1, 1, 9, False, 1],
+ [1, 88, 88, 28, 28, 3, 3, 1, 1, 1, 1, 88, False, 1],
+ [1, 816, 816, 19, 19, 5, 5, 1, 1, 2, 2, 816, False, 1],
+ [1, 816, 816, 23, 23, 5, 5, 2, 2, 0, 0, 816, False, 1],
+ [1, 80, 80, 14, 14, 3, 3, 1, 1, 1, 1, 80, False, 1],
+ [1, 80, 80, 7, 7, 3, 3, 1, 1, 1, 1, 80, False, 1],
+ [1, 128, 128, 28, 28, 3, 3, 1, 1, 1, 1, 8, False, 1],
+ [1, 128, 128, 56, 56, 3, 3, 2, 2, 1, 1, 8, False, 1],
+ [1, 1344, 1344, 14, 14, 3, 3, 1, 1, 1, 1, 8, False, 1],
+ [1, 1344, 1344, 28, 28, 3, 3, 2, 2, 1, 1, 8, False, 1],
+ [1, 448, 448, 28, 28, 3, 3, 1, 1, 1, 1, 8, False, 1],
+ [1, 448, 448, 56, 56, 3, 3, 2, 2, 1, 1, 8, False, 1],
+ [1, 8, 8, 112, 112, 3, 3, 1, 1, 1, 1, 8, False, 1],
+ [1, 728, 728, 19, 19, 3, 3, 1, 1, 1, 1, 728, False, 1],
+ [1, 728, 728, 38, 38, 3, 3, 2, 2, 1, 1, 728, False, 1],
+ [1, 728, 728, 38, 38, 3, 3, 1, 1, 1, 1, 728, False, 1],
+ [1, 720, 720, 17, 17, 5, 5, 1, 1, 2, 2, 720, False, 1],
+ [1, 720, 720, 21, 21, 5, 5, 2, 2, 0, 0, 720, False, 1],
+ [1, 72, 72, 28, 28, 1, 5, 1, 1, 0, 2, 72, False, 1],
+ [1, 72, 72, 28, 28, 5, 1, 1, 1, 2, 0, 72, False, 1],
+ [1, 72, 72, 56, 56, 3, 3, 1, 1, 1, 1, 72, False, 1],
+ [1, 72, 72, 56, 56, 3, 3, 2, 2, 1, 1, 72, False, 1],
+ [1, 72, 72, 56, 56, 5, 5, 2, 2, 2, 2, 72, False, 1],
+ [1, 72, 72, 80, 80, 3, 3, 1, 1, 1, 1, 72, False, 1],
+ [1, 72, 72, 80, 80, 5, 5, 2, 2, 2, 2, 72, False, 1],
+ [1, 168, 168, 28, 28, 3, 3, 1, 1, 1, 1, 7, False, 1],
+ [1, 168, 168, 56, 56, 3, 3, 2, 2, 1, 1, 7, False, 1],
+ [1, 896, 896, 14, 14, 3, 3, 1, 1, 1, 1, 7, False, 1],
+ [1, 896, 896, 28, 28, 3, 3, 2, 2, 1, 1, 7, False, 1],
+ [1, 672, 672, 14, 14, 3, 3, 1, 1, 1, 1, 672, False, 1],
+ [1, 672, 672, 14, 14, 5, 5, 1, 1, 2, 2, 672, False, 1],
+ [1, 672, 672, 14, 14, 5, 5, 2, 2, 2, 2, 672, False, 1],
+ [1, 672, 672, 15, 15, 5, 5, 1, 1, 2, 2, 672, False, 1],
+ [1, 672, 672, 17, 17, 5, 5, 2, 2, 0, 0, 672, False, 1],
+ [1, 672, 672, 19, 19, 5, 5, 2, 2, 0, 0, 672, False, 1],
+ [1, 672, 672, 20, 20, 3, 3, 1, 1, 1, 1, 672, False, 1],
+ [1, 672, 672, 20, 20, 5, 5, 2, 2, 2, 2, 672, False, 1],
+ [1, 672, 672, 24, 24, 3, 3, 1, 1, 1, 1, 672, False, 1],
+ [1, 672, 672, 24, 24, 5, 5, 1, 1, 2, 2, 672, False, 1],
+ [1, 672, 672, 7, 7, 1, 5, 1, 1, 0, 2, 672, False, 1],
+ [1, 672, 672, 7, 7, 5, 1, 1, 1, 2, 0, 672, False, 1],
+ [1, 640, 640, 32, 32, 3, 3, 1, 1, 1, 1, 640, True, 1],
+ [1, 1024, 1024, 14, 14, 3, 3, 1, 1, 1, 1, 64, False, 1],
+ [1, 1024, 1024, 28, 28, 3, 3, 2, 2, 1, 1, 64, False, 1],
+ [1, 2048, 2048, 14, 14, 3, 3, 2, 2, 1, 1, 64, False, 1],
+ [1, 2048, 2048, 7, 7, 3, 3, 1, 1, 1, 1, 64, False, 1],
+ [1, 512, 512, 28, 28, 3, 3, 1, 1, 1, 1, 64, False, 1],
+ [1, 512, 512, 56, 56, 3, 3, 2, 2, 1, 1, 64, False, 1],
+ [1, 64, 64, 112, 112, 3, 3, 1, 1, 1, 1, 64, False, 1],
+ [1, 64, 64, 112, 112, 3, 3, 2, 2, 1, 1, 64, False, 1],
+ [1, 64, 64, 150, 150, 3, 3, 1, 1, 1, 1, 64, False, 1],
+ [1, 64, 64, 160, 160, 3, 3, 2, 2, 1, 1, 64, False, 1],
+ [1, 64, 64, 2, 2, 3, 3, 2, 2, 1, 1, 64, False, 1],
+ [1, 256, 256, 56, 56, 3, 3, 1, 1, 1, 1, 64, False, 1],
+ [1, 1512, 1512, 14, 14, 3, 3, 2, 2, 1, 1, 63, False, 1],
+ [1, 60, 60, 28, 28, 3, 3, 1, 1, 1, 1, 60, False, 1],
+ [1, 1392, 1392, 14, 14, 3, 3, 1, 1, 1, 1, 6, False, 1],
+ [1, 1392, 1392, 28, 28, 3, 3, 2, 2, 1, 1, 6, False, 1],
+ [1, 48, 48, 112, 112, 3, 3, 2, 2, 1, 1, 6, False, 1],
+ [1, 720, 720, 14, 14, 3, 3, 1, 1, 1, 1, 6, False, 1],
+ [1, 720, 720, 28, 28, 3, 3, 2, 2, 1, 1, 6, False, 1],
+ [1, 576, 576, 14, 14, 3, 3, 1, 1, 1, 1, 576, False, 1],
+ [1, 576, 576, 14, 14, 3, 3, 2, 2, 1, 1, 576, False, 1],
+ [1, 576, 576, 19, 19, 3, 3, 1, 1, 1, 1, 576, False, 1],
+ [1, 576, 576, 19, 19, 5, 5, 1, 1, 2, 2, 576, False, 1],
+ [1, 576, 576, 7, 7, 5, 5, 1, 1, 2, 2, 576, False, 1],
+ [1, 56, 56, 14, 14, 3, 3, 1, 1, 1, 1, 56, False, 1],
+ [1, 440, 440, 14, 14, 3, 3, 2, 2, 1, 1, 55, False, 1],
+ [1, 440, 440, 7, 7, 3, 3, 1, 1, 1, 1, 55, False, 1],
+ [1, 528, 528, 17, 17, 3, 3, 1, 1, 1, 1, 528, False, 1],
+ [1, 528, 528, 17, 17, 5, 5, 1, 1, 2, 2, 528, False, 1],
+ [1, 512, 512, 14, 14, 3, 3, 1, 1, 1, 1, 512, False, 1],
+ [1, 512, 512, 14, 14, 3, 3, 2, 2, 1, 1, 512, False, 1],
+ [1, 512, 512, 28, 28, 3, 3, 1, 1, 1, 1, 512, False, 1],
+ [1, 512, 512, 28, 28, 3, 3, 1, 1, 2, 2, 512, False, 1],
+ [1, 512, 512, 5, 5, 3, 3, 1, 1, 1, 1, 512, False, 1],
+ [1, 512, 512, 60, 80, 3, 3, 1, 1, 1, 1, 512, True, 1],
+ [1, 120, 120, 28, 28, 3, 3, 1, 1, 1, 1, 5, False, 1],
+ [1, 120, 120, 56, 56, 3, 3, 2, 2, 1, 1, 5, False, 1],
+ [1, 784, 784, 14, 14, 3, 3, 2, 2, 1, 1, 49, False, 1],
+ [1, 784, 784, 7, 7, 3, 3, 1, 1, 1, 1, 49, False, 1],
+ [1, 480, 480, 10, 10, 3, 3, 1, 1, 1, 1, 480, False, 1],
+ [1, 480, 480, 10, 10, 5, 5, 1, 1, 2, 2, 480, False, 1],
+ [1, 480, 480, 14, 14, 3, 3, 1, 1, 1, 1, 480, False, 1],
+ [1, 480, 480, 14, 14, 5, 5, 1, 1, 2, 2, 480, False, 1],
+ [1, 480, 480, 15, 15, 3, 3, 1, 1, 1, 1, 480, False, 1],
+ [1, 480, 480, 15, 15, 5, 5, 1, 1, 2, 2, 480, False, 1],
+ [1, 480, 480, 20, 20, 3, 3, 1, 1, 1, 1, 480, False, 1],
+ [1, 480, 480, 7, 7, 1, 5, 1, 1, 0, 2, 480, False, 1],
+ [1, 480, 480, 7, 7, 3, 3, 1, 1, 1, 1, 480, False, 1],
+ [1, 480, 480, 7, 7, 5, 1, 1, 1, 2, 0, 480, False, 1],
+ [1, 48, 48, 112, 112, 3, 3, 2, 2, 1, 1, 48, False, 1],
+ [1, 672, 672, 14, 14, 3, 3, 2, 2, 1, 1, 42, False, 1],
+ [1, 672, 672, 7, 7, 3, 3, 1, 1, 1, 1, 42, False, 1],
+ [1, 40, 40, 14, 14, 3, 3, 1, 1, 1, 1, 40, False, 1],
+ [1, 40, 40, 28, 28, 3, 3, 2, 2, 1, 1, 40, False, 1],
+ [1, 192, 192, 28, 28, 3, 3, 1, 1, 1, 1, 4, False, 1],
+ [1, 192, 192, 56, 56, 3, 3, 2, 2, 1, 1, 4, False, 1],
+ [1, 224, 224, 112, 112, 3, 3, 2, 2, 1, 1, 4, False, 1],
+ [1, 224, 224, 56, 56, 3, 3, 1, 1, 1, 1, 4, False, 1],
+ [1, 448, 448, 28, 28, 3, 3, 1, 1, 1, 1, 4, False, 1],
+ [1, 448, 448, 56, 56, 3, 3, 2, 2, 1, 1, 4, False, 1],
+ [1, 512, 512, 28, 28, 3, 3, 1, 1, 1, 1, 4, False, 1],
+ [1, 512, 512, 56, 56, 3, 3, 2, 2, 1, 1, 4, False, 1],
+ [1, 64, 64, 112, 112, 3, 3, 2, 2, 1, 1, 4, False, 1],
+ [1, 64, 64, 28, 28, 3, 3, 1, 1, 1, 1, 4, False, 1],
+ [1, 64, 64, 56, 56, 3, 3, 2, 2, 1, 1, 4, False, 1],
+ [1, 672, 672, 28, 28, 3, 3, 1, 1, 1, 1, 4, False, 1],
+ [1, 672, 672, 56, 56, 3, 3, 2, 2, 1, 1, 4, False, 1],
+ [1, 1056, 1056, 48, 48, 3, 3, 1, 1, 1, 1, 4, False, 1],
+ [1, 1056, 1056, 96, 96, 3, 3, 2, 2, 1, 1, 4, False, 1],
+ [1, 384, 384, 14, 14, 3, 3, 1, 1, 1, 1, 384, False, 1],
+ [1, 912, 912, 14, 14, 3, 3, 2, 2, 1, 1, 38, False, 1],
+ [1, 912, 912, 7, 7, 3, 3, 1, 1, 1, 1, 38, False, 1],
+ [1, 888, 888, 14, 14, 3, 3, 2, 2, 1, 1, 37, False, 1],
+ [1, 888, 888, 7, 7, 3, 3, 1, 1, 1, 1, 37, False, 1],
+ [1, 2016, 2016, 14, 14, 3, 3, 2, 2, 1, 1, 36, False, 1],
+ [1, 36, 36, 56, 56, 3, 3, 1, 1, 1, 1, 36, False, 1],
+ [1, 336, 336, 14, 14, 3, 3, 1, 1, 1, 1, 336, False, 1],
+ [1, 336, 336, 49, 49, 3, 3, 2, 2, 0, 0, 336, False, 1],
+ [1, 336, 336, 48, 48, 5, 5, 1, 1, 2, 2, 336, False, 1],
+ [1, 1024, 1024, 14, 14, 3, 3, 1, 1, 1, 1, 32, False, 1],
+ [1, 1024, 1024, 14, 14, 3, 3, 2, 2, 1, 1, 32, False, 1],
+ [1, 1024, 1024, 28, 28, 3, 3, 2, 2, 1, 1, 32, False, 1],
+ [1, 1024, 1024, 7, 7, 3, 3, 1, 1, 1, 1, 32, False, 1],
+ [1, 128, 128, 56, 56, 3, 3, 1, 1, 1, 1, 32, False, 1],
+ [1, 2048, 2048, 14, 14, 3, 3, 2, 2, 1, 1, 32, False, 1],
+ [1, 2048, 2048, 7, 7, 3, 3, 1, 1, 1, 1, 32, False, 1],
+ [1, 256, 256, 28, 28, 3, 3, 1, 1, 1, 1, 32, False, 1],
+ [1, 256, 256, 56, 56, 3, 3, 2, 2, 1, 1, 32, False, 1],
+ [1, 32, 32, 112, 112, 3, 3, 1, 1, 1, 1, 32, False, 1],
+ [1, 32, 32, 120, 120, 3, 3, 1, 1, 1, 1, 32, False, 1],
+ [1, 32, 32, 130, 130, 3, 3, 1, 1, 1, 1, 32, False, 1],
+ [1, 32, 32, 150, 150, 3, 3, 1, 1, 1, 1, 32, False, 1],
+ [1, 32, 32, 190, 190, 3, 3, 1, 1, 1, 1, 32, False, 1],
+ [1, 512, 512, 14, 14, 3, 3, 1, 1, 1, 1, 32, False, 1],
+ [1, 512, 512, 28, 28, 3, 3, 1, 1, 1, 1, 32, False, 1],
+ [1, 512, 512, 28, 28, 3, 3, 2, 2, 1, 1, 32, False, 1],
+ [1, 512, 512, 56, 56, 3, 3, 2, 2, 1, 1, 32, False, 1],
+ [1, 256, 256, 56, 56, 3, 3, 1, 1, 1, 1, 32, False, 1],
+ [1, 72, 72, 112, 112, 3, 3, 2, 2, 1, 1, 3, False, 1],
+ [1, 72, 72, 56, 56, 3, 3, 1, 1, 1, 1, 3, False, 1],
+ [1, 696, 696, 28, 28, 3, 3, 1, 1, 1, 1, 3, False, 1],
+ [1, 696, 696, 56, 56, 3, 3, 2, 2, 1, 1, 3, False, 1],
+ [1, 288, 288, 14, 14, 5, 5, 2, 2, 2, 2, 288, False, 1],
+ [1, 288, 288, 33, 33, 5, 5, 1, 1, 2, 2, 288, False, 1],
+ [1, 288, 288, 35, 35, 3, 3, 2, 2, 0, 0, 288, False, 1],
+ [1, 288, 288, 38, 38, 5, 5, 1, 1, 2, 2, 288, False, 1],
+ [1, 288, 288, 39, 39, 3, 3, 2, 2, 0, 0, 288, False, 1],
+ [1, 7392, 7392, 24, 24, 3, 3, 2, 2, 1, 1, 28, False, 1],
+ [1, 3024, 3024, 14, 14, 3, 3, 2, 2, 1, 1, 27, False, 1],
+ [1, 208, 208, 14, 14, 3, 3, 1, 1, 1, 1, 26, False, 1],
+ [1, 208, 208, 28, 28, 3, 3, 2, 2, 1, 1, 26, False, 1],
+ [1, 256, 256, 10, 10, 3, 3, 2, 2, 1, 1, 256, False, 1],
+ [1, 256, 256, 2, 2, 3, 3, 1, 1, 1, 1, 256, False, 1],
+ [1, 256, 256, 28, 28, 3, 3, 1, 1, 1, 1, 256, False, 1],
+ [1, 256, 256, 28, 28, 3, 3, 2, 2, 1, 1, 256, False, 1],
+ [1, 256, 256, 3, 3, 3, 3, 1, 1, 1, 1, 256, False, 1],
+ [1, 256, 256, 38, 38, 3, 3, 1, 1, 1, 1, 256, False, 1],
+ [1, 256, 256, 64, 64, 3, 3, 1, 1, 1, 1, 256, True, 1],
+ [1, 256, 256, 75, 75, 3, 3, 2, 2, 1, 1, 256, False, 1],
+ [1, 256, 256, 120, 160, 3, 3, 1, 1, 1, 1, 256, True, 1],
+ [1, 256, 256, 75, 75, 3, 3, 1, 1, 1, 1, 256, False, 1],
+ [1, 400, 400, 14, 14, 3, 3, 2, 2, 1, 1, 25, False, 1],
+ [1, 400, 400, 7, 7, 3, 3, 1, 1, 1, 1, 25, False, 1],
+ [1, 240, 240, 14, 14, 1, 5, 1, 1, 0, 2, 240, False, 1],
+ [1, 240, 240, 14, 14, 3, 3, 1, 1, 1, 1, 240, False, 1],
+ [1, 240, 240, 14, 14, 5, 1, 1, 1, 2, 0, 240, False, 1],
+ [1, 240, 240, 14, 14, 5, 5, 1, 1, 2, 2, 240, False, 1],
+ [1, 240, 240, 28, 28, 3, 3, 2, 2, 1, 1, 240, False, 1],
+ [1, 240, 240, 28, 28, 5, 5, 1, 1, 2, 2, 240, False, 1],
+ [1, 240, 240, 29, 29, 3, 3, 2, 2, 0, 0, 240, False, 1],
+ [1, 240, 240, 30, 30, 5, 5, 1, 1, 2, 2, 240, False, 1],
+ [1, 240, 240, 31, 31, 3, 3, 2, 2, 0, 0, 240, False, 1],
+ [1, 240, 240, 40, 40, 3, 3, 2, 2, 1, 1, 240, False, 1],
+ [1, 24, 24, 112, 112, 3, 3, 1, 1, 1, 1, 24, False, 1],
+ [1, 24, 24, 56, 56, 5, 5, 2, 2, 2, 2, 24, False, 1],
+ [1, 576, 576, 14, 14, 3, 3, 1, 1, 1, 1, 24, False, 1],
+ [1, 576, 576, 28, 28, 3, 3, 2, 2, 1, 1, 24, False, 1],
+ [1, 224, 224, 7, 7, 3, 3, 1, 1, 1, 1, 224, False, 1],
+ [1, 1008, 1008, 14, 14, 3, 3, 2, 2, 1, 1, 21, False, 1],
+ [1, 1008, 1008, 7, 7, 3, 3, 1, 1, 1, 1, 21, False, 1],
+ [1, 2048, 2048, 15, 20, 3, 3, 1, 1, 1, 1, 2048, True, 1],
+ [1, 200, 200, 14, 14, 3, 3, 1, 1, 1, 1, 200, False, 1],
+ [1, 200, 200, 20, 20, 3, 3, 1, 1, 1, 1, 200, False, 1],
+ [1, 200, 200, 7, 7, 1, 5, 1, 1, 0, 2, 200, False, 1],
+ [1, 200, 200, 7, 7, 5, 1, 1, 1, 2, 0, 200, False, 1],
+ [1, 20, 20, 28, 28, 3, 3, 1, 1, 1, 1, 20, False, 1],
+ [1, 320, 320, 14, 14, 3, 3, 1, 1, 1, 1, 20, False, 1],
+ [1, 320, 320, 28, 28, 3, 3, 2, 2, 1, 1, 20, False, 1],
+ [1, 224, 224, 112, 112, 3, 3, 2, 2, 1, 1, 2, False, 1],
+ [1, 224, 224, 56, 56, 3, 3, 1, 1, 1, 1, 2, False, 1],
+ [1, 240, 240, 28, 28, 3, 3, 1, 1, 1, 1, 2, False, 1],
+ [1, 240, 240, 56, 56, 3, 3, 2, 2, 1, 1, 2, False, 1],
+ [1, 32, 32, 112, 112, 3, 3, 2, 2, 1, 1, 2, False, 1],
+ [1, 48, 48, 112, 112, 3, 3, 2, 2, 1, 1, 2, False, 1],
+ [1, 48, 48, 56, 56, 3, 3, 1, 1, 1, 1, 2, False, 1],
+ [1, 96, 96, 112, 112, 3, 3, 2, 2, 1, 1, 2, False, 1],
+ [1, 96, 96, 56, 56, 3, 3, 1, 1, 1, 1, 2, False, 1],
+ [1, 256, 256, 112, 112, 3, 3, 2, 2, 1, 1, 2, False, 1],
+ [1, 256, 256, 56, 56, 3, 3, 1, 1, 1, 1, 2, False, 1],
+ [1, 336, 336, 112, 112, 3, 3, 2, 2, 1, 1, 2, False, 1],
+ [1, 336, 336, 56, 56, 3, 3, 1, 1, 1, 1, 2, False, 1],
+ [1, 528, 528, 192, 192, 3, 3, 2, 2, 1, 1, 2, False, 1],
+ [1, 528, 528, 96, 96, 3, 3, 1, 1, 1, 1, 2, False, 1],
+ [1, 192, 192, 14, 14, 3, 3, 1, 1, 1, 1, 192, False, 1],
+ [1, 192, 192, 28, 28, 3, 3, 1, 1, 1, 1, 192, False, 1],
+ [1, 192, 192, 28, 28, 3, 3, 2, 2, 1, 1, 192, False, 1],
+ [1, 192, 192, 75, 75, 3, 3, 1, 1, 1, 1, 192, False, 1],
+ [1, 192, 192, 79, 79, 5, 5, 2, 2, 0, 0, 192, False, 1],
+ [1, 192, 192, 95, 95, 3, 3, 1, 1, 1, 1, 192, False, 1],
+ [1, 192, 192, 99, 99, 5, 5, 2, 2, 0, 0, 192, False, 1],
+ [1, 184, 184, 14, 14, 3, 3, 1, 1, 1, 1, 184, False, 1],
+ [1, 184, 184, 20, 20, 3, 3, 1, 1, 1, 1, 184, False, 1],
+ [1, 184, 184, 7, 7, 1, 5, 1, 1, 0, 2, 184, False, 1],
+ [1, 184, 184, 7, 7, 5, 1, 1, 1, 2, 0, 184, False, 1],
+ [1, 288, 288, 14, 14, 3, 3, 1, 1, 1, 1, 18, False, 1],
+ [1, 288, 288, 28, 28, 3, 3, 2, 2, 1, 1, 18, False, 1],
+ [1, 408, 408, 14, 14, 3, 3, 1, 1, 1, 1, 17, False, 1],
+ [1, 408, 408, 28, 28, 3, 3, 2, 2, 1, 1, 17, False, 1],
+ [1, 1632, 1632, 12, 12, 3, 3, 1, 1, 1, 1, 1632, False, 1],
+ [1, 1632, 1632, 12, 12, 5, 5, 1, 1, 2, 2, 1632, False, 1],
+ [1, 160, 160, 28, 28, 3, 3, 1, 1, 1, 1, 160, False, 1],
+ [1, 16, 16, 112, 112, 3, 3, 1, 1, 1, 1, 16, False, 1],
+ [1, 16, 16, 112, 112, 3, 3, 2, 2, 1, 1, 16, False, 1],
+ [1, 16, 16, 160, 160, 3, 3, 1, 1, 1, 1, 16, False, 1],
+ [1, 1920, 1920, 14, 14, 3, 3, 2, 2, 1, 1, 16, False, 1],
+ [1, 2048, 2048, 14, 14, 3, 3, 2, 2, 1, 1, 16, False, 1],
+ [1, 3712, 3712, 14, 14, 3, 3, 2, 2, 1, 1, 16, False, 1],
+ [1, 896, 896, 14, 14, 3, 3, 1, 1, 1, 1, 16, False, 1],
+ [1, 896, 896, 28, 28, 3, 3, 2, 2, 1, 1, 16, False, 1],
+ [1, 1536, 1536, 10, 10, 3, 3, 1, 1, 1, 1, 1536, False, 1],
+ [1, 2520, 2520, 14, 14, 3, 3, 2, 2, 1, 1, 15, False, 1],
+ [1, 144, 144, 14, 14, 5, 5, 1, 1, 2, 2, 144, False, 1],
+ [1, 144, 144, 151, 151, 3, 3, 2, 2, 0, 0, 144, False, 1],
+ [1, 144, 144, 191, 191, 3, 3, 2, 2, 0, 0, 144, False, 1],
+ [1, 144, 144, 56, 56, 3, 3, 1, 1, 1, 1, 144, False, 1],
+ [1, 144, 144, 56, 56, 3, 3, 2, 2, 1, 1, 144, False, 1],
+ [1, 144, 144, 59, 59, 5, 5, 2, 2, 0, 0, 144, False, 1],
+ [1, 144, 144, 60, 60, 3, 3, 1, 1, 1, 1, 144, False, 1],
+ [1, 144, 144, 63, 63, 5, 5, 2, 2, 0, 0, 144, False, 1],
+ [1, 144, 144, 65, 65, 3, 3, 1, 1, 1, 1, 144, False, 1],
+ [1, 144, 144, 69, 69, 5, 5, 2, 2, 0, 0, 144, False, 1],
+ [1, 336, 336, 14, 14, 3, 3, 1, 1, 1, 1, 14, False, 1],
+ [1, 336, 336, 28, 28, 3, 3, 2, 2, 1, 1, 14, False, 1],
+ [1, 1392, 1392, 10, 10, 3, 3, 1, 1, 1, 1, 1392, False, 1],
+ [1, 1392, 1392, 10, 10, 5, 5, 1, 1, 2, 2, 1392, False, 1],
+ [1, 104, 104, 28, 28, 3, 3, 1, 1, 1, 1, 13, False, 1],
+ [1, 104, 104, 56, 56, 3, 3, 2, 2, 1, 1, 13, False, 1],
+ [1, 1280, 1280, 30, 40, 3, 3, 1, 1, 1, 1, 1280, True, 1],
+ [1, 128, 128, 1, 1, 3, 3, 1, 1, 1, 1, 128, False, 1],
+ [1, 128, 128, 128, 128, 3, 3, 1, 1, 1, 1, 128, True, 1],
+ [1, 128, 128, 150, 150, 3, 3, 1, 1, 1, 1, 128, False, 1],
+ [1, 128, 128, 150, 150, 3, 3, 2, 2, 1, 1, 128, False, 1],
+ [1, 128, 128, 28, 28, 3, 3, 1, 1, 1, 1, 128, False, 1],
+ [1, 128, 128, 3, 3, 3, 3, 2, 2, 1, 1, 128, False, 1],
+ [1, 128, 128, 56, 56, 3, 3, 1, 1, 1, 1, 128, False, 1],
+ [1, 128, 128, 56, 56, 3, 3, 2, 2, 1, 1, 128, False, 1],
+ [1, 128, 128, 75, 75, 3, 3, 1, 1, 1, 1, 128, False, 1],
+ [1, 128, 128, 5, 5, 3, 3, 2, 2, 1, 1, 128, False, 1],
+ [1, 1248, 1248, 9, 9, 3, 3, 1, 1, 1, 1, 1248, False, 1],
+ [1, 1248, 1248, 9, 9, 5, 5, 1, 1, 2, 2, 1248, False, 1],
+ [1, 120, 120, 14, 14, 1, 5, 1, 1, 0, 2, 120, False, 1],
+ [1, 120, 120, 14, 14, 5, 1, 1, 1, 2, 0, 120, False, 1],
+ [1, 120, 120, 14, 14, 5, 5, 1, 1, 2, 2, 120, False, 1],
+ [1, 120, 120, 28, 28, 3, 3, 1, 1, 1, 1, 120, False, 1],
+ [1, 120, 120, 28, 28, 5, 5, 1, 1, 2, 2, 120, False, 1],
+ [1, 120, 120, 40, 40, 5, 5, 1, 1, 2, 2, 120, False, 1],
+ [1, 12, 12, 56, 56, 3, 3, 1, 1, 1, 1, 12, False, 1],
+ [1, 1152, 1152, 7, 7, 3, 3, 1, 1, 1, 1, 1152, False, 1],
+ [1, 1152, 1152, 7, 7, 5, 5, 1, 1, 2, 2, 1152, False, 1],
+ [1, 1152, 1152, 8, 8, 3, 3, 1, 1, 1, 1, 1152, False, 1],
+ [1, 1152, 1152, 8, 8, 5, 5, 1, 1, 2, 2, 1152, False, 1],
+ [1, 112, 112, 14, 14, 5, 5, 2, 2, 2, 2, 112, False, 1],
+ [1, 1232, 1232, 14, 14, 3, 3, 1, 1, 1, 1, 11, False, 1],
+ [1, 1232, 1232, 28, 28, 3, 3, 2, 2, 1, 1, 11, False, 1],
+ [1, 2904, 2904, 24, 24, 3, 3, 1, 1, 1, 1, 11, False, 1],
+ [1, 2904, 2904, 48, 48, 3, 3, 2, 2, 1, 1, 11, False, 1],
+ [1, 1024, 1024, 10, 10, 3, 3, 1, 1, 1, 1, 1024, False, 1],
+ [1, 1024, 1024, 16, 16, 3, 3, 1, 1, 1, 1, 1024, True, 1],
+ [1, 1024, 1024, 19, 19, 3, 3, 2, 2, 1, 1, 1024, False, 1],
+ [1, 1024, 1024, 7, 7, 3, 3, 1, 1, 1, 1, 1024, False, 1],
+ [1, 100, 100, 14, 14, 3, 3, 1, 1, 1, 1, 100, False, 1],
+ [1, 160, 160, 14, 14, 3, 3, 1, 1, 1, 1, 10, False, 1],
+ [1, 160, 160, 28, 28, 3, 3, 2, 2, 1, 1, 10, False, 1],
[1, 16, 1, 28, 28, 3, 3, 1, 1, 1, 1, 1, True, 1],
- [1, 32, 1, 28, 28, 3, 3, 1, 1, 1, 1, 0, True, 1],
- [1, 100, 100, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 1008, 1008, 14, 14, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 192, 1008, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 1008, 1008, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 1008, 1008, 7, 7, 3, 3, 1, 1, 1, 1, 1, False, 1],
+ [1, 32, 1, 28, 28, 3, 3, 1, 1, 0, 0, 1, True, 1],
+ [1, 192, 1008, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 1008, 1008, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
[1, 40, 102, 56, 56, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 1024, 1024, 10, 10, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 1024, 1024, 10, 10, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 1536, 1024, 10, 10, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 1024, 1024, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
+ [1, 1024, 1024, 10, 10, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 1536, 1024, 10, 10, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 1024, 1024, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
[1, 1024, 1024, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 1024, 1024, 14, 14, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 1024, 1024, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 1024, 1024, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 1024, 1024, 14, 14, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 128, 1024, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 256, 1024, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 512, 1024, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 1024, 1024, 16, 16, 3, 3, 1, 1, 1, 1, 1, True, 1],
- [1, 255, 1024, 16, 16, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 512, 1024, 16, 16, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 128, 1024, 17, 17, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 192, 1024, 17, 17, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 256, 1024, 17, 17, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 384, 1024, 17, 17, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 1024, 1024, 19, 19, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 1024, 1024, 19, 19, 1, 1, 1, 1, 1, 1, 0, True, 1],
+ [1, 1024, 1024, 14, 14, 3, 3, 2, 2, 1, 1, 1, False, 1],
+ [1, 128, 1024, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 256, 1024, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 512, 1024, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 255, 1024, 16, 16, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 512, 1024, 16, 16, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 128, 1024, 17, 17, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 192, 1024, 17, 17, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 256, 1024, 17, 17, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 384, 1024, 17, 17, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 1024, 1024, 19, 19, 1, 1, 1, 1, 0, 0, 1, True, 1],
[1, 24, 1024, 19, 19, 3, 3, 1, 1, 1, 1, 1, True, 1],
- [1, 256, 1024, 19, 19, 1, 1, 1, 1, 1, 1, 0, True, 1],
+ [1, 256, 1024, 19, 19, 1, 1, 1, 1, 0, 0, 1, True, 1],
[1, 546, 1024, 19, 19, 3, 3, 1, 1, 1, 1, 1, True, 1],
- [1, 1024, 1024, 28, 28, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 1024, 1024, 28, 28, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 256, 1024, 45, 80, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 512, 1024, 45, 80, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 256, 1024, 50, 68, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 256, 1024, 50, 68, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 512, 1024, 50, 68, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 1024, 1024, 7, 7, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 1024, 1024, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
+ [1, 256, 1024, 45, 80, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 512, 1024, 45, 80, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 256, 1024, 50, 68, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 256, 1024, 50, 68, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 512, 1024, 50, 68, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 1024, 1024, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
[1, 1024, 1024, 7, 7, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 1024, 1024, 7, 7, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 128, 1024, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 12, 104, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 26, 104, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 104, 104, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 104, 104, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 208, 104, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 208, 104, 28, 28, 1, 1, 2, 2, 2, 2, 0, False, 1],
- [1, 104, 104, 56, 56, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 132, 1056, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 264, 1056, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 128, 1056, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 192, 1056, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 128, 1056, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 192, 1056, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
+ [1, 128, 1024, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 12, 104, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 26, 104, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 104, 104, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 208, 104, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 208, 104, 28, 28, 1, 1, 2, 2, 0, 0, 1, False, 1],
+ [1, 132, 1056, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 264, 1056, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 128, 1056, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 192, 1056, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 128, 1056, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 192, 1056, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
[1, 462, 1072, 7, 7, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 128, 1088, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 768, 1088, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 128, 1088, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 440, 110, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 192, 1104, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 192, 1104, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 1232, 112, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 448, 112, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 896, 112, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 112, 112, 14, 14, 5, 5, 2, 2, 2, 2, 2, False, 1],
+ [1, 128, 1088, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 768, 1088, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 128, 1088, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 440, 110, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 192, 1104, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 192, 1104, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 1232, 112, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 448, 112, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 896, 112, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
[1, 224, 112, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 336, 112, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 672, 112, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 672, 112, 15, 15, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 672, 112, 20, 20, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 672, 112, 24, 24, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 160, 112, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 672, 112, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 128, 1120, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 128, 1120, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 128, 1152, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 192, 1152, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 1152, 1152, 7, 7, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 1152, 1152, 7, 7, 5, 5, 1, 1, 1, 1, 2, False, 1],
- [1, 128, 1152, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 192, 1152, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 320, 1152, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 1152, 1152, 8, 8, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 1152, 1152, 8, 8, 5, 5, 1, 1, 1, 1, 2, False, 1],
- [1, 192, 1152, 8, 8, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 320, 1152, 8, 8, 1, 1, 1, 1, 1, 1, 0, False, 1],
+ [1, 336, 112, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 672, 112, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 672, 112, 15, 15, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 672, 112, 20, 20, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 672, 112, 24, 24, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 160, 112, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 672, 112, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 128, 1120, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 128, 1120, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 128, 1152, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 192, 1152, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 128, 1152, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 192, 1152, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 320, 1152, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 192, 1152, 8, 8, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 320, 1152, 8, 8, 1, 1, 1, 1, 0, 0, 1, False, 1],
[1, 40, 116, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1],
[1, 34, 118, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 128, 1184, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 128, 1184, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 104, 12, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 120, 12, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 48, 12, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 12, 12, 56, 56, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 12, 120, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 30, 120, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 32, 120, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 480, 120, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 120, 120, 14, 14, 1, 5, 1, 1, 1, 1, 0, False, 1],
- [1, 120, 120, 14, 14, 5, 1, 1, 1, 1, 1, 2, False, 1],
- [1, 120, 120, 14, 14, 5, 5, 1, 1, 1, 1, 2, False, 1],
- [1, 48, 120, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 720, 120, 17, 17, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 120, 120, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 120, 120, 28, 28, 5, 5, 1, 1, 1, 1, 2, False, 1],
- [1, 120, 120, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 120, 120, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 20, 120, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 336, 120, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 336, 120, 28, 28, 1, 1, 2, 2, 2, 2, 0, False, 1],
- [1, 40, 120, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 120, 120, 40, 40, 5, 5, 1, 1, 1, 1, 2, False, 1],
- [1, 40, 120, 40, 40, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 120, 120, 56, 56, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 192, 1200, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 192, 1200, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 128, 1216, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 128, 1216, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
+ [1, 128, 1184, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 128, 1184, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 104, 12, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 120, 12, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 48, 12, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 12, 120, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 30, 120, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 32, 120, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 480, 120, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 48, 120, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 720, 120, 17, 17, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 120, 120, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 20, 120, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 336, 120, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 336, 120, 28, 28, 1, 1, 2, 2, 0, 0, 1, False, 1],
+ [1, 40, 120, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 40, 120, 40, 40, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 192, 1200, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 192, 1200, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 128, 1216, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 128, 1216, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
[1, 46, 122, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 112, 1232, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 308, 1232, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 1232, 1232, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 1232, 1232, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 1232, 1232, 28, 28, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 128, 124, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 128, 1248, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 192, 1248, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 128, 1248, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 192, 1248, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 1248, 1248, 9, 9, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 1248, 1248, 9, 9, 5, 5, 1, 1, 1, 1, 2, False, 1],
- [1, 208, 1248, 9, 9, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 352, 1248, 9, 9, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 128, 128, 1, 1, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 24, 128, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 546, 128, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 256, 128, 10, 10, 3, 3, 2, 2, 2, 2, 1, True, 1],
+ [1, 112, 1232, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 308, 1232, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 1232, 1232, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 128, 124, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 128, 1248, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 192, 1248, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 128, 1248, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 192, 1248, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 208, 1248, 9, 9, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 352, 1248, 9, 9, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 24, 128, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 546, 128, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 256, 128, 10, 10, 3, 3, 2, 2, 1, 1, 1, True, 1],
[1, 128, 128, 100, 136, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 512, 128, 100, 136, 1, 1, 1, 1, 1, 1, 0, False, 1],
+ [1, 512, 128, 100, 136, 1, 1, 1, 1, 0, 0, 1, False, 1],
[1, 128, 128, 112, 112, 3, 3, 1, 1, 1, 1, 1, False, 1],
[1, 128, 128, 112, 112, 3, 3, 1, 1, 1, 1, 1, True, 1],
[1, 64, 128, 120, 160, 3, 3, 1, 1, 1, 1, 1, True, 1],
- [1, 128, 128, 128, 128, 3, 3, 1, 1, 1, 1, 1, True, 1],
- [1, 256, 128, 128, 128, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 64, 128, 128, 128, 1, 1, 1, 1, 1, 1, 0, False, 1],
+ [1, 256, 128, 128, 128, 3, 3, 2, 2, 1, 1, 1, False, 1],
+ [1, 64, 128, 128, 128, 1, 1, 1, 1, 0, 0, 1, False, 1],
[1, 64, 128, 128, 128, 3, 3, 1, 1, 1, 1, 1, False, 1],
[1, 128, 128, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 256, 128, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
+ [1, 256, 128, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
[1, 256, 128, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1],
[1, 32, 128, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 512, 128, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 128, 128, 150, 150, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 128, 128, 150, 150, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 128, 128, 150, 150, 1, 1, 1, 1, 1, 1, 0, False, 1],
+ [1, 512, 128, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 128, 128, 150, 150, 1, 1, 1, 1, 0, 0, 1, False, 1],
[1, 128, 128, 150, 150, 3, 3, 1, 1, 1, 1, 1, True, 1],
- [1, 128, 128, 180, 320, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 256, 128, 2, 2, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 128, 128, 200, 272, 3, 3, 2, 2, 2, 2, 1, False, 1],
+ [1, 128, 128, 180, 320, 3, 3, 2, 2, 1, 1, 1, False, 1],
+ [1, 256, 128, 2, 2, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 128, 128, 200, 272, 3, 3, 2, 2, 1, 1, 1, False, 1],
[1, 64, 128, 224, 224, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 128, 128, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 128, 128, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 128, 128, 28, 28, 1, 1, 1, 1, 1, 1, 0, True, 1],
+ [1, 128, 128, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 128, 128, 28, 28, 1, 1, 1, 1, 0, 0, 1, True, 1],
[1, 128, 128, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1],
[1, 128, 128, 28, 28, 3, 3, 1, 1, 1, 1, 1, True, 1],
- [1, 128, 128, 28, 28, 3, 3, 1, 1, 1, 1, 2, True, 1],
- [1, 128, 128, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1],
+ [1, 128, 128, 28, 28, 3, 3, 1, 1, 2, 2, 1, True, 1],
[1, 16, 128, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 19, 128, 28, 28, 1, 1, 1, 1, 1, 1, 0, True, 1],
+ [1, 19, 128, 28, 28, 1, 1, 1, 1, 0, 0, 1, True, 1],
[1, 192, 128, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 256, 128, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 256, 128, 28, 28, 1, 1, 2, 2, 2, 2, 0, False, 1],
- [1, 256, 128, 28, 28, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 288, 128, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 288, 128, 28, 28, 1, 1, 2, 2, 2, 2, 0, False, 1],
+ [1, 256, 128, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 256, 128, 28, 28, 1, 1, 2, 2, 0, 0, 1, False, 1],
+ [1, 256, 128, 28, 28, 3, 3, 2, 2, 1, 1, 1, False, 1],
+ [1, 288, 128, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 288, 128, 28, 28, 1, 1, 2, 2, 0, 0, 1, False, 1],
[1, 32, 128, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 38, 128, 28, 28, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 512, 128, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 512, 128, 28, 28, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 128, 128, 3, 3, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 256, 128, 3, 3, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 256, 128, 3, 3, 3, 3, 1, 1, 1, 1, 0, True, 1],
+ [1, 38, 128, 28, 28, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 512, 128, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 512, 128, 28, 28, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 256, 128, 3, 3, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 256, 128, 3, 3, 3, 3, 1, 1, 0, 0, 1, True, 1],
[1, 64, 128, 30, 40, 3, 3, 1, 1, 1, 1, 1, True, 1],
[1, 256, 128, 32, 32, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 256, 128, 5, 5, 3, 3, 1, 1, 1, 1, 0, True, 1],
+ [1, 256, 128, 5, 5, 3, 3, 1, 1, 0, 0, 1, True, 1],
+ [1, 128, 128, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1],
[1, 128, 128, 56, 56, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 128, 128, 56, 56, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 128, 128, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 128, 128, 56, 56, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 128, 128, 56, 56, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 128, 128, 56, 56, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 128, 128, 56, 56, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 256, 128, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1],
+ [1, 128, 128, 56, 56, 3, 3, 2, 2, 1, 1, 1, False, 1],
+ [1, 256, 128, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1],
[1, 256, 128, 56, 56, 3, 3, 1, 1, 1, 1, 1, False, 1],
[1, 256, 128, 56, 56, 3, 3, 1, 1, 1, 1, 1, True, 1],
- [1, 256, 128, 56, 56, 3, 3, 2, 2, 2, 2, 1, True, 1],
+ [1, 256, 128, 56, 56, 3, 3, 2, 2, 1, 1, 1, True, 1],
[1, 32, 128, 56, 56, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 64, 128, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 128, 128, 60, 80, 4, 4, 4, 4, 4, 4, 0, True, 1],
- [1, 320, 128, 60, 80, 3, 3, 2, 2, 2, 2, 1, True, 1],
- [1, 64, 128, 60, 80, 1, 1, 1, 1, 1, 1, 0, True, 1],
+ [1, 64, 128, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 128, 128, 60, 80, 4, 4, 4, 4, 0, 0, 1, True, 1],
+ [1, 320, 128, 60, 80, 3, 3, 2, 2, 1, 1, 1, True, 1],
+ [1, 64, 128, 60, 80, 1, 1, 1, 1, 0, 0, 1, True, 1],
[1, 64, 128, 60, 80, 3, 3, 1, 1, 1, 1, 1, True, 1],
[1, 128, 128, 64, 64, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 128, 128, 64, 64, 2, 2, 2, 2, 2, 2, 0, True, 1],
+ [1, 128, 128, 64, 64, 2, 2, 2, 2, 0, 0, 1, True, 1],
[1, 256, 128, 64, 64, 3, 3, 1, 1, 1, 1, 1, False, 1],
[1, 32, 128, 7, 7, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 128, 128, 75, 75, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 128, 128, 75, 75, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 256, 128, 75, 75, 1, 1, 1, 1, 1, 1, 0, False, 1],
+ [1, 128, 128, 75, 75, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 256, 128, 75, 75, 1, 1, 1, 1, 0, 0, 1, False, 1],
[1, 256, 128, 75, 75, 3, 3, 1, 1, 1, 1, 1, True, 1],
[1, 128, 128, 90, 160, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 512, 128, 90, 160, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 128, 1280, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 640, 1280, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 128, 1280, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 512, 1280, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 192, 1296, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 192, 1296, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 128, 1312, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 128, 1312, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 1056, 132, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 528, 132, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 128, 1344, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 1344, 1344, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 1344, 1344, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 192, 1344, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 1344, 1344, 28, 28, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 128, 1344, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 192, 1344, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 816, 136, 19, 19, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 128, 1376, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 128, 1376, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 174, 1392, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 348, 1392, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 1392, 1392, 10, 10, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 1392, 1392, 10, 10, 5, 5, 1, 1, 1, 1, 2, False, 1],
- [1, 232, 1392, 10, 10, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 384, 1392, 10, 10, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 1392, 1392, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 1392, 1392, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 192, 1392, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 1392, 1392, 28, 28, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 192, 1392, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 128, 1408, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 128, 1408, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
+ [1, 512, 128, 90, 160, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 128, 1280, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 640, 1280, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 128, 1280, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 512, 1280, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 192, 1296, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 192, 1296, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 128, 1312, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 128, 1312, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 1056, 132, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 528, 132, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 128, 1344, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 1344, 1344, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 192, 1344, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 128, 1344, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 192, 1344, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 816, 136, 19, 19, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 128, 1376, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 128, 1376, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 174, 1392, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 348, 1392, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 232, 1392, 10, 10, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 384, 1392, 10, 10, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 1392, 1392, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 192, 1392, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 192, 1392, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 128, 1408, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 128, 1408, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
[1, 68, 142, 56, 56, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 1512, 144, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 16, 144, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 36, 144, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 40, 144, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 576, 144, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 144, 144, 14, 14, 5, 5, 1, 1, 1, 1, 2, False, 1],
+ [1, 1512, 144, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 16, 144, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 36, 144, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 40, 144, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 576, 144, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
[1, 288, 144, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 48, 144, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 144, 144, 151, 151, 3, 3, 2, 2, 2, 2, 0, False, 1],
- [1, 144, 144, 191, 191, 3, 3, 2, 2, 2, 2, 0, False, 1],
- [1, 144, 144, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 144, 144, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1],
+ [1, 48, 144, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 144, 144, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1],
[1, 28, 144, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 32, 144, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 320, 144, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 320, 144, 28, 28, 1, 1, 2, 2, 2, 2, 0, False, 1],
- [1, 40, 144, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 40, 144, 30, 30, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 48, 144, 33, 33, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 144, 144, 56, 56, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 144, 144, 56, 56, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 144, 144, 56, 56, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 192, 144, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 24, 144, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 144, 144, 59, 59, 5, 5, 2, 2, 2, 2, 0, False, 1],
- [1, 144, 144, 60, 60, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 24, 144, 60, 60, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 144, 144, 63, 63, 5, 5, 2, 2, 2, 2, 0, False, 1],
- [1, 144, 144, 65, 65, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 24, 144, 65, 65, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 144, 144, 69, 69, 5, 5, 2, 2, 2, 2, 0, False, 1],
- [1, 1024, 144, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
+ [1, 32, 144, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 320, 144, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 320, 144, 28, 28, 1, 1, 2, 2, 0, 0, 1, False, 1],
+ [1, 40, 144, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 40, 144, 30, 30, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 48, 144, 33, 33, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 192, 144, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 24, 144, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 24, 144, 60, 60, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 24, 144, 65, 65, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 1024, 144, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
[1, 144, 144, 7, 7, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 18, 144, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 256, 144, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 36, 144, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 72, 144, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 32, 144, 75, 75, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 32, 144, 95, 95, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 128, 1440, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 192, 1440, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 128, 1440, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 192, 1440, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 128, 1472, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 128, 1472, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 192, 1488, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 192, 1488, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 128, 1504, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 128, 1504, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 144, 1512, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 1512, 1512, 14, 14, 3, 3, 2, 2, 2, 2, 1, False, 1],
+ [1, 18, 144, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 256, 144, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 36, 144, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 72, 144, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 32, 144, 75, 75, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 32, 144, 95, 95, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 128, 1440, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 192, 1440, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 128, 1440, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 192, 1440, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 128, 1472, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 128, 1472, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 192, 1488, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 192, 1488, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 128, 1504, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 128, 1504, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 144, 1512, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
[1, 58, 152, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 1536, 1536, 10, 10, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 128, 1536, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 192, 1536, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 128, 1536, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 192, 1536, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 256, 1536, 8, 8, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 384, 1536, 8, 8, 1, 1, 1, 1, 1, 1, 0, False, 1],
+ [1, 128, 1536, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 192, 1536, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 128, 1536, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 192, 1536, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 256, 1536, 8, 8, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 384, 1536, 8, 8, 1, 1, 1, 1, 0, 0, 1, False, 1],
[1, 68, 156, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 128, 1568, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 128, 1568, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 192, 1584, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 192, 1584, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 144, 16, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 8, 16, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 16, 16, 112, 112, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 16, 16, 112, 112, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 16, 16, 112, 112, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 24, 16, 112, 112, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 64, 16, 112, 112, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 8, 16, 112, 112, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 96, 16, 112, 112, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 96, 16, 120, 120, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 96, 16, 130, 130, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 16, 16, 14, 14, 2, 2, 2, 2, 2, 2, 0, True, 1],
+ [1, 128, 1568, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 128, 1568, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 192, 1584, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 192, 1584, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 144, 16, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 8, 16, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 16, 16, 112, 112, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 24, 16, 112, 112, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 64, 16, 112, 112, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 8, 16, 112, 112, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 96, 16, 112, 112, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 96, 16, 120, 120, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 96, 16, 130, 130, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 16, 16, 14, 14, 2, 2, 2, 2, 0, 0, 1, True, 1],
[1, 4, 16, 14, 14, 3, 3, 1, 1, 1, 1, 1, True, 1],
[1, 48, 16, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 16, 16, 160, 160, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 16, 16, 160, 160, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 64, 16, 160, 160, 1, 1, 1, 1, 1, 1, 0, False, 1],
+ [1, 16, 16, 160, 160, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 64, 16, 160, 160, 1, 1, 1, 1, 0, 0, 1, False, 1],
[1, 16, 16, 224, 224, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 32, 16, 224, 224, 3, 3, 2, 2, 2, 2, 1, False, 1],
+ [1, 32, 16, 224, 224, 3, 3, 2, 2, 1, 1, 1, False, 1],
[1, 32, 16, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 16, 16, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 24, 16, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 72, 16, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 160, 160, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 160, 160, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
+ [1, 16, 16, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 24, 16, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 72, 16, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 160, 160, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
[1, 320, 160, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 400, 160, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 400, 160, 14, 14, 1, 1, 2, 2, 2, 2, 0, False, 1],
- [1, 960, 160, 24, 24, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 128, 160, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 160, 160, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 160, 160, 28, 28, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 160, 160, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 960, 160, 3, 3, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 160, 160, 32, 32, 2, 2, 2, 2, 2, 2, 0, True, 1],
- [1, 256, 160, 32, 32, 3, 3, 2, 2, 2, 2, 1, True, 1],
- [1, 128, 160, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1],
+ [1, 400, 160, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 400, 160, 14, 14, 1, 1, 2, 2, 0, 0, 1, False, 1],
+ [1, 960, 160, 24, 24, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 128, 160, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 160, 160, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 960, 160, 3, 3, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 160, 160, 32, 32, 2, 2, 2, 2, 0, 0, 1, True, 1],
+ [1, 256, 160, 32, 32, 3, 3, 2, 2, 1, 1, 1, True, 1],
+ [1, 128, 160, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1],
[1, 320, 160, 7, 7, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 480, 160, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 960, 160, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 64, 160, 73, 73, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 128, 1600, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 128, 1600, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 1632, 1632, 12, 12, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 1632, 1632, 12, 12, 5, 5, 1, 1, 1, 1, 2, False, 1],
- [1, 272, 1632, 12, 12, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 128, 1632, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 192, 1632, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 128, 1632, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 192, 1632, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 128, 1664, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 128, 1664, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 672, 168, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 168, 168, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 168, 168, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 408, 168, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 408, 168, 28, 28, 1, 1, 2, 2, 2, 2, 0, False, 1],
- [1, 168, 168, 56, 56, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 192, 1680, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 192, 1680, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 128, 1696, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 128, 1696, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
+ [1, 480, 160, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 960, 160, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 64, 160, 73, 73, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 128, 1600, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 128, 1600, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 272, 1632, 12, 12, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 128, 1632, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 192, 1632, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 128, 1632, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 192, 1632, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 128, 1664, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 128, 1664, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 672, 168, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 168, 168, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 408, 168, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 408, 168, 28, 28, 1, 1, 2, 2, 0, 0, 1, False, 1],
+ [1, 192, 1680, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 192, 1680, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 128, 1696, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 128, 1696, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
[1, 46, 172, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 128, 1728, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 192, 1728, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 128, 1728, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 192, 1728, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 1392, 174, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 696, 174, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 128, 1760, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 128, 1760, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 192, 1776, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 192, 1776, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 896, 1792, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 128, 1792, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 216, 18, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 72, 18, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 144, 18, 14, 14, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 18, 18, 28, 28, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 72, 18, 28, 28, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 128, 18, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1],
+ [1, 128, 1728, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 192, 1728, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 128, 1728, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 192, 1728, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 1392, 174, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 696, 174, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 128, 1760, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 128, 1760, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 192, 1776, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 192, 1776, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 896, 1792, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 128, 1792, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 216, 18, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 72, 18, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 144, 18, 14, 14, 3, 3, 2, 2, 1, 1, 1, False, 1],
+ [1, 18, 18, 28, 28, 3, 3, 2, 2, 1, 1, 1, False, 1],
+ [1, 72, 18, 28, 28, 3, 3, 2, 2, 1, 1, 1, False, 1],
+ [1, 128, 18, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1],
[1, 18, 18, 56, 56, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 18, 18, 56, 56, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 32, 18, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 36, 18, 56, 56, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 192, 1824, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 128, 1824, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 192, 1824, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 184, 184, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 40, 184, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 80, 184, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 184, 184, 20, 20, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 80, 184, 20, 20, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 184, 184, 7, 7, 1, 5, 1, 1, 1, 1, 0, False, 1],
- [1, 184, 184, 7, 7, 5, 1, 1, 1, 1, 1, 2, False, 1],
- [1, 128, 185, 28, 28, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 128, 1856, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 192, 1872, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 192, 1872, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 128, 1888, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 192, 192, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 192, 192, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
+ [1, 18, 18, 56, 56, 3, 3, 2, 2, 1, 1, 1, False, 1],
+ [1, 32, 18, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 36, 18, 56, 56, 3, 3, 2, 2, 1, 1, 1, False, 1],
+ [1, 192, 1824, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 128, 1824, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 192, 1824, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 40, 184, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 80, 184, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 80, 184, 20, 20, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 128, 185, 28, 28, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 128, 1856, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 192, 1872, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 192, 1872, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 128, 1888, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 192, 192, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
[1, 48, 192, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 64, 192, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 192, 192, 17, 17, 3, 3, 2, 2, 2, 2, 0, False, 1],
- [1, 192, 192, 17, 17, 7, 1, 1, 1, 1, 1, 3, False, 1],
- [1, 224, 192, 17, 17, 1, 7, 1, 1, 1, 1, 0, False, 1],
- [1, 128, 192, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 16, 192, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 192, 192, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 192, 192, 28, 28, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 192, 192, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 192, 192, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 32, 192, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 432, 192, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 432, 192, 28, 28, 1, 1, 2, 2, 2, 2, 0, False, 1],
+ [1, 64, 192, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 192, 192, 17, 17, 3, 3, 2, 2, 0, 0, 1, False, 1],
+ [1, 192, 192, 17, 17, 7, 1, 1, 1, 3, 0, 1, False, 1],
+ [1, 224, 192, 17, 17, 1, 7, 1, 1, 0, 3, 1, False, 1],
+ [1, 128, 192, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 16, 192, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 192, 192, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 32, 192, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 432, 192, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 432, 192, 28, 28, 1, 1, 2, 2, 0, 0, 1, False, 1],
[1, 48, 192, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 64, 192, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 96, 192, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1],
+ [1, 64, 192, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 96, 192, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1],
[1, 224, 192, 35, 35, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 48, 192, 38, 38, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 56, 192, 48, 48, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 128, 192, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 192, 192, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 192, 192, 56, 56, 3, 3, 2, 2, 2, 2, 1, False, 1],
+ [1, 48, 192, 38, 38, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 56, 192, 48, 48, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 128, 192, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 192, 192, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1],
[1, 48, 192, 56, 56, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 1152, 192, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
+ [1, 1152, 192, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
[1, 384, 192, 7, 7, 3, 3, 1, 1, 1, 1, 1, False, 1],
[1, 48, 192, 7, 7, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 192, 192, 71, 71, 3, 3, 2, 2, 2, 2, 0, False, 1],
- [1, 192, 192, 75, 75, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 32, 192, 75, 75, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 192, 192, 79, 79, 5, 5, 2, 2, 2, 2, 0, False, 1],
- [1, 1152, 192, 8, 8, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 192, 192, 95, 95, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 32, 192, 95, 95, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 192, 192, 99, 99, 5, 5, 2, 2, 2, 2, 0, False, 1],
- [1, 192, 1920, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 1920, 1920, 14, 14, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 192, 1920, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 784, 196, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
+ [1, 192, 192, 71, 71, 3, 3, 2, 2, 0, 0, 1, False, 1],
+ [1, 32, 192, 75, 75, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 1152, 192, 8, 8, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 32, 192, 95, 95, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 192, 1920, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 192, 1920, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 784, 196, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
[1, 40, 196, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 192, 1968, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 192, 1968, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 72, 20, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 20, 20, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 200, 200, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 40, 200, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 80, 200, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 200, 200, 20, 20, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 80, 200, 20, 20, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 200, 200, 7, 7, 1, 5, 1, 1, 1, 1, 0, False, 1],
- [1, 200, 200, 7, 7, 5, 1, 1, 1, 1, 1, 2, False, 1],
- [1, 224, 2016, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 192, 2016, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 2016, 2016, 14, 14, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 192, 2016, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 2048, 2048, 14, 14, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 2048, 2048, 14, 14, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 2048, 2048, 14, 14, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 256, 2048, 23, 40, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 512, 2048, 23, 40, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 256, 2048, 25, 34, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 256, 2048, 25, 34, 3, 3, 2, 2, 2, 2, 1, True, 1],
- [1, 512, 2048, 25, 34, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 2048, 2048, 7, 7, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 2048, 2048, 7, 7, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 512, 2048, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 192, 2064, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 192, 2064, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 26, 208, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 52, 208, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 208, 208, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 208, 208, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 440, 208, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 440, 208, 14, 14, 1, 1, 2, 2, 2, 2, 0, False, 1],
- [1, 208, 208, 28, 28, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 1248, 208, 9, 9, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 192, 2112, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 18, 216, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 54, 216, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 216, 216, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 216, 216, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 576, 216, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 216, 216, 56, 56, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 192, 2160, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
+ [1, 192, 1968, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 192, 1968, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 72, 20, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 40, 200, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 80, 200, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 80, 200, 20, 20, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 224, 2016, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 192, 2016, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 192, 2016, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 256, 2048, 23, 40, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 512, 2048, 23, 40, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 256, 2048, 25, 34, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 256, 2048, 25, 34, 3, 3, 2, 2, 1, 1, 1, True, 1],
+ [1, 512, 2048, 25, 34, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 512, 2048, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 192, 2064, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 192, 2064, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 26, 208, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 52, 208, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 208, 208, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 440, 208, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 440, 208, 14, 14, 1, 1, 2, 2, 0, 0, 1, False, 1],
+ [1, 1248, 208, 9, 9, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 192, 2112, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 18, 216, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 54, 216, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 216, 216, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 576, 216, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 192, 2160, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
[1, 78, 218, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 888, 222, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 2016, 224, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 56, 224, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 8, 224, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 896, 224, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 224, 224, 112, 112, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 224, 224, 112, 112, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 224, 224, 17, 17, 7, 1, 1, 1, 1, 1, 3, False, 1],
- [1, 256, 224, 17, 17, 1, 7, 1, 1, 1, 1, 0, False, 1],
- [1, 256, 224, 17, 17, 7, 1, 1, 1, 1, 1, 3, False, 1],
- [1, 128, 224, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 256, 224, 35, 35, 3, 3, 2, 2, 2, 2, 0, False, 1],
- [1, 128, 224, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 224, 224, 56, 56, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 224, 224, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 224, 224, 56, 56, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 448, 224, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 448, 224, 56, 56, 1, 1, 2, 2, 2, 2, 0, False, 1],
- [1, 224, 224, 7, 7, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 224, 224, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 58, 232, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 8, 232, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 1392, 232, 10, 10, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 232, 232, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 696, 232, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 696, 232, 56, 56, 1, 1, 2, 2, 2, 2, 0, False, 1],
+ [1, 888, 222, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 2016, 224, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 56, 224, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 8, 224, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 896, 224, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 224, 224, 17, 17, 7, 1, 1, 1, 3, 0, 1, False, 1],
+ [1, 256, 224, 17, 17, 1, 7, 1, 1, 0, 3, 1, False, 1],
+ [1, 256, 224, 17, 17, 7, 1, 1, 1, 3, 0, 1, False, 1],
+ [1, 128, 224, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 256, 224, 35, 35, 3, 3, 2, 2, 0, 0, 1, False, 1],
+ [1, 128, 224, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 224, 224, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 448, 224, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 448, 224, 56, 56, 1, 1, 2, 2, 0, 0, 1, False, 1],
+ [1, 224, 224, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 58, 232, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 8, 232, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 1392, 232, 10, 10, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 232, 232, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 696, 232, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 696, 232, 56, 56, 1, 1, 2, 2, 0, 0, 1, False, 1],
[1, 68, 236, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 72, 24, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 96, 24, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 24, 24, 112, 112, 3, 3, 1, 1, 1, 1, 1, False, 1],
+ [1, 72, 24, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 96, 24, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
[1, 64, 24, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 144, 24, 150, 150, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 40, 24, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 72, 24, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 88, 24, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 96, 24, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1],
+ [1, 144, 24, 150, 150, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 40, 24, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 72, 24, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 88, 24, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 96, 24, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1],
[1, 14, 24, 56, 56, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 144, 24, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 24, 24, 56, 56, 5, 5, 2, 2, 2, 2, 2, False, 1],
- [1, 36, 24, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 72, 24, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 144, 24, 60, 60, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 144, 24, 65, 65, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 72, 24, 80, 80, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 64, 240, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 960, 240, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 240, 240, 14, 14, 1, 5, 1, 1, 1, 1, 0, False, 1],
- [1, 240, 240, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 240, 240, 14, 14, 5, 1, 1, 1, 1, 1, 2, False, 1],
- [1, 240, 240, 14, 14, 5, 5, 1, 1, 1, 1, 2, False, 1],
- [1, 40, 240, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 80, 240, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 80, 240, 15, 15, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 80, 240, 20, 20, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 192, 240, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 240, 240, 28, 28, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 240, 240, 28, 28, 5, 5, 1, 1, 1, 1, 2, False, 1],
- [1, 240, 240, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 240, 240, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 40, 240, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 720, 240, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 240, 240, 29, 29, 3, 3, 2, 2, 2, 2, 0, False, 1],
- [1, 240, 240, 30, 30, 5, 5, 1, 1, 1, 1, 2, False, 1],
- [1, 40, 240, 30, 30, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 240, 240, 31, 31, 3, 3, 2, 2, 2, 2, 0, False, 1],
- [1, 240, 240, 40, 40, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 192, 240, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 240, 240, 56, 56, 3, 3, 2, 2, 2, 2, 1, False, 1],
+ [1, 144, 24, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 36, 24, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 72, 24, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 144, 24, 60, 60, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 144, 24, 65, 65, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 72, 24, 80, 80, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 64, 240, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 960, 240, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 40, 240, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 80, 240, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 80, 240, 15, 15, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 80, 240, 20, 20, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 192, 240, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 240, 240, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 40, 240, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 720, 240, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 40, 240, 30, 30, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 192, 240, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1],
[1, 16, 256, 1, 1, 3, 3, 1, 1, 1, 1, 1, True, 1],
- [1, 256, 256, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
+ [1, 256, 256, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
[1, 364, 256, 1, 1, 3, 3, 1, 1, 1, 1, 1, True, 1],
- [1, 256, 256, 10, 10, 3, 3, 2, 2, 2, 2, 1, False, 1],
[1, 256, 256, 100, 136, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 256, 256, 100, 136, 3, 3, 2, 2, 2, 2, 1, False, 1],
+ [1, 256, 256, 100, 136, 3, 3, 2, 2, 1, 1, 1, False, 1],
[1, 256, 256, 100, 136, 3, 3, 1, 1, 1, 1, 1, True, 1],
[1, 36, 256, 100, 136, 3, 3, 1, 1, 1, 1, 1, True, 1],
[1, 128, 256, 112, 112, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 150, 256, 128, 128, 1, 1, 1, 1, 1, 1, 0, True, 1],
+ [1, 150, 256, 128, 128, 1, 1, 1, 1, 0, 0, 1, True, 1],
[1, 256, 256, 13, 17, 3, 3, 1, 1, 1, 1, 1, False, 1],
[1, 256, 256, 13, 17, 3, 3, 1, 1, 1, 1, 1, True, 1],
- [1, 256, 256, 13, 17, 3, 3, 2, 2, 2, 2, 1, True, 1],
+ [1, 256, 256, 13, 17, 3, 3, 2, 2, 1, 1, 1, True, 1],
[1, 36, 256, 13, 17, 3, 3, 1, 1, 1, 1, 1, True, 1],
[1, 819, 256, 13, 17, 3, 3, 1, 1, 1, 1, 1, True, 1],
- [1, 1024, 256, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 128, 256, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
+ [1, 1024, 256, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 128, 256, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
[1, 256, 256, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 512, 256, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 512, 256, 14, 14, 1, 1, 2, 2, 2, 2, 0, False, 1],
- [1, 512, 256, 14, 14, 3, 3, 2, 2, 2, 2, 1, False, 1],
+ [1, 512, 256, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 512, 256, 14, 14, 1, 1, 2, 2, 0, 0, 1, False, 1],
+ [1, 512, 256, 14, 14, 3, 3, 2, 2, 1, 1, 1, False, 1],
[1, 512, 256, 16, 16, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 256, 256, 17, 17, 1, 7, 1, 1, 1, 1, 0, False, 1],
- [1, 320, 256, 17, 17, 7, 1, 1, 1, 1, 1, 3, False, 1],
- [1, 128, 256, 180, 320, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 64, 256, 180, 320, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 512, 256, 19, 19, 3, 3, 2, 2, 2, 2, 1, True, 1],
- [1, 24, 256, 2, 2, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 256, 256, 2, 2, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 546, 256, 2, 2, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 64, 256, 2, 2, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 128, 256, 200, 272, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 64, 256, 200, 272, 1, 1, 1, 1, 1, 1, 0, False, 1],
+ [1, 256, 256, 17, 17, 1, 7, 1, 1, 0, 3, 1, False, 1],
+ [1, 320, 256, 17, 17, 7, 1, 1, 1, 3, 0, 1, False, 1],
+ [1, 128, 256, 180, 320, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 64, 256, 180, 320, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 512, 256, 19, 19, 3, 3, 2, 2, 1, 1, 1, True, 1],
+ [1, 24, 256, 2, 2, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 546, 256, 2, 2, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 64, 256, 2, 2, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 128, 256, 200, 272, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 64, 256, 200, 272, 1, 1, 1, 1, 0, 0, 1, False, 1],
[1, 256, 256, 25, 34, 3, 3, 1, 1, 1, 1, 1, False, 1],
[1, 256, 256, 25, 34, 3, 3, 1, 1, 1, 1, 1, True, 1],
- [1, 256, 256, 25, 34, 3, 3, 2, 2, 2, 2, 1, True, 1],
+ [1, 256, 256, 25, 34, 3, 3, 2, 2, 1, 1, 1, True, 1],
[1, 36, 256, 25, 34, 3, 3, 1, 1, 1, 1, 1, True, 1],
- [1, 128, 256, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 160, 256, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1],
+ [1, 128, 256, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 160, 256, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1],
[1, 20, 256, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1],
+ [1, 256, 256, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1],
[1, 256, 256, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 256, 256, 28, 28, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 256, 256, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 256, 256, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 256, 256, 28, 28, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 256, 256, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 32, 256, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 512, 256, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1],
+ [1, 256, 256, 28, 28, 3, 3, 2, 2, 1, 1, 1, False, 1],
+ [1, 32, 256, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 512, 256, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1],
[1, 512, 256, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1],
[1, 512, 256, 28, 28, 3, 3, 1, 1, 1, 1, 1, True, 1],
- [1, 512, 256, 28, 28, 3, 3, 2, 2, 2, 2, 1, True, 1],
- [1, 64, 256, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 128, 256, 3, 3, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 128, 256, 3, 3, 1, 1, 1, 1, 1, 1, 0, True, 1],
+ [1, 512, 256, 28, 28, 3, 3, 2, 2, 1, 1, 1, True, 1],
+ [1, 64, 256, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 128, 256, 3, 3, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 128, 256, 3, 3, 1, 1, 1, 1, 0, 0, 1, True, 1],
[1, 16, 256, 3, 3, 3, 3, 1, 1, 1, 1, 1, True, 1],
- [1, 24, 256, 3, 3, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 256, 256, 3, 3, 3, 3, 1, 1, 1, 1, 1, False, 1],
+ [1, 24, 256, 3, 3, 1, 1, 1, 1, 0, 0, 1, True, 1],
[1, 364, 256, 3, 3, 3, 3, 1, 1, 1, 1, 1, True, 1],
- [1, 546, 256, 3, 3, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 128, 256, 32, 32, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 256, 256, 32, 32, 2, 2, 2, 2, 2, 2, 0, True, 1],
+ [1, 546, 256, 3, 3, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 128, 256, 32, 32, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 256, 256, 32, 32, 2, 2, 2, 2, 0, 0, 1, True, 1],
[1, 256, 256, 32, 32, 3, 3, 1, 1, 1, 1, 1, False, 1],
[1, 512, 256, 32, 32, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 256, 256, 38, 38, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 256, 256, 38, 38, 1, 1, 1, 1, 1, 1, 0, False, 1],
+ [1, 256, 256, 38, 38, 1, 1, 1, 1, 0, 0, 1, False, 1],
[1, 512, 256, 38, 38, 3, 3, 1, 1, 1, 1, 1, True, 1],
- [1, 728, 256, 38, 38, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 1024, 256, 45, 80, 1, 1, 1, 1, 1, 1, 0, False, 1],
+ [1, 728, 256, 38, 38, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 1024, 256, 45, 80, 1, 1, 1, 1, 0, 0, 1, False, 1],
[1, 256, 256, 45, 80, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 128, 256, 5, 5, 1, 1, 1, 1, 1, 1, 0, True, 1],
+ [1, 128, 256, 5, 5, 1, 1, 1, 1, 0, 0, 1, True, 1],
[1, 24, 256, 5, 5, 3, 3, 1, 1, 1, 1, 1, True, 1],
- [1, 512, 256, 5, 5, 1, 1, 1, 1, 1, 1, 0, False, 1],
+ [1, 512, 256, 5, 5, 1, 1, 1, 1, 0, 0, 1, False, 1],
[1, 546, 256, 5, 5, 3, 3, 1, 1, 1, 1, 1, True, 1],
- [1, 1024, 256, 50, 68, 1, 1, 1, 1, 1, 1, 0, False, 1],
+ [1, 1024, 256, 50, 68, 1, 1, 1, 1, 0, 0, 1, False, 1],
[1, 256, 256, 50, 68, 3, 3, 1, 1, 1, 1, 1, False, 1],
[1, 256, 256, 50, 68, 3, 3, 1, 1, 1, 1, 1, True, 1],
[1, 36, 256, 50, 68, 3, 3, 1, 1, 1, 1, 1, True, 1],
- [1, 128, 256, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1],
+ [1, 128, 256, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1],
[1, 18, 256, 56, 56, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 256, 256, 56, 56, 2, 2, 2, 2, 2, 2, 0, True, 1],
- [1, 256, 256, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 256, 256, 56, 56, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 256, 256, 56, 56, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 36, 256, 56, 56, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 512, 256, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 512, 256, 56, 56, 1, 1, 2, 2, 2, 2, 0, False, 1],
- [1, 64, 256, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 128, 256, 64, 64, 1, 1, 1, 1, 1, 1, 0, False, 1],
+ [1, 256, 256, 56, 56, 2, 2, 2, 2, 0, 0, 1, True, 1],
+ [1, 256, 256, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 256, 256, 56, 56, 3, 3, 2, 2, 1, 1, 1, False, 1],
+ [1, 36, 256, 56, 56, 3, 3, 2, 2, 1, 1, 1, False, 1],
+ [1, 512, 256, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 512, 256, 56, 56, 1, 1, 2, 2, 0, 0, 1, False, 1],
+ [1, 64, 256, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 128, 256, 64, 64, 1, 1, 1, 1, 0, 0, 1, False, 1],
[1, 128, 256, 64, 64, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 255, 256, 64, 64, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 256, 256, 64, 64, 3, 3, 1, 1, 1, 1, 1, True, 1],
- [1, 512, 256, 64, 64, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 1024, 256, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
+ [1, 255, 256, 64, 64, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 512, 256, 64, 64, 3, 3, 2, 2, 1, 1, 1, False, 1],
+ [1, 1024, 256, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
[1, 256, 256, 7, 7, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 512, 256, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
+ [1, 512, 256, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
[1, 256, 256, 7, 9, 3, 3, 1, 1, 1, 1, 1, False, 1],
[1, 256, 256, 7, 9, 3, 3, 1, 1, 1, 1, 1, True, 1],
[1, 36, 256, 7, 9, 3, 3, 1, 1, 1, 1, 1, True, 1],
[1, 819, 256, 7, 9, 3, 3, 1, 1, 1, 1, 1, True, 1],
- [1, 256, 256, 75, 75, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 256, 256, 75, 75, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 256, 256, 75, 75, 1, 1, 2, 2, 2, 2, 0, False, 1],
- [1, 256, 256, 90, 160, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 104, 26, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 208, 26, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 256, 262, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 1056, 264, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 1632, 272, 12, 12, 1, 1, 1, 1, 1, 1, 0, False, 1],
+ [1, 256, 256, 75, 75, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 256, 256, 75, 75, 1, 1, 2, 2, 0, 0, 1, False, 1],
+ [1, 256, 256, 90, 160, 3, 3, 2, 2, 1, 1, 1, False, 1],
+ [1, 104, 26, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 208, 26, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 256, 262, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 1056, 264, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 1632, 272, 12, 12, 1, 1, 1, 1, 0, 0, 1, False, 1],
[1, 160, 272, 7, 7, 3, 3, 1, 1, 1, 1, 1, False, 1],
[1, 34, 276, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1],
[1, 16, 28, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 72, 288, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 128, 288, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 288, 288, 14, 14, 5, 5, 2, 2, 2, 2, 2, False, 1],
- [1, 288, 288, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 288, 288, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 672, 288, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 672, 288, 14, 14, 1, 1, 2, 2, 2, 2, 0, False, 1],
- [1, 88, 288, 17, 17, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 96, 288, 19, 19, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 128, 288, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 192, 288, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 288, 288, 28, 28, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 288, 288, 33, 33, 5, 5, 1, 1, 1, 1, 2, False, 1],
- [1, 48, 288, 33, 33, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 288, 288, 35, 35, 3, 3, 2, 2, 2, 2, 0, False, 1],
- [1, 288, 288, 38, 38, 5, 5, 1, 1, 1, 1, 2, False, 1],
- [1, 48, 288, 38, 38, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 288, 288, 39, 39, 3, 3, 2, 2, 2, 2, 0, False, 1],
- [1, 192, 288, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 96, 288, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
+ [1, 72, 288, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 128, 288, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 288, 288, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 672, 288, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 672, 288, 14, 14, 1, 1, 2, 2, 0, 0, 1, False, 1],
+ [1, 88, 288, 17, 17, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 96, 288, 19, 19, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 128, 288, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 192, 288, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 48, 288, 33, 33, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 48, 288, 38, 38, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 192, 288, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 96, 288, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
[1, 134, 296, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 128, 3, 224, 224, 4, 4, 4, 4, 4, 4, 0, True, 1],
- [1, 16, 3, 224, 224, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 16, 3, 224, 224, 7, 7, 1, 1, 1, 1, 3, False, 1],
- [1, 32, 3, 224, 224, 3, 3, 2, 2, 2, 2, 1, False, 1],
+ [1, 128, 3, 224, 224, 4, 4, 4, 4, 0, 0, 1, True, 1],
+ [1, 16, 3, 224, 224, 3, 3, 2, 2, 1, 1, 1, False, 1],
+ [1, 16, 3, 224, 224, 7, 7, 1, 1, 3, 3, 1, False, 1],
+ [1, 32, 3, 224, 224, 3, 3, 2, 2, 1, 1, 1, False, 1],
[1, 64, 3, 224, 224, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 64, 3, 224, 224, 3, 3, 2, 2, 2, 2, 1, False, 1],
+ [1, 64, 3, 224, 224, 3, 3, 2, 2, 1, 1, 1, False, 1],
[1, 64, 3, 224, 224, 3, 3, 1, 1, 1, 1, 1, True, 1],
- [1, 64, 3, 224, 224, 7, 7, 2, 2, 2, 2, 3, False, 1],
- [1, 96, 3, 224, 224, 4, 4, 4, 4, 4, 4, 0, True, 1],
- [1, 96, 3, 224, 224, 7, 7, 2, 2, 2, 2, 3, False, 1],
- [1, 32, 3, 225, 225, 3, 3, 2, 2, 2, 2, 0, False, 1],
- [1, 32, 3, 241, 241, 3, 3, 2, 2, 2, 2, 0, False, 1],
- [1, 128, 3, 256, 256, 4, 4, 4, 4, 4, 4, 0, True, 1],
+ [1, 64, 3, 224, 224, 7, 7, 2, 2, 3, 3, 1, False, 1],
+ [1, 96, 3, 224, 224, 4, 4, 4, 4, 0, 0, 1, True, 1],
+ [1, 96, 3, 224, 224, 7, 7, 2, 2, 3, 3, 1, False, 1],
+ [1, 32, 3, 225, 225, 3, 3, 2, 2, 0, 0, 1, False, 1],
+ [1, 32, 3, 241, 241, 3, 3, 2, 2, 0, 0, 1, False, 1],
+ [1, 128, 3, 256, 256, 4, 4, 4, 4, 0, 0, 1, True, 1],
[1, 32, 3, 256, 256, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 96, 3, 256, 256, 4, 4, 4, 4, 4, 4, 0, True, 1],
- [1, 32, 3, 261, 261, 3, 3, 2, 2, 2, 2, 0, False, 1],
- [1, 32, 3, 299, 299, 3, 3, 2, 2, 2, 2, 1, False, 1],
+ [1, 96, 3, 256, 256, 4, 4, 4, 4, 0, 0, 1, True, 1],
+ [1, 32, 3, 261, 261, 3, 3, 2, 2, 0, 0, 1, False, 1],
+ [1, 32, 3, 299, 299, 3, 3, 2, 2, 1, 1, 1, False, 1],
[1, 64, 3, 300, 300, 3, 3, 1, 1, 1, 1, 1, True, 1],
- [1, 32, 3, 301, 301, 3, 3, 2, 2, 2, 2, 0, False, 1],
- [1, 16, 3, 320, 320, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 32, 3, 384, 384, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 120, 30, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 336, 30, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 3024, 3024, 14, 14, 3, 3, 2, 2, 2, 2, 1, False, 1],
+ [1, 32, 3, 301, 301, 3, 3, 2, 2, 0, 0, 1, False, 1],
+ [1, 16, 3, 320, 320, 3, 3, 2, 2, 1, 1, 1, False, 1],
+ [1, 32, 3, 384, 384, 3, 3, 2, 2, 1, 1, 1, False, 1],
+ [1, 120, 30, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 336, 30, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
[1, 116, 304, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 1232, 308, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
+ [1, 1232, 308, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
[1, 58, 310, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 120, 32, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 16, 32, 112, 112, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 224, 32, 112, 112, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 224, 32, 112, 112, 1, 1, 2, 2, 2, 2, 0, False, 1],
- [1, 232, 32, 112, 112, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 232, 32, 112, 112, 1, 1, 2, 2, 2, 2, 0, False, 1],
- [1, 256, 32, 112, 112, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 256, 32, 112, 112, 1, 1, 2, 2, 2, 2, 0, False, 1],
- [1, 32, 32, 112, 112, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 32, 32, 112, 112, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 32, 32, 112, 112, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 32, 32, 112, 112, 1, 1, 2, 2, 2, 2, 0, False, 1],
- [1, 336, 32, 112, 112, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 336, 32, 112, 112, 1, 1, 2, 2, 2, 2, 0, False, 1],
- [1, 48, 32, 112, 112, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 48, 32, 112, 112, 1, 1, 2, 2, 2, 2, 0, False, 1],
- [1, 64, 32, 112, 112, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 64, 32, 112, 112, 1, 1, 2, 2, 2, 2, 0, False, 1],
+ [1, 120, 32, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 16, 32, 112, 112, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 224, 32, 112, 112, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 224, 32, 112, 112, 1, 1, 2, 2, 0, 0, 1, False, 1],
+ [1, 232, 32, 112, 112, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 232, 32, 112, 112, 1, 1, 2, 2, 0, 0, 1, False, 1],
+ [1, 256, 32, 112, 112, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 256, 32, 112, 112, 1, 1, 2, 2, 0, 0, 1, False, 1],
+ [1, 32, 32, 112, 112, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 32, 32, 112, 112, 1, 1, 2, 2, 0, 0, 1, False, 1],
+ [1, 336, 32, 112, 112, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 336, 32, 112, 112, 1, 1, 2, 2, 0, 0, 1, False, 1],
+ [1, 48, 32, 112, 112, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 48, 32, 112, 112, 1, 1, 2, 2, 0, 0, 1, False, 1],
+ [1, 64, 32, 112, 112, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 64, 32, 112, 112, 1, 1, 2, 2, 0, 0, 1, False, 1],
[1, 64, 32, 112, 112, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 64, 32, 112, 112, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 72, 32, 112, 112, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 72, 32, 112, 112, 1, 1, 2, 2, 2, 2, 0, False, 1],
- [1, 80, 32, 112, 112, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 80, 32, 112, 112, 1, 1, 2, 2, 2, 2, 0, False, 1],
- [1, 96, 32, 112, 112, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 96, 32, 112, 112, 1, 1, 2, 2, 2, 2, 0, False, 1],
- [1, 16, 32, 120, 120, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 32, 32, 120, 120, 3, 3, 1, 1, 1, 1, 1, False, 1],
+ [1, 64, 32, 112, 112, 3, 3, 2, 2, 1, 1, 1, False, 1],
+ [1, 72, 32, 112, 112, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 72, 32, 112, 112, 1, 1, 2, 2, 0, 0, 1, False, 1],
+ [1, 80, 32, 112, 112, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 80, 32, 112, 112, 1, 1, 2, 2, 0, 0, 1, False, 1],
+ [1, 96, 32, 112, 112, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 96, 32, 112, 112, 1, 1, 2, 2, 0, 0, 1, False, 1],
+ [1, 16, 32, 120, 120, 1, 1, 1, 1, 0, 0, 1, False, 1],
[1, 2, 32, 120, 160, 3, 3, 1, 1, 1, 1, 1, True, 1],
- [1, 32, 32, 128, 128, 8, 8, 8, 8, 8, 8, 0, True, 1],
+ [1, 32, 32, 128, 128, 8, 8, 8, 8, 0, 0, 1, True, 1],
[1, 64, 32, 128, 128, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 64, 32, 128, 128, 3, 3, 2, 2, 2, 2, 1, True, 1],
- [1, 16, 32, 130, 130, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 32, 32, 130, 130, 3, 3, 1, 1, 1, 1, 1, False, 1],
+ [1, 64, 32, 128, 128, 3, 3, 2, 2, 1, 1, 1, True, 1],
+ [1, 16, 32, 130, 130, 1, 1, 1, 1, 0, 0, 1, False, 1],
[1, 128, 32, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1],
[1, 64, 32, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1],
[1, 64, 32, 147, 147, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 32, 32, 149, 149, 3, 3, 1, 1, 1, 1, 0, False, 1],
- [1, 24, 32, 150, 150, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 32, 32, 150, 150, 3, 3, 1, 1, 1, 1, 1, False, 1],
+ [1, 32, 32, 149, 149, 3, 3, 1, 1, 0, 0, 1, False, 1],
+ [1, 24, 32, 150, 150, 1, 1, 1, 1, 0, 0, 1, False, 1],
[1, 64, 32, 150, 150, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 24, 32, 190, 190, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 32, 32, 190, 190, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 528, 32, 192, 192, 1, 1, 2, 2, 2, 2, 0, False, 1],
- [1, 1, 32, 256, 256, 1, 1, 1, 1, 1, 1, 0, True, 1],
+ [1, 24, 32, 190, 190, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 528, 32, 192, 192, 1, 1, 2, 2, 0, 0, 1, False, 1],
+ [1, 1, 32, 256, 256, 1, 1, 1, 1, 0, 0, 1, True, 1],
[1, 32, 32, 256, 256, 3, 3, 1, 1, 1, 1, 1, False, 1],
[1, 64, 32, 256, 256, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 64, 32, 26, 26, 3, 3, 1, 1, 1, 1, 0, True, 1],
- [1, 192, 32, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1],
+ [1, 64, 32, 26, 26, 3, 3, 1, 1, 0, 0, 1, True, 1],
+ [1, 192, 32, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1],
[1, 96, 32, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1],
[1, 2, 32, 30, 40, 3, 3, 1, 1, 1, 1, 1, True, 1],
- [1, 128, 32, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 32, 32, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1],
+ [1, 128, 32, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 32, 32, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1],
[1, 32, 32, 56, 56, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 64, 32, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 64, 32, 56, 56, 1, 1, 2, 2, 2, 2, 0, False, 1],
+ [1, 64, 32, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 64, 32, 56, 56, 1, 1, 2, 2, 0, 0, 1, False, 1],
[1, 2, 32, 60, 80, 3, 3, 1, 1, 1, 1, 1, True, 1],
[1, 128, 32, 7, 7, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 192, 32, 75, 75, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 192, 32, 95, 95, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 36, 320, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 80, 320, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 128, 320, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 320, 320, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 320, 320, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
+ [1, 192, 32, 75, 75, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 192, 32, 95, 95, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 36, 320, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 80, 320, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 128, 320, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 320, 320, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
[1, 40, 320, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 784, 320, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 784, 320, 14, 14, 1, 1, 2, 2, 2, 2, 0, False, 1],
- [1, 320, 320, 17, 17, 3, 3, 2, 2, 2, 2, 0, False, 1],
- [1, 128, 320, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 320, 320, 28, 28, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 320, 320, 30, 40, 2, 2, 2, 2, 2, 2, 0, True, 1],
- [1, 512, 320, 30, 40, 3, 3, 2, 2, 2, 2, 1, True, 1],
- [1, 64, 320, 30, 40, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 1280, 320, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 1280, 320, 8, 8, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 320, 328, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 30, 336, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 84, 336, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 336, 336, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 336, 336, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 336, 336, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 888, 336, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 112, 336, 24, 24, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 192, 336, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 336, 336, 28, 28, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 56, 336, 48, 48, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 336, 336, 49, 49, 3, 3, 2, 2, 2, 2, 0, False, 1],
- [1, 192, 336, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 336, 336, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 672, 336, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 672, 336, 56, 56, 1, 1, 2, 2, 2, 2, 0, False, 1],
+ [1, 784, 320, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 784, 320, 14, 14, 1, 1, 2, 2, 0, 0, 1, False, 1],
+ [1, 320, 320, 17, 17, 3, 3, 2, 2, 0, 0, 1, False, 1],
+ [1, 128, 320, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 320, 320, 30, 40, 2, 2, 2, 2, 0, 0, 1, True, 1],
+ [1, 512, 320, 30, 40, 3, 3, 2, 2, 1, 1, 1, True, 1],
+ [1, 64, 320, 30, 40, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 1280, 320, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 1280, 320, 8, 8, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 320, 328, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 30, 336, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 84, 336, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 336, 336, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 888, 336, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 112, 336, 24, 24, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 192, 336, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 56, 336, 48, 48, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 192, 336, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 336, 336, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 672, 336, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 672, 336, 56, 56, 1, 1, 2, 2, 0, 0, 1, False, 1],
[1, 20, 34, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 1392, 348, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 128, 352, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 128, 352, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 1280, 352, 9, 9, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 144, 36, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 320, 36, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 144, 36, 14, 14, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 18, 36, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 256, 36, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1],
+ [1, 1392, 348, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 128, 352, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 128, 352, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 1280, 352, 9, 9, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 144, 36, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 320, 36, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 144, 36, 14, 14, 3, 3, 2, 2, 1, 1, 1, False, 1],
+ [1, 18, 36, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 256, 36, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1],
[1, 36, 36, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 36, 36, 28, 28, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 64, 36, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 72, 36, 28, 28, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 36, 36, 56, 56, 3, 3, 1, 1, 1, 1, 1, False, 1],
+ [1, 36, 36, 28, 28, 3, 3, 2, 2, 1, 1, 1, False, 1],
+ [1, 64, 36, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 72, 36, 28, 28, 3, 3, 2, 2, 1, 1, 1, False, 1],
[1, 68, 360, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1],
[1, 98, 368, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 3712, 3712, 14, 14, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 1280, 384, 10, 10, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 128, 384, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 192, 384, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 384, 384, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 64, 384, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 96, 384, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 128, 384, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 192, 384, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 192, 384, 35, 35, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 384, 384, 35, 35, 3, 3, 2, 2, 2, 2, 0, False, 1],
- [1, 64, 384, 35, 35, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 96, 384, 35, 35, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 192, 384, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 128, 384, 64, 64, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 256, 384, 8, 8, 1, 3, 1, 1, 1, 1, 0, False, 1],
- [1, 256, 384, 8, 8, 3, 1, 1, 1, 1, 1, 1, False, 1],
- [1, 448, 384, 8, 8, 3, 1, 1, 1, 1, 1, 1, False, 1],
- [1, 4, 4, 7, 7, 2, 2, 2, 2, 2, 2, 0, True, 1],
- [1, 144, 40, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 120, 40, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 240, 40, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 40, 40, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 80, 40, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 120, 40, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 240, 40, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 40, 40, 28, 28, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 60, 40, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 240, 40, 30, 30, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 120, 40, 40, 40, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 240, 40, 40, 40, 1, 1, 1, 1, 1, 1, 0, False, 1],
+ [1, 1280, 384, 10, 10, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 128, 384, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 192, 384, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 64, 384, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 96, 384, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 128, 384, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 192, 384, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 192, 384, 35, 35, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 384, 384, 35, 35, 3, 3, 2, 2, 0, 0, 1, False, 1],
+ [1, 64, 384, 35, 35, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 96, 384, 35, 35, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 192, 384, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 128, 384, 64, 64, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 256, 384, 8, 8, 1, 3, 1, 1, 0, 1, 1, False, 1],
+ [1, 256, 384, 8, 8, 3, 1, 1, 1, 1, 0, 1, False, 1],
+ [1, 448, 384, 8, 8, 3, 1, 1, 1, 1, 0, 1, False, 1],
+ [1, 4, 4, 7, 7, 2, 2, 2, 2, 0, 0, 1, True, 1],
+ [1, 144, 40, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 120, 40, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 240, 40, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 80, 40, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 120, 40, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 240, 40, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 60, 40, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 240, 40, 30, 30, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 120, 40, 40, 40, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 240, 40, 40, 40, 1, 1, 1, 1, 0, 0, 1, False, 1],
[1, 14, 40, 56, 56, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 400, 400, 14, 14, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 400, 400, 7, 7, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 400, 400, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 408, 408, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 408, 408, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 912, 408, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 408, 408, 28, 28, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 128, 416, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 128, 416, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1],
+ [1, 400, 400, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 408, 408, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 912, 408, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 128, 416, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 128, 416, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1],
[1, 116, 428, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 1008, 432, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 192, 432, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 432, 432, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 432, 432, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 192, 432, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 432, 432, 28, 28, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 110, 440, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 52, 440, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 440, 440, 14, 14, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 440, 440, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 440, 440, 7, 7, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 112, 448, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 56, 448, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 1280, 448, 12, 12, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 128, 448, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 1232, 448, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 1232, 448, 28, 28, 1, 1, 2, 2, 2, 2, 0, False, 1],
- [1, 128, 448, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 448, 448, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 448, 448, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 448, 448, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 896, 448, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 896, 448, 28, 28, 1, 1, 2, 2, 2, 2, 0, False, 1],
- [1, 256, 448, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 448, 448, 56, 56, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 448, 448, 56, 56, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 512, 448, 8, 8, 1, 3, 1, 1, 1, 1, 0, False, 1],
+ [1, 1008, 432, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 192, 432, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 432, 432, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 192, 432, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 110, 440, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 52, 440, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 440, 440, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 112, 448, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 56, 448, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 1280, 448, 12, 12, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 128, 448, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 1232, 448, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 1232, 448, 28, 28, 1, 1, 2, 2, 0, 0, 1, False, 1],
+ [1, 128, 448, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 448, 448, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 896, 448, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 896, 448, 28, 28, 1, 1, 2, 2, 0, 0, 1, False, 1],
+ [1, 256, 448, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 512, 448, 8, 8, 1, 3, 1, 1, 0, 1, 1, False, 1],
[1, 16, 46, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1],
[1, 168, 466, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 12, 48, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 8, 48, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 48, 48, 112, 112, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 48, 48, 112, 112, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 48, 48, 112, 112, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 144, 48, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 288, 48, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 288, 48, 33, 33, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 288, 48, 38, 38, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 104, 48, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 104, 48, 56, 56, 1, 1, 2, 2, 2, 2, 0, False, 1],
- [1, 12, 48, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 120, 48, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 120, 48, 56, 56, 1, 1, 2, 2, 2, 2, 0, False, 1],
- [1, 48, 48, 56, 56, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 48, 48, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1],
+ [1, 12, 48, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 8, 48, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 144, 48, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 288, 48, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 288, 48, 33, 33, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 288, 48, 38, 38, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 104, 48, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 104, 48, 56, 56, 1, 1, 2, 2, 0, 0, 1, False, 1],
+ [1, 12, 48, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 120, 48, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 120, 48, 56, 56, 1, 1, 2, 2, 0, 0, 1, False, 1],
+ [1, 48, 48, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1],
[1, 128, 48, 7, 7, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 120, 480, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 24, 480, 10, 10, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 256, 480, 10, 10, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 480, 480, 10, 10, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 480, 480, 10, 10, 5, 5, 1, 1, 1, 1, 2, False, 1],
- [1, 546, 480, 10, 10, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 80, 480, 10, 10, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 112, 480, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 128, 480, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 16, 480, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 192, 480, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 480, 480, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 480, 480, 14, 14, 5, 5, 1, 1, 1, 1, 2, False, 1],
- [1, 56, 480, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 64, 480, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 80, 480, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 96, 480, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 112, 480, 15, 15, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 480, 480, 15, 15, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 480, 480, 15, 15, 5, 5, 1, 1, 1, 1, 2, False, 1],
- [1, 80, 480, 15, 15, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 112, 480, 20, 20, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 480, 480, 20, 20, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 128, 480, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 192, 480, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 480, 480, 7, 7, 1, 5, 1, 1, 1, 1, 0, False, 1],
- [1, 480, 480, 7, 7, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 480, 480, 7, 7, 5, 1, 1, 1, 1, 1, 2, False, 1],
- [1, 1000, 512, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 512, 512, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 128, 512, 10, 10, 1, 1, 1, 1, 1, 1, 0, True, 1],
+ [1, 120, 480, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 24, 480, 10, 10, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 256, 480, 10, 10, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 546, 480, 10, 10, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 80, 480, 10, 10, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 112, 480, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 128, 480, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 16, 480, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 192, 480, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 56, 480, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 64, 480, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 80, 480, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 96, 480, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 112, 480, 15, 15, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 80, 480, 15, 15, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 112, 480, 20, 20, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 128, 480, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 192, 480, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 1000, 512, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 512, 512, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 128, 512, 10, 10, 1, 1, 1, 1, 0, 0, 1, True, 1],
[1, 24, 512, 10, 10, 3, 3, 1, 1, 1, 1, 1, True, 1],
[1, 546, 512, 10, 10, 3, 3, 1, 1, 1, 1, 1, True, 1],
- [1, 1024, 512, 100, 136, 1, 1, 2, 2, 2, 2, 0, False, 1],
- [1, 1024, 512, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
+ [1, 1024, 512, 100, 136, 1, 1, 2, 2, 0, 0, 1, False, 1],
+ [1, 1024, 512, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
[1, 1024, 512, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 1024, 512, 14, 14, 3, 3, 2, 2, 2, 2, 1, True, 1],
- [1, 112, 512, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 128, 512, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 144, 512, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 160, 512, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 192, 512, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 24, 512, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 256, 512, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 32, 512, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 512, 512, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 512, 512, 14, 14, 3, 3, 2, 2, 2, 2, 1, False, 1],
+ [1, 1024, 512, 14, 14, 3, 3, 2, 2, 1, 1, 1, True, 1],
+ [1, 112, 512, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 128, 512, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 144, 512, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 160, 512, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 192, 512, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 24, 512, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 256, 512, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 32, 512, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 512, 512, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
[1, 512, 512, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 512, 512, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 512, 512, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 512, 512, 14, 14, 3, 3, 2, 2, 2, 2, 1, False, 1],
+ [1, 512, 512, 14, 14, 3, 3, 2, 2, 1, 1, 1, False, 1],
[1, 512, 512, 14, 14, 3, 3, 1, 1, 1, 1, 1, True, 1],
- [1, 64, 512, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 64, 512, 15, 20, 1, 1, 1, 1, 1, 1, 0, True, 1],
+ [1, 64, 512, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 64, 512, 15, 20, 1, 1, 1, 1, 0, 0, 1, True, 1],
[1, 1024, 512, 16, 16, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 256, 512, 16, 16, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 512, 512, 16, 16, 2, 2, 2, 2, 2, 2, 0, True, 1],
+ [1, 256, 512, 16, 16, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 512, 512, 16, 16, 2, 2, 2, 2, 0, 0, 1, True, 1],
[1, 512, 512, 16, 16, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 1024, 512, 19, 19, 3, 3, 1, 1, 1, 1, 6, True, 1],
+ [1, 1024, 512, 19, 19, 3, 3, 1, 1, 6, 6, 1, True, 1],
[1, 512, 512, 19, 19, 3, 3, 1, 1, 1, 1, 1, True, 1],
- [1, 2048, 512, 23, 40, 1, 1, 1, 1, 1, 1, 0, False, 1],
+ [1, 2048, 512, 23, 40, 1, 1, 1, 1, 0, 0, 1, False, 1],
[1, 512, 512, 23, 40, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 2048, 512, 25, 34, 1, 1, 1, 1, 1, 1, 0, False, 1],
+ [1, 2048, 512, 25, 34, 1, 1, 1, 1, 0, 0, 1, False, 1],
[1, 512, 512, 25, 34, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 1024, 512, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 1024, 512, 28, 28, 1, 1, 2, 2, 2, 2, 0, False, 1],
- [1, 128, 512, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 128, 512, 28, 28, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 19, 512, 28, 28, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 256, 512, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 38, 512, 28, 28, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 512, 512, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 512, 512, 28, 28, 3, 3, 1, 1, 1, 1, 2, False, 1],
- [1, 512, 512, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1],
+ [1, 1024, 512, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 1024, 512, 28, 28, 1, 1, 2, 2, 0, 0, 1, False, 1],
+ [1, 128, 512, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 128, 512, 28, 28, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 19, 512, 28, 28, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 256, 512, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 38, 512, 28, 28, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 512, 512, 28, 28, 2, 2, 2, 2, 0, 0, 1, True, 1],
+ [1, 512, 512, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1],
[1, 512, 512, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 512, 512, 28, 28, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 512, 512, 28, 28, 2, 2, 2, 2, 2, 2, 0, True, 1],
- [1, 512, 512, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 512, 512, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 512, 512, 28, 28, 3, 3, 2, 2, 2, 2, 1, False, 1],
+ [1, 512, 512, 28, 28, 3, 3, 2, 2, 1, 1, 1, False, 1],
[1, 512, 512, 28, 28, 3, 3, 1, 1, 1, 1, 1, True, 1],
- [1, 512, 512, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 896, 512, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 896, 512, 28, 28, 1, 1, 2, 2, 2, 2, 0, False, 1],
- [1, 1024, 512, 32, 32, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 255, 512, 32, 32, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 256, 512, 32, 32, 1, 1, 1, 1, 1, 1, 0, False, 1],
+ [1, 896, 512, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 896, 512, 28, 28, 1, 1, 2, 2, 0, 0, 1, False, 1],
+ [1, 1024, 512, 32, 32, 3, 3, 2, 2, 1, 1, 1, False, 1],
+ [1, 255, 512, 32, 32, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 256, 512, 32, 32, 1, 1, 1, 1, 0, 0, 1, False, 1],
[1, 256, 512, 32, 32, 3, 3, 1, 1, 1, 1, 1, False, 1],
[1, 16, 512, 38, 38, 3, 3, 1, 1, 1, 1, 1, True, 1],
- [1, 512, 512, 45, 80, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 128, 512, 5, 5, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 24, 512, 5, 5, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 512, 512, 5, 5, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 546, 512, 5, 5, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 512, 512, 50, 68, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 512, 512, 56, 56, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 512, 512, 56, 56, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 512, 512, 56, 56, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 1024, 512, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 128, 512, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 2048, 512, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
+ [1, 512, 512, 45, 80, 3, 3, 2, 2, 1, 1, 1, False, 1],
+ [1, 128, 512, 5, 5, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 24, 512, 5, 5, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 546, 512, 5, 5, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 512, 512, 50, 68, 3, 3, 2, 2, 1, 1, 1, False, 1],
+ [1, 1024, 512, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 128, 512, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 2048, 512, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
[1, 512, 512, 7, 7, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 256, 512, 8, 8, 1, 3, 1, 1, 1, 1, 0, False, 1],
- [1, 256, 512, 8, 8, 3, 1, 1, 1, 1, 1, 1, False, 1],
- [1, 128, 512, 90, 160, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 256, 512, 90, 160, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 208, 52, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 440, 52, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 132, 528, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 8, 528, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 128, 528, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 160, 528, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 192, 528, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 256, 528, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 32, 528, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 120, 528, 17, 17, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 528, 528, 17, 17, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 88, 528, 17, 17, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 192, 528, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 528, 528, 96, 96, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 216, 54, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 576, 54, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
+ [1, 256, 512, 8, 8, 1, 3, 1, 1, 0, 1, 1, False, 1],
+ [1, 256, 512, 8, 8, 3, 1, 1, 1, 1, 0, 1, False, 1],
+ [1, 128, 512, 90, 160, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 256, 512, 90, 160, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 208, 52, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 440, 52, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 132, 528, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 8, 528, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 128, 528, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 160, 528, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 192, 528, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 256, 528, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 32, 528, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 120, 528, 17, 17, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 88, 528, 17, 17, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 192, 528, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 528, 528, 96, 96, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 216, 54, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 576, 54, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
[1, 24, 54, 56, 56, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 128, 544, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
+ [1, 128, 544, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
[1, 196, 544, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 128, 544, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 224, 56, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 448, 56, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 56, 56, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 336, 56, 48, 48, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 144, 576, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 54, 576, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 128, 576, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 1512, 576, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 1512, 576, 14, 14, 1, 1, 2, 2, 2, 2, 0, False, 1],
- [1, 192, 576, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 576, 576, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 576, 576, 14, 14, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 576, 576, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 576, 576, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 96, 576, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 136, 576, 19, 19, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 576, 576, 19, 19, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 576, 576, 19, 19, 5, 5, 1, 1, 1, 1, 2, False, 1],
- [1, 96, 576, 19, 19, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 192, 576, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 576, 576, 28, 28, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 128, 576, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 160, 576, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 576, 576, 7, 7, 5, 5, 1, 1, 1, 1, 2, False, 1],
- [1, 96, 576, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 232, 58, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 696, 58, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
+ [1, 128, 544, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 224, 56, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 448, 56, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 336, 56, 48, 48, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 144, 576, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 54, 576, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 128, 576, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 1512, 576, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 1512, 576, 14, 14, 1, 1, 2, 2, 0, 0, 1, False, 1],
+ [1, 192, 576, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 576, 576, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 96, 576, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 136, 576, 19, 19, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 96, 576, 19, 19, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 192, 576, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 128, 576, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 160, 576, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 96, 576, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 232, 58, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 696, 58, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
[1, 20, 58, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 60, 60, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 128, 608, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 128, 608, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
+ [1, 128, 608, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 128, 608, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
[1, 28, 62, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 192, 624, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 192, 624, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 128, 64, 1, 1, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 240, 64, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 8, 64, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
+ [1, 192, 624, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 192, 624, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 128, 64, 1, 1, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 240, 64, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 8, 64, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
[1, 128, 64, 112, 112, 3, 3, 1, 1, 1, 1, 1, False, 1],
[1, 128, 64, 112, 112, 3, 3, 1, 1, 1, 1, 1, True, 1],
- [1, 64, 64, 112, 112, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 64, 64, 112, 112, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 64, 64, 112, 112, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 64, 64, 112, 112, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 64, 64, 112, 112, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 128, 64, 120, 160, 3, 3, 2, 2, 2, 2, 1, True, 1],
+ [1, 64, 64, 112, 112, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 64, 64, 112, 112, 3, 3, 2, 2, 1, 1, 1, False, 1],
+ [1, 128, 64, 120, 160, 3, 3, 2, 2, 1, 1, 1, True, 1],
[1, 32, 64, 120, 160, 3, 3, 1, 1, 1, 1, 1, True, 1],
- [1, 64, 64, 120, 160, 8, 8, 8, 8, 8, 8, 0, True, 1],
+ [1, 64, 64, 120, 160, 8, 8, 8, 8, 0, 0, 1, True, 1],
[1, 128, 64, 128, 128, 3, 3, 1, 1, 1, 1, 1, False, 1],
[1, 64, 64, 128, 128, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 384, 64, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 96, 64, 147, 147, 3, 3, 2, 2, 2, 2, 0, False, 1],
- [1, 128, 64, 150, 150, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 128, 64, 150, 150, 1, 1, 2, 2, 2, 2, 0, False, 1],
+ [1, 384, 64, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 96, 64, 147, 147, 3, 3, 2, 2, 0, 0, 1, False, 1],
+ [1, 128, 64, 150, 150, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 128, 64, 150, 150, 1, 1, 2, 2, 0, 0, 1, False, 1],
[1, 128, 64, 150, 150, 3, 3, 1, 1, 1, 1, 1, True, 1],
- [1, 64, 64, 150, 150, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 64, 64, 160, 160, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 64, 64, 180, 320, 1, 1, 1, 1, 1, 1, 0, False, 1],
+ [1, 64, 64, 180, 320, 1, 1, 1, 1, 0, 0, 1, False, 1],
[1, 64, 64, 180, 320, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 64, 64, 2, 2, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 256, 64, 200, 272, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 64, 64, 200, 272, 1, 1, 1, 1, 1, 1, 0, False, 1],
+ [1, 256, 64, 200, 272, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 64, 64, 200, 272, 1, 1, 1, 1, 0, 0, 1, False, 1],
[1, 64, 64, 200, 272, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 1, 64, 224, 224, 1, 1, 1, 1, 1, 1, 0, True, 1],
+ [1, 1, 64, 224, 224, 1, 1, 1, 1, 0, 0, 1, True, 1],
[1, 64, 64, 224, 224, 3, 3, 1, 1, 1, 1, 1, False, 1],
[1, 64, 64, 224, 224, 3, 3, 1, 1, 1, 1, 1, True, 1],
- [1, 128, 64, 256, 256, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 32, 64, 256, 256, 1, 1, 1, 1, 1, 1, 0, False, 1],
+ [1, 128, 64, 256, 256, 3, 3, 2, 2, 1, 1, 1, False, 1],
+ [1, 32, 64, 256, 256, 1, 1, 1, 1, 0, 0, 1, False, 1],
[1, 32, 64, 256, 256, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 128, 64, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 160, 64, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 160, 64, 28, 28, 1, 1, 2, 2, 2, 2, 0, False, 1],
- [1, 256, 64, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 64, 64, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 64, 64, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1],
+ [1, 128, 64, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 160, 64, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 160, 64, 28, 28, 1, 1, 2, 2, 0, 0, 1, False, 1],
+ [1, 256, 64, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 64, 64, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1],
[1, 64, 64, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1],
[1, 32, 64, 30, 40, 3, 3, 1, 1, 1, 1, 1, True, 1],
[1, 64, 64, 300, 300, 3, 3, 1, 1, 1, 1, 1, True, 1],
[1, 96, 64, 35, 35, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 128, 64, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 128, 64, 56, 56, 1, 1, 2, 2, 2, 2, 0, False, 1],
- [1, 128, 64, 56, 56, 3, 3, 2, 2, 2, 2, 1, False, 1],
+ [1, 128, 64, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 128, 64, 56, 56, 1, 1, 2, 2, 0, 0, 1, False, 1],
+ [1, 128, 64, 56, 56, 3, 3, 2, 2, 1, 1, 1, False, 1],
[1, 14, 64, 56, 56, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 144, 64, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 144, 64, 56, 56, 1, 1, 2, 2, 2, 2, 0, False, 1],
+ [1, 144, 64, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 144, 64, 56, 56, 1, 1, 2, 2, 0, 0, 1, False, 1],
[1, 192, 64, 56, 56, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 24, 64, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 256, 64, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 64, 64, 56, 56, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 64, 64, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1],
+ [1, 24, 64, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 256, 64, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 64, 64, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1],
[1, 64, 64, 56, 56, 3, 3, 1, 1, 1, 1, 1, False, 1],
[1, 32, 64, 60, 80, 3, 3, 1, 1, 1, 1, 1, True, 1],
[1, 128, 64, 64, 64, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 160, 64, 64, 64, 3, 3, 2, 2, 2, 2, 1, True, 1],
- [1, 64, 64, 64, 64, 4, 4, 4, 4, 4, 4, 0, True, 1],
- [1, 64, 64, 73, 73, 1, 7, 1, 1, 1, 1, 0, False, 1],
- [1, 64, 64, 73, 73, 7, 1, 1, 1, 1, 1, 3, False, 1],
- [1, 96, 64, 73, 73, 3, 3, 1, 1, 1, 1, 0, False, 1],
- [1, 24, 64, 80, 80, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 128, 640, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 640, 640, 32, 32, 3, 3, 1, 1, 1, 1, 1, True, 1],
- [1, 128, 640, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
+ [1, 160, 64, 64, 64, 3, 3, 2, 2, 1, 1, 1, True, 1],
+ [1, 64, 64, 64, 64, 4, 4, 4, 4, 0, 0, 1, True, 1],
+ [1, 64, 64, 73, 73, 1, 7, 1, 1, 0, 3, 1, False, 1],
+ [1, 64, 64, 73, 73, 7, 1, 1, 1, 3, 0, 1, False, 1],
+ [1, 96, 64, 73, 73, 3, 3, 1, 1, 0, 0, 1, False, 1],
+ [1, 24, 64, 80, 80, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 128, 640, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 128, 640, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
[1, 160, 640, 7, 7, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 640, 654, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 168, 672, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 80, 672, 10, 10, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 112, 672, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 128, 672, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 192, 672, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 56, 672, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 672, 672, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 672, 672, 14, 14, 5, 5, 1, 1, 1, 1, 2, False, 1],
- [1, 672, 672, 14, 14, 5, 5, 2, 2, 2, 2, 2, False, 1],
- [1, 672, 672, 14, 14, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 112, 672, 15, 15, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 672, 672, 15, 15, 5, 5, 1, 1, 1, 1, 2, False, 1],
- [1, 672, 672, 17, 17, 5, 5, 2, 2, 2, 2, 0, False, 1],
- [1, 672, 672, 19, 19, 5, 5, 2, 2, 2, 2, 0, False, 1],
- [1, 112, 672, 20, 20, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 24, 672, 20, 20, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 546, 672, 20, 20, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 672, 672, 20, 20, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 672, 672, 20, 20, 5, 5, 2, 2, 2, 2, 2, False, 1],
- [1, 112, 672, 24, 24, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 160, 672, 24, 24, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 672, 672, 24, 24, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 672, 672, 24, 24, 5, 5, 1, 1, 1, 1, 2, False, 1],
- [1, 1344, 672, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 1344, 672, 28, 28, 1, 1, 2, 2, 2, 2, 0, False, 1],
- [1, 192, 672, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 672, 672, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 672, 672, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 672, 672, 56, 56, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 128, 672, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 160, 672, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 192, 672, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 672, 672, 7, 7, 1, 5, 1, 1, 1, 1, 0, False, 1],
- [1, 672, 672, 7, 7, 5, 1, 1, 1, 1, 1, 2, False, 1],
- [1, 672, 672, 7, 7, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 672, 672, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 80, 672, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 192, 672, 8, 8, 1, 1, 1, 1, 1, 1, 0, False, 1],
+ [1, 640, 654, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 168, 672, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 80, 672, 10, 10, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 112, 672, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 128, 672, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 192, 672, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 56, 672, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 112, 672, 15, 15, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 112, 672, 20, 20, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 24, 672, 20, 20, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 546, 672, 20, 20, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 112, 672, 24, 24, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 160, 672, 24, 24, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 1344, 672, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 1344, 672, 28, 28, 1, 1, 2, 2, 0, 0, 1, False, 1],
+ [1, 192, 672, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 672, 672, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 128, 672, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 160, 672, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 192, 672, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 672, 672, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 80, 672, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 192, 672, 8, 8, 1, 1, 1, 1, 0, 0, 1, False, 1],
[1, 40, 68, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 174, 696, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 58, 696, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 696, 696, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 128, 704, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 128, 704, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 18, 72, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 20, 72, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 24, 72, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 288, 72, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 8, 72, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 72, 72, 112, 112, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 128, 72, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 144, 72, 14, 14, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 18, 72, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 36, 72, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 512, 72, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
+ [1, 174, 696, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 58, 696, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 696, 696, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 128, 704, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 128, 704, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 18, 72, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 20, 72, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 24, 72, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 288, 72, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 8, 72, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 128, 72, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 144, 72, 14, 14, 3, 3, 2, 2, 1, 1, 1, False, 1],
+ [1, 18, 72, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 36, 72, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 512, 72, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
[1, 72, 72, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 20, 72, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 24, 72, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 40, 72, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 72, 72, 28, 28, 1, 5, 1, 1, 1, 1, 0, False, 1],
- [1, 72, 72, 28, 28, 5, 1, 1, 1, 1, 1, 2, False, 1],
- [1, 40, 72, 40, 40, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 12, 72, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 168, 72, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 168, 72, 56, 56, 1, 1, 2, 2, 2, 2, 0, False, 1],
- [1, 216, 72, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 216, 72, 56, 56, 1, 1, 2, 2, 2, 2, 0, False, 1],
- [1, 24, 72, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 72, 72, 56, 56, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 72, 72, 56, 56, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 72, 72, 56, 56, 5, 5, 2, 2, 2, 2, 2, False, 1],
- [1, 72, 72, 56, 56, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 72, 72, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 24, 72, 80, 80, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 72, 72, 80, 80, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 72, 72, 80, 80, 5, 5, 2, 2, 2, 2, 2, False, 1],
- [1, 192, 720, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 720, 720, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 720, 720, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 120, 720, 17, 17, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 192, 720, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 720, 720, 28, 28, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 208, 720, 9, 9, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 728, 728, 19, 19, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 728, 728, 19, 19, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 728, 728, 38, 38, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 728, 728, 38, 38, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 128, 736, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 512, 736, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 128, 736, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
+ [1, 20, 72, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 24, 72, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 40, 72, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 40, 72, 40, 40, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 12, 72, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 168, 72, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 168, 72, 56, 56, 1, 1, 2, 2, 0, 0, 1, False, 1],
+ [1, 216, 72, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 216, 72, 56, 56, 1, 1, 2, 2, 0, 0, 1, False, 1],
+ [1, 24, 72, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 72, 72, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 24, 72, 80, 80, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 192, 720, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 720, 720, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 120, 720, 17, 17, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 192, 720, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 208, 720, 9, 9, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 728, 728, 19, 19, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 728, 728, 38, 38, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 128, 736, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 512, 736, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 128, 736, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
[1, 334, 740, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 768, 768, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 128, 768, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 192, 768, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 384, 768, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 256, 768, 32, 32, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 128, 768, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 224, 768, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
+ [1, 768, 768, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 128, 768, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 192, 768, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 384, 768, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 256, 768, 32, 32, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 128, 768, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 224, 768, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
[1, 16, 78, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1],
[1, 34, 78, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1],
[1, 24, 78, 56, 56, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 196, 784, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 80, 784, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 784, 784, 14, 14, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 784, 784, 7, 7, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 784, 784, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 16, 8, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 224, 8, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 232, 8, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 48, 8, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 528, 8, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 64, 8, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 72, 8, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 8, 8, 112, 112, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 320, 80, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 784, 80, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 480, 80, 10, 10, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 80, 80, 112, 112, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 100, 80, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 112, 80, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 184, 80, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 200, 80, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 240, 80, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 480, 80, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 80, 80, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 92, 80, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 480, 80, 15, 15, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 184, 80, 20, 20, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 200, 80, 20, 20, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 480, 80, 20, 20, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 240, 80, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 240, 80, 56, 56, 1, 1, 2, 2, 2, 2, 0, False, 1],
- [1, 80, 80, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1],
+ [1, 196, 784, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 80, 784, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 784, 784, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 16, 8, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 224, 8, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 232, 8, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 48, 8, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 528, 8, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 64, 8, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 72, 8, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 320, 80, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 784, 80, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 480, 80, 10, 10, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 80, 80, 112, 112, 3, 3, 2, 2, 1, 1, 1, False, 1],
+ [1, 100, 80, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 112, 80, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 184, 80, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 200, 80, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 240, 80, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 480, 80, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 92, 80, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 480, 80, 15, 15, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 184, 80, 20, 20, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 200, 80, 20, 20, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 480, 80, 20, 20, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 240, 80, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 240, 80, 56, 56, 1, 1, 2, 2, 0, 0, 1, False, 1],
+ [1, 80, 80, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1],
[1, 80, 80, 56, 56, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 184, 80, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 200, 80, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 480, 80, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 80, 80, 7, 7, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 128, 800, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 128, 800, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
+ [1, 184, 80, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 200, 80, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 480, 80, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 128, 800, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 128, 800, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
[1, 272, 800, 7, 7, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 232, 816, 10, 10, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 192, 816, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 136, 816, 19, 19, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 128, 832, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 128, 832, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 160, 832, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 192, 832, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 256, 832, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 32, 832, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 384, 832, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 48, 832, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 336, 84, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 888, 84, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 128, 864, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 192, 864, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 128, 864, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 528, 88, 17, 17, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 24, 88, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 88, 88, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 222, 888, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 84, 888, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 888, 888, 14, 14, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 888, 888, 7, 7, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 888, 888, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 112, 896, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 224, 896, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 128, 896, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 2016, 896, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 2016, 896, 14, 14, 1, 1, 2, 2, 2, 2, 0, False, 1],
- [1, 2048, 896, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 2048, 896, 14, 14, 1, 1, 2, 2, 2, 2, 0, False, 1],
- [1, 256, 896, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 896, 896, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 896, 896, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 896, 896, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 896, 896, 28, 28, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 896, 896, 28, 28, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 128, 896, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 192, 912, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 912, 912, 14, 14, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 912, 912, 7, 7, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 92, 92, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 128, 928, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 128, 928, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
+ [1, 232, 816, 10, 10, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 192, 816, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 136, 816, 19, 19, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 128, 832, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 128, 832, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 160, 832, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 192, 832, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 256, 832, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 32, 832, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 384, 832, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 48, 832, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 336, 84, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 888, 84, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 128, 864, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 192, 864, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 128, 864, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 528, 88, 17, 17, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 24, 88, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 222, 888, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 84, 888, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 888, 888, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 112, 896, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 224, 896, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 128, 896, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 2016, 896, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 2016, 896, 14, 14, 1, 1, 2, 2, 0, 0, 1, False, 1],
+ [1, 2048, 896, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 2048, 896, 14, 14, 1, 1, 2, 2, 0, 0, 1, False, 1],
+ [1, 256, 896, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 896, 896, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 128, 896, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 192, 912, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 128, 928, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 128, 928, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
[1, 28, 94, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 24, 96, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 96, 96, 112, 112, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 96, 96, 112, 112, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 96, 96, 113, 113, 3, 3, 2, 2, 2, 2, 0, False, 1],
- [1, 96, 96, 121, 121, 3, 3, 2, 2, 2, 2, 0, False, 1],
- [1, 96, 96, 131, 131, 3, 3, 2, 2, 2, 2, 0, False, 1],
+ [1, 24, 96, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
[1, 208, 96, 14, 14, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 40, 96, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 576, 96, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 576, 96, 19, 19, 1, 1, 1, 1, 1, 1, 0, False, 1],
+ [1, 40, 96, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 576, 96, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 576, 96, 19, 19, 1, 1, 1, 1, 0, 0, 1, False, 1],
[1, 128, 96, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 96, 96, 28, 28, 5, 5, 2, 2, 2, 2, 2, False, 1],
[1, 96, 96, 35, 35, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 128, 96, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 192, 96, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 192, 96, 56, 56, 1, 1, 2, 2, 2, 2, 0, False, 1],
- [1, 24, 96, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 96, 96, 56, 56, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 96, 96, 56, 56, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 24, 96, 60, 60, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 24, 96, 65, 65, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 576, 96, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 240, 960, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 272, 960, 12, 12, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 128, 960, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 192, 960, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 160, 960, 24, 24, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 960, 960, 27, 27, 5, 5, 2, 2, 2, 2, 0, False, 1],
- [1, 960, 960, 3, 3, 1, 5, 1, 1, 1, 1, 0, False, 1],
- [1, 960, 960, 3, 3, 5, 1, 1, 1, 1, 1, 2, False, 1],
- [1, 128, 960, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 160, 960, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 320, 960, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 80, 960, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 960, 960, 7, 7, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 960, 960, 7, 7, 5, 5, 1, 1, 1, 1, 2, False, 1],
+ [1, 128, 96, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 192, 96, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 192, 96, 56, 56, 1, 1, 2, 2, 0, 0, 1, False, 1],
+ [1, 24, 96, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 96, 96, 56, 56, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 24, 96, 60, 60, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 24, 96, 65, 65, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 576, 96, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 240, 960, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 272, 960, 12, 12, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 128, 960, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 192, 960, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 160, 960, 24, 24, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 128, 960, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 160, 960, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 320, 960, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 80, 960, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
[1, 20, 98, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 128, 992, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 128, 992, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 1024, 1024, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 256, 1024, 128, 128, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 1024, 1024, 14, 14, 2, 2, 2, 2, 2, 2, 0, True, 1],
- [1, 2048, 1024, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 2048, 1024, 14, 14, 1, 1, 2, 2, 2, 2, 0, False, 1],
+ [1, 128, 992, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 128, 992, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 1024, 1024, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 256, 1024, 128, 128, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 1024, 1024, 14, 14, 2, 2, 2, 2, 0, 0, 1, True, 1],
+ [1, 2048, 1024, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 2048, 1024, 14, 14, 1, 1, 2, 2, 0, 0, 1, False, 1],
[1, 512, 1024, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 2048, 1024, 45, 80, 1, 1, 2, 2, 2, 2, 0, False, 1],
- [1, 2048, 1024, 50, 68, 1, 1, 2, 2, 2, 2, 0, False, 1],
- [1, 2048, 1024, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 2048, 1024, 7, 7, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 1056, 1056, 48, 48, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 1056, 1056, 48, 48, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 2904, 1056, 48, 48, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 2904, 1056, 48, 48, 1, 1, 2, 2, 2, 2, 0, False, 1],
- [1, 1056, 1056, 96, 96, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 3024, 1232, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 3024, 1232, 14, 14, 1, 1, 2, 2, 2, 2, 0, False, 1],
- [1, 128, 128, 112, 112, 2, 2, 2, 2, 2, 2, 0, True, 1],
- [1, 128, 128, 5, 5, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 1280, 1280, 30, 40, 3, 3, 1, 1, 1, 1, 1, True, 1],
- [1, 2520, 1344, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 2520, 1344, 14, 14, 1, 1, 2, 2, 2, 2, 0, False, 1],
- [1, 3712, 1392, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 3712, 1392, 14, 14, 1, 1, 2, 2, 2, 2, 0, False, 1],
- [1, 1024, 1440, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 1512, 1512, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 1536, 1536, 10, 10, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 2048, 1536, 10, 10, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 448, 1632, 12, 12, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 1920, 1920, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 2016, 2016, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 2048, 2048, 15, 20, 3, 3, 1, 1, 1, 1, 1, True, 1],
- [1, 1024, 2048, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 2048, 2048, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 1056, 2112, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 576, 216, 28, 28, 1, 1, 2, 2, 2, 2, 0, False, 1],
- [1, 232, 232, 112, 112, 3, 3, 2, 2, 2, 2, 1, False, 1],
+ [1, 2048, 1024, 45, 80, 1, 1, 2, 2, 0, 0, 1, False, 1],
+ [1, 2048, 1024, 50, 68, 1, 1, 2, 2, 0, 0, 1, False, 1],
+ [1, 2048, 1024, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 2048, 1024, 7, 7, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 1056, 1056, 48, 48, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 2904, 1056, 48, 48, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 2904, 1056, 48, 48, 1, 1, 2, 2, 0, 0, 1, False, 1],
+ [1, 3024, 1232, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 3024, 1232, 14, 14, 1, 1, 2, 2, 0, 0, 1, False, 1],
+ [1, 128, 128, 112, 112, 2, 2, 2, 2, 0, 0, 1, True, 1],
+ [1, 2520, 1344, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 2520, 1344, 14, 14, 1, 1, 2, 2, 0, 0, 1, False, 1],
+ [1, 3712, 1392, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 3712, 1392, 14, 14, 1, 1, 2, 2, 0, 0, 1, False, 1],
+ [1, 1024, 1440, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 1512, 1512, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 1536, 1536, 10, 10, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 2048, 1536, 10, 10, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 448, 1632, 12, 12, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 1920, 1920, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 2016, 2016, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 1024, 2048, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 2048, 2048, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 1056, 2112, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 576, 216, 28, 28, 1, 1, 2, 2, 0, 0, 1, False, 1],
+ [1, 232, 232, 112, 112, 3, 3, 2, 2, 1, 1, 1, False, 1],
[1, 232, 232, 56, 56, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 144, 24, 190, 190, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 720, 240, 28, 28, 1, 1, 2, 2, 2, 2, 0, False, 1],
- [1, 2520, 2520, 14, 14, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 2520, 2520, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
+ [1, 144, 24, 190, 190, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 720, 240, 28, 28, 1, 1, 2, 2, 0, 0, 1, False, 1],
+ [1, 2520, 2520, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
[1, 819, 256, 100, 136, 3, 3, 1, 1, 1, 1, 1, True, 1],
- [1, 256, 256, 112, 112, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 256, 256, 120, 160, 3, 3, 1, 1, 1, 1, 1, True, 1],
- [1, 512, 256, 180, 320, 1, 1, 2, 2, 2, 2, 0, False, 1],
- [1, 512, 256, 200, 272, 1, 1, 2, 2, 2, 2, 0, False, 1],
+ [1, 512, 256, 180, 320, 1, 1, 2, 2, 0, 0, 1, False, 1],
+ [1, 512, 256, 200, 272, 1, 1, 2, 2, 0, 0, 1, False, 1],
[1, 819, 256, 25, 34, 3, 3, 1, 1, 1, 1, 1, True, 1],
[1, 819, 256, 50, 68, 3, 3, 1, 1, 1, 1, 1, True, 1],
[1, 256, 256, 56, 56, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 256, 256, 56, 56, 3, 3, 1, 1, 1, 1, 1, False, 1],
[1, 256, 256, 56, 56, 3, 3, 1, 1, 1, 1, 1, True, 1],
- [1, 256, 256, 56, 56, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 256, 256, 56, 56, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 256, 256, 75, 75, 3, 3, 1, 1, 1, 1, 1, False, 1],
[1, 256, 256, 75, 75, 3, 3, 1, 1, 1, 1, 1, True, 1],
- [1, 2904, 264, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 264, 2904, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 726, 2904, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 2904, 2904, 24, 24, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 2904, 2904, 24, 24, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 7392, 2904, 24, 24, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 7392, 2904, 24, 24, 1, 1, 2, 2, 2, 2, 0, False, 1],
- [1, 2904, 2904, 48, 48, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 1024, 3, 224, 224, 16, 16, 16, 16, 16, 16, 0, True, 1],
- [1, 1024, 3, 224, 224, 32, 32, 32, 32, 32, 32, 0, True, 1],
- [1, 768, 3, 224, 224, 16, 16, 16, 16, 16, 16, 0, True, 1],
- [1, 768, 3, 224, 224, 32, 32, 32, 32, 32, 32, 0, False, 1],
- [1, 768, 3, 224, 224, 32, 32, 32, 32, 32, 32, 0, True, 1],
- [1, 32, 3, 299, 299, 3, 3, 2, 2, 2, 2, 0, False, 1],
- [1, 32, 3, 381, 381, 3, 3, 2, 2, 2, 2, 0, False, 1],
- [1, 768, 3, 384, 512, 32, 32, 32, 32, 32, 32, 0, True, 1],
- [1, 64, 3, 480, 640, 7, 7, 4, 4, 4, 4, 3, True, 1],
+ [1, 2904, 264, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 264, 2904, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 726, 2904, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 2904, 2904, 24, 24, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 7392, 2904, 24, 24, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 7392, 2904, 24, 24, 1, 1, 2, 2, 0, 0, 1, False, 1],
+ [1, 1024, 3, 224, 224, 16, 16, 16, 16, 0, 0, 1, True, 1],
+ [1, 1024, 3, 224, 224, 32, 32, 32, 32, 0, 0, 1, True, 1],
+ [1, 768, 3, 224, 224, 16, 16, 16, 16, 0, 0, 1, True, 1],
+ [1, 768, 3, 224, 224, 32, 32, 32, 32, 0, 0, 1, False, 1],
+ [1, 768, 3, 224, 224, 32, 32, 32, 32, 0, 0, 1, True, 1],
+ [1, 32, 3, 299, 299, 3, 3, 2, 2, 0, 0, 1, False, 1],
+ [1, 32, 3, 381, 381, 3, 3, 2, 2, 0, 0, 1, False, 1],
+ [1, 768, 3, 384, 512, 32, 32, 32, 32, 0, 0, 1, True, 1],
+ [1, 64, 3, 480, 640, 7, 7, 4, 4, 3, 3, 1, True, 1],
[1, 32, 3, 512, 512, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 32, 3, 512, 512, 6, 6, 2, 2, 2, 2, 2, False, 1],
- [1, 32, 3, 512, 512, 7, 7, 4, 4, 4, 4, 3, True, 1],
- [1, 192, 3, 512, 672, 16, 16, 16, 16, 16, 16, 0, True, 1],
- [1, 1280, 3, 518, 518, 14, 14, 14, 14, 14, 14, 0, True, 1],
- [1, 64, 3, 720, 1280, 7, 7, 2, 2, 2, 2, 3, False, 1],
- [1, 64, 3, 800, 1088, 7, 7, 2, 2, 2, 2, 3, False, 1],
- [1, 308, 3024, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 3024, 3024, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 3024, 308, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 528, 32, 192, 192, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 64, 32, 512, 512, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 336, 336, 112, 112, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 888, 336, 14, 14, 1, 1, 2, 2, 2, 2, 0, False, 1],
- [1, 336, 336, 48, 48, 5, 5, 1, 1, 1, 1, 2, False, 1],
- [1, 336, 336, 56, 56, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 3712, 348, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 348, 3712, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 3712, 3712, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 912, 408, 14, 14, 1, 1, 2, 2, 2, 2, 0, False, 1],
- [1, 1008, 432, 14, 14, 1, 1, 2, 2, 2, 2, 0, False, 1],
- [1, 128, 512, 100, 136, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 256, 512, 100, 136, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 256, 512, 100, 136, 1, 1, 1, 1, 1, 1, 0, True, 1],
+ [1, 32, 3, 512, 512, 6, 6, 2, 2, 2, 2, 1, False, 1],
+ [1, 32, 3, 512, 512, 7, 7, 4, 4, 3, 3, 1, True, 1],
+ [1, 192, 3, 512, 672, 16, 16, 16, 16, 0, 0, 1, True, 1],
+ [1, 1280, 3, 518, 518, 14, 14, 14, 14, 0, 0, 1, True, 1],
+ [1, 64, 3, 720, 1280, 7, 7, 2, 2, 3, 3, 1, False, 1],
+ [1, 64, 3, 800, 1088, 7, 7, 2, 2, 3, 3, 1, False, 1],
+ [1, 308, 3024, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 3024, 3024, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 3024, 308, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 528, 32, 192, 192, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 64, 32, 512, 512, 3, 3, 2, 2, 1, 1, 1, False, 1],
+ [1, 888, 336, 14, 14, 1, 1, 2, 2, 0, 0, 1, False, 1],
+ [1, 3712, 348, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 348, 3712, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 3712, 3712, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 912, 408, 14, 14, 1, 1, 2, 2, 0, 0, 1, False, 1],
+ [1, 1008, 432, 14, 14, 1, 1, 2, 2, 0, 0, 1, False, 1],
+ [1, 128, 512, 100, 136, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 256, 512, 100, 136, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 256, 512, 100, 136, 1, 1, 1, 1, 0, 0, 1, True, 1],
[1, 364, 512, 38, 38, 3, 3, 1, 1, 1, 1, 1, True, 1],
[1, 512, 512, 38, 38, 3, 3, 1, 1, 1, 1, 1, True, 1],
[1, 256, 512, 56, 56, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 512, 512, 60, 80, 3, 3, 1, 1, 1, 1, 1, True, 1],
- [1, 1024, 512, 90, 160, 1, 1, 2, 2, 2, 2, 0, False, 1],
- [1, 528, 528, 17, 17, 5, 5, 1, 1, 1, 1, 2, False, 1],
- [1, 528, 528, 192, 192, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 1056, 528, 96, 96, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 1056, 528, 96, 96, 1, 1, 2, 2, 2, 2, 0, False, 1],
- [1, 528, 528, 96, 96, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 64, 64, 128, 128, 2, 2, 2, 2, 2, 2, 0, True, 1],
- [1, 256, 64, 180, 320, 1, 1, 1, 1, 1, 1, 0, False, 1],
+ [1, 1024, 512, 90, 160, 1, 1, 2, 2, 0, 0, 1, False, 1],
+ [1, 1056, 528, 96, 96, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 1056, 528, 96, 96, 1, 1, 2, 2, 0, 0, 1, False, 1],
+ [1, 64, 64, 128, 128, 2, 2, 2, 2, 0, 0, 1, True, 1],
+ [1, 256, 64, 180, 320, 1, 1, 1, 1, 0, 0, 1, False, 1],
[1, 1, 64, 480, 640, 3, 3, 1, 1, 1, 1, 1, True, 1],
[1, 64, 64, 480, 640, 3, 3, 1, 1, 1, 1, 1, True, 1],
- [1, 1392, 696, 28, 28, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 1392, 696, 28, 28, 1, 1, 2, 2, 2, 2, 0, False, 1],
- [1, 696, 696, 28, 28, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 696, 696, 56, 56, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 1920, 720, 14, 14, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 1920, 720, 14, 14, 1, 1, 2, 2, 2, 2, 0, False, 1],
- [1, 720, 720, 17, 17, 5, 5, 1, 1, 1, 1, 2, False, 1],
- [1, 720, 720, 21, 21, 5, 5, 2, 2, 2, 2, 0, False, 1],
- [1, 2904, 726, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 7392, 726, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 1024, 728, 19, 19, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 1024, 728, 19, 19, 1, 1, 2, 2, 2, 2, 0, False, 1],
- [1, 728, 728, 38, 38, 3, 3, 1, 1, 1, 1, 1, False, 1],
- [1, 728, 728, 38, 38, 1, 1, 2, 2, 2, 2, 0, False, 1],
- [1, 726, 7392, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 7392, 7392, 12, 12, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 7392, 7392, 24, 24, 3, 3, 2, 2, 2, 2, 1, False, 1],
- [1, 1024, 782, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 816, 816, 19, 19, 5, 5, 1, 1, 1, 1, 2, False, 1],
- [1, 816, 816, 23, 23, 5, 5, 2, 2, 2, 2, 0, False, 1],
- [1, 912, 912, 7, 7, 1, 1, 1, 1, 1, 1, 0, False, 1],
- [1, 1280, 960, 1, 1, 1, 1, 1, 1, 1, 1, 0, True, 1],
- [1, 960, 960, 24, 24, 5, 5, 1, 1, 1, 1, 2, False, 1],
- [1, 1280, 1280, 16, 16, 1, 1, 1, 1, 1, 1, 0, True, 1],
+ [1, 1392, 696, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 1392, 696, 28, 28, 1, 1, 2, 2, 0, 0, 1, False, 1],
+ [1, 1920, 720, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 1920, 720, 14, 14, 1, 1, 2, 2, 0, 0, 1, False, 1],
+ [1, 2904, 726, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 7392, 726, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 1024, 728, 19, 19, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 1024, 728, 19, 19, 1, 1, 2, 2, 0, 0, 1, False, 1],
+ [1, 728, 728, 38, 38, 1, 1, 2, 2, 0, 0, 1, False, 1],
+ [1, 726, 7392, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 7392, 7392, 12, 12, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 1024, 782, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 912, 912, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1],
+ [1, 1280, 960, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1],
+ [1, 1280, 1280, 16, 16, 1, 1, 1, 1, 0, 0, 1, True, 1],
[1, 1280, 1280, 16, 16, 3, 3, 1, 1, 1, 1, 1, True, 1],
- [1, 1280, 1280, 16, 16, 3, 3, 2, 2, 2, 2, 1, True, 1],
+ [1, 1280, 1280, 16, 16, 3, 3, 2, 2, 1, 1, 1, True, 1],
[1, 1280, 1280, 32, 32, 3, 3, 1, 1, 1, 1, 1, True, 1],
- [1, 640, 1280, 32, 32, 1, 1, 1, 1, 1, 1, 0, True, 1],
+ [1, 640, 1280, 32, 32, 1, 1, 1, 1, 0, 0, 1, True, 1],
[1, 640, 1280, 32, 32, 3, 3, 1, 1, 1, 1, 1, True, 1],
- [1, 1280, 1280, 8, 8, 1, 1, 1, 1, 1, 1, 0, True, 1],
+ [1, 1280, 1280, 8, 8, 1, 1, 1, 1, 0, 0, 1, True, 1],
[1, 1280, 1280, 8, 8, 3, 3, 1, 1, 1, 1, 1, True, 1],
- [1, 1280, 1920, 16, 16, 1, 1, 1, 1, 1, 1, 0, True, 1],
+ [1, 1280, 1920, 16, 16, 1, 1, 1, 1, 0, 0, 1, True, 1],
[1, 1280, 1920, 16, 16, 3, 3, 1, 1, 1, 1, 1, True, 1],
- [1, 640, 1920, 32, 32, 1, 1, 1, 1, 1, 1, 0, True, 1],
+ [1, 640, 1920, 32, 32, 1, 1, 1, 1, 0, 0, 1, True, 1],
[1, 640, 1920, 32, 32, 3, 3, 1, 1, 1, 1, 1, True, 1],
- [1, 1280, 2560, 16, 16, 1, 1, 1, 1, 1, 1, 0, True, 1],
+ [1, 1280, 2560, 16, 16, 1, 1, 1, 1, 0, 0, 1, True, 1],
[1, 1280, 2560, 16, 16, 3, 3, 1, 1, 1, 1, 1, True, 1],
- [1, 1280, 2560, 8, 8, 1, 1, 1, 1, 1, 1, 0, True, 1],
+ [1, 1280, 2560, 8, 8, 1, 1, 1, 1, 0, 0, 1, True, 1],
[1, 1280, 2560, 8, 8, 3, 3, 1, 1, 1, 1, 1, True, 1],
- [1, 640, 320, 32, 32, 1, 1, 1, 1, 1, 1, 0, True, 1],
+ [1, 640, 320, 32, 32, 1, 1, 1, 1, 0, 0, 1, True, 1],
[1, 640, 320, 32, 32, 3, 3, 1, 1, 1, 1, 1, True, 1],
- [1, 320, 320, 64, 64, 1, 1, 1, 1, 1, 1, 0, True, 1],
+ [1, 320, 320, 64, 64, 1, 1, 1, 1, 0, 0, 1, True, 1],
[1, 320, 320, 64, 64, 3, 3, 1, 1, 1, 1, 1, True, 1],
- [1, 320, 320, 64, 64, 3, 3, 2, 2, 2, 2, 1, True, 1],
+ [1, 320, 320, 64, 64, 3, 3, 2, 2, 1, 1, 1, True, 1],
[1, 4, 320, 64, 64, 3, 3, 1, 1, 1, 1, 1, True, 1],
[1, 320, 4, 64, 64, 3, 3, 1, 1, 1, 1, 1, True, 1],
- [1, 1280, 640, 16, 16, 1, 1, 1, 1, 1, 1, 0, True, 1],
+ [1, 1280, 640, 16, 16, 1, 1, 1, 1, 0, 0, 1, True, 1],
[1, 1280, 640, 16, 16, 3, 3, 1, 1, 1, 1, 1, True, 1],
- [1, 640, 640, 32, 32, 1, 1, 1, 1, 1, 1, 0, True, 1],
+ [1, 640, 640, 32, 32, 1, 1, 1, 1, 0, 0, 1, True, 1],
[1, 640, 640, 32, 32, 3, 3, 1, 1, 1, 1, 1, True, 1],
- [1, 640, 640, 32, 32, 3, 3, 2, 2, 2, 2, 1, True, 1],
- [1, 320, 640, 64, 64, 1, 1, 1, 1, 1, 1, 0, True, 1],
+ [1, 640, 640, 32, 32, 3, 3, 2, 2, 1, 1, 1, True, 1],
+ [1, 320, 640, 64, 64, 1, 1, 1, 1, 0, 0, 1, True, 1],
[1, 320, 640, 64, 64, 3, 3, 1, 1, 1, 1, 1, True, 1],
[1, 640, 640, 64, 64, 3, 3, 1, 1, 1, 1, 1, True, 1],
- [1, 640, 960, 32, 32, 1, 1, 1, 1, 1, 1, 0, True, 1],
+ [1, 640, 960, 32, 32, 1, 1, 1, 1, 0, 0, 1, True, 1],
[1, 640, 960, 32, 32, 3, 3, 1, 1, 1, 1, 1, True, 1],
- [1, 320, 960, 64, 64, 1, 1, 1, 1, 1, 1, 0, True, 1],
+ [1, 320, 960, 64, 64, 1, 1, 1, 1, 0, 0, 1, True, 1],
[1, 320, 960, 64, 64, 3, 3, 1, 1, 1, 1, 1, True, 1],
],
},
@@ -1599,21 +1599,59 @@ def test_conv2d_localrun(device, input_spec):
failing_parameters = [
# [batch_size, output_channels, input_channels, input_height, input_width, kernel_height, kernel_width, stride_x, stride_y, pad_x, pad_y, groups, bias, dilation]
- # Input is 32MB maps to MM 64 cores, we neeed to avoid sharding this tensor and use dram intrelaved directly with MM
- [1, 256, 1024, 128, 128, 1, 1, 1, 1, 0, 0, 1, False, 1], # 5
- [1, 1056, 1056, 96, 96, 3, 3, 2, 2, 1, 1, 4, False, 1], # 14
- [1, 2904, 2904, 48, 48, 3, 3, 2, 2, 1, 1, 11, False, 1], # 170
- [1, 1024, 3, 224, 224, 32, 32, 32, 32, 0, 0, 1, True, 1], # 172
- [1, 768, 3, 224, 224, 32, 32, 32, 32, 0, 0, 1, False, 1], # 181
- [1, 768, 3, 224, 224, 32, 32, 32, 32, 0, 0, 1, True, 1], # 182
- [1, 768, 3, 384, 512, 32, 32, 32, 32, 0, 0, 1, True, 1], # 198
- [1, 64, 3, 720, 1280, 7, 7, 2, 2, 3, 3, 1, False, 1], # 203
- [1, 64, 3, 800, 1088, 7, 7, 2, 2, 3, 3, 1, False, 1], # 204
- [1, 528, 528, 192, 192, 3, 3, 2, 2, 1, 1, 2, False, 1], # 292
- [1, 816, 816, 19, 19, 5, 5, 1, 1, 2, 2, 816, False, 1], # 373
- [1, 816, 816, 23, 23, 5, 5, 2, 2, 0, 0, 816, False, 1], # 374
- [1, 960, 960, 24, 24, 5, 5, 1, 1, 2, 2, 960, False, 1], # 394
- [1, 960, 960, 27, 27, 5, 5, 2, 2, 0, 0, 960, False, 1], # 395
+ [1, 960, 960, 27, 27, 5, 5, 2, 2, 0, 0, 960, False, 1], # 0
+ [1, 960, 960, 24, 24, 5, 5, 1, 1, 2, 2, 960, False, 1], # 5
+ [1, 816, 816, 19, 19, 5, 5, 1, 1, 2, 2, 816, False, 1], # 19
+ [1, 816, 816, 23, 23, 5, 5, 2, 2, 0, 0, 816, False, 1], # 20
+ [1, 1056, 1056, 96, 96, 3, 3, 2, 2, 1, 1, 4, False, 1], # 127
+ [1, 528, 528, 192, 192, 3, 3, 2, 2, 1, 1, 2, False, 1], # 220
+ [1, 2904, 2904, 48, 48, 3, 3, 2, 2, 1, 1, 11, False, 1], # 294
+ [1, 1024, 1024, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], # 1407
+ [1, 256, 1024, 128, 128, 1, 1, 1, 1, 0, 0, 1, False, 1], # 1408
+ [1, 1056, 1056, 48, 48, 1, 1, 1, 1, 0, 0, 1, False, 1], # 1417
+ [1, 2904, 1056, 48, 48, 1, 1, 1, 1, 0, 0, 1, False, 1], # 1418
+ [1, 3024, 1232, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], # 1420
+ [1, 3024, 1232, 14, 14, 1, 1, 2, 2, 0, 0, 1, False, 1], # 1421
+ [1, 2520, 1344, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], # 1423
+ [1, 3712, 1392, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], # 1425
+ [1, 1024, 1440, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], # 1427
+ [1, 448, 1632, 12, 12, 1, 1, 1, 1, 0, 0, 1, False, 1], # 1431
+ [1, 2520, 2520, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], # 1442
+ [1, 819, 256, 100, 136, 3, 3, 1, 1, 1, 1, 1, True, 1], # 1443
+ [1, 819, 256, 50, 68, 3, 3, 1, 1, 1, 1, 1, True, 1], # 1447
+ [1, 2904, 264, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], # 1451
+ [1, 264, 2904, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], # 1452
+ [1, 726, 2904, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], # 1453
+ [1, 7392, 2904, 24, 24, 1, 1, 1, 1, 0, 0, 1, False, 1], # 1455
+ [1, 1024, 3, 224, 224, 32, 32, 32, 32, 0, 0, 1, True, 1], # 1458
+ [1, 768, 3, 224, 224, 32, 32, 32, 32, 0, 0, 1, False, 1], # 1460
+ [1, 768, 3, 224, 224, 32, 32, 32, 32, 0, 0, 1, True, 1], # 1461
+ [1, 768, 3, 384, 512, 32, 32, 32, 32, 0, 0, 1, True, 1], # 1464
+ [1, 64, 3, 720, 1280, 7, 7, 2, 2, 3, 3, 1, False, 1], # 1471
+ [1, 64, 3, 800, 1088, 7, 7, 2, 2, 3, 3, 1, False, 1], # 1472
+ [1, 308, 3024, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], # 1473
+ [1, 3024, 3024, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], # 1474
+ [1, 3024, 308, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], # 1475
+ [1, 3712, 348, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], # 1479
+ [1, 348, 3712, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], # 1480
+ [1, 3712, 3712, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], # 1481
+ [1, 1056, 528, 96, 96, 1, 1, 1, 1, 0, 0, 1, False, 1], # 1491
+ [1, 1, 64, 480, 640, 3, 3, 1, 1, 1, 1, 1, True, 1], # 1495
+ [1, 64, 64, 480, 640, 3, 3, 1, 1, 1, 1, 1, True, 1], # 1496
+ [1, 1392, 696, 28, 28, 1, 1, 1, 1, 0, 0, 1, False, 1], # 1497
+ [1, 1920, 720, 14, 14, 1, 1, 1, 1, 0, 0, 1, False, 1], # 1499
+ [1, 2904, 726, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], # 1501
+ [1, 7392, 726, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], # 1502
+ [1, 1024, 728, 19, 19, 1, 1, 1, 1, 0, 0, 1, False, 1], # 1503
+ [1, 726, 7392, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], # 1506
+ [1, 7392, 7392, 12, 12, 1, 1, 1, 1, 0, 0, 1, False, 1], # 1507
+ [1, 1024, 782, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], # 1508
+ [1, 912, 912, 7, 7, 1, 1, 1, 1, 0, 0, 1, False, 1], # 1509
+ [1, 1280, 960, 1, 1, 1, 1, 1, 1, 0, 0, 1, True, 1], # 1510
+ [1, 640, 1920, 32, 32, 3, 3, 1, 1, 1, 1, 1, True, 1], # 1522
+ [1, 320, 320, 64, 64, 3, 3, 1, 1, 1, 1, 1, True, 1], # 1530
+ [1, 320, 640, 64, 64, 3, 3, 1, 1, 1, 1, 1, True, 1], # 1540
+ [1, 320, 960, 64, 64, 3, 3, 1, 1, 1, 1, 1, True, 1], # 1545
]
diff --git a/tests/sweep_framework/sweeps/matmul/short/matmul_traces.py b/tests/sweep_framework/sweeps/matmul/short/matmul_traces.py
index c66e5548834..0b6c9031553 100644
--- a/tests/sweep_framework/sweeps/matmul/short/matmul_traces.py
+++ b/tests/sweep_framework/sweeps/matmul/short/matmul_traces.py
@@ -13,6 +13,10 @@
TIMEOUT = 70
+# params contains the shape of the first tensor followed by the second tensor
+# Note: the shape of the second tensor starts at int(count / 2). It's easiest
+# to reason about if both tensors are the same rank, although some other
+# combinations may be valid.
parameters = {
"default": {
"params": [
@@ -111,25 +115,269 @@
(9, 768, 768, 640),
(920, 256, 256, 256),
],
- }
+ "core_grid": [False],
+ },
+ "gpt": {
+ "params": [
+ (1, 1, 1, 1, 1, 1, 1, 1),
+ (1, 1, 1, 1, 1, 1, 1, 2304),
+ (1, 1, 1, 1, 1, 1, 1, 3072),
+ (1, 1, 1, 1, 1, 1, 1, 65536),
+ (1, 1, 1, 1, 1, 1, 1, 768),
+ (1, 1, 1, 1, 1, 1, 1, 96),
+ (1, 1, 1, 2304, 1, 1, 2304, 1),
+ (1, 1, 1, 2304, 1, 1, 2304, 65536),
+ (1, 1, 1, 2304, 1, 1, 2304, 768),
+ (1, 1, 1, 3072, 1, 1, 3072, 1),
+ (1, 1, 1, 3072, 1, 1, 3072, 65536),
+ (1, 1, 1, 3072, 1, 1, 3072, 768),
+ (1, 1, 1, 65536, 1, 1, 65536, 2304),
+ (1, 1, 1, 65536, 1, 1, 65536, 3072),
+ (1, 1, 1, 65536, 1, 1, 65536, 768),
+ (1, 1, 1, 65536, 1, 1, 65536, 96),
+ (1, 1, 1, 768, 1, 1, 768, 1),
+ (1, 1, 1, 768, 1, 1, 768, 1024),
+ (1, 1, 1, 768, 1, 1, 768, 2304),
+ (1, 1, 1, 768, 1, 1, 768, 3072),
+ (1, 1, 1, 768, 1, 1, 768, 65536),
+ (1, 1, 1, 768, 1, 1, 768, 768),
+ (1, 1, 1, 768, 1, 1, 768, 96),
+ (1, 1, 1, 96, 1, 1, 96, 1),
+ (1, 1, 1, 96, 1, 1, 96, 65536),
+ (1, 1, 1, 96, 1, 1, 96, 768),
+ (1, 1, 1024, 768, 1, 1, 768, 1),
+ (1, 1, 1024, 768, 1, 1, 768, 1024),
+ (1, 1, 1024, 768, 1, 1, 768, 2304),
+ (1, 1, 1024, 768, 1, 1, 768, 3072),
+ (1, 1, 1024, 768, 1, 1, 768, 65536),
+ (1, 1, 1024, 768, 1, 1, 768, 768),
+ (1, 1, 1024, 768, 1, 1, 768, 96),
+ (1, 1, 2304, 1, 1, 1, 1, 1),
+ (1, 1, 2304, 1, 1, 1, 1, 2304),
+ (1, 1, 2304, 1, 1, 1, 1, 3072),
+ (1, 1, 2304, 1, 1, 1, 1, 65536),
+ (1, 1, 2304, 1, 1, 1, 1, 768),
+ (1, 1, 2304, 1, 1, 1, 1, 96),
+ (1, 1, 2304, 65536, 1, 1, 65536, 1),
+ (1, 1, 2304, 65536, 1, 1, 65536, 2304),
+ (1, 1, 2304, 65536, 1, 1, 65536, 3072),
+ (1, 1, 2304, 65536, 1, 1, 65536, 768),
+ (1, 1, 2304, 65536, 1, 1, 65536, 96),
+ (1, 1, 2304, 768, 1, 1, 768, 1),
+ (1, 1, 2304, 768, 1, 1, 768, 1024),
+ (1, 1, 2304, 768, 1, 1, 768, 2304),
+ (1, 1, 2304, 768, 1, 1, 768, 3072),
+ (1, 1, 2304, 768, 1, 1, 768, 65536),
+ (1, 1, 2304, 768, 1, 1, 768, 768),
+ (1, 1, 2304, 768, 1, 1, 768, 96),
+ (1, 1, 3072, 1, 1, 1, 1, 1),
+ (1, 1, 3072, 1, 1, 1, 1, 2304),
+ (1, 1, 3072, 1, 1, 1, 1, 3072),
+ (1, 1, 3072, 1, 1, 1, 1, 65536),
+ (1, 1, 3072, 1, 1, 1, 1, 768),
+ (1, 1, 3072, 1, 1, 1, 1, 96),
+ (1, 1, 3072, 65536, 1, 1, 65536, 1),
+ (1, 1, 3072, 65536, 1, 1, 65536, 2304),
+ (1, 1, 3072, 65536, 1, 1, 65536, 3072),
+ (1, 1, 3072, 65536, 1, 1, 65536, 768),
+ (1, 1, 3072, 65536, 1, 1, 65536, 96),
+ (1, 1, 3072, 768, 1, 1, 768, 1),
+ (1, 1, 3072, 768, 1, 1, 768, 1024),
+ (1, 1, 3072, 768, 1, 1, 768, 2304),
+ (1, 1, 3072, 768, 1, 1, 768, 3072),
+ (1, 1, 3072, 768, 1, 1, 768, 65536),
+ (1, 1, 3072, 768, 1, 1, 768, 768),
+ (1, 1, 3072, 768, 1, 1, 768, 96),
+ (1, 1, 65536, 1, 1, 1, 1, 1),
+ (1, 1, 65536, 1, 1, 1, 1, 2304),
+ (1, 1, 65536, 1, 1, 1, 1, 3072),
+ (1, 1, 65536, 1, 1, 1, 1, 65536),
+ (1, 1, 65536, 1, 1, 1, 1, 768),
+ (1, 1, 65536, 1, 1, 1, 1, 96),
+ (1, 1, 65536, 2304, 1, 1, 2304, 1),
+ (1, 1, 65536, 2304, 1, 1, 2304, 65536),
+ (1, 1, 65536, 2304, 1, 1, 2304, 768),
+ (1, 1, 65536, 3072, 1, 1, 3072, 1),
+ (1, 1, 65536, 3072, 1, 1, 3072, 65536),
+ (1, 1, 65536, 3072, 1, 1, 3072, 768),
+ (1, 1, 65536, 768, 1, 1, 768, 1),
+ (1, 1, 65536, 768, 1, 1, 768, 1024),
+ (1, 1, 65536, 768, 1, 1, 768, 2304),
+ (1, 1, 65536, 768, 1, 1, 768, 3072),
+ (1, 1, 65536, 768, 1, 1, 768, 65536),
+ (1, 1, 65536, 768, 1, 1, 768, 768),
+ (1, 1, 65536, 768, 1, 1, 768, 96),
+ (1, 1, 65536, 96, 1, 1, 96, 65536),
+ (1, 1, 65536, 96, 1, 1, 96, 768),
+ (1, 1, 768, 1, 1, 1, 1, 1),
+ (1, 1, 768, 1, 1, 1, 1, 2304),
+ (1, 1, 768, 1, 1, 1, 1, 3072),
+ (1, 1, 768, 1, 1, 1, 1, 65536),
+ (1, 1, 768, 1, 1, 1, 1, 768),
+ (1, 1, 768, 1, 1, 1, 1, 96),
+ (1, 1, 768, 1024, 1, 1, 1024, 768),
+ (1, 1, 768, 2304, 1, 1, 2304, 1),
+ (1, 1, 768, 2304, 1, 1, 2304, 65536),
+ (1, 1, 768, 2304, 1, 1, 2304, 768),
+ (1, 1, 768, 3072, 1, 1, 3072, 1),
+ (1, 1, 768, 3072, 1, 1, 3072, 65536),
+ (1, 1, 768, 3072, 1, 1, 3072, 768),
+ (1, 1, 768, 65536, 1, 1, 65536, 1),
+ (1, 1, 768, 65536, 1, 1, 65536, 2304),
+ (1, 1, 768, 65536, 1, 1, 65536, 3072),
+ (1, 1, 768, 65536, 1, 1, 65536, 768),
+ (1, 1, 768, 65536, 1, 1, 65536, 96),
+ (1, 1, 768, 768, 1, 1, 768, 1),
+ (1, 1, 768, 768, 1, 1, 768, 1024),
+ (1, 1, 768, 768, 1, 1, 768, 2304),
+ (1, 1, 768, 768, 1, 1, 768, 3072),
+ (1, 1, 768, 768, 1, 1, 768, 65536),
+ (1, 1, 768, 768, 1, 1, 768, 768),
+ (1, 1, 768, 768, 1, 1, 768, 96),
+ (1, 1, 768, 96, 1, 1, 96, 1),
+ (1, 1, 768, 96, 1, 1, 96, 65536),
+ (1, 1, 768, 96, 1, 1, 96, 768),
+ (1, 1, 96, 1, 1, 1, 1, 1),
+ (1, 1, 96, 1, 1, 1, 1, 2304),
+ (1, 1, 96, 1, 1, 1, 1, 3072),
+ (1, 1, 96, 1, 1, 1, 1, 65536),
+ (1, 1, 96, 1, 1, 1, 1, 768),
+ (1, 1, 96, 1, 1, 1, 1, 96),
+ (1, 1, 96, 65536, 1, 1, 65536, 1),
+ (1, 1, 96, 65536, 1, 1, 65536, 2304),
+ (1, 1, 96, 65536, 1, 1, 65536, 3072),
+ (1, 1, 96, 65536, 1, 1, 65536, 768),
+ (1, 1, 96, 65536, 1, 1, 65536, 96),
+ (1, 1, 96, 768, 1, 1, 768, 1),
+ (1, 1, 96, 768, 1, 1, 768, 1024),
+ (1, 1, 96, 768, 1, 1, 768, 2304),
+ (1, 1, 96, 768, 1, 1, 768, 3072),
+ (1, 1, 96, 768, 1, 1, 768, 65536),
+ (1, 1, 96, 768, 1, 1, 768, 768),
+ (1, 1, 96, 768, 1, 1, 768, 96),
+ (1, 64, 1024, 768, 1, 1, 768, 1),
+ (1, 64, 1024, 768, 1, 1, 768, 2304),
+ (1, 64, 1024, 768, 1, 1, 768, 3072),
+ (1, 64, 1024, 768, 1, 1, 768, 65536),
+ (1, 64, 1024, 768, 1, 1, 768, 768),
+ (1, 64, 1024, 768, 1, 1, 768, 96),
+ (1, 64, 768, 1024, 1, 1, 1024, 768),
+ (1, 64, 768, 1024, 1, 64, 1024, 768),
+ (64, 1, 1, 1024, 1, 1, 1024, 768),
+ (64, 1, 1, 1024, 64, 1, 1024, 1),
+ (64, 1, 1, 1024, 64, 1, 1024, 2304),
+ (64, 1, 1, 1024, 64, 1, 1024, 3072),
+ (64, 1, 1, 1024, 64, 1, 1024, 768),
+ (64, 1, 1, 1024, 64, 1, 1024, 96),
+ (64, 1, 1, 768, 1, 1, 768, 1),
+ (64, 1, 1, 768, 1, 1, 768, 1024),
+ (64, 1, 1, 768, 1, 1, 768, 2304),
+ (64, 1, 1, 768, 1, 1, 768, 3072),
+ (64, 1, 1, 768, 1, 1, 768, 65536),
+ (64, 1, 1, 768, 1, 1, 768, 768),
+ (64, 1, 1, 768, 1, 1, 768, 96),
+ (64, 1, 1, 768, 64, 1, 768, 1),
+ (64, 1, 1, 768, 64, 1, 768, 1024),
+ (64, 1, 1024, 1, 1, 1, 1, 2304),
+ (64, 1, 1024, 1, 1, 1, 1, 3072),
+ (64, 1, 1024, 1, 1, 1, 1, 768),
+ (64, 1, 1024, 1, 1, 1, 1, 96),
+ (64, 1, 1024, 1, 64, 1, 1, 1024),
+ (64, 1, 1024, 1, 64, 1, 1, 768),
+ (64, 1, 1024, 2304, 1, 1, 2304, 65536),
+ (64, 1, 1024, 2304, 1, 1, 2304, 768),
+ (64, 1, 1024, 2304, 64, 1, 2304, 1024),
+ (64, 1, 1024, 3072, 1, 1, 3072, 1),
+ (64, 1, 1024, 3072, 1, 1, 3072, 65536),
+ (64, 1, 1024, 3072, 1, 1, 3072, 768),
+ (64, 1, 1024, 768, 1, 1, 768, 1),
+ (64, 1, 1024, 768, 1, 1, 768, 1024),
+ (64, 1, 1024, 768, 1, 1, 768, 2304),
+ (64, 1, 1024, 768, 1, 1, 768, 3072),
+ (64, 1, 1024, 768, 1, 1, 768, 65536),
+ (64, 1, 1024, 768, 1, 1, 768, 768),
+ (64, 1, 1024, 768, 1, 1, 768, 96),
+ (64, 1, 1024, 768, 64, 1, 768, 1024),
+ (64, 1, 1024, 96, 1, 1, 96, 65536),
+ (64, 1, 1024, 96, 1, 1, 96, 768),
+ (64, 1, 1024, 96, 64, 1, 96, 1024),
+ (64, 1, 2304, 1024, 1, 1, 1024, 768),
+ (64, 1, 2304, 1024, 64, 1, 1024, 1),
+ (64, 1, 2304, 1024, 64, 1, 1024, 2304),
+ (64, 1, 2304, 1024, 64, 1, 1024, 3072),
+ (64, 1, 2304, 1024, 64, 1, 1024, 768),
+ (64, 1, 2304, 1024, 64, 1, 1024, 96),
+ (64, 1, 3072, 1024, 1, 1, 1024, 768),
+ (64, 1, 3072, 1024, 64, 1, 1024, 1),
+ (64, 1, 3072, 1024, 64, 1, 1024, 2304),
+ (64, 1, 3072, 1024, 64, 1, 1024, 3072),
+ (64, 1, 3072, 1024, 64, 1, 1024, 768),
+ (64, 1, 3072, 1024, 64, 1, 1024, 96),
+ (64, 1, 768, 1, 1, 1, 1, 2304),
+ (64, 1, 768, 1, 1, 1, 1, 3072),
+ (64, 1, 768, 1, 1, 1, 1, 768),
+ (64, 1, 768, 1, 1, 1, 1, 96),
+ (64, 1, 768, 1, 64, 1, 1, 768),
+ (64, 1, 768, 1024, 1, 1, 1024, 768),
+ (64, 1, 768, 1024, 64, 1, 1024, 1),
+ (64, 1, 768, 1024, 64, 1, 1024, 2304),
+ (64, 1, 768, 1024, 64, 1, 1024, 3072),
+ (64, 1, 768, 1024, 64, 1, 1024, 768),
+ (64, 1, 768, 1024, 64, 1, 1024, 96),
+ (64, 1, 96, 1024, 1, 1, 1024, 768),
+ (64, 1, 96, 1024, 64, 1, 1024, 1),
+ (64, 1, 96, 1024, 64, 1, 1024, 2304),
+ (64, 1, 96, 1024, 64, 1, 1024, 3072),
+ (64, 1, 96, 1024, 64, 1, 1024, 768),
+ (64, 1, 96, 1024, 64, 1, 1024, 96),
+ (64, 12, 1, 1024, 1, 1, 1024, 768),
+ (64, 12, 1, 1024, 64, 12, 1024, 1),
+ (64, 12, 1, 1024, 64, 12, 1024, 1024),
+ (64, 12, 1, 1024, 64, 12, 1024, 64),
+ (64, 12, 1024, 1, 1, 1, 1, 1),
+ (64, 12, 1024, 1, 1, 1, 1, 2304),
+ (64, 12, 1024, 1, 1, 1, 1, 3072),
+ (64, 12, 1024, 1, 1, 1, 1, 768),
+ (64, 12, 1024, 1, 1, 1, 1, 96),
+ (64, 12, 1024, 1, 64, 12, 1, 1024),
+ (64, 12, 1024, 1024, 1, 1, 1024, 768),
+ (64, 12, 1024, 1024, 64, 12, 1024, 1),
+ (64, 12, 1024, 1024, 64, 12, 1024, 1024),
+ (64, 12, 1024, 1024, 64, 12, 1024, 64),
+ (64, 12, 1024, 64, 64, 12, 64, 1024),
+ (64, 12, 64, 1024, 1, 1, 1024, 768),
+ (64, 12, 64, 1024, 64, 12, 1024, 1),
+ (64, 12, 64, 1024, 64, 12, 1024, 1024),
+ (64, 12, 64, 1024, 64, 12, 1024, 64),
+ ],
+ "core_grid": [True, False],
+ },
}
def run(
params,
+ core_grid,
*,
device,
) -> list:
- [in0_h, in0_w, in1_h, in1_w] = params
- torch_input_tensor0 = torch.rand([in0_h, in0_w], dtype=torch.float32)
- torch_input_tensor1 = torch.rand([in1_h, in1_w], dtype=torch.float32)
+ if core_grid == False:
+ grid = None
+ else:
+ grid = device.core_grid
+ count = len(params)
+ half = int(count / 2)
+ shape0 = params[0:half]
+ shape1 = params[half:count]
+ torch_input_tensor0 = torch.rand(shape0, dtype=torch.float32)
+ torch_input_tensor1 = torch.rand(shape1, dtype=torch.float32)
torch_output_tensor = torch.matmul(torch_input_tensor0, torch_input_tensor1)
input_tensor0 = ttnn.from_torch(torch_input_tensor0, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
input_tensor1 = ttnn.from_torch(torch_input_tensor1, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
start_time = start_measuring_time()
- output_tensor = ttnn.matmul(input_tensor0, input_tensor1)
+ output_tensor = ttnn.matmul(input_tensor0, input_tensor1, core_grid=grid)
output_tensor = ttnn.to_torch(output_tensor)
e2e_perf = stop_measuring_time(start_time)
expected_pcc = 0.99
diff --git a/tests/sweep_framework/sweeps/transformer/rotary_embedding/rotary_embedding.py b/tests/sweep_framework/sweeps/transformer/rotary_embedding/rotary_embedding.py
new file mode 100644
index 00000000000..9c39ab1ae1f
--- /dev/null
+++ b/tests/sweep_framework/sweeps/transformer/rotary_embedding/rotary_embedding.py
@@ -0,0 +1,137 @@
+# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
+
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Optional, Tuple
+from functools import partial
+
+import torch
+import random
+import ttnn
+from tests.sweep_framework.sweep_utils.utils import gen_shapes, gen_rotary_embedding_spec
+from tests.tt_eager.python_api_testing.sweep_tests.generation_funcs import gen_func_with_cast_tt
+
+from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time
+from models.utility_functions import torch_random
+
+# Override the default timeout in seconds for hang detection.
+TIMEOUT = 30
+
+random.seed(0)
+
+
+# Parameters provided to the test vector generator are defined here.
+# They are defined as dict-type suites that contain the arguments to the run function as keys, and lists of possible inputs as values.
+# Each suite has a key name (in this case "suite_1" and "suite_2") which will associate the test vectors to this specific suite of inputs.
+# Developers can create their own generator functions and pass them to the parameters as inputs.
+parameters = {
+ "nightly": {
+ "input_spec": gen_rotary_embedding_spec(
+ input_shape_list=gen_shapes([1, 1, 32, 64], [6, 12, 256, 512], [1, 1, 32, 64], 16),
+ cache_size_list=[random.randint(1, 2048) for i in range(8)],
+ ),
+ "input_dtype": [ttnn.bfloat16, ttnn.bfloat8_b],
+ "input_layout": [ttnn.ROW_MAJOR_LAYOUT, ttnn.TILE_LAYOUT],
+ "input_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG],
+ "output_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG],
+ },
+}
+
+
+def invalidate_vector(test_vector) -> Tuple[bool, Optional[str]]:
+ if test_vector["input_layout"] == ttnn.ROW_MAJOR_LAYOUT and test_vector["input_dtype"] == ttnn.bfloat8_b:
+ return True, "bfloat8_b/bfloat4_b requires TILE_LAYOUT!"
+ if test_vector["input_spec"]["input_shape"][-1] % 64 != 0:
+ return True, "Input X dimension (133) must be divisible by 64 for tiling"
+ if test_vector["input_spec"]["token_idx"] and test_vector["input_spec"]["input_shape"][0] != 1:
+ return True, "When passing token_idx, sequence length must be 1"
+ return False, None
+
+
+# This is the run instructions for the test, defined by the developer.
+# The run function must take the above-defined parameters as inputs.
+# The runner will call this run function with each test vector, and the returned results from this function will be stored.
+# If you defined a mesh_device_fixture above, the object you yielded will be passed into this function as 'device'. Otherwise, it will be the default ttnn device opened by the infra.
+def run(
+ input_spec,
+ input_dtype,
+ input_layout,
+ input_memory_config,
+ output_memory_config,
+ *,
+ device,
+) -> list:
+ data_seed = random.randint(0, 20000000)
+ torch.manual_seed(data_seed)
+
+ input_shape, cache_size, token_idx = input_spec.values()
+ seq_length, batch_size, num_heads, head_dim = input_shape
+
+ sin_cos_cache_shape = [1, 1, cache_size, head_dim]
+
+ torch_input_tensor = gen_func_with_cast_tt(
+ partial(torch_random, low=-100, high=100, dtype=torch.float32), input_dtype
+ )(input_shape)
+ torch_cos_cache_tensor = gen_func_with_cast_tt(
+ partial(torch_random, low=-100, high=100, dtype=torch.float32), input_dtype
+ )(sin_cos_cache_shape)
+ torch_sin_cache_tensor = gen_func_with_cast_tt(
+ partial(torch_random, low=-100, high=100, dtype=torch.float32), input_dtype
+ )(sin_cos_cache_shape)
+
+ if token_idx:
+ golden_function = partial(ttnn.get_golden_function(ttnn.experimental.rotary_embedding), token_idx=token_idx)
+ else:
+
+ def rotate_half(x):
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+ def apply_rotary_pos_emb(x, cos_cached, sin_cached, token_idx=None):
+ seq_len = x.shape[-2]
+ if token_idx is None:
+ cos = cos_cached[:, :, :seq_len, ...]
+ sin = sin_cached[:, :, :seq_len, ...]
+ else:
+ cos = cos_cached[:, :, token_idx : token_idx + 1, ...]
+ sin = sin_cached[:, :, token_idx : token_idx + 1, ...]
+
+ x_embed = (x * cos) + (rotate_half(x) * sin)
+ return x_embed
+
+ golden_function = apply_rotary_pos_emb
+
+ torch_output_tensor = golden_function(torch_input_tensor, torch_cos_cache_tensor, torch_sin_cache_tensor)
+
+ input_tensor = ttnn.from_torch(
+ torch_input_tensor,
+ dtype=input_dtype,
+ layout=input_layout,
+ device=device,
+ memory_config=input_memory_config,
+ )
+ cos_cache_tensor = ttnn.from_torch(
+ torch_cos_cache_tensor,
+ dtype=input_dtype,
+ layout=input_layout,
+ device=device,
+ memory_config=input_memory_config,
+ )
+ sin_cache_tensor = ttnn.from_torch(
+ torch_sin_cache_tensor,
+ dtype=input_dtype,
+ layout=input_layout,
+ device=device,
+ memory_config=input_memory_config,
+ )
+
+ start_time = start_measuring_time()
+ output_tensor = ttnn.experimental.rotary_embedding(
+ input_tensor, cos_cache_tensor, sin_cache_tensor, token_idx, memory_config=output_memory_config
+ )
+ e2e_perf = stop_measuring_time(start_time)
+
+ output_tensor = ttnn.to_torch(output_tensor)
+
+ return [check_with_pcc(torch_output_tensor, output_tensor, 0.999), e2e_perf]
diff --git a/tests/tt_eager/ops/test_bcast_op.cpp b/tests/tt_eager/ops/test_bcast_op.cpp
index a761c1eba51..05be3303c06 100644
--- a/tests/tt_eager/ops/test_bcast_op.cpp
+++ b/tests/tt_eager/ops/test_bcast_op.cpp
@@ -3,16 +3,13 @@
// SPDX-License-Identifier: Apache-2.0
#include "tt_metal/host_api.hpp"
+#include "ttnn/cpp/ttnn/operations/creation.hpp"
#include "ttnn/tensor/tensor.hpp"
#include "ttnn/operations/data_movement/bcast/bcast.hpp"
#include "common/constants.hpp"
#include
#include
-#include
-#include
-#include
-
using namespace tt;
using namespace tt_metal;
using namespace constants;
@@ -53,9 +50,8 @@ int main(int argc, char** argv) {
}
Tensor a = ttnn::numpy::random::random(input_shape_a).to(Layout::TILE).to(device);
- Tensor b = ttnn::numpy::zeros({1, 1, TILE_HEIGHT, TILE_WIDTH}, DataType::BFLOAT16)
- .to(Layout::TILE)
- .to(device);
+ Tensor b = ttnn::zeros(
+ ttnn::Shape({1, 1, TILE_HEIGHT, TILE_WIDTH}), DataType::BFLOAT16, Layout::TILE, *device);
for (auto bcast_math : magic_enum::enum_values()) {
Tensor c = ttnn::bcast(0, a, b, bcast_math, bcast_dim);
@@ -72,28 +68,28 @@ int main(int argc, char** argv) {
{
Tensor a = ttnn::numpy::random::random({1, 1, 32, 4544}).to(Layout::TILE).to(device);
- Tensor b = ttnn::numpy::zeros({1, 1, 32, 4544}, DataType::BFLOAT16).to(Layout::TILE).to(device);
+ Tensor b = ttnn::zeros(ttnn::Shape({1, 1, 32, 4544}), DataType::BFLOAT16, Layout::TILE, *device);
Tensor c = ttnn::bcast(0, a, b, ttnn::BcastOpMath::MUL, ttnn::BcastOpDim::H);
Tensor d = c.cpu();
}
{
Tensor a = ttnn::numpy::random::random({1, 1, 32, 4544}).to(Layout::TILE).to(device);
- Tensor b = ttnn::numpy::zeros({1, 1, 32, 4544}, DataType::BFLOAT16).to(Layout::TILE).to(device);
+ Tensor b = ttnn::zeros(ttnn::Shape({1, 1, 32, 4544}), DataType::BFLOAT16, Layout::TILE, *device);
Tensor c = ttnn::bcast(0, a, b, ttnn::BcastOpMath::ADD, ttnn::BcastOpDim::H);
Tensor d = c.cpu();
}
{
Tensor a = ttnn::numpy::random::random({1, 71, 32, 32}).to(Layout::TILE).to(device);
- Tensor b = ttnn::numpy::zeros({1, 1, 32, 32}, DataType::BFLOAT16).to(Layout::TILE).to(device);
+ Tensor b = ttnn::zeros(ttnn::Shape({1, 1, 32, 32}), DataType::BFLOAT16, Layout::TILE, *device);
Tensor c = ttnn::bcast(0, a, b, ttnn::BcastOpMath::MUL, ttnn::BcastOpDim::HW);
Tensor d = c.cpu();
}
{
Tensor a = ttnn::numpy::random::random({1, 71, 32, 64}).to(Layout::TILE).to(device);
- Tensor b = ttnn::numpy::zeros({1, 1, 32, 32}, DataType::BFLOAT16).to(Layout::TILE).to(device);
+ Tensor b = ttnn::zeros(ttnn::Shape({1, 1, 32, 32}), DataType::BFLOAT16, Layout::TILE, *device);
Tensor c = ttnn::bcast(0, a, b, ttnn::BcastOpMath::MUL, ttnn::BcastOpDim::HW);
Tensor d = c.cpu();
}
diff --git a/tests/tt_eager/ops/test_bmm_op.cpp b/tests/tt_eager/ops/test_bmm_op.cpp
index c7760e67354..f769870b595 100644
--- a/tests/tt_eager/ops/test_bmm_op.cpp
+++ b/tests/tt_eager/ops/test_bmm_op.cpp
@@ -3,15 +3,13 @@
// SPDX-License-Identifier: Apache-2.0
#include "tt_metal/host_api.hpp"
+#include "ttnn/cpp/ttnn/operations/creation.hpp"
#include "ttnn/tensor/tensor.hpp"
+#include "ttnn/tensor/types.hpp"
#include "ttnn/operations/matmul/device/matmul_op.hpp"
#include "common/constants.hpp"
#include "ttnn/operations/numpy/functions.hpp"
-#include
-#include
-#include
-
using namespace tt;
using namespace tt_metal;
using namespace constants;
@@ -37,14 +35,14 @@ int main(int argc, char** argv) {
uint32_t Kt = 2;
uint32_t Nt = 4;
uint32_t B = 5;
- tt::tt_metal::LegacyShape shapea = {B, 1, Mt * TILE_HEIGHT, Kt * TILE_WIDTH};
- tt::tt_metal::LegacyShape shapeb = {B, 1, Kt * TILE_HEIGHT, Nt * TILE_WIDTH};
- tt::tt_metal::LegacyShape shapeb1 = {1, 1, Kt * TILE_HEIGHT, Nt * TILE_WIDTH};
+ ttnn::Shape shapea({B, 1, Mt * TILE_HEIGHT, Kt * TILE_WIDTH});
+ ttnn::Shape shapeb({B, 1, Kt * TILE_HEIGHT, Nt * TILE_WIDTH});
+ ttnn::Shape shapeb1({1, 1, Kt * TILE_HEIGHT, Nt * TILE_WIDTH});
// Allocates a DRAM buffer on device populated with values specified by initialize
- Tensor a = ttnn::numpy::random::random(shapea).to(Layout::TILE).to(device);
- Tensor b = ttnn::numpy::zeros(shapeb, DataType::BFLOAT16).to(Layout::TILE).to(device);
- Tensor b1 = ttnn::numpy::zeros(shapeb1, DataType::BFLOAT16).to(Layout::TILE).to(device);
+ Tensor a = ttnn::numpy::random::random(shapea.value).to(Layout::TILE).to(device);
+ Tensor b = ttnn::zeros(shapeb, DataType::BFLOAT16, Layout::TILE, *device);
+ Tensor b1 = ttnn::zeros(shapeb1, DataType::BFLOAT16, Layout::TILE, *device);
Tensor mm = ttnn::operations::matmul::matmul(
a,
diff --git a/tests/tt_eager/ops/test_tensor_utils.cpp b/tests/tt_eager/ops/test_tensor_utils.cpp
index 1121b455b2a..0b6c2e3d376 100644
--- a/tests/tt_eager/ops/test_tensor_utils.cpp
+++ b/tests/tt_eager/ops/test_tensor_utils.cpp
@@ -11,14 +11,11 @@
#include "ttnn/tensor/host_buffer/types.hpp"
#include "ttnn/tensor/tensor.hpp"
#include "ttnn/tensor/tensor.hpp"
-#include "ttnn/operations/numpy/functions.hpp"
+#include "ttnn/operations/creation.hpp"
#include "ttnn/tensor/types.hpp"
#include "ttnn/tensor/tensor_utils.hpp"
-using std::vector;
-using tt::tt_metal::Tensor;
-using namespace tt::tt_metal;
-static vector> ref_weight_in = {
+static std::vector> ref_weight_in = {
{
16140, 16151, 16183, 16216, 16154, 16219, 16139, 16216, 16088, 16159, 16165, 16068, 16096, 16024, 16228, 15720,
16246, 16011, 16068, 16116, 16202, 16207, 16135, 16117, 16145, 16073, 16236, 16214, 15761, 16044, 15794, 16165,
@@ -246,7 +243,7 @@ static vector> ref_weight_in = {
}
};
-static vector> ref_weight_out = {
+static std::vector> ref_weight_out = {
{16140, 16151, 16183, 16216, 16154, 16219, 16139, 16216, 16088, 16156, 15971, 16157, 16069, 16241, 16231, 16174,
16102, 16056, 16250, 15716, 16154, 16102, 16189, 15523, 15648, 16098, 16016, 15972, 16228, 16243, 16174, 16100,
16101, 16216, 16250, 16179, 16206, 16137, 16180, 16101, 15821, 15819, 16235, 16052, 16182, 15912, 16128, 16159,
@@ -419,9 +416,11 @@ static vector> ref_weight_out = {
15832, 15895, 16234, 16062, 16231, 16173, 16122, 16016, 16187, 15560, 16229, 16046, 16243, 16219, 15849, 16135,
}};
-static vector weight_tensor_shape = {{8, 8, 3, 3}, {10, 10, 3, 3}, {12, 8, 3, 3}, {8, 15, 3, 3}};
-static vector bias_tensor_shape = {{1, 1, 1, 32}, {1, 1, 1, 60}, {12, 1, 1, 320}, {8, 1, 1, 48}};
-static vector shards = {8, 3, 5, 4};
+static std::vector weight_tensor_shape = {
+ {8, 8, 3, 3}, {10, 10, 3, 3}, {12, 8, 3, 3}, {8, 15, 3, 3}};
+static std::vector bias_tensor_shape = {
+ {1, 1, 1, 32}, {1, 1, 1, 60}, {12, 1, 1, 320}, {8, 1, 1, 48}};
+static std::vector shards = {8, 3, 5, 4};
template
static uint32_t compare_out_with_ref(const owned_buffer::Buffer& out_buf, T& ref) {
@@ -447,7 +446,7 @@ static uint32_t compare_out_with_ref(const owned_buffer::Buffer& out_b
static void test_convert_conv_weight_tensor_to_tiled_layout_block_sharded() {
tt::log_info(tt::LogTest, "Running {}", __func__);
for (auto i = 0; i < weight_tensor_shape.size(); i++) {
- auto input_tensor = ttnn::numpy::zeros(weight_tensor_shape[i]);
+ auto input_tensor = ttnn::zeros(weight_tensor_shape[i]);
auto input_buffer = owned_buffer::get_as(input_tensor);
for (auto j = 0; j < input_buffer.size(); j++) {
input_buffer[j] = ref_weight_in[i][j];
diff --git a/tests/tt_eager/ops/test_tilize_zero_padding_channels_last.cpp b/tests/tt_eager/ops/test_tilize_zero_padding_channels_last.cpp
index 1a0d13092ad..a94093f8ac6 100644
--- a/tests/tt_eager/ops/test_tilize_zero_padding_channels_last.cpp
+++ b/tests/tt_eager/ops/test_tilize_zero_padding_channels_last.cpp
@@ -7,12 +7,12 @@
#include
#include "common/constants.hpp"
+#include "ttnn/cpp/ttnn/operations/creation.hpp"
#include "ttnn/tensor/host_buffer/functions.hpp"
#include "ttnn/tensor/host_buffer/types.hpp"
#include "ttnn/tensor/tensor.hpp"
#include "ttnn/operations/data_movement/tilize_with_val_padding/tilize_with_val_padding.hpp"
#include "tt_metal/host_api.hpp"
-#include "ttnn/operations/numpy/functions.hpp"
using namespace tt;
using namespace tt_metal;
@@ -37,7 +37,8 @@ int main(int argc, char** argv) {
////////////////////////////////////////////////////////////////////////////
ttnn::SimpleShape shape{1, 32, 61, 32};
// Allocates a DRAM buffer on device populated with values specified by initialize
- Tensor a = ttnn::numpy::arange(0, shape.volume(), 1).reshape(shape).to(device);
+ Tensor a = ttnn::arange(/*start=*/0, /*stop=*/shape.volume(), /*step=*/1, DataType::BFLOAT16, std::ref(*device))
+ .reshape(shape);
Tensor b = ttnn::tilize_with_zero_padding(a);
Tensor c = b.cpu();
////////////////////////////////////////////////////////////////////////////
diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_num_cores_to_corerangeset_in_subcoregrids.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_num_cores_to_corerangeset_in_subcoregrids.py
new file mode 100644
index 00000000000..0840e55513d
--- /dev/null
+++ b/tests/tt_eager/python_api_testing/unit_testing/misc/test_num_cores_to_corerangeset_in_subcoregrids.py
@@ -0,0 +1,93 @@
+# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
+
+# SPDX-License-Identifier: Apache-2.0
+
+import ttnn
+import pytest
+
+
+@pytest.mark.parametrize(
+ "start_core, num_cores, sub_core_grids, row_wise, expected_core_range_set",
+ [
+ # Test Case 1: Basic row-wise scenario with enough cores in sub_core_grids
+ (
+ ttnn.CoreCoord(1, 0),
+ 32,
+ ttnn.CoreRangeSet(
+ [
+ ttnn.CoreRange(ttnn.CoreCoord(1, 0), ttnn.CoreCoord(3, 9)),
+ ttnn.CoreRange(ttnn.CoreCoord(5, 0), ttnn.CoreCoord(6, 9)),
+ ]
+ ),
+ True,
+ ttnn.CoreRangeSet(
+ [
+ ttnn.CoreRange(ttnn.CoreCoord(1, 0), ttnn.CoreCoord(3, 9)),
+ ttnn.CoreRange(ttnn.CoreCoord(5, 0), ttnn.CoreCoord(6, 0)),
+ ]
+ ),
+ ),
+ # Test Case 2: Basic Column-wise processing
+ (
+ ttnn.CoreCoord(1, 0),
+ 32,
+ ttnn.CoreRangeSet(
+ [
+ ttnn.CoreRange(ttnn.CoreCoord(1, 0), ttnn.CoreCoord(3, 9)),
+ ttnn.CoreRange(ttnn.CoreCoord(5, 0), ttnn.CoreCoord(6, 9)),
+ ]
+ ),
+ False,
+ ttnn.CoreRangeSet(
+ [
+ ttnn.CoreRange(ttnn.CoreCoord(1, 0), ttnn.CoreCoord(3, 9)),
+ ttnn.CoreRange(ttnn.CoreCoord(5, 0), ttnn.CoreCoord(5, 1)),
+ ]
+ ),
+ ),
+ # Test Case 3: row-wise scenario with small target cores and start offset
+ (
+ ttnn.CoreCoord(3, 2),
+ 8,
+ ttnn.CoreRangeSet(
+ [
+ ttnn.CoreRange(ttnn.CoreCoord(1, 0), ttnn.CoreCoord(3, 9)),
+ ttnn.CoreRange(ttnn.CoreCoord(5, 0), ttnn.CoreCoord(6, 9)),
+ ]
+ ),
+ True,
+ ttnn.CoreRangeSet(
+ [
+ ttnn.CoreRange(ttnn.CoreCoord(3, 2), ttnn.CoreCoord(3, 2)),
+ ttnn.CoreRange(ttnn.CoreCoord(1, 3), ttnn.CoreCoord(3, 4)),
+ ttnn.CoreRange(ttnn.CoreCoord(1, 5), ttnn.CoreCoord(1, 5)),
+ ]
+ ),
+ ),
+ # Test Case 4: col-wise scenario with small target cores and start offset
+ (
+ ttnn.CoreCoord(1, 8),
+ 8,
+ ttnn.CoreRangeSet(
+ [
+ ttnn.CoreRange(ttnn.CoreCoord(1, 0), ttnn.CoreCoord(3, 9)),
+ ttnn.CoreRange(ttnn.CoreCoord(5, 0), ttnn.CoreCoord(6, 9)),
+ ]
+ ),
+ False,
+ ttnn.CoreRangeSet(
+ [
+ ttnn.CoreRange(ttnn.CoreCoord(1, 8), ttnn.CoreCoord(1, 9)),
+ ttnn.CoreRange(ttnn.CoreCoord(2, 0), ttnn.CoreCoord(2, 5)),
+ ]
+ ),
+ ),
+ ],
+)
+def test_numcores_to_corerangeset_in_subcoregrids(
+ start_core, num_cores, sub_core_grids, row_wise, expected_core_range_set
+):
+ output_corerangeset = ttnn.num_cores_to_corerangeset_in_subcoregrids(
+ start_core, num_cores, sub_core_grids, row_wise=row_wise
+ )
+ assert output_corerangeset.to_json() == expected_core_range_set.to_json()
diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_scaled_dot_product_attention.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_scaled_dot_product_attention.py
index 98bca8dc1a3..736b2f2db82 100644
--- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_scaled_dot_product_attention.py
+++ b/tests/tt_eager/python_api_testing/unit_testing/misc/test_scaled_dot_product_attention.py
@@ -183,8 +183,10 @@ def test_sdpa_tt_with_program_cache(device, b, nh, nkv, s, d, q_chunk_size, k_ch
assert device.num_program_cache_entries() == 1
-def run_sdpa_noncausal(device, b, nh, nkv, s, d, q_chunk_size, k_chunk_size, dtype):
+def run_sdpa_noncausal(device, b, nh, nkv, sq, d, q_chunk_size, k_chunk_size, dtype, sk=None):
torch.manual_seed(1234)
+ if sk is None:
+ sk = sq
program_config = ttnn.SDPAProgramConfig(
compute_with_storage_grid_size=device.compute_with_storage_grid_size(),
@@ -200,16 +202,16 @@ def run_sdpa_noncausal(device, b, nh, nkv, s, d, q_chunk_size, k_chunk_size, dty
packer_l1_acc=False,
)
- Q = fa_rand(b, nh, s, d)
- K = fa_rand(b, nkv, s, d)
- V = fa_rand(b, nkv, s, d)
+ Q = fa_rand(b, nh, sq, d)
+ K = fa_rand(b, nkv, sk, d)
+ V = fa_rand(b, nkv, sk, d)
# Generate random non-causal attention mask
mask = torch.bernoulli(
torch.full(
(
b,
- s,
- s,
+ sq,
+ sk,
),
0.25,
)
@@ -240,8 +242,8 @@ def run_sdpa_noncausal(device, b, nh, nkv, s, d, q_chunk_size, k_chunk_size, dty
if nkv > 1 and nkv != nh:
assert nh % nkv == 0
- K = K.reshape(b, nkv, 1, s, d).repeat(1, 1, nh // nkv, 1, 1).reshape(b, nh, s, d)
- V = V.reshape(b, nkv, 1, s, d).repeat(1, 1, nh // nkv, 1, 1).reshape(b, nh, s, d)
+ K = K.reshape(b, nkv, 1, sk, d).repeat(1, 1, nh // nkv, 1, 1).reshape(b, nh, sk, d)
+ V = V.reshape(b, nkv, 1, sk, d).repeat(1, 1, nh // nkv, 1, 1).reshape(b, nh, sk, d)
gt = torch.nn.functional.scaled_dot_product_attention(Q, K, V, is_causal=False, attn_mask=mask)
@@ -274,3 +276,24 @@ def test_sdpa_noncausal(device, b, nh, nkv, s, d, q_chunk_size, k_chunk_size, dt
pytest.skip("Bad PCC for small chunks")
ttnn.device.DisablePersistentKernelCache()
run_sdpa_noncausal(device, b, nh, nkv, s, d, q_chunk_size, k_chunk_size, dtype)
+
+
+@skip_for_blackhole("Mismatching on BH, see #12349")
+@pytest.mark.skipif(is_watcher_enabled(), reason="Kernel OOM with watcher enabled")
+@skip_for_grayskull("Unsupported in GS since L1 runs OOM with most configs")
+@pytest.mark.parametrize("dtype", [ttnn.bfloat8_b, ttnn.bfloat16], ids=["bfp8", "bf16"])
+@pytest.mark.parametrize("q_chunk_size", [128, 256], ids=["q128", "q256"])
+@pytest.mark.parametrize("k_chunk_size", [128, 256], ids=["k128", "k256"])
+@pytest.mark.parametrize(
+ "b, nh, nkv, sq, sk, d",
+ (
+ [1, 8, 1, 4096, 2048, 128],
+ # [1, 4, 4, 128*1024, 6528, 128], # Llama-Vision long seq
+ [1, 4, 1, 2048, 6528, 128], # Llama-Vision
+ ),
+)
+def test_sdpa_noncausal_unequal_seqlen(device, b, nh, nkv, sq, sk, d, q_chunk_size, k_chunk_size, dtype):
+ if (sq % q_chunk_size != 0) or (sk % k_chunk_size != 0):
+ pytest.skip("s must be divisible by q_chunk_size and k_chunk_size")
+ ttnn.device.DisablePersistentKernelCache()
+ run_sdpa_noncausal(device, b, nh, nkv, sq, d, q_chunk_size, k_chunk_size, dtype, sk=sk)
diff --git a/tests/tt_eager/tensors/test_async_tensor_apis.cpp b/tests/tt_eager/tensors/test_async_tensor_apis.cpp
index b762f10acf7..3ef44800178 100644
--- a/tests/tt_eager/tensors/test_async_tensor_apis.cpp
+++ b/tests/tt_eager/tensors/test_async_tensor_apis.cpp
@@ -2,13 +2,11 @@
//
// SPDX-License-Identifier: Apache-2.0
-#include
#include
-#include
-#include
#include "common/bfloat16.hpp"
#include "common/constants.hpp"
+#include "ttnn/cpp/ttnn/operations/creation.hpp"
#include "ttnn/tensor/host_buffer/functions.hpp"
#include "ttnn/tensor/host_buffer/types.hpp"
#include "ttnn/tensor/tensor.hpp"
@@ -16,16 +14,16 @@
#include "ttnn/tensor/types.hpp"
#include "tests/tt_metal/tt_metal/common/dispatch_fixture.hpp"
#include "tt_metal/host_api.hpp"
-#include "ttnn/operations/numpy/functions.hpp"
#include "ttnn/operations/eltwise/binary/binary.hpp"
#include "ttnn/operations/eltwise/unary/unary.hpp"
-using namespace tt;
-using namespace tt_metal;
-using namespace constants;
-
+namespace tt::tt_metal {
namespace {
+
+using ::tt::constants::TILE_HEIGHT;
+using ::tt::constants::TILE_WIDTH;
+
uint32_t get_device_buffer_address(const Tensor& tensor) {
TT_FATAL(std::holds_alternative(tensor.get_storage()), "Tensor storage is not DeviceStorage");
auto buffer = std::get(tensor.get_storage()).buffer;
@@ -33,13 +31,12 @@ uint32_t get_device_buffer_address(const Tensor& tensor) {
buffer->device()->push_work([&]() { result = buffer->address(); }, true);
return result;
}
-} // namespace
TEST_F(DispatchFixture, TestTensorOwnershipSanity) {
// Sanity test tensor read, write and update paths with synchronous
// Ensure that tensor data is copied and owned as expected
Device* device = this->devices_[0];
- Tensor host_tensor = ttnn::numpy::arange(0, 32 * 32 * 4, 1);
+ Tensor host_tensor = ttnn::arange(/*start=*/0, /*stop=*/32 * 32 * 4, /*step=*/1, DataType::FLOAT32);
Tensor readback_tensor(1);
auto func = [device, host_tensor, readback_tensor]() mutable {
@@ -122,18 +119,12 @@ TEST_F(DispatchFixture, TestAsyncEltwiseBinary) {
for (int i = 0; i < 5; i++) {
// Initialize tensors and move them to DRAM
- Tensor input_tensor_a =
- ttnn::numpy::full(
- tt::tt_metal::LegacyShape({1, 1, 1024, 1024}), static_cast(i), DataType::BFLOAT16, Layout::TILE)
- .to(device);
- Tensor input_tensor_b =
- ttnn::numpy::full(
- tt::tt_metal::LegacyShape({1, 1, 1024, 1024}), static_cast(i), DataType::BFLOAT16, Layout::TILE)
- .to(device);
- Tensor input_tensor_c =
- ttnn::numpy::full(
- tt::tt_metal::LegacyShape({1, 1, 1024, 1024}), static_cast(i), DataType::BFLOAT16, Layout::TILE)
- .to(device);
+ Tensor input_tensor_a = ttnn::full(
+ ttnn::Shape({1, 1, 1024, 1024}), static_cast(i), DataType::BFLOAT16, Layout::TILE, *device);
+ Tensor input_tensor_b = ttnn::full(
+ ttnn::Shape({1, 1, 1024, 1024}), static_cast(i), DataType::BFLOAT16, Layout::TILE, *device);
+ Tensor input_tensor_c = ttnn::full(
+ ttnn::Shape({1, 1, 1024, 1024}), static_cast(i), DataType::BFLOAT16, Layout::TILE, *device);
Tensor output_tensor_device = ttnn::multiply(ttnn::add(input_tensor_a, input_tensor_b), input_tensor_c);
Tensor output_tensor_device_2 = ttnn::neg(ttnn::subtract(output_tensor_device, input_tensor_c));
@@ -181,12 +172,18 @@ TEST_F(DispatchFixture, TestAsyncRefCountManager) {
for (int i = 0; i < 5; i++) {
// Run for multiple loops to ensure deterministic behaviour with device addresses
// Initialize 2 tensors on device
- Tensor tensor1 = ttnn::numpy::full(
- tt::tt_metal::LegacyShape({1, 1, 1024, 1024}), static_cast(i), DataType::BFLOAT16)
- .to(device);
- Tensor tensor2 = ttnn::numpy::full(
- tt::tt_metal::LegacyShape({1, 1, 1024, 1024}), static_cast(i), DataType::BFLOAT16)
- .to(device);
+ Tensor tensor1 = ttnn::full(
+ ttnn::Shape({1, 1, 1024, 1024}),
+ static_cast(i),
+ DataType::BFLOAT16,
+ /*layout=*/std::nullopt,
+ *device);
+ Tensor tensor2 = ttnn::full(
+ ttnn::Shape({1, 1, 1024, 1024}),
+ static_cast(i),
+ DataType::BFLOAT16,
+ /*layout=*/std::nullopt,
+ *device);
uint32_t tensor2_device_buf_addr = get_device_buffer_address(tensor2);
// Assign tensor1 to tensor2 and ensure that ref counts are appropriately updated with the buffer for tensor2
// deallocated
@@ -195,18 +192,23 @@ TEST_F(DispatchFixture, TestAsyncRefCountManager) {
EXPECT_EQ(tensor1.tensor_attributes->main_thread_ref_count, 2);
// To check if tensor2 is deallocated, create a third tensor on device and ensure that its address matches the
// prev addr for tensor2
- Tensor tensor3 = ttnn::numpy::full(
- tt::tt_metal::LegacyShape({1, 1, 1024, 1024}), static_cast(i), DataType::BFLOAT16)
- .to(device);
+ Tensor tensor3 = ttnn::full(
+ ttnn::Shape({1, 1, 1024, 1024}),
+ static_cast(i),
+ DataType::BFLOAT16,
+ /*layout=*/std::nullopt,
+ *device);
EXPECT_EQ(get_device_buffer_address(tensor3), tensor2_device_buf_addr);
EXPECT_EQ(get_device_buffer_address(tensor1), get_device_buffer_address(tensor2));
}
log_info(LogTest, "Testing Device tensor self-assignment through function");
for (int i = 0; i < 5; i++) {
- Tensor device_tensor =
- ttnn::numpy::full(
- tt::tt_metal::LegacyShape({1, 1, 1024, 1024}), static_cast(i), DataType::BFLOAT16)
- .to(device);
+ Tensor device_tensor = ttnn::full(
+ ttnn::Shape({1, 1, 1024, 1024}),
+ static_cast(i),
+ DataType::BFLOAT16,
+ /*layout=*/std::nullopt,
+ *device);
uint32_t device_tensor_address = get_device_buffer_address(device_tensor);
// This step will copy the tensor to a temp rval and std::move it back to the caller's instance of device_tensor
// Ensure ref count and address remain unchanged
@@ -217,18 +219,19 @@ TEST_F(DispatchFixture, TestAsyncRefCountManager) {
log_info(LogTest, "Testing Device tensor move assignment");
for (int i = 0; i < 5; i++) {
- Tensor tensor1 = ttnn::numpy::full(
- tt::tt_metal::LegacyShape({1, 1, 1024, 1024}), static_cast(i), DataType::BFLOAT16)
- .to(device);
+ Tensor tensor1 = ttnn::full(
+ ttnn::Shape({1, 1, 1024, 1024}),
+ static_cast(i),
+ DataType::BFLOAT16,
+ /*layout=*/std::nullopt,
+ *device);
Tensor tensor2 = std::move(tensor1);
EXPECT_EQ(tensor2.tensor_attributes->main_thread_ref_count, 1);
}
log_info(LogTest, "Testing Device tensor self-assignment");
- Tensor tensor_to_self_assign =
- ttnn::numpy::full(
- tt::tt_metal::LegacyShape({1, 1, 1024, 1024}), static_cast(0), DataType::BFLOAT16)
- .to(device);
+ Tensor tensor_to_self_assign = ttnn::full(
+ ttnn::Shape({1, 1, 1024, 1024}), static_cast(0), DataType::BFLOAT16, /*layout=*/std::nullopt, *device);
uint32_t tensor_to_self_assign_address = get_device_buffer_address(tensor_to_self_assign);
tensor_to_self_assign = tensor_to_self_assign;
EXPECT_EQ(tensor_to_self_assign.tensor_attributes->main_thread_ref_count, 1);
@@ -255,7 +258,7 @@ TEST_F(DispatchFixture, TestTensorAsyncDataMovement) {
{
// host_tensor only lives in this scope
- Tensor host_tensor = ttnn::numpy::arange(tensor_start, tensor_stop, 1);
+ Tensor host_tensor = ttnn::arange(tensor_start, tensor_stop, /*step=*/1, DataType::FLOAT32);
log_info(LogTest, "Spawning worker thread");
worker = std::thread([tensor_stop, host_tensor, readback_tensor, device]() mutable {
// Sleep for 3 seconds to ensure that main thread deallocates host_tensor
@@ -338,3 +341,5 @@ TEST_F(DispatchFixture, TestTensorAsyncDataMovement) {
EXPECT_EQ(readback_tensor.get_layout(), Layout::ROW_MAJOR);
EXPECT_EQ(readback_tensor.get_shape(), ttnn::Shape(tt::tt_metal::LegacyShape({1, 1, 32, tensor_stop / 32})));
}
+} // namespace
+} // namespace tt::tt_metal
diff --git a/tests/tt_eager/tensors/test_copy_and_move.cpp b/tests/tt_eager/tensors/test_copy_and_move.cpp
index f735791ad4c..656585f3351 100644
--- a/tests/tt_eager/tensors/test_copy_and_move.cpp
+++ b/tests/tt_eager/tensors/test_copy_and_move.cpp
@@ -2,12 +2,9 @@
//
// SPDX-License-Identifier: Apache-2.0
-#include
-#include
-#include
-
#include "common/bfloat16.hpp"
#include "common/constants.hpp"
+#include "ttnn/cpp/ttnn/operations/creation.hpp"
#include "ttnn/tensor/host_buffer/functions.hpp"
#include "ttnn/tensor/host_buffer/types.hpp"
#include "ttnn/tensor/tensor.hpp"
@@ -40,7 +37,7 @@ bool test_tensor_copy_semantics(Device* device) {
pass &= dev_a_data == dev_a_copy_data;
// host tensor updated with host tensor copy assignment
- Tensor host_c = ttnn::numpy::arange(0, tt_metal::compute_volume(single_tile_shape), 1)
+ Tensor host_c = ttnn::arange(/*start=*/0, /*stop=*/tt_metal::compute_volume(single_tile_shape), /*step=*/1)
.reshape(single_tile_shape)
.to(Layout::TILE);
Tensor host_c_copy = ttnn::numpy::random::random(single_tile_shape).to(Layout::TILE);
@@ -58,7 +55,7 @@ bool test_tensor_copy_semantics(Device* device) {
pass &= dev_a_data == host_d_copy_data;
// dev tensor updated with host tensor copy assignment
- Tensor host_e = ttnn::numpy::ones(single_tile_shape).to(Layout::TILE);
+ Tensor host_e = ttnn::ones(single_tile_shape, DataType::BFLOAT16, Layout::TILE);
Tensor dev_e_copy = ttnn::numpy::random::random(single_tile_shape).to(Layout::TILE).to(device);
dev_e_copy = host_e;
pass &= (dev_e_copy.storage_type() == StorageType::OWNED);
@@ -67,8 +64,8 @@ bool test_tensor_copy_semantics(Device* device) {
pass &= host_e_data == dev_e_copy_data;
// dev tensor updated with dev tensor copy assignment
- Tensor dev_b = ttnn::numpy::ones(single_tile_shape).to(Layout::TILE).to(device);
- Tensor dev_b_copy = ttnn::numpy::zeros(single_tile_shape).to(Layout::TILE).to(device);
+ Tensor dev_b = ttnn::ones(single_tile_shape, DataType::BFLOAT16, Layout::TILE, *device);
+ Tensor dev_b_copy = ttnn::zeros(single_tile_shape, DataType::BFLOAT16, Layout::TILE, *device);
dev_b_copy = dev_b;
pass &= (dev_b_copy.storage_type() == StorageType::DEVICE);
auto dev_b_on_host = dev_b.cpu();
diff --git a/tests/tt_metal/distributed/test_distributed.cpp b/tests/tt_metal/distributed/test_distributed.cpp
index 26df7dbcc78..45afeece3be 100644
--- a/tests/tt_metal/distributed/test_distributed.cpp
+++ b/tests/tt_metal/distributed/test_distributed.cpp
@@ -46,7 +46,7 @@ TEST(MeshDeviceSuite, Test1x1SystemMeshInitialize) {
auto& sys = tt::tt_metal::distributed::SystemMesh::instance();
auto config =
- tt::tt_metal::distributed::MeshDeviceConfig({1, 1}, std::pair(0, 0), {}, MeshType::RowMajor);
+ tt::tt_metal::distributed::MeshDeviceConfig(MeshShape(1, 1), MeshOffset(0, 0), {}, MeshType::RowMajor);
EXPECT_NO_THROW({
auto mesh = tt::tt_metal::distributed::MeshDevice::create(
diff --git a/tests/tt_metal/tools/profiler/test_device_profiler.py b/tests/tt_metal/tools/profiler/test_device_profiler.py
index 01d71f987aa..1c2ab3823cc 100644
--- a/tests/tt_metal/tools/profiler/test_device_profiler.py
+++ b/tests/tt_metal/tools/profiler/test_device_profiler.py
@@ -7,6 +7,7 @@
import re
import inspect
import pytest
+import subprocess
import pandas as pd
@@ -24,6 +25,30 @@
PROG_EXMP_DIR = "programming_examples/profiler"
+def get_device_data(setupStr=""):
+ postProcessRun = os.system(
+ f"cd {PROFILER_SCRIPTS_ROOT} && " f"./process_device_log.py {setupStr} --no-artifacts --no-print-stats"
+ )
+
+ assert postProcessRun == 0, f"Log process script crashed with exit code {postProcessRun}"
+
+ devicesData = {}
+ with open(f"{PROFILER_ARTIFACTS_DIR}/output/device/device_analysis_data.json", "r") as devicesDataJson:
+ devicesData = json.load(devicesDataJson)
+
+ return devicesData
+
+
+def run_gtest_profiler_test(testbin, testname):
+ clear_profiler_runtime_artifacts()
+ output = subprocess.check_output(
+ f"cd {TT_METAL_HOME} && {testbin} --gtest_filter={testname}", stderr=subprocess.STDOUT, shell=True
+ ).decode("UTF-8")
+ print(output)
+ if "SKIPPED" not in output:
+ get_device_data()
+
+
def run_device_profiler_test(testName=None, setup=False, slowDispatch=False):
name = inspect.stack()[1].function
testCommand = f"build/{PROG_EXMP_DIR}/{name}"
@@ -41,17 +66,7 @@ def run_device_profiler_test(testName=None, setup=False, slowDispatch=False):
if setup:
setupStr = f"-s {name}"
- postProcessRun = os.system(
- f"cd {PROFILER_SCRIPTS_ROOT} && " f"./process_device_log.py {setupStr} --no-artifacts --no-print-stats"
- )
-
- assert postProcessRun == 0, f"Log process script crashed with exit code {postProcessRun}"
-
- devicesData = {}
- with open(f"{PROFILER_ARTIFACTS_DIR}/output/device/device_analysis_data.json", "r") as devicesDataJson:
- devicesData = json.load(devicesDataJson)
-
- return devicesData
+ return get_device_data(setupStr)
def get_function_name():
@@ -231,6 +246,8 @@ def test_profiler_host_device_sync():
assert freq < (reportedFreq * (1 + TOLERANCE)), f"Frequency too large on device {device}"
assert freq > (reportedFreq * (1 - TOLERANCE)), f"Frequency too small on device {device}"
+ os.environ["TT_METAL_PROFILER_SYNC"] = "0"
+
def test_timestamped_events():
OP_COUNT = 2
@@ -268,3 +285,19 @@ def test_timestamped_events():
devicesData["data"]["devices"]["0"]["cores"]["DEVICE"]["riscs"]["TENSIX"]["events"]["all_events"]
)
assert eventCount in REF_COUNT_DICT[ENV_VAR_ARCH_NAME], "Wrong event count"
+
+
+def test_sub_device_profiler():
+ run_gtest_profiler_test(
+ "./build/test/tt_metal/unit_tests_dispatch", "CommandQueueSingleCardFixture.TensixTestSubDeviceBasicPrograms"
+ )
+ os.environ["TT_METAL_PROFILER_SYNC"] = "1"
+ run_gtest_profiler_test(
+ "./build/test/tt_metal/unit_tests_dispatch",
+ "CommandQueueSingleCardFixture.TensixActiveEthTestSubDeviceBasicEthPrograms",
+ )
+ os.environ["TT_METAL_PROFILER_SYNC"] = "0"
+ run_gtest_profiler_test(
+ "./build/test/tt_metal/unit_tests_dispatch_trace",
+ "CommandQueueSingleCardTraceFixture.TensixTestSubDeviceTraceBasicPrograms",
+ )
diff --git a/tests/tt_metal/tt_metal/api/circular_buffer/test_CircularBuffer_allocation.cpp b/tests/tt_metal/tt_metal/api/circular_buffer/test_CircularBuffer_allocation.cpp
index 624226798ec..8c332185840 100644
--- a/tests/tt_metal/tt_metal/api/circular_buffer/test_CircularBuffer_allocation.cpp
+++ b/tests/tt_metal/tt_metal/api/circular_buffer/test_CircularBuffer_allocation.cpp
@@ -41,8 +41,7 @@ void validate_cb_address(
for (const auto& [buffer_index, expected_address] : address_per_buffer_index) {
auto base_index = UINT32_WORDS_PER_LOCAL_CIRCULAR_BUFFER_CONFIG * buffer_index;
- EXPECT_EQ(
- expected_address >> CIRCULAR_BUFFER_LOG2_WORD_SIZE_BYTES, cb_config_vector.at(base_index));
+ EXPECT_EQ(expected_address, cb_config_vector.at(base_index));
}
}
}
@@ -358,9 +357,8 @@ TEST_F(DeviceFixture, TensixTestUpdateCircularBufferPageSize) {
for (const auto& [buffer_index, expected_address] : address_per_buffer_index) {
auto base_index = UINT32_WORDS_PER_LOCAL_CIRCULAR_BUFFER_CONFIG * buffer_index;
- EXPECT_EQ(
- expected_address >> CIRCULAR_BUFFER_LOG2_WORD_SIZE_BYTES,
- cb_config_vector.at(base_index)); // address validation
+ EXPECT_EQ(expected_address,
+ cb_config_vector.at(base_index)); // address validation
EXPECT_EQ(
num_pages_per_buffer_index.at(buffer_index),
cb_config_vector.at(base_index + 2)); // num pages validation
@@ -391,9 +389,8 @@ TEST_F(DeviceFixture, TensixTestUpdateCircularBufferPageSize) {
for (const auto& [buffer_index, expected_address] : address_per_buffer_index) {
auto base_index = UINT32_WORDS_PER_LOCAL_CIRCULAR_BUFFER_CONFIG * buffer_index;
- EXPECT_EQ(
- expected_address >> CIRCULAR_BUFFER_LOG2_WORD_SIZE_BYTES,
- cb_config_vector.at(base_index)); // address validation
+ EXPECT_EQ(expected_address,
+ cb_config_vector.at(base_index)); // address validation
EXPECT_EQ(
num_pages_per_buffer_index.at(buffer_index),
cb_config_vector.at(base_index + 2)); // num pages validation
diff --git a/tests/tt_metal/tt_metal/api/circular_buffer/test_CircularBuffer_creation.cpp b/tests/tt_metal/tt_metal/api/circular_buffer/test_CircularBuffer_creation.cpp
index b9d1c369973..1278c8abb7d 100644
--- a/tests/tt_metal/tt_metal/api/circular_buffer/test_CircularBuffer_creation.cpp
+++ b/tests/tt_metal/tt_metal/api/circular_buffer/test_CircularBuffer_creation.cpp
@@ -65,22 +65,10 @@ TEST_F(DeviceFixture, TensixTestCreateCircularBufferAtValidIndices) {
uint32_t l1_unreserved_base = devices_.at(0)->get_base_allocator_addr(HalMemType::L1);
std::map> golden_cb_config = {
- {0,
- {l1_unreserved_base >> CIRCULAR_BUFFER_LOG2_WORD_SIZE_BYTES,
- cb_config.page_size >> CIRCULAR_BUFFER_LOG2_WORD_SIZE_BYTES,
- cb_config.num_pages}},
- {2,
- {l1_unreserved_base >> CIRCULAR_BUFFER_LOG2_WORD_SIZE_BYTES,
- cb_config.page_size >> CIRCULAR_BUFFER_LOG2_WORD_SIZE_BYTES,
- cb_config.num_pages}},
- {16,
- {l1_unreserved_base >> CIRCULAR_BUFFER_LOG2_WORD_SIZE_BYTES,
- cb_config.page_size >> CIRCULAR_BUFFER_LOG2_WORD_SIZE_BYTES,
- cb_config.num_pages}},
- {24,
- {l1_unreserved_base >> CIRCULAR_BUFFER_LOG2_WORD_SIZE_BYTES,
- cb_config.page_size >> CIRCULAR_BUFFER_LOG2_WORD_SIZE_BYTES,
- cb_config.num_pages}}};
+ {0, {l1_unreserved_base, cb_config.page_size, cb_config.num_pages}},
+ {2, {l1_unreserved_base, cb_config.page_size, cb_config.num_pages}},
+ {16, {l1_unreserved_base, cb_config.page_size, cb_config.num_pages}},
+ {24, {l1_unreserved_base, cb_config.page_size, cb_config.num_pages}}};
std::map data_format_spec = {
{0, cb_config.data_format},
{2, cb_config.data_format},
diff --git a/tests/tt_metal/tt_metal/common/command_queue_fixture.hpp b/tests/tt_metal/tt_metal/common/command_queue_fixture.hpp
index ff143198763..b51338d1d0f 100644
--- a/tests/tt_metal/tt_metal/common/command_queue_fixture.hpp
+++ b/tests/tt_metal/tt_metal/common/command_queue_fixture.hpp
@@ -8,7 +8,7 @@
#include "dispatch_fixture.hpp"
#include "hostdevcommon/common_values.hpp"
#include "impl/device/device.hpp"
-#include "umd/device/tt_cluster_descriptor_types.h"
+#include "umd/device/types/cluster_descriptor_types.h"
#include "tt_metal/host_api.hpp"
#include "tt_metal/detail/tt_metal.hpp"
#include "tt_metal/test_utils/env_vars.hpp"
diff --git a/tests/tt_metal/tt_metal/common/dispatch_fixture.hpp b/tests/tt_metal/tt_metal/common/dispatch_fixture.hpp
index 7656ac8c147..57bfbcdb934 100644
--- a/tests/tt_metal/tt_metal/common/dispatch_fixture.hpp
+++ b/tests/tt_metal/tt_metal/common/dispatch_fixture.hpp
@@ -46,7 +46,7 @@ class DispatchFixture : public ::testing::Test {
}
void ReadBuffer(
tt::tt_metal::Device* device,
- std::shared_ptr out_buffer,
+ const std::shared_ptr& out_buffer,
std::vector& dst_vec) {
if (this->slow_dispatch_) {
tt::tt_metal::detail::ReadFromBuffer(out_buffer, dst_vec);
diff --git a/tests/tt_metal/tt_metal/common/multi_device_fixture.hpp b/tests/tt_metal/tt_metal/common/multi_device_fixture.hpp
index 7e1626ba7fe..21cf4dc2943 100644
--- a/tests/tt_metal/tt_metal/common/multi_device_fixture.hpp
+++ b/tests/tt_metal/tt_metal/common/multi_device_fixture.hpp
@@ -8,7 +8,7 @@
#include "host_api.hpp"
#include "dispatch_fixture.hpp"
-#include "umd/device/tt_cluster_descriptor_types.h"
+#include "umd/device/types/cluster_descriptor_types.h"
#include "tt_metal/test_utils/env_vars.hpp"
#include "tt_metal/impl/device/device_pool.hpp"
diff --git a/tests/tt_metal/tt_metal/dispatch/dispatch_program/test_EnqueueProgram.cpp b/tests/tt_metal/tt_metal/dispatch/dispatch_program/test_EnqueueProgram.cpp
index 356f7766820..5dd7eea0042 100644
--- a/tests/tt_metal/tt_metal/dispatch/dispatch_program/test_EnqueueProgram.cpp
+++ b/tests/tt_metal/tt_metal/dispatch/dispatch_program/test_EnqueueProgram.cpp
@@ -101,7 +101,7 @@ bool cb_config_successful(Device* device, Program& program, const DummyProgramMu
tt::tt_metal::detail::ReadFromDeviceL1(
device,
core_coord,
- program.get_sem_base_addr(device, core_coord, CoreType::WORKER),
+ program.get_cb_base_addr(device, core_coord, CoreType::WORKER),
cb_config_buffer_size,
cb_config_vector);
@@ -110,8 +110,8 @@ bool cb_config_successful(Device* device, Program& program, const DummyProgramMu
const uint32_t index = program_config.cb_config_vector[i].cb_id * sizeof(uint32_t);
const uint32_t cb_num_pages = program_config.cb_config_vector[i].num_pages;
const uint32_t cb_size = cb_num_pages * program_config.cb_config_vector[i].page_size;
- const bool addr_match = cb_config_vector.at(index) == ((cb_addr) >> 4);
- const bool size_match = cb_config_vector.at(index + 1) == (cb_size >> 4);
+ const bool addr_match = cb_config_vector.at(index) == cb_addr;
+ const bool size_match = cb_config_vector.at(index + 1) == cb_size;
const bool num_pages_match = cb_config_vector.at(index + 2) == cb_num_pages;
pass &= (addr_match and size_match and num_pages_match);
@@ -860,15 +860,15 @@ TEST_F(CommandQueueSingleCardProgramFixture, TensixTestMultiCBSharedAddressSpace
uint32_t cb_addr = device->get_base_allocator_addr(HalMemType::L1);
uint32_t intermediate_index = intermediate_cb * sizeof(uint32_t);
- bool addr_match_intermediate = cb_config_vector.at(intermediate_index) == ((cb_addr) >> 4);
- bool size_match_intermediate = cb_config_vector.at(intermediate_index + 1) == (cb_size >> 4);
+ bool addr_match_intermediate = cb_config_vector.at(intermediate_index) == (cb_addr);
+ bool size_match_intermediate = cb_config_vector.at(intermediate_index + 1) == (cb_size);
bool num_pages_match_intermediate = cb_config_vector.at(intermediate_index + 2) == num_tiles;
bool pass_intermediate = (addr_match_intermediate and size_match_intermediate and num_pages_match_intermediate);
EXPECT_TRUE(pass_intermediate);
uint32_t out_index = out_cb * sizeof(uint32_t);
- bool addr_match_out = cb_config_vector.at(out_index) == ((cb_addr) >> 4);
- bool size_match_out = cb_config_vector.at(out_index + 1) == (cb_size >> 4);
+ bool addr_match_out = cb_config_vector.at(out_index) == cb_addr;
+ bool size_match_out = cb_config_vector.at(out_index + 1) == cb_size;
bool num_pages_match_out = cb_config_vector.at(out_index + 2) == num_tiles;
bool pass_out = (addr_match_out and size_match_out and num_pages_match_out);
EXPECT_TRUE(pass_out);
diff --git a/tests/tt_metal/tt_metal/dispatch/dispatch_program/test_dispatch_program_with_kernel_created_from_string.cpp b/tests/tt_metal/tt_metal/dispatch/dispatch_program/test_dispatch_program_with_kernel_created_from_string.cpp
index d54453532aa..d0c969bee1b 100644
--- a/tests/tt_metal/tt_metal/dispatch/dispatch_program/test_dispatch_program_with_kernel_created_from_string.cpp
+++ b/tests/tt_metal/tt_metal/dispatch/dispatch_program/test_dispatch_program_with_kernel_created_from_string.cpp
@@ -11,7 +11,7 @@
#include "impl/kernels/data_types.hpp"
#include "impl/kernels/kernel_types.hpp"
#include "impl/program/program.hpp"
-#include "umd/device/tt_cluster_descriptor_types.h"
+#include "umd/device/types/cluster_descriptor_types.h"
#include "program_with_kernel_created_from_string_fixture.hpp"
TEST_F(ProgramWithKernelCreatedFromStringFixture, TensixDataMovementKernel) {
diff --git a/tests/tt_metal/tt_metal/dispatch/dispatch_program/test_sub_device.cpp b/tests/tt_metal/tt_metal/dispatch/dispatch_program/test_sub_device.cpp
index 6016433f556..f140433f3a9 100644
--- a/tests/tt_metal/tt_metal/dispatch/dispatch_program/test_sub_device.cpp
+++ b/tests/tt_metal/tt_metal/dispatch/dispatch_program/test_sub_device.cpp
@@ -108,6 +108,7 @@ TEST_F(CommandQueueSingleCardFixture, TensixTestSubDeviceBasicPrograms) {
EnqueueProgram(device->command_queue(), incrementer_program, false);
}
Synchronize(device);
+ detail::DumpDeviceProfileResults(device);
}
}
@@ -136,5 +137,6 @@ TEST_F(CommandQueueSingleCardFixture, TensixActiveEthTestSubDeviceBasicEthProgra
EnqueueProgram(device->command_queue(), incrementer_program, false);
}
Synchronize(device);
+ detail::DumpDeviceProfileResults(device);
}
}
diff --git a/tests/tt_metal/tt_metal/dispatch/dispatch_trace/test_sub_device.cpp b/tests/tt_metal/tt_metal/dispatch/dispatch_trace/test_sub_device.cpp
index 74fa8256ab8..5caff9052aa 100644
--- a/tests/tt_metal/tt_metal/dispatch/dispatch_trace/test_sub_device.cpp
+++ b/tests/tt_metal/tt_metal/dispatch/dispatch_trace/test_sub_device.cpp
@@ -63,6 +63,7 @@ TEST_F(CommandQueueSingleCardTraceFixture, TensixTestSubDeviceTraceBasicPrograms
ReplayTrace(device, device->command_queue().id(), tid_2, false);
}
Synchronize(device);
+ detail::DumpDeviceProfileResults(device);
}
}
diff --git a/tests/tt_metal/tt_metal/dispatch/multi_command_queue_fixture.hpp b/tests/tt_metal/tt_metal/dispatch/multi_command_queue_fixture.hpp
index 17a4da1cd7b..75c4c3f5cca 100644
--- a/tests/tt_metal/tt_metal/dispatch/multi_command_queue_fixture.hpp
+++ b/tests/tt_metal/tt_metal/dispatch/multi_command_queue_fixture.hpp
@@ -9,7 +9,7 @@
#include "hostdevcommon/common_values.hpp"
#include "impl/device/device.hpp"
#include "llrt/hal.hpp"
-#include "umd/device/tt_cluster_descriptor_types.h"
+#include "umd/device/types/cluster_descriptor_types.h"
#include "tt_metal/host_api.hpp"
#include "tt_metal/detail/tt_metal.hpp"
#include "tt_metal/test_utils/env_vars.hpp"
diff --git a/tests/tt_metal/tt_metal/dispatch/random_program_fixture.hpp b/tests/tt_metal/tt_metal/dispatch/random_program_fixture.hpp
index 55c9d3d40a4..fc87c2b58df 100644
--- a/tests/tt_metal/tt_metal/dispatch/random_program_fixture.hpp
+++ b/tests/tt_metal/tt_metal/dispatch/random_program_fixture.hpp
@@ -9,6 +9,7 @@
#include "llrt/hal.hpp"
#include "tt_metal/host_api.hpp"
#include "tt_metal/detail/tt_metal.hpp"
+#include "tt_metal/hw/inc/circular_buffer_constants.h"
#include "tt_metal/impl/kernels/kernel.hpp"
#include "tt_metal/common/tt_backend_api_types.hpp"
#include "dispatch_test_utils.hpp"
@@ -141,13 +142,14 @@ class RandomProgramFixture : virtual public CommandQueueSingleCardProgramFixture
const uint32_t num_cbs = this->generate_random_num(min, max);
std::vector cb_page_sizes;
for (uint32_t cb_idx = 0; cb_idx < num_cbs; cb_idx++) {
- const uint32_t cb_page_size = this->generate_random_num(MIN_CB_PAGE_SIZE, MAX_CB_PAGE_SIZE, 16);
+ const uint32_t cb_page_size =
+ this->generate_random_num(MIN_CB_PAGE_SIZE, MAX_CB_PAGE_SIZE, CIRCULAR_BUFFER_COMPUTE_WORD_SIZE);
const uint32_t cb_total_size =
this->generate_random_num(MIN_CB_TOTAL_SIZE, MAX_CB_TOTAL_SIZE, cb_page_size);
CircularBufferConfig config = CircularBufferConfig(cb_total_size, {{cb_idx, tt::DataFormat::Float16_b}})
.set_page_size(cb_idx, cb_page_size);
CreateCircularBuffer(program, cores, config);
- cb_page_sizes.push_back(cb_page_size / 16);
+ cb_page_sizes.push_back(cb_page_size);
}
return cb_page_sizes;
}
diff --git a/tests/tt_metal/tt_metal/integration/test_autonomous_relay_streams.cpp b/tests/tt_metal/tt_metal/integration/test_autonomous_relay_streams.cpp
index 1a2618ddc09..4c4077c4adc 100644
--- a/tests/tt_metal/tt_metal/integration/test_autonomous_relay_streams.cpp
+++ b/tests/tt_metal/tt_metal/integration/test_autonomous_relay_streams.cpp
@@ -11,7 +11,7 @@
#include
#include "gtest/gtest.h"
-#include "umd/device/tt_arch_types.h"
+#include "umd/device/types/arch.h"
#include "command_queue_fixture.hpp"
#include "tt_metal/common/logger.hpp"
#include "impl/device/device.hpp"
diff --git a/tests/tt_metal/tt_metal/perf_microbenchmark/3_pcie_transfer/kernels/pull_from_pcie.cpp b/tests/tt_metal/tt_metal/perf_microbenchmark/3_pcie_transfer/kernels/pull_from_pcie.cpp
index 12ad5aa9c30..ff9b279d489 100644
--- a/tests/tt_metal/tt_metal/perf_microbenchmark/3_pcie_transfer/kernels/pull_from_pcie.cpp
+++ b/tests/tt_metal/tt_metal/perf_microbenchmark/3_pcie_transfer/kernels/pull_from_pcie.cpp
@@ -17,7 +17,7 @@ void kernel_main() {
volatile tt_l1_ptr uint32_t* done_address_ptr = reinterpret_cast(done_address);
- uint64_t pcie_noc_xy_encoding = (uint64_t)NOC_XY_PCIE_ENCODING(PCIE_NOC_X, PCIE_NOC_Y, NOC_INDEX);
+ uint64_t pcie_noc_xy_encoding = (uint64_t)NOC_XY_PCIE_ENCODING(PCIE_NOC_X, PCIE_NOC_Y);
while (done_address_ptr[0] == 0) {
uint64_t host_src_addr = pcie_noc_xy_encoding | pcie_read_ptr;
noc_async_read(host_src_addr, done_address, read_sizeB);
diff --git a/tests/tt_metal/tt_metal/perf_microbenchmark/dispatch/test_pgm_dispatch.cpp b/tests/tt_metal/tt_metal/perf_microbenchmark/dispatch/test_pgm_dispatch.cpp
index 600483ffe02..22127a8fb8a 100644
--- a/tests/tt_metal/tt_metal/perf_microbenchmark/dispatch/test_pgm_dispatch.cpp
+++ b/tests/tt_metal/tt_metal/perf_microbenchmark/dispatch/test_pgm_dispatch.cpp
@@ -2,7 +2,7 @@
//
// SPDX-License-Identifier: Apache-2.0
-#include "umd/device/tt_cluster_descriptor_types.h"
+#include "umd/device/types/cluster_descriptor_types.h"
#include "tt_metal/host_api.hpp"
#include "tt_metal/detail/tt_metal.hpp"
#include "tt_metal/impl/dispatch/command_queue.hpp"
diff --git a/tests/tt_metal/tt_metal/perf_microbenchmark/ethernet/test_ethernet_bidirectional_bandwidth_no_edm.cpp b/tests/tt_metal/tt_metal/perf_microbenchmark/ethernet/test_ethernet_bidirectional_bandwidth_no_edm.cpp
index f5353a53a35..f351b0a75d7 100644
--- a/tests/tt_metal/tt_metal/perf_microbenchmark/ethernet/test_ethernet_bidirectional_bandwidth_no_edm.cpp
+++ b/tests/tt_metal/tt_metal/perf_microbenchmark/ethernet/test_ethernet_bidirectional_bandwidth_no_edm.cpp
@@ -9,7 +9,7 @@
#include
#include
-#include "umd/device/tt_arch_types.h"
+#include "umd/device/types/arch.h"
#include "impl/device/device.hpp"
#include "impl/kernels/kernel_types.hpp"
#include "tt_backend_api_types.hpp"
diff --git a/tests/tt_metal/tt_metal/perf_microbenchmark/ethernet/test_ethernet_hop_latencies_no_edm.cpp b/tests/tt_metal/tt_metal/perf_microbenchmark/ethernet/test_ethernet_hop_latencies_no_edm.cpp
index fbbdcfaf5c7..3123ea1736a 100644
--- a/tests/tt_metal/tt_metal/perf_microbenchmark/ethernet/test_ethernet_hop_latencies_no_edm.cpp
+++ b/tests/tt_metal/tt_metal/perf_microbenchmark/ethernet/test_ethernet_hop_latencies_no_edm.cpp
@@ -10,7 +10,7 @@
#include "tt_metal/distributed/mesh_device_view.hpp"
#include "tt_metal/common/logger.hpp"
-#include "umd/device/tt_arch_types.h"
+#include "umd/device/types/arch.h"
#include "impl/device/device.hpp"
#include "impl/kernels/data_types.hpp"
#include "impl/kernels/kernel_types.hpp"
diff --git a/tests/tt_metal/tt_metal/perf_microbenchmark/ethernet/test_ethernet_link_ping_latency_no_edm.cpp b/tests/tt_metal/tt_metal/perf_microbenchmark/ethernet/test_ethernet_link_ping_latency_no_edm.cpp
index 1cc7f4e6c7e..d2544f39271 100644
--- a/tests/tt_metal/tt_metal/perf_microbenchmark/ethernet/test_ethernet_link_ping_latency_no_edm.cpp
+++ b/tests/tt_metal/tt_metal/perf_microbenchmark/ethernet/test_ethernet_link_ping_latency_no_edm.cpp
@@ -10,7 +10,7 @@
#include
#include