Skip to content

Commit

Permalink
Add EnsemblePlot support for rate plotting drawing
Browse files Browse the repository at this point in the history
 - exploits "steps-pre" drawing style
 - tests that draw_style is set correctly
  • Loading branch information
xjules committed Sep 13, 2024
1 parent 45c970a commit d8b17e4
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 1 deletion.
7 changes: 6 additions & 1 deletion src/ert/gui/plottery/plots/ensemble.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Dict
from typing import TYPE_CHECKING, Dict, Optional

import numpy as np
import pandas as pd

from ert.gui.plottery.plots.history import plotHistory
from ert.gui.tools.plot.plot_api import EnsembleObject
from ert.shared.storage.summary_key_utils import is_rate

from .observations import plotObservations
from .plot_tools import PlotTools
Expand Down Expand Up @@ -36,6 +37,7 @@ def plot(

plot_context.y_axis = plot_context.VALUE_AXIS
plot_context.x_axis = plot_context.DATE_AXIS
draw_style = "steps-pre" if is_rate(plot_context.key()) else None

for ensemble, data in ensemble_to_data_map.items():
data = data.T
Expand All @@ -50,6 +52,7 @@ def plot(
config,
data,
f"{ensemble.experiment_name} : {ensemble.name}",
draw_style,
)
config.nextColor()

Expand All @@ -71,6 +74,7 @@ def _plotLines(
plot_config: PlotConfig,
data: pd.DataFrame,
ensemble_label: str,
draw_style: Optional[str] = None,
) -> None:
style = plot_config.defaultStyle()

Expand All @@ -86,6 +90,7 @@ def _plotLines(
linewidth=style.width,
linestyle=style.line_style,
markersize=style.size,
drawstyle=draw_style,
)

if len(lines) > 0:
Expand Down
48 changes: 48 additions & 0 deletions tests/unit_tests/gui/plottery/test_ensemble_plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from unittest.mock import Mock, patch

import pandas as pd
import pytest
from matplotlib.figure import Figure

from ert.gui.plottery import PlotConfig, PlotContext
from ert.gui.plottery.plots.ensemble import EnsemblePlot
from ert.gui.tools.plot.plot_api import EnsembleObject
from ert.shared.storage.summary_key_utils import is_rate


@pytest.fixture(
params=[
pytest.param("WOPR:OP_4"),
pytest.param("BPR:123"),
]
)
def plot_context(request):
context = Mock(spec=PlotContext)
context.ensembles.return_value = [
EnsembleObject("ensemble_1", "id", False, "experiment_1")
]
context.key.return_value = request.param
context.history_data = None
context.plotConfig.return_value = PlotConfig(title="Ensemble Plot")
return context


def test_ensemble_plot_handles_rate(plot_context: PlotContext):
figure = Figure()
with patch(
"ert.gui.plottery.plots.ensemble.EnsemblePlot._plotLines"
) as mock_plotLines:
EnsemblePlot().plot(
figure,
plot_context,
dict.fromkeys(
plot_context.ensembles(),
pd.DataFrame([[0.1], [0.2], [0.3], [0.4], [0.5]]),
),
pd.DataFrame(),
{},
)
if is_rate(plot_context.key()):
assert mock_plotLines.call_args[0][4] == "steps-pre"
else:
assert mock_plotLines.call_args[0][4] is None

0 comments on commit d8b17e4

Please sign in to comment.