diff --git a/.github/workflows/build_and_publish.yml b/.github/workflows/build_and_publish.yml index 5dd19985..5a5f6ed1 100644 --- a/.github/workflows/build_and_publish.yml +++ b/.github/workflows/build_and_publish.yml @@ -41,13 +41,13 @@ jobs: pip install -e . - name: Run mypy - run: mypy src + run: mypy src/hssm - - name: Check styling - run: black . --check + - name: Check formatting + run: ruff format --check . - name: Linting - run: ruff check . + run: ruff check src/hssm - name: Run tests run: pytest -n auto -s diff --git a/.github/workflows/run_tests.yml b/.github/workflows/run_tests.yml index a143656e..3708c2c5 100644 --- a/.github/workflows/run_tests.yml +++ b/.github/workflows/run_tests.yml @@ -39,13 +39,13 @@ jobs: pip install -e . - name: Run mypy - run: mypy src + run: mypy src/hssm - - name: Check styling - run: black . --check + - name: Check formatting + run: ruff format --check . - name: Linting - run: ruff check . + run: ruff check src/hssm - name: Run tests run: pytest -n auto -s diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 2f91c585..191a2ec0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,31 +5,13 @@ repos: - id: end-of-file-fixer - id: trailing-whitespace - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.1.3 + rev: v0.4.3 hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] - - repo: https://github.com/psf/black - rev: 23.10.1 - hooks: - - id: black-jupyter - args: - - --line-length=88 - - --include='\.pyi?$' - - # these folders wont be formatted by black - - --exclude="""\.git | - \.__pycache__| - \.hg| - \.mypy_cache| - \.tox| - \.venv| - _build| - buck-out| - build| - dist""" + - id: ruff-format - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.6.1 # Use the sha / tag you want to point at + rev: v1.10.0 # Use the sha / tag you want to point at hooks: - id: mypy args: [--no-strict-optional, --ignore-missing-imports] diff --git a/pyproject.toml b/pyproject.toml index d53be14b..c0dd632d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,58 +16,42 @@ keywords = ["HSSM", "sequential sampling models", "bayesian", "bayes", "mcmc"] [tool.poetry.dependencies] python = ">=3.10,<3.12" -pymc = ">=5.12" +pymc = "^5.14.0" arviz = "^0.18.0" -onnx = "^1.12.0" +onnx = "^1.16.0" jax = "^0.4.25" jaxlib = "^0.4.25" -ssm-simulators = "^0.7.0" -huggingface-hub = "^0.15.1" +ssm-simulators = "^0.7.2" +huggingface-hub = "^0.23.0" bambi = "^0.13.0" numpyro = "^0.14.0" -hddm-wfpt = "^0.1.1" +hddm-wfpt = "^0.1.4" seaborn = "^0.13.2" [tool.poetry.group.dev.dependencies] -pytest = "^7.3.1" -black = { extras = ["jupyter"], version = "^23.10.1" } -mypy = "^1.6.1" +pytest = "^8.2.0" +mypy = "^1.10.0" pre-commit = "^2.20.0" jupyterlab = "^4.0.2" ipykernel = "^6.16.0" ipywidgets = "^8.0.3" -ruff = "^0.1.3" -mkdocs = "^1.4.3" -mkdocs-material = "^9.1.17" -mkdocstrings-python = "^1.1.2" -mkdocs-jupyter = "^0.24.1" +ruff = "^0.4.3" graphviz = "^0.20.1" -pytest-xdist = "^3.5.0" +pytest-xdist = "^3.6.1" onnxruntime = "^1.17.1" - -[tool.black] -line-length = 88 -include = '\.pyi?$' -exclude = ''' -/( - \.git - | \.hg - | \.mypy_cache - | \.tox - | \.venv - | _build - | buck-out - | build - | dist -)/ -''' - -[tool.isort] -profile = "black" +mkdocs = "^1.6.0" +mkdocs-material = "^9.5.21" +mkdocstrings-python = "^1.10.0" +mkdocs-jupyter = "^0.24.7" [tool.ruff] line-length = 88 target-version = "py310" + +[tool.ruff.format] +docstring-code-format = true + +[tool.ruff.lint] unfixable = ["E711"] select = [ @@ -166,9 +150,9 @@ ignore = [ "TID252", ] -exclude = [".github", "docs", "notebook", "tests"] +exclude = [".github", "docs", "notebook", "tests/*"] -[tool.ruff.pydocstyle] +[tool.ruff.lint.pydocstyle] convention = "numpy" [tool.mypy] diff --git a/src/hssm/distribution_utils/onnx/onnx2pt.py b/src/hssm/distribution_utils/onnx/onnx2pt.py index cd562f9c..9c04bb33 100644 --- a/src/hssm/distribution_utils/onnx/onnx2pt.py +++ b/src/hssm/distribution_utils/onnx/onnx2pt.py @@ -17,9 +17,7 @@ def onnx_add(a, b, axis=None, broadcast=True): return [pt.add(a, b)] -def pytensor_gemm( - a, b, c=0.0, alpha=1.0, beta=1.0, transA=0, transB=0 -): # pylint: disable=C0103 +def pytensor_gemm(a, b, c=0.0, alpha=1.0, beta=1.0, transA=0, transB=0): # pylint: disable=C0103 """Perform the GEMM op. Numpy-backed implementatio, of ONNX General Matrix Multiply (GeMM) op. diff --git a/src/hssm/distribution_utils/onnx/onnx2xla.py b/src/hssm/distribution_utils/onnx/onnx2xla.py index 76c386b7..a2142be1 100644 --- a/src/hssm/distribution_utils/onnx/onnx2xla.py +++ b/src/hssm/distribution_utils/onnx/onnx2xla.py @@ -102,9 +102,7 @@ def onnx_add(a, b, axis=None, broadcast=True): # Added by HSSM Developers -def onnx_gemm( - a, b, c=0.0, alpha=1.0, beta=1.0, transA=0, transB=0 -): # pylint: disable=C0103 +def onnx_gemm(a, b, c=0.0, alpha=1.0, beta=1.0, transA=0, transB=0): # pylint: disable=C0103 """Numpy-backed implementatio of Onnx Gemm op.""" a = jnp.transpose(a) if transA else a b = jnp.transpose(b) if transB else b