Skip to content

Commit

Permalink
Make dimensions BCHW and transpose inside embedding layer
Browse files Browse the repository at this point in the history
  • Loading branch information
joaocmd committed Jun 22, 2023
1 parent 2d49c7d commit f5412dd
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def __init__(self, config: SwiftFormerConfig, **kwargs):
)

def call(self, x: tf.Tensor, training: bool = False) -> tf.Tensor:
x = tf.transpose(x, perm=(0, 2, 3, 1))
return self.patch_embedding(x, training=training)


Expand Down
10 changes: 7 additions & 3 deletions tests/models/swiftformer/test_modeling_tf_swiftformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ def __init__(

def prepare_config_and_inputs(self):
# FIXME: should be same shape as pytorch version??
pixel_values = floats_tensor([self.batch_size, self.image_size, self.image_size, self.num_channels])
# pixel_values = floats_tensor([self.batch_size, self.image_size, self.image_size, self.num_channels])
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])

labels = None
if self.use_labels:
Expand Down Expand Up @@ -104,6 +105,7 @@ def get_config(self):
def create_and_check_model(self, config, pixel_values, labels):
model = TFSwiftFormerModel(config=config)
result = model(pixel_values)
# FIXME: channels_first or last?
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, 7, 7, self.embed_dims[-1]))

def create_and_check_for_image_classification(self, config, pixel_values, labels):
Expand Down Expand Up @@ -160,7 +162,7 @@ def setUp(self):
def test_config(self):
self.config_tester.run_common_tests()

@unittest.skip(reason="SwiftFormer does not use inputs_embeds")
@unittest.skip(reason="TFSwiftFormer does not use inputs_embeds")
def test_inputs_embeds(self):
pass

Expand Down Expand Up @@ -198,7 +200,7 @@ def test_model_from_pretrained(self):
model = TFSwiftFormerModel.from_pretrained(model_name, from_pt=True)
self.assertIsNotNone(model)

@unittest.skip(reason="SwiftFormer does not output attentions")
@unittest.skip(reason="TFSwiftFormer does not output attentions")
def test_attention_outputs(self):
pass

Expand All @@ -220,7 +222,9 @@ def check_hidden_states_output(inputs_dict, config, model_class):
hidden_states[i].shape,
tf.TensorShape(
[
# FIXME: channels_last?
self.model_tester.batch_size,
# self.model_tester.embed_dims[i // 2],
(self.model_tester.image_size // 4) // 2 ** (i // 2),
(self.model_tester.image_size // 4) // 2 ** (i // 2),
self.model_tester.embed_dims[i // 2],
Expand Down

0 comments on commit f5412dd

Please sign in to comment.