Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce redundancy in shape factory by adding names #242

Merged
merged 1 commit into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions docs/tutorials/shape-creation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -61,18 +61,18 @@ to calculate its position and scale:

class XLines(LineCollection):

name = 'x'

def __init__(self, dataset: Dataset) -> None:
xmin, xmax = dataset.morph_bounds.x_bounds
ymin, ymax = dataset.morph_bounds.y_bounds

super().__init__([[xmin, ymin], [xmax, ymax]], [[xmin, ymax], [xmax, ymin]])

def __str__(self) -> str:
return 'x'

Since we inherit from :class:`.LineCollection` here, we don't need to define
the ``distance()`` and ``plot()`` methods (unless we want to override them).
We do override the ``__str__()`` method here since the default will result in
We do set the ``name`` attribute here since the default will result in
a value of ``xlines`` and ``x`` makes more sense for use in the documentation
(see :class:`.ShapeFactory`).

Expand All @@ -89,8 +89,8 @@ For the ``data-morph`` CLI to find your shape, you need to register it with the
2. Add your shape to ``__all__`` in that module's ``__init__.py`` (*e.g.*, use
``src/data_morph/shapes/points/__init__.py`` for a new shape inheriting from
:class:`.PointCollection`).
3. Add an entry to the ``ShapeFactory._SHAPE_MAPPING`` dictionary in
``src/data_morph/shapes/factory.py``.
3. Add an entry to the ``ShapeFactory._SHAPE_CLASSES`` tuple in
``src/data_morph/shapes/factory.py``, preserving alphabetical order.

Test out the shape
------------------
Expand Down
21 changes: 20 additions & 1 deletion src/data_morph/shapes/bases/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,21 @@
class Shape(ABC):
"""Abstract base class for a shape."""

name: str | None = None
"""The display name for the shape, if the lowercased class name is not desired."""

@classmethod
def get_name(cls) -> str:
"""
Get the name of the shape.

Returns
-------
str
The name of the shape.
"""
return cls.name or cls.__name__.lower()

def __repr__(self) -> str:
"""
Return string representation of the shape.
Expand All @@ -32,8 +47,12 @@ def __str__(self) -> str:
-------
str
The human-readable string representation of the shape.

See Also
--------
get_name : This calls the :meth:`.get_name` class method.
"""
return self.__class__.__name__.lower()
return self.get_name()

@abstractmethod
def distance(self, x: Number, y: Number) -> float:
Expand Down
52 changes: 29 additions & 23 deletions src/data_morph/shapes/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,33 +57,39 @@ class ShapeFactory:
The starting dataset to morph into other shapes.
"""

_SHAPE_CLASSES: tuple[type[Shape]] = (
Bullseye,
Circle,
Club,
Diamond,
DotsGrid,
DownParabola,
Heart,
HighLines,
HorizontalLines,
LeftParabola,
Rectangle,
RightParabola,
Rings,
Scatter,
SlantDownLines,
SlantUpLines,
Spade,
Star,
UpParabola,
VerticalLines,
WideLines,
XLines,
)
"""New shape classes must be registered here."""

_SHAPE_MAPPING: ClassVar[dict[str, type[Shape]]] = {
'bullseye': Bullseye,
'circle': Circle,
'high_lines': HighLines,
'h_lines': HorizontalLines,
'slant_down': SlantDownLines,
'slant_up': SlantUpLines,
'v_lines': VerticalLines,
'wide_lines': WideLines,
'x': XLines,
'dots': DotsGrid,
'down_parab': DownParabola,
'heart': Heart,
'left_parab': LeftParabola,
'scatter': Scatter,
'right_parab': RightParabola,
'up_parab': UpParabola,
'diamond': Diamond,
'rectangle': Rectangle,
'rings': Rings,
'star': Star,
'club': Club,
'spade': Spade,
shape_cls.get_name(): shape_cls for shape_cls in _SHAPE_CLASSES
}
"""Mapping of shape display names to classes."""

AVAILABLE_SHAPES: list[str] = sorted(_SHAPE_MAPPING.keys())
"""list[str]: The list of available shapes, which can be visualized with
"""The list of available shapes, which can be visualized with
:meth:`.plot_available_shapes`."""

def __init__(self, dataset: Dataset) -> None:
Expand Down
5 changes: 2 additions & 3 deletions src/data_morph/shapes/lines/high_lines.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ class HighLines(LineCollection):
The starting dataset to morph into other shapes.
"""

name = 'high_lines'

def __init__(self, dataset: Dataset) -> None:
x_bounds = dataset.data_bounds.x_bounds
y_bounds = dataset.data_bounds.y_bounds
Expand All @@ -36,6 +38,3 @@ def __init__(self, dataset: Dataset) -> None:
[[x_bounds[0], lower], [x_bounds[1], lower]],
[[x_bounds[0], upper], [x_bounds[1], upper]],
)

def __str__(self) -> str:
return 'high_lines'
5 changes: 2 additions & 3 deletions src/data_morph/shapes/lines/horizontal_lines.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ class HorizontalLines(LineCollection):
The starting dataset to morph into other shapes.
"""

name = 'h_lines'

def __init__(self, dataset: Dataset) -> None:
x_bounds = dataset.data_bounds.x_bounds
y_bounds = dataset.data_bounds.y_bounds
Expand All @@ -36,6 +38,3 @@ def __init__(self, dataset: Dataset) -> None:
for y in np.linspace(y_bounds[0], y_bounds[1], 5)
]
)

def __str__(self) -> str:
return 'h_lines'
5 changes: 2 additions & 3 deletions src/data_morph/shapes/lines/slant_down.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ class SlantDownLines(LineCollection):
The starting dataset to morph into other shapes.
"""

name = 'slant_down'

def __init__(self, dataset: Dataset) -> None:
x_bounds = dataset.morph_bounds.x_bounds
y_bounds = dataset.morph_bounds.y_bounds
Expand All @@ -43,6 +45,3 @@ def __init__(self, dataset: Dataset) -> None:
[[xmin + x_offset, ymax], [xmax, ymin + y_offset]],
[[xmid, ymax], [xmax, ymid]],
)

def __str__(self) -> str:
return 'slant_down'
5 changes: 2 additions & 3 deletions src/data_morph/shapes/lines/slant_up.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ class SlantUpLines(LineCollection):
The starting dataset to morph into other shapes.
"""

name = 'slant_up'

def __init__(self, dataset: Dataset) -> None:
x_bounds = dataset.morph_bounds.x_bounds
y_bounds = dataset.morph_bounds.y_bounds
Expand All @@ -43,6 +45,3 @@ def __init__(self, dataset: Dataset) -> None:
[[xmin + x_offset, ymin], [xmax, ymid + y_offset]],
[[xmid, ymin], [xmax, ymid]],
)

def __str__(self) -> str:
return 'slant_up'
5 changes: 2 additions & 3 deletions src/data_morph/shapes/lines/vertical_lines.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ class VerticalLines(LineCollection):
The starting dataset to morph into other shapes.
"""

name = 'v_lines'

def __init__(self, dataset: Dataset) -> None:
x_bounds = dataset.data_bounds.x_bounds
y_bounds = dataset.data_bounds.y_bounds
Expand All @@ -36,6 +38,3 @@ def __init__(self, dataset: Dataset) -> None:
for x in np.linspace(x_bounds[0], x_bounds[1], 5)
]
)

def __str__(self) -> str:
return 'v_lines'
5 changes: 2 additions & 3 deletions src/data_morph/shapes/lines/wide_lines.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ class WideLines(LineCollection):
The starting dataset to morph into other shapes.
"""

name = 'wide_lines'

def __init__(self, dataset: Dataset) -> None:
x_bounds = dataset.data_bounds.x_bounds
y_bounds = dataset.data_bounds.y_bounds
Expand All @@ -36,6 +38,3 @@ def __init__(self, dataset: Dataset) -> None:
[[lower, y_bounds[0]], [lower, y_bounds[1]]],
[[upper, y_bounds[0]], [upper, y_bounds[1]]],
)

def __str__(self) -> str:
return 'wide_lines'
5 changes: 2 additions & 3 deletions src/data_morph/shapes/lines/x_lines.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,10 @@ class XLines(LineCollection):
The starting dataset to morph into other shapes.
"""

name = 'x'

def __init__(self, dataset: Dataset) -> None:
xmin, xmax = dataset.morph_bounds.x_bounds
ymin, ymax = dataset.morph_bounds.y_bounds

super().__init__([[xmin, ymin], [xmax, ymax]], [[xmin, ymax], [xmax, ymin]])

def __str__(self) -> str:
return 'x'
5 changes: 2 additions & 3 deletions src/data_morph/shapes/points/dots_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ class DotsGrid(PointCollection):
The starting dataset to morph into other shapes.
"""

name = 'dots'

def __init__(self, dataset: Dataset) -> None:
xlow, xhigh = dataset.df.x.quantile([0.05, 0.95]).tolist()
ylow, yhigh = dataset.df.y.quantile([0.05, 0.95]).tolist()
Expand All @@ -36,6 +38,3 @@ def __init__(self, dataset: Dataset) -> None:
super().__init__(
*list(itertools.product([xlow, xmid, xhigh], [ylow, ymid, yhigh]))
)

def __str__(self) -> str:
return 'dots'
20 changes: 8 additions & 12 deletions src/data_morph/shapes/points/parabola.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ class DownParabola(PointCollection):
The starting dataset to morph into other shapes.
"""

name = 'down_parab'

def __init__(self, dataset: Dataset) -> None:
x_bounds = dataset.data_bounds.x_bounds
xmin, xmax = x_bounds
Expand All @@ -41,9 +43,6 @@ def __init__(self, dataset: Dataset) -> None:

super().__init__(*np.stack(poly.linspace(), axis=1))

def __str__(self) -> str:
return 'down_parab'


class LeftParabola(PointCollection):
"""
Expand All @@ -65,6 +64,8 @@ class LeftParabola(PointCollection):
The starting dataset to morph into other shapes.
"""

name = 'left_parab'

def __init__(self, dataset: Dataset) -> None:
y_bounds = dataset.data_bounds.y_bounds
ymin, ymax = y_bounds
Expand All @@ -80,9 +81,6 @@ def __init__(self, dataset: Dataset) -> None:

super().__init__(*np.stack(poly.linspace()[::-1], axis=1))

def __str__(self) -> str:
return 'left_parab'


class RightParabola(PointCollection):
"""
Expand All @@ -104,6 +102,8 @@ class RightParabola(PointCollection):
The starting dataset to morph into other shapes.
"""

name = 'right_parab'

def __init__(self, dataset: Dataset) -> None:
y_bounds = dataset.data_bounds.y_bounds
ymin, ymax = y_bounds
Expand All @@ -119,9 +119,6 @@ def __init__(self, dataset: Dataset) -> None:

super().__init__(*np.stack(poly.linspace()[::-1], axis=1))

def __str__(self) -> str:
return 'right_parab'


class UpParabola(PointCollection):
"""
Expand All @@ -143,6 +140,8 @@ class UpParabola(PointCollection):
The starting dataset to morph into other shapes.
"""

name = 'up_parab'

def __init__(self, dataset: Dataset) -> None:
x_bounds = dataset.data_bounds.x_bounds
xmin, xmax = x_bounds
Expand All @@ -157,6 +156,3 @@ def __init__(self, dataset: Dataset) -> None:
poly = np.polynomial.Polynomial.fit([xmin, xmid, xmax], [ymax, ymin, ymax], 2)

super().__init__(*np.stack(poly.linspace(), axis=1))

def __str__(self) -> str:
return 'up_parab'