Skip to content

Commit

Permalink
Reworked team detect
Browse files Browse the repository at this point in the history
- apply spectral clustering to solve sparest cut
- separate possession and pass calculation out for orthogonality
  • Loading branch information
Mikonooooo committed Oct 23, 2023
1 parent ab20720 commit a4cd8fd
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 186 deletions.
3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ yapf
isort==4.3.21
imageio

# Processing
scikit-learn

# View
streamlit>=1.18.1
hydralit_components>= 1.0.10
Expand Down
173 changes: 45 additions & 128 deletions src/processing/team_detect.py
Original file line number Diff line number Diff line change
@@ -1,136 +1,53 @@
from typing import Tuple
from state import GameState
import numpy as np
from sklearn.cluster import SpectralClustering

"""
Team Detection and Possession Finder
This module contains functions that will find the best team split. It will
also create a list of players in the order of ball possession.
This module contains functions that will find the best team split.
"""


# def connections(pos_lst, players):
# """
# Input:
# pos_lst [list]: list of player ids in the order of ball possession
# throughout the video
# players [list]: list of player ids
# Output:
# connects [dict]: dictionary of connections between players
# """
# connects = {}
# for i, player in enumerate(players):
# for j, player2 in enumerate(players):
# if i == j:
# continue
# name = player + player2
# connects.update({name: 0})

# curr = pos_lst[0]
# for i in range(1, len(pos_lst)):
# name = curr + pos_lst[i]
# connects[name] += 1
# return connects


def connections(pos_lst, players, player_idx):
"""
Input:
pos_lst [list]: list of player ids in the order of ball possession
throughout the video
players [list]: list of player ids
player_idx [dict]: dictionary of player ids to their index in the
players list
Output:
connects [list of lists]: 2D array of connections between players where
connects[i][j] is the number of times
player i passes to player j
"""
connects = [[0 for _ in range(len(players))] for _ in range(len(players))]
for i in range(0, len(pos_lst) - 1):
connects[player_idx.get(pos_lst[i][0])][player_idx.get(pos_lst[i + 1][0])] += 1
return connects


def possible_teams(players):
"""
Input:
players [list]: list of player ids
Output:
acc [list]: list of possible team splits
"""
num_people = len(players)

acc = []

def permutation(i, t):
if i >= num_people:
return
if len(t) == ppl_per_team:
acc.append((t, (set(players) - set(t))))
else:
permutation(i + 1, t.copy())
t.add(players[i])
permutation(i + 1, t.copy())

if num_people % 2 != 0:
ppl_per_team = int(num_people / 2) + 1
permutation(0, set())
ppl_per_team -= 1
permutation(0, set())
else:
ppl_per_team = int(num_people / 2)
permutation(0, set())
return acc


def team_split(state: GameState):
"""
Input:
state: a StatState class that holds all sorts of information
on the video
Output:
best_team [tuple]: tuple of two sets of player ids that are the best
team split
pos_lst [list[tuple]]: list of player ids in the order of ball
possession with start and finish frames
"""
player_list = state.players.keys()
pos_lst = possession_list(state.frames, player_list, thresh=11)
player_idx = {player: i for i, player in enumerate(player_list)}
connects = connections(pos_lst, player_list, player_idx)
teams = possible_teams(player_list)
best_team = None
min_count = 100000
for team in teams:
count = 0
team1 = list(team[0])
team2 = list(team[1])
for player1 in team1:
for player2 in team2:
count += connects[player_idx.get(player1)][player_idx.get(player2)]
count += connects[player_idx.get(player2)][player_idx.get(player1)]
if count < min_count:
min_count = count
best_team = team
return best_team, pos_lst, player_list


def compute_possession(player_pos, team1) -> Tuple[float, float]:
"""
Input: player possession, list of players on team 1
Computes and returns team1 possession, team2 possession.
"""
# total frames of each team's possession
team1_pos = 0
team2_pos = 0
for player, pos in player_pos.items():
for intervals in pos:
pos_time = intervals[1] - intervals[0]
if player in team1:
team1_pos += pos_time
else:
team2_pos += pos_time
total_pos = team1_pos + team2_pos

return team1_pos / total_pos, team2_pos / total_pos
def passing_matrix(state: GameState, p_list: list):
"computes passing matrix of a state"
n = len(p_list)
graph = np.zeros((n, n))
for i in range(n):
for j in range(n):
graph[i][j] = state.passes[p_list[i]][p_list[j]]
graph = graph + graph.T # directed -> undirected
return graph


def sparest_cut_weighted(graph):
"get sparest cut given pass frequency matrix"
# Perform spectral clustering
clustering = SpectralClustering(
n_clusters=2, affinity="precomputed", random_state=0, eigen_solver="arpack"
)
labels = clustering.fit_predict(graph)

# Calculate sparsity cut for weighted graph
cut_size = np.sum(graph[labels == 0][:, labels == 1])
sparsity = cut_size / min(np.sum(labels == 0), np.sum(labels == 1))

return labels, sparsity


def split_team(state: GameState):
"splits players into two teams"
n = len(state.players)
p_list = list(state.players.keys())

graph = passing_matrix(state, p_list)
labels, _ = sparest_cut_weighted(graph)
assert len(labels) == n

state.team1.clear()
state.team2.clear()
for i in range(n):
if labels[i] == 0:
state.team1.add(p_list[i])
else: # label[i] == 1
state.team2.add(p_list[i])
30 changes: 2 additions & 28 deletions src/processrunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,36 +42,10 @@ def run_parse(self):
def run_possession(self):
self.state.filter_players(threshold=100)
self.state.recompute_possession_list(threshold=20, join_threshold=20)
self.state.recompute_pass_from_possession()

def run_team_detect(self):
"""
TODO figure out how to decouple team and general processing more
TODO explain pos_list
Splits identified players into teams, then curates:
ball state, passes, player possession, and team possession
"""

teams, pos_list, playerids = team_detect.team_split(self.state.frames)
self.state.possession_list = pos_list
for pid in playerids:
self.state.players[pid] = {
"shots": 0,
"points": 0,
"rebounds": 0,
"assists": 0,
}
self.state.ball_state = general_detect.ball_state_update(
pos_list, len(self.state.frames) - 1
)
self.state.passes = general_detect.player_passes(pos_list)
self.state.possession = general_detect.player_possession(pos_list)

self.state.team1 = teams[0]
self.state.team2 = teams[1]

self.state.team1_pos, self.state.team2_pos = team_detect.compute_possession(
self.state.possession, self.state.team1
)
team_detect.split_team(self.state)

def run_shot_detect(self):
"""Runs shot detection and updates scores."""
Expand Down
62 changes: 32 additions & 30 deletions src/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,8 +307,7 @@ class BallState:

def __init__(self) -> None:
"""
Ball state containing
frames: number frames player appeared in
Ball state containing ball stuff
"""
# MUTABLE
Expand Down Expand Up @@ -356,15 +355,14 @@ class GameState:
def __init__(self) -> None:
"""
Initialises state; contains the following instance variables:
states: list of dictionaries with info at each frame
players: dictionary of players to PlayerState
balls: dictioanry of balls to BallState
possession_list: list of ball possession tuples
passes: dictionary of passes with their start and end frames and players involved
possession: dictionary of players with their possessions as list of frame tuples
team1, team2: list of players on each team
score1, score2: score of each team
team1_pos, team2_pos: percentage of possession for each team
frames: list of PlayerFrame
players: dictionary of PlayerState
ball: BallState
possessions: list of PossessionInterval
passes: dictionary of passes
shots: list of shots by player
team1: set of players on one team
team2: est of players on other team
"""
# MUTABLE

Expand All @@ -374,30 +372,20 @@ def __init__(self) -> None:
self.players: dict = {}
"Global player data: {player_0 : PlayerState, player_1 : PlayerState}"

self.balls: dict = {}
"{ball_0 : BallState, ball_1 : BallState}"
self.ball: BallState = BallState()
"Global ball data"

self.possessions: list = []
"[PossessionInterval]"

self.passes: dict = {}
"dictionary of passes {player_0 : {player_0 : 3}}"

self.shots: list = []
" list of shots: [(player_[id],start,end)]"

# EVERYTHING BELOW THIS POINT IS OUT-OF-DATE

# [(start_frame, end_frame, BallFrame)]
self.ball_state = None
# {'pass_id': {'frames': (start_frame, end_frame)}, 'players':(p1_id, p2_id)}}
self.passes = None

self.team1 = None
self.team2 = None

# statistics
self.score1 = 0
self.score2 = 0
self.team1_pos = 0
self.team2_pos = 0
self.team1: set = set()
self.team2: set = set()

def recompute_frame_count(self):
"recompute frame count of all players in frames"
Expand All @@ -416,7 +404,7 @@ def recompute_possession_list(self, threshold=20, join_threshold=20):
"""
lst = []
prev = None
while lst != prev: # until lists have converged
while lst != prev: # until lists have converged
prev = lst.copy()
self.grow_poss(lst)
self.join_poss(lst, threshold)
Expand Down Expand Up @@ -469,7 +457,7 @@ def filter_poss(self, lst: list, threshold: int = 20):
i = 0
while i < len(lst):
p: PossessionInterval = lst[i]
if p.length < threshold:
if p.length < threshold or p.playerid not in self.players:
lst.pop(i)
else:
i += 1 # next interval
Expand All @@ -481,6 +469,20 @@ def filter_players(self, threshold: int):
if v.frames < threshold:
self.players.pop(k)

def recompute_pass_from_possession(self):
"Recompute passes naively from possession list"
self.passes: dict = {} # reset pass dictionary
for p in self.players:
self.passes.update({p: {}})
for c in self.players:
self.passes.get(p).update({c: 0})

i = 0
for i in range(len(self.possessions) - 1):
p1 = self.possessions[i].playerid
p2 = self.possessions[i + 1].playerid
self.passes[p1][p2] += 1

def update_scores(self, madeshot_list):
"""
TODO check for correctness + potentially move out of state.py
Expand Down

0 comments on commit a4cd8fd

Please sign in to comment.