Skip to content

Commit

Permalink
introduce string type + string attribute type + tests for string type
Browse files Browse the repository at this point in the history
  • Loading branch information
dshaaban01 committed Jan 6, 2025
1 parent 584a68e commit 15d2e21
Show file tree
Hide file tree
Showing 11 changed files with 243 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ def Substrait_Dialect : Dialect {
more natural in MLIR to represent several message types as a single op and
express message sub-types with interfaces instead.
}];
let useDefaultAttributePrinterParser = 1;
let useDefaultTypePrinterParser = 1;
}

#endif // SUBSTRAIT_DIALECT_SUBSTRAIT_IR_SUBSTRAITDIALECT
23 changes: 21 additions & 2 deletions include/substrait-mlir/Dialect/Substrait/IR/SubstraitTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
include "substrait-mlir/Dialect/Substrait/IR/SubstraitDialect.td"
include "mlir/IR/CommonTypeConstraints.td"
include "mlir/IR/OpBase.td"
include "mlir/IR/BuiltinAttributeInterfaces.td"

// Base class for Substrait dialect types.
class Substrait_Type<string name, string typeMnemonic, list<Trait> traits = []>
Expand All @@ -25,21 +26,39 @@ class Substrait_Attr<string name, string typeMnemonic, list<Trait> traits = []>
let mnemonic = typeMnemonic;
}

def Substrait_String : Substrait_Type<"String", "string"> {
let summary = "Substrait string type";
let description = [{
This type represents a substrait string type.
}];
}

def Substrait_StrAttr : Substrait_Attr<"Str", "string", [DeclareAttrInterfaceMethods<TypedAttrInterface>]> {
let summary = "Substrait string attribute type";
let description = [{
This type represents a substrait string attribute type.
}];
let parameters = (ins StringRefParameter<> :$value);
let assemblyFormat = [{ `<` $value `>` }];
}

/// Currently supported atomic types. These correspond directly to the types in
/// https://github.com/substrait-io/substrait/blob/main/proto/substrait/type.proto.
// TODO(ingomueller): Add the other low-hanging fruits here.
def Substrait_AtomicTypes {
list<Type> types = [
SI1, // Boolean
SI32 // I32
SI32, // I32
Substrait_String // String
];
}

/// Attributes of currently supported atomic types.
def Substrait_AtomicAttributes {
list<Attr> attrs = [
SI1Attr, // Boolean
SI32Attr // I32
SI32Attr, // I32
Substrait_StrAttr // String
];
}

Expand Down
8 changes: 8 additions & 0 deletions lib/Dialect/Substrait/IR/Substrait.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,14 @@ void printCountAsAll(OpAsmPrinter &printer, Operation *op, IntegerAttr count) {
// Normal integer.
printer << count.getValue();
}

//===----------------------------------------------------------------------===//
// Substrait types
//===----------------------------------------------------------------------===//

/// Implement the getType method for custom type `StrAttr`.
::mlir::Type StrAttr::getType() const { return StringType::get(getContext()); }

//===----------------------------------------------------------------------===//
// Substrait operations
//===----------------------------------------------------------------------===//
Expand Down
15 changes: 15 additions & 0 deletions lib/Target/SubstraitPB/Export.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,18 @@ SubstraitExporter::exportType(Location loc, mlir::Type mlirType) {
return std::move(type);
}

// Handle String.
if (mlirType.isa<StringType>()) {
// TODO(ingomueller): support other nullability modes.
auto stringType = std::make_unique<proto::Type::String>();
stringType->set_nullability(
Type_Nullability::Type_Nullability_NULLABILITY_REQUIRED);

auto type = std::make_unique<proto::Type>();
type->set_allocated_string(stringType.release());
return std::move(type);
}

if (auto tupleType = llvm::dyn_cast<TupleType>(mlirType)) {
auto structType = std::make_unique<proto::Type::Struct>();
for (mlir::Type fieldType : tupleType.getTypes()) {
Expand Down Expand Up @@ -428,6 +440,9 @@ SubstraitExporter::exportOperation(LiteralOp op) {
default:
op->emitOpError("has integer value with unsupported width");
}
} // `StringType`.
else if (literalType.isa<StringType>()) {
literal->set_string(value.cast<StrAttr>().getValue().str());
} else
op->emitOpError("has unsupported value");

Expand Down
6 changes: 6 additions & 0 deletions lib/Target/SubstraitPB/Import.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ static mlir::FailureOr<mlir::Type> importType(MLIRContext *context,
return IntegerType::get(context, 1, IntegerType::Signed);
case proto::Type::kI32:
return IntegerType::get(context, 32, IntegerType::Signed);
case proto::Type::kString:
return StringType::get(context);
case proto::Type::kStruct: {
const proto::Type::Struct &structType = type.struct_();
llvm::SmallVector<mlir::Type> fieldTypes;
Expand Down Expand Up @@ -266,6 +268,10 @@ importLiteral(ImplicitLocOpBuilder builder,
IntegerType::get(context, 32, IntegerType::Signed), message.i32());
return builder.create<LiteralOp>(attr);
}
case Expression::Literal::LiteralTypeCase::kString: {
auto attr = StrAttr::get(context, message.string());
return builder.create<LiteralOp>(attr);
}
default: {
const pb::FieldDescriptor *desc =
Expression::Literal::GetDescriptor()->FindFieldByNumber(literalType);
Expand Down
24 changes: 24 additions & 0 deletions test/Dialect/Substrait/literal.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// RUN: substrait-opt %s \
// RUN: | FileCheck %s

// CHECK: substrait.plan version 0 : 42 : 1 {
// CHECK-NEXT: relation
// CHECK: %[[V0:.*]] = named_table
// CHECK-NEXT: %[[V1:.*]] = project %[[V0]] : tuple<si1> -> tuple<si1, !substrait.string> {
// CHECK-NEXT: ^[[BB0:.*]](%[[ARG0:.*]]: tuple<si1>):
// CHECK-NEXT: %[[V2:.*]] = literal #substrait.string<"hi"> : !substrait.string
// CHECK-NEXT: yield %[[V2]] : !substrait.string
// CHECK-NEXT: }
// CHECK-NEXT: yield %[[V1]] : tuple<si1, !substrait.string>

substrait.plan version 0 : 42 : 1 {
relation {
%0 = named_table @t1 as ["a"] : tuple<si1>
%1 = project %0 : tuple<si1> -> tuple<si1, !substrait.string> {
^bb0(%arg : tuple<si1>):
%hi = literal #substrait.string<"hi"> : !substrait.string
yield %hi : !substrait.string
}
yield %1 : tuple<si1, !substrait.string>
}
}
14 changes: 14 additions & 0 deletions test/Dialect/Substrait/types.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// RUN: substrait-opt -split-input-file -split-input-file %s \
// RUN: | FileCheck %s

// CHECK-LABEL: substrait.plan
// CHECK: relation
// CHECK: %[[V0:.*]] = named_table @t1 as ["a"] : tuple<!substrait.string>
// CHECK-NEXT: yield %0 : tuple<!substrait.string>

substrait.plan version 0 : 42 : 1 {
relation {
%0 = named_table @t1 as ["a"] : tuple<!substrait.string>
yield %0 : tuple<!substrait.string>
}
}
32 changes: 32 additions & 0 deletions test/Target/SubstraitPB/Export/literal.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// RUN: substrait-translate -substrait-to-protobuf %s \
// RUN: | FileCheck %s

// RUN: substrait-translate -substrait-to-protobuf %s \
// RUN: | substrait-translate -protobuf-to-substrait \
// RUN: | substrait-translate -substrait-to-protobuf \
// RUN: | FileCheck %s

// CHECK-LABEL: relations {
// CHECK-NEXT: rel {
// CHECK-NEXT: project {
// CHECK-NEXT: common {
// CHECK-NEXT: direct {
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: input {
// CHECK-NEXT: read {
// CHECK: expressions {
// CHECK-NEXT: literal {
// CHECK-NEXT: string: "hi"

substrait.plan version 0 : 42 : 1 {
relation {
%0 = named_table @t1 as ["a"] : tuple<si1>
%1 = project %0 : tuple<si1> -> tuple<si1, !substrait.string> {
^bb0(%arg : tuple<si1>):
%hi = literal #substrait.string<"hi"> : !substrait.string
yield %hi : !substrait.string
}
yield %1 : tuple<si1, !substrait.string>
}
}
25 changes: 25 additions & 0 deletions test/Target/SubstraitPB/Export/types.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,31 @@
// RUN: --split-input-file --output-split-marker="# -----" \
// RUN: | FileCheck %s

// CHECK-LABEL: relations {
// CHECK-NEXT: rel {
// CHECK-NEXT: read {
// CHECK: base_schema {
// CHECK-NEXT: names: "a"
// CHECK-NEXT: struct {
// CHECK-NEXT: types {
// CHECK-NEXT: string {
// CHECK-NEXT: nullability: NULLABILITY_REQUIRED
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: nullability: NULLABILITY_REQUIRED
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: named_table {

substrait.plan version 0 : 42 : 1 {
relation {
%0 = named_table @t1 as ["a"] : tuple<!substrait.string>
yield %0 : tuple<!substrait.string>
}
}

// -----

// CHECK-LABEL: relations {
// CHECK-NEXT: rel {
// CHECK-NEXT: read {
Expand Down
60 changes: 60 additions & 0 deletions test/Target/SubstraitPB/Import/literal.textpb
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# RUN: substrait-translate -protobuf-to-substrait %s \
# RUN: | FileCheck %s

# RUN: substrait-translate -protobuf-to-substrait %s \
# RUN: | substrait-translate -substrait-to-protobuf \
# RUN: | substrait-translate -protobuf-to-substrait \
# RUN: | FileCheck %s

# CHECK: substrait.plan version 0 : 42 : 1 {
# CHECK-NEXT: relation
# CHECK: %[[V0:.*]] = named_table
# CHECK-NEXT: %[[V1:.*]] = project %[[V0]] : tuple<si1> -> tuple<si1, !substrait.string> {
# CHECK-NEXT: ^[[BB0:.*]](%[[ARG0:.*]]: tuple<si1>):
# CHECK-NEXT: %[[V2:.*]] = literal #substrait.string<"hi"> : !substrait.string
# CHECK-NEXT: yield %[[V2]] : !substrait.string
# CHECK-NEXT: }
# CHECK-NEXT: yield %[[V1]] : tuple<si1, !substrait.string>


relations {
rel {
project {
common {
direct {
}
}
input {
read {
common {
direct {
}
}
base_schema {
names: "a"
struct {
types {
bool {
nullability: NULLABILITY_REQUIRED
}
}
nullability: NULLABILITY_REQUIRED
}
}
named_table {
names: "t1"
}
}
}
expressions {
literal {
string: "hi"
}
}
}
}
}
version {
minor_number: 42
patch_number: 1
}
36 changes: 36 additions & 0 deletions test/Target/SubstraitPB/Import/types.textpb
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,42 @@
# RUN: --split-input-file="# ""-----" --output-split-marker="// -----" \
# RUN: | FileCheck %s

# CHECK: substrait.plan
# CHECK-NEXT: relation
# CHECK-NEXT: named_table
# CHECK-SAME: : tuple<!substrait.string>

relations {
rel {
read {
common {
direct {
}
}
base_schema {
names: "a"
struct {
types {
string {
nullability: NULLABILITY_REQUIRED
}
}
nullability: NULLABILITY_REQUIRED
}
}
named_table {
names: "t1"
}
}
}
}
version {
minor_number: 42
patch_number: 1
}

# -----

# CHECK: substrait.plan
# CHECK-NEXT: relation
# CHECK-NEXT: named_table
Expand Down

0 comments on commit 15d2e21

Please sign in to comment.