Skip to content

Commit

Permalink
finish base gpr
Browse files Browse the repository at this point in the history
  • Loading branch information
pohaoc2 committed Nov 24, 2024
1 parent 521d1b3 commit e8e9276
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 13 deletions.
4 changes: 0 additions & 4 deletions src/conf/cs/models/gpr.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,3 @@ gpr:
- 3
kernel:
- DotProduct + WhiteKernel
- RBF
- Matern
- RationalQuadratic
- Exponential
26 changes: 25 additions & 1 deletion src/permutation/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,29 @@ def pandas_to_csv(
dir_path = f"{self.export_path}/{self.experiment}/{model}/{stage.name}"
self._save_df(dir_path, filename, dataframe)

def make_json_serializable(self, data):
"""
Recursively checks if all elements in a dictionary are JSON-serializable.
If not, converts non-serializable elements to strings.
Parameters:
- data (dict): Dictionary of parameters to check.
Returns:
- dict: JSON-serializable dictionary.
"""
if isinstance(data, dict):
return {key: self.make_json_serializable(value) for key, value in data.items()}
elif isinstance(data, list):
return [self.make_json_serializable(item) for item in data]
try:
# Attempt to serialize to JSON
json.dumps(data)
return data
except (TypeError, OverflowError):
# If not serializable, convert to string
return str(data)

def _save_df(self, dir_path: str, name: str, df: pd.DataFrame | pd.Series) -> None:
"""Validates save path and then saves a Dataframe as a CSV file"""
validate_dir(dir_path)
Expand All @@ -67,7 +90,8 @@ def save_model_json(self, runner: Runner) -> None:
)
with open(dir_path, "w", encoding="UTF-8") as outfile:
if runner.model.hparams:
json.dump(runner.model.hparams.as_dict(), outfile)
serialized_hp = self.make_json_serializable(runner.model.hparams.as_dict())
json.dump(serialized_hp, outfile)

def save_predictions(self, model: str, runner: Runner) -> None:
"""Saves the model predictions as a CSV file"""
Expand Down
4 changes: 1 addition & 3 deletions src/permutation/models/gpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,7 @@ def set_model(
kernel_name = hparams_dict.get("kernel")
kernel = map_kernel(kernel_name)
hparams_dict['kernel'] = kernel
print(hparams_dict)
print(hparams)
hparams = HParams(**hparams_dict)
hparams = HParams(param_dict=hparams_dict)
else:
hparams_dict = {'kernel': kernel} if kernel else {}

Expand Down
5 changes: 0 additions & 5 deletions src/permutation/models/sklearnmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,6 @@ def _set_model(
"""
model = cls()
model.hparams = hparams
print(model.hparams)
print(model)
print(model_dependency)
#asd()
pipeline_list = []

if preprocessing_dependencies:
Expand All @@ -70,7 +66,6 @@ def _set_model(
model.model = model_dependency()

pipeline_list.append((model.algorithm_name, model.model)) # pylint: disable=no-member

model.pipeline = Pipeline(pipeline_list)
return model

Expand Down

0 comments on commit e8e9276

Please sign in to comment.