diff --git a/discrete_optimization/rcpsp/solver/cpsat_solver.py b/discrete_optimization/rcpsp/solver/cpsat_solver.py index 1c355c564..a0bf98f49 100644 --- a/discrete_optimization/rcpsp/solver/cpsat_solver.py +++ b/discrete_optimization/rcpsp/solver/cpsat_solver.py @@ -13,20 +13,25 @@ UNKNOWN, CpModel, CpSolver, - IntervalVar, - VarArrayAndObjectiveSolutionPrinter, - VarArraySolutionPrinter, + CpSolverSolutionCallback, ) -from discrete_optimization.generic_tools.cp_tools import ParametersCP, StatusSolver +from discrete_optimization.generic_tools.callbacks.callback import ( + Callback, + CallbackList, +) +from discrete_optimization.generic_tools.cp_tools import ( + CPSolver, + ParametersCP, + StatusSolver, +) from discrete_optimization.generic_tools.do_problem import ( ParamsObjectiveFunction, build_aggreg_function_and_params_objective, ) -from discrete_optimization.generic_tools.do_solver import SolverDO +from discrete_optimization.generic_tools.exceptions import SolveEarlyStop from discrete_optimization.generic_tools.result_storage.result_storage import ( ResultStorage, - from_solutions_to_result_storage, ) from discrete_optimization.rcpsp.rcpsp_model import RCPSPModel, RCPSPSolution from discrete_optimization.rcpsp.rcpsp_utils import create_fake_tasks @@ -34,7 +39,7 @@ logger = logging.getLogger(__name__) -class CPSatRCPSPSolver(SolverDO): +class CPSatRCPSPSolver(CPSolver): def __init__( self, problem: RCPSPModel, @@ -53,7 +58,7 @@ def __init__( self.variables: Optional[Dict[str, Any]] = None self.status_solver: Optional[StatusSolver] = None - def init_model(self): + def init_model(self, **kwargs: Any) -> None: model = CpModel() starts_var = {} ends_var = {} @@ -156,8 +161,13 @@ def init_model(self): } def solve( - self, parameters_cp: Optional[ParametersCP] = None, **kwargs: Any + self, + callbacks: Optional[List[Callback]] = None, + parameters_cp: Optional[ParametersCP] = None, + **kwargs: Any, ) -> ResultStorage: + callbacks_list = CallbackList(callbacks=callbacks) + callbacks_list.on_solve_start(solver=self) if self.cp_model is None: self.init_model() if parameters_cp is None: @@ -165,39 +175,21 @@ def solve( solver = CpSolver() solver.parameters.max_time_in_seconds = parameters_cp.time_limit solver.parameters.num_workers = parameters_cp.nb_process - callback = VarArrayAndObjectiveSolutionPrinter( - variables=list(self.variables["is_present"].values()) - + list(self.variables["start"].values()) - ) - status = solver.Solve(self.cp_model, callback) + ortools_callback = DOCallback(do_solver=self, callback=callbacks_list) + try: + status = solver.Solve(self.cp_model, ortools_callback) + self.status_solver = cpstatus_to_dostatus(status_from_cpsat=status) + except SolveEarlyStop as e: + logger.info(e) + if ortools_callback.nb_solutions > 0: + status = StatusSolver.SATISFIED + else: + status = StatusSolver.UNSATISFIABLE + # ortools_callback.store_current_solution() self.status_solver = cpstatus_to_dostatus(status_from_cpsat=status) - logger.info( - f"Solver finished, status={solver.StatusName(status)}, objective = {solver.ObjectiveValue()}," - f"best obj bound = {solver.BestObjectiveBound()}" - ) - return self.retrieve_solution(solver=solver) - - def retrieve_solution(self, solver: CpSolver): - schedule = {} - modes_dict = {} - for task in self.variables["start"]: - schedule[task] = { - "start_time": solver.Value(self.variables["start"][task]), - "end_time": solver.Value(self.variables["end"][task]), - } - for task, mode in self.variables["is_present"]: - if solver.Value(self.variables["is_present"][task, mode]): - modes_dict[task] = mode - sol = RCPSPSolution( - problem=self.problem, - rcpsp_schedule=schedule, - rcpsp_modes=[modes_dict[t] for t in self.problem.tasks_list_non_dummy], - ) - return from_solutions_to_result_storage( - [sol], - problem=self.problem, - params_objective_function=self.params_objective_function, - ) + res = ortools_callback.res + callbacks_list.on_solve_end(res=res, solver=self) + return res def cpstatus_to_dostatus(status_from_cpsat) -> StatusSolver: @@ -214,3 +206,52 @@ def cpstatus_to_dostatus(status_from_cpsat) -> StatusSolver: return StatusSolver.OPTIMAL if status_from_cpsat == FEASIBLE: return StatusSolver.SATISFIED + + +class DOCallback(CpSolverSolutionCallback): + def __init__(self, do_solver: CPSatRCPSPSolver, callback: Callback): + super().__init__() + self.do_solver = do_solver + self.callback = callback + self.res = ResultStorage( + [], + mode_optim=self.do_solver.params_objective_function.sense_function, + limit_store=False, + ) + self.nb_solutions = 0 + + def on_solution_callback(self) -> None: + self.store_current_solution() + self.nb_solutions += 1 + # end of step callback: stopping? + stopping = self.callback.on_step_end( + step=self.nb_solutions, res=self.res, solver=self.do_solver + ) + if stopping: + raise SolveEarlyStop( + f"{self.do_solver.__class__.__name__}.solve() stopped by user callback." + ) + + def get_current_solution(self) -> RCPSPSolution: + schedule = {} + modes_dict = {} + for task in self.do_solver.variables["start"]: + schedule[task] = { + "start_time": self.Value(self.do_solver.variables["start"][task]), + "end_time": self.Value(self.do_solver.variables["end"][task]), + } + for task, mode in self.do_solver.variables["is_present"]: + if self.Value(self.do_solver.variables["is_present"][task, mode]): + modes_dict[task] = mode + return RCPSPSolution( + problem=self.do_solver.problem, + rcpsp_schedule=schedule, + rcpsp_modes=[ + modes_dict[t] for t in self.do_solver.problem.tasks_list_non_dummy + ], + ) + + def store_current_solution(self): + sol = self.get_current_solution() + fit = self.do_solver.aggreg_sol(sol) + self.res.add_solution(solution=sol, fitness=fit) diff --git a/tests/rcpsp/solver/test_rcpsp_cp.py b/tests/rcpsp/solver/test_rcpsp_cp.py index 0732b6cc8..fbac412f5 100644 --- a/tests/rcpsp/solver/test_rcpsp_cp.py +++ b/tests/rcpsp/solver/test_rcpsp_cp.py @@ -1,11 +1,16 @@ # Copyright (c) 2022 AIRBUS and its affiliates. # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import logging import pytest +from discrete_optimization.generic_tools.callbacks.callback import Callback +from discrete_optimization.generic_tools.callbacks.early_stoppers import TimerStopper +from discrete_optimization.generic_tools.callbacks.loggers import NbIterationTracker from discrete_optimization.generic_tools.cp_tools import CPSolverName, ParametersCP from discrete_optimization.generic_tools.result_storage.result_storage import ( + ResultStorage, result_storage_to_pareto_front, ) from discrete_optimization.rcpsp.rcpsp_model import RCPSPModel @@ -86,6 +91,38 @@ def test_ortools(model): plot_task_gantt(rcpsp_problem, solution) +def test_ortools_with_cb(caplog): + model = "j1201_1.sm" + files_available = get_data_available() + file = [f for f in files_available if model in f][0] + rcpsp_problem = parse_file(file) + solver = CPSatRCPSPSolver(problem=rcpsp_problem) + parameters_cp = ParametersCP.default() + parameters_cp.time_limit = 5 + parameters_cp.nr_solutions = 1 + + class VariablePrinterCallback(Callback): + def __init__(self) -> None: + super().__init__() + self.nb_solution = 0 + + def on_step_end(self, step: int, res: ResultStorage, solver: CPSatRCPSPSolver): + self.nb_solution += 1 + sol: RCPSPSolution + sol, fit = res.list_solution_fits[-1] + logging.debug(f"Solution #{self.nb_solution}:") + logging.debug(sol.rcpsp_schedule) + logging.debug(sol.rcpsp_modes) + + callbacks = [VariablePrinterCallback(), TimerStopper(3)] + + with caplog.at_level(logging.DEBUG): + result_storage = solver.solve(callbacks=callbacks, parameters_cp=parameters_cp) + + assert "Solution #1" in caplog.text + assert "stopped by user callback" in caplog.text + + def test_cp_sm_intermediate_solution(): files_available = get_data_available() file = [f for f in files_available if "j1201_1.sm" in f][0]