diff --git a/requirements.txt b/requirements.txt index 80ce400d..a947937e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -44,6 +44,9 @@ yapf isort==4.3.21 imageio +# Processing +scikit-learn + # View streamlit>=1.18.1 hydralit_components>= 1.0.10 diff --git a/src/processing/team_detect.py b/src/processing/team_detect.py index edd83789..ba9eb7a6 100644 --- a/src/processing/team_detect.py +++ b/src/processing/team_detect.py @@ -1,136 +1,53 @@ -from typing import Tuple from state import GameState +import numpy as np +from sklearn.cluster import SpectralClustering """ Team Detection and Possession Finder -This module contains functions that will find the best team split. It will -also create a list of players in the order of ball possession. +This module contains functions that will find the best team split. """ -# def connections(pos_lst, players): -# """ -# Input: -# pos_lst [list]: list of player ids in the order of ball possession -# throughout the video -# players [list]: list of player ids -# Output: -# connects [dict]: dictionary of connections between players -# """ -# connects = {} -# for i, player in enumerate(players): -# for j, player2 in enumerate(players): -# if i == j: -# continue -# name = player + player2 -# connects.update({name: 0}) - -# curr = pos_lst[0] -# for i in range(1, len(pos_lst)): -# name = curr + pos_lst[i] -# connects[name] += 1 -# return connects - - -def connections(pos_lst, players, player_idx): - """ - Input: - pos_lst [list]: list of player ids in the order of ball possession - throughout the video - players [list]: list of player ids - player_idx [dict]: dictionary of player ids to their index in the - players list - Output: - connects [list of lists]: 2D array of connections between players where - connects[i][j] is the number of times - player i passes to player j - """ - connects = [[0 for _ in range(len(players))] for _ in range(len(players))] - for i in range(0, len(pos_lst) - 1): - connects[player_idx.get(pos_lst[i][0])][player_idx.get(pos_lst[i + 1][0])] += 1 - return connects - - -def possible_teams(players): - """ - Input: - players [list]: list of player ids - Output: - acc [list]: list of possible team splits - """ - num_people = len(players) - - acc = [] - - def permutation(i, t): - if i >= num_people: - return - if len(t) == ppl_per_team: - acc.append((t, (set(players) - set(t)))) - else: - permutation(i + 1, t.copy()) - t.add(players[i]) - permutation(i + 1, t.copy()) - - if num_people % 2 != 0: - ppl_per_team = int(num_people / 2) + 1 - permutation(0, set()) - ppl_per_team -= 1 - permutation(0, set()) - else: - ppl_per_team = int(num_people / 2) - permutation(0, set()) - return acc - - -def team_split(state: GameState): - """ - Input: - state: a StatState class that holds all sorts of information - on the video - Output: - best_team [tuple]: tuple of two sets of player ids that are the best - team split - pos_lst [list[tuple]]: list of player ids in the order of ball - possession with start and finish frames - """ - player_list = state.players.keys() - pos_lst = possession_list(state.frames, player_list, thresh=11) - player_idx = {player: i for i, player in enumerate(player_list)} - connects = connections(pos_lst, player_list, player_idx) - teams = possible_teams(player_list) - best_team = None - min_count = 100000 - for team in teams: - count = 0 - team1 = list(team[0]) - team2 = list(team[1]) - for player1 in team1: - for player2 in team2: - count += connects[player_idx.get(player1)][player_idx.get(player2)] - count += connects[player_idx.get(player2)][player_idx.get(player1)] - if count < min_count: - min_count = count - best_team = team - return best_team, pos_lst, player_list - - -def compute_possession(player_pos, team1) -> Tuple[float, float]: - """ - Input: player possession, list of players on team 1 - Computes and returns team1 possession, team2 possession. - """ - # total frames of each team's possession - team1_pos = 0 - team2_pos = 0 - for player, pos in player_pos.items(): - for intervals in pos: - pos_time = intervals[1] - intervals[0] - if player in team1: - team1_pos += pos_time - else: - team2_pos += pos_time - total_pos = team1_pos + team2_pos - - return team1_pos / total_pos, team2_pos / total_pos +def passing_matrix(state: GameState, p_list: list): + "computes passing matrix of a state" + n = len(p_list) + graph = np.zeros((n, n)) + for i in range(n): + for j in range(n): + graph[i][j] = state.passes[p_list[i]][p_list[j]] + graph = graph + graph.T # directed -> undirected + return graph + + +def sparest_cut_weighted(graph): + "get sparest cut given pass frequency matrix" + # Perform spectral clustering + clustering = SpectralClustering( + n_clusters=2, affinity="precomputed", random_state=0, eigen_solver="arpack" + ) + labels = clustering.fit_predict(graph) + + # Calculate sparsity cut for weighted graph + cut_size = np.sum(graph[labels == 0][:, labels == 1]) + sparsity = cut_size / min(np.sum(labels == 0), np.sum(labels == 1)) + + return labels, sparsity + + +def split_team(state: GameState): + "splits players into two teams" + n = len(state.players) + p_list = list(state.players.keys()) + + graph = passing_matrix(state, p_list) + labels, _ = sparest_cut_weighted(graph) + assert len(labels) == n + + state.team1.clear() + state.team2.clear() + for i in range(n): + if labels[i] == 0: + state.team1.add(p_list[i]) + else: # label[i] == 1 + state.team2.add(p_list[i]) diff --git a/src/processrunner.py b/src/processrunner.py index 5748a1a9..66724858 100644 --- a/src/processrunner.py +++ b/src/processrunner.py @@ -42,36 +42,10 @@ def run_parse(self): def run_possession(self): self.state.filter_players(threshold=100) self.state.recompute_possession_list(threshold=20, join_threshold=20) + self.state.recompute_pass_from_possession() def run_team_detect(self): - """ - TODO figure out how to decouple team and general processing more - TODO explain pos_list - Splits identified players into teams, then curates: - ball state, passes, player possession, and team possession - """ - - teams, pos_list, playerids = team_detect.team_split(self.state.frames) - self.state.possession_list = pos_list - for pid in playerids: - self.state.players[pid] = { - "shots": 0, - "points": 0, - "rebounds": 0, - "assists": 0, - } - self.state.ball_state = general_detect.ball_state_update( - pos_list, len(self.state.frames) - 1 - ) - self.state.passes = general_detect.player_passes(pos_list) - self.state.possession = general_detect.player_possession(pos_list) - - self.state.team1 = teams[0] - self.state.team2 = teams[1] - - self.state.team1_pos, self.state.team2_pos = team_detect.compute_possession( - self.state.possession, self.state.team1 - ) + team_detect.split_team(self.state) def run_shot_detect(self): """Runs shot detection and updates scores.""" diff --git a/src/state.py b/src/state.py index e012d79d..25dd698b 100644 --- a/src/state.py +++ b/src/state.py @@ -307,8 +307,7 @@ class BallState: def __init__(self) -> None: """ - Ball state containing - frames: number frames player appeared in + Ball state containing ball stuff """ # MUTABLE @@ -356,15 +355,14 @@ class GameState: def __init__(self) -> None: """ Initialises state; contains the following instance variables: - states: list of dictionaries with info at each frame - players: dictionary of players to PlayerState - balls: dictioanry of balls to BallState - possession_list: list of ball possession tuples - passes: dictionary of passes with their start and end frames and players involved - possession: dictionary of players with their possessions as list of frame tuples - team1, team2: list of players on each team - score1, score2: score of each team - team1_pos, team2_pos: percentage of possession for each team + frames: list of PlayerFrame + players: dictionary of PlayerState + ball: BallState + possessions: list of PossessionInterval + passes: dictionary of passes + shots: list of shots by player + team1: set of players on one team + team2: est of players on other team """ # MUTABLE @@ -374,30 +372,20 @@ def __init__(self) -> None: self.players: dict = {} "Global player data: {player_0 : PlayerState, player_1 : PlayerState}" - self.balls: dict = {} - "{ball_0 : BallState, ball_1 : BallState}" + self.ball: BallState = BallState() + "Global ball data" self.possessions: list = [] "[PossessionInterval]" + self.passes: dict = {} + "dictionary of passes {player_0 : {player_0 : 3}}" + self.shots: list = [] " list of shots: [(player_[id],start,end)]" - # EVERYTHING BELOW THIS POINT IS OUT-OF-DATE - - # [(start_frame, end_frame, BallFrame)] - self.ball_state = None - # {'pass_id': {'frames': (start_frame, end_frame)}, 'players':(p1_id, p2_id)}} - self.passes = None - - self.team1 = None - self.team2 = None - - # statistics - self.score1 = 0 - self.score2 = 0 - self.team1_pos = 0 - self.team2_pos = 0 + self.team1: set = set() + self.team2: set = set() def recompute_frame_count(self): "recompute frame count of all players in frames" @@ -416,7 +404,7 @@ def recompute_possession_list(self, threshold=20, join_threshold=20): """ lst = [] prev = None - while lst != prev: # until lists have converged + while lst != prev: # until lists have converged prev = lst.copy() self.grow_poss(lst) self.join_poss(lst, threshold) @@ -469,7 +457,7 @@ def filter_poss(self, lst: list, threshold: int = 20): i = 0 while i < len(lst): p: PossessionInterval = lst[i] - if p.length < threshold: + if p.length < threshold or p.playerid not in self.players: lst.pop(i) else: i += 1 # next interval @@ -481,6 +469,20 @@ def filter_players(self, threshold: int): if v.frames < threshold: self.players.pop(k) + def recompute_pass_from_possession(self): + "Recompute passes naively from possession list" + self.passes: dict = {} # reset pass dictionary + for p in self.players: + self.passes.update({p: {}}) + for c in self.players: + self.passes.get(p).update({c: 0}) + + i = 0 + for i in range(len(self.possessions) - 1): + p1 = self.possessions[i].playerid + p2 = self.possessions[i + 1].playerid + self.passes[p1][p2] += 1 + def update_scores(self, madeshot_list): """ TODO check for correctness + potentially move out of state.py