diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f4c44e5..f67fdd3 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,19 +1,19 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v3.2.0 + rev: v4.5.0 hooks: - id: trailing-whitespace - id: end-of-file-fixer - repo: https://github.com/asottile/pyupgrade - rev: v3.3.1 + rev: v3.15.2 hooks: - id: pyupgrade args: [--py37-plus] - repo: https://github.com/omnilib/ufmt - rev: v2.0.1 + rev: v2.5.1 hooks: - id: ufmt additional_dependencies: @@ -21,6 +21,6 @@ repos: - usort == 1.1.0b2 - repo: https://github.com/pycqa/flake8 - rev: 4.0.1 + rev: 7.0.0 hooks: - id: flake8 diff --git a/dmlcloud/pipeline.py b/dmlcloud/pipeline.py index 7f28e44..43d500e 100644 --- a/dmlcloud/pipeline.py +++ b/dmlcloud/pipeline.py @@ -70,7 +70,7 @@ def register_model( if verbose: msg = f'Model "{name}":\n' - msg += f' - Parameters: {sum(p.numel() for p in model.parameters())/1e6:.1f} kk\n' + msg += f' - Parameters: {sum(p.numel() for p in model.parameters()) / 1e6:.1f} kk\n' msg += f' - DDP: {use_ddp}\n' msg += f' - {model}' self.logger.info(msg) @@ -231,7 +231,7 @@ def _pre_run(self): self._resume_run() diagnostics = general_diagnostics() - diagnostics += '\n* CONFIG:\n' + diagnostics += '\n* CONFIG:\n' diagnostics += '\n'.join(f' {line}' for line in OmegaConf.to_yaml(self.config, resolve=True).splitlines()) self.logger.info(diagnostics) diff --git a/dmlcloud/util/data.py b/dmlcloud/util/data.py index 0dee40a..772250a 100644 --- a/dmlcloud/util/data.py +++ b/dmlcloud/util/data.py @@ -47,7 +47,7 @@ def sharded_xr_dataset( start = chunk_idx * chunk_size end = start + chunk_size chunk = ds.isel({dim: slice(start, end)}) - + if load: kwargs = load_kwargs or {} chunk.load(**kwargs) @@ -123,7 +123,7 @@ def interleave_batches( """ if num_batches < 1: raise ValueError('num_batches must be greater than 0') - + if num_batches == 1: yield from iterable