From 476b9da4d15cc863f96a7014a37e1561e38a51a9 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Tue, 30 Apr 2024 22:09:38 -0400 Subject: [PATCH 001/183] WIP: Add src/vak/config/dataset.py --- src/vak/config/dataset.py | 40 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) create mode 100644 src/vak/config/dataset.py diff --git a/src/vak/config/dataset.py b/src/vak/config/dataset.py new file mode 100644 index 000000000..f795d74a3 --- /dev/null +++ b/src/vak/config/dataset.py @@ -0,0 +1,40 @@ +"""Class that represents dataset table in toml config file.""" +from __future__ import annotations + +import pathlib + +from attr import define, field +import attr.validators + + +@define +class Dataset: + """Class that represents dataset table in toml config file. + + Attributes + ---------- + name : str, optional + Name of dataset. Only required for built-in datasets + from the :mod:`~vak.datasets` module. + path : pathlib.Path + Path to the directory that contains the dataset. + Equivalent to the `root` parameter of :module:`torchvision` + datasets. + splits_path : pathlib.Path, optional + Path to file representing splits. + """ + path: pathlib.Path = field(converter=pathlib.Path) + name: str | None = field( + converter=attr.converters.optional(str), default=None + ) + splits_path: pathlib.Path | None = field( + converter=attr.converters.optional(pathlib.Path), default=None + ) + + @classmethod + def from_dict(cls, dict_: dict) -> Dataset: + return cls( + path=dict_.get('path'), + name=dict_.get('name'), + splits_path=dict_.get('splits_path') + ) From a625bb7ae2c6e388c8cfe21fb7542284d3c2ba91 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Tue, 30 Apr 2024 22:11:34 -0400 Subject: [PATCH 002/183] Add module-level docstring + type annotations in src/vak/config/parse.py --- src/vak/config/parse.py | 29 ++++++++++++++++++----------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/src/vak/config/parse.py b/src/vak/config/parse.py index 295a1a537..3c099ee80 100644 --- a/src/vak/config/parse.py +++ b/src/vak/config/parse.py @@ -1,4 +1,7 @@ -from pathlib import Path +"""Functions to parse toml config files.""" +from __future__ import annotations + +import pathlib import toml from toml.decoder import TomlDecodeError @@ -111,7 +114,9 @@ def _validate_sections_arg_convert_list(sections): return sections -def from_toml(config_toml, toml_path=None, sections=None): +def from_toml( + config_toml: dict, toml_path: str | pathlib.Path | None = None, sections: list[str] | None = None + ) -> Config: """load a TOML configuration file Parameters @@ -119,7 +124,7 @@ def from_toml(config_toml, toml_path=None, sections=None): config_toml : dict Python ``dict`` containing a .toml configuration file, parsed by the ``toml`` library. - toml_path : str, Path + toml_path : str, pathlib.Path path to a configuration file in TOML format. Default is None. Not required, used only to make any error messages clearer. sections : str, list @@ -154,19 +159,21 @@ def from_toml(config_toml, toml_path=None, sections=None): return Config(**config_dict) -def _load_toml_from_path(toml_path): - """helper function to load toml config file, +def _load_toml_from_path(toml_path: str | pathlib.Path) -> dict: + """Load a toml file from a path, and return as a :class:`dict`. + + Helper function to load toml config file, factored out to use in other modules when needed checks if ``toml_path`` exists before opening, and tries to give a clear message if an error occurs when parsing""" - toml_path = Path(toml_path) + toml_path = pathlib.Path(toml_path) if not toml_path.is_file(): raise FileNotFoundError(f".toml config file not found: {toml_path}") try: with toml_path.open("r") as fp: - config_toml = toml.load(fp) + config_toml: dict = toml.load(fp) except TomlDecodeError as e: raise Exception( f"Error when parsing .toml config file: {toml_path}" @@ -175,12 +182,12 @@ def _load_toml_from_path(toml_path): return config_toml -def from_toml_path(toml_path, sections=None): - """parse a TOML configuration file +def from_toml_path(toml_path: str | pathlib.Path, sections: list[str] | None = None) -> Config: + """Parse a TOML configuration file and return as a :class:`Config`. Parameters ---------- - toml_path : str, Path + toml_path : str, pathlib.Path path to a configuration file in TOML format. Parsed by ``toml`` library, then converted to an instance of ``vak.config.parse.Config`` by @@ -195,7 +202,7 @@ def from_toml_path(toml_path, sections=None): Returns ------- config : vak.config.parse.Config - instance of Config class, whose attributes correspond to + instance of :class:`Config` class, whose attributes correspond to sections in a config.toml file. """ config_toml = _load_toml_from_path(toml_path) From 272f67df8c392f75203a63ccc8e75984c9237f5d Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Tue, 30 Apr 2024 22:12:05 -0400 Subject: [PATCH 003/183] WIP: Fix how cli.prep adds dataset path to toml config file --- src/vak/cli/prep.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/vak/cli/prep.py b/src/vak/cli/prep.py index 3a8ee6b8d..051a323ff 100644 --- a/src/vak/cli/prep.py +++ b/src/vak/cli/prep.py @@ -88,8 +88,8 @@ def prep(toml_path): # ---- figure out purpose of config file from sections; will save csv path in that section ------------------------- purpose = purpose_from_toml(config_toml, toml_path) if ( - "dataset_path" in config_toml[purpose.upper()] - and config_toml[purpose.upper()]["dataset_path"] is not None + "path" in config_toml[purpose.upper()]["path"] + and config_toml[purpose.upper()]["dataset"]["path"] is not None ): raise ValueError( f"config .toml file already has a 'dataset_path' option in the '{purpose.upper()}' section, " @@ -142,7 +142,7 @@ def prep(toml_path): ) # use config and section from above to add dataset_path to config.toml file - config_toml[section]["dataset_path"] = str(dataset_path) + config_toml[section]["dataset"]["path"] = str(dataset_path) with toml_path.open("w") as fp: toml.dump(config_toml, fp) From fa35def8d14c3c8b71759793a731906fc493aeb3 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Tue, 30 Apr 2024 22:27:14 -0400 Subject: [PATCH 004/183] Change table names in src/vak/config/valid.toml --- src/vak/config/valid.toml | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/vak/config/valid.toml b/src/vak/config/valid.toml index 11cd535f5..2103c1166 100644 --- a/src/vak/config/valid.toml +++ b/src/vak/config/valid.toml @@ -5,7 +5,7 @@ # Options should be in the same order they are defined for the # attrs-based class that represents the config, for easy comparison # when changing that class + this file. -[PREP] +[vak.prep] data_dir = './tests/test_data/cbins/gy6or6/032312' output_dir = './tests/test_data/prep/learncurve' dataset_type = 'frame_classification' @@ -22,7 +22,7 @@ test_dur = 30 train_set_durs = [ 4.5, 6.0 ] num_replicates = 2 -[SPECT_PARAMS] +[vak.prep.spect_params] fft_size = 512 step_size = 64 freq_cutoffs = [ 500, 10000 ] @@ -33,7 +33,7 @@ freqbins_key = 'f' timebins_key = 't' audio_path_key = 'audio_path' -[TRAIN] +[vak.train] model = 'TweetyNet' root_results_dir = './tests/test_data/results/train' dataset_path = 'tests/test_data/prep/train/032312_prep_191224_225912.csv' @@ -54,7 +54,7 @@ train_dataset_params = {'window_size' = 80} val_transform_params = {'resize' = 128} val_dataset_params = {'window_size' = 80} -[EVAL] +[vak.eval] dataset_path = 'tests/test_data/prep/learncurve/032312_prep_191224_225910.csv' checkpoint_path = '/home/user/results_181014_194418/TweetyNet/checkpoints/' labelmap_path = '/home/user/results_181014_194418/labelmap.json' @@ -68,7 +68,7 @@ post_tfm_kwargs = {'majority_vote' = true, 'min_segment_dur' = 0.01} transform_params = {'resize' = 128} dataset_params = {'window_size' = 80} -[LEARNCURVE] +[vak.learncurve] model = 'TweetyNet' root_results_dir = './tests/test_data/results/learncurve' batch_size = 11 @@ -88,7 +88,7 @@ train_dataset_params = {'window_size' = 80} val_transform_params = {'resize' = 128} val_dataset_params = {'window_size' = 80} -[PREDICT] +[vak.predict] dataset_path = 'tests/test_data/prep/learncurve/032312_prep_191224_225910.csv' checkpoint_path = '/home/user/results_181014_194418/TweetyNet/checkpoints/' labelmap_path = '/home/user/results_181014_194418/labelmap.json' From 988c150c1f14748221c184a2950955ac00d300f9 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Tue, 30 Apr 2024 22:39:09 -0400 Subject: [PATCH 005/183] Rename section -> table in config/parse.py --- src/vak/config/parse.py | 134 ++++++++++++++++++++-------------------- 1 file changed, 67 insertions(+), 67 deletions(-) diff --git a/src/vak/config/parse.py b/src/vak/config/parse.py index 3c099ee80..1e8bc0aad 100644 --- a/src/vak/config/parse.py +++ b/src/vak/config/parse.py @@ -13,111 +13,111 @@ from .prep import PrepConfig from .spect_params import SpectParamsConfig from .train import TrainConfig -from .validators import are_options_valid, are_sections_valid - -SECTION_CLASSES = { - "EVAL": EvalConfig, - "LEARNCURVE": LearncurveConfig, - "PREDICT": PredictConfig, - "PREP": PrepConfig, - "SPECT_PARAMS": SpectParamsConfig, - "TRAIN": TrainConfig, +from .validators import are_options_valid, are_tables_valid + +TABLE_CLASSES = { + "eval": EvalConfig, + "learncurve": LearncurveConfig, + "predict": PredictConfig, + "prep": PrepConfig, + "spect_params": SpectParamsConfig, + "train": TrainConfig, } REQUIRED_OPTIONS = { - "EVAL": [ + "eval": [ "checkpoint_path", "output_dir", "model", ], - "LEARNCURVE": [ + "learncurve": [ "model", "root_results_dir", ], - "PREDICT": [ + "predict": [ "checkpoint_path", "model", ], - "PREP": [ + "prep": [ "data_dir", "output_dir", ], - "SPECT_PARAMS": None, - "TRAIN": [ + "spect_params": None, + "train": [ "model", "root_results_dir", ], } -def parse_config_section(config_toml, section_name, toml_path=None): - """parse section of config.toml file +def parse_config_table(config_toml, table_name, toml_path=None): + """Parse table of config.toml file Parameters ---------- config_toml : dict - containing config.toml file already loaded by parse function - section_name : str - name of section from configuration - file that should be parsed + Containing config.toml file already loaded by parse function + table_name : str + Name of table from configuration + file that should be parsed. toml_path : str path to a configuration file in TOML format. Default is None. Used for error messages if specified. Returns ------- - config : vak.config section class - instance of class that represents section of config.toml file, - e.g. PredictConfig for 'PREDICT' section + config : vak.config table class + instance of class that represents table of config.toml file, + e.g. PredictConfig for 'PREDICT' table """ - section = dict(config_toml[section_name].items()) + table = dict(config_toml[table_name].items()) - required_options = REQUIRED_OPTIONS[section_name] + required_options = REQUIRED_OPTIONS[table_name] if required_options is not None: for required_option in required_options: - if required_option not in section: + if required_option not in table: if toml_path: err_msg = ( f"the '{required_option}' option is required but was not found in the " - f"{section_name} section of the config.toml file: {toml_path}" + f"{table_name} table of the config.toml file: {toml_path}" ) else: err_msg = ( f"the '{required_option}' option is required but was not found in the " - f"{section_name} section of the toml config" + f"{table_name} table of the toml config" ) raise KeyError(err_msg) - return SECTION_CLASSES[section_name](**section) + return TABLE_CLASSES[table_name](**table) -def _validate_sections_arg_convert_list(sections): - if isinstance(sections, str): - sections = [sections] - elif isinstance(sections, list): +def _validate_tables_arg_convert_list(tables): + if isinstance(tables, str): + tables = [tables] + elif isinstance(tables, list): if not all( - [isinstance(section_name, str) for section_name in sections] + [isinstance(table_name, str) for table_name in tables] ): raise ValueError( - "all section names in 'sections' should be strings" + "all table names in 'tables' should be strings" ) if not all( [ - section_name in list(SECTION_CLASSES.keys()) - for section_name in sections + table_name in list(TABLE_CLASSES.keys()) + for table_name in tables ] ): raise ValueError( - "all section names in 'sections' should be valid names of sections. " - f"Values for 'sections were: {sections}.\n" - f"Valid section names are: {list(SECTION_CLASSES.keys())}" + "all table names in 'tables' should be valid names of tables. " + f"Values for 'tables were: {tables}.\n" + f"Valid table names are: {list(TABLE_CLASSES.keys())}" ) - return sections + return tables def from_toml( - config_toml: dict, toml_path: str | pathlib.Path | None = None, sections: list[str] | None = None + config_toml: dict, toml_path: str | pathlib.Path | None = None, tables: list[str] | None = None ) -> Config: - """load a TOML configuration file + """Load a TOML configuration file. Parameters ---------- @@ -127,33 +127,33 @@ def from_toml( toml_path : str, pathlib.Path path to a configuration file in TOML format. Default is None. Not required, used only to make any error messages clearer. - sections : str, list - name of section or sections from configuration + tables : str, list + Name of table or tables from configuration file that should be parsed. Can be a string - (single section) or list of strings (multiple - sections). Default is None, + (single table) or list of strings (multiple + tables). Default is None, in which case all are validated and parsed. Returns ------- config : vak.config.parse.Config instance of Config class, whose attributes correspond to - sections in a config.toml file. + tables in a config.toml file. """ - are_sections_valid(config_toml, toml_path) + are_tables_valid(config_toml, toml_path) - sections = _validate_sections_arg_convert_list(sections) + tables = _validate_tables_arg_convert_list(tables) config_dict = {} - if sections is None: - sections = list( - SECTION_CLASSES.keys() - ) # i.e., parse all sections, except model - for section_name in sections: - if section_name in config_toml: - are_options_valid(config_toml, section_name, toml_path) - config_dict[section_name.lower()] = parse_config_section( - config_toml, section_name, toml_path + if tables is None: + tables = list( + TABLE_CLASSES.keys() + ) # i.e., parse all tables, except model + for table_name in tables: + if table_name in config_toml: + are_options_valid(config_toml, table_name, toml_path) + config_dict[table_name.lower()] = parse_config_table( + config_toml, table_name, toml_path ) return Config(**config_dict) @@ -182,7 +182,7 @@ def _load_toml_from_path(toml_path: str | pathlib.Path) -> dict: return config_toml -def from_toml_path(toml_path: str | pathlib.Path, sections: list[str] | None = None) -> Config: +def from_toml_path(toml_path: str | pathlib.Path, tables: list[str] | None = None) -> Config: """Parse a TOML configuration file and return as a :class:`Config`. Parameters @@ -192,18 +192,18 @@ def from_toml_path(toml_path: str | pathlib.Path, sections: list[str] | None = N Parsed by ``toml`` library, then converted to an instance of ``vak.config.parse.Config`` by calling ``vak.parse.from_toml`` - sections : str, list - name of section or sections from configuration + tables : str, list + name of table or tables from configuration file that should be parsed. Can be a string - (single section) or list of strings (multiple - sections). Default is None, + (single table) or list of strings (multiple + tables). Default is None, in which case all are validated and parsed. Returns ------- config : vak.config.parse.Config instance of :class:`Config` class, whose attributes correspond to - sections in a config.toml file. + tables in a config.toml file. """ config_toml = _load_toml_from_path(toml_path) - return from_toml(config_toml, toml_path, sections) + return from_toml(config_toml, toml_path, tables) From 72e11e8067dd3cbb311276884e564aeb95fde2bb Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Tue, 30 Apr 2024 22:49:13 -0400 Subject: [PATCH 006/183] In cli/prep change 'section' -> 'table' and lowercase table names --- src/vak/cli/prep.py | 40 ++++++++++++++++++++-------------------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/src/vak/cli/prep.py b/src/vak/cli/prep.py index 051a323ff..80b7384d0 100644 --- a/src/vak/cli/prep.py +++ b/src/vak/cli/prep.py @@ -10,7 +10,7 @@ from .. import config from .. import prep as prep_module from ..config.parse import _load_toml_from_path -from ..config.validators import are_sections_valid +from ..config.validators import are_tables_valid def purpose_from_toml( @@ -19,11 +19,11 @@ def purpose_from_toml( """determine "purpose" from toml config, i.e., the command that will be run after we ``prep`` the data. - By convention this is the other section in the config file + By convention this is the other table in the config file that correspond to a cli command besides '[PREP]' """ # validate, make sure there aren't multiple commands in one config file first - are_sections_valid(config_toml, toml_path=toml_path) + are_tables_valid(config_toml, toml_path=toml_path) from ..cli.cli import CLI_COMMANDS # avoid circular imports @@ -31,18 +31,18 @@ def purpose_from_toml( command for command in CLI_COMMANDS if command != "prep" ) for command in commands_that_are_not_prep: - section_name = ( + table_name = ( command.upper() - ) # we write section names in uppercase, e.g. `[PREP]`, by convention - if section_name in config_toml: - return section_name.lower() # this is the "purpose" of the file + ) # we write table names in uppercase, e.g. `[PREP]`, by convention + if table_name in config_toml: + return table_name.lower() # this is the "purpose" of the file # note NO LOGGING -- we configure logger inside `core.prep` # so we can save log file inside dataset directory # see https://github.com/NickleDave/vak/issues/334 -SECTIONS_PREP_SHOULD_PARSE = ("PREP", "SPECT_PARAMS", "DATALOADER") +SECTIONS_PREP_SHOULD_PARSE = ("prep", "spect_params", "dataloader") def prep(toml_path): @@ -85,41 +85,41 @@ def prep(toml_path): # open here because need to check for `dataset_path` in this function, see #314 & #333 config_toml = _load_toml_from_path(toml_path) - # ---- figure out purpose of config file from sections; will save csv path in that section ------------------------- + # ---- figure out purpose of config file from tables; will save csv path in that table ------------------------- purpose = purpose_from_toml(config_toml, toml_path) if ( "path" in config_toml[purpose.upper()]["path"] and config_toml[purpose.upper()]["dataset"]["path"] is not None ): raise ValueError( - f"config .toml file already has a 'dataset_path' option in the '{purpose.upper()}' section, " + f"config .toml file already has a 'dataset_path' option in the '{purpose.upper()}' table, " f"and running `prep` would overwrite that value. To `prep` a new dataset, please remove " - f"the 'dataset_path' option from the '{purpose.upper()}' section in the config file:\n{toml_path}" + f"the 'dataset_path' option from the '{purpose.upper()}' table in the config file:\n{toml_path}" ) - # now that we've checked that, go ahead and parse the sections we want + # now that we've checked that, go ahead and parse the tables we want cfg = config.parse.from_toml_path( - toml_path, sections=SECTIONS_PREP_SHOULD_PARSE + toml_path, tables=SECTIONS_PREP_SHOULD_PARSE ) - # notice we ignore any other option/values in the 'purpose' section, + # notice we ignore any other option/values in the 'purpose' table, # see https://github.com/NickleDave/vak/issues/334 and https://github.com/NickleDave/vak/issues/314 if cfg.prep is None: raise ValueError( - f"prep called with a config.toml file that does not have a PREP section: {toml_path}" + f"prep called with a config.toml file that does not have a [vak.prep] table: {toml_path}" ) if purpose == "predict": if cfg.prep.labelset is not None: warnings.warn( - "config has a PREDICT section, but labelset option is specified in PREP section." - "This would cause an error because the dataframe.from_files section will attempt to " + "config has a [vak.predict] table, but labelset option is specified in [vak.prep] table." + "This would cause an error because the dataframe.from_files method will attempt to " f"check whether the files in the data_dir ({cfg.prep.data_dir}) have labels in " "labelset, even though those files don't have annotation.\n" "Setting labelset to None." ) cfg.prep.labelset = None - section = purpose.upper() + table = purpose.upper() dataset_df, dataset_path = prep_module.prep( data_dir=cfg.prep.data_dir, @@ -141,8 +141,8 @@ def prep(toml_path): num_replicates=cfg.prep.num_replicates, ) - # use config and section from above to add dataset_path to config.toml file - config_toml[section]["dataset"]["path"] = str(dataset_path) + # use config and table from above to add dataset_path to config.toml file + config_toml[table]["dataset"]["path"] = str(dataset_path) with toml_path.open("w") as fp: toml.dump(config_toml, fp) From 7181f718d885df798a652abc02543d040141a5cc Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Tue, 30 Apr 2024 22:49:37 -0400 Subject: [PATCH 007/183] In config/config.py, change 'section' -> 'table' and lowercase table names --- src/vak/config/config.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/vak/config/config.py b/src/vak/config/config.py index 377802b3b..771ca40a4 100644 --- a/src/vak/config/config.py +++ b/src/vak/config/config.py @@ -16,17 +16,17 @@ class Config: Attributes ---------- prep : vak.config.prep.PrepConfig - represents ``[PREP]`` section of config.toml file + represents ``[vak.prep]`` table of config.toml file spect_params : vak.config.spect_params.SpectParamsConfig - represents ``[SPECT_PARAMS]`` section of config.toml file + represents ``[SPECT_PARAMS]`` table of config.toml file train : vak.config.train.TrainConfig - represents ``[TRAIN]`` section of config.toml file + represents ``[vak.train]`` table of config.toml file eval : vak.config.eval.EvalConfig - represents ``[EVAL]`` section of config.toml file + represents ``[vak.eval]`` table of config.toml file predict : vak.config.predict.PredictConfig - represents ``[PREDICT]`` section of config.toml file. + represents ``[vak.predict]`` table of config.toml file. learncurve : vak.config.learncurve.LearncurveConfig - represents ``[LEARNCURVE]`` section of config.toml file + represents ``[vak.learncurve]`` table of config.toml file """ spect_params = attr.ib( From 9d4467cb05d6df83bfc07feb55ff40e1eaafa77c Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Tue, 30 Apr 2024 22:50:06 -0400 Subject: [PATCH 008/183] Change '[PREP]' -> '[vak.prep]' in config/prep.py --- src/vak/config/prep.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/vak/config/prep.py b/src/vak/config/prep.py index 7481d8cc2..031cd0e00 100644 --- a/src/vak/config/prep.py +++ b/src/vak/config/prep.py @@ -1,4 +1,4 @@ -"""parses [PREP] section of config""" +"""parses [vak.prep] section of config""" import inspect import attr From b5a64136cf95ed6c62afdd45f38bc55dcedef07a Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Tue, 30 Apr 2024 22:56:38 -0400 Subject: [PATCH 009/183] WIP: Change table names in config files in tests/data_for_tests/configs --- ...ConvEncoderUMAP_eval_audio_cbin_annot_notmat.toml | 11 +++++------ ...onvEncoderUMAP_train_audio_cbin_annot_notmat.toml | 6 +++--- .../TweetyNet_eval_audio_cbin_annot_notmat.toml | 10 +++++----- ...TweetyNet_learncurve_audio_cbin_annot_notmat.toml | 12 ++++++------ .../TweetyNet_predict_audio_cbin_annot_notmat.toml | 8 ++++---- .../TweetyNet_train_audio_cbin_annot_notmat.toml | 10 +++++----- ...tyNet_train_continue_audio_cbin_annot_notmat.toml | 10 +++++----- ...etyNet_train_continue_spect_mat_annot_yarden.toml | 10 +++++----- .../TweetyNet_train_spect_mat_annot_yarden.toml | 10 +++++----- .../configs/invalid_option_config.toml | 6 +++--- .../configs/invalid_section_config.toml | 6 +++--- .../configs/invalid_train_and_learncurve_config.toml | 10 +++++----- 12 files changed, 54 insertions(+), 55 deletions(-) diff --git a/tests/data_for_tests/configs/ConvEncoderUMAP_eval_audio_cbin_annot_notmat.toml b/tests/data_for_tests/configs/ConvEncoderUMAP_eval_audio_cbin_annot_notmat.toml index a2be49143..b38979f48 100644 --- a/tests/data_for_tests/configs/ConvEncoderUMAP_eval_audio_cbin_annot_notmat.toml +++ b/tests/data_for_tests/configs/ConvEncoderUMAP_eval_audio_cbin_annot_notmat.toml @@ -1,4 +1,4 @@ -[PREP] +[vak.prep] dataset_type = "parametric umap" input_type = "spect" data_dir = "./tests/data_for_tests/source/audio_cbin_annot_notmat/gy6or6/032412" @@ -8,20 +8,19 @@ annot_format = "notmat" labelset = "iabcdefghjk" test_dur = 0.2 -[SPECT_PARAMS] +[vak.prep.spect_params] fft_size = 512 step_size = 32 transform_type = "log_spect_plus_one" -[EVAL] +[vak.eval] checkpoint_path = "tests/data_for_tests/generated/results/train/audio_cbin_annot_notmat/ConvEncoderUMAP/results_230727_210112/ConvEncoderUMAP/checkpoints/checkpoint.pt" -model = "ConvEncoderUMAP" batch_size = 64 num_workers = 16 device = "cuda" output_dir = "./tests/data_for_tests/generated/results/eval/audio_cbin_annot_notmat/ConvEncoderUMAP" -[ConvEncoderUMAP.network] +[vak.eval.model.ConvEncoderUMAP.network] conv1_filters = 8 conv2_filters = 16 conv_kernel_size = 3 @@ -30,5 +29,5 @@ conv_padding = 1 n_features_linear = 32 n_components = 2 -[ConvEncoderUMAP.optimizer] +[vak.eval.model.ConvEncoderUMAP.optimizer] lr = 0.001 diff --git a/tests/data_for_tests/configs/ConvEncoderUMAP_train_audio_cbin_annot_notmat.toml b/tests/data_for_tests/configs/ConvEncoderUMAP_train_audio_cbin_annot_notmat.toml index 456c99468..fef7afeb4 100644 --- a/tests/data_for_tests/configs/ConvEncoderUMAP_train_audio_cbin_annot_notmat.toml +++ b/tests/data_for_tests/configs/ConvEncoderUMAP_train_audio_cbin_annot_notmat.toml @@ -1,4 +1,4 @@ -[PREP] +[vak.prep] dataset_type = "parametric umap" input_type = "spect" data_dir = "./tests/data_for_tests/source/audio_cbin_annot_notmat/gy6or6/032312" @@ -10,12 +10,12 @@ train_dur = 0.5 val_dur = 0.2 test_dur = 0.25 -[SPECT_PARAMS] +[vak.prep.spect_params] fft_size = 512 step_size = 32 transform_type = "log_spect_plus_one" -[TRAIN] +[vak.train] model = "ConvEncoderUMAP" batch_size = 64 num_epochs = 1 diff --git a/tests/data_for_tests/configs/TweetyNet_eval_audio_cbin_annot_notmat.toml b/tests/data_for_tests/configs/TweetyNet_eval_audio_cbin_annot_notmat.toml index 51f5157e4..4af0d82d4 100644 --- a/tests/data_for_tests/configs/TweetyNet_eval_audio_cbin_annot_notmat.toml +++ b/tests/data_for_tests/configs/TweetyNet_eval_audio_cbin_annot_notmat.toml @@ -1,4 +1,4 @@ -[PREP] +[vak.prep] dataset_type = "frame classification" input_type = "spect" labelset = "iabcdefghjk" @@ -7,14 +7,14 @@ output_dir = "./tests/data_for_tests/generated/prep/eval/audio_cbin_annot_notmat audio_format = "cbin" annot_format = "notmat" -[SPECT_PARAMS] +[vak.prep.spect_params] fft_size = 512 step_size = 64 freq_cutoffs = [ 500, 10000,] thresh = 6.25 transform_type = "log_spect" -[EVAL] +[vak.eval] checkpoint_path = "~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/TweetyNet/checkpoints/max-val-acc-checkpoint.pt" labelmap_path = "~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/labelmap.json" model = "TweetyNet" @@ -24,11 +24,11 @@ device = "cuda" spect_scaler_path = "~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/StandardizeSpect" output_dir = "./tests/data_for_tests/generated/results/eval/audio_cbin_annot_notmat/TweetyNet" -[EVAL.post_tfm_kwargs] +[vak.eval.post_tfm_kwargs] majority_vote = true min_segment_dur = 0.02 -[EVAL.transform_params] +[vak.eval.transform_params] window_size = 88 [TweetyNet.network] diff --git a/tests/data_for_tests/configs/TweetyNet_learncurve_audio_cbin_annot_notmat.toml b/tests/data_for_tests/configs/TweetyNet_learncurve_audio_cbin_annot_notmat.toml index 0922283e8..9a909bf0a 100644 --- a/tests/data_for_tests/configs/TweetyNet_learncurve_audio_cbin_annot_notmat.toml +++ b/tests/data_for_tests/configs/TweetyNet_learncurve_audio_cbin_annot_notmat.toml @@ -1,4 +1,4 @@ -[PREP] +[vak.prep] dataset_type = "frame classification" input_type = "spect" data_dir = "./tests/data_for_tests/source/audio_cbin_annot_notmat/gy6or6/032312" @@ -12,14 +12,14 @@ test_dur = 30 train_set_durs = [ 4, 6,] num_replicates = 2 -[SPECT_PARAMS] +[vak.prep.spect_params] fft_size = 512 step_size = 64 freq_cutoffs = [ 500, 10000,] thresh = 6.25 transform_type = "log_spect" -[LEARNCURVE] +[vak.learncurve] model = "TweetyNet" normalize_spectrograms = true batch_size = 11 @@ -31,14 +31,14 @@ num_workers = 16 device = "cuda" root_results_dir = "./tests/data_for_tests/generated/results/learncurve/audio_cbin_annot_notmat/TweetyNet" -[LEARNCURVE.post_tfm_kwargs] +[vak.learncurve.post_tfm_kwargs] majority_vote = true min_segment_dur = 0.02 -[LEARNCURVE.train_dataset_params] +[vak.learncurve.train_dataset_params] window_size = 88 -[LEARNCURVE.val_transform_params] +[vak.learncurve.val_transform_params] window_size = 88 [TweetyNet.network] diff --git a/tests/data_for_tests/configs/TweetyNet_predict_audio_cbin_annot_notmat.toml b/tests/data_for_tests/configs/TweetyNet_predict_audio_cbin_annot_notmat.toml index da6a9175c..b2db522f0 100644 --- a/tests/data_for_tests/configs/TweetyNet_predict_audio_cbin_annot_notmat.toml +++ b/tests/data_for_tests/configs/TweetyNet_predict_audio_cbin_annot_notmat.toml @@ -1,18 +1,18 @@ -[PREP] +[vak.prep] dataset_type = "frame classification" input_type = "spect" data_dir = "./tests/data_for_tests/source/audio_cbin_annot_notmat/gy6or6/032412" output_dir = "./tests/data_for_tests/generated/prep/predict/audio_cbin_annot_notmat/TweetyNet" spect_format = "npz" -[SPECT_PARAMS] +[vak.prep.spect_params] fft_size = 512 step_size = 64 freq_cutoffs = [ 500, 10000,] thresh = 6.25 transform_type = "log_spect" -[PREDICT] +[vak.predict] spect_scaler_path = "/home/user/results_181014_194418/spect_scaler" checkpoint_path = "~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/bl26lb16/results_200620_164245/TweetyNet/checkpoints/max-val-acc-checkpoint.pt" labelmap_path = "~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/bl26lb16/results_200620_164245/labelmap.json" @@ -23,7 +23,7 @@ device = "cuda" output_dir = "./tests/data_for_tests/generated/results/predict/audio_cbin_annot_notmat/TweetyNet" annot_csv_filename = "bl26lb16.041912.annot.csv" -[PREDICT.transform_params] +[vak.predict.transform_params] window_size = 88 [TweetyNet.network] diff --git a/tests/data_for_tests/configs/TweetyNet_train_audio_cbin_annot_notmat.toml b/tests/data_for_tests/configs/TweetyNet_train_audio_cbin_annot_notmat.toml index 2f72adfb1..7e06b4a86 100644 --- a/tests/data_for_tests/configs/TweetyNet_train_audio_cbin_annot_notmat.toml +++ b/tests/data_for_tests/configs/TweetyNet_train_audio_cbin_annot_notmat.toml @@ -1,4 +1,4 @@ -[PREP] +[vak.prep] dataset_type = "frame classification" input_type = "spect" data_dir = "./tests/data_for_tests/source/audio_cbin_annot_notmat/gy6or6/032312" @@ -10,14 +10,14 @@ train_dur = 50 val_dur = 15 test_dur = 30 -[SPECT_PARAMS] +[vak.prep.spect_params] fft_size = 512 step_size = 64 freq_cutoffs = [ 500, 10000,] thresh = 6.25 transform_type = "log_spect" -[TRAIN] +[vak.train] model = "TweetyNet" normalize_spectrograms = true batch_size = 11 @@ -29,10 +29,10 @@ num_workers = 16 device = "cuda" root_results_dir = "./tests/data_for_tests/generated/results/train/audio_cbin_annot_notmat/TweetyNet" -[TRAIN.train_dataset_params] +[vak.train.train_dataset_params] window_size = 88 -[TRAIN.val_transform_params] +[vak.train.val_transform_params] window_size = 88 [TweetyNet.network] diff --git a/tests/data_for_tests/configs/TweetyNet_train_continue_audio_cbin_annot_notmat.toml b/tests/data_for_tests/configs/TweetyNet_train_continue_audio_cbin_annot_notmat.toml index 932208616..97c602a5d 100644 --- a/tests/data_for_tests/configs/TweetyNet_train_continue_audio_cbin_annot_notmat.toml +++ b/tests/data_for_tests/configs/TweetyNet_train_continue_audio_cbin_annot_notmat.toml @@ -1,4 +1,4 @@ -[PREP] +[vak.prep] dataset_type = "frame classification" input_type = "spect" data_dir = "./tests/data_for_tests/source/audio_cbin_annot_notmat/gy6or6/032312" @@ -10,14 +10,14 @@ train_dur = 50 val_dur = 15 test_dur = 30 -[SPECT_PARAMS] +[vak.prep.spect_params] fft_size = 512 step_size = 64 freq_cutoffs = [ 500, 10000,] thresh = 6.25 transform_type = "log_spect" -[TRAIN] +[vak.train] model = "TweetyNet" normalize_spectrograms = true batch_size = 11 @@ -31,10 +31,10 @@ root_results_dir = "./tests/data_for_tests/generated/results/train_continue/audi checkpoint_path = "~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/TweetyNet/checkpoints/max-val-acc-checkpoint.pt" spect_scaler_path = "~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/StandardizeSpect" -[TRAIN.train_dataset_params] +[vak.train.train_dataset_params] window_size = 88 -[TRAIN.val_transform_params] +[vak.train.val_transform_params] window_size = 88 [TweetyNet.network] diff --git a/tests/data_for_tests/configs/TweetyNet_train_continue_spect_mat_annot_yarden.toml b/tests/data_for_tests/configs/TweetyNet_train_continue_spect_mat_annot_yarden.toml index aa384b6ed..c2826251f 100644 --- a/tests/data_for_tests/configs/TweetyNet_train_continue_spect_mat_annot_yarden.toml +++ b/tests/data_for_tests/configs/TweetyNet_train_continue_spect_mat_annot_yarden.toml @@ -1,4 +1,4 @@ -[PREP] +[vak.prep] dataset_type = "frame classification" input_type = "spect" data_dir = "./tests/data_for_tests/source/spect_mat_annot_yarden/llb3/spect" @@ -10,14 +10,14 @@ labelset = "range: 1-3,6-14,17-19" train_dur = 213 val_dur = 213 -[SPECT_PARAMS] +[vak.prep.spect_params] fft_size = 512 step_size = 64 freq_cutoffs = [ 500, 10000,] thresh = 6.25 transform_type = "log_spect" -[TRAIN] +[vak.train] model = "TweetyNet" normalize_spectrograms = false batch_size = 11 @@ -30,10 +30,10 @@ device = "cuda" root_results_dir = "./tests/data_for_tests/generated/results/train_continue/spect_mat_annot_yarden/TweetyNet" checkpoint_path = "~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/TweetyNet/checkpoints/max-val-acc-checkpoint.pt" -[TRAIN.train_dataset_params] +[vak.train.train_dataset_params] window_size = 88 -[TRAIN.val_transform_params] +[vak.train.val_transform_params] window_size = 88 [TweetyNet.network] diff --git a/tests/data_for_tests/configs/TweetyNet_train_spect_mat_annot_yarden.toml b/tests/data_for_tests/configs/TweetyNet_train_spect_mat_annot_yarden.toml index 770012f4f..9ed05126e 100644 --- a/tests/data_for_tests/configs/TweetyNet_train_spect_mat_annot_yarden.toml +++ b/tests/data_for_tests/configs/TweetyNet_train_spect_mat_annot_yarden.toml @@ -1,4 +1,4 @@ -[PREP] +[vak.prep] dataset_type = "frame classification" input_type = "spect" data_dir = "./tests/data_for_tests/source/spect_mat_annot_yarden/llb3/spect" @@ -10,14 +10,14 @@ labelset = "range: 1-3,6-14,17-19" train_dur = 213 val_dur = 213 -[SPECT_PARAMS] +[vak.prep.spect_params] fft_size = 512 step_size = 64 freq_cutoffs = [ 500, 10000,] thresh = 6.25 transform_type = "log_spect" -[TRAIN] +[vak.train] model = "TweetyNet" normalize_spectrograms = false batch_size = 11 @@ -29,10 +29,10 @@ num_workers = 16 device = "cuda" root_results_dir = "./tests/data_for_tests/generated/results/train/spect_mat_annot_yarden/TweetyNet" -[TRAIN.train_dataset_params] +[vak.train.train_dataset_params] window_size = 88 -[TRAIN.val_transform_params] +[vak.train.val_transform_params] window_size = 88 [TweetyNet.network] diff --git a/tests/data_for_tests/configs/invalid_option_config.toml b/tests/data_for_tests/configs/invalid_option_config.toml index 5504fbf38..469e55435 100644 --- a/tests/data_for_tests/configs/invalid_option_config.toml +++ b/tests/data_for_tests/configs/invalid_option_config.toml @@ -1,7 +1,7 @@ # used to test that invalid option 'ouput_dir' (instead of 'output_dir') # raises a ValueError when passed to # vak.config.validators.are_options_valid -[PREP] +[vak.prep] dataset_type = "frame classification" input_type = "spect" data_dir = '/home/user/data/subdir/' @@ -20,7 +20,7 @@ freq_cutoffs = [500, 10000] thresh = 6.25 transform_type = 'log_spect' -[TRAIN] +[vak.train] model = 'TweetyNet' root_results_dir = '/home/user/data/subdir/' normalize_spectrograms = true @@ -30,7 +30,7 @@ val_error_step = 1 checkpoint_step = 1 save_only_single_checkpoint_file = true -[TRAIN.dataset_params] +[vak.train.dataset_params] window_size = 88 [TweetyNet.optimizer] diff --git a/tests/data_for_tests/configs/invalid_section_config.toml b/tests/data_for_tests/configs/invalid_section_config.toml index f77cde3a3..517517d21 100644 --- a/tests/data_for_tests/configs/invalid_section_config.toml +++ b/tests/data_for_tests/configs/invalid_section_config.toml @@ -1,7 +1,7 @@ -# used to test that invalid section 'TRIAN' (instead of 'TRAIN') +# used to test that invalid section 'TRIAN' (instead of 'vak.train') # raises a ValueError when passed to # vak.config.validators.are_sections_valid -[PREP] +[vak.prep] dataset_type = "frame classification" input_type = "spect" data_dir = '/home/user/data/subdir/' @@ -20,7 +20,7 @@ freq_cutoffs = [500, 10000] thresh = 6.25 transform_type = 'log_spect' -[TRIAN] # <-- invalid section 'TRIAN' (instead of 'TRAIN') +[TRIAN] # <-- invalid section 'TRIAN' (instead of 'vak.train') model = 'TweetyNet' root_results_dir = '/home/user/data/subdir/' normalize_spectrograms = true diff --git a/tests/data_for_tests/configs/invalid_train_and_learncurve_config.toml b/tests/data_for_tests/configs/invalid_train_and_learncurve_config.toml index ce13bb316..29e12ee2e 100644 --- a/tests/data_for_tests/configs/invalid_train_and_learncurve_config.toml +++ b/tests/data_for_tests/configs/invalid_train_and_learncurve_config.toml @@ -1,4 +1,4 @@ -[PREP] +[vak.prep] dataset_type = "frame classification" input_type = "spect" data_dir = "./tests/data_for_tests/source/cbins/gy6or6/032312" @@ -10,7 +10,7 @@ train_dur = 50 val_dur = 15 test_dur = 30 -[SPECT_PARAMS] +[vak.prep.spect_params] fft_size=512 step_size=64 freq_cutoffs = [500, 10000] @@ -18,8 +18,8 @@ thresh = 6.25 transform_type = "log_spect" # this .toml file should cause 'vak.config.parse.from_toml' to raise a ValueError -# because it defines both a TRAIN and a LEARNCURVE section -[TRAIN] +# because it defines both a vak.train and a vak.learncurve section +[vak.train] model = "TweetyNet" normalize_spectrograms = true batch_size = 11 @@ -31,7 +31,7 @@ num_workers = 16 device = "cuda" root_results_dir = "./tests/data_for_tests/generated/results/train/audio_cbin_annot_notmat" -[LEARNCURVE] +[vak.learncurve] model = 'TweetyNet' normalize_spectrograms = true batch_size = 11 From 27f0596cc3c747083ac7feea1558463dbc188136 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Tue, 30 Apr 2024 22:58:11 -0400 Subject: [PATCH 010/183] Make tomlkit a dependency in pyproject.toml, drop toml --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index bacf95d7a..f0895a996 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,7 +36,7 @@ dependencies = [ "SoundFile >=0.10.3", "pandas >=1.0.1", "tensorboard >=2.8.0", - "toml >=0.10.2", + "tomlkit >=0.12.4", "torch >= 2.0.1", "torchvision >=0.15.2", "tqdm >=4.42.1", From 62fe4f48dcb72aa7e51765ad98871697ee416cac Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Wed, 1 May 2024 08:28:49 -0400 Subject: [PATCH 011/183] Change config/parse.py to use tomlkit --- src/vak/config/parse.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/src/vak/config/parse.py b/src/vak/config/parse.py index 1e8bc0aad..29b31b952 100644 --- a/src/vak/config/parse.py +++ b/src/vak/config/parse.py @@ -3,8 +3,8 @@ import pathlib -import toml -from toml.decoder import TomlDecodeError +import tomlkit +import tomlkit.exceptions from .config import Config from .eval import EvalConfig @@ -163,23 +163,28 @@ def _load_toml_from_path(toml_path: str | pathlib.Path) -> dict: """Load a toml file from a path, and return as a :class:`dict`. Helper function to load toml config file, - factored out to use in other modules when needed - - checks if ``toml_path`` exists before opening, - and tries to give a clear message if an error occurs when parsing""" + factored out to use in other modules when needed. + Checks if ``toml_path`` exists before opening, + and tries to give a clear message if an error occurs when loading.""" toml_path = pathlib.Path(toml_path) if not toml_path.is_file(): raise FileNotFoundError(f".toml config file not found: {toml_path}") try: with toml_path.open("r") as fp: - config_toml: dict = toml.load(fp) - except TomlDecodeError as e: + config_toml: dict = tomlkit.load(fp) + except tomlkit.exceptions.TOMLKitError as e: raise Exception( f"Error when parsing .toml config file: {toml_path}" ) from e - return config_toml + if 'vak' not in config_toml: + raise ValueError( + "Toml file does not contain a top-level table named `vak`. " + f"Please see example configuration files here: " + ) + + return config_toml['vak'] def from_toml_path(toml_path: str | pathlib.Path, tables: list[str] | None = None) -> Config: From 3253a4ea2046daf3281e7dc4a31bf9d24259b6c1 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Wed, 1 May 2024 08:32:35 -0400 Subject: [PATCH 012/183] Update example configs in doc/toml/ --- doc/toml/gy6or6_eval.toml | 16 ++++++++-------- doc/toml/gy6or6_predict.toml | 10 +++++----- doc/toml/gy6or6_train.toml | 14 +++++++------- 3 files changed, 20 insertions(+), 20 deletions(-) diff --git a/doc/toml/gy6or6_eval.toml b/doc/toml/gy6or6_eval.toml index fcd9a7203..da788a6b7 100644 --- a/doc/toml/gy6or6_eval.toml +++ b/doc/toml/gy6or6_eval.toml @@ -1,4 +1,4 @@ -[PREP] +[vak.prep] # dataset_type: corresponds to the model family such as "frame classification" or "parametric umap" dataset_type = "frame classification" # input_type: input to model, either audio ("audio") or spectrogram ("spect") @@ -19,7 +19,7 @@ train_dur = 50 val_dur = 15 # SPECT_PARAMS: parameters for computing spectrograms -[SPECT_PARAMS] +[vak.prep.spect_params] # fft_size: size of window used for Fast Fourier Transform, in number of samples fft_size = 512 # step_size: size of step to take when computing spectra with FFT for spectrogram @@ -27,7 +27,7 @@ fft_size = 512 step_size = 64 # EVAL: options for evaluating a trained model. This is done using the "test" split. -[EVAL] +[vak.eval] model = "TweetyNet" # checkpoint_path: path to saved model checkpoint checkpoint_path = "/PATH/TO/FOLDER/results/train/RESULTS_TIMESTAMP/TweetyNet/checkpoints/max-val-acc-checkpoint.pt" @@ -51,7 +51,7 @@ output_dir = "/PATH/TO/FOLDER/results/eval" # ADD THE dataset_path OPTION FROM THE TRAIN FILE HERE (we already created a test split when we ran `vak prep` with that config) # EVAL.post_tfm_kwargs: options for post-processing -[EVAL.post_tfm_kwargs] +[vak.eval.post_tfm_kwargs] # both these transforms require that there is an "unlabeled" label, # and they will only be applied to segments that are bordered on both sides # by the "unlabeled" label. @@ -68,9 +68,9 @@ min_segment_dur = 0.02 # transform_params: parameters used when transforming data # for a frame classification model, we use FrameDataset with the eval_item_transform, # that reshapes batches into consecutive adjacent windows with a specific `window_size` -[EVAL.transform_params] +[vak.eval.transform_params] window_size = 176 -# Note we do not specify any options for the network, and just use the defaults -# We need to put this "dummy" table here though for the config to parse correctly -[TweetyNet] +# Note we do not specify any options for the model, and just use the defaults +# We need to put this table here though so we know which model we are using +[vak.eval.model.TweetyNet] diff --git a/doc/toml/gy6or6_predict.toml b/doc/toml/gy6or6_predict.toml index bcdfbd240..001c16cae 100644 --- a/doc/toml/gy6or6_predict.toml +++ b/doc/toml/gy6or6_predict.toml @@ -1,5 +1,5 @@ # PREP: options for preparing dataset -[PREP] +[vak.prep] # dataset_type: corresponds to the model family such as "frame classification" or "parametric umap" dataset_type = "frame classification" # input_type: input to model, either audio ("audio") or spectrogram ("spect") @@ -15,7 +15,7 @@ audio_format = "wav" # all data found in `data_dir` will be assigned to a "predict split" instead # SPECT_PARAMS: parameters for computing spectrograms -[SPECT_PARAMS] +[vak.prep.spect_params] # fft_size: size of window used for Fast Fourier Transform, in number of samples fft_size = 512 # step_size: size of step to take when computing spectra with FFT for spectrogram @@ -23,7 +23,7 @@ fft_size = 512 step_size = 64 # PREDICT: options for generating predictions with a trained model -[PREDICT] +[vak.predict] # model: the string name of the model. must be a name within `vak.models` or added e.g. with `vak.model.decorators.model` model = "TweetyNet" # checkpoint_path: path to saved model checkpoint @@ -64,9 +64,9 @@ min_segment_dur = 0.01 # transform_params: parameters used when transforming data # for a frame classification model, we use FrameDataset with the eval_item_transform, # that reshapes batches into consecutive adjacent windows with a specific `window_size` -[PREDICT.transform_params] +[vak.predict.transform_params] window_size = 176 # Note we do not specify any options for the network, and just use the defaults # We need to put this "dummy" table here though for the config to parse correctly -[TweetyNet] +[vak.predict.model.TweetyNet] diff --git a/doc/toml/gy6or6_train.toml b/doc/toml/gy6or6_train.toml index e86b5f7c8..ab3a02b85 100644 --- a/doc/toml/gy6or6_train.toml +++ b/doc/toml/gy6or6_train.toml @@ -1,5 +1,5 @@ # PREP: options for preparing dataset -[PREP] +[vak.prep] # dataset_type: corresponds to the model family such as "frame classification" or "parametric umap" dataset_type = "frame classification" # input_type: input to model, either audio ("audio") or spectrogram ("spect") @@ -22,7 +22,7 @@ val_dur = 15 test_dur = 30 # SPECT_PARAMS: parameters for computing spectrograms -[SPECT_PARAMS] +[vak.prep.spect_params] # fft_size: size of window used for Fast Fourier Transform, in number of samples fft_size = 512 # step_size: size of step to take when computing spectra with FFT for spectrogram @@ -30,7 +30,7 @@ fft_size = 512 step_size = 64 # TRAIN: options for training model -[TRAIN] +[vak.train] # model: the string name of the model. must be a name within `vak.models` or added e.g. with `vak.model.decorators.model` model = "TweetyNet" # root_results_dir: directory where results should be saved, as a sub-directory within `root_results_dir` @@ -60,21 +60,21 @@ device = "cuda" # train_dataset_params: parameters used when loading training dataset # for a frame classification model, we use a WindowDataset with a specific `window_size` -[TRAIN.train_dataset_params] +[vak.train.train_dataset_params] window_size = 176 # val_transform_params: parameters used when transforming validation data # for a frame classification model, we use FrameDataset with the eval_item_transform, # that reshapes batches into consecutive adjacent windows with a specific `window_size` -[TRAIN.val_transform_params] +[vak.train.val_transform_params] window_size = 176 # TweetyNet.optimizer: we specify options for the model's optimizer in this table -[TweetyNet.optimizer] +[vak.train.model.TweetyNet.optimizer] # lr: the learning rate lr = 0.001 # TweetyNet.network: we specify options for the model's network in this table -[TweetyNet.network] +[vak.train.model.TweetyNet.network] # hidden_size: the number of elements in the hidden state in the recurrent layer of the network hidden_size = 256 From fb70733a0d4df3e0df0af4ee7558d78952822165 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Wed, 1 May 2024 08:37:26 -0400 Subject: [PATCH 013/183] Add link to example config files in docs, in error messages in config/validators.py --- src/vak/config/validators.py | 72 ++++++++++++++++++------------------ 1 file changed, 37 insertions(+), 35 deletions(-) diff --git a/src/vak/config/validators.py b/src/vak/config/validators.py index 71d757d10..cfdb868fa 100644 --- a/src/vak/config/validators.py +++ b/src/vak/config/validators.py @@ -1,7 +1,7 @@ """validators used by attrs-based classes and by vak.parse.parse_config""" from pathlib import Path -import toml +import tomlkit from .. import models from ..common import constants @@ -34,13 +34,13 @@ def is_valid_model_name(instance, attribute, value: str) -> None: def is_audio_format(instance, attribute, value): - """check if valid audio format""" + """Check if valid audio format""" if value not in constants.VALID_AUDIO_FORMATS: raise ValueError(f"{value} is not a valid format for audio files") def is_annot_format(instance, attribute, value): - """check if valid annotation format""" + """Check if valid annotation format""" if value not in constants.VALID_ANNOT_FORMATS: raise ValueError( f"{value} is not a valid format for annotation files.\n" @@ -49,7 +49,7 @@ def is_annot_format(instance, attribute, value): def is_spect_format(instance, attribute, value): - """check if valid format for spectrograms""" + """Check if valid format for spectrograms""" if value not in constants.VALID_SPECT_FORMATS: raise ValueError( f"{value} is not a valid format for spectrogram files.\n" @@ -60,70 +60,72 @@ def is_spect_format(instance, attribute, value): CONFIG_DIR = Path(__file__).parent VALID_TOML_PATH = CONFIG_DIR.joinpath("valid.toml") with VALID_TOML_PATH.open("r") as fp: - VALID_DICT = toml.load(fp) -VALID_SECTIONS = list(VALID_DICT.keys()) + VALID_DICT = tomlkit.load(fp)['vak'] +VALID_TOP_LEVEL_TABLES = list(VALID_DICT.keys()) VALID_OPTIONS = { - section: list(options.keys()) for section, options in VALID_DICT.items() + table: list(options.keys()) for table, options in VALID_DICT.items() } -def are_sections_valid(config_dict, toml_path=None): - sections = list(config_dict.keys()) +def are_tables_valid(config_dict, toml_path=None): + tables = list(config_dict.keys()) from ..cli.cli import CLI_COMMANDS # avoid circular import cli_commands_besides_prep = [ command for command in CLI_COMMANDS if command != "prep" ] - sections_that_are_commands_besides_prep = [ - section - for section in sections - if section.lower() in cli_commands_besides_prep + tables_that_are_commands_besides_prep = [ + table + for table in tables + if table in cli_commands_besides_prep ] - if len(sections_that_are_commands_besides_prep) == 0: + if len(tables_that_are_commands_besides_prep) == 0: raise ValueError( - "did not find a section related to a vak command in config besides `prep`.\n" - f"Sections in config were: {sections}" + "Did not find a table related to a vak command in config besides `prep`.\n" + f"Sections in config were: {tables}\n" + "Please see example toml configuration files here: https://github.com/vocalpy/vak/tree/main/doc/toml" ) - if len(sections_that_are_commands_besides_prep) > 1: + if len(tables_that_are_commands_besides_prep) > 1: raise ValueError( - "found multiple sections related to a vak command in config besides `prep`.\n" - f"Those sections are: {sections_that_are_commands_besides_prep}. " - f"Please use just one command besides `prep` per .toml configuration file" + "Found multiple tables related to a vak command in config besides `prep`.\n" + f"Those tables are: {tables_that_are_commands_besides_prep}. " + f"Please use just one command besides `prep` per .toml configuration file.\n" + "See example toml configuration files here: https://github.com/vocalpy/vak/tree/main/doc/toml" ) - MODEL_NAMES = list(models.registry.MODEL_NAMES) - # add model names to valid sections so users can define model config in sections - valid_sections = VALID_SECTIONS + MODEL_NAMES - for section in sections: - if ( - section not in valid_sections - and f"{section}Model" not in valid_sections - ): + for table in tables: + if table not in VALID_TOP_LEVEL_TABLES: + # and f"{table}Model" not in valid_tables + # ): if toml_path: err_msg = ( - f"section defined in {toml_path} is not valid: {section}" + f"Top-level table defined in {toml_path} is not valid: {table}\n + f"Valid top-level tables are: {VALID_TOP_LEVEL_TABLES}\n + "Please see example toml configuration files here: https://github.com/vocalpy/vak/tree/main/doc/toml" ) else: err_msg = ( - f"section defined in toml config is not valid: {section}" + f"Table defined in toml config is not valid: {table}\n" + f"Valid top-level tables are: {VALID_TOP_LEVEL_TABLES}\n" + "Please see example toml configuration files here: https://github.com/vocalpy/vak/tree/main/doc/toml" ) raise ValueError(err_msg) -def are_options_valid(config_dict, section, toml_path=None): - user_options = set(config_dict[section].keys()) - valid_options = set(VALID_OPTIONS[section]) +def are_options_valid(config_dict, table, toml_path=None): + user_options = set(config_dict[table].keys()) + valid_options = set(VALID_OPTIONS[table]) if not user_options.issubset(valid_options): invalid_options = user_options - valid_options if toml_path: err_msg = ( - f"the following options from {section} section in " + f"The following options from '{table}' table in " f"the config file '{toml_path.name}' are not valid:\n{invalid_options}" ) else: err_msg = ( - f"the following options from {section} section in " + f"The following options from '{table}' table in " f"the toml config are not valid:\n{invalid_options}" ) raise ValueError(err_msg) From b6531389bef3d0056680fc8382c6cdd66fcd143e Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Wed, 1 May 2024 09:24:03 -0400 Subject: [PATCH 014/183] Remove 'spect_params' from REQUIRED_OPTIONS in config/parse.py, this is not a top-level table and will be an attribute of prep instead --- src/vak/config/parse.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/vak/config/parse.py b/src/vak/config/parse.py index 29b31b952..f8551d345 100644 --- a/src/vak/config/parse.py +++ b/src/vak/config/parse.py @@ -15,6 +15,7 @@ from .train import TrainConfig from .validators import are_options_valid, are_tables_valid + TABLE_CLASSES = { "eval": EvalConfig, "learncurve": LearncurveConfig, @@ -42,7 +43,6 @@ "data_dir", "output_dir", ], - "spect_params": None, "train": [ "model", "root_results_dir", From 83eddf0a50cb03ee11fcc9133c4cabfa10fa8c35 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Wed, 1 May 2024 09:25:02 -0400 Subject: [PATCH 015/183] Rename 'config_toml' -> 'config_dict' in config/parse.py --- src/vak/config/parse.py | 39 +++++++++++++++++++++------------------ 1 file changed, 21 insertions(+), 18 deletions(-) diff --git a/src/vak/config/parse.py b/src/vak/config/parse.py index f8551d345..ddc4c44be 100644 --- a/src/vak/config/parse.py +++ b/src/vak/config/parse.py @@ -50,12 +50,12 @@ } -def parse_config_table(config_toml, table_name, toml_path=None): +def parse_config_table(config_dict, table_name, toml_path=None): """Parse table of config.toml file Parameters ---------- - config_toml : dict + config_dict : dict Containing config.toml file already loaded by parse function table_name : str Name of table from configuration @@ -70,7 +70,7 @@ def parse_config_table(config_toml, table_name, toml_path=None): instance of class that represents table of config.toml file, e.g. PredictConfig for 'PREDICT' table """ - table = dict(config_toml[table_name].items()) + table = dict(config_dict[table_name].items()) required_options = REQUIRED_OPTIONS[table_name] if required_options is not None: @@ -115,13 +115,13 @@ def _validate_tables_arg_convert_list(tables): def from_toml( - config_toml: dict, toml_path: str | pathlib.Path | None = None, tables: list[str] | None = None + config_dict: dict, toml_path: str | pathlib.Path | None = None, tables: str | list[str] | None = None ) -> Config: """Load a TOML configuration file. Parameters ---------- - config_toml : dict + config_dict : dict Python ``dict`` containing a .toml configuration file, parsed by the ``toml`` library. toml_path : str, pathlib.Path @@ -140,23 +140,26 @@ def from_toml( instance of Config class, whose attributes correspond to tables in a config.toml file. """ - are_tables_valid(config_toml, toml_path) - + are_tables_valid(config_dict, toml_path) tables = _validate_tables_arg_convert_list(tables) - config_dict = {} + config_kwargs = {} if tables is None: tables = list( TABLE_CLASSES.keys() ) # i.e., parse all tables, except model for table_name in tables: - if table_name in config_toml: - are_options_valid(config_toml, table_name, toml_path) - config_dict[table_name.lower()] = parse_config_table( - config_toml, table_name, toml_path + if table_name in config_dict: + are_options_valid(config_dict, table_name, toml_path) + config_kwargs[table_name.lower()] = parse_config_table( + config_dict, table_name, toml_path + ) + else: + raise KeyError( + f"A table specified in `tables` was not found in the config: {table_name}" ) - return Config(**config_dict) + return Config(**config_kwargs) def _load_toml_from_path(toml_path: str | pathlib.Path) -> dict: @@ -172,19 +175,19 @@ def _load_toml_from_path(toml_path: str | pathlib.Path) -> dict: try: with toml_path.open("r") as fp: - config_toml: dict = tomlkit.load(fp) + config_dict: dict = tomlkit.load(fp) except tomlkit.exceptions.TOMLKitError as e: raise Exception( f"Error when parsing .toml config file: {toml_path}" ) from e - if 'vak' not in config_toml: + if 'vak' not in config_dict: raise ValueError( "Toml file does not contain a top-level table named `vak`. " f"Please see example configuration files here: " ) - return config_toml['vak'] + return config_dict['vak'] def from_toml_path(toml_path: str | pathlib.Path, tables: list[str] | None = None) -> Config: @@ -210,5 +213,5 @@ def from_toml_path(toml_path: str | pathlib.Path, tables: list[str] | None = Non instance of :class:`Config` class, whose attributes correspond to tables in a config.toml file. """ - config_toml = _load_toml_from_path(toml_path) - return from_toml(config_toml, toml_path, tables) + config_dict = _load_toml_from_path(toml_path) + return from_toml(config_dict, toml_path, tables) From 1160a6b689df7126465b5c6d3e7fe5b2f56619f2 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Wed, 1 May 2024 09:25:21 -0400 Subject: [PATCH 016/183] Fix function _validate_tables_arg_convert_list in config/parse.py --- src/vak/config/parse.py | 43 +++++++++++++++++++++++------------------ 1 file changed, 24 insertions(+), 19 deletions(-) diff --git a/src/vak/config/parse.py b/src/vak/config/parse.py index ddc4c44be..721893530 100644 --- a/src/vak/config/parse.py +++ b/src/vak/config/parse.py @@ -90,27 +90,32 @@ def parse_config_table(config_dict, table_name, toml_path=None): return TABLE_CLASSES[table_name](**table) -def _validate_tables_arg_convert_list(tables): +def _validate_tables_arg_convert_list(tables: str | list[str]) -> list[str]: if isinstance(tables, str): tables = [tables] - elif isinstance(tables, list): - if not all( - [isinstance(table_name, str) for table_name in tables] - ): - raise ValueError( - "all table names in 'tables' should be strings" - ) - if not all( - [ - table_name in list(TABLE_CLASSES.keys()) - for table_name in tables - ] - ): - raise ValueError( - "all table names in 'tables' should be valid names of tables. " - f"Values for 'tables were: {tables}.\n" - f"Valid table names are: {list(TABLE_CLASSES.keys())}" - ) + + if not isinstance(tables, list): + raise TypeError( + f"`tables` should be a string or list of strings but type was: {type(tables)}" + ) + + if not all( + [isinstance(table_name, str) for table_name in tables] + ): + raise ValueError( + "All table names in 'tables' should be strings" + ) + if not all( + [ + table_name in list(TABLE_CLASSES.keys()) + for table_name in tables + ] + ): + raise ValueError( + "All table names in 'tables' should be valid names of tables. " + f"Values for 'tables were: {tables}.\n" + f"Valid table names are: {list(TABLE_CLASSES.keys())}" + ) return tables From 97966243f7514840b68b077cde0d9f31d4f31fa4 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Wed, 1 May 2024 09:25:52 -0400 Subject: [PATCH 017/183] Fix error message formatting in src/vak/config/validators.py --- src/vak/config/validators.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/vak/config/validators.py b/src/vak/config/validators.py index cfdb868fa..f31778cf7 100644 --- a/src/vak/config/validators.py +++ b/src/vak/config/validators.py @@ -96,12 +96,10 @@ def are_tables_valid(config_dict, toml_path=None): for table in tables: if table not in VALID_TOP_LEVEL_TABLES: - # and f"{table}Model" not in valid_tables - # ): if toml_path: err_msg = ( - f"Top-level table defined in {toml_path} is not valid: {table}\n - f"Valid top-level tables are: {VALID_TOP_LEVEL_TABLES}\n + f"Top-level table defined in {toml_path} is not valid: {table}\n" + f"Valid top-level tables are: {VALID_TOP_LEVEL_TABLES}\n" "Please see example toml configuration files here: https://github.com/vocalpy/vak/tree/main/doc/toml" ) else: From 60f7b02f6b8ad007a1359b64a1f248f6787d9729 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Wed, 1 May 2024 09:26:54 -0400 Subject: [PATCH 018/183] Add ModelConfig class to config/model.py, add type annotations, fix config_from_toml_dict to look in specific section --- src/vak/config/model.py | 105 +++++++++++++++++++++++++++++++++++----- 1 file changed, 92 insertions(+), 13 deletions(-) diff --git a/src/vak/config/model.py b/src/vak/config/model.py index d1d36643b..6caad80e7 100644 --- a/src/vak/config/model.py +++ b/src/vak/config/model.py @@ -1,10 +1,12 @@ +"""Class representing the model table of a toml configuration file.""" from __future__ import annotations import pathlib -import toml +from attrs import define, field from .. import models +from . import parse, validators MODEL_TABLES = [ "network", @@ -14,7 +16,90 @@ ] -def config_from_toml_dict(toml_dict: dict, model_name: str) -> dict: +@define +class ModelConfig: + """Class representing the model table of a toml configuration file. + + Attributes + ---------- + name : str + network : dict + Keyword arguments for the network class, + or a :class:`dict` of ``dict``s mapping + network names to keyword arguments. + optimizer: dict + Keyword arguments for the optimizer class. + loss : dict + Keyword arguments for the class representing the loss function. + metrics: dict + A :class:`dict` of ``dict``s mapping + metric names to keyword arguments. + """ + name: str + network: dict = field(validators=isinstance(dict)) + optimizer: dict = field(validators=isinstance(dict)) + loss: dict = field(validators=isinstance(dict)) + metrics: dict = field(validators=isinstance(dict)) + + @classmethod + def from_config_dict(cls, config_dict: dict): + """Return :class:`ModelConfig` instance from a :class:`dict`. + + The :class:`dict` passed in should be the one found + by loading a valid configuration toml file with + :func:`vak.config.parse.from_toml_path`, + and then using a top-level table key, + followed by key ``'model'``. + E.g., ``config_dict['train']['model']` or + ``config_dict['predict']['model']``. + + Examples + -------- + config_dict = vak.config.parse.from_toml_path(toml_path) + model_config = vak.config.Model.from_config_dict(config_dict['train']) + """ + model_name = list(config_dict.keys()) + if len(model_name) == 0: + raise ValueError( + "Did not find a single key in `config_dict` corresponding to model name. " + f"Instead found no keys. Config dict:\n{config_dict}\n" + "A configuration file should specify a single model per top-level table." + ) + if len(model_name) > 1: + raise ValueError( + "Did not find a single key in `config_dict` corresponding to model name. " + f"Instead found multiple keys: {model_name}.\nConfig dict:\n{config_dict}.\n" + "A configuration file should specify a single model per top-level table." + ) + model_name = model_name[0] + MODEL_NAMES = list(models.registry.MODEL_NAMES) + if model_name not in MODEL_NAMES: + raise ValueError( + f"Model name not found in registry: {model_name}\n" + f"Model names in registry:\n{MODEL_NAMES}" + ) + model_config = config_dict[model_name] + if not all( + key in MODEL_TABLES for key in model_config.keys() + ): + invalid_keys = ( + key for key in model_config.keys() if key not in MODEL_TABLES + ) + raise ValueError( + f"The following sub-tables in the model config are not valid: {invalid_keys}\n" + f"Valid sub-table names are: {MODEL_TABLES}" + ) + # for any tables not specified, default to empty dict so we can still use ``**`` operator on it + for model_table in MODEL_TABLES: + if model_table not in config_dict: + model_config[model_table] = {} + return cls( + name=model_name, + **model_config + ) + + +def config_from_toml_dict(toml_dict: dict, table: str, model_name: str) -> dict: """Get configuration for a model from a .toml configuration file loaded into a ``dict``. @@ -39,9 +124,10 @@ def config_from_toml_dict(toml_dict: dict, model_name: str) -> dict: raise ValueError( f"Invalid model name: {model_name}.\nValid model names are: {models.registry.MODEL_NAMES}" ) + validators.are_tables_valid(toml_dict) try: - model_config = toml_dict[model_name] + model_config = toml_dict[table][model_name] except KeyError as e: raise ValueError( f"A config section specifies the model name '{model_name}', " @@ -58,7 +144,7 @@ def config_from_toml_dict(toml_dict: dict, model_name: str) -> dict: def config_from_toml_path( - toml_path: str | pathlib.Path, model_name: str + toml_path: str | pathlib.Path, table: str, model_name: str ) -> dict: """Get configuration for a model from a .toml configuration file, given the path to the file. @@ -79,12 +165,5 @@ def config_from_toml_path( as loaded from a .toml file, and used by the model method ``from_config``. """ - toml_path = pathlib.Path(toml_path) - if not toml_path.is_file(): - raise FileNotFoundError( - f"File not found, or not recognized as a file: {toml_path}" - ) - - with toml_path.open("r") as fp: - config_dict = toml.load(fp) - return config_from_toml_dict(config_dict, model_name) + toml_dict = parse._load_toml_from_path(toml_path) + return config_from_toml_dict(toml_dict, model_name) From 910c5df24a2e91faa73d3ae2e524a2394cf7372d Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Wed, 1 May 2024 09:28:48 -0400 Subject: [PATCH 019/183] Fixup fixing config_from_toml_dict to look in specific section --- src/vak/config/model.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/vak/config/model.py b/src/vak/config/model.py index 6caad80e7..399384a3d 100644 --- a/src/vak/config/model.py +++ b/src/vak/config/model.py @@ -107,6 +107,8 @@ def config_from_toml_dict(toml_dict: dict, table: str, model_name: str) -> dict: ---------- toml_dict : dict Configuration from a .toml file, loaded into a dictionary. + table : str + Name of top-level table to get model config from. model_name : str Name of a model, specified as the ``model`` option in a table (such as TRAIN or PREDICT), @@ -131,7 +133,7 @@ def config_from_toml_dict(toml_dict: dict, table: str, model_name: str) -> dict: except KeyError as e: raise ValueError( f"A config section specifies the model name '{model_name}', " - f"but there is no section named '{model_name}' in the config." + f"but there is no section named '{model_name}' in the '{table}' table of the config." ) from e # check if config declares parameters for required attributes; @@ -153,6 +155,8 @@ def config_from_toml_path( ---------- toml_path : str, Path to configuration file in .toml format + table : str + Name of top-level table to get model config from. model_name : str of str, i.e. names of models specified by a section (such as TRAIN or PREDICT) that should each have corresponding sections @@ -166,4 +170,4 @@ def config_from_toml_path( and used by the model method ``from_config``. """ toml_dict = parse._load_toml_from_path(toml_path) - return config_from_toml_dict(toml_dict, model_name) + return config_from_toml_dict(toml_dict, table, model_name) From 32995bf83a55ec7317f6ff5e3b02086863bd7853 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Wed, 1 May 2024 09:30:31 -0400 Subject: [PATCH 020/183] Rewrite config/eval.py with 'modern' attrs --- src/vak/config/eval.py | 36 ++++++++++++++++++------------------ 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/src/vak/config/eval.py b/src/vak/config/eval.py index 6991b89a7..b82b5ef2b 100644 --- a/src/vak/config/eval.py +++ b/src/vak/config/eval.py @@ -1,7 +1,7 @@ -"""parses [EVAL] section of config""" -import attr -from attr import converters, validators -from attr.validators import instance_of +"""Class and functions for [vak.eval] table in config""" +from attrs import define, field +from attrs import converters, validators +from attrs.validators import instance_of from ..common import device from ..common.converters import expanded_user_path @@ -67,9 +67,9 @@ def are_valid_post_tfm_kwargs(instance, attribute, value): ) -@attr.s +@define class EvalConfig: - """class that represents [EVAL] section of config.toml file + """Class that represents [vak.eval] table in config.toml file Attributes ---------- @@ -119,18 +119,18 @@ class EvalConfig: """ # required, external files - checkpoint_path = attr.ib(converter=expanded_user_path) - output_dir = attr.ib(converter=expanded_user_path) + checkpoint_path: pathlib.Path = field(converter=expanded_user_path) + output_dir: pathlib.Path = field(converter=expanded_user_path) # required, model / dataloader - model = attr.ib( + model = field( validator=[instance_of(str), is_valid_model_name], ) - batch_size = attr.ib(converter=int, validator=instance_of(int)) + batch_size = field(converter=int, validator=instance_of(int)) # dataset_path is actually 'required' but we can't enforce that here because cli.prep looks at # what sections are defined to figure out where to add dataset_path after it creates the csv - dataset_path = attr.ib( + dataset_path = field( converter=converters.optional(expanded_user_path), default=None, ) @@ -138,32 +138,32 @@ class EvalConfig: # "optional" but actually required for frame classification models # TODO: check model family in __post_init__ and raise ValueError if labelmap # TODO: not specified for a frame classification model? - labelmap_path = attr.ib( + labelmap_path = field( converter=converters.optional(expanded_user_path), default=None ) # optional, transform - spect_scaler_path = attr.ib( + spect_scaler_path = field( converter=converters.optional(expanded_user_path), default=None, ) - post_tfm_kwargs = attr.ib( + post_tfm_kwargs = field( validator=validators.optional(are_valid_post_tfm_kwargs), converter=converters.optional(convert_post_tfm_kwargs), default=None, ) # optional, data loader - num_workers = attr.ib(validator=instance_of(int), default=2) - device = attr.ib(validator=instance_of(str), default=device.get_default()) + num_workers = field(validator=instance_of(int), default=2) + device = field(validator=instance_of(str), default=device.get_default()) - transform_params = attr.ib( + transform_params = field( converter=converters.optional(dict), validator=validators.optional(instance_of(dict)), default=None, ) - dataset_params = attr.ib( + dataset_params = field( converter=converters.optional(dict), validator=validators.optional(instance_of(dict)), default=None, From f7757498f97a0f80f60ae2e63775ba62edcc4719 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Wed, 1 May 2024 09:38:03 -0400 Subject: [PATCH 021/183] Fixup rewrite config/eval with 'modern attrs --- src/vak/config/eval.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/vak/config/eval.py b/src/vak/config/eval.py index b82b5ef2b..8e8f802e3 100644 --- a/src/vak/config/eval.py +++ b/src/vak/config/eval.py @@ -1,4 +1,8 @@ -"""Class and functions for [vak.eval] table in config""" +"""Class and functions for ``[vak.eval]`` table in configuration file.""" +from __future__ import annotations + +import pathlib + from attrs import define, field from attrs import converters, validators from attrs.validators import instance_of @@ -69,7 +73,7 @@ def are_valid_post_tfm_kwargs(instance, attribute, value): @define class EvalConfig: - """Class that represents [vak.eval] table in config.toml file + """Class that represents [vak.eval] table in configuration file. Attributes ---------- From 060abaa9b7a633727f9a0811eac508c4df55b9a1 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Wed, 1 May 2024 09:38:16 -0400 Subject: [PATCH 022/183] Rewrite config/learncurve.py with 'modern' attrs --- src/vak/config/learncurve.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/vak/config/learncurve.py b/src/vak/config/learncurve.py index 13fc6021a..d1489de6f 100644 --- a/src/vak/config/learncurve.py +++ b/src/vak/config/learncurve.py @@ -1,14 +1,16 @@ -"""parses [LEARNCURVE] section of config""" -import attr -from attr import converters, validators +"""Class that represents ``[vak.learncurve]`` table in configuration file.""" +from __future__ import annotations + +from attrs import define, field +from attrs import converters, validators from .eval import are_valid_post_tfm_kwargs, convert_post_tfm_kwargs from .train import TrainConfig -@attr.s +@define class LearncurveConfig(TrainConfig): - """class that represents [LEARNCURVE] section of config.toml file + """Class that represents ``[vak.learncurve]`` table in configuration file. Attributes ---------- @@ -51,7 +53,7 @@ class LearncurveConfig(TrainConfig): these arguments and how they work. """ - post_tfm_kwargs = attr.ib( + post_tfm_kwargs = field( validator=validators.optional(are_valid_post_tfm_kwargs), converter=converters.optional(convert_post_tfm_kwargs), default=None, From 71e89bcfccd73ca9cafbef68c8f0c02eec39f7dd Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Wed, 1 May 2024 09:38:28 -0400 Subject: [PATCH 023/183] Rewrite config/predict.py with 'modern' attrs --- src/vak/config/predict.py | 40 ++++++++++++++++++++------------------- 1 file changed, 21 insertions(+), 19 deletions(-) diff --git a/src/vak/config/predict.py b/src/vak/config/predict.py index 852605165..1f5e402b1 100644 --- a/src/vak/config/predict.py +++ b/src/vak/config/predict.py @@ -1,8 +1,10 @@ -"""parses [PREDICT] section of config""" +"""Class that represents ``[vak.prep]`` section of configuration file.""" +from __future__ import annotations + import os from pathlib import Path -import attr +from attrs import define, field from attr import converters, validators from attr.validators import instance_of @@ -11,9 +13,9 @@ from .validators import is_valid_model_name -@attr.s +@define class PredictConfig: - """class that represents [PREDICT] section of config.toml file + """Class that represents ``[vak.prep]`` section of configuration file. Attributes ---------- @@ -79,52 +81,52 @@ class PredictConfig: """ # required, external files - checkpoint_path = attr.ib(converter=expanded_user_path) - labelmap_path = attr.ib(converter=expanded_user_path) + checkpoint_path = field(converter=expanded_user_path) + labelmap_path = field(converter=expanded_user_path) # required, model / dataloader - model = attr.ib( + model = field( validator=[instance_of(str), is_valid_model_name], ) - batch_size = attr.ib(converter=int, validator=instance_of(int)) + batch_size = field(converter=int, validator=instance_of(int)) # dataset_path is actually 'required' but we can't enforce that here because cli.prep looks at # what sections are defined to figure out where to add dataset_path after it creates the csv - dataset_path = attr.ib( + dataset_path = field( converter=converters.optional(expanded_user_path), default=None, ) # optional, transform - spect_scaler_path = attr.ib( + spect_scaler_path = field( converter=converters.optional(expanded_user_path), default=None, ) # optional, data loader - num_workers = attr.ib(validator=instance_of(int), default=2) - device = attr.ib(validator=instance_of(str), default=device.get_default()) + num_workers = field(validator=instance_of(int), default=2) + device = field(validator=instance_of(str), default=device.get_default()) - annot_csv_filename = attr.ib( + annot_csv_filename = field( validator=validators.optional(instance_of(str)), default=None ) - output_dir = attr.ib( + output_dir = field( converter=expanded_user_path, default=Path(os.getcwd()), ) - min_segment_dur = attr.ib( + min_segment_dur = field( validator=validators.optional(instance_of(float)), default=None ) - majority_vote = attr.ib(validator=instance_of(bool), default=True) - save_net_outputs = attr.ib(validator=instance_of(bool), default=False) + majority_vote = field(validator=instance_of(bool), default=True) + save_net_outputs = field(validator=instance_of(bool), default=False) - transform_params = attr.ib( + transform_params = field( converter=converters.optional(dict), validator=validators.optional(instance_of(dict)), default=None, ) - dataset_params = attr.ib( + dataset_params = field( converter=converters.optional(dict), validator=validators.optional(instance_of(dict)), default=None, From f6cf94531e461c291513a1bbac363f91fc3fe98c Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Wed, 1 May 2024 09:38:38 -0400 Subject: [PATCH 024/183] Rewrite config/prep.py with 'modern' attrs --- src/vak/config/prep.py | 44 ++++++++++++++++++++++-------------------- 1 file changed, 23 insertions(+), 21 deletions(-) diff --git a/src/vak/config/prep.py b/src/vak/config/prep.py index 031cd0e00..aa7c98092 100644 --- a/src/vak/config/prep.py +++ b/src/vak/config/prep.py @@ -1,10 +1,12 @@ -"""parses [vak.prep] section of config""" +"""Class and functions for ``[vak.prep]`` table in configuration file.""" +from __future__ import annotations + import inspect -import attr +from attrs import define, field import dask.bag -from attr import converters, validators -from attr.validators import instance_of +from attrs import converters, validators +from attrs.validators import instance_of from .. import prep from ..common.converters import expanded_user_path, labelset_to_set @@ -60,9 +62,9 @@ def are_valid_dask_bag_kwargs(instance, attribute, value): ) -@attr.s +@define class PrepConfig: - """class to represent [PREP] section of config.toml file + """Class that represents ``[vak.prep]`` table of configuration file. Attributes ---------- @@ -127,10 +129,10 @@ class PrepConfig: Default is None. Required if config file has a learncurve section. """ - data_dir = attr.ib(converter=expanded_user_path) - output_dir = attr.ib(converter=expanded_user_path) + data_dir = field(converter=expanded_user_path) + output_dir = field(converter=expanded_user_path) - dataset_type = attr.ib(validator=instance_of(str)) + dataset_type = field(validator=instance_of(str)) @dataset_type.validator def is_valid_dataset_type(self, attribute, value): @@ -140,7 +142,7 @@ def is_valid_dataset_type(self, attribute, value): f"Valid dataset types are: {prep.constants.DATASET_TYPES}" ) - input_type = attr.ib(validator=instance_of(str)) + input_type = field(validator=instance_of(str)) @input_type.validator def is_valid_input_type(self, attribute, value): @@ -149,49 +151,49 @@ def is_valid_input_type(self, attribute, value): f"Invalid input type: {value}. Must be one of: {prep.constants.INPUT_TYPES}" ) - audio_format = attr.ib( + audio_format = field( validator=validators.optional(is_audio_format), default=None ) - spect_format = attr.ib( + spect_format = field( validator=validators.optional(is_spect_format), default=None ) - annot_file = attr.ib( + annot_file = field( converter=converters.optional(expanded_user_path), default=None, ) - annot_format = attr.ib( + annot_format = field( validator=validators.optional(is_annot_format), default=None ) - labelset = attr.ib( + labelset = field( converter=converters.optional(labelset_to_set), validator=validators.optional(instance_of(set)), default=None, ) - audio_dask_bag_kwargs = attr.ib( + audio_dask_bag_kwargs = field( validator=validators.optional(are_valid_dask_bag_kwargs), default=None ) - train_dur = attr.ib( + train_dur = field( converter=converters.optional(duration_from_toml_value), validator=validators.optional(is_valid_duration), default=None, ) - val_dur = attr.ib( + val_dur = field( converter=converters.optional(duration_from_toml_value), validator=validators.optional(is_valid_duration), default=None, ) - test_dur = attr.ib( + test_dur = field( converter=converters.optional(duration_from_toml_value), validator=validators.optional(is_valid_duration), default=None, ) - train_set_durs = attr.ib( + train_set_durs = field( validator=validators.optional(instance_of(list)), default=None ) - num_replicates = attr.ib( + num_replicates = field( validator=validators.optional(instance_of(int)), default=None ) From 9798707f9e8436563076a086e9167635ebf30e68 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Wed, 1 May 2024 09:38:57 -0400 Subject: [PATCH 025/183] Rewrite config/train.py with 'modern' attrs --- src/vak/config/train.py | 50 ++++++++++++++++++++--------------------- 1 file changed, 25 insertions(+), 25 deletions(-) diff --git a/src/vak/config/train.py b/src/vak/config/train.py index 034a110b4..21c378aa2 100644 --- a/src/vak/config/train.py +++ b/src/vak/config/train.py @@ -1,16 +1,16 @@ -"""parses [TRAIN] section of config""" -import attr -from attr import converters, validators -from attr.validators import instance_of +"""Class that represents ``[vak.train]`` table of configuration file.""" +from attrs import define, field +from attrs import converters, validators +from attrs.validators import instance_of from ..common import device from ..common.converters import bool_from_str, expanded_user_path from .validators import is_valid_model_name -@attr.s +@define class TrainConfig: - """class that represents [TRAIN] section of config.toml file + """Class that represents ``[vak.train]`` table of configuration file. Attributes ---------- @@ -64,81 +64,81 @@ class TrainConfig: """ # required - model = attr.ib( + model = field( validator=[instance_of(str), is_valid_model_name], ) - num_epochs = attr.ib(converter=int, validator=instance_of(int)) - batch_size = attr.ib(converter=int, validator=instance_of(int)) - root_results_dir = attr.ib(converter=expanded_user_path) + num_epochs = field(converter=int, validator=instance_of(int)) + batch_size = field(converter=int, validator=instance_of(int)) + root_results_dir = field(converter=expanded_user_path) # optional # dataset_path is actually 'required' but we can't enforce that here because cli.prep looks at # what sections are defined to figure out where to add dataset_path after it creates the csv - dataset_path = attr.ib( + dataset_path = field( converter=converters.optional(expanded_user_path), default=None, ) - results_dirname = attr.ib( + results_dirname = field( converter=converters.optional(expanded_user_path), default=None, ) - normalize_spectrograms = attr.ib( + normalize_spectrograms = field( converter=bool_from_str, validator=validators.optional(instance_of(bool)), default=False, ) - num_workers = attr.ib(validator=instance_of(int), default=2) - device = attr.ib(validator=instance_of(str), default=device.get_default()) - shuffle = attr.ib( + num_workers = field(validator=instance_of(int), default=2) + device = field(validator=instance_of(str), default=device.get_default()) + shuffle = field( converter=bool_from_str, validator=instance_of(bool), default=True ) - val_step = attr.ib( + val_step = field( converter=converters.optional(int), validator=validators.optional(instance_of(int)), default=None, ) - ckpt_step = attr.ib( + ckpt_step = field( converter=converters.optional(int), validator=validators.optional(instance_of(int)), default=None, ) - patience = attr.ib( + patience = field( converter=converters.optional(int), validator=validators.optional(instance_of(int)), default=None, ) - checkpoint_path = attr.ib( + checkpoint_path = field( converter=converters.optional(expanded_user_path), default=None, ) - spect_scaler_path = attr.ib( + spect_scaler_path = field( converter=converters.optional(expanded_user_path), default=None, ) - train_transform_params = attr.ib( + train_transform_params = field( converter=converters.optional(dict), validator=validators.optional(instance_of(dict)), default=None, ) - train_dataset_params = attr.ib( + train_dataset_params = field( converter=converters.optional(dict), validator=validators.optional(instance_of(dict)), default=None, ) - val_transform_params = attr.ib( + val_transform_params = field( converter=converters.optional(dict), validator=validators.optional(instance_of(dict)), default=None, ) - val_dataset_params = attr.ib( + val_dataset_params = field( converter=converters.optional(dict), validator=validators.optional(instance_of(dict)), default=None, From 9c0d301ff8c0ad88656383fb7ba929677a064e03 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Wed, 1 May 2024 09:55:47 -0400 Subject: [PATCH 026/183] Rename Dataset -> DatasetConfig in config/dataset.py --- src/vak/config/dataset.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/vak/config/dataset.py b/src/vak/config/dataset.py index f795d74a3..cc9f7eb63 100644 --- a/src/vak/config/dataset.py +++ b/src/vak/config/dataset.py @@ -1,4 +1,4 @@ -"""Class that represents dataset table in toml config file.""" +"""Class that represents dataset table in configuration file.""" from __future__ import annotations import pathlib @@ -8,8 +8,8 @@ @define -class Dataset: - """Class that represents dataset table in toml config file. +class DatasetConfig: + """Class that represents dataset table in configuration file. Attributes ---------- @@ -32,7 +32,7 @@ class Dataset: ) @classmethod - def from_dict(cls, dict_: dict) -> Dataset: + def from_config_dict(cls, dict_: dict) -> DatasetConfig: return cls( path=dict_.get('path'), name=dict_.get('name'), From 0de12daaf17d32d3908919d5cd89784f2af5e58d Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Wed, 1 May 2024 09:56:14 -0400 Subject: [PATCH 027/183] Add are_table_options_valid to config/validators.py, will be used by classmethods from_config_dict --- src/vak/config/validators.py | 41 +++++++++++++++++++++++++++++++----- 1 file changed, 36 insertions(+), 5 deletions(-) diff --git a/src/vak/config/validators.py b/src/vak/config/validators.py index f31778cf7..6a81f1912 100644 --- a/src/vak/config/validators.py +++ b/src/vak/config/validators.py @@ -1,5 +1,5 @@ """validators used by attrs-based classes and by vak.parse.parse_config""" -from pathlib import Path +import pathlib import tomlkit @@ -9,7 +9,7 @@ def is_a_directory(instance, attribute, value): """check if given path is a directory""" - if not Path(value).is_dir(): + if not pathlib.Path(value).is_dir(): raise NotADirectoryError( f"Value specified for {attribute.name} of {type(instance)} not recognized as a directory:\n" f"{value}" @@ -18,7 +18,7 @@ def is_a_directory(instance, attribute, value): def is_a_file(instance, attribute, value): """check if given path is a file""" - if not Path(value).is_file(): + if not pathlib.Path(value).is_file(): raise FileNotFoundError( f"Value specified for {attribute.name} of {type(instance)} not recognized as a file:\n" f"{value}" @@ -57,7 +57,7 @@ def is_spect_format(instance, attribute, value): ) -CONFIG_DIR = Path(__file__).parent +CONFIG_DIR = pathlib.Path(__file__).parent VALID_TOML_PATH = CONFIG_DIR.joinpath("valid.toml") with VALID_TOML_PATH.open("r") as fp: VALID_DICT = tomlkit.load(fp)['vak'] @@ -111,7 +111,11 @@ def are_tables_valid(config_dict, toml_path=None): raise ValueError(err_msg) -def are_options_valid(config_dict, table, toml_path=None): +def are_options_valid( + config_dict: dict, table: str, toml_path: str | pathlib.Path | None = None + ) -> None: + """Given a :class:`dict` containing the *entire* configuration loaded from a toml file, + validate the option names for a specific top-level table, e.g. ``vak.train`` or ``vak.predict``""" user_options = set(config_dict[table].keys()) valid_options = set(VALID_OPTIONS[table]) if not user_options.issubset(valid_options): @@ -127,3 +131,30 @@ def are_options_valid(config_dict, table, toml_path=None): f"the toml config are not valid:\n{invalid_options}" ) raise ValueError(err_msg) + + +def are_table_options_valid(table_config_dict: dict, table: str, toml_path: str | pathlib.Path | None = None) -> None: + """Given a :class:`dict` containing the configuration for a *specific* top-level table, + loaded from a toml file, validate the option names for that table, + e.g. ``vak.train`` or ``vak.predict``. + + This function assumes ``table_config_dict`` comes from the entire ``config_dict`` + returned by :func:`vak.config.parse.from_toml_path`, accessed using the table name as a key, + unlike :func:`are_options_valid`. This function is used by the ``from_config_dict`` + classmethod of the top-level tables. + """ + user_options = set(table_config_dict.keys()) + valid_options = set(VALID_OPTIONS[table]) + if not user_options.issubset(valid_options): + invalid_options = user_options - valid_options + if toml_path: + err_msg = ( + f"The following options from '{table}' table in " + f"the config file '{toml_path.name}' are not valid:\n{invalid_options}" + ) + else: + err_msg = ( + f"The following options from '{table}' table in " + f"the toml config are not valid:\n{invalid_options}" + ) + raise ValueError(err_msg) From 2c272bc37434389cb988c5f7966a9566d5569d55 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Wed, 1 May 2024 09:56:29 -0400 Subject: [PATCH 028/183] WIP: Add from_config_dict classmethod to EvalConfig --- src/vak/config/eval.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/vak/config/eval.py b/src/vak/config/eval.py index 8e8f802e3..bdb21e8f6 100644 --- a/src/vak/config/eval.py +++ b/src/vak/config/eval.py @@ -9,6 +9,8 @@ from ..common import device from ..common.converters import expanded_user_path +from .dataset import DatasetConfig +from .model import ModelConfig from .validators import is_valid_model_name @@ -172,3 +174,14 @@ class EvalConfig: validator=validators.optional(instance_of(dict)), default=None, ) + + @classmethod + def from_config_dict(cls, config_dict: dict) -> EvalConfig: + """""" + # TODO: check for required keys here, including 'model' and 'dataset' + # use + model_config = ModelConfig(config_dict['model']) + dataset_config = DatasetConfig(config_dict['dataset']) + return cls( + **config_dict + ) \ No newline at end of file From 53447f5f8fbf80a2478ae89493a747e86c5e06e2 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Wed, 1 May 2024 09:57:45 -0400 Subject: [PATCH 029/183] WIP: Add tests/test_config/test_dataset.py --- tests/test_config/test_dataset.py | 10 ++++++++++ 1 file changed, 10 insertions(+) create mode 100644 tests/test_config/test_dataset.py diff --git a/tests/test_config/test_dataset.py b/tests/test_config/test_dataset.py new file mode 100644 index 000000000..9b1a294fa --- /dev/null +++ b/tests/test_config/test_dataset.py @@ -0,0 +1,10 @@ +import pytest + +import vak.config.dataset + +class TestDatasetConfig: + def test_init(self): + assert False + + def test_from_config_dict(self, config_dict): + assert False From a6cb67485d3d11933db6d7ed94e4a3928f8bfc5b Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Wed, 1 May 2024 19:57:37 -0400 Subject: [PATCH 030/183] Make fixes to ModelConfig class, fix circular imports in config/model.py module --- src/vak/config/model.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/vak/config/model.py b/src/vak/config/model.py index 399384a3d..307731951 100644 --- a/src/vak/config/model.py +++ b/src/vak/config/model.py @@ -4,9 +4,10 @@ import pathlib from attrs import define, field +from attrs.validators import instance_of from .. import models -from . import parse, validators + MODEL_TABLES = [ "network", @@ -36,10 +37,10 @@ class ModelConfig: metric names to keyword arguments. """ name: str - network: dict = field(validators=isinstance(dict)) - optimizer: dict = field(validators=isinstance(dict)) - loss: dict = field(validators=isinstance(dict)) - metrics: dict = field(validators=isinstance(dict)) + network: dict = field(validator=instance_of(dict)) + optimizer: dict = field(validator=instance_of(dict)) + loss: dict = field(validator=instance_of(dict)) + metrics: dict = field(validator=instance_of(dict)) @classmethod def from_config_dict(cls, config_dict: dict): @@ -126,6 +127,7 @@ def config_from_toml_dict(toml_dict: dict, table: str, model_name: str) -> dict: raise ValueError( f"Invalid model name: {model_name}.\nValid model names are: {models.registry.MODEL_NAMES}" ) + from . import validators # avoid circular import validators.are_tables_valid(toml_dict) try: @@ -169,5 +171,7 @@ def config_from_toml_path( as loaded from a .toml file, and used by the model method ``from_config``. """ + from . import parse # avoid circular import + toml_dict = parse._load_toml_from_path(toml_path) return config_from_toml_dict(toml_dict, table, model_name) From a039dd994f8a7393bcce24ca01067de4fea47417 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Wed, 1 May 2024 19:57:44 -0400 Subject: [PATCH 031/183] Write tests in tests/test_config/test_dataset.py --- tests/test_config/test_dataset.py | 73 +++++++++++++++++++++++++++++-- 1 file changed, 70 insertions(+), 3 deletions(-) diff --git a/tests/test_config/test_dataset.py b/tests/test_config/test_dataset.py index 9b1a294fa..3ecff930e 100644 --- a/tests/test_config/test_dataset.py +++ b/tests/test_config/test_dataset.py @@ -1,10 +1,77 @@ +import pathlib + import pytest import vak.config.dataset + class TestDatasetConfig: - def test_init(self): - assert False + @pytest.mark.parametrize( + 'path, splits_path, name', + [ + # typical use by a user with default split + ('~/user/prepped/dataset', None, None), + # use by a user with a split specified + ('~/user/prepped/dataset', 'spilts/replicate-1.json', None), + # use of a built-in dataset, with a split specified + ('~/datasets/BioSoundSegBench', 'splits/Bengalese-Finch-song-gy6or6-replicate-1.json', 'BioSoundSegBench'), + ] + ) + def test_init(self, path, splits_path, name): + if name is None and splits_path is None: + dataset_config = vak.config.dataset.DatasetConfig( + path=path + ) + elif name is None: + dataset_config = vak.config.dataset.DatasetConfig( + path=path, + splits_path=splits_path, + ) + else: + dataset_config = vak.config.dataset.DatasetConfig( + name=name, + path=path, + splits_path=splits_path, + ) + assert isinstance(dataset_config, vak.config.dataset.DatasetConfig) + assert dataset_config.path == pathlib.Path(path) + if splits_path is not None: + assert dataset_config.splits_path == pathlib.Path(splits_path) + else: + assert dataset_config.splits_path is None + if name is not None: + assert dataset_config.name == name + else: + assert dataset_config.name is None + + @pytest.mark.parametrize( + 'config_dict', + [ + { + 'path' :'~/datasets/BioSoundSegBench', + 'splits_path': 'splits/Bengalese-Finch-song-gy6or6-replicate-1.json', + 'name': 'BioSoundSegBench', + }, + { + 'path' :'~/user/prepped/dataset', + }, + { + 'path' :'~/user/prepped/dataset', + 'splits_path': 'splits/replicate-1.json' + }, + ] + ) def test_from_config_dict(self, config_dict): - assert False + dataset_config = vak.config.dataset.DatasetConfig.from_config_dict(config_dict) + assert isinstance(dataset_config, vak.config.dataset.DatasetConfig) + assert dataset_config.path == pathlib.Path(config_dict['path']) + if 'splits_path' in config_dict: + assert dataset_config.splits_path == pathlib.Path(config_dict['splits_path']) + else: + assert dataset_config.splits_path is None + if 'name' in config_dict: + assert dataset_config.name == config_dict['name'] + else: + assert dataset_config.name is None + From d6588f0599558509d58162b6ec20e4a1320b6acf Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Wed, 1 May 2024 20:06:51 -0400 Subject: [PATCH 032/183] Use tomlkit not toml in cli/prep.py --- src/vak/cli/prep.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/vak/cli/prep.py b/src/vak/cli/prep.py index 80b7384d0..e9e993e7d 100644 --- a/src/vak/cli/prep.py +++ b/src/vak/cli/prep.py @@ -5,7 +5,7 @@ import shutil import warnings -import toml +import tomlkit from .. import config from .. import prep as prep_module @@ -145,7 +145,7 @@ def prep(toml_path): config_toml[table]["dataset"]["path"] = str(dataset_path) with toml_path.open("w") as fp: - toml.dump(config_toml, fp) + tomlkit.dump(config_toml, fp) # lastly, copy config to dataset directory root shutil.copy(src=toml_path, dst=dataset_path) From 53ef623e9124008d674d39ffc9bdf1ff927dcb33 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Wed, 1 May 2024 21:34:03 -0400 Subject: [PATCH 033/183] Use tomlkit in tests/fixtures/annot.py --- tests/fixtures/annot.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/fixtures/annot.py b/tests/fixtures/annot.py index 40ba6dec0..9ccb5d637 100644 --- a/tests/fixtures/annot.py +++ b/tests/fixtures/annot.py @@ -1,7 +1,7 @@ """fixtures relating to annotation files""" import crowsetta import pytest -import toml +import tomlkit from .config import GENERATED_TEST_CONFIGS_ROOT @@ -75,7 +75,7 @@ def annot_list_notmat(): )[0] # get first config.toml from glob list # doesn't really matter which config, they all have labelset with a_train_notmat_config.open("r") as fp: - a_train_notmat_toml = toml.load(fp) + a_train_notmat_toml = tomlkit.load(fp) LABELSET_NOTMAT = a_train_notmat_toml["PREP"]["labelset"] From 1f80c19c8c8a3647d2c3057cb72a887eae212a53 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Wed, 1 May 2024 21:35:23 -0400 Subject: [PATCH 034/183] Use tomlkit in tests/scripts/vaktestdata/configs.py --- tests/scripts/vaktestdata/configs.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/tests/scripts/vaktestdata/configs.py b/tests/scripts/vaktestdata/configs.py index cde39be49..1741e5b1b 100644 --- a/tests/scripts/vaktestdata/configs.py +++ b/tests/scripts/vaktestdata/configs.py @@ -3,8 +3,7 @@ import pathlib import shutil -# TODO: use tomli -import toml +import tomlkit import vak.cli.prep from . import constants @@ -70,7 +69,7 @@ def add_dataset_path_from_prepped_configs(): config_dataset_path = constants.GENERATED_TEST_CONFIGS_ROOT / config_metadata.use_dataset_from_config with config_dataset_path.open("r") as fp: - dataset_config_toml = toml.load(fp) + dataset_config_toml = tomlkit.load(fp) purpose = vak.cli.prep.purpose_from_toml(dataset_config_toml) # next line, we can't use `section` here because we could get a KeyError, # e.g., when the config we are rewriting is an EVAL config, but @@ -80,10 +79,10 @@ def add_dataset_path_from_prepped_configs(): dataset_config_section = purpose.upper() # need to be 'TRAIN', not 'train' dataset_path = dataset_config_toml[dataset_config_section]['dataset_path'] with config_to_change_path.open("r") as fp: - config_to_change_toml = toml.load(fp) + config_to_change_toml = tomlkit.load(fp) config_to_change_toml[section]['dataset_path'] = dataset_path with config_to_change_path.open("w") as fp: - toml.dump(config_to_change_toml, fp) + tomlkit.dump(config_to_change_toml, fp) def fix_options_in_configs(config_metadata_list, command, single_train_result=True): @@ -104,7 +103,7 @@ def fix_options_in_configs(config_metadata_list, command, single_train_result=Tr # now use the config to find the results dir and get the values for the options we need to set # which are checkpoint_path, spect_scaler_path, and labelmap_path with config_to_use_result_from.open("r") as fp: - config_toml = toml.load(fp) + config_toml = tomlkit.load(fp) root_results_dir = pathlib.Path(config_toml["TRAIN"]["root_results_dir"]) results_dir = sorted(root_results_dir.glob("results_*")) if len(results_dir) > 1: @@ -150,7 +149,7 @@ def fix_options_in_configs(config_metadata_list, command, single_train_result=Tr # now add these values to corresponding options in predict / eval config with config_to_fix.open("r") as fp: - config_toml = toml.load(fp) + config_toml = tomlkit.load(fp) if command == 'train_continue': section = 'TRAIN' @@ -169,4 +168,4 @@ def fix_options_in_configs(config_metadata_list, command, single_train_result=Tr config_toml[section]["labelmap_path"] = str(labelmap_path) with config_to_fix.open("w") as fp: - toml.dump(config_toml, fp) + tomlkit.dump(config_toml, fp) From cd2d0dd5d72ada50ba78272c784e472b40154af6 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Wed, 1 May 2024 21:35:33 -0400 Subject: [PATCH 035/183] Use tomlkit in tests/scripts/vaktestdata/source_files.py --- tests/scripts/vaktestdata/source_files.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/scripts/vaktestdata/source_files.py b/tests/scripts/vaktestdata/source_files.py index e53d0e2ee..4f72b134a 100644 --- a/tests/scripts/vaktestdata/source_files.py +++ b/tests/scripts/vaktestdata/source_files.py @@ -7,7 +7,7 @@ warnings.simplefilter('ignore', category=NumbaDeprecationWarning) import pandas as pd -import toml +import tomlkit import vak @@ -103,11 +103,11 @@ def set_up_source_files_and_csv_files_for_frame_classification_models(): ) with config_path.open("r") as fp: - config_toml = toml.load(fp) + config_toml = tomlkit.load(fp) data_dir = constants.GENERATED_TEST_DATA_ROOT / config_metadata.data_dir config_toml['PREP']['data_dir'] = str(data_dir) with config_path.open("w") as fp: - toml.dump(config_toml, fp) + tomlkit.dump(config_toml, fp) cfg = vak.config.parse.from_toml_path(config_path) From ad1bbf3355c100c5d3e2b1f5218bdc1ad3a0b82c Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Wed, 1 May 2024 21:36:35 -0400 Subject: [PATCH 036/183] Use tomlkit in tests/test_config/test_validators.py --- tests/test_config/test_validators.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_config/test_validators.py b/tests/test_config/test_validators.py index 49ae7fad9..65771cc14 100644 --- a/tests/test_config/test_validators.py +++ b/tests/test_config/test_validators.py @@ -1,5 +1,5 @@ import pytest -import toml +import tomlkit import vak.config.validators @@ -7,7 +7,7 @@ def test_are_sections_valid(invalid_section_config_path): """test that invalid section name raises a ValueError""" with invalid_section_config_path.open("r") as fp: - config_toml = toml.load(fp) + config_toml = tomlkit.load(fp) with pytest.raises(ValueError): vak.config.validators.are_sections_valid( config_toml, invalid_section_config_path @@ -18,7 +18,7 @@ def test_are_options_valid(invalid_option_config_path): """test that section with an invalid option name raises a ValueError""" section_with_invalid_option = "PREP" with invalid_option_config_path.open("r") as fp: - config_toml = toml.load(fp) + config_toml = tomlkit.load(fp) with pytest.raises(ValueError): vak.config.validators.are_options_valid( config_toml, section_with_invalid_option, invalid_option_config_path From 249cf1665414ade6002932792c937d5c6950638a Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Wed, 1 May 2024 21:37:16 -0400 Subject: [PATCH 037/183] Remove spect_params attribute from Config in config/config.py, fix class' docstring --- src/vak/config/config.py | 19 ++++++------------- 1 file changed, 6 insertions(+), 13 deletions(-) diff --git a/src/vak/config/config.py b/src/vak/config/config.py index 771ca40a4..330b83414 100644 --- a/src/vak/config/config.py +++ b/src/vak/config/config.py @@ -5,33 +5,26 @@ from .learncurve import LearncurveConfig from .predict import PredictConfig from .prep import PrepConfig -from .spect_params import SpectParamsConfig from .train import TrainConfig @attr.s class Config: - """class to represent config.toml file + """Class that represents a configuration file. Attributes ---------- prep : vak.config.prep.PrepConfig - represents ``[vak.prep]`` table of config.toml file - spect_params : vak.config.spect_params.SpectParamsConfig - represents ``[SPECT_PARAMS]`` table of config.toml file + Represents ``[vak.prep]`` table of config.toml file train : vak.config.train.TrainConfig - represents ``[vak.train]`` table of config.toml file + Represents ``[vak.train]`` table of config.toml file eval : vak.config.eval.EvalConfig - represents ``[vak.eval]`` table of config.toml file + Represents ``[vak.eval]`` table of config.toml file predict : vak.config.predict.PredictConfig - represents ``[vak.predict]`` table of config.toml file. + Represents ``[vak.predict]`` table of config.toml file. learncurve : vak.config.learncurve.LearncurveConfig - represents ``[vak.learncurve]`` table of config.toml file + Represents ``[vak.learncurve]`` table of config.toml file """ - - spect_params = attr.ib( - validator=instance_of(SpectParamsConfig), default=SpectParamsConfig() - ) prep = attr.ib(validator=optional(instance_of(PrepConfig)), default=None) train = attr.ib(validator=optional(instance_of(TrainConfig)), default=None) eval = attr.ib(validator=optional(instance_of(EvalConfig)), default=None) From 6654a0698fd4970d8cad1f1bca1014d92d499bc2 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Wed, 1 May 2024 21:46:50 -0400 Subject: [PATCH 038/183] Reorder attributes, fix typo in docstring of DatasetConfig --- src/vak/config/dataset.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/vak/config/dataset.py b/src/vak/config/dataset.py index cc9f7eb63..60ed15ed1 100644 --- a/src/vak/config/dataset.py +++ b/src/vak/config/dataset.py @@ -13,28 +13,28 @@ class DatasetConfig: Attributes ---------- - name : str, optional - Name of dataset. Only required for built-in datasets - from the :mod:`~vak.datasets` module. path : pathlib.Path Path to the directory that contains the dataset. - Equivalent to the `root` parameter of :module:`torchvision` + Equivalent to the `root` parameter of :mod:`torchvision` datasets. splits_path : pathlib.Path, optional Path to file representing splits. + name : str, optional + Name of dataset. Only required for built-in datasets + from the :mod:`~vak.datasets` module. """ path: pathlib.Path = field(converter=pathlib.Path) - name: str | None = field( - converter=attr.converters.optional(str), default=None - ) splits_path: pathlib.Path | None = field( converter=attr.converters.optional(pathlib.Path), default=None ) + name: str | None = field( + converter=attr.converters.optional(str), default=None + ) @classmethod def from_config_dict(cls, dict_: dict) -> DatasetConfig: return cls( path=dict_.get('path'), + splits_path=dict_.get('splits_path'), name=dict_.get('name'), - splits_path=dict_.get('splits_path') ) From b2c8652a5069438a090decb59eb1695f8bba55ef Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Wed, 1 May 2024 22:11:07 -0400 Subject: [PATCH 039/183] Rewrite config/parse.py assuming config classes have from_config_dict classmethod --- src/vak/config/parse.py | 105 +++++++++++++--------------------------- 1 file changed, 33 insertions(+), 72 deletions(-) diff --git a/src/vak/config/parse.py b/src/vak/config/parse.py index 721893530..fdae90723 100644 --- a/src/vak/config/parse.py +++ b/src/vak/config/parse.py @@ -11,7 +11,6 @@ from .learncurve import LearncurveConfig from .predict import PredictConfig from .prep import PrepConfig -from .spect_params import SpectParamsConfig from .train import TrainConfig from .validators import are_options_valid, are_tables_valid @@ -21,7 +20,6 @@ "learncurve": LearncurveConfig, "predict": PredictConfig, "prep": PrepConfig, - "spect_params": SpectParamsConfig, "train": TrainConfig, } @@ -50,77 +48,40 @@ } -def parse_config_table(config_dict, table_name, toml_path=None): - """Parse table of config.toml file +def _validate_tables_to_parse_arg_convert_list(tables_to_parse: str | list[str]) -> list[str]: + """Helper function used by :func:`from_toml` that + validates the ``tables_to_parse`` argument, + and returns it as a list of strings.""" + if isinstance(tables_to_parse, str): + tables_to_parse = [tables_to_parse] - Parameters - ---------- - config_dict : dict - Containing config.toml file already loaded by parse function - table_name : str - Name of table from configuration - file that should be parsed. - toml_path : str - path to a configuration file in TOML format. Default is None. - Used for error messages if specified. - - Returns - ------- - config : vak.config table class - instance of class that represents table of config.toml file, - e.g. PredictConfig for 'PREDICT' table - """ - table = dict(config_dict[table_name].items()) - - required_options = REQUIRED_OPTIONS[table_name] - if required_options is not None: - for required_option in required_options: - if required_option not in table: - if toml_path: - err_msg = ( - f"the '{required_option}' option is required but was not found in the " - f"{table_name} table of the config.toml file: {toml_path}" - ) - else: - err_msg = ( - f"the '{required_option}' option is required but was not found in the " - f"{table_name} table of the toml config" - ) - raise KeyError(err_msg) - return TABLE_CLASSES[table_name](**table) - - -def _validate_tables_arg_convert_list(tables: str | list[str]) -> list[str]: - if isinstance(tables, str): - tables = [tables] - - if not isinstance(tables, list): + if not isinstance(tables_to_parse, list): raise TypeError( - f"`tables` should be a string or list of strings but type was: {type(tables)}" + f"`tables_to_parse` should be a string or list of strings but type was: {type(tables_to_parse)}" ) if not all( - [isinstance(table_name, str) for table_name in tables] + [isinstance(table_name, str) for table_name in tables_to_parse] ): raise ValueError( - "All table names in 'tables' should be strings" + "All table names in 'tables_to_parse' should be strings" ) if not all( [ table_name in list(TABLE_CLASSES.keys()) - for table_name in tables + for table_name in tables_to_parse ] ): raise ValueError( - "All table names in 'tables' should be valid names of tables. " - f"Values for 'tables were: {tables}.\n" + "All table names in 'tables_to_parse' should be valid names of tables. " + f"Values for 'tables were: {tables_to_parse}.\n" f"Valid table names are: {list(TABLE_CLASSES.keys())}" ) - return tables + return tables_to_parse def from_toml( - config_dict: dict, toml_path: str | pathlib.Path | None = None, tables: str | list[str] | None = None + config_dict: dict, toml_path: str | pathlib.Path | None = None, tables_to_parse: str | list[str] | None = None ) -> Config: """Load a TOML configuration file. @@ -132,8 +93,8 @@ def from_toml( toml_path : str, pathlib.Path path to a configuration file in TOML format. Default is None. Not required, used only to make any error messages clearer. - tables : str, list - Name of table or tables from configuration + tables_to_parse : str, list + Name of top-level table or tables from configuration file that should be parsed. Can be a string (single table) or list of strings (multiple tables). Default is None, @@ -146,22 +107,22 @@ def from_toml( tables in a config.toml file. """ are_tables_valid(config_dict, toml_path) - tables = _validate_tables_arg_convert_list(tables) + if tables_to_parse is None: + tables_to_parse = list( + TABLE_CLASSES.keys() + ) # i.e., parse all tables + else: + tables_to_parse = _validate_tables_to_parse_arg_convert_list(tables_to_parse) config_kwargs = {} - if tables is None: - tables = list( - TABLE_CLASSES.keys() - ) # i.e., parse all tables, except model - for table_name in tables: + for table_name in tables_to_parse: if table_name in config_dict: are_options_valid(config_dict, table_name, toml_path) - config_kwargs[table_name.lower()] = parse_config_table( - config_dict, table_name, toml_path - ) + table_config_dict = config_dict[table_name] + config_kwargs[table_name] = TABLE_CLASSES[table_name].from_config_dict(table_config_dict) else: raise KeyError( - f"A table specified in `tables` was not found in the config: {table_name}" + f"A table specified in `tables_to_parse` was not found in the config: {table_name}" ) return Config(**config_kwargs) @@ -195,18 +156,18 @@ def _load_toml_from_path(toml_path: str | pathlib.Path) -> dict: return config_dict['vak'] -def from_toml_path(toml_path: str | pathlib.Path, tables: list[str] | None = None) -> Config: +def from_toml_path(toml_path: str | pathlib.Path, tables_to_parse: list[str] | None = None) -> Config: """Parse a TOML configuration file and return as a :class:`Config`. Parameters ---------- toml_path : str, pathlib.Path - path to a configuration file in TOML format. + Path to a configuration file in TOML format. Parsed by ``toml`` library, then converted to an instance of ``vak.config.parse.Config`` by calling ``vak.parse.from_toml`` - tables : str, list - name of table or tables from configuration + tables_to_parse : str, list + Name of table or tables from configuration file that should be parsed. Can be a string (single table) or list of strings (multiple tables). Default is None, @@ -218,5 +179,5 @@ def from_toml_path(toml_path: str | pathlib.Path, tables: list[str] | None = Non instance of :class:`Config` class, whose attributes correspond to tables in a config.toml file. """ - config_dict = _load_toml_from_path(toml_path) - return from_toml(config_dict, toml_path, tables) + config_dict: dict = _load_toml_from_path(toml_path) + return from_toml(config_dict, toml_path, tables_to_parse) From 06804817c180bed8f893c784c74e607b5ba97f43 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Wed, 1 May 2024 22:11:56 -0400 Subject: [PATCH 040/183] Rename `table` -> `table_name` in a couple validators in config/validators.py --- src/vak/config/validators.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/vak/config/validators.py b/src/vak/config/validators.py index 6a81f1912..7439d80aa 100644 --- a/src/vak/config/validators.py +++ b/src/vak/config/validators.py @@ -112,28 +112,28 @@ def are_tables_valid(config_dict, toml_path=None): def are_options_valid( - config_dict: dict, table: str, toml_path: str | pathlib.Path | None = None + config_dict: dict, table_name: str, toml_path: str | pathlib.Path | None = None ) -> None: """Given a :class:`dict` containing the *entire* configuration loaded from a toml file, validate the option names for a specific top-level table, e.g. ``vak.train`` or ``vak.predict``""" - user_options = set(config_dict[table].keys()) - valid_options = set(VALID_OPTIONS[table]) + user_options = set(config_dict[table_name].keys()) + valid_options = set(VALID_OPTIONS[table_name]) if not user_options.issubset(valid_options): invalid_options = user_options - valid_options if toml_path: err_msg = ( - f"The following options from '{table}' table in " + f"The following options from '{table_name}' table in " f"the config file '{toml_path.name}' are not valid:\n{invalid_options}" ) else: err_msg = ( - f"The following options from '{table}' table in " + f"The following options from '{table_name}' table in " f"the toml config are not valid:\n{invalid_options}" ) raise ValueError(err_msg) -def are_table_options_valid(table_config_dict: dict, table: str, toml_path: str | pathlib.Path | None = None) -> None: +def are_table_options_valid(table_config_dict: dict, table_name: str, toml_path: str | pathlib.Path | None = None) -> None: """Given a :class:`dict` containing the configuration for a *specific* top-level table, loaded from a toml file, validate the option names for that table, e.g. ``vak.train`` or ``vak.predict``. @@ -144,17 +144,17 @@ def are_table_options_valid(table_config_dict: dict, table: str, toml_path: str classmethod of the top-level tables. """ user_options = set(table_config_dict.keys()) - valid_options = set(VALID_OPTIONS[table]) + valid_options = set(VALID_OPTIONS[table_name]) if not user_options.issubset(valid_options): invalid_options = user_options - valid_options if toml_path: err_msg = ( - f"The following options from '{table}' table in " + f"The following options from '{table_name}' table in " f"the config file '{toml_path.name}' are not valid:\n{invalid_options}" ) else: err_msg = ( - f"The following options from '{table}' table in " + f"The following options from '{table_name}' table in " f"the toml config are not valid:\n{invalid_options}" ) raise ValueError(err_msg) From 9cbf58a44d493f8e9cb4aec17d0768ddb1cd03e0 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Wed, 1 May 2024 22:52:54 -0400 Subject: [PATCH 041/183] Remove use of config.model.config_from_toml_path in cli/eval.py --- src/vak/cli/eval.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/vak/cli/eval.py b/src/vak/cli/eval.py index 329bf38b0..fd9f96b76 100644 --- a/src/vak/cli/eval.py +++ b/src/vak/cli/eval.py @@ -37,9 +37,6 @@ def eval(toml_path): logger.info("Logging results to {}".format(cfg.eval.output_dir)) - model_name = cfg.eval.model - model_config = config.model.config_from_toml_path(toml_path, model_name) - if cfg.eval.dataset_path is None: raise ValueError( "No value is specified for 'dataset_path' in this .toml config file." @@ -48,8 +45,8 @@ def eval(toml_path): ) eval_module.eval( - model_name=model_name, - model_config=model_config, + model_name=cfg.eval.model.name, + model_config=cfg.eval.model, dataset_path=cfg.eval.dataset_path, checkpoint_path=cfg.eval.checkpoint_path, labelmap_path=cfg.eval.labelmap_path, From eeef546af9952e62c194a4975719d7b032805be4 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Wed, 1 May 2024 22:53:03 -0400 Subject: [PATCH 042/183] Remove use of config.model.config_from_toml_path in cli/learncurve.py --- src/vak/cli/learncurve.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/vak/cli/learncurve.py b/src/vak/cli/learncurve.py index 33c293a61..20d613bba 100644 --- a/src/vak/cli/learncurve.py +++ b/src/vak/cli/learncurve.py @@ -45,9 +45,6 @@ def learning_curve(toml_path): log_version(logger) logger.info("Logging results to {}".format(results_path)) - model_name = cfg.learncurve.model - model_config = config.model.config_from_toml_path(toml_path, model_name) - if cfg.learncurve.dataset_path is None: raise ValueError( "No value is specified for 'dataset_path' in this .toml config file." @@ -56,8 +53,8 @@ def learning_curve(toml_path): ) learncurve.learning_curve( - model_name=model_name, - model_config=model_config, + model_name=cfg.learncurve.model.name, + model_config=cfg.learncurve.model, dataset_path=cfg.learncurve.dataset_path, batch_size=cfg.learncurve.batch_size, num_epochs=cfg.learncurve.num_epochs, From 0a5eb9611361da4a9d87024c393e8bfeb3e60706 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Wed, 1 May 2024 22:53:13 -0400 Subject: [PATCH 043/183] Remove use of config.model.config_from_toml_path in cli/predict.py --- src/vak/cli/predict.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/vak/cli/predict.py b/src/vak/cli/predict.py index 38701b87f..85f600ef9 100644 --- a/src/vak/cli/predict.py +++ b/src/vak/cli/predict.py @@ -35,9 +35,6 @@ def predict(toml_path): log_version(logger) logger.info("Logging results to {}".format(cfg.prep.output_dir)) - model_name = cfg.predict.model - model_config = config.model.config_from_toml_path(toml_path, model_name) - if cfg.predict.dataset_path is None: raise ValueError( "No value is specified for 'dataset_path' in this .toml config file." @@ -46,8 +43,8 @@ def predict(toml_path): ) predict_module.predict( - model_name=model_name, - model_config=model_config, + model_name=cfg.predict.model.name, + model_config=cfg.predict.model, dataset_path=cfg.predict.dataset_path, checkpoint_path=cfg.predict.checkpoint_path, labelmap_path=cfg.predict.labelmap_path, From f8ebc2080730ec41c3fc595a95c47be7b0690b48 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Wed, 1 May 2024 22:53:25 -0400 Subject: [PATCH 044/183] Remove use of config.model.config_from_toml_path in cli/train.py --- src/vak/cli/train.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/vak/cli/train.py b/src/vak/cli/train.py index 91c89bb95..6b93db32b 100644 --- a/src/vak/cli/train.py +++ b/src/vak/cli/train.py @@ -45,9 +45,6 @@ def train(toml_path): log_version(logger) logger.info("Logging results to {}".format(results_path)) - model_name = cfg.train.model - model_config = config.model.config_from_toml_path(toml_path, model_name) - if cfg.train.dataset_path is None: raise ValueError( "No value is specified for 'dataset_path' in this .toml config file." @@ -56,8 +53,8 @@ def train(toml_path): ) train_module.train( - model_name=model_name, - model_config=model_config, + model_name=cfg.train.model.name, + model_config=cfg.train.model, dataset_path=cfg.train.dataset_path, train_transform_params=cfg.train.train_transform_params, train_dataset_params=cfg.train.train_dataset_params, From cebd8d8476f8cc420b491c1b723d0f174c5257ef Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Wed, 1 May 2024 22:54:05 -0400 Subject: [PATCH 045/183] Remove functions from config/model.py: config_from_toml_path and config_from_toml_dict --- src/vak/config/model.py | 77 ----------------------------------------- 1 file changed, 77 deletions(-) diff --git a/src/vak/config/model.py b/src/vak/config/model.py index 307731951..ae0d6d45e 100644 --- a/src/vak/config/model.py +++ b/src/vak/config/model.py @@ -98,80 +98,3 @@ def from_config_dict(cls, config_dict: dict): name=model_name, **model_config ) - - -def config_from_toml_dict(toml_dict: dict, table: str, model_name: str) -> dict: - """Get configuration for a model from a .toml configuration file - loaded into a ``dict``. - - Parameters - ---------- - toml_dict : dict - Configuration from a .toml file, loaded into a dictionary. - table : str - Name of top-level table to get model config from. - model_name : str - Name of a model, specified as the ``model`` option in a table - (such as TRAIN or PREDICT), - that should have its own corresponding table - specifying its configuration: hyperparameters such as learning rate, etc. - - Returns - ------- - model_config : dict - Model configuration in a ``dict``, - as loaded from a .toml file, - and used by the model method ``from_config``. - """ - if model_name not in models.registry.MODEL_NAMES: - raise ValueError( - f"Invalid model name: {model_name}.\nValid model names are: {models.registry.MODEL_NAMES}" - ) - from . import validators # avoid circular import - validators.are_tables_valid(toml_dict) - - try: - model_config = toml_dict[table][model_name] - except KeyError as e: - raise ValueError( - f"A config section specifies the model name '{model_name}', " - f"but there is no section named '{model_name}' in the '{table}' table of the config." - ) from e - - # check if config declares parameters for required attributes; - # if not, just put an empty dict that will get passed as the "kwargs" - for attr in MODEL_TABLES: - if attr not in model_config: - model_config[attr] = {} - - return model_config - - -def config_from_toml_path( - toml_path: str | pathlib.Path, table: str, model_name: str -) -> dict: - """Get configuration for a model from a .toml configuration file, - given the path to the file. - - Parameters - ---------- - toml_path : str, Path - to configuration file in .toml format - table : str - Name of top-level table to get model config from. - model_name : str - of str, i.e. names of models specified by a section - (such as TRAIN or PREDICT) that should each have corresponding sections - specifying their configuration: hyperparameters such as learning rate, etc. - - Returns - ------- - model_config : dict - Model configuration in a ``dict``, - as loaded from a .toml file, - and used by the model method ``from_config``. - """ - from . import parse # avoid circular import - - toml_dict = parse._load_toml_from_path(toml_path) - return config_from_toml_dict(toml_dict, table, model_name) From 7469f8da23e5e4a768ffd2c3436c2fa6815c80fc Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Wed, 1 May 2024 23:03:45 -0400 Subject: [PATCH 046/183] Add `to_dict` method to ModelConfig --- src/vak/config/model.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/src/vak/config/model.py b/src/vak/config/model.py index ae0d6d45e..c1a4ce0cf 100644 --- a/src/vak/config/model.py +++ b/src/vak/config/model.py @@ -98,3 +98,19 @@ def from_config_dict(cls, config_dict: dict): name=model_name, **model_config ) + + def to_dict(self): + """Convert this :class:`ModelConfig` instance + to a :class:`dict` that can be passed + into functions that take a ``model_config`` argument, + like :func:`vak.train` and :func:`vak.predict`. + + This function drops the ``name`` attribute, + and returns all other attributes in a :class:`dict`. + """ + return { + 'network': self.network, + 'optimizer': self.optimizer, + 'loss': self.loss, + 'metrics': self.metrics, + } \ No newline at end of file From c0c11e4eae6e3e8ff67e5a66b98ebe72b04145e9 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Wed, 1 May 2024 23:06:06 -0400 Subject: [PATCH 047/183] Use to_dict() method of ModelConfig class in cli functions --- src/vak/cli/eval.py | 4 +++- src/vak/cli/learncurve.py | 2 +- src/vak/cli/predict.py | 2 +- src/vak/cli/train.py | 2 +- 4 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/vak/cli/eval.py b/src/vak/cli/eval.py index fd9f96b76..5914bb41e 100644 --- a/src/vak/cli/eval.py +++ b/src/vak/cli/eval.py @@ -1,6 +1,8 @@ import logging from pathlib import Path +import attrs + from .. import config from .. import eval as eval_module from ..common.logging import config_logging_for_cli, log_version @@ -46,7 +48,7 @@ def eval(toml_path): eval_module.eval( model_name=cfg.eval.model.name, - model_config=cfg.eval.model, + model_config=cfg.eval.model.to_dict(), dataset_path=cfg.eval.dataset_path, checkpoint_path=cfg.eval.checkpoint_path, labelmap_path=cfg.eval.labelmap_path, diff --git a/src/vak/cli/learncurve.py b/src/vak/cli/learncurve.py index 20d613bba..e68f54651 100644 --- a/src/vak/cli/learncurve.py +++ b/src/vak/cli/learncurve.py @@ -54,7 +54,7 @@ def learning_curve(toml_path): learncurve.learning_curve( model_name=cfg.learncurve.model.name, - model_config=cfg.learncurve.model, + model_config=cfg.learncurve.model.to_dict(), dataset_path=cfg.learncurve.dataset_path, batch_size=cfg.learncurve.batch_size, num_epochs=cfg.learncurve.num_epochs, diff --git a/src/vak/cli/predict.py b/src/vak/cli/predict.py index 85f600ef9..a33c4955e 100644 --- a/src/vak/cli/predict.py +++ b/src/vak/cli/predict.py @@ -44,7 +44,7 @@ def predict(toml_path): predict_module.predict( model_name=cfg.predict.model.name, - model_config=cfg.predict.model, + model_config=cfg.predict.model.to_dict(), dataset_path=cfg.predict.dataset_path, checkpoint_path=cfg.predict.checkpoint_path, labelmap_path=cfg.predict.labelmap_path, diff --git a/src/vak/cli/train.py b/src/vak/cli/train.py index 6b93db32b..cedac92bb 100644 --- a/src/vak/cli/train.py +++ b/src/vak/cli/train.py @@ -54,7 +54,7 @@ def train(toml_path): train_module.train( model_name=cfg.train.model.name, - model_config=cfg.train.model, + model_config=cfg.train.model.to_dict(), dataset_path=cfg.train.dataset_path, train_transform_params=cfg.train.train_transform_params, train_dataset_params=cfg.train.train_dataset_params, From e87b5ed63d5ade835956410bf56a54f3b0ca1fcf Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Thu, 2 May 2024 08:56:29 -0400 Subject: [PATCH 048/183] Fix how we get labelset from config in tests/fixtures/annot.py --- tests/fixtures/annot.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/fixtures/annot.py b/tests/fixtures/annot.py index 9ccb5d637..c69c31ca7 100644 --- a/tests/fixtures/annot.py +++ b/tests/fixtures/annot.py @@ -76,7 +76,7 @@ def annot_list_notmat(): # doesn't really matter which config, they all have labelset with a_train_notmat_config.open("r") as fp: a_train_notmat_toml = tomlkit.load(fp) -LABELSET_NOTMAT = a_train_notmat_toml["PREP"]["labelset"] +LABELSET_NOTMAT = a_train_notmat_toml["vak"]["prep"]["labelset"] @pytest.fixture @@ -135,4 +135,5 @@ def annotated_annot_no_segments(request): Used to test edge case for `has_unlabeled`, see https://github.com/vocalpy/vak/issues/378 """ + return request.param From 07dd40e940945d1af38b2bcb8628140a76891fb7 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Thu, 2 May 2024 08:56:51 -0400 Subject: [PATCH 049/183] WIP: Clean up / rewrite tests/fixtures/config.py --- tests/fixtures/config.py | 166 ++++++++++++++------------------------- 1 file changed, 58 insertions(+), 108 deletions(-) diff --git a/tests/fixtures/config.py b/tests/fixtures/config.py index dae6e50f4..92ceb7f5c 100644 --- a/tests/fixtures/config.py +++ b/tests/fixtures/config.py @@ -3,7 +3,7 @@ import shutil import pytest -import toml +import tomlkit from .test_data import GENERATED_TEST_DATA_ROOT, TEST_DATA_ROOT @@ -78,15 +78,6 @@ def generated_test_configs_root(): return GENERATED_TEST_CONFIGS_ROOT -ALL_GENERATED_CONFIGS = sorted(GENERATED_TEST_CONFIGS_ROOT.glob("*toml")) - - -# ---- path to config files ---- -@pytest.fixture -def all_generated_configs(): - return ALL_GENERATED_CONFIGS - - @pytest.fixture def specific_config_toml_path(generated_test_configs_root, list_of_schematized_configs, tmp_path): """returns a factory function @@ -177,7 +168,7 @@ def _specific_config( ) with config_copy_path.open("r") as fp: - config_toml = toml.load(fp) + config_toml = tomlkit.load(fp) for opt_dict in options_to_change: if opt_dict["value"] == 'DELETE-OPTION': @@ -187,13 +178,22 @@ def _specific_config( config_toml[opt_dict["section"]][opt_dict["option"]] = opt_dict["value"] with config_copy_path.open("w") as fp: - toml.dump(config_toml, fp) + tomlkit.dump(config_toml, fp) return config_copy_path return _specific_config +ALL_GENERATED_CONFIG_PATHS = sorted(GENERATED_TEST_CONFIGS_ROOT.glob("*toml")) + + +# ---- path to config files ---- +@pytest.fixture(params=ALL_GENERATED_CONFIG_PATHS) +def a_generated_config_path(request): + return request.param + + @pytest.fixture def all_generated_train_configs(generated_test_configs_root): return sorted(generated_test_configs_root.glob("test_train*toml")) @@ -206,9 +206,18 @@ def all_generated_learncurve_configs(generated_test_configs_root): return ALL_GENERATED_LEARNCURVE_CONFIGS +ALL_GENERATED_EVAL_CONFIG_PATHS = sorted( + GENERATED_TEST_CONFIGS_ROOT.glob("test_eval*toml") +) + + @pytest.fixture -def all_generated_eval_configs(generated_test_configs_root): - return sorted(generated_test_configs_root.glob("test_eval*toml")) +def all_generated_eval_configs(): + return ALL_GENERATED_EVAL_CONFIG_PATHS + +@pytest.fixture(params=ALL_GENERATED_EVAL_CONFIG_PATHS) +def a_generated_eval_config_toml(request): + return request.param @pytest.fixture @@ -216,13 +225,17 @@ def all_generated_predict_configs(generated_test_configs_root): return sorted(generated_test_configs_root.glob("test_predict*toml")) -# ---- config toml from paths ---- -def _return_toml(toml_path): - """return config files loaded into dicts with toml library - used to test functions that parse config sections, taking these dicts as inputs""" +# ---- config dicts from paths ---- +def _load_config_dict(toml_path): + """Return config as dict, loaded from toml file. + + Used to test functions that parse config sections, taking these dicts as inputs. + + Note that we access the topmost table loaded from the toml: config_dict['vak'] + """ with toml_path.open("r") as fp: - config_toml = toml.load(fp) - return config_toml + config_dict = tomlkit.load(fp) + return config_dict['vak'] @pytest.fixture @@ -244,60 +257,48 @@ def _specific_config_toml( config_path = specific_config_toml_path( config_type, model, annot_format, audio_format, spect_format ) - return _return_toml(config_path) + return _load_config_dict(config_path) return _specific_config_toml -ALL_GENERATED_CONFIGS_TOML = [_return_toml(config) for config in ALL_GENERATED_CONFIGS] +ALL_GENERATED_CONFIG_DICTS = [ + _load_config_dict(config) + for config in ALL_GENERATED_CONFIG_PATHS +] +@pytest.fixture(params=ALL_GENERATED_CONFIG_DICTS) +def a_generated_config_dict(request): + return request.param -@pytest.fixture -def all_generated_configs_toml(): - return ALL_GENERATED_CONFIGS_TOML - - -@pytest.fixture -def all_generated_train_configs_toml(all_generated_train_configs): - return [_return_toml(config) for config in all_generated_train_configs] @pytest.fixture def all_generated_learncurve_configs_toml(all_generated_learncurve_configs): - return [_return_toml(config) for config in all_generated_learncurve_configs] - - -@pytest.fixture -def all_generated_eval_configs_toml(all_generated_eval_configs): - return [_return_toml(config) for config in all_generated_eval_configs] - - -@pytest.fixture -def all_generated_predict_configs_toml(all_generated_predict_configs): - return [_return_toml(config) for config in all_generated_predict_configs] + return [_load_config_dict(config) for config in all_generated_learncurve_configs] ALL_GENERATED_CONFIGS_TOML_PATH_PAIRS = list(zip( - [_return_toml(config) for config in ALL_GENERATED_CONFIGS], - ALL_GENERATED_CONFIGS, + [_load_config_dict(config) for config in ALL_GENERATED_CONFIG_PATHS], + ALL_GENERATED_CONFIG_PATHS, )) # ---- config toml + path pairs ---- -@pytest.fixture -def all_generated_configs_toml_path_pairs(): - """zip of tuple pairs: (dict, pathlib.Path) - where ``Path`` is path to .toml config file and ``dict`` is - the .toml config from that path - loaded into a dict with the ``toml`` library - """ - # we duplicate the constant above because we need to remake - # the variables for each unit test. Otherwise tests that modify values - # for config options cause other tests to fail - return zip( - [_return_toml(config) for config in ALL_GENERATED_CONFIGS], - ALL_GENERATED_CONFIGS - ) +# @pytest.fixture +# def all_generated_configs_toml_path_pairs(): +# """zip of tuple pairs: (dict, pathlib.Path) +# where ``Path`` is path to .toml config file and ``dict`` is +# the .toml config from that path +# loaded into a dict with the ``toml`` library +# """ +# # we duplicate the constant above because we need to remake +# # the variables for each unit test. Otherwise tests that modify values +# # for config options cause other tests to fail +# return zip( +# [_load_config_dict(config) for config in ALL_GENERATED_CONFIGS], +# ALL_GENERATED_CONFIGS +# ) @pytest.fixture @@ -325,54 +326,3 @@ def _wrapped(model, return _wrapped - -@pytest.fixture -def all_generated_train_configs_toml_path_pairs(all_generated_train_configs): - """zip of tuple pairs: (dict, pathlib.Path) - where ``Path`` is path to .toml config file and ``dict`` is - the .toml config from that path - loaded into a dict with the ``toml`` library - """ - return zip( - [_return_toml(config) for config in all_generated_train_configs], - all_generated_train_configs, - ) - - -@pytest.fixture -def all_generated_learncurve_configs_toml_path_pairs(all_generated_learncurve_configs): - """zip of tuple pairs: (dict, pathlib.Path) - where ``Path`` is path to .toml config file and ``dict`` is - the .toml config from that path - loaded into a dict with the ``toml`` library - """ - return zip( - [_return_toml(config) for config in all_generated_learncurve_configs], - all_generated_learncurve_configs, - ) - - -@pytest.fixture -def all_generated_eval_configs_toml_path_pairs(all_generated_eval_configs): - """zip of tuple pairs: (dict, pathlib.Path) - where ``Path`` is path to .toml config file and ``dict`` is - the .toml config from that path - loaded into a dict with the ``toml`` library - """ - return zip( - [_return_toml(config) for config in all_generated_eval_configs], - all_generated_eval_configs, - ) - - -@pytest.fixture -def all_generated_predict_configs_toml_path_pairs(all_generated_predict_configs): - """zip of tuple pairs: (dict, pathlib.Path) - where ``Path`` is path to .toml config file and ``dict`` is - the .toml config from that path - loaded into a dict with the ``toml`` library - """ - return zip( - [_return_toml(config) for config in all_generated_predict_configs], - all_generated_predict_configs, - ) From 22aa552aa530f778f0e2a1d419b3fcbc1c7c7290 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Thu, 2 May 2024 10:29:43 -0400 Subject: [PATCH 050/183] Fix model tables in tests/data_for_tests/configs --- .../ConvEncoderUMAP_train_audio_cbin_annot_notmat.toml | 5 ++--- .../configs/TweetyNet_eval_audio_cbin_annot_notmat.toml | 5 ++--- .../TweetyNet_learncurve_audio_cbin_annot_notmat.toml | 5 ++--- .../configs/TweetyNet_predict_audio_cbin_annot_notmat.toml | 5 ++--- .../configs/TweetyNet_train_audio_cbin_annot_notmat.toml | 5 ++--- .../TweetyNet_train_continue_audio_cbin_annot_notmat.toml | 5 ++--- .../TweetyNet_train_continue_spect_mat_annot_yarden.toml | 5 ++--- .../configs/TweetyNet_train_spect_mat_annot_yarden.toml | 5 ++--- tests/data_for_tests/configs/invalid_option_config.toml | 3 +-- ...nvalid_section_config.toml => invalid_table_config.toml} | 6 +++--- .../configs/invalid_train_and_learncurve_config.toml | 4 +--- 11 files changed, 21 insertions(+), 32 deletions(-) rename tests/data_for_tests/configs/{invalid_section_config.toml => invalid_table_config.toml} (81%) diff --git a/tests/data_for_tests/configs/ConvEncoderUMAP_train_audio_cbin_annot_notmat.toml b/tests/data_for_tests/configs/ConvEncoderUMAP_train_audio_cbin_annot_notmat.toml index fef7afeb4..8be5a4d3a 100644 --- a/tests/data_for_tests/configs/ConvEncoderUMAP_train_audio_cbin_annot_notmat.toml +++ b/tests/data_for_tests/configs/ConvEncoderUMAP_train_audio_cbin_annot_notmat.toml @@ -16,7 +16,6 @@ step_size = 32 transform_type = "log_spect_plus_one" [vak.train] -model = "ConvEncoderUMAP" batch_size = 64 num_epochs = 1 val_step = 1 @@ -25,7 +24,7 @@ num_workers = 16 device = "cuda" root_results_dir = "./tests/data_for_tests/generated/results/train/audio_cbin_annot_notmat/ConvEncoderUMAP" -[ConvEncoderUMAP.network] +[vak.train.model.ConvEncoderUMAP.network] conv1_filters = 8 conv2_filters = 16 conv_kernel_size = 3 @@ -34,5 +33,5 @@ conv_padding = 1 n_features_linear = 32 n_components = 2 -[ConvEncoderUMAP.optimizer] +[vak.train.model.ConvEncoderUMAP.optimizer] lr = 0.001 diff --git a/tests/data_for_tests/configs/TweetyNet_eval_audio_cbin_annot_notmat.toml b/tests/data_for_tests/configs/TweetyNet_eval_audio_cbin_annot_notmat.toml index 4af0d82d4..d5975034b 100644 --- a/tests/data_for_tests/configs/TweetyNet_eval_audio_cbin_annot_notmat.toml +++ b/tests/data_for_tests/configs/TweetyNet_eval_audio_cbin_annot_notmat.toml @@ -17,7 +17,6 @@ transform_type = "log_spect" [vak.eval] checkpoint_path = "~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/TweetyNet/checkpoints/max-val-acc-checkpoint.pt" labelmap_path = "~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/labelmap.json" -model = "TweetyNet" batch_size = 11 num_workers = 16 device = "cuda" @@ -31,7 +30,7 @@ min_segment_dur = 0.02 [vak.eval.transform_params] window_size = 88 -[TweetyNet.network] +[vak.eval.model.TweetyNet.network] conv1_filters = 8 conv1_kernel_size = [3, 3] conv2_filters = 16 @@ -42,5 +41,5 @@ pool2_size = [4, 1] pool2_stride = [4, 1] hidden_size = 32 -[TweetyNet.optimizer] +[vak.eval.model.TweetyNet.optimizer] lr = 0.001 diff --git a/tests/data_for_tests/configs/TweetyNet_learncurve_audio_cbin_annot_notmat.toml b/tests/data_for_tests/configs/TweetyNet_learncurve_audio_cbin_annot_notmat.toml index 9a909bf0a..9d0c5b3b8 100644 --- a/tests/data_for_tests/configs/TweetyNet_learncurve_audio_cbin_annot_notmat.toml +++ b/tests/data_for_tests/configs/TweetyNet_learncurve_audio_cbin_annot_notmat.toml @@ -20,7 +20,6 @@ thresh = 6.25 transform_type = "log_spect" [vak.learncurve] -model = "TweetyNet" normalize_spectrograms = true batch_size = 11 num_epochs = 2 @@ -41,7 +40,7 @@ window_size = 88 [vak.learncurve.val_transform_params] window_size = 88 -[TweetyNet.network] +[vak.learncurve.model.TweetyNet.network] conv1_filters = 8 conv1_kernel_size = [3, 3] conv2_filters = 16 @@ -52,5 +51,5 @@ pool2_size = [4, 1] pool2_stride = [4, 1] hidden_size = 32 -[TweetyNet.optimizer] +[vak.learncurve.model.TweetyNet.optimizer] lr = 0.001 diff --git a/tests/data_for_tests/configs/TweetyNet_predict_audio_cbin_annot_notmat.toml b/tests/data_for_tests/configs/TweetyNet_predict_audio_cbin_annot_notmat.toml index b2db522f0..3c83d2826 100644 --- a/tests/data_for_tests/configs/TweetyNet_predict_audio_cbin_annot_notmat.toml +++ b/tests/data_for_tests/configs/TweetyNet_predict_audio_cbin_annot_notmat.toml @@ -16,7 +16,6 @@ transform_type = "log_spect" spect_scaler_path = "/home/user/results_181014_194418/spect_scaler" checkpoint_path = "~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/bl26lb16/results_200620_164245/TweetyNet/checkpoints/max-val-acc-checkpoint.pt" labelmap_path = "~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/bl26lb16/results_200620_164245/labelmap.json" -model = "TweetyNet" batch_size = 11 num_workers = 16 device = "cuda" @@ -26,7 +25,7 @@ annot_csv_filename = "bl26lb16.041912.annot.csv" [vak.predict.transform_params] window_size = 88 -[TweetyNet.network] +[vak.predict.model.TweetyNet.network] conv1_filters = 8 conv1_kernel_size = [3, 3] conv2_filters = 16 @@ -37,5 +36,5 @@ pool2_size = [4, 1] pool2_stride = [4, 1] hidden_size = 32 -[TweetyNet.optimizer] +[vak.predict.model.TweetyNet.optimizer] lr = 0.001 diff --git a/tests/data_for_tests/configs/TweetyNet_train_audio_cbin_annot_notmat.toml b/tests/data_for_tests/configs/TweetyNet_train_audio_cbin_annot_notmat.toml index 7e06b4a86..aba76e566 100644 --- a/tests/data_for_tests/configs/TweetyNet_train_audio_cbin_annot_notmat.toml +++ b/tests/data_for_tests/configs/TweetyNet_train_audio_cbin_annot_notmat.toml @@ -18,7 +18,6 @@ thresh = 6.25 transform_type = "log_spect" [vak.train] -model = "TweetyNet" normalize_spectrograms = true batch_size = 11 num_epochs = 2 @@ -35,7 +34,7 @@ window_size = 88 [vak.train.val_transform_params] window_size = 88 -[TweetyNet.network] +[vak.train.model.TweetyNet.network] conv1_filters = 8 conv1_kernel_size = [3, 3] conv2_filters = 16 @@ -46,5 +45,5 @@ pool2_size = [4, 1] pool2_stride = [4, 1] hidden_size = 32 -[TweetyNet.optimizer] +[vak.train.model.TweetyNet.optimizer] lr = 0.001 diff --git a/tests/data_for_tests/configs/TweetyNet_train_continue_audio_cbin_annot_notmat.toml b/tests/data_for_tests/configs/TweetyNet_train_continue_audio_cbin_annot_notmat.toml index 97c602a5d..754a8ac4e 100644 --- a/tests/data_for_tests/configs/TweetyNet_train_continue_audio_cbin_annot_notmat.toml +++ b/tests/data_for_tests/configs/TweetyNet_train_continue_audio_cbin_annot_notmat.toml @@ -18,7 +18,6 @@ thresh = 6.25 transform_type = "log_spect" [vak.train] -model = "TweetyNet" normalize_spectrograms = true batch_size = 11 num_epochs = 2 @@ -37,7 +36,7 @@ window_size = 88 [vak.train.val_transform_params] window_size = 88 -[TweetyNet.network] +[vak.train.model.TweetyNet.network] conv1_filters = 8 conv1_kernel_size = [3, 3] conv2_filters = 16 @@ -48,5 +47,5 @@ pool2_size = [4, 1] pool2_stride = [4, 1] hidden_size = 32 -[TweetyNet.optimizer] +[vak.train.model.TweetyNet.optimizer] lr = 0.001 diff --git a/tests/data_for_tests/configs/TweetyNet_train_continue_spect_mat_annot_yarden.toml b/tests/data_for_tests/configs/TweetyNet_train_continue_spect_mat_annot_yarden.toml index c2826251f..59f94124b 100644 --- a/tests/data_for_tests/configs/TweetyNet_train_continue_spect_mat_annot_yarden.toml +++ b/tests/data_for_tests/configs/TweetyNet_train_continue_spect_mat_annot_yarden.toml @@ -18,7 +18,6 @@ thresh = 6.25 transform_type = "log_spect" [vak.train] -model = "TweetyNet" normalize_spectrograms = false batch_size = 11 num_epochs = 2 @@ -36,7 +35,7 @@ window_size = 88 [vak.train.val_transform_params] window_size = 88 -[TweetyNet.network] +[vak.train.model.TweetyNet.network] conv1_filters = 8 conv1_kernel_size = [3, 3] conv2_filters = 16 @@ -47,5 +46,5 @@ pool2_size = [4, 1] pool2_stride = [4, 1] hidden_size = 32 -[TweetyNet.optimizer] +[vak.train.model.TweetyNet.optimizer] lr = 0.001 diff --git a/tests/data_for_tests/configs/TweetyNet_train_spect_mat_annot_yarden.toml b/tests/data_for_tests/configs/TweetyNet_train_spect_mat_annot_yarden.toml index 9ed05126e..588f2dd51 100644 --- a/tests/data_for_tests/configs/TweetyNet_train_spect_mat_annot_yarden.toml +++ b/tests/data_for_tests/configs/TweetyNet_train_spect_mat_annot_yarden.toml @@ -18,7 +18,6 @@ thresh = 6.25 transform_type = "log_spect" [vak.train] -model = "TweetyNet" normalize_spectrograms = false batch_size = 11 num_epochs = 2 @@ -35,7 +34,7 @@ window_size = 88 [vak.train.val_transform_params] window_size = 88 -[TweetyNet.network] +[vak.train.model.TweetyNet.network] conv1_filters = 8 conv1_kernel_size = [3, 3] conv2_filters = 16 @@ -46,5 +45,5 @@ pool2_size = [4, 1] pool2_stride = [4, 1] hidden_size = 32 -[TweetyNet.optimizer] +[vak.train.model.TweetyNet.optimizer] lr = 0.001 diff --git a/tests/data_for_tests/configs/invalid_option_config.toml b/tests/data_for_tests/configs/invalid_option_config.toml index 469e55435..55aa334f1 100644 --- a/tests/data_for_tests/configs/invalid_option_config.toml +++ b/tests/data_for_tests/configs/invalid_option_config.toml @@ -21,7 +21,6 @@ thresh = 6.25 transform_type = 'log_spect' [vak.train] -model = 'TweetyNet' root_results_dir = '/home/user/data/subdir/' normalize_spectrograms = true num_epochs = 2 @@ -33,5 +32,5 @@ save_only_single_checkpoint_file = true [vak.train.dataset_params] window_size = 88 -[TweetyNet.optimizer] +[vak.train.model.TweetyNet.optimizer] learning_rate = 0.001 diff --git a/tests/data_for_tests/configs/invalid_section_config.toml b/tests/data_for_tests/configs/invalid_table_config.toml similarity index 81% rename from tests/data_for_tests/configs/invalid_section_config.toml rename to tests/data_for_tests/configs/invalid_table_config.toml index 517517d21..daf0d4e0d 100644 --- a/tests/data_for_tests/configs/invalid_section_config.toml +++ b/tests/data_for_tests/configs/invalid_table_config.toml @@ -13,14 +13,14 @@ train_dur = 10 val_dur = 5 test_dur = 10 -[SPECTROGRAM] +[vak.prep.spect_params] fft_size=512 step_size=64 freq_cutoffs = [500, 10000] thresh = 6.25 transform_type = 'log_spect' -[TRIAN] # <-- invalid section 'TRIAN' (instead of 'vak.train') +[vak.trian] # <-- invalid section 'TRIAN' (instead of 'vak.train') model = 'TweetyNet' root_results_dir = '/home/user/data/subdir/' normalize_spectrograms = true @@ -30,5 +30,5 @@ val_error_step = 1 checkpoint_step = 1 save_only_single_checkpoint_file = true -[TweetyNet.optimizer] +[vak.train.model.TweetyNet.optimizer] learning_rate = 0.001 diff --git a/tests/data_for_tests/configs/invalid_train_and_learncurve_config.toml b/tests/data_for_tests/configs/invalid_train_and_learncurve_config.toml index 29e12ee2e..a4fcd542d 100644 --- a/tests/data_for_tests/configs/invalid_train_and_learncurve_config.toml +++ b/tests/data_for_tests/configs/invalid_train_and_learncurve_config.toml @@ -20,7 +20,6 @@ transform_type = "log_spect" # this .toml file should cause 'vak.config.parse.from_toml' to raise a ValueError # because it defines both a vak.train and a vak.learncurve section [vak.train] -model = "TweetyNet" normalize_spectrograms = true batch_size = 11 num_epochs = 2 @@ -32,7 +31,6 @@ device = "cuda" root_results_dir = "./tests/data_for_tests/generated/results/train/audio_cbin_annot_notmat" [vak.learncurve] -model = 'TweetyNet' normalize_spectrograms = true batch_size = 11 num_epochs = 2 @@ -45,5 +43,5 @@ num_replicates = 2 device = "cuda" root_results_dir = './tests/data_for_tests/generated/results/learncurve/audio_cbin_annot_notmat' -[TweetyNet.optimizer] +[vak.learncurve.model.TweetyNet.optimizer] lr = 0.001 From 10823a53b2961abf41c63666f4136a876a3ac499 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Thu, 2 May 2024 10:30:49 -0400 Subject: [PATCH 051/183] Finish unit tests in tests/test_config/test_model.py --- tests/test_config/test_model.py | 149 ++++++++++++++++++++------------ 1 file changed, 96 insertions(+), 53 deletions(-) diff --git a/tests/test_config/test_model.py b/tests/test_config/test_model.py index 7f5e56c4b..881e31f18 100644 --- a/tests/test_config/test_model.py +++ b/tests/test_config/test_model.py @@ -1,62 +1,105 @@ -import copy import pytest -from ..fixtures import ( - ALL_GENERATED_CONFIGS_TOML, - ALL_GENERATED_CONFIGS_TOML_PATH_PAIRS -) - import vak.config.model -def _make_expected_config(model_config: dict) -> dict: - for attr in vak.config.model.MODEL_TABLES: - if attr not in model_config: - model_config[attr] = {} - return model_config - - -@pytest.mark.parametrize( - 'toml_dict', - ALL_GENERATED_CONFIGS_TOML -) -def test_config_from_toml_dict(toml_dict): - for section_name in ('TRAIN', 'EVAL', 'LEARNCURVE', 'PREDICT'): - try: - section = toml_dict[section_name] - except KeyError: - continue - model_name = section['model'] - # we need to copy so that we don't silently fail to detect mistakes - # by comparing a reference to the dict with itself - expected_model_config = copy.deepcopy( - toml_dict[model_name] +class TestModelConfig: + @pytest.mark.parametrize( + 'config_dict', + [ + { + 'NonExistentModel': { + 'network': {}, + 'optimizer': {}, + 'loss': {}, + 'metrics': {}, + } + }, + { + 'TweetyNet': { + 'network': {}, + 'optimizer': {'lr': 1e-3}, + 'loss': {}, + 'metrics': {}, + } + }, + ] ) - expected_model_config = _make_expected_config(expected_model_config) - - model_config = vak.config.model.config_from_toml_dict(toml_dict, model_name) - - assert model_config == expected_model_config - - -@pytest.mark.parametrize( - 'toml_dict, toml_path', - ALL_GENERATED_CONFIGS_TOML_PATH_PAIRS -) -def test_config_from_toml_path(toml_dict, toml_path): - for section_name in ('TRAIN', 'EVAL', 'LEARNCURVE', 'PREDICT'): - try: - section = toml_dict[section_name] - except KeyError: - continue - model_name = section['model'] - # we need to copy so that we don't silently fail to detect mistakes - # by comparing a reference to the dict with itself - expected_model_config = copy.deepcopy( - toml_dict[model_name] + def test_init(self, config_dict): + name=list(config_dict.keys())[0] + config_dict_from_name = config_dict[name] + + model_config = vak.config.model.ModelConfig( + name=name, + **config_dict_from_name + ) + + assert isinstance(model_config, vak.config.model.ModelConfig) + assert model_config.name == name + for key, val in config_dict_from_name.items(): + assert hasattr(model_config, key) + assert getattr(model_config, key) == val + + @pytest.mark.parametrize( + 'config_dict', + [ + { + 'TweetyNet': { + 'optimizer': {'lr': 1e-3}, + } + }, + { + 'TweetyNet': { + 'network': {}, + 'optimizer': {'lr': 1e-3}, + 'loss': {}, + 'metrics': {}, + } + }, + { + 'ED_TCN': { + 'optimizer': {'lr': 1e-3}, + } + }, + { + "ConvEncoderUMAP": { + "optimizer": 1e-3 + } + } + ] ) - expected_model_config = _make_expected_config(expected_model_config) + def test_from_config_dict(self, config_dict): + model_config = vak.config.model.ModelConfig.from_config_dict(config_dict) + + name=list(config_dict.keys())[0] + config_dict_from_name = config_dict[name] + assert model_config.name == name + for attr in ('network', 'optimizer', 'loss', 'metrics'): + assert hasattr(model_config, attr) + if attr in config_dict_from_name: + assert getattr(model_config, attr) == config_dict_from_name[attr] + else: + assert getattr(model_config, attr) == {} + + def test_from_config_dict_real_config(self, a_generated_config_dict): + config_dict = None + for table_name in ('train', 'eval', 'predict', 'learncurve'): + if table_name in a_generated_config_dict: + config_dict = a_generated_config_dict[table_name]['model'] + if config_dict is None: + raise ValueError( + f"Didn't find top-level table for config: {a_generated_config_dict}" + ) + + model_config = vak.config.model.ModelConfig.from_config_dict(config_dict) - model_config = vak.config.model.config_from_toml_path(toml_path, model_name) + name=list(config_dict.keys())[0] + config_dict_from_name = config_dict[name] + assert model_config.name == name + for attr in ('network', 'optimizer', 'loss', 'metrics'): + assert hasattr(model_config, attr) + if attr in config_dict_from_name: + assert getattr(model_config, attr) == config_dict_from_name[attr] + else: + assert getattr(model_config, attr) == {} - assert model_config == expected_model_config From 2216f4b9039d801ddbcf78d29da7ac9a25f3b351 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Thu, 2 May 2024 10:37:34 -0400 Subject: [PATCH 052/183] Fix model tables in doc/toml --- doc/toml/gy6or6_eval.toml | 1 - doc/toml/gy6or6_predict.toml | 4 +--- doc/toml/gy6or6_train.toml | 7 ++++--- src/vak/config/model.py | 5 ++++- 4 files changed, 9 insertions(+), 8 deletions(-) diff --git a/doc/toml/gy6or6_eval.toml b/doc/toml/gy6or6_eval.toml index da788a6b7..80a147a0f 100644 --- a/doc/toml/gy6or6_eval.toml +++ b/doc/toml/gy6or6_eval.toml @@ -28,7 +28,6 @@ step_size = 64 # EVAL: options for evaluating a trained model. This is done using the "test" split. [vak.eval] -model = "TweetyNet" # checkpoint_path: path to saved model checkpoint checkpoint_path = "/PATH/TO/FOLDER/results/train/RESULTS_TIMESTAMP/TweetyNet/checkpoints/max-val-acc-checkpoint.pt" # labelmap_path: path to file that maps from outputs of model (integers) to text labels in annotations; diff --git a/doc/toml/gy6or6_predict.toml b/doc/toml/gy6or6_predict.toml index 001c16cae..1144aac4b 100644 --- a/doc/toml/gy6or6_predict.toml +++ b/doc/toml/gy6or6_predict.toml @@ -24,8 +24,6 @@ step_size = 64 # PREDICT: options for generating predictions with a trained model [vak.predict] -# model: the string name of the model. must be a name within `vak.models` or added e.g. with `vak.model.decorators.model` -model = "TweetyNet" # checkpoint_path: path to saved model checkpoint checkpoint_path = "/PATH/TO/FOLDER/results/train/RESULTS_TIMESTAMP/TweetyNet/checkpoints/max-val-acc-checkpoint.pt" # labelmap_path: path to file that maps from outputs of model (integers) to text labels in annotations; @@ -68,5 +66,5 @@ min_segment_dur = 0.01 window_size = 176 # Note we do not specify any options for the network, and just use the defaults -# We need to put this "dummy" table here though for the config to parse correctly +# We need to put this table here though, to indicate which model we are using. [vak.predict.model.TweetyNet] diff --git a/doc/toml/gy6or6_train.toml b/doc/toml/gy6or6_train.toml index ab3a02b85..dde4a8926 100644 --- a/doc/toml/gy6or6_train.toml +++ b/doc/toml/gy6or6_train.toml @@ -31,8 +31,6 @@ step_size = 64 # TRAIN: options for training model [vak.train] -# model: the string name of the model. must be a name within `vak.models` or added e.g. with `vak.model.decorators.model` -model = "TweetyNet" # root_results_dir: directory where results should be saved, as a sub-directory within `root_results_dir` root_results_dir = "/PATH/TO/FOLDER/results/train" # batch_size: number of samples from dataset per batch fed into network @@ -69,8 +67,11 @@ window_size = 176 [vak.train.val_transform_params] window_size = 176 -# TweetyNet.optimizer: we specify options for the model's optimizer in this table +# To indicate the model to train, we use a "dotted key" with `model` followed by the string name of the model. +# This name must be a name within `vak.models` or added e.g. with `vak.model.decorators.model` +# We use another dotted key to indicate options for configuring the model, e.g. `TweetyNet.optimizer` [vak.train.model.TweetyNet.optimizer] +# vak.train.model.TweetyNet.optimizer: we specify options for the model's optimizer in this table # lr: the learning rate lr = 0.001 diff --git a/src/vak/config/model.py b/src/vak/config/model.py index c1a4ce0cf..aacc6484d 100644 --- a/src/vak/config/model.py +++ b/src/vak/config/model.py @@ -59,7 +59,10 @@ def from_config_dict(cls, config_dict: dict): config_dict = vak.config.parse.from_toml_path(toml_path) model_config = vak.config.Model.from_config_dict(config_dict['train']) """ - model_name = list(config_dict.keys()) + try: + model_name = list(config_dict.keys()) + except: + breakpoint() if len(model_name) == 0: raise ValueError( "Did not find a single key in `config_dict` corresponding to model name. " From 4c38198e97b49aa8f30e6d6cb4285bfc8a0c7d06 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Thu, 2 May 2024 10:39:08 -0400 Subject: [PATCH 053/183] Rename data_for_tests/configs/invalid_option_config.toml -> invalid_key_config.toml --- ...{invalid_option_config.toml => invalid_key_config.toml} | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) rename tests/data_for_tests/configs/{invalid_option_config.toml => invalid_key_config.toml} (73%) diff --git a/tests/data_for_tests/configs/invalid_option_config.toml b/tests/data_for_tests/configs/invalid_key_config.toml similarity index 73% rename from tests/data_for_tests/configs/invalid_option_config.toml rename to tests/data_for_tests/configs/invalid_key_config.toml index 55aa334f1..7e95d9332 100644 --- a/tests/data_for_tests/configs/invalid_option_config.toml +++ b/tests/data_for_tests/configs/invalid_key_config.toml @@ -1,11 +1,12 @@ -# used to test that invalid option 'ouput_dir' (instead of 'output_dir') +# used to test that invalid key 'ouput_dir' (instead of 'output_dir') # raises a ValueError when passed to -# vak.config.validators.are_options_valid +# vak.config.validators.are_keys_valid [vak.prep] dataset_type = "frame classification" input_type = "spect" data_dir = '/home/user/data/subdir/' -ouput_dir = '/why/do/i/keep/typing/ouput' # <-- invalid option 'ouput' instead of 'output' +# next line, invalid key 'ouput' instead of 'output' +ouput_dir = '/why/do/i/keep/typing/ouput' audio_format = 'cbin' annot_format = 'notmat' labelset = 'iabcdefghjk' From 9f8b222cac3f5ec4c307a53da1536fdc5d868161 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Thu, 2 May 2024 10:45:32 -0400 Subject: [PATCH 054/183] Rename are_options_valid/are_table_options_valid -> are_keys_valid/are_table_keys_valid in config/validators.py --- src/vak/config/validators.py | 46 ++++++++++++++++++------------------ 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/src/vak/config/validators.py b/src/vak/config/validators.py index 7439d80aa..bc8001a85 100644 --- a/src/vak/config/validators.py +++ b/src/vak/config/validators.py @@ -62,8 +62,8 @@ def is_spect_format(instance, attribute, value): with VALID_TOML_PATH.open("r") as fp: VALID_DICT = tomlkit.load(fp)['vak'] VALID_TOP_LEVEL_TABLES = list(VALID_DICT.keys()) -VALID_OPTIONS = { - table: list(options.keys()) for table, options in VALID_DICT.items() +VALID_KEYS = { + table_name: list(table_config_dict.keys()) for table_name, table_config_dict in VALID_DICT.items() } @@ -111,50 +111,50 @@ def are_tables_valid(config_dict, toml_path=None): raise ValueError(err_msg) -def are_options_valid( +def are_keys_valid( config_dict: dict, table_name: str, toml_path: str | pathlib.Path | None = None ) -> None: """Given a :class:`dict` containing the *entire* configuration loaded from a toml file, - validate the option names for a specific top-level table, e.g. ``vak.train`` or ``vak.predict``""" - user_options = set(config_dict[table_name].keys()) - valid_options = set(VALID_OPTIONS[table_name]) - if not user_options.issubset(valid_options): - invalid_options = user_options - valid_options + validate the key names for a specific top-level table, e.g. ``vak.train`` or ``vak.predict``""" + table_keys = set(config_dict[table_name].keys()) + valid_keys = set(VALID_KEYS[table_name]) + if not table_keys.issubset(valid_keys): + invalid_keys = table_keys - valid_keys if toml_path: err_msg = ( - f"The following options from '{table_name}' table in " - f"the config file '{toml_path.name}' are not valid:\n{invalid_options}" + f"The following keys from '{table_name}' table in " + f"the config file '{toml_path.name}' are not valid:\n{invalid_keys}" ) else: err_msg = ( - f"The following options from '{table_name}' table in " - f"the toml config are not valid:\n{invalid_options}" + f"The following keys from '{table_name}' table in " + f"the toml config are not valid:\n{invalid_keys}" ) raise ValueError(err_msg) -def are_table_options_valid(table_config_dict: dict, table_name: str, toml_path: str | pathlib.Path | None = None) -> None: +def are_table_keys_valid(table_config_dict: dict, table_name: str, toml_path: str | pathlib.Path | None = None) -> None: """Given a :class:`dict` containing the configuration for a *specific* top-level table, - loaded from a toml file, validate the option names for that table, + loaded from a toml file, validate the key names for that table, e.g. ``vak.train`` or ``vak.predict``. This function assumes ``table_config_dict`` comes from the entire ``config_dict`` returned by :func:`vak.config.parse.from_toml_path`, accessed using the table name as a key, - unlike :func:`are_options_valid`. This function is used by the ``from_config_dict`` + unlike :func:`are_keys_valid`. This function is used by the ``from_config_dict`` classmethod of the top-level tables. """ - user_options = set(table_config_dict.keys()) - valid_options = set(VALID_OPTIONS[table_name]) - if not user_options.issubset(valid_options): - invalid_options = user_options - valid_options + table_keys = set(table_config_dict.keys()) + valid_keys = set(VALID_KEYS[table_name]) + if not table_keys.issubset(valid_keys): + invalid_keys = table_keys - valid_keys if toml_path: err_msg = ( - f"The following options from '{table_name}' table in " - f"the config file '{toml_path.name}' are not valid:\n{invalid_options}" + f"The following keys from '{table_name}' table in " + f"the config file '{toml_path.name}' are not valid:\n{invalid_keys}" ) else: err_msg = ( - f"The following options from '{table_name}' table in " - f"the toml config are not valid:\n{invalid_options}" + f"The following keys from '{table_name}' table in " + f"the toml config are not valid:\n{invalid_keys}" ) raise ValueError(err_msg) From cdf749525ebc01479939318cc83c06927d146bf4 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Thu, 2 May 2024 10:46:12 -0400 Subject: [PATCH 055/183] Rename two fixtures in fixtures/config.py: invalid_section_config_path -> invalid_table_config_path, invalid_option_config_path -> invalid_key_config_path --- tests/fixtures/config.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/fixtures/config.py b/tests/fixtures/config.py index 92ceb7f5c..cef2bb69f 100644 --- a/tests/fixtures/config.py +++ b/tests/fixtures/config.py @@ -61,13 +61,13 @@ def config_that_doesnt_exist(tmp_path): @pytest.fixture -def invalid_section_config_path(test_configs_root): - return test_configs_root.joinpath("invalid_section_config.toml") +def invalid_table_config_path(test_configs_root): + return test_configs_root.joinpath("invalid_table_config.toml") @pytest.fixture -def invalid_option_config_path(test_configs_root): - return test_configs_root.joinpath("invalid_option_config.toml") +def invalid_key_config_path(test_configs_root): + return test_configs_root.joinpath("invalid_key_config.toml") GENERATED_TEST_CONFIGS_ROOT = GENERATED_TEST_DATA_ROOT.joinpath("configs") From 90f401f0b7950833f39a652cb851d195448c6a42 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Thu, 2 May 2024 10:54:25 -0400 Subject: [PATCH 056/183] Fix validator names in config/parse.py, rename TABLE_CLASSES constant -> TABLE_CLASSES_MAP --- src/vak/config/parse.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/vak/config/parse.py b/src/vak/config/parse.py index fdae90723..a61b39bf7 100644 --- a/src/vak/config/parse.py +++ b/src/vak/config/parse.py @@ -12,10 +12,10 @@ from .predict import PredictConfig from .prep import PrepConfig from .train import TrainConfig -from .validators import are_options_valid, are_tables_valid +from .validators import are_keys_valid, are_tables_valid -TABLE_CLASSES = { +TABLE_CLASSES_MAP = { "eval": EvalConfig, "learncurve": LearncurveConfig, "predict": PredictConfig, @@ -68,14 +68,14 @@ def _validate_tables_to_parse_arg_convert_list(tables_to_parse: str | list[str]) ) if not all( [ - table_name in list(TABLE_CLASSES.keys()) + table_name in list(TABLE_CLASSES_MAP.keys()) for table_name in tables_to_parse ] ): raise ValueError( "All table names in 'tables_to_parse' should be valid names of tables. " f"Values for 'tables were: {tables_to_parse}.\n" - f"Valid table names are: {list(TABLE_CLASSES.keys())}" + f"Valid table names are: {list(TABLE_CLASSES_MAP.keys())}" ) return tables_to_parse @@ -109,7 +109,7 @@ def from_toml( are_tables_valid(config_dict, toml_path) if tables_to_parse is None: tables_to_parse = list( - TABLE_CLASSES.keys() + TABLE_CLASSES_MAP.keys() ) # i.e., parse all tables else: tables_to_parse = _validate_tables_to_parse_arg_convert_list(tables_to_parse) @@ -117,9 +117,9 @@ def from_toml( config_kwargs = {} for table_name in tables_to_parse: if table_name in config_dict: - are_options_valid(config_dict, table_name, toml_path) + are_keys_valid(config_dict, table_name, toml_path) table_config_dict = config_dict[table_name] - config_kwargs[table_name] = TABLE_CLASSES[table_name].from_config_dict(table_config_dict) + config_kwargs[table_name] = TABLE_CLASSES_MAP[table_name].from_config_dict(table_config_dict) else: raise KeyError( f"A table specified in `tables_to_parse` was not found in the config: {table_name}" From 67ec4a906c384653d195494d319eee171b707c5b Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Thu, 2 May 2024 10:54:57 -0400 Subject: [PATCH 057/183] Rename config/valid.toml -> valid-version-1.0.toml, fix how model table is declared --- src/vak/config/{valid.toml => valid-version-1.0.toml} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename src/vak/config/{valid.toml => valid-version-1.0.toml} (100%) diff --git a/src/vak/config/valid.toml b/src/vak/config/valid-version-1.0.toml similarity index 100% rename from src/vak/config/valid.toml rename to src/vak/config/valid-version-1.0.toml From 728893349007da59e2aee95f2c0e2b65f9d464a5 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Thu, 2 May 2024 10:56:05 -0400 Subject: [PATCH 058/183] Fix VALID_TOML_PATH in config/validators.py after renaming config/valid.toml -> config/valid-version-1.0.toml --- src/vak/config/validators.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/vak/config/validators.py b/src/vak/config/validators.py index bc8001a85..da8d55a04 100644 --- a/src/vak/config/validators.py +++ b/src/vak/config/validators.py @@ -58,7 +58,7 @@ def is_spect_format(instance, attribute, value): CONFIG_DIR = pathlib.Path(__file__).parent -VALID_TOML_PATH = CONFIG_DIR.joinpath("valid.toml") +VALID_TOML_PATH = CONFIG_DIR.joinpath("valid-version-1.0.toml") with VALID_TOML_PATH.open("r") as fp: VALID_DICT = tomlkit.load(fp)['vak'] VALID_TOP_LEVEL_TABLES = list(VALID_DICT.keys()) From 4f72b0c83df3efd854bbf86b602a218d29376f54 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Thu, 2 May 2024 11:21:44 -0400 Subject: [PATCH 059/183] Import config classes in vak/config/__init__.py --- src/vak/config/__init__.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/src/vak/config/__init__.py b/src/vak/config/__init__.py index 8f20f2224..4e2560714 100644 --- a/src/vak/config/__init__.py +++ b/src/vak/config/__init__.py @@ -1,6 +1,7 @@ """sub-package that parses config.toml files and returns config object""" from . import ( config, + dataset, eval, learncurve, model, @@ -12,6 +13,17 @@ validators, ) +from .config import Config +from .dataset import DatasetConfig +from .eval import EvalConfig +from .learncurve import LearncurveConfig +from .model import ModelConfig +from .predict import PredictConfig +from .prep import PrepConfig +from .spect_params import SpectParamsConfig +from .train import TrainConfig + + __all__ = [ "config", "eval", @@ -23,4 +35,13 @@ "spect_params", "train", "validators", + "Config", + "DatasetConfig", + "EvalConfig", + "LearncurveConfig", + "ModelConfig", + "PredictConfig", + "PrepConfig", + "SpectParamsConfig", + "TrainConfig", ] From bd0a0551c54264262e4fdd329c46e43c77f10061 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Thu, 2 May 2024 12:34:16 -0400 Subject: [PATCH 060/183] Add _tomlkit_to_popo to tests/fixtures/config.py so we operate on dicts not tomlkit.TOMLDocument --- tests/fixtures/config.py | 30 +++++++++++++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/tests/fixtures/config.py b/tests/fixtures/config.py index cef2bb69f..65279197b 100644 --- a/tests/fixtures/config.py +++ b/tests/fixtures/config.py @@ -225,6 +225,34 @@ def all_generated_predict_configs(generated_test_configs_root): return sorted(generated_test_configs_root.glob("test_predict*toml")) +def _tomlkit_to_popo(d): + """Convert tomlkit to "popo" (Plain-Old Python Objects) + + From https://github.com/python-poetry/tomlkit/issues/43#issuecomment-660415820 + """ + try: + result = getattr(d, "value") + except AttributeError: + result = d + + if isinstance(result, list): + result = [_tomlkit_to_popo(x) for x in result] + elif isinstance(result, dict): + result = { + _tomlkit_to_popo(key): _tomlkit_to_popo(val) for key, val in result.items() + } + elif isinstance(result, tomlkit.items.Integer): + result = int(result) + elif isinstance(result, tomlkit.items.Float): + result = float(result) + elif isinstance(result, tomlkit.items.String): + result = str(result) + elif isinstance(result, tomlkit.items.Bool): + result = bool(result) + + return result + + # ---- config dicts from paths ---- def _load_config_dict(toml_path): """Return config as dict, loaded from toml file. @@ -235,7 +263,7 @@ def _load_config_dict(toml_path): """ with toml_path.open("r") as fp: config_dict = tomlkit.load(fp) - return config_dict['vak'] + return _tomlkit_to_popo(config_dict['vak']) @pytest.fixture From e3593761b0af4807d2ac74aa804126a775bcbe2e Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Thu, 2 May 2024 12:34:34 -0400 Subject: [PATCH 061/183] Add _tomlkit_to_popo to config/parse.py so we operate on dicts not tomlkit.TOMLDocument --- src/vak/config/parse.py | 48 +++++++++++++++++++++++++++++++++++++++-- 1 file changed, 46 insertions(+), 2 deletions(-) diff --git a/src/vak/config/parse.py b/src/vak/config/parse.py index a61b39bf7..2b689f166 100644 --- a/src/vak/config/parse.py +++ b/src/vak/config/parse.py @@ -23,7 +23,7 @@ "train": TrainConfig, } -REQUIRED_OPTIONS = { +REQUIRED_KEYS = { "eval": [ "checkpoint_path", "output_dir", @@ -128,6 +128,42 @@ def from_toml( return Config(**config_kwargs) +def _tomlkit_to_popo(d): + """Convert tomlkit to "popo" (Plain-Old Python Objects) + + From https://github.com/python-poetry/tomlkit/issues/43#issuecomment-660415820 + + We need this so we don't get a ``tomlkit.items._ConvertError`` when + the `from_config_dict` classmethods try to add a class to a ``config_dict``, + e.g. when :meth:`EvalConfig.from_config_dict` converts the ``spect_params`` + key-value pairs to a :class:`vak.config.SpectParamsConfig` instance + and then assigns it to the ``spect_params`` key. + We would get this error if we just return the result of :func:`tomlkit.load`, + which is a `tomlkit.TOMLDocument` that tries to ensure that everything is valid toml. + """ + try: + result = getattr(d, "value") + except AttributeError: + result = d + + if isinstance(result, list): + result = [_tomlkit_to_popo(x) for x in result] + elif isinstance(result, dict): + result = { + _tomlkit_to_popo(key): _tomlkit_to_popo(val) for key, val in result.items() + } + elif isinstance(result, tomlkit.items.Integer): + result = int(result) + elif isinstance(result, tomlkit.items.Float): + result = float(result) + elif isinstance(result, tomlkit.items.String): + result = str(result) + elif isinstance(result, tomlkit.items.Bool): + result = bool(result) + + return result + + def _load_toml_from_path(toml_path: str | pathlib.Path) -> dict: """Load a toml file from a path, and return as a :class:`dict`. @@ -153,7 +189,15 @@ def _load_toml_from_path(toml_path: str | pathlib.Path) -> dict: f"Please see example configuration files here: " ) - return config_dict['vak'] + # Next line, convert TOMLDocument returned by tomlkit.load to a dict. + # We need this so we don't get a ``tomlkit.items._ConvertError`` when + # the `from_config_dict` classmethods try to add a class to a ``config_dict``, + # e.g. when :meth:`EvalConfig.from_config_dict` converts the ``spect_params`` + # key-value pairs to a :class:`vak.config.SpectParamsConfig` instance + # and then assigns it to the ``spect_params`` key. + # We would get this error if we just return the result of :func:`tomlkit.load`, + # which is a `tomlkit.TOMLDocument` that tries to ensure that everything is valid toml. + return _tomlkit_to_popo(config_dict['vak']) def from_toml_path(toml_path: str | pathlib.Path, tables_to_parse: list[str] | None = None) -> Config: From 0ff5860d6958db3b1ee473aeb463d791b888f8c9 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Thu, 2 May 2024 12:37:14 -0400 Subject: [PATCH 062/183] Finish rewriting tests for tests/test_config/test_prep.py --- tests/test_config/test_prep.py | 121 ++++++++++++++++++++++++++++++--- 1 file changed, 113 insertions(+), 8 deletions(-) diff --git a/tests/test_config/test_prep.py b/tests/test_config/test_prep.py index 3912f11f0..6901c2b23 100644 --- a/tests/test_config/test_prep.py +++ b/tests/test_config/test_prep.py @@ -1,12 +1,117 @@ """tests for vak.config.prep module""" +import copy + +import pytest + import vak.config.prep -def test_parse_prep_config_returns_PrepConfig_instance( - configs_toml_path_pairs_by_model_factory, -): - config_toml_path_pairs = configs_toml_path_pairs_by_model_factory("TweetyNet") - for config_toml, toml_path in config_toml_path_pairs: - prep_section = config_toml["PREP"] - config = vak.config.prep.PrepConfig(**prep_section) - assert isinstance(config, vak.config.prep.PrepConfig) +class TestPrepConfig: + @pytest.mark.parametrize( + 'config_dict', + [ + { + 'annot_format': 'notmat', + 'audio_format': 'cbin', + 'data_dir': './tests/data_for_tests/source/audio_cbin_annot_notmat/gy6or6/032412', + 'dataset_type': 'parametric umap', + 'input_type': 'spect', + 'labelset': 'iabcdefghjk', + 'output_dir': './tests/data_for_tests/generated/prep/eval/audio_cbin_annot_notmat/ConvEncoderUMAP', + 'spect_params': {'fft_size': 512, + 'step_size': 32, + 'transform_type': 'log_spect_plus_one'}, + 'test_dur': 0.2 + }, + { + 'annot_format': 'notmat', + 'audio_format': 'cbin', + 'data_dir': './tests/data_for_tests/source/audio_cbin_annot_notmat/gy6or6/032312', + 'dataset_type': 'frame classification', + 'input_type': 'spect', + 'labelset': 'iabcdefghjk', + 'output_dir': './tests/data_for_tests/generated/prep/train/audio_cbin_annot_notmat/TweetyNet', + 'spect_params': {'fft_size': 512, + 'freq_cutoffs': [500, 10000], + 'step_size': 64, + 'thresh': 6.25, + 'transform_type': 'log_spect'}, + 'test_dur': 30, + 'train_dur': 50, + 'val_dur': 15 + }, + ] + ) + def test_init(self, config_dict): + config_dict['spect_params'] = vak.config.SpectParamsConfig(**config_dict['spect_params']) + + prep_config = vak.config.PrepConfig(**config_dict) + + assert isinstance(prep_config, vak.config.prep.PrepConfig) + for key, val in config_dict.items(): + assert hasattr(prep_config, key) + if key == 'data_dir' or key == 'output_dir': + assert getattr(prep_config, key) == vak.common.converters.expanded_user_path(val) + elif key == 'labelset': + assert getattr(prep_config, key) == vak.common.converters.labelset_to_set(val) + else: + assert getattr(prep_config, key) == val + + @pytest.mark.parametrize( + 'config_dict', + [ + { + 'annot_format': 'notmat', + 'audio_format': 'cbin', + 'data_dir': './tests/data_for_tests/source/audio_cbin_annot_notmat/gy6or6/032412', + 'dataset_type': 'parametric umap', + 'input_type': 'spect', + 'labelset': 'iabcdefghjk', + 'output_dir': './tests/data_for_tests/generated/prep/eval/audio_cbin_annot_notmat/ConvEncoderUMAP', + 'spect_params': {'fft_size': 512, + 'step_size': 32, + 'transform_type': 'log_spect_plus_one'}, + 'test_dur': 0.2 + }, + { + 'annot_format': 'notmat', + 'audio_format': 'cbin', + 'data_dir': './tests/data_for_tests/source/audio_cbin_annot_notmat/gy6or6/032312', + 'dataset_type': 'frame classification', + 'input_type': 'spect', + 'labelset': 'iabcdefghjk', + 'output_dir': './tests/data_for_tests/generated/prep/train/audio_cbin_annot_notmat/TweetyNet', + 'spect_params': {'fft_size': 512, + 'freq_cutoffs': [500, 10000], + 'step_size': 64, + 'thresh': 6.25, + 'transform_type': 'log_spect'}, + 'test_dur': 30, + 'train_dur': 50, + 'val_dur': 15 + }, + ] + ) + def test_from_config_dict(self, config_dict): + # we have to make a copy since `from_config_dict` mutates the dict + config_dict_copy = copy.deepcopy(config_dict) + + prep_config = vak.config.prep.PrepConfig.from_config_dict(config_dict_copy) + + assert isinstance(prep_config, vak.config.prep.PrepConfig) + for key, val in config_dict.items(): + assert hasattr(prep_config, key) + if key == 'data_dir' or key == 'output_dir': + assert getattr(prep_config, key) == vak.common.converters.expanded_user_path(val) + elif key == 'labelset': + assert getattr(prep_config, key) == vak.common.converters.labelset_to_set(val) + elif key == 'spect_params': + assert getattr(prep_config, key) == vak.config.SpectParamsConfig(**val) + else: + assert getattr(prep_config, key) == val + + def test_from_config_dict_real_config( + self, a_generated_config_dict + ): + prep_config = vak.config.prep.PrepConfig.from_config_dict(a_generated_config_dict['prep']) + assert isinstance(prep_config, vak.config.prep.PrepConfig) From 3e5f1d26a4c22cd890a32ae6cd76c36aeec185dd Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Thu, 2 May 2024 12:49:48 -0400 Subject: [PATCH 063/183] Rewrite EvalConfig with from_config_dict method --- src/vak/config/eval.py | 57 +++++++++++++++++++++++++++++------------- 1 file changed, 39 insertions(+), 18 deletions(-) diff --git a/src/vak/config/eval.py b/src/vak/config/eval.py index bdb21e8f6..4e87604c6 100644 --- a/src/vak/config/eval.py +++ b/src/vak/config/eval.py @@ -11,7 +11,6 @@ from ..common.converters import expanded_user_path from .dataset import DatasetConfig from .model import ModelConfig -from .validators import is_valid_model_name def convert_post_tfm_kwargs(post_tfm_kwargs: dict) -> dict: @@ -73,22 +72,35 @@ def are_valid_post_tfm_kwargs(instance, attribute, value): ) +REQUIRED_KEYS = ( + "checkpoint_path", + "dataset", + "output_dir", + "model", +) + + @define class EvalConfig: """Class that represents [vak.eval] table in configuration file. Attributes ---------- - dataset_path : str - Path to dataset, e.g., a csv file generated by running ``vak prep``. + dataset : vak.config.DatasetConfig + The dataset to use: the path to it, + and optionally a path to a file representing splits, + and the name, if it is a built-in dataset. + Must be an instance of :class:`vak.config.DatasetConfig`. checkpoint_path : str path to directory with checkpoint files saved by Torch, to reload model output_dir : str Path to location where .csv files with evaluation metrics should be saved. labelmap_path : str path to 'labelmap.json' file. - model : str - Model name, e.g., ``model = "TweetyNet"`` + model : vak.config.ModelConfig + The model to use: its name, + and the parameters to configure it. + Must be an instance of :class:`vak.config.ModelConfig` batch_size : int number of samples per batch presented to models during training. num_workers : int @@ -130,15 +142,11 @@ class EvalConfig: # required, model / dataloader model = field( - validator=[instance_of(str), is_valid_model_name], + validator=instance_of(ModelConfig), ) batch_size = field(converter=int, validator=instance_of(int)) - - # dataset_path is actually 'required' but we can't enforce that here because cli.prep looks at - # what sections are defined to figure out where to add dataset_path after it creates the csv - dataset_path = field( - converter=converters.optional(expanded_user_path), - default=None, + dataset: DatasetConfig = field( + validator=instance_of(DatasetConfig), ) # "optional" but actually required for frame classification models @@ -177,11 +185,24 @@ class EvalConfig: @classmethod def from_config_dict(cls, config_dict: dict) -> EvalConfig: - """""" - # TODO: check for required keys here, including 'model' and 'dataset' - # use - model_config = ModelConfig(config_dict['model']) - dataset_config = DatasetConfig(config_dict['dataset']) + """Return :class:`EvalConfig` instance from a :class:`dict`. + + The :class:`dict` passed in should be the one found + by loading a valid configuration toml file with + :func:`vak.config.parse.from_toml_path`, + and then using key ``eval``, + i.e., ``EvalConfig.from_config_dict(config_dict['eval'])``.""" + for required_key in REQUIRED_KEYS: + if required_key not in config_dict: + raise KeyError( + "The `[vak.eval]` table in a configuration file requires " + f"the option '{required_key}', but it was not found " + "when loading the configuration file into a Python dictionary. " + "Please check that the configuration file is formatted correctly." + ) + config_dict['dataset'] = DatasetConfig(**config_dict['dataset']) + config_dict['model'] = ModelConfig(**config_dict['model']) return cls( **config_dict - ) \ No newline at end of file + ) + From aeee2b679c61e99ab87cd3f372138944204b77bc Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Thu, 2 May 2024 12:50:01 -0400 Subject: [PATCH 064/183] Rewrite LearncurveConfig with from_config_dict method --- src/vak/config/learncurve.py | 45 ++++++++++++++++++++++++++++++++---- 1 file changed, 41 insertions(+), 4 deletions(-) diff --git a/src/vak/config/learncurve.py b/src/vak/config/learncurve.py index d1489de6f..28d3237a9 100644 --- a/src/vak/config/learncurve.py +++ b/src/vak/config/learncurve.py @@ -8,16 +8,28 @@ from .train import TrainConfig +REQUIRED_KEYS = ( + 'dataset', + 'model', + 'root_results_dir' +) + + @define class LearncurveConfig(TrainConfig): """Class that represents ``[vak.learncurve]`` table in configuration file. Attributes ---------- - model : str - Model name, e.g., ``model = "TweetyNet"`` - dataset_path : str - Path to dataset, e.g., a csv file generated by running ``vak prep``. + model : vak.config.ModelConfig + The model to use: its name, + and the parameters to configure it. + Must be an instance of :class:`vak.config.ModelConfig` + dataset : vak.config.DatasetConfig + The dataset to use: the path to it, + and optionally a path to a file representing splits, + and the name, if it is a built-in dataset. + Must be an instance of :class:`vak.config.DatasetConfig`. num_epochs : int number of training epochs. One epoch = one iteration through the entire training set. @@ -58,3 +70,28 @@ class LearncurveConfig(TrainConfig): converter=converters.optional(convert_post_tfm_kwargs), default=None, ) + + # we over-ride this method from TrainConfig mainly so the docstring is correct. + # TODO: can we do this by just over-writing `__doc__` for the method on this class? + @classmethod + def from_config_dict(cls, config_dict: dict) -> "TrainConfig": + """Return :class:`LearncurveConfig` instance from a :class:`dict`. + + The :class:`dict` passed in should be the one found + by loading a valid configuration toml file with + :func:`vak.config.parse.from_toml_path`, + and then using key ``prep``, + i.e., ``LearncurveConfig.from_config_dict(config_dict['train'])``.""" + for required_key in REQUIRED_KEYS: + if required_key not in config_dict: + raise KeyError( + "The `[vak.train]` table in a configuration file requires " + f"the option '{required_key}', but it was not found " + "when loading the configuration file into a Python dictionary. " + "Please check that the configuration file is formatted correctly." + ) + config_dict['model'] = ModelConfig(**config_dict['model']) + config_dict['dataset'] = DatasetConfig(**config_dict['dataset']) + return cls( + **config_dict + ) \ No newline at end of file From 6ab3a5f10481667368b46192ef191ed53deb17ca Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Thu, 2 May 2024 12:50:10 -0400 Subject: [PATCH 065/183] Rewrite PredictConfig with from_config_dict method --- src/vak/config/predict.py | 61 ++++++++++++++++++++++++++++++--------- 1 file changed, 47 insertions(+), 14 deletions(-) diff --git a/src/vak/config/predict.py b/src/vak/config/predict.py index 1f5e402b1..d318ee89d 100644 --- a/src/vak/config/predict.py +++ b/src/vak/config/predict.py @@ -1,4 +1,4 @@ -"""Class that represents ``[vak.prep]`` section of configuration file.""" +"""Class that represents ``[vak.predict]`` table of configuration file.""" from __future__ import annotations import os @@ -8,25 +8,38 @@ from attr import converters, validators from attr.validators import instance_of +from .dataset import DatasetConfig +from .model import ModelConfig from ..common import device from ..common.converters import expanded_user_path -from .validators import is_valid_model_name + + +REQUIRED_KEYS = ( + "checkpoint_path", + "dataset", + "model", +) @define class PredictConfig: - """Class that represents ``[vak.prep]`` section of configuration file. + """Class that represents ``[vak.predict]`` table of configuration file. Attributes ---------- - dataset_path : str - Path to dataset, e.g., a csv file generated by running ``vak prep``. + dataset : vak.config.DatasetConfig + The dataset to use: the path to it, + and optionally a path to a file representing splits, + and the name, if it is a built-in dataset. + Must be an instance of :class:`vak.config.DatasetConfig`. checkpoint_path : str path to directory with checkpoint files saved by Torch, to reload model labelmap_path : str path to 'labelmap.json' file. - model : str - Model name, e.g., ``model = "TweetyNet"`` + model : vak.config.ModelConfig + The model to use: its name, + and the parameters to configure it. + Must be an instance of :class:`vak.config.ModelConfig` batch_size : int number of samples per batch presented to models during training. num_workers : int @@ -86,15 +99,11 @@ class PredictConfig: # required, model / dataloader model = field( - validator=[instance_of(str), is_valid_model_name], + validator=instance_of(ModelConfig), ) batch_size = field(converter=int, validator=instance_of(int)) - - # dataset_path is actually 'required' but we can't enforce that here because cli.prep looks at - # what sections are defined to figure out where to add dataset_path after it creates the csv - dataset_path = field( - converter=converters.optional(expanded_user_path), - default=None, + dataset: DatasetConfig = field( + validator=instance_of(DatasetConfig), ) # optional, transform @@ -131,3 +140,27 @@ class PredictConfig: validator=validators.optional(instance_of(dict)), default=None, ) + + + @classmethod + def from_config_dict(cls, config_dict: dict) -> PredictConfig: + """Return :class:`PredictConfig` instance from a :class:`dict`. + + The :class:`dict` passed in should be the one found + by loading a valid configuration toml file with + :func:`vak.config.parse.from_toml_path`, + and then using key ``predict``, + i.e., ``PredictConfig.from_config_dict(config_dict['predict'])``.""" + for required_key in REQUIRED_KEYS: + if required_key not in config_dict: + raise KeyError( + "The `[vak.eval]` table in a configuration file requires " + f"the option '{required_key}', but it was not found " + "when loading the configuration file into a Python dictionary. " + "Please check that the configuration file is formatted correctly." + ) + config_dict['dataset'] = DatasetConfig(**config_dict['dataset']) + config_dict['model'] = ModelConfig(**config_dict['model']) + return cls( + **config_dict + ) From 4b4fedf113f0ea6d2969f73cbf34437fc5ccd469 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Thu, 2 May 2024 12:50:24 -0400 Subject: [PATCH 066/183] Rewrite PrepConfig with from_config_dict method --- src/vak/config/prep.py | 43 ++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 41 insertions(+), 2 deletions(-) diff --git a/src/vak/config/prep.py b/src/vak/config/prep.py index aa7c98092..bd521c03a 100644 --- a/src/vak/config/prep.py +++ b/src/vak/config/prep.py @@ -1,4 +1,4 @@ -"""Class and functions for ``[vak.prep]`` table in configuration file.""" +"""Class and functions for ``[vak.prep]`` table of configuration file.""" from __future__ import annotations import inspect @@ -8,9 +8,10 @@ from attrs import converters, validators from attrs.validators import instance_of +from .spect_params import SpectParamsConfig +from .validators import is_annot_format, is_audio_format, is_spect_format from .. import prep from ..common.converters import expanded_user_path, labelset_to_set -from .validators import is_annot_format, is_audio_format, is_spect_format def duration_from_toml_value(value): @@ -62,6 +63,12 @@ def are_valid_dask_bag_kwargs(instance, attribute, value): ) +REQUIRED_KEYS = ( + "data_dir", + "output_dir", +) + + @define class PrepConfig: """Class that represents ``[vak.prep]`` table of configuration file. @@ -86,6 +93,11 @@ class PrepConfig: spect_format : str format of files containg spectrograms as 2-d matrices. One of {'mat', 'npy'}. + spect_params: vak.config.SpectParamsConfig, optional + Parameters for Short-Time Fourier Transform and post-processing + of spectrograms. + Instance of :class:`vak.config.SpectParamsConfig` class. + Optional, default is None. annot_format : str format of annotations. Any format that can be used with the crowsetta library is valid. @@ -157,6 +169,10 @@ def is_valid_input_type(self, attribute, value): spect_format = field( validator=validators.optional(is_spect_format), default=None ) + spect_params = field( + validator=validators.optional(instance_of(SpectParamsConfig)), + default=None, + ) annot_file = field( converter=converters.optional(expanded_user_path), default=None, @@ -205,3 +221,26 @@ def __attrs_post_init__(self): raise ValueError( "must specify either audio_format or spect_format" ) + + @classmethod + def from_config_dict(cls, config_dict: dict) -> PrepConfig: + """Return :class:`PrepConfig` instance from a :class:`dict`. + + The :class:`dict` passed in should be the one found + by loading a valid configuration toml file with + :func:`vak.config.parse.from_toml_path`, + and then using key ``prep``, + i.e., ``PrepConfig.from_config_dict(config_dict['prep'])``.""" + for required_key in REQUIRED_KEYS: + if required_key not in config_dict: + raise KeyError( + "The `[vak.prep]` table in a configuration file requires " + f"the key '{required_key}', but it was not found " + "when loading the configuration file into a Python dictionary. " + "Please check that the configuration file is formatted correctly." + ) + if 'spect_params' in config_dict: + config_dict['spect_params'] = SpectParamsConfig(**config_dict['spect_params']) + return cls( + **config_dict + ) From 77e38e6ccd3d98a6cbb1afbbfcc4251268875bed Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Thu, 2 May 2024 12:50:35 -0400 Subject: [PATCH 067/183] Rewrite TrainConfig with from_config_dict method --- src/vak/config/train.py | 57 +++++++++++++++++++++++++++++++---------- 1 file changed, 44 insertions(+), 13 deletions(-) diff --git a/src/vak/config/train.py b/src/vak/config/train.py index 21c378aa2..b6b5879d4 100644 --- a/src/vak/config/train.py +++ b/src/vak/config/train.py @@ -5,7 +5,15 @@ from ..common import device from ..common.converters import bool_from_str, expanded_user_path -from .validators import is_valid_model_name +from .dataset import DatasetConfig +from .model import ModelConfig + + +REQUIRED_KEYS = ( + 'dataset', + 'model', + 'root_results_dir' +) @define @@ -14,10 +22,15 @@ class TrainConfig: Attributes ---------- - model : str - Model name, e.g., ``model = "TweetyNet"`` - dataset_path : str - Path to dataset, e.g., a csv file generated by running ``vak prep``. + model : vak.config.ModelConfig + The model to use: its name, + and the parameters to configure it. + Must be an instance of :class:`vak.config.ModelConfig` + dataset : vak.config.DatasetConfig + The dataset to use: the path to it, + and optionally a path to a file representing splits, + and the name, if it is a built-in dataset. + Must be an instance of :class:`vak.config.DatasetConfig`. num_epochs : int number of training epochs. One epoch = one iteration through the entire training set. @@ -65,18 +78,13 @@ class TrainConfig: # required model = field( - validator=[instance_of(str), is_valid_model_name], + validator=instance_of(ModelConfig), ) num_epochs = field(converter=int, validator=instance_of(int)) batch_size = field(converter=int, validator=instance_of(int)) root_results_dir = field(converter=expanded_user_path) - - # optional - # dataset_path is actually 'required' but we can't enforce that here because cli.prep looks at - # what sections are defined to figure out where to add dataset_path after it creates the csv - dataset_path = field( - converter=converters.optional(expanded_user_path), - default=None, + dataset: DatasetConfig = field( + validator=instance_of(DatasetConfig), ) results_dirname = field( @@ -143,3 +151,26 @@ class TrainConfig: validator=validators.optional(instance_of(dict)), default=None, ) + + @classmethod + def from_config_dict(cls, config_dict: dict) -> "TrainConfig": + """Return :class:`TrainConfig` instance from a :class:`dict`. + + The :class:`dict` passed in should be the one found + by loading a valid configuration toml file with + :func:`vak.config.parse.from_toml_path`, + and then using key ``prep``, + i.e., ``TrainConfig.from_config_dict(config_dict['train'])``.""" + for required_key in REQUIRED_KEYS: + if required_key not in config_dict: + raise KeyError( + "The `[vak.train]` table in a configuration file requires " + f"the option '{required_key}', but it was not found " + "when loading the configuration file into a Python dictionary. " + "Please check that the configuration file is formatted correctly." + ) + config_dict['model'] = ModelConfig(**config_dict['model']) + config_dict['dataset'] = DatasetConfig(**config_dict['dataset']) + return cls( + **config_dict + ) \ No newline at end of file From f3e37ae3a07a228b1b0079f09e816addb27334ca Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Thu, 2 May 2024 16:40:44 -0400 Subject: [PATCH 068/183] Remove functions from config/parse.py --- src/vak/config/parse.py | 143 +--------------------------------------- 1 file changed, 2 insertions(+), 141 deletions(-) diff --git a/src/vak/config/parse.py b/src/vak/config/parse.py index 2b689f166..2b4a489c1 100644 --- a/src/vak/config/parse.py +++ b/src/vak/config/parse.py @@ -15,119 +15,6 @@ from .validators import are_keys_valid, are_tables_valid -TABLE_CLASSES_MAP = { - "eval": EvalConfig, - "learncurve": LearncurveConfig, - "predict": PredictConfig, - "prep": PrepConfig, - "train": TrainConfig, -} - -REQUIRED_KEYS = { - "eval": [ - "checkpoint_path", - "output_dir", - "model", - ], - "learncurve": [ - "model", - "root_results_dir", - ], - "predict": [ - "checkpoint_path", - "model", - ], - "prep": [ - "data_dir", - "output_dir", - ], - "train": [ - "model", - "root_results_dir", - ], -} - - -def _validate_tables_to_parse_arg_convert_list(tables_to_parse: str | list[str]) -> list[str]: - """Helper function used by :func:`from_toml` that - validates the ``tables_to_parse`` argument, - and returns it as a list of strings.""" - if isinstance(tables_to_parse, str): - tables_to_parse = [tables_to_parse] - - if not isinstance(tables_to_parse, list): - raise TypeError( - f"`tables_to_parse` should be a string or list of strings but type was: {type(tables_to_parse)}" - ) - - if not all( - [isinstance(table_name, str) for table_name in tables_to_parse] - ): - raise ValueError( - "All table names in 'tables_to_parse' should be strings" - ) - if not all( - [ - table_name in list(TABLE_CLASSES_MAP.keys()) - for table_name in tables_to_parse - ] - ): - raise ValueError( - "All table names in 'tables_to_parse' should be valid names of tables. " - f"Values for 'tables were: {tables_to_parse}.\n" - f"Valid table names are: {list(TABLE_CLASSES_MAP.keys())}" - ) - return tables_to_parse - - -def from_toml( - config_dict: dict, toml_path: str | pathlib.Path | None = None, tables_to_parse: str | list[str] | None = None - ) -> Config: - """Load a TOML configuration file. - - Parameters - ---------- - config_dict : dict - Python ``dict`` containing a .toml configuration file, - parsed by the ``toml`` library. - toml_path : str, pathlib.Path - path to a configuration file in TOML format. Default is None. - Not required, used only to make any error messages clearer. - tables_to_parse : str, list - Name of top-level table or tables from configuration - file that should be parsed. Can be a string - (single table) or list of strings (multiple - tables). Default is None, - in which case all are validated and parsed. - - Returns - ------- - config : vak.config.parse.Config - instance of Config class, whose attributes correspond to - tables in a config.toml file. - """ - are_tables_valid(config_dict, toml_path) - if tables_to_parse is None: - tables_to_parse = list( - TABLE_CLASSES_MAP.keys() - ) # i.e., parse all tables - else: - tables_to_parse = _validate_tables_to_parse_arg_convert_list(tables_to_parse) - - config_kwargs = {} - for table_name in tables_to_parse: - if table_name in config_dict: - are_keys_valid(config_dict, table_name, toml_path) - table_config_dict = config_dict[table_name] - config_kwargs[table_name] = TABLE_CLASSES_MAP[table_name].from_config_dict(table_config_dict) - else: - raise KeyError( - f"A table specified in `tables_to_parse` was not found in the config: {table_name}" - ) - - return Config(**config_kwargs) - - def _tomlkit_to_popo(d): """Convert tomlkit to "popo" (Plain-Old Python Objects) @@ -197,31 +84,5 @@ def _load_toml_from_path(toml_path: str | pathlib.Path) -> dict: # and then assigns it to the ``spect_params`` key. # We would get this error if we just return the result of :func:`tomlkit.load`, # which is a `tomlkit.TOMLDocument` that tries to ensure that everything is valid toml. - return _tomlkit_to_popo(config_dict['vak']) - - -def from_toml_path(toml_path: str | pathlib.Path, tables_to_parse: list[str] | None = None) -> Config: - """Parse a TOML configuration file and return as a :class:`Config`. - - Parameters - ---------- - toml_path : str, pathlib.Path - Path to a configuration file in TOML format. - Parsed by ``toml`` library, then converted to an - instance of ``vak.config.parse.Config`` by - calling ``vak.parse.from_toml`` - tables_to_parse : str, list - Name of table or tables from configuration - file that should be parsed. Can be a string - (single table) or list of strings (multiple - tables). Default is None, - in which case all are validated and parsed. - - Returns - ------- - config : vak.config.parse.Config - instance of :class:`Config` class, whose attributes correspond to - tables in a config.toml file. - """ - config_dict: dict = _load_toml_from_path(toml_path) - return from_toml(config_dict, toml_path, tables_to_parse) + return _tomlkit_to_popo(config_dict) + From 2b598d0ca177b14d9880da5310160f114de758ff Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Thu, 2 May 2024 19:52:50 -0400 Subject: [PATCH 069/183] Rename config/parse.py -> config/load.py --- src/vak/config/__init__.py | 4 ++-- src/vak/config/{parse.py => load.py} | 0 2 files changed, 2 insertions(+), 2 deletions(-) rename src/vak/config/{parse.py => load.py} (100%) diff --git a/src/vak/config/__init__.py b/src/vak/config/__init__.py index 4e2560714..7de1fa9b8 100644 --- a/src/vak/config/__init__.py +++ b/src/vak/config/__init__.py @@ -4,8 +4,8 @@ dataset, eval, learncurve, + load, model, - parse, predict, prep, spect_params, @@ -29,7 +29,7 @@ "eval", "learncurve", "model", - "parse", + "load", "predict", "prep", "spect_params", diff --git a/src/vak/config/parse.py b/src/vak/config/load.py similarity index 100% rename from src/vak/config/parse.py rename to src/vak/config/load.py From d175436e2d6ecd2cffbf94c462e13501cfa8864a Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Thu, 2 May 2024 16:40:01 -0400 Subject: [PATCH 070/183] Make functions in config/parse.py into classmethods on Config class --- src/vak/config/config.py | 164 +++++++++++++++++++++++++++++++++++++-- 1 file changed, 156 insertions(+), 8 deletions(-) diff --git a/src/vak/config/config.py b/src/vak/config/config.py index 330b83414..e9efbac0e 100644 --- a/src/vak/config/config.py +++ b/src/vak/config/config.py @@ -1,16 +1,64 @@ -import attr +"""Class that represents the TOML configuration file used with the vak command-line interface.""" +from __future__ import annotations + +import pathlib + +from attrs import define, field from attr.validators import instance_of, optional +from . import load from .eval import EvalConfig from .learncurve import LearncurveConfig from .predict import PredictConfig from .prep import PrepConfig from .train import TrainConfig +from .validators import are_keys_valid, are_tables_valid + + +TABLE_CLASSES_MAP = { + "eval": EvalConfig, + "learncurve": LearncurveConfig, + "predict": PredictConfig, + "prep": PrepConfig, + "train": TrainConfig, +} + +def _validate_tables_to_parse_arg_convert_list(tables_to_parse: str | list[str]) -> list[str]: + """Helper function used by :func:`from_toml` that + validates the ``tables_to_parse`` argument, + and returns it as a list of strings.""" + if isinstance(tables_to_parse, str): + tables_to_parse = [tables_to_parse] -@attr.s + if not isinstance(tables_to_parse, list): + raise TypeError( + f"`tables_to_parse` should be a string or list of strings but type was: {type(tables_to_parse)}" + ) + + if not all( + [isinstance(table_name, str) for table_name in tables_to_parse] + ): + raise ValueError( + "All table names in 'tables_to_parse' should be strings" + ) + if not all( + [ + table_name in list(TABLE_CLASSES_MAP.keys()) + for table_name in tables_to_parse + ] + ): + raise ValueError( + "All table names in 'tables_to_parse' should be valid names of tables. " + f"Values for 'tables were: {tables_to_parse}.\n" + f"Valid table names are: {list(TABLE_CLASSES_MAP.keys())}" + ) + return tables_to_parse + + +@define class Config: - """Class that represents a configuration file. + """Class that represents the TOML configuration file used with the vak command-line interface. Attributes ---------- @@ -25,12 +73,112 @@ class Config: learncurve : vak.config.learncurve.LearncurveConfig Represents ``[vak.learncurve]`` table of config.toml file """ - prep = attr.ib(validator=optional(instance_of(PrepConfig)), default=None) - train = attr.ib(validator=optional(instance_of(TrainConfig)), default=None) - eval = attr.ib(validator=optional(instance_of(EvalConfig)), default=None) - predict = attr.ib( + prep = field(validator=optional(instance_of(PrepConfig)), default=None) + train = field(validator=optional(instance_of(TrainConfig)), default=None) + eval = field(validator=optional(instance_of(EvalConfig)), default=None) + predict = field( validator=optional(instance_of(PredictConfig)), default=None ) - learncurve = attr.ib( + learncurve = field( validator=optional(instance_of(LearncurveConfig)), default=None ) + + @classmethod + def from_config_dict( + cls, + config_dict: dict, + toml_path: str | pathlib.Path | None = None, + tables_to_parse: str | list[str] | None = None + ) -> "Config": + """Return instance of :class:`Config` class, + given a :class:`dict` containing the contents of + a TOML configuration file. + + This :func:`classmethod` expects the output + of :func:`vak.config.load._load_from_toml_path`, + that converts a :class:`tomlkit.TOMLDocument` + to a :class:`dict`. + + Parameters + ---------- + config_dict : dict + Python ``dict`` containing a .toml configuration file, + parsed by the ``toml`` library. + toml_path : str, pathlib.Path + path to a configuration file in TOML format. Default is None. + Not required, used only to make any error messages clearer. + tables_to_parse : str, list + Name of top-level table or tables from configuration + file that should be parsed. Can be a string + (single table) or list of strings (multiple + tables). Default is None, + in which case all are validated and parsed. + + Returns + ------- + config : vak.config.parse.Config + instance of :class:`Config` class, + whose attributes correspond to the + top-level tables in a config.toml file. + """ + try: + config_dict = config_dict['vak'] + except KeyError as e: + raise KeyError( + "Did not find key 'vak' in `config_dict`." + "All top-level tables in toml configuration file must " + "use dotted names that begin with 'vak', e.g. ``[vak.eval]``.\n" + f"`config_dict`:\n{config_dict}" + ) + are_tables_valid(config_dict, toml_path) + if tables_to_parse is None: + tables_to_parse = list( + TABLE_CLASSES_MAP.keys() + ) # i.e., parse all tables + else: + tables_to_parse = _validate_tables_to_parse_arg_convert_list(tables_to_parse) + + config_kwargs = {} + for table_name in tables_to_parse: + if table_name in config_dict: + are_keys_valid(config_dict, table_name, toml_path) + table_config_dict = config_dict[table_name] + config_kwargs[table_name] = TABLE_CLASSES_MAP[table_name].from_config_dict(table_config_dict) + else: + raise KeyError( + f"A table specified in `tables_to_parse` was not found in the config: {table_name}" + ) + + return cls(**config_kwargs) + + @classmethod + def from_toml_path( + cls, + toml_path: str | pathlib.Path, + tables_to_parse: list[str] | None = None + ) -> "Config": + """Return instance of :class:`Config` class, + given the path to a TOML configuration file. + + Parameters + ---------- + toml_path : str, pathlib.Path + Path to a configuration file in TOML format. + Parsed by ``toml`` library, then converted to an + instance of ``vak.config.parse.Config`` by + calling ``vak.parse.from_toml`` + tables_to_parse : str, list + Name of table or tables from configuration + file that should be parsed. Can be a string + (single table) or list of strings (multiple + tables). Default is None, + in which case all are validated and parsed. + + Returns + ------- + config : vak.config.parse.Config + instance of :class:`Config` class, whose attributes correspond to + tables in a config.toml file. + """ + config_dict: dict = load._load_toml_from_path(toml_path) + return cls.from_toml(config_dict, toml_path, tables_to_parse) From 323c23815709cc2c815772e93062357e9912f35e Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Thu, 2 May 2024 19:59:15 -0400 Subject: [PATCH 071/183] Use config.Config.from_toml_path everywhere instead of config.parse.from_toml_path --- src/vak/cli/eval.py | 2 +- src/vak/cli/learncurve.py | 2 +- src/vak/cli/predict.py | 2 +- src/vak/cli/train.py | 2 +- tests/scripts/vaktestdata/source_files.py | 12 ++++++------ .../test_frame_classification/test_frames_dataset.py | 2 +- .../test_frame_classification/test_window_dataset.py | 2 +- .../test_parametric_umap/test_parametric_umap.py | 2 +- tests/test_eval/test_eval.py | 2 +- tests/test_eval/test_frame_classification.py | 6 +++--- tests/test_eval/test_parametric_umap.py | 6 +++--- tests/test_learncurve/test_frame_classification.py | 4 ++-- tests/test_models/test_base.py | 4 ++-- tests/test_models/test_frame_classification_model.py | 2 +- tests/test_models/test_parametric_umap_model.py | 2 +- tests/test_predict/test_frame_classification.py | 6 +++--- tests/test_predict/test_predict.py | 2 +- .../test_assign_samples_to_splits.py | 2 +- .../test_frame_classification.py | 12 ++++++------ .../test_get_or_make_source_files.py | 2 +- .../test_frame_classification/test_learncurve.py | 4 ++-- .../test_frame_classification/test_make_splits.py | 2 +- tests/test_prep/test_prep.py | 2 +- tests/test_train/test_frame_classification.py | 8 ++++---- tests/test_train/test_parametric_umap.py | 6 +++--- tests/test_train/test_train.py | 2 +- 26 files changed, 50 insertions(+), 50 deletions(-) diff --git a/src/vak/cli/eval.py b/src/vak/cli/eval.py index 5914bb41e..736edda91 100644 --- a/src/vak/cli/eval.py +++ b/src/vak/cli/eval.py @@ -24,7 +24,7 @@ def eval(toml_path): None """ toml_path = Path(toml_path) - cfg = config.parse.from_toml_path(toml_path) + cfg = config.Config.from_toml_path(toml_path) if cfg.eval is None: raise ValueError( diff --git a/src/vak/cli/learncurve.py b/src/vak/cli/learncurve.py index e68f54651..9b6991816 100644 --- a/src/vak/cli/learncurve.py +++ b/src/vak/cli/learncurve.py @@ -23,7 +23,7 @@ def learning_curve(toml_path): path to a configuration file in TOML format. """ toml_path = Path(toml_path) - cfg = config.parse.from_toml_path(toml_path) + cfg = config.Config.from_toml_path(toml_path) if cfg.learncurve is None: raise ValueError( diff --git a/src/vak/cli/predict.py b/src/vak/cli/predict.py index a33c4955e..19c4648da 100644 --- a/src/vak/cli/predict.py +++ b/src/vak/cli/predict.py @@ -18,7 +18,7 @@ def predict(toml_path): path to a configuration file in TOML format. """ toml_path = Path(toml_path) - cfg = config.parse.from_toml_path(toml_path) + cfg = config.Config.from_toml_path(toml_path) if cfg.predict is None: raise ValueError( diff --git a/src/vak/cli/train.py b/src/vak/cli/train.py index cedac92bb..97fddb133 100644 --- a/src/vak/cli/train.py +++ b/src/vak/cli/train.py @@ -23,7 +23,7 @@ def train(toml_path): path to a configuration file in TOML format. """ toml_path = Path(toml_path) - cfg = config.parse.from_toml_path(toml_path) + cfg = config.Config.from_toml_path(toml_path) if cfg.train is None: raise ValueError( diff --git a/tests/scripts/vaktestdata/source_files.py b/tests/scripts/vaktestdata/source_files.py index 4f72b134a..2f7c032eb 100644 --- a/tests/scripts/vaktestdata/source_files.py +++ b/tests/scripts/vaktestdata/source_files.py @@ -47,7 +47,7 @@ def set_up_source_files_and_csv_files_for_frame_classification_models(): f"\nRunning :func:`vak.prep.frame_classification.get_or_make_source_files` to generate data for tests, " f"using config:\n{config_path.name}" ) - cfg = vak.config.parse.from_toml_path(config_path) + cfg = vak.config.Config.from_toml_path(config_path) source_files_df: pd.DataFrame = vak.prep.frame_classification.get_or_make_source_files( data_dir=cfg.prep.data_dir, @@ -72,7 +72,7 @@ def set_up_source_files_and_csv_files_for_frame_classification_models(): csv_path = constants.GENERATED_SOURCE_FILES_CSV_DIR / f'{config_metadata.filename}-source-files.csv' source_files_df.to_csv(csv_path, index=False) - config_toml: dict = vak.config.parse._load_toml_from_path(config_path) + config_toml: dict = vak.config.load._load_toml_from_path(config_path) purpose = vak.cli.prep.purpose_from_toml(config_toml, config_path) dataset_df: pd.DataFrame = vak.prep.frame_classification.assign_samples_to_splits( purpose, @@ -109,7 +109,7 @@ def set_up_source_files_and_csv_files_for_frame_classification_models(): with config_path.open("w") as fp: tomlkit.dump(config_toml, fp) - cfg = vak.config.parse.from_toml_path(config_path) + cfg = vak.config.Config.from_toml_path(config_path) source_files_df: pd.DataFrame = vak.prep.frame_classification.get_or_make_source_files( data_dir=cfg.prep.data_dir, @@ -127,7 +127,7 @@ def set_up_source_files_and_csv_files_for_frame_classification_models(): csv_path = constants.GENERATED_SOURCE_FILES_CSV_DIR / f'{config_metadata.filename}-source-files.csv' source_files_df.to_csv(csv_path, index=False) - config_toml: dict = vak.config.parse._load_toml_from_path(config_path) + config_toml: dict = vak.config.load._load_toml_from_path(config_path) purpose = vak.cli.prep.purpose_from_toml(config_toml, config_path) dataset_df: pd.DataFrame = vak.prep.frame_classification.assign_samples_to_splits( purpose, @@ -159,7 +159,7 @@ def set_up_source_files_and_csv_files_for_frame_classification_models(): f"\nRunning :func:`vak.prep.frame_classification.get_or_make_source_files` to generate data for tests, " f"using config:\n{config_path.name}" ) - cfg = vak.config.parse.from_toml_path(config_path) + cfg = vak.config.Config.from_toml_path(config_path) source_files_df: pd.DataFrame = vak.prep.frame_classification.get_or_make_source_files( data_dir=cfg.prep.data_dir, input_type=cfg.prep.input_type, @@ -176,7 +176,7 @@ def set_up_source_files_and_csv_files_for_frame_classification_models(): csv_path = constants.GENERATED_SOURCE_FILES_CSV_DIR / f'{config_metadata.filename}-source-files.csv' source_files_df.to_csv(csv_path, index=False) - config_toml: dict = vak.config.parse._load_toml_from_path(config_path) + config_toml: dict = vak.config.load._load_toml_from_path(config_path) purpose = vak.cli.prep.purpose_from_toml(config_toml, config_path) dataset_df: pd.DataFrame = vak.prep.frame_classification.assign_samples_to_splits( purpose, diff --git a/tests/test_datasets/test_frame_classification/test_frames_dataset.py b/tests/test_datasets/test_frame_classification/test_frames_dataset.py index a7674ec61..52a75495f 100644 --- a/tests/test_datasets/test_frame_classification/test_frames_dataset.py +++ b/tests/test_datasets/test_frame_classification/test_frames_dataset.py @@ -19,7 +19,7 @@ def test_from_dataset_path(self, config_type, model_name, audio_format, spect_fo audio_format=audio_format, spect_format=spect_format, annot_format=annot_format) - cfg = vak.config.parse.from_toml_path(toml_path) + cfg = vak.config.Config.from_toml_path(toml_path) cfg_command = getattr(cfg, config_type) item_transform = vak.transforms.defaults.get_default_transform( diff --git a/tests/test_datasets/test_frame_classification/test_window_dataset.py b/tests/test_datasets/test_frame_classification/test_window_dataset.py index 613fd1854..67c6fde50 100644 --- a/tests/test_datasets/test_frame_classification/test_window_dataset.py +++ b/tests/test_datasets/test_frame_classification/test_window_dataset.py @@ -20,7 +20,7 @@ def test_from_dataset_path(self, config_type, model_name, audio_format, spect_fo audio_format=audio_format, spect_format=spect_format, annot_format=annot_format) - cfg = vak.config.parse.from_toml_path(toml_path) + cfg = vak.config.Config.from_toml_path(toml_path) cfg_command = getattr(cfg, config_type) transform, target_transform = vak.transforms.defaults.get_default_transform( diff --git a/tests/test_datasets/test_parametric_umap/test_parametric_umap.py b/tests/test_datasets/test_parametric_umap/test_parametric_umap.py index 15eab713f..e2829991d 100644 --- a/tests/test_datasets/test_parametric_umap/test_parametric_umap.py +++ b/tests/test_datasets/test_parametric_umap/test_parametric_umap.py @@ -19,7 +19,7 @@ def test_from_dataset_path(self, config_type, model_name, audio_format, spect_fo audio_format=audio_format, spect_format=spect_format, annot_format=annot_format) - cfg = vak.config.parse.from_toml_path(toml_path) + cfg = vak.config.Config.from_toml_path(toml_path) cfg_command = getattr(cfg, config_type) transform = vak.transforms.defaults.get_default_transform( diff --git a/tests/test_eval/test_eval.py b/tests/test_eval/test_eval.py index b4e69322b..ba0e38143 100644 --- a/tests/test_eval/test_eval.py +++ b/tests/test_eval/test_eval.py @@ -42,7 +42,7 @@ def test_eval( spect_format=spect_format, options_to_change=options_to_change, ) - cfg = vak.config.parse.from_toml_path(toml_path) + cfg = vak.config.Config.from_toml_path(toml_path) model_config = vak.config.model.config_from_toml_path(toml_path, cfg.eval.model) results_path = tmp_path / 'results_path' diff --git a/tests/test_eval/test_frame_classification.py b/tests/test_eval/test_frame_classification.py index ce299c0e6..fb2055351 100644 --- a/tests/test_eval/test_frame_classification.py +++ b/tests/test_eval/test_frame_classification.py @@ -67,7 +67,7 @@ def test_eval_frame_classification_model( spect_format=spect_format, options_to_change=options_to_change, ) - cfg = vak.config.parse.from_toml_path(toml_path) + cfg = vak.config.Config.from_toml_path(toml_path) model_config = vak.config.model.config_from_toml_path(toml_path, cfg.eval.model) vak.eval.frame_classification.eval_frame_classification_model( @@ -125,7 +125,7 @@ def test_eval_frame_classification_model_raises_file_not_found( spect_format=None, options_to_change=options_to_change, ) - cfg = vak.config.parse.from_toml_path(toml_path) + cfg = vak.config.Config.from_toml_path(toml_path) model_config = vak.config.model.config_from_toml_path(toml_path, cfg.eval.model) with pytest.raises(FileNotFoundError): vak.eval.frame_classification.eval_frame_classification_model( @@ -183,7 +183,7 @@ def test_eval_frame_classification_model_raises_not_a_directory( spect_format=None, options_to_change=options_to_change, ) - cfg = vak.config.parse.from_toml_path(toml_path) + cfg = vak.config.Config.from_toml_path(toml_path) model_config = vak.config.model.config_from_toml_path(toml_path, cfg.eval.model) with pytest.raises(NotADirectoryError): vak.eval.frame_classification.eval_frame_classification_model( diff --git a/tests/test_eval/test_parametric_umap.py b/tests/test_eval/test_parametric_umap.py index 5b803a7e7..925178c1a 100644 --- a/tests/test_eval/test_parametric_umap.py +++ b/tests/test_eval/test_parametric_umap.py @@ -45,7 +45,7 @@ def test_eval_parametric_umap_model( spect_format=spect_format, options_to_change=options_to_change, ) - cfg = vak.config.parse.from_toml_path(toml_path) + cfg = vak.config.Config.from_toml_path(toml_path) model_config = vak.config.model.config_from_toml_path(toml_path, cfg.eval.model) vak.eval.parametric_umap.eval_parametric_umap_model( @@ -97,7 +97,7 @@ def test_eval_frame_classification_model_raises_file_not_found( spect_format=None, options_to_change=options_to_change, ) - cfg = vak.config.parse.from_toml_path(toml_path) + cfg = vak.config.Config.from_toml_path(toml_path) model_config = vak.config.model.config_from_toml_path(toml_path, cfg.eval.model) with pytest.raises(FileNotFoundError): vak.eval.parametric_umap.eval_parametric_umap_model( @@ -153,7 +153,7 @@ def test_eval_frame_classification_model_raises_not_a_directory( spect_format=None, options_to_change=options_to_change, ) - cfg = vak.config.parse.from_toml_path(toml_path) + cfg = vak.config.Config.from_toml_path(toml_path) model_config = vak.config.model.config_from_toml_path(toml_path, cfg.eval.model) with pytest.raises(NotADirectoryError): vak.eval.parametric_umap.eval_parametric_umap_model( diff --git a/tests/test_learncurve/test_frame_classification.py b/tests/test_learncurve/test_frame_classification.py index cc3484279..d83ca9eeb 100644 --- a/tests/test_learncurve/test_frame_classification.py +++ b/tests/test_learncurve/test_frame_classification.py @@ -62,7 +62,7 @@ def test_learning_curve_for_frame_classification_model( options_to_change=options_to_change, ) - cfg = vak.config.parse.from_toml_path(toml_path) + cfg = vak.config.Config.from_toml_path(toml_path) model_config = vak.config.model.config_from_toml_path(toml_path, cfg.learncurve.model) results_path = vak.common.paths.generate_results_dir_name_as_path(tmp_path) results_path.mkdir() @@ -116,7 +116,7 @@ def test_learncurve_raises_not_a_directory(dir_option_to_change, annot_format="notmat", options_to_change=options_to_change, ) - cfg = vak.config.parse.from_toml_path(toml_path) + cfg = vak.config.Config.from_toml_path(toml_path) model_config = vak.config.model.config_from_toml_path(toml_path, cfg.learncurve.model) # mock behavior of cli.learncurve, building `results_path` from config option `root_results_dir` results_path = cfg.learncurve.root_results_dir / 'results-dir-timestamp' diff --git a/tests/test_models/test_base.py b/tests/test_models/test_base.py index 0c3236296..d916e049a 100644 --- a/tests/test_models/test_base.py +++ b/tests/test_models/test_base.py @@ -191,7 +191,7 @@ def test_load_state_dict_from_path(self, """ definition = self.MODEL_DEFINITION_MAP[model_name] train_toml_path = specific_config_toml_path('train', model_name, audio_format='cbin', annot_format='notmat') - train_cfg = vak.config.parse.from_toml_path(train_toml_path) + train_cfg = vak.config.Config.from_toml_path(train_toml_path) # stuff we need just to be able to instantiate network labelmap = vak.common.labels.to_map(train_cfg.prep.labelset, map_unlabeled=True) @@ -225,7 +225,7 @@ def test_load_state_dict_from_path(self, model.to(device) eval_toml_path = specific_config_toml_path('eval', model_name, audio_format='cbin', annot_format='notmat') - eval_cfg = vak.config.parse.from_toml_path(eval_toml_path) + eval_cfg = vak.config.Config.from_toml_path(eval_toml_path) checkpoint_path = eval_cfg.eval.checkpoint_path # ---- actually test method diff --git a/tests/test_models/test_frame_classification_model.py b/tests/test_models/test_frame_classification_model.py index e77e84acf..c516b7d54 100644 --- a/tests/test_models/test_frame_classification_model.py +++ b/tests/test_models/test_frame_classification_model.py @@ -86,7 +86,7 @@ def test_from_config(self, definition = vak.models.definition.validate(definition) model_name = definition.__name__.replace('Definition', '') toml_path = specific_config_toml_path('train', model_name, audio_format='cbin', annot_format='notmat') - cfg = vak.config.parse.from_toml_path(toml_path) + cfg = vak.config.Config.from_toml_path(toml_path) # stuff we need just to be able to instantiate network labelmap = vak.common.labels.to_map(cfg.prep.labelset, map_unlabeled=True) diff --git a/tests/test_models/test_parametric_umap_model.py b/tests/test_models/test_parametric_umap_model.py index 0e255fed1..201c3da7f 100644 --- a/tests/test_models/test_parametric_umap_model.py +++ b/tests/test_models/test_parametric_umap_model.py @@ -87,7 +87,7 @@ def test_from_config( definition = vak.models.definition.validate(definition) model_name = definition.__name__.replace('Definition', '') toml_path = specific_config_toml_path('train', model_name, audio_format='cbin', annot_format='notmat') - cfg = vak.config.parse.from_toml_path(toml_path) + cfg = vak.config.Config.from_toml_path(toml_path) monkeypatch.setattr( vak.models.ParametricUMAPModel, 'definition', definition, raising=False diff --git a/tests/test_predict/test_frame_classification.py b/tests/test_predict/test_frame_classification.py index 6726ec09b..db77ef7e0 100644 --- a/tests/test_predict/test_frame_classification.py +++ b/tests/test_predict/test_frame_classification.py @@ -49,7 +49,7 @@ def test_predict_with_frame_classification_model( annot_format=annot_format, options_to_change=options_to_change, ) - cfg = vak.config.parse.from_toml_path(toml_path) + cfg = vak.config.Config.from_toml_path(toml_path) model_config = vak.config.model.config_from_toml_path(toml_path, cfg.predict.model) @@ -124,7 +124,7 @@ def test_predict_with_frame_classification_model_raises_file_not_found( annot_format="notmat", options_to_change=options_to_change, ) - cfg = vak.config.parse.from_toml_path(toml_path) + cfg = vak.config.Config.from_toml_path(toml_path) model_config = vak.config.model.config_from_toml_path(toml_path, cfg.predict.model) @@ -188,7 +188,7 @@ def test_predict_with_frame_classification_model_raises_not_a_directory( annot_format="notmat", options_to_change=options_to_change, ) - cfg = vak.config.parse.from_toml_path(toml_path) + cfg = vak.config.Config.from_toml_path(toml_path) model_config = vak.config.model.config_from_toml_path(toml_path, cfg.predict.model) with pytest.raises(NotADirectoryError): diff --git a/tests/test_predict/test_predict.py b/tests/test_predict/test_predict.py index 98051ca80..d25d06017 100644 --- a/tests/test_predict/test_predict.py +++ b/tests/test_predict/test_predict.py @@ -38,7 +38,7 @@ def test_predict( annot_format=annot_format, options_to_change=options_to_change, ) - cfg = vak.config.parse.from_toml_path(toml_path) + cfg = vak.config.Config.from_toml_path(toml_path) model_config = vak.config.model.config_from_toml_path(toml_path, cfg.predict.model) results_path = tmp_path / 'results_path' diff --git a/tests/test_prep/test_frame_classification/test_assign_samples_to_splits.py b/tests/test_prep/test_frame_classification/test_assign_samples_to_splits.py index d354dfd6a..87b9489eb 100644 --- a/tests/test_prep/test_frame_classification/test_assign_samples_to_splits.py +++ b/tests/test_prep/test_frame_classification/test_assign_samples_to_splits.py @@ -27,7 +27,7 @@ def test_assign_samples_to_splits( spect_format, ) - cfg = vak.config.parse.from_toml_path(toml_path) + cfg = vak.config.Config.from_toml_path(toml_path) # ---- set up ---- tmp_dataset_path = tmp_path / 'dataset_dir' diff --git a/tests/test_prep/test_frame_classification/test_frame_classification.py b/tests/test_prep/test_frame_classification/test_frame_classification.py index 31a264847..6b8c7d580 100644 --- a/tests/test_prep/test_frame_classification/test_frame_classification.py +++ b/tests/test_prep/test_frame_classification/test_frame_classification.py @@ -95,7 +95,7 @@ def test_prep_frame_classification_dataset( spect_format=spect_format, options_to_change=options_to_change, ) - cfg = vak.config.parse.from_toml_path(toml_path) + cfg = vak.config.Config.from_toml_path(toml_path) purpose = config_type.lower() dataset_df, dataset_path = vak.prep.frame_classification.frame_classification.prep_frame_classification_dataset( @@ -167,7 +167,7 @@ def test_prep_frame_classification_dataset_raises_when_labelset_required_but_is_ spect_format=spect_format, options_to_change=options_to_change, ) - cfg = vak.config.parse.from_toml_path(toml_path) + cfg = vak.config.Config.from_toml_path(toml_path) purpose = config_type.lower() with pytest.raises(ValueError): @@ -235,7 +235,7 @@ def test_prep_frame_classification_dataset_with_single_audio_and_annot(source_te spect_format=None, options_to_change=options_to_change, ) - cfg = vak.config.parse.from_toml_path(toml_path) + cfg = vak.config.Config.from_toml_path(toml_path) purpose = 'eval' dataset_df, dataset_path = vak.prep.frame_classification.frame_classification.prep_frame_classification_dataset( @@ -293,7 +293,7 @@ def test_prep_frame_classification_dataset_when_annot_has_single_segment(source_ spect_format=None, options_to_change=options_to_change, ) - cfg = vak.config.parse.from_toml_path(toml_path) + cfg = vak.config.Config.from_toml_path(toml_path) purpose = 'eval' dataset_df, dataset_path = vak.prep.frame_classification.frame_classification.prep_frame_classification_dataset( @@ -340,7 +340,7 @@ def test_prep_frame_classification_dataset_raises_not_a_directory( spect_format=None, options_to_change=dir_option_to_change, ) - cfg = vak.config.parse.from_toml_path(toml_path) + cfg = vak.config.Config.from_toml_path(toml_path) purpose = "train" with pytest.raises(NotADirectoryError): @@ -388,7 +388,7 @@ def test_prep_frame_classification_dataset_raises_file_not_found( spect_format=None, options_to_change=path_option_to_change, ) - cfg = vak.config.parse.from_toml_path(toml_path) + cfg = vak.config.Config.from_toml_path(toml_path) purpose = "train" with pytest.raises(FileNotFoundError): diff --git a/tests/test_prep/test_frame_classification/test_get_or_make_source_files.py b/tests/test_prep/test_frame_classification/test_get_or_make_source_files.py index d43ca7380..a58808d35 100644 --- a/tests/test_prep/test_frame_classification/test_get_or_make_source_files.py +++ b/tests/test_prep/test_frame_classification/test_get_or_make_source_files.py @@ -40,7 +40,7 @@ def test_get_or_make_source_files( spect_format, ) - cfg = vak.config.parse.from_toml_path(toml_path) + cfg = vak.config.Config.from_toml_path(toml_path) # ---- set up ---- tmp_dataset_path = tmp_path / 'dataset_dir' diff --git a/tests/test_prep/test_frame_classification/test_learncurve.py b/tests/test_prep/test_frame_classification/test_learncurve.py index 150e6483a..ed800b872 100644 --- a/tests/test_prep/test_frame_classification/test_learncurve.py +++ b/tests/test_prep/test_frame_classification/test_learncurve.py @@ -36,7 +36,7 @@ def test_make_index_vectors_for_each_subsets( annot_format=annot_format, options_to_change=options_to_change, ) - cfg = vak.config.parse.from_toml_path(toml_path) + cfg = vak.config.Config.from_toml_path(toml_path) dataset_path = cfg.learncurve.dataset_path metadata = vak.datasets.frame_classification.Metadata.from_dataset_path(dataset_path) @@ -148,7 +148,7 @@ def test_make_subsets_from_dataset_df( annot_format=annot_format, options_to_change=options_to_change, ) - cfg = vak.config.parse.from_toml_path(toml_path) + cfg = vak.config.Config.from_toml_path(toml_path) dataset_path = cfg.learncurve.dataset_path metadata = vak.datasets.frame_classification.Metadata.from_dataset_path(dataset_path) diff --git a/tests/test_prep/test_frame_classification/test_make_splits.py b/tests/test_prep/test_frame_classification/test_make_splits.py index 5d5ef11cf..11037a5c3 100644 --- a/tests/test_prep/test_frame_classification/test_make_splits.py +++ b/tests/test_prep/test_frame_classification/test_make_splits.py @@ -88,7 +88,7 @@ def test_make_splits(config_type, model_name, audio_format, spect_format, annot_ audio_format, spect_format, ) - cfg = vak.config.parse.from_toml_path(toml_path) + cfg = vak.config.Config.from_toml_path(toml_path) # ---- set up ---- tmp_dataset_path = tmp_path / 'dataset_dir' diff --git a/tests/test_prep/test_prep.py b/tests/test_prep/test_prep.py index 8e995f8bc..e5bb0664b 100644 --- a/tests/test_prep/test_prep.py +++ b/tests/test_prep/test_prep.py @@ -48,7 +48,7 @@ def test_prep( spect_format=spect_format, options_to_change=options_to_change, ) - cfg = vak.config.parse.from_toml_path(toml_path) + cfg = vak.config.Config.from_toml_path(toml_path) purpose = config_type.lower() # ---- test diff --git a/tests/test_train/test_frame_classification.py b/tests/test_train/test_frame_classification.py index 9cddd7e17..38642872f 100644 --- a/tests/test_train/test_frame_classification.py +++ b/tests/test_train/test_frame_classification.py @@ -58,7 +58,7 @@ def test_train_frame_classification_model( spect_format=spect_format, options_to_change=options_to_change, ) - cfg = vak.config.parse.from_toml_path(toml_path) + cfg = vak.config.Config.from_toml_path(toml_path) model_config = vak.config.model.config_from_toml_path(toml_path, cfg.train.model) vak.train.frame_classification.train_frame_classification_model( @@ -111,7 +111,7 @@ def test_continue_training( spect_format=spect_format, options_to_change=options_to_change, ) - cfg = vak.config.parse.from_toml_path(toml_path) + cfg = vak.config.Config.from_toml_path(toml_path) model_config = vak.config.model.config_from_toml_path(toml_path, cfg.train.model) vak.train.frame_classification.train_frame_classification_model( @@ -165,7 +165,7 @@ def test_train_raises_file_not_found( spect_format=None, options_to_change=options_to_change, ) - cfg = vak.config.parse.from_toml_path(toml_path) + cfg = vak.config.Config.from_toml_path(toml_path) model_config = vak.config.model.config_from_toml_path(toml_path, cfg.train.model) results_path = vak.common.paths.generate_results_dir_name_as_path(tmp_path) results_path.mkdir() @@ -220,7 +220,7 @@ def test_train_raises_not_a_directory( spect_format=None, options_to_change=options_to_change, ) - cfg = vak.config.parse.from_toml_path(toml_path) + cfg = vak.config.Config.from_toml_path(toml_path) model_config = vak.config.model.config_from_toml_path(toml_path, cfg.train.model) # mock behavior of cli.train, building `results_path` from config option `root_results_dir` diff --git a/tests/test_train/test_parametric_umap.py b/tests/test_train/test_parametric_umap.py index a64516e0a..ab6ea672f 100644 --- a/tests/test_train/test_parametric_umap.py +++ b/tests/test_train/test_parametric_umap.py @@ -51,7 +51,7 @@ def test_train_parametric_umap_model( spect_format=spect_format, options_to_change=options_to_change, ) - cfg = vak.config.parse.from_toml_path(toml_path) + cfg = vak.config.Config.from_toml_path(toml_path) model_config = vak.config.model.config_from_toml_path(toml_path, cfg.train.model) vak.train.parametric_umap.train_parametric_umap_model( @@ -101,7 +101,7 @@ def test_train_parametric_umap_model_raises_file_not_found( spect_format=None, options_to_change=options_to_change, ) - cfg = vak.config.parse.from_toml_path(toml_path) + cfg = vak.config.Config.from_toml_path(toml_path) model_config = vak.config.model.config_from_toml_path(toml_path, cfg.train.model) results_path = vak.common.paths.generate_results_dir_name_as_path(tmp_path) results_path.mkdir() @@ -153,7 +153,7 @@ def test_train_parametric_umap_model_raises_not_a_directory( spect_format=None, options_to_change=options_to_change, ) - cfg = vak.config.parse.from_toml_path(toml_path) + cfg = vak.config.Config.from_toml_path(toml_path) model_config = vak.config.model.config_from_toml_path(toml_path, cfg.train.model) # mock behavior of cli.train, building `results_path` from config option `root_results_dir` diff --git a/tests/test_train/test_train.py b/tests/test_train/test_train.py index 559853a24..5fb911e6e 100644 --- a/tests/test_train/test_train.py +++ b/tests/test_train/test_train.py @@ -46,7 +46,7 @@ def test_train( spect_format=spect_format, options_to_change=options_to_change, ) - cfg = vak.config.parse.from_toml_path(toml_path) + cfg = vak.config.Config.from_toml_path(toml_path) model_config = vak.config.model.config_from_toml_path(toml_path, cfg.train.model) results_path = tmp_path / 'results_path' From 500543ea31ce34d7f971c36c2f8e0ac71b135ad9 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Thu, 2 May 2024 21:23:34 -0400 Subject: [PATCH 072/183] Make fixes in Config classmethods --- src/vak/cli/prep.py | 87 ++++++++++++++++++++++------------------ src/vak/config/config.py | 18 +++------ 2 files changed, 52 insertions(+), 53 deletions(-) diff --git a/src/vak/cli/prep.py b/src/vak/cli/prep.py index e9e993e7d..955c07945 100644 --- a/src/vak/cli/prep.py +++ b/src/vak/cli/prep.py @@ -9,40 +9,46 @@ from .. import config from .. import prep as prep_module -from ..config.parse import _load_toml_from_path +from ..config.load import _load_toml_from_path from ..config.validators import are_tables_valid def purpose_from_toml( - config_toml: dict, toml_path: str | pathlib.Path | None = None + config_dict: dict, toml_path: str | pathlib.Path | None = None ) -> str: - """determine "purpose" from toml config, + """Determine "purpose" from toml config, i.e., the command that will be run after we ``prep`` the data. - By convention this is the other table in the config file - that correspond to a cli command besides '[PREP]' + By convention this is the other top-level table in the config file + that correspond to a cli command besides ``[vak.prep]``, e.g. ``[vak.train]``. """ # validate, make sure there aren't multiple commands in one config file first - are_tables_valid(config_toml, toml_path=toml_path) + are_tables_valid(config_dict, toml_path=toml_path) + config_dict = config_dict from ..cli.cli import CLI_COMMANDS # avoid circular imports - commands_that_are_not_prep = ( + commands_that_are_not_prep = [ command for command in CLI_COMMANDS if command != "prep" - ) - for command in commands_that_are_not_prep: - table_name = ( - command.upper() - ) # we write table names in uppercase, e.g. `[PREP]`, by convention - if table_name in config_toml: - return table_name.lower() # this is the "purpose" of the file - + ] + purpose = None + for table_name in commands_that_are_not_prep: + if table_name in config_dict: + purpose = table_name # this top-level table is the "purpose" of the file + if purpose is None: + raise ValueError( + "Did not find a top-level table in configuration file that corresponds to a CLI command. " + f"Configuration file path: {toml_path}\n" + f"Found the following top-level tables: {config_dict.keys()}\n" + f"Valid CLI commands besides ``prep`` (that correspond top-level tables) are: {commands_that_are_not_prep}" + ) + return purpose # note NO LOGGING -- we configure logger inside `core.prep` # so we can save log file inside dataset directory # see https://github.com/NickleDave/vak/issues/334 -SECTIONS_PREP_SHOULD_PARSE = ("prep", "spect_params", "dataloader") +TABLES_PREP_SHOULD_PARSE = "prep" def prep(toml_path): @@ -83,26 +89,25 @@ def prep(toml_path): """ toml_path = pathlib.Path(toml_path) - # open here because need to check for `dataset_path` in this function, see #314 & #333 - config_toml = _load_toml_from_path(toml_path) - # ---- figure out purpose of config file from tables; will save csv path in that table ------------------------- - purpose = purpose_from_toml(config_toml, toml_path) - if ( - "path" in config_toml[purpose.upper()]["path"] - and config_toml[purpose.upper()]["dataset"]["path"] is not None - ): + # open here because need to check whether the `dataset` already has a `path`, see #314 & #333 + config_dict = _load_toml_from_path(toml_path) + + # ---- figure out purpose of config file from tables; will save path of prep'd dataset in that table ------------------------- + purpose = purpose_from_toml(config_dict, toml_path) + if "dataset" in config_dict[purpose] and "path" in config_dict[purpose]["dataset"]: raise ValueError( - f"config .toml file already has a 'dataset_path' option in the '{purpose.upper()}' table, " - f"and running `prep` would overwrite that value. To `prep` a new dataset, please remove " - f"the 'dataset_path' option from the '{purpose.upper()}' table in the config file:\n{toml_path}" + f"This configuration file already has a '{purpose}.dataset' table with a 'path' key, " + f"and running `prep` would overwrite the value for that key. To `prep` a new dataset, please " + "either create a new configuration file, or remove " + f"the 'path' key-value pair from the '{purpose}.dataset' table in the file:\n{toml_path}" ) - # now that we've checked that, go ahead and parse the tables we want - cfg = config.parse.from_toml_path( - toml_path, tables=SECTIONS_PREP_SHOULD_PARSE - ) - # notice we ignore any other option/values in the 'purpose' table, + # now that we've checked that, go ahead and parse just the prep tabel; + # we don't load the 'purpose' table into a config, to avoid error messages like non-existent paths, etc. # see https://github.com/NickleDave/vak/issues/334 and https://github.com/NickleDave/vak/issues/314 + cfg = config.Config.from_toml_path( + toml_path, tables_to_parse=TABLES_PREP_SHOULD_PARSE + ) if cfg.prep is None: raise ValueError( f"prep called with a config.toml file that does not have a [vak.prep] table: {toml_path}" @@ -119,16 +124,14 @@ def prep(toml_path): ) cfg.prep.labelset = None - table = purpose.upper() - - dataset_df, dataset_path = prep_module.prep( + _, dataset_path = prep_module.prep( data_dir=cfg.prep.data_dir, purpose=purpose, dataset_type=cfg.prep.dataset_type, input_type=cfg.prep.input_type, audio_format=cfg.prep.audio_format, spect_format=cfg.prep.spect_format, - spect_params=cfg.spect_params, + spect_params=cfg.prep.spect_params, annot_format=cfg.prep.annot_format, annot_file=cfg.prep.annot_file, labelset=cfg.prep.labelset, @@ -141,11 +144,15 @@ def prep(toml_path): num_replicates=cfg.prep.num_replicates, ) - # use config and table from above to add dataset_path to config.toml file - config_toml[table]["dataset"]["path"] = str(dataset_path) - + # we re-open config using tomlkit so we can add path to dataset table in style-preserving way + with toml_path.open('r') as fp: + tomldoc = tomlkit.load(fp) + if 'dataset' not in tomldoc['vak'][purpose]: + dataset_table = tomlkit.table() + tomldoc["vak"][purpose].add("dataset", dataset_table) + tomldoc["vak"][purpose]["dataset"].add("path", str(dataset_path)) with toml_path.open("w") as fp: - tomlkit.dump(config_toml, fp) + tomlkit.dump(tomldoc, fp) # lastly, copy config to dataset directory root shutil.copy(src=toml_path, dst=dataset_path) diff --git a/src/vak/config/config.py b/src/vak/config/config.py index e9efbac0e..fa7e3f20e 100644 --- a/src/vak/config/config.py +++ b/src/vak/config/config.py @@ -97,7 +97,8 @@ def from_config_dict( This :func:`classmethod` expects the output of :func:`vak.config.load._load_from_toml_path`, that converts a :class:`tomlkit.TOMLDocument` - to a :class:`dict`. + to a :class:`dict`, and returns the :class:`dict` + that is accessed by the top-level key ``'vak'``. Parameters ---------- @@ -121,20 +122,11 @@ def from_config_dict( whose attributes correspond to the top-level tables in a config.toml file. """ - try: - config_dict = config_dict['vak'] - except KeyError as e: - raise KeyError( - "Did not find key 'vak' in `config_dict`." - "All top-level tables in toml configuration file must " - "use dotted names that begin with 'vak', e.g. ``[vak.eval]``.\n" - f"`config_dict`:\n{config_dict}" - ) are_tables_valid(config_dict, toml_path) if tables_to_parse is None: tables_to_parse = list( - TABLE_CLASSES_MAP.keys() - ) # i.e., parse all tables + config_dict.keys() + ) # i.e., parse all top-level tables else: tables_to_parse = _validate_tables_to_parse_arg_convert_list(tables_to_parse) @@ -181,4 +173,4 @@ def from_toml_path( tables in a config.toml file. """ config_dict: dict = load._load_toml_from_path(toml_path) - return cls.from_toml(config_dict, toml_path, tables_to_parse) + return cls.from_config_dict(config_dict, toml_path, tables_to_parse) From 7ebd92e360478c534d05f719bd5f2a7d2c1750b7 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Thu, 2 May 2024 21:24:08 -0400 Subject: [PATCH 073/183] Change load._load_toml_from_path again so that it returns config_dict['vak'], to avoid writing ['vak'] everywhere in calling functions --- src/vak/config/load.py | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/src/vak/config/load.py b/src/vak/config/load.py index 2b4a489c1..f661c4c78 100644 --- a/src/vak/config/load.py +++ b/src/vak/config/load.py @@ -6,12 +6,6 @@ import tomlkit import tomlkit.exceptions -from .config import Config -from .eval import EvalConfig -from .learncurve import LearncurveConfig -from .predict import PredictConfig -from .prep import PrepConfig -from .train import TrainConfig from .validators import are_keys_valid, are_tables_valid @@ -54,10 +48,22 @@ def _tomlkit_to_popo(d): def _load_toml_from_path(toml_path: str | pathlib.Path) -> dict: """Load a toml file from a path, and return as a :class:`dict`. + Notes + ----- Helper function to load toml config file, factored out to use in other modules when needed. Checks if ``toml_path`` exists before opening, - and tries to give a clear message if an error occurs when loading.""" + and tries to give a clear message if an error occurs when loading. + + Note also this function checks that the loaded :class:`dict` + has a single top-level key ``'vak'``, + and that it returns the :class:`dict` one level down + that is accessed with that key. + This avoids the need to write ``['vak']`` everywhere in + calling functions. + However it also means you need to add back that key + if you are *writing* a toml file. + """ toml_path = pathlib.Path(toml_path) if not toml_path.is_file(): raise FileNotFoundError(f".toml config file not found: {toml_path}") @@ -73,7 +79,8 @@ def _load_toml_from_path(toml_path: str | pathlib.Path) -> dict: if 'vak' not in config_dict: raise ValueError( "Toml file does not contain a top-level table named `vak`. " - f"Please see example configuration files here: " + "Please see example configuration files here:\n" + "https://github.com/vocalpy/vak/tree/main/doc/toml" ) # Next line, convert TOMLDocument returned by tomlkit.load to a dict. @@ -84,5 +91,4 @@ def _load_toml_from_path(toml_path: str | pathlib.Path) -> dict: # and then assigns it to the ``spect_params`` key. # We would get this error if we just return the result of :func:`tomlkit.load`, # which is a `tomlkit.TOMLDocument` that tries to ensure that everything is valid toml. - return _tomlkit_to_popo(config_dict) - + return _tomlkit_to_popo(config_dict)['vak'] From acf4cc3ff289af823f7c684d6b5acda3f96c66db Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Thu, 2 May 2024 21:24:27 -0400 Subject: [PATCH 074/183] Add docstring to are_tables_valid in config/validators.py --- src/vak/config/validators.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/vak/config/validators.py b/src/vak/config/validators.py index da8d55a04..9a1a39ca0 100644 --- a/src/vak/config/validators.py +++ b/src/vak/config/validators.py @@ -68,6 +68,11 @@ def is_spect_format(instance, attribute, value): def are_tables_valid(config_dict, toml_path=None): + """Validate top-level tables in class:`dict`. + + This function expects the ``config_dict`` + returned by :func:`vak.config.load._load_from_toml_path`. + """ tables = list(config_dict.keys()) from ..cli.cli import CLI_COMMANDS # avoid circular import From 42e92c7a717cd40755cfe1e86b47726789df9005 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Thu, 2 May 2024 21:24:48 -0400 Subject: [PATCH 075/183] Lowercase config table names in tests/scripts/vaktestdata/configs.py --- tests/scripts/vaktestdata/configs.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/scripts/vaktestdata/configs.py b/tests/scripts/vaktestdata/configs.py index 1741e5b1b..869482fc5 100644 --- a/tests/scripts/vaktestdata/configs.py +++ b/tests/scripts/vaktestdata/configs.py @@ -62,9 +62,9 @@ def add_dataset_path_from_prepped_configs(): for config_metadata in configs_to_change: config_to_change_path = constants.GENERATED_TEST_CONFIGS_ROOT / config_metadata.filename if config_metadata.config_type == 'train_continue': - section = 'TRAIN' + section = 'train' else: - section = config_metadata.config_type.upper() + section = config_metadata.config_type config_dataset_path = constants.GENERATED_TEST_CONFIGS_ROOT / config_metadata.use_dataset_from_config @@ -129,7 +129,7 @@ def fix_options_in_configs(config_metadata_list, command, single_train_result=Tr # these are the only options whose values we need to change # and they are the same for both predict and eval checkpoint_path = sorted(results_dir.glob("**/checkpoints/checkpoint.pt"))[0] - if 'normalize_spectrograms' in config_toml['TRAIN'] and config_toml['TRAIN']['normalize_spectrograms']: + if 'normalize_spectrograms' in config_toml['train'] and config_toml['train']['normalize_spectrograms']: spect_scaler_path = sorted(results_dir.glob("StandardizeSpect"))[0] else: spect_scaler_path = None From 61ea4e4cf98e0c85114e15f44b35b4f957433db0 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Thu, 2 May 2024 21:25:57 -0400 Subject: [PATCH 076/183] In tests/scripts/vaktestdata/source_files.py, change cfg.spect_params -> cfg.prep.spect_params, fix how we change values in toml, add tables_to_parse arg to call to Config.from_toml_path --- tests/scripts/vaktestdata/source_files.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/scripts/vaktestdata/source_files.py b/tests/scripts/vaktestdata/source_files.py index 2f7c032eb..43bea23e9 100644 --- a/tests/scripts/vaktestdata/source_files.py +++ b/tests/scripts/vaktestdata/source_files.py @@ -47,14 +47,14 @@ def set_up_source_files_and_csv_files_for_frame_classification_models(): f"\nRunning :func:`vak.prep.frame_classification.get_or_make_source_files` to generate data for tests, " f"using config:\n{config_path.name}" ) - cfg = vak.config.Config.from_toml_path(config_path) + cfg = vak.config.Config.from_toml_path(config_path, tables_to_parse='prep') source_files_df: pd.DataFrame = vak.prep.frame_classification.get_or_make_source_files( data_dir=cfg.prep.data_dir, input_type=cfg.prep.input_type, audio_format=cfg.prep.audio_format, spect_format=cfg.prep.spect_format, - spect_params=cfg.spect_params, + spect_params=cfg.prep.spect_params, spect_output_dir=spect_output_dir, annot_format=cfg.prep.annot_format, annot_file=cfg.prep.annot_file, @@ -103,9 +103,9 @@ def set_up_source_files_and_csv_files_for_frame_classification_models(): ) with config_path.open("r") as fp: - config_toml = tomlkit.load(fp) + tomldoc = tomlkit.load(fp) data_dir = constants.GENERATED_TEST_DATA_ROOT / config_metadata.data_dir - config_toml['PREP']['data_dir'] = str(data_dir) + tomldoc['vak']['prep']['data_dir'] = str(data_dir) with config_path.open("w") as fp: tomlkit.dump(config_toml, fp) @@ -116,7 +116,7 @@ def set_up_source_files_and_csv_files_for_frame_classification_models(): input_type=cfg.prep.input_type, audio_format=cfg.prep.audio_format, spect_format=cfg.prep.spect_format, - spect_params=cfg.spect_params, + spect_params=cfg.prep.spect_params, spect_output_dir=None, annot_format=cfg.prep.annot_format, annot_file=cfg.prep.annot_file, @@ -165,7 +165,7 @@ def set_up_source_files_and_csv_files_for_frame_classification_models(): input_type=cfg.prep.input_type, audio_format=cfg.prep.audio_format, spect_format=cfg.prep.spect_format, - spect_params=cfg.spect_params, + spect_params=cfg.prep.spect_params, spect_output_dir=None, annot_format=cfg.prep.annot_format, annot_file=cfg.prep.annot_file, From dc19398cf86c17603e5aa3ecbe919fe021a56657 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Thu, 2 May 2024 21:26:22 -0400 Subject: [PATCH 077/183] in test_cli/test_prep.py, call vak.config.load not vak.config.parse --- tests/test_cli/test_prep.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_cli/test_prep.py b/tests/test_cli/test_prep.py index cfdd8453e..3361f266e 100644 --- a/tests/test_cli/test_prep.py +++ b/tests/test_cli/test_prep.py @@ -35,7 +35,7 @@ def test_purpose_from_toml( annot_format=annot_format, spect_format=spect_format, ) - config_toml = vak.config.parse._load_toml_from_path(toml_path) + config_toml = vak.config.load._load_toml_from_path(toml_path) vak.cli.prep.purpose_from_toml(config_toml) From 95e424cc5e8565f8794d2045694ff6efe68d2172 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Thu, 2 May 2024 21:47:44 -0400 Subject: [PATCH 078/183] Fix how we instantiate DatasetConfig and ModelConfig in EvalConfig.from_config_dict method --- src/vak/config/eval.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/vak/config/eval.py b/src/vak/config/eval.py index 4e87604c6..64f764dd3 100644 --- a/src/vak/config/eval.py +++ b/src/vak/config/eval.py @@ -200,9 +200,8 @@ def from_config_dict(cls, config_dict: dict) -> EvalConfig: "when loading the configuration file into a Python dictionary. " "Please check that the configuration file is formatted correctly." ) - config_dict['dataset'] = DatasetConfig(**config_dict['dataset']) - config_dict['model'] = ModelConfig(**config_dict['model']) + config_dict['dataset'] = DatasetConfig.from_config_dict(config_dict['dataset']) + config_dict['model'] = ModelConfig.from_config_dict(config_dict['model']) return cls( **config_dict ) - From 5481ec8ada283beb58d114c4754d88ebc1b5c3c9 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Thu, 2 May 2024 21:49:12 -0400 Subject: [PATCH 079/183] Fix how we instantiate DatasetConfig and ModelConfig in PredictConfig.from_config_dict method --- src/vak/config/predict.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/vak/config/predict.py b/src/vak/config/predict.py index d318ee89d..a7f7c826d 100644 --- a/src/vak/config/predict.py +++ b/src/vak/config/predict.py @@ -159,8 +159,8 @@ def from_config_dict(cls, config_dict: dict) -> PredictConfig: "when loading the configuration file into a Python dictionary. " "Please check that the configuration file is formatted correctly." ) - config_dict['dataset'] = DatasetConfig(**config_dict['dataset']) - config_dict['model'] = ModelConfig(**config_dict['model']) + config_dict['dataset'] = DatasetConfig.from_config_dict(config_dict['dataset']) + config_dict['model'] = ModelConfig.from_config_dict(config_dict['model']) return cls( **config_dict ) From 1733166b51c3ba02bb1710ce121e18006e776091 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Thu, 2 May 2024 21:49:23 -0400 Subject: [PATCH 080/183] Fix how we instantiate DatasetConfig and ModelConfig in TrainConfig.from_config_dict method --- src/vak/config/train.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/vak/config/train.py b/src/vak/config/train.py index b6b5879d4..135a05541 100644 --- a/src/vak/config/train.py +++ b/src/vak/config/train.py @@ -169,8 +169,8 @@ def from_config_dict(cls, config_dict: dict) -> "TrainConfig": "when loading the configuration file into a Python dictionary. " "Please check that the configuration file is formatted correctly." ) - config_dict['model'] = ModelConfig(**config_dict['model']) - config_dict['dataset'] = DatasetConfig(**config_dict['dataset']) + config_dict['model'] = ModelConfig.from_config_dict(config_dict['model']) + config_dict['dataset'] = DatasetConfig.from_config_dict(config_dict['dataset']) return cls( **config_dict - ) \ No newline at end of file + ) From d729abdecef69f7e4344b87d44d623e1bd31a16c Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Thu, 2 May 2024 21:49:36 -0400 Subject: [PATCH 081/183] Fix how we instantiate DatasetConfig and ModelConfig in LearncurveConfig.from_config_dict method --- src/vak/config/learncurve.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/vak/config/learncurve.py b/src/vak/config/learncurve.py index 28d3237a9..4df1e0423 100644 --- a/src/vak/config/learncurve.py +++ b/src/vak/config/learncurve.py @@ -4,7 +4,9 @@ from attrs import define, field from attrs import converters, validators +from .dataset import DatasetConfig from .eval import are_valid_post_tfm_kwargs, convert_post_tfm_kwargs +from .model import ModelConfig from .train import TrainConfig @@ -90,8 +92,8 @@ def from_config_dict(cls, config_dict: dict) -> "TrainConfig": "when loading the configuration file into a Python dictionary. " "Please check that the configuration file is formatted correctly." ) - config_dict['model'] = ModelConfig(**config_dict['model']) - config_dict['dataset'] = DatasetConfig(**config_dict['dataset']) + config_dict['model'] = ModelConfig.from_config_dict(config_dict['model']) + config_dict['dataset'] = DatasetConfig.from_config_dict(config_dict['dataset']) return cls( **config_dict - ) \ No newline at end of file + ) From 9185eb9b8cc0453de3775396577264023815b999 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Thu, 2 May 2024 21:49:57 -0400 Subject: [PATCH 082/183] Remove brekapoint in src/vak/config/model.py --- src/vak/config/model.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/vak/config/model.py b/src/vak/config/model.py index aacc6484d..c1a4ce0cf 100644 --- a/src/vak/config/model.py +++ b/src/vak/config/model.py @@ -59,10 +59,7 @@ def from_config_dict(cls, config_dict: dict): config_dict = vak.config.parse.from_toml_path(toml_path) model_config = vak.config.Model.from_config_dict(config_dict['train']) """ - try: - model_name = list(config_dict.keys()) - except: - breakpoint() + model_name = list(config_dict.keys()) if len(model_name) == 0: raise ValueError( "Did not find a single key in `config_dict` corresponding to model name. " From 3341b08a2a3442595dc644d68c52e2dc55dfee43 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Thu, 2 May 2024 21:51:00 -0400 Subject: [PATCH 083/183] Fix wrong variable name so we save configs correctly in tests/scripts/vaktestdata/source_files.py, and add tables_to_parse arg to Config.from_toml_path, so we don't get 'missing dataset' errors --- tests/scripts/vaktestdata/source_files.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/scripts/vaktestdata/source_files.py b/tests/scripts/vaktestdata/source_files.py index 43bea23e9..d1f25bc4e 100644 --- a/tests/scripts/vaktestdata/source_files.py +++ b/tests/scripts/vaktestdata/source_files.py @@ -107,9 +107,9 @@ def set_up_source_files_and_csv_files_for_frame_classification_models(): data_dir = constants.GENERATED_TEST_DATA_ROOT / config_metadata.data_dir tomldoc['vak']['prep']['data_dir'] = str(data_dir) with config_path.open("w") as fp: - tomlkit.dump(config_toml, fp) + tomlkit.dump(tomldoc, fp) - cfg = vak.config.Config.from_toml_path(config_path) + cfg = vak.config.Config.from_toml_path(config_path, tables_to_parse='prep') source_files_df: pd.DataFrame = vak.prep.frame_classification.get_or_make_source_files( data_dir=cfg.prep.data_dir, @@ -159,7 +159,7 @@ def set_up_source_files_and_csv_files_for_frame_classification_models(): f"\nRunning :func:`vak.prep.frame_classification.get_or_make_source_files` to generate data for tests, " f"using config:\n{config_path.name}" ) - cfg = vak.config.Config.from_toml_path(config_path) + cfg = vak.config.Config.from_toml_path(config_path, tables_to_parse='prep') source_files_df: pd.DataFrame = vak.prep.frame_classification.get_or_make_source_files( data_dir=cfg.prep.data_dir, input_type=cfg.prep.input_type, From b05098975f9cdc44a189bc3bb54d5e1b0a7d6c01 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Thu, 2 May 2024 22:09:58 -0400 Subject: [PATCH 084/183] Fix how we re-write configs, in tests/scripts/vaktestdata/configs.py --- tests/scripts/vaktestdata/configs.py | 52 +++++++++++++++------------- 1 file changed, 27 insertions(+), 25 deletions(-) diff --git a/tests/scripts/vaktestdata/configs.py b/tests/scripts/vaktestdata/configs.py index 869482fc5..98d31a265 100644 --- a/tests/scripts/vaktestdata/configs.py +++ b/tests/scripts/vaktestdata/configs.py @@ -48,10 +48,9 @@ def add_dataset_path_from_prepped_configs(): """This helper function goes through all configs in :data:`vaktestdata.constants.CONFIG_METADATA` and for any that have a filename for the attribute - "use_dataset_from_config", it sets the option 'dataset_path' + "use_dataset_from_config", it sets the key 'path' in the 'dataset' table in the config file that the metadata corresponds to - to the same option from the file specified - by the attribute. + to the same value from the file specified by the attribute. """ configs_to_change = [ config_metadata @@ -62,27 +61,30 @@ def add_dataset_path_from_prepped_configs(): for config_metadata in configs_to_change: config_to_change_path = constants.GENERATED_TEST_CONFIGS_ROOT / config_metadata.filename if config_metadata.config_type == 'train_continue': - section = 'train' + table_to_add_dataset = 'train' else: - section = config_metadata.config_type + table_to_add_dataset = config_metadata.config_type config_dataset_path = constants.GENERATED_TEST_CONFIGS_ROOT / config_metadata.use_dataset_from_config - with config_dataset_path.open("r") as fp: - dataset_config_toml = tomlkit.load(fp) - purpose = vak.cli.prep.purpose_from_toml(dataset_config_toml) + config_dict = vak.config.load._load_toml_from_path(config_dataset_path) # next line, we can't use `section` here because we could get a KeyError, - # e.g., when the config we are rewriting is an EVAL config, but - # the config we are getting the dataset from is a TRAIN config. + # e.g., when the config we are rewriting is an ``[vak.eval]`` config, but + # the config we are getting the dataset from is a ``[vak.train]`` config. # so instead we use `purpose_from_toml` to get the `purpose` # of the config we are getting the dataset from. - dataset_config_section = purpose.upper() # need to be 'TRAIN', not 'train' - dataset_path = dataset_config_toml[dataset_config_section]['dataset_path'] - with config_to_change_path.open("r") as fp: - config_to_change_toml = tomlkit.load(fp) - config_to_change_toml[section]['dataset_path'] = dataset_path + dataset_config_section = vak.cli.prep.purpose_from_toml(config_dict) + dataset_path = config_dict[dataset_config_section]['dataset']['path'] + + # we open config using tomlkit so we can add path to dataset table in style-preserving way + with config_to_change_path.open('r') as fp: + tomldoc = tomlkit.load(fp) + if 'dataset' not in tomldoc['vak'][table_to_add_dataset]: + dataset_table = tomlkit.table() + tomldoc["vak"][table_to_add_dataset].add("dataset", dataset_table) + tomldoc["vak"][table_to_add_dataset]["dataset"].add("path", str(dataset_path)) with config_to_change_path.open("w") as fp: - tomlkit.dump(config_to_change_toml, fp) + tomlkit.dump(tomldoc, fp) def fix_options_in_configs(config_metadata_list, command, single_train_result=True): @@ -104,7 +106,7 @@ def fix_options_in_configs(config_metadata_list, command, single_train_result=Tr # which are checkpoint_path, spect_scaler_path, and labelmap_path with config_to_use_result_from.open("r") as fp: config_toml = tomlkit.load(fp) - root_results_dir = pathlib.Path(config_toml["TRAIN"]["root_results_dir"]) + root_results_dir = pathlib.Path(config_toml["vak"]["train"]["root_results_dir"]) results_dir = sorted(root_results_dir.glob("results_*")) if len(results_dir) > 1: if single_train_result: @@ -129,7 +131,7 @@ def fix_options_in_configs(config_metadata_list, command, single_train_result=Tr # these are the only options whose values we need to change # and they are the same for both predict and eval checkpoint_path = sorted(results_dir.glob("**/checkpoints/checkpoint.pt"))[0] - if 'normalize_spectrograms' in config_toml['train'] and config_toml['train']['normalize_spectrograms']: + if 'normalize_spectrograms' in config_toml["vak"]['train'] and config_toml["vak"]['train']['normalize_spectrograms']: spect_scaler_path = sorted(results_dir.glob("StandardizeSpect"))[0] else: spect_scaler_path = None @@ -152,20 +154,20 @@ def fix_options_in_configs(config_metadata_list, command, single_train_result=Tr config_toml = tomlkit.load(fp) if command == 'train_continue': - section = 'TRAIN' + table = 'train' else: - section = command.upper() + table = command - config_toml[section]["checkpoint_path"] = str(checkpoint_path) + config_toml["vak"][table]["checkpoint_path"] = str(checkpoint_path) if spect_scaler_path: - config_toml[section]["spect_scaler_path"] = str(spect_scaler_path) + config_toml["vak"][table]["spect_scaler_path"] = str(spect_scaler_path) else: - if 'spect_scaler_path' in config_toml[section]: + if 'spect_scaler_path' in config_toml[table]: # remove any existing 'spect_scaler_path' option - del config_toml[section]["spect_scaler_path"] + del config_toml["vak"][table]["spect_scaler_path"] if command != 'train_continue': # train always gets labelmap from dataset dir, not from a config option if labelmap_path is not None: - config_toml[section]["labelmap_path"] = str(labelmap_path) + config_toml["vak"][table]["labelmap_path"] = str(labelmap_path) with config_to_fix.open("w") as fp: tomlkit.dump(config_toml, fp) From f6fdec69a3bfe6352006bb827eeee88b430c6e25 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Thu, 2 May 2024 22:44:23 -0400 Subject: [PATCH 085/183] Add model and dataset tables to get those keys in top-level tables, in src/vak/config/valid-version-1.0.toml --- src/vak/config/valid-version-1.0.toml | 36 +++++++++++++++++++++------ 1 file changed, 28 insertions(+), 8 deletions(-) diff --git a/src/vak/config/valid-version-1.0.toml b/src/vak/config/valid-version-1.0.toml index 2103c1166..7ebcb1556 100644 --- a/src/vak/config/valid-version-1.0.toml +++ b/src/vak/config/valid-version-1.0.toml @@ -34,9 +34,7 @@ timebins_key = 't' audio_path_key = 'audio_path' [vak.train] -model = 'TweetyNet' root_results_dir = './tests/test_data/results/train' -dataset_path = 'tests/test_data/prep/train/032312_prep_191224_225912.csv' num_workers = 4 device = 'cuda' batch_size = 11 @@ -54,12 +52,17 @@ train_dataset_params = {'window_size' = 80} val_transform_params = {'resize' = 128} val_dataset_params = {'window_size' = 80} +[vak.train.dataset] +name = 'IntlDistributedSongbirdConsortiumPack' +path = 'tests/test_data/prep/train/032312_prep_191224_225912.csv' +splits_path = 'tests/test_data/prep/train/032312_prep_191224_225912.splits.json' + +[vak.train.model.TweetyNet] + [vak.eval] -dataset_path = 'tests/test_data/prep/learncurve/032312_prep_191224_225910.csv' checkpoint_path = '/home/user/results_181014_194418/TweetyNet/checkpoints/' labelmap_path = '/home/user/results_181014_194418/labelmap.json' output_dir = './tests/test_data/prep/learncurve' -model = 'TweetyNet' batch_size = 11 num_workers = 4 device = 'cuda' @@ -68,8 +71,14 @@ post_tfm_kwargs = {'majority_vote' = true, 'min_segment_dur' = 0.01} transform_params = {'resize' = 128} dataset_params = {'window_size' = 80} +[vak.eval.dataset] +name = 'IntlDistributedSongbirdConsortiumPack' +path = 'tests/test_data/prep/learncurve/032312_prep_191224_225910.csv' +splits_path = 'tests/test_data/prep/train/032312_prep_191224_225912.splits.json' + +[vak.eval.model.TweetyNet] + [vak.learncurve] -model = 'TweetyNet' root_results_dir = './tests/test_data/results/learncurve' batch_size = 11 num_epochs = 2 @@ -78,7 +87,6 @@ shuffle = true val_step = 1 ckpt_step = 1 patience = 4 -dataset_path = 'tests/test_data/prep/learncurve/032312_prep_191224_225910.csv' results_dir_made_by_main_script = '/some/path/to/learncurve/' post_tfm_kwargs = {'majority_vote' = true, 'min_segment_dur' = 0.01} num_workers = 4 @@ -88,13 +96,18 @@ train_dataset_params = {'window_size' = 80} val_transform_params = {'resize' = 128} val_dataset_params = {'window_size' = 80} +[vak.learncurve.dataset] +name = 'IntlDistributedSongbirdConsortiumPack' +path = 'tests/test_data/prep/learncurve/032312_prep_191224_225910.csv' +splits_path = 'tests/test_data/prep/train/032312_prep_191224_225912.splits.json' + +[vak.learncurve.model.TweetyNet] + [vak.predict] -dataset_path = 'tests/test_data/prep/learncurve/032312_prep_191224_225910.csv' checkpoint_path = '/home/user/results_181014_194418/TweetyNet/checkpoints/' labelmap_path = '/home/user/results_181014_194418/labelmap.json' annot_csv_filename = '032312_prep_191224_225910.annot.csv' output_dir = './tests/test_data/prep/learncurve' -model = 'TweetyNet' batch_size = 11 num_workers = 4 device = 'cuda' @@ -104,3 +117,10 @@ majority_vote = false save_net_outputs = false transform_params = {'resize' = 128} dataset_params = {'window_size' = 80} + +[vak.predict.dataset] +name = 'IntlDistributedSongbirdConsortiumPack' +path = 'tests/test_data/prep/learncurve/032312_prep_191224_225910.csv' +splits_path = 'tests/test_data/prep/train/032312_prep_191224_225912.splits.json' + +[vak.predict.model.TweetyNet] \ No newline at end of file From 8d49aef301ef685fb1263de46572c31e79924903 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Thu, 2 May 2024 22:54:01 -0400 Subject: [PATCH 086/183] Change cfg.table.dataset_path -> cfg.table.dataset.path in vak/cli modules (e.g., vak.train.dataset.path) --- src/vak/cli/eval.py | 4 ++-- src/vak/cli/learncurve.py | 4 ++-- src/vak/cli/predict.py | 4 ++-- src/vak/cli/train.py | 4 ++-- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/vak/cli/eval.py b/src/vak/cli/eval.py index 736edda91..2a8a95bbe 100644 --- a/src/vak/cli/eval.py +++ b/src/vak/cli/eval.py @@ -39,7 +39,7 @@ def eval(toml_path): logger.info("Logging results to {}".format(cfg.eval.output_dir)) - if cfg.eval.dataset_path is None: + if cfg.eval.dataset.path is None: raise ValueError( "No value is specified for 'dataset_path' in this .toml config file." f"To generate a .csv file that represents the dataset, " @@ -49,7 +49,7 @@ def eval(toml_path): eval_module.eval( model_name=cfg.eval.model.name, model_config=cfg.eval.model.to_dict(), - dataset_path=cfg.eval.dataset_path, + dataset_path=cfg.eval.dataset.path, checkpoint_path=cfg.eval.checkpoint_path, labelmap_path=cfg.eval.labelmap_path, output_dir=cfg.eval.output_dir, diff --git a/src/vak/cli/learncurve.py b/src/vak/cli/learncurve.py index 9b6991816..09f76692e 100644 --- a/src/vak/cli/learncurve.py +++ b/src/vak/cli/learncurve.py @@ -45,7 +45,7 @@ def learning_curve(toml_path): log_version(logger) logger.info("Logging results to {}".format(results_path)) - if cfg.learncurve.dataset_path is None: + if cfg.learncurve.dataset.path is None: raise ValueError( "No value is specified for 'dataset_path' in this .toml config file." f"To generate a .csv file that represents the dataset, " @@ -55,7 +55,7 @@ def learning_curve(toml_path): learncurve.learning_curve( model_name=cfg.learncurve.model.name, model_config=cfg.learncurve.model.to_dict(), - dataset_path=cfg.learncurve.dataset_path, + dataset_path=cfg.learncurve.dataset.path, batch_size=cfg.learncurve.batch_size, num_epochs=cfg.learncurve.num_epochs, num_workers=cfg.learncurve.num_workers, diff --git a/src/vak/cli/predict.py b/src/vak/cli/predict.py index 19c4648da..feb976357 100644 --- a/src/vak/cli/predict.py +++ b/src/vak/cli/predict.py @@ -35,7 +35,7 @@ def predict(toml_path): log_version(logger) logger.info("Logging results to {}".format(cfg.prep.output_dir)) - if cfg.predict.dataset_path is None: + if cfg.predict.dataset.path is None: raise ValueError( "No value is specified for 'dataset_path' in this .toml config file." f"To generate a .csv file that represents the dataset, " @@ -45,7 +45,7 @@ def predict(toml_path): predict_module.predict( model_name=cfg.predict.model.name, model_config=cfg.predict.model.to_dict(), - dataset_path=cfg.predict.dataset_path, + dataset_path=cfg.predict.dataset.path, checkpoint_path=cfg.predict.checkpoint_path, labelmap_path=cfg.predict.labelmap_path, num_workers=cfg.predict.num_workers, diff --git a/src/vak/cli/train.py b/src/vak/cli/train.py index 97fddb133..3487958ba 100644 --- a/src/vak/cli/train.py +++ b/src/vak/cli/train.py @@ -45,7 +45,7 @@ def train(toml_path): log_version(logger) logger.info("Logging results to {}".format(results_path)) - if cfg.train.dataset_path is None: + if cfg.train.dataset.path is None: raise ValueError( "No value is specified for 'dataset_path' in this .toml config file." f"To generate a .csv file that represents the dataset, " @@ -55,7 +55,7 @@ def train(toml_path): train_module.train( model_name=cfg.train.model.name, model_config=cfg.train.model.to_dict(), - dataset_path=cfg.train.dataset_path, + dataset_path=cfg.train.dataset.path, train_transform_params=cfg.train.train_transform_params, train_dataset_params=cfg.train.train_dataset_params, val_transform_params=cfg.train.val_transform_params, From 27982825ae1c53a4748480e57de794f6221ff955 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Thu, 2 May 2024 22:55:14 -0400 Subject: [PATCH 087/183] Get tests passing for tests/test_config/test_eval.py --- tests/test_config/test_eval.py | 106 +++++++++++++++++++++++++++++++-- 1 file changed, 100 insertions(+), 6 deletions(-) diff --git a/tests/test_config/test_eval.py b/tests/test_config/test_eval.py index 9b0ce2793..e762b0155 100644 --- a/tests/test_config/test_eval.py +++ b/tests/test_config/test_eval.py @@ -1,10 +1,104 @@ """tests for vak.config.eval module""" -import vak.config.eval +import pytest +import vak.config -def test_predict_attrs_class(all_generated_eval_configs_toml): - """test that instantiating EvalConfig class works as expected""" - for config_toml in all_generated_eval_configs_toml: - eval_section = config_toml["EVAL"] - config = vak.config.eval.EvalConfig(**eval_section) + +class TestEval: + + @pytest.mark.parametrize( + 'config_dict', + [ + { + 'checkpoint_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/TweetyNet/checkpoints/max-val-acc-checkpoint.pt', + 'labelmap_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/labelmap.json', + 'batch_size': 11, + 'num_workers': 16, + 'device': 'cuda', + 'spect_scaler_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/StandardizeSpect', + 'output_dir': './tests/data_for_tests/generated/results/eval/audio_cbin_annot_notmat/TweetyNet', + 'post_tfm_kwargs': { + 'majority_vote': True, 'min_segment_dur': 0.02 + }, + 'transform_params': { + 'window_size': 88 + }, + 'model': { + 'TweetyNet': { + 'network': { + 'conv1_filters': 8, + 'conv1_kernel_size': [3, 3], + 'conv2_filters': 16, + 'conv2_kernel_size': [5, 5], + 'pool1_size': [4, 1], + 'pool1_stride': [4, 1], + 'pool2_size': [4, 1], + 'pool2_stride': [4, 1], + 'hidden_size': 32 + }, + 'optimizer': {'lr': 0.001} + } + }, + 'dataset': { + 'path': '~/some/path/I/made/up/for/now' + }, + } + ] + ) + def test_init(self, config_dict): + config_dict['model'] = vak.config.ModelConfig.from_config_dict(config_dict['model']) + config_dict['dataset'] = vak.config.DatasetConfig.from_config_dict(config_dict['dataset']) + + eval_config = vak.config.EvalConfig(**config_dict) + + assert isinstance(eval_config, vak.config.EvalConfig) + + @pytest.mark.parametrize( + 'config_dict', + [ + { + 'checkpoint_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/TweetyNet/checkpoints/max-val-acc-checkpoint.pt', + 'labelmap_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/labelmap.json', + 'batch_size': 11, + 'num_workers': 16, + 'device': 'cuda', + 'spect_scaler_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/StandardizeSpect', + 'output_dir': './tests/data_for_tests/generated/results/eval/audio_cbin_annot_notmat/TweetyNet', + 'post_tfm_kwargs': { + 'majority_vote': True, 'min_segment_dur': 0.02 + }, + 'transform_params': { + 'window_size': 88 + }, + 'model': { + 'TweetyNet': { + 'network': { + 'conv1_filters': 8, + 'conv1_kernel_size': [3, 3], + 'conv2_filters': 16, + 'conv2_kernel_size': [5, 5], + 'pool1_size': [4, 1], + 'pool1_stride': [4, 1], + 'pool2_size': [4, 1], + 'pool2_stride': [4, 1], + 'hidden_size': 32 + }, + 'optimizer': {'lr': 0.001} + } + }, + 'dataset': { + 'path': '~/some/path/I/made/up/for/now' + }, + } + ] + ) + def test_from_config_dict(self, config_dict): + eval_config = vak.config.EvalConfig.from_config_dict(config_dict) + + assert isinstance(eval_config, vak.config.EvalConfig) + + def test_from_config_dict_with_real_config(self, a_generated_eval_config_dict): + """test that instantiating EvalConfig class works as expected""" + eval_section = a_generated_eval_config_dict["eval"] + config = vak.config.eval.EvalConfig.from_config_dict(eval_section) assert isinstance(config, vak.config.eval.EvalConfig) From 0f6cc5a06dda6799bde98f28e4515b7aafc8dc83 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Thu, 2 May 2024 23:05:33 -0400 Subject: [PATCH 088/183] Clean up tests/test_config/test_eval.py --- tests/test_config/test_eval.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/test_config/test_eval.py b/tests/test_config/test_eval.py index e762b0155..9d603d48d 100644 --- a/tests/test_config/test_eval.py +++ b/tests/test_config/test_eval.py @@ -98,7 +98,8 @@ def test_from_config_dict(self, config_dict): assert isinstance(eval_config, vak.config.EvalConfig) def test_from_config_dict_with_real_config(self, a_generated_eval_config_dict): - """test that instantiating EvalConfig class works as expected""" - eval_section = a_generated_eval_config_dict["eval"] - config = vak.config.eval.EvalConfig.from_config_dict(eval_section) - assert isinstance(config, vak.config.eval.EvalConfig) + eval_table = a_generated_eval_config_dict["eval"] + + eval_config = vak.config.eval.EvalConfig.from_config_dict(eval_table) + + assert isinstance(eval_config, vak.config.eval.EvalConfig) From 311287fe08b74157380baabec4d5531a064205ea Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Thu, 2 May 2024 23:14:22 -0400 Subject: [PATCH 089/183] Get tests passing in tests/test_config/test_predict.py --- tests/test_config/test_predict.py | 103 ++++++++++++++++++++++++++++-- 1 file changed, 97 insertions(+), 6 deletions(-) diff --git a/tests/test_config/test_predict.py b/tests/test_config/test_predict.py index c6fcb22f3..7c8a8eef4 100644 --- a/tests/test_config/test_predict.py +++ b/tests/test_config/test_predict.py @@ -1,10 +1,101 @@ """tests for vak.config.predict module""" +import pytest + import vak.config.predict -def test_predict_attrs_class(all_generated_predict_configs_toml): - """test that instantiating PredictConfig class works as expected""" - for config_toml in all_generated_predict_configs_toml: - predict_section = config_toml["PREDICT"] - config = vak.config.predict.PredictConfig(**predict_section) - assert isinstance(config, vak.config.predict.PredictConfig) +class TestPredictConfig: + + @pytest.mark.parametrize( + 'config_dict', + [ + { + 'spect_scaler_path': '/home/user/results_181014_194418/spect_scaler', + 'checkpoint_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/bl26lb16/results_200620_164245/TweetyNet/checkpoints/max-val-acc-checkpoint.pt', + 'labelmap_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/bl26lb16/results_200620_164245/labelmap.json', + 'batch_size': 11, + 'num_workers': 16, + 'device': 'cuda', + 'output_dir': './tests/data_for_tests/generated/results/predict/audio_cbin_annot_notmat/TweetyNet', + 'annot_csv_filename': 'bl26lb16.041912.annot.csv', + 'transform_params': { + 'window_size': 88 + }, + 'model': { + 'TweetyNet': { + 'network': { + 'conv1_filters': 8, + 'conv1_kernel_size': [3, 3], + 'conv2_filters': 16, + 'conv2_kernel_size': [5, 5], + 'pool1_size': [4, 1], + 'pool1_stride': [4, 1], + 'pool2_size': [4, 1], + 'pool2_stride': [4, 1], + 'hidden_size': 32 + }, + 'optimizer': {'lr': 0.001} + } + }, + 'dataset': { + 'path': '~/some/path/I/made/up/for/now' + }, + } + ] + ) + def test_init(self, config_dict): + config_dict['model'] = vak.config.ModelConfig.from_config_dict(config_dict['model']) + config_dict['dataset'] = vak.config.DatasetConfig.from_config_dict(config_dict['dataset']) + + predict_config = vak.config.PredictConfig(**config_dict) + + assert isinstance(predict_config, vak.config.PredictConfig) + + @pytest.mark.parametrize( + 'config_dict', + [ + { + 'spect_scaler_path': '/home/user/results_181014_194418/spect_scaler', + 'checkpoint_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/bl26lb16/results_200620_164245/TweetyNet/checkpoints/max-val-acc-checkpoint.pt', + 'labelmap_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/bl26lb16/results_200620_164245/labelmap.json', + 'batch_size': 11, + 'num_workers': 16, + 'device': 'cuda', + 'output_dir': './tests/data_for_tests/generated/results/predict/audio_cbin_annot_notmat/TweetyNet', + 'annot_csv_filename': 'bl26lb16.041912.annot.csv', + 'transform_params': { + 'window_size': 88 + }, + 'model': { + 'TweetyNet': { + 'network': { + 'conv1_filters': 8, + 'conv1_kernel_size': [3, 3], + 'conv2_filters': 16, + 'conv2_kernel_size': [5, 5], + 'pool1_size': [4, 1], + 'pool1_stride': [4, 1], + 'pool2_size': [4, 1], + 'pool2_stride': [4, 1], + 'hidden_size': 32 + }, + 'optimizer': {'lr': 0.001} + } + }, + 'dataset': { + 'path': '~/some/path/I/made/up/for/now' + }, + } + ] + ) + def test_from_config_dict(self, config_dict): + predict_config = vak.config.PredictConfig.from_config_dict(config_dict) + + assert isinstance(predict_config, vak.config.PredictConfig) + + def test_from_config_dict_with_real_config(self, a_generated_predict_config_dict): + predict_table = a_generated_predict_config_dict["predict"] + + predict_config = vak.config.predict.PredictConfig.from_config_dict(predict_table) + + assert isinstance(predict_config, vak.config.predict.PredictConfig) From 5c88bb0f47ceaf52b3e718245e76bd9af6a451e0 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Fri, 3 May 2024 07:18:12 -0400 Subject: [PATCH 090/183] Fix how we access config_toml in tests/scripts/vaktestdata/configs.py -- missing 'vak' key --- tests/scripts/vaktestdata/configs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/scripts/vaktestdata/configs.py b/tests/scripts/vaktestdata/configs.py index 98d31a265..3134f7f57 100644 --- a/tests/scripts/vaktestdata/configs.py +++ b/tests/scripts/vaktestdata/configs.py @@ -162,7 +162,7 @@ def fix_options_in_configs(config_metadata_list, command, single_train_result=Tr if spect_scaler_path: config_toml["vak"][table]["spect_scaler_path"] = str(spect_scaler_path) else: - if 'spect_scaler_path' in config_toml[table]: + if 'spect_scaler_path' in config_toml["vak"][table]: # remove any existing 'spect_scaler_path' option del config_toml["vak"][table]["spect_scaler_path"] if command != 'train_continue': # train always gets labelmap from dataset dir, not from a config option From 5db3d2f378b50a5e6e7e9459ba49cffd7a7f70a0 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Fri, 3 May 2024 07:18:36 -0400 Subject: [PATCH 091/183] Add pytest.mark.parametrize to tests/test_config/test_learncurve.py --- tests/test_config/test_learncurve.py | 110 +++++++++++++++++++++++++-- 1 file changed, 105 insertions(+), 5 deletions(-) diff --git a/tests/test_config/test_learncurve.py b/tests/test_config/test_learncurve.py index 6fa147c75..08e1d5c1d 100644 --- a/tests/test_config/test_learncurve.py +++ b/tests/test_config/test_learncurve.py @@ -1,10 +1,110 @@ """tests for vak.config.learncurve module""" +import pytest + import vak.config.learncurve -def test_learncurve_attrs_class(all_generated_learncurve_configs_toml): - """test that instantiating LearncurveConfig class works as expected""" - for config_toml in all_generated_learncurve_configs_toml: - learncurve_section = config_toml["LEARNCURVE"] - config = vak.config.learncurve.LearncurveConfig(**learncurve_section) +class TestLearncurveConfig: + + @pytest.mark.parametrize( + 'config_dict', + [ + { + 'normalize_spectrograms': True, + 'batch_size': 11, + 'num_epochs': 2, + 'val_step': 50, + 'ckpt_step': 200, + 'patience': 4, + 'num_workers': 16, + 'device': 'cuda', + 'root_results_dir': './tests/data_for_tests/generated/results/train/audio_cbin_annot_notmat/TweetyNet', + 'post_tfm_kwargs': {'majority_vote': True, 'min_segment_dur': 0.02}, + 'train_dataset_params': {'window_size': 88}, + 'val_transform_params': {'window_size': 88}, + 'model': { + 'TweetyNet': { + 'network': { + 'conv1_filters': 8, + 'conv1_kernel_size': [3, 3], + 'conv2_filters': 16, + 'conv2_kernel_size': [5, 5], + 'pool1_size': [4, 1], + 'pool1_stride': [4, 1], + 'pool2_size': [4, 1], + 'pool2_stride': [4, 1], + 'hidden_size': 32 + }, + 'optimizer': { + 'lr': 0.001 + } + } + }, + 'dataset': { + 'path': 'tests/data_for_tests/generated/prep/train/audio_cbin_annot_notmat/TweetyNet/032312-vak-frame-classification-dataset-generated-240502_234819' + } + } + ] + ) + def test_init(self, config_dict): + config_dict['model'] = vak.config.ModelConfig.from_config_dict(config_dict['model']) + config_dict['dataset'] = vak.config.DatasetConfig.from_config_dict(config_dict['dataset']) + + learncurve_config = vak.config.LearncurveConfig(**config_dict) + + assert isinstance(learncurve_config, vak.config.LearncurveConfig) + + @pytest.mark.parametrize( + 'config_dict', + [ + { + 'normalize_spectrograms': True, + 'batch_size': 11, + 'num_epochs': 2, + 'val_step': 50, + 'ckpt_step': 200, + 'patience': 4, + 'num_workers': 16, + 'device': 'cuda', + 'root_results_dir': './tests/data_for_tests/generated/results/train/audio_cbin_annot_notmat/TweetyNet', + 'post_tfm_kwargs': {'majority_vote': True, 'min_segment_dur': 0.02}, + 'train_dataset_params': {'window_size': 88}, + 'val_transform_params': {'window_size': 88}, + 'model': { + 'TweetyNet': { + 'network': { + 'conv1_filters': 8, + 'conv1_kernel_size': [3, 3], + 'conv2_filters': 16, + 'conv2_kernel_size': [5, 5], + 'pool1_size': [4, 1], + 'pool1_stride': [4, 1], + 'pool2_size': [4, 1], + 'pool2_stride': [4, 1], + 'hidden_size': 32 + }, + 'optimizer': { + 'lr': 0.001 + } + } + }, + 'dataset': { + 'path': 'tests/data_for_tests/generated/prep/train/audio_cbin_annot_notmat/TweetyNet/032312-vak-frame-classification-dataset-generated-240502_234819' + } + } + ] + ) + def test_from_config_dict(self, config_dict): + learncurve_config = vak.config.LearncurveConfig.from_config_dict(config_dict) + + assert isinstance(learncurve_config, vak.config.LearncurveConfig) + + def test_from_config_dict_with_real_config(self, a_generated_learncurve_config_dict): + """test that instantiating LearncurveConfig class works as expected""" + learncurve_table = a_generated_learncurve_config_dict["learncurve"] + + config = vak.config.learncurve.LearncurveConfig.from_config_dict( + learncurve_table + ) + assert isinstance(config, vak.config.learncurve.LearncurveConfig) From 3180d48a71a894757a90a4d3fdb090025fe5a0a8 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Fri, 3 May 2024 07:18:54 -0400 Subject: [PATCH 092/183] Rewrite tests in tests/test_config/test_train.py --- tests/test_config/test_train.py | 107 ++++++++++++++++++++++++++++++-- 1 file changed, 101 insertions(+), 6 deletions(-) diff --git a/tests/test_config/test_train.py b/tests/test_config/test_train.py index 1cde1db3e..5d4b3fee2 100644 --- a/tests/test_config/test_train.py +++ b/tests/test_config/test_train.py @@ -1,10 +1,105 @@ """tests for vak.config.train module""" +import pytest + import vak.config.train -def test_train_attrs_class(all_generated_train_configs_toml_path_pairs): - """test that instantiating TrainConfig class works as expected""" - for config_toml, toml_path in all_generated_train_configs_toml_path_pairs: - train_section = config_toml["TRAIN"] - train_config_obj = vak.config.train.TrainConfig(**train_section) - assert isinstance(train_config_obj, vak.config.train.TrainConfig) +class TestTrainConfig: + + @pytest.mark.parametrize( + 'config_dict', + [ + { + 'normalize_spectrograms': True, + 'batch_size': 11, + 'num_epochs': 2, + 'val_step': 50, + 'ckpt_step': 200, + 'patience': 4, + 'num_workers': 16, + 'device': 'cuda', + 'root_results_dir': './tests/data_for_tests/generated/results/train/audio_cbin_annot_notmat/TweetyNet', + 'train_dataset_params': {'window_size': 88}, + 'val_transform_params': {'window_size': 88}, + 'model': { + 'TweetyNet': { + 'network': { + 'conv1_filters': 8, + 'conv1_kernel_size': [3, 3], + 'conv2_filters': 16, + 'conv2_kernel_size': [5, 5], + 'pool1_size': [4, 1], + 'pool1_stride': [4, 1], + 'pool2_size': [4, 1], + 'pool2_stride': [4, 1], + 'hidden_size': 32 + }, + 'optimizer': { + 'lr': 0.001 + } + } + }, + 'dataset': { + 'path': 'tests/data_for_tests/generated/prep/train/audio_cbin_annot_notmat/TweetyNet/032312-vak-frame-classification-dataset-generated-240502_234819' + } + } + ] + ) + def test_init(self, config_dict): + config_dict['model'] = vak.config.ModelConfig.from_config_dict(config_dict['model']) + config_dict['dataset'] = vak.config.DatasetConfig.from_config_dict(config_dict['dataset']) + + train_config = vak.config.TrainConfig(**config_dict) + + assert isinstance(train_config, vak.config.TrainConfig) + + @pytest.mark.parametrize( + 'config_dict', + [ + { + 'normalize_spectrograms': True, + 'batch_size': 11, + 'num_epochs': 2, + 'val_step': 50, + 'ckpt_step': 200, + 'patience': 4, + 'num_workers': 16, + 'device': 'cuda', + 'root_results_dir': './tests/data_for_tests/generated/results/train/audio_cbin_annot_notmat/TweetyNet', + 'train_dataset_params': {'window_size': 88}, + 'val_transform_params': {'window_size': 88}, + 'model': { + 'TweetyNet': { + 'network': { + 'conv1_filters': 8, + 'conv1_kernel_size': [3, 3], + 'conv2_filters': 16, + 'conv2_kernel_size': [5, 5], + 'pool1_size': [4, 1], + 'pool1_stride': [4, 1], + 'pool2_size': [4, 1], + 'pool2_stride': [4, 1], + 'hidden_size': 32 + }, + 'optimizer': { + 'lr': 0.001 + } + } + }, + 'dataset': { + 'path': 'tests/data_for_tests/generated/prep/train/audio_cbin_annot_notmat/TweetyNet/032312-vak-frame-classification-dataset-generated-240502_234819' + } + } + ] + ) + def test_from_config_dict(self, config_dict): + train_config = vak.config.TrainConfig.from_config_dict(config_dict) + + assert isinstance(train_config, vak.config.TrainConfig) + + def test_from_config_dict_with_real_config(self, a_generated_train_config_dict): + train_table = a_generated_train_config_dict["train"] + + train_config = vak.config.train.TrainConfig.from_config_dict(train_table) + + assert isinstance(train_config, vak.config.train.TrainConfig) From 7c2a1f07783ef023d698bdb5e0b647a0d7bceceb Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Fri, 3 May 2024 07:19:09 -0400 Subject: [PATCH 093/183] Rewrite tests in tests/test_config/test_config.py --- tests/test_config/test_config.py | 45 ++++++++++++++++---------------- 1 file changed, 23 insertions(+), 22 deletions(-) diff --git a/tests/test_config/test_config.py b/tests/test_config/test_config.py index ebf814896..8d2ced413 100644 --- a/tests/test_config/test_config.py +++ b/tests/test_config/test_config.py @@ -1,26 +1,27 @@ import vak.config -def test_config_attrs_class( - all_generated_configs_toml_path_pairs, - default_model, -): - """test that instantiating Config class works as expected""" - for config_toml, toml_path in all_generated_configs_toml_path_pairs: - if default_model not in str(toml_path): - continue # only need to check configs for one model - # also avoids FileNotFoundError on CI - # this is basically the body of the ``config.parse.from_toml`` function. - config_dict = {} - for section_name in list(vak.config.parse.SECTION_CLASSES.keys()): - if section_name in config_toml: - vak.config.validators.are_options_valid( - config_toml, section_name, toml_path - ) - section = vak.config.parse.parse_config_section( - config_toml, section_name, toml_path - ) - config_dict[section_name.lower()] = section +class TestConfig: + def test_from_config_dict_with_real_config( + a_generated_config_dict, + ): + """test that instantiating Config class works as expected""" + # this is basically the body of the ``config.load.from_toml`` function. + config_kwargs = {} + for table_name in a_generated_config_dict: + config_kwargs[table_name] = vak.config.load.TABLE_CLASSES_MAP[table_name].from_config_dict( + a_generated_config_dict[table_name] + ) - config = vak.config.parse.Config(**config_dict) - assert isinstance(config, vak.config.parse.Config) + config = vak.config.load.Config(**config_kwargs) + + assert isinstance(config, vak.config.load.Config) + # we already test that config loading works for EvalConfig, et al., + # so here we just test that the logic of Config works as expected: + # we should get an attribute for each top-level table that we pass in; + # if we don't pass one in, then its corresponding attribute should be None + for table_name in ('eval', 'learncurve', 'predict', 'prep', 'train'): + if table_name in a_generated_config_dict: + assert hasattr(config, table_name) + else: + assert getattr(config, table_name) is None From 18b6dc2d65879fa04f5bea456db6cf9ec7b9108d Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Fri, 3 May 2024 07:24:50 -0400 Subject: [PATCH 094/183] Add unit test to tests/test_config/test_model.py --- tests/test_config/test_model.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/tests/test_config/test_model.py b/tests/test_config/test_model.py index 881e31f18..73745df79 100644 --- a/tests/test_config/test_model.py +++ b/tests/test_config/test_model.py @@ -4,6 +4,7 @@ class TestModelConfig: + @pytest.mark.parametrize( 'config_dict', [ @@ -103,3 +104,23 @@ def test_from_config_dict_real_config(self, a_generated_config_dict): else: assert getattr(model_config, attr) == {} + @pytest.mark.parametrize( + 'config_dict, expected_exception', + [ + # Raises ValueError because model name not in registry + ( + { + 'NonExistentModel': { + 'network': {}, + 'optimizer': {}, + 'loss': {}, + 'metrics': {}, + } + }, + ValueError, + ) + ] + ) + def test_from_config_dict_raises(self, config_dict, expected_exception): + with pytest.raises(expected_exception): + vak.config.model.ModelConfig.from_config_dict(config_dict) From a4c06d59ad9b6b7d9550e6e29e0d54c2afad6ca3 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Fri, 3 May 2024 07:29:20 -0400 Subject: [PATCH 095/183] Add unit test for exceptions in tests/test_config/test_eval.py --- tests/test_config/test_eval.py | 66 ++++++++++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) diff --git a/tests/test_config/test_eval.py b/tests/test_config/test_eval.py index 9d603d48d..0ee7bfdb2 100644 --- a/tests/test_config/test_eval.py +++ b/tests/test_config/test_eval.py @@ -103,3 +103,69 @@ def test_from_config_dict_with_real_config(self, a_generated_eval_config_dict): eval_config = vak.config.eval.EvalConfig.from_config_dict(eval_table) assert isinstance(eval_config, vak.config.eval.EvalConfig) + + @pytest.mark.parametrize( + 'config_dict, expected_exception', + [ + # missing 'model', should raise KeyError + ( + { + 'checkpoint_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/TweetyNet/checkpoints/max-val-acc-checkpoint.pt', + 'labelmap_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/labelmap.json', + 'batch_size': 11, + 'num_workers': 16, + 'device': 'cuda', + 'spect_scaler_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/StandardizeSpect', + 'output_dir': './tests/data_for_tests/generated/results/eval/audio_cbin_annot_notmat/TweetyNet', + 'post_tfm_kwargs': { + 'majority_vote': True, 'min_segment_dur': 0.02 + }, + 'transform_params': { + 'window_size': 88 + }, + 'dataset': { + 'path': '~/some/path/I/made/up/for/now' + }, + }, + KeyError + ), + # missing 'dataset', should raise KeyError + ( + { + 'checkpoint_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/TweetyNet/checkpoints/max-val-acc-checkpoint.pt', + 'labelmap_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/labelmap.json', + 'batch_size': 11, + 'num_workers': 16, + 'device': 'cuda', + 'spect_scaler_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/StandardizeSpect', + 'output_dir': './tests/data_for_tests/generated/results/eval/audio_cbin_annot_notmat/TweetyNet', + 'post_tfm_kwargs': { + 'majority_vote': True, 'min_segment_dur': 0.02 + }, + 'transform_params': { + 'window_size': 88 + }, + 'model': { + 'TweetyNet': { + 'network': { + 'conv1_filters': 8, + 'conv1_kernel_size': [3, 3], + 'conv2_filters': 16, + 'conv2_kernel_size': [5, 5], + 'pool1_size': [4, 1], + 'pool1_stride': [4, 1], + 'pool2_size': [4, 1], + 'pool2_stride': [4, 1], + 'hidden_size': 32 + }, + 'optimizer': {'lr': 0.001} + } + }, + }, + KeyError + ), + ] + ) + def test_from_config_dict_raises(self, config_dict, expected_exception): + with pytest.raises(expected_exception): + vak.config.EvalConfig.from_config_dict(config_dict) \ No newline at end of file From 05032d6238bb76e0b8595aad95d2fdf5f681a8a8 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Fri, 3 May 2024 07:30:04 -0400 Subject: [PATCH 096/183] Fix 'cfg.spect_params' -> 'cfg.prep.spect_params' in src/vak/cli/predict.py --- src/vak/cli/predict.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/vak/cli/predict.py b/src/vak/cli/predict.py index feb976357..aecd96b7f 100644 --- a/src/vak/cli/predict.py +++ b/src/vak/cli/predict.py @@ -51,7 +51,7 @@ def predict(toml_path): num_workers=cfg.predict.num_workers, transform_params=cfg.predict.transform_params, dataset_params=cfg.predict.dataset_params, - timebins_key=cfg.spect_params.timebins_key, + timebins_key=cfg.prep.spect_params.timebins_key, spect_scaler_path=cfg.predict.spect_scaler_path, device=cfg.predict.device, annot_csv_filename=cfg.predict.annot_csv_filename, From d96d7827c7caaaa4f1681ec2ded718101df8bf1c Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Fri, 3 May 2024 07:45:49 -0400 Subject: [PATCH 097/183] Add unit test for exceptions in tests/test_config/test_learncurve.py --- tests/test_config/test_learncurve.py | 66 ++++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) diff --git a/tests/test_config/test_learncurve.py b/tests/test_config/test_learncurve.py index 08e1d5c1d..4cf6e604b 100644 --- a/tests/test_config/test_learncurve.py +++ b/tests/test_config/test_learncurve.py @@ -108,3 +108,69 @@ def test_from_config_dict_with_real_config(self, a_generated_learncurve_config_d ) assert isinstance(config, vak.config.learncurve.LearncurveConfig) + + @pytest.mark.parametrize( + 'config_dict, expected_exception', + [ + # missing 'model', should raise KeyError + ( + { + 'normalize_spectrograms': True, + 'batch_size': 11, + 'num_epochs': 2, + 'val_step': 50, + 'ckpt_step': 200, + 'patience': 4, + 'num_workers': 16, + 'device': 'cuda', + 'root_results_dir': './tests/data_for_tests/generated/results/train/audio_cbin_annot_notmat/TweetyNet', + 'post_tfm_kwargs': {'majority_vote': True, 'min_segment_dur': 0.02}, + 'train_dataset_params': {'window_size': 88}, + 'val_transform_params': {'window_size': 88}, + 'dataset': { + 'path': 'tests/data_for_tests/generated/prep/train/audio_cbin_annot_notmat/TweetyNet/032312-vak-frame-classification-dataset-generated-240502_234819' + } + }, + KeyError + ), + # missing 'dataset', should raise KeyError + ( + { + 'normalize_spectrograms': True, + 'batch_size': 11, + 'num_epochs': 2, + 'val_step': 50, + 'ckpt_step': 200, + 'patience': 4, + 'num_workers': 16, + 'device': 'cuda', + 'root_results_dir': './tests/data_for_tests/generated/results/train/audio_cbin_annot_notmat/TweetyNet', + 'post_tfm_kwargs': {'majority_vote': True, 'min_segment_dur': 0.02}, + 'train_dataset_params': {'window_size': 88}, + 'val_transform_params': {'window_size': 88}, + 'model': { + 'TweetyNet': { + 'network': { + 'conv1_filters': 8, + 'conv1_kernel_size': [3, 3], + 'conv2_filters': 16, + 'conv2_kernel_size': [5, 5], + 'pool1_size': [4, 1], + 'pool1_stride': [4, 1], + 'pool2_size': [4, 1], + 'pool2_stride': [4, 1], + 'hidden_size': 32 + }, + 'optimizer': { + 'lr': 0.001 + } + } + }, + }, + KeyError + ), + ] + ) + def test_from_config_dict_raises(self, config_dict, expected_exception): + with pytest.raises(expected_exception): + vak.config.LearncurveConfig.from_config_dict(config_dict) From 3927c5b914d89751b14f8f2d391577b4d9b6c85d Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Fri, 3 May 2024 07:46:02 -0400 Subject: [PATCH 098/183] Add unit test for exceptions in tests/test_config/test_train.py --- tests/test_config/test_train.py | 62 +++++++++++++++++++++++++++++++++ 1 file changed, 62 insertions(+) diff --git a/tests/test_config/test_train.py b/tests/test_config/test_train.py index 5d4b3fee2..3a416db98 100644 --- a/tests/test_config/test_train.py +++ b/tests/test_config/test_train.py @@ -103,3 +103,65 @@ def test_from_config_dict_with_real_config(self, a_generated_train_config_dict): train_config = vak.config.train.TrainConfig.from_config_dict(train_table) assert isinstance(train_config, vak.config.train.TrainConfig) + + @pytest.mark.parametrize( + 'config_dict', + [ + ( + { + 'normalize_spectrograms': True, + 'batch_size': 11, + 'num_epochs': 2, + 'val_step': 50, + 'ckpt_step': 200, + 'patience': 4, + 'num_workers': 16, + 'device': 'cuda', + 'root_results_dir': './tests/data_for_tests/generated/results/train/audio_cbin_annot_notmat/TweetyNet', + 'train_dataset_params': {'window_size': 88}, + 'val_transform_params': {'window_size': 88}, + 'dataset': { + 'path': 'tests/data_for_tests/generated/prep/train/audio_cbin_annot_notmat/TweetyNet/032312-vak-frame-classification-dataset-generated-240502_234819' + } + }, + KeyError + ), + ( + { + 'normalize_spectrograms': True, + 'batch_size': 11, + 'num_epochs': 2, + 'val_step': 50, + 'ckpt_step': 200, + 'patience': 4, + 'num_workers': 16, + 'device': 'cuda', + 'root_results_dir': './tests/data_for_tests/generated/results/train/audio_cbin_annot_notmat/TweetyNet', + 'train_dataset_params': {'window_size': 88}, + 'val_transform_params': {'window_size': 88}, + 'model': { + 'TweetyNet': { + 'network': { + 'conv1_filters': 8, + 'conv1_kernel_size': [3, 3], + 'conv2_filters': 16, + 'conv2_kernel_size': [5, 5], + 'pool1_size': [4, 1], + 'pool1_stride': [4, 1], + 'pool2_size': [4, 1], + 'pool2_stride': [4, 1], + 'hidden_size': 32 + }, + 'optimizer': { + 'lr': 0.001 + } + } + }, + }, + KeyError + ) + ] + ) + def test_from_config_dict_raises(self, config_dict, expected_exception): + with pytest.raises(expected_exception): + vak.config.TrainConfig.from_config_dict(config_dict) \ No newline at end of file From 1996df2131d5886de3aa28cb3772b693ff2105e5 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Fri, 3 May 2024 08:19:42 -0400 Subject: [PATCH 099/183] Add more test cases to TestEvalConfig.test_from_config_dict_raises --- tests/test_config/test_eval.py | 73 ++++++++++++++++++++++++++++++++++ 1 file changed, 73 insertions(+) diff --git a/tests/test_config/test_eval.py b/tests/test_config/test_eval.py index 0ee7bfdb2..cb2a36e8a 100644 --- a/tests/test_config/test_eval.py +++ b/tests/test_config/test_eval.py @@ -164,6 +164,79 @@ def test_from_config_dict_with_real_config(self, a_generated_eval_config_dict): }, KeyError ), + # missing 'checkpoint_path', should raise KeyError + ( + { + 'labelmap_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/labelmap.json', + 'batch_size': 11, + 'num_workers': 16, + 'device': 'cuda', + 'spect_scaler_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/StandardizeSpect', + 'output_dir': './tests/data_for_tests/generated/results/eval/audio_cbin_annot_notmat/TweetyNet', + 'post_tfm_kwargs': { + 'majority_vote': True, 'min_segment_dur': 0.02 + }, + 'transform_params': { + 'window_size': 88 + }, + 'model': { + 'TweetyNet': { + 'network': { + 'conv1_filters': 8, + 'conv1_kernel_size': [3, 3], + 'conv2_filters': 16, + 'conv2_kernel_size': [5, 5], + 'pool1_size': [4, 1], + 'pool1_stride': [4, 1], + 'pool2_size': [4, 1], + 'pool2_stride': [4, 1], + 'hidden_size': 32 + }, + 'optimizer': {'lr': 0.001} + } + }, + 'dataset': { + 'path': '~/some/path/I/made/up/for/now' + }, + } + ), + # missing 'output_dir', should raise KeyError + ( + { + 'checkpoint_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/TweetyNet/checkpoints/max-val-acc-checkpoint.pt', + 'labelmap_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/labelmap.json', + 'batch_size': 11, + 'num_workers': 16, + 'device': 'cuda', + 'spect_scaler_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/StandardizeSpect', + 'post_tfm_kwargs': { + 'majority_vote': True, 'min_segment_dur': 0.02 + }, + 'transform_params': { + 'window_size': 88 + }, + 'model': { + 'TweetyNet': { + 'network': { + 'conv1_filters': 8, + 'conv1_kernel_size': [3, 3], + 'conv2_filters': 16, + 'conv2_kernel_size': [5, 5], + 'pool1_size': [4, 1], + 'pool1_stride': [4, 1], + 'pool2_size': [4, 1], + 'pool2_stride': [4, 1], + 'hidden_size': 32 + }, + 'optimizer': {'lr': 0.001} + } + }, + 'dataset': { + 'path': '~/some/path/I/made/up/for/now' + }, + }, + KeyError + ) ] ) def test_from_config_dict_raises(self, config_dict, expected_exception): From 30a31f864bc7e4d6a046f94a0cf8d104f22a2f61 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Fri, 3 May 2024 08:21:31 -0400 Subject: [PATCH 100/183] Add more test cases to TestLearncurveConfig.test_from_config_dict_raises --- tests/test_config/test_learncurve.py | 38 ++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/tests/test_config/test_learncurve.py b/tests/test_config/test_learncurve.py index 4cf6e604b..581e04c32 100644 --- a/tests/test_config/test_learncurve.py +++ b/tests/test_config/test_learncurve.py @@ -169,6 +169,44 @@ def test_from_config_dict_with_real_config(self, a_generated_learncurve_config_d }, KeyError ), + # missing 'root_results_dir', should raise KeyError + ( + { + 'normalize_spectrograms': True, + 'batch_size': 11, + 'num_epochs': 2, + 'val_step': 50, + 'ckpt_step': 200, + 'patience': 4, + 'num_workers': 16, + 'device': 'cuda', + 'post_tfm_kwargs': {'majority_vote': True, 'min_segment_dur': 0.02}, + 'train_dataset_params': {'window_size': 88}, + 'val_transform_params': {'window_size': 88}, + 'model': { + 'TweetyNet': { + 'network': { + 'conv1_filters': 8, + 'conv1_kernel_size': [3, 3], + 'conv2_filters': 16, + 'conv2_kernel_size': [5, 5], + 'pool1_size': [4, 1], + 'pool1_stride': [4, 1], + 'pool2_size': [4, 1], + 'pool2_stride': [4, 1], + 'hidden_size': 32 + }, + 'optimizer': { + 'lr': 0.001 + } + } + }, + 'dataset': { + 'path': 'tests/data_for_tests/generated/prep/train/audio_cbin_annot_notmat/TweetyNet/032312-vak-frame-classification-dataset-generated-240502_234819' + } + }, + KeyError + ) ] ) def test_from_config_dict_raises(self, config_dict, expected_exception): From 836e8edf11a906de7dcb7868d08228493591bb8e Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Fri, 3 May 2024 08:25:16 -0400 Subject: [PATCH 101/183] Add unit test for exceptions in tests/test_config/test_predict.py --- tests/test_config/test_predict.py | 97 +++++++++++++++++++++++++++++++ 1 file changed, 97 insertions(+) diff --git a/tests/test_config/test_predict.py b/tests/test_config/test_predict.py index 7c8a8eef4..9603e0df4 100644 --- a/tests/test_config/test_predict.py +++ b/tests/test_config/test_predict.py @@ -99,3 +99,100 @@ def test_from_config_dict_with_real_config(self, a_generated_predict_config_dict predict_config = vak.config.predict.PredictConfig.from_config_dict(predict_table) assert isinstance(predict_config, vak.config.predict.PredictConfig) + + @pytest.mark.parametrize( + 'config_dict, expected_exception', + [ + # missing 'checkpoint_path', should raise KeyError + ( + { + 'spect_scaler_path': '/home/user/results_181014_194418/spect_scaler', + 'labelmap_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/bl26lb16/results_200620_164245/labelmap.json', + 'batch_size': 11, + 'num_workers': 16, + 'device': 'cuda', + 'output_dir': './tests/data_for_tests/generated/results/predict/audio_cbin_annot_notmat/TweetyNet', + 'annot_csv_filename': 'bl26lb16.041912.annot.csv', + 'transform_params': { + 'window_size': 88 + }, + 'model': { + 'TweetyNet': { + 'network': { + 'conv1_filters': 8, + 'conv1_kernel_size': [3, 3], + 'conv2_filters': 16, + 'conv2_kernel_size': [5, 5], + 'pool1_size': [4, 1], + 'pool1_stride': [4, 1], + 'pool2_size': [4, 1], + 'pool2_stride': [4, 1], + 'hidden_size': 32 + }, + 'optimizer': {'lr': 0.001} + } + }, + 'dataset': { + 'path': '~/some/path/I/made/up/for/now' + }, + }, + KeyError + ), + # missing 'dataset', should raise KeyError + ( + { + 'spect_scaler_path': '/home/user/results_181014_194418/spect_scaler', + 'checkpoint_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/bl26lb16/results_200620_164245/TweetyNet/checkpoints/max-val-acc-checkpoint.pt', + 'labelmap_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/bl26lb16/results_200620_164245/labelmap.json', + 'batch_size': 11, + 'num_workers': 16, + 'device': 'cuda', + 'output_dir': './tests/data_for_tests/generated/results/predict/audio_cbin_annot_notmat/TweetyNet', + 'annot_csv_filename': 'bl26lb16.041912.annot.csv', + 'transform_params': { + 'window_size': 88 + }, + 'model': { + 'TweetyNet': { + 'network': { + 'conv1_filters': 8, + 'conv1_kernel_size': [3, 3], + 'conv2_filters': 16, + 'conv2_kernel_size': [5, 5], + 'pool1_size': [4, 1], + 'pool1_stride': [4, 1], + 'pool2_size': [4, 1], + 'pool2_stride': [4, 1], + 'hidden_size': 32 + }, + 'optimizer': {'lr': 0.001} + } + }, + }, + KeyError + ), + # missing 'model', should raise KeyError + ( + { + 'spect_scaler_path': '/home/user/results_181014_194418/spect_scaler', + 'checkpoint_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/bl26lb16/results_200620_164245/TweetyNet/checkpoints/max-val-acc-checkpoint.pt', + 'labelmap_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/bl26lb16/results_200620_164245/labelmap.json', + 'batch_size': 11, + 'num_workers': 16, + 'device': 'cuda', + 'output_dir': './tests/data_for_tests/generated/results/predict/audio_cbin_annot_notmat/TweetyNet', + 'annot_csv_filename': 'bl26lb16.041912.annot.csv', + 'transform_params': { + 'window_size': 88 + }, + 'dataset': { + 'path': '~/some/path/I/made/up/for/now' + }, + }, + KeyError + ), + ] + ) + def test_from_config_dict_raises(self, config_dict, expected_exception): + with pytest.raises(expected_exception): + vak.config.PredictConfig.from_config_dict(config_dict) From 8e2cc9f1ba5de57e663be7c0fb06acfb65aeee99 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Fri, 3 May 2024 08:33:24 -0400 Subject: [PATCH 102/183] Add two unit tests that PrepConfig raises expected exceptions --- tests/test_config/test_prep.py | 106 +++++++++++++++++++++++++-------- 1 file changed, 81 insertions(+), 25 deletions(-) diff --git a/tests/test_config/test_prep.py b/tests/test_config/test_prep.py index 6901c2b23..9583b708b 100644 --- a/tests/test_config/test_prep.py +++ b/tests/test_config/test_prep.py @@ -8,11 +8,62 @@ class TestPrepConfig: @pytest.mark.parametrize( - 'config_dict', + 'config_dict', + [ + { + 'annot_format': 'notmat', + 'audio_format': 'cbin', + 'data_dir': './tests/data_for_tests/source/audio_cbin_annot_notmat/gy6or6/032412', + 'dataset_type': 'parametric umap', + 'input_type': 'spect', + 'labelset': 'iabcdefghjk', + 'output_dir': './tests/data_for_tests/generated/prep/eval/audio_cbin_annot_notmat/ConvEncoderUMAP', + 'spect_params': {'fft_size': 512, + 'step_size': 32, + 'transform_type': 'log_spect_plus_one'}, + 'test_dur': 0.2 + }, + { + 'annot_format': 'notmat', + 'audio_format': 'cbin', + 'data_dir': './tests/data_for_tests/source/audio_cbin_annot_notmat/gy6or6/032312', + 'dataset_type': 'frame classification', + 'input_type': 'spect', + 'labelset': 'iabcdefghjk', + 'output_dir': './tests/data_for_tests/generated/prep/train/audio_cbin_annot_notmat/TweetyNet', + 'spect_params': {'fft_size': 512, + 'freq_cutoffs': [500, 10000], + 'step_size': 64, + 'thresh': 6.25, + 'transform_type': 'log_spect'}, + 'test_dur': 30, + 'train_dur': 50, + 'val_dur': 15 + }, + ] + ) + def test_init(self, config_dict): + config_dict['spect_params'] = vak.config.SpectParamsConfig(**config_dict['spect_params']) + + prep_config = vak.config.PrepConfig(**config_dict) + + assert isinstance(prep_config, vak.config.prep.PrepConfig) + for key, val in config_dict.items(): + assert hasattr(prep_config, key) + if key == 'data_dir' or key == 'output_dir': + assert getattr(prep_config, key) == vak.common.converters.expanded_user_path(val) + elif key == 'labelset': + assert getattr(prep_config, key) == vak.common.converters.labelset_to_set(val) + else: + assert getattr(prep_config, key) == val + + @pytest.mark.parametrize( + 'config_dict', [ { 'annot_format': 'notmat', 'audio_format': 'cbin', + 'spect_format': 'mat', 'data_dir': './tests/data_for_tests/source/audio_cbin_annot_notmat/gy6or6/032412', 'dataset_type': 'parametric umap', 'input_type': 'spect', @@ -23,39 +74,44 @@ class TestPrepConfig: 'transform_type': 'log_spect_plus_one'}, 'test_dur': 0.2 }, + ] + ) + def test_both_audio_and_spect_format_raises( + self, config_dict, + ): + """test that a config with both an audio and a spect format raises a ValueError""" + # need to do this set-up so we don't mask one error with another + config_dict['spect_params'] = vak.config.SpectParamsConfig(**config_dict['spect_params']) + + with pytest.raises(ValueError): + prep_config = vak.config.PrepConfig(**config_dict) + + @pytest.mark.parametrize( + 'config_dict', + [ { 'annot_format': 'notmat', - 'audio_format': 'cbin', - 'data_dir': './tests/data_for_tests/source/audio_cbin_annot_notmat/gy6or6/032312', - 'dataset_type': 'frame classification', + 'data_dir': './tests/data_for_tests/source/audio_cbin_annot_notmat/gy6or6/032412', + 'dataset_type': 'parametric umap', 'input_type': 'spect', 'labelset': 'iabcdefghjk', - 'output_dir': './tests/data_for_tests/generated/prep/train/audio_cbin_annot_notmat/TweetyNet', + 'output_dir': './tests/data_for_tests/generated/prep/eval/audio_cbin_annot_notmat/ConvEncoderUMAP', 'spect_params': {'fft_size': 512, - 'freq_cutoffs': [500, 10000], - 'step_size': 64, - 'thresh': 6.25, - 'transform_type': 'log_spect'}, - 'test_dur': 30, - 'train_dur': 50, - 'val_dur': 15 + 'step_size': 32, + 'transform_type': 'log_spect_plus_one'}, + 'test_dur': 0.2 }, - ] + ] ) - def test_init(self, config_dict): + def test_neither_audio_nor_spect_format_raises( + self, config_dict + ): + """test that a config without either an audio or a spect format raises a ValueError""" + # need to do this set-up so we don't mask one error with another config_dict['spect_params'] = vak.config.SpectParamsConfig(**config_dict['spect_params']) - prep_config = vak.config.PrepConfig(**config_dict) - - assert isinstance(prep_config, vak.config.prep.PrepConfig) - for key, val in config_dict.items(): - assert hasattr(prep_config, key) - if key == 'data_dir' or key == 'output_dir': - assert getattr(prep_config, key) == vak.common.converters.expanded_user_path(val) - elif key == 'labelset': - assert getattr(prep_config, key) == vak.common.converters.labelset_to_set(val) - else: - assert getattr(prep_config, key) == val + with pytest.raises(ValueError): + prep_config = vak.config.PrepConfig(**config_dict) @pytest.mark.parametrize( 'config_dict', From 35c37326574be220d887464dc1b66aa0e00b236e Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Fri, 3 May 2024 08:54:50 -0400 Subject: [PATCH 103/183] Fix/add unit tests in tests/test_config/test_config.py --- tests/test_config/test_config.py | 32 ++++++++++++++++++++++++++++---- 1 file changed, 28 insertions(+), 4 deletions(-) diff --git a/tests/test_config/test_config.py b/tests/test_config/test_config.py index 8d2ced413..74d96ac6b 100644 --- a/tests/test_config/test_config.py +++ b/tests/test_config/test_config.py @@ -1,21 +1,23 @@ +import pytest + import vak.config class TestConfig: def test_from_config_dict_with_real_config( - a_generated_config_dict, + self, a_generated_config_dict, ): """test that instantiating Config class works as expected""" # this is basically the body of the ``config.load.from_toml`` function. config_kwargs = {} for table_name in a_generated_config_dict: - config_kwargs[table_name] = vak.config.load.TABLE_CLASSES_MAP[table_name].from_config_dict( + config_kwargs[table_name] = vak.config.config.TABLE_CLASSES_MAP[table_name].from_config_dict( a_generated_config_dict[table_name] ) - config = vak.config.load.Config(**config_kwargs) + config = vak.config.Config(**config_kwargs) - assert isinstance(config, vak.config.load.Config) + assert isinstance(config, vak.config.Config) # we already test that config loading works for EvalConfig, et al., # so here we just test that the logic of Config works as expected: # we should get an attribute for each top-level table that we pass in; @@ -25,3 +27,25 @@ def test_from_config_dict_with_real_config( assert hasattr(config, table_name) else: assert getattr(config, table_name) is None + + def test_from_toml_path(self, a_generated_config_path): + config_toml = vak.config.load._load_toml_from_path(a_generated_config_path) + assert isinstance(config_toml, dict) + + def test_from_toml_path_raises_when_config_doesnt_exist(self, config_that_doesnt_exist): + with pytest.raises(FileNotFoundError): + vak.config.Config.from_toml_path(config_that_doesnt_exist) + + def test_invalid_table_raises(self, invalid_table_config_path): + with pytest.raises(ValueError): + vak.config.Config.from_toml_path(invalid_table_config_path) + + def test_invalid_key_raises(self, invalid_key_config_path): + with pytest.raises(ValueError): + vak.config.Config.from_toml_path(invalid_key_config_path) + + def test_mutiple_top_level_tables_besides_prep_raises(self, invalid_train_and_learncurve_config_toml): + """Test that a .toml config with two top-level tables besides ``[vak.prep]`` raises a ValueError + (in this case ``[vak.train]`` and ``[vak.learncurve]``)""" + with pytest.raises(ValueError): + vak.config.Config.from_toml_path(invalid_train_and_learncurve_config_toml) From e0abf23500b1241a63714d5cb964c686b50ee4f4 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Fri, 3 May 2024 10:04:24 -0400 Subject: [PATCH 104/183] Change order of parameters for Config.from_config_dict, make toml_path last param --- src/vak/config/config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/vak/config/config.py b/src/vak/config/config.py index fa7e3f20e..d94450439 100644 --- a/src/vak/config/config.py +++ b/src/vak/config/config.py @@ -87,8 +87,8 @@ class Config: def from_config_dict( cls, config_dict: dict, + tables_to_parse: str | list[str] | None = None, toml_path: str | pathlib.Path | None = None, - tables_to_parse: str | list[str] | None = None ) -> "Config": """Return instance of :class:`Config` class, given a :class:`dict` containing the contents of @@ -173,4 +173,4 @@ def from_toml_path( tables in a config.toml file. """ config_dict: dict = load._load_toml_from_path(toml_path) - return cls.from_config_dict(config_dict, toml_path, tables_to_parse) + return cls.from_config_dict(config_dict, tables_to_parse, toml_path) From e58509149c03b6c7e437a30388115766ad183c6c Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Fri, 3 May 2024 10:04:34 -0400 Subject: [PATCH 105/183] Fix/add unit tests in tests/fixtures/config.py --- tests/fixtures/config.py | 118 ++++++++++++++++++++++----------------- 1 file changed, 66 insertions(+), 52 deletions(-) diff --git a/tests/fixtures/config.py b/tests/fixtures/config.py index 65279197b..184504099 100644 --- a/tests/fixtures/config.py +++ b/tests/fixtures/config.py @@ -70,6 +70,11 @@ def invalid_key_config_path(test_configs_root): return test_configs_root.joinpath("invalid_key_config.toml") +@pytest.fixture +def invalid_train_and_learncurve_config_toml(test_configs_root): + return test_configs_root.joinpath("invalid_train_and_learncurve_config.toml") + + GENERATED_TEST_CONFIGS_ROOT = GENERATED_TEST_DATA_ROOT.joinpath("configs") @@ -194,37 +199,6 @@ def a_generated_config_path(request): return request.param -@pytest.fixture -def all_generated_train_configs(generated_test_configs_root): - return sorted(generated_test_configs_root.glob("test_train*toml")) - - -ALL_GENERATED_LEARNCURVE_CONFIGS = sorted(GENERATED_TEST_CONFIGS_ROOT.glob("test_learncurve*toml")) - -@pytest.fixture -def all_generated_learncurve_configs(generated_test_configs_root): - return ALL_GENERATED_LEARNCURVE_CONFIGS - - -ALL_GENERATED_EVAL_CONFIG_PATHS = sorted( - GENERATED_TEST_CONFIGS_ROOT.glob("test_eval*toml") -) - - -@pytest.fixture -def all_generated_eval_configs(): - return ALL_GENERATED_EVAL_CONFIG_PATHS - -@pytest.fixture(params=ALL_GENERATED_EVAL_CONFIG_PATHS) -def a_generated_eval_config_toml(request): - return request.param - - -@pytest.fixture -def all_generated_predict_configs(generated_test_configs_root): - return sorted(generated_test_configs_root.glob("test_predict*toml")) - - def _tomlkit_to_popo(d): """Convert tomlkit to "popo" (Plain-Old Python Objects) @@ -290,16 +264,56 @@ def _specific_config_toml( return _specific_config_toml -ALL_GENERATED_CONFIG_DICTS = [ - _load_config_dict(config) - for config in ALL_GENERATED_CONFIG_PATHS -] - -@pytest.fixture(params=ALL_GENERATED_CONFIG_DICTS) +@pytest.fixture(params=ALL_GENERATED_CONFIG_PATHS) def a_generated_config_dict(request): - return request.param + # we remake dict every time this gets called + # so that we're not returning a ``config_dict`` that was + # already mutated by a `Config.from_config_dict` function, + # e.g. the value for the 'spect_params' key gets mapped to a SpectParamsConfig + # by PrepConfig.from_config_dict + return _load_config_dict(request.param) + + +ALL_GENERATED_EVAL_CONFIG_PATHS = sorted( + GENERATED_TEST_CONFIGS_ROOT.glob("*eval*toml") +) + +ALL_GENERATED_LEARNCURVE_CONFIG_PATHS = sorted( + GENERATED_TEST_CONFIGS_ROOT.glob("*learncurve*toml") +) + +ALL_GENERATED_PREDICT_CONFIG_PATHS = sorted( + GENERATED_TEST_CONFIGS_ROOT.glob("*predict*toml") +) + +ALL_GENERATED_TRAIN_CONFIG_PATHS = sorted( + GENERATED_TEST_CONFIGS_ROOT.glob("*train*toml") +) + +# as above, we remake dict every time these fixutres get called +# so that we're not returning a ``config_dict`` that was +# already mutated by a `Config.from_config_dict` function, +# e.g. the value for the 'spect_params' key gets mapped to a SpectParamsConfig +# by PrepConfig.from_config_dict +@pytest.fixture(params=ALL_GENERATED_EVAL_CONFIG_PATHS) +def a_generated_eval_config_dict(request): + return _load_config_dict(request.param) + + +@pytest.fixture(params=ALL_GENERATED_LEARNCURVE_CONFIG_PATHS) +def a_generated_learncurve_config_dict(request): + return _load_config_dict(request.param) +@pytest.fixture(params=ALL_GENERATED_PREDICT_CONFIG_PATHS) +def a_generated_predict_config_dict(request): + return _load_config_dict(request.param) + + +@pytest.fixture(params=ALL_GENERATED_TRAIN_CONFIG_PATHS) +def a_generated_train_config_dict(request): + return _load_config_dict(request.param) + @pytest.fixture def all_generated_learncurve_configs_toml(all_generated_learncurve_configs): @@ -313,20 +327,20 @@ def all_generated_learncurve_configs_toml(all_generated_learncurve_configs): # ---- config toml + path pairs ---- -# @pytest.fixture -# def all_generated_configs_toml_path_pairs(): -# """zip of tuple pairs: (dict, pathlib.Path) -# where ``Path`` is path to .toml config file and ``dict`` is -# the .toml config from that path -# loaded into a dict with the ``toml`` library -# """ -# # we duplicate the constant above because we need to remake -# # the variables for each unit test. Otherwise tests that modify values -# # for config options cause other tests to fail -# return zip( -# [_load_config_dict(config) for config in ALL_GENERATED_CONFIGS], -# ALL_GENERATED_CONFIGS -# ) +@pytest.fixture +def all_generated_configs_toml_path_pairs(): + """zip of tuple pairs: (dict, pathlib.Path) + where ``Path`` is path to .toml config file and ``dict`` is + the .toml config from that path + loaded into a dict with the ``toml`` library + """ + # we duplicate the constant above because we need to remake + # the variables for each unit test. Otherwise tests that modify values + # for config options cause other tests to fail + return zip( + [_load_config_dict(config) for config in ALL_GENERATED_CONFIG_PATHS], + ALL_GENERATED_CONFIG_PATHS + ) @pytest.fixture From ffd0a4ea5c2528954b78410ecf1a1c33162f8dd9 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Fri, 3 May 2024 10:04:42 -0400 Subject: [PATCH 106/183] Fix/add unit tests in tests/fixtures/config.py --- tests/test_config/test_config.py | 107 ++++++++++++++++++++++++++----- 1 file changed, 92 insertions(+), 15 deletions(-) diff --git a/tests/test_config/test_config.py b/tests/test_config/test_config.py index 74d96ac6b..04622f1b7 100644 --- a/tests/test_config/test_config.py +++ b/tests/test_config/test_config.py @@ -4,16 +4,32 @@ class TestConfig: - def test_from_config_dict_with_real_config( - self, a_generated_config_dict, + @pytest.mark.parametrize( + 'tables_to_parse', + [ + None, + 'prep', + ['prep'], + ] + ) + def test_init_with_real_config( + self, a_generated_config_dict, tables_to_parse ): - """test that instantiating Config class works as expected""" - # this is basically the body of the ``config.load.from_toml`` function. + """Test that instantiating Config class works as expected""" + # this is basically the body of the ``Config.from_config_dict`` function. config_kwargs = {} - for table_name in a_generated_config_dict: - config_kwargs[table_name] = vak.config.config.TABLE_CLASSES_MAP[table_name].from_config_dict( - a_generated_config_dict[table_name] - ) + + if tables_to_parse is None: + for table_name in a_generated_config_dict: + config_kwargs[table_name] = vak.config.config.TABLE_CLASSES_MAP[table_name].from_config_dict( + a_generated_config_dict[table_name] + ) + else: + for table_name in a_generated_config_dict: + if table_name in tables_to_parse: + config_kwargs[table_name] = vak.config.config.TABLE_CLASSES_MAP[table_name].from_config_dict( + a_generated_config_dict[table_name] + ) config = vak.config.Config(**config_kwargs) @@ -22,15 +38,76 @@ def test_from_config_dict_with_real_config( # so here we just test that the logic of Config works as expected: # we should get an attribute for each top-level table that we pass in; # if we don't pass one in, then its corresponding attribute should be None - for table_name in ('eval', 'learncurve', 'predict', 'prep', 'train'): - if table_name in a_generated_config_dict: - assert hasattr(config, table_name) + for attr in ('eval', 'learncurve', 'predict', 'prep', 'train'): + if tables_to_parse is not None: + if attr in a_generated_config_dict and attr in tables_to_parse: + assert hasattr(config, attr) + else: + assert getattr(config, attr) is None else: - assert getattr(config, table_name) is None + if attr in a_generated_config_dict: + assert hasattr(config, attr) + + @pytest.mark.parametrize( + 'tables_to_parse', + [ + None, + 'prep', + ['prep'], + ] + ) + def test_from_config_dict_with_real_config( + self, a_generated_config_dict, tables_to_parse + ): + """Test :meth:`Config.from_config_dict`""" + config = vak.config.Config.from_config_dict( + a_generated_config_dict, tables_to_parse=tables_to_parse + ) - def test_from_toml_path(self, a_generated_config_path): - config_toml = vak.config.load._load_toml_from_path(a_generated_config_path) - assert isinstance(config_toml, dict) + assert isinstance(config, vak.config.Config) + # we already test that config loading works for EvalConfig, et al., + # so here we just test that the logic of Config works as expected: + # we should get an attribute for each top-level table that we pass in; + # if we don't pass one in, then its corresponding attribute should be None + for attr in ('eval', 'learncurve', 'predict', 'prep', 'train'): + if tables_to_parse is not None: + if attr in a_generated_config_dict and attr in tables_to_parse: + assert hasattr(config, attr) + else: + assert getattr(config, attr) is None + else: + if attr in a_generated_config_dict: + assert hasattr(config, attr) + + @pytest.mark.parametrize( + 'tables_to_parse', + [ + None, + 'prep', + ['prep'], + ] + ) + def test_from_toml_path(self, a_generated_config_path, tables_to_parse): + config = vak.config.Config.from_toml_path( + a_generated_config_path, tables_to_parse=tables_to_parse + ) + + assert isinstance(config, vak.config.Config) + + a_generated_config_dict = vak.config.load._load_toml_from_path(a_generated_config_path) + # we already test that config loading works for EvalConfig, et al., + # so here we just test that the logic of Config works as expected: + # we should get an attribute for each top-level table that we pass in; + # if we don't pass one in, then its corresponding attribute should be None + for attr in ('eval', 'learncurve', 'predict', 'prep', 'train'): + if tables_to_parse is not None: + if attr in a_generated_config_dict and attr in tables_to_parse: + assert hasattr(config, attr) + else: + assert getattr(config, attr) is None + else: + if attr in a_generated_config_dict: + assert hasattr(config, attr) def test_from_toml_path_raises_when_config_doesnt_exist(self, config_that_doesnt_exist): with pytest.raises(FileNotFoundError): From 63ee0ee15e0dc903b2212b06e1f7b55c85e6072b Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Fri, 3 May 2024 10:15:20 -0400 Subject: [PATCH 107/183] Rename test_config/test_parse.py -> test_load.py, fix/rewrite tests --- tests/test_config/__init__.py | 2 +- tests/test_config/test_load.py | 24 ++++ tests/test_config/test_parse.py | 243 -------------------------------- 3 files changed, 25 insertions(+), 244 deletions(-) create mode 100644 tests/test_config/test_load.py delete mode 100644 tests/test_config/test_parse.py diff --git a/tests/test_config/__init__.py b/tests/test_config/__init__.py index eb50111ee..66aaa2a8c 100644 --- a/tests/test_config/__init__.py +++ b/tests/test_config/__init__.py @@ -1,4 +1,4 @@ -from . import test_parse +from . import test_load from . import test_predict from . import test_prep from . import test_spect_params diff --git a/tests/test_config/test_load.py b/tests/test_config/test_load.py new file mode 100644 index 000000000..81c0e1809 --- /dev/null +++ b/tests/test_config/test_load.py @@ -0,0 +1,24 @@ +"""tests for vak.config.load module""" +import tomlkit + +import vak.config.load + + +def test__tomlkit_to_pop(a_generated_config_path): + with a_generated_config_path.open('r') as fp: + tomldoc = tomlkit.load(fp) + out = vak.config.load._tomlkit_to_popo(tomldoc) + assert isinstance(out, dict) + assert list(out.keys()) == ["vak"] + + +def test__load_from_toml_path(a_generated_config_path): + config_dict = vak.config.load._load_toml_from_path(a_generated_config_path) + + assert isinstance(config_dict, dict) + + with a_generated_config_path.open('r') as fp: + tomldoc = tomlkit.load(fp) + config_dict_raw = vak.config.load._tomlkit_to_popo(tomldoc) + + assert len(list(config_dict.keys())) == len(list(config_dict_raw["vak"].keys())) diff --git a/tests/test_config/test_parse.py b/tests/test_config/test_parse.py deleted file mode 100644 index 70549b34f..000000000 --- a/tests/test_config/test_parse.py +++ /dev/null @@ -1,243 +0,0 @@ -"""tests for vak.config.parse module""" -import copy - -import pytest - -import vak.config -import vak.transforms.transforms -import vak.models - - -@pytest.mark.parametrize( - "section_name", - [ - "DATALOADER", - "EVAL" "LEARNCURVE", - "PREDICT", - "PREP", - "SPECT_PARAMS", - "TRAIN", - ], -) -def test_parse_config_section_returns_attrs_class( - section_name, - configs_toml_path_pairs_by_model_factory, -): - """test that ``vak.config.parse.parse_config_section`` - returns an instance of ``vak.config.learncurve.LearncurveConfig``""" - config_toml_path_pairs = configs_toml_path_pairs_by_model_factory("TweetyNet", section_name) - for config_toml, toml_path in config_toml_path_pairs: - config_section_obj = vak.config.parse.parse_config_section( - config_toml=config_toml, - section_name=section_name, - toml_path=toml_path, - ) - assert isinstance( - config_section_obj, vak.config.parse.SECTION_CLASSES[section_name] - ) - - -@pytest.mark.parametrize( - "section_name", - [ - "EVAL", - "LEARNCURVE", - "PREDICT", - "PREP", - "SPECT_PARAMS", - "TRAIN", - ], -) -def test_parse_config_section_missing_options_raises( - section_name, - configs_toml_path_pairs_by_model_factory, -): - """test that configs without the required options in a section raise KeyError""" - if vak.config.parse.REQUIRED_OPTIONS[section_name] is None: - pytest.skip(f"no required options to test for section: {section_name}") - - configs_toml_path_pairs = configs_toml_path_pairs_by_model_factory("TweetyNet", section_name) - - for config_toml, toml_path in configs_toml_path_pairs: - for option in vak.config.parse.REQUIRED_OPTIONS[section_name]: - config_copy = copy.deepcopy(config_toml) - config_copy[section_name].pop(option) - with pytest.raises(KeyError): - config = vak.config.parse.parse_config_section( - config_toml=config_copy, - section_name=section_name, - toml_path=toml_path, - ) - - -@pytest.mark.parametrize("section_name", ["EVAL", "LEARNCURVE", "PREDICT", "TRAIN"]) -def test_parse_config_section_model_not_installed_raises( - section_name, - configs_toml_path_pairs_by_model_factory, -): - """test that a ValueError is raised when the ``models`` option - in the section specifies names of models that are not installed""" - # we only need one toml, path pair - # so we just call next on the ``zipped`` iterator that our fixture gives us - configs_toml_path_pairs = configs_toml_path_pairs_by_model_factory("TweetyNet") - - for config_toml, toml_path in configs_toml_path_pairs: - if section_name.lower() in toml_path.name: - break # use these. Only need to test on one - - config_toml[section_name]["model"] = "NotInstalledNet" - with pytest.raises(ValueError): - vak.config.parse.parse_config_section( - config_toml=config_toml, section_name=section_name, toml_path=toml_path - ) - - -def test_parse_prep_section_both_audio_and_spect_format_raises( - all_generated_configs_toml_path_pairs, -): - """test that a config with both an audio and a spect format raises a ValueError""" - # iterate through configs til we find one with an `audio_format` option - # and then we'll add a `spect_format` option to it - found_config_to_use = False - for config_toml, toml_path in all_generated_configs_toml_path_pairs: - if "audio_format" in config_toml["PREP"]: - found_config_to_use = True - break - assert found_config_to_use # sanity check - - config_toml["PREP"]["spect_format"] = "mat" - with pytest.raises(ValueError): - vak.config.parse.parse_config_section(config_toml, "PREP", toml_path) - - -def test_parse_prep_section_neither_audio_nor_spect_format_raises( - all_generated_configs_toml_path_pairs, -): - """test that a config without either an audio or a spect format raises a ValueError""" - # iterate through configs til we find one with an `audio_format` option - # and then we'll add a `spect_format` option to it - found_config_to_use = False - for config_toml, toml_path in all_generated_configs_toml_path_pairs: - if "audio_format" in config_toml["PREP"]: - found_config_to_use = True - break - assert found_config_to_use # sanity check - - config_toml["PREP"].pop("audio_format") - if "spect_format" in config_toml["PREP"]: - # shouldn't be but humor me - config_toml["PREP"].pop("spect_format") - - with pytest.raises(ValueError): - vak.config.parse.parse_config_section(config_toml, "PREP", toml_path) - - -def test_load_from_toml_path(all_generated_configs): - for toml_path in all_generated_configs: - config_toml = vak.config.parse._load_toml_from_path(toml_path) - assert isinstance(config_toml, dict) - - -def test_load_from_toml_path_raises_when_config_doesnt_exist(config_that_doesnt_exist): - with pytest.raises(FileNotFoundError): - vak.config.parse._load_toml_from_path(config_that_doesnt_exist) - - -def test_from_toml_path_returns_instance_of_config( - all_generated_configs, default_model -): - for toml_path in all_generated_configs: - if default_model not in str(toml_path): - continue # only need to check configs for one model - # also avoids FileNotFoundError on CI - config_obj = vak.config.parse.from_toml_path(toml_path) - assert isinstance(config_obj, vak.config.parse.Config) - - -def test_from_toml_path_raises_when_config_doesnt_exist(config_that_doesnt_exist): - with pytest.raises(FileNotFoundError): - vak.config.parse.from_toml_path(config_that_doesnt_exist) - - -def test_from_toml(configs_toml_path_pairs_by_model_factory): - config_toml_path_pairs = configs_toml_path_pairs_by_model_factory("TweetyNet") - for config_toml, toml_path in config_toml_path_pairs: - config_obj = vak.config.parse.from_toml(config_toml, toml_path) - assert isinstance(config_obj, vak.config.parse.Config) - - -def test_from_toml_parse_prep_with_sections_not_none( - configs_toml_path_pairs_by_model_factory, -): - """test that we get only the sections we want when we pass in a sections list to - ``from_toml``. Specifically test ``PREP`` since that's what this will be used for.""" - # only use configs from 'default_model') (TeenyTweetyNet) - # so we are sure paths exist, to avoid NotADirectoryErrors that give spurious test failures - config_toml_path_pairs = configs_toml_path_pairs_by_model_factory("TweetyNet") - for config_toml, toml_path in config_toml_path_pairs: - config_obj = vak.config.parse.from_toml( - config_toml, toml_path, sections=["PREP", "SPECT_PARAMS"] - ) - assert isinstance(config_obj, vak.config.parse.Config) - for should_have in ("prep", "spect_params"): - assert hasattr(config_obj, should_have) - for should_be_none in ("eval", "learncurve", "train", "predict"): - assert getattr(config_obj, should_be_none) is None - assert ( - getattr(config_obj, "dataloader") - == vak.config.dataloader.DataLoaderConfig() - ) - - -@pytest.mark.parametrize("section_name", ["EVAL", "LEARNCURVE", "PREDICT", "TRAIN"]) -def test_from_toml_parse_prep_with_sections_not_none( - section_name, all_generated_configs_toml_path_pairs, random_path_factory -): - """Test that ``config.parse.from_toml`` parameter ``sections`` works as expected. - - If we pass in a list of section names - specifying that we only want to parse ``PREP`` and ``SPECT_PARAMS``, - other sections should be left as None in the return Config instance.""" - for config_toml, toml_path in all_generated_configs_toml_path_pairs: - if section_name.lower() in toml_path.name: - break # use these - - purpose = vak.cli.prep.purpose_from_toml(config_toml, toml_path) - section_name = purpose.upper() - required_options = vak.config.parse.REQUIRED_OPTIONS[section_name] - for required_option in required_options: - # set option to values that **would** cause an error if we parse them - if "path" in required_option: - badval = random_path_factory(f"_{required_option}.exe") - elif "dir" in required_option: - badval = random_path_factory("nonexistent_dir") - else: - continue - config_toml[section_name][required_option] = badval - cfg = vak.config.parse.from_toml( - config_toml, toml_path, sections=["PREP", "SPECT_PARAMS"] - ) - assert hasattr(cfg, 'prep') and getattr(cfg, 'prep') is not None - assert hasattr(cfg, 'spect_params') and getattr(cfg, 'spect_params') is not None - assert getattr(cfg, purpose) is None - - -def test_invalid_section_raises(invalid_section_config_path): - with pytest.raises(ValueError): - vak.config.parse.from_toml_path(invalid_section_config_path) - - -def test_invalid_option_raises(invalid_option_config_path): - with pytest.raises(ValueError): - vak.config.parse.from_toml_path(invalid_option_config_path) - - -@pytest.fixture -def invalid_train_and_learncurve_config_toml(test_configs_root): - return test_configs_root.joinpath("invalid_train_and_learncurve_config.toml") - - -def test_train_and_learncurve_defined_raises(invalid_train_and_learncurve_config_toml): - """test that a .toml config with both a TRAIN and a LEARNCURVE section raises a ValueError""" - with pytest.raises(ValueError): - vak.config.parse.from_toml_path(invalid_train_and_learncurve_config_toml) From f08135922f2a662a7c0e6c1cf0d8bb2933c3c9dc Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Fri, 3 May 2024 10:20:18 -0400 Subject: [PATCH 108/183] Fix tests in tests/test_config/test_spect_params.py --- tests/test_config/test_spect_params.py | 25 ++++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/tests/test_config/test_spect_params.py b/tests/test_config/test_spect_params.py index 10f0fcdeb..5accbb0f1 100644 --- a/tests/test_config/test_spect_params.py +++ b/tests/test_config/test_spect_params.py @@ -45,10 +45,21 @@ def test_freq_cutoffs_wrong_order_raises(): ) -def test_spect_params_attrs_class(all_generated_configs_toml_path_pairs): - """test that instantiating SpectParamsConfig class works as expected""" - for config_toml, toml_path in all_generated_configs_toml_path_pairs: - if "SPECT_PARAMS" in config_toml: - spect_params_section = config_toml["SPECT_PARAMS"] - config = vak.config.spect_params.SpectParamsConfig(**spect_params_section) - assert isinstance(config, vak.config.spect_params.SpectParamsConfig) +class TestSpectParamsConfig: + @pytest.mark.parametrize( + 'config_dict', + [ + {'fft_size': 512, 'step_size': 64, 'freq_cutoffs': [500, 10000], 'thresh': 6.25, 'transform_type': 'log_spect'}, + ] + ) + def test_init(self, config_dict): + spect_params_config = vak.config.SpectParamsConfig(**config_dict) + assert isinstance(spect_params_config, vak.config.spect_params.SpectParamsConfig) + + def test_with_real_config(self, a_generated_config_dict): + if "spect_params" in a_generated_config_dict['prep']: + spect_config_dict = a_generated_config_dict['prep']['spect_params'] + else: + pytest.skip("No spect params in config") + spect_params_config = vak.config.spect_params.SpectParamsConfig(**spect_config_dict) + assert isinstance(spect_params_config, vak.config.spect_params.SpectParamsConfig) From dcc98a88796946c0d17d0e76832720c1addd7997 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Fri, 3 May 2024 10:28:09 -0400 Subject: [PATCH 109/183] Make fixups in tests/test_config --- tests/test_config/test_eval.py | 251 ++++++++++++++------------- tests/test_config/test_train.py | 2 +- tests/test_config/test_validators.py | 25 ++- 3 files changed, 138 insertions(+), 140 deletions(-) diff --git a/tests/test_config/test_eval.py b/tests/test_config/test_eval.py index cb2a36e8a..a83a0aca3 100644 --- a/tests/test_config/test_eval.py +++ b/tests/test_config/test_eval.py @@ -105,139 +105,140 @@ def test_from_config_dict_with_real_config(self, a_generated_eval_config_dict): assert isinstance(eval_config, vak.config.eval.EvalConfig) @pytest.mark.parametrize( - 'config_dict, expected_exception', - [ - # missing 'model', should raise KeyError - ( - { - 'checkpoint_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/TweetyNet/checkpoints/max-val-acc-checkpoint.pt', - 'labelmap_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/labelmap.json', - 'batch_size': 11, - 'num_workers': 16, - 'device': 'cuda', - 'spect_scaler_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/StandardizeSpect', - 'output_dir': './tests/data_for_tests/generated/results/eval/audio_cbin_annot_notmat/TweetyNet', - 'post_tfm_kwargs': { - 'majority_vote': True, 'min_segment_dur': 0.02 - }, - 'transform_params': { - 'window_size': 88 - }, - 'dataset': { - 'path': '~/some/path/I/made/up/for/now' + 'config_dict, expected_exception', + [ + # missing 'model', should raise KeyError + ( + { + 'checkpoint_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/TweetyNet/checkpoints/max-val-acc-checkpoint.pt', + 'labelmap_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/labelmap.json', + 'batch_size': 11, + 'num_workers': 16, + 'device': 'cuda', + 'spect_scaler_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/StandardizeSpect', + 'output_dir': './tests/data_for_tests/generated/results/eval/audio_cbin_annot_notmat/TweetyNet', + 'post_tfm_kwargs': { + 'majority_vote': True, 'min_segment_dur': 0.02 }, + 'transform_params': { + 'window_size': 88 + }, + 'dataset': { + 'path': '~/some/path/I/made/up/for/now' }, - KeyError - ), - # missing 'dataset', should raise KeyError - ( - { - 'checkpoint_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/TweetyNet/checkpoints/max-val-acc-checkpoint.pt', - 'labelmap_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/labelmap.json', - 'batch_size': 11, - 'num_workers': 16, - 'device': 'cuda', - 'spect_scaler_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/StandardizeSpect', - 'output_dir': './tests/data_for_tests/generated/results/eval/audio_cbin_annot_notmat/TweetyNet', - 'post_tfm_kwargs': { - 'majority_vote': True, 'min_segment_dur': 0.02 - }, - 'transform_params': { - 'window_size': 88 + }, + KeyError + ), + # missing 'dataset', should raise KeyError + ( + { + 'checkpoint_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/TweetyNet/checkpoints/max-val-acc-checkpoint.pt', + 'labelmap_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/labelmap.json', + 'batch_size': 11, + 'num_workers': 16, + 'device': 'cuda', + 'spect_scaler_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/StandardizeSpect', + 'output_dir': './tests/data_for_tests/generated/results/eval/audio_cbin_annot_notmat/TweetyNet', + 'post_tfm_kwargs': { + 'majority_vote': True, 'min_segment_dur': 0.02 + }, + 'transform_params': { + 'window_size': 88 + }, + 'model': { + 'TweetyNet': { + 'network': { + 'conv1_filters': 8, + 'conv1_kernel_size': [3, 3], + 'conv2_filters': 16, + 'conv2_kernel_size': [5, 5], + 'pool1_size': [4, 1], + 'pool1_stride': [4, 1], + 'pool2_size': [4, 1], + 'pool2_stride': [4, 1], + 'hidden_size': 32 }, - 'model': { - 'TweetyNet': { - 'network': { - 'conv1_filters': 8, - 'conv1_kernel_size': [3, 3], - 'conv2_filters': 16, - 'conv2_kernel_size': [5, 5], - 'pool1_size': [4, 1], - 'pool1_stride': [4, 1], - 'pool2_size': [4, 1], - 'pool2_stride': [4, 1], - 'hidden_size': 32 - }, - 'optimizer': {'lr': 0.001} - } + 'optimizer': {'lr': 0.001} + } + }, + }, + KeyError + ), + # missing 'checkpoint_path', should raise KeyError + ( + { + 'labelmap_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/labelmap.json', + 'batch_size': 11, + 'num_workers': 16, + 'device': 'cuda', + 'spect_scaler_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/StandardizeSpect', + 'output_dir': './tests/data_for_tests/generated/results/eval/audio_cbin_annot_notmat/TweetyNet', + 'post_tfm_kwargs': { + 'majority_vote': True, 'min_segment_dur': 0.02 + }, + 'transform_params': { + 'window_size': 88 + }, + 'model': { + 'TweetyNet': { + 'network': { + 'conv1_filters': 8, + 'conv1_kernel_size': [3, 3], + 'conv2_filters': 16, + 'conv2_kernel_size': [5, 5], + 'pool1_size': [4, 1], + 'pool1_stride': [4, 1], + 'pool2_size': [4, 1], + 'pool2_stride': [4, 1], + 'hidden_size': 32 }, + 'optimizer': {'lr': 0.001} + } + }, + 'dataset': { + 'path': '~/some/path/I/made/up/for/now' }, - KeyError - ), - # missing 'checkpoint_path', should raise KeyError - ( - { - 'labelmap_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/labelmap.json', - 'batch_size': 11, - 'num_workers': 16, - 'device': 'cuda', - 'spect_scaler_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/StandardizeSpect', - 'output_dir': './tests/data_for_tests/generated/results/eval/audio_cbin_annot_notmat/TweetyNet', - 'post_tfm_kwargs': { - 'majority_vote': True, 'min_segment_dur': 0.02 - }, - 'transform_params': { - 'window_size': 88 - }, - 'model': { - 'TweetyNet': { - 'network': { - 'conv1_filters': 8, - 'conv1_kernel_size': [3, 3], - 'conv2_filters': 16, - 'conv2_kernel_size': [5, 5], - 'pool1_size': [4, 1], - 'pool1_stride': [4, 1], - 'pool2_size': [4, 1], - 'pool2_stride': [4, 1], - 'hidden_size': 32 - }, - 'optimizer': {'lr': 0.001} - } - }, - 'dataset': { - 'path': '~/some/path/I/made/up/for/now' - }, - } - ), - # missing 'output_dir', should raise KeyError - ( - { - 'checkpoint_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/TweetyNet/checkpoints/max-val-acc-checkpoint.pt', - 'labelmap_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/labelmap.json', - 'batch_size': 11, - 'num_workers': 16, - 'device': 'cuda', - 'spect_scaler_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/StandardizeSpect', - 'post_tfm_kwargs': { - 'majority_vote': True, 'min_segment_dur': 0.02 - }, - 'transform_params': { - 'window_size': 88 - }, - 'model': { - 'TweetyNet': { - 'network': { - 'conv1_filters': 8, - 'conv1_kernel_size': [3, 3], - 'conv2_filters': 16, - 'conv2_kernel_size': [5, 5], - 'pool1_size': [4, 1], - 'pool1_stride': [4, 1], - 'pool2_size': [4, 1], - 'pool2_stride': [4, 1], - 'hidden_size': 32 - }, - 'optimizer': {'lr': 0.001} - } + }, + KeyError + ), + # missing 'output_dir', should raise KeyError + ( + { + 'checkpoint_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/TweetyNet/checkpoints/max-val-acc-checkpoint.pt', + 'labelmap_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/labelmap.json', + 'batch_size': 11, + 'num_workers': 16, + 'device': 'cuda', + 'spect_scaler_path': '~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/StandardizeSpect', + 'post_tfm_kwargs': { + 'majority_vote': True, 'min_segment_dur': 0.02 + }, + 'transform_params': { + 'window_size': 88 + }, + 'model': { + 'TweetyNet': { + 'network': { + 'conv1_filters': 8, + 'conv1_kernel_size': [3, 3], + 'conv2_filters': 16, + 'conv2_kernel_size': [5, 5], + 'pool1_size': [4, 1], + 'pool1_stride': [4, 1], + 'pool2_size': [4, 1], + 'pool2_stride': [4, 1], + 'hidden_size': 32 }, - 'dataset': { - 'path': '~/some/path/I/made/up/for/now' + 'optimizer': {'lr': 0.001} + } }, + 'dataset': { + 'path': '~/some/path/I/made/up/for/now' }, - KeyError - ) - ] + }, + KeyError + ) + ] ) def test_from_config_dict_raises(self, config_dict, expected_exception): with pytest.raises(expected_exception): diff --git a/tests/test_config/test_train.py b/tests/test_config/test_train.py index 3a416db98..229542410 100644 --- a/tests/test_config/test_train.py +++ b/tests/test_config/test_train.py @@ -105,7 +105,7 @@ def test_from_config_dict_with_real_config(self, a_generated_train_config_dict): assert isinstance(train_config, vak.config.train.TrainConfig) @pytest.mark.parametrize( - 'config_dict', + 'config_dict, expected_exception', [ ( { diff --git a/tests/test_config/test_validators.py b/tests/test_config/test_validators.py index 65771cc14..9d0b14421 100644 --- a/tests/test_config/test_validators.py +++ b/tests/test_config/test_validators.py @@ -1,25 +1,22 @@ import pytest -import tomlkit import vak.config.validators -def test_are_sections_valid(invalid_section_config_path): - """test that invalid section name raises a ValueError""" - with invalid_section_config_path.open("r") as fp: - config_toml = tomlkit.load(fp) +def test_are_tables_valid(invalid_table_config_path): + """test that invalid table name raises a ValueError""" + config_dict = vak.config.load._load_toml_from_path(invalid_table_config_path) with pytest.raises(ValueError): - vak.config.validators.are_sections_valid( - config_toml, invalid_section_config_path + vak.config.validators.are_tables_valid( + config_dict, invalid_table_config_path ) -def test_are_options_valid(invalid_option_config_path): - """test that section with an invalid option name raises a ValueError""" - section_with_invalid_option = "PREP" - with invalid_option_config_path.open("r") as fp: - config_toml = tomlkit.load(fp) +def test_are_keys_valid(invalid_key_config_path): + """test that table with an invalid key name raises a ValueError""" + table_with_invalid_key = "prep" + config_dict = vak.config.load._load_toml_from_path(invalid_key_config_path) with pytest.raises(ValueError): - vak.config.validators.are_options_valid( - config_toml, section_with_invalid_option, invalid_option_config_path + vak.config.validators.are_keys_valid( + config_dict, table_with_invalid_key, invalid_key_config_path ) From 82dd4f363956ebee0e6518f024f11baaadb76fbb Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Fri, 3 May 2024 10:39:58 -0400 Subject: [PATCH 110/183] Apply fixes from linter --- src/vak/__main__.py | 1 + src/vak/cli/prep.py | 15 +++++++--- src/vak/common/__init__.py | 1 + src/vak/common/constants.py | 1 + src/vak/common/converters.py | 10 ++++--- src/vak/common/labels.py | 7 +++-- src/vak/common/logging.py | 1 + src/vak/common/paths.py | 1 + src/vak/common/tensorboard.py | 1 + src/vak/common/timebins.py | 1 + src/vak/common/validators.py | 1 + src/vak/config/__init__.py | 3 +- src/vak/config/config.py | 29 ++++++++++++------- src/vak/config/dataset.py | 14 +++++---- src/vak/config/eval.py | 14 +++++---- src/vak/config/learncurve.py | 21 ++++++-------- src/vak/config/load.py | 8 +++-- src/vak/config/model.py | 22 ++++++-------- src/vak/config/predict.py | 19 ++++++------ src/vak/config/prep.py | 18 ++++++------ src/vak/config/spect_params.py | 1 + src/vak/config/train.py | 21 ++++++-------- src/vak/config/validators.py | 25 ++++++++++------ .../frame_classification/frames_dataset.py | 1 + .../datasets/frame_classification/helper.py | 1 + .../datasets/frame_classification/metadata.py | 1 + .../frame_classification/window_dataset.py | 1 + src/vak/datasets/parametric_umap/metadata.py | 1 + .../parametric_umap/parametric_umap.py | 1 + src/vak/eval/eval_.py | 1 + src/vak/eval/frame_classification.py | 1 + src/vak/eval/parametric_umap.py | 1 + src/vak/learncurve/frame_classification.py | 1 + src/vak/learncurve/learncurve.py | 1 + src/vak/metrics/util.py | 1 + src/vak/models/base.py | 1 + src/vak/models/convencoder_umap.py | 1 + src/vak/models/decorator.py | 1 + src/vak/models/definition.py | 1 + src/vak/models/ed_tcn.py | 1 + src/vak/models/frame_classification_model.py | 1 + src/vak/models/get.py | 1 + src/vak/models/parametric_umap_model.py | 1 + src/vak/models/registry.py | 1 + src/vak/models/tweetynet.py | 1 + src/vak/nets/tweetynet.py | 1 + src/vak/nn/loss/umap.py | 3 +- src/vak/nn/modules/activation.py | 1 + src/vak/nn/modules/conv.py | 1 + src/vak/plot/annot.py | 1 + src/vak/plot/learncurve.py | 1 + src/vak/plot/spect.py | 1 + src/vak/predict/frame_classification.py | 1 + src/vak/predict/parametric_umap.py | 1 + src/vak/predict/predict_.py | 1 + src/vak/prep/audio_dataset.py | 8 +++-- src/vak/prep/constants.py | 1 + src/vak/prep/dataset_df_helper.py | 1 + .../assign_samples_to_splits.py | 1 + .../frame_classification.py | 1 + .../prep/frame_classification/learncurve.py | 1 + .../prep/frame_classification/make_splits.py | 9 ++++-- .../prep/frame_classification/validators.py | 1 + .../prep/parametric_umap/dataset_arrays.py | 1 + src/vak/prep/sequence_dataset.py | 1 + src/vak/prep/spectrogram_dataset/__init__.py | 1 + src/vak/prep/spectrogram_dataset/spect.py | 7 +++-- .../prep/spectrogram_dataset/spect_helper.py | 9 ++++-- src/vak/prep/split/split.py | 1 + src/vak/prep/unit_dataset/unit_dataset.py | 1 + src/vak/train/frame_classification.py | 1 + src/vak/train/parametric_umap.py | 1 + src/vak/train/train_.py | 1 + .../defaults/frame_classification.py | 1 + src/vak/transforms/defaults/get.py | 1 + .../transforms/defaults/parametric_umap.py | 1 + src/vak/transforms/frame_labels/functional.py | 1 + src/vak/transforms/frame_labels/transforms.py | 1 + 78 files changed, 205 insertions(+), 116 deletions(-) diff --git a/src/vak/__main__.py b/src/vak/__main__.py index c3f6c0bac..a25d3f833 100644 --- a/src/vak/__main__.py +++ b/src/vak/__main__.py @@ -2,6 +2,7 @@ Invokes __main__ when the module is run as a script. Example: python -m vak --help """ + import argparse from pathlib import Path diff --git a/src/vak/cli/prep.py b/src/vak/cli/prep.py index 955c07945..be7861986 100644 --- a/src/vak/cli/prep.py +++ b/src/vak/cli/prep.py @@ -1,4 +1,5 @@ """Function called by command-line interface for prep command""" + from __future__ import annotations import pathlib @@ -34,7 +35,9 @@ def purpose_from_toml( purpose = None for table_name in commands_that_are_not_prep: if table_name in config_dict: - purpose = table_name # this top-level table is the "purpose" of the file + purpose = ( + table_name # this top-level table is the "purpose" of the file + ) if purpose is None: raise ValueError( "Did not find a top-level table in configuration file that corresponds to a CLI command. " @@ -44,6 +47,7 @@ def purpose_from_toml( ) return purpose + # note NO LOGGING -- we configure logger inside `core.prep` # so we can save log file inside dataset directory @@ -94,7 +98,10 @@ def prep(toml_path): # ---- figure out purpose of config file from tables; will save path of prep'd dataset in that table ------------------------- purpose = purpose_from_toml(config_dict, toml_path) - if "dataset" in config_dict[purpose] and "path" in config_dict[purpose]["dataset"]: + if ( + "dataset" in config_dict[purpose] + and "path" in config_dict[purpose]["dataset"] + ): raise ValueError( f"This configuration file already has a '{purpose}.dataset' table with a 'path' key, " f"and running `prep` would overwrite the value for that key. To `prep` a new dataset, please " @@ -145,9 +152,9 @@ def prep(toml_path): ) # we re-open config using tomlkit so we can add path to dataset table in style-preserving way - with toml_path.open('r') as fp: + with toml_path.open("r") as fp: tomldoc = tomlkit.load(fp) - if 'dataset' not in tomldoc['vak'][purpose]: + if "dataset" not in tomldoc["vak"][purpose]: dataset_table = tomlkit.table() tomldoc["vak"][purpose].add("dataset", dataset_table) tomldoc["vak"][purpose]["dataset"].add("path", str(dataset_path)) diff --git a/src/vak/common/__init__.py b/src/vak/common/__init__.py index e453adbb6..777bd9afb 100644 --- a/src/vak/common/__init__.py +++ b/src/vak/common/__init__.py @@ -7,6 +7,7 @@ See for example :mod:`vak.prep.prep_helper` or :mod:`vak.datsets.window_dataset._helper`. """ + from . import ( annotation, constants, diff --git a/src/vak/common/constants.py b/src/vak/common/constants.py index fcf2ab7d6..a3a34315f 100644 --- a/src/vak/common/constants.py +++ b/src/vak/common/constants.py @@ -1,6 +1,7 @@ """constants used by multiple modules. Defined here to avoid circular imports. """ + from functools import partial import crowsetta diff --git a/src/vak/common/converters.py b/src/vak/common/converters.py index 6b349e182..8d5735649 100644 --- a/src/vak/common/converters.py +++ b/src/vak/common/converters.py @@ -52,10 +52,12 @@ def range_str(range_str, sort=True): subrange, substr ) ) - list_range.extend([int(subrange[0])]) if len( - subrange - ) == 1 else list_range.extend( - range(int(subrange[0]), int(subrange[1]) + 1) + ( + list_range.extend([int(subrange[0])]) + if len(subrange) == 1 + else list_range.extend( + range(int(subrange[0]), int(subrange[1]) + 1) + ) ) if sort: diff --git a/src/vak/common/labels.py b/src/vak/common/labels.py index dd515df40..5f851cdec 100644 --- a/src/vak/common/labels.py +++ b/src/vak/common/labels.py @@ -172,8 +172,7 @@ def multi_char_labels_to_single_char( # which would map it to a new integer and cause us to lose the original integer # from the mapping single_char_labels_not_in_labelmap = [ - lbl for lbl in DUMMY_SINGLE_CHAR_LABELS - if lbl not in labelmap + lbl for lbl in DUMMY_SINGLE_CHAR_LABELS if lbl not in labelmap ] n_needed_to_remap = len( [lbl for lbl in current_str_labels if len(lbl) > 1] @@ -187,7 +186,9 @@ def multi_char_labels_to_single_char( new_labelmap = {} for dummy_label_ind, label_str in enumerate(current_str_labels): label_int = labelmap[label_str] - if len(label_str) > 1 and label_str not in skip: # default for `skip` is ('unlabeled',) + if ( + len(label_str) > 1 and label_str not in skip + ): # default for `skip` is ('unlabeled',) # replace with dummy label new_label_str = single_char_labels_not_in_labelmap[dummy_label_ind] new_labelmap[new_label_str] = label_int diff --git a/src/vak/common/logging.py b/src/vak/common/logging.py index 8ced29688..fa65f272d 100644 --- a/src/vak/common/logging.py +++ b/src/vak/common/logging.py @@ -1,4 +1,5 @@ """utility functions for logging""" + import logging import sys import warnings diff --git a/src/vak/common/paths.py b/src/vak/common/paths.py index 212ad32fc..12648f393 100644 --- a/src/vak/common/paths.py +++ b/src/vak/common/paths.py @@ -1,4 +1,5 @@ """functions for working with paths""" + from pathlib import Path from . import constants, timenow diff --git a/src/vak/common/tensorboard.py b/src/vak/common/tensorboard.py index 6e6b50d88..43db0e53e 100644 --- a/src/vak/common/tensorboard.py +++ b/src/vak/common/tensorboard.py @@ -1,4 +1,5 @@ """Functions dealing with ``tensorboard``""" + from __future__ import annotations from pathlib import Path diff --git a/src/vak/common/timebins.py b/src/vak/common/timebins.py index dd1d8375a..afef34e55 100644 --- a/src/vak/common/timebins.py +++ b/src/vak/common/timebins.py @@ -1,5 +1,6 @@ """module for functions that deal with vector of times from a spectrogram, i.e. where elements are the times at bin centers""" + import numpy as np diff --git a/src/vak/common/validators.py b/src/vak/common/validators.py index ecc7f1d5d..b51399bc1 100644 --- a/src/vak/common/validators.py +++ b/src/vak/common/validators.py @@ -1,4 +1,5 @@ """Functions for input validation""" + import pathlib import warnings diff --git a/src/vak/config/__init__.py b/src/vak/config/__init__.py index 7de1fa9b8..0321b3f98 100644 --- a/src/vak/config/__init__.py +++ b/src/vak/config/__init__.py @@ -1,4 +1,5 @@ """sub-package that parses config.toml files and returns config object""" + from . import ( config, dataset, @@ -12,7 +13,6 @@ train, validators, ) - from .config import Config from .dataset import DatasetConfig from .eval import EvalConfig @@ -23,7 +23,6 @@ from .spect_params import SpectParamsConfig from .train import TrainConfig - __all__ = [ "config", "eval", diff --git a/src/vak/config/config.py b/src/vak/config/config.py index d94450439..553afb464 100644 --- a/src/vak/config/config.py +++ b/src/vak/config/config.py @@ -1,10 +1,11 @@ """Class that represents the TOML configuration file used with the vak command-line interface.""" + from __future__ import annotations import pathlib -from attrs import define, field from attr.validators import instance_of, optional +from attrs import define, field from . import load from .eval import EvalConfig @@ -14,7 +15,6 @@ from .train import TrainConfig from .validators import are_keys_valid, are_tables_valid - TABLE_CLASSES_MAP = { "eval": EvalConfig, "learncurve": LearncurveConfig, @@ -24,7 +24,9 @@ } -def _validate_tables_to_parse_arg_convert_list(tables_to_parse: str | list[str]) -> list[str]: +def _validate_tables_to_parse_arg_convert_list( + tables_to_parse: str | list[str], +) -> list[str]: """Helper function used by :func:`from_toml` that validates the ``tables_to_parse`` argument, and returns it as a list of strings.""" @@ -73,6 +75,7 @@ class Config: learncurve : vak.config.learncurve.LearncurveConfig Represents ``[vak.learncurve]`` table of config.toml file """ + prep = field(validator=optional(instance_of(PrepConfig)), default=None) train = field(validator=optional(instance_of(TrainConfig)), default=None) eval = field(validator=optional(instance_of(EvalConfig)), default=None) @@ -85,11 +88,11 @@ class Config: @classmethod def from_config_dict( - cls, - config_dict: dict, - tables_to_parse: str | list[str] | None = None, - toml_path: str | pathlib.Path | None = None, - ) -> "Config": + cls, + config_dict: dict, + tables_to_parse: str | list[str] | None = None, + toml_path: str | pathlib.Path | None = None, + ) -> "Config": """Return instance of :class:`Config` class, given a :class:`dict` containing the contents of a TOML configuration file. @@ -128,14 +131,18 @@ def from_config_dict( config_dict.keys() ) # i.e., parse all top-level tables else: - tables_to_parse = _validate_tables_to_parse_arg_convert_list(tables_to_parse) + tables_to_parse = _validate_tables_to_parse_arg_convert_list( + tables_to_parse + ) config_kwargs = {} for table_name in tables_to_parse: if table_name in config_dict: are_keys_valid(config_dict, table_name, toml_path) table_config_dict = config_dict[table_name] - config_kwargs[table_name] = TABLE_CLASSES_MAP[table_name].from_config_dict(table_config_dict) + config_kwargs[table_name] = TABLE_CLASSES_MAP[ + table_name + ].from_config_dict(table_config_dict) else: raise KeyError( f"A table specified in `tables_to_parse` was not found in the config: {table_name}" @@ -147,7 +154,7 @@ def from_config_dict( def from_toml_path( cls, toml_path: str | pathlib.Path, - tables_to_parse: list[str] | None = None + tables_to_parse: list[str] | None = None, ) -> "Config": """Return instance of :class:`Config` class, given the path to a TOML configuration file. diff --git a/src/vak/config/dataset.py b/src/vak/config/dataset.py index 60ed15ed1..9ec2d2ff9 100644 --- a/src/vak/config/dataset.py +++ b/src/vak/config/dataset.py @@ -1,10 +1,11 @@ """Class that represents dataset table in configuration file.""" + from __future__ import annotations import pathlib -from attr import define, field import attr.validators +from attr import define, field @define @@ -23,18 +24,19 @@ class DatasetConfig: Name of dataset. Only required for built-in datasets from the :mod:`~vak.datasets` module. """ + path: pathlib.Path = field(converter=pathlib.Path) splits_path: pathlib.Path | None = field( converter=attr.converters.optional(pathlib.Path), default=None - ) + ) name: str | None = field( converter=attr.converters.optional(str), default=None - ) + ) @classmethod def from_config_dict(cls, dict_: dict) -> DatasetConfig: return cls( - path=dict_.get('path'), - splits_path=dict_.get('splits_path'), - name=dict_.get('name'), + path=dict_.get("path"), + splits_path=dict_.get("splits_path"), + name=dict_.get("name"), ) diff --git a/src/vak/config/eval.py b/src/vak/config/eval.py index 64f764dd3..d52aca1f0 100644 --- a/src/vak/config/eval.py +++ b/src/vak/config/eval.py @@ -1,10 +1,10 @@ """Class and functions for ``[vak.eval]`` table in configuration file.""" + from __future__ import annotations import pathlib -from attrs import define, field -from attrs import converters, validators +from attrs import converters, define, field, validators from attrs.validators import instance_of from ..common import device @@ -200,8 +200,10 @@ def from_config_dict(cls, config_dict: dict) -> EvalConfig: "when loading the configuration file into a Python dictionary. " "Please check that the configuration file is formatted correctly." ) - config_dict['dataset'] = DatasetConfig.from_config_dict(config_dict['dataset']) - config_dict['model'] = ModelConfig.from_config_dict(config_dict['model']) - return cls( - **config_dict + config_dict["dataset"] = DatasetConfig.from_config_dict( + config_dict["dataset"] + ) + config_dict["model"] = ModelConfig.from_config_dict( + config_dict["model"] ) + return cls(**config_dict) diff --git a/src/vak/config/learncurve.py b/src/vak/config/learncurve.py index 4df1e0423..fdf2b883a 100644 --- a/src/vak/config/learncurve.py +++ b/src/vak/config/learncurve.py @@ -1,20 +1,15 @@ """Class that represents ``[vak.learncurve]`` table in configuration file.""" + from __future__ import annotations -from attrs import define, field -from attrs import converters, validators +from attrs import converters, define, field, validators from .dataset import DatasetConfig from .eval import are_valid_post_tfm_kwargs, convert_post_tfm_kwargs from .model import ModelConfig from .train import TrainConfig - -REQUIRED_KEYS = ( - 'dataset', - 'model', - 'root_results_dir' -) +REQUIRED_KEYS = ("dataset", "model", "root_results_dir") @define @@ -92,8 +87,10 @@ def from_config_dict(cls, config_dict: dict) -> "TrainConfig": "when loading the configuration file into a Python dictionary. " "Please check that the configuration file is formatted correctly." ) - config_dict['model'] = ModelConfig.from_config_dict(config_dict['model']) - config_dict['dataset'] = DatasetConfig.from_config_dict(config_dict['dataset']) - return cls( - **config_dict + config_dict["model"] = ModelConfig.from_config_dict( + config_dict["model"] + ) + config_dict["dataset"] = DatasetConfig.from_config_dict( + config_dict["dataset"] ) + return cls(**config_dict) diff --git a/src/vak/config/load.py b/src/vak/config/load.py index f661c4c78..19f0d5abd 100644 --- a/src/vak/config/load.py +++ b/src/vak/config/load.py @@ -1,4 +1,5 @@ """Functions to parse toml config files.""" + from __future__ import annotations import pathlib @@ -31,7 +32,8 @@ def _tomlkit_to_popo(d): result = [_tomlkit_to_popo(x) for x in result] elif isinstance(result, dict): result = { - _tomlkit_to_popo(key): _tomlkit_to_popo(val) for key, val in result.items() + _tomlkit_to_popo(key): _tomlkit_to_popo(val) + for key, val in result.items() } elif isinstance(result, tomlkit.items.Integer): result = int(result) @@ -76,7 +78,7 @@ def _load_toml_from_path(toml_path: str | pathlib.Path) -> dict: f"Error when parsing .toml config file: {toml_path}" ) from e - if 'vak' not in config_dict: + if "vak" not in config_dict: raise ValueError( "Toml file does not contain a top-level table named `vak`. " "Please see example configuration files here:\n" @@ -91,4 +93,4 @@ def _load_toml_from_path(toml_path: str | pathlib.Path) -> dict: # and then assigns it to the ``spect_params`` key. # We would get this error if we just return the result of :func:`tomlkit.load`, # which is a `tomlkit.TOMLDocument` that tries to ensure that everything is valid toml. - return _tomlkit_to_popo(config_dict)['vak'] + return _tomlkit_to_popo(config_dict)["vak"] diff --git a/src/vak/config/model.py b/src/vak/config/model.py index c1a4ce0cf..cc5259fed 100644 --- a/src/vak/config/model.py +++ b/src/vak/config/model.py @@ -1,4 +1,5 @@ """Class representing the model table of a toml configuration file.""" + from __future__ import annotations import pathlib @@ -8,7 +9,6 @@ from .. import models - MODEL_TABLES = [ "network", "optimizer", @@ -36,6 +36,7 @@ class ModelConfig: A :class:`dict` of ``dict``s mapping metric names to keyword arguments. """ + name: str network: dict = field(validator=instance_of(dict)) optimizer: dict = field(validator=instance_of(dict)) @@ -80,9 +81,7 @@ def from_config_dict(cls, config_dict: dict): f"Model names in registry:\n{MODEL_NAMES}" ) model_config = config_dict[model_name] - if not all( - key in MODEL_TABLES for key in model_config.keys() - ): + if not all(key in MODEL_TABLES for key in model_config.keys()): invalid_keys = ( key for key in model_config.keys() if key not in MODEL_TABLES ) @@ -94,10 +93,7 @@ def from_config_dict(cls, config_dict: dict): for model_table in MODEL_TABLES: if model_table not in config_dict: model_config[model_table] = {} - return cls( - name=model_name, - **model_config - ) + return cls(name=model_name, **model_config) def to_dict(self): """Convert this :class:`ModelConfig` instance @@ -109,8 +105,8 @@ def to_dict(self): and returns all other attributes in a :class:`dict`. """ return { - 'network': self.network, - 'optimizer': self.optimizer, - 'loss': self.loss, - 'metrics': self.metrics, - } \ No newline at end of file + "network": self.network, + "optimizer": self.optimizer, + "loss": self.loss, + "metrics": self.metrics, + } diff --git a/src/vak/config/predict.py b/src/vak/config/predict.py index a7f7c826d..0029cc956 100644 --- a/src/vak/config/predict.py +++ b/src/vak/config/predict.py @@ -1,18 +1,18 @@ """Class that represents ``[vak.predict]`` table of configuration file.""" + from __future__ import annotations import os from pathlib import Path -from attrs import define, field from attr import converters, validators from attr.validators import instance_of +from attrs import define, field -from .dataset import DatasetConfig -from .model import ModelConfig from ..common import device from ..common.converters import expanded_user_path - +from .dataset import DatasetConfig +from .model import ModelConfig REQUIRED_KEYS = ( "checkpoint_path", @@ -141,7 +141,6 @@ class PredictConfig: default=None, ) - @classmethod def from_config_dict(cls, config_dict: dict) -> PredictConfig: """Return :class:`PredictConfig` instance from a :class:`dict`. @@ -159,8 +158,10 @@ def from_config_dict(cls, config_dict: dict) -> PredictConfig: "when loading the configuration file into a Python dictionary. " "Please check that the configuration file is formatted correctly." ) - config_dict['dataset'] = DatasetConfig.from_config_dict(config_dict['dataset']) - config_dict['model'] = ModelConfig.from_config_dict(config_dict['model']) - return cls( - **config_dict + config_dict["dataset"] = DatasetConfig.from_config_dict( + config_dict["dataset"] + ) + config_dict["model"] = ModelConfig.from_config_dict( + config_dict["model"] ) + return cls(**config_dict) diff --git a/src/vak/config/prep.py b/src/vak/config/prep.py index bd521c03a..3023ce204 100644 --- a/src/vak/config/prep.py +++ b/src/vak/config/prep.py @@ -1,17 +1,17 @@ """Class and functions for ``[vak.prep]`` table of configuration file.""" + from __future__ import annotations import inspect -from attrs import define, field import dask.bag -from attrs import converters, validators +from attrs import converters, define, field, validators from attrs.validators import instance_of -from .spect_params import SpectParamsConfig -from .validators import is_annot_format, is_audio_format, is_spect_format from .. import prep from ..common.converters import expanded_user_path, labelset_to_set +from .spect_params import SpectParamsConfig +from .validators import is_annot_format, is_audio_format, is_spect_format def duration_from_toml_value(value): @@ -239,8 +239,8 @@ def from_config_dict(cls, config_dict: dict) -> PrepConfig: "when loading the configuration file into a Python dictionary. " "Please check that the configuration file is formatted correctly." ) - if 'spect_params' in config_dict: - config_dict['spect_params'] = SpectParamsConfig(**config_dict['spect_params']) - return cls( - **config_dict - ) + if "spect_params" in config_dict: + config_dict["spect_params"] = SpectParamsConfig( + **config_dict["spect_params"] + ) + return cls(**config_dict) diff --git a/src/vak/config/spect_params.py b/src/vak/config/spect_params.py index 4a61942a6..b570f9e7c 100644 --- a/src/vak/config/spect_params.py +++ b/src/vak/config/spect_params.py @@ -1,4 +1,5 @@ """parses [SPECT_PARAMS] section of config""" + import attr from attr import converters, validators from attr.validators import instance_of diff --git a/src/vak/config/train.py b/src/vak/config/train.py index 135a05541..7ee5cdaff 100644 --- a/src/vak/config/train.py +++ b/src/vak/config/train.py @@ -1,6 +1,6 @@ """Class that represents ``[vak.train]`` table of configuration file.""" -from attrs import define, field -from attrs import converters, validators + +from attrs import converters, define, field, validators from attrs.validators import instance_of from ..common import device @@ -8,12 +8,7 @@ from .dataset import DatasetConfig from .model import ModelConfig - -REQUIRED_KEYS = ( - 'dataset', - 'model', - 'root_results_dir' -) +REQUIRED_KEYS = ("dataset", "model", "root_results_dir") @define @@ -169,8 +164,10 @@ def from_config_dict(cls, config_dict: dict) -> "TrainConfig": "when loading the configuration file into a Python dictionary. " "Please check that the configuration file is formatted correctly." ) - config_dict['model'] = ModelConfig.from_config_dict(config_dict['model']) - config_dict['dataset'] = DatasetConfig.from_config_dict(config_dict['dataset']) - return cls( - **config_dict + config_dict["model"] = ModelConfig.from_config_dict( + config_dict["model"] + ) + config_dict["dataset"] = DatasetConfig.from_config_dict( + config_dict["dataset"] ) + return cls(**config_dict) diff --git a/src/vak/config/validators.py b/src/vak/config/validators.py index 9a1a39ca0..40052eb06 100644 --- a/src/vak/config/validators.py +++ b/src/vak/config/validators.py @@ -1,4 +1,5 @@ """validators used by attrs-based classes and by vak.parse.parse_config""" + import pathlib import tomlkit @@ -60,10 +61,11 @@ def is_spect_format(instance, attribute, value): CONFIG_DIR = pathlib.Path(__file__).parent VALID_TOML_PATH = CONFIG_DIR.joinpath("valid-version-1.0.toml") with VALID_TOML_PATH.open("r") as fp: - VALID_DICT = tomlkit.load(fp)['vak'] + VALID_DICT = tomlkit.load(fp)["vak"] VALID_TOP_LEVEL_TABLES = list(VALID_DICT.keys()) VALID_KEYS = { - table_name: list(table_config_dict.keys()) for table_name, table_config_dict in VALID_DICT.items() + table_name: list(table_config_dict.keys()) + for table_name, table_config_dict in VALID_DICT.items() } @@ -80,9 +82,7 @@ def are_tables_valid(config_dict, toml_path=None): command for command in CLI_COMMANDS if command != "prep" ] tables_that_are_commands_besides_prep = [ - table - for table in tables - if table in cli_commands_besides_prep + table for table in tables if table in cli_commands_besides_prep ] if len(tables_that_are_commands_besides_prep) == 0: raise ValueError( @@ -117,10 +117,13 @@ def are_tables_valid(config_dict, toml_path=None): def are_keys_valid( - config_dict: dict, table_name: str, toml_path: str | pathlib.Path | None = None - ) -> None: + config_dict: dict, + table_name: str, + toml_path: str | pathlib.Path | None = None, +) -> None: """Given a :class:`dict` containing the *entire* configuration loaded from a toml file, - validate the key names for a specific top-level table, e.g. ``vak.train`` or ``vak.predict``""" + validate the key names for a specific top-level table, e.g. ``vak.train`` or ``vak.predict`` + """ table_keys = set(config_dict[table_name].keys()) valid_keys = set(VALID_KEYS[table_name]) if not table_keys.issubset(valid_keys): @@ -138,7 +141,11 @@ def are_keys_valid( raise ValueError(err_msg) -def are_table_keys_valid(table_config_dict: dict, table_name: str, toml_path: str | pathlib.Path | None = None) -> None: +def are_table_keys_valid( + table_config_dict: dict, + table_name: str, + toml_path: str | pathlib.Path | None = None, +) -> None: """Given a :class:`dict` containing the configuration for a *specific* top-level table, loaded from a toml file, validate the key names for that table, e.g. ``vak.train`` or ``vak.predict``. diff --git a/src/vak/datasets/frame_classification/frames_dataset.py b/src/vak/datasets/frame_classification/frames_dataset.py index 6d91ad77e..a8f5f6de0 100644 --- a/src/vak/datasets/frame_classification/frames_dataset.py +++ b/src/vak/datasets/frame_classification/frames_dataset.py @@ -1,6 +1,7 @@ """A dataset class used for neural network models with the frame classification task, where the source data consists of audio signals or spectrograms of varying lengths.""" + from __future__ import annotations import pathlib diff --git a/src/vak/datasets/frame_classification/helper.py b/src/vak/datasets/frame_classification/helper.py index 41163cf79..d6a6f19b1 100644 --- a/src/vak/datasets/frame_classification/helper.py +++ b/src/vak/datasets/frame_classification/helper.py @@ -1,4 +1,5 @@ """Helper functions used with frame classification datasets.""" + from __future__ import annotations from ... import common diff --git a/src/vak/datasets/frame_classification/metadata.py b/src/vak/datasets/frame_classification/metadata.py index 61c7cb918..b7a532aae 100644 --- a/src/vak/datasets/frame_classification/metadata.py +++ b/src/vak/datasets/frame_classification/metadata.py @@ -2,6 +2,7 @@ associated with a frame classification dataset, as generated by :func:`vak.core.prep.frame_classification.prep_frame_classification_dataset`""" + from __future__ import annotations import json diff --git a/src/vak/datasets/frame_classification/window_dataset.py b/src/vak/datasets/frame_classification/window_dataset.py index 30fe034dd..37fd6b910 100644 --- a/src/vak/datasets/frame_classification/window_dataset.py +++ b/src/vak/datasets/frame_classification/window_dataset.py @@ -15,6 +15,7 @@ :math:`I` determined by a ``stride`` parameter :math:`s`, :math:`I = (T - w) / s`. """ + from __future__ import annotations import pathlib diff --git a/src/vak/datasets/parametric_umap/metadata.py b/src/vak/datasets/parametric_umap/metadata.py index ac0b8a137..a821a223a 100644 --- a/src/vak/datasets/parametric_umap/metadata.py +++ b/src/vak/datasets/parametric_umap/metadata.py @@ -2,6 +2,7 @@ associated with a dimensionality reduction dataset, as generated by :func:`vak.core.prep.frame_classification.prep_dimensionality_reduction_dataset`""" + from __future__ import annotations import json diff --git a/src/vak/datasets/parametric_umap/parametric_umap.py b/src/vak/datasets/parametric_umap/parametric_umap.py index 052975d9c..d95cb0150 100644 --- a/src/vak/datasets/parametric_umap/parametric_umap.py +++ b/src/vak/datasets/parametric_umap/parametric_umap.py @@ -1,4 +1,5 @@ """A dataset class used to train Parametric UMAP models.""" + from __future__ import annotations import pathlib diff --git a/src/vak/eval/eval_.py b/src/vak/eval/eval_.py index 7f57d8f99..8800bb815 100644 --- a/src/vak/eval/eval_.py +++ b/src/vak/eval/eval_.py @@ -1,4 +1,5 @@ """High-level function that evaluates trained models.""" + from __future__ import annotations import logging diff --git a/src/vak/eval/frame_classification.py b/src/vak/eval/frame_classification.py index 531c55d6c..cf86670a4 100644 --- a/src/vak/eval/frame_classification.py +++ b/src/vak/eval/frame_classification.py @@ -1,4 +1,5 @@ """Function that evaluates trained models in the frame classification family.""" + from __future__ import annotations import json diff --git a/src/vak/eval/parametric_umap.py b/src/vak/eval/parametric_umap.py index 09dd8891b..5eeadf19b 100644 --- a/src/vak/eval/parametric_umap.py +++ b/src/vak/eval/parametric_umap.py @@ -1,4 +1,5 @@ """Function that evaluates trained models in the parametric UMAP family.""" + from __future__ import annotations import logging diff --git a/src/vak/learncurve/frame_classification.py b/src/vak/learncurve/frame_classification.py index f363a946a..30690ed3c 100644 --- a/src/vak/learncurve/frame_classification.py +++ b/src/vak/learncurve/frame_classification.py @@ -1,4 +1,5 @@ """Function that generates results for a learning curve for frame classification models.""" + from __future__ import annotations import logging diff --git a/src/vak/learncurve/learncurve.py b/src/vak/learncurve/learncurve.py index 781b601aa..63b27a6cb 100644 --- a/src/vak/learncurve/learncurve.py +++ b/src/vak/learncurve/learncurve.py @@ -1,4 +1,5 @@ """High-level function that generates results for a learning curve for all models.""" + from __future__ import annotations import logging diff --git a/src/vak/metrics/util.py b/src/vak/metrics/util.py index 7bbdbc226..e9cf84dd1 100644 --- a/src/vak/metrics/util.py +++ b/src/vak/metrics/util.py @@ -8,6 +8,7 @@ https://setuptools.readthedocs.io/en/latest/setuptools.html#dynamic-discovery-of-services-and-plugins https://amir.rachum.com/blog/2017/07/28/python-entry-points/ """ + from .. import entry_points METRICS_ENTRY_POINT = "vak.models" diff --git a/src/vak/models/base.py b/src/vak/models/base.py index 2aa47022a..fd27b7ae3 100644 --- a/src/vak/models/base.py +++ b/src/vak/models/base.py @@ -1,6 +1,7 @@ """Base class for a model in ``vak``, that other families of models should subclass. """ + from __future__ import annotations import inspect diff --git a/src/vak/models/convencoder_umap.py b/src/vak/models/convencoder_umap.py index a7a894b23..7e06efe1c 100644 --- a/src/vak/models/convencoder_umap.py +++ b/src/vak/models/convencoder_umap.py @@ -5,6 +5,7 @@ with changes made by Tim Sainburg: https://github.com/lmcinnes/umap/issues/580#issuecomment-1368649550. """ + from __future__ import annotations import torch diff --git a/src/vak/models/decorator.py b/src/vak/models/decorator.py index a0aa717fe..5a0fff875 100644 --- a/src/vak/models/decorator.py +++ b/src/vak/models/decorator.py @@ -8,6 +8,7 @@ The subclass can then be instantiated and have all model methods. """ + from __future__ import annotations from typing import Type diff --git a/src/vak/models/definition.py b/src/vak/models/definition.py index 14b5435de..b3742d2a8 100644 --- a/src/vak/models/definition.py +++ b/src/vak/models/definition.py @@ -1,6 +1,7 @@ """Code that handles classes that represent the definition of a neural network model; the abstraction of how models are declared with code in vak.""" + from __future__ import annotations import dataclasses diff --git a/src/vak/models/ed_tcn.py b/src/vak/models/ed_tcn.py index 11a195531..38fd9a325 100644 --- a/src/vak/models/ed_tcn.py +++ b/src/vak/models/ed_tcn.py @@ -1,5 +1,6 @@ """ """ + from __future__ import annotations import torch diff --git a/src/vak/models/frame_classification_model.py b/src/vak/models/frame_classification_model.py index ce35dc401..4ee2088a9 100644 --- a/src/vak/models/frame_classification_model.py +++ b/src/vak/models/frame_classification_model.py @@ -2,6 +2,7 @@ where a model predicts a label for each frame in a time series, e.g., each time bin in a window from a spectrogram.""" + from __future__ import annotations import logging diff --git a/src/vak/models/get.py b/src/vak/models/get.py index e4cc3ab06..b6f6a849c 100644 --- a/src/vak/models/get.py +++ b/src/vak/models/get.py @@ -1,5 +1,6 @@ """Function that gets an instance of a model, given its name and a configuration as a dict.""" + from __future__ import annotations import inspect diff --git a/src/vak/models/parametric_umap_model.py b/src/vak/models/parametric_umap_model.py index 67203b71c..4d2c2cb94 100644 --- a/src/vak/models/parametric_umap_model.py +++ b/src/vak/models/parametric_umap_model.py @@ -5,6 +5,7 @@ with changes made by Tim Sainburg: https://github.com/lmcinnes/umap/issues/580#issuecomment-1368649550. """ + from __future__ import annotations import pathlib diff --git a/src/vak/models/registry.py b/src/vak/models/registry.py index e9f01d23c..b187b2480 100644 --- a/src/vak/models/registry.py +++ b/src/vak/models/registry.py @@ -3,6 +3,7 @@ Makes it possible to register a model declared outside of ``vak`` with a decorator, so that the model can be used at runtime. """ + from __future__ import annotations import inspect diff --git a/src/vak/models/tweetynet.py b/src/vak/models/tweetynet.py index 62e7b58bf..b9631be59 100644 --- a/src/vak/models/tweetynet.py +++ b/src/vak/models/tweetynet.py @@ -6,6 +6,7 @@ Paper: https://elifesciences.org/articles/63853 Code: https://github.com/yardencsGitHub/tweetynet """ + from __future__ import annotations import torch diff --git a/src/vak/nets/tweetynet.py b/src/vak/nets/tweetynet.py index ed2ec5e7b..ab5f8defc 100644 --- a/src/vak/nets/tweetynet.py +++ b/src/vak/nets/tweetynet.py @@ -1,4 +1,5 @@ """TweetyNet model""" + from __future__ import annotations import torch diff --git a/src/vak/nn/loss/umap.py b/src/vak/nn/loss/umap.py index 8e59be403..077ee0ef2 100644 --- a/src/vak/nn/loss/umap.py +++ b/src/vak/nn/loss/umap.py @@ -1,4 +1,5 @@ """Parametric UMAP loss function.""" + from __future__ import annotations import warnings @@ -77,7 +78,7 @@ def umap_loss( distance_embedding = torch.cat( ( (embedding_to - embedding_from).norm(dim=1), - (embedding_neg_to - embedding_neg_from).norm(dim=1) + (embedding_neg_to - embedding_neg_from).norm(dim=1), # ``to`` method in next line to avoid error `Expected all tensors to be on the same device` ), dim=0, diff --git a/src/vak/nn/modules/activation.py b/src/vak/nn/modules/activation.py index 57a884496..5173ceee8 100644 --- a/src/vak/nn/modules/activation.py +++ b/src/vak/nn/modules/activation.py @@ -1,4 +1,5 @@ """Modules that act as activation functions.""" + import torch diff --git a/src/vak/nn/modules/conv.py b/src/vak/nn/modules/conv.py index 778e5abf3..c249a1b56 100644 --- a/src/vak/nn/modules/conv.py +++ b/src/vak/nn/modules/conv.py @@ -1,4 +1,5 @@ """Modules that perform neural network convolutions.""" + import torch from torch.nn import functional as F diff --git a/src/vak/plot/annot.py b/src/vak/plot/annot.py index fca7294d0..dcb8180f3 100644 --- a/src/vak/plot/annot.py +++ b/src/vak/plot/annot.py @@ -1,4 +1,5 @@ """functions for plotting annotations for vocalizations""" + import matplotlib.pyplot as plt import numpy as np from matplotlib.collections import LineCollection diff --git a/src/vak/plot/learncurve.py b/src/vak/plot/learncurve.py index c78cf0792..b304ec5c1 100644 --- a/src/vak/plot/learncurve.py +++ b/src/vak/plot/learncurve.py @@ -1,4 +1,5 @@ """functions to plot learning curve results""" + import os import pickle from configparser import ConfigParser diff --git a/src/vak/plot/spect.py b/src/vak/plot/spect.py index 286e357dd..99109649a 100644 --- a/src/vak/plot/spect.py +++ b/src/vak/plot/spect.py @@ -1,4 +1,5 @@ """functions for plotting spectrograms""" + import matplotlib.pyplot as plt from .annot import annotation diff --git a/src/vak/predict/frame_classification.py b/src/vak/predict/frame_classification.py index b029ea809..ee09833f4 100644 --- a/src/vak/predict/frame_classification.py +++ b/src/vak/predict/frame_classification.py @@ -1,4 +1,5 @@ """Function that generates new inferences from trained models in the frame classification family.""" + from __future__ import annotations import json diff --git a/src/vak/predict/parametric_umap.py b/src/vak/predict/parametric_umap.py index 66955f2a9..b83c975df 100644 --- a/src/vak/predict/parametric_umap.py +++ b/src/vak/predict/parametric_umap.py @@ -1,4 +1,5 @@ """Function that generates new inferences from trained models in the frame classification family.""" + from __future__ import annotations import logging diff --git a/src/vak/predict/predict_.py b/src/vak/predict/predict_.py index 29208d0f2..5ada31c2f 100644 --- a/src/vak/predict/predict_.py +++ b/src/vak/predict/predict_.py @@ -1,4 +1,5 @@ """High-level function that generates new inferences from trained models.""" + from __future__ import annotations import logging diff --git a/src/vak/prep/audio_dataset.py b/src/vak/prep/audio_dataset.py index e43444684..04674d00b 100644 --- a/src/vak/prep/audio_dataset.py +++ b/src/vak/prep/audio_dataset.py @@ -176,9 +176,11 @@ def abspath(a_path): [ abspath(audio_path), abspath(annot_path), - annot_format - if annot_format - else constants.NO_ANNOTATION_FORMAT, + ( + annot_format + if annot_format + else constants.NO_ANNOTATION_FORMAT + ), samplerate, sample_dur, audio_dur, diff --git a/src/vak/prep/constants.py b/src/vak/prep/constants.py index 68399dd4c..5657c2f0b 100644 --- a/src/vak/prep/constants.py +++ b/src/vak/prep/constants.py @@ -2,6 +2,7 @@ Defined in a separate module to minimize circular imports. """ + from . import frame_classification, parametric_umap VALID_PURPOSES = frozenset( diff --git a/src/vak/prep/dataset_df_helper.py b/src/vak/prep/dataset_df_helper.py index 81f73f9ba..2146f76e2 100644 --- a/src/vak/prep/dataset_df_helper.py +++ b/src/vak/prep/dataset_df_helper.py @@ -1,4 +1,5 @@ """Helper functions for working with datasets represented as a pandas.DataFrame""" + from __future__ import annotations import pathlib diff --git a/src/vak/prep/frame_classification/assign_samples_to_splits.py b/src/vak/prep/frame_classification/assign_samples_to_splits.py index fc272a93e..ef74eae3f 100644 --- a/src/vak/prep/frame_classification/assign_samples_to_splits.py +++ b/src/vak/prep/frame_classification/assign_samples_to_splits.py @@ -5,6 +5,7 @@ Helper function called by :func:`vak.prep.frame_classification.prep_frame_classification_dataset`. """ + from __future__ import annotations import logging diff --git a/src/vak/prep/frame_classification/frame_classification.py b/src/vak/prep/frame_classification/frame_classification.py index 5bd9469b1..8ce6d29fe 100644 --- a/src/vak/prep/frame_classification/frame_classification.py +++ b/src/vak/prep/frame_classification/frame_classification.py @@ -1,5 +1,6 @@ """Function that prepares datasets for neural network models that perform the frame classification task.""" + from __future__ import annotations import json diff --git a/src/vak/prep/frame_classification/learncurve.py b/src/vak/prep/frame_classification/learncurve.py index 97ba55525..bae335c5d 100644 --- a/src/vak/prep/frame_classification/learncurve.py +++ b/src/vak/prep/frame_classification/learncurve.py @@ -1,5 +1,6 @@ """Functionality to prepare splits of frame classification datasets to generate a learning curve.""" + from __future__ import annotations import logging diff --git a/src/vak/prep/frame_classification/make_splits.py b/src/vak/prep/frame_classification/make_splits.py index e4fd01564..2af4b586d 100644 --- a/src/vak/prep/frame_classification/make_splits.py +++ b/src/vak/prep/frame_classification/make_splits.py @@ -1,4 +1,5 @@ """Helper functions for frame classification dataset prep.""" + from __future__ import annotations import collections @@ -437,9 +438,11 @@ def _save_dataset_arrays_and_return_index_arrays( ] = frames_paths frame_labels_npy_paths = [ - sample.frame_labels_npy_path - if isinstance(sample.frame_labels_npy_path, str) - else None + ( + sample.frame_labels_npy_path + if isinstance(sample.frame_labels_npy_path, str) + else None + ) for sample in samples ] split_df[ diff --git a/src/vak/prep/frame_classification/validators.py b/src/vak/prep/frame_classification/validators.py index 91d56be7b..35e023771 100644 --- a/src/vak/prep/frame_classification/validators.py +++ b/src/vak/prep/frame_classification/validators.py @@ -1,4 +1,5 @@ """Validators for frame classification datasets""" + from __future__ import annotations import pandas as pd diff --git a/src/vak/prep/parametric_umap/dataset_arrays.py b/src/vak/prep/parametric_umap/dataset_arrays.py index 67e224ae7..82080ae5d 100644 --- a/src/vak/prep/parametric_umap/dataset_arrays.py +++ b/src/vak/prep/parametric_umap/dataset_arrays.py @@ -1,6 +1,7 @@ """Helper functions for `vak.prep.dimensionality_reduction` module that handle array files. """ + from __future__ import annotations import logging diff --git a/src/vak/prep/sequence_dataset.py b/src/vak/prep/sequence_dataset.py index 11ec2df86..067b7807a 100644 --- a/src/vak/prep/sequence_dataset.py +++ b/src/vak/prep/sequence_dataset.py @@ -1,4 +1,5 @@ """Helper functions for datasets annotated as sequences.""" + from __future__ import annotations import numpy as np diff --git a/src/vak/prep/spectrogram_dataset/__init__.py b/src/vak/prep/spectrogram_dataset/__init__.py index 54491e66f..15f1cc474 100644 --- a/src/vak/prep/spectrogram_dataset/__init__.py +++ b/src/vak/prep/spectrogram_dataset/__init__.py @@ -1,5 +1,6 @@ """Functions for preparing a dataset for neural network models from a dataset of spectrograms.""" + from .prep import prep_spectrogram_dataset __all__ = [ diff --git a/src/vak/prep/spectrogram_dataset/spect.py b/src/vak/prep/spectrogram_dataset/spect.py index d4d84ada0..f07562180 100644 --- a/src/vak/prep/spectrogram_dataset/spect.py +++ b/src/vak/prep/spectrogram_dataset/spect.py @@ -5,6 +5,7 @@ spectrogram adapted from code by Kyle Kastner and Tim Sainburg https://github.com/timsainb/python_spectrograms_and_inversion """ + import numpy as np from matplotlib.mlab import specgram from scipy.signal import butter, lfilter @@ -89,9 +90,9 @@ def spectrogram( spect[spect < thresh] = thresh else: if thresh: - spect[ - spect < thresh - ] = thresh # set anything less than the threshold as the threshold + spect[spect < thresh] = ( + thresh # set anything less than the threshold as the threshold + ) if freq_cutoffs: f_inds = np.nonzero( diff --git a/src/vak/prep/spectrogram_dataset/spect_helper.py b/src/vak/prep/spectrogram_dataset/spect_helper.py index 924e855b5..f0115e922 100644 --- a/src/vak/prep/spectrogram_dataset/spect_helper.py +++ b/src/vak/prep/spectrogram_dataset/spect_helper.py @@ -4,6 +4,7 @@ The columns of the dataframe are specified by :const:`vak.prep.spectrogram_dataset.spect_helper.DF_COLUMNS`. """ + from __future__ import annotations import logging @@ -239,9 +240,11 @@ def abspath(a_path): abspath(audio_path), abspath(spect_path), abspath(annot_path), - annot_format - if annot_format - else constants.NO_ANNOTATION_FORMAT, + ( + annot_format + if annot_format + else constants.NO_ANNOTATION_FORMAT + ), spect_dur, timebin_dur, ] diff --git a/src/vak/prep/split/split.py b/src/vak/prep/split/split.py index 23d37dd49..35ace9d26 100644 --- a/src/vak/prep/split/split.py +++ b/src/vak/prep/split/split.py @@ -1,5 +1,6 @@ """Functions for creating splits of datasets used with neural network models, such as the standard train-val-test splits used with supervised learning methods.""" + from __future__ import annotations import logging diff --git a/src/vak/prep/unit_dataset/unit_dataset.py b/src/vak/prep/unit_dataset/unit_dataset.py index 76a0e29b0..7d861b65a 100644 --- a/src/vak/prep/unit_dataset/unit_dataset.py +++ b/src/vak/prep/unit_dataset/unit_dataset.py @@ -1,5 +1,6 @@ """Functions for making a dataset of units from sequences, as used to train dimensionality reduction models.""" + from __future__ import annotations import logging diff --git a/src/vak/train/frame_classification.py b/src/vak/train/frame_classification.py index 256daaa84..3f5646721 100644 --- a/src/vak/train/frame_classification.py +++ b/src/vak/train/frame_classification.py @@ -1,4 +1,5 @@ """Function that trains models in the frame classification family.""" + from __future__ import annotations import datetime diff --git a/src/vak/train/parametric_umap.py b/src/vak/train/parametric_umap.py index 675dac90d..c2b7fac64 100644 --- a/src/vak/train/parametric_umap.py +++ b/src/vak/train/parametric_umap.py @@ -1,4 +1,5 @@ """Function that trains models in the Parametric UMAP family.""" + from __future__ import annotations import datetime diff --git a/src/vak/train/train_.py b/src/vak/train/train_.py index 79ee2897f..443f4ae5b 100644 --- a/src/vak/train/train_.py +++ b/src/vak/train/train_.py @@ -1,4 +1,5 @@ """High-level function that trains models.""" + from __future__ import annotations import logging diff --git a/src/vak/transforms/defaults/frame_classification.py b/src/vak/transforms/defaults/frame_classification.py index c4abd0f34..54e8cb642 100644 --- a/src/vak/transforms/defaults/frame_classification.py +++ b/src/vak/transforms/defaults/frame_classification.py @@ -8,6 +8,7 @@ needed for specific neural network models, e.g., whether the returned output includes a mask to crop off padding that was added. """ + from __future__ import annotations from typing import Callable diff --git a/src/vak/transforms/defaults/get.py b/src/vak/transforms/defaults/get.py index 0851d515c..db2cc556e 100644 --- a/src/vak/transforms/defaults/get.py +++ b/src/vak/transforms/defaults/get.py @@ -1,4 +1,5 @@ """Helper function that gets default transforms for a model.""" + from __future__ import annotations from ... import models diff --git a/src/vak/transforms/defaults/parametric_umap.py b/src/vak/transforms/defaults/parametric_umap.py index 83c568b06..9dca8bfb4 100644 --- a/src/vak/transforms/defaults/parametric_umap.py +++ b/src/vak/transforms/defaults/parametric_umap.py @@ -1,4 +1,5 @@ """Default transforms for Parametric UMAP models.""" + from __future__ import annotations import torchvision.transforms diff --git a/src/vak/transforms/frame_labels/functional.py b/src/vak/transforms/frame_labels/functional.py index 0ea758840..7fd73ff30 100644 --- a/src/vak/transforms/frame_labels/functional.py +++ b/src/vak/transforms/frame_labels/functional.py @@ -17,6 +17,7 @@ and apply the most "popular" label within each segment to all timebins in that segment - postprocess: combines remove_short_segments and take_majority_vote in one transform """ + from __future__ import annotations import numpy as np diff --git a/src/vak/transforms/frame_labels/transforms.py b/src/vak/transforms/frame_labels/transforms.py index 2734b7da0..bcb81bc48 100644 --- a/src/vak/transforms/frame_labels/transforms.py +++ b/src/vak/transforms/frame_labels/transforms.py @@ -20,6 +20,7 @@ - PostProcess: combines two post-processing transforms applied to frame labels, ``remove_short_segments`` and ``take_majority_vote``, in one class. """ + from __future__ import annotations import numpy as np From a03b3b112286466be3bc14728f2b9eee790e5463 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Fri, 3 May 2024 10:42:31 -0400 Subject: [PATCH 111/183] Make more linting fixes --- src/scripts/download_autoannotate_data.py | 1 + src/vak/cli/eval.py | 2 -- src/vak/cli/prep.py | 2 +- src/vak/config/__init__.py | 1 + src/vak/config/load.py | 2 -- src/vak/config/model.py | 2 -- src/vak/config/validators.py | 6 ++++-- 7 files changed, 7 insertions(+), 9 deletions(-) diff --git a/src/scripts/download_autoannotate_data.py b/src/scripts/download_autoannotate_data.py index 2d8ea268f..cca3c6e14 100644 --- a/src/scripts/download_autoannotate_data.py +++ b/src/scripts/download_autoannotate_data.py @@ -3,6 +3,7 @@ Adapted from https://github.com/NickleDave/bfsongrepo/blob/main/src/scripts/download_dataset.py """ + from __future__ import annotations import argparse diff --git a/src/vak/cli/eval.py b/src/vak/cli/eval.py index 2a8a95bbe..6780d73eb 100644 --- a/src/vak/cli/eval.py +++ b/src/vak/cli/eval.py @@ -1,8 +1,6 @@ import logging from pathlib import Path -import attrs - from .. import config from .. import eval as eval_module from ..common.logging import config_logging_for_cli, log_version diff --git a/src/vak/cli/prep.py b/src/vak/cli/prep.py index be7861986..d86c4c0a9 100644 --- a/src/vak/cli/prep.py +++ b/src/vak/cli/prep.py @@ -96,7 +96,7 @@ def prep(toml_path): # open here because need to check whether the `dataset` already has a `path`, see #314 & #333 config_dict = _load_toml_from_path(toml_path) - # ---- figure out purpose of config file from tables; will save path of prep'd dataset in that table ------------------------- + # ---- figure out purpose of config file from tables; will save path of prep'd dataset in that table --------------- purpose = purpose_from_toml(config_dict, toml_path) if ( "dataset" in config_dict[purpose] diff --git a/src/vak/config/__init__.py b/src/vak/config/__init__.py index 0321b3f98..056c0ef12 100644 --- a/src/vak/config/__init__.py +++ b/src/vak/config/__init__.py @@ -25,6 +25,7 @@ __all__ = [ "config", + "dataset", "eval", "learncurve", "model", diff --git a/src/vak/config/load.py b/src/vak/config/load.py index 19f0d5abd..3134dc85e 100644 --- a/src/vak/config/load.py +++ b/src/vak/config/load.py @@ -7,8 +7,6 @@ import tomlkit import tomlkit.exceptions -from .validators import are_keys_valid, are_tables_valid - def _tomlkit_to_popo(d): """Convert tomlkit to "popo" (Plain-Old Python Objects) diff --git a/src/vak/config/model.py b/src/vak/config/model.py index cc5259fed..3db4b5344 100644 --- a/src/vak/config/model.py +++ b/src/vak/config/model.py @@ -2,8 +2,6 @@ from __future__ import annotations -import pathlib - from attrs import define, field from attrs.validators import instance_of diff --git a/src/vak/config/validators.py b/src/vak/config/validators.py index 40052eb06..628656a51 100644 --- a/src/vak/config/validators.py +++ b/src/vak/config/validators.py @@ -105,13 +105,15 @@ def are_tables_valid(config_dict, toml_path=None): err_msg = ( f"Top-level table defined in {toml_path} is not valid: {table}\n" f"Valid top-level tables are: {VALID_TOP_LEVEL_TABLES}\n" - "Please see example toml configuration files here: https://github.com/vocalpy/vak/tree/main/doc/toml" + "Please see example toml configuration files here: " + "https://github.com/vocalpy/vak/tree/main/doc/toml" ) else: err_msg = ( f"Table defined in toml config is not valid: {table}\n" f"Valid top-level tables are: {VALID_TOP_LEVEL_TABLES}\n" - "Please see example toml configuration files here: https://github.com/vocalpy/vak/tree/main/doc/toml" + "Please see example toml configuration files here: " + "https://github.com/vocalpy/vak/tree/main/doc/toml" ) raise ValueError(err_msg) From 9b31067aaba31f492918ddb066205939bec854b2 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Fri, 3 May 2024 10:42:45 -0400 Subject: [PATCH 112/183] Speed up install in nox session 'lint', only install linting tools --- noxfile.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/noxfile.py b/noxfile.py index 188fe1281..f931e8803 100644 --- a/noxfile.py +++ b/noxfile.py @@ -61,11 +61,11 @@ def lint(session): """ Run the linter. """ - session.install(".[dev]") + session.install("isort", "black", "flake8") # run isort first since black disagrees with it session.run("isort", "./src") session.run("black", "./src", "--line-length=79") - session.run("flake8", "./src", "--max-line-length", "120", "--exclude", "./src/crowsetta/_vendor") + session.run("flake8", "./src", "--max-line-length", "120") @nox.session From ee9c4e57cde6babdb09ed98eab446284d9143a2d Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Fri, 3 May 2024 13:08:25 -0400 Subject: [PATCH 113/183] Change names 'section'/'option' -> 'table'/'key' in tests --- tests/fixtures/config.py | 50 ++++++++--------- tests/test_cli/test_eval.py | 14 ++--- tests/test_cli/test_learncurve.py | 22 ++++---- tests/test_cli/test_predict.py | 14 ++--- tests/test_cli/test_prep.py | 16 +++--- tests/test_cli/test_train.py | 18 +++---- tests/test_eval/test_eval.py | 8 +-- tests/test_eval/test_frame_classification.py | 38 ++++++------- tests/test_eval/test_parametric_umap.py | 34 ++++++------ .../test_frame_classification.py | 14 ++--- tests/test_metrics/test_segmentation.py | 0 .../test_predict/test_frame_classification.py | 40 +++++++------- tests/test_predict/test_predict.py | 8 +-- .../test_frame_classification.py | 54 +++++++++---------- .../test_learncurve.py | 16 +++--- tests/test_prep/test_prep.py | 8 +-- tests/test_train/test_frame_classification.py | 36 ++++++------- tests/test_train/test_parametric_umap.py | 26 ++++----- tests/test_train/test_train.py | 10 ++-- 19 files changed, 213 insertions(+), 213 deletions(-) create mode 100644 tests/test_metrics/test_segmentation.py diff --git a/tests/fixtures/config.py b/tests/fixtures/config.py index 184504099..a5c3418b9 100644 --- a/tests/fixtures/config.py +++ b/tests/fixtures/config.py @@ -19,7 +19,7 @@ def test_configs_root(): 1) those used by the tests/scripts/generate_data_for_tests.py script. Will be listed in configs.json. See ``specific_config_toml_path`` fixture below for details about types of configs. - 2) those used by tests that are static, e.g., ``invalid_section_config.toml`` + 2) those used by tests that are static, e.g., ``invalid_table_config.toml`` This fixture facilitates access to type (2), e.g. in test_config/test_parse """ @@ -97,7 +97,7 @@ def specific_config_toml_path(generated_test_configs_root, list_of_schematized_c If ``root_results_dir`` argument is specified when calling the factory function, - it will convert the value for that option in the section + it will convert the value for that key in the table corresponding to ``config_type`` to the value specified for ``root_results_dir``. This makes it possible to dynamically change the ``root_results_dir`` @@ -110,7 +110,7 @@ def _specific_config( annot_format, audio_format=None, spect_format=None, - options_to_change=None, + keys_to_change=None, ): """returns path to a specific configuration file, determined by characteristics specified by the caller: @@ -124,18 +124,18 @@ def _specific_config( annotation format, recognized by ``crowsetta`` audio_format : str spect_format : str - options_to_change : list, dict - list of dicts with keys 'section', 'option', and 'value'. - Can be a single dict, in which case only that option is changed. + keys_to_change : list, dict + list of dicts with keys 'table', 'key', and 'value'. + Can be a single dict, in which case only that key is changed. If the 'value' is set to 'DELETE-OPTION', - the option will be removed from the config. - This can be used to test behavior when the option is not set. + the key will be removed from the config. + This can be used to test behavior when the key is not set. Returns ------- config_path : pathlib.Path that points to temporary copy of specified config, - with any options changed as specified + with any keys changed as specified """ original_config_path = None for schematized_config in list_of_schematized_configs: @@ -162,28 +162,28 @@ def _specific_config( config_copy_path = tmp_path.joinpath(original_config_path.name) config_copy_path = shutil.copy(src=original_config_path, dst=config_copy_path) - if options_to_change is not None: - if isinstance(options_to_change, dict): - options_to_change = [options_to_change] - elif isinstance(options_to_change, list): + if keys_to_change is not None: + if isinstance(keys_to_change, dict): + keys_to_change = [keys_to_change] + elif isinstance(keys_to_change, list): pass else: raise TypeError( - f"invalid type for `options_to_change`: {type(options_to_change)}" + f"invalid type for `keys_to_change`: {type(keys_to_change)}" ) with config_copy_path.open("r") as fp: - config_toml = tomlkit.load(fp) + tomldoc = tomlkit.load(fp) - for opt_dict in options_to_change: + for opt_dict in keys_to_change: if opt_dict["value"] == 'DELETE-OPTION': - # e.g., to test behavior of config without this option - del config_toml[opt_dict["section"]][opt_dict["option"]] + # e.g., to test behavior of config without this key + del tomldoc["vak"][opt_dict["table"]][opt_dict["key"]] else: - config_toml[opt_dict["section"]][opt_dict["option"]] = opt_dict["value"] + tomldoc["vak"][opt_dict["table"]][opt_dict["key"]] = opt_dict["value"] with config_copy_path.open("w") as fp: - tomlkit.dump(config_toml, fp) + tomlkit.dump(tomldoc, fp) return config_copy_path @@ -231,7 +231,7 @@ def _tomlkit_to_popo(d): def _load_config_dict(toml_path): """Return config as dict, loaded from toml file. - Used to test functions that parse config sections, taking these dicts as inputs. + Used to test functions that parse config tables, taking these dicts as inputs. Note that we access the topmost table loaded from the toml: config_dict['vak'] """ @@ -336,7 +336,7 @@ def all_generated_configs_toml_path_pairs(): """ # we duplicate the constant above because we need to remake # the variables for each unit test. Otherwise tests that modify values - # for config options cause other tests to fail + # for config keys cause other tests to fail return zip( [_load_config_dict(config) for config in ALL_GENERATED_CONFIG_PATHS], ALL_GENERATED_CONFIG_PATHS @@ -350,13 +350,13 @@ def configs_toml_path_pairs_by_model_factory(all_generated_configs_toml_path_pai """ def _wrapped(model, - section_name=None): + table_name=None): out = [] unzipped = list(all_generated_configs_toml_path_pairs) for config_toml, toml_path in unzipped: if toml_path.name.startswith(model): - if section_name: - if section_name.lower() in toml_path.name: + if table_name: + if table_name.lower() in toml_path.name: out.append( (config_toml, toml_path) ) diff --git a/tests/test_cli/test_eval.py b/tests/test_cli/test_eval.py index f94f68f46..ccebf39f6 100644 --- a/tests/test_cli/test_eval.py +++ b/tests/test_cli/test_eval.py @@ -25,9 +25,9 @@ def test_eval( ) output_dir.mkdir() - options_to_change = [ - {"section": "EVAL", "option": "output_dir", "value": str(output_dir)}, - {"section": "EVAL", "option": "device", "value": device}, + keys_to_change = [ + {"table": "eval", "key": "output_dir", "value": str(output_dir)}, + {"table": "eval", "key": "device", "value": device}, ] toml_path = specific_config_toml_path( @@ -36,7 +36,7 @@ def test_eval( audio_format=audio_format, annot_format=annot_format, spect_format=spect_format, - options_to_change=options_to_change, + keys_to_change=keys_to_change, ) with mock.patch('vak.eval.eval', autospec=True) as mock_core_eval: @@ -54,8 +54,8 @@ def test_eval_dataset_path_none_raises( """Test that cli.eval raises ValueError when dataset_path is None (presumably because `vak prep` was not run yet) """ - options_to_change = [ - {"section": "EVAL", "option": "dataset_path", "value": "DELETE-OPTION"}, + keys_to_change = [ + {"table": "eval", "key": "dataset_path", "value": "DELETE-OPTION"}, ] toml_path = specific_config_toml_path( @@ -64,7 +64,7 @@ def test_eval_dataset_path_none_raises( audio_format="cbin", annot_format="notmat", spect_format=None, - options_to_change=options_to_change, + keys_to_change=keys_to_change, ) with pytest.raises(ValueError): diff --git a/tests/test_cli/test_learncurve.py b/tests/test_cli/test_learncurve.py index 8dce64302..8c829015c 100644 --- a/tests/test_cli/test_learncurve.py +++ b/tests/test_cli/test_learncurve.py @@ -14,13 +14,13 @@ def test_learncurve(specific_config_toml_path, tmp_path, device): root_results_dir = tmp_path.joinpath("test_learncurve_root_results_dir") root_results_dir.mkdir() - options_to_change = [ + keys_to_change = [ { - "section": "LEARNCURVE", - "option": "root_results_dir", + "table": "learncurve", + "key": "root_results_dir", "value": str(root_results_dir), }, - {"section": "LEARNCURVE", "option": "device", "value": device}, + {"table": "learncurve", "key": "device", "value": device}, ] toml_path = specific_config_toml_path( @@ -28,7 +28,7 @@ def test_learncurve(specific_config_toml_path, tmp_path, device): model="TweetyNet", audio_format="cbin", annot_format="notmat", - options_to_change=options_to_change, + keys_to_change=keys_to_change, ) with mock.patch('vak.learncurve.learning_curve', autospec=True) as mock_core_learning_curve: @@ -54,15 +54,15 @@ def test_learning_curve_dataset_path_none_raises( root_results_dir = tmp_path.joinpath("test_learncurve_root_results_dir") root_results_dir.mkdir() - options_to_change = [ + keys_to_change = [ { - "section": "LEARNCURVE", - "option": "root_results_dir", + "table": "learncurve", + "key": "root_results_dir", "value": str(root_results_dir), }, { - "section": "LEARNCURVE", - "option": "dataset_path", + "table": "learncurve", + "key": "dataset_path", "value": "DELETE-OPTION"}, ] @@ -72,7 +72,7 @@ def test_learning_curve_dataset_path_none_raises( audio_format="cbin", annot_format="notmat", spect_format=None, - options_to_change=options_to_change, + keys_to_change=keys_to_change, ) with pytest.raises(ValueError): diff --git a/tests/test_cli/test_predict.py b/tests/test_cli/test_predict.py index 6269c01d9..ceff6edcb 100644 --- a/tests/test_cli/test_predict.py +++ b/tests/test_cli/test_predict.py @@ -24,9 +24,9 @@ def test_predict( ) output_dir.mkdir() - options_to_change = [ - {"section": "PREDICT", "option": "output_dir", "value": str(output_dir)}, - {"section": "PREDICT", "option": "device", "value": device}, + keys_to_change = [ + {"table": "predict", "key": "output_dir", "value": str(output_dir)}, + {"table": "predict", "key": "device", "value": device}, ] toml_path = specific_config_toml_path( @@ -34,7 +34,7 @@ def test_predict( model=model_name, audio_format=audio_format, annot_format=annot_format, - options_to_change=options_to_change, + keys_to_change=keys_to_change, ) with mock.patch('vak.predict.predict', autospec=True) as mock_core_predict: @@ -51,8 +51,8 @@ def test_predict_dataset_path_none_raises( """Test that cli.predict raises ValueError when dataset_path is None (presumably because `vak prep` was not run yet) """ - options_to_change = [ - {"section": "PREDICT", "option": "dataset_path", "value": "DELETE-OPTION"}, + keys_to_change = [ + {"table": "predict", "key": "dataset_path", "value": "DELETE-OPTION"}, ] toml_path = specific_config_toml_path( @@ -61,7 +61,7 @@ def test_predict_dataset_path_none_raises( audio_format="cbin", annot_format="notmat", spect_format=None, - options_to_change=options_to_change, + keys_to_change=keys_to_change, ) with pytest.raises(ValueError): diff --git a/tests/test_cli/test_prep.py b/tests/test_cli/test_prep.py index 3361f266e..e5be2c723 100644 --- a/tests/test_cli/test_prep.py +++ b/tests/test_cli/test_prep.py @@ -64,12 +64,12 @@ def test_prep( ) output_dir.mkdir() - options_to_change = [ - {"section": "PREP", "option": "output_dir", "value": str(output_dir)}, + keys_to_change = [ + {"table": "prep", "key": "output_dir", "value": str(output_dir)}, # need to remove dataset_path option from configs we already ran prep on to avoid error { - "section": config_type.upper(), - "option": "dataset_path", + "table": config_type.upper(), + "key": "dataset_path", "value": None, }, ] @@ -79,7 +79,7 @@ def test_prep( audio_format=audio_format, annot_format=annot_format, spect_format=spect_format, - options_to_change=options_to_change, + keys_to_change=keys_to_change, ) with mock.patch('vak.prep.prep', autospec=True) as mock_core_prep: @@ -113,8 +113,8 @@ def test_prep_dataset_path_raises( ) output_dir.mkdir() - options_to_change = [ - {"section": "PREP", "option": "output_dir", "value": str(output_dir)}, + keys_to_change = [ + {"table": "prep", "key": "output_dir", "value": str(output_dir)}, ] toml_path = specific_config_toml_path( config_type=config_type, @@ -122,7 +122,7 @@ def test_prep_dataset_path_raises( audio_format=audio_format, annot_format=annot_format, spect_format=spect_format, - options_to_change=options_to_change, + keys_to_change=keys_to_change, ) with pytest.raises(ValueError): diff --git a/tests/test_cli/test_train.py b/tests/test_cli/test_train.py index c59716ff2..8c6ec9781 100644 --- a/tests/test_cli/test_train.py +++ b/tests/test_cli/test_train.py @@ -24,13 +24,13 @@ def test_train( root_results_dir = tmp_path.joinpath("test_train_root_results_dir") root_results_dir.mkdir() - options_to_change = [ + keys_to_change = [ { - "section": "TRAIN", - "option": "root_results_dir", + "table": "train", + "key": "root_results_dir", "value": str(root_results_dir), }, - {"section": "TRAIN", "option": "device", "value": device}, + {"table": "train", "key": "device", "value": device}, ] toml_path = specific_config_toml_path( @@ -39,7 +39,7 @@ def test_train( audio_format=audio_format, annot_format=annot_format, spect_format=spect_format, - options_to_change=options_to_change, + keys_to_change=keys_to_change, ) with mock.patch('vak.train.train', autospec=True) as mock_core_train: @@ -63,9 +63,9 @@ def test_train_dataset_path_none_raises( root_results_dir = tmp_path.joinpath("test_train_root_results_dir") root_results_dir.mkdir() - options_to_change = [ - {"section": "TRAIN", "option": "root_results_dir", "value": str(root_results_dir)}, - {"section": "TRAIN", "option": "dataset_path", "value": "DELETE-OPTION"}, + keys_to_change = [ + {"table": "train", "key": "root_results_dir", "value": str(root_results_dir)}, + {"table": "train", "key": "dataset_path", "value": "DELETE-OPTION"}, ] toml_path = specific_config_toml_path( @@ -74,7 +74,7 @@ def test_train_dataset_path_none_raises( audio_format="cbin", annot_format="notmat", spect_format=None, - options_to_change=options_to_change, + keys_to_change=keys_to_change, ) with pytest.raises(ValueError): diff --git a/tests/test_eval/test_eval.py b/tests/test_eval/test_eval.py index ba0e38143..7158cd393 100644 --- a/tests/test_eval/test_eval.py +++ b/tests/test_eval/test_eval.py @@ -29,9 +29,9 @@ def test_eval( ) output_dir.mkdir() - options_to_change = [ - {"section": "EVAL", "option": "output_dir", "value": str(output_dir)}, - {"section": "EVAL", "option": "device", "value": 'cpu'}, + keys_to_change = [ + {"table": "eval", "key": "output_dir", "value": str(output_dir)}, + {"table": "eval", "key": "device", "value": 'cpu'}, ] toml_path = specific_config_toml_path( @@ -40,7 +40,7 @@ def test_eval( audio_format=audio_format, annot_format=annot_format, spect_format=spect_format, - options_to_change=options_to_change, + keys_to_change=keys_to_change, ) cfg = vak.config.Config.from_toml_path(toml_path) model_config = vak.config.model.config_from_toml_path(toml_path, cfg.eval.model) diff --git a/tests/test_eval/test_frame_classification.py b/tests/test_eval/test_frame_classification.py index fb2055351..4528fd909 100644 --- a/tests/test_eval/test_frame_classification.py +++ b/tests/test_eval/test_frame_classification.py @@ -54,9 +54,9 @@ def test_eval_frame_classification_model( ) output_dir.mkdir() - options_to_change = [ - {"section": "EVAL", "option": "output_dir", "value": str(output_dir)}, - {"section": "EVAL", "option": "device", "value": device}, + keys_to_change = [ + {"table": "eval", "key": "output_dir", "value": str(output_dir)}, + {"table": "eval", "key": "device", "value": device}, ] toml_path = specific_config_toml_path( @@ -65,7 +65,7 @@ def test_eval_frame_classification_model( audio_format=audio_format, annot_format=annot_format, spect_format=spect_format, - options_to_change=options_to_change, + keys_to_change=keys_to_change, ) cfg = vak.config.Config.from_toml_path(toml_path) model_config = vak.config.model.config_from_toml_path(toml_path, cfg.eval.model) @@ -91,9 +91,9 @@ def test_eval_frame_classification_model( @pytest.mark.parametrize( 'path_option_to_change', [ - {"section": "EVAL", "option": "checkpoint_path", "value": '/obviously/doesnt/exist/ckpt.pt'}, - {"section": "EVAL", "option": "labelmap_path", "value": '/obviously/doesnt/exist/labelmap.json'}, - {"section": "EVAL", "option": "spect_scaler_path", "value": '/obviously/doesnt/exist/SpectScaler'}, + {"table": "eval", "key": "checkpoint_path", "value": '/obviously/doesnt/exist/ckpt.pt'}, + {"table": "eval", "key": "labelmap_path", "value": '/obviously/doesnt/exist/labelmap.json'}, + {"table": "eval", "key": "spect_scaler_path", "value": '/obviously/doesnt/exist/SpectScaler'}, ] ) def test_eval_frame_classification_model_raises_file_not_found( @@ -111,9 +111,9 @@ def test_eval_frame_classification_model_raises_file_not_found( ) output_dir.mkdir() - options_to_change = [ - {"section": "EVAL", "option": "output_dir", "value": str(output_dir)}, - {"section": "EVAL", "option": "device", "value": device}, + keys_to_change = [ + {"table": "eval", "key": "output_dir", "value": str(output_dir)}, + {"table": "eval", "key": "device", "value": device}, path_option_to_change, ] @@ -123,7 +123,7 @@ def test_eval_frame_classification_model_raises_file_not_found( audio_format="cbin", annot_format="notmat", spect_format=None, - options_to_change=options_to_change, + keys_to_change=keys_to_change, ) cfg = vak.config.Config.from_toml_path(toml_path) model_config = vak.config.model.config_from_toml_path(toml_path, cfg.eval.model) @@ -146,8 +146,8 @@ def test_eval_frame_classification_model_raises_file_not_found( @pytest.mark.parametrize( 'path_option_to_change', [ - {"section": "EVAL", "option": "dataset_path", "value": '/obviously/doesnt/exist/dataset-dir'}, - {"section": "EVAL", "option": "output_dir", "value": '/obviously/does/not/exist/output'}, + {"table": "eval", "key": "dataset_path", "value": '/obviously/doesnt/exist/dataset-dir'}, + {"table": "eval", "key": "output_dir", "value": '/obviously/does/not/exist/output'}, ] ) def test_eval_frame_classification_model_raises_not_a_directory( @@ -159,20 +159,20 @@ def test_eval_frame_classification_model_raises_not_a_directory( """Test that core.eval raises NotADirectory when directories don't exist """ - options_to_change = [ + keys_to_change = [ path_option_to_change, - {"section": "EVAL", "option": "device", "value": device}, + {"table": "eval", "key": "device", "value": device}, ] - if path_option_to_change["option"] != "output_dir": + if path_option_to_change["key"] != "output_dir": # need to make sure output_dir *does* exist # so we don't detect spurious NotADirectoryError and assume test passes output_dir = tmp_path.joinpath( f"test_eval_raises_not_a_directory" ) output_dir.mkdir() - options_to_change.append( - {"section": "EVAL", "option": "output_dir", "value": str(output_dir)} + keys_to_change.append( + {"table": "eval", "key": "output_dir", "value": str(output_dir)} ) toml_path = specific_config_toml_path( @@ -181,7 +181,7 @@ def test_eval_frame_classification_model_raises_not_a_directory( audio_format="cbin", annot_format="notmat", spect_format=None, - options_to_change=options_to_change, + keys_to_change=keys_to_change, ) cfg = vak.config.Config.from_toml_path(toml_path) model_config = vak.config.model.config_from_toml_path(toml_path, cfg.eval.model) diff --git a/tests/test_eval/test_parametric_umap.py b/tests/test_eval/test_parametric_umap.py index 925178c1a..824cbf2d8 100644 --- a/tests/test_eval/test_parametric_umap.py +++ b/tests/test_eval/test_parametric_umap.py @@ -32,9 +32,9 @@ def test_eval_parametric_umap_model( ) output_dir.mkdir() - options_to_change = [ - {"section": "EVAL", "option": "output_dir", "value": str(output_dir)}, - {"section": "EVAL", "option": "device", "value": device}, + keys_to_change = [ + {"table": "eval", "key": "output_dir", "value": str(output_dir)}, + {"table": "eval", "key": "device", "value": device}, ] toml_path = specific_config_toml_path( @@ -43,7 +43,7 @@ def test_eval_parametric_umap_model( audio_format=audio_format, annot_format=annot_format, spect_format=spect_format, - options_to_change=options_to_change, + keys_to_change=keys_to_change, ) cfg = vak.config.Config.from_toml_path(toml_path) model_config = vak.config.model.config_from_toml_path(toml_path, cfg.eval.model) @@ -67,7 +67,7 @@ def test_eval_parametric_umap_model( @pytest.mark.parametrize( 'path_option_to_change', [ - {"section": "EVAL", "option": "checkpoint_path", "value": '/obviously/doesnt/exist/ckpt.pt'}, + {"table": "eval", "key": "checkpoint_path", "value": '/obviously/doesnt/exist/ckpt.pt'}, ] ) def test_eval_frame_classification_model_raises_file_not_found( @@ -83,9 +83,9 @@ def test_eval_frame_classification_model_raises_file_not_found( ) output_dir.mkdir() - options_to_change = [ - {"section": "EVAL", "option": "output_dir", "value": str(output_dir)}, - {"section": "EVAL", "option": "device", "value": device}, + keys_to_change = [ + {"table": "eval", "key": "output_dir", "value": str(output_dir)}, + {"table": "eval", "key": "device", "value": device}, path_option_to_change, ] @@ -95,7 +95,7 @@ def test_eval_frame_classification_model_raises_file_not_found( audio_format="cbin", annot_format="notmat", spect_format=None, - options_to_change=options_to_change, + keys_to_change=keys_to_change, ) cfg = vak.config.Config.from_toml_path(toml_path) model_config = vak.config.model.config_from_toml_path(toml_path, cfg.eval.model) @@ -117,8 +117,8 @@ def test_eval_frame_classification_model_raises_file_not_found( @pytest.mark.parametrize( 'path_option_to_change', [ - {"section": "EVAL", "option": "dataset_path", "value": '/obviously/doesnt/exist/dataset-dir'}, - {"section": "EVAL", "option": "output_dir", "value": '/obviously/does/not/exist/output'}, + {"table": "eval", "key": "dataset_path", "value": '/obviously/doesnt/exist/dataset-dir'}, + {"table": "eval", "key": "output_dir", "value": '/obviously/does/not/exist/output'}, ] ) def test_eval_frame_classification_model_raises_not_a_directory( @@ -129,20 +129,20 @@ def test_eval_frame_classification_model_raises_not_a_directory( ): """Test that :func:`vak.eval.parametric_umap.eval_parametric_umap_model` raises NotADirectoryError when expected""" - options_to_change = [ + keys_to_change = [ path_option_to_change, - {"section": "EVAL", "option": "device", "value": device}, + {"table": "eval", "key": "device", "value": device}, ] - if path_option_to_change["option"] != "output_dir": + if path_option_to_change["key"] != "output_dir": # need to make sure output_dir *does* exist # so we don't detect spurious NotADirectoryError and assume test passes output_dir = tmp_path.joinpath( f"test_eval_raises_not_a_directory" ) output_dir.mkdir() - options_to_change.append( - {"section": "EVAL", "option": "output_dir", "value": str(output_dir)} + keys_to_change.append( + {"table": "eval", "key": "output_dir", "value": str(output_dir)} ) toml_path = specific_config_toml_path( @@ -151,7 +151,7 @@ def test_eval_frame_classification_model_raises_not_a_directory( audio_format="cbin", annot_format="notmat", spect_format=None, - options_to_change=options_to_change, + keys_to_change=keys_to_change, ) cfg = vak.config.Config.from_toml_path(toml_path) model_config = vak.config.model.config_from_toml_path(toml_path, cfg.eval.model) diff --git a/tests/test_learncurve/test_frame_classification.py b/tests/test_learncurve/test_frame_classification.py index d83ca9eeb..a985ba422 100644 --- a/tests/test_learncurve/test_frame_classification.py +++ b/tests/test_learncurve/test_frame_classification.py @@ -52,14 +52,14 @@ def assert_learncurve_output_matches_expected(cfg, model_name, results_path): ) def test_learning_curve_for_frame_classification_model( model_name, audio_format, annot_format, specific_config_toml_path, tmp_path, device): - options_to_change = {"section": "LEARNCURVE", "option": "device", "value": device} + keys_to_change = {"table": "learncurve", "key": "device", "value": device} toml_path = specific_config_toml_path( config_type="learncurve", model=model_name, audio_format=audio_format, annot_format=annot_format, - options_to_change=options_to_change, + keys_to_change=keys_to_change, ) cfg = vak.config.Config.from_toml_path(toml_path) @@ -94,8 +94,8 @@ def test_learning_curve_for_frame_classification_model( @pytest.mark.parametrize( 'dir_option_to_change', [ - {"section": "LEARNCURVE", "option": "root_results_dir", "value": '/obviously/does/not/exist/results/'}, - {"section": "LEARNCURVE", "option": "dataset_path", "value": '/obviously/doesnt/exist/dataset-dir'}, + {"table": "learncurve", "key": "root_results_dir", "value": '/obviously/does/not/exist/results/'}, + {"table": "learncurve", "key": "dataset_path", "value": '/obviously/doesnt/exist/dataset-dir'}, ] ) def test_learncurve_raises_not_a_directory(dir_option_to_change, @@ -105,8 +105,8 @@ def test_learncurve_raises_not_a_directory(dir_option_to_change, when the following directories do not exist: results_path, previous_run_path, dataset_path """ - options_to_change = [ - {"section": "LEARNCURVE", "option": "device", "value": device}, + keys_to_change = [ + {"table": "learncurve", "key": "device", "value": device}, dir_option_to_change ] toml_path = specific_config_toml_path( @@ -114,7 +114,7 @@ def test_learncurve_raises_not_a_directory(dir_option_to_change, model="TweetyNet", audio_format="cbin", annot_format="notmat", - options_to_change=options_to_change, + keys_to_change=keys_to_change, ) cfg = vak.config.Config.from_toml_path(toml_path) model_config = vak.config.model.config_from_toml_path(toml_path, cfg.learncurve.model) diff --git a/tests/test_metrics/test_segmentation.py b/tests/test_metrics/test_segmentation.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_predict/test_frame_classification.py b/tests/test_predict/test_frame_classification.py index db77ef7e0..e3d52a38b 100644 --- a/tests/test_predict/test_frame_classification.py +++ b/tests/test_predict/test_frame_classification.py @@ -37,17 +37,17 @@ def test_predict_with_frame_classification_model( ) output_dir.mkdir() - options_to_change = [ - {"section": "PREDICT", "option": "output_dir", "value": str(output_dir)}, - {"section": "PREDICT", "option": "device", "value": device}, - {"section": "PREDICT", "option": "save_net_outputs", "value": save_net_outputs}, + keys_to_change = [ + {"table": "predict", "key": "output_dir", "value": str(output_dir)}, + {"table": "predict", "key": "device", "value": device}, + {"table": "predict", "key": "save_net_outputs", "value": save_net_outputs}, ] toml_path = specific_config_toml_path( config_type="predict", model=model_name, audio_format=audio_format, annot_format=annot_format, - options_to_change=options_to_change, + keys_to_change=keys_to_change, ) cfg = vak.config.Config.from_toml_path(toml_path) @@ -94,9 +94,9 @@ def test_predict_with_frame_classification_model( @pytest.mark.parametrize( 'path_option_to_change', [ - {"section": "PREDICT", "option": "checkpoint_path", "value": '/obviously/doesnt/exist/ckpt.pt'}, - {"section": "PREDICT", "option": "labelmap_path", "value": '/obviously/doesnt/exist/labelmap.json'}, - {"section": "PREDICT", "option": "spect_scaler_path", "value": '/obviously/doesnt/exist/SpectScaler'}, + {"table": "predict", "key": "checkpoint_path", "value": '/obviously/doesnt/exist/ckpt.pt'}, + {"table": "predict", "key": "labelmap_path", "value": '/obviously/doesnt/exist/labelmap.json'}, + {"table": "predict", "key": "spect_scaler_path", "value": '/obviously/doesnt/exist/SpectScaler'}, ] ) def test_predict_with_frame_classification_model_raises_file_not_found( @@ -112,9 +112,9 @@ def test_predict_with_frame_classification_model_raises_file_not_found( ) output_dir.mkdir() - options_to_change = [ - {"section": "PREDICT", "option": "output_dir", "value": str(output_dir)}, - {"section": "PREDICT", "option": "device", "value": device}, + keys_to_change = [ + {"table": "predict", "key": "output_dir", "value": str(output_dir)}, + {"table": "predict", "key": "device", "value": device}, path_option_to_change, ] toml_path = specific_config_toml_path( @@ -122,7 +122,7 @@ def test_predict_with_frame_classification_model_raises_file_not_found( model="TweetyNet", audio_format="cbin", annot_format="notmat", - options_to_change=options_to_change, + keys_to_change=keys_to_change, ) cfg = vak.config.Config.from_toml_path(toml_path) @@ -152,8 +152,8 @@ def test_predict_with_frame_classification_model_raises_file_not_found( @pytest.mark.parametrize( 'path_option_to_change', [ - {"section": "PREDICT", "option": "dataset_path", "value": '/obviously/doesnt/exist/dataset-dir'}, - {"section": "PREDICT", "option": "output_dir", "value": '/obviously/does/not/exist/output'}, + {"table": "predict", "key": "dataset_path", "value": '/obviously/doesnt/exist/dataset-dir'}, + {"table": "predict", "key": "output_dir", "value": '/obviously/does/not/exist/output'}, ] ) def test_predict_with_frame_classification_model_raises_not_a_directory( @@ -165,20 +165,20 @@ def test_predict_with_frame_classification_model_raises_not_a_directory( """Test that core.eval raises NotADirectory when ``output_dir`` does not exist """ - options_to_change = [ + keys_to_change = [ path_option_to_change, - {"section": "PREDICT", "option": "device", "value": device}, + {"table": "predict", "key": "device", "value": device}, ] - if path_option_to_change["option"] != "output_dir": + if path_option_to_change["key"] != "output_dir": # need to make sure output_dir *does* exist # so we don't detect spurious NotADirectoryError and assume test passes output_dir = tmp_path.joinpath( f"test_predict_raises_not_a_directory" ) output_dir.mkdir() - options_to_change.append( - {"section": "PREDICT", "option": "output_dir", "value": str(output_dir)} + keys_to_change.append( + {"table": "predict", "key": "output_dir", "value": str(output_dir)} ) toml_path = specific_config_toml_path( @@ -186,7 +186,7 @@ def test_predict_with_frame_classification_model_raises_not_a_directory( model="TweetyNet", audio_format="cbin", annot_format="notmat", - options_to_change=options_to_change, + keys_to_change=keys_to_change, ) cfg = vak.config.Config.from_toml_path(toml_path) model_config = vak.config.model.config_from_toml_path(toml_path, cfg.predict.model) diff --git a/tests/test_predict/test_predict.py b/tests/test_predict/test_predict.py index d25d06017..6b0768380 100644 --- a/tests/test_predict/test_predict.py +++ b/tests/test_predict/test_predict.py @@ -26,9 +26,9 @@ def test_predict( ) output_dir.mkdir() - options_to_change = [ - {"section": "PREDICT", "option": "output_dir", "value": str(output_dir)}, - {"section": "PREDICT", "option": "device", "value": 'cpu'}, + keys_to_change = [ + {"table": "predict", "key": "output_dir", "value": str(output_dir)}, + {"table": "predict", "key": "device", "value": 'cpu'}, ] toml_path = specific_config_toml_path( @@ -36,7 +36,7 @@ def test_predict( model=model_name, audio_format=audio_format, annot_format=annot_format, - options_to_change=options_to_change, + keys_to_change=keys_to_change, ) cfg = vak.config.Config.from_toml_path(toml_path) model_config = vak.config.model.config_from_toml_path(toml_path, cfg.predict.model) diff --git a/tests/test_prep/test_frame_classification/test_frame_classification.py b/tests/test_prep/test_frame_classification/test_frame_classification.py index 6b8c7d580..e4513402f 100644 --- a/tests/test_prep/test_frame_classification/test_frame_classification.py +++ b/tests/test_prep/test_frame_classification/test_frame_classification.py @@ -80,10 +80,10 @@ def test_prep_frame_classification_dataset( ) output_dir.mkdir() - options_to_change = [ + keys_to_change = [ { - "section": "PREP", - "option": "output_dir", + "table": "prep", + "key": "output_dir", "value": str(output_dir), }, ] @@ -93,7 +93,7 @@ def test_prep_frame_classification_dataset( audio_format=audio_format, annot_format=annot_format, spect_format=spect_format, - options_to_change=options_to_change, + keys_to_change=keys_to_change, ) cfg = vak.config.Config.from_toml_path(toml_path) @@ -149,13 +149,13 @@ def test_prep_frame_classification_dataset_raises_when_labelset_required_but_is_ ) output_dir.mkdir() - options_to_change = [ - {"section": "PREP", - "option": "output_dir", + keys_to_change = [ + {"table": "prep", + "key": "output_dir", "value": str(output_dir), }, - {"section": "PREP", - "option": "labelset", + {"table": "prep", + "key": "labelset", "value": "DELETE-OPTION", }, ] @@ -165,7 +165,7 @@ def test_prep_frame_classification_dataset_raises_when_labelset_required_but_is_ audio_format=audio_format, annot_format=annot_format, spect_format=spect_format, - options_to_change=options_to_change, + keys_to_change=keys_to_change, ) cfg = vak.config.Config.from_toml_path(toml_path) @@ -214,15 +214,15 @@ def test_prep_frame_classification_dataset_with_single_audio_and_annot(source_te ) output_dir.mkdir() - options_to_change = [ + keys_to_change = [ { - "section": "PREP", - "option": "data_dir", + "table": "prep", + "key": "data_dir", "value": str(data_dir), }, { - "section": "PREP", - "option": "output_dir", + "table": "prep", + "key": "output_dir", "value": str(output_dir), }, ] @@ -233,7 +233,7 @@ def test_prep_frame_classification_dataset_with_single_audio_and_annot(source_te audio_format='cbin', annot_format='notmat', spect_format=None, - options_to_change=options_to_change, + keys_to_change=keys_to_change, ) cfg = vak.config.Config.from_toml_path(toml_path) @@ -272,15 +272,15 @@ def test_prep_frame_classification_dataset_when_annot_has_single_segment(source_ ) output_dir.mkdir() - options_to_change = [ + keys_to_change = [ { - "section": "PREP", - "option": "data_dir", + "table": "prep", + "key": "data_dir", "value": str(data_dir), }, { - "section": "PREP", - "option": "output_dir", + "table": "prep", + "key": "output_dir", "value": str(output_dir), }, ] @@ -291,7 +291,7 @@ def test_prep_frame_classification_dataset_when_annot_has_single_segment(source_ audio_format='cbin', annot_format='notmat', spect_format=None, - options_to_change=options_to_change, + keys_to_change=keys_to_change, ) cfg = vak.config.Config.from_toml_path(toml_path) @@ -318,8 +318,8 @@ def test_prep_frame_classification_dataset_when_annot_has_single_segment(source_ @pytest.mark.parametrize( "dir_option_to_change", [ - {"section": "PREP", "option": "data_dir", "value": '/obviously/does/not/exist/data'}, - {"section": "PREP", "option": "output_dir", "value": '/obviously/does/not/exist/output'}, + {"table": "prep", "key": "data_dir", "value": '/obviously/does/not/exist/data'}, + {"table": "prep", "key": "output_dir", "value": '/obviously/does/not/exist/output'}, ], ) def test_prep_frame_classification_dataset_raises_not_a_directory( @@ -338,7 +338,7 @@ def test_prep_frame_classification_dataset_raises_not_a_directory( audio_format="cbin", annot_format="notmat", spect_format=None, - options_to_change=dir_option_to_change, + keys_to_change=dir_option_to_change, ) cfg = vak.config.Config.from_toml_path(toml_path) @@ -364,7 +364,7 @@ def test_prep_frame_classification_dataset_raises_not_a_directory( @pytest.mark.parametrize( "path_option_to_change", [ - {"section": "PREP", "option": "annot_file", "value": '/obviously/does/not/exist/annot.mat'}, + {"table": "prep", "key": "annot_file", "value": '/obviously/does/not/exist/annot.mat'}, ], ) def test_prep_frame_classification_dataset_raises_file_not_found( @@ -386,7 +386,7 @@ def test_prep_frame_classification_dataset_raises_file_not_found( audio_format="cbin", annot_format="notmat", spect_format=None, - options_to_change=path_option_to_change, + keys_to_change=path_option_to_change, ) cfg = vak.config.Config.from_toml_path(toml_path) diff --git a/tests/test_prep/test_frame_classification/test_learncurve.py b/tests/test_prep/test_frame_classification/test_learncurve.py index ed800b872..3b93e42da 100644 --- a/tests/test_prep/test_frame_classification/test_learncurve.py +++ b/tests/test_prep/test_frame_classification/test_learncurve.py @@ -22,10 +22,10 @@ def test_make_index_vectors_for_each_subsets( ): root_results_dir = tmp_path.joinpath("tmp_root_results_dir") root_results_dir.mkdir() - options_to_change = [ + keys_to_change = [ { - "section": "LEARNCURVE", - "option": "root_results_dir", + "table": "learncurve", + "key": "root_results_dir", "value": str(root_results_dir), }, ] @@ -34,7 +34,7 @@ def test_make_index_vectors_for_each_subsets( model=model_name, audio_format=audio_format, annot_format=annot_format, - options_to_change=options_to_change, + keys_to_change=keys_to_change, ) cfg = vak.config.Config.from_toml_path(toml_path) @@ -134,10 +134,10 @@ def test_make_subsets_from_dataset_df( ): root_results_dir = tmp_path.joinpath("tmp_root_results_dir") root_results_dir.mkdir() - options_to_change = [ + keys_to_change = [ { - "section": "LEARNCURVE", - "option": "root_results_dir", + "table": "learncurve", + "key": "root_results_dir", "value": str(root_results_dir), }, ] @@ -146,7 +146,7 @@ def test_make_subsets_from_dataset_df( model=model_name, audio_format=audio_format, annot_format=annot_format, - options_to_change=options_to_change, + keys_to_change=keys_to_change, ) cfg = vak.config.Config.from_toml_path(toml_path) diff --git a/tests/test_prep/test_prep.py b/tests/test_prep/test_prep.py index e5bb0664b..ad673ed3d 100644 --- a/tests/test_prep/test_prep.py +++ b/tests/test_prep/test_prep.py @@ -33,10 +33,10 @@ def test_prep( ) output_dir.mkdir() - options_to_change = [ + keys_to_change = [ { - "section": "PREP", - "option": "output_dir", + "table": "prep", + "key": "output_dir", "value": str(output_dir), }, ] @@ -46,7 +46,7 @@ def test_prep( audio_format=audio_format, annot_format=annot_format, spect_format=spect_format, - options_to_change=options_to_change, + keys_to_change=keys_to_change, ) cfg = vak.config.Config.from_toml_path(toml_path) purpose = config_type.lower() diff --git a/tests/test_train/test_frame_classification.py b/tests/test_train/test_frame_classification.py index 38642872f..8eafd4d31 100644 --- a/tests/test_train/test_frame_classification.py +++ b/tests/test_train/test_frame_classification.py @@ -46,9 +46,9 @@ def test_train_frame_classification_model( ): results_path = vak.common.paths.generate_results_dir_name_as_path(tmp_path) results_path.mkdir() - options_to_change = [ - {"section": "TRAIN", "option": "device", "value": device}, - {"section": "TRAIN", "option": "root_results_dir", "value": results_path} + keys_to_change = [ + {"table": "train", "key": "device", "value": device}, + {"table": "train", "key": "root_results_dir", "value": results_path} ] toml_path = specific_config_toml_path( config_type="train", @@ -56,7 +56,7 @@ def test_train_frame_classification_model( audio_format=audio_format, annot_format=annot_format, spect_format=spect_format, - options_to_change=options_to_change, + keys_to_change=keys_to_change, ) cfg = vak.config.Config.from_toml_path(toml_path) model_config = vak.config.model.config_from_toml_path(toml_path, cfg.train.model) @@ -99,9 +99,9 @@ def test_continue_training( ): results_path = vak.common.paths.generate_results_dir_name_as_path(tmp_path) results_path.mkdir() - options_to_change = [ - {"section": "TRAIN", "option": "device", "value": device}, - {"section": "TRAIN", "option": "root_results_dir", "value": results_path} + keys_to_change = [ + {"table": "train", "key": "device", "value": device}, + {"table": "train", "key": "root_results_dir", "value": results_path} ] toml_path = specific_config_toml_path( config_type="train_continue", @@ -109,7 +109,7 @@ def test_continue_training( audio_format=audio_format, annot_format=annot_format, spect_format=spect_format, - options_to_change=options_to_change, + keys_to_change=keys_to_change, ) cfg = vak.config.Config.from_toml_path(toml_path) model_config = vak.config.model.config_from_toml_path(toml_path, cfg.train.model) @@ -142,8 +142,8 @@ def test_continue_training( @pytest.mark.parametrize( 'path_option_to_change', [ - {"section": "TRAIN", "option": "checkpoint_path", "value": '/obviously/doesnt/exist/ckpt.pt'}, - {"section": "TRAIN", "option": "spect_scaler_path", "value": '/obviously/doesnt/exist/SpectScaler'}, + {"table": "train", "key": "checkpoint_path", "value": '/obviously/doesnt/exist/ckpt.pt'}, + {"table": "train", "key": "spect_scaler_path", "value": '/obviously/doesnt/exist/SpectScaler'}, ] ) def test_train_raises_file_not_found( @@ -153,8 +153,8 @@ def test_train_raises_file_not_found( when one of the following does not exist: checkpoint_path, dataset_path, spect_scaler_path """ - options_to_change = [ - {"section": "TRAIN", "option": "device", "value": device}, + keys_to_change = [ + {"table": "train", "key": "device", "value": device}, path_option_to_change ] toml_path = specific_config_toml_path( @@ -163,7 +163,7 @@ def test_train_raises_file_not_found( audio_format="cbin", annot_format="notmat", spect_format=None, - options_to_change=options_to_change, + keys_to_change=keys_to_change, ) cfg = vak.config.Config.from_toml_path(toml_path) model_config = vak.config.model.config_from_toml_path(toml_path, cfg.train.model) @@ -197,8 +197,8 @@ def test_train_raises_file_not_found( @pytest.mark.parametrize( 'path_option_to_change', [ - {"section": "TRAIN", "option": "dataset_path", "value": '/obviously/doesnt/exist/dataset-dir'}, - {"section": "TRAIN", "option": "root_results_dir", "value": '/obviously/doesnt/exist/results/'}, + {"table": "train", "key": "dataset_path", "value": '/obviously/doesnt/exist/dataset-dir'}, + {"table": "train", "key": "root_results_dir", "value": '/obviously/doesnt/exist/results/'}, ] ) def test_train_raises_not_a_directory( @@ -207,9 +207,9 @@ def test_train_raises_not_a_directory( """Test that core.train raises NotADirectory when directory does not exist """ - options_to_change = [ + keys_to_change = [ path_option_to_change, - {"section": "TRAIN", "option": "device", "value": device}, + {"table": "train", "key": "device", "value": device}, ] toml_path = specific_config_toml_path( @@ -218,7 +218,7 @@ def test_train_raises_not_a_directory( audio_format="cbin", annot_format="notmat", spect_format=None, - options_to_change=options_to_change, + keys_to_change=keys_to_change, ) cfg = vak.config.Config.from_toml_path(toml_path) model_config = vak.config.model.config_from_toml_path(toml_path, cfg.train.model) diff --git a/tests/test_train/test_parametric_umap.py b/tests/test_train/test_parametric_umap.py index ab6ea672f..1fee8f303 100644 --- a/tests/test_train/test_parametric_umap.py +++ b/tests/test_train/test_parametric_umap.py @@ -39,9 +39,9 @@ def test_train_parametric_umap_model( ): results_path = vak.common.paths.generate_results_dir_name_as_path(tmp_path) results_path.mkdir() - options_to_change = [ - {"section": "TRAIN", "option": "device", "value": device}, - {"section": "TRAIN", "option": "root_results_dir", "value": results_path} + keys_to_change = [ + {"table": "train", "key": "device", "value": device}, + {"table": "train", "key": "root_results_dir", "value": results_path} ] toml_path = specific_config_toml_path( config_type="train", @@ -49,7 +49,7 @@ def test_train_parametric_umap_model( audio_format=audio_format, annot_format=annot_format, spect_format=spect_format, - options_to_change=options_to_change, + keys_to_change=keys_to_change, ) cfg = vak.config.Config.from_toml_path(toml_path) model_config = vak.config.model.config_from_toml_path(toml_path, cfg.train.model) @@ -79,7 +79,7 @@ def test_train_parametric_umap_model( @pytest.mark.parametrize( 'path_option_to_change', [ - {"section": "TRAIN", "option": "checkpoint_path", "value": '/obviously/doesnt/exist/ckpt.pt'}, + {"table": "train", "key": "checkpoint_path", "value": '/obviously/doesnt/exist/ckpt.pt'}, ] ) def test_train_parametric_umap_model_raises_file_not_found( @@ -89,8 +89,8 @@ def test_train_parametric_umap_model_raises_file_not_found( raise FileNotFoundError when one of the following does not exist: checkpoint_path, dataset_path """ - options_to_change = [ - {"section": "TRAIN", "option": "device", "value": device}, + keys_to_change = [ + {"table": "train", "key": "device", "value": device}, path_option_to_change ] toml_path = specific_config_toml_path( @@ -99,7 +99,7 @@ def test_train_parametric_umap_model_raises_file_not_found( audio_format="cbin", annot_format="notmat", spect_format=None, - options_to_change=options_to_change, + keys_to_change=keys_to_change, ) cfg = vak.config.Config.from_toml_path(toml_path) model_config = vak.config.model.config_from_toml_path(toml_path, cfg.train.model) @@ -130,8 +130,8 @@ def test_train_parametric_umap_model_raises_file_not_found( @pytest.mark.parametrize( 'path_option_to_change', [ - {"section": "TRAIN", "option": "dataset_path", "value": '/obviously/doesnt/exist/dataset-dir'}, - {"section": "TRAIN", "option": "root_results_dir", "value": '/obviously/doesnt/exist/results/'}, + {"table": "train", "key": "dataset_path", "value": '/obviously/doesnt/exist/dataset-dir'}, + {"table": "train", "key": "root_results_dir", "value": '/obviously/doesnt/exist/results/'}, ] ) def test_train_parametric_umap_model_raises_not_a_directory( @@ -140,9 +140,9 @@ def test_train_parametric_umap_model_raises_not_a_directory( """Test that core.train raises NotADirectory when directory does not exist """ - options_to_change = [ + keys_to_change = [ path_option_to_change, - {"section": "TRAIN", "option": "device", "value": device}, + {"table": "train", "key": "device", "value": device}, ] toml_path = specific_config_toml_path( @@ -151,7 +151,7 @@ def test_train_parametric_umap_model_raises_not_a_directory( audio_format="cbin", annot_format="notmat", spect_format=None, - options_to_change=options_to_change, + keys_to_change=keys_to_change, ) cfg = vak.config.Config.from_toml_path(toml_path) model_config = vak.config.model.config_from_toml_path(toml_path, cfg.train.model) diff --git a/tests/test_train/test_train.py b/tests/test_train/test_train.py index 5fb911e6e..71ec3b179 100644 --- a/tests/test_train/test_train.py +++ b/tests/test_train/test_train.py @@ -29,13 +29,13 @@ def test_train( root_results_dir = tmp_path.joinpath("test_train_root_results_dir") root_results_dir.mkdir() - options_to_change = [ + keys_to_change = [ { - "section": "TRAIN", - "option": "root_results_dir", + "table": "train", + "key": "root_results_dir", "value": str(root_results_dir), }, - {"section": "TRAIN", "option": "device", "value": 'cpu'}, + {"table": "train", "key": "device", "value": 'cpu'}, ] toml_path = specific_config_toml_path( @@ -44,7 +44,7 @@ def test_train( audio_format=audio_format, annot_format=annot_format, spect_format=spect_format, - options_to_change=options_to_change, + keys_to_change=keys_to_change, ) cfg = vak.config.Config.from_toml_path(toml_path) model_config = vak.config.model.config_from_toml_path(toml_path, cfg.train.model) From 9466d6c6e4a62ca7048e85f7d6049fe14e123b7f Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Fri, 3 May 2024 13:11:34 -0400 Subject: [PATCH 114/183] Fix tests in tests/test_cli/test_eval.py --- tests/test_cli/test_eval.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/test_cli/test_eval.py b/tests/test_cli/test_eval.py index ccebf39f6..c5a00bf48 100644 --- a/tests/test_cli/test_eval.py +++ b/tests/test_cli/test_eval.py @@ -48,14 +48,14 @@ def test_eval( assert cli_asserts.log_file_contains_version(command="eval", output_path=output_dir) -def test_eval_dataset_path_none_raises( - specific_config_toml_path, tmp_path, +def test_eval_dataset_none_raises( + specific_config_toml_path ): - """Test that cli.eval raises ValueError when dataset_path is None + """Test that cli.eval raises ValueError when dataset is None (presumably because `vak prep` was not run yet) """ keys_to_change = [ - {"table": "eval", "key": "dataset_path", "value": "DELETE-OPTION"}, + {"table": "eval", "key": "dataset", "value": "DELETE-OPTION"}, ] toml_path = specific_config_toml_path( @@ -67,5 +67,5 @@ def test_eval_dataset_path_none_raises( keys_to_change=keys_to_change, ) - with pytest.raises(ValueError): + with pytest.raises(KeyError): vak.cli.eval.eval(toml_path) From 02064555e6c42cabad3204f0fbcc276c9a9aa58b Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Fri, 3 May 2024 13:31:07 -0400 Subject: [PATCH 115/183] Finish fixing cli tests, fix renaming --- tests/fixtures/config.py | 4 ++-- tests/test_cli/test_eval.py | 2 +- tests/test_cli/test_learncurve.py | 10 +++++----- tests/test_cli/test_predict.py | 8 ++++---- tests/test_cli/test_prep.py | 12 ++++++------ tests/test_cli/test_train.py | 4 ++-- .../test_frame_classification.py | 2 +- 7 files changed, 21 insertions(+), 21 deletions(-) diff --git a/tests/fixtures/config.py b/tests/fixtures/config.py index a5c3418b9..8205d7c44 100644 --- a/tests/fixtures/config.py +++ b/tests/fixtures/config.py @@ -127,7 +127,7 @@ def _specific_config( keys_to_change : list, dict list of dicts with keys 'table', 'key', and 'value'. Can be a single dict, in which case only that key is changed. - If the 'value' is set to 'DELETE-OPTION', + If the 'value' is set to 'DELETE-KEY', the key will be removed from the config. This can be used to test behavior when the key is not set. @@ -176,7 +176,7 @@ def _specific_config( tomldoc = tomlkit.load(fp) for opt_dict in keys_to_change: - if opt_dict["value"] == 'DELETE-OPTION': + if opt_dict["value"] == 'DELETE-KEY': # e.g., to test behavior of config without this key del tomldoc["vak"][opt_dict["table"]][opt_dict["key"]] else: diff --git a/tests/test_cli/test_eval.py b/tests/test_cli/test_eval.py index c5a00bf48..0ee9aba65 100644 --- a/tests/test_cli/test_eval.py +++ b/tests/test_cli/test_eval.py @@ -55,7 +55,7 @@ def test_eval_dataset_none_raises( (presumably because `vak prep` was not run yet) """ keys_to_change = [ - {"table": "eval", "key": "dataset", "value": "DELETE-OPTION"}, + {"table": "eval", "key": "dataset", "value": "DELETE-KEY"}, ] toml_path = specific_config_toml_path( diff --git a/tests/test_cli/test_learncurve.py b/tests/test_cli/test_learncurve.py index 8c829015c..7fd0a3a8b 100644 --- a/tests/test_cli/test_learncurve.py +++ b/tests/test_cli/test_learncurve.py @@ -44,11 +44,11 @@ def test_learncurve(specific_config_toml_path, tmp_path, device): assert cli_asserts.log_file_contains_version(command="learncurve", output_path=results_path) -def test_learning_curve_dataset_path_none_raises( +def test_learning_curve_dataset_none_raises( specific_config_toml_path, tmp_path, ): """Test that cli.learncurve.learning_curve - raises ValueError when dataset_path is None + raises ValueError when dataset is None (presumably because `vak prep` was not run yet) """ root_results_dir = tmp_path.joinpath("test_learncurve_root_results_dir") @@ -62,8 +62,8 @@ def test_learning_curve_dataset_path_none_raises( }, { "table": "learncurve", - "key": "dataset_path", - "value": "DELETE-OPTION"}, + "key": "dataset", + "value": "DELETE-KEY"}, ] toml_path = specific_config_toml_path( @@ -75,5 +75,5 @@ def test_learning_curve_dataset_path_none_raises( keys_to_change=keys_to_change, ) - with pytest.raises(ValueError): + with pytest.raises(KeyError): vak.cli.learncurve.learning_curve(toml_path) diff --git a/tests/test_cli/test_predict.py b/tests/test_cli/test_predict.py index ceff6edcb..30c78d3c5 100644 --- a/tests/test_cli/test_predict.py +++ b/tests/test_cli/test_predict.py @@ -45,14 +45,14 @@ def test_predict( assert cli_asserts.log_file_contains_version(command="predict", output_path=output_dir) -def test_predict_dataset_path_none_raises( - specific_config_toml_path, tmp_path, +def test_predict_dataset_none_raises( + specific_config_toml_path ): """Test that cli.predict raises ValueError when dataset_path is None (presumably because `vak prep` was not run yet) """ keys_to_change = [ - {"table": "predict", "key": "dataset_path", "value": "DELETE-OPTION"}, + {"table": "predict", "key": "dataset", "value": "DELETE-KEY"}, ] toml_path = specific_config_toml_path( @@ -64,5 +64,5 @@ def test_predict_dataset_path_none_raises( keys_to_change=keys_to_change, ) - with pytest.raises(ValueError): + with pytest.raises(KeyError): vak.cli.predict.predict(toml_path) diff --git a/tests/test_cli/test_prep.py b/tests/test_cli/test_prep.py index e5be2c723..88a7c0a35 100644 --- a/tests/test_cli/test_prep.py +++ b/tests/test_cli/test_prep.py @@ -68,9 +68,9 @@ def test_prep( {"table": "prep", "key": "output_dir", "value": str(output_dir)}, # need to remove dataset_path option from configs we already ran prep on to avoid error { - "table": config_type.upper(), - "key": "dataset_path", - "value": None, + "table": config_type, + "key": "dataset", + "value": "DELETE-KEY", }, ] toml_path = specific_config_toml_path( @@ -98,16 +98,16 @@ def test_prep( ("train", None, "mat", "yarden"), ], ) -def test_prep_dataset_path_raises( +def test_prep_dataset_raises( config_type, audio_format, spect_format, annot_format, - specific_config_toml_path, + specific_config_toml_path, default_model, tmp_path, - ): + """Test that prep raises a ValueError when the config already has a dataset with a path""" output_dir = tmp_path.joinpath( f"test_prep_{config_type}_{audio_format}_{spect_format}_{annot_format}" ) diff --git a/tests/test_cli/test_train.py b/tests/test_cli/test_train.py index 8c6ec9781..a23acab3c 100644 --- a/tests/test_cli/test_train.py +++ b/tests/test_cli/test_train.py @@ -65,7 +65,7 @@ def test_train_dataset_path_none_raises( keys_to_change = [ {"table": "train", "key": "root_results_dir", "value": str(root_results_dir)}, - {"table": "train", "key": "dataset_path", "value": "DELETE-OPTION"}, + {"table": "train", "key": "dataset", "value": "DELETE-KEY"}, ] toml_path = specific_config_toml_path( @@ -77,5 +77,5 @@ def test_train_dataset_path_none_raises( keys_to_change=keys_to_change, ) - with pytest.raises(ValueError): + with pytest.raises(KeyError): vak.cli.train.train(toml_path) diff --git a/tests/test_prep/test_frame_classification/test_frame_classification.py b/tests/test_prep/test_frame_classification/test_frame_classification.py index e4513402f..851857c94 100644 --- a/tests/test_prep/test_frame_classification/test_frame_classification.py +++ b/tests/test_prep/test_frame_classification/test_frame_classification.py @@ -156,7 +156,7 @@ def test_prep_frame_classification_dataset_raises_when_labelset_required_but_is_ }, {"table": "prep", "key": "labelset", - "value": "DELETE-OPTION", + "value": "DELETE-KEY", }, ] toml_path = specific_config_toml_path( From 2ed3763db3253cd5785fa413b673e71b09e21e69 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Fri, 3 May 2024 13:35:11 -0400 Subject: [PATCH 116/183] Fix how we get 'path' from 'dataset' table in configs, in tests/fixtures/csv.py --- tests/fixtures/csv.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/fixtures/csv.py b/tests/fixtures/csv.py index e4ef71761..0e21cc196 100644 --- a/tests/fixtures/csv.py +++ b/tests/fixtures/csv.py @@ -25,11 +25,11 @@ def _specific_csv_path( config_toml = specific_config_toml( config_type, model, annot_format, audio_format, spect_format ) - dataset_path = Path(config_toml[config_type.upper()]["dataset_path"]) + dataset_path = Path(config_toml[config_type]["dataset"]["path"]) # TODO: make this more general -- dataset registry? - if config_toml['PREP']['dataset_type'] == 'frame classification': + if config_toml['prep']['dataset_type'] == 'frame classification': metadata = vak.datasets.frame_classification.Metadata.from_dataset_path(dataset_path) - elif config_toml['PREP']['dataset_type'] == 'parametric umap': + elif config_toml['prep']['dataset_type'] == 'parametric umap': metadata = vak.datasets.parametric_umap.Metadata.from_dataset_path(dataset_path) dataset_csv_path = dataset_path / metadata.dataset_csv_filename return dataset_csv_path From fca4ba9aa9dbe5ac9960ae37f1fdb7b919319cc2 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Fri, 3 May 2024 13:35:18 -0400 Subject: [PATCH 117/183] Fix how we get 'path' from 'dataset' table in configs, in tests/fixtures/dataset.py --- tests/fixtures/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/fixtures/dataset.py b/tests/fixtures/dataset.py index 51d1344b0..cb4382269 100644 --- a/tests/fixtures/dataset.py +++ b/tests/fixtures/dataset.py @@ -22,7 +22,7 @@ def _specific_dataset_path( config_toml = specific_config_toml( config_type, model, annot_format, audio_format, spect_format ) - dataset_path = Path(config_toml[config_type.upper()]["dataset_path"]) + dataset_path = Path(config_toml[config_type]["dataset"]["path"]) return dataset_path return _specific_dataset_path From d71d77385c99545db15349693bb3ee4d9e0887ac Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Fri, 3 May 2024 13:40:06 -0400 Subject: [PATCH 118/183] Change .dataset_path -> .dataset.path in tests/ --- .../test_frame_classification/test_frames_dataset.py | 2 +- .../test_frame_classification/test_window_dataset.py | 2 +- .../test_parametric_umap/test_parametric_umap.py | 2 +- tests/test_eval/test_eval.py | 2 +- tests/test_eval/test_frame_classification.py | 6 +++--- tests/test_eval/test_parametric_umap.py | 6 +++--- tests/test_learncurve/test_frame_classification.py | 4 ++-- tests/test_models/test_base.py | 2 +- tests/test_predict/test_frame_classification.py | 10 +++++----- tests/test_predict/test_predict.py | 2 +- .../test_frame_classification/test_learncurve.py | 4 ++-- tests/test_train/test_frame_classification.py | 8 ++++---- tests/test_train/test_parametric_umap.py | 6 +++--- tests/test_train/test_train.py | 2 +- 14 files changed, 29 insertions(+), 29 deletions(-) diff --git a/tests/test_datasets/test_frame_classification/test_frames_dataset.py b/tests/test_datasets/test_frame_classification/test_frames_dataset.py index 52a75495f..c165669b7 100644 --- a/tests/test_datasets/test_frame_classification/test_frames_dataset.py +++ b/tests/test_datasets/test_frame_classification/test_frames_dataset.py @@ -27,7 +27,7 @@ def test_from_dataset_path(self, config_type, model_name, audio_format, spect_fo ) dataset = vak.datasets.frame_classification.FramesDataset.from_dataset_path( - dataset_path=cfg_command.dataset_path, + dataset_path=cfg_command.dataset.path, split=split, item_transform=item_transform, ) diff --git a/tests/test_datasets/test_frame_classification/test_window_dataset.py b/tests/test_datasets/test_frame_classification/test_window_dataset.py index 67c6fde50..238f7de12 100644 --- a/tests/test_datasets/test_frame_classification/test_window_dataset.py +++ b/tests/test_datasets/test_frame_classification/test_window_dataset.py @@ -28,7 +28,7 @@ def test_from_dataset_path(self, config_type, model_name, audio_format, spect_fo ) dataset = vak.datasets.frame_classification.WindowDataset.from_dataset_path( - dataset_path=cfg_command.dataset_path, + dataset_path=cfg_command.dataset.path, split=split, window_size=cfg_command.train_dataset_params['window_size'], transform=transform, diff --git a/tests/test_datasets/test_parametric_umap/test_parametric_umap.py b/tests/test_datasets/test_parametric_umap/test_parametric_umap.py index e2829991d..7f2c0bb38 100644 --- a/tests/test_datasets/test_parametric_umap/test_parametric_umap.py +++ b/tests/test_datasets/test_parametric_umap/test_parametric_umap.py @@ -27,7 +27,7 @@ def test_from_dataset_path(self, config_type, model_name, audio_format, spect_fo ) dataset = vak.datasets.parametric_umap.ParametricUMAPDataset.from_dataset_path( - dataset_path=cfg_command.dataset_path, + dataset_path=cfg_command.dataset.path, split=split, transform=transform, ) diff --git a/tests/test_eval/test_eval.py b/tests/test_eval/test_eval.py index 7158cd393..2974af371 100644 --- a/tests/test_eval/test_eval.py +++ b/tests/test_eval/test_eval.py @@ -52,7 +52,7 @@ def test_eval( vak.eval.eval( model_name=model_name, model_config=model_config, - dataset_path=cfg.eval.dataset_path, + dataset_path=cfg.eval.dataset.path, checkpoint_path=cfg.eval.checkpoint_path, labelmap_path=cfg.eval.labelmap_path, output_dir=cfg.eval.output_dir, diff --git a/tests/test_eval/test_frame_classification.py b/tests/test_eval/test_frame_classification.py index 4528fd909..780004e33 100644 --- a/tests/test_eval/test_frame_classification.py +++ b/tests/test_eval/test_frame_classification.py @@ -73,7 +73,7 @@ def test_eval_frame_classification_model( vak.eval.frame_classification.eval_frame_classification_model( model_name=cfg.eval.model, model_config=model_config, - dataset_path=cfg.eval.dataset_path, + dataset_path=cfg.eval.dataset.path, checkpoint_path=cfg.eval.checkpoint_path, labelmap_path=cfg.eval.labelmap_path, output_dir=cfg.eval.output_dir, @@ -131,7 +131,7 @@ def test_eval_frame_classification_model_raises_file_not_found( vak.eval.frame_classification.eval_frame_classification_model( model_name=cfg.eval.model, model_config=model_config, - dataset_path=cfg.eval.dataset_path, + dataset_path=cfg.eval.dataset.path, checkpoint_path=cfg.eval.checkpoint_path, labelmap_path=cfg.eval.labelmap_path, output_dir=cfg.eval.output_dir, @@ -189,7 +189,7 @@ def test_eval_frame_classification_model_raises_not_a_directory( vak.eval.frame_classification.eval_frame_classification_model( model_name=cfg.eval.model, model_config=model_config, - dataset_path=cfg.eval.dataset_path, + dataset_path=cfg.eval.dataset.path, checkpoint_path=cfg.eval.checkpoint_path, labelmap_path=cfg.eval.labelmap_path, output_dir=cfg.eval.output_dir, diff --git a/tests/test_eval/test_parametric_umap.py b/tests/test_eval/test_parametric_umap.py index 824cbf2d8..069ea7e1e 100644 --- a/tests/test_eval/test_parametric_umap.py +++ b/tests/test_eval/test_parametric_umap.py @@ -51,7 +51,7 @@ def test_eval_parametric_umap_model( vak.eval.parametric_umap.eval_parametric_umap_model( model_name=cfg.eval.model, model_config=model_config, - dataset_path=cfg.eval.dataset_path, + dataset_path=cfg.eval.dataset.path, checkpoint_path=cfg.eval.checkpoint_path, output_dir=cfg.eval.output_dir, batch_size=cfg.eval.batch_size, @@ -103,7 +103,7 @@ def test_eval_frame_classification_model_raises_file_not_found( vak.eval.parametric_umap.eval_parametric_umap_model( model_name=cfg.eval.model, model_config=model_config, - dataset_path=cfg.eval.dataset_path, + dataset_path=cfg.eval.dataset.path, checkpoint_path=cfg.eval.checkpoint_path, output_dir=cfg.eval.output_dir, batch_size=cfg.eval.batch_size, @@ -159,7 +159,7 @@ def test_eval_frame_classification_model_raises_not_a_directory( vak.eval.parametric_umap.eval_parametric_umap_model( model_name=cfg.eval.model, model_config=model_config, - dataset_path=cfg.eval.dataset_path, + dataset_path=cfg.eval.dataset.path, checkpoint_path=cfg.eval.checkpoint_path, output_dir=cfg.eval.output_dir, batch_size=cfg.eval.batch_size, diff --git a/tests/test_learncurve/test_frame_classification.py b/tests/test_learncurve/test_frame_classification.py index a985ba422..5d7941c74 100644 --- a/tests/test_learncurve/test_frame_classification.py +++ b/tests/test_learncurve/test_frame_classification.py @@ -70,7 +70,7 @@ def test_learning_curve_for_frame_classification_model( vak.learncurve.frame_classification.learning_curve_for_frame_classification_model( model_name=cfg.learncurve.model, model_config=model_config, - dataset_path=cfg.learncurve.dataset_path, + dataset_path=cfg.learncurve.dataset.path, batch_size=cfg.learncurve.batch_size, num_epochs=cfg.learncurve.num_epochs, num_workers=cfg.learncurve.num_workers, @@ -125,7 +125,7 @@ def test_learncurve_raises_not_a_directory(dir_option_to_change, vak.learncurve.frame_classification.learning_curve_for_frame_classification_model( model_name=cfg.learncurve.model, model_config=model_config, - dataset_path=cfg.learncurve.dataset_path, + dataset_path=cfg.learncurve.dataset.path, batch_size=cfg.learncurve.batch_size, num_epochs=cfg.learncurve.num_epochs, num_workers=cfg.learncurve.num_workers, diff --git a/tests/test_models/test_base.py b/tests/test_models/test_base.py index d916e049a..c49031172 100644 --- a/tests/test_models/test_base.py +++ b/tests/test_models/test_base.py @@ -201,7 +201,7 @@ def test_load_state_dict_from_path(self, transform_kwargs={}, ) train_dataset = vak.datasets.frame_classification.WindowDataset.from_dataset_path( - dataset_path=train_cfg.train.dataset_path, + dataset_path=train_cfg.train.dataset.path, split="train", window_size=train_cfg.train.train_dataset_params['window_size'], transform=transform, diff --git a/tests/test_predict/test_frame_classification.py b/tests/test_predict/test_frame_classification.py index e3d52a38b..c00e53743 100644 --- a/tests/test_predict/test_frame_classification.py +++ b/tests/test_predict/test_frame_classification.py @@ -56,7 +56,7 @@ def test_predict_with_frame_classification_model( vak.predict.frame_classification.predict_with_frame_classification_model( model_name=cfg.predict.model, model_config=model_config, - dataset_path=cfg.predict.dataset_path, + dataset_path=cfg.predict.dataset.path, checkpoint_path=cfg.predict.checkpoint_path, labelmap_path=cfg.predict.labelmap_path, num_workers=cfg.predict.num_workers, @@ -78,8 +78,8 @@ def test_predict_with_frame_classification_model( Path(output_dir).glob(f"*{vak.common.constants.NET_OUTPUT_SUFFIX}") ) - metadata = vak.datasets.frame_classification.Metadata.from_dataset_path(cfg.predict.dataset_path) - dataset_csv_path = cfg.predict.dataset_path / metadata.dataset_csv_filename + metadata = vak.datasets.frame_classification.Metadata.from_dataset_path(cfg.predict.dataset.path) + dataset_csv_path = cfg.predict.dataset.path / metadata.dataset_csv_filename dataset_df = pd.read_csv(dataset_csv_path) for spect_path in dataset_df.spect_path.values: @@ -132,7 +132,7 @@ def test_predict_with_frame_classification_model_raises_file_not_found( vak.predict.frame_classification.predict_with_frame_classification_model( model_name=cfg.predict.model, model_config=model_config, - dataset_path=cfg.predict.dataset_path, + dataset_path=cfg.predict.dataset.path, checkpoint_path=cfg.predict.checkpoint_path, labelmap_path=cfg.predict.labelmap_path, num_workers=cfg.predict.num_workers, @@ -195,7 +195,7 @@ def test_predict_with_frame_classification_model_raises_not_a_directory( vak.predict.frame_classification.predict_with_frame_classification_model( model_name=cfg.predict.model, model_config=model_config, - dataset_path=cfg.predict.dataset_path, + dataset_path=cfg.predict.dataset.path, checkpoint_path=cfg.predict.checkpoint_path, labelmap_path=cfg.predict.labelmap_path, num_workers=cfg.predict.num_workers, diff --git a/tests/test_predict/test_predict.py b/tests/test_predict/test_predict.py index 6b0768380..144e5e852 100644 --- a/tests/test_predict/test_predict.py +++ b/tests/test_predict/test_predict.py @@ -48,7 +48,7 @@ def test_predict( vak.predict.predict( model_name=model_name, model_config=model_config, - dataset_path=cfg.predict.dataset_path, + dataset_path=cfg.predict.dataset.path, checkpoint_path=cfg.predict.checkpoint_path, labelmap_path=cfg.predict.labelmap_path, num_workers=cfg.predict.num_workers, diff --git a/tests/test_prep/test_frame_classification/test_learncurve.py b/tests/test_prep/test_frame_classification/test_learncurve.py index 3b93e42da..ef3105144 100644 --- a/tests/test_prep/test_frame_classification/test_learncurve.py +++ b/tests/test_prep/test_frame_classification/test_learncurve.py @@ -38,7 +38,7 @@ def test_make_index_vectors_for_each_subsets( ) cfg = vak.config.Config.from_toml_path(toml_path) - dataset_path = cfg.learncurve.dataset_path + dataset_path = cfg.learncurve.dataset.path metadata = vak.datasets.frame_classification.Metadata.from_dataset_path(dataset_path) dataset_csv_path = dataset_path / metadata.dataset_csv_filename dataset_df = pd.read_csv(dataset_csv_path) @@ -150,7 +150,7 @@ def test_make_subsets_from_dataset_df( ) cfg = vak.config.Config.from_toml_path(toml_path) - dataset_path = cfg.learncurve.dataset_path + dataset_path = cfg.learncurve.dataset.path metadata = vak.datasets.frame_classification.Metadata.from_dataset_path(dataset_path) dataset_csv_path = dataset_path / metadata.dataset_csv_filename dataset_df = pd.read_csv(dataset_csv_path) diff --git a/tests/test_train/test_frame_classification.py b/tests/test_train/test_frame_classification.py index 8eafd4d31..d02271d0b 100644 --- a/tests/test_train/test_frame_classification.py +++ b/tests/test_train/test_frame_classification.py @@ -64,7 +64,7 @@ def test_train_frame_classification_model( vak.train.frame_classification.train_frame_classification_model( model_name=cfg.train.model, model_config=model_config, - dataset_path=cfg.train.dataset_path, + dataset_path=cfg.train.dataset.path, batch_size=cfg.train.batch_size, num_epochs=cfg.train.num_epochs, num_workers=cfg.train.num_workers, @@ -117,7 +117,7 @@ def test_continue_training( vak.train.frame_classification.train_frame_classification_model( model_name=cfg.train.model, model_config=model_config, - dataset_path=cfg.train.dataset_path, + dataset_path=cfg.train.dataset.path, batch_size=cfg.train.batch_size, num_epochs=cfg.train.num_epochs, num_workers=cfg.train.num_workers, @@ -174,7 +174,7 @@ def test_train_raises_file_not_found( vak.train.frame_classification.train_frame_classification_model( model_name=cfg.train.model, model_config=model_config, - dataset_path=cfg.train.dataset_path, + dataset_path=cfg.train.dataset.path, batch_size=cfg.train.batch_size, num_epochs=cfg.train.num_epochs, num_workers=cfg.train.num_workers, @@ -230,7 +230,7 @@ def test_train_raises_not_a_directory( vak.train.frame_classification.train_frame_classification_model( model_name=cfg.train.model, model_config=model_config, - dataset_path=cfg.train.dataset_path, + dataset_path=cfg.train.dataset.path, batch_size=cfg.train.batch_size, num_epochs=cfg.train.num_epochs, num_workers=cfg.train.num_workers, diff --git a/tests/test_train/test_parametric_umap.py b/tests/test_train/test_parametric_umap.py index 1fee8f303..d8e46e4fc 100644 --- a/tests/test_train/test_parametric_umap.py +++ b/tests/test_train/test_parametric_umap.py @@ -57,7 +57,7 @@ def test_train_parametric_umap_model( vak.train.parametric_umap.train_parametric_umap_model( model_name=cfg.train.model, model_config=model_config, - dataset_path=cfg.train.dataset_path, + dataset_path=cfg.train.dataset.path, batch_size=cfg.train.batch_size, num_epochs=cfg.train.num_epochs, num_workers=cfg.train.num_workers, @@ -110,7 +110,7 @@ def test_train_parametric_umap_model_raises_file_not_found( vak.train.parametric_umap.train_parametric_umap_model( model_name=cfg.train.model, model_config=model_config, - dataset_path=cfg.train.dataset_path, + dataset_path=cfg.train.dataset.path, batch_size=cfg.train.batch_size, num_epochs=cfg.train.num_epochs, num_workers=cfg.train.num_workers, @@ -163,7 +163,7 @@ def test_train_parametric_umap_model_raises_not_a_directory( vak.train.parametric_umap.train_parametric_umap_model( model_name=cfg.train.model, model_config=model_config, - dataset_path=cfg.train.dataset_path, + dataset_path=cfg.train.dataset.path, batch_size=cfg.train.batch_size, num_epochs=cfg.train.num_epochs, num_workers=cfg.train.num_workers, diff --git a/tests/test_train/test_train.py b/tests/test_train/test_train.py index 71ec3b179..502dc63d2 100644 --- a/tests/test_train/test_train.py +++ b/tests/test_train/test_train.py @@ -56,7 +56,7 @@ def test_train( vak.train.train( model_name=model_name, model_config=model_config, - dataset_path=cfg.train.dataset_path, + dataset_path=cfg.train.dataset.path, batch_size=cfg.train.batch_size, num_epochs=cfg.train.num_epochs, num_workers=cfg.train.num_workers, From 4a5122d88e9e3d9746a67f76e48f342521270a2b Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Fri, 3 May 2024 14:01:59 -0400 Subject: [PATCH 119/183] Fix how we get model config and rename config attribute .dataset_path -> .dataset.path throughout tests --- .../test_frames_dataset.py | 2 +- .../test_window_dataset.py | 2 +- .../test_parametric_umap.py | 2 +- tests/test_eval/test_eval.py | 7 ++--- tests/test_eval/test_frame_classification.py | 21 ++++++-------- tests/test_eval/test_parametric_umap.py | 21 ++++++-------- .../test_frame_classification.py | 14 ++++------ tests/test_models/test_base.py | 5 ++-- .../test_predict/test_frame_classification.py | 27 ++++++++---------- tests/test_predict/test_predict.py | 7 ++--- .../test_learncurve.py | 4 +-- tests/test_train/test_frame_classification.py | 28 ++++++++----------- tests/test_train/test_parametric_umap.py | 16 +++++------ tests/test_train/test_train.py | 7 ++--- 14 files changed, 71 insertions(+), 92 deletions(-) diff --git a/tests/test_datasets/test_frame_classification/test_frames_dataset.py b/tests/test_datasets/test_frame_classification/test_frames_dataset.py index c165669b7..386f52836 100644 --- a/tests/test_datasets/test_frame_classification/test_frames_dataset.py +++ b/tests/test_datasets/test_frame_classification/test_frames_dataset.py @@ -27,7 +27,7 @@ def test_from_dataset_path(self, config_type, model_name, audio_format, spect_fo ) dataset = vak.datasets.frame_classification.FramesDataset.from_dataset_path( - dataset_path=cfg_command.dataset.path, + dataset_path=cfg_command, split=split, item_transform=item_transform, ) diff --git a/tests/test_datasets/test_frame_classification/test_window_dataset.py b/tests/test_datasets/test_frame_classification/test_window_dataset.py index 238f7de12..94b920a14 100644 --- a/tests/test_datasets/test_frame_classification/test_window_dataset.py +++ b/tests/test_datasets/test_frame_classification/test_window_dataset.py @@ -28,7 +28,7 @@ def test_from_dataset_path(self, config_type, model_name, audio_format, spect_fo ) dataset = vak.datasets.frame_classification.WindowDataset.from_dataset_path( - dataset_path=cfg_command.dataset.path, + dataset_path=cfg_command, split=split, window_size=cfg_command.train_dataset_params['window_size'], transform=transform, diff --git a/tests/test_datasets/test_parametric_umap/test_parametric_umap.py b/tests/test_datasets/test_parametric_umap/test_parametric_umap.py index 7f2c0bb38..533f0366d 100644 --- a/tests/test_datasets/test_parametric_umap/test_parametric_umap.py +++ b/tests/test_datasets/test_parametric_umap/test_parametric_umap.py @@ -27,7 +27,7 @@ def test_from_dataset_path(self, config_type, model_name, audio_format, spect_fo ) dataset = vak.datasets.parametric_umap.ParametricUMAPDataset.from_dataset_path( - dataset_path=cfg_command.dataset.path, + dataset_path=cfg_command, split=split, transform=transform, ) diff --git a/tests/test_eval/test_eval.py b/tests/test_eval/test_eval.py index 2974af371..ed1efec3e 100644 --- a/tests/test_eval/test_eval.py +++ b/tests/test_eval/test_eval.py @@ -43,16 +43,15 @@ def test_eval( keys_to_change=keys_to_change, ) cfg = vak.config.Config.from_toml_path(toml_path) - model_config = vak.config.model.config_from_toml_path(toml_path, cfg.eval.model) results_path = tmp_path / 'results_path' results_path.mkdir() with mock.patch(eval_function_to_mock, autospec=True) as mock_eval_function: vak.eval.eval( - model_name=model_name, - model_config=model_config, - dataset_path=cfg.eval.dataset.path, + model_name=cfg.eval.model.name, + model_config=cfg.eval.model.to_dict(), + dataset_path=cfg.eval, checkpoint_path=cfg.eval.checkpoint_path, labelmap_path=cfg.eval.labelmap_path, output_dir=cfg.eval.output_dir, diff --git a/tests/test_eval/test_frame_classification.py b/tests/test_eval/test_frame_classification.py index 780004e33..49d253f00 100644 --- a/tests/test_eval/test_frame_classification.py +++ b/tests/test_eval/test_frame_classification.py @@ -68,12 +68,11 @@ def test_eval_frame_classification_model( keys_to_change=keys_to_change, ) cfg = vak.config.Config.from_toml_path(toml_path) - model_config = vak.config.model.config_from_toml_path(toml_path, cfg.eval.model) vak.eval.frame_classification.eval_frame_classification_model( - model_name=cfg.eval.model, - model_config=model_config, - dataset_path=cfg.eval.dataset.path, + model_name=cfg.eval.model.name, + model_config=cfg.eval.model.to_dict(), + dataset_path=cfg.eval, checkpoint_path=cfg.eval.checkpoint_path, labelmap_path=cfg.eval.labelmap_path, output_dir=cfg.eval.output_dir, @@ -126,12 +125,11 @@ def test_eval_frame_classification_model_raises_file_not_found( keys_to_change=keys_to_change, ) cfg = vak.config.Config.from_toml_path(toml_path) - model_config = vak.config.model.config_from_toml_path(toml_path, cfg.eval.model) with pytest.raises(FileNotFoundError): vak.eval.frame_classification.eval_frame_classification_model( - model_name=cfg.eval.model, - model_config=model_config, - dataset_path=cfg.eval.dataset.path, + model_name=cfg.eval.model.name, + model_config=cfg.eval.model.to_dict(), + dataset_path=cfg.eval, checkpoint_path=cfg.eval.checkpoint_path, labelmap_path=cfg.eval.labelmap_path, output_dir=cfg.eval.output_dir, @@ -184,12 +182,11 @@ def test_eval_frame_classification_model_raises_not_a_directory( keys_to_change=keys_to_change, ) cfg = vak.config.Config.from_toml_path(toml_path) - model_config = vak.config.model.config_from_toml_path(toml_path, cfg.eval.model) with pytest.raises(NotADirectoryError): vak.eval.frame_classification.eval_frame_classification_model( - model_name=cfg.eval.model, - model_config=model_config, - dataset_path=cfg.eval.dataset.path, + model_name=cfg.eval.model.name, + model_config=cfg.eval.model.to_dict(), + dataset_path=cfg.eval, checkpoint_path=cfg.eval.checkpoint_path, labelmap_path=cfg.eval.labelmap_path, output_dir=cfg.eval.output_dir, diff --git a/tests/test_eval/test_parametric_umap.py b/tests/test_eval/test_parametric_umap.py index 069ea7e1e..e9f8f85c3 100644 --- a/tests/test_eval/test_parametric_umap.py +++ b/tests/test_eval/test_parametric_umap.py @@ -46,12 +46,11 @@ def test_eval_parametric_umap_model( keys_to_change=keys_to_change, ) cfg = vak.config.Config.from_toml_path(toml_path) - model_config = vak.config.model.config_from_toml_path(toml_path, cfg.eval.model) vak.eval.parametric_umap.eval_parametric_umap_model( - model_name=cfg.eval.model, - model_config=model_config, - dataset_path=cfg.eval.dataset.path, + model_name=cfg.eval.model.name, + model_config=cfg.eval.model.to_dict(), + dataset_path=cfg.eval, checkpoint_path=cfg.eval.checkpoint_path, output_dir=cfg.eval.output_dir, batch_size=cfg.eval.batch_size, @@ -98,12 +97,11 @@ def test_eval_frame_classification_model_raises_file_not_found( keys_to_change=keys_to_change, ) cfg = vak.config.Config.from_toml_path(toml_path) - model_config = vak.config.model.config_from_toml_path(toml_path, cfg.eval.model) with pytest.raises(FileNotFoundError): vak.eval.parametric_umap.eval_parametric_umap_model( - model_name=cfg.eval.model, - model_config=model_config, - dataset_path=cfg.eval.dataset.path, + model_name=cfg.eval.model.name, + model_config=cfg.eval.model.to_dict(), + dataset_path=cfg.eval, checkpoint_path=cfg.eval.checkpoint_path, output_dir=cfg.eval.output_dir, batch_size=cfg.eval.batch_size, @@ -154,12 +152,11 @@ def test_eval_frame_classification_model_raises_not_a_directory( keys_to_change=keys_to_change, ) cfg = vak.config.Config.from_toml_path(toml_path) - model_config = vak.config.model.config_from_toml_path(toml_path, cfg.eval.model) with pytest.raises(NotADirectoryError): vak.eval.parametric_umap.eval_parametric_umap_model( - model_name=cfg.eval.model, - model_config=model_config, - dataset_path=cfg.eval.dataset.path, + model_name=cfg.eval.model.name, + model_config=cfg.eval.model.to_dict(), + dataset_path=cfg.eval, checkpoint_path=cfg.eval.checkpoint_path, output_dir=cfg.eval.output_dir, batch_size=cfg.eval.batch_size, diff --git a/tests/test_learncurve/test_frame_classification.py b/tests/test_learncurve/test_frame_classification.py index 5d7941c74..e53ed914e 100644 --- a/tests/test_learncurve/test_frame_classification.py +++ b/tests/test_learncurve/test_frame_classification.py @@ -63,14 +63,13 @@ def test_learning_curve_for_frame_classification_model( ) cfg = vak.config.Config.from_toml_path(toml_path) - model_config = vak.config.model.config_from_toml_path(toml_path, cfg.learncurve.model) results_path = vak.common.paths.generate_results_dir_name_as_path(tmp_path) results_path.mkdir() vak.learncurve.frame_classification.learning_curve_for_frame_classification_model( - model_name=cfg.learncurve.model, - model_config=model_config, - dataset_path=cfg.learncurve.dataset.path, + model_name=cfg.learncurve.model.name, + model_config=cfg.learncurve.model.to_dict(), + dataset_path=cfg.learncurve, batch_size=cfg.learncurve.batch_size, num_epochs=cfg.learncurve.num_epochs, num_workers=cfg.learncurve.num_workers, @@ -117,15 +116,14 @@ def test_learncurve_raises_not_a_directory(dir_option_to_change, keys_to_change=keys_to_change, ) cfg = vak.config.Config.from_toml_path(toml_path) - model_config = vak.config.model.config_from_toml_path(toml_path, cfg.learncurve.model) # mock behavior of cli.learncurve, building `results_path` from config option `root_results_dir` results_path = cfg.learncurve.root_results_dir / 'results-dir-timestamp' with pytest.raises(NotADirectoryError): vak.learncurve.frame_classification.learning_curve_for_frame_classification_model( - model_name=cfg.learncurve.model, - model_config=model_config, - dataset_path=cfg.learncurve.dataset.path, + model_name=cfg.learncurve.model.name, + model_config=cfg.learncurve.model.to_dict(), + dataset_path=cfg.learncurve, batch_size=cfg.learncurve.batch_size, num_epochs=cfg.learncurve.num_epochs, num_workers=cfg.learncurve.num_workers, diff --git a/tests/test_models/test_base.py b/tests/test_models/test_base.py index c49031172..44b01214c 100644 --- a/tests/test_models/test_base.py +++ b/tests/test_models/test_base.py @@ -201,7 +201,7 @@ def test_load_state_dict_from_path(self, transform_kwargs={}, ) train_dataset = vak.datasets.frame_classification.WindowDataset.from_dataset_path( - dataset_path=train_cfg.train.dataset.path, + dataset_path=train_cfg.train, split="train", window_size=train_cfg.train.train_dataset_params['window_size'], transform=transform, @@ -216,7 +216,8 @@ def test_load_state_dict_from_path(self, ) # network is the one thing that has required args # and we also need to use its config from the toml file - model_config = vak.config.model.config_from_toml_path(train_toml_path, model_name) + cfg = vak.config.Config.from_toml_path(train_toml_path) + model_config = cfg.train.model.to_dict() network = definition.network(num_classes=len(labelmap), num_input_channels=num_input_channels, num_freqbins=num_freqbins, diff --git a/tests/test_predict/test_frame_classification.py b/tests/test_predict/test_frame_classification.py index c00e53743..e7feb7cef 100644 --- a/tests/test_predict/test_frame_classification.py +++ b/tests/test_predict/test_frame_classification.py @@ -51,12 +51,10 @@ def test_predict_with_frame_classification_model( ) cfg = vak.config.Config.from_toml_path(toml_path) - model_config = vak.config.model.config_from_toml_path(toml_path, cfg.predict.model) - vak.predict.frame_classification.predict_with_frame_classification_model( - model_name=cfg.predict.model, - model_config=model_config, - dataset_path=cfg.predict.dataset.path, + model_name=cfg.predict.model.name, + model_config=cfg.predict.model.to_dict(), + dataset_path=cfg.predict, checkpoint_path=cfg.predict.checkpoint_path, labelmap_path=cfg.predict.labelmap_path, num_workers=cfg.predict.num_workers, @@ -78,8 +76,8 @@ def test_predict_with_frame_classification_model( Path(output_dir).glob(f"*{vak.common.constants.NET_OUTPUT_SUFFIX}") ) - metadata = vak.datasets.frame_classification.Metadata.from_dataset_path(cfg.predict.dataset.path) - dataset_csv_path = cfg.predict.dataset.path / metadata.dataset_csv_filename + metadata = vak.datasets.frame_classification.Metadata.from_dataset_path(cfg.predict) + dataset_csv_path = cfg.predict / metadata.dataset_csv_filename dataset_df = pd.read_csv(dataset_csv_path) for spect_path in dataset_df.spect_path.values: @@ -126,13 +124,11 @@ def test_predict_with_frame_classification_model_raises_file_not_found( ) cfg = vak.config.Config.from_toml_path(toml_path) - model_config = vak.config.model.config_from_toml_path(toml_path, cfg.predict.model) - with pytest.raises(FileNotFoundError): vak.predict.frame_classification.predict_with_frame_classification_model( - model_name=cfg.predict.model, - model_config=model_config, - dataset_path=cfg.predict.dataset.path, + model_name=cfg.predict.model.name, + model_config=cfg.predict.model.to_dict(), + dataset_path=cfg.predict, checkpoint_path=cfg.predict.checkpoint_path, labelmap_path=cfg.predict.labelmap_path, num_workers=cfg.predict.num_workers, @@ -189,13 +185,12 @@ def test_predict_with_frame_classification_model_raises_not_a_directory( keys_to_change=keys_to_change, ) cfg = vak.config.Config.from_toml_path(toml_path) - model_config = vak.config.model.config_from_toml_path(toml_path, cfg.predict.model) with pytest.raises(NotADirectoryError): vak.predict.frame_classification.predict_with_frame_classification_model( - model_name=cfg.predict.model, - model_config=model_config, - dataset_path=cfg.predict.dataset.path, + model_name=cfg.predict.model.name, + model_config=cfg.predict.model.to_dict(), + dataset_path=cfg.predict, checkpoint_path=cfg.predict.checkpoint_path, labelmap_path=cfg.predict.labelmap_path, num_workers=cfg.predict.num_workers, diff --git a/tests/test_predict/test_predict.py b/tests/test_predict/test_predict.py index 144e5e852..ad697e68d 100644 --- a/tests/test_predict/test_predict.py +++ b/tests/test_predict/test_predict.py @@ -39,16 +39,15 @@ def test_predict( keys_to_change=keys_to_change, ) cfg = vak.config.Config.from_toml_path(toml_path) - model_config = vak.config.model.config_from_toml_path(toml_path, cfg.predict.model) results_path = tmp_path / 'results_path' results_path.mkdir() with mock.patch(predict_function_to_mock, autospec=True) as mock_predict_function: vak.predict.predict( - model_name=model_name, - model_config=model_config, - dataset_path=cfg.predict.dataset.path, + model_name=cfg.predict.model.name, + model_config=cfg.predict.model.to_dict(), + dataset_path=cfg.predict, checkpoint_path=cfg.predict.checkpoint_path, labelmap_path=cfg.predict.labelmap_path, num_workers=cfg.predict.num_workers, diff --git a/tests/test_prep/test_frame_classification/test_learncurve.py b/tests/test_prep/test_frame_classification/test_learncurve.py index ef3105144..7a67956da 100644 --- a/tests/test_prep/test_frame_classification/test_learncurve.py +++ b/tests/test_prep/test_frame_classification/test_learncurve.py @@ -38,7 +38,7 @@ def test_make_index_vectors_for_each_subsets( ) cfg = vak.config.Config.from_toml_path(toml_path) - dataset_path = cfg.learncurve.dataset.path + dataset_path = cfg.learncurve metadata = vak.datasets.frame_classification.Metadata.from_dataset_path(dataset_path) dataset_csv_path = dataset_path / metadata.dataset_csv_filename dataset_df = pd.read_csv(dataset_csv_path) @@ -150,7 +150,7 @@ def test_make_subsets_from_dataset_df( ) cfg = vak.config.Config.from_toml_path(toml_path) - dataset_path = cfg.learncurve.dataset.path + dataset_path = cfg.learncurve metadata = vak.datasets.frame_classification.Metadata.from_dataset_path(dataset_path) dataset_csv_path = dataset_path / metadata.dataset_csv_filename dataset_df = pd.read_csv(dataset_csv_path) diff --git a/tests/test_train/test_frame_classification.py b/tests/test_train/test_frame_classification.py index d02271d0b..1d762d2fe 100644 --- a/tests/test_train/test_frame_classification.py +++ b/tests/test_train/test_frame_classification.py @@ -59,12 +59,11 @@ def test_train_frame_classification_model( keys_to_change=keys_to_change, ) cfg = vak.config.Config.from_toml_path(toml_path) - model_config = vak.config.model.config_from_toml_path(toml_path, cfg.train.model) vak.train.frame_classification.train_frame_classification_model( - model_name=cfg.train.model, - model_config=model_config, - dataset_path=cfg.train.dataset.path, + model_name=cfg.train.model.name, + model_config=cfg.train.model.to_dict(), + dataset_path=cfg.train, batch_size=cfg.train.batch_size, num_epochs=cfg.train.num_epochs, num_workers=cfg.train.num_workers, @@ -112,12 +111,11 @@ def test_continue_training( keys_to_change=keys_to_change, ) cfg = vak.config.Config.from_toml_path(toml_path) - model_config = vak.config.model.config_from_toml_path(toml_path, cfg.train.model) vak.train.frame_classification.train_frame_classification_model( - model_name=cfg.train.model, - model_config=model_config, - dataset_path=cfg.train.dataset.path, + model_name=cfg.train.model.name, + model_config=cfg.train.model.to_dict(), + dataset_path=cfg.train, batch_size=cfg.train.batch_size, num_epochs=cfg.train.num_epochs, num_workers=cfg.train.num_workers, @@ -166,15 +164,14 @@ def test_train_raises_file_not_found( keys_to_change=keys_to_change, ) cfg = vak.config.Config.from_toml_path(toml_path) - model_config = vak.config.model.config_from_toml_path(toml_path, cfg.train.model) results_path = vak.common.paths.generate_results_dir_name_as_path(tmp_path) results_path.mkdir() with pytest.raises(FileNotFoundError): vak.train.frame_classification.train_frame_classification_model( - model_name=cfg.train.model, - model_config=model_config, - dataset_path=cfg.train.dataset.path, + model_name=cfg.train.model.name, + model_config=cfg.train.model.to_dict(), + dataset_path=cfg.train, batch_size=cfg.train.batch_size, num_epochs=cfg.train.num_epochs, num_workers=cfg.train.num_workers, @@ -221,16 +218,15 @@ def test_train_raises_not_a_directory( keys_to_change=keys_to_change, ) cfg = vak.config.Config.from_toml_path(toml_path) - model_config = vak.config.model.config_from_toml_path(toml_path, cfg.train.model) # mock behavior of cli.train, building `results_path` from config option `root_results_dir` results_path = cfg.train.root_results_dir / 'results-dir-timestamp' with pytest.raises(NotADirectoryError): vak.train.frame_classification.train_frame_classification_model( - model_name=cfg.train.model, - model_config=model_config, - dataset_path=cfg.train.dataset.path, + model_name=cfg.train.model.name, + model_config=cfg.train.model.to_dict(), + dataset_path=cfg.train, batch_size=cfg.train.batch_size, num_epochs=cfg.train.num_epochs, num_workers=cfg.train.num_workers, diff --git a/tests/test_train/test_parametric_umap.py b/tests/test_train/test_parametric_umap.py index d8e46e4fc..f45344203 100644 --- a/tests/test_train/test_parametric_umap.py +++ b/tests/test_train/test_parametric_umap.py @@ -52,12 +52,11 @@ def test_train_parametric_umap_model( keys_to_change=keys_to_change, ) cfg = vak.config.Config.from_toml_path(toml_path) - model_config = vak.config.model.config_from_toml_path(toml_path, cfg.train.model) vak.train.parametric_umap.train_parametric_umap_model( - model_name=cfg.train.model, - model_config=model_config, - dataset_path=cfg.train.dataset.path, + model_name=cfg.train.model.name, + model_config=cfg.train.model.to_dict(), + dataset_path=cfg.train, batch_size=cfg.train.batch_size, num_epochs=cfg.train.num_epochs, num_workers=cfg.train.num_workers, @@ -102,15 +101,14 @@ def test_train_parametric_umap_model_raises_file_not_found( keys_to_change=keys_to_change, ) cfg = vak.config.Config.from_toml_path(toml_path) - model_config = vak.config.model.config_from_toml_path(toml_path, cfg.train.model) results_path = vak.common.paths.generate_results_dir_name_as_path(tmp_path) results_path.mkdir() with pytest.raises(FileNotFoundError): vak.train.parametric_umap.train_parametric_umap_model( - model_name=cfg.train.model, - model_config=model_config, - dataset_path=cfg.train.dataset.path, + model_name=cfg.train.model.name, + model_config=cfg.train.model.to_dict(), + dataset_path=cfg.train, batch_size=cfg.train.batch_size, num_epochs=cfg.train.num_epochs, num_workers=cfg.train.num_workers, @@ -163,7 +161,7 @@ def test_train_parametric_umap_model_raises_not_a_directory( vak.train.parametric_umap.train_parametric_umap_model( model_name=cfg.train.model, model_config=model_config, - dataset_path=cfg.train.dataset.path, + dataset_path=cfg.train, batch_size=cfg.train.batch_size, num_epochs=cfg.train.num_epochs, num_workers=cfg.train.num_workers, diff --git a/tests/test_train/test_train.py b/tests/test_train/test_train.py index 502dc63d2..aef6f547a 100644 --- a/tests/test_train/test_train.py +++ b/tests/test_train/test_train.py @@ -47,16 +47,15 @@ def test_train( keys_to_change=keys_to_change, ) cfg = vak.config.Config.from_toml_path(toml_path) - model_config = vak.config.model.config_from_toml_path(toml_path, cfg.train.model) results_path = tmp_path / 'results_path' results_path.mkdir() with mock.patch(train_function_to_mock, autospec=True) as mock_train_function: vak.train.train( - model_name=model_name, - model_config=model_config, - dataset_path=cfg.train.dataset.path, + model_name=cfg.train.model.name, + model_config=cfg.train.model.to_dict(), + dataset_path=cfg.train, batch_size=cfg.train.batch_size, num_epochs=cfg.train.num_epochs, num_workers=cfg.train.num_workers, From b8b902ce9e44ecae977cc0090e29efaf27d0e2b3 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Fri, 3 May 2024 14:13:26 -0400 Subject: [PATCH 120/183] In tests/, fixup change .dataset_path -> .dataset.path, use model.name where we used to use just 'model' attribute of config --- .../test_frame_classification/test_frames_dataset.py | 2 +- .../test_frame_classification/test_window_dataset.py | 2 +- .../test_parametric_umap/test_parametric_umap.py | 2 +- tests/test_eval/test_eval.py | 2 +- tests/test_eval/test_frame_classification.py | 9 +++++---- tests/test_eval/test_parametric_umap.py | 8 ++++---- tests/test_learncurve/test_frame_classification.py | 4 ++-- tests/test_models/test_base.py | 2 +- tests/test_predict/test_frame_classification.py | 6 +++--- tests/test_predict/test_predict.py | 2 +- tests/test_train/test_frame_classification.py | 12 ++++++------ tests/test_train/test_parametric_umap.py | 8 ++++---- tests/test_train/test_train.py | 2 +- 13 files changed, 31 insertions(+), 30 deletions(-) diff --git a/tests/test_datasets/test_frame_classification/test_frames_dataset.py b/tests/test_datasets/test_frame_classification/test_frames_dataset.py index 386f52836..c165669b7 100644 --- a/tests/test_datasets/test_frame_classification/test_frames_dataset.py +++ b/tests/test_datasets/test_frame_classification/test_frames_dataset.py @@ -27,7 +27,7 @@ def test_from_dataset_path(self, config_type, model_name, audio_format, spect_fo ) dataset = vak.datasets.frame_classification.FramesDataset.from_dataset_path( - dataset_path=cfg_command, + dataset_path=cfg_command.dataset.path, split=split, item_transform=item_transform, ) diff --git a/tests/test_datasets/test_frame_classification/test_window_dataset.py b/tests/test_datasets/test_frame_classification/test_window_dataset.py index 94b920a14..238f7de12 100644 --- a/tests/test_datasets/test_frame_classification/test_window_dataset.py +++ b/tests/test_datasets/test_frame_classification/test_window_dataset.py @@ -28,7 +28,7 @@ def test_from_dataset_path(self, config_type, model_name, audio_format, spect_fo ) dataset = vak.datasets.frame_classification.WindowDataset.from_dataset_path( - dataset_path=cfg_command, + dataset_path=cfg_command.dataset.path, split=split, window_size=cfg_command.train_dataset_params['window_size'], transform=transform, diff --git a/tests/test_datasets/test_parametric_umap/test_parametric_umap.py b/tests/test_datasets/test_parametric_umap/test_parametric_umap.py index 533f0366d..7f2c0bb38 100644 --- a/tests/test_datasets/test_parametric_umap/test_parametric_umap.py +++ b/tests/test_datasets/test_parametric_umap/test_parametric_umap.py @@ -27,7 +27,7 @@ def test_from_dataset_path(self, config_type, model_name, audio_format, spect_fo ) dataset = vak.datasets.parametric_umap.ParametricUMAPDataset.from_dataset_path( - dataset_path=cfg_command, + dataset_path=cfg_command.dataset.path, split=split, transform=transform, ) diff --git a/tests/test_eval/test_eval.py b/tests/test_eval/test_eval.py index ed1efec3e..e2beb1b7e 100644 --- a/tests/test_eval/test_eval.py +++ b/tests/test_eval/test_eval.py @@ -51,7 +51,7 @@ def test_eval( vak.eval.eval( model_name=cfg.eval.model.name, model_config=cfg.eval.model.to_dict(), - dataset_path=cfg.eval, + dataset_path=cfg.eval.dataset.path, checkpoint_path=cfg.eval.checkpoint_path, labelmap_path=cfg.eval.labelmap_path, output_dir=cfg.eval.output_dir, diff --git a/tests/test_eval/test_frame_classification.py b/tests/test_eval/test_frame_classification.py index 49d253f00..cc28a415e 100644 --- a/tests/test_eval/test_frame_classification.py +++ b/tests/test_eval/test_frame_classification.py @@ -10,6 +10,7 @@ # written as separate function so we can re-use in tests/unit/test_cli/test_eval.py def assert_eval_output_matches_expected(model_name, output_dir): eval_csv = sorted(output_dir.glob(f"eval_{model_name}*csv")) + breakpoint() assert len(eval_csv) == 1 @@ -72,7 +73,7 @@ def test_eval_frame_classification_model( vak.eval.frame_classification.eval_frame_classification_model( model_name=cfg.eval.model.name, model_config=cfg.eval.model.to_dict(), - dataset_path=cfg.eval, + dataset_path=cfg.eval.dataset.path, checkpoint_path=cfg.eval.checkpoint_path, labelmap_path=cfg.eval.labelmap_path, output_dir=cfg.eval.output_dir, @@ -84,7 +85,7 @@ def test_eval_frame_classification_model( post_tfm_kwargs=post_tfm_kwargs, ) - assert_eval_output_matches_expected(cfg.eval.model, output_dir) + assert_eval_output_matches_expected(cfg.eval.model.name, output_dir) @pytest.mark.parametrize( @@ -129,7 +130,7 @@ def test_eval_frame_classification_model_raises_file_not_found( vak.eval.frame_classification.eval_frame_classification_model( model_name=cfg.eval.model.name, model_config=cfg.eval.model.to_dict(), - dataset_path=cfg.eval, + dataset_path=cfg.eval.dataset.path, checkpoint_path=cfg.eval.checkpoint_path, labelmap_path=cfg.eval.labelmap_path, output_dir=cfg.eval.output_dir, @@ -186,7 +187,7 @@ def test_eval_frame_classification_model_raises_not_a_directory( vak.eval.frame_classification.eval_frame_classification_model( model_name=cfg.eval.model.name, model_config=cfg.eval.model.to_dict(), - dataset_path=cfg.eval, + dataset_path=cfg.eval.dataset.path, checkpoint_path=cfg.eval.checkpoint_path, labelmap_path=cfg.eval.labelmap_path, output_dir=cfg.eval.output_dir, diff --git a/tests/test_eval/test_parametric_umap.py b/tests/test_eval/test_parametric_umap.py index e9f8f85c3..f450d4ce1 100644 --- a/tests/test_eval/test_parametric_umap.py +++ b/tests/test_eval/test_parametric_umap.py @@ -50,7 +50,7 @@ def test_eval_parametric_umap_model( vak.eval.parametric_umap.eval_parametric_umap_model( model_name=cfg.eval.model.name, model_config=cfg.eval.model.to_dict(), - dataset_path=cfg.eval, + dataset_path=cfg.eval.dataset.path, checkpoint_path=cfg.eval.checkpoint_path, output_dir=cfg.eval.output_dir, batch_size=cfg.eval.batch_size, @@ -60,7 +60,7 @@ def test_eval_parametric_umap_model( device=cfg.eval.device, ) - assert_eval_output_matches_expected(cfg.eval.model, output_dir) + assert_eval_output_matches_expected(cfg.eval.model.name, output_dir) @pytest.mark.parametrize( @@ -101,7 +101,7 @@ def test_eval_frame_classification_model_raises_file_not_found( vak.eval.parametric_umap.eval_parametric_umap_model( model_name=cfg.eval.model.name, model_config=cfg.eval.model.to_dict(), - dataset_path=cfg.eval, + dataset_path=cfg.eval.dataset.path, checkpoint_path=cfg.eval.checkpoint_path, output_dir=cfg.eval.output_dir, batch_size=cfg.eval.batch_size, @@ -156,7 +156,7 @@ def test_eval_frame_classification_model_raises_not_a_directory( vak.eval.parametric_umap.eval_parametric_umap_model( model_name=cfg.eval.model.name, model_config=cfg.eval.model.to_dict(), - dataset_path=cfg.eval, + dataset_path=cfg.eval.dataset.path, checkpoint_path=cfg.eval.checkpoint_path, output_dir=cfg.eval.output_dir, batch_size=cfg.eval.batch_size, diff --git a/tests/test_learncurve/test_frame_classification.py b/tests/test_learncurve/test_frame_classification.py index e53ed914e..4bdff2593 100644 --- a/tests/test_learncurve/test_frame_classification.py +++ b/tests/test_learncurve/test_frame_classification.py @@ -69,7 +69,7 @@ def test_learning_curve_for_frame_classification_model( vak.learncurve.frame_classification.learning_curve_for_frame_classification_model( model_name=cfg.learncurve.model.name, model_config=cfg.learncurve.model.to_dict(), - dataset_path=cfg.learncurve, + dataset_path=cfg.learncurve.dataset.path, batch_size=cfg.learncurve.batch_size, num_epochs=cfg.learncurve.num_epochs, num_workers=cfg.learncurve.num_workers, @@ -123,7 +123,7 @@ def test_learncurve_raises_not_a_directory(dir_option_to_change, vak.learncurve.frame_classification.learning_curve_for_frame_classification_model( model_name=cfg.learncurve.model.name, model_config=cfg.learncurve.model.to_dict(), - dataset_path=cfg.learncurve, + dataset_path=cfg.learncurve.dataset.path, batch_size=cfg.learncurve.batch_size, num_epochs=cfg.learncurve.num_epochs, num_workers=cfg.learncurve.num_workers, diff --git a/tests/test_models/test_base.py b/tests/test_models/test_base.py index 44b01214c..f1c389983 100644 --- a/tests/test_models/test_base.py +++ b/tests/test_models/test_base.py @@ -201,7 +201,7 @@ def test_load_state_dict_from_path(self, transform_kwargs={}, ) train_dataset = vak.datasets.frame_classification.WindowDataset.from_dataset_path( - dataset_path=train_cfg.train, + dataset_path=train_cfg.train.dataset.path, split="train", window_size=train_cfg.train.train_dataset_params['window_size'], transform=transform, diff --git a/tests/test_predict/test_frame_classification.py b/tests/test_predict/test_frame_classification.py index e7feb7cef..554b6c92a 100644 --- a/tests/test_predict/test_frame_classification.py +++ b/tests/test_predict/test_frame_classification.py @@ -54,7 +54,7 @@ def test_predict_with_frame_classification_model( vak.predict.frame_classification.predict_with_frame_classification_model( model_name=cfg.predict.model.name, model_config=cfg.predict.model.to_dict(), - dataset_path=cfg.predict, + dataset_path=cfg.predict.dataset.path, checkpoint_path=cfg.predict.checkpoint_path, labelmap_path=cfg.predict.labelmap_path, num_workers=cfg.predict.num_workers, @@ -128,7 +128,7 @@ def test_predict_with_frame_classification_model_raises_file_not_found( vak.predict.frame_classification.predict_with_frame_classification_model( model_name=cfg.predict.model.name, model_config=cfg.predict.model.to_dict(), - dataset_path=cfg.predict, + dataset_path=cfg.predict.dataset.path, checkpoint_path=cfg.predict.checkpoint_path, labelmap_path=cfg.predict.labelmap_path, num_workers=cfg.predict.num_workers, @@ -190,7 +190,7 @@ def test_predict_with_frame_classification_model_raises_not_a_directory( vak.predict.frame_classification.predict_with_frame_classification_model( model_name=cfg.predict.model.name, model_config=cfg.predict.model.to_dict(), - dataset_path=cfg.predict, + dataset_path=cfg.predict.dataset.path, checkpoint_path=cfg.predict.checkpoint_path, labelmap_path=cfg.predict.labelmap_path, num_workers=cfg.predict.num_workers, diff --git a/tests/test_predict/test_predict.py b/tests/test_predict/test_predict.py index ad697e68d..03e484516 100644 --- a/tests/test_predict/test_predict.py +++ b/tests/test_predict/test_predict.py @@ -47,7 +47,7 @@ def test_predict( vak.predict.predict( model_name=cfg.predict.model.name, model_config=cfg.predict.model.to_dict(), - dataset_path=cfg.predict, + dataset_path=cfg.predict.dataset.path, checkpoint_path=cfg.predict.checkpoint_path, labelmap_path=cfg.predict.labelmap_path, num_workers=cfg.predict.num_workers, diff --git a/tests/test_train/test_frame_classification.py b/tests/test_train/test_frame_classification.py index 1d762d2fe..23de10f19 100644 --- a/tests/test_train/test_frame_classification.py +++ b/tests/test_train/test_frame_classification.py @@ -63,7 +63,7 @@ def test_train_frame_classification_model( vak.train.frame_classification.train_frame_classification_model( model_name=cfg.train.model.name, model_config=cfg.train.model.to_dict(), - dataset_path=cfg.train, + dataset_path=cfg.train.dataset.path, batch_size=cfg.train.batch_size, num_epochs=cfg.train.num_epochs, num_workers=cfg.train.num_workers, @@ -82,7 +82,7 @@ def test_train_frame_classification_model( device=cfg.train.device, ) - assert_train_output_matches_expected(cfg, cfg.train.model, results_path) + assert_train_output_matches_expected(cfg, cfg.train.model.name, results_path) @pytest.mark.slow @@ -115,7 +115,7 @@ def test_continue_training( vak.train.frame_classification.train_frame_classification_model( model_name=cfg.train.model.name, model_config=cfg.train.model.to_dict(), - dataset_path=cfg.train, + dataset_path=cfg.train.dataset.path, batch_size=cfg.train.batch_size, num_epochs=cfg.train.num_epochs, num_workers=cfg.train.num_workers, @@ -134,7 +134,7 @@ def test_continue_training( device=cfg.train.device, ) - assert_train_output_matches_expected(cfg, cfg.train.model, results_path) + assert_train_output_matches_expected(cfg, cfg.train.model.name, results_path) @pytest.mark.parametrize( @@ -171,7 +171,7 @@ def test_train_raises_file_not_found( vak.train.frame_classification.train_frame_classification_model( model_name=cfg.train.model.name, model_config=cfg.train.model.to_dict(), - dataset_path=cfg.train, + dataset_path=cfg.train.dataset.path, batch_size=cfg.train.batch_size, num_epochs=cfg.train.num_epochs, num_workers=cfg.train.num_workers, @@ -226,7 +226,7 @@ def test_train_raises_not_a_directory( vak.train.frame_classification.train_frame_classification_model( model_name=cfg.train.model.name, model_config=cfg.train.model.to_dict(), - dataset_path=cfg.train, + dataset_path=cfg.train.dataset.path, batch_size=cfg.train.batch_size, num_epochs=cfg.train.num_epochs, num_workers=cfg.train.num_workers, diff --git a/tests/test_train/test_parametric_umap.py b/tests/test_train/test_parametric_umap.py index f45344203..8dd9b2efe 100644 --- a/tests/test_train/test_parametric_umap.py +++ b/tests/test_train/test_parametric_umap.py @@ -56,7 +56,7 @@ def test_train_parametric_umap_model( vak.train.parametric_umap.train_parametric_umap_model( model_name=cfg.train.model.name, model_config=cfg.train.model.to_dict(), - dataset_path=cfg.train, + dataset_path=cfg.train.dataset.path, batch_size=cfg.train.batch_size, num_epochs=cfg.train.num_epochs, num_workers=cfg.train.num_workers, @@ -72,7 +72,7 @@ def test_train_parametric_umap_model( device=cfg.train.device, ) - assert_train_output_matches_expected(cfg, cfg.train.model, results_path) + assert_train_output_matches_expected(cfg, cfg.train.model.name, results_path) @pytest.mark.parametrize( @@ -108,7 +108,7 @@ def test_train_parametric_umap_model_raises_file_not_found( vak.train.parametric_umap.train_parametric_umap_model( model_name=cfg.train.model.name, model_config=cfg.train.model.to_dict(), - dataset_path=cfg.train, + dataset_path=cfg.train.dataset.path, batch_size=cfg.train.batch_size, num_epochs=cfg.train.num_epochs, num_workers=cfg.train.num_workers, @@ -161,7 +161,7 @@ def test_train_parametric_umap_model_raises_not_a_directory( vak.train.parametric_umap.train_parametric_umap_model( model_name=cfg.train.model, model_config=model_config, - dataset_path=cfg.train, + dataset_path=cfg.train.dataset.path, batch_size=cfg.train.batch_size, num_epochs=cfg.train.num_epochs, num_workers=cfg.train.num_workers, diff --git a/tests/test_train/test_train.py b/tests/test_train/test_train.py index aef6f547a..b511d0e13 100644 --- a/tests/test_train/test_train.py +++ b/tests/test_train/test_train.py @@ -55,7 +55,7 @@ def test_train( vak.train.train( model_name=cfg.train.model.name, model_config=cfg.train.model.to_dict(), - dataset_path=cfg.train, + dataset_path=cfg.train.dataset.path, batch_size=cfg.train.batch_size, num_epochs=cfg.train.num_epochs, num_workers=cfg.train.num_workers, From 631e4039bba5b34fc3719d38830dfe35cafde98a Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Fri, 3 May 2024 14:44:36 -0400 Subject: [PATCH 121/183] Fix fixture specific_config_toml_path in fixtures/config.py to handle case where we need to access sub-table and change a key in it--right now this is just [ 'dataset']['path'] --- tests/fixtures/config.py | 28 +++++++++++++++++++++++----- 1 file changed, 23 insertions(+), 5 deletions(-) diff --git a/tests/fixtures/config.py b/tests/fixtures/config.py index 8205d7c44..8fd9b8dd9 100644 --- a/tests/fixtures/config.py +++ b/tests/fixtures/config.py @@ -175,12 +175,30 @@ def _specific_config( with config_copy_path.open("r") as fp: tomldoc = tomlkit.load(fp) - for opt_dict in keys_to_change: - if opt_dict["value"] == 'DELETE-KEY': - # e.g., to test behavior of config without this key - del tomldoc["vak"][opt_dict["table"]][opt_dict["key"]] + for table_key_val_dict in keys_to_change: + table_name = table_key_val_dict["table"] + key = table_key_val_dict["key"] + value = table_key_val_dict["value"] + if isinstance(key, str): + if table_key_val_dict["value"] == 'DELETE-KEY': + # e.g., to test behavior of config without this key + del tomldoc["vak"][table_name][key] + else: + tomldoc["vak"][table_name][key] = value + elif isinstance(key, list) and len(key) == 2 and all([isinstance(el, str) for el in key]): + # for the case where we need to access a sub-table + # right now this applies mainly to ["vak"][table]["dataset"]["path"] + # if we end up having to access more / deeper then we'll need something more general + if table_key_val_dict["value"] == 'DELETE-KEY': + # e.g., to test behavior of config without this key + del tomldoc["vak"][table_name][key[0]][key[1]] + else: + tomldoc["vak"][table_name][key[0]][key[1]] = value else: - tomldoc["vak"][opt_dict["table"]][opt_dict["key"]] = opt_dict["value"] + raise ValueError( + f"Unexpected value for 'key' in `keys_to_change` dict: {key}.\n" + f"`keys_to_change` dict: {table_key_val_dict}" + ) with config_copy_path.open("w") as fp: tomlkit.dump(tomldoc, fp) From f717e0042b7d8b84bc696351c79f1e6323e51a4b Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Fri, 3 May 2024 14:44:57 -0400 Subject: [PATCH 122/183] Fix how we change ['dataset']['path'] value in tests/test_eval/test_frame_classification.py --- tests/test_eval/test_frame_classification.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/test_eval/test_frame_classification.py b/tests/test_eval/test_frame_classification.py index cc28a415e..b573fe85c 100644 --- a/tests/test_eval/test_frame_classification.py +++ b/tests/test_eval/test_frame_classification.py @@ -10,7 +10,6 @@ # written as separate function so we can re-use in tests/unit/test_cli/test_eval.py def assert_eval_output_matches_expected(model_name, output_dir): eval_csv = sorted(output_dir.glob(f"eval_{model_name}*csv")) - breakpoint() assert len(eval_csv) == 1 @@ -145,7 +144,7 @@ def test_eval_frame_classification_model_raises_file_not_found( @pytest.mark.parametrize( 'path_option_to_change', [ - {"table": "eval", "key": "dataset_path", "value": '/obviously/doesnt/exist/dataset-dir'}, + {"table": "eval", "key": ["dataset","path"], "value": '/obviously/doesnt/exist/dataset-dir'}, {"table": "eval", "key": "output_dir", "value": '/obviously/does/not/exist/output'}, ] ) From ac1a7e77db500947c56b5587f69975a3599bb5ba Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Fri, 3 May 2024 14:45:42 -0400 Subject: [PATCH 123/183] Fix how we change ['dataset']['path'] value in config in several tests --- tests/test_eval/test_parametric_umap.py | 2 +- tests/test_learncurve/test_frame_classification.py | 2 +- tests/test_predict/test_frame_classification.py | 2 +- tests/test_train/test_frame_classification.py | 2 +- tests/test_train/test_parametric_umap.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/test_eval/test_parametric_umap.py b/tests/test_eval/test_parametric_umap.py index f450d4ce1..69fa2cf44 100644 --- a/tests/test_eval/test_parametric_umap.py +++ b/tests/test_eval/test_parametric_umap.py @@ -115,7 +115,7 @@ def test_eval_frame_classification_model_raises_file_not_found( @pytest.mark.parametrize( 'path_option_to_change', [ - {"table": "eval", "key": "dataset_path", "value": '/obviously/doesnt/exist/dataset-dir'}, + {"table": "eval", "key": ["dataset", "path"], "value": '/obviously/doesnt/exist/dataset-dir'}, {"table": "eval", "key": "output_dir", "value": '/obviously/does/not/exist/output'}, ] ) diff --git a/tests/test_learncurve/test_frame_classification.py b/tests/test_learncurve/test_frame_classification.py index 4bdff2593..588b06d59 100644 --- a/tests/test_learncurve/test_frame_classification.py +++ b/tests/test_learncurve/test_frame_classification.py @@ -94,7 +94,7 @@ def test_learning_curve_for_frame_classification_model( 'dir_option_to_change', [ {"table": "learncurve", "key": "root_results_dir", "value": '/obviously/does/not/exist/results/'}, - {"table": "learncurve", "key": "dataset_path", "value": '/obviously/doesnt/exist/dataset-dir'}, + {"table": "learncurve", "key": ["dataset", "path"], "value": '/obviously/doesnt/exist/dataset-dir'}, ] ) def test_learncurve_raises_not_a_directory(dir_option_to_change, diff --git a/tests/test_predict/test_frame_classification.py b/tests/test_predict/test_frame_classification.py index 554b6c92a..74282e796 100644 --- a/tests/test_predict/test_frame_classification.py +++ b/tests/test_predict/test_frame_classification.py @@ -148,7 +148,7 @@ def test_predict_with_frame_classification_model_raises_file_not_found( @pytest.mark.parametrize( 'path_option_to_change', [ - {"table": "predict", "key": "dataset_path", "value": '/obviously/doesnt/exist/dataset-dir'}, + {"table": "predict", "key": ["dataset", "path"], "value": '/obviously/doesnt/exist/dataset-dir'}, {"table": "predict", "key": "output_dir", "value": '/obviously/does/not/exist/output'}, ] ) diff --git a/tests/test_train/test_frame_classification.py b/tests/test_train/test_frame_classification.py index 23de10f19..6cdf8d824 100644 --- a/tests/test_train/test_frame_classification.py +++ b/tests/test_train/test_frame_classification.py @@ -194,7 +194,7 @@ def test_train_raises_file_not_found( @pytest.mark.parametrize( 'path_option_to_change', [ - {"table": "train", "key": "dataset_path", "value": '/obviously/doesnt/exist/dataset-dir'}, + {"table": "train", "key": ["dataset", "path"], "value": '/obviously/doesnt/exist/dataset-dir'}, {"table": "train", "key": "root_results_dir", "value": '/obviously/doesnt/exist/results/'}, ] ) diff --git a/tests/test_train/test_parametric_umap.py b/tests/test_train/test_parametric_umap.py index 8dd9b2efe..2a2d4bfd8 100644 --- a/tests/test_train/test_parametric_umap.py +++ b/tests/test_train/test_parametric_umap.py @@ -128,7 +128,7 @@ def test_train_parametric_umap_model_raises_file_not_found( @pytest.mark.parametrize( 'path_option_to_change', [ - {"table": "train", "key": "dataset_path", "value": '/obviously/doesnt/exist/dataset-dir'}, + {"table": "train", "key": ["dataset", "path"], "value": '/obviously/doesnt/exist/dataset-dir'}, {"table": "train", "key": "root_results_dir", "value": '/obviously/doesnt/exist/results/'}, ] ) From e4ac360fc453197009984437bfe3f8e3d1ae9fff Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Fri, 3 May 2024 15:03:45 -0400 Subject: [PATCH 124/183] Use ModelConfig attribute name where needed in tests/test_learncurve/test_frame_classification.py --- tests/test_learncurve/test_frame_classification.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_learncurve/test_frame_classification.py b/tests/test_learncurve/test_frame_classification.py index 588b06d59..9f5d87804 100644 --- a/tests/test_learncurve/test_frame_classification.py +++ b/tests/test_learncurve/test_frame_classification.py @@ -87,7 +87,7 @@ def test_learning_curve_for_frame_classification_model( device=cfg.learncurve.device, ) - assert_learncurve_output_matches_expected(cfg, cfg.learncurve.model, results_path) + assert_learncurve_output_matches_expected(cfg, cfg.learncurve.model.name, results_path) @pytest.mark.parametrize( From f542c47771cf8b4681cb55a4fdb5c1bdb5b4fe57 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Fri, 3 May 2024 15:04:06 -0400 Subject: [PATCH 125/183] In tests, replace calls to vak.config.model.config_from_toml_path with calls to ModelConfig method to_dict() --- tests/test_models/test_frame_classification_model.py | 2 +- tests/test_models/test_parametric_umap_model.py | 2 +- tests/test_train/test_parametric_umap.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_models/test_frame_classification_model.py b/tests/test_models/test_frame_classification_model.py index c516b7d54..c69095bae 100644 --- a/tests/test_models/test_frame_classification_model.py +++ b/tests/test_models/test_frame_classification_model.py @@ -95,7 +95,7 @@ def test_from_config(self, vak.models.FrameClassificationModel, 'definition', definition, raising=False ) - config = vak.config.model.config_from_toml_path(toml_path, cfg.train.model) + config = cfg.train.model.to_dict() num_input_channels, num_freqbins = self.MOCK_INPUT_SHAPE[0], self.MOCK_INPUT_SHAPE[1] config["network"].update( diff --git a/tests/test_models/test_parametric_umap_model.py b/tests/test_models/test_parametric_umap_model.py index 201c3da7f..a933a88b0 100644 --- a/tests/test_models/test_parametric_umap_model.py +++ b/tests/test_models/test_parametric_umap_model.py @@ -93,7 +93,7 @@ def test_from_config( vak.models.ParametricUMAPModel, 'definition', definition, raising=False ) - config = vak.config.model.config_from_toml_path(toml_path, cfg.train.model) + config = cfg.train.model.to_dict() config["network"].update( encoder=dict(input_shape=input_shape) ) diff --git a/tests/test_train/test_parametric_umap.py b/tests/test_train/test_parametric_umap.py index 2a2d4bfd8..dd3d80466 100644 --- a/tests/test_train/test_parametric_umap.py +++ b/tests/test_train/test_parametric_umap.py @@ -152,7 +152,7 @@ def test_train_parametric_umap_model_raises_not_a_directory( keys_to_change=keys_to_change, ) cfg = vak.config.Config.from_toml_path(toml_path) - model_config = vak.config.model.config_from_toml_path(toml_path, cfg.train.model) + model_config = cfg.train.model.to_dict() # mock behavior of cli.train, building `results_path` from config option `root_results_dir` results_path = cfg.train.root_results_dir / 'results-dir-timestamp' From 0ceadfe77dbf49da158d321ab082f5f6fba52323 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Fri, 3 May 2024 15:16:04 -0400 Subject: [PATCH 126/183] Change cfg.spect_params -> cfg.prep.spect_params in tests --- tests/test_predict/test_frame_classification.py | 6 +++--- tests/test_predict/test_predict.py | 2 +- .../test_frame_classification.py | 12 ++++++------ .../test_get_or_make_source_files.py | 4 ++-- tests/test_prep/test_prep.py | 2 +- 5 files changed, 13 insertions(+), 13 deletions(-) diff --git a/tests/test_predict/test_frame_classification.py b/tests/test_predict/test_frame_classification.py index 74282e796..09c9267ec 100644 --- a/tests/test_predict/test_frame_classification.py +++ b/tests/test_predict/test_frame_classification.py @@ -60,7 +60,7 @@ def test_predict_with_frame_classification_model( num_workers=cfg.predict.num_workers, transform_params=cfg.predict.transform_params, dataset_params=cfg.predict.dataset_params, - timebins_key=cfg.spect_params.timebins_key, + timebins_key=cfg.prep.spect_params.timebins_key, spect_scaler_path=cfg.predict.spect_scaler_path, device=cfg.predict.device, annot_csv_filename=cfg.predict.annot_csv_filename, @@ -134,7 +134,7 @@ def test_predict_with_frame_classification_model_raises_file_not_found( num_workers=cfg.predict.num_workers, transform_params=cfg.predict.transform_params, dataset_params=cfg.predict.dataset_params, - timebins_key=cfg.spect_params.timebins_key, + timebins_key=cfg.prep.spect_params.timebins_key, spect_scaler_path=cfg.predict.spect_scaler_path, device=cfg.predict.device, annot_csv_filename=cfg.predict.annot_csv_filename, @@ -196,7 +196,7 @@ def test_predict_with_frame_classification_model_raises_not_a_directory( num_workers=cfg.predict.num_workers, transform_params=cfg.predict.transform_params, dataset_params=cfg.predict.dataset_params, - timebins_key=cfg.spect_params.timebins_key, + timebins_key=cfg.prep.spect_params.timebins_key, spect_scaler_path=cfg.predict.spect_scaler_path, device=cfg.predict.device, annot_csv_filename=cfg.predict.annot_csv_filename, diff --git a/tests/test_predict/test_predict.py b/tests/test_predict/test_predict.py index 03e484516..c445dad49 100644 --- a/tests/test_predict/test_predict.py +++ b/tests/test_predict/test_predict.py @@ -53,7 +53,7 @@ def test_predict( num_workers=cfg.predict.num_workers, transform_params=cfg.predict.transform_params, dataset_params=cfg.predict.dataset_params, - timebins_key=cfg.spect_params.timebins_key, + timebins_key=cfg.prep.spect_params.timebins_key, spect_scaler_path=cfg.predict.spect_scaler_path, device=cfg.predict.device, annot_csv_filename=cfg.predict.annot_csv_filename, diff --git a/tests/test_prep/test_frame_classification/test_frame_classification.py b/tests/test_prep/test_frame_classification/test_frame_classification.py index 851857c94..968f4d19b 100644 --- a/tests/test_prep/test_frame_classification/test_frame_classification.py +++ b/tests/test_prep/test_frame_classification/test_frame_classification.py @@ -104,7 +104,7 @@ def test_prep_frame_classification_dataset( purpose=purpose, audio_format=cfg.prep.audio_format, spect_format=cfg.prep.spect_format, - spect_params=cfg.spect_params, + spect_params=cfg.prep.spect_params, annot_format=cfg.prep.annot_format, annot_file=cfg.prep.annot_file, labelset=cfg.prep.labelset, @@ -177,7 +177,7 @@ def test_prep_frame_classification_dataset_raises_when_labelset_required_but_is_ purpose=purpose, audio_format=cfg.prep.audio_format, spect_format=cfg.prep.spect_format, - spect_params=cfg.spect_params, + spect_params=cfg.prep.spect_params, annot_format=cfg.prep.annot_format, annot_file=cfg.prep.annot_file, labelset=cfg.prep.labelset, @@ -244,7 +244,7 @@ def test_prep_frame_classification_dataset_with_single_audio_and_annot(source_te purpose=purpose, audio_format=cfg.prep.audio_format, spect_format=cfg.prep.spect_format, - spect_params=cfg.spect_params, + spect_params=cfg.prep.spect_params, annot_format=cfg.prep.annot_format, annot_file=cfg.prep.annot_file, labelset=cfg.prep.labelset, @@ -302,7 +302,7 @@ def test_prep_frame_classification_dataset_when_annot_has_single_segment(source_ purpose=purpose, audio_format=cfg.prep.audio_format, spect_format=cfg.prep.spect_format, - spect_params=cfg.spect_params, + spect_params=cfg.prep.spect_params, annot_format=cfg.prep.annot_format, annot_file=cfg.prep.annot_file, labelset=cfg.prep.labelset, @@ -350,7 +350,7 @@ def test_prep_frame_classification_dataset_raises_not_a_directory( purpose=purpose, audio_format=cfg.prep.audio_format, spect_format=cfg.prep.spect_format, - spect_params=cfg.spect_params, + spect_params=cfg.prep.spect_params, annot_format=cfg.prep.annot_format, annot_file=cfg.prep.annot_file, labelset=cfg.prep.labelset, @@ -398,7 +398,7 @@ def test_prep_frame_classification_dataset_raises_file_not_found( purpose=purpose, audio_format=cfg.prep.audio_format, spect_format=cfg.prep.spect_format, - spect_params=cfg.spect_params, + spect_params=cfg.prep.spect_params, annot_format=cfg.prep.annot_format, annot_file=cfg.prep.annot_file, labelset=cfg.prep.labelset, diff --git a/tests/test_prep/test_frame_classification/test_get_or_make_source_files.py b/tests/test_prep/test_frame_classification/test_get_or_make_source_files.py index a58808d35..a49bbf5ba 100644 --- a/tests/test_prep/test_frame_classification/test_get_or_make_source_files.py +++ b/tests/test_prep/test_frame_classification/test_get_or_make_source_files.py @@ -55,7 +55,7 @@ def test_get_or_make_source_files( cfg.prep.input_type, cfg.prep.audio_format, cfg.prep.spect_format, - cfg.spect_params, + cfg.prep.spect_params, tmp_dataset_path, cfg.prep.annot_format, cfg.prep.annot_file, @@ -77,7 +77,7 @@ def test_get_or_make_source_files( cfg.prep.input_type, cfg.prep.audio_format, cfg.prep.spect_format, - cfg.spect_params, + cfg.prep.spect_params, tmp_dataset_path, cfg.prep.annot_format, cfg.prep.annot_file, diff --git a/tests/test_prep/test_prep.py b/tests/test_prep/test_prep.py index ad673ed3d..370cb6647 100644 --- a/tests/test_prep/test_prep.py +++ b/tests/test_prep/test_prep.py @@ -61,7 +61,7 @@ def test_prep( purpose=purpose, audio_format=cfg.prep.audio_format, spect_format=cfg.prep.spect_format, - spect_params=cfg.spect_params, + spect_params=cfg.prep.spect_params, annot_format=cfg.prep.annot_format, annot_file=cfg.prep.annot_file, labelset=cfg.prep.labelset, From 315294a6acbbd22e100dfddb9c7ebcccc0cab880 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Fri, 3 May 2024 15:25:56 -0400 Subject: [PATCH 127/183] Fix cfg.predict -> cfg.predict.dataset.path in tests/test_predict/test_frame_classification.py --- tests/test_predict/test_frame_classification.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_predict/test_frame_classification.py b/tests/test_predict/test_frame_classification.py index 09c9267ec..3facf373d 100644 --- a/tests/test_predict/test_frame_classification.py +++ b/tests/test_predict/test_frame_classification.py @@ -76,8 +76,8 @@ def test_predict_with_frame_classification_model( Path(output_dir).glob(f"*{vak.common.constants.NET_OUTPUT_SUFFIX}") ) - metadata = vak.datasets.frame_classification.Metadata.from_dataset_path(cfg.predict) - dataset_csv_path = cfg.predict / metadata.dataset_csv_filename + metadata = vak.datasets.frame_classification.Metadata.from_dataset_path(cfg.predict.dataset.path) + dataset_csv_path = cfg.predict.dataset.path / metadata.dataset_csv_filename dataset_df = pd.read_csv(dataset_csv_path) for spect_path in dataset_df.spect_path.values: From 7ad30b015659d1c14273756af47eaedcc4a56d62 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Fri, 3 May 2024 15:33:11 -0400 Subject: [PATCH 128/183] Fix constant LABELSET_NOTMAT in fixtures/annot.py so it is a list of str, not a Tomlkit.String class --- tests/fixtures/annot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/fixtures/annot.py b/tests/fixtures/annot.py index c69c31ca7..d7c63a4af 100644 --- a/tests/fixtures/annot.py +++ b/tests/fixtures/annot.py @@ -76,7 +76,7 @@ def annot_list_notmat(): # doesn't really matter which config, they all have labelset with a_train_notmat_config.open("r") as fp: a_train_notmat_toml = tomlkit.load(fp) -LABELSET_NOTMAT = a_train_notmat_toml["vak"]["prep"]["labelset"] +LABELSET_NOTMAT = list(str(a_train_notmat_toml["vak"]["prep"]["labelset"])) @pytest.fixture From 937192dd2325142a61c6e69e5b2ced45706583f1 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Fri, 3 May 2024 19:50:19 -0400 Subject: [PATCH 129/183] Fix cfg.learncurve -> cfg.learncurve.dataset.path in tests/test_prep/test_frame/test_learncurve.py --- tests/test_prep/test_frame_classification/test_learncurve.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_prep/test_frame_classification/test_learncurve.py b/tests/test_prep/test_frame_classification/test_learncurve.py index 7a67956da..ce787736b 100644 --- a/tests/test_prep/test_frame_classification/test_learncurve.py +++ b/tests/test_prep/test_frame_classification/test_learncurve.py @@ -38,7 +38,7 @@ def test_make_index_vectors_for_each_subsets( ) cfg = vak.config.Config.from_toml_path(toml_path) - dataset_path = cfg.learncurve + dataset_path = cfg.learncurve.dataset.path metadata = vak.datasets.frame_classification.Metadata.from_dataset_path(dataset_path) dataset_csv_path = dataset_path / metadata.dataset_csv_filename dataset_df = pd.read_csv(dataset_csv_path) From ef9cd86aae7ea4ebad3709e602ab7272f6c00e09 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Fri, 3 May 2024 20:43:39 -0400 Subject: [PATCH 130/183] Fix cfg.learncurve -> cfg.learncurve.dataset.path in tests/test_prep/test_frame/test_learncurve.py --- tests/test_prep/test_frame_classification/test_learncurve.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_prep/test_frame_classification/test_learncurve.py b/tests/test_prep/test_frame_classification/test_learncurve.py index ce787736b..ef3105144 100644 --- a/tests/test_prep/test_frame_classification/test_learncurve.py +++ b/tests/test_prep/test_frame_classification/test_learncurve.py @@ -150,7 +150,7 @@ def test_make_subsets_from_dataset_df( ) cfg = vak.config.Config.from_toml_path(toml_path) - dataset_path = cfg.learncurve + dataset_path = cfg.learncurve.dataset.path metadata = vak.datasets.frame_classification.Metadata.from_dataset_path(dataset_path) dataset_csv_path = dataset_path / metadata.dataset_csv_filename dataset_df = pd.read_csv(dataset_csv_path) From ee4a9a9a4f6782438fab3a45ea64476a2a742dd7 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Fri, 3 May 2024 20:43:57 -0400 Subject: [PATCH 131/183] Cast pathlib to str before adding to tomldoc, in tests/test_train/ --- tests/test_train/test_frame_classification.py | 4 ++-- tests/test_train/test_parametric_umap.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_train/test_frame_classification.py b/tests/test_train/test_frame_classification.py index 6cdf8d824..4bc1587a3 100644 --- a/tests/test_train/test_frame_classification.py +++ b/tests/test_train/test_frame_classification.py @@ -48,7 +48,7 @@ def test_train_frame_classification_model( results_path.mkdir() keys_to_change = [ {"table": "train", "key": "device", "value": device}, - {"table": "train", "key": "root_results_dir", "value": results_path} + {"table": "train", "key": "root_results_dir", "value": str(results_path)} ] toml_path = specific_config_toml_path( config_type="train", @@ -100,7 +100,7 @@ def test_continue_training( results_path.mkdir() keys_to_change = [ {"table": "train", "key": "device", "value": device}, - {"table": "train", "key": "root_results_dir", "value": results_path} + {"table": "train", "key": "root_results_dir", "value": str(results_path)} ] toml_path = specific_config_toml_path( config_type="train_continue", diff --git a/tests/test_train/test_parametric_umap.py b/tests/test_train/test_parametric_umap.py index dd3d80466..58ff5fb82 100644 --- a/tests/test_train/test_parametric_umap.py +++ b/tests/test_train/test_parametric_umap.py @@ -41,7 +41,7 @@ def test_train_parametric_umap_model( results_path.mkdir() keys_to_change = [ {"table": "train", "key": "device", "value": device}, - {"table": "train", "key": "root_results_dir", "value": results_path} + {"table": "train", "key": "root_results_dir", "value": str(results_path)} ] toml_path = specific_config_toml_path( config_type="train", From 881f4f855744856ca38df298696c54b24722366d Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Fri, 3 May 2024 22:01:13 -0400 Subject: [PATCH 132/183] Change transform/dataset params keys in data_for_tests/configs to a dataset table with a params key --- .../configs/TweetyNet_eval_audio_cbin_annot_notmat.toml | 2 +- .../TweetyNet_learncurve_audio_cbin_annot_notmat.toml | 5 +---- .../configs/TweetyNet_predict_audio_cbin_annot_notmat.toml | 2 +- .../configs/TweetyNet_train_audio_cbin_annot_notmat.toml | 5 +---- .../TweetyNet_train_continue_audio_cbin_annot_notmat.toml | 5 +---- .../TweetyNet_train_continue_spect_mat_annot_yarden.toml | 5 +---- .../configs/TweetyNet_train_spect_mat_annot_yarden.toml | 5 +---- 7 files changed, 7 insertions(+), 22 deletions(-) diff --git a/tests/data_for_tests/configs/TweetyNet_eval_audio_cbin_annot_notmat.toml b/tests/data_for_tests/configs/TweetyNet_eval_audio_cbin_annot_notmat.toml index d5975034b..a629f1832 100644 --- a/tests/data_for_tests/configs/TweetyNet_eval_audio_cbin_annot_notmat.toml +++ b/tests/data_for_tests/configs/TweetyNet_eval_audio_cbin_annot_notmat.toml @@ -27,7 +27,7 @@ output_dir = "./tests/data_for_tests/generated/results/eval/audio_cbin_annot_not majority_vote = true min_segment_dur = 0.02 -[vak.eval.transform_params] +[vak.eval.dataset.params] window_size = 88 [vak.eval.model.TweetyNet.network] diff --git a/tests/data_for_tests/configs/TweetyNet_learncurve_audio_cbin_annot_notmat.toml b/tests/data_for_tests/configs/TweetyNet_learncurve_audio_cbin_annot_notmat.toml index 9d0c5b3b8..c169b3bda 100644 --- a/tests/data_for_tests/configs/TweetyNet_learncurve_audio_cbin_annot_notmat.toml +++ b/tests/data_for_tests/configs/TweetyNet_learncurve_audio_cbin_annot_notmat.toml @@ -34,10 +34,7 @@ root_results_dir = "./tests/data_for_tests/generated/results/learncurve/audio_cb majority_vote = true min_segment_dur = 0.02 -[vak.learncurve.train_dataset_params] -window_size = 88 - -[vak.learncurve.val_transform_params] +[vak.learncurve.dataset.params] window_size = 88 [vak.learncurve.model.TweetyNet.network] diff --git a/tests/data_for_tests/configs/TweetyNet_predict_audio_cbin_annot_notmat.toml b/tests/data_for_tests/configs/TweetyNet_predict_audio_cbin_annot_notmat.toml index 3c83d2826..1f2bf74b1 100644 --- a/tests/data_for_tests/configs/TweetyNet_predict_audio_cbin_annot_notmat.toml +++ b/tests/data_for_tests/configs/TweetyNet_predict_audio_cbin_annot_notmat.toml @@ -22,7 +22,7 @@ device = "cuda" output_dir = "./tests/data_for_tests/generated/results/predict/audio_cbin_annot_notmat/TweetyNet" annot_csv_filename = "bl26lb16.041912.annot.csv" -[vak.predict.transform_params] +[vak.predict.dataset.params] window_size = 88 [vak.predict.model.TweetyNet.network] diff --git a/tests/data_for_tests/configs/TweetyNet_train_audio_cbin_annot_notmat.toml b/tests/data_for_tests/configs/TweetyNet_train_audio_cbin_annot_notmat.toml index aba76e566..3cddb00b6 100644 --- a/tests/data_for_tests/configs/TweetyNet_train_audio_cbin_annot_notmat.toml +++ b/tests/data_for_tests/configs/TweetyNet_train_audio_cbin_annot_notmat.toml @@ -28,10 +28,7 @@ num_workers = 16 device = "cuda" root_results_dir = "./tests/data_for_tests/generated/results/train/audio_cbin_annot_notmat/TweetyNet" -[vak.train.train_dataset_params] -window_size = 88 - -[vak.train.val_transform_params] +[vak.train.dataset.params] window_size = 88 [vak.train.model.TweetyNet.network] diff --git a/tests/data_for_tests/configs/TweetyNet_train_continue_audio_cbin_annot_notmat.toml b/tests/data_for_tests/configs/TweetyNet_train_continue_audio_cbin_annot_notmat.toml index 754a8ac4e..09e93b442 100644 --- a/tests/data_for_tests/configs/TweetyNet_train_continue_audio_cbin_annot_notmat.toml +++ b/tests/data_for_tests/configs/TweetyNet_train_continue_audio_cbin_annot_notmat.toml @@ -30,10 +30,7 @@ root_results_dir = "./tests/data_for_tests/generated/results/train_continue/audi checkpoint_path = "~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/TweetyNet/checkpoints/max-val-acc-checkpoint.pt" spect_scaler_path = "~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/StandardizeSpect" -[vak.train.train_dataset_params] -window_size = 88 - -[vak.train.val_transform_params] +[vak.train.dataset.params] window_size = 88 [vak.train.model.TweetyNet.network] diff --git a/tests/data_for_tests/configs/TweetyNet_train_continue_spect_mat_annot_yarden.toml b/tests/data_for_tests/configs/TweetyNet_train_continue_spect_mat_annot_yarden.toml index 59f94124b..adbedc2c6 100644 --- a/tests/data_for_tests/configs/TweetyNet_train_continue_spect_mat_annot_yarden.toml +++ b/tests/data_for_tests/configs/TweetyNet_train_continue_spect_mat_annot_yarden.toml @@ -29,10 +29,7 @@ device = "cuda" root_results_dir = "./tests/data_for_tests/generated/results/train_continue/spect_mat_annot_yarden/TweetyNet" checkpoint_path = "~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/TweetyNet/checkpoints/max-val-acc-checkpoint.pt" -[vak.train.train_dataset_params] -window_size = 88 - -[vak.train.val_transform_params] +[vak.train.dataset.params] window_size = 88 [vak.train.model.TweetyNet.network] diff --git a/tests/data_for_tests/configs/TweetyNet_train_spect_mat_annot_yarden.toml b/tests/data_for_tests/configs/TweetyNet_train_spect_mat_annot_yarden.toml index 588f2dd51..81858e596 100644 --- a/tests/data_for_tests/configs/TweetyNet_train_spect_mat_annot_yarden.toml +++ b/tests/data_for_tests/configs/TweetyNet_train_spect_mat_annot_yarden.toml @@ -28,10 +28,7 @@ num_workers = 16 device = "cuda" root_results_dir = "./tests/data_for_tests/generated/results/train/spect_mat_annot_yarden/TweetyNet" -[vak.train.train_dataset_params] -window_size = 88 - -[vak.train.val_transform_params] +[vak.train.dataset.params] window_size = 88 [vak.train.model.TweetyNet.network] From eda99eb5ca16e9df2f2d6e2b6419ba43b62a55a5 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Fri, 3 May 2024 22:01:35 -0400 Subject: [PATCH 133/183] Add `params` attribute to DatasetConfig --- src/vak/config/dataset.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/src/vak/config/dataset.py b/src/vak/config/dataset.py index 9ec2d2ff9..e0ed822ab 100644 --- a/src/vak/config/dataset.py +++ b/src/vak/config/dataset.py @@ -20,9 +20,15 @@ class DatasetConfig: datasets. splits_path : pathlib.Path, optional Path to file representing splits. + Default is None. name : str, optional Name of dataset. Only required for built-in datasets - from the :mod:`~vak.datasets` module. + from the :mod:`~vak.datasets` module. Default is None. + params: dict, optional + Parameters for dataset class, + passed in as keyword arguments. + E.g., ``window_size=2000``. + Default is None. """ path: pathlib.Path = field(converter=pathlib.Path) @@ -32,6 +38,9 @@ class DatasetConfig: name: str | None = field( converter=attr.converters.optional(str), default=None ) + params : dict | None = field( + converter=attr.converters.optional(dict), default=None + ) @classmethod def from_config_dict(cls, dict_: dict) -> DatasetConfig: @@ -39,4 +48,5 @@ def from_config_dict(cls, dict_: dict) -> DatasetConfig: path=dict_.get("path"), splits_path=dict_.get("splits_path"), name=dict_.get("name"), + params=dict_.get("params") ) From 513046fd9a3efad27676fe90215409a99ac44883 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Fri, 3 May 2024 22:10:03 -0400 Subject: [PATCH 134/183] Change transform/dataset params keys in doc/toml/ to a dataset table with a params key --- doc/toml/gy6or6_eval.toml | 7 +++---- doc/toml/gy6or6_predict.toml | 7 +++---- doc/toml/gy6or6_train.toml | 12 +++--------- 3 files changed, 9 insertions(+), 17 deletions(-) diff --git a/doc/toml/gy6or6_eval.toml b/doc/toml/gy6or6_eval.toml index 80a147a0f..71355ed28 100644 --- a/doc/toml/gy6or6_eval.toml +++ b/doc/toml/gy6or6_eval.toml @@ -64,10 +64,9 @@ majority_vote = true # Only applied if this option is specified. min_segment_dur = 0.02 -# transform_params: parameters used when transforming data -# for a frame classification model, we use FrameDataset with the eval_item_transform, -# that reshapes batches into consecutive adjacent windows with a specific `window_size` -[vak.eval.transform_params] +# dataset.params = parameters used for datasets +# for a frame classification model, we use dataset classes with a specific `window_size` +[vak.eval.dataset.params] window_size = 176 # Note we do not specify any options for the model, and just use the defaults diff --git a/doc/toml/gy6or6_predict.toml b/doc/toml/gy6or6_predict.toml index 1144aac4b..c4c89ef73 100644 --- a/doc/toml/gy6or6_predict.toml +++ b/doc/toml/gy6or6_predict.toml @@ -59,10 +59,9 @@ majority_vote = true min_segment_dur = 0.01 # dataset_path : path to dataset created by prep. This will be added when you run `vak prep`, you don't have to add it -# transform_params: parameters used when transforming data -# for a frame classification model, we use FrameDataset with the eval_item_transform, -# that reshapes batches into consecutive adjacent windows with a specific `window_size` -[vak.predict.transform_params] +# dataset.params = parameters used for datasets +# for a frame classification model, we use dataset classes with a specific `window_size` +[vak.predict.dataset.params] window_size = 176 # Note we do not specify any options for the network, and just use the defaults diff --git a/doc/toml/gy6or6_train.toml b/doc/toml/gy6or6_train.toml index dde4a8926..68202f796 100644 --- a/doc/toml/gy6or6_train.toml +++ b/doc/toml/gy6or6_train.toml @@ -56,15 +56,9 @@ num_workers = 4 device = "cuda" # dataset_path : path to dataset created by prep. This will be added when you run `vak prep`, you don't have to add it -# train_dataset_params: parameters used when loading training dataset -# for a frame classification model, we use a WindowDataset with a specific `window_size` -[vak.train.train_dataset_params] -window_size = 176 - -# val_transform_params: parameters used when transforming validation data -# for a frame classification model, we use FrameDataset with the eval_item_transform, -# that reshapes batches into consecutive adjacent windows with a specific `window_size` -[vak.train.val_transform_params] +# dataset.params = parameters used for datasets +# for a frame classification model, we use dataset classes with a specific `window_size` +[vak.train.dataset.params] window_size = 176 # To indicate the model to train, we use a "dotted key" with `model` followed by the string name of the model. From 5317dff041fa1d26cd71b2bb470fa2e73ce33f52 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Fri, 3 May 2024 22:14:34 -0400 Subject: [PATCH 135/183] Rewrite vak/config/model.py method 'to_dict' as 'asdict', using attrs asdict function. We now return 'name' and will just get it from the dict instead of having a separate 'model_name' parameter for functions that take 'model_config' --- src/vak/config/model.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/src/vak/config/model.py b/src/vak/config/model.py index 3db4b5344..aaf920866 100644 --- a/src/vak/config/model.py +++ b/src/vak/config/model.py @@ -2,7 +2,7 @@ from __future__ import annotations -from attrs import define, field +from attrs import asdict, define, field from attrs.validators import instance_of from .. import models @@ -93,18 +93,10 @@ def from_config_dict(cls, config_dict: dict): model_config[model_table] = {} return cls(name=model_name, **model_config) - def to_dict(self): + def asdict(self): """Convert this :class:`ModelConfig` instance to a :class:`dict` that can be passed into functions that take a ``model_config`` argument, like :func:`vak.train` and :func:`vak.predict`. - - This function drops the ``name`` attribute, - and returns all other attributes in a :class:`dict`. """ - return { - "network": self.network, - "optimizer": self.optimizer, - "loss": self.loss, - "metrics": self.metrics, - } + return asdict(self) From a1f0e0a314f86ee8252788185a8d56bdf902bc7a Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Fri, 3 May 2024 22:15:01 -0400 Subject: [PATCH 136/183] Add asdict method to DatasetConfig class, like ModelConfig.asdict --- src/vak/config/dataset.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/vak/config/dataset.py b/src/vak/config/dataset.py index e0ed822ab..d5e242b47 100644 --- a/src/vak/config/dataset.py +++ b/src/vak/config/dataset.py @@ -5,7 +5,7 @@ import pathlib import attr.validators -from attr import define, field +from attr import asdict, define, field @define @@ -50,3 +50,11 @@ def from_config_dict(cls, dict_: dict) -> DatasetConfig: name=dict_.get("name"), params=dict_.get("params") ) + + def asdict(self): + """Convert this :class:`DatasetConfig` instance + to a :class:`dict` that can be passed + into functions that take a ``dataset_config`` argument, + like :func:`vak.train` and :func:`vak.predict`. + """ + return asdict(self) From bd0d05052a85a0826167a3d3e16769216fc975c7 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Fri, 3 May 2024 22:16:18 -0400 Subject: [PATCH 137/183] Fix calls to model.to_dict() -> model.asdict() --- src/vak/cli/eval.py | 2 +- src/vak/cli/learncurve.py | 2 +- src/vak/cli/predict.py | 2 +- src/vak/cli/train.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/vak/cli/eval.py b/src/vak/cli/eval.py index 6780d73eb..272c3cf3f 100644 --- a/src/vak/cli/eval.py +++ b/src/vak/cli/eval.py @@ -46,7 +46,7 @@ def eval(toml_path): eval_module.eval( model_name=cfg.eval.model.name, - model_config=cfg.eval.model.to_dict(), + model_config=cfg.eval.model.asdict(), dataset_path=cfg.eval.dataset.path, checkpoint_path=cfg.eval.checkpoint_path, labelmap_path=cfg.eval.labelmap_path, diff --git a/src/vak/cli/learncurve.py b/src/vak/cli/learncurve.py index 09f76692e..810492b28 100644 --- a/src/vak/cli/learncurve.py +++ b/src/vak/cli/learncurve.py @@ -54,7 +54,7 @@ def learning_curve(toml_path): learncurve.learning_curve( model_name=cfg.learncurve.model.name, - model_config=cfg.learncurve.model.to_dict(), + model_config=cfg.learncurve.model.asdict(), dataset_path=cfg.learncurve.dataset.path, batch_size=cfg.learncurve.batch_size, num_epochs=cfg.learncurve.num_epochs, diff --git a/src/vak/cli/predict.py b/src/vak/cli/predict.py index aecd96b7f..0625de613 100644 --- a/src/vak/cli/predict.py +++ b/src/vak/cli/predict.py @@ -44,7 +44,7 @@ def predict(toml_path): predict_module.predict( model_name=cfg.predict.model.name, - model_config=cfg.predict.model.to_dict(), + model_config=cfg.predict.model.asdict(), dataset_path=cfg.predict.dataset.path, checkpoint_path=cfg.predict.checkpoint_path, labelmap_path=cfg.predict.labelmap_path, diff --git a/src/vak/cli/train.py b/src/vak/cli/train.py index 3487958ba..542a93fbc 100644 --- a/src/vak/cli/train.py +++ b/src/vak/cli/train.py @@ -54,7 +54,7 @@ def train(toml_path): train_module.train( model_name=cfg.train.model.name, - model_config=cfg.train.model.to_dict(), + model_config=cfg.train.model.asdict(), dataset_path=cfg.train.dataset.path, train_transform_params=cfg.train.train_transform_params, train_dataset_params=cfg.train.train_dataset_params, From ba788c60ec35ac3f1e9266f94c8dc78b96b20328 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Fri, 3 May 2024 22:20:41 -0400 Subject: [PATCH 138/183] Add unit tests for DatasetConfig.asdict --- tests/test_config/test_dataset.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/tests/test_config/test_dataset.py b/tests/test_config/test_dataset.py index 3ecff930e..ddb475781 100644 --- a/tests/test_config/test_dataset.py +++ b/tests/test_config/test_dataset.py @@ -75,3 +75,33 @@ def test_from_config_dict(self, config_dict): else: assert dataset_config.name is None + @pytest.mark.parametrize( + 'config_dict', + [ + { + 'path' :'~/datasets/BioSoundSegBench', + 'splits_path': 'splits/Bengalese-Finch-song-gy6or6-replicate-1.json', + 'name': 'BioSoundSegBench', + }, + { + 'path' :'~/user/prepped/dataset', + }, + { + 'path' :'~/user/prepped/dataset', + 'splits_path': 'splits/replicate-1.json' + }, + ] + ) + def test_asdict(self, config_dict): + dataset_config = vak.config.dataset.DatasetConfig.from_config_dict(config_dict) + + dataset_config_as_dict = dataset_config.asdict() + + for key in ('name', 'path', 'splits_path', 'params'): + if key in config_dict: + if 'path' in key: + assert dataset_config_as_dict[key] == pathlib.Path(config_dict[key]) + else: + assert dataset_config_as_dict[key] == config_dict[key] + else: + assert dataset_config_as_dict[key] is None From 0738b7700c30b4189f6ec469b2f1dbb67002bbd3 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Fri, 3 May 2024 22:26:58 -0400 Subject: [PATCH 139/183] Add unit tests for ModelConfig.asdict --- tests/test_config/test_model.py | 45 +++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/tests/test_config/test_model.py b/tests/test_config/test_model.py index 73745df79..3ba45fd48 100644 --- a/tests/test_config/test_model.py +++ b/tests/test_config/test_model.py @@ -104,6 +104,51 @@ def test_from_config_dict_real_config(self, a_generated_config_dict): else: assert getattr(model_config, attr) == {} + @pytest.mark.parametrize( + 'config_dict', + [ + { + 'TweetyNet': { + 'optimizer': {'lr': 1e-3}, + } + }, + { + 'TweetyNet': { + 'network': {}, + 'optimizer': {'lr': 1e-3}, + 'loss': {}, + 'metrics': {}, + } + }, + { + 'ED_TCN': { + 'optimizer': {'lr': 1e-3}, + } + }, + { + "ConvEncoderUMAP": { + "optimizer": 1e-3 + } + } + ] + ) + def test_asdict(self, config_dict): + model_config = vak.config.model.ModelConfig.from_config_dict(config_dict) + + model_config_as_dict = model_config.asdict() + + assert isinstance(model_config_as_dict, dict) + + model_name = list(config_dict.keys())[0] + for key in ('name', 'network', 'optimizer', 'loss', 'metrics'): + if key == 'name': + assert model_config_as_dict[key] == model_name + else: + if key in config_dict[model_name]: + assert model_config_as_dict[key] == config_dict[model_name][key] + else: + assert model_config_as_dict[key] == {} + @pytest.mark.parametrize( 'config_dict, expected_exception', [ From 3df6a5cf70019a02612bc2185d5f2c717fde80c9 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Fri, 3 May 2024 22:27:22 -0400 Subject: [PATCH 140/183] Add an assertion in tests/test_config/test_dataset.py --- tests/test_config/test_dataset.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_config/test_dataset.py b/tests/test_config/test_dataset.py index ddb475781..690882b14 100644 --- a/tests/test_config/test_dataset.py +++ b/tests/test_config/test_dataset.py @@ -97,6 +97,7 @@ def test_asdict(self, config_dict): dataset_config_as_dict = dataset_config.asdict() + assert isinstance(dataset_config_as_dict, dict) for key in ('name', 'path', 'splits_path', 'params'): if key in config_dict: if 'path' in key: From 9e2dab9e9ab16971cd00bc2c8495c7a6d56bd9e5 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Sat, 4 May 2024 07:58:05 -0400 Subject: [PATCH 141/183] Remove transform params and dataset_params from EvalConfig, will just use dataset attribute, a DatasetConfig, with its params attribute --- src/vak/config/eval.py | 20 -------------------- 1 file changed, 20 deletions(-) diff --git a/src/vak/config/eval.py b/src/vak/config/eval.py index d52aca1f0..3012da649 100644 --- a/src/vak/config/eval.py +++ b/src/vak/config/eval.py @@ -126,14 +126,6 @@ class EvalConfig: a float value for ``min_segment_dur``. See the docstring of the transform for more details on these arguments and how they work. - transform_params: dict, optional - Parameters for data transform. - Passed as keyword arguments. - Optional, default is None. - dataset_params: dict, optional - Parameters for dataset. - Passed as keyword arguments. - Optional, default is None. """ # required, external files @@ -171,18 +163,6 @@ class EvalConfig: num_workers = field(validator=instance_of(int), default=2) device = field(validator=instance_of(str), default=device.get_default()) - transform_params = field( - converter=converters.optional(dict), - validator=validators.optional(instance_of(dict)), - default=None, - ) - - dataset_params = field( - converter=converters.optional(dict), - validator=validators.optional(instance_of(dict)), - default=None, - ) - @classmethod def from_config_dict(cls, config_dict: dict) -> EvalConfig: """Return :class:`EvalConfig` instance from a :class:`dict`. From 1b70d547f2c079075532615f8557783f694c2cb8 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Sat, 4 May 2024 08:01:54 -0400 Subject: [PATCH 142/183] Remove dataset/transform_params key-value pairs in valid-version-1.0.toml, and add params key to dataset tables with in-line table params --- src/vak/config/valid-version-1.0.toml | 15 +++------------ 1 file changed, 3 insertions(+), 12 deletions(-) diff --git a/src/vak/config/valid-version-1.0.toml b/src/vak/config/valid-version-1.0.toml index 7ebcb1556..ded6aa6ae 100644 --- a/src/vak/config/valid-version-1.0.toml +++ b/src/vak/config/valid-version-1.0.toml @@ -47,15 +47,12 @@ patience = 4 results_dir_made_by_main_script = '/some/path/to/learncurve/' checkpoint_path = '/home/user/results_181014_194418/TweetyNet/checkpoints/' spect_scaler_path = '/home/user/results_181014_194418/spect_scaler' -train_transform_params = {'resize' = 128} -train_dataset_params = {'window_size' = 80} -val_transform_params = {'resize' = 128} -val_dataset_params = {'window_size' = 80} [vak.train.dataset] name = 'IntlDistributedSongbirdConsortiumPack' path = 'tests/test_data/prep/train/032312_prep_191224_225912.csv' splits_path = 'tests/test_data/prep/train/032312_prep_191224_225912.splits.json' +params = {window_size = 2000} [vak.train.model.TweetyNet] @@ -68,8 +65,6 @@ num_workers = 4 device = 'cuda' spect_scaler_path = '/home/user/results_181014_194418/spect_scaler' post_tfm_kwargs = {'majority_vote' = true, 'min_segment_dur' = 0.01} -transform_params = {'resize' = 128} -dataset_params = {'window_size' = 80} [vak.eval.dataset] name = 'IntlDistributedSongbirdConsortiumPack' @@ -91,15 +86,12 @@ results_dir_made_by_main_script = '/some/path/to/learncurve/' post_tfm_kwargs = {'majority_vote' = true, 'min_segment_dur' = 0.01} num_workers = 4 device = 'cuda' -train_transform_params = {'resize' = 128} -train_dataset_params = {'window_size' = 80} -val_transform_params = {'resize' = 128} -val_dataset_params = {'window_size' = 80} [vak.learncurve.dataset] name = 'IntlDistributedSongbirdConsortiumPack' path = 'tests/test_data/prep/learncurve/032312_prep_191224_225910.csv' splits_path = 'tests/test_data/prep/train/032312_prep_191224_225912.splits.json' +params = {window_size = 2000} [vak.learncurve.model.TweetyNet] @@ -115,12 +107,11 @@ spect_scaler_path = '/home/user/results_181014_194418/spect_scaler' min_segment_dur = 0.004 majority_vote = false save_net_outputs = false -transform_params = {'resize' = 128} -dataset_params = {'window_size' = 80} [vak.predict.dataset] name = 'IntlDistributedSongbirdConsortiumPack' path = 'tests/test_data/prep/learncurve/032312_prep_191224_225910.csv' splits_path = 'tests/test_data/prep/train/032312_prep_191224_225912.splits.json' +params = {window_size = 2000} [vak.predict.model.TweetyNet] \ No newline at end of file From 48830d45a8aba3675f03bfe7a75af67d4f6f41f4 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Sat, 4 May 2024 08:03:44 -0400 Subject: [PATCH 143/183] Remove train/val/dataset/transform_params from TrainConfig, will use DatasetConfig attribute params instead --- src/vak/config/train.py | 24 ------------------------ 1 file changed, 24 deletions(-) diff --git a/src/vak/config/train.py b/src/vak/config/train.py index 7ee5cdaff..4c997a8c1 100644 --- a/src/vak/config/train.py +++ b/src/vak/config/train.py @@ -123,30 +123,6 @@ class TrainConfig: default=None, ) - train_transform_params = field( - converter=converters.optional(dict), - validator=validators.optional(instance_of(dict)), - default=None, - ) - - train_dataset_params = field( - converter=converters.optional(dict), - validator=validators.optional(instance_of(dict)), - default=None, - ) - - val_transform_params = field( - converter=converters.optional(dict), - validator=validators.optional(instance_of(dict)), - default=None, - ) - - val_dataset_params = field( - converter=converters.optional(dict), - validator=validators.optional(instance_of(dict)), - default=None, - ) - @classmethod def from_config_dict(cls, config_dict: dict) -> "TrainConfig": """Return :class:`TrainConfig` instance from a :class:`dict`. From ad05a9850f398d12f6d9c5b10b592c79d5a847d4 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Sat, 4 May 2024 08:04:48 -0400 Subject: [PATCH 144/183] Remove train/val/dataset/transform_params from PredictConfig, will use DatasetConfig attribute params instead --- src/vak/config/predict.py | 28 ++++------------------------ 1 file changed, 4 insertions(+), 24 deletions(-) diff --git a/src/vak/config/predict.py b/src/vak/config/predict.py index 0029cc956..8803d9317 100644 --- a/src/vak/config/predict.py +++ b/src/vak/config/predict.py @@ -32,10 +32,10 @@ class PredictConfig: and optionally a path to a file representing splits, and the name, if it is a built-in dataset. Must be an instance of :class:`vak.config.DatasetConfig`. - checkpoint_path : str - path to directory with checkpoint files saved by Torch, to reload model - labelmap_path : str - path to 'labelmap.json' file. + checkpoint_path : str + path to directory with checkpoint files saved by Torch, to reload model + labelmap_path : str + path to 'labelmap.json' file. model : vak.config.ModelConfig The model to use: its name, and the parameters to configure it. @@ -83,14 +83,6 @@ class PredictConfig: spectrogram with `spect_path` filename `gy6or6_032312_081416.npz`, and the network is `TweetyNet`, then the net output file will be `gy6or6_032312_081416.tweetynet.output.npz`. - transform_params: dict, optional - Parameters for data transform. - Passed as keyword arguments. - Optional, default is None. - dataset_params: dict, optional - Parameters for dataset. - Passed as keyword arguments. - Optional, default is None. """ # required, external files @@ -129,18 +121,6 @@ class PredictConfig: majority_vote = field(validator=instance_of(bool), default=True) save_net_outputs = field(validator=instance_of(bool), default=False) - transform_params = field( - converter=converters.optional(dict), - validator=validators.optional(instance_of(dict)), - default=None, - ) - - dataset_params = field( - converter=converters.optional(dict), - validator=validators.optional(instance_of(dict)), - default=None, - ) - @classmethod def from_config_dict(cls, config_dict: dict) -> PredictConfig: """Return :class:`PredictConfig` instance from a :class:`dict`. From 7f93b743ed7ec7f76905e9058c949d990fb36a93 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Sat, 4 May 2024 08:08:13 -0400 Subject: [PATCH 145/183] Revise transforms.defaults.frame_classification.TrainItemTransform and change get_default_frame_classification_transform to return an instance of the TrainItemTransform when 'mode' is 'train' --- .../defaults/frame_classification.py | 36 ++++++------------- 1 file changed, 11 insertions(+), 25 deletions(-) diff --git a/src/vak/transforms/defaults/frame_classification.py b/src/vak/transforms/defaults/frame_classification.py index 54e8cb642..36d22161e 100644 --- a/src/vak/transforms/defaults/frame_classification.py +++ b/src/vak/transforms/defaults/frame_classification.py @@ -27,32 +27,32 @@ def __init__( ): if spect_standardizer is not None: if isinstance(spect_standardizer, vak_transforms.StandardizeSpect): - source_transform = [spect_standardizer] + frames_transform = [spect_standardizer] else: raise TypeError( f"invalid type for spect_standardizer: {type(spect_standardizer)}. " "Should be an instance of vak.transforms.StandardizeSpect" ) else: - source_transform = [] + frames_transform = [] - source_transform.extend( + frames_transform.extend( [ vak_transforms.ToFloatTensor(), vak_transforms.AddChannel(), ] ) self.source_transform = torchvision.transforms.Compose( - source_transform + frames_transform ) - self.annot_transform = vak_transforms.ToLongTensor() + self.frame_labels_transform = vak_transforms.ToLongTensor() - def __call__(self, source, annot, spect_path=None): - source = self.source_transform(source) - annot = self.annot_transform(annot) + def __call__(self, frames, frame_labels, spect_path=None): + frames = self.frames_transform(frames) + frame_labels = self.frame_labels_transform(frame_labels) item = { - "frames": source, - "frame_labels": annot, + "frames": frames, + "frame_labels": frame_labels, } if spect_path is not None: @@ -239,21 +239,7 @@ def get_default_frame_classification_transform( ) if mode == "train": - if spect_standardizer is not None: - transform = [spect_standardizer] - else: - transform = [] - - transform.extend( - [ - vak_transforms.ToFloatTensor(), - vak_transforms.AddChannel(), - ] - ) - transform = torchvision.transforms.Compose(transform) - - target_transform = vak_transforms.ToLongTensor() - return transform, target_transform + return TrainItemTransform(spect_standardizer) elif mode == "predict": item_transform = PredictItemTransform( From 7152657b3e6eb9cb50f044595f7beaefa8edf1a4 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Sat, 4 May 2024 08:21:59 -0400 Subject: [PATCH 146/183] Make vak.table.dataset.params into an in-line table in toml files in tests/data_for_tests/configs --- .../configs/TweetyNet_eval_audio_cbin_annot_notmat.toml | 4 ++-- .../configs/TweetyNet_learncurve_audio_cbin_annot_notmat.toml | 4 ++-- .../configs/TweetyNet_predict_audio_cbin_annot_notmat.toml | 4 ++-- .../configs/TweetyNet_train_audio_cbin_annot_notmat.toml | 4 ++-- .../TweetyNet_train_continue_audio_cbin_annot_notmat.toml | 4 ++-- .../TweetyNet_train_continue_spect_mat_annot_yarden.toml | 4 ++-- .../configs/TweetyNet_train_spect_mat_annot_yarden.toml | 4 ++-- tests/data_for_tests/configs/invalid_key_config.toml | 4 ++-- tests/data_for_tests/configs/invalid_table_config.toml | 2 +- 9 files changed, 17 insertions(+), 17 deletions(-) diff --git a/tests/data_for_tests/configs/TweetyNet_eval_audio_cbin_annot_notmat.toml b/tests/data_for_tests/configs/TweetyNet_eval_audio_cbin_annot_notmat.toml index a629f1832..12bfcba84 100644 --- a/tests/data_for_tests/configs/TweetyNet_eval_audio_cbin_annot_notmat.toml +++ b/tests/data_for_tests/configs/TweetyNet_eval_audio_cbin_annot_notmat.toml @@ -27,8 +27,8 @@ output_dir = "./tests/data_for_tests/generated/results/eval/audio_cbin_annot_not majority_vote = true min_segment_dur = 0.02 -[vak.eval.dataset.params] -window_size = 88 +[vak.eval.dataset] +params = { window_size = 88 } [vak.eval.model.TweetyNet.network] conv1_filters = 8 diff --git a/tests/data_for_tests/configs/TweetyNet_learncurve_audio_cbin_annot_notmat.toml b/tests/data_for_tests/configs/TweetyNet_learncurve_audio_cbin_annot_notmat.toml index c169b3bda..59868a28a 100644 --- a/tests/data_for_tests/configs/TweetyNet_learncurve_audio_cbin_annot_notmat.toml +++ b/tests/data_for_tests/configs/TweetyNet_learncurve_audio_cbin_annot_notmat.toml @@ -34,8 +34,8 @@ root_results_dir = "./tests/data_for_tests/generated/results/learncurve/audio_cb majority_vote = true min_segment_dur = 0.02 -[vak.learncurve.dataset.params] -window_size = 88 +[vak.learncurve.dataset] +params = { window_size = 88 } [vak.learncurve.model.TweetyNet.network] conv1_filters = 8 diff --git a/tests/data_for_tests/configs/TweetyNet_predict_audio_cbin_annot_notmat.toml b/tests/data_for_tests/configs/TweetyNet_predict_audio_cbin_annot_notmat.toml index 1f2bf74b1..3d794f314 100644 --- a/tests/data_for_tests/configs/TweetyNet_predict_audio_cbin_annot_notmat.toml +++ b/tests/data_for_tests/configs/TweetyNet_predict_audio_cbin_annot_notmat.toml @@ -22,8 +22,8 @@ device = "cuda" output_dir = "./tests/data_for_tests/generated/results/predict/audio_cbin_annot_notmat/TweetyNet" annot_csv_filename = "bl26lb16.041912.annot.csv" -[vak.predict.dataset.params] -window_size = 88 +[vak.predict.dataset] +params = { window_size = 88 } [vak.predict.model.TweetyNet.network] conv1_filters = 8 diff --git a/tests/data_for_tests/configs/TweetyNet_train_audio_cbin_annot_notmat.toml b/tests/data_for_tests/configs/TweetyNet_train_audio_cbin_annot_notmat.toml index 3cddb00b6..9b751e7f0 100644 --- a/tests/data_for_tests/configs/TweetyNet_train_audio_cbin_annot_notmat.toml +++ b/tests/data_for_tests/configs/TweetyNet_train_audio_cbin_annot_notmat.toml @@ -28,8 +28,8 @@ num_workers = 16 device = "cuda" root_results_dir = "./tests/data_for_tests/generated/results/train/audio_cbin_annot_notmat/TweetyNet" -[vak.train.dataset.params] -window_size = 88 +[vak.train.dataset] +params = { window_size = 88 } [vak.train.model.TweetyNet.network] conv1_filters = 8 diff --git a/tests/data_for_tests/configs/TweetyNet_train_continue_audio_cbin_annot_notmat.toml b/tests/data_for_tests/configs/TweetyNet_train_continue_audio_cbin_annot_notmat.toml index 09e93b442..c7ca91a96 100644 --- a/tests/data_for_tests/configs/TweetyNet_train_continue_audio_cbin_annot_notmat.toml +++ b/tests/data_for_tests/configs/TweetyNet_train_continue_audio_cbin_annot_notmat.toml @@ -30,8 +30,8 @@ root_results_dir = "./tests/data_for_tests/generated/results/train_continue/audi checkpoint_path = "~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/TweetyNet/checkpoints/max-val-acc-checkpoint.pt" spect_scaler_path = "~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/StandardizeSpect" -[vak.train.dataset.params] -window_size = 88 +[vak.train.dataset] +params = { window_size = 88 } [vak.train.model.TweetyNet.network] conv1_filters = 8 diff --git a/tests/data_for_tests/configs/TweetyNet_train_continue_spect_mat_annot_yarden.toml b/tests/data_for_tests/configs/TweetyNet_train_continue_spect_mat_annot_yarden.toml index adbedc2c6..c66e9c34d 100644 --- a/tests/data_for_tests/configs/TweetyNet_train_continue_spect_mat_annot_yarden.toml +++ b/tests/data_for_tests/configs/TweetyNet_train_continue_spect_mat_annot_yarden.toml @@ -29,8 +29,8 @@ device = "cuda" root_results_dir = "./tests/data_for_tests/generated/results/train_continue/spect_mat_annot_yarden/TweetyNet" checkpoint_path = "~/Documents/repos/coding/birdsong/TweetyNet/results/BFSongRepository/gy6or6/results_200620_165308/TweetyNet/checkpoints/max-val-acc-checkpoint.pt" -[vak.train.dataset.params] -window_size = 88 +[vak.train.dataset] +params = { window_size = 88 } [vak.train.model.TweetyNet.network] conv1_filters = 8 diff --git a/tests/data_for_tests/configs/TweetyNet_train_spect_mat_annot_yarden.toml b/tests/data_for_tests/configs/TweetyNet_train_spect_mat_annot_yarden.toml index 81858e596..a9aaaf112 100644 --- a/tests/data_for_tests/configs/TweetyNet_train_spect_mat_annot_yarden.toml +++ b/tests/data_for_tests/configs/TweetyNet_train_spect_mat_annot_yarden.toml @@ -28,8 +28,8 @@ num_workers = 16 device = "cuda" root_results_dir = "./tests/data_for_tests/generated/results/train/spect_mat_annot_yarden/TweetyNet" -[vak.train.dataset.params] -window_size = 88 +[vak.train.dataset] +params = { window_size = 88 } [vak.train.model.TweetyNet.network] conv1_filters = 8 diff --git a/tests/data_for_tests/configs/invalid_key_config.toml b/tests/data_for_tests/configs/invalid_key_config.toml index 7e95d9332..0012c6d6c 100644 --- a/tests/data_for_tests/configs/invalid_key_config.toml +++ b/tests/data_for_tests/configs/invalid_key_config.toml @@ -30,8 +30,8 @@ val_error_step = 1 checkpoint_step = 1 save_only_single_checkpoint_file = true -[vak.train.dataset_params] -window_size = 88 +[vak.train.dataset] +params = { window_size = 88 } [vak.train.model.TweetyNet.optimizer] learning_rate = 0.001 diff --git a/tests/data_for_tests/configs/invalid_table_config.toml b/tests/data_for_tests/configs/invalid_table_config.toml index daf0d4e0d..24998129d 100644 --- a/tests/data_for_tests/configs/invalid_table_config.toml +++ b/tests/data_for_tests/configs/invalid_table_config.toml @@ -20,7 +20,7 @@ freq_cutoffs = [500, 10000] thresh = 6.25 transform_type = 'log_spect' -[vak.trian] # <-- invalid section 'TRIAN' (instead of 'vak.train') +[vak.trian] # <-- invalid section 'trian' (instead of 'vak.train') model = 'TweetyNet' root_results_dir = '/home/user/data/subdir/' normalize_spectrograms = true From 149d6eb80eef190264b94d88e8aee2dcabf2f31e Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Sat, 4 May 2024 16:15:19 -0400 Subject: [PATCH 147/183] Fix attribute name in frame_classification.TrainItemTransform.__init__: source_transform -> frames_transform --- src/vak/transforms/defaults/frame_classification.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/vak/transforms/defaults/frame_classification.py b/src/vak/transforms/defaults/frame_classification.py index 36d22161e..0d2af838f 100644 --- a/src/vak/transforms/defaults/frame_classification.py +++ b/src/vak/transforms/defaults/frame_classification.py @@ -42,7 +42,7 @@ def __init__( vak_transforms.AddChannel(), ] ) - self.source_transform = torchvision.transforms.Compose( + self.frames_transform = torchvision.transforms.Compose( frames_transform ) self.frame_labels_transform = vak_transforms.ToLongTensor() From 21f7971b1b576fa69002824957f2855156116d91 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Sat, 4 May 2024 16:16:05 -0400 Subject: [PATCH 148/183] Rewrite datasets.frame_classification.WindowDataset to require item_transform, and assume that it is an instance of transforms.frame_classification.TrainItemTransform --- .../frame_classification/window_dataset.py | 27 ++++++++----------- 1 file changed, 11 insertions(+), 16 deletions(-) diff --git a/src/vak/datasets/frame_classification/window_dataset.py b/src/vak/datasets/frame_classification/window_dataset.py index 37fd6b910..c4781f07f 100644 --- a/src/vak/datasets/frame_classification/window_dataset.py +++ b/src/vak/datasets/frame_classification/window_dataset.py @@ -174,11 +174,10 @@ def __init__( inds_in_sample: npt.NDArray, window_size: int, frame_dur: float, + item_transform: Callable, stride: int = 1, subset: str | None = None, window_inds: npt.NDArray | None = None, - transform: Callable | None = None, - target_transform: Callable | None = None, ): """Initialize a new instance of a WindowDataset. @@ -211,6 +210,9 @@ def __init__( frame_dur: float Duration of a frame, i.e., a single sample in audio or a single timebin in a spectrogram. + item_transform : callable + The transform applied to each item :math:`(x, y)` + that is returned by :meth:`WindowDataset.__getitem__`. stride : int The size of the stride used to determine which windows are included in the dataset. The default is 1. @@ -267,8 +269,7 @@ def __init__( sample_ids.shape[-1], window_size, stride ) self.window_inds = window_inds - self.transform = transform - self.target_transform = target_transform + self.item_transform = item_transform @property def duration(self): @@ -277,10 +278,10 @@ def duration(self): @property def shape(self): tmp_x_ind = 0 - one_x, _ = self.__getitem__(tmp_x_ind) + tmp_item = self.__getitem__(tmp_x_ind) # used by vak functions that need to determine size of window, # e.g. when initializing a neural network model - return one_x.shape + return tmp_item['frames'].shape def _load_frames(self, frames_path): """Helper function that loads "frames", @@ -338,12 +339,8 @@ def __getitem__(self, idx): frame_labels = frame_labels[ inds_in_sample : inds_in_sample + self.window_size # noqa: E203 ] - if self.transform: - frames = self.transform(frames) - if self.target_transform: - frame_labels = self.target_transform(frame_labels) - - return frames, frame_labels + item = self.item_transform(frames, frame_labels) + return item def __len__(self): """number of batches""" @@ -354,11 +351,10 @@ def from_dataset_path( cls, dataset_path: str | pathlib.Path, window_size: int, + item_transform: Callable, stride: int = 1, split: str = "train", subset: str | None = None, - transform: Callable | None = None, - target_transform: Callable | None = None, ): """Make a :class:`WindowDataset` instance, given the path to a frame classification dataset. @@ -441,9 +437,8 @@ def from_dataset_path( inds_in_sample, window_size, frame_dur, + item_transform, stride, subset, window_inds, - transform, - target_transform, ) From be8e6939f69e522650d0e392ae6a436642afd7c9 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Sat, 4 May 2024 16:16:45 -0400 Subject: [PATCH 149/183] Rewrite datasets.frame_classification.FramesDataset to make item_transform required --- .../frame_classification/frames_dataset.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/src/vak/datasets/frame_classification/frames_dataset.py b/src/vak/datasets/frame_classification/frames_dataset.py index a8f5f6de0..94ea8169d 100644 --- a/src/vak/datasets/frame_classification/frames_dataset.py +++ b/src/vak/datasets/frame_classification/frames_dataset.py @@ -79,8 +79,8 @@ def __init__( sample_ids: npt.NDArray, inds_in_sample: npt.NDArray, frame_dur: float, + item_transform: Callable, subset: str | None = None, - item_transform: Callable | None = None, ): """Initialize a new instance of a FramesDataset. @@ -115,9 +115,9 @@ def __init__( If specified, this takes precedence over split. Subsets are typically taken from the training data for use when generating a learning curve. - item_transform : callable, optional - Transform applied to each item :math:`(x, y)` - returned by :meth:`FramesDataset.__getitem__`. + item_transform : callable + The transform applied to each item :math:`(x, y)` + that is returned by :meth:`FramesDataset.__getitem__`. """ from ... import ( prep, @@ -196,9 +196,9 @@ def __len__(self): def from_dataset_path( cls, dataset_path: str | pathlib.Path, + item_transform: Callable, split: str = "val", subset: str | None = None, - item_transform: Callable | None = None, ): """Make a :class:`FramesDataset` instance, given the path to a frame classification dataset. @@ -210,17 +210,18 @@ def from_dataset_path( frame classification dataset, as created by :func:`vak.prep.prep_frame_classification_dataset`. + item_transform : callable, optional + Transform applied to each item :math:`(x, y)` + returned by :meth:`FramesDataset.__getitem__`. split : str The name of a split from the dataset, one of {'train', 'val', 'test'}. + Default is "val". subset : str, optional Name of subset to use. If specified, this takes precedence over split. Subsets are typically taken from the training data for use when generating a learning curve. - item_transform : callable, optional - Transform applied to each item :math:`(x, y)` - returned by :meth:`FramesDataset.__getitem__`. Returns ------- @@ -263,6 +264,6 @@ def from_dataset_path( sample_ids, inds_in_sample, frame_dur, - subset, item_transform, + subset, ) From 558bb14e0e6af3598b07eb888ea08a5114db3c11 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Sat, 4 May 2024 16:21:07 -0400 Subject: [PATCH 150/183] Rewrite src/vak/train/frame_classification.py: remove params model_name, train/val_transform_params, train/val_dataset_params, and dataset_path, replace with dataset_config and just have model_config contain name --- src/vak/train/frame_classification.py | 90 ++++++++++----------------- 1 file changed, 32 insertions(+), 58 deletions(-) diff --git a/src/vak/train/frame_classification.py b/src/vak/train/frame_classification.py index 3f5646721..fa6243761 100644 --- a/src/vak/train/frame_classification.py +++ b/src/vak/train/frame_classification.py @@ -27,16 +27,11 @@ def get_split_dur(df: pd.DataFrame, split: str) -> float: def train_frame_classification_model( - model_name: str, model_config: dict, - dataset_path: str | pathlib.Path, + dataset_config: dict, batch_size: int, num_epochs: int, num_workers: int, - train_transform_params: dict | None = None, - train_dataset_params: dict | None = None, - val_transform_params: dict | None = None, - val_dataset_params: dict | None = None, checkpoint_path: str | pathlib.Path | None = None, spect_scaler_path: str | pathlib.Path | None = None, results_path: str | pathlib.Path | None = None, @@ -59,14 +54,12 @@ def train_frame_classification_model( Parameters ---------- - model_name : str - Model name, must be one of vak.models.registry.MODEL_NAMES. model_config : dict - Model configuration in a ``dict``, - as loaded from a .toml file, - and used by the model method ``from_config``. - dataset_path : str - Path to dataset, a directory generated by running ``vak prep``. + Model configuration in a :class:`dict`. + Can be obtained by calling :meth:`vak.config.ModelConfig.asdict`. + dataset_config: dict + Dataset configuration in a :class:`dict`. + Can be obtained by calling :meth:`vak.config.DatasetConfig.asdict`. batch_size : int number of samples per batch presented to models during training. num_epochs : int @@ -75,31 +68,7 @@ def train_frame_classification_model( num_workers : int Number of processes to use for parallel loading of data. Argument to torch.DataLoader. - train_transform_params: dict, optional - Parameters for training data transform. - Passed as keyword arguments. Optional, default is None. - train_dataset_params: dict, optional - Parameters for training dataset. - Passed as keyword arguments to - :class:`vak.datasets.frame_classification.WindowDataset`. - Optional, default is None. - val_transform_params: dict, optional - Parameters for validation data transform. - Passed as keyword arguments. - Optional, default is None. - val_dataset_params: dict, optional - Parameters for validation dataset. - Passed as keyword arguments to - :class:`vak.datasets.frame_classification.FramesDataset`. - Optional, default is None. - dataset_csv_path - Path to csv file representing splits of dataset, - e.g., such a file generated by running ``vak prep``. - This parameter is used by :func:`vak.core.learncurve` to specify - different splits to use, when generating results for a learning curve. - If this argument is specified, the csv file must be inside the directory - ``dataset_path``. checkpoint_path : str, pathlib.Path path to a checkpoint file, e.g., one generated by a previous run of ``vak.core.train``. @@ -158,14 +127,14 @@ def train_frame_classification_model( f"value for ``{path_name}`` not recognized as a file: {path}" ) - dataset_path = pathlib.Path(dataset_path) + dataset_path = pathlib.Path(dataset_config['path']) if not dataset_path.exists() or not dataset_path.is_dir(): raise NotADirectoryError( f"`dataset_path` not found or not recognized as a directory: {dataset_path}" ) logger.info( - f"Loading dataset from path: {dataset_path}", + f"Loading dataset from `dataset_path`: {dataset_path}", ) metadata = datasets.frame_classification.Metadata.from_dataset_path( dataset_path @@ -240,22 +209,31 @@ def train_frame_classification_model( ) spect_standardizer = None - if train_transform_params is None: - train_transform_params = {} - train_transform_params.update({"spect_standardizer": spect_standardizer}) - transform, target_transform = transforms.defaults.get_default_transform( - model_name, "train", transform_kwargs=train_transform_params + model_name = model_config['name'] + # TODO: move this into datapipe once each datapipe uses a fixed set of transforms + # that will require adding `spect_standardizer`` as a parameter to the datapipe, + # maybe rename to `frames_standardizer`? + try: + window_size = dataset_config["params"]["window_size"] + except KeyError as e: + raise KeyError( + f"The `dataset_config` for frame classification model '{model_name}' must include a 'params' sub-table " + f"that sets a value for 'window_size', but received a `dataset_config` that did not:\n{dataset_config}" + ) + transform_kwargs = { + "spect_standardizer": spect_standardizer, + "window_size": window_size, + } + train_transform = transforms.defaults.get_default_transform( + model_name, "train", transform_kwargs=transform_kwargs ) - if train_dataset_params is None: - train_dataset_params = {} train_dataset = WindowDataset.from_dataset_path( dataset_path=dataset_path, split="train", subset=subset, - transform=transform, - target_transform=target_transform, - **train_dataset_params, + item_transform=train_transform, + window_size=dataset_config['params']['window_size'], ) logger.info( f"Duration of WindowDataset used for training, in seconds: {train_dataset.duration}", @@ -278,19 +256,15 @@ def train_frame_classification_model( f"Total duration of validation split from dataset (in s): {val_dur}", ) - if val_transform_params is None: - val_transform_params = {} - val_transform_params.update({"spect_standardizer": spect_standardizer}) - item_transform = transforms.defaults.get_default_transform( - model_name, "eval", val_transform_params + # NOTE: we use same `transform_kwargs` here; will need to change to a `dataset_param` + # when we factor transform *into* fixed DataPipes as above + val_transform = transforms.defaults.get_default_transform( + model_name, "eval", transform_kwargs ) - if val_dataset_params is None: - val_dataset_params = {} val_dataset = FramesDataset.from_dataset_path( dataset_path=dataset_path, split="val", - item_transform=item_transform, - **val_dataset_params, + item_transform=val_transform, ) logger.info( f"Duration of FramesDataset used for evaluation, in seconds: {val_dataset.duration}", From b54be05b6f016491dc138b4bc7f2ff9eb9e293b1 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Sat, 4 May 2024 16:24:59 -0400 Subject: [PATCH 151/183] Rewrite src/vak/train/_train.py: remove params model_name, train/val_transform_params, train/val_dataset_params, and dataset_path, replace with dataset_config and just have model_config contain name --- src/vak/train/train_.py | 52 ++++++++--------------------------------- 1 file changed, 10 insertions(+), 42 deletions(-) diff --git a/src/vak/train/train_.py b/src/vak/train/train_.py index 443f4ae5b..7d3597179 100644 --- a/src/vak/train/train_.py +++ b/src/vak/train/train_.py @@ -14,16 +14,11 @@ def train( - model_name: str, model_config: dict, - dataset_path: str | pathlib.Path, + dataset_config: dict, batch_size: int, num_epochs: int, num_workers: int, - train_transform_params: dict | None = None, - train_dataset_params: dict | None = None, - val_transform_params: dict | None = None, - val_dataset_params: dict | None = None, checkpoint_path: str | pathlib.Path | None = None, spect_scaler_path: str | pathlib.Path | None = None, results_path: str | pathlib.Path | None = None, @@ -43,14 +38,12 @@ def train( Parameters ---------- - model_name : str - Model name, must be one of vak.models.registry.MODEL_NAMES. model_config : dict - Model configuration in a ``dict``, - as loaded from a .toml file, - and used by the model method ``from_config``. - dataset_path : str - Path to dataset, a directory generated by running ``vak prep``. + Model configuration in a :class:`dict`. + Can be obtained by calling :meth:`vak.config.ModelConfig.asdict`. + dataset_config: dict + Dataset configuration in a :class:`dict`. + Can be obtained by calling :meth:`vak.config.DatasetConfig.asdict`. window_size : int size of windows taken from spectrograms, in number of time bins, shown to neural networks @@ -62,22 +55,6 @@ def train( num_workers : int Number of processes to use for parallel loading of data. Argument to torch.DataLoader. - train_transform_params: dict, optional - Parameters for training data transform. - Passed as keyword arguments. - Optional, default is None. - train_dataset_params: dict, optional - Parameters for training dataset. - Passed as keyword arguments. - Optional, default is None. - val_transform_params: dict, optional - Parameters for validation data transform. - Passed as keyword arguments. - Optional, default is None. - val_dataset_params: dict, optional - Parameters for validation dataset. - Passed as keyword arguments. - Optional, default is None. checkpoint_path : str, pathlib.Path Path to a checkpoint file, e.g., one generated by a previous run of ``vak.core.train``. @@ -153,12 +130,13 @@ def train( f"value for ``{path_name}`` not recognized as a file: {path}" ) - dataset_path = pathlib.Path(dataset_path) + dataset_path = pathlib.Path(dataset_config['path']) if not dataset_path.exists() or not dataset_path.is_dir(): raise NotADirectoryError( f"`dataset_path` not found or not recognized as a directory: {dataset_path}" ) + model_name = model_config['name'] try: model_family = models.registry.MODEL_FAMILY_FROM_NAME[model_name] except KeyError as e: @@ -167,16 +145,11 @@ def train( ) from e if model_family == "FrameClassificationModel": train_frame_classification_model( - model_name=model_name, model_config=model_config, - dataset_path=dataset_path, + dataset_config=dataset_config, batch_size=batch_size, num_epochs=num_epochs, num_workers=num_workers, - train_transform_params=train_transform_params, - train_dataset_params=train_dataset_params, - val_transform_params=val_transform_params, - val_dataset_params=val_dataset_params, checkpoint_path=checkpoint_path, spect_scaler_path=spect_scaler_path, results_path=results_path, @@ -190,16 +163,11 @@ def train( ) elif model_family == "ParametricUMAPModel": train_parametric_umap_model( - model_name=model_name, model_config=model_config, - dataset_path=dataset_path, + dataset_config=dataset_config, batch_size=batch_size, num_epochs=num_epochs, num_workers=num_workers, - train_transform_params=train_transform_params, - train_dataset_params=train_dataset_params, - val_transform_params=val_transform_params, - val_dataset_params=val_dataset_params, checkpoint_path=checkpoint_path, results_path=results_path, shuffle=shuffle, From 88f8b496de48c1bf276666b46dffc09c42d71c5a Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Sat, 4 May 2024 16:25:45 -0400 Subject: [PATCH 152/183] Rewrite vak/cli/train.py to call train._train.train with just model_config and dataset_config, remove model_name, dataset_path, train/val_transform_params and train/val_dataset_params --- src/vak/cli/train.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/src/vak/cli/train.py b/src/vak/cli/train.py index 542a93fbc..c63096ca2 100644 --- a/src/vak/cli/train.py +++ b/src/vak/cli/train.py @@ -53,13 +53,8 @@ def train(toml_path): ) train_module.train( - model_name=cfg.train.model.name, model_config=cfg.train.model.asdict(), - dataset_path=cfg.train.dataset.path, - train_transform_params=cfg.train.train_transform_params, - train_dataset_params=cfg.train.train_dataset_params, - val_transform_params=cfg.train.val_transform_params, - val_dataset_params=cfg.train.val_dataset_params, + dataset_config=cfg.train.dataset.asdict(), batch_size=cfg.train.batch_size, num_epochs=cfg.train.num_epochs, num_workers=cfg.train.num_workers, From a8721cc23356993200429f2c0f303cb802e7c4f9 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Sat, 4 May 2024 16:54:37 -0400 Subject: [PATCH 153/183] Fix how we unpack batch in training_step method of FrameClassificationModel --- src/vak/models/frame_classification_model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/vak/models/frame_classification_model.py b/src/vak/models/frame_classification_model.py index 4ee2088a9..305018e8c 100644 --- a/src/vak/models/frame_classification_model.py +++ b/src/vak/models/frame_classification_model.py @@ -200,9 +200,9 @@ def training_step(self, batch: tuple, batch_idx: int): Scalar loss value computed by the loss function, ``self.loss``. """ - x, y = batch[0], batch[1] - out = self.network(x) - loss = self.loss(out, y) + frames, frame_labels = batch["frames"], batch["frame_labels"] + out = self.network(frames) + loss = self.loss(out, frame_labels) self.log("train_loss", loss, on_step=True) return loss From ad5d98bd1a149516defc6648c0dab30d81fb0b97 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Sat, 4 May 2024 20:01:00 -0400 Subject: [PATCH 154/183] Change transform_kwargs parameter of transforms.defaults.parametric_umap.get_default_parametric_umap_transform to default to None, and if None to be an empty dict --- src/vak/transforms/defaults/parametric_umap.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/vak/transforms/defaults/parametric_umap.py b/src/vak/transforms/defaults/parametric_umap.py index 9dca8bfb4..a62a7c29b 100644 --- a/src/vak/transforms/defaults/parametric_umap.py +++ b/src/vak/transforms/defaults/parametric_umap.py @@ -8,18 +8,22 @@ def get_default_parametric_umap_transform( - transform_kwargs, + transform_kwargs: dict | None = None, ) -> torchvision.transforms.Compose: """Get default transform for frame classification model. Parameters ---------- - transform_kwargs : dict + transform_kwargs : dict, optional + Keyword arguments for transform class. + Default is None. Returns ------- transform : Callable """ + if transform_kwargs is None: + transform_kwargs = {} transforms = [ vak_transforms.ToFloatTensor(), vak_transforms.AddChannel(), From 43c525ff83be5bb8304443d1483aa5fa2a3cfae4 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Sat, 4 May 2024 20:01:23 -0400 Subject: [PATCH 155/183] Change transform_kwargs parameter of transforms.defaults.frame_classification.get_default_frame_classification_transform to default to None, and if None to be an empty dict --- src/vak/transforms/defaults/frame_classification.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/src/vak/transforms/defaults/frame_classification.py b/src/vak/transforms/defaults/frame_classification.py index 0d2af838f..9d20cfe50 100644 --- a/src/vak/transforms/defaults/frame_classification.py +++ b/src/vak/transforms/defaults/frame_classification.py @@ -200,15 +200,18 @@ def __call__(self, frames, frames_path=None): def get_default_frame_classification_transform( - mode: str, transform_kwargs: dict + mode: str, transform_kwargs: dict | None = None ) -> tuple[Callable, Callable] | Callable: """Get default transform for frame classification model. Parameters ---------- mode : str - transform_kwargs : dict - A dict with the following key-value pairs: + transform_kwargs : dict, optional + Keyword arguments for transform class. + Default is None. + If supplied, should be a :class:`dict`, + that can include the following key-value pairs: spect_standardizer : vak.transforms.StandardizeSpect instance that has already been fit to dataset, using fit_df method. Default is None, in which case no standardization transform is applied. @@ -227,8 +230,10 @@ def get_default_frame_classification_transform( Returns ------- - + transform: TrainItemTransform, EvalItemTransform, or PredictItemTransform """ + if transform_kwargs is None: + transform_kwargs = {} spect_standardizer = transform_kwargs.get("spect_standardizer", None) # regardless of mode, transform always starts with StandardizeSpect, if used if spect_standardizer is not None: From 9e003fc6fff677f81369f0095e05ad59aca6eb78 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Sat, 4 May 2024 20:10:22 -0400 Subject: [PATCH 156/183] Change DatasetConfig.params attribute to default to empty dict, so we can unpack with ** operator even when no params are specified --- src/vak/config/dataset.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/vak/config/dataset.py b/src/vak/config/dataset.py index d5e242b47..64c615f67 100644 --- a/src/vak/config/dataset.py +++ b/src/vak/config/dataset.py @@ -39,7 +39,9 @@ class DatasetConfig: converter=attr.converters.optional(str), default=None ) params : dict | None = field( - converter=attr.converters.optional(dict), default=None + # we default to an empty dict instead of None + # so we can still do **['dataset']['params'] everywhere we do when params are specified + converter=attr.converters.optional(dict), default={} ) @classmethod From 0b5bb27301e65fea93bf64f90bd41c04b2944106 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Sat, 4 May 2024 20:12:59 -0400 Subject: [PATCH 157/183] Fix DatasetConfig.from_config_dict method to not use dict.get method, so we don't set attributes to None inadvertently --- src/vak/config/dataset.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/vak/config/dataset.py b/src/vak/config/dataset.py index 64c615f67..dc86ec963 100644 --- a/src/vak/config/dataset.py +++ b/src/vak/config/dataset.py @@ -45,12 +45,10 @@ class DatasetConfig: ) @classmethod - def from_config_dict(cls, dict_: dict) -> DatasetConfig: + def from_config_dict(cls, config_dict: dict) -> DatasetConfig: + return cls( - path=dict_.get("path"), - splits_path=dict_.get("splits_path"), - name=dict_.get("name"), - params=dict_.get("params") + **config_dict ) def asdict(self): From b7ca3322de82090a1455ec0b8f221bd090e15974 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Sat, 4 May 2024 20:13:58 -0400 Subject: [PATCH 158/183] Modify transforms.defaults.get so that transform_kwargs is None by default. Also revise docstring and type annotations --- src/vak/transforms/defaults/get.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/src/vak/transforms/defaults/get.py b/src/vak/transforms/defaults/get.py index db2cc556e..3d567bde7 100644 --- a/src/vak/transforms/defaults/get.py +++ b/src/vak/transforms/defaults/get.py @@ -2,16 +2,18 @@ from __future__ import annotations +from typing import Callable, Literal + from ... import models from . import frame_classification, parametric_umap def get_default_transform( model_name: str, - mode: str, - transform_kwargs: dict, -): - """Get default transforms for a model, + mode: Literal["eval", "predict", "train"], + transform_kwargs: dict | None = None, +) -> Callable: + """Get default transform for a model, according to its family and what mode the model is being used in. @@ -20,14 +22,13 @@ def get_default_transform( model_name : str Name of model. mode : str - one of {'train', 'eval', 'predict'}. Determines set of transforms. + One of {'eval', 'predict', 'train'}. Returns ------- - transform, target_transform : callable - one or more vak transforms to be applied to inputs x and, during training, the target y. - If more than one transform, they are combined into an instance of torchvision.transforms.Compose. - Note that when mode is 'predict', the target transform is None. + item_transform : callable + Transform to be applied to input :math:`x` to a model and, + during training, the target :math:`y`. """ try: model_family = models.registry.MODEL_FAMILY_FROM_NAME[model_name] From bebaa3f62d5c1d3879b4a893311fe51ca54d601b Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Sat, 4 May 2024 20:14:53 -0400 Subject: [PATCH 159/183] Rewrite src/vak/train/parametric_umap.py to use model_config and dataset_config parameters, removing parameters val/train_transform_params + val/train_dataset_params and dataset_path --- src/vak/train/parametric_umap.py | 53 +++++++++----------------------- 1 file changed, 15 insertions(+), 38 deletions(-) diff --git a/src/vak/train/parametric_umap.py b/src/vak/train/parametric_umap.py index c2b7fac64..f9180e5c0 100644 --- a/src/vak/train/parametric_umap.py +++ b/src/vak/train/parametric_umap.py @@ -78,16 +78,11 @@ def get_trainer( def train_parametric_umap_model( - model_name: str, model_config: dict, - dataset_path: str | pathlib.Path, + dataset_config: dict, batch_size: int, num_epochs: int, num_workers: int, - train_transform_params: dict | None = None, - train_dataset_params: dict | None = None, - val_transform_params: dict | None = None, - val_dataset_params: dict | None = None, checkpoint_path: str | pathlib.Path | None = None, root_results_dir: str | pathlib.Path | None = None, results_path: str | pathlib.Path | None = None, @@ -108,14 +103,12 @@ def train_parametric_umap_model( Parameters ---------- - model_name : str - Model name, must be one of vak.models.registry.MODEL_NAMES. model_config : dict - Model configuration in a ``dict``, - as loaded from a .toml file, - and used by the model method ``from_config``. - dataset_path : str - Path to dataset, a directory generated by running ``vak prep``. + Model configuration in a :class:`dict`. + Can be obtained by calling :meth:`vak.config.ModelConfig.asdict`. + dataset_config: dict + Dataset configuration in a :class:`dict`. + Can be obtained by calling :meth:`vak.config.DatasetConfig.asdict`. batch_size : int number of samples per batch presented to models during training. num_epochs : int @@ -124,16 +117,6 @@ def train_parametric_umap_model( num_workers : int Number of processes to use for parallel loading of data. Argument to torch.DataLoader. - train_dataset_params: dict, optional - Parameters for training dataset. - Passed as keyword arguments to - :class:`vak.datasets.parametric_umap.ParametricUMAP`. - Optional, default is None. - val_dataset_params: dict, optional - Parameters for validation dataset. - Passed as keyword arguments to - :class:`vak.datasets.parametric_umap.ParametricUMAP`. - Optional, default is None. checkpoint_path : str, pathlib.Path, optional path to a checkpoint file, e.g., one generated by a previous run of ``vak.core.train``. @@ -176,7 +159,7 @@ def train_parametric_umap_model( f"value for ``{path_name}`` not recognized as a file: {path}" ) - dataset_path = pathlib.Path(dataset_path) + dataset_path = pathlib.Path(dataset_config["path"]) if not dataset_path.exists() or not dataset_path.is_dir(): raise NotADirectoryError( f"`dataset_path` not found or not recognized as a directory: {dataset_path}" @@ -219,20 +202,18 @@ def train_parametric_umap_model( f"Total duration of training split from dataset (in s): {train_dur}", ) - if train_transform_params is None: - train_transform_params = {} - transform = transforms.defaults.get_default_transform( - model_name, "train", train_transform_params + model_name = model_config["name"] + train_transform = transforms.defaults.get_default_transform( + model_name, "train" ) - if train_dataset_params is None: - train_dataset_params = {} + dataset_params = dataset_config["params"] train_dataset = ParametricUMAPDataset.from_dataset_path( dataset_path=dataset_path, split="train", subset=subset, - transform=transform, - **train_dataset_params, + transform=train_transform, + **dataset_params, ) logger.info( f"Duration of ParametricUMAPDataset used for training, in seconds: {train_dataset.duration}", @@ -246,18 +227,14 @@ def train_parametric_umap_model( # ---------------- load validation set (if there is one) ----------------------------------------------------------- if val_step: - if val_transform_params is None: - val_transform_params = {} transform = transforms.defaults.get_default_transform( - model_name, "eval", val_transform_params + model_name, "eval" ) - if val_dataset_params is None: - val_dataset_params = {} val_dataset = ParametricUMAPDataset.from_dataset_path( dataset_path=dataset_path, split="val", transform=transform, - **val_dataset_params, + **dataset_params, ) logger.info( f"Duration of ParametricUMAPDataset used for validation, in seconds: {val_dataset.duration}", From 8f6ef3731fb5a0d859564d5336d3da2685c6c457 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Sat, 4 May 2024 20:21:45 -0400 Subject: [PATCH 160/183] Rewrite vak/eval/frame_classification.py to use model_config and dataset_config parameters, removing parameters transform_params + dataset_params and dataset_path --- src/vak/eval/frame_classification.py | 50 +++++++++++++--------------- 1 file changed, 24 insertions(+), 26 deletions(-) diff --git a/src/vak/eval/frame_classification.py b/src/vak/eval/frame_classification.py index cf86670a4..0cc4bb41f 100644 --- a/src/vak/eval/frame_classification.py +++ b/src/vak/eval/frame_classification.py @@ -21,15 +21,12 @@ def eval_frame_classification_model( - model_name: str, model_config: dict, - dataset_path: str | pathlib.Path, + dataset_config: dict, checkpoint_path: str | pathlib.Path, labelmap_path: str | pathlib.Path, output_dir: str | pathlib.Path, num_workers: int, - transform_params: dict | None = None, - dataset_params: dict | None = None, split: str = "test", spect_scaler_path: str | pathlib.Path = None, post_tfm_kwargs: dict | None = None, @@ -39,14 +36,12 @@ def eval_frame_classification_model( Parameters ---------- - model_name : str - Model name, must be one of vak.models.registry.MODEL_NAMES. model_config : dict - Model configuration in a ``dict``, - as loaded from a .toml file, - and used by the model method ``from_config``. - dataset_path : str, pathlib.Path - Path to dataset, e.g., a csv file generated by running ``vak prep``. + Model configuration in a :class:`dict`. + Can be obtained by calling :meth:`vak.config.ModelConfig.asdict`. + dataset_config: dict + Dataset configuration in a :class:`dict`. + Can be obtained by calling :meth:`vak.config.DatasetConfig.asdict`. checkpoint_path : str, pathlib.Path Path to directory with checkpoint files saved by Torch, to reload model output_dir : str, pathlib.Path @@ -56,14 +51,6 @@ def eval_frame_classification_model( num_workers : int Number of processes to use for parallel loading of data. Argument to torch.DataLoader. Default is 2. - transform_params: dict, optional - Parameters for data transform. - Passed as keyword arguments. - Optional, default is None. - dataset_params: dict, optional - Parameters for dataset. - Passed as keyword arguments. - Optional, default is None. split : str Split of dataset on which model should be evaluated. One of {'train', 'val', 'test'}. Default is 'test'. @@ -108,7 +95,7 @@ def eval_frame_classification_model( f"value for ``{path_name}`` not recognized as a file: {path}" ) - dataset_path = pathlib.Path(dataset_path) + dataset_path = pathlib.Path(dataset_config['path']) if not dataset_path.exists() or not dataset_path.is_dir(): raise NotADirectoryError( f"`dataset_path` not found or not recognized as a directory: {dataset_path}" @@ -142,19 +129,30 @@ def eval_frame_classification_model( logger.info(f"loading labelmap from path: {labelmap_path}") with labelmap_path.open("r") as f: labelmap = json.load(f) - if transform_params is None: - transform_params = {} - transform_params.update({"spect_standardizer": spect_standardizer}) + + model_name = model_config["name"] + # TODO: move this into datapipe once each datapipe uses a fixed set of transforms + # that will require adding `spect_standardizer`` as a parameter to the datapipe, + # maybe rename to `frames_standardizer`? + try: + window_size = dataset_config["params"]["window_size"] + except KeyError as e: + raise KeyError( + f"The `dataset_config` for frame classification model '{model_name}' must include a 'params' sub-table " + f"that sets a value for 'window_size', but received a `dataset_config` that did not:\n{dataset_config}" + ) + transform_params = { + "spect_standardizer": spect_standardizer, + "window_size": window_size + } + item_transform = transforms.defaults.get_default_transform( model_name, "eval", transform_params ) - if dataset_params is None: - dataset_params = {} val_dataset = FramesDataset.from_dataset_path( dataset_path=dataset_path, split=split, item_transform=item_transform, - **dataset_params, ) val_loader = torch.utils.data.DataLoader( dataset=val_dataset, From 9768b5ddd448b4e25aaca1c007ee227d923c1660 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Sat, 4 May 2024 20:34:46 -0400 Subject: [PATCH 161/183] Rewrite vak/eval/parametric_umap.py to use model_config and dataset_config parameters, removing parameters transform_params + dataset_params and dataset_path --- src/vak/eval/parametric_umap.py | 36 +++++++++------------------------ 1 file changed, 10 insertions(+), 26 deletions(-) diff --git a/src/vak/eval/parametric_umap.py b/src/vak/eval/parametric_umap.py index 5eeadf19b..107d8d844 100644 --- a/src/vak/eval/parametric_umap.py +++ b/src/vak/eval/parametric_umap.py @@ -19,15 +19,12 @@ def eval_parametric_umap_model( - model_name: str, model_config: dict, - dataset_path: str | pathlib.Path, + dataset_config: dict, checkpoint_path: str | pathlib.Path, output_dir: str | pathlib.Path, batch_size: int, num_workers: int, - transform_params: dict | None = None, - dataset_params: dict | None = None, split: str = "test", device: str | None = None, ) -> None: @@ -35,14 +32,12 @@ def eval_parametric_umap_model( Parameters ---------- - model_name : str - Model name, must be one of vak.models.registry.MODEL_NAMES. model_config : dict - Model configuration in a ``dict``, - as loaded from a .toml file, - and used by the model method ``from_config``. - dataset_path : str, pathlib.Path - Path to dataset, e.g., a csv file generated by running ``vak prep``. + Model configuration in a :class:`dict`. + Can be obtained by calling :meth:`vak.config.ModelConfig.asdict`. + dataset_config: dict + Dataset configuration in a :class:`dict`. + Can be obtained by calling :meth:`vak.config.DatasetConfig.asdict`. checkpoint_path : str, pathlib.Path Path to directory with checkpoint files saved by Torch, to reload model output_dir : str, pathlib.Path @@ -52,14 +47,6 @@ def eval_parametric_umap_model( num_workers : int Number of processes to use for parallel loading of data. Argument to torch.DataLoader. Default is 2. - transform_params: dict, optional - Parameters for data transform. - Passed as keyword arguments. - Optional, default is None. - dataset_params: dict, optional - Parameters for dataset. - Passed as keyword arguments. - Optional, default is None. split : str Split of dataset on which model should be evaluated. One of {'train', 'val', 'test'}. Default is 'test'. @@ -78,7 +65,7 @@ def eval_parametric_umap_model( f"value for ``{path_name}`` not recognized as a file: {path}" ) - dataset_path = pathlib.Path(dataset_path) + dataset_path = pathlib.Path(dataset_config["path"]) if not dataset_path.exists() or not dataset_path.is_dir(): raise NotADirectoryError( f"`dataset_path` not found or not recognized as a directory: {dataset_path}" @@ -96,18 +83,15 @@ def eval_parametric_umap_model( timenow = datetime.now().strftime("%y%m%d_%H%M%S") # ---------------- load data for evaluation ------------------------------------------------------------------------ - if transform_params is None: - transform_params = {} + model_name = model_config["name"] item_transform = transforms.defaults.get_default_transform( - model_name, "eval", transform_params + model_name, "eval" ) - if dataset_params is None: - dataset_params = {} val_dataset = ParametricUMAPDataset.from_dataset_path( dataset_path=dataset_path, split=split, transform=item_transform, - **dataset_params, + **dataset_config["params"], ) val_loader = torch.utils.data.DataLoader( dataset=val_dataset, From 519f7d43e3a69064346cb7f4b70f7d1450ba3da5 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Sat, 4 May 2024 20:35:18 -0400 Subject: [PATCH 162/183] Rewrite vak/eval/eval_.py to use model_config and dataset_config parameters, removing parameters transform_params + dataset_params and dataset_path --- src/vak/eval/eval_.py | 39 +++++++++++---------------------------- 1 file changed, 11 insertions(+), 28 deletions(-) diff --git a/src/vak/eval/eval_.py b/src/vak/eval/eval_.py index 8800bb815..d8a13e663 100644 --- a/src/vak/eval/eval_.py +++ b/src/vak/eval/eval_.py @@ -14,16 +14,13 @@ def eval( - model_name: str, model_config: dict, - dataset_path: str | pathlib.Path, + dataset_config: dict, checkpoint_path: str | pathlib.Path, output_dir: str | pathlib.Path, num_workers: int, labelmap_path: str | pathlib.Path | None = None, batch_size: int | None = None, - transform_params: dict | None = None, - dataset_params: dict | None = None, split: str = "test", spect_scaler_path: str | pathlib.Path = None, post_tfm_kwargs: dict | None = None, @@ -33,14 +30,12 @@ def eval( Parameters ---------- - model_name : str - Model name, must be one of vak.models.registry.MODEL_NAMES. model_config : dict - Model configuration in a ``dict``, - as loaded from a .toml file, - and used by the model method ``from_config``. - dataset_path : str, pathlib.Path - Path to dataset, e.g., a csv file generated by running ``vak prep``. + Model configuration in a :class:`dict`. + Can be obtained by calling :meth:`vak.config.ModelConfig.asdict`. + dataset_config: dict + Dataset configuration in a :class:`dict`. + Can be obtained by calling :meth:`vak.config.DatasetConfig.asdict`. checkpoint_path : str, pathlib.Path path to directory with checkpoint files saved by Torch, to reload model output_dir : str, pathlib.Path @@ -54,14 +49,6 @@ def eval( batch_size : int, optional. Number of samples per batch fed into model. Optional, default is None. - transform_params: dict, optional - Parameters for data transform. - Passed as keyword arguments. - Optional, default is None. - dataset_params: dict, optional - Parameters for dataset. - Passed as keyword arguments. - Optional, default is None. split : str split of dataset on which model should be evaluated. One of {'train', 'val', 'test'}. Default is 'test'. @@ -106,29 +93,28 @@ def eval( f"value for ``{path_name}`` not recognized as a file: {path}" ) - dataset_path = pathlib.Path(dataset_path) + dataset_path = pathlib.Path(dataset_config['path']) if not dataset_path.exists() or not dataset_path.is_dir(): raise NotADirectoryError( f"`dataset_path` not found or not recognized as a directory: {dataset_path}" ) + model_name = model_config['name'] try: model_family = models.registry.MODEL_FAMILY_FROM_NAME[model_name] except KeyError as e: raise ValueError( f"No model family found for the model name specified: {model_name}" ) from e + if model_family == "FrameClassificationModel": eval_frame_classification_model( - model_name=model_name, model_config=model_config, - dataset_path=dataset_path, + dataset_config=dataset_config, checkpoint_path=checkpoint_path, labelmap_path=labelmap_path, output_dir=output_dir, num_workers=num_workers, - transform_params=transform_params, - dataset_params=dataset_params, split=split, spect_scaler_path=spect_scaler_path, device=device, @@ -136,15 +122,12 @@ def eval( ) elif model_family == "ParametricUMAPModel": eval_parametric_umap_model( - model_name=model_name, model_config=model_config, - dataset_path=dataset_path, + dataset_config=dataset_config, checkpoint_path=checkpoint_path, output_dir=output_dir, batch_size=batch_size, num_workers=num_workers, - transform_params=transform_params, - dataset_params=dataset_params, split=split, device=device, ) From e68ad492f06cfb3091cb6d693a08f970c4969b15 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Sat, 4 May 2024 20:35:57 -0400 Subject: [PATCH 163/183] Rewrite cli.eval to pass model_config and dataset_config into eval_module.eval, remove dataset_path/transform_params/datset_params arguments --- src/vak/cli/eval.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/src/vak/cli/eval.py b/src/vak/cli/eval.py index 272c3cf3f..29bee65a5 100644 --- a/src/vak/cli/eval.py +++ b/src/vak/cli/eval.py @@ -1,5 +1,9 @@ +"""Evaluate a trained model with dataset specified in config.toml file.""" + +from __future__ import annotations + import logging -from pathlib import Path +import pathlib from .. import config from .. import eval as eval_module @@ -8,8 +12,9 @@ logger = logging.getLogger(__name__) -def eval(toml_path): - """evaluate a trained model with dataset specified in config.toml file. +def eval(toml_path: str | pathlib.Path) -> None: + """Evaluate a trained model with dataset specified in config.toml file. + Function called by command-line interface. Parameters @@ -21,7 +26,7 @@ def eval(toml_path): ------- None """ - toml_path = Path(toml_path) + toml_path = pathlib.Path(toml_path) cfg = config.Config.from_toml_path(toml_path) if cfg.eval is None: @@ -45,16 +50,13 @@ def eval(toml_path): ) eval_module.eval( - model_name=cfg.eval.model.name, model_config=cfg.eval.model.asdict(), - dataset_path=cfg.eval.dataset.path, + dataset_config=cfg.eval.dataset.asdict(), checkpoint_path=cfg.eval.checkpoint_path, labelmap_path=cfg.eval.labelmap_path, output_dir=cfg.eval.output_dir, num_workers=cfg.eval.num_workers, batch_size=cfg.eval.batch_size, - transform_params=cfg.eval.transform_params, - dataset_params=cfg.eval.dataset_params, spect_scaler_path=cfg.eval.spect_scaler_path, device=cfg.eval.device, post_tfm_kwargs=cfg.eval.post_tfm_kwargs, From dca625f92c23c7a395ad61ae4eaaeaadb31b4b86 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Sat, 4 May 2024 20:38:13 -0400 Subject: [PATCH 164/183] Unpack dataset_config[params] with ** inside trak/frame_classification.py, instead of directly getting window_size from the params dict --- src/vak/train/frame_classification.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/vak/train/frame_classification.py b/src/vak/train/frame_classification.py index fa6243761..99dd93734 100644 --- a/src/vak/train/frame_classification.py +++ b/src/vak/train/frame_classification.py @@ -233,7 +233,7 @@ def train_frame_classification_model( split="train", subset=subset, item_transform=train_transform, - window_size=dataset_config['params']['window_size'], + **dataset_config['params'], ) logger.info( f"Duration of WindowDataset used for training, in seconds: {train_dataset.duration}", From 43d1a997ebc59330ea856792c86f8016137275ff Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Sat, 4 May 2024 20:38:36 -0400 Subject: [PATCH 165/183] Rewrite vak/learncurve/frame_classification.py to use model_config and dataset_config parameters, removing parameters transform_params + dataset_params and dataset_path --- src/vak/learncurve/frame_classification.py | 47 ++++------------------ 1 file changed, 8 insertions(+), 39 deletions(-) diff --git a/src/vak/learncurve/frame_classification.py b/src/vak/learncurve/frame_classification.py index 30690ed3c..fb64be7a5 100644 --- a/src/vak/learncurve/frame_classification.py +++ b/src/vak/learncurve/frame_classification.py @@ -17,17 +17,12 @@ def learning_curve_for_frame_classification_model( - model_name: str, model_config: dict, - dataset_path: str | pathlib.Path, + dataset_config: dict, batch_size: int, num_epochs: int, num_workers: int, results_path: str | pathlib.Path, - train_transform_params: dict | None = None, - train_dataset_params: dict | None = None, - val_transform_params: dict | None = None, - val_dataset_params: dict | None = None, post_tfm_kwargs: dict | None = None, normalize_spectrograms: bool = True, shuffle: bool = True, @@ -48,12 +43,12 @@ def learning_curve_for_frame_classification_model( Parameters ---------- - model_name : str - Model name, must be one of vak.models.registry.MODEL_NAMES. model_config : dict - Model configuration in a ``dict``, - as loaded from a .toml file, - and used by the model method ``from_config``. + Model configuration in a :class:`dict`. + Can be obtained by calling :meth:`vak.config.ModelConfig.asdict`. + dataset_config: dict + Dataset configuration in a :class:`dict`. + Can be obtained by calling :meth:`vak.config.DatasetConfig.asdict`. dataset_path : str path to where dataset was saved as a csv. batch_size : int @@ -66,24 +61,6 @@ def learning_curve_for_frame_classification_model( Argument to torch.DataLoader. results_path : str, pathlib.Path Directory where results will be saved. - train_transform_params: dict, optional - Parameters for training data transform. - Passed as keyword arguments. - Optional, default is None. - train_dataset_params: dict, optional - Parameters for training dataset. - Passed as keyword arguments to - :class:`vak.datasets.frame_classification.WindowDataset`. - Optional, default is None. - val_transform_params: dict, optional - Parameters for validation data transform. - Passed as keyword arguments. - Optional, default is None. - val_dataset_params: dict, optional - Parameters for validation dataset. - Passed as keyword arguments to - :class:`vak.datasets.frame_classification.FramesDataset`. - Optional, default is None. previous_run_path : str, Path Path to directory containing dataset .csv files that represent subsets of training set, created by @@ -206,16 +183,11 @@ def learning_curve_for_frame_classification_model( ) train_frame_classification_model( - model_name, model_config, - dataset_path, + dataset_config, batch_size, num_epochs, num_workers, - train_transform_params, - train_dataset_params, - val_transform_params, - val_dataset_params, results_path=results_path_this_replicate, normalize_spectrograms=normalize_spectrograms, shuffle=shuffle, @@ -261,15 +233,12 @@ def learning_curve_for_frame_classification_model( spect_scaler_path = None eval_frame_classification_model( - model_name, model_config, - dataset_path, + dataset_config, ckpt_path, labelmap_path, results_path_this_replicate, num_workers, - val_transform_params, - val_dataset_params, "test", spect_scaler_path, post_tfm_kwargs, From af971a0d0918dc614c76b4ccb77aff2d13ad5d88 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Sat, 4 May 2024 20:38:46 -0400 Subject: [PATCH 166/183] Rewrite vak/learncurve/learncurve.py to use model_config and dataset_config parameters, removing parameters transform_params + dataset_params and dataset_path --- src/vak/learncurve/learncurve.py | 26 +++++++------------------- 1 file changed, 7 insertions(+), 19 deletions(-) diff --git a/src/vak/learncurve/learncurve.py b/src/vak/learncurve/learncurve.py index 63b27a6cb..39e176109 100644 --- a/src/vak/learncurve/learncurve.py +++ b/src/vak/learncurve/learncurve.py @@ -13,16 +13,11 @@ def learning_curve( - model_name: str, model_config: dict, - dataset_path: str | pathlib.Path, + dataset_config: dict, batch_size: int, num_epochs: int, num_workers: int, - train_transform_params: dict | None = None, - train_dataset_params: dict | None = None, - val_transform_params: dict | None = None, - val_dataset_params: dict | None = None, results_path: str | pathlib.Path = None, post_tfm_kwargs: dict | None = None, normalize_spectrograms: bool = True, @@ -44,14 +39,12 @@ def learning_curve( Parameters ---------- - model_name : str - Model name, must be one of vak.models.registry.MODEL_NAMES. model_config : dict - Model configuration in a ``dict``, - as loaded from a .toml file, - and used by the model method ``from_config``. - dataset_path : str - path to where dataset was saved as a csv. + Model configuration in a :class:`dict`. + Can be obtained by calling :meth:`vak.config.ModelConfig.asdict`. + dataset_config: dict + Dataset configuration in a :class:`dict`. + Can be obtained by calling :meth:`vak.config.DatasetConfig.asdict`. batch_size : int number of samples per batch presented to models during training. num_epochs : int @@ -122,16 +115,11 @@ def learning_curve( ) from e if model_family == "FrameClassificationModel": learning_curve_for_frame_classification_model( - model_name=model_name, model_config=model_config, - dataset_path=dataset_path, + dataset_config=dataset_config, batch_size=batch_size, num_epochs=num_epochs, num_workers=num_workers, - train_transform_params=train_transform_params, - train_dataset_params=train_dataset_params, - val_transform_params=val_transform_params, - val_dataset_params=val_dataset_params, results_path=results_path, post_tfm_kwargs=post_tfm_kwargs, normalize_spectrograms=normalize_spectrograms, From 22571445dec35f05537701f0d4b1915761e5cb01 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Sat, 4 May 2024 20:39:01 -0400 Subject: [PATCH 167/183] Rewrite cli.learncurve to pass model_config and dataset_config into learning_curve.learncurve, remove dataset_path/transform_params/datset_params arguments --- src/vak/cli/learncurve.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/src/vak/cli/learncurve.py b/src/vak/cli/learncurve.py index 810492b28..2decc5cd8 100644 --- a/src/vak/cli/learncurve.py +++ b/src/vak/cli/learncurve.py @@ -53,16 +53,11 @@ def learning_curve(toml_path): ) learncurve.learning_curve( - model_name=cfg.learncurve.model.name, model_config=cfg.learncurve.model.asdict(), - dataset_path=cfg.learncurve.dataset.path, + dataset_config=cfg.learncurve.dataset.asdict(), batch_size=cfg.learncurve.batch_size, num_epochs=cfg.learncurve.num_epochs, num_workers=cfg.learncurve.num_workers, - train_transform_params=cfg.learncurve.train_transform_params, - train_dataset_params=cfg.learncurve.train_dataset_params, - val_transform_params=cfg.learncurve.val_transform_params, - val_dataset_params=cfg.learncurve.val_dataset_params, results_path=results_path, post_tfm_kwargs=cfg.learncurve.post_tfm_kwargs, normalize_spectrograms=cfg.learncurve.normalize_spectrograms, From cd60eab2cc9a4c559acb5f67ef11d34e8321fea1 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Sat, 4 May 2024 20:50:25 -0400 Subject: [PATCH 168/183] Rewrite vak/predict/frame_classification.py to use model_config and dataset_config parameters, removing parameters transform_params + dataset_params and dataset_path --- src/vak/predict/frame_classification.py | 145 ++++++++++++------------ 1 file changed, 72 insertions(+), 73 deletions(-) diff --git a/src/vak/predict/frame_classification.py b/src/vak/predict/frame_classification.py index ee09833f4..79db38221 100644 --- a/src/vak/predict/frame_classification.py +++ b/src/vak/predict/frame_classification.py @@ -23,14 +23,11 @@ def predict_with_frame_classification_model( - model_name: str, model_config: dict, - dataset_path, + dataset_config: dict, checkpoint_path, labelmap_path, num_workers=2, - transform_params: dict | None = None, - dataset_params: dict | None = None, timebins_key="t", spect_scaler_path=None, device=None, @@ -40,75 +37,66 @@ def predict_with_frame_classification_model( majority_vote=False, save_net_outputs=False, ): - """Make predictions on a dataset with a trained model. + """Make predictions on a dataset with a trained + :class:`~vak.models.FrameClassificationModel`. Parameters ---------- - model_name : str - Model name, must be one of vak.models.registry.MODEL_NAMES. model_config : dict - Model configuration in a ``dict``, - as loaded from a .toml file, - and used by the model method ``from_config``. - dataset_path : str - Path to dataset, e.g., a csv file generated by running ``vak prep``. - checkpoint_path : str - path to directory with checkpoint files saved by Torch, to reload model - labelmap_path : str - path to 'labelmap.json' file. - num_workers : int - Number of processes to use for parallel loading of data. - Argument to torch.DataLoader. Default is 2. - transform_params: dict, optional - Parameters for data transform. - Passed as keyword arguments. - Optional, default is None. - dataset_params: dict, optional - Parameters for dataset. - Passed as keyword arguments. - Optional, default is None. - spect_key : str - key for accessing spectrogram in files. Default is 's'. - timebins_key : str - key for accessing vector of time bins in files. Default is 't'. - device : str - Device on which to work with model + data. - Defaults to 'cuda' if torch.cuda.is_available is True. - spect_scaler_path : str - path to a saved SpectScaler object used to normalize spectrograms. - If spectrograms were normalized and this is not provided, will give - incorrect results. - annot_csv_filename : str - name of .csv file containing predicted annotations. - Default is None, in which case the name of the dataset .csv - is used, with '.annot.csv' appended to it. - output_dir : str, Path - path to location where .csv containing predicted annotation - should be saved. Defaults to current working directory. - min_segment_dur : float - minimum duration of segment, in seconds. If specified, then - any segment with a duration less than min_segment_dur is - removed from lbl_tb. Default is None, in which case no - segments are removed. - majority_vote : bool - if True, transform segments containing multiple labels - into segments with a single label by taking a "majority vote", - i.e. assign all time bins in the segment the most frequently - occurring label in the segment. This transform can only be - applied if the labelmap contains an 'unlabeled' label, - because unlabeled segments makes it possible to identify - the labeled segments. Default is False. + Model configuration in a :class:`dict`. + Can be obtained by calling :meth:`vak.config.ModelConfig.asdict`. + dataset_config: dict + Dataset configuration in a :class:`dict`. + Can be obtained by calling :meth:`vak.config.DatasetConfig.asdict`. + checkpoint_path : str + path to directory with checkpoint files saved by Torch, to reload model + labelmap_path : str + path to 'labelmap.json' file. + num_workers : int + Number of processes to use for parallel loading of data. + Argument to torch.DataLoader. Default is 2. + spect_key : str + key for accessing spectrogram in files. Default is 's'. + timebins_key : str + key for accessing vector of time bins in files. Default is 't'. + device : str + Device on which to work with model + data. + Defaults to 'cuda' if torch.cuda.is_available is True. + spect_scaler_path : str + path to a saved SpectScaler object used to normalize spectrograms. + If spectrograms were normalized and this is not provided, will give + incorrect results. + annot_csv_filename : str + name of .csv file containing predicted annotations. + Default is None, in which case the name of the dataset .csv + is used, with '.annot.csv' appended to it. + output_dir : str, Path + path to location where .csv containing predicted annotation + should be saved. Defaults to current working directory. + min_segment_dur : float + minimum duration of segment, in seconds. If specified, then + any segment with a duration less than min_segment_dur is + removed from lbl_tb. Default is None, in which case no + segments are removed. + majority_vote : bool + if True, transform segments containing multiple labels + into segments with a single label by taking a "majority vote", + i.e. assign all time bins in the segment the most frequently + occurring label in the segment. This transform can only be + applied if the labelmap contains an 'unlabeled' label, + because unlabeled segments makes it possible to identify + the labeled segments. Default is False. save_net_outputs : bool - if True, save 'raw' outputs of neural networks - before they are converted to annotations. Default is False. - Typically the output will be "logits" - to which a softmax transform might be applied. - For each item in the dataset--each row in the `dataset_path` .csv-- - the output will be saved in a separate file in `output_dir`, - with the extension `{MODEL_NAME}.output.npz`. E.g., if the input is a - spectrogram with `spect_path` filename `gy6or6_032312_081416.npz`, - and the network is `TweetyNet`, then the net output file - will be `gy6or6_032312_081416.tweetynet.output.npz`. + if True, save 'raw' outputs of neural networks + before they are converted to annotations. Default is False. + Typically the output will be "logits" + to which a softmax transform might be applied. + For each item in the dataset--each row in the `dataset_path` .csv-- + the output will be saved in a separate file in `output_dir`, + with the extension `{MODEL_NAME}.output.npz`. E.g., if the input is a + spectrogram with `spect_path` filename `gy6or6_032312_081416.npz`, + and the network is `TweetyNet`, then the net output file + will be `gy6or6_032312_081416.tweetynet.output.npz`. """ for path, path_name in zip( (checkpoint_path, labelmap_path, spect_scaler_path), @@ -120,7 +108,7 @@ def predict_with_frame_classification_model( f"value for ``{path_name}`` not recognized as a file: {path}" ) - dataset_path = pathlib.Path(dataset_path) + dataset_path = pathlib.Path(dataset_config["path"]) if not dataset_path.exists() or not dataset_path.is_dir(): raise NotADirectoryError( f"`dataset_path` not found or not recognized as a directory: {dataset_path}" @@ -147,9 +135,21 @@ def predict_with_frame_classification_model( logger.info("Not loading SpectScaler, no path was specified") spect_standardizer = None - if transform_params is None: - transform_params = {} - transform_params.update({"spect_standardizer": spect_standardizer}) + model_name = model_config["name"] + # TODO: move this into datapipe once each datapipe uses a fixed set of transforms + # that will require adding `spect_standardizer`` as a parameter to the datapipe, + # maybe rename to `frames_standardizer`? + try: + window_size = dataset_config["params"]["window_size"] + except KeyError as e: + raise KeyError( + f"The `dataset_config` for frame classification model '{model_name}' must include a 'params' sub-table " + f"that sets a value for 'window_size', but received a `dataset_config` that did not:\n{dataset_config}" + ) + transform_params = { + "spect_standardizer": spect_standardizer, + "window_size": window_size, + } item_transform = transforms.defaults.get_default_transform( model_name, "predict", transform_params ) @@ -172,7 +172,6 @@ def predict_with_frame_classification_model( dataset_path=dataset_path, split="predict", item_transform=item_transform, - **dataset_params, ) pred_loader = torch.utils.data.DataLoader( From 571a9c6b845914e7138dcd4082290f7d32b0bf9e Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Sat, 4 May 2024 20:50:39 -0400 Subject: [PATCH 169/183] Rewrite vak/predict/parametric_umap.py to use model_config and dataset_config parameters, removing parameters transform_params + dataset_params and dataset_path --- src/vak/predict/parametric_umap.py | 71 ++++++++++++++---------------- 1 file changed, 34 insertions(+), 37 deletions(-) diff --git a/src/vak/predict/parametric_umap.py b/src/vak/predict/parametric_umap.py index b83c975df..bcc9250e9 100644 --- a/src/vak/predict/parametric_umap.py +++ b/src/vak/predict/parametric_umap.py @@ -18,9 +18,8 @@ def predict_with_parametric_umap_model( - model_name: str, model_config: dict, - dataset_path, + dataset_config: dict, checkpoint_path, num_workers=2, transform_params: dict | None = None, @@ -29,23 +28,22 @@ def predict_with_parametric_umap_model( device=None, output_dir=None, ): - """Make predictions on a dataset with a trained model. + """Make predictions on a dataset with a trained + :class:`vak.models.ParametricUMAPModel`. Parameters ---------- - model_name : str - Model name, must be one of vak.models.registry.MODEL_NAMES. model_config : dict - Model configuration in a ``dict``, - as loaded from a .toml file, - and used by the model method ``from_config``. - dataset_path : str - Path to dataset, e.g., a csv file generated by running ``vak prep``. - checkpoint_path : str - path to directory with checkpoint files saved by Torch, to reload model - num_workers : int - Number of processes to use for parallel loading of data. - Argument to torch.DataLoader. Default is 2. + Model configuration in a :class:`dict`. + Can be obtained by calling :meth:`vak.config.ModelConfig.asdict`. + dataset_config: dict + Dataset configuration in a :class:`dict`. + Can be obtained by calling :meth:`vak.config.DatasetConfig.asdict`. + checkpoint_path : str + path to directory with checkpoint files saved by Torch, to reload model + num_workers : int + Number of processes to use for parallel loading of data. + Argument to torch.DataLoader. Default is 2. transform_params: dict, optional Parameters for data transform. Passed as keyword arguments. @@ -54,18 +52,18 @@ def predict_with_parametric_umap_model( Parameters for dataset. Passed as keyword arguments. Optional, default is None. - timebins_key : str - key for accessing vector of time bins in files. Default is 't'. - device : str - Device on which to work with model + data. - Defaults to 'cuda' if torch.cuda.is_available is True. - annot_csv_filename : str - name of .csv file containing predicted annotations. - Default is None, in which case the name of the dataset .csv - is used, with '.annot.csv' appended to it. - output_dir : str, Path - path to location where .csv containing predicted annotation - should be saved. Defaults to current working directory. + timebins_key : str + key for accessing vector of time bins in files. Default is 't'. + device : str + Device on which to work with model + data. + Defaults to 'cuda' if torch.cuda.is_available is True. + annot_csv_filename : str + name of .csv file containing predicted annotations. + Default is None, in which case the name of the dataset .csv + is used, with '.annot.csv' appended to it. + output_dir : str, Path + path to location where .csv containing predicted annotation + should be saved. Defaults to current working directory. """ for path, path_name in zip( (checkpoint_path,), @@ -77,7 +75,7 @@ def predict_with_parametric_umap_model( f"value for ``{path_name}`` not recognized as a file: {path}" ) - dataset_path = pathlib.Path(dataset_path) + dataset_path = pathlib.Path(dataset_config["path"]) if not dataset_path.exists() or not dataset_path.is_dir(): raise NotADirectoryError( f"`dataset_path` not found or not recognized as a directory: {dataset_path}" @@ -103,12 +101,13 @@ def predict_with_parametric_umap_model( device = get_default_device() # ---------------- load data for prediction ------------------------------------------------------------------------ - if transform_params is None: - transform_params = {} - if "padding" not in transform_params and model_name == "ConvEncoderUMAP": - padding = models.convencoder_umap.get_default_padding(metadata.shape) - transform_params["padding"] = padding - + model_name = model_config["name"] + # TODO: fix this when we build transforms into datasets + transform_params = { + "padding": dataset_config["params"].get( + "padding", models.convencoder_umap.get_default_padding(metadata.shape) + ) + } item_transform = transforms.defaults.get_default_transform( model_name, "predict", transform_params ) @@ -118,13 +117,11 @@ def predict_with_parametric_umap_model( f"loading dataset to predict from csv path: {dataset_csv_path}" ) - if dataset_params is None: - dataset_params = {} pred_dataset = ParametricUMAPDataset.from_dataset_path( dataset_path=dataset_path, split="predict", transform=item_transform, - **dataset_params, + **dataset_config["params"], ) pred_loader = torch.utils.data.DataLoader( From aba3353ab2c7101d9b23db02fccaedf64b1df56b Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Sat, 4 May 2024 20:54:13 -0400 Subject: [PATCH 170/183] Rewrite vak/predict/predict.py to use model_config and dataset_config parameters, removing parameters transform_params + dataset_params and dataset_path --- src/vak/predict/predict_.py | 125 ++++++++++++++++-------------------- 1 file changed, 55 insertions(+), 70 deletions(-) diff --git a/src/vak/predict/predict_.py b/src/vak/predict/predict_.py index 5ada31c2f..60373b11d 100644 --- a/src/vak/predict/predict_.py +++ b/src/vak/predict/predict_.py @@ -15,14 +15,11 @@ def predict( - model_name: str, model_config: dict, - dataset_path: str | pathlib.Path, + dataset_config: dict, checkpoint_path: str | pathlib.Path, labelmap_path: str | pathlib.Path, num_workers: int = 2, - transform_params: dict | None = None, - dataset_params: dict | None = None, timebins_key: str = "t", spect_scaler_path: str | pathlib.Path | None = None, device: str | None = None, @@ -36,72 +33,62 @@ def predict( Parameters ---------- - model_name : str - Model name, must be one of vak.models.registry.MODEL_NAMES. model_config : dict Model configuration in a ``dict``, as loaded from a .toml file, and used by the model method ``from_config``. - dataset_path : str - Path to dataset, e.g., a csv file generated by running ``vak prep``. - checkpoint_path : str - path to directory with checkpoint files saved by Torch, to reload model - labelmap_path : str - path to 'labelmap.json' file. - window_size : int - size of windows taken from spectrograms, in number of time bins, - shown to neural networks - num_workers : int - Number of processes to use for parallel loading of data. - Argument to torch.DataLoader. Default is 2. - transform_params: dict, optional - Parameters for data transform. - Passed as keyword arguments. - Optional, default is None. - dataset_params: dict, optional - Parameters for dataset. - Passed as keyword arguments. - Optional, default is None. - timebins_key : str - key for accessing vector of time bins in files. Default is 't'. - device : str - Device on which to work with model + data. - Defaults to 'cuda' if torch.cuda.is_available is True. - spect_scaler_path : str - path to a saved SpectScaler object used to normalize spectrograms. - If spectrograms were normalized and this is not provided, will give - incorrect results. - annot_csv_filename : str - name of .csv file containing predicted annotations. - Default is None, in which case the name of the dataset .csv - is used, with '.annot.csv' appended to it. - output_dir : str, Path - path to location where .csv containing predicted annotation - should be saved. Defaults to current working directory. - min_segment_dur : float - minimum duration of segment, in seconds. If specified, then - any segment with a duration less than min_segment_dur is - removed from lbl_tb. Default is None, in which case no - segments are removed. - majority_vote : bool - if True, transform segments containing multiple labels - into segments with a single label by taking a "majority vote", - i.e. assign all time bins in the segment the most frequently - occurring label in the segment. This transform can only be - applied if the labelmap contains an 'unlabeled' label, - because unlabeled segments makes it possible to identify - the labeled segments. Default is False. + dataset_path : str + Path to dataset, e.g., a csv file generated by running ``vak prep``. + checkpoint_path : str + path to directory with checkpoint files saved by Torch, to reload model + labelmap_path : str + path to 'labelmap.json' file. + window_size : int + size of windows taken from spectrograms, in number of time bins, + shown to neural networks + num_workers : int + Number of processes to use for parallel loading of data. + Argument to torch.DataLoader. Default is 2. + timebins_key : str + key for accessing vector of time bins in files. Default is 't'. + device : str + Device on which to work with model + data. + Defaults to 'cuda' if torch.cuda.is_available is True. + spect_scaler_path : str + path to a saved SpectScaler object used to normalize spectrograms. + If spectrograms were normalized and this is not provided, will give + incorrect results. + annot_csv_filename : str + name of .csv file containing predicted annotations. + Default is None, in which case the name of the dataset .csv + is used, with '.annot.csv' appended to it. + output_dir : str, Path + path to location where .csv containing predicted annotation + should be saved. Defaults to current working directory. + min_segment_dur : float + minimum duration of segment, in seconds. If specified, then + any segment with a duration less than min_segment_dur is + removed from lbl_tb. Default is None, in which case no + segments are removed. + majority_vote : bool + if True, transform segments containing multiple labels + into segments with a single label by taking a "majority vote", + i.e. assign all time bins in the segment the most frequently + occurring label in the segment. This transform can only be + applied if the labelmap contains an 'unlabeled' label, + because unlabeled segments makes it possible to identify + the labeled segments. Default is False. save_net_outputs : bool - if True, save 'raw' outputs of neural networks - before they are converted to annotations. Default is False. - Typically the output will be "logits" - to which a softmax transform might be applied. - For each item in the dataset--each row in the `dataset_path` .csv-- - the output will be saved in a separate file in `output_dir`, - with the extension `{MODEL_NAME}.output.npz`. E.g., if the input is a - spectrogram with `spect_path` filename `gy6or6_032312_081416.npz`, - and the network is `TweetyNet`, then the net output file - will be `gy6or6_032312_081416.tweetynet.output.npz`. + If True, save 'raw' outputs of neural networks + before they are converted to annotations. Default is False. + Typically the output will be "logits" + to which a softmax transform might be applied. + For each item in the dataset--each row in the `dataset_path` .csv-- + the output will be saved in a separate file in `output_dir`, + with the extension `{MODEL_NAME}.output.npz`. E.g., if the input is a + spectrogram with `spect_path` filename `gy6or6_032312_081416.npz`, + and the network is `TweetyNet`, then the net output file + will be `gy6or6_032312_081416.tweetynet.output.npz`. """ for path, path_name in zip( (checkpoint_path, labelmap_path, spect_scaler_path), @@ -113,7 +100,7 @@ def predict( f"value for ``{path_name}`` not recognized as a file: {path}" ) - dataset_path = pathlib.Path(dataset_path) + dataset_path = pathlib.Path(dataset_config["path"]) if not dataset_path.exists() or not dataset_path.is_dir(): raise NotADirectoryError( f"`dataset_path` not found or not recognized as a directory: {dataset_path}" @@ -132,6 +119,7 @@ def predict( if device is None: device = get_default_device() + model_name = model_config["name"] try: model_family = models.registry.MODEL_FAMILY_FROM_NAME[model_name] except KeyError as e: @@ -140,14 +128,11 @@ def predict( ) from e if model_family == "FrameClassificationModel": predict_with_frame_classification_model( - model_name=model_name, model_config=model_config, - dataset_path=dataset_path, + dataset_config=dataset_config, checkpoint_path=checkpoint_path, labelmap_path=labelmap_path, num_workers=num_workers, - transform_params=transform_params, - dataset_params=dataset_params, timebins_key=timebins_key, spect_scaler_path=spect_scaler_path, device=device, From 3e237b480ccbb98b189b54b6d816e1aedcadccb1 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Sat, 4 May 2024 21:06:02 -0400 Subject: [PATCH 171/183] Fix dataset_path -> dataset_config[path] and add missing variable model_name in src/vak/learncurve/learncurve.py --- src/vak/learncurve/learncurve.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/vak/learncurve/learncurve.py b/src/vak/learncurve/learncurve.py index 39e176109..0b6e443bf 100644 --- a/src/vak/learncurve/learncurve.py +++ b/src/vak/learncurve/learncurve.py @@ -101,12 +101,13 @@ def learning_curve( Default is None, in which case training only stops after the specified number of epochs. """ # ---------------- pre-conditions ---------------------------------------------------------------------------------- - dataset_path = expanded_user_path(dataset_path) + dataset_path = expanded_user_path(dataset_config["path"]) if not dataset_path.exists() or not dataset_path.is_dir(): raise NotADirectoryError( f"`dataset_path` not found or not recognized as a directory: {dataset_path}" ) + model_name = model_config["name"] try: model_family = models.registry.MODEL_FAMILY_FROM_NAME[model_name] except KeyError as e: From 9eb7c0c2f5a4d49b9e2f49983f8a2ea65047c89d Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Sat, 4 May 2024 21:06:19 -0400 Subject: [PATCH 172/183] Fix dataset_path -> dataset_config[path] and add missing variable model_name in src/vak/learncurve/frame_classification.py --- src/vak/learncurve/frame_classification.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/vak/learncurve/frame_classification.py b/src/vak/learncurve/frame_classification.py index fb64be7a5..76364679d 100644 --- a/src/vak/learncurve/frame_classification.py +++ b/src/vak/learncurve/frame_classification.py @@ -107,7 +107,7 @@ def learning_curve_for_frame_classification_model( Default is None, in which case training only stops after the specified number of epochs. """ # ---------------- pre-conditions ---------------------------------------------------------------------------------- - dataset_path = expanded_user_path(dataset_path) + dataset_path = expanded_user_path(dataset_config["path"]) if not dataset_path.exists() or not dataset_path.is_dir(): raise NotADirectoryError( f"`dataset_path` not found or not recognized as a directory: {dataset_path}" @@ -158,6 +158,7 @@ def learning_curve_for_frame_classification_model( # ---- main loop that creates "learning curve" --------------------------------------------------------------------- logger.info("Starting training for learning curve.") + model_name = model_config["name"] # used below when getting checkpoint path, etc for train_dur, replicate_num in to_do: logger.info( f"Training model with training set of size: {train_dur}s, replicate number {replicate_num}.", From 32d9604c9d0c11eb7d6eb74fc1b9952418462d72 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Sat, 4 May 2024 21:18:22 -0400 Subject: [PATCH 173/183] Rewrite vak/cli/predict.py to use model_config and dataset_config parameters, removing parameters transform_params + dataset_params and dataset_path --- src/vak/cli/predict.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/vak/cli/predict.py b/src/vak/cli/predict.py index 0625de613..01c0e2612 100644 --- a/src/vak/cli/predict.py +++ b/src/vak/cli/predict.py @@ -43,14 +43,11 @@ def predict(toml_path): ) predict_module.predict( - model_name=cfg.predict.model.name, model_config=cfg.predict.model.asdict(), - dataset_path=cfg.predict.dataset.path, + dataset_config=cfg.predict.dataset.asdict(), checkpoint_path=cfg.predict.checkpoint_path, labelmap_path=cfg.predict.labelmap_path, num_workers=cfg.predict.num_workers, - transform_params=cfg.predict.transform_params, - dataset_params=cfg.predict.dataset_params, timebins_key=cfg.prep.spect_params.timebins_key, spect_scaler_path=cfg.predict.spect_scaler_path, device=cfg.predict.device, From 87e1a790660f4a2e2572815150a41089532f9ce4 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Sat, 4 May 2024 21:25:09 -0400 Subject: [PATCH 174/183] Remove non-existent dataset_params variable in vak/predict/frame_classification.py --- src/vak/predict/frame_classification.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/vak/predict/frame_classification.py b/src/vak/predict/frame_classification.py index 79db38221..d65924490 100644 --- a/src/vak/predict/frame_classification.py +++ b/src/vak/predict/frame_classification.py @@ -166,8 +166,8 @@ def predict_with_frame_classification_model( logger.info( f"loading dataset to predict from csv path: {dataset_csv_path}" ) - if dataset_params is None: - dataset_params = {} + + # TODO: fix this when we build transforms into datasets; pass in `window_size` here pred_dataset = FramesDataset.from_dataset_path( dataset_path=dataset_path, split="predict", From bfb4a9b181dc15a7e862aed099790613e9bcbea6 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Sat, 4 May 2024 22:15:28 -0400 Subject: [PATCH 175/183] Fix unit tests for DatasetConfig to test 'params' attribute gets handled correctly --- tests/test_config/test_dataset.py | 27 ++++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/tests/test_config/test_dataset.py b/tests/test_config/test_dataset.py index 690882b14..ecc9fed0d 100644 --- a/tests/test_config/test_dataset.py +++ b/tests/test_config/test_dataset.py @@ -60,6 +60,15 @@ def test_init(self, path, splits_path, name): 'path' :'~/user/prepped/dataset', 'splits_path': 'splits/replicate-1.json' }, + { + 'path' :'~/user/prepped/dataset', + 'params': {'window_size': 2000} + }, + { + 'name' : 'BioSoundSegBench', + 'path' :'~/user/prepped/dataset', + 'params': {'window_size': 2000}, + }, ] ) def test_from_config_dict(self, config_dict): @@ -74,6 +83,10 @@ def test_from_config_dict(self, config_dict): assert dataset_config.name == config_dict['name'] else: assert dataset_config.name is None + if 'params' in config_dict: + assert dataset_config.params == config_dict['params'] + else: + assert dataset_config.params == {} @pytest.mark.parametrize( 'config_dict', @@ -90,6 +103,15 @@ def test_from_config_dict(self, config_dict): 'path' :'~/user/prepped/dataset', 'splits_path': 'splits/replicate-1.json' }, + { + 'path' :'~/user/prepped/dataset', + 'params': {'window_size': 2000} + }, + { + 'name' : 'BioSoundSegBench', + 'path' :'~/user/prepped/dataset', + 'params': {'window_size': 2000}, + }, ] ) def test_asdict(self, config_dict): @@ -105,4 +127,7 @@ def test_asdict(self, config_dict): else: assert dataset_config_as_dict[key] == config_dict[key] else: - assert dataset_config_as_dict[key] is None + if key == 'params': + assert dataset_config_as_dict[key] == {} + else: + assert dataset_config_as_dict[key] is None From d70886aaaec3a926da417479c87839d52981de3a Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Sat, 4 May 2024 22:20:02 -0400 Subject: [PATCH 176/183] Remove train/val_dataset_params and train/val_transform_params from test cases we parametrize with in tests/test_config/ --- tests/test_config/test_eval.py | 18 ------------------ tests/test_config/test_learncurve.py | 10 ---------- tests/test_config/test_predict.py | 15 --------------- tests/test_config/test_train.py | 10 +--------- 4 files changed, 1 insertion(+), 52 deletions(-) diff --git a/tests/test_config/test_eval.py b/tests/test_config/test_eval.py index a83a0aca3..de0ac681e 100644 --- a/tests/test_config/test_eval.py +++ b/tests/test_config/test_eval.py @@ -20,9 +20,6 @@ class TestEval: 'post_tfm_kwargs': { 'majority_vote': True, 'min_segment_dur': 0.02 }, - 'transform_params': { - 'window_size': 88 - }, 'model': { 'TweetyNet': { 'network': { @@ -67,9 +64,6 @@ def test_init(self, config_dict): 'post_tfm_kwargs': { 'majority_vote': True, 'min_segment_dur': 0.02 }, - 'transform_params': { - 'window_size': 88 - }, 'model': { 'TweetyNet': { 'network': { @@ -120,9 +114,6 @@ def test_from_config_dict_with_real_config(self, a_generated_eval_config_dict): 'post_tfm_kwargs': { 'majority_vote': True, 'min_segment_dur': 0.02 }, - 'transform_params': { - 'window_size': 88 - }, 'dataset': { 'path': '~/some/path/I/made/up/for/now' }, @@ -142,9 +133,6 @@ def test_from_config_dict_with_real_config(self, a_generated_eval_config_dict): 'post_tfm_kwargs': { 'majority_vote': True, 'min_segment_dur': 0.02 }, - 'transform_params': { - 'window_size': 88 - }, 'model': { 'TweetyNet': { 'network': { @@ -176,9 +164,6 @@ def test_from_config_dict_with_real_config(self, a_generated_eval_config_dict): 'post_tfm_kwargs': { 'majority_vote': True, 'min_segment_dur': 0.02 }, - 'transform_params': { - 'window_size': 88 - }, 'model': { 'TweetyNet': { 'network': { @@ -213,9 +198,6 @@ def test_from_config_dict_with_real_config(self, a_generated_eval_config_dict): 'post_tfm_kwargs': { 'majority_vote': True, 'min_segment_dur': 0.02 }, - 'transform_params': { - 'window_size': 88 - }, 'model': { 'TweetyNet': { 'network': { diff --git a/tests/test_config/test_learncurve.py b/tests/test_config/test_learncurve.py index 581e04c32..6d2d65270 100644 --- a/tests/test_config/test_learncurve.py +++ b/tests/test_config/test_learncurve.py @@ -20,8 +20,6 @@ class TestLearncurveConfig: 'device': 'cuda', 'root_results_dir': './tests/data_for_tests/generated/results/train/audio_cbin_annot_notmat/TweetyNet', 'post_tfm_kwargs': {'majority_vote': True, 'min_segment_dur': 0.02}, - 'train_dataset_params': {'window_size': 88}, - 'val_transform_params': {'window_size': 88}, 'model': { 'TweetyNet': { 'network': { @@ -68,8 +66,6 @@ def test_init(self, config_dict): 'device': 'cuda', 'root_results_dir': './tests/data_for_tests/generated/results/train/audio_cbin_annot_notmat/TweetyNet', 'post_tfm_kwargs': {'majority_vote': True, 'min_segment_dur': 0.02}, - 'train_dataset_params': {'window_size': 88}, - 'val_transform_params': {'window_size': 88}, 'model': { 'TweetyNet': { 'network': { @@ -125,8 +121,6 @@ def test_from_config_dict_with_real_config(self, a_generated_learncurve_config_d 'device': 'cuda', 'root_results_dir': './tests/data_for_tests/generated/results/train/audio_cbin_annot_notmat/TweetyNet', 'post_tfm_kwargs': {'majority_vote': True, 'min_segment_dur': 0.02}, - 'train_dataset_params': {'window_size': 88}, - 'val_transform_params': {'window_size': 88}, 'dataset': { 'path': 'tests/data_for_tests/generated/prep/train/audio_cbin_annot_notmat/TweetyNet/032312-vak-frame-classification-dataset-generated-240502_234819' } @@ -146,8 +140,6 @@ def test_from_config_dict_with_real_config(self, a_generated_learncurve_config_d 'device': 'cuda', 'root_results_dir': './tests/data_for_tests/generated/results/train/audio_cbin_annot_notmat/TweetyNet', 'post_tfm_kwargs': {'majority_vote': True, 'min_segment_dur': 0.02}, - 'train_dataset_params': {'window_size': 88}, - 'val_transform_params': {'window_size': 88}, 'model': { 'TweetyNet': { 'network': { @@ -181,8 +173,6 @@ def test_from_config_dict_with_real_config(self, a_generated_learncurve_config_d 'num_workers': 16, 'device': 'cuda', 'post_tfm_kwargs': {'majority_vote': True, 'min_segment_dur': 0.02}, - 'train_dataset_params': {'window_size': 88}, - 'val_transform_params': {'window_size': 88}, 'model': { 'TweetyNet': { 'network': { diff --git a/tests/test_config/test_predict.py b/tests/test_config/test_predict.py index 9603e0df4..8d81dcf07 100644 --- a/tests/test_config/test_predict.py +++ b/tests/test_config/test_predict.py @@ -18,9 +18,6 @@ class TestPredictConfig: 'device': 'cuda', 'output_dir': './tests/data_for_tests/generated/results/predict/audio_cbin_annot_notmat/TweetyNet', 'annot_csv_filename': 'bl26lb16.041912.annot.csv', - 'transform_params': { - 'window_size': 88 - }, 'model': { 'TweetyNet': { 'network': { @@ -63,9 +60,6 @@ def test_init(self, config_dict): 'device': 'cuda', 'output_dir': './tests/data_for_tests/generated/results/predict/audio_cbin_annot_notmat/TweetyNet', 'annot_csv_filename': 'bl26lb16.041912.annot.csv', - 'transform_params': { - 'window_size': 88 - }, 'model': { 'TweetyNet': { 'network': { @@ -113,9 +107,6 @@ def test_from_config_dict_with_real_config(self, a_generated_predict_config_dict 'device': 'cuda', 'output_dir': './tests/data_for_tests/generated/results/predict/audio_cbin_annot_notmat/TweetyNet', 'annot_csv_filename': 'bl26lb16.041912.annot.csv', - 'transform_params': { - 'window_size': 88 - }, 'model': { 'TweetyNet': { 'network': { @@ -149,9 +140,6 @@ def test_from_config_dict_with_real_config(self, a_generated_predict_config_dict 'device': 'cuda', 'output_dir': './tests/data_for_tests/generated/results/predict/audio_cbin_annot_notmat/TweetyNet', 'annot_csv_filename': 'bl26lb16.041912.annot.csv', - 'transform_params': { - 'window_size': 88 - }, 'model': { 'TweetyNet': { 'network': { @@ -182,9 +170,6 @@ def test_from_config_dict_with_real_config(self, a_generated_predict_config_dict 'device': 'cuda', 'output_dir': './tests/data_for_tests/generated/results/predict/audio_cbin_annot_notmat/TweetyNet', 'annot_csv_filename': 'bl26lb16.041912.annot.csv', - 'transform_params': { - 'window_size': 88 - }, 'dataset': { 'path': '~/some/path/I/made/up/for/now' }, diff --git a/tests/test_config/test_train.py b/tests/test_config/test_train.py index 229542410..e5a3127da 100644 --- a/tests/test_config/test_train.py +++ b/tests/test_config/test_train.py @@ -19,8 +19,6 @@ class TestTrainConfig: 'num_workers': 16, 'device': 'cuda', 'root_results_dir': './tests/data_for_tests/generated/results/train/audio_cbin_annot_notmat/TweetyNet', - 'train_dataset_params': {'window_size': 88}, - 'val_transform_params': {'window_size': 88}, 'model': { 'TweetyNet': { 'network': { @@ -66,8 +64,6 @@ def test_init(self, config_dict): 'num_workers': 16, 'device': 'cuda', 'root_results_dir': './tests/data_for_tests/generated/results/train/audio_cbin_annot_notmat/TweetyNet', - 'train_dataset_params': {'window_size': 88}, - 'val_transform_params': {'window_size': 88}, 'model': { 'TweetyNet': { 'network': { @@ -118,8 +114,6 @@ def test_from_config_dict_with_real_config(self, a_generated_train_config_dict): 'num_workers': 16, 'device': 'cuda', 'root_results_dir': './tests/data_for_tests/generated/results/train/audio_cbin_annot_notmat/TweetyNet', - 'train_dataset_params': {'window_size': 88}, - 'val_transform_params': {'window_size': 88}, 'dataset': { 'path': 'tests/data_for_tests/generated/prep/train/audio_cbin_annot_notmat/TweetyNet/032312-vak-frame-classification-dataset-generated-240502_234819' } @@ -137,8 +131,6 @@ def test_from_config_dict_with_real_config(self, a_generated_train_config_dict): 'num_workers': 16, 'device': 'cuda', 'root_results_dir': './tests/data_for_tests/generated/results/train/audio_cbin_annot_notmat/TweetyNet', - 'train_dataset_params': {'window_size': 88}, - 'val_transform_params': {'window_size': 88}, 'model': { 'TweetyNet': { 'network': { @@ -164,4 +156,4 @@ def test_from_config_dict_with_real_config(self, a_generated_train_config_dict): ) def test_from_config_dict_raises(self, config_dict, expected_exception): with pytest.raises(expected_exception): - vak.config.TrainConfig.from_config_dict(config_dict) \ No newline at end of file + vak.config.TrainConfig.from_config_dict(config_dict) From 401f1a333f5b249d3932016cc9f5fd631f66780e Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Sat, 4 May 2024 22:28:11 -0400 Subject: [PATCH 177/183] Use DatasetConfig.params attribute where we need to in tests/test_datasets --- .../test_frame_classification/test_frames_dataset.py | 5 ++++- .../test_frame_classification/test_window_dataset.py | 7 +++---- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/tests/test_datasets/test_frame_classification/test_frames_dataset.py b/tests/test_datasets/test_frame_classification/test_frames_dataset.py index c165669b7..f71c7f9fb 100644 --- a/tests/test_datasets/test_frame_classification/test_frames_dataset.py +++ b/tests/test_datasets/test_frame_classification/test_frames_dataset.py @@ -22,8 +22,11 @@ def test_from_dataset_path(self, config_type, model_name, audio_format, spect_fo cfg = vak.config.Config.from_toml_path(toml_path) cfg_command = getattr(cfg, config_type) + transform_kwargs = { + "window_size": cfg.eval.dataset.params["window_size"] + } item_transform = vak.transforms.defaults.get_default_transform( - model_name, config_type, cfg.eval.transform_params + model_name, config_type, transform_kwargs ) dataset = vak.datasets.frame_classification.FramesDataset.from_dataset_path( diff --git a/tests/test_datasets/test_frame_classification/test_window_dataset.py b/tests/test_datasets/test_frame_classification/test_window_dataset.py index 238f7de12..430917d2a 100644 --- a/tests/test_datasets/test_frame_classification/test_window_dataset.py +++ b/tests/test_datasets/test_frame_classification/test_window_dataset.py @@ -23,15 +23,14 @@ def test_from_dataset_path(self, config_type, model_name, audio_format, spect_fo cfg = vak.config.Config.from_toml_path(toml_path) cfg_command = getattr(cfg, config_type) - transform, target_transform = vak.transforms.defaults.get_default_transform( + transform = vak.transforms.defaults.get_default_transform( model_name, config_type, transform_kwargs ) dataset = vak.datasets.frame_classification.WindowDataset.from_dataset_path( dataset_path=cfg_command.dataset.path, split=split, - window_size=cfg_command.train_dataset_params['window_size'], - transform=transform, - target_transform=target_transform, + window_size=cfg_command.dataset.params['window_size'], + item_transform=transform, ) assert isinstance(dataset, vak.datasets.frame_classification.WindowDataset) From 0ae33a8db1c520a3cc3b0a52a2b6ecbd6f80eb12 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Sat, 4 May 2024 22:29:55 -0400 Subject: [PATCH 178/183] Fix method name ModelConfig.to_dict -> asdict in tests/ --- tests/test_eval/test_eval.py | 2 +- tests/test_eval/test_frame_classification.py | 6 +++--- tests/test_eval/test_parametric_umap.py | 6 +++--- tests/test_learncurve/test_frame_classification.py | 4 ++-- tests/test_models/test_base.py | 2 +- tests/test_models/test_frame_classification_model.py | 2 +- tests/test_models/test_parametric_umap_model.py | 2 +- tests/test_predict/test_frame_classification.py | 6 +++--- tests/test_predict/test_predict.py | 2 +- tests/test_train/test_frame_classification.py | 8 ++++---- tests/test_train/test_parametric_umap.py | 6 +++--- tests/test_train/test_train.py | 2 +- 12 files changed, 24 insertions(+), 24 deletions(-) diff --git a/tests/test_eval/test_eval.py b/tests/test_eval/test_eval.py index e2beb1b7e..d2d19bea2 100644 --- a/tests/test_eval/test_eval.py +++ b/tests/test_eval/test_eval.py @@ -50,7 +50,7 @@ def test_eval( with mock.patch(eval_function_to_mock, autospec=True) as mock_eval_function: vak.eval.eval( model_name=cfg.eval.model.name, - model_config=cfg.eval.model.to_dict(), + model_config=cfg.eval.model.asdict(), dataset_path=cfg.eval.dataset.path, checkpoint_path=cfg.eval.checkpoint_path, labelmap_path=cfg.eval.labelmap_path, diff --git a/tests/test_eval/test_frame_classification.py b/tests/test_eval/test_frame_classification.py index b573fe85c..5b6c66a82 100644 --- a/tests/test_eval/test_frame_classification.py +++ b/tests/test_eval/test_frame_classification.py @@ -71,7 +71,7 @@ def test_eval_frame_classification_model( vak.eval.frame_classification.eval_frame_classification_model( model_name=cfg.eval.model.name, - model_config=cfg.eval.model.to_dict(), + model_config=cfg.eval.model.asdict(), dataset_path=cfg.eval.dataset.path, checkpoint_path=cfg.eval.checkpoint_path, labelmap_path=cfg.eval.labelmap_path, @@ -128,7 +128,7 @@ def test_eval_frame_classification_model_raises_file_not_found( with pytest.raises(FileNotFoundError): vak.eval.frame_classification.eval_frame_classification_model( model_name=cfg.eval.model.name, - model_config=cfg.eval.model.to_dict(), + model_config=cfg.eval.model.asdict(), dataset_path=cfg.eval.dataset.path, checkpoint_path=cfg.eval.checkpoint_path, labelmap_path=cfg.eval.labelmap_path, @@ -185,7 +185,7 @@ def test_eval_frame_classification_model_raises_not_a_directory( with pytest.raises(NotADirectoryError): vak.eval.frame_classification.eval_frame_classification_model( model_name=cfg.eval.model.name, - model_config=cfg.eval.model.to_dict(), + model_config=cfg.eval.model.asdict(), dataset_path=cfg.eval.dataset.path, checkpoint_path=cfg.eval.checkpoint_path, labelmap_path=cfg.eval.labelmap_path, diff --git a/tests/test_eval/test_parametric_umap.py b/tests/test_eval/test_parametric_umap.py index 69fa2cf44..b055920bc 100644 --- a/tests/test_eval/test_parametric_umap.py +++ b/tests/test_eval/test_parametric_umap.py @@ -49,7 +49,7 @@ def test_eval_parametric_umap_model( vak.eval.parametric_umap.eval_parametric_umap_model( model_name=cfg.eval.model.name, - model_config=cfg.eval.model.to_dict(), + model_config=cfg.eval.model.asdict(), dataset_path=cfg.eval.dataset.path, checkpoint_path=cfg.eval.checkpoint_path, output_dir=cfg.eval.output_dir, @@ -100,7 +100,7 @@ def test_eval_frame_classification_model_raises_file_not_found( with pytest.raises(FileNotFoundError): vak.eval.parametric_umap.eval_parametric_umap_model( model_name=cfg.eval.model.name, - model_config=cfg.eval.model.to_dict(), + model_config=cfg.eval.model.asdict(), dataset_path=cfg.eval.dataset.path, checkpoint_path=cfg.eval.checkpoint_path, output_dir=cfg.eval.output_dir, @@ -155,7 +155,7 @@ def test_eval_frame_classification_model_raises_not_a_directory( with pytest.raises(NotADirectoryError): vak.eval.parametric_umap.eval_parametric_umap_model( model_name=cfg.eval.model.name, - model_config=cfg.eval.model.to_dict(), + model_config=cfg.eval.model.asdict(), dataset_path=cfg.eval.dataset.path, checkpoint_path=cfg.eval.checkpoint_path, output_dir=cfg.eval.output_dir, diff --git a/tests/test_learncurve/test_frame_classification.py b/tests/test_learncurve/test_frame_classification.py index 9f5d87804..c23925ac0 100644 --- a/tests/test_learncurve/test_frame_classification.py +++ b/tests/test_learncurve/test_frame_classification.py @@ -68,7 +68,7 @@ def test_learning_curve_for_frame_classification_model( vak.learncurve.frame_classification.learning_curve_for_frame_classification_model( model_name=cfg.learncurve.model.name, - model_config=cfg.learncurve.model.to_dict(), + model_config=cfg.learncurve.model.asdict(), dataset_path=cfg.learncurve.dataset.path, batch_size=cfg.learncurve.batch_size, num_epochs=cfg.learncurve.num_epochs, @@ -122,7 +122,7 @@ def test_learncurve_raises_not_a_directory(dir_option_to_change, with pytest.raises(NotADirectoryError): vak.learncurve.frame_classification.learning_curve_for_frame_classification_model( model_name=cfg.learncurve.model.name, - model_config=cfg.learncurve.model.to_dict(), + model_config=cfg.learncurve.model.asdict(), dataset_path=cfg.learncurve.dataset.path, batch_size=cfg.learncurve.batch_size, num_epochs=cfg.learncurve.num_epochs, diff --git a/tests/test_models/test_base.py b/tests/test_models/test_base.py index f1c389983..bd30f00f9 100644 --- a/tests/test_models/test_base.py +++ b/tests/test_models/test_base.py @@ -217,7 +217,7 @@ def test_load_state_dict_from_path(self, # network is the one thing that has required args # and we also need to use its config from the toml file cfg = vak.config.Config.from_toml_path(train_toml_path) - model_config = cfg.train.model.to_dict() + model_config = cfg.train.model.asdict() network = definition.network(num_classes=len(labelmap), num_input_channels=num_input_channels, num_freqbins=num_freqbins, diff --git a/tests/test_models/test_frame_classification_model.py b/tests/test_models/test_frame_classification_model.py index c69095bae..c66dbcb47 100644 --- a/tests/test_models/test_frame_classification_model.py +++ b/tests/test_models/test_frame_classification_model.py @@ -95,7 +95,7 @@ def test_from_config(self, vak.models.FrameClassificationModel, 'definition', definition, raising=False ) - config = cfg.train.model.to_dict() + config = cfg.train.model.asdict() num_input_channels, num_freqbins = self.MOCK_INPUT_SHAPE[0], self.MOCK_INPUT_SHAPE[1] config["network"].update( diff --git a/tests/test_models/test_parametric_umap_model.py b/tests/test_models/test_parametric_umap_model.py index a933a88b0..eba4f77d1 100644 --- a/tests/test_models/test_parametric_umap_model.py +++ b/tests/test_models/test_parametric_umap_model.py @@ -93,7 +93,7 @@ def test_from_config( vak.models.ParametricUMAPModel, 'definition', definition, raising=False ) - config = cfg.train.model.to_dict() + config = cfg.train.model.asdict() config["network"].update( encoder=dict(input_shape=input_shape) ) diff --git a/tests/test_predict/test_frame_classification.py b/tests/test_predict/test_frame_classification.py index 3facf373d..bf1f090ec 100644 --- a/tests/test_predict/test_frame_classification.py +++ b/tests/test_predict/test_frame_classification.py @@ -53,7 +53,7 @@ def test_predict_with_frame_classification_model( vak.predict.frame_classification.predict_with_frame_classification_model( model_name=cfg.predict.model.name, - model_config=cfg.predict.model.to_dict(), + model_config=cfg.predict.model.asdict(), dataset_path=cfg.predict.dataset.path, checkpoint_path=cfg.predict.checkpoint_path, labelmap_path=cfg.predict.labelmap_path, @@ -127,7 +127,7 @@ def test_predict_with_frame_classification_model_raises_file_not_found( with pytest.raises(FileNotFoundError): vak.predict.frame_classification.predict_with_frame_classification_model( model_name=cfg.predict.model.name, - model_config=cfg.predict.model.to_dict(), + model_config=cfg.predict.model.asdict(), dataset_path=cfg.predict.dataset.path, checkpoint_path=cfg.predict.checkpoint_path, labelmap_path=cfg.predict.labelmap_path, @@ -189,7 +189,7 @@ def test_predict_with_frame_classification_model_raises_not_a_directory( with pytest.raises(NotADirectoryError): vak.predict.frame_classification.predict_with_frame_classification_model( model_name=cfg.predict.model.name, - model_config=cfg.predict.model.to_dict(), + model_config=cfg.predict.model.asdict(), dataset_path=cfg.predict.dataset.path, checkpoint_path=cfg.predict.checkpoint_path, labelmap_path=cfg.predict.labelmap_path, diff --git a/tests/test_predict/test_predict.py b/tests/test_predict/test_predict.py index c445dad49..61c764c66 100644 --- a/tests/test_predict/test_predict.py +++ b/tests/test_predict/test_predict.py @@ -46,7 +46,7 @@ def test_predict( with mock.patch(predict_function_to_mock, autospec=True) as mock_predict_function: vak.predict.predict( model_name=cfg.predict.model.name, - model_config=cfg.predict.model.to_dict(), + model_config=cfg.predict.model.asdict(), dataset_path=cfg.predict.dataset.path, checkpoint_path=cfg.predict.checkpoint_path, labelmap_path=cfg.predict.labelmap_path, diff --git a/tests/test_train/test_frame_classification.py b/tests/test_train/test_frame_classification.py index 4bc1587a3..ef57fa5b3 100644 --- a/tests/test_train/test_frame_classification.py +++ b/tests/test_train/test_frame_classification.py @@ -62,7 +62,7 @@ def test_train_frame_classification_model( vak.train.frame_classification.train_frame_classification_model( model_name=cfg.train.model.name, - model_config=cfg.train.model.to_dict(), + model_config=cfg.train.model.asdict(), dataset_path=cfg.train.dataset.path, batch_size=cfg.train.batch_size, num_epochs=cfg.train.num_epochs, @@ -114,7 +114,7 @@ def test_continue_training( vak.train.frame_classification.train_frame_classification_model( model_name=cfg.train.model.name, - model_config=cfg.train.model.to_dict(), + model_config=cfg.train.model.asdict(), dataset_path=cfg.train.dataset.path, batch_size=cfg.train.batch_size, num_epochs=cfg.train.num_epochs, @@ -170,7 +170,7 @@ def test_train_raises_file_not_found( with pytest.raises(FileNotFoundError): vak.train.frame_classification.train_frame_classification_model( model_name=cfg.train.model.name, - model_config=cfg.train.model.to_dict(), + model_config=cfg.train.model.asdict(), dataset_path=cfg.train.dataset.path, batch_size=cfg.train.batch_size, num_epochs=cfg.train.num_epochs, @@ -225,7 +225,7 @@ def test_train_raises_not_a_directory( with pytest.raises(NotADirectoryError): vak.train.frame_classification.train_frame_classification_model( model_name=cfg.train.model.name, - model_config=cfg.train.model.to_dict(), + model_config=cfg.train.model.asdict(), dataset_path=cfg.train.dataset.path, batch_size=cfg.train.batch_size, num_epochs=cfg.train.num_epochs, diff --git a/tests/test_train/test_parametric_umap.py b/tests/test_train/test_parametric_umap.py index 58ff5fb82..89fb55db4 100644 --- a/tests/test_train/test_parametric_umap.py +++ b/tests/test_train/test_parametric_umap.py @@ -55,7 +55,7 @@ def test_train_parametric_umap_model( vak.train.parametric_umap.train_parametric_umap_model( model_name=cfg.train.model.name, - model_config=cfg.train.model.to_dict(), + model_config=cfg.train.model.asdict(), dataset_path=cfg.train.dataset.path, batch_size=cfg.train.batch_size, num_epochs=cfg.train.num_epochs, @@ -107,7 +107,7 @@ def test_train_parametric_umap_model_raises_file_not_found( with pytest.raises(FileNotFoundError): vak.train.parametric_umap.train_parametric_umap_model( model_name=cfg.train.model.name, - model_config=cfg.train.model.to_dict(), + model_config=cfg.train.model.asdict(), dataset_path=cfg.train.dataset.path, batch_size=cfg.train.batch_size, num_epochs=cfg.train.num_epochs, @@ -152,7 +152,7 @@ def test_train_parametric_umap_model_raises_not_a_directory( keys_to_change=keys_to_change, ) cfg = vak.config.Config.from_toml_path(toml_path) - model_config = cfg.train.model.to_dict() + model_config = cfg.train.model.asdict() # mock behavior of cli.train, building `results_path` from config option `root_results_dir` results_path = cfg.train.root_results_dir / 'results-dir-timestamp' diff --git a/tests/test_train/test_train.py b/tests/test_train/test_train.py index b511d0e13..4eb513d3c 100644 --- a/tests/test_train/test_train.py +++ b/tests/test_train/test_train.py @@ -54,7 +54,7 @@ def test_train( with mock.patch(train_function_to_mock, autospec=True) as mock_train_function: vak.train.train( model_name=cfg.train.model.name, - model_config=cfg.train.model.to_dict(), + model_config=cfg.train.model.asdict(), dataset_path=cfg.train.dataset.path, batch_size=cfg.train.batch_size, num_epochs=cfg.train.num_epochs, From e033b25f00eb3f19758ea4d522a71b7be2838a3c Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Sat, 4 May 2024 22:39:53 -0400 Subject: [PATCH 179/183] In tests for eval/learncurve/predict/train, use model_config and dataset_config parameters, removing parameters transform_params + dataset_params and dataset_path --- tests/test_eval/test_eval.py | 5 +--- tests/test_eval/test_frame_classification.py | 15 ++-------- tests/test_eval/test_parametric_umap.py | 16 ++--------- .../test_frame_classification.py | 14 ++-------- .../test_predict/test_frame_classification.py | 15 ++-------- tests/test_predict/test_predict.py | 5 +--- tests/test_train/test_frame_classification.py | 28 +++---------------- tests/test_train/test_parametric_umap.py | 21 ++------------ tests/test_train/test_train.py | 7 +---- 9 files changed, 21 insertions(+), 105 deletions(-) diff --git a/tests/test_eval/test_eval.py b/tests/test_eval/test_eval.py index d2d19bea2..7a874afb9 100644 --- a/tests/test_eval/test_eval.py +++ b/tests/test_eval/test_eval.py @@ -49,16 +49,13 @@ def test_eval( with mock.patch(eval_function_to_mock, autospec=True) as mock_eval_function: vak.eval.eval( - model_name=cfg.eval.model.name, model_config=cfg.eval.model.asdict(), - dataset_path=cfg.eval.dataset.path, + dataset_config=cfg.eval.dataset.asdict(), checkpoint_path=cfg.eval.checkpoint_path, labelmap_path=cfg.eval.labelmap_path, output_dir=cfg.eval.output_dir, num_workers=cfg.eval.num_workers, batch_size=cfg.eval.batch_size, - transform_params=cfg.eval.transform_params, - dataset_params=cfg.eval.dataset_params, spect_scaler_path=cfg.eval.spect_scaler_path, device=cfg.eval.device, post_tfm_kwargs=cfg.eval.post_tfm_kwargs, diff --git a/tests/test_eval/test_frame_classification.py b/tests/test_eval/test_frame_classification.py index 5b6c66a82..ee55825a5 100644 --- a/tests/test_eval/test_frame_classification.py +++ b/tests/test_eval/test_frame_classification.py @@ -70,15 +70,12 @@ def test_eval_frame_classification_model( cfg = vak.config.Config.from_toml_path(toml_path) vak.eval.frame_classification.eval_frame_classification_model( - model_name=cfg.eval.model.name, model_config=cfg.eval.model.asdict(), - dataset_path=cfg.eval.dataset.path, + dataset_config=cfg.eval.dataset.asdict(), checkpoint_path=cfg.eval.checkpoint_path, labelmap_path=cfg.eval.labelmap_path, output_dir=cfg.eval.output_dir, num_workers=cfg.eval.num_workers, - transform_params=cfg.eval.transform_params, - dataset_params=cfg.eval.dataset_params, spect_scaler_path=cfg.eval.spect_scaler_path, device=cfg.eval.device, post_tfm_kwargs=post_tfm_kwargs, @@ -127,15 +124,12 @@ def test_eval_frame_classification_model_raises_file_not_found( cfg = vak.config.Config.from_toml_path(toml_path) with pytest.raises(FileNotFoundError): vak.eval.frame_classification.eval_frame_classification_model( - model_name=cfg.eval.model.name, model_config=cfg.eval.model.asdict(), - dataset_path=cfg.eval.dataset.path, + dataset_config=cfg.eval.dataset.asdict(), checkpoint_path=cfg.eval.checkpoint_path, labelmap_path=cfg.eval.labelmap_path, output_dir=cfg.eval.output_dir, num_workers=cfg.eval.num_workers, - transform_params=cfg.eval.transform_params, - dataset_params=cfg.eval.dataset_params, spect_scaler_path=cfg.eval.spect_scaler_path, device=cfg.eval.device, ) @@ -184,15 +178,12 @@ def test_eval_frame_classification_model_raises_not_a_directory( cfg = vak.config.Config.from_toml_path(toml_path) with pytest.raises(NotADirectoryError): vak.eval.frame_classification.eval_frame_classification_model( - model_name=cfg.eval.model.name, model_config=cfg.eval.model.asdict(), - dataset_path=cfg.eval.dataset.path, + dataset_config=cfg.eval.dataset.asdict(), checkpoint_path=cfg.eval.checkpoint_path, labelmap_path=cfg.eval.labelmap_path, output_dir=cfg.eval.output_dir, num_workers=cfg.eval.num_workers, - transform_params=cfg.eval.transform_params, - dataset_params=cfg.eval.dataset_params, spect_scaler_path=cfg.eval.spect_scaler_path, device=cfg.eval.device, ) diff --git a/tests/test_eval/test_parametric_umap.py b/tests/test_eval/test_parametric_umap.py index b055920bc..4c6c7e573 100644 --- a/tests/test_eval/test_parametric_umap.py +++ b/tests/test_eval/test_parametric_umap.py @@ -48,15 +48,12 @@ def test_eval_parametric_umap_model( cfg = vak.config.Config.from_toml_path(toml_path) vak.eval.parametric_umap.eval_parametric_umap_model( - model_name=cfg.eval.model.name, model_config=cfg.eval.model.asdict(), - dataset_path=cfg.eval.dataset.path, + dataset_config=cfg.eval.dataset.asdict(), checkpoint_path=cfg.eval.checkpoint_path, output_dir=cfg.eval.output_dir, batch_size=cfg.eval.batch_size, num_workers=cfg.eval.num_workers, - transform_params=cfg.eval.transform_params, - dataset_params=cfg.eval.dataset_params, device=cfg.eval.device, ) @@ -99,15 +96,12 @@ def test_eval_frame_classification_model_raises_file_not_found( cfg = vak.config.Config.from_toml_path(toml_path) with pytest.raises(FileNotFoundError): vak.eval.parametric_umap.eval_parametric_umap_model( - model_name=cfg.eval.model.name, model_config=cfg.eval.model.asdict(), - dataset_path=cfg.eval.dataset.path, + dataset_config=cfg.eval.dataset.asdict(), checkpoint_path=cfg.eval.checkpoint_path, output_dir=cfg.eval.output_dir, batch_size=cfg.eval.batch_size, num_workers=cfg.eval.num_workers, - transform_params=cfg.eval.transform_params, - dataset_params=cfg.eval.dataset_params, device=cfg.eval.device, ) @@ -154,15 +148,11 @@ def test_eval_frame_classification_model_raises_not_a_directory( cfg = vak.config.Config.from_toml_path(toml_path) with pytest.raises(NotADirectoryError): vak.eval.parametric_umap.eval_parametric_umap_model( - model_name=cfg.eval.model.name, model_config=cfg.eval.model.asdict(), - dataset_path=cfg.eval.dataset.path, + dataset_config=cfg.eval.dataset.asdict(), checkpoint_path=cfg.eval.checkpoint_path, output_dir=cfg.eval.output_dir, batch_size=cfg.eval.batch_size, num_workers=cfg.eval.num_workers, - transform_params=cfg.eval.transform_params, - dataset_params=cfg.eval.dataset_params, device=cfg.eval.device, ) - diff --git a/tests/test_learncurve/test_frame_classification.py b/tests/test_learncurve/test_frame_classification.py index c23925ac0..c88f4ebf4 100644 --- a/tests/test_learncurve/test_frame_classification.py +++ b/tests/test_learncurve/test_frame_classification.py @@ -67,16 +67,11 @@ def test_learning_curve_for_frame_classification_model( results_path.mkdir() vak.learncurve.frame_classification.learning_curve_for_frame_classification_model( - model_name=cfg.learncurve.model.name, model_config=cfg.learncurve.model.asdict(), - dataset_path=cfg.learncurve.dataset.path, + dataset_config=cfg.learncurve.dataset.asdict(), batch_size=cfg.learncurve.batch_size, num_epochs=cfg.learncurve.num_epochs, num_workers=cfg.learncurve.num_workers, - train_transform_params=cfg.learncurve.train_transform_params, - train_dataset_params=cfg.learncurve.train_dataset_params, - val_transform_params=cfg.learncurve.val_transform_params, - val_dataset_params=cfg.learncurve.val_dataset_params, results_path=results_path, post_tfm_kwargs=cfg.learncurve.post_tfm_kwargs, normalize_spectrograms=cfg.learncurve.normalize_spectrograms, @@ -121,16 +116,11 @@ def test_learncurve_raises_not_a_directory(dir_option_to_change, with pytest.raises(NotADirectoryError): vak.learncurve.frame_classification.learning_curve_for_frame_classification_model( - model_name=cfg.learncurve.model.name, model_config=cfg.learncurve.model.asdict(), - dataset_path=cfg.learncurve.dataset.path, + dataset_config=cfg.learncurve.dataset.asdict(), batch_size=cfg.learncurve.batch_size, num_epochs=cfg.learncurve.num_epochs, num_workers=cfg.learncurve.num_workers, - train_transform_params=cfg.learncurve.train_transform_params, - train_dataset_params=cfg.learncurve.train_dataset_params, - val_transform_params=cfg.learncurve.val_transform_params, - val_dataset_params=cfg.learncurve.val_dataset_params, results_path=results_path, post_tfm_kwargs=cfg.learncurve.post_tfm_kwargs, normalize_spectrograms=cfg.learncurve.normalize_spectrograms, diff --git a/tests/test_predict/test_frame_classification.py b/tests/test_predict/test_frame_classification.py index bf1f090ec..5b8902f9d 100644 --- a/tests/test_predict/test_frame_classification.py +++ b/tests/test_predict/test_frame_classification.py @@ -52,14 +52,11 @@ def test_predict_with_frame_classification_model( cfg = vak.config.Config.from_toml_path(toml_path) vak.predict.frame_classification.predict_with_frame_classification_model( - model_name=cfg.predict.model.name, model_config=cfg.predict.model.asdict(), - dataset_path=cfg.predict.dataset.path, + dataset_config=cfg.predict.dataset.asdict(), checkpoint_path=cfg.predict.checkpoint_path, labelmap_path=cfg.predict.labelmap_path, num_workers=cfg.predict.num_workers, - transform_params=cfg.predict.transform_params, - dataset_params=cfg.predict.dataset_params, timebins_key=cfg.prep.spect_params.timebins_key, spect_scaler_path=cfg.predict.spect_scaler_path, device=cfg.predict.device, @@ -126,14 +123,11 @@ def test_predict_with_frame_classification_model_raises_file_not_found( with pytest.raises(FileNotFoundError): vak.predict.frame_classification.predict_with_frame_classification_model( - model_name=cfg.predict.model.name, model_config=cfg.predict.model.asdict(), - dataset_path=cfg.predict.dataset.path, + dataset_config=cfg.predict.dataset.asdict(), checkpoint_path=cfg.predict.checkpoint_path, labelmap_path=cfg.predict.labelmap_path, num_workers=cfg.predict.num_workers, - transform_params=cfg.predict.transform_params, - dataset_params=cfg.predict.dataset_params, timebins_key=cfg.prep.spect_params.timebins_key, spect_scaler_path=cfg.predict.spect_scaler_path, device=cfg.predict.device, @@ -188,14 +182,11 @@ def test_predict_with_frame_classification_model_raises_not_a_directory( with pytest.raises(NotADirectoryError): vak.predict.frame_classification.predict_with_frame_classification_model( - model_name=cfg.predict.model.name, model_config=cfg.predict.model.asdict(), - dataset_path=cfg.predict.dataset.path, + dataset_config=cfg.predict.dataset.asdict(), checkpoint_path=cfg.predict.checkpoint_path, labelmap_path=cfg.predict.labelmap_path, num_workers=cfg.predict.num_workers, - transform_params=cfg.predict.transform_params, - dataset_params=cfg.predict.dataset_params, timebins_key=cfg.prep.spect_params.timebins_key, spect_scaler_path=cfg.predict.spect_scaler_path, device=cfg.predict.device, diff --git a/tests/test_predict/test_predict.py b/tests/test_predict/test_predict.py index 61c764c66..820613ed4 100644 --- a/tests/test_predict/test_predict.py +++ b/tests/test_predict/test_predict.py @@ -45,14 +45,11 @@ def test_predict( with mock.patch(predict_function_to_mock, autospec=True) as mock_predict_function: vak.predict.predict( - model_name=cfg.predict.model.name, model_config=cfg.predict.model.asdict(), - dataset_path=cfg.predict.dataset.path, + dataset_config=cfg.predict.dataset.asdict(), checkpoint_path=cfg.predict.checkpoint_path, labelmap_path=cfg.predict.labelmap_path, num_workers=cfg.predict.num_workers, - transform_params=cfg.predict.transform_params, - dataset_params=cfg.predict.dataset_params, timebins_key=cfg.prep.spect_params.timebins_key, spect_scaler_path=cfg.predict.spect_scaler_path, device=cfg.predict.device, diff --git a/tests/test_train/test_frame_classification.py b/tests/test_train/test_frame_classification.py index ef57fa5b3..f4e50ef46 100644 --- a/tests/test_train/test_frame_classification.py +++ b/tests/test_train/test_frame_classification.py @@ -61,16 +61,11 @@ def test_train_frame_classification_model( cfg = vak.config.Config.from_toml_path(toml_path) vak.train.frame_classification.train_frame_classification_model( - model_name=cfg.train.model.name, model_config=cfg.train.model.asdict(), - dataset_path=cfg.train.dataset.path, + dataset_config=cfg.train.dataset.asdict(), batch_size=cfg.train.batch_size, num_epochs=cfg.train.num_epochs, num_workers=cfg.train.num_workers, - train_transform_params=cfg.train.train_transform_params, - train_dataset_params=cfg.train.train_dataset_params, - val_transform_params=cfg.train.val_transform_params, - val_dataset_params=cfg.train.val_dataset_params, checkpoint_path=cfg.train.checkpoint_path, spect_scaler_path=cfg.train.spect_scaler_path, results_path=results_path, @@ -113,16 +108,11 @@ def test_continue_training( cfg = vak.config.Config.from_toml_path(toml_path) vak.train.frame_classification.train_frame_classification_model( - model_name=cfg.train.model.name, model_config=cfg.train.model.asdict(), - dataset_path=cfg.train.dataset.path, + dataset_config=cfg.train.dataset.asdict(), batch_size=cfg.train.batch_size, num_epochs=cfg.train.num_epochs, num_workers=cfg.train.num_workers, - train_transform_params=cfg.train.train_transform_params, - train_dataset_params=cfg.train.train_dataset_params, - val_transform_params=cfg.train.val_transform_params, - val_dataset_params=cfg.train.val_dataset_params, checkpoint_path=cfg.train.checkpoint_path, spect_scaler_path=cfg.train.spect_scaler_path, results_path=results_path, @@ -169,16 +159,11 @@ def test_train_raises_file_not_found( with pytest.raises(FileNotFoundError): vak.train.frame_classification.train_frame_classification_model( - model_name=cfg.train.model.name, model_config=cfg.train.model.asdict(), - dataset_path=cfg.train.dataset.path, + dataset_config=cfg.train.dataset.asdict(), batch_size=cfg.train.batch_size, num_epochs=cfg.train.num_epochs, num_workers=cfg.train.num_workers, - train_transform_params=cfg.train.train_transform_params, - train_dataset_params=cfg.train.train_dataset_params, - val_transform_params=cfg.train.val_transform_params, - val_dataset_params=cfg.train.val_dataset_params, checkpoint_path=cfg.train.checkpoint_path, spect_scaler_path=cfg.train.spect_scaler_path, results_path=results_path, @@ -224,16 +209,11 @@ def test_train_raises_not_a_directory( with pytest.raises(NotADirectoryError): vak.train.frame_classification.train_frame_classification_model( - model_name=cfg.train.model.name, model_config=cfg.train.model.asdict(), - dataset_path=cfg.train.dataset.path, + dataset_config=cfg.train.dataset.asdict(), batch_size=cfg.train.batch_size, num_epochs=cfg.train.num_epochs, num_workers=cfg.train.num_workers, - train_transform_params=cfg.train.train_transform_params, - train_dataset_params=cfg.train.train_dataset_params, - val_transform_params=cfg.train.val_transform_params, - val_dataset_params=cfg.train.val_dataset_params, checkpoint_path=cfg.train.checkpoint_path, spect_scaler_path=cfg.train.spect_scaler_path, results_path=results_path, diff --git a/tests/test_train/test_parametric_umap.py b/tests/test_train/test_parametric_umap.py index 89fb55db4..509078c21 100644 --- a/tests/test_train/test_parametric_umap.py +++ b/tests/test_train/test_parametric_umap.py @@ -54,16 +54,11 @@ def test_train_parametric_umap_model( cfg = vak.config.Config.from_toml_path(toml_path) vak.train.parametric_umap.train_parametric_umap_model( - model_name=cfg.train.model.name, model_config=cfg.train.model.asdict(), - dataset_path=cfg.train.dataset.path, + dataset_config=cfg.train.dataset.asdict(), batch_size=cfg.train.batch_size, num_epochs=cfg.train.num_epochs, num_workers=cfg.train.num_workers, - train_transform_params=cfg.train.train_transform_params, - train_dataset_params=cfg.train.train_dataset_params, - val_transform_params=cfg.train.val_transform_params, - val_dataset_params=cfg.train.val_dataset_params, checkpoint_path=cfg.train.checkpoint_path, results_path=results_path, shuffle=cfg.train.shuffle, @@ -106,16 +101,11 @@ def test_train_parametric_umap_model_raises_file_not_found( with pytest.raises(FileNotFoundError): vak.train.parametric_umap.train_parametric_umap_model( - model_name=cfg.train.model.name, model_config=cfg.train.model.asdict(), - dataset_path=cfg.train.dataset.path, + dataset_config=cfg.train.dataset.asdict(), batch_size=cfg.train.batch_size, num_epochs=cfg.train.num_epochs, num_workers=cfg.train.num_workers, - train_transform_params=cfg.train.train_transform_params, - train_dataset_params=cfg.train.train_dataset_params, - val_transform_params=cfg.train.val_transform_params, - val_dataset_params=cfg.train.val_dataset_params, checkpoint_path=cfg.train.checkpoint_path, results_path=results_path, shuffle=cfg.train.shuffle, @@ -159,16 +149,11 @@ def test_train_parametric_umap_model_raises_not_a_directory( with pytest.raises(NotADirectoryError): vak.train.parametric_umap.train_parametric_umap_model( - model_name=cfg.train.model, model_config=model_config, - dataset_path=cfg.train.dataset.path, + dataset_config=cfg.train.dataset.asdict(), batch_size=cfg.train.batch_size, num_epochs=cfg.train.num_epochs, num_workers=cfg.train.num_workers, - train_transform_params=cfg.train.train_transform_params, - train_dataset_params=cfg.train.train_dataset_params, - val_transform_params=cfg.train.val_transform_params, - val_dataset_params=cfg.train.val_dataset_params, checkpoint_path=cfg.train.checkpoint_path, results_path=results_path, shuffle=cfg.train.shuffle, diff --git a/tests/test_train/test_train.py b/tests/test_train/test_train.py index 4eb513d3c..b9038007e 100644 --- a/tests/test_train/test_train.py +++ b/tests/test_train/test_train.py @@ -53,16 +53,11 @@ def test_train( with mock.patch(train_function_to_mock, autospec=True) as mock_train_function: vak.train.train( - model_name=cfg.train.model.name, model_config=cfg.train.model.asdict(), - dataset_path=cfg.train.dataset.path, + dataset_config=cfg.train.dataset.asdict(), batch_size=cfg.train.batch_size, num_epochs=cfg.train.num_epochs, num_workers=cfg.train.num_workers, - train_transform_params=cfg.train.train_transform_params, - train_dataset_params=cfg.train.train_dataset_params, - val_transform_params=cfg.train.val_transform_params, - val_dataset_params=cfg.train.val_dataset_params, checkpoint_path=cfg.train.checkpoint_path, spect_scaler_path=cfg.train.spect_scaler_path, results_path=results_path, From 9d7ec3e71c08f19754c259449b24d737931d7a06 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Sat, 4 May 2024 22:49:08 -0400 Subject: [PATCH 180/183] Fix use of default transform and dataset.params attribute in test_models/test_base.py --- tests/test_models/test_base.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/test_models/test_base.py b/tests/test_models/test_base.py index bd30f00f9..ec9d37ccf 100644 --- a/tests/test_models/test_base.py +++ b/tests/test_models/test_base.py @@ -195,7 +195,7 @@ def test_load_state_dict_from_path(self, # stuff we need just to be able to instantiate network labelmap = vak.common.labels.to_map(train_cfg.prep.labelset, map_unlabeled=True) - transform, target_transform = vak.transforms.defaults.get_default_transform( + item_transform = vak.transforms.defaults.get_default_transform( model_name, "train", transform_kwargs={}, @@ -203,9 +203,8 @@ def test_load_state_dict_from_path(self, train_dataset = vak.datasets.frame_classification.WindowDataset.from_dataset_path( dataset_path=train_cfg.train.dataset.path, split="train", - window_size=train_cfg.train.train_dataset_params['window_size'], - transform=transform, - target_transform=target_transform, + window_size=train_cfg.train.dataset.params['window_size'], + item_transform=item_transform, ) input_shape = train_dataset.shape num_input_channels = input_shape[-3] From 33f578b42ceeff2bc5196387ea06063c532f2c65 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Sat, 4 May 2024 22:55:50 -0400 Subject: [PATCH 181/183] Fix config snippets in docs --- doc/get_started/autoannotate.md | 78 ++++++++++++++++----------------- doc/reference/config.md | 22 +++------- 2 files changed, 46 insertions(+), 54 deletions(-) diff --git a/doc/get_started/autoannotate.md b/doc/get_started/autoannotate.md index eab532e89..37ba9eef9 100644 --- a/doc/get_started/autoannotate.md +++ b/doc/get_started/autoannotate.md @@ -20,10 +20,10 @@ Below is an example of some annotated Bengalese finch song, which is what we'll :::{hint} `vak` has built-in support for widely-used annotation formats. -Even if your data is not annotated with one of these formats, -you can use `vak` by converting your annotations to a simple `.csv` format +Even if your data is not annotated with one of these formats, +you can use `vak` by converting your annotations to a simple `.csv` format that is easy to create with Python libraries like `pandas`. -For more information, please see: +For more information, please see: {ref}`howto-user-annot` ::: @@ -42,39 +42,39 @@ Before going through this tutorial, you'll need to: or [notepad++](https://notepad-plus-plus.org/) 3. Download example data from this dataset: - - one day of birdsong, for training data (click to download) + - one day of birdsong, for training data (click to download) {download}`https://figshare.com/ndownloader/files/41668980` - another day, to use to predict annotations (click to download) {download}`https://figshare.com/ndownloader/files/41668983` - - Be sure to extract the files from these archives! - Please use the program "tar" to extract the archives, + - Be sure to extract the files from these archives! + Please use the program "tar" to extract the archives, on either macOS/Linux or Windows. - Using other programs like WinZIP on Windows + Using other programs like WinZIP on Windows can corrupt the files when extracting them, causing confusing errors. Tar should be available on newer Windows systems - (as described + (as described [here](https://learn.microsoft.com/en-us/virtualization/community/team-blog/2017/20171219-tar-and-curl-come-to-windows)). - - Alternatively you can copy the following command and then - paste it into a terminal to run a Python script - that will download and extract the files for you. + - Alternatively you can copy the following command and then + paste it into a terminal to run a Python script + that will download and extract the files for you. :::{eval-rst} - + .. tabs:: - + .. code-tab:: shell macOS / Linux - + curl -sSL https://raw.githubusercontent.com/vocalpy/vak/main/src/scripts/download_autoannotate_data.py | python3 - - + .. code-tab:: shell Windows - + (Invoke-WebRequest -Uri https://raw.githubusercontent.com/vocalpy/vak/main/src/scripts/download_autoannotate_data.py -UseBasicParsing).Content | py - ::: 4. Download the corresponding configuration files (click to download): {download}`gy6or6_train.toml <../toml/gy6or6_train.toml>`, - {download}`gy6or6_eval.toml <../toml/gy6or6_eval.toml>`, + {download}`gy6or6_eval.toml <../toml/gy6or6_eval.toml>`, and {download}`gy6or6_predict.toml <../toml/gy6or6_predict.toml>` ## Overview @@ -181,7 +181,7 @@ Change the part of the path in capital letters to the actual location on your computer: ```toml -[PREP] +[vak.prep] dataset_type = "frame classification" input_type = "spect" # we change the next line @@ -230,11 +230,11 @@ When you run `prep`, `vak` converts the data from `data_dir` into a special data automatically adds the path to that file to the `[TRAIN]` section of the `config.toml` file, as the option `csv_path`. -You have now prepared a dataset for training a model! -You'll probably have more questions about -how to do this later, -when you start to work with your own data. -When that time comes, please see the how-to page: +You have now prepared a dataset for training a model! +You'll probably have more questions about +how to do this later, +when you start to work with your own data. +When that time comes, please see the how-to page: {ref}`howto-prep-annotate`. For now, let's move on to training a neural network with this dataset. @@ -294,7 +294,7 @@ from that checkpoint later when we predict annotations for new data. (prepare-prediction-dataset)= -An important step when using neural network models is to evaluate the model's performance +An important step when using neural network models is to evaluate the model's performance on a held-out dataset that has never been used during training, often called the "test" set. Here we show you how to evaluate the model we just trained. @@ -356,33 +356,33 @@ This file will also be found in the root `results_{timestamp}` directory. spect_scaler = "/home/users/You/Data/vak_tutorial_data/vak_output/results_{timestamp}/SpectScaler" ``` -The last path you need is actually in the TOML file that we used +The last path you need is actually in the TOML file that we used to train the neural network: `dataset_path`. -You should copy that `dataset_path` option exactly as it is -and then paste it at the bottom of the `[EVAL]` table +You should copy that `dataset_path` option exactly as it is +and then paste it at the bottom of the `[EVAL]` table in the configuration file for evaluation. -We do this instead of preparing another dataset, -because we already created a test split when we ran +We do this instead of preparing another dataset, +because we already created a test split when we ran `vak prep` with the training configuration. -This is a good practice, because it helps ensure +This is a good practice, because it helps ensure that we do not mix the training data with the test data; -`vak` makes sure that the data from the `data_dir` option +`vak` makes sure that the data from the `data_dir` option is placed in two separate splits, the train and test splits. -Once you have prepared the configuration file as described, +Once you have prepared the configuration file as described, you can run the following in the terminal: ```shell vak eval gy6o6_eval.toml ``` -You will see output to the console as the network is evaluated. -Notice that for this model we evaluate it *with* and *without* -post-processing transforms that clean up the predictions +You will see output to the console as the network is evaluated. +Notice that for this model we evaluate it *with* and *without* +post-processing transforms that clean up the predictions of the model. -The parameters of the post-processing transform are specified +The parameters of the post-processing transform are specified with the `post_tfm_kwargs` option in the configuration file. -You may find this helpful to understand factors affecting +You may find this helpful to understand factors affecting the performance of your own model. ## 4. Preparing a prediction dataset @@ -400,7 +400,7 @@ Just like before, you're going to modify the `data_dir` option of the This time you'll change it to the path to the directory with the other day of data we downloaded. ```toml -[PREP] +[vak.prep] data_dir = "/home/users/You/Data/vak_tutorial_data/032312" ``` @@ -428,7 +428,7 @@ and then add the path to that file as the option `csv_path` in the `[PREDICT]` s Finally you will use the trained network to predict annotations. This is the part that requires you to find paths to files saved by `vak`. -There's three you need. These are the exact same paths we used above +There's three you need. These are the exact same paths we used above in the configuration file for evaluation, so you can copy them from that file. We explain them again here for completeness. All three paths will be in the `results` directory diff --git a/doc/reference/config.md b/doc/reference/config.md index 687f45e2a..dbe0ec9ba 100644 --- a/doc/reference/config.md +++ b/doc/reference/config.md @@ -19,7 +19,7 @@ for each class. ## Valid section names Following is the set of valid section names: -`{PREP, SPECT_PARAMS, DATALOADER, TRAIN, PREDICT, LEARNCURVE}`. +`{eval, learncurve, predict, prep, train}`. In the code, these names correspond to attributes of the main `Config` class, as shown below. @@ -43,50 +43,42 @@ that are considered valid. Valid options for each section are presented below. (ref-config-prep)= -### `[PREP]` section +### `[vak.prep]` section ```{eval-rst} .. autoclass:: vak.config.prep.PrepConfig ``` (ref-config-spect-params)= -### `[SPECT_PARAMS]` section +### `[vak.prep.spect_params]` section ```{eval-rst} .. autoclass:: vak.config.spect_params.SpectParamsConfig ``` -(ref-config-dataloader)= -### `[DATALOADER]` section - -```{eval-rst} -.. autoclass:: vak.config.dataloader.DataLoaderConfig - -``` - (ref-config-train)= -### `[TRAIN]` section +### `[vak.train]` section ```{eval-rst} .. autoclass:: vak.config.train.TrainConfig ``` (ref-config-eval)= -### `[EVAL]` section +### `[vak.eval]` section ```{eval-rst} .. autoclass:: vak.config.eval.EvalConfig ``` (ref-config-predict)= -### `[PREDICT]` section +### `[vak.predict]` section ```{eval-rst} .. autoclass:: vak.config.predict.PredictConfig ``` (ref-config-learncurve)= -### `[LEARNCURVE]` section +### `[vak.learncurve]` section ```{eval-rst} .. autoclass:: vak.config.learncurve.LearncurveConfig From 84c316b84f892d75171e04b43a22f3e0bd442908 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Sat, 4 May 2024 22:57:02 -0400 Subject: [PATCH 182/183] Apply linting to src/ --- src/vak/config/dataset.py | 9 ++++----- src/vak/datasets/frame_classification/window_dataset.py | 2 +- src/vak/eval/eval_.py | 4 ++-- src/vak/eval/frame_classification.py | 4 ++-- src/vak/learncurve/frame_classification.py | 4 +++- src/vak/predict/parametric_umap.py | 3 ++- src/vak/train/frame_classification.py | 6 +++--- src/vak/train/train_.py | 4 ++-- 8 files changed, 19 insertions(+), 17 deletions(-) diff --git a/src/vak/config/dataset.py b/src/vak/config/dataset.py index dc86ec963..f75c34b73 100644 --- a/src/vak/config/dataset.py +++ b/src/vak/config/dataset.py @@ -38,18 +38,17 @@ class DatasetConfig: name: str | None = field( converter=attr.converters.optional(str), default=None ) - params : dict | None = field( + params: dict | None = field( # we default to an empty dict instead of None # so we can still do **['dataset']['params'] everywhere we do when params are specified - converter=attr.converters.optional(dict), default={} + converter=attr.converters.optional(dict), + default={}, ) @classmethod def from_config_dict(cls, config_dict: dict) -> DatasetConfig: - return cls( - **config_dict - ) + return cls(**config_dict) def asdict(self): """Convert this :class:`DatasetConfig` instance diff --git a/src/vak/datasets/frame_classification/window_dataset.py b/src/vak/datasets/frame_classification/window_dataset.py index c4781f07f..d916a6bcc 100644 --- a/src/vak/datasets/frame_classification/window_dataset.py +++ b/src/vak/datasets/frame_classification/window_dataset.py @@ -281,7 +281,7 @@ def shape(self): tmp_item = self.__getitem__(tmp_x_ind) # used by vak functions that need to determine size of window, # e.g. when initializing a neural network model - return tmp_item['frames'].shape + return tmp_item["frames"].shape def _load_frames(self, frames_path): """Helper function that loads "frames", diff --git a/src/vak/eval/eval_.py b/src/vak/eval/eval_.py index d8a13e663..fa1209f1d 100644 --- a/src/vak/eval/eval_.py +++ b/src/vak/eval/eval_.py @@ -93,13 +93,13 @@ def eval( f"value for ``{path_name}`` not recognized as a file: {path}" ) - dataset_path = pathlib.Path(dataset_config['path']) + dataset_path = pathlib.Path(dataset_config["path"]) if not dataset_path.exists() or not dataset_path.is_dir(): raise NotADirectoryError( f"`dataset_path` not found or not recognized as a directory: {dataset_path}" ) - model_name = model_config['name'] + model_name = model_config["name"] try: model_family = models.registry.MODEL_FAMILY_FROM_NAME[model_name] except KeyError as e: diff --git a/src/vak/eval/frame_classification.py b/src/vak/eval/frame_classification.py index 0cc4bb41f..1be124758 100644 --- a/src/vak/eval/frame_classification.py +++ b/src/vak/eval/frame_classification.py @@ -95,7 +95,7 @@ def eval_frame_classification_model( f"value for ``{path_name}`` not recognized as a file: {path}" ) - dataset_path = pathlib.Path(dataset_config['path']) + dataset_path = pathlib.Path(dataset_config["path"]) if not dataset_path.exists() or not dataset_path.is_dir(): raise NotADirectoryError( f"`dataset_path` not found or not recognized as a directory: {dataset_path}" @@ -143,7 +143,7 @@ def eval_frame_classification_model( ) transform_params = { "spect_standardizer": spect_standardizer, - "window_size": window_size + "window_size": window_size, } item_transform = transforms.defaults.get_default_transform( diff --git a/src/vak/learncurve/frame_classification.py b/src/vak/learncurve/frame_classification.py index 76364679d..8ca2e11b7 100644 --- a/src/vak/learncurve/frame_classification.py +++ b/src/vak/learncurve/frame_classification.py @@ -158,7 +158,9 @@ def learning_curve_for_frame_classification_model( # ---- main loop that creates "learning curve" --------------------------------------------------------------------- logger.info("Starting training for learning curve.") - model_name = model_config["name"] # used below when getting checkpoint path, etc + model_name = model_config[ + "name" + ] # used below when getting checkpoint path, etc for train_dur, replicate_num in to_do: logger.info( f"Training model with training set of size: {train_dur}s, replicate number {replicate_num}.", diff --git a/src/vak/predict/parametric_umap.py b/src/vak/predict/parametric_umap.py index bcc9250e9..4e54336f4 100644 --- a/src/vak/predict/parametric_umap.py +++ b/src/vak/predict/parametric_umap.py @@ -105,7 +105,8 @@ def predict_with_parametric_umap_model( # TODO: fix this when we build transforms into datasets transform_params = { "padding": dataset_config["params"].get( - "padding", models.convencoder_umap.get_default_padding(metadata.shape) + "padding", + models.convencoder_umap.get_default_padding(metadata.shape), ) } item_transform = transforms.defaults.get_default_transform( diff --git a/src/vak/train/frame_classification.py b/src/vak/train/frame_classification.py index 99dd93734..da9a13c54 100644 --- a/src/vak/train/frame_classification.py +++ b/src/vak/train/frame_classification.py @@ -127,7 +127,7 @@ def train_frame_classification_model( f"value for ``{path_name}`` not recognized as a file: {path}" ) - dataset_path = pathlib.Path(dataset_config['path']) + dataset_path = pathlib.Path(dataset_config["path"]) if not dataset_path.exists() or not dataset_path.is_dir(): raise NotADirectoryError( f"`dataset_path` not found or not recognized as a directory: {dataset_path}" @@ -209,7 +209,7 @@ def train_frame_classification_model( ) spect_standardizer = None - model_name = model_config['name'] + model_name = model_config["name"] # TODO: move this into datapipe once each datapipe uses a fixed set of transforms # that will require adding `spect_standardizer`` as a parameter to the datapipe, # maybe rename to `frames_standardizer`? @@ -233,7 +233,7 @@ def train_frame_classification_model( split="train", subset=subset, item_transform=train_transform, - **dataset_config['params'], + **dataset_config["params"], ) logger.info( f"Duration of WindowDataset used for training, in seconds: {train_dataset.duration}", diff --git a/src/vak/train/train_.py b/src/vak/train/train_.py index 7d3597179..96926967d 100644 --- a/src/vak/train/train_.py +++ b/src/vak/train/train_.py @@ -130,13 +130,13 @@ def train( f"value for ``{path_name}`` not recognized as a file: {path}" ) - dataset_path = pathlib.Path(dataset_config['path']) + dataset_path = pathlib.Path(dataset_config["path"]) if not dataset_path.exists() or not dataset_path.is_dir(): raise NotADirectoryError( f"`dataset_path` not found or not recognized as a directory: {dataset_path}" ) - model_name = model_config['name'] + model_name = model_config["name"] try: model_family = models.registry.MODEL_FAMILY_FROM_NAME[model_name] except KeyError as e: From bc0b48752d11e44a7ff96cf0d1e2425a3613fc72 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Sat, 4 May 2024 22:58:54 -0400 Subject: [PATCH 183/183] Raise 'from e' with errors in eval/predict/train/frame_classification modules --- src/vak/eval/frame_classification.py | 2 +- src/vak/predict/frame_classification.py | 2 +- src/vak/train/frame_classification.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/vak/eval/frame_classification.py b/src/vak/eval/frame_classification.py index 1be124758..9757287e8 100644 --- a/src/vak/eval/frame_classification.py +++ b/src/vak/eval/frame_classification.py @@ -140,7 +140,7 @@ def eval_frame_classification_model( raise KeyError( f"The `dataset_config` for frame classification model '{model_name}' must include a 'params' sub-table " f"that sets a value for 'window_size', but received a `dataset_config` that did not:\n{dataset_config}" - ) + ) from e transform_params = { "spect_standardizer": spect_standardizer, "window_size": window_size, diff --git a/src/vak/predict/frame_classification.py b/src/vak/predict/frame_classification.py index d65924490..765fb3134 100644 --- a/src/vak/predict/frame_classification.py +++ b/src/vak/predict/frame_classification.py @@ -145,7 +145,7 @@ def predict_with_frame_classification_model( raise KeyError( f"The `dataset_config` for frame classification model '{model_name}' must include a 'params' sub-table " f"that sets a value for 'window_size', but received a `dataset_config` that did not:\n{dataset_config}" - ) + ) from e transform_params = { "spect_standardizer": spect_standardizer, "window_size": window_size, diff --git a/src/vak/train/frame_classification.py b/src/vak/train/frame_classification.py index da9a13c54..25e07a062 100644 --- a/src/vak/train/frame_classification.py +++ b/src/vak/train/frame_classification.py @@ -219,7 +219,7 @@ def train_frame_classification_model( raise KeyError( f"The `dataset_config` for frame classification model '{model_name}' must include a 'params' sub-table " f"that sets a value for 'window_size', but received a `dataset_config` that did not:\n{dataset_config}" - ) + ) from e transform_kwargs = { "spect_standardizer": spect_standardizer, "window_size": window_size,