-
Notifications
You must be signed in to change notification settings - Fork 87
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
#0: Merge branch 'main' into llama32-vision
- Loading branch information
Showing
331 changed files
with
5,508 additions
and
1,101 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,6 +8,38 @@ on: | |
- "main" | ||
|
||
jobs: | ||
pre-commit: | ||
name: Run Pre-commit Hooks | ||
runs-on: ubuntu-latest | ||
permissions: | ||
contents: write | ||
pull-requests: write | ||
steps: | ||
- name: Checkout code | ||
uses: actions/checkout@v4 | ||
with: | ||
fetch-depth: 0 # Fetch all history so 'origin/main' is available | ||
fetch-refs: true # Ensure all refs are fetched | ||
|
||
- name: Set up Python | ||
uses: actions/setup-python@v4 | ||
with: | ||
python-version: 3.11 | ||
|
||
- name: Run Pre-commit | ||
uses: pre-commit/[email protected] | ||
with: | ||
extra_args: | | ||
--from-ref ${{ github.event_name == 'pull_request' && format('refs/remotes/origin/{0}', github.event.pull_request.base.ref) || 'HEAD^' }} \ | ||
--to-ref HEAD | ||
continue-on-error: false | ||
check-black: | ||
runs-on: ubuntu-latest | ||
steps: | ||
- name: Do Nothing | ||
run: echo "Black is covered by pre-commit. This is a placeholder to be removed after updating branch restrictions." | ||
|
||
|
||
check-spdx-licenses: | ||
runs-on: ubuntu-latest | ||
steps: | ||
|
@@ -27,11 +59,6 @@ jobs: | |
- uses: actions/checkout@v4 | ||
- name: Check kernel count in base metal is less than maximum | ||
run: if (( $(find tt_metal/kernels/ -type f | wc -l) > 7 )); then exit 1; fi | ||
check-black: | ||
runs-on: ubuntu-latest | ||
steps: | ||
- uses: actions/checkout@v4 | ||
- uses: psf/[email protected] | ||
check-doc: | ||
runs-on: ubuntu-latest | ||
steps: | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
71 changes: 71 additions & 0 deletions
71
models/experimental/yolov4/ttnn/weight_parameter_update.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. | ||
|
||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import re | ||
from collections import OrderedDict | ||
|
||
|
||
def update_weigth_keys(key): | ||
key = key.replace("downsample", "down") | ||
key = key.replace("neck", "neek") | ||
if ".res" in key: | ||
|
||
def res_name_update(match): | ||
chr = match.group(1) | ||
num = int(match.group(2)) | ||
if num == 0 or num == 1: | ||
return f".{chr}.0.conv.{num}." | ||
if num == 3 or num == 4: | ||
return f".{chr}.1.conv.{num-3}." | ||
|
||
key = re.sub(r"\.res\.", r".resblock.", key) | ||
key = re.sub(r"\.(\d+)\.(\d+)\.", res_name_update, key) | ||
return key | ||
if "neek" in key: | ||
|
||
def neek_underscore_update_rule(match): | ||
chr = match.group(1) | ||
num1 = int(match.group(2)) | ||
num2 = int(match.group(3)) | ||
dict = { | ||
(7, 2): 8, | ||
(7, 3): 9, | ||
(7, 4): 11, | ||
(8, 2): 12, | ||
(7, 5): 13, | ||
(9, 2): 15, | ||
(9, 3): 16, | ||
(9, 4): 18, | ||
(10, 2): 19, | ||
(9, 5): 20, | ||
} | ||
if chr == "b": | ||
return f".conv{dict[(num1, num2)]}.conv.1." | ||
return f".conv{dict[(num1, num2)]}.conv.0." | ||
|
||
def neck_rename_update(match): | ||
chr = match.group(1) | ||
num = int(match.group(2)) | ||
if num <= 7: | ||
return f".conv{num}.conv.1." if chr == "b" else f".conv{num}.conv.0." | ||
dict = {8: 10, 9: 14, 10: 17} | ||
return f".conv{dict[num]}.conv.1." if chr == "b" else f".conv{dict[num]}.conv.0." | ||
|
||
updated_name = re.sub(r"\.([a-z])(\d+)_(\d+)\.", neek_underscore_update_rule, key) | ||
if key != updated_name: # chk if name got updated | ||
return updated_name | ||
updated_name = re.sub(r"\.([a-z])(\d+)\.", neck_rename_update, key) | ||
if key != updated_name: | ||
return updated_name | ||
key = re.sub(r"\.c(\d+)\.", r".conv\1.conv.0.", key) | ||
key = re.sub(r"\.b(\d+)\.", r".conv\1.conv.1.", key) | ||
return key | ||
|
||
|
||
def update_weight_parameters(model_weight): | ||
ttnn_model_random_weight = OrderedDict() | ||
for key, weight in model_weight.items(): | ||
updated_key = update_weigth_keys(key) | ||
ttnn_model_random_weight[updated_key] = weight | ||
return ttnn_model_random_weight |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.