diff --git a/changelog/1238.feature.rst b/changelog/1238.feature.rst new file mode 100644 index 0000000000..29445fc514 --- /dev/null +++ b/changelog/1238.feature.rst @@ -0,0 +1 @@ +Add support for ``BaseFlags`` to allow comparison with ``flag_values`` and vice versa. diff --git a/disnake/flags.py b/disnake/flags.py index de753b30ba..6b2c24e71c 100644 --- a/disnake/flags.py +++ b/disnake/flags.py @@ -55,6 +55,16 @@ def __init__(self, func: Callable[[Any], int]) -> None: self.__doc__ = func.__doc__ self._parent: Type[T] = MISSING + def __eq__(self, other: Any) -> bool: + if isinstance(other, flag_value): + return self.flag == other.flag + if isinstance(other, BaseFlags): + return self._parent is other.__class__ and self.flag == other.value + return False + + def __ne__(self, other: Any) -> bool: + return not self.__eq__(other) + def __or__(self, other: Union[flag_value[T], T]) -> T: if isinstance(other, BaseFlags): if self._parent is not other.__class__: @@ -148,7 +158,11 @@ def _from_value(cls, value: int) -> Self: return self def __eq__(self, other: Any) -> bool: - return isinstance(other, self.__class__) and self.value == other.value + if isinstance(other, self.__class__): + return self.value == other.value + if isinstance(other, flag_value): + return self.__class__ is other._parent and self.value == other.flag + return False def __ne__(self, other: Any) -> bool: return not self.__eq__(other) diff --git a/tests/test_flags.py b/tests/test_flags.py index cb9d64b0ea..575851c26e 100644 --- a/tests/test_flags.py +++ b/tests/test_flags.py @@ -184,6 +184,21 @@ def test__eq__(self) -> None: assert not ins == other assert ins != other + def test__eq__flag_value(self) -> None: + ins = TestFlags(one=True) + other = TestFlags(one=True, two=True) + + assert ins == TestFlags.one + assert TestFlags.one == ins + + assert not ins != TestFlags.one + assert ins != TestFlags.two + + assert other != TestFlags.one + assert other != TestFlags.two + + assert other == TestFlags.three + def test__and__(self) -> None: ins = TestFlags(one=True, two=True) other = TestFlags(one=True, two=True)