Skip to content

Commit

Permalink
#5337: Merge branch 'mixtral-ttlib-matmuls' of github.com:tenstorrent…
Browse files Browse the repository at this point in the history
…/tt-metal into mixtral-ttlib-matmuls
  • Loading branch information
yieldthought committed May 31, 2024
2 parents 050312f + 2a85990 commit ed8ec62
Show file tree
Hide file tree
Showing 18 changed files with 1,031 additions and 672 deletions.
21 changes: 21 additions & 0 deletions docs/source/tt-metalium/tools/watcher.rst
Original file line number Diff line number Diff line change
Expand Up @@ -201,3 +201,24 @@ watcher log:
0x00000020,0x0000001f,0x0000001e,0x0000001d,0x0000001c,0x0000001b,0x0000001a,0x00000019,
0x00000018,0x00000017,0x00000016,0x00000015,0x00000014,0x00000013,0x00000012,0x00000011,
0x00000010,0x0000000f,0x0000000e,0x0000000d,0x0000000c,0x0000000b,0x0000000a]
Debug Delays
------------
Watcher can insert NOC transaction delays for debugging purposes. These delays can be specified by
transaction type and location. Environment variable `TT_METAL_WATCHER_DELAY` specifies the number
of clock cycles to wait for. Similarly to DPRINT, the delay can be set for all cores, or a
or a subset by setting environment variable `TT_METAL_*_DEBUG_DELAY_CORES`: x,y OR (x1,y1),(x2,y2),(x3,y3) OR (x1,y1)-(x2,y2) OR all.
The * can be one of: READ, WRITE or ATOMIC indicating whether the delays will be inserted before read, write or atomic NOC
transactions. Finally, the delay can be set for a specific RISCs (BRISC, NCRISC, TRISC0, TRISC1, TRISC2) through the
environment variable `TT_METAL_*_DEBUG_DELAY_RISCVS`: (one of: BR,NC,TR0,TR1,TR2); if not set, the delay
is applied to all RISCs.
Note that `TT_METAL_WATCHER` must be set and `TT_METAL_WATCHER_DISABLE_NOC_SANITIZE` must not be
set for the delays to be applied.

For example, the following command will run test_eltwise_binary with a delay of 10 iterations added to both READ and WRITE
transactions on BRISC core at location 0,0:

.. code-block::
TT_METAL_WATCHER=1 TT_METAL_WATCHER_DEBUG_DELAY=10 TT_METAL_READ_DEBUG_DELAY_CORES=0,0 TT_METAL_WRITE_DEBUG_DELAY_CORES=0,0 TT_METAL_READ_DEBUG_DELAY_RISCVS=BR TT_METAL_WRITE_DEBUG_DELAY_RISCVS=BR ./build/test/tt_metal/test_eltwise_binary
61 changes: 38 additions & 23 deletions models/demos/t3000/mixtral8x7b/scripts/op_perf_results.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0
from collections import defaultdict
from argparse import ArgumentParser
import csv

Expand All @@ -10,6 +11,7 @@ def main():
"Parse an op perf results CSV and show performance data using the min allgather time and max other time over devices, optionally only for a specific signpost region."
)
parser.add_argument("csv", help="Input CSV file")
parser.add_argument("--all", help="Show all times for each device", action="store_true")
parser.add_argument("--signpost", help="Only include data after this signpost and before any others")
args = parser.parse_args()

Expand All @@ -18,12 +20,12 @@ def main():

print(f'{"Op":20} {"Time (us)"}')
for block in blocks:
print(block)
print(block.long_str() if args.all else block.short_str())

total_time_ns = sum(block.time for block in blocks)
total_time_ns = sum(block.time() for block in blocks)
total_time_s = total_time_ns / 1e9
tokens_per_s = 1 / total_time_s
print(f"Tokens/s/user: {tokens_per_s:.2f} ({total_time_s:.2f}s latency)")
print(f"Tokens/s/user: {tokens_per_s:.2f} ({total_time_s*1000:.1f} ms latency)")
if signposts_seen and not args.signpost:
print(f"Warning - this file contains the following signposts that were not used for this analysis:")
for s in signposts_seen:
Expand All @@ -43,12 +45,18 @@ class Block:
def __init__(self, op_name, times):
self.op_name = op_name
self.times = times
self.time = min(times) if "AllGather" in op_name else max(times)

def __str__(self):
def time(self):
return min(self.times) if "AllGather" in self.op_name else max(self.times)

def short_str(self):
short_name = self.op_name.split("::")[-1].split(")")[0]
time_range = max(self.times) - min(self.times)
return f"{short_name:20} {self.time/1000:-6.0f} ± {time_range/1000:-5.0f}"
return f"{short_name:20} {self.time()/1000:-6.0f} ± {time_range/1000:-5.0f}"

def long_str(self):
short_name = self.op_name.split("::")[-1].split(")")[0]
return f"{short_name:20} {self.time()/1000:-6.0f} <-" + " | ".join(f"{t/1000:-5.0f}" for t in self.times)

def __repr__(self):
return f"Block({self.op_name}, {self.times})"
Expand All @@ -60,10 +68,8 @@ def make_blocks(header, rows, signpost):
and a list of times for each device.
"""

# loop through the rows until we find a repeated device i, then emit a block
blocks = []
device_ids = set()
times = []
# group rows by device then merge them together
block_by_device = defaultdict(list)
stop_on_signpost = False
signposts_seen = []

Expand All @@ -84,22 +90,31 @@ def make_blocks(header, rows, signpost):
elif op_name == signpost:
# clear any previous data and stop on the next signpost
stop_on_signpost = True
blocks = []
device_ids = set()
times = []
block_by_device = defaultdict(list)
elif op_type == "tt_dnn_device":
device_id = int(row[DEVICE_ID])
time = int(row[FW_DURATION])
if device_id in device_ids:
blocks.append(Block(block_op_name, times))
device_ids = set()
times = []

block_op_name = op_name
device_ids.add(device_id)
times.append(time)

return blocks, signposts_seen
block_by_device[device_id].append(Block(op_name, [time]))

# merge each device block into a single block with all the device times,
# checking that the op name matches
# blocks_by_device is a dict of device_id -> Block
# we want to get a list of Block (with all device times)

device_ids = list(block_by_device.keys())
merged_blocks = block_by_device[device_ids[0]]

for device_id in device_ids[1:]:
assert len(block_by_device[device_id]) == len(
merged_blocks
), f"Device {device_id} has {len(block_by_device[device_id])} ops, expected {len(merged_blocks)} from previous devices"
for row, b in enumerate(block_by_device[device_id]):
assert (
b.op_name == merged_blocks[row].op_name
), f"Op name mismatch at row {row}: device {device_id} has {b.op_name} != {merged_blocks[row].op_name}"
merged_blocks[row].times += b.times

return merged_blocks, signposts_seen


if __name__ == "__main__":
Expand Down
4 changes: 3 additions & 1 deletion models/demos/t3000/mixtral8x7b/tt/mixtral_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,9 @@ def forward(self, x: ttnn.Tensor) -> ttnn.Tensor:
w2_out = ttnn.experimental.operations.primary.matmul_1d(
w2_in,
self.w2,
program_config=self.model_config["FF2_OUTPUT_PROGCFG"],
program_config=self.model_config[
"FF3_OUTPUT_PROGCFG"
], # FF3 config avoids random hangs. TODO: Investigate why.
output_mem_config=self.model_config["FF2_OUTPUT_MEMCFG"],
compute_kernel_config=self.model_args.get_compute_kernel_config(),
output_dtype=ttnn.bfloat8_b,
Expand Down
106 changes: 0 additions & 106 deletions models/demos/wormhole/mistral7b/scripts/op_perf_results.py

This file was deleted.

1 change: 1 addition & 0 deletions tests/tt_metal/tt_metal/unit_tests_common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ set(UNIT_TESTS_COMMON_SRC
${CMAKE_CURRENT_SOURCE_DIR}/dram/test_dram.cpp
${CMAKE_CURRENT_SOURCE_DIR}/watcher/test_assert.cpp
${CMAKE_CURRENT_SOURCE_DIR}/watcher/test_noc_sanitize.cpp
${CMAKE_CURRENT_SOURCE_DIR}/watcher/test_noc_sanitize_delays.cpp
${CMAKE_CURRENT_SOURCE_DIR}/watcher/test_pause.cpp
${CMAKE_CURRENT_SOURCE_DIR}/watcher/test_ringbuf.cpp
${CMAKE_CURRENT_SOURCE_DIR}/watcher/test_waypoint.cpp
Expand Down
28 changes: 14 additions & 14 deletions tests/tt_metal/tt_metal/unit_tests_common/common/dprint_fixture.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@ class DPrintFixture: public CommonFixture {
// The core range (physical) needs to be set >= the set of all cores
// used by all tests using this fixture, so set dprint enabled for
// all cores and all devices
tt::llrt::OptionsG.set_dprint_enabled(true);
tt::llrt::OptionsG.set_dprint_all_cores(CoreType::WORKER, true);
tt::llrt::OptionsG.set_dprint_all_cores(CoreType::ETH, true);
tt::llrt::OptionsG.set_dprint_all_chips(true);
tt::llrt::OptionsG.set_feature_enabled(tt::llrt::RunTimeDebugFeatureDprint, true);
tt::llrt::OptionsG.set_feature_all_cores(tt::llrt::RunTimeDebugFeatureDprint, CoreType::WORKER, true);
tt::llrt::OptionsG.set_feature_all_cores(tt::llrt::RunTimeDebugFeatureDprint, CoreType::ETH, true);
tt::llrt::OptionsG.set_feature_all_chips(tt::llrt::RunTimeDebugFeatureDprint, true);
// Send output to a file so the test can check after program is run.
tt::llrt::OptionsG.set_dprint_file_name(dprint_file_name);
tt::llrt::OptionsG.set_feature_file_name(tt::llrt::RunTimeDebugFeatureDprint, dprint_file_name);
tt::llrt::OptionsG.set_test_mode_enabled(true);
watcher_previous_enabled = tt::llrt::OptionsG.get_watcher_enabled();
tt::llrt::OptionsG.set_watcher_enabled(false);
Expand All @@ -49,7 +49,7 @@ class DPrintFixture: public CommonFixture {
disabled[core_desc.dispatch_core_type].insert(core);
}
}
tt::llrt::OptionsG.set_dprint_disabled_cores(disabled);
tt::llrt::OptionsG.set_feature_disabled_cores(tt::llrt::RunTimeDebugFeatureDprint, disabled);

ExtraSetUp();

Expand All @@ -65,12 +65,12 @@ class DPrintFixture: public CommonFixture {
std::remove(dprint_file_name.c_str());

// Reset DPrint settings
tt::llrt::OptionsG.set_dprint_cores({});
tt::llrt::OptionsG.set_dprint_enabled(false);
tt::llrt::OptionsG.set_dprint_all_cores(CoreType::WORKER, false);
tt::llrt::OptionsG.set_dprint_all_cores(CoreType::ETH, false);
tt::llrt::OptionsG.set_dprint_all_chips(false);
tt::llrt::OptionsG.set_dprint_file_name("");
tt::llrt::OptionsG.set_feature_cores(tt::llrt::RunTimeDebugFeatureDprint, {});
tt::llrt::OptionsG.set_feature_enabled(tt::llrt::RunTimeDebugFeatureDprint, false);
tt::llrt::OptionsG.set_feature_all_cores(tt::llrt::RunTimeDebugFeatureDprint, CoreType::WORKER, false);
tt::llrt::OptionsG.set_feature_all_cores(tt::llrt::RunTimeDebugFeatureDprint, CoreType::ETH, false);
tt::llrt::OptionsG.set_feature_all_chips(tt::llrt::RunTimeDebugFeatureDprint, false);
tt::llrt::OptionsG.set_feature_file_name(tt::llrt::RunTimeDebugFeatureDprint, "");
tt::llrt::OptionsG.set_test_mode_enabled(false);
tt::llrt::OptionsG.set_watcher_enabled(watcher_previous_enabled);
}
Expand All @@ -97,7 +97,7 @@ class DPrintFixtureDisableDevices: public DPrintFixture {
protected:
void ExtraSetUp() override {
// For this test, mute each devices using the environment variable
tt::llrt::OptionsG.set_dprint_all_chips(false);
tt::llrt::OptionsG.set_dprint_chip_ids({});
tt::llrt::OptionsG.set_feature_all_chips(tt::llrt::RunTimeDebugFeatureDprint, false);
tt::llrt::OptionsG.set_feature_chip_ids(tt::llrt::RunTimeDebugFeatureDprint, {});
}
};
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <thread>
#include "common_fixture.hpp"
#include "impl/debug/watcher_server.hpp"
#include "llrt/rtoptions.hpp"

// A version of CommonFixture with watcher enabled
class WatcherFixture: public CommonFixture {
Expand Down Expand Up @@ -84,3 +85,40 @@ class WatcherFixture: public CommonFixture {
tt::watcher_clear_log();
}
};

// A version of WatcherFixture with read and write debug delays enabled
class WatcherDelayFixture : public WatcherFixture {
public:
tt::llrt::TargetSelection saved_target_selection[tt::llrt::RunTimeDebugFeatureCount];

std::map<CoreType, std::vector<CoreCoord>> delayed_cores;

void SetUp() override {
tt::llrt::OptionsG.set_watcher_debug_delay(5000000);
delayed_cores[CoreType::WORKER] = {{0, 0}, {1, 1}};

// Store the previous state of the watcher features
saved_target_selection[tt::llrt::RunTimeDebugFeatureReadDebugDelay] = tt::llrt::OptionsG.get_feature_targets(tt::llrt::RunTimeDebugFeatureReadDebugDelay);
saved_target_selection[tt::llrt::RunTimeDebugFeatureWriteDebugDelay] = tt::llrt::OptionsG.get_feature_targets(tt::llrt::RunTimeDebugFeatureWriteDebugDelay);
saved_target_selection[tt::llrt::RunTimeDebugFeatureAtomicDebugDelay] = tt::llrt::OptionsG.get_feature_targets(tt::llrt::RunTimeDebugFeatureAtomicDebugDelay);

// Enable read and write debug delay for the test core
tt::llrt::OptionsG.set_feature_enabled(tt::llrt::RunTimeDebugFeatureReadDebugDelay, true);
tt::llrt::OptionsG.set_feature_cores(tt::llrt::RunTimeDebugFeatureReadDebugDelay, delayed_cores);
tt::llrt::OptionsG.set_feature_enabled(tt::llrt::RunTimeDebugFeatureWriteDebugDelay, true);
tt::llrt::OptionsG.set_feature_cores(tt::llrt::RunTimeDebugFeatureWriteDebugDelay, delayed_cores);

// Call parent
WatcherFixture::SetUp();
}

void TearDown() override {
// Call parent
WatcherFixture::TearDown();

// Restore
tt::llrt::OptionsG.set_feature_targets(tt::llrt::RunTimeDebugFeatureReadDebugDelay, saved_target_selection[tt::llrt::RunTimeDebugFeatureReadDebugDelay]);
tt::llrt::OptionsG.set_feature_targets(tt::llrt::RunTimeDebugFeatureWriteDebugDelay, saved_target_selection[tt::llrt::RunTimeDebugFeatureWriteDebugDelay]);
tt::llrt::OptionsG.set_feature_targets(tt::llrt::RunTimeDebugFeatureAtomicDebugDelay, saved_target_selection[tt::llrt::RunTimeDebugFeatureAtomicDebugDelay]);
}
};
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ TEST(DPrintErrorChecking, TestPrintInvalidCore) {
// device setup, but not the print server should simply ignore the invalid cores.
std::map<CoreType, std::vector<CoreCoord>> dprint_cores;
dprint_cores[CoreType::WORKER] = {{0, 0}, {1, 1}, {100, 100}};
tt::llrt::OptionsG.set_dprint_cores(dprint_cores); // Only (100, 100) is invalid.
tt::llrt::OptionsG.set_dprint_enabled(true);
tt::llrt::OptionsG.set_feature_cores(tt::llrt::RunTimeDebugFeatureDprint, dprint_cores);
tt::llrt::OptionsG.set_feature_enabled(tt::llrt::RunTimeDebugFeatureDprint, true);

const int device_id = 0;
Device* device = nullptr;
Expand All @@ -27,6 +27,6 @@ TEST(DPrintErrorChecking, TestPrintInvalidCore) {
// We expect that even though illegal worker cores were requested, device setup did not hang.
// So just make sure that device setup worked and then close the device.
EXPECT_TRUE(device != nullptr);
tt::llrt::OptionsG.set_dprint_enabled(false);
tt::llrt::OptionsG.set_feature_enabled(tt::llrt::RunTimeDebugFeatureDprint, false);
tt::tt_metal::CloseDevice(device);
}
Loading

0 comments on commit ed8ec62

Please sign in to comment.