Skip to content

Commit

Permalink
Merge pull request #88 from ChristopherMayes/cnsga_load_pop
Browse files Browse the repository at this point in the history
CNSGA load pop
  • Loading branch information
ChristopherMayes authored Jan 18, 2023
2 parents 0132d96 + 52bcb1e commit dafd2dd
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 30 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
- uses: conda-incubator/setup-miniconda@v2
with:
python-version: ${{ matrix.python-version }}
mamba-version: "*"
miniforge-variant: Mambaforge
channels: conda-forge
activate-environment: xopt-dev
environment-file: environment.yml
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/examples.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ jobs:
- uses: conda-incubator/setup-miniconda@v2
with:
python-version: ${{ matrix.python-version }}
mamba-version: "*"
miniforge-variant: Mambaforge
channels: conda-forge
activate-environment: xopt-dev
environment-file: environment.yml
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/gh-pages.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
- uses: conda-incubator/setup-miniconda@v2
with:
python-version: ${{ matrix.python-version }}
mamba-version: "*"
miniforge-variant: Mambaforge
channels: conda-forge
activate-environment: xopt-dev
environment-file: environment.yml
Expand Down
4 changes: 2 additions & 2 deletions xopt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
from xopt.log import configure_logger


def output_notebook():
def output_notebook(**kwargs):
"""
Redirects logging to stdout for use in Jupyter notebooks
"""
configure_logger()
configure_logger(**kwargs)
2 changes: 1 addition & 1 deletion xopt/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ def dump_state(self):
output = state_to_dict(self)
with open(self.options.dump_file, "w") as f:
yaml.dump(output, f)
logger.debug(f"Dumping state to:{self.options.dump_file}")
logger.debug(f"Dumped state to YAML file: {self.options.dump_file}")

@property
def data(self):
Expand Down
53 changes: 29 additions & 24 deletions xopt/generators/ga/cnsga.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import pandas as pd
from deap import algorithms as deap_algorithms, base as deap_base, tools as deap_tools
from pydantic import confloat
from pydantic import confloat, Field

import xopt.utils
from xopt.generator import Generator, GeneratorOptions
Expand All @@ -22,11 +22,11 @@

class CNSGAOptions(GeneratorOptions):

population_size: int = 64
crossover_probability: confloat(ge=0, le=1) = 0.9
mutation_probability: confloat(ge=0, le=1) = 1.0
population_file: str = None
output_path: str = None
population_size: int = Field(64, description="Population size")
crossover_probability: confloat(ge=0, le=1) = Field(0.9, description="Crossover probability")
mutation_probability: confloat(ge=0, le=1) = Field(1.0, description="Mutation probability")
population_file: str = Field(None, description="Population file to load (CSV format)")
output_path: str = Field(None, description="Output path for population files")


class CNSGAGenerator(Generator):
Expand All @@ -44,19 +44,21 @@ def __init__(self, vocs, options: CNSGAOptions = None):
super().__init__(vocs, options)

# Internal data structures
self.children = [] # unevaluated inputs. This should be a list of dicts.
self.population = None # The latest population (fully evaluated)
self.children = (
[]
) # list of unevaluated inputs. This should be a list of dicts.
self.population = None # The latest population data (fully evaluated)
self.offspring = None # Newly evaluated data, but not yet added to population

self._loaded_population = (
None # use these to generate children until the first pop is made
)

# DEAP toolbox (internal)
self.toolbox = cnsga_toolbox(vocs, selection="auto")

if options.population_file is not None:
self.load_population_csv(options.population_file)
# n_here = len(self.population)
# if n_here != self.n_pop:
# warnings.warn(f"Population in {options.population_file}"
# f"does not match n_pop: {n_here} != {self.n_pop}")

if options.output_path is not None:
assert os.path.isdir(options.output_path), "Output directory does not exist"
Expand Down Expand Up @@ -94,11 +96,17 @@ def create_children(self):

# No population, so create random children
if self.population is None:
return [self.vocs.random_inputs() for _ in range(self.n_pop)]
# Special case when pop is loaded from file
if self._loaded_population is None:
return [self.vocs.random_inputs() for _ in range(self.n_pop)]
else:
pop = self._loaded_population
else:
pop = self.population

# Use population to create children
inputs = cnsga_variation(
self.population,
pop,
self.vocs,
self.toolbox,
crossover_probability=self.options.crossover_probability,
Expand All @@ -111,16 +119,12 @@ def add_data(self, new_data: pd.DataFrame):

# Next generation
if len(self.offspring) >= self.n_pop:
if self.population is None:
self.population = self.offspring.iloc[: self.n_pop]
self.offspring = self.offspring.iloc[self.n_pop:]
else:
candidates = pd.concat([self.population, self.offspring])
self.population = cnsga_select(
candidates, self.n_pop, self.vocs, self.toolbox
)
self.children = [] # reset children
self.offspring = None # reset offspring
candidates = pd.concat([self.population, self.offspring])
self.population = cnsga_select(
candidates, self.n_pop, self.vocs, self.toolbox
)
self.children = [] # reset children
self.offspring = None # reset offspring

if self.options.output_path is not None:
self.write_population()
Expand Down Expand Up @@ -154,6 +158,7 @@ def load_population_csv(self, filename):
These will be reverted back to children for re-evaluation.
"""
pop = pd.read_csv(filename, index_col="xopt_index")
self._loaded_population = pop
# This is a list of dicts
self.children = self.vocs.convert_dataframe_to_inputs(pop).to_dict(
orient="records"
Expand Down

0 comments on commit dafd2dd

Please sign in to comment.