Skip to content

Commit

Permalink
[Feat] Adding PCTSP baselinse template
Browse files Browse the repository at this point in the history
  • Loading branch information
cbhua committed Jun 10, 2024
1 parent 730dc05 commit e1c17ea
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 0 deletions.
30 changes: 30 additions & 0 deletions rl4co/envs/routing/pctsp/baselines/ortools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import torch
import numpy as np

from torch import Tensor
from tensordict.tensordict import TensorDict


def solve(instance: TensorDict, max_runtime: float, **kwargs) -> tuple[Tensor, Tensor]:
"""
Solves the PCTSP instance with Compass.
Parameters
----------
instance
The PCTSP instance to solve.
max_runtime
Maximum runtime for the solver.
Returns
-------
tuple[Tensor, Tensor]
A tuple consisting of the action and the cost, respectively.
"""
raise NotImplementedError("Compass solver is not implemented yet.")

# TODO
action = None
cost = None

return action, cost
47 changes: 47 additions & 0 deletions rl4co/envs/routing/pctsp/baselines/solve.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from functools import partial
from multiprocessing import Pool

from tensordict.tensordict import TensorDict
from torch import Tensor


def solve(
instances: TensorDict,
max_runtime: float,
num_procs: int = 1,
solver: str = "compass",
**kwargs,
) -> tuple[Tensor, Tensor]:
"""
Solves the PCTSP instances with solvers.
Args:
instances: The PCTSP instances to solve.
max_runtime: The maximum runtime for the solver.
num_procs: The number of processes to use.
solver: The solver to use, currently support 'ortools' solver.
Returns:
A tuple containing the action and the cost, respectively.
"""
if solver == "ortools":
from . import ortools
_solve = ortools.solve
else:
raise ValueError(f"Unknown baseline solver: {solver}")

func = partial(_solve, max_runtime=max_runtime, **kwargs)

if num_procs > 1:
with Pool(processes=num_procs) as pool:
results = pool.map(func, instances)
else:
results = [func(instance) for instance in instances]

actions, costs = zip(*results)

# Pad to ensure all actions have the same length.
max_len = max(len(action) for action in actions)
actions = [action + [0] * (max_len - len(action)) for action in actions]

return Tensor(actions).long(), Tensor(costs)

0 comments on commit e1c17ea

Please sign in to comment.