Skip to content

Commit

Permalink
[vis] Fix Matplotlib color handling
Browse files Browse the repository at this point in the history
  • Loading branch information
holl- committed Oct 6, 2024
1 parent 4a71915 commit 751d8d1
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions phi/vis/_matplotlib/_matplotlib_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from phi.geom._transform import _EmbeddedGeometry
from phi.math import Tensor, channel, spatial, instance, non_channel, Shape, reshaped_numpy, shape
from phi.vis._vis_base import display_name, PlottingLibrary, Recipe, index_label, only_stored_elements, to_field

from phiml.math import wrap

colormaps = matplotlib.colormaps if hasattr(matplotlib.colormaps, 'get_cmap') else matplotlib.cm

Expand Down Expand Up @@ -400,7 +400,7 @@ def plot(self, data: Field, figure, subplot, space: Box, min_val: float, max_val
xyz = StaggeredGrid(lambda x: x, math.extrapolation.BOUNDARY, data.geometry.bounds, data.resolution).staggered_tensor().numpy(dims + ('vector',))[:-1, :-1, :-1, :]
xyz = xyz.reshape(-1, 3)
values = data.values.numpy(dims).flatten()
if color == 'cmap':
if wrap(color == 'cmap').all:
color = 0
col = matplotlib.colors.to_rgba(_plt_col(color))
colors = np.zeros_like(values)[..., None] + col
Expand Down Expand Up @@ -451,7 +451,7 @@ def plot(self, data: Field, figure, subplot, space: Box, min_val: float, max_val
x, y = reshaped_numpy(c_data.center[dims], [vector, c_data.shape.without('vector')])
u, v = reshaped_numpy(c_data.values.vector[dims], [vector, c_data.shape.without('vector')])
color_i = color[idx]
if color[idx] == 'cmap':
if (color[idx] == 'cmap').all:
col = _next_line_color(subplot, kind='collections') # ToDo
elif color[idx].shape:
col = [_plt_col(c) for c in color_i.numpy(c_data.shape.non_channel).reshape(-1)]
Expand Down Expand Up @@ -494,7 +494,7 @@ def plot(self, data: Field, figure, subplot, space: Box, min_val: float, max_val
x = x[:, 0]
y = y[0, :]
u, v = reshaped_numpy(data.values.vector[vector.item_names[0]], [vector, *data.shape.without('vector')])
if color == 'cmap':
if wrap(color == 'cmap').all:
col = reshaped_numpy(math.vec_length(data.values), [*data.shape.without('vector')]).T
elif color.shape:
col = [_plt_col(c) for c in color.numpy(data.shape.non_channel).reshape(-1)]
Expand Down Expand Up @@ -619,7 +619,7 @@ def _plot_points(axis: Axes, data: Field, dims: tuple, vector: Shape, color: Ten
data = Field(sdf_grid, math.NAN, 0)
data = only_stored_elements(data)
x, y = reshaped_numpy(data.points.vector[dims], ['vector', non_channel(data)])
if color == 'cmap':
if wrap(color == 'cmap').all:
values = reshaped_numpy(data.values, [non_channel(data)])
mpl_colors = add_color_bar(axis, values, min_val, max_val)
single_color = False
Expand Down Expand Up @@ -690,7 +690,7 @@ def _plot_points(axis: Axes, data: Field, dims: tuple, vector: Shape, color: Ten
p1, p2 = edges.index
x1, y1 = reshaped_numpy(data.graph.center[p1], ['vector', instance])
x2, y2 = reshaped_numpy(data.graph.center[p2], ['vector', instance])
if color == 'cmap':
if wrap(color == 'cmap').all:
edge_val = reshaped_numpy(edge_val, [instance])
edge_colors = add_color_bar(axis, edge_val, min_val, max_val)
if edge_val.min() == edge_val.max():
Expand Down Expand Up @@ -877,7 +877,7 @@ def _plt_col(col):


def matplotlib_colors(color: Tensor, dims: Shape, default=None) -> Union[list, None]:
if color.rank == 0 and color == 'cmap':
if color.rank == 0 and wrap(color == 'cmap').all:
if default is None:
return None
else:
Expand Down

0 comments on commit 751d8d1

Please sign in to comment.