Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Visualize: fix flatten of observable mapping with one observable #1515

Merged
merged 4 commits into from
Dec 2, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions pypesto/visualize/observable_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def visualize_estimated_observable_mapping(
n_axes = n_relative_observables + n_semiquant_observables
n_rows = int(np.ceil(np.sqrt(n_axes)))
n_cols = int(np.ceil(n_axes / n_rows))
_, axes = plt.subplots(n_rows, n_cols, **kwargs)
_, axes = plt.subplots(n_rows, n_cols, squeeze=False, **kwargs)
axes = axes.flatten()

# Plot the estimated observable mapping for relative observables.
Expand Down Expand Up @@ -246,8 +246,7 @@ def plot_linear_observable_mappings_from_pypesto_result(
n_cols = int(np.ceil(n_relative_observables / n_rows))

# Make as many subplots as there are relative observables
_, axes = plt.subplots(n_rows, n_cols, **kwargs)

_, axes = plt.subplots(n_rows, n_cols, squeeze=False, **kwargs)
# Flatten the axes array
axes = axes.flatten()

Expand Down Expand Up @@ -590,8 +589,7 @@ def plot_splines_from_inner_result(
n_cols = int(np.ceil(n_groups / n_rows))

# Make as many subplots as there are groups
_, axes = plt.subplots(n_rows, n_cols, **kwargs)

_, axes = plt.subplots(n_rows, n_cols, squeeze=False, **kwargs)
# Flatten the axes array
axes = axes.flatten()

Expand Down
26 changes: 9 additions & 17 deletions pypesto/visualize/ordinal_categories.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,26 +612,18 @@ def _get_data_for_plotting(

def _get_default_axes(n_groups, **kwargs):
"""Return a list of axes with the default layout."""
# If there is only one group, make a figure with only one plot
if n_groups == 1:
# Make figure with only one plot
fig, ax = plt.subplots(1, 1, **kwargs)
# Choose number of rows and columns to be used for the subplots
n_rows = int(np.ceil(np.sqrt(n_groups)))
n_cols = int(np.ceil(n_groups / n_rows))

axes = [ax]
# If there are multiple groups, make a figure with multiple plots
else:
# Choose number of rows and columns to be used for the subplots
n_rows = int(np.ceil(np.sqrt(n_groups)))
n_cols = int(np.ceil(n_groups / n_rows))

# Make as many subplots as there are groups
fig, axes = plt.subplots(n_rows, n_cols, **kwargs)
# Make as many subplots as there are groups
fig, axes = plt.subplots(n_rows, n_cols, squeeze=False, **kwargs)

# Increase the spacing between the subplots
fig.subplots_adjust(hspace=0.35, wspace=0.25)
# Increase the spacing between the subplots
fig.subplots_adjust(hspace=0.35, wspace=0.25)

# Flatten the axes array
axes = axes.flatten()
# Flatten the axes array
axes = axes.flatten()
return axes


Expand Down
52 changes: 52 additions & 0 deletions test/visualize/test_visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
from collections.abc import Sequence
from functools import wraps
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
Expand Down Expand Up @@ -1240,3 +1241,54 @@ def test_parameters_correlation_matrix(result_creation):
result = result_creation()

visualize.parameters_correlation_matrix(result)


@close_fig
def test_plot_ordinal_categories():
example_ordinal_yaml = (
Path(__file__).parent
/ ".."
/ ".."
/ "doc"
/ "example"
/ "example_ordinal"
/ "example_ordinal.yaml"
)
petab_problem = petab.Problem.from_yaml(example_ordinal_yaml)
# Set seed for reproducibility.
np.random.seed(0)
optimizer = pypesto.optimize.ScipyOptimizer(
method="L-BFGS-B", options={"maxiter": 10}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For just testing that visualization does not produce an error, a single iteration would suffice, right?

)
importer = pypesto.petab.PetabImporter(petab_problem, hierarchical=True)
problem = importer.create_problem()
result = pypesto.optimize.minimize(
problem=problem, n_starts=1, optimizer=optimizer
)
visualize.plot_categories_from_pypesto_result(result)


@close_fig
def test_visualize_estimated_observable_mapping():
example_semiquantitative_yaml = (
Path(__file__).parent
/ ".."
/ ".."
/ "doc"
/ "example"
/ "example_semiquantitative"
/ "example_semiquantitative_linear.yaml"
)
petab_problem = petab.Problem.from_yaml(example_semiquantitative_yaml)
# Set seed for reproducibility.
np.random.seed(0)
optimizer = pypesto.optimize.ScipyOptimizer(
method="L-BFGS-B",
options={"disp": None, "ftol": 2.220446049250313e-09, "gtol": 1e-5},
)
importer = pypesto.petab.PetabImporter(petab_problem, hierarchical=True)
problem = importer.create_problem()
result = pypesto.optimize.minimize(
problem=problem, n_starts=1, optimizer=optimizer
)
visualize.visualize_estimated_observable_mapping(result, problem)