From 755736be65a443c0bd9c98be88abc2c310a350a3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= <15837247+mofeing@users.noreply.github.com> Date: Tue, 29 Oct 2024 23:58:58 +0100 Subject: [PATCH] Add C-API for constructing Complex Attributes (#208) * Add C-API to construct ComplexF32 and ComplexF64 attributes * Link against MLIR Complex dialect * Rename C functions and delete `float` version * Add the Julia bindings to the C routines * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- deps/ReactantExtra/API.cpp | 14 ++++++++++++++ deps/ReactantExtra/BUILD | 1 + src/mlir/IR/Attribute.jl | 23 +++++++++++++++++++++++ src/mlir/MLIR.jl | 13 +++++++++++++ 4 files changed, 51 insertions(+) diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index 9381626f7..8e8579dbb 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Async/IR/Async.h" +#include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" #include "mlir/Dialect/DLTI/DLTI.h" #include "mlir/Dialect/Func/Extensions/InlinerExtension.h" @@ -63,6 +64,19 @@ using namespace mlir; using namespace llvm; using namespace xla; +// MLIR C-API extras +#pragma region MLIR Extra +MLIR_CAPI_EXPORTED MlirAttribute mlirComplexAttrDoubleGet(MlirContext ctx, MlirType type, double real, double imag) { + return wrap(complex::NumberAttr::get(unwrap(type), real, imag)); +} + +MLIR_CAPI_EXPORTED MlirAttribute mlirComplexAttrDoubleGetChecked(MlirLocation loc, MlirType type, double real, double imag) { + return wrap(complex::NumberAttr::getChecked(unwrap(loc), unwrap(type), unwrap(type), real, imag)); +} + +MlirTypeID mlirComplexAttrGetTypeID(void) { return wrap(complex::NumberAttr::getTypeID()); } +#pragma endregion + // int google::protobuf::io::CodedInputStream::default_recursion_limit_ = 100; // int xla::_LayoutProto_default_instance_; diff --git a/deps/ReactantExtra/BUILD b/deps/ReactantExtra/BUILD index 363233fda..61257a421 100644 --- a/deps/ReactantExtra/BUILD +++ b/deps/ReactantExtra/BUILD @@ -308,6 +308,7 @@ cc_library( "@llvm-project//mlir:AllPassesAndDialects", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:AsyncDialect", + "@llvm-project//mlir:ComplexDialect", "@llvm-project//mlir:ControlFlowDialect", "@llvm-project//mlir:ConversionPasses", "@llvm-project//mlir:DLTIDialect", diff --git a/src/mlir/IR/Attribute.jl b/src/mlir/IR/Attribute.jl index 59c115a33..91ff127ca 100644 --- a/src/mlir/IR/Attribute.jl +++ b/src/mlir/IR/Attribute.jl @@ -133,6 +133,29 @@ function Base.Float64(attr::Attribute) return API.mlirFloatAttrGetValueDouble(attr) end +""" + Attribute(complex; context=context(), location=Location(), check=false) + +Creates a complex attribute in the given context with the given complex value and double-precision FP semantics. +""" +function Attribute( + c::T; context::Context=context(), location::Location=Location(), check::Bool=false +) where {T<:Complex} + if check + Attribute( + API.mlirComplexAttrDoubleGetChecked( + location, Type(T), Float64(real(c)), Float64(imag(c)) + ), + ) + else + Attribute( + API.mlirComplexAttrDoubleGet( + context, Type(T), Float64(real(c)), Float64(imag(c)) + ), + ) + end +end + """ isinteger(attr) diff --git a/src/mlir/MLIR.jl b/src/mlir/MLIR.jl index 71d11eecd..6bbf3cad4 100644 --- a/src/mlir/MLIR.jl +++ b/src/mlir/MLIR.jl @@ -11,6 +11,19 @@ module API let include("libMLIR_h.jl") end + + # MLIR C API - extra + function mlirComplexAttrDoubleGet(ctx, type, real, imag) + @ccall mlir_c.mlirComplexAttrDoubleGet( + ctx::MlirContext, type::MlirType, real::Cdouble, imag::Cdouble + )::MlirAttribute + end + + function mlirComplexAttrDoubleGetChecked(loc, type, real, imag) + @ccall mlir_c.mlirComplexAttrDoubleGetChecked( + loc::MlirLocation, type::MlirType, real::Cdouble, imag::Cdouble + )::MlirAttribute + end end # module API include("IR/IR.jl")