Skip to content

Commit

Permalink
Add StableHLO dialect version for bytecode (openxla#2040)
Browse files Browse the repository at this point in the history
This enables (at least) invalidating previously generated StableHLO
bytecode serialization (not using VHLO) by bumping the version. The
version is optionally set and if not set, not serialized - this would
result in no changes to what is serialized by default today unless
explicitly set.

This would enable a better error message than the existing where a
loaded bytecode would fail and knowledge required as to why:

Before 

```
<unknown>:0: error: 'stablehlo.broadcast_in_dim' op attribute 'broadcast_dimensions' failed to satisfy constraint: either a DenseI64ArrayAttr or a 1-dimensional I64ElementsAttr.
<unknown>:0: note: see current operation: %4 = "stablehlo.broadcast_in_dim"(%1) {broadcast_dimensions = array<i64>} : (tensor<f32>) -> tensor<500xf32>
<unknown>:0: note: in bytecode version 6 produced by: MLIR19.0.0git
```

Approximately after:

```
reading newer dialect version than supported
```

(Note this could also be refined, but at least point to version skew vs
folks being concerned about generation error)

This does not represent a policy change nor an automatic test added.
  • Loading branch information
jpienaar authored Apr 3, 2024
1 parent 8a55376 commit 6331196
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 0 deletions.
1 change: 1 addition & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -1067,6 +1067,7 @@ cc_library(
":stablehlo_ops_inc_gen",
":stablehlo_type_inference",
":stablehlo_types_inc_gen",
":version",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:ComplexDialect",
Expand Down
1 change: 1 addition & 0 deletions stablehlo/dialect/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ add_mlir_dialect_library(StablehloOps
StablehloAssemblyFormat
StablehloBase
StablehloTypeInference
Version
MLIRArithDialect
MLIRDataLayoutInterfaces
MLIRInferTypeOpInterface
Expand Down
37 changes: 37 additions & 0 deletions stablehlo/dialect/StablehloBytecode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,14 @@ class StablehloBytecodeInterface : public BytecodeDialectInterface {
// TO ADD TYPE: Include a write method for each type in StableHLO
// Ex: void write(SomeType attr, DialectBytecodeWriter &writer) const;
void write(TokenType type, DialectBytecodeWriter &writer) const;

//===--------------------------------------------------------------------===//
// Version

std::unique_ptr<DialectVersion> readVersion(
DialectBytecodeReader &reader) const override final;

void writeVersion(DialectBytecodeWriter &writer) const override final;
};

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -695,6 +703,35 @@ void StablehloBytecodeInterface::write(TokenType type,
writer.writeVarInt(stablehlo_encoding::kTokenType);
}

std::unique_ptr<DialectVersion> StablehloBytecodeInterface::readVersion(
DialectBytecodeReader &reader) const {
uint64_t major, minor, patch;
if (failed(reader.readVarInt(major)) || failed(reader.readVarInt(minor)) ||
failed(reader.readVarInt(patch)))
return nullptr;

auto version = std::make_unique<StablehloDialectVersion>(
/*major=*/major, /*minor=*/minor, /*patch=*/patch);
if (version && StablehloDialectVersion::getCurrentVersion() < *version) {
// Note: dialect bytecode reader does not expose emitWarning.
// TODO(jpienaar): Update when it does.
mlir::emitWarning(mlir::UnknownLoc::get(getContext()))
<< "reading newer dialect than supported";
return nullptr;
}

return version;
}

void StablehloBytecodeInterface::writeVersion(
DialectBytecodeWriter &writer) const {
if (auto version = cast<StablehloDialect>(getDialect())->getVersion()) {
writer.writeVarInt(static_cast<uint64_t>(version->getMajor()));
writer.writeVarInt(static_cast<uint64_t>(version->getMinor()));
writer.writeVarInt(static_cast<uint64_t>(version->getPatch()));
}
}

} // namespace

void addBytecodeInterface(StablehloDialect *dialect) {
Expand Down
9 changes: 9 additions & 0 deletions stablehlo/dialect/StablehloOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3039,5 +3039,14 @@ Operation* StablehloDialect::materializeConstant(OpBuilder& builder,
return builder.create<ConstantOp>(loc, type, elementsAttr);
}

std::optional<StablehloDialectVersion> StablehloDialect::getVersion() const {
return version;
}

void StablehloDialect::setVersion(
std::optional<StablehloDialectVersion> version) {
this->version = version;
}

} // namespace stablehlo
} // namespace mlir
34 changes: 34 additions & 0 deletions stablehlo/dialect/StablehloOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ limitations under the License.
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Support/LogicalResult.h"
#include "stablehlo/dialect/Base.h"
#include "stablehlo/dialect/Version.h"

#define GET_TYPEDEF_CLASSES
#include "stablehlo/dialect/StablehloTypeDefs.h.inc"
Expand All @@ -52,6 +53,29 @@ limitations under the License.
namespace mlir {
namespace stablehlo {

struct StablehloDialectVersion : public mlir::DialectVersion {
StablehloDialectVersion(int64_t major, int64_t minor, int64_t patch)
: dialectVersion(major, minor, patch) {}

int64_t getMajor() const { return dialectVersion.getMajor(); }
int64_t getMinor() const { return dialectVersion.getMinor(); }
int64_t getPatch() const { return dialectVersion.getPatch(); }

static StablehloDialectVersion getCurrentVersion() {
// The same version as VHLO as this is serialization related only.
auto vhloVer = vhlo::Version::getCurrentVersion();
return {vhloVer.getMajor(), vhloVer.getMinor(), vhloVer.getPatch()};
}

bool operator<(const StablehloDialectVersion &other) const {
return this->dialectVersion < other.dialectVersion;
}

private:
// The dialect version read from bytecode.
vhlo::Version dialectVersion;
};

class StablehloDialect : public Dialect {
public:
explicit StablehloDialect(MLIRContext *context);
Expand All @@ -73,6 +97,16 @@ class StablehloDialect : public Dialect {

// Prints an attribute registered to this dialect.
void printAttribute(Attribute attr, DialectAsmPrinter &os) const override;

// Get the set dialect version.
std::optional<StablehloDialectVersion> getVersion() const;

// Set dialect version.
// Note: there is currently no validation.
void setVersion(std::optional<StablehloDialectVersion> version);

private:
std::optional<StablehloDialectVersion> version;
};

// Verifies the source target pairs attached to collective permute.
Expand Down

0 comments on commit 6331196

Please sign in to comment.