Skip to content

Commit

Permalink
perf: improve compute_probability_of_success algorithm
Browse files Browse the repository at this point in the history
  • Loading branch information
Net-Mist committed Jan 13, 2024
1 parent 002215b commit abec26f
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 35 deletions.
6 changes: 6 additions & 0 deletions Readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,12 @@ Contains the definitions of `PlanetId`, `GalaxyRoutes`, `PlanetCatalog` and `Bou

Contains the `compute_probability_of_success` function.

> Implementation notes:
> It is fundamentally an A\* algorithm, patched to explore first paths without bounty hunters.
> First, a Dijkstra algorithm is run to compute the shortest distance between every planet and the destination. This will be the heuristic function of the A\* algorithm.
> Then the A\* is run, first without allowing to cross the path of a bounty hunter, then allowing a single one, ...
> This logic is automatically implemented thanks to a BinaryHeap.
### Application services

Contains the definition of `MillenniumFalconData` and `EmpireData` matching the json formats specified in the requirements of the app, and `Route` matching the database data format (but without db-related types or field) and some code to bridge the data.
Expand Down
148 changes: 113 additions & 35 deletions src/domain_services.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,87 @@
use std::{cmp::Reverse, collections::BinaryHeap};
use std::{
cmp::Reverse,
collections::{BinaryHeap, HashMap, HashSet},
};

use anyhow::Result;

use crate::domain_models::{BountyHunterPlanning, GalaxyRoutes, PlanetCatalog, PlanetId};

/// When exploring all the states of the simulation to find the best route, we want to
/// privilege the ones with less bounty hunter, then the ones that minimize the elapsed_time.
/// To do this, we derive the trait Ord that will automatically sort the State, first by
/// `n_bounty_hunter`, then by `elapsed_time` (then by `fluel` and then by `planet`, but this we don't care)
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord)]
#[derive(Debug, PartialEq, Eq, Hash, Clone)]
struct State {
pub n_bounty_hunter: u64,
pub elapsed_time: u64,
pub fluel: u64,
pub planet: PlanetId,
n_bounty_hunter: u64,
elapsed_time: u64,
time_to_destination: u64,
fuel: u64,
planet: PlanetId,
}

impl PartialOrd for State {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}

impl Ord for State {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
if self.n_bounty_hunter != other.n_bounty_hunter {
return self.n_bounty_hunter.cmp(&other.n_bounty_hunter);
}
let sum_time = self.elapsed_time + self.time_to_destination;
let other_sum_time = other.elapsed_time + other.time_to_destination;
sum_time.cmp(&other_sum_time)
}
}

#[derive(PartialEq, Eq)]
struct AllTimeState {
time: u64,
planet_id: PlanetId,
}

impl PartialOrd for AllTimeState {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}

impl Ord for AllTimeState {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.time.cmp(&other.time)
}
}

/// run a Dijkstra algorithm to compute the minimal distance from a planet to the destination,
/// without considering the bounty hunters.
/// the output of this function will be the A* heuristic
fn compute_all_time_to_destination(
galaxy_routes: &GalaxyRoutes,
destination_id: &PlanetId,
) -> Result<HashMap<PlanetId, u64>> {
let mut time_to_destination = HashMap::new();
let mut planet_to_process = BinaryHeap::from([Reverse(AllTimeState {
time: 0,
planet_id: *destination_id,
})]);

while let Some(Reverse(state)) = planet_to_process.pop() {
if let std::collections::hash_map::Entry::Vacant(e) =
time_to_destination.entry(state.planet_id)
{
// first time we see this planet
e.insert(state.time);
} else {
// this planet has already been processed
continue;
}
for (neighbour_planet_id, time) in galaxy_routes.get(&state.planet_id)? {
planet_to_process.push(Reverse(AllTimeState {
time: state.time + time,
planet_id: *neighbour_planet_id,
}));
}
}
Ok(time_to_destination)
}

pub fn compute_probability_of_success(
Expand All @@ -35,17 +103,28 @@ pub fn compute_probability_of_success(
None => return Ok(0.), // arrival planet is not connected to other planets. How did the rebel get there ?
};

let mut state_to_process = BinaryHeap::new();
state_to_process.push(Reverse(State {
let all_time_to_destination = compute_all_time_to_destination(galaxy_routes, arrival_id)?;

let mut state_to_process = BinaryHeap::from([Reverse(State {
n_bounty_hunter: 0,
elapsed_time: 0,
fluel: autonomy,
fuel: autonomy,
planet: *departure_id,
}));
time_to_destination: *all_time_to_destination
.get(departure_id)
.unwrap_or(&u64::MAX),
})]);

while let Some(state) = state_to_process.pop() {
let state = state.0;
if state.elapsed_time > countdown {
let mut seen_state = HashSet::new();

while let Some(Reverse(state)) = state_to_process.pop() {
if seen_state.contains(&state) {
// this state has already been explored
continue;
}
seen_state.insert(state.clone());
if state.elapsed_time.saturating_add(state.time_to_destination) > countdown {
// then it is not possible to reach the destination from this state
continue;
}
let n_bounty_hunter = state.n_bounty_hunter
Expand All @@ -55,24 +134,28 @@ pub fn compute_probability_of_success(
return Ok(1. - probability_been_captured(n_bounty_hunter));
}

// Millennium Falcon can refluel
// Millennium Falcon can refuel
state_to_process.push(Reverse(State {
n_bounty_hunter,
elapsed_time: state.elapsed_time + 1,
fluel: autonomy,
fuel: autonomy,
planet: state.planet,
time_to_destination: state.time_to_destination,
}));

// or visit neightbours planets, if it has enough fluel
for (new_planet_id, time) in galaxy_routes.get(&state.planet)? {
if *time > state.fluel {
if *time > state.fuel {
continue;
}
state_to_process.push(Reverse(State {
n_bounty_hunter,
elapsed_time: state.elapsed_time + time,
fluel: state.fluel - time,
fuel: state.fuel - time,
planet: *new_planet_id,
time_to_destination: *all_time_to_destination
.get(new_planet_id)
.unwrap_or(&u64::MAX),
}));
}
}
Expand All @@ -95,6 +178,8 @@ fn probability_been_captured(n_bounty_hunter: u64) -> f64 {
#[cfg(test)]
mod test {

use std::collections::{HashMap, HashSet};

use crate::{
domain_models::{BountyHunterPlanning, GalaxyRoutes, PlanetCatalog},
domain_services::probability_been_captured,
Expand Down Expand Up @@ -123,20 +208,13 @@ mod test {
let dagobah_id = *planet_id_map.get("Dagobah").unwrap();
let endor_id = *planet_id_map.get("Endor").unwrap();

let hunter_planning = BountyHunterPlanning::new(
[(dagobah_id, [1].into_iter().collect())]
.into_iter()
.collect(),
);
let galaxy_routes = GalaxyRoutes::from_hashmap(
[
(tatooine_id, vec![(dagobah_id, 1)]),
(dagobah_id, vec![(tatooine_id, 1), (endor_id, 1)]),
(endor_id, vec![(dagobah_id, 1)]),
]
.into_iter()
.collect(),
)
let hunter_planning =
BountyHunterPlanning::new([(dagobah_id, HashSet::from([1]))].into_iter().collect());
let galaxy_routes = GalaxyRoutes::from_hashmap(HashMap::from([
(tatooine_id, vec![(dagobah_id, 1)]),
(dagobah_id, vec![(tatooine_id, 1), (endor_id, 1)]),
(endor_id, vec![(dagobah_id, 1)]),
]))
.unwrap();

let r = compute_probability_of_success(
Expand Down

0 comments on commit abec26f

Please sign in to comment.