Skip to content

Commit

Permalink
feat: Vector multiplication
Browse files Browse the repository at this point in the history
  • Loading branch information
unexcellent committed Nov 13, 2024
1 parent 4a0a9d7 commit 26a89aa
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 0 deletions.
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ ignore = [
"TCH003", # same as TCH001

"SIM103", # less readable in some cases imo

"SLF001", # necessary for vector operations
]

[tool.mypy]
Expand Down
29 changes: 29 additions & 0 deletions quantio/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

import numpy as np

from ._quantity_base import _QuantityBase

T = TypeVar("T")


Expand Down Expand Up @@ -38,6 +40,33 @@ def __sub__(self, other: Vector[T] | np.ndarray) -> Vector[T]:
other_elements = other._elements if isinstance(other, Vector) else np.array(other)
return Vector[T](self._elements - other_elements)

def __mul__(self, other: Vector | np.ndarray | float) -> np.ndarray:
"""Multipy this vector with either another vector or a scalar."""
if isinstance(self._elements[0], _QuantityBase):
self_elements = np.array([element._base_value for element in self._elements])
else:
self_elements = self._elements

if isinstance(other, (float, int)):
other_elements = np.array([other])

elif isinstance(other, _QuantityBase):
other_elements = np.array([other._base_value])

elif isinstance(other, Vector):
if isinstance(self._elements[0], _QuantityBase):
other_elements = np.array([element._base_value for element in other._elements])
else:
other_elements = other._elements

elif isinstance(other, np.ndarray):
other_elements = other

else:
raise TypeError

return self_elements * other_elements

def __eq__(self, other: object) -> bool:
"""Assess if this object is the same as another."""
if not isinstance(other, Vector):
Expand Down
48 changes: 48 additions & 0 deletions test/test_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,5 +59,53 @@ def test_subtract__wrong_type():
vec1 - vec2


def test_multiply__with_vector_of_quantities():
vec1: Vector[Length] = Vector([Length(meters=1), Length(meters=2)])
vec2: Vector[Length] = Vector([Length(meters=3), Length(meters=4)])

actual = vec2 * vec1
assert np.all(actual == np.array([3, 8]))


def test_multiply__with_vector_of_float():
vec1: Vector[float] = Vector([1, 2])
vec2: Vector[float] = Vector([3, 4])

actual = vec2 * vec1
assert np.all(actual == np.array([3, 8]))


def test_multiply__with_array():
vec: Vector[Length] = Vector([Length(meters=1), Length(meters=2)])
array = np.array([3, 4])

actual = vec * array
assert np.all(actual == np.array([3, 8]))


def test_multiply__with_scalar_float():
vec: Vector[Length] = Vector([Length(meters=1), Length(meters=2)])
scalar = 5

actual = vec * scalar
assert np.all(actual == np.array([5, 10]))


def test_multiply__with_scalar_quantitiy():
vec: Vector[Length] = Vector([Length(meters=1), Length(meters=2)])
scalar = Length(meters=5)

actual = vec * scalar
assert np.all(actual == np.array([5, 10]))


def test_multiply__wrong_dimension():
vec1: Vector[Length] = Vector([Length(meters=1), Length(meters=2), Length(meters=2)])
vec2: Vector[Length] = Vector([Length(meters=3), Length(meters=4)])

with pytest.raises(ValueError):
vec1 * vec2


if __name__ == "__main__":
pytest.main([__file__, "-v"])

0 comments on commit 26a89aa

Please sign in to comment.