diff --git a/experiments/tower_defense.jl b/experiments/tower_defense.jl index 6862e97..02e7fac 100644 --- a/experiments/tower_defense.jl +++ b/experiments/tower_defense.jl @@ -101,7 +101,7 @@ end """ Temp. script to calculate and plot surfaces for the terms in Stage 1's cost function """ -function run_stage_1_breakout(;display_controls = 0, dr = 0.05) +function run_stage_1_breakout(;display_controls = 0, dr = 0.05, cost_player = 1) # dr = 0.05 ps = [1/3, 1/3, 1/3] βs = [ @@ -109,30 +109,33 @@ function run_stage_1_breakout(;display_controls = 0, dr = 0.05) [2.0, 3., 2.0], [2.0, 2.0, 3.0] ] - # initial_guess = vcat(repeat([0.9,0.05,0.05],4),repeat([0.1,0.5,0.4],6),(1/3)*ones(42)) - # primal_guess = vcat(repeat([0.5,0.25,0.25],4), repeat([0.9,0.05,0.05,0.05,0.9,0.05,0.05,0.05,0.9],2)) - # primal_guess = vcat([0.34,0.33,0.33,0.5,0.25,0.25,0.25,0.5,0.25,0.25,0.25,0.5], repeat([0.9,0.05,0.05,0.05,0.9,0.05,0.05,0.05,0.9],2)) - primal_guess = (1/3)*ones(30) - initial_guess = vcat(primal_guess,(1/3)*ones(42)) + #### Choose the initial guess for Stage 2 initialization + primal_guess = (1/3)*ones(30) ## Initialization frorm primes + initial_guess = vcat(primal_guess,(1/3)*ones(42)) ## concatenate, assume duals are 1/3 + + if (display_controls in [1,2]) - world_1_misid_costs, world_1_misid_controls = calculate_misid_costs(ps, βs, 1; dr, return_controls=display_controls, initial_guess=initial_guess) - world_2_misid_costs, world_2_misid_controls = calculate_misid_costs(ps, βs, 2; dr, return_controls=display_controls, initial_guess=initial_guess) - world_3_misid_costs, world_3_misid_controls = calculate_misid_costs(ps, βs, 3; dr, return_controls=display_controls, initial_guess=initial_guess) - world_1_id_costs, world_1_id_controls = calculate_id_costs(ps, βs, 1; dr, return_controls=display_controls, initial_guess=initial_guess) - world_2_id_costs, world_2_id_controls = calculate_id_costs(ps, βs, 2; dr, return_controls=display_controls, initial_guess=initial_guess) - world_3_id_costs, world_3_id_controls = calculate_id_costs(ps, βs, 3; dr, return_controls=display_controls, initial_guess=initial_guess) + world_1_misid_costs, world_1_misid_controls = calculate_misid_costs(ps, βs, 1; dr, return_controls=display_controls, initial_guess=initial_guess, cost_player=cost_player) + world_2_misid_costs, world_2_misid_controls = calculate_misid_costs(ps, βs, 2; dr, return_controls=display_controls, initial_guess=initial_guess, cost_player=cost_player) + world_3_misid_costs, world_3_misid_controls = calculate_misid_costs(ps, βs, 3; dr, return_controls=display_controls, initial_guess=initial_guess, cost_player=cost_player) + world_1_id_costs, world_1_id_controls = calculate_id_costs(ps, βs, 1; dr, return_controls=display_controls, initial_guess=initial_guess, cost_player=cost_player) + world_2_id_costs, world_2_id_controls = calculate_id_costs(ps, βs, 2; dr, return_controls=display_controls, initial_guess=initial_guess, cost_player=cost_player) + world_3_id_costs, world_3_id_controls = calculate_id_costs(ps, βs, 3; dr, return_controls=display_controls, initial_guess=initial_guess, cost_player=cost_player) else - world_1_misid_costs = calculate_misid_costs(ps, βs, 1; dr, initial_guess=initial_guess) - world_2_misid_costs = calculate_misid_costs(ps, βs, 2; dr, initial_guess=initial_guess) - world_3_misid_costs = calculate_misid_costs(ps, βs, 3; dr, initial_guess=initial_guess) - world_1_id_costs = calculate_id_costs(ps, βs, 1; dr, initial_guess=initial_guess) - world_2_id_costs = calculate_id_costs(ps, βs, 2; dr, initial_guess=initial_guess) - world_3_id_costs = calculate_id_costs(ps, βs, 3; dr, initial_guess=initial_guess) + world_1_misid_costs = calculate_misid_costs(ps, βs, 1; dr, initial_guess=initial_guess, cost_player=cost_player) + world_2_misid_costs = calculate_misid_costs(ps, βs, 2; dr, initial_guess=initial_guess, cost_player=cost_player) + world_3_misid_costs = calculate_misid_costs(ps, βs, 3; dr, initial_guess=initial_guess, cost_player=cost_player) + world_1_id_costs = calculate_id_costs(ps, βs, 1; dr, initial_guess=initial_guess, cost_player=cost_player) + world_2_id_costs = calculate_id_costs(ps, βs, 2; dr, initial_guess=initial_guess, cost_player=cost_player) + world_3_id_costs = calculate_id_costs(ps, βs, 3; dr, initial_guess=initial_guess, cost_player=cost_player) end # Normalize using maximum value across all worlds + + maxormin = cost_player == 2 ? minimum : maximum + max_value = - maximum( + maxormin( filter( !isnan, vcat( @@ -145,6 +148,7 @@ function run_stage_1_breakout(;display_controls = 0, dr = 0.05) ), ), ) + max_value = (-1)^(cost_player+1)*max_value world_1_misid_costs = [isnan(c) ? NaN : c / max_value for c in world_1_misid_costs] world_2_misid_costs = [isnan(c) ? NaN : c / max_value for c in world_2_misid_costs] world_3_misid_costs = [isnan(c) ? NaN : c / max_value for c in world_3_misid_costs] @@ -170,7 +174,9 @@ function run_stage_1_breakout(;display_controls = 0, dr = 0.05) world_2_misid_controls, world_3_misid_controls, ], - ps, save_file="P"*string(display_controls)*"_" + ps, + save_file="P"*string(display_controls)*"_", + cost_player=cost_player ) else display_stage_1_costs( @@ -231,7 +237,7 @@ function calculate_residuals(ps, βs, world_idx; dr = 0.05) return residuals end -function calculate_id_costs(ps, βs, world_idx; dr = 0.05, return_controls=0, initial_guess=nothing) +function calculate_id_costs(ps, βs, world_idx; dr = 0.05, return_controls=0, initial_guess=nothing, cost_player = 1) @assert sum(ps) ≈ 1.0 "Prior distribution ps must be a probability distribution" game, _ = build_stage_2(ps, βs) rs = 0:dr:1 @@ -245,6 +251,7 @@ function calculate_id_costs(ps, βs, world_idx; dr = 0.05, return_controls=0, in return_controls = 0 end end + J = cost_player == 2 ? J_2 : J_1 for (i, r1) in enumerate(rs) for (j, r2) in enumerate(rs) @@ -257,7 +264,7 @@ function calculate_id_costs(ps, βs, world_idx; dr = 0.05, return_controls=0, in id_cost = r[world_idx] * ps[world_idx] * - J_1( + J( x[Block(world_idx + 1)], x[Block(world_idx + 2 * num_worlds + 1)], βs[world_idx], @@ -277,7 +284,7 @@ function calculate_id_costs(ps, βs, world_idx; dr = 0.05, return_controls=0, in end end -function calculate_misid_costs(ps, βs, world_idx; dr = 0.05, return_controls = 0, initial_guess=nothing) +function calculate_misid_costs(ps, βs, world_idx; dr = 0.05, return_controls = 0, initial_guess=nothing, cost_player = 1) @assert sum(ps) ≈ 1.0 "Prior distribution ps must be a probability distribution" game, _ = build_stage_2(ps, βs) rs = 0:dr:1 @@ -291,6 +298,7 @@ function calculate_misid_costs(ps, βs, world_idx; dr = 0.05, return_controls = return_controls = 0 end end + J = cost_player == 2 ? J_2 : J_1 for (i, r1) in enumerate(rs) for (j, r2) in enumerate(rs) @@ -302,7 +310,7 @@ function calculate_misid_costs(ps, βs, world_idx; dr = 0.05, return_controls = x = compute_stage_2(r, ps, βs, game, initial_guess=initial_guess) defender_signal_0 = x[Block(1)] attacker_signal_0_world_idx = x[Block(world_idx + num_worlds + 1)] - misid_cost = J_1(defender_signal_0, attacker_signal_0_world_idx, βs[world_idx]) + misid_cost = J(defender_signal_0, attacker_signal_0_world_idx, βs[world_idx]) misid_cost = (1 - r[world_idx]) * ps[world_idx] * misid_cost # weight by p(w_k|s¹=0) misid_costs[i, j] = misid_cost if (return_controls > 0) @@ -454,11 +462,11 @@ Input: Output: fig: Figure with simplex heatmap """ -function display_stage_1_costs_controls(costs, controls, ps; save_file = "") +function display_stage_1_costs_controls(costs, controls, ps; save_file = "", cost_player=1) rs = 0:(1 / (size(costs[1])[1] - 1)):1 num_worlds = length(ps) fig = Figure(size = (1500, 1000), title = "test") - max_value = 1.0 + ylims = cost_player == 2 ? (-1.0,0.0) : (0.01, 1.0) ## either graph from y=0,1 (for normalized cost for P1), or else y=-1,0 (for P2) axs = [ [ Axis3( @@ -474,7 +482,7 @@ function display_stage_1_costs_controls(costs, controls, ps; save_file = "") ylabel = "r₂", zlabel = "Cost", title = "W$world_idx, S$world_idx", - limits = (nothing, nothing, (0.01, max_value)), + limits = (nothing, nothing, ylims), ) for world_idx in 1:num_worlds ], [ @@ -491,7 +499,7 @@ function display_stage_1_costs_controls(costs, controls, ps; save_file = "") ylabel = "r₂", zlabel = "Cost", title = "W$world_idx, S0", - limits = (nothing, nothing, (0.01, max_value)), + limits = (nothing, nothing, ylims), ) for world_idx in 1:num_worlds ], ] @@ -505,8 +513,8 @@ function display_stage_1_costs_controls(costs, controls, ps; save_file = "") rs[jj], costs[world_idx][ii,jj], color = colors[ii,jj], - colormap = :viridis, - colorrange = (0, max_value), + # colormap = :viridis, + # colorrange = (0, max_value), ) end end @@ -526,8 +534,8 @@ function display_stage_1_costs_controls(costs, controls, ps; save_file = "") rs[jj], costs[world_idx+num_worlds][ii,jj], color = colors[ii,jj], - colormap = :viridis, - colorrange = (0, max_value), + # colormap = :viridis, + # colorrange = (0, max_value), ) end end