diff --git a/CHANGELOG.md b/CHANGELOG.md index 861f66ec0f..3735324562 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -36,6 +36,9 @@ These changes are available on the `master` branch, but have not yet been releas `Permissions.use_external_sounds`, and `Permissions.view_creator_monetization_analytics`. ([#2620](https://github.com/Pycord-Development/pycord/pull/2620)) +- Added the ability to use functions with any number of optional arguments, and + functions returning an awaitable as `Option.autocomplete` + ([#2669](https://github.com/Pycord-Development/pycord/pull/2669)). ### Fixed diff --git a/discord/commands/core.py b/discord/commands/core.py index 6dd1b0d636..846e030efc 100644 --- a/discord/commands/core.py +++ b/discord/commands/core.py @@ -1095,13 +1095,13 @@ async def invoke_autocomplete_callback(self, ctx: AutocompleteContext): ctx.value = op.get("value") ctx.options = values - if len(inspect.signature(option.autocomplete).parameters) == 2: + if option.autocomplete._is_instance_method: instance = getattr(option.autocomplete, "__self__", ctx.cog) result = option.autocomplete(instance, ctx) else: result = option.autocomplete(ctx) - if asyncio.iscoroutinefunction(option.autocomplete): + if inspect.isawaitable(result): result = await result choices = [ diff --git a/discord/commands/options.py b/discord/commands/options.py index 4b35a080d9..8b4649a4a5 100644 --- a/discord/commands/options.py +++ b/discord/commands/options.py @@ -26,8 +26,9 @@ import inspect import logging +from collections.abc import Awaitable, Callable, Iterable from enum import Enum -from typing import TYPE_CHECKING, Literal, Optional, Type, Union +from typing import TYPE_CHECKING, Any, Literal, Optional, Type, TypeVar, Union from ..abc import GuildChannel, Mentionable from ..channel import ( @@ -39,13 +40,14 @@ Thread, VoiceChannel, ) -from ..commands import ApplicationContext +from ..commands import ApplicationContext, AutocompleteContext from ..enums import ChannelType from ..enums import Enum as DiscordEnum from ..enums import SlashCommandOptionType from ..utils import MISSING, basic_autocomplete if TYPE_CHECKING: + from ..cog import Cog from ..ext.commands import Converter from ..member import Member from ..message import Attachment @@ -71,6 +73,25 @@ Type[DiscordEnum], ] + AutocompleteReturnType = Union[ + Iterable["OptionChoice"], Iterable[str], Iterable[int], Iterable[float] + ] + T = TypeVar("T", bound=AutocompleteReturnType) + MaybeAwaitable = Union[T, Awaitable[T]] + AutocompleteFunction = Union[ + Callable[[AutocompleteContext], MaybeAwaitable[AutocompleteReturnType]], + Callable[[Cog, AutocompleteContext], MaybeAwaitable[AutocompleteReturnType]], + Callable[ + [AutocompleteContext, Any], # pyright: ignore [reportExplicitAny] + MaybeAwaitable[AutocompleteReturnType], + ], + Callable[ + [Cog, AutocompleteContext, Any], # pyright: ignore [reportExplicitAny] + MaybeAwaitable[AutocompleteReturnType], + ], + ] + + __all__ = ( "ThreadOption", "Option", @@ -146,15 +167,6 @@ class Option: max_length: Optional[:class:`int`] The maximum length of the string that can be entered. Must be between 1 and 6000 (inclusive). Only applies to Options with an :attr:`input_type` of :class:`str`. - autocomplete: Optional[Callable[[:class:`.AutocompleteContext`], Awaitable[Union[Iterable[:class:`.OptionChoice`], Iterable[:class:`str`], Iterable[:class:`int`], Iterable[:class:`float`]]]]] - The autocomplete handler for the option. Accepts a callable (sync or async) - that takes a single argument of :class:`AutocompleteContext`. - The callable must return an iterable of :class:`str` or :class:`OptionChoice`. - Alternatively, :func:`discord.utils.basic_autocomplete` may be used in place of the callable. - - .. note:: - - Does not validate the input value against the autocomplete results. channel_types: list[:class:`discord.ChannelType`] | None A list of channel types that can be selected in this option. Only applies to Options with an :attr:`input_type` of :class:`discord.SlashCommandOptionType.channel`. @@ -272,6 +284,7 @@ def __init__( ) self.default = kwargs.pop("default", None) + self._autocomplete: AutocompleteFunction | None = None self.autocomplete = kwargs.pop("autocomplete", None) if len(enum_choices) > 25: self.choices: list[OptionChoice] = [] @@ -390,6 +403,43 @@ def to_dict(self) -> dict: def __repr__(self): return f"" + @property + def autocomplete(self) -> AutocompleteFunction | None: + """ + The autocomplete handler for the option. Accepts a callable (sync or async) + that takes a single required argument of :class:`AutocompleteContext` or two arguments + of :class:`discord.Cog` (being the command's cog) and :class:`AutocompleteContext`. + The callable must return an iterable of :class:`str` or :class:`OptionChoice`. + Alternatively, :func:`discord.utils.basic_autocomplete` may be used in place of the callable. + + Returns + ------- + Optional[AutocompleteFunction] + + .. versionchanged:: 2.7 + + .. note:: + Does not validate the input value against the autocomplete results. + """ + return self._autocomplete + + @autocomplete.setter + def autocomplete(self, value: AutocompleteFunction | None) -> None: + self._autocomplete = value + # this is done here so it does not have to be computed every time the autocomplete is invoked + if self._autocomplete is not None: + self._autocomplete._is_instance_method = ( # pyright: ignore [reportFunctionMemberAccess] + sum( + 1 + for param in inspect.signature( + self._autocomplete + ).parameters.values() + if param.default == param.empty # pyright: ignore[reportAny] + and param.kind not in (param.VAR_POSITIONAL, param.VAR_KEYWORD) + ) + == 2 + ) + class OptionChoice: """ diff --git a/examples/app_commands/slash_partial_autocomplete.py b/examples/app_commands/slash_partial_autocomplete.py new file mode 100644 index 0000000000..ca6cd1eab9 --- /dev/null +++ b/examples/app_commands/slash_partial_autocomplete.py @@ -0,0 +1,49 @@ +from functools import partial +from os import getenv + +from dotenv import load_dotenv + +import discord +from discord.ext import commands + +load_dotenv() + +bot = discord.Bot() + +fruits = ["Apple", "Banana", "Orange"] +vegetables = ["Carrot", "Lettuce", "Potato"] + + +async def food_autocomplete( + ctx: discord.AutocompleteContext, food_type: str +) -> list[discord.OptionChoice]: + items = fruits if food_type == "fruit" else vegetables + return [ + discord.OptionChoice(name=item) + for item in items + if ctx.value.lower() in item.lower() + ] + + +class FoodCog(commands.Cog): + @commands.slash_command(name="fruit") + @discord.option( + "choice", + "Pick a fruit", + autocomplete=partial(food_autocomplete, food_type="fruit"), + ) + async def get_fruit(self, ctx: discord.ApplicationContext, choice: str): + await ctx.respond(f'You picked "{choice}"') + + @commands.slash_command(name="vegetable") + @discord.option( + "choice", + "Pick a vegetable", + autocomplete=partial(food_autocomplete, food_type="vegetable"), + ) + async def get_vegetable(self, ctx: discord.ApplicationContext, choice: str): + await ctx.respond(f'You picked "{choice}"') + + +bot.add_cog(FoodCog()) +bot.run(getenv("TOKEN"))