diff --git a/lux/vislib/matplotlib/BarChart.py b/lux/vislib/matplotlib/BarChart.py index 6f706f7d..4c6a960b 100644 --- a/lux/vislib/matplotlib/BarChart.py +++ b/lux/vislib/matplotlib/BarChart.py @@ -79,9 +79,9 @@ def initialize_chart(self): ) df = self.data - - bars = df[bar_attr].apply(lambda x: str(x)) - measurements = df[measure_attr] + bar = df[bar_attr].apply(lambda x: str(x)) + bars = list(bar) + measurements = list(df[measure_attr]) plot_code = "" @@ -89,7 +89,6 @@ def initialize_chart(self): if len(color_attr) == 1: self.fig, self.ax = matplotlib_setup(6, 4) color_attr_name = color_attr[0].attribute - color_attr_type = color_attr[0].data_type colors = df[color_attr_name].values unique = list(set(colors)) d_x = {} @@ -101,22 +100,22 @@ def initialize_chart(self): d_x[colors[i]].append(bars[i]) d_y[colors[i]].append(measurements[i]) for i in range(len(unique)): - self.ax.barh(d_x[unique[i]], d_y[unique[i]], label=unique[i]) - plot_code += ( - f"ax.barh({d_x}[{unique}[{i}]], {d_y}[{unique}[{i}]], label={unique}[{i}])\n" - ) + xval = d_x[unique[i]] + yval = d_y[unique[i]] + l = unique[i] + self.ax.barh(xval, yval, label=l) + plot_code += f"ax.barh({xval},{yval}, label='{l}')\n" self.ax.legend( title=color_attr_name, bbox_to_anchor=(1.05, 1), loc="upper left", ncol=1, frameon=False ) - plot_code += f"""ax.legend( - title='{color_attr_name}', - bbox_to_anchor=(1.05, 1), - loc='upper left', - ncol=1, - frameon=False,)\n""" + plot_code += f"""ax.legend(title='{color_attr_name}', + bbox_to_anchor=(1.05, 1), + loc='upper left', + ncol=1, + frameon=False)\n""" else: - self.ax.barh(bars, measurements, align="center") - plot_code += f"ax.barh(bars, measurements, align='center')\n" + self.ax.barh(bar, df[measure_attr], align="center") + plot_code += f"ax.barh({bar}, {df[measure_attr]}, align='center')\n" y_ticks_abbev = df[bar_attr].apply(lambda x: str(x)[:10] + "..." if len(str(x)) > 10 else str(x)) self.ax.set_yticks(bars) @@ -128,7 +127,7 @@ def initialize_chart(self): self.code += "import numpy as np\n" self.code += "from math import nan\n" - + self.code += f"df = pd.DataFrame({str(self.data.to_dict())})\n" self.code += f"fig, ax = plt.subplots()\n" self.code += f"bars = df['{bar_attr}']\n" self.code += f"measurements = df['{measure_attr}']\n" diff --git a/lux/vislib/matplotlib/LineChart.py b/lux/vislib/matplotlib/LineChart.py index b2fd9f11..920aa9c4 100644 --- a/lux/vislib/matplotlib/LineChart.py +++ b/lux/vislib/matplotlib/LineChart.py @@ -110,7 +110,7 @@ def initialize_chart(self): self.code += "import numpy as np\n" self.code += "from math import nan\n" - + self.code += f"df = pd.DataFrame({str(self.data.to_dict())})\n" self.code += f"fig, ax = plt.subplots()\n" self.code += f"x_pts = df['{x_attr.attribute}']\n" self.code += f"y_pts = df['{y_attr.attribute}']\n" diff --git a/lux/vislib/matplotlib/ScatterChart.py b/lux/vislib/matplotlib/ScatterChart.py index 66dc8297..2c33c887 100644 --- a/lux/vislib/matplotlib/ScatterChart.py +++ b/lux/vislib/matplotlib/ScatterChart.py @@ -140,7 +140,7 @@ def initialize_chart(self): self.code += "import numpy as np\n" self.code += "from math import nan\n" self.code += "from matplotlib.cm import ScalarMappable\n" - + self.code += f"df = pd.DataFrame({str(self.data.to_dict())})\n" self.code += set_fig_code self.code += f"x_pts = df['{x_attr.attribute}']\n" self.code += f"y_pts = df['{y_attr.attribute}']\n" diff --git a/tests/test_vis.py b/tests/test_vis.py index 8f49d292..601f9ceb 100644 --- a/tests/test_vis.py +++ b/tests/test_vis.py @@ -218,7 +218,6 @@ def test_bar_chart(global_var): lux.config.plotting_backend = "matplotlib" vis = Vis(["Origin", "Acceleration"], df) vis_code = vis.to_matplotlib() - assert "ax.barh(bars, measurements, align='center')" in vis_code assert "ax.set_xlabel('Acceleration')" in vis_code assert "ax.set_ylabel('Origin')" in vis_code