Skip to content

Commit

Permalink
[BugFix] times in evaluation for any device with time.time
Browse files Browse the repository at this point in the history
  • Loading branch information
fedebotu committed Jun 8, 2024
1 parent 0602fb8 commit 8c3c2a0
Showing 1 changed file with 111 additions and 31 deletions.
142 changes: 111 additions & 31 deletions rl4co/tasks/eval.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import time

import numpy as np
import torch

Expand Down Expand Up @@ -33,11 +35,7 @@ def __call__(self, policy, dataloader, **kwargs):
"""Evaluate the policy on the given dataloader with **kwargs parameter
self._inner is implemented in subclasses and returns actions and rewards
"""

# Collect timings for evaluation (more accurate than timeit)
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
start = time.time()

with torch.inference_mode():
rewards_list = []
Expand All @@ -64,15 +62,14 @@ def __call__(self, policy, dataloader, **kwargs):
0,
)

end_event.record()
torch.cuda.synchronize()
inference_time = start_event.elapsed_time(end_event)
inference_time = time.time() - start

tqdm.write(f"Mean reward for {self.name}: {rewards.mean():.4f}")
tqdm.write(f"Time: {inference_time/1000:.4f}s")
tqdm.write(f"Time: {inference_time:.4f}s")

# Empty cache
torch.cuda.empty_cache()
if torch.cuda.is_available():
torch.cuda.empty_cache()

return {
"actions": actions.cpu(),
Expand Down Expand Up @@ -161,7 +158,17 @@ class SamplingEval(EvalBase):

name = "sampling"

def __init__(self, env, samples, softmax_temp=None, select_best=True, temperature=1.0, top_p=0.0, top_k=0, **kwargs):
def __init__(
self,
env,
samples,
softmax_temp=None,
select_best=True,
temperature=1.0,
top_p=0.0,
top_k=0,
**kwargs,
):
check_unused_kwargs(self, kwargs)
super().__init__(env, kwargs.get("progress", True))

Expand Down Expand Up @@ -410,41 +417,114 @@ def evaluate_policy(


if __name__ == "__main__":
import os
import pickle
import argparse
import importlib
import os
import pickle

import torch

from rl4co.envs import get_env

parser = argparse.ArgumentParser()

# Environment
parser.add_argument("--problem", type=str, default="tsp", help="Problem to solve")
parser.add_argument("--generator-params", type=dict, default={"num_loc": 50}, help="Generator parameters for the environment")
parser.add_argument("--data-path", type=str, default="data/tsp/tsp50_test_seed1234.npz", help="Path of the test data npz file")
parser.add_argument(
"--generator-params",
type=dict,
default={"num_loc": 50},
help="Generator parameters for the environment",
)
parser.add_argument(
"--data-path",
type=str,
default="data/tsp/tsp50_test_seed1234.npz",
help="Path of the test data npz file",
)

# Model
parser.add_argument("--model", type=str, default="AttentionModel", help="The class name of the valid model")
parser.add_argument("--ckpt-path", type=str, default="checkpoints/am-tsp50.ckpt", help="The path of the checkpoint file")
parser.add_argument("--device", type=str, default="cuda:1", help="Device to run the evaluation")
parser.add_argument(
"--model",
type=str,
default="AttentionModel",
help="The class name of the valid model",
)
parser.add_argument(
"--ckpt-path",
type=str,
default="checkpoints/am-tsp50.ckpt",
help="The path of the checkpoint file",
)
parser.add_argument(
"--device", type=str, default="cuda:1", help="Device to run the evaluation"
)

# Evaluation
parser.add_argument("--method", type=str, default="greedy", help="Evaluation method, support 'greedy', 'sampling',\
parser.add_argument(
"--method",
type=str,
default="greedy",
help="Evaluation method, support 'greedy', 'sampling',\
'multistart_greedy', 'augment_dihedral_8', 'augment', 'multistart_greedy_augment_dihedral_8',\
'multistart_greedy_augment'")
parser.add_argument("--temperature", type=float, default=1.0, help="Temperature for sampling")
parser.add_argument("--top-p", type=float, default=0.0, help="Top-p for sampling, from 0.0 to 1.0, 0.0 means not activated")
'multistart_greedy_augment'",
)
parser.add_argument(
"--temperature", type=float, default=1.0, help="Temperature for sampling"
)
parser.add_argument(
"--top-p",
type=float,
default=0.0,
help="Top-p for sampling, from 0.0 to 1.0, 0.0 means not activated",
)
parser.add_argument("--top-k", type=int, default=0, help="Top-k for sampling")
parser.add_argument("--select-best", default=True, action=argparse.BooleanOptionalAction, help="During sampling, whether to select the best action, use --no-select_best to disable")
parser.add_argument("--save-results", default=True, action=argparse.BooleanOptionalAction, help="Whether to save the evaluation results")
parser.add_argument("--save-path", type=str, default="results", help="The root path to save the results")
parser.add_argument("--num-instances", type=int, default=1000, help="Number of instances to test, maximum 10000")

parser.add_argument("--samples", type=int, default=1280, help="Number of samples for sampling method")
parser.add_argument("--softmax-temp", type=float, default=1.0, help="Temperature for softmax in the sampling method")
parser.add_argument("--num-augment", type=int, default=8, help="Number of augmentations for augmentation method")
parser.add_argument("--force-dihedral-8", default=True, action=argparse.BooleanOptionalAction, help="Force the use of 8 augmentations for augmentation method")
parser.add_argument(
"--select-best",
default=True,
action=argparse.BooleanOptionalAction,
help="During sampling, whether to select the best action, use --no-select_best to disable",
)
parser.add_argument(
"--save-results",
default=True,
action=argparse.BooleanOptionalAction,
help="Whether to save the evaluation results",
)
parser.add_argument(
"--save-path",
type=str,
default="results",
help="The root path to save the results",
)
parser.add_argument(
"--num-instances",
type=int,
default=1000,
help="Number of instances to test, maximum 10000",
)

parser.add_argument(
"--samples", type=int, default=1280, help="Number of samples for sampling method"
)
parser.add_argument(
"--softmax-temp",
type=float,
default=1.0,
help="Temperature for softmax in the sampling method",
)
parser.add_argument(
"--num-augment",
type=int,
default=8,
help="Number of augmentations for augmentation method",
)
parser.add_argument(
"--force-dihedral-8",
default=True,
action=argparse.BooleanOptionalAction,
help="Force the use of 8 augmentations for augmentation method",
)

opts = parser.parse_args()

Expand Down

0 comments on commit 8c3c2a0

Please sign in to comment.