Skip to content

Commit

Permalink
MiT PatchEmbedding input shape fix (#140)
Browse files Browse the repository at this point in the history
* fix

* MiT patch embedding fix

* adjustment
  • Loading branch information
DavidLandup0 authored Jun 4, 2023
1 parent c256451 commit f45e928
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion deepvision/models/classification/mix_transformer/mit_pt.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def __init__(

for i in range(self.num_stages):
patch_embed_layer = OverlappingPatchingAndEmbedding(
in_channels=3 if i == 0 else embed_dims[i - 1],
in_channels=input_shape[0] if i == 0 else embed_dims[i - 1],
out_channels=embed_dims[0] if i == 0 else embed_dims[i],
patch_size=7 if i == 0 else 3,
stride=4 if i == 0 else 2,
Expand Down
2 changes: 1 addition & 1 deletion deepvision/models/classification/mix_transformer/mit_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def __init__(

for i in range(num_stages):
patch_embed_layer = OverlappingPatchingAndEmbedding(
in_channels=3 if i == 0 else embed_dims[i - 1],
in_channels=input_shape[-1] if i == 0 else embed_dims[i - 1],
out_channels=embed_dims[0] if i == 0 else embed_dims[i],
patch_size=7 if i == 0 else 3,
stride=4 if i == 0 else 2,
Expand Down

0 comments on commit f45e928

Please sign in to comment.