Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor and optimize Enum class #188

Merged
merged 3 commits into from
Apr 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions changelog_entry.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
- bump: patch
changes:
fixed:
- FutureWarning issue with bools and enums.
123 changes: 46 additions & 77 deletions policyengine_core/enums/enum.py
Original file line number Diff line number Diff line change
@@ -1,110 +1,79 @@
from __future__ import annotations

import enum
from typing import Union

import numpy

import numpy as np
from .config import ENUM_ARRAY_DTYPE
from .enum_array import EnumArray

import warnings

warnings.simplefilter("ignore", category=FutureWarning)


class Enum(enum.Enum):
"""
Enum based on `enum34 <https://pypi.python.org/pypi/enum34/>`_, whose items
have an index.
"""

# Tweak enums to add an index attribute to each enum item
def __init__(self, name: str) -> None:
# When the enum item is initialized, self._member_names_ contains the
# names of the previously initialized items, so its length is the index
# of this item.
"""
Initialize an Enum item with a name and an index.

The index is automatically assigned based on the order of the Enum items.
"""
self.index = len(self._member_names_)

# Bypass the slow Enum.__eq__
__eq__ = object.__eq__

# In Python 3, __hash__ must be defined if __eq__ is defined to stay
# hashable.
__hash__ = object.__hash__

@classmethod
def encode(
cls,
array: Union[
EnumArray,
numpy.int_,
numpy.float_,
numpy.object_,
],
) -> EnumArray:
def encode(cls, array: Union[EnumArray, np.ndarray]) -> EnumArray:
"""
Encode a string numpy array, an enum item numpy array, or an int numpy
array into an :any:`EnumArray`. See :any:`EnumArray.decode` for
decoding.

:param numpy.ndarray array: Array of string identifiers, or of enum
items, to encode.
Encode an array of enum items or string identifiers into an EnumArray.

:returns: An :any:`EnumArray` encoding the input array values.
:rtype: :any:`EnumArray`
Args:
array: The input array to encode. Can be an EnumArray, a NumPy array
of enum items, or a NumPy array of string identifiers.

For instance:
Returns:
An EnumArray containing the encoded values.

>>> string_identifier_array = asarray(['free_lodger', 'owner'])
>>> encoded_array = HousingOccupancyStatus.encode(string_identifier_array)
>>> encoded_array[0]
2 # Encoded value
Examples:
>>> string_array = np.array(["ITEM_1", "ITEM_2", "ITEM_3"])
>>> encoded_array = MyEnum.encode(string_array)
>>> encoded_array
EnumArray([1, 2, 3], dtype=int8)

>>> free_lodger = HousingOccupancyStatus.free_lodger
>>> owner = HousingOccupancyStatus.owner
>>> enum_item_array = asarray([free_lodger, owner])
>>> encoded_array = HousingOccupancyStatus.encode(enum_item_array)
>>> encoded_array[0]
2 # Encoded value
>>> item_array = np.array([MyEnum.ITEM_1, MyEnum.ITEM_2, MyEnum.ITEM_3])
>>> encoded_array = MyEnum.encode(item_array)
>>> encoded_array
EnumArray([1, 2, 3], dtype=int8)
"""
if isinstance(array, EnumArray):
return array

if isinstance(array == 0, bool):
if array.dtype.kind == "b":
# Convert boolean array to string array
array = array.astype(str)

# String array
if isinstance(array, numpy.ndarray) and array.dtype.kind in {"U", "S"}:
array = numpy.select(
if array.dtype.kind in {"U", "S"}:
# String array
indices = np.select(
[array == item.name for item in cls],
[item.index for item in cls],
).astype(ENUM_ARRAY_DTYPE)

# Enum items arrays
elif isinstance(array, numpy.ndarray) and array.dtype.kind == "O":
# Ensure we are comparing the comparable. The problem this fixes:
# On entering this method "cls" will generally come from
# variable.possible_values, while the array values may come from
# directly importing a module containing an Enum class. However,
# variables (and hence their possible_values) are loaded by a call
# to load_module, which gives them a different identity from the
# ones imported in the usual way.
#
# So, instead of relying on the "cls" passed in, we use only its
# name to check that the values in the array, if non-empty, are of
# the right type.
if len(array) > 0 and cls.__name__ is array[0].__class__.__name__:
cls = array[0].__class__
if array[0].__class__.__name__ != "bytes":
array = numpy.select(
[array == item for item in cls],
[item.index for item in cls],
).astype(ENUM_ARRAY_DTYPE)
else:
array = numpy.select(
[array.astype(str) == item.name for item in cls],
[item.index for item in cls],
).astype(ENUM_ARRAY_DTYPE)

return EnumArray(array, cls)
)
elif array.dtype.kind == "O":
# Enum items array
if len(array) > 0:
first_item = array[0]
if cls.__name__ == type(first_item).__name__:
# Use the same Enum class as the array items
cls = type(first_item)
indices = np.select(
[array == item for item in cls],
[item.index for item in cls],
)
elif array.dtype.kind in {"i", "u"}:
# Integer array
indices = array
else:
raise ValueError(f"Unsupported array dtype: {array.dtype}")

return EnumArray(indices.astype(ENUM_ARRAY_DTYPE), cls)
Loading