diff --git a/hr_edl/src/bin/run_cfr_corr_dist.cc b/hr_edl/src/bin/run_cfr_corr_dist.cc index c0cdf3f..13f5fe4 100644 --- a/hr_edl/src/bin/run_cfr_corr_dist.cc +++ b/hr_edl/src/bin/run_cfr_corr_dist.cc @@ -25,18 +25,13 @@ // Environment ABSL_FLAG(std::string, efg_file, "efg/extended_shapleys.efg", "The EFG file representing the game."); -ABSL_FLAG(int32_t, t, 1000, "The number of iterations to run."); - -inline constexpr int kSeed = 23894982; - -// Reporting -ABSL_FLAG( - double, report_gap_factor, 2.0, - "The factor increasing the number of iterations to finish between " - "reporting exploitability. E.g. 1 prints after each iteration, 2 prints " - "after 2^i iterations, i >= 1."); -ABSL_FLAG(int32_t, report_skip, 0, - "Number of initial iterations for which to skip reporting."); +ABSL_FLAG(int32_t, t, 10000000, "The number of iterations to run."); +ABSL_FLAG(int32_t, report_freq, 100, "Number of iterations between reports."); +ABSL_FLAG(bool, random_initial_regrets, false, + "Initialize CFR tables with random initial regret?"); +ABSL_FLAG(int32_t, seed, 23894971, "Seed for the random initial regrets."); +ABSL_FLAG(bool, alternating_updates, false, + "Use alternating regret updates in CFR?"); using open_spiel::algorithms::CorrDevBuilder; using open_spiel::algorithms::CFRSolverBase; @@ -56,22 +51,24 @@ void run_experiment() { "efg_game", {{"filename", open_spiel::GameParameter( absl::GetFlag(FLAGS_efg_file))}}); const int iterations = absl::GetFlag(FLAGS_t); + const int report_freq = absl::GetFlag(FLAGS_report_freq); + const bool alternating_updates = absl::GetFlag(FLAGS_alternating_updates); + const bool random_initial_regrets = + absl::GetFlag(FLAGS_random_initial_regrets); + const int seed = absl::GetFlag(FLAGS_seed); CorrDevBuilder cd_builder; - CFRSolverBase solver(*game, - /*alternating_updates=*/false, - /*linear_averaging=*/false, - /*regret_matching_plus=*/false, - /*random_initial_regrets*/ true, - /*seed*/kSeed); + CFRSolverBase solver(*game, alternating_updates, /*linear_averaging=*/false, + /*regret_matching_plus=*/false, random_initial_regrets, + seed); CorrDistConfig config; - for (int i = 0; i < 10000000; i++) { + for (int i = 0; i < iterations; i++) { solver.EvaluateAndUpdatePolicy(); TabularPolicy current_policy = static_cast(solver.CurrentPolicy().get()) ->AsTabular(); cd_builder.AddMixedJointPolicy(current_policy); - if (i < 100 || i % 100 == 0) { + if (i < 100 || i % report_freq == 0) { CorrelationDevice mu = cd_builder.GetCorrelationDevice(); double afcce_dist = AFCCEDist(*game, config, mu); double afce_dist = AFCEDist(*game, config, mu);