diff --git a/lux/vislib/altair/AltairChart.py b/lux/vislib/altair/AltairChart.py index c56d7d2f..8920ab1b 100644 --- a/lux/vislib/altair/AltairChart.py +++ b/lux/vislib/altair/AltairChart.py @@ -36,6 +36,8 @@ def __init__(self, vis): self.tooltip = True # ----- START self.code modification ----- self.code = "" + self.width = 160 + self.height = 150 self.chart = self.initialize_chart() # self.add_tooltip() self.encode_color() @@ -71,7 +73,9 @@ def apply_default_config(self): labelFont="Helvetica Neue", ) plotting_scale = lux.config.plotting_scale - self.chart = self.chart.properties(width=160 * plotting_scale, height=150 * plotting_scale) + self.chart = self.chart.properties( + width=self.width * plotting_scale, height=self.height * plotting_scale + ) self.code += ( "\nchart = chart.configure_title(fontWeight=500,fontSize=13,font='Helvetica Neue')\n" ) @@ -79,9 +83,7 @@ def apply_default_config(self): self.code += "\t\t\t\t\tlabelFontWeight=400,labelFontSize=8,labelFont='Helvetica Neue',labelColor='#505050')\n" self.code += "chart = chart.configure_legend(titleFontWeight=500,titleFontSize=10,titleFont='Helvetica Neue',\n" self.code += "\t\t\t\t\tlabelFontWeight=400,labelFontSize=8,labelFont='Helvetica Neue')\n" - self.code += ( - f"chart = chart.properties(width={160 * plotting_scale},height={150 * plotting_scale})\n" - ) + self.code += f"chart = chart.properties(width={self.width * plotting_scale},height={self.height * plotting_scale})\n" def encode_color(self): color_attr = self.vis.get_attr_by_channel("color") diff --git a/lux/vislib/altair/Choropleth.py b/lux/vislib/altair/Choropleth.py index 3a0f0b30..b7768c7f 100644 --- a/lux/vislib/altair/Choropleth.py +++ b/lux/vislib/altair/Choropleth.py @@ -40,6 +40,9 @@ def __repr__(self): return f"Choropleth Map <{str(self.vis)}>" def initialize_chart(self): + # Override default width and height + self.width = 200 + x_attr = self.vis.get_attr_by_channel("x")[0] y_attr = self.vis.get_attr_by_channel("y")[0] @@ -148,6 +151,8 @@ def get_geomap(self, feature): def get_us_fips_code(self, attribute): """Returns FIPS code given a US state""" + if not isinstance(attribute, str): + return attribute usa = pd.DataFrame( [ {"fips": 1, "state": "alabama", "abbrev": "al"}, @@ -204,8 +209,6 @@ def get_us_fips_code(self, attribute): ] ) attribute = attribute.lower() - if not isinstance(attribute, str): - return attribute match = usa[(usa.state == attribute) | (usa.abbrev == attribute)] if len(match) == 1: return match["fips"].values[0] @@ -213,7 +216,7 @@ def get_us_fips_code(self, attribute): if attribute in ["washington d.c.", "washington dc", "d.c.", "d.c"]: return 11 else: - return attribute + return 0 # any unmatching value (e.g. nan) def get_country_iso_code(self, attribute): """Returns country ISO code given a country"""