From d50b6fc2d9ab6fc97be8548beceba2db2e6b15d9 Mon Sep 17 00:00:00 2001 From: Kevin Gleason Date: Thu, 1 Feb 2024 17:59:54 -0600 Subject: [PATCH] Refactor portable api to use main registration method (#1981) Less moving pieces. Will also get the QuantizationDialect registration that was added earlier. We end up registering _some_ dialects that aren't needed in serialization, but the cost is likely not worth the duplication. --- BUILD.bazel | 1 + stablehlo/api/CMakeLists.txt | 1 + stablehlo/api/PortableApi.cpp | 13 +++++++------ 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/BUILD.bazel b/BUILD.bazel index 78c255d13eb..b6c51c82241 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -958,6 +958,7 @@ cc_library( ], strip_include_prefix = ".", deps = [ + ":register", ":stablehlo_ops", ":stablehlo_serialization", ":version", diff --git a/stablehlo/api/CMakeLists.txt b/stablehlo/api/CMakeLists.txt index d4978487cc6..7ff4b1b4bdf 100644 --- a/stablehlo/api/CMakeLists.txt +++ b/stablehlo/api/CMakeLists.txt @@ -19,6 +19,7 @@ add_mlir_dialect_library(StablehloPortableApi LINK_LIBS PUBLIC ChloOps StablehloOps + StablehloRegister StablehloSerialization Version VhloOps diff --git a/stablehlo/api/PortableApi.cpp b/stablehlo/api/PortableApi.cpp index 07c856db684..ed53193c07a 100644 --- a/stablehlo/api/PortableApi.cpp +++ b/stablehlo/api/PortableApi.cpp @@ -21,6 +21,7 @@ limitations under the License. #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/MLIRContext.h" #include "mlir/Parser/Parser.h" +#include "stablehlo/dialect/Register.h" #include "stablehlo/dialect/Serialization.h" #include "stablehlo/dialect/StablehloOps.h" #include "stablehlo/dialect/Version.h" @@ -29,10 +30,10 @@ limitations under the License. namespace mlir { namespace stablehlo { namespace { -void loadSerializationDialects(MLIRContext* context) { - context->loadDialect(); - context->loadDialect(); - context->loadDialect(); +void loadSerializationDialects(MLIRContext& context) { + mlir::DialectRegistry registry; + mlir::stablehlo::registerAllDialects(registry); + context.appendDialectRegistry(registry); } } // namespace @@ -48,7 +49,7 @@ LogicalResult serializePortableArtifact(StringRef moduleStr, StringRef targetVersion, raw_ostream& os) { MLIRContext context; - loadSerializationDialects(&context); + loadSerializationDialects(context); auto module = mlir::parseSourceString(moduleStr, &context); if (!module || failed(module->verifyInvariants())) return failure(); @@ -58,7 +59,7 @@ LogicalResult serializePortableArtifact(StringRef moduleStr, LogicalResult deserializePortableArtifact(StringRef artifactStr, raw_ostream& os) { MLIRContext context; - loadSerializationDialects(&context); + loadSerializationDialects(context); auto module = deserializePortableArtifact(artifactStr, &context); if (!module) return failure();