diff --git a/python/tvm/relay/op/contrib/forge/forge_passes.py b/python/tvm/relay/op/contrib/forge/forge_passes.py index fd6533460..cff76391f 100644 --- a/python/tvm/relay/op/contrib/forge/forge_passes.py +++ b/python/tvm/relay/op/contrib/forge/forge_passes.py @@ -1891,7 +1891,10 @@ def callback(self, pre, post, node_map): act = node_map[self.act][0] axis = post.attrs.axis input_shape = [int(dim) for dim in pre.args[0].checked_type.shape] - adjusted_axes = [(ax - len(input_shape)) if ax >= 0 else ax for ax in axis] + if axis is None: + adjusted_axes = [i - len(input_shape) for i, dim in enumerate(input_shape) if dim == 1] + else: + adjusted_axes = [(ax - len(input_shape)) if ax >= 0 else ax for ax in axis] assert all(ax < 0 for ax in adjusted_axes), "Invalid squeeze dimension: all axes must be negative." for ax in sorted(adjusted_axes): act = tvm.relay.squeeze(act, axis=[ax]) @@ -3953,7 +3956,201 @@ def callback(self, pre, post, node_map): num_newaxis-=1 return act return post + +class DecomposeGridSample(DFPatternCallback): + """Given an input and a flow-field grid, computes the output using input + values and pixel locations from grid. The grid_sample operation is typically used for sampling an input tensor (such as a feature map) + at specified coordinates, which are provided by a grid tensor. This operation is widely used in tasks + such as image warping, spatial transformers, and differentiable image resizing. + + Args: + im (torch.Tensor): Input feature map, shape (N, C, H, W) + grid (torch.Tensor): Point coordinates, shape (N, Hg, Wg, 2) + align_corners (bool): If set to True, the extrema (-1 and 1) are + considered as referring to the center points of the input’s + corner pixels. If set to False, they are instead considered as + referring to the corner points of the input’s corner pixels, + making the sampling more resolution agnostic. + + This decomposition handles the following main cases: + 1. Image and grid split across batch: The input image and grid are split per batch, ensuring each sample is processed independently. + 2. Grid coordinates handling: The grid coordinates (x, y) are extracted from the flow-field grid and mapped to pixel indices, considering + the `align_corners` flag to adjust for how the extrema are interpreted. + 3. Boundary handling: The image is padded to handle boundary conditions, ensuring that when grid coordinates point outside the valid image range, + they are clamped to the image bounds. + 4. Bilinear interpolation: For each pixel in the grid, the surrounding pixels are retrieved using bilinear interpolation, which computes the + weighted average of the nearest four pixels based on the grid's fractional coordinates. + 5. Efficient computation: The image is reshaped and flattened for efficient indexing, and bilinear interpolation is computed using the + precomputed weights and indices. + 6. Results reconstruction: The interpolated values are computed for each channel and combined back into the final result, which is concatenated + along the batch dimension. + + Returns: + torch.Tensor: A tensor with sampled points, shape (N, C, Hg, Wg) + This decomposition is adapted from the implementation in + MMCV: https://mmcv.readthedocs.io/en/latest/_modules/mmcv/ops/point_sample.html + """ + + def __init__(self): + super().__init__(rewrite_once=True, require_type=True) + + self.data = wildcard() + self.grid = wildcard() + self.grid_sample = is_op("image.grid_sample")(self.data, self.grid) + self.pattern = self.grid_sample + + def callback(self, pre, post, node_map): + from tvm.relay.frontend.common import infer_shape + data = node_map[self.data][0] + grid = node_map[self.grid][0].args[0] + + mode = post.attrs["method"] + align_corners = post.attrs["align_corners"] + # Split the image and grid tensors across the batch size + batch_size = infer_shape(data)[0] + split_im = tvm.relay.split(data, indices_or_sections=batch_size, axis=0) + split_grid = tvm.relay.split(grid, indices_or_sections=batch_size, axis=0) + results = [] + for batch_idx in range(batch_size): + data = split_im[batch_idx] + grid = split_grid[batch_idx] + n, c, h, w = infer_shape(data) + gn, gh, gw, _ = infer_shape(grid) + assert len(infer_shape(data)) == 4 and len(infer_shape(grid)) == 4, "Length of data and grid shapes should be 4." + + # Extract x and y components from the grid + x = tvm.relay.strided_slice(grid, begin=[0, 0, 0, 0], end=[gn, gh, gw, 1], strides=[1, 1, 1, 1]) + x = tvm.relay.squeeze(x, axis=[3]) + y = tvm.relay.strided_slice(grid, begin=[0, 0, 0, 1], end=[gn, gh, gw, 2], strides=[1, 1, 1, 1]) + y = tvm.relay.squeeze(y, axis=[3]) + # Map grid coordinates to image pixel indices + if align_corners: + x = ((x + tvm.relay.const(1.0)) / tvm.relay.const(2.0)) * tvm.relay.const(float(w - 1)) + y = ((y + tvm.relay.const(1.0)) / tvm.relay.const(2.0)) * tvm.relay.const(float(h - 1)) + else: + x = ((x + tvm.relay.const(1.0)) * tvm.relay.const(float(w)) - tvm.relay.const(1.0)) / tvm.relay.const(2.0) + y = ((y + tvm.relay.const(1.0)) * tvm.relay.const(float(h)) - tvm.relay.const(1.0)) / tvm.relay.const(2.0) + + # Compute integer pixel indices for bilinear interpolation + if mode == 'bilinear': + x0 = tvm.relay.floor(x).astype("int32") + y0 = tvm.relay.floor(y).astype("int32") + else: + assert False, f"Unsupported mode: {mode}. Only 'bilinear' is supported." + + # Compute next indices for interpolation + x1 = x0 + tvm.relay.const(1, dtype="int32") + y1 = y0 + tvm.relay.const(1, dtype="int32") + + # Compute interpolation weights + if mode == 'bilinear': + wa = tvm.relay.expand_dims((tvm.relay.cast(x1, "float32") - x) * (tvm.relay.cast(y1, "float32") - y), axis=1) + wb = tvm.relay.expand_dims((tvm.relay.cast(x1, "float32") - x) * (y - tvm.relay.cast(y0, "float32")), axis=1) + wc = tvm.relay.expand_dims((x - tvm.relay.cast(x0, "float32")) * (tvm.relay.cast(y1, "float32") - y), axis=1) + wd = tvm.relay.expand_dims((x - tvm.relay.cast(x0, "float32")) * (y - tvm.relay.cast(y0, "float32")), axis=1) + + # Add padding to the input image to handle boundary conditions + im_padded = tvm.relay.nn.pad(data, pad_width=((0, 0), (0, 0), (1, 1), (1, 1))) + padded_h = h + 2 + padded_w = w + 2 + temp_w = padded_w - 1 + temp_h = padded_h - 1 + + # Adjust indices to include padding and clip within valid image bounds + pad_const = tvm.relay.const(1, dtype="int32") + x0 = tvm.relay.clip(x0 + pad_const, a_min=0, a_max=temp_w) + x1 = tvm.relay.clip(x1 + pad_const, a_min=0, a_max=temp_w) + y0 = tvm.relay.clip(y0 + pad_const, a_min=0, a_max=temp_h) + y1 = tvm.relay.clip(y1 + pad_const, a_min=0, a_max=temp_h) + # Precompute constants and reshape padded image for flattened indexing + padded_w_const = tvm.relay.const(padded_w, dtype="int32") + im_padded = tvm.relay.reshape(im_padded, (n, c, -1)) + + # Compute flattened indices for each corner in a compact way + def compute_flattened_indices(x, y): + return tvm.relay.reshape(x + y * padded_w_const, newshape=[-1]) + + x0_y0_flattened = compute_flattened_indices(x0, y0) + x0_y1_flattened = compute_flattened_indices(x0, y1) + x1_y0_flattened = compute_flattened_indices(x1, y0) + x1_y1_flattened = compute_flattened_indices(x1, y1) + + num_splits = infer_shape(im_padded)[1] + split_im_padded = tvm.relay.split(im_padded, indices_or_sections=num_splits, axis=1) + x0_y0_parts = [] + x0_y1_parts = [] + x1_y0_parts = [] + x1_y1_parts = [] + x0_y0 = x0 + y0 * padded_w_const + t1, t2 = infer_shape(tvm.relay.squeeze(x0_y0, axis=[0])) + flatten_shape = [t1, t2] + + # Perform bilinear interpolation for each channel + for i in range(num_splits): + im_part = split_im_padded[i] + im_part = tvm.relay.squeeze(im_part, axis=[0]) + im_part = tvm.relay.transpose(im_part, [1, 0]) + im_part = tvm.relay.expand_dims(im_part, axis=1) + + x0_y0 = tvm.relay.expand_dims(tvm.relay.reshape(tvm.relay.take(im_part, x0_y0_flattened, axis=0),newshape=flatten_shape),axis=0) + x0_y0_parts.append(x0_y0) + + x0_y1 = tvm.relay.expand_dims(tvm.relay.reshape(tvm.relay.take(im_part, x0_y1_flattened, axis=0),newshape=flatten_shape),axis=0) + x0_y1_parts.append(x0_y1) + + x1_y0 = tvm.relay.expand_dims(tvm.relay.reshape(tvm.relay.take(im_part, x1_y0_flattened, axis=0),newshape=flatten_shape),axis=0) + x1_y0_parts.append(x1_y0) + + x1_y1 = tvm.relay.expand_dims(tvm.relay.reshape(tvm.relay.take(im_part, x1_y1_flattened, axis=0),newshape=flatten_shape),axis=0) + x1_y1_parts.append(x1_y1) + + # Concatenate parts back for all channels + x0_y0 = tvm.relay.concatenate(x0_y0_parts, axis=0) + x0_y1 = tvm.relay.concatenate(x0_y1_parts, axis=0) + x1_y0 = tvm.relay.concatenate(x1_y0_parts, axis=0) + x1_y1 = tvm.relay.concatenate(x1_y1_parts, axis=0) + + # Expand dimensions to align with batch/channel format + x0_y0, x0_y1, x1_y0, x1_y1 = [tvm.relay.expand_dims(x, axis=1) for x in [x0_y0, x0_y1, x1_y0, x1_y1]] + + # Compute the final output using bilinear weights + if mode == 'bilinear': + output = (x0_y0 * wa + x0_y1 * wb + x1_y0 * wc + x1_y1 * wd) + output = tvm.relay.transpose(output, [1, 0, 2, 3]) + else: + assert False, f"Unsupported mode: {mode}. Only 'bilinear' is supported." + + results.append(output) + + # Concatenate results along the batch dimension + final_output = tvm.relay.concatenate(results, axis=0) + return final_output + + +class DecomposeFloor(DFPatternCallback): + def __init__(self): + super().__init__(rewrite_once=True, require_type=True) + + self.data = wildcard() + self.floor = is_op("floor")(self.data) + + self.pattern = self.floor + def callback(self, pre, post, node_map): + pre_node_map = construct_pre_node_map(self.pattern, pre) + data = node_map[self.data][0] + + # Decompose floor operation using the formula: + # floor(x) = int(x) - 1 if x < 0 and x != int(x) + int_part = tvm.relay.cast(data, "int32") + int_part = tvm.relay.cast(int_part, "float32") + negative_mask = tvm.relay.less(data, tvm.relay.const(0.0, "float32")) + not_equal_mask = tvm.relay.not_equal(data, int_part) + + # Compute the adjustment: subtract 1 for negative non-integer values + adjustment = tvm.relay.where(negative_mask, tvm.relay.cast(not_equal_mask, "float32"), tvm.relay.const(0.0, "float32")) + floor_result = tvm.relay.subtract(int_part, adjustment) # Subtract adjustment to get the floor value + return floor_result def _get_callback_name(callback): if isinstance(callback, DFPatternCallback): @@ -4002,6 +4199,8 @@ def run_forge_compile_passes(relay_module, params=None, inputs=None, target=None return run_pattern_callbacks( relay_module, [ + DecomposeGridSample(), + DecomposeFloor(), ExpandMultipleDims(), DecomposeReverse(), ConvertLayout(),