From c53210d3857924e266cc1a1479378e1a0aab3828 Mon Sep 17 00:00:00 2001 From: OuyangWenyu Date: Mon, 20 Nov 2023 20:47:01 +0800 Subject: [PATCH] save attr str var to numeric (#4) --- hydrodataset/camels.py | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/hydrodataset/camels.py b/hydrodataset/camels.py index 706bc80..88d680f 100644 --- a/hydrodataset/camels.py +++ b/hydrodataset/camels.py @@ -1,10 +1,10 @@ """ Author: Wenyu Ouyang Date: 2022-01-05 18:01:11 -LastEditTime: 2023-10-18 17:54:32 +LastEditTime: 2023-11-20 20:08:14 LastEditors: Wenyu Ouyang Description: Read Camels Series ("AUStralia", "BRazil", "ChiLe", "GreatBritain", "UnitedStates") datasets -FilePath: \hydrodataset\hydrodataset\camels.py +FilePath: /hydrodataset/hydrodataset/camels.py Copyright (c) 2021-2022 Wenyu Ouyang. All rights reserved. """ import json @@ -1523,6 +1523,7 @@ def cache_attributes_xrdataset(self): attrs_df = pd.concat(attrs.values(), axis=1) + # fix station names def fix_station_nm(station_nm): name = station_nm.title().rsplit(" ", 1) name[0] = name[0] if name[0][-1] == "," else f"{name[0]}," @@ -1535,6 +1536,17 @@ def fix_station_nm(station_nm): obj_cols = attrs_df.columns[attrs_df.dtypes == "object"] for c in obj_cols: attrs_df[c] = attrs_df[c].str.strip().astype(str) + + # transform categorical variables to numeric + categorical_mappings = {} + for column in attrs_df.columns: + if attrs_df[column].dtype == "object": + attrs_df[column] = attrs_df[column].astype("category") + categorical_mappings[column] = dict( + enumerate(attrs_df[column].cat.categories) + ) + attrs_df[column] = attrs_df[column].cat.codes + # unify id to basin attrs_df.index.name = "basin" # We use xarray dataset to cache all data @@ -1605,6 +1617,12 @@ def fix_station_nm(station_nm): for var_name in units_dict: if var_name in ds_from_df.data_vars: ds_from_df[var_name].attrs["units"] = units_dict[var_name] + + # Assign categorical mappings to the variables in the Dataset + for column in ds_from_df.data_vars: + if column in categorical_mappings: + mapping_str = categorical_mappings[column] + ds_from_df[column].attrs["category_mapping"] = str(mapping_str) return ds_from_df def cache_streamflow_xrdataset(self):