Skip to content

Commit

Permalink
Fix docstrings and args default values (keras-team#1938)
Browse files Browse the repository at this point in the history
* update sam docstring to show correct backbone in docstring

* update mix_transformer to mit

* update file names

* update init files

* address matt's feedback
  • Loading branch information
divyashreepathihalli authored Oct 18, 2024
1 parent adce690 commit efedb6f
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 5 deletions.
2 changes: 1 addition & 1 deletion keras_hub/src/models/mit/mit_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __init__(
```python
images = np.ones(shape=(1, 96, 96, 3))
labels = np.zeros(shape=(1, 96, 96, 1))
backbone = keras_hub.models.MiTBackbone.from_preset("mit_b0_imagenet")
backbone = keras_hub.models.MiTBackbone.from_preset("mit_b0_ade20k_512")
# Evaluate model
model(images)
Expand Down
2 changes: 1 addition & 1 deletion keras_hub/src/models/mobilenet/mobilenet_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def __init__(
stackwise_activation,
output_num_filters,
inverted_res_block,
image_shape=(224, 224, 3),
image_shape=(None, None, 3),
input_activation="hard_swish",
output_activation="hard_swish",
depth_multiplier=1.0,
Expand Down
2 changes: 1 addition & 1 deletion keras_hub/src/models/resnet/resnet_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class ResNetBackbone(FeaturePyramidBackbone):
input_data = np.random.uniform(0, 1, size=(2, 224, 224, 3))
# Pretrained ResNet backbone.
model = keras_hub.models.ResNetBackbone.from_preset("resnet50")
model = keras_hub.models.ResNetBackbone.from_preset("resnet_50_imagenet")
model(input_data)
# Randomly initialized ResNetV2 backbone with a custom config.
Expand Down
14 changes: 13 additions & 1 deletion keras_hub/src/models/vae/vae_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@


class VAEBackbone(Backbone):
"""VAE backbone used in latent diffusion models.
"""Variational Autoencoder(VAE) backbone used in latent diffusion models.
When encoding, this model generates mean and log variance of the input
images. When decoding, it reconstructs images from the latent space.
Expand Down Expand Up @@ -51,6 +51,18 @@ class VAEBackbone(Backbone):
`"channels_last"`.
dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype
to use for the model's computations and weights.
Example:
```Python
backbone = VAEBackbone(
encoder_num_filters=[32, 32, 32, 32],
encoder_num_blocks=[1, 1, 1, 1],
decoder_num_filters=[32, 32, 32, 32],
decoder_num_blocks=[1, 1, 1, 1],
)
input_data = ops.ones((2, self.height, self.width, 3))
output = backbone(input_data)
```
"""

def __init__(
Expand Down
2 changes: 1 addition & 1 deletion keras_hub/src/models/vgg/vgg_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class VGGBackbone(Backbone):
stackwise_num_filters: list of ints, filter size for convolutional
blocks per VGG block. For both VGG16 and VGG19 this is [
64, 128, 256, 512, 512].
image_shape: tuple, optional shape tuple, defaults to (224, 224, 3).
image_shape: tuple, optional shape tuple, defaults to (None, None, 3).
Examples:
```python
Expand Down

0 comments on commit efedb6f

Please sign in to comment.