diff --git a/models/experimental/functional_mobilenetv2/README.md b/models/experimental/functional_mobilenetv2/README.md new file mode 100644 index 000000000000..99061f50e3fa --- /dev/null +++ b/models/experimental/functional_mobilenetv2/README.md @@ -0,0 +1,26 @@ +# MobilenetV2 +The MobileNetV2 model is a convolutional neural network (CNN) architecture designed for efficient mobile and embedded vision applications. It was introduced in the paper ["MobileNetV2: Inverted Residuals and Linear Bottlenecks"](https://arxiv.org/abs/1801.04381).
+The MobileNetV2 model has been pre-trained on the ImageNet dataset and can be used for various tasks such as image classification, object detection, and semantic segmentation. It has achieved state-of-the-art performance on several benchmarks 1 for mobile and embedded vision applications. + +## How to Run + +To run the demo, make sure to build the project, activate the environment, and set the appropriate environment variables. +For more information, refer [installation and build guide](https://docs.tenstorrent.com/tt-metalium/latest/get_started/get_started.html#install-and-build). + +To run the functional Mobilenetv2 model on a single-chip: +```sh +pytest --disable-warnings models/experimental/functional_mobilenetv2/test/test_ttnn_mobilenetv2.py +``` + +To run the functional Mobilenetv2 model on a single-chip: +```sh +pytest --disable-warnings models/experimental/functional_mobilenetv2/demo/demo.py +``` + +## Supported Hardware +- N150 + +## Other Details + +- Inputs by default are random data in test_ttnn_mobilenetv2 and images can be fed as input in demo.py. +- The model weights will be automatically downloaded from Google Drive using wget implemented in weights_download.sh. diff --git a/models/experimental/functional_mobilenetv2/demo/demo.py b/models/experimental/functional_mobilenetv2/demo/demo.py new file mode 100644 index 000000000000..ce6b36f962a1 --- /dev/null +++ b/models/experimental/functional_mobilenetv2/demo/demo.py @@ -0,0 +1,101 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + + +import pytest +import ttnn +import torch + +from tests.ttnn.utils_for_testing import assert_with_pcc + +from models.experimental.functional_mobilenetv2.reference.mobilenetv2 import Mobilenetv2 +from models.experimental.functional_mobilenetv2.tt.model_preprocessing import ( + create_mobilenetv2_input_tensors, + create_mobilenetv2_model_parameters, +) +from models.experimental.functional_mobilenetv2.tt import ttnn_mobilenetv2 +import os +from models.utility_functions import ( + skip_for_grayskull, +) +from PIL import Image +from torchvision import transforms + + +@pytest.mark.parametrize("device_params", [{"l1_small_size": 32768}], indirect=True) +@skip_for_grayskull() +def test_mobilenetv2_demo(device, reset_seeds): + if not os.path.exists("models/experimental/functional_mobilenetv2/mobilenet_v2-b0353104.pth"): + os.system( + "bash models/experimental/functional_mobilenetv2/weights_download.sh" + ) # execute the weights_download.sh file + + state_dict = torch.load("models/experimental/functional_mobilenetv2/mobilenet_v2-b0353104.pth") + ds_state_dict = {k: v for k, v in state_dict.items()} + torch_model = Mobilenetv2() + + new_state_dict = {} + + for (name1, parameter1), (name2, parameter2) in zip(torch_model.state_dict().items(), ds_state_dict.items()): + if isinstance(parameter2, torch.FloatTensor): + new_state_dict[name1] = parameter2 + + torch_model.load_state_dict(new_state_dict) + torch_model.eval() + + transform = transforms.Compose( + [ + transforms.Resize(128), + transforms.CenterCrop(128), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ] + ) + + img = Image.open("models/experimental/functional_mobilenetv2/demo/images/strawberry.jpg") + + img_t = transform(img) + torch_input_tensor = torch.unsqueeze(img_t, 0) + + # torch_input_tensor, ttnn_input_tensor = create_mobilenetv2_input_tensors() + torch_output_tensor = torch_model(torch_input_tensor) + + parameters = create_mobilenetv2_model_parameters(torch_model, torch_input_tensor, device=device) + + ttnn_input_tensor = torch.permute(torch_input_tensor, (0, 2, 3, 1)) + ttnn_input_tensor = ttnn_input_tensor.reshape( + 1, + 1, + ttnn_input_tensor.shape[0] * ttnn_input_tensor.shape[1] * ttnn_input_tensor.shape[2], + ttnn_input_tensor.shape[3], + ) + ttnn_input_tensor = ttnn.from_torch(ttnn_input_tensor, dtype=ttnn.bfloat16) + + ttnn_model = ttnn_mobilenetv2.MobileNetV2(parameters, device, torch_model) + output_tensor = ttnn_model(device, ttnn_input_tensor) + + # + # Tensor Postprocessing + # + output_tensor = ttnn.to_torch(output_tensor) + output_tensor = output_tensor.reshape(torch_output_tensor.shape) + output_tensor = output_tensor.to(torch_input_tensor.dtype) + + with open("models/experimental/functional_mobilenetv2/demo/imagenet_classes.txt") as f: + classes = [line.strip() for line in f.readlines()] + + # Get the predicted class index and probability + _, index = torch.max(output_tensor, 1) # Get the index of the highest probability class + percentage = torch.nn.functional.softmax(output_tensor, dim=1)[0] * 100 # Calculate the class probabilities + + # Print the predicted class and its probability + print("\033[1m" + f"Predicted class: {classes[index[0]]}") + print( + "\033[1m" + f"Probability: {percentage[index[0]].item():.2f}%" + ) # Format the probability with two decimal places + + _, indices = torch.sort(output_tensor, descending=True) + [(classes[idx], percentage[idx].item()) for idx in indices[0][:5]] + + assert_with_pcc(torch_output_tensor, output_tensor, pcc=0.95) diff --git a/models/experimental/functional_mobilenetv2/demo/imagenet_classes.txt b/models/experimental/functional_mobilenetv2/demo/imagenet_classes.txt new file mode 100644 index 000000000000..34c3324c3a3b --- /dev/null +++ b/models/experimental/functional_mobilenetv2/demo/imagenet_classes.txt @@ -0,0 +1,1000 @@ +0, tench +1, goldfish +2, great_white_shark +3, tiger_shark +4, hammerhead +5, electric_ray +6, stingray +7, cock +8, hen +9, ostrich +10, brambling +11, goldfinch +12, house_finch +13, junco +14, indigo_bunting +15, robin +16, bulbul +17, jay +18, magpie +19, chickadee +20, water_ouzel +21, kite +22, bald_eagle +23, vulture +24, great_grey_owl +25, European_fire_salamander +26, common_newt +27, eft +28, spotted_salamander +29, axolotl +30, bullfrog +31, tree_frog +32, tailed_frog +33, loggerhead +34, leatherback_turtle +35, mud_turtle +36, terrapin +37, box_turtle +38, banded_gecko +39, common_iguana +40, American_chameleon +41, whiptail +42, agama +43, frilled_lizard +44, alligator_lizard +45, Gila_monster +46, green_lizard +47, African_chameleon +48, Komodo_dragon +49, African_crocodile +50, American_alligator +51, triceratops +52, thunder_snake +53, ringneck_snake +54, hognose_snake +55, green_snake +56, king_snake +57, garter_snake +58, water_snake +59, vine_snake +60, night_snake +61, boa_constrictor +62, rock_python +63, Indian_cobra +64, green_mamba +65, sea_snake +66, horned_viper +67, diamondback +68, sidewinder +69, trilobite +70, harvestman +71, scorpion +72, black_and_gold_garden_spider +73, barn_spider +74, garden_spider +75, black_widow +76, tarantula +77, wolf_spider +78, tick +79, centipede +80, black_grouse +81, ptarmigan +82, ruffed_grouse +83, prairie_chicken +84, peacock +85, quail +86, partridge +87, African_grey +88, macaw +89, sulphur-crested_cockatoo +90, lorikeet +91, coucal +92, bee_eater +93, hornbill +94, hummingbird +95, jacamar +96, toucan +97, drake +98, red-breasted_merganser +99, goose +100, black_swan +101, tusker +102, echidna +103, platypus +104, wallaby +105, koala +106, wombat +107, jellyfish +108, sea_anemone +109, brain_coral +110, flatworm +111, nematode +112, conch +113, snail +114, slug +115, sea_slug +116, chiton +117, chambered_nautilus +118, Dungeness_crab +119, rock_crab +120, fiddler_crab +121, king_crab +122, American_lobster +123, spiny_lobster +124, crayfish +125, hermit_crab +126, isopod +127, white_stork +128, black_stork +129, spoonbill +130, flamingo +131, little_blue_heron +132, American_egret +133, bittern +134, crane +135, limpkin +136, European_gallinule +137, American_coot +138, bustard +139, ruddy_turnstone +140, red-backed_sandpiper +141, redshank +142, dowitcher +143, oystercatcher +144, pelican +145, king_penguin +146, albatross +147, grey_whale +148, killer_whale +149, dugong +150, sea_lion +151, Chihuahua +152, Japanese_spaniel +153, Maltese_dog +154, Pekinese +155, Shih-Tzu +156, Blenheim_spaniel +157, papillon +158, toy_terrier +159, Rhodesian_ridgeback +160, Afghan_hound +161, basset +162, beagle +163, bloodhound +164, bluetick +165, black-and-tan_coonhound +166, Walker_hound +167, English_foxhound +168, redbone +169, borzoi +170, Irish_wolfhound +171, Italian_greyhound +172, whippet +173, Ibizan_hound +174, Norwegian_elkhound +175, otterhound +176, Saluki +177, Scottish_deerhound +178, Weimaraner +179, Staffordshire_bullterrier +180, American_Staffordshire_terrier +181, Bedlington_terrier +182, Border_terrier +183, Kerry_blue_terrier +184, Irish_terrier +185, Norfolk_terrier +186, Norwich_terrier +187, Yorkshire_terrier +188, wire-haired_fox_terrier +189, Lakeland_terrier +190, Sealyham_terrier +191, Airedale +192, cairn +193, Australian_terrier +194, Dandie_Dinmont +195, Boston_bull +196, miniature_schnauzer +197, giant_schnauzer +198, standard_schnauzer +199, Scotch_terrier +200, Tibetan_terrier +201, silky_terrier +202, soft-coated_wheaten_terrier +203, West_Highland_white_terrier +204, Lhasa +205, flat-coated_retriever +206, curly-coated_retriever +207, golden_retriever +208, Labrador_retriever +209, Chesapeake_Bay_retriever +210, German_short-haired_pointer +211, vizsla +212, English_setter +213, Irish_setter +214, Gordon_setter +215, Brittany_spaniel +216, clumber +217, English_springer +218, Welsh_springer_spaniel +219, cocker_spaniel +220, Sussex_spaniel +221, Irish_water_spaniel +222, kuvasz +223, schipperke +224, groenendael +225, malinois +226, briard +227, kelpie +228, komondor +229, Old_English_sheepdog +230, Shetland_sheepdog +231, collie +232, Border_collie +233, Bouvier_des_Flandres +234, Rottweiler +235, German_shepherd +236, Doberman +237, miniature_pinscher +238, Greater_Swiss_Mountain_dog +239, Bernese_mountain_dog +240, Appenzeller +241, EntleBucher +242, boxer +243, bull_mastiff +244, Tibetan_mastiff +245, French_bulldog +246, Great_Dane +247, Saint_Bernard +248, Eskimo_dog +249, malamute +250, Siberian_husky +251, dalmatian +252, affenpinscher +253, basenji +254, pug +255, Leonberg +256, Newfoundland +257, Great_Pyrenees +258, Samoyed +259, Pomeranian +260, chow +261, keeshond +262, Brabancon_griffon +263, Pembroke +264, Cardigan +265, toy_poodle +266, miniature_poodle +267, standard_poodle +268, Mexican_hairless +269, timber_wolf +270, white_wolf +271, red_wolf +272, coyote +273, dingo +274, dhole +275, African_hunting_dog +276, hyena +277, red_fox +278, kit_fox +279, Arctic_fox +280, grey_fox +281, tabby +282, tiger_cat +283, Persian_cat +284, Siamese_cat +285, Egyptian_cat +286, cougar +287, lynx +288, leopard +289, snow_leopard +290, jaguar +291, lion +292, tiger +293, cheetah +294, brown_bear +295, American_black_bear +296, ice_bear +297, sloth_bear +298, mongoose +299, meerkat +300, tiger_beetle +301, ladybug +302, ground_beetle +303, long-horned_beetle +304, leaf_beetle +305, dung_beetle +306, rhinoceros_beetle +307, weevil +308, fly +309, bee +310, ant +311, grasshopper +312, cricket +313, walking_stick +314, cockroach +315, mantis +316, cicada +317, leafhopper +318, lacewing +319, dragonfly +320, damselfly +321, admiral +322, ringlet +323, monarch +324, cabbage_butterfly +325, sulphur_butterfly +326, lycaenid +327, starfish +328, sea_urchin +329, sea_cucumber +330, wood_rabbit +331, hare +332, Angora +333, hamster +334, porcupine +335, fox_squirrel +336, marmot +337, beaver +338, guinea_pig +339, sorrel +340, zebra +341, hog +342, wild_boar +343, warthog +344, hippopotamus +345, ox +346, water_buffalo +347, bison +348, ram +349, bighorn +350, ibex +351, hartebeest +352, impala +353, gazelle +354, Arabian_camel +355, llama +356, weasel +357, mink +358, polecat +359, black-footed_ferret +360, otter +361, skunk +362, badger +363, armadillo +364, three-toed_sloth +365, orangutan +366, gorilla +367, chimpanzee +368, gibbon +369, siamang +370, guenon +371, patas +372, baboon +373, macaque +374, langur +375, colobus +376, proboscis_monkey +377, marmoset +378, capuchin +379, howler_monkey +380, titi +381, spider_monkey +382, squirrel_monkey +383, Madagascar_cat +384, indri +385, Indian_elephant +386, African_elephant +387, lesser_panda +388, giant_panda +389, barracouta +390, eel +391, coho +392, rock_beauty +393, anemone_fish +394, sturgeon +395, gar +396, lionfish +397, puffer +398, abacus +399, abaya +400, academic_gown +401, accordion +402, acoustic_guitar +403, aircraft_carrier +404, airliner +405, airship +406, altar +407, ambulance +408, amphibian +409, analog_clock +410, apiary +411, apron +412, ashcan +413, assault_rifle +414, backpack +415, bakery +416, balance_beam +417, balloon +418, ballpoint +419, Band_Aid +420, banjo +421, bannister +422, barbell +423, barber_chair +424, barbershop +425, barn +426, barometer +427, barrel +428, barrow +429, baseball +430, basketball +431, bassinet +432, bassoon +433, bathing_cap +434, bath_towel +435, bathtub +436, beach_wagon +437, beacon +438, beaker +439, bearskin +440, beer_bottle +441, beer_glass +442, bell_cote +443, bib +444, bicycle-built-for-two +445, bikini +446, binder +447, binoculars +448, birdhouse +449, boathouse +450, bobsled +451, bolo_tie +452, bonnet +453, bookcase +454, bookshop +455, bottlecap +456, bow +457, bow_tie +458, brass +459, brassiere +460, breakwater +461, breastplate +462, broom +463, bucket +464, buckle +465, bulletproof_vest +466, bullet_train +467, butcher_shop +468, cab +469, caldron +470, candle +471, cannon +472, canoe +473, can_opener +474, cardigan +475, car_mirror +476, carousel +477, carpenter's_kit +478, carton +479, car_wheel +480, cash_machine +481, cassette +482, cassette_player +483, castle +484, catamaran +485, CD_player +486, cello +487, cellular_telephone +488, chain +489, chainlink_fence +490, chain_mail +491, chain_saw +492, chest +493, chiffonier +494, chime +495, china_cabinet +496, Christmas_stocking +497, church +498, cinema +499, cleaver +500, cliff_dwelling +501, cloak +502, clog +503, cocktail_shaker +504, coffee_mug +505, coffeepot +506, coil +507, combination_lock +508, computer_keyboard +509, confectionery +510, container_ship +511, convertible +512, corkscrew +513, cornet +514, cowboy_boot +515, cowboy_hat +516, cradle +517, crane +518, crash_helmet +519, crate +520, crib +521, Crock_Pot +522, croquet_ball +523, crutch +524, cuirass +525, dam +526, desk +527, desktop_computer +528, dial_telephone +529, diaper +530, digital_clock +531, digital_watch +532, dining_table +533, dishrag +534, dishwasher +535, disk_brake +536, dock +537, dogsled +538, dome +539, doormat +540, drilling_platform +541, drum +542, drumstick +543, dumbbell +544, Dutch_oven +545, electric_fan +546, electric_guitar +547, electric_locomotive +548, entertainment_center +549, envelope +550, espresso_maker +551, face_powder +552, feather_boa +553, file +554, fireboat +555, fire_engine +556, fire_screen +557, flagpole +558, flute +559, folding_chair +560, football_helmet +561, forklift +562, fountain +563, fountain_pen +564, four-poster +565, freight_car +566, French_horn +567, frying_pan +568, fur_coat +569, garbage_truck +570, gasmask +571, gas_pump +572, goblet +573, go-kart +574, golf_ball +575, golfcart +576, gondola +577, gong +578, gown +579, grand_piano +580, greenhouse +581, grille +582, grocery_store +583, guillotine +584, hair_slide +585, hair_spray +586, half_track +587, hammer +588, hamper +589, hand_blower +590, hand-held_computer +591, handkerchief +592, hard_disc +593, harmonica +594, harp +595, harvester +596, hatchet +597, holster +598, home_theater +599, honeycomb +600, hook +601, hoopskirt +602, horizontal_bar +603, horse_cart +604, hourglass +605, iPod +606, iron +607, jack-o'-lantern +608, jean +609, jeep +610, jersey +611, jigsaw_puzzle +612, jinrikisha +613, joystick +614, kimono +615, knee_pad +616, knot +617, lab_coat +618, ladle +619, lampshade +620, laptop +621, lawn_mower +622, lens_cap +623, letter_opener +624, library +625, lifeboat +626, lighter +627, limousine +628, liner +629, lipstick +630, Loafer +631, lotion +632, loudspeaker +633, loupe +634, lumbermill +635, magnetic_compass +636, mailbag +637, mailbox +638, maillot +639, maillot +640, manhole_cover +641, maraca +642, marimba +643, mask +644, matchstick +645, maypole +646, maze +647, measuring_cup +648, medicine_chest +649, megalith +650, microphone +651, microwave +652, military_uniform +653, milk_can +654, minibus +655, miniskirt +656, minivan +657, missile +658, mitten +659, mixing_bowl +660, mobile_home +661, Model_T +662, modem +663, monastery +664, monitor +665, moped +666, mortar +667, mortarboard +668, mosque +669, mosquito_net +670, motor_scooter +671, mountain_bike +672, mountain_tent +673, mouse +674, mousetrap +675, moving_van +676, muzzle +677, nail +678, neck_brace +679, necklace +680, nipple +681, notebook +682, obelisk +683, oboe +684, ocarina +685, odometer +686, oil_filter +687, organ +688, oscilloscope +689, overskirt +690, oxcart +691, oxygen_mask +692, packet +693, paddle +694, paddlewheel +695, padlock +696, paintbrush +697, pajama +698, palace +699, panpipe +700, paper_towel +701, parachute +702, parallel_bars +703, park_bench +704, parking_meter +705, passenger_car +706, patio +707, pay-phone +708, pedestal +709, pencil_box +710, pencil_sharpener +711, perfume +712, Petri_dish +713, photocopier +714, pick +715, pickelhaube +716, picket_fence +717, pickup +718, pier +719, piggy_bank +720, pill_bottle +721, pillow +722, ping-pong_ball +723, pinwheel +724, pirate +725, pitcher +726, plane +727, planetarium +728, plastic_bag +729, plate_rack +730, plow +731, plunger +732, Polaroid_camera +733, pole +734, police_van +735, poncho +736, pool_table +737, pop_bottle +738, pot +739, potter's_wheel +740, power_drill +741, prayer_rug +742, printer +743, prison +744, projectile +745, projector +746, puck +747, punching_bag +748, purse +749, quill +750, quilt +751, racer +752, racket +753, radiator +754, radio +755, radio_telescope +756, rain_barrel +757, recreational_vehicle +758, reel +759, reflex_camera +760, refrigerator +761, remote_control +762, restaurant +763, revolver +764, rifle +765, rocking_chair +766, rotisserie +767, rubber_eraser +768, rugby_ball +769, rule +770, running_shoe +771, safe +772, safety_pin +773, saltshaker +774, sandal +775, sarong +776, sax +777, scabbard +778, scale +779, school_bus +780, schooner +781, scoreboard +782, screen +783, screw +784, screwdriver +785, seat_belt +786, sewing_machine +787, shield +788, shoe_shop +789, shoji +790, shopping_basket +791, shopping_cart +792, shovel +793, shower_cap +794, shower_curtain +795, ski +796, ski_mask +797, sleeping_bag +798, slide_rule +799, sliding_door +800, slot +801, snorkel +802, snowmobile +803, snowplow +804, soap_dispenser +805, soccer_ball +806, sock +807, solar_dish +808, sombrero +809, soup_bowl +810, space_bar +811, space_heater +812, space_shuttle +813, spatula +814, speedboat +815, spider_web +816, spindle +817, sports_car +818, spotlight +819, stage +820, steam_locomotive +821, steel_arch_bridge +822, steel_drum +823, stethoscope +824, stole +825, stone_wall +826, stopwatch +827, stove +828, strainer +829, streetcar +830, stretcher +831, studio_couch +832, stupa +833, submarine +834, suit +835, sundial +836, sunglass +837, sunglasses +838, sunscreen +839, suspension_bridge +840, swab +841, sweatshirt +842, swimming_trunks +843, swing +844, switch +845, syringe +846, table_lamp +847, tank +848, tape_player +849, teapot +850, teddy +851, television +852, tennis_ball +853, thatch +854, theater_curtain +855, thimble +856, thresher +857, throne +858, tile_roof +859, toaster +860, tobacco_shop +861, toilet_seat +862, torch +863, totem_pole +864, tow_truck +865, toyshop +866, tractor +867, trailer_truck +868, tray +869, trench_coat +870, tricycle +871, trimaran +872, tripod +873, triumphal_arch +874, trolleybus +875, trombone +876, tub +877, turnstile +878, typewriter_keyboard +879, umbrella +880, unicycle +881, upright +882, vacuum +883, vase +884, vault +885, velvet +886, vending_machine +887, vestment +888, viaduct +889, violin +890, volleyball +891, waffle_iron +892, wall_clock +893, wallet +894, wardrobe +895, warplane +896, washbasin +897, washer +898, water_bottle +899, water_jug +900, water_tower +901, whiskey_jug +902, whistle +903, wig +904, window_screen +905, window_shade +906, Windsor_tie +907, wine_bottle +908, wing +909, wok +910, wooden_spoon +911, wool +912, worm_fence +913, wreck +914, yawl +915, yurt +916, web_site +917, comic_book +918, crossword_puzzle +919, street_sign +920, traffic_light +921, book_jacket +922, menu +923, plate +924, guacamole +925, consomme +926, hot_pot +927, trifle +928, ice_cream +929, ice_lolly +930, French_loaf +931, bagel +932, pretzel +933, cheeseburger +934, hotdog +935, mashed_potato +936, head_cabbage +937, broccoli +938, cauliflower +939, zucchini +940, spaghetti_squash +941, acorn_squash +942, butternut_squash +943, cucumber +944, artichoke +945, bell_pepper +946, cardoon +947, mushroom +948, Granny_Smith +949, strawberry +950, orange +951, lemon +952, fig +953, pineapple +954, banana +955, jackfruit +956, custard_apple +957, pomegranate +958, hay +959, carbonara +960, chocolate_sauce +961, dough +962, meat_loaf +963, pizza +964, potpie +965, burrito +966, red_wine +967, espresso +968, cup +969, eggnog +970, alp +971, bubble +972, cliff +973, coral_reef +974, geyser +975, lakeside +976, promontory +977, sandbar +978, seashore +979, valley +980, volcano +981, ballplayer +982, groom +983, scuba_diver +984, rapeseed +985, daisy +986, yellow_lady's_slipper +987, corn +988, acorn +989, hip +990, buckeye +991, coral_fungus +992, agaric +993, gyromitra +994, stinkhorn +995, earthstar +996, hen-of-the-woods +997, bolete +998, ear +999, toilet_tissue diff --git a/models/experimental/functional_mobilenetv2/demo/images/strawberry.jpg b/models/experimental/functional_mobilenetv2/demo/images/strawberry.jpg new file mode 100644 index 000000000000..2cefb972eb87 Binary files /dev/null and b/models/experimental/functional_mobilenetv2/demo/images/strawberry.jpg differ diff --git a/models/experimental/functional_mobilenetv2/reference/mobilenetv2.py b/models/experimental/functional_mobilenetv2/reference/mobilenetv2.py new file mode 100644 index 000000000000..f9bf6130f614 --- /dev/null +++ b/models/experimental/functional_mobilenetv2/reference/mobilenetv2.py @@ -0,0 +1,384 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + + +import torch +import torch.nn as nn + + +class Mobilenetv2(nn.Module): + def __init__(self): + super().__init__() + self.c1 = nn.Conv2d(3, 32, 3, 2, 1, bias=False) + self.b1 = nn.BatchNorm2d(32) + self.relu = nn.ReLU6(inplace=True) + + self.c2 = nn.Conv2d(32, 32, 3, 1, 1, groups=32, bias=False) + self.b2 = nn.BatchNorm2d(32) + + self.c3 = nn.Conv2d(32, 16, 1, 1, bias=False) + self.b3 = nn.BatchNorm2d(16) + + self.c4 = nn.Conv2d(16, 96, 1, 1, bias=False) + self.b4 = nn.BatchNorm2d(96) + + self.c5 = nn.Conv2d(96, 96, 3, 2, 1, groups=96, bias=False) + self.b5 = nn.BatchNorm2d(96) + + self.c6 = nn.Conv2d(96, 24, 1, 1, bias=False) + self.b6 = nn.BatchNorm2d(24) + + self.c7 = nn.Conv2d(24, 144, 1, 1, bias=False) + self.b7 = nn.BatchNorm2d(144) + + self.c8 = nn.Conv2d(144, 144, 3, 1, 1, groups=144, bias=False) + self.b8 = nn.BatchNorm2d(144) + + self.c9 = nn.Conv2d(144, 24, 1, 1, bias=False) + self.b9 = nn.BatchNorm2d(24) + + self.c10 = nn.Conv2d(24, 144, 1, 1, bias=False) + self.b10 = nn.BatchNorm2d(144) + + self.c11 = nn.Conv2d(144, 144, 3, 2, 1, groups=144, bias=False) + self.b11 = nn.BatchNorm2d(144) + + self.c12 = nn.Conv2d(144, 32, 1, 1, bias=False) + self.b12 = nn.BatchNorm2d(32) + + self.c13 = nn.Conv2d(32, 192, 1, 1, bias=False) + self.b13 = nn.BatchNorm2d(192) + + self.c14 = nn.Conv2d(192, 192, 3, 1, 1, groups=192, bias=False) + self.b14 = nn.BatchNorm2d(192) + + self.c15 = nn.Conv2d(192, 32, 1, 1, bias=False) + self.b15 = nn.BatchNorm2d(32) + + self.c16 = nn.Conv2d(32, 192, 1, 1, bias=False) + self.b16 = nn.BatchNorm2d(192) + + self.c17 = nn.Conv2d(192, 192, 3, 1, 1, groups=192, bias=False) + self.b17 = nn.BatchNorm2d(192) + + self.c18 = nn.Conv2d(192, 32, 1, 1, bias=False) + self.b18 = nn.BatchNorm2d(32) + + self.c19 = nn.Conv2d(32, 192, 1, 1, bias=False) + self.b19 = nn.BatchNorm2d(192) + + self.c20 = nn.Conv2d(192, 192, 3, 2, 1, groups=192, bias=False) + self.b20 = nn.BatchNorm2d(192) + + self.c21 = nn.Conv2d(192, 64, 1, 1, bias=False) + self.b21 = nn.BatchNorm2d(64) + + self.c22 = nn.Conv2d(64, 384, 1, 1, bias=False) + self.b22 = nn.BatchNorm2d(384) + + self.c23 = nn.Conv2d(384, 384, 3, 1, 1, groups=384, bias=False) + self.b23 = nn.BatchNorm2d(384) + + self.c24 = nn.Conv2d(384, 64, 1, 1, bias=False) + self.b24 = nn.BatchNorm2d(64) + + self.c25 = nn.Conv2d(64, 384, 1, 1, bias=False) + self.b25 = nn.BatchNorm2d(384) + + self.c26 = nn.Conv2d(384, 384, 3, 1, 1, groups=384, bias=False) + self.b26 = nn.BatchNorm2d(384) + + self.c27 = nn.Conv2d(384, 64, 1, 1, bias=False) + self.b27 = nn.BatchNorm2d(64) + + self.c28 = nn.Conv2d(64, 384, 1, 1, bias=False) + self.b28 = nn.BatchNorm2d(384) + + self.c29 = nn.Conv2d(384, 384, 3, 1, 1, groups=384, bias=False) + self.b29 = nn.BatchNorm2d(384) + + self.c30 = nn.Conv2d(384, 64, 1, 1, bias=False) + self.b30 = nn.BatchNorm2d(64) + + self.c31 = nn.Conv2d(64, 384, 1, 1, bias=False) + self.b31 = nn.BatchNorm2d(384) + + self.c32 = nn.Conv2d(384, 384, 3, 1, 1, groups=384, bias=False) + self.b32 = nn.BatchNorm2d(384) + + self.c33 = nn.Conv2d(384, 96, 1, 1, bias=False) + self.b33 = nn.BatchNorm2d(96) + + self.c34 = nn.Conv2d(96, 576, 1, 1, bias=False) + self.b34 = nn.BatchNorm2d(576) + + self.c35 = nn.Conv2d(576, 576, 3, 1, 1, groups=576, bias=False) + self.b35 = nn.BatchNorm2d(576) + + self.c36 = nn.Conv2d(576, 96, 1, 1, bias=False) + self.b36 = nn.BatchNorm2d(96) + + self.c37 = nn.Conv2d(96, 576, 1, 1, bias=False) + self.b37 = nn.BatchNorm2d(576) + + self.c38 = nn.Conv2d(576, 576, 3, 1, 1, groups=576, bias=False) + self.b38 = nn.BatchNorm2d(576) + + self.c39 = nn.Conv2d(576, 96, 1, 1, bias=False) + self.b39 = nn.BatchNorm2d(96) + + self.c40 = nn.Conv2d(96, 576, 1, 1, bias=False) + self.b40 = nn.BatchNorm2d(576) + + self.c41 = nn.Conv2d(576, 576, 3, 2, 1, groups=576, bias=False) + self.b41 = nn.BatchNorm2d(576) + + self.c42 = nn.Conv2d(576, 160, 1, 1, bias=False) + self.b42 = nn.BatchNorm2d(160) + + self.c43 = nn.Conv2d(160, 960, 1, 1, bias=False) + self.b43 = nn.BatchNorm2d(960) + + self.c44 = nn.Conv2d(960, 960, 3, 1, 1, groups=960, bias=False) + self.b44 = nn.BatchNorm2d(960) + + self.c45 = nn.Conv2d(960, 160, 1, 1, bias=False) + self.b45 = nn.BatchNorm2d(160) + + self.c46 = nn.Conv2d(160, 960, 1, 1, bias=False) + self.b46 = nn.BatchNorm2d(960) + + self.c47 = nn.Conv2d(960, 960, 3, 1, 1, groups=960, bias=False) + self.b47 = nn.BatchNorm2d(960) + + self.c48 = nn.Conv2d(960, 160, 1, 1, bias=False) + self.b48 = nn.BatchNorm2d(160) + + self.c49 = nn.Conv2d(160, 960, 1, 1, bias=False) + self.b49 = nn.BatchNorm2d(960) + + self.c50 = nn.Conv2d(960, 960, 3, 1, 1, groups=960, bias=False) + self.b50 = nn.BatchNorm2d(960) + + self.c51 = nn.Conv2d(960, 320, 1, 1, bias=False) + self.b51 = nn.BatchNorm2d(320) + + self.c52 = nn.Conv2d(320, 1280, 1, 1, bias=False) + self.b52 = nn.BatchNorm2d(1280) + + self.l1 = nn.Linear(in_features=1280, out_features=1000) + + def forward(self, input: torch.Tensor): + x1 = self.c1(input) + x1_b = self.b1(x1) + x1_m = self.relu(x1_b) + + x2 = self.c2(x1_m) + x2_b = self.b2(x2) + x2_m = self.relu(x2_b) + + x3 = self.c3(x2_m) + x3_b = self.b3(x3) + + x4 = self.c4(x3_b) + x4_b = self.b4(x4) + x4_m = self.relu(x4_b) + + x5 = self.c5(x4_m) + x5_b = self.b5(x5) + x5_m = self.relu(x5_b) + + x6 = self.c6(x5_m) + x6_b = self.b6(x6) + + x7 = self.c7(x6_b) + x7_b = self.b7(x7) + x7_m = self.relu(x7_b) + + x8 = self.c8(x7_m) + x8_b = self.b8(x8) + x8_m = self.relu(x8_b) + + x9 = self.c9(x8_m) + x9_b = self.b9(x9) + a1 = x9_b + x6_b + x10 = self.c10(a1) + x10_b = self.b10(x10) + x10_m = self.relu(x10_b) + + x11 = self.c11(x10_m) + x11_b = self.b11(x11) + x11_m = self.relu(x11_b) + + x12 = self.c12(x11_m) + x12_b = self.b12(x12) + + x13 = self.c13(x12_b) + x13_b = self.b13(x13) + x13_m = self.relu(x13_b) + + x14 = self.c14(x13_m) + x14_b = self.b14(x14) + x14_m = self.relu(x14_b) + + x15 = self.c15(x14_m) + x15_b = self.b15(x15) + + a2 = x15_b + x12_b + + x16 = self.c16(a2) + x16_b = self.b16(x16) + x16_m = self.relu(x16_b) + + x17 = self.c17(x16_m) + x17_b = self.b17(x17) + x17_m = self.relu(x17_b) + + x18 = self.c18(x17_m) + x18_b = self.b18(x18) + + a3 = a2 + x18_b + + x19 = self.c19(a3) + x19_b = self.b19(x19) + x19_m = self.relu(x19_b) + + x20 = self.c20(x19_m) + x20_b = self.b20(x20) + x20_m = self.relu(x20_b) + + x21 = self.c21(x20_m) + x21_b = self.b21(x21) + + x22 = self.c22(x21_b) + x22_b = self.b22(x22) + x22_m = self.relu(x22_b) + + x23 = self.c23(x22_m) + x23_b = self.b23(x23) + x23_m = self.relu(x23_b) + + x24 = self.c24(x23_m) + x24_b = self.b24(x24) + + a4 = x21_b + x24_b + + x25 = self.c25(a4) + x25_b = self.b25(x25) + x25_m = self.relu(x25_b) + + x26 = self.c26(x25_m) + x26_b = self.b26(x26) + x26_m = self.relu(x26_b) + + x27 = self.c27(x26_m) + x27_b = self.b27(x27) + + a5 = a4 + x27_b + + x28 = self.c28(a5) + x28_b = self.b28(x28) + x28_m = self.relu(x28_b) + + x29 = self.c29(x28_m) + x29_b = self.b29(x29) + x29_m = self.relu(x29_b) + + x30 = self.c30(x29_m) + x30_b = self.b30(x30) + + a6 = a5 + x30_b + + x31 = self.c31(a6) + x31_b = self.b31(x31) + x31_m = self.relu(x31_b) + + x32 = self.c32(x31_m) + x32_b = self.b32(x32) + x32_m = self.relu(x32_b) + + x33 = self.c33(x32_m) + x33_b = self.b33(x33) + + x34 = self.c34(x33_b) + x34_b = self.b34(x34) + x34_m = self.relu(x34_b) + + x35 = self.c35(x34_m) + x35_b = self.b35(x35) + x35_m = self.relu(x35_b) + + x36 = self.c36(x35_m) + x36_b = self.b36(x36) + + a7 = x33_b + x36_b + + x37 = self.c37(a7) + x37_b = self.b37(x37) + x37_m = self.relu(x37_b) + + x38 = self.c38(x37_m) + x38_b = self.b38(x38) + x38_m = self.relu(x38_b) + + x39 = self.c39(x38_m) + x39_b = self.b39(x39) + + a8 = a7 + x39_b + + x40 = self.c40(a8) + x40_b = self.b40(x40) + x40_m = self.relu(x40_b) + + x41 = self.c41(x40_m) + x41_b = self.b41(x41) + x41_m = self.relu(x41_b) + + x42 = self.c42(x41_m) + x42_b = self.b42(x42) + + x43 = self.c43(x42_b) + x43_b = self.b43(x43) + x43_m = self.relu(x43_b) + + x44 = self.c44(x43_m) + x44_b = self.b44(x44) + x44_m = self.relu(x44_b) + + x45 = self.c45(x44_m) + x45_b = self.b45(x45) + + a9 = x45_b + x42_b + + x46 = self.c46(a9) + x46_b = self.b46(x46) + x46_m = self.relu(x46_b) + + x47 = self.c47(x46_m) + x47_b = self.b47(x47) + x47_m = self.relu(x47_b) + + x48 = self.c48(x47_m) + x48_b = self.b48(x48) + + a10 = a9 + x48_b + + x49 = self.c49(a10) + x49_b = self.b49(x49) + x49_m = self.relu(x49_b) + + x50 = self.c50(x49_m) + x50_b = self.b50(x50) + x50_m = self.relu(x50_b) + + x51 = self.c51(x50_m) + x51_b = self.b51(x51) + + x52 = self.c52(x51_b) + x52_b = self.b52(x52) + x52_m = self.relu(x52_b) + x = nn.functional.adaptive_avg_pool2d(x52_m, (1, 1)) + x = torch.flatten(x, 1) + x53 = self.l1(x) + return x53 diff --git a/models/experimental/functional_mobilenetv2/test/test_ttnn_mobilenetv2.py b/models/experimental/functional_mobilenetv2/test/test_ttnn_mobilenetv2.py new file mode 100644 index 000000000000..3d271b6f979c --- /dev/null +++ b/models/experimental/functional_mobilenetv2/test/test_ttnn_mobilenetv2.py @@ -0,0 +1,58 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + + +import pytest +import ttnn +import torch + +from tests.ttnn.utils_for_testing import assert_with_pcc + +from models.experimental.functional_mobilenetv2.reference.mobilenetv2 import Mobilenetv2 +from models.experimental.functional_mobilenetv2.tt.model_preprocessing import ( + create_mobilenetv2_input_tensors, + create_mobilenetv2_model_parameters, +) +from models.experimental.functional_mobilenetv2.tt import ttnn_mobilenetv2 +import os +from models.utility_functions import ( + skip_for_grayskull, +) + + +@pytest.mark.parametrize("device_params", [{"l1_small_size": 32768}], indirect=True) +@skip_for_grayskull() +def test_mobilenetv2(device, reset_seeds): + if not os.path.exists("models/experimental/functional_mobilenetv2/mobilenet_v2-b0353104.pth"): + os.system( + "bash models/experimental/functional_mobilenetv2/weights_download.sh" + ) # execute the weights_download.sh file + + state_dict = torch.load("models/experimental/functional_mobilenetv2/mobilenet_v2-b0353104.pth") + ds_state_dict = {k: v for k, v in state_dict.items()} + torch_model = Mobilenetv2() + + new_state_dict = {} + + for (name1, parameter1), (name2, parameter2) in zip(torch_model.state_dict().items(), ds_state_dict.items()): + if isinstance(parameter2, torch.FloatTensor): + new_state_dict[name1] = parameter2 + + torch_model.load_state_dict(new_state_dict) + torch_model.eval() + torch_input_tensor, ttnn_input_tensor = create_mobilenetv2_input_tensors() + torch_output_tensor = torch_model(torch_input_tensor) + + parameters = create_mobilenetv2_model_parameters(torch_model, torch_input_tensor, device=device) + + ttnn_model = ttnn_mobilenetv2.MobileNetV2(parameters, device, torch_model) + output_tensor = ttnn_model(device, ttnn_input_tensor) + + # + # Tensor Postprocessing + # + output_tensor = ttnn.to_torch(output_tensor) + output_tensor = output_tensor.reshape(torch_output_tensor.shape) + output_tensor = output_tensor.to(torch_input_tensor.dtype) + assert_with_pcc(torch_output_tensor, output_tensor, pcc=0.95) diff --git a/models/experimental/functional_mobilenetv2/tt/model_preprocessing.py b/models/experimental/functional_mobilenetv2/tt/model_preprocessing.py new file mode 100644 index 000000000000..a21a3d2f45f2 --- /dev/null +++ b/models/experimental/functional_mobilenetv2/tt/model_preprocessing.py @@ -0,0 +1,36 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch +import ttnn + +from models.experimental.functional_mobilenetv2.reference.mobilenetv2 import Mobilenetv2 +from ttnn.model_preprocessing import infer_ttnn_module_args + + +def create_mobilenetv2_input_tensors(batch=1, input_channels=3, input_height=128, input_width=128): + torch_input_tensor = torch.randn(batch, input_channels, input_height, input_width) + ttnn_input_tensor = torch.permute(torch_input_tensor, (0, 2, 3, 1)) + ttnn_input_tensor = ttnn_input_tensor.reshape( + 1, + 1, + ttnn_input_tensor.shape[0] * ttnn_input_tensor.shape[1] * ttnn_input_tensor.shape[2], + ttnn_input_tensor.shape[3], + ) + ttnn_input_tensor = ttnn.from_torch(ttnn_input_tensor, dtype=ttnn.bfloat16) + + return torch_input_tensor, ttnn_input_tensor + + +def create_mobilenetv2_model_parameters(model: Mobilenetv2, input_tensor, device): + parameters = infer_ttnn_module_args(model=model, run_model=lambda model: model(input_tensor), device=None) + assert parameters is not None + for key in parameters.keys(): + parameters[key].module = getattr(model, key) + + parameters["l1"] = {} + parameters["l1"]["weight"] = model.l1.weight + parameters["l1"]["bias"] = model.l1.bias + + return parameters diff --git a/models/experimental/functional_mobilenetv2/tt/ttnn_mobilenetv2.py b/models/experimental/functional_mobilenetv2/tt/ttnn_mobilenetv2.py new file mode 100644 index 000000000000..7f18ddac0fa0 --- /dev/null +++ b/models/experimental/functional_mobilenetv2/tt/ttnn_mobilenetv2.py @@ -0,0 +1,464 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import ttnn +import torch + +from ttnn.model_preprocessing import ParameterDict + +from torch import nn +from ttnn.model_preprocessing import preprocess_linear_weight, preprocess_linear_bias + + +class MobileNetV2Conv2D: + def fold_batch_norm2d_into_conv2d(self, conv, bn): + if not bn.track_running_stats: + raise RuntimeError("BatchNorm2d must have track_running_stats=True to be folded into Conv2d") + weight = conv.weight + running_mean = bn.running_mean + running_var = bn.running_var + eps = bn.eps + scale = bn.weight + shift = bn.bias + weight = weight * (scale / torch.sqrt(running_var + eps))[:, None, None, None] + bias = shift - running_mean * (scale / torch.sqrt(running_var + eps)) + return weight, bias + + def __init__( + self, + conv, + bn=None, + device=None, + cache={}, + activation="", + activation_dtype=ttnn.bfloat8_b, + weights_dtype=ttnn.bfloat8_b, + use_1d_systolic_array=True, + shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED, + ): + self.device = device + self.batch_size = conv.batch_size + self.input_height = conv.input_height + self.input_width = conv.input_width + self.in_channels = conv.in_channels + self.out_channels = conv.out_channels + self.kernel_size = conv.kernel_size + self.padding = conv.padding + self.stride = conv.stride + self.groups = conv.groups + self.use_1d_systolic_array = use_1d_systolic_array + self.deallocate_activation = True + self.cache = cache + + self.conv_config = ttnn.Conv2dConfig( + dtype=activation_dtype, + weights_dtype=weights_dtype, + math_fidelity=ttnn.MathFidelity.LoFi, + shard_layout=shard_layout, + deallocate_activation=self.deallocate_activation, + fp32_dest_acc_enabled=True, + packer_l1_accum_enabled=False, + enable_act_double_buffer=False, + enable_split_reader=False, + enable_subblock_padding=False, + reshard_if_not_optimal=True if self.use_1d_systolic_array else False, + activation=activation, + ) + config_override = conv.conv_blocking_and_parallelization_config_override + if config_override and "act_block_h" in config_override: + self.conv_config.act_block_h_override = config_override["act_block_h"] + + if bn is not None: + weight, bias = self.fold_batch_norm2d_into_conv2d(conv.module, bn.module) + else: + weight, bias = conv.module.weight, conv.module.bias + + weight = weight + bias = torch.reshape(bias, (1, 1, 1, -1)) + self.weight = ttnn.from_torch(weight, dtype=ttnn.float32) + self.bias = ttnn.from_torch(bias, dtype=ttnn.float32) + + def __call__(self, x): + x, output_height, output_width, self.weight, self.bias = ttnn.conv2d( + input_tensor=x, + weight_tensor=self.weight, + bias_tensor=self.bias, + device=self.device, + in_channels=self.in_channels, + out_channels=self.out_channels, + input_height=self.input_height, + input_width=self.input_width, + batch_size=self.batch_size, + kernel_size=self.kernel_size, + stride=self.stride, + padding=self.padding, + conv_config=self.conv_config, + conv_op_cache=self.cache, + groups=self.groups, + ) + return x + + +class MobileNetV2: + def input_preprocessor(self, tensor, n, c, h, w): + tensor = ttnn.to_torch(tensor).to(torch.float32) + tensor = torch.reshape(tensor, (n, h, w, c)) + tensor = torch.permute(tensor, (0, 3, 1, 2)) + return tensor + + def __init__(self, parameters: ParameterDict, device, model) -> None: + self.device = device + self.model = model + self.parameters = parameters + + self.c1 = MobileNetV2Conv2D(parameters.c1, parameters.b1, device) + self.c2 = MobileNetV2Conv2D(parameters.c2, parameters.b2, device) + + self.c3 = MobileNetV2Conv2D(parameters.c3, parameters.b3, device) + + self.c4 = MobileNetV2Conv2D(parameters.c4, parameters.b4, device) + + self.c5 = MobileNetV2Conv2D(parameters.c5, parameters.b5, device) + + self.c6 = MobileNetV2Conv2D(parameters.c6, parameters.b6, device) + + self.c7 = MobileNetV2Conv2D(parameters.c7, parameters.b7, device) + + self.c8 = MobileNetV2Conv2D(parameters.c8, parameters.b8, device) + + self.c9 = MobileNetV2Conv2D(parameters.c9, parameters.b9, device) + + self.c10 = MobileNetV2Conv2D(parameters.c10, parameters.b10, device) + + self.c11 = MobileNetV2Conv2D(parameters.c11, parameters.b11, device) + + self.c12 = MobileNetV2Conv2D(parameters.c12, parameters.b12, device) + + self.c13 = MobileNetV2Conv2D(parameters.c13, parameters.b13, device) + self.c14 = MobileNetV2Conv2D(parameters.c14, parameters.b14, device) + self.c15 = MobileNetV2Conv2D(parameters.c15, parameters.b15, device) + self.c16 = MobileNetV2Conv2D(parameters.c16, parameters.b16, device) + self.c17 = MobileNetV2Conv2D(parameters.c17, parameters.b17, device) + self.c18 = MobileNetV2Conv2D(parameters.c18, parameters.b18, device) + self.c19 = MobileNetV2Conv2D(parameters.c19, parameters.b19, device) + self.c20 = MobileNetV2Conv2D(parameters.c20, parameters.b20, device) + self.c21 = MobileNetV2Conv2D(parameters.c21, parameters.b21, device) + self.c22 = MobileNetV2Conv2D(parameters.c22, parameters.b22, device) + self.c23 = MobileNetV2Conv2D(parameters.c23, parameters.b23, device) + self.c24 = MobileNetV2Conv2D(parameters.c24, parameters.b24, device) + self.c25 = MobileNetV2Conv2D(parameters.c25, parameters.b25, device) + self.c26 = MobileNetV2Conv2D(parameters.c26, parameters.b26, device) + self.c27 = MobileNetV2Conv2D(parameters.c27, parameters.b27, device) + self.c28 = MobileNetV2Conv2D(parameters.c28, parameters.b28, device) + self.c29 = MobileNetV2Conv2D(parameters.c29, parameters.b29, device) + self.c30 = MobileNetV2Conv2D(parameters.c30, parameters.b30, device) + self.c31 = MobileNetV2Conv2D(parameters.c31, parameters.b31, device) + self.c32 = MobileNetV2Conv2D(parameters.c32, parameters.b32, device) + self.c33 = MobileNetV2Conv2D(parameters.c33, parameters.b33, device) + self.c34 = MobileNetV2Conv2D(parameters.c34, parameters.b34, device) + + self.c35 = MobileNetV2Conv2D( + parameters.c35, parameters.b35, device, shard_layout=ttnn.TensorMemoryLayout.WIDTH_SHARDED + ) + + self.c36 = MobileNetV2Conv2D(parameters.c36, parameters.b36, device) + self.c37 = MobileNetV2Conv2D(parameters.c37, parameters.b37, device) + + self.c38 = MobileNetV2Conv2D( + parameters.c38, parameters.b38, device, shard_layout=ttnn.TensorMemoryLayout.WIDTH_SHARDED + ) + + self.c39 = MobileNetV2Conv2D(parameters.c39, parameters.b39, device) + self.c40 = MobileNetV2Conv2D(parameters.c40, parameters.b40, device) + + self.c41 = MobileNetV2Conv2D( + parameters.c41, parameters.b41, device, shard_layout=ttnn.TensorMemoryLayout.WIDTH_SHARDED + ) + + self.c42 = MobileNetV2Conv2D(parameters.c42, parameters.b42, device) + self.c43 = MobileNetV2Conv2D(parameters.c43, parameters.b43, device) + + self.c44 = MobileNetV2Conv2D( + parameters.c44, parameters.b44, device, shard_layout=ttnn.TensorMemoryLayout.WIDTH_SHARDED + ) + + self.c45 = MobileNetV2Conv2D(parameters.c45, parameters.b45, device) + self.c46 = MobileNetV2Conv2D(parameters.c46, parameters.b46, device) + + self.c47 = MobileNetV2Conv2D( + parameters.c47, parameters.b47, device, shard_layout=ttnn.TensorMemoryLayout.WIDTH_SHARDED + ) + + self.c48 = MobileNetV2Conv2D(parameters.c48, parameters.b48, device) + self.c49 = MobileNetV2Conv2D(parameters.c49, parameters.b49, device) + + self.c50 = MobileNetV2Conv2D( + parameters.c50, parameters.b50, device, shard_layout=ttnn.TensorMemoryLayout.WIDTH_SHARDED + ) + + self.c51 = MobileNetV2Conv2D(parameters.c51, parameters.b51, device) + self.c52 = MobileNetV2Conv2D(parameters.c52, parameters.b52, device) + + self.l1_weight = parameters.l1["weight"] + self.l1_bias = parameters.l1["bias"] + + def __call__( + self, + device, + x, + ): + output_tensor = self.c1(x) + output_tensor = ttnn.relu6(output_tensor) + + output_tensor = self.c2(output_tensor) + output_tensor = ttnn.relu6(output_tensor) + + output_tensor = self.c3(output_tensor) + + output_tensor = self.c4(output_tensor) + output_tensor = ttnn.relu6(output_tensor) + + output_tensor = self.c5(output_tensor) + output_tensor = ttnn.relu6(output_tensor) + + output_tensor = self.c6(output_tensor) + output_tensor_c6 = output_tensor + + if output_tensor_c6.is_sharded(): + output_tensor_c6 = ttnn.sharded_to_interleaved(output_tensor_c6, ttnn.L1_MEMORY_CONFIG) + + output_tensor = self.c7(output_tensor) + output_tensor = ttnn.relu6(output_tensor) + + output_tensor = self.c8(output_tensor) + output_tensor = ttnn.relu6(output_tensor) + + output_tensor = self.c9(output_tensor) + + if output_tensor.is_sharded(): + output_tensor = ttnn.sharded_to_interleaved(output_tensor, ttnn.L1_MEMORY_CONFIG) + + output_tensor = ttnn.add(output_tensor_c6, output_tensor) + + output_tensor = self.c10(output_tensor) + output_tensor = ttnn.relu6(output_tensor) + + output_tensor = self.c11(output_tensor) + output_tensor = ttnn.relu6(output_tensor) + + output_tensor = self.c12(output_tensor) + output_tensor_c12 = output_tensor + + if output_tensor_c12.is_sharded(): + output_tensor_c12 = ttnn.sharded_to_interleaved(output_tensor_c12, ttnn.L1_MEMORY_CONFIG) + + output_tensor = self.c13(output_tensor) + output_tensor = ttnn.relu6(output_tensor) + + output_tensor = self.c14(output_tensor) + output_tensor = ttnn.relu6(output_tensor) + + output_tensor = self.c15(output_tensor) + output_tensor_c15 = output_tensor + + if output_tensor_c15.is_sharded(): + output_tensor_c15 = ttnn.sharded_to_interleaved(output_tensor_c15, ttnn.L1_MEMORY_CONFIG) + + output_tensor = output_tensor_c15 + output_tensor_c12 + output_tensor_a2 = output_tensor + + if output_tensor_a2.is_sharded(): + output_tensor_a2 = ttnn.sharded_to_interleaved(output_tensor_a2, ttnn.L1_MEMORY_CONFIG) + + output_tensor = self.c16(output_tensor) + output_tensor = ttnn.relu6(output_tensor) + + output_tensor = self.c17(output_tensor) + output_tensor = ttnn.relu6(output_tensor) + + output_tensor = self.c18(output_tensor) + + if output_tensor.is_sharded(): + output_tensor = ttnn.sharded_to_interleaved(output_tensor, ttnn.L1_MEMORY_CONFIG) + + output_tensor = output_tensor_a2 + output_tensor + + output_tensor = self.c19(output_tensor) + output_tensor = ttnn.relu6(output_tensor) + + output_tensor = self.c20(output_tensor) + output_tensor = ttnn.relu6(output_tensor) + + output_tensor = self.c21(output_tensor) + + output_tensor_c21 = output_tensor + if output_tensor_c21.is_sharded(): + output_tensor_c21 = ttnn.sharded_to_interleaved(output_tensor_c21, ttnn.L1_MEMORY_CONFIG) + + output_tensor = self.c22(output_tensor) + output_tensor = ttnn.relu6(output_tensor) + + output_tensor = self.c23(output_tensor) + output_tensor = ttnn.relu6(output_tensor) + + output_tensor = self.c24(output_tensor) + + if output_tensor.is_sharded(): + output_tensor = ttnn.sharded_to_interleaved(output_tensor, ttnn.L1_MEMORY_CONFIG) + + output_tensor = output_tensor_c21 + output_tensor + + output_tensor_a4 = output_tensor + + if output_tensor_a4.is_sharded(): + output_tensor_a4 = ttnn.sharded_to_interleaved(output_tensor_a4, ttnn.L1_MEMORY_CONFIG) + + output_tensor = self.c25(output_tensor) + output_tensor = ttnn.relu6(output_tensor) + + output_tensor = self.c26(output_tensor) + output_tensor = ttnn.relu6(output_tensor) + + output_tensor = self.c27(output_tensor) + + if output_tensor.is_sharded(): + output_tensor = ttnn.sharded_to_interleaved(output_tensor, ttnn.L1_MEMORY_CONFIG) + + output_tensor = output_tensor_a4 + output_tensor + output_tensor_a5 = output_tensor + if output_tensor_a5.is_sharded(): + output_tensor_a5 = ttnn.sharded_to_interleaved(output_tensor_a5, ttnn.L1_MEMORY_CONFIG) + + output_tensor = self.c28(output_tensor) + output_tensor = ttnn.relu6(output_tensor) + + output_tensor = self.c29(output_tensor) + output_tensor = ttnn.relu6(output_tensor) + + output_tensor = self.c30(output_tensor) + + if output_tensor.is_sharded(): + output_tensor = ttnn.sharded_to_interleaved(output_tensor, ttnn.L1_MEMORY_CONFIG) + + output_tensor = ttnn.add(output_tensor_a5, output_tensor) + + output_tensor = self.c31(output_tensor) + output_tensor = ttnn.relu6(output_tensor) + + output_tensor = self.c32(output_tensor) + output_tensor = ttnn.relu6(output_tensor) + + output_tensor = self.c33(output_tensor) + + output_tensor_c33 = output_tensor + if output_tensor_c33.is_sharded(): + output_tensor_c33 = ttnn.sharded_to_interleaved(output_tensor_c33, ttnn.L1_MEMORY_CONFIG) + + output_tensor = self.c34(output_tensor_c33) + output_tensor = ttnn.relu6(output_tensor) + + output_tensor = self.c35(output_tensor) + output_tensor = ttnn.relu6(output_tensor) + + output_tensor = self.c36(output_tensor) + + if output_tensor.is_sharded(): + output_tensor = ttnn.sharded_to_interleaved(output_tensor, ttnn.L1_MEMORY_CONFIG) + + output_tensor = output_tensor_c33 + output_tensor + + output_tensor_a7 = output_tensor + + if output_tensor_a7.is_sharded(): + output_tensor_a7 = ttnn.sharded_to_interleaved(output_tensor_a7, ttnn.L1_MEMORY_CONFIG) + + output_tensor = self.c37(output_tensor) + output_tensor = ttnn.relu6(output_tensor) + + output_tensor = self.c38(output_tensor) + output_tensor = ttnn.relu6(output_tensor) + + output_tensor = self.c39(output_tensor) + + if output_tensor.is_sharded(): + output_tensor = ttnn.sharded_to_interleaved(output_tensor, ttnn.L1_MEMORY_CONFIG) + + output_tensor = output_tensor_a7 + output_tensor + + output_tensor = self.c40(output_tensor) + output_tensor = ttnn.relu6(output_tensor) + + output_tensor = self.c41(output_tensor) + output_tensor = ttnn.relu6(output_tensor) + + output_tensor = self.c42(output_tensor) + + output_tensor_c42 = output_tensor + if output_tensor_c42.is_sharded(): + output_tensor_c42 = ttnn.sharded_to_interleaved(output_tensor_c42, ttnn.L1_MEMORY_CONFIG) + + output_tensor = self.c43(output_tensor_c42) + output_tensor = ttnn.relu6(output_tensor) + output_tensor = ttnn.to_layout(output_tensor, ttnn.ROW_MAJOR_LAYOUT) + + output_tensor = self.c44(output_tensor) + output_tensor = ttnn.relu6(output_tensor) + + output_tensor = self.c45(output_tensor) + + if output_tensor.is_sharded(): + output_tensor = ttnn.sharded_to_interleaved(output_tensor, ttnn.L1_MEMORY_CONFIG) + + output_tensor = output_tensor_c42 + output_tensor + output_tensor_a9 = output_tensor + + if output_tensor_a9.is_sharded(): + output_tensor_a9 = ttnn.sharded_to_interleaved(output_tensor_a9, ttnn.L1_MEMORY_CONFIG) + + output_tensor = self.c46(output_tensor) + output_tensor = ttnn.relu6(output_tensor) + + output_tensor = ttnn.to_layout(output_tensor, ttnn.ROW_MAJOR_LAYOUT) + output_tensor = self.c47(output_tensor) + output_tensor = ttnn.relu6(output_tensor) + + output_tensor = self.c48(output_tensor) + + if output_tensor.is_sharded(): + output_tensor = ttnn.sharded_to_interleaved(output_tensor, ttnn.L1_MEMORY_CONFIG) + + output_tensor = output_tensor + output_tensor_a9 + + output_tensor = self.c49(output_tensor) + output_tensor = ttnn.relu6(output_tensor) + + output_tensor = ttnn.to_layout(output_tensor, ttnn.ROW_MAJOR_LAYOUT) + output_tensor = self.c50(output_tensor) + output_tensor = ttnn.relu6(output_tensor) + + output_tensor = self.c51(output_tensor) + + output_tensor = self.c52(output_tensor) + output_tensor = ttnn.relu6(output_tensor) + + if output_tensor.is_sharded(): + output_tensor = ttnn.sharded_to_interleaved(output_tensor, ttnn.L1_MEMORY_CONFIG) + + output_tensor = ttnn.global_avg_pool2d(output_tensor) + + output_tensor = self.input_preprocessor(output_tensor, 1, 1280, 1, 1) + + output_tensor = torch.flatten(output_tensor, 1) + + output_tensor = ttnn.from_torch(output_tensor, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device) + output_tensor = ttnn.to_memory_config(output_tensor, ttnn.L1_MEMORY_CONFIG) + + self.l1_weight = preprocess_linear_weight(self.l1_weight, dtype=ttnn.bfloat16) + self.l1_bias = preprocess_linear_bias(self.l1_bias, dtype=ttnn.bfloat16) + self.l1_weight = ttnn.to_device(self.l1_weight, device) + self.l1_bias = ttnn.to_device(self.l1_bias, device) + + output_tensor = ttnn.linear(output_tensor, self.l1_weight, bias=self.l1_bias) + + return ttnn.from_device(output_tensor) diff --git a/models/experimental/functional_mobilenetv2/weights_download.sh b/models/experimental/functional_mobilenetv2/weights_download.sh new file mode 100644 index 000000000000..89dbc939007e --- /dev/null +++ b/models/experimental/functional_mobilenetv2/weights_download.sh @@ -0,0 +1,15 @@ +#!/bin/bash + +# Output filename +OUTPUT="models/experimental/functional_mobilenetv2/mobilenet_v2-b0353104.pth" + +# Create output directory if it doesn't exist +mkdir -p "$(dirname "$OUTPUT")" + +# Download the file using wget +if wget "https://download.pytorch.org/models/mobilenet_v2-b0353104.pth" -O "${OUTPUT}"; then + echo "File downloaded successfully: ${OUTPUT}" +else + echo "Error downloading the file." + exit 1 +fi diff --git a/tests/nightly/single_card/functional_mobilenetv2/experimental/functional_mobilenetv2/test/test_ttnn_mobilenetv2.py b/tests/nightly/single_card/functional_mobilenetv2/experimental/functional_mobilenetv2/test/test_ttnn_mobilenetv2.py new file mode 120000 index 000000000000..c1a01b75951e --- /dev/null +++ b/tests/nightly/single_card/functional_mobilenetv2/experimental/functional_mobilenetv2/test/test_ttnn_mobilenetv2.py @@ -0,0 +1 @@ +../../../../../../../models/experimental/functional_mobilenetv2/test/test_ttnn_mobilenetv2.py \ No newline at end of file diff --git a/tests/scripts/single_card/nightly/run_common_models.sh b/tests/scripts/single_card/nightly/run_common_models.sh index 9c057b224827..1b2afd0f8bc1 100755 --- a/tests/scripts/single_card/nightly/run_common_models.sh +++ b/tests/scripts/single_card/nightly/run_common_models.sh @@ -15,3 +15,4 @@ env pytest -n auto tests/nightly/single_card/common_models/ ; fail+=$? if [[ $fail -ne 0 ]]; then exit 1 fi +env pytest models/experimental/functional_mobilenetv2/test/test_ttnn_mobilenetv2.py