Skip to content

Commit

Permalink
feat: add Vector.from_numpy()
Browse files Browse the repository at this point in the history
  • Loading branch information
unexcellent committed Nov 22, 2024
1 parent 6069abd commit 598f0d3
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 0 deletions.
7 changes: 7 additions & 0 deletions quantio/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,13 @@ def tile(cls, element: T | Vector[T] | list[T], length: int) -> Vector[T]:

return Vector(np.tile(element, length))

@classmethod
def from_numpy(
cls, array: np.ndarray, element_class: type[_QuantityBase], unit: str
) -> Vector[_QuantityBase]:
"""Construct a quantity vector from a numpy array."""
return Vector([element_class(**{unit: elem}) for elem in array])

def to_numpy(self, unit: str | None = None) -> np.ndarray[float]:
"""Convert this vector into a numpy array of floats."""
if isinstance(self._elements[0], _QuantityBase):
Expand Down
6 changes: 6 additions & 0 deletions test/test_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,5 +243,11 @@ def test_sum__quantity():
assert actual == Length(meters=6)


def test_from_numpy__quantity():
array = np.array([0, 1, 2])
actual = Vector.from_numpy(array, Length, "meters")
assert actual == Vector([Length(meters=0), Length(meters=1), Length(meters=2)])


if __name__ == "__main__":
pytest.main([__file__, "-v"])

0 comments on commit 598f0d3

Please sign in to comment.