Skip to content

Commit

Permalink
style: Fix ruff issues
Browse files Browse the repository at this point in the history
  • Loading branch information
aseyboldt committed Nov 14, 2024
1 parent 7b4dd2e commit 6ec04e5
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 11 deletions.
2 changes: 1 addition & 1 deletion python/nutpie/compile_pymc.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ def _compile_pymc_model_jax(
def logp_fn_jax_grad(x, *shared):
return jax.value_and_grad(lambda x: orig_logp_fn(x, *shared)[0])(x)

static_argnums = list(range(1, len(logp_shared_names) + 1))
# static_argnums = list(range(1, len(logp_shared_names) + 1))
logp_fn_jax_grad = jax.jit(
logp_fn_jax_grad,
# static_argnums=static_argnums,
Expand Down
4 changes: 2 additions & 2 deletions python/nutpie/compile_stan.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from functools import partial
import json
import tempfile
from dataclasses import dataclass, replace
from functools import partial
from importlib.util import find_spec
from pathlib import Path
from typing import Any, Optional, Callable
from typing import Any, Optional

import numpy as np
import pandas as pd
Expand Down
17 changes: 9 additions & 8 deletions python/nutpie/transform_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,18 @@ def make_transform_adapter(
untransformed_dim=None,
zero_init=True,
):
import jax
import traceback
from functools import partial

import equinox as eqx
import jax.numpy as jnp
import flowjax
import flowjax.train
import flowjax.flows
import flowjax.train
import jax
import jax.numpy as jnp
import numpy as np
import optax
import traceback
from paramax import Parameterize, unwrap
from functools import partial

import numpy as np

class FisherLoss:
@eqx.filter_jit
Expand Down Expand Up @@ -363,7 +363,7 @@ def update(self, seed, positions, gradients):
else:
base = self._bijection

# make_flow might still only return a single trafo if the for 1d problems
# make_flow might still only return a single trafo for 1d problems
if len(base.bijections) == 1:
self._bijection = base
return
Expand Down Expand Up @@ -436,6 +436,7 @@ def update(self, seed, positions, gradients):
except Exception as e:
print("update error:", e)
print(traceback.format_exc())
raise

def init_from_transformed_position(self, transformed_position):
try:
Expand Down

0 comments on commit 6ec04e5

Please sign in to comment.