diff --git a/XLA_VERSION b/XLA_VERSION index 6d2544760..412bc18ca 100644 --- a/XLA_VERSION +++ b/XLA_VERSION @@ -1 +1 @@ -2fb20601f1cc6cab7f29f8bc73d90cd31e74bba0 \ No newline at end of file +b44f55da3dac449f03466815ac431474f86fd73f \ No newline at end of file diff --git a/dev.sh b/dev.sh index f0d395ace..df3df6ad8 100644 --- a/dev.sh +++ b/dev.sh @@ -8,12 +8,7 @@ short_revision () { echo "${rev%%"${rev##??????????}"}" } -install_xla () { - if [ -z "$2" ]; then - echo "Usage: install_xla ." - exit 1; - fi - +install_git_repository () { if [ "$(ls -A "$2")" ]; then echo "Directory at path $2 is not empty, refusing to install XLA to this directory." exit 1; @@ -22,8 +17,16 @@ install_xla () { ( cd "$2" git init - git remote add origin https://github.com/openxla/xla + git remote add origin "$3" git fetch --depth 1 origin "$1" git checkout FETCH_HEAD ) } + +install_xla () { + install_git_repository "$1" "$2" https://github.com/openxla/xla +} + +install_enzyme () { + install_git_repository "$1" "$2" https://github.com/EnzymeAD/Enzyme-JAX.git +} diff --git a/pjrt-plugins/xla-cpu/.gitignore b/pjrt-plugins/xla-cpu/.gitignore new file mode 100644 index 000000000..bb617c857 --- /dev/null +++ b/pjrt-plugins/xla-cpu/.gitignore @@ -0,0 +1 @@ +xla/ diff --git a/pjrt-plugins/xla-cpu/build.sh b/pjrt-plugins/xla-cpu/build.sh index 094d8c41d..69a3a5140 100755 --- a/pjrt-plugins/xla-cpu/build.sh +++ b/pjrt-plugins/xla-cpu/build.sh @@ -23,7 +23,8 @@ case $osu in ;; esac -xla_dir=$(mktemp -d) +xla_dir=pjrt-plugins/xla-cpu/xla +mkdir "$xla_dir" install_xla "$rev" "$xla_dir" ( cd "$xla_dir" diff --git a/pjrt-plugins/xla-cuda/build.sh b/pjrt-plugins/xla-cuda/build.sh index 32ffbf971..271063872 100755 --- a/pjrt-plugins/xla-cuda/build.sh +++ b/pjrt-plugins/xla-cuda/build.sh @@ -18,7 +18,8 @@ case $osu in ;; esac -xla_dir=$(mktemp -d) +xla_dir=pjrt-plugins/xla-cuda/xla +mkdir "$xla_dir" install_xla "$rev" "$xla_dir" ( cd "$xla_dir" diff --git a/spidr/backend/.gitignore b/spidr/backend/.gitignore index 24a3274c5..2fbfae974 100644 --- a/spidr/backend/.gitignore +++ b/spidr/backend/.gitignore @@ -1 +1,2 @@ +/Enzyme-JAX /xla diff --git a/spidr/backend/BUILD b/spidr/backend/BUILD index ba1a39bfa..fd36a2838 100644 --- a/spidr/backend/BUILD +++ b/spidr/backend/BUILD @@ -12,21 +12,43 @@ cc_binary( linkshared = True, linkstatic = True, srcs = [ + "//src/Enzyme-JAX/src/enzyme_ad/jax", + "//src/Enzyme-JAX/src/enzyme_ad/jax/Passes", + "//src/Enzyme/enzyme/Enzyme/MLIR/Dialect", + "//src/Enzyme/enzyme/Enzyme/MLIR/Passes", + "//src/llvm/Support", + "//src/mlir/IR", + "//src/mlir/Pass", + "//src/stablehlo/dialect", "//src/xla", "//src/xla/client", "//src/xla/hlo/builder", "//src/xla/hlo/builder/lib", + "//src/xla/hlo/translate", + "//src/xla/mlir_hlo/mhlo/IR", "//src/xla/pjrt", "//src/xla/pjrt/c", + "//src/xla/service", "//src", ], deps = [ + "//src/Enzyme-JAX/src/enzyme_ad/jax", + "//src/Enzyme-JAX/src/enzyme_ad/jax/Passes", + "//src/Enzyme/enzyme/Enzyme/MLIR/Dialect", + "//src/Enzyme/enzyme/Enzyme/MLIR/Passes", + "//src/llvm/Support", + "//src/mlir/IR", + "//src/mlir/Pass", + "//src/stablehlo/dialect", "//src/xla", "//src/xla/client", "//src/xla/hlo/builder", "//src/xla/hlo/builder/lib", + "//src/xla/hlo/translate", + "//src/xla/mlir_hlo/mhlo/IR", "//src/xla/pjrt", "//src/xla/pjrt/c", + "//src/xla/service", "//src", ], ) diff --git a/spidr/backend/ENZYME_JAX_VERSION b/spidr/backend/ENZYME_JAX_VERSION new file mode 100644 index 000000000..cb76d9f9c --- /dev/null +++ b/spidr/backend/ENZYME_JAX_VERSION @@ -0,0 +1 @@ +b6d6563aa3a3050474a4250bf18322f7ebf0b486 \ No newline at end of file diff --git a/spidr/backend/VERSION b/spidr/backend/VERSION index e3b86dd9c..9beca35dc 100644 --- a/spidr/backend/VERSION +++ b/spidr/backend/VERSION @@ -1 +1 @@ -0.0.16 +0.0.15 \ No newline at end of file diff --git a/spidr/backend/WORKSPACE b/spidr/backend/WORKSPACE index 991ab29d9..b57874318 100644 --- a/spidr/backend/WORKSPACE +++ b/spidr/backend/WORKSPACE @@ -1,3 +1,5 @@ +### xla + # this must be a local repository not http archive # so we can run ./configure.py before invoking bazel local_repository(name = "xla", path = "xla") @@ -28,3 +30,45 @@ xla_workspace0() load("@tsl//third_party/gpus/cuda/hermetic:cuda_configure.bzl", "cuda_configure") cuda_configure(name = "local_config_cuda") + +### Enzyme-JAX +# note enzyme-jax specifies XLA versions, which we're currently ignoring. Do we need to use their versions? +local_repository(name = "enzyme-jax", path = "Enzyme-JAX") + +load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") + +http_archive( + name = "hedron_compile_commands", + + # Replace the commit hash (0e990032f3c5a866e72615cf67e5ce22186dcb97) in both places (below) with the latest (https://github.com/hedronvision/bazel-compile-commands-extractor/commits/main), rather than using the stale one here. + # Even better, set up Renovate and let it do the work for you (see "Suggestion: Updates" in the README). + url = "https://github.com/hedronvision/bazel-compile-commands-extractor/archive/4f28899228fb3ad0126897876f147ca15026151e.tar.gz", + strip_prefix = "bazel-compile-commands-extractor-4f28899228fb3ad0126897876f147ca15026151e", + # When you first run this tool, it'll recommend a sha256 hash to put here with a message like: "DEBUG: Rule 'hedron_compile_commands' indicated that a canonical reproducible form can be obtained by modifying arguments sha256 = ..." +) +# load("@hedron_compile_commands//:workspace_setup.bzl", "hedron_compile_commands_setup") +# hedron_compile_commands_setup() +# load("@hedron_compile_commands//:workspace_setup_transitive.bzl", "hedron_compile_commands_setup_transitive") +# hedron_compile_commands_setup_transitive() +# load("@hedron_compile_commands//:workspace_setup_transitive_transitive.bzl", "hedron_compile_commands_setup_transitive_transitive") +# hedron_compile_commands_setup_transitive_transitive() +# load("@hedron_compile_commands//:workspace_setup_transitive_transitive_transitive.bzl", "hedron_compile_commands_setup_transitive_transitive_transitive") +# hedron_compile_commands_setup_transitive_transitive_transitive() + +load("@enzyme-jax//:workspace.bzl", "JAX_COMMIT", "JAX_SHA256", "ENZYME_COMMIT", "ENZYME_SHA256") + +http_archive( + name = "jax", + sha256 = JAX_SHA256, + strip_prefix = "jax-" + JAX_COMMIT, + urls = ["https://github.com/google/jax/archive/{commit}.tar.gz".format(commit = JAX_COMMIT)], + patch_args = ["-p1"], + patches = ["@enzyme-jax//:patches/jax.patch"], +) + +http_archive( + name = "enzyme", + sha256 = ENZYME_SHA256, + strip_prefix = "Enzyme-" + ENZYME_COMMIT + "/enzyme", + urls = ["https://github.com/EnzymeAD/Enzyme/archive/{commit}.tar.gz".format(commit = ENZYME_COMMIT)], +) diff --git a/spidr/backend/build.sh b/spidr/backend/build.sh index e88541581..3d9227dae 100755 --- a/spidr/backend/build.sh +++ b/spidr/backend/build.sh @@ -3,7 +3,8 @@ script_dir=$(CDPATH="" cd -- "$(dirname -- "$0")" && pwd) cd "$script_dir/../.." . ./dev.sh -rev="$(cat XLA_VERSION)" +xla_rev="$(cat XLA_VERSION)" +enzyme_rev="$(cat spidr/backend/ENZYME_JAX_VERSION)" osu="$(uname)" case $osu in @@ -26,8 +27,13 @@ esac ( cd spidr/backend mkdir xla - install_xla "$rev" xla + install_xla "$xla_rev" xla (cd xla; ./configure.py --backend=cpu --os=$os) + # depending on Enzyme-JAX is problematic as it fixes the XLA version. Can we only depend on enzyme? + # seems unlikely that they could decouple XLA entirely. They almost certainly can't decouple stablehlo + mkdir Enzyme-JAX + install_enzyme "$enzyme_rev" Enzyme-JAX + cat everything >> Enzyme-JAX/BUILD bazel build //:c_xla rm -rf xla ) diff --git a/spidr/backend/everything b/spidr/backend/everything new file mode 100644 index 000000000..3705d060e --- /dev/null +++ b/spidr/backend/everything @@ -0,0 +1,54 @@ + +cc_library( + name = "everything", + srcs = [ + "//src/enzyme_ad/jax:TransformOps", + "//src/enzyme_ad/jax:XLADerivatives", + "//src/enzyme_ad/jax:RegistryUtils.cpp", + ], + hdrs = [ + "//src/enzyme_ad/jax:TransformOps", + "//src/enzyme_ad/jax:XLADerivatives", + "//src/enzyme_ad/jax:RegistryUtils.h", + ], + visibility = ["//visibility:public"], + deps = [ + "@enzyme//:EnzymeMLIR", + "@llvm-project//mlir:AffineDialect", + "@llvm-project//mlir:AllPassesAndDialects", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:AsyncDialect", + "@llvm-project//mlir:ComplexDialect", + "@llvm-project//mlir:ControlFlowDialect", + "@llvm-project//mlir:ConversionPasses", + "@llvm-project//mlir:DLTIDialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:LinalgDialect", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:MathDialect", + "@llvm-project//mlir:MemRefDialect", + "@llvm-project//mlir:MlirOptLib", + "@llvm-project//mlir:NVVMDialect", + "@llvm-project//mlir:NVGPUDialect", + "@llvm-project//mlir:OpenMPDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:TransformDialect", + "@llvm-project//mlir:Transforms", + "//src/enzyme_ad/jax:TransformOps", + "//src/enzyme_ad/jax:XLADerivatives", + "@stablehlo//:chlo_ops", + "@stablehlo//stablehlo/tests:check_ops", + "@llvm-project//mlir:ArithToLLVM", + "@llvm-project//mlir:BuiltinToLLVMIRTranslation", + "@llvm-project//mlir:ComplexToLLVM", + "@llvm-project//mlir:ControlFlowToLLVM", + "@llvm-project//mlir:GPUToLLVMIRTranslation", + "@llvm-project//mlir:LLVMToLLVMIRTranslation", + "@llvm-project//mlir:NVVMToLLVMIRTranslation", + + "@llvm-project//llvm:X86AsmParser", + "@llvm-project//llvm:X86CodeGen", + ], +) diff --git a/spidr/backend/src/Enzyme-JAX/src/enzyme_ad/jax/BUILD b/spidr/backend/src/Enzyme-JAX/src/enzyme_ad/jax/BUILD new file mode 100644 index 000000000..42058b418 --- /dev/null +++ b/spidr/backend/src/Enzyme-JAX/src/enzyme_ad/jax/BUILD @@ -0,0 +1,12 @@ +cc_library( + name = "jax", + linkstatic = True, + alwayslink = True, + srcs = glob(["*.cpp"]), + hdrs = glob(["*.h"]), + deps = [ + "@enzyme-jax//:everything", + "//src/mlir/IR", + ], + visibility = ["//visibility:public"], +) diff --git a/spidr/backend/src/Enzyme-JAX/src/enzyme_ad/jax/Passes/BUILD b/spidr/backend/src/Enzyme-JAX/src/enzyme_ad/jax/Passes/BUILD new file mode 100644 index 000000000..d9b330d82 --- /dev/null +++ b/spidr/backend/src/Enzyme-JAX/src/enzyme_ad/jax/Passes/BUILD @@ -0,0 +1,33 @@ +cc_library( + name = "Passes", + linkstatic = True, + alwayslink = True, + srcs = glob(["*.cpp"]), + hdrs = glob(["*.h"]), + deps = [ + "@xla//xla/hlo/builder:xla_builder", + "@xla//xla/hlo/translate:stablehlo", + "@xla//xla/hlo/builder/lib:math", + "@xla//xla/mlir_hlo:hlo_dialect_registration", + "@enzyme-jax//:everything", + "//src/mlir/IR", + "//src/mlir/Pass", + ], + visibility = ["//visibility:public"], +) + +cc_binary( + name = "example", + linkstatic = True, + srcs = glob(["*.cpp"]), + deps = [ + "@xla//xla/hlo/builder:xla_builder", + "@xla//xla/hlo/translate:stablehlo", + "@xla//xla/hlo/builder/lib:math", + "@xla//xla/mlir_hlo:hlo_dialect_registration", + "@enzyme-jax//:everything", + "//src/mlir/IR", + "//src/mlir/Pass", + ], + visibility = ["//visibility:public"], +) diff --git a/spidr/backend/src/Enzyme-JAX/src/enzyme_ad/jax/Passes/Passes.cpp b/spidr/backend/src/Enzyme-JAX/src/enzyme_ad/jax/Passes/Passes.cpp new file mode 100644 index 000000000..7e68901f5 --- /dev/null +++ b/spidr/backend/src/Enzyme-JAX/src/enzyme_ad/jax/Passes/Passes.cpp @@ -0,0 +1,231 @@ +/* +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "../../../../../mlir/IR/BuiltinOps.h" +#include "../../../../../mlir/IR/DialectRegistry.h" +#include "../../../../../mlir/Pass/Pass.h" + +#include "stablehlo/dialect/Register.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/translate/stablehlo.h" +#include "xla/hlo/builder/lib/math.h" +#include "xla/mlir_hlo/mhlo/IR/register.h" + +#include "Enzyme/MLIR/Dialect/Dialect.h" +#include "Enzyme/MLIR/Dialect/Ops.h" +#include "Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h" +#include "Enzyme/MLIR/Passes/Passes.h" + +#include "src/enzyme_ad/jax/Dialect/Dialect.h" +#include "src/enzyme_ad/jax/Implementations/XLADerivatives.h" +#include "src/enzyme_ad/jax/Passes/Passes.h" +#include "src/enzyme_ad/jax/TransformOps/TransformOps.h" +#include "src/enzyme_ad/jax/RegistryUtils.h" + +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/Conversion/ConvertToLLVM/ToLLVMPass.h" +#include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h" +#include "mlir/Conversion/Passes.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Async/IR/Async.h" +#include "mlir/Dialect/Complex/IR/Complex.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" +#include "mlir/Dialect/DLTI/DLTI.h" +#include "mlir/Dialect/Func/Extensions/InlinerExtension.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/TransformOps/DialectExtension.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" +#include "mlir/Dialect/OpenMP/OpenMPDialect.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Transform/Transforms/Passes.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/InitAllPasses.h" +#include "mlir/Pass/PassRegistry.h" +#include "mlir/Tools/mlir-opt/MlirOptMain.h" +#include "mlir/Transforms/Passes.h" +#include "mlir/Target/LLVM/NVVM/Target.h" +#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h" +#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" +#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h" + +#include "llvm/Support/TargetSelect.h" + +#include "stablehlo/dialect/ChloOps.h" +#include "stablehlo/dialect/StablehloOps.h" +#include "stablehlo/tests/CheckOps.h" + +//class MemRefInsider +// : public mlir::MemRefElementTypeInterface::FallbackModel {}; +// +//template +//struct PtrElementModel +// : public mlir::LLVM::PointerElementTypeInterface::ExternalModel< +// PtrElementModel, T> {}; + +extern "C" { + void regsiterenzymeXLAPasses_() { + regsiterenzymeXLAPasses(); + } + + void registerenzymePasses() { + mlir::registerenzymePasses(); + } + + Pass* createDifferentiatePass() { + return reinterpret_cast(mlir::enzyme::createDifferentiatePass().release()); + } + + ModuleOp* emitEnzymeADOp(ModuleOp& module_op) { + printf("emitEnzymeADOp\n"); +// xla::XlaBuilder builder("root"); +// auto xlaScalarf64 = xla::ShapeUtil::MakeScalarShape((xla::PrimitiveType) 12); +// auto arg = xla::Parameter(&builder, 0, xlaScalarf64, "arg"); +// auto proto = builder.Build(xla::Square(arg))->proto(); +// +// mlir::MLIRContext ctx; +// mlir::DialectRegistry registry_; +// ctx.appendDialectRegistry(registry_); +// mlir::mhlo::registerAllMhloDialects(registry_); +// mlir::stablehlo::registerAllDialects(registry_); +// +// auto module_op_ = xla::ConvertHloToStablehlo(ctx, &proto).value().release(); + auto module_op_ = reinterpret_cast(module_op); + auto ctx = module_op_.getContext(); + mlir::DialectRegistry registry_; + + ctx->loadDialect(); // as suggested in MLIR tutorial + registry_.insert(); + registry_.insert(); + prepareRegistry(registry_); + + ctx->appendDialectRegistry(registry_); + + llvm::InitializeNativeTarget(); + llvm::InitializeNativeTargetAsmPrinter(); + + mlir::registerenzymePasses(); + regsiterenzymeXLAPasses(); + + mlir::registerCSEPass(); + mlir::registerConvertAffineToStandardPass(); + mlir::registerSCCPPass(); + mlir::registerInlinerPass(); + mlir::registerCanonicalizerPass(); + mlir::registerSymbolDCEPass(); + mlir::registerLoopInvariantCodeMotionPass(); + mlir::registerConvertSCFToOpenMPPass(); + mlir::affine::registerAffinePasses(); + mlir::registerReconcileUnrealizedCasts(); + +// registry_.addExtension(+[](mlir::MLIRContext *ctx, mlir::LLVM::LLVMDialect *dialect) { +// mlir::LLVM::LLVMFunctionType::attachInterface(*ctx); +// mlir::LLVM::LLVMArrayType::attachInterface(*ctx); +// mlir::LLVM::LLVMPointerType::attachInterface(*ctx); +// mlir::LLVM::LLVMStructType::attachInterface(*ctx); +// mlir::MemRefType::attachInterface>(*ctx); +// mlir::LLVM::LLVMStructType::attachInterface< +// PtrElementModel>(*ctx); +// mlir::LLVM::LLVMPointerType::attachInterface< +// PtrElementModel>(*ctx); +// mlir::LLVM::LLVMArrayType::attachInterface>(*ctx); +// }); + + mlir::transform::registerInterpreterPass(); + mlir::enzyme::registerGenerateApplyPatternsPass(); + mlir::enzyme::registerRemoveTransformPass(); + + printf("module_op_.getOperation()\n"); + module_op_.getOperation()->dump(); + + auto& region = module_op_.getOperation()->getRegion(0); + auto& block = region.front(); + auto& operation = block.front(); + + mlir::SymbolTable::setSymbolName(&operation, "tmp"); +// mlir::SymbolTable::setSymbolVisibility(&operation, mlir::SymbolTable::Visibility::Private); + + auto scalarf64 = mlir::RankedTensorType::get({}, mlir::FloatType::getF64(ctx)); + auto func_type = mlir::FunctionType::get(ctx, {scalarf64}, {scalarf64}); + auto func_op = mlir::func::FuncOp::create(mlir::UnknownLoc::get(ctx), "main", func_type); + + block.push_back(func_op); + + auto entry_block = func_op.addEntryBlock(); + + auto activity = mlir::enzyme::ActivityAttr::get(ctx, mlir::enzyme::Activity::enzyme_active); + auto ret_activity = mlir::enzyme::ActivityAttr::get( + ctx, mlir::enzyme::Activity::enzyme_activenoneed + ); + + mlir::NamedAttrList attrs; + attrs.set("fn", operation.getAttr("sym_name")); + attrs.set("activity", activity); + attrs.set("ret_activity", ret_activity); + + auto autodiff = mlir::Operation::create( + mlir::UnknownLoc::get(ctx), + mlir::OperationName("enzyme.autodiff", ctx), + mlir::TypeRange({scalarf64}), + mlir::ValueRange(entry_block->getArgument(0)), + std::move(attrs), +// mlir::NamedAttributeList({ +// mlir::NamedAttribute(mlir::StringAttr::get("fn", ctx), operation.getAttr("sym_name")), +// mlir::NamedAttribute(mlir::StringAttr::get("activity", ctx), activity), +// mlir::NamedAttribute(mlir::StringAttr::get("ret_activity", ctx), ret_activity), +// }), + mlir::OpaqueProperties(nullptr) + ); + +// auto state = mlir::OperationState(mlir::UnknownLoc::get(ctx), "enzyme.autodiff"); +// state.addOperands(mlir::ValueRange(entry_block->getArgument(0))); +// state.addTypes({scalarf64}); +// state.addAttribute("fn", operation.getAttr("sym_name")); +// state.addAttribute("activity", {activity}); +// state.addAttribute("ret_activity", {ret_activity}); +// auto autodiff = mlir::Operation::create(state); + entry_block->push_back(autodiff); + + auto return_op = mlir::OpBuilder(ctx).create( + mlir::UnknownLoc::get(ctx), + mlir::ValueRange(autodiff->getOpResult(0)) + ); + entry_block->push_back(return_op); + + printf("module_op_.getOperation()\n"); + module_op_.getOperation()->dump(); + + mlir::PassManager pm(ctx); + printf("0\n"); + pm.addPass(mlir::enzyme::createDifferentiatePass()); + printf("1\n"); + pm.run(func_op); + printf("2\n"); + + return reinterpret_cast(new mlir::ModuleOp(func_op)); + } +} diff --git a/spidr/backend/src/Enzyme-JAX/src/enzyme_ad/jax/RegistryUtils.cpp b/spidr/backend/src/Enzyme-JAX/src/enzyme_ad/jax/RegistryUtils.cpp new file mode 100644 index 000000000..5a0f5e983 --- /dev/null +++ b/spidr/backend/src/Enzyme-JAX/src/enzyme_ad/jax/RegistryUtils.cpp @@ -0,0 +1,24 @@ +/* +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +//#include "src/enzyme_ad/jax/RegistryUtils.h" +// +//#include "../../../../mlir/IR/DialectRegistry.h" +// +//extern "C" { +// void prepareRegistry_(DialectRegistry& registry) { +// prepareRegistry(reinterpret_cast(registry)); +// } +//} diff --git a/spidr/backend/src/Enzyme/enzyme/Enzyme/MLIR/Dialect/BUILD b/spidr/backend/src/Enzyme/enzyme/Enzyme/MLIR/Dialect/BUILD new file mode 100644 index 000000000..2de39da72 --- /dev/null +++ b/spidr/backend/src/Enzyme/enzyme/Enzyme/MLIR/Dialect/BUILD @@ -0,0 +1,14 @@ +cc_library( + name = "Dialect", + linkstatic = True, + alwayslink = True, + srcs = glob(["*.cpp"]), + hdrs = glob(["*.h"]), + deps = [ + "@stablehlo//:register", + "@enzyme//:EnzymeMLIR", + "@llvm-project//mlir:IR", + "//src/mlir/IR", + ], + visibility = ["//visibility:public"], +) diff --git a/spidr/backend/src/Enzyme/enzyme/Enzyme/MLIR/Dialect/Dialect.cpp b/spidr/backend/src/Enzyme/enzyme/Enzyme/MLIR/Dialect/Dialect.cpp new file mode 100644 index 000000000..74aad7d40 --- /dev/null +++ b/spidr/backend/src/Enzyme/enzyme/Enzyme/MLIR/Dialect/Dialect.cpp @@ -0,0 +1,25 @@ +/* +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "mlir/IR/DialectRegistry.h" +#include "Enzyme/MLIR/Dialect/Dialect.h" + +#include "../../../../../mlir/IR/DialectRegistry.h" + +extern "C" { + void DialectRegistry_insert_EnzymeDialect(DialectRegistry& s) { + reinterpret_cast(s).insert(); + } +} diff --git a/spidr/backend/src/Enzyme/enzyme/Enzyme/MLIR/Passes/BUILD b/spidr/backend/src/Enzyme/enzyme/Enzyme/MLIR/Passes/BUILD new file mode 100644 index 000000000..ab3aeb8f2 --- /dev/null +++ b/spidr/backend/src/Enzyme/enzyme/Enzyme/MLIR/Passes/BUILD @@ -0,0 +1,12 @@ +cc_library( + name = "Passes", + linkstatic = True, + alwayslink = True, + srcs = glob(["*.cpp"]), + hdrs = glob(["*.h"]), + deps = [ + "@enzyme//:EnzymeMLIR", + "//src/mlir/Pass", + ], + visibility = ["//visibility:public"], +) diff --git a/spidr/backend/src/Enzyme/enzyme/Enzyme/MLIR/Passes/Passes.cpp b/spidr/backend/src/Enzyme/enzyme/Enzyme/MLIR/Passes/Passes.cpp new file mode 100644 index 000000000..1777112ab --- /dev/null +++ b/spidr/backend/src/Enzyme/enzyme/Enzyme/MLIR/Passes/Passes.cpp @@ -0,0 +1,15 @@ +/* +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ \ No newline at end of file diff --git a/spidr/backend/src/ffi.cpp b/spidr/backend/src/ffi.cpp index 2a77d13ac..f940b423e 100644 --- a/spidr/backend/src/ffi.cpp +++ b/spidr/backend/src/ffi.cpp @@ -29,6 +29,10 @@ extern "C" { return ptr == nullptr; } + string* string_new() { + return reinterpret_cast(new std::string()); + } + void string_delete(string* s) { delete reinterpret_cast(s); } diff --git a/spidr/backend/src/llvm/Support/BUILD b/spidr/backend/src/llvm/Support/BUILD new file mode 100644 index 000000000..12ee6525d --- /dev/null +++ b/spidr/backend/src/llvm/Support/BUILD @@ -0,0 +1,12 @@ +cc_library( + name = "Support", + linkstatic = True, + alwayslink = True, + srcs = glob(["*.cpp"]), + hdrs = glob(["*.h"]), + deps = [ + "@llvm-project//llvm:Support", + "//src", + ], + visibility = ["//visibility:public"], +) diff --git a/spidr/backend/src/llvm/Support/raw_ostream.cpp b/spidr/backend/src/llvm/Support/raw_ostream.cpp new file mode 100644 index 000000000..bb30c8b06 --- /dev/null +++ b/spidr/backend/src/llvm/Support/raw_ostream.cpp @@ -0,0 +1,32 @@ +/* +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "llvm/Support/raw_ostream.h" + +#include "../../ffi.h" +#include "raw_ostream.h" + +extern "C" { + struct raw_string_ostream; + + raw_string_ostream* raw_string_ostream_new(string& o) { + auto& o_ = reinterpret_cast(o); + return reinterpret_cast(new llvm::raw_string_ostream(o_)); + } + + void raw_string_ostream_delete(raw_string_ostream* s) { + delete reinterpret_cast(s); + } +} diff --git a/spidr/backend/src/llvm/Support/raw_ostream.h b/spidr/backend/src/llvm/Support/raw_ostream.h new file mode 100644 index 000000000..09f078918 --- /dev/null +++ b/spidr/backend/src/llvm/Support/raw_ostream.h @@ -0,0 +1,18 @@ +/* +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +extern "C" { + struct raw_ostream; +} diff --git a/spidr/backend/src/mlir/IR/Attributes.h b/spidr/backend/src/mlir/IR/Attributes.h new file mode 100644 index 000000000..fb5e8e3a4 --- /dev/null +++ b/spidr/backend/src/mlir/IR/Attributes.h @@ -0,0 +1,18 @@ +/* +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +extern "C" { + struct Attribute; +} diff --git a/spidr/backend/src/mlir/IR/BUILD b/spidr/backend/src/mlir/IR/BUILD new file mode 100644 index 000000000..f034b361b --- /dev/null +++ b/spidr/backend/src/mlir/IR/BUILD @@ -0,0 +1,12 @@ +cc_library( + name = "IR", + linkstatic = True, + alwayslink = True, + srcs = glob(["*.cpp"]), + hdrs = glob(["*.h"]), + deps = [ + "@llvm-project//mlir:IR", + "//src", + ], + visibility = ["//visibility:public"], +) diff --git a/spidr/backend/src/mlir/IR/Block.cpp b/spidr/backend/src/mlir/IR/Block.cpp new file mode 100644 index 000000000..43e56ec4e --- /dev/null +++ b/spidr/backend/src/mlir/IR/Block.cpp @@ -0,0 +1,28 @@ +/* +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "mlir/IR/Block.h" + +#include "Block.h" + +extern "C" { + Block* Block_new() { + return reinterpret_cast(new mlir::Block()); + } + + void Block_delete(Block* s) { + delete reinterpret_cast(s); + } +} diff --git a/spidr/backend/src/mlir/IR/Block.h b/spidr/backend/src/mlir/IR/Block.h new file mode 100644 index 000000000..0b730556d --- /dev/null +++ b/spidr/backend/src/mlir/IR/Block.h @@ -0,0 +1,18 @@ +/* +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +extern "C" { + struct Block; +} diff --git a/spidr/backend/src/mlir/IR/BuiltinOps.cpp b/spidr/backend/src/mlir/IR/BuiltinOps.cpp new file mode 100644 index 000000000..ffa671600 --- /dev/null +++ b/spidr/backend/src/mlir/IR/BuiltinOps.cpp @@ -0,0 +1,31 @@ +/* +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "mlir/IR/BuiltinOps.h" + +#include "BuiltinOps.h" +#include "MLIRContext.h" + +extern "C" { + void ModuleOp_delete(ModuleOp* s) { + delete reinterpret_cast(s); + } + + // who owns this? + MLIRContext* ModuleOp_getContext(ModuleOp& s) { + auto s_ = reinterpret_cast(s); + return reinterpret_cast(s_.getContext()); + } +} diff --git a/spidr/backend/src/mlir/IR/BuiltinOps.h b/spidr/backend/src/mlir/IR/BuiltinOps.h new file mode 100644 index 000000000..0fb5ccbec --- /dev/null +++ b/spidr/backend/src/mlir/IR/BuiltinOps.h @@ -0,0 +1,18 @@ +/* +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +extern "C" { + struct ModuleOp; +} diff --git a/spidr/backend/src/mlir/IR/DialectRegistry.cpp b/spidr/backend/src/mlir/IR/DialectRegistry.cpp new file mode 100644 index 000000000..dfc543d57 --- /dev/null +++ b/spidr/backend/src/mlir/IR/DialectRegistry.cpp @@ -0,0 +1,28 @@ +/* +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "mlir/IR/DialectRegistry.h" + +#include "DialectRegistry.h" + +extern "C" { + DialectRegistry* DialectRegistry_new() { + return reinterpret_cast(new mlir::DialectRegistry()); + } + + void DialectRegistry_delete(DialectRegistry* s) { + delete reinterpret_cast(s); + } +} diff --git a/spidr/backend/src/mlir/IR/DialectRegistry.h b/spidr/backend/src/mlir/IR/DialectRegistry.h new file mode 100644 index 000000000..58c7ab272 --- /dev/null +++ b/spidr/backend/src/mlir/IR/DialectRegistry.h @@ -0,0 +1,18 @@ +/* +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +extern "C" { + struct DialectRegistry; +} diff --git a/spidr/backend/src/mlir/IR/Location.cpp b/spidr/backend/src/mlir/IR/Location.cpp new file mode 100644 index 000000000..8d4c7d7be --- /dev/null +++ b/spidr/backend/src/mlir/IR/Location.cpp @@ -0,0 +1,28 @@ +/* +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "mlir/IR/Location.h" + +#include "Location.h" + +extern "C" { + Location* Location_new(...) { + return nullptr; //reinterpret_cast(new mlir::Location(...)); + } + + void Location_delete(Location* s) { + delete reinterpret_cast(s); + } +} diff --git a/spidr/backend/src/mlir/IR/Location.h b/spidr/backend/src/mlir/IR/Location.h new file mode 100644 index 000000000..438fe63cd --- /dev/null +++ b/spidr/backend/src/mlir/IR/Location.h @@ -0,0 +1,18 @@ +/* +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +extern "C" { + struct Location; +} diff --git a/spidr/backend/src/mlir/IR/MLIRContext.cpp b/spidr/backend/src/mlir/IR/MLIRContext.cpp new file mode 100644 index 000000000..361e43a90 --- /dev/null +++ b/spidr/backend/src/mlir/IR/MLIRContext.cpp @@ -0,0 +1,39 @@ +/* +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "mlir/IR/MLIRContext.h" + +#include "DialectRegistry.h" +#include "MLIRContext.h" + +extern "C" { + MLIRContext* MLIRContext_new() { + return reinterpret_cast(new mlir::MLIRContext); + } + + void MLIRContext_delete(MLIRContext* s) { + delete reinterpret_cast(s); + } + +// DialectRegistry* MLIRContext_getDialectRegistry(MLIRContext& s) { +// auto& s_ = reinterpret_cast(s); +// return reinterpret_cast(s_.getDialectRegistry()); +// } + + void MLIRContext_appendDialectRegistry(MLIRContext& s, DialectRegistry& registry) { + auto& registry_ = reinterpret_cast(registry); + reinterpret_cast(s).appendDialectRegistry(registry_); + } +} diff --git a/spidr/backend/src/mlir/IR/MLIRContext.h b/spidr/backend/src/mlir/IR/MLIRContext.h new file mode 100644 index 000000000..efa58bc0c --- /dev/null +++ b/spidr/backend/src/mlir/IR/MLIRContext.h @@ -0,0 +1,18 @@ +/* +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +extern "C" { + struct MLIRContext; +} diff --git a/spidr/backend/src/mlir/IR/Operation.h b/spidr/backend/src/mlir/IR/Operation.h new file mode 100644 index 000000000..31743deed --- /dev/null +++ b/spidr/backend/src/mlir/IR/Operation.h @@ -0,0 +1,18 @@ +/* +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +extern "C" { + struct Operation; +} diff --git a/spidr/backend/src/mlir/IR/OperationSupport.cpp b/spidr/backend/src/mlir/IR/OperationSupport.cpp new file mode 100644 index 000000000..fe82530af --- /dev/null +++ b/spidr/backend/src/mlir/IR/OperationSupport.cpp @@ -0,0 +1,54 @@ +/* +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "llvm/ADT/StringRef.h" +#include "mlir/IR/OperationSupport.h" + +#include "Attributes.h" +#include "Block.h" +#include "Location.h" +#include "ValueRange.h" + +extern "C" { + struct OperationState; + + OperationState* OperationState_new(Location& location, char* name) { + auto& location_ = reinterpret_cast(location); + auto op_state = new mlir::OperationState(location_, name); + return reinterpret_cast(op_state); + } + + void OperationState_delete(OperationState* s) { + delete reinterpret_cast(s); + } + + void OperationState_addOperands(OperationState& s, ValueRange& newOperands) { + auto& s_ = reinterpret_cast(s); + auto& newOperands_ = reinterpret_cast(newOperands); + s_.addOperands(newOperands_); + } + + void OperationState_addAttribute(OperationState& s, char* name, Attribute& attr) { + auto& s_ = reinterpret_cast(s); + auto& attr_ = reinterpret_cast(attr); + s_.addAttribute(name, attr_); + } + +// void OperationState_addSuccessors(OperationState& s, Block* successor) { +// auto& s_ = reinterpret_cast(s); +// auto successor_ = reinterpret_cast(successor); +// s_.addSuccessors(successor_); +// } +} diff --git a/spidr/backend/src/mlir/IR/ValueRange.h b/spidr/backend/src/mlir/IR/ValueRange.h new file mode 100644 index 000000000..c568df6d0 --- /dev/null +++ b/spidr/backend/src/mlir/IR/ValueRange.h @@ -0,0 +1,18 @@ +/* +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +extern "C" { + struct ValueRange; +} diff --git a/spidr/backend/src/mlir/Pass/BUILD b/spidr/backend/src/mlir/Pass/BUILD new file mode 100644 index 000000000..125748060 --- /dev/null +++ b/spidr/backend/src/mlir/Pass/BUILD @@ -0,0 +1,12 @@ +cc_library( + name = "Pass", + linkstatic = True, + alwayslink = True, + srcs = glob(["*.cpp"]), + hdrs = glob(["*.h"]), + deps = [ + "@llvm-project//mlir:Pass", + "//src/mlir/IR", + ], + visibility = ["//visibility:public"], +) diff --git a/spidr/backend/src/mlir/Pass/Pass.cpp b/spidr/backend/src/mlir/Pass/Pass.cpp new file mode 100644 index 000000000..07ee4e861 --- /dev/null +++ b/spidr/backend/src/mlir/Pass/Pass.cpp @@ -0,0 +1,24 @@ +/* +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "mlir/Pass/Pass.h" + +#include "Pass.h" + +extern "C" { + void Pass_delete(Pass* s) { + delete reinterpret_cast(s); + } +} diff --git a/spidr/backend/src/mlir/Pass/Pass.h b/spidr/backend/src/mlir/Pass/Pass.h new file mode 100644 index 000000000..edbf60f2d --- /dev/null +++ b/spidr/backend/src/mlir/Pass/Pass.h @@ -0,0 +1,18 @@ +/* +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +extern "C" { + struct Pass; +} diff --git a/spidr/backend/src/mlir/Pass/PassManager.cpp b/spidr/backend/src/mlir/Pass/PassManager.cpp new file mode 100644 index 000000000..6116840fe --- /dev/null +++ b/spidr/backend/src/mlir/Pass/PassManager.cpp @@ -0,0 +1,50 @@ +/* +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" + +#include "Pass.h" +#include "../IR/BuiltinOps.h" +#include "../IR/MLIRContext.h" +#include "../IR/Operation.h" + +extern "C" { + struct PassManager; + + PassManager* PassManager_new(MLIRContext* ctx) { + auto ctx_ = reinterpret_cast(ctx); + return reinterpret_cast(new mlir::PassManager(ctx_)); + } + + void PassManager_delete(PassManager* s) { + delete reinterpret_cast(s); + } + + void PassManager_addPass(PassManager& s, Pass* pass) { + return; // i hate cpp +// auto& s_ = reinterpret_cast(s); +// auto pass_ = reinterpret_cast(pass); +// auto pass__ = std::unique_ptr{std::exchange(pass_, nullptr)}; +// s_.addPass(pass__); + } + + int PassManager_run(PassManager& s, Operation* op) { + auto& s_ = reinterpret_cast(s); + auto op_ = reinterpret_cast(op); + return (int) s_.run(op_).succeeded(); + } +} diff --git a/spidr/backend/src/stablehlo/dialect/BUILD b/spidr/backend/src/stablehlo/dialect/BUILD new file mode 100644 index 000000000..5f76ca13f --- /dev/null +++ b/spidr/backend/src/stablehlo/dialect/BUILD @@ -0,0 +1,14 @@ +cc_library( + name = "dialect", + linkstatic = True, + alwayslink = True, + srcs = glob(["*.cpp"]), + hdrs = glob(["*.h"]), + deps = [ + "@stablehlo//:register", + "@stablehlo//:stablehlo_serialization", + "//src/llvm/Support", + "//src/mlir/IR", + ], + visibility = ["//visibility:public"], +) diff --git a/spidr/backend/src/stablehlo/dialect/Register.cpp b/spidr/backend/src/stablehlo/dialect/Register.cpp new file mode 100644 index 000000000..505668a34 --- /dev/null +++ b/spidr/backend/src/stablehlo/dialect/Register.cpp @@ -0,0 +1,24 @@ +/* +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "stablehlo/dialect/Register.h" + +#include "../../mlir/IR/DialectRegistry.h" + +extern "C" { + void registerAllDialects(DialectRegistry& registry) { + mlir::stablehlo::registerAllDialects(reinterpret_cast(registry)); + } +} diff --git a/spidr/backend/src/stablehlo/dialect/Serialization.cpp b/spidr/backend/src/stablehlo/dialect/Serialization.cpp new file mode 100644 index 000000000..8ba423ca7 --- /dev/null +++ b/spidr/backend/src/stablehlo/dialect/Serialization.cpp @@ -0,0 +1,30 @@ +/* +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "stablehlo/dialect/Serialization.h" + +#include "../../mlir/IR/BuiltinOps.h" +#include "../../llvm/Support/raw_ostream.h" +#include "../../ffi.h" + +extern "C" { + int serializePortableArtifact(ModuleOp& module, string& version, raw_ostream& os) { + auto& module_ = reinterpret_cast(module); + auto& version_ = reinterpret_cast(version); + auto& os_ = reinterpret_cast(os); + auto result = mlir::stablehlo::serializePortableArtifact(module_, version_, os_); + return (int) result.succeeded(); + } +} diff --git a/spidr/backend/src/stablehlo/dialect/Version.cpp b/spidr/backend/src/stablehlo/dialect/Version.cpp new file mode 100644 index 000000000..c402990c2 --- /dev/null +++ b/spidr/backend/src/stablehlo/dialect/Version.cpp @@ -0,0 +1,36 @@ +/* +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "stablehlo/dialect/Version.h" + +#include "../../ffi.h" + +extern "C" { + struct Version; + + void Version_delete(Version* s) { + delete reinterpret_cast(s); + } + + Version* Version_getMinimumVersion() { + auto version = mlir::vhlo::Version::getMinimumVersion(); + return reinterpret_cast(new mlir::vhlo::Version(version)); + } + + string* Version_toString(Version& s) { + auto& s_ = reinterpret_cast(s); + return reinterpret_cast(new std::string(s_.toString())); + } +} diff --git a/spidr/backend/src/xla/hlo/builder/BUILD b/spidr/backend/src/xla/hlo/builder/BUILD index e729f1eef..48be5352b 100644 --- a/spidr/backend/src/xla/hlo/builder/BUILD +++ b/spidr/backend/src/xla/hlo/builder/BUILD @@ -8,6 +8,7 @@ cc_library( "@xla//xla/hlo/builder:xla_builder", "//src", "//src/xla", + "//src/xla/service", ], visibility = ["//visibility:public"], ) diff --git a/spidr/backend/src/xla/hlo/builder/xla_builder.cpp b/spidr/backend/src/xla/hlo/builder/xla_builder.cpp index 89c46188f..32b06a48d 100644 --- a/spidr/backend/src/xla/hlo/builder/xla_builder.cpp +++ b/spidr/backend/src/xla/hlo/builder/xla_builder.cpp @@ -290,6 +290,18 @@ extern "C" { return reinterpret_cast(new xla::XlaOp(res)); } + XlaOp* Call( + XlaBuilder* builder, XlaComputation& computation, XlaOp* operands, size_t operands_len + ) { + auto builder_ = reinterpret_cast(builder); + auto& computation_ = reinterpret_cast(computation); + auto operands_ = reinterpret_cast(operands); + auto operands_span = absl::Span(operands_, operands_len); + + auto res = xla::Call(builder_, computation_, operands_span); + return reinterpret_cast(new xla::XlaOp(res)); + } + XlaOp* Add(XlaOp& lhs, XlaOp& rhs) { return binOp(xla::Add, lhs, rhs); } XlaOp* Sub(XlaOp& lhs, XlaOp& rhs) { return binOp(xla::Sub, lhs, rhs); } XlaOp* Mul(XlaOp& lhs, XlaOp& rhs) { return binOp(xla::Mul, lhs, rhs); } diff --git a/spidr/backend/src/xla/hlo/builder/xla_computation.cpp b/spidr/backend/src/xla/hlo/builder/xla_computation.cpp index 1cba3a527..2c75b2089 100644 --- a/spidr/backend/src/xla/hlo/builder/xla_computation.cpp +++ b/spidr/backend/src/xla/hlo/builder/xla_computation.cpp @@ -14,18 +14,26 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "xla/hlo/builder/xla_computation.h" +#include "xla/shape.h" #include "../../../ffi.h" +#include "../../service/hlo.proto.h" +#include "../../shape.h" #include "xla_computation.h" extern "C" { + XlaComputation* XlaComputation_new(HloModuleProto& proto) { + auto& proto_ = reinterpret_cast(proto); + // this moves the proto? should we then not GC it? + return reinterpret_cast(new xla::XlaComputation(proto_)); + } + void XlaComputation_delete(XlaComputation* s) { delete reinterpret_cast(s); } - string* XlaComputation_SerializeAsString(XlaComputation* s) { + HloModuleProto* XlaComputation_proto(XlaComputation* s) { auto s_ = reinterpret_cast(s); - auto serialized = s_->proto().SerializeAsString(); - return reinterpret_cast(new std::string(serialized)); + return reinterpret_cast(new xla::HloModuleProto(s_->proto())); } } diff --git a/spidr/backend/src/xla/hlo/ir/BUILD b/spidr/backend/src/xla/hlo/ir/BUILD new file mode 100644 index 000000000..b09f5e688 --- /dev/null +++ b/spidr/backend/src/xla/hlo/ir/BUILD @@ -0,0 +1,12 @@ +cc_library( + name = "ir", + linkstatic = True, + alwayslink = True, + srcs = glob(["*.cpp"]), + hdrs = glob(["*.h"]), + deps = [ + "@xla//xla/hlo/ir:hlo", + "//src/xla/service", + ], + visibility = ["//visibility:public"], +) diff --git a/spidr/backend/src/xla/hlo/ir/hlo_module.cpp b/spidr/backend/src/xla/hlo/ir/hlo_module.cpp new file mode 100644 index 000000000..4944fe917 --- /dev/null +++ b/spidr/backend/src/xla/hlo/ir/hlo_module.cpp @@ -0,0 +1,37 @@ +/* +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "xla/hlo/ir/hlo_module.h" +#include "xla/service/hlo.pb.h" +#include "xla/service/hlo_module_config.h" + +#include "hlo_module.h" + +#include "../../service/hlo.proto.h" +#include "../../service/hlo_module_config.h" + +extern "C" { + HloModule* HloModule_CreateFromProto(HloModuleProto& proto, HloModuleConfig& module_config) { + // put print statements in all C functions to see if the error's coming from elsewhere + auto& proto_ = reinterpret_cast(proto); + auto& module_config_ = reinterpret_cast(module_config); + auto module = xla::HloModule::CreateFromProto(proto_, module_config_); + return reinterpret_cast(module.value().release()); + } + + void HloModule_delete(HloModule* s) { + delete reinterpret_cast(s); + } +} diff --git a/spidr/backend/src/xla/hlo/ir/hlo_module.h b/spidr/backend/src/xla/hlo/ir/hlo_module.h new file mode 100644 index 000000000..5fc43b5a6 --- /dev/null +++ b/spidr/backend/src/xla/hlo/ir/hlo_module.h @@ -0,0 +1,18 @@ +/* +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +extern "C" { + struct HloModule; +} diff --git a/spidr/backend/src/xla/hlo/translate/BUILD b/spidr/backend/src/xla/hlo/translate/BUILD new file mode 100644 index 000000000..75212dc84 --- /dev/null +++ b/spidr/backend/src/xla/hlo/translate/BUILD @@ -0,0 +1,13 @@ +cc_library( + name = "translate", + linkstatic = True, + alwayslink = True, + srcs = glob(["*.cpp"]), + hdrs = glob(["*.h"]), + deps = [ + "@xla//xla/hlo/translate:stablehlo", + "//src/mlir/IR", + "//src/xla/service", + ], + visibility = ["//visibility:public"], +) diff --git a/spidr/backend/src/xla/hlo/translate/stablehlo.cpp b/spidr/backend/src/xla/hlo/translate/stablehlo.cpp new file mode 100644 index 000000000..6ef61891c --- /dev/null +++ b/spidr/backend/src/xla/hlo/translate/stablehlo.cpp @@ -0,0 +1,40 @@ +/* +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "mlir/IR/BuiltinOps.h" +#include "xla/service/hlo.pb.h" +#include "xla/hlo/translate/stablehlo.h" + +#include "../../service/hlo.proto.h" +#include "../../../mlir/IR/BuiltinOps.h" +#include "../../../mlir/IR/MLIRContext.h" + +extern "C" { + ModuleOp* ConvertHloToStablehlo(MLIRContext& ctx, HloModuleProto* hlo_module) { + auto& ctx_ = reinterpret_cast(ctx); + auto hlo_module_ = reinterpret_cast(hlo_module); + auto module_op = xla::ConvertHloToStablehlo(ctx_, hlo_module_); + module_op.value()->dump(); + return reinterpret_cast(new mlir::ModuleOp(module_op.value().release())); + } + + HloModuleProto* ConvertStablehloToHlo(ModuleOp& module) { + auto& module_ = reinterpret_cast(module); + auto hlo = xla::ConvertStablehloToHlo(module_).value().release(); + // mode ToProto to separate function? + auto res = hlo->ToProto(); + return reinterpret_cast(new xla::HloModuleProto(res)); + } +} diff --git a/spidr/backend/src/xla/mlir_hlo/mhlo/IR/BUILD b/spidr/backend/src/xla/mlir_hlo/mhlo/IR/BUILD new file mode 100644 index 000000000..e7f37a3c4 --- /dev/null +++ b/spidr/backend/src/xla/mlir_hlo/mhlo/IR/BUILD @@ -0,0 +1,12 @@ +cc_library( + name = "IR", + linkstatic = True, + alwayslink = True, + srcs = glob(["*.cpp"]), + hdrs = glob(["*.h"]), + deps = [ + "@xla//xla/mlir_hlo:hlo_dialect_registration", + "//src/mlir/IR", + ], + visibility = ["//visibility:public"], +) diff --git a/spidr/backend/src/xla/mlir_hlo/mhlo/IR/register.cpp b/spidr/backend/src/xla/mlir_hlo/mhlo/IR/register.cpp new file mode 100644 index 000000000..eb9319d4d --- /dev/null +++ b/spidr/backend/src/xla/mlir_hlo/mhlo/IR/register.cpp @@ -0,0 +1,24 @@ +/* +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "xla/mlir_hlo/mhlo/IR/register.h" + +#include "../../../../mlir/IR/DialectRegistry.h" + +extern "C" { + void registerAllMhloDialects(DialectRegistry& registry) { + mlir::mhlo::registerAllMhloDialects(reinterpret_cast(registry)); + } +} diff --git a/spidr/backend/src/xla/pjrt/c/pjrt_c_api.cpp b/spidr/backend/src/xla/pjrt/c/pjrt_c_api.cpp index 040efb300..c17be41fc 100644 --- a/spidr/backend/src/xla/pjrt/c/pjrt_c_api.cpp +++ b/spidr/backend/src/xla/pjrt/c/pjrt_c_api.cpp @@ -138,7 +138,7 @@ extern "C" { } PJRT_Program* PJRT_Program_new(char* code, size_t code_size) { - auto format = pjrt::kHloFormat; + auto format = pjrt::kMlirFormat; return new PJRT_Program{ .struct_size = PJRT_Program_STRUCT_SIZE, .extension_start = nullptr, diff --git a/spidr/backend/src/xla/service/BUILD b/spidr/backend/src/xla/service/BUILD new file mode 100644 index 000000000..38c7f2d3e --- /dev/null +++ b/spidr/backend/src/xla/service/BUILD @@ -0,0 +1,13 @@ +cc_library( + name = "service", + linkstatic = True, + alwayslink = True, + srcs = glob(["*.cpp"]), + hdrs = glob(["*.h"]), + deps = [ + "@xla//xla/service", + "//src/xla", + "//src", + ], + visibility = ["//visibility:public"], +) diff --git a/spidr/backend/src/xla/service/hlo.proto.cpp b/spidr/backend/src/xla/service/hlo.proto.cpp new file mode 100644 index 000000000..f62a63af0 --- /dev/null +++ b/spidr/backend/src/xla/service/hlo.proto.cpp @@ -0,0 +1,24 @@ +/* +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "xla/service/hlo.pb.h" + +#include "hlo.proto.h" + +extern "C" { + void HloModuleProto_delete(HloModuleProto* s) { + delete reinterpret_cast(s); + } +} diff --git a/spidr/backend/src/xla/service/hlo.proto.h b/spidr/backend/src/xla/service/hlo.proto.h new file mode 100644 index 000000000..336bbeaf3 --- /dev/null +++ b/spidr/backend/src/xla/service/hlo.proto.h @@ -0,0 +1,18 @@ +/* +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +extern "C" { + struct HloModuleProto; +} diff --git a/spidr/spidr.ipkg b/spidr/spidr.ipkg index fa7b670c0..486eeaf6b 100644 --- a/spidr/spidr.ipkg +++ b/spidr/spidr.ipkg @@ -8,6 +8,28 @@ modules = BayesianOptimization, BayesianOptimization.Acquisition, + Compiler.EnzymeJAX.Src.EnzymeAD.JAX.Passes.Passes, + Compiler.EnzymeJAX.Src.EnzymeAD.JAX.RegistryUtils, + + Compiler.Enzyme.Enzyme.Enzyme.MLIR.Dialect.Dialect, + Compiler.Enzyme.Enzyme.Enzyme.MLIR.Passes.Passes, + + Compiler.LLVM.Support.RawOStream, + + Compiler.MLIR.IR.Block, + Compiler.MLIR.IR.BuiltinOps, + Compiler.MLIR.IR.DialectRegistry, + Compiler.MLIR.IR.Location, + Compiler.MLIR.IR.MLIRContext, + Compiler.MLIR.IR.OperationSupport, + Compiler.MLIR.IR.ValueRange, + Compiler.MLIR.Pass.PassManager, + Compiler.MLIR.Pass.Pass, + + Compiler.StableHLO.Dialect.Register, + Compiler.StableHLO.Dialect.Serialization, + Compiler.StableHLO.Dialect.Version, + Compiler.Xla.Client.ExecutableBuildOptions, Compiler.Xla.HLO.Builder.Lib.Arithmetic, Compiler.Xla.HLO.Builder.Lib.Constants, @@ -16,8 +38,11 @@ modules = Compiler.Xla.HLO.Builder.Lib.PRNG, Compiler.Xla.HLO.Builder.XlaBuilder, Compiler.Xla.HLO.Builder.XlaComputation, + Compiler.Xla.HLO.Translate.StableHLO, + Compiler.Xla.MLIRHLO.MHLO.IR.Register, Compiler.Xla.PJRT.C.PjrtCApi, Compiler.Xla.PJRT.PjrtExecutable, + Compiler.Xla.Service.HloProto, Compiler.Xla.Literal, Compiler.Xla.Shape, Compiler.Xla.ShapeUtil, diff --git a/spidr/src/Compiler/Enzyme/Enzyme/Enzyme/MLIR/Dialect/Dialect.idr b/spidr/src/Compiler/Enzyme/Enzyme/Enzyme/MLIR/Dialect/Dialect.idr new file mode 100644 index 000000000..25b240e8f --- /dev/null +++ b/spidr/src/Compiler/Enzyme/Enzyme/Enzyme/MLIR/Dialect/Dialect.idr @@ -0,0 +1,28 @@ +{-- +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--} +||| For internal spidr use only. +module Compiler.Enzyme.Enzyme.Enzyme.MLIR.Dialect.Dialect + +import Compiler.MLIR.IR.DialectRegistry +import Compiler.MLIR.Pass.Pass +import Compiler.FFI + +%foreign (libxla "DialectRegistry_insert_EnzymeDialect") +prim__dialectRegistryInsertEnzymeDialect : GCAnyPtr -> PrimIO () + +export +insertEnzymeDialect : HasIO io => DialectRegistry -> io () +insertEnzymeDialect (MkDialectRegistry reg) = primIO $ prim__dialectRegistryInsertEnzymeDialect reg diff --git a/spidr/src/Compiler/Enzyme/Enzyme/Enzyme/MLIR/Passes/Passes.idr b/spidr/src/Compiler/Enzyme/Enzyme/Enzyme/MLIR/Passes/Passes.idr new file mode 100644 index 000000000..51f375301 --- /dev/null +++ b/spidr/src/Compiler/Enzyme/Enzyme/Enzyme/MLIR/Passes/Passes.idr @@ -0,0 +1,42 @@ +{-- +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--} +||| For internal spidr use only. +module Compiler.Enzyme.Enzyme.Enzyme.MLIR.Passes.Passes + +import Compiler.MLIR.IR.BuiltinOps +import Compiler.MLIR.IR.DialectRegistry +import Compiler.MLIR.Pass.Pass +import Compiler.FFI + +%foreign (libxla "createDifferentiatePass") +prim__createDifferentiatePass : PrimIO AnyPtr + +export +createDifferentiatePass : HasIO io => io Pass +createDifferentiatePass = do + pass <- primIO prim__createDifferentiatePass + pass <- onCollectAny pass (primIO . Pass.prim__delete) + pure (MkPass pass) + +%foreign (libxla "emitEnzymeADOp") +prim__emitEnzymeADOp : GCAnyPtr -> PrimIO AnyPtr + +export +emitEnzymeADOp : HasIO io => ModuleOp -> io ModuleOp +emitEnzymeADOp (MkModuleOp op) = do + op <- primIO $ prim__emitEnzymeADOp op + op <- onCollectAny op (primIO . BuiltinOps.prim__delete) + pure (MkModuleOp op) diff --git a/spidr/src/Compiler/EnzymeJAX/Src/EnzymeAD/JAX/Passes/Passes.idr b/spidr/src/Compiler/EnzymeJAX/Src/EnzymeAD/JAX/Passes/Passes.idr new file mode 100644 index 000000000..dbae4dd88 --- /dev/null +++ b/spidr/src/Compiler/EnzymeJAX/Src/EnzymeAD/JAX/Passes/Passes.idr @@ -0,0 +1,33 @@ +{-- +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--} +||| For internal spidr use only. +module Compiler.EnzymeJAX.Src.EnzymeAD.JAX.Passes.Passes + +import Compiler.FFI + +%foreign (libxla "regsiterenzymeXLAPasses_") +prim__regsiterenzymeXLAPasses : PrimIO () + +export +regsiterenzymeXLAPasses : HasIO io => io () +regsiterenzymeXLAPasses = primIO prim__regsiterenzymeXLAPasses + +%foreign (libxla "registerenzymePasses") +prim__registerenzymePasses : PrimIO () + +export +registerenzymePasses : HasIO io => io () +registerenzymePasses = primIO prim__registerenzymePasses diff --git a/spidr/src/Compiler/EnzymeJAX/Src/EnzymeAD/JAX/RegistryUtils.idr b/spidr/src/Compiler/EnzymeJAX/Src/EnzymeAD/JAX/RegistryUtils.idr new file mode 100644 index 000000000..65f9bff4c --- /dev/null +++ b/spidr/src/Compiler/EnzymeJAX/Src/EnzymeAD/JAX/RegistryUtils.idr @@ -0,0 +1,27 @@ +{-- +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--} +||| For internal spidr use only. +module Compiler.EnzymeJAX.Src.EnzymeAD.JAX.RegistryUtils + +import Compiler.MLIR.IR.DialectRegistry +import Compiler.FFI + +%foreign (libxla "prepareRegistry_") +prim__prepareRegistry : GCAnyPtr -> PrimIO () + +export +prepareRegistry : HasIO io => DialectRegistry -> io () +prepareRegistry (MkDialectRegistry registry) = primIO $ prim__prepareRegistry registry diff --git a/spidr/src/Compiler/Eval.idr b/spidr/src/Compiler/Eval.idr index 6c26186f5..310c64496 100644 --- a/spidr/src/Compiler/Eval.idr +++ b/spidr/src/Compiler/Eval.idr @@ -23,9 +23,19 @@ import Data.IOArray import Data.List import Data.List.Elem -import Compiler.Expr -import Compiler.FFI -import Compiler.LiteralRW +import Compiler.EnzymeJAX.Src.EnzymeAD.JAX.Passes.Passes +import Compiler.EnzymeJAX.Src.EnzymeAD.JAX.RegistryUtils +import Compiler.Enzyme.Enzyme.Enzyme.MLIR.Dialect.Dialect +import Compiler.Enzyme.Enzyme.Enzyme.MLIR.Passes.Passes +import Compiler.LLVM.Support.RawOStream +import Compiler.MLIR.IR.BuiltinOps +import Compiler.MLIR.IR.DialectRegistry +import Compiler.MLIR.IR.MLIRContext +import Compiler.MLIR.Pass.PassManager +import Compiler.MLIR.Pass.Pass +import Compiler.StableHLO.Dialect.Register +import Compiler.StableHLO.Dialect.Serialization +import Compiler.StableHLO.Dialect.Version import Compiler.Xla.Client.ExecutableBuildOptions import Compiler.Xla.HLO.Builder.Lib.Arithmetic import Compiler.Xla.HLO.Builder.Lib.Constants @@ -34,34 +44,62 @@ import Compiler.Xla.HLO.Builder.Lib.Matrix import Compiler.Xla.HLO.Builder.Lib.PRNG import Compiler.Xla.HLO.Builder.XlaBuilder import Compiler.Xla.HLO.Builder.XlaComputation +import Compiler.Xla.HLO.Translate.StableHLO +import Compiler.Xla.MLIRHLO.MHLO.IR.Register import Compiler.Xla.PJRT.C.PjrtCApi import Compiler.Xla.PJRT.PjrtExecutable +import Compiler.Xla.Service.HloProto import Compiler.Xla.Literal import Compiler.Xla.Shape import Compiler.Xla.ShapeUtil import Compiler.Xla.XlaData +import Compiler.Expr +import Compiler.FFI +import Compiler.LiteralRW import Literal import Primitive import Types import Util import Device +import System + export data Err = OutOfBounds Nat Nat | ValueNotFound Nat | PjrtErr PjrtError + | SerializationError String + | MlirPassError String export Show Err where show (OutOfBounds idx size) = "Index \{show idx} is out of bounds for array of size \{show size}" show (ValueNotFound idx) = "Value not found at index \{show idx}" - show (PjrtErr err)= show err + show (PjrtErr err) = show err + show (SerializationError err) = "SerializationError: \{err}" + show (MlirPassError err) = "MlirPassError: \{err}" public export 0 ErrIO : Type -> Type ErrIO = EitherT Err IO +serializeStableHLO : ModuleOp -> ErrIO CharArray +serializeStableHLO stablehlo = do + code <- cppString + version <- toString !getMinimumVersion + ok <- serializePortableArtifact stablehlo version !(rawStringOStream code) + if ok then stringToCharArray code else throwE (SerializationError "Failed to serialize StableHLO") + +hloModuleProtoToStableHLO : HloModuleProto -> ErrIO ModuleOp +hloModuleProtoToStableHLO proto = do + dialectRegistry <- mkDialectRegistry + registerAllMhloDialects dialectRegistry + registerAllDialects dialectRegistry + mlirCtx <- mkMLIRContext + appendDialectRegistry mlirCtx dialectRegistry + convertHloToStablehlo mlirCtx proto + covering interpret : IOArray XlaOp => XlaBuilder -> Fn arity -> ErrIO XlaOp @@ -97,6 +135,31 @@ interpret @{cache} xlaBuilder (MkFn params root env) = do interpretE (Var x) = get x interpretE (Tuple xs) = tuple xlaBuilder !(traverse interpretE xs) interpretE (GetTupleElement idx x) = getTupleElement !(interpretE x) idx + interpretE (Grad f x) = do + putStrLn "interpretE (Grad _ _)" + computation <- compile xlaBuilder f + stablehlo <- hloModuleProtoToStableHLO !(proto computation) + -- ctx <- getContext stablehlo + -- reg <- mkDialectRegistry + -- appendDialectRegistry ctx reg + -- insertEnzymeDialect reg + enzymeOp <- emitEnzymeADOp stablehlo + --regsiterenzymeXLAPasses + --prepareRegistry reg + --registerenzymePasses + -- need other dialects? + -- surely the ModuleOp already has stablehlo registered, since it's stablehlo code + -- StableHLO.Dialect.Register.registerAllDialects reg + -- registerStableHLODialectAutoDiffInterface reg + --mgr <- mkPassManager ctx + --addPass mgr !createDifferentiatePass + -- True <- run mgr enzymeOp + -- | False => throwE $ MlirPassError "Failed to run differentiate pass on StableHLO" + hloProto <- convertStablehloToHlo enzymeOp + computation <- mkXlaComputation hloProto + -- x should be correct shape, because we're sending R^{n0, n1, ..} -> R + -- to R^{n0, n1, ..} -> R^{n0, n1, ..} i.e. we're only changing the output shape + call xlaBuilder computation [!(interpretE x)] interpretE (MinValue {dtype}) = minValue {dtype} xlaBuilder interpretE (MaxValue {dtype}) = maxValue {dtype} xlaBuilder interpretE (MinFiniteValue {dtype}) = minFiniteValue {dtype} xlaBuilder @@ -223,11 +286,12 @@ execute : Device -> Fn 0 -> {outputs : _} -> Vect outputs Xla.Shape -> ErrIO $ V execute (MkDevice api client) f@(MkFn _ _ env) shapes = do xlaBuilder <- mkXlaBuilder "root" computation <- compile @{!(newArray $ cast $ counter env)} xlaBuilder f + code <- serializeStableHLO !(hloModuleProtoToStableHLO !(proto computation)) + executableBuildOptions <- mkExecutableBuildOptions + compileOptions <- serializeAsString !(mkCompileOptions executableBuildOptions) + program <- mkPjrtProgram code bimapEitherT PjrtErr id $ do - code <- serializeAsString computation - executableBuildOptions <- mkExecutableBuildOptions - compileOptions <- serializeAsString !(mkCompileOptions executableBuildOptions) - loadedExec <- pjrtClientCompile api client !(mkPjrtProgram code) compileOptions + loadedExec <- pjrtClientCompile api client program compileOptions free code free compileOptions delete executableBuildOptions diff --git a/spidr/src/Compiler/Expr.idr b/spidr/src/Compiler/Expr.idr index c5c4ff68d..1912b8062 100644 --- a/spidr/src/Compiler/Expr.idr +++ b/spidr/src/Compiler/Expr.idr @@ -106,6 +106,7 @@ data Expr : Type where Var : Nat -> Expr Tuple : List Expr -> Expr GetTupleElement : (index : Nat) -> Expr -> Expr + Grad : Fn 1 -> Expr -> Expr -- temporary name MinValue : Primitive dtype => Expr MaxValue : Primitive dtype => Expr MinFiniteValue : Primitive dtype => Expr @@ -184,6 +185,7 @@ showExpr indent (FromLiteral {shape, dtype} x) = "Lit \{shape} \{xlaIdentifier { showExpr indent (Var k) = "Var \{k}" showExpr indent (Tuple xs) = "Tuple \{showExprList indent xs}" showExpr indent (GetTupleElement k x) = "GetTupleElement {index = \{k}} (\{showExpr indent x})" +showExpr indent (Grad _ _) = "Grad" showExpr indent (MinValue {dtype}) = "MinValue {dtype = \{xlaIdentifier {dtype}}}" showExpr indent (MaxValue {dtype}) = "MaxValue {dtype = \{xlaIdentifier {dtype}}}" showExpr indent (MinFiniteValue {dtype}) = "MinFiniteValue {dtype = \{xlaIdentifier {dtype}}}" diff --git a/spidr/src/Compiler/FFI.idr b/spidr/src/Compiler/FFI.idr index aec92c193..ea4bfd34f 100644 --- a/spidr/src/Compiler/FFI.idr +++ b/spidr/src/Compiler/FFI.idr @@ -26,15 +26,32 @@ libxla fname = "C:" ++ fname ++ ",libc_xla" public export data CharArray = MkCharArray (Ptr Char) Bits64 +public export +data CppString = MkCppString AnyPtr + namespace CharArray export free : HasIO io => CharArray -> io () free (MkCharArray arr _) = free $ prim__forgetPtr arr +export +%foreign (libxla "string_new") +prim__mkString : PrimIO AnyPtr + +||| It is up to the caller to `delete` the string. +export +cppString : HasIO io => io CppString +cppString = MkCppString <$> primIO prim__mkString + export %foreign (libxla "string_delete") prim__stringDelete : AnyPtr -> PrimIO () +namespace CppString + export + delete : HasIO io => CppString -> io () + delete (MkCppString str) = primIO $ prim__stringDelete str + export %foreign (libxla "string_data") prim__stringData : AnyPtr -> PrimIO $ Ptr Char @@ -47,6 +64,15 @@ export %foreign (libxla "idx") prim__index : Int -> AnyPtr -> AnyPtr +||| Deletes the `string`. It is up to the caller to `free` the `CharArray`. +export +stringToCharArray : HasIO io => CppString -> io CharArray +stringToCharArray (MkCppString str) = do + data' <- primIO $ prim__stringData str + let size = prim__stringSize str + primIO $ prim__stringDelete str + pure (MkCharArray data' size) + export cIntToBool : Int -> Bool cIntToBool 0 = False diff --git a/spidr/src/Compiler/LLVM/Support/RawOStream.idr b/spidr/src/Compiler/LLVM/Support/RawOStream.idr new file mode 100644 index 000000000..f8b11e32a --- /dev/null +++ b/spidr/src/Compiler/LLVM/Support/RawOStream.idr @@ -0,0 +1,35 @@ +{-- +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--} +||| For internal spidr use only. +module Compiler.LLVM.Support.RawOStream + +import Compiler.FFI + +public export +data RawStringOStream = MkRawStringOStream GCAnyPtr + +%foreign (libxla "raw_string_ostream_new") +prim__mkRawStringOStream : AnyPtr -> PrimIO AnyPtr + +%foreign (libxla "raw_string_ostream_delete") +prim__delete : AnyPtr -> PrimIO () + +export +rawStringOStream : HasIO io => CppString -> io RawStringOStream +rawStringOStream (MkCppString str) = do + os <- primIO $ prim__mkRawStringOStream str + os <- onCollectAny os (primIO . prim__delete) + pure (MkRawStringOStream os) diff --git a/spidr/src/Compiler/MLIR/IR/Attributes.idr b/spidr/src/Compiler/MLIR/IR/Attributes.idr new file mode 100644 index 000000000..c6b225c42 --- /dev/null +++ b/spidr/src/Compiler/MLIR/IR/Attributes.idr @@ -0,0 +1,35 @@ +{-- +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--} +||| For internal spidr use only. +module Compiler.MLIR.IR.Attributes + +import Compiler.FFI + +public export +data Attribute = MkAttribute GCAnyPtr + +%foreign (libxla "Attribute_new") +prim__mkAttribute : PrimIO AnyPtr + +%foreign (libxla "Attribute_delete") +prim__deleteAttribute : AnyPtr -> PrimIO () + +export +mkAttribute : HasIO io => io Attribute +mkAttribute = do + attr <- primIO prim__mkAttribute + attr <- onCollectAny attr (primIO . prim__deleteAttribute) + pure (MkAttribute attr) diff --git a/spidr/src/Compiler/MLIR/IR/Block.idr b/spidr/src/Compiler/MLIR/IR/Block.idr new file mode 100644 index 000000000..921418676 --- /dev/null +++ b/spidr/src/Compiler/MLIR/IR/Block.idr @@ -0,0 +1,35 @@ +{-- +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--} +||| For internal spidr use only. +module Compiler.MLIR.IR.Block + +import Compiler.FFI + +public export +data Block = MkBlock GCAnyPtr + +%foreign (libxla "Block_new") +prim__mkBlock : PrimIO AnyPtr + +%foreign (libxla "Block_delete") +prim__deleteBlock : AnyPtr -> PrimIO () + +export +mkBlock : HasIO io => io Block +mkBlock = do + block <- primIO prim__mkBlock + block <- onCollectAny block (primIO . prim__deleteBlock) + pure (MkBlock block) diff --git a/spidr/src/Compiler/MLIR/IR/BuiltinOps.idr b/spidr/src/Compiler/MLIR/IR/BuiltinOps.idr new file mode 100644 index 000000000..0ac3fa2ee --- /dev/null +++ b/spidr/src/Compiler/MLIR/IR/BuiltinOps.idr @@ -0,0 +1,38 @@ +{-- +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--} +||| For internal spidr use only. +module Compiler.MLIR.IR.BuiltinOps + +import Compiler.MLIR.IR.MLIRContext +import Compiler.FFI + +public export +data ModuleOp = MkModuleOp GCAnyPtr + +export +%foreign (libxla "ModuleOp_delete") +prim__delete : AnyPtr -> PrimIO () + +export +%foreign (libxla "ModuleOp_getContext") +prim__moduleOp : GCAnyPtr -> PrimIO AnyPtr + +export +getContext : HasIO io => ModuleOp -> io MLIRContext +getContext (MkModuleOp op) = do + ctx <- primIO $ prim__moduleOp op + ctx <- onCollectAny ctx (const $ pure ()) -- I reckon we've already GC'ed this + pure (MkMLIRContext ctx) diff --git a/spidr/src/Compiler/MLIR/IR/DialectRegistry.idr b/spidr/src/Compiler/MLIR/IR/DialectRegistry.idr new file mode 100644 index 000000000..329b35308 --- /dev/null +++ b/spidr/src/Compiler/MLIR/IR/DialectRegistry.idr @@ -0,0 +1,35 @@ +{-- +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--} +||| For internal spidr use only. +module Compiler.MLIR.IR.DialectRegistry + +import Compiler.FFI + +public export +data DialectRegistry = MkDialectRegistry GCAnyPtr + +%foreign (libxla "DialectRegistry_new") +prim__mkDialectRegistry : PrimIO AnyPtr + +%foreign (libxla "DialectRegistry_delete") +prim__deleteDialectRegistry : AnyPtr -> PrimIO () + +export +mkDialectRegistry : HasIO io => io DialectRegistry +mkDialectRegistry = do + registry <- primIO prim__mkDialectRegistry + registry <- onCollectAny registry (primIO . prim__deleteDialectRegistry) + pure (MkDialectRegistry registry) diff --git a/spidr/src/Compiler/MLIR/IR/Location.idr b/spidr/src/Compiler/MLIR/IR/Location.idr new file mode 100644 index 000000000..d845e477c --- /dev/null +++ b/spidr/src/Compiler/MLIR/IR/Location.idr @@ -0,0 +1,22 @@ +{-- +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--} +||| For internal spidr use only. +module Compiler.MLIR.IR.Location + +import Compiler.FFI + +public export +data Location = MkLocation GCAnyPtr diff --git a/spidr/src/Compiler/MLIR/IR/MLIRContext.idr b/spidr/src/Compiler/MLIR/IR/MLIRContext.idr new file mode 100644 index 000000000..42c91caa3 --- /dev/null +++ b/spidr/src/Compiler/MLIR/IR/MLIRContext.idr @@ -0,0 +1,54 @@ +{-- +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--} +||| For internal spidr use only. +module Compiler.MLIR.IR.MLIRContext + +import Compiler.MLIR.IR.DialectRegistry +import Compiler.FFI + +public export +data MLIRContext = MkMLIRContext GCAnyPtr + +%foreign (libxla "MLIRContext_new") +prim__mkMLIRContext : PrimIO AnyPtr + +%foreign (libxla "MLIRContext_delete") +prim__deleteMLIRContext : AnyPtr -> PrimIO () + +export +mkMLIRContext : HasIO io => io MLIRContext +mkMLIRContext = do + ctx <- primIO prim__mkMLIRContext + ctx <- onCollectAny ctx (primIO . prim__deleteMLIRContext) + pure (MkMLIRContext ctx) + +%foreign (libxla "MLIRContext_getDialectRegistry") +prim__getDialectRegistry : GCAnyPtr -> PrimIO AnyPtr + +export +getDialectRegistry : HasIO io => MLIRContext -> io DialectRegistry +getDialectRegistry (MkMLIRContext ctx) = do + registry <- primIO $ prim__getDialectRegistry ctx + registry <- onCollectAny registry (const $ pure ()) -- correct? + pure (MkDialectRegistry registry) + +%foreign (libxla "MLIRContext_appendDialectRegistry") +prim__appendDialectRegistry : GCAnyPtr -> GCAnyPtr -> PrimIO () + +export +appendDialectRegistry : HasIO io => MLIRContext -> DialectRegistry -> io () +appendDialectRegistry (MkMLIRContext ctx) (MkDialectRegistry registry) = + primIO $ prim__appendDialectRegistry ctx registry diff --git a/spidr/src/Compiler/MLIR/IR/Operation.idr b/spidr/src/Compiler/MLIR/IR/Operation.idr new file mode 100644 index 000000000..ba44b6583 --- /dev/null +++ b/spidr/src/Compiler/MLIR/IR/Operation.idr @@ -0,0 +1,26 @@ +{-- +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--} +||| For internal spidr use only. +module Compiler.MLIR.IR.BuiltinOps + +import Compiler.FFI + +public export +data ModuleOp = MkModuleOp GCAnyPtr + +export +%foreign (libxla "ModuleOp_delete") +prim__delete : AnyPtr -> PrimIO () diff --git a/spidr/src/Compiler/MLIR/IR/OperationSupport.idr b/spidr/src/Compiler/MLIR/IR/OperationSupport.idr new file mode 100644 index 000000000..fcb69cbf0 --- /dev/null +++ b/spidr/src/Compiler/MLIR/IR/OperationSupport.idr @@ -0,0 +1,54 @@ +{-- +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--} +||| For internal spidr use only. +module Compiler.MLIR.IR.OperationSupport + +import Compiler.MLIR.IR.Attributes +import Compiler.MLIR.IR.Location +import Compiler.MLIR.IR.ValueRange +import Compiler.FFI + +public export +data OperationState = MkOperationState GCAnyPtr + +%foreign (libxla "OperationState_new") +prim__mkOperationState : GCAnyPtr -> String -> PrimIO AnyPtr + +%foreign (libxla "OperationState_delete") +prim__delete : AnyPtr -> PrimIO () + +export +mkOperationState : HasIO io => Location -> String -> io OperationState +mkOperationState (MkLocation location) name = do + opState <- primIO $ prim__mkOperationState location name + opState <- onCollectAny opState (primIO . OperationSupport.prim__delete) + pure (MkOperationState opState) + +%foreign (libxla "OperationState_addOperands") +prim__operationStateAddOperands : GCAnyPtr -> GCAnyPtr -> PrimIO () + +export +addOperands : HasIO io => OperationState -> ValueRange -> io () +addOperands (MkOperationState opState) (MkValueRange valueRange) = + primIO $ prim__operationStateAddOperands opState valueRange + +%foreign (libxla "OperationState_addAttribute") +prim__operationStateAddAttribute : GCAnyPtr -> GCAnyPtr -> PrimIO () + +export +addAttribute : HasIO io => OperationState -> Attribute -> io () +addAttribute (MkOperationState opState) (MkAttribute attribute) = + primIO $ prim__operationStateAddAttribute opState attribute diff --git a/spidr/src/Compiler/MLIR/IR/ValueRange.idr b/spidr/src/Compiler/MLIR/IR/ValueRange.idr new file mode 100644 index 000000000..b347b30fa --- /dev/null +++ b/spidr/src/Compiler/MLIR/IR/ValueRange.idr @@ -0,0 +1,22 @@ +{-- +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--} +||| For internal spidr use only. +module Compiler.MLIR.IR.ValueRange + +import Compiler.FFI + +public export +data ValueRange = MkValueRange GCAnyPtr diff --git a/spidr/src/Compiler/MLIR/Pass/Pass.idr b/spidr/src/Compiler/MLIR/Pass/Pass.idr new file mode 100644 index 000000000..af268a944 --- /dev/null +++ b/spidr/src/Compiler/MLIR/Pass/Pass.idr @@ -0,0 +1,26 @@ +{-- +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--} +||| For internal spidr use only. +module Compiler.MLIR.Pass.Pass + +import Compiler.FFI + +public export +data Pass = MkPass GCAnyPtr + +export +%foreign (libxla "Pass_delete") +prim__delete : AnyPtr -> PrimIO () diff --git a/spidr/src/Compiler/MLIR/Pass/PassManager.idr b/spidr/src/Compiler/MLIR/Pass/PassManager.idr new file mode 100644 index 000000000..eb4972221 --- /dev/null +++ b/spidr/src/Compiler/MLIR/Pass/PassManager.idr @@ -0,0 +1,54 @@ +{-- +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--} +||| For internal spidr use only. +module Compiler.MLIR.Pass.PassManager + +import Compiler.MLIR.IR.BuiltinOps +import Compiler.MLIR.IR.MLIRContext +import Compiler.MLIR.Pass.Pass +import Compiler.FFI + +public export +data PassManager = MkPassManager GCAnyPtr + +%foreign (libxla "PassManager_new") +prim__mkPassManager : GCAnyPtr -> PrimIO AnyPtr + +%foreign (libxla "PassManager_delete") +prim__delete : AnyPtr -> PrimIO () + +export +mkPassManager : HasIO io => MLIRContext -> io PassManager +mkPassManager (MkMLIRContext ctx) = do + manager <- primIO $ prim__mkPassManager ctx + manager <- onCollectAny manager (primIO . PassManager.prim__delete) + pure (MkPassManager manager) + +%foreign (libxla "PassManager_addPass") +prim__passManagerAddPass : GCAnyPtr -> GCAnyPtr -> PrimIO () + +export +addPass : HasIO io => PassManager -> Pass -> io () +addPass (MkPassManager manager) (MkPass pass) = primIO $ prim__passManagerAddPass manager pass + +%foreign (libxla "PassManager_run") +prim__passManagerRun : GCAnyPtr -> GCAnyPtr -> PrimIO Int + +export +run : HasIO io => PassManager -> ModuleOp -> io Bool +run (MkPassManager manager) (MkModuleOp op) = do + ok <- primIO $ prim__passManagerRun manager op + pure (cIntToBool ok) diff --git a/spidr/src/Compiler/StableHLO/Dialect/Register.idr b/spidr/src/Compiler/StableHLO/Dialect/Register.idr new file mode 100644 index 000000000..e51220ac2 --- /dev/null +++ b/spidr/src/Compiler/StableHLO/Dialect/Register.idr @@ -0,0 +1,27 @@ +{-- +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--} +||| For internal spidr use only. +module Compiler.StableHLO.Dialect.Register + +import Compiler.MLIR.IR.DialectRegistry +import Compiler.FFI + +%foreign (libxla "registerAllDialects") +prim__registerAllDialects : GCAnyPtr -> PrimIO () + +export +registerAllDialects : HasIO io => DialectRegistry -> io () +registerAllDialects (MkDialectRegistry reg) = primIO $ prim__registerAllDialects reg diff --git a/spidr/src/Compiler/StableHLO/Dialect/Serialization.idr b/spidr/src/Compiler/StableHLO/Dialect/Serialization.idr new file mode 100644 index 000000000..299f9269c --- /dev/null +++ b/spidr/src/Compiler/StableHLO/Dialect/Serialization.idr @@ -0,0 +1,30 @@ +{-- +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--} +||| For internal spidr use only. +module Compiler.StableHLO.Dialect.Serialization + +import Compiler.LLVM.Support.RawOStream +import Compiler.MLIR.IR.BuiltinOps +import Compiler.FFI + +%foreign (libxla "serializePortableArtifact") +prim__serializePortableArtifact : GCAnyPtr -> AnyPtr -> GCAnyPtr -> PrimIO Int + +export +serializePortableArtifact : HasIO io => ModuleOp -> CppString -> RawStringOStream -> io Bool +serializePortableArtifact (MkModuleOp moduleOp) (MkCppString version) (MkRawStringOStream os) = do + ok <- primIO $ prim__serializePortableArtifact moduleOp version os + pure (cIntToBool ok) diff --git a/spidr/src/Compiler/StableHLO/Dialect/Version.idr b/spidr/src/Compiler/StableHLO/Dialect/Version.idr new file mode 100644 index 000000000..bad9ca363 --- /dev/null +++ b/spidr/src/Compiler/StableHLO/Dialect/Version.idr @@ -0,0 +1,45 @@ +{-- +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--} +||| For internal spidr use only. +module Compiler.StableHLO.Dialect.Version + +import Compiler.FFI + +export +data Version = MkVersion GCAnyPtr + +%foreign (libxla "Version_delete") +prim__delete : AnyPtr -> PrimIO () + +%foreign (libxla "Version_getMinimumVersion") +prim__versionGetMinimumVersion : PrimIO AnyPtr + +export +getMinimumVersion : HasIO io => io Version +getMinimumVersion = do + version <- primIO prim__versionGetMinimumVersion + version <- onCollectAny version (primIO . prim__delete) + pure (MkVersion version) + +%foreign (libxla "Version_toString") +prim__versionToString : GCAnyPtr -> PrimIO AnyPtr + +||| It is up to the caller to `delete` the string. +export +toString : HasIO io => Version -> io CppString +toString (MkVersion version) = do + str <- primIO $ prim__versionToString version + pure (MkCppString str) diff --git a/spidr/src/Compiler/Xla/HLO/Builder/XlaBuilder.idr b/spidr/src/Compiler/Xla/HLO/Builder/XlaBuilder.idr index c85bc08bf..d1ca3ad42 100644 --- a/spidr/src/Compiler/Xla/HLO/Builder/XlaBuilder.idr +++ b/spidr/src/Compiler/Xla/HLO/Builder/XlaBuilder.idr @@ -347,6 +347,17 @@ cholesky (MkXlaOp a) lower = do opPtr <- onCollectAny opPtr XlaOp.delete pure (MkXlaOp opPtr) +%foreign (libxla "Call") +prim__call : GCAnyPtr -> GCAnyPtr -> GCAnyPtr -> Bits64 -> PrimIO AnyPtr + +export +call : HasIO io => XlaBuilder -> XlaComputation -> List XlaOp -> io XlaOp +call (MkXlaBuilder builder) (MkXlaComputation computation) operands = do + MkXlaOpArray operandsXlaOpArrayPtr <- mkXlaOpArray operands + op <- primIO $ prim__call builder computation operandsXlaOpArrayPtr (cast $ length operands) + op <- onCollectAny op XlaOp.delete + pure (MkXlaOp op) + %foreign (libxla "Add") prim__add : GCAnyPtr -> GCAnyPtr -> PrimIO AnyPtr diff --git a/spidr/src/Compiler/Xla/HLO/Builder/XlaComputation.idr b/spidr/src/Compiler/Xla/HLO/Builder/XlaComputation.idr index 1e35ba4dc..c4cd9a8a8 100644 --- a/spidr/src/Compiler/Xla/HLO/Builder/XlaComputation.idr +++ b/spidr/src/Compiler/Xla/HLO/Builder/XlaComputation.idr @@ -17,28 +17,36 @@ limitations under the License. module Compiler.Xla.HLO.Builder.XlaComputation import Compiler.FFI +import Compiler.Xla.Shape +import Compiler.Xla.Service.HloProto public export data XlaComputation : Type where MkXlaComputation : GCAnyPtr -> XlaComputation +%foreign (libxla "XlaComputation_new") +prim__mkXlaComputation : GCAnyPtr -> PrimIO AnyPtr + %foreign (libxla "XlaComputation_delete") prim__delete : AnyPtr -> PrimIO () export delete : AnyPtr -> IO () -delete = primIO . prim__delete +delete = primIO . XlaComputation.prim__delete export -%foreign (libxla "XlaComputation_SerializeAsString") -prim__xlaComputationSerializeAsString : GCAnyPtr -> PrimIO AnyPtr +mkXlaComputation : HasIO io => HloModuleProto -> io XlaComputation +mkXlaComputation (MkHloModuleProto proto) = do + comp <- primIO $ prim__mkXlaComputation proto + comp <- onCollectAny comp XlaComputation.delete + pure (MkXlaComputation comp) + +%foreign (libxla "XlaComputation_proto") +prim__xlaComputationProto : GCAnyPtr -> PrimIO AnyPtr -||| It is up to the caller to deallocate the CharArray. export -serializeAsString : HasIO io => XlaComputation -> io CharArray -serializeAsString (MkXlaComputation computation) = do - str <- primIO $ prim__xlaComputationSerializeAsString computation - data' <- primIO $ prim__stringData str - let size = prim__stringSize str - primIO $ prim__stringDelete str - pure (MkCharArray data' size) +proto : HasIO io => XlaComputation -> io HloModuleProto +proto (MkXlaComputation comp) = do + proto <- primIO $ prim__xlaComputationProto comp + proto <- onCollectAny proto (primIO . HloProto.prim__delete) + pure (MkHloModuleProto proto) diff --git a/spidr/src/Compiler/Xla/HLO/Translate/StableHLO.idr b/spidr/src/Compiler/Xla/HLO/Translate/StableHLO.idr new file mode 100644 index 000000000..8f2f39854 --- /dev/null +++ b/spidr/src/Compiler/Xla/HLO/Translate/StableHLO.idr @@ -0,0 +1,42 @@ +{-- +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--} +||| For internal spidr use only. +module Compiler.Xla.HLO.Translate.StableHLO + +import Compiler.FFI +import Compiler.MLIR.IR.BuiltinOps +import Compiler.MLIR.IR.MLIRContext +import Compiler.Xla.Service.HloProto + +%foreign (libxla "ConvertHloToStablehlo") +prim__convertHloToStablehlo : GCAnyPtr -> GCAnyPtr -> PrimIO AnyPtr + +export +convertHloToStablehlo : HasIO io => MLIRContext -> HloModuleProto -> io ModuleOp +convertHloToStablehlo (MkMLIRContext ctx) (MkHloModuleProto proto) = do + moduleOp <- primIO $ prim__convertHloToStablehlo ctx proto + moduleOp <- onCollectAny moduleOp (primIO . BuiltinOps.prim__delete) + pure (MkModuleOp moduleOp) + +%foreign (libxla "ConvertStablehloToHlo") +prim__convertStablehloToHlo : GCAnyPtr -> PrimIO AnyPtr + +export +convertStablehloToHlo : HasIO io => ModuleOp -> io HloModuleProto +convertStablehloToHlo (MkModuleOp op) = do + hlo <- primIO $ prim__convertStablehloToHlo op + hlo <- onCollectAny hlo (primIO . BuiltinOps.prim__delete) + pure (MkHloModuleProto hlo) diff --git a/spidr/src/Compiler/Xla/MLIRHLO/MHLO/IR/Register.idr b/spidr/src/Compiler/Xla/MLIRHLO/MHLO/IR/Register.idr new file mode 100644 index 000000000..77f82fd41 --- /dev/null +++ b/spidr/src/Compiler/Xla/MLIRHLO/MHLO/IR/Register.idr @@ -0,0 +1,27 @@ +{-- +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--} +||| For internal spidr use only. +module Compiler.Xla.MLIRHLO.MHLO.IR.Register + +import Compiler.MLIR.IR.DialectRegistry +import Compiler.FFI + +%foreign (libxla "registerAllMhloDialects") +prim__registerAllMhloDialects : GCAnyPtr -> PrimIO () + +export +registerAllMhloDialects : HasIO io => DialectRegistry -> io () +registerAllMhloDialects (MkDialectRegistry reg) = primIO $ prim__registerAllMhloDialects reg diff --git a/spidr/src/Compiler/Xla/PJRT/C/PjrtCApi.idr b/spidr/src/Compiler/Xla/PJRT/C/PjrtCApi.idr index 9933b6c30..9ab2f0e62 100644 --- a/spidr/src/Compiler/Xla/PJRT/C/PjrtCApi.idr +++ b/spidr/src/Compiler/Xla/PJRT/C/PjrtCApi.idr @@ -73,9 +73,9 @@ export Show PjrtError where show e = let code = case e.code of - Nothing => "not found" + Nothing => "unknown" Just c => show c - in "PjrtError \{show e.message} (code \{code})" + in "PjrtError (error code \{code})\n\{e.message}" %foreign (libxla "PJRT_Error_Destroy_Args_new") prim__mkPjrtErrorDestroyArgs : AnyPtr -> PrimIO AnyPtr diff --git a/spidr/src/Compiler/Xla/PJRT/PjrtExecutable.idr b/spidr/src/Compiler/Xla/PJRT/PjrtExecutable.idr index 987cb1fdd..a2e1cc136 100644 --- a/spidr/src/Compiler/Xla/PJRT/PjrtExecutable.idr +++ b/spidr/src/Compiler/Xla/PJRT/PjrtExecutable.idr @@ -36,12 +36,9 @@ mkCompileOptions (MkExecutableBuildOptions executableBuildOptions) = do %foreign (libxla "CompileOptions_SerializeAsString") prim__compileOptionsSerializeAsString : GCAnyPtr -> PrimIO AnyPtr -||| It is up to the caller to deallocate the CharArray. +||| It is up to the caller to `free` the `CharArray`. export serializeAsString : HasIO io => CompileOptions -> io CharArray serializeAsString (MkCompileOptions options) = do - str <- primIO $ prim__compileOptionsSerializeAsString options - data' <- primIO $ prim__stringData str - let size = prim__stringSize str - primIO $ prim__stringDelete str - pure (MkCharArray data' size) + str <- primIO (prim__compileOptionsSerializeAsString options) + stringToCharArray (MkCppString str) diff --git a/spidr/src/Compiler/Xla/Service/HloProto.idr b/spidr/src/Compiler/Xla/Service/HloProto.idr new file mode 100644 index 000000000..5cd389aae --- /dev/null +++ b/spidr/src/Compiler/Xla/Service/HloProto.idr @@ -0,0 +1,26 @@ +{-- +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--} +||| For internal spidr use only. +module Compiler.Xla.Service.HloProto + +import Compiler.FFI + +public export +data HloModuleProto = MkHloModuleProto GCAnyPtr + +export +%foreign (libxla "HloModuleProto_delete") +prim__delete : AnyPtr -> PrimIO () diff --git a/spidr/src/Tensor.idr b/spidr/src/Tensor.idr index 024a2a41e..8720dd562 100644 --- a/spidr/src/Tensor.idr +++ b/spidr/src/Tensor.idr @@ -238,6 +238,18 @@ export castDtype : Primitive.Integral a => Tensor shape a -> Tensor shape F64 castDtype $ MkTensor x = MkTensor $ ConvertElementType {dtype = F64} x +||| Reverse-mode automatic differentiation. +export +grad : (Tensor shape F64 -> Tag $ Tensor [] F64) -> Tensor shape F64 -> Tag $ Tensor shape F64 +grad f (MkTensor x) = MkTagT $ do + addr <- reserve + let MkTagT app = f (MkTensor $ Var addr) + (env, MkTensor res) = runState (emptyFrom !get) app + g = MkFn [(addr, MkParameter [] F64)] res env + + updateCounterFrom env + pure $ MkTensor $ Grad g x + ----------------------------- structural operations ---------------------------- ||| Reshape a `Tensor`. For example, `reshape {to = [2, 1]} (tensor [3, 4])` is @@ -1335,7 +1347,7 @@ cos = unary Cos ||| The element-wise tangent. export tan : Tensor shape F64 -> Tensor shape F64 -tan = unary Tan +tan x = sin x / cos x ||| The element-wise inverse sine. export diff --git a/test/runner/TestRunner.idr b/test/runner/TestRunner.idr index 792d2ae12..835c775c3 100644 --- a/test/runner/TestRunner.idr +++ b/test/runner/TestRunner.idr @@ -30,12 +30,12 @@ import Unit.TestUtil export run : Device -> IO () -run device = test [ +run device = test [{- Utils.TestComparison.group , TestUtils.group , Unit.TestUtil.group , Unit.TestLiteral.group - , Unit.TestTensor.group + ,-} Unit.TestTensor.group{- , Unit.TestDistribution.group - , Unit.Model.TestKernel.group + , Unit.Model.TestKernel.group-} ] diff --git a/test/runner/Unit/TestTensor.idr b/test/runner/Unit/TestTensor.idr index e774435ab..20c5762fb 100644 --- a/test/runner/Unit/TestTensor.idr +++ b/test/runner/Unit/TestTensor.idr @@ -15,6 +15,7 @@ limitations under the License. --} module Unit.TestTensor +import Unit.TestTensor.AD import Unit.TestTensor.Elementwise import Unit.TestTensor.HigherOrder import Unit.TestTensor.Sampling @@ -477,7 +478,7 @@ trace = fixedProperty $ export group : Device => Group -group = MkGroup "Tensor" $ [ +group = MkGroup "Tensor" $ [{- ("eval . tensor", tensorThenEval) , ("eval multiple tensors (tuple)", evalTuple) , ("eval multiple tensors (tuple) for non-trivial graph", evalTupleNonTrivial) @@ -498,11 +499,12 @@ group = MkGroup "Tensor" $ [ , ("cholesky", cholesky) , (#"(|\) and (/|) result and inverse"#, triangularSolveResultAndInverse) , (#"(|\) and (/|) ignore opposite elements"#, triangularSolveIgnoresOppositeElems) - , ("trace", trace) + , ("trace", trace)-} ] ++ concat (the (List _) [ - Unit.TestTensor.Elementwise.all + Unit.TestTensor.AD.all + {-, Unit.TestTensor.Elementwise.all , Unit.TestTensor.HigherOrder.all , Unit.TestTensor.Sampling.all , Unit.TestTensor.Slice.all - , Unit.TestTensor.Structure.all + , Unit.TestTensor.Structure.all-} ]) diff --git a/test/runner/Unit/TestTensor/AD.idr b/test/runner/Unit/TestTensor/AD.idr new file mode 100644 index 000000000..84b0ca994 --- /dev/null +++ b/test/runner/Unit/TestTensor/AD.idr @@ -0,0 +1,36 @@ +{-- +Copyright 2023 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--} +module Unit.TestTensor.AD + +import System + +import Device +import Tensor + +import Utils +import Utils.Comparison +import Utils.Cases + +square : Device => Property +square = fixedProperty $ do + -- square (tensor 3.0) ===# tensor 9.0 + grad (pure . square) (tensor 3.0) ===# pure (tensor 6.0) + +export +all : Device => List (PropertyName, Property) +all = [ + ("grad square", square) + ] diff --git a/test/runner/runner.ipkg b/test/runner/runner.ipkg index 8563d2bf0..f837aba18 100644 --- a/test/runner/runner.ipkg +++ b/test/runner/runner.ipkg @@ -8,6 +8,7 @@ depends = modules = Unit.Model.TestKernel, + Unit.TestTensor.AD, Unit.TestTensor.Elementwise, Unit.TestTensor.HigherOrder, Unit.TestTensor.Sampling, diff --git a/test/xla-cpu/Main.idr b/test/xla-cpu/Main.idr new file mode 100644 index 000000000..854d1eae4 --- /dev/null +++ b/test/xla-cpu/Main.idr @@ -0,0 +1,25 @@ +{-- +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--} +module Main + +import System + +import TestRunner +import PjrtPluginXlaCpu + +partial +main : IO () +main = eitherT (die . show) run device diff --git a/test/xla-cpu/xla-cpu.ipkg b/test/xla-cpu/xla-cpu.ipkg index 24255b025..39fd35065 100644 --- a/test/xla-cpu/xla-cpu.ipkg +++ b/test/xla-cpu/xla-cpu.ipkg @@ -5,4 +5,4 @@ depends = runner executable = test -main = XlaCpu +main = Main diff --git a/test/xla-cuda/Main.idr b/test/xla-cuda/Main.idr new file mode 100644 index 000000000..4a727f497 --- /dev/null +++ b/test/xla-cuda/Main.idr @@ -0,0 +1,25 @@ +{-- +Copyright 2024 Joel Berkeley + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--} +module Main + +import System + +import TestRunner +import PjrtPluginXlaCuda + +partial +main : IO () +main = eitherT (die . show) run device diff --git a/test/xla-cuda/xla-cuda.ipkg b/test/xla-cuda/xla-cuda.ipkg index 66c3f269b..9d76e1994 100644 --- a/test/xla-cuda/xla-cuda.ipkg +++ b/test/xla-cuda/xla-cuda.ipkg @@ -5,4 +5,4 @@ depends = runner executable = test -main = XlaCuda +main = Main