diff --git a/include/aie/Dialect/AIE/IR/AIETargetModel.h b/include/aie/Dialect/AIE/IR/AIETargetModel.h index 9849c70b2c..b524e97578 100644 --- a/include/aie/Dialect/AIE/IR/AIETargetModel.h +++ b/include/aie/Dialect/AIE/IR/AIETargetModel.h @@ -19,7 +19,7 @@ namespace xilinx::AIE { -typedef struct TileID { +using TileID = struct TileID { // friend definition (will define the function as a non-member function in the // namespace surrounding the class). friend std::ostream &operator<<(std::ostream &os, const TileID &s) { @@ -50,7 +50,7 @@ typedef struct TileID { bool operator!=(const TileID &rhs) const { return !(*this == rhs); } int col, row; -} TileID; +}; class AIETargetModel { public: @@ -155,7 +155,7 @@ class AIETargetModel { virtual bool isLegalMemAffinity(int coreCol, int coreRow, int memCol, int memRow) const = 0; - /// Return the base address in the local address map of differnet memories. + /// Return the base address in the local address map of different memories. virtual uint32_t getMemInternalBaseAddress(TileID src) const = 0; virtual uint32_t getMemSouthBaseAddress() const = 0; virtual uint32_t getMemWestBaseAddress() const = 0; @@ -321,7 +321,7 @@ class AIE2TargetModel : public AIETargetModel { }; class VC1902TargetModel : public AIE1TargetModel { - llvm::SmallDenseSet noc_columns = { + llvm::SmallDenseSet nocColumns = { 2, 3, 6, 7, 10, 11, 18, 19, 26, 27, 34, 35, 42, 43, 46, 47}; public: @@ -332,11 +332,11 @@ class VC1902TargetModel : public AIE1TargetModel { int rows() const override { return 9; /* One Shim row and 8 Core rows. */ } bool isShimNOCTile(int col, int row) const override { - return row == 0 && noc_columns.contains(col); + return row == 0 && nocColumns.contains(col); } bool isShimPLTile(int col, int row) const override { - return row == 0 && !noc_columns.contains(col); + return row == 0 && !nocColumns.contains(col); } bool isShimNOCorPLTile(int col, int row) const override { @@ -504,7 +504,8 @@ class IPUTargetModel : public AIE2TargetModel { } // namespace xilinx::AIE namespace llvm { -template <> struct DenseMapInfo { +template <> +struct DenseMapInfo { using FirstInfo = DenseMapInfo; using SecondInfo = DenseMapInfo; @@ -528,7 +529,8 @@ template <> struct DenseMapInfo { }; } // namespace llvm -template <> struct std::hash { +template <> +struct std::hash { std::size_t operator()(const xilinx::AIE::TileID &s) const noexcept { std::size_t h1 = std::hash{}(s.col); std::size_t h2 = std::hash{}(s.row); diff --git a/python/util.py b/python/util.py index adbaba2c52..deca1e39e6 100644 --- a/python/util.py +++ b/python/util.py @@ -1,10 +1,10 @@ import multiprocessing import numbers import os -import warnings from collections import defaultdict from contextlib import ExitStack, contextmanager from dataclasses import dataclass +from pprint import pprint from typing import List, Tuple, Dict from typing import Optional from typing import Union @@ -205,12 +205,10 @@ def route_using_cp( solver.parameters.max_time_in_seconds = timeout # Create variable for each edge, for each path - flow_vars = {} - flat_flow_vars = [] - for flow in flows: - flow_var = {(i, j): model.NewIntVar(0, 1, "") for i, j in DG.edges} - flow_vars[flow] = flow_var - flat_flow_vars.append(flow_var) + flow_vars = { + flow: {(i, j): model.NewIntVar(0, 1, "") for i, j in DG.edges} for flow in flows + } + flat_flow_vars = list(flow_vars.values()) # Add flow-balance constraints at all nodes (besides sources and targets) for (src, tgt), flow_var in zip(flows, flat_flow_vars): @@ -303,12 +301,10 @@ def route_using_ilp( m.setParam("TimeLimit", timeout) # Create variable for each edge, for each path - flow_vars = {} - flat_flow_vars = [] - for flow in flows: - flow_var = m.addVars(DG.edges, vtype=GRB.BINARY, name="flow") - flow_vars[flow] = flow_var - flat_flow_vars.append(flow_var) + flow_vars = { + flow: m.addVars(DG.edges, vtype=GRB.BINARY, name="flow") for flow in flows + } + flat_flow_vars = list(flow_vars.values()) # Add flow-balance constraints at all nodes (besides sources and targets) for (src, tgt), flow_var in zip(flows, flat_flow_vars): @@ -512,6 +508,11 @@ def find_paths(self): ) self.routing_solution = get_routing_solution(DG, flow_paths) + + for pe, sws in self.routing_solution.items(): + print(pe) + pprint(sws, indent=2) + return self.routing_solution def is_legal(self):