-
Notifications
You must be signed in to change notification settings - Fork 37
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Wrong dimension order in unpatchify? #6
Comments
You are 100% correct that this is not the case. This is a bug in my side. However its actually fine because all the info in the patch gets mapped to unpatched. Order gets only mixed within the patch, so its equivalent upto permutation, which nn.Linear will learn to recover. What I mean is that, You can see that by running the following code, that always returns true. import torch
class PatchProcessor:
def __init__(self, patch_size, out_channels):
self.patch_size = patch_size
self.out_channels = out_channels
def unpatchify(self, x):
c = self.out_channels
p = self.patch_size
h = w = int(x.shape[1] ** 0.5)
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
x = torch.einsum("nhwpqc->nchpwq", x)
imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
return imgs
def patchify(self, x):
B, C, H, W = x.size()
x = x.view(
B,
C,
H // self.patch_size,
self.patch_size,
W // self.patch_size,
self.patch_size,
)
x = x.permute(0, 2, 4, 1, 3, 5).flatten(-3).flatten(1, 2)
return x
patch_size = 4
out_channels = 3 # Assuming an RGB image
processor = PatchProcessor(patch_size, out_channels)
SIZE = 32
image = torch.arange(out_channels * SIZE * SIZE).reshape(1, out_channels, SIZE, SIZE).float()
patched_image = processor.patchify(image)
reconstructed_image = processor.unpatchify(patched_image)
for idx in range(0, SIZE // patch_size, patch_size):
for jdx in range(0, SIZE // patch_size, patch_size):
print(f"Patch ({idx}, {jdx}):")
sets_bef = set(image[:, :, idx: idx + patch_size, jdx :jdx + patch_size].flatten().tolist())
sets_aft = set(reconstructed_image[:, :, idx: idx + patch_size, jdx :jdx + patch_size].flatten().tolist())
print(sets_bef == sets_aft) However this was not intended and what you pointed out is correct. This is unnessesary channel-wise shuffle that doesnt need to be here so ill remove this in the future |
minRF/dit.py
Lines 287 to 288 in 72feb0c
Should this not be:
I would expect unpatchify( patchify( image ) ) == image but as is that is not the case.
The text was updated successfully, but these errors were encountered: