Skip to content

Commit

Permalink
Merge pull request #184 from yining043/new
Browse files Browse the repository at this point in the history
[Feat] Add DACT and NeuOpt Improvement Models
  • Loading branch information
fedebotu authored Jun 7, 2024
2 parents 7169e74 + afc9bab commit 3ffd855
Show file tree
Hide file tree
Showing 25 changed files with 1,730 additions and 83 deletions.
2 changes: 2 additions & 0 deletions rl4co/envs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
SPCTSPEnv,
SVRPEnv,
TSPEnv,
TSPkoptEnv,
)

# Scheduling
Expand Down Expand Up @@ -49,6 +50,7 @@
"smtwtp": SMTWTPEnv,
"mdcpdp": MDCPDPEnv,
"mtvrp": MTVRPEnv,
"tsp_kopt": TSPkoptEnv,
}


Expand Down
70 changes: 70 additions & 0 deletions rl4co/envs/common/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,3 +291,73 @@ def __setstate__(self, state):
self.__dict__.update(state)
self.rng = torch.manual_seed(0)
self.rng.set_state(state["rng"])


class ImprovementEnvBase(RL4COEnvBase, metaclass=abc.ABCMeta):
"""Base class for Improvement environments based on RL4CO EnvBase.
Note that this class assumes that the solution is stored in a linked list format.
Here, if rec[i] = j, it means the node i is connected to node j, i.e., edge i-j is in the solution.
For example, if edge 0-1, edge 1-5, edge 2-10 are in the solution, so we have rec[0]=1, rec[1]=5 and rec[2]=10.
Kindly see https://github.com/yining043/VRP-DACT/blob/new_version/Play_with_DACT.ipynb for an example at the end for TSP.
"""

def __init__(
self,
**kwargs,
):
super().__init__(**kwargs)

@abc.abstractmethod
def _step(self, td: TensorDict, solution_to=None) -> TensorDict:
raise NotImplementedError

def step_to_solution(self, td, solution) -> TensorDict:
return self._step(td, solution_to=solution)

@staticmethod
def _get_reward(td, actions) -> TensorDict:
raise NotImplementedError(
"This function is not used for improvement tasks since the reward is computed per step"
)

@staticmethod
def get_costs(coordinates, rec):
batch_size, size = rec.size()

# calculate the route length value
d1 = coordinates.gather(1, rec.long().unsqueeze(-1).expand(batch_size, size, 2))
d2 = coordinates
length = (d1 - d2).norm(p=2, dim=2).sum(1)

return length

@staticmethod
def _get_real_solution(rec):
batch_size, seq_length = rec.size()
visited_time = torch.zeros((batch_size, seq_length)).to(rec.device)
pre = torch.zeros((batch_size), device=rec.device).long()
for i in range(seq_length):
visited_time[torch.arange(batch_size), rec[torch.arange(batch_size), pre]] = (
i + 1
)
pre = rec[torch.arange(batch_size), pre]

visited_time = visited_time % seq_length
return visited_time.argsort()

@staticmethod
def _get_linked_list_solution(solution):
solution_pre = solution
solution_post = torch.cat((solution[:, 1:], solution[:, :1]), 1)

rec = solution.clone()
rec.scatter_(1, solution_pre, solution_post)
return rec

@classmethod
def get_best_solution(cls, td):
return cls._get_real_solution(td["rec_best"])

@classmethod
def get_current_solution(cls, td):
return cls._get_real_solution(td["rec_current"])
2 changes: 1 addition & 1 deletion rl4co/envs/routing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,5 @@
from rl4co.envs.routing.spctsp.env import SPCTSPEnv
from rl4co.envs.routing.svrp.env import SVRPEnv
from rl4co.envs.routing.svrp.generator import SVRPGenerator
from rl4co.envs.routing.tsp.env import DenseRewardTSPEnv, TSPEnv
from rl4co.envs.routing.tsp.env import DenseRewardTSPEnv, TSPEnv, TSPkoptEnv
from rl4co.envs.routing.tsp.generator import TSPGenerator
92 changes: 28 additions & 64 deletions rl4co/envs/routing/pdp/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
UnboundedDiscreteTensorSpec,
)

from rl4co.envs.common.base import RL4COEnvBase
from rl4co.envs.common.base import ImprovementEnvBase, RL4COEnvBase
from rl4co.utils.ops import gather_by_index, get_tour_length

from .generator import PDPGenerator
Expand Down Expand Up @@ -234,7 +234,7 @@ def render(td: TensorDict, actions: torch.Tensor = None, ax=None):
return render(td, actions, ax)


class PDPRuinRepairEnv(RL4COEnvBase):
class PDPRuinRepairEnv(ImprovementEnvBase):
"""Pickup and Delivery Problem (PDP) environment for performing neural rein-repair search.
The environment is made of num_loc + 1 locations (cities):
- 1 depot
Expand Down Expand Up @@ -287,23 +287,22 @@ def __init__(
self.generator = generator
self._make_spec(self.generator)

@classmethod
def _step(cls, td: TensorDict) -> TensorDict:
def _step(self, td: TensorDict, solution_to=None) -> TensorDict:
# get state information from td
action = td["action"]
selected = action[:, 0].view(-1, 1)
first = action[:, 1].view(-1, 1)
second = action[:, 2].view(-1, 1)
solution = td["rec_current"]
solution_best = td["rec_best"]
locs = td["locs"]
cost_bsf = td["cost_bsf"]
action_record = td["action_record"]
bs, gs = solution.size()
bs, gs = solution_best.size()

# perform ruin and repair
next_rec = cls._insert_operator(solution, selected + 1, first, second)
new_obj = cls.get_costs(locs, next_rec)
# perform local_operator
if solution_to is None:
action = td["action"]
solution = td["rec_current"]
next_rec = self._local_operator(solution, action)
else:
next_rec = solution_to.clone()
new_obj = self.get_costs(locs, next_rec)

# compute reward and update best-so-far solutions
now_bsf = torch.where(new_obj < cost_bsf, new_obj, cost_bsf)
Expand All @@ -322,9 +321,10 @@ def _step(cls, td: TensorDict) -> TensorDict:
visited_time = visited_time.long()

# update action record
action_record[:, :-1] = action_record[:, 1:]
action_record[:, -1] *= 0
action_record[torch.arange(bs), -1, action[:, 0]] = 1
if solution_to is None:
action_record[:, :-1] = action_record[:, 1:]
action_record[:, -1] *= 0
action_record[torch.arange(bs), -1, action[:, 0]] = 1

# Update step
td.update(
Expand All @@ -335,7 +335,7 @@ def _step(cls, td: TensorDict) -> TensorDict:
"rec_best": solution_best,
"visited_time": visited_time,
"action_record": action_record,
"i": td["i"] + 1,
"i": td["i"] + 1 if solution_to is None else td["i"],
"reward": reward,
}
)
Expand All @@ -347,7 +347,6 @@ def _reset(self, td: Optional[TensorDict] = None, batch_size=None) -> TensorDict

locs = torch.cat((td["depot"][:, None, :], td["locs"]), -2)
current_rec = self.generator._get_initial_solutions(locs).to(device)

obj = self.get_costs(locs, current_rec)

# get index according to the solutions in the linked list data structure
Expand All @@ -359,12 +358,16 @@ def _reset(self, td: Optional[TensorDict] = None, batch_size=None) -> TensorDict
for i in range(seq_length):
current_nodes = current_rec[arange, pre]
visited_time[arange, current_nodes] = i + 1
pre = current_rec[arange, pre]
pre = current_nodes
visited_time = visited_time.long()

# get action record and step i
i = torch.zeros((*batch_size, 1), dtype=torch.int64).to(device)
action_record = torch.zeros((bs, seq_length // 2, seq_length // 2))
action_record = (
torch.zeros((bs, seq_length, seq_length // 2))
if self.training
else torch.zeros((bs, seq_length // 2, seq_length // 2))
)

return TensorDict(
{
Expand All @@ -381,7 +384,11 @@ def _reset(self, td: Optional[TensorDict] = None, batch_size=None) -> TensorDict
)

@staticmethod
def _insert_operator(solution, pair_index, first, second):
def _local_operator(solution, action):
# get info
pair_index = action[:, 0].view(-1, 1) + 1
first = action[:, 1].view(-1, 1)
second = action[:, 2].view(-1, 1)
rec = solution.clone()
bs, gs = rec.size()

Expand Down Expand Up @@ -447,10 +454,6 @@ def _make_spec(self, generator: PDPGenerator):
shape=(1),
dtype=torch.int64,
),
action_mask=UnboundedDiscreteTensorSpec(
shape=(self.generator.num_loc + 1, self.generator.num_loc + 1),
dtype=torch.bool,
),
shape=(),
)
self.action_spec = BoundedTensorSpec(
Expand All @@ -462,12 +465,6 @@ def _make_spec(self, generator: PDPGenerator):
self.reward_spec = UnboundedContinuousTensorSpec(shape=(1,))
self.done_spec = UnboundedDiscreteTensorSpec(shape=(1,), dtype=torch.bool)

@staticmethod
def _get_reward(td, actions) -> TensorDict:
raise NotImplementedError(
"This function is not used for improvement tasks since the reward is computed per step"
)

def check_solution_validity(self, td, actions=None):
# The function can be called by the agent to check the validity of the best found solution
# Note that the args actions are not used in improvement method.
Expand Down Expand Up @@ -514,17 +511,6 @@ def get_mask(selected_node, td):

return ~mask

@staticmethod
def get_costs(coordinates, rec):
batch_size, size = rec.size()

# calculate the route length value
d1 = coordinates.gather(1, rec.long().unsqueeze(-1).expand(batch_size, size, 2))
d2 = coordinates
length = (d1 - d2).norm(p=2, dim=2).sum(1)

return length

@classmethod
def _random_action(cls, td):
batch_size, graph_size = td["rec_best"].size()
Expand All @@ -542,28 +528,6 @@ def _random_action(cls, td):
td["action"] = action
return action

@staticmethod
def _get_real_solution(rec):
batch_size, seq_length = rec.size()
visited_time = torch.zeros((batch_size, seq_length)).to(rec.device)
pre = torch.zeros((batch_size), device=rec.device).long()
for i in range(seq_length):
visited_time[torch.arange(batch_size), rec[torch.arange(batch_size), pre]] = (
i + 1
)
pre = rec[torch.arange(batch_size), pre]

visited_time = visited_time % seq_length
return visited_time.argsort()

@classmethod
def get_best_solution(cls, td):
return cls._get_real_solution(td["rec_best"])

@classmethod
def get_current_solution(cls, td):
return cls._get_real_solution(td["rec_current"])

@classmethod
def render(cls, td: TensorDict, actions: torch.Tensor = None, ax=None):
solution_current = cls.get_current_solution(td)
Expand Down
2 changes: 0 additions & 2 deletions rl4co/envs/routing/pdp/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,6 @@ def _get_initial_solutions(self, coordinates):
candidates.scatter_(1, next_selected_node, 0)
selected_node = next_selected_node

return rec

elif self.init_sol_type == "greedy":
candidates = torch.ones(batch_size, self.num_loc + 1).bool()
candidates[:, order_size + 1 :] = 0
Expand Down
Loading

0 comments on commit 3ffd855

Please sign in to comment.