Skip to content

Commit

Permalink
[WIP] Add callbacks to rcpsp/ortools/cpsat solver
Browse files Browse the repository at this point in the history
wip: store all solutions for now
-> only improving ones ? (as in routing)
-> only x best ones ? (use limit_store but also change resultstorage to
avoid appending to list_solution_fit

- Make CPSatRCPSPSolver derive from CPSolver
- Remove specific callback used to print intermediate variables (to be
  replaced by a d-o callback in examples/tests)
- Branch d-o callbacks in a dedicated CpSolverSolutionCallback
- Manage result storage with the CpSolverSolutionCallback
- Test callbacks usage
  • Loading branch information
nhuet committed Jan 29, 2024
1 parent 03d76f4 commit a043f9c
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 41 deletions.
123 changes: 82 additions & 41 deletions discrete_optimization/rcpsp/solver/cpsat_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,28 +13,33 @@
UNKNOWN,
CpModel,
CpSolver,
IntervalVar,
VarArrayAndObjectiveSolutionPrinter,
VarArraySolutionPrinter,
CpSolverSolutionCallback,
)

from discrete_optimization.generic_tools.cp_tools import ParametersCP, StatusSolver
from discrete_optimization.generic_tools.callbacks.callback import (
Callback,
CallbackList,
)
from discrete_optimization.generic_tools.cp_tools import (
CPSolver,
ParametersCP,
StatusSolver,
)
from discrete_optimization.generic_tools.do_problem import (
ParamsObjectiveFunction,
build_aggreg_function_and_params_objective,
)
from discrete_optimization.generic_tools.do_solver import SolverDO
from discrete_optimization.generic_tools.exceptions import SolveEarlyStop
from discrete_optimization.generic_tools.result_storage.result_storage import (
ResultStorage,
from_solutions_to_result_storage,
)
from discrete_optimization.rcpsp.rcpsp_model import RCPSPModel, RCPSPSolution
from discrete_optimization.rcpsp.rcpsp_utils import create_fake_tasks

logger = logging.getLogger(__name__)


class CPSatRCPSPSolver(SolverDO):
class CPSatRCPSPSolver(CPSolver):
def __init__(
self,
problem: RCPSPModel,
Expand All @@ -53,7 +58,7 @@ def __init__(
self.variables: Optional[Dict[str, Any]] = None
self.status_solver: Optional[StatusSolver] = None

def init_model(self):
def init_model(self, **kwargs: Any) -> None:
model = CpModel()
starts_var = {}
ends_var = {}
Expand Down Expand Up @@ -156,48 +161,35 @@ def init_model(self):
}

def solve(
self, parameters_cp: Optional[ParametersCP] = None, **kwargs: Any
self,
callbacks: Optional[List[Callback]] = None,
parameters_cp: Optional[ParametersCP] = None,
**kwargs: Any,
) -> ResultStorage:
callbacks_list = CallbackList(callbacks=callbacks)
callbacks_list.on_solve_start(solver=self)
if self.cp_model is None:
self.init_model()
if parameters_cp is None:
parameters_cp = ParametersCP.default()
solver = CpSolver()
solver.parameters.max_time_in_seconds = parameters_cp.time_limit
solver.parameters.num_workers = parameters_cp.nb_process
callback = VarArrayAndObjectiveSolutionPrinter(
variables=list(self.variables["is_present"].values())
+ list(self.variables["start"].values())
)
status = solver.Solve(self.cp_model, callback)
ortools_callback = DOCallback(do_solver=self, callback=callbacks_list)
try:
status = solver.Solve(self.cp_model, ortools_callback)
self.status_solver = cpstatus_to_dostatus(status_from_cpsat=status)
except SolveEarlyStop as e:
logger.info(e)
if ortools_callback.nb_solutions > 0:
status = StatusSolver.SATISFIED
else:
status = StatusSolver.UNSATISFIABLE
# ortools_callback.store_current_solution()
self.status_solver = cpstatus_to_dostatus(status_from_cpsat=status)
logger.info(
f"Solver finished, status={solver.StatusName(status)}, objective = {solver.ObjectiveValue()},"
f"best obj bound = {solver.BestObjectiveBound()}"
)
return self.retrieve_solution(solver=solver)

def retrieve_solution(self, solver: CpSolver):
schedule = {}
modes_dict = {}
for task in self.variables["start"]:
schedule[task] = {
"start_time": solver.Value(self.variables["start"][task]),
"end_time": solver.Value(self.variables["end"][task]),
}
for task, mode in self.variables["is_present"]:
if solver.Value(self.variables["is_present"][task, mode]):
modes_dict[task] = mode
sol = RCPSPSolution(
problem=self.problem,
rcpsp_schedule=schedule,
rcpsp_modes=[modes_dict[t] for t in self.problem.tasks_list_non_dummy],
)
return from_solutions_to_result_storage(
[sol],
problem=self.problem,
params_objective_function=self.params_objective_function,
)
res = ortools_callback.res
callbacks_list.on_solve_end(res=res, solver=self)
return res


def cpstatus_to_dostatus(status_from_cpsat) -> StatusSolver:
Expand All @@ -214,3 +206,52 @@ def cpstatus_to_dostatus(status_from_cpsat) -> StatusSolver:
return StatusSolver.OPTIMAL
if status_from_cpsat == FEASIBLE:
return StatusSolver.SATISFIED


class DOCallback(CpSolverSolutionCallback):
def __init__(self, do_solver: CPSatRCPSPSolver, callback: Callback):
super().__init__()
self.do_solver = do_solver
self.callback = callback
self.res = ResultStorage(
[],
mode_optim=self.do_solver.params_objective_function.sense_function,
limit_store=False,
)
self.nb_solutions = 0

def on_solution_callback(self) -> None:
self.store_current_solution()
self.nb_solutions += 1
# end of step callback: stopping?
stopping = self.callback.on_step_end(
step=self.nb_solutions, res=self.res, solver=self.do_solver
)
if stopping:
raise SolveEarlyStop(
f"{self.do_solver.__class__.__name__}.solve() stopped by user callback."
)

def get_current_solution(self) -> RCPSPSolution:
schedule = {}
modes_dict = {}
for task in self.do_solver.variables["start"]:
schedule[task] = {
"start_time": self.Value(self.do_solver.variables["start"][task]),
"end_time": self.Value(self.do_solver.variables["end"][task]),
}
for task, mode in self.do_solver.variables["is_present"]:
if self.Value(self.do_solver.variables["is_present"][task, mode]):
modes_dict[task] = mode
return RCPSPSolution(
problem=self.do_solver.problem,
rcpsp_schedule=schedule,
rcpsp_modes=[
modes_dict[t] for t in self.do_solver.problem.tasks_list_non_dummy
],
)

def store_current_solution(self):
sol = self.get_current_solution()
fit = self.do_solver.aggreg_sol(sol)
self.res.add_solution(solution=sol, fitness=fit)
37 changes: 37 additions & 0 deletions tests/rcpsp/solver/test_rcpsp_cp.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
# Copyright (c) 2022 AIRBUS and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging

import pytest

from discrete_optimization.generic_tools.callbacks.callback import Callback
from discrete_optimization.generic_tools.callbacks.early_stoppers import TimerStopper
from discrete_optimization.generic_tools.callbacks.loggers import NbIterationTracker
from discrete_optimization.generic_tools.cp_tools import CPSolverName, ParametersCP
from discrete_optimization.generic_tools.result_storage.result_storage import (
ResultStorage,
result_storage_to_pareto_front,
)
from discrete_optimization.rcpsp.rcpsp_model import RCPSPModel
Expand Down Expand Up @@ -86,6 +91,38 @@ def test_ortools(model):
plot_task_gantt(rcpsp_problem, solution)


def test_ortools_with_cb(caplog):
model = "j1201_1.sm"
files_available = get_data_available()
file = [f for f in files_available if model in f][0]
rcpsp_problem = parse_file(file)
solver = CPSatRCPSPSolver(problem=rcpsp_problem)
parameters_cp = ParametersCP.default()
parameters_cp.time_limit = 5
parameters_cp.nr_solutions = 1

class VariablePrinterCallback(Callback):
def __init__(self) -> None:
super().__init__()
self.nb_solution = 0

def on_step_end(self, step: int, res: ResultStorage, solver: CPSatRCPSPSolver):
self.nb_solution += 1
sol: RCPSPSolution
sol, fit = res.list_solution_fits[-1]
logging.debug(f"Solution #{self.nb_solution}:")
logging.debug(sol.rcpsp_schedule)
logging.debug(sol.rcpsp_modes)

callbacks = [VariablePrinterCallback(), TimerStopper(3)]

with caplog.at_level(logging.DEBUG):
result_storage = solver.solve(callbacks=callbacks, parameters_cp=parameters_cp)

assert "Solution #1" in caplog.text
assert "stopped by user callback" in caplog.text


def test_cp_sm_intermediate_solution():
files_available = get_data_available()
file = [f for f in files_available if "j1201_1.sm" in f][0]
Expand Down

0 comments on commit a043f9c

Please sign in to comment.