forked from tenstorrent/tt-metal
-
Notifications
You must be signed in to change notification settings - Fork 0
/
custom_preprocessing.py
41 lines (31 loc) · 1.52 KB
/
custom_preprocessing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
# SPDX-License-Identifier: Apache-2.0
import torch
from torch import nn
import ttnn
def preprocess_groupnorm_parameter(parameter, *, dtype):
parameter = ttnn.from_torch(parameter, dtype=dtype, layout=ttnn.TILE_LAYOUT)
return parameter
def preprocess_conv_parameter(parameter, *, dtype):
parameter = ttnn.from_torch(parameter, dtype=dtype, layout=ttnn.TILE_LAYOUT)
return parameter
def custom_preprocessor(model, name):
parameters = {}
if isinstance(model, nn.GroupNorm):
parameters["weight"] = preprocess_groupnorm_parameter(model.weight, dtype=ttnn.bfloat16)
parameters["bias"] = preprocess_groupnorm_parameter(model.bias, dtype=ttnn.bfloat16)
if isinstance(model, nn.Conv2d):
weight = torch.permute(model.weight, (2, 3, 0, 1))
parameters["weight"] = preprocess_conv_parameter(weight, dtype=ttnn.bfloat16)
parameters["bias"] = preprocess_conv_parameter(model.bias, dtype=ttnn.bfloat16)
if isinstance(model, (nn.Linear, nn.LayerNorm)):
weight = model.weight.T.contiguous()
while len(weight.shape) < 4:
weight = weight.unsqueeze(0)
parameters["weight"] = ttnn.from_torch(weight, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT)
if model.bias is not None:
bias = model.bias
while len(bias.shape) < 4:
bias = bias.unsqueeze(0)
parameters["bias"] = ttnn.from_torch(bias, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT)
return parameters