Skip to content

Commit

Permalink
Change SwiftFormerEncoderBlock
Browse files Browse the repository at this point in the history
  • Loading branch information
joaocmd committed Jun 22, 2023
1 parent f5412dd commit 6e8fc5b
Showing 1 changed file with 12 additions and 10 deletions.
22 changes: 12 additions & 10 deletions src/transformers/models/swiftformer/modeling_tf_swiftformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,10 +383,11 @@ def build(self, input_shape: tf.TensorShape):

def call(self, x: tf.Tensor, training: bool = False):
x = self.local_representation(x, training=training)
batch_size, channels, height, width = x.shape
res = self.attn(tf.reshape(tf.transpose(x, perm=(0, 2, 3, 1)), (batch_size, height * width, channels)))
batch_size, height, width, channels = x.shape
# FIXME: pytorch -> b c h w -> b h w c (tensorflow, keep same order)
# Attention layer uses channels last why? should I go for channels first?
res = self.attn(tf.reshape(x, (batch_size, height * width, channels)))
res = tf.reshape(res, (batch_size, height, width, channels))
res = tf.tranpose(res, perm=(0, 3, 1, 2))
if self.use_layer_scale:
x = x + self.drop_path(self.layer_scale_1 * res, training=training)
x = x + self.drop_path(self.layer_scale_2 * self.linear(x), training=training)
Expand Down Expand Up @@ -498,7 +499,7 @@ def dummy_inputs(self) -> Dict[str, tf.Tensor]:
"""
VISION_DUMMY_INPUTS = tf.random.uniform(
# FIXME: In the vit these values come from the config except the batch size, what should I put here?
shape=(5, self.config.num_channels, 256, 256),
shape=(5, self.config.num_channels, 224, 224),
dtype=tf.float32,
)
return {"pixel_values": tf.constant(VISION_DUMMY_INPUTS)}
Expand Down Expand Up @@ -738,14 +739,15 @@ def call(
loss = loss_fct(labels.squeeze(), logits.squeeze())
else:
loss = loss_fct(labels, logits)
# FIXME: multilabel or multiclass?
elif self.config.problem_type == "single_label_classification":
loss_fct = tf.keras.losses.CategoricalCrossentropy
loss = loss_fct(labels.view(-1), logits.view(-1, self.num_labels), from_logits=False)
# FIXME: from_logits? Initially I had False from somewhere
loss_fct = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
loss = loss_fct(labels, logits)
elif self.config.problem_type == "multi_label_classification":
loss_fct = tf.keras.losses.CategoricalCrossentropy
loss = loss_fct(
labels, logits, from_logits=True
) # FIXME: should we use from_logits in multi_label_classification?
# FIXME: from_logits? Initially I had False from somewhere
loss_fct = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
loss = loss_fct(labels, logits)

if not return_dict:
output = (logits,) + outputs[1:]
Expand Down

0 comments on commit 6e8fc5b

Please sign in to comment.