Skip to content

Commit

Permalink
start demoing MILP
Browse files Browse the repository at this point in the history
  • Loading branch information
mcoughlin committed May 7, 2024
1 parent 5d803fc commit 6e4d52f
Show file tree
Hide file tree
Showing 6 changed files with 276 additions and 7 deletions.
4 changes: 4 additions & 0 deletions gwemopt/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,4 +160,8 @@ def parse_args(args):

parser.add_argument("--inclination", action="store_true", default=False)

parser.add_argument("--solverType", default="heuristic")
parser.add_argument("--milpSolver", default="PULP_CBC_CMD")
parser.add_argument("--milpOptions")

return parser.parse_args(args=args)
11 changes: 11 additions & 0 deletions gwemopt/params.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import os
from pathlib import Path

Expand Down Expand Up @@ -146,4 +147,14 @@ def params_struct(opts):
opts.inclination if hasattr(opts, "true_location") else False
)

params["solverType"] = (
opts.solverType if hasattr(opts, "solverType") else "heuristic"
)
params["milpSolver"] = (
opts.milpSolver if hasattr(opts, "milpSolver") else "PULP_CBC_CMD"
)
params["milpOptions"] = (
json.loads(opts.milpOptions) if hasattr(opts, "milpOptions") else {}
)

return params
141 changes: 134 additions & 7 deletions gwemopt/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,7 @@
from ortools.linear_solver import pywraplp

from gwemopt.tiles import balance_tiles, optimize_max_tiles, schedule_alternating
from gwemopt.utils import angular_distance

# from munkres import Munkres, make_cost_matrix
from gwemopt.utils import angular_distance, solve_milp


def get_altaz_tile(ra, dec, observer, obstime):
Expand Down Expand Up @@ -90,7 +88,7 @@ def find_tile(
return idx2, exposureids, probs


def get_order(
def get_order_heuristic(
params, tile_struct, tilesegmentlists, exposurelist, observer, config_struct
):
"""
Expand Down Expand Up @@ -473,6 +471,127 @@ def get_order(
return idxs, filts


def get_order_milp(params, tile_struct, exposurelist, observer, config_struct):
"""
tile_struct: dictionary. key -> struct info.
exposurelist: list of segments that the telescope is supposed to be working.
consecutive segments from the start to the end, with each segment size
being the exposure time.
Returns a list of tile indices in the order of observation.
"""

if "dec_constraint" in config_struct:
dec_constraint = config_struct["dec_constraint"].split(",")
dec_min = float(dec_constraint[0])
dec_max = float(dec_constraint[1])

exposureids = []
probs = []
ras, decs, filts, keys = [], [], [], []
for ii, key in enumerate(list(tile_struct.keys())):
if tile_struct[key]["prob"] == 0:
continue
if "dec_constraint" in config_struct:
if (tile_struct[key]["dec"] < dec_min) or (
tile_struct[key]["dec"] > dec_max
):
continue

exposureids.append(key)
probs.append(tile_struct[key]["prob"])
ras.append(tile_struct[key]["ra"])
decs.append(tile_struct[key]["dec"])
filts.append(tile_struct[key]["filt"])
keys.append(key)

fields = -1 * np.ones(
(len(exposurelist)),
)
filters = ["n"] * len(exposurelist)

if len(probs) == 0:
return fields, filters

probs = np.array(probs)
ras = np.array(ras)
decs = np.array(decs)
exposureids = np.array(exposureids)
keys = np.array(keys)
tilematrix = np.zeros((len(exposurelist), len(ras)))
for ii in np.arange(len(exposurelist)):
# first, create an array of airmass-weighted probabilities
t = Time(exposurelist[ii][0], format="mjd")
altazs = [get_altaz_tile(ra, dec, observer, t) for ra, dec in zip(ras, decs)]
alts = np.array([altaz[0] for altaz in altazs])
horizon = config_struct["horizon"]
horizon_mask = alts <= horizon
airmass = 1 / np.cos((90.0 - alts) * np.pi / 180.0)
airmass_mask = airmass > params["airmass"]

airmass_weight = 10 ** (0.4 * 0.1 * (airmass - 1))

if params["scheduleType"] in ["greedy", "greedy_slew"]:
tilematrix[ii, :] = np.array(probs)
elif params["scheduleType"] == ["airmass_weighted", "airmass_weighted_slew"]:
tilematrix[ii, :] = np.array(probs / airmass_weight)
tilematrix[ii, horizon_mask] = np.nan
tilematrix[ii, airmass_mask] = np.nan

for jj, key in enumerate(keys):
tilesegmentlist = tile_struct[key]["segmentlist"]
if not tilesegmentlist.intersects_segment(exposurelist[ii]):
tilematrix[ii, jj] = np.nan

# which fields are never observable
ind = np.where(np.nansum(tilematrix, axis=0) > 0)[0]

probs = np.array(probs)[ind]
ras = np.array(ras)[ind]
decs = np.array(decs)[ind]
exposureids = np.array(exposureids)[ind]
filts = [filts[i] for i in ind]
tilematrix = tilematrix[:, ind]

# which times do not have any observability
ind = np.where(np.nansum(tilematrix, axis=1) > 0)[0]
tilematrix = tilematrix[ind, :]

cost_matrix = tilematrix
cost_matrix[np.isnan(cost_matrix)] = -np.inf

distmatrix = np.zeros((len(ras), len(ras)))
for ii, (r, d) in enumerate(zip(ras, decs)):
dist = angular_distance(r, d, ras, decs)
if "slew" in params["scheduleType"]:
dist = dist / config_struct["slew_rate"]
dist = dist - config_struct["readout"]
dist[dist < 0] = 0
else:
distmatrix[ii, :] = dist
distmatrix = distmatrix / np.max(distmatrix)

dt = int(np.ceil((exposurelist[1][0] - exposurelist[0][0]) * 86400))
optimal_points = solve_milp(
cost_matrix,
dist_matrix=distmatrix,
useDistance=False,
max_tasks_per_worker=len(params["filters"]),
useTaskSepration=False,
min_task_separation=int(np.ceil(dt / params["mindiff"])),
)

for optimal_point in optimal_points:
idx = ind[optimal_point[0]]
idy = optimal_point[1]
if len(filts[idy]) > 0:
fields[idx] = exposureids[idy]
filters[idx] = filts[idy][0]
filt = filts[idy][1:]
filts[idy] = filt

return fields, filters


def scheduler(params, config_struct, tile_struct):
"""
config_struct: the telescope configurations
Expand Down Expand Up @@ -506,9 +625,17 @@ def scheduler(params, config_struct, tile_struct):
for key in keys:
# segments.py: tile_struct[key]["segmentlist"] is a list of segments when the tile is available for observation
tilesegmentlists.append(tile_struct[key]["segmentlist"])
keys, filts = get_order(
params, tile_struct, tilesegmentlists, exposurelist, observer, config_struct
)

if params["solverType"] == "heuristic":
keys, filts = get_order_heuristic(
params, tile_struct, tilesegmentlists, exposurelist, observer, config_struct
)
elif params["solverType"] == "milp":
keys, filts = get_order_milp(
params, tile_struct, exposurelist, observer, config_struct
)
else:
raise ValueError(f'Unknown solverType {params["solverType"]}')

if params["doPlots"]:
from gwemopt.plotting import make_schedule_plots
Expand Down
1 change: 1 addition & 0 deletions gwemopt/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from gwemopt.utils.geometry import angular_distance
from gwemopt.utils.milp import solve_milp
from gwemopt.utils.misc import auto_rasplit, get_exposures, integrationTime
from gwemopt.utils.observability import calculate_observability
from gwemopt.utils.param_utils import readParamsFromFile
Expand Down
125 changes: 125 additions & 0 deletions gwemopt/utils/milp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import time

import numpy as np
import pulp
from tqdm import tqdm


def solve_milp(
cost_matrix,
max_tasks_per_worker=1,
useTaskSepration=False,
min_task_separation=1,
useDistance=False,
dist_matrix=None,
timeLimit=300,
milpSolver="PULP_CBC_CMD",
milpOptions={},
):

cost_matrix_mask = cost_matrix > 10 ** (-10)
optimal_points = []

if cost_matrix_mask.any():
print("Calculating MILP solution...")

# Create a CP-SAT model
problem = pulp.LpProblem("problem", pulp.LpMaximize)

# Define variables
num_exposures, num_fields = cost_matrix.shape

print("Define decision variables...")
# Define decision variables
x = {
(i, j): pulp.LpVariable(f"x_{i}_{j}", cat=pulp.LpBinary)
for i in tqdm(range(num_exposures))
for j in range(num_fields)
}

print("Define binary variables for task separation violation...")
# Define binary variables for task separation violation
s = {
(i, j): pulp.LpVariable(f"s_{i}_{j}", cat=pulp.LpBinary)
for i in tqdm(range(num_exposures))
for j in range(num_fields)
}

obj = pulp.lpSum(
x[i, j] * cost_matrix[i][j]
for i in range(num_exposures)
for j in range(num_fields)
)

if useDistance:
products = [
(x[i, j], dist_matrix[j, k])
for i in range(num_exposures)
for j in range(num_fields)
for k in range(dist_matrix.shape[1])
]
total_distance = pulp.LpAffineExpression(products)
obj -= total_distance

# One field per exposure
for i in range(num_exposures):
problem += pulp.lpSum(x[i, j] for j in range(num_fields)) == 1

# Limit the number of tasks each worker can handle (if applicable)
for j in range(num_fields):
problem += (
pulp.lpSum(x[i, j] for i in range(num_exposures))
<= max_tasks_per_worker
)

print("Add constraints to exclude impossible assignments...")
# Add constraints to exclude impossible assignments
for i in tqdm(range(num_exposures)):
for j in range(num_fields):
if not np.isfinite(cost_matrix[i][j]):
problem += x[i, j] == 0

if useTaskSepration:
print("Define constraints: enforce minimum task separation...")
## Define constraints: enforce minimum task separation
for i in tqdm(range(num_exposures)):
for j in range(num_fields):
for k in range(j + 1, num_fields):
problem += s[i, j] >= x[i, k] - x[i, j] - min_task_separation
problem += s[i, j] >= x[i, j] - x[i, k] - min_task_separation

print("Number of variables:", len(problem.variables()))
print("Number of constraints:", len(problem.constraints))

time_limit = 60 # Stop the solver after 60 seconds

if milpSolver == "PULP_CBC_CMD":
solver = pulp.getSolver(milpSolver)
elif milpSolver == "GUROBI":
solver = pulp.getSolver(
"GUROBI", manageEnv=True, envOptions=milpOptions, mip_gap=1e-6
)
else:
raise ValueError("milpSolver must be either PULP_CBC_CMD or GUROBI")

solver.timeLimit = timeLimit
# solver.msg = True
status = problem.solve(solver)

optimal_points = []
if status in [pulp.LpStatusOptimal]:
for i in range(num_exposures):
for j in range(num_fields):
if (
pulp.value(x[i, j]) is not None
and pulp.value(x[i, j]) > 0.5
and np.isfinite(cost_matrix[i][j])
):
optimal_points.append((i, j))
else:
print("The problem does not have a solution.")

else:
print("The localization is not visible from the site.")

return optimal_points
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ dependencies = [
'healpy',
'ortools',
'pandas',
'pulp',
'shapely',
'tables',
'regions'
Expand Down

0 comments on commit 6e4d52f

Please sign in to comment.