Skip to content

Commit

Permalink
[Bugfix] Fix the case when the input value of wd_l2norm is using scie…
Browse files Browse the repository at this point in the history
…ntific notation. (#1054)

*Issue #, if available:*
#1051

*Description of changes:*
Fix the bug for wd_l2norm and alpha_l2norm


By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.

---------

Co-authored-by: Xiang Song <[email protected]>
  • Loading branch information
classicsong and Xiang Song authored Oct 2, 2024
1 parent 0b48f4e commit 2fd1511
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 2 deletions.
14 changes: 12 additions & 2 deletions python/graphstorm/config/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -1723,7 +1723,12 @@ def wd_l2norm(self):
"""
# pylint: disable=no-member
if hasattr(self, "_wd_l2norm"):
return self._wd_l2norm
try:
wd_l2norm = float(self._wd_l2norm)
except:
raise ValueError("wd_l2norm must be a floating point " \
f"but get {self._wd_l2norm}")
return wd_l2norm
return 0

@property
Expand All @@ -1735,7 +1740,12 @@ def alpha_l2norm(self):
"""
# pylint: disable=no-member
if hasattr(self, "_alpha_l2norm"):
return self._alpha_l2norm
try:
alpha_l2norm = float(self._alpha_l2norm)
except:
raise ValueError("alpha_l2norm must be a floating point " \
f"but get {self._alpha_l2norm}")
return alpha_l2norm
return .0

@property
Expand Down
16 changes: 16 additions & 0 deletions tests/unit-tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,8 @@ def create_train_config(tmp_path, file_name):
yaml_object["gsf"]["hyperparam"] = {
"topk_model_to_save": 4,
"save_model_path": os.path.join(tmp_path, "save"),
"wd_l2norm": 5e-5,
"alpha_l2norm": 5e-5,
}
with open(os.path.join(tmp_path, file_name+"1.yaml"), "w") as f:
yaml.dump(yaml_object, f)
Expand All @@ -261,6 +263,8 @@ def create_train_config(tmp_path, file_name):
'save_model_frequency': 2000,
"topk_model_to_save": 5,
"save_model_path": os.path.join(tmp_path, "save"),
"wd_l2norm": "1e-3",
"alpha_l2norm": "1e-3",
}
with open(os.path.join(tmp_path, file_name+"2.yaml"), "w") as f:
yaml.dump(yaml_object, f)
Expand Down Expand Up @@ -291,6 +295,8 @@ def create_train_config(tmp_path, file_name):
"use_early_stop": True,
"early_stop_burnin_rounds": -1,
"early_stop_rounds": 0,
"wd_l2norm": "NA",
"alpha_l2norm": "NA",
}

with open(os.path.join(tmp_path, file_name+"_fail.yaml"), "w") as f:
Expand All @@ -301,6 +307,8 @@ def create_train_config(tmp_path, file_name):
'save_model_frequency': 2000,
"topk_model_to_save": 3,
"save_model_path": os.path.join(tmp_path, "save"),
"wd_l2norm": "",
"alpha_l2norm": "",
}
with open(os.path.join(tmp_path, file_name+"_fail1.yaml"), "w") as f:
yaml.dump(yaml_object, f)
Expand Down Expand Up @@ -350,12 +358,16 @@ def test_train_info():
args = Namespace(yaml_config_file=os.path.join(Path(tmpdirname), 'train_test1.yaml'), local_rank=0)
config = GSConfig(args)
assert config.topk_model_to_save == 4
assert config.wd_l2norm == 5e-5
assert config.alpha_l2norm == 5e-5

args = Namespace(yaml_config_file=os.path.join(Path(tmpdirname), 'train_test2.yaml'), local_rank=0)
config = GSConfig(args)
assert config.eval_frequency == 1000
assert config.save_model_frequency == 2000
assert config.topk_model_to_save == 5
assert config.wd_l2norm == 1e-3
assert config.alpha_l2norm == 1e-3

args = Namespace(yaml_config_file=os.path.join(Path(tmpdirname), 'train_test3.yaml'), local_rank=0)
config = GSConfig(args)
Expand All @@ -380,12 +392,16 @@ def test_train_info():
check_failure(config, "topk_model_to_save")
check_failure(config, "early_stop_burnin_rounds")
check_failure(config, "early_stop_rounds")
check_failure(config, "wd_l2norm")
check_failure(config, "alpha_l2norm")

args = Namespace(yaml_config_file=os.path.join(Path(tmpdirname), 'train_test_fail1.yaml'), local_rank=0)
config = GSConfig(args)
# in PR # 893 we loose the constraints of model saving frequency and eval frequency
# so here we do not check failure, but check the topk model argument
assert config.topk_model_to_save == 3
check_failure(config, "wd_l2norm")
check_failure(config, "alpha_l2norm")

def create_rgcn_config(tmp_path, file_name):
yaml_object = create_dummpy_config_obj()
Expand Down

0 comments on commit 2fd1511

Please sign in to comment.