Skip to content

Commit

Permalink
Add pd.interval reformatting to nd_binning plotting functions
Browse files Browse the repository at this point in the history
  • Loading branch information
rhugonnet committed Nov 14, 2023
1 parent 9b3ecfa commit d53d5c4
Showing 1 changed file with 23 additions and 0 deletions.
23 changes: 23 additions & 0 deletions xdem/spatialstats.py
Original file line number Diff line number Diff line change
Expand Up @@ -3299,6 +3299,17 @@ def plot_1d_binning(
if statistic_name not in df.columns.values:
raise ValueError(f'The statistic "{statistic_name}" is not part of the provided dataframe column names.')

# Re-format pandas interval if read from CSV as string
if all(isinstance(x, pd.Interval) for x in df[var_name].values):
pass
# Check for any unformatted interval (saving and reading a pd.DataFrame without MultiIndexing transforms
# pd.Interval into strings)
elif any(isinstance(_pandas_str_to_interval(x), pd.Interval) for x in df[var_name].values):
intervalindex_vals = [_pandas_str_to_interval(x) for x in df[var_name].values]
df[var_name] = pd.IntervalIndex(intervalindex_vals)
else:
raise ValueError("The variable columns must be provided as string or pd.Interval values.")

# Hide axes for the main subplot (which will be subdivded)
ax.axis("off")

Expand Down Expand Up @@ -3421,6 +3432,18 @@ def plot_2d_binning(
if statistic_name not in df.columns.values:
raise ValueError(f'The statistic "{statistic_name}" is not part of the provided dataframe column names.')

# Re-format pandas interval if read from CSV as string
for var in [var_name_1, var_name_2]:
if all(isinstance(x, pd.Interval) for x in df[var].values):
pass
# Check for any unformatted interval (saving and reading a pd.DataFrame without MultiIndexing transforms
# pd.Interval into strings)
elif any(isinstance(_pandas_str_to_interval(x), pd.Interval) for x in df[var].values):
intervalindex_vals = [_pandas_str_to_interval(x) for x in df[var].values]
df[var] = pd.IntervalIndex(intervalindex_vals)
else:
raise ValueError("The variable columns must be provided as string or pd.Interval values.")

# Hide axes for the main subplot (which will be subdivded)
ax.axis("off")

Expand Down

0 comments on commit d53d5c4

Please sign in to comment.