Skip to content

Commit

Permalink
Basic simplifier for indexing maps.
Browse files Browse the repository at this point in the history
Doesn't yet handle everything that can and should be handled, see TODOs in tests.

PiperOrigin-RevId: 586327437
  • Loading branch information
jreiffers authored and tensorflower-gardener committed Nov 29, 2023
1 parent 34c440c commit 7486607
Show file tree
Hide file tree
Showing 3 changed files with 291 additions and 11 deletions.
264 changes: 264 additions & 0 deletions third_party/xla/xla/service/gpu/model/tile_analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,12 @@ namespace gpu {
namespace {

using llvm::SmallVector;
using mlir::AffineBinaryOpExpr;
using mlir::AffineDimExpr;
using mlir::AffineExpr;
using mlir::AffineExprKind;
using mlir::AffineMap;
using mlir::AffineSymbolExpr;
using mlir::getAffineBinaryOpExpr;
using mlir::getAffineConstantExpr;
using mlir::getAffineDimExpr;
Expand Down Expand Up @@ -590,8 +593,269 @@ std::string ToStringImpl(const T& value) {
return ss.str();
}

struct IndexingMapSimplifier {
struct Bounds {
int64_t lower;
int64_t upper;
};

Bounds BoundsInclusive(AffineExpr expr) {
auto bound = bounds.find(expr);
if (bound != bounds.end()) return bound->second;

switch (expr.getKind()) {
case AffineExprKind::Constant: {
int64_t value = mlir::cast<mlir::AffineConstantExpr>(expr).getValue();
CHECK_GE(value, 0);
return bounds[expr] = {value, value};
}
case AffineExprKind::DimId: {
int64_t size =
dimension_sizes[mlir::cast<AffineDimExpr>(expr).getPosition()];
return bounds[expr] = {0, size - 1};
}
case AffineExprKind::SymbolId: {
int64_t size =
symbol_sizes[mlir::cast<AffineSymbolExpr>(expr).getPosition()];
return bounds[expr] = {0, size - 1};
}
default:
auto binary_op = mlir::dyn_cast<AffineBinaryOpExpr>(expr);
CHECK(binary_op);
auto lhs = BoundsInclusive(binary_op.getLHS());
auto rhs = BoundsInclusive(binary_op.getRHS());

auto& result = bounds[expr];
switch (expr.getKind()) {
case AffineExprKind::Add:
return result = {lhs.lower + rhs.lower, lhs.upper + rhs.upper};
case AffineExprKind::Mul:
return result = {lhs.lower * rhs.lower, lhs.upper * rhs.upper};
case AffineExprKind::Mod: {
CHECK_EQ(rhs.lower, rhs.upper) << "RHS of mod must be a constant";
int64_t m = rhs.lower;
if (lhs.upper < m) {
return result = lhs;
}
return result = {0, m - 1};
}
case AffineExprKind::FloorDiv: {
CHECK_EQ(rhs.lower, rhs.upper)
<< "RHS of floor_div must be a constant";
int64_t d = rhs.lower;
return result = {lhs.lower / d, lhs.upper / d};
}
default:
// We don't use ceildiv, so we don't support it.
LOG(FATAL) << "Unsupported expression";
}
}
}

// Simplifier for mod.
// - Rewrites (a * 100 + ...) % 100 to (...) % 100
// - Rewrites a % b to a if a is known to be less than b.
AffineExpr RewriteMod(AffineBinaryOpExpr mod) {
auto lhs_simplified = SimplifyOnce(mod.getLHS());

auto lhs = BoundsInclusive(lhs_simplified);
auto rhs = BoundsInclusive(mod.getRHS());

// a % b where b is always larger than a?
if (lhs.upper < rhs.lower) return lhs_simplified;

// The logic below assumes we have a constant RHS.
if (rhs.lower != rhs.upper) return mod;
int64_t m = rhs.lower;

auto new_lhs = RewriteSumIf(lhs_simplified, [&](AffineExpr expr) {
if (expr.getKind() != AffineExprKind::Mul) {
return true;
}

auto mul_rhs =
BoundsInclusive(mlir::cast<AffineBinaryOpExpr>(expr).getRHS());
bool remove = mul_rhs.lower == mul_rhs.upper && (mul_rhs.lower % m) == 0;
return !remove; // We keep it if we don't remove it!
});

// If we weren't able to remove or simplify anything, return the original
// expression.
if (new_lhs == mod.getLHS()) {
return mod;
}
// If we removed everything, return 0.
if (!new_lhs) {
return getAffineConstantExpr(0, mlir_context);
}
// Otherwise, return new_sum % m.
return getAffineBinaryOpExpr(AffineExprKind::Mod, new_lhs, mod.getRHS());
}

// Simplifier for floordiv.
// - Rewrites (a * 100 + ...) / 100 to a + (...) / 100
// - Rewrites a / 100 to 0 when a is known to be less than 100.
AffineExpr RewriteFloorDiv(AffineBinaryOpExpr div) {
auto lhs_simplified = SimplifyOnce(div.getLHS());
auto lhs = BoundsInclusive(lhs_simplified);
auto rhs = BoundsInclusive(div.getRHS());

if (lhs.upper < rhs.lower) {
return getAffineConstantExpr(0, mlir_context);
}

// The logic below assumes we have a constant RHS.
if (rhs.lower != rhs.upper) return div;
int64_t d = rhs.lower;

AffineExpr extracted = getAffineConstantExpr(0, mlir_context);
auto new_dividend = RewriteSumIf(lhs_simplified, [&](AffineExpr expr) {
if (auto multiplier = GetConstantRhsMultiplier(expr)) {
// (x * 7 + ...) / 3 -> can't extract. We could extract x * 2 and keep
// one x, but we currently have no reason to do that.
if (*multiplier % d != 0) return true;
int64_t factor = *multiplier / d;
extracted = getAffineBinaryOpExpr(
AffineExprKind::Add, extracted,
getAffineBinaryOpExpr(AffineExprKind::Mul,
cast<AffineBinaryOpExpr>(expr).getLHS(),
getAffineConstantExpr(factor, mlir_context)));
// Remove from dividend.
return false;
}

// Not a constant multiplier, keep in dividend.
return true;
});

// If we removed everything, skip the div.
if (!new_dividend) return extracted;
// If we removed nothing, return the original division.
if (extracted == getAffineConstantExpr(0, mlir_context) &&
new_dividend == div.getLHS()) {
return div;
}

return getAffineBinaryOpExpr(
AffineExprKind::Add, extracted,
getAffineBinaryOpExpr(AffineExprKind::FloorDiv, new_dividend,
div.getRHS()));
}

std::optional<int64_t> GetConstantRhsMultiplier(AffineExpr expr) {
if (expr.getKind() != AffineExprKind::Mul) return std::nullopt;
auto bound = BoundsInclusive(mlir::cast<AffineBinaryOpExpr>(expr).getRHS());
if (bound.lower != bound.upper) return std::nullopt;
return bound.lower;
}

AffineExpr RewriteSumIf(AffineExpr expr,
const std::function<bool(AffineExpr)>& pred) {
if (expr.getKind() == AffineExprKind::Add) {
auto add = mlir::dyn_cast<AffineBinaryOpExpr>(expr);
auto lhs = RewriteSumIf(add.getLHS(), pred);
auto rhs = RewriteSumIf(add.getRHS(), pred);
if (lhs == add.getLHS() && rhs == add.getRHS()) {
return add;
}
if (lhs && rhs) {
return getAffineBinaryOpExpr(AffineExprKind::Add, lhs, rhs);
}
return lhs ? lhs : (rhs ? rhs : nullptr);
}
return pred(expr) ? expr : nullptr;
}

// Attempts to simplify the expression, but doesn't attempt to simplify the
// result further.
AffineExpr SimplifyOnce(AffineExpr expr) {
switch (expr.getKind()) {
case AffineExprKind::Mul:
case AffineExprKind::Add: {
auto binop = mlir::cast<AffineBinaryOpExpr>(expr);
auto lhs = SimplifyOnce(binop.getLHS());
auto rhs = SimplifyOnce(binop.getRHS());
if (lhs == binop.getLHS() && rhs == binop.getRHS()) {
return expr;
}
return getAffineBinaryOpExpr(expr.getKind(), lhs, rhs);
}
case AffineExprKind::Mod:
return RewriteMod(cast<AffineBinaryOpExpr>(expr));
case AffineExprKind::FloorDiv:
return RewriteFloorDiv(cast<AffineBinaryOpExpr>(expr));
default:
return expr;
}
}

// Simplifies the expression as much as possible.
AffineExpr Simplify(AffineExpr expr) {
while (true) {
auto simplified = SimplifyOnce(expr);
if (simplified == expr) return expr;
expr = simplified;
}
}

MLIRContext* mlir_context;
absl::Span<const int64_t> dimension_sizes;
absl::Span<const int64_t> symbol_sizes;
llvm::DenseMap<AffineExpr, Bounds> bounds{};
};

} // namespace

bool IndexingMap::Simplify(absl::Span<const int64_t> dimension_sizes) {
IndexingMapSimplifier simplifier{affine_map.getContext(), dimension_sizes,
input_dims_sizes};
std::vector<AffineExpr> results;
bool any_changed = false;
for (auto expr : affine_map.getResults()) {
auto simplified = simplifier.Simplify(expr);
any_changed |= simplified != expr;
results.push_back(simplified);
}

if (!any_changed) {
return false;
}

affine_map =
AffineMap::get(affine_map.getNumDims(), affine_map.getNumSymbols(),
results, affine_map.getContext());
return true;
}

bool HloOperandIndexing::Simplify(absl::Span<const int64_t> dimension_sizes) {
std::vector<IndexingMap> to_remove;
std::vector<IndexingMap> to_add;
for (auto map : indexing_maps) {
to_remove.push_back(map);
if (map.Simplify(dimension_sizes)) {
to_add.push_back(map);
} else {
to_remove.pop_back();
}
}
for (auto& map : to_remove) {
indexing_maps.erase(map);
}
for (auto& map : to_add) {
indexing_maps.insert(map);
}
return !to_remove.empty();
}

bool HloInstructionIndexing::Simplify(
absl::Span<const int64_t> dimension_sizes) {
bool any_simplified = false;
for (auto& operand_indexing : operand_indexing_maps) {
any_simplified |= operand_indexing.Simplify(dimension_sizes);
}
return any_simplified;
}

std::string ToString(const AffineMap& affine_map) {
std::string s;
llvm::raw_string_ostream ss(s);
Expand Down
8 changes: 8 additions & 0 deletions third_party/xla/xla/service/gpu/model/tile_analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ namespace gpu {
// could not be expressed via dimensions of the output.
struct IndexingMap {
std::string ToString() const;
// Returns true if the map was simplified.
bool Simplify(absl::Span<const int64_t> dimension_sizes);

mlir::AffineMap affine_map;
std::vector<int64_t> input_dims_sizes;
Expand All @@ -84,6 +86,9 @@ H AbslHashValue(H h, const IndexingMap& indexing_map) {
struct HloOperandIndexing {
std::string ToString() const;

// Returns true if the indexing was simplified.
bool Simplify(absl::Span<const int64_t> dimension_sizes);

absl::flat_hash_set<IndexingMap> indexing_maps;
int64_t operand_id;
};
Expand All @@ -95,6 +100,9 @@ std::ostream& operator<<(std::ostream& out,
struct HloInstructionIndexing {
std::string ToString() const;

// Returns true if the indexing was simplified.
bool Simplify(absl::Span<const int64_t> dimension_sizes);

std::vector<HloOperandIndexing> operand_indexing_maps;
};
std::ostream& operator<<(std::ostream& out,
Expand Down
Loading

0 comments on commit 7486607

Please sign in to comment.