From 3fc5962452a2401f498a4ded3e5707254ac57a02 Mon Sep 17 00:00:00 2001 From: Oriol Abril-Pla Date: Tue, 12 Mar 2024 18:05:59 +0100 Subject: [PATCH] remove conditional_jit from plot_forest label_idx (#2319) * remove conditional_jit from plot_forest label_idx * update changelog * fix ci --- .azure-pipelines/azure-pipelines-benchmarks.yml | 2 +- CHANGELOG.md | 4 ++-- arviz/data/inference_data.py | 3 +-- arviz/data/io_pystan.py | 3 +-- arviz/plots/backends/bokeh/forestplot.py | 2 -- arviz/plots/backends/matplotlib/forestplot.py | 2 -- 6 files changed, 5 insertions(+), 11 deletions(-) diff --git a/.azure-pipelines/azure-pipelines-benchmarks.yml b/.azure-pipelines/azure-pipelines-benchmarks.yml index 412b926f84..91b1e08b8b 100644 --- a/.azure-pipelines/azure-pipelines-benchmarks.yml +++ b/.azure-pipelines/azure-pipelines-benchmarks.yml @@ -29,7 +29,7 @@ jobs: python -m pip install wheel python -m pip install --no-cache-dir -r requirements.txt python -m pip install --no-cache-dir -r requirements-optional.txt - python -m pip install asv!=0.6.2 + python -m pip install asv==0.6.1 displayName: 'Install requirements' - script: | diff --git a/CHANGELOG.md b/CHANGELOG.md index d25dfa5ace..a5b807ff4c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,8 +6,8 @@ ### Maintenance and fixes - Fix deprecations introduced in latest pandas and xarray versions, and prepare for numpy 2.0 ones ([2315](https://github.com/arviz-devs/arviz/pull/2315))) - -- Refactor ECDF code ([2311](https://github.com/arviz-devs/arviz/pull/2311)) +- Refactor ECDF code ([2311](https://github.com/arviz-devs/arviz/pull/2311)) +- Fix `plot_forest` when Numba is installed ([2319](https://github.com/arviz-devs/arviz/pull/2319)) ### Deprecation diff --git a/arviz/data/inference_data.py b/arviz/data/inference_data.py index 3740fca1e0..c921c0a9dd 100644 --- a/arviz/data/inference_data.py +++ b/arviz/data/inference_data.py @@ -254,8 +254,7 @@ def __len__(self) -> int: def __iter__(self) -> Iterator[str]: """Iterate over groups in InferenceData object.""" - for group in self._groups_all: - yield group + yield from self._groups_all def __contains__(self, key: object) -> bool: """Return True if the named item is present, and False otherwise.""" diff --git a/arviz/data/io_pystan.py b/arviz/data/io_pystan.py index 0b899c5e9d..480a1be638 100644 --- a/arviz/data/io_pystan.py +++ b/arviz/data/io_pystan.py @@ -676,8 +676,7 @@ def get_draws(fit, variables=None, ignore=None, warmup=False, dtypes=None): for item in par_keys: _, shape = item.replace("]", "").split("[") shape_idx_min = min(int(shape_value) for shape_value in shape.split(",")) - if shape_idx_min < shift: - shift = shape_idx_min + shift = min(shift, shape_idx_min) # If shift is higher than 1, this will probably mean that Stan # has implemented sparse structure (saves only non-zero parts), # but let's hope that dims are still corresponding to the full shape diff --git a/arviz/plots/backends/bokeh/forestplot.py b/arviz/plots/backends/bokeh/forestplot.py index 9406641fbe..467aa44718 100644 --- a/arviz/plots/backends/bokeh/forestplot.py +++ b/arviz/plots/backends/bokeh/forestplot.py @@ -15,7 +15,6 @@ from ....stats import hdi from ....stats.density_utils import get_bins, histogram, kde from ....stats.diagnostics import _ess, _rhat -from ....utils import conditional_jit from ...plot_utils import _scale_fig_size from .. import show_layout from . import backend_kwarg_defaults @@ -277,7 +276,6 @@ def labels_and_ticks(self): """Collect labels and ticks from plotters.""" val = self.plotters.values() - @conditional_jit(forceobj=True, nopython=False) def label_idxs(): labels, idxs = [], [] for plotter in val: diff --git a/arviz/plots/backends/matplotlib/forestplot.py b/arviz/plots/backends/matplotlib/forestplot.py index 67b710bf48..3993f497e6 100644 --- a/arviz/plots/backends/matplotlib/forestplot.py +++ b/arviz/plots/backends/matplotlib/forestplot.py @@ -11,7 +11,6 @@ from ....stats.density_utils import get_bins, histogram, kde from ....stats.diagnostics import _ess, _rhat from ....sel_utils import xarray_var_iter -from ....utils import conditional_jit from ...plot_utils import _scale_fig_size from . import backend_kwarg_defaults, backend_show @@ -236,7 +235,6 @@ def labels_and_ticks(self): """Collect labels and ticks from plotters.""" val = self.plotters.values() - @conditional_jit(forceobj=True, nopython=False) def label_idxs(): labels, idxs = [], [] for plotter in val: