Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add decomposition for Grid_Sample and Floor op #54

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
205 changes: 204 additions & 1 deletion python/tvm/relay/op/contrib/forge/forge_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1891,7 +1891,12 @@ def callback(self, pre, post, node_map):
act = node_map[self.act][0]
axis = post.attrs.axis
input_shape = [int(dim) for dim in pre.args[0].checked_type.shape]
adjusted_axes = [(ax - len(input_shape)) if ax >= 0 else ax for ax in axis]
if axis is None:
meenakshiramanathan1 marked this conversation as resolved.
Show resolved Hide resolved
# PyTorch's `squeeze_()` removes all dimensions of size 1.
# Adjust all size-1 dimensions into negative axis indices.
adjusted_axes = [i - len(input_shape) for i, dim in enumerate(input_shape) if dim == 1]
else:
adjusted_axes = [(ax - len(input_shape)) if ax >= 0 else ax for ax in axis]
assert all(ax < 0 for ax in adjusted_axes), "Invalid squeeze dimension: all axes must be negative."
for ax in sorted(adjusted_axes):
act = tvm.relay.squeeze(act, axis=[ax])
Expand Down Expand Up @@ -3953,7 +3958,203 @@ def callback(self, pre, post, node_map):
num_newaxis-=1
return act
return post

class DecomposeGridSample(DFPatternCallback):
meenakshiramanathan1 marked this conversation as resolved.
Show resolved Hide resolved
"""Given an input and a flow-field grid, computes the output using input
values and pixel locations from grid. The grid_sample operation is typically used for sampling an input tensor (such as a feature map)
at specified coordinates, which are provided by a grid tensor. This operation is widely used in tasks
such as image warping, spatial transformers, and differentiable image resizing.

Args:
im (torch.Tensor): Input feature map, shape (N, C, H, W)
grid (torch.Tensor): Point coordinates, shape (N, Hg, Wg, 2)
align_corners (bool): If set to True, the extrema (-1 and 1) are
considered as referring to the center points of the input’s
corner pixels. If set to False, they are instead considered as
referring to the corner points of the input’s corner pixels,
making the sampling more resolution agnostic.

This decomposition handles the following main cases:
1. Image and grid split across batch: The input image and grid are split per batch, ensuring each sample is processed independently.
2. Grid coordinates handling: The grid coordinates (x, y) are extracted from the flow-field grid and mapped to pixel indices, considering
the `align_corners` flag to adjust for how the extrema are interpreted.
3. Boundary handling: The image is padded to handle boundary conditions, ensuring that when grid coordinates point outside the valid image range,
they are clamped to the image bounds.
4. Bilinear interpolation: For each pixel in the grid, the surrounding pixels are retrieved using bilinear interpolation, which computes the
weighted average of the nearest four pixels based on the grid's fractional coordinates.
5. Efficient computation: The image is reshaped and flattened for efficient indexing, and bilinear interpolation is computed using the
precomputed weights and indices.
6. Results reconstruction: The interpolated values are computed for each channel and combined back into the final result, which is concatenated
along the batch dimension.

Returns:
torch.Tensor: A tensor with sampled points, shape (N, C, Hg, Wg)
This decomposition is adapted from the implementation in
MMCV: https://mmcv.readthedocs.io/en/latest/_modules/mmcv/ops/point_sample.html
"""

def __init__(self):
super().__init__(rewrite_once=True, require_type=True)

self.data = wildcard()
self.grid = wildcard()
self.grid_sample = is_op("image.grid_sample")(self.data, self.grid)
self.pattern = self.grid_sample

def callback(self, pre, post, node_map):
from tvm.relay.frontend.common import infer_shape

# Attributes extraction
data = node_map[self.data][0]
grid = node_map[self.grid][0].args[0]
mode = post.attrs["method"]
align_corners = post.attrs["align_corners"]

# Split the image and grid tensors across the batch size
batch_size = infer_shape(data)[0]
meenakshiramanathan1 marked this conversation as resolved.
Show resolved Hide resolved
split_im = tvm.relay.split(data, indices_or_sections=batch_size, axis=0)
meenakshiramanathan1 marked this conversation as resolved.
Show resolved Hide resolved
split_grid = tvm.relay.split(grid, indices_or_sections=batch_size, axis=0)
results = []
for batch_idx in range(batch_size):
data = split_im[batch_idx]
grid = split_grid[batch_idx]
n, c, h, w = infer_shape(data)
gn, gh, gw, _ = infer_shape(grid)
assert len(infer_shape(data)) == 4 and len(infer_shape(grid)) == 4, "Length of data and grid shapes should be 4."

# Extract x and y components from the grid
x = tvm.relay.strided_slice(grid, begin=[0, 0, 0, 0], end=[gn, gh, gw, 1], strides=[1, 1, 1, 1])
x = tvm.relay.squeeze(x, axis=[3])
y = tvm.relay.strided_slice(grid, begin=[0, 0, 0, 1], end=[gn, gh, gw, 2], strides=[1, 1, 1, 1])
y = tvm.relay.squeeze(y, axis=[3])
# Map grid coordinates to image pixel indices
if align_corners:
x = ((x + tvm.relay.const(1.0)) / tvm.relay.const(2.0)) * tvm.relay.const(float(w - 1))
y = ((y + tvm.relay.const(1.0)) / tvm.relay.const(2.0)) * tvm.relay.const(float(h - 1))
else:
x = ((x + tvm.relay.const(1.0)) * tvm.relay.const(float(w)) - tvm.relay.const(1.0)) / tvm.relay.const(2.0)
y = ((y + tvm.relay.const(1.0)) * tvm.relay.const(float(h)) - tvm.relay.const(1.0)) / tvm.relay.const(2.0)

# Compute integer pixel indices for bilinear interpolation
if mode == 'bilinear':
x0 = tvm.relay.floor(x).astype("int32")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to do a type cast?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, if it's not typecasted we are ending up with dtype mismatch later.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Let's print out a warning as well.

y0 = tvm.relay.floor(y).astype("int32")
else:
assert False, f"Unsupported mode: {mode}. Only 'bilinear' is supported."

# Compute next indices for interpolation
x1 = x0 + tvm.relay.const(1, dtype="int32")
y1 = y0 + tvm.relay.const(1, dtype="int32")

# Compute interpolation weights
if mode == 'bilinear':
wa = tvm.relay.expand_dims((tvm.relay.cast(x1, "float32") - x) * (tvm.relay.cast(y1, "float32") - y), axis=1)
wb = tvm.relay.expand_dims((tvm.relay.cast(x1, "float32") - x) * (y - tvm.relay.cast(y0, "float32")), axis=1)
wc = tvm.relay.expand_dims((x - tvm.relay.cast(x0, "float32")) * (tvm.relay.cast(y1, "float32") - y), axis=1)
wd = tvm.relay.expand_dims((x - tvm.relay.cast(x0, "float32")) * (y - tvm.relay.cast(y0, "float32")), axis=1)

# Add padding to the input image to handle boundary conditions
im_padded = tvm.relay.nn.pad(data, pad_width=((0, 0), (0, 0), (1, 1), (1, 1)))
padded_h = h + 2
padded_w = w + 2
temp_w = padded_w - 1
temp_h = padded_h - 1

# Adjust indices to include padding and clip within valid image bounds
pad_const = tvm.relay.const(1, dtype="int32")
x0 = tvm.relay.clip(x0 + pad_const, a_min=0, a_max=temp_w)
x1 = tvm.relay.clip(x1 + pad_const, a_min=0, a_max=temp_w)
y0 = tvm.relay.clip(y0 + pad_const, a_min=0, a_max=temp_h)
y1 = tvm.relay.clip(y1 + pad_const, a_min=0, a_max=temp_h)
# Precompute constants and reshape padded image for flattened indexing
padded_w_const = tvm.relay.const(padded_w, dtype="int32")
im_padded = tvm.relay.reshape(im_padded, (n, c, -1))

# Compute flattened indices for each corner in a compact way
def compute_flattened_indices(x, y):
return tvm.relay.reshape(x + y * padded_w_const, newshape=[-1])

x0_y0_flattened = compute_flattened_indices(x0, y0)
x0_y1_flattened = compute_flattened_indices(x0, y1)
x1_y0_flattened = compute_flattened_indices(x1, y0)
x1_y1_flattened = compute_flattened_indices(x1, y1)

num_splits = infer_shape(im_padded)[1]
split_im_padded = tvm.relay.split(im_padded, indices_or_sections=num_splits, axis=1)
x0_y0_parts = []
x0_y1_parts = []
x1_y0_parts = []
x1_y1_parts = []
x0_y0 = x0 + y0 * padded_w_const
t1, t2 = infer_shape(tvm.relay.squeeze(x0_y0, axis=[0]))
flatten_shape = [t1, t2]

# Perform bilinear interpolation for each channel
for i in range(num_splits):
im_part = split_im_padded[i]
im_part = tvm.relay.squeeze(im_part, axis=[0])
im_part = tvm.relay.transpose(im_part, [1, 0])
im_part = tvm.relay.expand_dims(im_part, axis=1)

x0_y0 = tvm.relay.expand_dims(tvm.relay.reshape(tvm.relay.take(im_part, x0_y0_flattened, axis=0),newshape=flatten_shape),axis=0)
x0_y0_parts.append(x0_y0)

x0_y1 = tvm.relay.expand_dims(tvm.relay.reshape(tvm.relay.take(im_part, x0_y1_flattened, axis=0),newshape=flatten_shape),axis=0)
x0_y1_parts.append(x0_y1)

x1_y0 = tvm.relay.expand_dims(tvm.relay.reshape(tvm.relay.take(im_part, x1_y0_flattened, axis=0),newshape=flatten_shape),axis=0)
x1_y0_parts.append(x1_y0)

x1_y1 = tvm.relay.expand_dims(tvm.relay.reshape(tvm.relay.take(im_part, x1_y1_flattened, axis=0),newshape=flatten_shape),axis=0)
x1_y1_parts.append(x1_y1)

# Concatenate parts back for all channels
x0_y0 = tvm.relay.concatenate(x0_y0_parts, axis=0)
x0_y1 = tvm.relay.concatenate(x0_y1_parts, axis=0)
x1_y0 = tvm.relay.concatenate(x1_y0_parts, axis=0)
x1_y1 = tvm.relay.concatenate(x1_y1_parts, axis=0)

# Expand dimensions to align with batch/channel format
x0_y0, x0_y1, x1_y0, x1_y1 = [tvm.relay.expand_dims(x, axis=1) for x in [x0_y0, x0_y1, x1_y0, x1_y1]]

# Compute the final output using bilinear weights
if mode == 'bilinear':
output = (x0_y0 * wa + x0_y1 * wb + x1_y0 * wc + x1_y1 * wd)
output = tvm.relay.transpose(output, [1, 0, 2, 3])
else:
assert False, f"Unsupported mode: {mode}. Only 'bilinear' is supported."

results.append(output)

# Concatenate results along the batch dimension
final_output = tvm.relay.concatenate(results, axis=0)
return final_output


class DecomposeFloor(DFPatternCallback):
def __init__(self):
super().__init__(rewrite_once=True, require_type=True)

self.data = wildcard()
self.floor = is_op("floor")(self.data)

self.pattern = self.floor

def callback(self, pre, post, node_map):
pre_node_map = construct_pre_node_map(self.pattern, pre)
data = node_map[self.data][0]

# Decompose floor operation using the formula:
# floor(x) = int(x) - 1 if x < 0 and x != int(x)
int_part = tvm.relay.cast(data, "int32")
int_part = tvm.relay.cast(int_part, "float32")
negative_mask = tvm.relay.less(data, tvm.relay.const(0.0, "float32"))
not_equal_mask = tvm.relay.not_equal(data, int_part)

# Compute the adjustment: subtract 1 for negative non-integer values
adjustment = tvm.relay.where(negative_mask, tvm.relay.cast(not_equal_mask, "float32"), tvm.relay.const(0.0, "float32"))
floor_result = tvm.relay.subtract(int_part, adjustment) # Subtract adjustment to get the floor value
return floor_result

def _get_callback_name(callback):
if isinstance(callback, DFPatternCallback):
Expand Down Expand Up @@ -4002,6 +4203,8 @@ def run_forge_compile_passes(relay_module, params=None, inputs=None, target=None
return run_pattern_callbacks(
relay_module,
[
DecomposeGridSample(),
DecomposeFloor(),
ExpandMultipleDims(),
DecomposeReverse(),
ConvertLayout(),
Expand Down