diff --git a/quantio/vector.py b/quantio/vector.py index fea8477..640205b 100644 --- a/quantio/vector.py +++ b/quantio/vector.py @@ -23,3 +23,7 @@ def __class_getitem__(cls, *_: object) -> type: def __getitem__(self, index: int) -> T: """Return the element at a specific index.""" return self._elements[index] + + def __setitem__(self, index: int, value: T) -> None: + """Set the element at a specific index.""" + self._elements[index] = value diff --git a/test/test_vector.py b/test/test_vector.py index dc1624d..60229ae 100644 --- a/test/test_vector.py +++ b/test/test_vector.py @@ -6,18 +6,25 @@ def test_init(): - actual: Vector[Length] = Vector([Length.zero(), Length.zero()]) - assert np.all(actual._elements == np.array([Length.zero(), Length.zero()])) + vec: Vector[Length] = Vector([Length.zero(), Length.zero()]) + assert np.all(vec._elements == np.array([Length.zero(), Length.zero()])) def test_init_with_type_hint(): - actual = Vector[Length]([Length.zero(), Length.zero()]) - assert np.all(actual._elements == np.array([Length.zero(), Length.zero()])) + vec = Vector[Length]([Length.zero(), Length.zero()]) + assert np.all(vec._elements == np.array([Length.zero(), Length.zero()])) def test_indexing(): - actual: Vector[Length] = Vector([Length(meters=1), Length(meters=2)]) - assert actual[0] == Length(meters=1) + vec: Vector[Length] = Vector([Length(meters=1), Length(meters=2)]) + assert vec[0] == Length(meters=1) + + +def test_set_item(): + vec: Vector[Length] = Vector([Length(meters=1), Length(meters=2)]) + vec[0] = Length(meters=3) + + assert vec[0] == Length(meters=3) if __name__ == "__main__":