diff --git a/pvnet/models/multimodal/site_encoders/encoders.py b/pvnet/models/multimodal/site_encoders/encoders.py index 4821968f..90588039 100644 --- a/pvnet/models/multimodal/site_encoders/encoders.py +++ b/pvnet/models/multimodal/site_encoders/encoders.py @@ -11,8 +11,7 @@ class SimpleLearnedAggregator(AbstractPVSitesEncoder): - """A simple model which learns a different weighted-average across all of the PV sites for each - GSP. + """A simple model which learns a different weighted-average across all PV sites for each GSP. Each sequence from each site is independently encodeded through some dense layers wih skip- connections, then the encoded form of each sequence is aggregated through a learned weighted-sum diff --git a/pvnet/optimizers.py b/pvnet/optimizers.py index 55aa5493..fda5c0fb 100644 --- a/pvnet/optimizers.py +++ b/pvnet/optimizers.py @@ -52,6 +52,12 @@ def __call__(self, model): def find_submodule_parameters(model, search_modules): + """Finds all parameters within given submodule types + + Args: + model: torch Module to search through + search_modules: List of submodule types to search for + """ if isinstance(model, search_modules): return model.parameters() @@ -66,6 +72,12 @@ def find_submodule_parameters(model, search_modules): def find_other_than_submodule_parameters(model, ignore_modules): + """Finds all parameters not with given submodule types + + Args: + model: torch Module to search through + search_modules: List of submodule types to ignore + """ if isinstance(model, ignore_modules): return []