Skip to content

Commit

Permalink
perf: make Vector.__add__() and Vector.__sub__() faster by not callin…
Browse files Browse the repository at this point in the history
…g Vector.elements
  • Loading branch information
unexcellent committed Dec 11, 2024
1 parent 3730c1f commit 3679c64
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 9 deletions.
20 changes: 12 additions & 8 deletions quantio/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,9 @@ def tile(cls, element: T | Vector[T] | list[T], length: int) -> Vector[T]:
return Vector(np.tile(element, length))

@classmethod
def from_numpy(
cls, array: np.ndarray, element_class: type[Quantity], unit: str
) -> Vector[Quantity]:
def from_numpy(cls, array: np.ndarray, element_class: type[Quantity], unit: str) -> Vector[T]:
"""Construct a quantity vector from a numpy array."""
vector: Vector[Quantity] = Vector([0])
vector: Vector[T] = Vector([0])

if unit == element_class.BASE_UNIT:
vector._elements = array
Expand Down Expand Up @@ -119,15 +117,21 @@ def __add__(self, other: Vector[T] | np.ndarray) -> Vector[T]:
"""Add another vector to this one."""
if not isinstance(other[0], self._quantitiy):
raise CanNotAddTypesError(self[0].__class__.__name__, other[0].__class__.__name__)
other_elements = other.elements if isinstance(other, Vector) else np.array(other)
return Vector[T](self.elements + other_elements)
return Vector[T].from_numpy(
self._elements + other._elements,
self._quantitiy,
self._quantitiy.BASE_UNIT, # type: ignore[attr-defined]
)

def __sub__(self, other: Vector[T] | np.ndarray) -> Vector[T]:
"""Subtract another vector from this one."""
if not isinstance(other[0], self._quantitiy):
raise CanNotSubtractTypesError(self[0].__class__.__name__, other[0].__class__.__name__)
other_elements = other.elements if isinstance(other, Vector) else np.array(other)
return Vector[T](self.elements - other_elements)
return Vector[T].from_numpy(
self._elements - other._elements,
self._quantitiy,
self._quantitiy.BASE_UNIT, # type: ignore[attr-defined]
)

def __mul__(self, other: Vector | np.ndarray | float) -> np.ndarray:
"""Multiply this vector with either another vector or a scalar."""
Expand Down
2 changes: 1 addition & 1 deletion test/test_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def test_sum__quantity():

def test_from_numpy__quantity():
array = np.array([0, 1, 2])
actual = Vector.from_numpy(array, Length, "meters")
actual = Vector[Length].from_numpy(array, Length, "meters")
assert actual == Vector([Length(meters=0), Length(meters=1), Length(meters=2)])


Expand Down

0 comments on commit 3679c64

Please sign in to comment.