Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Nary tree #26

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ generated
.vscode
rendering_times.csv
media/
.coverage

venv
.coverage
.python-version
22 changes: 11 additions & 11 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,14 @@ repos:
flake8-rst-docstrings==0.3.0,
flake8-simplify==0.19.3,
]
- repo: local
hooks:
- id: pytest
name: pytest
entry: poetry run pytest -cov=src tests/
language: system
pass_filenames: false
# alternatively you could `types: [python]` so it only runs when python files change
# though tests might be invalidated if you were to say change a data file
always_run: true
stages: [push]
# - repo: local
# hooks:
# - id: pytest
# name: pytest
# entry: poetry run pytest -cov=src tests/
# language: system
# pass_filenames: false
# # alternatively you could `types: [python]` so it only runs when python files change
# # though tests might be invalidated if you were to say change a data file
# always_run: true
# stages: [push]
144 changes: 144 additions & 0 deletions src/manim_data_structures/m_tree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
import operator as op
import random
from collections import defaultdict
from copy import copy
from functools import partialmethod, reduce
from typing import Any, Callable, Dict, Hashable, List, Tuple

import numpy as np
from manim import *
from manim import WHITE, Graph, Mobject, VMobject


class Tree(VMobject):
"""Computer Science Tree Data Structure"""

_graph: Graph
__layout_config: dict
__layout_scale: float
__layout: str | dict
__vertex_type: Callable[..., Mobject]

# __parents: list
# __children: dict[Hashable, list] = defaultdict(list)

def __init__(
self,
nodes: dict[int, Any],
edges: list[tuple[int, int]],
vertex_type: Callable[..., Mobject],
edge_buff=0.4,
layout="tree",
layout_config={"vertex_spacing": (-1, 1)},
root_vertex=0,
**kwargs
):
super().__init__(**kwargs)
vertex_mobjects = {k: vertex_type(v) for k, v in nodes.items()}
self.__layout_config = layout_config
self.__layout_scale = len(nodes) * 0.5
self.__layout = layout
self.__vertex_type = vertex_type
self._graph = Graph(
list(nodes),
edges,
vertex_mobjects=vertex_mobjects,
layout=layout,
root_vertex=0,
layout_config=self.__layout_config,
layout_scale=len(nodes) * 0.5,
edge_config={"stroke_width": 1, "stroke_color": WHITE},
)

def update_edges(graph: Graph):
"""Updates edges of graph"""
for (u, v), edge in graph.edges.items():
buff_vec = (
edge_buff
* (graph[u].get_center() - graph[v].get_center())
/ np.linalg.norm(graph[u].get_center() - graph[v].get_center())
)
edge.put_start_and_end_on(
graph[u].get_center() - buff_vec, graph[v].get_center() + buff_vec
)

self._graph.updaters.clear()
self._graph.updaters.append(update_edges)
self.add(self._graph)

def insert_node(self, node: Any, edge: tuple[Hashable, Hashable]):
"""Inserts a node into the graph as (parent, node)"""
self._graph.add_vertices(
edge[1], vertex_mobjects={edge[1]: self.__vertex_type(node)}
)
self._graph.add_edges(edge)
return self

def insert_node2(self, node: Any, edge: tuple[Hashable, Hashable]):
"""Inserts a node into the graph as (parent, node)"""
self._graph.change_layout(
self.__layout,
layout_scale=self.__layout_scale,
layout_config=self.__layout_config,
root_vertex=0,
)
for mob in self.family_members_with_points():
if (mob.get_center() == self._graph[edge[1]].get_center()).all():
mob.points = mob.points.astype("float")
return self

def insert_node3(self, node: Any, edge: tuple[Hashable, Hashable]):
"""Inserts a node into the graph as (parent, node)"""
self.suspend_updating()
self.insert_node(node, edge)
# self.resume_updating()
self.insert_node2(node, edge)

return self

def remove_node(self, node: Hashable):
"""Removes a node from the graph"""
self._graph.remove_vertices(node)

# def insert_node2(self):
# """Shift by the given vectors.
#
# Parameters
# ----------
# vectors
# Vectors to shift by. If multiple vectors are given, they are added
# together.
#
# Returns
# -------
# :class:`Mobject`
# ``self``
#
# See also
# --------
# :meth:`move_to`
# """
#
# total_vector = reduce(op.add, vectors)
# for mob in self.family_members_with_points():
# mob.points = mob.points.astype("float")
# mob.points += total_vector
#
# return self


if __name__ == "__main__":

class TestScene(Scene):
def construct(self):
# make a parent list for a tree
tree = Tree({0: 0, 1: 1, 2: 2, 3: 3}, [(0, 1), (0, 2), (1, 3)], Integer)
self.play(Create(tree))
self.wait()
self.play(tree.animate.insert_node3(4, (2, 4)), run_time=0)
self.wait()

config.preview = True
config.renderer = "cairo"
config.quality = "low_quality"
TestScene().render(preview=True)
118 changes: 118 additions & 0 deletions src/manim_data_structures/nary_tree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
from typing import Any, Callable, Hashable

import networkx as nx
from m_tree import Tree
from manim import Mobject


def _nary_layout(
T: nx.classes.graph.Graph,
vertex_spacing: tuple | None = None,
n: int | None = None,
):
if not n:
raise ValueError("the n-ary tree layout requires the n parameter")
if not nx.is_tree(T):
raise ValueError("The tree layout must be used with trees")

max_height = NaryTree.calc_loc(max(T), n)[1]

def calc_pos(x, y):
"""
Scales the coordinates to the desired spacing
"""
return (x - (n**y - 1) / 2) * vertex_spacing[0] * n ** (
max_height - y
), y * vertex_spacing[1]

return {
i: np.array([x, y, 0])
for i, (x, y) in ((i, calc_pos(*NaryTree.calc_loc(i, n))) for i in T)
}


class NaryTree(Tree):
def __init__(
self,
nodes: dict[int, Any],
num_child: int,
vertex_type: Callable[..., Mobject],
edge_buff=0.4,
layout_config=None,
**kwargs
):
if layout_config is None:
layout_config = {"vertex_spacing": (-1, 1)}
self.__layout_config = layout_config
self.num_child = num_child

edges = [(self.get_parent(e), e) for e in nodes if e != 0]
super().__init__(nodes, edges, vertex_type, edge_buff, **kwargs)
dict_layout = _nary_layout(self._graph._graph, n=num_child, **layout_config)
self._graph.change_layout(dict_layout)

@staticmethod
def calc_loc(i, n):
"""
Calculates the coordinates in terms of the shifted level order x position and level height
"""
if n == 1:
return 1, i + 1
height = int(np.emath.logn(n, i * (n - 1) + 1))
node_shift = (1 - n**height) // (1 - n)
return i - node_shift, height

@staticmethod
def calc_idx(loc, n):
"""
Calculates the index from the coordinates
"""
x, y = loc
if n == 1:
return y - 1

return int(x + (1 - n**y) // (1 - n))

def get_parent(self, idx):
"""
Returns the index of the parent of the node at the given index
"""
x, y = NaryTree.calc_loc(idx, self.num_child)
new_loc = x // self.num_child, y - 1
return NaryTree.calc_idx(new_loc, self.num_child)

def insert_node(self, node: Any, index: Hashable):
"""Inserts a node into the graph"""
res = super().insert_node(node, (self.get_parent(index), index))
dict_layout = _nary_layout(
self._graph._graph, n=self.num_child, **self.__layout_config
)
self._graph.change_layout(dict_layout)
self.update()
return res


if __name__ == "__main__":
from manim import *

class TestScene(Scene):
def construct(self):
tree = NaryTree(
{0: 0, 1: 1, 4: 4},
num_child=2,
vertex_type=Integer,
layout_config={"vertex_spacing": (1, -1)},
)
# tree._graph.change_layout(root_vertex=0, layout_config=tree._Tree__layout_config,
# layout_scale=tree._Tree__layout_scale)
self.play(Create(tree))
self.wait()
tree.insert_node(1, 3)
self.wait()
tree.remove_node(4)
self.wait()

config.preview = True
config.renderer = "cairo"
config.quality = "low_quality"
TestScene().render(preview=True)
11 changes: 11 additions & 0 deletions tests/test_mtree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# TODO: Fill with appropriate tests
def test_getitem():
pass


def test_setitem():
pass


def test_iteration():
pass