Skip to content

Commit

Permalink
#9021: adding resnet api into ci.
Browse files Browse the repository at this point in the history
  • Loading branch information
shwetankTT committed Jul 13, 2024
1 parent e9b5260 commit 4f3b1b1
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 7 deletions.
3 changes: 2 additions & 1 deletion tests/scripts/run_python_model_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ if [ "$ARCH_NAME" != "wormhole_b0" ]; then

# Resnet18 tests with conv on cpu and with conv on device
pytest $TT_METAL_HOME/models/demos/resnet/tests/test_resnet18.py

pytest $TT_METAL_HOME/tests/ttnn/integration_tests/resnet/test_ttnn_functional_resnet50_new.py
# Falcon tests
pytest $TT_METAL_HOME/models/demos/falcon7b/tests/unit_tests/test_falcon_matmuls_and_bmms_with_mixed_precision.py -k "seq_len_128 and in0_BFLOAT16-in1_BFLOAT8_B-out_BFLOAT16-weights_DRAM"
pytest $TT_METAL_HOME/models/demos/falcon7b/tests/unit_tests/test_falcon_matmuls_and_bmms_with_mixed_precision.py -k "seq_len_512 and in0_BFLOAT16-in1_BFLOAT8_B-out_BFLOAT16-weights_DRAM"
Expand All @@ -53,6 +53,7 @@ if [ "$ARCH_NAME" != "wormhole_b0" ]; then
pytest $TT_METAL_HOME/models/demos/falcon7b/tests/unit_tests/test_falcon_attn_matmul.py -k "not attn_matmul_from_cache"
# higher sequence lengths and different formats trigger memory issues
pytest $TT_METAL_HOME/models/demos/falcon7b/tests/unit_tests/test_falcon_matmuls_and_bmms_with_mixed_precision.py -k "seq_len_128 and in0_BFLOAT16-in1_BFLOAT8_B-out_BFLOAT16-weights_DRAM"
pytest $TT_METAL_HOME/tests/ttnn/integration_tests/resnet/test_ttnn_functional_resnet50_new.py -k "pretrained_weight_false"

SLOW_MATMULS=1 WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest $TT_METAL_HOME/tests/ttnn/integration_tests/stable_diffusion/test_unet_2d_condition_model.py -k 512 --timeout=420
fi
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@

class ResNet50TestInfra:
def __init__(
self, device, batch_size, act_dtype, weight_dtype, math_fidelity, dealloc_input, final_output_mem_config
self, device, batch_size, act_dtype, weight_dtype, math_fidelity, use_pretrained_weight, dealloc_input, final_output_mem_config
):
super().__init__()
torch.manual_seed(0)
Expand All @@ -147,7 +147,11 @@ def __init__(
self.dealloc_input = dealloc_input
self.final_output_mem_config = final_output_mem_config

torch_model = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.IMAGENET1K_V1).eval()
torch_model = (
torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.IMAGENET1K_V1).eval()
if use_pretrained_weight
else torchvision.models.resnet50().eval()
)

model_config = {
"MATH_FIDELITY": math_fidelity,
Expand Down Expand Up @@ -225,15 +229,15 @@ def create_test_infra(
act_dtype,
weight_dtype,
math_fidelity,
use_pretrained_weight=True,
dealloc_input=True,
final_output_mem_config=ttnn.L1_MEMORY_CONFIG,
):
return ResNet50TestInfra(
device, batch_size, act_dtype, weight_dtype, math_fidelity, dealloc_input, final_output_mem_config
device, batch_size, act_dtype, weight_dtype, math_fidelity, use_pretrained_weight, dealloc_input, final_output_mem_config
)


@skip_for_wormhole_b0("PCC error with B=16. Fitting issue with B=20 due to 1x1s2 repleacement.")
@pytest.mark.parametrize("device_params", [{"l1_small_size": 24576}], indirect=True)
@pytest.mark.parametrize(
"batch_size, act_dtype, weight_dtype, math_fidelity",
Expand All @@ -250,13 +254,23 @@ def create_test_infra(
(20, ttnn.bfloat8_b, ttnn.bfloat8_b, ttnn.MathFidelity.LoFi),
),
)
def test_resnet_50(device, use_program_cache, batch_size, act_dtype, weight_dtype, math_fidelity):
@pytest.mark.parametrize(
"use_pretrained_weight",
[True, False],
ids=[
"pretrained_weight_true",
"pretrained_weight_false",
],
)
def test_resnet_50(
device, use_program_cache, batch_size, act_dtype, weight_dtype, math_fidelity, use_pretrained_weight
):
if batch_size == 8:
pytest.skip("Skipping batch size 8 due to memory config issue")
if is_wormhole_b0() and batch_size == 20:
pytest.skip("Skipping batch size 20 for Wormhole B0 due to fitting issue")

test_infra = create_test_infra(device, batch_size, act_dtype, weight_dtype, math_fidelity)
test_infra = create_test_infra(device, batch_size, act_dtype, weight_dtype, math_fidelity, use_pretrained_weight)
enable_memory_reports()
test_infra.preprocess_torch_input()
# First run configures convs JIT
Expand Down

0 comments on commit 4f3b1b1

Please sign in to comment.