Skip to content

Commit

Permalink
[XeTileToXeGPU] Remove old implementation of XeTileToXeGPU. (#984)
Browse files Browse the repository at this point in the history
  • Loading branch information
chencha3 authored Dec 16, 2024
1 parent 2bd7407 commit eb22052
Show file tree
Hide file tree
Showing 51 changed files with 758 additions and 4,612 deletions.
5 changes: 1 addition & 4 deletions include/imex/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -394,10 +394,7 @@ def ConvertXeTileToXeGPU: Pass<"convert-xetile-to-xegpu", "::mlir::gpu::GPUModul
let options = [
Option<"device", "device", "std::string",
/*default=*/"\"pvc\"",
"gpu platform architecture where these ops are running">,
Option<"EnableTransform", "enable-2d-transform", "bool",
/*default=*/"false",
"Using 2D transform or 4D Conversion.">
"gpu platform architecture where these ops are running">
];
}

Expand Down
9 changes: 2 additions & 7 deletions include/imex/Conversion/XeTileToXeGPU/XeTileToXeGPU.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@
#include <mlir/IR/PatternMatch.h>
#include <mlir/Transforms/DialectConversion.h>

#include "XeTileToXeGPUConversion.h"

namespace mlir {
class MLIRContext;
class ModuleOp;
Expand All @@ -37,12 +35,9 @@ namespace imex {
#define GEN_PASS_DECL_CONVERTXETILETOXEGPU
#include "imex/Conversion/Passes.h.inc"

class XeOneToNTypeConverter;

/// Populate the given list with patterns rewrite XeTile Ops
void populateXeTileToXeGPUConversionPatterns(XeOneToNTypeConverter &converter,
mlir::RewritePatternSet &patterns,
imex::TileUsageAnalysis &analysis);
void populateXeTileToXeGPUConversionPatterns(mlir::TypeConverter &converter,
mlir::RewritePatternSet &patterns);

/// Create a pass to convert the XeTile dialect to the XeGPU dialect.
std::unique_ptr<mlir::OperationPass<mlir::gpu::GPUModuleOp>>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
//===- TypeConverter.h - XeTileToXeGPU conversion -------*- C++ -*-===//
//===- XeTileOneToNConversion.h --- XeTileOneToNConversion -----*- C++ -*-===//
//
// Copyright 2022 Intel Corporation
// Part of the IMEX Project, under the Apache License v2.0 with LLVM Exceptions.
Expand All @@ -9,7 +9,7 @@
///
/// \file
/// This file defines the XeOneToNConversion, the base class for
/// XeTileToXeGPU conversion, XeOneToNTypeConverter, converting types used in
/// doing OneToN conversion, XeOneToNTypeConverter, converting types used in
/// XeTile dialect to types used in XeGPU dialect, XeOneToNPatternRewriter a
/// wrapper around ConversionPatterRewriter providng interface for supporting
/// OneToN replace.
Expand Down
16 changes: 13 additions & 3 deletions lib/Conversion/XeGPUToVC/XeGPUToVC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -603,6 +603,7 @@ class VectorShapeCastPattern : public OpConversionPattern<ShapeCastOp> {

if (!dstType)
return failure();

if (dstType == adaptor.getSource().getType() ||
shapeCastOp.getResultVectorType().getNumElements() == 1) {
rewriter.replaceOp(shapeCastOp, adaptor.getSource());
Expand Down Expand Up @@ -760,7 +761,10 @@ struct XeGPUToVCPass : public imex::impl::ConvertXeGPUToVCBase<XeGPUToVCPass> {
target.addDynamicallyLegalDialect<scf::SCFDialect>(
[&](Operation *op) { return isLegalXeGPUSCFOp(op, typeConverter); });

target.addIllegalOp<ShapeCastOp>();
target.addDynamicallyLegalOp<ShapeCastOp>([&](ShapeCastOp op) {
return typeConverter.isLegal(op.getType()) &&
typeConverter.isLegal(op.getSource().getType());
});

// TODO: can we change it to addDynamicLegalOp?
target.addLegalOp<mlir::UnrealizedConversionCastOp>();
Expand All @@ -786,15 +790,21 @@ struct XeGPUToVCPass : public imex::impl::ConvertXeGPUToVCBase<XeGPUToVCPass> {
});

typeConverter.addConversion([&](VectorType type) -> Type {
// TODO: it looks like needs some improvement for matching upstream
// passes
// TODO: I don't think we need to convert 2D VectorType to
// 1D VectorType. It needs to removed after we move vector
// linearization after this pass

unsigned rank = type.getRank();
auto elemType = type.getElementType();

if (rank < 1)
return elemType;

// TODO: a temporary fix to avoid do type conversion
// for create_mask result
if (elemType.isInteger(1))
return type;

unsigned sum = 1;
for (unsigned i = 0; i < rank; i++) {
sum *= type.getShape()[i];
Expand Down
Loading

0 comments on commit eb22052

Please sign in to comment.