diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e09d665c3ff..d4d221e3c0d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.8.0 + rev: v0.9.1 hooks: - id: ruff types_or: diff --git a/docs/tutorials/transforms.ipynb b/docs/tutorials/transforms.ipynb index 689b2eebd33..e148945afbb 100644 --- a/docs/tutorials/transforms.ipynb +++ b/docs/tutorials/transforms.ipynb @@ -707,7 +707,7 @@ "sample = dataset[idx]\n", "rgb = sample['image'][0, 1:4]\n", "image = T.ToPILImage()(rgb)\n", - "print(f\"Class Label: {dataset.classes[sample['label']]}\")\n", + "print(f'Class Label: {dataset.classes[sample[\"label\"]]}')\n", "image.resize((256, 256), resample=Image.BILINEAR)" ] }, diff --git a/experiments/torchgeo/run_resisc45_experiments.py b/experiments/torchgeo/run_resisc45_experiments.py index 6897ea12772..9ed69b03968 100755 --- a/experiments/torchgeo/run_resisc45_experiments.py +++ b/experiments/torchgeo/run_resisc45_experiments.py @@ -38,7 +38,7 @@ def do_work(work: 'Queue[str]', gpu_idx: int) -> bool: for model, lr, loss, weights in itertools.product( model_options, lr_options, loss_options, weight_options ): - experiment_name = f"{model}_{lr}_{loss}_{weights.replace('_', '-')}" + experiment_name = f'{model}_{lr}_{loss}_{weights.replace("_", "-")}' output_dir = os.path.join('output', 'resisc45_experiments') log_dir = os.path.join(output_dir, 'logs') diff --git a/experiments/torchgeo/run_so2sat_byol_experiments.py b/experiments/torchgeo/run_so2sat_byol_experiments.py index 169a010cef8..4ae78601fbd 100755 --- a/experiments/torchgeo/run_so2sat_byol_experiments.py +++ b/experiments/torchgeo/run_so2sat_byol_experiments.py @@ -39,7 +39,7 @@ def do_work(work: 'Queue[str]', gpu_idx: int) -> bool: for model, lr, loss, weights, bands in itertools.product( model_options, lr_options, loss_options, weight_options, bands_options ): - experiment_name = f"{model}_{lr}_{loss}_byol_{bands}-{weights.split('/')[-2]}" + experiment_name = f'{model}_{lr}_{loss}_byol_{bands}-{weights.split("/")[-2]}' output_dir = os.path.join('output', 'so2sat_experiments') log_dir = os.path.join(output_dir, 'logs') diff --git a/experiments/torchgeo/run_so2sat_experiments.py b/experiments/torchgeo/run_so2sat_experiments.py index 41e2fc04b5f..44ba5c7aaf3 100755 --- a/experiments/torchgeo/run_so2sat_experiments.py +++ b/experiments/torchgeo/run_so2sat_experiments.py @@ -38,7 +38,7 @@ def do_work(work: 'Queue[str]', gpu_idx: int) -> bool: for model, lr, loss, weights in itertools.product( model_options, lr_options, loss_options, weight_options ): - experiment_name = f"{model}_{lr}_{loss}_{weights.replace('_', '-')}" + experiment_name = f'{model}_{lr}_{loss}_{weights.replace("_", "-")}' output_dir = os.path.join('output', 'so2sat_experiments') log_dir = os.path.join(output_dir, 'logs') diff --git a/experiments/torchgeo/run_so2sat_seed_experiments.py b/experiments/torchgeo/run_so2sat_seed_experiments.py index 2d2efe1e248..4f71770917a 100755 --- a/experiments/torchgeo/run_so2sat_seed_experiments.py +++ b/experiments/torchgeo/run_so2sat_seed_experiments.py @@ -39,7 +39,7 @@ def do_work(work: 'Queue[str]', gpu_idx: int) -> bool: for model, lr, loss, weights, seed in itertools.product( model_options, lr_options, loss_options, weight_options, seeds ): - experiment_name = f"{model}_{lr}_{loss}_{weights.replace('_', '-')}_{seed}" + experiment_name = f'{model}_{lr}_{loss}_{weights.replace("_", "-")}_{seed}' output_dir = os.path.join('output', 'so2sat_seed_experiments') log_dir = os.path.join(output_dir, 'logs') diff --git a/tests/data/inria/data.py b/tests/data/inria/data.py index 96626dea5fa..76ea3aa9c60 100755 --- a/tests/data/inria/data.py +++ b/tests/data/inria/data.py @@ -68,9 +68,9 @@ def generate_test_data(root: str, n_samples: int = 2) -> str: lbl = np.random.randint(2, size=size, dtype=dtype) timg = np.random.randint(dtype_max, size=size, dtype=dtype) - img_path = os.path.join(img_dir, f'austin{i+1}.tif') - lbl_path = os.path.join(lbl_dir, f'austin{i+1}.tif') - timg_path = os.path.join(timg_dir, f'austin{i+10}.tif') + img_path = os.path.join(img_dir, f'austin{i + 1}.tif') + lbl_path = os.path.join(lbl_dir, f'austin{i + 1}.tif') + timg_path = os.path.join(timg_dir, f'austin{i + 10}.tif') write_data(img_path, img, driver, crs, transform) write_data(lbl_path, lbl, driver, crs, transform) diff --git a/tests/data/seasonet/data.py b/tests/data/seasonet/data.py index e3197ddde12..86f6210636b 100644 --- a/tests/data/seasonet/data.py +++ b/tests/data/seasonet/data.py @@ -63,7 +63,7 @@ os.remove(archive) for grid, comp in zip(grids, name_comps): - file_name = f"{comp[0]}_{''.join(comp[1:8])}_{'_'.join(comp[8:])}" + file_name = f'{comp[0]}_{"".join(comp[1:8])}_{"_".join(comp[8:])}' dir = os.path.join(season, f'grid{grid}', file_name) os.makedirs(dir) diff --git a/tests/data/ssl4eo_benchmark_landsat/data.py b/tests/data/ssl4eo_benchmark_landsat/data.py index 177ed7d7954..5470aacef05 100755 --- a/tests/data/ssl4eo_benchmark_landsat/data.py +++ b/tests/data/ssl4eo_benchmark_landsat/data.py @@ -193,7 +193,7 @@ def create_tarballs(directories: str) -> None: # mask directory cdl mask_keep = ['tm_toa', 'etm_sr', 'oli_sr'] mask_filenames = { - f"ssl4eo_l_{key.split('_')[0]}_cdl": val + f'ssl4eo_l_{key.split("_")[0]}_cdl': val for key, val in filenames.items() if key in mask_keep } @@ -203,7 +203,7 @@ def create_tarballs(directories: str) -> None: # mask directory nlcd mask_filenames = { - f"ssl4eo_l_{key.split('_')[0]}_nlcd": val + f'ssl4eo_l_{key.split("_")[0]}_nlcd': val for key, val in filenames.items() if key in mask_keep } diff --git a/tests/datamodules/test_digital_typhoon.py b/tests/datamodules/test_digital_typhoon.py index 0ecd85f5ec7..dd61eb26933 100644 --- a/tests/datamodules/test_digital_typhoon.py +++ b/tests/datamodules/test_digital_typhoon.py @@ -57,14 +57,14 @@ def find_max_time_per_id( # Assert that each max value in train_max_values is lower # than in val_max_values for each key id for id, max_value in train_max_values.items(): - assert ( - id not in val_max_values or max_value < val_max_values[id] - ), f'Max value for id {id} in train is not lower than in validation.' + assert id not in val_max_values or max_value < val_max_values[id], ( + f'Max value for id {id} in train is not lower than in validation.' + ) else: train_ids = {seq['id'] for seq in train_sequences} val_ids = {seq['id'] for seq in val_sequences} # Assert that the intersection between train_ids and val_ids is empty - assert ( - len(train_ids & val_ids) == 0 - ), 'Train and validation datasets have overlapping ids.' + assert len(train_ids & val_ids) == 0, ( + 'Train and validation datasets have overlapping ids.' + ) diff --git a/torchgeo/datamodules/digital_typhoon.py b/torchgeo/datamodules/digital_typhoon.py index ce799bf3d52..9ebc1643255 100644 --- a/torchgeo/datamodules/digital_typhoon.py +++ b/torchgeo/datamodules/digital_typhoon.py @@ -43,9 +43,9 @@ def __init__( """ super().__init__(DigitalTyphoon, batch_size, num_workers, **kwargs) - assert ( - split_by in self.valid_split_types - ), f'Please choose from {self.valid_split_types}' + assert split_by in self.valid_split_types, ( + f'Please choose from {self.valid_split_types}' + ) self.split_by = split_by def _split_dataset( diff --git a/torchgeo/datamodules/ftw.py b/torchgeo/datamodules/ftw.py index a197a789c48..19128cdbb3d 100644 --- a/torchgeo/datamodules/ftw.py +++ b/torchgeo/datamodules/ftw.py @@ -44,9 +44,9 @@ def __init__( Raises: AssertionError: If 'countries' are specified in kwargs """ - assert ( - 'countries' not in kwargs - ), "Please specify 'train_countries', 'val_countries', and 'test_countries' instead of 'countries' inside kwargs" + assert 'countries' not in kwargs, ( + "Please specify 'train_countries', 'val_countries', and 'test_countries' instead of 'countries' inside kwargs" + ) super().__init__(FieldsOfTheWorld, batch_size, num_workers, **kwargs) diff --git a/torchgeo/datasets/agrifieldnet.py b/torchgeo/datasets/agrifieldnet.py index 3624c1e193e..ba116377878 100644 --- a/torchgeo/datasets/agrifieldnet.py +++ b/torchgeo/datasets/agrifieldnet.py @@ -149,9 +149,9 @@ def __init__( Raises: DatasetNotFoundError: If dataset is not found and *download* is False. """ - assert ( - set(classes) <= self.cmap.keys() - ), f'Only the following classes are valid: {list(self.cmap.keys())}.' + assert set(classes) <= self.cmap.keys(), ( + f'Only the following classes are valid: {list(self.cmap.keys())}.' + ) assert 0 in classes, 'Classes must include the background class: 0' self.paths = paths diff --git a/torchgeo/datasets/bigearthnet.py b/torchgeo/datasets/bigearthnet.py index 38669cd6ff1..ef62ac1a280 100644 --- a/torchgeo/datasets/bigearthnet.py +++ b/torchgeo/datasets/bigearthnet.py @@ -565,9 +565,9 @@ def plot( ax.imshow(image) ax.axis('off') if show_titles: - title = f"Labels: {', '.join(labels)}" + title = f'Labels: {", ".join(labels)}' if showing_predictions: - title += f"\nPredictions: {', '.join(predictions)}" + title += f'\nPredictions: {", ".join(predictions)}' ax.set_title(title) if suptitle is not None: diff --git a/torchgeo/datasets/biomassters.py b/torchgeo/datasets/biomassters.py index 70a53a4220a..2531c96dd23 100644 --- a/torchgeo/datasets/biomassters.py +++ b/torchgeo/datasets/biomassters.py @@ -81,14 +81,14 @@ def __init__( """ self.root = root - assert ( - split in self.valid_splits - ), f'Please choose one of the valid splits: {self.valid_splits}.' + assert split in self.valid_splits, ( + f'Please choose one of the valid splits: {self.valid_splits}.' + ) self.split = split - assert set(sensors).issubset( - set(self.valid_sensors) - ), f'Please choose a subset of valid sensors: {self.valid_sensors}.' + assert set(sensors).issubset(set(self.valid_sensors)), ( + f'Please choose a subset of valid sensors: {self.valid_sensors}.' + ) self.sensors = sensors self.as_time_series = as_time_series diff --git a/torchgeo/datasets/cdl.py b/torchgeo/datasets/cdl.py index 0b0f6ac5b3d..2de5719beb0 100644 --- a/torchgeo/datasets/cdl.py +++ b/torchgeo/datasets/cdl.py @@ -248,9 +248,9 @@ def __init__( 'CDL data product only exists for the following years: ' f'{list(self.md5s.keys())}.' ) - assert ( - set(classes) <= self.cmap.keys() - ), f'Only the following classes are valid: {list(self.cmap.keys())}.' + assert set(classes) <= self.cmap.keys(), ( + f'Only the following classes are valid: {list(self.cmap.keys())}.' + ) assert 0 in classes, 'Classes must include the background class: 0' self.paths = paths diff --git a/torchgeo/datasets/cms_mangrove_canopy.py b/torchgeo/datasets/cms_mangrove_canopy.py index f9db256238d..681d5026f25 100644 --- a/torchgeo/datasets/cms_mangrove_canopy.py +++ b/torchgeo/datasets/cms_mangrove_canopy.py @@ -204,15 +204,15 @@ def __init__( self.checksum = checksum assert isinstance(country, str), 'Country argument must be a str.' - assert ( - country in self.all_countries - ), f'You have selected an invalid country, please choose one of {self.all_countries}' + assert country in self.all_countries, ( + f'You have selected an invalid country, please choose one of {self.all_countries}' + ) self.country = country assert isinstance(measurement, str), 'Measurement must be a string.' - assert ( - measurement in self.measurements - ), f'You have entered an invalid measurement, please choose one of {self.measurements}.' + assert measurement in self.measurements, ( + f'You have entered an invalid measurement, please choose one of {self.measurements}.' + ) self.measurement = measurement self.filename_glob = f'**/Mangrove_{self.measurement}_{self.country}*' diff --git a/torchgeo/datasets/digital_typhoon.py b/torchgeo/datasets/digital_typhoon.py index 42bb4caa1bd..dfa47966440 100644 --- a/torchgeo/datasets/digital_typhoon.py +++ b/torchgeo/datasets/digital_typhoon.py @@ -139,9 +139,9 @@ def __init__( self.min_feature_value = min_feature_value self.max_feature_value = max_feature_value - assert ( - task in self.valid_tasks - ), f'Please choose one of {self.valid_tasks}, you provided {task}.' + assert task in self.valid_tasks, ( + f'Please choose one of {self.valid_tasks}, you provided {task}.' + ) self.task = task assert set(features).issubset(set(self.valid_features)) diff --git a/torchgeo/datasets/loveda.py b/torchgeo/datasets/loveda.py index 93b6b18e455..e3066058371 100644 --- a/torchgeo/datasets/loveda.py +++ b/torchgeo/datasets/loveda.py @@ -115,9 +115,9 @@ def __init__( DatasetNotFoundError: If dataset is not found and *download* is False. """ assert split in self.splits - assert set(scene).intersection( - set(self.scenes) - ), "The possible scenes are 'rural' and/or 'urban'" + assert set(scene).intersection(set(self.scenes)), ( + "The possible scenes are 'rural' and/or 'urban'" + ) assert len(scene) <= 2, "There are no other scenes than 'rural' or 'urban'" self.root = root diff --git a/torchgeo/datasets/mdas.py b/torchgeo/datasets/mdas.py index 1ee020ab953..25a61a72396 100644 --- a/torchgeo/datasets/mdas.py +++ b/torchgeo/datasets/mdas.py @@ -162,13 +162,13 @@ def __init__( """ self.root = root self.download = download - assert all( - sub in self.valid_subareas for sub in subareas - ), f'Subareas must be one of {self.valid_subareas}' + assert all(sub in self.valid_subareas for sub in subareas), ( + f'Subareas must be one of {self.valid_subareas}' + ) self.subareas = subareas - assert all( - mod in self.valid_modalities for mod in modalities - ), f'Modalities must be one of {self.valid_modalities}' + assert all(mod in self.valid_modalities for mod in modalities), ( + f'Modalities must be one of {self.valid_modalities}' + ) self.modalities = modalities self.transforms = transforms self.checksum = checksum diff --git a/torchgeo/datasets/mmearth.py b/torchgeo/datasets/mmearth.py index f363276c40a..b940537d8b4 100644 --- a/torchgeo/datasets/mmearth.py +++ b/torchgeo/datasets/mmearth.py @@ -206,12 +206,12 @@ def __init__( """ lazy_import('h5py') - assert ( - normalization_mode in self.norm_modes - ), f'Invalid normalization mode: {normalization_mode}, please choose from {self.norm_modes}' - assert ( - subset in self.subsets - ), f'Invalid dataset version: {subset}, please choose from {self.subsets}' + assert normalization_mode in self.norm_modes, ( + f'Invalid normalization mode: {normalization_mode}, please choose from {self.norm_modes}' + ) + assert subset in self.subsets, ( + f'Invalid dataset version: {subset}, please choose from {self.subsets}' + ) self._validate_modalities(modalities) self.modalities = modalities diff --git a/torchgeo/datasets/nlcd.py b/torchgeo/datasets/nlcd.py index 501fd6db8f9..681f0e242bc 100644 --- a/torchgeo/datasets/nlcd.py +++ b/torchgeo/datasets/nlcd.py @@ -167,9 +167,9 @@ def __init__( 'NLCD data product only exists for the following years: ' f'{list(self.md5s.keys())}.' ) - assert ( - set(classes) <= self.cmap.keys() - ), f'Only the following classes are valid: {list(self.cmap.keys())}.' + assert set(classes) <= self.cmap.keys(), ( + f'Only the following classes are valid: {list(self.cmap.keys())}.' + ) assert 0 in classes, 'Classes must include the background class: 0' self.paths = paths diff --git a/torchgeo/datasets/seasonet.py b/torchgeo/datasets/seasonet.py index 3e47a8ec491..ebbb9036374 100644 --- a/torchgeo/datasets/seasonet.py +++ b/torchgeo/datasets/seasonet.py @@ -450,7 +450,7 @@ def plot( axs[ax].imshow(image) axs[ax].axis('off') if show_titles: - axs[ax].set_title(f'Image {ax+1}') + axs[ax].set_title(f'Image {ax + 1}') axs[ax + 1].imshow(mask, vmin=0, vmax=32, cmap=plt_cmap, interpolation='none') axs[ax + 1].axis('off') diff --git a/torchgeo/datasets/skippd.py b/torchgeo/datasets/skippd.py index 3accf32d2af..d5a68f48488 100644 --- a/torchgeo/datasets/skippd.py +++ b/torchgeo/datasets/skippd.py @@ -104,14 +104,14 @@ def __init__( """ lazy_import('h5py') - assert ( - split in self.valid_splits - ), f'Please choose one of these valid data splits {self.valid_splits}.' + assert split in self.valid_splits, ( + f'Please choose one of these valid data splits {self.valid_splits}.' + ) self.split = split - assert ( - task in self.valid_tasks - ), f'Please choose one of these valid tasks {self.valid_tasks}.' + assert task in self.valid_tasks, ( + f'Please choose one of these valid tasks {self.valid_tasks}.' + ) self.task = task self.root = root diff --git a/torchgeo/datasets/south_africa_crop_type.py b/torchgeo/datasets/south_africa_crop_type.py index a8643873c5b..841cd7173de 100644 --- a/torchgeo/datasets/south_africa_crop_type.py +++ b/torchgeo/datasets/south_africa_crop_type.py @@ -131,9 +131,9 @@ def __init__( Raises: DatasetNotFoundError: If dataset is not found and *download* is False. """ - assert ( - set(classes) <= self.cmap.keys() - ), f'Only the following classes are valid: {list(self.cmap.keys())}.' + assert set(classes) <= self.cmap.keys(), ( + f'Only the following classes are valid: {list(self.cmap.keys())}.' + ) assert 0 in classes, 'Classes must include the background class: 0' self.paths = paths diff --git a/torchgeo/datasets/ssl4eo_benchmark.py b/torchgeo/datasets/ssl4eo_benchmark.py index 13c5a8474c4..111fe487e09 100644 --- a/torchgeo/datasets/ssl4eo_benchmark.py +++ b/torchgeo/datasets/ssl4eo_benchmark.py @@ -138,26 +138,26 @@ def __init__( AssertionError: if any arguments are invalid DatasetNotFoundError: If dataset is not found and *download* is False. """ - assert ( - sensor in self.valid_sensors - ), f'Only supports one of {self.valid_sensors}, but found {sensor}.' + assert sensor in self.valid_sensors, ( + f'Only supports one of {self.valid_sensors}, but found {sensor}.' + ) self.sensor = sensor - assert ( - product in self.valid_products - ), f'Only supports one of {self.valid_products}, but found {product}.' + assert product in self.valid_products, ( + f'Only supports one of {self.valid_products}, but found {product}.' + ) self.product = product - assert ( - split in self.valid_splits - ), f'Only supports one of {self.valid_splits}, but found {split}.' + assert split in self.valid_splits, ( + f'Only supports one of {self.valid_splits}, but found {split}.' + ) self.split = split self.cmap = self.cmaps[product] if classes is None: classes = list(self.cmap.keys()) - assert ( - set(classes) <= self.cmap.keys() - ), f'Only the following classes are valid: {list(self.cmap.keys())}.' + assert set(classes) <= self.cmap.keys(), ( + f'Only the following classes are valid: {list(self.cmap.keys())}.' + ) assert 0 in classes, 'Classes must include the background class: 0' self.root = root diff --git a/torchgeo/datasets/sustainbench_crop_yield.py b/torchgeo/datasets/sustainbench_crop_yield.py index eec9be57ab3..4d3a0b4de9f 100644 --- a/torchgeo/datasets/sustainbench_crop_yield.py +++ b/torchgeo/datasets/sustainbench_crop_yield.py @@ -82,14 +82,14 @@ def __init__( is invalid DatasetNotFoundError: If dataset is not found and *download* is False. """ - assert set(countries).issubset( - self.valid_countries - ), f'Please choose a subset of these valid countried: {self.valid_countries}.' + assert set(countries).issubset(self.valid_countries), ( + f'Please choose a subset of these valid countried: {self.valid_countries}.' + ) self.countries = countries - assert ( - split in self.valid_splits - ), f'Pleas choose one of these valid data splits {self.valid_splits}.' + assert split in self.valid_splits, ( + f'Pleas choose one of these valid data splits {self.valid_splits}.' + ) self.split = split self.root = root diff --git a/torchgeo/models/croma.py b/torchgeo/models/croma.py index 475c32fd3a9..de57a835936 100644 --- a/torchgeo/models/croma.py +++ b/torchgeo/models/croma.py @@ -56,9 +56,9 @@ def __init__( """ super().__init__() for modality in modalities: - assert ( - modality in self.valid_modalities - ), f'{modality} is not a valid modality' + assert modality in self.valid_modalities, ( + f'{modality} is not a valid modality' + ) assert image_size % 8 == 0, 'image_size must be a multiple of 8' assert num_heads % 2 == 0, 'num_heads must be a power of 2'