diff --git a/.github/workflows/package-and-release.yaml b/.github/workflows/package-and-release.yaml index cfd9f035096..f6072be060c 100644 --- a/.github/workflows/package-and-release.yaml +++ b/.github/workflows/package-and-release.yaml @@ -161,7 +161,7 @@ jobs: with: name: changelog - name: Assert wheels exist - run: ls -arhl metal_libs-*+*.whl + run: ls -arhl ttnn-*+*.whl - name: Release # A major release has not been tagged yet, so we need to do this to avoid # Node 16 deprecation warning message @@ -178,7 +178,7 @@ jobs: README.md INSTALLING.md infra/machine_setup/scripts/setup_hugepages.py - metal_libs-*+*.whl + ttnn-*+*.whl fail_on_unmatched_files: true create-docker-release-image: needs: [ diff --git a/.github/workflows/ttnn-run-sweeps.yaml b/.github/workflows/ttnn-run-sweeps.yaml index 10b7f88d158..bf5a83b870f 100644 --- a/.github/workflows/ttnn-run-sweeps.yaml +++ b/.github/workflows/ttnn-run-sweeps.yaml @@ -12,7 +12,9 @@ on: - ALL SWEEPS (Nightly) - add - tilize + - tilize_with_val_padding - untilize + - untilize_with_unpadding - ccl.line_all_gather - ccl.all_gather_n300 - ccl.all_gather_n300_focused diff --git a/docs/source/ttnn/ttnn/api.rst b/docs/source/ttnn/ttnn/api.rst index 5c60787fd36..5a94c2eed96 100644 --- a/docs/source/ttnn/ttnn/api.rst +++ b/docs/source/ttnn/ttnn/api.rst @@ -98,9 +98,6 @@ Pointwise Unary ttnn.asinh ttnn.atan ttnn.atanh - ttnn.bitwise_and - ttnn.bitwise_or - ttnn.bitwise_xor ttnn.bitwise_not ttnn.bitwise_left_shift ttnn.bitwise_right_shift @@ -162,7 +159,6 @@ Pointwise Unary ttnn.normalize_global ttnn.normalize_hw ttnn.polygamma - ttnn.pow ttnn.prelu ttnn.rad2deg ttnn.rdiv @@ -175,7 +171,6 @@ Pointwise Unary ttnn.remainder ttnn.round ttnn.rsqrt - ttnn.rsub ttnn.selu ttnn.sigmoid ttnn.sigmoid_accurate @@ -309,10 +304,14 @@ Pointwise Binary ttnn.logical_or_ ttnn.logical_xor_ ttnn.rpow + ttnn.rsub ttnn.ldexp ttnn.logical_and ttnn.logical_or ttnn.logical_xor + ttnn.bitwise_and + ttnn.bitwise_or + ttnn.bitwise_xor ttnn.logaddexp ttnn.logaddexp2 ttnn.hypot @@ -335,6 +334,7 @@ Pointwise Binary ttnn.maximum ttnn.minimum ttnn.outer + ttnn.pow ttnn.polyval ttnn.scatter ttnn.atan2 diff --git a/models/demos/wormhole/distilbert/README.md b/models/demos/wormhole/distilbert/README.md new file mode 100644 index 00000000000..4b26a482beb --- /dev/null +++ b/models/demos/wormhole/distilbert/README.md @@ -0,0 +1,35 @@ +## Distilbert Model + +# Platforms: + WH N300, N150 + +## Introduction +DistilBERT is a transformers model, smaller and faster than BERT, which was pretrained on the same corpus in a self-supervised fashion, using the BERT base model as a teacher. The DistilBERT Question Answering model is fine-tuned specifically for the task of extracting answers from a given context, making it highly efficient for question-answering applications. + +# Details +The entry point to distilebert model is distilbert_for_question_answering in `models/demos/wormhole/distilbert/tt/ttnn_optimized_distilbert.py`. The model picks up certain configs and weights from huggingface pretrained model. We have used `distilbert-base-uncased-distilled-squad` version from huggingface as our reference. + +This model, located in `models/demos/wormhole`, supports functionality on both N150 and N300 devices, depending on availability. If the device is N300, the weights and inputs are distributed across the device, allowing the model to run in parallel. + +## Sequence Size: 384 + +Sequence size determines the maximum length of input sequences processed by the model, optimizing performance and compatibility. It's recommended to set the `sequence_size` to 384 + +## Batch size: 8 + +Batch Size determines the number of input sequences processed simultaneously during training or inference, impacting computational efficiency and memory usage. It's recommended to set the `batch_size` to 8 + +Use `pytest --disable-warnings models/demos/wormhole/distilbert/demo/demo.py::test_demo[wormhole_b0-True-models.demos.wormhole.distilbert.tt.ttnn_optimized_distilbert-8-distilbert-base-uncased-distilled-squad-models/demos/wormhole/distilbert/demo/input_data.json]` to run the ttnn_optimized_distilbert demo. + + +If you wish to run the demo with a different input, change the pytest fixture input_loc to the desired location and use `pytest --disable-warnings models/demos/wormhole/distilbert/demo/demo.py::test_demo[wormhole_b0-True-models.demos.wormhole.distilbert.tt.ttnn_optimized_distilbert-8-distilbert-base-uncased-distilled-squad-]`. This file is expected to have exactly 8 inputs. + +Our second demo is designed to run SQuADV2 dataset, run this with `pytest --disable-warnings models/demos/wormhole/distilbert/demo/demo.py::test_demo_squadv2[wormhole_b0-True-3-8-models.demos.wormhole.distilbert.tt.ttnn_optimized_distilbert-distilbert-base-uncased-distilled-squad]`. + +If you wish to run for `n_iterations` samples, use `pytest --disable-warnings models/demos/wormhole/distilbert/demo/demo.py::test_demo_squadv2[wormhole_b0-True--8-models.demos.wormhole.distilbert.tt.ttnn_optimized_distilbert-distilbert-base-uncased-distilled-squad]` + +## Inputs + +The demo receives inputs from respective input_data.json by default. To modify the inputs or specify a different path, adjust the input_path parameter in the command accordingly. It's recommended to avoid direct modifications to the input_data.json file. + +# Owner Sudharsan Vijayaraghavan diff --git a/models/demos/wormhole/distilbert/demo/demo.py b/models/demos/wormhole/distilbert/demo/demo.py new file mode 100644 index 00000000000..dfd89c18939 --- /dev/null +++ b/models/demos/wormhole/distilbert/demo/demo.py @@ -0,0 +1,351 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +# SPDX-License-Identifier: Apache-2.0 +import json +import pytest +import torch +from loguru import logger +import ttnn +from models.utility_functions import ( + disable_compilation_reports, + disable_persistent_kernel_cache, + profiler, +) +from models.demos.wormhole.distilbert.tt import ttnn_optimized_distilbert +from models.demos.wormhole.distilbert.distilbert_utils import ( + squadv2_1K_samples_input, + squadv2_answer_decode_batch, +) +from ttnn.model_preprocessing import ( + preprocess_model_parameters, +) +from models.utility_functions import is_wormhole_b0, skip_for_grayskull +from transformers import DistilBertForQuestionAnswering, AutoTokenizer, pipeline +import evaluate + + +def load_inputs(input_path, batch): + with open(input_path) as f: + input_data = json.load(f) + assert len(input_data) >= batch, f"Input data needs to have at least {batch} (batch size) entries." + context = [] + question = [] + for i in range(batch): + context.append(input_data[i]["context"]) + question.append(input_data[i]["question"]) + return context, question + + +def run_distilbert_question_and_answering_inference( + model_name, + batch_size, + sequence_size, + distilbert, + model_location_generator, + input_path, + mesh_device, +): + disable_persistent_kernel_cache() + + HF_model = DistilBertForQuestionAnswering.from_pretrained(model_name) + HF_model.eval() + tt_model_name = f"ttnn_{model_name}_optimized" + + inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) + weights_mesh_mapper = ttnn.ReplicateTensorToMesh(mesh_device) + output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) + + profiler.start(f"preprocessing_parameter") + + with ttnn.distribute(ttnn.ReplicateTensorToMesh(mesh_device)): + parameters = preprocess_model_parameters( + model_name=tt_model_name, + initialize_model=lambda: HF_model, + custom_preprocessor=ttnn_optimized_distilbert.custom_preprocessor, + device=mesh_device, + ) + profiler.end(f"preprocessing_parameter") + + # set up tokenizer + tokenizer = AutoTokenizer.from_pretrained(model_name) + config = HF_model.config + nlp = pipeline("question-answering", model=HF_model, tokenizer=tokenizer) + + context, question = load_inputs(input_path, batch_size) + preprocess_params, _, postprocess_params = nlp._sanitize_parameters(max_seq_len=sequence_size, padding="max_length") + inputs = nlp._args_parser({"question": question, "context": context}) + preprocessed_inputs = [] + for i in range(batch_size): + model_input = next(nlp.preprocess(inputs[0][i], **preprocess_params)) + single_input = { + "example": model_input["example"], + "inputs": model_input, + } + preprocessed_inputs.append(single_input) + + distilbert_input = tokenizer( + question, + context, + max_length=sequence_size, + padding="max_length", + truncation=True, + return_attention_mask=True, + return_tensors="pt", + ) + + profiler.start(f"preprocessing_input") + position_ids = torch.arange(config.max_position_embeddings).expand((1, -1)) + position_ids = torch.cat([position_ids] * batch_size, dim=0) + input_ids, position_ids, attention_mask = distilbert.preprocess_inputs( + distilbert_input["input_ids"], + position_ids, + distilbert_input["attention_mask"], + device=mesh_device, + mesh_mapper=inputs_mesh_mapper, + ) + profiler.end(f"preprocessing_input") + + mask_reshp = (batch_size, 1, 1, attention_mask.shape[1]) + score_shape = (batch_size, 12, 384, 384) + + mask = (distilbert_input["attention_mask"] == 0).view(mask_reshp).expand(score_shape) + min_val = torch.zeros(score_shape) + min_val_tensor = min_val.masked_fill(mask, torch.tensor(torch.finfo(torch.bfloat16).min)) + negative_val = torch.zeros(score_shape) + negative_val_tensor = negative_val.masked_fill(mask, -1) + + min_val_tensor = ttnn.from_torch( + min_val_tensor, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, mesh_mapper=inputs_mesh_mapper, device=mesh_device + ) + + negative_val_tensor = ttnn.from_torch( + negative_val_tensor, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + mesh_mapper=inputs_mesh_mapper, + device=mesh_device, + ) + + profiler.start(f"inference_time") + tt_output = ttnn_optimized_distilbert.distilbert_for_question_answering( + config, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + parameters=parameters, + device=mesh_device, + min_val_tensor=min_val_tensor, + negative_val_tensor=negative_val_tensor, + mesh_mapper=weights_mesh_mapper, + ip_mesh_mapper=inputs_mesh_mapper, + ) + profiler.end(f"inference_time") + + tt_output = ( + ttnn.to_torch(ttnn.from_device(tt_output), mesh_composer=output_mesh_composer) + .reshape(batch_size, 1, sequence_size, -1) + .to(torch.float32) + ) + tt_start_logits = tt_output[..., :, 0].squeeze(1) + tt_end_logits = tt_output[..., :, 1].squeeze(1) + model_answers = {} + + profiler.start("post_processing_output_to_string") + for i in range(batch_size): + tt_res = { + "start": tt_start_logits[i], + "end": tt_end_logits[i], + "example": preprocessed_inputs[i]["example"], + **preprocessed_inputs[i]["inputs"], + } + tt_answer = nlp.postprocess([tt_res], **postprocess_params) + logger.info(f"answer: {tt_answer['answer']}\n") + model_answers[i] = tt_answer["answer"] + profiler.end("post_processing_output_to_string") + + measurements = { + "preprocessing_parameter": profiler.get("preprocessing_parameter"), + "preprocessing_input": profiler.get("preprocessing_input"), + "inference_time": profiler.get("inference_time"), + "post_processing": profiler.get("post_processing_output_to_string"), + } + logger.info(f"preprocessing_parameter: {measurements['preprocessing_parameter']} s") + logger.info(f"preprocessing_input: {measurements['preprocessing_input']} s") + logger.info(f"inference_time: {measurements['inference_time']} s") + logger.info(f"post_processing : {measurements['post_processing']} s") + return measurements + + +def run_distilbert_question_and_answering_inference_squad_v2( + use_program_cache, + model_name, + batch_size, + sequence_size, + distilbert, + model_location_generator, + n_iterations, + mesh_device, +): + disable_persistent_kernel_cache() + HF_model = DistilBertForQuestionAnswering.from_pretrained(model_name) + HF_model.eval() + + tt_model_name = f"ttnn_{model_name}_optimized" + + inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) + weights_mesh_mapper = ttnn.ReplicateTensorToMesh(mesh_device) + output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) + + with ttnn.distribute(ttnn.ReplicateTensorToMesh(mesh_device)): + parameters = preprocess_model_parameters( + model_name=tt_model_name, + initialize_model=lambda: HF_model, + custom_preprocessor=ttnn_optimized_distilbert.custom_preprocessor, + device=mesh_device, + ) + + # set up tokenizer + tokenizer = AutoTokenizer.from_pretrained(model_name) + config = HF_model.config + + nlp = pipeline("question-answering", model=HF_model, tokenizer=tokenizer) + attention_mask = True + token_type_ids = False + inputs_squadv2 = squadv2_1K_samples_input(tokenizer, sequence_size, attention_mask, token_type_ids, batch_size) + squad_metric = evaluate.load("squad_v2") + position_ids = torch.arange(config.max_position_embeddings).expand((1, -1)) + position_ids = torch.cat([position_ids] * batch_size, dim=0) + + with torch.no_grad(): + pred_labels = [] + cpu_pred_labels = [] + true_labels = [] + i = 0 + for batch in inputs_squadv2: + if i < n_iterations: + batch_data = batch[0] + curr_batch_size = batch_data["input_ids"].shape[0] + ttnn_distilbert_inputs = distilbert.preprocess_inputs( + batch_data["input_ids"], + position_ids, + batch_data["attention_mask"], + device=mesh_device, + mesh_mapper=inputs_mesh_mapper, + ) + mask_reshp = (batch_size, 1, 1, batch_data["attention_mask"].shape[1]) + score_shape = (batch_size, 12, 384, 384) + + mask = (batch_data["attention_mask"] == 0).view(mask_reshp).expand(score_shape) + min_val = torch.zeros(score_shape) + min_val_tensor = min_val.masked_fill(mask, torch.tensor(torch.finfo(torch.bfloat16).min)) + negative_val = torch.zeros(score_shape) + negative_val_tensor = negative_val.masked_fill(mask, -1) + min_val_tensor = ttnn.from_torch( + min_val_tensor, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + mesh_mapper=inputs_mesh_mapper, + device=mesh_device, + ) + + negative_val_tensor = ttnn.from_torch( + negative_val_tensor, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + mesh_mapper=inputs_mesh_mapper, + device=mesh_device, + ) + + tt_output = ttnn_optimized_distilbert.distilbert_for_question_answering( + config, + input_ids=ttnn_distilbert_inputs[0], + attention_mask=ttnn_distilbert_inputs[2], + position_ids=ttnn_distilbert_inputs[1], + parameters=parameters, + device=mesh_device, + min_val_tensor=min_val_tensor, + negative_val_tensor=negative_val_tensor, + mesh_mapper=weights_mesh_mapper, + ip_mesh_mapper=inputs_mesh_mapper, + ) + tt_output = ( + ttnn.to_torch(tt_output, mesh_composer=output_mesh_composer) + .reshape(batch_size, 1, sequence_size, -1) + .to(torch.float32) + ) + cpu_output = HF_model(**batch_data) + references = batch[1] + question = batch[2] + context = batch[3] + cpu_predictions, tt_predictions = squadv2_answer_decode_batch( + HF_model, + tokenizer, + nlp, + references, + cpu_output, + tt_output, + curr_batch_size, + question, + context, + ) + pred_labels.extend(tt_predictions) + cpu_pred_labels.extend(cpu_predictions) + true_labels.extend(references) + del tt_output + i += 1 + eval_score = squad_metric.compute(predictions=pred_labels, references=true_labels) + cpu_eval_score = squad_metric.compute(predictions=cpu_pred_labels, references=true_labels) + logger.info(f"\tTT_Eval: exact: {eval_score['exact']} -- F1: {eval_score['f1']}") + logger.info(f"\tCPU_Eval: exact: {cpu_eval_score['exact']} -- F1: {cpu_eval_score['f1']}") + + +@skip_for_grayskull() +@pytest.mark.parametrize( + "model_name, input_loc", + ((["distilbert-base-uncased-distilled-squad", "models/demos/wormhole/distilbert/demo/input_data.json"]),), +) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("distilbert", [ttnn_optimized_distilbert]) +def test_demo(input_loc, model_name, distilbert, batch_size, model_location_generator, mesh_device): + disable_persistent_kernel_cache() + disable_compilation_reports() + + if ttnn.GetNumAvailableDevices() == 2: + batch_size = batch_size * 2 + + return run_distilbert_question_and_answering_inference( + model_name=model_name, + batch_size=batch_size, + sequence_size=384, + distilbert=distilbert, + model_location_generator=model_location_generator, + input_path=input_loc, + mesh_device=mesh_device, + ) + + +@skip_for_grayskull() +@pytest.mark.parametrize("model_name", ["distilbert-base-uncased-distilled-squad"]) +@pytest.mark.parametrize("distilbert", [ttnn_optimized_distilbert]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize( + "n_iterations", + ((3),), +) +def test_demo_squadv2( + model_name, distilbert, batch_size, n_iterations, model_location_generator, use_program_cache, mesh_device +): + disable_persistent_kernel_cache() + disable_compilation_reports() + + if ttnn.GetNumAvailableDevices() == 2: + batch_size = batch_size * 2 + return run_distilbert_question_and_answering_inference_squad_v2( + use_program_cache=use_program_cache, + model_name=model_name, + batch_size=batch_size, + sequence_size=384, + distilbert=distilbert, + model_location_generator=model_location_generator, + n_iterations=n_iterations, + mesh_device=mesh_device, + ) diff --git a/models/demos/wormhole/distilbert/demo/input_data.json b/models/demos/wormhole/distilbert/demo/input_data.json new file mode 100644 index 00000000000..f182d147a45 --- /dev/null +++ b/models/demos/wormhole/distilbert/demo/input_data.json @@ -0,0 +1,66 @@ +[ + { + "context" : "Johann Joachim Winckelmann was a German art historian and archaeologist. He was a pioneering Hellenist who first articulated the difference between Greek, Greco-Roman and Roman art. The prophet and founding hero of modern archaeology, Winckelmann was one of the founders of scientific archaeology and first applied the categories of style on a large, systematic basis to the history of art.", + "question" : "What discipline did Winkelmann create?" + }, + { + "context" : "The Norman dynasty had a major political, cultural and military impact on medieval Europe and even the Near East. The Normans were famed for their martial spirit and eventually for their Christian piety, becoming exponents of the Catholic orthodoxy into which they assimilated. They adopted the Gallo-Romance language of the Frankish land they settled, their dialect becoming known as Norman, Normaund or Norman French, an important literary language. The Duchy of Normandy, which they formed by treaty with the French crown, was a great fief of medieval France, and under Richard I of Normandy was forged into a cohesive and formidable principality in feudal tenure. The Normans are noted both for their culture, such as their unique Romanesque architecture and musical traditions, and for their significant military accomplishments and innovations. Norman adventurers founded the Kingdom of Sicily under Roger II after conquering southern Italy on the Saracens and Byzantines, and an expedition on behalf of their duke, William the Conqueror, led to the Norman conquest of England at the Battle of Hastings in 1066. Norman cultural and military influence spread from these new European centres to the Crusader states of the Near East, where their prince Bohemond I founded the Principality of Antioch in the Levant, to Scotland and Wales in Great Britain, to Ireland, and to the coasts of north Africa and the Canary Islands.", + "question" : "Who ruled the duchy of Normandy" + }, + { + "context" : "In many countries, there is a Gender pay gap in favor of males in the labor market. Several factors other than discrimination may contribute to this gap. On average, women are more likely than men to consider factors other than pay when looking for work, and may be less willing to travel or relocate. Thomas Sowell, in his book Knowledge and Decisions, claims that this difference is due to women not taking jobs due to marriage or pregnancy, but income studies show that that does not explain the entire difference. A U.S. Census's report stated that in US once other factors are accounted for there is still a difference in earnings between women and men. The income gap in other countries ranges from 53% in Botswana to -40% in Bahrain.", + "question" : "Who does a gender pay gap tend to favor?" + }, + { + "context" : "Most of the Huguenot congregations (or individuals) in North America eventually affiliated with other Protestant denominations with more numerous members. The Huguenots adapted quickly and often married outside their immediate French communities, which led to their assimilation. Their descendants in many families continued to use French first names and surnames for their children well into the nineteenth century. Assimilated, the French made numerous contributions to United States economic life, especially as merchants and artisans in the late Colonial and early Federal periods. For example, E.I. du Pont, a former student of Lavoisier, established the Eleutherian gunpowder mills.", + "question" : "How were Huguenot settlers assimilated into North American society at large?" + }, + { + "context" : "In the laboratory, biostratigraphers analyze rock samples from outcrop and drill cores for the fossils found in them. These fossils help scientists to date the core and to understand the depositional environment in which the rock units formed. Geochronologists precisely date rocks within the stratigraphic section in order to provide better absolute bounds on the timing and rates of deposition. Magnetic stratigraphers look for signs of magnetic reversals in igneous rock units within the drill cores. Other scientists perform stable isotope studies on the rocks to gain information about past climate.", + "question" : "Who analyzes rock samples from drill cores in the lab?" + }, + { + "context" : "Neutrophils and macrophages are phagocytes that travel throughout the body in pursuit of invading pathogens. Neutrophils are normally found in the bloodstream and are the most abundant type of phagocyte, normally representing 50% to 60% of the total circulating leukocytes. During the acute phase of inflammation, particularly as a result of bacterial infection, neutrophils migrate toward the site of inflammation in a process called chemotaxis, and are usually the first cells to arrive at the scene of infection. Macrophages are versatile cells that reside within tissues and produce a wide array of chemicals including enzymes, complement proteins, and regulatory factors such as interleukin 1. Macrophages also act as scavengers, ridding the body of worn-out cells and other debris, and as antigen-presenting cells that activate the adaptive immune system.", + "question" : "What is the process in which neutrophils move towards the site of inflammation called?" + }, + { + "context" : "In Afghanistan, the mujahideen's victory against the Soviet Union in the 1980s did not lead to justice and prosperity, due to a vicious and destructive civil war between political and tribal warlords, making Afghanistan one of the poorest countries on earth. In 1992, the Democratic Republic of Afghanistan ruled by communist forces collapsed, and democratic Islamist elements of mujahdeen founded the Islamic State of Afghanistan. In 1996, a more conservative and anti-democratic Islamist movement known as the Taliban rose to power, defeated most of the warlords and took over roughly 80% of Afghanistan.", + "question" : "When did the Democratic Republic of Afghanistan collapse?" + }, + { + "context" : "The largest single sensory feature is the aboral organ (at the opposite end from the mouth). Its main component is a statocyst, a balance sensor consisting of a statolith, a solid particle supported on four bundles of cilia, called \"balancers\", that sense its orientation. The statocyst is protected by a transparent dome made of long, immobile cilia. A ctenophore does not automatically try to keep the statolith resting equally on all the balancers. Instead its response is determined by the animal's \"mood\", in other words the overall state of the nervous system. For example, if a ctenophore with trailing tentacles captures prey, it will often put some comb rows into reverse, spinning the mouth towards the prey.", + "question" : "What is the main component of the aboral organ?" + }, + { + "context": "Mark Rothko was a Latvian-born American abstract painter. He is best known for his color field paintings that depicted irregular and painterly rectangular regions of color, which he produced from 1949 to 1970. Although Rothko did not personally subscribe to any one school, he is associated with the American Abstract Expressionist movement of modern art. Originally emigrating to Portland, Oregon, from Russian Empire (Latvia) with his family, Rothko later moved to New York City where his youthful period of artistic production dealt primarily with urban scenery.", + "question": "what is Rothko best known for?" + }, + { + "context": "Malignant narcissism is a psychological syndrome that could include aspects of narcissistic personality disorder (NPD) alongside a mix of antisocial, paranoid and sadistic personality disorder traits. The importance of malignant narcissism and of projection as a defense mechanism has been confirmed in paranoia, as well as the patient's vulnerability to malignant narcissistic regression. A person with malignant narcissism exhibits paranoia in addition to the symptoms of a Narcissistic Personality Disorder. Because a malignant narcissist's personality cannot tolerate any criticism, being mocked typically causes paranoia.", + "question": "What symptoms a malignant narcissist might exhibit in addition to the symptoms of a NPD patient?" + }, + { + "context": "The 14 July Revolution, also known as the 1958 Iraqi military coup, was a coup d'état that took place on 14 July 1958 in Iraq which resulted in the toppling of King Faisal II and the overthrow of the Hashemite-led Kingdom of Iraq. The Iraqi Republic established in its wake ended the Hashemite Arab Federation between Iraq and Jordan that had been established just six months earlier. In July 1958, units of the Royal Iraqi Army were dispatched to Jordan in support of King Hussein. A group of Iraqi Free Officers, led by Brigadier Abd al-Karim Qasim and Colonel Abdul Salam Arif, took advantage of the opportunity and instead marched on Baghdad. On 14 July, revolutionary forces seized control of the capital and proclaimed a new republic, headed by a Revolutionary Council.", + "question": "When was the Hashemite Arab Federation formed?" + }, + { + "context": "The Tasmanian devil is a carnivorous marsupial of the family Dasyuridae. It was formerly present across mainland Australia, but became extinct there around 3,500 years ago. The size of a small dog, the Tasmanian devil became the largest carnivorous marsupial in the world following the extinction of the thylacine in 1936. It is related to quolls, and distantly related to the thylacine. It is characterised by its stocky and muscular build, black fur, pungent odour, extremely loud and disturbing screech, keen sense of smell, and ferocity when feeding. The Tasmanian devil's large head and neck allow it to generate among the strongest bites per unit body mass of any extant predatory land mammal. It hunts prey and scavenges on carrion.", + "question": "What allows Tasmanian devil to generate strong bites?" + }, + { + "context": "Soon after the Normans began to enter Italy, they entered the Byzantine Empire and then Armenia, fighting against the Pechenegs, the Bulgars, and especially the Seljuk Turks. Norman mercenaries were first encouraged to come to the south by the Lombards to act against the Byzantines, but they soon fought in Byzantine service in Sicily. They were prominent alongside Varangian and Lombard contingents in the Sicilian campaign of George Maniaces in 1038-40. There is debate whether the Normans in Greek service actually were from Norman Italy, and it now seems likely only a few came from there. It is also unknown how many of the 'Franks', as the Byzantines called them, were Normans and not other Frenchmen.", + "question": "Who was the Normans' main enemy in Italy, the Byzantine Empire and Armenia?" + }, + { + "context": "Soon after the Normans began to enter Italy, they entered the Byzantine Empire and then Armenia, fighting against the Pechenegs, the Bulgars, and especially the Seljuk Turks. Norman mercenaries were first encouraged to come to the south by the Lombards to act against the Byzantines, but they soon fought in Byzantine service in Sicily. They were prominent alongside Varangian and Lombard contingents in the Sicilian campaign of George Maniaces in 1038-40. There is debate whether the Normans in Greek service actually were from Norman Italy, and it now seems likely only a few came from there. It is also unknown how many of the 'Franks', as the Byzantines called them, were Normans and not other Frenchmen.", + "question": "Who entered Italy soon after the Byzantine Empire?" + }, + { + "context": "Soon after the Normans began to enter Italy, they entered the Byzantine Empire and then Armenia, fighting against the Pechenegs, the Bulgars, and especially the Seljuk Turks. Norman mercenaries were first encouraged to come to the south by the Lombards to act against the Byzantines, but they soon fought in Byzantine service in Sicily. They were prominent alongside Varangian and Lombard contingents in the Sicilian campaign of George Maniaces in 1038-40. There is debate whether the Normans in Greek service actually were from Norman Italy, and it now seems likely only a few came from there. It is also unknown how many of the 'Franks', as the Byzantines called them, were Normans and not other Frenchmen.", + "question": "Who did the Normans fight in Italy?" + }, + { + "context": "Soon after the Normans began to enter Italy, they entered the Byzantine Empire and then Armenia, fighting against the Pechenegs, the Bulgars, and especially the Seljuk Turks. Norman mercenaries were first encouraged to come to the south by the Lombards to act against the Byzantines, but they soon fought in Byzantine service in Sicily. They were prominent alongside Varangian and Lombard contingents in the Sicilian campaign of George Maniaces in 1038-40. There is debate whether the Normans in Greek service actually were from Norman Italy, and it now seems likely only a few came from there. It is also unknown how many of the 'Franks', as the Byzantines called them, were Normans and not other Frenchmen.", + "question": "Who did the Normans encourage to come to the south?" + } +] diff --git a/models/demos/wormhole/distilbert/distilbert_utils.py b/models/demos/wormhole/distilbert/distilbert_utils.py new file mode 100644 index 00000000000..30df5f0f3a2 --- /dev/null +++ b/models/demos/wormhole/distilbert/distilbert_utils.py @@ -0,0 +1,179 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +# SPDX-License-Identifier: Apache-2.0 + +from torch.utils.data import Dataset +from typing import Any +from datasets import load_dataset +from loguru import logger + + +class SQUADV2Dataset(Dataset): + """Configurable SQuad-V2 Dataset.""" + + def __init__( + self, + dataset_question: Any, + dataset_context: Any, + dataset_reference: Any, + tokenizer: Any, + seq_len: int, + attention_mask: bool, + token_type_ids: bool, + ): + """Init and preprocess SST-2 dataset. + Parameters + ---------- + dataset : Any + SQUAD-v2 dataset + tokenizer : Any + tokenizer object from HuggingFace + split : str + Which split to use i.e. ["train", "validation", "test"] + seq_len : int + Sequence length + attention_mask : bool + token_type_ids : bool + """ + self.data = [] + for i in range(len(dataset_question)): + self.data.append( + ( + tokenizer( + dataset_question[i], + dataset_context[i], + max_length=seq_len, + padding="max_length", + truncation=True, + return_attention_mask=attention_mask, + # return_token_type_ids=token_type_ids, + return_tensors="pt", + ), + dataset_reference[i], + dataset_question[i], + dataset_context[i], + ) + ) + + def __len__(self): + """Return length of dataset. + Returns + ------- + int + Length of dataset + """ + return len(self.data) + + def __getitem__(self, index: int): + """Return sample from dataset. + Parameters + ---------- + index : int + Index of sample + Returns + ------- + Tuple + Data sample in format of X, y + """ + X = self.data[index] + return X + + +def squad_divide_chunks(dataset_question, dataset_context, dataset_reference, batch): + dataset_question_b = [] + dataset_context_b = [] + dataset_reference_b = [] + for i in range(0, len(dataset_question), batch): + dataset_question_b.append(dataset_question[i : i + batch]) + dataset_context_b.append(dataset_context[i : i + batch]) + dataset_reference_b.append(dataset_reference[i : i + batch]) + return dataset_question_b, dataset_context_b, dataset_reference_b + + +def squadv2_1K_samples_input(tokenizer, seq_len, attention_mask, token_type_ids, microbatch=8): + squadv2_dataset = load_dataset("squad_v2", use_auth_token=False, streaming=True)["validation"] + dataset_iter = iter(squadv2_dataset) + dataset_question = [] + dataset_context = [] + dataset_reference = [] + for _ in range(2048): + dataset_sgl = next(dataset_iter) + if len(dataset_sgl["answers"]["text"]) > 0: + dataset_question.append(dataset_sgl["question"]) + dataset_context.append(dataset_sgl["context"]) + dataset_reference.append({"answers": dataset_sgl["answers"], "id": dataset_sgl["id"]}) + if len(dataset_question) == 1024: + logger.info("SQuADv2 1024 samples load ..done") + break + dataset_question, dataset_context, dataset_reference = squad_divide_chunks( + dataset_question, dataset_context, dataset_reference, microbatch + ) + dataset_processed = SQUADV2Dataset( + dataset_question, + dataset_context, + dataset_reference, + tokenizer=tokenizer, + seq_len=seq_len, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + ) + return dataset_processed + + +def squadv2_answer_decode_batch( + HF_model, + tokenizer, + nlp, + references, + cpu_out, + tt_untilized_output, + BATCH_SIZE, + question, + context, + seq_len=384, + padding=None, +): + tt_predictions = [] + cpu_predictions = [] + preprocess_params, _, postprocess_params = nlp._sanitize_parameters(max_seq_len=seq_len, padding="max_length") + input_q = {"context": context, "question": question} + examples = nlp._args_parser(input_q) + for i in range(BATCH_SIZE): + logger.info(f"--REF-- {references[i]['answers']['text']}") + answer_start_scores = cpu_out["start_logits"][i] + answer_end_scores = cpu_out["end_logits"][i] + tt_start_logits = tt_untilized_output[..., :, 0].squeeze(1)[i] + tt_end_logits = tt_untilized_output[..., :, 1].squeeze(1)[i] + model_input = next(nlp.preprocess(examples[0][i], **preprocess_params)) + single_input = { + "data": ( + model_input["input_ids"], + model_input["attention_mask"], + model_input["token_type_ids"], + ), + "example": model_input["example"], + "inputs": model_input, + } + pt_res = { + "start": answer_start_scores, + "end": answer_end_scores, + "example": single_input["example"], + **single_input["inputs"], + } + cpu_answer_nlp = nlp.postprocess([pt_res], **postprocess_params)["answer"] + tt_res = { + "start": tt_start_logits, + "end": tt_end_logits, + "example": single_input["example"], + **single_input["inputs"], + } + tt_answer_nlp = nlp.postprocess([tt_res], **postprocess_params)["answer"] + logger.info(f"--CPU-- {cpu_answer_nlp}") + logger.info(f"--TT--- {tt_answer_nlp}") + logger.info(f"=======") + cpu_predictions.append( + {"prediction_text": cpu_answer_nlp, "id": references[i]["id"], "no_answer_probability": 0.0} + ) + tt_predictions.append( + {"prediction_text": tt_answer_nlp, "id": references[i]["id"], "no_answer_probability": 0.0} + ) + return cpu_predictions, tt_predictions diff --git a/models/demos/wormhole/distilbert/tests/test_perf_distilbert.py b/models/demos/wormhole/distilbert/tests/test_perf_distilbert.py new file mode 100644 index 00000000000..e167e350dae --- /dev/null +++ b/models/demos/wormhole/distilbert/tests/test_perf_distilbert.py @@ -0,0 +1,196 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch +import pytest +import ttnn +from loguru import logger +import time + +from models.demos.wormhole.distilbert.tt import ttnn_optimized_distilbert +from models.utility_functions import ( + enable_persistent_kernel_cache, + disable_persistent_kernel_cache, + profiler, +) +from ttnn.model_preprocessing import ( + preprocess_model_parameters, +) +from models.perf.perf_utils import prep_perf_report +from transformers import DistilBertForQuestionAnswering, AutoTokenizer +from models.perf.device_perf_utils import run_device_perf, check_device_perf, prep_device_perf_report +from models.utility_functions import is_grayskull, is_wormhole_b0, skip_for_grayskull + + +@skip_for_grayskull() +@pytest.mark.models_performance_bare_metal +@pytest.mark.parametrize("model_name", ["distilbert-base-uncased-distilled-squad"]) +@pytest.mark.parametrize( + "batch_size, seq_len, expected_inference_time, expected_compile_time", + ([8, 384, 15.00, 16.00],), +) +def test_performance_distilbert_for_qa( + mesh_device, + batch_size, + model_name, + seq_len, + expected_inference_time, + expected_compile_time, +): + if ttnn.GetNumAvailableDevices() == 2: + batch_size = batch_size * 2 + HF_model = DistilBertForQuestionAnswering.from_pretrained(model_name) + HF_model.eval() + + # set up tokenizer + tokenizer = AutoTokenizer.from_pretrained(model_name) + config = HF_model.config + + disable_persistent_kernel_cache() + + cpu_key = "ref_key" + + context = batch_size * [ + "Johann Joachim Winckelmann was a German art historian and archaeologist. He was a pioneering Hellenist who first articulated the difference between Greek, Greco-Roman and Roman art. The prophet and founding hero of modern archaeology, Winckelmann was one of the founders of scientific archaeology and first applied the categories of style on a large, systematic basis to the history of art." + ] + + question = batch_size * ["What discipline did Winkelmann create?"] + inputs = tokenizer( + question, + context, + max_length=seq_len, + padding="max_length", + truncation=True, + return_attention_mask=True, + return_tensors="pt", + ) + tt_model_name = f"ttnn_{model_name}_optimized" + + inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) + weights_mesh_mapper = ttnn.ReplicateTensorToMesh(mesh_device) + + profiler.start(f"preprocessing_parameter") + with ttnn.distribute(ttnn.ReplicateTensorToMesh(mesh_device)): + parameters = preprocess_model_parameters( + model_name=tt_model_name, + initialize_model=lambda: HF_model, + custom_preprocessor=ttnn_optimized_distilbert.custom_preprocessor, + device=mesh_device, + ) + profiler.end(f"preprocessing_parameter") + + mask_reshp = (batch_size, 1, 1, inputs["attention_mask"].shape[1]) + score_shape = (batch_size, 12, 384, 384) + + mask = (inputs["attention_mask"] == 0).view(mask_reshp).expand(score_shape) + min_val = torch.zeros(score_shape) + min_val_tensor = min_val.masked_fill(mask, torch.tensor(torch.finfo(torch.bfloat16).min)) + + negative_val = torch.zeros(score_shape) + negative_val_tensor = negative_val.masked_fill(mask, -1) + min_val_tensor = ttnn.from_torch( + min_val_tensor, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, mesh_mapper=inputs_mesh_mapper, device=mesh_device + ) + + negative_val_tensor = ttnn.from_torch( + negative_val_tensor, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + mesh_mapper=inputs_mesh_mapper, + device=mesh_device, + ) + + with torch.no_grad(): + profiler.start(cpu_key) + torch_out = HF_model(**inputs) + profiler.end(cpu_key) + + durations = [] + for _ in range(2): + position_ids = torch.arange(config.max_position_embeddings).expand((1, -1)) + position_ids = torch.cat([position_ids] * batch_size, dim=0) + profiler.start(f"preprocessing_input") + input_ids, position_ids, attention_mask = ttnn_optimized_distilbert.preprocess_inputs( + inputs["input_ids"], + position_ids, + inputs["attention_mask"], + device=mesh_device, + mesh_mapper=inputs_mesh_mapper, + ) + profiler.end(f"preprocessing_input") + + start = time.time() + tt_output = ttnn_optimized_distilbert.distilbert_for_question_answering( + config, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + parameters=parameters, + device=mesh_device, + min_val_tensor=min_val_tensor, + negative_val_tensor=negative_val_tensor, + mesh_mapper=weights_mesh_mapper, + ip_mesh_mapper=inputs_mesh_mapper, + ) + tt_output = ttnn.from_device(tt_output) + end = time.time() + + durations.append(end - start) + enable_persistent_kernel_cache() + + inference_and_compile_time, inference_time, *_ = durations + + prep_perf_report( + model_name=f"ttnn_{model_name}_optimized", + batch_size=batch_size, + inference_and_compile_time=inference_and_compile_time, + inference_time=inference_time, + expected_compile_time=expected_compile_time, + expected_inference_time=expected_inference_time, + comments="", + inference_time_cpu=0.0, + ) + + logger.info(f"Compile time: {inference_and_compile_time - inference_time}") + logger.info(f"Inference time: {inference_time}") + logger.info(f"Samples per second: {1 / inference_time * batch_size}") + + assert ( + inference_time < expected_inference_time + ), f"Expected inference time: {expected_inference_time} Actual inference time: {inference_time}" + logger.info("Exit Distilbert perf test") + + +@skip_for_grayskull() +@pytest.mark.models_device_performance_bare_metal +@pytest.mark.parametrize( + "batch_size, test", + [ + [8, "distilbert-base-uncased-distilled-squad"], + ], +) +def test_distilbert_perf_device(batch_size, test, reset_seeds): + subdir = "ttnn_distilbert" + margin = 0.03 + num_iterations = 1 + + expected_perf = 224 + if ttnn.GetNumAvailableDevices() == 2: + batch_size = batch_size * 2 + + command = f"pytest tests/ttnn/integration_tests/distilbert/test_ttnn_distilbert_wh.py::test_distilbert_for_question_answering[silicon_arch_name=wormhole_b0-silicon_arch_wormhole_b0=True-sequence_size=768-batch_size=8-model_name=distilbert-base-uncased-distilled-squad]" + + cols = ["DEVICE FW", "DEVICE KERNEL", "DEVICE BRISC KERNEL"] + inference_time_key = "AVG DEVICE KERNEL SAMPLES/S" + expected_perf_cols = {inference_time_key: expected_perf} + + post_processed_results = run_device_perf(command, subdir, num_iterations, cols, batch_size) + expected_results = check_device_perf(post_processed_results, margin, expected_perf_cols, assert_on_fail=True) + prep_device_perf_report( + model_name=f"ttnn_distilbert{batch_size}", + batch_size=batch_size, + post_processed_results=post_processed_results, + expected_results=expected_results, + comments=test.replace("/", "_"), + ) diff --git a/models/demos/wormhole/distilbert/tt/ttnn_optimized_distilbert.py b/models/demos/wormhole/distilbert/tt/ttnn_optimized_distilbert.py new file mode 100644 index 00000000000..2e294ac8ec3 --- /dev/null +++ b/models/demos/wormhole/distilbert/tt/ttnn_optimized_distilbert.py @@ -0,0 +1,406 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import ttnn +from typing import Optional +import torch +from ttnn.model_preprocessing import ( + preprocess_linear_bias, + preprocess_linear_weight, +) +import ttnn.torch_tracer + + +def get_head_mask( + head_mask: Optional[ttnn.Tensor], + num_hidden_layers: int, + is_attention_chunked: bool = False, +): + head_mask = [ + None, + ] * num_hidden_layers + return head_mask + + +def attention( + config, + hidden_states, + mask, + head_mask=None, + output_attentions=None, + device=None, + base_address=None, + parameters=None, + num_cores_x=12, + min_val_tensor=None, + negative_val_tensor=None, + mesh_mapper=None, +): + batch_size, q_length, dim = hidden_states.shape + k_length = hidden_states.shape[1] + dim_per_head = config.dim // config.n_heads + + query_key_value_output = ttnn.linear( + hidden_states, + parameters.query_key_value.weight, + bias=parameters.query_key_value.bias, + memory_config=ttnn.L1_MEMORY_CONFIG, + dtype=ttnn.bfloat8_b, + core_grid=ttnn.CoreGrid(y=device.core_grid.y, x=device.core_grid.x), + ) + + ( + query, + key, + value, + ) = ttnn.transformer.split_query_key_value_and_split_heads( + query_key_value_output, + memory_config=ttnn.L1_MEMORY_CONFIG, + num_heads=config.n_heads, + ) + ttnn.deallocate(query_key_value_output) + + query = query * (1 / (dim_per_head) ** 0.5) + + attention_scores = ttnn.matmul( + query, + key, + memory_config=ttnn.L1_MEMORY_CONFIG, + dtype=ttnn.bfloat16, + core_grid=ttnn.CoreGrid(y=device.core_grid.y, x=device.core_grid.x), + ) + ttnn.deallocate(query) + ttnn.deallocate(key) + score_list = [] + + if batch_size <= 2: + inter_scores = attention_scores * negative_val_tensor + inter_scores = inter_scores + attention_scores + scores = inter_scores + min_val_tensor + else: + for i in range(2, batch_size + 1, 2): + inter_scores = attention_scores[i - 2 : i, :, :, :] * negative_val_tensor[i - 2 : i, :, :, :] + inter_scores = inter_scores + attention_scores[i - 2 : i, :, :, :] + score = inter_scores + min_val_tensor[i - 2 : i, :, :, :] + score = ttnn.permute(score, (1, 0, 2, 3)) + + score_list.append(score) + ttnn.deallocate(inter_scores) + + scores = ttnn.concat(score_list, dim=1) + scores = ttnn.permute(scores, (1, 0, 2, 3)) + + weights = ttnn.transformer.attention_softmax(scores, head_size=1) + ttnn.deallocate(scores) + context_layer = ttnn.matmul( + weights, + value, + memory_config=ttnn.L1_MEMORY_CONFIG, + dtype=ttnn.bfloat16, + core_grid=ttnn.CoreGrid(y=device.core_grid.y, x=device.core_grid.x), + ) + + ttnn.deallocate(weights) + ttnn.deallocate(value) + + context_layer = ttnn.permute(context_layer, [0, 1, 3, 2]) + context_layer = ttnn.reshape(context_layer, (batch_size, config.n_heads * dim_per_head, -1)) + + context_layer = ttnn.permute(context_layer, (0, 2, 1)) + + self_output = ttnn.linear( + context_layer, + parameters.out_lin.weight, + bias=parameters.out_lin.bias, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + dtype=ttnn.bfloat16, + ) + ttnn.deallocate(context_layer) + return self_output + + +def ffn(configs, hidden_state, device, base_address, parameters, num_cores_x=12, mesh_mapper=None): + batch_size, *_ = hidden_state.shape + + output = ttnn.linear( + hidden_state, + parameters.lin1.weight, + bias=parameters.lin1.bias, + memory_config=ttnn.L1_MEMORY_CONFIG, + dtype=ttnn.bfloat16, + activation="gelu", + core_grid=ttnn.CoreGrid(y=device.core_grid.y, x=device.core_grid.x), + ) + + output = ttnn.linear( + output, + parameters.lin2.weight, + bias=parameters.lin2.bias, + memory_config=ttnn.L1_MEMORY_CONFIG, + dtype=ttnn.bfloat16, + core_grid=ttnn.CoreGrid(y=device.core_grid.y, x=device.core_grid.x), + ) + return output + + +def transformer_block( + config, + x, + attention_mask=None, + head_mask=None, + output_attentions: bool = False, + base_address=None, + parameters=None, + device=None, + min_val_tensor=None, + negative_val_tensor=None, + mesh_mapper=None, +): + sa_output = attention( + config, + x, + attention_mask, + head_mask, + output_attentions, + device=device, + base_address=base_address, + parameters=parameters.attention, + min_val_tensor=min_val_tensor, + negative_val_tensor=negative_val_tensor, + ) + + sa_output = ttnn.layer_norm( + x + sa_output, + weight=parameters.sa_layer_norm.weight, + bias=parameters.sa_layer_norm.bias, + epsilon=1e-12, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + ttnn.deallocate(x) + + ffn_output = ffn(config, sa_output, device=device, base_address=base_address, parameters=parameters.ffn) + + ffn_output = ttnn.layer_norm( + ffn_output + sa_output, + weight=parameters.output_layer_norm.weight, + bias=parameters.output_layer_norm.bias, + epsilon=1e-12, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + + ttnn.deallocate(sa_output) + return ffn_output + + +def transformer( + config, + x, + attention_mask=None, + head_mask=None, + output_attentions: bool = False, + output_hidden_states: bool = False, + base_address=None, + parameters=None, + device=None, + min_val_tensor=None, + negative_val_tensor=None, + mesh_mapper=None, +): + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + hidden_state = x + + for params in parameters.layer: + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_state,) + + layer_outputs = transformer_block( + config=config, + x=hidden_state, + attention_mask=attention_mask, + head_mask=None, + output_attentions=output_attentions, + base_address=f"{base_address}.layer", + parameters=params, + device=device, + min_val_tensor=min_val_tensor, + negative_val_tensor=negative_val_tensor, + ) + hidden_state = layer_outputs + + return hidden_state + + +def distilbert( + config, + input_ids=None, + attention_mask=None, + head_mask=None, + inputs_embeds=None, + output_attentions=None, + output_hidden_states=None, + position_ids=None, + min_val_tensor=None, + negative_val_tensor=None, + *, + base_address, + parameters, + device, + mesh_mapper=None, + ip_mesh_mapper=None, +): + output_attentions = output_attentions if output_attentions is not None else config.output_attentions + output_hidden_states = output_hidden_states if output_hidden_states is not None else config.output_hidden_states + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.shape + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + head_mask = get_head_mask(head_mask, config.num_hidden_layers) + + if input_ids is not None: + word_embeddings = ttnn.embedding( + input_ids, + parameters.distilbert.embeddings.word_embeddings.weight, + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + ) + seq_length = word_embeddings.shape[1] + + if position_ids is not None: + position_ids = position_ids[:, :seq_length] + + position_embeddings = ttnn.embedding( + position_ids, + parameters.distilbert.embeddings.position_embeddings.weight, + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + ) + + transpose = False + if word_embeddings.shape[0] > 1: + word_embeddings = ttnn.permute(word_embeddings, (1, 2, 0)) + position_embeddings = ttnn.permute(position_embeddings, (1, 2, 0)) + transpose = True + + embeddings = word_embeddings + position_embeddings + + ttnn.deallocate(word_embeddings) + + if transpose: + embeddings = ttnn.permute(embeddings, (2, 0, 1)) + + embeddings = ttnn.layer_norm( + embeddings, + epsilon=1e-12, + weight=parameters.distilbert.embeddings.LayerNorm.weight, + bias=parameters.distilbert.embeddings.LayerNorm.bias, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + + return transformer( + config, + embeddings, + attention_mask, + head_mask, + output_attentions, + output_hidden_states, + base_address=f"distilbert.transformer", + parameters=parameters.distilbert.transformer, + device=device, + min_val_tensor=min_val_tensor, + negative_val_tensor=negative_val_tensor, + ) + + +def distilbert_for_question_answering( + config, + input_ids, + attention_mask, + position_ids=None, + head_mask=None, + inputs_embeds=None, + start_positions=None, + end_positions=None, + output_attentions=None, + output_hidden_states=None, + min_val_tensor=None, + negative_val_tensor=None, + *, + parameters, + device, + base_address="", + mesh_mapper=None, + ip_mesh_mapper=None, +): + distilbert_output = distilbert( + config, + input_ids, + attention_mask, + head_mask, + inputs_embeds, + output_attentions, + output_hidden_states, + position_ids=position_ids, + device=device, + base_address=f"", + parameters=parameters, + min_val_tensor=min_val_tensor, + negative_val_tensor=negative_val_tensor, + ip_mesh_mapper=ip_mesh_mapper, + ) + + qa_outputs = ttnn.linear( + distilbert_output, + parameters.qa_outputs.weight, + bias=parameters.qa_outputs.bias, + memory_config=ttnn.L1_MEMORY_CONFIG, + core_grid=ttnn.CoreGrid(y=device.core_grid.y, x=device.core_grid.x), + ) + + return qa_outputs + + +def preprocess_inputs( + input_ids, + position_ids, + attention_mask, + device, + mesh_mapper, +): + input_ids = ttnn.from_torch(input_ids, mesh_mapper=mesh_mapper, device=device) + if position_ids is not None: + position_ids = ttnn.from_torch(position_ids, mesh_mapper=mesh_mapper, device=device) + attention_mask = ttnn.from_torch(attention_mask, mesh_mapper=mesh_mapper, device=device) + return (input_ids, position_ids, attention_mask) + + +def custom_preprocessor(torch_model, name): + parameters = {} + + if hasattr(torch_model, "q_lin") and hasattr(torch_model, "k_lin") and hasattr(torch_model, "v_lin"): + qkv_weight = torch.cat( + [ + torch_model.q_lin.weight, + torch_model.k_lin.weight, + torch_model.v_lin.weight, + ], + dim=0, + ) + qkv_bias = torch.cat( + [torch_model.q_lin.bias, torch_model.k_lin.bias, torch_model.v_lin.bias], + dim=0, + ) + output_weight = torch_model.out_lin.weight + output_bias = torch_model.out_lin.bias + parameters = {"query_key_value": {}, "out_lin": {}} + parameters["query_key_value"]["weight"] = preprocess_linear_weight(qkv_weight, dtype=ttnn.bfloat16) + parameters["query_key_value"]["bias"] = preprocess_linear_bias(qkv_bias, dtype=ttnn.bfloat16) + parameters["out_lin"]["weight"] = preprocess_linear_weight(output_weight, dtype=ttnn.bfloat16) + parameters["out_lin"]["bias"] = preprocess_linear_bias(output_bias, dtype=ttnn.bfloat16) + return parameters diff --git a/pyproject.toml b/pyproject.toml index c06a688beb5..2281efdb014 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ build-backend = "setuptools.build_meta" [project] # To change to eager, or something similar -name = "metal_libs" +name = "ttnn" authors = [ {name = "Tenstorrent"}, {email = "info@tenstorrent.com"}, diff --git a/tests/scripts/run_performance.sh b/tests/scripts/run_performance.sh index fd3920176ee..ebb7179b098 100755 --- a/tests/scripts/run_performance.sh +++ b/tests/scripts/run_performance.sh @@ -21,6 +21,8 @@ run_perf_models_other() { env WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest models/demos/wormhole/bert_tiny/tests/test_performance.py -m $test_marker env WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest models/demos/yolov4/tests/test_perf_yolo.py -m $test_marker + + env WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest models/demos/wormhole/distilbert/tests/test_perf_distilbert.py -m $test_marker fi env pytest -n auto tests/ttnn/integration_tests/bert/test_performance.py -m $test_marker @@ -129,6 +131,8 @@ run_device_perf_models() { env WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest models/demos/wormhole/bert_tiny/tests -m $test_marker env WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest models/demos/yolov4/tests/ -m $test_marker + + env WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest models/demos/wormhole/distilbert/tests -m $test_marker fi ## Merge all the generated reports diff --git a/tests/scripts/set_up_end_to_end_tests_env.sh b/tests/scripts/set_up_end_to_end_tests_env.sh index 9a7e1e6c386..3c57afe72b1 100755 --- a/tests/scripts/set_up_end_to_end_tests_env.sh +++ b/tests/scripts/set_up_end_to_end_tests_env.sh @@ -19,7 +19,7 @@ set_up_end_to_end_tests_env() { python -m pip config set global.extra-index-url https://download.pytorch.org/whl/cpu python -m pip install -r requirements.txt - python -m pip install ../../metal_libs-*.whl + python -m pip install ../../ttnn-*.whl cd ../../ rm -rf tt_metal tt_eager ttnn models diff --git a/tests/scripts/single_card/run_single_card_demo_tests.sh b/tests/scripts/single_card/run_single_card_demo_tests.sh index 071c0f5734a..f0d5cf1f4a3 100755 --- a/tests/scripts/single_card/run_single_card_demo_tests.sh +++ b/tests/scripts/single_card/run_single_card_demo_tests.sh @@ -59,6 +59,9 @@ run_common_func_tests() { # SqueezeBERT pytest --disable-warnings models/demos/squeezebert/demo/demo.py --timeout 600; fail+=$? + # Distilbert + WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest --disable-warnings models/demos/wormhole/distilbert/demo/demo.py --timeout 600; fail+=$? + return $fail } diff --git a/tests/sweep_framework/sweeps/tilize_with_val_padding.py b/tests/sweep_framework/sweeps/tilize_with_val_padding.py new file mode 100644 index 00000000000..12ffda7abf0 --- /dev/null +++ b/tests/sweep_framework/sweeps/tilize_with_val_padding.py @@ -0,0 +1,110 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, Tuple +from functools import partial + +import torch +import random +import ttnn +import math +from tests.sweep_framework.sweep_utils.utils import gen_shapes +from tests.tt_eager.python_api_testing.sweep_tests.generation_funcs import gen_func_with_cast_tt + +from tests.ttnn.utils_for_testing import comp_equal, start_measuring_time, stop_measuring_time +from models.utility_functions import torch_random + +from tt_lib.utils import tilize as tilize_util + +# Override the default timeout in seconds for hang detection. +TIMEOUT = 30 + +random.seed(0) + + +def tilize_with_val_padding(x, output_tensor_shape, pad_value): + pad = torch.nn.functional.pad( + x, + tuple(j for i in reversed(range(len(x.shape))) for j in (0, output_tensor_shape[i] - x.shape[i])), + value=pad_value, + ) + tilized = tilize_util(pad) + return tilized + + +def _nearest_32(x): + return math.ceil(x / 32) * 32 + + +# Parameters provided to the test vector generator are defined here. +# They are defined as dict-type suites that contain the arguments to the run function as keys, and lists of possible inputs as values. +# Each suite has a key name (in this case "suite_1" and "suite_2") which will associate the test vectors to this specific suite of inputs. +# Developers can create their own generator functions and pass them to the parameters as inputs. +parameters = { + "xfail": { + "input_shape": gen_shapes([1, 1, 32, 32], [6, 12, 128, 128], [1, 1, 32, 32], 32), + "input_a_dtype": [ttnn.bfloat16], + "input_a_layout": [ttnn.ROW_MAJOR_LAYOUT], + "input_a_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + "output_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + }, +} + + +# Invalidate vector is called during the generation phase where each vector will be passed in. +# If invalidated, the vector will still be stored but will be skipped. +# Returns False, None if the vector is valid, and True, str with a reason for invalidation if it is invalid. +def invalidate_vector(test_vector) -> Tuple[bool, Optional[str]]: + if test_vector["input_a_layout"] == ttnn.TILE_LAYOUT: + return True, "Tile layout is not supported" + return False, None + + +# This is the run instructions for the test, defined by the developer. +# The run function must take the above-defined parameters as inputs. +# The runner will call this run function with each test vector, and the returned results from this function will be stored. +# If you defined a mesh_device_fixture above, the object you yielded will be passed into this function as 'device'. Otherwise, it will be the default ttnn device opened by the infra. +def run( + input_shape, + input_a_dtype, + input_a_layout, + input_a_memory_config, + output_memory_config, + *, + device, +) -> list: + data_seed = random.randint(0, 20000000) + torch.manual_seed(data_seed) + + torch_input_tensor_a = gen_func_with_cast_tt( + partial(torch_random, low=-100, high=100, dtype=torch.float32), input_a_dtype + )(input_shape) + + number_generated = torch.tensor(1, dtype=torch.bfloat16).uniform_(-100, 100).item() + + padded_shape = [ + input_shape[0], + input_shape[1], + _nearest_32(input_shape[2]), + _nearest_32(input_shape[3]), + ] + + torch_output_tensor = tilize_with_val_padding(torch_input_tensor_a, padded_shape, number_generated) + + input_tensor_a = ttnn.from_torch( + torch_input_tensor_a, + dtype=input_a_dtype, + layout=input_a_layout, + device=device, + memory_config=input_a_memory_config, + ) + + start_time = start_measuring_time() + result = ttnn.tilize_with_val_padding( + input_tensor_a, padded_shape, number_generated, memory_config=output_memory_config + ) + output_tensor = ttnn.to_torch(result) + e2e_perf = stop_measuring_time(start_time) + + return [comp_equal(torch_output_tensor, output_tensor), e2e_perf] diff --git a/tests/sweep_framework/sweeps/untilize_with_unpadding.py b/tests/sweep_framework/sweeps/untilize_with_unpadding.py new file mode 100644 index 00000000000..b1e8113a941 --- /dev/null +++ b/tests/sweep_framework/sweeps/untilize_with_unpadding.py @@ -0,0 +1,108 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, Tuple +from functools import partial + +import torch +import random +import ttnn +from tt_lib.utils import untilize as untilize_util + +from tests.sweep_framework.sweep_utils.utils import gen_shapes +from tests.tt_eager.python_api_testing.sweep_tests.generation_funcs import gen_func_with_cast_tt + +from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time +from models.utility_functions import torch_random + +# Override the default timeout in seconds for hang detection. +TIMEOUT = 30 + +random.seed(0) + + +# Parameters provided to the test vector generator are defined here. +# They are defined as dict-type suites that contain the arguments to the run function as keys, and lists of possible inputs as values. +# Each suite has a key name (in this case "suite_1" and "suite_2") which will associate the test vectors to this specific suite of inputs. +# Developers can create their own generator functions and pass them to the parameters as inputs. +parameters = { + "xfail": { + "input_shape": gen_shapes([1, 1, 32, 32], [6, 12, 128, 128], [1, 1, 32, 32], 16), + "input_a_dtype": [ttnn.bfloat16, ttnn.bfloat8_b], + "input_a_layout": [ttnn.TILE_LAYOUT], + "input_a_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + "output_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + }, + "xfail-2": { + "input_shape": gen_shapes([1, 1, 32, 32], [6, 12, 128, 128], [1, 1, 1, 1], 16) + + gen_shapes([1, 32, 32], [12, 256, 256], [1, 1, 1], 16) + + gen_shapes([32, 32], [256, 256], [1, 1], 16), + "input_a_dtype": [ttnn.bfloat16, ttnn.bfloat8_b], + "input_a_layout": [ttnn.TILE_LAYOUT], + "input_a_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + "output_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + }, +} + + +# Invalidate vector is called during the generation phase where each vector will be passed in. +# If invalidated, the vector will still be stored but will be skipped. +# Returns False, None if the vector is valid, and True, str with a reason for invalidation if it is invalid. +def invalidate_vector(test_vector) -> Tuple[bool, Optional[str]]: + if test_vector["input_a_layout"] == ttnn.ROW_MAJOR_LAYOUT: + return True, "Row major layout is not supported" + return False, None + + +def untilize_with_unpadding(x, output_tensor_end, *args, **kwargs): + untilized = untilize_util(x) + unpad = untilized[ + : output_tensor_end[0] + 1, + : output_tensor_end[1] + 1, + : output_tensor_end[2] + 1, + : output_tensor_end[3] + 1, + ] + return unpad + + +# This is the run instructions for the test, defined by the developer. +# The run function must take the above-defined parameters as inputs. +# The runner will call this run function with each test vector, and the returned results from this function will be stored. +# If you defined a mesh_device_fixture above, the object you yielded will be passed into this function as 'device'. Otherwise, it will be the default ttnn device opened by the infra. +def run( + input_shape, + input_a_dtype, + input_a_layout, + input_a_memory_config, + output_memory_config, + *, + device, +) -> list: + data_seed = random.randint(0, 20000000) + torch.manual_seed(data_seed) + + torch_input_tensor_a = gen_func_with_cast_tt( + partial(torch_random, low=-100, high=100, dtype=torch.float32), input_a_dtype + )(input_shape) + + end_shape = [] + for s in input_shape: + end_shape.append(random.randint(0, s - 1)) + + torch_output_tensor = untilize_with_unpadding(torch_input_tensor_a, end_shape) + + input_tensor_a = ttnn.from_torch( + torch_input_tensor_a, + dtype=input_a_dtype, + layout=input_a_layout, + device=device, + memory_config=input_a_memory_config, + ) + + start_time = start_measuring_time() + result = ttnn.untilize_with_unpadding(input_tensor_a, end_shape, memory_config=output_memory_config) + output_tensor = ttnn.to_torch(result) + e2e_perf = stop_measuring_time(start_time) + + return [check_with_pcc(torch_output_tensor, output_tensor, 0.999), e2e_perf] diff --git a/tests/ttnn/integration_tests/distilbert/test_ttnn_distilbert_wh.py b/tests/ttnn/integration_tests/distilbert/test_ttnn_distilbert_wh.py new file mode 100644 index 00000000000..5d2dd6284bd --- /dev/null +++ b/tests/ttnn/integration_tests/distilbert/test_ttnn_distilbert_wh.py @@ -0,0 +1,116 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +import torch + +import ttnn +from ttnn.model_preprocessing import preprocess_model_parameters +from transformers import ( + DistilBertForQuestionAnswering as HF_DistilBertForQuestionAnswering, +) +from transformers import AutoTokenizer +from models.demos.wormhole.distilbert.tt import ttnn_optimized_distilbert +from tests.ttnn.utils_for_testing import assert_with_pcc +from models.utility_functions import is_wormhole_b0, skip_for_grayskull + + +@skip_for_grayskull() +@pytest.mark.parametrize("model_name", ["distilbert-base-uncased-distilled-squad"]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("sequence_size", [768]) +def test_distilbert_for_question_answering(mesh_device, model_name, batch_size, sequence_size, reset_seeds): + tokenizer = AutoTokenizer.from_pretrained(model_name) + HF_model = HF_DistilBertForQuestionAnswering.from_pretrained(model_name) + HF_model.eval() + + tt_model_name = f"ttnn_{model_name}_optimized" + + inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) + weights_mesh_mapper = ttnn.ReplicateTensorToMesh(mesh_device) + output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) + + if ttnn.GetNumAvailableDevices() == 2: + batch_size = batch_size * 2 + + with ttnn.distribute(ttnn.ReplicateTensorToMesh(mesh_device)): + parameters = preprocess_model_parameters( + model_name=tt_model_name, + initialize_model=lambda: HF_model, + custom_preprocessor=ttnn_optimized_distilbert.custom_preprocessor, + device=mesh_device, + ) + + model = HF_model.eval() + config = HF_model.config + + question = batch_size * ["Where do I live?"] + context = batch_size * ["My name is Merve and I live in İstanbul."] + inputs = tokenizer( + question, + context, + return_tensors="pt", + padding="max_length", + max_length=384, + truncation=True, + return_attention_mask=True, + ) + input_ids = inputs.input_ids + attention_mask = inputs.attention_mask + position_ids = torch.arange(config.max_position_embeddings).expand((1, -1)) + position_ids = torch.cat([position_ids] * batch_size, dim=0) + mask_reshp = (batch_size, 1, 1, attention_mask.shape[1]) + score_shape = (batch_size, 12, 384, 384) + + mask = (attention_mask == 0).view(mask_reshp).expand(score_shape) + min_val = torch.zeros(score_shape) + min_val_tensor = min_val.masked_fill(mask, torch.tensor(torch.finfo(torch.bfloat16).min)) + + negative_val = torch.zeros(score_shape) + negative_val_tensor = negative_val.masked_fill(mask, -1) + torch_output = model(input_ids, attention_mask) + + tt_model_name = f"ttnn_{model_name}_optimized" + + input_ids, position_ids, attention_mask = ttnn_optimized_distilbert.preprocess_inputs( + input_ids, position_ids, attention_mask, device=mesh_device, mesh_mapper=inputs_mesh_mapper + ) + + min_val_tensor = ttnn.from_torch( + min_val_tensor, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + mesh_mapper=inputs_mesh_mapper, + device=mesh_device, + ) + + negative_val_tensor = ttnn.from_torch( + negative_val_tensor, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + mesh_mapper=inputs_mesh_mapper, + device=mesh_device, + ) + + tt_output = ttnn_optimized_distilbert.distilbert_for_question_answering( + config, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + parameters=parameters, + device=mesh_device, + min_val_tensor=min_val_tensor, + negative_val_tensor=negative_val_tensor, + mesh_mapper=weights_mesh_mapper, + ip_mesh_mapper=inputs_mesh_mapper, + ) + + tt_output = ttnn.to_torch(tt_output, mesh_composer=output_mesh_composer) + start_logits, end_logits = tt_output.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + assert_with_pcc(torch_output.start_logits, start_logits, 0.99) + assert_with_pcc(torch_output.end_logits, end_logits, 0.99) diff --git a/tests/ttnn/integration_tests/whisper/test_performance.py b/tests/ttnn/integration_tests/whisper/test_performance.py index 6f444176af8..e07e110b6bc 100644 --- a/tests/ttnn/integration_tests/whisper/test_performance.py +++ b/tests/ttnn/integration_tests/whisper/test_performance.py @@ -18,7 +18,7 @@ def get_expected_times(functional_whisper): return { ttnn_functional_whisper: (11.7, 4.16), - ttnn_optimized_functional_whisper: (1.55, 1.35), + ttnn_optimized_functional_whisper: (1.57, 1.35), }[functional_whisper] diff --git a/tests/ttnn/unit_tests/gtests/test_graph_query_op_constraints.cpp b/tests/ttnn/unit_tests/gtests/test_graph_query_op_constraints.cpp index 768639149a1..314328ac286 100644 --- a/tests/ttnn/unit_tests/gtests/test_graph_query_op_constraints.cpp +++ b/tests/ttnn/unit_tests/gtests/test_graph_query_op_constraints.cpp @@ -556,6 +556,8 @@ INSTANTIATE_TEST_SUITE_P( .in0_block_w = 2, .out_subblock_h = 1, .out_subblock_w = 1, + .out_block_h = 64, + .out_block_w = 2, .per_core_M = 64, .per_core_N = 2, .fuse_batch = true, diff --git a/tests/ttnn/unit_tests/operations/eltwise/test_unary.py b/tests/ttnn/unit_tests/operations/eltwise/test_unary.py index da305202e9c..805d73ca5c9 100644 --- a/tests/ttnn/unit_tests/operations/eltwise/test_unary.py +++ b/tests/ttnn/unit_tests/operations/eltwise/test_unary.py @@ -447,3 +447,22 @@ def test_unary_floor(input_shapes, device): golden_tensor = golden_function(in_data1) output_tensor = ttnn.to_torch(output_tensor) assert_with_pcc(golden_tensor, output_tensor, 0.999) + + +@skip_for_grayskull() +@pytest.mark.parametrize( + "input_shapes", + ( + (torch.Size([1, 1, 32, 32])), + (torch.Size([1, 1, 320, 384])), + (torch.Size([1, 3, 320, 384])), + ), +) +def test_unary_ceil(input_shapes, device): + in_data1 = torch.empty(input_shapes, dtype=torch.float32).uniform_(-43566, 43565) + input_tensor1 = ttnn.from_torch(in_data1, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device) + output_tensor = ttnn.ceil(input_tensor1) + golden_function = ttnn.get_golden_function(ttnn.ceil) + golden_tensor = golden_function(in_data1) + output_tensor = ttnn.to_torch(output_tensor) + assert_with_pcc(golden_tensor, output_tensor, 0.999) diff --git a/tests/ttnn/unit_tests/operations/test_matmul.py b/tests/ttnn/unit_tests/operations/test_matmul.py index d66fb34f464..c411ab46631 100644 --- a/tests/ttnn/unit_tests/operations/test_matmul.py +++ b/tests/ttnn/unit_tests/operations/test_matmul.py @@ -958,6 +958,232 @@ def test_matmul_1d_tiny_tile( assert device.num_program_cache_entries() == 1 +def run_matmul_1d_multiple_output_blocks_per_core( + device, + m, + k, + n, + has_bias, + grid_size, + in_sharded, + out_sharded, + num_out_block_h, + num_out_block_w, + mcast_in0, + uneven_width, +): + if in_sharded or out_sharded: + fuse_batch = True + else: + fuse_batch = False + + if out_sharded and num_out_block_w > 1: + pytest.skip("out sharded not support multiple blocks on w dim") + + if not mcast_in0: + tmp = m + m = n + n = tmp + + in0_shape = [1, 1, m, k] + in1_shape = [1, 1, k, n] + bias_shape = [1, 1, n] + + num_cores = grid_size[0] * grid_size[1] + + if mcast_in0: + in0_block_w = k // num_cores // 32 + per_core_M = m // 32 + per_core_N = n // num_cores // 32 + uneven_width + else: + in0_block_w = k // 32 // 2 # test exracting shards + per_core_M = m // 32 // num_cores + per_core_N = n // 32 + out_block_h = per_core_M // num_out_block_h + out_block_w = per_core_N // num_out_block_w + out_subblock_h, out_subblock_w, _ = find_max_subblock(out_block_h, out_block_w) + + logger.info(f"m: {m}") + logger.info(f"k: {k}") + logger.info(f"n: {n}") + logger.info(f"in0_block_w: {in0_block_w}") + logger.info(f"per_core_M: {per_core_M}") + logger.info(f"per_core_N: {per_core_N}") + logger.info(f"out_block_h: {out_block_h}") + logger.info(f"out_block_w: {out_block_w}") + logger.info(f"out_subblock_h: {out_subblock_h}") + logger.info(f"out_subblock_w: {out_subblock_w}") + + in0 = torch.randn(in0_shape).bfloat16().float() + in1 = torch.randn(in1_shape).bfloat16().float() + + if in_sharded: + if mcast_in0: + in0_memory_config = ttnn.create_sharded_memory_config( + (1, 1, m, k), + core_grid=ttnn.CoreGrid(y=grid_size[1], x=grid_size[0]), + strategy=ttnn.ShardStrategy.WIDTH, + orientation=ttnn.ShardOrientation.ROW_MAJOR, + ) + else: + in0_memory_config = ttnn.create_sharded_memory_config( + (1, 1, m, k), + core_grid=ttnn.CoreGrid(y=grid_size[1], x=grid_size[0]), + strategy=ttnn.ShardStrategy.HEIGHT, + orientation=ttnn.ShardOrientation.ROW_MAJOR, + ) + else: + in0_memory_config = ttnn.DRAM_MEMORY_CONFIG + in1_memory_config = ttnn.DRAM_MEMORY_CONFIG + in0_t = ttnn.from_torch( + in0, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + device=device, + memory_config=in0_memory_config, + ) + in1_t = ttnn.from_torch( + in1, + dtype=ttnn.bfloat8_b, + layout=ttnn.TILE_LAYOUT, + device=device, + memory_config=in1_memory_config, + ) + + if has_bias: + bias = torch.randn(bias_shape).bfloat16().float() + bias_padded = bias.unsqueeze(2) + bias_padded = torch.nn.functional.pad(bias_padded, (0, 0, 0, 32 - bias_padded.size(2)), "constant", 0) + bias_t = ttnn.from_torch( + bias_padded, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + device=device, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + ) + + program_config = ttnn.MatmulMultiCoreReuseMultiCast1DProgramConfig( + compute_with_storage_grid_size=grid_size, + in0_block_w=in0_block_w, + out_subblock_h=out_subblock_h, + out_subblock_w=out_subblock_w, + out_block_h=out_block_h, + out_block_w=out_block_w, + per_core_M=per_core_M, + per_core_N=per_core_N, + fuse_batch=fuse_batch, + fused_activation=None, + mcast_in0=mcast_in0, + ) + + if is_grayskull(): + compute_kernel_config = ttnn.GrayskullComputeKernelConfig( + math_fidelity=ttnn.MathFidelity.LoFi, + math_approx_mode=True, + ) + else: + compute_kernel_config = ttnn.WormholeComputeKernelConfig( + math_fidelity=ttnn.MathFidelity.LoFi, + math_approx_mode=True, + fp32_dest_acc_en=False, + packer_l1_acc=True, + ) + if out_sharded: + if mcast_in0: + out_mem_config = ttnn.MemoryConfig( + memory_layout=ttnn.TensorMemoryLayout.WIDTH_SHARDED, + buffer_type=ttnn.BufferType.L1, + ) + else: + out_mem_config = ttnn.MemoryConfig( + memory_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED, + buffer_type=ttnn.BufferType.L1, + ) + else: + out_mem_config = ttnn.DRAM_MEMORY_CONFIG + + if has_bias: + output_t = ttnn.linear( + in0_t, + in1_t, + bias=bias_t, + program_config=program_config, + memory_config=out_mem_config, + dtype=ttnn.bfloat16, + compute_kernel_config=compute_kernel_config, + ) + else: + output_t = ttnn.matmul( + in0_t, + in1_t, + program_config=program_config, + memory_config=out_mem_config, + dtype=ttnn.bfloat16, + compute_kernel_config=compute_kernel_config, + ) + output_tensor = ttnn.to_torch(output_t) + pt_out = in0 @ in1 + if has_bias: + pt_out += bias + + assert_with_pcc(pt_out, output_tensor, 0.999) + + +@run_for_wormhole_b0() +@pytest.mark.parametrize("m", [256]) +@pytest.mark.parametrize("k", [1024]) +@pytest.mark.parametrize("n", [2048]) +@pytest.mark.parametrize("has_bias", [False]) +@pytest.mark.parametrize("grid_size", [(8, 2)]) +@pytest.mark.parametrize("in_sharded", [True, False]) +@pytest.mark.parametrize("out_sharded", [True, False]) +@pytest.mark.parametrize("num_out_block_h", [1, 2]) +@pytest.mark.parametrize("num_out_block_w", [1, 2]) +@pytest.mark.parametrize("mcast_in0", [True, False]) +@pytest.mark.parametrize("uneven_width", [0, 2]) +def test_matmul_1d_multiple_output_blocks_per_core( + device, + m, + k, + n, + has_bias, + grid_size, + in_sharded, + out_sharded, + num_out_block_h, + num_out_block_w, + mcast_in0, + uneven_width, + use_program_cache, +): + for _ in range(2): + run_matmul_1d_multiple_output_blocks_per_core( + device, + m, + k, + n, + has_bias, + grid_size, + in_sharded, + out_sharded, + num_out_block_h, + num_out_block_w, + mcast_in0, + uneven_width, + ) + # dummy tensor to change tensor alloc + dummy_shape = [1, 1, 32, 32] + py_dummy_tensor = torch.randn(dummy_shape) + tt_dummy_tensor = ttnn.from_torch( + py_dummy_tensor, + dtype=ttnn.DataType.BFLOAT16, + layout=ttnn.TILE_LAYOUT, + device=device, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + assert device.num_program_cache_entries() == 1 + + # fmt: off @pytest.mark.skipif(is_wormhole_b0() or is_blackhole(), reason="Unsupported on WH and BH") @pytest.mark.parametrize("m_size,k_size,n_size", [ diff --git a/tests/ttnn/unit_tests/operations/test_moreh_clip_grad_norm.py b/tests/ttnn/unit_tests/operations/test_moreh_clip_grad_norm.py index c4989099cc4..a71ab46945a 100644 --- a/tests/ttnn/unit_tests/operations/test_moreh_clip_grad_norm.py +++ b/tests/ttnn/unit_tests/operations/test_moreh_clip_grad_norm.py @@ -15,27 +15,6 @@ from tests.tt_eager.python_api_testing.unit_testing.misc.test_utils import TILE_HEIGHT, TILE_WIDTH -def to_cpu(npu_tensor, shape, *, cpu_layout=ttnn.ROW_MAJOR_LAYOUT): - if npu_tensor is None: - return None - cpu_tensor = npu_tensor.cpu().to(cpu_layout).unpad_from_tile(shape).to_torch() - return cpu_tensor - - -def to_npu( - cpu_tensor, - device, - *, - npu_layout=ttnn.TILE_LAYOUT, - npu_dtype=ttnn.bfloat16, - padding_value=float("nan"), -): - if cpu_tensor is None: - return None - return ttnn.from_torch(cpu_tensor, npu_dtype, device=device, layout=npu_layout) - - -@pytest.mark.skip(reason="assertion fails during binary op input shape comparison because of different padding") @pytest.mark.parametrize("num_iters_of_each_case", [2]) @pytest.mark.parametrize("range_of_padding", [(0, 21, 10)]) # [0, 10, 20] @pytest.mark.parametrize("range_of_n", [(1, 4)]) @@ -88,66 +67,70 @@ def test_moreh_clip_grad_norm( param.grad = grad cpu_inputs.append(param) - npu_inputs.append(to_npu(grad.clone().bfloat16(), device, npu_dtype=npu_dtype)) + npu_inputs.append( + ttnn.from_torch(grad.clone().bfloat16(), dtype=npu_dtype, layout=ttnn.TILE_LAYOUT, device=device) + ) + # npu_inputs.append(to_npu(grad.clone().bfloat16(), device, npu_dtype=npu_dtype)) input_shapes.append(input_shape) cpu_total_norm = torch.nn.utils.clip_grad_norm_(cpu_inputs, max_norm, norm_type) npu_total_norm = ttnn.operations.moreh.clip_grad_norm(npu_inputs, max_norm, norm_type) - + actual_total_norm = ttnn.to_torch(npu_total_norm).reshape(1) expected_total_norm = cpu_total_norm - actual_total_norm = to_cpu(npu_total_norm, [1, 1, 1, 1]) rtol = atol = 0.1 # Check total_norm pass_total_norm, out_total_norm = comp_allclose_and_pcc( - actual_total_norm.double(), expected_total_norm.double(), rtol=rtol, atol=atol + actual_total_norm, expected_total_norm, rtol=rtol, atol=atol ) logger.debug(f"total_norm's {out_total_norm}") assert pass_total_norm # Check inputs for i in range(num_parameters): - expected_input_i = cpu_inputs[i].grad.double() - actual_input_i = to_cpu(npu_inputs[i], input_shapes[i]).double() + expected_input_i = cpu_inputs[i].grad + actual_input_i = ttnn.to_torch(npu_inputs[i]) pass_input_i, out_input_i = comp_allclose_and_pcc(expected_input_i, actual_input_i, rtol=rtol, atol=atol) logger.debug(f"inputs[{i}]-shape[{input_shapes[i]}]'s {out_input_i}") assert pass_input_i -# @pytest.mark.skipif(is_wormhole_b0() or is_blackhole(), reason="Unsupported on WH and BH") -# @pytest.mark.parametrize("error_if_nonfinite", [True, False]) -# def test_moreh_clip_grad_norm_with_error_if_nonfinite(error_if_nonfinite, device): -# torch.manual_seed(2023) - -# cpu_dtype = torch.bfloat16 -# npu_dtype = ttnn.bfloat16 - -# input_shape = [4, 4, 4 * TILE_HEIGHT, 4 * TILE_WIDTH] -# param = torch.nn.Parameter(torch.empty(input_shape, dtype=cpu_dtype)) -# grad = torch.randn(input_shape, dtype=cpu_dtype) -# param.grad = grad - -# max_norm = 1.0 -# norm_type = float("nan") - -# expected_error_msg = ( -# f"The total norm of order {norm_type} for gradients from `parameters` is non-finite, so it cannot be clipped." -# ) - -# # Check vanilla torch behavior -# try: -# torch.nn.utils.clip_grad_norm_((param), max_norm, norm_type, error_if_nonfinite) -# assert not error_if_nonfinite -# except RuntimeError as actual_error_msg: -# assert expected_error_msg in str(actual_error_msg) -# assert error_if_nonfinite - -# # Check tt behavior -# try: -# ttnn.operations.moreh.clip_grad_norm( -# [to_npu(param.grad.bfloat16(), device, npu_dtype=npu_dtype)], max_norm, norm_type, error_if_nonfinite -# ) -# assert not error_if_nonfinite -# except RuntimeError as actual_error_msg: -# assert expected_error_msg in str(actual_error_msg) -# assert error_if_nonfinite +@pytest.mark.parametrize("error_if_nonfinite", [True, False]) +def test_moreh_clip_grad_norm_with_error_if_nonfinite(error_if_nonfinite, device): + torch.manual_seed(2023) + + cpu_dtype = torch.bfloat16 + npu_dtype = ttnn.bfloat16 + + input_shape = [4, 4, 4 * TILE_HEIGHT, 4 * TILE_WIDTH] + param = torch.nn.Parameter(torch.empty(input_shape, dtype=cpu_dtype)) + grad = torch.randn(input_shape, dtype=cpu_dtype) + param.grad = grad + + max_norm = 1.0 + norm_type = float("nan") + + expected_error_msg = ( + f"The total norm of order {norm_type} for gradients from `parameters` is non-finite, so it cannot be clipped." + ) + + # Check vanilla torch behavior + try: + torch.nn.utils.clip_grad_norm_((param), max_norm, norm_type, error_if_nonfinite) + assert not error_if_nonfinite + except RuntimeError as actual_error_msg: + assert expected_error_msg in str(actual_error_msg) + assert error_if_nonfinite + + # Check tt behavior + try: + ttnn.operations.moreh.clip_grad_norm( + [ttnn.from_torch(param.grad.bfloat16(), dtype=npu_dtype, layout=ttnn.TILE_LAYOUT, device=device)], + max_norm, + norm_type, + error_if_nonfinite, + ) + assert not error_if_nonfinite + except RuntimeError as actual_error_msg: + assert expected_error_msg in str(actual_error_msg) + assert error_if_nonfinite diff --git a/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/ckernel_sfpu_ceil.h b/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/ckernel_sfpu_ceil.h index d803f841dbf..c34da7ec973 100644 --- a/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/ckernel_sfpu_ceil.h +++ b/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/ckernel_sfpu_ceil.h @@ -9,6 +9,7 @@ #include "sfpi.h" #include "noc_nonblocking_api.h" #include "limits.h" +#include "ckernel_sfpu_floor.h" using namespace sfpi; @@ -20,7 +21,7 @@ inline void calculate_ceil() { for (int d = 0; d < ITERATIONS; d++) { vFloat result = dst_reg[0]; vFloat v = result; - vInt tmp = float_to_int16(result, 0); // TODO: Replace float_to_int16 to float_to_int32 once it is available + vInt tmp = float_to_int16(result, 0); result = int32_to_float(tmp, 0); v_if(result < v) { result = result + 1; } v_endif; @@ -31,5 +32,19 @@ inline void calculate_ceil() { } } +template +inline void calculate_ceil_float32() { + for (int d = 0; d < ITERATIONS; d++) { + vFloat result = dst_reg[0]; + vFloat v = result; + vInt tmp = float_to_int32(result); + result = int32_to_float(tmp, 0); + v_if(result < v) { result = result + 1; } + v_endif; + dst_reg[0] = result; + dst_reg++; + } +} + } // namespace sfpu } // namespace ckernel diff --git a/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_ceil.h b/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_ceil.h index 690528ed0c4..e328d869eb8 100644 --- a/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_ceil.h +++ b/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_ceil.h @@ -22,4 +22,9 @@ inline void llk_math_eltwise_unary_sfpu_ceil(uint dst_index, int vector_mode = ( llk_math_eltwise_unary_sfpu_params(ckernel::sfpu::calculate_ceil, dst_index, vector_mode); } +template +inline void llk_math_eltwise_unary_sfpu_ceil_float32(uint dst_index, int vector_mode = (int)VectorMode::RC) { + llk_math_eltwise_unary_sfpu_params( + ckernel::sfpu::calculate_ceil_float32, dst_index, vector_mode); +} } // namespace ckernel diff --git a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/ckernel_sfpu_ceil.h b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/ckernel_sfpu_ceil.h index d803f841dbf..983faf8db6f 100644 --- a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/ckernel_sfpu_ceil.h +++ b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/ckernel_sfpu_ceil.h @@ -9,6 +9,7 @@ #include "sfpi.h" #include "noc_nonblocking_api.h" #include "limits.h" +#include "ckernel_sfpu_floor.h" using namespace sfpi; @@ -20,7 +21,7 @@ inline void calculate_ceil() { for (int d = 0; d < ITERATIONS; d++) { vFloat result = dst_reg[0]; vFloat v = result; - vInt tmp = float_to_int16(result, 0); // TODO: Replace float_to_int16 to float_to_int32 once it is available + vInt tmp = float_to_int16(result, 0); result = int32_to_float(tmp, 0); v_if(result < v) { result = result + 1; } v_endif; @@ -31,5 +32,18 @@ inline void calculate_ceil() { } } +template +inline void calculate_ceil_float32() { + for (int d = 0; d < ITERATIONS; d++) { + vFloat result = dst_reg[0]; + vFloat v = result; + vInt tmp = float_to_int32(result); + result = int32_to_float(tmp, 0); + v_if(result < v) { result = result + 1; } + v_endif; + dst_reg[0] = result; + dst_reg++; + } +} } // namespace sfpu } // namespace ckernel diff --git a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_ceil.h b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_ceil.h index 690528ed0c4..e328d869eb8 100644 --- a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_ceil.h +++ b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_ceil.h @@ -22,4 +22,9 @@ inline void llk_math_eltwise_unary_sfpu_ceil(uint dst_index, int vector_mode = ( llk_math_eltwise_unary_sfpu_params(ckernel::sfpu::calculate_ceil, dst_index, vector_mode); } +template +inline void llk_math_eltwise_unary_sfpu_ceil_float32(uint dst_index, int vector_mode = (int)VectorMode::RC) { + llk_math_eltwise_unary_sfpu_params( + ckernel::sfpu::calculate_ceil_float32, dst_index, vector_mode); +} } // namespace ckernel diff --git a/tt_metal/include/compute_kernel_api/eltwise_unary/ceil.h b/tt_metal/include/compute_kernel_api/eltwise_unary/ceil.h index bc2f2e7863b..3d8d27724fd 100644 --- a/tt_metal/include/compute_kernel_api/eltwise_unary/ceil.h +++ b/tt_metal/include/compute_kernel_api/eltwise_unary/ceil.h @@ -31,9 +31,25 @@ ALWI void ceil_tile_init() { MATH((llk_math_eltwise_unary_sfpu_ceil_init * | Argument | Description | Type | Valid * Range | Required | * |-----------------|----------------------------------------------------------------------------|----------|-------------------------------------------------------|----------| - * | idst | The index of the tile in DST register buffer to modify the sign bit of | uint32_t | Must be + * | idst | The index of the tile in DST register buffer to perform ceil operation | uint32_t | Must be * less than the size of the DST register buffer | True | */ ALWI void ceil_tile(uint32_t idst) { MATH((llk_math_eltwise_unary_sfpu_ceil(idst))); } +/** + * Performs ceil operation on each row of a tile. + * in DST register at index tile_index. The DST register buffer must be in + * acquired state via *acquire_dst* call. This call is blocking and is only + * available on the compute engine. + * + * Return value: None + * + * | Argument | Description | Type | Valid + * Range | Required | + * |-----------------|----------------------------------------------------------------------------|----------|-------------------------------------------------------|----------| + * | idst | The index of the tile in DST register buffer to perform ceil operation | uint32_t | Must be + * less than the size of the DST register buffer | True | + */ +ALWI void ceil_tile_float32(uint32_t idst) { MATH((llk_math_eltwise_unary_sfpu_ceil_float32(idst))); } + } // namespace ckernel diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.cpp index bf215230584..9060a439669 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.cpp @@ -754,6 +754,8 @@ ttnn::operations::matmul::MatmulProgramConfig determine_matmul_op_config_from_co .in0_block_w = conv_blocking_config.act_block_w_ntiles, .out_subblock_h = conv_blocking_config.out_subblock_h_ntiles, .out_subblock_w = conv_blocking_config.out_subblock_w_ntiles, + .out_block_h = div_up(conv_parallelization_config.per_core_out_matrix_height, tt::constants::TILE_HEIGHT), + .out_block_w = div_up(conv_parallelization_config.per_core_out_matrix_width, tt::constants::TILE_WIDTH), .per_core_M = div_up(conv_parallelization_config.per_core_out_matrix_height, tt::constants::TILE_HEIGHT), .per_core_N = div_up(conv_parallelization_config.per_core_out_matrix_width, tt::constants::TILE_WIDTH), .fuse_batch = true, diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/binary.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary/binary.hpp index 67e00b2a311..514b3df54ca 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/binary.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/binary.hpp @@ -249,22 +249,6 @@ constexpr auto ne_ = ttnn::register_operation_with_auto_launch_op< "ttnn::ne_", operations::binary::InplaceRelationalBinary>(); -constexpr auto rsub_binary = ttnn::register_operation_with_auto_launch_op< - "ttnn::rsub_binary", - operations::binary::BinaryOperation>(); -constexpr auto power_binary = ttnn::register_operation_with_auto_launch_op< - "ttnn::power_binary", - operations::binary::BinaryOperationSfpu>(); -constexpr auto bitwise_and_binary = ttnn::register_operation_with_auto_launch_op< - "ttnn::bitwise_and_binary", - operations::binary::BinaryOperationSfpu>(); -constexpr auto bitwise_or_binary = ttnn::register_operation_with_auto_launch_op< - "ttnn::bitwise_or_binary", - operations::binary::BinaryOperationSfpu>(); -constexpr auto bitwise_xor_binary = ttnn::register_operation_with_auto_launch_op< - "ttnn::bitwise_xor_binary", - operations::binary::BinaryOperationSfpu>(); - template ttnn::Tensor operator+(const ttnn::Tensor& input_tensor_a, InputBType scalar) { return add(input_tensor_a, scalar); diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/binary_composite.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary/binary_composite.hpp index f366c1104f1..c89bf48fae6 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/binary_composite.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/binary_composite.hpp @@ -18,6 +18,69 @@ namespace operations { namespace binary { +/** + * @brief Performs element-wise power operation on the input with the exponent. + * When exponent is Tensor, the supported dtypes are float32 and bfloat16. + * The tested range for the input is (-30,30) and for the exponent is (-20, 20). + * + * @param input The input tensor, i.e the base. + * @param exponent The exponent + * @return The result tensor + */ +struct ExecutePower { + static Tensor invoke( + uint8_t queue_id, + const Tensor& input_tensor, + uint32_t exponent, + const std::optional& memory_config = std::nullopt, + const std::optional& optional_output_tensor = std::nullopt); + + static Tensor invoke( + const Tensor& input_tensor, + uint32_t exponent, + const std::optional& memory_config = std::nullopt, + const std::optional& optional_output_tensor = std::nullopt); + + static Tensor invoke( + uint8_t queue_id, + const Tensor& input_tensor, + float exponent, + const std::optional& memory_config = std::nullopt, + const std::optional& optional_output_tensor = std::nullopt); + + static Tensor invoke( + const Tensor& input_tensor, + float exponent, + const std::optional& memory_config = std::nullopt, + const std::optional& optional_output_tensor = std::nullopt); + + static Tensor invoke( + uint8_t queue_id, + float input_a, + const Tensor& exponent, + const std::optional& memory_config = std::nullopt, + const std::optional& optional_output_tensor = std::nullopt); + + static Tensor invoke( + float input_a, + const Tensor& exponent, + const std::optional& memory_config = std::nullopt, + const std::optional& optional_output_tensor = std::nullopt); + + static Tensor invoke( + uint8_t queue_id, + const Tensor& input_tensor, + const Tensor& exponent, + const std::optional& memory_config = std::nullopt, + const std::optional& optional_output_tensor = std::nullopt); + + static Tensor invoke( + const Tensor& input_tensor, + const Tensor& exponent, + const std::optional& memory_config = std::nullopt, + const std::optional& optional_output_tensor = std::nullopt); +}; + template struct ExecuteBinaryCompositeOps { static Tensor invoke( @@ -436,5 +499,6 @@ constexpr auto rsub = ttnn::register_operation_with_auto_launch_op<"ttnn::rsub", constexpr auto bitwise_and = ttnn::register_operation_with_auto_launch_op<"ttnn::bitwise_and", operations::binary::ExecuteBitwiseAnd>(); constexpr auto bitwise_or = ttnn::register_operation_with_auto_launch_op<"ttnn::bitwise_or", operations::binary::ExecuteBitwiseOr>(); constexpr auto bitwise_xor = ttnn::register_operation_with_auto_launch_op<"ttnn::bitwise_xor", operations::binary::ExecuteBitwiseXor>(); +constexpr auto pow = ttnn::register_operation_with_auto_launch_op<"ttnn::pow", operations::binary::ExecutePower>(); } // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/binary_pybind.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary/binary_pybind.hpp index d244c63a83f..1f5cfde8f29 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/binary_pybind.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/binary_pybind.hpp @@ -1265,6 +1265,124 @@ void bind_binary_inplace_operation( py::arg("activations") = std::nullopt, py::arg("input_tensor_a_activation") = std::nullopt}); } + +template +void bind_power(py::module& module, const binary_operation_t& operation, const std::string& note = "") { + auto doc = fmt::format( + R"doc( + Perform element-wise {0} operation on :attr:`input_tensor` with :attr:`exponent`. + + .. math:: + \mathrm{{output\_tensor}}_i = (\mathrm{{input\_tensor}}_i ** \mathrm{{exponent}}_i) + + Args: + input_tensor (ttnn.Tensor, float): the input tensor. + exponent (float, int, ttnn.Tensor): the exponent value. + + Keyword Args: + memory_config (ttnn.MemoryConfig, optional): memory configuration for the operation. Defaults to `None`. + output_tensor (ttnn.Tensor, optional): preallocated output tensor. Defaults to `None`. + queue_id (int, optional): command queue id. Defaults to `0`. + + Returns: + ttnn.Tensor: the output tensor. + + Note: + Supported dtypes, layouts, and ranks: + + .. list-table:: + :header-rows: 1 + + * - Dtypes + - Layouts + - Ranks + * - BFLOAT16, BFLOAT8_B + - TILE + - 2, 3, 4 + + {2} + + Example: + >>> tensor = ttnn.from_torch(torch.tensor([[1, 2], [3, 4]], dtype=torch.bfloat16), layout=ttnn.TILE_LAYOUT, device=device) + >>> exponent = 2 + >>> output = {1}(tensor, exponent) + )doc", + ttnn::pow.base_name(), + ttnn::pow.python_fully_qualified_name(), + note); + + bind_registered_operation( + module, + ttnn::pow, + doc, + // integer exponent + ttnn::pybind_overload_t{ + [](const binary_operation_t& self, + const Tensor& input_tensor, + uint32_t exponent, + const std::optional& memory_config, + const std::optional& output_tensor, + const uint8_t queue_id) -> ttnn::Tensor { + return self(queue_id, input_tensor, exponent, memory_config, output_tensor); + }, + py::arg("input_tensor"), + py::arg("exponent"), + py::kw_only(), + py::arg("memory_config") = std::nullopt, + py::arg("output_tensor") = std::nullopt, + py::arg("queue_id") = 0}, + + // float exponent + ttnn::pybind_overload_t{ + [](const binary_operation_t& self, + const Tensor& input_tensor, + float exponent, + const std::optional& memory_config, + std::optional output_tensor, + const uint8_t queue_id) -> ttnn::Tensor { + return self(queue_id, input_tensor, exponent, memory_config, output_tensor); + }, + py::arg("input_tensor"), + py::arg("exponent"), + py::kw_only(), + py::arg("memory_config") = std::nullopt, + py::arg("output_tensor") = std::nullopt, + py::arg("queue_id") = ttnn::DefaultQueueId}, + + // tensor exponent + ttnn::pybind_overload_t{ + [](const binary_operation_t& self, + const Tensor& input_tensor, + const Tensor& exponent, + const std::optional& memory_config, + std::optional output_tensor, + const uint8_t queue_id) -> ttnn::Tensor { + return self(queue_id, input_tensor, exponent, memory_config, output_tensor); + }, + py::arg("input_tensor"), + py::arg("exponent"), + py::kw_only(), + py::arg("memory_config") = std::nullopt, + py::arg("output_tensor") = std::nullopt, + py::arg("queue_id") = ttnn::DefaultQueueId}, + + // scalar input - tensor exponent + ttnn::pybind_overload_t{ + [](const binary_operation_t& self, + float input, + const Tensor& exponent, + const std::optional& memory_config, + std::optional output_tensor, + const uint8_t queue_id) -> ttnn::Tensor { + return self(queue_id, input, exponent, memory_config, output_tensor); + }, + py::arg("input"), + py::arg("exponent"), + py::kw_only(), + py::arg("memory_config") = std::nullopt, + py::arg("output_tensor") = std::nullopt, + py::arg("queue_id") = ttnn::DefaultQueueId}); +} } // namespace detail void py_module(py::module& module) { @@ -1434,7 +1552,7 @@ void py_module(py::module& module) { module, ttnn::bitwise_and, R"doc(Perform bitwise_and operation on :attr:`input_tensor_a` and :attr:`input_tensor_b` and returns the tensor with the same layout as :attr:`input_tensor_a`)doc", - R"doc(\mathrm{{output\_tensor}}_i = \mathrm{{input\_tensor\_b}}_i \verb|bitwise_and| \mathrm{{input\_tensor\_a}}_i)doc", + R"doc(\mathrm{{output\_tensor}}_i = \verb|bitwise_and|(\mathrm{{input\_tensor\_a, input\_tensor\_b}}))doc", ". ", R"doc(INT32)doc"); @@ -1442,7 +1560,7 @@ void py_module(py::module& module) { module, ttnn::bitwise_or, R"doc(Perform bitwise_or operation on :attr:`input_tensor_a` and :attr:`input_tensor_b` and returns the tensor with the same layout as :attr:`input_tensor_a`)doc", - R"doc(\mathrm{{output\_tensor}}_i = \mathrm{{input\_tensor\_b}}_i \verb|bitwise_or| \mathrm{{input\_tensor\_a}}_i)doc", + R"doc(\mathrm{{output\_tensor}}_i = \verb|bitwise_or|(\mathrm{{input\_tensor\_a, input\_tensor\_b}}))doc", ". ", R"doc(INT32)doc"); @@ -1450,7 +1568,7 @@ void py_module(py::module& module) { module, ttnn::bitwise_xor, R"doc(Perform bitwise_xor operation on :attr:`input_tensor_a` and :attr:`input_tensor_b` and returns the tensor with the same layout as :attr:`input_tensor_a`)doc", - R"doc(\mathrm{{output\_tensor}}_i = \mathrm{{input\_tensor\_b}}_i \verb|bitwise_xor| \mathrm{{input\_tensor\_a}}_i)doc", + R"doc(\mathrm{{output\_tensor}}_i = \verb|bitwise_xor|(\mathrm{{input\_tensor\_a, input\_tensor\_b}}))doc", ". ", R"doc(INT32)doc"); @@ -1689,6 +1807,9 @@ void py_module(py::module& module) { R"doc(Performs Not equal to in-place operation on :attr:`input_a` and :attr:`input_b` and returns the tensor with the same layout as :attr:`input_tensor`)doc", R"doc(\mathrm{{input\_tensor\_a}}\: != \mathrm{{input\_tensor\_b}})doc", R"doc(BFLOAT16, BFLOAT8_B)doc"); + + detail::bind_power( + module, ttnn::pow, R"doc(When :attr:`exponent` is a Tensor, supported dtypes are: BFLOAT16, FLOAT32)doc"); } } // namespace binary diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_composite_op.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_composite_op.cpp index f09d2b08e8a..27cdec97ee9 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_composite_op.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_composite_op.cpp @@ -131,7 +131,7 @@ Tensor ExecuteMaximum::invoke( Tensor ExecuteMaximum::invoke( const Tensor& input_a, float value, const std::optional& output_mem_config) { - Tensor t_diff = ttnn::rsub_unary(input_a, value, output_mem_config); + Tensor t_diff = ttnn::rsub(input_a, value, output_mem_config); Tensor result = ttnn::where(t_diff, value, input_a); return result; } @@ -526,6 +526,111 @@ Tensor ExecuteLCM::invoke( return ttnn::abs(result); } +// power - floating point exponent +Tensor ExecutePower::invoke( + uint8_t queue_id, + const Tensor& input_a, + float exponent, + const std::optional& output_mem_config, + const std::optional& output_tensor) { + TT_FATAL(exponent >= 0.0f, "works for positive exponents only"); + const uint32_t exponent_floor = static_cast(std::floor(exponent)); + if (static_cast(exponent_floor) == exponent) { + if (output_tensor.has_value()) { + ttnn::power(queue_id, input_a, exponent_floor, output_mem_config, output_tensor); + return output_tensor.value(); + } + return ttnn::power(queue_id, input_a, exponent_floor, output_mem_config); + } + const float exponent_trunc = exponent - static_cast(exponent_floor); + Tensor pow_trunc_log = ttnn::multiply( + queue_id, ttnn::log(queue_id, input_a, output_mem_config), exponent_trunc, std::nullopt, output_mem_config); + Tensor pow_frac = ttnn::exp(queue_id, pow_trunc_log, false, output_mem_config); + pow_trunc_log.deallocate(); + float t_nan = std::nanf(""); + Tensor result = ttnn::multiply( + queue_id, + ttnn::power(queue_id, input_a, exponent_floor, output_mem_config), + pow_frac, + std::nullopt, + output_mem_config); + // To handle negative inputs: + // in torch For -ve inputs with float exponent power returns nan + auto output_memory_config = output_tensor.has_value() ? output_tensor.value().memory_config() + : output_mem_config.value_or(input_a.memory_config()); + result = ttnn::where( + ttnn::ltz(queue_id, input_a, output_mem_config), t_nan, result, output_memory_config, output_tensor); + return result; +} + +// power - floating point exponent +Tensor ExecutePower::invoke( + const Tensor& input_a, + float exponent, + const std::optional& output_mem_config, + const std::optional& output_tensor) { + return ExecutePower::invoke(DefaultQueueId, input_a, exponent, output_mem_config, std::move(output_tensor)); +} + +// power - integer exponent +Tensor ExecutePower::invoke( + uint8_t queue_id, + const Tensor& input, + uint32_t exponent, + const std::optional& output_mem_config, + const std::optional& output_tensor) { + return ttnn::power(queue_id, input, exponent, output_mem_config, output_tensor); +} + +// power - integer exponent +Tensor ExecutePower::invoke( + const Tensor& input, + uint32_t exponent, + const std::optional& output_mem_config, + const std::optional& output_tensor) { + return ExecutePower::invoke(DefaultQueueId, input, exponent, output_mem_config, std::move(output_tensor)); +} + +// power - tensor exponent +Tensor ExecutePower::invoke( + uint8_t queue_id, + const Tensor& input, + const Tensor& exponent, + const std::optional& output_mem_config, + const std::optional& output_tensor) { + return BinaryOperationSfpu::invoke( + queue_id, input, exponent, std::nullopt, output_mem_config, output_tensor); +} + +// power - tensor exponent +Tensor ExecutePower::invoke( + const Tensor& input, + const Tensor& exponent, + const std::optional& output_mem_config, + const std::optional& output_tensor) { + return ExecutePower::invoke(DefaultQueueId, input, exponent, output_mem_config, std::move(output_tensor)); +} + +// power - scalar input +Tensor ExecutePower::invoke( + uint8_t queue_id, + float input_a, + const Tensor& exponent, + const std::optional& output_mem_config, + const std::optional& output_tensor) { + Tensor input = ttnn::full_like(exponent, input_a); + return ExecutePower::invoke(queue_id, input, exponent, output_mem_config, std::move(output_tensor)); +} + +// power - scalar input +Tensor ExecutePower::invoke( + float input_a, + const Tensor& exponent, + const std::optional& output_mem_config, + const std::optional& output_tensor) { + return ExecutePower::invoke(DefaultQueueId, input_a, exponent, output_mem_config, std::move(output_tensor)); +} + Tensor ExecuteRsub::invoke( uint8_t queue_id, const Tensor& input_tensor_a, @@ -535,8 +640,7 @@ Tensor ExecuteRsub::invoke( const std::optional& optional_output_tensor, const std::optional& activations, const std::optional& input_tensor_a_activation) { - - return ttnn::rsub_binary( + return BinaryOperation::invoke( queue_id, input_tensor_a, input_tensor_b, @@ -573,13 +677,8 @@ Tensor ExecuteRsub::invoke( const float input_b, const std::optional& memory_config, const std::optional& optional_output_tensor) { - - return ttnn::rsub_unary( - queue_id, - input_tensor_a, - input_b, - memory_config, - optional_output_tensor); + return ttnn::operations::unary::ExecuteUnaryWithFloatParameter::invoke( + queue_id, input_tensor_a, input_b, memory_config, optional_output_tensor); } Tensor ExecuteRsub::invoke( @@ -603,14 +702,8 @@ Tensor ExecuteBitwiseAnd::invoke( const Tensor& input_tensor_b, const std::optional& memory_config, const std::optional& optional_output_tensor) { - - return ttnn::bitwise_and_binary( - queue_id, - input_tensor_a, - input_tensor_b, - std::nullopt, - memory_config, - optional_output_tensor); + return BinaryOperationSfpu::invoke( + queue_id, input_tensor_a, input_tensor_b, std::nullopt, memory_config, optional_output_tensor); } Tensor ExecuteBitwiseAnd::invoke( @@ -633,13 +726,9 @@ Tensor ExecuteBitwiseAnd::invoke( const int32_t input_b, const std::optional& memory_config, const std::optional& optional_output_tensor) { - - return ttnn::bitwise_and_unary( - queue_id, - input_tensor_a, - input_b, - memory_config, - optional_output_tensor); + return ttnn::operations::unary:: + ExecuteUnaryWithIntegerParameter::invoke( + queue_id, input_tensor_a, input_b, memory_config, optional_output_tensor); } Tensor ExecuteBitwiseAnd::invoke( @@ -663,14 +752,8 @@ Tensor ExecuteBitwiseOr::invoke( const Tensor& input_tensor_b, const std::optional& memory_config, const std::optional& optional_output_tensor) { - - return ttnn::bitwise_or_binary( - queue_id, - input_tensor_a, - input_tensor_b, - std::nullopt, - memory_config, - optional_output_tensor); + return BinaryOperationSfpu::invoke( + queue_id, input_tensor_a, input_tensor_b, std::nullopt, memory_config, optional_output_tensor); } Tensor ExecuteBitwiseOr::invoke( @@ -693,13 +776,9 @@ Tensor ExecuteBitwiseOr::invoke( const int32_t input_b, const std::optional& memory_config, const std::optional& optional_output_tensor) { - - return ttnn::bitwise_or_unary( - queue_id, - input_tensor_a, - input_b, - memory_config, - optional_output_tensor); + return ttnn::operations::unary:: + ExecuteUnaryWithIntegerParameter::invoke( + queue_id, input_tensor_a, input_b, memory_config, optional_output_tensor); } Tensor ExecuteBitwiseOr::invoke( @@ -723,14 +802,8 @@ Tensor ExecuteBitwiseXor::invoke( const Tensor& input_tensor_b, const std::optional& memory_config, const std::optional& optional_output_tensor) { - - return ttnn::bitwise_xor_binary( - queue_id, - input_tensor_a, - input_tensor_b, - std::nullopt, - memory_config, - optional_output_tensor); + return BinaryOperationSfpu::invoke( + queue_id, input_tensor_a, input_tensor_b, std::nullopt, memory_config, optional_output_tensor); } Tensor ExecuteBitwiseXor::invoke( @@ -753,13 +826,9 @@ Tensor ExecuteBitwiseXor::invoke( const int32_t input_b, const std::optional& memory_config, const std::optional& optional_output_tensor) { - - return ttnn::bitwise_xor_unary( - queue_id, - input_tensor_a, - input_b, - memory_config, - optional_output_tensor); + return ttnn::operations::unary:: + ExecuteUnaryWithIntegerParameter::invoke( + queue_id, input_tensor_a, input_b, memory_config, optional_output_tensor); } Tensor ExecuteBitwiseXor::invoke( diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary/common/unary_op_types.hpp b/ttnn/cpp/ttnn/operations/eltwise/unary/common/unary_op_types.hpp index da0ef55c936..ec83abec1cf 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary/common/unary_op_types.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary/common/unary_op_types.hpp @@ -82,6 +82,7 @@ enum class UnaryOpType { FLOOR, FLOOR_FLOAT32, CEIL, + CEIL_FLOAT32, LEFT_SHIFT, REMAINDER, FMOD, diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary/common/unary_op_utils.cpp b/ttnn/cpp/ttnn/operations/eltwise/unary/common/unary_op_utils.cpp index 0732e967602..4ed08212b53 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary/common/unary_op_utils.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary/common/unary_op_utils.cpp @@ -51,6 +51,8 @@ void update_macro_defines(UnaryOpType op_type, std::map get_op_init_and_func_default(UnaryOpType op_type, std: case UnaryOpType::SIGNBIT: op_init_and_name = {"signbit_tile_init();", fmt::format("signbit_tile({});", idst)}; break; - case UnaryOpType::CEIL: op_init_and_name = {"ceil_tile_init();", fmt::format("ceil_tile({});", idst)}; break; case UnaryOpType::SIN: op_init_and_name = {"sin_tile_init();", fmt::format("sin_tile({});", idst)}; break; case UnaryOpType::COS: op_init_and_name = {"cos_tile_init();", fmt::format("cos_tile({});", idst)}; break; case UnaryOpType::ISFINITE: @@ -344,7 +344,11 @@ std::pair get_op_init_and_func_default(UnaryOpType op_type, std: op_init_and_name = {"floor_tile_init();", fmt::format("floor_tile({});", idst)}; break; case UnaryOpType::FLOOR_FLOAT32: - op_init_and_name = {"floor_tile_init();", fmt::format("floor_tile_float32({});", idst)}; break; + op_init_and_name = {"floor_tile_init();", fmt::format("floor_tile_float32({});", idst)}; + break; + case UnaryOpType::CEIL: op_init_and_name = {"ceil_tile_init();", fmt::format("ceil_tile({});", idst)}; break; + case UnaryOpType::CEIL_FLOAT32: + op_init_and_name = {"ceil_tile_init();", fmt::format("ceil_tile_float32({});", idst)}; break; case UnaryOpType::RELU6: op_init_and_name = {"relu_max_tile_init();", fmt::format("relu_max_tile({}, 0x40c00000u);", idst)}; diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_composite_op.cpp b/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_composite_op.cpp index a170e3d9855..04baa0f9599 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_composite_op.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_composite_op.cpp @@ -42,110 +42,6 @@ Tensor _tanhshrink(const Tensor& x, const std::optional& output_me return result; } -// power - floating point exponent -Tensor ExecutePower::invoke( - uint8_t queue_id, - const Tensor& input_a, - float exponent, - const std::optional& output_mem_config, - std::optional output_tensor) { - TT_FATAL(exponent >= 0.0f, "works for positive exponents only"); - const uint32_t exponent_floor = static_cast(std::floor(exponent)); - if (static_cast(exponent_floor) == exponent) { - if (output_tensor.has_value()) { - ttnn::power(queue_id, input_a, exponent_floor, output_mem_config, output_tensor); - return output_tensor.value(); - } - return ttnn::power(queue_id, input_a, exponent_floor, output_mem_config); - } - const float exponent_trunc = exponent - static_cast(exponent_floor); - Tensor pow_trunc_log = ttnn::multiply( - queue_id, ttnn::log(queue_id, input_a, output_mem_config), exponent_trunc, std::nullopt, output_mem_config); - Tensor pow_frac = ttnn::exp(queue_id, pow_trunc_log, false, output_mem_config); - pow_trunc_log.deallocate(); - float t_nan = std::nanf(""); - Tensor result = ttnn::multiply( - queue_id, - ttnn::power(queue_id, input_a, exponent_floor, output_mem_config), - pow_frac, - std::nullopt, - output_mem_config); - // To handle negative inputs: - // in torch For -ve inputs with float exponent power returns nan - auto output_memory_config = output_tensor.has_value() ? output_tensor.value().memory_config() - : output_mem_config.value_or(input_a.memory_config()); - result = ttnn::where( - ttnn::ltz(queue_id, input_a, output_mem_config), t_nan, result, output_memory_config, output_tensor); - return result; -} - -// power - floating point exponent -Tensor ExecutePower::invoke( - const Tensor& input_a, - float exponent, - const std::optional& output_mem_config, - std::optional output_tensor) { - return ExecutePower::invoke(DefaultQueueId, input_a, exponent, output_mem_config, std::move(output_tensor)); -} - -// power - integer exponent -Tensor ExecutePower::invoke( - uint8_t queue_id, - const Tensor& input, - uint32_t exponent, - const std::optional& output_mem_config, - std::optional output_tensor) { - return ttnn::power(queue_id, input, exponent, output_mem_config, output_tensor); -} - -// power - integer exponent -Tensor ExecutePower::invoke( - const Tensor& input, - uint32_t exponent, - const std::optional& output_mem_config, - std::optional output_tensor) { - return ExecutePower::invoke(DefaultQueueId, input, exponent, output_mem_config, std::move(output_tensor)); -} - -// power - tensor exponent -Tensor ExecutePower::invoke( - uint8_t queue_id, - const Tensor& input, - const Tensor& exponent, - const std::optional& output_mem_config, - std::optional output_tensor) { - return ttnn::power_binary(queue_id, input, exponent, std::nullopt, output_mem_config, output_tensor); -} - -// power - tensor exponent -Tensor ExecutePower::invoke( - const Tensor& input, - const Tensor& exponent, - const std::optional& output_mem_config, - std::optional output_tensor) { - return ExecutePower::invoke(DefaultQueueId, input, exponent, output_mem_config, std::move(output_tensor)); -} - -// power - scalar input -Tensor ExecutePower::invoke( - uint8_t queue_id, - float input_a, - const Tensor& exponent, - const std::optional& output_mem_config, - std::optional output_tensor) { - Tensor input = ttnn::full_like(exponent, input_a); - return ExecutePower::invoke(queue_id, input, exponent, output_mem_config, std::move(output_tensor)); -} - -// power - scalar input -Tensor ExecutePower::invoke( - float input_a, - const Tensor& exponent, - const std::optional& output_mem_config, - std::optional output_tensor) { - return ExecutePower::invoke(DefaultQueueId, input_a, exponent, output_mem_config, std::move(output_tensor)); -} - // acosh(x) = log(x + sqrt(x^2 - 1)) Tensor _acosh(const Tensor& input_a, const std::optional& output_mem_config) { TT_FATAL(input_a.storage_type() == StorageType::DEVICE, "Unary operation requires input to be on Device."); diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary/unary.cpp b/ttnn/cpp/ttnn/operations/eltwise/unary/unary.cpp index c87dae81384..ec50c8ce692 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary/unary.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary/unary.cpp @@ -99,7 +99,6 @@ template struct ExecuteUnary; template struct ExecuteUnary; template struct ExecuteUnary; template struct ExecuteUnary; -template struct ExecuteUnary; template struct ExecuteUnary; template struct ExecuteUnary; template struct ExecuteUnary; @@ -362,6 +361,32 @@ Tensor Floor::invoke( DefaultQueueId, input_tensor, {UnaryWithParam{op_type}}, memory_config, optional_output_tensor); } +Tensor Ceil::invoke( + uint8_t queue_id, + const Tensor& input_tensor, + const std::optional& memory_config, + const std::optional& optional_output_tensor) { + UnaryOpType op_type = UnaryOpType::CEIL; + if (input_tensor.get_dtype() == DataType::FLOAT32) { + op_type = UnaryOpType::CEIL_FLOAT32; + } + + return detail::unary_impl(queue_id, input_tensor, {UnaryWithParam{op_type}}, memory_config, optional_output_tensor); +} + +Tensor Ceil::invoke( + const Tensor& input_tensor, + const std::optional& memory_config, + const std::optional& optional_output_tensor) { + UnaryOpType op_type = UnaryOpType::CEIL; + if (input_tensor.get_dtype() == DataType::FLOAT32) { + op_type = UnaryOpType::CEIL_FLOAT32; + } + + return detail::unary_impl( + DefaultQueueId, input_tensor, {UnaryWithParam{op_type}}, memory_config, optional_output_tensor); +} + Tensor Dropout::invoke( const Tensor& input, const uint32_t seed, diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary/unary.hpp b/ttnn/cpp/ttnn/operations/eltwise/unary/unary.hpp index a5a8d89087d..abe259f03de 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary/unary.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary/unary.hpp @@ -161,6 +161,18 @@ struct Floor { const std::optional& optional_output_tensor = std::nullopt); }; +struct Ceil { + static Tensor invoke( + uint8_t queue_id, + const Tensor& input_tensor, + const std::optional& memory_config = std::nullopt, + const std::optional& optional_output_tensor = std::nullopt); + + static Tensor invoke( + const Tensor& input_tensor, + const std::optional& memory_config = std::nullopt, + const std::optional& optional_output_tensor = std::nullopt); +}; struct Dropout { static Tensor invoke( const Tensor& input, @@ -294,7 +306,6 @@ REGISTER_UNARY_OPERATION(erfinv, ERFINV); REGISTER_UNARY_OPERATION(exp2, EXP2); REGISTER_UNARY_OPERATION(expm1, EXPM1); REGISTER_UNARY_OPERATION(eqz, EQZ); -REGISTER_UNARY_OPERATION(ceil, CEIL); REGISTER_UNARY_OPERATION(gez, GEZ); REGISTER_UNARY_OPERATION(gtz, GTZ); REGISTER_UNARY_OPERATION(i0, I0); @@ -341,7 +352,6 @@ REGISTER_UNARY_OPERATION_WITH_FAST_AND_APPROXIMATE_MODE(rsqrt, RSQRT); // Unaries with float parameter REGISTER_UNARY_OPERATION_WITH_FLOAT_PARAMETER(elu, ELU); -REGISTER_UNARY_OPERATION_WITH_FLOAT_PARAMETER(rsub_unary, RSUB); REGISTER_UNARY_OPERATION_WITH_FLOAT_PARAMETER(heaviside, HEAVISIDE); REGISTER_UNARY_OPERATION_WITH_FLOAT_PARAMETER(leaky_relu, LEAKY_RELU); REGISTER_UNARY_OPERATION_WITH_FLOAT_PARAMETER(relu_max, RELU_MAX); @@ -357,9 +367,6 @@ REGISTER_UNARY_OPERATION_WITH_FLOAT_PARAMETER(ne_unary, UNARY_NE); REGISTER_UNARY_OPERATION_WITH_INTEGER_PARAMETER(power, POWER, uint32_t); REGISTER_UNARY_OPERATION_WITH_INTEGER_PARAMETER(bitwise_left_shift, LEFT_SHIFT, int32_t); REGISTER_UNARY_OPERATION_WITH_INTEGER_PARAMETER(bitwise_right_shift, RIGHT_SHIFT, int32_t); -REGISTER_UNARY_OPERATION_WITH_INTEGER_PARAMETER(bitwise_and_unary, BITWISE_AND, int32_t); -REGISTER_UNARY_OPERATION_WITH_INTEGER_PARAMETER(bitwise_or_unary, BITWISE_OR, int32_t); -REGISTER_UNARY_OPERATION_WITH_INTEGER_PARAMETER(bitwise_xor_unary, BITWISE_XOR, int32_t); // Other unaries constexpr auto dropout = @@ -368,6 +375,7 @@ constexpr auto identity = ttnn::register_operation_with_auto_launch_op<"ttnn::identity", ttnn::operations::unary::Identity>(); constexpr auto floor = ttnn::register_operation_with_auto_launch_op<"ttnn::floor", ttnn::operations::unary::Floor>(); +constexpr auto ceil = ttnn::register_operation_with_auto_launch_op<"ttnn::ceil", ttnn::operations::unary::Ceil>(); constexpr auto softplus = ttnn::register_operation_with_auto_launch_op<"ttnn::softplus", ttnn::operations::unary::Softplus>(); constexpr auto prelu_sfpu = diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary/unary_composite.hpp b/ttnn/cpp/ttnn/operations/eltwise/unary/unary_composite.hpp index a523e09cb30..b7c2535aa01 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary/unary_composite.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary/unary_composite.hpp @@ -12,69 +12,6 @@ namespace ttnn { namespace operations { namespace unary { -/** - * @brief Performs element-wise power operation on the input with the exponent. - * When exponent is Tensor, the supported dtypes are float32 and bfloat16. - * The tested range for the input is (-30,30) and for the exponent is (-20, 20). - * - * @param input The input tensor, i.e the base. - * @param exponent The exponent - * @return The result tensor - */ -struct ExecutePower { - static Tensor invoke( - uint8_t queue_id, - const Tensor& input_tensor, - uint32_t exponent, - const std::optional& memory_config = std::nullopt, - std::optional optional_output_tensor = std::nullopt); - - static Tensor invoke( - const Tensor& input_tensor, - uint32_t exponent, - const std::optional& memory_config = std::nullopt, - std::optional optional_output_tensor = std::nullopt); - - static Tensor invoke( - uint8_t queue_id, - const Tensor& input_tensor, - float exponent, - const std::optional& memory_config = std::nullopt, - std::optional optional_output_tensor = std::nullopt); - - static Tensor invoke( - const Tensor& input_tensor, - float exponent, - const std::optional& memory_config = std::nullopt, - std::optional optional_output_tensor = std::nullopt); - - static Tensor invoke( - uint8_t queue_id, - float input_a, - const Tensor& exponent, - const std::optional& memory_config = std::nullopt, - std::optional optional_output_tensor = std::nullopt); - - static Tensor invoke( - float input_a, - const Tensor& exponent, - const std::optional& memory_config = std::nullopt, - std::optional optional_output_tensor = std::nullopt); - - static Tensor invoke( - uint8_t queue_id, - const Tensor& input_tensor, - const Tensor& exponent, - const std::optional& memory_config = std::nullopt, - std::optional optional_output_tensor = std::nullopt); - - static Tensor invoke( - const Tensor& input_tensor, - const Tensor& exponent, - const std::optional& memory_config = std::nullopt, - std::optional optional_output_tensor = std::nullopt); -}; - template struct ExecuteUnaryCompositeOp { static Tensor invoke(const Tensor& input_tensor, const std::optional& memory_config = std::nullopt) { @@ -235,7 +172,6 @@ auto transform_first_matching_arg(Lambda lambda, First&& first, Rest&&... rest) constexpr auto rdiv = ttnn::register_operation_with_auto_launch_op<"ttnn::rdiv", operations::unary::ExecuteRdiv>(); -constexpr auto pow = ttnn::register_operation_with_auto_launch_op<"ttnn::pow", operations::unary::ExecutePower>(); constexpr auto tanhshrink = ttnn::register_operation_with_auto_launch_op< "ttnn::tanhshrink", operations::unary::ExecuteUnaryCompositeOp>(); diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary/unary_pybind.hpp b/ttnn/cpp/ttnn/operations/eltwise/unary/unary_pybind.hpp index 5f0a6514b35..d343f42bb92 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary/unary_pybind.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary/unary_pybind.hpp @@ -975,125 +975,6 @@ void bind_identity(py::module& module, const unary_operation_t& operation) { py::arg("queue_id") = 0}); } -template -void bind_power(py::module& module, const unary_operation_t& operation, const std::string& note = "") { - auto doc = fmt::format( - R"doc( - Perform element-wise {0} operation on :attr:`input_tensor` with :attr:`exponent`. - - .. math:: - \mathrm{{output\_tensor}}_i = \verb|{0}|(\mathrm{{input\_tensor}}_i ** \mathrm{{exponent}}_i) - - Args: - input_tensor (ttnn.Tensor): the input tensor. - exponent (float, int): the exponent value. - - Keyword Args: - memory_config (ttnn.MemoryConfig, optional): memory configuration for the operation. Defaults to `None`. - output_tensor (ttnn.Tensor, optional): preallocated output tensor. Defaults to `None`. - queue_id (int, optional): command queue id. Defaults to `0`. - - Returns: - ttnn.Tensor: the output tensor. - - Note: - Supported dtypes, layouts, and ranks: - - .. list-table:: - :header-rows: 1 - - * - Dtypes - - Layouts - - Ranks - * - BFLOAT16, BFLOAT8_B - - TILE - - 2, 3, 4 - - {2} - - Example: - >>> tensor = ttnn.from_torch(torch.tensor([[1, 2], [3, 4]], dtype=torch.bfloat16), layout=ttnn.TILE_LAYOUT, device=device) - >>> exponent = 2 - >>> output = {1}(tensor, exponent) - )doc", - ttnn::pow.base_name(), - ttnn::pow.python_fully_qualified_name(), - note); - - bind_registered_operation( - module, - ttnn::pow, - doc, - // integer exponent - ttnn::pybind_overload_t{ - [](const unary_operation_t& self, - const Tensor& input_tensor, - uint32_t exponent, - const std::optional& memory_config, - std::optional output_tensor, - const uint8_t queue_id) -> ttnn::Tensor { - return self(queue_id, input_tensor, exponent, memory_config, output_tensor); - }, - py::arg("input_tensor"), - py::arg("exponent"), - py::kw_only(), - py::arg("memory_config") = std::nullopt, - py::arg("output_tensor") = std::nullopt, - py::arg("queue_id") = 0}, - - // float exponent - ttnn::pybind_overload_t{ - [](const unary_operation_t& self, - const Tensor& input_tensor, - float exponent, - const std::optional& memory_config, - std::optional output_tensor, - const uint8_t queue_id) -> ttnn::Tensor { - return self(queue_id, input_tensor, exponent, memory_config, output_tensor); - }, - py::arg("input_tensor"), - py::arg("exponent"), - py::kw_only(), - py::arg("memory_config") = std::nullopt, - py::arg("output_tensor") = std::nullopt, - py::arg("queue_id") = ttnn::DefaultQueueId}, - - // tensor exponent - ttnn::pybind_overload_t{ - [](const unary_operation_t& self, - const Tensor& input_tensor, - const Tensor& exponent, - const std::optional& memory_config, - std::optional output_tensor, - const uint8_t queue_id) -> ttnn::Tensor { - return self(queue_id, input_tensor, exponent, memory_config, output_tensor); - }, - py::arg("input_tensor"), - py::arg("exponent"), - py::kw_only(), - py::arg("memory_config") = std::nullopt, - py::arg("output_tensor") = std::nullopt, - py::arg("queue_id") = ttnn::DefaultQueueId}, - - // scalar input - tensor exponent - ttnn::pybind_overload_t{ - [](const unary_operation_t& self, - float input, - const Tensor& exponent, - const std::optional& memory_config, - std::optional output_tensor, - const uint8_t queue_id) -> ttnn::Tensor { - return self(queue_id, input, exponent, memory_config, output_tensor); - }, - py::arg("input"), - py::arg("exponent"), - py::kw_only(), - py::arg("memory_config") = std::nullopt, - py::arg("output_tensor") = std::nullopt, - py::arg("queue_id") = ttnn::DefaultQueueId} - ); -} - template void bind_unary_composite( py::module& module, @@ -1923,7 +1804,6 @@ void py_module(py::module& module) { detail::bind_sigmoid_accurate(module, ttnn::sigmoid_accurate); detail::bind_unary_chain(module, ttnn::unary_chain); detail::bind_identity(module, ttnn::identity); - detail::bind_power(module, ttnn::pow, R"doc(When :attr:`exponent` is a Tensor, supported dtypes are: BFLOAT16, FLOAT32)doc"); // unary composite imported into ttnn detail::bind_unary_composite(module, ttnn::deg2rad, R"doc(Performs deg2rad function on :attr:`input_tensor`.)doc", "", R"doc(BFLOAT16, BFLOAT8_B)doc"); diff --git a/ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_tile_layout_in0_sender_padding.cpp b/ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_tile_layout_in0_sender_padding.cpp index c84a4d4825a..0532ae005cc 100644 --- a/ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_tile_layout_in0_sender_padding.cpp +++ b/ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_tile_layout_in0_sender_padding.cpp @@ -71,14 +71,17 @@ void kernel_main() { // In case we need to send multiple blocks per shard, in0 sharded cb is cb2 and we extract the sub-blocks to cb0 constexpr uint32_t shard_read_stride = shard_width_in_tiles * in0_single_tile_size_bytes; constexpr uint32_t shard_read_width = in0_single_tile_size_bytes * in0_block_w; + constexpr uint32_t shard_num_tiles = shard_width_in_tiles * shard_height_in_tiles; + constexpr uint32_t in0_tensor_next_h_dim_block_stride_bytes = + in0_tensor_next_h_dim_block_stride * in0_single_tile_size_bytes; - uint64_t noc_shard_read_start_addr = 0; + uint32_t noc_shard_read_start_addr = 0; if constexpr (extract_shard_sub_blocks) { constexpr uint32_t cb_id_in2 = 2; // in0 sharded cb if extract_shard_sub_blocks - noc_shard_read_start_addr = get_noc_addr(get_read_ptr(cb_id_in2)); + noc_shard_read_start_addr = get_read_ptr(cb_id_in2); } else { - cb_reserve_back(cb_id_in0, in0_block_num_tiles); - cb_push_back(cb_id_in0, in0_block_num_tiles); + cb_reserve_back(cb_id_in0, shard_num_tiles); + cb_push_back(cb_id_in0, shard_num_tiles); } #else constexpr DataFormat in0_data_format = get_dataformat(cb_id_in0); @@ -113,9 +116,15 @@ void kernel_main() { #endif for (uint32_t b = 0; b < batch; ++b) { +#ifdef IN0_SHARDED + uint32_t in0_tensor_current_h_dim_block_start_addr = noc_shard_read_start_addr; +#endif uint32_t in0_tensor_current_h_dim_block_tile_id = in0_tensor_start_tile_id; for (uint32_t bh = 0; bh < num_blocks_h_dim; ++bh) { for (uint32_t bw = 0; bw < num_blocks_w_dim; ++bw) { +#ifdef IN0_SHARDED + uint32_t in0_tensor_current_inner_dim_block_start_addr = in0_tensor_current_h_dim_block_start_addr; +#endif uint32_t in0_tensor_current_inner_dim_block_start_tile_id = in0_tensor_current_h_dim_block_tile_id; for (uint32_t block = 0; block < num_blocks_inner_dim; ++block) { if constexpr (fuse_op) { @@ -159,16 +168,16 @@ void kernel_main() { in0_start_address = l1_write_addr_in0; // copy start address of block, to be used for mcasting #endif - uint64_t noc_shard_read_addr = noc_shard_read_start_addr; - noc_shard_read_start_addr += shard_read_width; + uint64_t noc_shard_read_addr = get_noc_addr(in0_tensor_current_inner_dim_block_start_addr); - for (uint32_t i = 0; i < shard_height_in_tiles; i++) { + for (uint32_t i = 0; i < in0_block_h; i++) { noc_async_read(noc_shard_read_addr, l1_write_addr_in0, shard_read_width); l1_write_addr_in0 += shard_read_width; noc_shard_read_addr += shard_read_stride; } + in0_tensor_current_inner_dim_block_start_addr += shard_read_width; noc_async_read_barrier(); } #endif @@ -216,6 +225,9 @@ void kernel_main() { #endif } } +#ifdef IN0_SHARDED + in0_tensor_current_h_dim_block_start_addr += in0_tensor_next_h_dim_block_stride_bytes; +#endif in0_tensor_current_h_dim_block_tile_id += in0_tensor_next_h_dim_block_stride; } in0_tensor_start_tile_id += MtKt; diff --git a/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.cpp b/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.cpp index a57ef49b74c..c4745856b80 100644 --- a/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.cpp +++ b/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.cpp @@ -306,6 +306,8 @@ MatmulProgramConfig create_matmul_1d_systolic_array_program_config( .in0_block_w = k_tiles_per_core, .out_subblock_h = out_subblock_h, .out_subblock_w = out_subblock_w, + .out_block_h = batch_and_m_tiles_per_core, + .out_block_w = n_tiles_per_core, .per_core_M = batch_and_m_tiles_per_core, .per_core_N = n_tiles_per_core, .fuse_batch = true, @@ -357,6 +359,8 @@ MatmulMultiCoreReuseMultiCast1DProgramConfig get_mcast_1d_config( .in0_block_w = in0_block_w, .out_subblock_h = out_subblock_h, .out_subblock_w = out_subblock_w, + .out_block_h = per_core_M, + .out_block_w = per_core_N, .per_core_M = per_core_M, .per_core_N = per_core_N, .fuse_batch = fuse_batch, @@ -701,6 +705,8 @@ MatmulProgramConfig get_matmul_program_config( .in0_block_w = in0_block_w, .out_subblock_h = out_subblock_h, .out_subblock_w = out_subblock_w, + .out_block_h = per_core_M, + .out_block_w = per_core_N, .per_core_M = per_core_M, .per_core_N = per_core_N, .fuse_batch = true, @@ -1182,6 +1188,26 @@ void Matmul::validate( // TODO: For 1D and 2D mcasts, we don't check if tensor is single core or single row/col // We can uplift these variants to skip mcasting to support single core (1D) or single row/col (2D) if constexpr (std::is_same_v) { + TT_FATAL( + program_config.per_core_M % program_config.out_block_h == 0, + "Error: incompatible values {} and {}", + program_config.per_core_M, + program_config.out_block_h); + TT_FATAL( + program_config.per_core_N % program_config.out_block_w == 0, + "Error: incompatible values {} and {}", + program_config.per_core_N, + program_config.out_block_w); + TT_FATAL( + program_config.out_block_h % program_config.out_subblock_h == 0, + "Error: incompatible values {} and {}", + program_config.out_block_h, + program_config.out_subblock_h); + TT_FATAL( + program_config.out_block_w % program_config.out_subblock_w == 0, + "Error: incompatible values {} and {}", + program_config.out_block_w, + program_config.out_subblock_w); TT_FATAL( !(program_config.mcast_in0 && program_config.gather_in0), "Matmul1D does not support mcast_in0 and gather_in0 at the same time."); @@ -1806,6 +1832,8 @@ operation::ProgramWithCallbacks Matmul::create_program( program_config.in0_block_w, program_config.out_subblock_h, program_config.out_subblock_w, + program_config.out_block_h, + program_config.out_block_w, program_config.per_core_M, program_config.per_core_N, program_config.fuse_batch, diff --git a/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.hpp b/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.hpp index 4eea7a50f19..a4b41cb6519 100644 --- a/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.hpp +++ b/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.hpp @@ -43,6 +43,8 @@ operation::ProgramWithCallbacks matmul_multi_core_reuse_mcast_1d_optimized( uint32_t in0_block_w, uint32_t out_subblock_h, uint32_t out_subblock_w, + uint32_t out_block_h, + uint32_t out_block_w, uint32_t per_core_M, uint32_t per_core_N, bool fuse_batch, @@ -130,6 +132,8 @@ struct MatmulMultiCoreReuseMultiCast1DProgramConfig { std::size_t in0_block_w; std::size_t out_subblock_h; std::size_t out_subblock_w; + std::size_t out_block_h; + std::size_t out_block_w; std::size_t per_core_M; std::size_t per_core_N; bool fuse_batch; diff --git a/ttnn/cpp/ttnn/operations/matmul/device/matmul_op_multi_core_reuse_mcast_1d_program_factory.cpp b/ttnn/cpp/ttnn/operations/matmul/device/matmul_op_multi_core_reuse_mcast_1d_program_factory.cpp index b5352ab0d45..4e299855332 100644 --- a/ttnn/cpp/ttnn/operations/matmul/device/matmul_op_multi_core_reuse_mcast_1d_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/matmul/device/matmul_op_multi_core_reuse_mcast_1d_program_factory.cpp @@ -71,6 +71,8 @@ operation::ProgramWithCallbacks create_program_mcast_in0( uint32_t in0_block_w, uint32_t out_subblock_h, uint32_t out_subblock_w, + uint32_t out_block_h, + uint32_t out_block_w, uint32_t per_core_M, uint32_t per_core_N, std::optional fused_activation, @@ -113,7 +115,16 @@ operation::ProgramWithCallbacks create_program_mcast_in0( uint32_t output_single_tile_size = output_tile.get_tile_size(output_data_format); uint32_t interm0_single_tile_size = output_tile.get_tile_size(interm0_data_format); - uint32_t in0_block_tiles = per_core_M * in0_block_w; + bool do_not_inplace_interm0_out_CB = output_is_sharded && (per_core_M != out_block_h); + + uint32_t in0_block_h = out_block_h; + uint32_t in1_block_w = out_block_w; + uint32_t in0_num_blocks_y = per_core_M / out_block_h; + uint32_t in1_num_blocks_x = per_core_N / out_block_w; + uint32_t out_num_blocks_x = in1_num_blocks_x; + uint32_t out_num_blocks_y = in0_num_blocks_y; + + uint32_t in0_block_tiles = in0_block_h * in0_block_w; uint32_t in0_CB_tiles = in0_block_tiles; if (B * num_blocks > 1) { in0_CB_tiles = in0_CB_tiles * 2; // double buffer @@ -131,7 +142,7 @@ operation::ProgramWithCallbacks create_program_mcast_in0( uint32_t in2_CB_tiles = in2_block_tiles; uint32_t in2_CB_size = in2_CB_tiles * in0_single_tile_size; - uint32_t in1_block_tiles = per_core_N * in0_block_w; + uint32_t in1_block_tiles = out_block_w * in0_block_w; uint32_t in1_CB_tiles = in1_block_tiles; if (B * num_blocks > 1) { in1_CB_tiles = in1_CB_tiles * 2; // double buffer @@ -143,12 +154,17 @@ operation::ProgramWithCallbacks create_program_mcast_in0( uint32_t in1_CB_size = in1_CB_tiles * in1_single_tile_size; - uint32_t out_block_tiles = per_core_M * per_core_N; + uint32_t out_block_tiles = out_block_h * out_block_w; + uint32_t out_shard_tiles = per_core_M * per_core_N; uint32_t out_CB_tiles = out_block_tiles; // No double buffer + if (output_is_sharded) { + out_CB_tiles = out_shard_tiles; + } uint32_t out_CB_size = out_CB_tiles * output_single_tile_size; - uint32_t interm0_CB_size = out_CB_tiles * interm0_single_tile_size; + uint32_t interm0_CB_tiles = out_block_tiles; // No double buffer + uint32_t interm0_CB_size = interm0_CB_tiles * interm0_single_tile_size; - uint32_t in3_block_tiles = per_core_N; + uint32_t in3_block_tiles = out_block_w; uint32_t in3_CB_tiles = in3_block_tiles; // No double buffer uint32_t in3_CB_size = in3_CB_tiles * bias_single_tile_size; @@ -252,7 +268,7 @@ operation::ProgramWithCallbacks create_program_mcast_in0( } bool out_is_dram = out_buffer->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0; - uint32_t in0_num_subblocks = (per_core_M / out_subblock_h); + uint32_t in0_num_subblocks = (out_block_h / out_subblock_h); uint32_t in0_block_num_tiles = out_subblock_h * in0_block_w * in0_num_subblocks; std::vector in0_sender_compile_time_args; @@ -264,9 +280,9 @@ operation::ProgramWithCallbacks create_program_mcast_in0( (std::uint32_t)in0_block_num_tiles, // in0_block_num_tiles (std::uint32_t)in0_block_num_tiles * in0_single_tile_size, // in0_block_size_bytes // in0/in1 common args - (std::uint32_t)num_blocks, // num_blocks - (std::uint32_t)1, // num_blocks_x - (std::uint32_t)1, // num_blocks_y + (std::uint32_t)num_blocks, // num_blocks + (std::uint32_t)out_num_blocks_x, // num_blocks_x + (std::uint32_t)out_num_blocks_y, // num_blocks_y // in0 mcast args (std::uint32_t)in0_mcast_sender_semaphore_id, (std::uint32_t)in0_mcast_receiver_semaphore_id, @@ -278,7 +294,7 @@ operation::ProgramWithCallbacks create_program_mcast_in0( (std::uint32_t)(in0_shard_width_in_tiles), (std::uint32_t)(in0_shard_height_in_tiles), (std::uint32_t)(in0_block_w), - (std::uint32_t)per_core_M, // in0_block_h + (std::uint32_t)in0_block_h, // in0_block_h // batch args (std::uint32_t)B // batch @@ -289,21 +305,21 @@ operation::ProgramWithCallbacks create_program_mcast_in0( (std::uint32_t)in0_is_dram, // in0 tensor args - (std::uint32_t)1, // in0_tensor_stride_w - (std::uint32_t)K, // in0_tensor_stride_h - (std::uint32_t)in0_block_w, // in0_tensor_next_block_stride - (std::uint32_t)K * per_core_M, // in0_tensor_next_h_dim_block_stride + (std::uint32_t)1, // in0_tensor_stride_w + (std::uint32_t)K, // in0_tensor_stride_h + (std::uint32_t)in0_block_w, // in0_tensor_next_block_stride + (std::uint32_t)K * in0_block_h, // in0_tensor_next_h_dim_block_stride // in0 block args - (std::uint32_t)in0_block_w, // in0_block_w - (std::uint32_t)per_core_M, // in0_block_h - (std::uint32_t)in0_block_w * per_core_M, // in0_block_num_tiles - (std::uint32_t) false, // extract_shard_sub_blocks (not used for interleaved) - (std::uint32_t)0, // shard_width_in_tiles (not used for interleaved) - (std::uint32_t)0, // shard_height_in_tiles (not used for interleaved) + (std::uint32_t)in0_block_w, // in0_block_w + (std::uint32_t)in0_block_h, // in0_block_h + (std::uint32_t)in0_block_num_tiles, // in0_block_num_tiles + (std::uint32_t)false, // extract_shard_sub_blocks (not used for interleaved) + (std::uint32_t)0, // shard_width_in_tiles (not used for interleaved) + (std::uint32_t)0, // shard_height_in_tiles (not used for interleaved) // in0/in1 common args - (std::uint32_t)num_blocks, // num_blocks - (std::uint32_t)1, // num_blocks_x - (std::uint32_t)1, // num_blocks_y + (std::uint32_t)num_blocks, // num_blocks + (std::uint32_t)out_num_blocks_x, // num_blocks_x + (std::uint32_t)out_num_blocks_y, // num_blocks_y // in0 mcast args (std::uint32_t)in0_mcast_sender_semaphore_id, (std::uint32_t)in0_mcast_receiver_semaphore_id, @@ -326,15 +342,15 @@ operation::ProgramWithCallbacks create_program_mcast_in0( (std::uint32_t)1, // in1_tensor_stride_w (std::uint32_t)N, // in1_tensor_stride_h (std::uint32_t)in0_block_w * N, // in1_tensor_next_block_stride - (std::uint32_t)per_core_N, // in1_tensor_next_w_dim_block_stride + (std::uint32_t)in1_block_w, // in1_tensor_next_w_dim_block_stride // in1 block args - (std::uint32_t)per_core_N, // in1_block_w - (std::uint32_t)in0_block_w, // in1_block_h - (std::uint32_t)per_core_N * in0_block_w, // in1_block_num_tiles + (std::uint32_t)in1_block_w, // in1_block_w + (std::uint32_t)in0_block_w, // in1_block_h + (std::uint32_t)in1_block_w * in0_block_w, // in1_block_num_tiles // in0/in1 common args - (std::uint32_t)num_blocks, // num_blocks - (std::uint32_t)1, // out_num_blocks_x - (std::uint32_t)1, // out_num_blocks_y + (std::uint32_t)num_blocks, // num_blocks + (std::uint32_t)out_num_blocks_x, // out_num_blocks_x + (std::uint32_t)out_num_blocks_y, // out_num_blocks_y // in1 mcast args (std::uint32_t)0, (std::uint32_t)0, @@ -351,8 +367,8 @@ operation::ProgramWithCallbacks create_program_mcast_in0( (std::uint32_t)N, // out_tensor_stride_h (std::uint32_t)out_subblock_w, // out_tensor_next_subblock_stride_w (std::uint32_t)out_subblock_h * N, // out_tensor_next_subblock_stride_h - (std::uint32_t)per_core_N, // out_tensor_next_w_dim_block_stride - (std::uint32_t)per_core_M * N, // out_tensor_next_h_dim_block_stride + (std::uint32_t)out_block_w, // out_tensor_next_w_dim_block_stride + (std::uint32_t)out_block_h * N, // out_tensor_next_h_dim_block_stride // out subblock args (std::uint32_t)out_subblock_w, // out_subblock_w (std::uint32_t)out_subblock_h, // out_subblock_h @@ -372,11 +388,11 @@ operation::ProgramWithCallbacks create_program_mcast_in0( std::vector in0_receiver_compile_time_args = { // in0 block args - (std::uint32_t)in0_block_w * per_core_M, // in0_block_num_tiles + (std::uint32_t)in0_block_num_tiles, // in0_block_num_tiles // in0/in1 common args - (std::uint32_t)num_blocks, // num_blocks - (std::uint32_t)1, // out_num_blocks_x - (std::uint32_t)1, // out_num_blocks_y + (std::uint32_t)num_blocks, // num_blocks + (std::uint32_t)out_num_blocks_x, // out_num_blocks_x + (std::uint32_t)out_num_blocks_y, // out_num_blocks_y // in0 mcast args (std::uint32_t)in0_mcast_sender_semaphore_id, (std::uint32_t)in0_mcast_receiver_semaphore_id, @@ -518,7 +534,7 @@ operation::ProgramWithCallbacks create_program_mcast_in0( uint32_t in0_subblock_num_tiles = out_subblock_h * in0_block_w; - uint32_t in1_num_subblocks = (per_core_N / out_subblock_w); + uint32_t in1_num_subblocks = (out_block_w / out_subblock_w); uint32_t in1_block_num_tiles = out_subblock_w * in0_block_w * in1_num_subblocks; uint32_t in1_per_core_w = out_subblock_w * in1_num_subblocks; @@ -534,9 +550,9 @@ operation::ProgramWithCallbacks create_program_mcast_in0( in1_block_num_tiles, // in1_block_num_tiles in1_per_core_w, // in1_per_core_w - num_blocks, // num_blocks - 1, // out_num_blocks_x - 1, // out_num_blocks_y + num_blocks, // num_blocks + out_num_blocks_x, // out_num_blocks_x + out_num_blocks_y, // out_num_blocks_y out_subblock_h, // out_subblock_h out_subblock_w, // out_subblock_w @@ -627,7 +643,8 @@ operation::ProgramWithCallbacks create_program_mcast_in0( tt_metal::CircularBufferConfig output_cb_config = tt_metal::CircularBufferConfig(0, {{output_cb_index, output_data_format}}); - if ((interm0_data_format != output_data_format) || (untilize_out && (in1_num_subblocks > 1))) { + if (do_not_inplace_interm0_out_CB || (interm0_data_format != output_data_format) || + (untilize_out && (in1_num_subblocks > 1))) { // output std::map output_cb_data_format_spec{ {output_cb_index, output_data_format}, @@ -697,20 +714,22 @@ operation::ProgramWithCallbacks create_program_mcast_in0( } // Parameters for last row, col, or block - uint32_t last_block_h = M % per_core_M == 0 ? per_core_M : M % per_core_M; - uint32_t last_block_w = N % per_core_N == 0 ? per_core_N : N % per_core_N; - uint32_t last_block_num_nonzero_subblocks_h = (last_block_h - 1) / out_subblock_h + 1; - uint32_t last_block_num_nonzero_subblocks_w = (last_block_w - 1) / out_subblock_w + 1; + uint32_t last_per_core_M = M % per_core_M == 0 ? per_core_M : M % per_core_M; + uint32_t last_per_core_N = N % per_core_N == 0 ? per_core_N : N % per_core_N; + uint32_t last_out_block_h = last_per_core_M % out_block_h == 0 ? out_block_h : last_per_core_M % out_block_h; + uint32_t last_out_block_w = last_per_core_N % out_block_w == 0 ? out_block_w : last_per_core_N % out_block_w; + uint32_t last_block_num_nonzero_subblocks_h = (last_out_block_h - 1) / out_subblock_h + 1; + uint32_t last_block_num_nonzero_subblocks_w = (last_out_block_w - 1) / out_subblock_w + 1; uint32_t last_subblock_of_last_block_h = - last_block_h % out_subblock_h == 0 ? out_subblock_h : last_block_h % out_subblock_h; + last_out_block_h % out_subblock_h == 0 ? out_subblock_h : last_out_block_h % out_subblock_h; uint32_t last_subblock_of_last_block_w = - last_block_w % out_subblock_w == 0 ? out_subblock_w : last_block_w % out_subblock_w; + last_out_block_w % out_subblock_w == 0 ? out_subblock_w : last_out_block_w % out_subblock_w; uint32_t last_block_padded_subblock_tiles_addr_skip = output_single_tile_size * (out_subblock_w - last_subblock_of_last_block_w); uint32_t last_block_padded_block_tiles_w_skip = - (out_subblock_w * out_subblock_h) * (per_core_N / out_subblock_w - last_block_num_nonzero_subblocks_w); + (out_subblock_w * out_subblock_h) * (out_block_w / out_subblock_w - last_block_num_nonzero_subblocks_w); uint32_t last_block_padded_block_tiles_h_skip = - (per_core_M / out_subblock_h - last_block_num_nonzero_subblocks_h) * (per_core_N * out_subblock_h); + (out_block_h / out_subblock_h - last_block_num_nonzero_subblocks_h) * (out_block_w * out_subblock_h); CoreCoord start_core_noc = top_left_core_physical; CoreCoord end_core_noc = bottom_right_core_physical; @@ -772,7 +791,7 @@ operation::ProgramWithCallbacks create_program_mcast_in0( (std::uint32_t)end_core_noc.y, // in0_mcast_dest_noc_end_y // padding args - (std::uint32_t)per_core_M // last_block_h + (std::uint32_t)out_block_h // last_block_h }; if (fuse_op) { @@ -815,27 +834,27 @@ operation::ProgramWithCallbacks create_program_mcast_in0( if (output_idx_x == num_blocks_x - 1) { // padding args (READER) - mm_in1_sender_writer_args.push_back(last_block_w); + mm_in1_sender_writer_args.push_back(last_out_block_w); // padding args (WRITER) - mm_in1_sender_writer_args.push_back(per_core_M / out_subblock_h); + mm_in1_sender_writer_args.push_back(out_block_h / out_subblock_h); mm_in1_sender_writer_args.push_back(out_subblock_h); mm_in1_sender_writer_args.push_back(0); - mm_in1_sender_writer_args.push_back(per_core_N / out_subblock_w); // out_num_nonzero_subblocks_w + mm_in1_sender_writer_args.push_back(out_block_w / out_subblock_w); // out_num_nonzero_subblocks_w mm_in1_sender_writer_args.push_back(last_block_num_nonzero_subblocks_w); mm_in1_sender_writer_args.push_back(last_subblock_of_last_block_w); mm_in1_sender_writer_args.push_back(last_block_padded_subblock_tiles_addr_skip); mm_in1_sender_writer_args.push_back(last_block_padded_block_tiles_w_skip); } else { // padding args (READER) - mm_in1_sender_writer_args.push_back(per_core_N); + mm_in1_sender_writer_args.push_back(out_block_w); // padding args (WRITER) - mm_in1_sender_writer_args.push_back(per_core_M / out_subblock_h); + mm_in1_sender_writer_args.push_back(out_block_h / out_subblock_h); mm_in1_sender_writer_args.push_back(out_subblock_h); mm_in1_sender_writer_args.push_back(0); - mm_in1_sender_writer_args.push_back(per_core_N / out_subblock_w); // out_num_nonzero_subblocks_w - mm_in1_sender_writer_args.push_back(per_core_N / out_subblock_w); + mm_in1_sender_writer_args.push_back(out_block_w / out_subblock_w); // out_num_nonzero_subblocks_w + mm_in1_sender_writer_args.push_back(out_block_w / out_subblock_w); mm_in1_sender_writer_args.push_back(out_subblock_w); mm_in1_sender_writer_args.push_back(0); mm_in1_sender_writer_args.push_back(0); @@ -948,6 +967,8 @@ operation::ProgramWithCallbacks create_program_mcast_in1( uint32_t in0_block_w, uint32_t out_subblock_h, uint32_t out_subblock_w, + uint32_t out_block_h, + uint32_t out_block_w, uint32_t per_core_M, uint32_t per_core_N, std::optional fused_activation, @@ -991,10 +1012,19 @@ operation::ProgramWithCallbacks create_program_mcast_in1( uint32_t output_single_tile_size = output_tile.get_tile_size(output_data_format); uint32_t interm0_single_tile_size = output_tile.get_tile_size(interm0_data_format); - uint32_t in0_block_tiles = per_core_M * in0_block_w; + bool do_not_inplace_interm0_out_CB = output_is_sharded && (per_core_M != out_block_h); + + uint32_t in0_block_h = out_block_h; + uint32_t in1_block_w = out_block_w; + uint32_t in0_num_blocks_y = per_core_M / out_block_h; + uint32_t in1_num_blocks_x = per_core_N / out_block_w; + uint32_t out_num_blocks_x = in1_num_blocks_x; + uint32_t out_num_blocks_y = in0_num_blocks_y; + + uint32_t in0_block_tiles = in0_block_h * in0_block_w; uint32_t in0_CB_tiles = in0_block_tiles; if (in0_is_sharded) { - in0_CB_tiles = num_blocks * in0_CB_tiles * B; + in0_CB_tiles = num_blocks * per_core_M * in0_block_w * B; } else if (B * num_blocks > 1) { in0_CB_tiles = in0_CB_tiles * 2; // double buffer } @@ -1015,20 +1045,27 @@ operation::ProgramWithCallbacks create_program_mcast_in1( extract_shard_sub_blocks = true; } } + uint32_t in2_CB_tiles = in0_block_tiles; + uint32_t in2_CB_size = in2_CB_tiles * in0_single_tile_size; - uint32_t in1_block_tiles = per_core_N * in0_block_w; + uint32_t in1_block_tiles = out_block_w * in0_block_w; uint32_t in1_CB_tiles = in1_block_tiles; if (B * num_blocks > 1) { in1_CB_tiles = in1_CB_tiles * 2; // double buffer } uint32_t in1_CB_size = in1_CB_tiles * in1_single_tile_size; - uint32_t out_block_tiles = per_core_M * per_core_N; + uint32_t out_block_tiles = out_block_h * out_block_w; + uint32_t out_shard_tiles = per_core_M * per_core_N; uint32_t out_CB_tiles = out_block_tiles; // No double buffer + if (output_is_sharded) { + out_CB_tiles = out_shard_tiles; + } uint32_t out_CB_size = out_CB_tiles * output_single_tile_size; - uint32_t interm0_CB_size = out_CB_tiles * interm0_single_tile_size; + uint32_t interm0_CB_tiles = out_block_tiles; // No double buffer + uint32_t interm0_CB_size = interm0_CB_tiles * interm0_single_tile_size; - uint32_t in3_block_tiles = per_core_N; + uint32_t in3_block_tiles = out_block_w; uint32_t in3_CB_tiles = in3_block_tiles; // No double buffer uint32_t in3_CB_size = in3_CB_tiles * bias_single_tile_size; @@ -1078,21 +1115,21 @@ operation::ProgramWithCallbacks create_program_mcast_in1( (std::uint32_t)in0_is_dram, // in0 tensor args - (std::uint32_t)1, // in0_tensor_stride_w - (std::uint32_t)K, // in0_tensor_stride_h - (std::uint32_t)in0_block_w, // in0_tensor_next_block_stride - (std::uint32_t)K * per_core_M, // in0_tensor_next_h_dim_block_stride + (std::uint32_t)1, // in0_tensor_stride_w + (std::uint32_t)K, // in0_tensor_stride_h + (std::uint32_t)in0_block_w, // in0_tensor_next_block_stride + (std::uint32_t)K * in0_block_h, // in0_tensor_next_h_dim_block_stride // in0 block args - (std::uint32_t)in0_block_w, // in0_block_w - (std::uint32_t)per_core_M, // in0_block_h - (std::uint32_t)in0_block_w * per_core_M, // in0_block_num_tiles + (std::uint32_t)in0_block_w, // in0_block_w + (std::uint32_t)in0_block_h, // in0_block_h + (std::uint32_t)in0_block_w * in0_block_h, // in0_block_num_tiles (std::uint32_t)extract_shard_sub_blocks, (std::uint32_t)in0_shard_width_in_tiles, (std::uint32_t)in0_shard_height_in_tiles, // in0/in1 common args - (std::uint32_t)num_blocks, // num_blocks - (std::uint32_t)1, // out_num_blocks_x - (std::uint32_t)1, // out_num_blocks_y + (std::uint32_t)num_blocks, // num_blocks + (std::uint32_t)out_num_blocks_x, // out_num_blocks_x + (std::uint32_t)out_num_blocks_y, // out_num_blocks_y // in0 mcast args (std::uint32_t)0, (std::uint32_t)0, @@ -1114,15 +1151,15 @@ operation::ProgramWithCallbacks create_program_mcast_in1( (std::uint32_t)1, // in1_tensor_stride_w (std::uint32_t)N, // in1_tensor_stride_h (std::uint32_t)in0_block_w * N, // in1_tensor_next_block_stride - (std::uint32_t)per_core_N, // in1_tensor_next_w_dim_block_stride + (std::uint32_t)in1_block_w, // in1_tensor_next_w_dim_block_stride // in1 block args - (std::uint32_t)per_core_N, // in1_block_w - (std::uint32_t)in0_block_w, // in1_block_h - (std::uint32_t)per_core_N * in0_block_w, // in1_block_num_tiles + (std::uint32_t)in1_block_w, // in1_block_w + (std::uint32_t)in0_block_w, // in1_block_h + (std::uint32_t)in1_block_w * in0_block_w, // in1_block_num_tiles // in0/in1 common args - (std::uint32_t)num_blocks, // num_blocks - (std::uint32_t)1, // out_num_blocks_x - (std::uint32_t)1, // out_num_blocks_y + (std::uint32_t)num_blocks, // num_blocks + (std::uint32_t)out_num_blocks_x, // out_num_blocks_x + (std::uint32_t)out_num_blocks_y, // out_num_blocks_y // in1 mcast args (std::uint32_t)in1_mcast_sender_semaphore_id, (std::uint32_t)in1_mcast_receiver_semaphore_id, @@ -1139,8 +1176,8 @@ operation::ProgramWithCallbacks create_program_mcast_in1( (std::uint32_t)N, // out_tensor_stride_h (std::uint32_t)out_subblock_w, // out_tensor_next_subblock_stride_w (std::uint32_t)out_subblock_h * N, // out_tensor_next_subblock_stride_h - (std::uint32_t)per_core_N, // out_tensor_next_w_dim_block_stride - (std::uint32_t)per_core_M * N, // out_tensor_next_h_dim_block_stride + (std::uint32_t)out_block_w, // out_tensor_next_w_dim_block_stride + (std::uint32_t)out_block_h * N, // out_tensor_next_h_dim_block_stride // out subblock args (std::uint32_t)out_subblock_w, // out_subblock_w (std::uint32_t)out_subblock_h, // out_subblock_h @@ -1164,11 +1201,11 @@ operation::ProgramWithCallbacks create_program_mcast_in1( // READER // in1 block args - (std::uint32_t)per_core_N * in0_block_w, // in1_block_num_tiles + (std::uint32_t)in1_block_w * in0_block_w, // in1_block_num_tiles // in0/in1 common args - (std::uint32_t)num_blocks, // num_blocks - (std::uint32_t)1, // out_num_blocks_x - (std::uint32_t)1, // out_num_blocks_y + (std::uint32_t)num_blocks, // num_blocks + (std::uint32_t)out_num_blocks_x, // out_num_blocks_x + (std::uint32_t)out_num_blocks_y, // out_num_blocks_y // in1 mcast args (std::uint32_t)in1_mcast_sender_semaphore_id, (std::uint32_t)in1_mcast_receiver_semaphore_id, @@ -1181,8 +1218,8 @@ operation::ProgramWithCallbacks create_program_mcast_in1( (std::uint32_t)N, // out_tensor_stride_h (std::uint32_t)out_subblock_w, // out_tensor_next_subblock_stride_w (std::uint32_t)out_subblock_h * N, // out_tensor_next_subblock_stride_h - (std::uint32_t)per_core_N, // out_tensor_next_w_dim_block_stride - (std::uint32_t)per_core_M * N, // out_tensor_next_h_dim_block_stride + (std::uint32_t)out_block_w, // out_tensor_next_w_dim_block_stride + (std::uint32_t)out_block_h * N, // out_tensor_next_h_dim_block_stride // out subblock args (std::uint32_t)out_subblock_w, // out_subblock_w (std::uint32_t)out_subblock_h, // out_subblock_h @@ -1279,11 +1316,11 @@ operation::ProgramWithCallbacks create_program_mcast_in1( // Compute kernel compile time args - uint32_t in0_num_subblocks = (per_core_M / out_subblock_h); + uint32_t in0_num_subblocks = (out_block_h / out_subblock_h); uint32_t in0_block_num_tiles = out_subblock_h * in0_block_w * in0_num_subblocks; uint32_t in0_subblock_num_tiles = out_subblock_h * in0_block_w; - uint32_t in1_num_subblocks = (per_core_N / out_subblock_w); + uint32_t in1_num_subblocks = (out_block_w / out_subblock_w); uint32_t in1_block_num_tiles = out_subblock_w * in0_block_w * in1_num_subblocks; uint32_t in1_per_core_w = out_subblock_w * in1_num_subblocks; @@ -1299,9 +1336,9 @@ operation::ProgramWithCallbacks create_program_mcast_in1( in1_block_num_tiles, // in1_block_num_tiles in1_per_core_w, // in1_per_core_w - num_blocks, // num_blocks - 1, // out_num_blocks_x - 1, // out_num_blocks_y + num_blocks, // num_blocks + out_num_blocks_x, // out_num_blocks_x + out_num_blocks_y, // out_num_blocks_y out_subblock_h, // out_subblock_h out_subblock_w, // out_subblock_w @@ -1349,7 +1386,7 @@ operation::ProgramWithCallbacks create_program_mcast_in1( CBHandle cb_src2 = 0; if (in0_is_sharded and extract_shard_sub_blocks) { // in0_is_sharded is technically redundant tt_metal::CircularBufferConfig src2_cb_config = - tt_metal::CircularBufferConfig(in0_CB_size, {{src2_cb_index, in0_data_format}}) + tt_metal::CircularBufferConfig(in2_CB_size, {{src2_cb_index, in0_data_format}}) .set_page_size(src2_cb_index, in0_single_tile_size) .set_globally_allocated_address(*in0_buffer) .set_tile_dims(src2_cb_index, in0_tile); @@ -1359,8 +1396,8 @@ operation::ProgramWithCallbacks create_program_mcast_in1( "CB {} :: PS = {}, NP = {}, TOTAL = {}", src2_cb_index, in0_single_tile_size, - in0_CB_size / in0_single_tile_size, - in0_CB_size); + in2_CB_size / in0_single_tile_size, + in2_CB_size); } uint32_t src1_cb_index = 1; @@ -1384,7 +1421,8 @@ operation::ProgramWithCallbacks create_program_mcast_in1( tt_metal::CircularBufferConfig output_cb_config = tt_metal::CircularBufferConfig(0, {{output_cb_index, output_data_format}}); - if (interm0_data_format != output_data_format) { + if (do_not_inplace_interm0_out_CB || (interm0_data_format != output_data_format) || + (untilize_out && (in1_num_subblocks > 1))) { // output std::map output_cb_data_format_spec{ {output_cb_index, output_data_format}, @@ -1448,20 +1486,22 @@ operation::ProgramWithCallbacks create_program_mcast_in1( } // Parameters for last row, col, or block - uint32_t last_block_h = M % per_core_M == 0 ? per_core_M : M % per_core_M; - uint32_t last_block_w = N % per_core_N == 0 ? per_core_N : N % per_core_N; - uint32_t last_block_num_nonzero_subblocks_h = (last_block_h - 1) / out_subblock_h + 1; - uint32_t last_block_num_nonzero_subblocks_w = (last_block_w - 1) / out_subblock_w + 1; + uint32_t last_per_core_M = M % per_core_M == 0 ? per_core_M : M % per_core_M; + uint32_t last_per_core_N = N % per_core_N == 0 ? per_core_N : N % per_core_N; + uint32_t last_out_block_h = last_per_core_M % out_block_h == 0 ? out_block_h : last_per_core_M % out_block_h; + uint32_t last_out_block_w = last_per_core_N % out_block_w == 0 ? out_block_w : last_per_core_N % out_block_w; + uint32_t last_block_num_nonzero_subblocks_h = (last_out_block_h - 1) / out_subblock_h + 1; + uint32_t last_block_num_nonzero_subblocks_w = (last_out_block_w - 1) / out_subblock_w + 1; uint32_t last_subblock_of_last_block_h = - last_block_h % out_subblock_h == 0 ? out_subblock_h : last_block_h % out_subblock_h; + last_out_block_h % out_subblock_h == 0 ? out_subblock_h : last_out_block_h % out_subblock_h; uint32_t last_subblock_of_last_block_w = - last_block_w % out_subblock_w == 0 ? out_subblock_w : last_block_w % out_subblock_w; + last_out_block_w % out_subblock_w == 0 ? out_subblock_w : last_out_block_w % out_subblock_w; uint32_t last_block_padded_subblock_tiles_addr_skip = output_single_tile_size * (out_subblock_w - last_subblock_of_last_block_w); uint32_t last_block_padded_block_tiles_w_skip = - (out_subblock_w * out_subblock_h) * (per_core_N / out_subblock_w - last_block_num_nonzero_subblocks_w); + (out_subblock_w * out_subblock_h) * (out_block_w / out_subblock_w - last_block_num_nonzero_subblocks_w); uint32_t last_block_padded_block_tiles_h_skip = - (per_core_M / out_subblock_h - last_block_num_nonzero_subblocks_h) * (per_core_N * out_subblock_h); + (out_block_h / out_subblock_h - last_block_num_nonzero_subblocks_h) * (out_block_w * out_subblock_h); CoreCoord start_core_noc = bottom_right_core_physical; CoreCoord end_core_noc = top_left_core_physical; @@ -1494,13 +1534,13 @@ operation::ProgramWithCallbacks create_program_mcast_in1( (std::uint32_t)output_idx_x * per_core_N + output_idx_y * per_core_M * N, // out_tensor_start_tile_id // padding args (READER) - (std::uint32_t)per_core_N, // last_block_w + (std::uint32_t)out_block_w, // last_block_w // padding args (WRITER) - (std::uint32_t)per_core_M / out_subblock_h, + (std::uint32_t)out_block_h / out_subblock_h, (std::uint32_t)out_subblock_h, (std::uint32_t)0, - (std::uint32_t)per_core_N / out_subblock_w, - (std::uint32_t)per_core_N / out_subblock_w, + (std::uint32_t)out_block_w / out_subblock_w, + (std::uint32_t)out_block_w / out_subblock_w, (std::uint32_t)out_subblock_w, (std::uint32_t)0, (std::uint32_t)0}; @@ -1530,23 +1570,23 @@ operation::ProgramWithCallbacks create_program_mcast_in1( if (output_idx_y == num_blocks_y - 1) { // padding args (WRITER) - mm_in1_receiver_writer_args.push_back(per_core_M / out_subblock_h); + mm_in1_receiver_writer_args.push_back(out_block_h / out_subblock_h); mm_in1_receiver_writer_args.push_back(last_block_num_nonzero_subblocks_h); mm_in1_receiver_writer_args.push_back(last_subblock_of_last_block_h); mm_in1_receiver_writer_args.push_back(last_block_padded_block_tiles_h_skip); - mm_in1_receiver_writer_args.push_back(per_core_N / out_subblock_w); - mm_in1_receiver_writer_args.push_back(per_core_N / out_subblock_w); + mm_in1_receiver_writer_args.push_back(out_block_w / out_subblock_w); + mm_in1_receiver_writer_args.push_back(out_block_w / out_subblock_w); mm_in1_receiver_writer_args.push_back(out_subblock_w); mm_in1_receiver_writer_args.push_back(0); mm_in1_receiver_writer_args.push_back(0); } else { // padding args (WRITER) - mm_in1_receiver_writer_args.push_back(per_core_M / out_subblock_h); - mm_in1_receiver_writer_args.push_back(per_core_M / out_subblock_h); + mm_in1_receiver_writer_args.push_back(out_block_h / out_subblock_h); + mm_in1_receiver_writer_args.push_back(out_block_h / out_subblock_h); mm_in1_receiver_writer_args.push_back(out_subblock_h); mm_in1_receiver_writer_args.push_back(0); - mm_in1_receiver_writer_args.push_back(per_core_N / out_subblock_w); - mm_in1_receiver_writer_args.push_back(per_core_N / out_subblock_w); + mm_in1_receiver_writer_args.push_back(out_block_w / out_subblock_w); + mm_in1_receiver_writer_args.push_back(out_block_w / out_subblock_w); mm_in1_receiver_writer_args.push_back(out_subblock_w); mm_in1_receiver_writer_args.push_back(0); mm_in1_receiver_writer_args.push_back(0); @@ -1973,6 +2013,8 @@ operation::ProgramWithCallbacks matmul_multi_core_reuse_mcast_1d_optimized_( uint32_t in0_block_w, uint32_t out_subblock_h, uint32_t out_subblock_w, + uint32_t out_block_h, + uint32_t out_block_w, uint32_t per_core_M, uint32_t per_core_N, bool fuse_batch, @@ -2131,6 +2173,8 @@ operation::ProgramWithCallbacks matmul_multi_core_reuse_mcast_1d_optimized_( in0_block_w, out_subblock_h, out_subblock_w, + out_block_h, + out_block_w, per_core_M, per_core_N, fused_activation, @@ -2168,6 +2212,8 @@ operation::ProgramWithCallbacks matmul_multi_core_reuse_mcast_1d_optimized_( in0_block_w, out_subblock_h, out_subblock_w, + out_block_h, + out_block_w, per_core_M, per_core_N, fused_activation, @@ -2200,6 +2246,8 @@ operation::ProgramWithCallbacks matmul_multi_core_reuse_mcast_1d_optimized( uint32_t in0_block_w, uint32_t out_subblock_h, uint32_t out_subblock_w, + uint32_t out_block_h, + uint32_t out_block_w, uint32_t per_core_M, uint32_t per_core_N, bool fuse_batch, @@ -2222,6 +2270,8 @@ operation::ProgramWithCallbacks matmul_multi_core_reuse_mcast_1d_optimized( in0_block_w, out_subblock_h, out_subblock_w, + out_block_h, + out_block_w, per_core_M, per_core_N, fuse_batch, @@ -2258,6 +2308,8 @@ operation::ProgramWithCallbacks matmul_multi_core_reuse_mcast_1d_optimized_helpe config.in0_block_w, config.out_subblock_h, config.out_subblock_w, + config.out_block_h, + config.out_block_w, config.per_core_M, config.per_core_N, config.fuse_batch, diff --git a/ttnn/cpp/ttnn/operations/matmul/device/matmul_op_multi_core_reuse_mcast_2d_program_factory.cpp b/ttnn/cpp/ttnn/operations/matmul/device/matmul_op_multi_core_reuse_mcast_2d_program_factory.cpp index d6a9b7b2bb6..46ac80aa248 100644 --- a/ttnn/cpp/ttnn/operations/matmul/device/matmul_op_multi_core_reuse_mcast_2d_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/matmul/device/matmul_op_multi_core_reuse_mcast_2d_program_factory.cpp @@ -129,7 +129,7 @@ operation::ProgramWithCallbacks create_program_mcast_in0_in1( if (in0_is_sharded) { in0_shard_width_in_tiles = in0_buffer->shard_spec().shape()[1] / in0_tile.get_tile_shape()[1]; in0_shard_height_in_tiles = in0_buffer->shard_spec().shape()[0] / in0_tile.get_tile_shape()[0]; - in2_block_tiles = out_block_h * in0_shard_width_in_tiles; + in2_block_tiles = per_core_M * in0_shard_width_in_tiles; } uint32_t in2_CB_tiles = in2_block_tiles; @@ -363,12 +363,12 @@ operation::ProgramWithCallbacks create_program_mcast_in0_in1( (std::uint32_t)in0_block_w, // in0_tensor_next_inner_dim_block_stride (std::uint32_t)K * in0_block_h, // in0_tensor_next_h_dim_block_stride // in0 block args - (std::uint32_t)in0_block_w, // in0_block_w - (std::uint32_t)in0_block_h, // in0_block_h - (std::uint32_t)in0_block_num_tiles, // in0_block_num_tiles - (std::uint32_t) false, // extract_shard_sub_blocks (not used for interleaved) - (std::uint32_t)0, // shard_width_in_tiles (not used for interleaved) - (std::uint32_t)0, // shard_height_in_tiles (not used for interleaved) + (std::uint32_t)in0_block_w, // in0_block_w + (std::uint32_t)in0_block_h, // in0_block_h + (std::uint32_t)in0_block_num_tiles, // in0_block_num_tiles + (std::uint32_t)false, // extract_shard_sub_blocks (not used for interleaved) + (std::uint32_t)in0_shard_width_in_tiles, // shard_width_in_tiles (not used for interleaved) + (std::uint32_t)in0_shard_height_in_tiles, // shard_height_in_tiles (not used for interleaved) // in0/in1 common args (std::uint32_t)num_blocks, // num_blocks (std::uint32_t)out_num_blocks_x, diff --git a/ttnn/cpp/ttnn/operations/matmul/matmul_pybind.cpp b/ttnn/cpp/ttnn/operations/matmul/matmul_pybind.cpp index 7619134656e..de6d7348eb2 100644 --- a/ttnn/cpp/ttnn/operations/matmul/matmul_pybind.cpp +++ b/ttnn/cpp/ttnn/operations/matmul/matmul_pybind.cpp @@ -119,22 +119,43 @@ void py_module(py::module& module) { matmul_multi_core_reuse_multicast_1d_program_config .def( - py::init< - CoreCoord, - std::size_t, - std::size_t, - std::size_t, - std::size_t, - std::size_t, - bool, - std::optional, - bool, - bool>(), + py::init([](CoreCoord compute_with_storage_grid_size, + std::size_t in0_block_w, + std::size_t out_subblock_h, + std::size_t out_subblock_w, + std::optional out_block_h, + std::optional out_block_w, + std::size_t per_core_M, + std::size_t per_core_N, + bool fuse_batch, + std::optional fused_activation, + bool mcast_in0, + bool gather_in0) { + // Set out_block_h and out_block_w to defaults if they are not provided + std::size_t actual_out_block_h = out_block_h.value_or(per_core_M); + std::size_t actual_out_block_w = out_block_w.value_or(per_core_N); + + return MatmulMultiCoreReuseMultiCast1DProgramConfig( + compute_with_storage_grid_size, + in0_block_w, + out_subblock_h, + out_subblock_w, + actual_out_block_h, + actual_out_block_w, + per_core_M, + per_core_N, + fuse_batch, + std::move(fused_activation), + mcast_in0, + gather_in0); + }), py::kw_only(), py::arg("compute_with_storage_grid_size"), py::arg("in0_block_w").noconvert(), py::arg("out_subblock_h").noconvert(), py::arg("out_subblock_w").noconvert(), + py::arg("out_block_h") = py::none(), + py::arg("out_block_w") = py::none(), py::arg("per_core_M").noconvert(), py::arg("per_core_N").noconvert(), py::arg("fuse_batch").noconvert(), @@ -147,6 +168,8 @@ void py_module(py::module& module) { .def_readwrite("in0_block_w", &MatmulMultiCoreReuseMultiCast1DProgramConfig::in0_block_w) .def_readwrite("out_subblock_h", &MatmulMultiCoreReuseMultiCast1DProgramConfig::out_subblock_h) .def_readwrite("out_subblock_w", &MatmulMultiCoreReuseMultiCast1DProgramConfig::out_subblock_w) + .def_readwrite("out_block_h", &MatmulMultiCoreReuseMultiCast1DProgramConfig::out_block_h) + .def_readwrite("out_block_w", &MatmulMultiCoreReuseMultiCast1DProgramConfig::out_block_w) .def_readwrite("per_core_M", &MatmulMultiCoreReuseMultiCast1DProgramConfig::per_core_M) .def_readwrite("per_core_N", &MatmulMultiCoreReuseMultiCast1DProgramConfig::per_core_N) .def_readwrite("fuse_batch", &MatmulMultiCoreReuseMultiCast1DProgramConfig::fuse_batch) diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_clip_grad_norm/moreh_clip_grad_norm.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_clip_grad_norm/moreh_clip_grad_norm.cpp index c45e48d828f..f24d2393a50 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_clip_grad_norm/moreh_clip_grad_norm.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_clip_grad_norm/moreh_clip_grad_norm.cpp @@ -41,9 +41,9 @@ Tensor MorehClipGradNorm::invoke( const auto max_num_inputs = get_num_device_cores(device); const auto total_num_inputs = static_cast(inputs.size()); const auto num_iter = (total_num_inputs + max_num_inputs - 1) / max_num_inputs; - + // Store intermediate reduction of Sum[|e|^p] auto tmp_pow_sum = create_device_tensor( - SimpleShape{tt::constants::TILE_HEIGHT, tt::constants::TILE_WIDTH * static_cast(inputs.size())}, + SimpleShape{static_cast(inputs.size()), 1, 1}, inputs.at(0).get_dtype(), Layout::TILE, device, @@ -91,11 +91,13 @@ Tensor MorehClipGradNorm::invoke( } // max_norm / (total_norm + 1e-6) - auto clip_coef = ttnn::multiply(ttnn::add(output_total_norm, 1e-6f), (1 / max_norm)); + Tensor max_norm_tensor = creation::create_scalar(max_norm, inputs.at(0).get_dtype(), Layout::TILE, device); + auto clip_coef = ttnn::div(max_norm_tensor, ttnn::add(output_total_norm, 1e-6f)); // min(clip_coef, 1.0f) Tensor scalar = creation::create_scalar(1.0f, inputs.at(0).get_dtype(), Layout::TILE, device); auto clip_coef_clamped = ttnn::minimum(clip_coef, scalar); scalar.deallocate(); + max_norm_tensor.deallocate(); // Run Step 3 // Inplace update inputs(inputs *= clip_coef_clamped) diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_clip_grad_norm/moreh_clip_grad_norm_step1/device/kernels/moreh_clip_grad_norm_step1_kernel.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_clip_grad_norm/moreh_clip_grad_norm_step1/device/kernels/moreh_clip_grad_norm_step1_kernel.cpp index 61590ae399e..1362a988ffb 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_clip_grad_norm/moreh_clip_grad_norm_step1/device/kernels/moreh_clip_grad_norm_step1_kernel.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_clip_grad_norm/moreh_clip_grad_norm_step1/device/kernels/moreh_clip_grad_norm_step1_kernel.cpp @@ -58,7 +58,7 @@ void MAIN { for (uint32_t tile_idx = 0; tile_idx < num_tiles; tile_idx++) { // Comput cb_xabs and mask(optional) // |x| - ACQ(); + tile_regs_acquire(); cb_wait_front(cb_x, onetile); // comes from the reader cb_reserve_back(cb_xabs, onetile); @@ -83,61 +83,68 @@ void MAIN { abs_tile_init(); abs_tile(dst0); + cb_pop_front(cb_x, onetile); + tile_regs_commit(); + tile_regs_wait(); pack_tile(dst0, cb_xabs); - - cb_pop_front(cb_x, onetile); cb_push_back(cb_xabs, onetile); - REL(); + tile_regs_release(); // |x + decimal|^p power_tile_to_cb(cb_xabs, cb_xpow, cb_logx, cb_decimal, cb_exp_lxmd, cb_correct_xpow, p, p_is_negative); if (tile_idx == 0) { - ACQ(); + tile_regs_acquire(); cb_wait_front(cb_correct_xpow, onetile); cb_reserve_back(cb_xpowadd, onetile); copy_tile_init(); copy_tile(cb_correct_xpow, 0, dst0); + tile_regs_commit(); + tile_regs_wait(); pack_tile(dst0, cb_xpowadd); cb_pop_front(cb_correct_xpow, onetile); cb_push_back(cb_xpowadd, onetile); - REL(); + tile_regs_release(); } else { - ACQ(); + tile_regs_acquire(); cb_wait_front(cb_correct_xpow, onetile); cb_wait_front(cb_xpowadd, onetile); cb_reserve_back(cb_xpowadd, onetile); add_tiles_init(); add_tiles(cb_correct_xpow, cb_xpowadd, 0, 0, dst0); + tile_regs_commit(); + tile_regs_wait(); pack_tile(dst0, cb_xpowadd); cb_pop_front(cb_correct_xpow, onetile); cb_pop_front(cb_xpowadd, onetile); cb_push_back(cb_xpowadd, onetile); - REL(); + tile_regs_release(); } } // Compute cb_y - ACQ(); + tile_regs_acquire(); cb_wait_front(cb_xpowadd, onetile); cb_reserve_back(cb_y, onetile); reduce_init_delta(); reduce_tile(cb_xpowadd, cb_one, 0, 0, dst0); reduce_revert_delta(); + tile_regs_commit(); + tile_regs_wait(); pack_tile(dst0, cb_y); cb_pop_front(cb_xpowadd, onetile); cb_push_back(cb_y, onetile); - REL(); + tile_regs_release(); cb_pop_front(cb_decimal, onetile); cb_pop_front(cb_one, onetile); diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_clip_grad_norm/moreh_clip_grad_norm_step2/device/kernels/moreh_clip_grad_norm_step2_kernel.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_clip_grad_norm/moreh_clip_grad_norm_step2/device/kernels/moreh_clip_grad_norm_step2_kernel.cpp index ee089be698a..160f1f96044 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_clip_grad_norm/moreh_clip_grad_norm_step2/device/kernels/moreh_clip_grad_norm_step2_kernel.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_clip_grad_norm/moreh_clip_grad_norm_step2/device/kernels/moreh_clip_grad_norm_step2_kernel.cpp @@ -40,38 +40,38 @@ void MAIN { // Compute cb_x for (uint32_t tile_idx = 0; tile_idx < num_tiles; tile_idx++) { if (tile_idx == 0) { - ACQ(); + tile_regs_acquire(); cb_wait_front(cb_input, onetile); // comes from the reader cb_reserve_back(cb_x, onetile); copy_tile_init(); copy_tile(cb_input, 0, dst0); + cb_pop_front(cb_input, onetile); + tile_regs_commit(); + tile_regs_wait(); pack_tile(dst0, cb_x); - - cb_pop_front(cb_input, onetile); cb_push_back(cb_x, onetile); - REL(); + tile_regs_release(); } else { - ACQ(); + tile_regs_acquire(); cb_wait_front(cb_input, onetile); // comes from the reader cb_wait_front(cb_x, onetile); cb_reserve_back(cb_x, onetile); add_tiles_init(); add_tiles(cb_input, cb_x, 0, 0, dst0); + cb_pop_front(cb_x, onetile); + cb_pop_front(cb_input, onetile); + tile_regs_commit(); + tile_regs_wait(); pack_tile(dst0, cb_x); - - cb_pop_front(cb_input, onetile); - cb_pop_front(cb_x, onetile); cb_push_back(cb_x, onetile); - REL(); + tile_regs_release(); } } - // x^p power_tile_to_cb(cb_x, cb_xpow, cb_logx, cb_decimal, cb_exp_lxmd, cb_y, p, p_is_negative); - } // void MAIN } // namespace NAMESPACE diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_clip_grad_norm/moreh_clip_grad_norm_step2/device/moreh_clip_grad_norm_step2_device_operation.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_clip_grad_norm/moreh_clip_grad_norm_step2/device/moreh_clip_grad_norm_step2_device_operation.cpp index e248513bfbb..29b009cf9b0 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_clip_grad_norm/moreh_clip_grad_norm_step2/device/moreh_clip_grad_norm_step2_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_clip_grad_norm/moreh_clip_grad_norm_step2/device/moreh_clip_grad_norm_step2_device_operation.cpp @@ -37,7 +37,8 @@ void MorehClipGradNormStep2Operation::validate_on_program_cache_hit( MorehClipGradNormStep2Operation::shape_return_value_t MorehClipGradNormStep2Operation::compute_output_shapes( const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { - return SimpleShape{tt::constants::TILE_HEIGHT, tt::constants::TILE_WIDTH}; + // output total_norm 1 element + return SimpleShape{1, 1}; }; MorehClipGradNormStep2Operation::tensor_return_value_t MorehClipGradNormStep2Operation::create_output_tensors( diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_clip_grad_norm/moreh_clip_grad_norm_step3/device/kernels/moreh_clip_grad_norm_step3_kernel.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_clip_grad_norm/moreh_clip_grad_norm_step3/device/kernels/moreh_clip_grad_norm_step3_kernel.cpp index 10b42ca378a..b0bf50fc4b8 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_clip_grad_norm/moreh_clip_grad_norm_step3/device/kernels/moreh_clip_grad_norm_step3_kernel.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_clip_grad_norm/moreh_clip_grad_norm_step3/device/kernels/moreh_clip_grad_norm_step3_kernel.cpp @@ -25,18 +25,19 @@ void MAIN { // Compute cb_y for (uint32_t tile_idx = 0; tile_idx < num_tiles; tile_idx++) { - ACQ(); + tile_regs_acquire(); cb_wait_front(cb_x, onetile); // comes from the reader cb_reserve_back(cb_y, onetile); mul_tiles_bcast_scalar_init_short(); mul_tiles_bcast_scalar(cb_x, cb_clip_coef_clamped, 0, 0, dst0); + cb_pop_front(cb_x, onetile); + tile_regs_commit(); + tile_regs_wait(); pack_tile(dst0, cb_y); - - cb_pop_front(cb_x, onetile); cb_push_back(cb_y, onetile); - REL(); + tile_regs_release(); } cb_pop_front(cb_clip_coef_clamped, onetile); diff --git a/ttnn/cpp/ttnn/operations/reduction/generic/generic_reductions.cpp b/ttnn/cpp/ttnn/operations/reduction/generic/generic_reductions.cpp index 6cb44e93b09..6b5e63c9433 100644 --- a/ttnn/cpp/ttnn/operations/reduction/generic/generic_reductions.cpp +++ b/ttnn/cpp/ttnn/operations/reduction/generic/generic_reductions.cpp @@ -5,7 +5,7 @@ #include "ttnn/operations/reduction/generic/generic_reductions.hpp" #include "ttnn/operations/data_movement/transpose/transpose.hpp" #include "ttnn/operations/eltwise/unary/unary.hpp" -#include "ttnn/operations/eltwise/unary/unary_composite.hpp" +#include "ttnn/operations/eltwise/binary/binary_composite.hpp" #include "ttnn/operations/reduction/generic/device/reduce_op.hpp" #include "ttnn/operations/core/core.hpp" namespace ttnn { diff --git a/ttnn/ttnn/library_tweaks.py b/ttnn/ttnn/library_tweaks.py index 81778ca2f25..3b1f6f0128d 100644 --- a/ttnn/ttnn/library_tweaks.py +++ b/ttnn/ttnn/library_tweaks.py @@ -30,7 +30,7 @@ def prepare_dir_as_metal_home(ttnn_package_path, metal_home): metal_home.mkdir(exist_ok=True) version_file = metal_home / ".METAL_VERSION" - current_version = version("metal-libs").strip() + current_version = version("ttnn").strip() runtime_src = ttnn_package_path.parent / "runtime" assert (