Skip to content

Commit

Permalink
Integration of collective_broadcast into spec (openxla#1856)
Browse files Browse the repository at this point in the history
This is the first PR for RFC openxla#1809. I did not add an interpreter
implementation as @GleasonK specifically asked me to leave that for new
staff joining his team.
  • Loading branch information
chaserileyroberts authored Nov 30, 2023
1 parent df1262c commit dea89ec
Show file tree
Hide file tree
Showing 16 changed files with 2,492 additions and 23 deletions.
72 changes: 67 additions & 5 deletions docs/spec.md
Original file line number Diff line number Diff line change
Expand Up @@ -1647,6 +1647,68 @@ for this operation ([#560](https://github.com/openxla/stablehlo/issues/560)).

 [More Examples](../stablehlo/tests/interpret_clamp.mlir)

### collective_broadcast

#### Semantics

Within each process group in the StableHLO process grid, send the value of the
`operand` tensor from the source process to the target processes and produce a
`result` tensor.

The operation splits the StableHLO process grid into `process_groups` which is
defined as follows:

* `cross_replica(replica_groups)` if `channel_id <= 0`.
* `cross_partition(replica_groups)` if `channel_id > 0`.

Afterwards, `result@process` is given by:

* `operand@process_groups[i, 0]` if there exists an `i` such that the process is
in `process_groups[i]`.
* `broadcast_in_dim(constant(0, element_type(result)), [], type(result))`
otherwise.

#### Inputs

| Label | Name | Type | Constraints |
|-------|-------------------------|------------------------------------------------------------------|-------------|
| (I1) | `operand` | tensor | (C3) |
| (I2) | `replica_groups` | variadic number of 1-dimensional tensor constants of type `si64` | (C1), (C2) |
| (I3) | `channel_id` | constant of type `si64` | |

#### Outputs

| Name | Type | Constraints |
|----------|--------|-------------|
| `result` | tensor | (C3) |

#### Constraints

* (C1) `is_unique(replica_groups)`.
* (C2) `0 <= replica_groups < N` where `N` is defined as:
* `num_replicas` if `cross_replica` is used.
* `num_partitions` if `cross_partition` is used.
* (C3) `type(result) = type(operand)`.

#### Examples

```mlir
// num_replicas: 4
// num_partitions: 1
// %operand@(0, 0): [[1, 2]]
// %operand@(1, 0): [[3, 4]]
// %operand@(2, 0): [[5, 6]]
// %operand@(3, 0): [[7, 8]]
%result = "stablehlo.collective_broadcast"(%operand) {
replica_groups = dense<[[2, 1]]> : tensor<1x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor1x2xi64>) -> tensor<1x2xi64>
// %result@(0, 0): [[0, 0]]
// %result@(1, 0): [[5, 6]]
// %result@(2, 0): [[5, 6]]
// %result@(3, 0): [[0, 0]]
```

### collective_permute

#### Semantics
Expand Down Expand Up @@ -5949,11 +6011,11 @@ order and what kind of synchronization is introduced by it, is TBD

### Collective ops

There are five collective ops in StableHLO: `all_gather`, `all_reduce`,
`all_to_all`, `collective_permute` and `reduce_scatter`. All these ops split
the processes in the StableHLO process grid into **StableHLO process groups**
and execute a joint computation within each process group, independently from
other process groups.
There are six collective ops in StableHLO: `all_gather`, `all_reduce`,
`all_to_all`, `collective_broadcast`, `collective_permute`, and
`reduce_scatter`. All these ops split the processes in the StableHLO process
grid into **StableHLO process groups** and execute a joint computation within
each process group, independently from other process groups.

Within each process group, collective ops may introduce a synchronization
barrier. Further formalization, e.g. elaborating on when exactly this
Expand Down
5 changes: 3 additions & 2 deletions stablehlo/dialect/StablehloAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,9 @@ def StableHLO_OutputOperandAlias : AttrDef<StableHLO_Dialect, "OutputOperandAlia
}

// Represents a unique identifier for each Send/Recv instruction pair or
// optionally for collective instructions (AllReduce, CollectivePermute,
// AllToAll). Non-positive channel_id handle is equivalent to no channel id.
// optionally for collective instructions (AllToAll, AllReduce,
// CollectiveBroadcast, and CollectivePermute). Non-positive channel_id
// handle is equivalent to no channel id.
def StableHLO_ChannelHandle : AttrDef<StableHLO_Dialect, "ChannelHandle"> {
let cppNamespace = "::mlir::stablehlo";
let mnemonic = "channel_handle";
Expand Down
35 changes: 24 additions & 11 deletions stablehlo/dialect/StablehloOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,17 +122,6 @@ LogicalResult TypeExtensionsAttr::verifyEncoding(
getBounds(), RankedTensorType::get(shape, elementType), emitError);
}

//===----------------------------------------------------------------------===//
// CollectivePermuteOp
//===----------------------------------------------------------------------===//

void CollectivePermuteOp::build(OpBuilder& odsBuilder, OperationState& odsState,
Type resultType, Value operand,
DenseIntElementsAttr sourceTargetPairs) {
CollectivePermuteOp::build(odsBuilder, odsState, resultType, operand,
sourceTargetPairs, /*channel_handle=*/nullptr);
}

//===----------------------------------------------------------------------===//
// ReduceScatterOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -171,6 +160,7 @@ INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(Atan2Op)
INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(CbrtOp)
INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(CeilOp)
INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(ClzOp)
INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(CollectiveBroadcastOp)
INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(CollectivePermuteOp)
INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(CosineOp)
INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(CrossReplicaSumOp)
Expand Down Expand Up @@ -807,10 +797,33 @@ LogicalResult AbsOp::inferReturnTypes(
return hlo::inferAbsOp(location, adaptor.getOperand(), inferredReturnTypes);
}

//===----------------------------------------------------------------------===//
// CollectiveBroadcastOp
//===----------------------------------------------------------------------===//

void CollectiveBroadcastOp::build(OpBuilder& odsBuilder,
OperationState& odsState, Type resultType,
Value operand,
DenseIntElementsAttr replica_groups) {
CollectiveBroadcastOp::build(odsBuilder, odsState, resultType, operand,
replica_groups, /*channel_handle=*/nullptr);
}

LogicalResult CollectiveBroadcastOp::verify() {
return hlo::verifyCollectiveBroadcastOp(getLoc(), getReplicaGroups());
}

//===----------------------------------------------------------------------===//
// CollectivePermuteOp
//===----------------------------------------------------------------------===//

void CollectivePermuteOp::build(OpBuilder& odsBuilder, OperationState& odsState,
Type resultType, Value operand,
DenseIntElementsAttr sourceTargetPairs) {
CollectivePermuteOp::build(odsBuilder, odsState, resultType, operand,
sourceTargetPairs, /*channel_handle=*/nullptr);
}

LogicalResult CollectivePermuteOp::verify() {
return hlo::verifyCollectivePermuteOp(getLoc(), getSourceTargetPairs());
}
Expand Down
36 changes: 36 additions & 0 deletions stablehlo/dialect/StablehloOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2020,6 +2020,42 @@ def StableHLO_ConcatenateOp : StableHLO_ShapedInterfaceOp<"concatenate",
}];
}


def StableHLO_CollectiveBroadcastOp: StableHLO_Op<"collective_broadcast",
[HLO_CompatibleOperandsAndResultType /*collective_broadcast_c3*/]> {
let summary = "CollectiveBroadcast operation";
let description = [{
Within each process group in the process grid, send the value of the
`operand` tensor from the source process to the target processes and produce a
`result` tensor.

See:
https://github.com/openxla/stablehlo/blob/main/docs/spec.md#collective_broadcast

Example:
```mlir
%result = "stablehlo.collective_broadcast"(%operand) {
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<1x2xi64>) -> tensor<1x2xi64>
```
}];

let arguments = (ins
HLO_Tensor:$operand, /*collective_broadcast_i1*/
I64ElementsAttr:$replica_groups, /*collective_broadcast_i2*/
OptionalAttr<StableHLO_ChannelHandle>:$channel_handle /*collective_broadcast_i3*/
);
let results = (outs HLO_Tensor);
let hasVerifier = 1;
// channel_handle is only used for the SPMD partitioner, so we add a
// simplified builder method for convenience.
let builders = [
OpBuilder<(ins
"::mlir::Type":$result_type, "::mlir::Value":$operand,
"::mlir::DenseIntElementsAttr":$replica_groups)>];
}

def StableHLO_CollectivePermuteOp: StableHLO_Op<"collective_permute",
[HLO_CompatibleOperandsAndResultType /*collective_permute_c5*/]> {
let summary = "CollectivePermute operation";
Expand Down
29 changes: 29 additions & 0 deletions stablehlo/dialect/TypeInference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3289,6 +3289,35 @@ LogicalResult verifyBroadcastInDimOp(std::optional<Location> location,
return success();
}

LogicalResult verifyCollectiveBroadcastOp(std::optional<Location> location,
DenseIntElementsAttr replicaGroups) {
// collective_permute_i2
auto replicaGroupType = replicaGroups.getType().cast<RankedTensorType>();
if (replicaGroupType.getRank() != 2)
return emitOptionalError(
location, "replica groups should be a rank 2 tensor,",
"but instead it is of rank ", replicaGroupType.getRank());

auto replicaIds = replicaGroups.getValues<int64_t>();
llvm::SmallSet<int64_t, 8> replicaIdsSeen;
for (int64_t replicaId : replicaIds) {
// collective_broadcast_c2
// We only check that is is not negative, as it is impossible
// to statically know `num_replicas` or `num_partitions`
if (replicaId < 0)
return emitOptionalError(
location, "replica_groups values must be positive, but was given ",
replicaId);

// collective_broadcast_c1
if (!replicaIdsSeen.insert(replicaId).second)
return emitOptionalError(location, "replica id #", replicaId,
" seen more than once");
}

return success();
}

LogicalResult verifyCollectivePermuteOp(
std::optional<Location> location, DenseIntElementsAttr sourceTargetPairs) {
auto type = sourceTargetPairs.getType().dyn_cast<RankedTensorType>();
Expand Down
3 changes: 3 additions & 0 deletions stablehlo/dialect/TypeInference.h
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,9 @@ LogicalResult verifyBroadcastInDimOp(std::optional<Location> location,
DenseIntElementsAttr broadcastDimensions,
Value result);

LogicalResult verifyCollectiveBroadcastOp(std::optional<Location> location,
DenseIntElementsAttr replicaGroups);

LogicalResult verifyCollectivePermuteOp(std::optional<Location> location,
DenseIntElementsAttr sourceTargetPairs);

Expand Down
2 changes: 1 addition & 1 deletion stablehlo/dialect/Version.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class Version {
static FailureOr<Version> fromString(llvm::StringRef versionRef);

/// Return a Version representing the current VHLO dialect version.
static Version getCurrentVersion() { return Version(0, 15, 5); }
static Version getCurrentVersion() { return Version(0, 16, 0); }

/// Return a Version representing the minimum supported VHLO dialect version.
static Version getMinimumVersion() { return Version(0, 9, 0); }
Expand Down
1 change: 1 addition & 0 deletions stablehlo/dialect/VhloDialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def VHLO_Dialect : Dialect {
0.12.0: MLIR bytecode version 1 => 3.
0.14.0: MLIR bytecode version 3 => 5 (revised to 4 in #1827).
0.15.0: MLIR bytecode version 5 => 6, use properties in VHLO.
0.16.0: Introduce `collective_broadcast` operation.
}];

let useDefaultAttributePrinterParser = 0;
Expand Down
9 changes: 9 additions & 0 deletions stablehlo/dialect/VhloOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,15 @@ def VHLO_ClzOpV1 : VHLO_Op<"count_leading_zeros_v1", "0.9.0", "current"> {
let results = (outs VHLO_AnyType:$result);
}

def VHLO_CollectiveBroadcastOpV1 : VHLO_Op<"collective_broadcast_v1", "0.16.0", "current"> {
let arguments = (ins
VHLO_AnyType:$operand,
VHLO_AnyAttr:$replica_groups,
VHLO_AnyAttr:$channel_id
);
let results = (outs VHLO_AnyType:$result);
}

def VHLO_CollectivePermuteOpV1 : VHLO_Op<"collective_permute_v1", "0.9.0", "current"> {
let arguments = (ins
VHLO_AnyType:$operand,
Expand Down
Loading

0 comments on commit dea89ec

Please sign in to comment.