diff --git a/.buildinfo b/.buildinfo index 7e147ae..07f5d1a 100644 --- a/.buildinfo +++ b/.buildinfo @@ -1,4 +1,4 @@ # Sphinx build info version 1 # This file hashes the configuration used when building these files. When it is not found, a full rebuild will be done. -config: 6de73c88159c02f3816efd2634ab99e0 +config: b1a3f6489316c3363b54d50c321eaccd tags: 645f666f9bcd5a90fca523b33c5a78b7 diff --git a/_images/notebooks_05_covid_anomaly_detection_19_0.png b/_images/notebooks_05_covid_anomaly_detection_19_0.png new file mode 100644 index 0000000..3137016 Binary files /dev/null and b/_images/notebooks_05_covid_anomaly_detection_19_0.png differ diff --git a/_images/stanford_data_processing.png b/_images/stanford_data_processing.png new file mode 100644 index 0000000..aa37ff9 Binary files /dev/null and b/_images/stanford_data_processing.png differ diff --git a/_modules/index.html b/_modules/index.html index bc60a2f..bb39481 100644 --- a/_modules/index.html +++ b/_modules/index.html @@ -6,7 +6,7 @@ Overview: module code — SSLTools documentation - + - - - - - - - - - - - - - -
- - -
- -
-
-
- -
-
-
-
- -

Source code for ssl_tools.utils.data

-from bisect import bisect_right
-
-
-[docs] -class ConcatDataset: - """ - Concatenate multiple datasets1 - """ - - def __init__(self, datasets): - self.datasets = datasets - self.slices = self._get_slices(datasets) - -
-[docs] - @staticmethod - def _get_slices(datasets): - i = 0 - slices = [] - for d in datasets: - i += len(d) - slices.append(i) - return slices
- - -
-[docs] - def __getitem__(self, i): - bucket = bisect_right(self.slices, i) - if bucket >= len(self.datasets): - raise IndexError("Index out of range") - - return self.datasets[bucket][i-self.slices[bucket]]
- - - -
-[docs] - def __len__(self): - return self.slices[-1]
-
- -
- -
-
- -
-
-
-
- - - - \ No newline at end of file diff --git a/_sources/autoapi/ssl_tools/analysis/index.rst.txt b/_sources/autoapi/ssl_tools/analysis/index.rst.txt index 38a97a5..72ff1f3 100644 --- a/_sources/autoapi/ssl_tools/analysis/index.rst.txt +++ b/_sources/autoapi/ssl_tools/analysis/index.rst.txt @@ -1,15 +1,16 @@ -:py:mod:`ssl_tools.analysis` -============================ +ssl_tools.analysis +================== .. py:module:: ssl_tools.analysis Submodules ---------- + .. toctree:: - :titlesonly: :maxdepth: 1 - plot_metrics/index.rst + /autoapi/ssl_tools/analysis/latent_analysis/index + /autoapi/ssl_tools/analysis/plot_metrics/index diff --git a/_sources/autoapi/ssl_tools/analysis/latent_analysis/index.rst.txt b/_sources/autoapi/ssl_tools/analysis/latent_analysis/index.rst.txt new file mode 100644 index 0000000..d77686c --- /dev/null +++ b/_sources/autoapi/ssl_tools/analysis/latent_analysis/index.rst.txt @@ -0,0 +1,40 @@ +ssl_tools.analysis.latent_analysis +================================== + +.. py:module:: ssl_tools.analysis.latent_analysis + + +Classes +------- + +.. autoapisummary:: + + ssl_tools.analysis.latent_analysis.LatentAnalysis + ssl_tools.analysis.latent_analysis.LayerOutputSaverHook + + +Module Contents +--------------- + +.. py:class:: LatentAnalysis(layers, sklearn_cls, output_name_suffix = 'transformed', **sklearn_kwargs) + + .. py:method:: __call__(trainer, model, data_module) + + +.. py:class:: LayerOutputSaverHook + + .. py:method:: _forward_hook(module, inputs, outputs, layer_name) + + + .. py:method:: attach_hooks(model, layer_names) + + + .. py:method:: outputs_from_layer(layer_name, concat = True) + + + .. py:method:: remove_hooks() + + + .. py:method:: run_model_with_hooks(model, layer_names) + + diff --git a/_sources/autoapi/ssl_tools/analysis/plot_metrics/index.rst.txt b/_sources/autoapi/ssl_tools/analysis/plot_metrics/index.rst.txt index 7d9836f..965bd0d 100644 --- a/_sources/autoapi/ssl_tools/analysis/plot_metrics/index.rst.txt +++ b/_sources/autoapi/ssl_tools/analysis/plot_metrics/index.rst.txt @@ -1,35 +1,33 @@ -:py:mod:`ssl_tools.analysis.plot_metrics` -========================================= +ssl_tools.analysis.plot_metrics +=============================== .. py:module:: ssl_tools.analysis.plot_metrics -Module Contents ---------------- - Classes -~~~~~~~ +------- .. autoapisummary:: ssl_tools.analysis.plot_metrics.PlotMetrics - Functions -~~~~~~~~~ +--------- .. autoapisummary:: ssl_tools.analysis.plot_metrics.main +Module Contents +--------------- .. py:class:: PlotMetrics - Class for plotting metrics from a training/predict run. + .. py:method:: accuracy(root_experiment_dir, results_file = 'results.csv', hyperparams_file = 'hparams.yaml', metric = 'test_acc', title = 'Results') Plot the accuracy of a multiple test runs in a single figure. This @@ -64,6 +62,7 @@ Functions Title of the plot. + .. py:method:: epoch_loss(experiment_dir, losses = ('train_loss', 'val_loss'), metrics_file = 'metrics.csv', title = 'Loss') Plot the loss over epochs. @@ -85,4 +84,3 @@ Functions .. py:function:: main() - diff --git a/_sources/autoapi/ssl_tools/benchmarks/index.rst.txt b/_sources/autoapi/ssl_tools/benchmarks/index.rst.txt new file mode 100644 index 0000000..aa4861b --- /dev/null +++ b/_sources/autoapi/ssl_tools/benchmarks/index.rst.txt @@ -0,0 +1,18 @@ +ssl_tools.benchmarks +==================== + +.. py:module:: ssl_tools.benchmarks + + +Submodules +---------- + +.. toctree:: + :maxdepth: 1 + + /autoapi/ssl_tools/benchmarks/main_mix_style/index + /autoapi/ssl_tools/benchmarks/main_supervised/index + /autoapi/ssl_tools/benchmarks/main_supervised_analysis/index + /autoapi/ssl_tools/benchmarks/simple_trainer/index + + diff --git a/_sources/autoapi/ssl_tools/benchmarks/main_mix_style/index.rst.txt b/_sources/autoapi/ssl_tools/benchmarks/main_mix_style/index.rst.txt new file mode 100644 index 0000000..3fb79e5 --- /dev/null +++ b/_sources/autoapi/ssl_tools/benchmarks/main_mix_style/index.rst.txt @@ -0,0 +1,270 @@ +ssl_tools.benchmarks.main_mix_style +=================================== + +.. py:module:: ssl_tools.benchmarks.main_mix_style + + +Classes +------- + +.. autoapisummary:: + + ssl_tools.benchmarks.main_mix_style.CNN_HaEtAl_1D + ssl_tools.benchmarks.main_mix_style.CNN_HaEtAl_1D_Backbone + ssl_tools.benchmarks.main_mix_style.CNN_HaEtAl_2D + ssl_tools.benchmarks.main_mix_style.CNN_HaEtAl_2D_Backbone + ssl_tools.benchmarks.main_mix_style.ConvolutionalBlock + ssl_tools.benchmarks.main_mix_style.ExperimentArgs + ssl_tools.benchmarks.main_mix_style.ResNet1DBase + ssl_tools.benchmarks.main_mix_style.ResNet1D_8 + ssl_tools.benchmarks.main_mix_style.ResNetBlock + ssl_tools.benchmarks.main_mix_style.ResNetSE1D_5 + ssl_tools.benchmarks.main_mix_style.ResNetSE1D_8 + ssl_tools.benchmarks.main_mix_style.ResNetSEBlock + ssl_tools.benchmarks.main_mix_style.SimpleClassificationNet2 + ssl_tools.benchmarks.main_mix_style.SqueezeAndExcitation1D + ssl_tools.benchmarks.main_mix_style._ResNet1D + + +Functions +--------- + +.. autoapisummary:: + + ssl_tools.benchmarks.main_mix_style._run_experiment_wrapper + ssl_tools.benchmarks.main_mix_style.cli_main + ssl_tools.benchmarks.main_mix_style.conv3x3 + ssl_tools.benchmarks.main_mix_style.conv3x3_dynamic + ssl_tools.benchmarks.main_mix_style.main_loo + ssl_tools.benchmarks.main_mix_style.pretty_print_experiment_args + ssl_tools.benchmarks.main_mix_style.run_serial + ssl_tools.benchmarks.main_mix_style.run_using_ray + + +Module Contents +--------------- + +.. py:class:: CNN_HaEtAl_1D(input_shape = (1, 6, 60), num_classes = 6, learning_rate = 0.001) + + Bases: :py:obj:`SimpleClassificationNet2` + + + .. py:method:: _calculate_fc_input_features(backbone, input_shape) + + Run a single forward pass with a random input to get the number of + features after the convolutional layers. + + Parameters + ---------- + backbone : torch.nn.Module + The backbone of the network + input_shape : Tuple[int, int, int] + The input shape of the network. + + Returns + ------- + int + The number of features after the convolutional layers. + + + + .. py:method:: _create_backbone(input_shape) + + + .. py:method:: _create_fc(input_features, num_classes) + + +.. py:class:: CNN_HaEtAl_1D_Backbone(input_channels = 1) + + Bases: :py:obj:`torch.nn.Module` + + + .. py:method:: forward(x) + + +.. py:class:: CNN_HaEtAl_2D(pad_at = (3, ), input_shape = (1, 6, 60), num_classes = 6, learning_rate = 0.001) + + Bases: :py:obj:`SimpleClassificationNet2` + + + .. py:method:: _calculate_fc_input_features(backbone, input_shape) + + Run a single forward pass with a random input to get the number of + features after the convolutional layers. + + Parameters + ---------- + backbone : torch.nn.Module + The backbone of the network + input_shape : Tuple[int, int, int] + The input shape of the network. + + Returns + ------- + int + The number of features after the convolutional layers. + + + + .. py:method:: _create_backbone(input_shape) + + + .. py:method:: _create_fc(input_features, num_classes) + + +.. py:class:: CNN_HaEtAl_2D_Backbone(pad_at, in_channels = 1) + + Bases: :py:obj:`torch.nn.Module` + + + .. py:method:: forward(x) + + +.. py:class:: ConvolutionalBlock(in_channels, activation_cls = None) + + Bases: :py:obj:`torch.nn.Module` + + + .. py:method:: forward(x) + + +.. py:class:: ExperimentArgs + + .. py:attribute:: data_cls + :type: Any + + + .. py:attribute:: mix + :type: bool + :value: True + + + + .. py:attribute:: model_args + :type: Dict[str, Any] + + + .. py:attribute:: model_cls + :type: Any + + + .. py:attribute:: seed + :type: int + :value: 42 + + + + .. py:attribute:: test_data_args + :type: Dict[str, Any] + + + .. py:attribute:: train_data_args + :type: Dict[str, Any] + + + .. py:attribute:: trainer_args + :type: Dict[str, Any] + + + .. py:attribute:: trainer_cls + :type: Any + + +.. py:class:: ResNet1DBase(resnet_block_cls = ResNetBlock, activation_cls = torch.nn.ReLU, input_shape = (6, 60), num_classes = 6, num_residual_blocks = 5, reduction_ratio=2, learning_rate = 0.001) + + Bases: :py:obj:`SimpleClassificationNet2` + + + .. py:method:: _calculate_fc_input_features(backbone, input_shape) + + Run a single forward pass with a random input to get the number of + features after the convolutional layers. + + Parameters + ---------- + backbone : torch.nn.Module + The backbone of the network + input_shape : Tuple[int, int, int] + The input shape of the network. + + Returns + ------- + int + The number of features after the convolutional layers. + + + +.. py:class:: ResNet1D_8(*args, **kwargs) + + Bases: :py:obj:`ResNet1DBase` + + +.. py:class:: ResNetBlock(in_channels = 64, activation_cls = torch.nn.ReLU, mix_style_factor=False) + + Bases: :py:obj:`torch.nn.Module` + + + .. py:method:: forward(x) + + +.. py:class:: ResNetSE1D_5(*args, **kwargs) + + Bases: :py:obj:`ResNet1DBase` + + +.. py:class:: ResNetSE1D_8(*args, **kwargs) + + Bases: :py:obj:`ResNet1DBase` + + +.. py:class:: ResNetSEBlock(*args, **kwargs) + + Bases: :py:obj:`ResNetBlock` + + +.. py:class:: SimpleClassificationNet2(backbone, fc, learning_rate = 0.001, flatten = True, loss_fn = None, train_metrics = None, val_metrics = None, test_metrics = None) + + Bases: :py:obj:`ssl_tools.models.nets.simple.SimpleClassificationNet` + + + .. py:method:: single_step(batch, batch_idx, step_name) + + +.. py:class:: SqueezeAndExcitation1D(in_channels, reduction_ratio = 2) + + Bases: :py:obj:`torch.nn.Module` + + + .. py:method:: forward(input_tensor) + + +.. py:class:: _ResNet1D(input_shape, residual_block_cls=ResNetBlock, activation_cls = torch.nn.ReLU, num_residual_blocks = 5, reduction_ratio=2) + + Bases: :py:obj:`torch.nn.Module` + + + .. py:method:: forward(x) + + +.. py:function:: _run_experiment_wrapper(experiment_args) + +.. py:function:: cli_main(experiment) + +.. py:function:: conv3x3(in_planes, out_planes, stride = 1, groups = 1, dilation = 1) + + 3x3 convolution with padding + + +.. py:function:: conv3x3_dynamic(in_planes, out_planes, stride = 1, attention_in_channels = None) + + 3x3 convolution with padding + + +.. py:function:: main_loo() + +.. py:function:: pretty_print_experiment_args(args, indent = 4) + +.. py:function:: run_serial(experiments) + +.. py:function:: run_using_ray(experiments, ray_address = None) + diff --git a/_sources/autoapi/ssl_tools/benchmarks/main_supervised/index.rst.txt b/_sources/autoapi/ssl_tools/benchmarks/main_supervised/index.rst.txt new file mode 100644 index 0000000..a500d59 --- /dev/null +++ b/_sources/autoapi/ssl_tools/benchmarks/main_supervised/index.rst.txt @@ -0,0 +1,89 @@ +ssl_tools.benchmarks.main_supervised +==================================== + +.. py:module:: ssl_tools.benchmarks.main_supervised + + +Classes +------- + +.. autoapisummary:: + + ssl_tools.benchmarks.main_supervised.ExperimentArgs + ssl_tools.benchmarks.main_supervised.SupervisedConfigParser + + +Functions +--------- + +.. autoapisummary:: + + ssl_tools.benchmarks.main_supervised._run_experiment_wrapper + ssl_tools.benchmarks.main_supervised.cli_main + ssl_tools.benchmarks.main_supervised.hack_to_avoid_lightning_cli_sys_argv_warning + ssl_tools.benchmarks.main_supervised.main + ssl_tools.benchmarks.main_supervised.run + ssl_tools.benchmarks.main_supervised.run_serial + ssl_tools.benchmarks.main_supervised.run_using_ray + + +Module Contents +--------------- + +.. py:class:: ExperimentArgs + + .. py:attribute:: data + :type: Dict[str, Any] + + + .. py:attribute:: model + :type: Dict[str, Any] + + + .. py:attribute:: num_classes + :type: int + :value: 7 + + + + .. py:attribute:: seed + :type: int + :value: 42 + + + + .. py:attribute:: test_data + :type: Dict[str, Any] + + + .. py:attribute:: trainer + :type: Dict[str, Any] + + +.. py:class:: SupervisedConfigParser(data_path, default_trainer_config, data_module_configs, model_configs, output_dir = 'benchmarks/', skip_existing = True, seed = 42, leave_one_out = False, data_shapes_file = None, num_classes = 7) + + .. py:method:: __call__() + + + .. py:method:: filter_experiments(experiments) + + + .. py:method:: scan_configs(configs_path) + :staticmethod: + + + +.. py:function:: _run_experiment_wrapper(experiment_args) + +.. py:function:: cli_main(experiment) + +.. py:function:: hack_to_avoid_lightning_cli_sys_argv_warning(func, *args, **kwargs) + +.. py:function:: main(data_path, default_trainer_config_file, data_module_configs_path, model_configs_path, output_path = 'benchmarks/', skip_existing = True, ray_address = None, use_ray = True, seed = 42, dry_run = False, dry_run_limit = 5, leave_one_out = False, data_shapes_file = None, num_classes = 7) + +.. py:function:: run(config_parser, use_ray, ray_address = None, dry_run = False, dry_run_limit = 3) + +.. py:function:: run_serial(experiments) + +.. py:function:: run_using_ray(experiments, ray_address = None) + diff --git a/_sources/autoapi/ssl_tools/benchmarks/main_supervised_analysis/index.rst.txt b/_sources/autoapi/ssl_tools/benchmarks/main_supervised_analysis/index.rst.txt new file mode 100644 index 0000000..dd8e8c3 --- /dev/null +++ b/_sources/autoapi/ssl_tools/benchmarks/main_supervised_analysis/index.rst.txt @@ -0,0 +1,19 @@ +ssl_tools.benchmarks.main_supervised_analysis +============================================= + +.. py:module:: ssl_tools.benchmarks.main_supervised_analysis + + +Functions +--------- + +.. autoapisummary:: + + ssl_tools.benchmarks.main_supervised_analysis.analysis + + +Module Contents +--------------- + +.. py:function:: analysis(results_dir, query = None, output_dir = None, result_file = 'results.csv', print_results = True, remove_on_error = False) + diff --git a/_sources/autoapi/ssl_tools/benchmarks/simple_trainer/index.rst.txt b/_sources/autoapi/ssl_tools/benchmarks/simple_trainer/index.rst.txt new file mode 100644 index 0000000..c7b3200 --- /dev/null +++ b/_sources/autoapi/ssl_tools/benchmarks/simple_trainer/index.rst.txt @@ -0,0 +1,19 @@ +ssl_tools.benchmarks.simple_trainer +=================================== + +.. py:module:: ssl_tools.benchmarks.simple_trainer + + +Functions +--------- + +.. autoapisummary:: + + ssl_tools.benchmarks.simple_trainer.cli_main + + +Module Contents +--------------- + +.. py:function:: cli_main() + diff --git a/_sources/autoapi/ssl_tools/callbacks/index.rst.txt b/_sources/autoapi/ssl_tools/callbacks/index.rst.txt index 6d60532..776afc6 100644 --- a/_sources/autoapi/ssl_tools/callbacks/index.rst.txt +++ b/_sources/autoapi/ssl_tools/callbacks/index.rst.txt @@ -1,15 +1,16 @@ -:py:mod:`ssl_tools.callbacks` -============================= +ssl_tools.callbacks +=================== .. py:module:: ssl_tools.callbacks Submodules ---------- + .. toctree:: - :titlesonly: :maxdepth: 1 - performance/index.rst + /autoapi/ssl_tools/callbacks/performance/index + /autoapi/ssl_tools/callbacks/save_best/index diff --git a/_sources/autoapi/ssl_tools/callbacks/performance/index.rst.txt b/_sources/autoapi/ssl_tools/callbacks/performance/index.rst.txt index f133f92..e0673ad 100644 --- a/_sources/autoapi/ssl_tools/callbacks/performance/index.rst.txt +++ b/_sources/autoapi/ssl_tools/callbacks/performance/index.rst.txt @@ -1,46 +1,48 @@ -:py:mod:`ssl_tools.callbacks.performance` -========================================= +ssl_tools.callbacks.performance +=============================== .. py:module:: ssl_tools.callbacks.performance -Module Contents ---------------- - Classes -~~~~~~~ +------- .. autoapisummary:: - ssl_tools.callbacks.performance.PerformanceLog + ssl_tools.callbacks.performance.PerformanceLogger +Module Contents +--------------- - -.. py:class:: PerformanceLog - +.. py:class:: PerformanceLogger Bases: :py:obj:`lightning.pytorch.callbacks.Callback` + This callback logs the time taken for each epoch and the overall fit time. + .. py:method:: on_fit_end(trainer, module) Called when fit ends. + .. py:method:: on_fit_start(trainer, module) Called when fit begins. + .. py:method:: on_train_epoch_end(trainer, module) Called when the train epoch ends. + .. py:method:: on_train_epoch_start(trainer, module) Called when the train epoch begins. diff --git a/_sources/autoapi/ssl_tools/callbacks/save_best/index.rst.txt b/_sources/autoapi/ssl_tools/callbacks/save_best/index.rst.txt new file mode 100644 index 0000000..b1400b3 --- /dev/null +++ b/_sources/autoapi/ssl_tools/callbacks/save_best/index.rst.txt @@ -0,0 +1,25 @@ +ssl_tools.callbacks.save_best +============================= + +.. py:module:: ssl_tools.callbacks.save_best + + +Classes +------- + +.. autoapisummary:: + + ssl_tools.callbacks.save_best.PickleBestModelAndLoad + + +Module Contents +--------------- + +.. py:class:: PickleBestModelAndLoad(model_name, filename = 'best_model.pt', model_tags = None, model_description = None) + + Bases: :py:obj:`lightning.Callback` + + + .. py:method:: on_train_end(trainer, module) + + diff --git a/_sources/autoapi/ssl_tools/data/data_modules/base/index.rst.txt b/_sources/autoapi/ssl_tools/data/data_modules/base/index.rst.txt new file mode 100644 index 0000000..e9b68a8 --- /dev/null +++ b/_sources/autoapi/ssl_tools/data/data_modules/base/index.rst.txt @@ -0,0 +1,44 @@ +ssl_tools.data.data_modules.base +================================ + +.. py:module:: ssl_tools.data.data_modules.base + + +Classes +------- + +.. autoapisummary:: + + ssl_tools.data.data_modules.base.SimpleDataModule + + +Module Contents +--------------- + +.. py:class:: SimpleDataModule + + Bases: :py:obj:`lightning.LightningDataModule` + + + .. py:method:: _get_loader(split_name, shuffle) + :abstractmethod: + + + + .. py:method:: _load_dataset(split_name) + :abstractmethod: + + + + .. py:method:: predict_dataloader() + + + .. py:method:: test_dataloader() + + + .. py:method:: train_dataloader() + + + .. py:method:: val_dataloader() + + diff --git a/_sources/autoapi/ssl_tools/data/data_modules/covid_anomaly/index.rst.txt b/_sources/autoapi/ssl_tools/data/data_modules/covid_anomaly/index.rst.txt new file mode 100644 index 0000000..9fa2c8c --- /dev/null +++ b/_sources/autoapi/ssl_tools/data/data_modules/covid_anomaly/index.rst.txt @@ -0,0 +1,43 @@ +ssl_tools.data.data_modules.covid_anomaly +========================================= + +.. py:module:: ssl_tools.data.data_modules.covid_anomaly + + +Classes +------- + +.. autoapisummary:: + + ssl_tools.data.data_modules.covid_anomaly.CovidUserAnomalyDataModule + + +Module Contents +--------------- + +.. py:class:: CovidUserAnomalyDataModule(data_path, participants = None, feature_column_prefix = 'RHR', target_column = 'anomaly', participant_column = 'participant_id', include_recovered_in_test = False, reshape = None, train_transforms = None, batch_size = 32, num_workers = 1, validation_split = 0.2, dataset_transforms = None, shuffle_train = True, discard_last_batch = False, balance = False, train_baseline_only = True) + + Bases: :py:obj:`lightning.LightningDataModule` + + + .. py:method:: __repr__() + + + .. py:method:: __str__() + + + .. py:method:: predict_dataloader() + + + .. py:method:: setup(stage) + + + .. py:method:: test_dataloader() + + + .. py:method:: train_dataloader() + + + .. py:method:: val_dataloader() + + diff --git a/_sources/autoapi/ssl_tools/data/data_modules/har/index.rst.txt b/_sources/autoapi/ssl_tools/data/data_modules/har/index.rst.txt index 6e7cd1c..7138a02 100644 --- a/_sources/autoapi/ssl_tools/data/data_modules/har/index.rst.txt +++ b/_sources/autoapi/ssl_tools/data/data_modules/har/index.rst.txt @@ -1,26 +1,23 @@ -:py:mod:`ssl_tools.data.data_modules.har` -========================================= +ssl_tools.data.data_modules.har +=============================== .. py:module:: ssl_tools.data.data_modules.har -Module Contents ---------------- - Classes -~~~~~~~ +------- .. autoapisummary:: + ssl_tools.data.data_modules.har.AugmentedMultiModalHARSeriesDataModule ssl_tools.data.data_modules.har.MultiModalHARSeriesDataModule ssl_tools.data.data_modules.har.TFCDataModule ssl_tools.data.data_modules.har.TNCHARDataModule ssl_tools.data.data_modules.har.UserActivityFolderDataModule - Functions -~~~~~~~~~ +--------- .. autoapisummary:: @@ -28,11 +25,13 @@ Functions ssl_tools.data.data_modules.har.parse_transforms +Module Contents +--------------- -.. py:class:: MultiModalHARSeriesDataModule(data_path, feature_prefixes = ('accel-x', 'accel-y', 'accel-z', 'gyro-x', 'gyro-y', 'gyro-z'), label = 'standard activity code', features_as_channels = True, transforms = None, cast_to = 'float32', batch_size = 1, num_workers = None) +.. py:class:: AugmentedMultiModalHARSeriesDataModule(train_transforms, validation_transforms = None, test_transforms = None, **kwargs) + Bases: :py:obj:`MultiModalHARSeriesDataModule` - Bases: :py:obj:`lightning.LightningDataModule` Define the dataloaders for train, validation and test splits for @@ -96,6 +95,92 @@ Functions num_workers : int, optional Number of workers to load data. If None, then use all cores + + .. py:method:: _load_dataset(split_name) + + Create a ``MultiModalSeriesCSVDataset`` dataset with the given split. + + Parameters + ---------- + split_name : str + The name of the split. This must be one of: "train", "validation", + "test" or "predict". + + Returns + ------- + MultiModalSeriesCSVDataset + A MultiModalSeriesCSVDataset dataset with the given split. + + + +.. py:class:: MultiModalHARSeriesDataModule(data_path, feature_prefixes = ('accel-x', 'accel-y', 'accel-z', 'gyro-x', 'gyro-y', 'gyro-z'), label = 'standard activity code', features_as_channels = True, transforms = None, cast_to = 'float32', batch_size = 1, num_workers = None, data_percentage = 1.0, domain_info = False) + + Bases: :py:obj:`ssl_tools.data.data_modules.base.SimpleDataModule` + + + + Define the dataloaders for train, validation and test splits for + HAR datasets. This datasets assumes that the data is in a single CSV + file with series of data. Each row is a single sample that can be + composed of multiple modalities (series). Each column is a feature of + some series with the prefix indicating the series. The suffix may + indicates the time step. For instance, if we have two series, accel-x + and accel-y, the data will look something like: + + +-----------+-----------+-----------+-----------+--------+ + | accel-x-0 | accel-x-1 | accel-y-0 | accel-y-1 | class | + +-----------+-----------+-----------+-----------+--------+ + | 0.502123 | 0.02123 | 0.502123 | 0.502123 | 0 | + | 0.6820123 | 0.02123 | 0.502123 | 0.502123 | 1 | + | 0.498217 | 0.00001 | 1.414141 | 3.141592 | 2 | + +-----------+-----------+-----------+-----------+--------+ + + The ``feature_prefixes`` parameter is used to select the columns that + will be used as features. For instance, if we want to use only the + accel-x series, we can set ``feature_prefixes=["accel-x"]``. If we want + to use both accel-x and accel-y, we can set + ``feature_prefixes=["accel-x", "accel-y"]``. If None is passed, all + columns will be used as features, except the label column. + The label column is specified by the ``label`` parameter. + + The dataset will return a 2-element tuple with the data and the label, + if the ``label`` parameter is specified, otherwise return only the data. + + If ``features_as_channels`` is ``True``, the data will be returned as a + vector of shape `(C, T)`, where C is the number of channels (features) + and `T` is the number of time steps. Else, the data will be returned as + a vector of shape T*C (a single vector with all the features). + + Parameters + ---------- + data_path : PathLike + The path to the folder with "train.csv", "validation.csv" and + "test.csv" files inside it. + feature_prefixes : Union[str, List[str]], optional + The prefix of the column names in the dataframe that will be used + to become features. If None, all columns except the label will be + used as features. + label : str, optional + The name of the column that will be used as label + features_as_channels : bool, optional + If True, the data will be returned as a vector of shape (C, T), + else the data will be returned as a vector of shape T*C. + cast_to: str, optional + Cast the numpy data to the specified type + transforms : Union[List[Callable], Dict[str, List[Callable]]], optional + This could be: + - None: No transforms will be applied + - List[Callable]: A list of transforms that will be applied to the + data. The same transforms will be applied to all splits. + - Dict[str, List[Callable]]: A dictionary with the split name as + key and a list of transforms as value. The split name must be + one of: "train", "validation", "test" or "predict". + batch_size : int, optional + The size of the batch + num_workers : int, optional + Number of workers to load data. If None, then use all cores + + .. py:method:: __repr__() @@ -120,6 +205,7 @@ Functions A dataloader for the given split. + .. py:method:: _load_dataset(split_name) Create a ``MultiModalSeriesCSVDataset`` dataset with the given split. @@ -136,6 +222,7 @@ Functions A MultiModalSeriesCSVDataset dataset with the given split. + .. py:method:: predict_dataloader() @@ -159,6 +246,7 @@ Functions If the stage is not one of: "fit", "test" or "predict" + .. py:method:: test_dataloader() @@ -168,11 +256,10 @@ Functions .. py:method:: val_dataloader() - .. py:class:: TFCDataModule(data_path, feature_prefixes = ('accel-x', 'accel-y', 'accel-z', 'gyro-x', 'gyro-y', 'gyro-z'), label = 'standard activity code', features_as_channels = True, length_alignment = 60, time_transforms = None, frequency_transforms = None, cast_to = 'float32', jitter_ratio = 2, only_time_frequency = False, batch_size = 32, num_workers = None) + Bases: :py:obj:`ssl_tools.data.data_modules.base.SimpleDataModule` - Bases: :py:obj:`lightning.LightningDataModule` Define a dataloader for ``TFCDataset``. This is a wrapper around @@ -205,7 +292,7 @@ Functions - Dict[str, List[Callable]]: A dictionary with the split name as key and a list of transforms as value. The split name must be one of: "train", "validation", "test" or "predict". - If None. an ``AddGaussianNoise`` transform will be used with the + If None. an ``AddGaussianNoise`` transform will be used with the given ``jitter_ratio``. frequency_transforms : Union[List[Callable], Dict[str, List[Callable]]], optional Transforms to be applied to frequency domain data. This could be: @@ -214,7 +301,7 @@ Functions data. The same transforms will be applied to all splits. - Dict[str, List[Callable]]: A dictionary with the split name as key and a list of transforms as value. The split name must be - one of: "train", "validation", "test" or "predict". + one of: "train", "validation", "test" or "predict". If None, an ``AddRemoveFrequency`` transform will be used. cast_to : str, optional Cast the data to the given type, by default "float32" @@ -231,6 +318,7 @@ Functions num_workers : int, optional Number of workers to load data, by default None (use all cores) + .. py:method:: _get_loader(split_name, shuffle) Get a dataloader for the given split. @@ -249,6 +337,7 @@ Functions A dataloader for the given split. + .. py:method:: _load_dataset(split_name) Create a ``TFCDataset`` @@ -258,13 +347,14 @@ Functions split_name : str Name of the split (train, validation or test). This will be used to load the corresponding CSV file. - + Returns ------- TFCDataset A TFC dataset with the given split. + .. py:method:: predict_dataloader() @@ -288,6 +378,7 @@ Functions If the stage is not one of: "fit", "test" or "predict" + .. py:method:: test_dataloader() @@ -297,12 +388,11 @@ Functions .. py:method:: val_dataloader() - .. py:class:: TNCHARDataModule(data_path, features = ('accel-x', 'accel-y', 'accel-z', 'gyro-x', 'gyro-y', 'gyro-z'), label = None, pad = False, transforms = None, batch_size = 1, num_workers = None, cast_to = 'float32', window_size = 60, mc_sample_size = 20, significance_level = 0.01, repeat = 1) - Bases: :py:obj:`UserActivityFolderDataModule` + Define the dataloaders for train, validation and test splits for TNC datasets. The data must be in the following folder structure: @@ -390,6 +480,7 @@ Functions repeat : int, optional Simple repeat the element of the dataset ``repeat`` times, + .. py:method:: _load_dataset(split_name) Create a ``TNCDataset`` dataset with the given split. @@ -409,8 +500,8 @@ Functions .. py:class:: UserActivityFolderDataModule(data_path, features = ('accel-x', 'accel-y', 'accel-z', 'gyro-x', 'gyro-y', 'gyro-z'), label = None, pad = False, transforms = None, cast_to = 'float32', batch_size = 1, num_workers = None) + Bases: :py:obj:`ssl_tools.data.data_modules.base.SimpleDataModule` - Bases: :py:obj:`lightning.LightningDataModule` Define the dataloaders for train, validation and test splits for @@ -487,6 +578,7 @@ Functions num_workers : int, optional Number of workers to load data. If None, then use all cores + .. py:method:: __repr__() @@ -511,6 +603,7 @@ Functions A dataloader for the given split. + .. py:method:: _load_dataset(split_name) Create a ``SeriesFolderCSVDataset`` dataset with the given split. @@ -527,6 +620,7 @@ Functions The dataset with the given split. + .. py:method:: predict_dataloader() @@ -550,6 +644,7 @@ Functions If the stage is not one of: "fit", "test" or "predict" + .. py:method:: test_dataloader() @@ -559,7 +654,6 @@ Functions .. py:method:: val_dataloader() - .. py:function:: parse_num_workers(num_workers) Parse the num_workers parameter. If None, use all cores. diff --git a/_sources/autoapi/ssl_tools/data/data_modules/index.rst.txt b/_sources/autoapi/ssl_tools/data/data_modules/index.rst.txt index 262a273..0f4a759 100644 --- a/_sources/autoapi/ssl_tools/data/data_modules/index.rst.txt +++ b/_sources/autoapi/ssl_tools/data/data_modules/index.rst.txt @@ -1,38 +1,65 @@ -:py:mod:`ssl_tools.data.data_modules` -===================================== +ssl_tools.data.data_modules +=========================== .. py:module:: ssl_tools.data.data_modules Submodules ---------- + .. toctree:: - :titlesonly: :maxdepth: 1 - har/index.rst - + /autoapi/ssl_tools/data/data_modules/base/index + /autoapi/ssl_tools/data/data_modules/covid_anomaly/index + /autoapi/ssl_tools/data/data_modules/har/index -Package Contents ----------------- Classes -~~~~~~~ +------- .. autoapisummary:: + ssl_tools.data.data_modules.CovidUserAnomalyDataModule ssl_tools.data.data_modules.MultiModalHARSeriesDataModule ssl_tools.data.data_modules.TFCDataModule ssl_tools.data.data_modules.TNCHARDataModule ssl_tools.data.data_modules.UserActivityFolderDataModule +Package Contents +---------------- + +.. py:class:: CovidUserAnomalyDataModule(data_path, participants = None, feature_column_prefix = 'RHR', target_column = 'anomaly', participant_column = 'participant_id', include_recovered_in_test = False, reshape = None, train_transforms = None, batch_size = 32, num_workers = 1, validation_split = 0.2, dataset_transforms = None, shuffle_train = True, discard_last_batch = False, balance = False, train_baseline_only = True) + + Bases: :py:obj:`lightning.LightningDataModule` -.. py:class:: MultiModalHARSeriesDataModule(data_path, feature_prefixes = ('accel-x', 'accel-y', 'accel-z', 'gyro-x', 'gyro-y', 'gyro-z'), label = 'standard activity code', features_as_channels = True, transforms = None, cast_to = 'float32', batch_size = 1, num_workers = None) + .. py:method:: __repr__() - Bases: :py:obj:`lightning.LightningDataModule` + .. py:method:: __str__() + + + .. py:method:: predict_dataloader() + + + .. py:method:: setup(stage) + + + .. py:method:: test_dataloader() + + + .. py:method:: train_dataloader() + + + .. py:method:: val_dataloader() + + +.. py:class:: MultiModalHARSeriesDataModule(data_path, feature_prefixes = ('accel-x', 'accel-y', 'accel-z', 'gyro-x', 'gyro-y', 'gyro-z'), label = 'standard activity code', features_as_channels = True, transforms = None, cast_to = 'float32', batch_size = 1, num_workers = None, data_percentage = 1.0, domain_info = False) + + Bases: :py:obj:`ssl_tools.data.data_modules.base.SimpleDataModule` + Define the dataloaders for train, validation and test splits for @@ -96,6 +123,7 @@ Classes num_workers : int, optional Number of workers to load data. If None, then use all cores + .. py:method:: __repr__() @@ -120,6 +148,7 @@ Classes A dataloader for the given split. + .. py:method:: _load_dataset(split_name) Create a ``MultiModalSeriesCSVDataset`` dataset with the given split. @@ -136,6 +165,7 @@ Classes A MultiModalSeriesCSVDataset dataset with the given split. + .. py:method:: predict_dataloader() @@ -159,6 +189,7 @@ Classes If the stage is not one of: "fit", "test" or "predict" + .. py:method:: test_dataloader() @@ -168,11 +199,10 @@ Classes .. py:method:: val_dataloader() - .. py:class:: TFCDataModule(data_path, feature_prefixes = ('accel-x', 'accel-y', 'accel-z', 'gyro-x', 'gyro-y', 'gyro-z'), label = 'standard activity code', features_as_channels = True, length_alignment = 60, time_transforms = None, frequency_transforms = None, cast_to = 'float32', jitter_ratio = 2, only_time_frequency = False, batch_size = 32, num_workers = None) + Bases: :py:obj:`ssl_tools.data.data_modules.base.SimpleDataModule` - Bases: :py:obj:`lightning.LightningDataModule` Define a dataloader for ``TFCDataset``. This is a wrapper around @@ -205,7 +235,7 @@ Classes - Dict[str, List[Callable]]: A dictionary with the split name as key and a list of transforms as value. The split name must be one of: "train", "validation", "test" or "predict". - If None. an ``AddGaussianNoise`` transform will be used with the + If None. an ``AddGaussianNoise`` transform will be used with the given ``jitter_ratio``. frequency_transforms : Union[List[Callable], Dict[str, List[Callable]]], optional Transforms to be applied to frequency domain data. This could be: @@ -214,7 +244,7 @@ Classes data. The same transforms will be applied to all splits. - Dict[str, List[Callable]]: A dictionary with the split name as key and a list of transforms as value. The split name must be - one of: "train", "validation", "test" or "predict". + one of: "train", "validation", "test" or "predict". If None, an ``AddRemoveFrequency`` transform will be used. cast_to : str, optional Cast the data to the given type, by default "float32" @@ -231,6 +261,7 @@ Classes num_workers : int, optional Number of workers to load data, by default None (use all cores) + .. py:method:: _get_loader(split_name, shuffle) Get a dataloader for the given split. @@ -249,6 +280,7 @@ Classes A dataloader for the given split. + .. py:method:: _load_dataset(split_name) Create a ``TFCDataset`` @@ -258,13 +290,14 @@ Classes split_name : str Name of the split (train, validation or test). This will be used to load the corresponding CSV file. - + Returns ------- TFCDataset A TFC dataset with the given split. + .. py:method:: predict_dataloader() @@ -288,6 +321,7 @@ Classes If the stage is not one of: "fit", "test" or "predict" + .. py:method:: test_dataloader() @@ -297,12 +331,11 @@ Classes .. py:method:: val_dataloader() - .. py:class:: TNCHARDataModule(data_path, features = ('accel-x', 'accel-y', 'accel-z', 'gyro-x', 'gyro-y', 'gyro-z'), label = None, pad = False, transforms = None, batch_size = 1, num_workers = None, cast_to = 'float32', window_size = 60, mc_sample_size = 20, significance_level = 0.01, repeat = 1) - Bases: :py:obj:`UserActivityFolderDataModule` + Define the dataloaders for train, validation and test splits for TNC datasets. The data must be in the following folder structure: @@ -390,6 +423,7 @@ Classes repeat : int, optional Simple repeat the element of the dataset ``repeat`` times, + .. py:method:: _load_dataset(split_name) Create a ``TNCDataset`` dataset with the given split. @@ -409,8 +443,8 @@ Classes .. py:class:: UserActivityFolderDataModule(data_path, features = ('accel-x', 'accel-y', 'accel-z', 'gyro-x', 'gyro-y', 'gyro-z'), label = None, pad = False, transforms = None, cast_to = 'float32', batch_size = 1, num_workers = None) + Bases: :py:obj:`ssl_tools.data.data_modules.base.SimpleDataModule` - Bases: :py:obj:`lightning.LightningDataModule` Define the dataloaders for train, validation and test splits for @@ -487,6 +521,7 @@ Classes num_workers : int, optional Number of workers to load data. If None, then use all cores + .. py:method:: __repr__() @@ -511,6 +546,7 @@ Classes A dataloader for the given split. + .. py:method:: _load_dataset(split_name) Create a ``SeriesFolderCSVDataset`` dataset with the given split. @@ -527,6 +563,7 @@ Classes The dataset with the given split. + .. py:method:: predict_dataloader() @@ -550,6 +587,7 @@ Classes If the stage is not one of: "fit", "test" or "predict" + .. py:method:: test_dataloader() @@ -559,4 +597,3 @@ Classes .. py:method:: val_dataloader() - diff --git a/_sources/autoapi/ssl_tools/data/datasets/augmented_dataset/index.rst.txt b/_sources/autoapi/ssl_tools/data/datasets/augmented_dataset/index.rst.txt new file mode 100644 index 0000000..39a32aa --- /dev/null +++ b/_sources/autoapi/ssl_tools/data/datasets/augmented_dataset/index.rst.txt @@ -0,0 +1,45 @@ +ssl_tools.data.datasets.augmented_dataset +========================================= + +.. py:module:: ssl_tools.data.datasets.augmented_dataset + + +Classes +------- + +.. autoapisummary:: + + ssl_tools.data.datasets.augmented_dataset.AugmentedDataset + + +Module Contents +--------------- + +.. py:class:: AugmentedDataset(dataset, transforms) + + Bases: :py:obj:`torch.utils.data.Dataset` + + + Note: this class assumes that dataset is a Dataset object, and that + the __getitem__ method of the dataset returns a tuple of n elements. + + _summary_ + + Parameters + ---------- + dataset : Dataset + _description_ + transforms : Dict[int, Callable] + As each element (result of __getitem__) of the dataset is a + n-element tuple, the transforms are applied to the n-th element + of the tuple. The key of the dictionary is the index of the + element of the tuple to apply the transform (0-indexed), and the + value is the transform to apply. + + + .. py:method:: __getitem__(idx) + + + .. py:method:: __len__() + + diff --git a/_sources/autoapi/ssl_tools/data/datasets/domain_dataset/index.rst.txt b/_sources/autoapi/ssl_tools/data/datasets/domain_dataset/index.rst.txt new file mode 100644 index 0000000..655d692 --- /dev/null +++ b/_sources/autoapi/ssl_tools/data/datasets/domain_dataset/index.rst.txt @@ -0,0 +1,25 @@ +ssl_tools.data.datasets.domain_dataset +====================================== + +.. py:module:: ssl_tools.data.datasets.domain_dataset + + +Classes +------- + +.. autoapisummary:: + + ssl_tools.data.datasets.domain_dataset.DomainDataset + + +Module Contents +--------------- + +.. py:class:: DomainDataset(dataset, domain) + + .. py:method:: __getitem__(idx) + + + .. py:method:: __len__() + + diff --git a/_sources/autoapi/ssl_tools/data/datasets/index.rst.txt b/_sources/autoapi/ssl_tools/data/datasets/index.rst.txt index f2308f0..42e3825 100644 --- a/_sources/autoapi/ssl_tools/data/datasets/index.rst.txt +++ b/_sources/autoapi/ssl_tools/data/datasets/index.rst.txt @@ -1,39 +1,65 @@ -:py:mod:`ssl_tools.data.datasets` -================================= +ssl_tools.data.datasets +======================= .. py:module:: ssl_tools.data.datasets Submodules ---------- + .. toctree:: - :titlesonly: :maxdepth: 1 - series_dataset/index.rst - tfc/index.rst - tnc/index.rst - + /autoapi/ssl_tools/data/datasets/augmented_dataset/index + /autoapi/ssl_tools/data/datasets/domain_dataset/index + /autoapi/ssl_tools/data/datasets/series_dataset/index + /autoapi/ssl_tools/data/datasets/tfc/index + /autoapi/ssl_tools/data/datasets/tnc/index -Package Contents ----------------- Classes -~~~~~~~ +------- .. autoapisummary:: + ssl_tools.data.datasets.MultiModalDataframeDataset ssl_tools.data.datasets.MultiModalSeriesCSVDataset ssl_tools.data.datasets.SeriesFolderCSVDataset ssl_tools.data.datasets.TFCDataset ssl_tools.data.datasets.TNCDataset +Package Contents +---------------- +.. py:class:: MultiModalDataframeDataset(df, feature_column_prefix = 'RHR', target_column = 'anomaly', reshape = None, transforms = None, name = 'participant', dataset_transforms = None, balance = False) -.. py:class:: MultiModalSeriesCSVDataset(data_path, feature_prefixes = None, label = None, features_as_channels = True, cast_to = 'float32', transforms = None) + .. py:method:: __getitem__(index) + + + .. py:method:: __len__() + + + .. py:method:: __repr__() + + Return repr(self). + + + + .. py:method:: __str__() + + Return str(self). + + + + .. py:method:: _balance() + + + .. py:method:: _dataset_transform() +.. py:class:: MultiModalSeriesCSVDataset(data_path, feature_prefixes = None, label = None, features_as_channels = True, cast_to = 'float32', transforms = None) + This datasets assumes that the data is in a single CSV file with series of data. Each row is a single sample that can be composed of @@ -116,6 +142,7 @@ Classes 3 + .. py:method:: __getitem__(index) @@ -127,11 +154,13 @@ Classes Return repr(self). + .. py:method:: __str__() Return str(self). + .. py:method:: _load_data() Load data from the CSV file @@ -146,7 +175,6 @@ Classes .. py:class:: SeriesFolderCSVDataset(data_path, features = None, label = None, pad = False, cast_to = 'float32', transforms = None, lazy = False) - This dataset assumes that the data is in a folder with multiple CSV files. Each CSV file is a single sample that can be composed of @@ -238,6 +266,7 @@ Classes If True, the data will be loaded lazily (i.e. the CSV files will be read only when needed) + .. py:method:: __getitem__(idx) Get a single sample from the dataset @@ -254,6 +283,7 @@ Classes specified, otherwise only the data. + .. py:method:: __len__() @@ -262,16 +292,19 @@ Classes Return repr(self). + .. py:method:: __str__() Return str(self). + .. py:method:: _disable_fix_length() Decorator to disable fix_length when calling a function + .. py:method:: _get_longest_sample_size() Return the size of the longest sample in the dataset @@ -282,6 +315,7 @@ Classes The size of the longest sample in the dataset + .. py:method:: _pad_data(data) Pad the data to the length of the longest sample. In summary, this @@ -298,6 +332,7 @@ Classes The padded data + .. py:method:: _read_all_csv() Read all the CSV files in the data directory @@ -308,6 +343,7 @@ Classes A list of 2-element tuple with the data and the label. If the label is not specified, the second element of the tuples are None. + .. py:method:: _read_csv(path) Read a single CSV file (a single sample) @@ -324,6 +360,7 @@ Classes specified, the second element is None. + .. py:method:: _scan_data() List the CSV files in the data directory @@ -337,9 +374,9 @@ Classes .. py:class:: TFCDataset(data, length_alignment = 178, time_transforms = None, frequency_transforms = None, cast_to = 'float32', only_time_frequency = False) - Bases: :py:obj:`torch.utils.data.Dataset` + Time-Frequency Contrastive (TFC) Dataset. This dataset is intented to be used using TFC technique. Given a dataset with time-domain signal, @@ -395,8 +432,8 @@ Classes torch.Tensor([[-0.5020, -0.5020, -0.5020, ..., -0.5020, -0.5020, -0.5020]]), # frequency augmented ) - .. py:class:: FFT(absolute = True) + .. py:class:: FFT(absolute = True) Simple wrapper to apply the FFT to the data @@ -406,6 +443,7 @@ Classes absolute : bool, optional If True, returns the absolute value of FFT, by default True + .. py:method:: __call__(x) Apply the FFT to the data @@ -422,6 +460,7 @@ Classes + .. py:method:: __getitem__(index) @@ -445,6 +484,7 @@ Classes The transformed data + .. py:method:: _apply_transforms_per_axis(data, transforms) Split the data into channels and apply the transforms to each channel @@ -469,9 +509,9 @@ Classes .. py:class:: TNCDataset(data, window_size, mc_sample_size = 20, significance_level = 0.01, repeat = 1, cast_to = 'float32') - Bases: :py:obj:`torch.utils.data.Dataset` + Temporal Neighbourhood Coding (TNC) dataset. This dataset is used to pre-train self-supervised models. The dataset obtain close and @@ -508,6 +548,7 @@ Classes cast_to : str, optional Cast the data to the given type, by default "float32" + .. py:method:: __getitem__(idx) Get a sample from the dataset. The sample is a tuple with 3 @@ -531,6 +572,7 @@ Classes A tuple with 3 elements (W_t, X_p, X_n). + .. py:method:: __len__() @@ -555,6 +597,7 @@ Classes is the delta, used to adjust the neighbourhood size. + .. py:method:: _find_non_neighours(data, t, delta = 0.0) Find distant samples. The samples will be selected from the diff --git a/_sources/autoapi/ssl_tools/data/datasets/series_dataset/index.rst.txt b/_sources/autoapi/ssl_tools/data/datasets/series_dataset/index.rst.txt index d1e4cd9..f2b50fc 100644 --- a/_sources/autoapi/ssl_tools/data/datasets/series_dataset/index.rst.txt +++ b/_sources/autoapi/ssl_tools/data/datasets/series_dataset/index.rst.txt @@ -1,26 +1,50 @@ -:py:mod:`ssl_tools.data.datasets.series_dataset` -================================================ +ssl_tools.data.datasets.series_dataset +====================================== .. py:module:: ssl_tools.data.datasets.series_dataset -Module Contents ---------------- - Classes -~~~~~~~ +------- .. autoapisummary:: + ssl_tools.data.datasets.series_dataset.MultiModalDataframeDataset ssl_tools.data.datasets.series_dataset.MultiModalSeriesCSVDataset ssl_tools.data.datasets.series_dataset.SeriesFolderCSVDataset +Module Contents +--------------- +.. py:class:: MultiModalDataframeDataset(df, feature_column_prefix = 'RHR', target_column = 'anomaly', reshape = None, transforms = None, name = 'participant', dataset_transforms = None, balance = False) -.. py:class:: MultiModalSeriesCSVDataset(data_path, feature_prefixes = None, label = None, features_as_channels = True, cast_to = 'float32', transforms = None) + .. py:method:: __getitem__(index) + + + .. py:method:: __len__() + + + .. py:method:: __repr__() + + Return repr(self). + + + + .. py:method:: __str__() + + Return str(self). + + .. py:method:: _balance() + + + .. py:method:: _dataset_transform() + + +.. py:class:: MultiModalSeriesCSVDataset(data_path, feature_prefixes = None, label = None, features_as_channels = True, cast_to = 'float32', transforms = None) + This datasets assumes that the data is in a single CSV file with series of data. Each row is a single sample that can be composed of @@ -103,6 +127,7 @@ Classes 3 + .. py:method:: __getitem__(index) @@ -114,11 +139,13 @@ Classes Return repr(self). + .. py:method:: __str__() Return str(self). + .. py:method:: _load_data() Load data from the CSV file @@ -133,7 +160,6 @@ Classes .. py:class:: SeriesFolderCSVDataset(data_path, features = None, label = None, pad = False, cast_to = 'float32', transforms = None, lazy = False) - This dataset assumes that the data is in a folder with multiple CSV files. Each CSV file is a single sample that can be composed of @@ -225,6 +251,7 @@ Classes If True, the data will be loaded lazily (i.e. the CSV files will be read only when needed) + .. py:method:: __getitem__(idx) Get a single sample from the dataset @@ -241,6 +268,7 @@ Classes specified, otherwise only the data. + .. py:method:: __len__() @@ -249,16 +277,19 @@ Classes Return repr(self). + .. py:method:: __str__() Return str(self). + .. py:method:: _disable_fix_length() Decorator to disable fix_length when calling a function + .. py:method:: _get_longest_sample_size() Return the size of the longest sample in the dataset @@ -269,6 +300,7 @@ Classes The size of the longest sample in the dataset + .. py:method:: _pad_data(data) Pad the data to the length of the longest sample. In summary, this @@ -285,6 +317,7 @@ Classes The padded data + .. py:method:: _read_all_csv() Read all the CSV files in the data directory @@ -295,6 +328,7 @@ Classes A list of 2-element tuple with the data and the label. If the label is not specified, the second element of the tuples are None. + .. py:method:: _read_csv(path) Read a single CSV file (a single sample) @@ -311,6 +345,7 @@ Classes specified, the second element is None. + .. py:method:: _scan_data() List the CSV files in the data directory diff --git a/_sources/autoapi/ssl_tools/data/datasets/tfc/index.rst.txt b/_sources/autoapi/ssl_tools/data/datasets/tfc/index.rst.txt index fd37610..aa1df0d 100644 --- a/_sources/autoapi/ssl_tools/data/datasets/tfc/index.rst.txt +++ b/_sources/autoapi/ssl_tools/data/datasets/tfc/index.rst.txt @@ -1,27 +1,25 @@ -:py:mod:`ssl_tools.data.datasets.tfc` -===================================== +ssl_tools.data.datasets.tfc +=========================== .. py:module:: ssl_tools.data.datasets.tfc -Module Contents ---------------- - Classes -~~~~~~~ +------- .. autoapisummary:: ssl_tools.data.datasets.tfc.TFCDataset - +Module Contents +--------------- .. py:class:: TFCDataset(data, length_alignment = 178, time_transforms = None, frequency_transforms = None, cast_to = 'float32', only_time_frequency = False) - Bases: :py:obj:`torch.utils.data.Dataset` + Time-Frequency Contrastive (TFC) Dataset. This dataset is intented to be used using TFC technique. Given a dataset with time-domain signal, @@ -77,8 +75,8 @@ Classes torch.Tensor([[-0.5020, -0.5020, -0.5020, ..., -0.5020, -0.5020, -0.5020]]), # frequency augmented ) - .. py:class:: FFT(absolute = True) + .. py:class:: FFT(absolute = True) Simple wrapper to apply the FFT to the data @@ -88,6 +86,7 @@ Classes absolute : bool, optional If True, returns the absolute value of FFT, by default True + .. py:method:: __call__(x) Apply the FFT to the data @@ -104,6 +103,7 @@ Classes + .. py:method:: __getitem__(index) @@ -127,6 +127,7 @@ Classes The transformed data + .. py:method:: _apply_transforms_per_axis(data, transforms) Split the data into channels and apply the transforms to each channel diff --git a/_sources/autoapi/ssl_tools/data/datasets/tnc/index.rst.txt b/_sources/autoapi/ssl_tools/data/datasets/tnc/index.rst.txt index 88c84a4..a6bf8b9 100644 --- a/_sources/autoapi/ssl_tools/data/datasets/tnc/index.rst.txt +++ b/_sources/autoapi/ssl_tools/data/datasets/tnc/index.rst.txt @@ -1,27 +1,25 @@ -:py:mod:`ssl_tools.data.datasets.tnc` -===================================== +ssl_tools.data.datasets.tnc +=========================== .. py:module:: ssl_tools.data.datasets.tnc -Module Contents ---------------- - Classes -~~~~~~~ +------- .. autoapisummary:: ssl_tools.data.datasets.tnc.TNCDataset - +Module Contents +--------------- .. py:class:: TNCDataset(data, window_size, mc_sample_size = 20, significance_level = 0.01, repeat = 1, cast_to = 'float32') - Bases: :py:obj:`torch.utils.data.Dataset` + Temporal Neighbourhood Coding (TNC) dataset. This dataset is used to pre-train self-supervised models. The dataset obtain close and @@ -58,6 +56,7 @@ Classes cast_to : str, optional Cast the data to the given type, by default "float32" + .. py:method:: __getitem__(idx) Get a sample from the dataset. The sample is a tuple with 3 @@ -81,6 +80,7 @@ Classes A tuple with 3 elements (W_t, X_p, X_n). + .. py:method:: __len__() @@ -105,6 +105,7 @@ Classes is the delta, used to adjust the neighbourhood size. + .. py:method:: _find_non_neighours(data, t, delta = 0.0) Find distant samples. The samples will be selected from the diff --git a/_sources/autoapi/ssl_tools/data/index.rst.txt b/_sources/autoapi/ssl_tools/data/index.rst.txt index 177a291..2db45d3 100644 --- a/_sources/autoapi/ssl_tools/data/index.rst.txt +++ b/_sources/autoapi/ssl_tools/data/index.rst.txt @@ -1,16 +1,16 @@ -:py:mod:`ssl_tools.data` -======================== +ssl_tools.data +============== .. py:module:: ssl_tools.data Subpackages ----------- + .. toctree:: - :titlesonly: - :maxdepth: 3 + :maxdepth: 1 - data_modules/index.rst - datasets/index.rst + /autoapi/ssl_tools/data/data_modules/index + /autoapi/ssl_tools/data/datasets/index diff --git a/_sources/autoapi/ssl_tools/experiments/covid_detection/anomaly_detection_base/index.rst.txt b/_sources/autoapi/ssl_tools/experiments/covid_detection/anomaly_detection_base/index.rst.txt new file mode 100644 index 0000000..b9fca38 --- /dev/null +++ b/_sources/autoapi/ssl_tools/experiments/covid_detection/anomaly_detection_base/index.rst.txt @@ -0,0 +1,130 @@ +ssl_tools.experiments.covid_detection.anomaly_detection_base +============================================================ + +.. py:module:: ssl_tools.experiments.covid_detection.anomaly_detection_base + + +Classes +------- + +.. autoapisummary:: + + ssl_tools.experiments.covid_detection.anomaly_detection_base.CovidAnomalyDetectionEvaluator + ssl_tools.experiments.covid_detection.anomaly_detection_base.CovidAnomalyDetectionTrain + ssl_tools.experiments.covid_detection.anomaly_detection_base.RMSELoss + + +Functions +--------- + +.. autoapisummary:: + + ssl_tools.experiments.covid_detection.anomaly_detection_base.kmeans_threshold + ssl_tools.experiments.covid_detection.anomaly_detection_base.mean_absolute_error + ssl_tools.experiments.covid_detection.anomaly_detection_base.mean_squared_error + ssl_tools.experiments.covid_detection.anomaly_detection_base.root_mean_squared_error + ssl_tools.experiments.covid_detection.anomaly_detection_base.sigma_threshold + ssl_tools.experiments.covid_detection.anomaly_detection_base.zscore_threshold_max + ssl_tools.experiments.covid_detection.anomaly_detection_base.zscore_threshold_std + + +Module Contents +--------------- + +.. py:class:: CovidAnomalyDetectionEvaluator(train_data, test_data, train_participant, test_participant, input_shape, feature_column_prefix = 'RHR', target_column = 'anomaly', include_recovered_in_test = False, results_dir = 'results', *args, **kwargs) + + Bases: :py:obj:`ssl_tools.experiments.LightningTest` + + + Helper class that provides a standard way to create an ABC using + inheritance. + + + .. py:method:: _calc_static_anomaly_thresholds(losses) + + + .. py:method:: get_data_module() + + Get the datamodule to use for the experiment. + + Returns + ------- + L.LightningDataModule + The datamodule to use for the experiment + + + + .. py:method:: get_model() + :abstractmethod: + + + Get the model to use for the experiment. + + Returns + ------- + L.LightningModule + The model to use for the experiment + + + + .. py:method:: run_model(model, data_module, trainer) + + +.. py:class:: CovidAnomalyDetectionTrain(data, input_shape, participant_ids = None, validation_split = 0.1, augment = False, feature_column_prefix = 'RHR', target_column = 'anomaly', *args, **kwargs) + + Bases: :py:obj:`ssl_tools.experiments.LightningTrain` + + + Helper class that provides a standard way to create an ABC using + inheritance. + + + .. py:method:: _get_transforms() + + + .. py:method:: get_data_module() + + Get the datamodule to use for the experiment. + + Returns + ------- + L.LightningDataModule + The datamodule to use for the experiment + + + + .. py:method:: get_model() + :abstractmethod: + + + Get the model to use for the experiment. + + Returns + ------- + L.LightningModule + The model to use for the experiment + + + +.. py:class:: RMSELoss(eps=1e-06, *args, **kwargs) + + Bases: :py:obj:`torch.nn.Module` + + + .. py:method:: forward(y_hat, y) + + +.. py:function:: kmeans_threshold(X_recon, n_clusters=1) + +.. py:function:: mean_absolute_error(X, X_recon) + +.. py:function:: mean_squared_error(X, X_recon) + +.. py:function:: root_mean_squared_error(X, X_recon) + +.. py:function:: sigma_threshold(X_recon, sigma) + +.. py:function:: zscore_threshold_max(X_recon) + +.. py:function:: zscore_threshold_std(X_recon, std) + diff --git a/_sources/autoapi/ssl_tools/experiments/covid_detection/cae/index.rst.txt b/_sources/autoapi/ssl_tools/experiments/covid_detection/cae/index.rst.txt new file mode 100644 index 0000000..c801a53 --- /dev/null +++ b/_sources/autoapi/ssl_tools/experiments/covid_detection/cae/index.rst.txt @@ -0,0 +1,78 @@ +ssl_tools.experiments.covid_detection.cae +========================================= + +.. py:module:: ssl_tools.experiments.covid_detection.cae + + +Attributes +---------- + +.. autoapisummary:: + + ssl_tools.experiments.covid_detection.cae.options + + +Classes +------- + +.. autoapisummary:: + + ssl_tools.experiments.covid_detection.cae.ConvolutionalAutoencoderAnomalyDetectionTest + ssl_tools.experiments.covid_detection.cae.ConvolutionalAutoencoderAnomalyDetectionTrain + + +Module Contents +--------------- + +.. py:class:: ConvolutionalAutoencoderAnomalyDetectionTest(train_data, test_data, train_participant, test_participant, input_shape, feature_column_prefix = 'RHR', target_column = 'anomaly', include_recovered_in_test = False, results_dir = 'results', *args, **kwargs) + + Bases: :py:obj:`ssl_tools.experiments.covid_detection.anomaly_detection_base.CovidAnomalyDetectionEvaluator` + + + Helper class that provides a standard way to create an ABC using + inheritance. + + + .. py:attribute:: _MODEL_NAME + :value: 'cae' + + + + .. py:method:: get_model() + + Get the model to use for the experiment. + + Returns + ------- + L.LightningModule + The model to use for the experiment + + + +.. py:class:: ConvolutionalAutoencoderAnomalyDetectionTrain(data, input_shape, participant_ids = None, validation_split = 0.1, augment = False, feature_column_prefix = 'RHR', target_column = 'anomaly', *args, **kwargs) + + Bases: :py:obj:`ssl_tools.experiments.covid_detection.anomaly_detection_base.CovidAnomalyDetectionTrain` + + + Helper class that provides a standard way to create an ABC using + inheritance. + + + .. py:attribute:: _MODEL_NAME + :value: 'cae' + + + + .. py:method:: get_model() + + Get the model to use for the experiment. + + Returns + ------- + L.LightningModule + The model to use for the experiment + + + +.. py:data:: options + diff --git a/_sources/autoapi/ssl_tools/experiments/covid_detection/cae2d/index.rst.txt b/_sources/autoapi/ssl_tools/experiments/covid_detection/cae2d/index.rst.txt new file mode 100644 index 0000000..7dc358d --- /dev/null +++ b/_sources/autoapi/ssl_tools/experiments/covid_detection/cae2d/index.rst.txt @@ -0,0 +1,78 @@ +ssl_tools.experiments.covid_detection.cae2d +=========================================== + +.. py:module:: ssl_tools.experiments.covid_detection.cae2d + + +Attributes +---------- + +.. autoapisummary:: + + ssl_tools.experiments.covid_detection.cae2d.options + + +Classes +------- + +.. autoapisummary:: + + ssl_tools.experiments.covid_detection.cae2d.ConvolutionalAutoencoder2DAnomalyDetectionTest + ssl_tools.experiments.covid_detection.cae2d.ConvolutionalAutoencoder2DAnomalyDetectionTrain + + +Module Contents +--------------- + +.. py:class:: ConvolutionalAutoencoder2DAnomalyDetectionTest(train_data, test_data, train_participant, test_participant, input_shape, feature_column_prefix = 'RHR', target_column = 'anomaly', include_recovered_in_test = False, results_dir = 'results', *args, **kwargs) + + Bases: :py:obj:`ssl_tools.experiments.covid_detection.anomaly_detection_base.CovidAnomalyDetectionEvaluator` + + + Helper class that provides a standard way to create an ABC using + inheritance. + + + .. py:attribute:: _MODEL_NAME + :value: 'cae2d' + + + + .. py:method:: get_model() + + Get the model to use for the experiment. + + Returns + ------- + L.LightningModule + The model to use for the experiment + + + +.. py:class:: ConvolutionalAutoencoder2DAnomalyDetectionTrain(data, input_shape, participant_ids = None, validation_split = 0.1, augment = False, feature_column_prefix = 'RHR', target_column = 'anomaly', *args, **kwargs) + + Bases: :py:obj:`ssl_tools.experiments.covid_detection.anomaly_detection_base.CovidAnomalyDetectionTrain` + + + Helper class that provides a standard way to create an ABC using + inheritance. + + + .. py:attribute:: _MODEL_NAME + :value: 'cae2d' + + + + .. py:method:: get_model() + + Get the model to use for the experiment. + + Returns + ------- + L.LightningModule + The model to use for the experiment + + + +.. py:data:: options + diff --git a/_sources/autoapi/ssl_tools/experiments/covid_detection/ccae/index.rst.txt b/_sources/autoapi/ssl_tools/experiments/covid_detection/ccae/index.rst.txt new file mode 100644 index 0000000..8e52f16 --- /dev/null +++ b/_sources/autoapi/ssl_tools/experiments/covid_detection/ccae/index.rst.txt @@ -0,0 +1,78 @@ +ssl_tools.experiments.covid_detection.ccae +========================================== + +.. py:module:: ssl_tools.experiments.covid_detection.ccae + + +Attributes +---------- + +.. autoapisummary:: + + ssl_tools.experiments.covid_detection.ccae.options + + +Classes +------- + +.. autoapisummary:: + + ssl_tools.experiments.covid_detection.ccae.ConvolutionalAutoencoderAnomalyDetectionTest + ssl_tools.experiments.covid_detection.ccae.ConvolutionalAutoencoderAnomalyDetectionTrain + + +Module Contents +--------------- + +.. py:class:: ConvolutionalAutoencoderAnomalyDetectionTest(train_data, test_data, train_participant, test_participant, input_shape, feature_column_prefix = 'RHR', target_column = 'anomaly', include_recovered_in_test = False, results_dir = 'results', *args, **kwargs) + + Bases: :py:obj:`ssl_tools.experiments.covid_detection.anomaly_detection_base.CovidAnomalyDetectionEvaluator` + + + Helper class that provides a standard way to create an ABC using + inheritance. + + + .. py:attribute:: _MODEL_NAME + :value: 'ccae' + + + + .. py:method:: get_model() + + Get the model to use for the experiment. + + Returns + ------- + L.LightningModule + The model to use for the experiment + + + +.. py:class:: ConvolutionalAutoencoderAnomalyDetectionTrain(data, input_shape, participant_ids = None, validation_split = 0.1, augment = False, feature_column_prefix = 'RHR', target_column = 'anomaly', *args, **kwargs) + + Bases: :py:obj:`ssl_tools.experiments.covid_detection.anomaly_detection_base.CovidAnomalyDetectionTrain` + + + Helper class that provides a standard way to create an ABC using + inheritance. + + + .. py:attribute:: _MODEL_NAME + :value: 'ccae' + + + + .. py:method:: get_model() + + Get the model to use for the experiment. + + Returns + ------- + L.LightningModule + The model to use for the experiment + + + +.. py:data:: options + diff --git a/_sources/autoapi/ssl_tools/experiments/covid_detection/classfication_report/index.rst.txt b/_sources/autoapi/ssl_tools/experiments/covid_detection/classfication_report/index.rst.txt new file mode 100644 index 0000000..eabdf0b --- /dev/null +++ b/_sources/autoapi/ssl_tools/experiments/covid_detection/classfication_report/index.rst.txt @@ -0,0 +1,61 @@ +ssl_tools.experiments.covid_detection.classfication_report +========================================================== + +.. py:module:: ssl_tools.experiments.covid_detection.classfication_report + + +Functions +--------- + +.. autoapisummary:: + + ssl_tools.experiments.covid_detection.classfication_report._balanced_accuracy_score + ssl_tools.experiments.covid_detection.classfication_report._matthews_corrcoef + ssl_tools.experiments.covid_detection.classfication_report._roc_auc_score + ssl_tools.experiments.covid_detection.classfication_report.accuracy_score + ssl_tools.experiments.covid_detection.classfication_report.classification_report + ssl_tools.experiments.covid_detection.classfication_report.f1_score + ssl_tools.experiments.covid_detection.classfication_report.f2_score + ssl_tools.experiments.covid_detection.classfication_report.f2_score + ssl_tools.experiments.covid_detection.classfication_report.fbeta_score + ssl_tools.experiments.covid_detection.classfication_report.negative_precision_score + ssl_tools.experiments.covid_detection.classfication_report.precision_score + ssl_tools.experiments.covid_detection.classfication_report.recall_score + ssl_tools.experiments.covid_detection.classfication_report.specificity_score + ssl_tools.experiments.covid_detection.classfication_report.uar_score + ssl_tools.experiments.covid_detection.classfication_report.wrap_zero_div + + +Module Contents +--------------- + +.. py:function:: _balanced_accuracy_score(y_true, y_pred, labels) + +.. py:function:: _matthews_corrcoef(y_true, y_pred, labels) + +.. py:function:: _roc_auc_score(y_true, y_pred, labels) + +.. py:function:: accuracy_score(tn, fp, fn, tp) + +.. py:function:: classification_report(y_true, y_pred, labels=None) + +.. py:function:: f1_score(tn, fp, fn, tp) + +.. py:function:: f2_score(tn, fp, fn, tp) + +.. py:function:: f2_score(tn, fp, fn, tp) + +.. py:function:: fbeta_score(tn, fp, fn, tp, beta=0.1) + +.. py:function:: negative_precision_score(tn, fp, fn, tp) + +.. py:function:: precision_score(tn, fp, fn, tp) + +.. py:function:: recall_score(tn, fp, fn, tp) + +.. py:function:: specificity_score(tn, fp, fn, tp) + +.. py:function:: uar_score(tn, fp, fn, tp) + +.. py:function:: wrap_zero_div(func) + diff --git a/_sources/autoapi/ssl_tools/experiments/covid_detection/classification_base/index.rst.txt b/_sources/autoapi/ssl_tools/experiments/covid_detection/classification_base/index.rst.txt new file mode 100644 index 0000000..a50ef23 --- /dev/null +++ b/_sources/autoapi/ssl_tools/experiments/covid_detection/classification_base/index.rst.txt @@ -0,0 +1,61 @@ +ssl_tools.experiments.covid_detection.classification_base +========================================================= + +.. py:module:: ssl_tools.experiments.covid_detection.classification_base + + +Classes +------- + +.. autoapisummary:: + + ssl_tools.experiments.covid_detection.classification_base.CovidDetectionEvaluator + ssl_tools.experiments.covid_detection.classification_base.CovidDetectionTrain + + +Module Contents +--------------- + +.. py:class:: CovidDetectionEvaluator(data, feature_column_prefix = 'RHR', target_column = 'anomaly', results_file = 'results.csv', *args, **kwargs) + + Bases: :py:obj:`ssl_tools.experiments.LightningTest` + + + Helper class that provides a standard way to create an ABC using + inheritance. + + + .. py:method:: get_data_module() + + Get the datamodule to use for the experiment. + + Returns + ------- + L.LightningDataModule + The datamodule to use for the experiment + + + + .. py:method:: run_model(model, data_module, trainer) + + +.. py:class:: CovidDetectionTrain(data, reshape = None, validation_split = 0.1, balance = False, feature_column_prefix = 'RHR', target_column = 'anomaly', *args, **kwargs) + + Bases: :py:obj:`ssl_tools.experiments.LightningTrain` + + + Helper class that provides a standard way to create an ABC using + inheritance. + + + .. py:method:: get_data_module() + + Get the datamodule to use for the experiment. + + Returns + ------- + L.LightningDataModule + The datamodule to use for the experiment + + + diff --git a/_sources/autoapi/ssl_tools/experiments/covid_detection/index.rst.txt b/_sources/autoapi/ssl_tools/experiments/covid_detection/index.rst.txt new file mode 100644 index 0000000..2915532 --- /dev/null +++ b/_sources/autoapi/ssl_tools/experiments/covid_detection/index.rst.txt @@ -0,0 +1,22 @@ +ssl_tools.experiments.covid_detection +===================================== + +.. py:module:: ssl_tools.experiments.covid_detection + + +Submodules +---------- + +.. toctree:: + :maxdepth: 1 + + /autoapi/ssl_tools/experiments/covid_detection/anomaly_detection_base/index + /autoapi/ssl_tools/experiments/covid_detection/cae/index + /autoapi/ssl_tools/experiments/covid_detection/cae2d/index + /autoapi/ssl_tools/experiments/covid_detection/ccae/index + /autoapi/ssl_tools/experiments/covid_detection/classfication_report/index + /autoapi/ssl_tools/experiments/covid_detection/classification_base/index + /autoapi/ssl_tools/experiments/covid_detection/lstm_ae/index + /autoapi/ssl_tools/experiments/covid_detection/mlp/index + + diff --git a/_sources/autoapi/ssl_tools/experiments/covid_detection/lstm_ae/index.rst.txt b/_sources/autoapi/ssl_tools/experiments/covid_detection/lstm_ae/index.rst.txt new file mode 100644 index 0000000..bf5da3d --- /dev/null +++ b/_sources/autoapi/ssl_tools/experiments/covid_detection/lstm_ae/index.rst.txt @@ -0,0 +1,78 @@ +ssl_tools.experiments.covid_detection.lstm_ae +============================================= + +.. py:module:: ssl_tools.experiments.covid_detection.lstm_ae + + +Attributes +---------- + +.. autoapisummary:: + + ssl_tools.experiments.covid_detection.lstm_ae.options + + +Classes +------- + +.. autoapisummary:: + + ssl_tools.experiments.covid_detection.lstm_ae.LSTMAutoencoderAnomalyDetectionTest + ssl_tools.experiments.covid_detection.lstm_ae.LSTMAutoencoderAnomalyDetectionTrain + + +Module Contents +--------------- + +.. py:class:: LSTMAutoencoderAnomalyDetectionTest(train_data, test_data, train_participant, test_participant, input_shape, feature_column_prefix = 'RHR', target_column = 'anomaly', include_recovered_in_test = False, results_dir = 'results', *args, **kwargs) + + Bases: :py:obj:`ssl_tools.experiments.covid_detection.anomaly_detection_base.CovidAnomalyDetectionEvaluator` + + + Helper class that provides a standard way to create an ABC using + inheritance. + + + .. py:attribute:: _MODEL_NAME + :value: 'lstm-ae' + + + + .. py:method:: get_model() + + Get the model to use for the experiment. + + Returns + ------- + L.LightningModule + The model to use for the experiment + + + +.. py:class:: LSTMAutoencoderAnomalyDetectionTrain(data, input_shape, participant_ids = None, validation_split = 0.1, augment = False, feature_column_prefix = 'RHR', target_column = 'anomaly', *args, **kwargs) + + Bases: :py:obj:`ssl_tools.experiments.covid_detection.anomaly_detection_base.CovidAnomalyDetectionTrain` + + + Helper class that provides a standard way to create an ABC using + inheritance. + + + .. py:attribute:: _MODEL_NAME + :value: 'lstm-ae' + + + + .. py:method:: get_model() + + Get the model to use for the experiment. + + Returns + ------- + L.LightningModule + The model to use for the experiment + + + +.. py:data:: options + diff --git a/_sources/autoapi/ssl_tools/experiments/covid_detection/mlp/index.rst.txt b/_sources/autoapi/ssl_tools/experiments/covid_detection/mlp/index.rst.txt new file mode 100644 index 0000000..23401a7 --- /dev/null +++ b/_sources/autoapi/ssl_tools/experiments/covid_detection/mlp/index.rst.txt @@ -0,0 +1,87 @@ +ssl_tools.experiments.covid_detection.mlp +========================================= + +.. py:module:: ssl_tools.experiments.covid_detection.mlp + + +Attributes +---------- + +.. autoapisummary:: + + ssl_tools.experiments.covid_detection.mlp.options + + +Classes +------- + +.. autoapisummary:: + + ssl_tools.experiments.covid_detection.mlp.FlattenBCELoss + ssl_tools.experiments.covid_detection.mlp.MLPClassifierTest + ssl_tools.experiments.covid_detection.mlp.MLPClassifierTrain + + +Module Contents +--------------- + +.. py:class:: FlattenBCELoss + + Bases: :py:obj:`torch.nn.BCELoss` + + + .. py:method:: forward(input, target) + + +.. py:class:: MLPClassifierTest(input_size = 16, hidden_size = 128, num_hidden_layers = 1, num_classes = 1, learning_rate = 0.001, *args, **kwargs) + + Bases: :py:obj:`ssl_tools.experiments.covid_detection.classification_base.CovidDetectionEvaluator` + + + Helper class that provides a standard way to create an ABC using + inheritance. + + + .. py:attribute:: _MODEL_NAME + :value: 'mlp' + + + + .. py:method:: get_model() + + Get the model to use for the experiment. + + Returns + ------- + L.LightningModule + The model to use for the experiment + + + +.. py:class:: MLPClassifierTrain(input_size = 16, hidden_size = 128, num_hidden_layers = 1, num_classes = 1, learning_rate = 0.001, *args, **kwargs) + + Bases: :py:obj:`ssl_tools.experiments.covid_detection.classification_base.CovidDetectionTrain` + + + Helper class that provides a standard way to create an ABC using + inheritance. + + + .. py:attribute:: _MODEL_NAME + :value: 'mlp' + + + + .. py:method:: get_model() + + Get the model to use for the experiment. + + Returns + ------- + L.LightningModule + The model to use for the experiment + + + +.. py:data:: options + diff --git a/_sources/autoapi/ssl_tools/experiments/experiment/index.rst.txt b/_sources/autoapi/ssl_tools/experiments/experiment/index.rst.txt index 4d210c2..2489c7a 100644 --- a/_sources/autoapi/ssl_tools/experiments/experiment/index.rst.txt +++ b/_sources/autoapi/ssl_tools/experiments/experiment/index.rst.txt @@ -1,23 +1,27 @@ -:py:mod:`ssl_tools.experiments.experiment` -========================================== +ssl_tools.experiments.experiment +================================ .. py:module:: ssl_tools.experiments.experiment -Module Contents ---------------- +Attributes +---------- + +.. autoapisummary:: + + ssl_tools.experiments.experiment.EXPERIMENT_VERSION_FORMAT + Classes -~~~~~~~ +------- .. autoapisummary:: ssl_tools.experiments.experiment.Experiment - Functions -~~~~~~~~~ +--------- .. autoapisummary:: @@ -25,31 +29,21 @@ Functions ssl_tools.experiments.experiment.get_parser - -Attributes -~~~~~~~~~~ - -.. autoapisummary:: - - ssl_tools.experiments.experiment.EXPERIMENT_VERSION_FORMAT - +Module Contents +--------------- .. py:data:: EXPERIMENT_VERSION_FORMAT :value: '%Y-%m-%d_%H-%M-%S' - .. py:class:: Experiment(name = 'experiment', run_id = None, log_dir = 'logs', seed = None) - Bases: :py:obj:`abc.ABC` + Helper class that provides a standard way to create an ABC using inheritance. - .. py:property:: experiment_dir - :type: pathlib.Path - .. py:method:: __call__() @@ -59,28 +53,33 @@ Attributes Return repr(self). + .. py:method:: __str__() Return str(self). + .. py:method:: execute() + .. py:property:: experiment_dir + :type: pathlib.Path + + + .. py:method:: run() :abstractmethod: + .. py:method:: setup() .. py:method:: teardown() - -.. py:function:: auto_main(commands) - +.. py:function:: auto_main(commands, print_args = False) .. py:function:: get_parser(commands) - diff --git a/_sources/autoapi/ssl_tools/experiments/har_classification/_classification_base/index.rst.txt b/_sources/autoapi/ssl_tools/experiments/har_classification/_classification_base/index.rst.txt new file mode 100644 index 0000000..d0b10a2 --- /dev/null +++ b/_sources/autoapi/ssl_tools/experiments/har_classification/_classification_base/index.rst.txt @@ -0,0 +1,75 @@ +ssl_tools.experiments.har_classification._classification_base +============================================================= + +.. py:module:: ssl_tools.experiments.har_classification._classification_base + + +Classes +------- + +.. autoapisummary:: + + ssl_tools.experiments.har_classification._classification_base.EvaluatorBase + ssl_tools.experiments.har_classification._classification_base.PredictionHeadClassifier + + +Functions +--------- + +.. autoapisummary:: + + ssl_tools.experiments.har_classification._classification_base.full_dataset_from_dataloader + ssl_tools.experiments.har_classification._classification_base.generate_embeddings + ssl_tools.experiments.har_classification._classification_base.get_full_data_split + ssl_tools.experiments.har_classification._classification_base.get_split_dataloader + + +Module Contents +--------------- + +.. py:class:: EvaluatorBase(results_file = 'results.csv', confusion_matrix_file = 'confusion_matrix.csv', confusion_matrix_image_file = 'confusion_matrix.png', tsne_plot_file = 'tsne_embeddings.png', embedding_file = 'embeddings.csv', predictions_file = 'predictions.csv', *args, **kwargs) + + Bases: :py:obj:`ssl_tools.experiments.LightningTest` + + + Helper class that provides a standard way to create an ABC using + inheritance. + + + .. py:method:: _compute_classification_metrics(y_hat_logits, y, n_classes) + + + .. py:method:: _compute_embeddings(model, data_module, trainer) + + + .. py:method:: _plot_confusion_matrix(y_hat, y, n_classes, cm_file, cm_image_file) + + + .. py:method:: _plot_tnse_embeddings(embeddings, y, y_hat, n_components = 2, tsne_plot_file = 'tsne_embeddings.png') + + + .. py:method:: evaluate_embeddings(model, data_module, trainer) + + + .. py:method:: evaluate_model_performance(model, data_module, trainer) + + + .. py:method:: predict(model, dataloader, trainer) + + + .. py:method:: run_model(model, data_module, trainer) + + +.. py:class:: PredictionHeadClassifier(prediction_head, num_classes = 6) + + Bases: :py:obj:`ssl_tools.models.nets.simple.SimpleClassificationNet` + + +.. py:function:: full_dataset_from_dataloader(dataloader) + +.. py:function:: generate_embeddings(model, dataloader, trainer) + +.. py:function:: get_full_data_split(data_module, stage) + +.. py:function:: get_split_dataloader(stage, data_module) + diff --git a/_sources/autoapi/ssl_tools/experiments/har_classification/cpc/index.rst.txt b/_sources/autoapi/ssl_tools/experiments/har_classification/cpc/index.rst.txt index 6b009cf..3812b65 100644 --- a/_sources/autoapi/ssl_tools/experiments/har_classification/cpc/index.rst.txt +++ b/_sources/autoapi/ssl_tools/experiments/har_classification/cpc/index.rst.txt @@ -1,36 +1,34 @@ -:py:mod:`ssl_tools.experiments.har_classification.cpc` -====================================================== +ssl_tools.experiments.har_classification.cpc +============================================ .. py:module:: ssl_tools.experiments.har_classification.cpc -Module Contents ---------------- - -Classes -~~~~~~~ +Attributes +---------- .. autoapisummary:: - ssl_tools.experiments.har_classification.cpc.CPCTest - ssl_tools.experiments.har_classification.cpc.CPCTrain - - + ssl_tools.experiments.har_classification.cpc.options -Attributes -~~~~~~~~~~ +Classes +------- .. autoapisummary:: - ssl_tools.experiments.har_classification.cpc.options + ssl_tools.experiments.har_classification.cpc.CPCTest + ssl_tools.experiments.har_classification.cpc.CPCTrain -.. py:class:: CPCTest(data, encoding_size = 150, in_channel = 6, window_size = 4, num_classes = 6, *args, **kwargs) +Module Contents +--------------- +.. py:class:: CPCTest(data, encoding_size = 150, in_channel = 6, window_size = 4, num_classes = 6, *args, **kwargs) Bases: :py:obj:`ssl_tools.experiments.LightningTest` + Helper class that provides a standard way to create an ABC using inheritance. @@ -50,10 +48,11 @@ Attributes If True, the backbone will be updated during training. Only used in finetune mode. + .. py:attribute:: _MODEL_NAME :value: 'CPC' - + .. py:method:: get_data_module() @@ -65,6 +64,7 @@ Attributes The datamodule to use for the experiment + .. py:method:: get_model(load_backbone = None) Get the model to use for the experiment. @@ -78,9 +78,9 @@ Attributes .. py:class:: CPCTrain(data, encoding_size = 150, in_channel = 6, window_size = 4, pad_length = False, num_classes = 6, update_backbone = False, *args, **kwargs) - Bases: :py:obj:`ssl_tools.experiments.LightningSSLTrain` + Helper class that provides a standard way to create an ABC using inheritance. @@ -103,10 +103,11 @@ Attributes If True, the backbone will be updated during training. Only used in finetune mode. + .. py:attribute:: _MODEL_NAME :value: 'CPC' - + .. py:method:: get_finetune_data_module() @@ -123,6 +124,7 @@ Attributes _description_ + .. py:method:: get_finetune_model(load_backbone = None) Get the model to use for fine-tuning. @@ -139,6 +141,7 @@ Attributes The model to use for fine-tuning + .. py:method:: get_pretrain_data_module() The data module to use for pre-training. @@ -149,6 +152,7 @@ Attributes The data module to use for pre-training + .. py:method:: get_pretrain_model() Get the model to use for the pretraining phase. @@ -162,5 +166,3 @@ Attributes .. py:data:: options - - diff --git a/_sources/autoapi/ssl_tools/experiments/har_classification/gru_encoder/index.rst.txt b/_sources/autoapi/ssl_tools/experiments/har_classification/gru_encoder/index.rst.txt new file mode 100644 index 0000000..eafa399 --- /dev/null +++ b/_sources/autoapi/ssl_tools/experiments/har_classification/gru_encoder/index.rst.txt @@ -0,0 +1,106 @@ +ssl_tools.experiments.har_classification.gru_encoder +==================================================== + +.. py:module:: ssl_tools.experiments.har_classification.gru_encoder + + +Attributes +---------- + +.. autoapisummary:: + + ssl_tools.experiments.har_classification.gru_encoder.options + + +Classes +------- + +.. autoapisummary:: + + ssl_tools.experiments.har_classification.gru_encoder.GRUClassifier + ssl_tools.experiments.har_classification.gru_encoder.GRUClassifierTest + ssl_tools.experiments.har_classification.gru_encoder.GRUClassifierTrain + + +Module Contents +--------------- + +.. py:class:: GRUClassifier(hidden_size = 100, in_channels = 6, num_classes = 6, encoding_size = 100, num_layers = 1, dropout = 0.0, bidirectional = True) + + Bases: :py:obj:`ssl_tools.models.nets.simple.SimpleClassificationNet` + + +.. py:class:: GRUClassifierTest(data, hidden_size = 100, in_channels = 6, num_classes = 6, encoding_size = 100, num_layers = 1, dropout = 0.0, bidirectional = True, transforms = 'identity', *args, **kwargs) + + Bases: :py:obj:`ssl_tools.experiments.har_classification._classification_base.EvaluatorBase` + + + Helper class that provides a standard way to create an ABC using + inheritance. + + + .. py:attribute:: _MODEL_NAME + :value: 'GRU' + + + + .. py:method:: get_data_module() + + Get the datamodule to use for the experiment. + + Returns + ------- + L.LightningDataModule + The datamodule to use for the experiment + + + + .. py:method:: get_model() + + Get the model to use for the experiment. + + Returns + ------- + L.LightningModule + The model to use for the experiment + + + +.. py:class:: GRUClassifierTrain(data, hidden_size = 100, in_channels = 6, num_classes = 6, encoding_size = 100, num_layers = 1, dropout = 0.0, bidirectional = True, transforms = 'identity', *args, **kwargs) + + Bases: :py:obj:`ssl_tools.experiments.LightningTrain` + + + Helper class that provides a standard way to create an ABC using + inheritance. + + + .. py:attribute:: _MODEL_NAME + :value: 'GRU' + + + + .. py:method:: get_data_module() + + Get the datamodule to use for the experiment. + + Returns + ------- + L.LightningDataModule + The datamodule to use for the experiment + + + + .. py:method:: get_model() + + Get the model to use for the experiment. + + Returns + ------- + L.LightningModule + The model to use for the experiment + + + +.. py:data:: options + diff --git a/_sources/autoapi/ssl_tools/experiments/har_classification/index.rst.txt b/_sources/autoapi/ssl_tools/experiments/har_classification/index.rst.txt index 8b72d27..2221567 100644 --- a/_sources/autoapi/ssl_tools/experiments/har_classification/index.rst.txt +++ b/_sources/autoapi/ssl_tools/experiments/har_classification/index.rst.txt @@ -1,17 +1,25 @@ -:py:mod:`ssl_tools.experiments.har_classification` -================================================== +ssl_tools.experiments.har_classification +======================================== .. py:module:: ssl_tools.experiments.har_classification Submodules ---------- + .. toctree:: - :titlesonly: :maxdepth: 1 - cpc/index.rst - tfc/index.rst - tnc/index.rst + /autoapi/ssl_tools/experiments/har_classification/_classification_base/index + /autoapi/ssl_tools/experiments/har_classification/cpc/index + /autoapi/ssl_tools/experiments/har_classification/gru_encoder/index + /autoapi/ssl_tools/experiments/har_classification/mlp_classifier/index + /autoapi/ssl_tools/experiments/har_classification/simple1Dconv_classifier/index + /autoapi/ssl_tools/experiments/har_classification/simple2Dconv_classifier/index + /autoapi/ssl_tools/experiments/har_classification/tfc/index + /autoapi/ssl_tools/experiments/har_classification/tfc_head_classifier/index + /autoapi/ssl_tools/experiments/har_classification/tnc/index + /autoapi/ssl_tools/experiments/har_classification/tnc_head_classifier/index + /autoapi/ssl_tools/experiments/har_classification/utils/index diff --git a/_sources/autoapi/ssl_tools/experiments/har_classification/mlp_classifier/index.rst.txt b/_sources/autoapi/ssl_tools/experiments/har_classification/mlp_classifier/index.rst.txt new file mode 100644 index 0000000..87a20d8 --- /dev/null +++ b/_sources/autoapi/ssl_tools/experiments/har_classification/mlp_classifier/index.rst.txt @@ -0,0 +1,100 @@ +ssl_tools.experiments.har_classification.mlp_classifier +======================================================= + +.. py:module:: ssl_tools.experiments.har_classification.mlp_classifier + + +Attributes +---------- + +.. autoapisummary:: + + ssl_tools.experiments.har_classification.mlp_classifier.options + + +Classes +------- + +.. autoapisummary:: + + ssl_tools.experiments.har_classification.mlp_classifier.MLPClassifierTest + ssl_tools.experiments.har_classification.mlp_classifier.MLPClassifierTrain + + +Module Contents +--------------- + +.. py:class:: MLPClassifierTest(data, input_size = 360, hidden_size = 64, num_hidden_layers = 1, num_classes = 6, transforms = 'identity', *args, **kwargs) + + Bases: :py:obj:`ssl_tools.experiments.har_classification._classification_base.EvaluatorBase` + + + Helper class that provides a standard way to create an ABC using + inheritance. + + + .. py:attribute:: _MODEL_NAME + :value: 'MLP' + + + + .. py:method:: get_data_module() + + Get the datamodule to use for the experiment. + + Returns + ------- + L.LightningDataModule + The datamodule to use for the experiment + + + + .. py:method:: get_model() + + Get the model to use for the experiment. + + Returns + ------- + L.LightningModule + The model to use for the experiment + + + +.. py:class:: MLPClassifierTrain(data, input_size = 360, hidden_size = 64, num_hidden_layers = 1, num_classes = 6, transforms = 'identity', *args, **kwargs) + + Bases: :py:obj:`ssl_tools.experiments.LightningTrain` + + + Helper class that provides a standard way to create an ABC using + inheritance. + + + .. py:attribute:: _MODEL_NAME + :value: 'MLP' + + + + .. py:method:: get_data_module() + + Get the datamodule to use for the experiment. + + Returns + ------- + L.LightningDataModule + The datamodule to use for the experiment + + + + .. py:method:: get_model() + + Get the model to use for the experiment. + + Returns + ------- + L.LightningModule + The model to use for the experiment + + + +.. py:data:: options + diff --git a/_sources/autoapi/ssl_tools/experiments/har_classification/simple1Dconv_classifier/index.rst.txt b/_sources/autoapi/ssl_tools/experiments/har_classification/simple1Dconv_classifier/index.rst.txt new file mode 100644 index 0000000..25b28cb --- /dev/null +++ b/_sources/autoapi/ssl_tools/experiments/har_classification/simple1Dconv_classifier/index.rst.txt @@ -0,0 +1,100 @@ +ssl_tools.experiments.har_classification.simple1Dconv_classifier +================================================================ + +.. py:module:: ssl_tools.experiments.har_classification.simple1Dconv_classifier + + +Attributes +---------- + +.. autoapisummary:: + + ssl_tools.experiments.har_classification.simple1Dconv_classifier.options + + +Classes +------- + +.. autoapisummary:: + + ssl_tools.experiments.har_classification.simple1Dconv_classifier.Simple1DConvNetTest + ssl_tools.experiments.har_classification.simple1Dconv_classifier.Simple1DConvNetTrain + + +Module Contents +--------------- + +.. py:class:: Simple1DConvNetTest(data, input_shape = (6, 60), num_classes = 6, transforms = 'identity', *args, **kwargs) + + Bases: :py:obj:`ssl_tools.experiments.har_classification._classification_base.EvaluatorBase` + + + Helper class that provides a standard way to create an ABC using + inheritance. + + + .. py:attribute:: _MODEL_NAME + :value: 'Simple1DConvNet' + + + + .. py:method:: get_data_module() + + Get the datamodule to use for the experiment. + + Returns + ------- + L.LightningDataModule + The datamodule to use for the experiment + + + + .. py:method:: get_model() + + Get the model to use for the experiment. + + Returns + ------- + L.LightningModule + The model to use for the experiment + + + +.. py:class:: Simple1DConvNetTrain(data, input_shape = (6, 60), num_classes = 6, transforms = 'identity', *args, **kwargs) + + Bases: :py:obj:`ssl_tools.experiments.LightningTrain` + + + Helper class that provides a standard way to create an ABC using + inheritance. + + + .. py:attribute:: _MODEL_NAME + :value: 'Simple1DConvNet' + + + + .. py:method:: get_data_module() + + Get the datamodule to use for the experiment. + + Returns + ------- + L.LightningDataModule + The datamodule to use for the experiment + + + + .. py:method:: get_model() + + Get the model to use for the experiment. + + Returns + ------- + L.LightningModule + The model to use for the experiment + + + +.. py:data:: options + diff --git a/_sources/autoapi/ssl_tools/experiments/har_classification/simple2Dconv_classifier/index.rst.txt b/_sources/autoapi/ssl_tools/experiments/har_classification/simple2Dconv_classifier/index.rst.txt new file mode 100644 index 0000000..0edc2fa --- /dev/null +++ b/_sources/autoapi/ssl_tools/experiments/har_classification/simple2Dconv_classifier/index.rst.txt @@ -0,0 +1,100 @@ +ssl_tools.experiments.har_classification.simple2Dconv_classifier +================================================================ + +.. py:module:: ssl_tools.experiments.har_classification.simple2Dconv_classifier + + +Attributes +---------- + +.. autoapisummary:: + + ssl_tools.experiments.har_classification.simple2Dconv_classifier.options + + +Classes +------- + +.. autoapisummary:: + + ssl_tools.experiments.har_classification.simple2Dconv_classifier.Simple2DConvNetTest + ssl_tools.experiments.har_classification.simple2Dconv_classifier.Simple2DConvNetTrain + + +Module Contents +--------------- + +.. py:class:: Simple2DConvNetTest(data, input_shape = (6, 1, 60), num_classes = 6, transforms = 'identity', *args, **kwargs) + + Bases: :py:obj:`ssl_tools.experiments.har_classification._classification_base.EvaluatorBase` + + + Helper class that provides a standard way to create an ABC using + inheritance. + + + .. py:attribute:: _MODEL_NAME + :value: 'Simple2DConvNet' + + + + .. py:method:: get_data_module() + + Get the datamodule to use for the experiment. + + Returns + ------- + L.LightningDataModule + The datamodule to use for the experiment + + + + .. py:method:: get_model() + + Get the model to use for the experiment. + + Returns + ------- + L.LightningModule + The model to use for the experiment + + + +.. py:class:: Simple2DConvNetTrain(data, input_shape = (6, 1, 60), num_classes = 6, transforms = 'identity', *args, **kwargs) + + Bases: :py:obj:`ssl_tools.experiments.LightningTrain` + + + Helper class that provides a standard way to create an ABC using + inheritance. + + + .. py:attribute:: _MODEL_NAME + :value: 'Simple2DConvNet' + + + + .. py:method:: get_data_module() + + Get the datamodule to use for the experiment. + + Returns + ------- + L.LightningDataModule + The datamodule to use for the experiment + + + + .. py:method:: get_model() + + Get the model to use for the experiment. + + Returns + ------- + L.LightningModule + The model to use for the experiment + + + +.. py:data:: options + diff --git a/_sources/autoapi/ssl_tools/experiments/har_classification/tfc/index.rst.txt b/_sources/autoapi/ssl_tools/experiments/har_classification/tfc/index.rst.txt index 6a1c633..bd1a783 100644 --- a/_sources/autoapi/ssl_tools/experiments/har_classification/tfc/index.rst.txt +++ b/_sources/autoapi/ssl_tools/experiments/har_classification/tfc/index.rst.txt @@ -1,36 +1,34 @@ -:py:mod:`ssl_tools.experiments.har_classification.tfc` -====================================================== +ssl_tools.experiments.har_classification.tfc +============================================ .. py:module:: ssl_tools.experiments.har_classification.tfc -Module Contents ---------------- - -Classes -~~~~~~~ +Attributes +---------- .. autoapisummary:: - ssl_tools.experiments.har_classification.tfc.TFCTest - ssl_tools.experiments.har_classification.tfc.TFCTrain - - + ssl_tools.experiments.har_classification.tfc.options -Attributes -~~~~~~~~~~ +Classes +------- .. autoapisummary:: - ssl_tools.experiments.har_classification.tfc.options + ssl_tools.experiments.har_classification.tfc.TFCTest + ssl_tools.experiments.har_classification.tfc.TFCTrain -.. py:class:: TFCTest(data, label = 'standard activity code', encoding_size = 128, in_channels = 6, length_alignment = 178, use_cosine_similarity = True, temperature = 0.5, features_as_channels = False, num_classes = 6, *args, **kwargs) +Module Contents +--------------- +.. py:class:: TFCTest(data, label = 'standard activity code', encoding_size = 128, in_channels = 6, length_alignment = 178, use_cosine_similarity = True, temperature = 0.5, features_as_channels = False, num_classes = 6, *args, **kwargs) Bases: :py:obj:`ssl_tools.experiments.LightningTest` + Helper class that provides a standard way to create an ABC using inheritance. @@ -72,10 +70,11 @@ Attributes If True, the backbone will be updated during training. Only used in finetune mode. + .. py:attribute:: _MODEL_NAME :value: 'TFC' - + .. py:method:: get_data_module() @@ -87,6 +86,7 @@ Attributes The datamodule to use for the experiment + .. py:method:: get_model() Get the model to use for the experiment. @@ -100,9 +100,9 @@ Attributes .. py:class:: TFCTrain(data, label = 'standard activity code', encoding_size = 128, in_channels = 6, length_alignment = 178, use_cosine_similarity = True, temperature = 0.5, features_as_channels = False, jitter_ratio = 2, num_classes = 6, update_backbone = False, *args, **kwargs) - Bases: :py:obj:`ssl_tools.experiments.LightningSSLTrain` + Helper class that provides a standard way to create an ABC using inheritance. @@ -144,10 +144,11 @@ Attributes If True, the backbone will be updated during training. Only used in finetune mode. + .. py:attribute:: _MODEL_NAME :value: 'TFC' - + .. py:method:: get_finetune_data_module() @@ -164,6 +165,7 @@ Attributes _description_ + .. py:method:: get_finetune_model(load_backbone = None) Get the model to use for fine-tuning. @@ -180,6 +182,7 @@ Attributes The model to use for fine-tuning + .. py:method:: get_pretrain_data_module() The data module to use for pre-training. @@ -190,6 +193,7 @@ Attributes The data module to use for pre-training + .. py:method:: get_pretrain_model() Get the model to use for the pretraining phase. @@ -203,5 +207,3 @@ Attributes .. py:data:: options - - diff --git a/_sources/autoapi/ssl_tools/experiments/har_classification/tfc_head_classifier/index.rst.txt b/_sources/autoapi/ssl_tools/experiments/har_classification/tfc_head_classifier/index.rst.txt new file mode 100644 index 0000000..45e0574 --- /dev/null +++ b/_sources/autoapi/ssl_tools/experiments/har_classification/tfc_head_classifier/index.rst.txt @@ -0,0 +1,100 @@ +ssl_tools.experiments.har_classification.tfc_head_classifier +============================================================ + +.. py:module:: ssl_tools.experiments.har_classification.tfc_head_classifier + + +Attributes +---------- + +.. autoapisummary:: + + ssl_tools.experiments.har_classification.tfc_head_classifier.options + + +Classes +------- + +.. autoapisummary:: + + ssl_tools.experiments.har_classification.tfc_head_classifier.TFCHeadClassifierTest + ssl_tools.experiments.har_classification.tfc_head_classifier.TFCHeadClassifierTrain + + +Module Contents +--------------- + +.. py:class:: TFCHeadClassifierTest(data, input_size = 360, num_classes = 6, transforms = 'identity', *args, **kwargs) + + Bases: :py:obj:`ssl_tools.experiments.har_classification._classification_base.EvaluatorBase` + + + Helper class that provides a standard way to create an ABC using + inheritance. + + + .. py:attribute:: _MODEL_NAME + :value: 'TFCPredictionHead' + + + + .. py:method:: get_data_module() + + Get the datamodule to use for the experiment. + + Returns + ------- + L.LightningDataModule + The datamodule to use for the experiment + + + + .. py:method:: get_model() + + Get the model to use for the experiment. + + Returns + ------- + L.LightningModule + The model to use for the experiment + + + +.. py:class:: TFCHeadClassifierTrain(data, input_size = 360, num_classes = 6, transforms = 'identity', *args, **kwargs) + + Bases: :py:obj:`ssl_tools.experiments.LightningTrain` + + + Helper class that provides a standard way to create an ABC using + inheritance. + + + .. py:attribute:: _MODEL_NAME + :value: 'TFCPredictionHead' + + + + .. py:method:: get_data_module() + + Get the datamodule to use for the experiment. + + Returns + ------- + L.LightningDataModule + The datamodule to use for the experiment + + + + .. py:method:: get_model() + + Get the model to use for the experiment. + + Returns + ------- + L.LightningModule + The model to use for the experiment + + + +.. py:data:: options + diff --git a/_sources/autoapi/ssl_tools/experiments/har_classification/tnc/index.rst.txt b/_sources/autoapi/ssl_tools/experiments/har_classification/tnc/index.rst.txt index e0dba55..4f861dc 100644 --- a/_sources/autoapi/ssl_tools/experiments/har_classification/tnc/index.rst.txt +++ b/_sources/autoapi/ssl_tools/experiments/har_classification/tnc/index.rst.txt @@ -1,36 +1,34 @@ -:py:mod:`ssl_tools.experiments.har_classification.tnc` -====================================================== +ssl_tools.experiments.har_classification.tnc +============================================ .. py:module:: ssl_tools.experiments.har_classification.tnc -Module Contents ---------------- - -Classes -~~~~~~~ +Attributes +---------- .. autoapisummary:: - ssl_tools.experiments.har_classification.tnc.TNCTest - ssl_tools.experiments.har_classification.tnc.TNCTrain - - + ssl_tools.experiments.har_classification.tnc.options -Attributes -~~~~~~~~~~ +Classes +------- .. autoapisummary:: - ssl_tools.experiments.har_classification.tnc.options + ssl_tools.experiments.har_classification.tnc.TNCTest + ssl_tools.experiments.har_classification.tnc.TNCTrain -.. py:class:: TNCTest(data, encoding_size = 10, in_channel = 6, window_size = 60, mc_sample_size = 20, w = 0.05, num_classes = 6, *args, **kwargs) +Module Contents +--------------- +.. py:class:: TNCTest(data, encoding_size = 10, in_channel = 6, window_size = 60, mc_sample_size = 20, w = 0.05, num_classes = 6, *args, **kwargs) Bases: :py:obj:`ssl_tools.experiments.LightningTest` + Helper class that provides a standard way to create an ABC using inheritance. @@ -48,10 +46,11 @@ Attributes If True, the samples are padded to the length of the longest sample in the dataset. + .. py:attribute:: _MODEL_NAME :value: 'TNC' - + .. py:method:: get_data_module() @@ -63,6 +62,7 @@ Attributes The datamodule to use for the experiment + .. py:method:: get_model() Get the model to use for the experiment. @@ -76,9 +76,9 @@ Attributes .. py:class:: TNCTrain(data, encoding_size = 10, in_channel = 6, window_size = 60, mc_sample_size = 20, w = 0.05, significance_level = 0.01, repeat = 5, pad_length = True, num_classes = 6, update_backbone = False, *args, **kwargs) - Bases: :py:obj:`ssl_tools.experiments.LightningSSLTrain` + Helper class that provides a standard way to create an ABC using inheritance. @@ -101,10 +101,11 @@ Attributes If True, the backbone will be updated during training. Only used in finetune mode. + .. py:attribute:: _MODEL_NAME :value: 'TNC' - + .. py:method:: get_finetune_data_module() @@ -121,6 +122,7 @@ Attributes _description_ + .. py:method:: get_finetune_model(load_backbone = None) Get the model to use for fine-tuning. @@ -137,6 +139,7 @@ Attributes The model to use for fine-tuning + .. py:method:: get_pretrain_data_module() The data module to use for pre-training. @@ -147,6 +150,7 @@ Attributes The data module to use for pre-training + .. py:method:: get_pretrain_model() Get the model to use for the pretraining phase. @@ -160,5 +164,3 @@ Attributes .. py:data:: options - - diff --git a/_sources/autoapi/ssl_tools/experiments/har_classification/tnc_head_classifier/index.rst.txt b/_sources/autoapi/ssl_tools/experiments/har_classification/tnc_head_classifier/index.rst.txt new file mode 100644 index 0000000..09ded08 --- /dev/null +++ b/_sources/autoapi/ssl_tools/experiments/har_classification/tnc_head_classifier/index.rst.txt @@ -0,0 +1,100 @@ +ssl_tools.experiments.har_classification.tnc_head_classifier +============================================================ + +.. py:module:: ssl_tools.experiments.har_classification.tnc_head_classifier + + +Attributes +---------- + +.. autoapisummary:: + + ssl_tools.experiments.har_classification.tnc_head_classifier.options + + +Classes +------- + +.. autoapisummary:: + + ssl_tools.experiments.har_classification.tnc_head_classifier.TNCHeadClassifierTest + ssl_tools.experiments.har_classification.tnc_head_classifier.TNCHeadClassifierTrain + + +Module Contents +--------------- + +.. py:class:: TNCHeadClassifierTest(data, input_size = 360, num_classes = 6, transforms = 'identity', *args, **kwargs) + + Bases: :py:obj:`ssl_tools.experiments.har_classification._classification_base.EvaluatorBase` + + + Helper class that provides a standard way to create an ABC using + inheritance. + + + .. py:attribute:: _MODEL_NAME + :value: 'TNCPredictionHead' + + + + .. py:method:: get_data_module() + + Get the datamodule to use for the experiment. + + Returns + ------- + L.LightningDataModule + The datamodule to use for the experiment + + + + .. py:method:: get_model() + + Get the model to use for the experiment. + + Returns + ------- + L.LightningModule + The model to use for the experiment + + + +.. py:class:: TNCHeadClassifierTrain(data, input_size = 360, num_classes = 6, transforms = 'identity', *args, **kwargs) + + Bases: :py:obj:`ssl_tools.experiments.LightningTrain` + + + Helper class that provides a standard way to create an ABC using + inheritance. + + + .. py:attribute:: _MODEL_NAME + :value: 'TNCPredictionHead' + + + + .. py:method:: get_data_module() + + Get the datamodule to use for the experiment. + + Returns + ------- + L.LightningDataModule + The datamodule to use for the experiment + + + + .. py:method:: get_model() + + Get the model to use for the experiment. + + Returns + ------- + L.LightningModule + The model to use for the experiment + + + +.. py:data:: options + diff --git a/_sources/autoapi/ssl_tools/experiments/har_classification/utils/index.rst.txt b/_sources/autoapi/ssl_tools/experiments/har_classification/utils/index.rst.txt new file mode 100644 index 0000000..a8948a2 --- /dev/null +++ b/_sources/autoapi/ssl_tools/experiments/har_classification/utils/index.rst.txt @@ -0,0 +1,87 @@ +ssl_tools.experiments.har_classification.utils +============================================== + +.. py:module:: ssl_tools.experiments.har_classification.utils + + +Classes +------- + +.. autoapisummary:: + + ssl_tools.experiments.har_classification.utils.DimensionAdder + ssl_tools.experiments.har_classification.utils.FFT + ssl_tools.experiments.har_classification.utils.Flatten + ssl_tools.experiments.har_classification.utils.Spectrogram + + +Module Contents +--------------- + +.. py:class:: DimensionAdder(dim) + + .. py:method:: __call__(x) + + +.. py:class:: FFT(absolute = True, centered = False) + + .. py:method:: __call__(x) + + Aplly FFT to the input signal. It apply the FFT into each channel + of the input signal. + + Parameters + ---------- + x : np.ndarray + An array with shape (n_channels, n_samples) containing the input + + Returns + ------- + np.ndarray + The FFT of the input signal. The shape of the output is + (n_channels, n_samples) if absolute is False, and + (n_channels, n_samples//2) if absolute is True. + + + +.. py:class:: Flatten + + .. py:method:: __call__(x) + + Flatten the input signal. It apply the flatten into each channel + of the input signal. + + Parameters + ---------- + x : np.ndarray + An array with shape (n_channels, n_samples) containing the input + + Returns + ------- + np.ndarray + The flatten of the input signal. The shape of the output is + (n_channels, n_samples). + + + +.. py:class:: Spectrogram(fs=20, nperseg=16, noverlap=8, nfft=16) + + .. py:method:: __call__(x) + + Aplly Spectrogram to the input signal. It apply the Spectrogram into each channel + of the input signal. + + Parameters + ---------- + x : np.ndarray + An array with shape (n_channels, n_samples) containing the input + + Returns + ------- + np.ndarray + The Spectrogram of the input signal. The shape of the output is + (n_channels, n_samples) if absolute is False, and + (n_channels, n_samples//2) if absolute is True. + + + diff --git a/_sources/autoapi/ssl_tools/experiments/index.rst.txt b/_sources/autoapi/ssl_tools/experiments/index.rst.txt index 3f75e67..e9a29f6 100644 --- a/_sources/autoapi/ssl_tools/experiments/index.rst.txt +++ b/_sources/autoapi/ssl_tools/experiments/index.rst.txt @@ -1,33 +1,31 @@ -:py:mod:`ssl_tools.experiments` -=============================== +ssl_tools.experiments +===================== .. py:module:: ssl_tools.experiments Subpackages ----------- + .. toctree:: - :titlesonly: - :maxdepth: 3 + :maxdepth: 1 - har_classification/index.rst + /autoapi/ssl_tools/experiments/covid_detection/index + /autoapi/ssl_tools/experiments/har_classification/index Submodules ---------- + .. toctree:: - :titlesonly: :maxdepth: 1 - experiment/index.rst - lightning_experiment/index.rst - + /autoapi/ssl_tools/experiments/experiment/index + /autoapi/ssl_tools/experiments/lightning_experiment/index -Package Contents ----------------- Classes -~~~~~~~ +------- .. autoapisummary:: @@ -38,27 +36,25 @@ Classes ssl_tools.experiments.LightningTrain - Functions -~~~~~~~~~ +--------- .. autoapisummary:: ssl_tools.experiments.auto_main +Package Contents +---------------- .. py:class:: Experiment(name = 'experiment', run_id = None, log_dir = 'logs', seed = None) - Bases: :py:obj:`abc.ABC` + Helper class that provides a standard way to create an ABC using inheritance. - .. py:property:: experiment_dir - :type: pathlib.Path - .. py:method:: __call__() @@ -68,84 +64,82 @@ Functions Return repr(self). + .. py:method:: __str__() Return str(self). + .. py:method:: execute() + .. py:property:: experiment_dir + :type: pathlib.Path + + + .. py:method:: run() :abstractmethod: + .. py:method:: setup() .. py:method:: teardown() - .. py:class:: LightningExperiment(name = None, stage_name = None, batch_size = 1, load = None, accelerator = 'cpu', devices = 1, strategy = 'auto', num_nodes = 1, num_workers = None, log_every_n_steps = 50, *args, **kwargs) - Bases: :py:obj:`ssl_tools.experiments.experiment.Experiment` + Helper class that provides a standard way to create an ABC using inheritance. - .. py:property:: callbacks - :type: List[lightning.Callback] + .. py:attribute:: _MODEL_NAME + :type: str + :value: 'model' - .. py:property:: checkpoint_dir - :type: pathlib.Path - .. py:property:: data_module - :type: lightning.LightningDataModule + .. py:attribute:: _STAGE_NAME + :type: str + :value: 'stage' - .. py:property:: experiment_dir - :type: pathlib.Path + .. py:method:: __str__() - .. py:property:: finished - :type: bool + Return str(self). - .. py:property:: hyperparameters - :type: dict + .. py:property:: callbacks + :type: List[lightning.Callback] - .. py:property:: logger - :type: lightning.pytorch.loggers.Logger - .. py:property:: model - :type: lightning.LightningModule + .. py:property:: checkpoint_dir + :type: pathlib.Path - .. py:property:: trainer - :type: lightning.Trainer + .. py:property:: data_module + :type: lightning.LightningDataModule - .. py:attribute:: _MODEL_NAME - :type: str - :value: 'model' - - .. py:attribute:: _STAGE_NAME - :type: str - :value: 'stage' + .. py:property:: experiment_dir + :type: pathlib.Path - - .. py:method:: __str__() - Return str(self). + .. py:property:: finished + :type: bool + .. py:method:: get_callbacks() @@ -158,9 +152,11 @@ Functions A list of callbacks to use for the experiment + .. py:method:: get_data_module() :abstractmethod: + Get the datamodule to use for the experiment. Returns @@ -169,6 +165,7 @@ Functions The datamodule to use for the experiment + .. py:method:: get_logger() Get the logger to use for the experiment. @@ -179,9 +176,11 @@ Functions The logger to use for the experiment + .. py:method:: get_model() :abstractmethod: + Get the model to use for the experiment. Returns @@ -190,9 +189,11 @@ Functions The model to use for the experiment + .. py:method:: get_trainer(logger, callbacks) :abstractmethod: + Get trainer to use for the experiment. Parameters @@ -208,6 +209,12 @@ Functions The trainer to use for the experiment + + .. py:property:: hyperparameters + :type: dict + + + .. py:method:: load_checkpoint(model, path) Load the model to use for the experiment. @@ -218,6 +225,7 @@ Functions The model to use for the experiment + .. py:method:: log_hyperparams(logger) Log the hyperparameters for reproducibility purposes. @@ -230,6 +238,17 @@ Functions The logger to use for logging the hyperparameters + + .. py:property:: logger + :type: lightning.pytorch.loggers.Logger + + + + .. py:property:: model + :type: lightning.LightningModule + + + .. py:method:: run() Runs the experiment. This method: @@ -241,19 +260,25 @@ Functions 5. Trains/Tests the model + .. py:method:: run_model(model, data_module, trainer) :abstractmethod: + .. py:method:: setup() + .. py:property:: trainer + :type: lightning.Trainer + -.. py:class:: LightningSSLTrain(training_mode = 'pretrain', load_backbone = None, *args, **kwargs) +.. py:class:: LightningSSLTrain(training_mode = 'pretrain', load_backbone = None, *args, **kwargs) Bases: :py:obj:`LightningTrain` + Helper class that provides a standard way to create an ABC using inheritance. @@ -270,6 +295,7 @@ Functions using ``load_backbone``. The ``load`` parameter is used to load the full model (backbone + head). + .. py:method:: get_data_module() Get the datamodule to use for the experiment. @@ -280,9 +306,11 @@ Functions The datamodule to use for the experiment + .. py:method:: get_finetune_data_module() :abstractmethod: + The data module to use for fine-tuning. Returns @@ -296,9 +324,11 @@ Functions _description_ + .. py:method:: get_finetune_model(load_backbone = None) :abstractmethod: + Get the model to use for fine-tuning. Parameters @@ -313,6 +343,7 @@ Functions The model to use for fine-tuning + .. py:method:: get_model() Get the model to use for the experiment. @@ -323,9 +354,11 @@ Functions The model to use for the experiment + .. py:method:: get_pretrain_data_module() :abstractmethod: + The data module to use for pre-training. Returns @@ -334,9 +367,11 @@ Functions The data module to use for pre-training + .. py:method:: get_pretrain_model() :abstractmethod: + Get the model to use for the pretraining phase. Returns @@ -348,16 +383,17 @@ Functions .. py:class:: LightningTest(limit_test_batches = 1.0, *args, **kwargs) - Bases: :py:obj:`LightningExperiment` + Helper class that provides a standard way to create an ABC using inheritance. + .. py:attribute:: _STAGE_NAME :value: 'test' - + .. py:method:: get_callbacks() @@ -369,6 +405,7 @@ Functions The list of callbacks to use for the experiment. + .. py:method:: get_trainer(logger, callbacks) Get trainer to use for the experiment. @@ -386,22 +423,23 @@ Functions The trainer to use for the experiment - .. py:method:: run_model(model, data_module, trainer) - + .. py:method:: run_model(model, data_module, trainer) -.. py:class:: LightningTrain(stage_name = 'train', epochs = 1, learning_rate = 0.001, checkpoint_metric = None, checkpoint_metric_mode = 'min', limit_train_batches = 1.0, limit_val_batches = 1.0, *args, **kwargs) +.. py:class:: LightningTrain(stage_name = 'train', epochs = 1, learning_rate = 0.001, checkpoint_metric = None, checkpoint_metric_mode = 'min', limit_train_batches = 1.0, limit_val_batches = 1.0, patience = None, *args, **kwargs) Bases: :py:obj:`LightningExperiment` + Helper class that provides a standard way to create an ABC using inheritance. + .. py:attribute:: _STAGE_NAME :value: 'train' - + .. py:method:: get_callbacks() @@ -413,6 +451,7 @@ Functions A list of callbacks to use for the experiment + .. py:method:: get_trainer(logger, callbacks) Get trainer to use for the experiment. @@ -430,10 +469,9 @@ Functions The trainer to use for the experiment - .. py:method:: run_model(model, data_module, trainer) - + .. py:method:: run_model(model, data_module, trainer) -.. py:function:: auto_main(commands) +.. py:function:: auto_main(commands, print_args = False) diff --git a/_sources/autoapi/ssl_tools/experiments/lightning_experiment/index.rst.txt b/_sources/autoapi/ssl_tools/experiments/lightning_experiment/index.rst.txt index 7f08952..e9e69a8 100644 --- a/_sources/autoapi/ssl_tools/experiments/lightning_experiment/index.rst.txt +++ b/_sources/autoapi/ssl_tools/experiments/lightning_experiment/index.rst.txt @@ -1,14 +1,11 @@ -:py:mod:`ssl_tools.experiments.lightning_experiment` -==================================================== +ssl_tools.experiments.lightning_experiment +========================================== .. py:module:: ssl_tools.experiments.lightning_experiment -Module Contents ---------------- - Classes -~~~~~~~ +------- .. autoapisummary:: @@ -18,67 +15,59 @@ Classes ssl_tools.experiments.lightning_experiment.LightningTrain - +Module Contents +--------------- .. py:class:: LightningExperiment(name = None, stage_name = None, batch_size = 1, load = None, accelerator = 'cpu', devices = 1, strategy = 'auto', num_nodes = 1, num_workers = None, log_every_n_steps = 50, *args, **kwargs) - Bases: :py:obj:`ssl_tools.experiments.experiment.Experiment` + Helper class that provides a standard way to create an ABC using inheritance. - .. py:property:: callbacks - :type: List[lightning.Callback] + .. py:attribute:: _MODEL_NAME + :type: str + :value: 'model' - .. py:property:: checkpoint_dir - :type: pathlib.Path - .. py:property:: data_module - :type: lightning.LightningDataModule + .. py:attribute:: _STAGE_NAME + :type: str + :value: 'stage' - .. py:property:: experiment_dir - :type: pathlib.Path + .. py:method:: __str__() - .. py:property:: finished - :type: bool + Return str(self). - .. py:property:: hyperparameters - :type: dict + .. py:property:: callbacks + :type: List[lightning.Callback] - .. py:property:: logger - :type: lightning.pytorch.loggers.Logger - .. py:property:: model - :type: lightning.LightningModule + .. py:property:: checkpoint_dir + :type: pathlib.Path - .. py:property:: trainer - :type: lightning.Trainer + .. py:property:: data_module + :type: lightning.LightningDataModule - .. py:attribute:: _MODEL_NAME - :type: str - :value: 'model' - - .. py:attribute:: _STAGE_NAME - :type: str - :value: 'stage' + .. py:property:: experiment_dir + :type: pathlib.Path - - .. py:method:: __str__() - Return str(self). + .. py:property:: finished + :type: bool + .. py:method:: get_callbacks() @@ -91,9 +80,11 @@ Classes A list of callbacks to use for the experiment + .. py:method:: get_data_module() :abstractmethod: + Get the datamodule to use for the experiment. Returns @@ -102,6 +93,7 @@ Classes The datamodule to use for the experiment + .. py:method:: get_logger() Get the logger to use for the experiment. @@ -112,9 +104,11 @@ Classes The logger to use for the experiment + .. py:method:: get_model() :abstractmethod: + Get the model to use for the experiment. Returns @@ -123,9 +117,11 @@ Classes The model to use for the experiment + .. py:method:: get_trainer(logger, callbacks) :abstractmethod: + Get trainer to use for the experiment. Parameters @@ -141,6 +137,12 @@ Classes The trainer to use for the experiment + + .. py:property:: hyperparameters + :type: dict + + + .. py:method:: load_checkpoint(model, path) Load the model to use for the experiment. @@ -151,6 +153,7 @@ Classes The model to use for the experiment + .. py:method:: log_hyperparams(logger) Log the hyperparameters for reproducibility purposes. @@ -163,6 +166,17 @@ Classes The logger to use for logging the hyperparameters + + .. py:property:: logger + :type: lightning.pytorch.loggers.Logger + + + + .. py:property:: model + :type: lightning.LightningModule + + + .. py:method:: run() Runs the experiment. This method: @@ -174,19 +188,25 @@ Classes 5. Trains/Tests the model + .. py:method:: run_model(model, data_module, trainer) :abstractmethod: + .. py:method:: setup() + .. py:property:: trainer + :type: lightning.Trainer -.. py:class:: LightningSSLTrain(training_mode = 'pretrain', load_backbone = None, *args, **kwargs) +.. py:class:: LightningSSLTrain(training_mode = 'pretrain', load_backbone = None, *args, **kwargs) + Bases: :py:obj:`LightningTrain` + Helper class that provides a standard way to create an ABC using inheritance. @@ -203,6 +223,7 @@ Classes using ``load_backbone``. The ``load`` parameter is used to load the full model (backbone + head). + .. py:method:: get_data_module() Get the datamodule to use for the experiment. @@ -213,9 +234,11 @@ Classes The datamodule to use for the experiment + .. py:method:: get_finetune_data_module() :abstractmethod: + The data module to use for fine-tuning. Returns @@ -229,9 +252,11 @@ Classes _description_ + .. py:method:: get_finetune_model(load_backbone = None) :abstractmethod: + Get the model to use for fine-tuning. Parameters @@ -246,6 +271,7 @@ Classes The model to use for fine-tuning + .. py:method:: get_model() Get the model to use for the experiment. @@ -256,9 +282,11 @@ Classes The model to use for the experiment + .. py:method:: get_pretrain_data_module() :abstractmethod: + The data module to use for pre-training. Returns @@ -267,9 +295,11 @@ Classes The data module to use for pre-training + .. py:method:: get_pretrain_model() :abstractmethod: + Get the model to use for the pretraining phase. Returns @@ -281,16 +311,17 @@ Classes .. py:class:: LightningTest(limit_test_batches = 1.0, *args, **kwargs) - Bases: :py:obj:`LightningExperiment` + Helper class that provides a standard way to create an ABC using inheritance. + .. py:attribute:: _STAGE_NAME :value: 'test' - + .. py:method:: get_callbacks() @@ -302,6 +333,7 @@ Classes The list of callbacks to use for the experiment. + .. py:method:: get_trainer(logger, callbacks) Get trainer to use for the experiment. @@ -319,22 +351,23 @@ Classes The trainer to use for the experiment - .. py:method:: run_model(model, data_module, trainer) - + .. py:method:: run_model(model, data_module, trainer) -.. py:class:: LightningTrain(stage_name = 'train', epochs = 1, learning_rate = 0.001, checkpoint_metric = None, checkpoint_metric_mode = 'min', limit_train_batches = 1.0, limit_val_batches = 1.0, *args, **kwargs) +.. py:class:: LightningTrain(stage_name = 'train', epochs = 1, learning_rate = 0.001, checkpoint_metric = None, checkpoint_metric_mode = 'min', limit_train_batches = 1.0, limit_val_batches = 1.0, patience = None, *args, **kwargs) Bases: :py:obj:`LightningExperiment` + Helper class that provides a standard way to create an ABC using inheritance. + .. py:attribute:: _STAGE_NAME :value: 'train' - + .. py:method:: get_callbacks() @@ -346,6 +379,7 @@ Classes A list of callbacks to use for the experiment + .. py:method:: get_trainer(logger, callbacks) Get trainer to use for the experiment. @@ -363,7 +397,7 @@ Classes The trainer to use for the experiment - .. py:method:: run_model(model, data_module, trainer) + .. py:method:: run_model(model, data_module, trainer) diff --git a/_sources/autoapi/ssl_tools/index.rst.txt b/_sources/autoapi/ssl_tools/index.rst.txt index 4115582..51fe6a3 100644 --- a/_sources/autoapi/ssl_tools/index.rst.txt +++ b/_sources/autoapi/ssl_tools/index.rst.txt @@ -1,22 +1,24 @@ -:py:mod:`ssl_tools` -=================== +ssl_tools +========= .. py:module:: ssl_tools Subpackages ----------- + .. toctree:: - :titlesonly: - :maxdepth: 3 + :maxdepth: 1 - analysis/index.rst - callbacks/index.rst - data/index.rst - experiments/index.rst - losses/index.rst - models/index.rst - transforms/index.rst - utils/index.rst + /autoapi/ssl_tools/analysis/index + /autoapi/ssl_tools/benchmarks/index + /autoapi/ssl_tools/callbacks/index + /autoapi/ssl_tools/data/index + /autoapi/ssl_tools/experiments/index + /autoapi/ssl_tools/losses/index + /autoapi/ssl_tools/models/index + /autoapi/ssl_tools/pipelines/index + /autoapi/ssl_tools/transforms/index + /autoapi/ssl_tools/utils/index diff --git a/_sources/autoapi/ssl_tools/losses/contrastive_loss/index.rst.txt b/_sources/autoapi/ssl_tools/losses/contrastive_loss/index.rst.txt new file mode 100644 index 0000000..31e1f84 --- /dev/null +++ b/_sources/autoapi/ssl_tools/losses/contrastive_loss/index.rst.txt @@ -0,0 +1,25 @@ +ssl_tools.losses.contrastive_loss +================================= + +.. py:module:: ssl_tools.losses.contrastive_loss + + +Classes +------- + +.. autoapisummary:: + + ssl_tools.losses.contrastive_loss.ContrastiveLoss + + +Module Contents +--------------- + +.. py:class:: ContrastiveLoss(margin = 1.0) + + Bases: :py:obj:`torch.nn.Module` + + + .. py:method:: forward(y_true, y_pred) + + diff --git a/_sources/autoapi/ssl_tools/losses/index.rst.txt b/_sources/autoapi/ssl_tools/losses/index.rst.txt index 905d1a6..b354bfb 100644 --- a/_sources/autoapi/ssl_tools/losses/index.rst.txt +++ b/_sources/autoapi/ssl_tools/losses/index.rst.txt @@ -1,15 +1,16 @@ -:py:mod:`ssl_tools.losses` -========================== +ssl_tools.losses +================ .. py:module:: ssl_tools.losses Submodules ---------- + .. toctree:: - :titlesonly: :maxdepth: 1 - nxtent/index.rst + /autoapi/ssl_tools/losses/contrastive_loss/index + /autoapi/ssl_tools/losses/nxtent/index diff --git a/_sources/autoapi/ssl_tools/losses/nxtent/index.rst.txt b/_sources/autoapi/ssl_tools/losses/nxtent/index.rst.txt index c3e369a..10bcfdb 100644 --- a/_sources/autoapi/ssl_tools/losses/nxtent/index.rst.txt +++ b/_sources/autoapi/ssl_tools/losses/nxtent/index.rst.txt @@ -1,27 +1,25 @@ -:py:mod:`ssl_tools.losses.nxtent` -================================= +ssl_tools.losses.nxtent +======================= .. py:module:: ssl_tools.losses.nxtent -Module Contents ---------------- - Classes -~~~~~~~ +------- .. autoapisummary:: ssl_tools.losses.nxtent.NTXentLoss_poly - +Module Contents +--------------- .. py:class:: NTXentLoss_poly(temperature = 0.2, use_cosine_similarity = True) - Bases: :py:obj:`lightning.LightningModule` + .. py:method:: _cosine_simililarity(x, y) @@ -29,6 +27,7 @@ Classes :staticmethod: + .. py:method:: _get_correlated_mask(batch_size) @@ -38,4 +37,3 @@ Classes .. py:method:: forward(zis, zjs) - diff --git a/_sources/autoapi/ssl_tools/models/index.rst.txt b/_sources/autoapi/ssl_tools/models/index.rst.txt index e84c716..f2ef03c 100644 --- a/_sources/autoapi/ssl_tools/models/index.rst.txt +++ b/_sources/autoapi/ssl_tools/models/index.rst.txt @@ -1,17 +1,26 @@ -:py:mod:`ssl_tools.models` -========================== +ssl_tools.models +================ .. py:module:: ssl_tools.models Subpackages ----------- + +.. toctree:: + :maxdepth: 1 + + /autoapi/ssl_tools/models/layers/index + /autoapi/ssl_tools/models/nets/index + /autoapi/ssl_tools/models/ssl/index + + +Submodules +---------- + .. toctree:: - :titlesonly: - :maxdepth: 3 + :maxdepth: 1 - layers/index.rst - nets/index.rst - ssl/index.rst + /autoapi/ssl_tools/models/utils/index diff --git a/_sources/autoapi/ssl_tools/models/layers/gru/index.rst.txt b/_sources/autoapi/ssl_tools/models/layers/gru/index.rst.txt index e47f3b1..fdcb487 100644 --- a/_sources/autoapi/ssl_tools/models/layers/gru/index.rst.txt +++ b/_sources/autoapi/ssl_tools/models/layers/gru/index.rst.txt @@ -1,27 +1,25 @@ -:py:mod:`ssl_tools.models.layers.gru` -===================================== +ssl_tools.models.layers.gru +=========================== .. py:module:: ssl_tools.models.layers.gru -Module Contents ---------------- - Classes -~~~~~~~ +------- .. autoapisummary:: ssl_tools.models.layers.gru.GRUEncoder - +Module Contents +--------------- .. py:class:: GRUEncoder(hidden_size = 100, in_channels = 6, encoding_size = 10, num_layers = 1, dropout = 0.0, bidirectional = True) - Bases: :py:obj:`torch.nn.Module` + Gate Recurrent Unit (GRU) Encoder. This class is a wrapper for the GRU layer (torch.nn.GRU) followed by a @@ -61,7 +59,7 @@ Classes bidirectional : bool, optional If ``True``, becomes a bidirectional GRU, by default True - .. py:method:: forward(x) + .. py:method:: forward(x) diff --git a/_sources/autoapi/ssl_tools/models/layers/index.rst.txt b/_sources/autoapi/ssl_tools/models/layers/index.rst.txt index 129abcd..141e027 100644 --- a/_sources/autoapi/ssl_tools/models/layers/index.rst.txt +++ b/_sources/autoapi/ssl_tools/models/layers/index.rst.txt @@ -1,15 +1,15 @@ -:py:mod:`ssl_tools.models.layers` -================================= +ssl_tools.models.layers +======================= .. py:module:: ssl_tools.models.layers Submodules ---------- + .. toctree:: - :titlesonly: :maxdepth: 1 - gru/index.rst + /autoapi/ssl_tools/models/layers/gru/index diff --git a/_sources/autoapi/ssl_tools/models/nets/cnn_ha_etal/index.rst.txt b/_sources/autoapi/ssl_tools/models/nets/cnn_ha_etal/index.rst.txt new file mode 100644 index 0000000..a882318 --- /dev/null +++ b/_sources/autoapi/ssl_tools/models/nets/cnn_ha_etal/index.rst.txt @@ -0,0 +1,78 @@ +ssl_tools.models.nets.cnn_ha_etal +================================= + +.. py:module:: ssl_tools.models.nets.cnn_ha_etal + + +Classes +------- + +.. autoapisummary:: + + ssl_tools.models.nets.cnn_ha_etal.CNN_HaEtAl_1D + ssl_tools.models.nets.cnn_ha_etal.CNN_HaEtAl_2D + + +Module Contents +--------------- + +.. py:class:: CNN_HaEtAl_1D(input_shape = (1, 6, 60), num_classes = 6, learning_rate = 0.001) + + Bases: :py:obj:`ssl_tools.models.nets.simple.SimpleClassificationNet` + + + .. py:method:: _calculate_fc_input_features(backbone, input_shape) + + Run a single forward pass with a random input to get the number of + features after the convolutional layers. + + Parameters + ---------- + backbone : torch.nn.Module + The backbone of the network + input_shape : Tuple[int, int, int] + The input shape of the network. + + Returns + ------- + int + The number of features after the convolutional layers. + + + + .. py:method:: _create_backbone(input_shape) + + + .. py:method:: _create_fc(input_features, num_classes) + + +.. py:class:: CNN_HaEtAl_2D(pad_at = (3, ), input_shape = (1, 6, 60), num_classes = 6, learning_rate = 0.001) + + Bases: :py:obj:`ssl_tools.models.nets.simple.SimpleClassificationNet` + + + .. py:method:: _calculate_fc_input_features(backbone, input_shape) + + Run a single forward pass with a random input to get the number of + features after the convolutional layers. + + Parameters + ---------- + backbone : torch.nn.Module + The backbone of the network + input_shape : Tuple[int, int, int] + The input shape of the network. + + Returns + ------- + int + The number of features after the convolutional layers. + + + + .. py:method:: _create_backbone(input_shape) + + + .. py:method:: _create_fc(input_features, num_classes) + + diff --git a/_sources/autoapi/ssl_tools/models/nets/cnn_pf/index.rst.txt b/_sources/autoapi/ssl_tools/models/nets/cnn_pf/index.rst.txt new file mode 100644 index 0000000..c2214d9 --- /dev/null +++ b/_sources/autoapi/ssl_tools/models/nets/cnn_pf/index.rst.txt @@ -0,0 +1,59 @@ +ssl_tools.models.nets.cnn_pf +============================ + +.. py:module:: ssl_tools.models.nets.cnn_pf + + +Classes +------- + +.. autoapisummary:: + + ssl_tools.models.nets.cnn_pf.CNN_PFF_2D + ssl_tools.models.nets.cnn_pf.CNN_PF_2D + ssl_tools.models.nets.cnn_pf.CNN_PF_Backbone + + +Module Contents +--------------- + +.. py:class:: CNN_PFF_2D(*args, **kwargs) + + Bases: :py:obj:`CNN_PF_2D` + + +.. py:class:: CNN_PF_2D(pad_at, input_shape = (1, 6, 60), out_channels = 16, num_classes = 6, learning_rate = 0.001, include_middle = False) + + Bases: :py:obj:`ssl_tools.models.nets.simple.SimpleClassificationNet` + + + .. py:method:: _calculate_fc_input_features(backbone, input_shape) + + Run a single forward pass with a random input to get the number of + features after the convolutional layers. + + Parameters + ---------- + backbone : torch.nn.Module + The backbone of the network + input_shape : Tuple[int, int, int] + The input shape of the network. + + Returns + ------- + int + The number of features after the convolutional layers. + + + + .. py:method:: _create_fc(input_features, num_classes) + + +.. py:class:: CNN_PF_Backbone(pad_at, input_shape, out_channels = 16, include_middle = False) + + Bases: :py:obj:`torch.nn.Module` + + + .. py:method:: forward(x) + + diff --git a/_sources/autoapi/ssl_tools/models/nets/convae/index.rst.txt b/_sources/autoapi/ssl_tools/models/nets/convae/index.rst.txt new file mode 100644 index 0000000..a924245 --- /dev/null +++ b/_sources/autoapi/ssl_tools/models/nets/convae/index.rst.txt @@ -0,0 +1,58 @@ +ssl_tools.models.nets.convae +============================ + +.. py:module:: ssl_tools.models.nets.convae + + +Classes +------- + +.. autoapisummary:: + + ssl_tools.models.nets.convae.ContrastiveConvolutionalAutoEncoder + ssl_tools.models.nets.convae.ContrastiveConvolutionalAutoEncoder2D + ssl_tools.models.nets.convae.ConvolutionalAutoEncoder + ssl_tools.models.nets.convae.ConvolutionalAutoEncoder2D + ssl_tools.models.nets.convae._ConvolutionalAutoEncoder + ssl_tools.models.nets.convae._ConvolutionalAutoEncoder2D + + +Module Contents +--------------- + +.. py:class:: ContrastiveConvolutionalAutoEncoder(input_shape = (1, 16), learning_rate = 0.001, margin = 1.0) + + Bases: :py:obj:`ssl_tools.models.nets.simple.SimpleReconstructionNet` + + +.. py:class:: ContrastiveConvolutionalAutoEncoder2D(input_shape = (4, 4, 1), learning_rate = 0.001, margin = 1.0) + + Bases: :py:obj:`ssl_tools.models.nets.simple.SimpleReconstructionNet` + + +.. py:class:: ConvolutionalAutoEncoder(input_shape = (1, 16), learning_rate = 0.001) + + Bases: :py:obj:`ssl_tools.models.nets.simple.SimpleReconstructionNet` + + +.. py:class:: ConvolutionalAutoEncoder2D(input_shape = (1, 4, 4), learning_rate = 0.001) + + Bases: :py:obj:`ssl_tools.models.nets.simple.SimpleReconstructionNet` + + +.. py:class:: _ConvolutionalAutoEncoder(input_shape = (1, 16)) + + Bases: :py:obj:`torch.nn.Module` + + + .. py:method:: forward(x) + + +.. py:class:: _ConvolutionalAutoEncoder2D(input_shape = (1, 4, 4)) + + Bases: :py:obj:`torch.nn.Module` + + + .. py:method:: forward(x) + + diff --git a/_sources/autoapi/ssl_tools/models/nets/convnet/index.rst.txt b/_sources/autoapi/ssl_tools/models/nets/convnet/index.rst.txt index 4a3ca4c..f8f165e 100644 --- a/_sources/autoapi/ssl_tools/models/nets/convnet/index.rst.txt +++ b/_sources/autoapi/ssl_tools/models/nets/convnet/index.rst.txt @@ -1,14 +1,11 @@ -:py:mod:`ssl_tools.models.nets.convnet` -======================================= +ssl_tools.models.nets.convnet +============================= .. py:module:: ssl_tools.models.nets.convnet -Module Contents ---------------- - Classes -~~~~~~~ +------- .. autoapisummary:: @@ -16,113 +13,96 @@ Classes ssl_tools.models.nets.convnet.Simple2DConvNetwork +Module Contents +--------------- +.. py:class:: Simple1DConvNetwork(input_shape = (6, 60), num_classes = 6, learning_rate = 0.001) -.. py:class:: Simple1DConvNetwork(input_channels = 6, num_classes = 6, time_steps = 60, learning_rate = 0.001) - - - Bases: :py:obj:`lightning.LightningModule` - - Model for human-activity-recognition. + Bases: :py:obj:`ssl_tools.models.nets.simple.SimpleClassificationNet` - .. py:method:: _calculate_fc_input_features(input_channels) + + Create a simple 1D Convolutional Network with 3 layers and 2 fully + connected layers. - .. py:method:: _common_step(batch, batch_idx, prefix) + Parameters + ---------- + input_shape : Tuple[int, int], optional + A 2-tuple containing the number of input channels and the number of + features, by default (6, 60). + num_classes : int, optional + Number of output classes, by default 6 + learning_rate : float, optional + Learning rate for Adam optimizer, by default 1e-3 - .. py:method:: _compute_metrics(y_hat, y, stage) + .. py:method:: _calculate_fc_input_features(backbone, input_shape) - Compute the metrics. + Run a single forward pass with a random input to get the number of + features after the convolutional layers. Parameters ---------- - y_hat : torch.Tensor - The predictions of the model - y : _type_ - The ground truth labels - stage : str - The stage of the training loop (train, val or test) + backbone : torch.nn.Module + The backbone of the network + input_shape : Tuple[int, int] + The input shape of the network. Returns ------- - Dict[str, float] - A dictionary containing the metrics. The keys are the names of the - metrics, and the values are the values of the metrics. - - - .. py:method:: configure_optimizers() - - - .. py:method:: forward(x) - - - .. py:method:: loss_function(X, y) - + int + The number of features after the convolutional layers. - .. py:method:: predict_step(batch, batch_idx, dataloader_idx=None) - .. py:method:: test_step(batch, batch_idx) + .. py:method:: _create_backbone(input_channels) - .. py:method:: training_step(batch, batch_idx) + .. py:method:: _create_fc(input_features, num_classes) - .. py:method:: validation_step(batch, batch_idx) +.. py:class:: Simple2DConvNetwork(input_shape = (6, 1, 60), num_classes = 6, learning_rate = 0.001) + Bases: :py:obj:`ssl_tools.models.nets.simple.SimpleClassificationNet` -.. py:class:: Simple2DConvNetwork(input_channels = 10, num_classes = 6, time_steps = 60, learning_rate = 0.001) + + Create a simple 2D Convolutional Network with 3 layers and 2 fully + connected layers. + Parameters + ---------- + input_shape : Tuple[int, int, int], optional + A 3-tuple containing the number of input channels, and the number of + the 2D input shape, by default (6, 1, 60). + num_classes : int, optional + Number of output classes, by default 6 + learning_rate : float, optional + Learning rate for Adam optimizer, by default 1e-3 - Bases: :py:obj:`lightning.LightningModule` - .. py:method:: _calculate_fc_input_features(input_channels) + .. py:method:: _calculate_fc_input_features(backbone, input_shape) - - .. py:method:: _common_step(batch, batch_idx, prefix) - - - .. py:method:: _compute_metrics(y_hat, y, stage) - - Compute the metrics. + Run a single forward pass with a random input to get the number of + features after the convolutional layers. Parameters ---------- - y_hat : torch.Tensor - The predictions of the model - y : _type_ - The ground truth labels - stage : str - The stage of the training loop (train, val or test) + backbone : torch.nn.Module + The backbone of the network + input_shape : Tuple[int, int, int] + The input shape of the network. Returns ------- - Dict[str, float] - A dictionary containing the metrics. The keys are the names of the - metrics, and the values are the values of the metrics. - - - .. py:method:: configure_optimizers() - - - .. py:method:: forward(x) - - - .. py:method:: loss_function(X, y) - - - .. py:method:: predict_step(batch, batch_idx, dataloader_idx=None) - - - .. py:method:: test_step(batch, batch_idx) + int + The number of features after the convolutional layers. - .. py:method:: training_step(batch, batch_idx) + .. py:method:: _create_backbone(input_channels) - .. py:method:: validation_step(batch, batch_idx) + .. py:method:: _create_fc(input_features, num_classes) diff --git a/_sources/autoapi/ssl_tools/models/nets/deep_conv_lstm/index.rst.txt b/_sources/autoapi/ssl_tools/models/nets/deep_conv_lstm/index.rst.txt new file mode 100644 index 0000000..9c561a1 --- /dev/null +++ b/_sources/autoapi/ssl_tools/models/nets/deep_conv_lstm/index.rst.txt @@ -0,0 +1,59 @@ +ssl_tools.models.nets.deep_conv_lstm +==================================== + +.. py:module:: ssl_tools.models.nets.deep_conv_lstm + + +Classes +------- + +.. autoapisummary:: + + ssl_tools.models.nets.deep_conv_lstm.ConvLSTMCell + ssl_tools.models.nets.deep_conv_lstm.DeepConvLSTM + + +Module Contents +--------------- + +.. py:class:: ConvLSTMCell(input_shape) + + Bases: :py:obj:`torch.nn.Module` + + + .. py:method:: _calculate_conv_output_shape(backbone, input_shape) + + + .. py:method:: forward(x) + + +.. py:class:: DeepConvLSTM(input_shape = (1, 6, 60), num_classes = 6, learning_rate = 0.001) + + Bases: :py:obj:`ssl_tools.models.nets.simple.SimpleClassificationNet` + + + .. py:method:: _calculate_fc_input_features(backbone, input_shape) + + Run a single forward pass with a random input to get the number of + features after the convolutional layers. + + Parameters + ---------- + backbone : torch.nn.Module + The backbone of the network + input_shape : Tuple[int, int, int] + The input shape of the network. + + Returns + ------- + int + The number of features after the convolutional layers. + + + + .. py:method:: _create_backbone(input_shape) + + + .. py:method:: _create_fc(input_features, num_classes) + + diff --git a/_sources/autoapi/ssl_tools/models/nets/deep_convnet/index.rst.txt b/_sources/autoapi/ssl_tools/models/nets/deep_convnet/index.rst.txt new file mode 100644 index 0000000..74fe642 --- /dev/null +++ b/_sources/autoapi/ssl_tools/models/nets/deep_convnet/index.rst.txt @@ -0,0 +1,95 @@ +ssl_tools.models.nets.deep_convnet +================================== + +.. py:module:: ssl_tools.models.nets.deep_convnet + + +Classes +------- + +.. autoapisummary:: + + ssl_tools.models.nets.deep_convnet.DeepConvNet + + +Functions +--------- + +.. autoapisummary:: + + ssl_tools.models.nets.deep_convnet.main + + +Module Contents +--------------- + +.. py:class:: DeepConvNet(input_channels = 6, time_steps = 60, num_classes = 6, learning_rate = 0.001) + + Bases: :py:obj:`lightning.LightningModule` + + + .. py:method:: _calculate_fc_input_features(input_channels, time_steps) + + Calculate the number of input features of the fully connected layer. + Basically, it performs a forward pass with a dummy input to get the + output shape after the convolutional layers. + + Parameters + ---------- + input_channels : int + The number of input channels. + + Returns + ------- + int + The number of input features of the fully connected layer. + + + + .. py:method:: _common_step(batch, batch_idx, prefix) + + + .. py:method:: _compute_metrics(y_hat, y, stage) + + Compute the metrics. + + Parameters + ---------- + y_hat : torch.Tensor + The predictions of the model + y : torch.Tensor + The ground truth labels + stage : str + The stage of the training loop (train, val or test) + + Returns + ------- + Dict[str, float] + A dictionary containing the metrics. The keys are the names of the + metrics, and the values are the values of the metrics. + + + + .. py:method:: configure_optimizers() + + + .. py:method:: forward(x) + + + .. py:method:: loss_function(X, y) + + + .. py:method:: predict_step(batch, batch_idx, dataloader_idx=None) + + + .. py:method:: test_step(batch, batch_idx) + + + .. py:method:: training_step(batch, batch_idx) + + + .. py:method:: validation_step(batch, batch_idx) + + +.. py:function:: main() + diff --git a/_sources/autoapi/ssl_tools/models/nets/imu_transformer/index.rst.txt b/_sources/autoapi/ssl_tools/models/nets/imu_transformer/index.rst.txt new file mode 100644 index 0000000..6898bff --- /dev/null +++ b/_sources/autoapi/ssl_tools/models/nets/imu_transformer/index.rst.txt @@ -0,0 +1,86 @@ +ssl_tools.models.nets.imu_transformer +===================================== + +.. py:module:: ssl_tools.models.nets.imu_transformer + + +Classes +------- + +.. autoapisummary:: + + ssl_tools.models.nets.imu_transformer.IMUCNN + ssl_tools.models.nets.imu_transformer.IMUTransformerEncoder + ssl_tools.models.nets.imu_transformer._IMUTransformerEncoder + + +Functions +--------- + +.. autoapisummary:: + + ssl_tools.models.nets.imu_transformer.test_imu_cnn + ssl_tools.models.nets.imu_transformer.test_imu_transformer + + +Module Contents +--------------- + +.. py:class:: IMUCNN(input_shape = (6, 60), hidden_dim = 64, num_classes = 6, dropout_factor = 0.1, learning_rate = 0.001) + + Bases: :py:obj:`ssl_tools.models.nets.simple.SimpleClassificationNet` + + + .. py:method:: _calculate_fc_input_features(backbone, input_shape) + + + .. py:method:: _create_backbone(input_shape, hidden_dim, dropout_factor) + + + .. py:method:: _create_fc(input_features, hidden_dim, num_classes) + + +.. py:class:: IMUTransformerEncoder(input_shape = (6, 60), transformer_dim = 64, encode_position = True, nhead = 8, dim_feedforward = 128, transformer_dropout = 0.1, transformer_activation = 'gelu', num_encoder_layers = 6, num_classes = 6, learning_rate = 0.001) + + Bases: :py:obj:`ssl_tools.models.nets.simple.SimpleClassificationNet` + + + .. py:method:: _create_backbone(input_shape, transformer_dim, encode_position, nhead, dim_feedforward, transformer_dropout, transformer_activation, num_encoder_layers) + + + .. py:method:: _create_fc(transform_dim, num_classes) + + +.. py:class:: _IMUTransformerEncoder(input_shape = (6, 60), transformer_dim = 64, encode_position = True, nhead = 8, dim_feedforward = 128, transformer_dropout = 0.1, transformer_activation = 'gelu', num_encoder_layers = 6) + + Bases: :py:obj:`torch.nn.Module` + + + + input_shape: (tuple) shape of the input data + transformer_dim: (int) dimension of the transformer + encode_position: (bool) whether to encode position or not + nhead: (int) number of attention heads + dim_feedforward: (int) dimension of the feedforward network + transformer_dropout: (float) dropout rate for the transformer + transformer_activation: (str) activation function for the transformer + num_encoder_layers: (int) number of transformer encoder layers + num_classes: (int) number of output classes + + + .. py:method:: forward(x) + + Forward + + Parameters + ---------- + x : _type_ + A tensor of shape (B, C, S) with B = batch size, C = channels, S = sequence length + + + + +.. py:function:: test_imu_cnn() + +.. py:function:: test_imu_transformer() + diff --git a/_sources/autoapi/ssl_tools/models/nets/inception_time/index.rst.txt b/_sources/autoapi/ssl_tools/models/nets/inception_time/index.rst.txt new file mode 100644 index 0000000..96f3284 --- /dev/null +++ b/_sources/autoapi/ssl_tools/models/nets/inception_time/index.rst.txt @@ -0,0 +1,77 @@ +ssl_tools.models.nets.inception_time +==================================== + +.. py:module:: ssl_tools.models.nets.inception_time + + +Classes +------- + +.. autoapisummary:: + + ssl_tools.models.nets.inception_time.InceptionModule + ssl_tools.models.nets.inception_time.InceptionTime + ssl_tools.models.nets.inception_time.ShortcutLayer + ssl_tools.models.nets.inception_time._InceptionTime + + +Module Contents +--------------- + +.. py:class:: InceptionModule(input_shape = (6, 60), stride = 1, kernel_size = 41, nb_filters = 32, use_bottleneck = True, bottleneck_size = 32) + + Bases: :py:obj:`torch.nn.Module` + + + .. py:method:: build_model() + + + .. py:method:: forward(input_tensor) + + +.. py:class:: InceptionTime(input_shape = (6, 60), nb_filters=32, use_residual = True, use_bottleneck = True, depth = 6, kernel_size = 41, num_classes = 6, learning_rate = 0.001) + + Bases: :py:obj:`ssl_tools.models.nets.simple.SimpleClassificationNet` + + + .. py:method:: _calculate_fc_input_features(backbone, input_shape) + + Run a single forward pass with a random input to get the number of + features after the convolutional layers. + + Parameters + ---------- + backbone : torch.nn.Module + The backbone of the network + input_shape : Tuple[int, int, int] + The input shape of the network. + + Returns + ------- + int + The number of features after the convolutional layers. + + + + .. py:method:: _create_fc(input_features, num_classes) + + +.. py:class:: ShortcutLayer(input_tensor_shape, out_tensor_shape) + + Bases: :py:obj:`torch.nn.Module` + + + .. py:method:: forward(input_tensor, output_tensor) + + +.. py:class:: _InceptionTime(input_shape = (6, 60), nb_filters=32, use_residual = True, use_bottleneck = True, depth = 6, kernel_size = 41) + + Bases: :py:obj:`torch.nn.Module` + + + .. py:method:: build_model() + + + .. py:method:: forward(x) + + diff --git a/_sources/autoapi/ssl_tools/models/nets/index.rst.txt b/_sources/autoapi/ssl_tools/models/nets/index.rst.txt index c080469..eb2d9b1 100644 --- a/_sources/autoapi/ssl_tools/models/nets/index.rst.txt +++ b/_sources/autoapi/ssl_tools/models/nets/index.rst.txt @@ -1,16 +1,137 @@ -:py:mod:`ssl_tools.models.nets` -=============================== +ssl_tools.models.nets +===================== .. py:module:: ssl_tools.models.nets Submodules ---------- + .. toctree:: - :titlesonly: :maxdepth: 1 - convnet/index.rst - wisenet/index.rst + /autoapi/ssl_tools/models/nets/cnn_ha_etal/index + /autoapi/ssl_tools/models/nets/cnn_pf/index + /autoapi/ssl_tools/models/nets/convae/index + /autoapi/ssl_tools/models/nets/convnet/index + /autoapi/ssl_tools/models/nets/deep_conv_lstm/index + /autoapi/ssl_tools/models/nets/deep_convnet/index + /autoapi/ssl_tools/models/nets/imu_transformer/index + /autoapi/ssl_tools/models/nets/inception_time/index + /autoapi/ssl_tools/models/nets/lstm_ae/index + /autoapi/ssl_tools/models/nets/multi_channel_cnn/index + /autoapi/ssl_tools/models/nets/resnet1d/index + /autoapi/ssl_tools/models/nets/resnet_1d/index + /autoapi/ssl_tools/models/nets/simple/index + /autoapi/ssl_tools/models/nets/transformer/index + /autoapi/ssl_tools/models/nets/wisenet/index + + +Classes +------- + +.. autoapisummary:: + + ssl_tools.models.nets.MLPClassifier + ssl_tools.models.nets.Simple1DConvNetwork + ssl_tools.models.nets.Simple2DConvNetwork + + +Package Contents +---------------- + +.. py:class:: MLPClassifier(input_size, hidden_size, num_hidden_layers, output_size, learning_rate = 0.001, flatten = True, loss_fn = None, train_metrics = None, val_metrics = None, test_metrics = None) + + Bases: :py:obj:`SimpleClassificationNet` + + +.. py:class:: Simple1DConvNetwork(input_shape = (6, 60), num_classes = 6, learning_rate = 0.001) + + Bases: :py:obj:`ssl_tools.models.nets.simple.SimpleClassificationNet` + + + + Create a simple 1D Convolutional Network with 3 layers and 2 fully + connected layers. + + Parameters + ---------- + input_shape : Tuple[int, int], optional + A 2-tuple containing the number of input channels and the number of + features, by default (6, 60). + num_classes : int, optional + Number of output classes, by default 6 + learning_rate : float, optional + Learning rate for Adam optimizer, by default 1e-3 + + + .. py:method:: _calculate_fc_input_features(backbone, input_shape) + + Run a single forward pass with a random input to get the number of + features after the convolutional layers. + + Parameters + ---------- + backbone : torch.nn.Module + The backbone of the network + input_shape : Tuple[int, int] + The input shape of the network. + + Returns + ------- + int + The number of features after the convolutional layers. + + + + .. py:method:: _create_backbone(input_channels) + + + .. py:method:: _create_fc(input_features, num_classes) + + +.. py:class:: Simple2DConvNetwork(input_shape = (6, 1, 60), num_classes = 6, learning_rate = 0.001) + + Bases: :py:obj:`ssl_tools.models.nets.simple.SimpleClassificationNet` + + + + Create a simple 2D Convolutional Network with 3 layers and 2 fully + connected layers. + + Parameters + ---------- + input_shape : Tuple[int, int, int], optional + A 3-tuple containing the number of input channels, and the number of + the 2D input shape, by default (6, 1, 60). + num_classes : int, optional + Number of output classes, by default 6 + learning_rate : float, optional + Learning rate for Adam optimizer, by default 1e-3 + + + .. py:method:: _calculate_fc_input_features(backbone, input_shape) + + Run a single forward pass with a random input to get the number of + features after the convolutional layers. + + Parameters + ---------- + backbone : torch.nn.Module + The backbone of the network + input_shape : Tuple[int, int, int] + The input shape of the network. + + Returns + ------- + int + The number of features after the convolutional layers. + + + + .. py:method:: _create_backbone(input_channels) + + + .. py:method:: _create_fc(input_features, num_classes) diff --git a/_sources/autoapi/ssl_tools/models/nets/lstm_ae/index.rst.txt b/_sources/autoapi/ssl_tools/models/nets/lstm_ae/index.rst.txt new file mode 100644 index 0000000..812b5af --- /dev/null +++ b/_sources/autoapi/ssl_tools/models/nets/lstm_ae/index.rst.txt @@ -0,0 +1,42 @@ +ssl_tools.models.nets.lstm_ae +============================= + +.. py:module:: ssl_tools.models.nets.lstm_ae + + +Classes +------- + +.. autoapisummary:: + + ssl_tools.models.nets.lstm_ae.LSTMAutoencoder + ssl_tools.models.nets.lstm_ae._LSTMAutoEncoder + + +Module Contents +--------------- + +.. py:class:: LSTMAutoencoder(input_shape = (16, 1), learning_rate = 0.001) + + Bases: :py:obj:`ssl_tools.models.nets.simple.SimpleReconstructionNet` + + + + Create a LSTM Autoencoder model + + Parameters + ---------- + input_shape : Tuple[int, int], optional + The shape of the input. The first element is the sequence length and the second is the number of features, by default (16, 1) + learning_rate : float, optional + Learning rate for Adam optimizer, by default 1e-3 + + +.. py:class:: _LSTMAutoEncoder(input_shape = (16, 1)) + + Bases: :py:obj:`torch.nn.Module` + + + .. py:method:: forward(x) + + diff --git a/_sources/autoapi/ssl_tools/models/nets/multi_channel_cnn/index.rst.txt b/_sources/autoapi/ssl_tools/models/nets/multi_channel_cnn/index.rst.txt new file mode 100644 index 0000000..65bfdce --- /dev/null +++ b/_sources/autoapi/ssl_tools/models/nets/multi_channel_cnn/index.rst.txt @@ -0,0 +1,71 @@ +ssl_tools.models.nets.multi_channel_cnn +======================================= + +.. py:module:: ssl_tools.models.nets.multi_channel_cnn + + +Classes +------- + +.. autoapisummary:: + + ssl_tools.models.nets.multi_channel_cnn.MultiChannelCNN_HAR + ssl_tools.models.nets.multi_channel_cnn._MultiChannelCNN_HAR + + +Module Contents +--------------- + +.. py:class:: MultiChannelCNN_HAR(input_shape = (1, 6, 60), num_classes = 6, learning_rate = 0.001) + + Bases: :py:obj:`ssl_tools.models.nets.simple.SimpleClassificationNet` + + + + Create a simple 1D Convolutional Network with 3 layers and 2 fully + connected layers. + + Parameters + ---------- + input_shape : Tuple[int, int], optional + A 2-tuple containing the number of input channels and the number of + features, by default (6, 60). + num_classes : int, optional + Number of output classes, by default 6 + learning_rate : float, optional + Learning rate for Adam optimizer, by default 1e-3 + + + .. py:method:: _calculate_fc_input_features(backbone, input_shape) + + Run a single forward pass with a random input to get the number of + features after the convolutional layers. + + Parameters + ---------- + backbone : torch.nn.Module + The backbone of the network + input_shape : Tuple[int, int] + The input shape of the network. + + Returns + ------- + int + The number of features after the convolutional layers. + + + + .. py:method:: _create_backbone(input_channels) + + + .. py:method:: _create_fc(input_features, num_classes) + + +.. py:class:: _MultiChannelCNN_HAR(input_channels = 1, concatenate = True) + + Bases: :py:obj:`torch.nn.Module` + + + .. py:method:: forward(x) + + diff --git a/_sources/autoapi/ssl_tools/models/nets/resnet1d/index.rst.txt b/_sources/autoapi/ssl_tools/models/nets/resnet1d/index.rst.txt new file mode 100644 index 0000000..3209947 --- /dev/null +++ b/_sources/autoapi/ssl_tools/models/nets/resnet1d/index.rst.txt @@ -0,0 +1,124 @@ +ssl_tools.models.nets.resnet1d +============================== + +.. py:module:: ssl_tools.models.nets.resnet1d + +.. autoapi-nested-parse:: + + resnet for 1-d signal data, pytorch version + + Shenda Hong, Oct 2019 + + + +Classes +------- + +.. autoapisummary:: + + ssl_tools.models.nets.resnet1d.BasicBlock + ssl_tools.models.nets.resnet1d.MyConv1dPadSame + ssl_tools.models.nets.resnet1d.MyMaxPool1dPadSame + ssl_tools.models.nets.resnet1d.ResNet1D + ssl_tools.models.nets.resnet1d._ResNet1D + + +Functions +--------- + +.. autoapisummary:: + + ssl_tools.models.nets.resnet1d.main + + +Module Contents +--------------- + +.. py:class:: BasicBlock(in_channels, out_channels, kernel_size, stride, groups, downsample, use_bn, use_do, is_first_block=False) + + Bases: :py:obj:`torch.nn.Module` + + + ResNet Basic Block + + + .. py:method:: forward(x) + + +.. py:class:: MyConv1dPadSame(in_channels, out_channels, kernel_size, stride, groups=1) + + Bases: :py:obj:`torch.nn.Module` + + + extend nn.Conv1d to support SAME padding + + + .. py:method:: forward(x) + + +.. py:class:: MyMaxPool1dPadSame(kernel_size) + + Bases: :py:obj:`torch.nn.Module` + + + extend nn.MaxPool1d to support SAME padding + + + .. py:method:: forward(x) + + +.. py:class:: ResNet1D(input_shape, base_filters=128, kernel_size=16, stride=2, groups=32, n_block=48, num_classes=6, downsample_gap=2, increasefilter_gap=4, use_bn=True, use_do=True, learning_rate = 0.001) + + Bases: :py:obj:`ssl_tools.models.nets.simple.SimpleClassificationNet` + + + .. py:method:: _calculate_fc_input_features(backbone, input_shape) + + Run a single forward pass with a random input to get the number of + features after the convolutional layers. + + Parameters + ---------- + backbone : torch.nn.Module + The backbone of the network + input_shape : Tuple[int, int] + The input shape of the network. + + Returns + ------- + int + The number of features after the convolutional layers. + + + + .. py:method:: _create_fc(input_features, num_classes) + + +.. py:class:: _ResNet1D(in_channels, base_filters=64, kernel_size=16, stride=2, groups=32, n_block=48, n_classes=6, downsample_gap=2, increasefilter_gap=4, use_bn=True, use_do=True, verbose=False) + + Bases: :py:obj:`torch.nn.Module` + + + Input: + X: (n_samples, n_channel, n_length) + Y: (n_samples) + + Output: + out: (n_samples) + + Pararmetes: + in_channels: dim of input, the same as n_channel + base_filters: number of filters in the first several Conv layer, it will double at every 4 layers + kernel_size: width of kernel + stride: stride of kernel moving + groups: set larget to 1 as ResNeXt + n_block: number of blocks + n_classes: number of classes + + + + .. py:method:: forward(x) + + +.. py:function:: main() + diff --git a/_sources/autoapi/ssl_tools/models/nets/resnet_1d/index.rst.txt b/_sources/autoapi/ssl_tools/models/nets/resnet_1d/index.rst.txt new file mode 100644 index 0000000..d01d9a5 --- /dev/null +++ b/_sources/autoapi/ssl_tools/models/nets/resnet_1d/index.rst.txt @@ -0,0 +1,101 @@ +ssl_tools.models.nets.resnet_1d +=============================== + +.. py:module:: ssl_tools.models.nets.resnet_1d + + +Classes +------- + +.. autoapisummary:: + + ssl_tools.models.nets.resnet_1d.ConvolutionalBlock + ssl_tools.models.nets.resnet_1d.ResNet1DBase + ssl_tools.models.nets.resnet_1d.ResNet1D_8 + ssl_tools.models.nets.resnet_1d.ResNetBlock + ssl_tools.models.nets.resnet_1d.ResNetSE1D_5 + ssl_tools.models.nets.resnet_1d.ResNetSE1D_8 + ssl_tools.models.nets.resnet_1d.ResNetSEBlock + ssl_tools.models.nets.resnet_1d.SqueezeAndExcitation1D + ssl_tools.models.nets.resnet_1d._ResNet1D + + +Module Contents +--------------- + +.. py:class:: ConvolutionalBlock(in_channels, activation_cls = None) + + Bases: :py:obj:`torch.nn.Module` + + + .. py:method:: forward(x) + + +.. py:class:: ResNet1DBase(resnet_block_cls = ResNetBlock, activation_cls = torch.nn.ReLU, input_shape = (6, 60), num_classes = 6, num_residual_blocks = 5, reduction_ratio=2, learning_rate = 0.001) + + Bases: :py:obj:`ssl_tools.models.nets.simple.SimpleClassificationNet` + + + .. py:method:: _calculate_fc_input_features(backbone, input_shape) + + Run a single forward pass with a random input to get the number of + features after the convolutional layers. + + Parameters + ---------- + backbone : torch.nn.Module + The backbone of the network + input_shape : Tuple[int, int, int] + The input shape of the network. + + Returns + ------- + int + The number of features after the convolutional layers. + + + +.. py:class:: ResNet1D_8(*args, **kwargs) + + Bases: :py:obj:`ResNet1DBase` + + +.. py:class:: ResNetBlock(in_channels = 64, activation_cls = torch.nn.ReLU) + + Bases: :py:obj:`torch.nn.Module` + + + .. py:method:: forward(x) + + +.. py:class:: ResNetSE1D_5(*args, **kwargs) + + Bases: :py:obj:`ResNet1DBase` + + +.. py:class:: ResNetSE1D_8(*args, **kwargs) + + Bases: :py:obj:`ResNet1DBase` + + +.. py:class:: ResNetSEBlock(*args, **kwargs) + + Bases: :py:obj:`ResNetBlock` + + +.. py:class:: SqueezeAndExcitation1D(in_channels, reduction_ratio = 2) + + Bases: :py:obj:`torch.nn.Module` + + + .. py:method:: forward(input_tensor) + + +.. py:class:: _ResNet1D(input_shape, residual_block_cls=ResNetBlock, activation_cls = torch.nn.ReLU, num_residual_blocks = 5, reduction_ratio=2) + + Bases: :py:obj:`torch.nn.Module` + + + .. py:method:: forward(x) + + diff --git a/_sources/autoapi/ssl_tools/models/nets/simple/index.rst.txt b/_sources/autoapi/ssl_tools/models/nets/simple/index.rst.txt new file mode 100644 index 0000000..ace0d34 --- /dev/null +++ b/_sources/autoapi/ssl_tools/models/nets/simple/index.rst.txt @@ -0,0 +1,85 @@ +ssl_tools.models.nets.simple +============================ + +.. py:module:: ssl_tools.models.nets.simple + + +Classes +------- + +.. autoapisummary:: + + ssl_tools.models.nets.simple.MLPClassifier + ssl_tools.models.nets.simple.SimpleClassificationNet + ssl_tools.models.nets.simple.SimpleReconstructionNet + + +Module Contents +--------------- + +.. py:class:: MLPClassifier(input_size, hidden_size, num_hidden_layers, output_size, learning_rate = 0.001, flatten = True, loss_fn = None, train_metrics = None, val_metrics = None, test_metrics = None) + + Bases: :py:obj:`SimpleClassificationNet` + + +.. py:class:: SimpleClassificationNet(backbone, fc, learning_rate = 0.001, flatten = True, loss_fn = None, train_metrics = None, val_metrics = None, test_metrics = None) + + Bases: :py:obj:`lightning.LightningModule` + + + .. py:method:: compute_metrics(y_hat, y, step_name) + + + .. py:method:: configure_optimizers() + + + .. py:method:: forward(x) + + + .. py:method:: loss_func(y_hat, y) + + + .. py:method:: predict_step(batch, batch_idx, dataloader_idx=None) + + + .. py:method:: single_step(batch, batch_idx, step_name) + + + .. py:method:: test_step(batch, batch_idx) + + + .. py:method:: training_step(batch, batch_idx) + + + .. py:method:: validation_step(batch, batch_idx) + + +.. py:class:: SimpleReconstructionNet(backbone, learning_rate = 0.001, loss_fn = None) + + Bases: :py:obj:`lightning.LightningModule` + + + .. py:method:: configure_optimizers() + + + .. py:method:: forward(x) + + + .. py:method:: loss_func(y_hat, y) + + + .. py:method:: predict_step(batch, batch_idx, dataloader_idx=None) + + + .. py:method:: single_step(batch, batch_idx, step_name) + + + .. py:method:: test_step(batch, batch_idx) + + + .. py:method:: training_step(batch, batch_idx) + + + .. py:method:: validation_step(batch, batch_idx) + + diff --git a/_sources/autoapi/ssl_tools/models/nets/transformer/index.rst.txt b/_sources/autoapi/ssl_tools/models/nets/transformer/index.rst.txt new file mode 100644 index 0000000..6f2c672 --- /dev/null +++ b/_sources/autoapi/ssl_tools/models/nets/transformer/index.rst.txt @@ -0,0 +1,25 @@ +ssl_tools.models.nets.transformer +================================= + +.. py:module:: ssl_tools.models.nets.transformer + + +Classes +------- + +.. autoapisummary:: + + ssl_tools.models.nets.transformer.SimpleTransformer + + +Module Contents +--------------- + +.. py:class:: SimpleTransformer(in_channels = 6, dim_feedforward=60, num_classes = 6, heads = 2, num_layers = 2, learning_rate = 0.001) + + Bases: :py:obj:`ssl_tools.models.nets.simple.SimpleClassificationNet` + + + .. py:method:: configure_optimizers() + + diff --git a/_sources/autoapi/ssl_tools/models/nets/wisenet/index.rst.txt b/_sources/autoapi/ssl_tools/models/nets/wisenet/index.rst.txt index 5d600cc..c1b8020 100644 --- a/_sources/autoapi/ssl_tools/models/nets/wisenet/index.rst.txt +++ b/_sources/autoapi/ssl_tools/models/nets/wisenet/index.rst.txt @@ -1,27 +1,25 @@ -:py:mod:`ssl_tools.models.nets.wisenet` -======================================= +ssl_tools.models.nets.wisenet +============================= .. py:module:: ssl_tools.models.nets.wisenet -Module Contents ---------------- - Classes -~~~~~~~ +------- .. autoapisummary:: ssl_tools.models.nets.wisenet.WiseNet - +Module Contents +--------------- .. py:class:: WiseNet(learning_rate = 0.0001) - Bases: :py:obj:`pytorch_lightning.LightningModule` + .. py:method:: _common_step(batch, batch_idx) @@ -43,4 +41,3 @@ Classes .. py:method:: validation_step(batch, batch_idx) - diff --git a/_sources/autoapi/ssl_tools/models/ssl/classifier/index.rst.txt b/_sources/autoapi/ssl_tools/models/ssl/classifier/index.rst.txt index a826550..dd65863 100644 --- a/_sources/autoapi/ssl_tools/models/ssl/classifier/index.rst.txt +++ b/_sources/autoapi/ssl_tools/models/ssl/classifier/index.rst.txt @@ -1,27 +1,25 @@ -:py:mod:`ssl_tools.models.ssl.classifier` -========================================= +ssl_tools.models.ssl.classifier +=============================== .. py:module:: ssl_tools.models.ssl.classifier -Module Contents ---------------- - Classes -~~~~~~~ +------- .. autoapisummary:: ssl_tools.models.ssl.classifier.SSLDiscriminator - +Module Contents +--------------- .. py:class:: SSLDiscriminator(backbone, head, loss_fn, learning_rate = 0.001, update_backbone = True, metrics = None, optimizer_cls = None) - Bases: :py:obj:`lightning.LightningModule` + A generic SSL Discriminator model. It takes a backbone and a head and trains them jointly (or not, depending on the ``update_backbone`` @@ -52,6 +50,7 @@ Classes The metrics to use during training. The keys are the names of the metrics, and the values are the metrics themselves. + .. py:method:: _compute_metrics(y_hat, y, stage) Compute the metrics. @@ -72,6 +71,7 @@ Classes metrics, and the values are the values of the metrics. + .. py:method:: _freeze(model) Freezes the model, i.e. sets the requires_grad parameter of all the @@ -83,6 +83,7 @@ Classes The model to freeze + .. py:method:: _loss_func(y_hat, y) Calculates the loss function. @@ -95,6 +96,7 @@ Classes The ground truth labels + .. py:method:: configure_optimizers() Configures the optimizer. If ``update_backbone`` is True, it will @@ -102,6 +104,7 @@ Classes only update the parameters of the head. + .. py:method:: forward(x) Performs a forward pass through the model. It first passes the input @@ -120,6 +123,7 @@ Classes The predictions of the model. + .. py:method:: predict_step(batch, batch_idx) Performs a prediction step. It only performs a forward pass through @@ -138,6 +142,7 @@ Classes The predictions of the model + .. py:method:: test_step(batch, batch_idx) Performs a test step. It first performs a forward pass through @@ -157,6 +162,7 @@ Classes The loss of the model + .. py:method:: training_step(batch, batch_idx) Performs a training step. It first performs a forward pass through @@ -176,6 +182,7 @@ Classes The loss of the model + .. py:method:: validation_step(batch, batch_idx) Performs a validation step. It first performs a forward pass through diff --git a/_sources/autoapi/ssl_tools/models/ssl/cpc/index.rst.txt b/_sources/autoapi/ssl_tools/models/ssl/cpc/index.rst.txt index 32d5043..e54e599 100644 --- a/_sources/autoapi/ssl_tools/models/ssl/cpc/index.rst.txt +++ b/_sources/autoapi/ssl_tools/models/ssl/cpc/index.rst.txt @@ -1,35 +1,33 @@ -:py:mod:`ssl_tools.models.ssl.cpc` -================================== +ssl_tools.models.ssl.cpc +======================== .. py:module:: ssl_tools.models.ssl.cpc -Module Contents ---------------- - Classes -~~~~~~~ +------- .. autoapisummary:: ssl_tools.models.ssl.cpc.CPC - Functions -~~~~~~~~~ +--------- .. autoapisummary:: ssl_tools.models.ssl.cpc.build_cpc +Module Contents +--------------- .. py:class:: CPC(encoder, density_estimator, auto_regressor, lr = 0.001, weight_decay = 0.0, window_size = 4, n_size = 5) - Bases: :py:obj:`lightning.LightningModule`, :py:obj:`ssl_tools.utils.configurable.Configurable` + Implements the Contrastive Predictive Coding (CPC) model, as described in https://arxiv.org/abs/1807.03748. The implementation was adapted from https://github.com/sanatonek/TNC_representation_learning @@ -65,6 +63,7 @@ Functions Number of negative samples to be used in the contrastive loss (steps to predict) + .. py:method:: _shared_step(batch, batch_idx, prefix) @@ -86,6 +85,7 @@ Functions A tensor of size (B, encoder_output_size), with the samples encoded + .. py:method:: configure_optimizers() @@ -107,6 +107,7 @@ Functions A tensor of size (B, encoder_output_size), with the samples encoded. + .. py:method:: get_config() @@ -122,7 +123,6 @@ Functions .. py:method:: validation_step(batch, batch_idx) - .. py:function:: build_cpc(encoding_size = 150, in_channels = 6, gru_hidden_size = 100, gru_num_layers = 1, gru_bidirectional = True, dropout = 0.0, learning_rate = 0.001, weight_decay = 0.0, window_size = 4, n_size = 5) Builds a default CPC model. This function aid in the creation of a CPC diff --git a/_sources/autoapi/ssl_tools/models/ssl/index.rst.txt b/_sources/autoapi/ssl_tools/models/ssl/index.rst.txt index 3074247..05aeb07 100644 --- a/_sources/autoapi/ssl_tools/models/ssl/index.rst.txt +++ b/_sources/autoapi/ssl_tools/models/ssl/index.rst.txt @@ -1,18 +1,18 @@ -:py:mod:`ssl_tools.models.ssl` -============================== +ssl_tools.models.ssl +==================== .. py:module:: ssl_tools.models.ssl Submodules ---------- + .. toctree:: - :titlesonly: :maxdepth: 1 - classifier/index.rst - cpc/index.rst - tfc/index.rst - tnc/index.rst + /autoapi/ssl_tools/models/ssl/classifier/index + /autoapi/ssl_tools/models/ssl/cpc/index + /autoapi/ssl_tools/models/ssl/tfc/index + /autoapi/ssl_tools/models/ssl/tnc/index diff --git a/_sources/autoapi/ssl_tools/models/ssl/modules/heads/index.rst.txt b/_sources/autoapi/ssl_tools/models/ssl/modules/heads/index.rst.txt index 41c77fa..71ec438 100644 --- a/_sources/autoapi/ssl_tools/models/ssl/modules/heads/index.rst.txt +++ b/_sources/autoapi/ssl_tools/models/ssl/modules/heads/index.rst.txt @@ -1,14 +1,11 @@ -:py:mod:`ssl_tools.models.ssl.modules.heads` -============================================ +ssl_tools.models.ssl.modules.heads +================================== .. py:module:: ssl_tools.models.ssl.modules.heads -Module Contents ---------------- - Classes -~~~~~~~ +------- .. autoapisummary:: @@ -18,29 +15,26 @@ Classes ssl_tools.models.ssl.modules.heads.TNCPredictionHead - +Module Contents +--------------- .. py:class:: CPCPredictionHead(input_dim = 10, hidden_dim1 = 64, hidden_dim2 = 64, output_dim = 6, dropout_prob = 0) - Bases: :py:obj:`TNCPredictionHead` .. py:class:: TFCPredictionHead(input_dim = 2 * 128, hidden_dim = 64, output_dim = 2) - Bases: :py:obj:`lightly.models.modules.heads.ProjectionHead` .. py:class:: TFCProjectionHead(input_dim, hidden_dim = 256, output_dim = 128) - Bases: :py:obj:`lightly.models.modules.heads.ProjectionHead` .. py:class:: TNCPredictionHead(input_dim = 10, hidden_dim1 = 64, hidden_dim2 = 64, output_dim = 6, dropout_prob = 0) - Bases: :py:obj:`lightly.models.modules.heads.ProjectionHead` diff --git a/_sources/autoapi/ssl_tools/models/ssl/tfc/index.rst.txt b/_sources/autoapi/ssl_tools/models/ssl/tfc/index.rst.txt index 8cf7bd6..0f5217f 100644 --- a/_sources/autoapi/ssl_tools/models/ssl/tfc/index.rst.txt +++ b/_sources/autoapi/ssl_tools/models/ssl/tfc/index.rst.txt @@ -1,14 +1,11 @@ -:py:mod:`ssl_tools.models.ssl.tfc` -================================== +ssl_tools.models.ssl.tfc +======================== .. py:module:: ssl_tools.models.ssl.tfc -Module Contents ---------------- - Classes -~~~~~~~ +------- .. autoapisummary:: @@ -16,21 +13,22 @@ Classes ssl_tools.models.ssl.tfc.TFCHead - Functions -~~~~~~~~~ +--------- .. autoapisummary:: ssl_tools.models.ssl.tfc.build_tfc_transformer +Module Contents +--------------- .. py:class:: TFC(time_encoder, frequency_encoder, time_projector, frequency_projector, nxtent_criterion, learning_rate = 0.001, loss_lambda = 0.2, permute_input = None) - Bases: :py:obj:`lightning.LightningModule`, :py:obj:`ssl_tools.utils.configurable.Configurable` + Configurable interface for models and other objects that can be configured with a dictionary. For now, this interface is used to save the hyperparameters of the models. @@ -60,6 +58,7 @@ Functions loss_lambda : float, optional The consistency threshold, by default 0.2 + .. py:method:: _generate_representations(x_in_t, x_in_f) Returns the intermediate representations of the model. @@ -82,6 +81,7 @@ Functions (h_time, z_time, h_freq, z_freq). + .. py:method:: _shared_step(data, aug1, data_f, aug1_f, stage) Compute the representations and the loss. @@ -107,6 +107,7 @@ Functions z_time, h_freq, z_freq). The second element is the loss. + .. py:method:: configure_optimizers() @@ -127,6 +128,7 @@ Functions The final representation of the model (z_t, z_f concatenated) + .. py:method:: get_config() @@ -139,12 +141,11 @@ Functions .. py:method:: validation_step(batch, batch_idx) - .. py:class:: TFCHead(input_size = 2 * 128, num_classes = 2) - Bases: :py:obj:`torch.nn.Module` + Simple discriminator network, used as the head of the TFC model. @@ -155,8 +156,8 @@ Functions n_classes : int, optional Number of output classes (output_size), by default 2 - .. py:method:: forward(x) + .. py:method:: forward(x) .. py:function:: build_tfc_transformer(encoding_size = 128, in_channels = 1, length_alignment = 360, use_cosine_similarity = True, learning_rate = 0.001, temperature = 0.5) diff --git a/_sources/autoapi/ssl_tools/models/ssl/tnc/index.rst.txt b/_sources/autoapi/ssl_tools/models/ssl/tnc/index.rst.txt index 1645b65..dd20f7f 100644 --- a/_sources/autoapi/ssl_tools/models/ssl/tnc/index.rst.txt +++ b/_sources/autoapi/ssl_tools/models/ssl/tnc/index.rst.txt @@ -1,14 +1,11 @@ -:py:mod:`ssl_tools.models.ssl.tnc` -================================== +ssl_tools.models.ssl.tnc +======================== .. py:module:: ssl_tools.models.ssl.tnc -Module Contents ---------------- - Classes -~~~~~~~ +------- .. autoapisummary:: @@ -16,21 +13,22 @@ Classes ssl_tools.models.ssl.tnc.TNCDiscriminator - Functions -~~~~~~~~~ +--------- .. autoapisummary:: ssl_tools.models.ssl.tnc.build_tnc +Module Contents +--------------- .. py:class:: TNC(discriminator, encoder, mc_sample_size = 20, w = 0.05, learning_rate=0.001) - Bases: :py:obj:`lightning.LightningModule`, :py:obj:`ssl_tools.utils.configurable.Configurable` + Configurable interface for models and other objects that can be configured with a dictionary. For now, this interface is used to save the hyperparameters of the models. @@ -57,6 +55,7 @@ Functions learning_rate : _type_, optional The learning rate of the optimizer, by default 1e-3 + .. py:method:: _shared_step(x_t, x_p, x_n, stage) Runs TNC and returns the representation and the loss. @@ -85,6 +84,7 @@ Functions respectively. + .. py:method:: configure_optimizers() @@ -106,6 +106,7 @@ Functions The predicted labels. + .. py:method:: test_step(batch, batch_idx) @@ -115,12 +116,11 @@ Functions .. py:method:: validation_step(batch, batch_idx) - .. py:class:: TNCDiscriminator(input_size = 10, n_classes = 1) - Bases: :py:obj:`torch.nn.Module` + Simple discriminator network. As usued by `Tonekaboni et al.` at "Unsupervised Representation Learning for Time Series with Temporal @@ -138,6 +138,7 @@ Functions n_classes : int, optional Number of output classes (output_size), by default 1 + .. py:method:: forward(x) Predict the probability of the two inputs belonging to the same diff --git a/_sources/autoapi/ssl_tools/models/utils/index.rst.txt b/_sources/autoapi/ssl_tools/models/utils/index.rst.txt new file mode 100644 index 0000000..15d4f10 --- /dev/null +++ b/_sources/autoapi/ssl_tools/models/utils/index.rst.txt @@ -0,0 +1,58 @@ +ssl_tools.models.utils +====================== + +.. py:module:: ssl_tools.models.utils + + +Classes +------- + +.. autoapisummary:: + + ssl_tools.models.utils.RandomDataModule + ssl_tools.models.utils.RandomDataset + ssl_tools.models.utils.ShapePrinter + ssl_tools.models.utils.ZeroPadder2D + + +Module Contents +--------------- + +.. py:class:: RandomDataModule(num_samples, num_classes, input_shape, transforms = None, batch_size = 1) + + Bases: :py:obj:`lightning.LightningDataModule` + + + .. py:method:: train_dataloader() + + +.. py:class:: RandomDataset(num_samples = 64, num_classes = 6, input_shape = (6, 60), transforms = None) + + .. py:method:: __getitem__(idx) + + + .. py:method:: __len__() + + +.. py:class:: ShapePrinter(name='') + + Bases: :py:obj:`torch.nn.Module` + + + .. py:method:: forward(x) + + +.. py:class:: ZeroPadder2D(pad_at, padding_size) + + Bases: :py:obj:`torch.nn.Module` + + + .. py:method:: __repr__() + + + .. py:method:: __str__() + + + .. py:method:: forward(x) + + diff --git a/_sources/autoapi/ssl_tools/pipelines/base/index.rst.txt b/_sources/autoapi/ssl_tools/pipelines/base/index.rst.txt new file mode 100644 index 0000000..fb14a6c --- /dev/null +++ b/_sources/autoapi/ssl_tools/pipelines/base/index.rst.txt @@ -0,0 +1,30 @@ +ssl_tools.pipelines.base +======================== + +.. py:module:: ssl_tools.pipelines.base + + +Classes +------- + +.. autoapisummary:: + + ssl_tools.pipelines.base.Pipeline + + +Module Contents +--------------- + +.. py:class:: Pipeline + + Bases: :py:obj:`lightning.pytorch.core.mixins.HyperparametersMixin` + + + .. py:method:: __call__() + + + .. py:method:: run() + :abstractmethod: + + + diff --git a/_sources/autoapi/ssl_tools/pipelines/cli/index.rst.txt b/_sources/autoapi/ssl_tools/pipelines/cli/index.rst.txt new file mode 100644 index 0000000..e358fb5 --- /dev/null +++ b/_sources/autoapi/ssl_tools/pipelines/cli/index.rst.txt @@ -0,0 +1,22 @@ +ssl_tools.pipelines.cli +======================= + +.. py:module:: ssl_tools.pipelines.cli + + +Functions +--------- + +.. autoapisummary:: + + ssl_tools.pipelines.cli.auto_main + ssl_tools.pipelines.cli.get_parser + + +Module Contents +--------------- + +.. py:function:: auto_main(commands, print_args = False) + +.. py:function:: get_parser(commands) + diff --git a/_sources/autoapi/ssl_tools/pipelines/har_classification/conv1d_conss/index.rst.txt b/_sources/autoapi/ssl_tools/pipelines/har_classification/conv1d_conss/index.rst.txt new file mode 100644 index 0000000..4133858 --- /dev/null +++ b/_sources/autoapi/ssl_tools/pipelines/har_classification/conv1d_conss/index.rst.txt @@ -0,0 +1,53 @@ +ssl_tools.pipelines.har_classification.conv1d_conss +=================================================== + +.. py:module:: ssl_tools.pipelines.har_classification.conv1d_conss + + +Attributes +---------- + +.. autoapisummary:: + + ssl_tools.pipelines.har_classification.conv1d_conss.experiment + + +Classes +------- + +.. autoapisummary:: + + ssl_tools.pipelines.har_classification.conv1d_conss.PartialEmbeddingEvaluator + ssl_tools.pipelines.har_classification.conv1d_conss.PartialEmbeddingEvaluatorCallback + ssl_tools.pipelines.har_classification.conv1d_conss.Simple1DConvNetFineTune2 + + +Module Contents +--------------- + +.. py:class:: PartialEmbeddingEvaluator(experiment_name, model, data_module, trainer, **kwargs) + + Bases: :py:obj:`evaluator.EmbeddingEvaluator` + + + .. py:method:: run() + + +.. py:class:: PartialEmbeddingEvaluatorCallback(experiment_name, frequency = 1, **partal_embedding_evaluator_kwargs) + + Bases: :py:obj:`lightning.pytorch.callbacks.Callback` + + + .. py:method:: on_validation_end(trainer, pl_module) + + +.. py:class:: Simple1DConvNetFineTune2 + + Bases: :py:obj:`simple1Dconv_classifier.Simple1DConvNetFineTune` + + + .. py:method:: get_callbacks() + + +.. py:data:: experiment + diff --git a/_sources/autoapi/ssl_tools/pipelines/har_classification/cpc/index.rst.txt b/_sources/autoapi/ssl_tools/pipelines/har_classification/cpc/index.rst.txt new file mode 100644 index 0000000..056a8c2 --- /dev/null +++ b/_sources/autoapi/ssl_tools/pipelines/har_classification/cpc/index.rst.txt @@ -0,0 +1,128 @@ +ssl_tools.pipelines.har_classification.cpc +========================================== + +.. py:module:: ssl_tools.pipelines.har_classification.cpc + + +Attributes +---------- + +.. autoapisummary:: + + ssl_tools.pipelines.har_classification.cpc.options + + +Classes +------- + +.. autoapisummary:: + + ssl_tools.pipelines.har_classification.cpc.CPCFineTune + ssl_tools.pipelines.har_classification.cpc.CPCPreTrain + + +Module Contents +--------------- + +.. py:class:: CPCFineTune(data, encoding_size = 128, num_classes = 6, num_workers = None, **kwargs) + + Bases: :py:obj:`ssl_tools.pipelines.mlflow_train.LightningFineTuneMLFlow` + + + + Train a model using Lightning framework. + + Parameters + ---------- + experiment_name : str + Name of the experiment. + model_name : str + Name of the model. + dataset_name : str + Name of the dataset. + run_name : str, optional + The name of the run, by default None + accelerator : str, optional + The accelerator to use, by default "cpu" + devices : int, optional + Number of accelerators to use, by default 1 + num_nodes : int, optional + Number of nodes, by default 1 + strategy : str, optional + Training strategy, by default "auto" + max_epochs : int, optional + Maximium number of epochs, by default 1 + batch_size : int, optional + Batch size, by default 1 + limit_train_batches : int | float, optional + Limit the number of batches to train, by default 1.0 + limit_val_batches : int | float, optional + Limit the number of batches to test, by default 1.0 + checkpoint_monitor_metric : str, optional + The metric to monitor for checkpointing, by default None + checkpoint_monitor_mode : str, optional + The mode for checkpointing, by default "min" + patience : int, optional + The patience for early stopping, by default None + log_dir : str, optional + Location where logs will be saved, by default "./runs" + + + .. py:method:: get_data_module() + + + .. py:method:: get_model() + + +.. py:class:: CPCPreTrain(data, encoding_size = 128, in_channel = 6, window_size = 4, pad_length = False, num_classes = 6, num_workers = None, **kwargs) + + Bases: :py:obj:`ssl_tools.pipelines.mlflow_train.LightningTrainMLFlow` + + + + Train a model using Lightning framework. + + Parameters + ---------- + experiment_name : str + Name of the experiment. + model_name : str + Name of the model. + dataset_name : str + Name of the dataset. + run_name : str, optional + The name of the run, by default None + accelerator : str, optional + The accelerator to use, by default "cpu" + devices : int, optional + Number of accelerators to use, by default 1 + num_nodes : int, optional + Number of nodes, by default 1 + strategy : str, optional + Training strategy, by default "auto" + max_epochs : int, optional + Maximium number of epochs, by default 1 + batch_size : int, optional + Batch size, by default 1 + limit_train_batches : int | float, optional + Limit the number of batches to train, by default 1.0 + limit_val_batches : int | float, optional + Limit the number of batches to test, by default 1.0 + checkpoint_monitor_metric : str, optional + The metric to monitor for checkpointing, by default None + checkpoint_monitor_mode : str, optional + The mode for checkpointing, by default "min" + patience : int, optional + The patience for early stopping, by default None + log_dir : str, optional + Location where logs will be saved, by default "./runs" + + + .. py:method:: get_data_module() + + + .. py:method:: get_model() + + +.. py:data:: options + diff --git a/_sources/autoapi/ssl_tools/pipelines/har_classification/evaluator/index.rst.txt b/_sources/autoapi/ssl_tools/pipelines/har_classification/evaluator/index.rst.txt new file mode 100644 index 0000000..64fe1c0 --- /dev/null +++ b/_sources/autoapi/ssl_tools/pipelines/har_classification/evaluator/index.rst.txt @@ -0,0 +1,152 @@ +ssl_tools.pipelines.har_classification.evaluator +================================================ + +.. py:module:: ssl_tools.pipelines.har_classification.evaluator + + +Attributes +---------- + +.. autoapisummary:: + + ssl_tools.pipelines.har_classification.evaluator.options + ssl_tools.pipelines.har_classification.evaluator.transforms_map + + +Classes +------- + +.. autoapisummary:: + + ssl_tools.pipelines.har_classification.evaluator.CSVGenerator + ssl_tools.pipelines.har_classification.evaluator.EmbeddingEvaluator + ssl_tools.pipelines.har_classification.evaluator.EvaluateAll + ssl_tools.pipelines.har_classification.evaluator.HAREmbeddingEvaluator + + +Functions +--------- + +.. autoapisummary:: + + ssl_tools.pipelines.har_classification.evaluator.full_dataset_from_dataloader + ssl_tools.pipelines.har_classification.evaluator.generate_embeddings + ssl_tools.pipelines.har_classification.evaluator.get_full_data_split + ssl_tools.pipelines.har_classification.evaluator.get_split_dataloader + ssl_tools.pipelines.har_classification.evaluator.run_evaluator_wrapper + + +Module Contents +--------------- + +.. py:class:: CSVGenerator(experiments, log_dir = './mlruns', results_file = 'results.csv') + + Bases: :py:obj:`ssl_tools.pipelines.base.Pipeline` + + + .. py:property:: client + + + .. py:method:: run() + + +.. py:class:: EmbeddingEvaluator(experiment_name, registered_model_name, registered_model_tags = None, experiment_tags = None, n_classes = 7, run_name = None, accelerator = 'cpu', devices = 1, num_nodes = 1, num_workers = None, strategy = 'auto', batch_size = 1, limit_predict_batches = 1.0, log_dir = './mlruns', results_file = 'results.csv', confusion_matrix_file = 'confusion_matrix.csv', confusion_matrix_image_file = 'confusion_matrix.png', tsne_plot_file = 'tsne_embeddings.png', embedding_file = 'embeddings.csv', predictions_file = 'predictions.csv', add_epoch_info = False) + + Bases: :py:obj:`ssl_tools.pipelines.base.Pipeline` + + + .. py:method:: _compute_classification_metrics(y_hat_logits, y, n_classes) + + + .. py:method:: _confusion_matrix(y_hat, y, n_classes) + + + .. py:method:: _evaluate_embeddings(model, y_hat, y, n_classes, run_id, artifact_path) + + + .. py:method:: _plot_confusion_matrix(cm, classes) + + + .. py:method:: _plot_tnse_embeddings(embeddings, y, y_hat, n_components = 2) + + + .. py:property:: client + + + .. py:method:: evaluate_embeddings(model, data_module, trainer) + + + .. py:method:: evaluate_model_performance(model, data_module, trainer) + + + .. py:method:: get_callbacks() + + + .. py:method:: get_data_module() + :abstractmethod: + + + + .. py:method:: get_logger() + + + .. py:method:: get_trainer(logger, callbacks) + + + .. py:method:: load_model() + + + .. py:method:: predict(model, dataloader, trainer) + + + .. py:method:: run() + + + .. py:method:: run_task(model, data_module, trainer) + + +.. py:class:: EvaluateAll(root_dataset_dir, experiment_id, experiment_names, config_dir = None, log_dir = './mlruns', skip_existing = True, accelerator = 'cpu', devices = 1, num_nodes = 1, num_workers = None, strategy = 'auto', batch_size = 1, use_ray = False, ray_address = None) + + Bases: :py:obj:`ssl_tools.pipelines.base.Pipeline` + + + .. py:property:: client + + + .. py:method:: filter_runs(runs) + + + .. py:method:: get_runs(experiment_ids, search_string = '') + + + .. py:method:: locate_config(model_name) + + + .. py:method:: run() + + + .. py:method:: summarize(runs) + + +.. py:class:: HAREmbeddingEvaluator(data, transforms = 'identity', **kwargs) + + Bases: :py:obj:`EmbeddingEvaluator` + + + .. py:method:: get_data_module() + + +.. py:function:: full_dataset_from_dataloader(dataloader) + +.. py:function:: generate_embeddings(model, dataloader, trainer) + +.. py:function:: get_full_data_split(data_module, stage) + +.. py:function:: get_split_dataloader(stage, data_module) + +.. py:data:: options + +.. py:function:: run_evaluator_wrapper(evaluator) + +.. py:data:: transforms_map + diff --git a/_sources/autoapi/ssl_tools/pipelines/har_classification/gru_encoder/index.rst.txt b/_sources/autoapi/ssl_tools/pipelines/har_classification/gru_encoder/index.rst.txt new file mode 100644 index 0000000..09dbc5e --- /dev/null +++ b/_sources/autoapi/ssl_tools/pipelines/har_classification/gru_encoder/index.rst.txt @@ -0,0 +1,144 @@ +ssl_tools.pipelines.har_classification.gru_encoder +================================================== + +.. py:module:: ssl_tools.pipelines.har_classification.gru_encoder + + +Attributes +---------- + +.. autoapisummary:: + + ssl_tools.pipelines.har_classification.gru_encoder.options + + +Classes +------- + +.. autoapisummary:: + + ssl_tools.pipelines.har_classification.gru_encoder.GRUClassifier + ssl_tools.pipelines.har_classification.gru_encoder.GRUClassifierFineTune + ssl_tools.pipelines.har_classification.gru_encoder.GRUClassifierTrain + + +Module Contents +--------------- + +.. py:class:: GRUClassifier(hidden_size = 100, in_channels = 6, num_classes = 6, encoding_size = 100, num_layers = 1, dropout = 0.0, bidirectional = True) + + Bases: :py:obj:`ssl_tools.models.nets.simple.SimpleClassificationNet` + + +.. py:class:: GRUClassifierFineTune(data, num_classes = 6, encoding_size = 128, transforms = 'identity', num_workers = None, **kwargs) + + Bases: :py:obj:`ssl_tools.pipelines.mlflow_train.LightningFineTuneMLFlow` + + + + Train a model using Lightning framework. + + Parameters + ---------- + experiment_name : str + Name of the experiment. + model_name : str + Name of the model. + dataset_name : str + Name of the dataset. + run_name : str, optional + The name of the run, by default None + accelerator : str, optional + The accelerator to use, by default "cpu" + devices : int, optional + Number of accelerators to use, by default 1 + num_nodes : int, optional + Number of nodes, by default 1 + strategy : str, optional + Training strategy, by default "auto" + max_epochs : int, optional + Maximium number of epochs, by default 1 + batch_size : int, optional + Batch size, by default 1 + limit_train_batches : int | float, optional + Limit the number of batches to train, by default 1.0 + limit_val_batches : int | float, optional + Limit the number of batches to test, by default 1.0 + checkpoint_monitor_metric : str, optional + The metric to monitor for checkpointing, by default None + checkpoint_monitor_mode : str, optional + The mode for checkpointing, by default "min" + patience : int, optional + The patience for early stopping, by default None + log_dir : str, optional + Location where logs will be saved, by default "./runs" + + + .. py:attribute:: MODEL + :value: 'GRU' + + + + .. py:method:: get_data_module() + + + .. py:method:: get_model() + + +.. py:class:: GRUClassifierTrain(data, hidden_size = 100, in_channels = 6, num_classes = 6, encoding_size = 100, num_layers = 1, dropout = 0.0, bidirectional = True, num_workers = None, transforms = 'identity', **kwargs) + + Bases: :py:obj:`ssl_tools.pipelines.mlflow_train.LightningTrainMLFlow` + + + + Train a model using Lightning framework. + + Parameters + ---------- + experiment_name : str + Name of the experiment. + model_name : str + Name of the model. + dataset_name : str + Name of the dataset. + run_name : str, optional + The name of the run, by default None + accelerator : str, optional + The accelerator to use, by default "cpu" + devices : int, optional + Number of accelerators to use, by default 1 + num_nodes : int, optional + Number of nodes, by default 1 + strategy : str, optional + Training strategy, by default "auto" + max_epochs : int, optional + Maximium number of epochs, by default 1 + batch_size : int, optional + Batch size, by default 1 + limit_train_batches : int | float, optional + Limit the number of batches to train, by default 1.0 + limit_val_batches : int | float, optional + Limit the number of batches to test, by default 1.0 + checkpoint_monitor_metric : str, optional + The metric to monitor for checkpointing, by default None + checkpoint_monitor_mode : str, optional + The mode for checkpointing, by default "min" + patience : int, optional + The patience for early stopping, by default None + log_dir : str, optional + Location where logs will be saved, by default "./runs" + + + .. py:attribute:: MODEL + :value: 'GRU' + + + + .. py:method:: get_data_module() + + + .. py:method:: get_model() + + +.. py:data:: options + diff --git a/_sources/autoapi/ssl_tools/pipelines/har_classification/index.rst.txt b/_sources/autoapi/ssl_tools/pipelines/har_classification/index.rst.txt new file mode 100644 index 0000000..7a82b3e --- /dev/null +++ b/_sources/autoapi/ssl_tools/pipelines/har_classification/index.rst.txt @@ -0,0 +1,27 @@ +ssl_tools.pipelines.har_classification +====================================== + +.. py:module:: ssl_tools.pipelines.har_classification + + +Submodules +---------- + +.. toctree:: + :maxdepth: 1 + + /autoapi/ssl_tools/pipelines/har_classification/conv1d_conss/index + /autoapi/ssl_tools/pipelines/har_classification/cpc/index + /autoapi/ssl_tools/pipelines/har_classification/evaluator/index + /autoapi/ssl_tools/pipelines/har_classification/gru_encoder/index + /autoapi/ssl_tools/pipelines/har_classification/mlp/index + /autoapi/ssl_tools/pipelines/har_classification/simple1Dconv_classifier/index + /autoapi/ssl_tools/pipelines/har_classification/simple2Dconv_classifier/index + /autoapi/ssl_tools/pipelines/har_classification/tfc/index + /autoapi/ssl_tools/pipelines/har_classification/tfc_head_classifier/index + /autoapi/ssl_tools/pipelines/har_classification/tnc/index + /autoapi/ssl_tools/pipelines/har_classification/tnc_head_classifier/index + /autoapi/ssl_tools/pipelines/har_classification/transformer/index + /autoapi/ssl_tools/pipelines/har_classification/utils/index + + diff --git a/_sources/autoapi/ssl_tools/pipelines/har_classification/mlp/index.rst.txt b/_sources/autoapi/ssl_tools/pipelines/har_classification/mlp/index.rst.txt new file mode 100644 index 0000000..e2b3115 --- /dev/null +++ b/_sources/autoapi/ssl_tools/pipelines/har_classification/mlp/index.rst.txt @@ -0,0 +1,128 @@ +ssl_tools.pipelines.har_classification.mlp +========================================== + +.. py:module:: ssl_tools.pipelines.har_classification.mlp + + +Attributes +---------- + +.. autoapisummary:: + + ssl_tools.pipelines.har_classification.mlp.options + + +Classes +------- + +.. autoapisummary:: + + ssl_tools.pipelines.har_classification.mlp.MLPClassifierFineTune + ssl_tools.pipelines.har_classification.mlp.MLPClassifierTrain + + +Module Contents +--------------- + +.. py:class:: MLPClassifierFineTune(data, num_classes = 6, transforms = 'identity', num_workers = None, **kwargs) + + Bases: :py:obj:`ssl_tools.pipelines.mlflow_train.LightningFineTuneMLFlow` + + + + Train a model using Lightning framework. + + Parameters + ---------- + experiment_name : str + Name of the experiment. + model_name : str + Name of the model. + dataset_name : str + Name of the dataset. + run_name : str, optional + The name of the run, by default None + accelerator : str, optional + The accelerator to use, by default "cpu" + devices : int, optional + Number of accelerators to use, by default 1 + num_nodes : int, optional + Number of nodes, by default 1 + strategy : str, optional + Training strategy, by default "auto" + max_epochs : int, optional + Maximium number of epochs, by default 1 + batch_size : int, optional + Batch size, by default 1 + limit_train_batches : int | float, optional + Limit the number of batches to train, by default 1.0 + limit_val_batches : int | float, optional + Limit the number of batches to test, by default 1.0 + checkpoint_monitor_metric : str, optional + The metric to monitor for checkpointing, by default None + checkpoint_monitor_mode : str, optional + The mode for checkpointing, by default "min" + patience : int, optional + The patience for early stopping, by default None + log_dir : str, optional + Location where logs will be saved, by default "./runs" + + + .. py:method:: get_data_module() + + + .. py:method:: get_model() + + +.. py:class:: MLPClassifierTrain(data, input_size = 360, hidden_size = 64, num_hidden_layers = 1, num_classes = 6, transforms = 'identity', num_workers = None, **kwargs) + + Bases: :py:obj:`ssl_tools.pipelines.mlflow_train.LightningTrainMLFlow` + + + + Train a model using Lightning framework. + + Parameters + ---------- + experiment_name : str + Name of the experiment. + model_name : str + Name of the model. + dataset_name : str + Name of the dataset. + run_name : str, optional + The name of the run, by default None + accelerator : str, optional + The accelerator to use, by default "cpu" + devices : int, optional + Number of accelerators to use, by default 1 + num_nodes : int, optional + Number of nodes, by default 1 + strategy : str, optional + Training strategy, by default "auto" + max_epochs : int, optional + Maximium number of epochs, by default 1 + batch_size : int, optional + Batch size, by default 1 + limit_train_batches : int | float, optional + Limit the number of batches to train, by default 1.0 + limit_val_batches : int | float, optional + Limit the number of batches to test, by default 1.0 + checkpoint_monitor_metric : str, optional + The metric to monitor for checkpointing, by default None + checkpoint_monitor_mode : str, optional + The mode for checkpointing, by default "min" + patience : int, optional + The patience for early stopping, by default None + log_dir : str, optional + Location where logs will be saved, by default "./runs" + + + .. py:method:: get_data_module() + + + .. py:method:: get_model() + + +.. py:data:: options + diff --git a/_sources/autoapi/ssl_tools/pipelines/har_classification/scripts/evaluate_all/index.rst.txt b/_sources/autoapi/ssl_tools/pipelines/har_classification/scripts/evaluate_all/index.rst.txt new file mode 100644 index 0000000..1c53c4d --- /dev/null +++ b/_sources/autoapi/ssl_tools/pipelines/har_classification/scripts/evaluate_all/index.rst.txt @@ -0,0 +1,29 @@ +ssl_tools.pipelines.har_classification.scripts.evaluate_all +=========================================================== + +.. py:module:: ssl_tools.pipelines.har_classification.scripts.evaluate_all + + +Attributes +---------- + +.. autoapisummary:: + + ssl_tools.pipelines.har_classification.scripts.evaluate_all.options + + +Functions +--------- + +.. autoapisummary:: + + ssl_tools.pipelines.har_classification.scripts.evaluate_all.EvaluateAll + + +Module Contents +--------------- + +.. py:function:: EvaluateAll(Pipeline) + +.. py:data:: options + diff --git a/_sources/autoapi/ssl_tools/pipelines/har_classification/simple1Dconv_classifier/index.rst.txt b/_sources/autoapi/ssl_tools/pipelines/har_classification/simple1Dconv_classifier/index.rst.txt new file mode 100644 index 0000000..46f70b4 --- /dev/null +++ b/_sources/autoapi/ssl_tools/pipelines/har_classification/simple1Dconv_classifier/index.rst.txt @@ -0,0 +1,138 @@ +ssl_tools.pipelines.har_classification.simple1Dconv_classifier +============================================================== + +.. py:module:: ssl_tools.pipelines.har_classification.simple1Dconv_classifier + + +Attributes +---------- + +.. autoapisummary:: + + ssl_tools.pipelines.har_classification.simple1Dconv_classifier.options + + +Classes +------- + +.. autoapisummary:: + + ssl_tools.pipelines.har_classification.simple1Dconv_classifier.Simple1DConvNetFineTune + ssl_tools.pipelines.har_classification.simple1Dconv_classifier.Simple1DConvNetTrain + + +Module Contents +--------------- + +.. py:class:: Simple1DConvNetFineTune(data, num_classes = 6, transforms = 'identity', num_workers = None, **kwargs) + + Bases: :py:obj:`ssl_tools.pipelines.mlflow_train.LightningFineTuneMLFlow` + + + + Train a model using Lightning framework. + + Parameters + ---------- + experiment_name : str + Name of the experiment. + model_name : str + Name of the model. + dataset_name : str + Name of the dataset. + run_name : str, optional + The name of the run, by default None + accelerator : str, optional + The accelerator to use, by default "cpu" + devices : int, optional + Number of accelerators to use, by default 1 + num_nodes : int, optional + Number of nodes, by default 1 + strategy : str, optional + Training strategy, by default "auto" + max_epochs : int, optional + Maximium number of epochs, by default 1 + batch_size : int, optional + Batch size, by default 1 + limit_train_batches : int | float, optional + Limit the number of batches to train, by default 1.0 + limit_val_batches : int | float, optional + Limit the number of batches to test, by default 1.0 + checkpoint_monitor_metric : str, optional + The metric to monitor for checkpointing, by default None + checkpoint_monitor_mode : str, optional + The mode for checkpointing, by default "min" + patience : int, optional + The patience for early stopping, by default None + log_dir : str, optional + Location where logs will be saved, by default "./runs" + + + .. py:attribute:: MODEL + :value: 'Simple1DConvNet' + + + + .. py:method:: get_data_module() + + + .. py:method:: get_model() + + +.. py:class:: Simple1DConvNetTrain(data, input_shape = (6, 60), num_classes = 6, transforms = 'identity', num_workers = None, **kwargs) + + Bases: :py:obj:`ssl_tools.pipelines.mlflow_train.LightningTrainMLFlow` + + + + Train a model using Lightning framework. + + Parameters + ---------- + experiment_name : str + Name of the experiment. + model_name : str + Name of the model. + dataset_name : str + Name of the dataset. + run_name : str, optional + The name of the run, by default None + accelerator : str, optional + The accelerator to use, by default "cpu" + devices : int, optional + Number of accelerators to use, by default 1 + num_nodes : int, optional + Number of nodes, by default 1 + strategy : str, optional + Training strategy, by default "auto" + max_epochs : int, optional + Maximium number of epochs, by default 1 + batch_size : int, optional + Batch size, by default 1 + limit_train_batches : int | float, optional + Limit the number of batches to train, by default 1.0 + limit_val_batches : int | float, optional + Limit the number of batches to test, by default 1.0 + checkpoint_monitor_metric : str, optional + The metric to monitor for checkpointing, by default None + checkpoint_monitor_mode : str, optional + The mode for checkpointing, by default "min" + patience : int, optional + The patience for early stopping, by default None + log_dir : str, optional + Location where logs will be saved, by default "./runs" + + + .. py:attribute:: MODEL + :value: 'Simple1DConvNet' + + + + .. py:method:: get_data_module() + + + .. py:method:: get_model() + + +.. py:data:: options + diff --git a/_sources/autoapi/ssl_tools/pipelines/har_classification/simple2Dconv_classifier/index.rst.txt b/_sources/autoapi/ssl_tools/pipelines/har_classification/simple2Dconv_classifier/index.rst.txt new file mode 100644 index 0000000..711374c --- /dev/null +++ b/_sources/autoapi/ssl_tools/pipelines/har_classification/simple2Dconv_classifier/index.rst.txt @@ -0,0 +1,138 @@ +ssl_tools.pipelines.har_classification.simple2Dconv_classifier +============================================================== + +.. py:module:: ssl_tools.pipelines.har_classification.simple2Dconv_classifier + + +Attributes +---------- + +.. autoapisummary:: + + ssl_tools.pipelines.har_classification.simple2Dconv_classifier.options + + +Classes +------- + +.. autoapisummary:: + + ssl_tools.pipelines.har_classification.simple2Dconv_classifier.Simple2DConvNetFineTune + ssl_tools.pipelines.har_classification.simple2Dconv_classifier.Simple2DConvNetTrain + + +Module Contents +--------------- + +.. py:class:: Simple2DConvNetFineTune(data, num_classes = 6, transforms = 'identity', num_workers = None, **kwargs) + + Bases: :py:obj:`ssl_tools.pipelines.mlflow_train.LightningFineTuneMLFlow` + + + + Train a model using Lightning framework. + + Parameters + ---------- + experiment_name : str + Name of the experiment. + model_name : str + Name of the model. + dataset_name : str + Name of the dataset. + run_name : str, optional + The name of the run, by default None + accelerator : str, optional + The accelerator to use, by default "cpu" + devices : int, optional + Number of accelerators to use, by default 1 + num_nodes : int, optional + Number of nodes, by default 1 + strategy : str, optional + Training strategy, by default "auto" + max_epochs : int, optional + Maximium number of epochs, by default 1 + batch_size : int, optional + Batch size, by default 1 + limit_train_batches : int | float, optional + Limit the number of batches to train, by default 1.0 + limit_val_batches : int | float, optional + Limit the number of batches to test, by default 1.0 + checkpoint_monitor_metric : str, optional + The metric to monitor for checkpointing, by default None + checkpoint_monitor_mode : str, optional + The mode for checkpointing, by default "min" + patience : int, optional + The patience for early stopping, by default None + log_dir : str, optional + Location where logs will be saved, by default "./runs" + + + .. py:attribute:: MODEL + :value: 'Simple2DConvNet' + + + + .. py:method:: get_data_module() + + + .. py:method:: get_model() + + +.. py:class:: Simple2DConvNetTrain(data, input_shape = (6, 1, 60), num_classes = 6, transforms = 'identity', num_workers = None, **kwargs) + + Bases: :py:obj:`ssl_tools.pipelines.mlflow_train.LightningTrainMLFlow` + + + + Train a model using Lightning framework. + + Parameters + ---------- + experiment_name : str + Name of the experiment. + model_name : str + Name of the model. + dataset_name : str + Name of the dataset. + run_name : str, optional + The name of the run, by default None + accelerator : str, optional + The accelerator to use, by default "cpu" + devices : int, optional + Number of accelerators to use, by default 1 + num_nodes : int, optional + Number of nodes, by default 1 + strategy : str, optional + Training strategy, by default "auto" + max_epochs : int, optional + Maximium number of epochs, by default 1 + batch_size : int, optional + Batch size, by default 1 + limit_train_batches : int | float, optional + Limit the number of batches to train, by default 1.0 + limit_val_batches : int | float, optional + Limit the number of batches to test, by default 1.0 + checkpoint_monitor_metric : str, optional + The metric to monitor for checkpointing, by default None + checkpoint_monitor_mode : str, optional + The mode for checkpointing, by default "min" + patience : int, optional + The patience for early stopping, by default None + log_dir : str, optional + Location where logs will be saved, by default "./runs" + + + .. py:attribute:: MODEL + :value: 'Simple2DConvNet' + + + + .. py:method:: get_data_module() + + + .. py:method:: get_model() + + +.. py:data:: options + diff --git a/_sources/autoapi/ssl_tools/pipelines/har_classification/tfc/index.rst.txt b/_sources/autoapi/ssl_tools/pipelines/har_classification/tfc/index.rst.txt new file mode 100644 index 0000000..4dcd85e --- /dev/null +++ b/_sources/autoapi/ssl_tools/pipelines/har_classification/tfc/index.rst.txt @@ -0,0 +1,128 @@ +ssl_tools.pipelines.har_classification.tfc +========================================== + +.. py:module:: ssl_tools.pipelines.har_classification.tfc + + +Attributes +---------- + +.. autoapisummary:: + + ssl_tools.pipelines.har_classification.tfc.options + + +Classes +------- + +.. autoapisummary:: + + ssl_tools.pipelines.har_classification.tfc.TFCFineTune + ssl_tools.pipelines.har_classification.tfc.TFCTrain + + +Module Contents +--------------- + +.. py:class:: TFCFineTune(data, num_classes = 6, num_workers = None, length_alignment = 60, encoding_size = 128, features_as_channels = True, **kwargs) + + Bases: :py:obj:`ssl_tools.pipelines.mlflow_train.LightningFineTuneMLFlow` + + + + Train a model using Lightning framework. + + Parameters + ---------- + experiment_name : str + Name of the experiment. + model_name : str + Name of the model. + dataset_name : str + Name of the dataset. + run_name : str, optional + The name of the run, by default None + accelerator : str, optional + The accelerator to use, by default "cpu" + devices : int, optional + Number of accelerators to use, by default 1 + num_nodes : int, optional + Number of nodes, by default 1 + strategy : str, optional + Training strategy, by default "auto" + max_epochs : int, optional + Maximium number of epochs, by default 1 + batch_size : int, optional + Batch size, by default 1 + limit_train_batches : int | float, optional + Limit the number of batches to train, by default 1.0 + limit_val_batches : int | float, optional + Limit the number of batches to test, by default 1.0 + checkpoint_monitor_metric : str, optional + The metric to monitor for checkpointing, by default None + checkpoint_monitor_mode : str, optional + The mode for checkpointing, by default "min" + patience : int, optional + The patience for early stopping, by default None + log_dir : str, optional + Location where logs will be saved, by default "./runs" + + + .. py:method:: get_data_module() + + + .. py:method:: get_model() + + +.. py:class:: TFCTrain(data, label = 'standard activity code', encoding_size = 128, in_channels = 6, length_alignment = 60, use_cosine_similarity = True, temperature = 0.5, features_as_channels = True, jitter_ratio = 2, num_classes = 6, num_workers = None, **kwargs) + + Bases: :py:obj:`ssl_tools.pipelines.mlflow_train.LightningTrainMLFlow` + + + + Train a model using Lightning framework. + + Parameters + ---------- + experiment_name : str + Name of the experiment. + model_name : str + Name of the model. + dataset_name : str + Name of the dataset. + run_name : str, optional + The name of the run, by default None + accelerator : str, optional + The accelerator to use, by default "cpu" + devices : int, optional + Number of accelerators to use, by default 1 + num_nodes : int, optional + Number of nodes, by default 1 + strategy : str, optional + Training strategy, by default "auto" + max_epochs : int, optional + Maximium number of epochs, by default 1 + batch_size : int, optional + Batch size, by default 1 + limit_train_batches : int | float, optional + Limit the number of batches to train, by default 1.0 + limit_val_batches : int | float, optional + Limit the number of batches to test, by default 1.0 + checkpoint_monitor_metric : str, optional + The metric to monitor for checkpointing, by default None + checkpoint_monitor_mode : str, optional + The mode for checkpointing, by default "min" + patience : int, optional + The patience for early stopping, by default None + log_dir : str, optional + Location where logs will be saved, by default "./runs" + + + .. py:method:: get_data_module() + + + .. py:method:: get_model() + + +.. py:data:: options + diff --git a/_sources/autoapi/ssl_tools/pipelines/har_classification/tfc_head_classifier/index.rst.txt b/_sources/autoapi/ssl_tools/pipelines/har_classification/tfc_head_classifier/index.rst.txt new file mode 100644 index 0000000..bdcbf4c --- /dev/null +++ b/_sources/autoapi/ssl_tools/pipelines/har_classification/tfc_head_classifier/index.rst.txt @@ -0,0 +1,138 @@ +ssl_tools.pipelines.har_classification.tfc_head_classifier +========================================================== + +.. py:module:: ssl_tools.pipelines.har_classification.tfc_head_classifier + + +Attributes +---------- + +.. autoapisummary:: + + ssl_tools.pipelines.har_classification.tfc_head_classifier.options + + +Classes +------- + +.. autoapisummary:: + + ssl_tools.pipelines.har_classification.tfc_head_classifier.TFCHeadClassifierFineTune + ssl_tools.pipelines.har_classification.tfc_head_classifier.TFCHeadClassifierTrain + + +Module Contents +--------------- + +.. py:class:: TFCHeadClassifierFineTune(data, num_classes = 6, transforms = 'identity', num_workers = None, **kwargs) + + Bases: :py:obj:`ssl_tools.pipelines.mlflow_train.LightningFineTuneMLFlow` + + + + Train a model using Lightning framework. + + Parameters + ---------- + experiment_name : str + Name of the experiment. + model_name : str + Name of the model. + dataset_name : str + Name of the dataset. + run_name : str, optional + The name of the run, by default None + accelerator : str, optional + The accelerator to use, by default "cpu" + devices : int, optional + Number of accelerators to use, by default 1 + num_nodes : int, optional + Number of nodes, by default 1 + strategy : str, optional + Training strategy, by default "auto" + max_epochs : int, optional + Maximium number of epochs, by default 1 + batch_size : int, optional + Batch size, by default 1 + limit_train_batches : int | float, optional + Limit the number of batches to train, by default 1.0 + limit_val_batches : int | float, optional + Limit the number of batches to test, by default 1.0 + checkpoint_monitor_metric : str, optional + The metric to monitor for checkpointing, by default None + checkpoint_monitor_mode : str, optional + The mode for checkpointing, by default "min" + patience : int, optional + The patience for early stopping, by default None + log_dir : str, optional + Location where logs will be saved, by default "./runs" + + + .. py:attribute:: MODEL + :value: 'TFCPredictionHead' + + + + .. py:method:: get_data_module() + + + .. py:method:: get_model() + + +.. py:class:: TFCHeadClassifierTrain(data, input_size = 360, num_classes = 6, transforms = 'identity', num_workers = None, **kwargs) + + Bases: :py:obj:`ssl_tools.pipelines.mlflow_train.LightningTrainMLFlow` + + + + Train a model using Lightning framework. + + Parameters + ---------- + experiment_name : str + Name of the experiment. + model_name : str + Name of the model. + dataset_name : str + Name of the dataset. + run_name : str, optional + The name of the run, by default None + accelerator : str, optional + The accelerator to use, by default "cpu" + devices : int, optional + Number of accelerators to use, by default 1 + num_nodes : int, optional + Number of nodes, by default 1 + strategy : str, optional + Training strategy, by default "auto" + max_epochs : int, optional + Maximium number of epochs, by default 1 + batch_size : int, optional + Batch size, by default 1 + limit_train_batches : int | float, optional + Limit the number of batches to train, by default 1.0 + limit_val_batches : int | float, optional + Limit the number of batches to test, by default 1.0 + checkpoint_monitor_metric : str, optional + The metric to monitor for checkpointing, by default None + checkpoint_monitor_mode : str, optional + The mode for checkpointing, by default "min" + patience : int, optional + The patience for early stopping, by default None + log_dir : str, optional + Location where logs will be saved, by default "./runs" + + + .. py:attribute:: MODEL + :value: 'TFCPredictionHead' + + + + .. py:method:: get_data_module() + + + .. py:method:: get_model() + + +.. py:data:: options + diff --git a/_sources/autoapi/ssl_tools/pipelines/har_classification/tnc/index.rst.txt b/_sources/autoapi/ssl_tools/pipelines/har_classification/tnc/index.rst.txt new file mode 100644 index 0000000..df43956 --- /dev/null +++ b/_sources/autoapi/ssl_tools/pipelines/har_classification/tnc/index.rst.txt @@ -0,0 +1,128 @@ +ssl_tools.pipelines.har_classification.tnc +========================================== + +.. py:module:: ssl_tools.pipelines.har_classification.tnc + + +Attributes +---------- + +.. autoapisummary:: + + ssl_tools.pipelines.har_classification.tnc.options + + +Classes +------- + +.. autoapisummary:: + + ssl_tools.pipelines.har_classification.tnc.TNCFineTune + ssl_tools.pipelines.har_classification.tnc.TNCPreTrain + + +Module Contents +--------------- + +.. py:class:: TNCFineTune(data, num_classes = 6, num_workers = None, **kwargs) + + Bases: :py:obj:`ssl_tools.pipelines.mlflow_train.LightningFineTuneMLFlow` + + + + Train a model using Lightning framework. + + Parameters + ---------- + experiment_name : str + Name of the experiment. + model_name : str + Name of the model. + dataset_name : str + Name of the dataset. + run_name : str, optional + The name of the run, by default None + accelerator : str, optional + The accelerator to use, by default "cpu" + devices : int, optional + Number of accelerators to use, by default 1 + num_nodes : int, optional + Number of nodes, by default 1 + strategy : str, optional + Training strategy, by default "auto" + max_epochs : int, optional + Maximium number of epochs, by default 1 + batch_size : int, optional + Batch size, by default 1 + limit_train_batches : int | float, optional + Limit the number of batches to train, by default 1.0 + limit_val_batches : int | float, optional + Limit the number of batches to test, by default 1.0 + checkpoint_monitor_metric : str, optional + The metric to monitor for checkpointing, by default None + checkpoint_monitor_mode : str, optional + The mode for checkpointing, by default "min" + patience : int, optional + The patience for early stopping, by default None + log_dir : str, optional + Location where logs will be saved, by default "./runs" + + + .. py:method:: get_data_module() + + + .. py:method:: get_model() + + +.. py:class:: TNCPreTrain(data, encoding_size = 10, in_channel = 6, window_size = 60, mc_sample_size = 20, w = 0.05, significance_level = 0.01, repeat = 5, pad_length = True, num_classes = 6, num_workers = None, **kwargs) + + Bases: :py:obj:`ssl_tools.pipelines.mlflow_train.LightningTrainMLFlow` + + + + Train a model using Lightning framework. + + Parameters + ---------- + experiment_name : str + Name of the experiment. + model_name : str + Name of the model. + dataset_name : str + Name of the dataset. + run_name : str, optional + The name of the run, by default None + accelerator : str, optional + The accelerator to use, by default "cpu" + devices : int, optional + Number of accelerators to use, by default 1 + num_nodes : int, optional + Number of nodes, by default 1 + strategy : str, optional + Training strategy, by default "auto" + max_epochs : int, optional + Maximium number of epochs, by default 1 + batch_size : int, optional + Batch size, by default 1 + limit_train_batches : int | float, optional + Limit the number of batches to train, by default 1.0 + limit_val_batches : int | float, optional + Limit the number of batches to test, by default 1.0 + checkpoint_monitor_metric : str, optional + The metric to monitor for checkpointing, by default None + checkpoint_monitor_mode : str, optional + The mode for checkpointing, by default "min" + patience : int, optional + The patience for early stopping, by default None + log_dir : str, optional + Location where logs will be saved, by default "./runs" + + + .. py:method:: get_data_module() + + + .. py:method:: get_model() + + +.. py:data:: options + diff --git a/_sources/autoapi/ssl_tools/pipelines/har_classification/tnc_head_classifier/index.rst.txt b/_sources/autoapi/ssl_tools/pipelines/har_classification/tnc_head_classifier/index.rst.txt new file mode 100644 index 0000000..9a07b65 --- /dev/null +++ b/_sources/autoapi/ssl_tools/pipelines/har_classification/tnc_head_classifier/index.rst.txt @@ -0,0 +1,138 @@ +ssl_tools.pipelines.har_classification.tnc_head_classifier +========================================================== + +.. py:module:: ssl_tools.pipelines.har_classification.tnc_head_classifier + + +Attributes +---------- + +.. autoapisummary:: + + ssl_tools.pipelines.har_classification.tnc_head_classifier.options + + +Classes +------- + +.. autoapisummary:: + + ssl_tools.pipelines.har_classification.tnc_head_classifier.TNCHeadClassifierFineTune + ssl_tools.pipelines.har_classification.tnc_head_classifier.TNCHeadClassifierTrain + + +Module Contents +--------------- + +.. py:class:: TNCHeadClassifierFineTune(data, num_classes = 6, transforms = 'identity', num_workers = None, **kwargs) + + Bases: :py:obj:`ssl_tools.pipelines.mlflow_train.LightningFineTuneMLFlow` + + + + Train a model using Lightning framework. + + Parameters + ---------- + experiment_name : str + Name of the experiment. + model_name : str + Name of the model. + dataset_name : str + Name of the dataset. + run_name : str, optional + The name of the run, by default None + accelerator : str, optional + The accelerator to use, by default "cpu" + devices : int, optional + Number of accelerators to use, by default 1 + num_nodes : int, optional + Number of nodes, by default 1 + strategy : str, optional + Training strategy, by default "auto" + max_epochs : int, optional + Maximium number of epochs, by default 1 + batch_size : int, optional + Batch size, by default 1 + limit_train_batches : int | float, optional + Limit the number of batches to train, by default 1.0 + limit_val_batches : int | float, optional + Limit the number of batches to test, by default 1.0 + checkpoint_monitor_metric : str, optional + The metric to monitor for checkpointing, by default None + checkpoint_monitor_mode : str, optional + The mode for checkpointing, by default "min" + patience : int, optional + The patience for early stopping, by default None + log_dir : str, optional + Location where logs will be saved, by default "./runs" + + + .. py:attribute:: MODEL + :value: 'TNCPredictionHead' + + + + .. py:method:: get_data_module() + + + .. py:method:: get_model() + + +.. py:class:: TNCHeadClassifierTrain(data, input_size = 360, num_classes = 6, transforms = 'identity', num_workers = None, **kwargs) + + Bases: :py:obj:`ssl_tools.pipelines.mlflow_train.LightningTrainMLFlow` + + + + Train a model using Lightning framework. + + Parameters + ---------- + experiment_name : str + Name of the experiment. + model_name : str + Name of the model. + dataset_name : str + Name of the dataset. + run_name : str, optional + The name of the run, by default None + accelerator : str, optional + The accelerator to use, by default "cpu" + devices : int, optional + Number of accelerators to use, by default 1 + num_nodes : int, optional + Number of nodes, by default 1 + strategy : str, optional + Training strategy, by default "auto" + max_epochs : int, optional + Maximium number of epochs, by default 1 + batch_size : int, optional + Batch size, by default 1 + limit_train_batches : int | float, optional + Limit the number of batches to train, by default 1.0 + limit_val_batches : int | float, optional + Limit the number of batches to test, by default 1.0 + checkpoint_monitor_metric : str, optional + The metric to monitor for checkpointing, by default None + checkpoint_monitor_mode : str, optional + The mode for checkpointing, by default "min" + patience : int, optional + The patience for early stopping, by default None + log_dir : str, optional + Location where logs will be saved, by default "./runs" + + + .. py:attribute:: MODEL + :value: 'TNCPredictionHead' + + + + .. py:method:: get_data_module() + + + .. py:method:: get_model() + + +.. py:data:: options + diff --git a/_sources/autoapi/ssl_tools/pipelines/har_classification/transformer/index.rst.txt b/_sources/autoapi/ssl_tools/pipelines/har_classification/transformer/index.rst.txt new file mode 100644 index 0000000..56ee2ff --- /dev/null +++ b/_sources/autoapi/ssl_tools/pipelines/har_classification/transformer/index.rst.txt @@ -0,0 +1,133 @@ +ssl_tools.pipelines.har_classification.transformer +================================================== + +.. py:module:: ssl_tools.pipelines.har_classification.transformer + + +Attributes +---------- + +.. autoapisummary:: + + ssl_tools.pipelines.har_classification.transformer.options + + +Classes +------- + +.. autoapisummary:: + + ssl_tools.pipelines.har_classification.transformer.SimpleTransformerFineTune + ssl_tools.pipelines.har_classification.transformer.SimpleTransformerTrain + + +Module Contents +--------------- + +.. py:class:: SimpleTransformerFineTune(data, num_classes = 6, num_workers = None, **kwargs) + + Bases: :py:obj:`ssl_tools.pipelines.mlflow_train.LightningFineTuneMLFlow` + + + + Train a model using Lightning framework. + + Parameters + ---------- + experiment_name : str + Name of the experiment. + model_name : str + Name of the model. + dataset_name : str + Name of the dataset. + run_name : str, optional + The name of the run, by default None + accelerator : str, optional + The accelerator to use, by default "cpu" + devices : int, optional + Number of accelerators to use, by default 1 + num_nodes : int, optional + Number of nodes, by default 1 + strategy : str, optional + Training strategy, by default "auto" + max_epochs : int, optional + Maximium number of epochs, by default 1 + batch_size : int, optional + Batch size, by default 1 + limit_train_batches : int | float, optional + Limit the number of batches to train, by default 1.0 + limit_val_batches : int | float, optional + Limit the number of batches to test, by default 1.0 + checkpoint_monitor_metric : str, optional + The metric to monitor for checkpointing, by default None + checkpoint_monitor_mode : str, optional + The mode for checkpointing, by default "min" + patience : int, optional + The patience for early stopping, by default None + log_dir : str, optional + Location where logs will be saved, by default "./runs" + + + .. py:method:: get_data_module() + + + .. py:method:: get_model() + + +.. py:class:: SimpleTransformerTrain(data, in_channels = 6, dim_feedforward=60, num_classes = 6, heads = 1, num_layers = 1, num_workers = None, **kwargs) + + Bases: :py:obj:`ssl_tools.pipelines.mlflow_train.LightningTrainMLFlow` + + + + Train a model using Lightning framework. + + Parameters + ---------- + experiment_name : str + Name of the experiment. + model_name : str + Name of the model. + dataset_name : str + Name of the dataset. + run_name : str, optional + The name of the run, by default None + accelerator : str, optional + The accelerator to use, by default "cpu" + devices : int, optional + Number of accelerators to use, by default 1 + num_nodes : int, optional + Number of nodes, by default 1 + strategy : str, optional + Training strategy, by default "auto" + max_epochs : int, optional + Maximium number of epochs, by default 1 + batch_size : int, optional + Batch size, by default 1 + limit_train_batches : int | float, optional + Limit the number of batches to train, by default 1.0 + limit_val_batches : int | float, optional + Limit the number of batches to test, by default 1.0 + checkpoint_monitor_metric : str, optional + The metric to monitor for checkpointing, by default None + checkpoint_monitor_mode : str, optional + The mode for checkpointing, by default "min" + patience : int, optional + The patience for early stopping, by default None + log_dir : str, optional + Location where logs will be saved, by default "./runs" + + + .. py:attribute:: MODEL + :value: 'Transformer' + + + + .. py:method:: get_data_module() + + + .. py:method:: get_model() + + +.. py:data:: options + diff --git a/_sources/autoapi/ssl_tools/pipelines/har_classification/utils/index.rst.txt b/_sources/autoapi/ssl_tools/pipelines/har_classification/utils/index.rst.txt new file mode 100644 index 0000000..e808958 --- /dev/null +++ b/_sources/autoapi/ssl_tools/pipelines/har_classification/utils/index.rst.txt @@ -0,0 +1,99 @@ +ssl_tools.pipelines.har_classification.utils +============================================ + +.. py:module:: ssl_tools.pipelines.har_classification.utils + + +Classes +------- + +.. autoapisummary:: + + ssl_tools.pipelines.har_classification.utils.DimensionAdder + ssl_tools.pipelines.har_classification.utils.FFT + ssl_tools.pipelines.har_classification.utils.Flatten + ssl_tools.pipelines.har_classification.utils.PredictionHeadClassifier + ssl_tools.pipelines.har_classification.utils.Spectrogram + ssl_tools.pipelines.har_classification.utils.SwapAxes + + +Module Contents +--------------- + +.. py:class:: DimensionAdder(dim) + + .. py:method:: __call__(x) + + +.. py:class:: FFT(absolute = True, centered = False) + + .. py:method:: __call__(x) + + Aplly FFT to the input signal. It apply the FFT into each channel + of the input signal. + + Parameters + ---------- + x : np.ndarray + An array with shape (n_channels, n_samples) containing the input + + Returns + ------- + np.ndarray + The FFT of the input signal. The shape of the output is + (n_channels, n_samples) if absolute is False, and + (n_channels, n_samples//2) if absolute is True. + + + +.. py:class:: Flatten + + .. py:method:: __call__(x) + + Flatten the input signal. It apply the flatten into each channel + of the input signal. + + Parameters + ---------- + x : np.ndarray + An array with shape (n_channels, n_samples) containing the input + + Returns + ------- + np.ndarray + The flatten of the input signal. The shape of the output is + (n_channels, n_samples). + + + +.. py:class:: PredictionHeadClassifier(prediction_head, num_classes = 6) + + Bases: :py:obj:`ssl_tools.models.nets.simple.SimpleClassificationNet` + + +.. py:class:: Spectrogram(fs=20, nperseg=16, noverlap=8, nfft=16) + + .. py:method:: __call__(x) + + Aplly Spectrogram to the input signal. It apply the Spectrogram into each channel + of the input signal. + + Parameters + ---------- + x : np.ndarray + An array with shape (n_channels, n_samples) containing the input + + Returns + ------- + np.ndarray + The Spectrogram of the input signal. The shape of the output is + (n_channels, n_samples) if absolute is False, and + (n_channels, n_samples//2) if absolute is True. + + + +.. py:class:: SwapAxes(axis1, axis2) + + .. py:method:: __call__(x) + + diff --git a/_sources/autoapi/ssl_tools/pipelines/index.rst.txt b/_sources/autoapi/ssl_tools/pipelines/index.rst.txt new file mode 100644 index 0000000..d210881 --- /dev/null +++ b/_sources/autoapi/ssl_tools/pipelines/index.rst.txt @@ -0,0 +1,27 @@ +ssl_tools.pipelines +=================== + +.. py:module:: ssl_tools.pipelines + + +Subpackages +----------- + +.. toctree:: + :maxdepth: 1 + + /autoapi/ssl_tools/pipelines/har_classification/index + + +Submodules +---------- + +.. toctree:: + :maxdepth: 1 + + /autoapi/ssl_tools/pipelines/base/index + /autoapi/ssl_tools/pipelines/cli/index + /autoapi/ssl_tools/pipelines/mlflow_train/index + /autoapi/ssl_tools/pipelines/utils/index + + diff --git a/_sources/autoapi/ssl_tools/pipelines/mlflow_train/index.rst.txt b/_sources/autoapi/ssl_tools/pipelines/mlflow_train/index.rst.txt new file mode 100644 index 0000000..394488c --- /dev/null +++ b/_sources/autoapi/ssl_tools/pipelines/mlflow_train/index.rst.txt @@ -0,0 +1,134 @@ +ssl_tools.pipelines.mlflow_train +================================ + +.. py:module:: ssl_tools.pipelines.mlflow_train + + +Classes +------- + +.. autoapisummary:: + + ssl_tools.pipelines.mlflow_train.LightningFineTuneMLFlow + ssl_tools.pipelines.mlflow_train.LightningTrainMLFlow + + +Module Contents +--------------- + +.. py:class:: LightningFineTuneMLFlow(registered_model_name, registered_model_tags = None, update_backbone = False, **kwargs) + + Bases: :py:obj:`LightningTrainMLFlow` + + + + Train a model using Lightning framework. + + Parameters + ---------- + experiment_name : str + Name of the experiment. + model_name : str + Name of the model. + dataset_name : str + Name of the dataset. + run_name : str, optional + The name of the run, by default None + accelerator : str, optional + The accelerator to use, by default "cpu" + devices : int, optional + Number of accelerators to use, by default 1 + num_nodes : int, optional + Number of nodes, by default 1 + strategy : str, optional + Training strategy, by default "auto" + max_epochs : int, optional + Maximium number of epochs, by default 1 + batch_size : int, optional + Batch size, by default 1 + limit_train_batches : int | float, optional + Limit the number of batches to train, by default 1.0 + limit_val_batches : int | float, optional + Limit the number of batches to test, by default 1.0 + checkpoint_monitor_metric : str, optional + The metric to monitor for checkpointing, by default None + checkpoint_monitor_mode : str, optional + The mode for checkpointing, by default "min" + patience : int, optional + The patience for early stopping, by default None + log_dir : str, optional + Location where logs will be saved, by default "./runs" + + + .. py:property:: client + + + .. py:method:: load_model() + + +.. py:class:: LightningTrainMLFlow(experiment_name, model_name, run_name = None, accelerator = 'cpu', devices = 1, num_nodes = 1, strategy = 'auto', max_epochs = 1, batch_size = 1, limit_train_batches = 1.0, limit_val_batches = 1.0, checkpoint_monitor_metric = None, checkpoint_monitor_mode = 'min', patience = None, log_dir = './mlruns', model_tags = None) + + Bases: :py:obj:`ssl_tools.pipelines.base.Pipeline` + + + + Train a model using Lightning framework. + + Parameters + ---------- + experiment_name : str + Name of the experiment. + model_name : str + Name of the model. + dataset_name : str + Name of the dataset. + run_name : str, optional + The name of the run, by default None + accelerator : str, optional + The accelerator to use, by default "cpu" + devices : int, optional + Number of accelerators to use, by default 1 + num_nodes : int, optional + Number of nodes, by default 1 + strategy : str, optional + Training strategy, by default "auto" + max_epochs : int, optional + Maximium number of epochs, by default 1 + batch_size : int, optional + Batch size, by default 1 + limit_train_batches : int | float, optional + Limit the number of batches to train, by default 1.0 + limit_val_batches : int | float, optional + Limit the number of batches to test, by default 1.0 + checkpoint_monitor_metric : str, optional + The metric to monitor for checkpointing, by default None + checkpoint_monitor_mode : str, optional + The mode for checkpointing, by default "min" + patience : int, optional + The patience for early stopping, by default None + log_dir : str, optional + Location where logs will be saved, by default "./runs" + + + .. py:method:: get_callbacks() + + + .. py:method:: get_data_module() + :abstractmethod: + + + + .. py:method:: get_logger() + + + .. py:method:: get_model() + :abstractmethod: + + + + .. py:method:: get_trainer(logger, callacks) + + + .. py:method:: run() + + diff --git a/_sources/autoapi/ssl_tools/pipelines/utils/index.rst.txt b/_sources/autoapi/ssl_tools/pipelines/utils/index.rst.txt new file mode 100644 index 0000000..702329c --- /dev/null +++ b/_sources/autoapi/ssl_tools/pipelines/utils/index.rst.txt @@ -0,0 +1,31 @@ +ssl_tools.pipelines.utils +========================= + +.. py:module:: ssl_tools.pipelines.utils + + +Functions +--------- + +.. autoapisummary:: + + ssl_tools.pipelines.utils.load_model_mlflow + ssl_tools.pipelines.utils.tags2str + + +Module Contents +--------------- + +.. py:function:: load_model_mlflow(client, registered_model_name, registered_model_tags = None) + +.. py:function:: tags2str(d) + + Convert a dictionary of tags to a search string compatible with MLflow's search_model_versions method. + + Parameters: + - d: A dictionary containing tags where keys are tag names and values are tag values. + + Returns: + - search_str: A search string formatted for MLflow's search_model_versions method. + + diff --git a/_sources/autoapi/ssl_tools/transforms/index.rst.txt b/_sources/autoapi/ssl_tools/transforms/index.rst.txt index c6c39fc..a4b8663 100644 --- a/_sources/autoapi/ssl_tools/transforms/index.rst.txt +++ b/_sources/autoapi/ssl_tools/transforms/index.rst.txt @@ -1,17 +1,20 @@ -:py:mod:`ssl_tools.transforms` -============================== +ssl_tools.transforms +==================== .. py:module:: ssl_tools.transforms Submodules ---------- + .. toctree:: - :titlesonly: :maxdepth: 1 - signal_1d/index.rst - time_1d/index.rst - utils/index.rst + /autoapi/ssl_tools/transforms/pad/index + /autoapi/ssl_tools/transforms/signal_1d/index + /autoapi/ssl_tools/transforms/time_1d/index + /autoapi/ssl_tools/transforms/time_1d_full/index + /autoapi/ssl_tools/transforms/utils/index + /autoapi/ssl_tools/transforms/window/index diff --git a/_sources/autoapi/ssl_tools/transforms/pad/index.rst.txt b/_sources/autoapi/ssl_tools/transforms/pad/index.rst.txt new file mode 100644 index 0000000..a05d55c --- /dev/null +++ b/_sources/autoapi/ssl_tools/transforms/pad/index.rst.txt @@ -0,0 +1,22 @@ +ssl_tools.transforms.pad +======================== + +.. py:module:: ssl_tools.transforms.pad + + +Classes +------- + +.. autoapisummary:: + + ssl_tools.transforms.pad.ZeroPaddingBetween + + +Module Contents +--------------- + +.. py:class:: ZeroPaddingBetween(pad_every = 3, padding_size = 2, discard_last = True) + + .. py:method:: __call__(x) + + diff --git a/_sources/autoapi/ssl_tools/transforms/signal_1d/index.rst.txt b/_sources/autoapi/ssl_tools/transforms/signal_1d/index.rst.txt index a8a84fa..2fba978 100644 --- a/_sources/autoapi/ssl_tools/transforms/signal_1d/index.rst.txt +++ b/_sources/autoapi/ssl_tools/transforms/signal_1d/index.rst.txt @@ -1,28 +1,27 @@ -:py:mod:`ssl_tools.transforms.signal_1d` -======================================== +ssl_tools.transforms.signal_1d +============================== .. py:module:: ssl_tools.transforms.signal_1d -Module Contents ---------------- - Classes -~~~~~~~ +------- .. autoapisummary:: ssl_tools.transforms.signal_1d.AddRemoveFrequency ssl_tools.transforms.signal_1d.FFT + ssl_tools.transforms.signal_1d.WelchPowerSpectralDensity - +Module Contents +--------------- .. py:class:: AddRemoveFrequency(add_pertub_ratio=0.1, remove_pertub_ratio=0.1) - Bases: :py:obj:`librep.base.Transform` + .. py:method:: __call__(sample) @@ -35,16 +34,71 @@ Classes .. py:method:: transform(sample) +.. py:class:: FFT(absolute = True) -.. py:class:: FFT + + Simple wrapper to apply the FFT to the data + Parameters + ---------- + absolute : bool, optional + If True, returns the absolute value of FFT, by default True - Bases: :py:obj:`librep.base.Transform` .. py:method:: __call__(sample) - .. py:method:: transform(sample) + .. py:method:: transform(x) + + Apply the FFT to the data + + Parameters + ---------- + x : np.ndarray + A 1-D array with the data + + Returns + ------- + np.ndarray + The FFT of the data + + + +.. py:class:: WelchPowerSpectralDensity(fs = 1 / 20, nperseg = None, noverlap = None, return_onesided=False, absolute = True) + + + Simple wrapper to apply the Welch Power Spectral Density to the data + + Parameters + ---------- + fs : int, optional + The sampling frequency, by default 20 + nperseg : int, optional + The number of data points in each segment, by default 30 + noverlap : int, optional + The number of points of overlap between segments, by default 15 + return_onesided : bool, optional + If True, return the one-sided PSD, by default False + absolute : bool, optional + If True, returns the absolute value of PSD, by default True + + + .. py:method:: __call__(sample) + + + .. py:method:: transform(x) + + Apply the Welch Power Spectral Density to the data + + Parameters + ---------- + x : np.ndarray + A 1-D array with the data + + Returns + ------- + np.ndarray + The Welch Power Spectral Density of the data diff --git a/_sources/autoapi/ssl_tools/transforms/time_1d/index.rst.txt b/_sources/autoapi/ssl_tools/transforms/time_1d/index.rst.txt index 2a59ab5..bd72484 100644 --- a/_sources/autoapi/ssl_tools/transforms/time_1d/index.rst.txt +++ b/_sources/autoapi/ssl_tools/transforms/time_1d/index.rst.txt @@ -1,14 +1,11 @@ -:py:mod:`ssl_tools.transforms.time_1d` -====================================== +ssl_tools.transforms.time_1d +============================ .. py:module:: ssl_tools.transforms.time_1d -Module Contents ---------------- - Classes -~~~~~~~ +------- .. autoapisummary:: @@ -21,89 +18,83 @@ Classes ssl_tools.transforms.time_1d.TimeAmplitudeModulation - +Module Contents +--------------- .. py:class:: AddGaussianNoise(mean=0.0, std=0.1) - Bases: :py:obj:`librep.base.Transform` + .. py:method:: __call__(sample) .. py:method:: transform(sample) - .. py:class:: LeftToRightFlip - Bases: :py:obj:`librep.base.Transform` + .. py:method:: __call__(sample) .. py:method:: transform(sample) - .. py:class:: MagnitudeWrap(max_magnitude=1.0) - Bases: :py:obj:`librep.base.Transform` + .. py:method:: __call__(sample) .. py:method:: transform(sample) - .. py:class:: RandomSmoothing(sigma_range=(1, 1)) - Bases: :py:obj:`librep.base.Transform` + .. py:method:: __call__(sample) .. py:method:: transform(sample) - .. py:class:: Rotate - Bases: :py:obj:`librep.base.Transform` + .. py:method:: __call__(sample) .. py:method:: transform(dataset) - .. py:class:: Scale(mean = 1.0, sigma = 0.5) - Bases: :py:obj:`librep.base.Transform` + .. py:method:: __call__(sample) .. py:method:: transform(sample) - .. py:class:: TimeAmplitudeModulation(modulation_factor=0.1) - Bases: :py:obj:`librep.base.Transform` + .. py:method:: __call__(sample) .. py:method:: transform(sample) - diff --git a/_sources/autoapi/ssl_tools/transforms/time_1d_full/index.rst.txt b/_sources/autoapi/ssl_tools/transforms/time_1d_full/index.rst.txt new file mode 100644 index 0000000..329e4d7 --- /dev/null +++ b/_sources/autoapi/ssl_tools/transforms/time_1d_full/index.rst.txt @@ -0,0 +1,130 @@ +ssl_tools.transforms.time_1d_full +================================= + +.. py:module:: ssl_tools.transforms.time_1d_full + + +Classes +------- + +.. autoapisummary:: + + ssl_tools.transforms.time_1d_full.Composer + ssl_tools.transforms.time_1d_full.ConcatComposer + ssl_tools.transforms.time_1d_full.Identity + ssl_tools.transforms.time_1d_full.MagnitudeWarp + ssl_tools.transforms.time_1d_full.Permutate + ssl_tools.transforms.time_1d_full.Rotate + ssl_tools.transforms.time_1d_full.Scale + ssl_tools.transforms.time_1d_full.TimeWarp + ssl_tools.transforms.time_1d_full.WindowSlice + ssl_tools.transforms.time_1d_full.WindowWarp + + +Module Contents +--------------- + +.. py:class:: Composer(transforms) + + .. py:method:: __call__(dataset, labels = None) + + + .. py:method:: __str__() + + Return str(self). + + + +.. py:class:: ConcatComposer(transforms, axis = 0) + + .. py:method:: __call__(dataset, labels = None) + + + .. py:method:: __str__() + + Return str(self). + + + +.. py:class:: Identity + + .. py:method:: __call__(dataset) + + +.. py:class:: MagnitudeWarp(sigma = 0.2, knot = 4) + + .. py:method:: __call__(dataset) + + + .. py:method:: __str__() + + Return str(self). + + + +.. py:class:: Permutate(max_segments = 5, segment_mode = 'equal') + + .. py:method:: __call__(dataset) + + + .. py:method:: __str__() + + Return str(self). + + + +.. py:class:: Rotate + + .. py:method:: __call__(dataset) + + + .. py:method:: __str__() + + Return str(self). + + + +.. py:class:: Scale(sigma = 0.1) + + .. py:method:: __call__(dataset) + + + .. py:method:: __str__() + + Return str(self). + + + +.. py:class:: TimeWarp(sigma = 0.2, knot = 4) + + .. py:method:: __call__(dataset) + + + .. py:method:: __str__() + + Return str(self). + + + +.. py:class:: WindowSlice(reduce_ratio = 0.9) + + .. py:method:: __call__(dataset) + + + .. py:method:: __str__() + + Return str(self). + + + +.. py:class:: WindowWarp(window_ratio=0.1, scales=[0.5, 2.0]) + + .. py:method:: __call__(dataset) + + + .. py:method:: __str__() + + Return str(self). + + + diff --git a/_sources/autoapi/ssl_tools/transforms/utils/index.rst.txt b/_sources/autoapi/ssl_tools/transforms/utils/index.rst.txt index 5500211..b5fd6eb 100644 --- a/_sources/autoapi/ssl_tools/transforms/utils/index.rst.txt +++ b/_sources/autoapi/ssl_tools/transforms/utils/index.rst.txt @@ -1,59 +1,140 @@ -:py:mod:`ssl_tools.transforms.utils` -==================================== +ssl_tools.transforms.utils +========================== .. py:module:: ssl_tools.transforms.utils -Module Contents ---------------- - Classes -~~~~~~~ +------- .. autoapisummary:: + ssl_tools.transforms.utils.Cast ssl_tools.transforms.utils.Composer ssl_tools.transforms.utils.Flatten ssl_tools.transforms.utils.Identity + ssl_tools.transforms.utils.PerChannelTransform ssl_tools.transforms.utils.Reshape + ssl_tools.transforms.utils.Squeeze + ssl_tools.transforms.utils.StackComposer + ssl_tools.transforms.utils.Unsqueeze +Module Contents +--------------- +.. py:class:: Cast(dtype) -.. py:class:: Composer(transforms) + Bases: :py:obj:`librep.base.Transform` + .. py:method:: __call__(X) + + + .. py:method:: transform(X) + + +.. py:class:: Composer(transforms) + Bases: :py:obj:`librep.base.Transform` + .. py:method:: __call__(X) .. py:method:: transform(X) - .. py:class:: Flatten + Bases: :py:obj:`librep.base.Transform` + + + .. py:method:: __call__(X) + + + .. py:method:: transform(X) + + +.. py:class:: Identity Bases: :py:obj:`librep.base.Transform` + + .. py:method:: __call__(X) + + .. py:method:: transform(X) +.. py:class:: PerChannelTransform(transform) -.. py:class:: Identity + Bases: :py:obj:`librep.base.Transform` + + + .. py:method:: __call__(X) .. py:method:: transform(X) + Split the data into channels and apply the transforms to each channel + separately. + + Parameters + ---------- + data : np.ndarray + The data to be transformed. It must be a 2-D array with the shape + (C, T), where C is the number of channels and T is the number of + time steps. + transforms : List[Transform] + A sequence of transforms to apply in the data + + Returns + ------- + np.ndarray + An 2-D array with the transformed data. The array has the number of + channels as the first dimension. + .. py:class:: Reshape(shape) + Bases: :py:obj:`librep.base.Transform` + + + .. py:method:: __call__(X) + + + .. py:method:: transform(X) + + +.. py:class:: Squeeze(axis=None) Bases: :py:obj:`librep.base.Transform` + + .. py:method:: __call__(X) + + + .. py:method:: transform(X) + + +.. py:class:: StackComposer(transforms) + + .. py:method:: __call__(x) + + .. py:method:: transform(X) +.. py:class:: Unsqueeze(axis) + + Bases: :py:obj:`librep.base.Transform` + + + .. py:method:: __call__(X) + + + .. py:method:: transform(X) + diff --git a/_sources/autoapi/ssl_tools/transforms/window/index.rst.txt b/_sources/autoapi/ssl_tools/transforms/window/index.rst.txt new file mode 100644 index 0000000..09d42b0 --- /dev/null +++ b/_sources/autoapi/ssl_tools/transforms/window/index.rst.txt @@ -0,0 +1,22 @@ +ssl_tools.transforms.window +=========================== + +.. py:module:: ssl_tools.transforms.window + + +Classes +------- + +.. autoapisummary:: + + ssl_tools.transforms.window.Windowize + + +Module Contents +--------------- + +.. py:class:: Windowize(time_segments = 15, stride = None) + + .. py:method:: __call__(x) + + diff --git a/_sources/autoapi/ssl_tools/utils/configurable/index.rst.txt b/_sources/autoapi/ssl_tools/utils/configurable/index.rst.txt index b29a53f..5187de3 100644 --- a/_sources/autoapi/ssl_tools/utils/configurable/index.rst.txt +++ b/_sources/autoapi/ssl_tools/utils/configurable/index.rst.txt @@ -1,28 +1,26 @@ -:py:mod:`ssl_tools.utils.configurable` -====================================== +ssl_tools.utils.configurable +============================ .. py:module:: ssl_tools.utils.configurable -Module Contents ---------------- - Classes -~~~~~~~ +------- .. autoapisummary:: ssl_tools.utils.configurable.Configurable - +Module Contents +--------------- .. py:class:: Configurable - Configurable interface for models and other objects that can be configured with a dictionary. For now, this interface is used to save the hyperparameters of the models. + .. py:method:: get_config() :abstractmethod: diff --git a/_sources/autoapi/ssl_tools/utils/data/index.rst.txt b/_sources/autoapi/ssl_tools/utils/data/index.rst.txt index 9e3048c..9a10004 100644 --- a/_sources/autoapi/ssl_tools/utils/data/index.rst.txt +++ b/_sources/autoapi/ssl_tools/utils/data/index.rst.txt @@ -1,27 +1,35 @@ -:py:mod:`ssl_tools.utils.data` -============================== +ssl_tools.utils.data +==================== .. py:module:: ssl_tools.utils.data -Module Contents ---------------- - Classes -~~~~~~~ +------- .. autoapisummary:: ssl_tools.utils.data.ConcatDataset +Functions +--------- +.. autoapisummary:: -.. py:class:: ConcatDataset(datasets) + ssl_tools.utils.data.full_dataset_from_dataloader + ssl_tools.utils.data.get_full_data_split + ssl_tools.utils.data.get_split_dataloader + + +Module Contents +--------------- +.. py:class:: ConcatDataset(datasets) Concatenate multiple datasets1 + .. py:method:: __getitem__(i) @@ -33,3 +41,9 @@ Classes +.. py:function:: full_dataset_from_dataloader(dataloader) + +.. py:function:: get_full_data_split(data_module, stage) + +.. py:function:: get_split_dataloader(stage, data_module) + diff --git a/_sources/autoapi/ssl_tools/utils/index.rst.txt b/_sources/autoapi/ssl_tools/utils/index.rst.txt index 412b046..bdfa5bc 100644 --- a/_sources/autoapi/ssl_tools/utils/index.rst.txt +++ b/_sources/autoapi/ssl_tools/utils/index.rst.txt @@ -1,17 +1,18 @@ -:py:mod:`ssl_tools.utils` -========================= +ssl_tools.utils +=============== .. py:module:: ssl_tools.utils Submodules ---------- + .. toctree:: - :titlesonly: :maxdepth: 1 - configurable/index.rst - data/index.rst - types/index.rst + /autoapi/ssl_tools/utils/configurable/index + /autoapi/ssl_tools/utils/data/index + /autoapi/ssl_tools/utils/layers/index + /autoapi/ssl_tools/utils/types/index diff --git a/_sources/autoapi/ssl_tools/utils/layers/index.rst.txt b/_sources/autoapi/ssl_tools/utils/layers/index.rst.txt new file mode 100644 index 0000000..5b549d6 --- /dev/null +++ b/_sources/autoapi/ssl_tools/utils/layers/index.rst.txt @@ -0,0 +1,31 @@ +ssl_tools.utils.layers +====================== + +.. py:module:: ssl_tools.utils.layers + + +Classes +------- + +.. autoapisummary:: + + ssl_tools.utils.layers.OutputLoggerCallback + + +Module Contents +--------------- + +.. py:class:: OutputLoggerCallback(layers) + + Bases: :py:obj:`lightning.Callback` + + + .. py:method:: count(module, input, output, layer_name) + + + .. py:method:: setup(trainer, pl_module, stage) + + + .. py:method:: teardown(trainer, pl_module, stage) + + diff --git a/_sources/autoapi/ssl_tools/utils/types/index.rst.txt b/_sources/autoapi/ssl_tools/utils/types/index.rst.txt index 3840c63..09f7ca3 100644 --- a/_sources/autoapi/ssl_tools/utils/types/index.rst.txt +++ b/_sources/autoapi/ssl_tools/utils/types/index.rst.txt @@ -1,13 +1,19 @@ -:py:mod:`ssl_tools.utils.types` -=============================== +ssl_tools.utils.types +===================== .. py:module:: ssl_tools.utils.types +Attributes +---------- + +.. autoapisummary:: + + ssl_tools.utils.types.PathLike + + Module Contents --------------- .. py:data:: PathLike - - diff --git a/_sources/notebooks/02_training_model.ipynb.txt b/_sources/notebooks/02_training_model.ipynb.txt index 4ce9a64..94ebe48 100644 --- a/_sources/notebooks/02_training_model.ipynb.txt +++ b/_sources/notebooks/02_training_model.ipynb.txt @@ -208,9 +208,8 @@ "from ssl_tools.models.nets.convnet import Simple1DConvNetwork\n", "\n", "model = Simple1DConvNetwork(\n", - " input_channels=6, # The number of input channels (accel-x, accel-y, ...)\n", + " input_shape=(6,60), # (The number of input channels, input size of FC layers)\n", " num_classes=6, # The number of output classes\n", - " time_steps=60, # Used to auto calculate the input size of FC layers\n", " learning_rate=1e-3, # The learning rate of the Adam optimizer\n", ")\n", "\n", diff --git a/_sources/notebooks/05_covid_anomaly_detection.ipynb.txt b/_sources/notebooks/05_covid_anomaly_detection.ipynb.txt new file mode 100644 index 0000000..a7091e7 --- /dev/null +++ b/_sources/notebooks/05_covid_anomaly_detection.ipynb.txt @@ -0,0 +1,2372 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 5. Training an Anomaly Detection Model for Covid Anomaly Detection\n", + "\n", + "In this tutorial, we will train an anomaly detection model using a simple [LSTM-AutoEncoder model](https://www.medrxiv.org/content/10.1101/2021.01.08.21249474v1).\n", + "Data can be obtained from [this link](https://iscteiul365-my.sharepoint.com/:u:/g/personal/oonia_iscte-iul_pt/ERZLm1ruUNpMqkSwjpqhE9wB_7loVWAC4yZWuIH2RKGOlQ?e=kD4HlI). This is a processed version of data from original Stanford dataset-Phase 2. The overall pre-processing pipeline used is illustrated in Figure below.\n", + "\n", + "![preprocessing](stanford_data_processing.png)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Data was aquired from diferent sources (Germin, FitBit, Apple Watch) and pre-processed to have a common format. In this form, data has two columns: heart rate and number of user steps in last minute. \n", + "Then the processing pipeline was applied to the data. The pipeline is composed of the following steps:\n", + "1. Once data was standardized, the resting heart rate was extracted (``Resting Heart Rate Extractor``, in Figure). This process takes as input `min_minutes_rest` that is the number of minutes that the user has to be at rest to consider the heart rate as resting. This variable looks at user steps and, when user steps is 0 for `min_minutes_rest` minutes, the heart rate is considered as resting. At the end of this process, we will have a new dataframe with: the date and the resting heart rate of the last minute.\n", + "2. The second step is adding labels." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "from ssl_tools.data.data_modules.covid_anomaly import CovidUserAnomalyDataModule\n", + "from ssl_tools.utils.data import get_full_data_split\n", + "from ssl_tools.models.nets.lstm_ae import LSTMAutoencoder\n", + "import lightning as L\n", + "import torch\n", + "import numpy as np\n", + "from torchmetrics import MeanSquaredError" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
datetimeRHR-0RHR-1RHR-2RHR-3RHR-4RHR-5RHR-6RHR-7RHR-8...RHR-10RHR-11RHR-12RHR-13RHR-14RHR-15anomalybaselinelabelparticipant_id
02027-01-14 21:00:001.1701750.653752-0.392374-1.431553-2.129013-2.755962-3.681322-4.674443-5.668570...-6.937363-7.102118-6.975790-6.554774-6.112156-5.396099FalseTruenormalP110465
12027-01-15 05:00:00-5.668570-6.373289-6.937363-7.102118-6.975790-6.554774-6.112156-5.396099-4.415848...-2.656756-1.305630-0.0727561.0461951.5304671.829053FalseFalsenormalP110465
22027-01-15 13:00:00-4.415848-3.467073-2.656756-1.305630-0.0727561.0461951.5304671.8290531.223064...-0.424000-1.145581-1.355121-2.321206-3.124961-3.928738FalseFalsenormalP110465
32027-01-15 21:00:001.2230640.472444-0.424000-1.145581-1.355121-2.321206-3.124961-3.928738-4.802627...-6.067744-5.460156-4.671143-3.408943-2.237883-1.187843FalseFalsenormalP110465
42027-01-16 05:00:00-4.802627-5.831013-6.067744-5.460156-4.671143-3.408943-2.237883-1.187843-0.062360...2.2669443.7944654.6257454.8277564.7200004.677464FalseFalsenormalP110465
..................................................................
317322024-12-13 00:00:00-0.180702-0.499793-0.749829-0.868485-0.966754-1.004670-0.888210-0.580762-0.467943...0.0920000.3478400.6363950.9581951.1705141.301841FalseFalserecoveredP992022
317332024-12-13 08:00:00-0.467943-0.1627400.0920000.3478400.6363950.9581951.1705141.3018411.477526...1.6603441.6566001.6856521.7472521.7673291.793616FalseFalserecoveredP992022
317342024-12-13 16:00:001.4775261.6573211.6603441.6566001.6856521.7472521.7673291.7936161.728615...1.5098331.3807491.2637441.1399971.0242050.946663FalseFalserecoveredP992022
317352024-12-14 00:00:001.7286151.6162651.5098331.3807491.2637441.1399971.0242050.9466631.136868...1.6421531.9093812.1144392.2822382.4536912.587843FalseFalserecoveredP992022
317362024-12-14 08:00:001.1368681.3804181.6421531.9093812.1144392.2822382.4536912.5878432.437232...2.3598402.1734002.0981401.9676691.7845121.561848FalseFalserecoveredP992022
\n", + "

31737 rows × 21 columns

\n", + "
" + ], + "text/plain": [ + " datetime RHR-0 RHR-1 RHR-2 RHR-3 RHR-4 \\\n", + "0 2027-01-14 21:00:00 1.170175 0.653752 -0.392374 -1.431553 -2.129013 \n", + "1 2027-01-15 05:00:00 -5.668570 -6.373289 -6.937363 -7.102118 -6.975790 \n", + "2 2027-01-15 13:00:00 -4.415848 -3.467073 -2.656756 -1.305630 -0.072756 \n", + "3 2027-01-15 21:00:00 1.223064 0.472444 -0.424000 -1.145581 -1.355121 \n", + "4 2027-01-16 05:00:00 -4.802627 -5.831013 -6.067744 -5.460156 -4.671143 \n", + "... ... ... ... ... ... ... \n", + "31732 2024-12-13 00:00:00 -0.180702 -0.499793 -0.749829 -0.868485 -0.966754 \n", + "31733 2024-12-13 08:00:00 -0.467943 -0.162740 0.092000 0.347840 0.636395 \n", + "31734 2024-12-13 16:00:00 1.477526 1.657321 1.660344 1.656600 1.685652 \n", + "31735 2024-12-14 00:00:00 1.728615 1.616265 1.509833 1.380749 1.263744 \n", + "31736 2024-12-14 08:00:00 1.136868 1.380418 1.642153 1.909381 2.114439 \n", + "\n", + " RHR-5 RHR-6 RHR-7 RHR-8 ... RHR-10 RHR-11 \\\n", + "0 -2.755962 -3.681322 -4.674443 -5.668570 ... -6.937363 -7.102118 \n", + "1 -6.554774 -6.112156 -5.396099 -4.415848 ... -2.656756 -1.305630 \n", + "2 1.046195 1.530467 1.829053 1.223064 ... -0.424000 -1.145581 \n", + "3 -2.321206 -3.124961 -3.928738 -4.802627 ... -6.067744 -5.460156 \n", + "4 -3.408943 -2.237883 -1.187843 -0.062360 ... 2.266944 3.794465 \n", + "... ... ... ... ... ... ... ... \n", + "31732 -1.004670 -0.888210 -0.580762 -0.467943 ... 0.092000 0.347840 \n", + "31733 0.958195 1.170514 1.301841 1.477526 ... 1.660344 1.656600 \n", + "31734 1.747252 1.767329 1.793616 1.728615 ... 1.509833 1.380749 \n", + "31735 1.139997 1.024205 0.946663 1.136868 ... 1.642153 1.909381 \n", + "31736 2.282238 2.453691 2.587843 2.437232 ... 2.359840 2.173400 \n", + "\n", + " RHR-12 RHR-13 RHR-14 RHR-15 anomaly baseline label \\\n", + "0 -6.975790 -6.554774 -6.112156 -5.396099 False True normal \n", + "1 -0.072756 1.046195 1.530467 1.829053 False False normal \n", + "2 -1.355121 -2.321206 -3.124961 -3.928738 False False normal \n", + "3 -4.671143 -3.408943 -2.237883 -1.187843 False False normal \n", + "4 4.625745 4.827756 4.720000 4.677464 False False normal \n", + "... ... ... ... ... ... ... ... \n", + "31732 0.636395 0.958195 1.170514 1.301841 False False recovered \n", + "31733 1.685652 1.747252 1.767329 1.793616 False False recovered \n", + "31734 1.263744 1.139997 1.024205 0.946663 False False recovered \n", + "31735 2.114439 2.282238 2.453691 2.587843 False False recovered \n", + "31736 2.098140 1.967669 1.784512 1.561848 False False recovered \n", + "\n", + " participant_id \n", + "0 P110465 \n", + "1 P110465 \n", + "2 P110465 \n", + "3 P110465 \n", + "4 P110465 \n", + "... ... \n", + "31732 P992022 \n", + "31733 P992022 \n", + "31734 P992022 \n", + "31735 P992022 \n", + "31736 P992022 \n", + "\n", + "[31737 rows x 21 columns]" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Read CSV data\n", + "data_path = \"/workspaces/hiaac-m4/data/Stanford-COVID/processed/windowed_16_overlap_8_df_scaled.csv\"\n", + "df = pd.read_csv(data_path)\n", + "df" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "CovidUserAnomalyDataModule (Data=/workspaces/hiaac-m4/data/Stanford-COVID/processed/windowed_16_overlap_8_df_scaled.csv, 1 participant selected)" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dm = CovidUserAnomalyDataModule(\n", + " data_path,\n", + " participants=[\"P992022\"],\n", + " batch_size=32,\n", + " num_workers=0,\n", + " reshape=(16, 1),\n", + ")\n", + "dm" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "LSTMAutoencoder(\n", + " (backbone): _LSTMAutoEncoder(\n", + " (lstm1): LSTM(1, 128, batch_first=True)\n", + " (lstm2): LSTM(128, 64, batch_first=True)\n", + " (repeat_vector): Linear(in_features=64, out_features=1024, bias=True)\n", + " (lstm3): LSTM(64, 64, batch_first=True)\n", + " (lstm4): LSTM(64, 128, batch_first=True)\n", + " (time_distributed): Linear(in_features=128, out_features=1, bias=True)\n", + " )\n", + " (loss_fn): MSELoss()\n", + ")" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model = LSTMAutoencoder(input_shape=(16, 1))\n", + "model" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU available: True (cuda), used: False\n", + "TPU available: False, using: 0 TPU cores\n", + "IPU available: False, using: 0 IPUs\n", + "HPU available: False, using: 0 HPUs\n", + "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/setup.py:187: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "trainer = L.Trainer(max_epochs=100, devices=1, accelerator=\"cpu\")\n", + "trainer" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + " | Name | Type | Params\n", + "----------------------------------------------\n", + "0 | backbone | _LSTMAutoEncoder | 316 K \n", + "1 | loss_fn | MSELoss | 0 \n", + "----------------------------------------------\n", + "316 K Trainable params\n", + "0 Non-trainable params\n", + "316 K Total params\n", + "1.264 Total estimated model params size (MB)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "122a71df981c48c183eb2b4e7585103d", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Sanity Checking: | | 0/? [00:00 anomaly_threshold else 0 for loss in losses]" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
truepredictedlossanomaly_threshold
0000.0237000.374275
1000.0914130.374275
2000.0542990.374275
3000.0074860.374275
4000.0246010.374275
...............
89100.0898330.374275
90100.0515620.374275
91100.1327480.374275
92100.1586100.374275
93100.0255220.374275
\n", + "

94 rows × 4 columns

\n", + "
" + ], + "text/plain": [ + " true predicted loss anomaly_threshold\n", + "0 0 0 0.023700 0.374275\n", + "1 0 0 0.091413 0.374275\n", + "2 0 0 0.054299 0.374275\n", + "3 0 0 0.007486 0.374275\n", + "4 0 0 0.024601 0.374275\n", + ".. ... ... ... ...\n", + "89 1 0 0.089833 0.374275\n", + "90 1 0 0.051562 0.374275\n", + "91 1 0 0.132748 0.374275\n", + "92 1 0 0.158610 0.374275\n", + "93 1 0 0.025522 0.374275\n", + "\n", + "[94 rows x 4 columns]" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "results_dataframe = pd.DataFrame(\n", + " {\n", + " \"true\": y_test,\n", + " \"predicted\": y_test_hat,\n", + " \"loss\": losses,\n", + " \"anomaly_threshold\": anomaly_threshold,\n", + " }\n", + ")\n", + "\n", + "results_dataframe" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "F1-score: 0.0\n", + "Recall: 0.0\n", + "Balanced Accuracy: 0.5\n", + "ROC AUC: 0.5\n" + ] + } + ], + "source": [ + "from sklearn.metrics import f1_score, recall_score, balanced_accuracy_score, roc_auc_score\n", + "\n", + "# Extract true and predicted labels from the results_dataframe\n", + "true_labels = results_dataframe['true']\n", + "predicted_labels = results_dataframe['predicted']\n", + "\n", + "# Calculate the F1-score\n", + "f1 = f1_score(true_labels, predicted_labels)\n", + "\n", + "# Calculate the recall\n", + "recall = recall_score(true_labels, predicted_labels)\n", + "\n", + "# Calculate the balanced accuracy\n", + "balanced_acc = balanced_accuracy_score(true_labels, predicted_labels)\n", + "\n", + "# Calculate the ROC AUC\n", + "roc_auc = roc_auc_score(true_labels, predicted_labels)\n", + "\n", + "# Print the results\n", + "print(\"F1-score:\", f1)\n", + "print(\"Recall:\", recall)\n", + "print(\"Balanced Accuracy:\", balanced_acc)\n", + "print(\"ROC AUC:\", roc_auc)" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAApUAAAJOCAYAAADmqPxLAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAABTMUlEQVR4nO3dd3hUZdrH8d8kkAmQBhESkBJ6L9ID0iOIgtQFRJeuoIhgABWVDguLNBVExFBkZRGkKOiCSFV6L4JIFRQSiiQhQArJef/AzOsY0IyTw0yY74frXDLPOfOc++Qy5M79lLEYhmEIAAAAcIKXqwMAAABA9kdSCQAAAKeRVAIAAMBpJJUAAABwGkklAAAAnEZSCQAAAKeRVAIAAMBpJJUAAABwGkklAAAAnEZSCcCtnThxQs2bN1dgYKAsFotWrlyZpf2fPXtWFotF8+fPz9J+s7PGjRurcePGrg4DQDZDUgngL506dUp9+/ZViRIl5Ovrq4CAANWvX1/vvPOObt26Zeq9u3fvrsOHD2v8+PFauHChatasaer97qcePXrIYrEoICDgrl/HEydOyGKxyGKxaPLkyQ73f+HCBY0aNUoHDhzIgmgB4M/lcHUAANzbl19+qX/84x+yWq3q1q2bKlWqpOTkZH333XcaOnSovv/+e3344Yem3PvWrVvavn273nzzTb300kum3KNYsWK6deuWcubMaUr/fyVHjhy6efOmVq1apU6dOtmd++STT+Tr66vExMS/1feFCxc0evRohYWFqVq1apl+39dff/237gfAs5FUArinM2fOqEuXLipWrJg2bNigggUL2s71799fJ0+e1Jdffmna/S9fvixJCgoKMu0eFotFvr6+pvX/V6xWq+rXr6///ve/GZLKRYsW6cknn9SyZcvuSyw3b95U7ty55ePjc1/uB+DBwvA3gHuaNGmSEhISFBUVZZdQpitVqpQGDhxoe3379m2NHTtWJUuWlNVqVVhYmN544w0lJSXZvS8sLEytWrXSd999p9q1a8vX11clSpTQxx9/bLtm1KhRKlasmCRp6NChslgsCgsLk3Rn2Dj97783atQoWSwWu7Z169bp0UcfVVBQkPz8/FS2bFm98cYbtvP3mlO5YcMGNWjQQHny5FFQUJDatGmjY8eO3fV+J0+eVI8ePRQUFKTAwED17NlTN2/evPcX9g+6du2q//3vf4qNjbW17d69WydOnFDXrl0zXP/rr79qyJAhqly5svz8/BQQEKCWLVvq4MGDtms2bdqkWrVqSZJ69uxpG0ZPf87GjRurUqVK2rt3rxo2bKjcuXPbvi5/nFPZvXt3+fr6Znj+Fi1aKG/evLpw4UKmnxXAg4ukEsA9rVq1SiVKlFC9evUydX2fPn00YsQIVa9eXdOmTVOjRo00YcIEdenSJcO1J0+eVMeOHfXYY49pypQpyps3r3r06KHvv/9ektS+fXtNmzZNkvT0009r4cKFmj59ukPxf//992rVqpWSkpI0ZswYTZkyRU899ZS2bt36p+/75ptv1KJFC126dEmjRo1SZGSktm3bpvr16+vs2bMZru/UqZOuX7+uCRMmqFOnTpo/f75Gjx6d6Tjbt28vi8Wi5cuX29oWLVqkcuXKqXr16hmuP336tFauXKlWrVpp6tSpGjp0qA4fPqxGjRrZErzy5ctrzJgxkqTnn39eCxcu1MKFC9WwYUNbP1evXlXLli1VrVo1TZ8+XU2aNLlrfO+8847y58+v7t27KzU1VZI0e/Zsff3113rvvfdUqFChTD8rgAeYAQB3ERcXZ0gy2rRpk6nrDxw4YEgy+vTpY9c+ZMgQQ5KxYcMGW1uxYsUMScaWLVtsbZcuXTKsVqsxePBgW9uZM2cMScbbb79t12f37t2NYsWKZYhh5MiRxu//WZs2bZohybh8+fI9406/x7x582xt1apVMwoUKGBcvXrV1nbw4EHDy8vL6NatW4b79erVy67Pdu3aGcHBwfe85++fI0+ePIZhGEbHjh2NZs2aGYZhGKmpqUZoaKgxevTou34NEhMTjdTU1AzPYbVajTFjxtjadu/eneHZ0jVq1MiQZHzwwQd3PdeoUSO7trVr1xqSjHHjxhmnT582/Pz8jLZt2/7lMwLwHFQqAdxVfHy8JMnf3z9T13/11VeSpMjISLv2wYMHS1KGuZcVKlRQgwYNbK/z58+vsmXL6vTp03875j9Kn4v5+eefKy0tLVPvuXjxog4cOKAePXooX758tvYqVarosccesz3n7/Xr18/udYMGDXT16lXb1zAzunbtqk2bNik6OlobNmxQdHT0XYe+pTvzML287vzznZqaqqtXr9qG9vft25fpe1qtVvXs2TNT1zZv3lx9+/bVmDFj1L59e/n6+mr27NmZvheABx9JJYC7CggIkCRdv349U9f/9NNP8vLyUqlSpezaQ0NDFRQUpJ9++smuvWjRohn6yJs3r65du/Y3I86oc+fOql+/vvr06aOQkBB16dJFS5Ys+dMEMz3OsmXLZjhXvnx5XblyRTdu3LBr/+Oz5M2bV5IcepYnnnhC/v7++vTTT/XJJ5+oVq1aGb6W6dLS0jRt2jSVLl1aVqtVDz30kPLnz69Dhw4pLi4u0/d8+OGHHVqUM3nyZOXLl08HDhzQu+++qwIFCmT6vQAefCSVAO4qICBAhQoV0pEjRxx63x8XytyLt7f3XdsNw/jb90if75cuV65c2rJli7755hv985//1KFDh9S5c2c99thjGa51hjPPks5qtap9+/ZasGCBVqxYcc8qpST961//UmRkpBo2bKj//Oc/Wrt2rdatW6eKFStmuiIr3fn6OGL//v26dOmSJOnw4cMOvRfAg4+kEsA9tWrVSqdOndL27dv/8tpixYopLS1NJ06csGuPiYlRbGysbSV3VsibN6/dSul0f6yGSpKXl5eaNWumqVOn6ujRoxo/frw2bNigjRs33rXv9DiPHz+e4dwPP/yghx56SHny5HHuAe6ha9eu2r9/v65fv37XxU3pPvvsMzVp0kRRUVHq0qWLmjdvroiIiAxfk8wm+Jlx48YN9ezZUxUqVNDzzz+vSZMmaffu3VnWP4Dsj6QSwD29+uqrypMnj/r06aOYmJgM50+dOqV33nlH0p3hW0kZVmhPnTpVkvTkk09mWVwlS5ZUXFycDh06ZGu7ePGiVqxYYXfdr7/+muG96ZuA/3Gbo3QFCxZUtWrVtGDBArsk7ciRI/r6669tz2mGJk2aaOzYsZoxY4ZCQ0PveZ23t3eGKujSpUv1yy+/2LWlJ793S8Ad9dprr+ncuXNasGCBpk6dqrCwMHXv3v2eX0cAnofNzwHcU8mSJbVo0SJ17txZ5cuXt/tEnW3btmnp0qXq0aOHJKlq1arq3r27PvzwQ8XGxqpRo0batWuXFixYoLZt295zu5q/o0uXLnrttdfUrl07vfzyy7p586ZmzZqlMmXK2C1UGTNmjLZs2aInn3xSxYoV06VLl/T++++rcOHCevTRR+/Z/9tvv62WLVsqPDxcvXv31q1bt/Tee+8pMDBQo0aNyrLn+CMvLy+99dZbf3ldq1atNGbMGPXs2VP16tXT4cOH9cknn6hEiRJ215UsWVJBQUH64IMP5O/vrzx58qhOnToqXry4Q3Ft2LBB77//vkaOHGnb4mjevHlq3Lixhg8frkmTJjnUH4AHE5VKAH/qqaee0qFDh9SxY0d9/vnn6t+/v15//XWdPXtWU6ZM0bvvvmu79qOPPtLo0aO1e/duDRo0SBs2bNCwYcO0ePHiLI0pODhYK1asUO7cufXqq69qwYIFmjBhglq3bp0h9qJFi2ru3Lnq37+/Zs6cqYYNG2rDhg0KDAy8Z/8RERFas2aNgoODNWLECE2ePFl169bV1q1bHU7IzPDGG29o8ODBWrt2rQYOHKh9+/bpyy+/VJEiReyuy5kzpxYsWCBvb2/169dPTz/9tDZv3uzQva5fv65evXrpkUce0Ztvvmlrb9CggQYOHKgpU6Zox44dWfJcALI3i+HITHIAAADgLqhUAgAAwGkklQAAAHAaSSUAAACcRlIJAAAAp5FUAgAAwGkklQAAAHAam5+7kbS0NF24cEH+/v5Z+vFqAAB4AsMwdP36dRUqVEheXq6vmyUmJio5Odm0/n18fOTr62ta/44iqXQjFy5cyLB5MQAAcMz58+dVuHBhl8aQmJioXP7B0u2bpt0jNDRUZ86ccZvEkqTSjfj7+0uS6o9ZqRy+eVwcDYC7WdK7tqtDAHAP1+PjVap4EdvPU1dKTk6Wbt+UtUJ3ydsn62+QmqzoowuUnJxMUomM0oe8c/jmUY5cJJWAOwoICHB1CAD+gltNIcvhK4sJSaVhcf3w/h+5X0QAAADIdqhUAgAAmMUiyYzKqRsVY9NRqQQAAIDTqFQCAACYxeJ15zCjXzfjfhEBAAAg26FSCQAAYBaLxaQ5le43qZKkEgAAwCwMfwMAAACZR6USAADALB40/E2lEgAAAE6jUgkAAGAak+ZUumFd0P0iAgAAQLZDpRIAAMAszKkEAAAAMo9KJQAAgFnYpxIAAADIPCqVAAAAZvGgOZUklQAAAGZh+BsAAADIPCqVAAAAZvGg4W8qlQAAAHAalUoAAACzMKcSAAAAyDwqlQAAAGaxWEyqVDKnEgAAAA8gKpUAAABm8bLcOczo182QVAIAAJiFhToAAABA5lGpBAAAMAubnwMAAACZR6USAADALMypBAAAADKPSiUAAIBZmFMJAAAAZB6VSgAAALMwpxIAAADIPCqVAAAAZvGgOZUklQAAAGZh+BsAAADIPCqVAAAAZvGg4W8qlQAAAHAalUoAAADTmDSn0g3rgu4XEQAAALIdKpUAAABmYU4lAAAAkHlUKgEAAMxisZi0T6X7VSpJKgEAAMzC5ucAAABA5lGpBAAAMAsLdQAAAIDMo1IJAABgFuZUAgAAAJlHpRIAAMAszKkEAAAAMo9KJQAAgFmYUwkAAABkHpVKAAAAszCnEgAAAM6yWCymHZk1atSoDO8tV66c7XxiYqL69++v4OBg+fn5qUOHDoqJiXH4WUkqAQAAHnAVK1bUxYsXbcd3331nO/fKK69o1apVWrp0qTZv3qwLFy6offv2Dt+D4W8AAACTOFpVdKBjhy7PkSOHQkNDM7THxcUpKipKixYtUtOmTSVJ8+bNU/ny5bVjxw7VrVs30/egUgkAAPCAO3HihAoVKqQSJUromWee0blz5yRJe/fuVUpKiiIiImzXlitXTkWLFtX27dsdugeVSgAAALNYfjvM6FdSfHy8XbPVapXVarVrq1OnjubPn6+yZcvq4sWLGj16tBo0aKAjR44oOjpaPj4+CgoKsntPSEiIoqOjHQqJpBIAACCbKlKkiN3rkSNHatSoUXZtLVu2tP29SpUqqlOnjooVK6YlS5YoV65cWRYLSSUAAIBJzJ5Tef78eQUEBNia/1ilvJugoCCVKVNGJ0+e1GOPPabk5GTFxsbaVStjYmLuOgfzzzCnEgAAIJsKCAiwOzKTVCYkJOjUqVMqWLCgatSooZw5c2r9+vW288ePH9e5c+cUHh7uUCxUKgEAAEziDqu/hwwZotatW6tYsWK6cOGCRo4cKW9vbz399NMKDAxU7969FRkZqXz58ikgIEADBgxQeHi4Qyu/JZJKAAAA07hDUvnzzz/r6aef1tWrV5U/f349+uij2rFjh/Lnzy9JmjZtmry8vNShQwclJSWpRYsWev/99x0OiaQSAADgAbZ48eI/Pe/r66uZM2dq5syZTt2HpBIAAMAk7lCpvF9YqAMAAACnUakEAAAwi8mbn7sTKpUAAABwGpVKAAAAkzCnEgAAAHAAlUoAAACTWCwyqVKZ9V06i6QSAADAJBaZNPzthlklw98AAABwGpVKAAAAk7BQBwAAAHAAlUoAAACzsPk5AAAAkHlUKgEAAMxi0pxKgzmVAAAAeBBRqQQAADCJWau/zdn70jlUKgEAAOA0KpUAAAAm8aRKJUklAACAWdhSCAAAAMg8KpUAAAAm8aThbyqVAAAAcBqVSgAAAJNQqQQAAAAcQKUSAADAJFQqAQAAAAdQqQQAADCJJ1UqSSoBAADMwubnAAAAQOZRqQQAADCJJw1/U6kEAACA06hUAgAAmIRKJQAAAOAAKpUAAAAmoVIJAAAAOIBKJQAAgFk8aJ9KkkrASa0rhah1pRCFBFglST/9eksLd/2s3edibdeUD/VTr7pFVS7ET2mGoVOXb+r1L44pOTXNRVED+OD9mZo29W3FREercpWqmjr9PdWqXdvVYQHZFkkl4KTLCcn6aPs5/RKbKFmk5uXya8yTZdXv00P66ddbKh/qp4mty+u/e3/RjC1nlJpmqORDeWQYhqtDBzzW0iWf6rWhkXpv5geqVbuOZrw7XU892UIHvz+uAgUKuDo8PECYUwkg03acvaZdP8Xql7hE/RKbqHk7zutWSprKh/hLkl58NEwrDkVr8b4L+unXW/o5NlGbT15VShpJJeAq706fqp69n1O3Hj1VvkIFvff+B8qVO7cWzJ/r6tDwgElPKs043A1JJZCFvCxS49LB8s3ppaPR1xWUK4fKh/or9laK3ulQSUt71dCUdhVVqaC/q0MFPFZycrL279urps0ibG1eXl5q2jRCu3Zsd2FkQPZGUmmiTZs2yWKxKDY21tWhwGTFg3Nr1fO19b8X6mpQ4xIa9dVxnbt2SwUDfCVJ3WoX1ldHYzTsi2M6eTlBk9pW0MOBvi6OGvBMV65cUWpqqgoUCLFrLxASoujoaBdFhQeVRSZVKt1wpU62SSp79Oghi8WiiRMn2rWvXLnSLUvA8Cznr91S308P6aWlh7XqSIxejSilonlzKf1/zdVHYrT22GWdvHJTs777ST9fu6XHKzBvCwDw4Mg2SaUk+fr66t///reuXbuWZX0mJydnWV/wXLfTDF2IS9SJyzcUtf2cTl+5ofZVC+rXGymS7qwI/71z126pgJ+PK0IFPN5DDz0kb29vXboUY9d+KSZGoaGhLooKDyrmVLqpiIgIhYaGasKECfe8ZtmyZapYsaKsVqvCwsI0ZcoUu/NhYWEaO3asunXrpoCAAD3//POaP3++goKCtHr1apUtW1a5c+dWx44ddfPmTS1YsEBhYWHKmzevXn75ZaWmptr6WrhwoWrWrCl/f3+Fhoaqa9euunTpkmnPj+zDYrEop7dF0deTdCUhWUXy5rI7Xzgol2KuJ7koOsCz+fj46JHqNbRxw3pbW1pamjZuXK/adcNdGBmQvWWrpNLb21v/+te/9N577+nnn3/OcH7v3r3q1KmTunTposOHD2vUqFEaPny45s+fb3fd5MmTVbVqVe3fv1/Dhw+XJN28eVPvvvuuFi9erDVr1mjTpk1q166dvvrqK3311VdauHChZs+erc8++8zWT0pKisaOHauDBw9q5cqVOnv2rHr06GHmlwBuqHd4UVUu5K8Qf6uKB+dW7/CiqvpwgNb/eEWStGT/L2pXJVQNSuZToUBf9ahTREXy5tL/jvILCOAqLw+K1LyoOfrPxwv0w7Fjern/C7p544a6de/p6tDwoLGYeLiZbLdPZbt27VStWjWNHDlSUVFRduemTp2qZs2a2RLFMmXK6OjRo3r77bftkr2mTZtq8ODBttfffvutUlJSNGvWLJUsWVKS1LFjRy1cuFAxMTHy8/NThQoV1KRJE23cuFGdO3eWJPXq1cvWR4kSJfTuu++qVq1aSkhIkJ+f318+S1JSkpKS/r9aFR8f7/gXBC4XlCunXosopXx5fHQjKVVnrt7Q618c077zcZKk5Qej5ePtpRceDZO/bw6dvnJTr31+VBfjqVQCrvKPTp115fJljRk9QjHR0apStZo+X71GISEhf/1mAHeV7ZJKSfr3v/+tpk2basiQIXbtx44dU5s2beza6tevr+nTpys1NVXe3t6SpJo1a2boM3fu3LaEUpJCQkIUFhZmlxyGhITYDW/v3btXo0aN0sGDB3Xt2jWlpd35dJRz586pQoUKf/kcEyZM0OjRozPxxHBnUzac+strFu+7oMX7LtyHaABk1gv9X9IL/V9ydRh4wLH5uZtr2LChWrRooWHDhv2t9+fJkydDW86cOe1eWyyWu7alJ443btxQixYtFBAQoE8++US7d+/WihUrJGV+8c+wYcMUFxdnO86fP/93HgcAALgpT1qoky0rlZI0ceJEVatWTWXLlrW1lS9fXlu3brW7buvWrSpTpoytSplVfvjhB129elUTJ05UkSJFJEl79uxxqA+r1Sqr1ZqlcQEAALhCtqxUSlLlypX1zDPP6N1337W1DR48WOvXr9fYsWP1448/asGCBZoxY0aGYfKsULRoUfn4+Oi9997T6dOn9cUXX2js2LFZfh8AAJB9WSzmHe4m2yaVkjRmzBjbcLQkVa9eXUuWLNHixYtVqVIljRgxQmPGjDFlRXb+/Pk1f/58LV26VBUqVNDEiRM1efLkLL8PAABAdmAxDMNwdRC4Iz4+XoGBgWo0aZ1y5Mo47xOA663uxz6GgLuKj49XSHCg4uLiFBAQ4PJYAgMDVWLAZ/KyZv3P9LSkGzr9Xke3eNZ02bpSCQAAAPeQbRfqAAAAuD2z5j8ypxIAAAAPIiqVAAAAJmHzcwAAAMABVCoBAABMYtaekm5YqCSpBAAAMIuXl0VeXlmfARom9Okshr8BAADgNCqVAAAAJvGk4W8qlQAAAHAalUoAAACTsKUQAAAA4AAqlQAAACZhTiUAAADgACqVAAAAJvGkOZUklQAAACbxpKSS4W8AAAA4jaQSAADAJOkLdcw4/q6JEyfKYrFo0KBBtrbExET1799fwcHB8vPzU4cOHRQTE+NQvySVAAAAHmL37t2aPXu2qlSpYtf+yiuvaNWqVVq6dKk2b96sCxcuqH379g71TVIJAABgEosstnmVWXrI8VJlQkKCnnnmGc2ZM0d58+a1tcfFxSkqKkpTp05V06ZNVaNGDc2bN0/btm3Tjh07Mt0/SSUAAIAH6N+/v5588klFRETYte/du1cpKSl27eXKlVPRokW1ffv2TPfP6m8AAACTmL35eXx8vF271WqV1WrNcP3ixYu1b98+7d69O8O56Oho+fj4KCgoyK49JCRE0dHRmY6JSiUAAEA2VaRIEQUGBtqOCRMmZLjm/PnzGjhwoD755BP5+vqaFguVSgAAAJOYvU/l+fPnFRAQYGu/W5Vy7969unTpkqpXr25rS01N1ZYtWzRjxgytXbtWycnJio2NtatWxsTEKDQ0NNMxkVQCAABkUwEBAXZJ5d00a9ZMhw8ftmvr2bOnypUrp9dee01FihRRzpw5tX79enXo0EGSdPz4cZ07d07h4eGZjoWkEgAAwCRmz6nMDH9/f1WqVMmuLU+ePAoODra19+7dW5GRkcqXL58CAgI0YMAAhYeHq27dupm+D0klAACASbLLxzROmzZNXl5e6tChg5KSktSiRQu9//77DvVBUgkAAOBhNm3aZPfa19dXM2fO1MyZM/92nySVAAAAJnGH4e/7hS2FAAAA4DQqlQAAACbJLnMqswKVSgAAADiNSiUAAIBZTJpTKfcrVFKpBAAAgPOoVAIAAJjEk+ZUklQCAACYhC2FAAAAAAdQqQQAADCJJw1/U6kEAACA06hUAgAAmIQ5lQAAAIADqFQCAACYhDmVAAAAgAOoVAIAAJiESiUAAADgACqVAAAAJvGk1d8klQAAACZh+BsAAABwAJVKAAAAk3jS8DeVSgAAADiNSiUAAIBJmFMJAAAAOIBKJQAAgEksMmlOZdZ36TQqlQAAAHAalUoAAACTeFks8jKhVGlGn84iqQQAADAJWwoBAAAADqBSCQAAYBK2FAIAAAAcQKUSAADAJF6WO4cZ/bobKpUAAABwGpVKAAAAs1hMmv9IpRIAAAAPIiqVAAAAJmGfSgAAAMABVCoBAABMYvntjxn9uhuSSgAAAJOwpRAAAADgACqVAAAAJuFjGgEAAAAHUKkEAAAwCVsKAQAAAA6gUgkAAGASL4tFXiaUFc3o01lUKgEAAOA0KpUAAAAm8aQ5lSSVAAAAJvGkLYUylVQeOnQo0x1WqVLlbwcDAACA7ClTSWW1atVksVhkGMZdz6efs1gsSk1NzdIAAQAAsiuGv//gzJkzZscBAACAbCxTSWWxYsXMjgMAAOCBw5ZCf2HhwoWqX7++ChUqpJ9++kmSNH36dH3++edZGhwAAACyB4eTylmzZikyMlJPPPGEYmNjbXMog4KCNH369KyODwAAINuymHi4G4eTyvfee09z5szRm2++KW9vb1t7zZo1dfjw4SwNDgAAANmDw/tUnjlzRo888kiGdqvVqhs3bmRJUAAAAA8CT9qn0uFKZfHixXXgwIEM7WvWrFH58uWzIiYAAABkMw5XKiMjI9W/f38lJibKMAzt2rVL//3vfzVhwgR99NFHZsQIAACQLXlZ7hxm9OtuHE4q+/Tpo1y5cumtt97SzZs31bVrVxUqVEjvvPOOunTpYkaMAAAA2ZInDX//rc/+fuaZZ/TMM8/o5s2bSkhIUIECBbI6LgAAAGQjfyuplKRLly7p+PHjku5ky/nz58+yoAAAAB4UblhUNIXDC3WuX7+uf/7znypUqJAaNWqkRo0aqVChQnr22WcVFxdnRowAAABwcw4nlX369NHOnTv15ZdfKjY2VrGxsVq9erX27Nmjvn37mhEjAABAtpQ+p9KMw904PPy9evVqrV27Vo8++qitrUWLFpozZ44ef/zxLA0OAAAA2YPDSWVwcLACAwMztAcGBipv3rxZEhQAAMCDwJO2FHJ4+Putt95SZGSkoqOjbW3R0dEaOnSohg8fnqXBAQAAIHvIVKXykUcesRu7P3HihIoWLaqiRYtKks6dOyer1arLly8zrxIAAOA37FP5B23btjU5DAAAgAeP5bfDjH7dTaaSypEjR5odBwAAALIxh+dUAgAAIHO8LBbTjsyaNWuWqlSpooCAAAUEBCg8PFz/+9//bOcTExPVv39/BQcHy8/PTx06dFBMTIzjz+roG1JTUzV58mTVrl1boaGhypcvn90BAAAA91G4cGFNnDhRe/fu1Z49e9S0aVO1adNG33//vSTplVde0apVq7R06VJt3rxZFy5cUPv27R2+j8NJ5ejRozV16lR17txZcXFxioyMVPv27eXl5aVRo0Y5HAAAAMCDymIx78is1q1b64knnlDp0qVVpkwZjR8/Xn5+ftqxY4fi4uIUFRWlqVOnqmnTpqpRo4bmzZunbdu2aceOHQ49q8NJ5SeffKI5c+Zo8ODBypEjh55++ml99NFHGjFihMM3BwAAwP2TmpqqxYsX68aNGwoPD9fevXuVkpKiiIgI2zXlypVT0aJFtX37dof6dnjz8+joaFWuXFmS5OfnZ/u871atWrFPJQAAwO+YvaVQfHy8XbvVapXVas1w/eHDhxUeHq7ExET5+flpxYoVqlChgg4cOCAfHx8FBQXZXR8SEmK3J3lmOFypLFy4sC5evChJKlmypL7++mtJ0u7du+/6EAAAADBHkSJFFBgYaDsmTJhw1+vKli2rAwcOaOfOnXrhhRfUvXt3HT16NEtjcbhS2a5dO61fv1516tTRgAED9OyzzyoqKkrnzp3TK6+8kqXBAQAAZGeOzn90pF9JOn/+vAICAmzt9yrw+fj4qFSpUpKkGjVqaPfu3XrnnXfUuXNnJScnKzY21q5aGRMTo9DQUIdicjipnDhxou3vnTt3VrFixbRt2zaVLl1arVu3drQ7AAAA/E3p2wQ5Ki0tTUlJSapRo4Zy5syp9evXq0OHDpKk48eP69y5cwoPD3eoT4eTyj+qW7eu6tatq0uXLulf//qX3njjDWe7BAAAeCA4uqekI/1m1rBhw9SyZUsVLVpU169f16JFi7Rp0yatXbtWgYGB6t27tyIjI5UvXz4FBARowIABCg8PV926dR2KyemkMt3Fixc1fPhwkkoAAIDfmD38nRmXLl1St27ddPHiRQUGBqpKlSpau3atHnvsMUnStGnT5OXlpQ4dOigpKUktWrTQ+++/73BMWZZUAgAAwP1ERUX96XlfX1/NnDlTM2fOdOo+JJUAAAAmMXtLIXfCZ38DAADAaZmuVEZGRv7p+cuXLzsdDO7YsXCJLN4+rg4DwN30c2w1JADP5iVzKnjuWBXMdFK5f//+v7ymYcOGTgUDAACA7CnTSeXGjRvNjAMAAOCBw5xKAAAAwAGs/gYAADCJxSJ5uXifyvuFpBIAAMAkXiYllWb06SyGvwEAAOA0KpUAAAAmYaHOX/j222/17LPPKjw8XL/88oskaeHChfruu++yNDgAAABkDw4nlcuWLVOLFi2UK1cu7d+/X0lJSZKkuLg4/etf/8ryAAEAALKr9DmVZhzuxuGkcty4cfrggw80Z84c5cyZ09Zev3597du3L0uDAwAAQPbg8JzK48eP3/WTcwIDAxUbG5sVMQEAADwQLBZztv9xwymVjlcqQ0NDdfLkyQzt3333nUqUKJElQQEAACB7cTipfO655zRw4EDt3LlTFotFFy5c0CeffKIhQ4bohRdeMCNGAACAbMnLYjHtcDcOD3+//vrrSktLU7NmzXTz5k01bNhQVqtVQ4YM0YABA8yIEQAAAG7O4aTSYrHozTff1NChQ3Xy5EklJCSoQoUK8vPzMyM+AACAbMtL5nzSjDt+es3f3vzcx8dHFSpUyMpYAAAAHiietFDH4aSySZMmf7qL+4YNG5wKCAAAANmPw0lltWrV7F6npKTowIEDOnLkiLp3755VcQEAAGR7XjJnUY2X3K9U6XBSOW3atLu2jxo1SgkJCU4HBAAAgOwny+Z5Pvvss5o7d25WdQcAAJDtpc+pNONwN1mWVG7fvl2+vr5Z1R0AAACyEYeHv9u3b2/32jAMXbx4UXv27NHw4cOzLDAAAIDszsty5zCjX3fjcFIZGBho99rLy0tly5bVmDFj1Lx58ywLDAAAANmHQ0llamqqevbsqcqVKytv3rxmxQQAAPBAsFhkyurvbD+n0tvbW82bN1dsbKxJ4QAAADw4WKjzJypVqqTTp0+bEQsAAACyKYeTynHjxmnIkCFavXq1Ll68qPj4eLsDAAAAd6Qv1DHjcDeZnlM5ZswYDR48WE888YQk6amnnrL7uEbDMGSxWJSampr1UQIAAMCtZTqpHD16tPr166eNGzeaGQ8AAMADw/LbHzP6dTeZTioNw5AkNWrUyLRgAAAAkD05tKWQxR2XGgEAALgpNj+/hzJlyvxlYvnrr786FRAAAACyH4eSytGjR2f4RB0AAADcHZXKe+jSpYsKFChgViwAAADIpjKdVDKfEgAAwDEWi8WUHMod8zKHV38DAAAgcxj+vou0tDQz4wAAAEA25tCcSgAAAGSexXLnMKNfd+PwZ38DAAAAf0SlEgAAwCReFou8TCgrmtGns6hUAgAAwGlUKgEAAEziSau/qVQCAADAaVQqAQAAzGLS6m+5YaWSpBIAAMAkXrLIy4QM0Iw+ncXwNwAAAJxGpRIAAMAkbH4OAAAAOIBKJQAAgEnYUggAAABwAJVKAAAAk/AxjQAAAIADqFQCAACYhNXfAAAAgAOoVAIAAJjESybNqXTDT9QhqQQAADAJw98AAACAA6hUAgAAmMRL5lTw3LEq6I4xAQAAIJuhUgkAAGASi8UiiwkTIM3o01lUKgEAAOA0KpUAAAAmsfx2mNGvu6FSCQAAAKdRqQQAADCJl8Wkzc/dcE4lSSUAAICJ3C/9MwfD3wAAAHAalUoAAACT8DGNAAAAeCBMmDBBtWrVkr+/vwoUKKC2bdvq+PHjdtckJiaqf//+Cg4Olp+fnzp06KCYmBiH7kNSCQAAYJL0zc/NODJr8+bN6t+/v3bs2KF169YpJSVFzZs3140bN2zXvPLKK1q1apWWLl2qzZs368KFC2rfvr1Dz8rwNwAAwANszZo1dq/nz5+vAgUKaO/evWrYsKHi4uIUFRWlRYsWqWnTppKkefPmqXz58tqxY4fq1q2bqftQqQQAADCJl4mHJMXHx9sdSUlJfxlTXFycJClfvnySpL179yolJUURERG2a8qVK6eiRYtq+/btDj0rAAAAsqEiRYooMDDQdkyYMOFPr09LS9OgQYNUv359VapUSZIUHR0tHx8fBQUF2V0bEhKi6OjoTMfC8DcAAIBJHJ3/6Ei/knT+/HkFBATY2q1W65++r3///jpy5Ii+++67LI+JpBIAACCbCggIsEsq/8xLL72k1atXa8uWLSpcuLCtPTQ0VMnJyYqNjbWrVsbExCg0NDTTsTD8DQAAYBKLiUdmGYahl156SStWrNCGDRtUvHhxu/M1atRQzpw5tX79elvb8ePHde7cOYWHh2f6PlQqAQAATGL28Hdm9O/fX4sWLdLnn38uf39/2zzJwMBA5cqVS4GBgerdu7ciIyOVL18+BQQEaMCAAQoPD8/0ym+JpBIAAOCBNmvWLElS48aN7drnzZunHj16SJKmTZsmLy8vdejQQUlJSWrRooXef/99h+5DUgkAAGCS32//k9X9ZpZhGH95ja+vr2bOnKmZM2fel5gAAACAu6JSCQAAYBJ3mFN5v1CpBAAAgNOoVAIAAJjE0e1/HOnX3VCpBAAAgNOoVAIAAJjEYrlzmNGvuyGpBAAAMImXLPIyYbDajD6dxfA3AAAAnEalEgAAwCSeNPxNpRIAAABOo1IJAABgEstvf8zo191QqQQAAIDTqFQCAACYhDmVAAAAgAOoVAIAAJjEYtI+lcypBAAAwAOJSiUAAIBJPGlOJUklAACASTwpqWT4GwAAAE6jUgkAAGASNj8HAAAAHEClEgAAwCReljuHGf26GyqVAAAAcBqVSgAAAJMwpxIAAABwAJVKAAAAk3jSPpUklQAAACaxyJyhajfMKRn+Bpz1Zt8ndGv/DLvjwPK3bOdDgv0VNbabzqz7l65sm6Jti15T22bVXBcwAEnSB+/PVNlSYQry81WDenW0e9cuV4cEZGtUKoEs8P3JC3qy33u217dT02x//2hsNwX559I/Bs3WldgEdW5ZU//5dy/Vf2aSDh7/2RXhAh5v6ZJP9drQSL038wPVql1HM96drqeebKGD3x9XgQIFXB0eHiBsKQTAIbdT0xRz9brtuBp7w3aubtUSen/xZu35/ied/eWq/v3RWsVev6VHKhRxYcSAZ3t3+lT17P2cuvXoqfIVKui99z9Qrty5tWD+XFeHBmRbJJVAFihVNL9Ofz1eR1eN0rzx3VUkNK/t3I6Dp9WxeQ3lDcgti8Wif7SoIV9rDm3Zc8KFEQOeKzk5Wfv37VXTZhG2Ni8vLzVtGqFdO7a7MDI8iCwm/nE3DH8DTtp95KyeH/Ef/fhTjEIfCtSbfVvqm7mvqEbH8Uq4maRnX52rhf/upQubJyklJVU3E5PVOXKOTp+/4urQAY905coVpaamqkCBELv2AiEhOn78BxdFBWR/VCqdEBYWpunTp7s6DLjY11uPavk3+3XkxAV9s/2Y2r40S4F+udSheXVJ0sj+rRTkn0st+76r+s9O0rv/2aD/TOqliqUKuThyAIDZ0rcUMuNwN26RVG7fvl3e3t568sknXR0K4LS4hFs6ee6SShbJr+KFH9ILXRqp76j/aNOuH3X4x1/0rw//p31Hz6lv54auDhXwSA899JC8vb116VKMXfulmBiFhoa6KCog+3OLpDIqKkoDBgzQli1bdOHCBVeHAzglTy4fFS/8kKKvxCm3r48kKc0w7K5JTTXk5Y6/ZgIewMfHR49Ur6GNG9bb2tLS0rRx43rVrhvuwsjwILKYeLgblyeVCQkJ+vTTT/XCCy/oySef1Pz5823nNm3aJIvFovXr16tmzZrKnTu36tWrp+PHj9v1MWvWLJUsWVI+Pj4qW7asFi5caHfeYrFo9uzZatWqlXLnzq3y5ctr+/btOnnypBo3bqw8efKoXr16OnXqlO09p06dUps2bRQSEiI/Pz/VqlVL33zzzT2fo1evXmrVqpVdW0pKigoUKKCoqCgnvkJwdxNeaadHa5RS0YL5VLdqcX069XmlpqVpyZq9On42WifPXdKMt55WzYrFVLzwQxr4z6ZqVresVm066OrQAY/18qBIzYuao/98vEA/HDuml/u/oJs3bqhb956uDg3ItlyeVC5ZskTlypVT2bJl9eyzz2ru3Lky/lDVefPNNzVlyhTt2bNHOXLkUK9evWznVqxYoYEDB2rw4ME6cuSI+vbtq549e2rjxo12fYwdO1bdunXTgQMHVK5cOXXt2lV9+/bVsGHDtGfPHhmGoZdeesl2fUJCgp544gmtX79e+/fv1+OPP67WrVvr3Llzd32OPn36aM2aNbp48aKtbfXq1bp586Y6d+581/ckJSUpPj7e7kD283BIkD6e0FOHVg7Xf/7dS7/G3VCjblN05VqCbt9OU9sBs3TlWoI+e6evdi8Zpq6taqvPiIVa+91RV4cOeKx/dOqsCf+erDGjR6hOzWo6ePCAPl+9RiEhIX/9ZsABXrLIy2LC4Ya1SovxxwzuPqtfv746deqkgQMH6vbt2ypYsKCWLl2qxo0ba9OmTWrSpIm++eYbNWvWTJL01Vdf6cknn9StW7fk6+ur+vXrq2LFivrwww9tfXbq1Ek3btzQl19+KelOpfKtt97S2LFjJUk7duxQeHi4oqKibAnq4sWL1bNnT926deuesVaqVEn9+vWzJZ9hYWEaNGiQBg0aJEmqWLGiunfvrldffVWS9NRTTyk4OFjz5s27a3+jRo3S6NGjM7RbKz8ni7ePI19GAPfJtd0zXB0CgHuIj49XSHCg4uLiFBAQ4PJYAgMD9c2+n5THP+tjuXE9XhHVi7nFs6ZzaaXy+PHj2rVrl55++mlJUo4cOdS5c+cMw8VVqlSx/b1gwYKSpEuXLkmSjh07pvr169tdX79+fR07duyefaT/Jlq5cmW7tsTERFu1MCEhQUOGDFH58uUVFBQkPz8/HTt27J6VSulOtTI9gYyJidH//vc/u6rqHw0bNkxxcXG24/z58/e8FgAAwJ25dJ/KqKgo3b59W4UK/f/WKoZhyGq1asaM/68G5MyZ0/Z3y2+LG9LS/v9j8DLjbn38Wb9DhgzRunXrNHnyZJUqVUq5cuVSx44dlZycfM97dOvWTa+//rq2b9+ubdu2qXjx4mrQoME9r7darbJarQ49BwAAyEbMWlXjfqPfrksqb9++rY8//lhTpkxR8+bN7c61bdtW//3vf1WuXLm/7Kd8+fLaunWrunfvbmvbunWrKlSo4FR8W7duVY8ePdSuXTtJdyqXZ8+e/dP3BAcHq23btpo3b562b9+unj2Z8A0AADyDy5LK1atX69q1a+rdu7cCAwPtznXo0EFRUVF6++23/7KfoUOHqlOnTnrkkUcUERGhVatWafny5X+6UjszSpcureXLl6t169ayWCwaPnx4pqqjffr0UatWrZSammqX6AIAAM9j1kcquuPHNLpsTmVUVJQiIiIyJJTSnaRyz549OnTo0F/207ZtW73zzjuaPHmyKlasqNmzZ2vevHlq3LixU/FNnTpVefPmVb169dS6dWu1aNFC1atX/8v3RUREqGDBgmrRooXdsD4AAMCDzOWrvx80CQkJevjhhzVv3jy1b9/eofemrxRj9Tfgvlj9Dbgvd1z9vf7AOfmZsPo74Xq8mlUr6hbPms6lC3UeJGlpabpy5YqmTJmioKAgPfXUU64OCQAA4L4hqcwi586dU/HixVW4cGHNnz9fOXLwpQUAwNN50OJvksqsEhYWluGTgAAAgIfzoKzS5R/TCAAAgOyPSiUAAIBJ2FIIAAAAcACVSgAAAJNYLHcOM/p1N1QqAQAA4DQqlQAAACbxoMXfVCoBAADgPCqVAAAAZvGgUiWVSgAAADiNSiUAAIBJPGmfSpJKAAAAk7ClEAAAAOAAKpUAAAAm8aB1OlQqAQAA4DwqlQAAAGbxoFIllUoAAAA4jUolAACASTxpSyEqlQAAAHAalUoAAACTeNI+lSSVAAAAJvGgdToMfwMAAMB5VCoBAADM4kGlSiqVAAAAD7gtW7aodevWKlSokCwWi1auXGl33jAMjRgxQgULFlSuXLkUERGhEydOOHQPkkoAAACTWEz844gbN26oatWqmjlz5l3PT5o0Se+++64++OAD7dy5U3ny5FGLFi2UmJiY6Xsw/A0AAPCAa9mypVq2bHnXc4ZhaPr06XrrrbfUpk0bSdLHH3+skJAQrVy5Ul26dMnUPahUAgAAmCR9SyEzDkmKj4+3O5KSkhyO8cyZM4qOjlZERIStLTAwUHXq1NH27dsz3Q9JJQAAQDZVpEgRBQYG2o4JEyY43Ed0dLQkKSQkxK49JCTEdi4zGP4GAAAwidmLv8+fP6+AgABbu9VqNeFumUOlEgAAIJsKCAiwO/5OUhkaGipJiomJsWuPiYmxncsMkkoAAACzWEw8skjx4sUVGhqq9evX29ri4+O1c+dOhYeHZ7ofhr8BAABM8ne2/8lsv45ISEjQyZMnba/PnDmjAwcOKF++fCpatKgGDRqkcePGqXTp0ipevLiGDx+uQoUKqW3btpm+B0klAADAA27Pnj1q0qSJ7XVkZKQkqXv37po/f75effVV3bhxQ88//7xiY2P16KOPas2aNfL19c30PUgqAQAATPL77X+yul9HNG7cWIZh/El/Fo0ZM0Zjxoz52zExpxIAAABOo1IJAABgErO3FHInVCoBAADgNCqVAAAAZvGgUiWVSgAAADiNSiUAAIBJ3GWfyvuBpBIAAMAsJm0p5IY5JcPfAAAAcB6VSgAAAJN40DodKpUAAABwHpVKAAAAs3hQqZJKJQAAAJxGpRIAAMAknrSlEJVKAAAAOI1KJQAAgEksJu1Tacrel06iUgkAAACnUakEAAAwiQct/iapBAAAMI0HZZUMfwMAAMBpVCoBAABMwpZCAAAAgAOoVAIAAJjEIpO2FMr6Lp1GpRIAAABOo1IJAABgEg9a/E2lEgAAAM6jUgkAAGAST/qYRpJKAAAA03jOADjD3wAAAHAalUoAAACTeNLwN5VKAAAAOI1KJQAAgEk8Z0YllUoAAABkASqVAAAAJmFOJQAAAOAAKpUAAAAmsfz2x4x+3Q2VSgAAADiNSiUAAIBZPGj5N0klAACASTwop2T4GwAAAM6jUgkAAGASthQCAAAAHEClEgAAwCRsKQQAAAA4gEolAACAWTxo+TeVSgAAADiNSiUAAIBJPKhQSVIJAABgFrYUAgAAABxApRIAAMA05mwp5I4D4FQqAQAA4DQqlQAAACZhTiUAAADgAJJKAAAAOI2kEgAAAE5jTiUAAIBJmFMJAAAAOIBKJQAAgEksJu1Tac7el84hqQQAADAJw98AAACAA6hUAgAAmMQicz5Q0Q0LlVQqAQAA4DwqlQAAAGbxoFIllUoAAAA4jUolAACASTxpSyEqlQAAAHAalUoAAACTeNI+lSSVAAAAJvGgdToMfwMAAMB5VCoBAADM4kGlSiqVAAAAHmDmzJkKCwuTr6+v6tSpo127dmVp/ySVAAAAJrGY+McRn376qSIjIzVy5Ejt27dPVatWVYsWLXTp0qUse1aSSgAAgAfc1KlT9dxzz6lnz56qUKGCPvjgA+XOnVtz587NsnuQVAIAAJgkfUshM47MSk5O1t69exUREWFr8/LyUkREhLZv355lz8pCHTdiGMad/6YmuzgSAPcSHx/v6hAA3MP1374/03+eugOz/s1I7/eP/VutVlmtVru2K1euKDU1VSEhIXbtISEh+uGHH7IsJpJKN3L9+nVJUvLRBS6OBMC9hATPcXUIAP7C9evXFRgY6NIYfHx8FBoaqtLFi5h2Dz8/PxUpYt//yJEjNWrUKNPu+WdIKt1IoUKFdP78efn7+8vijlvlw2Hx8fEqUqSIzp8/r4CAAFeHA+B3+P588BiGoevXr6tQoUKuDkW+vr46c+aMkpPNG300DCNDvvDHKqUkPfTQQ/L29lZMTIxde0xMjEJDQ7MsHpJKN+Ll5aXChQu7OgyYICAggB9agJvi+/PB4uoK5e/5+vrK19fX1WHIx8dHNWrU0Pr169W2bVtJUlpamtavX6+XXnopy+5DUgkAAPCAi4yMVPfu3VWzZk3Vrl1b06dP140bN9SzZ88suwdJJQAAwAOuc+fOunz5skaMGKHo6GhVq1ZNa9asybB4xxkklYCJrFarRo4cedc5LgBci+9PeJqXXnopS4e7/8hiuNO6ewAAAGRLbH4OAAAAp5FUAgAAwGkklQAAAHAaSSUAAACcRlIJuJm0tDRXhwAgE/heBeyRVAJuYvr06Tp8+LC8vLz4YQVkA15ed36Ebt68WbGxsWIzFXg6kkrADSQkJGj58uVq2LChjh07RmIJZAOGYWjXrl1q0qSJoqOjZbFYSCzh0dinEnATv/zyi/r376+tW7dq8+bNqlChgtLS0mzVEADuqUWLFgoODtb8+fPl4+Pj6nAAl+GnFeAmHn74Yc2cOVN169ZVo0aNdPToUSqWgBu5ffu23euUlBRJUvv27XX69GnFxMRIYq4lPBdJJeAG0gcMHn74Yc2aNYvEEnAjZ86ckSTlyHHnk423bt2qpKQk5cyZU5LUtWtX/fLLL5o+fbokMboAj8X/+YALpSeTFovF1la4cGHNmjVLderUIbEEXOyFF17Qiy++qP3790uSvvnmG3Xr1k2VK1fWZ599piNHjsjf31+jR4/Wzp07dezYMRdHDLgOSSXgIoZhyGKxaMuWLXr99dc1YMAALVmyRNKdxPLDDz+0JZYs3gFco127dvrxxx81depUHT16VI0aNdKXX36pZs2aadKkSWrXrp0mT56snDlz6tKlSzpx4oQksWAHHomFOoALrVixQs8995zq1aunhx56SPPnz9fEiRM1aNAg+fj46MKFC3rxxRf1xRdf6NixYypbtqyrQwY8xu9/8evRo4dq166tN954Q1WqVJEkHT58WHv37tX48eNVo0YNLVmyRFWqVNHXX3+tAgUKuDh64P4jqQRcZM+ePWrbtq1GjBih559/XtHR0SpdurRu3LihwYMHa8KECcqRI4fOnz+voUOHasyYMSpTpoyrwwY8SnpiuXnzZvXs2VPh4eEaNGiQatWqZbvm3LlzOnLkiBYsWKANGzZo4cKFevzxx9m9AR6HpBJwgbS0NP33v//VsWPHNG7cOJ0/f14NGjRQq1atVKNGDfXu3Vvjxo3TkCFD5OPjo9TUVHl7e7s6bMAj3CsZ3LRpk3r16qXw8HANHjxY1atXz3BNq1atlJKSorVr196PUAG3wq9QwH2U/jucl5eXmjRporZt2yo5OVm9e/dWs2bN9M477+iJJ55QoUKF9NZbb2ns2LGSREIJ3Ce/Tyh/+uknHTlyRKmpqbp9+7YaN26sjz76SNu3b9eUKVNsi3ckKSkpSZLUt29fxcbG6urVqy6JH3AlkkrgPkhPJm/evGl7XahQIdWsWVNXrlzRlStX1LlzZ3l7e8tqteqJJ57QggUL9Mwzz7gybMCj/D6hHDFihFq1aqV69erpscce08KFC3Xjxg01bdpUH330kXbs2KGpU6dq586dkiSr1SpJWr16ta5cuWLbbgjwJCSVwH1gsVj05Zdf6h//+IfatWunjz/+WPHx8ZKk69ev6+DBg/rxxx8VExOjyZMna8eOHWrTpo3KlSvn4sgBz5GeUI4ePVpz5szRuHHjdObMGaWmpmry5MmaNWuWEhISbInlsmXLtGbNGtv7b9++LS8vLy1atEgBAQGuegzAZZhTCdwHO3fuVEREhPr166ddu3YpOTlZ1atX15gxYxQcHKyJEyfqjTfeUKlSpfTrr79q3bp1euSRR1wdNuBx9u3bp759+2r8+PFq3ry5Nm3apFatWqlKlSq6cuWKXnzxRT333HPKkyeP9u3bp6pVqzI9BfgNSSVgkvRVo5K0fPlyHThwQGPGjJEkTZo0SStXrlTlypU1ceJE5c2bV9u3b1dcXJwqVqyoIkWKuDJ0wGPFxMRo7dq16ty5s7Zv365OnTppwoQJ6t27t6pXr67ExER17NhRr7/+unLnzi1JLKQDfpPD1QEAD6L0hHL37t26cOGC9uzZI39/f9v5wYMHy2KxaPny5Xrrrbc0atQohYeHuzBiwPPcbZV3/vz51bp1a/n4+OjDDz9Ut27d1KNHD0lSmTJltHv3bl27dk25cuWyvYeEEriDpBIwgcVi0bJly9S9e3cFBQXp119/VdmyZTVw4EDlzp1b3t7eGjx4sLy8vBQVFSUfHx9NmTJFFovF7iMbAZjj9wnlxo0blTNnTuXNm1cVK1ZU3rx5lZaWpsuXLyt//vy2pDFHjhyaNWuWIiIiZLFY7EYjADD8DWSp9B8yN27c0MCBA/Xoo4/qiSee0IoVKzR79mwVK1ZMH3/8sa1qmZaWppkzZ6p169YKCwtzbfCAB3r99dc1a9YsBQcHKyEhQe+//746duyo5ORkPf/88/rhhx9UsWJFnTx5UlevXtXBgwfl7e3NxubAXfAdAWSh9CHv2rVr68KFC6pfv74KFCigPn36aNCgQbp48aL++c9/6vr165LurDYdMGAACSVwn/y+jnLs2DF98803Wr9+vRYtWqQ+ffqoU6dOmjt3rnx8fDRt2jRVqlRJsbGxKly4sPbv309CCfwJhr+BLJBeody3b59Onz6twMBAffvtt8qTJ4+kO3OuunbtKovFog8//FBPPfWUVq1aJT8/PxdHDniO3yeDiYmJunXrlho1aqSaNWtKkipVqiQfHx/16dNHaWlp6tOnj2bNmmW35+Tt27eVIwc/OoG74VctIAuk70PZoUMHBQQEaPTo0SpcuLDatGmjlJQUSXfmYz399NPq1q2bcubMqdjYWNcGDXiY9IRy1KhRatWqlXr16qW9e/cqLi5OkuTn56chQ4ZoxIgRevHFFzVjxgy7hNIwDBJK4E8wpxJwQnqFMiYmRkOGDFGtWrX08ssvKy0tTRs3btTgwYOVK1cubdq0yfaJG7dv39bNmzfZHBm4T35foZw5c6bGjRun7t27KyYmRgsWLNDkyZP1yiuv2Bbd3LhxQ8OHD9euXbv07bffshgHyCSSSsBJW7du1fjx4/Xrr79q+vTpqlu3rqQ7yeOmTZs0dOhQ+fv7a926dbbEEsD9t2/fPi1btkx169ZV69atJUlTp07V0KFDNXXqVL388su2BDIxMVFWq5VV3oADGP4GnBQaGqozZ85o165d2r9/v609R44catKkiaZMmaJz587pqaeecmGUgOcyDEP79u1TzZo1NWnSJLupJ5GRkXr77bc1ePBgzZgxw7aQx9fXl4QScBBJJeCkkiVLas2aNapWrZo++eQTbdiwwXbO29tbjRo10oIFCzRr1iwXRgl4LovFourVq+vjjz9Wamqqtm7dqitXrtjOR0ZGavLkyRo4cKA+++yzDO8FkDkMfwMOSK9aHD9+XOfPn1dQUJBCQ0NVuHBhnThxQh06dFDBggU1bNgwNW7c2NXhAh7p93Mo/1hpnD17tl544QWNHDlSL7/8svLmzWs79+mnn6pDhw4sxgH+Jr5zgExK/+G0bNkyDRw4UDlz5pRhGPL19dWHH36ohg0b6rPPPlPHjh319ttvKzk5Wc2bN3d12IBH+X1COWfOHB06dEjJycmqU6eO/vnPf6pv375KTU3VSy+9JEl2iWXnzp0lsW0Q8Hcx/A3cQ1pamu3vt2/flsVi0a5du9SzZ08NHz5c3333nRYsWKBatWqpRYsW+vbbb1WmTBktX75chw8f1uzZs3Xz5k0XPgHgedITyldffVWvv/66UlJSdOjQIb3zzjt66qmnlJycrBdffFHvv/++xo4dq3Hjxtk+jCAdCSXw9/CdA9yDl5eXfvrpJxUtWlQ5cuRQamqqDh8+rJo1a+q5556Tl5eXHn74YZUtW1ZpaWkaOHCgvvrqK5UqVUpbtmxRWlqacufO7erHADzO9u3btWTJEq1cuVINGjSQYRhavny5/v3vf6tLly5asmSJ+vXrp6SkJC1ZsoQPIQCyCJVK4B6SkpLUpUsXlShRQoZhyNvbW/Hx8Tpw4IDi4+Ml3RkSDw0NVdeuXXXlyhVdu3ZNkhQWFqYSJUq4MnzAY126dEm3bt1S6dKlJd1ZbPPkk0+qX79+On36tI4dOyZJGjhwoL777jvbKm8AziGpBO7Bx8dHb7/9tvz8/FS9enUZhqE2bdqoYMGCmjdvnmJjY20LAEqXLq2cOXNmGEYDYK7fJ4Ppfy9cuLCCgoLstvjy9fVVy5Yt9eOPP+r777+3tbNtEJB1SCqB3/x+DqV054dNvXr1NGfOHN26dUt16tRRiRIl1K5dO82bN09z5sxRTEyMEhISNHfuXHl5eSksLMw1wQMeKC0tzS4ZTE1NlSQVKVJEAQEBmjlzpo4cOWI77+3trXLlyikoKMiuHxJKIGuwpRCg/18xGh0drbNnz9o+FUeSUlJStH//fnXp0kVFihTR5s2bNWLECK1YsUInT55UtWrVdOrUKa1du1aPPPKIC58C8EyTJ0/W7t27lZqaqsjISNWrV0/Hjx9XRESEypUrpyZNmqhSpUqaMWOGLl++rD179sjb29vVYQMPHJJK4Dfnz5/XI488ol9//VWNGjVSeHi4IiIiVLNmTQUEBGj37t3q3bu3AgIC9N133yk6OlpfffWV8ubNq+rVq6tYsWKufgTAI/x+26AxY8ZoxowZatOmjU6dOqXNmzfr448/1jPPPKOTJ09qxIgROnDggKxWqwoXLqzly5crZ86cSk1NJbEEshhJJfCbn376SW3bttWtW7fk7++vihUr6tNPP1W5cuVUuXJltWrVShaLRcOGDVOJEiW0du1ahs0AF/rll18UFRWlpk2b6tFHH9WtW7c0evRoTZkyRfPmzdOzzz6rpKQkpaSk6Pr16woNDZXFYmEfSsAkJJXA75w8eVKvvvqq0tLSNGzYMBUsWFDbtm3TjBkzlJKSoiNHjqhkyZI6cuSI2rRpoxUrVjDJH3CBzz//XO3atVNYWJgWL16s2rVrS7ozXWX48OGaOnWqPv74Y3Xp0sXufb+vcgLIWiSVwB8cP35cAwcOVFpamsaPH69atWpJkmJjY7Vq1Sr98MMP+t///qeoqCjmUAL3SXoymP7fCxcuaPz48Zo9e7aWLVumNm3a2M7dvn1bI0eO1IQJE7Ru3To1a9bM1eEDHoGkEriLEydOaMCAAZKkYcOGqVGjRnbnGT4D7p/Fixfr66+/1uuvv66HH35YefLkkSTFxMRo6NChWrZsmdatW6d69erZRg5SUlIUFRWlPn368L0K3CcklcA9nDhxQi+//LIMw9CIESNUr149V4cEeJz4+HhVr15d8fHxCg0NVe3atfXoo4+qR48ekqSbN2+qd+/e+uKLL/T111+rfv36Gaak8EsgcH+QVAJ/4sSJE4qMjNSVK1c0bdo0u62GAJgvNTVVw4cPV7FixVSrVi1t2LBB48ePV8uWLVWlShUNHjxYcXFxGjFihBYuXKgvvvhCTZo0cXXYgEditjLwJ0qXLq23335bhQsXVqFChVwdDuBxvL291aBBAw0dOlQ5cuTQkCFDdPHiRZUqVUpvvPGGwsPDNXfuXLVv314tW7bU+PHjXR0y4LGoVAKZkJycLB8fH1eHAXis/v37S5JmzpwpSapYsaLKlCmjkiVL6vvvv9fatWs1efJkDRo0iNXdgIswyQTIBBJKwLWqV6+uefPm6dq1a2rWrJny5s2rBQsWKCAgQD///LO2bdum9u3b260QB3B/UakEAGQLtWvX1p49e9SwYUMtX75c+fLly3ANi3IA1+FXOQCAW0uvfbz88suqWLGipkyZonz58uluNRESSsB1SCoBAG4tfXugJk2a6OrVq1q3bp1dOwD3QFIJAMgWHn74YQ0bNkyTJ0/W0aNHXR0OgD9gnAAAkG088cQT2rNnj8qVK+fqUAD8AQt1AADZSvon5qSmpsrb29vV4QD4DUklAAAAnMacSgAAADiNpBIAAABOI6kEAACA00gqAQAA4DSSSgAAADiNpBIAAABOI6kE8EDo0aOH2rZta3vduHFjDRo06L7HsWnTJlksFsXGxpp2jz8+699xP+IE4FlIKgGYpkePHrJYLLJYLPLx8VGpUqU0ZswY3b592/R7L1++XGPHjs3Utfc7wQoLC9P06dPvy70A4H7hYxoBmOrxxx/XvHnzlJSUpK+++kr9+/dXzpw5NWzYsAzXJicny8fHJ0vumy9fvizpBwCQOVQqAZjKarUqNDRUxYoV0wsvvKCIiAh98cUXkv5/GHf8+PEqVKiQypYtK0k6f/68OnXqpKCgIOXLl09t2rTR2bNnbX2mpqYqMjJSQUFBCg4O1quvvqo/fjjYH4e/k5KS9Nprr6lIkSKyWq0qVaqUoqKidPbsWTVp0kSSlDdvXlksFvXo0UOSlJaWpgkTJqh48eLKlSuXqlatqs8++8zuPl999ZXKlCmjXLlyqUmTJnZx/h2pqanq3bu37Z5ly5bVO++8c9drR48erfz58ysgIED9+vVTcnKy7VxmYgeArESlEsB9lStXLl29etX2ev369QoICNC6deskSSkpKWrRooXCw8P17bffKkeOHBo3bpwef/xxHTp0SD4+PpoyZYrmz5+vuXPnqnz58poyZYpWrFihpk2b3vO+3bp10/bt2/Xuu++qatWqOnPmjK5cuaIiRYpo2bJl6tChg44fP66AgADlypVLkjRhwgT95z//0QcffKDSpUtry5YtevbZZ5U/f341atRI58+fV/v27dW/f389//zz2rNnjwYPHuzU1yctLU2FCxfW0qVLFRwcrG3btun5559XwYIF1alTJ7uvm6+vrzZt2qSzZ8+qZ8+eCg4O1vjx4zMVOwBkOQMATNK9e3ejTZs2hmEYRlpamrFu3TrDarUaQ4YMsZ0PCQkxkpKSbO9ZuHChUbZsWSMtLc3WlpSUZOTKlctYu3atYRiGUbBgQWPSpEm28ykpKUbhwoVt9zIMw2jUqJExcOBAwzAM4/jx44YkY926dXeNc+PGjYYk49q1a7a2xMREI3fu3Ma2bdvsru3du7fx9NNPG4ZhGMOGDTMqVKhgd/61117L0NcfFStWzJg2bdo9z/9R//79jQ4dOthed+/e3ciXL59x48YNW9usWbMMPz8/IzU1NVOx3+2ZAcAZVCoBmGr16tXy8/NTSkqK0tLS1LVrV40aNcp2vnLlynbzKA8ePKiTJ0/K39/frp/ExESdOnVKcXFxunjxourUqWM7lyNHDtWsWTPDEHi6AwcOyNvb26EK3cmTJ3Xz5k099thjdu3Jycl65JFHJEnHjh2zi0OSwsPDM32Pe5k5c6bmzp2rc+fO6datW0pOTla1atXsrqlatapy585td9+EhASdP39eCQkJfxk7AGQ1kkoApmrSpIlmzZolHx8fFSpUSDly2P+zkydPHrvXCQkJqlGjhj755JMMfeXPn/9vxZA+nO2IhIQESdKXX36phx9+2O6c1Wr9W3FkxuLFizVkyBBNmTJF4eHh8vf319tvv62dO3dmug9XxQ7As5FUAjBVnjx5VKpUqUxfX716dX366acqUKCAAgIC7npNwYIFtXPnTjVs2FCSdPv2be3du1fVq1e/6/WVK1dWWlqaNm/erIiIiAzn0yulqamptrYKFSrIarXq3Llz96xwli9f3rboKN2OHTv++iH/xNatW1WvXj29+OKLtrZTp05luO7gwYO6deuWLWHesWOH/Pz8VKRIEeXLl+8vYweArMbqbwBu5ZlnntFDDz2kNm3a6Ntvv9WZM2e0adMmvfzyy/r5558lSQMHDtTEiRO1cuVK/fDDD3rxxRf/dI/JsLAwde/eXb169dLKlSttfS5ZskSSVKxYMVksFq1evVqXL19WQkKC/P39NWTIEL3yyitasGCBTp06pX379um9997TggULJEn9+vXTiRMnNHToUB0/flyLFi3S/PnzM/Wcv/zyiw4cOGB3XLt2TaVLl9aePXu0du1a/fjjjxo+fLh2796d4f3Jycnq3bu3jh49qq+++kojR47USy+9JC8vr0zFDgBZztWTOgE8uH6/UMeR8xcvXjS6detmPPTQQ4bVajVKlChhPPfcc0ZcXJxhGHcW5gwcONAICAgwgoKCjMjISKNbt273XKhjGIZx69Yt45VXXjEKFixo+Pj4GKVKlTLmzp1rOz9mzBgjNDTUsFgsRvfu3Q3DuLO4aPr06UbZsmWNnDlzGvnz5zdatGhhbN682fa+VatWGaVKlTKsVqvRoEEDY+7cuZlaqCMpw7Fw4UIjMTHR6NGjhxEYGGgEBQUZL7zwgvH6668bVatWzfB1GzFihBEcHGz4+fkZzz33nJGYmGi75q9iZ6EOgKxmMYx7zGwHAAAAMonhbwAAADiNpBIAAABOI6kEAACA00gqAQAA4DSSSgAAADiNpBIAAABOI6kEAACA00gqAQAA4DSSSgAAADiNpBIAAABOI6kEAACA00gqAQAA4LT/A1LoQss7DPyPAAAAAElFTkSuQmCC", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import numpy as np\n", + "from sklearn.metrics import confusion_matrix\n", + "\n", + "import matplotlib.pyplot as plt\n", + "\n", + "# Get the true and predicted labels from the results_dataframe\n", + "true_labels = results_dataframe['true']\n", + "predicted_labels = results_dataframe['predicted']\n", + "\n", + "# Compute the confusion matrix\n", + "cm = confusion_matrix(true_labels, predicted_labels)\n", + "\n", + "# Define the class labels\n", + "class_labels = ['Normal', 'Anomaly']\n", + "\n", + "# Plot the confusion matrix\n", + "plt.figure(figsize=(8, 6))\n", + "plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)\n", + "plt.title('Confusion Matrix')\n", + "plt.colorbar()\n", + "tick_marks = np.arange(len(class_labels))\n", + "plt.xticks(tick_marks, class_labels, rotation=45)\n", + "plt.yticks(tick_marks, class_labels)\n", + "plt.xlabel('Predicted Label')\n", + "plt.ylabel('True Label')\n", + "\n", + "# Add the values to the confusion matrix plot\n", + "thresh = cm.max() / 2.\n", + "for i in range(cm.shape[0]):\n", + " for j in range(cm.shape[1]):\n", + " plt.text(j, i, format(cm[i, j], 'd'),\n", + " horizontalalignment=\"center\",\n", + " color=\"white\" if cm[i, j] > thresh else \"black\")\n", + "\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.6" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/_static/basic.css b/_static/basic.css index 30fee9d..f316efc 100644 --- a/_static/basic.css +++ b/_static/basic.css @@ -4,7 +4,7 @@ * * Sphinx stylesheet -- basic theme. * - * :copyright: Copyright 2007-2023 by the Sphinx team, see AUTHORS. + * :copyright: Copyright 2007-2024 by the Sphinx team, see AUTHORS. * :license: BSD, see LICENSE for details. * */ diff --git a/_static/doctools.js b/_static/doctools.js index d06a71d..4d67807 100644 --- a/_static/doctools.js +++ b/_static/doctools.js @@ -4,7 +4,7 @@ * * Base JavaScript utilities for all Sphinx HTML documentation. * - * :copyright: Copyright 2007-2023 by the Sphinx team, see AUTHORS. + * :copyright: Copyright 2007-2024 by the Sphinx team, see AUTHORS. * :license: BSD, see LICENSE for details. * */ diff --git a/_static/graphviz.css b/_static/graphviz.css index 8d81c02..027576e 100644 --- a/_static/graphviz.css +++ b/_static/graphviz.css @@ -4,7 +4,7 @@ * * Sphinx stylesheet -- graphviz extension. * - * :copyright: Copyright 2007-2023 by the Sphinx team, see AUTHORS. + * :copyright: Copyright 2007-2024 by the Sphinx team, see AUTHORS. * :license: BSD, see LICENSE for details. * */ diff --git a/_static/language_data.js b/_static/language_data.js index 250f566..367b8ed 100644 --- a/_static/language_data.js +++ b/_static/language_data.js @@ -5,7 +5,7 @@ * This script contains the language-specific data used by searchtools.js, * namely the list of stopwords, stemmer, scorer and splitter. * - * :copyright: Copyright 2007-2023 by the Sphinx team, see AUTHORS. + * :copyright: Copyright 2007-2024 by the Sphinx team, see AUTHORS. * :license: BSD, see LICENSE for details. * */ @@ -13,7 +13,7 @@ var stopwords = ["a", "and", "are", "as", "at", "be", "but", "by", "for", "if", "in", "into", "is", "it", "near", "no", "not", "of", "on", "or", "such", "that", "the", "their", "then", "there", "these", "they", "this", "to", "was", "will", "with"]; -/* Non-minified version is copied as a separate JS file, is available */ +/* Non-minified version is copied as a separate JS file, if available */ /** * Porter Stemmer diff --git a/_static/searchtools.js b/_static/searchtools.js index 7918c3f..92da3f8 100644 --- a/_static/searchtools.js +++ b/_static/searchtools.js @@ -4,7 +4,7 @@ * * Sphinx JavaScript utilities for the full-text search. * - * :copyright: Copyright 2007-2023 by the Sphinx team, see AUTHORS. + * :copyright: Copyright 2007-2024 by the Sphinx team, see AUTHORS. * :license: BSD, see LICENSE for details. * */ @@ -99,7 +99,7 @@ const _displayItem = (item, searchTerms, highlightTerms) => { .then((data) => { if (data) listItem.appendChild( - Search.makeSearchSummary(data, searchTerms) + Search.makeSearchSummary(data, searchTerms, anchor) ); // highlight search terms in the summary if (SPHINX_HIGHLIGHT_ENABLED) // set in sphinx_highlight.js @@ -116,8 +116,8 @@ const _finishSearch = (resultCount) => { ); else Search.status.innerText = _( - `Search finished, found ${resultCount} page(s) matching the search query.` - ); + "Search finished, found ${resultCount} page(s) matching the search query." + ).replace('${resultCount}', resultCount); }; const _displayNextItem = ( results, @@ -137,6 +137,22 @@ const _displayNextItem = ( // search finished, update title and status message else _finishSearch(resultCount); }; +// Helper function used by query() to order search results. +// Each input is an array of [docname, title, anchor, descr, score, filename]. +// Order the results by score (in opposite order of appearance, since the +// `_displayNextItem` function uses pop() to retrieve items) and then alphabetically. +const _orderResultsByScoreThenName = (a, b) => { + const leftScore = a[4]; + const rightScore = b[4]; + if (leftScore === rightScore) { + // same score: sort alphabetically + const leftTitle = a[1].toLowerCase(); + const rightTitle = b[1].toLowerCase(); + if (leftTitle === rightTitle) return 0; + return leftTitle > rightTitle ? -1 : 1; // inverted is intentional + } + return leftScore > rightScore ? 1 : -1; +}; /** * Default splitQuery function. Can be overridden in ``sphinx.search`` with a @@ -160,13 +176,26 @@ const Search = { _queued_query: null, _pulse_status: -1, - htmlToText: (htmlString) => { + htmlToText: (htmlString, anchor) => { const htmlElement = new DOMParser().parseFromString(htmlString, 'text/html'); - htmlElement.querySelectorAll(".headerlink").forEach((el) => { el.remove() }); + for (const removalQuery of [".headerlinks", "script", "style"]) { + htmlElement.querySelectorAll(removalQuery).forEach((el) => { el.remove() }); + } + if (anchor) { + const anchorContent = htmlElement.querySelector(`[role="main"] ${anchor}`); + if (anchorContent) return anchorContent.textContent; + + console.warn( + `Anchored content block not found. Sphinx search tries to obtain it via DOM query '[role=main] ${anchor}'. Check your theme or template.` + ); + } + + // if anchor not specified or not found, fall back to main content const docContent = htmlElement.querySelector('[role="main"]'); - if (docContent !== undefined) return docContent.textContent; + if (docContent) return docContent.textContent; + console.warn( - "Content block not found. Sphinx search tries to obtain it via '[role=main]'. Could you check your theme or template." + "Content block not found. Sphinx search tries to obtain it via DOM query '[role=main]'. Check your theme or template." ); return ""; }, @@ -239,16 +268,7 @@ const Search = { else Search.deferQuery(query); }, - /** - * execute search (requires search index to be loaded) - */ - query: (query) => { - const filenames = Search._index.filenames; - const docNames = Search._index.docnames; - const titles = Search._index.titles; - const allTitles = Search._index.alltitles; - const indexEntries = Search._index.indexentries; - + _parseQuery: (query) => { // stem the search terms and add them to the correct list const stemmer = new Stemmer(); const searchTerms = new Set(); @@ -284,16 +304,32 @@ const Search = { // console.info("required: ", [...searchTerms]); // console.info("excluded: ", [...excludedTerms]); - // array of [docname, title, anchor, descr, score, filename] - let results = []; + return [query, searchTerms, excludedTerms, highlightTerms, objectTerms]; + }, + + /** + * execute search (requires search index to be loaded) + */ + _performSearch: (query, searchTerms, excludedTerms, highlightTerms, objectTerms) => { + const filenames = Search._index.filenames; + const docNames = Search._index.docnames; + const titles = Search._index.titles; + const allTitles = Search._index.alltitles; + const indexEntries = Search._index.indexentries; + + // Collect multiple result groups to be sorted separately and then ordered. + // Each is an array of [docname, title, anchor, descr, score, filename]. + const normalResults = []; + const nonMainIndexResults = []; + _removeChildren(document.getElementById("search-progress")); - const queryLower = query.toLowerCase(); + const queryLower = query.toLowerCase().trim(); for (const [title, foundTitles] of Object.entries(allTitles)) { - if (title.toLowerCase().includes(queryLower) && (queryLower.length >= title.length/2)) { + if (title.toLowerCase().trim().includes(queryLower) && (queryLower.length >= title.length/2)) { for (const [file, id] of foundTitles) { let score = Math.round(100 * queryLower.length / title.length) - results.push([ + normalResults.push([ docNames[file], titles[file] !== title ? `${titles[file]} > ${title}` : title, id !== null ? "#" + id : "", @@ -308,46 +344,47 @@ const Search = { // search for explicit entries in index directives for (const [entry, foundEntries] of Object.entries(indexEntries)) { if (entry.includes(queryLower) && (queryLower.length >= entry.length/2)) { - for (const [file, id] of foundEntries) { - let score = Math.round(100 * queryLower.length / entry.length) - results.push([ + for (const [file, id, isMain] of foundEntries) { + const score = Math.round(100 * queryLower.length / entry.length); + const result = [ docNames[file], titles[file], id ? "#" + id : "", null, score, filenames[file], - ]); + ]; + if (isMain) { + normalResults.push(result); + } else { + nonMainIndexResults.push(result); + } } } } // lookup as object objectTerms.forEach((term) => - results.push(...Search.performObjectSearch(term, objectTerms)) + normalResults.push(...Search.performObjectSearch(term, objectTerms)) ); // lookup as search terms in fulltext - results.push(...Search.performTermsSearch(searchTerms, excludedTerms)); + normalResults.push(...Search.performTermsSearch(searchTerms, excludedTerms)); // let the scorer override scores with a custom scoring function - if (Scorer.score) results.forEach((item) => (item[4] = Scorer.score(item))); - - // now sort the results by score (in opposite order of appearance, since the - // display function below uses pop() to retrieve items) and then - // alphabetically - results.sort((a, b) => { - const leftScore = a[4]; - const rightScore = b[4]; - if (leftScore === rightScore) { - // same score: sort alphabetically - const leftTitle = a[1].toLowerCase(); - const rightTitle = b[1].toLowerCase(); - if (leftTitle === rightTitle) return 0; - return leftTitle > rightTitle ? -1 : 1; // inverted is intentional - } - return leftScore > rightScore ? 1 : -1; - }); + if (Scorer.score) { + normalResults.forEach((item) => (item[4] = Scorer.score(item))); + nonMainIndexResults.forEach((item) => (item[4] = Scorer.score(item))); + } + + // Sort each group of results by score and then alphabetically by name. + normalResults.sort(_orderResultsByScoreThenName); + nonMainIndexResults.sort(_orderResultsByScoreThenName); + + // Combine the result groups in (reverse) order. + // Non-main index entries are typically arbitrary cross-references, + // so display them after other results. + let results = [...nonMainIndexResults, ...normalResults]; // remove duplicate search results // note the reversing of results, so that in the case of duplicates, the highest-scoring entry is kept @@ -361,7 +398,12 @@ const Search = { return acc; }, []); - results = results.reverse(); + return results.reverse(); + }, + + query: (query) => { + const [searchQuery, searchTerms, excludedTerms, highlightTerms, objectTerms] = Search._parseQuery(query); + const results = Search._performSearch(searchQuery, searchTerms, excludedTerms, highlightTerms, objectTerms); // for debugging //Search.lastresults = results.slice(); // a copy @@ -466,14 +508,18 @@ const Search = { // add support for partial matches if (word.length > 2) { const escapedWord = _escapeRegExp(word); - Object.keys(terms).forEach((term) => { - if (term.match(escapedWord) && !terms[word]) - arr.push({ files: terms[term], score: Scorer.partialTerm }); - }); - Object.keys(titleTerms).forEach((term) => { - if (term.match(escapedWord) && !titleTerms[word]) - arr.push({ files: titleTerms[word], score: Scorer.partialTitle }); - }); + if (!terms.hasOwnProperty(word)) { + Object.keys(terms).forEach((term) => { + if (term.match(escapedWord)) + arr.push({ files: terms[term], score: Scorer.partialTerm }); + }); + } + if (!titleTerms.hasOwnProperty(word)) { + Object.keys(titleTerms).forEach((term) => { + if (term.match(escapedWord)) + arr.push({ files: titleTerms[term], score: Scorer.partialTitle }); + }); + } } // no match but word was a required one @@ -496,9 +542,8 @@ const Search = { // create the mapping files.forEach((file) => { - if (fileMap.has(file) && fileMap.get(file).indexOf(word) === -1) - fileMap.get(file).push(word); - else fileMap.set(file, [word]); + if (!fileMap.has(file)) fileMap.set(file, [word]); + else if (fileMap.get(file).indexOf(word) === -1) fileMap.get(file).push(word); }); }); @@ -549,8 +594,8 @@ const Search = { * search summary for a given text. keywords is a list * of stemmed words. */ - makeSearchSummary: (htmlText, keywords) => { - const text = Search.htmlToText(htmlText); + makeSearchSummary: (htmlText, keywords, anchor) => { + const text = Search.htmlToText(htmlText, anchor); if (text === "") return null; const textLower = text.toLowerCase(); diff --git a/api.html b/api.html index 19197fc..c51f542 100644 --- a/api.html +++ b/api.html @@ -7,7 +7,7 @@ Programming Reference — SSLTools documentation - + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

ssl_tools.analysis.latent_analysis

+
+

Classes

+ + + + + + + + + +

LatentAnalysis

LayerOutputSaverHook

+
+
+

Module Contents

+
+
+class ssl_tools.analysis.latent_analysis.LatentAnalysis(layers, sklearn_cls, output_name_suffix='transformed', **sklearn_kwargs)
+
+
Parameters:
+
    +
  • layers (List[str])

  • +
  • output_name_suffix (str)

  • +
+
+
+
+
+__call__(trainer, model, data_module)
+
+
Parameters:
+
    +
  • trainer (lightning.Trainer)

  • +
  • model (lightning.LightningModule)

  • +
  • data_module (lightning.LightningDataModule)

  • +
+
+
+
+ +
+ +
+
+class ssl_tools.analysis.latent_analysis.LayerOutputSaverHook
+
+
+_forward_hook(module, inputs, outputs, layer_name)
+
+
Parameters:
+

layer_name (str)

+
+
+
+ +
+
+attach_hooks(model, layer_names)
+
+
Parameters:
+
    +
  • model (lightning.LightningModule)

  • +
  • layer_names (List[str])

  • +
+
+
+
+ +
+
+outputs_from_layer(layer_name, concat=True)
+
+
Parameters:
+
    +
  • layer_name (str)

  • +
  • concat (bool)

  • +
+
+
+
+ +
+
+remove_hooks()
+
+ +
+
+run_model_with_hooks(model, layer_names)
+
+
Parameters:
+
    +
  • model (lightning.LightningModule)

  • +
  • layer_names (List[str])

  • +
+
+
+
+ +
+ +
+
+ + +
+
+ +
+
+
+
+ + + + \ No newline at end of file diff --git a/autoapi/ssl_tools/analysis/plot_metrics/index.html b/autoapi/ssl_tools/analysis/plot_metrics/index.html index 0515213..043bb1f 100644 --- a/autoapi/ssl_tools/analysis/plot_metrics/index.html +++ b/autoapi/ssl_tools/analysis/plot_metrics/index.html @@ -7,7 +7,7 @@ ssl_tools.analysis.plot_metrics — SSLTools documentation - + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/autoapi/ssl_tools/benchmarks/main_mix_style/index.html b/autoapi/ssl_tools/benchmarks/main_mix_style/index.html new file mode 100644 index 0000000..7f2ef4d --- /dev/null +++ b/autoapi/ssl_tools/benchmarks/main_mix_style/index.html @@ -0,0 +1,756 @@ + + + + + + + ssl_tools.benchmarks.main_mix_style — SSLTools documentation + + + + + + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

ssl_tools.benchmarks.main_mix_style

+
+

Classes

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

CNN_HaEtAl_1D

CNN_HaEtAl_1D_Backbone

CNN_HaEtAl_2D

CNN_HaEtAl_2D_Backbone

ConvolutionalBlock

ExperimentArgs

ResNet1DBase

ResNet1D_8

ResNetBlock

ResNetSE1D_5

ResNetSE1D_8

ResNetSEBlock

SimpleClassificationNet2

SqueezeAndExcitation1D

_ResNet1D

+
+
+

Functions

+ + + + + + + + + + + + + + + + + + + + + + + + + + + +

_run_experiment_wrapper(experiment_args)

cli_main(experiment)

conv3x3(in_planes, out_planes[, stride, groups, dilation])

3x3 convolution with padding

conv3x3_dynamic(in_planes, out_planes[, stride, ...])

3x3 convolution with padding

main_loo()

pretty_print_experiment_args(args[, indent])

run_serial(experiments)

run_using_ray(experiments[, ray_address])

+
+
+

Module Contents

+
+
+class ssl_tools.benchmarks.main_mix_style.CNN_HaEtAl_1D(input_shape=(1, 6, 60), num_classes=6, learning_rate=0.001)
+

Bases: SimpleClassificationNet2

+
+
Parameters:
+
    +
  • input_shape (Tuple[int, int, int])

  • +
  • num_classes (int)

  • +
  • learning_rate (float)

  • +
+
+
+
+
+_calculate_fc_input_features(backbone, input_shape)
+

Run a single forward pass with a random input to get the number of +features after the convolutional layers.

+
+

Parameters

+
+
backbonetorch.nn.Module

The backbone of the network

+
+
input_shapeTuple[int, int, int]

The input shape of the network.

+
+
+
+
+

Returns

+
+
int

The number of features after the convolutional layers.

+
+
+
+
+
Parameters:
+
    +
  • backbone (torch.nn.Module)

  • +
  • input_shape (Tuple[int, int, int])

  • +
+
+
Return type:
+

int

+
+
+
+ +
+
+_create_backbone(input_shape)
+
+
Parameters:
+

input_shape (Tuple[int, int])

+
+
Return type:
+

torch.nn.Module

+
+
+
+ +
+
+_create_fc(input_features, num_classes)
+
+
Parameters:
+
    +
  • input_features (int)

  • +
  • num_classes (int)

  • +
+
+
Return type:
+

torch.nn.Module

+
+
+
+ +
+ +
+
+class ssl_tools.benchmarks.main_mix_style.CNN_HaEtAl_1D_Backbone(input_channels=1)
+

Bases: torch.nn.Module

+
+
Parameters:
+

input_channels (int)

+
+
+
+
+forward(x)
+
+ +
+ +
+
+class ssl_tools.benchmarks.main_mix_style.CNN_HaEtAl_2D(pad_at=(3,), input_shape=(1, 6, 60), num_classes=6, learning_rate=0.001)
+

Bases: SimpleClassificationNet2

+
+
Parameters:
+
    +
  • pad_at (List[int])

  • +
  • input_shape (Tuple[int, int, int])

  • +
  • num_classes (int)

  • +
  • learning_rate (float)

  • +
+
+
+
+
+_calculate_fc_input_features(backbone, input_shape)
+

Run a single forward pass with a random input to get the number of +features after the convolutional layers.

+
+

Parameters

+
+
backbonetorch.nn.Module

The backbone of the network

+
+
input_shapeTuple[int, int, int]

The input shape of the network.

+
+
+
+
+

Returns

+
+
int

The number of features after the convolutional layers.

+
+
+
+
+
Parameters:
+
    +
  • backbone (torch.nn.Module)

  • +
  • input_shape (Tuple[int, int, int])

  • +
+
+
Return type:
+

int

+
+
+
+ +
+
+_create_backbone(input_shape)
+
+
Parameters:
+

input_shape (Tuple[int, int])

+
+
Return type:
+

torch.nn.Module

+
+
+
+ +
+
+_create_fc(input_features, num_classes)
+
+
Parameters:
+
    +
  • input_features (int)

  • +
  • num_classes (int)

  • +
+
+
Return type:
+

torch.nn.Module

+
+
+
+ +
+ +
+
+class ssl_tools.benchmarks.main_mix_style.CNN_HaEtAl_2D_Backbone(pad_at, in_channels=1)
+

Bases: torch.nn.Module

+
+
Parameters:
+
    +
  • pad_at (int)

  • +
  • in_channels (int)

  • +
+
+
+
+
+forward(x)
+
+ +
+ +
+
+class ssl_tools.benchmarks.main_mix_style.ConvolutionalBlock(in_channels, activation_cls=None)
+

Bases: torch.nn.Module

+
+
Parameters:
+
    +
  • in_channels (int)

  • +
  • activation_cls (torch.nn.Module)

  • +
+
+
+
+
+forward(x)
+
+ +
+ +
+
+class ssl_tools.benchmarks.main_mix_style.ExperimentArgs
+
+
+data_cls: Any
+
+ +
+
+mix: bool = True
+
+ +
+
+model_args: Dict[str, Any]
+
+ +
+
+model_cls: Any
+
+ +
+
+seed: int = 42
+
+ +
+
+test_data_args: Dict[str, Any]
+
+ +
+
+train_data_args: Dict[str, Any]
+
+ +
+
+trainer_args: Dict[str, Any]
+
+ +
+
+trainer_cls: Any
+
+ +
+ +
+
+class ssl_tools.benchmarks.main_mix_style.ResNet1DBase(resnet_block_cls=ResNetBlock, activation_cls=torch.nn.ReLU, input_shape=(6, 60), num_classes=6, num_residual_blocks=5, reduction_ratio=2, learning_rate=0.001)
+

Bases: SimpleClassificationNet2

+
+
Parameters:
+
    +
  • resnet_block_cls (type)

  • +
  • activation_cls (type)

  • +
  • input_shape (Tuple[int, int])

  • +
  • num_classes (int)

  • +
  • num_residual_blocks (int)

  • +
  • learning_rate (float)

  • +
+
+
+
+
+_calculate_fc_input_features(backbone, input_shape)
+

Run a single forward pass with a random input to get the number of +features after the convolutional layers.

+
+

Parameters

+
+
backbonetorch.nn.Module

The backbone of the network

+
+
input_shapeTuple[int, int, int]

The input shape of the network.

+
+
+
+
+

Returns

+
+
int

The number of features after the convolutional layers.

+
+
+
+
+
Parameters:
+
    +
  • backbone (torch.nn.Module)

  • +
  • input_shape (Tuple[int, int, int])

  • +
+
+
Return type:
+

int

+
+
+
+ +
+ +
+
+class ssl_tools.benchmarks.main_mix_style.ResNet1D_8(*args, **kwargs)
+

Bases: ResNet1DBase

+
+ +
+
+class ssl_tools.benchmarks.main_mix_style.ResNetBlock(in_channels=64, activation_cls=torch.nn.ReLU, mix_style_factor=False)
+

Bases: torch.nn.Module

+
+
Parameters:
+
    +
  • in_channels (int)

  • +
  • activation_cls (torch.nn.Module)

  • +
+
+
+
+
+forward(x)
+
+ +
+ +
+
+class ssl_tools.benchmarks.main_mix_style.ResNetSE1D_5(*args, **kwargs)
+

Bases: ResNet1DBase

+
+ +
+
+class ssl_tools.benchmarks.main_mix_style.ResNetSE1D_8(*args, **kwargs)
+

Bases: ResNet1DBase

+
+ +
+
+class ssl_tools.benchmarks.main_mix_style.ResNetSEBlock(*args, **kwargs)
+

Bases: ResNetBlock

+
+ +
+
+class ssl_tools.benchmarks.main_mix_style.SimpleClassificationNet2(backbone, fc, learning_rate=0.001, flatten=True, loss_fn=None, train_metrics=None, val_metrics=None, test_metrics=None)
+

Bases: ssl_tools.models.nets.simple.SimpleClassificationNet

+
+
Parameters:
+
    +
  • backbone (torch.nn.Module)

  • +
  • fc (torch.nn.Module)

  • +
  • learning_rate (float)

  • +
  • flatten (bool)

  • +
  • loss_fn (torch.nn.Module)

  • +
  • train_metrics (Dict[str, torch.Tensor])

  • +
  • val_metrics (Dict[str, torch.Tensor])

  • +
  • test_metrics (Dict[str, torch.Tensor])

  • +
+
+
+
+
+single_step(batch, batch_idx, step_name)
+
+
Parameters:
+
    +
  • batch (torch.Tensor)

  • +
  • batch_idx (int)

  • +
  • step_name (str)

  • +
+
+
+
+ +
+ +
+
+class ssl_tools.benchmarks.main_mix_style.SqueezeAndExcitation1D(in_channels, reduction_ratio=2)
+

Bases: torch.nn.Module

+
+
Parameters:
+
    +
  • in_channels (int)

  • +
  • reduction_ratio (int)

  • +
+
+
+
+
+forward(input_tensor)
+
+ +
+ +
+
+class ssl_tools.benchmarks.main_mix_style._ResNet1D(input_shape, residual_block_cls=ResNetBlock, activation_cls=torch.nn.ReLU, num_residual_blocks=5, reduction_ratio=2)
+

Bases: torch.nn.Module

+
+
Parameters:
+
    +
  • input_shape (Tuple[int, int])

  • +
  • activation_cls (torch.nn.Module)

  • +
  • num_residual_blocks (int)

  • +
+
+
+
+
+forward(x)
+
+ +
+ +
+
+ssl_tools.benchmarks.main_mix_style._run_experiment_wrapper(experiment_args)
+
+
Parameters:
+

experiment_args (ExperimentArgs)

+
+
+
+ +
+
+ssl_tools.benchmarks.main_mix_style.cli_main(experiment)
+
+
Parameters:
+

experiment (ExperimentArgs)

+
+
+
+ +
+
+ssl_tools.benchmarks.main_mix_style.conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1)
+

3x3 convolution with padding

+
+
Parameters:
+
    +
  • in_planes (int)

  • +
  • out_planes (int)

  • +
  • stride (int)

  • +
  • groups (int)

  • +
  • dilation (int)

  • +
+
+
Return type:
+

torch.nn.Conv2d

+
+
+
+ +
+
+ssl_tools.benchmarks.main_mix_style.conv3x3_dynamic(in_planes, out_planes, stride=1, attention_in_channels=None)
+

3x3 convolution with padding

+
+
Parameters:
+
    +
  • in_planes (int)

  • +
  • out_planes (int)

  • +
  • stride (int)

  • +
  • attention_in_channels (int)

  • +
+
+
Return type:
+

dassl.modeling.ops.Conv2dDynamic

+
+
+
+ +
+
+ssl_tools.benchmarks.main_mix_style.main_loo()
+
+ +
+
+ssl_tools.benchmarks.main_mix_style.pretty_print_experiment_args(args, indent=4)
+
+
Parameters:
+
+
+
Return type:
+

str

+
+
+
+ +
+
+ssl_tools.benchmarks.main_mix_style.run_serial(experiments)
+
+
Parameters:
+

experiments (List[ExperimentArgs])

+
+
+
+ +
+
+ssl_tools.benchmarks.main_mix_style.run_using_ray(experiments, ray_address=None)
+
+
Parameters:
+
+
+
+
+ +
+
+ + +
+
+ +
+
+
+
+ + + + \ No newline at end of file diff --git a/autoapi/ssl_tools/benchmarks/main_supervised/index.html b/autoapi/ssl_tools/benchmarks/main_supervised/index.html new file mode 100644 index 0000000..6441562 --- /dev/null +++ b/autoapi/ssl_tools/benchmarks/main_supervised/index.html @@ -0,0 +1,360 @@ + + + + + + + ssl_tools.benchmarks.main_supervised — SSLTools documentation + + + + + + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

ssl_tools.benchmarks.main_supervised

+
+

Classes

+ + + + + + + + + +

ExperimentArgs

SupervisedConfigParser

+
+
+

Functions

+ + + + + + + + + + + + + + + + + + + + + + + + +

_run_experiment_wrapper(experiment_args)

cli_main(experiment)

hack_to_avoid_lightning_cli_sys_argv_warning(func, ...)

main(data_path, default_trainer_config_file, ...[, ...])

run(config_parser, use_ray[, ray_address, dry_run, ...])

run_serial(experiments)

run_using_ray(experiments[, ray_address])

+
+
+

Module Contents

+
+
+class ssl_tools.benchmarks.main_supervised.ExperimentArgs
+
+
+data: Dict[str, Any]
+
+ +
+
+model: Dict[str, Any]
+
+ +
+
+num_classes: int = 7
+
+ +
+
+seed: int = 42
+
+ +
+
+test_data: Dict[str, Any]
+
+ +
+
+trainer: Dict[str, Any]
+
+ +
+ +
+
+class ssl_tools.benchmarks.main_supervised.SupervisedConfigParser(data_path, default_trainer_config, data_module_configs, model_configs, output_dir='benchmarks/', skip_existing=True, seed=42, leave_one_out=False, data_shapes_file=None, num_classes=7)
+
+
Parameters:
+
    +
  • data_path (str)

  • +
  • default_trainer_config (str)

  • +
  • data_module_configs (str | List[str])

  • +
  • model_configs (str | List[str])

  • +
  • output_dir (str)

  • +
  • skip_existing (bool)

  • +
  • seed (int)

  • +
  • leave_one_out (bool)

  • +
  • data_shapes_file (str)

  • +
  • num_classes (int)

  • +
+
+
+
+
+__call__()
+
+
Return type:
+

List[ExperimentArgs]

+
+
+
+ +
+
+filter_experiments(experiments)
+
+
Parameters:
+

experiments (List[ExperimentArgs])

+
+
+
+ +
+
+static scan_configs(configs_path)
+
+
Parameters:
+

configs_path (pathlib.Path)

+
+
Return type:
+

List[pathlib.Path]

+
+
+
+ +
+ +
+
+ssl_tools.benchmarks.main_supervised._run_experiment_wrapper(experiment_args)
+
+
Parameters:
+

experiment_args (ExperimentArgs)

+
+
+
+ +
+
+ssl_tools.benchmarks.main_supervised.cli_main(experiment)
+
+
Parameters:
+

experiment (ExperimentArgs)

+
+
+
+ +
+
+ssl_tools.benchmarks.main_supervised.hack_to_avoid_lightning_cli_sys_argv_warning(func, *args, **kwargs)
+
+ +
+
+ssl_tools.benchmarks.main_supervised.main(data_path, default_trainer_config_file, data_module_configs_path, model_configs_path, output_path='benchmarks/', skip_existing=True, ray_address=None, use_ray=True, seed=42, dry_run=False, dry_run_limit=5, leave_one_out=False, data_shapes_file=None, num_classes=7)
+
+
Parameters:
+
    +
  • data_path (str)

  • +
  • default_trainer_config_file (str)

  • +
  • data_module_configs_path (str | List[str])

  • +
  • model_configs_path (str | List[str])

  • +
  • output_path (str)

  • +
  • skip_existing (bool)

  • +
  • ray_address (str)

  • +
  • use_ray (bool)

  • +
  • seed (int)

  • +
  • dry_run (bool)

  • +
  • dry_run_limit (int)

  • +
  • leave_one_out (bool)

  • +
  • data_shapes_file (str)

  • +
  • num_classes (int)

  • +
+
+
+
+ +
+
+ssl_tools.benchmarks.main_supervised.run(config_parser, use_ray, ray_address=None, dry_run=False, dry_run_limit=3)
+
+
Parameters:
+
    +
  • config_parser (SupervisedConfigParser)

  • +
  • use_ray (bool)

  • +
  • ray_address (str)

  • +
  • dry_run (bool)

  • +
  • dry_run_limit (int)

  • +
+
+
+
+ +
+
+ssl_tools.benchmarks.main_supervised.run_serial(experiments)
+
+
Parameters:
+

experiments (List[ExperimentArgs])

+
+
+
+ +
+
+ssl_tools.benchmarks.main_supervised.run_using_ray(experiments, ray_address=None)
+
+
Parameters:
+
+
+
+
+ +
+
+ + +
+
+ +
+
+
+
+ + + + \ No newline at end of file diff --git a/autoapi/ssl_tools/benchmarks/main_supervised_analysis/index.html b/autoapi/ssl_tools/benchmarks/main_supervised_analysis/index.html new file mode 100644 index 0000000..b8a9c3d --- /dev/null +++ b/autoapi/ssl_tools/benchmarks/main_supervised_analysis/index.html @@ -0,0 +1,168 @@ + + + + + + + ssl_tools.benchmarks.main_supervised_analysis — SSLTools documentation + + + + + + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

ssl_tools.benchmarks.main_supervised_analysis

+
+

Functions

+ + + + + + +

analysis(results_dir[, query, output_dir, ...])

+
+
+

Module Contents

+
+
+ssl_tools.benchmarks.main_supervised_analysis.analysis(results_dir, query=None, output_dir=None, result_file='results.csv', print_results=True, remove_on_error=False)
+
+
Parameters:
+
    +
  • results_dir (str)

  • +
  • query (str)

  • +
  • output_dir (str)

  • +
  • result_file (str)

  • +
  • print_results (bool)

  • +
  • remove_on_error (bool)

  • +
+
+
+
+ +
+
+ + +
+
+ +
+
+
+
+ + + + \ No newline at end of file diff --git a/autoapi/ssl_tools/benchmarks/simple_trainer/index.html b/autoapi/ssl_tools/benchmarks/simple_trainer/index.html new file mode 100644 index 0000000..1510c56 --- /dev/null +++ b/autoapi/ssl_tools/benchmarks/simple_trainer/index.html @@ -0,0 +1,156 @@ + + + + + + + ssl_tools.benchmarks.simple_trainer — SSLTools documentation + + + + + + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

ssl_tools.benchmarks.simple_trainer

+
+

Functions

+ + + + + + +

cli_main()

+
+
+

Module Contents

+
+
+ssl_tools.benchmarks.simple_trainer.cli_main()
+
+ +
+
+ + +
+
+ +
+
+
+
+ + + + \ No newline at end of file diff --git a/autoapi/ssl_tools/callbacks/index.html b/autoapi/ssl_tools/callbacks/index.html index eb6a9f4..f04c139 100644 --- a/autoapi/ssl_tools/callbacks/index.html +++ b/autoapi/ssl_tools/callbacks/index.html @@ -7,7 +7,7 @@ ssl_tools.callbacks — SSLTools documentation - + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

ssl_tools.callbacks.save_best

+
+

Classes

+ + + + + + +

PickleBestModelAndLoad

+
+
+

Module Contents

+
+
+class ssl_tools.callbacks.save_best.PickleBestModelAndLoad(model_name, filename='best_model.pt', model_tags=None, model_description=None)
+

Bases: lightning.Callback

+
+
Parameters:
+
    +
  • model_name (str)

  • +
  • filename (str)

  • +
  • model_tags (Optional[Dict[str, Any]])

  • +
  • model_description (Optional[str])

  • +
+
+
+
+
+on_train_end(trainer, module)
+
+
Parameters:
+
    +
  • trainer (lightning.Trainer)

  • +
  • module (lightning.LightningModule)

  • +
+
+
Return type:
+

None

+
+
+
+ +
+ +
+
+ + +
+
+ +
+
+
+
+ + + + \ No newline at end of file diff --git a/autoapi/ssl_tools/data/data_modules/base/index.html b/autoapi/ssl_tools/data/data_modules/base/index.html new file mode 100644 index 0000000..810a3ab --- /dev/null +++ b/autoapi/ssl_tools/data/data_modules/base/index.html @@ -0,0 +1,227 @@ + + + + + + + ssl_tools.data.data_modules.base — SSLTools documentation + + + + + + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

ssl_tools.data.data_modules.base

+
+

Classes

+ + + + + + +

SimpleDataModule

+
+
+

Module Contents

+
+
+class ssl_tools.data.data_modules.base.SimpleDataModule
+

Bases: lightning.LightningDataModule

+
+
+abstract _get_loader(split_name, shuffle)
+
+
Parameters:
+
    +
  • split_name (str)

  • +
  • shuffle (bool)

  • +
+
+
Return type:
+

torch.utils.data.DataLoader

+
+
+
+ +
+
+abstract _load_dataset(split_name)
+
+
Parameters:
+

split_name (str)

+
+
Return type:
+

torch.utils.data.Dataset

+
+
+
+ +
+
+predict_dataloader()
+
+
Return type:
+

torch.utils.data.DataLoader

+
+
+
+ +
+
+test_dataloader()
+
+
Return type:
+

torch.utils.data.DataLoader

+
+
+
+ +
+
+train_dataloader()
+
+
Return type:
+

torch.utils.data.DataLoader

+
+
+
+ +
+
+val_dataloader()
+
+
Return type:
+

torch.utils.data.DataLoader

+
+
+
+ +
+ +
+
+ + +
+
+ +
+
+
+
+ + + + \ No newline at end of file diff --git a/autoapi/ssl_tools/data/data_modules/covid_anomaly/index.html b/autoapi/ssl_tools/data/data_modules/covid_anomaly/index.html new file mode 100644 index 0000000..650f795 --- /dev/null +++ b/autoapi/ssl_tools/data/data_modules/covid_anomaly/index.html @@ -0,0 +1,230 @@ + + + + + + + ssl_tools.data.data_modules.covid_anomaly — SSLTools documentation + + + + + + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

ssl_tools.data.data_modules.covid_anomaly

+
+

Classes

+ + + + + + +

CovidUserAnomalyDataModule

+
+
+

Module Contents

+
+
+class ssl_tools.data.data_modules.covid_anomaly.CovidUserAnomalyDataModule(data_path, participants=None, feature_column_prefix='RHR', target_column='anomaly', participant_column='participant_id', include_recovered_in_test=False, reshape=None, train_transforms=None, batch_size=32, num_workers=1, validation_split=0.2, dataset_transforms=None, shuffle_train=True, discard_last_batch=False, balance=False, train_baseline_only=True)
+

Bases: lightning.LightningDataModule

+
+
Parameters:
+
    +
  • data_path (pathlib.Path)

  • +
  • participants (Union[str, int, List[Union[str, int]]])

  • +
  • feature_column_prefix (str)

  • +
  • target_column (str)

  • +
  • participant_column (str)

  • +
  • include_recovered_in_test (bool)

  • +
  • reshape (tuple)

  • +
  • train_transforms (List[callable])

  • +
  • batch_size (int)

  • +
  • num_workers (int)

  • +
  • validation_split (float)

  • +
  • dataset_transforms (List[callable])

  • +
  • shuffle_train (bool)

  • +
  • discard_last_batch (bool)

  • +
  • balance (bool)

  • +
  • train_baseline_only (bool)

  • +
+
+
+
+
+__repr__()
+
+
Return type:
+

str

+
+
+
+ +
+
+__str__()
+
+
Return type:
+

str

+
+
+
+ +
+
+predict_dataloader()
+
+ +
+
+setup(stage)
+
+
Parameters:
+

stage (str)

+
+
+
+ +
+
+test_dataloader()
+
+ +
+
+train_dataloader()
+
+ +
+
+val_dataloader()
+
+ +
+ +
+
+ + +
+
+ +
+
+
+
+ + + + \ No newline at end of file diff --git a/autoapi/ssl_tools/data/data_modules/har/index.html b/autoapi/ssl_tools/data/data_modules/har/index.html index 5b43352..75f1854 100644 --- a/autoapi/ssl_tools/data/data_modules/har/index.html +++ b/autoapi/ssl_tools/data/data_modules/har/index.html @@ -7,7 +7,7 @@ ssl_tools.data.data_modules.har — SSLTools documentation - + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

ssl_tools.data.datasets.augmented_dataset

+
+

Classes

+ + + + + + +

AugmentedDataset

Note: this class assumes that dataset is a Dataset object, and that

+
+
+

Module Contents

+
+
+class ssl_tools.data.datasets.augmented_dataset.AugmentedDataset(dataset, transforms)
+

Bases: torch.utils.data.Dataset

+

Note: this class assumes that dataset is a Dataset object, and that +the __getitem__ method of the dataset returns a tuple of n elements.

+

_summary_

+
+

Parameters

+
+
datasetDataset

_description_

+
+
transformsDict[int, Callable]

As each element (result of __getitem__) of the dataset is a +n-element tuple, the transforms are applied to the n-th element +of the tuple. The key of the dictionary is the index of the +element of the tuple to apply the transform (0-indexed), and the +value is the transform to apply.

+
+
+
+
+__getitem__(idx)
+
+ +
+
+__len__()
+
+ +
+
+
Parameters:
+
    +
  • dataset (torch.utils.data.Dataset)

  • +
  • transforms (List[Union[Callable, Dict[int, Callable]]])

  • +
+
+
+
+ +
+
+ + +
+
+ +
+
+
+
+ + + + \ No newline at end of file diff --git a/autoapi/ssl_tools/data/datasets/domain_dataset/index.html b/autoapi/ssl_tools/data/datasets/domain_dataset/index.html new file mode 100644 index 0000000..169566f --- /dev/null +++ b/autoapi/ssl_tools/data/datasets/domain_dataset/index.html @@ -0,0 +1,167 @@ + + + + + + + ssl_tools.data.datasets.domain_dataset — SSLTools documentation + + + + + + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

ssl_tools.data.datasets.domain_dataset

+
+

Classes

+ + + + + + +

DomainDataset

+
+
+

Module Contents

+
+
+class ssl_tools.data.datasets.domain_dataset.DomainDataset(dataset, domain)
+
+
+__getitem__(idx)
+
+ +
+
+__len__()
+
+ +
+ +
+
+ + +
+
+ +
+
+
+
+ + + + \ No newline at end of file diff --git a/autoapi/ssl_tools/data/datasets/index.html b/autoapi/ssl_tools/data/datasets/index.html index 108c628..d7fb175 100644 --- a/autoapi/ssl_tools/data/datasets/index.html +++ b/autoapi/ssl_tools/data/datasets/index.html @@ -7,7 +7,7 @@ ssl_tools.data.datasets — SSLTools documentation - + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

ssl_tools.experiments.covid_detection.anomaly_detection_base

+
+

Classes

+ + + + + + + + + + + + +

CovidAnomalyDetectionEvaluator

Helper class that provides a standard way to create an ABC using

CovidAnomalyDetectionTrain

Helper class that provides a standard way to create an ABC using

RMSELoss

+
+
+

Functions

+ + + + + + + + + + + + + + + + + + + + + + + + +

kmeans_threshold(X_recon[, n_clusters])

mean_absolute_error(X, X_recon)

mean_squared_error(X, X_recon)

root_mean_squared_error(X, X_recon)

sigma_threshold(X_recon, sigma)

zscore_threshold_max(X_recon)

zscore_threshold_std(X_recon, std)

+
+
+

Module Contents

+
+
+class ssl_tools.experiments.covid_detection.anomaly_detection_base.CovidAnomalyDetectionEvaluator(train_data, test_data, train_participant, test_participant, input_shape, feature_column_prefix='RHR', target_column='anomaly', include_recovered_in_test=False, results_dir='results', *args, **kwargs)
+

Bases: ssl_tools.experiments.LightningTest

+

Helper class that provides a standard way to create an ABC using +inheritance.

+
+
Parameters:
+
    +
  • train_data (str)

  • +
  • test_data (str)

  • +
  • train_participant (int)

  • +
  • test_participant (int)

  • +
  • input_shape (Tuple[int, Ellipsis])

  • +
  • feature_column_prefix (str)

  • +
  • target_column (str)

  • +
  • include_recovered_in_test (bool)

  • +
  • results_dir (str)

  • +
+
+
+
+
+_calc_static_anomaly_thresholds(losses)
+
+
Parameters:
+

losses (numpy.ndarray)

+
+
Return type:
+

Dict[str, float]

+
+
+
+ +
+
+get_data_module()
+

Get the datamodule to use for the experiment.

+
+

Returns

+
+
L.LightningDataModule

The datamodule to use for the experiment

+
+
+
+
+
Return type:
+

ssl_tools.data.data_modules.CovidUserAnomalyDataModule

+
+
+
+ +
+
+abstract get_model()
+

Get the model to use for the experiment.

+
+

Returns

+
+
L.LightningModule

The model to use for the experiment

+
+
+
+
+
Return type:
+

lightning.LightningModule

+
+
+
+ +
+
+run_model(model, data_module, trainer)
+
+
Parameters:
+
    +
  • model (lightning.LightningModule)

  • +
  • data_module (lightning.LightningDataModule)

  • +
  • trainer (lightning.Trainer)

  • +
+
+
+
+ +
+ +
+
+class ssl_tools.experiments.covid_detection.anomaly_detection_base.CovidAnomalyDetectionTrain(data, input_shape, participant_ids=None, validation_split=0.1, augment=False, feature_column_prefix='RHR', target_column='anomaly', *args, **kwargs)
+

Bases: ssl_tools.experiments.LightningTrain

+

Helper class that provides a standard way to create an ABC using +inheritance.

+
+
Parameters:
+
    +
  • data (str)

  • +
  • input_shape (Tuple[int, Ellipsis])

  • +
  • participant_ids (Optional[Union[int, List[int]]])

  • +
  • validation_split (float)

  • +
  • augment (bool)

  • +
  • feature_column_prefix (str)

  • +
  • target_column (str)

  • +
+
+
+
+
+_get_transforms()
+
+ +
+
+get_data_module()
+

Get the datamodule to use for the experiment.

+
+

Returns

+
+
L.LightningDataModule

The datamodule to use for the experiment

+
+
+
+
+
Return type:
+

ssl_tools.data.data_modules.CovidUserAnomalyDataModule

+
+
+
+ +
+
+abstract get_model()
+

Get the model to use for the experiment.

+
+

Returns

+
+
L.LightningModule

The model to use for the experiment

+
+
+
+
+
Return type:
+

lightning.LightningModule

+
+
+
+ +
+ +
+
+class ssl_tools.experiments.covid_detection.anomaly_detection_base.RMSELoss(eps=1e-06, *args, **kwargs)
+

Bases: torch.nn.Module

+
+
+forward(y_hat, y)
+
+
Parameters:
+
    +
  • y_hat (torch.Tensor)

  • +
  • y (torch.Tensor)

  • +
+
+
Return type:
+

torch.Tensor

+
+
+
+ +
+ +
+
+ssl_tools.experiments.covid_detection.anomaly_detection_base.kmeans_threshold(X_recon, n_clusters=1)
+
+ +
+
+ssl_tools.experiments.covid_detection.anomaly_detection_base.mean_absolute_error(X, X_recon)
+
+ +
+
+ssl_tools.experiments.covid_detection.anomaly_detection_base.mean_squared_error(X, X_recon)
+
+ +
+
+ssl_tools.experiments.covid_detection.anomaly_detection_base.root_mean_squared_error(X, X_recon)
+
+ +
+
+ssl_tools.experiments.covid_detection.anomaly_detection_base.sigma_threshold(X_recon, sigma)
+
+ +
+
+ssl_tools.experiments.covid_detection.anomaly_detection_base.zscore_threshold_max(X_recon)
+
+ +
+
+ssl_tools.experiments.covid_detection.anomaly_detection_base.zscore_threshold_std(X_recon, std)
+
+ +
+
+ + +
+
+ +
+
+
+
+ + + + \ No newline at end of file diff --git a/autoapi/ssl_tools/experiments/covid_detection/cae/index.html b/autoapi/ssl_tools/experiments/covid_detection/cae/index.html new file mode 100644 index 0000000..3dbf880 --- /dev/null +++ b/autoapi/ssl_tools/experiments/covid_detection/cae/index.html @@ -0,0 +1,260 @@ + + + + + + + ssl_tools.experiments.covid_detection.cae — SSLTools documentation + + + + + + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

ssl_tools.experiments.covid_detection.cae

+
+

Attributes

+ + + + + + +

options

+
+
+

Classes

+ + + + + + + + + +

ConvolutionalAutoencoderAnomalyDetectionTest

Helper class that provides a standard way to create an ABC using

ConvolutionalAutoencoderAnomalyDetectionTrain

Helper class that provides a standard way to create an ABC using

+
+
+

Module Contents

+
+
+class ssl_tools.experiments.covid_detection.cae.ConvolutionalAutoencoderAnomalyDetectionTest(train_data, test_data, train_participant, test_participant, input_shape, feature_column_prefix='RHR', target_column='anomaly', include_recovered_in_test=False, results_dir='results', *args, **kwargs)
+

Bases: ssl_tools.experiments.covid_detection.anomaly_detection_base.CovidAnomalyDetectionEvaluator

+

Helper class that provides a standard way to create an ABC using +inheritance.

+
+
Parameters:
+
    +
  • train_data (str)

  • +
  • test_data (str)

  • +
  • train_participant (int)

  • +
  • test_participant (int)

  • +
  • input_shape (Tuple[int, Ellipsis])

  • +
  • feature_column_prefix (str)

  • +
  • target_column (str)

  • +
  • include_recovered_in_test (bool)

  • +
  • results_dir (str)

  • +
+
+
+
+
+_MODEL_NAME = 'cae'
+
+ +
+
+get_model()
+

Get the model to use for the experiment.

+
+

Returns

+
+
L.LightningModule

The model to use for the experiment

+
+
+
+
+
Return type:
+

lightning.LightningModule

+
+
+
+ +
+ +
+
+class ssl_tools.experiments.covid_detection.cae.ConvolutionalAutoencoderAnomalyDetectionTrain(data, input_shape, participant_ids=None, validation_split=0.1, augment=False, feature_column_prefix='RHR', target_column='anomaly', *args, **kwargs)
+

Bases: ssl_tools.experiments.covid_detection.anomaly_detection_base.CovidAnomalyDetectionTrain

+

Helper class that provides a standard way to create an ABC using +inheritance.

+
+
Parameters:
+
    +
  • data (str)

  • +
  • input_shape (Tuple[int, Ellipsis])

  • +
  • participant_ids (Optional[Union[int, List[int]]])

  • +
  • validation_split (float)

  • +
  • augment (bool)

  • +
  • feature_column_prefix (str)

  • +
  • target_column (str)

  • +
+
+
+
+
+_MODEL_NAME = 'cae'
+
+ +
+
+get_model()
+

Get the model to use for the experiment.

+
+

Returns

+
+
L.LightningModule

The model to use for the experiment

+
+
+
+
+
Return type:
+

lightning.LightningModule

+
+
+
+ +
+ +
+
+ssl_tools.experiments.covid_detection.cae.options
+
+ +
+
+ + +
+
+ +
+
+
+
+ + + + \ No newline at end of file diff --git a/autoapi/ssl_tools/experiments/covid_detection/cae2d/index.html b/autoapi/ssl_tools/experiments/covid_detection/cae2d/index.html new file mode 100644 index 0000000..2c8b703 --- /dev/null +++ b/autoapi/ssl_tools/experiments/covid_detection/cae2d/index.html @@ -0,0 +1,260 @@ + + + + + + + ssl_tools.experiments.covid_detection.cae2d — SSLTools documentation + + + + + + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

ssl_tools.experiments.covid_detection.cae2d

+
+

Attributes

+ + + + + + +

options

+
+
+

Classes

+ + + + + + + + + +

ConvolutionalAutoencoder2DAnomalyDetectionTest

Helper class that provides a standard way to create an ABC using

ConvolutionalAutoencoder2DAnomalyDetectionTrain

Helper class that provides a standard way to create an ABC using

+
+
+

Module Contents

+
+
+class ssl_tools.experiments.covid_detection.cae2d.ConvolutionalAutoencoder2DAnomalyDetectionTest(train_data, test_data, train_participant, test_participant, input_shape, feature_column_prefix='RHR', target_column='anomaly', include_recovered_in_test=False, results_dir='results', *args, **kwargs)
+

Bases: ssl_tools.experiments.covid_detection.anomaly_detection_base.CovidAnomalyDetectionEvaluator

+

Helper class that provides a standard way to create an ABC using +inheritance.

+
+
Parameters:
+
    +
  • train_data (str)

  • +
  • test_data (str)

  • +
  • train_participant (int)

  • +
  • test_participant (int)

  • +
  • input_shape (Tuple[int, Ellipsis])

  • +
  • feature_column_prefix (str)

  • +
  • target_column (str)

  • +
  • include_recovered_in_test (bool)

  • +
  • results_dir (str)

  • +
+
+
+
+
+_MODEL_NAME = 'cae2d'
+
+ +
+
+get_model()
+

Get the model to use for the experiment.

+
+

Returns

+
+
L.LightningModule

The model to use for the experiment

+
+
+
+
+
Return type:
+

lightning.LightningModule

+
+
+
+ +
+ +
+
+class ssl_tools.experiments.covid_detection.cae2d.ConvolutionalAutoencoder2DAnomalyDetectionTrain(data, input_shape, participant_ids=None, validation_split=0.1, augment=False, feature_column_prefix='RHR', target_column='anomaly', *args, **kwargs)
+

Bases: ssl_tools.experiments.covid_detection.anomaly_detection_base.CovidAnomalyDetectionTrain

+

Helper class that provides a standard way to create an ABC using +inheritance.

+
+
Parameters:
+
    +
  • data (str)

  • +
  • input_shape (Tuple[int, Ellipsis])

  • +
  • participant_ids (Optional[Union[int, List[int]]])

  • +
  • validation_split (float)

  • +
  • augment (bool)

  • +
  • feature_column_prefix (str)

  • +
  • target_column (str)

  • +
+
+
+
+
+_MODEL_NAME = 'cae2d'
+
+ +
+
+get_model()
+

Get the model to use for the experiment.

+
+

Returns

+
+
L.LightningModule

The model to use for the experiment

+
+
+
+
+
Return type:
+

lightning.LightningModule

+
+
+
+ +
+ +
+
+ssl_tools.experiments.covid_detection.cae2d.options
+
+ +
+
+ + +
+
+ +
+
+
+
+ + + + \ No newline at end of file diff --git a/autoapi/ssl_tools/experiments/covid_detection/ccae/index.html b/autoapi/ssl_tools/experiments/covid_detection/ccae/index.html new file mode 100644 index 0000000..fb061a5 --- /dev/null +++ b/autoapi/ssl_tools/experiments/covid_detection/ccae/index.html @@ -0,0 +1,260 @@ + + + + + + + ssl_tools.experiments.covid_detection.ccae — SSLTools documentation + + + + + + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

ssl_tools.experiments.covid_detection.ccae

+
+

Attributes

+ + + + + + +

options

+
+
+

Classes

+ + + + + + + + + +

ConvolutionalAutoencoderAnomalyDetectionTest

Helper class that provides a standard way to create an ABC using

ConvolutionalAutoencoderAnomalyDetectionTrain

Helper class that provides a standard way to create an ABC using

+
+
+

Module Contents

+
+
+class ssl_tools.experiments.covid_detection.ccae.ConvolutionalAutoencoderAnomalyDetectionTest(train_data, test_data, train_participant, test_participant, input_shape, feature_column_prefix='RHR', target_column='anomaly', include_recovered_in_test=False, results_dir='results', *args, **kwargs)
+

Bases: ssl_tools.experiments.covid_detection.anomaly_detection_base.CovidAnomalyDetectionEvaluator

+

Helper class that provides a standard way to create an ABC using +inheritance.

+
+
Parameters:
+
    +
  • train_data (str)

  • +
  • test_data (str)

  • +
  • train_participant (int)

  • +
  • test_participant (int)

  • +
  • input_shape (Tuple[int, Ellipsis])

  • +
  • feature_column_prefix (str)

  • +
  • target_column (str)

  • +
  • include_recovered_in_test (bool)

  • +
  • results_dir (str)

  • +
+
+
+
+
+_MODEL_NAME = 'ccae'
+
+ +
+
+get_model()
+

Get the model to use for the experiment.

+
+

Returns

+
+
L.LightningModule

The model to use for the experiment

+
+
+
+
+
Return type:
+

lightning.LightningModule

+
+
+
+ +
+ +
+
+class ssl_tools.experiments.covid_detection.ccae.ConvolutionalAutoencoderAnomalyDetectionTrain(data, input_shape, participant_ids=None, validation_split=0.1, augment=False, feature_column_prefix='RHR', target_column='anomaly', *args, **kwargs)
+

Bases: ssl_tools.experiments.covid_detection.anomaly_detection_base.CovidAnomalyDetectionTrain

+

Helper class that provides a standard way to create an ABC using +inheritance.

+
+
Parameters:
+
    +
  • data (str)

  • +
  • input_shape (Tuple[int, Ellipsis])

  • +
  • participant_ids (Optional[Union[int, List[int]]])

  • +
  • validation_split (float)

  • +
  • augment (bool)

  • +
  • feature_column_prefix (str)

  • +
  • target_column (str)

  • +
+
+
+
+
+_MODEL_NAME = 'ccae'
+
+ +
+
+get_model()
+

Get the model to use for the experiment.

+
+

Returns

+
+
L.LightningModule

The model to use for the experiment

+
+
+
+
+
Return type:
+

lightning.LightningModule

+
+
+
+ +
+ +
+
+ssl_tools.experiments.covid_detection.ccae.options
+
+ +
+
+ + +
+
+ +
+
+
+
+ + + + \ No newline at end of file diff --git a/autoapi/ssl_tools/experiments/covid_detection/classfication_report/index.html b/autoapi/ssl_tools/experiments/covid_detection/classfication_report/index.html new file mode 100644 index 0000000..8fd9b22 --- /dev/null +++ b/autoapi/ssl_tools/experiments/covid_detection/classfication_report/index.html @@ -0,0 +1,269 @@ + + + + + + + ssl_tools.experiments.covid_detection.classfication_report — SSLTools documentation + + + + + + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

ssl_tools.experiments.covid_detection.classfication_report

+
+

Functions

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

_balanced_accuracy_score(y_true, y_pred, labels)

_matthews_corrcoef(y_true, y_pred, labels)

_roc_auc_score(y_true, y_pred, labels)

accuracy_score(tn, fp, fn, tp)

classification_report(y_true, y_pred[, labels])

f1_score(tn, fp, fn, tp)

f2_score(tn, fp, fn, tp)

f2_score(tn, fp, fn, tp)

fbeta_score(tn, fp, fn, tp[, beta])

negative_precision_score(tn, fp, fn, tp)

precision_score(tn, fp, fn, tp)

recall_score(tn, fp, fn, tp)

specificity_score(tn, fp, fn, tp)

uar_score(tn, fp, fn, tp)

wrap_zero_div(func)

+
+
+

Module Contents

+
+
+ssl_tools.experiments.covid_detection.classfication_report._balanced_accuracy_score(y_true, y_pred, labels)
+
+ +
+
+ssl_tools.experiments.covid_detection.classfication_report._matthews_corrcoef(y_true, y_pred, labels)
+
+ +
+
+ssl_tools.experiments.covid_detection.classfication_report._roc_auc_score(y_true, y_pred, labels)
+
+ +
+
+ssl_tools.experiments.covid_detection.classfication_report.accuracy_score(tn, fp, fn, tp)
+
+ +
+
+ssl_tools.experiments.covid_detection.classfication_report.classification_report(y_true, y_pred, labels=None)
+
+ +
+
+ssl_tools.experiments.covid_detection.classfication_report.f1_score(tn, fp, fn, tp)
+
+ +
+
+ssl_tools.experiments.covid_detection.classfication_report.f2_score(tn, fp, fn, tp)
+
+ +
+
+ssl_tools.experiments.covid_detection.classfication_report.f2_score(tn, fp, fn, tp)
+
+ +
+
+ssl_tools.experiments.covid_detection.classfication_report.fbeta_score(tn, fp, fn, tp, beta=0.1)
+
+ +
+
+ssl_tools.experiments.covid_detection.classfication_report.negative_precision_score(tn, fp, fn, tp)
+
+ +
+
+ssl_tools.experiments.covid_detection.classfication_report.precision_score(tn, fp, fn, tp)
+
+ +
+
+ssl_tools.experiments.covid_detection.classfication_report.recall_score(tn, fp, fn, tp)
+
+ +
+
+ssl_tools.experiments.covid_detection.classfication_report.specificity_score(tn, fp, fn, tp)
+
+ +
+
+ssl_tools.experiments.covid_detection.classfication_report.uar_score(tn, fp, fn, tp)
+
+ +
+
+ssl_tools.experiments.covid_detection.classfication_report.wrap_zero_div(func)
+
+ +
+
+ + +
+
+ +
+
+
+
+ + + + \ No newline at end of file diff --git a/autoapi/ssl_tools/experiments/covid_detection/classification_base/index.html b/autoapi/ssl_tools/experiments/covid_detection/classification_base/index.html new file mode 100644 index 0000000..1b209b5 --- /dev/null +++ b/autoapi/ssl_tools/experiments/covid_detection/classification_base/index.html @@ -0,0 +1,243 @@ + + + + + + + ssl_tools.experiments.covid_detection.classification_base — SSLTools documentation + + + + + + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

ssl_tools.experiments.covid_detection.classification_base

+
+

Classes

+ + + + + + + + + +

CovidDetectionEvaluator

Helper class that provides a standard way to create an ABC using

CovidDetectionTrain

Helper class that provides a standard way to create an ABC using

+
+
+

Module Contents

+
+
+class ssl_tools.experiments.covid_detection.classification_base.CovidDetectionEvaluator(data, feature_column_prefix='RHR', target_column='anomaly', results_file='results.csv', *args, **kwargs)
+

Bases: ssl_tools.experiments.LightningTest

+

Helper class that provides a standard way to create an ABC using +inheritance.

+
+
Parameters:
+
    +
  • data (str)

  • +
  • feature_column_prefix (str)

  • +
  • target_column (str)

  • +
  • results_file (str)

  • +
+
+
+
+
+get_data_module()
+

Get the datamodule to use for the experiment.

+
+

Returns

+
+
L.LightningDataModule

The datamodule to use for the experiment

+
+
+
+
+
Return type:
+

ssl_tools.data.data_modules.CovidUserAnomalyDataModule

+
+
+
+ +
+
+run_model(model, data_module, trainer)
+
+
Parameters:
+
    +
  • model (lightning.LightningModule)

  • +
  • data_module (lightning.LightningDataModule)

  • +
  • trainer (lightning.Trainer)

  • +
+
+
+
+ +
+ +
+
+class ssl_tools.experiments.covid_detection.classification_base.CovidDetectionTrain(data, reshape=None, validation_split=0.1, balance=False, feature_column_prefix='RHR', target_column='anomaly', *args, **kwargs)
+

Bases: ssl_tools.experiments.LightningTrain

+

Helper class that provides a standard way to create an ABC using +inheritance.

+
+
Parameters:
+
    +
  • data (str)

  • +
  • reshape (Optional[Tuple[int, Ellipsis]])

  • +
  • validation_split (float)

  • +
  • balance (bool)

  • +
  • feature_column_prefix (str)

  • +
  • target_column (str)

  • +
+
+
+
+
+get_data_module()
+

Get the datamodule to use for the experiment.

+
+

Returns

+
+
L.LightningDataModule

The datamodule to use for the experiment

+
+
+
+
+
Return type:
+

ssl_tools.data.data_modules.CovidUserAnomalyDataModule

+
+
+
+ +
+ +
+
+ + +
+
+ +
+
+
+
+ + + + \ No newline at end of file diff --git a/autoapi/ssl_tools/experiments/covid_detection/index.html b/autoapi/ssl_tools/experiments/covid_detection/index.html new file mode 100644 index 0000000..872ceb7 --- /dev/null +++ b/autoapi/ssl_tools/experiments/covid_detection/index.html @@ -0,0 +1,153 @@ + + + + + + + ssl_tools.experiments.covid_detection — SSLTools documentation + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/autoapi/ssl_tools/experiments/covid_detection/lstm_ae/index.html b/autoapi/ssl_tools/experiments/covid_detection/lstm_ae/index.html new file mode 100644 index 0000000..48c8467 --- /dev/null +++ b/autoapi/ssl_tools/experiments/covid_detection/lstm_ae/index.html @@ -0,0 +1,260 @@ + + + + + + + ssl_tools.experiments.covid_detection.lstm_ae — SSLTools documentation + + + + + + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

ssl_tools.experiments.covid_detection.lstm_ae

+
+

Attributes

+ + + + + + +

options

+
+
+

Classes

+ + + + + + + + + +

LSTMAutoencoderAnomalyDetectionTest

Helper class that provides a standard way to create an ABC using

LSTMAutoencoderAnomalyDetectionTrain

Helper class that provides a standard way to create an ABC using

+
+
+

Module Contents

+
+
+class ssl_tools.experiments.covid_detection.lstm_ae.LSTMAutoencoderAnomalyDetectionTest(train_data, test_data, train_participant, test_participant, input_shape, feature_column_prefix='RHR', target_column='anomaly', include_recovered_in_test=False, results_dir='results', *args, **kwargs)
+

Bases: ssl_tools.experiments.covid_detection.anomaly_detection_base.CovidAnomalyDetectionEvaluator

+

Helper class that provides a standard way to create an ABC using +inheritance.

+
+
Parameters:
+
    +
  • train_data (str)

  • +
  • test_data (str)

  • +
  • train_participant (int)

  • +
  • test_participant (int)

  • +
  • input_shape (Tuple[int, Ellipsis])

  • +
  • feature_column_prefix (str)

  • +
  • target_column (str)

  • +
  • include_recovered_in_test (bool)

  • +
  • results_dir (str)

  • +
+
+
+
+
+_MODEL_NAME = 'lstm-ae'
+
+ +
+
+get_model()
+

Get the model to use for the experiment.

+
+

Returns

+
+
L.LightningModule

The model to use for the experiment

+
+
+
+
+
Return type:
+

lightning.LightningModule

+
+
+
+ +
+ +
+
+class ssl_tools.experiments.covid_detection.lstm_ae.LSTMAutoencoderAnomalyDetectionTrain(data, input_shape, participant_ids=None, validation_split=0.1, augment=False, feature_column_prefix='RHR', target_column='anomaly', *args, **kwargs)
+

Bases: ssl_tools.experiments.covid_detection.anomaly_detection_base.CovidAnomalyDetectionTrain

+

Helper class that provides a standard way to create an ABC using +inheritance.

+
+
Parameters:
+
    +
  • data (str)

  • +
  • input_shape (Tuple[int, Ellipsis])

  • +
  • participant_ids (Optional[Union[int, List[int]]])

  • +
  • validation_split (float)

  • +
  • augment (bool)

  • +
  • feature_column_prefix (str)

  • +
  • target_column (str)

  • +
+
+
+
+
+_MODEL_NAME = 'lstm-ae'
+
+ +
+
+get_model()
+

Get the model to use for the experiment.

+
+

Returns

+
+
L.LightningModule

The model to use for the experiment

+
+
+
+
+
Return type:
+

lightning.LightningModule

+
+
+
+ +
+ +
+
+ssl_tools.experiments.covid_detection.lstm_ae.options
+
+ +
+
+ + +
+
+ +
+
+
+
+ + + + \ No newline at end of file diff --git a/autoapi/ssl_tools/experiments/covid_detection/mlp/index.html b/autoapi/ssl_tools/experiments/covid_detection/mlp/index.html new file mode 100644 index 0000000..6c7dc98 --- /dev/null +++ b/autoapi/ssl_tools/experiments/covid_detection/mlp/index.html @@ -0,0 +1,279 @@ + + + + + + + ssl_tools.experiments.covid_detection.mlp — SSLTools documentation + + + + + + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

ssl_tools.experiments.covid_detection.mlp

+
+

Attributes

+ + + + + + +

options

+
+
+

Classes

+ + + + + + + + + + + + +

FlattenBCELoss

MLPClassifierTest

Helper class that provides a standard way to create an ABC using

MLPClassifierTrain

Helper class that provides a standard way to create an ABC using

+
+
+

Module Contents

+
+
+class ssl_tools.experiments.covid_detection.mlp.FlattenBCELoss
+

Bases: torch.nn.BCELoss

+
+
+forward(input, target)
+
+
Parameters:
+
    +
  • input (torch.Tensor)

  • +
  • target (torch.Tensor)

  • +
+
+
Return type:
+

torch.Tensor

+
+
+
+ +
+ +
+
+class ssl_tools.experiments.covid_detection.mlp.MLPClassifierTest(input_size=16, hidden_size=128, num_hidden_layers=1, num_classes=1, learning_rate=0.001, *args, **kwargs)
+

Bases: ssl_tools.experiments.covid_detection.classification_base.CovidDetectionEvaluator

+

Helper class that provides a standard way to create an ABC using +inheritance.

+
+
Parameters:
+
    +
  • input_size (int)

  • +
  • hidden_size (int)

  • +
  • num_hidden_layers (int)

  • +
  • num_classes (int)

  • +
  • learning_rate (float)

  • +
+
+
+
+
+_MODEL_NAME = 'mlp'
+
+ +
+
+get_model()
+

Get the model to use for the experiment.

+
+

Returns

+
+
L.LightningModule

The model to use for the experiment

+
+
+
+
+
Return type:
+

lightning.LightningModule

+
+
+
+ +
+ +
+
+class ssl_tools.experiments.covid_detection.mlp.MLPClassifierTrain(input_size=16, hidden_size=128, num_hidden_layers=1, num_classes=1, learning_rate=0.001, *args, **kwargs)
+

Bases: ssl_tools.experiments.covid_detection.classification_base.CovidDetectionTrain

+

Helper class that provides a standard way to create an ABC using +inheritance.

+
+
Parameters:
+
    +
  • input_size (int)

  • +
  • hidden_size (int)

  • +
  • num_hidden_layers (int)

  • +
  • num_classes (int)

  • +
  • learning_rate (float)

  • +
+
+
+
+
+_MODEL_NAME = 'mlp'
+
+ +
+
+get_model()
+

Get the model to use for the experiment.

+
+

Returns

+
+
L.LightningModule

The model to use for the experiment

+
+
+
+
+
Return type:
+

lightning.LightningModule

+
+
+
+ +
+ +
+
+ssl_tools.experiments.covid_detection.mlp.options
+
+ +
+
+ + +
+
+ +
+
+
+
+ + + + \ No newline at end of file diff --git a/autoapi/ssl_tools/experiments/experiment/index.html b/autoapi/ssl_tools/experiments/experiment/index.html index 94121eb..accbd3c 100644 --- a/autoapi/ssl_tools/experiments/experiment/index.html +++ b/autoapi/ssl_tools/experiments/experiment/index.html @@ -7,7 +7,7 @@ ssl_tools.experiments.experiment — SSLTools documentation - + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

ssl_tools.experiments.har_classification._classification_base

+
+

Classes

+ + + + + + + + + +

EvaluatorBase

Helper class that provides a standard way to create an ABC using

PredictionHeadClassifier

+
+
+

Functions

+ + + + + + + + + + + + + + + +

full_dataset_from_dataloader(dataloader)

generate_embeddings(model, dataloader, trainer)

get_full_data_split(data_module, stage)

get_split_dataloader(stage, data_module)

+
+
+

Module Contents

+
+
+class ssl_tools.experiments.har_classification._classification_base.EvaluatorBase(results_file='results.csv', confusion_matrix_file='confusion_matrix.csv', confusion_matrix_image_file='confusion_matrix.png', tsne_plot_file='tsne_embeddings.png', embedding_file='embeddings.csv', predictions_file='predictions.csv', *args, **kwargs)
+

Bases: ssl_tools.experiments.LightningTest

+

Helper class that provides a standard way to create an ABC using +inheritance.

+
+
Parameters:
+
    +
  • results_file (str)

  • +
  • confusion_matrix_file (str)

  • +
  • confusion_matrix_image_file (str)

  • +
  • tsne_plot_file (str)

  • +
  • embedding_file (str)

  • +
  • predictions_file (str)

  • +
+
+
+
+
+_compute_classification_metrics(y_hat_logits, y, n_classes)
+
+
Parameters:
+
    +
  • y_hat_logits (torch.Tensor)

  • +
  • y (torch.Tensor)

  • +
  • n_classes (int)

  • +
+
+
Return type:
+

pandas.DataFrame

+
+
+
+ +
+
+_compute_embeddings(model, data_module, trainer)
+
+ +
+
+_plot_confusion_matrix(y_hat, y, n_classes, cm_file, cm_image_file)
+
+
Parameters:
+
    +
  • y_hat (torch.Tensor)

  • +
  • y (torch.Tensor)

  • +
  • n_classes (int)

  • +
  • cm_file (str)

  • +
  • cm_image_file (str)

  • +
+
+
+
+ +
+
+_plot_tnse_embeddings(embeddings, y, y_hat, n_components=2, tsne_plot_file='tsne_embeddings.png')
+
+
Parameters:
+
    +
  • embeddings (torch.Tensor)

  • +
  • y (torch.Tensor)

  • +
  • y_hat (torch.Tensor)

  • +
  • n_components (int)

  • +
  • tsne_plot_file (str)

  • +
+
+
+
+ +
+
+evaluate_embeddings(model, data_module, trainer)
+
+ +
+
+evaluate_model_performance(model, data_module, trainer)
+
+ +
+
+predict(model, dataloader, trainer)
+
+ +
+
+run_model(model, data_module, trainer)
+
+
Parameters:
+
+
+
+
+ +
+ +
+
+class ssl_tools.experiments.har_classification._classification_base.PredictionHeadClassifier(prediction_head, num_classes=6)
+

Bases: ssl_tools.models.nets.simple.SimpleClassificationNet

+
+
Parameters:
+
    +
  • prediction_head (torch.nn.Module)

  • +
  • num_classes (int)

  • +
+
+
+
+ +
+
+ssl_tools.experiments.har_classification._classification_base.full_dataset_from_dataloader(dataloader)
+
+
Parameters:
+

dataloader (torch.utils.data.DataLoader)

+
+
+
+ +
+
+ssl_tools.experiments.har_classification._classification_base.generate_embeddings(model, dataloader, trainer)
+
+
Parameters:
+
+
+
+
+ +
+
+ssl_tools.experiments.har_classification._classification_base.get_full_data_split(data_module, stage)
+
+
Parameters:
+
    +
  • data_module (lightning.LightningDataModule)

  • +
  • stage (str)

  • +
+
+
+
+ +
+
+ssl_tools.experiments.har_classification._classification_base.get_split_dataloader(stage, data_module)
+
+
Parameters:
+
    +
  • stage (str)

  • +
  • data_module (lightning.LightningDataModule)

  • +
+
+
Return type:
+

torch.utils.data.DataLoader

+
+
+
+ +
+
+ + +
+
+ +
+
+
+
+ + + + \ No newline at end of file diff --git a/autoapi/ssl_tools/experiments/har_classification/cpc/index.html b/autoapi/ssl_tools/experiments/har_classification/cpc/index.html index 7c44590..c6d79b8 100644 --- a/autoapi/ssl_tools/experiments/har_classification/cpc/index.html +++ b/autoapi/ssl_tools/experiments/har_classification/cpc/index.html @@ -7,7 +7,7 @@ ssl_tools.experiments.har_classification.cpc — SSLTools documentation - + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

ssl_tools.experiments.har_classification.gru_encoder

+
+

Attributes

+ + + + + + +

options

+
+
+

Classes

+ + + + + + + + + + + + +

GRUClassifier

GRUClassifierTest

Helper class that provides a standard way to create an ABC using

GRUClassifierTrain

Helper class that provides a standard way to create an ABC using

+
+
+

Module Contents

+
+
+class ssl_tools.experiments.har_classification.gru_encoder.GRUClassifier(hidden_size=100, in_channels=6, num_classes=6, encoding_size=100, num_layers=1, dropout=0.0, bidirectional=True)
+

Bases: ssl_tools.models.nets.simple.SimpleClassificationNet

+
+
Parameters:
+
    +
  • hidden_size (int)

  • +
  • in_channels (int)

  • +
  • num_classes (int)

  • +
  • encoding_size (int)

  • +
  • num_layers (int)

  • +
  • dropout (float)

  • +
  • bidirectional (bool)

  • +
+
+
+
+ +
+
+class ssl_tools.experiments.har_classification.gru_encoder.GRUClassifierTest(data, hidden_size=100, in_channels=6, num_classes=6, encoding_size=100, num_layers=1, dropout=0.0, bidirectional=True, transforms='identity', *args, **kwargs)
+

Bases: ssl_tools.experiments.har_classification._classification_base.EvaluatorBase

+

Helper class that provides a standard way to create an ABC using +inheritance.

+
+
Parameters:
+
    +
  • data (str)

  • +
  • hidden_size (int)

  • +
  • in_channels (int)

  • +
  • num_classes (int)

  • +
  • encoding_size (int)

  • +
  • num_layers (int)

  • +
  • dropout (float)

  • +
  • bidirectional (bool)

  • +
  • transforms (str)

  • +
+
+
+
+
+_MODEL_NAME = 'GRU'
+
+ +
+
+get_data_module()
+

Get the datamodule to use for the experiment.

+
+

Returns

+
+
L.LightningDataModule

The datamodule to use for the experiment

+
+
+
+
+
Return type:
+

lightning.LightningDataModule

+
+
+
+ +
+
+get_model()
+

Get the model to use for the experiment.

+
+

Returns

+
+
L.LightningModule

The model to use for the experiment

+
+
+
+
+
Return type:
+

lightning.LightningModule

+
+
+
+ +
+ +
+
+class ssl_tools.experiments.har_classification.gru_encoder.GRUClassifierTrain(data, hidden_size=100, in_channels=6, num_classes=6, encoding_size=100, num_layers=1, dropout=0.0, bidirectional=True, transforms='identity', *args, **kwargs)
+

Bases: ssl_tools.experiments.LightningTrain

+

Helper class that provides a standard way to create an ABC using +inheritance.

+
+
Parameters:
+
    +
  • data (str)

  • +
  • hidden_size (int)

  • +
  • in_channels (int)

  • +
  • num_classes (int)

  • +
  • encoding_size (int)

  • +
  • num_layers (int)

  • +
  • dropout (float)

  • +
  • bidirectional (bool)

  • +
  • transforms (str)

  • +
+
+
+
+
+_MODEL_NAME = 'GRU'
+
+ +
+
+get_data_module()
+

Get the datamodule to use for the experiment.

+
+

Returns

+
+
L.LightningDataModule

The datamodule to use for the experiment

+
+
+
+
+
Return type:
+

lightning.LightningDataModule

+
+
+
+ +
+
+get_model()
+

Get the model to use for the experiment.

+
+

Returns

+
+
L.LightningModule

The model to use for the experiment

+
+
+
+
+
Return type:
+

lightning.LightningModule

+
+
+
+ +
+ +
+
+ssl_tools.experiments.har_classification.gru_encoder.options
+
+ +
+
+ + +
+
+ +
+
+
+
+ + + + \ No newline at end of file diff --git a/autoapi/ssl_tools/experiments/har_classification/index.html b/autoapi/ssl_tools/experiments/har_classification/index.html index 494cb3c..a78960b 100644 --- a/autoapi/ssl_tools/experiments/har_classification/index.html +++ b/autoapi/ssl_tools/experiments/har_classification/index.html @@ -7,7 +7,7 @@ ssl_tools.experiments.har_classification — SSLTools documentation - + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

ssl_tools.experiments.har_classification.mlp_classifier

+
+

Attributes

+ + + + + + +

options

+
+
+

Classes

+ + + + + + + + + +

MLPClassifierTest

Helper class that provides a standard way to create an ABC using

MLPClassifierTrain

Helper class that provides a standard way to create an ABC using

+
+
+

Module Contents

+
+
+class ssl_tools.experiments.har_classification.mlp_classifier.MLPClassifierTest(data, input_size=360, hidden_size=64, num_hidden_layers=1, num_classes=6, transforms='identity', *args, **kwargs)
+

Bases: ssl_tools.experiments.har_classification._classification_base.EvaluatorBase

+

Helper class that provides a standard way to create an ABC using +inheritance.

+
+
Parameters:
+
    +
  • data (str)

  • +
  • input_size (int)

  • +
  • hidden_size (int)

  • +
  • num_hidden_layers (int)

  • +
  • num_classes (int)

  • +
  • transforms (str)

  • +
+
+
+
+
+_MODEL_NAME = 'MLP'
+
+ +
+
+get_data_module()
+

Get the datamodule to use for the experiment.

+
+

Returns

+
+
L.LightningDataModule

The datamodule to use for the experiment

+
+
+
+
+
Return type:
+

lightning.LightningDataModule

+
+
+
+ +
+
+get_model()
+

Get the model to use for the experiment.

+
+

Returns

+
+
L.LightningModule

The model to use for the experiment

+
+
+
+
+
Return type:
+

lightning.LightningModule

+
+
+
+ +
+ +
+
+class ssl_tools.experiments.har_classification.mlp_classifier.MLPClassifierTrain(data, input_size=360, hidden_size=64, num_hidden_layers=1, num_classes=6, transforms='identity', *args, **kwargs)
+

Bases: ssl_tools.experiments.LightningTrain

+

Helper class that provides a standard way to create an ABC using +inheritance.

+
+
Parameters:
+
    +
  • data (str)

  • +
  • input_size (int)

  • +
  • hidden_size (int)

  • +
  • num_hidden_layers (int)

  • +
  • num_classes (int)

  • +
  • transforms (str)

  • +
+
+
+
+
+_MODEL_NAME = 'MLP'
+
+ +
+
+get_data_module()
+

Get the datamodule to use for the experiment.

+
+

Returns

+
+
L.LightningDataModule

The datamodule to use for the experiment

+
+
+
+
+
Return type:
+

lightning.LightningDataModule

+
+
+
+ +
+
+get_model()
+

Get the model to use for the experiment.

+
+

Returns

+
+
L.LightningModule

The model to use for the experiment

+
+
+
+
+
Return type:
+

lightning.LightningModule

+
+
+
+ +
+ +
+
+ssl_tools.experiments.har_classification.mlp_classifier.options
+
+ +
+
+ + +
+
+ +
+
+
+
+ + + + \ No newline at end of file diff --git a/autoapi/ssl_tools/experiments/har_classification/simple1Dconv_classifier/index.html b/autoapi/ssl_tools/experiments/har_classification/simple1Dconv_classifier/index.html new file mode 100644 index 0000000..d2ac2e9 --- /dev/null +++ b/autoapi/ssl_tools/experiments/har_classification/simple1Dconv_classifier/index.html @@ -0,0 +1,288 @@ + + + + + + + ssl_tools.experiments.har_classification.simple1Dconv_classifier — SSLTools documentation + + + + + + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

ssl_tools.experiments.har_classification.simple1Dconv_classifier

+
+

Attributes

+ + + + + + +

options

+
+
+

Classes

+ + + + + + + + + +

Simple1DConvNetTest

Helper class that provides a standard way to create an ABC using

Simple1DConvNetTrain

Helper class that provides a standard way to create an ABC using

+
+
+

Module Contents

+
+
+class ssl_tools.experiments.har_classification.simple1Dconv_classifier.Simple1DConvNetTest(data, input_shape=(6, 60), num_classes=6, transforms='identity', *args, **kwargs)
+

Bases: ssl_tools.experiments.har_classification._classification_base.EvaluatorBase

+

Helper class that provides a standard way to create an ABC using +inheritance.

+
+
Parameters:
+
    +
  • data (str)

  • +
  • input_shape (Tuple[int, int])

  • +
  • num_classes (int)

  • +
  • transforms (str)

  • +
+
+
+
+
+_MODEL_NAME = 'Simple1DConvNet'
+
+ +
+
+get_data_module()
+

Get the datamodule to use for the experiment.

+
+

Returns

+
+
L.LightningDataModule

The datamodule to use for the experiment

+
+
+
+
+
Return type:
+

lightning.LightningDataModule

+
+
+
+ +
+
+get_model()
+

Get the model to use for the experiment.

+
+

Returns

+
+
L.LightningModule

The model to use for the experiment

+
+
+
+
+
Return type:
+

lightning.LightningModule

+
+
+
+ +
+ +
+
+class ssl_tools.experiments.har_classification.simple1Dconv_classifier.Simple1DConvNetTrain(data, input_shape=(6, 60), num_classes=6, transforms='identity', *args, **kwargs)
+

Bases: ssl_tools.experiments.LightningTrain

+

Helper class that provides a standard way to create an ABC using +inheritance.

+
+
Parameters:
+
    +
  • data (str)

  • +
  • input_shape (Tuple[int, int])

  • +
  • num_classes (int)

  • +
  • transforms (str)

  • +
+
+
+
+
+_MODEL_NAME = 'Simple1DConvNet'
+
+ +
+
+get_data_module()
+

Get the datamodule to use for the experiment.

+
+

Returns

+
+
L.LightningDataModule

The datamodule to use for the experiment

+
+
+
+
+
Return type:
+

lightning.LightningDataModule

+
+
+
+ +
+
+get_model()
+

Get the model to use for the experiment.

+
+

Returns

+
+
L.LightningModule

The model to use for the experiment

+
+
+
+
+
Return type:
+

lightning.LightningModule

+
+
+
+ +
+ +
+
+ssl_tools.experiments.har_classification.simple1Dconv_classifier.options
+
+ +
+
+ + +
+
+ +
+
+
+
+ + + + \ No newline at end of file diff --git a/autoapi/ssl_tools/experiments/har_classification/simple2Dconv_classifier/index.html b/autoapi/ssl_tools/experiments/har_classification/simple2Dconv_classifier/index.html new file mode 100644 index 0000000..4c28414 --- /dev/null +++ b/autoapi/ssl_tools/experiments/har_classification/simple2Dconv_classifier/index.html @@ -0,0 +1,288 @@ + + + + + + + ssl_tools.experiments.har_classification.simple2Dconv_classifier — SSLTools documentation + + + + + + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

ssl_tools.experiments.har_classification.simple2Dconv_classifier

+
+

Attributes

+ + + + + + +

options

+
+
+

Classes

+ + + + + + + + + +

Simple2DConvNetTest

Helper class that provides a standard way to create an ABC using

Simple2DConvNetTrain

Helper class that provides a standard way to create an ABC using

+
+
+

Module Contents

+
+
+class ssl_tools.experiments.har_classification.simple2Dconv_classifier.Simple2DConvNetTest(data, input_shape=(6, 1, 60), num_classes=6, transforms='identity', *args, **kwargs)
+

Bases: ssl_tools.experiments.har_classification._classification_base.EvaluatorBase

+

Helper class that provides a standard way to create an ABC using +inheritance.

+
+
Parameters:
+
    +
  • data (str)

  • +
  • input_shape (Tuple[int, int, int])

  • +
  • num_classes (int)

  • +
  • transforms (str)

  • +
+
+
+
+
+_MODEL_NAME = 'Simple2DConvNet'
+
+ +
+
+get_data_module()
+

Get the datamodule to use for the experiment.

+
+

Returns

+
+
L.LightningDataModule

The datamodule to use for the experiment

+
+
+
+
+
Return type:
+

lightning.LightningDataModule

+
+
+
+ +
+
+get_model()
+

Get the model to use for the experiment.

+
+

Returns

+
+
L.LightningModule

The model to use for the experiment

+
+
+
+
+
Return type:
+

lightning.LightningModule

+
+
+
+ +
+ +
+
+class ssl_tools.experiments.har_classification.simple2Dconv_classifier.Simple2DConvNetTrain(data, input_shape=(6, 1, 60), num_classes=6, transforms='identity', *args, **kwargs)
+

Bases: ssl_tools.experiments.LightningTrain

+

Helper class that provides a standard way to create an ABC using +inheritance.

+
+
Parameters:
+
    +
  • data (str)

  • +
  • input_shape (Tuple[int, int, int])

  • +
  • num_classes (int)

  • +
  • transforms (str)

  • +
+
+
+
+
+_MODEL_NAME = 'Simple2DConvNet'
+
+ +
+
+get_data_module()
+

Get the datamodule to use for the experiment.

+
+

Returns

+
+
L.LightningDataModule

The datamodule to use for the experiment

+
+
+
+
+
Return type:
+

lightning.LightningDataModule

+
+
+
+ +
+
+get_model()
+

Get the model to use for the experiment.

+
+

Returns

+
+
L.LightningModule

The model to use for the experiment

+
+
+
+
+
Return type:
+

lightning.LightningModule

+
+
+
+ +
+ +
+
+ssl_tools.experiments.har_classification.simple2Dconv_classifier.options
+
+ +
+
+ + +
+
+ +
+
+
+
+ + + + \ No newline at end of file diff --git a/autoapi/ssl_tools/experiments/har_classification/tfc/index.html b/autoapi/ssl_tools/experiments/har_classification/tfc/index.html index 34edae7..f193ff5 100644 --- a/autoapi/ssl_tools/experiments/har_classification/tfc/index.html +++ b/autoapi/ssl_tools/experiments/har_classification/tfc/index.html @@ -7,7 +7,7 @@ ssl_tools.experiments.har_classification.tfc — SSLTools documentation - + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

ssl_tools.experiments.har_classification.tfc_head_classifier

+
+

Attributes

+ + + + + + +

options

+
+
+

Classes

+ + + + + + + + + +

TFCHeadClassifierTest

Helper class that provides a standard way to create an ABC using

TFCHeadClassifierTrain

Helper class that provides a standard way to create an ABC using

+
+
+

Module Contents

+
+
+class ssl_tools.experiments.har_classification.tfc_head_classifier.TFCHeadClassifierTest(data, input_size=360, num_classes=6, transforms='identity', *args, **kwargs)
+

Bases: ssl_tools.experiments.har_classification._classification_base.EvaluatorBase

+

Helper class that provides a standard way to create an ABC using +inheritance.

+
+
Parameters:
+
    +
  • data (str)

  • +
  • input_size (int)

  • +
  • num_classes (int)

  • +
  • transforms (str)

  • +
+
+
+
+
+_MODEL_NAME = 'TFCPredictionHead'
+
+ +
+
+get_data_module()
+

Get the datamodule to use for the experiment.

+
+

Returns

+
+
L.LightningDataModule

The datamodule to use for the experiment

+
+
+
+
+
Return type:
+

lightning.LightningDataModule

+
+
+
+ +
+
+get_model()
+

Get the model to use for the experiment.

+
+

Returns

+
+
L.LightningModule

The model to use for the experiment

+
+
+
+
+
Return type:
+

lightning.LightningModule

+
+
+
+ +
+ +
+
+class ssl_tools.experiments.har_classification.tfc_head_classifier.TFCHeadClassifierTrain(data, input_size=360, num_classes=6, transforms='identity', *args, **kwargs)
+

Bases: ssl_tools.experiments.LightningTrain

+

Helper class that provides a standard way to create an ABC using +inheritance.

+
+
Parameters:
+
    +
  • data (str)

  • +
  • input_size (int)

  • +
  • num_classes (int)

  • +
  • transforms (str)

  • +
+
+
+
+
+_MODEL_NAME = 'TFCPredictionHead'
+
+ +
+
+get_data_module()
+

Get the datamodule to use for the experiment.

+
+

Returns

+
+
L.LightningDataModule

The datamodule to use for the experiment

+
+
+
+
+
Return type:
+

lightning.LightningDataModule

+
+
+
+ +
+
+get_model()
+

Get the model to use for the experiment.

+
+

Returns

+
+
L.LightningModule

The model to use for the experiment

+
+
+
+
+
Return type:
+

lightning.LightningModule

+
+
+
+ +
+ +
+
+ssl_tools.experiments.har_classification.tfc_head_classifier.options
+
+ +
+
+ + +
+
+ +
+
+
+
+ + + + \ No newline at end of file diff --git a/autoapi/ssl_tools/experiments/har_classification/tnc/index.html b/autoapi/ssl_tools/experiments/har_classification/tnc/index.html index c3f68ba..6855ec8 100644 --- a/autoapi/ssl_tools/experiments/har_classification/tnc/index.html +++ b/autoapi/ssl_tools/experiments/har_classification/tnc/index.html @@ -7,7 +7,7 @@ ssl_tools.experiments.har_classification.tnc — SSLTools documentation - + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

ssl_tools.experiments.har_classification.tnc_head_classifier

+
+

Attributes

+ + + + + + +

options

+
+
+

Classes

+ + + + + + + + + +

TNCHeadClassifierTest

Helper class that provides a standard way to create an ABC using

TNCHeadClassifierTrain

Helper class that provides a standard way to create an ABC using

+
+
+

Module Contents

+
+
+class ssl_tools.experiments.har_classification.tnc_head_classifier.TNCHeadClassifierTest(data, input_size=360, num_classes=6, transforms='identity', *args, **kwargs)
+

Bases: ssl_tools.experiments.har_classification._classification_base.EvaluatorBase

+

Helper class that provides a standard way to create an ABC using +inheritance.

+
+
Parameters:
+
    +
  • data (str)

  • +
  • input_size (int)

  • +
  • num_classes (int)

  • +
  • transforms (str)

  • +
+
+
+
+
+_MODEL_NAME = 'TNCPredictionHead'
+
+ +
+
+get_data_module()
+

Get the datamodule to use for the experiment.

+
+

Returns

+
+
L.LightningDataModule

The datamodule to use for the experiment

+
+
+
+
+
Return type:
+

lightning.LightningDataModule

+
+
+
+ +
+
+get_model()
+

Get the model to use for the experiment.

+
+

Returns

+
+
L.LightningModule

The model to use for the experiment

+
+
+
+
+
Return type:
+

lightning.LightningModule

+
+
+
+ +
+ +
+
+class ssl_tools.experiments.har_classification.tnc_head_classifier.TNCHeadClassifierTrain(data, input_size=360, num_classes=6, transforms='identity', *args, **kwargs)
+

Bases: ssl_tools.experiments.LightningTrain

+

Helper class that provides a standard way to create an ABC using +inheritance.

+
+
Parameters:
+
    +
  • data (str)

  • +
  • input_size (int)

  • +
  • num_classes (int)

  • +
  • transforms (str)

  • +
+
+
+
+
+_MODEL_NAME = 'TNCPredictionHead'
+
+ +
+
+get_data_module()
+

Get the datamodule to use for the experiment.

+
+

Returns

+
+
L.LightningDataModule

The datamodule to use for the experiment

+
+
+
+
+
Return type:
+

lightning.LightningDataModule

+
+
+
+ +
+
+get_model()
+

Get the model to use for the experiment.

+
+

Returns

+
+
L.LightningModule

The model to use for the experiment

+
+
+
+
+
Return type:
+

lightning.LightningModule

+
+
+
+ +
+ +
+
+ssl_tools.experiments.har_classification.tnc_head_classifier.options
+
+ +
+
+ + +
+
+ +
+
+
+
+ + + + \ No newline at end of file diff --git a/autoapi/ssl_tools/experiments/har_classification/utils/index.html b/autoapi/ssl_tools/experiments/har_classification/utils/index.html new file mode 100644 index 0000000..323cdd5 --- /dev/null +++ b/autoapi/ssl_tools/experiments/har_classification/utils/index.html @@ -0,0 +1,291 @@ + + + + + + + ssl_tools.experiments.har_classification.utils — SSLTools documentation + + + + + + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

ssl_tools.experiments.har_classification.utils

+
+

Classes

+ + + + + + + + + + + + + + + +

DimensionAdder

FFT

Flatten

Spectrogram

+
+
+

Module Contents

+
+
+class ssl_tools.experiments.har_classification.utils.DimensionAdder(dim)
+
+
Parameters:
+

dim (int)

+
+
+
+
+__call__(x)
+
+ +
+ +
+
+class ssl_tools.experiments.har_classification.utils.FFT(absolute=True, centered=False)
+
+
Parameters:
+
    +
  • absolute (bool)

  • +
  • centered (bool)

  • +
+
+
+
+
+__call__(x)
+

Aplly FFT to the input signal. It apply the FFT into each channel +of the input signal.

+
+

Parameters

+
+
xnp.ndarray

An array with shape (n_channels, n_samples) containing the input

+
+
+
+
+

Returns

+
+
np.ndarray

The FFT of the input signal. The shape of the output is +(n_channels, n_samples) if absolute is False, and +(n_channels, n_samples//2) if absolute is True.

+
+
+
+
+
Parameters:
+

x (numpy.ndarray)

+
+
Return type:
+

numpy.ndarray

+
+
+
+ +
+ +
+
+class ssl_tools.experiments.har_classification.utils.Flatten
+
+
+__call__(x)
+

Flatten the input signal. It apply the flatten into each channel +of the input signal.

+
+

Parameters

+
+
xnp.ndarray

An array with shape (n_channels, n_samples) containing the input

+
+
+
+
+

Returns

+
+
np.ndarray

The flatten of the input signal. The shape of the output is +(n_channels, n_samples).

+
+
+
+
+
Parameters:
+

x (numpy.ndarray)

+
+
Return type:
+

numpy.ndarray

+
+
+
+ +
+ +
+
+class ssl_tools.experiments.har_classification.utils.Spectrogram(fs=20, nperseg=16, noverlap=8, nfft=16)
+
+
+__call__(x)
+

Aplly Spectrogram to the input signal. It apply the Spectrogram into each channel +of the input signal.

+
+

Parameters

+
+
xnp.ndarray

An array with shape (n_channels, n_samples) containing the input

+
+
+
+
+

Returns

+
+
np.ndarray

The Spectrogram of the input signal. The shape of the output is +(n_channels, n_samples) if absolute is False, and +(n_channels, n_samples//2) if absolute is True.

+
+
+
+
+
Parameters:
+

x (numpy.ndarray)

+
+
Return type:
+

numpy.ndarray

+
+
+
+ +
+ +
+
+ + +
+
+ +
+
+
+
+ + + + \ No newline at end of file diff --git a/autoapi/ssl_tools/experiments/index.html b/autoapi/ssl_tools/experiments/index.html index cdcb01a..f24b015 100644 --- a/autoapi/ssl_tools/experiments/index.html +++ b/autoapi/ssl_tools/experiments/index.html @@ -7,7 +7,7 @@ ssl_tools.experiments — SSLTools documentation - + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

ssl_tools.losses.contrastive_loss

+
+

Classes

+ + + + + + +

ContrastiveLoss

+
+
+

Module Contents

+
+
+class ssl_tools.losses.contrastive_loss.ContrastiveLoss(margin=1.0)
+

Bases: torch.nn.Module

+
+
Parameters:
+

margin (float)

+
+
+
+
+forward(y_true, y_pred)
+
+ +
+ +
+
+ + +
+
+ +
+
+
+
+ + + + \ No newline at end of file diff --git a/autoapi/ssl_tools/losses/index.html b/autoapi/ssl_tools/losses/index.html index 8901c09..4482f76 100644 --- a/autoapi/ssl_tools/losses/index.html +++ b/autoapi/ssl_tools/losses/index.html @@ -7,7 +7,7 @@ ssl_tools.losses — SSLTools documentation - + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

ssl_tools.models.nets.cnn_ha_etal

+
+

Classes

+ + + + + + + + + +

CNN_HaEtAl_1D

CNN_HaEtAl_2D

+
+
+

Module Contents

+
+
+class ssl_tools.models.nets.cnn_ha_etal.CNN_HaEtAl_1D(input_shape=(1, 6, 60), num_classes=6, learning_rate=0.001)
+

Bases: ssl_tools.models.nets.simple.SimpleClassificationNet

+
+
Parameters:
+
    +
  • input_shape (Tuple[int, int, int])

  • +
  • num_classes (int)

  • +
  • learning_rate (float)

  • +
+
+
+
+
+_calculate_fc_input_features(backbone, input_shape)
+

Run a single forward pass with a random input to get the number of +features after the convolutional layers.

+
+

Parameters

+
+
backbonetorch.nn.Module

The backbone of the network

+
+
input_shapeTuple[int, int, int]

The input shape of the network.

+
+
+
+
+

Returns

+
+
int

The number of features after the convolutional layers.

+
+
+
+
+
Parameters:
+
    +
  • backbone (torch.nn.Module)

  • +
  • input_shape (Tuple[int, int, int])

  • +
+
+
Return type:
+

int

+
+
+
+ +
+
+_create_backbone(input_shape)
+
+
Parameters:
+

input_shape (Tuple[int, int])

+
+
Return type:
+

torch.nn.Module

+
+
+
+ +
+
+_create_fc(input_features, num_classes)
+
+
Parameters:
+
    +
  • input_features (int)

  • +
  • num_classes (int)

  • +
+
+
Return type:
+

torch.nn.Module

+
+
+
+ +
+ +
+
+class ssl_tools.models.nets.cnn_ha_etal.CNN_HaEtAl_2D(pad_at=(3,), input_shape=(1, 6, 60), num_classes=6, learning_rate=0.001)
+

Bases: ssl_tools.models.nets.simple.SimpleClassificationNet

+
+
Parameters:
+
    +
  • pad_at (List[int])

  • +
  • input_shape (Tuple[int, int, int])

  • +
  • num_classes (int)

  • +
  • learning_rate (float)

  • +
+
+
+
+
+_calculate_fc_input_features(backbone, input_shape)
+

Run a single forward pass with a random input to get the number of +features after the convolutional layers.

+
+

Parameters

+
+
backbonetorch.nn.Module

The backbone of the network

+
+
input_shapeTuple[int, int, int]

The input shape of the network.

+
+
+
+
+

Returns

+
+
int

The number of features after the convolutional layers.

+
+
+
+
+
Parameters:
+
    +
  • backbone (torch.nn.Module)

  • +
  • input_shape (Tuple[int, int, int])

  • +
+
+
Return type:
+

int

+
+
+
+ +
+
+_create_backbone(input_shape)
+
+
Parameters:
+

input_shape (Tuple[int, int])

+
+
Return type:
+

torch.nn.Module

+
+
+
+ +
+
+_create_fc(input_features, num_classes)
+
+
Parameters:
+
    +
  • input_features (int)

  • +
  • num_classes (int)

  • +
+
+
Return type:
+

torch.nn.Module

+
+
+
+ +
+ +
+
+ + +
+
+ +
+
+
+
+ + + + \ No newline at end of file diff --git a/autoapi/ssl_tools/models/nets/cnn_pf/index.html b/autoapi/ssl_tools/models/nets/cnn_pf/index.html new file mode 100644 index 0000000..c7fe40c --- /dev/null +++ b/autoapi/ssl_tools/models/nets/cnn_pf/index.html @@ -0,0 +1,253 @@ + + + + + + + ssl_tools.models.nets.cnn_pf — SSLTools documentation + + + + + + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

ssl_tools.models.nets.cnn_pf

+
+

Classes

+ + + + + + + + + + + + +

CNN_PFF_2D

CNN_PF_2D

CNN_PF_Backbone

+
+
+

Module Contents

+
+
+class ssl_tools.models.nets.cnn_pf.CNN_PFF_2D(*args, **kwargs)
+

Bases: CNN_PF_2D

+
+ +
+
+class ssl_tools.models.nets.cnn_pf.CNN_PF_2D(pad_at, input_shape=(1, 6, 60), out_channels=16, num_classes=6, learning_rate=0.001, include_middle=False)
+

Bases: ssl_tools.models.nets.simple.SimpleClassificationNet

+
+
Parameters:
+
    +
  • pad_at (int)

  • +
  • input_shape (Tuple[int, int, int])

  • +
  • out_channels (int)

  • +
  • num_classes (int)

  • +
  • learning_rate (float)

  • +
  • include_middle (bool)

  • +
+
+
+
+
+_calculate_fc_input_features(backbone, input_shape)
+

Run a single forward pass with a random input to get the number of +features after the convolutional layers.

+
+

Parameters

+
+
backbonetorch.nn.Module

The backbone of the network

+
+
input_shapeTuple[int, int, int]

The input shape of the network.

+
+
+
+
+

Returns

+
+
int

The number of features after the convolutional layers.

+
+
+
+
+
Parameters:
+
    +
  • backbone (torch.nn.Module)

  • +
  • input_shape (Tuple[int, int, int])

  • +
+
+
Return type:
+

int

+
+
+
+ +
+
+_create_fc(input_features, num_classes)
+
+
Parameters:
+
    +
  • input_features (int)

  • +
  • num_classes (int)

  • +
+
+
Return type:
+

torch.nn.Module

+
+
+
+ +
+ +
+
+class ssl_tools.models.nets.cnn_pf.CNN_PF_Backbone(pad_at, input_shape, out_channels=16, include_middle=False)
+

Bases: torch.nn.Module

+
+
Parameters:
+
    +
  • pad_at (int)

  • +
  • input_shape (Tuple[int, int, int])

  • +
  • out_channels (int)

  • +
  • include_middle (bool)

  • +
+
+
+
+
+forward(x)
+
+ +
+ +
+
+ + +
+
+ +
+
+
+
+ + + + \ No newline at end of file diff --git a/autoapi/ssl_tools/models/nets/convae/index.html b/autoapi/ssl_tools/models/nets/convae/index.html new file mode 100644 index 0000000..6d06095 --- /dev/null +++ b/autoapi/ssl_tools/models/nets/convae/index.html @@ -0,0 +1,257 @@ + + + + + + + ssl_tools.models.nets.convae — SSLTools documentation + + + + + + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

ssl_tools.models.nets.convae

+
+

Classes

+ + + + + + + + + + + + + + + + + + + + + +

ContrastiveConvolutionalAutoEncoder

ContrastiveConvolutionalAutoEncoder2D

ConvolutionalAutoEncoder

ConvolutionalAutoEncoder2D

_ConvolutionalAutoEncoder

_ConvolutionalAutoEncoder2D

+
+
+

Module Contents

+
+
+class ssl_tools.models.nets.convae.ContrastiveConvolutionalAutoEncoder(input_shape=(1, 16), learning_rate=0.001, margin=1.0)
+

Bases: ssl_tools.models.nets.simple.SimpleReconstructionNet

+
+
Parameters:
+
    +
  • input_shape (Tuple[int, int])

  • +
  • learning_rate (float)

  • +
  • margin (float)

  • +
+
+
+
+ +
+
+class ssl_tools.models.nets.convae.ContrastiveConvolutionalAutoEncoder2D(input_shape=(4, 4, 1), learning_rate=0.001, margin=1.0)
+

Bases: ssl_tools.models.nets.simple.SimpleReconstructionNet

+
+
Parameters:
+
    +
  • input_shape (Tuple[int, int, int])

  • +
  • learning_rate (float)

  • +
  • margin (float)

  • +
+
+
+
+ +
+
+class ssl_tools.models.nets.convae.ConvolutionalAutoEncoder(input_shape=(1, 16), learning_rate=0.001)
+

Bases: ssl_tools.models.nets.simple.SimpleReconstructionNet

+
+
Parameters:
+
    +
  • input_shape (Tuple[int, int])

  • +
  • learning_rate (float)

  • +
+
+
+
+ +
+
+class ssl_tools.models.nets.convae.ConvolutionalAutoEncoder2D(input_shape=(1, 4, 4), learning_rate=0.001)
+

Bases: ssl_tools.models.nets.simple.SimpleReconstructionNet

+
+
Parameters:
+
    +
  • input_shape (Tuple[int, int, int])

  • +
  • learning_rate (float)

  • +
+
+
+
+ +
+
+class ssl_tools.models.nets.convae._ConvolutionalAutoEncoder(input_shape=(1, 16))
+

Bases: torch.nn.Module

+
+
Parameters:
+

input_shape (Tuple[int, int])

+
+
+
+
+forward(x)
+
+ +
+ +
+
+class ssl_tools.models.nets.convae._ConvolutionalAutoEncoder2D(input_shape=(1, 4, 4))
+

Bases: torch.nn.Module

+
+
Parameters:
+

input_shape (Tuple[int, int, int])

+
+
+
+
+forward(x)
+
+ +
+ +
+
+ + +
+
+ +
+
+
+
+ + + + \ No newline at end of file diff --git a/autoapi/ssl_tools/models/nets/convnet/index.html b/autoapi/ssl_tools/models/nets/convnet/index.html index 359b567..b9d5558 100644 --- a/autoapi/ssl_tools/models/nets/convnet/index.html +++ b/autoapi/ssl_tools/models/nets/convnet/index.html @@ -7,7 +7,7 @@ ssl_tools.models.nets.convnet — SSLTools documentation - + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

ssl_tools.models.nets.deep_conv_lstm

+
+

Classes

+ + + + + + + + + +

ConvLSTMCell

DeepConvLSTM

+
+
+

Module Contents

+
+
+class ssl_tools.models.nets.deep_conv_lstm.ConvLSTMCell(input_shape)
+

Bases: torch.nn.Module

+
+
Parameters:
+

input_shape (tuple)

+
+
+
+
+_calculate_conv_output_shape(backbone, input_shape)
+
+
Parameters:
+

input_shape (Tuple[int, int, int])

+
+
Return type:
+

int

+
+
+
+ +
+
+forward(x)
+
+ +
+ +
+
+class ssl_tools.models.nets.deep_conv_lstm.DeepConvLSTM(input_shape=(1, 6, 60), num_classes=6, learning_rate=0.001)
+

Bases: ssl_tools.models.nets.simple.SimpleClassificationNet

+
+
Parameters:
+
    +
  • input_shape (Tuple[int, int, int])

  • +
  • num_classes (int)

  • +
  • learning_rate (float)

  • +
+
+
+
+
+_calculate_fc_input_features(backbone, input_shape)
+

Run a single forward pass with a random input to get the number of +features after the convolutional layers.

+
+

Parameters

+
+
backbonetorch.nn.Module

The backbone of the network

+
+
input_shapeTuple[int, int, int]

The input shape of the network.

+
+
+
+
+

Returns

+
+
int

The number of features after the convolutional layers.

+
+
+
+
+
Parameters:
+
    +
  • backbone (torch.nn.Module)

  • +
  • input_shape (Tuple[int, int, int])

  • +
+
+
Return type:
+

int

+
+
+
+ +
+
+_create_backbone(input_shape)
+
+
Parameters:
+

input_shape (Tuple[int, int])

+
+
Return type:
+

torch.nn.Module

+
+
+
+ +
+
+_create_fc(input_features, num_classes)
+
+
Parameters:
+
    +
  • input_features (int)

  • +
  • num_classes (int)

  • +
+
+
Return type:
+

torch.nn.Module

+
+
+
+ +
+ +
+
+ + +
+
+ +
+
+
+
+ + + + \ No newline at end of file diff --git a/autoapi/ssl_tools/models/nets/deep_convnet/index.html b/autoapi/ssl_tools/models/nets/deep_convnet/index.html new file mode 100644 index 0000000..5aff0d6 --- /dev/null +++ b/autoapi/ssl_tools/models/nets/deep_convnet/index.html @@ -0,0 +1,293 @@ + + + + + + + ssl_tools.models.nets.deep_convnet — SSLTools documentation + + + + + + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

ssl_tools.models.nets.deep_convnet

+
+

Classes

+ + + + + + +

DeepConvNet

+
+
+

Functions

+ + + + + + +

main()

+
+
+

Module Contents

+
+
+class ssl_tools.models.nets.deep_convnet.DeepConvNet(input_channels=6, time_steps=60, num_classes=6, learning_rate=0.001)
+

Bases: lightning.LightningModule

+
+
Parameters:
+
    +
  • input_channels (int)

  • +
  • time_steps (int)

  • +
  • num_classes (int)

  • +
  • learning_rate (float)

  • +
+
+
+
+
+_calculate_fc_input_features(input_channels, time_steps)
+

Calculate the number of input features of the fully connected layer. +Basically, it performs a forward pass with a dummy input to get the +output shape after the convolutional layers.

+
+

Parameters

+
+
input_channelsint

The number of input channels.

+
+
+
+
+

Returns

+
+
int

The number of input features of the fully connected layer.

+
+
+
+
+
Parameters:
+
    +
  • input_channels (int)

  • +
  • time_steps (int)

  • +
+
+
Return type:
+

int

+
+
+
+ +
+
+_common_step(batch, batch_idx, prefix)
+
+ +
+
+_compute_metrics(y_hat, y, stage)
+

Compute the metrics.

+
+

Parameters

+
+
y_hattorch.Tensor

The predictions of the model

+
+
ytorch.Tensor

The ground truth labels

+
+
stagestr

The stage of the training loop (train, val or test)

+
+
+
+
+

Returns

+
+
Dict[str, float]

A dictionary containing the metrics. The keys are the names of the +metrics, and the values are the values of the metrics.

+
+
+
+
+
Parameters:
+
    +
  • y_hat (torch.Tensor)

  • +
  • y (torch.Tensor)

  • +
  • stage (str)

  • +
+
+
Return type:
+

Dict[str, float]

+
+
+
+ +
+
+configure_optimizers()
+
+ +
+
+forward(x)
+
+ +
+
+loss_function(X, y)
+
+ +
+
+predict_step(batch, batch_idx, dataloader_idx=None)
+
+ +
+
+test_step(batch, batch_idx)
+
+ +
+
+training_step(batch, batch_idx)
+
+ +
+
+validation_step(batch, batch_idx)
+
+ +
+ +
+
+ssl_tools.models.nets.deep_convnet.main()
+
+ +
+
+ + +
+
+ +
+
+
+
+ + + + \ No newline at end of file diff --git a/autoapi/ssl_tools/models/nets/imu_transformer/index.html b/autoapi/ssl_tools/models/nets/imu_transformer/index.html new file mode 100644 index 0000000..a71ca57 --- /dev/null +++ b/autoapi/ssl_tools/models/nets/imu_transformer/index.html @@ -0,0 +1,298 @@ + + + + + + + ssl_tools.models.nets.imu_transformer — SSLTools documentation + + + + + + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

ssl_tools.models.nets.imu_transformer

+
+

Classes

+ + + + + + + + + + + + +

IMUCNN

IMUTransformerEncoder

_IMUTransformerEncoder

input_shape: (tuple) shape of the input data

+
+
+

Functions

+ + + + + + + + + +

test_imu_cnn()

test_imu_transformer()

+
+
+

Module Contents

+
+
+class ssl_tools.models.nets.imu_transformer.IMUCNN(input_shape=(6, 60), hidden_dim=64, num_classes=6, dropout_factor=0.1, learning_rate=0.001)
+

Bases: ssl_tools.models.nets.simple.SimpleClassificationNet

+
+
Parameters:
+
    +
  • input_shape (tuple)

  • +
  • hidden_dim (int)

  • +
  • num_classes (int)

  • +
  • dropout_factor (float)

  • +
  • learning_rate (float)

  • +
+
+
+
+
+_calculate_fc_input_features(backbone, input_shape)
+
+
Parameters:
+
    +
  • backbone (torch.nn.Module)

  • +
  • input_shape (Tuple[int, int])

  • +
+
+
Return type:
+

int

+
+
+
+ +
+
+_create_backbone(input_shape, hidden_dim, dropout_factor)
+
+ +
+
+_create_fc(input_features, hidden_dim, num_classes)
+
+ +
+ +
+
+class ssl_tools.models.nets.imu_transformer.IMUTransformerEncoder(input_shape=(6, 60), transformer_dim=64, encode_position=True, nhead=8, dim_feedforward=128, transformer_dropout=0.1, transformer_activation='gelu', num_encoder_layers=6, num_classes=6, learning_rate=0.001)
+

Bases: ssl_tools.models.nets.simple.SimpleClassificationNet

+
+
Parameters:
+
    +
  • input_shape (tuple)

  • +
  • transformer_dim (int)

  • +
  • encode_position (bool)

  • +
  • nhead (int)

  • +
  • dim_feedforward (int)

  • +
  • transformer_dropout (float)

  • +
  • transformer_activation (str)

  • +
  • num_encoder_layers (int)

  • +
  • num_classes (int)

  • +
  • learning_rate (float)

  • +
+
+
+
+
+_create_backbone(input_shape, transformer_dim, encode_position, nhead, dim_feedforward, transformer_dropout, transformer_activation, num_encoder_layers)
+
+ +
+
+_create_fc(transform_dim, num_classes)
+
+ +
+ +
+
+class ssl_tools.models.nets.imu_transformer._IMUTransformerEncoder(input_shape=(6, 60), transformer_dim=64, encode_position=True, nhead=8, dim_feedforward=128, transformer_dropout=0.1, transformer_activation='gelu', num_encoder_layers=6)
+

Bases: torch.nn.Module

+

input_shape: (tuple) shape of the input data +transformer_dim: (int) dimension of the transformer +encode_position: (bool) whether to encode position or not +nhead: (int) number of attention heads +dim_feedforward: (int) dimension of the feedforward network +transformer_dropout: (float) dropout rate for the transformer +transformer_activation: (str) activation function for the transformer +num_encoder_layers: (int) number of transformer encoder layers +num_classes: (int) number of output classes

+
+
Parameters:
+
    +
  • input_shape (tuple)

  • +
  • transformer_dim (int)

  • +
  • encode_position (bool)

  • +
  • nhead (int)

  • +
  • dim_feedforward (int)

  • +
  • transformer_dropout (float)

  • +
  • transformer_activation (str)

  • +
  • num_encoder_layers (int)

  • +
+
+
+
+
+forward(x)
+

Forward

+
+

Parameters

+
+
x_type_

A tensor of shape (B, C, S) with B = batch size, C = channels, S = sequence length

+
+
+
+
+ +
+ +
+
+ssl_tools.models.nets.imu_transformer.test_imu_cnn()
+
+ +
+
+ssl_tools.models.nets.imu_transformer.test_imu_transformer()
+
+ +
+
+ + +
+
+ +
+
+
+
+ + + + \ No newline at end of file diff --git a/autoapi/ssl_tools/models/nets/inception_time/index.html b/autoapi/ssl_tools/models/nets/inception_time/index.html new file mode 100644 index 0000000..eca07b5 --- /dev/null +++ b/autoapi/ssl_tools/models/nets/inception_time/index.html @@ -0,0 +1,296 @@ + + + + + + + ssl_tools.models.nets.inception_time — SSLTools documentation + + + + + + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

ssl_tools.models.nets.inception_time

+
+

Classes

+ + + + + + + + + + + + + + + +

InceptionModule

InceptionTime

ShortcutLayer

_InceptionTime

+
+
+

Module Contents

+
+
+class ssl_tools.models.nets.inception_time.InceptionModule(input_shape=(6, 60), stride=1, kernel_size=41, nb_filters=32, use_bottleneck=True, bottleneck_size=32)
+

Bases: torch.nn.Module

+
+
Parameters:
+
    +
  • input_shape (Tuple[int, int])

  • +
  • stride (int)

  • +
  • kernel_size (int)

  • +
  • nb_filters (int)

  • +
  • use_bottleneck (bool)

  • +
  • bottleneck_size (int)

  • +
+
+
+
+
+build_model()
+
+ +
+
+forward(input_tensor)
+
+ +
+ +
+
+class ssl_tools.models.nets.inception_time.InceptionTime(input_shape=(6, 60), nb_filters=32, use_residual=True, use_bottleneck=True, depth=6, kernel_size=41, num_classes=6, learning_rate=0.001)
+

Bases: ssl_tools.models.nets.simple.SimpleClassificationNet

+
+
Parameters:
+
    +
  • input_shape (Tuple[int, int])

  • +
  • use_residual (bool)

  • +
  • use_bottleneck (bool)

  • +
  • depth (int)

  • +
  • kernel_size (int)

  • +
  • num_classes (int)

  • +
  • learning_rate (float)

  • +
+
+
+
+
+_calculate_fc_input_features(backbone, input_shape)
+

Run a single forward pass with a random input to get the number of +features after the convolutional layers.

+
+

Parameters

+
+
backbonetorch.nn.Module

The backbone of the network

+
+
input_shapeTuple[int, int, int]

The input shape of the network.

+
+
+
+
+

Returns

+
+
int

The number of features after the convolutional layers.

+
+
+
+
+
Parameters:
+
    +
  • backbone (torch.nn.Module)

  • +
  • input_shape (Tuple[int, int, int])

  • +
+
+
Return type:
+

int

+
+
+
+ +
+
+_create_fc(input_features, num_classes)
+
+
Parameters:
+
    +
  • input_features (int)

  • +
  • num_classes (int)

  • +
+
+
Return type:
+

torch.nn.Module

+
+
+
+ +
+ +
+
+class ssl_tools.models.nets.inception_time.ShortcutLayer(input_tensor_shape, out_tensor_shape)
+

Bases: torch.nn.Module

+
+
+forward(input_tensor, output_tensor)
+
+ +
+ +
+
+class ssl_tools.models.nets.inception_time._InceptionTime(input_shape=(6, 60), nb_filters=32, use_residual=True, use_bottleneck=True, depth=6, kernel_size=41)
+

Bases: torch.nn.Module

+
+
Parameters:
+
    +
  • input_shape (Tuple[int, int])

  • +
  • use_residual (bool)

  • +
  • use_bottleneck (bool)

  • +
  • depth (int)

  • +
  • kernel_size (int)

  • +
+
+
+
+
+build_model()
+
+ +
+
+forward(x)
+
+ +
+ +
+
+ + +
+
+ +
+
+
+
+ + + + \ No newline at end of file diff --git a/autoapi/ssl_tools/models/nets/index.html b/autoapi/ssl_tools/models/nets/index.html index 00f8854..61ec68a 100644 --- a/autoapi/ssl_tools/models/nets/index.html +++ b/autoapi/ssl_tools/models/nets/index.html @@ -7,7 +7,7 @@ ssl_tools.models.nets — SSLTools documentation - + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

ssl_tools.models.nets.lstm_ae

+
+

Classes

+ + + + + + + + + +

LSTMAutoencoder

Create a LSTM Autoencoder model

_LSTMAutoEncoder

+
+
+

Module Contents

+
+
+class ssl_tools.models.nets.lstm_ae.LSTMAutoencoder(input_shape=(16, 1), learning_rate=0.001)
+

Bases: ssl_tools.models.nets.simple.SimpleReconstructionNet

+

Create a LSTM Autoencoder model

+
+

Parameters

+
+
input_shapeTuple[int, int], optional

The shape of the input. The first element is the sequence length and the second is the number of features, by default (16, 1)

+
+
learning_ratefloat, optional

Learning rate for Adam optimizer, by default 1e-3

+
+
+
+
+
Parameters:
+
    +
  • input_shape (Tuple[int, int])

  • +
  • learning_rate (float)

  • +
+
+
+
+ +
+
+class ssl_tools.models.nets.lstm_ae._LSTMAutoEncoder(input_shape=(16, 1))
+

Bases: torch.nn.Module

+
+
Parameters:
+

input_shape (Tuple[int, int])

+
+
+
+
+forward(x)
+
+ +
+ +
+
+ + +
+
+ +
+
+
+
+ + + + \ No newline at end of file diff --git a/autoapi/ssl_tools/models/nets/multi_channel_cnn/index.html b/autoapi/ssl_tools/models/nets/multi_channel_cnn/index.html new file mode 100644 index 0000000..e57760b --- /dev/null +++ b/autoapi/ssl_tools/models/nets/multi_channel_cnn/index.html @@ -0,0 +1,266 @@ + + + + + + + ssl_tools.models.nets.multi_channel_cnn — SSLTools documentation + + + + + + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

ssl_tools.models.nets.multi_channel_cnn

+
+

Classes

+ + + + + + + + + +

MultiChannelCNN_HAR

Create a simple 1D Convolutional Network with 3 layers and 2 fully

_MultiChannelCNN_HAR

+
+
+

Module Contents

+
+
+class ssl_tools.models.nets.multi_channel_cnn.MultiChannelCNN_HAR(input_shape=(1, 6, 60), num_classes=6, learning_rate=0.001)
+

Bases: ssl_tools.models.nets.simple.SimpleClassificationNet

+

Create a simple 1D Convolutional Network with 3 layers and 2 fully +connected layers.

+
+

Parameters

+
+
input_shapeTuple[int, int], optional

A 2-tuple containing the number of input channels and the number of +features, by default (6, 60).

+
+
num_classesint, optional

Number of output classes, by default 6

+
+
learning_ratefloat, optional

Learning rate for Adam optimizer, by default 1e-3

+
+
+
+
+_calculate_fc_input_features(backbone, input_shape)
+

Run a single forward pass with a random input to get the number of +features after the convolutional layers.

+
+

Parameters

+
+
backbonetorch.nn.Module

The backbone of the network

+
+
input_shapeTuple[int, int]

The input shape of the network.

+
+
+
+
+

Returns

+
+
int

The number of features after the convolutional layers.

+
+
+
+
+
Parameters:
+
    +
  • backbone (torch.nn.Module)

  • +
  • input_shape (Tuple[int, int])

  • +
+
+
Return type:
+

int

+
+
+
+ +
+
+_create_backbone(input_channels)
+
+
Parameters:
+

input_channels (int)

+
+
Return type:
+

torch.nn.Module

+
+
+
+ +
+
+_create_fc(input_features, num_classes)
+
+
Parameters:
+
    +
  • input_features (int)

  • +
  • num_classes (int)

  • +
+
+
Return type:
+

torch.nn.Module

+
+
+
+ +
+
+
Parameters:
+
    +
  • input_shape (Tuple[int, int, int])

  • +
  • num_classes (int)

  • +
  • learning_rate (float)

  • +
+
+
+
+ +
+
+class ssl_tools.models.nets.multi_channel_cnn._MultiChannelCNN_HAR(input_channels=1, concatenate=True)
+

Bases: torch.nn.Module

+
+
Parameters:
+
    +
  • input_channels (int)

  • +
  • concatenate (bool)

  • +
+
+
+
+
+forward(x)
+
+ +
+ +
+
+ + +
+
+ +
+
+
+
+ + + + \ No newline at end of file diff --git a/autoapi/ssl_tools/models/nets/resnet1d/index.html b/autoapi/ssl_tools/models/nets/resnet1d/index.html new file mode 100644 index 0000000..7bda4ef --- /dev/null +++ b/autoapi/ssl_tools/models/nets/resnet1d/index.html @@ -0,0 +1,307 @@ + + + + + + + ssl_tools.models.nets.resnet1d — SSLTools documentation + + + + + + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

ssl_tools.models.nets.resnet1d

+

resnet for 1-d signal data, pytorch version

+

Shenda Hong, Oct 2019

+
+

Classes

+ + + + + + + + + + + + + + + + + + +

BasicBlock

ResNet Basic Block

MyConv1dPadSame

extend nn.Conv1d to support SAME padding

MyMaxPool1dPadSame

extend nn.MaxPool1d to support SAME padding

ResNet1D

_ResNet1D

Input:

+
+
+

Functions

+ + + + + + +

main()

+
+
+

Module Contents

+
+
+class ssl_tools.models.nets.resnet1d.BasicBlock(in_channels, out_channels, kernel_size, stride, groups, downsample, use_bn, use_do, is_first_block=False)
+

Bases: torch.nn.Module

+

ResNet Basic Block

+
+
+forward(x)
+
+ +
+ +
+
+class ssl_tools.models.nets.resnet1d.MyConv1dPadSame(in_channels, out_channels, kernel_size, stride, groups=1)
+

Bases: torch.nn.Module

+

extend nn.Conv1d to support SAME padding

+
+
+forward(x)
+
+ +
+ +
+
+class ssl_tools.models.nets.resnet1d.MyMaxPool1dPadSame(kernel_size)
+

Bases: torch.nn.Module

+

extend nn.MaxPool1d to support SAME padding

+
+
+forward(x)
+
+ +
+ +
+
+class ssl_tools.models.nets.resnet1d.ResNet1D(input_shape, base_filters=128, kernel_size=16, stride=2, groups=32, n_block=48, num_classes=6, downsample_gap=2, increasefilter_gap=4, use_bn=True, use_do=True, learning_rate=0.001)
+

Bases: ssl_tools.models.nets.simple.SimpleClassificationNet

+
+
Parameters:
+
    +
  • input_shape (Tuple[int, int])

  • +
  • learning_rate (float)

  • +
+
+
+
+
+_calculate_fc_input_features(backbone, input_shape)
+

Run a single forward pass with a random input to get the number of +features after the convolutional layers.

+
+

Parameters

+
+
backbonetorch.nn.Module

The backbone of the network

+
+
input_shapeTuple[int, int]

The input shape of the network.

+
+
+
+
+

Returns

+
+
int

The number of features after the convolutional layers.

+
+
+
+
+
Parameters:
+
    +
  • backbone (torch.nn.Module)

  • +
  • input_shape (Tuple[int, int])

  • +
+
+
Return type:
+

int

+
+
+
+ +
+
+_create_fc(input_features, num_classes)
+
+
Parameters:
+
    +
  • input_features (int)

  • +
  • num_classes (int)

  • +
+
+
Return type:
+

torch.nn.Module

+
+
+
+ +
+ +
+
+class ssl_tools.models.nets.resnet1d._ResNet1D(in_channels, base_filters=64, kernel_size=16, stride=2, groups=32, n_block=48, n_classes=6, downsample_gap=2, increasefilter_gap=4, use_bn=True, use_do=True, verbose=False)
+

Bases: torch.nn.Module

+
+
Input:

X: (n_samples, n_channel, n_length) +Y: (n_samples)

+
+
Output:

out: (n_samples)

+
+
Pararmetes:

in_channels: dim of input, the same as n_channel +base_filters: number of filters in the first several Conv layer, it will double at every 4 layers +kernel_size: width of kernel +stride: stride of kernel moving +groups: set larget to 1 as ResNeXt +n_block: number of blocks +n_classes: number of classes

+
+
+
+
+forward(x)
+
+ +
+ +
+
+ssl_tools.models.nets.resnet1d.main()
+
+ +
+
+ + +
+
+ +
+
+
+
+ + + + \ No newline at end of file diff --git a/autoapi/ssl_tools/models/nets/resnet_1d/index.html b/autoapi/ssl_tools/models/nets/resnet_1d/index.html new file mode 100644 index 0000000..e8b527a --- /dev/null +++ b/autoapi/ssl_tools/models/nets/resnet_1d/index.html @@ -0,0 +1,329 @@ + + + + + + + ssl_tools.models.nets.resnet_1d — SSLTools documentation + + + + + + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

ssl_tools.models.nets.resnet_1d

+
+

Classes

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

ConvolutionalBlock

ResNet1DBase

ResNet1D_8

ResNetBlock

ResNetSE1D_5

ResNetSE1D_8

ResNetSEBlock

SqueezeAndExcitation1D

_ResNet1D

+
+
+

Module Contents

+
+
+class ssl_tools.models.nets.resnet_1d.ConvolutionalBlock(in_channels, activation_cls=None)
+

Bases: torch.nn.Module

+
+
Parameters:
+
    +
  • in_channels (int)

  • +
  • activation_cls (torch.nn.Module)

  • +
+
+
+
+
+forward(x)
+
+ +
+ +
+
+class ssl_tools.models.nets.resnet_1d.ResNet1DBase(resnet_block_cls=ResNetBlock, activation_cls=torch.nn.ReLU, input_shape=(6, 60), num_classes=6, num_residual_blocks=5, reduction_ratio=2, learning_rate=0.001)
+

Bases: ssl_tools.models.nets.simple.SimpleClassificationNet

+
+
Parameters:
+
    +
  • resnet_block_cls (type)

  • +
  • activation_cls (type)

  • +
  • input_shape (Tuple[int, int])

  • +
  • num_classes (int)

  • +
  • num_residual_blocks (int)

  • +
  • learning_rate (float)

  • +
+
+
+
+
+_calculate_fc_input_features(backbone, input_shape)
+

Run a single forward pass with a random input to get the number of +features after the convolutional layers.

+
+

Parameters

+
+
backbonetorch.nn.Module

The backbone of the network

+
+
input_shapeTuple[int, int, int]

The input shape of the network.

+
+
+
+
+

Returns

+
+
int

The number of features after the convolutional layers.

+
+
+
+
+
Parameters:
+
    +
  • backbone (torch.nn.Module)

  • +
  • input_shape (Tuple[int, int, int])

  • +
+
+
Return type:
+

int

+
+
+
+ +
+ +
+
+class ssl_tools.models.nets.resnet_1d.ResNet1D_8(*args, **kwargs)
+

Bases: ResNet1DBase

+
+ +
+
+class ssl_tools.models.nets.resnet_1d.ResNetBlock(in_channels=64, activation_cls=torch.nn.ReLU)
+

Bases: torch.nn.Module

+
+
Parameters:
+
    +
  • in_channels (int)

  • +
  • activation_cls (torch.nn.Module)

  • +
+
+
+
+
+forward(x)
+
+ +
+ +
+
+class ssl_tools.models.nets.resnet_1d.ResNetSE1D_5(*args, **kwargs)
+

Bases: ResNet1DBase

+
+ +
+
+class ssl_tools.models.nets.resnet_1d.ResNetSE1D_8(*args, **kwargs)
+

Bases: ResNet1DBase

+
+ +
+
+class ssl_tools.models.nets.resnet_1d.ResNetSEBlock(*args, **kwargs)
+

Bases: ResNetBlock

+
+ +
+
+class ssl_tools.models.nets.resnet_1d.SqueezeAndExcitation1D(in_channels, reduction_ratio=2)
+

Bases: torch.nn.Module

+
+
Parameters:
+
    +
  • in_channels (int)

  • +
  • reduction_ratio (int)

  • +
+
+
+
+
+forward(input_tensor)
+
+ +
+ +
+
+class ssl_tools.models.nets.resnet_1d._ResNet1D(input_shape, residual_block_cls=ResNetBlock, activation_cls=torch.nn.ReLU, num_residual_blocks=5, reduction_ratio=2)
+

Bases: torch.nn.Module

+
+
Parameters:
+
    +
  • input_shape (Tuple[int, int])

  • +
  • activation_cls (torch.nn.Module)

  • +
  • num_residual_blocks (int)

  • +
+
+
+
+
+forward(x)
+
+ +
+ +
+
+ + +
+
+ +
+
+
+
+ + + + \ No newline at end of file diff --git a/autoapi/ssl_tools/models/nets/simple/index.html b/autoapi/ssl_tools/models/nets/simple/index.html new file mode 100644 index 0000000..915c26d --- /dev/null +++ b/autoapi/ssl_tools/models/nets/simple/index.html @@ -0,0 +1,376 @@ + + + + + + + ssl_tools.models.nets.simple — SSLTools documentation + + + + + + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

ssl_tools.models.nets.simple

+
+

Classes

+ + + + + + + + + + + + +

MLPClassifier

SimpleClassificationNet

SimpleReconstructionNet

+
+
+

Module Contents

+
+
+class ssl_tools.models.nets.simple.MLPClassifier(input_size, hidden_size, num_hidden_layers, output_size, learning_rate=0.001, flatten=True, loss_fn=None, train_metrics=None, val_metrics=None, test_metrics=None)
+

Bases: SimpleClassificationNet

+
+
Parameters:
+
    +
  • input_size (int)

  • +
  • hidden_size (int)

  • +
  • num_hidden_layers (int)

  • +
  • output_size (int)

  • +
  • learning_rate (float)

  • +
  • flatten (bool)

  • +
  • loss_fn (torch.nn.Module)

  • +
  • train_metrics (Dict[str, torch.Tensor])

  • +
  • val_metrics (Dict[str, torch.Tensor])

  • +
  • test_metrics (Dict[str, torch.Tensor])

  • +
+
+
+
+ +
+
+class ssl_tools.models.nets.simple.SimpleClassificationNet(backbone, fc, learning_rate=0.001, flatten=True, loss_fn=None, train_metrics=None, val_metrics=None, test_metrics=None)
+

Bases: lightning.LightningModule

+
+
Parameters:
+
    +
  • backbone (torch.nn.Module)

  • +
  • fc (torch.nn.Module)

  • +
  • learning_rate (float)

  • +
  • flatten (bool)

  • +
  • loss_fn (torch.nn.Module)

  • +
  • train_metrics (Dict[str, torch.Tensor])

  • +
  • val_metrics (Dict[str, torch.Tensor])

  • +
  • test_metrics (Dict[str, torch.Tensor])

  • +
+
+
+
+
+compute_metrics(y_hat, y, step_name)
+
+ +
+
+configure_optimizers()
+
+ +
+
+forward(x)
+
+
Parameters:
+

x (torch.Tensor)

+
+
+
+ +
+
+loss_func(y_hat, y)
+
+ +
+
+predict_step(batch, batch_idx, dataloader_idx=None)
+
+ +
+
+single_step(batch, batch_idx, step_name)
+
+
Parameters:
+
    +
  • batch (torch.Tensor)

  • +
  • batch_idx (int)

  • +
  • step_name (str)

  • +
+
+
+
+ +
+
+test_step(batch, batch_idx)
+
+
Parameters:
+
    +
  • batch (torch.Tensor)

  • +
  • batch_idx (int)

  • +
+
+
+
+ +
+
+training_step(batch, batch_idx)
+
+
Parameters:
+
    +
  • batch (torch.Tensor)

  • +
  • batch_idx (int)

  • +
+
+
+
+ +
+
+validation_step(batch, batch_idx)
+
+
Parameters:
+
    +
  • batch (torch.Tensor)

  • +
  • batch_idx (int)

  • +
+
+
+
+ +
+ +
+
+class ssl_tools.models.nets.simple.SimpleReconstructionNet(backbone, learning_rate=0.001, loss_fn=None)
+

Bases: lightning.LightningModule

+
+
Parameters:
+
    +
  • backbone (torch.nn.Module)

  • +
  • learning_rate (float)

  • +
  • loss_fn (torch.nn.Module)

  • +
+
+
+
+
+configure_optimizers()
+
+ +
+
+forward(x)
+
+
Parameters:
+

x (torch.Tensor)

+
+
+
+ +
+
+loss_func(y_hat, y)
+
+ +
+
+predict_step(batch, batch_idx, dataloader_idx=None)
+
+ +
+
+single_step(batch, batch_idx, step_name)
+
+
Parameters:
+
    +
  • batch (torch.Tensor)

  • +
  • batch_idx (int)

  • +
  • step_name (str)

  • +
+
+
+
+ +
+
+test_step(batch, batch_idx)
+
+
Parameters:
+
    +
  • batch (torch.Tensor)

  • +
  • batch_idx (int)

  • +
+
+
+
+ +
+
+training_step(batch, batch_idx)
+
+
Parameters:
+
    +
  • batch (torch.Tensor)

  • +
  • batch_idx (int)

  • +
+
+
+
+ +
+
+validation_step(batch, batch_idx)
+
+
Parameters:
+
    +
  • batch (torch.Tensor)

  • +
  • batch_idx (int)

  • +
+
+
+
+ +
+ +
+
+ + +
+
+ +
+
+
+
+ + + + \ No newline at end of file diff --git a/autoapi/ssl_tools/models/nets/transformer/index.html b/autoapi/ssl_tools/models/nets/transformer/index.html new file mode 100644 index 0000000..4f00097 --- /dev/null +++ b/autoapi/ssl_tools/models/nets/transformer/index.html @@ -0,0 +1,174 @@ + + + + + + + ssl_tools.models.nets.transformer — SSLTools documentation + + + + + + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

ssl_tools.models.nets.transformer

+
+

Classes

+ + + + + + +

SimpleTransformer

+
+
+

Module Contents

+
+
+class ssl_tools.models.nets.transformer.SimpleTransformer(in_channels=6, dim_feedforward=60, num_classes=6, heads=2, num_layers=2, learning_rate=0.001)
+

Bases: ssl_tools.models.nets.simple.SimpleClassificationNet

+
+
Parameters:
+
    +
  • in_channels (int)

  • +
  • num_classes (int)

  • +
  • heads (int)

  • +
  • num_layers (int)

  • +
  • learning_rate (float)

  • +
+
+
+
+
+configure_optimizers()
+
+ +
+ +
+
+ + +
+
+ +
+
+
+
+ + + + \ No newline at end of file diff --git a/autoapi/ssl_tools/models/nets/wisenet/index.html b/autoapi/ssl_tools/models/nets/wisenet/index.html index f5f2eb4..dfa711d 100644 --- a/autoapi/ssl_tools/models/nets/wisenet/index.html +++ b/autoapi/ssl_tools/models/nets/wisenet/index.html @@ -7,7 +7,7 @@ ssl_tools.models.nets.wisenet — SSLTools documentation - + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

ssl_tools.models.utils

+
+

Classes

+ + + + + + + + + + + + + + + +

RandomDataModule

RandomDataset

ShapePrinter

ZeroPadder2D

+
+
+

Module Contents

+
+
+class ssl_tools.models.utils.RandomDataModule(num_samples, num_classes, input_shape, transforms=None, batch_size=1)
+

Bases: lightning.LightningDataModule

+
+
Parameters:
+
    +
  • transforms (list)

  • +
  • batch_size (int)

  • +
+
+
+
+
+train_dataloader()
+
+ +
+ +
+
+class ssl_tools.models.utils.RandomDataset(num_samples=64, num_classes=6, input_shape=(6, 60), transforms=None)
+
+
Parameters:
+
    +
  • num_samples (int)

  • +
  • num_classes (int)

  • +
  • input_shape (tuple)

  • +
  • transforms (list)

  • +
+
+
+
+
+__getitem__(idx)
+
+ +
+
+__len__()
+
+ +
+ +
+
+class ssl_tools.models.utils.ShapePrinter(name='')
+

Bases: torch.nn.Module

+
+
+forward(x)
+
+ +
+ +
+
+class ssl_tools.models.utils.ZeroPadder2D(pad_at, padding_size)
+

Bases: torch.nn.Module

+
+
Parameters:
+
    +
  • pad_at (List[int])

  • +
  • padding_size (int)

  • +
+
+
+
+
+__repr__()
+
+
Return type:
+

str

+
+
+
+ +
+
+__str__()
+
+
Return type:
+

str

+
+
+
+ +
+
+forward(x)
+
+ +
+ +
+
+ + +
+
+ +
+
+
+
+ + + + \ No newline at end of file diff --git a/autoapi/ssl_tools/pipelines/base/index.html b/autoapi/ssl_tools/pipelines/base/index.html new file mode 100644 index 0000000..68a899e --- /dev/null +++ b/autoapi/ssl_tools/pipelines/base/index.html @@ -0,0 +1,172 @@ + + + + + + + ssl_tools.pipelines.base — SSLTools documentation + + + + + + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

ssl_tools.pipelines.base

+
+

Classes

+ + + + + + +

Pipeline

+
+
+

Module Contents

+
+
+class ssl_tools.pipelines.base.Pipeline
+

Bases: lightning.pytorch.core.mixins.HyperparametersMixin

+
+
+__call__()
+
+ +
+
+abstract run()
+
+
Return type:
+

Any

+
+
+
+ +
+ +
+
+ + +
+
+ +
+
+
+
+ + + + \ No newline at end of file diff --git a/autoapi/ssl_tools/pipelines/cli/index.html b/autoapi/ssl_tools/pipelines/cli/index.html new file mode 100644 index 0000000..cc5e292 --- /dev/null +++ b/autoapi/ssl_tools/pipelines/cli/index.html @@ -0,0 +1,180 @@ + + + + + + + ssl_tools.pipelines.cli — SSLTools documentation + + + + + + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

ssl_tools.pipelines.cli

+
+

Functions

+ + + + + + + + + +

auto_main(commands[, print_args])

get_parser(commands)

+
+
+

Module Contents

+
+
+ssl_tools.pipelines.cli.auto_main(commands, print_args=False)
+
+
Parameters:
+
+
+
Return type:
+

Any

+
+
+
+ +
+
+ssl_tools.pipelines.cli.get_parser(commands)
+
+
Parameters:
+

commands (Dict[str, ssl_tools.pipelines.base.Pipeline] | ssl_tools.pipelines.base.Pipeline)

+
+
+
+ +
+
+ + +
+
+ +
+
+
+
+ + + + \ No newline at end of file diff --git a/autoapi/ssl_tools/pipelines/har_classification/conv1d_conss/index.html b/autoapi/ssl_tools/pipelines/har_classification/conv1d_conss/index.html new file mode 100644 index 0000000..4f1901e --- /dev/null +++ b/autoapi/ssl_tools/pipelines/har_classification/conv1d_conss/index.html @@ -0,0 +1,232 @@ + + + + + + + ssl_tools.pipelines.har_classification.conv1d_conss — SSLTools documentation + + + + + + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

ssl_tools.pipelines.har_classification.conv1d_conss

+
+

Attributes

+ + + + + + +

experiment

+
+
+

Classes

+ + + + + + + + + + + + +

PartialEmbeddingEvaluator

PartialEmbeddingEvaluatorCallback

Simple1DConvNetFineTune2

+
+
+

Module Contents

+
+
+class ssl_tools.pipelines.har_classification.conv1d_conss.PartialEmbeddingEvaluator(experiment_name, model, data_module, trainer, **kwargs)
+

Bases: evaluator.EmbeddingEvaluator

+
+
Parameters:
+
    +
  • experiment_name (str)

  • +
  • trainer (lightning.Trainer)

  • +
+
+
+
+
+run()
+
+ +
+ +
+
+class ssl_tools.pipelines.har_classification.conv1d_conss.PartialEmbeddingEvaluatorCallback(experiment_name, frequency=1, **partal_embedding_evaluator_kwargs)
+

Bases: lightning.pytorch.callbacks.Callback

+
+
Parameters:
+
    +
  • experiment_name (str)

  • +
  • frequency (int)

  • +
+
+
+
+
+on_validation_end(trainer, pl_module)
+
+
Parameters:
+

trainer (lightning.Trainer)

+
+
+
+ +
+ +
+
+class ssl_tools.pipelines.har_classification.conv1d_conss.Simple1DConvNetFineTune2
+

Bases: simple1Dconv_classifier.Simple1DConvNetFineTune

+
+
+get_callbacks()
+
+
Return type:
+

List[lightning.pytorch.callbacks.Callback]

+
+
+
+ +
+ +
+
+ssl_tools.pipelines.har_classification.conv1d_conss.experiment
+
+ +
+
+ + +
+
+ +
+
+
+
+ + + + \ No newline at end of file diff --git a/autoapi/ssl_tools/pipelines/har_classification/cpc/index.html b/autoapi/ssl_tools/pipelines/har_classification/cpc/index.html new file mode 100644 index 0000000..05049ae --- /dev/null +++ b/autoapi/ssl_tools/pipelines/har_classification/cpc/index.html @@ -0,0 +1,321 @@ + + + + + + + ssl_tools.pipelines.har_classification.cpc — SSLTools documentation + + + + + + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

ssl_tools.pipelines.har_classification.cpc

+
+

Attributes

+ + + + + + +

options

+
+
+

Classes

+ + + + + + + + + +

CPCFineTune

Train a model using Lightning framework.

CPCPreTrain

Train a model using Lightning framework.

+
+
+

Module Contents

+
+
+class ssl_tools.pipelines.har_classification.cpc.CPCFineTune(data, encoding_size=128, num_classes=6, num_workers=None, **kwargs)
+

Bases: ssl_tools.pipelines.mlflow_train.LightningFineTuneMLFlow

+

Train a model using Lightning framework.

+
+

Parameters

+
+
experiment_namestr

Name of the experiment.

+
+
model_namestr

Name of the model.

+
+
dataset_namestr

Name of the dataset.

+
+
run_namestr, optional

The name of the run, by default None

+
+
acceleratorstr, optional

The accelerator to use, by default “cpu”

+
+
devicesint, optional

Number of accelerators to use, by default 1

+
+
num_nodesint, optional

Number of nodes, by default 1

+
+
strategystr, optional

Training strategy, by default “auto”

+
+
max_epochsint, optional

Maximium number of epochs, by default 1

+
+
batch_sizeint, optional

Batch size, by default 1

+
+
limit_train_batchesint | float, optional

Limit the number of batches to train, by default 1.0

+
+
limit_val_batchesint | float, optional

Limit the number of batches to test, by default 1.0

+
+
checkpoint_monitor_metricstr, optional

The metric to monitor for checkpointing, by default None

+
+
checkpoint_monitor_modestr, optional

The mode for checkpointing, by default “min”

+
+
patienceint, optional

The patience for early stopping, by default None

+
+
log_dirstr, optional

Location where logs will be saved, by default “./runs”

+
+
+
+
+get_data_module()
+
+
Return type:
+

lightning.LightningDataModule

+
+
+
+ +
+
+get_model()
+
+
Return type:
+

lightning.LightningModule

+
+
+
+ +
+
+
Parameters:
+
    +
  • data (str | List[str])

  • +
  • encoding_size (int)

  • +
  • num_classes (int)

  • +
  • num_workers (int)

  • +
+
+
+
+ +
+
+class ssl_tools.pipelines.har_classification.cpc.CPCPreTrain(data, encoding_size=128, in_channel=6, window_size=4, pad_length=False, num_classes=6, num_workers=None, **kwargs)
+

Bases: ssl_tools.pipelines.mlflow_train.LightningTrainMLFlow

+

Train a model using Lightning framework.

+
+

Parameters

+
+
experiment_namestr

Name of the experiment.

+
+
model_namestr

Name of the model.

+
+
dataset_namestr

Name of the dataset.

+
+
run_namestr, optional

The name of the run, by default None

+
+
acceleratorstr, optional

The accelerator to use, by default “cpu”

+
+
devicesint, optional

Number of accelerators to use, by default 1

+
+
num_nodesint, optional

Number of nodes, by default 1

+
+
strategystr, optional

Training strategy, by default “auto”

+
+
max_epochsint, optional

Maximium number of epochs, by default 1

+
+
batch_sizeint, optional

Batch size, by default 1

+
+
limit_train_batchesint | float, optional

Limit the number of batches to train, by default 1.0

+
+
limit_val_batchesint | float, optional

Limit the number of batches to test, by default 1.0

+
+
checkpoint_monitor_metricstr, optional

The metric to monitor for checkpointing, by default None

+
+
checkpoint_monitor_modestr, optional

The mode for checkpointing, by default “min”

+
+
patienceint, optional

The patience for early stopping, by default None

+
+
log_dirstr, optional

Location where logs will be saved, by default “./runs”

+
+
+
+
+get_data_module()
+
+
Return type:
+

lightning.LightningDataModule

+
+
+
+ +
+
+get_model()
+
+
Return type:
+

lightning.LightningModule

+
+
+
+ +
+
+
Parameters:
+
    +
  • data (str)

  • +
  • encoding_size (int)

  • +
  • in_channel (int)

  • +
  • window_size (int)

  • +
  • pad_length (bool)

  • +
  • num_classes (int)

  • +
  • num_workers (int)

  • +
+
+
+
+ +
+
+ssl_tools.pipelines.har_classification.cpc.options
+
+ +
+
+ + +
+
+ +
+
+
+
+ + + + \ No newline at end of file diff --git a/autoapi/ssl_tools/pipelines/har_classification/evaluator/index.html b/autoapi/ssl_tools/pipelines/har_classification/evaluator/index.html new file mode 100644 index 0000000..24cff5b --- /dev/null +++ b/autoapi/ssl_tools/pipelines/har_classification/evaluator/index.html @@ -0,0 +1,583 @@ + + + + + + + ssl_tools.pipelines.har_classification.evaluator — SSLTools documentation + + + + + + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

ssl_tools.pipelines.har_classification.evaluator

+
+

Attributes

+ + + + + + + + + +

options

transforms_map

+
+
+

Classes

+ + + + + + + + + + + + + + + +

CSVGenerator

EmbeddingEvaluator

EvaluateAll

HAREmbeddingEvaluator

+
+
+

Functions

+ + + + + + + + + + + + + + + + + + +

full_dataset_from_dataloader(dataloader)

generate_embeddings(model, dataloader, trainer)

get_full_data_split(data_module, stage)

get_split_dataloader(stage, data_module)

run_evaluator_wrapper(evaluator)

+
+
+

Module Contents

+
+
+class ssl_tools.pipelines.har_classification.evaluator.CSVGenerator(experiments, log_dir='./mlruns', results_file='results.csv')
+

Bases: ssl_tools.pipelines.base.Pipeline

+
+
Parameters:
+
    +
  • experiments (str | List[str])

  • +
  • log_dir (str)

  • +
  • results_file (str)

  • +
+
+
+
+
+property client
+
+ +
+
+run()
+
+ +
+ +
+
+class ssl_tools.pipelines.har_classification.evaluator.EmbeddingEvaluator(experiment_name, registered_model_name, registered_model_tags=None, experiment_tags=None, n_classes=7, run_name=None, accelerator='cpu', devices=1, num_nodes=1, num_workers=None, strategy='auto', batch_size=1, limit_predict_batches=1.0, log_dir='./mlruns', results_file='results.csv', confusion_matrix_file='confusion_matrix.csv', confusion_matrix_image_file='confusion_matrix.png', tsne_plot_file='tsne_embeddings.png', embedding_file='embeddings.csv', predictions_file='predictions.csv', add_epoch_info=False)
+

Bases: ssl_tools.pipelines.base.Pipeline

+
+
Parameters:
+
    +
  • experiment_name (str)

  • +
  • registered_model_name (str)

  • +
  • registered_model_tags (Dict[str, str])

  • +
  • experiment_tags (Dict[str, str])

  • +
  • n_classes (int)

  • +
  • run_name (str)

  • +
  • accelerator (str)

  • +
  • devices (int)

  • +
  • num_nodes (int)

  • +
  • num_workers (int)

  • +
  • strategy (str)

  • +
  • batch_size (int)

  • +
  • limit_predict_batches (int | float)

  • +
  • log_dir (str)

  • +
  • results_file (str)

  • +
  • confusion_matrix_file (str)

  • +
  • confusion_matrix_image_file (str)

  • +
  • tsne_plot_file (str)

  • +
  • embedding_file (str)

  • +
  • predictions_file (str)

  • +
  • add_epoch_info (bool)

  • +
+
+
+
+
+_compute_classification_metrics(y_hat_logits, y, n_classes)
+
+
Parameters:
+
    +
  • y_hat_logits (torch.Tensor)

  • +
  • y (torch.Tensor)

  • +
  • n_classes (int)

  • +
+
+
Return type:
+

pandas.DataFrame

+
+
+
+ +
+
+_confusion_matrix(y_hat, y, n_classes)
+
+
Parameters:
+
    +
  • y_hat (torch.Tensor)

  • +
  • y (torch.Tensor)

  • +
  • n_classes (int)

  • +
+
+
Return type:
+

pandas.DataFrame

+
+
+
+ +
+
+_evaluate_embeddings(model, y_hat, y, n_classes, run_id, artifact_path)
+
+ +
+
+_plot_confusion_matrix(cm, classes)
+
+
Parameters:
+
    +
  • cm (pandas.DataFrame)

  • +
  • classes (List[int])

  • +
+
+
Return type:
+

plotly.graph_objects.Figure

+
+
+
+ +
+
+_plot_tnse_embeddings(embeddings, y, y_hat, n_components=2)
+
+
Parameters:
+
    +
  • embeddings (torch.Tensor)

  • +
  • y (torch.Tensor)

  • +
  • y_hat (torch.Tensor)

  • +
  • n_components (int)

  • +
+
+
Return type:
+

plotly.graph_objects.Figure

+
+
+
+ +
+
+property client
+
+ +
+
+evaluate_embeddings(model, data_module, trainer)
+
+ +
+
+evaluate_model_performance(model, data_module, trainer)
+
+ +
+
+get_callbacks()
+
+
Return type:
+

List[lightning.Callback]

+
+
+
+ +
+
+abstract get_data_module()
+
+
Return type:
+

lightning.LightningDataModule

+
+
+
+ +
+
+get_logger()
+
+
Return type:
+

lightning.pytorch.loggers.Logger

+
+
+
+ +
+
+get_trainer(logger, callbacks)
+
+
Parameters:
+
    +
  • logger (lightning.pytorch.loggers.Logger)

  • +
  • callbacks (List[lightning.Callback])

  • +
+
+
Return type:
+

lightning.Trainer

+
+
+
+ +
+
+load_model()
+
+
Return type:
+

lightning.LightningModule

+
+
+
+ +
+
+predict(model, dataloader, trainer)
+
+ +
+
+run()
+
+ +
+
+run_task(model, data_module, trainer)
+
+
Parameters:
+
+
+
+
+ +
+ +
+
+class ssl_tools.pipelines.har_classification.evaluator.EvaluateAll(root_dataset_dir, experiment_id, experiment_names, config_dir=None, log_dir='./mlruns', skip_existing=True, accelerator='cpu', devices=1, num_nodes=1, num_workers=None, strategy='auto', batch_size=1, use_ray=False, ray_address=None)
+

Bases: ssl_tools.pipelines.base.Pipeline

+
+
Parameters:
+
    +
  • root_dataset_dir (str)

  • +
  • experiment_id (str | List[str])

  • +
  • experiment_names (str | List[str])

  • +
  • config_dir (str)

  • +
  • log_dir (str)

  • +
  • skip_existing (bool)

  • +
  • accelerator (str)

  • +
  • devices (int)

  • +
  • num_nodes (int)

  • +
  • num_workers (int)

  • +
  • strategy (str)

  • +
  • batch_size (int)

  • +
  • use_ray (bool)

  • +
  • ray_address (str)

  • +
+
+
+
+
+property client
+
+ +
+
+filter_runs(runs)
+
+ +
+
+get_runs(experiment_ids, search_string='')
+
+
Parameters:
+

search_string (str)

+
+
Return type:
+

pandas.DataFrame

+
+
+
+ +
+
+locate_config(model_name)
+
+ +
+
+run()
+
+ +
+
+summarize(runs)
+
+ +
+ +
+
+class ssl_tools.pipelines.har_classification.evaluator.HAREmbeddingEvaluator(data, transforms='identity', **kwargs)
+

Bases: EmbeddingEvaluator

+
+
Parameters:
+
    +
  • data (str)

  • +
  • transforms (str | List[str])

  • +
+
+
+
+
+get_data_module()
+
+
Return type:
+

lightning.LightningDataModule

+
+
+
+ +
+ +
+
+ssl_tools.pipelines.har_classification.evaluator.full_dataset_from_dataloader(dataloader)
+
+
Parameters:
+

dataloader (torch.utils.data.DataLoader)

+
+
+
+ +
+
+ssl_tools.pipelines.har_classification.evaluator.generate_embeddings(model, dataloader, trainer)
+
+
Parameters:
+
+
+
+
+ +
+
+ssl_tools.pipelines.har_classification.evaluator.get_full_data_split(data_module, stage)
+
+
Parameters:
+
    +
  • data_module (lightning.LightningDataModule)

  • +
  • stage (str)

  • +
+
+
+
+ +
+
+ssl_tools.pipelines.har_classification.evaluator.get_split_dataloader(stage, data_module)
+
+
Parameters:
+
    +
  • stage (str)

  • +
  • data_module (lightning.LightningDataModule)

  • +
+
+
Return type:
+

torch.utils.data.DataLoader

+
+
+
+ +
+
+ssl_tools.pipelines.har_classification.evaluator.options
+
+ +
+
+ssl_tools.pipelines.har_classification.evaluator.run_evaluator_wrapper(evaluator)
+
+
Parameters:
+

evaluator (EmbeddingEvaluator)

+
+
+
+ +
+
+ssl_tools.pipelines.har_classification.evaluator.transforms_map
+
+ +
+
+ + +
+
+ +
+
+
+
+ + + + \ No newline at end of file diff --git a/autoapi/ssl_tools/pipelines/har_classification/gru_encoder/index.html b/autoapi/ssl_tools/pipelines/har_classification/gru_encoder/index.html new file mode 100644 index 0000000..a743ae3 --- /dev/null +++ b/autoapi/ssl_tools/pipelines/har_classification/gru_encoder/index.html @@ -0,0 +1,357 @@ + + + + + + + ssl_tools.pipelines.har_classification.gru_encoder — SSLTools documentation + + + + + + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

ssl_tools.pipelines.har_classification.gru_encoder

+
+

Attributes

+ + + + + + +

options

+
+
+

Classes

+ + + + + + + + + + + + +

GRUClassifier

GRUClassifierFineTune

Train a model using Lightning framework.

GRUClassifierTrain

Train a model using Lightning framework.

+
+
+

Module Contents

+
+
+class ssl_tools.pipelines.har_classification.gru_encoder.GRUClassifier(hidden_size=100, in_channels=6, num_classes=6, encoding_size=100, num_layers=1, dropout=0.0, bidirectional=True)
+

Bases: ssl_tools.models.nets.simple.SimpleClassificationNet

+
+
Parameters:
+
    +
  • hidden_size (int)

  • +
  • in_channels (int)

  • +
  • num_classes (int)

  • +
  • encoding_size (int)

  • +
  • num_layers (int)

  • +
  • dropout (float)

  • +
  • bidirectional (bool)

  • +
+
+
+
+ +
+
+class ssl_tools.pipelines.har_classification.gru_encoder.GRUClassifierFineTune(data, num_classes=6, encoding_size=128, transforms='identity', num_workers=None, **kwargs)
+

Bases: ssl_tools.pipelines.mlflow_train.LightningFineTuneMLFlow

+

Train a model using Lightning framework.

+
+

Parameters

+
+
experiment_namestr

Name of the experiment.

+
+
model_namestr

Name of the model.

+
+
dataset_namestr

Name of the dataset.

+
+
run_namestr, optional

The name of the run, by default None

+
+
acceleratorstr, optional

The accelerator to use, by default “cpu”

+
+
devicesint, optional

Number of accelerators to use, by default 1

+
+
num_nodesint, optional

Number of nodes, by default 1

+
+
strategystr, optional

Training strategy, by default “auto”

+
+
max_epochsint, optional

Maximium number of epochs, by default 1

+
+
batch_sizeint, optional

Batch size, by default 1

+
+
limit_train_batchesint | float, optional

Limit the number of batches to train, by default 1.0

+
+
limit_val_batchesint | float, optional

Limit the number of batches to test, by default 1.0

+
+
checkpoint_monitor_metricstr, optional

The metric to monitor for checkpointing, by default None

+
+
checkpoint_monitor_modestr, optional

The mode for checkpointing, by default “min”

+
+
patienceint, optional

The patience for early stopping, by default None

+
+
log_dirstr, optional

Location where logs will be saved, by default “./runs”

+
+
+
+
+MODEL = 'GRU'
+
+ +
+
+get_data_module()
+
+
Return type:
+

lightning.LightningDataModule

+
+
+
+ +
+
+get_model()
+
+
Return type:
+

lightning.LightningModule

+
+
+
+ +
+
+
Parameters:
+
    +
  • data (str | List[str])

  • +
  • num_classes (int)

  • +
  • encoding_size (int)

  • +
  • transforms (str)

  • +
  • num_workers (int)

  • +
+
+
+
+ +
+
+class ssl_tools.pipelines.har_classification.gru_encoder.GRUClassifierTrain(data, hidden_size=100, in_channels=6, num_classes=6, encoding_size=100, num_layers=1, dropout=0.0, bidirectional=True, num_workers=None, transforms='identity', **kwargs)
+

Bases: ssl_tools.pipelines.mlflow_train.LightningTrainMLFlow

+

Train a model using Lightning framework.

+
+

Parameters

+
+
experiment_namestr

Name of the experiment.

+
+
model_namestr

Name of the model.

+
+
dataset_namestr

Name of the dataset.

+
+
run_namestr, optional

The name of the run, by default None

+
+
acceleratorstr, optional

The accelerator to use, by default “cpu”

+
+
devicesint, optional

Number of accelerators to use, by default 1

+
+
num_nodesint, optional

Number of nodes, by default 1

+
+
strategystr, optional

Training strategy, by default “auto”

+
+
max_epochsint, optional

Maximium number of epochs, by default 1

+
+
batch_sizeint, optional

Batch size, by default 1

+
+
limit_train_batchesint | float, optional

Limit the number of batches to train, by default 1.0

+
+
limit_val_batchesint | float, optional

Limit the number of batches to test, by default 1.0

+
+
checkpoint_monitor_metricstr, optional

The metric to monitor for checkpointing, by default None

+
+
checkpoint_monitor_modestr, optional

The mode for checkpointing, by default “min”

+
+
patienceint, optional

The patience for early stopping, by default None

+
+
log_dirstr, optional

Location where logs will be saved, by default “./runs”

+
+
+
+
+MODEL = 'GRU'
+
+ +
+
+get_data_module()
+
+
Return type:
+

lightning.LightningDataModule

+
+
+
+ +
+
+get_model()
+
+
Return type:
+

lightning.LightningModule

+
+
+
+ +
+
+
Parameters:
+
    +
  • data (str | List[str])

  • +
  • hidden_size (int)

  • +
  • in_channels (int)

  • +
  • num_classes (int)

  • +
  • encoding_size (int)

  • +
  • num_layers (int)

  • +
  • dropout (float)

  • +
  • bidirectional (bool)

  • +
  • num_workers (int)

  • +
  • transforms (str)

  • +
+
+
+
+ +
+
+ssl_tools.pipelines.har_classification.gru_encoder.options
+
+ +
+
+ + +
+
+ +
+
+
+
+ + + + \ No newline at end of file diff --git a/autoapi/ssl_tools/pipelines/har_classification/index.html b/autoapi/ssl_tools/pipelines/har_classification/index.html new file mode 100644 index 0000000..a8c16cd --- /dev/null +++ b/autoapi/ssl_tools/pipelines/har_classification/index.html @@ -0,0 +1,158 @@ + + + + + + + ssl_tools.pipelines.har_classification — SSLTools documentation + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/autoapi/ssl_tools/pipelines/har_classification/mlp/index.html b/autoapi/ssl_tools/pipelines/har_classification/mlp/index.html new file mode 100644 index 0000000..c1c0231 --- /dev/null +++ b/autoapi/ssl_tools/pipelines/har_classification/mlp/index.html @@ -0,0 +1,321 @@ + + + + + + + ssl_tools.pipelines.har_classification.mlp — SSLTools documentation + + + + + + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

ssl_tools.pipelines.har_classification.mlp

+
+

Attributes

+ + + + + + +

options

+
+
+

Classes

+ + + + + + + + + +

MLPClassifierFineTune

Train a model using Lightning framework.

MLPClassifierTrain

Train a model using Lightning framework.

+
+
+

Module Contents

+
+
+class ssl_tools.pipelines.har_classification.mlp.MLPClassifierFineTune(data, num_classes=6, transforms='identity', num_workers=None, **kwargs)
+

Bases: ssl_tools.pipelines.mlflow_train.LightningFineTuneMLFlow

+

Train a model using Lightning framework.

+
+

Parameters

+
+
experiment_namestr

Name of the experiment.

+
+
model_namestr

Name of the model.

+
+
dataset_namestr

Name of the dataset.

+
+
run_namestr, optional

The name of the run, by default None

+
+
acceleratorstr, optional

The accelerator to use, by default “cpu”

+
+
devicesint, optional

Number of accelerators to use, by default 1

+
+
num_nodesint, optional

Number of nodes, by default 1

+
+
strategystr, optional

Training strategy, by default “auto”

+
+
max_epochsint, optional

Maximium number of epochs, by default 1

+
+
batch_sizeint, optional

Batch size, by default 1

+
+
limit_train_batchesint | float, optional

Limit the number of batches to train, by default 1.0

+
+
limit_val_batchesint | float, optional

Limit the number of batches to test, by default 1.0

+
+
checkpoint_monitor_metricstr, optional

The metric to monitor for checkpointing, by default None

+
+
checkpoint_monitor_modestr, optional

The mode for checkpointing, by default “min”

+
+
patienceint, optional

The patience for early stopping, by default None

+
+
log_dirstr, optional

Location where logs will be saved, by default “./runs”

+
+
+
+
+get_data_module()
+
+
Return type:
+

lightning.LightningDataModule

+
+
+
+ +
+
+get_model()
+
+
Return type:
+

lightning.LightningModule

+
+
+
+ +
+
+
Parameters:
+
    +
  • data (str | List[str])

  • +
  • num_classes (int)

  • +
  • transforms (str)

  • +
  • num_workers (int)

  • +
+
+
+
+ +
+
+class ssl_tools.pipelines.har_classification.mlp.MLPClassifierTrain(data, input_size=360, hidden_size=64, num_hidden_layers=1, num_classes=6, transforms='identity', num_workers=None, **kwargs)
+

Bases: ssl_tools.pipelines.mlflow_train.LightningTrainMLFlow

+

Train a model using Lightning framework.

+
+

Parameters

+
+
experiment_namestr

Name of the experiment.

+
+
model_namestr

Name of the model.

+
+
dataset_namestr

Name of the dataset.

+
+
run_namestr, optional

The name of the run, by default None

+
+
acceleratorstr, optional

The accelerator to use, by default “cpu”

+
+
devicesint, optional

Number of accelerators to use, by default 1

+
+
num_nodesint, optional

Number of nodes, by default 1

+
+
strategystr, optional

Training strategy, by default “auto”

+
+
max_epochsint, optional

Maximium number of epochs, by default 1

+
+
batch_sizeint, optional

Batch size, by default 1

+
+
limit_train_batchesint | float, optional

Limit the number of batches to train, by default 1.0

+
+
limit_val_batchesint | float, optional

Limit the number of batches to test, by default 1.0

+
+
checkpoint_monitor_metricstr, optional

The metric to monitor for checkpointing, by default None

+
+
checkpoint_monitor_modestr, optional

The mode for checkpointing, by default “min”

+
+
patienceint, optional

The patience for early stopping, by default None

+
+
log_dirstr, optional

Location where logs will be saved, by default “./runs”

+
+
+
+
+get_data_module()
+
+
Return type:
+

lightning.LightningDataModule

+
+
+
+ +
+
+get_model()
+
+
Return type:
+

lightning.LightningModule

+
+
+
+ +
+
+
Parameters:
+
    +
  • data (str | List[str])

  • +
  • input_size (int)

  • +
  • hidden_size (int)

  • +
  • num_hidden_layers (int)

  • +
  • num_classes (int)

  • +
  • transforms (str)

  • +
  • num_workers (int)

  • +
+
+
+
+ +
+
+ssl_tools.pipelines.har_classification.mlp.options
+
+ +
+
+ + +
+
+ +
+
+
+
+ + + + \ No newline at end of file diff --git a/autoapi/ssl_tools/pipelines/har_classification/scripts/evaluate_all/index.html b/autoapi/ssl_tools/pipelines/har_classification/scripts/evaluate_all/index.html new file mode 100644 index 0000000..75bc20e --- /dev/null +++ b/autoapi/ssl_tools/pipelines/har_classification/scripts/evaluate_all/index.html @@ -0,0 +1,145 @@ + + + + + + + ssl_tools.pipelines.har_classification.scripts.evaluate_all — SSLTools documentation + + + + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • + +
  • + View page source +
  • +
+
+
+
+
+ +
+

ssl_tools.pipelines.har_classification.scripts.evaluate_all

+
+

Attributes

+ + + + + + +

options

+
+
+

Functions

+ + + + + + +

EvaluateAll(Pipeline)

+
+
+

Module Contents

+
+
+ssl_tools.pipelines.har_classification.scripts.evaluate_all.EvaluateAll(Pipeline)
+
+ +
+
+ssl_tools.pipelines.har_classification.scripts.evaluate_all.options
+
+ +
+
+ + +
+
+ +
+
+
+
+ + + + \ No newline at end of file diff --git a/autoapi/ssl_tools/pipelines/har_classification/simple1Dconv_classifier/index.html b/autoapi/ssl_tools/pipelines/har_classification/simple1Dconv_classifier/index.html new file mode 100644 index 0000000..7bc15b2 --- /dev/null +++ b/autoapi/ssl_tools/pipelines/har_classification/simple1Dconv_classifier/index.html @@ -0,0 +1,329 @@ + + + + + + + ssl_tools.pipelines.har_classification.simple1Dconv_classifier — SSLTools documentation + + + + + + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

ssl_tools.pipelines.har_classification.simple1Dconv_classifier

+
+

Attributes

+ + + + + + +

options

+
+
+

Classes

+ + + + + + + + + +

Simple1DConvNetFineTune

Train a model using Lightning framework.

Simple1DConvNetTrain

Train a model using Lightning framework.

+
+
+

Module Contents

+
+
+class ssl_tools.pipelines.har_classification.simple1Dconv_classifier.Simple1DConvNetFineTune(data, num_classes=6, transforms='identity', num_workers=None, **kwargs)
+

Bases: ssl_tools.pipelines.mlflow_train.LightningFineTuneMLFlow

+

Train a model using Lightning framework.

+
+

Parameters

+
+
experiment_namestr

Name of the experiment.

+
+
model_namestr

Name of the model.

+
+
dataset_namestr

Name of the dataset.

+
+
run_namestr, optional

The name of the run, by default None

+
+
acceleratorstr, optional

The accelerator to use, by default “cpu”

+
+
devicesint, optional

Number of accelerators to use, by default 1

+
+
num_nodesint, optional

Number of nodes, by default 1

+
+
strategystr, optional

Training strategy, by default “auto”

+
+
max_epochsint, optional

Maximium number of epochs, by default 1

+
+
batch_sizeint, optional

Batch size, by default 1

+
+
limit_train_batchesint | float, optional

Limit the number of batches to train, by default 1.0

+
+
limit_val_batchesint | float, optional

Limit the number of batches to test, by default 1.0

+
+
checkpoint_monitor_metricstr, optional

The metric to monitor for checkpointing, by default None

+
+
checkpoint_monitor_modestr, optional

The mode for checkpointing, by default “min”

+
+
patienceint, optional

The patience for early stopping, by default None

+
+
log_dirstr, optional

Location where logs will be saved, by default “./runs”

+
+
+
+
+MODEL = 'Simple1DConvNet'
+
+ +
+
+get_data_module()
+
+
Return type:
+

lightning.LightningDataModule

+
+
+
+ +
+
+get_model()
+
+
Return type:
+

lightning.LightningModule

+
+
+
+ +
+
+
Parameters:
+
    +
  • data (str | List[str])

  • +
  • num_classes (int)

  • +
  • transforms (str)

  • +
  • num_workers (int)

  • +
+
+
+
+ +
+
+class ssl_tools.pipelines.har_classification.simple1Dconv_classifier.Simple1DConvNetTrain(data, input_shape=(6, 60), num_classes=6, transforms='identity', num_workers=None, **kwargs)
+

Bases: ssl_tools.pipelines.mlflow_train.LightningTrainMLFlow

+

Train a model using Lightning framework.

+
+

Parameters

+
+
experiment_namestr

Name of the experiment.

+
+
model_namestr

Name of the model.

+
+
dataset_namestr

Name of the dataset.

+
+
run_namestr, optional

The name of the run, by default None

+
+
acceleratorstr, optional

The accelerator to use, by default “cpu”

+
+
devicesint, optional

Number of accelerators to use, by default 1

+
+
num_nodesint, optional

Number of nodes, by default 1

+
+
strategystr, optional

Training strategy, by default “auto”

+
+
max_epochsint, optional

Maximium number of epochs, by default 1

+
+
batch_sizeint, optional

Batch size, by default 1

+
+
limit_train_batchesint | float, optional

Limit the number of batches to train, by default 1.0

+
+
limit_val_batchesint | float, optional

Limit the number of batches to test, by default 1.0

+
+
checkpoint_monitor_metricstr, optional

The metric to monitor for checkpointing, by default None

+
+
checkpoint_monitor_modestr, optional

The mode for checkpointing, by default “min”

+
+
patienceint, optional

The patience for early stopping, by default None

+
+
log_dirstr, optional

Location where logs will be saved, by default “./runs”

+
+
+
+
+MODEL = 'Simple1DConvNet'
+
+ +
+
+get_data_module()
+
+
Return type:
+

lightning.LightningDataModule

+
+
+
+ +
+
+get_model()
+
+
Return type:
+

lightning.LightningModule

+
+
+
+ +
+
+
Parameters:
+
    +
  • data (str | List[str])

  • +
  • input_shape (Tuple[int, int])

  • +
  • num_classes (int)

  • +
  • transforms (str)

  • +
  • num_workers (int)

  • +
+
+
+
+ +
+
+ssl_tools.pipelines.har_classification.simple1Dconv_classifier.options
+
+ +
+
+ + +
+
+ +
+
+
+
+ + + + \ No newline at end of file diff --git a/autoapi/ssl_tools/pipelines/har_classification/simple2Dconv_classifier/index.html b/autoapi/ssl_tools/pipelines/har_classification/simple2Dconv_classifier/index.html new file mode 100644 index 0000000..948d90d --- /dev/null +++ b/autoapi/ssl_tools/pipelines/har_classification/simple2Dconv_classifier/index.html @@ -0,0 +1,329 @@ + + + + + + + ssl_tools.pipelines.har_classification.simple2Dconv_classifier — SSLTools documentation + + + + + + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

ssl_tools.pipelines.har_classification.simple2Dconv_classifier

+
+

Attributes

+ + + + + + +

options

+
+
+

Classes

+ + + + + + + + + +

Simple2DConvNetFineTune

Train a model using Lightning framework.

Simple2DConvNetTrain

Train a model using Lightning framework.

+
+
+

Module Contents

+
+
+class ssl_tools.pipelines.har_classification.simple2Dconv_classifier.Simple2DConvNetFineTune(data, num_classes=6, transforms='identity', num_workers=None, **kwargs)
+

Bases: ssl_tools.pipelines.mlflow_train.LightningFineTuneMLFlow

+

Train a model using Lightning framework.

+
+

Parameters

+
+
experiment_namestr

Name of the experiment.

+
+
model_namestr

Name of the model.

+
+
dataset_namestr

Name of the dataset.

+
+
run_namestr, optional

The name of the run, by default None

+
+
acceleratorstr, optional

The accelerator to use, by default “cpu”

+
+
devicesint, optional

Number of accelerators to use, by default 1

+
+
num_nodesint, optional

Number of nodes, by default 1

+
+
strategystr, optional

Training strategy, by default “auto”

+
+
max_epochsint, optional

Maximium number of epochs, by default 1

+
+
batch_sizeint, optional

Batch size, by default 1

+
+
limit_train_batchesint | float, optional

Limit the number of batches to train, by default 1.0

+
+
limit_val_batchesint | float, optional

Limit the number of batches to test, by default 1.0

+
+
checkpoint_monitor_metricstr, optional

The metric to monitor for checkpointing, by default None

+
+
checkpoint_monitor_modestr, optional

The mode for checkpointing, by default “min”

+
+
patienceint, optional

The patience for early stopping, by default None

+
+
log_dirstr, optional

Location where logs will be saved, by default “./runs”

+
+
+
+
+MODEL = 'Simple2DConvNet'
+
+ +
+
+get_data_module()
+
+
Return type:
+

lightning.LightningDataModule

+
+
+
+ +
+
+get_model()
+
+
Return type:
+

lightning.LightningModule

+
+
+
+ +
+
+
Parameters:
+
    +
  • data (str | List[str])

  • +
  • num_classes (int)

  • +
  • transforms (str)

  • +
  • num_workers (int)

  • +
+
+
+
+ +
+
+class ssl_tools.pipelines.har_classification.simple2Dconv_classifier.Simple2DConvNetTrain(data, input_shape=(6, 1, 60), num_classes=6, transforms='identity', num_workers=None, **kwargs)
+

Bases: ssl_tools.pipelines.mlflow_train.LightningTrainMLFlow

+

Train a model using Lightning framework.

+
+

Parameters

+
+
experiment_namestr

Name of the experiment.

+
+
model_namestr

Name of the model.

+
+
dataset_namestr

Name of the dataset.

+
+
run_namestr, optional

The name of the run, by default None

+
+
acceleratorstr, optional

The accelerator to use, by default “cpu”

+
+
devicesint, optional

Number of accelerators to use, by default 1

+
+
num_nodesint, optional

Number of nodes, by default 1

+
+
strategystr, optional

Training strategy, by default “auto”

+
+
max_epochsint, optional

Maximium number of epochs, by default 1

+
+
batch_sizeint, optional

Batch size, by default 1

+
+
limit_train_batchesint | float, optional

Limit the number of batches to train, by default 1.0

+
+
limit_val_batchesint | float, optional

Limit the number of batches to test, by default 1.0

+
+
checkpoint_monitor_metricstr, optional

The metric to monitor for checkpointing, by default None

+
+
checkpoint_monitor_modestr, optional

The mode for checkpointing, by default “min”

+
+
patienceint, optional

The patience for early stopping, by default None

+
+
log_dirstr, optional

Location where logs will be saved, by default “./runs”

+
+
+
+
+MODEL = 'Simple2DConvNet'
+
+ +
+
+get_data_module()
+
+
Return type:
+

lightning.LightningDataModule

+
+
+
+ +
+
+get_model()
+
+
Return type:
+

lightning.LightningModule

+
+
+
+ +
+
+
Parameters:
+
    +
  • data (str | List[str])

  • +
  • input_shape (Tuple[int, int, int])

  • +
  • num_classes (int)

  • +
  • transforms (str)

  • +
  • num_workers (int)

  • +
+
+
+
+ +
+
+ssl_tools.pipelines.har_classification.simple2Dconv_classifier.options
+
+ +
+
+ + +
+
+ +
+
+
+
+ + + + \ No newline at end of file diff --git a/autoapi/ssl_tools/pipelines/har_classification/tfc/index.html b/autoapi/ssl_tools/pipelines/har_classification/tfc/index.html new file mode 100644 index 0000000..fa31309 --- /dev/null +++ b/autoapi/ssl_tools/pipelines/har_classification/tfc/index.html @@ -0,0 +1,327 @@ + + + + + + + ssl_tools.pipelines.har_classification.tfc — SSLTools documentation + + + + + + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

ssl_tools.pipelines.har_classification.tfc

+
+

Attributes

+ + + + + + +

options

+
+
+

Classes

+ + + + + + + + + +

TFCFineTune

Train a model using Lightning framework.

TFCTrain

Train a model using Lightning framework.

+
+
+

Module Contents

+
+
+class ssl_tools.pipelines.har_classification.tfc.TFCFineTune(data, num_classes=6, num_workers=None, length_alignment=60, encoding_size=128, features_as_channels=True, **kwargs)
+

Bases: ssl_tools.pipelines.mlflow_train.LightningFineTuneMLFlow

+

Train a model using Lightning framework.

+
+

Parameters

+
+
experiment_namestr

Name of the experiment.

+
+
model_namestr

Name of the model.

+
+
dataset_namestr

Name of the dataset.

+
+
run_namestr, optional

The name of the run, by default None

+
+
acceleratorstr, optional

The accelerator to use, by default “cpu”

+
+
devicesint, optional

Number of accelerators to use, by default 1

+
+
num_nodesint, optional

Number of nodes, by default 1

+
+
strategystr, optional

Training strategy, by default “auto”

+
+
max_epochsint, optional

Maximium number of epochs, by default 1

+
+
batch_sizeint, optional

Batch size, by default 1

+
+
limit_train_batchesint | float, optional

Limit the number of batches to train, by default 1.0

+
+
limit_val_batchesint | float, optional

Limit the number of batches to test, by default 1.0

+
+
checkpoint_monitor_metricstr, optional

The metric to monitor for checkpointing, by default None

+
+
checkpoint_monitor_modestr, optional

The mode for checkpointing, by default “min”

+
+
patienceint, optional

The patience for early stopping, by default None

+
+
log_dirstr, optional

Location where logs will be saved, by default “./runs”

+
+
+
+
+get_data_module()
+
+
Return type:
+

lightning.LightningDataModule

+
+
+
+ +
+
+get_model()
+
+
Return type:
+

lightning.LightningModule

+
+
+
+ +
+
+
Parameters:
+
    +
  • data (str | List[str])

  • +
  • num_classes (int)

  • +
  • num_workers (int)

  • +
  • length_alignment (int)

  • +
  • encoding_size (int)

  • +
  • features_as_channels (bool)

  • +
+
+
+
+ +
+
+class ssl_tools.pipelines.har_classification.tfc.TFCTrain(data, label='standard activity code', encoding_size=128, in_channels=6, length_alignment=60, use_cosine_similarity=True, temperature=0.5, features_as_channels=True, jitter_ratio=2, num_classes=6, num_workers=None, **kwargs)
+

Bases: ssl_tools.pipelines.mlflow_train.LightningTrainMLFlow

+

Train a model using Lightning framework.

+
+

Parameters

+
+
experiment_namestr

Name of the experiment.

+
+
model_namestr

Name of the model.

+
+
dataset_namestr

Name of the dataset.

+
+
run_namestr, optional

The name of the run, by default None

+
+
acceleratorstr, optional

The accelerator to use, by default “cpu”

+
+
devicesint, optional

Number of accelerators to use, by default 1

+
+
num_nodesint, optional

Number of nodes, by default 1

+
+
strategystr, optional

Training strategy, by default “auto”

+
+
max_epochsint, optional

Maximium number of epochs, by default 1

+
+
batch_sizeint, optional

Batch size, by default 1

+
+
limit_train_batchesint | float, optional

Limit the number of batches to train, by default 1.0

+
+
limit_val_batchesint | float, optional

Limit the number of batches to test, by default 1.0

+
+
checkpoint_monitor_metricstr, optional

The metric to monitor for checkpointing, by default None

+
+
checkpoint_monitor_modestr, optional

The mode for checkpointing, by default “min”

+
+
patienceint, optional

The patience for early stopping, by default None

+
+
log_dirstr, optional

Location where logs will be saved, by default “./runs”

+
+
+
+
+get_data_module()
+
+
Return type:
+

lightning.LightningDataModule

+
+
+
+ +
+
+get_model()
+
+
Return type:
+

lightning.LightningModule

+
+
+
+ +
+
+
Parameters:
+
    +
  • data (str | List[str])

  • +
  • label (str)

  • +
  • encoding_size (int)

  • +
  • in_channels (int)

  • +
  • length_alignment (int)

  • +
  • use_cosine_similarity (bool)

  • +
  • temperature (float)

  • +
  • features_as_channels (bool)

  • +
  • jitter_ratio (float)

  • +
  • num_classes (int)

  • +
  • num_workers (int)

  • +
+
+
+
+ +
+
+ssl_tools.pipelines.har_classification.tfc.options
+
+ +
+
+ + +
+
+ +
+
+
+
+ + + + \ No newline at end of file diff --git a/autoapi/ssl_tools/pipelines/har_classification/tfc_head_classifier/index.html b/autoapi/ssl_tools/pipelines/har_classification/tfc_head_classifier/index.html new file mode 100644 index 0000000..043e469 --- /dev/null +++ b/autoapi/ssl_tools/pipelines/har_classification/tfc_head_classifier/index.html @@ -0,0 +1,329 @@ + + + + + + + ssl_tools.pipelines.har_classification.tfc_head_classifier — SSLTools documentation + + + + + + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

ssl_tools.pipelines.har_classification.tfc_head_classifier

+
+

Attributes

+ + + + + + +

options

+
+
+

Classes

+ + + + + + + + + +

TFCHeadClassifierFineTune

Train a model using Lightning framework.

TFCHeadClassifierTrain

Train a model using Lightning framework.

+
+
+

Module Contents

+
+
+class ssl_tools.pipelines.har_classification.tfc_head_classifier.TFCHeadClassifierFineTune(data, num_classes=6, transforms='identity', num_workers=None, **kwargs)
+

Bases: ssl_tools.pipelines.mlflow_train.LightningFineTuneMLFlow

+

Train a model using Lightning framework.

+
+

Parameters

+
+
experiment_namestr

Name of the experiment.

+
+
model_namestr

Name of the model.

+
+
dataset_namestr

Name of the dataset.

+
+
run_namestr, optional

The name of the run, by default None

+
+
acceleratorstr, optional

The accelerator to use, by default “cpu”

+
+
devicesint, optional

Number of accelerators to use, by default 1

+
+
num_nodesint, optional

Number of nodes, by default 1

+
+
strategystr, optional

Training strategy, by default “auto”

+
+
max_epochsint, optional

Maximium number of epochs, by default 1

+
+
batch_sizeint, optional

Batch size, by default 1

+
+
limit_train_batchesint | float, optional

Limit the number of batches to train, by default 1.0

+
+
limit_val_batchesint | float, optional

Limit the number of batches to test, by default 1.0

+
+
checkpoint_monitor_metricstr, optional

The metric to monitor for checkpointing, by default None

+
+
checkpoint_monitor_modestr, optional

The mode for checkpointing, by default “min”

+
+
patienceint, optional

The patience for early stopping, by default None

+
+
log_dirstr, optional

Location where logs will be saved, by default “./runs”

+
+
+
+
+MODEL = 'TFCPredictionHead'
+
+ +
+
+get_data_module()
+
+
Return type:
+

lightning.LightningDataModule

+
+
+
+ +
+
+get_model()
+
+
Return type:
+

lightning.LightningModule

+
+
+
+ +
+
+
Parameters:
+
    +
  • data (str | List[str])

  • +
  • num_classes (int)

  • +
  • transforms (str)

  • +
  • num_workers (int)

  • +
+
+
+
+ +
+
+class ssl_tools.pipelines.har_classification.tfc_head_classifier.TFCHeadClassifierTrain(data, input_size=360, num_classes=6, transforms='identity', num_workers=None, **kwargs)
+

Bases: ssl_tools.pipelines.mlflow_train.LightningTrainMLFlow

+

Train a model using Lightning framework.

+
+

Parameters

+
+
experiment_namestr

Name of the experiment.

+
+
model_namestr

Name of the model.

+
+
dataset_namestr

Name of the dataset.

+
+
run_namestr, optional

The name of the run, by default None

+
+
acceleratorstr, optional

The accelerator to use, by default “cpu”

+
+
devicesint, optional

Number of accelerators to use, by default 1

+
+
num_nodesint, optional

Number of nodes, by default 1

+
+
strategystr, optional

Training strategy, by default “auto”

+
+
max_epochsint, optional

Maximium number of epochs, by default 1

+
+
batch_sizeint, optional

Batch size, by default 1

+
+
limit_train_batchesint | float, optional

Limit the number of batches to train, by default 1.0

+
+
limit_val_batchesint | float, optional

Limit the number of batches to test, by default 1.0

+
+
checkpoint_monitor_metricstr, optional

The metric to monitor for checkpointing, by default None

+
+
checkpoint_monitor_modestr, optional

The mode for checkpointing, by default “min”

+
+
patienceint, optional

The patience for early stopping, by default None

+
+
log_dirstr, optional

Location where logs will be saved, by default “./runs”

+
+
+
+
+MODEL = 'TFCPredictionHead'
+
+ +
+
+get_data_module()
+
+
Return type:
+

lightning.LightningDataModule

+
+
+
+ +
+
+get_model()
+
+
Return type:
+

lightning.LightningModule

+
+
+
+ +
+
+
Parameters:
+
    +
  • data (str | List[str])

  • +
  • input_size (int)

  • +
  • num_classes (int)

  • +
  • transforms (str)

  • +
  • num_workers (int)

  • +
+
+
+
+ +
+
+ssl_tools.pipelines.har_classification.tfc_head_classifier.options
+
+ +
+
+ + +
+
+ +
+
+
+
+ + + + \ No newline at end of file diff --git a/autoapi/ssl_tools/pipelines/har_classification/tnc/index.html b/autoapi/ssl_tools/pipelines/har_classification/tnc/index.html new file mode 100644 index 0000000..1d82829 --- /dev/null +++ b/autoapi/ssl_tools/pipelines/har_classification/tnc/index.html @@ -0,0 +1,324 @@ + + + + + + + ssl_tools.pipelines.har_classification.tnc — SSLTools documentation + + + + + + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

ssl_tools.pipelines.har_classification.tnc

+
+

Attributes

+ + + + + + +

options

+
+
+

Classes

+ + + + + + + + + +

TNCFineTune

Train a model using Lightning framework.

TNCPreTrain

Train a model using Lightning framework.

+
+
+

Module Contents

+
+
+class ssl_tools.pipelines.har_classification.tnc.TNCFineTune(data, num_classes=6, num_workers=None, **kwargs)
+

Bases: ssl_tools.pipelines.mlflow_train.LightningFineTuneMLFlow

+

Train a model using Lightning framework.

+
+

Parameters

+
+
experiment_namestr

Name of the experiment.

+
+
model_namestr

Name of the model.

+
+
dataset_namestr

Name of the dataset.

+
+
run_namestr, optional

The name of the run, by default None

+
+
acceleratorstr, optional

The accelerator to use, by default “cpu”

+
+
devicesint, optional

Number of accelerators to use, by default 1

+
+
num_nodesint, optional

Number of nodes, by default 1

+
+
strategystr, optional

Training strategy, by default “auto”

+
+
max_epochsint, optional

Maximium number of epochs, by default 1

+
+
batch_sizeint, optional

Batch size, by default 1

+
+
limit_train_batchesint | float, optional

Limit the number of batches to train, by default 1.0

+
+
limit_val_batchesint | float, optional

Limit the number of batches to test, by default 1.0

+
+
checkpoint_monitor_metricstr, optional

The metric to monitor for checkpointing, by default None

+
+
checkpoint_monitor_modestr, optional

The mode for checkpointing, by default “min”

+
+
patienceint, optional

The patience for early stopping, by default None

+
+
log_dirstr, optional

Location where logs will be saved, by default “./runs”

+
+
+
+
+get_data_module()
+
+
Return type:
+

lightning.LightningDataModule

+
+
+
+ +
+
+get_model()
+
+
Return type:
+

lightning.LightningModule

+
+
+
+ +
+
+
Parameters:
+
    +
  • data (str | List[str])

  • +
  • num_classes (int)

  • +
  • num_workers (int)

  • +
+
+
+
+ +
+
+class ssl_tools.pipelines.har_classification.tnc.TNCPreTrain(data, encoding_size=10, in_channel=6, window_size=60, mc_sample_size=20, w=0.05, significance_level=0.01, repeat=5, pad_length=True, num_classes=6, num_workers=None, **kwargs)
+

Bases: ssl_tools.pipelines.mlflow_train.LightningTrainMLFlow

+

Train a model using Lightning framework.

+
+

Parameters

+
+
experiment_namestr

Name of the experiment.

+
+
model_namestr

Name of the model.

+
+
dataset_namestr

Name of the dataset.

+
+
run_namestr, optional

The name of the run, by default None

+
+
acceleratorstr, optional

The accelerator to use, by default “cpu”

+
+
devicesint, optional

Number of accelerators to use, by default 1

+
+
num_nodesint, optional

Number of nodes, by default 1

+
+
strategystr, optional

Training strategy, by default “auto”

+
+
max_epochsint, optional

Maximium number of epochs, by default 1

+
+
batch_sizeint, optional

Batch size, by default 1

+
+
limit_train_batchesint | float, optional

Limit the number of batches to train, by default 1.0

+
+
limit_val_batchesint | float, optional

Limit the number of batches to test, by default 1.0

+
+
checkpoint_monitor_metricstr, optional

The metric to monitor for checkpointing, by default None

+
+
checkpoint_monitor_modestr, optional

The mode for checkpointing, by default “min”

+
+
patienceint, optional

The patience for early stopping, by default None

+
+
log_dirstr, optional

Location where logs will be saved, by default “./runs”

+
+
+
+
+get_data_module()
+
+
Return type:
+

lightning.LightningDataModule

+
+
+
+ +
+
+get_model()
+
+
Return type:
+

lightning.LightningModule

+
+
+
+ +
+
+
Parameters:
+
    +
  • data (str)

  • +
  • encoding_size (int)

  • +
  • in_channel (int)

  • +
  • window_size (int)

  • +
  • mc_sample_size (int)

  • +
  • w (float)

  • +
  • significance_level (float)

  • +
  • repeat (int)

  • +
  • pad_length (bool)

  • +
  • num_classes (int)

  • +
  • num_workers (int)

  • +
+
+
+
+ +
+
+ssl_tools.pipelines.har_classification.tnc.options
+
+ +
+
+ + +
+
+ +
+
+
+
+ + + + \ No newline at end of file diff --git a/autoapi/ssl_tools/pipelines/har_classification/tnc_head_classifier/index.html b/autoapi/ssl_tools/pipelines/har_classification/tnc_head_classifier/index.html new file mode 100644 index 0000000..fb6c3a3 --- /dev/null +++ b/autoapi/ssl_tools/pipelines/har_classification/tnc_head_classifier/index.html @@ -0,0 +1,329 @@ + + + + + + + ssl_tools.pipelines.har_classification.tnc_head_classifier — SSLTools documentation + + + + + + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

ssl_tools.pipelines.har_classification.tnc_head_classifier

+
+

Attributes

+ + + + + + +

options

+
+
+

Classes

+ + + + + + + + + +

TNCHeadClassifierFineTune

Train a model using Lightning framework.

TNCHeadClassifierTrain

Train a model using Lightning framework.

+
+
+

Module Contents

+
+
+class ssl_tools.pipelines.har_classification.tnc_head_classifier.TNCHeadClassifierFineTune(data, num_classes=6, transforms='identity', num_workers=None, **kwargs)
+

Bases: ssl_tools.pipelines.mlflow_train.LightningFineTuneMLFlow

+

Train a model using Lightning framework.

+
+

Parameters

+
+
experiment_namestr

Name of the experiment.

+
+
model_namestr

Name of the model.

+
+
dataset_namestr

Name of the dataset.

+
+
run_namestr, optional

The name of the run, by default None

+
+
acceleratorstr, optional

The accelerator to use, by default “cpu”

+
+
devicesint, optional

Number of accelerators to use, by default 1

+
+
num_nodesint, optional

Number of nodes, by default 1

+
+
strategystr, optional

Training strategy, by default “auto”

+
+
max_epochsint, optional

Maximium number of epochs, by default 1

+
+
batch_sizeint, optional

Batch size, by default 1

+
+
limit_train_batchesint | float, optional

Limit the number of batches to train, by default 1.0

+
+
limit_val_batchesint | float, optional

Limit the number of batches to test, by default 1.0

+
+
checkpoint_monitor_metricstr, optional

The metric to monitor for checkpointing, by default None

+
+
checkpoint_monitor_modestr, optional

The mode for checkpointing, by default “min”

+
+
patienceint, optional

The patience for early stopping, by default None

+
+
log_dirstr, optional

Location where logs will be saved, by default “./runs”

+
+
+
+
+MODEL = 'TNCPredictionHead'
+
+ +
+
+get_data_module()
+
+
Return type:
+

lightning.LightningDataModule

+
+
+
+ +
+
+get_model()
+
+
Return type:
+

lightning.LightningModule

+
+
+
+ +
+
+
Parameters:
+
    +
  • data (str | List[str])

  • +
  • num_classes (int)

  • +
  • transforms (str)

  • +
  • num_workers (int)

  • +
+
+
+
+ +
+
+class ssl_tools.pipelines.har_classification.tnc_head_classifier.TNCHeadClassifierTrain(data, input_size=360, num_classes=6, transforms='identity', num_workers=None, **kwargs)
+

Bases: ssl_tools.pipelines.mlflow_train.LightningTrainMLFlow

+

Train a model using Lightning framework.

+
+

Parameters

+
+
experiment_namestr

Name of the experiment.

+
+
model_namestr

Name of the model.

+
+
dataset_namestr

Name of the dataset.

+
+
run_namestr, optional

The name of the run, by default None

+
+
acceleratorstr, optional

The accelerator to use, by default “cpu”

+
+
devicesint, optional

Number of accelerators to use, by default 1

+
+
num_nodesint, optional

Number of nodes, by default 1

+
+
strategystr, optional

Training strategy, by default “auto”

+
+
max_epochsint, optional

Maximium number of epochs, by default 1

+
+
batch_sizeint, optional

Batch size, by default 1

+
+
limit_train_batchesint | float, optional

Limit the number of batches to train, by default 1.0

+
+
limit_val_batchesint | float, optional

Limit the number of batches to test, by default 1.0

+
+
checkpoint_monitor_metricstr, optional

The metric to monitor for checkpointing, by default None

+
+
checkpoint_monitor_modestr, optional

The mode for checkpointing, by default “min”

+
+
patienceint, optional

The patience for early stopping, by default None

+
+
log_dirstr, optional

Location where logs will be saved, by default “./runs”

+
+
+
+
+MODEL = 'TNCPredictionHead'
+
+ +
+
+get_data_module()
+
+
Return type:
+

lightning.LightningDataModule

+
+
+
+ +
+
+get_model()
+
+
Return type:
+

lightning.LightningModule

+
+
+
+ +
+
+
Parameters:
+
    +
  • data (str | List[str])

  • +
  • input_size (int)

  • +
  • num_classes (int)

  • +
  • transforms (str)

  • +
  • num_workers (int)

  • +
+
+
+
+ +
+
+ssl_tools.pipelines.har_classification.tnc_head_classifier.options
+
+ +
+
+ + +
+
+ +
+
+
+
+ + + + \ No newline at end of file diff --git a/autoapi/ssl_tools/pipelines/har_classification/transformer/index.html b/autoapi/ssl_tools/pipelines/har_classification/transformer/index.html new file mode 100644 index 0000000..efed92b --- /dev/null +++ b/autoapi/ssl_tools/pipelines/har_classification/transformer/index.html @@ -0,0 +1,324 @@ + + + + + + + ssl_tools.pipelines.har_classification.transformer — SSLTools documentation + + + + + + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

ssl_tools.pipelines.har_classification.transformer

+
+

Attributes

+ + + + + + +

options

+
+
+

Classes

+ + + + + + + + + +

SimpleTransformerFineTune

Train a model using Lightning framework.

SimpleTransformerTrain

Train a model using Lightning framework.

+
+
+

Module Contents

+
+
+class ssl_tools.pipelines.har_classification.transformer.SimpleTransformerFineTune(data, num_classes=6, num_workers=None, **kwargs)
+

Bases: ssl_tools.pipelines.mlflow_train.LightningFineTuneMLFlow

+

Train a model using Lightning framework.

+
+

Parameters

+
+
experiment_namestr

Name of the experiment.

+
+
model_namestr

Name of the model.

+
+
dataset_namestr

Name of the dataset.

+
+
run_namestr, optional

The name of the run, by default None

+
+
acceleratorstr, optional

The accelerator to use, by default “cpu”

+
+
devicesint, optional

Number of accelerators to use, by default 1

+
+
num_nodesint, optional

Number of nodes, by default 1

+
+
strategystr, optional

Training strategy, by default “auto”

+
+
max_epochsint, optional

Maximium number of epochs, by default 1

+
+
batch_sizeint, optional

Batch size, by default 1

+
+
limit_train_batchesint | float, optional

Limit the number of batches to train, by default 1.0

+
+
limit_val_batchesint | float, optional

Limit the number of batches to test, by default 1.0

+
+
checkpoint_monitor_metricstr, optional

The metric to monitor for checkpointing, by default None

+
+
checkpoint_monitor_modestr, optional

The mode for checkpointing, by default “min”

+
+
patienceint, optional

The patience for early stopping, by default None

+
+
log_dirstr, optional

Location where logs will be saved, by default “./runs”

+
+
+
+
+get_data_module()
+
+
Return type:
+

lightning.LightningDataModule

+
+
+
+ +
+
+get_model()
+
+
Return type:
+

lightning.LightningModule

+
+
+
+ +
+
+
Parameters:
+
    +
  • data (str | List[str])

  • +
  • num_classes (int)

  • +
  • num_workers (int)

  • +
+
+
+
+ +
+
+class ssl_tools.pipelines.har_classification.transformer.SimpleTransformerTrain(data, in_channels=6, dim_feedforward=60, num_classes=6, heads=1, num_layers=1, num_workers=None, **kwargs)
+

Bases: ssl_tools.pipelines.mlflow_train.LightningTrainMLFlow

+

Train a model using Lightning framework.

+
+

Parameters

+
+
experiment_namestr

Name of the experiment.

+
+
model_namestr

Name of the model.

+
+
dataset_namestr

Name of the dataset.

+
+
run_namestr, optional

The name of the run, by default None

+
+
acceleratorstr, optional

The accelerator to use, by default “cpu”

+
+
devicesint, optional

Number of accelerators to use, by default 1

+
+
num_nodesint, optional

Number of nodes, by default 1

+
+
strategystr, optional

Training strategy, by default “auto”

+
+
max_epochsint, optional

Maximium number of epochs, by default 1

+
+
batch_sizeint, optional

Batch size, by default 1

+
+
limit_train_batchesint | float, optional

Limit the number of batches to train, by default 1.0

+
+
limit_val_batchesint | float, optional

Limit the number of batches to test, by default 1.0

+
+
checkpoint_monitor_metricstr, optional

The metric to monitor for checkpointing, by default None

+
+
checkpoint_monitor_modestr, optional

The mode for checkpointing, by default “min”

+
+
patienceint, optional

The patience for early stopping, by default None

+
+
log_dirstr, optional

Location where logs will be saved, by default “./runs”

+
+
+
+
+MODEL = 'Transformer'
+
+ +
+
+get_data_module()
+
+
Return type:
+

lightning.LightningDataModule

+
+
+
+ +
+
+get_model()
+
+
Return type:
+

lightning.LightningModule

+
+
+
+ +
+
+
Parameters:
+
    +
  • data (str | List[str])

  • +
  • in_channels (int)

  • +
  • num_classes (int)

  • +
  • heads (int)

  • +
  • num_layers (int)

  • +
  • num_workers (int)

  • +
+
+
+
+ +
+
+ssl_tools.pipelines.har_classification.transformer.options
+
+ +
+
+ + +
+
+ +
+
+
+
+ + + + \ No newline at end of file diff --git a/autoapi/ssl_tools/pipelines/har_classification/utils/index.html b/autoapi/ssl_tools/pipelines/har_classification/utils/index.html new file mode 100644 index 0000000..1c28632 --- /dev/null +++ b/autoapi/ssl_tools/pipelines/har_classification/utils/index.html @@ -0,0 +1,329 @@ + + + + + + + ssl_tools.pipelines.har_classification.utils — SSLTools documentation + + + + + + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

ssl_tools.pipelines.har_classification.utils

+
+

Classes

+ + + + + + + + + + + + + + + + + + + + + +

DimensionAdder

FFT

Flatten

PredictionHeadClassifier

Spectrogram

SwapAxes

+
+
+

Module Contents

+
+
+class ssl_tools.pipelines.har_classification.utils.DimensionAdder(dim)
+
+
Parameters:
+

dim (int)

+
+
+
+
+__call__(x)
+
+ +
+ +
+
+class ssl_tools.pipelines.har_classification.utils.FFT(absolute=True, centered=False)
+
+
Parameters:
+
    +
  • absolute (bool)

  • +
  • centered (bool)

  • +
+
+
+
+
+__call__(x)
+

Aplly FFT to the input signal. It apply the FFT into each channel +of the input signal.

+
+

Parameters

+
+
xnp.ndarray

An array with shape (n_channels, n_samples) containing the input

+
+
+
+
+

Returns

+
+
np.ndarray

The FFT of the input signal. The shape of the output is +(n_channels, n_samples) if absolute is False, and +(n_channels, n_samples//2) if absolute is True.

+
+
+
+
+
Parameters:
+

x (numpy.ndarray)

+
+
Return type:
+

numpy.ndarray

+
+
+
+ +
+ +
+
+class ssl_tools.pipelines.har_classification.utils.Flatten
+
+
+__call__(x)
+

Flatten the input signal. It apply the flatten into each channel +of the input signal.

+
+

Parameters

+
+
xnp.ndarray

An array with shape (n_channels, n_samples) containing the input

+
+
+
+
+

Returns

+
+
np.ndarray

The flatten of the input signal. The shape of the output is +(n_channels, n_samples).

+
+
+
+
+
Parameters:
+

x (numpy.ndarray)

+
+
Return type:
+

numpy.ndarray

+
+
+
+ +
+ +
+
+class ssl_tools.pipelines.har_classification.utils.PredictionHeadClassifier(prediction_head, num_classes=6)
+

Bases: ssl_tools.models.nets.simple.SimpleClassificationNet

+
+
Parameters:
+
    +
  • prediction_head (torch.nn.Module)

  • +
  • num_classes (int)

  • +
+
+
+
+ +
+
+class ssl_tools.pipelines.har_classification.utils.Spectrogram(fs=20, nperseg=16, noverlap=8, nfft=16)
+
+
+__call__(x)
+

Aplly Spectrogram to the input signal. It apply the Spectrogram into each channel +of the input signal.

+
+

Parameters

+
+
xnp.ndarray

An array with shape (n_channels, n_samples) containing the input

+
+
+
+
+

Returns

+
+
np.ndarray

The Spectrogram of the input signal. The shape of the output is +(n_channels, n_samples) if absolute is False, and +(n_channels, n_samples//2) if absolute is True.

+
+
+
+
+
Parameters:
+

x (numpy.ndarray)

+
+
Return type:
+

numpy.ndarray

+
+
+
+ +
+ +
+
+class ssl_tools.pipelines.har_classification.utils.SwapAxes(axis1, axis2)
+
+
Parameters:
+
    +
  • axis1 (int)

  • +
  • axis2 (int)

  • +
+
+
+
+
+__call__(x)
+
+ +
+ +
+
+ + +
+
+ +
+
+
+
+ + + + \ No newline at end of file diff --git a/autoapi/ssl_tools/pipelines/index.html b/autoapi/ssl_tools/pipelines/index.html new file mode 100644 index 0000000..86ef327 --- /dev/null +++ b/autoapi/ssl_tools/pipelines/index.html @@ -0,0 +1,156 @@ + + + + + + + ssl_tools.pipelines — SSLTools documentation + + + + + + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+ + +
+
+
+
+ + + + \ No newline at end of file diff --git a/autoapi/ssl_tools/pipelines/mlflow_train/index.html b/autoapi/ssl_tools/pipelines/mlflow_train/index.html new file mode 100644 index 0000000..8441878 --- /dev/null +++ b/autoapi/ssl_tools/pipelines/mlflow_train/index.html @@ -0,0 +1,354 @@ + + + + + + + ssl_tools.pipelines.mlflow_train — SSLTools documentation + + + + + + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

ssl_tools.pipelines.mlflow_train

+
+

Classes

+ + + + + + + + + +

LightningFineTuneMLFlow

Train a model using Lightning framework.

LightningTrainMLFlow

Train a model using Lightning framework.

+
+
+

Module Contents

+
+
+class ssl_tools.pipelines.mlflow_train.LightningFineTuneMLFlow(registered_model_name, registered_model_tags=None, update_backbone=False, **kwargs)
+

Bases: LightningTrainMLFlow

+

Train a model using Lightning framework.

+
+

Parameters

+
+
experiment_namestr

Name of the experiment.

+
+
model_namestr

Name of the model.

+
+
dataset_namestr

Name of the dataset.

+
+
run_namestr, optional

The name of the run, by default None

+
+
acceleratorstr, optional

The accelerator to use, by default “cpu”

+
+
devicesint, optional

Number of accelerators to use, by default 1

+
+
num_nodesint, optional

Number of nodes, by default 1

+
+
strategystr, optional

Training strategy, by default “auto”

+
+
max_epochsint, optional

Maximium number of epochs, by default 1

+
+
batch_sizeint, optional

Batch size, by default 1

+
+
limit_train_batchesint | float, optional

Limit the number of batches to train, by default 1.0

+
+
limit_val_batchesint | float, optional

Limit the number of batches to test, by default 1.0

+
+
checkpoint_monitor_metricstr, optional

The metric to monitor for checkpointing, by default None

+
+
checkpoint_monitor_modestr, optional

The mode for checkpointing, by default “min”

+
+
patienceint, optional

The patience for early stopping, by default None

+
+
log_dirstr, optional

Location where logs will be saved, by default “./runs”

+
+
+
+
+property client
+
+ +
+
+load_model()
+
+
Return type:
+

lightning.LightningModule

+
+
+
+ +
+
+
Parameters:
+
    +
  • registered_model_name (str)

  • +
  • registered_model_tags (Dict[str, str])

  • +
  • update_backbone (bool)

  • +
+
+
+
+ +
+
+class ssl_tools.pipelines.mlflow_train.LightningTrainMLFlow(experiment_name, model_name, run_name=None, accelerator='cpu', devices=1, num_nodes=1, strategy='auto', max_epochs=1, batch_size=1, limit_train_batches=1.0, limit_val_batches=1.0, checkpoint_monitor_metric=None, checkpoint_monitor_mode='min', patience=None, log_dir='./mlruns', model_tags=None)
+

Bases: ssl_tools.pipelines.base.Pipeline

+

Train a model using Lightning framework.

+
+

Parameters

+
+
experiment_namestr

Name of the experiment.

+
+
model_namestr

Name of the model.

+
+
dataset_namestr

Name of the dataset.

+
+
run_namestr, optional

The name of the run, by default None

+
+
acceleratorstr, optional

The accelerator to use, by default “cpu”

+
+
devicesint, optional

Number of accelerators to use, by default 1

+
+
num_nodesint, optional

Number of nodes, by default 1

+
+
strategystr, optional

Training strategy, by default “auto”

+
+
max_epochsint, optional

Maximium number of epochs, by default 1

+
+
batch_sizeint, optional

Batch size, by default 1

+
+
limit_train_batchesint | float, optional

Limit the number of batches to train, by default 1.0

+
+
limit_val_batchesint | float, optional

Limit the number of batches to test, by default 1.0

+
+
checkpoint_monitor_metricstr, optional

The metric to monitor for checkpointing, by default None

+
+
checkpoint_monitor_modestr, optional

The mode for checkpointing, by default “min”

+
+
patienceint, optional

The patience for early stopping, by default None

+
+
log_dirstr, optional

Location where logs will be saved, by default “./runs”

+
+
+
+
+get_callbacks()
+
+
Return type:
+

List[lightning.Callback]

+
+
+
+ +
+
+abstract get_data_module()
+
+
Return type:
+

lightning.LightningDataModule

+
+
+
+ +
+
+get_logger()
+
+
Return type:
+

lightning.pytorch.loggers.Logger

+
+
+
+ +
+
+abstract get_model()
+
+
Return type:
+

lightning.LightningModule

+
+
+
+ +
+
+get_trainer(logger, callacks)
+
+
Parameters:
+
    +
  • logger (lightning.pytorch.loggers.Logger)

  • +
  • callacks (List[lightning.Callback])

  • +
+
+
Return type:
+

lightning.Trainer

+
+
+
+ +
+
+run()
+
+
Return type:
+

lightning.LightningModule

+
+
+
+ +
+
+
Parameters:
+
    +
  • experiment_name (str)

  • +
  • model_name (str)

  • +
  • run_name (str)

  • +
  • accelerator (str)

  • +
  • devices (int)

  • +
  • num_nodes (int)

  • +
  • strategy (str)

  • +
  • max_epochs (int)

  • +
  • batch_size (int)

  • +
  • limit_train_batches (int | float)

  • +
  • limit_val_batches (int | float)

  • +
  • checkpoint_monitor_metric (str)

  • +
  • checkpoint_monitor_mode (str)

  • +
  • patience (int)

  • +
  • log_dir (str)

  • +
  • model_tags (Dict[str, str])

  • +
+
+
+
+ +
+
+ + +
+
+ +
+
+
+
+ + + + \ No newline at end of file diff --git a/autoapi/ssl_tools/pipelines/utils/index.html b/autoapi/ssl_tools/pipelines/utils/index.html new file mode 100644 index 0000000..fd217b0 --- /dev/null +++ b/autoapi/ssl_tools/pipelines/utils/index.html @@ -0,0 +1,189 @@ + + + + + + + ssl_tools.pipelines.utils — SSLTools documentation + + + + + + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

ssl_tools.pipelines.utils

+
+

Functions

+ + + + + + + + + +

load_model_mlflow(client, registered_model_name[, ...])

tags2str(d)

Convert a dictionary of tags to a search string compatible with MLflow's search_model_versions method.

+
+
+

Module Contents

+
+
+ssl_tools.pipelines.utils.load_model_mlflow(client, registered_model_name, registered_model_tags=None)
+
+
Parameters:
+
    +
  • client (mlflow.client.MlflowClient)

  • +
  • registered_model_name (str)

  • +
  • registered_model_tags (Dict[str, str])

  • +
+
+
Return type:
+

Dict[lightning.LightningModule, Dict[str, str]]

+
+
+
+ +
+
+ssl_tools.pipelines.utils.tags2str(d)
+

Convert a dictionary of tags to a search string compatible with MLflow’s search_model_versions method.

+

Parameters: +- d: A dictionary containing tags where keys are tag names and values are tag values.

+

Returns: +- search_str: A search string formatted for MLflow’s search_model_versions method.

+
+
Parameters:
+

d (Dict[str, str])

+
+
Return type:
+

str

+
+
+
+ +
+
+ + +
+
+ +
+
+
+
+ + + + \ No newline at end of file diff --git a/autoapi/ssl_tools/transforms/index.html b/autoapi/ssl_tools/transforms/index.html index 87b0857..860181b 100644 --- a/autoapi/ssl_tools/transforms/index.html +++ b/autoapi/ssl_tools/transforms/index.html @@ -7,7 +7,7 @@ ssl_tools.transforms — SSLTools documentation - + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

ssl_tools.transforms.pad

+
+

Classes

+ + + + + + +

ZeroPaddingBetween

+
+
+

Module Contents

+
+
+class ssl_tools.transforms.pad.ZeroPaddingBetween(pad_every=3, padding_size=2, discard_last=True)
+
+
Parameters:
+
    +
  • pad_every (int)

  • +
  • padding_size (int)

  • +
  • discard_last (bool)

  • +
+
+
+
+
+__call__(x)
+
+
Parameters:
+

x (numpy.ndarray)

+
+
+
+ +
+ +
+
+ + +
+
+ +
+
+
+
+ + + + \ No newline at end of file diff --git a/autoapi/ssl_tools/transforms/signal_1d/index.html b/autoapi/ssl_tools/transforms/signal_1d/index.html index 2d572af..ce8cc73 100644 --- a/autoapi/ssl_tools/transforms/signal_1d/index.html +++ b/autoapi/ssl_tools/transforms/signal_1d/index.html @@ -7,7 +7,7 @@ ssl_tools.transforms.signal_1d — SSLTools documentation - + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

ssl_tools.transforms.time_1d_full

+
+

Classes

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

Composer

ConcatComposer

Identity

MagnitudeWarp

Permutate

Rotate

Scale

TimeWarp

WindowSlice

WindowWarp

+
+
+

Module Contents

+
+
+class ssl_tools.transforms.time_1d_full.Composer(transforms)
+
+
Parameters:
+

transforms (List[Callable])

+
+
+
+
+__call__(dataset, labels=None)
+
+
Parameters:
+
    +
  • dataset (numpy.ndarray)

  • +
  • labels (numpy.ndarray)

  • +
+
+
+
+ +
+
+__str__()
+

Return str(self).

+
+
Return type:
+

str

+
+
+
+ +
+ +
+
+class ssl_tools.transforms.time_1d_full.ConcatComposer(transforms, axis=0)
+
+
Parameters:
+
    +
  • transforms (List[Callable])

  • +
  • axis (int)

  • +
+
+
+
+
+__call__(dataset, labels=None)
+
+
Parameters:
+
    +
  • dataset (numpy.ndarray)

  • +
  • labels (numpy.ndarray)

  • +
+
+
+
+ +
+
+__str__()
+

Return str(self).

+
+
Return type:
+

str

+
+
+
+ +
+ +
+
+class ssl_tools.transforms.time_1d_full.Identity
+
+
+__call__(dataset)
+
+
Parameters:
+

dataset (numpy.ndarray)

+
+
+
+ +
+ +
+
+class ssl_tools.transforms.time_1d_full.MagnitudeWarp(sigma=0.2, knot=4)
+
+
Parameters:
+
    +
  • sigma (float)

  • +
  • knot (int)

  • +
+
+
+
+
+__call__(dataset)
+
+
Parameters:
+

dataset (numpy.ndarray)

+
+
+
+ +
+
+__str__()
+

Return str(self).

+
+
Return type:
+

str

+
+
+
+ +
+ +
+
+class ssl_tools.transforms.time_1d_full.Permutate(max_segments=5, segment_mode='equal')
+
+
Parameters:
+
    +
  • max_segments (int)

  • +
  • segment_mode (str)

  • +
+
+
+
+
+__call__(dataset)
+
+
Parameters:
+

dataset (numpy.ndarray)

+
+
+
+ +
+
+__str__()
+

Return str(self).

+
+
Return type:
+

str

+
+
+
+ +
+ +
+
+class ssl_tools.transforms.time_1d_full.Rotate
+
+
+__call__(dataset)
+
+
Parameters:
+

dataset (numpy.ndarray)

+
+
+
+ +
+
+__str__()
+

Return str(self).

+
+
Return type:
+

str

+
+
+
+ +
+ +
+
+class ssl_tools.transforms.time_1d_full.Scale(sigma=0.1)
+
+
Parameters:
+

sigma (float)

+
+
+
+
+__call__(dataset)
+
+
Parameters:
+

dataset (numpy.ndarray)

+
+
+
+ +
+
+__str__()
+

Return str(self).

+
+
Return type:
+

str

+
+
+
+ +
+ +
+
+class ssl_tools.transforms.time_1d_full.TimeWarp(sigma=0.2, knot=4)
+
+
Parameters:
+
    +
  • sigma (float)

  • +
  • knot (int)

  • +
+
+
+
+
+__call__(dataset)
+
+
Parameters:
+

dataset (numpy.ndarray)

+
+
+
+ +
+
+__str__()
+

Return str(self).

+
+
Return type:
+

str

+
+
+
+ +
+ +
+
+class ssl_tools.transforms.time_1d_full.WindowSlice(reduce_ratio=0.9)
+
+
Parameters:
+

reduce_ratio (float)

+
+
+
+
+__call__(dataset)
+
+
Parameters:
+

dataset (numpy.ndarray)

+
+
+
+ +
+
+__str__()
+

Return str(self).

+
+
Return type:
+

str

+
+
+
+ +
+ +
+
+class ssl_tools.transforms.time_1d_full.WindowWarp(window_ratio=0.1, scales=[0.5, 2.0])
+
+
+__call__(dataset)
+
+
Parameters:
+

dataset (numpy.ndarray)

+
+
+
+ +
+
+__str__()
+

Return str(self).

+
+
Return type:
+

str

+
+
+
+ +
+ +
+
+ + +
+
+ +
+
+
+
+ + + + \ No newline at end of file diff --git a/autoapi/ssl_tools/transforms/utils/index.html b/autoapi/ssl_tools/transforms/utils/index.html index 86ae4e2..7bf77cd 100644 --- a/autoapi/ssl_tools/transforms/utils/index.html +++ b/autoapi/ssl_tools/transforms/utils/index.html @@ -7,7 +7,7 @@ ssl_tools.transforms.utils — SSLTools documentation - + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

ssl_tools.transforms.window

+
+

Classes

+ + + + + + +

Windowize

+
+
+

Module Contents

+
+
+class ssl_tools.transforms.window.Windowize(time_segments=15, stride=None)
+
+
Parameters:
+
    +
  • time_segments (int)

  • +
  • stride (int)

  • +
+
+
+
+
+__call__(x)
+
+
Parameters:
+

x (numpy.ndarray)

+
+
+
+ +
+ +
+
+ + +
+
+ +
+
+
+
+ + + + \ No newline at end of file diff --git a/autoapi/ssl_tools/utils/configurable/index.html b/autoapi/ssl_tools/utils/configurable/index.html index 32b833e..1a035ab 100644 --- a/autoapi/ssl_tools/utils/configurable/index.html +++ b/autoapi/ssl_tools/utils/configurable/index.html @@ -7,7 +7,7 @@ ssl_tools.utils.configurable — SSLTools documentation - + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+ +
+
+
+
+ +
+

ssl_tools.utils.layers

+
+

Classes

+ + + + + + +

OutputLoggerCallback

+
+
+

Module Contents

+
+
+class ssl_tools.utils.layers.OutputLoggerCallback(layers)
+

Bases: lightning.Callback

+
+
Parameters:
+

layers (List[str])

+
+
+
+
+count(module, input, output, layer_name)
+
+
Parameters:
+

layer_name (str)

+
+
+
+ +
+
+setup(trainer, pl_module, stage)
+
+
Parameters:
+
    +
  • trainer (lightning.Trainer)

  • +
  • pl_module (lightning.LightningModule)

  • +
  • stage (str)

  • +
+
+
+
+ +
+
+teardown(trainer, pl_module, stage)
+
+
Parameters:
+
    +
  • trainer (lightning.Trainer)

  • +
  • pl_module (lightning.LightningModule)

  • +
  • stage (str)

  • +
+
+
Return type:
+

None

+
+
+
+ +
+ +
+
+ + +
+
+ +
+
+
+
+ + + + \ No newline at end of file diff --git a/autoapi/ssl_tools/utils/types/index.html b/autoapi/ssl_tools/utils/types/index.html index 2b30c75..7ac63e7 100644 --- a/autoapi/ssl_tools/utils/types/index.html +++ b/autoapi/ssl_tools/utils/types/index.html @@ -7,7 +7,7 @@ ssl_tools.utils.types — SSLTools documentation - + + + + + + + + + + + + + + + + +
+ + +
+ +
+
+
+
    +
  • + +
  • + View page source +
  • +
+
+
+
+
+ +
+

5. Training an Anomaly Detection Model for Covid Anomaly Detection

+

In this tutorial, we will train an anomaly detection model using a simple LSTM-AutoEncoder model. Data can be obtained from this link. This is a processed version of data from original Stanford dataset-Phase 2. The overall pre-processing pipeline used is illustrated in Figure below.

+

preprocessing

+

Data was aquired from diferent sources (Germin, FitBit, Apple Watch) and pre-processed to have a common format. In this form, data has two columns: heart rate and number of user steps in last minute. Then the processing pipeline was applied to the data. The pipeline is composed of the following steps: 1. Once data was standardized, the resting heart rate was extracted (Resting Heart Rate Extractor, in Figure). This process takes as input min_minutes_rest that is the number of minutes +that the user has to be at rest to consider the heart rate as resting. This variable looks at user steps and, when user steps is 0 for min_minutes_rest minutes, the heart rate is considered as resting. At the end of this process, we will have a new dataframe with: the date and the resting heart rate of the last minute. 2. The second step is adding labels.

+
+
[1]:
+
+
+
import pandas as pd
+from ssl_tools.data.data_modules.covid_anomaly import CovidUserAnomalyDataModule
+from ssl_tools.utils.data import get_full_data_split
+from ssl_tools.models.nets.lstm_ae import LSTMAutoencoder
+import lightning as L
+import torch
+import numpy as np
+from torchmetrics import MeanSquaredError
+
+
+
+
+
[2]:
+
+
+
# Read CSV data
+data_path = "/workspaces/hiaac-m4/data/Stanford-COVID/processed/windowed_16_overlap_8_df_scaled.csv"
+df = pd.read_csv(data_path)
+df
+
+
+
+
+
[2]:
+
+
+
+
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
datetimeRHR-0RHR-1RHR-2RHR-3RHR-4RHR-5RHR-6RHR-7RHR-8...RHR-10RHR-11RHR-12RHR-13RHR-14RHR-15anomalybaselinelabelparticipant_id
02027-01-14 21:00:001.1701750.653752-0.392374-1.431553-2.129013-2.755962-3.681322-4.674443-5.668570...-6.937363-7.102118-6.975790-6.554774-6.112156-5.396099FalseTruenormalP110465
12027-01-15 05:00:00-5.668570-6.373289-6.937363-7.102118-6.975790-6.554774-6.112156-5.396099-4.415848...-2.656756-1.305630-0.0727561.0461951.5304671.829053FalseFalsenormalP110465
22027-01-15 13:00:00-4.415848-3.467073-2.656756-1.305630-0.0727561.0461951.5304671.8290531.223064...-0.424000-1.145581-1.355121-2.321206-3.124961-3.928738FalseFalsenormalP110465
32027-01-15 21:00:001.2230640.472444-0.424000-1.145581-1.355121-2.321206-3.124961-3.928738-4.802627...-6.067744-5.460156-4.671143-3.408943-2.237883-1.187843FalseFalsenormalP110465
42027-01-16 05:00:00-4.802627-5.831013-6.067744-5.460156-4.671143-3.408943-2.237883-1.187843-0.062360...2.2669443.7944654.6257454.8277564.7200004.677464FalseFalsenormalP110465
..................................................................
317322024-12-13 00:00:00-0.180702-0.499793-0.749829-0.868485-0.966754-1.004670-0.888210-0.580762-0.467943...0.0920000.3478400.6363950.9581951.1705141.301841FalseFalserecoveredP992022
317332024-12-13 08:00:00-0.467943-0.1627400.0920000.3478400.6363950.9581951.1705141.3018411.477526...1.6603441.6566001.6856521.7472521.7673291.793616FalseFalserecoveredP992022
317342024-12-13 16:00:001.4775261.6573211.6603441.6566001.6856521.7472521.7673291.7936161.728615...1.5098331.3807491.2637441.1399971.0242050.946663FalseFalserecoveredP992022
317352024-12-14 00:00:001.7286151.6162651.5098331.3807491.2637441.1399971.0242050.9466631.136868...1.6421531.9093812.1144392.2822382.4536912.587843FalseFalserecoveredP992022
317362024-12-14 08:00:001.1368681.3804181.6421531.9093812.1144392.2822382.4536912.5878432.437232...2.3598402.1734002.0981401.9676691.7845121.561848FalseFalserecoveredP992022
+

31737 rows × 21 columns

+
+
+
+
[3]:
+
+
+
dm = CovidUserAnomalyDataModule(
+    data_path,
+    participants=["P992022"],
+    batch_size=32,
+    num_workers=0,
+    reshape=(16, 1),
+)
+dm
+
+
+
+
+
[3]:
+
+
+
+
+CovidUserAnomalyDataModule (Data=/workspaces/hiaac-m4/data/Stanford-COVID/processed/windowed_16_overlap_8_df_scaled.csv, 1 participant selected)
+
+
+
+
[4]:
+
+
+
model = LSTMAutoencoder(input_shape=(16, 1))
+model
+
+
+
+
+
[4]:
+
+
+
+
+LSTMAutoencoder(
+  (backbone): _LSTMAutoEncoder(
+    (lstm1): LSTM(1, 128, batch_first=True)
+    (lstm2): LSTM(128, 64, batch_first=True)
+    (repeat_vector): Linear(in_features=64, out_features=1024, bias=True)
+    (lstm3): LSTM(64, 64, batch_first=True)
+    (lstm4): LSTM(64, 128, batch_first=True)
+    (time_distributed): Linear(in_features=128, out_features=1, bias=True)
+  )
+  (loss_fn): MSELoss()
+)
+
+
+
+
[5]:
+
+
+
trainer = L.Trainer(max_epochs=100, devices=1, accelerator="cpu")
+trainer
+
+
+
+
+
+
+
+
+GPU available: True (cuda), used: False
+TPU available: False, using: 0 TPU cores
+IPU available: False, using: 0 IPUs
+HPU available: False, using: 0 HPUs
+/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/setup.py:187: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.
+
+
+
+
[5]:
+
+
+
+
+<lightning.pytorch.trainer.trainer.Trainer at 0x7f98860215a0>
+
+
+
+
[6]:
+
+
+
trainer.fit(model, dm)
+
+
+
+
+
+
+
+
+
+  | Name     | Type             | Params
+----------------------------------------------
+0 | backbone | _LSTMAutoEncoder | 316 K
+1 | loss_fn  | MSELoss          | 0
+----------------------------------------------
+316 K     Trainable params
+0         Non-trainable params
+316 K     Total params
+1.264     Total estimated model params size (MB)
+
+
+
+
+
+
+
+
+
+
+
+
+
+/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=47` in the `DataLoader` to improve performance.
+/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=47` in the `DataLoader` to improve performance.
+/usr/local/lib/python3.10/dist-packages/lightning/pytorch/loops/fit_loop.py:298: The number of training batches (5) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+`Trainer.fit` stopped: `max_epochs=100` reached.
+
+
+
+
[7]:
+
+
+
def compute_losses(y, y_pred, loss_fn):
+    losses = []
+    for _y, _y_pred in zip(y, y_pred):
+        loss = loss_fn(_y, _y_pred)
+        loss = loss.detach().numpy().item()
+        losses.append(loss)
+    return losses
+
+
+
+
+

Predict

+
+
[12]:
+
+
+
dm.setup("fit")
+train_dataset = dm.train_dataloader().dataset
+
+dm.setup("test")
+test_dataset = dm.test_dataloader().dataset
+
+
+
+
+

Defining Anomaly Threshold

+
+
[13]:
+
+
+
x_train = torch.stack([torch.Tensor(x) for x, y in train_dataset])
+y_train = np.array([y for x, y in train_dataset])
+x_train_hat = model(x_train)
+
+
+
+
+
[14]:
+
+
+
mse = MeanSquaredError()
+losses = compute_losses(x_train, x_train_hat, mse)
+
+anomaly_threshold = max(losses)
+anomaly_threshold
+
+
+
+
+
[14]:
+
+
+
+
+0.3742748498916626
+
+
+
+
+

Predicting on Test set

+
+
[15]:
+
+
+
x_test = torch.stack([torch.Tensor(x) for x, y in test_dataset])
+y_test = np.array([y for x, y in test_dataset])
+
+x_test_hat = model(x_test)
+
+
+
+
+
[16]:
+
+
+
mse = MeanSquaredError()
+losses = compute_losses(x_test, x_test_hat, mse)
+
+y_test_hat = [1 if loss > anomaly_threshold else 0 for loss in losses]
+
+
+
+
+
[17]:
+
+
+
results_dataframe = pd.DataFrame(
+    {
+        "true": y_test,
+        "predicted": y_test_hat,
+        "loss": losses,
+        "anomaly_threshold": anomaly_threshold,
+    }
+)
+
+results_dataframe
+
+
+
+
+
[17]:
+
+
+
+
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
truepredictedlossanomaly_threshold
0000.0237000.374275
1000.0914130.374275
2000.0542990.374275
3000.0074860.374275
4000.0246010.374275
...............
89100.0898330.374275
90100.0515620.374275
91100.1327480.374275
92100.1586100.374275
93100.0255220.374275
+

94 rows × 4 columns

+
+
+
+
[18]:
+
+
+
from sklearn.metrics import f1_score, recall_score, balanced_accuracy_score, roc_auc_score
+
+# Extract true and predicted labels from the results_dataframe
+true_labels = results_dataframe['true']
+predicted_labels = results_dataframe['predicted']
+
+# Calculate the F1-score
+f1 = f1_score(true_labels, predicted_labels)
+
+# Calculate the recall
+recall = recall_score(true_labels, predicted_labels)
+
+# Calculate the balanced accuracy
+balanced_acc = balanced_accuracy_score(true_labels, predicted_labels)
+
+# Calculate the ROC AUC
+roc_auc = roc_auc_score(true_labels, predicted_labels)
+
+# Print the results
+print("F1-score:", f1)
+print("Recall:", recall)
+print("Balanced Accuracy:", balanced_acc)
+print("ROC AUC:", roc_auc)
+
+
+
+
+
+
+
+
+F1-score: 0.0
+Recall: 0.0
+Balanced Accuracy: 0.5
+ROC AUC: 0.5
+
+
+
+
[19]:
+
+
+
import numpy as np
+from sklearn.metrics import confusion_matrix
+
+import matplotlib.pyplot as plt
+
+# Get the true and predicted labels from the results_dataframe
+true_labels = results_dataframe['true']
+predicted_labels = results_dataframe['predicted']
+
+# Compute the confusion matrix
+cm = confusion_matrix(true_labels, predicted_labels)
+
+# Define the class labels
+class_labels = ['Normal', 'Anomaly']
+
+# Plot the confusion matrix
+plt.figure(figsize=(8, 6))
+plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
+plt.title('Confusion Matrix')
+plt.colorbar()
+tick_marks = np.arange(len(class_labels))
+plt.xticks(tick_marks, class_labels, rotation=45)
+plt.yticks(tick_marks, class_labels)
+plt.xlabel('Predicted Label')
+plt.ylabel('True Label')
+
+# Add the values to the confusion matrix plot
+thresh = cm.max() / 2.
+for i in range(cm.shape[0]):
+    for j in range(cm.shape[1]):
+        plt.text(j, i, format(cm[i, j], 'd'),
+                 horizontalalignment="center",
+                 color="white" if cm[i, j] > thresh else "black")
+
+plt.tight_layout()
+plt.show()
+
+
+
+
+
+
+
+../_images/notebooks_05_covid_anomaly_detection_19_0.png +
+
+
+
[ ]:
+
+
+

+
+
+
+
+
+
+ + +
+
+ +
+
+
+
+ + + + \ No newline at end of file diff --git a/notebooks/05_covid_anomaly_detection.ipynb b/notebooks/05_covid_anomaly_detection.ipynb new file mode 100644 index 0000000..a7091e7 --- /dev/null +++ b/notebooks/05_covid_anomaly_detection.ipynb @@ -0,0 +1,2372 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 5. Training an Anomaly Detection Model for Covid Anomaly Detection\n", + "\n", + "In this tutorial, we will train an anomaly detection model using a simple [LSTM-AutoEncoder model](https://www.medrxiv.org/content/10.1101/2021.01.08.21249474v1).\n", + "Data can be obtained from [this link](https://iscteiul365-my.sharepoint.com/:u:/g/personal/oonia_iscte-iul_pt/ERZLm1ruUNpMqkSwjpqhE9wB_7loVWAC4yZWuIH2RKGOlQ?e=kD4HlI). This is a processed version of data from original Stanford dataset-Phase 2. The overall pre-processing pipeline used is illustrated in Figure below.\n", + "\n", + "![preprocessing](stanford_data_processing.png)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Data was aquired from diferent sources (Germin, FitBit, Apple Watch) and pre-processed to have a common format. In this form, data has two columns: heart rate and number of user steps in last minute. \n", + "Then the processing pipeline was applied to the data. The pipeline is composed of the following steps:\n", + "1. Once data was standardized, the resting heart rate was extracted (``Resting Heart Rate Extractor``, in Figure). This process takes as input `min_minutes_rest` that is the number of minutes that the user has to be at rest to consider the heart rate as resting. This variable looks at user steps and, when user steps is 0 for `min_minutes_rest` minutes, the heart rate is considered as resting. At the end of this process, we will have a new dataframe with: the date and the resting heart rate of the last minute.\n", + "2. The second step is adding labels." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "from ssl_tools.data.data_modules.covid_anomaly import CovidUserAnomalyDataModule\n", + "from ssl_tools.utils.data import get_full_data_split\n", + "from ssl_tools.models.nets.lstm_ae import LSTMAutoencoder\n", + "import lightning as L\n", + "import torch\n", + "import numpy as np\n", + "from torchmetrics import MeanSquaredError" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
datetimeRHR-0RHR-1RHR-2RHR-3RHR-4RHR-5RHR-6RHR-7RHR-8...RHR-10RHR-11RHR-12RHR-13RHR-14RHR-15anomalybaselinelabelparticipant_id
02027-01-14 21:00:001.1701750.653752-0.392374-1.431553-2.129013-2.755962-3.681322-4.674443-5.668570...-6.937363-7.102118-6.975790-6.554774-6.112156-5.396099FalseTruenormalP110465
12027-01-15 05:00:00-5.668570-6.373289-6.937363-7.102118-6.975790-6.554774-6.112156-5.396099-4.415848...-2.656756-1.305630-0.0727561.0461951.5304671.829053FalseFalsenormalP110465
22027-01-15 13:00:00-4.415848-3.467073-2.656756-1.305630-0.0727561.0461951.5304671.8290531.223064...-0.424000-1.145581-1.355121-2.321206-3.124961-3.928738FalseFalsenormalP110465
32027-01-15 21:00:001.2230640.472444-0.424000-1.145581-1.355121-2.321206-3.124961-3.928738-4.802627...-6.067744-5.460156-4.671143-3.408943-2.237883-1.187843FalseFalsenormalP110465
42027-01-16 05:00:00-4.802627-5.831013-6.067744-5.460156-4.671143-3.408943-2.237883-1.187843-0.062360...2.2669443.7944654.6257454.8277564.7200004.677464FalseFalsenormalP110465
..................................................................
317322024-12-13 00:00:00-0.180702-0.499793-0.749829-0.868485-0.966754-1.004670-0.888210-0.580762-0.467943...0.0920000.3478400.6363950.9581951.1705141.301841FalseFalserecoveredP992022
317332024-12-13 08:00:00-0.467943-0.1627400.0920000.3478400.6363950.9581951.1705141.3018411.477526...1.6603441.6566001.6856521.7472521.7673291.793616FalseFalserecoveredP992022
317342024-12-13 16:00:001.4775261.6573211.6603441.6566001.6856521.7472521.7673291.7936161.728615...1.5098331.3807491.2637441.1399971.0242050.946663FalseFalserecoveredP992022
317352024-12-14 00:00:001.7286151.6162651.5098331.3807491.2637441.1399971.0242050.9466631.136868...1.6421531.9093812.1144392.2822382.4536912.587843FalseFalserecoveredP992022
317362024-12-14 08:00:001.1368681.3804181.6421531.9093812.1144392.2822382.4536912.5878432.437232...2.3598402.1734002.0981401.9676691.7845121.561848FalseFalserecoveredP992022
\n", + "

31737 rows × 21 columns

\n", + "
" + ], + "text/plain": [ + " datetime RHR-0 RHR-1 RHR-2 RHR-3 RHR-4 \\\n", + "0 2027-01-14 21:00:00 1.170175 0.653752 -0.392374 -1.431553 -2.129013 \n", + "1 2027-01-15 05:00:00 -5.668570 -6.373289 -6.937363 -7.102118 -6.975790 \n", + "2 2027-01-15 13:00:00 -4.415848 -3.467073 -2.656756 -1.305630 -0.072756 \n", + "3 2027-01-15 21:00:00 1.223064 0.472444 -0.424000 -1.145581 -1.355121 \n", + "4 2027-01-16 05:00:00 -4.802627 -5.831013 -6.067744 -5.460156 -4.671143 \n", + "... ... ... ... ... ... ... \n", + "31732 2024-12-13 00:00:00 -0.180702 -0.499793 -0.749829 -0.868485 -0.966754 \n", + "31733 2024-12-13 08:00:00 -0.467943 -0.162740 0.092000 0.347840 0.636395 \n", + "31734 2024-12-13 16:00:00 1.477526 1.657321 1.660344 1.656600 1.685652 \n", + "31735 2024-12-14 00:00:00 1.728615 1.616265 1.509833 1.380749 1.263744 \n", + "31736 2024-12-14 08:00:00 1.136868 1.380418 1.642153 1.909381 2.114439 \n", + "\n", + " RHR-5 RHR-6 RHR-7 RHR-8 ... RHR-10 RHR-11 \\\n", + "0 -2.755962 -3.681322 -4.674443 -5.668570 ... -6.937363 -7.102118 \n", + "1 -6.554774 -6.112156 -5.396099 -4.415848 ... -2.656756 -1.305630 \n", + "2 1.046195 1.530467 1.829053 1.223064 ... -0.424000 -1.145581 \n", + "3 -2.321206 -3.124961 -3.928738 -4.802627 ... -6.067744 -5.460156 \n", + "4 -3.408943 -2.237883 -1.187843 -0.062360 ... 2.266944 3.794465 \n", + "... ... ... ... ... ... ... ... \n", + "31732 -1.004670 -0.888210 -0.580762 -0.467943 ... 0.092000 0.347840 \n", + "31733 0.958195 1.170514 1.301841 1.477526 ... 1.660344 1.656600 \n", + "31734 1.747252 1.767329 1.793616 1.728615 ... 1.509833 1.380749 \n", + "31735 1.139997 1.024205 0.946663 1.136868 ... 1.642153 1.909381 \n", + "31736 2.282238 2.453691 2.587843 2.437232 ... 2.359840 2.173400 \n", + "\n", + " RHR-12 RHR-13 RHR-14 RHR-15 anomaly baseline label \\\n", + "0 -6.975790 -6.554774 -6.112156 -5.396099 False True normal \n", + "1 -0.072756 1.046195 1.530467 1.829053 False False normal \n", + "2 -1.355121 -2.321206 -3.124961 -3.928738 False False normal \n", + "3 -4.671143 -3.408943 -2.237883 -1.187843 False False normal \n", + "4 4.625745 4.827756 4.720000 4.677464 False False normal \n", + "... ... ... ... ... ... ... ... \n", + "31732 0.636395 0.958195 1.170514 1.301841 False False recovered \n", + "31733 1.685652 1.747252 1.767329 1.793616 False False recovered \n", + "31734 1.263744 1.139997 1.024205 0.946663 False False recovered \n", + "31735 2.114439 2.282238 2.453691 2.587843 False False recovered \n", + "31736 2.098140 1.967669 1.784512 1.561848 False False recovered \n", + "\n", + " participant_id \n", + "0 P110465 \n", + "1 P110465 \n", + "2 P110465 \n", + "3 P110465 \n", + "4 P110465 \n", + "... ... \n", + "31732 P992022 \n", + "31733 P992022 \n", + "31734 P992022 \n", + "31735 P992022 \n", + "31736 P992022 \n", + "\n", + "[31737 rows x 21 columns]" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Read CSV data\n", + "data_path = \"/workspaces/hiaac-m4/data/Stanford-COVID/processed/windowed_16_overlap_8_df_scaled.csv\"\n", + "df = pd.read_csv(data_path)\n", + "df" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "CovidUserAnomalyDataModule (Data=/workspaces/hiaac-m4/data/Stanford-COVID/processed/windowed_16_overlap_8_df_scaled.csv, 1 participant selected)" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dm = CovidUserAnomalyDataModule(\n", + " data_path,\n", + " participants=[\"P992022\"],\n", + " batch_size=32,\n", + " num_workers=0,\n", + " reshape=(16, 1),\n", + ")\n", + "dm" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "LSTMAutoencoder(\n", + " (backbone): _LSTMAutoEncoder(\n", + " (lstm1): LSTM(1, 128, batch_first=True)\n", + " (lstm2): LSTM(128, 64, batch_first=True)\n", + " (repeat_vector): Linear(in_features=64, out_features=1024, bias=True)\n", + " (lstm3): LSTM(64, 64, batch_first=True)\n", + " (lstm4): LSTM(64, 128, batch_first=True)\n", + " (time_distributed): Linear(in_features=128, out_features=1, bias=True)\n", + " )\n", + " (loss_fn): MSELoss()\n", + ")" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model = LSTMAutoencoder(input_shape=(16, 1))\n", + "model" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU available: True (cuda), used: False\n", + "TPU available: False, using: 0 TPU cores\n", + "IPU available: False, using: 0 IPUs\n", + "HPU available: False, using: 0 HPUs\n", + "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/setup.py:187: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "trainer = L.Trainer(max_epochs=100, devices=1, accelerator=\"cpu\")\n", + "trainer" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + " | Name | Type | Params\n", + "----------------------------------------------\n", + "0 | backbone | _LSTMAutoEncoder | 316 K \n", + "1 | loss_fn | MSELoss | 0 \n", + "----------------------------------------------\n", + "316 K Trainable params\n", + "0 Non-trainable params\n", + "316 K Total params\n", + "1.264 Total estimated model params size (MB)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "122a71df981c48c183eb2b4e7585103d", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Sanity Checking: | | 0/? [00:00 anomaly_threshold else 0 for loss in losses]" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
truepredictedlossanomaly_threshold
0000.0237000.374275
1000.0914130.374275
2000.0542990.374275
3000.0074860.374275
4000.0246010.374275
...............
89100.0898330.374275
90100.0515620.374275
91100.1327480.374275
92100.1586100.374275
93100.0255220.374275
\n", + "

94 rows × 4 columns

\n", + "
" + ], + "text/plain": [ + " true predicted loss anomaly_threshold\n", + "0 0 0 0.023700 0.374275\n", + "1 0 0 0.091413 0.374275\n", + "2 0 0 0.054299 0.374275\n", + "3 0 0 0.007486 0.374275\n", + "4 0 0 0.024601 0.374275\n", + ".. ... ... ... ...\n", + "89 1 0 0.089833 0.374275\n", + "90 1 0 0.051562 0.374275\n", + "91 1 0 0.132748 0.374275\n", + "92 1 0 0.158610 0.374275\n", + "93 1 0 0.025522 0.374275\n", + "\n", + "[94 rows x 4 columns]" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "results_dataframe = pd.DataFrame(\n", + " {\n", + " \"true\": y_test,\n", + " \"predicted\": y_test_hat,\n", + " \"loss\": losses,\n", + " \"anomaly_threshold\": anomaly_threshold,\n", + " }\n", + ")\n", + "\n", + "results_dataframe" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "F1-score: 0.0\n", + "Recall: 0.0\n", + "Balanced Accuracy: 0.5\n", + "ROC AUC: 0.5\n" + ] + } + ], + "source": [ + "from sklearn.metrics import f1_score, recall_score, balanced_accuracy_score, roc_auc_score\n", + "\n", + "# Extract true and predicted labels from the results_dataframe\n", + "true_labels = results_dataframe['true']\n", + "predicted_labels = results_dataframe['predicted']\n", + "\n", + "# Calculate the F1-score\n", + "f1 = f1_score(true_labels, predicted_labels)\n", + "\n", + "# Calculate the recall\n", + "recall = recall_score(true_labels, predicted_labels)\n", + "\n", + "# Calculate the balanced accuracy\n", + "balanced_acc = balanced_accuracy_score(true_labels, predicted_labels)\n", + "\n", + "# Calculate the ROC AUC\n", + "roc_auc = roc_auc_score(true_labels, predicted_labels)\n", + "\n", + "# Print the results\n", + "print(\"F1-score:\", f1)\n", + "print(\"Recall:\", recall)\n", + "print(\"Balanced Accuracy:\", balanced_acc)\n", + "print(\"ROC AUC:\", roc_auc)" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAApUAAAJOCAYAAADmqPxLAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAABTMUlEQVR4nO3dd3hUZdrH8d8kkAmQBhESkBJ6L9ID0iOIgtQFRJeuoIhgABWVDguLNBVExFBkZRGkKOiCSFV6L4JIFRQSiiQhQArJef/AzOsY0IyTw0yY74frXDLPOfOc++Qy5M79lLEYhmEIAAAAcIKXqwMAAABA9kdSCQAAAKeRVAIAAMBpJJUAAABwGkklAAAAnEZSCQAAAKeRVAIAAMBpJJUAAABwGkklAAAAnEZSCcCtnThxQs2bN1dgYKAsFotWrlyZpf2fPXtWFotF8+fPz9J+s7PGjRurcePGrg4DQDZDUgngL506dUp9+/ZViRIl5Ovrq4CAANWvX1/vvPOObt26Zeq9u3fvrsOHD2v8+PFauHChatasaer97qcePXrIYrEoICDgrl/HEydOyGKxyGKxaPLkyQ73f+HCBY0aNUoHDhzIgmgB4M/lcHUAANzbl19+qX/84x+yWq3q1q2bKlWqpOTkZH333XcaOnSovv/+e3344Yem3PvWrVvavn273nzzTb300kum3KNYsWK6deuWcubMaUr/fyVHjhy6efOmVq1apU6dOtmd++STT+Tr66vExMS/1feFCxc0evRohYWFqVq1apl+39dff/237gfAs5FUArinM2fOqEuXLipWrJg2bNigggUL2s71799fJ0+e1Jdffmna/S9fvixJCgoKMu0eFotFvr6+pvX/V6xWq+rXr6///ve/GZLKRYsW6cknn9SyZcvuSyw3b95U7ty55ePjc1/uB+DBwvA3gHuaNGmSEhISFBUVZZdQpitVqpQGDhxoe3379m2NHTtWJUuWlNVqVVhYmN544w0lJSXZvS8sLEytWrXSd999p9q1a8vX11clSpTQxx9/bLtm1KhRKlasmCRp6NChslgsCgsLk3Rn2Dj97783atQoWSwWu7Z169bp0UcfVVBQkPz8/FS2bFm98cYbtvP3mlO5YcMGNWjQQHny5FFQUJDatGmjY8eO3fV+J0+eVI8ePRQUFKTAwED17NlTN2/evPcX9g+6du2q//3vf4qNjbW17d69WydOnFDXrl0zXP/rr79qyJAhqly5svz8/BQQEKCWLVvq4MGDtms2bdqkWrVqSZJ69uxpG0ZPf87GjRurUqVK2rt3rxo2bKjcuXPbvi5/nFPZvXt3+fr6Znj+Fi1aKG/evLpw4UKmnxXAg4ukEsA9rVq1SiVKlFC9evUydX2fPn00YsQIVa9eXdOmTVOjRo00YcIEdenSJcO1J0+eVMeOHfXYY49pypQpyps3r3r06KHvv/9ektS+fXtNmzZNkvT0009r4cKFmj59ukPxf//992rVqpWSkpI0ZswYTZkyRU899ZS2bt36p+/75ptv1KJFC126dEmjRo1SZGSktm3bpvr16+vs2bMZru/UqZOuX7+uCRMmqFOnTpo/f75Gjx6d6Tjbt28vi8Wi5cuX29oWLVqkcuXKqXr16hmuP336tFauXKlWrVpp6tSpGjp0qA4fPqxGjRrZErzy5ctrzJgxkqTnn39eCxcu1MKFC9WwYUNbP1evXlXLli1VrVo1TZ8+XU2aNLlrfO+8847y58+v7t27KzU1VZI0e/Zsff3113rvvfdUqFChTD8rgAeYAQB3ERcXZ0gy2rRpk6nrDxw4YEgy+vTpY9c+ZMgQQ5KxYcMGW1uxYsUMScaWLVtsbZcuXTKsVqsxePBgW9uZM2cMScbbb79t12f37t2NYsWKZYhh5MiRxu//WZs2bZohybh8+fI9406/x7x582xt1apVMwoUKGBcvXrV1nbw4EHDy8vL6NatW4b79erVy67Pdu3aGcHBwfe85++fI0+ePIZhGEbHjh2NZs2aGYZhGKmpqUZoaKgxevTou34NEhMTjdTU1AzPYbVajTFjxtjadu/eneHZ0jVq1MiQZHzwwQd3PdeoUSO7trVr1xqSjHHjxhmnT582/Pz8jLZt2/7lMwLwHFQqAdxVfHy8JMnf3z9T13/11VeSpMjISLv2wYMHS1KGuZcVKlRQgwYNbK/z58+vsmXL6vTp03875j9Kn4v5+eefKy0tLVPvuXjxog4cOKAePXooX758tvYqVarosccesz3n7/Xr18/udYMGDXT16lXb1zAzunbtqk2bNik6OlobNmxQdHT0XYe+pTvzML287vzznZqaqqtXr9qG9vft25fpe1qtVvXs2TNT1zZv3lx9+/bVmDFj1L59e/n6+mr27NmZvheABx9JJYC7CggIkCRdv349U9f/9NNP8vLyUqlSpezaQ0NDFRQUpJ9++smuvWjRohn6yJs3r65du/Y3I86oc+fOql+/vvr06aOQkBB16dJFS5Ys+dMEMz3OsmXLZjhXvnx5XblyRTdu3LBr/+Oz5M2bV5IcepYnnnhC/v7++vTTT/XJJ5+oVq1aGb6W6dLS0jRt2jSVLl1aVqtVDz30kPLnz69Dhw4pLi4u0/d8+OGHHVqUM3nyZOXLl08HDhzQu+++qwIFCmT6vQAefCSVAO4qICBAhQoV0pEjRxx63x8XytyLt7f3XdsNw/jb90if75cuV65c2rJli7755hv985//1KFDh9S5c2c99thjGa51hjPPks5qtap9+/ZasGCBVqxYcc8qpST961//UmRkpBo2bKj//Oc/Wrt2rdatW6eKFStmuiIr3fn6OGL//v26dOmSJOnw4cMOvRfAg4+kEsA9tWrVSqdOndL27dv/8tpixYopLS1NJ06csGuPiYlRbGysbSV3VsibN6/dSul0f6yGSpKXl5eaNWumqVOn6ujRoxo/frw2bNigjRs33rXv9DiPHz+e4dwPP/yghx56SHny5HHuAe6ha9eu2r9/v65fv37XxU3pPvvsMzVp0kRRUVHq0qWLmjdvroiIiAxfk8wm+Jlx48YN9ezZUxUqVNDzzz+vSZMmaffu3VnWP4Dsj6QSwD29+uqrypMnj/r06aOYmJgM50+dOqV33nlH0p3hW0kZVmhPnTpVkvTkk09mWVwlS5ZUXFycDh06ZGu7ePGiVqxYYXfdr7/+muG96ZuA/3Gbo3QFCxZUtWrVtGDBArsk7ciRI/r6669tz2mGJk2aaOzYsZoxY4ZCQ0PveZ23t3eGKujSpUv1yy+/2LWlJ793S8Ad9dprr+ncuXNasGCBpk6dqrCwMHXv3v2eX0cAnofNzwHcU8mSJbVo0SJ17txZ5cuXt/tEnW3btmnp0qXq0aOHJKlq1arq3r27PvzwQ8XGxqpRo0batWuXFixYoLZt295zu5q/o0uXLnrttdfUrl07vfzyy7p586ZmzZqlMmXK2C1UGTNmjLZs2aInn3xSxYoV06VLl/T++++rcOHCevTRR+/Z/9tvv62WLVsqPDxcvXv31q1bt/Tee+8pMDBQo0aNyrLn+CMvLy+99dZbf3ldq1atNGbMGPXs2VP16tXT4cOH9cknn6hEiRJ215UsWVJBQUH64IMP5O/vrzx58qhOnToqXry4Q3Ft2LBB77//vkaOHGnb4mjevHlq3Lixhg8frkmTJjnUH4AHE5VKAH/qqaee0qFDh9SxY0d9/vnn6t+/v15//XWdPXtWU6ZM0bvvvmu79qOPPtLo0aO1e/duDRo0SBs2bNCwYcO0ePHiLI0pODhYK1asUO7cufXqq69qwYIFmjBhglq3bp0h9qJFi2ru3Lnq37+/Zs6cqYYNG2rDhg0KDAy8Z/8RERFas2aNgoODNWLECE2ePFl169bV1q1bHU7IzPDGG29o8ODBWrt2rQYOHKh9+/bpyy+/VJEiReyuy5kzpxYsWCBvb2/169dPTz/9tDZv3uzQva5fv65evXrpkUce0Ztvvmlrb9CggQYOHKgpU6Zox44dWfJcALI3i+HITHIAAADgLqhUAgAAwGkklQAAAHAaSSUAAACcRlIJAAAAp5FUAgAAwGkklQAAAHAam5+7kbS0NF24cEH+/v5Z+vFqAAB4AsMwdP36dRUqVEheXq6vmyUmJio5Odm0/n18fOTr62ta/44iqXQjFy5cyLB5MQAAcMz58+dVuHBhl8aQmJioXP7B0u2bpt0jNDRUZ86ccZvEkqTSjfj7+0uS6o9ZqRy+eVwcDYC7WdK7tqtDAHAP1+PjVap4EdvPU1dKTk6Wbt+UtUJ3ydsn62+QmqzoowuUnJxMUomM0oe8c/jmUY5cJJWAOwoICHB1CAD+gltNIcvhK4sJSaVhcf3w/h+5X0QAAADIdqhUAgAAmMUiyYzKqRsVY9NRqQQAAIDTqFQCAACYxeJ15zCjXzfjfhEBAAAg26FSCQAAYBaLxaQ5le43qZKkEgAAwCwMfwMAAACZR6USAADALB40/E2lEgAAAE6jUgkAAGAak+ZUumFd0P0iAgAAQLZDpRIAAMAszKkEAAAAMo9KJQAAgFnYpxIAAADIPCqVAAAAZvGgOZUklQAAAGZh+BsAAADIPCqVAAAAZvGg4W8qlQAAAHAalUoAAACzMKcSAAAAyDwqlQAAAGaxWEyqVDKnEgAAAA8gKpUAAABm8bLcOczo182QVAIAAJiFhToAAABA5lGpBAAAMAubnwMAAACZR6USAADALMypBAAAADKPSiUAAIBZmFMJAAAAZB6VSgAAALMwpxIAAADIPCqVAAAAZvGgOZUklQAAAGZh+BsAAADIPCqVAAAAZvGg4W8qlQAAAHAalUoAAADTmDSn0g3rgu4XEQAAALIdKpUAAABmYU4lAAAAkHlUKgEAAMxisZi0T6X7VSpJKgEAAMzC5ucAAABA5lGpBAAAMAsLdQAAAIDMo1IJAABgFuZUAgAAAJlHpRIAAMAszKkEAAAAMo9KJQAAgFmYUwkAAABkHpVKAAAAszCnEgAAAM6yWCymHZk1atSoDO8tV66c7XxiYqL69++v4OBg+fn5qUOHDoqJiXH4WUkqAQAAHnAVK1bUxYsXbcd3331nO/fKK69o1apVWrp0qTZv3qwLFy6offv2Dt+D4W8AAACTOFpVdKBjhy7PkSOHQkNDM7THxcUpKipKixYtUtOmTSVJ8+bNU/ny5bVjxw7VrVs30/egUgkAAPCAO3HihAoVKqQSJUromWee0blz5yRJe/fuVUpKiiIiImzXlitXTkWLFtX27dsdugeVSgAAALNYfjvM6FdSfHy8XbPVapXVarVrq1OnjubPn6+yZcvq4sWLGj16tBo0aKAjR44oOjpaPj4+CgoKsntPSEiIoqOjHQqJpBIAACCbKlKkiN3rkSNHatSoUXZtLVu2tP29SpUqqlOnjooVK6YlS5YoV65cWRYLSSUAAIBJzJ5Tef78eQUEBNia/1ilvJugoCCVKVNGJ0+e1GOPPabk5GTFxsbaVStjYmLuOgfzzzCnEgAAIJsKCAiwOzKTVCYkJOjUqVMqWLCgatSooZw5c2r9+vW288ePH9e5c+cUHh7uUCxUKgEAAEziDqu/hwwZotatW6tYsWK6cOGCRo4cKW9vbz399NMKDAxU7969FRkZqXz58ikgIEADBgxQeHi4Qyu/JZJKAAAA07hDUvnzzz/r6aef1tWrV5U/f349+uij2rFjh/Lnzy9JmjZtmry8vNShQwclJSWpRYsWev/99x0OiaQSAADgAbZ48eI/Pe/r66uZM2dq5syZTt2HpBIAAMAk7lCpvF9YqAMAAACnUakEAAAwi8mbn7sTKpUAAABwGpVKAAAAkzCnEgAAAHAAlUoAAACTWCwyqVKZ9V06i6QSAADAJBaZNPzthlklw98AAABwGpVKAAAAk7BQBwAAAHAAlUoAAACzsPk5AAAAkHlUKgEAAMxi0pxKgzmVAAAAeBBRqQQAADCJWau/zdn70jlUKgEAAOA0KpUAAAAm8aRKJUklAACAWdhSCAAAAMg8KpUAAAAm8aThbyqVAAAAcBqVSgAAAJNQqQQAAAAcQKUSAADAJFQqAQAAAAdQqQQAADCJJ1UqSSoBAADMwubnAAAAQOZRqQQAADCJJw1/U6kEAACA06hUAgAAmIRKJQAAAOAAKpUAAAAmoVIJAAAAOIBKJQAAgFk8aJ9KkkrASa0rhah1pRCFBFglST/9eksLd/2s3edibdeUD/VTr7pFVS7ET2mGoVOXb+r1L44pOTXNRVED+OD9mZo29W3FREercpWqmjr9PdWqXdvVYQHZFkkl4KTLCcn6aPs5/RKbKFmk5uXya8yTZdXv00P66ddbKh/qp4mty+u/e3/RjC1nlJpmqORDeWQYhqtDBzzW0iWf6rWhkXpv5geqVbuOZrw7XU892UIHvz+uAgUKuDo8PECYUwkg03acvaZdP8Xql7hE/RKbqHk7zutWSprKh/hLkl58NEwrDkVr8b4L+unXW/o5NlGbT15VShpJJeAq706fqp69n1O3Hj1VvkIFvff+B8qVO7cWzJ/r6tDwgElPKs043A1JJZCFvCxS49LB8s3ppaPR1xWUK4fKh/or9laK3ulQSUt71dCUdhVVqaC/q0MFPFZycrL279urps0ibG1eXl5q2jRCu3Zsd2FkQPZGUmmiTZs2yWKxKDY21tWhwGTFg3Nr1fO19b8X6mpQ4xIa9dVxnbt2SwUDfCVJ3WoX1ldHYzTsi2M6eTlBk9pW0MOBvi6OGvBMV65cUWpqqgoUCLFrLxASoujoaBdFhQeVRSZVKt1wpU62SSp79Oghi8WiiRMn2rWvXLnSLUvA8Cznr91S308P6aWlh7XqSIxejSilonlzKf1/zdVHYrT22GWdvHJTs777ST9fu6XHKzBvCwDw4Mg2SaUk+fr66t///reuXbuWZX0mJydnWV/wXLfTDF2IS9SJyzcUtf2cTl+5ofZVC+rXGymS7qwI/71z126pgJ+PK0IFPN5DDz0kb29vXboUY9d+KSZGoaGhLooKDyrmVLqpiIgIhYaGasKECfe8ZtmyZapYsaKsVqvCwsI0ZcoUu/NhYWEaO3asunXrpoCAAD3//POaP3++goKCtHr1apUtW1a5c+dWx44ddfPmTS1YsEBhYWHKmzevXn75ZaWmptr6WrhwoWrWrCl/f3+Fhoaqa9euunTpkmnPj+zDYrEop7dF0deTdCUhWUXy5rI7Xzgol2KuJ7koOsCz+fj46JHqNbRxw3pbW1pamjZuXK/adcNdGBmQvWWrpNLb21v/+te/9N577+nnn3/OcH7v3r3q1KmTunTposOHD2vUqFEaPny45s+fb3fd5MmTVbVqVe3fv1/Dhw+XJN28eVPvvvuuFi9erDVr1mjTpk1q166dvvrqK3311VdauHChZs+erc8++8zWT0pKisaOHauDBw9q5cqVOnv2rHr06GHmlwBuqHd4UVUu5K8Qf6uKB+dW7/CiqvpwgNb/eEWStGT/L2pXJVQNSuZToUBf9ahTREXy5tL/jvILCOAqLw+K1LyoOfrPxwv0w7Fjern/C7p544a6de/p6tDwoLGYeLiZbLdPZbt27VStWjWNHDlSUVFRduemTp2qZs2a2RLFMmXK6OjRo3r77bftkr2mTZtq8ODBttfffvutUlJSNGvWLJUsWVKS1LFjRy1cuFAxMTHy8/NThQoV1KRJE23cuFGdO3eWJPXq1cvWR4kSJfTuu++qVq1aSkhIkJ+f318+S1JSkpKS/r9aFR8f7/gXBC4XlCunXosopXx5fHQjKVVnrt7Q618c077zcZKk5Qej5ePtpRceDZO/bw6dvnJTr31+VBfjqVQCrvKPTp115fJljRk9QjHR0apStZo+X71GISEhf/1mAHeV7ZJKSfr3v/+tpk2basiQIXbtx44dU5s2beza6tevr+nTpys1NVXe3t6SpJo1a2boM3fu3LaEUpJCQkIUFhZmlxyGhITYDW/v3btXo0aN0sGDB3Xt2jWlpd35dJRz586pQoUKf/kcEyZM0OjRozPxxHBnUzac+strFu+7oMX7LtyHaABk1gv9X9IL/V9ydRh4wLH5uZtr2LChWrRooWHDhv2t9+fJkydDW86cOe1eWyyWu7alJ443btxQixYtFBAQoE8++US7d+/WihUrJGV+8c+wYcMUFxdnO86fP/93HgcAALgpT1qoky0rlZI0ceJEVatWTWXLlrW1lS9fXlu3brW7buvWrSpTpoytSplVfvjhB129elUTJ05UkSJFJEl79uxxqA+r1Sqr1ZqlcQEAALhCtqxUSlLlypX1zDPP6N1337W1DR48WOvXr9fYsWP1448/asGCBZoxY0aGYfKsULRoUfn4+Oi9997T6dOn9cUXX2js2LFZfh8AAJB9WSzmHe4m2yaVkjRmzBjbcLQkVa9eXUuWLNHixYtVqVIljRgxQmPGjDFlRXb+/Pk1f/58LV26VBUqVNDEiRM1efLkLL8PAABAdmAxDMNwdRC4Iz4+XoGBgWo0aZ1y5Mo47xOA663uxz6GgLuKj49XSHCg4uLiFBAQ4PJYAgMDVWLAZ/KyZv3P9LSkGzr9Xke3eNZ02bpSCQAAAPeQbRfqAAAAuD2z5j8ypxIAAAAPIiqVAAAAJmHzcwAAAMABVCoBAABMYtaekm5YqCSpBAAAMIuXl0VeXlmfARom9Okshr8BAADgNCqVAAAAJvGk4W8qlQAAAHAalUoAAACTsKUQAAAA4AAqlQAAACZhTiUAAADgACqVAAAAJvGkOZUklQAAACbxpKSS4W8AAAA4jaQSAADAJOkLdcw4/q6JEyfKYrFo0KBBtrbExET1799fwcHB8vPzU4cOHRQTE+NQvySVAAAAHmL37t2aPXu2qlSpYtf+yiuvaNWqVVq6dKk2b96sCxcuqH379g71TVIJAABgEosstnmVWXrI8VJlQkKCnnnmGc2ZM0d58+a1tcfFxSkqKkpTp05V06ZNVaNGDc2bN0/btm3Tjh07Mt0/SSUAAIAH6N+/v5588klFRETYte/du1cpKSl27eXKlVPRokW1ffv2TPfP6m8AAACTmL35eXx8vF271WqV1WrNcP3ixYu1b98+7d69O8O56Oho+fj4KCgoyK49JCRE0dHRmY6JSiUAAEA2VaRIEQUGBtqOCRMmZLjm/PnzGjhwoD755BP5+vqaFguVSgAAAJOYvU/l+fPnFRAQYGu/W5Vy7969unTpkqpXr25rS01N1ZYtWzRjxgytXbtWycnJio2NtatWxsTEKDQ0NNMxkVQCAABkUwEBAXZJ5d00a9ZMhw8ftmvr2bOnypUrp9dee01FihRRzpw5tX79enXo0EGSdPz4cZ07d07h4eGZjoWkEgAAwCRmz6nMDH9/f1WqVMmuLU+ePAoODra19+7dW5GRkcqXL58CAgI0YMAAhYeHq27dupm+D0klAACASbLLxzROmzZNXl5e6tChg5KSktSiRQu9//77DvVBUgkAAOBhNm3aZPfa19dXM2fO1MyZM/92nySVAAAAJnGH4e/7hS2FAAAA4DQqlQAAACbJLnMqswKVSgAAADiNSiUAAIBZTJpTKfcrVFKpBAAAgPOoVAIAAJjEk+ZUklQCAACYhC2FAAAAAAdQqQQAADCJJw1/U6kEAACA06hUAgAAmIQ5lQAAAIADqFQCAACYhDmVAAAAgAOoVAIAAJiESiUAAADgACqVAAAAJvGk1d8klQAAACZh+BsAAABwAJVKAAAAk3jS8DeVSgAAADiNSiUAAIBJmFMJAAAAOIBKJQAAgEksMmlOZdZ36TQqlQAAAHAalUoAAACTeFks8jKhVGlGn84iqQQAADAJWwoBAAAADqBSCQAAYBK2FAIAAAAcQKUSAADAJF6WO4cZ/bobKpUAAABwGpVKAAAAs1hMmv9IpRIAAAAPIiqVAAAAJmGfSgAAAMABVCoBAABMYvntjxn9uhuSSgAAAJOwpRAAAADgACqVAAAAJuFjGgEAAAAHUKkEAAAwCVsKAQAAAA6gUgkAAGASL4tFXiaUFc3o01lUKgEAAOA0KpUAAAAm8aQ5lSSVAAAAJvGkLYUylVQeOnQo0x1WqVLlbwcDAACA7ClTSWW1atVksVhkGMZdz6efs1gsSk1NzdIAAQAAsiuGv//gzJkzZscBAACAbCxTSWWxYsXMjgMAAOCBw5ZCf2HhwoWqX7++ChUqpJ9++kmSNH36dH3++edZGhwAAACyB4eTylmzZikyMlJPPPGEYmNjbXMog4KCNH369KyODwAAINuymHi4G4eTyvfee09z5szRm2++KW9vb1t7zZo1dfjw4SwNDgAAANmDw/tUnjlzRo888kiGdqvVqhs3bmRJUAAAAA8CT9qn0uFKZfHixXXgwIEM7WvWrFH58uWzIiYAAABkMw5XKiMjI9W/f38lJibKMAzt2rVL//3vfzVhwgR99NFHZsQIAACQLXlZ7hxm9OtuHE4q+/Tpo1y5cumtt97SzZs31bVrVxUqVEjvvPOOunTpYkaMAAAA2ZInDX//rc/+fuaZZ/TMM8/o5s2bSkhIUIECBbI6LgAAAGQjfyuplKRLly7p+PHjku5ky/nz58+yoAAAAB4UblhUNIXDC3WuX7+uf/7znypUqJAaNWqkRo0aqVChQnr22WcVFxdnRowAAABwcw4nlX369NHOnTv15ZdfKjY2VrGxsVq9erX27Nmjvn37mhEjAABAtpQ+p9KMw904PPy9evVqrV27Vo8++qitrUWLFpozZ44ef/zxLA0OAAAA2YPDSWVwcLACAwMztAcGBipv3rxZEhQAAMCDwJO2FHJ4+Putt95SZGSkoqOjbW3R0dEaOnSohg8fnqXBAQAAIHvIVKXykUcesRu7P3HihIoWLaqiRYtKks6dOyer1arLly8zrxIAAOA37FP5B23btjU5DAAAgAeP5bfDjH7dTaaSypEjR5odBwAAALIxh+dUAgAAIHO8LBbTjsyaNWuWqlSpooCAAAUEBCg8PFz/+9//bOcTExPVv39/BQcHy8/PTx06dFBMTIzjz+roG1JTUzV58mTVrl1boaGhypcvn90BAAAA91G4cGFNnDhRe/fu1Z49e9S0aVO1adNG33//vSTplVde0apVq7R06VJt3rxZFy5cUPv27R2+j8NJ5ejRozV16lR17txZcXFxioyMVPv27eXl5aVRo0Y5HAAAAMCDymIx78is1q1b64knnlDp0qVVpkwZjR8/Xn5+ftqxY4fi4uIUFRWlqVOnqmnTpqpRo4bmzZunbdu2aceOHQ49q8NJ5SeffKI5c+Zo8ODBypEjh55++ml99NFHGjFihMM3BwAAwP2TmpqqxYsX68aNGwoPD9fevXuVkpKiiIgI2zXlypVT0aJFtX37dof6dnjz8+joaFWuXFmS5OfnZ/u871atWrFPJQAAwO+YvaVQfHy8XbvVapXVas1w/eHDhxUeHq7ExET5+flpxYoVqlChgg4cOCAfHx8FBQXZXR8SEmK3J3lmOFypLFy4sC5evChJKlmypL7++mtJ0u7du+/6EAAAADBHkSJFFBgYaDsmTJhw1+vKli2rAwcOaOfOnXrhhRfUvXt3HT16NEtjcbhS2a5dO61fv1516tTRgAED9OyzzyoqKkrnzp3TK6+8kqXBAQAAZGeOzn90pF9JOn/+vAICAmzt9yrw+fj4qFSpUpKkGjVqaPfu3XrnnXfUuXNnJScnKzY21q5aGRMTo9DQUIdicjipnDhxou3vnTt3VrFixbRt2zaVLl1arVu3drQ7AAAA/E3p2wQ5Ki0tTUlJSapRo4Zy5syp9evXq0OHDpKk48eP69y5cwoPD3eoT4eTyj+qW7eu6tatq0uXLulf//qX3njjDWe7BAAAeCA4uqekI/1m1rBhw9SyZUsVLVpU169f16JFi7Rp0yatXbtWgYGB6t27tyIjI5UvXz4FBARowIABCg8PV926dR2KyemkMt3Fixc1fPhwkkoAAIDfmD38nRmXLl1St27ddPHiRQUGBqpKlSpau3atHnvsMUnStGnT5OXlpQ4dOigpKUktWrTQ+++/73BMWZZUAgAAwP1ERUX96XlfX1/NnDlTM2fOdOo+JJUAAAAmMXtLIXfCZ38DAADAaZmuVEZGRv7p+cuXLzsdDO7YsXCJLN4+rg4DwN30c2w1JADP5iVzKnjuWBXMdFK5f//+v7ymYcOGTgUDAACA7CnTSeXGjRvNjAMAAOCBw5xKAAAAwAGs/gYAADCJxSJ5uXifyvuFpBIAAMAkXiYllWb06SyGvwEAAOA0KpUAAAAmYaHOX/j222/17LPPKjw8XL/88oskaeHChfruu++yNDgAAABkDw4nlcuWLVOLFi2UK1cu7d+/X0lJSZKkuLg4/etf/8ryAAEAALKr9DmVZhzuxuGkcty4cfrggw80Z84c5cyZ09Zev3597du3L0uDAwAAQPbg8JzK48eP3/WTcwIDAxUbG5sVMQEAADwQLBZztv9xwymVjlcqQ0NDdfLkyQzt3333nUqUKJElQQEAACB7cTipfO655zRw4EDt3LlTFotFFy5c0CeffKIhQ4bohRdeMCNGAACAbMnLYjHtcDcOD3+//vrrSktLU7NmzXTz5k01bNhQVqtVQ4YM0YABA8yIEQAAAG7O4aTSYrHozTff1NChQ3Xy5EklJCSoQoUK8vPzMyM+AACAbMtL5nzSjDt+es3f3vzcx8dHFSpUyMpYAAAAHiietFDH4aSySZMmf7qL+4YNG5wKCAAAANmPw0lltWrV7F6npKTowIEDOnLkiLp3755VcQEAAGR7XjJnUY2X3K9U6XBSOW3atLu2jxo1SgkJCU4HBAAAgOwny+Z5Pvvss5o7d25WdQcAAJDtpc+pNONwN1mWVG7fvl2+vr5Z1R0AAACyEYeHv9u3b2/32jAMXbx4UXv27NHw4cOzLDAAAIDszsty5zCjX3fjcFIZGBho99rLy0tly5bVmDFj1Lx58ywLDAAAANmHQ0llamqqevbsqcqVKytv3rxmxQQAAPBAsFhkyurvbD+n0tvbW82bN1dsbKxJ4QAAADw4WKjzJypVqqTTp0+bEQsAAACyKYeTynHjxmnIkCFavXq1Ll68qPj4eLsDAAAAd6Qv1DHjcDeZnlM5ZswYDR48WE888YQk6amnnrL7uEbDMGSxWJSampr1UQIAAMCtZTqpHD16tPr166eNGzeaGQ8AAMADw/LbHzP6dTeZTioNw5AkNWrUyLRgAAAAkD05tKWQxR2XGgEAALgpNj+/hzJlyvxlYvnrr786FRAAAACyH4eSytGjR2f4RB0AAADcHZXKe+jSpYsKFChgViwAAADIpjKdVDKfEgAAwDEWi8WUHMod8zKHV38DAAAgcxj+vou0tDQz4wAAAEA25tCcSgAAAGSexXLnMKNfd+PwZ38DAAAAf0SlEgAAwCReFou8TCgrmtGns6hUAgAAwGlUKgEAAEziSau/qVQCAADAaVQqAQAAzGLS6m+5YaWSpBIAAMAkXrLIy4QM0Iw+ncXwNwAAAJxGpRIAAMAkbH4OAAAAOIBKJQAAgEnYUggAAABwAJVKAAAAk/AxjQAAAIADqFQCAACYhNXfAAAAgAOoVAIAAJjESybNqXTDT9QhqQQAADAJw98AAACAA6hUAgAAmMRL5lTw3LEq6I4xAQAAIJuhUgkAAGASi8UiiwkTIM3o01lUKgEAAOA0KpUAAAAmsfx2mNGvu6FSCQAAAKdRqQQAADCJl8Wkzc/dcE4lSSUAAICJ3C/9MwfD3wAAAHAalUoAAACT8DGNAAAAeCBMmDBBtWrVkr+/vwoUKKC2bdvq+PHjdtckJiaqf//+Cg4Olp+fnzp06KCYmBiH7kNSCQAAYJL0zc/NODJr8+bN6t+/v3bs2KF169YpJSVFzZs3140bN2zXvPLKK1q1apWWLl2qzZs368KFC2rfvr1Dz8rwNwAAwANszZo1dq/nz5+vAgUKaO/evWrYsKHi4uIUFRWlRYsWqWnTppKkefPmqXz58tqxY4fq1q2bqftQqQQAADCJl4mHJMXHx9sdSUlJfxlTXFycJClfvnySpL179yolJUURERG2a8qVK6eiRYtq+/btDj0rAAAAsqEiRYooMDDQdkyYMOFPr09LS9OgQYNUv359VapUSZIUHR0tHx8fBQUF2V0bEhKi6OjoTMfC8DcAAIBJHJ3/6Ei/knT+/HkFBATY2q1W65++r3///jpy5Ii+++67LI+JpBIAACCbCggIsEsq/8xLL72k1atXa8uWLSpcuLCtPTQ0VMnJyYqNjbWrVsbExCg0NDTTsTD8DQAAYBKLiUdmGYahl156SStWrNCGDRtUvHhxu/M1atRQzpw5tX79elvb8ePHde7cOYWHh2f6PlQqAQAATGL28Hdm9O/fX4sWLdLnn38uf39/2zzJwMBA5cqVS4GBgerdu7ciIyOVL18+BQQEaMCAAQoPD8/0ym+JpBIAAOCBNmvWLElS48aN7drnzZunHj16SJKmTZsmLy8vdejQQUlJSWrRooXef/99h+5DUgkAAGCS32//k9X9ZpZhGH95ja+vr2bOnKmZM2fel5gAAACAu6JSCQAAYBJ3mFN5v1CpBAAAgNOoVAIAAJjE0e1/HOnX3VCpBAAAgNOoVAIAAJjEYrlzmNGvuyGpBAAAMImXLPIyYbDajD6dxfA3AAAAnEalEgAAwCSeNPxNpRIAAABOo1IJAABgEstvf8zo191QqQQAAIDTqFQCAACYhDmVAAAAgAOoVAIAAJjEYtI+lcypBAAAwAOJSiUAAIBJPGlOJUklAACASTwpqWT4GwAAAE6jUgkAAGASNj8HAAAAHEClEgAAwCReljuHGf26GyqVAAAAcBqVSgAAAJMwpxIAAABwAJVKAAAAk3jSPpUklQAAACaxyJyhajfMKRn+Bpz1Zt8ndGv/DLvjwPK3bOdDgv0VNbabzqz7l65sm6Jti15T22bVXBcwAEnSB+/PVNlSYQry81WDenW0e9cuV4cEZGtUKoEs8P3JC3qy33u217dT02x//2hsNwX559I/Bs3WldgEdW5ZU//5dy/Vf2aSDh7/2RXhAh5v6ZJP9drQSL038wPVql1HM96drqeebKGD3x9XgQIFXB0eHiBsKQTAIbdT0xRz9brtuBp7w3aubtUSen/xZu35/ied/eWq/v3RWsVev6VHKhRxYcSAZ3t3+lT17P2cuvXoqfIVKui99z9Qrty5tWD+XFeHBmRbJJVAFihVNL9Ofz1eR1eN0rzx3VUkNK/t3I6Dp9WxeQ3lDcgti8Wif7SoIV9rDm3Zc8KFEQOeKzk5Wfv37VXTZhG2Ni8vLzVtGqFdO7a7MDI8iCwm/nE3DH8DTtp95KyeH/Ef/fhTjEIfCtSbfVvqm7mvqEbH8Uq4maRnX52rhf/upQubJyklJVU3E5PVOXKOTp+/4urQAY905coVpaamqkCBELv2AiEhOn78BxdFBWR/VCqdEBYWpunTp7s6DLjY11uPavk3+3XkxAV9s/2Y2r40S4F+udSheXVJ0sj+rRTkn0st+76r+s9O0rv/2aD/TOqliqUKuThyAIDZ0rcUMuNwN26RVG7fvl3e3t568sknXR0K4LS4hFs6ee6SShbJr+KFH9ILXRqp76j/aNOuH3X4x1/0rw//p31Hz6lv54auDhXwSA899JC8vb116VKMXfulmBiFhoa6KCog+3OLpDIqKkoDBgzQli1bdOHCBVeHAzglTy4fFS/8kKKvxCm3r48kKc0w7K5JTTXk5Y6/ZgIewMfHR49Ur6GNG9bb2tLS0rRx43rVrhvuwsjwILKYeLgblyeVCQkJ+vTTT/XCCy/oySef1Pz5823nNm3aJIvFovXr16tmzZrKnTu36tWrp+PHj9v1MWvWLJUsWVI+Pj4qW7asFi5caHfeYrFo9uzZatWqlXLnzq3y5ctr+/btOnnypBo3bqw8efKoXr16OnXqlO09p06dUps2bRQSEiI/Pz/VqlVL33zzzT2fo1evXmrVqpVdW0pKigoUKKCoqCgnvkJwdxNeaadHa5RS0YL5VLdqcX069XmlpqVpyZq9On42WifPXdKMt55WzYrFVLzwQxr4z6ZqVresVm066OrQAY/18qBIzYuao/98vEA/HDuml/u/oJs3bqhb956uDg3ItlyeVC5ZskTlypVT2bJl9eyzz2ru3Lky/lDVefPNNzVlyhTt2bNHOXLkUK9evWznVqxYoYEDB2rw4ME6cuSI+vbtq549e2rjxo12fYwdO1bdunXTgQMHVK5cOXXt2lV9+/bVsGHDtGfPHhmGoZdeesl2fUJCgp544gmtX79e+/fv1+OPP67WrVvr3Llzd32OPn36aM2aNbp48aKtbfXq1bp586Y6d+581/ckJSUpPj7e7kD283BIkD6e0FOHVg7Xf/7dS7/G3VCjblN05VqCbt9OU9sBs3TlWoI+e6evdi8Zpq6taqvPiIVa+91RV4cOeKx/dOqsCf+erDGjR6hOzWo6ePCAPl+9RiEhIX/9ZsABXrLIy2LC4Ya1SovxxwzuPqtfv746deqkgQMH6vbt2ypYsKCWLl2qxo0ba9OmTWrSpIm++eYbNWvWTJL01Vdf6cknn9StW7fk6+ur+vXrq2LFivrwww9tfXbq1Ek3btzQl19+KelOpfKtt97S2LFjJUk7duxQeHi4oqKibAnq4sWL1bNnT926deuesVaqVEn9+vWzJZ9hYWEaNGiQBg0aJEmqWLGiunfvrldffVWS9NRTTyk4OFjz5s27a3+jRo3S6NGjM7RbKz8ni7ePI19GAPfJtd0zXB0CgHuIj49XSHCg4uLiFBAQ4PJYAgMD9c2+n5THP+tjuXE9XhHVi7nFs6ZzaaXy+PHj2rVrl55++mlJUo4cOdS5c+cMw8VVqlSx/b1gwYKSpEuXLkmSjh07pvr169tdX79+fR07duyefaT/Jlq5cmW7tsTERFu1MCEhQUOGDFH58uUVFBQkPz8/HTt27J6VSulOtTI9gYyJidH//vc/u6rqHw0bNkxxcXG24/z58/e8FgAAwJ25dJ/KqKgo3b59W4UK/f/WKoZhyGq1asaM/68G5MyZ0/Z3y2+LG9LS/v9j8DLjbn38Wb9DhgzRunXrNHnyZJUqVUq5cuVSx44dlZycfM97dOvWTa+//rq2b9+ubdu2qXjx4mrQoME9r7darbJarQ49BwAAyEbMWlXjfqPfrksqb9++rY8//lhTpkxR8+bN7c61bdtW//3vf1WuXLm/7Kd8+fLaunWrunfvbmvbunWrKlSo4FR8W7duVY8ePdSuXTtJdyqXZ8+e/dP3BAcHq23btpo3b562b9+unj2Z8A0AADyDy5LK1atX69q1a+rdu7cCAwPtznXo0EFRUVF6++23/7KfoUOHqlOnTnrkkUcUERGhVatWafny5X+6UjszSpcureXLl6t169ayWCwaPnx4pqqjffr0UatWrZSammqX6AIAAM9j1kcquuPHNLpsTmVUVJQiIiIyJJTSnaRyz549OnTo0F/207ZtW73zzjuaPHmyKlasqNmzZ2vevHlq3LixU/FNnTpVefPmVb169dS6dWu1aNFC1atX/8v3RUREqGDBgmrRooXdsD4AAMCDzOWrvx80CQkJevjhhzVv3jy1b9/eofemrxRj9Tfgvlj9Dbgvd1z9vf7AOfmZsPo74Xq8mlUr6hbPms6lC3UeJGlpabpy5YqmTJmioKAgPfXUU64OCQAA4L4hqcwi586dU/HixVW4cGHNnz9fOXLwpQUAwNN50OJvksqsEhYWluGTgAAAgIfzoKzS5R/TCAAAgOyPSiUAAIBJ2FIIAAAAcACVSgAAAJNYLHcOM/p1N1QqAQAA4DQqlQAAACbxoMXfVCoBAADgPCqVAAAAZvGgUiWVSgAAADiNSiUAAIBJPGmfSpJKAAAAk7ClEAAAAOAAKpUAAAAm8aB1OlQqAQAA4DwqlQAAAGbxoFIllUoAAAA4jUolAACASTxpSyEqlQAAAHAalUoAAACTeNI+lSSVAAAAJvGgdToMfwMAAMB5VCoBAADM4kGlSiqVAAAAD7gtW7aodevWKlSokCwWi1auXGl33jAMjRgxQgULFlSuXLkUERGhEydOOHQPkkoAAACTWEz844gbN26oatWqmjlz5l3PT5o0Se+++64++OAD7dy5U3ny5FGLFi2UmJiY6Xsw/A0AAPCAa9mypVq2bHnXc4ZhaPr06XrrrbfUpk0bSdLHH3+skJAQrVy5Ul26dMnUPahUAgAAmCR9SyEzDkmKj4+3O5KSkhyO8cyZM4qOjlZERIStLTAwUHXq1NH27dsz3Q9JJQAAQDZVpEgRBQYG2o4JEyY43Ed0dLQkKSQkxK49JCTEdi4zGP4GAAAwidmLv8+fP6+AgABbu9VqNeFumUOlEgAAIJsKCAiwO/5OUhkaGipJiomJsWuPiYmxncsMkkoAAACzWEw8skjx4sUVGhqq9evX29ri4+O1c+dOhYeHZ7ofhr8BAABM8ne2/8lsv45ISEjQyZMnba/PnDmjAwcOKF++fCpatKgGDRqkcePGqXTp0ipevLiGDx+uQoUKqW3btpm+B0klAADAA27Pnj1q0qSJ7XVkZKQkqXv37po/f75effVV3bhxQ88//7xiY2P16KOPas2aNfL19c30PUgqAQAATPL77X+yul9HNG7cWIZh/El/Fo0ZM0Zjxoz52zExpxIAAABOo1IJAABgErO3FHInVCoBAADgNCqVAAAAZvGgUiWVSgAAADiNSiUAAIBJ3GWfyvuBpBIAAMAsJm0p5IY5JcPfAAAAcB6VSgAAAJN40DodKpUAAABwHpVKAAAAs3hQqZJKJQAAAJxGpRIAAMAknrSlEJVKAAAAOI1KJQAAgEksJu1Tacrel06iUgkAAACnUakEAAAwiQct/iapBAAAMI0HZZUMfwMAAMBpVCoBAABMwpZCAAAAgAOoVAIAAJjEIpO2FMr6Lp1GpRIAAABOo1IJAABgEg9a/E2lEgAAAM6jUgkAAGAST/qYRpJKAAAA03jOADjD3wAAAHAalUoAAACTeNLwN5VKAAAAOI1KJQAAgEk8Z0YllUoAAABkASqVAAAAJmFOJQAAAOAAKpUAAAAmsfz2x4x+3Q2VSgAAADiNSiUAAIBZPGj5N0klAACASTwop2T4GwAAAM6jUgkAAGASthQCAAAAHEClEgAAwCRsKQQAAAA4gEolAACAWTxo+TeVSgAAADiNSiUAAIBJPKhQSVIJAABgFrYUAgAAABxApRIAAMA05mwp5I4D4FQqAQAA4DQqlQAAACZhTiUAAADgAJJKAAAAOI2kEgAAAE5jTiUAAIBJmFMJAAAAOIBKJQAAgEksJu1Tac7el84hqQQAADAJw98AAACAA6hUAgAAmMQicz5Q0Q0LlVQqAQAA4DwqlQAAAGbxoFIllUoAAAA4jUolAACASTxpSyEqlQAAAHAalUoAAACTeNI+lSSVAAAAJvGgdToMfwMAAMB5VCoBAADM4kGlSiqVAAAAHmDmzJkKCwuTr6+v6tSpo127dmVp/ySVAAAAJrGY+McRn376qSIjIzVy5Ejt27dPVatWVYsWLXTp0qUse1aSSgAAgAfc1KlT9dxzz6lnz56qUKGCPvjgA+XOnVtz587NsnuQVAIAAJgkfUshM47MSk5O1t69exUREWFr8/LyUkREhLZv355lz8pCHTdiGMad/6YmuzgSAPcSHx/v6hAA3MP1374/03+eugOz/s1I7/eP/VutVlmtVru2K1euKDU1VSEhIXbtISEh+uGHH7IsJpJKN3L9+nVJUvLRBS6OBMC9hATPcXUIAP7C9evXFRgY6NIYfHx8FBoaqtLFi5h2Dz8/PxUpYt//yJEjNWrUKNPu+WdIKt1IoUKFdP78efn7+8vijlvlw2Hx8fEqUqSIzp8/r4CAAFeHA+B3+P588BiGoevXr6tQoUKuDkW+vr46c+aMkpPNG300DCNDvvDHKqUkPfTQQ/L29lZMTIxde0xMjEJDQ7MsHpJKN+Ll5aXChQu7OgyYICAggB9agJvi+/PB4uoK5e/5+vrK19fX1WHIx8dHNWrU0Pr169W2bVtJUlpamtavX6+XXnopy+5DUgkAAPCAi4yMVPfu3VWzZk3Vrl1b06dP140bN9SzZ88suwdJJQAAwAOuc+fOunz5skaMGKHo6GhVq1ZNa9asybB4xxkklYCJrFarRo4cedc5LgBci+9PeJqXXnopS4e7/8hiuNO6ewAAAGRLbH4OAAAAp5FUAgAAwGkklQAAAHAaSSUAAACcRlIJuJm0tDRXhwAgE/heBeyRVAJuYvr06Tp8+LC8vLz4YQVkA15ed36Ebt68WbGxsWIzFXg6kkrADSQkJGj58uVq2LChjh07RmIJZAOGYWjXrl1q0qSJoqOjZbFYSCzh0dinEnATv/zyi/r376+tW7dq8+bNqlChgtLS0mzVEADuqUWLFgoODtb8+fPl4+Pj6nAAl+GnFeAmHn74Yc2cOVN169ZVo0aNdPToUSqWgBu5ffu23euUlBRJUvv27XX69GnFxMRIYq4lPBdJJeAG0gcMHn74Yc2aNYvEEnAjZ86ckSTlyHHnk423bt2qpKQk5cyZU5LUtWtX/fLLL5o+fbokMboAj8X/+YALpSeTFovF1la4cGHNmjVLderUIbEEXOyFF17Qiy++qP3790uSvvnmG3Xr1k2VK1fWZ599piNHjsjf31+jR4/Wzp07dezYMRdHDLgOSSXgIoZhyGKxaMuWLXr99dc1YMAALVmyRNKdxPLDDz+0JZYs3gFco127dvrxxx81depUHT16VI0aNdKXX36pZs2aadKkSWrXrp0mT56snDlz6tKlSzpx4oQksWAHHomFOoALrVixQs8995zq1aunhx56SPPnz9fEiRM1aNAg+fj46MKFC3rxxRf1xRdf6NixYypbtqyrQwY8xu9/8evRo4dq166tN954Q1WqVJEkHT58WHv37tX48eNVo0YNLVmyRFWqVNHXX3+tAgUKuDh64P4jqQRcZM+ePWrbtq1GjBih559/XtHR0SpdurRu3LihwYMHa8KECcqRI4fOnz+voUOHasyYMSpTpoyrwwY8SnpiuXnzZvXs2VPh4eEaNGiQatWqZbvm3LlzOnLkiBYsWKANGzZo4cKFevzxx9m9AR6HpBJwgbS0NP33v//VsWPHNG7cOJ0/f14NGjRQq1atVKNGDfXu3Vvjxo3TkCFD5OPjo9TUVHl7e7s6bMAj3CsZ3LRpk3r16qXw8HANHjxY1atXz3BNq1atlJKSorVr196PUAG3wq9QwH2U/jucl5eXmjRporZt2yo5OVm9e/dWs2bN9M477+iJJ55QoUKF9NZbb2ns2LGSREIJ3Ce/Tyh/+uknHTlyRKmpqbp9+7YaN26sjz76SNu3b9eUKVNsi3ckKSkpSZLUt29fxcbG6urVqy6JH3AlkkrgPkhPJm/evGl7XahQIdWsWVNXrlzRlStX1LlzZ3l7e8tqteqJJ57QggUL9Mwzz7gybMCj/D6hHDFihFq1aqV69erpscce08KFC3Xjxg01bdpUH330kXbs2KGpU6dq586dkiSr1SpJWr16ta5cuWLbbgjwJCSVwH1gsVj05Zdf6h//+IfatWunjz/+WPHx8ZKk69ev6+DBg/rxxx8VExOjyZMna8eOHWrTpo3KlSvn4sgBz5GeUI4ePVpz5szRuHHjdObMGaWmpmry5MmaNWuWEhISbInlsmXLtGbNGtv7b9++LS8vLy1atEgBAQGuegzAZZhTCdwHO3fuVEREhPr166ddu3YpOTlZ1atX15gxYxQcHKyJEyfqjTfeUKlSpfTrr79q3bp1euSRR1wdNuBx9u3bp759+2r8+PFq3ry5Nm3apFatWqlKlSq6cuWKXnzxRT333HPKkyeP9u3bp6pVqzI9BfgNSSVgkvRVo5K0fPlyHThwQGPGjJEkTZo0SStXrlTlypU1ceJE5c2bV9u3b1dcXJwqVqyoIkWKuDJ0wGPFxMRo7dq16ty5s7Zv365OnTppwoQJ6t27t6pXr67ExER17NhRr7/+unLnzi1JLKQDfpPD1QEAD6L0hHL37t26cOGC9uzZI39/f9v5wYMHy2KxaPny5Xrrrbc0atQohYeHuzBiwPPcbZV3/vz51bp1a/n4+OjDDz9Ut27d1KNHD0lSmTJltHv3bl27dk25cuWyvYeEEriDpBIwgcVi0bJly9S9e3cFBQXp119/VdmyZTVw4EDlzp1b3t7eGjx4sLy8vBQVFSUfHx9NmTJFFovF7iMbAZjj9wnlxo0blTNnTuXNm1cVK1ZU3rx5lZaWpsuXLyt//vy2pDFHjhyaNWuWIiIiZLFY7EYjADD8DWSp9B8yN27c0MCBA/Xoo4/qiSee0IoVKzR79mwVK1ZMH3/8sa1qmZaWppkzZ6p169YKCwtzbfCAB3r99dc1a9YsBQcHKyEhQe+//746duyo5ORkPf/88/rhhx9UsWJFnTx5UlevXtXBgwfl7e3NxubAXfAdAWSh9CHv2rVr68KFC6pfv74KFCigPn36aNCgQbp48aL++c9/6vr165LurDYdMGAACSVwn/y+jnLs2DF98803Wr9+vRYtWqQ+ffqoU6dOmjt3rnx8fDRt2jRVqlRJsbGxKly4sPbv309CCfwJhr+BLJBeody3b59Onz6twMBAffvtt8qTJ4+kO3OuunbtKovFog8//FBPPfWUVq1aJT8/PxdHDniO3yeDiYmJunXrlho1aqSaNWtKkipVqiQfHx/16dNHaWlp6tOnj2bNmmW35+Tt27eVIwc/OoG74VctIAuk70PZoUMHBQQEaPTo0SpcuLDatGmjlJQUSXfmYz399NPq1q2bcubMqdjYWNcGDXiY9IRy1KhRatWqlXr16qW9e/cqLi5OkuTn56chQ4ZoxIgRevHFFzVjxgy7hNIwDBJK4E8wpxJwQnqFMiYmRkOGDFGtWrX08ssvKy0tTRs3btTgwYOVK1cubdq0yfaJG7dv39bNmzfZHBm4T35foZw5c6bGjRun7t27KyYmRgsWLNDkyZP1yiuv2Bbd3LhxQ8OHD9euXbv07bffshgHyCSSSsBJW7du1fjx4/Xrr79q+vTpqlu3rqQ7yeOmTZs0dOhQ+fv7a926dbbEEsD9t2/fPi1btkx169ZV69atJUlTp07V0KFDNXXqVL388su2BDIxMVFWq5VV3oADGP4GnBQaGqozZ85o165d2r9/v609R44catKkiaZMmaJz587pqaeecmGUgOcyDEP79u1TzZo1NWnSJLupJ5GRkXr77bc1ePBgzZgxw7aQx9fXl4QScBBJJeCkkiVLas2aNapWrZo++eQTbdiwwXbO29tbjRo10oIFCzRr1iwXRgl4LovFourVq+vjjz9Wamqqtm7dqitXrtjOR0ZGavLkyRo4cKA+++yzDO8FkDkMfwMOSK9aHD9+XOfPn1dQUJBCQ0NVuHBhnThxQh06dFDBggU1bNgwNW7c2NXhAh7p93Mo/1hpnD17tl544QWNHDlSL7/8svLmzWs79+mnn6pDhw4sxgH+Jr5zgExK/+G0bNkyDRw4UDlz5pRhGPL19dWHH36ohg0b6rPPPlPHjh319ttvKzk5Wc2bN3d12IBH+X1COWfOHB06dEjJycmqU6eO/vnPf6pv375KTU3VSy+9JEl2iWXnzp0lsW0Q8Hcx/A3cQ1pamu3vt2/flsVi0a5du9SzZ08NHz5c3333nRYsWKBatWqpRYsW+vbbb1WmTBktX75chw8f1uzZs3Xz5k0XPgHgedITyldffVWvv/66UlJSdOjQIb3zzjt66qmnlJycrBdffFHvv/++xo4dq3Hjxtk+jCAdCSXw9/CdA9yDl5eXfvrpJxUtWlQ5cuRQamqqDh8+rJo1a+q5556Tl5eXHn74YZUtW1ZpaWkaOHCgvvrqK5UqVUpbtmxRWlqacufO7erHADzO9u3btWTJEq1cuVINGjSQYRhavny5/v3vf6tLly5asmSJ+vXrp6SkJC1ZsoQPIQCyCJVK4B6SkpLUpUsXlShRQoZhyNvbW/Hx8Tpw4IDi4+Ml3RkSDw0NVdeuXXXlyhVdu3ZNkhQWFqYSJUq4MnzAY126dEm3bt1S6dKlJd1ZbPPkk0+qX79+On36tI4dOyZJGjhwoL777jvbKm8AziGpBO7Bx8dHb7/9tvz8/FS9enUZhqE2bdqoYMGCmjdvnmJjY20LAEqXLq2cOXNmGEYDYK7fJ4Ppfy9cuLCCgoLstvjy9fVVy5Yt9eOPP+r777+3tbNtEJB1SCqB3/x+DqV054dNvXr1NGfOHN26dUt16tRRiRIl1K5dO82bN09z5sxRTEyMEhISNHfuXHl5eSksLMw1wQMeKC0tzS4ZTE1NlSQVKVJEAQEBmjlzpo4cOWI77+3trXLlyikoKMiuHxJKIGuwpRCg/18xGh0drbNnz9o+FUeSUlJStH//fnXp0kVFihTR5s2bNWLECK1YsUInT55UtWrVdOrUKa1du1aPPPKIC58C8EyTJ0/W7t27lZqaqsjISNWrV0/Hjx9XRESEypUrpyZNmqhSpUqaMWOGLl++rD179sjb29vVYQMPHJJK4Dfnz5/XI488ol9//VWNGjVSeHi4IiIiVLNmTQUEBGj37t3q3bu3AgIC9N133yk6OlpfffWV8ubNq+rVq6tYsWKufgTAI/x+26AxY8ZoxowZatOmjU6dOqXNmzfr448/1jPPPKOTJ09qxIgROnDggKxWqwoXLqzly5crZ86cSk1NJbEEshhJJfCbn376SW3bttWtW7fk7++vihUr6tNPP1W5cuVUuXJltWrVShaLRcOGDVOJEiW0du1ahs0AF/rll18UFRWlpk2b6tFHH9WtW7c0evRoTZkyRfPmzdOzzz6rpKQkpaSk6Pr16woNDZXFYmEfSsAkJJXA75w8eVKvvvqq0tLSNGzYMBUsWFDbtm3TjBkzlJKSoiNHjqhkyZI6cuSI2rRpoxUrVjDJH3CBzz//XO3atVNYWJgWL16s2rVrS7ozXWX48OGaOnWqPv74Y3Xp0sXufb+vcgLIWiSVwB8cP35cAwcOVFpamsaPH69atWpJkmJjY7Vq1Sr98MMP+t///qeoqCjmUAL3SXoymP7fCxcuaPz48Zo9e7aWLVumNm3a2M7dvn1bI0eO1IQJE7Ru3To1a9bM1eEDHoGkEriLEydOaMCAAZKkYcOGqVGjRnbnGT4D7p/Fixfr66+/1uuvv66HH35YefLkkSTFxMRo6NChWrZsmdatW6d69erZRg5SUlIUFRWlPn368L0K3CcklcA9nDhxQi+//LIMw9CIESNUr149V4cEeJz4+HhVr15d8fHxCg0NVe3atfXoo4+qR48ekqSbN2+qd+/e+uKLL/T111+rfv36Gaak8EsgcH+QVAJ/4sSJE4qMjNSVK1c0bdo0u62GAJgvNTVVw4cPV7FixVSrVi1t2LBB48ePV8uWLVWlShUNHjxYcXFxGjFihBYuXKgvvvhCTZo0cXXYgEditjLwJ0qXLq23335bhQsXVqFChVwdDuBxvL291aBBAw0dOlQ5cuTQkCFDdPHiRZUqVUpvvPGGwsPDNXfuXLVv314tW7bU+PHjXR0y4LGoVAKZkJycLB8fH1eHAXis/v37S5JmzpwpSapYsaLKlCmjkiVL6vvvv9fatWs1efJkDRo0iNXdgIswyQTIBBJKwLWqV6+uefPm6dq1a2rWrJny5s2rBQsWKCAgQD///LO2bdum9u3b260QB3B/UakEAGQLtWvX1p49e9SwYUMtX75c+fLly3ANi3IA1+FXOQCAW0uvfbz88suqWLGipkyZonz58uluNRESSsB1SCoBAG4tfXugJk2a6OrVq1q3bp1dOwD3QFIJAMgWHn74YQ0bNkyTJ0/W0aNHXR0OgD9gnAAAkG088cQT2rNnj8qVK+fqUAD8AQt1AADZSvon5qSmpsrb29vV4QD4DUklAAAAnMacSgAAADiNpBIAAABOI6kEAACA00gqAQAA4DSSSgAAADiNpBIAAABOI6kE8EDo0aOH2rZta3vduHFjDRo06L7HsWnTJlksFsXGxpp2jz8+699xP+IE4FlIKgGYpkePHrJYLLJYLPLx8VGpUqU0ZswY3b592/R7L1++XGPHjs3Utfc7wQoLC9P06dPvy70A4H7hYxoBmOrxxx/XvHnzlJSUpK+++kr9+/dXzpw5NWzYsAzXJicny8fHJ0vumy9fvizpBwCQOVQqAZjKarUqNDRUxYoV0wsvvKCIiAh98cUXkv5/GHf8+PEqVKiQypYtK0k6f/68OnXqpKCgIOXLl09t2rTR2bNnbX2mpqYqMjJSQUFBCg4O1quvvqo/fjjYH4e/k5KS9Nprr6lIkSKyWq0qVaqUoqKidPbsWTVp0kSSlDdvXlksFvXo0UOSlJaWpgkTJqh48eLKlSuXqlatqs8++8zuPl999ZXKlCmjXLlyqUmTJnZx/h2pqanq3bu37Z5ly5bVO++8c9drR48erfz58ysgIED9+vVTcnKy7VxmYgeArESlEsB9lStXLl29etX2ev369QoICNC6deskSSkpKWrRooXCw8P17bffKkeOHBo3bpwef/xxHTp0SD4+PpoyZYrmz5+vuXPnqnz58poyZYpWrFihpk2b3vO+3bp10/bt2/Xuu++qatWqOnPmjK5cuaIiRYpo2bJl6tChg44fP66AgADlypVLkjRhwgT95z//0QcffKDSpUtry5YtevbZZ5U/f341atRI58+fV/v27dW/f389//zz2rNnjwYPHuzU1yctLU2FCxfW0qVLFRwcrG3btun5559XwYIF1alTJ7uvm6+vrzZt2qSzZ8+qZ8+eCg4O1vjx4zMVOwBkOQMATNK9e3ejTZs2hmEYRlpamrFu3TrDarUaQ4YMsZ0PCQkxkpKSbO9ZuHChUbZsWSMtLc3WlpSUZOTKlctYu3atYRiGUbBgQWPSpEm28ykpKUbhwoVt9zIMw2jUqJExcOBAwzAM4/jx44YkY926dXeNc+PGjYYk49q1a7a2xMREI3fu3Ma2bdvsru3du7fx9NNPG4ZhGMOGDTMqVKhgd/61117L0NcfFStWzJg2bdo9z/9R//79jQ4dOthed+/e3ciXL59x48YNW9usWbMMPz8/IzU1NVOx3+2ZAcAZVCoBmGr16tXy8/NTSkqK0tLS1LVrV40aNcp2vnLlynbzKA8ePKiTJ0/K39/frp/ExESdOnVKcXFxunjxourUqWM7lyNHDtWsWTPDEHi6AwcOyNvb26EK3cmTJ3Xz5k099thjdu3Jycl65JFHJEnHjh2zi0OSwsPDM32Pe5k5c6bmzp2rc+fO6datW0pOTla1atXsrqlatapy585td9+EhASdP39eCQkJfxk7AGQ1kkoApmrSpIlmzZolHx8fFSpUSDly2P+zkydPHrvXCQkJqlGjhj755JMMfeXPn/9vxZA+nO2IhIQESdKXX36phx9+2O6c1Wr9W3FkxuLFizVkyBBNmTJF4eHh8vf319tvv62dO3dmug9XxQ7As5FUAjBVnjx5VKpUqUxfX716dX366acqUKCAAgIC7npNwYIFtXPnTjVs2FCSdPv2be3du1fVq1e/6/WVK1dWWlqaNm/erIiIiAzn0yulqamptrYKFSrIarXq3Llz96xwli9f3rboKN2OHTv++iH/xNatW1WvXj29+OKLtrZTp05luO7gwYO6deuWLWHesWOH/Pz8VKRIEeXLl+8vYweArMbqbwBu5ZlnntFDDz2kNm3a6Ntvv9WZM2e0adMmvfzyy/r5558lSQMHDtTEiRO1cuVK/fDDD3rxxRf/dI/JsLAwde/eXb169dLKlSttfS5ZskSSVKxYMVksFq1evVqXL19WQkKC/P39NWTIEL3yyitasGCBTp06pX379um9997TggULJEn9+vXTiRMnNHToUB0/flyLFi3S/PnzM/Wcv/zyiw4cOGB3XLt2TaVLl9aePXu0du1a/fjjjxo+fLh2796d4f3Jycnq3bu3jh49qq+++kojR47USy+9JC8vr0zFDgBZztWTOgE8uH6/UMeR8xcvXjS6detmPPTQQ4bVajVKlChhPPfcc0ZcXJxhGHcW5gwcONAICAgwgoKCjMjISKNbt273XKhjGIZx69Yt45VXXjEKFixo+Pj4GKVKlTLmzp1rOz9mzBgjNDTUsFgsRvfu3Q3DuLO4aPr06UbZsmWNnDlzGvnz5zdatGhhbN682fa+VatWGaVKlTKsVqvRoEEDY+7cuZlaqCMpw7Fw4UIjMTHR6NGjhxEYGGgEBQUZL7zwgvH6668bVatWzfB1GzFihBEcHGz4+fkZzz33nJGYmGi75q9iZ6EOgKxmMYx7zGwHAAAAMonhbwAAADiNpBIAAABOI6kEAACA00gqAQAA4DSSSgAAADiNpBIAAABOI6kEAACA00gqAQAA4DSSSgAAADiNpBIAAABOI6kEAACA00gqAQAA4LT/A1LoQss7DPyPAAAAAElFTkSuQmCC", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import numpy as np\n", + "from sklearn.metrics import confusion_matrix\n", + "\n", + "import matplotlib.pyplot as plt\n", + "\n", + "# Get the true and predicted labels from the results_dataframe\n", + "true_labels = results_dataframe['true']\n", + "predicted_labels = results_dataframe['predicted']\n", + "\n", + "# Compute the confusion matrix\n", + "cm = confusion_matrix(true_labels, predicted_labels)\n", + "\n", + "# Define the class labels\n", + "class_labels = ['Normal', 'Anomaly']\n", + "\n", + "# Plot the confusion matrix\n", + "plt.figure(figsize=(8, 6))\n", + "plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)\n", + "plt.title('Confusion Matrix')\n", + "plt.colorbar()\n", + "tick_marks = np.arange(len(class_labels))\n", + "plt.xticks(tick_marks, class_labels, rotation=45)\n", + "plt.yticks(tick_marks, class_labels)\n", + "plt.xlabel('Predicted Label')\n", + "plt.ylabel('True Label')\n", + "\n", + "# Add the values to the confusion matrix plot\n", + "thresh = cm.max() / 2.\n", + "for i in range(cm.shape[0]):\n", + " for j in range(cm.shape[1]):\n", + " plt.text(j, i, format(cm[i, j], 'd'),\n", + " horizontalalignment=\"center\",\n", + " color=\"white\" if cm[i, j] > thresh else \"black\")\n", + "\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.6" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/objects.inv b/objects.inv index 75d1ed9..c708810 100644 Binary files a/objects.inv and b/objects.inv differ diff --git a/py-modindex.html b/py-modindex.html index 3033af8..421cc5b 100644 --- a/py-modindex.html +++ b/py-modindex.html @@ -6,7 +6,7 @@ Python Module Index — SSLTools documentation - +