Skip to content

Commit

Permalink
feat: extend Vector.to_numpy() to include units
Browse files Browse the repository at this point in the history
  • Loading branch information
unexcellent committed Nov 14, 2024
1 parent a99a18b commit 7d00b0d
Show file tree
Hide file tree
Showing 7 changed files with 90 additions and 9 deletions.
16 changes: 16 additions & 0 deletions generate/generate_boilerplate.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ def generate_boilerplate(quantities_with_fields: dict[str, dict[str, str]]) -> N
current_line_is_part_of_autogenerated_code = True
target_content_with_boilerplate.append(line)
target_content_with_boilerplate.append("")
target_content_with_boilerplate.extend(
_generate_base_unit(current_class, quantities_with_fields[current_class])
)
target_content_with_boilerplate.extend(
_generate_properties(current_class, quantities_with_fields[current_class])
)
Expand All @@ -46,6 +49,19 @@ def generate_boilerplate(quantities_with_fields: dict[str, dict[str, str]]) -> N
target_file.writelines(target_content_with_boilerplate)


def _generate_base_unit(current_class: str, units: dict[str, str]) -> list[str]:
base_unit = None
for unit, factor in units.items():
if float(eval(factor)) == 1:
base_unit = unit
break

if base_unit is None:
raise ValueError(f"{current_class} needs a unit with factor equal to 1.")

return [f' _BASE_UNIT = "{base_unit}"', ""]


def _generate_properties(current_class: str, units: dict[str, str]) -> list[str]:
code = []

Expand Down
3 changes: 2 additions & 1 deletion quantio/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""The main quantio package."""

from .exceptions import CanNotAddTypesError, CanNotSubtractTypesError
from .exceptions import CanNotAddTypesError, CanNotSubtractTypesError, NoUnitSpecifiedError
from .quantities import Acceleration, Angle, Area, Length, Mass, Time, Velocity
from .vector import Vector

Expand All @@ -15,4 +15,5 @@
"Vector",
"CanNotAddTypesError",
"CanNotSubtractTypesError",
"NoUnitSpecifiedError",
]
3 changes: 3 additions & 0 deletions quantio/_quantity_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ class _QuantityBase(ABC):
_base_value: float
"The base unit of the quantity."

_BASE_UNIT: str
"Name of the unit with a factor of 1."

def __eq__(self, other: object) -> bool:
"""Assess if this object is the same as another."""
if isinstance(other, type(self)):
Expand Down
10 changes: 10 additions & 0 deletions quantio/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,13 @@ class CanNotSubtractTypesError(TypeError):

def __init__(self, self_type_descriptor: str, other_type_descriptor: str) -> None:
super().__init__(f"Can not subtract {other_type_descriptor} from {self_type_descriptor}")


class NoUnitSpecifiedError(TypeError):
"""Raised when a Vector[_QuantityBase] is converted to a np.array without specifying a unit."""

def __init__(self) -> None:
super().__init__(
"When a vector with quantity elements is converted into a numpy array, a unit must be "
"specified."
)
14 changes: 14 additions & 0 deletions quantio/quantities.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ class Acceleration(_QuantityBase):

# --- This part is auto generated. Do not change manually. ---

_BASE_UNIT = "meters_per_square_second"

@property
def meters_per_square_second(self) -> float:
"""The acceleration in meters per square second."""
Expand Down Expand Up @@ -40,6 +42,8 @@ class Angle(_QuantityBase):

# --- This part is auto generated. Do not change manually. ---

_BASE_UNIT = "radians"

@property
def degrees(self) -> float:
"""The angle in degrees."""
Expand Down Expand Up @@ -72,6 +76,8 @@ class Area(_QuantityBase):

# --- This part is auto generated. Do not change manually. ---

_BASE_UNIT = "square_meters"

@property
def square_miles(self) -> float:
"""The area in square miles."""
Expand Down Expand Up @@ -146,6 +152,8 @@ class Length(_QuantityBase):

# --- This part is auto generated. Do not change manually. ---

_BASE_UNIT = "meters"

@property
def miles(self) -> float:
"""The length in miles."""
Expand Down Expand Up @@ -220,6 +228,8 @@ class Mass(_QuantityBase):

# --- This part is auto generated. Do not change manually. ---

_BASE_UNIT = "kilograms"

@property
def tonnes(self) -> float:
"""The mass in tonnes."""
Expand Down Expand Up @@ -287,6 +297,8 @@ class Velocity(_QuantityBase):

# --- This part is auto generated. Do not change manually. ---

_BASE_UNIT = "meters_per_second"

@property
def meters_per_second(self) -> float:
"""The velocity in meters per second."""
Expand Down Expand Up @@ -326,6 +338,8 @@ class Time(_QuantityBase):

# --- This part is auto generated. Do not change manually. ---

_BASE_UNIT = "seconds"

@property
def hours(self) -> float:
"""The time in hours."""
Expand Down
24 changes: 17 additions & 7 deletions quantio/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np

from ._quantity_base import _QuantityBase
from .exceptions import NoUnitSpecifiedError

T = TypeVar("T")

Expand All @@ -17,13 +18,12 @@ class Vector(Generic[T]):
def __init__(self, elements: list | tuple | np.ndarray) -> None:
self._elements = np.array(elements)

def to_numpy(self) -> np.ndarray[float]:
def to_numpy(self, unit: str | None = None) -> np.ndarray[float]:
"""Convert this vector into a numpy array of floats."""
if len(self._elements) == 0:
return np.array([])

if isinstance(self._elements[0], _QuantityBase):
return np.array([element._base_value for element in self._elements])
if unit is None:
raise NoUnitSpecifiedError
return np.array([getattr(element, unit) for element in self._elements])

return np.array([float(element) for element in self._elements])

Expand Down Expand Up @@ -52,11 +52,19 @@ def __sub__(self, other: Vector[T] | np.ndarray) -> Vector[T]:

def __mul__(self, other: Vector | np.ndarray | float) -> np.ndarray:
"""Multipy this vector with either another vector or a scalar."""
return self.to_numpy() * _other_to_numpy(other)
if isinstance(self._elements[0], _QuantityBase):
self_to_numpy = self.to_numpy(self._elements[0]._BASE_UNIT)
else:
self_to_numpy = self.to_numpy()
return self_to_numpy * _other_to_numpy(other)

def __truediv__(self, other: Vector | np.ndarray | float) -> np.ndarray:
"""Multipy this vector with either another vector or a scalar."""
return self.to_numpy() / _other_to_numpy(other)
if isinstance(self._elements[0], _QuantityBase):
self_to_numpy = self.to_numpy(self._elements[0]._BASE_UNIT)
else:
self_to_numpy = self.to_numpy()
return self_to_numpy / _other_to_numpy(other)

def __eq__(self, other: object) -> bool:
"""Assess if this object is the same as another."""
Expand All @@ -74,6 +82,8 @@ def _other_to_numpy(other: Vector | np.ndarray | float) -> np.ndarray:
return np.array([other._base_value])

if isinstance(other, Vector):
if isinstance(other._elements[0], _QuantityBase):
return other.to_numpy(other._elements[0]._BASE_UNIT)
return other.to_numpy()

if isinstance(other, np.ndarray):
Expand Down
29 changes: 28 additions & 1 deletion test/test_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,13 @@

import numpy as np

from quantio import Vector, Length, CanNotAddTypesError, CanNotSubtractTypesError
from quantio import (
Vector,
Length,
CanNotAddTypesError,
CanNotSubtractTypesError,
NoUnitSpecifiedError,
)


def test_init():
Expand Down Expand Up @@ -155,5 +161,26 @@ def test_divide__wrong_dimension():
vec1 / vec2


def test_to_numpy__floats():
vec: Vector[float] = Vector([0.0, 1.0])

actual = vec.to_numpy()
assert np.all(actual == np.array([0.0, 1.0]))


def test_to_numpy__quantity():
vec: Vector[Length] = Vector([Length(meters=1), Length(meters=2), Length(meters=3)])

actual = vec.to_numpy("centimeters")
assert np.all(actual == np.array([100.0, 200.0, 300.0]))


def test_to_numpy__quantity_no_unit():
vec: Vector[Length] = Vector([Length(meters=1), Length(meters=2), Length(meters=3)])

with pytest.raises(NoUnitSpecifiedError):
vec.to_numpy()


if __name__ == "__main__":
pytest.main([__file__, "-v"])

0 comments on commit 7d00b0d

Please sign in to comment.