diff --git a/CHANGELOG.md b/CHANGELOG.md index 4f74299ba1..2c5ed63053 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,8 @@ These changes are available on the `master` branch, but have not yet been releas `tags`. ([#2520](https://github.com/Pycord-Development/pycord/pull/2520)) - Added `Member.guild_banner` and `Member.display_banner` properties. ([#2556](https://github.com/Pycord-Development/pycord/pull/2556)) +- Added optional `filter` parameter to `utils.basic_autocomplete()`. + ([#2590](https://github.com/Pycord-Development/pycord/pull/2590)) ### Fixed diff --git a/discord/utils.py b/discord/utils.py index 9f01f53e71..fbcf9c1f31 100644 --- a/discord/utils.py +++ b/discord/utils.py @@ -1306,9 +1306,12 @@ def generate_snowflake(dt: datetime.datetime | None = None) -> int: AV = Awaitable[V] Values = Union[V, Callable[[AutocompleteContext], Union[V, AV]], AV] AutocompleteFunc = Callable[[AutocompleteContext], AV] +FilterFunc = Callable[[AutocompleteContext, Any], Union[bool, Awaitable[bool]]] -def basic_autocomplete(values: Values) -> AutocompleteFunc: +def basic_autocomplete( + values: Values, *, filter: FilterFunc | None = None +) -> AutocompleteFunc: """A helper function to make a basic autocomplete for slash commands. This is a pretty standard autocomplete and will return any options that start with the value from the user, case-insensitive. If the ``values`` parameter is callable, it will be called with the AutocompleteContext. @@ -1320,18 +1323,21 @@ def basic_autocomplete(values: Values) -> AutocompleteFunc: values: Union[Union[Iterable[:class:`.OptionChoice`], Iterable[:class:`str`], Iterable[:class:`int`], Iterable[:class:`float`]], Callable[[:class:`.AutocompleteContext`], Union[Union[Iterable[:class:`str`], Iterable[:class:`int`], Iterable[:class:`float`]], Awaitable[Union[Iterable[:class:`str`], Iterable[:class:`int`], Iterable[:class:`float`]]]]], Awaitable[Union[Iterable[:class:`str`], Iterable[:class:`int`], Iterable[:class:`float`]]]] Possible values for the option. Accepts an iterable of :class:`str`, a callable (sync or async) that takes a single argument of :class:`.AutocompleteContext`, or a coroutine. Must resolve to an iterable of :class:`str`. + filter: Optional[Callable[[:class:`.AutocompleteContext`, Any], Union[:class:`bool`, Awaitable[:class:`bool`]]]] + An optional callable (sync or async) used to filter the autocomplete options. It accepts two arguments: + the :class:`.AutocompleteContext` and an item from ``values`` iteration treated as callback parameters. If ``None`` is provided, a default filter is used that includes items whose string representation starts with the user's input value, case-insensitive. + + .. versionadded:: 2.7 Returns ------- Callable[[:class:`.AutocompleteContext`], Awaitable[Union[Iterable[:class:`.OptionChoice`], Iterable[:class:`str`], Iterable[:class:`int`], Iterable[:class:`float`]]]] A wrapped callback for the autocomplete. - Note - ---- - Autocomplete cannot be used for options that have specified choices. + Examples + -------- - Example - ------- + Basic usage: .. code-block:: python3 @@ -1344,7 +1350,17 @@ async def autocomplete(ctx): Option(str, "name", autocomplete=basic_autocomplete(autocomplete)) + With filter parameter: + + .. code-block:: python3 + + Option(str, "color", autocomplete=basic_autocomplete(("red", "green", "blue"), filter=lambda c, i: str(c.value or "") in i)) + .. versionadded:: 2.0 + + Note + ---- + Autocomplete cannot be used for options that have specified choices. """ async def autocomplete_callback(ctx: AutocompleteContext) -> V: @@ -1355,11 +1371,23 @@ async def autocomplete_callback(ctx: AutocompleteContext) -> V: if asyncio.iscoroutine(_values): _values = await _values - def check(item: Any) -> bool: - item = getattr(item, "name", item) - return str(item).lower().startswith(str(ctx.value or "").lower()) + if filter is None: + + def _filter(ctx: AutocompleteContext, item: Any) -> bool: + item = getattr(item, "name", item) + return str(item).lower().startswith(str(ctx.value or "").lower()) + + gen = (val for val in _values if _filter(ctx, val)) + + elif asyncio.iscoroutinefunction(filter): + gen = (val for val in _values if await filter(ctx, val)) + + elif callable(filter): + gen = (val for val in _values if filter(ctx, val)) + + else: + raise TypeError("``filter`` must be callable.") - gen = (val for val in _values if check(val)) return iter(itertools.islice(gen, 25)) return autocomplete_callback