Skip to content

Commit

Permalink
save attr str var to numeric (#4)
Browse files Browse the repository at this point in the history
  • Loading branch information
OuyangWenyu authored Nov 20, 2023
1 parent ac31b20 commit c53210d
Showing 1 changed file with 20 additions and 2 deletions.
22 changes: 20 additions & 2 deletions hydrodataset/camels.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
"""
Author: Wenyu Ouyang
Date: 2022-01-05 18:01:11
LastEditTime: 2023-10-18 17:54:32
LastEditTime: 2023-11-20 20:08:14
LastEditors: Wenyu Ouyang
Description: Read Camels Series ("AUStralia", "BRazil", "ChiLe", "GreatBritain", "UnitedStates") datasets
FilePath: \hydrodataset\hydrodataset\camels.py
FilePath: /hydrodataset/hydrodataset/camels.py
Copyright (c) 2021-2022 Wenyu Ouyang. All rights reserved.
"""
import json
Expand Down Expand Up @@ -1523,6 +1523,7 @@ def cache_attributes_xrdataset(self):

attrs_df = pd.concat(attrs.values(), axis=1)

# fix station names
def fix_station_nm(station_nm):
name = station_nm.title().rsplit(" ", 1)
name[0] = name[0] if name[0][-1] == "," else f"{name[0]},"
Expand All @@ -1535,6 +1536,17 @@ def fix_station_nm(station_nm):
obj_cols = attrs_df.columns[attrs_df.dtypes == "object"]
for c in obj_cols:
attrs_df[c] = attrs_df[c].str.strip().astype(str)

# transform categorical variables to numeric
categorical_mappings = {}
for column in attrs_df.columns:
if attrs_df[column].dtype == "object":
attrs_df[column] = attrs_df[column].astype("category")
categorical_mappings[column] = dict(
enumerate(attrs_df[column].cat.categories)
)
attrs_df[column] = attrs_df[column].cat.codes

# unify id to basin
attrs_df.index.name = "basin"
# We use xarray dataset to cache all data
Expand Down Expand Up @@ -1605,6 +1617,12 @@ def fix_station_nm(station_nm):
for var_name in units_dict:
if var_name in ds_from_df.data_vars:
ds_from_df[var_name].attrs["units"] = units_dict[var_name]

# Assign categorical mappings to the variables in the Dataset
for column in ds_from_df.data_vars:
if column in categorical_mappings:
mapping_str = categorical_mappings[column]
ds_from_df[column].attrs["category_mapping"] = str(mapping_str)
return ds_from_df

def cache_streamflow_xrdataset(self):
Expand Down

0 comments on commit c53210d

Please sign in to comment.