From 751d8d1a90a2e8ab29bd42647ce73f9e304a1669 Mon Sep 17 00:00:00 2001 From: Philipp Holl Date: Sun, 6 Oct 2024 14:01:21 +0200 Subject: [PATCH] [vis] Fix Matplotlib color handling --- phi/vis/_matplotlib/_matplotlib_plots.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/phi/vis/_matplotlib/_matplotlib_plots.py b/phi/vis/_matplotlib/_matplotlib_plots.py index d3896c8b2..e7e6ecbd9 100644 --- a/phi/vis/_matplotlib/_matplotlib_plots.py +++ b/phi/vis/_matplotlib/_matplotlib_plots.py @@ -22,7 +22,7 @@ from phi.geom._transform import _EmbeddedGeometry from phi.math import Tensor, channel, spatial, instance, non_channel, Shape, reshaped_numpy, shape from phi.vis._vis_base import display_name, PlottingLibrary, Recipe, index_label, only_stored_elements, to_field - +from phiml.math import wrap colormaps = matplotlib.colormaps if hasattr(matplotlib.colormaps, 'get_cmap') else matplotlib.cm @@ -400,7 +400,7 @@ def plot(self, data: Field, figure, subplot, space: Box, min_val: float, max_val xyz = StaggeredGrid(lambda x: x, math.extrapolation.BOUNDARY, data.geometry.bounds, data.resolution).staggered_tensor().numpy(dims + ('vector',))[:-1, :-1, :-1, :] xyz = xyz.reshape(-1, 3) values = data.values.numpy(dims).flatten() - if color == 'cmap': + if wrap(color == 'cmap').all: color = 0 col = matplotlib.colors.to_rgba(_plt_col(color)) colors = np.zeros_like(values)[..., None] + col @@ -451,7 +451,7 @@ def plot(self, data: Field, figure, subplot, space: Box, min_val: float, max_val x, y = reshaped_numpy(c_data.center[dims], [vector, c_data.shape.without('vector')]) u, v = reshaped_numpy(c_data.values.vector[dims], [vector, c_data.shape.without('vector')]) color_i = color[idx] - if color[idx] == 'cmap': + if (color[idx] == 'cmap').all: col = _next_line_color(subplot, kind='collections') # ToDo elif color[idx].shape: col = [_plt_col(c) for c in color_i.numpy(c_data.shape.non_channel).reshape(-1)] @@ -494,7 +494,7 @@ def plot(self, data: Field, figure, subplot, space: Box, min_val: float, max_val x = x[:, 0] y = y[0, :] u, v = reshaped_numpy(data.values.vector[vector.item_names[0]], [vector, *data.shape.without('vector')]) - if color == 'cmap': + if wrap(color == 'cmap').all: col = reshaped_numpy(math.vec_length(data.values), [*data.shape.without('vector')]).T elif color.shape: col = [_plt_col(c) for c in color.numpy(data.shape.non_channel).reshape(-1)] @@ -619,7 +619,7 @@ def _plot_points(axis: Axes, data: Field, dims: tuple, vector: Shape, color: Ten data = Field(sdf_grid, math.NAN, 0) data = only_stored_elements(data) x, y = reshaped_numpy(data.points.vector[dims], ['vector', non_channel(data)]) - if color == 'cmap': + if wrap(color == 'cmap').all: values = reshaped_numpy(data.values, [non_channel(data)]) mpl_colors = add_color_bar(axis, values, min_val, max_val) single_color = False @@ -690,7 +690,7 @@ def _plot_points(axis: Axes, data: Field, dims: tuple, vector: Shape, color: Ten p1, p2 = edges.index x1, y1 = reshaped_numpy(data.graph.center[p1], ['vector', instance]) x2, y2 = reshaped_numpy(data.graph.center[p2], ['vector', instance]) - if color == 'cmap': + if wrap(color == 'cmap').all: edge_val = reshaped_numpy(edge_val, [instance]) edge_colors = add_color_bar(axis, edge_val, min_val, max_val) if edge_val.min() == edge_val.max(): @@ -877,7 +877,7 @@ def _plt_col(col): def matplotlib_colors(color: Tensor, dims: Shape, default=None) -> Union[list, None]: - if color.rank == 0 and color == 'cmap': + if color.rank == 0 and wrap(color == 'cmap').all: if default is None: return None else: