From 70f86f608344788fec551f54dab8b1abe5243c33 Mon Sep 17 00:00:00 2001 From: Joel Berkeley <16429957+joelberkeley@users.noreply.github.com> Date: Sun, 1 Dec 2024 02:16:35 +0000 Subject: [PATCH 01/38] convert HLO to StableHLO --- spidr/backend/BUILD | 6 +++ spidr/backend/VERSION | 2 +- spidr/backend/src/xla/hlo/builder/BUILD | 1 + .../src/xla/hlo/builder/xla_computation.cpp | 19 ++++++--- spidr/backend/src/xla/hlo/ir/BUILD | 12 ++++++ spidr/backend/src/xla/hlo/ir/hlo_module.cpp | 39 +++++++++++++++++++ spidr/backend/src/xla/hlo/ir/hlo_module.h | 18 +++++++++ spidr/backend/src/xla/hlo/translate/BUILD | 13 +++++++ .../src/xla/hlo/translate/portable_api.cpp | 28 +++++++++++++ spidr/backend/src/xla/service/BUILD | 13 +++++++ spidr/backend/src/xla/service/hlo.proto.cpp | 31 +++++++++++++++ spidr/backend/src/xla/service/hlo.proto.h | 18 +++++++++ .../src/xla/service/hlo_module_config.cpp | 33 ++++++++++++++++ .../src/xla/service/hlo_module_config.h | 18 +++++++++ spidr/backend/src/xla/shape.cpp | 4 ++ spidr/backend/src/xla/shape.h | 1 + spidr/src/Compiler/FFI.idr | 8 ++++ .../Xla/HLO/Builder/XlaComputation.idr | 24 +++++++----- spidr/src/Compiler/Xla/HLO/IR/HloModule.idr | 35 +++++++++++++++++ .../Xla/HLO/Translate/PortableAPI.idr | 28 +++++++++++++ .../src/Compiler/Xla/PJRT/PjrtExecutable.idr | 8 +--- .../Compiler/Xla/Service/HloModuleConfig.idr | 18 +++++++++ spidr/src/Compiler/Xla/Service/HloProto.idr | 18 +++++++++ spidr/src/Compiler/Xla/Shape.idr | 11 +++++- 24 files changed, 382 insertions(+), 24 deletions(-) create mode 100644 spidr/backend/src/xla/hlo/ir/BUILD create mode 100644 spidr/backend/src/xla/hlo/ir/hlo_module.cpp create mode 100644 spidr/backend/src/xla/hlo/ir/hlo_module.h create mode 100644 spidr/backend/src/xla/hlo/translate/BUILD create mode 100644 spidr/backend/src/xla/hlo/translate/portable_api.cpp create mode 100644 spidr/backend/src/xla/service/BUILD create mode 100644 spidr/backend/src/xla/service/hlo.proto.cpp create mode 100644 spidr/backend/src/xla/service/hlo.proto.h create mode 100644 spidr/backend/src/xla/service/hlo_module_config.cpp create mode 100644 spidr/backend/src/xla/service/hlo_module_config.h create mode 100644 spidr/src/Compiler/Xla/HLO/IR/HloModule.idr create mode 100644 spidr/src/Compiler/Xla/HLO/Translate/PortableAPI.idr create mode 100644 spidr/src/Compiler/Xla/Service/HloModuleConfig.idr create mode 100644 spidr/src/Compiler/Xla/Service/HloProto.idr diff --git a/spidr/backend/BUILD b/spidr/backend/BUILD index ba1a39bfa..65ec73bc4 100644 --- a/spidr/backend/BUILD +++ b/spidr/backend/BUILD @@ -16,8 +16,11 @@ cc_binary( "//src/xla/client", "//src/xla/hlo/builder", "//src/xla/hlo/builder/lib", + "//src/xla/hlo/ir", + "//src/xla/hlo/translate", "//src/xla/pjrt", "//src/xla/pjrt/c", + "//src/xla/service", "//src", ], deps = [ @@ -25,8 +28,11 @@ cc_binary( "//src/xla/client", "//src/xla/hlo/builder", "//src/xla/hlo/builder/lib", + "//src/xla/hlo/ir", + "//src/xla/hlo/translate", "//src/xla/pjrt", "//src/xla/pjrt/c", + "//src/xla/service", "//src", ], ) diff --git a/spidr/backend/VERSION b/spidr/backend/VERSION index 9789c4ccb..ceddfb28f 100644 --- a/spidr/backend/VERSION +++ b/spidr/backend/VERSION @@ -1 +1 @@ -0.0.14 +0.0.15 diff --git a/spidr/backend/src/xla/hlo/builder/BUILD b/spidr/backend/src/xla/hlo/builder/BUILD index e729f1eef..48be5352b 100644 --- a/spidr/backend/src/xla/hlo/builder/BUILD +++ b/spidr/backend/src/xla/hlo/builder/BUILD @@ -8,6 +8,7 @@ cc_library( "@xla//xla/hlo/builder:xla_builder", "//src", "//src/xla", + "//src/xla/service", ], visibility = ["//visibility:public"], ) diff --git a/spidr/backend/src/xla/hlo/builder/xla_computation.cpp b/spidr/backend/src/xla/hlo/builder/xla_computation.cpp index 1cba3a527..efd371cbe 100644 --- a/spidr/backend/src/xla/hlo/builder/xla_computation.cpp +++ b/spidr/backend/src/xla/hlo/builder/xla_computation.cpp @@ -14,8 +14,11 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "xla/hlo/builder/xla_computation.h" +#include "xla/shape.h" #include "../../../ffi.h" +#include "../../service/hlo.proto.h" +#include "../../shape.h" #include "xla_computation.h" extern "C" { @@ -23,9 +26,15 @@ extern "C" { delete reinterpret_cast(s); } - string* XlaComputation_SerializeAsString(XlaComputation* s) { - auto s_ = reinterpret_cast(s); - auto serialized = s_->proto().SerializeAsString(); - return reinterpret_cast(new std::string(serialized)); + ProgramShape* XlaComputation_GetProgramShape(XlaComputation* s) { + auto res = reinterpret_cast(s)->GetProgramShape(); + return reinterpret_cast(new xla::ProgramShape(*res)); } -} + + HloModuleProto* XlaComputation_proto(XlaComputation* s) { + auto res = reinterpret_cast(s)->proto(); + // I think the proto is owned by, and lives as long as, the XlaComputation + // so this is probably wrong since Idris will GC the XlaComputation + return reinterpret_cast(&res); + } +} \ No newline at end of file diff --git a/spidr/backend/src/xla/hlo/ir/BUILD b/spidr/backend/src/xla/hlo/ir/BUILD new file mode 100644 index 000000000..b09f5e688 --- /dev/null +++ b/spidr/backend/src/xla/hlo/ir/BUILD @@ -0,0 +1,12 @@ +cc_library( + name = "ir", + linkstatic = True, + alwayslink = True, + srcs = glob(["*.cpp"]), + hdrs = glob(["*.h"]), + deps = [ + "@xla//xla/hlo/ir:hlo", + "//src/xla/service", + ], + visibility = ["//visibility:public"], +) diff --git a/spidr/backend/src/xla/hlo/ir/hlo_module.cpp b/spidr/backend/src/xla/hlo/ir/hlo_module.cpp new file mode 100644 index 000000000..6cdc30383 --- /dev/null +++ b/spidr/backend/src/xla/hlo/ir/hlo_module.cpp @@ -0,0 +1,39 @@ +/* +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "xla/hlo/ir/hlo_module.h" +#include "xla/service/hlo.pb.h" +#include "xla/service/hlo_module_config.h" + +#include "hlo_module.h" + +#include "../../service/hlo.proto.h" +#include "../../service/hlo_module_config.h" + +extern "C" { + // if proto is a nuisance, could pass in the XlaComputation instead + HloModule* HloModule_CreateFromProto(HloModuleProto& proto, HloModuleConfig& module_config) { + auto& proto_ = reinterpret_cast(proto); + auto& module_config_ = reinterpret_cast(module_config); + + auto module = xla::HloModule::CreateFromProto(proto_, module_config_); + // this looks suspicious, but I'm pretty sure we own the HloModule + return reinterpret_cast(&*module); + } + + void HloModule_delete(HloModule* s) { + delete reinterpret_cast(s); + } +} \ No newline at end of file diff --git a/spidr/backend/src/xla/hlo/ir/hlo_module.h b/spidr/backend/src/xla/hlo/ir/hlo_module.h new file mode 100644 index 000000000..10ad53e06 --- /dev/null +++ b/spidr/backend/src/xla/hlo/ir/hlo_module.h @@ -0,0 +1,18 @@ +/* +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +extern "C" { + struct HloModule; +} \ No newline at end of file diff --git a/spidr/backend/src/xla/hlo/translate/BUILD b/spidr/backend/src/xla/hlo/translate/BUILD new file mode 100644 index 000000000..7ebd74269 --- /dev/null +++ b/spidr/backend/src/xla/hlo/translate/BUILD @@ -0,0 +1,13 @@ +cc_library( + name = "translate", + linkstatic = True, + alwayslink = True, + srcs = glob(["*.cpp"]), + hdrs = glob(["*.h"]), + deps = [ + "@xla//xla/hlo/translate:portable_api", + "//src/xla/hlo/ir", + "//src", + ], + visibility = ["//visibility:public"], +) diff --git a/spidr/backend/src/xla/hlo/translate/portable_api.cpp b/spidr/backend/src/xla/hlo/translate/portable_api.cpp new file mode 100644 index 000000000..5837df22e --- /dev/null +++ b/spidr/backend/src/xla/hlo/translate/portable_api.cpp @@ -0,0 +1,28 @@ +/* +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/translate/portable_api.h" + +#include "../ir/hlo_module.h" +#include "../../../ffi.h" + +extern "C" { + string* ConvertHloToStablehlo(HloModule& hlo_module) { + auto& hlo_module_ = reinterpret_cast(hlo_module); + auto res = xla::ConvertHloToStablehlo(hlo_module_); + return reinterpret_cast(new std::string(*res)); + } +} \ No newline at end of file diff --git a/spidr/backend/src/xla/service/BUILD b/spidr/backend/src/xla/service/BUILD new file mode 100644 index 000000000..38c7f2d3e --- /dev/null +++ b/spidr/backend/src/xla/service/BUILD @@ -0,0 +1,13 @@ +cc_library( + name = "service", + linkstatic = True, + alwayslink = True, + srcs = glob(["*.cpp"]), + hdrs = glob(["*.h"]), + deps = [ + "@xla//xla/service", + "//src/xla", + "//src", + ], + visibility = ["//visibility:public"], +) diff --git a/spidr/backend/src/xla/service/hlo.proto.cpp b/spidr/backend/src/xla/service/hlo.proto.cpp new file mode 100644 index 000000000..09ba05c65 --- /dev/null +++ b/spidr/backend/src/xla/service/hlo.proto.cpp @@ -0,0 +1,31 @@ +/* +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "xla/service/hlo.pb.h" +// #include "xla/service/..." // try to import from some random place + +#include "../../ffi.h" +#include "hlo.proto.h" + +extern "C" { + string* HloModuleProto_SerializeAsString(HloModuleProto& s) { + auto s_ = reinterpret_cast(s); + return reinterpret_cast(new std::string(s_.SerializeAsString())); + } + + void HloModuleProto_delete(HloModuleProto* s) { + delete reinterpret_cast(s); + } +} \ No newline at end of file diff --git a/spidr/backend/src/xla/service/hlo.proto.h b/spidr/backend/src/xla/service/hlo.proto.h new file mode 100644 index 000000000..336bbeaf3 --- /dev/null +++ b/spidr/backend/src/xla/service/hlo.proto.h @@ -0,0 +1,18 @@ +/* +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +extern "C" { + struct HloModuleProto; +} diff --git a/spidr/backend/src/xla/service/hlo_module_config.cpp b/spidr/backend/src/xla/service/hlo_module_config.cpp new file mode 100644 index 000000000..cec0467ca --- /dev/null +++ b/spidr/backend/src/xla/service/hlo_module_config.cpp @@ -0,0 +1,33 @@ +/* +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "xla/service/hlo_module_config.h" +#include "xla/shape.h" + +#include "hlo_module_config.h" + +#include "../shape.h" + +extern "C" { + HloModuleConfig* HloModuleConfig_new(ProgramShape& program_shape) { + auto& program_shape_ = reinterpret_cast(program_shape); + auto config = new xla::HloModuleConfig(program_shape_); + return reinterpret_cast(config); + } + + void HloModuleConfig_delete(HloModuleConfig* s) { + delete reinterpret_cast(s); + } +} \ No newline at end of file diff --git a/spidr/backend/src/xla/service/hlo_module_config.h b/spidr/backend/src/xla/service/hlo_module_config.h new file mode 100644 index 000000000..eb1da11d2 --- /dev/null +++ b/spidr/backend/src/xla/service/hlo_module_config.h @@ -0,0 +1,18 @@ +/* +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +extern "C" { + struct HloModuleConfig; +} \ No newline at end of file diff --git a/spidr/backend/src/xla/shape.cpp b/spidr/backend/src/xla/shape.cpp index b85223928..2bed88cde 100644 --- a/spidr/backend/src/xla/shape.cpp +++ b/spidr/backend/src/xla/shape.cpp @@ -29,4 +29,8 @@ extern "C" { void set_array_Shape(Shape* arr, int idx, Shape* shape) { reinterpret_cast(arr)[idx] = *reinterpret_cast(shape); } + + void ProgramShape_delete(ProgramShape* s) { + delete reinterpret_cast(s); + } } diff --git a/spidr/backend/src/xla/shape.h b/spidr/backend/src/xla/shape.h index 27da41111..a05f9e411 100644 --- a/spidr/backend/src/xla/shape.h +++ b/spidr/backend/src/xla/shape.h @@ -15,4 +15,5 @@ limitations under the License. */ extern "C" { struct Shape; + struct ProgramShape; } diff --git a/spidr/src/Compiler/FFI.idr b/spidr/src/Compiler/FFI.idr index aec92c193..ad7d9ab4c 100644 --- a/spidr/src/Compiler/FFI.idr +++ b/spidr/src/Compiler/FFI.idr @@ -47,6 +47,14 @@ export %foreign (libxla "idx") prim__index : Int -> AnyPtr -> AnyPtr +export +stringToCharArray : HasIO io => AnyPtr -> io MkCharArray +stringToCharArray str = do + data' <- primIO $ prim__stringData str + let size = prim__stringSize str + primIO $ prim__stringDelete str + pure (MkCharArray data' size) + export cIntToBool : Int -> Bool cIntToBool 0 = False diff --git a/spidr/src/Compiler/Xla/HLO/Builder/XlaComputation.idr b/spidr/src/Compiler/Xla/HLO/Builder/XlaComputation.idr index 1e35ba4dc..1ff858178 100644 --- a/spidr/src/Compiler/Xla/HLO/Builder/XlaComputation.idr +++ b/spidr/src/Compiler/Xla/HLO/Builder/XlaComputation.idr @@ -17,6 +17,7 @@ limitations under the License. module Compiler.Xla.HLO.Builder.XlaComputation import Compiler.FFI +import Compiler.Xla.Shape public export data XlaComputation : Type where @@ -29,16 +30,19 @@ export delete : AnyPtr -> IO () delete = primIO . prim__delete +%foreign (libxla "XlaComputation_GetProgramShape") +prim__xlaComputationGetProgramShape : GCAnyPtr -> PrimIO AnyPtr + export -%foreign (libxla "XlaComputation_SerializeAsString") -prim__xlaComputationSerializeAsString : GCAnyPtr -> PrimIO AnyPtr +getProgramShape : HasIO io => XlaComputation -> io ProgramShape +getProgramShape (MkXlaComputation comp) = do + pshape <- primIO $ prim__xlaComputationGetProgramShape comp + pshape <- onCollectAny (primIO . prim__ProgramShape_delete) pshape + pure (MkProgramShape pshape) + +%foreign (libxla "XlaComputation_proto") +prim__xlaComputationProto : GCAnyPtr -> AnyPtr -||| It is up to the caller to deallocate the CharArray. export -serializeAsString : HasIO io => XlaComputation -> io CharArray -serializeAsString (MkXlaComputation computation) = do - str <- primIO $ prim__xlaComputationSerializeAsString computation - data' <- primIO $ prim__stringData str - let size = prim__stringSize str - primIO $ prim__stringDelete str - pure (MkCharArray data' size) +proto : XlaComputation -> HloModuleProto -- HasIO? ownership? +proto (MkXlaComputation comp) = MkHloModuleProto $ prim__xlaComputationProto comp diff --git a/spidr/src/Compiler/Xla/HLO/IR/HloModule.idr b/spidr/src/Compiler/Xla/HLO/IR/HloModule.idr new file mode 100644 index 000000000..0c36cf80a --- /dev/null +++ b/spidr/src/Compiler/Xla/HLO/IR/HloModule.idr @@ -0,0 +1,35 @@ +{-- +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--} +||| For internal spidr use only. +module Compiler.Xla.HLO.IR.HloModule + +import Compiler.FFI + +public export +data HloModule = MkHloModule GCAnyPtr + +%foreign (libxla "HloModule_delete") +prim__hloModuleDelete : AnyPtr -> PrimIO () + +%foreign (libxla "HloModule_CreateFromProto") +prim__hloModuleCreateFromProto : AnyPtr -> GCAnyPtr -> PrimIO AnyPtr + +export +createFromProto : HasIO io => HloModuleProto -> HloModuleConfig -> io HloModule +createFromProto (MkHloModuleProto proto) (MkHloModuleConfig config) = do + module <- primIO $ prim__hloModuleCreateFromProto proto config + module <- onCollectAny (primIO . prim__hloModuleDelete) module + pure (MkHloModule module) diff --git a/spidr/src/Compiler/Xla/HLO/Translate/PortableAPI.idr b/spidr/src/Compiler/Xla/HLO/Translate/PortableAPI.idr new file mode 100644 index 000000000..c5f17eec4 --- /dev/null +++ b/spidr/src/Compiler/Xla/HLO/Translate/PortableAPI.idr @@ -0,0 +1,28 @@ +{-- +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--} +||| For internal spidr use only. +module Compiler.Xla.HLO.Translate + +import Compiler.FFI +import Compiler.Xla.HLO.IR.HloModule + +%foreign (libxla "ConvertHloToStablehlo") +prim__convertHloToStablehlo : GCAnyPtr -> PrimIO AnyPtr + +export +convertHloToStablehlo : HasIO io => HloModule -> io CharArray +convertHloToStablehlo (MkHloModule module) = + primIO (prim__convertHloToStablehlo module) >>= stringToCharArray diff --git a/spidr/src/Compiler/Xla/PJRT/PjrtExecutable.idr b/spidr/src/Compiler/Xla/PJRT/PjrtExecutable.idr index 987cb1fdd..458d77a0f 100644 --- a/spidr/src/Compiler/Xla/PJRT/PjrtExecutable.idr +++ b/spidr/src/Compiler/Xla/PJRT/PjrtExecutable.idr @@ -39,9 +39,5 @@ prim__compileOptionsSerializeAsString : GCAnyPtr -> PrimIO AnyPtr ||| It is up to the caller to deallocate the CharArray. export serializeAsString : HasIO io => CompileOptions -> io CharArray -serializeAsString (MkCompileOptions options) = do - str <- primIO $ prim__compileOptionsSerializeAsString options - data' <- primIO $ prim__stringData str - let size = prim__stringSize str - primIO $ prim__stringDelete str - pure (MkCharArray data' size) +serializeAsString (MkCompileOptions options) = + primIO (prim__compileOptionsSerializeAsString options) >>= stringToCharArray diff --git a/spidr/src/Compiler/Xla/Service/HloModuleConfig.idr b/spidr/src/Compiler/Xla/Service/HloModuleConfig.idr new file mode 100644 index 000000000..027236d62 --- /dev/null +++ b/spidr/src/Compiler/Xla/Service/HloModuleConfig.idr @@ -0,0 +1,18 @@ +{-- +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--} +||| For internal spidr use only. + +newHloModuleConfig : HasIO io => io HloModuleConfig diff --git a/spidr/src/Compiler/Xla/Service/HloProto.idr b/spidr/src/Compiler/Xla/Service/HloProto.idr new file mode 100644 index 000000000..5c46d08d0 --- /dev/null +++ b/spidr/src/Compiler/Xla/Service/HloProto.idr @@ -0,0 +1,18 @@ +{-- +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--} +||| For internal spidr use only. + +serializeAsString : HloModuleProto -> String diff --git a/spidr/src/Compiler/Xla/Shape.idr b/spidr/src/Compiler/Xla/Shape.idr index 95d2156a9..dee238c75 100644 --- a/spidr/src/Compiler/Xla/Shape.idr +++ b/spidr/src/Compiler/Xla/Shape.idr @@ -25,11 +25,11 @@ namespace Xla MkShape : GCAnyPtr -> Shape %foreign (libxla "Shape_delete") -prim__delete : AnyPtr -> PrimIO () +prim__Shape_delete : AnyPtr -> PrimIO () export delete : AnyPtr -> IO () -delete = primIO . prim__delete +delete = primIO . prim__Shape_delete %foreign (libxla "sizeof_Shape") sizeOfShape : Int @@ -48,3 +48,10 @@ mkShapeArray shapes = do primIO $ prim__setArrayShape arr (cast idx) shape) (enumerate (fromList shapes)) arr <- onCollectAny arr free pure (MkShapeArray arr) + +public export +data ProgramShape = MkProgramShape GCAnyPtr + +export +%foreign (libxla "ProgramShape_delete") +prim__ProgramShape_delete : AnyPtr -> PrimIO () From 866d15a018beb4b8fb15efed5c5743314c0e6c81 Mon Sep 17 00:00:00 2001 From: Joel Berkeley <16429957+joelberkeley@users.noreply.github.com> Date: Sun, 1 Dec 2024 20:06:23 +0000 Subject: [PATCH 02/38] working sometimes --- .../src/xla/hlo/builder/xla_computation.cpp | 4 +--- spidr/backend/src/xla/hlo/ir/hlo_module.cpp | 4 +--- .../src/xla/hlo/translate/portable_api.cpp | 2 +- spidr/backend/src/xla/pjrt/c/pjrt_c_api.cpp | 3 ++- spidr/spidr.ipkg | 4 ++++ spidr/src/Compiler/Eval.idr | 16 ++++++++++++- spidr/src/Compiler/FFI.idr | 2 +- .../Xla/HLO/Builder/XlaComputation.idr | 14 +++++++---- spidr/src/Compiler/Xla/HLO/IR/HloModule.idr | 12 ++++++---- .../Xla/HLO/Translate/PortableAPI.idr | 6 ++--- .../Compiler/Xla/Service/HloModuleConfig.idr | 20 +++++++++++++++- spidr/src/Compiler/Xla/Service/HloProto.idr | 18 +++++++++++++- test/runner/TestRunner.idr | 14 +++++------ test/runner/Unit/TestTensor.idr | 24 +++++++++++++++---- 14 files changed, 107 insertions(+), 36 deletions(-) diff --git a/spidr/backend/src/xla/hlo/builder/xla_computation.cpp b/spidr/backend/src/xla/hlo/builder/xla_computation.cpp index efd371cbe..73834499a 100644 --- a/spidr/backend/src/xla/hlo/builder/xla_computation.cpp +++ b/spidr/backend/src/xla/hlo/builder/xla_computation.cpp @@ -33,8 +33,6 @@ extern "C" { HloModuleProto* XlaComputation_proto(XlaComputation* s) { auto res = reinterpret_cast(s)->proto(); - // I think the proto is owned by, and lives as long as, the XlaComputation - // so this is probably wrong since Idris will GC the XlaComputation - return reinterpret_cast(&res); + return reinterpret_cast(new xla::HloModuleProto(res)); } } \ No newline at end of file diff --git a/spidr/backend/src/xla/hlo/ir/hlo_module.cpp b/spidr/backend/src/xla/hlo/ir/hlo_module.cpp index 6cdc30383..8301554e9 100644 --- a/spidr/backend/src/xla/hlo/ir/hlo_module.cpp +++ b/spidr/backend/src/xla/hlo/ir/hlo_module.cpp @@ -27,10 +27,8 @@ extern "C" { HloModule* HloModule_CreateFromProto(HloModuleProto& proto, HloModuleConfig& module_config) { auto& proto_ = reinterpret_cast(proto); auto& module_config_ = reinterpret_cast(module_config); - auto module = xla::HloModule::CreateFromProto(proto_, module_config_); - // this looks suspicious, but I'm pretty sure we own the HloModule - return reinterpret_cast(&*module); + return reinterpret_cast(module.value().release()); } void HloModule_delete(HloModule* s) { diff --git a/spidr/backend/src/xla/hlo/translate/portable_api.cpp b/spidr/backend/src/xla/hlo/translate/portable_api.cpp index 5837df22e..cc1ddf896 100644 --- a/spidr/backend/src/xla/hlo/translate/portable_api.cpp +++ b/spidr/backend/src/xla/hlo/translate/portable_api.cpp @@ -23,6 +23,6 @@ extern "C" { string* ConvertHloToStablehlo(HloModule& hlo_module) { auto& hlo_module_ = reinterpret_cast(hlo_module); auto res = xla::ConvertHloToStablehlo(hlo_module_); - return reinterpret_cast(new std::string(*res)); + return reinterpret_cast(new std::string(res.value())); } } \ No newline at end of file diff --git a/spidr/backend/src/xla/pjrt/c/pjrt_c_api.cpp b/spidr/backend/src/xla/pjrt/c/pjrt_c_api.cpp index 18290860c..fc7fe2f00 100644 --- a/spidr/backend/src/xla/pjrt/c/pjrt_c_api.cpp +++ b/spidr/backend/src/xla/pjrt/c/pjrt_c_api.cpp @@ -127,7 +127,7 @@ extern "C" { } PJRT_Program* PJRT_Program_new(char* code, size_t code_size) { - auto format = pjrt::kHloFormat; + auto format = pjrt::kMlirFormat; return new PJRT_Program{ .struct_size = PJRT_Program_STRUCT_SIZE, .extension_start = nullptr, @@ -159,6 +159,7 @@ extern "C" { } PJRT_Error* pjrt_client_compile(PJRT_Api* api, PJRT_Client_Compile_Args* args) { + printf("%.*s\n", args->program->code_size, args->program->code); return api->PJRT_Client_Compile(args); } diff --git a/spidr/spidr.ipkg b/spidr/spidr.ipkg index fa7b670c0..33cf3546d 100644 --- a/spidr/spidr.ipkg +++ b/spidr/spidr.ipkg @@ -16,8 +16,12 @@ modules = Compiler.Xla.HLO.Builder.Lib.PRNG, Compiler.Xla.HLO.Builder.XlaBuilder, Compiler.Xla.HLO.Builder.XlaComputation, + Compiler.Xla.HLO.IR.HloModule, + Compiler.Xla.HLO.Translate.PortableAPI, Compiler.Xla.PJRT.C.PjrtCApi, Compiler.Xla.PJRT.PjrtExecutable, + Compiler.Xla.Service.HloModuleConfig, + Compiler.Xla.Service.HloProto, Compiler.Xla.Literal, Compiler.Xla.Shape, Compiler.Xla.ShapeUtil, diff --git a/spidr/src/Compiler/Eval.idr b/spidr/src/Compiler/Eval.idr index 6c26186f5..e9b39a853 100644 --- a/spidr/src/Compiler/Eval.idr +++ b/spidr/src/Compiler/Eval.idr @@ -34,8 +34,12 @@ import Compiler.Xla.HLO.Builder.Lib.Matrix import Compiler.Xla.HLO.Builder.Lib.PRNG import Compiler.Xla.HLO.Builder.XlaBuilder import Compiler.Xla.HLO.Builder.XlaComputation +import Compiler.Xla.HLO.IR.HloModule +import Compiler.Xla.HLO.Translate.PortableAPI import Compiler.Xla.PJRT.C.PjrtCApi import Compiler.Xla.PJRT.PjrtExecutable +import Compiler.Xla.Service.HloModuleConfig +import Compiler.Xla.Service.HloProto import Compiler.Xla.Literal import Compiler.Xla.Shape import Compiler.Xla.ShapeUtil @@ -224,7 +228,17 @@ execute (MkDevice api client) f@(MkFn _ _ env) shapes = do xlaBuilder <- mkXlaBuilder "root" computation <- compile @{!(newArray $ cast $ counter env)} xlaBuilder f bimapEitherT PjrtErr id $ do - code <- serializeAsString computation + -- printLn 0 + proto <- proto computation + -- printLn 1 + programShape <- getProgramShape computation + -- printLn 2 + moduleConfig <- hloModuleConfig programShape + -- printLn 3 + module' <- createFromProto proto moduleConfig + -- printLn 4 + code <- convertHloToStablehlo module' + -- printLn 5 executableBuildOptions <- mkExecutableBuildOptions compileOptions <- serializeAsString !(mkCompileOptions executableBuildOptions) loadedExec <- pjrtClientCompile api client !(mkPjrtProgram code) compileOptions diff --git a/spidr/src/Compiler/FFI.idr b/spidr/src/Compiler/FFI.idr index ad7d9ab4c..1db22cbdb 100644 --- a/spidr/src/Compiler/FFI.idr +++ b/spidr/src/Compiler/FFI.idr @@ -48,7 +48,7 @@ export prim__index : Int -> AnyPtr -> AnyPtr export -stringToCharArray : HasIO io => AnyPtr -> io MkCharArray +stringToCharArray : HasIO io => AnyPtr -> io CharArray stringToCharArray str = do data' <- primIO $ prim__stringData str let size = prim__stringSize str diff --git a/spidr/src/Compiler/Xla/HLO/Builder/XlaComputation.idr b/spidr/src/Compiler/Xla/HLO/Builder/XlaComputation.idr index 1ff858178..b579be951 100644 --- a/spidr/src/Compiler/Xla/HLO/Builder/XlaComputation.idr +++ b/spidr/src/Compiler/Xla/HLO/Builder/XlaComputation.idr @@ -18,6 +18,7 @@ module Compiler.Xla.HLO.Builder.XlaComputation import Compiler.FFI import Compiler.Xla.Shape +import Compiler.Xla.Service.HloProto public export data XlaComputation : Type where @@ -28,7 +29,7 @@ prim__delete : AnyPtr -> PrimIO () export delete : AnyPtr -> IO () -delete = primIO . prim__delete +delete = primIO . XlaComputation.prim__delete %foreign (libxla "XlaComputation_GetProgramShape") prim__xlaComputationGetProgramShape : GCAnyPtr -> PrimIO AnyPtr @@ -37,12 +38,15 @@ export getProgramShape : HasIO io => XlaComputation -> io ProgramShape getProgramShape (MkXlaComputation comp) = do pshape <- primIO $ prim__xlaComputationGetProgramShape comp - pshape <- onCollectAny (primIO . prim__ProgramShape_delete) pshape + pshape <- onCollectAny pshape (primIO . prim__ProgramShape_delete) pure (MkProgramShape pshape) %foreign (libxla "XlaComputation_proto") -prim__xlaComputationProto : GCAnyPtr -> AnyPtr +prim__xlaComputationProto : GCAnyPtr -> PrimIO AnyPtr export -proto : XlaComputation -> HloModuleProto -- HasIO? ownership? -proto (MkXlaComputation comp) = MkHloModuleProto $ prim__xlaComputationProto comp +proto : HasIO io => XlaComputation -> io HloModuleProto +proto (MkXlaComputation comp) = do + proto <- primIO $ prim__xlaComputationProto comp + proto <- onCollectAny proto (primIO . HloProto.prim__delete) + pure (MkHloModuleProto proto) diff --git a/spidr/src/Compiler/Xla/HLO/IR/HloModule.idr b/spidr/src/Compiler/Xla/HLO/IR/HloModule.idr index 0c36cf80a..ba438842e 100644 --- a/spidr/src/Compiler/Xla/HLO/IR/HloModule.idr +++ b/spidr/src/Compiler/Xla/HLO/IR/HloModule.idr @@ -17,19 +17,21 @@ limitations under the License. module Compiler.Xla.HLO.IR.HloModule import Compiler.FFI +import Compiler.Xla.Service.HloModuleConfig +import Compiler.Xla.Service.HloProto public export data HloModule = MkHloModule GCAnyPtr %foreign (libxla "HloModule_delete") -prim__hloModuleDelete : AnyPtr -> PrimIO () +prim__delete : AnyPtr -> PrimIO () %foreign (libxla "HloModule_CreateFromProto") -prim__hloModuleCreateFromProto : AnyPtr -> GCAnyPtr -> PrimIO AnyPtr +prim__hloModuleCreateFromProto : GCAnyPtr -> GCAnyPtr -> PrimIO AnyPtr export createFromProto : HasIO io => HloModuleProto -> HloModuleConfig -> io HloModule createFromProto (MkHloModuleProto proto) (MkHloModuleConfig config) = do - module <- primIO $ prim__hloModuleCreateFromProto proto config - module <- onCollectAny (primIO . prim__hloModuleDelete) module - pure (MkHloModule module) + module' <- primIO $ prim__hloModuleCreateFromProto proto config + module' <- onCollectAny module' (primIO . HloModule.prim__delete) + pure (MkHloModule module') diff --git a/spidr/src/Compiler/Xla/HLO/Translate/PortableAPI.idr b/spidr/src/Compiler/Xla/HLO/Translate/PortableAPI.idr index c5f17eec4..8fb7bfa0d 100644 --- a/spidr/src/Compiler/Xla/HLO/Translate/PortableAPI.idr +++ b/spidr/src/Compiler/Xla/HLO/Translate/PortableAPI.idr @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. --} ||| For internal spidr use only. -module Compiler.Xla.HLO.Translate +module Compiler.Xla.HLO.Translate.PortableAPI import Compiler.FFI import Compiler.Xla.HLO.IR.HloModule @@ -24,5 +24,5 @@ prim__convertHloToStablehlo : GCAnyPtr -> PrimIO AnyPtr export convertHloToStablehlo : HasIO io => HloModule -> io CharArray -convertHloToStablehlo (MkHloModule module) = - primIO (prim__convertHloToStablehlo module) >>= stringToCharArray +convertHloToStablehlo (MkHloModule module') = + primIO (prim__convertHloToStablehlo module') >>= stringToCharArray diff --git a/spidr/src/Compiler/Xla/Service/HloModuleConfig.idr b/spidr/src/Compiler/Xla/Service/HloModuleConfig.idr index 027236d62..15899d502 100644 --- a/spidr/src/Compiler/Xla/Service/HloModuleConfig.idr +++ b/spidr/src/Compiler/Xla/Service/HloModuleConfig.idr @@ -14,5 +14,23 @@ See the License for the specific language governing permissions and limitations under the License. --} ||| For internal spidr use only. +module Compiler.Xla.Service.HloModuleConfig -newHloModuleConfig : HasIO io => io HloModuleConfig +import Compiler.FFI +import Compiler.Xla.Shape + +public export +data HloModuleConfig = MkHloModuleConfig GCAnyPtr + +%foreign (libxla "HloModuleConfig_new") +prim__hloModuleConfig : GCAnyPtr -> PrimIO AnyPtr + +%foreign (libxla "HloModuleConfig_delete") +prim__delete : AnyPtr -> PrimIO () + +export +hloModuleConfig : HasIO io => ProgramShape -> io HloModuleConfig +hloModuleConfig (MkProgramShape pshape) = do + config <- primIO $ prim__hloModuleConfig pshape + config <- onCollectAny config (primIO . prim__delete) + pure (MkHloModuleConfig config) diff --git a/spidr/src/Compiler/Xla/Service/HloProto.idr b/spidr/src/Compiler/Xla/Service/HloProto.idr index 5c46d08d0..0e8d9f221 100644 --- a/spidr/src/Compiler/Xla/Service/HloProto.idr +++ b/spidr/src/Compiler/Xla/Service/HloProto.idr @@ -14,5 +14,21 @@ See the License for the specific language governing permissions and limitations under the License. --} ||| For internal spidr use only. +module Compiler.Xla.Service.HloProto -serializeAsString : HloModuleProto -> String +import Compiler.FFI + +public export +data HloModuleProto = MkHloModuleProto GCAnyPtr + +%foreign (libxla "HloModuleProto_SerializeAsString") +prim__hloModuleProtoSerializeAsString : GCAnyPtr -> PrimIO AnyPtr + +export +%foreign (libxla "HloModuleProto_delete") +prim__delete : AnyPtr -> PrimIO () + +export +serializeAsString : HasIO io => HloModuleProto -> io CharArray +serializeAsString (MkHloModuleProto proto) = + primIO (prim__hloModuleProtoSerializeAsString proto) >>= stringToCharArray diff --git a/test/runner/TestRunner.idr b/test/runner/TestRunner.idr index 792d2ae12..c5113e1b5 100644 --- a/test/runner/TestRunner.idr +++ b/test/runner/TestRunner.idr @@ -31,11 +31,11 @@ import Unit.TestUtil export run : Device -> IO () run device = test [ - Utils.TestComparison.group - , TestUtils.group - , Unit.TestUtil.group - , Unit.TestLiteral.group - , Unit.TestTensor.group - , Unit.TestDistribution.group - , Unit.Model.TestKernel.group + -- Utils.TestComparison.group + --, TestUtils.group + --, Unit.TestUtil.group + --, Unit.TestLiteral.group + Unit.TestTensor.group + --, Unit.TestDistribution.group + --, Unit.Model.TestKernel.group ] diff --git a/test/runner/Unit/TestTensor.idr b/test/runner/Unit/TestTensor.idr index e774435ab..f03037cd8 100644 --- a/test/runner/Unit/TestTensor.idr +++ b/test/runner/Unit/TestTensor.idr @@ -475,10 +475,26 @@ trace : Device => Property trace = fixedProperty $ trace (tensor {dtype = S32} [[-1, 5], [1, 4]]) ===# pure 3 +fixed : Device => Property +fixed @{device} = fixedProperty $ do + let x = Scalar 0.0 + x ==~ unsafePerformIO (eval device $ pure $ tensor {dtype = F64} x) + x ==~ unsafePerformIO (eval device $ pure $ tensor {dtype = F64} x) + + let x = Scalar (the Int32 1) + x === unsafePerformIO (Tag.eval device $ pure $ tensor {dtype = S32} x) + x === unsafePerformIO (Tag.eval device $ pure $ tensor {dtype = S32} x) + + let x = [0.0] + x ==~ unsafePerformIO (eval device $ pure $ tensor {dtype = F64} x) + x ==~ unsafePerformIO (eval device $ pure $ tensor {dtype = F64} x) + --x ==~ unsafePerformIO (eval device $ pure $ tensor {dtype = F64} x) + export group : Device => Group group = MkGroup "Tensor" $ [ - ("eval . tensor", tensorThenEval) + ("fixed", fixed) + {- ("eval . tensor", tensorThenEval) , ("eval multiple tensors (tuple)", evalTuple) , ("eval multiple tensors (tuple) for non-trivial graph", evalTupleNonTrivial) , ("can read/write finite numeric bounds to/from XLA", canConvertAtXlaNumericBounds) @@ -498,11 +514,11 @@ group = MkGroup "Tensor" $ [ , ("cholesky", cholesky) , (#"(|\) and (/|) result and inverse"#, triangularSolveResultAndInverse) , (#"(|\) and (/|) ignore opposite elements"#, triangularSolveIgnoresOppositeElems) - , ("trace", trace) + , ("trace", trace)-} ] ++ concat (the (List _) [ - Unit.TestTensor.Elementwise.all + {- Unit.TestTensor.Elementwise.all , Unit.TestTensor.HigherOrder.all , Unit.TestTensor.Sampling.all , Unit.TestTensor.Slice.all - , Unit.TestTensor.Structure.all + , Unit.TestTensor.Structure.all-} ]) From 793bfa121731a7593d9ac4511a98cf9b8a64f741 Mon Sep 17 00:00:00 2001 From: Joel Berkeley <16429957+joelberkeley@users.noreply.github.com> Date: Mon, 2 Dec 2024 22:44:49 +0000 Subject: [PATCH 03/38] working! thanks kevin --- .../src/xla/hlo/translate/portable_api.cpp | 2 +- spidr/backend/src/xla/pjrt/c/pjrt_c_api.cpp | 1 - spidr/src/Compiler/Eval.idr | 12 ++-------- test/runner/TestRunner.idr | 14 +++++------ test/runner/Unit/TestTensor.idr | 24 ++++--------------- 5 files changed, 14 insertions(+), 39 deletions(-) diff --git a/spidr/backend/src/xla/hlo/translate/portable_api.cpp b/spidr/backend/src/xla/hlo/translate/portable_api.cpp index cc1ddf896..7dc3519a0 100644 --- a/spidr/backend/src/xla/hlo/translate/portable_api.cpp +++ b/spidr/backend/src/xla/hlo/translate/portable_api.cpp @@ -22,7 +22,7 @@ limitations under the License. extern "C" { string* ConvertHloToStablehlo(HloModule& hlo_module) { auto& hlo_module_ = reinterpret_cast(hlo_module); - auto res = xla::ConvertHloToStablehlo(hlo_module_); + auto res = xla::ConvertHloToStablehlo(hlo_module_, true); return reinterpret_cast(new std::string(res.value())); } } \ No newline at end of file diff --git a/spidr/backend/src/xla/pjrt/c/pjrt_c_api.cpp b/spidr/backend/src/xla/pjrt/c/pjrt_c_api.cpp index fc7fe2f00..4730d0589 100644 --- a/spidr/backend/src/xla/pjrt/c/pjrt_c_api.cpp +++ b/spidr/backend/src/xla/pjrt/c/pjrt_c_api.cpp @@ -159,7 +159,6 @@ extern "C" { } PJRT_Error* pjrt_client_compile(PJRT_Api* api, PJRT_Client_Compile_Args* args) { - printf("%.*s\n", args->program->code_size, args->program->code); return api->PJRT_Client_Compile(args); } diff --git a/spidr/src/Compiler/Eval.idr b/spidr/src/Compiler/Eval.idr index e9b39a853..a04d27d19 100644 --- a/spidr/src/Compiler/Eval.idr +++ b/spidr/src/Compiler/Eval.idr @@ -228,17 +228,9 @@ execute (MkDevice api client) f@(MkFn _ _ env) shapes = do xlaBuilder <- mkXlaBuilder "root" computation <- compile @{!(newArray $ cast $ counter env)} xlaBuilder f bimapEitherT PjrtErr id $ do - -- printLn 0 - proto <- proto computation - -- printLn 1 - programShape <- getProgramShape computation - -- printLn 2 - moduleConfig <- hloModuleConfig programShape - -- printLn 3 - module' <- createFromProto proto moduleConfig - -- printLn 4 + moduleConfig <- hloModuleConfig !(getProgramShape computation) + module' <- createFromProto !(proto computation) moduleConfig code <- convertHloToStablehlo module' - -- printLn 5 executableBuildOptions <- mkExecutableBuildOptions compileOptions <- serializeAsString !(mkCompileOptions executableBuildOptions) loadedExec <- pjrtClientCompile api client !(mkPjrtProgram code) compileOptions diff --git a/test/runner/TestRunner.idr b/test/runner/TestRunner.idr index c5113e1b5..792d2ae12 100644 --- a/test/runner/TestRunner.idr +++ b/test/runner/TestRunner.idr @@ -31,11 +31,11 @@ import Unit.TestUtil export run : Device -> IO () run device = test [ - -- Utils.TestComparison.group - --, TestUtils.group - --, Unit.TestUtil.group - --, Unit.TestLiteral.group - Unit.TestTensor.group - --, Unit.TestDistribution.group - --, Unit.Model.TestKernel.group + Utils.TestComparison.group + , TestUtils.group + , Unit.TestUtil.group + , Unit.TestLiteral.group + , Unit.TestTensor.group + , Unit.TestDistribution.group + , Unit.Model.TestKernel.group ] diff --git a/test/runner/Unit/TestTensor.idr b/test/runner/Unit/TestTensor.idr index f03037cd8..e774435ab 100644 --- a/test/runner/Unit/TestTensor.idr +++ b/test/runner/Unit/TestTensor.idr @@ -475,26 +475,10 @@ trace : Device => Property trace = fixedProperty $ trace (tensor {dtype = S32} [[-1, 5], [1, 4]]) ===# pure 3 -fixed : Device => Property -fixed @{device} = fixedProperty $ do - let x = Scalar 0.0 - x ==~ unsafePerformIO (eval device $ pure $ tensor {dtype = F64} x) - x ==~ unsafePerformIO (eval device $ pure $ tensor {dtype = F64} x) - - let x = Scalar (the Int32 1) - x === unsafePerformIO (Tag.eval device $ pure $ tensor {dtype = S32} x) - x === unsafePerformIO (Tag.eval device $ pure $ tensor {dtype = S32} x) - - let x = [0.0] - x ==~ unsafePerformIO (eval device $ pure $ tensor {dtype = F64} x) - x ==~ unsafePerformIO (eval device $ pure $ tensor {dtype = F64} x) - --x ==~ unsafePerformIO (eval device $ pure $ tensor {dtype = F64} x) - export group : Device => Group group = MkGroup "Tensor" $ [ - ("fixed", fixed) - {- ("eval . tensor", tensorThenEval) + ("eval . tensor", tensorThenEval) , ("eval multiple tensors (tuple)", evalTuple) , ("eval multiple tensors (tuple) for non-trivial graph", evalTupleNonTrivial) , ("can read/write finite numeric bounds to/from XLA", canConvertAtXlaNumericBounds) @@ -514,11 +498,11 @@ group = MkGroup "Tensor" $ [ , ("cholesky", cholesky) , (#"(|\) and (/|) result and inverse"#, triangularSolveResultAndInverse) , (#"(|\) and (/|) ignore opposite elements"#, triangularSolveIgnoresOppositeElems) - , ("trace", trace)-} + , ("trace", trace) ] ++ concat (the (List _) [ - {- Unit.TestTensor.Elementwise.all + Unit.TestTensor.Elementwise.all , Unit.TestTensor.HigherOrder.all , Unit.TestTensor.Sampling.all , Unit.TestTensor.Slice.all - , Unit.TestTensor.Structure.all-} + , Unit.TestTensor.Structure.all ]) From e07c7f8c72e0a32829954f82a896f037331f6fdf Mon Sep 17 00:00:00 2001 From: Joel Berkeley <16429957+joelberkeley@users.noreply.github.com> Date: Mon, 2 Dec 2024 22:48:21 +0000 Subject: [PATCH 04/38] wip --- spidr/backend/src/xla/hlo/translate/portable_api.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/spidr/backend/src/xla/hlo/translate/portable_api.cpp b/spidr/backend/src/xla/hlo/translate/portable_api.cpp index 7dc3519a0..337f7da55 100644 --- a/spidr/backend/src/xla/hlo/translate/portable_api.cpp +++ b/spidr/backend/src/xla/hlo/translate/portable_api.cpp @@ -22,6 +22,8 @@ limitations under the License. extern "C" { string* ConvertHloToStablehlo(HloModule& hlo_module) { auto& hlo_module_ = reinterpret_cast(hlo_module); + // the implementation of this function shows how to get the actual MLIR module, which is + // crucial for enzyme! auto res = xla::ConvertHloToStablehlo(hlo_module_, true); return reinterpret_cast(new std::string(res.value())); } From c8bf80c5ab05cea5affe49e05dcff65088d66687 Mon Sep 17 00:00:00 2001 From: Joel Berkeley <16429957+joelberkeley@users.noreply.github.com> Date: Mon, 2 Dec 2024 22:54:56 +0000 Subject: [PATCH 05/38] wip --- spidr/backend/src/xla/hlo/builder/xla_computation.cpp | 2 +- spidr/backend/src/xla/hlo/ir/hlo_module.cpp | 2 +- spidr/backend/src/xla/hlo/ir/hlo_module.h | 2 +- spidr/backend/src/xla/hlo/translate/portable_api.cpp | 2 +- spidr/backend/src/xla/service/hlo.proto.cpp | 2 +- spidr/backend/src/xla/service/hlo_module_config.cpp | 2 +- spidr/backend/src/xla/service/hlo_module_config.h | 2 +- 7 files changed, 7 insertions(+), 7 deletions(-) diff --git a/spidr/backend/src/xla/hlo/builder/xla_computation.cpp b/spidr/backend/src/xla/hlo/builder/xla_computation.cpp index 73834499a..f6896b9bc 100644 --- a/spidr/backend/src/xla/hlo/builder/xla_computation.cpp +++ b/spidr/backend/src/xla/hlo/builder/xla_computation.cpp @@ -35,4 +35,4 @@ extern "C" { auto res = reinterpret_cast(s)->proto(); return reinterpret_cast(new xla::HloModuleProto(res)); } -} \ No newline at end of file +} diff --git a/spidr/backend/src/xla/hlo/ir/hlo_module.cpp b/spidr/backend/src/xla/hlo/ir/hlo_module.cpp index 8301554e9..ee00b9e9c 100644 --- a/spidr/backend/src/xla/hlo/ir/hlo_module.cpp +++ b/spidr/backend/src/xla/hlo/ir/hlo_module.cpp @@ -34,4 +34,4 @@ extern "C" { void HloModule_delete(HloModule* s) { delete reinterpret_cast(s); } -} \ No newline at end of file +} diff --git a/spidr/backend/src/xla/hlo/ir/hlo_module.h b/spidr/backend/src/xla/hlo/ir/hlo_module.h index 10ad53e06..5fc43b5a6 100644 --- a/spidr/backend/src/xla/hlo/ir/hlo_module.h +++ b/spidr/backend/src/xla/hlo/ir/hlo_module.h @@ -15,4 +15,4 @@ limitations under the License. */ extern "C" { struct HloModule; -} \ No newline at end of file +} diff --git a/spidr/backend/src/xla/hlo/translate/portable_api.cpp b/spidr/backend/src/xla/hlo/translate/portable_api.cpp index 337f7da55..38c89e4d5 100644 --- a/spidr/backend/src/xla/hlo/translate/portable_api.cpp +++ b/spidr/backend/src/xla/hlo/translate/portable_api.cpp @@ -27,4 +27,4 @@ extern "C" { auto res = xla::ConvertHloToStablehlo(hlo_module_, true); return reinterpret_cast(new std::string(res.value())); } -} \ No newline at end of file +} diff --git a/spidr/backend/src/xla/service/hlo.proto.cpp b/spidr/backend/src/xla/service/hlo.proto.cpp index 09ba05c65..195d17026 100644 --- a/spidr/backend/src/xla/service/hlo.proto.cpp +++ b/spidr/backend/src/xla/service/hlo.proto.cpp @@ -28,4 +28,4 @@ extern "C" { void HloModuleProto_delete(HloModuleProto* s) { delete reinterpret_cast(s); } -} \ No newline at end of file +} diff --git a/spidr/backend/src/xla/service/hlo_module_config.cpp b/spidr/backend/src/xla/service/hlo_module_config.cpp index cec0467ca..59478448a 100644 --- a/spidr/backend/src/xla/service/hlo_module_config.cpp +++ b/spidr/backend/src/xla/service/hlo_module_config.cpp @@ -30,4 +30,4 @@ extern "C" { void HloModuleConfig_delete(HloModuleConfig* s) { delete reinterpret_cast(s); } -} \ No newline at end of file +} diff --git a/spidr/backend/src/xla/service/hlo_module_config.h b/spidr/backend/src/xla/service/hlo_module_config.h index eb1da11d2..dff48c9ca 100644 --- a/spidr/backend/src/xla/service/hlo_module_config.h +++ b/spidr/backend/src/xla/service/hlo_module_config.h @@ -15,4 +15,4 @@ limitations under the License. */ extern "C" { struct HloModuleConfig; -} \ No newline at end of file +} From b9a5f04d386d36de6a30f0e074cfe6bf90647572 Mon Sep 17 00:00:00 2001 From: Joel Berkeley <16429957+joelberkeley@users.noreply.github.com> Date: Mon, 2 Dec 2024 23:57:36 +0000 Subject: [PATCH 06/38] temp remove pack switch HEAD --- .github/workflows/checks.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index 30cc6c6be..af39d9a89 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -148,7 +148,7 @@ jobs: - name: Install build dependencies run: | apt-get update && apt-get install -y curl - pack switch HEAD + # pack switch HEAD - name: Build tests working-directory: test/xla-cpu run: | @@ -169,7 +169,7 @@ jobs: brew install chezscheme git clone https://github.com/stefan-hoeck/idris2-pack.git (cd idris2-pack && make micropack SCHEME=chez) - ~/.pack/bin/pack switch HEAD + # ~/.pack/bin/pack switch HEAD - name: Build tests working-directory: test/xla-cpu run: | @@ -189,7 +189,7 @@ jobs: - name: Install build dependencies run: | apt-get update && apt-get install -y curl - pack switch HEAD + # pack switch HEAD - name: Build tests working-directory: test/xla-cuda run: | From fea8646f8cda039955b79f75f87df4be51549f09 Mon Sep 17 00:00:00 2001 From: Joel Berkeley <16429957+joelberkeley@users.noreply.github.com> Date: Tue, 3 Dec 2024 10:53:24 +0000 Subject: [PATCH 07/38] revert --- .github/workflows/checks.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index af39d9a89..30cc6c6be 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -148,7 +148,7 @@ jobs: - name: Install build dependencies run: | apt-get update && apt-get install -y curl - # pack switch HEAD + pack switch HEAD - name: Build tests working-directory: test/xla-cpu run: | @@ -169,7 +169,7 @@ jobs: brew install chezscheme git clone https://github.com/stefan-hoeck/idris2-pack.git (cd idris2-pack && make micropack SCHEME=chez) - # ~/.pack/bin/pack switch HEAD + ~/.pack/bin/pack switch HEAD - name: Build tests working-directory: test/xla-cpu run: | @@ -189,7 +189,7 @@ jobs: - name: Install build dependencies run: | apt-get update && apt-get install -y curl - # pack switch HEAD + pack switch HEAD - name: Build tests working-directory: test/xla-cuda run: | From 364158b6625c083f4b06561ee23b2effc92172b6 Mon Sep 17 00:00:00 2001 From: Joel Berkeley <16429957+joelberkeley@users.noreply.github.com> Date: Wed, 4 Dec 2024 12:50:50 +0000 Subject: [PATCH 08/38] wip --- spidr/backend/src/xla/hlo/ir/hlo_module.cpp | 1 - .../src/xla/hlo/translate/portable_api.cpp | 3 +++ spidr/src/Compiler/Eval.idr | 21 +++++++++++++------ spidr/src/Compiler/FFI.idr | 1 + .../Xla/HLO/Translate/PortableAPI.idr | 1 + .../src/Compiler/Xla/PJRT/PjrtExecutable.idr | 2 +- spidr/src/Compiler/Xla/Service/HloProto.idr | 1 + 7 files changed, 22 insertions(+), 8 deletions(-) diff --git a/spidr/backend/src/xla/hlo/ir/hlo_module.cpp b/spidr/backend/src/xla/hlo/ir/hlo_module.cpp index ee00b9e9c..077d56e1b 100644 --- a/spidr/backend/src/xla/hlo/ir/hlo_module.cpp +++ b/spidr/backend/src/xla/hlo/ir/hlo_module.cpp @@ -23,7 +23,6 @@ limitations under the License. #include "../../service/hlo_module_config.h" extern "C" { - // if proto is a nuisance, could pass in the XlaComputation instead HloModule* HloModule_CreateFromProto(HloModuleProto& proto, HloModuleConfig& module_config) { auto& proto_ = reinterpret_cast(proto); auto& module_config_ = reinterpret_cast(module_config); diff --git a/spidr/backend/src/xla/hlo/translate/portable_api.cpp b/spidr/backend/src/xla/hlo/translate/portable_api.cpp index 38c89e4d5..1d9400e60 100644 --- a/spidr/backend/src/xla/hlo/translate/portable_api.cpp +++ b/spidr/backend/src/xla/hlo/translate/portable_api.cpp @@ -21,10 +21,13 @@ limitations under the License. extern "C" { string* ConvertHloToStablehlo(HloModule& hlo_module) { + printf("ConvertHloToStablehlo ...\n"); auto& hlo_module_ = reinterpret_cast(hlo_module); + printf("0\n"); // the implementation of this function shows how to get the actual MLIR module, which is // crucial for enzyme! auto res = xla::ConvertHloToStablehlo(hlo_module_, true); + printf("1\n"); return reinterpret_cast(new std::string(res.value())); } } diff --git a/spidr/src/Compiler/Eval.idr b/spidr/src/Compiler/Eval.idr index a04d27d19..21f1c2675 100644 --- a/spidr/src/Compiler/Eval.idr +++ b/spidr/src/Compiler/Eval.idr @@ -225,15 +225,24 @@ interpret @{cache} xlaBuilder (MkFn params root env) = do export covering execute : Device -> Fn 0 -> {outputs : _} -> Vect outputs Xla.Shape -> ErrIO $ Vect outputs Literal execute (MkDevice api client) f@(MkFn _ _ env) shapes = do + putStrLn "execute ..." xlaBuilder <- mkXlaBuilder "root" computation <- compile @{!(newArray $ cast $ counter env)} xlaBuilder f + printLn 0 + moduleConfig <- hloModuleConfig !(getProgramShape computation) + printLn 1 + module' <- createFromProto !(proto computation) moduleConfig + printLn 2 + code <- convertHloToStablehlo module' + printLn 3 + executableBuildOptions <- mkExecutableBuildOptions + printLn 4 + compileOptions <- serializeAsString !(mkCompileOptions executableBuildOptions) + printLn 5 + program <- mkPjrtProgram code bimapEitherT PjrtErr id $ do - moduleConfig <- hloModuleConfig !(getProgramShape computation) - module' <- createFromProto !(proto computation) moduleConfig - code <- convertHloToStablehlo module' - executableBuildOptions <- mkExecutableBuildOptions - compileOptions <- serializeAsString !(mkCompileOptions executableBuildOptions) - loadedExec <- pjrtClientCompile api client !(mkPjrtProgram code) compileOptions + loadedExec <- pjrtClientCompile api client program compileOptions + printLn 6 free code free compileOptions delete executableBuildOptions diff --git a/spidr/src/Compiler/FFI.idr b/spidr/src/Compiler/FFI.idr index 1db22cbdb..a2f24eab1 100644 --- a/spidr/src/Compiler/FFI.idr +++ b/spidr/src/Compiler/FFI.idr @@ -47,6 +47,7 @@ export %foreign (libxla "idx") prim__index : Int -> AnyPtr -> AnyPtr +||| Deletes the `string`. It is up to the caller to `free` the `CharArray`. export stringToCharArray : HasIO io => AnyPtr -> io CharArray stringToCharArray str = do diff --git a/spidr/src/Compiler/Xla/HLO/Translate/PortableAPI.idr b/spidr/src/Compiler/Xla/HLO/Translate/PortableAPI.idr index 8fb7bfa0d..7aaba289a 100644 --- a/spidr/src/Compiler/Xla/HLO/Translate/PortableAPI.idr +++ b/spidr/src/Compiler/Xla/HLO/Translate/PortableAPI.idr @@ -22,6 +22,7 @@ import Compiler.Xla.HLO.IR.HloModule %foreign (libxla "ConvertHloToStablehlo") prim__convertHloToStablehlo : GCAnyPtr -> PrimIO AnyPtr +||| It is up to the caller to `free` the `CharArray`. export convertHloToStablehlo : HasIO io => HloModule -> io CharArray convertHloToStablehlo (MkHloModule module') = diff --git a/spidr/src/Compiler/Xla/PJRT/PjrtExecutable.idr b/spidr/src/Compiler/Xla/PJRT/PjrtExecutable.idr index 458d77a0f..acdb83406 100644 --- a/spidr/src/Compiler/Xla/PJRT/PjrtExecutable.idr +++ b/spidr/src/Compiler/Xla/PJRT/PjrtExecutable.idr @@ -36,7 +36,7 @@ mkCompileOptions (MkExecutableBuildOptions executableBuildOptions) = do %foreign (libxla "CompileOptions_SerializeAsString") prim__compileOptionsSerializeAsString : GCAnyPtr -> PrimIO AnyPtr -||| It is up to the caller to deallocate the CharArray. +||| It is up to the caller to `free` the `CharArray`. export serializeAsString : HasIO io => CompileOptions -> io CharArray serializeAsString (MkCompileOptions options) = diff --git a/spidr/src/Compiler/Xla/Service/HloProto.idr b/spidr/src/Compiler/Xla/Service/HloProto.idr index 0e8d9f221..9e7ce2b2c 100644 --- a/spidr/src/Compiler/Xla/Service/HloProto.idr +++ b/spidr/src/Compiler/Xla/Service/HloProto.idr @@ -28,6 +28,7 @@ export %foreign (libxla "HloModuleProto_delete") prim__delete : AnyPtr -> PrimIO () +||| It is up to the caller to `free` the `CharArray`. export serializeAsString : HasIO io => HloModuleProto -> io CharArray serializeAsString (MkHloModuleProto proto) = From efab2a55f0bf9f362571b69c8860892725d89b91 Mon Sep 17 00:00:00 2001 From: Joel Berkeley <16429957+joelberkeley@users.noreply.github.com> Date: Wed, 4 Dec 2024 18:38:07 +0000 Subject: [PATCH 09/38] build mlir::ModuleOp --- spidr/backend/BUILD | 6 +++ spidr/backend/src/mlir/IR/BUILD | 11 +++++ spidr/backend/src/mlir/IR/BuiltinOps.h | 18 ++++++++ spidr/backend/src/mlir/IR/DialectRegistry.cpp | 28 ++++++++++++ spidr/backend/src/mlir/IR/DialectRegistry.h | 18 ++++++++ spidr/backend/src/mlir/IR/MLIRContext.cpp | 37 ++++++++++++++++ spidr/backend/src/mlir/IR/MLIRContext.h | 18 ++++++++ spidr/backend/src/stablehlo/dialect/BUILD | 12 +++++ .../src/stablehlo/dialect/Register.cpp | 24 ++++++++++ spidr/backend/src/xla/hlo/ir/hlo_module.cpp | 1 + spidr/backend/src/xla/hlo/translate/BUILD | 3 +- .../src/xla/hlo/translate/portable_api.cpp | 33 -------------- .../src/xla/hlo/translate/stablehlo.cpp | 44 +++++++++++++++++++ spidr/backend/src/xla/mlir_hlo/mhlo/IR/BUILD | 12 +++++ .../src/xla/mlir_hlo/mhlo/IR/register.cpp | 24 ++++++++++ spidr/spidr.ipkg | 6 +++ spidr/src/Compiler/Eval.idr | 24 ++++++---- spidr/src/Compiler/MLIR/IR/BuiltinOps.idr | 20 +++++++++ .../src/Compiler/MLIR/IR/DialectRegistry.idr | 35 +++++++++++++++ spidr/src/Compiler/MLIR/IR/MLIRContext.idr | 44 +++++++++++++++++++ .../Compiler/StableHLO/Dialect/Register.idr | 27 ++++++++++++ .../Compiler/Xla/HLO/Translate/StableHLO.idr | 41 +++++++++++++++++ .../Compiler/Xla/MLIRHLO/MHLO/IR/Register.idr | 27 ++++++++++++ 23 files changed, 471 insertions(+), 42 deletions(-) create mode 100644 spidr/backend/src/mlir/IR/BUILD create mode 100644 spidr/backend/src/mlir/IR/BuiltinOps.h create mode 100644 spidr/backend/src/mlir/IR/DialectRegistry.cpp create mode 100644 spidr/backend/src/mlir/IR/DialectRegistry.h create mode 100644 spidr/backend/src/mlir/IR/MLIRContext.cpp create mode 100644 spidr/backend/src/mlir/IR/MLIRContext.h create mode 100644 spidr/backend/src/stablehlo/dialect/BUILD create mode 100644 spidr/backend/src/stablehlo/dialect/Register.cpp delete mode 100644 spidr/backend/src/xla/hlo/translate/portable_api.cpp create mode 100644 spidr/backend/src/xla/hlo/translate/stablehlo.cpp create mode 100644 spidr/backend/src/xla/mlir_hlo/mhlo/IR/BUILD create mode 100644 spidr/backend/src/xla/mlir_hlo/mhlo/IR/register.cpp create mode 100644 spidr/src/Compiler/MLIR/IR/BuiltinOps.idr create mode 100644 spidr/src/Compiler/MLIR/IR/DialectRegistry.idr create mode 100644 spidr/src/Compiler/MLIR/IR/MLIRContext.idr create mode 100644 spidr/src/Compiler/StableHLO/Dialect/Register.idr create mode 100644 spidr/src/Compiler/Xla/HLO/Translate/StableHLO.idr create mode 100644 spidr/src/Compiler/Xla/MLIRHLO/MHLO/IR/Register.idr diff --git a/spidr/backend/BUILD b/spidr/backend/BUILD index 65ec73bc4..760fdf40c 100644 --- a/spidr/backend/BUILD +++ b/spidr/backend/BUILD @@ -12,24 +12,30 @@ cc_binary( linkshared = True, linkstatic = True, srcs = [ + "//src/mlir/IR", + "//src/stablehlo/dialect", "//src/xla", "//src/xla/client", "//src/xla/hlo/builder", "//src/xla/hlo/builder/lib", "//src/xla/hlo/ir", "//src/xla/hlo/translate", + "//src/xla/mlir_hlo/mhlo/IR", "//src/xla/pjrt", "//src/xla/pjrt/c", "//src/xla/service", "//src", ], deps = [ + "//src/mlir/IR", + "//src/stablehlo/dialect", "//src/xla", "//src/xla/client", "//src/xla/hlo/builder", "//src/xla/hlo/builder/lib", "//src/xla/hlo/ir", "//src/xla/hlo/translate", + "//src/xla/mlir_hlo/mhlo/IR", "//src/xla/pjrt", "//src/xla/pjrt/c", "//src/xla/service", diff --git a/spidr/backend/src/mlir/IR/BUILD b/spidr/backend/src/mlir/IR/BUILD new file mode 100644 index 000000000..42fb4a6db --- /dev/null +++ b/spidr/backend/src/mlir/IR/BUILD @@ -0,0 +1,11 @@ +cc_library( + name = "IR", + linkstatic = True, + alwayslink = True, + srcs = glob(["*.cpp"]), + hdrs = glob(["*.h"]), + deps = [ + "@llvm-project//mlir:IR", + ], + visibility = ["//visibility:public"], +) diff --git a/spidr/backend/src/mlir/IR/BuiltinOps.h b/spidr/backend/src/mlir/IR/BuiltinOps.h new file mode 100644 index 000000000..0fb5ccbec --- /dev/null +++ b/spidr/backend/src/mlir/IR/BuiltinOps.h @@ -0,0 +1,18 @@ +/* +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +extern "C" { + struct ModuleOp; +} diff --git a/spidr/backend/src/mlir/IR/DialectRegistry.cpp b/spidr/backend/src/mlir/IR/DialectRegistry.cpp new file mode 100644 index 000000000..dfc543d57 --- /dev/null +++ b/spidr/backend/src/mlir/IR/DialectRegistry.cpp @@ -0,0 +1,28 @@ +/* +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "mlir/IR/DialectRegistry.h" + +#include "DialectRegistry.h" + +extern "C" { + DialectRegistry* DialectRegistry_new() { + return reinterpret_cast(new mlir::DialectRegistry()); + } + + void DialectRegistry_delete(DialectRegistry* s) { + delete reinterpret_cast(s); + } +} diff --git a/spidr/backend/src/mlir/IR/DialectRegistry.h b/spidr/backend/src/mlir/IR/DialectRegistry.h new file mode 100644 index 000000000..58c7ab272 --- /dev/null +++ b/spidr/backend/src/mlir/IR/DialectRegistry.h @@ -0,0 +1,18 @@ +/* +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +extern "C" { + struct DialectRegistry; +} diff --git a/spidr/backend/src/mlir/IR/MLIRContext.cpp b/spidr/backend/src/mlir/IR/MLIRContext.cpp new file mode 100644 index 000000000..f806806ef --- /dev/null +++ b/spidr/backend/src/mlir/IR/MLIRContext.cpp @@ -0,0 +1,37 @@ +/* +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "mlir/IR/MLIRContext.h" + +#include "DialectRegistry.h" +#include "MLIRContext.h" + +extern "C" { + MLIRContext* MLIRContext_new() { + printf("MLIRContext_new ...\n"); + auto res = reinterpret_cast(new mlir::MLIRContext); + printf("0\n"); + return res; + } + + void MLIRContext_delete(MLIRContext* s) { + delete reinterpret_cast(s); + } + + void MLIRContext_appendDialectRegistry(MLIRContext& s, DialectRegistry& registry) { + auto& registry_ = reinterpret_cast(registry); + reinterpret_cast(s).appendDialectRegistry(registry_); + } +} diff --git a/spidr/backend/src/mlir/IR/MLIRContext.h b/spidr/backend/src/mlir/IR/MLIRContext.h new file mode 100644 index 000000000..efa58bc0c --- /dev/null +++ b/spidr/backend/src/mlir/IR/MLIRContext.h @@ -0,0 +1,18 @@ +/* +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +extern "C" { + struct MLIRContext; +} diff --git a/spidr/backend/src/stablehlo/dialect/BUILD b/spidr/backend/src/stablehlo/dialect/BUILD new file mode 100644 index 000000000..ab4578b99 --- /dev/null +++ b/spidr/backend/src/stablehlo/dialect/BUILD @@ -0,0 +1,12 @@ +cc_library( + name = "dialect", + linkstatic = True, + alwayslink = True, + srcs = glob(["*.cpp"]), + hdrs = glob(["*.h"]), + deps = [ + "@stablehlo//:register", + "//src/mlir/IR", + ], + visibility = ["//visibility:public"], +) diff --git a/spidr/backend/src/stablehlo/dialect/Register.cpp b/spidr/backend/src/stablehlo/dialect/Register.cpp new file mode 100644 index 000000000..505668a34 --- /dev/null +++ b/spidr/backend/src/stablehlo/dialect/Register.cpp @@ -0,0 +1,24 @@ +/* +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "stablehlo/dialect/Register.h" + +#include "../../mlir/IR/DialectRegistry.h" + +extern "C" { + void registerAllDialects(DialectRegistry& registry) { + mlir::stablehlo::registerAllDialects(reinterpret_cast(registry)); + } +} diff --git a/spidr/backend/src/xla/hlo/ir/hlo_module.cpp b/spidr/backend/src/xla/hlo/ir/hlo_module.cpp index 077d56e1b..4944fe917 100644 --- a/spidr/backend/src/xla/hlo/ir/hlo_module.cpp +++ b/spidr/backend/src/xla/hlo/ir/hlo_module.cpp @@ -24,6 +24,7 @@ limitations under the License. extern "C" { HloModule* HloModule_CreateFromProto(HloModuleProto& proto, HloModuleConfig& module_config) { + // put print statements in all C functions to see if the error's coming from elsewhere auto& proto_ = reinterpret_cast(proto); auto& module_config_ = reinterpret_cast(module_config); auto module = xla::HloModule::CreateFromProto(proto_, module_config_); diff --git a/spidr/backend/src/xla/hlo/translate/BUILD b/spidr/backend/src/xla/hlo/translate/BUILD index 7ebd74269..a538cca98 100644 --- a/spidr/backend/src/xla/hlo/translate/BUILD +++ b/spidr/backend/src/xla/hlo/translate/BUILD @@ -5,7 +5,8 @@ cc_library( srcs = glob(["*.cpp"]), hdrs = glob(["*.h"]), deps = [ - "@xla//xla/hlo/translate:portable_api", + "@xla//xla/hlo/translate:stablehlo", + "//src/mlir/IR", "//src/xla/hlo/ir", "//src", ], diff --git a/spidr/backend/src/xla/hlo/translate/portable_api.cpp b/spidr/backend/src/xla/hlo/translate/portable_api.cpp deleted file mode 100644 index 1d9400e60..000000000 --- a/spidr/backend/src/xla/hlo/translate/portable_api.cpp +++ /dev/null @@ -1,33 +0,0 @@ -/* -Copyright 2024 Joel Berkeley - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ -#include "xla/hlo/ir/hlo_module.h" -#include "xla/hlo/translate/portable_api.h" - -#include "../ir/hlo_module.h" -#include "../../../ffi.h" - -extern "C" { - string* ConvertHloToStablehlo(HloModule& hlo_module) { - printf("ConvertHloToStablehlo ...\n"); - auto& hlo_module_ = reinterpret_cast(hlo_module); - printf("0\n"); - // the implementation of this function shows how to get the actual MLIR module, which is - // crucial for enzyme! - auto res = xla::ConvertHloToStablehlo(hlo_module_, true); - printf("1\n"); - return reinterpret_cast(new std::string(res.value())); - } -} diff --git a/spidr/backend/src/xla/hlo/translate/stablehlo.cpp b/spidr/backend/src/xla/hlo/translate/stablehlo.cpp new file mode 100644 index 000000000..8edc9c57d --- /dev/null +++ b/spidr/backend/src/xla/hlo/translate/stablehlo.cpp @@ -0,0 +1,44 @@ +/* +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "mlir/IR/BuiltinOps.h" +#include "xla/service/hlo.pb.h" +#include "xla/hlo/translate/stablehlo.h" + + +#include "mlir/Bytecode/BytecodeWriter.h" + +#include "../../service/hlo.proto.h" +#include "../../../mlir/IR/BuiltinOps.h" +#include "../../../mlir/IR/MLIRContext.h" +#include "../../../ffi.h" + +extern "C" { + ModuleOp* ConvertHloToStablehlo(MLIRContext& ctx, HloModuleProto* hlo_module) { + auto& ctx_ = reinterpret_cast(ctx); + auto hlo_module_ = reinterpret_cast(hlo_module); + auto module_op = xla::ConvertHloToStablehlo(ctx_, hlo_module_); + return reinterpret_cast(new mlir::ModuleOp(module_op.value().release())); + } + + string* SerializeUsingBytecode(ModuleOp& module) { + auto module_ = reinterpret_cast(module); + auto bytecode = new std::string(); + llvm::raw_string_ostream os(*bytecode); + mlir::BytecodeWriterConfig config; + mlir::writeBytecodeToFile(module_, os, config); + return reinterpret_cast(bytecode); + } +} diff --git a/spidr/backend/src/xla/mlir_hlo/mhlo/IR/BUILD b/spidr/backend/src/xla/mlir_hlo/mhlo/IR/BUILD new file mode 100644 index 000000000..e7f37a3c4 --- /dev/null +++ b/spidr/backend/src/xla/mlir_hlo/mhlo/IR/BUILD @@ -0,0 +1,12 @@ +cc_library( + name = "IR", + linkstatic = True, + alwayslink = True, + srcs = glob(["*.cpp"]), + hdrs = glob(["*.h"]), + deps = [ + "@xla//xla/mlir_hlo:hlo_dialect_registration", + "//src/mlir/IR", + ], + visibility = ["//visibility:public"], +) diff --git a/spidr/backend/src/xla/mlir_hlo/mhlo/IR/register.cpp b/spidr/backend/src/xla/mlir_hlo/mhlo/IR/register.cpp new file mode 100644 index 000000000..eb9319d4d --- /dev/null +++ b/spidr/backend/src/xla/mlir_hlo/mhlo/IR/register.cpp @@ -0,0 +1,24 @@ +/* +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "xla/mlir_hlo/mhlo/IR/register.h" + +#include "../../../../mlir/IR/DialectRegistry.h" + +extern "C" { + void registerAllMhloDialects(DialectRegistry& registry) { + mlir::mhlo::registerAllMhloDialects(reinterpret_cast(registry)); + } +} diff --git a/spidr/spidr.ipkg b/spidr/spidr.ipkg index 33cf3546d..dd834d0c6 100644 --- a/spidr/spidr.ipkg +++ b/spidr/spidr.ipkg @@ -8,6 +8,10 @@ modules = BayesianOptimization, BayesianOptimization.Acquisition, + Compiler.MLIR.IR.BuiltinOps, + Compiler.MLIR.IR.DialectRegistry, + Compiler.MLIR.IR.MLIRContext, + Compiler.StableHLO.Dialect.Register, Compiler.Xla.Client.ExecutableBuildOptions, Compiler.Xla.HLO.Builder.Lib.Arithmetic, Compiler.Xla.HLO.Builder.Lib.Constants, @@ -18,6 +22,8 @@ modules = Compiler.Xla.HLO.Builder.XlaComputation, Compiler.Xla.HLO.IR.HloModule, Compiler.Xla.HLO.Translate.PortableAPI, + Compiler.Xla.HLO.Translate.StableHLO, + Compiler.Xla.MLIRHLO.MHLO.IR.Register, Compiler.Xla.PJRT.C.PjrtCApi, Compiler.Xla.PJRT.PjrtExecutable, Compiler.Xla.Service.HloModuleConfig, diff --git a/spidr/src/Compiler/Eval.idr b/spidr/src/Compiler/Eval.idr index 21f1c2675..4f254a957 100644 --- a/spidr/src/Compiler/Eval.idr +++ b/spidr/src/Compiler/Eval.idr @@ -26,6 +26,10 @@ import Data.List.Elem import Compiler.Expr import Compiler.FFI import Compiler.LiteralRW +import Compiler.MLIR.IR.BuiltinOps +import Compiler.MLIR.IR.DialectRegistry +import Compiler.MLIR.IR.MLIRContext +import Compiler.StableHLO.Dialect.Register import Compiler.Xla.Client.ExecutableBuildOptions import Compiler.Xla.HLO.Builder.Lib.Arithmetic import Compiler.Xla.HLO.Builder.Lib.Constants @@ -35,7 +39,8 @@ import Compiler.Xla.HLO.Builder.Lib.PRNG import Compiler.Xla.HLO.Builder.XlaBuilder import Compiler.Xla.HLO.Builder.XlaComputation import Compiler.Xla.HLO.IR.HloModule -import Compiler.Xla.HLO.Translate.PortableAPI +import Compiler.Xla.HLO.Translate.StableHLO +import Compiler.Xla.MLIRHLO.MHLO.IR.Register import Compiler.Xla.PJRT.C.PjrtCApi import Compiler.Xla.PJRT.PjrtExecutable import Compiler.Xla.Service.HloModuleConfig @@ -50,6 +55,8 @@ import Types import Util import Device +import System + export data Err = OutOfBounds Nat Nat @@ -226,19 +233,20 @@ export covering execute : Device -> Fn 0 -> {outputs : _} -> Vect outputs Xla.Shape -> ErrIO $ Vect outputs Literal execute (MkDevice api client) f@(MkFn _ _ env) shapes = do putStrLn "execute ..." + dialectRegistry <- mkDialectRegistry + registerAllMhloDialects dialectRegistry + registerAllDialects dialectRegistry + mlirCtx <- mkMLIRContext + appendDialectRegistry mlirCtx dialectRegistry xlaBuilder <- mkXlaBuilder "root" computation <- compile @{!(newArray $ cast $ counter env)} xlaBuilder f - printLn 0 - moduleConfig <- hloModuleConfig !(getProgramShape computation) printLn 1 - module' <- createFromProto !(proto computation) moduleConfig + code <- serializeUsingBytecode !(convertHloToStablehlo mlirCtx !(proto computation)) printLn 2 - code <- convertHloToStablehlo module' - printLn 3 executableBuildOptions <- mkExecutableBuildOptions - printLn 4 + printLn 3 compileOptions <- serializeAsString !(mkCompileOptions executableBuildOptions) - printLn 5 + printLn 4 program <- mkPjrtProgram code bimapEitherT PjrtErr id $ do loadedExec <- pjrtClientCompile api client program compileOptions diff --git a/spidr/src/Compiler/MLIR/IR/BuiltinOps.idr b/spidr/src/Compiler/MLIR/IR/BuiltinOps.idr new file mode 100644 index 000000000..9f06c7327 --- /dev/null +++ b/spidr/src/Compiler/MLIR/IR/BuiltinOps.idr @@ -0,0 +1,20 @@ +{-- +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--} +||| For internal spidr use only. +module Compiler.MLIR.IR.BuiltinOps + +public export +data ModuleOp = MkModuleOp AnyPtr -- need to GC diff --git a/spidr/src/Compiler/MLIR/IR/DialectRegistry.idr b/spidr/src/Compiler/MLIR/IR/DialectRegistry.idr new file mode 100644 index 000000000..329b35308 --- /dev/null +++ b/spidr/src/Compiler/MLIR/IR/DialectRegistry.idr @@ -0,0 +1,35 @@ +{-- +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--} +||| For internal spidr use only. +module Compiler.MLIR.IR.DialectRegistry + +import Compiler.FFI + +public export +data DialectRegistry = MkDialectRegistry GCAnyPtr + +%foreign (libxla "DialectRegistry_new") +prim__mkDialectRegistry : PrimIO AnyPtr + +%foreign (libxla "DialectRegistry_delete") +prim__deleteDialectRegistry : AnyPtr -> PrimIO () + +export +mkDialectRegistry : HasIO io => io DialectRegistry +mkDialectRegistry = do + registry <- primIO prim__mkDialectRegistry + registry <- onCollectAny registry (primIO . prim__deleteDialectRegistry) + pure (MkDialectRegistry registry) diff --git a/spidr/src/Compiler/MLIR/IR/MLIRContext.idr b/spidr/src/Compiler/MLIR/IR/MLIRContext.idr new file mode 100644 index 000000000..645a21e56 --- /dev/null +++ b/spidr/src/Compiler/MLIR/IR/MLIRContext.idr @@ -0,0 +1,44 @@ +{-- +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--} +||| For internal spidr use only. +module Compiler.MLIR.IR.MLIRContext + +import Compiler.MLIR.IR.DialectRegistry +import Compiler.FFI + +public export +data MLIRContext = MkMLIRContext GCAnyPtr + +%foreign (libxla "MLIRContext_new") +prim__mkMLIRContext : PrimIO AnyPtr + +%foreign (libxla "MLIRContext_delete") +prim__deleteMLIRContext : AnyPtr -> PrimIO () + +export +mkMLIRContext : HasIO io => io MLIRContext +mkMLIRContext = do + ctx <- primIO prim__mkMLIRContext + ctx <- onCollectAny ctx (primIO . prim__deleteMLIRContext) + pure (MkMLIRContext ctx) + +%foreign (libxla "MLIRContext_appendDialectRegistry") +prim__appendDialectRegistry : GCAnyPtr -> GCAnyPtr -> PrimIO () + +export +appendDialectRegistry : HasIO io => MLIRContext -> DialectRegistry -> io () +appendDialectRegistry (MkMLIRContext ctx) (MkDialectRegistry registry) = + primIO $ prim__appendDialectRegistry ctx registry diff --git a/spidr/src/Compiler/StableHLO/Dialect/Register.idr b/spidr/src/Compiler/StableHLO/Dialect/Register.idr new file mode 100644 index 000000000..e51220ac2 --- /dev/null +++ b/spidr/src/Compiler/StableHLO/Dialect/Register.idr @@ -0,0 +1,27 @@ +{-- +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--} +||| For internal spidr use only. +module Compiler.StableHLO.Dialect.Register + +import Compiler.MLIR.IR.DialectRegistry +import Compiler.FFI + +%foreign (libxla "registerAllDialects") +prim__registerAllDialects : GCAnyPtr -> PrimIO () + +export +registerAllDialects : HasIO io => DialectRegistry -> io () +registerAllDialects (MkDialectRegistry reg) = primIO $ prim__registerAllDialects reg diff --git a/spidr/src/Compiler/Xla/HLO/Translate/StableHLO.idr b/spidr/src/Compiler/Xla/HLO/Translate/StableHLO.idr new file mode 100644 index 000000000..1dcefc61a --- /dev/null +++ b/spidr/src/Compiler/Xla/HLO/Translate/StableHLO.idr @@ -0,0 +1,41 @@ +{-- +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--} +||| For internal spidr use only. +module Compiler.Xla.HLO.Translate.StableHLO + +import Compiler.FFI +import Compiler.MLIR.IR.BuiltinOps +import Compiler.MLIR.IR.MLIRContext +import Compiler.Xla.Service.HloProto + +%foreign (libxla "ConvertHloToStablehlo") +prim__convertHloToStablehlo : GCAnyPtr -> GCAnyPtr -> PrimIO AnyPtr + +export +convertHloToStablehlo : HasIO io => MLIRContext -> HloModuleProto -> io ModuleOp +convertHloToStablehlo (MkMLIRContext ctx) (MkHloModuleProto proto) = do + moduleOp <- primIO $ prim__convertHloToStablehlo ctx proto + pure (MkModuleOp moduleOp) + +%foreign (libxla "SerializeUsingBytecode") +prim__serializeUsingBytecode : AnyPtr -> PrimIO AnyPtr + +export +serializeUsingBytecode : HasIO io => ModuleOp -> io CharArray +serializeUsingBytecode (MkModuleOp mop) = do + putStrLn "serializeUsingBytecode ..." + printLn 0 + primIO (prim__serializeUsingBytecode mop) >>= stringToCharArray diff --git a/spidr/src/Compiler/Xla/MLIRHLO/MHLO/IR/Register.idr b/spidr/src/Compiler/Xla/MLIRHLO/MHLO/IR/Register.idr new file mode 100644 index 000000000..77f82fd41 --- /dev/null +++ b/spidr/src/Compiler/Xla/MLIRHLO/MHLO/IR/Register.idr @@ -0,0 +1,27 @@ +{-- +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--} +||| For internal spidr use only. +module Compiler.Xla.MLIRHLO.MHLO.IR.Register + +import Compiler.MLIR.IR.DialectRegistry +import Compiler.FFI + +%foreign (libxla "registerAllMhloDialects") +prim__registerAllMhloDialects : GCAnyPtr -> PrimIO () + +export +registerAllMhloDialects : HasIO io => DialectRegistry -> io () +registerAllMhloDialects (MkDialectRegistry reg) = primIO $ prim__registerAllMhloDialects reg From 01c3cf21ec1c218c00dbec6b610d5151f0df1e82 Mon Sep 17 00:00:00 2001 From: Joel Berkeley <16429957+joelberkeley@users.noreply.github.com> Date: Sun, 8 Dec 2024 00:24:23 +0000 Subject: [PATCH 10/38] working on linux --- spidr/backend/BUILD | 2 - spidr/backend/src/ffi.cpp | 4 ++ spidr/backend/src/mlir/IR/BUILD | 1 + spidr/backend/src/mlir/IR/MLIRContext.cpp | 5 +- spidr/backend/src/stablehlo/dialect/BUILD | 2 + .../src/stablehlo/dialect/Serialization.cpp | 57 +++++++++++++++++++ .../src/xla/hlo/builder/xla_computation.cpp | 9 +-- spidr/backend/src/xla/hlo/translate/BUILD | 3 +- .../src/xla/hlo/translate/stablehlo.cpp | 13 ----- spidr/backend/src/xla/shape.cpp | 4 -- spidr/backend/src/xla/shape.h | 1 - spidr/spidr.ipkg | 4 +- spidr/src/Compiler/Eval.idr | 21 +++---- spidr/src/Compiler/FFI.idr | 4 ++ .../StableHLO/Dialect/Serialization.idr | 39 +++++++++++++ .../Xla/HLO/Builder/XlaComputation.idr | 10 ---- .../Compiler/Xla/HLO/Translate/StableHLO.idr | 10 ---- spidr/src/Compiler/Xla/PJRT/C/PjrtCApi.idr | 4 +- .../Compiler/Xla/Service/HloModuleConfig.idr | 36 ------------ spidr/src/Compiler/Xla/Shape.idr | 11 +--- .../xla-cpu/Main.idr | 17 ++++-- test/xla-cpu/xla-cpu.ipkg | 2 +- .../xla-cuda/Main.idr | 26 +++------ test/xla-cuda/xla-cuda.ipkg | 2 +- 24 files changed, 148 insertions(+), 139 deletions(-) create mode 100644 spidr/backend/src/stablehlo/dialect/Serialization.cpp create mode 100644 spidr/src/Compiler/StableHLO/Dialect/Serialization.idr delete mode 100644 spidr/src/Compiler/Xla/Service/HloModuleConfig.idr rename spidr/backend/src/xla/service/hlo_module_config.h => test/xla-cpu/Main.idr (79%) rename spidr/backend/src/xla/service/hlo_module_config.cpp => test/xla-cuda/Main.idr (50%) diff --git a/spidr/backend/BUILD b/spidr/backend/BUILD index 760fdf40c..cff2e8398 100644 --- a/spidr/backend/BUILD +++ b/spidr/backend/BUILD @@ -18,7 +18,6 @@ cc_binary( "//src/xla/client", "//src/xla/hlo/builder", "//src/xla/hlo/builder/lib", - "//src/xla/hlo/ir", "//src/xla/hlo/translate", "//src/xla/mlir_hlo/mhlo/IR", "//src/xla/pjrt", @@ -33,7 +32,6 @@ cc_binary( "//src/xla/client", "//src/xla/hlo/builder", "//src/xla/hlo/builder/lib", - "//src/xla/hlo/ir", "//src/xla/hlo/translate", "//src/xla/mlir_hlo/mhlo/IR", "//src/xla/pjrt", diff --git a/spidr/backend/src/ffi.cpp b/spidr/backend/src/ffi.cpp index 2a77d13ac..f940b423e 100644 --- a/spidr/backend/src/ffi.cpp +++ b/spidr/backend/src/ffi.cpp @@ -29,6 +29,10 @@ extern "C" { return ptr == nullptr; } + string* string_new() { + return reinterpret_cast(new std::string()); + } + void string_delete(string* s) { delete reinterpret_cast(s); } diff --git a/spidr/backend/src/mlir/IR/BUILD b/spidr/backend/src/mlir/IR/BUILD index 42fb4a6db..f034b361b 100644 --- a/spidr/backend/src/mlir/IR/BUILD +++ b/spidr/backend/src/mlir/IR/BUILD @@ -6,6 +6,7 @@ cc_library( hdrs = glob(["*.h"]), deps = [ "@llvm-project//mlir:IR", + "//src", ], visibility = ["//visibility:public"], ) diff --git a/spidr/backend/src/mlir/IR/MLIRContext.cpp b/spidr/backend/src/mlir/IR/MLIRContext.cpp index f806806ef..9083b03e7 100644 --- a/spidr/backend/src/mlir/IR/MLIRContext.cpp +++ b/spidr/backend/src/mlir/IR/MLIRContext.cpp @@ -20,10 +20,7 @@ limitations under the License. extern "C" { MLIRContext* MLIRContext_new() { - printf("MLIRContext_new ...\n"); - auto res = reinterpret_cast(new mlir::MLIRContext); - printf("0\n"); - return res; + return reinterpret_cast(new mlir::MLIRContext); } void MLIRContext_delete(MLIRContext* s) { diff --git a/spidr/backend/src/stablehlo/dialect/BUILD b/spidr/backend/src/stablehlo/dialect/BUILD index ab4578b99..8cfce51e4 100644 --- a/spidr/backend/src/stablehlo/dialect/BUILD +++ b/spidr/backend/src/stablehlo/dialect/BUILD @@ -5,7 +5,9 @@ cc_library( srcs = glob(["*.cpp"]), hdrs = glob(["*.h"]), deps = [ + "@llvm-project//mlir:IR", "@stablehlo//:register", + "@stablehlo//:stablehlo_serialization", "//src/mlir/IR", ], visibility = ["//visibility:public"], diff --git a/spidr/backend/src/stablehlo/dialect/Serialization.cpp b/spidr/backend/src/stablehlo/dialect/Serialization.cpp new file mode 100644 index 000000000..87c066038 --- /dev/null +++ b/spidr/backend/src/stablehlo/dialect/Serialization.cpp @@ -0,0 +1,57 @@ +/* +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "mlir/Bytecode/BytecodeWriter.h" +#include "stablehlo/dialect/Serialization.h" +#include "stablehlo/dialect/Version.h" + +#include "../../mlir/IR/BuiltinOps.h" +#include "../../ffi.h" + +extern "C" { + int serializePortableArtifact(ModuleOp& module, string& str) { + auto& module_ = reinterpret_cast(module); + auto& str_ = reinterpret_cast(str); + +// std::string s; +// llvm::raw_string_ostream os0(s); +// module_.print(os0); +// printf("serializePortableArtifact ...\n"); +// printf("... debug print:\n"); +// printf("%s\n", s.c_str()); + + llvm::raw_string_ostream os(str_); +// if (mlir::writeBytecodeToFile(module_, os).failed()) { +// return (int) false; +// } + +// printf("... serialization:\n"); +// printf("%s\n", str_.c_str()); + auto version = mlir::vhlo::Version::getMinimumVersion().toString(); + auto result = mlir::stablehlo::serializePortableArtifact(module_, version, os); + return (int) result.succeeded(); + } + + string* printModule(ModuleOp& module) { + auto& module_ = reinterpret_cast(module); + auto str = new std::string(); + llvm::raw_string_ostream os(*str); + module_.print(os); + + printf("... debug print:\n"); + printf("%s\n", str->c_str()); + return reinterpret_cast(str); + } +} diff --git a/spidr/backend/src/xla/hlo/builder/xla_computation.cpp b/spidr/backend/src/xla/hlo/builder/xla_computation.cpp index f6896b9bc..8695e74ee 100644 --- a/spidr/backend/src/xla/hlo/builder/xla_computation.cpp +++ b/spidr/backend/src/xla/hlo/builder/xla_computation.cpp @@ -26,13 +26,8 @@ extern "C" { delete reinterpret_cast(s); } - ProgramShape* XlaComputation_GetProgramShape(XlaComputation* s) { - auto res = reinterpret_cast(s)->GetProgramShape(); - return reinterpret_cast(new xla::ProgramShape(*res)); - } - HloModuleProto* XlaComputation_proto(XlaComputation* s) { - auto res = reinterpret_cast(s)->proto(); - return reinterpret_cast(new xla::HloModuleProto(res)); + auto s_ = reinterpret_cast(s); + return reinterpret_cast(new xla::HloModuleProto(s_->proto())); } } diff --git a/spidr/backend/src/xla/hlo/translate/BUILD b/spidr/backend/src/xla/hlo/translate/BUILD index a538cca98..75212dc84 100644 --- a/spidr/backend/src/xla/hlo/translate/BUILD +++ b/spidr/backend/src/xla/hlo/translate/BUILD @@ -7,8 +7,7 @@ cc_library( deps = [ "@xla//xla/hlo/translate:stablehlo", "//src/mlir/IR", - "//src/xla/hlo/ir", - "//src", + "//src/xla/service", ], visibility = ["//visibility:public"], ) diff --git a/spidr/backend/src/xla/hlo/translate/stablehlo.cpp b/spidr/backend/src/xla/hlo/translate/stablehlo.cpp index 8edc9c57d..e358aeca7 100644 --- a/spidr/backend/src/xla/hlo/translate/stablehlo.cpp +++ b/spidr/backend/src/xla/hlo/translate/stablehlo.cpp @@ -17,13 +17,9 @@ limitations under the License. #include "xla/service/hlo.pb.h" #include "xla/hlo/translate/stablehlo.h" - -#include "mlir/Bytecode/BytecodeWriter.h" - #include "../../service/hlo.proto.h" #include "../../../mlir/IR/BuiltinOps.h" #include "../../../mlir/IR/MLIRContext.h" -#include "../../../ffi.h" extern "C" { ModuleOp* ConvertHloToStablehlo(MLIRContext& ctx, HloModuleProto* hlo_module) { @@ -32,13 +28,4 @@ extern "C" { auto module_op = xla::ConvertHloToStablehlo(ctx_, hlo_module_); return reinterpret_cast(new mlir::ModuleOp(module_op.value().release())); } - - string* SerializeUsingBytecode(ModuleOp& module) { - auto module_ = reinterpret_cast(module); - auto bytecode = new std::string(); - llvm::raw_string_ostream os(*bytecode); - mlir::BytecodeWriterConfig config; - mlir::writeBytecodeToFile(module_, os, config); - return reinterpret_cast(bytecode); - } } diff --git a/spidr/backend/src/xla/shape.cpp b/spidr/backend/src/xla/shape.cpp index 2bed88cde..b85223928 100644 --- a/spidr/backend/src/xla/shape.cpp +++ b/spidr/backend/src/xla/shape.cpp @@ -29,8 +29,4 @@ extern "C" { void set_array_Shape(Shape* arr, int idx, Shape* shape) { reinterpret_cast(arr)[idx] = *reinterpret_cast(shape); } - - void ProgramShape_delete(ProgramShape* s) { - delete reinterpret_cast(s); - } } diff --git a/spidr/backend/src/xla/shape.h b/spidr/backend/src/xla/shape.h index a05f9e411..27da41111 100644 --- a/spidr/backend/src/xla/shape.h +++ b/spidr/backend/src/xla/shape.h @@ -15,5 +15,4 @@ limitations under the License. */ extern "C" { struct Shape; - struct ProgramShape; } diff --git a/spidr/spidr.ipkg b/spidr/spidr.ipkg index dd834d0c6..a0dfb6bce 100644 --- a/spidr/spidr.ipkg +++ b/spidr/spidr.ipkg @@ -12,6 +12,7 @@ modules = Compiler.MLIR.IR.DialectRegistry, Compiler.MLIR.IR.MLIRContext, Compiler.StableHLO.Dialect.Register, + Compiler.StableHLO.Dialect.Serialization, Compiler.Xla.Client.ExecutableBuildOptions, Compiler.Xla.HLO.Builder.Lib.Arithmetic, Compiler.Xla.HLO.Builder.Lib.Constants, @@ -20,13 +21,10 @@ modules = Compiler.Xla.HLO.Builder.Lib.PRNG, Compiler.Xla.HLO.Builder.XlaBuilder, Compiler.Xla.HLO.Builder.XlaComputation, - Compiler.Xla.HLO.IR.HloModule, - Compiler.Xla.HLO.Translate.PortableAPI, Compiler.Xla.HLO.Translate.StableHLO, Compiler.Xla.MLIRHLO.MHLO.IR.Register, Compiler.Xla.PJRT.C.PjrtCApi, Compiler.Xla.PJRT.PjrtExecutable, - Compiler.Xla.Service.HloModuleConfig, Compiler.Xla.Service.HloProto, Compiler.Xla.Literal, Compiler.Xla.Shape, diff --git a/spidr/src/Compiler/Eval.idr b/spidr/src/Compiler/Eval.idr index 4f254a957..e719cd255 100644 --- a/spidr/src/Compiler/Eval.idr +++ b/spidr/src/Compiler/Eval.idr @@ -30,6 +30,7 @@ import Compiler.MLIR.IR.BuiltinOps import Compiler.MLIR.IR.DialectRegistry import Compiler.MLIR.IR.MLIRContext import Compiler.StableHLO.Dialect.Register +import Compiler.StableHLO.Dialect.Serialization import Compiler.Xla.Client.ExecutableBuildOptions import Compiler.Xla.HLO.Builder.Lib.Arithmetic import Compiler.Xla.HLO.Builder.Lib.Constants @@ -38,12 +39,10 @@ import Compiler.Xla.HLO.Builder.Lib.Matrix import Compiler.Xla.HLO.Builder.Lib.PRNG import Compiler.Xla.HLO.Builder.XlaBuilder import Compiler.Xla.HLO.Builder.XlaComputation -import Compiler.Xla.HLO.IR.HloModule import Compiler.Xla.HLO.Translate.StableHLO import Compiler.Xla.MLIRHLO.MHLO.IR.Register import Compiler.Xla.PJRT.C.PjrtCApi import Compiler.Xla.PJRT.PjrtExecutable -import Compiler.Xla.Service.HloModuleConfig import Compiler.Xla.Service.HloProto import Compiler.Xla.Literal import Compiler.Xla.Shape @@ -62,12 +61,14 @@ data Err = OutOfBounds Nat Nat | ValueNotFound Nat | PjrtErr PjrtError + | SerializationError String export Show Err where show (OutOfBounds idx size) = "Index \{show idx} is out of bounds for array of size \{show size}" show (ValueNotFound idx) = "Value not found at index \{show idx}" - show (PjrtErr err)= show err + show (PjrtErr err) = show err + show (SerializationError err) = "SerializationError: \{err}" public export 0 ErrIO : Type -> Type @@ -232,25 +233,21 @@ interpret @{cache} xlaBuilder (MkFn params root env) = do export covering execute : Device -> Fn 0 -> {outputs : _} -> Vect outputs Xla.Shape -> ErrIO $ Vect outputs Literal execute (MkDevice api client) f@(MkFn _ _ env) shapes = do - putStrLn "execute ..." + xlaBuilder <- mkXlaBuilder "root" + computation <- compile @{!(newArray $ cast $ counter env)} xlaBuilder f dialectRegistry <- mkDialectRegistry registerAllMhloDialects dialectRegistry registerAllDialects dialectRegistry mlirCtx <- mkMLIRContext + stablehlo <- convertHloToStablehlo mlirCtx !(proto computation) appendDialectRegistry mlirCtx dialectRegistry - xlaBuilder <- mkXlaBuilder "root" - computation <- compile @{!(newArray $ cast $ counter env)} xlaBuilder f - printLn 1 - code <- serializeUsingBytecode !(convertHloToStablehlo mlirCtx !(proto computation)) - printLn 2 + Just code <- serializePortableArtifact stablehlo | Nothing => throwE (SerializationError "Failed to serialize StableHLO") + -- code <- printModule stablehlo executableBuildOptions <- mkExecutableBuildOptions - printLn 3 compileOptions <- serializeAsString !(mkCompileOptions executableBuildOptions) - printLn 4 program <- mkPjrtProgram code bimapEitherT PjrtErr id $ do loadedExec <- pjrtClientCompile api client program compileOptions - printLn 6 free code free compileOptions delete executableBuildOptions diff --git a/spidr/src/Compiler/FFI.idr b/spidr/src/Compiler/FFI.idr index a2f24eab1..307b9574b 100644 --- a/spidr/src/Compiler/FFI.idr +++ b/spidr/src/Compiler/FFI.idr @@ -31,6 +31,10 @@ namespace CharArray free : HasIO io => CharArray -> io () free (MkCharArray arr _) = free $ prim__forgetPtr arr +export +%foreign (libxla "string_new") +prim__stringNew : PrimIO AnyPtr + export %foreign (libxla "string_delete") prim__stringDelete : AnyPtr -> PrimIO () diff --git a/spidr/src/Compiler/StableHLO/Dialect/Serialization.idr b/spidr/src/Compiler/StableHLO/Dialect/Serialization.idr new file mode 100644 index 000000000..4b8d55e02 --- /dev/null +++ b/spidr/src/Compiler/StableHLO/Dialect/Serialization.idr @@ -0,0 +1,39 @@ +{-- +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--} +||| For internal spidr use only. +module Compiler.StableHLO.Dialect.Serialization + +import Compiler.MLIR.IR.BuiltinOps +import Compiler.FFI + +%foreign (libxla "serializePortableArtifact") +prim__serializePortableArtifact : AnyPtr -> AnyPtr -> PrimIO Int + +export +serializePortableArtifact : HasIO io => ModuleOp -> io (Maybe CharArray) +serializePortableArtifact (MkModuleOp moduleOp) = do + str <- primIO prim__stringNew + ok <- primIO $ prim__serializePortableArtifact moduleOp str + case cIntToBool ok of + True => Just <$> stringToCharArray str + False => free str >> pure Nothing + +%foreign (libxla "printModule") +prim__printModule : AnyPtr -> PrimIO AnyPtr + +export +printModule : HasIO io => ModuleOp -> io CharArray +printModule (MkModuleOp moduleOp) = primIO (prim__printModule moduleOp) >>= stringToCharArray diff --git a/spidr/src/Compiler/Xla/HLO/Builder/XlaComputation.idr b/spidr/src/Compiler/Xla/HLO/Builder/XlaComputation.idr index b579be951..a9a4c455f 100644 --- a/spidr/src/Compiler/Xla/HLO/Builder/XlaComputation.idr +++ b/spidr/src/Compiler/Xla/HLO/Builder/XlaComputation.idr @@ -31,16 +31,6 @@ export delete : AnyPtr -> IO () delete = primIO . XlaComputation.prim__delete -%foreign (libxla "XlaComputation_GetProgramShape") -prim__xlaComputationGetProgramShape : GCAnyPtr -> PrimIO AnyPtr - -export -getProgramShape : HasIO io => XlaComputation -> io ProgramShape -getProgramShape (MkXlaComputation comp) = do - pshape <- primIO $ prim__xlaComputationGetProgramShape comp - pshape <- onCollectAny pshape (primIO . prim__ProgramShape_delete) - pure (MkProgramShape pshape) - %foreign (libxla "XlaComputation_proto") prim__xlaComputationProto : GCAnyPtr -> PrimIO AnyPtr diff --git a/spidr/src/Compiler/Xla/HLO/Translate/StableHLO.idr b/spidr/src/Compiler/Xla/HLO/Translate/StableHLO.idr index 1dcefc61a..b58bbe055 100644 --- a/spidr/src/Compiler/Xla/HLO/Translate/StableHLO.idr +++ b/spidr/src/Compiler/Xla/HLO/Translate/StableHLO.idr @@ -29,13 +29,3 @@ convertHloToStablehlo : HasIO io => MLIRContext -> HloModuleProto -> io ModuleOp convertHloToStablehlo (MkMLIRContext ctx) (MkHloModuleProto proto) = do moduleOp <- primIO $ prim__convertHloToStablehlo ctx proto pure (MkModuleOp moduleOp) - -%foreign (libxla "SerializeUsingBytecode") -prim__serializeUsingBytecode : AnyPtr -> PrimIO AnyPtr - -export -serializeUsingBytecode : HasIO io => ModuleOp -> io CharArray -serializeUsingBytecode (MkModuleOp mop) = do - putStrLn "serializeUsingBytecode ..." - printLn 0 - primIO (prim__serializeUsingBytecode mop) >>= stringToCharArray diff --git a/spidr/src/Compiler/Xla/PJRT/C/PjrtCApi.idr b/spidr/src/Compiler/Xla/PJRT/C/PjrtCApi.idr index 734f9fa16..24f0b7f9f 100644 --- a/spidr/src/Compiler/Xla/PJRT/C/PjrtCApi.idr +++ b/spidr/src/Compiler/Xla/PJRT/C/PjrtCApi.idr @@ -73,9 +73,9 @@ export Show PjrtError where show e = let code = case e.code of - Nothing => "not found" + Nothing => "unknown" Just c => show c - in "PjrtError \{show e.message} (code \{code})" + in "PjrtError (error code \{code})\n\{e.message}" %foreign (libxla "PJRT_Error_Destroy_Args_new") prim__mkPjrtErrorDestroyArgs : AnyPtr -> PrimIO AnyPtr diff --git a/spidr/src/Compiler/Xla/Service/HloModuleConfig.idr b/spidr/src/Compiler/Xla/Service/HloModuleConfig.idr deleted file mode 100644 index 15899d502..000000000 --- a/spidr/src/Compiler/Xla/Service/HloModuleConfig.idr +++ /dev/null @@ -1,36 +0,0 @@ -{-- -Copyright 2024 Joel Berkeley - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. ---} -||| For internal spidr use only. -module Compiler.Xla.Service.HloModuleConfig - -import Compiler.FFI -import Compiler.Xla.Shape - -public export -data HloModuleConfig = MkHloModuleConfig GCAnyPtr - -%foreign (libxla "HloModuleConfig_new") -prim__hloModuleConfig : GCAnyPtr -> PrimIO AnyPtr - -%foreign (libxla "HloModuleConfig_delete") -prim__delete : AnyPtr -> PrimIO () - -export -hloModuleConfig : HasIO io => ProgramShape -> io HloModuleConfig -hloModuleConfig (MkProgramShape pshape) = do - config <- primIO $ prim__hloModuleConfig pshape - config <- onCollectAny config (primIO . prim__delete) - pure (MkHloModuleConfig config) diff --git a/spidr/src/Compiler/Xla/Shape.idr b/spidr/src/Compiler/Xla/Shape.idr index dee238c75..95d2156a9 100644 --- a/spidr/src/Compiler/Xla/Shape.idr +++ b/spidr/src/Compiler/Xla/Shape.idr @@ -25,11 +25,11 @@ namespace Xla MkShape : GCAnyPtr -> Shape %foreign (libxla "Shape_delete") -prim__Shape_delete : AnyPtr -> PrimIO () +prim__delete : AnyPtr -> PrimIO () export delete : AnyPtr -> IO () -delete = primIO . prim__Shape_delete +delete = primIO . prim__delete %foreign (libxla "sizeof_Shape") sizeOfShape : Int @@ -48,10 +48,3 @@ mkShapeArray shapes = do primIO $ prim__setArrayShape arr (cast idx) shape) (enumerate (fromList shapes)) arr <- onCollectAny arr free pure (MkShapeArray arr) - -public export -data ProgramShape = MkProgramShape GCAnyPtr - -export -%foreign (libxla "ProgramShape_delete") -prim__ProgramShape_delete : AnyPtr -> PrimIO () diff --git a/spidr/backend/src/xla/service/hlo_module_config.h b/test/xla-cpu/Main.idr similarity index 79% rename from spidr/backend/src/xla/service/hlo_module_config.h rename to test/xla-cpu/Main.idr index dff48c9ca..854d1eae4 100644 --- a/spidr/backend/src/xla/service/hlo_module_config.h +++ b/test/xla-cpu/Main.idr @@ -1,4 +1,4 @@ -/* +{-- Copyright 2024 Joel Berkeley Licensed under the Apache License, Version 2.0 (the "License"); @@ -12,7 +12,14 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. -*/ -extern "C" { - struct HloModuleConfig; -} +--} +module Main + +import System + +import TestRunner +import PjrtPluginXlaCpu + +partial +main : IO () +main = eitherT (die . show) run device diff --git a/test/xla-cpu/xla-cpu.ipkg b/test/xla-cpu/xla-cpu.ipkg index 24255b025..39fd35065 100644 --- a/test/xla-cpu/xla-cpu.ipkg +++ b/test/xla-cpu/xla-cpu.ipkg @@ -5,4 +5,4 @@ depends = runner executable = test -main = XlaCpu +main = Main diff --git a/spidr/backend/src/xla/service/hlo_module_config.cpp b/test/xla-cuda/Main.idr similarity index 50% rename from spidr/backend/src/xla/service/hlo_module_config.cpp rename to test/xla-cuda/Main.idr index 59478448a..4a727f497 100644 --- a/spidr/backend/src/xla/service/hlo_module_config.cpp +++ b/test/xla-cuda/Main.idr @@ -1,4 +1,4 @@ -/* +{-- Copyright 2024 Joel Berkeley Licensed under the Apache License, Version 2.0 (the "License"); @@ -12,22 +12,14 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. -*/ -#include "xla/service/hlo_module_config.h" -#include "xla/shape.h" +--} +module Main -#include "hlo_module_config.h" +import System -#include "../shape.h" +import TestRunner +import PjrtPluginXlaCuda -extern "C" { - HloModuleConfig* HloModuleConfig_new(ProgramShape& program_shape) { - auto& program_shape_ = reinterpret_cast(program_shape); - auto config = new xla::HloModuleConfig(program_shape_); - return reinterpret_cast(config); - } - - void HloModuleConfig_delete(HloModuleConfig* s) { - delete reinterpret_cast(s); - } -} +partial +main : IO () +main = eitherT (die . show) run device diff --git a/test/xla-cuda/xla-cuda.ipkg b/test/xla-cuda/xla-cuda.ipkg index 66c3f269b..9d76e1994 100644 --- a/test/xla-cuda/xla-cuda.ipkg +++ b/test/xla-cuda/xla-cuda.ipkg @@ -5,4 +5,4 @@ depends = runner executable = test -main = XlaCuda +main = Main From de00a64b2603e6d34b247c768119051e67a2a3f3 Mon Sep 17 00:00:00 2001 From: Joel Berkeley <16429957+joelberkeley@users.noreply.github.com> Date: Sun, 8 Dec 2024 01:14:40 +0000 Subject: [PATCH 11/38] wip --- .../src/stablehlo/dialect/Serialization.cpp | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/spidr/backend/src/stablehlo/dialect/Serialization.cpp b/spidr/backend/src/stablehlo/dialect/Serialization.cpp index 87c066038..dbac5e90c 100644 --- a/spidr/backend/src/stablehlo/dialect/Serialization.cpp +++ b/spidr/backend/src/stablehlo/dialect/Serialization.cpp @@ -13,7 +13,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "mlir/Bytecode/BytecodeWriter.h" #include "stablehlo/dialect/Serialization.h" #include "stablehlo/dialect/Version.h" @@ -24,21 +23,7 @@ extern "C" { int serializePortableArtifact(ModuleOp& module, string& str) { auto& module_ = reinterpret_cast(module); auto& str_ = reinterpret_cast(str); - -// std::string s; -// llvm::raw_string_ostream os0(s); -// module_.print(os0); -// printf("serializePortableArtifact ...\n"); -// printf("... debug print:\n"); -// printf("%s\n", s.c_str()); - llvm::raw_string_ostream os(str_); -// if (mlir::writeBytecodeToFile(module_, os).failed()) { -// return (int) false; -// } - -// printf("... serialization:\n"); -// printf("%s\n", str_.c_str()); auto version = mlir::vhlo::Version::getMinimumVersion().toString(); auto result = mlir::stablehlo::serializePortableArtifact(module_, version, os); return (int) result.succeeded(); @@ -49,9 +34,6 @@ extern "C" { auto str = new std::string(); llvm::raw_string_ostream os(*str); module_.print(os); - - printf("... debug print:\n"); - printf("%s\n", str->c_str()); return reinterpret_cast(str); } } From 6cea6ba34439d646ede0ec0d2998de8f7ea39d1e Mon Sep 17 00:00:00 2001 From: Joel Berkeley <16429957+joelberkeley@users.noreply.github.com> Date: Sun, 8 Dec 2024 01:16:35 +0000 Subject: [PATCH 12/38] wip --- spidr/backend/src/stablehlo/dialect/BUILD | 1 - 1 file changed, 1 deletion(-) diff --git a/spidr/backend/src/stablehlo/dialect/BUILD b/spidr/backend/src/stablehlo/dialect/BUILD index 8cfce51e4..91d6dd3ed 100644 --- a/spidr/backend/src/stablehlo/dialect/BUILD +++ b/spidr/backend/src/stablehlo/dialect/BUILD @@ -5,7 +5,6 @@ cc_library( srcs = glob(["*.cpp"]), hdrs = glob(["*.h"]), deps = [ - "@llvm-project//mlir:IR", "@stablehlo//:register", "@stablehlo//:stablehlo_serialization", "//src/mlir/IR", From 4f66569e9668793d6126b8867961aed3fa3cb038 Mon Sep 17 00:00:00 2001 From: Joel Berkeley <16429957+joelberkeley@users.noreply.github.com> Date: Sun, 8 Dec 2024 02:33:44 +0000 Subject: [PATCH 13/38] tan --- spidr/src/Tensor.idr | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spidr/src/Tensor.idr b/spidr/src/Tensor.idr index 24272644c..8d97cd785 100644 --- a/spidr/src/Tensor.idr +++ b/spidr/src/Tensor.idr @@ -1336,7 +1336,7 @@ cos = unary Cos ||| The element-wise tangent. export tan : Tensor shape F64 -> Tensor shape F64 -tan = unary Tan +tan x = sin x / cos x ||| The element-wise inverse sine. export From cf0d80f4f477cde6486db0df80b53cbeb70918c7 Mon Sep 17 00:00:00 2001 From: Joel Berkeley <16429957+joelberkeley@users.noreply.github.com> Date: Fri, 13 Dec 2024 22:08:48 +0000 Subject: [PATCH 14/38] wip --- .../src/stablehlo/dialect/Serialization.cpp | 15 ++----- .../backend/src/stablehlo/dialect/Version.cpp | 36 +++++++++++++++ spidr/spidr.ipkg | 1 + spidr/src/Compiler/Eval.idr | 22 +++++---- spidr/src/Compiler/FFI.idr | 3 ++ .../StableHLO/Dialect/Serialization.idr | 13 ++---- .../Compiler/StableHLO/Dialect/Version.idr | 45 +++++++++++++++++++ 7 files changed, 105 insertions(+), 30 deletions(-) create mode 100644 spidr/backend/src/stablehlo/dialect/Version.cpp create mode 100644 spidr/src/Compiler/StableHLO/Dialect/Version.idr diff --git a/spidr/backend/src/stablehlo/dialect/Serialization.cpp b/spidr/backend/src/stablehlo/dialect/Serialization.cpp index dbac5e90c..5d334466b 100644 --- a/spidr/backend/src/stablehlo/dialect/Serialization.cpp +++ b/spidr/backend/src/stablehlo/dialect/Serialization.cpp @@ -14,26 +14,17 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "stablehlo/dialect/Serialization.h" -#include "stablehlo/dialect/Version.h" #include "../../mlir/IR/BuiltinOps.h" #include "../../ffi.h" extern "C" { - int serializePortableArtifact(ModuleOp& module, string& str) { + int serializePortableArtifact(ModuleOp& module, string& version, string& str) { auto& module_ = reinterpret_cast(module); + auto& version_ = reinterpret_cast(str); auto& str_ = reinterpret_cast(str); llvm::raw_string_ostream os(str_); - auto version = mlir::vhlo::Version::getMinimumVersion().toString(); - auto result = mlir::stablehlo::serializePortableArtifact(module_, version, os); + auto result = mlir::stablehlo::serializePortableArtifact(module_, version_, os); return (int) result.succeeded(); } - - string* printModule(ModuleOp& module) { - auto& module_ = reinterpret_cast(module); - auto str = new std::string(); - llvm::raw_string_ostream os(*str); - module_.print(os); - return reinterpret_cast(str); - } } diff --git a/spidr/backend/src/stablehlo/dialect/Version.cpp b/spidr/backend/src/stablehlo/dialect/Version.cpp new file mode 100644 index 000000000..c402990c2 --- /dev/null +++ b/spidr/backend/src/stablehlo/dialect/Version.cpp @@ -0,0 +1,36 @@ +/* +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "stablehlo/dialect/Version.h" + +#include "../../ffi.h" + +extern "C" { + struct Version; + + void Version_delete(Version* s) { + delete reinterpret_cast(s); + } + + Version* Version_getMinimumVersion() { + auto version = mlir::vhlo::Version::getMinimumVersion(); + return reinterpret_cast(new mlir::vhlo::Version(version)); + } + + string* Version_toString(Version& s) { + auto& s_ = reinterpret_cast(s); + return reinterpret_cast(new std::string(s_.toString())); + } +} diff --git a/spidr/spidr.ipkg b/spidr/spidr.ipkg index a0dfb6bce..b83273331 100644 --- a/spidr/spidr.ipkg +++ b/spidr/spidr.ipkg @@ -13,6 +13,7 @@ modules = Compiler.MLIR.IR.MLIRContext, Compiler.StableHLO.Dialect.Register, Compiler.StableHLO.Dialect.Serialization, + Compiler.StableHLO.Dialect.Version, Compiler.Xla.Client.ExecutableBuildOptions, Compiler.Xla.HLO.Builder.Lib.Arithmetic, Compiler.Xla.HLO.Builder.Lib.Constants, diff --git a/spidr/src/Compiler/Eval.idr b/spidr/src/Compiler/Eval.idr index e719cd255..133fed6af 100644 --- a/spidr/src/Compiler/Eval.idr +++ b/spidr/src/Compiler/Eval.idr @@ -31,6 +31,7 @@ import Compiler.MLIR.IR.DialectRegistry import Compiler.MLIR.IR.MLIRContext import Compiler.StableHLO.Dialect.Register import Compiler.StableHLO.Dialect.Serialization +import Compiler.StableHLO.Dialect.Version import Compiler.Xla.Client.ExecutableBuildOptions import Compiler.Xla.HLO.Builder.Lib.Arithmetic import Compiler.Xla.HLO.Builder.Lib.Constants @@ -229,20 +230,25 @@ interpret @{cache} xlaBuilder (MkFn params root env) = do !(interpretE key) !(interpretE initialState) ThreeFry !(mkShape {dtype = F64} shape) tuple xlaBuilder [value rngOutput, state rngOutput] +hloModuleProtoToStableHLO : HloModuleProto -> ErrIO CharArray +hloModuleProtoToStableHLO proto = do + dialectRegistry <- mkDialectRegistry + registerAllMhloDialects dialectRegistry + registerAllDialects dialectRegistry + mlirCtx <- mkMLIRContext + stablehlo <- convertHloToStablehlo mlirCtx proto + appendDialectRegistry mlirCtx dialectRegistry + Just code <- serializePortableArtifact stablehlo !(toString !getMinimumVersion) + | Nothing => throwE (SerializationError "Failed to serialize StableHLO") + pure code + ||| It is up to the caller to free the `Literal`s. export covering execute : Device -> Fn 0 -> {outputs : _} -> Vect outputs Xla.Shape -> ErrIO $ Vect outputs Literal execute (MkDevice api client) f@(MkFn _ _ env) shapes = do xlaBuilder <- mkXlaBuilder "root" computation <- compile @{!(newArray $ cast $ counter env)} xlaBuilder f - dialectRegistry <- mkDialectRegistry - registerAllMhloDialects dialectRegistry - registerAllDialects dialectRegistry - mlirCtx <- mkMLIRContext - stablehlo <- convertHloToStablehlo mlirCtx !(proto computation) - appendDialectRegistry mlirCtx dialectRegistry - Just code <- serializePortableArtifact stablehlo | Nothing => throwE (SerializationError "Failed to serialize StableHLO") - -- code <- printModule stablehlo + code <- hloModuleProtoToStableHLO !(proto computation) executableBuildOptions <- mkExecutableBuildOptions compileOptions <- serializeAsString !(mkCompileOptions executableBuildOptions) program <- mkPjrtProgram code diff --git a/spidr/src/Compiler/FFI.idr b/spidr/src/Compiler/FFI.idr index 307b9574b..27e23c05c 100644 --- a/spidr/src/Compiler/FFI.idr +++ b/spidr/src/Compiler/FFI.idr @@ -26,6 +26,9 @@ libxla fname = "C:" ++ fname ++ ",libc_xla" public export data CharArray = MkCharArray (Ptr Char) Bits64 +public export +data CppString = MkCppString GCAnyPtr + namespace CharArray export free : HasIO io => CharArray -> io () diff --git a/spidr/src/Compiler/StableHLO/Dialect/Serialization.idr b/spidr/src/Compiler/StableHLO/Dialect/Serialization.idr index 4b8d55e02..533497805 100644 --- a/spidr/src/Compiler/StableHLO/Dialect/Serialization.idr +++ b/spidr/src/Compiler/StableHLO/Dialect/Serialization.idr @@ -23,17 +23,10 @@ import Compiler.FFI prim__serializePortableArtifact : AnyPtr -> AnyPtr -> PrimIO Int export -serializePortableArtifact : HasIO io => ModuleOp -> io (Maybe CharArray) -serializePortableArtifact (MkModuleOp moduleOp) = do +serializePortableArtifact : HasIO io => ModuleOp -> CppString -> io (Maybe CharArray) +serializePortableArtifact (MkModuleOp moduleOp) (MkCppString version) = do str <- primIO prim__stringNew - ok <- primIO $ prim__serializePortableArtifact moduleOp str + ok <- primIO $ prim__serializePortableArtifact moduleOp version str case cIntToBool ok of True => Just <$> stringToCharArray str False => free str >> pure Nothing - -%foreign (libxla "printModule") -prim__printModule : AnyPtr -> PrimIO AnyPtr - -export -printModule : HasIO io => ModuleOp -> io CharArray -printModule (MkModuleOp moduleOp) = primIO (prim__printModule moduleOp) >>= stringToCharArray diff --git a/spidr/src/Compiler/StableHLO/Dialect/Version.idr b/spidr/src/Compiler/StableHLO/Dialect/Version.idr new file mode 100644 index 000000000..0d8d4d531 --- /dev/null +++ b/spidr/src/Compiler/StableHLO/Dialect/Version.idr @@ -0,0 +1,45 @@ +{-- +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--} +||| For internal spidr use only. +module Compiler.StableHLO.Dialect.Version + +import Compiler.FFI + +export +data Version = MkVersion GCAnyPtr + +%foreign (libxla "Version_delete") +prim__delete : AnyPtr -> PrimIO () + +%foreign (libxla "Version_getMinimumVersion") +prim__versionGetMinimumVersion : PrimIO AnyPtr + +export +getMinimumVersion : HasIO io => io Version +getMinimumVersion = do + version <- primIO prim__versionGetMinimumVersion + version <- onCollectAny version (primIO . prim__delete) + pure (MkVersion version) + +%foreign (libxla "Version_toString") +prim__versionToString : GCAnyPtr -> PrimIO AnyPtr + +export +toString : HasIO io => Version -> io CppString +toString (MkVersion version) = do + str <- primIO $ prim__versionToString version + str <- onCollectAny str (primIO . prim__stringDelete) + pure (MkCppString str) From ed838f8f791dcd80719b658ef44df3976bef4235 Mon Sep 17 00:00:00 2001 From: Joel Berkeley <16429957+joelberkeley@users.noreply.github.com> Date: Fri, 13 Dec 2024 22:45:16 +0000 Subject: [PATCH 15/38] llvm --- spidr/backend/BUILD | 2 ++ spidr/backend/src/llvm/Support/BUILD | 12 +++++++ .../backend/src/llvm/Support/raw_ostream.cpp | 32 +++++++++++++++++ spidr/backend/src/llvm/Support/raw_ostream.h | 18 ++++++++++ spidr/backend/src/stablehlo/dialect/BUILD | 1 + .../src/stablehlo/dialect/Serialization.cpp | 10 +++--- spidr/src/Compiler/Eval.idr | 8 +++-- spidr/src/Compiler/FFI.idr | 18 +++++++--- .../src/Compiler/LLVM/Support/RawOStream.idr | 35 +++++++++++++++++++ .../StableHLO/Dialect/Serialization.idr | 14 ++++---- .../Compiler/StableHLO/Dialect/Version.idr | 2 +- .../src/Compiler/Xla/PJRT/PjrtExecutable.idr | 5 +-- spidr/src/Compiler/Xla/Service/HloProto.idr | 9 ----- 13 files changed, 134 insertions(+), 32 deletions(-) create mode 100644 spidr/backend/src/llvm/Support/BUILD create mode 100644 spidr/backend/src/llvm/Support/raw_ostream.cpp create mode 100644 spidr/backend/src/llvm/Support/raw_ostream.h create mode 100644 spidr/src/Compiler/LLVM/Support/RawOStream.idr diff --git a/spidr/backend/BUILD b/spidr/backend/BUILD index cff2e8398..34a2dee6a 100644 --- a/spidr/backend/BUILD +++ b/spidr/backend/BUILD @@ -12,6 +12,7 @@ cc_binary( linkshared = True, linkstatic = True, srcs = [ + "//src/llvm/Support", "//src/mlir/IR", "//src/stablehlo/dialect", "//src/xla", @@ -26,6 +27,7 @@ cc_binary( "//src", ], deps = [ + "//src/llvm/Support", "//src/mlir/IR", "//src/stablehlo/dialect", "//src/xla", diff --git a/spidr/backend/src/llvm/Support/BUILD b/spidr/backend/src/llvm/Support/BUILD new file mode 100644 index 000000000..12ee6525d --- /dev/null +++ b/spidr/backend/src/llvm/Support/BUILD @@ -0,0 +1,12 @@ +cc_library( + name = "Support", + linkstatic = True, + alwayslink = True, + srcs = glob(["*.cpp"]), + hdrs = glob(["*.h"]), + deps = [ + "@llvm-project//llvm:Support", + "//src", + ], + visibility = ["//visibility:public"], +) diff --git a/spidr/backend/src/llvm/Support/raw_ostream.cpp b/spidr/backend/src/llvm/Support/raw_ostream.cpp new file mode 100644 index 000000000..bb30c8b06 --- /dev/null +++ b/spidr/backend/src/llvm/Support/raw_ostream.cpp @@ -0,0 +1,32 @@ +/* +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "llvm/Support/raw_ostream.h" + +#include "../../ffi.h" +#include "raw_ostream.h" + +extern "C" { + struct raw_string_ostream; + + raw_string_ostream* raw_string_ostream_new(string& o) { + auto& o_ = reinterpret_cast(o); + return reinterpret_cast(new llvm::raw_string_ostream(o_)); + } + + void raw_string_ostream_delete(raw_string_ostream* s) { + delete reinterpret_cast(s); + } +} diff --git a/spidr/backend/src/llvm/Support/raw_ostream.h b/spidr/backend/src/llvm/Support/raw_ostream.h new file mode 100644 index 000000000..09f078918 --- /dev/null +++ b/spidr/backend/src/llvm/Support/raw_ostream.h @@ -0,0 +1,18 @@ +/* +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +extern "C" { + struct raw_ostream; +} diff --git a/spidr/backend/src/stablehlo/dialect/BUILD b/spidr/backend/src/stablehlo/dialect/BUILD index 91d6dd3ed..5f76ca13f 100644 --- a/spidr/backend/src/stablehlo/dialect/BUILD +++ b/spidr/backend/src/stablehlo/dialect/BUILD @@ -7,6 +7,7 @@ cc_library( deps = [ "@stablehlo//:register", "@stablehlo//:stablehlo_serialization", + "//src/llvm/Support", "//src/mlir/IR", ], visibility = ["//visibility:public"], diff --git a/spidr/backend/src/stablehlo/dialect/Serialization.cpp b/spidr/backend/src/stablehlo/dialect/Serialization.cpp index 5d334466b..8ba423ca7 100644 --- a/spidr/backend/src/stablehlo/dialect/Serialization.cpp +++ b/spidr/backend/src/stablehlo/dialect/Serialization.cpp @@ -16,15 +16,15 @@ limitations under the License. #include "stablehlo/dialect/Serialization.h" #include "../../mlir/IR/BuiltinOps.h" +#include "../../llvm/Support/raw_ostream.h" #include "../../ffi.h" extern "C" { - int serializePortableArtifact(ModuleOp& module, string& version, string& str) { + int serializePortableArtifact(ModuleOp& module, string& version, raw_ostream& os) { auto& module_ = reinterpret_cast(module); - auto& version_ = reinterpret_cast(str); - auto& str_ = reinterpret_cast(str); - llvm::raw_string_ostream os(str_); - auto result = mlir::stablehlo::serializePortableArtifact(module_, version_, os); + auto& version_ = reinterpret_cast(version); + auto& os_ = reinterpret_cast(os); + auto result = mlir::stablehlo::serializePortableArtifact(module_, version_, os_); return (int) result.succeeded(); } } diff --git a/spidr/src/Compiler/Eval.idr b/spidr/src/Compiler/Eval.idr index 133fed6af..de4a03d48 100644 --- a/spidr/src/Compiler/Eval.idr +++ b/spidr/src/Compiler/Eval.idr @@ -26,6 +26,7 @@ import Data.List.Elem import Compiler.Expr import Compiler.FFI import Compiler.LiteralRW +import Compiler.LLVM.Support.RawOStream import Compiler.MLIR.IR.BuiltinOps import Compiler.MLIR.IR.DialectRegistry import Compiler.MLIR.IR.MLIRContext @@ -238,9 +239,10 @@ hloModuleProtoToStableHLO proto = do mlirCtx <- mkMLIRContext stablehlo <- convertHloToStablehlo mlirCtx proto appendDialectRegistry mlirCtx dialectRegistry - Just code <- serializePortableArtifact stablehlo !(toString !getMinimumVersion) - | Nothing => throwE (SerializationError "Failed to serialize StableHLO") - pure code + code <- cppString + version <- toString !getMinimumVersion + ok <- serializePortableArtifact stablehlo version !(rawStringOStream code) + if ok then stringToCharArray code else throwE (SerializationError "Failed to serialize StableHLO") ||| It is up to the caller to free the `Literal`s. export covering diff --git a/spidr/src/Compiler/FFI.idr b/spidr/src/Compiler/FFI.idr index 27e23c05c..ea4bfd34f 100644 --- a/spidr/src/Compiler/FFI.idr +++ b/spidr/src/Compiler/FFI.idr @@ -27,7 +27,7 @@ public export data CharArray = MkCharArray (Ptr Char) Bits64 public export -data CppString = MkCppString GCAnyPtr +data CppString = MkCppString AnyPtr namespace CharArray export @@ -36,12 +36,22 @@ namespace CharArray export %foreign (libxla "string_new") -prim__stringNew : PrimIO AnyPtr +prim__mkString : PrimIO AnyPtr + +||| It is up to the caller to `delete` the string. +export +cppString : HasIO io => io CppString +cppString = MkCppString <$> primIO prim__mkString export %foreign (libxla "string_delete") prim__stringDelete : AnyPtr -> PrimIO () +namespace CppString + export + delete : HasIO io => CppString -> io () + delete (MkCppString str) = primIO $ prim__stringDelete str + export %foreign (libxla "string_data") prim__stringData : AnyPtr -> PrimIO $ Ptr Char @@ -56,8 +66,8 @@ prim__index : Int -> AnyPtr -> AnyPtr ||| Deletes the `string`. It is up to the caller to `free` the `CharArray`. export -stringToCharArray : HasIO io => AnyPtr -> io CharArray -stringToCharArray str = do +stringToCharArray : HasIO io => CppString -> io CharArray +stringToCharArray (MkCppString str) = do data' <- primIO $ prim__stringData str let size = prim__stringSize str primIO $ prim__stringDelete str diff --git a/spidr/src/Compiler/LLVM/Support/RawOStream.idr b/spidr/src/Compiler/LLVM/Support/RawOStream.idr new file mode 100644 index 000000000..f8b11e32a --- /dev/null +++ b/spidr/src/Compiler/LLVM/Support/RawOStream.idr @@ -0,0 +1,35 @@ +{-- +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--} +||| For internal spidr use only. +module Compiler.LLVM.Support.RawOStream + +import Compiler.FFI + +public export +data RawStringOStream = MkRawStringOStream GCAnyPtr + +%foreign (libxla "raw_string_ostream_new") +prim__mkRawStringOStream : AnyPtr -> PrimIO AnyPtr + +%foreign (libxla "raw_string_ostream_delete") +prim__delete : AnyPtr -> PrimIO () + +export +rawStringOStream : HasIO io => CppString -> io RawStringOStream +rawStringOStream (MkCppString str) = do + os <- primIO $ prim__mkRawStringOStream str + os <- onCollectAny os (primIO . prim__delete) + pure (MkRawStringOStream os) diff --git a/spidr/src/Compiler/StableHLO/Dialect/Serialization.idr b/spidr/src/Compiler/StableHLO/Dialect/Serialization.idr index 533497805..5e493981d 100644 --- a/spidr/src/Compiler/StableHLO/Dialect/Serialization.idr +++ b/spidr/src/Compiler/StableHLO/Dialect/Serialization.idr @@ -16,17 +16,15 @@ limitations under the License. ||| For internal spidr use only. module Compiler.StableHLO.Dialect.Serialization +import Compiler.LLVM.Support.RawOStream import Compiler.MLIR.IR.BuiltinOps import Compiler.FFI %foreign (libxla "serializePortableArtifact") -prim__serializePortableArtifact : AnyPtr -> AnyPtr -> PrimIO Int +prim__serializePortableArtifact : AnyPtr -> AnyPtr -> GCAnyPtr -> PrimIO Int export -serializePortableArtifact : HasIO io => ModuleOp -> CppString -> io (Maybe CharArray) -serializePortableArtifact (MkModuleOp moduleOp) (MkCppString version) = do - str <- primIO prim__stringNew - ok <- primIO $ prim__serializePortableArtifact moduleOp version str - case cIntToBool ok of - True => Just <$> stringToCharArray str - False => free str >> pure Nothing +serializePortableArtifact : HasIO io => ModuleOp -> CppString -> RawStringOStream -> io Bool +serializePortableArtifact (MkModuleOp moduleOp) (MkCppString version) (MkRawStringOStream os) = do + ok <- primIO $ prim__serializePortableArtifact moduleOp version os + pure (cIntToBool ok) diff --git a/spidr/src/Compiler/StableHLO/Dialect/Version.idr b/spidr/src/Compiler/StableHLO/Dialect/Version.idr index 0d8d4d531..bad9ca363 100644 --- a/spidr/src/Compiler/StableHLO/Dialect/Version.idr +++ b/spidr/src/Compiler/StableHLO/Dialect/Version.idr @@ -37,9 +37,9 @@ getMinimumVersion = do %foreign (libxla "Version_toString") prim__versionToString : GCAnyPtr -> PrimIO AnyPtr +||| It is up to the caller to `delete` the string. export toString : HasIO io => Version -> io CppString toString (MkVersion version) = do str <- primIO $ prim__versionToString version - str <- onCollectAny str (primIO . prim__stringDelete) pure (MkCppString str) diff --git a/spidr/src/Compiler/Xla/PJRT/PjrtExecutable.idr b/spidr/src/Compiler/Xla/PJRT/PjrtExecutable.idr index acdb83406..a2e1cc136 100644 --- a/spidr/src/Compiler/Xla/PJRT/PjrtExecutable.idr +++ b/spidr/src/Compiler/Xla/PJRT/PjrtExecutable.idr @@ -39,5 +39,6 @@ prim__compileOptionsSerializeAsString : GCAnyPtr -> PrimIO AnyPtr ||| It is up to the caller to `free` the `CharArray`. export serializeAsString : HasIO io => CompileOptions -> io CharArray -serializeAsString (MkCompileOptions options) = - primIO (prim__compileOptionsSerializeAsString options) >>= stringToCharArray +serializeAsString (MkCompileOptions options) = do + str <- primIO (prim__compileOptionsSerializeAsString options) + stringToCharArray (MkCppString str) diff --git a/spidr/src/Compiler/Xla/Service/HloProto.idr b/spidr/src/Compiler/Xla/Service/HloProto.idr index 9e7ce2b2c..5cd389aae 100644 --- a/spidr/src/Compiler/Xla/Service/HloProto.idr +++ b/spidr/src/Compiler/Xla/Service/HloProto.idr @@ -21,15 +21,6 @@ import Compiler.FFI public export data HloModuleProto = MkHloModuleProto GCAnyPtr -%foreign (libxla "HloModuleProto_SerializeAsString") -prim__hloModuleProtoSerializeAsString : GCAnyPtr -> PrimIO AnyPtr - export %foreign (libxla "HloModuleProto_delete") prim__delete : AnyPtr -> PrimIO () - -||| It is up to the caller to `free` the `CharArray`. -export -serializeAsString : HasIO io => HloModuleProto -> io CharArray -serializeAsString (MkHloModuleProto proto) = - primIO (prim__hloModuleProtoSerializeAsString proto) >>= stringToCharArray From 59a317dd1f2985a2ce637da617c37570630382e0 Mon Sep 17 00:00:00 2001 From: Joel Berkeley <16429957+joelberkeley@users.noreply.github.com> Date: Fri, 13 Dec 2024 23:11:47 +0000 Subject: [PATCH 16/38] tidy --- .../src/mlir/IR/BuiltinOps.cpp} | 23 +++++------- spidr/backend/src/xla/service/hlo.proto.cpp | 7 ---- spidr/src/Compiler/MLIR/IR/BuiltinOps.idr | 8 +++- .../StableHLO/Dialect/Serialization.idr | 2 +- spidr/src/Compiler/Xla/HLO/IR/HloModule.idr | 37 ------------------- .../Compiler/Xla/HLO/Translate/StableHLO.idr | 1 + 6 files changed, 18 insertions(+), 60 deletions(-) rename spidr/{src/Compiler/Xla/HLO/Translate/PortableAPI.idr => backend/src/mlir/IR/BuiltinOps.cpp} (53%) delete mode 100644 spidr/src/Compiler/Xla/HLO/IR/HloModule.idr diff --git a/spidr/src/Compiler/Xla/HLO/Translate/PortableAPI.idr b/spidr/backend/src/mlir/IR/BuiltinOps.cpp similarity index 53% rename from spidr/src/Compiler/Xla/HLO/Translate/PortableAPI.idr rename to spidr/backend/src/mlir/IR/BuiltinOps.cpp index 7aaba289a..5609c3a95 100644 --- a/spidr/src/Compiler/Xla/HLO/Translate/PortableAPI.idr +++ b/spidr/backend/src/mlir/IR/BuiltinOps.cpp @@ -1,4 +1,4 @@ -{-- +/* Copyright 2024 Joel Berkeley Licensed under the Apache License, Version 2.0 (the "License"); @@ -12,18 +12,13 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ---} -||| For internal spidr use only. -module Compiler.Xla.HLO.Translate.PortableAPI +*/ +#include "mlir/IR/BuiltinOps.h" -import Compiler.FFI -import Compiler.Xla.HLO.IR.HloModule +#include "BuiltinOps.h" -%foreign (libxla "ConvertHloToStablehlo") -prim__convertHloToStablehlo : GCAnyPtr -> PrimIO AnyPtr - -||| It is up to the caller to `free` the `CharArray`. -export -convertHloToStablehlo : HasIO io => HloModule -> io CharArray -convertHloToStablehlo (MkHloModule module') = - primIO (prim__convertHloToStablehlo module') >>= stringToCharArray +extern "C" { + void ModuleOp_delete(ModuleOp* s) { + delete reinterpret_cast(s); + } +} diff --git a/spidr/backend/src/xla/service/hlo.proto.cpp b/spidr/backend/src/xla/service/hlo.proto.cpp index 195d17026..f62a63af0 100644 --- a/spidr/backend/src/xla/service/hlo.proto.cpp +++ b/spidr/backend/src/xla/service/hlo.proto.cpp @@ -14,17 +14,10 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "xla/service/hlo.pb.h" -// #include "xla/service/..." // try to import from some random place -#include "../../ffi.h" #include "hlo.proto.h" extern "C" { - string* HloModuleProto_SerializeAsString(HloModuleProto& s) { - auto s_ = reinterpret_cast(s); - return reinterpret_cast(new std::string(s_.SerializeAsString())); - } - void HloModuleProto_delete(HloModuleProto* s) { delete reinterpret_cast(s); } diff --git a/spidr/src/Compiler/MLIR/IR/BuiltinOps.idr b/spidr/src/Compiler/MLIR/IR/BuiltinOps.idr index 9f06c7327..ba44b6583 100644 --- a/spidr/src/Compiler/MLIR/IR/BuiltinOps.idr +++ b/spidr/src/Compiler/MLIR/IR/BuiltinOps.idr @@ -16,5 +16,11 @@ limitations under the License. ||| For internal spidr use only. module Compiler.MLIR.IR.BuiltinOps +import Compiler.FFI + public export -data ModuleOp = MkModuleOp AnyPtr -- need to GC +data ModuleOp = MkModuleOp GCAnyPtr + +export +%foreign (libxla "ModuleOp_delete") +prim__delete : AnyPtr -> PrimIO () diff --git a/spidr/src/Compiler/StableHLO/Dialect/Serialization.idr b/spidr/src/Compiler/StableHLO/Dialect/Serialization.idr index 5e493981d..299f9269c 100644 --- a/spidr/src/Compiler/StableHLO/Dialect/Serialization.idr +++ b/spidr/src/Compiler/StableHLO/Dialect/Serialization.idr @@ -21,7 +21,7 @@ import Compiler.MLIR.IR.BuiltinOps import Compiler.FFI %foreign (libxla "serializePortableArtifact") -prim__serializePortableArtifact : AnyPtr -> AnyPtr -> GCAnyPtr -> PrimIO Int +prim__serializePortableArtifact : GCAnyPtr -> AnyPtr -> GCAnyPtr -> PrimIO Int export serializePortableArtifact : HasIO io => ModuleOp -> CppString -> RawStringOStream -> io Bool diff --git a/spidr/src/Compiler/Xla/HLO/IR/HloModule.idr b/spidr/src/Compiler/Xla/HLO/IR/HloModule.idr deleted file mode 100644 index ba438842e..000000000 --- a/spidr/src/Compiler/Xla/HLO/IR/HloModule.idr +++ /dev/null @@ -1,37 +0,0 @@ -{-- -Copyright 2024 Joel Berkeley - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. ---} -||| For internal spidr use only. -module Compiler.Xla.HLO.IR.HloModule - -import Compiler.FFI -import Compiler.Xla.Service.HloModuleConfig -import Compiler.Xla.Service.HloProto - -public export -data HloModule = MkHloModule GCAnyPtr - -%foreign (libxla "HloModule_delete") -prim__delete : AnyPtr -> PrimIO () - -%foreign (libxla "HloModule_CreateFromProto") -prim__hloModuleCreateFromProto : GCAnyPtr -> GCAnyPtr -> PrimIO AnyPtr - -export -createFromProto : HasIO io => HloModuleProto -> HloModuleConfig -> io HloModule -createFromProto (MkHloModuleProto proto) (MkHloModuleConfig config) = do - module' <- primIO $ prim__hloModuleCreateFromProto proto config - module' <- onCollectAny module' (primIO . HloModule.prim__delete) - pure (MkHloModule module') diff --git a/spidr/src/Compiler/Xla/HLO/Translate/StableHLO.idr b/spidr/src/Compiler/Xla/HLO/Translate/StableHLO.idr index b58bbe055..ee634a6cf 100644 --- a/spidr/src/Compiler/Xla/HLO/Translate/StableHLO.idr +++ b/spidr/src/Compiler/Xla/HLO/Translate/StableHLO.idr @@ -28,4 +28,5 @@ export convertHloToStablehlo : HasIO io => MLIRContext -> HloModuleProto -> io ModuleOp convertHloToStablehlo (MkMLIRContext ctx) (MkHloModuleProto proto) = do moduleOp <- primIO $ prim__convertHloToStablehlo ctx proto + moduleOp <- onCollectAny moduleOp (primIO . BuiltinOps.prim__delete) pure (MkModuleOp moduleOp) From b983e1377a7bbaed5cf626585685505f79f8c4f8 Mon Sep 17 00:00:00 2001 From: Joel Berkeley <16429957+joelberkeley@users.noreply.github.com> Date: Fri, 13 Dec 2024 23:20:40 +0000 Subject: [PATCH 17/38] wip --- spidr/spidr.ipkg | 1 + 1 file changed, 1 insertion(+) diff --git a/spidr/spidr.ipkg b/spidr/spidr.ipkg index b83273331..7807de7b0 100644 --- a/spidr/spidr.ipkg +++ b/spidr/spidr.ipkg @@ -8,6 +8,7 @@ modules = BayesianOptimization, BayesianOptimization.Acquisition, + Compiler.LLVM.Support.RawOStream, Compiler.MLIR.IR.BuiltinOps, Compiler.MLIR.IR.DialectRegistry, Compiler.MLIR.IR.MLIRContext, From 189f40adf9bd30689f17b48aee4aae03581b8656 Mon Sep 17 00:00:00 2001 From: Joel Berkeley <16429957+joelberkeley@users.noreply.github.com> Date: Mon, 16 Dec 2024 00:36:02 +0000 Subject: [PATCH 18/38] draft AD with enzyme --- dev.sh | 17 ++++--- spidr/backend/.gitignore | 1 + spidr/backend/BUILD | 4 ++ spidr/backend/ENZYME_JAX_VERSION | 1 + spidr/backend/WORKSPACE | 27 +++++++++++ spidr/backend/build.sh | 8 +++- .../src/enzyme_ad/jax/Implementations/BUILD | 12 +++++ .../StableHLOAutoDiffOpInterfaceImpl.cpp | 25 ++++++++++ spidr/backend/src/mlir/IR/Operation.h | 18 +++++++ spidr/backend/src/mlir/Pass/BUILD | 12 +++++ spidr/backend/src/mlir/Pass/PassManager.cpp | 39 +++++++++++++++ .../src/xla/hlo/builder/xla_builder.cpp | 12 +++++ .../src/xla/hlo/builder/xla_computation.cpp | 6 +++ .../src/xla/hlo/translate/stablehlo.cpp | 7 +++ spidr/spidr.ipkg | 2 + .../StableHLOAutoDiffOpInterfaceImpl.idr | 28 +++++++++++ .../Compiler/EnzymeJAX/Support/RawOStream.idr | 35 ++++++++++++++ spidr/src/Compiler/Eval.idr | 48 +++++++++++++------ spidr/src/Compiler/Expr.idr | 2 + spidr/src/Compiler/MLIR/Pass/PassManager.idr | 46 ++++++++++++++++++ .../Compiler/Xla/HLO/Builder/XlaBuilder.idr | 11 +++++ .../Xla/HLO/Builder/XlaComputation.idr | 10 ++++ .../Compiler/Xla/HLO/Translate/StableHLO.idr | 10 ++++ spidr/src/Tensor.idr | 11 +++++ test/runner/Unit/TestTensor.idr | 4 +- test/runner/Unit/TestTensor/AD.idr | 43 +++++++++++++++++ test/runner/runner.ipkg | 1 + 27 files changed, 416 insertions(+), 24 deletions(-) create mode 100644 spidr/backend/ENZYME_JAX_VERSION create mode 100644 spidr/backend/src/Enzyme-JAX/src/enzyme_ad/jax/Implementations/BUILD create mode 100644 spidr/backend/src/Enzyme-JAX/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp create mode 100644 spidr/backend/src/mlir/IR/Operation.h create mode 100644 spidr/backend/src/mlir/Pass/BUILD create mode 100644 spidr/backend/src/mlir/Pass/PassManager.cpp create mode 100644 spidr/src/Compiler/EnzymeJAX/Src/EnzymeAD/JAX/Implementations/StableHLOAutoDiffOpInterfaceImpl.idr create mode 100644 spidr/src/Compiler/EnzymeJAX/Support/RawOStream.idr create mode 100644 spidr/src/Compiler/MLIR/Pass/PassManager.idr create mode 100644 test/runner/Unit/TestTensor/AD.idr diff --git a/dev.sh b/dev.sh index f0d395ace..bf14e875c 100644 --- a/dev.sh +++ b/dev.sh @@ -8,12 +8,7 @@ short_revision () { echo "${rev%%"${rev##??????????}"}" } -install_xla () { - if [ -z "$2" ]; then - echo "Usage: install_xla ." - exit 1; - fi - +install_git_repository () { if [ "$(ls -A "$2")" ]; then echo "Directory at path $2 is not empty, refusing to install XLA to this directory." exit 1; @@ -22,8 +17,16 @@ install_xla () { ( cd "$2" git init - git remote add origin https://github.com/openxla/xla + git remote add origin $3 git fetch --depth 1 origin "$1" git checkout FETCH_HEAD ) } + +install_xla () { + install_git_repository $1 $2 https://github.com/openxla/xla +} + +install_enzyme () { + install_git_repository $1 $2 https://github.com/EnzymeAD/Enzyme-JAX.git +} diff --git a/spidr/backend/.gitignore b/spidr/backend/.gitignore index 24a3274c5..2fbfae974 100644 --- a/spidr/backend/.gitignore +++ b/spidr/backend/.gitignore @@ -1 +1,2 @@ +/Enzyme-JAX /xla diff --git a/spidr/backend/BUILD b/spidr/backend/BUILD index 34a2dee6a..cad409a7d 100644 --- a/spidr/backend/BUILD +++ b/spidr/backend/BUILD @@ -12,8 +12,10 @@ cc_binary( linkshared = True, linkstatic = True, srcs = [ + "//src/Enzyme-JAX/src/enzyme_ad/jax/Implementations", "//src/llvm/Support", "//src/mlir/IR", + "//src/mlir/Pass", "//src/stablehlo/dialect", "//src/xla", "//src/xla/client", @@ -27,8 +29,10 @@ cc_binary( "//src", ], deps = [ + "//src/Enzyme-JAX/src/enzyme_ad/jax/Implementations", "//src/llvm/Support", "//src/mlir/IR", + "//src/mlir/Pass", "//src/stablehlo/dialect", "//src/xla", "//src/xla/client", diff --git a/spidr/backend/ENZYME_JAX_VERSION b/spidr/backend/ENZYME_JAX_VERSION new file mode 100644 index 000000000..6183501f6 --- /dev/null +++ b/spidr/backend/ENZYME_JAX_VERSION @@ -0,0 +1 @@ +v0.0.9 \ No newline at end of file diff --git a/spidr/backend/WORKSPACE b/spidr/backend/WORKSPACE index 991ab29d9..4f53e783a 100644 --- a/spidr/backend/WORKSPACE +++ b/spidr/backend/WORKSPACE @@ -1,3 +1,5 @@ +### xla + # this must be a local repository not http archive # so we can run ./configure.py before invoking bazel local_repository(name = "xla", path = "xla") @@ -28,3 +30,28 @@ xla_workspace0() load("@tsl//third_party/gpus/cuda/hermetic:cuda_configure.bzl", "cuda_configure") cuda_configure(name = "local_config_cuda") + +### Enzyme-JAX +# note enzyme-jax specifies XLA versions, which we're currently ignoring. Do we need to use their versions? + +load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") + +local_repository(name = "enzyme-jax", path = "Enzyme-JAX") + +load("@enzyme-jax//:workspace.bzl", "JAX_COMMIT", "JAX_SHA256", "ENZYME_COMMIT", "ENZYME_SHA256") + +http_archive( + name = "jax", + sha256 = JAX_SHA256, + strip_prefix = "jax-" + JAX_COMMIT, + urls = ["https://github.com/google/jax/archive/{commit}.tar.gz".format(commit = JAX_COMMIT)], + patch_args = ["-p1"], + patches = ["@enzyme-jax//:patches/jax.patch"], +) + +http_archive( + name = "enzyme", + sha256 = ENZYME_SHA256, + strip_prefix = "Enzyme-" + ENZYME_COMMIT + "/enzyme", + urls = ["https://github.com/EnzymeAD/Enzyme/archive/{commit}.tar.gz".format(commit = ENZYME_COMMIT)], +) diff --git a/spidr/backend/build.sh b/spidr/backend/build.sh index e88541581..5f32cc19e 100755 --- a/spidr/backend/build.sh +++ b/spidr/backend/build.sh @@ -3,7 +3,8 @@ script_dir=$(CDPATH="" cd -- "$(dirname -- "$0")" && pwd) cd "$script_dir/../.." . ./dev.sh -rev="$(cat XLA_VERSION)" +xla_rev="$(cat XLA_VERSION)" +enzyme_rev="$(cat spidr/backend/ENZYME_JAX_VERSION)" osu="$(uname)" case $osu in @@ -26,8 +27,11 @@ esac ( cd spidr/backend mkdir xla - install_xla "$rev" xla + install_xla "$xla_rev" xla (cd xla; ./configure.py --backend=cpu --os=$os) + mkdir Enzyme-JAX + install_enzyme "$enzyme_rev" Enzyme-JAX + sed -i -e 's/"-Werror=unused-variable",//g' Enzyme-JAX/src/enzyme_ad/jax/BUILD bazel build //:c_xla rm -rf xla ) diff --git a/spidr/backend/src/Enzyme-JAX/src/enzyme_ad/jax/Implementations/BUILD b/spidr/backend/src/Enzyme-JAX/src/enzyme_ad/jax/Implementations/BUILD new file mode 100644 index 000000000..a3208d05e --- /dev/null +++ b/spidr/backend/src/Enzyme-JAX/src/enzyme_ad/jax/Implementations/BUILD @@ -0,0 +1,12 @@ +cc_library( + name = "Implementations", + linkstatic = True, + alwayslink = True, + srcs = glob(["*.cpp"]), + hdrs = glob(["*.h"]), + deps = [ + "@enzyme-jax//:enzymexlamlir-opt", + "//src/mlir/IR", + ], + visibility = ["//visibility:public"], +) diff --git a/spidr/backend/src/Enzyme-JAX/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp b/spidr/backend/src/Enzyme-JAX/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp new file mode 100644 index 000000000..09e36b435 --- /dev/null +++ b/spidr/backend/src/Enzyme-JAX/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp @@ -0,0 +1,25 @@ +/* +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "src/enzyme_ad/jax/Implementations/XLADerivatives.h" + +#include "../../../../../mlir/IR/DialectRegistry.h" + +extern "C" { + void registerStableHLODialectAutoDiffInterface(DialectRegistry& registry) { + auto& registry_ = reinterpret_cast(registry); + mlir::enzyme::registerStableHLODialectAutoDiffInterface(registry_); + } +} diff --git a/spidr/backend/src/mlir/IR/Operation.h b/spidr/backend/src/mlir/IR/Operation.h new file mode 100644 index 000000000..31743deed --- /dev/null +++ b/spidr/backend/src/mlir/IR/Operation.h @@ -0,0 +1,18 @@ +/* +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +extern "C" { + struct Operation; +} diff --git a/spidr/backend/src/mlir/Pass/BUILD b/spidr/backend/src/mlir/Pass/BUILD new file mode 100644 index 000000000..125748060 --- /dev/null +++ b/spidr/backend/src/mlir/Pass/BUILD @@ -0,0 +1,12 @@ +cc_library( + name = "Pass", + linkstatic = True, + alwayslink = True, + srcs = glob(["*.cpp"]), + hdrs = glob(["*.h"]), + deps = [ + "@llvm-project//mlir:Pass", + "//src/mlir/IR", + ], + visibility = ["//visibility:public"], +) diff --git a/spidr/backend/src/mlir/Pass/PassManager.cpp b/spidr/backend/src/mlir/Pass/PassManager.cpp new file mode 100644 index 000000000..50b5afcc5 --- /dev/null +++ b/spidr/backend/src/mlir/Pass/PassManager.cpp @@ -0,0 +1,39 @@ +/* +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "mlir/Pass/PassManager.h" + +#include "../IR/BuiltinOps.h" +#include "../IR/MLIRContext.h" +#include "../IR/Operation.h" + +extern "C" { + struct PassManager; + + PassManager* PassManager_new(MLIRContext* ctx) { + auto ctx_ = reinterpret_cast(ctx); + return reinterpret_cast(new mlir::PassManager(ctx_)); + } + + void PassManager_delete(PassManager* s) { + delete reinterpret_cast(s); + } + + int PassManager_run(PassManager& s, Operation* op) { + auto& s_ = reinterpret_cast(s); + auto op_ = reinterpret_cast(op); + return (int) s_.run(op_).succeeded(); + } +} diff --git a/spidr/backend/src/xla/hlo/builder/xla_builder.cpp b/spidr/backend/src/xla/hlo/builder/xla_builder.cpp index 195562645..289ca6938 100644 --- a/spidr/backend/src/xla/hlo/builder/xla_builder.cpp +++ b/spidr/backend/src/xla/hlo/builder/xla_builder.cpp @@ -290,6 +290,18 @@ extern "C" { return reinterpret_cast(new xla::XlaOp(res)); } + XlaOp* Call( + XlaBuilder* builder, XlaComputation& computation, XlaOp* operands, size_t operands_len + ) { + auto builder_ = reinterpret_cast(builder); + auto& computation_ = reinterpret_cast(computation); + auto operands_ = reinterpret_cast(operands); + auto operands_span = absl::Span(operands_, operands_len); + + auto res = xla::Call(builder_, computation_, operands_span); + return reinterpret_cast(new xla::XlaOp(res)); + } + XlaOp* Add(XlaOp& lhs, XlaOp& rhs) { return binOp(xla::Add, lhs, rhs); } XlaOp* Sub(XlaOp& lhs, XlaOp& rhs) { return binOp(xla::Sub, lhs, rhs); } XlaOp* Mul(XlaOp& lhs, XlaOp& rhs) { return binOp(xla::Mul, lhs, rhs); } diff --git a/spidr/backend/src/xla/hlo/builder/xla_computation.cpp b/spidr/backend/src/xla/hlo/builder/xla_computation.cpp index 8695e74ee..2c75b2089 100644 --- a/spidr/backend/src/xla/hlo/builder/xla_computation.cpp +++ b/spidr/backend/src/xla/hlo/builder/xla_computation.cpp @@ -22,6 +22,12 @@ limitations under the License. #include "xla_computation.h" extern "C" { + XlaComputation* XlaComputation_new(HloModuleProto& proto) { + auto& proto_ = reinterpret_cast(proto); + // this moves the proto? should we then not GC it? + return reinterpret_cast(new xla::XlaComputation(proto_)); + } + void XlaComputation_delete(XlaComputation* s) { delete reinterpret_cast(s); } diff --git a/spidr/backend/src/xla/hlo/translate/stablehlo.cpp b/spidr/backend/src/xla/hlo/translate/stablehlo.cpp index e358aeca7..6b023a961 100644 --- a/spidr/backend/src/xla/hlo/translate/stablehlo.cpp +++ b/spidr/backend/src/xla/hlo/translate/stablehlo.cpp @@ -28,4 +28,11 @@ extern "C" { auto module_op = xla::ConvertHloToStablehlo(ctx_, hlo_module_); return reinterpret_cast(new mlir::ModuleOp(module_op.value().release())); } + + HloModuleProto* ConvertStablehloToHlo(ModuleOp& module) { + auto& module_ = reinterpret_cast(module); + // mode ToProto to separate function? + auto res = xla::ConvertStablehloToHlo(module_).value().release()->ToProto(); + return reinterpret_cast(new xla::HloModuleProto(res)); + } } diff --git a/spidr/spidr.ipkg b/spidr/spidr.ipkg index 7807de7b0..646582345 100644 --- a/spidr/spidr.ipkg +++ b/spidr/spidr.ipkg @@ -8,10 +8,12 @@ modules = BayesianOptimization, BayesianOptimization.Acquisition, + Compiler.EnzymeJAX.Src.EnzymeAD.JAX.Implementations.StableHLOAutoDiffOpInterfaceImpl, Compiler.LLVM.Support.RawOStream, Compiler.MLIR.IR.BuiltinOps, Compiler.MLIR.IR.DialectRegistry, Compiler.MLIR.IR.MLIRContext, + Compiler.MLIR.Pass.PassManager, Compiler.StableHLO.Dialect.Register, Compiler.StableHLO.Dialect.Serialization, Compiler.StableHLO.Dialect.Version, diff --git a/spidr/src/Compiler/EnzymeJAX/Src/EnzymeAD/JAX/Implementations/StableHLOAutoDiffOpInterfaceImpl.idr b/spidr/src/Compiler/EnzymeJAX/Src/EnzymeAD/JAX/Implementations/StableHLOAutoDiffOpInterfaceImpl.idr new file mode 100644 index 000000000..2be237d73 --- /dev/null +++ b/spidr/src/Compiler/EnzymeJAX/Src/EnzymeAD/JAX/Implementations/StableHLOAutoDiffOpInterfaceImpl.idr @@ -0,0 +1,28 @@ +{-- +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--} +||| For internal spidr use only. +module Compiler.EnzymeJAX.Src.EnzymeAD.JAX.Implementations.StableHLOAutoDiffOpInterfaceImpl + +import Compiler.MLIR.IR.DialectRegistry +import Compiler.FFI + +%foreign (libxla "registerStableHLODialectAutoDiffInterface") +prim__registerStableHLODialectAutoDiffInterface : GCAnyPtr -> PrimIO () + +export +registerStableHLODialectAutoDiffInterface : HasIO io => DialectRegistry -> io () +registerStableHLODialectAutoDiffInterface (MkDialectRegistry reg) = + primIO $ prim__registerStableHLODialectAutoDiffInterface reg diff --git a/spidr/src/Compiler/EnzymeJAX/Support/RawOStream.idr b/spidr/src/Compiler/EnzymeJAX/Support/RawOStream.idr new file mode 100644 index 000000000..f8b11e32a --- /dev/null +++ b/spidr/src/Compiler/EnzymeJAX/Support/RawOStream.idr @@ -0,0 +1,35 @@ +{-- +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--} +||| For internal spidr use only. +module Compiler.LLVM.Support.RawOStream + +import Compiler.FFI + +public export +data RawStringOStream = MkRawStringOStream GCAnyPtr + +%foreign (libxla "raw_string_ostream_new") +prim__mkRawStringOStream : AnyPtr -> PrimIO AnyPtr + +%foreign (libxla "raw_string_ostream_delete") +prim__delete : AnyPtr -> PrimIO () + +export +rawStringOStream : HasIO io => CppString -> io RawStringOStream +rawStringOStream (MkCppString str) = do + os <- primIO $ prim__mkRawStringOStream str + os <- onCollectAny os (primIO . prim__delete) + pure (MkRawStringOStream os) diff --git a/spidr/src/Compiler/Eval.idr b/spidr/src/Compiler/Eval.idr index de4a03d48..b508a357b 100644 --- a/spidr/src/Compiler/Eval.idr +++ b/spidr/src/Compiler/Eval.idr @@ -26,10 +26,12 @@ import Data.List.Elem import Compiler.Expr import Compiler.FFI import Compiler.LiteralRW +import Compiler.EnzymeJAX.Src.EnzymeAD.JAX.Implementations.StableHLOAutoDiffOpInterfaceImpl import Compiler.LLVM.Support.RawOStream import Compiler.MLIR.IR.BuiltinOps import Compiler.MLIR.IR.DialectRegistry import Compiler.MLIR.IR.MLIRContext +import Compiler.MLIR.Pass.PassManager import Compiler.StableHLO.Dialect.Register import Compiler.StableHLO.Dialect.Serialization import Compiler.StableHLO.Dialect.Version @@ -76,6 +78,22 @@ public export 0 ErrIO : Type -> Type ErrIO = EitherT Err IO +serializeStableHLO : ModuleOp -> ErrIO CharArray +serializeStableHLO stablehlo = do + code <- cppString + version <- toString !getMinimumVersion + ok <- serializePortableArtifact stablehlo version !(rawStringOStream code) + if ok then stringToCharArray code else throwE (SerializationError "Failed to serialize StableHLO") + +hloModuleProtoToStableHLO : HloModuleProto -> ErrIO ModuleOp +hloModuleProtoToStableHLO proto = do + dialectRegistry <- mkDialectRegistry + registerAllMhloDialects dialectRegistry + registerAllDialects dialectRegistry + mlirCtx <- mkMLIRContext + appendDialectRegistry mlirCtx dialectRegistry + convertHloToStablehlo mlirCtx proto + covering interpret : IOArray XlaOp => XlaBuilder -> Fn arity -> ErrIO XlaOp @@ -111,6 +129,21 @@ interpret @{cache} xlaBuilder (MkFn params root env) = do interpretE (Var x) = get x interpretE (Tuple xs) = tuple xlaBuilder !(traverse interpretE xs) interpretE (GetTupleElement idx x) = getTupleElement !(interpretE x) idx + interpretE (Grad f x) = do + reg <- mkDialectRegistry + StableHLO.Dialect.Register.registerAllDialects reg + registerStableHLODialectAutoDiffInterface reg + ctx <- mkMLIRContext + appendDialectRegistry ctx reg + mgr <- mkPassManager ctx + computation <- compile xlaBuilder f + stablehlo <- hloModuleProtoToStableHLO !(proto computation) -- using the wrong function, we want the module, not the string + True <- run mgr stablehlo | False => ?err + hloProto <- convertStablehloToHlo stablehlo + computation <- mkXlaComputation hloProto + -- x should be correct shape, because we're sending R^{n0, n1, ..} -> R + -- to R^{n0, n1, ..} -> R^{n0, n1, ..} i.e. we're only changing the output shape + call xlaBuilder computation [!(interpretE x)] interpretE (MinValue {dtype}) = minValue {dtype} xlaBuilder interpretE (MaxValue {dtype}) = maxValue {dtype} xlaBuilder interpretE (MinFiniteValue {dtype}) = minFiniteValue {dtype} xlaBuilder @@ -231,26 +264,13 @@ interpret @{cache} xlaBuilder (MkFn params root env) = do !(interpretE key) !(interpretE initialState) ThreeFry !(mkShape {dtype = F64} shape) tuple xlaBuilder [value rngOutput, state rngOutput] -hloModuleProtoToStableHLO : HloModuleProto -> ErrIO CharArray -hloModuleProtoToStableHLO proto = do - dialectRegistry <- mkDialectRegistry - registerAllMhloDialects dialectRegistry - registerAllDialects dialectRegistry - mlirCtx <- mkMLIRContext - stablehlo <- convertHloToStablehlo mlirCtx proto - appendDialectRegistry mlirCtx dialectRegistry - code <- cppString - version <- toString !getMinimumVersion - ok <- serializePortableArtifact stablehlo version !(rawStringOStream code) - if ok then stringToCharArray code else throwE (SerializationError "Failed to serialize StableHLO") - ||| It is up to the caller to free the `Literal`s. export covering execute : Device -> Fn 0 -> {outputs : _} -> Vect outputs Xla.Shape -> ErrIO $ Vect outputs Literal execute (MkDevice api client) f@(MkFn _ _ env) shapes = do xlaBuilder <- mkXlaBuilder "root" computation <- compile @{!(newArray $ cast $ counter env)} xlaBuilder f - code <- hloModuleProtoToStableHLO !(proto computation) + code <- serializeStableHLO !(hloModuleProtoToStableHLO !(proto computation)) executableBuildOptions <- mkExecutableBuildOptions compileOptions <- serializeAsString !(mkCompileOptions executableBuildOptions) program <- mkPjrtProgram code diff --git a/spidr/src/Compiler/Expr.idr b/spidr/src/Compiler/Expr.idr index c5c4ff68d..1912b8062 100644 --- a/spidr/src/Compiler/Expr.idr +++ b/spidr/src/Compiler/Expr.idr @@ -106,6 +106,7 @@ data Expr : Type where Var : Nat -> Expr Tuple : List Expr -> Expr GetTupleElement : (index : Nat) -> Expr -> Expr + Grad : Fn 1 -> Expr -> Expr -- temporary name MinValue : Primitive dtype => Expr MaxValue : Primitive dtype => Expr MinFiniteValue : Primitive dtype => Expr @@ -184,6 +185,7 @@ showExpr indent (FromLiteral {shape, dtype} x) = "Lit \{shape} \{xlaIdentifier { showExpr indent (Var k) = "Var \{k}" showExpr indent (Tuple xs) = "Tuple \{showExprList indent xs}" showExpr indent (GetTupleElement k x) = "GetTupleElement {index = \{k}} (\{showExpr indent x})" +showExpr indent (Grad _ _) = "Grad" showExpr indent (MinValue {dtype}) = "MinValue {dtype = \{xlaIdentifier {dtype}}}" showExpr indent (MaxValue {dtype}) = "MaxValue {dtype = \{xlaIdentifier {dtype}}}" showExpr indent (MinFiniteValue {dtype}) = "MinFiniteValue {dtype = \{xlaIdentifier {dtype}}}" diff --git a/spidr/src/Compiler/MLIR/Pass/PassManager.idr b/spidr/src/Compiler/MLIR/Pass/PassManager.idr new file mode 100644 index 000000000..896b471d9 --- /dev/null +++ b/spidr/src/Compiler/MLIR/Pass/PassManager.idr @@ -0,0 +1,46 @@ +{-- +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--} +||| For internal spidr use only. +module Compiler.MLIR.Pass.PassManager + +import Compiler.MLIR.IR.BuiltinOps +import Compiler.MLIR.IR.MLIRContext +import Compiler.FFI + +public export +data PassManager = MkPassManager GCAnyPtr + +%foreign (libxla "PassManager_new") +prim__mkPassManager : GCAnyPtr -> PrimIO AnyPtr + +%foreign (libxla "PassManager_delete") +prim__delete : AnyPtr -> PrimIO () + +export +mkPassManager : HasIO io => MLIRContext -> io PassManager +mkPassManager (MkMLIRContext ctx) = do + manager <- primIO $ prim__mkPassManager ctx + manager <- onCollectAny manager (primIO . PassManager.prim__delete) + pure (MkPassManager manager) + +%foreign (libxla "PassManager_run") +prim__passManagerRun : GCAnyPtr -> GCAnyPtr -> PrimIO Int + +export +run : HasIO io => PassManager -> ModuleOp -> io Bool +run (MkPassManager manager) (MkModuleOp op) = do + ok <- primIO $ prim__passManagerRun manager op + pure (cIntToBool ok) diff --git a/spidr/src/Compiler/Xla/HLO/Builder/XlaBuilder.idr b/spidr/src/Compiler/Xla/HLO/Builder/XlaBuilder.idr index c85bc08bf..d1ca3ad42 100644 --- a/spidr/src/Compiler/Xla/HLO/Builder/XlaBuilder.idr +++ b/spidr/src/Compiler/Xla/HLO/Builder/XlaBuilder.idr @@ -347,6 +347,17 @@ cholesky (MkXlaOp a) lower = do opPtr <- onCollectAny opPtr XlaOp.delete pure (MkXlaOp opPtr) +%foreign (libxla "Call") +prim__call : GCAnyPtr -> GCAnyPtr -> GCAnyPtr -> Bits64 -> PrimIO AnyPtr + +export +call : HasIO io => XlaBuilder -> XlaComputation -> List XlaOp -> io XlaOp +call (MkXlaBuilder builder) (MkXlaComputation computation) operands = do + MkXlaOpArray operandsXlaOpArrayPtr <- mkXlaOpArray operands + op <- primIO $ prim__call builder computation operandsXlaOpArrayPtr (cast $ length operands) + op <- onCollectAny op XlaOp.delete + pure (MkXlaOp op) + %foreign (libxla "Add") prim__add : GCAnyPtr -> GCAnyPtr -> PrimIO AnyPtr diff --git a/spidr/src/Compiler/Xla/HLO/Builder/XlaComputation.idr b/spidr/src/Compiler/Xla/HLO/Builder/XlaComputation.idr index a9a4c455f..c4cd9a8a8 100644 --- a/spidr/src/Compiler/Xla/HLO/Builder/XlaComputation.idr +++ b/spidr/src/Compiler/Xla/HLO/Builder/XlaComputation.idr @@ -24,6 +24,9 @@ public export data XlaComputation : Type where MkXlaComputation : GCAnyPtr -> XlaComputation +%foreign (libxla "XlaComputation_new") +prim__mkXlaComputation : GCAnyPtr -> PrimIO AnyPtr + %foreign (libxla "XlaComputation_delete") prim__delete : AnyPtr -> PrimIO () @@ -31,6 +34,13 @@ export delete : AnyPtr -> IO () delete = primIO . XlaComputation.prim__delete +export +mkXlaComputation : HasIO io => HloModuleProto -> io XlaComputation +mkXlaComputation (MkHloModuleProto proto) = do + comp <- primIO $ prim__mkXlaComputation proto + comp <- onCollectAny comp XlaComputation.delete + pure (MkXlaComputation comp) + %foreign (libxla "XlaComputation_proto") prim__xlaComputationProto : GCAnyPtr -> PrimIO AnyPtr diff --git a/spidr/src/Compiler/Xla/HLO/Translate/StableHLO.idr b/spidr/src/Compiler/Xla/HLO/Translate/StableHLO.idr index ee634a6cf..8f2f39854 100644 --- a/spidr/src/Compiler/Xla/HLO/Translate/StableHLO.idr +++ b/spidr/src/Compiler/Xla/HLO/Translate/StableHLO.idr @@ -30,3 +30,13 @@ convertHloToStablehlo (MkMLIRContext ctx) (MkHloModuleProto proto) = do moduleOp <- primIO $ prim__convertHloToStablehlo ctx proto moduleOp <- onCollectAny moduleOp (primIO . BuiltinOps.prim__delete) pure (MkModuleOp moduleOp) + +%foreign (libxla "ConvertStablehloToHlo") +prim__convertStablehloToHlo : GCAnyPtr -> PrimIO AnyPtr + +export +convertStablehloToHlo : HasIO io => ModuleOp -> io HloModuleProto +convertStablehloToHlo (MkModuleOp op) = do + hlo <- primIO $ prim__convertStablehloToHlo op + hlo <- onCollectAny hlo (primIO . BuiltinOps.prim__delete) + pure (MkHloModuleProto hlo) diff --git a/spidr/src/Tensor.idr b/spidr/src/Tensor.idr index 8d97cd785..ea84f355e 100644 --- a/spidr/src/Tensor.idr +++ b/spidr/src/Tensor.idr @@ -239,6 +239,17 @@ export castDtype : Primitive.Integral a => Tensor shape a -> Tensor shape F64 castDtype $ MkTensor x = MkTensor $ ConvertElementType {dtype = F64} x +export +grad : (Tensor shape F64 -> Tag $ Tensor [] F64) -> Tensor shape F64 -> Tag $ Tensor shape F64 +grad f (MkTensor x) = MkTagT $ do + addr <- reserve + let MkTagT app = f (MkTensor $ Var addr) + (env, MkTensor res) = runState (emptyFrom !get) app + g = MkFn [(addr, MkParameter [] F64)] res env + + updateCounterFrom env + pure $ MkTensor $ Grad g x + ----------------------------- structural operations ---------------------------- ||| Reshape a `Tensor`. For example, `reshape {to = [2, 1]} (tensor [3, 4])` is diff --git a/test/runner/Unit/TestTensor.idr b/test/runner/Unit/TestTensor.idr index e774435ab..0247a2656 100644 --- a/test/runner/Unit/TestTensor.idr +++ b/test/runner/Unit/TestTensor.idr @@ -15,6 +15,7 @@ limitations under the License. --} module Unit.TestTensor +import Unit.TestTensor.AD import Unit.TestTensor.Elementwise import Unit.TestTensor.HigherOrder import Unit.TestTensor.Sampling @@ -500,7 +501,8 @@ group = MkGroup "Tensor" $ [ , (#"(|\) and (/|) ignore opposite elements"#, triangularSolveIgnoresOppositeElems) , ("trace", trace) ] ++ concat (the (List _) [ - Unit.TestTensor.Elementwise.all + Unit.TestTensor.AD.all + , Unit.TestTensor.Elementwise.all , Unit.TestTensor.HigherOrder.all , Unit.TestTensor.Sampling.all , Unit.TestTensor.Slice.all diff --git a/test/runner/Unit/TestTensor/AD.idr b/test/runner/Unit/TestTensor/AD.idr new file mode 100644 index 000000000..c67550a78 --- /dev/null +++ b/test/runner/Unit/TestTensor/AD.idr @@ -0,0 +1,43 @@ +{-- +Copyright 2023 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--} +module Unit.TestTensor.AD + +import System + +import Device +import Tensor + +import Utils +import Utils.Comparison +import Utils.Cases + +square : Device => Property +square = fixedProperty $ do + grad (pure . square) (tensor 1.0) ===# pure (tensor 2.0) +{- + shape <- forAll shapes + + x <- forAll (literal shape doubles) + let x' = tensor {dtype = F64} x + map id x ==~ unsafeEval (map pure x') + map (1.0 /) x ==~ Tag.unsafeEval (map (pure . (1.0 /)) x') +-} + +export +all : Device => List (PropertyName, Property) +all = [ + ("grad square", square) + ] diff --git a/test/runner/runner.ipkg b/test/runner/runner.ipkg index 8563d2bf0..f837aba18 100644 --- a/test/runner/runner.ipkg +++ b/test/runner/runner.ipkg @@ -8,6 +8,7 @@ depends = modules = Unit.Model.TestKernel, + Unit.TestTensor.AD, Unit.TestTensor.Elementwise, Unit.TestTensor.HigherOrder, Unit.TestTensor.Sampling, From 8abda252120740c6a21ba2c885d1378c6a7081f6 Mon Sep 17 00:00:00 2001 From: Joel Berkeley <16429957+joelberkeley@users.noreply.github.com> Date: Mon, 16 Dec 2024 00:41:57 +0000 Subject: [PATCH 19/38] shellcheck --- dev.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dev.sh b/dev.sh index bf14e875c..dda7ab22e 100644 --- a/dev.sh +++ b/dev.sh @@ -24,9 +24,9 @@ install_git_repository () { } install_xla () { - install_git_repository $1 $2 https://github.com/openxla/xla + install_git_repository "$1" "$2" https://github.com/openxla/xla } install_enzyme () { - install_git_repository $1 $2 https://github.com/EnzymeAD/Enzyme-JAX.git + install_git_repository "$1" "$2" https://github.com/EnzymeAD/Enzyme-JAX.git } From edd3f398654d7a660139d8f3e9fb3a86232f6655 Mon Sep 17 00:00:00 2001 From: Joel Berkeley <16429957+joelberkeley@users.noreply.github.com> Date: Mon, 16 Dec 2024 00:43:04 +0000 Subject: [PATCH 20/38] shellcheck --- dev.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dev.sh b/dev.sh index dda7ab22e..df3df6ad8 100644 --- a/dev.sh +++ b/dev.sh @@ -17,7 +17,7 @@ install_git_repository () { ( cd "$2" git init - git remote add origin $3 + git remote add origin "$3" git fetch --depth 1 origin "$1" git checkout FETCH_HEAD ) From c933e6a4e42a73ad1c040de21c79004bb7eeface Mon Sep 17 00:00:00 2001 From: Joel Berkeley <16429957+joelberkeley@users.noreply.github.com> Date: Mon, 16 Dec 2024 20:56:19 +0000 Subject: [PATCH 21/38] update enzyme version --- spidr/backend/ENZYME_JAX_VERSION | 2 +- spidr/backend/VERSION | 2 +- spidr/backend/build.sh | 1 - 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/spidr/backend/ENZYME_JAX_VERSION b/spidr/backend/ENZYME_JAX_VERSION index 6183501f6..f90916d1f 100644 --- a/spidr/backend/ENZYME_JAX_VERSION +++ b/spidr/backend/ENZYME_JAX_VERSION @@ -1 +1 @@ -v0.0.9 \ No newline at end of file +51687b09d49dee1044c6767c0aca9b3dbb3c97d5 \ No newline at end of file diff --git a/spidr/backend/VERSION b/spidr/backend/VERSION index ceddfb28f..9beca35dc 100644 --- a/spidr/backend/VERSION +++ b/spidr/backend/VERSION @@ -1 +1 @@ -0.0.15 +0.0.15 \ No newline at end of file diff --git a/spidr/backend/build.sh b/spidr/backend/build.sh index 5f32cc19e..dffb1bf9c 100755 --- a/spidr/backend/build.sh +++ b/spidr/backend/build.sh @@ -31,7 +31,6 @@ esac (cd xla; ./configure.py --backend=cpu --os=$os) mkdir Enzyme-JAX install_enzyme "$enzyme_rev" Enzyme-JAX - sed -i -e 's/"-Werror=unused-variable",//g' Enzyme-JAX/src/enzyme_ad/jax/BUILD bazel build //:c_xla rm -rf xla ) From 4e5a25390723c4ebf7a613f553fd9d2e604b3d84 Mon Sep 17 00:00:00 2001 From: Joel Berkeley <16429957+joelberkeley@users.noreply.github.com> Date: Thu, 19 Dec 2024 21:20:37 +0000 Subject: [PATCH 22/38] wip --- spidr/backend/BUILD | 6 ++-- spidr/backend/WORKSPACE | 19 +++++++++++- spidr/backend/build.sh | 1 + .../enzyme/Enzyme/MLIR/Dialect}/BUILD | 5 ++-- .../enzyme/Enzyme/MLIR/Dialect/Dialect.cpp} | 8 ++--- .../Enzyme/enzyme/Enzyme/MLIR/Passes/BUILD | 12 ++++++++ .../enzyme/Enzyme/MLIR/Passes/Passes.cpp | 24 +++++++++++++++ spidr/backend/src/mlir/Pass/Pass.cpp | 24 +++++++++++++++ spidr/backend/src/mlir/Pass/Pass.h | 18 +++++++++++ spidr/backend/src/mlir/Pass/PassManager.cpp | 10 ++++++- spidr/spidr.ipkg | 3 +- .../enzyme/Enzyme/MLIR/Dialect/Dialect.idr} | 13 ++++---- .../enzyme/Enzyme/MLIR/Passes/Passes.idr | 30 +++++++++++++++++++ spidr/src/Compiler/Eval.idr | 22 ++++++++++---- .../RawOStream.idr => MLIR/Pass/Pass.idr} | 16 ++-------- spidr/src/Compiler/MLIR/Pass/PassManager.idr | 7 +++++ spidr/src/Tensor.idr | 1 + 17 files changed, 182 insertions(+), 37 deletions(-) rename spidr/backend/src/{Enzyme-JAX/src/enzyme_ad/jax/Implementations => Enzyme/enzyme/Enzyme/MLIR/Dialect}/BUILD (69%) rename spidr/backend/src/{Enzyme-JAX/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp => Enzyme/enzyme/Enzyme/MLIR/Dialect/Dialect.cpp} (68%) create mode 100644 spidr/backend/src/Enzyme/enzyme/Enzyme/MLIR/Passes/BUILD create mode 100644 spidr/backend/src/Enzyme/enzyme/Enzyme/MLIR/Passes/Passes.cpp create mode 100644 spidr/backend/src/mlir/Pass/Pass.cpp create mode 100644 spidr/backend/src/mlir/Pass/Pass.h rename spidr/src/Compiler/{EnzymeJAX/Src/EnzymeAD/JAX/Implementations/StableHLOAutoDiffOpInterfaceImpl.idr => Enzyme/enzyme/Enzyme/MLIR/Dialect/Dialect.idr} (56%) create mode 100644 spidr/src/Compiler/Enzyme/enzyme/Enzyme/MLIR/Passes/Passes.idr rename spidr/src/Compiler/{EnzymeJAX/Support/RawOStream.idr => MLIR/Pass/Pass.idr} (58%) diff --git a/spidr/backend/BUILD b/spidr/backend/BUILD index cad409a7d..414b2b47c 100644 --- a/spidr/backend/BUILD +++ b/spidr/backend/BUILD @@ -12,7 +12,8 @@ cc_binary( linkshared = True, linkstatic = True, srcs = [ - "//src/Enzyme-JAX/src/enzyme_ad/jax/Implementations", + "//src/Enzyme/enzyme/Enzyme/MLIR/Dialect", + "//src/Enzyme/enzyme/Enzyme/MLIR/Passes", "//src/llvm/Support", "//src/mlir/IR", "//src/mlir/Pass", @@ -29,7 +30,8 @@ cc_binary( "//src", ], deps = [ - "//src/Enzyme-JAX/src/enzyme_ad/jax/Implementations", + "//src/Enzyme/enzyme/Enzyme/MLIR/Dialect", + "//src/Enzyme/enzyme/Enzyme/MLIR/Passes", "//src/llvm/Support", "//src/mlir/IR", "//src/mlir/Pass", diff --git a/spidr/backend/WORKSPACE b/spidr/backend/WORKSPACE index 4f53e783a..b57874318 100644 --- a/spidr/backend/WORKSPACE +++ b/spidr/backend/WORKSPACE @@ -33,10 +33,27 @@ cuda_configure(name = "local_config_cuda") ### Enzyme-JAX # note enzyme-jax specifies XLA versions, which we're currently ignoring. Do we need to use their versions? +local_repository(name = "enzyme-jax", path = "Enzyme-JAX") load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") -local_repository(name = "enzyme-jax", path = "Enzyme-JAX") +http_archive( + name = "hedron_compile_commands", + + # Replace the commit hash (0e990032f3c5a866e72615cf67e5ce22186dcb97) in both places (below) with the latest (https://github.com/hedronvision/bazel-compile-commands-extractor/commits/main), rather than using the stale one here. + # Even better, set up Renovate and let it do the work for you (see "Suggestion: Updates" in the README). + url = "https://github.com/hedronvision/bazel-compile-commands-extractor/archive/4f28899228fb3ad0126897876f147ca15026151e.tar.gz", + strip_prefix = "bazel-compile-commands-extractor-4f28899228fb3ad0126897876f147ca15026151e", + # When you first run this tool, it'll recommend a sha256 hash to put here with a message like: "DEBUG: Rule 'hedron_compile_commands' indicated that a canonical reproducible form can be obtained by modifying arguments sha256 = ..." +) +# load("@hedron_compile_commands//:workspace_setup.bzl", "hedron_compile_commands_setup") +# hedron_compile_commands_setup() +# load("@hedron_compile_commands//:workspace_setup_transitive.bzl", "hedron_compile_commands_setup_transitive") +# hedron_compile_commands_setup_transitive() +# load("@hedron_compile_commands//:workspace_setup_transitive_transitive.bzl", "hedron_compile_commands_setup_transitive_transitive") +# hedron_compile_commands_setup_transitive_transitive() +# load("@hedron_compile_commands//:workspace_setup_transitive_transitive_transitive.bzl", "hedron_compile_commands_setup_transitive_transitive_transitive") +# hedron_compile_commands_setup_transitive_transitive_transitive() load("@enzyme-jax//:workspace.bzl", "JAX_COMMIT", "JAX_SHA256", "ENZYME_COMMIT", "ENZYME_SHA256") diff --git a/spidr/backend/build.sh b/spidr/backend/build.sh index dffb1bf9c..61d5612e4 100755 --- a/spidr/backend/build.sh +++ b/spidr/backend/build.sh @@ -31,6 +31,7 @@ esac (cd xla; ./configure.py --backend=cpu --os=$os) mkdir Enzyme-JAX install_enzyme "$enzyme_rev" Enzyme-JAX + # sed -i -e 's/"-Werror=unused-variable",//g' Enzyme-JAX/src/enzyme_ad/jax/BUILD bazel build //:c_xla rm -rf xla ) diff --git a/spidr/backend/src/Enzyme-JAX/src/enzyme_ad/jax/Implementations/BUILD b/spidr/backend/src/Enzyme/enzyme/Enzyme/MLIR/Dialect/BUILD similarity index 69% rename from spidr/backend/src/Enzyme-JAX/src/enzyme_ad/jax/Implementations/BUILD rename to spidr/backend/src/Enzyme/enzyme/Enzyme/MLIR/Dialect/BUILD index a3208d05e..f2b284155 100644 --- a/spidr/backend/src/Enzyme-JAX/src/enzyme_ad/jax/Implementations/BUILD +++ b/spidr/backend/src/Enzyme/enzyme/Enzyme/MLIR/Dialect/BUILD @@ -1,11 +1,12 @@ cc_library( - name = "Implementations", + name = "Dialect", linkstatic = True, alwayslink = True, srcs = glob(["*.cpp"]), hdrs = glob(["*.h"]), deps = [ - "@enzyme-jax//:enzymexlamlir-opt", + "@enzyme//:EnzymeMLIR", + "@llvm-project//mlir:IR", "//src/mlir/IR", ], visibility = ["//visibility:public"], diff --git a/spidr/backend/src/Enzyme-JAX/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp b/spidr/backend/src/Enzyme/enzyme/Enzyme/MLIR/Dialect/Dialect.cpp similarity index 68% rename from spidr/backend/src/Enzyme-JAX/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp rename to spidr/backend/src/Enzyme/enzyme/Enzyme/MLIR/Dialect/Dialect.cpp index 09e36b435..74aad7d40 100644 --- a/spidr/backend/src/Enzyme-JAX/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp +++ b/spidr/backend/src/Enzyme/enzyme/Enzyme/MLIR/Dialect/Dialect.cpp @@ -13,13 +13,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "src/enzyme_ad/jax/Implementations/XLADerivatives.h" +#include "mlir/IR/DialectRegistry.h" +#include "Enzyme/MLIR/Dialect/Dialect.h" #include "../../../../../mlir/IR/DialectRegistry.h" extern "C" { - void registerStableHLODialectAutoDiffInterface(DialectRegistry& registry) { - auto& registry_ = reinterpret_cast(registry); - mlir::enzyme::registerStableHLODialectAutoDiffInterface(registry_); + void DialectRegistry_insert_EnzymeDialect(DialectRegistry& s) { + reinterpret_cast(s).insert(); } } diff --git a/spidr/backend/src/Enzyme/enzyme/Enzyme/MLIR/Passes/BUILD b/spidr/backend/src/Enzyme/enzyme/Enzyme/MLIR/Passes/BUILD new file mode 100644 index 000000000..ab3aeb8f2 --- /dev/null +++ b/spidr/backend/src/Enzyme/enzyme/Enzyme/MLIR/Passes/BUILD @@ -0,0 +1,12 @@ +cc_library( + name = "Passes", + linkstatic = True, + alwayslink = True, + srcs = glob(["*.cpp"]), + hdrs = glob(["*.h"]), + deps = [ + "@enzyme//:EnzymeMLIR", + "//src/mlir/Pass", + ], + visibility = ["//visibility:public"], +) diff --git a/spidr/backend/src/Enzyme/enzyme/Enzyme/MLIR/Passes/Passes.cpp b/spidr/backend/src/Enzyme/enzyme/Enzyme/MLIR/Passes/Passes.cpp new file mode 100644 index 000000000..d1652d7d8 --- /dev/null +++ b/spidr/backend/src/Enzyme/enzyme/Enzyme/MLIR/Passes/Passes.cpp @@ -0,0 +1,24 @@ +/* +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "Enzyme/MLIR/Passes/Passes.h" + +#include "../../../../../mlir/Pass/Pass.h" + +extern "C" { + Pass* createDifferentiatePass() { + return reinterpret_cast(mlir::enzyme::createDifferentiatePass().release()); + } +} diff --git a/spidr/backend/src/mlir/Pass/Pass.cpp b/spidr/backend/src/mlir/Pass/Pass.cpp new file mode 100644 index 000000000..07ee4e861 --- /dev/null +++ b/spidr/backend/src/mlir/Pass/Pass.cpp @@ -0,0 +1,24 @@ +/* +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "mlir/Pass/Pass.h" + +#include "Pass.h" + +extern "C" { + void Pass_delete(Pass* s) { + delete reinterpret_cast(s); + } +} diff --git a/spidr/backend/src/mlir/Pass/Pass.h b/spidr/backend/src/mlir/Pass/Pass.h new file mode 100644 index 000000000..edbf60f2d --- /dev/null +++ b/spidr/backend/src/mlir/Pass/Pass.h @@ -0,0 +1,18 @@ +/* +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +extern "C" { + struct Pass; +} diff --git a/spidr/backend/src/mlir/Pass/PassManager.cpp b/spidr/backend/src/mlir/Pass/PassManager.cpp index 50b5afcc5..88275ec05 100644 --- a/spidr/backend/src/mlir/Pass/PassManager.cpp +++ b/spidr/backend/src/mlir/Pass/PassManager.cpp @@ -13,7 +13,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "mlir/Pass/PassManager.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/Pass.h" #include "../IR/BuiltinOps.h" #include "../IR/MLIRContext.h" @@ -31,6 +33,12 @@ extern "C" { delete reinterpret_cast(s); } + void PassManager_addPass(PassManager& s, Pass* pass) { + auto& s_ = reinterpret_cast(s); + auto pass_ = reinterpret_cast(pass); + s_.addPass(pass_); + } + int PassManager_run(PassManager& s, Operation* op) { auto& s_ = reinterpret_cast(s); auto op_ = reinterpret_cast(op); diff --git a/spidr/spidr.ipkg b/spidr/spidr.ipkg index 646582345..c86ac1821 100644 --- a/spidr/spidr.ipkg +++ b/spidr/spidr.ipkg @@ -8,7 +8,8 @@ modules = BayesianOptimization, BayesianOptimization.Acquisition, - Compiler.EnzymeJAX.Src.EnzymeAD.JAX.Implementations.StableHLOAutoDiffOpInterfaceImpl, + Compiler.Enzyme.enzyme.Enzyme.MLIR.Dialect.Dialect, + Compiler.Enzyme.enzyme.Enzyme.MLIR.Passes.Passes, Compiler.LLVM.Support.RawOStream, Compiler.MLIR.IR.BuiltinOps, Compiler.MLIR.IR.DialectRegistry, diff --git a/spidr/src/Compiler/EnzymeJAX/Src/EnzymeAD/JAX/Implementations/StableHLOAutoDiffOpInterfaceImpl.idr b/spidr/src/Compiler/Enzyme/enzyme/Enzyme/MLIR/Dialect/Dialect.idr similarity index 56% rename from spidr/src/Compiler/EnzymeJAX/Src/EnzymeAD/JAX/Implementations/StableHLOAutoDiffOpInterfaceImpl.idr rename to spidr/src/Compiler/Enzyme/enzyme/Enzyme/MLIR/Dialect/Dialect.idr index 2be237d73..0748a87ff 100644 --- a/spidr/src/Compiler/EnzymeJAX/Src/EnzymeAD/JAX/Implementations/StableHLOAutoDiffOpInterfaceImpl.idr +++ b/spidr/src/Compiler/Enzyme/enzyme/Enzyme/MLIR/Dialect/Dialect.idr @@ -14,15 +14,14 @@ See the License for the specific language governing permissions and limitations under the License. --} ||| For internal spidr use only. -module Compiler.EnzymeJAX.Src.EnzymeAD.JAX.Implementations.StableHLOAutoDiffOpInterfaceImpl +module Compiler.Enzyme.enzyme.Enzyme.MLIR.Dialect.Dialect -import Compiler.MLIR.IR.DialectRegistry +import Compiler.MLIR.Pass.Pass import Compiler.FFI -%foreign (libxla "registerStableHLODialectAutoDiffInterface") -prim__registerStableHLODialectAutoDiffInterface : GCAnyPtr -> PrimIO () +%foreign (libxla "DialectRegistry_insert_EnzymeDialect") +prim__dialectRegistryInsertEnzymeDialect : GCAnyPtr -> PrimIO () export -registerStableHLODialectAutoDiffInterface : HasIO io => DialectRegistry -> io () -registerStableHLODialectAutoDiffInterface (MkDialectRegistry reg) = - primIO $ prim__registerStableHLODialectAutoDiffInterface reg +insertEnzymeDialect : HasIO io => DialectRegistry -> io () +insertEnzymeDialect (MkDialectRegistry reg) = primIO $ prim__dialectRegistryInsertEnzymeDialect reg diff --git a/spidr/src/Compiler/Enzyme/enzyme/Enzyme/MLIR/Passes/Passes.idr b/spidr/src/Compiler/Enzyme/enzyme/Enzyme/MLIR/Passes/Passes.idr new file mode 100644 index 000000000..27ef01067 --- /dev/null +++ b/spidr/src/Compiler/Enzyme/enzyme/Enzyme/MLIR/Passes/Passes.idr @@ -0,0 +1,30 @@ +{-- +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--} +||| For internal spidr use only. +module Compiler.Enzyme.enzyme.Enzyme.MLIR.Passes.Passes + +import Compiler.MLIR.Pass.Pass +import Compiler.FFI + +%foreign (libxla "createDifferentiatePass") +prim__createDifferentiatePass : PrimIO AnyPtr + +export +createDifferentiatePass : HasIO io => io () +createDifferentiatePass = do + pass <- primIO prim__createDifferentiatePass + pass <- onCollectAny pass (primIO . Pass.prim__delete) + pure (MkPass pass) diff --git a/spidr/src/Compiler/Eval.idr b/spidr/src/Compiler/Eval.idr index b508a357b..b871f3481 100644 --- a/spidr/src/Compiler/Eval.idr +++ b/spidr/src/Compiler/Eval.idr @@ -23,9 +23,8 @@ import Data.IOArray import Data.List import Data.List.Elem -import Compiler.Expr -import Compiler.FFI -import Compiler.LiteralRW +import Compiler.Enzyme.enzyme.Enzyme.MLIR.Dialect.Dialect +import Compiler.Enzyme.enzyme.Enzyme.MLIR.Passes.Passes import Compiler.EnzymeJAX.Src.EnzymeAD.JAX.Implementations.StableHLOAutoDiffOpInterfaceImpl import Compiler.LLVM.Support.RawOStream import Compiler.MLIR.IR.BuiltinOps @@ -52,6 +51,9 @@ import Compiler.Xla.Literal import Compiler.Xla.Shape import Compiler.Xla.ShapeUtil import Compiler.Xla.XlaData +import Compiler.Expr +import Compiler.FFI +import Compiler.LiteralRW import Literal import Primitive import Types @@ -66,6 +68,7 @@ data Err | ValueNotFound Nat | PjrtErr PjrtError | SerializationError String + | MlirPassError String export Show Err where @@ -73,6 +76,7 @@ Show Err where show (ValueNotFound idx) = "Value not found at index \{show idx}" show (PjrtErr err) = show err show (SerializationError err) = "SerializationError: \{err}" + show (MlirPassError err) = "MlirPassError: \{err}" public export 0 ErrIO : Type -> Type @@ -131,14 +135,20 @@ interpret @{cache} xlaBuilder (MkFn params root env) = do interpretE (GetTupleElement idx x) = getTupleElement !(interpretE x) idx interpretE (Grad f x) = do reg <- mkDialectRegistry + -- need other dialects? + insertEnzymeDialect reg StableHLO.Dialect.Register.registerAllDialects reg - registerStableHLODialectAutoDiffInterface reg + -- registerStableHLODialectAutoDiffInterface reg + -- should we instead be getting the context from the stablehlo ModuleOp? ctx <- mkMLIRContext appendDialectRegistry ctx reg mgr <- mkPassManager ctx + addPass mgr !createDifferentiatePass computation <- compile xlaBuilder f - stablehlo <- hloModuleProtoToStableHLO !(proto computation) -- using the wrong function, we want the module, not the string - True <- run mgr stablehlo | False => ?err + stablehlo <- hloModuleProtoToStableHLO !(proto computation) + enzymeOp <- ?enzymeAutodiffReverseOp stablehlo + True <- run mgr enzymeOp + | False => throwE $ MlirPassError "Failed to run differentiate pass on StableHLO" hloProto <- convertStablehloToHlo stablehlo computation <- mkXlaComputation hloProto -- x should be correct shape, because we're sending R^{n0, n1, ..} -> R diff --git a/spidr/src/Compiler/EnzymeJAX/Support/RawOStream.idr b/spidr/src/Compiler/MLIR/Pass/Pass.idr similarity index 58% rename from spidr/src/Compiler/EnzymeJAX/Support/RawOStream.idr rename to spidr/src/Compiler/MLIR/Pass/Pass.idr index f8b11e32a..1f1058893 100644 --- a/spidr/src/Compiler/EnzymeJAX/Support/RawOStream.idr +++ b/spidr/src/Compiler/MLIR/Pass/Pass.idr @@ -14,22 +14,12 @@ See the License for the specific language governing permissions and limitations under the License. --} ||| For internal spidr use only. -module Compiler.LLVM.Support.RawOStream +module Compiler.MLIR.Pass.Pass import Compiler.FFI public export -data RawStringOStream = MkRawStringOStream GCAnyPtr +data Pass = MkPass GCAnyPtr -%foreign (libxla "raw_string_ostream_new") -prim__mkRawStringOStream : AnyPtr -> PrimIO AnyPtr - -%foreign (libxla "raw_string_ostream_delete") +%foreign (libxla "Pass_delete") prim__delete : AnyPtr -> PrimIO () - -export -rawStringOStream : HasIO io => CppString -> io RawStringOStream -rawStringOStream (MkCppString str) = do - os <- primIO $ prim__mkRawStringOStream str - os <- onCollectAny os (primIO . prim__delete) - pure (MkRawStringOStream os) diff --git a/spidr/src/Compiler/MLIR/Pass/PassManager.idr b/spidr/src/Compiler/MLIR/Pass/PassManager.idr index 896b471d9..b44a3d50f 100644 --- a/spidr/src/Compiler/MLIR/Pass/PassManager.idr +++ b/spidr/src/Compiler/MLIR/Pass/PassManager.idr @@ -36,6 +36,13 @@ mkPassManager (MkMLIRContext ctx) = do manager <- onCollectAny manager (primIO . PassManager.prim__delete) pure (MkPassManager manager) +%foreign (libxla "PassManager_addPass") +prim__passManagerAddPass : GCAnyPtr -> GCAnyPtr -> PrimIO Int + +export +addPass : HasIO io => PassManager -> Pass -> io () +addPass (MkPassManager manager) (MkPass pass) = primIO $ prim__passManagerAddPass manager pass + %foreign (libxla "PassManager_run") prim__passManagerRun : GCAnyPtr -> GCAnyPtr -> PrimIO Int diff --git a/spidr/src/Tensor.idr b/spidr/src/Tensor.idr index 30fc3b956..8720dd562 100644 --- a/spidr/src/Tensor.idr +++ b/spidr/src/Tensor.idr @@ -238,6 +238,7 @@ export castDtype : Primitive.Integral a => Tensor shape a -> Tensor shape F64 castDtype $ MkTensor x = MkTensor $ ConvertElementType {dtype = F64} x +||| Reverse-mode automatic differentiation. export grad : (Tensor shape F64 -> Tag $ Tensor [] F64) -> Tensor shape F64 -> Tag $ Tensor shape F64 grad f (MkTensor x) = MkTagT $ do From 8828373e9c8c0275502910e99c515681ce1d8e95 Mon Sep 17 00:00:00 2001 From: Joel Berkeley <16429957+joelberkeley@users.noreply.github.com> Date: Mon, 23 Dec 2024 01:51:35 +0000 Subject: [PATCH 23/38] first (almost e2e) draft --- .../enzyme/Enzyme/MLIR/Passes/Passes.cpp | 33 ++++++++++++ spidr/backend/src/mlir/IR/Attributes.h | 18 +++++++ spidr/backend/src/mlir/IR/Block.cpp | 28 ++++++++++ spidr/backend/src/mlir/IR/Block.h | 18 +++++++ spidr/backend/src/mlir/IR/Location.cpp | 28 ++++++++++ spidr/backend/src/mlir/IR/Location.h | 18 +++++++ .../backend/src/mlir/IR/OperationSupport.cpp | 54 +++++++++++++++++++ spidr/backend/src/mlir/IR/ValueRange.h | 18 +++++++ spidr/backend/src/mlir/Pass/PassManager.cpp | 11 ++-- .../src/xla/hlo/translate/stablehlo.cpp | 4 +- spidr/spidr.ipkg | 13 ++++- .../Enzyme/MLIR/Dialect/Dialect.idr | 3 +- .../Enzyme/MLIR/Passes/Passes.idr | 4 +- spidr/src/Compiler/Eval.idr | 7 +-- spidr/src/Compiler/MLIR/IR/Attributes.idr | 35 ++++++++++++ spidr/src/Compiler/MLIR/IR/Block.idr | 35 ++++++++++++ spidr/src/Compiler/MLIR/IR/BuiltinOps.idr | 4 +- spidr/src/Compiler/MLIR/IR/Location.idr | 22 ++++++++ spidr/src/Compiler/MLIR/IR/Operation.idr | 26 +++++++++ .../src/Compiler/MLIR/IR/OperationSupport.idr | 52 ++++++++++++++++++ spidr/src/Compiler/MLIR/IR/ValueRange.idr | 22 ++++++++ spidr/src/Compiler/MLIR/Pass/Pass.idr | 1 + spidr/src/Compiler/MLIR/Pass/PassManager.idr | 3 +- test/runner/TestRunner.idr | 6 +-- test/runner/Unit/TestTensor.idr | 8 +-- test/runner/Unit/TestTensor/AD.idr | 3 +- 26 files changed, 450 insertions(+), 24 deletions(-) create mode 100644 spidr/backend/src/mlir/IR/Attributes.h create mode 100644 spidr/backend/src/mlir/IR/Block.cpp create mode 100644 spidr/backend/src/mlir/IR/Block.h create mode 100644 spidr/backend/src/mlir/IR/Location.cpp create mode 100644 spidr/backend/src/mlir/IR/Location.h create mode 100644 spidr/backend/src/mlir/IR/OperationSupport.cpp create mode 100644 spidr/backend/src/mlir/IR/ValueRange.h rename spidr/src/Compiler/Enzyme/{enzyme => Enzyme}/Enzyme/MLIR/Dialect/Dialect.idr (90%) rename spidr/src/Compiler/Enzyme/{enzyme => Enzyme}/Enzyme/MLIR/Passes/Passes.idr (89%) create mode 100644 spidr/src/Compiler/MLIR/IR/Attributes.idr create mode 100644 spidr/src/Compiler/MLIR/IR/Block.idr create mode 100644 spidr/src/Compiler/MLIR/IR/Location.idr create mode 100644 spidr/src/Compiler/MLIR/IR/Operation.idr create mode 100644 spidr/src/Compiler/MLIR/IR/OperationSupport.idr create mode 100644 spidr/src/Compiler/MLIR/IR/ValueRange.idr diff --git a/spidr/backend/src/Enzyme/enzyme/Enzyme/MLIR/Passes/Passes.cpp b/spidr/backend/src/Enzyme/enzyme/Enzyme/MLIR/Passes/Passes.cpp index d1652d7d8..e30bb7203 100644 --- a/spidr/backend/src/Enzyme/enzyme/Enzyme/MLIR/Passes/Passes.cpp +++ b/spidr/backend/src/Enzyme/enzyme/Enzyme/MLIR/Passes/Passes.cpp @@ -14,11 +14,44 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "Enzyme/MLIR/Passes/Passes.h" +#include "mlir/IR/BuiltinOps.h" #include "../../../../../mlir/Pass/Pass.h" +// for AD function +#include "mlir/IR/BuiltinTypes.h" + +#include "../../../../../mlir/IR/BuiltinOps.h" + extern "C" { Pass* createDifferentiatePass() { return reinterpret_cast(mlir::enzyme::createDifferentiatePass().release()); } + + ModuleOp* emitEnzymeADOp(ModuleOp& module_op) { + auto module_op_ = reinterpret_cast(module_op); + auto state = mlir::OperationState("enzyme.autodiff", mlir::Location()); + auto ctx = module_op_.getContext(); + + auto scalarf64 = mlir::RankedTensorType::Builder() + .setShape({}) + .setElementType(mlir::FloatType::getF64(ctx)); + state.addTypes({scalarf64}); + + auto operand = module_op_.getOperand({0}) // complete guess + state.addOperands(ValueRange({operand})); + + auto operation = module_op_.front(); // complete guess + state.addAttribute("fn", operation.getAttr("sym_name")); + auto activity = mlir::enzyme::ActivityAttr::get(ctx, mlir::enzyme::Activity::enzyme_active); + state.addAttribute("activity", {activity}); // mlir::enzyme::Activity::enzyme_active + auto ret_activity = mlir::enzyme::ActivityAttr::get( + ctx, mlir::enzyme::Activity::enzyme_active + ); + state.addAttribute("ret_activity", {ret_activity}); // mlir::enzyme::Activity::enzyme_activenoneed + + auto res = mlir::ModuleOp(&mlir::Builder(module_op_), state); // complete guess + + return reinterpret_cast(new mlir::ModuleOp(res)); + } } diff --git a/spidr/backend/src/mlir/IR/Attributes.h b/spidr/backend/src/mlir/IR/Attributes.h new file mode 100644 index 000000000..fb5e8e3a4 --- /dev/null +++ b/spidr/backend/src/mlir/IR/Attributes.h @@ -0,0 +1,18 @@ +/* +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +extern "C" { + struct Attribute; +} diff --git a/spidr/backend/src/mlir/IR/Block.cpp b/spidr/backend/src/mlir/IR/Block.cpp new file mode 100644 index 000000000..43e56ec4e --- /dev/null +++ b/spidr/backend/src/mlir/IR/Block.cpp @@ -0,0 +1,28 @@ +/* +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "mlir/IR/Block.h" + +#include "Block.h" + +extern "C" { + Block* Block_new() { + return reinterpret_cast(new mlir::Block()); + } + + void Block_delete(Block* s) { + delete reinterpret_cast(s); + } +} diff --git a/spidr/backend/src/mlir/IR/Block.h b/spidr/backend/src/mlir/IR/Block.h new file mode 100644 index 000000000..0b730556d --- /dev/null +++ b/spidr/backend/src/mlir/IR/Block.h @@ -0,0 +1,18 @@ +/* +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +extern "C" { + struct Block; +} diff --git a/spidr/backend/src/mlir/IR/Location.cpp b/spidr/backend/src/mlir/IR/Location.cpp new file mode 100644 index 000000000..8d4c7d7be --- /dev/null +++ b/spidr/backend/src/mlir/IR/Location.cpp @@ -0,0 +1,28 @@ +/* +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "mlir/IR/Location.h" + +#include "Location.h" + +extern "C" { + Location* Location_new(...) { + return nullptr; //reinterpret_cast(new mlir::Location(...)); + } + + void Location_delete(Location* s) { + delete reinterpret_cast(s); + } +} diff --git a/spidr/backend/src/mlir/IR/Location.h b/spidr/backend/src/mlir/IR/Location.h new file mode 100644 index 000000000..438fe63cd --- /dev/null +++ b/spidr/backend/src/mlir/IR/Location.h @@ -0,0 +1,18 @@ +/* +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +extern "C" { + struct Location; +} diff --git a/spidr/backend/src/mlir/IR/OperationSupport.cpp b/spidr/backend/src/mlir/IR/OperationSupport.cpp new file mode 100644 index 000000000..fe82530af --- /dev/null +++ b/spidr/backend/src/mlir/IR/OperationSupport.cpp @@ -0,0 +1,54 @@ +/* +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "llvm/ADT/StringRef.h" +#include "mlir/IR/OperationSupport.h" + +#include "Attributes.h" +#include "Block.h" +#include "Location.h" +#include "ValueRange.h" + +extern "C" { + struct OperationState; + + OperationState* OperationState_new(Location& location, char* name) { + auto& location_ = reinterpret_cast(location); + auto op_state = new mlir::OperationState(location_, name); + return reinterpret_cast(op_state); + } + + void OperationState_delete(OperationState* s) { + delete reinterpret_cast(s); + } + + void OperationState_addOperands(OperationState& s, ValueRange& newOperands) { + auto& s_ = reinterpret_cast(s); + auto& newOperands_ = reinterpret_cast(newOperands); + s_.addOperands(newOperands_); + } + + void OperationState_addAttribute(OperationState& s, char* name, Attribute& attr) { + auto& s_ = reinterpret_cast(s); + auto& attr_ = reinterpret_cast(attr); + s_.addAttribute(name, attr_); + } + +// void OperationState_addSuccessors(OperationState& s, Block* successor) { +// auto& s_ = reinterpret_cast(s); +// auto successor_ = reinterpret_cast(successor); +// s_.addSuccessors(successor_); +// } +} diff --git a/spidr/backend/src/mlir/IR/ValueRange.h b/spidr/backend/src/mlir/IR/ValueRange.h new file mode 100644 index 000000000..c568df6d0 --- /dev/null +++ b/spidr/backend/src/mlir/IR/ValueRange.h @@ -0,0 +1,18 @@ +/* +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +extern "C" { + struct ValueRange; +} diff --git a/spidr/backend/src/mlir/Pass/PassManager.cpp b/spidr/backend/src/mlir/Pass/PassManager.cpp index 88275ec05..6116840fe 100644 --- a/spidr/backend/src/mlir/Pass/PassManager.cpp +++ b/spidr/backend/src/mlir/Pass/PassManager.cpp @@ -15,8 +15,9 @@ limitations under the License. */ #include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/Pass.h" -#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "Pass.h" #include "../IR/BuiltinOps.h" #include "../IR/MLIRContext.h" #include "../IR/Operation.h" @@ -34,9 +35,11 @@ extern "C" { } void PassManager_addPass(PassManager& s, Pass* pass) { - auto& s_ = reinterpret_cast(s); - auto pass_ = reinterpret_cast(pass); - s_.addPass(pass_); + return; // i hate cpp +// auto& s_ = reinterpret_cast(s); +// auto pass_ = reinterpret_cast(pass); +// auto pass__ = std::unique_ptr{std::exchange(pass_, nullptr)}; +// s_.addPass(pass__); } int PassManager_run(PassManager& s, Operation* op) { diff --git a/spidr/backend/src/xla/hlo/translate/stablehlo.cpp b/spidr/backend/src/xla/hlo/translate/stablehlo.cpp index 6b023a961..6ef61891c 100644 --- a/spidr/backend/src/xla/hlo/translate/stablehlo.cpp +++ b/spidr/backend/src/xla/hlo/translate/stablehlo.cpp @@ -26,13 +26,15 @@ extern "C" { auto& ctx_ = reinterpret_cast(ctx); auto hlo_module_ = reinterpret_cast(hlo_module); auto module_op = xla::ConvertHloToStablehlo(ctx_, hlo_module_); + module_op.value()->dump(); return reinterpret_cast(new mlir::ModuleOp(module_op.value().release())); } HloModuleProto* ConvertStablehloToHlo(ModuleOp& module) { auto& module_ = reinterpret_cast(module); + auto hlo = xla::ConvertStablehloToHlo(module_).value().release(); // mode ToProto to separate function? - auto res = xla::ConvertStablehloToHlo(module_).value().release()->ToProto(); + auto res = hlo->ToProto(); return reinterpret_cast(new xla::HloModuleProto(res)); } } diff --git a/spidr/spidr.ipkg b/spidr/spidr.ipkg index c86ac1821..f6a09f748 100644 --- a/spidr/spidr.ipkg +++ b/spidr/spidr.ipkg @@ -8,16 +8,25 @@ modules = BayesianOptimization, BayesianOptimization.Acquisition, - Compiler.Enzyme.enzyme.Enzyme.MLIR.Dialect.Dialect, - Compiler.Enzyme.enzyme.Enzyme.MLIR.Passes.Passes, + Compiler.Enzyme.Enzyme.Enzyme.MLIR.Dialect.Dialect, + Compiler.Enzyme.Enzyme.Enzyme.MLIR.Passes.Passes, + Compiler.LLVM.Support.RawOStream, + + Compiler.MLIR.IR.Block, Compiler.MLIR.IR.BuiltinOps, Compiler.MLIR.IR.DialectRegistry, + Compiler.MLIR.IR.Location, Compiler.MLIR.IR.MLIRContext, + Compiler.MLIR.IR.OperationSupport, + Compiler.MLIR.IR.ValueRange, Compiler.MLIR.Pass.PassManager, + Compiler.MLIR.Pass.Pass, + Compiler.StableHLO.Dialect.Register, Compiler.StableHLO.Dialect.Serialization, Compiler.StableHLO.Dialect.Version, + Compiler.Xla.Client.ExecutableBuildOptions, Compiler.Xla.HLO.Builder.Lib.Arithmetic, Compiler.Xla.HLO.Builder.Lib.Constants, diff --git a/spidr/src/Compiler/Enzyme/enzyme/Enzyme/MLIR/Dialect/Dialect.idr b/spidr/src/Compiler/Enzyme/Enzyme/Enzyme/MLIR/Dialect/Dialect.idr similarity index 90% rename from spidr/src/Compiler/Enzyme/enzyme/Enzyme/MLIR/Dialect/Dialect.idr rename to spidr/src/Compiler/Enzyme/Enzyme/Enzyme/MLIR/Dialect/Dialect.idr index 0748a87ff..25b240e8f 100644 --- a/spidr/src/Compiler/Enzyme/enzyme/Enzyme/MLIR/Dialect/Dialect.idr +++ b/spidr/src/Compiler/Enzyme/Enzyme/Enzyme/MLIR/Dialect/Dialect.idr @@ -14,8 +14,9 @@ See the License for the specific language governing permissions and limitations under the License. --} ||| For internal spidr use only. -module Compiler.Enzyme.enzyme.Enzyme.MLIR.Dialect.Dialect +module Compiler.Enzyme.Enzyme.Enzyme.MLIR.Dialect.Dialect +import Compiler.MLIR.IR.DialectRegistry import Compiler.MLIR.Pass.Pass import Compiler.FFI diff --git a/spidr/src/Compiler/Enzyme/enzyme/Enzyme/MLIR/Passes/Passes.idr b/spidr/src/Compiler/Enzyme/Enzyme/Enzyme/MLIR/Passes/Passes.idr similarity index 89% rename from spidr/src/Compiler/Enzyme/enzyme/Enzyme/MLIR/Passes/Passes.idr rename to spidr/src/Compiler/Enzyme/Enzyme/Enzyme/MLIR/Passes/Passes.idr index 27ef01067..5791e81e2 100644 --- a/spidr/src/Compiler/Enzyme/enzyme/Enzyme/MLIR/Passes/Passes.idr +++ b/spidr/src/Compiler/Enzyme/Enzyme/Enzyme/MLIR/Passes/Passes.idr @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. --} ||| For internal spidr use only. -module Compiler.Enzyme.enzyme.Enzyme.MLIR.Passes.Passes +module Compiler.Enzyme.Enzyme.Enzyme.MLIR.Passes.Passes import Compiler.MLIR.Pass.Pass import Compiler.FFI @@ -23,7 +23,7 @@ import Compiler.FFI prim__createDifferentiatePass : PrimIO AnyPtr export -createDifferentiatePass : HasIO io => io () +createDifferentiatePass : HasIO io => io Pass createDifferentiatePass = do pass <- primIO prim__createDifferentiatePass pass <- onCollectAny pass (primIO . Pass.prim__delete) diff --git a/spidr/src/Compiler/Eval.idr b/spidr/src/Compiler/Eval.idr index b871f3481..78ed8038c 100644 --- a/spidr/src/Compiler/Eval.idr +++ b/spidr/src/Compiler/Eval.idr @@ -23,14 +23,15 @@ import Data.IOArray import Data.List import Data.List.Elem -import Compiler.Enzyme.enzyme.Enzyme.MLIR.Dialect.Dialect -import Compiler.Enzyme.enzyme.Enzyme.MLIR.Passes.Passes -import Compiler.EnzymeJAX.Src.EnzymeAD.JAX.Implementations.StableHLOAutoDiffOpInterfaceImpl +import Compiler.Enzyme.Enzyme.Enzyme.MLIR.Dialect.Dialect +import Compiler.Enzyme.Enzyme.Enzyme.MLIR.Passes.Passes +--import Compiler.EnzymeJAX.Src.EnzymeAD.JAX.Implementations.StableHLOAutoDiffOpInterfaceImpl import Compiler.LLVM.Support.RawOStream import Compiler.MLIR.IR.BuiltinOps import Compiler.MLIR.IR.DialectRegistry import Compiler.MLIR.IR.MLIRContext import Compiler.MLIR.Pass.PassManager +import Compiler.MLIR.Pass.Pass import Compiler.StableHLO.Dialect.Register import Compiler.StableHLO.Dialect.Serialization import Compiler.StableHLO.Dialect.Version diff --git a/spidr/src/Compiler/MLIR/IR/Attributes.idr b/spidr/src/Compiler/MLIR/IR/Attributes.idr new file mode 100644 index 000000000..d07815d42 --- /dev/null +++ b/spidr/src/Compiler/MLIR/IR/Attributes.idr @@ -0,0 +1,35 @@ +{-- +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--} +||| For internal spidr use only. +module Compiler.MLIR.IR.Attributes + +import Compiler.FFI + +public export +data Attribute = MkAttribute GCAnyPtr + +%foreign (libxla "Attribute_new") +prim__mkAttribute : PrimIO AnyPtr + +%foreign (libxla "Attribute_delete") +prim__deleteAttribute : AnyPtr -> PrimIO () + +export +mkAttribute : HasIO io => io Attribute +mkAttribute = do + Attribute <- primIO prim__mkAttribute + Attribute <- onCollectAny Attribute (primIO . prim__deleteAttribute) + pure (MkAttribute Attribute) diff --git a/spidr/src/Compiler/MLIR/IR/Block.idr b/spidr/src/Compiler/MLIR/IR/Block.idr new file mode 100644 index 000000000..921418676 --- /dev/null +++ b/spidr/src/Compiler/MLIR/IR/Block.idr @@ -0,0 +1,35 @@ +{-- +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--} +||| For internal spidr use only. +module Compiler.MLIR.IR.Block + +import Compiler.FFI + +public export +data Block = MkBlock GCAnyPtr + +%foreign (libxla "Block_new") +prim__mkBlock : PrimIO AnyPtr + +%foreign (libxla "Block_delete") +prim__deleteBlock : AnyPtr -> PrimIO () + +export +mkBlock : HasIO io => io Block +mkBlock = do + block <- primIO prim__mkBlock + block <- onCollectAny block (primIO . prim__deleteBlock) + pure (MkBlock block) diff --git a/spidr/src/Compiler/MLIR/IR/BuiltinOps.idr b/spidr/src/Compiler/MLIR/IR/BuiltinOps.idr index ba44b6583..757c078b4 100644 --- a/spidr/src/Compiler/MLIR/IR/BuiltinOps.idr +++ b/spidr/src/Compiler/MLIR/IR/BuiltinOps.idr @@ -14,12 +14,12 @@ See the License for the specific language governing permissions and limitations under the License. --} ||| For internal spidr use only. -module Compiler.MLIR.IR.BuiltinOps +module Compiler.MLIR.IR.Operation import Compiler.FFI public export -data ModuleOp = MkModuleOp GCAnyPtr +data Operation = MkOperation GCAnyPtr export %foreign (libxla "ModuleOp_delete") diff --git a/spidr/src/Compiler/MLIR/IR/Location.idr b/spidr/src/Compiler/MLIR/IR/Location.idr new file mode 100644 index 000000000..d845e477c --- /dev/null +++ b/spidr/src/Compiler/MLIR/IR/Location.idr @@ -0,0 +1,22 @@ +{-- +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--} +||| For internal spidr use only. +module Compiler.MLIR.IR.Location + +import Compiler.FFI + +public export +data Location = MkLocation GCAnyPtr diff --git a/spidr/src/Compiler/MLIR/IR/Operation.idr b/spidr/src/Compiler/MLIR/IR/Operation.idr new file mode 100644 index 000000000..ba44b6583 --- /dev/null +++ b/spidr/src/Compiler/MLIR/IR/Operation.idr @@ -0,0 +1,26 @@ +{-- +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--} +||| For internal spidr use only. +module Compiler.MLIR.IR.BuiltinOps + +import Compiler.FFI + +public export +data ModuleOp = MkModuleOp GCAnyPtr + +export +%foreign (libxla "ModuleOp_delete") +prim__delete : AnyPtr -> PrimIO () diff --git a/spidr/src/Compiler/MLIR/IR/OperationSupport.idr b/spidr/src/Compiler/MLIR/IR/OperationSupport.idr new file mode 100644 index 000000000..994c55939 --- /dev/null +++ b/spidr/src/Compiler/MLIR/IR/OperationSupport.idr @@ -0,0 +1,52 @@ +{-- +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--} +||| For internal spidr use only. +module Compiler.MLIR.Pass.OperationSupport + +import Compiler.MLIR.IR.Location +import Compiler.FFI + +public export +data OperationState = MkOperationState GCAnyPtr + +%foreign (libxla "OperationState_new") +prim__mkOperationState : GCAnyPtr -> String -> PrimIO AnyPtr + +%foreign (libxla "OperationState_delete") +prim__delete : AnyPtr -> PrimIO () + +export +mkOperationState : HasIO io => Location -> String -> io OperationState +mkOperationState (MkLocation location) name = do + opState <- primIO $ prim__mkOperationState location name + opState <- onCollectAny opState (primIO . OperationState.prim__delete) + pure (MkOperationState opState) + +%foreign (libxla "OperationState_addOperands") +prim__operationStateAddOperands : GCAnyPtr -> GCAnyPtr -> PrimIO () + +export +addOperands : HasIO io => OperationState -> ValueRange -> io () +addOperands (MkOperationState opState) (MkValueRange valueRange) = + primIO $ prim__operationStateAddOperands opState valueRange + +%foreign (libxla "OperationState_addAttribute") +prim__operationStateAddAttribute : GCAnyPtr -> GCAnyPtr -> PrimIO () + +export +addAttribute : HasIO io => OperationState -> Attribute -> io () +addAttribute (MkOperationState opState) (MkAttribute attribute) = + primIO $ prim__operationStateAddAttribute opState attribute diff --git a/spidr/src/Compiler/MLIR/IR/ValueRange.idr b/spidr/src/Compiler/MLIR/IR/ValueRange.idr new file mode 100644 index 000000000..b347b30fa --- /dev/null +++ b/spidr/src/Compiler/MLIR/IR/ValueRange.idr @@ -0,0 +1,22 @@ +{-- +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--} +||| For internal spidr use only. +module Compiler.MLIR.IR.ValueRange + +import Compiler.FFI + +public export +data ValueRange = MkValueRange GCAnyPtr diff --git a/spidr/src/Compiler/MLIR/Pass/Pass.idr b/spidr/src/Compiler/MLIR/Pass/Pass.idr index 1f1058893..af268a944 100644 --- a/spidr/src/Compiler/MLIR/Pass/Pass.idr +++ b/spidr/src/Compiler/MLIR/Pass/Pass.idr @@ -21,5 +21,6 @@ import Compiler.FFI public export data Pass = MkPass GCAnyPtr +export %foreign (libxla "Pass_delete") prim__delete : AnyPtr -> PrimIO () diff --git a/spidr/src/Compiler/MLIR/Pass/PassManager.idr b/spidr/src/Compiler/MLIR/Pass/PassManager.idr index b44a3d50f..eb4972221 100644 --- a/spidr/src/Compiler/MLIR/Pass/PassManager.idr +++ b/spidr/src/Compiler/MLIR/Pass/PassManager.idr @@ -18,6 +18,7 @@ module Compiler.MLIR.Pass.PassManager import Compiler.MLIR.IR.BuiltinOps import Compiler.MLIR.IR.MLIRContext +import Compiler.MLIR.Pass.Pass import Compiler.FFI public export @@ -37,7 +38,7 @@ mkPassManager (MkMLIRContext ctx) = do pure (MkPassManager manager) %foreign (libxla "PassManager_addPass") -prim__passManagerAddPass : GCAnyPtr -> GCAnyPtr -> PrimIO Int +prim__passManagerAddPass : GCAnyPtr -> GCAnyPtr -> PrimIO () export addPass : HasIO io => PassManager -> Pass -> io () diff --git a/test/runner/TestRunner.idr b/test/runner/TestRunner.idr index 792d2ae12..835c775c3 100644 --- a/test/runner/TestRunner.idr +++ b/test/runner/TestRunner.idr @@ -30,12 +30,12 @@ import Unit.TestUtil export run : Device -> IO () -run device = test [ +run device = test [{- Utils.TestComparison.group , TestUtils.group , Unit.TestUtil.group , Unit.TestLiteral.group - , Unit.TestTensor.group + ,-} Unit.TestTensor.group{- , Unit.TestDistribution.group - , Unit.Model.TestKernel.group + , Unit.Model.TestKernel.group-} ] diff --git a/test/runner/Unit/TestTensor.idr b/test/runner/Unit/TestTensor.idr index 0247a2656..20c5762fb 100644 --- a/test/runner/Unit/TestTensor.idr +++ b/test/runner/Unit/TestTensor.idr @@ -478,7 +478,7 @@ trace = fixedProperty $ export group : Device => Group -group = MkGroup "Tensor" $ [ +group = MkGroup "Tensor" $ [{- ("eval . tensor", tensorThenEval) , ("eval multiple tensors (tuple)", evalTuple) , ("eval multiple tensors (tuple) for non-trivial graph", evalTupleNonTrivial) @@ -499,12 +499,12 @@ group = MkGroup "Tensor" $ [ , ("cholesky", cholesky) , (#"(|\) and (/|) result and inverse"#, triangularSolveResultAndInverse) , (#"(|\) and (/|) ignore opposite elements"#, triangularSolveIgnoresOppositeElems) - , ("trace", trace) + , ("trace", trace)-} ] ++ concat (the (List _) [ Unit.TestTensor.AD.all - , Unit.TestTensor.Elementwise.all + {-, Unit.TestTensor.Elementwise.all , Unit.TestTensor.HigherOrder.all , Unit.TestTensor.Sampling.all , Unit.TestTensor.Slice.all - , Unit.TestTensor.Structure.all + , Unit.TestTensor.Structure.all-} ]) diff --git a/test/runner/Unit/TestTensor/AD.idr b/test/runner/Unit/TestTensor/AD.idr index c67550a78..c0ac11bad 100644 --- a/test/runner/Unit/TestTensor/AD.idr +++ b/test/runner/Unit/TestTensor/AD.idr @@ -26,7 +26,8 @@ import Utils.Cases square : Device => Property square = fixedProperty $ do - grad (pure . square) (tensor 1.0) ===# pure (tensor 2.0) + sqrt (square $ tensor 3.0) ===# tensor 3.0 +-- grad (pure . square) (tensor 1.0) ===# pure (tensor 2.0) {- shape <- forAll shapes From 0a5f935c345118ee847d011ace9d181473f0b74d Mon Sep 17 00:00:00 2001 From: Joel Berkeley <16429957+joelberkeley@users.noreply.github.com> Date: Mon, 23 Dec 2024 21:20:31 +0000 Subject: [PATCH 24/38] compiling with runtime errors --- .../enzyme/Enzyme/MLIR/Passes/Passes.cpp | 159 ++++++++++++++++-- .../Enzyme/Enzyme/MLIR/Passes/Passes.idr | 12 ++ spidr/src/Compiler/Eval.idr | 2 +- spidr/src/Compiler/MLIR/IR/Attributes.idr | 6 +- spidr/src/Compiler/MLIR/IR/BuiltinOps.idr | 4 +- .../src/Compiler/MLIR/IR/OperationSupport.idr | 6 +- test/runner/Unit/TestTensor/AD.idr | 2 +- 7 files changed, 171 insertions(+), 20 deletions(-) diff --git a/spidr/backend/src/Enzyme/enzyme/Enzyme/MLIR/Passes/Passes.cpp b/spidr/backend/src/Enzyme/enzyme/Enzyme/MLIR/Passes/Passes.cpp index e30bb7203..c9763a7fb 100644 --- a/spidr/backend/src/Enzyme/enzyme/Enzyme/MLIR/Passes/Passes.cpp +++ b/spidr/backend/src/Enzyme/enzyme/Enzyme/MLIR/Passes/Passes.cpp @@ -15,34 +15,171 @@ limitations under the License. */ #include "Enzyme/MLIR/Passes/Passes.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Location.h" #include "../../../../../mlir/Pass/Pass.h" -// for AD function +#include "Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h" +#include "Enzyme/MLIR/Dialect/Ops.h" +//#include "Enzyme/MLIR/Interfaces/GradientUtilsReverse.h" +//#include "Enzyme/MLIR/PassDetails.h" + +//#include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/IR/BuiltinTypes.h" #include "../../../../../mlir/IR/BuiltinOps.h" +#include "../../../../../mlir/IR/DialectRegistry.h" + +//void registerStableHLODialectAutoDiffInterface( +// DialectRegistry ®istry) { +// registry.addExtension(+[](MLIRContext *context, +// stablehlo::StablehloDialect *) { +// registerInterfaces(context); +// +// // SortOp::attachInterface(*context); +// +// WhileOp::attachInterface(*context); +// SortOp::attachInterface(*context); +// ScatterOp::attachInterface(*context); +// ReduceOp::attachInterface(*context); +// +// CaseOp::attachInterface(*context); +// +// ScatterOp::attachInterface(*context); +// ScatterOp::attachInterface(*context); +// +// ReturnOp::attachInterface(*context); +// +// ReduceOp::attachInterface>(*context); +// IfOp::attachInterface(*context); +// IfOp::attachInterface(*context); +// IfOp::attachInterface(*context); +// +// WhileOp::attachInterface(*context); +// WhileOp::attachInterface(*context); +// ReduceOp::attachInterface>(*context); +// WhileOp::attachInterface>(*context); +// BroadcastInDimOp::attachInterface(*context); +// SliceOp::attachInterface(*context); +// DynamicUpdateSliceOp::attachInterface( +// *context); +// ReduceOp::attachInterface(*context); +// ConcatenateOp::attachInterface(*context); +// +// ConstantOp::attachInterface(*context); +// TransposeOp::attachInterface(*context); +// IfOp::attachInterface>(*context); +// WhileOp::attachInterface>(*context); +// +// ReverseOp::attachInterface>( +// *context); // TODO: simpler version with newly named dims +// ScatterOp::attachInterface>( +// *context); // TODO: simpler version with newly named dims +// ConvolutionOp::attachInterface>( +// *context); // TODO: simpler version with newly named dims +// }); +//} + +//void register_all(mlir::DialectRegistry& reg) { +// registry.insert(); +// registry.insert(); +// registry.insert(); +// registry.insert(); +// registry.insert(); +// registry.insert(); +// registry.insert(); +// registry.insert(); +// registry.insert(); +// registry.insert(); +// registry.insert(); +// registry.insert(); +// registry.insert(); +// registry.insert(); +// registry.insert(); +// registry.insert(); +// registry.insert(); +// registry.insert(); +// +// registry.insert(); +// +// mlir::registerenzymePasses(); +// regsiterenzymeXLAPasses(); +// mlir::enzyme::registerXLAAutoDiffInterfaces(registry); +// +// mlir::func::registerInlinerExtension(registry); +// +// // Register the standard passes we want. +// mlir::registerCSEPass(); +// mlir::registerConvertAffineToStandardPass(); +// mlir::registerSCCPPass(); +// mlir::registerInlinerPass(); +// mlir::registerCanonicalizerPass(); +// mlir::registerSymbolDCEPass(); +// mlir::registerLoopInvariantCodeMotionPass(); +// mlir::registerConvertSCFToOpenMPPass(); +// mlir::affine::registerAffinePasses(); +// mlir::registerReconcileUnrealizedCasts(); +// +// mlir::registerLLVMDialectImport(registry); +// mlir::registerNVVMDialectImport(registry); +// +// mlir::LLVM::registerInlinerInterface(registry); +// +// /* +// registry.addExtension(+[](MLIRContext *ctx, LLVM::LLVMDialect *dialect) { +// LLVM::LLVMFunctionType::attachInterface(*ctx); +// LLVM::LLVMArrayType::attachInterface(*ctx); +// LLVM::LLVMPointerType::attachInterface(*ctx); +// LLVM::LLVMStructType::attachInterface(*ctx); +// MemRefType::attachInterface>(*ctx); +// LLVM::LLVMStructType::attachInterface< +// PtrElementModel>(*ctx); +// LLVM::LLVMPointerType::attachInterface< +// PtrElementModel>(*ctx); +// LLVM::LLVMArrayType::attachInterface>( +// *ctx); +// }); +// */ +// +// // Register the autodiff interface implementations for upstream dialects. +// enzyme::registerCoreDialectAutodiffInterfaces(registry); +// +// // Transform dialect and extensions. +// mlir::transform::registerInterpreterPass(); +// mlir::linalg::registerTransformDialectExtension(registry); +// mlir::enzyme::registerGenerateApplyPatternsPass(); +// mlir::enzyme::registerRemoveTransformPass(); +// mlir::enzyme::registerEnzymeJaxTransformExtension(registry); +//} extern "C" { Pass* createDifferentiatePass() { return reinterpret_cast(mlir::enzyme::createDifferentiatePass().release()); } - ModuleOp* emitEnzymeADOp(ModuleOp& module_op) { + // doesn't belong here + ModuleOp* emitEnzymeADOp(ModuleOp& module_op, DialectRegistry& registry) { + auto& registry_ = reinterpret_cast(registry); + + mlir::enzyme::registerCoreDialectAutodiffInterfaces(registry_); + // registerXLAAutoDiffInterfaces(registry_); + // mlir::linalg::registerTransformDialectExtension(registry_); // how to import? + // mlir::enzyme::registerEnzymeJaxTransformExtension(registry_); // not public + // mlir::func::registerInlinerExtension(registry_); // not tried + auto module_op_ = reinterpret_cast(module_op); - auto state = mlir::OperationState("enzyme.autodiff", mlir::Location()); + auto ctx = module_op_.getContext(); + auto state = mlir::OperationState(mlir::UnknownLoc::get(ctx), "enzyme.autodiff"); - auto scalarf64 = mlir::RankedTensorType::Builder() - .setShape({}) - .setElementType(mlir::FloatType::getF64(ctx)); + auto scalarf64 = mlir::RankedTensorType::get({}, mlir::FloatType::getF64(ctx)); state.addTypes({scalarf64}); - auto operand = module_op_.getOperand({0}) // complete guess - state.addOperands(ValueRange({operand})); + auto operands = module_op_.getOperation()->getOperands(); // complete guess + state.addOperands(mlir::ValueRange(operands)); - auto operation = module_op_.front(); // complete guess - state.addAttribute("fn", operation.getAttr("sym_name")); + auto operation = module_op_.getOperation(); // complete guess + state.addAttribute("fn", operation->getAttr("sym_name")); auto activity = mlir::enzyme::ActivityAttr::get(ctx, mlir::enzyme::Activity::enzyme_active); state.addAttribute("activity", {activity}); // mlir::enzyme::Activity::enzyme_active auto ret_activity = mlir::enzyme::ActivityAttr::get( @@ -50,7 +187,7 @@ extern "C" { ); state.addAttribute("ret_activity", {ret_activity}); // mlir::enzyme::Activity::enzyme_activenoneed - auto res = mlir::ModuleOp(&mlir::Builder(module_op_), state); // complete guess + auto res = mlir::Operation::create(state); return reinterpret_cast(new mlir::ModuleOp(res)); } diff --git a/spidr/src/Compiler/Enzyme/Enzyme/Enzyme/MLIR/Passes/Passes.idr b/spidr/src/Compiler/Enzyme/Enzyme/Enzyme/MLIR/Passes/Passes.idr index 5791e81e2..31bcf29e0 100644 --- a/spidr/src/Compiler/Enzyme/Enzyme/Enzyme/MLIR/Passes/Passes.idr +++ b/spidr/src/Compiler/Enzyme/Enzyme/Enzyme/MLIR/Passes/Passes.idr @@ -16,6 +16,8 @@ limitations under the License. ||| For internal spidr use only. module Compiler.Enzyme.Enzyme.Enzyme.MLIR.Passes.Passes +import Compiler.MLIR.IR.BuiltinOps +import Compiler.MLIR.IR.DialectRegistry import Compiler.MLIR.Pass.Pass import Compiler.FFI @@ -28,3 +30,13 @@ createDifferentiatePass = do pass <- primIO prim__createDifferentiatePass pass <- onCollectAny pass (primIO . Pass.prim__delete) pure (MkPass pass) + +%foreign (libxla "emitEnzymeADOp") +prim__emitEnzymeADOp : GCAnyPtr -> GCAnyPtr -> PrimIO AnyPtr + +export +emitEnzymeADOp : HasIO io => ModuleOp -> DialectRegistry -> io ModuleOp +emitEnzymeADOp (MkModuleOp op) (MkDialectRegistry reg) = do + op <- primIO $ prim__emitEnzymeADOp op reg + op <- onCollectAny op (primIO . BuiltinOps.prim__delete) + pure (MkModuleOp op) diff --git a/spidr/src/Compiler/Eval.idr b/spidr/src/Compiler/Eval.idr index 78ed8038c..2734ffd39 100644 --- a/spidr/src/Compiler/Eval.idr +++ b/spidr/src/Compiler/Eval.idr @@ -147,7 +147,7 @@ interpret @{cache} xlaBuilder (MkFn params root env) = do addPass mgr !createDifferentiatePass computation <- compile xlaBuilder f stablehlo <- hloModuleProtoToStableHLO !(proto computation) - enzymeOp <- ?enzymeAutodiffReverseOp stablehlo + enzymeOp <- emitEnzymeADOp stablehlo reg True <- run mgr enzymeOp | False => throwE $ MlirPassError "Failed to run differentiate pass on StableHLO" hloProto <- convertStablehloToHlo stablehlo diff --git a/spidr/src/Compiler/MLIR/IR/Attributes.idr b/spidr/src/Compiler/MLIR/IR/Attributes.idr index d07815d42..c6b225c42 100644 --- a/spidr/src/Compiler/MLIR/IR/Attributes.idr +++ b/spidr/src/Compiler/MLIR/IR/Attributes.idr @@ -30,6 +30,6 @@ prim__deleteAttribute : AnyPtr -> PrimIO () export mkAttribute : HasIO io => io Attribute mkAttribute = do - Attribute <- primIO prim__mkAttribute - Attribute <- onCollectAny Attribute (primIO . prim__deleteAttribute) - pure (MkAttribute Attribute) + attr <- primIO prim__mkAttribute + attr <- onCollectAny attr (primIO . prim__deleteAttribute) + pure (MkAttribute attr) diff --git a/spidr/src/Compiler/MLIR/IR/BuiltinOps.idr b/spidr/src/Compiler/MLIR/IR/BuiltinOps.idr index 757c078b4..ba44b6583 100644 --- a/spidr/src/Compiler/MLIR/IR/BuiltinOps.idr +++ b/spidr/src/Compiler/MLIR/IR/BuiltinOps.idr @@ -14,12 +14,12 @@ See the License for the specific language governing permissions and limitations under the License. --} ||| For internal spidr use only. -module Compiler.MLIR.IR.Operation +module Compiler.MLIR.IR.BuiltinOps import Compiler.FFI public export -data Operation = MkOperation GCAnyPtr +data ModuleOp = MkModuleOp GCAnyPtr export %foreign (libxla "ModuleOp_delete") diff --git a/spidr/src/Compiler/MLIR/IR/OperationSupport.idr b/spidr/src/Compiler/MLIR/IR/OperationSupport.idr index 994c55939..fcb69cbf0 100644 --- a/spidr/src/Compiler/MLIR/IR/OperationSupport.idr +++ b/spidr/src/Compiler/MLIR/IR/OperationSupport.idr @@ -14,9 +14,11 @@ See the License for the specific language governing permissions and limitations under the License. --} ||| For internal spidr use only. -module Compiler.MLIR.Pass.OperationSupport +module Compiler.MLIR.IR.OperationSupport +import Compiler.MLIR.IR.Attributes import Compiler.MLIR.IR.Location +import Compiler.MLIR.IR.ValueRange import Compiler.FFI public export @@ -32,7 +34,7 @@ export mkOperationState : HasIO io => Location -> String -> io OperationState mkOperationState (MkLocation location) name = do opState <- primIO $ prim__mkOperationState location name - opState <- onCollectAny opState (primIO . OperationState.prim__delete) + opState <- onCollectAny opState (primIO . OperationSupport.prim__delete) pure (MkOperationState opState) %foreign (libxla "OperationState_addOperands") diff --git a/test/runner/Unit/TestTensor/AD.idr b/test/runner/Unit/TestTensor/AD.idr index c0ac11bad..3c7349f58 100644 --- a/test/runner/Unit/TestTensor/AD.idr +++ b/test/runner/Unit/TestTensor/AD.idr @@ -27,7 +27,7 @@ import Utils.Cases square : Device => Property square = fixedProperty $ do sqrt (square $ tensor 3.0) ===# tensor 3.0 --- grad (pure . square) (tensor 1.0) ===# pure (tensor 2.0) + grad (pure . square) (tensor 1.0) ===# pure (tensor 2.0) {- shape <- forAll shapes From 41bc6ad6163e5e303e1027a6f5594943927c8499 Mon Sep 17 00:00:00 2001 From: Joel Berkeley <16429957+joelberkeley@users.noreply.github.com> Date: Sat, 28 Dec 2024 00:29:58 +0000 Subject: [PATCH 25/38] wip --- XLA_VERSION | 2 +- spidr/backend/BUILD | 4 + spidr/backend/build.sh | 6 +- .../src/Enzyme-JAX/src/enzyme_ad/jax/BUILD | 12 ++ .../Enzyme-JAX/src/enzyme_ad/jax/Passes/BUILD | 33 ++++ .../src/enzyme_ad/jax/Passes/Passes.cpp | 175 +++++++++++++++++ .../src/enzyme_ad/jax/RegistryUtils.cpp | 24 +++ .../Enzyme/enzyme/Enzyme/MLIR/Dialect/BUILD | 1 + .../enzyme/Enzyme/MLIR/Passes/Passes.cpp | 181 +----------------- spidr/backend/src/mlir/IR/BuiltinOps.cpp | 7 + spidr/backend/src/mlir/IR/MLIRContext.cpp | 5 + .../src/xla/hlo/builder/xla_builder.cpp | 5 +- spidr/spidr.ipkg | 3 + .../Enzyme/Enzyme/MLIR/Passes/Passes.idr | 10 +- .../Src/EnzymeAD/JAX/Passes/Passes.idr | 33 ++++ .../Src/EnzymeAD/JAX/RegistryUtils.idr | 27 +++ spidr/src/Compiler/Eval.idr | 24 ++- spidr/src/Compiler/MLIR/IR/BuiltinOps.idr | 12 ++ spidr/src/Compiler/MLIR/IR/MLIRContext.idr | 10 + test/runner/Unit/TestTensor/AD.idr | 12 +- 20 files changed, 376 insertions(+), 210 deletions(-) create mode 100644 spidr/backend/src/Enzyme-JAX/src/enzyme_ad/jax/BUILD create mode 100644 spidr/backend/src/Enzyme-JAX/src/enzyme_ad/jax/Passes/BUILD create mode 100644 spidr/backend/src/Enzyme-JAX/src/enzyme_ad/jax/Passes/Passes.cpp create mode 100644 spidr/backend/src/Enzyme-JAX/src/enzyme_ad/jax/RegistryUtils.cpp create mode 100644 spidr/src/Compiler/EnzymeJAX/Src/EnzymeAD/JAX/Passes/Passes.idr create mode 100644 spidr/src/Compiler/EnzymeJAX/Src/EnzymeAD/JAX/RegistryUtils.idr diff --git a/XLA_VERSION b/XLA_VERSION index 2cf52560e..412bc18ca 100644 --- a/XLA_VERSION +++ b/XLA_VERSION @@ -1 +1 @@ -46bd34518db14456697224b9684f47bf4cdcf6ba \ No newline at end of file +b44f55da3dac449f03466815ac431474f86fd73f \ No newline at end of file diff --git a/spidr/backend/BUILD b/spidr/backend/BUILD index 414b2b47c..fd36a2838 100644 --- a/spidr/backend/BUILD +++ b/spidr/backend/BUILD @@ -12,6 +12,8 @@ cc_binary( linkshared = True, linkstatic = True, srcs = [ + "//src/Enzyme-JAX/src/enzyme_ad/jax", + "//src/Enzyme-JAX/src/enzyme_ad/jax/Passes", "//src/Enzyme/enzyme/Enzyme/MLIR/Dialect", "//src/Enzyme/enzyme/Enzyme/MLIR/Passes", "//src/llvm/Support", @@ -30,6 +32,8 @@ cc_binary( "//src", ], deps = [ + "//src/Enzyme-JAX/src/enzyme_ad/jax", + "//src/Enzyme-JAX/src/enzyme_ad/jax/Passes", "//src/Enzyme/enzyme/Enzyme/MLIR/Dialect", "//src/Enzyme/enzyme/Enzyme/MLIR/Passes", "//src/llvm/Support", diff --git a/spidr/backend/build.sh b/spidr/backend/build.sh index 61d5612e4..6da54e174 100755 --- a/spidr/backend/build.sh +++ b/spidr/backend/build.sh @@ -29,9 +29,9 @@ esac mkdir xla install_xla "$xla_rev" xla (cd xla; ./configure.py --backend=cpu --os=$os) - mkdir Enzyme-JAX - install_enzyme "$enzyme_rev" Enzyme-JAX - # sed -i -e 's/"-Werror=unused-variable",//g' Enzyme-JAX/src/enzyme_ad/jax/BUILD + # depending on Enzyme-JAX is problematic as it fixes the XLA version. Can we only depend on enzyme? +# mkdir Enzyme-JAX +# install_enzyme "$enzyme_rev" Enzyme-JAX bazel build //:c_xla rm -rf xla ) diff --git a/spidr/backend/src/Enzyme-JAX/src/enzyme_ad/jax/BUILD b/spidr/backend/src/Enzyme-JAX/src/enzyme_ad/jax/BUILD new file mode 100644 index 000000000..42058b418 --- /dev/null +++ b/spidr/backend/src/Enzyme-JAX/src/enzyme_ad/jax/BUILD @@ -0,0 +1,12 @@ +cc_library( + name = "jax", + linkstatic = True, + alwayslink = True, + srcs = glob(["*.cpp"]), + hdrs = glob(["*.h"]), + deps = [ + "@enzyme-jax//:everything", + "//src/mlir/IR", + ], + visibility = ["//visibility:public"], +) diff --git a/spidr/backend/src/Enzyme-JAX/src/enzyme_ad/jax/Passes/BUILD b/spidr/backend/src/Enzyme-JAX/src/enzyme_ad/jax/Passes/BUILD new file mode 100644 index 000000000..d9b330d82 --- /dev/null +++ b/spidr/backend/src/Enzyme-JAX/src/enzyme_ad/jax/Passes/BUILD @@ -0,0 +1,33 @@ +cc_library( + name = "Passes", + linkstatic = True, + alwayslink = True, + srcs = glob(["*.cpp"]), + hdrs = glob(["*.h"]), + deps = [ + "@xla//xla/hlo/builder:xla_builder", + "@xla//xla/hlo/translate:stablehlo", + "@xla//xla/hlo/builder/lib:math", + "@xla//xla/mlir_hlo:hlo_dialect_registration", + "@enzyme-jax//:everything", + "//src/mlir/IR", + "//src/mlir/Pass", + ], + visibility = ["//visibility:public"], +) + +cc_binary( + name = "example", + linkstatic = True, + srcs = glob(["*.cpp"]), + deps = [ + "@xla//xla/hlo/builder:xla_builder", + "@xla//xla/hlo/translate:stablehlo", + "@xla//xla/hlo/builder/lib:math", + "@xla//xla/mlir_hlo:hlo_dialect_registration", + "@enzyme-jax//:everything", + "//src/mlir/IR", + "//src/mlir/Pass", + ], + visibility = ["//visibility:public"], +) diff --git a/spidr/backend/src/Enzyme-JAX/src/enzyme_ad/jax/Passes/Passes.cpp b/spidr/backend/src/Enzyme-JAX/src/enzyme_ad/jax/Passes/Passes.cpp new file mode 100644 index 000000000..6b1e8d405 --- /dev/null +++ b/spidr/backend/src/Enzyme-JAX/src/enzyme_ad/jax/Passes/Passes.cpp @@ -0,0 +1,175 @@ +/* +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "../../../../../mlir/IR/BuiltinOps.h" +#include "../../../../../mlir/IR/DialectRegistry.h" +#include "../../../../../mlir/Pass/Pass.h" + +#include "stablehlo/dialect/Register.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/translate/stablehlo.h" +#include "xla/hlo/builder/lib/math.h" +#include "xla/mlir_hlo/mhlo/IR/register.h" + +#include "Enzyme/MLIR/Dialect/Dialect.h" +#include "Enzyme/MLIR/Dialect/Ops.h" +#include "Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h" +#include "Enzyme/MLIR/Passes/Passes.h" + +#include "src/enzyme_ad/jax/Dialect/Dialect.h" +#include "src/enzyme_ad/jax/Implementations/XLADerivatives.h" +#include "src/enzyme_ad/jax/Passes/Passes.h" +#include "src/enzyme_ad/jax/TransformOps/TransformOps.h" +#include "src/enzyme_ad/jax/RegistryUtils.h" + +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Location.h" +#include "mlir/Conversion/ConvertToLLVM/ToLLVMPass.h" +#include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h" +#include "mlir/Conversion/Passes.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Async/IR/Async.h" +#include "mlir/Dialect/Complex/IR/Complex.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" +#include "mlir/Dialect/DLTI/DLTI.h" +#include "mlir/Dialect/Func/Extensions/InlinerExtension.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/TransformOps/DialectExtension.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" +#include "mlir/Dialect/OpenMP/OpenMPDialect.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Transform/Transforms/Passes.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/InitAllPasses.h" +#include "mlir/Pass/PassRegistry.h" +#include "mlir/Tools/mlir-opt/MlirOptMain.h" +#include "mlir/Transforms/Passes.h" +#include "mlir/Target/LLVM/NVVM/Target.h" +#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h" +#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" +#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h" + +#include "llvm/Support/TargetSelect.h" + +#include "stablehlo/dialect/ChloOps.h" +#include "stablehlo/dialect/StablehloOps.h" +#include "stablehlo/tests/CheckOps.h" + +class MemRefInsider + : public mlir::MemRefElementTypeInterface::FallbackModel {}; + +template +struct PtrElementModel + : public mlir::LLVM::PointerElementTypeInterface::ExternalModel< + PtrElementModel, T> {}; + +//extern "C" { +// void regsiterenzymeXLAPasses_() { +// regsiterenzymeXLAPasses(); +// } +// +// void registerenzymePasses() { +// mlir::registerenzymePasses(); +// } +// +//// Pass* createDifferentiatePass() { +//// return reinterpret_cast(mlir::enzyme::createDifferentiatePass().release()); +//// } + +int main() { + xla::XlaBuilder builder("root"); + auto xlaScalarf64 = xla::ShapeUtil::MakeScalarShape((xla::PrimitiveType) 12); + auto arg = xla::Parameter(&builder, 0, xlaScalarf64, "arg"); + auto proto = builder.Build(xla::Square(arg))->proto(); + + mlir::MLIRContext ctx; + mlir::DialectRegistry registry_; + ctx.appendDialectRegistry(registry_); + mlir::mhlo::registerAllMhloDialects(registry_); + mlir::stablehlo::registerAllDialects(registry_); + + auto module_op_ = xla::ConvertHloToStablehlo(ctx, &proto).value().release(); + + llvm::InitializeNativeTarget(); + llvm::InitializeNativeTargetAsmPrinter(); + + registry_.insert(); + prepareRegistry(registry_); + + mlir::registerenzymePasses(); + regsiterenzymeXLAPasses(); + + mlir::registerCSEPass(); + mlir::registerConvertAffineToStandardPass(); + mlir::registerSCCPPass(); + mlir::registerInlinerPass(); + mlir::registerCanonicalizerPass(); + mlir::registerSymbolDCEPass(); + mlir::registerLoopInvariantCodeMotionPass(); + mlir::registerConvertSCFToOpenMPPass(); + mlir::affine::registerAffinePasses(); + mlir::registerReconcileUnrealizedCasts(); + + registry_.addExtension(+[](mlir::MLIRContext *ctx, mlir::LLVM::LLVMDialect *dialect) { + mlir::LLVM::LLVMFunctionType::attachInterface(*ctx); + mlir::LLVM::LLVMArrayType::attachInterface(*ctx); + mlir::LLVM::LLVMPointerType::attachInterface(*ctx); + mlir::LLVM::LLVMStructType::attachInterface(*ctx); + mlir::MemRefType::attachInterface>(*ctx); + mlir::LLVM::LLVMStructType::attachInterface< + PtrElementModel>(*ctx); + mlir::LLVM::LLVMPointerType::attachInterface< + PtrElementModel>(*ctx); + mlir::LLVM::LLVMArrayType::attachInterface>(*ctx); + }); + + mlir::transform::registerInterpreterPass(); + mlir::enzyme::registerGenerateApplyPatternsPass(); + mlir::enzyme::registerRemoveTransformPass(); + + auto state = mlir::OperationState(mlir::UnknownLoc::get(&ctx), "enzyme.autodiff"); + + auto scalarf64 = mlir::RankedTensorType::get({}, mlir::FloatType::getF64(&ctx)); + state.addTypes({scalarf64}); + + auto operands = module_op_.getOperation()->getOperands(); // complete guess + state.addOperands(mlir::ValueRange(operands)); + + auto operation = module_op_.getOperation(); // complete guess + state.addAttribute("fn", operation->getAttr("sym_name")); + auto activity = mlir::enzyme::ActivityAttr::get(&ctx, mlir::enzyme::Activity::enzyme_active); + state.addAttribute("activity", {activity}); + auto ret_activity = mlir::enzyme::ActivityAttr::get( + &ctx, mlir::enzyme::Activity::enzyme_activenoneed + ); + state.addAttribute("ret_activity", {ret_activity}); + + auto res = mlir::Operation::create(state); + + return 0; + +// return reinterpret_cast(new mlir::ModuleOp(res)); +//} +} diff --git a/spidr/backend/src/Enzyme-JAX/src/enzyme_ad/jax/RegistryUtils.cpp b/spidr/backend/src/Enzyme-JAX/src/enzyme_ad/jax/RegistryUtils.cpp new file mode 100644 index 000000000..5a0f5e983 --- /dev/null +++ b/spidr/backend/src/Enzyme-JAX/src/enzyme_ad/jax/RegistryUtils.cpp @@ -0,0 +1,24 @@ +/* +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +//#include "src/enzyme_ad/jax/RegistryUtils.h" +// +//#include "../../../../mlir/IR/DialectRegistry.h" +// +//extern "C" { +// void prepareRegistry_(DialectRegistry& registry) { +// prepareRegistry(reinterpret_cast(registry)); +// } +//} diff --git a/spidr/backend/src/Enzyme/enzyme/Enzyme/MLIR/Dialect/BUILD b/spidr/backend/src/Enzyme/enzyme/Enzyme/MLIR/Dialect/BUILD index f2b284155..2de39da72 100644 --- a/spidr/backend/src/Enzyme/enzyme/Enzyme/MLIR/Dialect/BUILD +++ b/spidr/backend/src/Enzyme/enzyme/Enzyme/MLIR/Dialect/BUILD @@ -5,6 +5,7 @@ cc_library( srcs = glob(["*.cpp"]), hdrs = glob(["*.h"]), deps = [ + "@stablehlo//:register", "@enzyme//:EnzymeMLIR", "@llvm-project//mlir:IR", "//src/mlir/IR", diff --git a/spidr/backend/src/Enzyme/enzyme/Enzyme/MLIR/Passes/Passes.cpp b/spidr/backend/src/Enzyme/enzyme/Enzyme/MLIR/Passes/Passes.cpp index c9763a7fb..1777112ab 100644 --- a/spidr/backend/src/Enzyme/enzyme/Enzyme/MLIR/Passes/Passes.cpp +++ b/spidr/backend/src/Enzyme/enzyme/Enzyme/MLIR/Passes/Passes.cpp @@ -12,183 +12,4 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. -*/ -#include "Enzyme/MLIR/Passes/Passes.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/Location.h" - -#include "../../../../../mlir/Pass/Pass.h" - -#include "Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h" -#include "Enzyme/MLIR/Dialect/Ops.h" -//#include "Enzyme/MLIR/Interfaces/GradientUtilsReverse.h" -//#include "Enzyme/MLIR/PassDetails.h" - -//#include "mlir/Dialect/Transform/IR/TransformDialect.h" -#include "mlir/IR/BuiltinTypes.h" - -#include "../../../../../mlir/IR/BuiltinOps.h" -#include "../../../../../mlir/IR/DialectRegistry.h" - -//void registerStableHLODialectAutoDiffInterface( -// DialectRegistry ®istry) { -// registry.addExtension(+[](MLIRContext *context, -// stablehlo::StablehloDialect *) { -// registerInterfaces(context); -// -// // SortOp::attachInterface(*context); -// -// WhileOp::attachInterface(*context); -// SortOp::attachInterface(*context); -// ScatterOp::attachInterface(*context); -// ReduceOp::attachInterface(*context); -// -// CaseOp::attachInterface(*context); -// -// ScatterOp::attachInterface(*context); -// ScatterOp::attachInterface(*context); -// -// ReturnOp::attachInterface(*context); -// -// ReduceOp::attachInterface>(*context); -// IfOp::attachInterface(*context); -// IfOp::attachInterface(*context); -// IfOp::attachInterface(*context); -// -// WhileOp::attachInterface(*context); -// WhileOp::attachInterface(*context); -// ReduceOp::attachInterface>(*context); -// WhileOp::attachInterface>(*context); -// BroadcastInDimOp::attachInterface(*context); -// SliceOp::attachInterface(*context); -// DynamicUpdateSliceOp::attachInterface( -// *context); -// ReduceOp::attachInterface(*context); -// ConcatenateOp::attachInterface(*context); -// -// ConstantOp::attachInterface(*context); -// TransposeOp::attachInterface(*context); -// IfOp::attachInterface>(*context); -// WhileOp::attachInterface>(*context); -// -// ReverseOp::attachInterface>( -// *context); // TODO: simpler version with newly named dims -// ScatterOp::attachInterface>( -// *context); // TODO: simpler version with newly named dims -// ConvolutionOp::attachInterface>( -// *context); // TODO: simpler version with newly named dims -// }); -//} - -//void register_all(mlir::DialectRegistry& reg) { -// registry.insert(); -// registry.insert(); -// registry.insert(); -// registry.insert(); -// registry.insert(); -// registry.insert(); -// registry.insert(); -// registry.insert(); -// registry.insert(); -// registry.insert(); -// registry.insert(); -// registry.insert(); -// registry.insert(); -// registry.insert(); -// registry.insert(); -// registry.insert(); -// registry.insert(); -// registry.insert(); -// -// registry.insert(); -// -// mlir::registerenzymePasses(); -// regsiterenzymeXLAPasses(); -// mlir::enzyme::registerXLAAutoDiffInterfaces(registry); -// -// mlir::func::registerInlinerExtension(registry); -// -// // Register the standard passes we want. -// mlir::registerCSEPass(); -// mlir::registerConvertAffineToStandardPass(); -// mlir::registerSCCPPass(); -// mlir::registerInlinerPass(); -// mlir::registerCanonicalizerPass(); -// mlir::registerSymbolDCEPass(); -// mlir::registerLoopInvariantCodeMotionPass(); -// mlir::registerConvertSCFToOpenMPPass(); -// mlir::affine::registerAffinePasses(); -// mlir::registerReconcileUnrealizedCasts(); -// -// mlir::registerLLVMDialectImport(registry); -// mlir::registerNVVMDialectImport(registry); -// -// mlir::LLVM::registerInlinerInterface(registry); -// -// /* -// registry.addExtension(+[](MLIRContext *ctx, LLVM::LLVMDialect *dialect) { -// LLVM::LLVMFunctionType::attachInterface(*ctx); -// LLVM::LLVMArrayType::attachInterface(*ctx); -// LLVM::LLVMPointerType::attachInterface(*ctx); -// LLVM::LLVMStructType::attachInterface(*ctx); -// MemRefType::attachInterface>(*ctx); -// LLVM::LLVMStructType::attachInterface< -// PtrElementModel>(*ctx); -// LLVM::LLVMPointerType::attachInterface< -// PtrElementModel>(*ctx); -// LLVM::LLVMArrayType::attachInterface>( -// *ctx); -// }); -// */ -// -// // Register the autodiff interface implementations for upstream dialects. -// enzyme::registerCoreDialectAutodiffInterfaces(registry); -// -// // Transform dialect and extensions. -// mlir::transform::registerInterpreterPass(); -// mlir::linalg::registerTransformDialectExtension(registry); -// mlir::enzyme::registerGenerateApplyPatternsPass(); -// mlir::enzyme::registerRemoveTransformPass(); -// mlir::enzyme::registerEnzymeJaxTransformExtension(registry); -//} - -extern "C" { - Pass* createDifferentiatePass() { - return reinterpret_cast(mlir::enzyme::createDifferentiatePass().release()); - } - - // doesn't belong here - ModuleOp* emitEnzymeADOp(ModuleOp& module_op, DialectRegistry& registry) { - auto& registry_ = reinterpret_cast(registry); - - mlir::enzyme::registerCoreDialectAutodiffInterfaces(registry_); - // registerXLAAutoDiffInterfaces(registry_); - // mlir::linalg::registerTransformDialectExtension(registry_); // how to import? - // mlir::enzyme::registerEnzymeJaxTransformExtension(registry_); // not public - // mlir::func::registerInlinerExtension(registry_); // not tried - - auto module_op_ = reinterpret_cast(module_op); - - auto ctx = module_op_.getContext(); - auto state = mlir::OperationState(mlir::UnknownLoc::get(ctx), "enzyme.autodiff"); - - auto scalarf64 = mlir::RankedTensorType::get({}, mlir::FloatType::getF64(ctx)); - state.addTypes({scalarf64}); - - auto operands = module_op_.getOperation()->getOperands(); // complete guess - state.addOperands(mlir::ValueRange(operands)); - - auto operation = module_op_.getOperation(); // complete guess - state.addAttribute("fn", operation->getAttr("sym_name")); - auto activity = mlir::enzyme::ActivityAttr::get(ctx, mlir::enzyme::Activity::enzyme_active); - state.addAttribute("activity", {activity}); // mlir::enzyme::Activity::enzyme_active - auto ret_activity = mlir::enzyme::ActivityAttr::get( - ctx, mlir::enzyme::Activity::enzyme_active - ); - state.addAttribute("ret_activity", {ret_activity}); // mlir::enzyme::Activity::enzyme_activenoneed - - auto res = mlir::Operation::create(state); - - return reinterpret_cast(new mlir::ModuleOp(res)); - } -} +*/ \ No newline at end of file diff --git a/spidr/backend/src/mlir/IR/BuiltinOps.cpp b/spidr/backend/src/mlir/IR/BuiltinOps.cpp index 5609c3a95..ffa671600 100644 --- a/spidr/backend/src/mlir/IR/BuiltinOps.cpp +++ b/spidr/backend/src/mlir/IR/BuiltinOps.cpp @@ -16,9 +16,16 @@ limitations under the License. #include "mlir/IR/BuiltinOps.h" #include "BuiltinOps.h" +#include "MLIRContext.h" extern "C" { void ModuleOp_delete(ModuleOp* s) { delete reinterpret_cast(s); } + + // who owns this? + MLIRContext* ModuleOp_getContext(ModuleOp& s) { + auto s_ = reinterpret_cast(s); + return reinterpret_cast(s_.getContext()); + } } diff --git a/spidr/backend/src/mlir/IR/MLIRContext.cpp b/spidr/backend/src/mlir/IR/MLIRContext.cpp index 9083b03e7..361e43a90 100644 --- a/spidr/backend/src/mlir/IR/MLIRContext.cpp +++ b/spidr/backend/src/mlir/IR/MLIRContext.cpp @@ -27,6 +27,11 @@ extern "C" { delete reinterpret_cast(s); } +// DialectRegistry* MLIRContext_getDialectRegistry(MLIRContext& s) { +// auto& s_ = reinterpret_cast(s); +// return reinterpret_cast(s_.getDialectRegistry()); +// } + void MLIRContext_appendDialectRegistry(MLIRContext& s, DialectRegistry& registry) { auto& registry_ = reinterpret_cast(registry); reinterpret_cast(s).appendDialectRegistry(registry_); diff --git a/spidr/backend/src/xla/hlo/builder/xla_builder.cpp b/spidr/backend/src/xla/hlo/builder/xla_builder.cpp index 289ca6938..0b925626b 100644 --- a/spidr/backend/src/xla/hlo/builder/xla_builder.cpp +++ b/spidr/backend/src/xla/hlo/builder/xla_builder.cpp @@ -350,7 +350,10 @@ extern "C" { } XlaOp* Abs(XlaOp& operand) { return unaryOp(xla::Abs, operand); } - XlaOp* Exp(XlaOp& operand) { return unaryOp(xla::Exp, operand); } + XlaOp* Exp(XlaOp& operand) { + xla::XlaOp res = xla::Exp(reinterpret_cast(operand)); + return reinterpret_cast(new xla::XlaOp(res)); + } XlaOp* Floor(XlaOp& operand) { return unaryOp(xla::Floor, operand); } XlaOp* Ceil(XlaOp& operand) { return unaryOp(xla::Ceil, operand); } XlaOp* Log(XlaOp& operand) { return unaryOp(xla::Log, operand); } diff --git a/spidr/spidr.ipkg b/spidr/spidr.ipkg index f6a09f748..486eeaf6b 100644 --- a/spidr/spidr.ipkg +++ b/spidr/spidr.ipkg @@ -8,6 +8,9 @@ modules = BayesianOptimization, BayesianOptimization.Acquisition, + Compiler.EnzymeJAX.Src.EnzymeAD.JAX.Passes.Passes, + Compiler.EnzymeJAX.Src.EnzymeAD.JAX.RegistryUtils, + Compiler.Enzyme.Enzyme.Enzyme.MLIR.Dialect.Dialect, Compiler.Enzyme.Enzyme.Enzyme.MLIR.Passes.Passes, diff --git a/spidr/src/Compiler/Enzyme/Enzyme/Enzyme/MLIR/Passes/Passes.idr b/spidr/src/Compiler/Enzyme/Enzyme/Enzyme/MLIR/Passes/Passes.idr index 31bcf29e0..c31e148e8 100644 --- a/spidr/src/Compiler/Enzyme/Enzyme/Enzyme/MLIR/Passes/Passes.idr +++ b/spidr/src/Compiler/Enzyme/Enzyme/Enzyme/MLIR/Passes/Passes.idr @@ -32,11 +32,11 @@ createDifferentiatePass = do pure (MkPass pass) %foreign (libxla "emitEnzymeADOp") -prim__emitEnzymeADOp : GCAnyPtr -> GCAnyPtr -> PrimIO AnyPtr +prim__emitEnzymeADOp : PrimIO () export -emitEnzymeADOp : HasIO io => ModuleOp -> DialectRegistry -> io ModuleOp -emitEnzymeADOp (MkModuleOp op) (MkDialectRegistry reg) = do - op <- primIO $ prim__emitEnzymeADOp op reg - op <- onCollectAny op (primIO . BuiltinOps.prim__delete) +emitEnzymeADOp : HasIO io => ModuleOp -> io ModuleOp +emitEnzymeADOp (MkModuleOp op) = do + _ <- primIO $ prim__emitEnzymeADOp + --op <- onCollectAny op (primIO . BuiltinOps.prim__delete) pure (MkModuleOp op) diff --git a/spidr/src/Compiler/EnzymeJAX/Src/EnzymeAD/JAX/Passes/Passes.idr b/spidr/src/Compiler/EnzymeJAX/Src/EnzymeAD/JAX/Passes/Passes.idr new file mode 100644 index 000000000..dbae4dd88 --- /dev/null +++ b/spidr/src/Compiler/EnzymeJAX/Src/EnzymeAD/JAX/Passes/Passes.idr @@ -0,0 +1,33 @@ +{-- +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--} +||| For internal spidr use only. +module Compiler.EnzymeJAX.Src.EnzymeAD.JAX.Passes.Passes + +import Compiler.FFI + +%foreign (libxla "regsiterenzymeXLAPasses_") +prim__regsiterenzymeXLAPasses : PrimIO () + +export +regsiterenzymeXLAPasses : HasIO io => io () +regsiterenzymeXLAPasses = primIO prim__regsiterenzymeXLAPasses + +%foreign (libxla "registerenzymePasses") +prim__registerenzymePasses : PrimIO () + +export +registerenzymePasses : HasIO io => io () +registerenzymePasses = primIO prim__registerenzymePasses diff --git a/spidr/src/Compiler/EnzymeJAX/Src/EnzymeAD/JAX/RegistryUtils.idr b/spidr/src/Compiler/EnzymeJAX/Src/EnzymeAD/JAX/RegistryUtils.idr new file mode 100644 index 000000000..65f9bff4c --- /dev/null +++ b/spidr/src/Compiler/EnzymeJAX/Src/EnzymeAD/JAX/RegistryUtils.idr @@ -0,0 +1,27 @@ +{-- +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--} +||| For internal spidr use only. +module Compiler.EnzymeJAX.Src.EnzymeAD.JAX.RegistryUtils + +import Compiler.MLIR.IR.DialectRegistry +import Compiler.FFI + +%foreign (libxla "prepareRegistry_") +prim__prepareRegistry : GCAnyPtr -> PrimIO () + +export +prepareRegistry : HasIO io => DialectRegistry -> io () +prepareRegistry (MkDialectRegistry registry) = primIO $ prim__prepareRegistry registry diff --git a/spidr/src/Compiler/Eval.idr b/spidr/src/Compiler/Eval.idr index 2734ffd39..5c47dd7d5 100644 --- a/spidr/src/Compiler/Eval.idr +++ b/spidr/src/Compiler/Eval.idr @@ -23,9 +23,10 @@ import Data.IOArray import Data.List import Data.List.Elem +import Compiler.EnzymeJAX.Src.EnzymeAD.JAX.Passes.Passes +import Compiler.EnzymeJAX.Src.EnzymeAD.JAX.RegistryUtils import Compiler.Enzyme.Enzyme.Enzyme.MLIR.Dialect.Dialect import Compiler.Enzyme.Enzyme.Enzyme.MLIR.Passes.Passes ---import Compiler.EnzymeJAX.Src.EnzymeAD.JAX.Implementations.StableHLOAutoDiffOpInterfaceImpl import Compiler.LLVM.Support.RawOStream import Compiler.MLIR.IR.BuiltinOps import Compiler.MLIR.IR.DialectRegistry @@ -135,19 +136,22 @@ interpret @{cache} xlaBuilder (MkFn params root env) = do interpretE (Tuple xs) = tuple xlaBuilder !(traverse interpretE xs) interpretE (GetTupleElement idx x) = getTupleElement !(interpretE x) idx interpretE (Grad f x) = do - reg <- mkDialectRegistry + computation <- compile xlaBuilder f + stablehlo <- hloModuleProtoToStableHLO !(proto computation) + ctx <- getContext stablehlo + -- reg <- mkDialectRegistry + -- appendDialectRegistry ctx reg + -- insertEnzymeDialect reg + enzymeOp <- emitEnzymeADOp stablehlo + --regsiterenzymeXLAPasses + --prepareRegistry reg + --registerenzymePasses -- need other dialects? - insertEnzymeDialect reg - StableHLO.Dialect.Register.registerAllDialects reg + -- surely the ModuleOp already has stablehlo registered, since it's stablehlo code + -- StableHLO.Dialect.Register.registerAllDialects reg -- registerStableHLODialectAutoDiffInterface reg - -- should we instead be getting the context from the stablehlo ModuleOp? - ctx <- mkMLIRContext - appendDialectRegistry ctx reg mgr <- mkPassManager ctx addPass mgr !createDifferentiatePass - computation <- compile xlaBuilder f - stablehlo <- hloModuleProtoToStableHLO !(proto computation) - enzymeOp <- emitEnzymeADOp stablehlo reg True <- run mgr enzymeOp | False => throwE $ MlirPassError "Failed to run differentiate pass on StableHLO" hloProto <- convertStablehloToHlo stablehlo diff --git a/spidr/src/Compiler/MLIR/IR/BuiltinOps.idr b/spidr/src/Compiler/MLIR/IR/BuiltinOps.idr index ba44b6583..0ac3fa2ee 100644 --- a/spidr/src/Compiler/MLIR/IR/BuiltinOps.idr +++ b/spidr/src/Compiler/MLIR/IR/BuiltinOps.idr @@ -16,6 +16,7 @@ limitations under the License. ||| For internal spidr use only. module Compiler.MLIR.IR.BuiltinOps +import Compiler.MLIR.IR.MLIRContext import Compiler.FFI public export @@ -24,3 +25,14 @@ data ModuleOp = MkModuleOp GCAnyPtr export %foreign (libxla "ModuleOp_delete") prim__delete : AnyPtr -> PrimIO () + +export +%foreign (libxla "ModuleOp_getContext") +prim__moduleOp : GCAnyPtr -> PrimIO AnyPtr + +export +getContext : HasIO io => ModuleOp -> io MLIRContext +getContext (MkModuleOp op) = do + ctx <- primIO $ prim__moduleOp op + ctx <- onCollectAny ctx (const $ pure ()) -- I reckon we've already GC'ed this + pure (MkMLIRContext ctx) diff --git a/spidr/src/Compiler/MLIR/IR/MLIRContext.idr b/spidr/src/Compiler/MLIR/IR/MLIRContext.idr index 645a21e56..42c91caa3 100644 --- a/spidr/src/Compiler/MLIR/IR/MLIRContext.idr +++ b/spidr/src/Compiler/MLIR/IR/MLIRContext.idr @@ -35,6 +35,16 @@ mkMLIRContext = do ctx <- onCollectAny ctx (primIO . prim__deleteMLIRContext) pure (MkMLIRContext ctx) +%foreign (libxla "MLIRContext_getDialectRegistry") +prim__getDialectRegistry : GCAnyPtr -> PrimIO AnyPtr + +export +getDialectRegistry : HasIO io => MLIRContext -> io DialectRegistry +getDialectRegistry (MkMLIRContext ctx) = do + registry <- primIO $ prim__getDialectRegistry ctx + registry <- onCollectAny registry (const $ pure ()) -- correct? + pure (MkDialectRegistry registry) + %foreign (libxla "MLIRContext_appendDialectRegistry") prim__appendDialectRegistry : GCAnyPtr -> GCAnyPtr -> PrimIO () diff --git a/test/runner/Unit/TestTensor/AD.idr b/test/runner/Unit/TestTensor/AD.idr index 3c7349f58..365d02014 100644 --- a/test/runner/Unit/TestTensor/AD.idr +++ b/test/runner/Unit/TestTensor/AD.idr @@ -26,16 +26,8 @@ import Utils.Cases square : Device => Property square = fixedProperty $ do - sqrt (square $ tensor 3.0) ===# tensor 3.0 - grad (pure . square) (tensor 1.0) ===# pure (tensor 2.0) -{- - shape <- forAll shapes - - x <- forAll (literal shape doubles) - let x' = tensor {dtype = F64} x - map id x ==~ unsafeEval (map pure x') - map (1.0 /) x ==~ Tag.unsafeEval (map (pure . (1.0 /)) x') --} + square (tensor 3.0) ===# tensor 9.0 + grad (pure . square) (tensor 3.0) ===# pure (tensor 6.0) export all : Device => List (PropertyName, Property) From ff858281eaf989b9b8e4526b14c5f4e1a392e17c Mon Sep 17 00:00:00 2001 From: Joel Berkeley <16429957+joelberkeley@users.noreply.github.com> Date: Sat, 4 Jan 2025 15:36:40 +0000 Subject: [PATCH 26/38] moses suggestion --- spidr/backend/src/Enzyme-JAX/src/enzyme_ad/jax/Passes/Passes.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/spidr/backend/src/Enzyme-JAX/src/enzyme_ad/jax/Passes/Passes.cpp b/spidr/backend/src/Enzyme-JAX/src/enzyme_ad/jax/Passes/Passes.cpp index 6b1e8d405..871488cb2 100644 --- a/spidr/backend/src/Enzyme-JAX/src/enzyme_ad/jax/Passes/Passes.cpp +++ b/spidr/backend/src/Enzyme-JAX/src/enzyme_ad/jax/Passes/Passes.cpp @@ -109,6 +109,7 @@ int main() { ctx.appendDialectRegistry(registry_); mlir::mhlo::registerAllMhloDialects(registry_); mlir::stablehlo::registerAllDialects(registry_); + ctx.insertenzyme::EnzymeDialect(); auto module_op_ = xla::ConvertHloToStablehlo(ctx, &proto).value().release(); From 6bd295de370b7b17bd275502c999fdbf7b61d7a3 Mon Sep 17 00:00:00 2001 From: Joel Berkeley <16429957+joelberkeley@users.noreply.github.com> Date: Sat, 4 Jan 2025 15:40:30 +0000 Subject: [PATCH 27/38] wip --- spidr/backend/build.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/spidr/backend/build.sh b/spidr/backend/build.sh index 6da54e174..08af0e449 100755 --- a/spidr/backend/build.sh +++ b/spidr/backend/build.sh @@ -30,8 +30,8 @@ esac install_xla "$xla_rev" xla (cd xla; ./configure.py --backend=cpu --os=$os) # depending on Enzyme-JAX is problematic as it fixes the XLA version. Can we only depend on enzyme? -# mkdir Enzyme-JAX -# install_enzyme "$enzyme_rev" Enzyme-JAX + mkdir Enzyme-JAX + install_enzyme "$enzyme_rev" Enzyme-JAX bazel build //:c_xla rm -rf xla ) From 768b5b56d547bc48161c7b4dea69f839fd3c62de Mon Sep 17 00:00:00 2001 From: Joel Berkeley <16429957+joelberkeley@users.noreply.github.com> Date: Sat, 4 Jan 2025 15:45:23 +0000 Subject: [PATCH 28/38] everything --- spidr/backend/build.sh | 2 ++ spidr/backend/everything | 54 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 56 insertions(+) create mode 100644 spidr/backend/everything diff --git a/spidr/backend/build.sh b/spidr/backend/build.sh index 08af0e449..3d9227dae 100755 --- a/spidr/backend/build.sh +++ b/spidr/backend/build.sh @@ -30,8 +30,10 @@ esac install_xla "$xla_rev" xla (cd xla; ./configure.py --backend=cpu --os=$os) # depending on Enzyme-JAX is problematic as it fixes the XLA version. Can we only depend on enzyme? + # seems unlikely that they could decouple XLA entirely. They almost certainly can't decouple stablehlo mkdir Enzyme-JAX install_enzyme "$enzyme_rev" Enzyme-JAX + cat everything >> Enzyme-JAX/BUILD bazel build //:c_xla rm -rf xla ) diff --git a/spidr/backend/everything b/spidr/backend/everything new file mode 100644 index 000000000..3705d060e --- /dev/null +++ b/spidr/backend/everything @@ -0,0 +1,54 @@ + +cc_library( + name = "everything", + srcs = [ + "//src/enzyme_ad/jax:TransformOps", + "//src/enzyme_ad/jax:XLADerivatives", + "//src/enzyme_ad/jax:RegistryUtils.cpp", + ], + hdrs = [ + "//src/enzyme_ad/jax:TransformOps", + "//src/enzyme_ad/jax:XLADerivatives", + "//src/enzyme_ad/jax:RegistryUtils.h", + ], + visibility = ["//visibility:public"], + deps = [ + "@enzyme//:EnzymeMLIR", + "@llvm-project//mlir:AffineDialect", + "@llvm-project//mlir:AllPassesAndDialects", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:AsyncDialect", + "@llvm-project//mlir:ComplexDialect", + "@llvm-project//mlir:ControlFlowDialect", + "@llvm-project//mlir:ConversionPasses", + "@llvm-project//mlir:DLTIDialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:LinalgDialect", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:MathDialect", + "@llvm-project//mlir:MemRefDialect", + "@llvm-project//mlir:MlirOptLib", + "@llvm-project//mlir:NVVMDialect", + "@llvm-project//mlir:NVGPUDialect", + "@llvm-project//mlir:OpenMPDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:TransformDialect", + "@llvm-project//mlir:Transforms", + "//src/enzyme_ad/jax:TransformOps", + "//src/enzyme_ad/jax:XLADerivatives", + "@stablehlo//:chlo_ops", + "@stablehlo//stablehlo/tests:check_ops", + "@llvm-project//mlir:ArithToLLVM", + "@llvm-project//mlir:BuiltinToLLVMIRTranslation", + "@llvm-project//mlir:ComplexToLLVM", + "@llvm-project//mlir:ControlFlowToLLVM", + "@llvm-project//mlir:GPUToLLVMIRTranslation", + "@llvm-project//mlir:LLVMToLLVMIRTranslation", + "@llvm-project//mlir:NVVMToLLVMIRTranslation", + + "@llvm-project//llvm:X86AsmParser", + "@llvm-project//llvm:X86CodeGen", + ], +) From 964c8756fa4c73ce74d07894abc9bc3ce45fbfc2 Mon Sep 17 00:00:00 2001 From: Joel Berkeley <16429957+joelberkeley@users.noreply.github.com> Date: Sat, 4 Jan 2025 16:59:48 +0000 Subject: [PATCH 29/38] enz version --- spidr/backend/ENZYME_JAX_VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spidr/backend/ENZYME_JAX_VERSION b/spidr/backend/ENZYME_JAX_VERSION index f90916d1f..cb76d9f9c 100644 --- a/spidr/backend/ENZYME_JAX_VERSION +++ b/spidr/backend/ENZYME_JAX_VERSION @@ -1 +1 @@ -51687b09d49dee1044c6767c0aca9b3dbb3c97d5 \ No newline at end of file +b6d6563aa3a3050474a4250bf18322f7ebf0b486 \ No newline at end of file From 3460cb741d0ee4b61e9035e5b84c6c420a6556f8 Mon Sep 17 00:00:00 2001 From: Joel Berkeley <16429957+joelberkeley@users.noreply.github.com> Date: Sat, 4 Jan 2025 17:02:01 +0000 Subject: [PATCH 30/38] revert xla version --- XLA_VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/XLA_VERSION b/XLA_VERSION index 412bc18ca..2cf52560e 100644 --- a/XLA_VERSION +++ b/XLA_VERSION @@ -1 +1 @@ -b44f55da3dac449f03466815ac431474f86fd73f \ No newline at end of file +46bd34518db14456697224b9684f47bf4cdcf6ba \ No newline at end of file From 33843fc490b22c3085a7757387279669d2e5e1f8 Mon Sep 17 00:00:00 2001 From: Joel Berkeley <16429957+joelberkeley@users.noreply.github.com> Date: Sat, 4 Jan 2025 17:06:42 +0000 Subject: [PATCH 31/38] wip --- XLA_VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/XLA_VERSION b/XLA_VERSION index 2cf52560e..412bc18ca 100644 --- a/XLA_VERSION +++ b/XLA_VERSION @@ -1 +1 @@ -46bd34518db14456697224b9684f47bf4cdcf6ba \ No newline at end of file +b44f55da3dac449f03466815ac431474f86fd73f \ No newline at end of file From d50956a57afe617f5e979e3537c2ca1220716bd2 Mon Sep 17 00:00:00 2001 From: Joel Berkeley <16429957+joelberkeley@users.noreply.github.com> Date: Sat, 4 Jan 2025 18:03:10 +0000 Subject: [PATCH 32/38] wip --- .../backend/src/Enzyme-JAX/src/enzyme_ad/jax/Passes/Passes.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spidr/backend/src/Enzyme-JAX/src/enzyme_ad/jax/Passes/Passes.cpp b/spidr/backend/src/Enzyme-JAX/src/enzyme_ad/jax/Passes/Passes.cpp index 871488cb2..1edad320b 100644 --- a/spidr/backend/src/Enzyme-JAX/src/enzyme_ad/jax/Passes/Passes.cpp +++ b/spidr/backend/src/Enzyme-JAX/src/enzyme_ad/jax/Passes/Passes.cpp @@ -109,13 +109,13 @@ int main() { ctx.appendDialectRegistry(registry_); mlir::mhlo::registerAllMhloDialects(registry_); mlir::stablehlo::registerAllDialects(registry_); - ctx.insertenzyme::EnzymeDialect(); auto module_op_ = xla::ConvertHloToStablehlo(ctx, &proto).value().release(); llvm::InitializeNativeTarget(); llvm::InitializeNativeTargetAsmPrinter(); + registry_.insert(); registry_.insert(); prepareRegistry(registry_); From 2b6d8dcee43a175bcc7c789fd4ba5259ca3256c4 Mon Sep 17 00:00:00 2001 From: Joel Berkeley <16429957+joelberkeley@users.noreply.github.com> Date: Sat, 4 Jan 2025 20:12:46 +0000 Subject: [PATCH 33/38] wip --- pjrt-plugins/xla-cpu/.gitignore | 1 + pjrt-plugins/xla-cpu/build.sh | 3 +- pjrt-plugins/xla-cuda/build.sh | 3 +- .../src/enzyme_ad/jax/Passes/Passes.cpp | 157 +++++++++--------- .../Enzyme/Enzyme/MLIR/Passes/Passes.idr | 6 +- 5 files changed, 89 insertions(+), 81 deletions(-) create mode 100644 pjrt-plugins/xla-cpu/.gitignore diff --git a/pjrt-plugins/xla-cpu/.gitignore b/pjrt-plugins/xla-cpu/.gitignore new file mode 100644 index 000000000..bb617c857 --- /dev/null +++ b/pjrt-plugins/xla-cpu/.gitignore @@ -0,0 +1 @@ +xla/ diff --git a/pjrt-plugins/xla-cpu/build.sh b/pjrt-plugins/xla-cpu/build.sh index 094d8c41d..69a3a5140 100755 --- a/pjrt-plugins/xla-cpu/build.sh +++ b/pjrt-plugins/xla-cpu/build.sh @@ -23,7 +23,8 @@ case $osu in ;; esac -xla_dir=$(mktemp -d) +xla_dir=pjrt-plugins/xla-cpu/xla +mkdir "$xla_dir" install_xla "$rev" "$xla_dir" ( cd "$xla_dir" diff --git a/pjrt-plugins/xla-cuda/build.sh b/pjrt-plugins/xla-cuda/build.sh index 16277cc5c..c61c51005 100755 --- a/pjrt-plugins/xla-cuda/build.sh +++ b/pjrt-plugins/xla-cuda/build.sh @@ -15,7 +15,8 @@ case $osu in ;; esac -xla_dir=$(mktemp -d) +xla_dir=pjrt-plugins/xla-cuda/xla +mkdir "$xla_dir" install_xla "$rev" "$xla_dir" ( cd "$xla_dir" diff --git a/spidr/backend/src/Enzyme-JAX/src/enzyme_ad/jax/Passes/Passes.cpp b/spidr/backend/src/Enzyme-JAX/src/enzyme_ad/jax/Passes/Passes.cpp index 1edad320b..e4a18e20b 100644 --- a/spidr/backend/src/Enzyme-JAX/src/enzyme_ad/jax/Passes/Passes.cpp +++ b/spidr/backend/src/Enzyme-JAX/src/enzyme_ad/jax/Passes/Passes.cpp @@ -85,7 +85,7 @@ struct PtrElementModel : public mlir::LLVM::PointerElementTypeInterface::ExternalModel< PtrElementModel, T> {}; -//extern "C" { +extern "C" { // void regsiterenzymeXLAPasses_() { // regsiterenzymeXLAPasses(); // } @@ -98,79 +98,84 @@ struct PtrElementModel //// return reinterpret_cast(mlir::enzyme::createDifferentiatePass().release()); //// } -int main() { - xla::XlaBuilder builder("root"); - auto xlaScalarf64 = xla::ShapeUtil::MakeScalarShape((xla::PrimitiveType) 12); - auto arg = xla::Parameter(&builder, 0, xlaScalarf64, "arg"); - auto proto = builder.Build(xla::Square(arg))->proto(); - - mlir::MLIRContext ctx; - mlir::DialectRegistry registry_; - ctx.appendDialectRegistry(registry_); - mlir::mhlo::registerAllMhloDialects(registry_); - mlir::stablehlo::registerAllDialects(registry_); - - auto module_op_ = xla::ConvertHloToStablehlo(ctx, &proto).value().release(); - - llvm::InitializeNativeTarget(); - llvm::InitializeNativeTargetAsmPrinter(); - - registry_.insert(); - registry_.insert(); - prepareRegistry(registry_); - - mlir::registerenzymePasses(); - regsiterenzymeXLAPasses(); - - mlir::registerCSEPass(); - mlir::registerConvertAffineToStandardPass(); - mlir::registerSCCPPass(); - mlir::registerInlinerPass(); - mlir::registerCanonicalizerPass(); - mlir::registerSymbolDCEPass(); - mlir::registerLoopInvariantCodeMotionPass(); - mlir::registerConvertSCFToOpenMPPass(); - mlir::affine::registerAffinePasses(); - mlir::registerReconcileUnrealizedCasts(); - - registry_.addExtension(+[](mlir::MLIRContext *ctx, mlir::LLVM::LLVMDialect *dialect) { - mlir::LLVM::LLVMFunctionType::attachInterface(*ctx); - mlir::LLVM::LLVMArrayType::attachInterface(*ctx); - mlir::LLVM::LLVMPointerType::attachInterface(*ctx); - mlir::LLVM::LLVMStructType::attachInterface(*ctx); - mlir::MemRefType::attachInterface>(*ctx); - mlir::LLVM::LLVMStructType::attachInterface< - PtrElementModel>(*ctx); - mlir::LLVM::LLVMPointerType::attachInterface< - PtrElementModel>(*ctx); - mlir::LLVM::LLVMArrayType::attachInterface>(*ctx); - }); - - mlir::transform::registerInterpreterPass(); - mlir::enzyme::registerGenerateApplyPatternsPass(); - mlir::enzyme::registerRemoveTransformPass(); - - auto state = mlir::OperationState(mlir::UnknownLoc::get(&ctx), "enzyme.autodiff"); - - auto scalarf64 = mlir::RankedTensorType::get({}, mlir::FloatType::getF64(&ctx)); - state.addTypes({scalarf64}); - - auto operands = module_op_.getOperation()->getOperands(); // complete guess - state.addOperands(mlir::ValueRange(operands)); - - auto operation = module_op_.getOperation(); // complete guess - state.addAttribute("fn", operation->getAttr("sym_name")); - auto activity = mlir::enzyme::ActivityAttr::get(&ctx, mlir::enzyme::Activity::enzyme_active); - state.addAttribute("activity", {activity}); - auto ret_activity = mlir::enzyme::ActivityAttr::get( - &ctx, mlir::enzyme::Activity::enzyme_activenoneed - ); - state.addAttribute("ret_activity", {ret_activity}); - - auto res = mlir::Operation::create(state); - - return 0; - -// return reinterpret_cast(new mlir::ModuleOp(res)); -//} + ModuleOp* emitEnzymeADOp(ModuleOp& module_op) { +// xla::XlaBuilder builder("root"); +// auto xlaScalarf64 = xla::ShapeUtil::MakeScalarShape((xla::PrimitiveType) 12); +// auto arg = xla::Parameter(&builder, 0, xlaScalarf64, "arg"); +// auto proto = builder.Build(xla::Square(arg))->proto(); +// +// mlir::MLIRContext ctx; +// mlir::DialectRegistry registry_; +// ctx.appendDialectRegistry(registry_); +// mlir::mhlo::registerAllMhloDialects(registry_); +// mlir::stablehlo::registerAllDialects(registry_); +// +// auto module_op_ = xla::ConvertHloToStablehlo(ctx, &proto).value().release(); + auto module_op_ = reinterpret_cast(module_op); + auto ctx = module_op_.getContext(); + mlir::DialectRegistry registry_; + + registry_.insert(); + registry_.insert(); + prepareRegistry(registry_); + + ctx->appendDialectRegistry(registry_); + + llvm::InitializeNativeTarget(); + llvm::InitializeNativeTargetAsmPrinter(); + + mlir::registerenzymePasses(); + regsiterenzymeXLAPasses(); + + mlir::registerCSEPass(); + mlir::registerConvertAffineToStandardPass(); + mlir::registerSCCPPass(); + mlir::registerInlinerPass(); + mlir::registerCanonicalizerPass(); + mlir::registerSymbolDCEPass(); + mlir::registerLoopInvariantCodeMotionPass(); + mlir::registerConvertSCFToOpenMPPass(); + mlir::affine::registerAffinePasses(); + mlir::registerReconcileUnrealizedCasts(); + + registry_.addExtension(+[](mlir::MLIRContext *ctx, mlir::LLVM::LLVMDialect *dialect) { + mlir::LLVM::LLVMFunctionType::attachInterface(*ctx); + mlir::LLVM::LLVMArrayType::attachInterface(*ctx); + mlir::LLVM::LLVMPointerType::attachInterface(*ctx); + mlir::LLVM::LLVMStructType::attachInterface(*ctx); + mlir::MemRefType::attachInterface>(*ctx); + mlir::LLVM::LLVMStructType::attachInterface< + PtrElementModel>(*ctx); + mlir::LLVM::LLVMPointerType::attachInterface< + PtrElementModel>(*ctx); + mlir::LLVM::LLVMArrayType::attachInterface>(*ctx); + }); + + mlir::transform::registerInterpreterPass(); + mlir::enzyme::registerGenerateApplyPatternsPass(); + mlir::enzyme::registerRemoveTransformPass(); + + auto state = mlir::OperationState(mlir::UnknownLoc::get(ctx), "enzyme.autodiff"); + + auto scalarf64 = mlir::RankedTensorType::get({}, mlir::FloatType::getF64(ctx)); + state.addTypes({scalarf64}); + + auto operands = module_op_.getOperation()->getOperands(); // complete guess + state.addOperands(mlir::ValueRange(operands)); + + auto operation = module_op_.getOperation(); // complete guess + state.addAttribute("fn", operation->getAttr("sym_name")); + auto activity = mlir::enzyme::ActivityAttr::get(ctx, mlir::enzyme::Activity::enzyme_active); + state.addAttribute("activity", {activity}); + auto ret_activity = mlir::enzyme::ActivityAttr::get( + ctx, mlir::enzyme::Activity::enzyme_activenoneed + ); + state.addAttribute("ret_activity", {ret_activity}); + + auto res = mlir::Operation::create(state); + +// return 0; + + return reinterpret_cast(new mlir::ModuleOp(res)); + } } diff --git a/spidr/src/Compiler/Enzyme/Enzyme/Enzyme/MLIR/Passes/Passes.idr b/spidr/src/Compiler/Enzyme/Enzyme/Enzyme/MLIR/Passes/Passes.idr index c31e148e8..51f375301 100644 --- a/spidr/src/Compiler/Enzyme/Enzyme/Enzyme/MLIR/Passes/Passes.idr +++ b/spidr/src/Compiler/Enzyme/Enzyme/Enzyme/MLIR/Passes/Passes.idr @@ -32,11 +32,11 @@ createDifferentiatePass = do pure (MkPass pass) %foreign (libxla "emitEnzymeADOp") -prim__emitEnzymeADOp : PrimIO () +prim__emitEnzymeADOp : GCAnyPtr -> PrimIO AnyPtr export emitEnzymeADOp : HasIO io => ModuleOp -> io ModuleOp emitEnzymeADOp (MkModuleOp op) = do - _ <- primIO $ prim__emitEnzymeADOp - --op <- onCollectAny op (primIO . BuiltinOps.prim__delete) + op <- primIO $ prim__emitEnzymeADOp op + op <- onCollectAny op (primIO . BuiltinOps.prim__delete) pure (MkModuleOp op) From 69558cb9cfe193e210e9fc424ca8b35808e985ba Mon Sep 17 00:00:00 2001 From: Joel Berkeley <16429957+joelberkeley@users.noreply.github.com> Date: Mon, 20 Jan 2025 23:26:44 +0000 Subject: [PATCH 34/38] use loadDialect --- spidr/backend/src/Enzyme-JAX/src/enzyme_ad/jax/Passes/Passes.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/spidr/backend/src/Enzyme-JAX/src/enzyme_ad/jax/Passes/Passes.cpp b/spidr/backend/src/Enzyme-JAX/src/enzyme_ad/jax/Passes/Passes.cpp index e4a18e20b..bf627a2e3 100644 --- a/spidr/backend/src/Enzyme-JAX/src/enzyme_ad/jax/Passes/Passes.cpp +++ b/spidr/backend/src/Enzyme-JAX/src/enzyme_ad/jax/Passes/Passes.cpp @@ -115,6 +115,7 @@ extern "C" { auto ctx = module_op_.getContext(); mlir::DialectRegistry registry_; + ctx.loadDialect(); registry_.insert(); registry_.insert(); prepareRegistry(registry_); From 8c319eec4b74fb78c23de6c2d8fbfc99b0ddbc09 Mon Sep 17 00:00:00 2001 From: Joel Berkeley <16429957+joelberkeley@users.noreply.github.com> Date: Mon, 20 Jan 2025 23:27:26 +0000 Subject: [PATCH 35/38] wip --- .../backend/src/Enzyme-JAX/src/enzyme_ad/jax/Passes/Passes.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spidr/backend/src/Enzyme-JAX/src/enzyme_ad/jax/Passes/Passes.cpp b/spidr/backend/src/Enzyme-JAX/src/enzyme_ad/jax/Passes/Passes.cpp index bf627a2e3..0b3a89e74 100644 --- a/spidr/backend/src/Enzyme-JAX/src/enzyme_ad/jax/Passes/Passes.cpp +++ b/spidr/backend/src/Enzyme-JAX/src/enzyme_ad/jax/Passes/Passes.cpp @@ -115,7 +115,7 @@ extern "C" { auto ctx = module_op_.getContext(); mlir::DialectRegistry registry_; - ctx.loadDialect(); + ctx.loadDialect(); // as suggested in MLIR tutorial registry_.insert(); registry_.insert(); prepareRegistry(registry_); From d4332df7af3f21501e2334bd1303f71c660e93a9 Mon Sep 17 00:00:00 2001 From: Joel Berkeley <16429957+joelberkeley@users.noreply.github.com> Date: Wed, 22 Jan 2025 23:31:07 +0000 Subject: [PATCH 36/38] start debugging properly --- .../src/enzyme_ad/jax/Passes/Passes.cpp | 45 +++++++++++++------ spidr/src/Compiler/Eval.idr | 13 +++--- test/runner/Unit/TestTensor/AD.idr | 2 +- 3 files changed, 40 insertions(+), 20 deletions(-) diff --git a/spidr/backend/src/Enzyme-JAX/src/enzyme_ad/jax/Passes/Passes.cpp b/spidr/backend/src/Enzyme-JAX/src/enzyme_ad/jax/Passes/Passes.cpp index 0b3a89e74..736a30670 100644 --- a/spidr/backend/src/Enzyme-JAX/src/enzyme_ad/jax/Passes/Passes.cpp +++ b/spidr/backend/src/Enzyme-JAX/src/enzyme_ad/jax/Passes/Passes.cpp @@ -86,19 +86,20 @@ struct PtrElementModel PtrElementModel, T> {}; extern "C" { -// void regsiterenzymeXLAPasses_() { -// regsiterenzymeXLAPasses(); -// } -// -// void registerenzymePasses() { -// mlir::registerenzymePasses(); -// } -// -//// Pass* createDifferentiatePass() { -//// return reinterpret_cast(mlir::enzyme::createDifferentiatePass().release()); -//// } + void regsiterenzymeXLAPasses_() { + regsiterenzymeXLAPasses(); + } + + void registerenzymePasses() { + mlir::registerenzymePasses(); + } + + Pass* createDifferentiatePass() { + return reinterpret_cast(mlir::enzyme::createDifferentiatePass().release()); + } ModuleOp* emitEnzymeADOp(ModuleOp& module_op) { + printf("emitEnzymeADOp\n"); // xla::XlaBuilder builder("root"); // auto xlaScalarf64 = xla::ShapeUtil::MakeScalarShape((xla::PrimitiveType) 12); // auto arg = xla::Parameter(&builder, 0, xlaScalarf64, "arg"); @@ -115,7 +116,7 @@ extern "C" { auto ctx = module_op_.getContext(); mlir::DialectRegistry registry_; - ctx.loadDialect(); // as suggested in MLIR tutorial + ctx->loadDialect(); // as suggested in MLIR tutorial registry_.insert(); registry_.insert(); prepareRegistry(registry_); @@ -161,6 +162,19 @@ extern "C" { auto scalarf64 = mlir::RankedTensorType::get({}, mlir::FloatType::getF64(ctx)); state.addTypes({scalarf64}); + printf("module_op_.getOperation()\n"); + module_op_.getOperation()->dump(); + + printf("module_op_.getOperation()->getName()\n"); + module_op_.getOperation()->getName().dump(); + printf("\n"); + + printf("module_op_.getOperation()->getNumOperands()\n"); + printf("%d\n", module_op_.getOperation()->getNumOperands()); + + printf("module_op_.getOperation()->getAttr('mhlo.cross_program_prefetches')\n"); + module_op_.getOperation()->getAttr("mhlo.cross_program_prefetches"); + auto operands = module_op_.getOperation()->getOperands(); // complete guess state.addOperands(mlir::ValueRange(operands)); @@ -175,7 +189,12 @@ extern "C" { auto res = mlir::Operation::create(state); -// return 0; + printf("enzyme op\n"); + res->dump(); + + mlir::PassManager pm(ctx); + pm.addPass(mlir::enzyme::createDifferentiatePass()); + pm.run(res); return reinterpret_cast(new mlir::ModuleOp(res)); } diff --git a/spidr/src/Compiler/Eval.idr b/spidr/src/Compiler/Eval.idr index 5c47dd7d5..310c64496 100644 --- a/spidr/src/Compiler/Eval.idr +++ b/spidr/src/Compiler/Eval.idr @@ -136,9 +136,10 @@ interpret @{cache} xlaBuilder (MkFn params root env) = do interpretE (Tuple xs) = tuple xlaBuilder !(traverse interpretE xs) interpretE (GetTupleElement idx x) = getTupleElement !(interpretE x) idx interpretE (Grad f x) = do + putStrLn "interpretE (Grad _ _)" computation <- compile xlaBuilder f stablehlo <- hloModuleProtoToStableHLO !(proto computation) - ctx <- getContext stablehlo + -- ctx <- getContext stablehlo -- reg <- mkDialectRegistry -- appendDialectRegistry ctx reg -- insertEnzymeDialect reg @@ -150,11 +151,11 @@ interpret @{cache} xlaBuilder (MkFn params root env) = do -- surely the ModuleOp already has stablehlo registered, since it's stablehlo code -- StableHLO.Dialect.Register.registerAllDialects reg -- registerStableHLODialectAutoDiffInterface reg - mgr <- mkPassManager ctx - addPass mgr !createDifferentiatePass - True <- run mgr enzymeOp - | False => throwE $ MlirPassError "Failed to run differentiate pass on StableHLO" - hloProto <- convertStablehloToHlo stablehlo + --mgr <- mkPassManager ctx + --addPass mgr !createDifferentiatePass + -- True <- run mgr enzymeOp + -- | False => throwE $ MlirPassError "Failed to run differentiate pass on StableHLO" + hloProto <- convertStablehloToHlo enzymeOp computation <- mkXlaComputation hloProto -- x should be correct shape, because we're sending R^{n0, n1, ..} -> R -- to R^{n0, n1, ..} -> R^{n0, n1, ..} i.e. we're only changing the output shape diff --git a/test/runner/Unit/TestTensor/AD.idr b/test/runner/Unit/TestTensor/AD.idr index 365d02014..84b0ca994 100644 --- a/test/runner/Unit/TestTensor/AD.idr +++ b/test/runner/Unit/TestTensor/AD.idr @@ -26,7 +26,7 @@ import Utils.Cases square : Device => Property square = fixedProperty $ do - square (tensor 3.0) ===# tensor 9.0 + -- square (tensor 3.0) ===# tensor 9.0 grad (pure . square) (tensor 3.0) ===# pure (tensor 6.0) export From 2e8f5d442d54f3a380f3b99fca3ac9b2373522b8 Mon Sep 17 00:00:00 2001 From: Joel Berkeley <16429957+joelberkeley@users.noreply.github.com> Date: Thu, 23 Jan 2025 23:41:35 +0000 Subject: [PATCH 37/38] really solid progress --- .../src/enzyme_ad/jax/Passes/Passes.cpp | 92 ++++++++++++------- 1 file changed, 60 insertions(+), 32 deletions(-) diff --git a/spidr/backend/src/Enzyme-JAX/src/enzyme_ad/jax/Passes/Passes.cpp b/spidr/backend/src/Enzyme-JAX/src/enzyme_ad/jax/Passes/Passes.cpp index 736a30670..fde274a87 100644 --- a/spidr/backend/src/Enzyme-JAX/src/enzyme_ad/jax/Passes/Passes.cpp +++ b/spidr/backend/src/Enzyme-JAX/src/enzyme_ad/jax/Passes/Passes.cpp @@ -77,13 +77,13 @@ limitations under the License. #include "stablehlo/dialect/StablehloOps.h" #include "stablehlo/tests/CheckOps.h" -class MemRefInsider - : public mlir::MemRefElementTypeInterface::FallbackModel {}; - -template -struct PtrElementModel - : public mlir::LLVM::PointerElementTypeInterface::ExternalModel< - PtrElementModel, T> {}; +//class MemRefInsider +// : public mlir::MemRefElementTypeInterface::FallbackModel {}; +// +//template +//struct PtrElementModel +// : public mlir::LLVM::PointerElementTypeInterface::ExternalModel< +// PtrElementModel, T> {}; extern "C" { void regsiterenzymeXLAPasses_() { @@ -140,46 +140,74 @@ extern "C" { mlir::affine::registerAffinePasses(); mlir::registerReconcileUnrealizedCasts(); - registry_.addExtension(+[](mlir::MLIRContext *ctx, mlir::LLVM::LLVMDialect *dialect) { - mlir::LLVM::LLVMFunctionType::attachInterface(*ctx); - mlir::LLVM::LLVMArrayType::attachInterface(*ctx); - mlir::LLVM::LLVMPointerType::attachInterface(*ctx); - mlir::LLVM::LLVMStructType::attachInterface(*ctx); - mlir::MemRefType::attachInterface>(*ctx); - mlir::LLVM::LLVMStructType::attachInterface< - PtrElementModel>(*ctx); - mlir::LLVM::LLVMPointerType::attachInterface< - PtrElementModel>(*ctx); - mlir::LLVM::LLVMArrayType::attachInterface>(*ctx); - }); +// registry_.addExtension(+[](mlir::MLIRContext *ctx, mlir::LLVM::LLVMDialect *dialect) { +// mlir::LLVM::LLVMFunctionType::attachInterface(*ctx); +// mlir::LLVM::LLVMArrayType::attachInterface(*ctx); +// mlir::LLVM::LLVMPointerType::attachInterface(*ctx); +// mlir::LLVM::LLVMStructType::attachInterface(*ctx); +// mlir::MemRefType::attachInterface>(*ctx); +// mlir::LLVM::LLVMStructType::attachInterface< +// PtrElementModel>(*ctx); +// mlir::LLVM::LLVMPointerType::attachInterface< +// PtrElementModel>(*ctx); +// mlir::LLVM::LLVMArrayType::attachInterface>(*ctx); +// }); mlir::transform::registerInterpreterPass(); mlir::enzyme::registerGenerateApplyPatternsPass(); mlir::enzyme::registerRemoveTransformPass(); - auto state = mlir::OperationState(mlir::UnknownLoc::get(ctx), "enzyme.autodiff"); + printf("module_op_.getOperation()\n"); + module_op_.getOperation()->dump(); - auto scalarf64 = mlir::RankedTensorType::get({}, mlir::FloatType::getF64(ctx)); - state.addTypes({scalarf64}); + auto& region = module_op_.getOperation()->getRegion(0); + + printf("region.getNumArguments()\n"); + printf("%d\n", region.getNumArguments()); + + auto& block = region.front(); + + printf("block.getNumArguments()\n"); + printf("%d\n", block.getNumArguments()); + + auto& operation = block.front(); + + mlir::SymbolTable::setSymbolName(&operation, "tmp"); + mlir::SymbolTable::setSymbolVisibility(&operation, mlir::SymbolTable::Visibility::Private); printf("module_op_.getOperation()\n"); module_op_.getOperation()->dump(); - printf("module_op_.getOperation()->getName()\n"); - module_op_.getOperation()->getName().dump(); - printf("\n"); + printf("operation.getNumOperands()\n"); + printf("%d\n", operation.getNumOperands()); + + printf("operation\n"); + operation.dump(); + + printf("operation.getNumRegions()\n"); + printf("%d\n", operation.getNumRegions()); + + printf("operation.getRegion(0).getNumArguments()\n"); + printf("%d\n", operation.getRegion(0).getNumArguments()); - printf("module_op_.getOperation()->getNumOperands()\n"); - printf("%d\n", module_op_.getOperation()->getNumOperands()); + auto scalarf64 = mlir::RankedTensorType::get({}, mlir::FloatType::getF64(ctx)); + auto func_type = mlir::FunctionType::get({scalarf64}, {scalarf64}, ctx); + auto func_op = mlir::func::FuncOp::create(mlir::UnknownLoc::get(ctx), "main", func_type); + + // in FuncOp, emit an "enzyme.autodiff" op (i.e. function call) like in + // https://github.com/EnzymeAD/Enzyme-JAX/blob/fb483c06f697990c60cc3c0bda7fb1d730fca3de/test/lit_tests/grad_sum1d.mlir#L11 + // (you can reuse a lot of the stuff below for that), followed by a func::ReturnOp + // + // I think the differentiate pass should just work then, and we've got a mwe!(?) + + auto state = mlir::OperationState(mlir::UnknownLoc::get(ctx), "enzyme.autodiff"); - printf("module_op_.getOperation()->getAttr('mhlo.cross_program_prefetches')\n"); - module_op_.getOperation()->getAttr("mhlo.cross_program_prefetches"); +// state.addTypes({scalarf64}); - auto operands = module_op_.getOperation()->getOperands(); // complete guess + auto operands = operation.getOperands(); // complete guess state.addOperands(mlir::ValueRange(operands)); - auto operation = module_op_.getOperation(); // complete guess - state.addAttribute("fn", operation->getAttr("sym_name")); + state.addAttribute("fn", operation.getAttr("sym_name")); auto activity = mlir::enzyme::ActivityAttr::get(ctx, mlir::enzyme::Activity::enzyme_active); state.addAttribute("activity", {activity}); auto ret_activity = mlir::enzyme::ActivityAttr::get( From e7e94f9592ff7df565bd8e0197fa8b16e0b7dc1d Mon Sep 17 00:00:00 2001 From: Joel Berkeley <16429957+joelberkeley@users.noreply.github.com> Date: Sun, 26 Jan 2025 02:29:41 +0000 Subject: [PATCH 38/38] most things there --- .../src/enzyme_ad/jax/Passes/Passes.cpp | 90 ++++++++++--------- 1 file changed, 46 insertions(+), 44 deletions(-) diff --git a/spidr/backend/src/Enzyme-JAX/src/enzyme_ad/jax/Passes/Passes.cpp b/spidr/backend/src/Enzyme-JAX/src/enzyme_ad/jax/Passes/Passes.cpp index fde274a87..7e68901f5 100644 --- a/spidr/backend/src/Enzyme-JAX/src/enzyme_ad/jax/Passes/Passes.cpp +++ b/spidr/backend/src/Enzyme-JAX/src/enzyme_ad/jax/Passes/Passes.cpp @@ -34,9 +34,11 @@ limitations under the License. #include "src/enzyme_ad/jax/TransformOps/TransformOps.h" #include "src/enzyme_ad/jax/RegistryUtils.h" +#include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Location.h" +#include "mlir/IR/OperationSupport.h" #include "mlir/Conversion/ConvertToLLVM/ToLLVMPass.h" #include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h" #include "mlir/Conversion/Passes.h" @@ -161,69 +163,69 @@ extern "C" { module_op_.getOperation()->dump(); auto& region = module_op_.getOperation()->getRegion(0); - - printf("region.getNumArguments()\n"); - printf("%d\n", region.getNumArguments()); - auto& block = region.front(); - - printf("block.getNumArguments()\n"); - printf("%d\n", block.getNumArguments()); - auto& operation = block.front(); mlir::SymbolTable::setSymbolName(&operation, "tmp"); - mlir::SymbolTable::setSymbolVisibility(&operation, mlir::SymbolTable::Visibility::Private); - - printf("module_op_.getOperation()\n"); - module_op_.getOperation()->dump(); - - printf("operation.getNumOperands()\n"); - printf("%d\n", operation.getNumOperands()); - - printf("operation\n"); - operation.dump(); - - printf("operation.getNumRegions()\n"); - printf("%d\n", operation.getNumRegions()); - - printf("operation.getRegion(0).getNumArguments()\n"); - printf("%d\n", operation.getRegion(0).getNumArguments()); +// mlir::SymbolTable::setSymbolVisibility(&operation, mlir::SymbolTable::Visibility::Private); auto scalarf64 = mlir::RankedTensorType::get({}, mlir::FloatType::getF64(ctx)); - auto func_type = mlir::FunctionType::get({scalarf64}, {scalarf64}, ctx); + auto func_type = mlir::FunctionType::get(ctx, {scalarf64}, {scalarf64}); auto func_op = mlir::func::FuncOp::create(mlir::UnknownLoc::get(ctx), "main", func_type); - // in FuncOp, emit an "enzyme.autodiff" op (i.e. function call) like in - // https://github.com/EnzymeAD/Enzyme-JAX/blob/fb483c06f697990c60cc3c0bda7fb1d730fca3de/test/lit_tests/grad_sum1d.mlir#L11 - // (you can reuse a lot of the stuff below for that), followed by a func::ReturnOp - // - // I think the differentiate pass should just work then, and we've got a mwe!(?) - - auto state = mlir::OperationState(mlir::UnknownLoc::get(ctx), "enzyme.autodiff"); + block.push_back(func_op); -// state.addTypes({scalarf64}); + auto entry_block = func_op.addEntryBlock(); - auto operands = operation.getOperands(); // complete guess - state.addOperands(mlir::ValueRange(operands)); - - state.addAttribute("fn", operation.getAttr("sym_name")); auto activity = mlir::enzyme::ActivityAttr::get(ctx, mlir::enzyme::Activity::enzyme_active); - state.addAttribute("activity", {activity}); auto ret_activity = mlir::enzyme::ActivityAttr::get( ctx, mlir::enzyme::Activity::enzyme_activenoneed ); - state.addAttribute("ret_activity", {ret_activity}); - auto res = mlir::Operation::create(state); + mlir::NamedAttrList attrs; + attrs.set("fn", operation.getAttr("sym_name")); + attrs.set("activity", activity); + attrs.set("ret_activity", ret_activity); + + auto autodiff = mlir::Operation::create( + mlir::UnknownLoc::get(ctx), + mlir::OperationName("enzyme.autodiff", ctx), + mlir::TypeRange({scalarf64}), + mlir::ValueRange(entry_block->getArgument(0)), + std::move(attrs), +// mlir::NamedAttributeList({ +// mlir::NamedAttribute(mlir::StringAttr::get("fn", ctx), operation.getAttr("sym_name")), +// mlir::NamedAttribute(mlir::StringAttr::get("activity", ctx), activity), +// mlir::NamedAttribute(mlir::StringAttr::get("ret_activity", ctx), ret_activity), +// }), + mlir::OpaqueProperties(nullptr) + ); + +// auto state = mlir::OperationState(mlir::UnknownLoc::get(ctx), "enzyme.autodiff"); +// state.addOperands(mlir::ValueRange(entry_block->getArgument(0))); +// state.addTypes({scalarf64}); +// state.addAttribute("fn", operation.getAttr("sym_name")); +// state.addAttribute("activity", {activity}); +// state.addAttribute("ret_activity", {ret_activity}); +// auto autodiff = mlir::Operation::create(state); + entry_block->push_back(autodiff); + + auto return_op = mlir::OpBuilder(ctx).create( + mlir::UnknownLoc::get(ctx), + mlir::ValueRange(autodiff->getOpResult(0)) + ); + entry_block->push_back(return_op); - printf("enzyme op\n"); - res->dump(); + printf("module_op_.getOperation()\n"); + module_op_.getOperation()->dump(); mlir::PassManager pm(ctx); + printf("0\n"); pm.addPass(mlir::enzyme::createDifferentiatePass()); - pm.run(res); + printf("1\n"); + pm.run(func_op); + printf("2\n"); - return reinterpret_cast(new mlir::ModuleOp(res)); + return reinterpret_cast(new mlir::ModuleOp(func_op)); } }