From e5e7c5bef1be8f85d71e0f1636cca971323d2704 Mon Sep 17 00:00:00 2001 From: Jesse Milzman Date: Sun, 25 Feb 2024 16:21:16 -0500 Subject: [PATCH] allow for top-level manipulation of stage 2 initialization --- experiments/tower_defense.jl | 58 ++++++++++++++++++++++-------------- 1 file changed, 36 insertions(+), 22 deletions(-) diff --git a/experiments/tower_defense.jl b/experiments/tower_defense.jl index 0adf0fc..c5cdabc 100644 --- a/experiments/tower_defense.jl +++ b/experiments/tower_defense.jl @@ -92,7 +92,7 @@ Temp. script to calculate and plot heatmap of Stage 1 cost function function run_visualization() dr = 0.01 ps = [1/3, 1 / 3, 1 / 3] - βs = [[2, 1, 1], [1, 2, 1], [1, 1, 2]] + βs = [[4.,2.,2.], [2., 3., 2.], [2., 2., 3.]] Ks = calculate_stage_1_costs(ps, βs; dr) fig = display_surface(ps, Ks) fig @@ -105,25 +105,30 @@ function run_stage_1_breakout(;display_controls = 0, dr = 0.05) # dr = 0.05 ps = [1/3, 1/3, 1/3] βs = [ - [2.1, 2.0, 2.0], - [2.0, 2.1, 2.0], - [2.0, 2.0, 2.1] + [4.0, 2.0, 2.0], + [2.0, 3., 2.0], + [2.0, 2.0, 3.0] ] + # initial_guess = vcat(repeat([0.9,0.05,0.05],4),repeat([0.1,0.5,0.4],6),(1/3)*ones(42)) + # primal_guess = vcat(repeat([0.5,0.25,0.25],4), repeat([0.9,0.05,0.05,0.05,0.9,0.05,0.05,0.05,0.9],2)) + # primal_guess = vcat([0.34,0.33,0.33,0.5,0.25,0.25,0.25,0.5,0.25,0.25,0.25,0.5], repeat([0.9,0.05,0.05,0.05,0.9,0.05,0.05,0.05,0.9],2)) + primal_guess = (1/3)*ones(30) + initial_guess = vcat(primal_guess,(1/3)*ones(42)) if (display_controls in [1,2]) - world_1_misid_costs, world_1_misid_controls = calculate_misid_costs(ps, βs, 1; dr, return_controls=display_controls) - world_2_misid_costs, world_2_misid_controls = calculate_misid_costs(ps, βs, 2; dr, return_controls=display_controls) - world_3_misid_costs, world_3_misid_controls = calculate_misid_costs(ps, βs, 3; dr, return_controls=display_controls) - world_1_id_costs, world_1_id_controls = calculate_id_costs(ps, βs, 1; dr, return_controls=display_controls) - world_2_id_costs, world_2_id_controls = calculate_id_costs(ps, βs, 2; dr, return_controls=display_controls) - world_3_id_costs, world_3_id_controls = calculate_id_costs(ps, βs, 3; dr, return_controls=display_controls) + world_1_misid_costs, world_1_misid_controls = calculate_misid_costs(ps, βs, 1; dr, return_controls=display_controls, initial_guess=initial_guess) + world_2_misid_costs, world_2_misid_controls = calculate_misid_costs(ps, βs, 2; dr, return_controls=display_controls, initial_guess=initial_guess) + world_3_misid_costs, world_3_misid_controls = calculate_misid_costs(ps, βs, 3; dr, return_controls=display_controls, initial_guess=initial_guess) + world_1_id_costs, world_1_id_controls = calculate_id_costs(ps, βs, 1; dr, return_controls=display_controls, initial_guess=initial_guess) + world_2_id_costs, world_2_id_controls = calculate_id_costs(ps, βs, 2; dr, return_controls=display_controls, initial_guess=initial_guess) + world_3_id_costs, world_3_id_controls = calculate_id_costs(ps, βs, 3; dr, return_controls=display_controls, initial_guess=initial_guess) else - world_1_misid_costs = calculate_misid_costs(ps, βs, 1; dr) - world_2_misid_costs = calculate_misid_costs(ps, βs, 2; dr) - world_3_misid_costs = calculate_misid_costs(ps, βs, 3; dr) - world_1_id_costs = calculate_id_costs(ps, βs, 1; dr) - world_2_id_costs = calculate_id_costs(ps, βs, 2; dr) - world_3_id_costs = calculate_id_costs(ps, βs, 3; dr) + world_1_misid_costs = calculate_misid_costs(ps, βs, 1; dr, initial_guess=initial_guess) + world_2_misid_costs = calculate_misid_costs(ps, βs, 2; dr, initial_guess=initial_guess) + world_3_misid_costs = calculate_misid_costs(ps, βs, 3; dr, initial_guess=initial_guess) + world_1_id_costs = calculate_id_costs(ps, βs, 1; dr, initial_guess=initial_guess) + world_2_id_costs = calculate_id_costs(ps, βs, 2; dr, initial_guess=initial_guess) + world_3_id_costs = calculate_id_costs(ps, βs, 3; dr, initial_guess=initial_guess) end # Normalize using maximum value across all worlds max_value = @@ -226,7 +231,7 @@ function calculate_residuals(ps, βs, world_idx; dr = 0.05) return residuals end -function calculate_id_costs(ps, βs, world_idx; dr = 0.05, return_controls=0) +function calculate_id_costs(ps, βs, world_idx; dr = 0.05, return_controls=0, initial_guess=nothing) @assert sum(ps) ≈ 1.0 "Prior distribution ps must be a probability distribution" game, _ = build_stage_2(ps, βs) rs = 0:dr:1 @@ -248,7 +253,7 @@ function calculate_id_costs(ps, βs, world_idx; dr = 0.05, return_controls=0) end r3 = 1 - r1 - r2 r = [r1, r2, r3] - x = compute_stage_2(r, ps, βs, game) + x = compute_stage_2(r, ps, βs, game, initial_guess=initial_guess) id_cost = r[world_idx] * ps[world_idx] * @@ -272,7 +277,7 @@ function calculate_id_costs(ps, βs, world_idx; dr = 0.05, return_controls=0) end end -function calculate_misid_costs(ps, βs, world_idx; dr = 0.05, return_controls = 0) +function calculate_misid_costs(ps, βs, world_idx; dr = 0.05, return_controls = 0, initial_guess=nothing) @assert sum(ps) ≈ 1.0 "Prior distribution ps must be a probability distribution" game, _ = build_stage_2(ps, βs) rs = 0:dr:1 @@ -294,7 +299,7 @@ function calculate_misid_costs(ps, βs, world_idx; dr = 0.05, return_controls = end r3 = 1 - r1 - r2 r = [r1, r2, r3] - x = compute_stage_2(r, ps, βs, game) + x = compute_stage_2(r, ps, βs, game, initial_guess=initial_guess) defender_signal_0 = x[Block(1)] attacker_signal_0_world_idx = x[Block(world_idx + num_worlds + 1)] misid_cost = J_1(defender_signal_0, attacker_signal_0_world_idx, βs[world_idx]) @@ -669,7 +674,7 @@ function J_2(u, v, β) end -function activate(δ; k=10.0) +function activate(δ; k=1.0) return 1/(1 + exp(-2 * δ * k)) end @@ -817,7 +822,16 @@ function compute_stage_2(r, ps, βs, game; initial_guess = nothing, verbose = fa solution = solve( game, r; - initial_guess = isnothing(initial_guess) ? 1/3 * ones(total_dim(game)) : initial_guess, + initial_guess = isnothing(initial_guess) ? 1/3 * ones(total_dim(game)) : initial_guess, ### gives smooth cost surfaces + # initial_guess = isnothing(initial_guess) ? repeat([1.0,0.0,0.0],24) : initial_guess, + # initial_guess = isnothing(initial_guess) ? vcat(repeat([1.0,0.0,0.0],10),zeros(14*3)) : initial_guess, + # initial_guess = isnothing(initial_guess) ? vcat(repeat([1.0,0.0,0.0],10),(1/3) * ones(14*3)) : initial_guess, + # initial_guess = isnothing(initial_guess) ? vcat((1/3)*ones(30),(1/3)*ones(10),zeros(32)) : initial_guess, + # initial_guess = isnothing(initial_guess) ? vcat((1/3)*ones(30),(0.0)*ones(10),(1/3)*ones(30),zeros(2)) : initial_guess, ### gives smooth cost surfaces + # initial_guess = isnothing(initial_guess) ? vcat((1/3)*ones(30),(0.0)*ones(10),repeat([0.0, 0.5,0.5],10),zeros(2)) : initial_guess, ## also smooth + # initial_guess = isnothing(initial_guess) ? vcat(repeat([1.0, 0.0,0.0],10),(0.0)*ones(10),repeat([0.0, 0.5,0.5],10),zeros(2)) : initial_guess, + # initial_guess = isnothing(initial_guess) ? vcat(repeat([0.9,0.05,0.05],4),repeat([0.1,0.5,0.4],6),(1/3)*ones(14*3)) : initial_guess, + # initial_guess = initial_guess, verbose = verbose, return_primals = false, )