Skip to content

Commit

Permalink
[Feat] Update CVRP baselines
Browse files Browse the repository at this point in the history
  • Loading branch information
cbhua committed Jun 10, 2024
1 parent 4edd29c commit 96d68e3
Show file tree
Hide file tree
Showing 6 changed files with 481 additions and 30 deletions.
187 changes: 187 additions & 0 deletions rl4co/envs/routing/cvrp/baselines/lkh.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
import lkh
import numpy as np

from tensordict import TensorDict
from torch import Tensor

from .utils import scale

LKH_SCALING_FACTOR = 100_000


def solve(
instance: TensorDict,
max_runtime: float,
problem_type: str,
num_runs: int,
solver_loc: str,
) -> tuple[Tensor, Tensor]:
"""
Solves an CVRP instance with OR-Tools.
Parameters
----------
instance
The CVRP instance to solve.
max_runtime
The maximum runtime for the solver.
problem_type
The problem type for LKH3.
num_runs
The number of runs to perform and returns the best result.
solver_loc
The location of the LKH3 solver executable.
Returns
-------
tuple[Tensor, Tensor]
A tuple consisting of the action and the cost, respectively.
"""
problem = instance2problem(instance, problem_type, LKH_SCALING_FACTOR)
action, cost = _solve(problem, max_runtime, num_runs, solver_loc)
cost /= -LKH_SCALING_FACTOR

return action, cost


def _solve(
problem: lkh.LKHProblem,
max_runtime: float,
num_runs: int,
solver_loc: str,
) -> tuple[Tensor, Tensor]:
"""
Solves an instance with LKH3.
Parameters
----------
problem
The LKHProblem instance.
max_runtime
The maximum runtime for each solver run.
num_runs
The number of runs to perform and returns the best result.
Note: Each run uses a different initial solution. LKH has difficulty
finding feasible solutions, so performing more runs can help to find
solutions that are feasible.
solver_loc
The location of the LKH3 solver executable.
"""
routes, cost = lkh.solve(
solver_loc,
problem=problem,
time_limit=max_runtime,
runs=num_runs,
)

action = routes2action(routes)
return action, cost


def instance2problem(
instance: TensorDict,
problem_type: str,
scaling_factor,
) -> lkh.LKHProblem:
"""
Converts an CVRP instance to an LKHProblem instance.
Parameters
----------
instance
The CVRP instance to convert.
problem_type
The problem type for LKH3. See ``constants.ROUTEFINDER2LKH`` for
supported problem types.
scaling_factor
The scaling factor to apply to the instance data.
"""
num_locations = instance["demand_linehaul"].size()[0]

# Data specifications
specs = {}
specs["DIMENSION"] = num_locations
specs["CAPACITY"] = scale(instance["vehicle_capacity"], scaling_factor)

specs["EDGE_WEIGHT_TYPE"] = "EXPLICIT"
specs["EDGE_WEIGHT_FORMAT"] = "FULL_MATRIX"
specs["NODE_COORD_TYPE"] = "TWOD_COORDS"

# LKH can only solve VRP variants that are explicitly supported (so no
# arbitrary combinations between individual supported features). We can
# support some open variants with some modeling tricks.
lkh_problem_type = "CVRP"
specs["TYPE"] = lkh_problem_type

# Data sections
sections = {}
sections["NODE_COORD_SECTION"] = scale(instance["locs"], scaling_factor)

demand = scale(instance["demand"], scaling_factor)
sections["DEMAND_SECTION"] = demand

distances = instance["cost_matrix"]

sections["EDGE_WEIGHT_SECTION"] = scale(distances, scaling_factor)

# Convert to VRPLIB-like string.
problem = "\n".join(f"{k} : {v}" for k, v in specs.items())
problem += "\n" + "\n".join(_format(name, data) for name, data in sections.items())
problem += "\n" + "\n".join(["DEPOT_SECTION", "1", "-1", "EOF"])

return lkh.LKHProblem.parse(problem)


def _is_1D(data) -> bool:
for elt in data:
if isinstance(elt, (list, tuple, np.ndarray)):
return False
return True


def _format(name: str, data) -> str:
"""
Formats a data section.
Parameters
----------
name
The name of the section.
data
The data to be formatted.
Returns
-------
str
A VRPLIB-formatted data section.
"""
section = [name]
include_idx = name not in ["EDGE_WEIGHT_SECTION", "BACKHAUL_SECTION"]

if name == "BACKHAUL_SECTION":
# Treat backhaul section as row vector.
section.append("\t".join(str(val) for val in data))

elif _is_1D(data):
# Treat 1D arrays as column vectors, so each element is a row.
for idx, elt in enumerate(data, 1):
prefix = f"{idx}\t" if include_idx else ""
section.append(prefix + str(elt))
else:
for idx, row in enumerate(data, 1):
prefix = f"{idx}\t" if include_idx else ""
rest = "\t".join([str(elt) for elt in row])
section.append(prefix + rest)

return "\n".join(section)


def routes2action(routes: list[list[int]]) -> list[int]:
"""
Converts LKH routes to an action.
"""
# LKH routes are location-indexed, which in turn are 1-indexed. The first
# location is always the depot, so we subtract 2 to get client indices.
# LKH routes are 1-indexed, so we subtract 1 to get client indices.
routes_ = [[client - 1 for client in route] for route in routes]
return [visit for route in routes_ for visit in route + [0]]
174 changes: 174 additions & 0 deletions rl4co/envs/routing/cvrp/baselines/ortools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
from dataclasses import dataclass
from typing import Optional

import numpy as np

from ortools.constraint_solver import pywrapcp, routing_enums_pb2
from tensordict import TensorDict
from torch import Tensor

from . import pyvrp

ORTOOLS_SCALING_FACTOR = 100_000


def solve(instance: TensorDict, max_runtime: float, **kwargs) -> tuple[Tensor, Tensor]:
"""
Solves an CVRP instance with OR-Tools.
Parameters
----------
instance
The CVRP instance to solve.
max_runtime
The maximum runtime for the solver.
Returns
-------
tuple[Tensor, Tensor]
A tuple consisting of the action and the cost, respectively.
Notes
-----
This function depends on PyVRP's data converter to convert the CVRP
instance to an OR-Tools compatible format. Future versions should
implement a direct conversion.
"""
data = instance2data(instance)
action, cost = _solve(data, max_runtime)
cost /= ORTOOLS_SCALING_FACTOR
cost *= -1

return action, cost


@dataclass
class ORToolsData:
"""
Convenient dataclass for instance data when using OR-Tools as solver.
Parameters
----------
depot
The depot index.
distance_matrix
The distance matrix between locations.
vehicle_capacities
The capacity of each vehicle.
demands
The demands of each location.
"""

depot: int
distance_matrix: list[list[int]]
vehicle_capacities: list[int]
demands: list[int]

@property
def num_locations(self) -> int:
return len(self.distance_matrix)


def instance2data(instance: TensorDict) -> ORToolsData:
"""
Converts an CVRP instance to an ORToolsData instance.
"""
# TODO: Do not use PyVRP's data converter.
data = pyvrp.instance2data(instance, ORTOOLS_SCALING_FACTOR)

capacities = [
veh_type.capacity
for veh_type in data.vehicle_types()
for _ in range(veh_type.num_available)
]

demands = [0] + [client.delivery for client in data.clients()]
distances = data.distance_matrix().copy()

return ORToolsData(
depot=0,
distance_matrix=distances.tolist(),
vehicle_capacities=capacities,
demands=demands,
)


def _solve(data: ORToolsData, max_runtime: float, log: bool = False):
"""
Solves an instance with OR-Tools.
Parameters
----------
data
The instance data.
max_runtime
The maximum runtime in seconds.
log
Whether to log the search.
Returns
-------
tuple[list[list[int]], int]
A tuple containing the routes and the objective value.
"""
# Manager for converting between nodes (location indices) and index
# (internal CP variable indices).
manager = pywrapcp.RoutingIndexManager(
data.num_locations, data.num_vehicles, data.depot
)
routing = pywrapcp.RoutingModel(manager)

# Set arc costs equal to distances.
distance_transit_idx = routing.RegisterTransitMatrix(data.distance_matrix)
routing.SetArcCostEvaluatorOfAllVehicles(distance_transit_idx)

# Vehicle capacity constraint.
routing.AddDimensionWithVehicleCapacity(
routing.RegisterUnaryTransitVector(data.demands),
0, # null capacity slack
data.vehicle_capacities, # vehicle maximum capacities
True, # start cumul to zero
"Demand",
)

# Setup search parameters.
params = pywrapcp.DefaultRoutingSearchParameters()

gls = routing_enums_pb2.LocalSearchMetaheuristic.GUIDED_LOCAL_SEARCH
params.local_search_metaheuristic = gls

params.time_limit.FromSeconds(int(max_runtime)) # only accepts int
params.log_search = log

solution = routing.SolveWithParameters(params)
action = solution2action(data, manager, routing, solution)
objective = solution.ObjectiveValue()

return action, objective


def solution2action(data, manager, routing, solution) -> list[list[int]]:
"""
Converts an OR-Tools solution to routes.
"""
routes = []
distance = 0 # for debugging

for vehicle_idx in range(data.num_vehicles):
index = routing.Start(vehicle_idx)
route = []
route_cost = 0

while not routing.IsEnd(index):
node = manager.IndexToNode(index)
route.append(node)

prev_index = index
index = solution.Value(routing.NextVar(index))
route_cost += routing.GetArcCostForVehicle(prev_index, index, vehicle_idx)

if clients := route[1:]: # ignore depot
routes.append(clients)
distance += route_cost

return [visit for route in routes for visit in route + [0]]
Loading

0 comments on commit 96d68e3

Please sign in to comment.