diff --git a/setup.py b/setup.py index 73ce1b0..6b1c0c9 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'transformer-in-transformer', packages = find_packages(), - version = '0.1.1', + version = '0.1.2', license='MIT', description = 'Transformer in Transformer - Pytorch', author = 'Phil Wang', diff --git a/transformer_in_transformer/tnt.py b/transformer_in_transformer/tnt.py index 17041f4..845f9c5 100644 --- a/transformer_in_transformer/tnt.py +++ b/transformer_in_transformer/tnt.py @@ -89,6 +89,7 @@ def __init__( pixel_size, depth, num_classes, + channels = 3, heads = 8, dim_head = 64, ff_dropout = 0., @@ -116,7 +117,7 @@ def __init__( Rearrange('b c (h p1) (w p2) -> (b h w) c p1 p2', p1 = patch_size, p2 = patch_size), nn.Unfold(kernel_size = kernel_size, stride = stride, padding = padding), Rearrange('... c n -> ... n c'), - nn.Linear(3 * kernel_size ** 2, pixel_dim) + nn.Linear(channels * kernel_size ** 2, pixel_dim) ) self.patch_pos_emb = nn.Parameter(torch.randn(num_patch_tokens + 1, patch_dim))