Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#13397: Add data parallel suppport for SqueezeBERT model #13418

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

kkeerthana0573
Copy link
Contributor

@kkeerthana0573 kkeerthana0573 commented Oct 3, 2024

Ticket

Link to Github Issue

Problem description

The SqueezeBERT model is configured to run on either N150 or N300, depending on the available machine.

Checklist

  • Post commit CI passes
  • Blackhole Post commit (if applicable)
  • Model regression CI testing passes (if applicable)
  • Device performance regression CI testing passes (if applicable)
  • New/Existing tests provide coverage for changes

@kkeerthana0573 kkeerthana0573 force-pushed the keerthana/functional_squeezebert_dataparallel branch from e698cc6 to e041c8a Compare October 4, 2024 03:35
@tt-rkim
Copy link
Collaborator

tt-rkim commented Nov 18, 2024

Can you post passing links...

@kkeerthana0573 kkeerthana0573 force-pushed the keerthana/functional_squeezebert_dataparallel branch from e02f5ea to 7842fe6 Compare November 19, 2024 08:03
@tt-rkim
Copy link
Collaborator

tt-rkim commented Nov 19, 2024

You posted the wrong device perf link.

I found it: https://github.com/tenstorrent/tt-metal/actions/runs/11908678446/job/33185055179

Please post the right link next time.

By the way, it seems to have failed on your model.

@kkeerthana0573
Copy link
Contributor Author

@tt-rkim,
I might have overlooked the links. I'll update the PR.
Thank you.

@kkeerthana0573 kkeerthana0573 force-pushed the keerthana/functional_squeezebert_dataparallel branch from 7842fe6 to 5189ca0 Compare November 20, 2024 07:55
@kkeerthana0573
Copy link
Contributor Author

@tt-rkim
Copy link
Collaborator

tt-rkim commented Nov 20, 2024

I will approve to unblock since it's only one left, but please ensure ttnn nightly passes.

@kkeerthana0573 kkeerthana0573 force-pushed the keerthana/functional_squeezebert_dataparallel branch 2 times, most recently from 2f5a5b4 to 010fa8e Compare November 22, 2024 07:13
models/demos/wormhole/squeezebert/demo/demo.py Outdated Show resolved Hide resolved

del tt_output
i += 1
eval_score = squad_metric.compute(predictions=pred_labels, references=true_labels)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it would be safer if the test fails if eval_score is lower than some threshold

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The evaluation scores should ideally remain consistent with the batch size and number of iterations specified in the demo. However, they may or may not vary with changes in batch size. I’ve now added an assertion to validate the expected scores.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the check should be tighter. The purpose of this assert is to catch regressions if a bad commit goes in. For example, if I push a change to a kernel and the score drops from 98 to 97.9, this is a bug, and this test should catch it.

So, we want to assert if there are any changes in the eval score, maybe with a small margin.



def get_expected_times(squeezebert):
return {ttnn_functional_squeezebert: (29.29, 15.5)}[squeezebert]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the current run time? How close is it to the expected times?

Copy link
Contributor Author

@kkeerthana0573 kkeerthana0573 Nov 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current runtimes in the test file have been updated based on the average times observed during CI runs.
We’re uncertain about the target numbers. Is there any other metric, besides the average CI times, that we can use to determine the target numbers?

cc: @boris-drazic, @mbahnasTT.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You want the expected times to be close to the current time, but not so close that a small variation will cause test to fail. Maybe 10-20% margin is good.

@kkeerthana0573 kkeerthana0573 force-pushed the keerthana/functional_squeezebert_dataparallel branch 3 times, most recently from 436aa1d to b25513c Compare November 29, 2024 11:22
@saichandax saichandax requested a review from uaydonat December 3, 2024 04:56
@kkeerthana0573 kkeerthana0573 force-pushed the keerthana/functional_squeezebert_dataparallel branch from b25513c to 8ddc243 Compare December 3, 2024 10:40
@kkeerthana0573 kkeerthana0573 requested a review from a team as a code owner December 3, 2024 10:40
@kkeerthana0573 kkeerthana0573 force-pushed the keerthana/functional_squeezebert_dataparallel branch from 8ddc243 to 080e26d Compare December 3, 2024 12:58
mesh_device=mesh_device,
use_program_cache=use_program_cache,
model_name=model_name,
batch_size=8,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should call with correct batch size

mesh_device=mesh_device,
use_program_cache=use_program_cache,
model_name=model_name,
batch_size=8,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should call with correct batch size


profiler.start(f"preprocessing_parameter")
mesh_device_flag = is_wormhole_b0() and ttnn.GetNumAvailableDevices() == 2
batch_size = batch_size * 2 if mesh_device_flag else batch_size
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do not overwrite the batch_size that the caller gives

tt_model_name = f"ttnn_{model_name}_optimized"

mesh_device_flag = is_wormhole_b0() and ttnn.GetNumAvailableDevices() == 2
batch_size = batch_size * 2 if mesh_device_flag else batch_size
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do not overwrite the batch_size that the caller gives


del tt_output
i += 1
eval_score = squad_metric.compute(predictions=pred_labels, references=true_labels)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the check should be tighter. The purpose of this assert is to catch regressions if a bad commit goes in. For example, if I push a change to a kernel and the score drops from 98 to 97.9, this is a bug, and this test should catch it.

So, we want to assert if there are any changes in the eval score, maybe with a small margin.



def get_expected_times(squeezebert):
return {ttnn_functional_squeezebert: (29.29, 15.5)}[squeezebert]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You want the expected times to be close to the current time, but not so close that a small variation will cause test to fail. Maybe 10-20% margin is good.

@kkeerthana0573 kkeerthana0573 force-pushed the keerthana/functional_squeezebert_dataparallel branch 3 times, most recently from 5dd5731 to 9d46e12 Compare December 9, 2024 14:54
@kkeerthana0573 kkeerthana0573 force-pushed the keerthana/functional_squeezebert_dataparallel branch from 9d46e12 to 6a66479 Compare December 9, 2024 15:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants