Skip to content

Commit

Permalink
Add decomposition for Grid_Sample and Floor op
Browse files Browse the repository at this point in the history
  • Loading branch information
meenakshiramanathan1 committed Dec 31, 2024
1 parent e405246 commit 863f508
Showing 1 changed file with 200 additions and 1 deletion.
201 changes: 200 additions & 1 deletion python/tvm/relay/op/contrib/forge/forge_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1891,7 +1891,10 @@ def callback(self, pre, post, node_map):
act = node_map[self.act][0]
axis = post.attrs.axis
input_shape = [int(dim) for dim in pre.args[0].checked_type.shape]
adjusted_axes = [(ax - len(input_shape)) if ax >= 0 else ax for ax in axis]
if axis is None:
adjusted_axes = [i - len(input_shape) for i, dim in enumerate(input_shape) if dim == 1]
else:
adjusted_axes = [(ax - len(input_shape)) if ax >= 0 else ax for ax in axis]
assert all(ax < 0 for ax in adjusted_axes), "Invalid squeeze dimension: all axes must be negative."
for ax in sorted(adjusted_axes):
act = tvm.relay.squeeze(act, axis=[ax])
Expand Down Expand Up @@ -3953,7 +3956,201 @@ def callback(self, pre, post, node_map):
num_newaxis-=1
return act
return post

class DecomposeGridSample(DFPatternCallback):
"""Given an input and a flow-field grid, computes the output using input
values and pixel locations from grid. The grid_sample operation is typically used for sampling an input tensor (such as a feature map)
at specified coordinates, which are provided by a grid tensor. This operation is widely used in tasks
such as image warping, spatial transformers, and differentiable image resizing.
Args:
im (torch.Tensor): Input feature map, shape (N, C, H, W)
grid (torch.Tensor): Point coordinates, shape (N, Hg, Wg, 2)
align_corners (bool): If set to True, the extrema (-1 and 1) are
considered as referring to the center points of the input’s
corner pixels. If set to False, they are instead considered as
referring to the corner points of the input’s corner pixels,
making the sampling more resolution agnostic.
This decomposition handles the following main cases:
1. Image and grid split across batch: The input image and grid are split per batch, ensuring each sample is processed independently.
2. Grid coordinates handling: The grid coordinates (x, y) are extracted from the flow-field grid and mapped to pixel indices, considering
the `align_corners` flag to adjust for how the extrema are interpreted.
3. Boundary handling: The image is padded to handle boundary conditions, ensuring that when grid coordinates point outside the valid image range,
they are clamped to the image bounds.
4. Bilinear interpolation: For each pixel in the grid, the surrounding pixels are retrieved using bilinear interpolation, which computes the
weighted average of the nearest four pixels based on the grid's fractional coordinates.
5. Efficient computation: The image is reshaped and flattened for efficient indexing, and bilinear interpolation is computed using the
precomputed weights and indices.
6. Results reconstruction: The interpolated values are computed for each channel and combined back into the final result, which is concatenated
along the batch dimension.
Returns:
torch.Tensor: A tensor with sampled points, shape (N, C, Hg, Wg)
This decomposition is adapted from the implementation in
MMCV: https://mmcv.readthedocs.io/en/latest/_modules/mmcv/ops/point_sample.html
"""

def __init__(self):
super().__init__(rewrite_once=True, require_type=True)

self.data = wildcard()
self.grid = wildcard()
self.grid_sample = is_op("image.grid_sample")(self.data, self.grid)
self.pattern = self.grid_sample

def callback(self, pre, post, node_map):
from tvm.relay.frontend.common import infer_shape
data = node_map[self.data][0]
grid = node_map[self.grid][0].args[0]

mode = post.attrs["method"]
align_corners = post.attrs["align_corners"]
# Split the image and grid tensors across the batch size
batch_size = infer_shape(data)[0]
split_im = tvm.relay.split(data, indices_or_sections=batch_size, axis=0)
split_grid = tvm.relay.split(grid, indices_or_sections=batch_size, axis=0)
results = []
for batch_idx in range(batch_size):
data = split_im[batch_idx]
grid = split_grid[batch_idx]
n, c, h, w = infer_shape(data)
gn, gh, gw, _ = infer_shape(grid)
assert len(infer_shape(data)) == 4 and len(infer_shape(grid)) == 4, "Length of data and grid shapes should be 4."

# Extract x and y components from the grid
x = tvm.relay.strided_slice(grid, begin=[0, 0, 0, 0], end=[gn, gh, gw, 1], strides=[1, 1, 1, 1])
x = tvm.relay.squeeze(x, axis=[3])
y = tvm.relay.strided_slice(grid, begin=[0, 0, 0, 1], end=[gn, gh, gw, 2], strides=[1, 1, 1, 1])
y = tvm.relay.squeeze(y, axis=[3])
# Map grid coordinates to image pixel indices
if align_corners:
x = ((x + tvm.relay.const(1.0)) / tvm.relay.const(2.0)) * tvm.relay.const(float(w - 1))
y = ((y + tvm.relay.const(1.0)) / tvm.relay.const(2.0)) * tvm.relay.const(float(h - 1))
else:
x = ((x + tvm.relay.const(1.0)) * tvm.relay.const(float(w)) - tvm.relay.const(1.0)) / tvm.relay.const(2.0)
y = ((y + tvm.relay.const(1.0)) * tvm.relay.const(float(h)) - tvm.relay.const(1.0)) / tvm.relay.const(2.0)

# Compute integer pixel indices for bilinear interpolation
if mode == 'bilinear':
x0 = tvm.relay.floor(x).astype("int32")
y0 = tvm.relay.floor(y).astype("int32")
else:
assert False, f"Unsupported mode: {mode}. Only 'bilinear' is supported."

# Compute next indices for interpolation
x1 = x0 + tvm.relay.const(1, dtype="int32")
y1 = y0 + tvm.relay.const(1, dtype="int32")

# Compute interpolation weights
if mode == 'bilinear':
wa = tvm.relay.expand_dims((tvm.relay.cast(x1, "float32") - x) * (tvm.relay.cast(y1, "float32") - y), axis=1)
wb = tvm.relay.expand_dims((tvm.relay.cast(x1, "float32") - x) * (y - tvm.relay.cast(y0, "float32")), axis=1)
wc = tvm.relay.expand_dims((x - tvm.relay.cast(x0, "float32")) * (tvm.relay.cast(y1, "float32") - y), axis=1)
wd = tvm.relay.expand_dims((x - tvm.relay.cast(x0, "float32")) * (y - tvm.relay.cast(y0, "float32")), axis=1)

# Add padding to the input image to handle boundary conditions
im_padded = tvm.relay.nn.pad(data, pad_width=((0, 0), (0, 0), (1, 1), (1, 1)))
padded_h = h + 2
padded_w = w + 2
temp_w = padded_w - 1
temp_h = padded_h - 1

# Adjust indices to include padding and clip within valid image bounds
pad_const = tvm.relay.const(1, dtype="int32")
x0 = tvm.relay.clip(x0 + pad_const, a_min=0, a_max=temp_w)
x1 = tvm.relay.clip(x1 + pad_const, a_min=0, a_max=temp_w)
y0 = tvm.relay.clip(y0 + pad_const, a_min=0, a_max=temp_h)
y1 = tvm.relay.clip(y1 + pad_const, a_min=0, a_max=temp_h)
# Precompute constants and reshape padded image for flattened indexing
padded_w_const = tvm.relay.const(padded_w, dtype="int32")
im_padded = tvm.relay.reshape(im_padded, (n, c, -1))

# Compute flattened indices for each corner in a compact way
def compute_flattened_indices(x, y):
return tvm.relay.reshape(x + y * padded_w_const, newshape=[-1])

x0_y0_flattened = compute_flattened_indices(x0, y0)
x0_y1_flattened = compute_flattened_indices(x0, y1)
x1_y0_flattened = compute_flattened_indices(x1, y0)
x1_y1_flattened = compute_flattened_indices(x1, y1)

num_splits = infer_shape(im_padded)[1]
split_im_padded = tvm.relay.split(im_padded, indices_or_sections=num_splits, axis=1)
x0_y0_parts = []
x0_y1_parts = []
x1_y0_parts = []
x1_y1_parts = []
x0_y0 = x0 + y0 * padded_w_const
t1, t2 = infer_shape(tvm.relay.squeeze(x0_y0, axis=[0]))
flatten_shape = [t1, t2]

# Perform bilinear interpolation for each channel
for i in range(num_splits):
im_part = split_im_padded[i]
im_part = tvm.relay.squeeze(im_part, axis=[0])
im_part = tvm.relay.transpose(im_part, [1, 0])
im_part = tvm.relay.expand_dims(im_part, axis=1)

x0_y0 = tvm.relay.expand_dims(tvm.relay.reshape(tvm.relay.take(im_part, x0_y0_flattened, axis=0),newshape=flatten_shape),axis=0)
x0_y0_parts.append(x0_y0)

x0_y1 = tvm.relay.expand_dims(tvm.relay.reshape(tvm.relay.take(im_part, x0_y1_flattened, axis=0),newshape=flatten_shape),axis=0)
x0_y1_parts.append(x0_y1)

x1_y0 = tvm.relay.expand_dims(tvm.relay.reshape(tvm.relay.take(im_part, x1_y0_flattened, axis=0),newshape=flatten_shape),axis=0)
x1_y0_parts.append(x1_y0)

x1_y1 = tvm.relay.expand_dims(tvm.relay.reshape(tvm.relay.take(im_part, x1_y1_flattened, axis=0),newshape=flatten_shape),axis=0)
x1_y1_parts.append(x1_y1)

# Concatenate parts back for all channels
x0_y0 = tvm.relay.concatenate(x0_y0_parts, axis=0)
x0_y1 = tvm.relay.concatenate(x0_y1_parts, axis=0)
x1_y0 = tvm.relay.concatenate(x1_y0_parts, axis=0)
x1_y1 = tvm.relay.concatenate(x1_y1_parts, axis=0)

# Expand dimensions to align with batch/channel format
x0_y0, x0_y1, x1_y0, x1_y1 = [tvm.relay.expand_dims(x, axis=1) for x in [x0_y0, x0_y1, x1_y0, x1_y1]]

# Compute the final output using bilinear weights
if mode == 'bilinear':
output = (x0_y0 * wa + x0_y1 * wb + x1_y0 * wc + x1_y1 * wd)
output = tvm.relay.transpose(output, [1, 0, 2, 3])
else:
assert False, f"Unsupported mode: {mode}. Only 'bilinear' is supported."

results.append(output)

# Concatenate results along the batch dimension
final_output = tvm.relay.concatenate(results, axis=0)
return final_output


class DecomposeFloor(DFPatternCallback):
def __init__(self):
super().__init__(rewrite_once=True, require_type=True)

self.data = wildcard()
self.floor = is_op("floor")(self.data)

self.pattern = self.floor

def callback(self, pre, post, node_map):
pre_node_map = construct_pre_node_map(self.pattern, pre)
data = node_map[self.data][0]

# Decompose floor operation using the formula:
# floor(x) = int(x) - 1 if x < 0 and x != int(x)
int_part = tvm.relay.cast(data, "int32")
int_part = tvm.relay.cast(int_part, "float32")
negative_mask = tvm.relay.less(data, tvm.relay.const(0.0, "float32"))
not_equal_mask = tvm.relay.not_equal(data, int_part)

# Compute the adjustment: subtract 1 for negative non-integer values
adjustment = tvm.relay.where(negative_mask, tvm.relay.cast(not_equal_mask, "float32"), tvm.relay.const(0.0, "float32"))
floor_result = tvm.relay.subtract(int_part, adjustment) # Subtract adjustment to get the floor value
return floor_result

def _get_callback_name(callback):
if isinstance(callback, DFPatternCallback):
Expand Down Expand Up @@ -4002,6 +4199,8 @@ def run_forge_compile_passes(relay_module, params=None, inputs=None, target=None
return run_pattern_callbacks(
relay_module,
[
DecomposeGridSample(),
DecomposeFloor(),
ExpandMultipleDims(),
DecomposeReverse(),
ConvertLayout(),
Expand Down

0 comments on commit 863f508

Please sign in to comment.