From 6ec04e51a193b6025162843b2665d032d1fdd369 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Thu, 14 Nov 2024 18:13:34 +0100 Subject: [PATCH] style: Fix ruff issues --- python/nutpie/compile_pymc.py | 2 +- python/nutpie/compile_stan.py | 4 ++-- python/nutpie/transform_adapter.py | 17 +++++++++-------- 3 files changed, 12 insertions(+), 11 deletions(-) diff --git a/python/nutpie/compile_pymc.py b/python/nutpie/compile_pymc.py index 58b6b0d..5510053 100644 --- a/python/nutpie/compile_pymc.py +++ b/python/nutpie/compile_pymc.py @@ -380,7 +380,7 @@ def _compile_pymc_model_jax( def logp_fn_jax_grad(x, *shared): return jax.value_and_grad(lambda x: orig_logp_fn(x, *shared)[0])(x) - static_argnums = list(range(1, len(logp_shared_names) + 1)) + # static_argnums = list(range(1, len(logp_shared_names) + 1)) logp_fn_jax_grad = jax.jit( logp_fn_jax_grad, # static_argnums=static_argnums, diff --git a/python/nutpie/compile_stan.py b/python/nutpie/compile_stan.py index 70ecd14..138652d 100644 --- a/python/nutpie/compile_stan.py +++ b/python/nutpie/compile_stan.py @@ -1,10 +1,10 @@ -from functools import partial import json import tempfile from dataclasses import dataclass, replace +from functools import partial from importlib.util import find_spec from pathlib import Path -from typing import Any, Optional, Callable +from typing import Any, Optional import numpy as np import pandas as pd diff --git a/python/nutpie/transform_adapter.py b/python/nutpie/transform_adapter.py index f658ca6..e0e601b 100644 --- a/python/nutpie/transform_adapter.py +++ b/python/nutpie/transform_adapter.py @@ -12,18 +12,18 @@ def make_transform_adapter( untransformed_dim=None, zero_init=True, ): - import jax + import traceback + from functools import partial + import equinox as eqx - import jax.numpy as jnp import flowjax - import flowjax.train import flowjax.flows + import flowjax.train + import jax + import jax.numpy as jnp + import numpy as np import optax - import traceback from paramax import Parameterize, unwrap - from functools import partial - - import numpy as np class FisherLoss: @eqx.filter_jit @@ -363,7 +363,7 @@ def update(self, seed, positions, gradients): else: base = self._bijection - # make_flow might still only return a single trafo if the for 1d problems + # make_flow might still only return a single trafo for 1d problems if len(base.bijections) == 1: self._bijection = base return @@ -436,6 +436,7 @@ def update(self, seed, positions, gradients): except Exception as e: print("update error:", e) print(traceback.format_exc()) + raise def init_from_transformed_position(self, transformed_position): try: