Skip to content

Commit

Permalink
Merge
Browse files Browse the repository at this point in the history
  • Loading branch information
ahottung committed Jun 4, 2024
2 parents 3785590 + 3b66124 commit e039432
Show file tree
Hide file tree
Showing 4 changed files with 251 additions and 28 deletions.
6 changes: 3 additions & 3 deletions rl4co/models/nn/env_embeddings/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,12 +117,12 @@ def forward(self, embeddings, td):
node_dim = (
(-1,) if td["first_node"].dim() == 1 else (td["first_node"].size(-1), -1)
)
if td["i"][(0,) * td["i"].dim()].item() < 1: # get first item fast
if td.batch_dims == 1:
if td["i"][(0,) * td["i"].dim()].item() < 1: # get first item fast
if len(td.batch_size) < 2:
context_embedding = self.W_placeholder[None, :].expand(
batch_size, self.W_placeholder.size(-1)
)
elif td.batch_dims == 2:
else:
context_embedding = self.W_placeholder[None, None, :].expand(
batch_size, td.batch_size[1], self.W_placeholder.size(-1)
)
Expand Down
103 changes: 103 additions & 0 deletions rl4co/tasks/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Evaluation

To evaluate your trained model, here are some steps to follow:

**Step 1**. Prepare your *pre-trained model checkpoint* and *test instances data file*. Put them in your preferred place. e.g., we will test the `AttentionModel` on TSP50:

```
.
├── rl4co/
│ └── ...
├── checkpoints/
│ └── am-tsp50.ckpt
└── data/
└── tsp/
└── tsp50_test_seed1234.npz
```

You can generate the test instances data file by running the following command:

```bash
python -c "from rl4co.data.generate_data import generate_default_datasets; generate_default_datasets('data')"
```

**Step 2**. Run the `eval.py` with your customized setting. e.g., let's use the `sampling` method with a `top_p=0.95` sampling strategy:

```bash
python rl4co/tasks/eval.py --problem tsp --data-path data/tsp/tsp50_test_seed1234.npz --model AttentionModel --ckpt-path checkpoints/am-tsp50.ckpt --method sampling --top-p 0.95
```

Arguments guideline:
- `--problem`: the problem name, e.g., `tsp`, `cvrp`, `pdp`, etc. This should be consistent with the `env.name`. Default is `tsp`.
- `--generator-params`: the generator parameters for the test instances. You could specify the `num_loc` etc. Default is `{'num_loc': 50}`.
- `--data-path`: the path to the test instances data file. Default is `data/tsp/tsp50_test_seed1234.npz`.
- `--model`: the model **class name**, e.g., `AttentionModel`, `POMO`, `SymNCO`, etc. It will be dynamically imported and instantiated. Default is `AttentionModel`.
- `--ckpt-path`: the path to the pre-trained model checkpoint. Default is `checkpoints/am-tsp50.ckpt`.
- `--device`: the device to run the evaluation, e.g., `cuda:0`, `cpu`, etc. Default is `cuda:0`.
- `--method`: the evaluation method, e.g., `greedy`, `sampling`, `multistart_greedy`, `augment_dihedral_8`, `augment`, `multistart_greedy_augment_dihedral_8`, and `multistart_greedy_augment`. Default is `greedy`.
- `--save-results`: whether to save the evaluation results as a `.pkl` file. Deafult is `True`. The results include `actions`, `rewards`, `inference_time`, and `avg_reward`.
- `--save-path`: the path to save the evaluation results. Default is `results/`.
- `--num-instances`: the number of test instances to evaluate. Default is `1000`.

If you use the `sampling` method, you may need to specify the following parameters:
- `--samples`: the number of samples for the sampling method. Default is `1280`.
- `--temperature`: the temperature for the sampling method. Default is `1.0`.
- `--top-p`: the top-p for the sampling method. Default is `0.0`, i.e. not activated.
- `--top-k`: the top-k for the sampling method. Deafult is `0`, i.e. not activated.
- `--select-best`: whether to select the best action from the sampling results. If `False`, the results will include all sampled rewards, i.e., `[num_instances * num_samples]`.

If you use the `augment` method, you may need to specify the following parameters:
- `--num-augments`: the number of augmented instances for the augment method. Default is `8`.
- `--force-dihedral-8`: whether to force the augmented instances to be dihedral 8. Default is `True`.

**Step 3**. If you want to launch several evaluations with various parameters, you may refer to the following examples:

- Evaluate POMO on TSP50 with a sampling of different Top-p and temperature:

```bash
#!/bin/bash

top_p_list=(0.5 0.6 0.7 0.8 0.9 0.95 0.98 0.99 0.995 1.0)
temp_list=(0.1 0.3 0.5 0.7 0.8 0.9 1.0 1.1 1.2 1.5 1.8 2.0 2.2 2.5 2.8 3.0)

device=cuda:0

problem=tsp
model=POMO
ckpt_path=checkpoints/pomo-tsp50.ckpt
data_path=data/tsp/tsp50_test_seed1234.npz

num_instances=1000
save_path=results/tsp50-pomo-topp-1k

for top_p in ${top_p_list[@]}; do
for temp in ${temp_list[@]}; do
python rl4co/tasks/eval.py --problem ${problem} --model ${model} --ckpt_path ${ckpt_path} --data_path ${data_path} --save_path ${save_path} --method sampling --temperature=${temp} --top_p=${top_p} --top_k=0 --device ${device}
done
done
```

- Evaluate POMO on CVRP50 with a sampling of different Top-k and temperature:

```bash
#!/bin/bash
top_k_list=(5 10 15 20 25)
temp_list=(0.1 0.3 0.5 0.7 0.8 0.9 1.0 1.1 1.2 1.5 1.8 2.0 2.2 2.5 2.8 3.0)
device=cuda:1
problem=cvrp
model=POMO
ckpt_path=checkpoints/pomo-cvrp50.ckpt
data_path=data/vrp/vrp50_test_seed1234.npz
num_instances=1000
save_path=results/cvrp50-pomo-topk-1k
for top_k in ${top_k_list[@]}; do
for temp in ${temp_list[@]}; do
python rl4co/tasks/eval.py --problem ${problem} --model ${model} --ckpt_path ${ckpt_path} --data_path ${data_path} --save_path ${save_path} --method sampling --temperature=${temp} --top_p=0.0 --top_k=${top_k} --device ${device}
done
done
```
142 changes: 126 additions & 16 deletions rl4co/tasks/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,22 +161,29 @@ class SamplingEval(EvalBase):

name = "sampling"

def __init__(self, env, samples, softmax_temp=None, **kwargs):
def __init__(self, env, samples, softmax_temp=None, select_best=True, temperature=1.0, top_p=0.0, top_k=0, **kwargs):
check_unused_kwargs(self, kwargs)
super().__init__(env, kwargs.get("progress", True))

self.samples = samples
self.softmax_temp = softmax_temp
self.temperature = temperature
self.select_best = select_best
self.top_p = top_p
self.top_k = top_k

def _inner(self, policy, td):
out = policy(
td.clone(),
decode_type="sampling",
num_starts=self.samples,
multistart=True,
temperature=self.temperature,
top_p=self.top_p,
top_k=self.top_k,
multisample=True,
return_actions=True,
softmax_temp=self.softmax_temp,
select_best=True,
select_best=self.select_best,
select_start_nodes_fn=lambda td, _, n: sample_n_random_actions(td, n),
)

Expand Down Expand Up @@ -331,8 +338,10 @@ def evaluate_policy(
max_batch_size=4096,
start_batch_size=8192,
auto_batch_size=True,
save_results=False,
save_fname="results.npz",
samples=1280,
softmax_temp=1.0,
num_augment=8,
force_dihedral_8=True,
**kwargs,
):
num_loc = getattr(env.generator, "num_loc", None)
Expand All @@ -341,28 +350,28 @@ def evaluate_policy(
"greedy": {"func": GreedyEval, "kwargs": {}},
"sampling": {
"func": SamplingEval,
"kwargs": {"samples": 100, "softmax_temp": 1.0},
"kwargs": {"samples": samples, "softmax_temp": softmax_temp},
},
"multistart_greedy": {
"func": GreedyMultiStartEval,
"kwargs": {"num_starts": num_loc},
},
"augment_dihedral_8": {
"func": AugmentationEval,
"kwargs": {"num_augment": 8, "force_dihedral_8": True},
"kwargs": {"num_augment": num_augment, "force_dihedral_8": force_dihedral_8},
},
"augment": {"func": AugmentationEval, "kwargs": {"num_augment": 8}},
"augment": {"func": AugmentationEval, "kwargs": {"num_augment": num_augment}},
"multistart_greedy_augment_dihedral_8": {
"func": GreedyMultiStartAugmentEval,
"kwargs": {
"num_augment": 8,
"force_dihedral_8": True,
"num_augment": num_augment,
"force_dihedral_8": force_dihedral_8,
"num_starts": num_loc,
},
},
"multistart_greedy_augment": {
"func": GreedyMultiStartAugmentEval,
"kwargs": {"num_augment": 8, "num_starts": num_loc},
"kwargs": {"num_augment": num_augment, "num_starts": num_loc},
},
}

Expand Down Expand Up @@ -397,9 +406,110 @@ def evaluate_policy(
# Run evaluation
retvals = eval_fn(policy, dataloader)

# Save results
if save_results:
print("Saving results to {}".format(save_fname))
np.savez(save_fname, **retvals)

return retvals


if __name__ == "__main__":
import os
import pickle
import argparse
import importlib
import torch
from rl4co.envs import get_env

parser = argparse.ArgumentParser()

# Environment
parser.add_argument("--problem", type=str, default="tsp", help="Problem to solve")
parser.add_argument("--generator-params", type=dict, default={"num_loc": 50}, help="Generator parameters for the environment")
parser.add_argument("--data-path", type=str, default="data/tsp/tsp50_test_seed1234.npz", help="Path of the test data npz file")

# Model
parser.add_argument("--model", type=str, default="AttentionModel", help="The class name of the valid model")
parser.add_argument("--ckpt-path", type=str, default="checkpoints/am-tsp50.ckpt", help="The path of the checkpoint file")
parser.add_argument("--device", type=str, default="cuda:1", help="Device to run the evaluation")

# Evaluation
parser.add_argument("--method", type=str, default="greedy", help="Evaluation method, support 'greedy', 'sampling',\
'multistart_greedy', 'augment_dihedral_8', 'augment', 'multistart_greedy_augment_dihedral_8',\
'multistart_greedy_augment'")
parser.add_argument("--temperature", type=float, default=1.0, help="Temperature for sampling")
parser.add_argument("--top-p", type=float, default=0.0, help="Top-p for sampling, from 0.0 to 1.0, 0.0 means not activated")
parser.add_argument("--top-k", type=int, default=0, help="Top-k for sampling")
parser.add_argument("--select-best", default=True, action=argparse.BooleanOptionalAction, help="During sampling, whether to select the best action, use --no-select_best to disable")
parser.add_argument("--save-results", default=True, action=argparse.BooleanOptionalAction, help="Whether to save the evaluation results")
parser.add_argument("--save-path", type=str, default="results", help="The root path to save the results")
parser.add_argument("--num-instances", type=int, default=1000, help="Number of instances to test, maximum 10000")

parser.add_argument("--samples", type=int, default=1280, help="Number of samples for sampling method")
parser.add_argument("--softmax-temp", type=float, default=1.0, help="Temperature for softmax in the sampling method")
parser.add_argument("--num-augment", type=int, default=8, help="Number of augmentations for augmentation method")
parser.add_argument("--force-dihedral-8", default=True, action=argparse.BooleanOptionalAction, help="Force the use of 8 augmentations for augmentation method")

opts = parser.parse_args()

# Log the evaluation setting information
print(f"Problem: {opts.problem}-{opts.generator_params['num_loc']}")
print(f"Model: {opts.model}")
print(f"Loading test instances from: {opts.data_path}")
print(f"Loading model checkpoint from: {opts.ckpt_path}")
print(f"Using the device: {opts.device}")
print(f"Evaluation method: {opts.method}")
print(f"Number of instances to test: {opts.num_instances}")

if opts.method == "sampling":
print(f"[Sampling] Number of samples: {opts.samples}")
print(f"[Sampling] Temperature: {opts.temperature}")
print(f"[Sampling] Top-p: {opts.top_p}")
print(f"[Sampling] Top-k: {opts.top_k}")
print(f"[Sampling] Softmax temperature: {opts.softmax_temp}")
print(f"[Sampling] Select best: {opts.select_best}")

if opts.method == "augment" or opts.method == "augment_dihedral_8":
print(f"[Augmentation] Number of augmentations: {opts.num_augment}")
print(f"[Augmentation] Force dihedral 8: {opts.force_dihedral_8}")

if opts.save_results:
print(f"Saving the results to: {opts.save_path}")
else:
print("[Warning] The result will not be saved!")

# Init the environment
env = get_env(opts.problem, generator_params=opts.generator_params)

# Load the test data
dataset = env.dataset(filename=opts.data_path)

# Restrict the instances of testing
dataset.data_len = min(opts.num_instances, len(dataset))

# Load the model from checkpoint
model_root = importlib.import_module("rl4co.models.zoo")
model_cls = getattr(model_root, opts.model)
model = model_cls.load_from_checkpoint(opts.ckpt_path, load_baseline=False)
model = model.to(opts.device)

# Evaluate
result = evaluate_policy(
env=env,
policy=model.policy,
dataset=dataset,
method=opts.method,
temperature=opts.temperature,
top_p=opts.top_p,
top_k=opts.top_k,
samples=opts.samples,
softmax_temp=opts.softmax_temp,
num_augment=opts.num_augment,
select_best=True,
force_dihedral_8=True,
)

# Save the results
if opts.save_results:
if not os.path.exists(opts.save_path):
os.makedirs(opts.save_path)
save_fname = f"{env.name}{env.generator.num_loc}-{opts.model}-{opts.method}-temp-{opts.temperature}-top_p-{opts.top_p}-top_k-{opts.top_k}.pkl"
save_path = os.path.join(opts.save_path, save_fname)
with open(save_path, "wb") as f:
pickle.dump(result, f)
28 changes: 19 additions & 9 deletions rl4co/utils/decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ class DecodingStrategy(metaclass=abc.ABCMeta):
mask_logits: Whether to mask logits of infeasible actions. Defaults to True.
tanh_clipping: Tanh clipping (https://arxiv.org/abs/1611.09940). Defaults to 0.
multistart: Whether to use multistart decoding. Defaults to False.
multisample: Whether to use sampling decoding. Defaults to False.
num_starts: Number of starts for multistart decoding. Defaults to None.
"""

Expand All @@ -215,6 +216,7 @@ def __init__(
mask_logits: bool = True,
tanh_clipping: float = 0,
multistart: bool = False,
multisample: bool = False,
num_starts: Optional[int] = None,
select_start_nodes_fn: Optional[callable] = None,
improvement_method_mode: bool = False,
Expand All @@ -228,6 +230,7 @@ def __init__(
self.mask_logits = mask_logits
self.tanh_clipping = tanh_clipping
self.multistart = multistart
self.multisample = multisample
self.num_starts = num_starts
self.select_start_nodes_fn = select_start_nodes_fn
self.improvement_method_mode = improvement_method_mode
Expand Down Expand Up @@ -262,9 +265,13 @@ def pre_decoder_hook(
"""Pre decoding hook. This method is called before the main decoding operation."""

# Multi-start decoding. If num_starts is None, we use the number of actions in the action mask
if self.multistart:
if self.multistart or self.multisample:
if self.num_starts is None:
self.num_starts = env.get_num_starts(td)
if self.multisample:
log.warn(
f"num_starts is not provided for sampling, using num_starts={self.num_starts}"
)
else:
if self.num_starts is not None:
if self.num_starts >= 1:
Expand All @@ -276,16 +283,16 @@ def pre_decoder_hook(

# Multi-start decoding: first action is chosen by ad-hoc node selection
if self.num_starts >= 1:
if action is None: # if action is provided, we use it as the first action
if self.select_start_nodes_fn is not None:
action = self.select_start_nodes_fn(td, env, self.num_starts)
else:
action = env.select_start_nodes(td, num_starts=self.num_starts)
if self.multistart:
if action is None: # if action is provided, we use it as the first action
if self.select_start_nodes_fn is not None:
action = self.select_start_nodes_fn(td, env, self.num_starts)
else:
action = env.select_start_nodes(td, num_starts=self.num_starts)

# Expand td to batch_size * num_starts
td = batchify(td, self.num_starts)
# Expand td to batch_size * num_starts
td = batchify(td, self.num_starts)

if action is not None:
td.set("action", action)
td = env.step(td)["next"]
# first logprobs is 0, so p = logprobs.exp() = 1
Expand All @@ -296,6 +303,9 @@ def pre_decoder_hook(

self.logprobs.append(logprobs)
self.actions.append(action)
else:
# Expand td to batch_size * num_samplestarts
td = batchify(td, self.num_starts)

return td, env, self.num_starts

Expand Down

0 comments on commit e039432

Please sign in to comment.