From 6e8fc5bcbc58805e9f188d9def6189ffd674b604 Mon Sep 17 00:00:00 2001 From: joaocmd Date: Sat, 17 Jun 2023 18:43:56 +0100 Subject: [PATCH] Change SwiftFormerEncoderBlock --- .../swiftformer/modeling_tf_swiftformer.py | 22 ++++++++++--------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/src/transformers/models/swiftformer/modeling_tf_swiftformer.py b/src/transformers/models/swiftformer/modeling_tf_swiftformer.py index b26813181887a4..84f76cac70d471 100644 --- a/src/transformers/models/swiftformer/modeling_tf_swiftformer.py +++ b/src/transformers/models/swiftformer/modeling_tf_swiftformer.py @@ -383,10 +383,11 @@ def build(self, input_shape: tf.TensorShape): def call(self, x: tf.Tensor, training: bool = False): x = self.local_representation(x, training=training) - batch_size, channels, height, width = x.shape - res = self.attn(tf.reshape(tf.transpose(x, perm=(0, 2, 3, 1)), (batch_size, height * width, channels))) + batch_size, height, width, channels = x.shape + # FIXME: pytorch -> b c h w -> b h w c (tensorflow, keep same order) + # Attention layer uses channels last why? should I go for channels first? + res = self.attn(tf.reshape(x, (batch_size, height * width, channels))) res = tf.reshape(res, (batch_size, height, width, channels)) - res = tf.tranpose(res, perm=(0, 3, 1, 2)) if self.use_layer_scale: x = x + self.drop_path(self.layer_scale_1 * res, training=training) x = x + self.drop_path(self.layer_scale_2 * self.linear(x), training=training) @@ -498,7 +499,7 @@ def dummy_inputs(self) -> Dict[str, tf.Tensor]: """ VISION_DUMMY_INPUTS = tf.random.uniform( # FIXME: In the vit these values come from the config except the batch size, what should I put here? - shape=(5, self.config.num_channels, 256, 256), + shape=(5, self.config.num_channels, 224, 224), dtype=tf.float32, ) return {"pixel_values": tf.constant(VISION_DUMMY_INPUTS)} @@ -738,14 +739,15 @@ def call( loss = loss_fct(labels.squeeze(), logits.squeeze()) else: loss = loss_fct(labels, logits) + # FIXME: multilabel or multiclass? elif self.config.problem_type == "single_label_classification": - loss_fct = tf.keras.losses.CategoricalCrossentropy - loss = loss_fct(labels.view(-1), logits.view(-1, self.num_labels), from_logits=False) + # FIXME: from_logits? Initially I had False from somewhere + loss_fct = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) + loss = loss_fct(labels, logits) elif self.config.problem_type == "multi_label_classification": - loss_fct = tf.keras.losses.CategoricalCrossentropy - loss = loss_fct( - labels, logits, from_logits=True - ) # FIXME: should we use from_logits in multi_label_classification? + # FIXME: from_logits? Initially I had False from somewhere + loss_fct = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) + loss = loss_fct(labels, logits) if not return_dict: output = (logits,) + outputs[1:]