Skip to content

Commit

Permalink
[Feat] Support turning off select_best for sampling evaluation
Browse files Browse the repository at this point in the history
  • Loading branch information
cbhua committed Jun 4, 2024
1 parent b1ced3c commit 077df4a
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions rl4co/tasks/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,13 +161,14 @@ class SamplingEval(EvalBase):

name = "sampling"

def __init__(self, env, samples, softmax_temp=None, temperature=1.0, top_p=0.0, top_k=0, **kwargs):
def __init__(self, env, samples, softmax_temp=None, select_best=True, temperature=1.0, top_p=0.0, top_k=0, **kwargs):
check_unused_kwargs(self, kwargs)
super().__init__(env, kwargs.get("progress", True))

self.samples = samples
self.softmax_temp = softmax_temp
self.temperature = temperature
self.select_best = select_best
self.top_p = top_p
self.top_k = top_k

Expand All @@ -182,7 +183,7 @@ def _inner(self, policy, td):
multisample=True,
return_actions=True,
softmax_temp=self.softmax_temp,
select_best=True,
select_best=self.select_best,
select_start_nodes_fn=lambda td, _, n: sample_n_random_actions(td, n),
)

Expand Down Expand Up @@ -435,6 +436,7 @@ def evaluate_policy(
parser.add_argument("--temperature", type=float, default=1.0, help="Temperature for sampling")
parser.add_argument("--top_p", type=float, default=0.0, help="Top-p for sampling, from 0.0 to 1.0, 0.0 means not activated")
parser.add_argument("--top_k", type=int, default=0, help="Top-k for sampling")
parser.add_argument("--select_best", type=bool, default=True, help="During sampling, whether to select the best action")
parser.add_argument("--save_results", type=bool, default=True, help="Whether to save the evaluation results")
parser.add_argument("--save_path", type=str, default="results", help="The root path to save the results")

Expand Down Expand Up @@ -469,11 +471,12 @@ def evaluate_policy(
samples=opts.samples,
softmax_temp=opts.softmax_temp,
num_augment=opts.num_augment,
force_dihedral_8=opts.force_dihedral_8,
select_best=True if opts.select_best == "True" else False,
force_dihedral_8=True if opts.force_dihedral_8 == "True" else False,
)

# Save the results
if opts.save_results:
if opts.save_results == "True":
if not os.path.exists(opts.save_path):
os.makedirs(opts.save_path)
save_fname = f"{env.name}{env.generator.num_loc}-{opts.model}-{opts.method}-temp-{opts.temperature}-top_p-{opts.top_p}-top_k-{opts.top_k}.pkl"
Expand Down

0 comments on commit 077df4a

Please sign in to comment.