diff --git a/quantio/vector.py b/quantio/vector.py index 526508a..ca26bfe 100644 --- a/quantio/vector.py +++ b/quantio/vector.py @@ -33,6 +33,11 @@ def __add__(self, other: Vector[T] | np.ndarray) -> Vector[T]: other_elements = other._elements if isinstance(other, Vector) else np.array(other) return Vector[T](self._elements + other_elements) + def __sub__(self, other: Vector[T] | np.ndarray) -> Vector[T]: + """Subtract another vector from this one.""" + other_elements = other._elements if isinstance(other, Vector) else np.array(other) + return Vector[T](self._elements - other_elements) + def __eq__(self, other: object) -> bool: """Assess if this object is the same as another.""" if not isinstance(other, Vector): diff --git a/test/test_vector.py b/test/test_vector.py index 6304d05..5f922ac 100644 --- a/test/test_vector.py +++ b/test/test_vector.py @@ -2,7 +2,7 @@ import numpy as np -from quantio import Vector, Length, CanNotAddTypesError +from quantio import Vector, Length, CanNotAddTypesError, CanNotSubtractTypesError def test_init(): @@ -43,5 +43,21 @@ def test_addition__wrong_type(): vec1 + vec2 +def test_subtract(): + vec1: Vector[Length] = Vector([Length(meters=1), Length(meters=2)]) + vec2: Vector[Length] = Vector([Length(meters=3), Length(meters=3)]) + + actual = vec2 - vec1 + assert actual == Vector([Length(meters=2), Length(meters=1)]) + + +def test_subtract__wrong_type(): + vec1: Vector[Length] = Vector([Length(meters=1), Length(meters=2)]) + vec2: Vector[float] = Vector([3, 4]) + + with pytest.raises(CanNotSubtractTypesError): + vec1 - vec2 + + if __name__ == "__main__": pytest.main([__file__, "-v"])