diff --git a/examples/BuddyBert/bert-main.cpp b/examples/BuddyBert/bert-main.cpp index 902c702c15..d3f0075491 100644 --- a/examples/BuddyBert/bert-main.cpp +++ b/examples/BuddyBert/bert-main.cpp @@ -24,18 +24,9 @@ using namespace buddy; -// Define ResultContainer -struct ResultContainer { - MemRef memRef3D; - MemRef memRef2D; - - ResultContainer(MemRef m1, MemRef m2) - : memRef3D(m1), memRef2D(m2) {} -}; - // Declare BERT forward function. extern "C" void -_mlir_ciface_forward(ResultContainer *result, MemRef *arg0, +_mlir_ciface_forward(MemRef *result, MemRef *arg0, MemRef *arg1, MemRef *arg2, MemRef *arg3, MemRef *arg4); @@ -94,9 +85,7 @@ int main() { pureStrContainer.tokenizeBert(vocabDir, 5); /// Initialize data containers. - MemRef result1({1, 5, 768}); - MemRef result2({1, 6}); - ResultContainer result(result1, result2); + MemRef result({1, 6}); MemRef attention_mask({1, 5}, 1LL); MemRef token_type_ids({1, 5}, 0LL); @@ -104,7 +93,7 @@ int main() { /// Execute forward inference of the model. _mlir_ciface_forward(&result, &arg0, &arg1, &pureStrContainer, - &token_type_ids, &attention_mask); + &token_type_ids, &attention_mask); const auto inferenceEnd = std::chrono::high_resolution_clock::now(); const std::chrono::duration inferenceTime = @@ -113,8 +102,8 @@ int main() { int predict_label = -1; float max_logits = std::numeric_limits::min(); for (int i = 0; i < 6; i++) { - if (max_logits < result.memRef2D.getData()[i]) { - max_logits = result.memRef2D.getData()[i]; + if (max_logits < result.getData()[i]) { + max_logits = result.getData()[i]; predict_label = i; } } diff --git a/examples/BuddyWhisper/whisper-main.cpp b/examples/BuddyWhisper/whisper-main.cpp index 42e75f2c38..011b5c847e 100644 --- a/examples/BuddyWhisper/whisper-main.cpp +++ b/examples/BuddyWhisper/whisper-main.cpp @@ -33,7 +33,7 @@ using namespace std; using namespace buddy; using namespace dap; -constexpr size_t ParamsSize = 99148800; +constexpr size_t ParamsSize = 72593920; constexpr size_t MaxVocabSize = 51865; constexpr size_t MaxTokenLength = 448; @@ -125,9 +125,8 @@ int main() { Text outputContainer; Audio rawAudioContainer("../../examples/BuddyWhisper/audio.wav"); MemRef audioInput({1, 80, 3000}); - MemRef resultContainer[3] = { + MemRef resultContainer[2] = { MemRef({1, 1500, 512}, false, 0), - MemRef({1, 448, 512}, false, 0), MemRef({1, 448, MaxVocabSize}, false, 0), }; MemRef textContainer({1, MaxTokenLength}, 50258); @@ -156,7 +155,7 @@ int main() { inferenceEnd - inferenceStart; // Determine the generated token. - const float *startPtr = resultContainer[2].getData() + i * MaxVocabSize; + const float *startPtr = resultContainer[1].getData() + i * MaxVocabSize; const float *endPtr = startPtr + MaxVocabSize; int maxIndex = findMaxIndex(startPtr, endPtr); @@ -172,9 +171,8 @@ int main() { textContainer.getData()[i + 1] = maxIndex; outputContainer.appendTokenIdx(maxIndex); - // free(resultContainer[0].release()); - // free(resultContainer[1].release()); - // free(resultContainer[2].release()); + free(resultContainer[0].release()); + free(resultContainer[1].release()); } /// Print the final result diff --git a/frontend/Python/frontend.py b/frontend/Python/frontend.py index 5d4f256f86..69a46e7842 100644 --- a/frontend/Python/frontend.py +++ b/frontend/Python/frontend.py @@ -171,6 +171,7 @@ def __init__( "_scaled_dot_product_flash_attention_for_cpu.default": ScaledDotProductFlashAttentionForCpuOp, "ge.Scalar": GeOp, "gt.Tensor": GreaterThanOp, + "_unsafe_index.Tensor": UnsafeIndexOp, } @property @@ -261,11 +262,26 @@ def _compile_fx( return for torchdynamo's call. """ - params = { - **dict(gm.named_parameters(remove_duplicate=False)), - **dict(gm.named_buffers(remove_duplicate=False)), - } - params_flat, _ = pytree.tree_flatten(params) + # params = { + # # **dict(gm.named_parameters(remove_duplicate=False)), + # **dict(gm.named_buffers(remove_duplicate=False)), + # } + # print(len(params)) + # params_flat, _ = pytree.tree_flatten(params) + inputs_pos = [] + params_pos = [] + buffers_pos = [] + for i, node in enumerate(gm.graph.nodes): + if i >= len(inputs): + break + if not str(node).startswith("l_self"): + inputs_pos.append(i) + elif "buffer" in str(node): + buffers_pos.append(i) + else: + params_pos.append(i) + + params_flat = [inputs[i] for i in params_pos + buffers_pos] if self._verbose: print("Graph in tabular form:") @@ -275,7 +291,9 @@ def _compiler(_gm: torch.fx.GraphModule, _inputs: List[torch.Tensor]): """Compile a FX graph in Aten/Prims IR to MLIR.""" nonlocal params_flat func_inputs = [] - for inp in _inputs[len(params_flat) :]: + for i in inputs_pos: + # for inp in _inputs[len(params_flat) :]: + inp = _inputs[i] inp_shape = inp.shape inp_dtype = self._torch_dtype_translate(str(inp.dtype)) func_inputs.append(TensorMeta(inp_shape, inp_dtype)) @@ -290,7 +308,20 @@ def _compiler(_gm: torch.fx.GraphModule, _inputs: List[torch.Tensor]): self._func_name, self._verbose ) - for gm_node in _gm.graph.nodes: + param_nodes = [] + buffers_nodes = [] + input_nodes = [] + for i, node in enumerate(_gm.graph.nodes): + if i in params_pos: + param_nodes.append(node) + elif i in buffers_pos: + buffers_nodes.append(node) + elif i in inputs_pos: + input_nodes.append(node) + + gm_nodes = param_nodes + buffers_nodes + input_nodes + + for gm_node in gm_nodes: node_users = [] for user in gm_node.users.keys(): node_users.append(str(user)) diff --git a/frontend/Python/graph/operation.py b/frontend/Python/graph/operation.py index 511adf6e35..c1a7b09746 100644 --- a/frontend/Python/graph/operation.py +++ b/frontend/Python/graph/operation.py @@ -553,3 +553,8 @@ def __init__(self) -> None: super().__init__() self._op_type = OpType.BroadcastType + +class UnsafeIndexOp(Op): + def __init__(self) -> None: + super().__init__() + self._op_type = OpType.ReshapeType diff --git a/frontend/Python/ops/linalg.py b/frontend/Python/ops/linalg.py index 00a2ccda57..ec6c827e6c 100644 --- a/frontend/Python/ops/linalg.py +++ b/frontend/Python/ops/linalg.py @@ -1231,28 +1231,51 @@ def index_op( return input1_shape = ir.RankedTensorType(input1.type).shape input2 = node.args[1] + input2_dim_sum = 0 + for i in range(len(input2)): + input2_dim_sum += len(symbol_table.get((str(input2[i]), 0)).type.shape) output_shape = list(node.tensor_meta["shape"]) + input_shape = input1.type.shape dtype = node.tensor_meta["dtype"] mlir_dtype = mlir_element_type_get(dtype) if len(input2) < len(input1_shape): tensor_type = ir.RankedTensorType.get(output_shape, mlir_dtype) output = tensor.EmptyOp(output_shape, mlir_dtype) - loops = ir.RankedTensorType( - symbol_table.get((str(input2[0]), 0)).type - ).shape generic_map = ir.AffineMap.get_permutation( - [i for i in range(len(output_shape))] + [i for i in range(max(len(output_shape), len(input_shape)))] ) - input_map = [ - ir.AffineMapAttr.get( - generic_map.get_submap([j for j in range(len(loops))]) + input_map = [] + for i in range(len(input2)): + input2_shape = symbol_table.get((str(input2[i]), 0)).type.shape + input_map.append( + ir.AffineMapAttr.get( + generic_map.get_submap( + [j for j in range(i, i + len(input2_shape))] + ) + ) ) - for i in range(len(input2)) - ] + [ - ir.AffineMapAttr.get( - generic_map.get_submap([j for j in range(len(output_shape))]) + if len(input_shape) > len(output_shape): + input_map.append( + ir.AffineMapAttr.get( + generic_map.get_submap( + [ + j + for j in range( + len(input_shape) - len(output_shape), + len(input_shape), + ) + ] + ) + ) + ) + else: + input_map.append( + ir.AffineMapAttr.get( + generic_map.get_submap( + [j for j in range(len(output_shape))] + ) + ) ) - ] operands = [symbol_table.get((str(i), 0)) for i in input2] op = linalg.GenericOp( [tensor_type], @@ -1261,7 +1284,7 @@ def index_op( ir.ArrayAttr.get(input_map), ir.ArrayAttr.get( [ir.Attribute.parse("#linalg.iterator_type")] - * len(output_shape) + * max(len(output_shape), len(input_shape)) ), ) arguments = [ @@ -1273,7 +1296,9 @@ def index_op( indexcast_op = arith.IndexCastOp(ir.IndexType.get(), i) block.append(indexcast_op) index.append(indexcast_op.result) - for i in range(len(loops), len(output_shape) - len(input2) + 1): + for i in range( + input2_dim_sum, max(len(input_shape), len(output_shape)) + ): index_op = linalg.IndexOp(ir._i64Attr(i, None)) block.append(index_op) index.append(index_op.result) @@ -1573,6 +1598,9 @@ def softmax_op( if dim < 0: dim += len(output_shape) mlir_dtype = mlir_element_type_get(dtype) + max_vals = tosa.ReduceMaxOp(input1, dim) + sub_op_output = ir.RankedTensorType.get(input1.type.shape, mlir_dtype) + input1 = tosa.SubOp(sub_op_output, input1, max_vals) # tensor_type = ir.RankedTensorType.get(output_shape, mlir_dtype) # output = tensor.EmptyOp(output_shape, mlir_dtype) # op = linalg.softmax( @@ -2118,6 +2146,202 @@ def greater_than_op( return op +def unsafe_index_op( + node: UnsafeIndexOp, + symbol_table: Dict[Tuple[str, int], ir.Operation], +): + """ + Import the tensor _unsafe_index operation. + From buddy UnsafeIndexOp to MLIR linalg `generic` + operation. + Note: This op, get input node slice result by input index. + Args: + node: Containing information from the input graph node. + symbol_table: A dictionary mapping symbols to their corresponding + operations. + Returns: + op: The operation return the linalg.generic op. + """ + assert len(node.args) == 2 + input1 = symbol_table.get((str(node.args[0]), 0)) + if input1 is None: + return + input1_shape = ir.RankedTensorType(input1.type).shape + input2 = node.args[1] + have_none = False + for i in input2: + if i == None: + have_none = True + break + input2_dim_sum = 0 + for i in range(len(input2)): + input2_dim_sum += ( + len(symbol_table.get((str(input2[i]), 0)).type.shape) + if input2[i] != None + else 0 + ) + output_shape = list(node.tensor_meta["shape"]) + input_shape = input1.type.shape + dtype = node.tensor_meta["dtype"] + mlir_dtype = mlir_element_type_get(dtype) + if len(input2) < len(input1_shape): + tensor_type = ir.RankedTensorType.get(output_shape, mlir_dtype) + output = tensor.EmptyOp(output_shape, mlir_dtype) + generic_map = ir.AffineMap.get_permutation( + [i for i in range(max(len(output_shape), len(input_shape)))] + ) + input_map = [] + for i in range(len(input2)): + input2_shape = symbol_table.get((str(input2[i]), 0)).type.shape + input_map.append( + ir.AffineMapAttr.get( + generic_map.get_submap( + [j for j in range(i, i + len(input2_shape))] + ) + ) + ) + if len(input_shape) > len(output_shape): + input_map.append( + ir.AffineMapAttr.get( + generic_map.get_submap( + [ + j + for j in range( + len(input_shape) - len(output_shape), + len(input_shape), + ) + ] + ) + ) + ) + else: + input_map.append( + ir.AffineMapAttr.get( + generic_map.get_submap( + [j for j in range(len(output_shape))] + ) + ) + ) + operands = [symbol_table.get((str(i), 0)) for i in input2] + op = linalg.GenericOp( + [tensor_type], + operands, + [output], + ir.ArrayAttr.get(input_map), + ir.ArrayAttr.get( + [ir.Attribute.parse("#linalg.iterator_type")] + * max(len(output_shape), len(input_shape)) + ), + ) + arguments = [ + ir.RankedTensorType(i.type).element_type for i in operands + ] + [ir.RankedTensorType(output.result.type).element_type] + block = ir.Block.create_at_start(op.region, arguments) + index = [] + for i in block.arguments[:-1]: + indexcast_op = arith.IndexCastOp(ir.IndexType.get(), i) + block.append(indexcast_op) + index.append(indexcast_op.result) + for i in range( + input2_dim_sum, max(len(input_shape), len(output_shape)) + ): + index_op = linalg.IndexOp(ir._i64Attr(i, None)) + block.append(index_op) + index.append(index_op.result) + value = tensor.ExtractOp(input1, index) + block.append(value) + block.append(linalg.YieldOp([value.result])) + else: + tensor_type = ir.RankedTensorType.get(output_shape, mlir_dtype) + output = tensor.EmptyOp(output_shape, mlir_dtype) + generic_map = ir.AffineMap.get_permutation( + [i for i in range(max(len(output_shape), len(input_shape)))] + ) + input_map = [] + for i in range(len(input2)): + if input2[i] == None: + continue + input2_shape = symbol_table.get((str(input2[i]), 0)).type.shape + if have_none: + input_map.append( + ir.AffineMapAttr.get( + generic_map.get_submap([j for j in range(i, i + 1)]) + ) + ) + if len(input_shape) > len(output_shape): + input_map.append( + ir.AffineMapAttr.get( + generic_map.get_submap( + [ + j + for j in range( + len(input_shape) - len(output_shape), + len(input_shape), + ) + ] + ) + ) + ) + else: + input_map.append( + ir.AffineMapAttr.get( + generic_map.get_submap( + [j for j in range(len(output_shape))] + ) + ) + ) + if have_none: + operands = [] + for i in input2: + if i == None: + continue + input2_ = symbol_table.get((str(i), 0)) + input2_shape = input2_.type.shape + if i != None and len(input2_shape) > 1: + total_size = 1 + for x in input2_shape: + total_size *= x + reshape_op = tosa.ReshapeOp( + input2_, memoryview(array.array("i", [total_size])) + ) + operands.append(reshape_op.result) + + else: + operands = [symbol_table.get((str(i), 0)) for i in input2] + op = linalg.GenericOp( + [tensor_type], + operands, + [output], + ir.ArrayAttr.get(input_map), + ir.ArrayAttr.get( + [ir.Attribute.parse("#linalg.iterator_type")] + * max(len(output_shape), len(input_shape)) + ), + ) + arguments = [ + ir.RankedTensorType(i.type).element_type for i in operands + ] + [ir.RankedTensorType(output.result.type).element_type] + block = ir.Block.create_at_start(op.region, arguments) + index = [] + None_count = 0 + for i in range(len(input2)): + if input2[i] == None: + None_count += 1 + index_op = linalg.IndexOp(ir._i64Attr(i, None)) + block.append(index_op) + index.append(index_op.result) + else: + indexcast_op = arith.IndexCastOp( + ir.IndexType.get(), block.arguments[i - None_count] + ) + block.append(indexcast_op) + index.append(indexcast_op.result) + value = tensor.ExtractOp(input1, index) + block.append(value) + block.append(linalg.YieldOp([value.result])) + return op + + ops_registry = { "MatmulOp": matmul_op, "ArangeOp": arange_op, @@ -2155,4 +2379,5 @@ def greater_than_op( "GtOp": gt_op, "GeOp": ge_op, "GreaterThanOp": greater_than_op, + "UnsafeIndexOp": unsafe_index_op, } diff --git a/requirements.txt b/requirements.txt index 5efad98526..6b2fd250c0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ --pre --extra-index-url https://download.pytorch.org/whl/cpu -torch == 2.4.0 +torch == 2.5.1 numpy < 2 transformers == 4.46.2 tokenizers >= 0.20