Skip to content

Commit

Permalink
allow for top-level manipulation of stage 2 initialization
Browse files Browse the repository at this point in the history
  • Loading branch information
jessemilzman committed Feb 25, 2024
1 parent 45cac9b commit e5e7c5b
Showing 1 changed file with 36 additions and 22 deletions.
58 changes: 36 additions & 22 deletions experiments/tower_defense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ Temp. script to calculate and plot heatmap of Stage 1 cost function
function run_visualization()
dr = 0.01
ps = [1/3, 1 / 3, 1 / 3]
βs = [[2, 1, 1], [1, 2, 1], [1, 1, 2]]
βs = [[4.,2.,2.], [2., 3., 2.], [2., 2., 3.]]
Ks = calculate_stage_1_costs(ps, βs; dr)
fig = display_surface(ps, Ks)
fig
Expand All @@ -105,25 +105,30 @@ function run_stage_1_breakout(;display_controls = 0, dr = 0.05)
# dr = 0.05
ps = [1/3, 1/3, 1/3]
βs = [
[2.1, 2.0, 2.0],
[2.0, 2.1, 2.0],
[2.0, 2.0, 2.1]
[4.0, 2.0, 2.0],
[2.0, 3., 2.0],
[2.0, 2.0, 3.0]
]
# initial_guess = vcat(repeat([0.9,0.05,0.05],4),repeat([0.1,0.5,0.4],6),(1/3)*ones(42))
# primal_guess = vcat(repeat([0.5,0.25,0.25],4), repeat([0.9,0.05,0.05,0.05,0.9,0.05,0.05,0.05,0.9],2))
# primal_guess = vcat([0.34,0.33,0.33,0.5,0.25,0.25,0.25,0.5,0.25,0.25,0.25,0.5], repeat([0.9,0.05,0.05,0.05,0.9,0.05,0.05,0.05,0.9],2))
primal_guess = (1/3)*ones(30)
initial_guess = vcat(primal_guess,(1/3)*ones(42))

if (display_controls in [1,2])
world_1_misid_costs, world_1_misid_controls = calculate_misid_costs(ps, βs, 1; dr, return_controls=display_controls)
world_2_misid_costs, world_2_misid_controls = calculate_misid_costs(ps, βs, 2; dr, return_controls=display_controls)
world_3_misid_costs, world_3_misid_controls = calculate_misid_costs(ps, βs, 3; dr, return_controls=display_controls)
world_1_id_costs, world_1_id_controls = calculate_id_costs(ps, βs, 1; dr, return_controls=display_controls)
world_2_id_costs, world_2_id_controls = calculate_id_costs(ps, βs, 2; dr, return_controls=display_controls)
world_3_id_costs, world_3_id_controls = calculate_id_costs(ps, βs, 3; dr, return_controls=display_controls)
world_1_misid_costs, world_1_misid_controls = calculate_misid_costs(ps, βs, 1; dr, return_controls=display_controls, initial_guess=initial_guess)
world_2_misid_costs, world_2_misid_controls = calculate_misid_costs(ps, βs, 2; dr, return_controls=display_controls, initial_guess=initial_guess)
world_3_misid_costs, world_3_misid_controls = calculate_misid_costs(ps, βs, 3; dr, return_controls=display_controls, initial_guess=initial_guess)
world_1_id_costs, world_1_id_controls = calculate_id_costs(ps, βs, 1; dr, return_controls=display_controls, initial_guess=initial_guess)
world_2_id_costs, world_2_id_controls = calculate_id_costs(ps, βs, 2; dr, return_controls=display_controls, initial_guess=initial_guess)
world_3_id_costs, world_3_id_controls = calculate_id_costs(ps, βs, 3; dr, return_controls=display_controls, initial_guess=initial_guess)
else
world_1_misid_costs = calculate_misid_costs(ps, βs, 1; dr)
world_2_misid_costs = calculate_misid_costs(ps, βs, 2; dr)
world_3_misid_costs = calculate_misid_costs(ps, βs, 3; dr)
world_1_id_costs = calculate_id_costs(ps, βs, 1; dr)
world_2_id_costs = calculate_id_costs(ps, βs, 2; dr)
world_3_id_costs = calculate_id_costs(ps, βs, 3; dr)
world_1_misid_costs = calculate_misid_costs(ps, βs, 1; dr, initial_guess=initial_guess)
world_2_misid_costs = calculate_misid_costs(ps, βs, 2; dr, initial_guess=initial_guess)
world_3_misid_costs = calculate_misid_costs(ps, βs, 3; dr, initial_guess=initial_guess)
world_1_id_costs = calculate_id_costs(ps, βs, 1; dr, initial_guess=initial_guess)
world_2_id_costs = calculate_id_costs(ps, βs, 2; dr, initial_guess=initial_guess)
world_3_id_costs = calculate_id_costs(ps, βs, 3; dr, initial_guess=initial_guess)
end
# Normalize using maximum value across all worlds
max_value =
Expand Down Expand Up @@ -226,7 +231,7 @@ function calculate_residuals(ps, βs, world_idx; dr = 0.05)
return residuals
end

function calculate_id_costs(ps, βs, world_idx; dr = 0.05, return_controls=0)
function calculate_id_costs(ps, βs, world_idx; dr = 0.05, return_controls=0, initial_guess=nothing)
@assert sum(ps) 1.0 "Prior distribution ps must be a probability distribution"
game, _ = build_stage_2(ps, βs)
rs = 0:dr:1
Expand All @@ -248,7 +253,7 @@ function calculate_id_costs(ps, βs, world_idx; dr = 0.05, return_controls=0)
end
r3 = 1 - r1 - r2
r = [r1, r2, r3]
x = compute_stage_2(r, ps, βs, game)
x = compute_stage_2(r, ps, βs, game, initial_guess=initial_guess)
id_cost =
r[world_idx] *
ps[world_idx] *
Expand All @@ -272,7 +277,7 @@ function calculate_id_costs(ps, βs, world_idx; dr = 0.05, return_controls=0)
end
end

function calculate_misid_costs(ps, βs, world_idx; dr = 0.05, return_controls = 0)
function calculate_misid_costs(ps, βs, world_idx; dr = 0.05, return_controls = 0, initial_guess=nothing)
@assert sum(ps) 1.0 "Prior distribution ps must be a probability distribution"
game, _ = build_stage_2(ps, βs)
rs = 0:dr:1
Expand All @@ -294,7 +299,7 @@ function calculate_misid_costs(ps, βs, world_idx; dr = 0.05, return_controls =
end
r3 = 1 - r1 - r2
r = [r1, r2, r3]
x = compute_stage_2(r, ps, βs, game)
x = compute_stage_2(r, ps, βs, game, initial_guess=initial_guess)
defender_signal_0 = x[Block(1)]
attacker_signal_0_world_idx = x[Block(world_idx + num_worlds + 1)]
misid_cost = J_1(defender_signal_0, attacker_signal_0_world_idx, βs[world_idx])
Expand Down Expand Up @@ -669,7 +674,7 @@ function J_2(u, v, β)

end

function activate(δ; k=10.0)
function activate(δ; k=1.0)
return 1/(1 + exp(-2 * δ * k))
end

Expand Down Expand Up @@ -817,7 +822,16 @@ function compute_stage_2(r, ps, βs, game; initial_guess = nothing, verbose = fa
solution = solve(
game,
r;
initial_guess = isnothing(initial_guess) ? 1/3 * ones(total_dim(game)) : initial_guess,
initial_guess = isnothing(initial_guess) ? 1/3 * ones(total_dim(game)) : initial_guess, ### gives smooth cost surfaces
# initial_guess = isnothing(initial_guess) ? repeat([1.0,0.0,0.0],24) : initial_guess,
# initial_guess = isnothing(initial_guess) ? vcat(repeat([1.0,0.0,0.0],10),zeros(14*3)) : initial_guess,
# initial_guess = isnothing(initial_guess) ? vcat(repeat([1.0,0.0,0.0],10),(1/3) * ones(14*3)) : initial_guess,
# initial_guess = isnothing(initial_guess) ? vcat((1/3)*ones(30),(1/3)*ones(10),zeros(32)) : initial_guess,
# initial_guess = isnothing(initial_guess) ? vcat((1/3)*ones(30),(0.0)*ones(10),(1/3)*ones(30),zeros(2)) : initial_guess, ### gives smooth cost surfaces
# initial_guess = isnothing(initial_guess) ? vcat((1/3)*ones(30),(0.0)*ones(10),repeat([0.0, 0.5,0.5],10),zeros(2)) : initial_guess, ## also smooth
# initial_guess = isnothing(initial_guess) ? vcat(repeat([1.0, 0.0,0.0],10),(0.0)*ones(10),repeat([0.0, 0.5,0.5],10),zeros(2)) : initial_guess,
# initial_guess = isnothing(initial_guess) ? vcat(repeat([0.9,0.05,0.05],4),repeat([0.1,0.5,0.4],6),(1/3)*ones(14*3)) : initial_guess,
# initial_guess = initial_guess,
verbose = verbose,
return_primals = false,
)
Expand Down

0 comments on commit e5e7c5b

Please sign in to comment.