diff --git a/xdem/spatialstats.py b/xdem/spatialstats.py index d6fb7f31..85257f81 100644 --- a/xdem/spatialstats.py +++ b/xdem/spatialstats.py @@ -3299,6 +3299,17 @@ def plot_1d_binning( if statistic_name not in df.columns.values: raise ValueError(f'The statistic "{statistic_name}" is not part of the provided dataframe column names.') + # Re-format pandas interval if read from CSV as string + if all(isinstance(x, pd.Interval) for x in df[var_name].values): + pass + # Check for any unformatted interval (saving and reading a pd.DataFrame without MultiIndexing transforms + # pd.Interval into strings) + elif any(isinstance(_pandas_str_to_interval(x), pd.Interval) for x in df[var_name].values): + intervalindex_vals = [_pandas_str_to_interval(x) for x in df[var_name].values] + df[var_name] = pd.IntervalIndex(intervalindex_vals) + else: + raise ValueError("The variable columns must be provided as string or pd.Interval values.") + # Hide axes for the main subplot (which will be subdivded) ax.axis("off") @@ -3421,6 +3432,18 @@ def plot_2d_binning( if statistic_name not in df.columns.values: raise ValueError(f'The statistic "{statistic_name}" is not part of the provided dataframe column names.') + # Re-format pandas interval if read from CSV as string + for var in [var_name_1, var_name_2]: + if all(isinstance(x, pd.Interval) for x in df[var].values): + pass + # Check for any unformatted interval (saving and reading a pd.DataFrame without MultiIndexing transforms + # pd.Interval into strings) + elif any(isinstance(_pandas_str_to_interval(x), pd.Interval) for x in df[var].values): + intervalindex_vals = [_pandas_str_to_interval(x) for x in df[var].values] + df[var] = pd.IntervalIndex(intervalindex_vals) + else: + raise ValueError("The variable columns must be provided as string or pd.Interval values.") + # Hide axes for the main subplot (which will be subdivded) ax.axis("off")