Skip to content

Commit

Permalink
add GA example in run.py (#24)
Browse files Browse the repository at this point in the history
  • Loading branch information
ppdebreuck authored and ml-evs committed Oct 11, 2021
1 parent acaf38e commit cc36d63
Showing 1 changed file with 33 additions and 13 deletions.
46 changes: 33 additions & 13 deletions benchmarks/matbench_v0.1_modnet_v0.1.10/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from modnet.models import EnsembleMODNetModel
from pymatgen.core import Composition

USE_GA = False # wheter to use the GA or fit_preset (dynamic grid-search) for hyper-paremeter optimization.

mb = MatbenchBenchmark(
autoload=False,
subset=[
Expand All @@ -19,7 +21,11 @@
'matbench_log_gvrh',
'matbench_log_kvrh',
'matbench_glass',
'matbench_expt_is_metal',
'matbench_expt_is_metal',
#'matbench_perovskites', # for the bigger tasks, USE_GA=True is recommended, as training time scales better with larger training sets
#'matbench_mp_e_form',
#'matbench_mp_gap',
#'matbench_mp_is_metal',
],
)

Expand Down Expand Up @@ -65,18 +71,32 @@
model = EnsembleMODNetModel(targets_hierarchy, weights)

# fit model
(
models,
val_losses,
best_learning_curve,
learning_curves,
best_presets,
) = model.fit_preset(
train_data,
classification=classification,
nested=5,
n_jobs=16,
)

if USE_GA:
# you can either use a GA for hyper-parameter optimization or...
from modnet.hyper_opt import FitGenetic
ga = FitGenetic(train_data)
model = ga.run(
size_pop=20,
num_generations=10,
n_jobs=16,
early_stopping=True,
refit=True,
)
else:
# ... a list of presets (kind of dynamic grid search)
(
models,
val_losses,
best_learning_curve,
learning_curves,
best_presets,
) = model.fit_preset(
train_data,
classification=classification,
nested=5,
n_jobs=16,
)

# Load and featurize test dataset
test_df = task.get_test_data(fold, include_target=False, as_type="df")
Expand Down

0 comments on commit cc36d63

Please sign in to comment.