-
-
Notifications
You must be signed in to change notification settings - Fork 49
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use Argument linking to link init_args
to dict_kwargs
.
#375
Comments
Does your model have init parameters parser.link_arguments("data.input_width", "model.init_args.input_width")
parser.link_arguments("data.input_height", "model.init_args.input_height") Any particular reason why you are using |
Not all of the models I used have init parameters The reason why I use class LitClassification(pl.LightningModule):
"""This is a PyTorch lightning wrapper for Image Classification tasks."""
def __init__(
self,
num_classes: int,
model_name: str = "resnet18",
loss_name: str = "cross_entropy_loss",
in_chans: int = 3,
pretrained: bool = False,
init_lr: float = 1e-3,
weight_decay: float = 1e-2,
augmentations: list[str] = [],
*args,
**kwargs
) -> None:
super().__init__()
self.model = create_model(model_name, in_chans, num_classes, pretrained, *args, **kwargs)
self.loss_fn = create_loss(loss_name, *args, **kwargs)
self.lr = init_lr
self.weight_decay = weight_decay
self.num_classes = num_classes
self.augmentation = ImageAugmentation(augmentations)
self.save_hyperparameters()
def create_model(
model_name: str, in_chans: int, num_classes: int, pretrained: bool = False, *args, **kwargs
) -> nn.Module:
"""Factory function for creating models.
Args:
model_name (str): Model name.
in_chans (int): Number of model input channels.
num_classes (int): Number of model classes.
pretrained (bool, optional): If `True`, load pretrained weight. Defaults to `False`.
Returns:
nn.Module: Created model.
"""
return eval(model_name)(in_chans, num_classes, pretrained=pretrained, *args, **kwargs) |
Argument linking
to link init_args
to dict_kwargs
.init_args
to dict_kwargs
.
The recommended way of implementing submodules (e.g. backbone, loss) is via dependency injection, see models-with-multiple-submodules. Linking with target Do you see viable to change your code to use dependency injection? |
I understand. Using dependency injection could not solve my problem, I want to create a simple short unified command line, but dependency injection would make the Thanks for your advice anyway. |
You could add |
Why exactly do you say this? If dependency injection is used, from command line a class can be selected like model:
backbone: {class_name} The automatically saved config indeed would be more complex because the submodule gets expended to a nested |
I'm building a model training API server, I want the request body (training config) to be as simple as possible because our users have no ML experience,and the request body structure is fixed, some extra arguments are under the So I think linking |
Using dependency injection sounds great, I just need to write more data validation code when developing my API 🤔 . |
🚀 Feature request
Use Argument linking to link
init_args
todict_kwargs
Motivation
I try to link
data.input_width
tomodel.dict_kwargs.input_width
because transformer models need input image shape for model initialization, but most CNNs don't need it, so have I a factory function to create both CNNs and Transformers, it uses kwargs to adapt both types of models, so I want to link dataset image shape information to model initialization kwargs.When I use the code below
It raises an error:
Pitch
If we can randomly link
init_args
anddict_kwargs
would be great, or is this feature already implemented?Alternatives
The text was updated successfully, but these errors were encountered: