Skip to content

Commit

Permalink
remove state utils
Browse files Browse the repository at this point in the history
  • Loading branch information
Bonifatius94 committed Jan 8, 2024
1 parent f6da285 commit 7e2ebd7
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 118 deletions.
82 changes: 73 additions & 9 deletions pysocialforce/forces.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
from numba import njit

import logging

from pysocialforce import stateutils
logging.getLogger('numba').setLevel(logging.WARNING)

from pysocialforce.scene import Line2D, Point2D, PedState
Expand Down Expand Up @@ -67,7 +65,7 @@ def __call__(self):
pos = self.peds.pos()
vel = self.peds.vel()
goal = self.peds.goal()
direction, dist = stateutils.normalize(goal - pos)
direction, dist = normalize(goal - pos)
force = np.zeros((self.peds.size(), 2))
force[dist > goal_threshold] = (
direction * self.peds.max_speeds.reshape((-1, 1)) - vel.reshape((-1, 2))
Expand Down Expand Up @@ -317,9 +315,9 @@ def __call__(self):
if len(member_pos) == 0:
continue

com = stateutils.centroid(member_pos)
com = centroid(member_pos)
force_vec = com - member_pos
norms = stateutils.speeds(force_vec)
norms = np.linalg.norm(force_vec, axis=1)
softened_factor = (np.tanh(norms - threshold) + 1) / 2
forces[group, :] += (force_vec.T * softened_factor).T
return forces * self.config.factor
Expand All @@ -344,8 +342,8 @@ def __call__(self):

size = len(group)
member_pos = self.peds.pos()[group, :]
diff = stateutils.each_diff(member_pos) # others - self
_, norms = stateutils.normalize(diff)
diff = each_diff(member_pos) # others - self
_, norms = normalize(diff)
diff[norms > threshold, :] = 0
forces[group, :] += np.sum(diff.reshape((size, -1, 2)), axis=1)

Expand All @@ -366,7 +364,7 @@ def __call__(self):
return forces

ped_positions = self.peds.pos()
directions, dist = stateutils.desired_directions(self.peds.state)
directions, dist = desired_directions(self.peds.state)

for group in self.peds.groups:
group_size = len(group)
Expand All @@ -387,7 +385,7 @@ def group_gaze_force(
for i in range(group_size):
# use center of mass without the current agent
other_member_pos = member_pos[np.arange(group_size) != i, :2]
mass_center_without_ped = stateutils.centroid(other_member_pos)
mass_center_without_ped = centroid(other_member_pos)
relative_com_x = mass_center_without_ped[0] - member_pos[i, 0]
relative_com_y = mass_center_without_ped[1] - member_pos[i, 1]
com_dir, com_dist = norm_vec((relative_com_x, relative_com_y))
Expand All @@ -399,3 +397,69 @@ def group_gaze_force(
out_forces[i, 0] = force_x
out_forces[i, 1] = force_y
return out_forces


@njit
def vec_len_2d(vec_x: float, vec_y: float) -> float:
return (vec_x**2 + vec_y**2)**0.5


@njit
def normalize(vecs: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
"""Normalize nx2 array along the second axis
input: [n,2] ndarray
output: (normalized vectors, norm factors)
"""
num_vecs = vecs.shape[0]
vec_lengths = np.zeros((num_vecs))
unit_vecs = np.zeros((num_vecs, 2))

for i, (vec_x, vec_y) in enumerate(vecs):
vec_len = vec_len_2d(vec_x, vec_y)
vec_lengths[i] = vec_len
if vec_len > 0:
unit_vecs[i] = vecs[i] / vec_len

return unit_vecs, vec_lengths


@njit
def desired_directions(state: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
"""Given the current state and destination, compute desired direction."""
destination_vectors = state[:, 4:6] - state[:, 0:2]
directions, dist = normalize(destination_vectors)
return directions, dist


@njit
def vec_diff(vecs: np.ndarray) -> np.ndarray:
"""r_ab
r_ab := r_a − r_b.
"""
diff = np.expand_dims(vecs, 1) - np.expand_dims(vecs, 0)
return diff


def each_diff(vecs: np.ndarray, keepdims=False) -> np.ndarray:
"""
:param vecs: nx2 array
:return: diff with diagonal elements removed
"""
diff = vec_diff(vecs)
diff = diff[~np.eye(diff.shape[0], dtype=bool), :]
if keepdims:
diff = diff.reshape(vecs.shape[0], -1, vecs.shape[1])
return diff


@njit
def centroid(vecs: np.ndarray) -> Tuple[float, float]:
"""Center-of-mass of a given group as arithmetic mean."""
num_datapoints = vecs.shape[0]
centroid_x, centroid_y = 0, 0
for x, y in vecs:
centroid_x += x
centroid_y += y
centroid_x /= num_datapoints
centroid_y /= num_datapoints
return centroid_x, centroid_y
9 changes: 4 additions & 5 deletions pysocialforce/scene.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import numpy as np

from pysocialforce import stateutils
from pysocialforce.config import SceneConfig


Expand Down Expand Up @@ -71,15 +70,15 @@ def tau(self):

def speeds(self):
"""Return the speeds corresponding to a given state."""
return stateutils.speeds(self.state)
return np.linalg.norm(self.vel(), axis=1)

def step(self, force, groups=None):
"""Move peds according to forces"""
# desired velocity
desired_velocity = self.vel() + self.d_t * force
desired_velocity = self.capped_velocity(desired_velocity, self.max_speeds)
# stop when arrived
desired_velocity[stateutils.desired_directions(self.state)[1] < 0.5] = [0, 0]
# desired_velocity[stateutils.desired_directions(self.state)[1] < 0.5] = [0, 0]

# update state
next_state = self.state
Expand All @@ -88,8 +87,8 @@ def step(self, force, groups=None):
next_groups = groups if groups is not None else self.groups
self.update(next_state, next_groups)

def desired_directions(self):
return stateutils.desired_directions(self.state)[0]
# def desired_directions(self):
# return stateutils.desired_directions(self.state)[0]

@staticmethod
def capped_velocity(desired_velocity, max_velocity):
Expand Down
104 changes: 0 additions & 104 deletions pysocialforce/stateutils.py

This file was deleted.

0 comments on commit 7e2ebd7

Please sign in to comment.