From ac0695fa4663752268e0dec615ce113fd1c660a1 Mon Sep 17 00:00:00 2001 From: mferrera Date: Wed, 22 Nov 2023 07:31:23 +0100 Subject: [PATCH] CLN: Dynamically import matplotlib and pyplot Normally dynamically importing modules is fairly straight forward: you can just import them in the function or method they are used in. However, by default matplotlib will look for display settings on the machine it's being run on and compute cluster nodes do not have this set. The suggested solution to this is to import matplotlib before anything else is imported, and set its backend to `use("Agg")` which doesn't look for display information. xtgeo implemented this by importing matplotlib in the root __init__, and it generally complicates the dynamic loading situation. This solution tries to ensure that using the Agg backend will still be triggered if it xtgeo believes it is in batch mode and wraps a getting around `sys.modules` after importing it. It also tries not to repeat the import logic if it is already imported. --- src/xtgeo/__init__.py | 26 +++------- src/xtgeo/plot/__init__.py | 3 -- src/xtgeo/plot/baseplot.py | 53 +++++++++++++++++---- src/xtgeo/plot/grid3d_slice.py | 58 +++-------------------- src/xtgeo/plot/xsection.py | 22 ++++++--- src/xtgeo/plot/xtmap.py | 20 +++++--- src/xtgeo/surface/_regsurf_oper.py | 5 +- src/xtgeo/xyz/_xyz_oper.py | 5 +- tests/test_plot/test_matplotlib_import.py | 58 +++++++++++++++++++++++ 9 files changed, 149 insertions(+), 101 deletions(-) create mode 100644 tests/test_plot/test_matplotlib_import.py diff --git a/src/xtgeo/__init__.py b/src/xtgeo/__init__.py index 8b37d7bb9..748e79147 100644 --- a/src/xtgeo/__init__.py +++ b/src/xtgeo/__init__.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # flake8: noqa # pylint: skip-file # type: ignore @@ -50,34 +49,23 @@ def _xprint(msg): except Exception: ROXAR = False - -# to avoid problems in batch runs when no DISPLAY is set: -_xprint("Import matplotlib etc...") if not ROXAR: - import matplotlib as mplib - - display = os.environ.get("DISPLAY", "") - host1 = os.environ.get("HOSTNAME", "") - host2 = os.environ.get("HOST", "") - dhost = host1 + host2 + display + _display = os.environ.get("DISPLAY", "") + _hostname = os.environ.get("HOSTNAME", "") + _host = os.environ.get("HOST", "") - ertbool = "LSB_JOBID" in os.environ + _dhost = _hostname + _host + _display + _lsf_job = "LSB_JOBID" in os.environ - if display == "" or "grid" in dhost or "lgc" in dhost or ertbool: + if _display == "" or "grid" in _dhost or "lgc" in _dhost or _lsf_job: _xprint("") _xprint("=" * 79) - _xprint( "XTGeo info: No display found or a batch (e.g. ERT) server. " "Using non-interactive Agg backend for matplotlib" ) - mplib.use("Agg") _xprint("=" * 79) - -# -# Order matters! -# -_xprint("Import matplotlib etc...DONE") + os.environ["MPLBACKEND"] = "Agg" from xtgeo._cxtgeo import XTGeoCLibError from xtgeo.common import XTGeoDialog diff --git a/src/xtgeo/plot/__init__.py b/src/xtgeo/plot/__init__.py index b967ebbef..f3b9f0557 100644 --- a/src/xtgeo/plot/__init__.py +++ b/src/xtgeo/plot/__init__.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """The XTGeo plot package""" @@ -7,5 +6,3 @@ # flake8: noqa from xtgeo.plot.xsection import XSection from xtgeo.plot.xtmap import Map - -# from ._colortables import random, random40, xtgeocolors, colorsfromfile diff --git a/src/xtgeo/plot/baseplot.py b/src/xtgeo/plot/baseplot.py index a95886f28..704b7becb 100644 --- a/src/xtgeo/plot/baseplot.py +++ b/src/xtgeo/plot/baseplot.py @@ -1,7 +1,4 @@ """The baseplot module.""" -import matplotlib as mpl -import matplotlib.pyplot as plt -from matplotlib.colors import LinearSegmentedColormap from packaging.version import parse as versionparse from xtgeo.common import XTGeoDialog, null_logger @@ -14,7 +11,11 @@ def _get_colormap(name): """For matplotlib compatibility.""" + import matplotlib as mpl + if versionparse(mpl.__version__) < versionparse("3.6"): + import matplotlib.plt as plt + return plt.cm.get_cmap(name) else: return mpl.colormaps[name] @@ -56,7 +57,9 @@ def colormap(self): @colormap.setter def colormap(self, cmap): - if isinstance(cmap, LinearSegmentedColormap): + import matplotlib as mpl + + if isinstance(cmap, mpl.colors.LinearSegmentedColormap): self._colormap = cmap elif isinstance(cmap, str): logger.info("Definition of a colormap from string name: %s", cmap) @@ -85,6 +88,9 @@ def define_any_colormap(cfile, colorlist=None): from 0 index. Default is just keep the linear sequence as is. """ + import matplotlib as mpl + import matplotlib.pyplot as plt + valid_maps = sorted(m for m in plt.cm.datad) logger.info("Valid color maps: %s", valid_maps) @@ -99,21 +105,37 @@ def define_any_colormap(cfile, colorlist=None): elif cfile == "xtgeo": colors = _ctable.xtgeocolors() - cmap = LinearSegmentedColormap.from_list(cfile, colors, N=len(colors)) + cmap = mpl.colors.LinearSegmentedColormap.from_list( + cfile, + colors, + N=len(colors), + ) cmap.name = "xtgeo" elif cfile == "random40": colors = _ctable.random40() - cmap = LinearSegmentedColormap.from_list(cfile, colors, N=len(colors)) + cmap = mpl.colors.LinearSegmentedColormap.from_list( + cfile, + colors, + N=len(colors), + ) cmap.name = "random40" elif cfile == "randomc": colors = _ctable.randomc(256) - cmap = LinearSegmentedColormap.from_list(cfile, colors, N=len(colors)) + cmap = mpl.colors.LinearSegmentedColormap.from_list( + cfile, + colors, + N=len(colors), + ) cmap.name = "randomc" elif isinstance(cfile, str) and "rms" in cfile: colors = _ctable.colorsfromfile(cfile) - cmap = LinearSegmentedColormap.from_list("rms", colors, N=len(colors)) + cmap = mpl.colors.LinearSegmentedColormap.from_list( + "rms", + colors, + N=len(colors), + ) cmap.name = cfile elif cfile in valid_maps: cmap = _get_colormap(cfile) @@ -138,7 +160,11 @@ def define_any_colormap(cfile, colorlist=None): logger.warning("Color list out of range") ctable.append(colors[0]) - cmap = LinearSegmentedColormap.from_list(ctable, colors, N=len(colors)) + cmap = mpl.colors.LinearSegmentedColormap.from_list( + ctable, + colors, + N=len(colors), + ) cmap.name = "user" return cmap @@ -182,7 +208,8 @@ def canvas(self, title=None, subtitle=None, infotext=None, figscaling=1.0): """ - # self._fig, (ax1, ax2) = plt.subplots(2, figsize=(11.69, 8.27)) + import matplotlib.pyplot as plt + self._fig, self._ax = plt.subplots( figsize=(11.69 * figscaling, 8.27 * figscaling) ) @@ -204,6 +231,8 @@ def show(self): self._fig.tight_layout() if self._showok: + import matplotlib.pyplot as plt + logger.info("Calling plt show method...") plt.show() return True @@ -218,6 +247,8 @@ def close(self): After close is called, no more operations can be performed on the plot. """ + import matplotlib.pyplot as plt + for fig in self._allfigs: plt.close(fig) @@ -247,6 +278,8 @@ def savefig(self, filename, fformat="png", last=True, **kwargs): self._fig.tight_layout() if self._showok: + import matplotlib.pyplot as plt + plt.savefig(filename, format=fformat, **kwargs) if last: self.close() diff --git a/src/xtgeo/plot/grid3d_slice.py b/src/xtgeo/plot/grid3d_slice.py index d52a998a7..a1d809738 100644 --- a/src/xtgeo/plot/grid3d_slice.py +++ b/src/xtgeo/plot/grid3d_slice.py @@ -1,10 +1,5 @@ """Module for 3D Grid slice plots, using matplotlib.""" - -import matplotlib.pyplot as plt -from matplotlib.collections import PatchCollection -from matplotlib.patches import Polygon - from xtgeo.common import null_logger from xtgeo.plot.baseplot import BasePlot @@ -105,51 +100,6 @@ def plot_gridslice( else: self._plot_layer() - # def _plot_row(self): - - # geomlist = self._geomlist - - # if self._window is None: - # xmin = geomlist[3] - 0.05 * (abs(geomlist[4] - geomlist[3])) - # xmax = geomlist[4] + 0.05 * (abs(geomlist[4] - geomlist[3])) - # zmin = geomlist[7] - 0.05 * (abs(geomlist[8] - geomlist[7])) - # zmax = geomlist[8] + 0.05 * (abs(geomlist[8] - geomlist[7])) - # else: - # xmin, xmax, zmin, zmax = self._window - - # # now some numpy operations, numbering is intended - # clist = self._clist - # xz0 = np.column_stack((clist[0].values1d, clist[2].values1d)) - # xz1 = np.column_stack((clist[3].values1d, clist[5].values1d)) - # xz2 = np.column_stack((clist[15].values1d, clist[17].values1d)) - # xz3 = np.column_stack((clist[12].values1d, clist[14].values1d)) - - # xyc = np.column_stack((xz0, xz1, xz2, xz3)) - # xyc = xyc.reshape(self._grid.nlay, self._grid.ncol * self._grid.nrow, 4, 2) - - # patches = [] - - # for pos in range(self._grid.nrow * self._grid.nlay): - # nppol = xyc[self._index - 1, pos, :, :] - # if nppol.mean() > 0.0: - # polygon = Polygon(nppol, True) - # patches.append(polygon) - - # black = (0, 0, 0, 1) - # patchcoll = PatchCollection(patches, edgecolors=(black,), cmap=self.colormap) - - # # patchcoll.set_array(np.array(pvalues)) - - # # patchcoll.set_clim([minvalue, maxvalue]) - - # im = self._ax.add_collection(patchcoll) - # self._ax.set_xlim((xmin, xmax)) - # self._ax.set_ylim((zmin, zmax)) - # self._ax.invert_yaxis() - # self._fig.colorbar(im) - - # # plt.gca().set_aspect("equal", adjustable="box") - def _plot_layer(self): xyc, ibn = self._grid.get_layer_slice(self._index, activeonly=self._active) @@ -171,13 +121,15 @@ def _plot_layer(self): patches = [] + import matplotlib as mpl + for pos in range(len(ibn)): nppol = xyc[pos, :, :] if nppol.mean() > 0.0: - polygon = Polygon(nppol) + polygon = mpl.patches.Polygon(nppol) patches.append(polygon) - patchcoll = PatchCollection( + patchcoll = mpl.collections.PatchCollection( patches, edgecolors=(self._linecolor,), cmap=self.colormap ) @@ -203,4 +155,6 @@ def _plot_layer(self): self._ax.set_ylim((ymin, ymax)) self._fig.colorbar(im) + import matplotlib.pyplot as plt + plt.gca().set_aspect("equal", adjustable="box") diff --git a/src/xtgeo/plot/xsection.py b/src/xtgeo/plot/xsection.py index d439aa1ec..8946cc900 100644 --- a/src/xtgeo/plot/xsection.py +++ b/src/xtgeo/plot/xsection.py @@ -5,12 +5,9 @@ import warnings from collections import OrderedDict -import matplotlib.pyplot as plt import numpy as np import numpy.ma as ma import pandas as pd -from matplotlib import collections as mc -from matplotlib.lines import Line2D from scipy.ndimage import gaussian_filter from xtgeo.common import XTGeoDialog, null_logger @@ -265,6 +262,7 @@ def canvas(self, title=None, subtitle=None, infotext=None, figscaling=1.0): """ # overriding the base class canvas + import matplotlib.pyplot as plt plt.rcParams["axes.xmargin"] = 0 # fill the plot margins @@ -445,6 +443,8 @@ def set_xaxis_md(self, gridlines=False): md_start_round = int(math.floor(md_start / 100.0)) * 100 md_start_delta = md_start - md_start_round + import matplotlib.pyplot as plt + auto_ticks = plt.xticks() auto_ticks_delta = auto_ticks[0][1] - auto_ticks[0][0] @@ -566,7 +566,9 @@ def _plot_well_zlog(self, df, ax, bba, zonelogname, logwidth=4, legend=False): df, idx_zshift, ctable, zonelogname, fillnavalue ) - lc = mc.LineCollection( + import matplotlib as mpl + + lc = mpl.collections.LineCollection( segments, colors=segments_colors, linewidth=logwidth, zorder=202 ) @@ -610,7 +612,9 @@ def _plot_well_faclog(self, df, ax, bba, facieslogname, logwidth=9, legend=True) df, idx, ctable, facieslogname, fillnavalue ) - lc = mc.LineCollection( + import matplotlib as mpl + + lc = mpl.collections.LineCollection( segments, colors=segments_colors, linewidth=logwidth, zorder=201 ) @@ -656,7 +660,9 @@ def _plot_well_perflog(self, df, ax, bba, perflogname, logwidth=12, legend=True) df, idx, ctable, perflogname, fillnavalue ) - lc = mc.LineCollection( + import matplotlib as mpl + + lc = mpl.collections.LineCollection( segments, colors=segments_colors, linewidth=logwidth, zorder=200 ) @@ -769,9 +775,11 @@ def _drawproxylegend(self, ax, bba, items, title=None): proxies = [] labels = [] + import matplotlib as mpl + for item in items: color = items[item] - proxies.append(Line2D([0, 1], [0, 1], color=color, linewidth=5)) + proxies.append(mpl.lines.Line2D([0, 1], [0, 1], color=color, linewidth=5)) labels.append(item) ax.legend( diff --git a/src/xtgeo/plot/xtmap.py b/src/xtgeo/plot/xtmap.py index 9416b2280..cdf4cfb57 100644 --- a/src/xtgeo/plot/xtmap.py +++ b/src/xtgeo/plot/xtmap.py @@ -1,11 +1,7 @@ """Module for map plots of surfaces, using matplotlib.""" - -import matplotlib.patches as mplp -import matplotlib.pyplot as plt import numpy as np import numpy.ma as ma -from matplotlib import ticker from xtgeo.common import null_logger @@ -121,6 +117,8 @@ def plot_surface( levels = np.linspace(minvalue, maxvalue, self.contourlevels) logger.debug("Number of contour levels: %s", levels) + import matplotlib.pyplot as plt + plt.setp(self._ax.xaxis.get_majorticklabels(), rotation=xlabelrotation) # zi = ma.masked_where(zimask, zi) @@ -143,7 +141,9 @@ def plot_surface( else: logger.info("use LogLocator") - locator = ticker.LogLocator() + import matplotlib as mpl + + locator = mpl.ticker.LogLocator() ticks = None uselevels = None im = self._ax.contourf(xi, yi, zi, locator=locator, cmap=self.colormap) @@ -176,6 +176,8 @@ def plot_faults( .. _Matplotlib: http://matplotlib.org/api/colors_api.html """ + import matplotlib as mpl + aff = fpoly.dataframe.groupby(idname) for name, _group in aff: @@ -185,7 +187,13 @@ def plot_faults( # make a list [(X,Y) ...]; af = list(zip(myfault["X_UTME"].values, myfault["Y_UTMN"].values)) - px = mplp.Polygon(af, alpha=alpha, color=color, ec=edgecolor, lw=linewidth) + px = mpl.patches.Polygon( + af, + alpha=alpha, + color=color, + ec=edgecolor, + lw=linewidth, + ) if px.get_closed(): self._ax.add_artist(px) diff --git a/src/xtgeo/surface/_regsurf_oper.py b/src/xtgeo/surface/_regsurf_oper.py index deeb22564..c17d5fe85 100644 --- a/src/xtgeo/surface/_regsurf_oper.py +++ b/src/xtgeo/surface/_regsurf_oper.py @@ -7,7 +7,6 @@ import numpy as np import numpy.ma as ma -from matplotlib.path import Path as MPath import xtgeo from xtgeo import XTGeoCLibError, _cxtgeo @@ -565,11 +564,13 @@ def _proxy_map_polygons(surf, poly, inside=True): xvals, yvals = proxy.get_xy_values(asmasked=False) points = np.array([xvals.ravel(), yvals.ravel()]).T + import matplotlib as mpl + for pol in usepolys: idgroups = pol.dataframe.groupby(pol.pname) for _, grp in idgroups: singlepoly = np.array([grp[pol.xname].values, grp[pol.yname].values]).T - poly_path = MPath(singlepoly) + poly_path = mpl.path.Path(singlepoly) is_inside = poly_path.contains_points(points) is_inside = is_inside.reshape(proxy.ncol, proxy.nrow) proxy.values = np.where(is_inside, inside_value, proxy.values) diff --git a/src/xtgeo/xyz/_xyz_oper.py b/src/xtgeo/xyz/_xyz_oper.py index 1a68d020a..6cf030a2b 100644 --- a/src/xtgeo/xyz/_xyz_oper.py +++ b/src/xtgeo/xyz/_xyz_oper.py @@ -5,7 +5,6 @@ import numpy as np import pandas as pd import shapely.geometry as sg -from matplotlib.path import Path as MPath from scipy.interpolate import UnivariateSpline, interp1d import xtgeo @@ -40,11 +39,13 @@ def mark_in_polygons_mpl(self, poly, name, inside_value, outside_value): self.dataframe[name] = outside_value + import matplotlib as mpl + for pol in usepolys: idgroups = pol.dataframe.groupby(pol.pname) for _, grp in idgroups: singlepoly = np.array([grp[pol.xname].values, grp[pol.yname].values]).T - poly_path = MPath(singlepoly) + poly_path = mpl.path.Path(singlepoly) is_inside = poly_path.contains_points(points) self.dataframe.loc[is_inside, name] = inside_value diff --git a/tests/test_plot/test_matplotlib_import.py b/tests/test_plot/test_matplotlib_import.py new file mode 100644 index 000000000..f592cfb8f --- /dev/null +++ b/tests/test_plot/test_matplotlib_import.py @@ -0,0 +1,58 @@ +import os +import sys +from unittest import mock + + +def _clear_state(sys, os): + delete = [] + for module, _ in sys.modules.items(): + if module.startswith(("xtgeo", "matplotlib")): + delete.append(module) + + for module in delete: + del sys.modules[module] + + if "MPLBACKEND" in os.environ: + del os.environ["MPLBACKEND"] + + +@mock.patch.dict(sys.modules) +@mock.patch.dict(os.environ) +def test_that_mpl_dynamically_imports(): + _clear_state(sys, os) + import xtgeo # noqa + + assert "matplotlib" not in sys.modules + assert "matplotlib.pyplot" not in sys.modules + + from xtgeo.plot.baseplot import BasePlot + + assert "matplotlib" not in sys.modules + assert "matplotlib.pyplot" not in sys.modules + + baseplot = BasePlot() + + assert "matplotlib" in sys.modules + assert "matplotlib.pyplot" not in sys.modules + + baseplot.close() + + assert "matplotlib.pyplot" in sys.modules + + +@mock.patch.dict(sys.modules) +@mock.patch.dict(os.environ, {"LSB_JOBID": "1"}) +def test_that_agg_backend_set_when_lsf_job(): + _clear_state(sys, os) + import xtgeo # noqa + + assert os.environ.get("MPLBACKEND", "") == "Agg" + + +@mock.patch.dict(sys.modules) +@mock.patch.dict(os.environ, {"DISPLAY": "X"}) +def test_that_agg_backend_set_when_display_set(): + _clear_state(sys, os) + import xtgeo # noqa + + assert os.environ.get("MPLBACKEND", "") == ""