Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TF swiftformer #23342

Merged
merged 111 commits into from
Apr 19, 2024
Merged

Add TF swiftformer #23342

merged 111 commits into from
Apr 19, 2024

Conversation

joaocmd
Copy link
Contributor

@joaocmd joaocmd commented May 12, 2023

What does this PR do?

Adds the TensorFlow version of the "SwiftFormer".

Fixes #22771

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@amyeroberts @D-Roberts

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@joaocmd joaocmd changed the title Add tf swiftformer [WIP] Add tf swiftformer May 13, 2023
@amyeroberts
Copy link
Collaborator

Hi @joaocmd,
Rapid work on opening the TF port! Let me or @Rocketknight1 know when the PR is ready for review or you experience any issues when porting.

@joaocmd
Copy link
Contributor Author

joaocmd commented May 18, 2023

Hi @Rocketknight1, could I get some pointers as to why I get errors like in most of the tests:

E               ValueError: Exception encountered when calling layer 'tf_swift_former_model_18' (type TFSwiftFormerModel).
E               
E               The following keyword arguments are not supported by this model: ['input_ids'].
E               
E               Call arguments received by layer 'tf_swift_former_model_18' (type TFSwiftFormerModel):
E                 • pixel_values={'pixel_values': 'tf.Tensor(shape=(13, 224, 224, 3), dtype=float32)'}
E                 • output_hidden_states=None
E                 • return_dict=None
E                 • training=False

src/transformers/modeling_tf_utils.py:500: ValueError

The PyTorch model has this following docstring but I don't see where the input_ids part is being taken care of.

"""
    Here we also overwrite some of the tests of test_modeling_common.py, as SwiftFormer does not use input_ids, inputs_embeds,
    attention_mask and seq_length.
"""

Thanks!

@joaocmd
Copy link
Contributor Author

joaocmd commented May 28, 2023

It seems like it is entering in the else statement at line 581 of src/transformers/modeling_tf_utils.py:

if "args" in output:
    if output["args"] is not None and is_tf_symbolic_tensor(output["args"]):
        tensor_name = output["args"].name.split(":")[0]
        output[tensor_name] = output["args"]
    else:
        # `args` in this case is always the first parameter, then `input_ids`
        output["input_ids"] = output["args"]

    del output["args"]

Thus it is injecting the input_ids argument into the dictionary.

@amyeroberts @Rocketknight1 How should I get around this? It must be some misconfiguration in my tests or models.

@amyeroberts
Copy link
Collaborator

@joaocmd Just looking at the error and the CI runs, I think the issue might be a missing @unpack_inputs decorator on the call method for the MainLayer class

@joaocmd
Copy link
Contributor Author

joaocmd commented May 30, 2023

@joaocmd Just looking at the error and the CI runs, I think the issue might be a missing @unpack_inputs decorator on the call method for the MainLayer class

Thank you @amyeroberts! It seems like that wasn't causing any issue (yet), but thanks to your comment I found out that I had a duplicate @unpack_inputs in one of the models.

@joaocmd
Copy link
Contributor Author

joaocmd commented Jun 17, 2023

Hi @amyeroberts and @Rocketknight1, can I get some help with the tests that are still failing? I'm getting ValueError: cannot reshape array of size 10368 into shape (3,3,3,24) for these two tests:

  • tests/models/swiftformer/test_modeling_tf_swiftformer.py::TFSwiftFormerModelTest::test_compile_tf_model
  • tests/models/swiftformer/test_modeling_tf_swiftformer.py::TFSwiftFormerModelTest::test_save_load

But I don't understand exactly what is being reshaped into the wrong shape. Could I get some insight as to what these tests are doing and why it might be failing? Thanks!

@amyeroberts
Copy link
Collaborator

Hi @joaocmd, there's been some large updates to our TF models regarding how they're built - @Rocketknight1 can give you more details :)

Are these errors happening if you rebase on main?

@joaocmd
Copy link
Contributor Author

joaocmd commented Jun 22, 2023

Hi @joaocmd, there's been some large updates to our TF models regarding how they're built - @Rocketknight1 can give you more details :)

Are these errors happening if you rebase on main?

Hi @amyeroberts, just rebased the branch. I think it's failing on the same tests but the error on these two tests changed to:

NotImplementedError: Could not infer input image shape from config, please override input_signature to specify input shapes.

Looking at the stack trace it seems like the image size should have been specified:

if hasattr(vision_config, "image_size"):
    pixel_values_shape[2] = pixel_values_shape[3] = vision_config.image_size
elif hasattr(vision_config, "input_size"):
    pixel_values_shape[2] = pixel_values_shape[3] = vision_config.input_size
else:
   raise NotImplementedError( # <------ this error here
        "Could not infer input image shape from config, please override input_signature to specify input shapes."
    )

Shouldn't this also affect the original model?

@amyeroberts
Copy link
Collaborator

@joaocmd Regarding the error, no, it shouldn't affect the original model. image_size is a parameter we add in the configs, even if it's not always used by the model as it's often important for parameterizing other things or understanding. We allow this here. It should have been added, and we can add in this PR, but the PT model can do without.

You'll notice that the error is being raise in modeling_tf_utils.py. This is because when constructing a TF model, we have to pass in dummy inputs to build it. In PyTorch this isn't necessary, because we explicitly set the input and output dimensions when creating each layer, so the weight matrices can be created immediately. image_size is needed to know the shape of the inputs to pass in.

As a side note, did you force push after rebasing? From the PR history, it looks like you might not have. As rebasing is a form of "rewriting history" it's necessary to force push.

@joaocmd joaocmd force-pushed the add_tf_swiftformer branch from 888a7d8 to 2586f4f Compare June 23, 2023 20:45
@joaocmd
Copy link
Contributor Author

joaocmd commented Jun 23, 2023

Thanks @amyeroberts, understood. As for the rebase, I had not done one in quite some time and it seems like I did mess it up. I think that is now fixed.

Since I started this PR I have had a fundamental question about huggingface's approach to tensorflow models. The default in TensorFlow is NHWC while in PyTorch it is NCHW, how should I approach this difference in my PR? Based on modeling_tf_vit.py I suppose the correct approach is to assume that images are given in PyTorch format and transpose them in the first block, is that so? How does that affect the following blocks?
Also, if we were implementing a model for semantic segmentation, which would return an image with the same size as the original one, would that be returned in the PyTorch format or the default TensorFlow format?

Thank you!

@amyeroberts
Copy link
Collaborator

@joaocmd The pattern we use for the TF vision models is to transpose the NCHW format in the first MainLayer class e.g. here and then transpose back, if pixel values are returned e.g. here. For some of the older models e.g. ViT this pattern may not have been applied, as these were the first models to be added.

This pattern means the model is written in the TF compatible NHWC format throughout, but all of our vision models accept and return images in NCHW.

@joaocmd
Copy link
Contributor Author

joaocmd commented Jun 27, 2023

Thank you @amyeroberts, that makes sense. I've already updated it to match the pattern.

I'm still having some trouble with the test_compile_tf_model. Initially it was failing because it was passing a shape (None, 56, 56, 48) to a reshape (204e216#diff-7f093399e807b53ca4b63460f610dcc550c2937cb18cd513d71dc49ce6e1b699R385).
I changed the line to use [-1, width * height, channels] as shape, which seems like it fixed that case. However, now it is failing because a shape (None, None, None, 48) is being passed to that reshape call. Is this expected of this test? According to the stack trace it seems like it's being triggered by a tf.keras.Model.save() (https://github.com/joaocmd/transformers/blob/add_tf_swiftformer/tests/test_modeling_tf_common.py#L711).

I've also noticed that there was an overhaul to the serving and dummy_inputs interface (814de8f). But maybe @Rocketknight1 can better explain the consequences of this change to mine (and other) PRs.

@amyeroberts amyeroberts mentioned this pull request Jul 5, 2023
2 tasks
@amyeroberts
Copy link
Collaborator

@joaocmd Yes, there was a big refactor of the serving_output logic. For most models, there's no need to have serving_output, dummy_inputs or serving implemented. You should be able to remove these and have the test_prepare_serving_output test pass.

Looking at the CI run, I don't see test_compile_tf_model failing. Were you able to resolve? Or perhaps are you refering to test_save_load?

@github-actions
Copy link

github-actions bot commented Aug 8, 2023

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@joaocmd
Copy link
Contributor Author

joaocmd commented Aug 13, 2023

Hi @amyeroberts! Sorry for the late response as I've been quite busy... It was failing more tests on my local machine than on the CI run, but after merging the main branch locally they now seem to match.
I am currently struggling with test_save_load:

ValueError: cannot reshape array of size 10368 into shape (3,3,3,24)

I can't find the reason for this error. I set up a breakpoint and found that the symbolic_weight_name at that point is kernel:0, so I assume it belongs to some convolutional layer, but I didn't get any further than that. Do you have any suggestions? Thank you!

Edit:

I believe the weight belongs to the patch embeddings layer, which I initialized with a tf.keras.Sequential call:

self.patch_embedding = tf.keras.Sequential(
    [
        tf.keras.layers.ZeroPadding2D(padding=(1, 1)),
        tf.keras.layers.Conv2D(out_chs // 2, kernel_size=3, strides=2),
        tf.keras.layers.BatchNormalization(
            epsilon=config.batch_norm_eps, momentum=0.9
        ),  # FIXME: is this the equivalent momentum?
        tf.keras.layers.Activation("relu"),
        tf.keras.layers.ZeroPadding2D(padding=(1, 1)),
        tf.keras.layers.Conv2D(out_chs, kernel_size=3, strides=2),
        tf.keras.layers.BatchNormalization(
            epsilon=config.batch_norm_eps, momentum=0.9
        ),  # FIXME: is this the equivalent momentum?
        tf.keras.layers.Activation("relu"),
    ],
    name="patch_embeddings",
)

I think the problem is that both Conv2D are being given the same name, what is the correct approach for this? Should I rewrite the pytorch version to not use nn.Sequential?

@amyeroberts
Copy link
Collaborator

@joaocmd I would suggest rewriting the torch version to not use sequential, but only for the purposes of debugging i.e. we wouldn't commit these changes to main. This way you'll be able to go line by line comparing the TF and PT outputs and seeing where any shape differences are coming from.

@joaocmd
Copy link
Contributor Author

joaocmd commented Aug 15, 2023

Hi @amyeroberts I might be misunderstanding something but I think test_save_load does not test any PyTorch to TensorFlow equivalence. I think the problem is that when the two convolutional layers inside the Sequential module are saved they are stored under the same name, so a shape mismatch happens. Do I understand this correctly?

@amyeroberts
Copy link
Collaborator

@joaocmd Ah, apologies, I misread your comment. Yes, I believe you're right about the the naming issue. What I suggest is follow the pattern in other ports where nn.Sequential has been used. For example in deit, for the sequential block in PT, a new layer is implemented for TF.

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@amyeroberts
Copy link
Collaborator

@joaocmd From next week I'll be off for a few weeks. If you have any implementation questions, please ask @Rocketknight1 :)

@joaocmd
Copy link
Contributor Author

joaocmd commented Sep 24, 2023

Hi @Rocketknight1, could you give me some pointers on the three remaining tests? I haven't looked closely into test_modeling_tf_swiftformer.py::TFSwiftFormerModelIntegrationTest::test_inference_image_classification_head yet because I think it makes sense to leave that one for last, but correct me if I'm wrong.

However, I am trying to understand what is wrong with TFSwiftFormerModelTest::test_save_load - AssertionError: 0.42552373 not less than or equal to 1e-05 but I have come to no conclusion yet.

There is also this current error that might be due to some misnamed layers, but I am not sure: tests/models/swiftformer/test_modeling_tf_swiftformer.py::TFSwiftFormerModelTest::test_pt_tf_model_equivalence - AttributeError: patch_embed.patch_embeddings.0.weight not found in PyTorch model.

Thank you!

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot closed this Oct 27, 2023
@amyeroberts
Copy link
Collaborator

@joaocmd Are you still working on this? If so, @Rocketknight1 could you help?

@joaocmd
Copy link
Contributor Author

joaocmd commented Oct 28, 2023

Hi @amyeroberts , I haven't made any changes since my last comment as I was stuck and had some other responsibilities. I would like to finish the issue especially because I believe it's very close to finishing.

@Rocketknight1
Copy link
Member

Hi, I'm sorry, I'm not sure how I missed your last comment - this is entirely my fault! Let me investigate the errors you were getting and I'll see if we can get this PR over the line.

@joaocmd joaocmd force-pushed the add_tf_swiftformer branch from 3a2b6fd to 200dbd0 Compare April 16, 2024 20:22
@joaocmd
Copy link
Contributor Author

joaocmd commented Apr 16, 2024

Hi @amyeroberts! I think I've addressed all the comments. The CI pipeline is failing on a ruff check even though I've run it locally. Do you know why?

@amyeroberts
Copy link
Collaborator

@joaocmd It might be a version mismatch. To ensure you have the same versions of libraries like ruff in your env, you can run pip install -e .[quality]. I'd then try and run make fixup again

@joaocmd
Copy link
Contributor Author

joaocmd commented Apr 17, 2024

Thanks @amyeroberts! Is there anything missing on my side?

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great - thanks for all the work adding this model!

@amyeroberts amyeroberts changed the title [WIP] Add tf swiftformer Add TF swiftformer Apr 19, 2024
@amyeroberts amyeroberts merged commit d2cec09 into huggingface:main Apr 19, 2024
23 checks passed
@joaocmd
Copy link
Contributor Author

joaocmd commented Apr 19, 2024

Thank you @amyeroberts and @Rocketknight1 for your patience and support! It was a pleasure ☺️

ArthurZucker pushed a commit that referenced this pull request Apr 22, 2024
* Duplicate swiftformer

* Convert SwiftFormerPatchEmbedding

* Convert SwiftFormerEmbeddings

* Convert TFSwiftFormerMlp

* Convert TFSwiftFormerConvEncoder

* Convert TFSwiftFormerLocalRepresentation

* convert TFSwiftFormerEncoderBlock

* Convert SwiftFormerStage

* Convert SwiftFormerEncoder

* Add TFSWiftFormerPreTrainedModel

* Convert SwiftFormerForImageClassification

* Add kwargs and start drop path

* Fix syntax

* Change Model class name

* Add TFSwiftFormer to __init__

* Duplicate test_modeling_swiftformer

* First test conversions

* Change require_torch to require_tf

* Add exports to swiftformer __init__

* Add TFSwiftFormerModel wrapper

* Fix __init__ and run black

* Remove docstring from MainLayer, fix padding

* Use keras.layers.Activation on keras.Sequential

* Fix swiftformer exports

* Fix activation layer from config

* Remove post_inits

* Use tf.keras.layers.ZeroPadding2D

* Convert torch normalize

* Change tf test input shape

* Fix softmax and reduce_sum

* Convert expand_dims and repeat

* Add missing reshape and tranpose

* Simplify TFSwiftFormerEncoderBlock.call

* Fix mismatch in patch embeddings

* Fix expected output shape to match channels last

* Fix swiftformer typo

* Disable test_onnx

* Fix TFSwiftFormerForImageClassification call

* Add unpack inputs

* Convert flatten(2).mean(-1)

* Change vision dummy inputs (to be reviewed)

* Change test_forward_signature to use .call

* Fix @unpack_inputs

* Set return_tensors="tf" and rename class

* Rename wrongly named patch_embeddings layer

* Add serving_output and change dummy_input shape

* Make dimensions BCHW and transpose inside embedding layer

* Change SwiftFormerEncoderBlock

* Fix ruff problems

* Add image size to swiftformer config

* Change tranpose to MainLayer and use -1 for reshape

* Remove serving_outputs and dummy_inputs

* Remove test_initialization test from tf model

* Make Sequential component a separate layer

* Fix layers' names

* Tranpose encoder outputs

* Fix tests and check if hidden states is not None

* Fix TFSwiftFormerForImageClassification

* Run make fixup

* Run make fix-copies

* Update modeling_tf_auto

* Update docs

* Fix modeling auto mapping

* Update modelint_tf_swiftformer docs

* Fill image_size doc and type

* Add reduction=None to loss computation

* Update docs

* make style

* Debug: Delete the tip to see if that changes anything

* Re-add tip

* Remove add_code_sample_docstrings

* Remove unused import

* Get the debug to actually tell us the problem it has with the docs

* Try a substitution to match the PyTorch file?

* Add swiftformer to ignore list

* Add build() methods

* Update copyright year

Co-authored-by: amyeroberts <[email protected]>

* Remove FIXME comment

* Remove from_pt

* Update copyright year

Co-authored-by: amyeroberts <[email protected]>

* Rename one-letter variables

* Remove FIXMEs related to momentum

* Remove old TODO comment

* Remove outstanding FIXME comments

* Get dropout rate from config

* Add specific dropout config for MLP

* Add convencoder dropout to config

* Pass config to SwiftFormerDropPath layer

* Fix drop_path variable name and add Adapted from comment

* Run ruff

* Removed copied from comment

* Run fix copies

* Change drop_path to identity to match pt

* Cleanup build() methods and move to new keras imports

* Update docs/source/en/model_doc/swiftformer.md

Co-authored-by: Matt <[email protected]>

* Raise error if drop_path_rate > 0.0

* Apply suggestions from code review

Replace (self.dim), with self.dim,

Co-authored-by: Matt <[email protected]>

* Remove drop_path function

* Add training to TFSwiftFormerEncoder

* Set self.built = True last

Co-authored-by: amyeroberts <[email protected]>

* Should have been added to previous commit

Co-authored-by: amyeroberts <[email protected]>

* Apply suggestions from code review

Co-authored-by: amyeroberts <[email protected]>

* Change default_feature_extractor to default_image_processor

Co-authored-by: amyeroberts <[email protected]>

* Import Keras from modeling_tf_utils

* Remove relative import

* Run ruff --fix

* Move import keras to tf_available

* Add copied from comment to test_forward_signature

* Reduce batch size and num_labels

* Extract loss logic to hf_compute_loss

* Run ruff format

---------

Co-authored-by: Matt <[email protected]>
Co-authored-by: amyeroberts <[email protected]>
Co-authored-by: Matt <[email protected]>
ydshieh pushed a commit that referenced this pull request Apr 23, 2024
* Duplicate swiftformer

* Convert SwiftFormerPatchEmbedding

* Convert SwiftFormerEmbeddings

* Convert TFSwiftFormerMlp

* Convert TFSwiftFormerConvEncoder

* Convert TFSwiftFormerLocalRepresentation

* convert TFSwiftFormerEncoderBlock

* Convert SwiftFormerStage

* Convert SwiftFormerEncoder

* Add TFSWiftFormerPreTrainedModel

* Convert SwiftFormerForImageClassification

* Add kwargs and start drop path

* Fix syntax

* Change Model class name

* Add TFSwiftFormer to __init__

* Duplicate test_modeling_swiftformer

* First test conversions

* Change require_torch to require_tf

* Add exports to swiftformer __init__

* Add TFSwiftFormerModel wrapper

* Fix __init__ and run black

* Remove docstring from MainLayer, fix padding

* Use keras.layers.Activation on keras.Sequential

* Fix swiftformer exports

* Fix activation layer from config

* Remove post_inits

* Use tf.keras.layers.ZeroPadding2D

* Convert torch normalize

* Change tf test input shape

* Fix softmax and reduce_sum

* Convert expand_dims and repeat

* Add missing reshape and tranpose

* Simplify TFSwiftFormerEncoderBlock.call

* Fix mismatch in patch embeddings

* Fix expected output shape to match channels last

* Fix swiftformer typo

* Disable test_onnx

* Fix TFSwiftFormerForImageClassification call

* Add unpack inputs

* Convert flatten(2).mean(-1)

* Change vision dummy inputs (to be reviewed)

* Change test_forward_signature to use .call

* Fix @unpack_inputs

* Set return_tensors="tf" and rename class

* Rename wrongly named patch_embeddings layer

* Add serving_output and change dummy_input shape

* Make dimensions BCHW and transpose inside embedding layer

* Change SwiftFormerEncoderBlock

* Fix ruff problems

* Add image size to swiftformer config

* Change tranpose to MainLayer and use -1 for reshape

* Remove serving_outputs and dummy_inputs

* Remove test_initialization test from tf model

* Make Sequential component a separate layer

* Fix layers' names

* Tranpose encoder outputs

* Fix tests and check if hidden states is not None

* Fix TFSwiftFormerForImageClassification

* Run make fixup

* Run make fix-copies

* Update modeling_tf_auto

* Update docs

* Fix modeling auto mapping

* Update modelint_tf_swiftformer docs

* Fill image_size doc and type

* Add reduction=None to loss computation

* Update docs

* make style

* Debug: Delete the tip to see if that changes anything

* Re-add tip

* Remove add_code_sample_docstrings

* Remove unused import

* Get the debug to actually tell us the problem it has with the docs

* Try a substitution to match the PyTorch file?

* Add swiftformer to ignore list

* Add build() methods

* Update copyright year

Co-authored-by: amyeroberts <[email protected]>

* Remove FIXME comment

* Remove from_pt

* Update copyright year

Co-authored-by: amyeroberts <[email protected]>

* Rename one-letter variables

* Remove FIXMEs related to momentum

* Remove old TODO comment

* Remove outstanding FIXME comments

* Get dropout rate from config

* Add specific dropout config for MLP

* Add convencoder dropout to config

* Pass config to SwiftFormerDropPath layer

* Fix drop_path variable name and add Adapted from comment

* Run ruff

* Removed copied from comment

* Run fix copies

* Change drop_path to identity to match pt

* Cleanup build() methods and move to new keras imports

* Update docs/source/en/model_doc/swiftformer.md

Co-authored-by: Matt <[email protected]>

* Raise error if drop_path_rate > 0.0

* Apply suggestions from code review

Replace (self.dim), with self.dim,

Co-authored-by: Matt <[email protected]>

* Remove drop_path function

* Add training to TFSwiftFormerEncoder

* Set self.built = True last

Co-authored-by: amyeroberts <[email protected]>

* Should have been added to previous commit

Co-authored-by: amyeroberts <[email protected]>

* Apply suggestions from code review

Co-authored-by: amyeroberts <[email protected]>

* Change default_feature_extractor to default_image_processor

Co-authored-by: amyeroberts <[email protected]>

* Import Keras from modeling_tf_utils

* Remove relative import

* Run ruff --fix

* Move import keras to tf_available

* Add copied from comment to test_forward_signature

* Reduce batch size and num_labels

* Extract loss logic to hf_compute_loss

* Run ruff format

---------

Co-authored-by: Matt <[email protected]>
Co-authored-by: amyeroberts <[email protected]>
Co-authored-by: Matt <[email protected]>
itazap pushed a commit that referenced this pull request May 14, 2024
* Duplicate swiftformer

* Convert SwiftFormerPatchEmbedding

* Convert SwiftFormerEmbeddings

* Convert TFSwiftFormerMlp

* Convert TFSwiftFormerConvEncoder

* Convert TFSwiftFormerLocalRepresentation

* convert TFSwiftFormerEncoderBlock

* Convert SwiftFormerStage

* Convert SwiftFormerEncoder

* Add TFSWiftFormerPreTrainedModel

* Convert SwiftFormerForImageClassification

* Add kwargs and start drop path

* Fix syntax

* Change Model class name

* Add TFSwiftFormer to __init__

* Duplicate test_modeling_swiftformer

* First test conversions

* Change require_torch to require_tf

* Add exports to swiftformer __init__

* Add TFSwiftFormerModel wrapper

* Fix __init__ and run black

* Remove docstring from MainLayer, fix padding

* Use keras.layers.Activation on keras.Sequential

* Fix swiftformer exports

* Fix activation layer from config

* Remove post_inits

* Use tf.keras.layers.ZeroPadding2D

* Convert torch normalize

* Change tf test input shape

* Fix softmax and reduce_sum

* Convert expand_dims and repeat

* Add missing reshape and tranpose

* Simplify TFSwiftFormerEncoderBlock.call

* Fix mismatch in patch embeddings

* Fix expected output shape to match channels last

* Fix swiftformer typo

* Disable test_onnx

* Fix TFSwiftFormerForImageClassification call

* Add unpack inputs

* Convert flatten(2).mean(-1)

* Change vision dummy inputs (to be reviewed)

* Change test_forward_signature to use .call

* Fix @unpack_inputs

* Set return_tensors="tf" and rename class

* Rename wrongly named patch_embeddings layer

* Add serving_output and change dummy_input shape

* Make dimensions BCHW and transpose inside embedding layer

* Change SwiftFormerEncoderBlock

* Fix ruff problems

* Add image size to swiftformer config

* Change tranpose to MainLayer and use -1 for reshape

* Remove serving_outputs and dummy_inputs

* Remove test_initialization test from tf model

* Make Sequential component a separate layer

* Fix layers' names

* Tranpose encoder outputs

* Fix tests and check if hidden states is not None

* Fix TFSwiftFormerForImageClassification

* Run make fixup

* Run make fix-copies

* Update modeling_tf_auto

* Update docs

* Fix modeling auto mapping

* Update modelint_tf_swiftformer docs

* Fill image_size doc and type

* Add reduction=None to loss computation

* Update docs

* make style

* Debug: Delete the tip to see if that changes anything

* Re-add tip

* Remove add_code_sample_docstrings

* Remove unused import

* Get the debug to actually tell us the problem it has with the docs

* Try a substitution to match the PyTorch file?

* Add swiftformer to ignore list

* Add build() methods

* Update copyright year

Co-authored-by: amyeroberts <[email protected]>

* Remove FIXME comment

* Remove from_pt

* Update copyright year

Co-authored-by: amyeroberts <[email protected]>

* Rename one-letter variables

* Remove FIXMEs related to momentum

* Remove old TODO comment

* Remove outstanding FIXME comments

* Get dropout rate from config

* Add specific dropout config for MLP

* Add convencoder dropout to config

* Pass config to SwiftFormerDropPath layer

* Fix drop_path variable name and add Adapted from comment

* Run ruff

* Removed copied from comment

* Run fix copies

* Change drop_path to identity to match pt

* Cleanup build() methods and move to new keras imports

* Update docs/source/en/model_doc/swiftformer.md

Co-authored-by: Matt <[email protected]>

* Raise error if drop_path_rate > 0.0

* Apply suggestions from code review

Replace (self.dim), with self.dim,

Co-authored-by: Matt <[email protected]>

* Remove drop_path function

* Add training to TFSwiftFormerEncoder

* Set self.built = True last

Co-authored-by: amyeroberts <[email protected]>

* Should have been added to previous commit

Co-authored-by: amyeroberts <[email protected]>

* Apply suggestions from code review

Co-authored-by: amyeroberts <[email protected]>

* Change default_feature_extractor to default_image_processor

Co-authored-by: amyeroberts <[email protected]>

* Import Keras from modeling_tf_utils

* Remove relative import

* Run ruff --fix

* Move import keras to tf_available

* Add copied from comment to test_forward_signature

* Reduce batch size and num_labels

* Extract loss logic to hf_compute_loss

* Run ruff format

---------

Co-authored-by: Matt <[email protected]>
Co-authored-by: amyeroberts <[email protected]>
Co-authored-by: Matt <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

TF Swiftformer
4 participants