From 67bea2fedd536c3bf66f1317f5f3e05e0f29462f Mon Sep 17 00:00:00 2001 From: shiftinv Date: Tue, 30 Aug 2022 22:55:13 +0200 Subject: [PATCH 01/31] refactor!: rewrite `Bot.load_extensions` and corresponding utils method --- disnake/ext/commands/common_bot_base.py | 57 ++++++++-- disnake/utils.py | 91 +++++++++++----- docs/api.rst | 2 +- tests/test_utils.py | 138 +++++++++++++++++------- 4 files changed, 215 insertions(+), 73 deletions(-) diff --git a/disnake/ext/commands/common_bot_base.py b/disnake/ext/commands/common_bot_base.py index 441d928f5a..5620abffaf 100644 --- a/disnake/ext/commands/common_bot_base.py +++ b/disnake/ext/commands/common_bot_base.py @@ -16,6 +16,7 @@ Callable, Dict, Generic, + Iterable, List, Mapping, Optional, @@ -574,18 +575,62 @@ def reload_extension(self, name: str, *, package: Optional[str] = None) -> None: sys.modules.update(modules) raise - def load_extensions(self, path: str) -> None: - """Loads all extensions in a directory. + def load_extensions( + self, + root_module: str, + *, + package: Optional[str] = None, + ignore: Optional[Union[Iterable[str], Callable[[str], bool]]] = None, + ) -> None: + """ + Loads all extensions in a given module, also traversing into sub-packages. + + See :func:`disnake.utils.walk_extensions` for details on how packages are found. .. versionadded:: 2.4 + .. versionchanged:: 2.6 + Now accepts a module name instead of a filesystem path. + Also improved package traversal, adding support for more complex extensions + with ``__init__.py`` files, and added ``ignore`` parameter. + Parameters ---------- - path: :class:`str` - The path to search for extensions + root_module: :class:`str` + The module/package name to search in, for example `cogs.admin`. + package: Optional[:class:`str`] + The package name to resolve relative imports with. + This is required when ``root_module`` is relative, e.g ``.cogs.admin``. + Defaults to ``None``. + ignore: Union[Iterable[:class:`str`], Callable[[:class:`str`], :class:`bool`]] + An iterable of module names to ignore, or a callable that's used for ignoring + modules (where the callable returning ``True`` results in the module being ignored). + + See :func:`disnake.utils.walk_extensions` for details. """ - for extension in disnake.utils.search_directory(path): - self.load_extension(extension) + if "/" in root_module or "\\" in root_module: + # likely a path, try to be backwards compatible by converting to + # a relative path and using that as the module name + disnake.utils.warn_deprecated( + "Using a directory with `load_extensions` is deprecated. Use a module name (optionally with a package) instead.", + stacklevel=2, + ) + + path = os.path.relpath(root_module) + if ".." in path: + raise ImportError( + "Paths outside the cwd are not supported. Try using the module name instead." + ) + root_module = path.replace(os.sep, ".") + + if not (spec := importlib.util.find_spec(root_module, package)): + raise ImportError(f"Unable to find root module '{root_module}' in package '{package}'") + + if not (paths := spec.submodule_search_locations): + raise ImportError(f"Module '{root_module}' is not a package") + + for ext_name in disnake.utils.walk_extensions(paths, prefix=f"{spec.name}.", ignore=ignore): + self.load_extension(ext_name) @property def extensions(self) -> Mapping[str, types.ModuleType]: diff --git a/disnake/utils.py b/disnake/utils.py index 50ba3ab3d3..8868b7f7c6 100644 --- a/disnake/utils.py +++ b/disnake/utils.py @@ -6,8 +6,8 @@ import asyncio import datetime import functools +import importlib import json -import os import pkgutil import re import sys @@ -69,7 +69,7 @@ "escape_mentions", "as_chunks", "format_dt", - "search_directory", + "walk_extensions", "as_valid_locale", ) @@ -1282,40 +1282,79 @@ def format_dt(dt: Union[datetime.datetime, float], /, style: TimestampStyle = "f return f"" -def search_directory(path: str) -> Iterator[str]: - """Walk through a directory and yield all modules. +def walk_extensions( + paths: Iterable[str], + prefix: str = "", + ignore: Optional[Union[Iterable[str], Callable[[str], bool]]] = None, +) -> Iterator[str]: + """ + Walk through the given package paths, and recursively yield modules. + + This is similar to :func:`py:pkgutil.walk_packages`, but supports ignoring + modules/packages. + + Namespace packages are not considered, meaning every package must have an + ``__init__.py`` file. If a package has a ``setup`` function, this method will + yield its name and not traverse the package further. + + Nonexistent paths are silently ignored. + + .. note:: + This imports all *packages* (not all modules) in the given path(s) + to access the ``__path__`` attribute for finding submodules, + unless they are filtered by the ``ignore`` parameter. Parameters ---------- - path: :class:`str` - The path to search for modules + paths: Iterable[:class:`str`] + The filesystem paths of packages to search in. + prefix: :class:`str` + The prefix to prepend to all module names. This should be set + accordingly to produce importable package names. + + For example, if ``paths`` contains ``/bot/cogs/admin``, this should + be set to `cogs.admin.` assuming the current working directory is ``/bot``. + ignore: Optional[Union[Iterable[:class:`str`], Callable[[:class:`str`], :class:`bool`]]] + An iterable of module names to ignore, or a callable that's used for ignoring + modules (where the callable returning ``True`` results in the module being ignored). + Defaults to ``None``, i.e. no modules are ignored. + + If it's an iterable, the elements must be module names. That is, + a module like ``cogs.admin.eval_cmd`` will be ignored if ``admin``, ``eval_cmd``, + or ``admin.eval_cmd`` is given, but not with ``admin.eval`` or ``cmd``, to name + a few examples. Yields - ------- + ------ :class:`str` - The name of the found module. (usable in load_extension) + The full module names in the given package paths. """ - relpath = os.path.relpath(path) # relative and normalized - if ".." in relpath: - raise ValueError("Modules outside the cwd require a package to be specified") - - abspath = os.path.abspath(path) - if not os.path.exists(relpath): - raise ValueError(f"Provided path '{abspath}' does not exist") - if not os.path.isdir(relpath): - raise ValueError(f"Provided path '{abspath}' is not a directory") - - prefix = relpath.replace(os.sep, ".") - if prefix in ("", "."): - prefix = "" - else: - prefix += "." + if isinstance(ignore, str): + raise TypeError("`ignore` must be an iterable of strings or a callable") + + if isinstance(ignore, Iterable): + ignore_parts = "|".join(re.escape(i) for i in ignore) + ignore_re = re.compile(rf"(^|\.)({ignore_parts})(\.|$)") + ignore = lambda path: ignore_re.search(path) is not None + # else, it's already a callable or None + + for _, name, ispkg in pkgutil.iter_modules(paths, prefix): + if ignore and ignore(name): + continue - for _, name, ispkg in pkgutil.iter_modules([path]): if ispkg: - yield from search_directory(os.path.join(path, name)) + mod = importlib.import_module(name) + + # if this module is a package but also has a `setup` function, + # yield it and don't look for other files in this module + if hasattr(mod, "setup"): + yield name + continue + + if sub_paths := mod.__path__: + yield from walk_extensions(sub_paths, prefix=f"{name}.", ignore=ignore) else: - yield prefix + name + yield name def as_valid_locale(locale: str) -> Optional[str]: diff --git a/docs/api.rst b/docs/api.rst index 4215491f56..9763bd3a5a 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -1560,7 +1560,7 @@ Utility Functions .. autofunction:: disnake.utils.as_chunks -.. autofunction:: disnake.utils.search_directory +.. autofunction:: disnake.utils.walk_extensions .. autofunction:: disnake.utils.as_valid_locale diff --git a/tests/test_utils.py b/tests/test_utils.py index a3eabf6f7c..8e78b35dde 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -8,6 +8,7 @@ import warnings from dataclasses import dataclass from datetime import timedelta, timezone +from pathlib import Path from typing import Any, Dict, List, Literal, Optional, Tuple, Union from unittest import mock @@ -808,61 +809,118 @@ def test_format_dt(dt, style, expected): assert utils.format_dt(dt, style) == expected +def _create_dirs(parent: Path, data: Dict[str, Any]) -> None: + for name, value in data.items(): + path = parent / name + if isinstance(value, dict): + path.mkdir() + _create_dirs(path, value) + elif isinstance(value, str): + path.write_text(value) + + @pytest.fixture(scope="session") -def tmp_module_root(tmp_path_factory): - # this obviously isn't great code, but it'll do just fine for tests +def tmp_module_root(tmp_path_factory: pytest.TempPathFactory): tmpdir = tmp_path_factory.mktemp("module_root") - for d in ["empty", "not_a_module", "mod/sub1/sub2"]: - (tmpdir / d).mkdir(parents=True) - for f in [ - "test.py", - "not_a_module/abc.py", - "mod/__init__.py", - "mod/ext.py", - "mod/sub1/sub2/__init__.py", - "mod/sub1/sub2/abc.py", - ]: - (tmpdir / f).touch() - return tmpdir + setup = "def setup(bot): ..." + _create_dirs( + tmpdir, + { + "toplevel": { + "__init__.py": "", + "nosetup.py": "", + "withsetup.py": setup, + "empty_dir": {}, + "not_a_module": {"abc.py": setup}, + "a_module": {"__init__.py": "", "abc.py": setup}, + "uncool_ext": {"__init__.py": ""}, + "cool_ext": {"__init__.py": setup}, + "mod": { + "__init__.py": "", + "ext.py": setup, + "not_a_submodule": { + "sub": {"__init__.py": setup}, + }, + "sub": { + "__init__.py": "", + "sub1": {"__init__.py": "", "abc.py": setup, "def.py": setup}, + "sub2": {"__init__.py": setup, "abc.py": setup}, + }, + }, + }, + }, + ) -@pytest.mark.parametrize( - ("path", "expected"), - [ - (".", ["test", "mod.ext"]), - ("./", ["test", "mod.ext"]), - ("empty/", []), - ], -) -def test_search_directory(tmp_module_root, path, expected): orig_cwd = os.getcwd() try: - os.chdir(tmp_module_root) - - # test relative and absolute paths - for p in [path, os.path.abspath(path)]: - assert sorted(utils.search_directory(p)) == sorted(expected) + os.chdir(tmpdir) + sys.path.insert(0, str(tmpdir)) + yield tmpdir finally: os.chdir(orig_cwd) + sys.path.remove(str(tmpdir)) @pytest.mark.parametrize( - ("path", "exc"), + ("ignore", "expected"), [ - ("../../", r"Modules outside the cwd require a package to be specified"), - ("nonexistent", r"Provided path '.*?nonexistent' does not exist"), - ("test.py", r"Provided path '.*?test.py' is not a directory"), + ( + None, + [ + "toplevel.nosetup", + "toplevel.withsetup", + "toplevel.a_module.abc", + "toplevel.cool_ext", + "toplevel.mod.ext", + "toplevel.mod.sub.sub1.abc", + "toplevel.mod.sub.sub1.def", + "toplevel.mod.sub.sub2", + ], + ), + ( + ["sub1.abc"], + [ + "toplevel.nosetup", + "toplevel.withsetup", + "toplevel.a_module.abc", + "toplevel.cool_ext", + "toplevel.mod.ext", + "toplevel.mod.sub.sub1.def", + "toplevel.mod.sub.sub2", + ], + ), + ( + ["ext", "a_module.abc", "sub.sub1"], + [ + "toplevel.nosetup", + "toplevel.withsetup", + "toplevel.cool_ext", + "toplevel.mod.sub.sub2", + ], + ), + ( + lambda name: "ext" in name, # pyright: ignore[reportUnknownLambdaType] + [ + "toplevel.nosetup", + "toplevel.withsetup", + "toplevel.a_module.abc", + "toplevel.mod.sub.sub1.abc", + "toplevel.mod.sub.sub1.def", + "toplevel.mod.sub.sub2", + ], + ), ], ) -def test_search_directory_exc(tmp_module_root, path, exc): - orig_cwd = os.getcwd() - try: - os.chdir(tmp_module_root) +def test_walk_extensions(tmp_module_root: Path, ignore, expected): + path = str(tmp_module_root / "toplevel") + assert sorted(utils.walk_extensions([path], "toplevel.", ignore)) == sorted(expected) - with pytest.raises(ValueError, match=exc): - list(utils.search_directory(tmp_module_root / path)) - finally: - os.chdir(orig_cwd) + +def test_walk_extensions_nonexistent(tmp_module_root: Path): + assert ( + list(utils.walk_extensions([str(tmp_module_root / "doesnotexist")], "doesnotexist.")) == [] + ) @pytest.mark.parametrize( From 60e76229a189e34e1b7c81fc2532cf60d868986b Mon Sep 17 00:00:00 2001 From: shiftinv Date: Tue, 30 Aug 2022 23:10:50 +0200 Subject: [PATCH 02/31] feat: yield loaded extension names --- disnake/ext/commands/common_bot_base.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/disnake/ext/commands/common_bot_base.py b/disnake/ext/commands/common_bot_base.py index 5620abffaf..c42e52d465 100644 --- a/disnake/ext/commands/common_bot_base.py +++ b/disnake/ext/commands/common_bot_base.py @@ -17,6 +17,7 @@ Dict, Generic, Iterable, + Iterator, List, Mapping, Optional, @@ -581,7 +582,7 @@ def load_extensions( *, package: Optional[str] = None, ignore: Optional[Union[Iterable[str], Callable[[str], bool]]] = None, - ) -> None: + ) -> Iterator[str]: """ Loads all extensions in a given module, also traversing into sub-packages. @@ -607,6 +608,11 @@ def load_extensions( modules (where the callable returning ``True`` results in the module being ignored). See :func:`disnake.utils.walk_extensions` for details. + + Yields + ------ + :class:`str` + The module names as they are being loaded. """ if "/" in root_module or "\\" in root_module: # likely a path, try to be backwards compatible by converting to @@ -631,6 +637,7 @@ def load_extensions( for ext_name in disnake.utils.walk_extensions(paths, prefix=f"{spec.name}.", ignore=ignore): self.load_extension(ext_name) + yield ext_name @property def extensions(self) -> Mapping[str, types.ModuleType]: From 0c9182b7cb066bc017915c9fac92a8b0fa84d3b1 Mon Sep 17 00:00:00 2001 From: shiftinv Date: Tue, 30 Aug 2022 23:11:22 +0200 Subject: [PATCH 03/31] fix: catch `ValueError` raised by `resolve_name` on 3.8 --- disnake/ext/commands/common_bot_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/disnake/ext/commands/common_bot_base.py b/disnake/ext/commands/common_bot_base.py index c42e52d465..afd31ba463 100644 --- a/disnake/ext/commands/common_bot_base.py +++ b/disnake/ext/commands/common_bot_base.py @@ -425,7 +425,7 @@ def _load_from_module_spec(self, spec: importlib.machinery.ModuleSpec, key: str) def _resolve_name(self, name: str, package: Optional[str]) -> str: try: return importlib.util.resolve_name(name, package) - except ImportError: + except (ValueError, ImportError): # 3.8 raises ValueError instead of ImportError raise errors.ExtensionNotFound(name) def load_extension(self, name: str, *, package: Optional[str] = None) -> None: From 8b8db8f2044bcaf9cbb95f8d4bb06fafcc6cc445 Mon Sep 17 00:00:00 2001 From: shiftinv Date: Tue, 30 Aug 2022 23:32:28 +0200 Subject: [PATCH 04/31] fix: change exception types --- disnake/ext/commands/common_bot_base.py | 27 +++++++++++++++++++++---- disnake/utils.py | 7 +++++++ 2 files changed, 30 insertions(+), 4 deletions(-) diff --git a/disnake/ext/commands/common_bot_base.py b/disnake/ext/commands/common_bot_base.py index afd31ba463..46fe839e74 100644 --- a/disnake/ext/commands/common_bot_base.py +++ b/disnake/ext/commands/common_bot_base.py @@ -588,6 +588,9 @@ def load_extensions( See :func:`disnake.utils.walk_extensions` for details on how packages are found. + This may raise any errors that :func:`load_extension` can raise, in addition to + the ones documented below. + .. versionadded:: 2.4 .. versionchanged:: 2.6 @@ -609,6 +612,17 @@ def load_extensions( See :func:`disnake.utils.walk_extensions` for details. + Raises + ------ + ExtensionNotFound + The given root module could not be found. + This is also raised if the name of the root module could not + be resolved using the provided ``package`` parameter. + + This, as well as other extension-related errors, may also be + raised as this method calls :func:`load_extension` on all found extensions. + See :func:`load_extension` for further details on raised exceptions. + Yields ------ :class:`str` @@ -624,16 +638,21 @@ def load_extensions( path = os.path.relpath(root_module) if ".." in path: - raise ImportError( + raise ValueError( "Paths outside the cwd are not supported. Try using the module name instead." ) root_module = path.replace(os.sep, ".") - if not (spec := importlib.util.find_spec(root_module, package)): - raise ImportError(f"Unable to find root module '{root_module}' in package '{package}'") + # `find_spec` already calls `resolve_name`, but we want our custom error handling here + root_module = self._resolve_name(root_module, package) + + if not (spec := importlib.util.find_spec(root_module)): + raise errors.ExtensionNotFound( + f"Unable to find root module '{root_module}' in package '{package}'" + ) if not (paths := spec.submodule_search_locations): - raise ImportError(f"Module '{root_module}' is not a package") + raise errors.ExtensionNotFound(f"Module '{root_module}' is not a package") for ext_name in disnake.utils.walk_extensions(paths, prefix=f"{spec.name}.", ignore=ignore): self.load_extension(ext_name) diff --git a/disnake/utils.py b/disnake/utils.py index 8868b7f7c6..4d8ed4c57d 100644 --- a/disnake/utils.py +++ b/disnake/utils.py @@ -1324,6 +1324,13 @@ def walk_extensions( or ``admin.eval_cmd`` is given, but not with ``admin.eval`` or ``cmd``, to name a few examples. + Raises + ------ + TypeError + The ``ignore`` parameter is of an invalid type. + ImportError + A package couldn't be imported. + Yields ------ :class:`str` From 5e2ada434e3db36a6a2b2aa672bd1f70054cca88 Mon Sep 17 00:00:00 2001 From: shiftinv Date: Sun, 9 Oct 2022 18:33:10 +0200 Subject: [PATCH 05/31] feat: add `return_exceptions` parameter --- disnake/ext/commands/common_bot_base.py | 57 ++++++++++++++++++++++--- 1 file changed, 51 insertions(+), 6 deletions(-) diff --git a/disnake/ext/commands/common_bot_base.py b/disnake/ext/commands/common_bot_base.py index 46fe839e74..3fa2d69943 100644 --- a/disnake/ext/commands/common_bot_base.py +++ b/disnake/ext/commands/common_bot_base.py @@ -19,11 +19,13 @@ Iterable, Iterator, List, + Literal, Mapping, Optional, Set, TypeVar, Union, + overload, ) import disnake @@ -576,13 +578,36 @@ def reload_extension(self, name: str, *, package: Optional[str] = None) -> None: sys.modules.update(modules) raise + @overload def load_extensions( self, root_module: str, *, package: Optional[str] = None, ignore: Optional[Union[Iterable[str], Callable[[str], bool]]] = None, + return_exceptions: Literal[False] = False, ) -> Iterator[str]: + ... + + @overload + def load_extensions( + self, + root_module: str, + *, + package: Optional[str] = None, + ignore: Optional[Union[Iterable[str], Callable[[str], bool]]] = None, + return_exceptions: Literal[True], + ) -> Iterator[Union[str, errors.ExtensionError]]: + ... + + def load_extensions( + self, + root_module: str, + *, + package: Optional[str] = None, + ignore: Optional[Union[Iterable[str], Callable[[str], bool]]] = None, + return_exceptions: bool = False, + ) -> Iterator[Union[str, errors.ExtensionError]]: """ Loads all extensions in a given module, also traversing into sub-packages. @@ -611,6 +636,11 @@ def load_extensions( modules (where the callable returning ``True`` results in the module being ignored). See :func:`disnake.utils.walk_extensions` for details. + return_exceptions: :class:`bool` + If set to ``True``, exceptions raised by the internal :func:`load_extension` calls + are yielded/returned instead of immediately propagating the first exception to the caller + (similar to :func:`py:asyncio.gather`). + Defaults to ``False``. Raises ------ @@ -619,14 +649,17 @@ def load_extensions( This is also raised if the name of the root module could not be resolved using the provided ``package`` parameter. - This, as well as other extension-related errors, may also be - raised as this method calls :func:`load_extension` on all found extensions. + ExtensionError + If ``return_exceptions=False``, other extension-related errors may also be raised + as this method calls :func:`load_extension` on all found extensions. See :func:`load_extension` for further details on raised exceptions. Yields ------ :class:`str` - The module names as they are being loaded. + The module names as they are being loaded (if ``return_exceptions=False``, the default). + Union[:class:`str`, :class:`ExtensionError`] + The module names or raised exceptions as they are being loaded (if ``return_exceptions=True``). """ if "/" in root_module or "\\" in root_module: # likely a path, try to be backwards compatible by converting to @@ -648,15 +681,27 @@ def load_extensions( if not (spec := importlib.util.find_spec(root_module)): raise errors.ExtensionNotFound( - f"Unable to find root module '{root_module}' in package '{package}'" + f"Unable to find root module '{root_module}' in package '{package or ''}'" ) if not (paths := spec.submodule_search_locations): raise errors.ExtensionNotFound(f"Module '{root_module}' is not a package") for ext_name in disnake.utils.walk_extensions(paths, prefix=f"{spec.name}.", ignore=ignore): - self.load_extension(ext_name) - yield ext_name + try: + self.load_extension(ext_name) + except Exception as e: + # always wrap in `ExtensionError` if not already + # (this should never happen, but we're doing it just in case) + if not isinstance(e, errors.ExtensionError): + e = errors.ExtensionFailed(ext_name, e) + + if return_exceptions: + yield e + else: + raise e + else: + yield ext_name @property def extensions(self) -> Mapping[str, types.ModuleType]: From 42ccd93bec017db93dd39094aa271f9d35a01c8f Mon Sep 17 00:00:00 2001 From: shiftinv Date: Sun, 9 Oct 2022 19:00:17 +0200 Subject: [PATCH 06/31] feat: discover all ext names first --- disnake/ext/commands/common_bot_base.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/disnake/ext/commands/common_bot_base.py b/disnake/ext/commands/common_bot_base.py index 3fa2d69943..00616c6728 100644 --- a/disnake/ext/commands/common_bot_base.py +++ b/disnake/ext/commands/common_bot_base.py @@ -687,7 +687,10 @@ def load_extensions( if not (paths := spec.submodule_search_locations): raise errors.ExtensionNotFound(f"Module '{root_module}' is not a package") - for ext_name in disnake.utils.walk_extensions(paths, prefix=f"{spec.name}.", ignore=ignore): + # collect all extension names first, in case of discovery errors + exts = list(disnake.utils.walk_extensions(paths, prefix=f"{spec.name}.", ignore=ignore)) + + for ext_name in exts: try: self.load_extension(ext_name) except Exception as e: From 82300deaed8c1c1c52fb0cdfc8f897e8d37b09d8 Mon Sep 17 00:00:00 2001 From: shiftinv Date: Sun, 9 Oct 2022 19:02:14 +0200 Subject: [PATCH 07/31] chore: rename `walk_extensions` to `walk_modules` --- disnake/ext/commands/common_bot_base.py | 6 +++--- disnake/utils.py | 6 +++--- docs/api.rst | 2 +- tests/test_utils.py | 10 ++++------ 4 files changed, 11 insertions(+), 13 deletions(-) diff --git a/disnake/ext/commands/common_bot_base.py b/disnake/ext/commands/common_bot_base.py index 00616c6728..99135636bb 100644 --- a/disnake/ext/commands/common_bot_base.py +++ b/disnake/ext/commands/common_bot_base.py @@ -611,7 +611,7 @@ def load_extensions( """ Loads all extensions in a given module, also traversing into sub-packages. - See :func:`disnake.utils.walk_extensions` for details on how packages are found. + See :func:`disnake.utils.walk_modules` for details on how packages are found. This may raise any errors that :func:`load_extension` can raise, in addition to the ones documented below. @@ -635,7 +635,7 @@ def load_extensions( An iterable of module names to ignore, or a callable that's used for ignoring modules (where the callable returning ``True`` results in the module being ignored). - See :func:`disnake.utils.walk_extensions` for details. + See :func:`disnake.utils.walk_modules` for details. return_exceptions: :class:`bool` If set to ``True``, exceptions raised by the internal :func:`load_extension` calls are yielded/returned instead of immediately propagating the first exception to the caller @@ -688,7 +688,7 @@ def load_extensions( raise errors.ExtensionNotFound(f"Module '{root_module}' is not a package") # collect all extension names first, in case of discovery errors - exts = list(disnake.utils.walk_extensions(paths, prefix=f"{spec.name}.", ignore=ignore)) + exts = list(disnake.utils.walk_modules(paths, prefix=f"{spec.name}.", ignore=ignore)) for ext_name in exts: try: diff --git a/disnake/utils.py b/disnake/utils.py index 4d8ed4c57d..e05f801bb9 100644 --- a/disnake/utils.py +++ b/disnake/utils.py @@ -69,7 +69,7 @@ "escape_mentions", "as_chunks", "format_dt", - "walk_extensions", + "walk_modules", "as_valid_locale", ) @@ -1282,7 +1282,7 @@ def format_dt(dt: Union[datetime.datetime, float], /, style: TimestampStyle = "f return f"" -def walk_extensions( +def walk_modules( paths: Iterable[str], prefix: str = "", ignore: Optional[Union[Iterable[str], Callable[[str], bool]]] = None, @@ -1359,7 +1359,7 @@ def walk_extensions( continue if sub_paths := mod.__path__: - yield from walk_extensions(sub_paths, prefix=f"{name}.", ignore=ignore) + yield from walk_modules(sub_paths, prefix=f"{name}.", ignore=ignore) else: yield name diff --git a/docs/api.rst b/docs/api.rst index 9763bd3a5a..232f342248 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -1560,7 +1560,7 @@ Utility Functions .. autofunction:: disnake.utils.as_chunks -.. autofunction:: disnake.utils.walk_extensions +.. autofunction:: disnake.utils.walk_modules .. autofunction:: disnake.utils.as_valid_locale diff --git a/tests/test_utils.py b/tests/test_utils.py index 8e78b35dde..e816505bf5 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -912,15 +912,13 @@ def tmp_module_root(tmp_path_factory: pytest.TempPathFactory): ), ], ) -def test_walk_extensions(tmp_module_root: Path, ignore, expected): +def test_walk_modules(tmp_module_root: Path, ignore, expected): path = str(tmp_module_root / "toplevel") - assert sorted(utils.walk_extensions([path], "toplevel.", ignore)) == sorted(expected) + assert sorted(utils.walk_modules([path], "toplevel.", ignore)) == sorted(expected) -def test_walk_extensions_nonexistent(tmp_module_root: Path): - assert ( - list(utils.walk_extensions([str(tmp_module_root / "doesnotexist")], "doesnotexist.")) == [] - ) +def test_walk_modules_nonexistent(tmp_module_root: Path): + assert list(utils.walk_modules([str(tmp_module_root / "doesnotexist")], "doesnotexist.")) == [] @pytest.mark.parametrize( From 9c4e784b0180b2166f418bbd9211eac75ad00093 Mon Sep 17 00:00:00 2001 From: shiftinv Date: Sun, 9 Oct 2022 19:12:54 +0200 Subject: [PATCH 08/31] docs: improve `walk_modules` docs --- disnake/utils.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/disnake/utils.py b/disnake/utils.py index e05f801bb9..9830a7ed82 100644 --- a/disnake/utils.py +++ b/disnake/utils.py @@ -1288,14 +1288,15 @@ def walk_modules( ignore: Optional[Union[Iterable[str], Callable[[str], bool]]] = None, ) -> Iterator[str]: """ - Walk through the given package paths, and recursively yield modules. + Walks through the given package paths, and recursively yields modules. This is similar to :func:`py:pkgutil.walk_packages`, but supports ignoring modules/packages. + If a package has a ``setup`` function, this method will + yield its name and not traverse the package further. Namespace packages are not considered, meaning every package must have an - ``__init__.py`` file. If a package has a ``setup`` function, this method will - yield its name and not traverse the package further. + ``__init__.py`` file. Nonexistent paths are silently ignored. From f3b24db43ae772617f1ae88c23983d33a0433301 Mon Sep 17 00:00:00 2001 From: shiftinv Date: Sun, 9 Oct 2022 19:13:38 +0200 Subject: [PATCH 09/31] fix: update `load_extensions` return type not really necessary since Iterator is covariant, but might as well do it anyway --- disnake/ext/commands/common_bot_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/disnake/ext/commands/common_bot_base.py b/disnake/ext/commands/common_bot_base.py index 99135636bb..d9af599602 100644 --- a/disnake/ext/commands/common_bot_base.py +++ b/disnake/ext/commands/common_bot_base.py @@ -607,7 +607,7 @@ def load_extensions( package: Optional[str] = None, ignore: Optional[Union[Iterable[str], Callable[[str], bool]]] = None, return_exceptions: bool = False, - ) -> Iterator[Union[str, errors.ExtensionError]]: + ) -> Union[Iterator[str], Iterator[Union[str, errors.ExtensionError]]]: """ Loads all extensions in a given module, also traversing into sub-packages. From 33819be2ed5e12c73d1d3084a1cbb9240b0358e3 Mon Sep 17 00:00:00 2001 From: shiftinv Date: Sun, 9 Oct 2022 19:21:29 +0200 Subject: [PATCH 10/31] fix: make `walk_modules:ignore` less complex --- disnake/utils.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/disnake/utils.py b/disnake/utils.py index 9830a7ed82..909e91f247 100644 --- a/disnake/utils.py +++ b/disnake/utils.py @@ -1320,10 +1320,7 @@ def walk_modules( modules (where the callable returning ``True`` results in the module being ignored). Defaults to ``None``, i.e. no modules are ignored. - If it's an iterable, the elements must be module names. That is, - a module like ``cogs.admin.eval_cmd`` will be ignored if ``admin``, ``eval_cmd``, - or ``admin.eval_cmd`` is given, but not with ``admin.eval`` or ``cmd``, to name - a few examples. + If it's an iterable, module names that start with any of the given strings will be ignored. Raises ------ @@ -1341,9 +1338,8 @@ def walk_modules( raise TypeError("`ignore` must be an iterable of strings or a callable") if isinstance(ignore, Iterable): - ignore_parts = "|".join(re.escape(i) for i in ignore) - ignore_re = re.compile(rf"(^|\.)({ignore_parts})(\.|$)") - ignore = lambda path: ignore_re.search(path) is not None + ignore_tup = tuple(ignore) + ignore = lambda path: path.startswith(ignore_tup) # else, it's already a callable or None for _, name, ispkg in pkgutil.iter_modules(paths, prefix): From 2b7d78a7f3e5042811538aca9f8420a45af3dc6d Mon Sep 17 00:00:00 2001 From: shiftinv Date: Sun, 9 Oct 2022 19:40:54 +0200 Subject: [PATCH 11/31] feat: avoid duplicate `__path__` entries --- disnake/utils.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/disnake/utils.py b/disnake/utils.py index 909e91f247..8d1de46951 100644 --- a/disnake/utils.py +++ b/disnake/utils.py @@ -1342,6 +1342,8 @@ def walk_modules( ignore = lambda path: path.startswith(ignore_tup) # else, it's already a callable or None + seen: Set[str] = set() + for _, name, ispkg in pkgutil.iter_modules(paths, prefix): if ignore and ignore(name): continue @@ -1355,7 +1357,13 @@ def walk_modules( yield name continue - if sub_paths := mod.__path__: + sub_paths: List[str] = [] + for p in mod.__path__ or []: + if p not in seen: + seen.add(p) + sub_paths.append(p) + + if sub_paths: yield from walk_modules(sub_paths, prefix=f"{name}.", ignore=ignore) else: yield name From a09045e7deb637404e79d9582116ead207b19fec Mon Sep 17 00:00:00 2001 From: shiftinv Date: Sun, 9 Oct 2022 19:53:11 +0200 Subject: [PATCH 12/31] fix(test): update walk_modules tests --- tests/test_utils.py | 57 ++++++++++++++++++--------------------------- 1 file changed, 23 insertions(+), 34 deletions(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index e816505bf5..4909a5d8c8 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -827,7 +827,7 @@ def tmp_module_root(tmp_path_factory: pytest.TempPathFactory): _create_dirs( tmpdir, { - "toplevel": { + "a": { "__init__.py": "", "nosetup.py": "", "withsetup.py": setup, @@ -868,53 +868,42 @@ def tmp_module_root(tmp_path_factory: pytest.TempPathFactory): ( None, [ - "toplevel.nosetup", - "toplevel.withsetup", - "toplevel.a_module.abc", - "toplevel.cool_ext", - "toplevel.mod.ext", - "toplevel.mod.sub.sub1.abc", - "toplevel.mod.sub.sub1.def", - "toplevel.mod.sub.sub2", + "a.nosetup", + "a.withsetup", + "a.a_module.abc", + "a.cool_ext", + "a.mod.ext", + "a.mod.sub.sub1.abc", + "a.mod.sub.sub1.def", + "a.mod.sub.sub2", ], ), ( - ["sub1.abc"], + ["a.nosetup", "a.mod.sub.sub1.abc", "a.mod.ext"], [ - "toplevel.nosetup", - "toplevel.withsetup", - "toplevel.a_module.abc", - "toplevel.cool_ext", - "toplevel.mod.ext", - "toplevel.mod.sub.sub1.def", - "toplevel.mod.sub.sub2", - ], - ), - ( - ["ext", "a_module.abc", "sub.sub1"], - [ - "toplevel.nosetup", - "toplevel.withsetup", - "toplevel.cool_ext", - "toplevel.mod.sub.sub2", + "a.withsetup", + "a.a_module.abc", + "a.cool_ext", + "a.mod.sub.sub1.def", + "a.mod.sub.sub2", ], ), ( lambda name: "ext" in name, # pyright: ignore[reportUnknownLambdaType] [ - "toplevel.nosetup", - "toplevel.withsetup", - "toplevel.a_module.abc", - "toplevel.mod.sub.sub1.abc", - "toplevel.mod.sub.sub1.def", - "toplevel.mod.sub.sub2", + "a.nosetup", + "a.withsetup", + "a.a_module.abc", + "a.mod.sub.sub1.abc", + "a.mod.sub.sub1.def", + "a.mod.sub.sub2", ], ), ], ) def test_walk_modules(tmp_module_root: Path, ignore, expected): - path = str(tmp_module_root / "toplevel") - assert sorted(utils.walk_modules([path], "toplevel.", ignore)) == sorted(expected) + path = str(tmp_module_root / "a") + assert sorted(utils.walk_modules([path], "a.", ignore)) == sorted(expected) def test_walk_modules_nonexistent(tmp_module_root: Path): From ea069ae299ed1216c6f6bf5d0ff3db6619f83e5d Mon Sep 17 00:00:00 2001 From: shiftinv Date: Sun, 9 Oct 2022 20:10:30 +0200 Subject: [PATCH 13/31] fix: un-deprecate `load_extensions()` --- disnake/ext/commands/common_bot_base.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/disnake/ext/commands/common_bot_base.py b/disnake/ext/commands/common_bot_base.py index d9af599602..27fd1f54ef 100644 --- a/disnake/ext/commands/common_bot_base.py +++ b/disnake/ext/commands/common_bot_base.py @@ -627,9 +627,10 @@ def load_extensions( ---------- root_module: :class:`str` The module/package name to search in, for example `cogs.admin`. + Also supports paths in the current working directory. package: Optional[:class:`str`] The package name to resolve relative imports with. - This is required when ``root_module`` is relative, e.g ``.cogs.admin``. + This is required when ``root_module`` is a relative module name, e.g ``.cogs.admin``. Defaults to ``None``. ignore: Union[Iterable[:class:`str`], Callable[[:class:`str`], :class:`bool`]] An iterable of module names to ignore, or a callable that's used for ignoring @@ -662,13 +663,6 @@ def load_extensions( The module names or raised exceptions as they are being loaded (if ``return_exceptions=True``). """ if "/" in root_module or "\\" in root_module: - # likely a path, try to be backwards compatible by converting to - # a relative path and using that as the module name - disnake.utils.warn_deprecated( - "Using a directory with `load_extensions` is deprecated. Use a module name (optionally with a package) instead.", - stacklevel=2, - ) - path = os.path.relpath(root_module) if ".." in path: raise ValueError( From 2ff26beef56272ab7cd97e9bd16cb895368f030f Mon Sep 17 00:00:00 2001 From: shiftinv Date: Sun, 9 Oct 2022 20:40:19 +0200 Subject: [PATCH 14/31] docs: improve exception documentation --- disnake/ext/commands/common_bot_base.py | 7 ++++++- disnake/utils.py | 4 +++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/disnake/ext/commands/common_bot_base.py b/disnake/ext/commands/common_bot_base.py index 27fd1f54ef..a3c2b72592 100644 --- a/disnake/ext/commands/common_bot_base.py +++ b/disnake/ext/commands/common_bot_base.py @@ -649,11 +649,16 @@ def load_extensions( The given root module could not be found. This is also raised if the name of the root module could not be resolved using the provided ``package`` parameter. - ExtensionError If ``return_exceptions=False``, other extension-related errors may also be raised as this method calls :func:`load_extension` on all found extensions. See :func:`load_extension` for further details on raised exceptions. + ValueError + ``root_module`` is a path and outside of the cwd. + TypeError + The ``ignore`` parameter is of an invalid type. + ImportError + A package (not module) couldn't be imported. Yields ------ diff --git a/disnake/utils.py b/disnake/utils.py index 8d1de46951..0837ab3e3d 100644 --- a/disnake/utils.py +++ b/disnake/utils.py @@ -1324,10 +1324,12 @@ def walk_modules( Raises ------ + ValueError + The ``paths`` parameter is not an iterable. TypeError The ``ignore`` parameter is of an invalid type. ImportError - A package couldn't be imported. + A package (not module) couldn't be imported. Yields ------ From be6122e0476e70a4e2d6e561763770d2423bf423 Mon Sep 17 00:00:00 2001 From: shiftinv Date: Sun, 9 Oct 2022 20:54:06 +0200 Subject: [PATCH 15/31] fix: use proper error type --- disnake/ext/commands/common_bot_base.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/disnake/ext/commands/common_bot_base.py b/disnake/ext/commands/common_bot_base.py index a3c2b72592..e7eeb8abd0 100644 --- a/disnake/ext/commands/common_bot_base.py +++ b/disnake/ext/commands/common_bot_base.py @@ -645,11 +645,9 @@ def load_extensions( Raises ------ - ExtensionNotFound - The given root module could not be found. - This is also raised if the name of the root module could not - be resolved using the provided ``package`` parameter. ExtensionError + The given root module could not be found, + or the name of the root module could not be resolved using the provided ``package`` parameter. If ``return_exceptions=False``, other extension-related errors may also be raised as this method calls :func:`load_extension` on all found extensions. See :func:`load_extension` for further details on raised exceptions. @@ -679,12 +677,15 @@ def load_extensions( root_module = self._resolve_name(root_module, package) if not (spec := importlib.util.find_spec(root_module)): - raise errors.ExtensionNotFound( - f"Unable to find root module '{root_module}' in package '{package or ''}'" + raise errors.ExtensionError( + f"Unable to find root module '{root_module}' in package '{package or ''}'", + name=root_module, ) if not (paths := spec.submodule_search_locations): - raise errors.ExtensionNotFound(f"Module '{root_module}' is not a package") + raise errors.ExtensionError( + f"Module '{root_module}' is not a package", name=root_module + ) # collect all extension names first, in case of discovery errors exts = list(disnake.utils.walk_modules(paths, prefix=f"{spec.name}.", ignore=ignore)) From 3610789ed7aa4dd48aebcd0e1c6adbcf12cd4de6 Mon Sep 17 00:00:00 2001 From: shiftinv Date: Sun, 9 Oct 2022 20:54:24 +0200 Subject: [PATCH 16/31] feat: update `test_bot.__main__` --- test_bot/__main__.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/test_bot/__main__.py b/test_bot/__main__.py index 5e084fdcb3..b964463f1f 100644 --- a/test_bot/__main__.py +++ b/test_bot/__main__.py @@ -2,7 +2,6 @@ import asyncio import logging -import os import sys import traceback @@ -51,10 +50,6 @@ async def on_ready(self): ) # fmt: on - def add_cog(self, cog: commands.Cog, *, override: bool = False) -> None: - logger.info(f"Loading cog {cog.qualified_name}.") - return super().add_cog(cog, override=override) - async def on_command_error(self, ctx: commands.Context, error: commands.CommandError) -> None: msg = f"Command `{ctx.command}` failed due to `{error}`" logger.error(msg, exc_info=True) @@ -126,5 +121,6 @@ async def on_message_command_error( if __name__ == "__main__": bot = TestBot() - bot.load_extensions(os.path.join(__package__, Config.cogs_folder)) + for e in bot.load_extensions(".cogs", package=__package__): + logger.info(f"Loaded extension {e}.") bot.run(Config.token) From 460cdf3d4a1f8d4ba4e211c4070e1e9bff7b67f8 Mon Sep 17 00:00:00 2001 From: shiftinv Date: Mon, 10 Oct 2022 14:23:39 +0200 Subject: [PATCH 17/31] fix: update module not found exception --- disnake/ext/commands/common_bot_base.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/disnake/ext/commands/common_bot_base.py b/disnake/ext/commands/common_bot_base.py index e7eeb8abd0..e5df4550ac 100644 --- a/disnake/ext/commands/common_bot_base.py +++ b/disnake/ext/commands/common_bot_base.py @@ -678,8 +678,7 @@ def load_extensions( if not (spec := importlib.util.find_spec(root_module)): raise errors.ExtensionError( - f"Unable to find root module '{root_module}' in package '{package or ''}'", - name=root_module, + f"Unable to find root module '{root_module}'", name=root_module ) if not (paths := spec.submodule_search_locations): From 30846984a98594f7d790c030f08c78d55f1aaa0f Mon Sep 17 00:00:00 2001 From: shiftinv Date: Mon, 10 Oct 2022 14:41:50 +0200 Subject: [PATCH 18/31] feat: add `load_callback` parameter, remove iterator --- disnake/ext/commands/common_bot_base.py | 51 ++++++++++++++++--------- test_bot/__main__.py | 7 +++- 2 files changed, 39 insertions(+), 19 deletions(-) diff --git a/disnake/ext/commands/common_bot_base.py b/disnake/ext/commands/common_bot_base.py index e5df4550ac..9233b8db7c 100644 --- a/disnake/ext/commands/common_bot_base.py +++ b/disnake/ext/commands/common_bot_base.py @@ -17,11 +17,11 @@ Dict, Generic, Iterable, - Iterator, List, Literal, Mapping, Optional, + Sequence, Set, TypeVar, Union, @@ -585,8 +585,9 @@ def load_extensions( *, package: Optional[str] = None, ignore: Optional[Union[Iterable[str], Callable[[str], bool]]] = None, + load_callback: Optional[Callable[[str], None]] = None, return_exceptions: Literal[False] = False, - ) -> Iterator[str]: + ) -> Sequence[str]: ... @overload @@ -596,8 +597,9 @@ def load_extensions( *, package: Optional[str] = None, ignore: Optional[Union[Iterable[str], Callable[[str], bool]]] = None, + load_callback: Optional[Callable[[Union[str, errors.ExtensionError]], None]] = None, return_exceptions: Literal[True], - ) -> Iterator[Union[str, errors.ExtensionError]]: + ) -> Sequence[Union[str, errors.ExtensionError]]: ... def load_extensions( @@ -606,8 +608,11 @@ def load_extensions( *, package: Optional[str] = None, ignore: Optional[Union[Iterable[str], Callable[[str], bool]]] = None, + load_callback: Optional[ + Union[Callable[[str], None], Callable[[Union[str, errors.ExtensionError]], None]] + ] = None, return_exceptions: bool = False, - ) -> Union[Iterator[str], Iterator[Union[str, errors.ExtensionError]]]: + ) -> Union[Sequence[str], Sequence[Union[str, errors.ExtensionError]]]: """ Loads all extensions in a given module, also traversing into sub-packages. @@ -620,8 +625,9 @@ def load_extensions( .. versionchanged:: 2.6 Now accepts a module name instead of a filesystem path. - Also improved package traversal, adding support for more complex extensions - with ``__init__.py`` files, and added ``ignore`` parameter. + Improved package traversal, adding support for more complex extensions + with ``__init__.py`` files. + Also added ``ignore``, ``load_callback`` and ``return_exceptions`` parameters. Parameters ---------- @@ -632,15 +638,18 @@ def load_extensions( The package name to resolve relative imports with. This is required when ``root_module`` is a relative module name, e.g ``.cogs.admin``. Defaults to ``None``. - ignore: Union[Iterable[:class:`str`], Callable[[:class:`str`], :class:`bool`]] + ignore: Optional[Union[Iterable[:class:`str`], Callable[[:class:`str`], :class:`bool`]]] An iterable of module names to ignore, or a callable that's used for ignoring modules (where the callable returning ``True`` results in the module being ignored). See :func:`disnake.utils.walk_modules` for details. + load_callback: Optional[Union[Callable[[:class:`str`], None], Callable[[Union[:class:`str`, :class:`ExtensionError`]], None]]] + A callback that gets invoked with the extension name when each extension gets loaded. + If ``return_exceptions=True``, also receives raised exceptions that occured while trying to load extensions. return_exceptions: :class:`bool` If set to ``True``, exceptions raised by the internal :func:`load_extension` calls - are yielded/returned instead of immediately propagating the first exception to the caller - (similar to :func:`py:asyncio.gather`). + are not immediately propagated to the caller (similar to :func:`py:asyncio.gather`). + See ``load_callback`` and the ``Raises`` and ``Returns`` sections. Defaults to ``False``. Raises @@ -658,12 +667,11 @@ def load_extensions( ImportError A package (not module) couldn't be imported. - Yields - ------ - :class:`str` - The module names as they are being loaded (if ``return_exceptions=False``, the default). - Union[:class:`str`, :class:`ExtensionError`] - The module names or raised exceptions as they are being loaded (if ``return_exceptions=True``). + Returns + ------- + Union[Sequence[:class:`str`], Sequence[Union[:class:`str`, :class:`ExtensionError`]]] + The list of module names that have been loaded + (including :class:`ExtensionError`\\s if ``return_exceptions=True``). """ if "/" in root_module or "\\" in root_module: path = os.path.relpath(root_module) @@ -689,6 +697,13 @@ def load_extensions( # collect all extension names first, in case of discovery errors exts = list(disnake.utils.walk_modules(paths, prefix=f"{spec.name}.", ignore=ignore)) + ret: List[Union[str, errors.ExtensionError]] = [] + + def add_result(r: Union[str, errors.ExtensionError]) -> None: + ret.append(r) + if load_callback: + load_callback(r) # type: ignore # can't assert callable parameter type + for ext_name in exts: try: self.load_extension(ext_name) @@ -699,11 +714,13 @@ def load_extensions( e = errors.ExtensionFailed(ext_name, e) if return_exceptions: - yield e + add_result(e) else: raise e else: - yield ext_name + add_result(ext_name) + + return ret @property def extensions(self) -> Mapping[str, types.ModuleType]: diff --git a/test_bot/__main__.py b/test_bot/__main__.py index b964463f1f..e340b2efa7 100644 --- a/test_bot/__main__.py +++ b/test_bot/__main__.py @@ -121,6 +121,9 @@ async def on_message_command_error( if __name__ == "__main__": bot = TestBot() - for e in bot.load_extensions(".cogs", package=__package__): - logger.info(f"Loaded extension {e}.") + bot.load_extensions( + ".cogs", + package=__package__, + load_callback=lambda e: logger.info(f"Loaded extension {e}."), + ) bot.run(Config.token) From 7690556d55e0385ecd9add3ddec68e16bf7def2b Mon Sep 17 00:00:00 2001 From: shiftinv Date: Mon, 10 Oct 2022 15:20:41 +0200 Subject: [PATCH 19/31] refactor: move to separate find_extensions method for easier customization --- disnake/ext/commands/common_bot_base.py | 113 ++++++++++++++++-------- 1 file changed, 75 insertions(+), 38 deletions(-) diff --git a/disnake/ext/commands/common_bot_base.py b/disnake/ext/commands/common_bot_base.py index 9233b8db7c..0996d72e2a 100644 --- a/disnake/ext/commands/common_bot_base.py +++ b/disnake/ext/commands/common_bot_base.py @@ -578,6 +578,75 @@ def reload_extension(self, name: str, *, package: Optional[str] = None) -> None: sys.modules.update(modules) raise + def find_extensions( + self, + root_module: str, + *, + package: Optional[str] = None, + ignore: Optional[Union[Iterable[str], Callable[[str], bool]]] = None, + ) -> List[str]: + """ + Finds all extensions in a given module, also traversing into sub-packages. + + See :func:`disnake.utils.walk_modules` for details on how packages are found. + + .. versionadded:: 2.7 + + Parameters + ---------- + root_module: :class:`str` + The module/package name to search in, for example ``cogs.admin``. + Also supports paths in the current working directory. + package: Optional[:class:`str`] + The package name to resolve relative imports with. + This is required when ``root_module`` is a relative module name, e.g ``.cogs.admin``. + Defaults to ``None``. + ignore: Optional[Union[Iterable[:class:`str`], Callable[[:class:`str`], :class:`bool`]]] + An iterable of module names to ignore, or a callable that's used for ignoring + modules (where the callable returning ``True`` results in the module being ignored). + + See :func:`disnake.utils.walk_modules` for details. + + Raises + ------ + ExtensionError + The given root module could not be found, + or the name of the root module could not be resolved using the provided ``package`` parameter. + ValueError + ``root_module`` is a path and outside of the cwd. + TypeError + The ``ignore`` parameter is of an invalid type. + ImportError + A package couldn't be imported. + + Returns + ------- + List[:class:`str`] + The list of full extension names. + """ + if "/" in root_module or "\\" in root_module: + path = os.path.relpath(root_module) + if ".." in path: + raise ValueError( + "Paths outside the cwd are not supported. Try using the module name instead." + ) + root_module = path.replace(os.sep, ".") + + # `find_spec` already calls `resolve_name`, but we want our custom error handling here + root_module = self._resolve_name(root_module, package) + + if not (spec := importlib.util.find_spec(root_module)): + raise errors.ExtensionError( + f"Unable to find root module '{root_module}'", name=root_module + ) + + if not (paths := spec.submodule_search_locations): + raise errors.ExtensionError( + f"Module '{root_module}' is not a package", name=root_module + ) + + return list(disnake.utils.walk_modules(paths, prefix=f"{spec.name}.", ignore=ignore)) + @overload def load_extensions( self, @@ -618,31 +687,23 @@ def load_extensions( See :func:`disnake.utils.walk_modules` for details on how packages are found. - This may raise any errors that :func:`load_extension` can raise, in addition to - the ones documented below. - .. versionadded:: 2.4 - .. versionchanged:: 2.6 + .. versionchanged:: 2.7 Now accepts a module name instead of a filesystem path. Improved package traversal, adding support for more complex extensions with ``__init__.py`` files. - Also added ``ignore``, ``load_callback`` and ``return_exceptions`` parameters. + Also added ``package``, ``ignore``, ``load_callback`` and ``return_exceptions`` parameters. Parameters ---------- root_module: :class:`str` - The module/package name to search in, for example `cogs.admin`. + The module/package name to search in, for example ``cogs.admin``. Also supports paths in the current working directory. package: Optional[:class:`str`] - The package name to resolve relative imports with. - This is required when ``root_module`` is a relative module name, e.g ``.cogs.admin``. - Defaults to ``None``. + See :func:`find_extensions`. ignore: Optional[Union[Iterable[:class:`str`], Callable[[:class:`str`], :class:`bool`]]] - An iterable of module names to ignore, or a callable that's used for ignoring - modules (where the callable returning ``True`` results in the module being ignored). - - See :func:`disnake.utils.walk_modules` for details. + See :func:`find_extensions`. load_callback: Optional[Union[Callable[[:class:`str`], None], Callable[[Union[:class:`str`, :class:`ExtensionError`]], None]]] A callback that gets invoked with the extension name when each extension gets loaded. If ``return_exceptions=True``, also receives raised exceptions that occured while trying to load extensions. @@ -673,30 +734,6 @@ def load_extensions( The list of module names that have been loaded (including :class:`ExtensionError`\\s if ``return_exceptions=True``). """ - if "/" in root_module or "\\" in root_module: - path = os.path.relpath(root_module) - if ".." in path: - raise ValueError( - "Paths outside the cwd are not supported. Try using the module name instead." - ) - root_module = path.replace(os.sep, ".") - - # `find_spec` already calls `resolve_name`, but we want our custom error handling here - root_module = self._resolve_name(root_module, package) - - if not (spec := importlib.util.find_spec(root_module)): - raise errors.ExtensionError( - f"Unable to find root module '{root_module}'", name=root_module - ) - - if not (paths := spec.submodule_search_locations): - raise errors.ExtensionError( - f"Module '{root_module}' is not a package", name=root_module - ) - - # collect all extension names first, in case of discovery errors - exts = list(disnake.utils.walk_modules(paths, prefix=f"{spec.name}.", ignore=ignore)) - ret: List[Union[str, errors.ExtensionError]] = [] def add_result(r: Union[str, errors.ExtensionError]) -> None: @@ -704,7 +741,7 @@ def add_result(r: Union[str, errors.ExtensionError]) -> None: if load_callback: load_callback(r) # type: ignore # can't assert callable parameter type - for ext_name in exts: + for ext_name in self.find_extensions(root_module, package=package, ignore=ignore): try: self.load_extension(ext_name) except Exception as e: From 08033a7410ae806741a146ccbcddd56c6b068879 Mon Sep 17 00:00:00 2001 From: shiftinv Date: Mon, 10 Oct 2022 15:31:32 +0200 Subject: [PATCH 20/31] chore: use list instead of sequence --- disnake/ext/commands/common_bot_base.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/disnake/ext/commands/common_bot_base.py b/disnake/ext/commands/common_bot_base.py index 0996d72e2a..38e3f0a026 100644 --- a/disnake/ext/commands/common_bot_base.py +++ b/disnake/ext/commands/common_bot_base.py @@ -21,7 +21,6 @@ Literal, Mapping, Optional, - Sequence, Set, TypeVar, Union, @@ -656,7 +655,7 @@ def load_extensions( ignore: Optional[Union[Iterable[str], Callable[[str], bool]]] = None, load_callback: Optional[Callable[[str], None]] = None, return_exceptions: Literal[False] = False, - ) -> Sequence[str]: + ) -> List[str]: ... @overload @@ -668,7 +667,7 @@ def load_extensions( ignore: Optional[Union[Iterable[str], Callable[[str], bool]]] = None, load_callback: Optional[Callable[[Union[str, errors.ExtensionError]], None]] = None, return_exceptions: Literal[True], - ) -> Sequence[Union[str, errors.ExtensionError]]: + ) -> List[Union[str, errors.ExtensionError]]: ... def load_extensions( @@ -681,7 +680,7 @@ def load_extensions( Union[Callable[[str], None], Callable[[Union[str, errors.ExtensionError]], None]] ] = None, return_exceptions: bool = False, - ) -> Union[Sequence[str], Sequence[Union[str, errors.ExtensionError]]]: + ) -> Union[List[str], List[Union[str, errors.ExtensionError]]]: """ Loads all extensions in a given module, also traversing into sub-packages. @@ -730,7 +729,7 @@ def load_extensions( Returns ------- - Union[Sequence[:class:`str`], Sequence[Union[:class:`str`, :class:`ExtensionError`]]] + Union[List[:class:`str`], List[Union[:class:`str`, :class:`ExtensionError`]]] The list of module names that have been loaded (including :class:`ExtensionError`\\s if ``return_exceptions=True``). """ From 2f28f207142a40b58a02b021b98dd2a1b95c748e Mon Sep 17 00:00:00 2001 From: shiftinv Date: Mon, 10 Oct 2022 15:35:25 +0200 Subject: [PATCH 21/31] docs: mention find_extensions for customization --- disnake/ext/commands/common_bot_base.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/disnake/ext/commands/common_bot_base.py b/disnake/ext/commands/common_bot_base.py index 38e3f0a026..19ebfb6659 100644 --- a/disnake/ext/commands/common_bot_base.py +++ b/disnake/ext/commands/common_bot_base.py @@ -694,6 +694,15 @@ def load_extensions( with ``__init__.py`` files. Also added ``package``, ``ignore``, ``load_callback`` and ``return_exceptions`` parameters. + .. note:: + For further customization, you may use :func:`find_extensions`: + + .. code-block:: python3 + + for extension_name in bot.find_extensions(...): + ... # custom logic + bot.load_extension(extension_name) + Parameters ---------- root_module: :class:`str` From b402320be64a8c40bb39c5c360ec45670645e2f0 Mon Sep 17 00:00:00 2001 From: shiftinv Date: Mon, 10 Oct 2022 15:49:48 +0200 Subject: [PATCH 22/31] chore(test): move directory utils to separate file --- tests/helpers.py | 29 ++++++++++++++++++++++++++++- tests/test_utils.py | 21 ++------------------- 2 files changed, 30 insertions(+), 20 deletions(-) diff --git a/tests/helpers.py b/tests/helpers.py index a8f7479feb..1ed5ff54fb 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -1,10 +1,14 @@ # SPDX-License-Identifier: MIT import asyncio +import contextlib import datetime import functools +import os +import sys import types -from typing import Callable, ContextManager, Optional, Type, TypeVar +from pathlib import Path +from typing import Any, Callable, ContextManager, Dict, Iterator, Optional, Type, TypeVar, Union from unittest import mock CallableT = TypeVar("CallableT", bound=Callable) @@ -59,3 +63,26 @@ def wrap_sync(*args, **kwargs): return func(*args, **kwargs) return wrap_sync # type: ignore + + +def create_dirs(parent: Union[str, Path[str]], data: Dict[str, Any]) -> None: + parent = Path(parent) if isinstance(parent, str) else parent + for name, value in data.items(): + path = parent / name + if isinstance(value, dict): + path.mkdir() + create_dirs(path, value) + elif isinstance(value, str): + path.write_text(value) + + +@contextlib.contextmanager +def chdir_module(path: Union[str, Path[str]]) -> Iterator[None]: + orig_cwd = os.getcwd() + try: + os.chdir(path) + sys.path.insert(0, str(path)) + yield + finally: + os.chdir(orig_cwd) + sys.path.remove(str(path)) diff --git a/tests/test_utils.py b/tests/test_utils.py index 4909a5d8c8..634d2b47a6 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -3,7 +3,6 @@ import asyncio import datetime import inspect -import os import sys import warnings from dataclasses import dataclass @@ -809,22 +808,12 @@ def test_format_dt(dt, style, expected): assert utils.format_dt(dt, style) == expected -def _create_dirs(parent: Path, data: Dict[str, Any]) -> None: - for name, value in data.items(): - path = parent / name - if isinstance(value, dict): - path.mkdir() - _create_dirs(path, value) - elif isinstance(value, str): - path.write_text(value) - - @pytest.fixture(scope="session") def tmp_module_root(tmp_path_factory: pytest.TempPathFactory): tmpdir = tmp_path_factory.mktemp("module_root") setup = "def setup(bot): ..." - _create_dirs( + helpers.create_dirs( tmpdir, { "a": { @@ -852,14 +841,8 @@ def tmp_module_root(tmp_path_factory: pytest.TempPathFactory): }, ) - orig_cwd = os.getcwd() - try: - os.chdir(tmpdir) - sys.path.insert(0, str(tmpdir)) + with helpers.chdir_module(tmpdir): yield tmpdir - finally: - os.chdir(orig_cwd) - sys.path.remove(str(tmpdir)) @pytest.mark.parametrize( From e83249ca8c21066d2fa314decb17ee61e072051b Mon Sep 17 00:00:00 2001 From: shiftinv Date: Mon, 10 Oct 2022 16:23:27 +0200 Subject: [PATCH 23/31] test: add some simple tests --- tests/ext/__init__.py | 1 + tests/ext/commands/__init__.py | 1 + tests/ext/commands/test_common_bot_base.py | 45 ++++++++++++++++++++++ 3 files changed, 47 insertions(+) create mode 100644 tests/ext/__init__.py create mode 100644 tests/ext/commands/__init__.py create mode 100644 tests/ext/commands/test_common_bot_base.py diff --git a/tests/ext/__init__.py b/tests/ext/__init__.py new file mode 100644 index 0000000000..548d2d447d --- /dev/null +++ b/tests/ext/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: MIT diff --git a/tests/ext/commands/__init__.py b/tests/ext/commands/__init__.py new file mode 100644 index 0000000000..548d2d447d --- /dev/null +++ b/tests/ext/commands/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: MIT diff --git a/tests/ext/commands/test_common_bot_base.py b/tests/ext/commands/test_common_bot_base.py new file mode 100644 index 0000000000..ffc7e62fc3 --- /dev/null +++ b/tests/ext/commands/test_common_bot_base.py @@ -0,0 +1,45 @@ +# SPDX-License-Identifier: MIT + +import asyncio +from pathlib import Path +from typing import Iterator +from unittest import mock + +import pytest + +from disnake.ext.commands import errors +from disnake.ext.commands.common_bot_base import CommonBotBase + +from ... import helpers + + +class TestExtensions: + @pytest.fixture() + def module_root(self, tmpdir: Path) -> Iterator[str]: + with helpers.chdir_module(tmpdir): + yield str(tmpdir) + + @pytest.fixture() + def bot(self): + with mock.patch.object(asyncio, "get_event_loop", mock.Mock()), mock.patch.object( + CommonBotBase, "_fill_owners", mock.Mock() + ): + bot = CommonBotBase() + return bot + + def test_find_path_invalid(self, bot: CommonBotBase): + with pytest.raises(ValueError, match=r"Paths outside the cwd are not supported"): + bot.find_extensions("../../etc/passwd") + + def test_find(self, bot: CommonBotBase, module_root: str): + helpers.create_dirs(module_root, {"test_cogs": {"__init__.py": "", "admin.py": ""}}) + + assert bot.find_extensions("test_cogs") + + with pytest.raises(errors.ExtensionError, match=r"Unable to find root module 'other_cogs'"): + bot.find_extensions("other_cogs") + + with pytest.raises( + errors.ExtensionError, match=r"Module 'test_cogs.admin' is not a package" + ): + bot.find_extensions(".admin", package="test_cogs") From 3be2be58714210b44aaf28458d1f9edc3d542f2e Mon Sep 17 00:00:00 2001 From: shiftinv Date: Mon, 10 Oct 2022 16:25:10 +0200 Subject: [PATCH 24/31] fix: python 3.8 moment --- tests/helpers.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/helpers.py b/tests/helpers.py index 1ed5ff54fb..27f225d5c8 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -1,5 +1,7 @@ # SPDX-License-Identifier: MIT +from __future__ import annotations + import asyncio import contextlib import datetime @@ -65,7 +67,7 @@ def wrap_sync(*args, **kwargs): return wrap_sync # type: ignore -def create_dirs(parent: Union[str, Path[str]], data: Dict[str, Any]) -> None: +def create_dirs(parent: Union[str, Path], data: Dict[str, Any]) -> None: parent = Path(parent) if isinstance(parent, str) else parent for name, value in data.items(): path = parent / name @@ -77,7 +79,7 @@ def create_dirs(parent: Union[str, Path[str]], data: Dict[str, Any]) -> None: @contextlib.contextmanager -def chdir_module(path: Union[str, Path[str]]) -> Iterator[None]: +def chdir_module(path: Union[str, Path]) -> Iterator[None]: orig_cwd = os.getcwd() try: os.chdir(path) From 413a5a3ad897f064585d7bb4a64ab18a068c3f89 Mon Sep 17 00:00:00 2001 From: shiftinv Date: Mon, 10 Oct 2022 18:04:51 +0200 Subject: [PATCH 25/31] docs: add changelog entries --- changelog/796.breaking.rst | 1 + changelog/796.feature.rst | 3 +++ 2 files changed, 4 insertions(+) create mode 100644 changelog/796.breaking.rst create mode 100644 changelog/796.feature.rst diff --git a/changelog/796.breaking.rst b/changelog/796.breaking.rst new file mode 100644 index 0000000000..a6cf043ab5 --- /dev/null +++ b/changelog/796.breaking.rst @@ -0,0 +1 @@ +Remove :func:`disnake.utils.search_directory` in favor of :func:`disnake.utils.walk_modules`. diff --git a/changelog/796.feature.rst b/changelog/796.feature.rst new file mode 100644 index 0000000000..1eba022400 --- /dev/null +++ b/changelog/796.feature.rst @@ -0,0 +1,3 @@ +|commands| Improve :func:`Bot.load_extensions `, add :func:`Bot.find_extensions `. +- Better support for more complex extension hierarchies. +- New ``package``, ``ignore``, ``load_callback`` and ``return_exceptions`` parameters. From c6a779b9711d7806d4007701536f01295de4d304 Mon Sep 17 00:00:00 2001 From: shiftinv Date: Tue, 11 Oct 2022 12:35:26 +0200 Subject: [PATCH 26/31] revert: remove `return_exceptions`, unnecessary complexity --- changelog/796.feature.rst | 2 +- disnake/ext/commands/common_bot_base.py | 74 ++++--------------------- 2 files changed, 12 insertions(+), 64 deletions(-) diff --git a/changelog/796.feature.rst b/changelog/796.feature.rst index 1eba022400..52a9af6221 100644 --- a/changelog/796.feature.rst +++ b/changelog/796.feature.rst @@ -1,3 +1,3 @@ |commands| Improve :func:`Bot.load_extensions `, add :func:`Bot.find_extensions `. - Better support for more complex extension hierarchies. -- New ``package``, ``ignore``, ``load_callback`` and ``return_exceptions`` parameters. +- New ``package``, ``ignore``, and ``load_callback`` parameters. diff --git a/disnake/ext/commands/common_bot_base.py b/disnake/ext/commands/common_bot_base.py index 19ebfb6659..14c139ceed 100644 --- a/disnake/ext/commands/common_bot_base.py +++ b/disnake/ext/commands/common_bot_base.py @@ -18,13 +18,11 @@ Generic, Iterable, List, - Literal, Mapping, Optional, Set, TypeVar, Union, - overload, ) import disnake @@ -646,7 +644,6 @@ def find_extensions( return list(disnake.utils.walk_modules(paths, prefix=f"{spec.name}.", ignore=ignore)) - @overload def load_extensions( self, root_module: str, @@ -654,32 +651,6 @@ def load_extensions( package: Optional[str] = None, ignore: Optional[Union[Iterable[str], Callable[[str], bool]]] = None, load_callback: Optional[Callable[[str], None]] = None, - return_exceptions: Literal[False] = False, - ) -> List[str]: - ... - - @overload - def load_extensions( - self, - root_module: str, - *, - package: Optional[str] = None, - ignore: Optional[Union[Iterable[str], Callable[[str], bool]]] = None, - load_callback: Optional[Callable[[Union[str, errors.ExtensionError]], None]] = None, - return_exceptions: Literal[True], - ) -> List[Union[str, errors.ExtensionError]]: - ... - - def load_extensions( - self, - root_module: str, - *, - package: Optional[str] = None, - ignore: Optional[Union[Iterable[str], Callable[[str], bool]]] = None, - load_callback: Optional[ - Union[Callable[[str], None], Callable[[Union[str, errors.ExtensionError]], None]] - ] = None, - return_exceptions: bool = False, ) -> Union[List[str], List[Union[str, errors.ExtensionError]]]: """ Loads all extensions in a given module, also traversing into sub-packages. @@ -692,7 +663,7 @@ def load_extensions( Now accepts a module name instead of a filesystem path. Improved package traversal, adding support for more complex extensions with ``__init__.py`` files. - Also added ``package``, ``ignore``, ``load_callback`` and ``return_exceptions`` parameters. + Also added ``package``, ``ignore``, and ``load_callback`` parameters. .. note:: For further customization, you may use :func:`find_extensions`: @@ -706,27 +677,20 @@ def load_extensions( Parameters ---------- root_module: :class:`str` - The module/package name to search in, for example ``cogs.admin``. - Also supports paths in the current working directory. + See :func:`find_extensions`. package: Optional[:class:`str`] See :func:`find_extensions`. ignore: Optional[Union[Iterable[:class:`str`], Callable[[:class:`str`], :class:`bool`]]] See :func:`find_extensions`. - load_callback: Optional[Union[Callable[[:class:`str`], None], Callable[[Union[:class:`str`, :class:`ExtensionError`]], None]]] + load_callback: Optional[Callable[[:class:`str`], None]] A callback that gets invoked with the extension name when each extension gets loaded. - If ``return_exceptions=True``, also receives raised exceptions that occured while trying to load extensions. - return_exceptions: :class:`bool` - If set to ``True``, exceptions raised by the internal :func:`load_extension` calls - are not immediately propagated to the caller (similar to :func:`py:asyncio.gather`). - See ``load_callback`` and the ``Raises`` and ``Returns`` sections. - Defaults to ``False``. Raises ------ ExtensionError The given root module could not be found, or the name of the root module could not be resolved using the provided ``package`` parameter. - If ``return_exceptions=False``, other extension-related errors may also be raised + Other extension-related errors may also be raised as this method calls :func:`load_extension` on all found extensions. See :func:`load_extension` for further details on raised exceptions. ValueError @@ -738,32 +702,16 @@ def load_extensions( Returns ------- - Union[List[:class:`str`], List[Union[:class:`str`, :class:`ExtensionError`]]] - The list of module names that have been loaded - (including :class:`ExtensionError`\\s if ``return_exceptions=True``). + List[:class:`str`] + The list of module names that have been loaded. """ - ret: List[Union[str, errors.ExtensionError]] = [] - - def add_result(r: Union[str, errors.ExtensionError]) -> None: - ret.append(r) - if load_callback: - load_callback(r) # type: ignore # can't assert callable parameter type + ret: List[str] = [] for ext_name in self.find_extensions(root_module, package=package, ignore=ignore): - try: - self.load_extension(ext_name) - except Exception as e: - # always wrap in `ExtensionError` if not already - # (this should never happen, but we're doing it just in case) - if not isinstance(e, errors.ExtensionError): - e = errors.ExtensionFailed(ext_name, e) - - if return_exceptions: - add_result(e) - else: - raise e - else: - add_result(ext_name) + self.load_extension(ext_name) + ret.append(ext_name) + if load_callback: + load_callback(ext_name) return ret From e159c7ec13d8a788f9f2af65142f5a97234f4e91 Mon Sep 17 00:00:00 2001 From: shiftinv Date: Tue, 11 Oct 2022 16:51:01 +0200 Subject: [PATCH 27/31] chore: move to separate documentation --- changelog/796.breaking.rst | 2 +- disnake/ext/commands/common_bot_base.py | 14 ++++-- disnake/utils.py | 55 ++---------------------- docs/api.rst | 2 - docs/ext/commands/extensions.rst | 57 +++++++++++++++++++++++++ tests/test_utils.py | 4 +- 6 files changed, 73 insertions(+), 61 deletions(-) diff --git a/changelog/796.breaking.rst b/changelog/796.breaking.rst index a6cf043ab5..73a291d1cb 100644 --- a/changelog/796.breaking.rst +++ b/changelog/796.breaking.rst @@ -1 +1 @@ -Remove :func:`disnake.utils.search_directory` in favor of :func:`disnake.utils.walk_modules`. +Remove :func:`disnake.utils.search_directory` in favor of :func:`disnake.ext.commands.Bot.find_extensions`. diff --git a/disnake/ext/commands/common_bot_base.py b/disnake/ext/commands/common_bot_base.py index 14c139ceed..ad6f55bbcf 100644 --- a/disnake/ext/commands/common_bot_base.py +++ b/disnake/ext/commands/common_bot_base.py @@ -585,10 +585,15 @@ def find_extensions( """ Finds all extensions in a given module, also traversing into sub-packages. - See :func:`disnake.utils.walk_modules` for details on how packages are found. + See :ref:`ext_commands_extensions_load` for details on how packages are found. .. versionadded:: 2.7 + .. note:: + This imports all *packages* (not all modules) in the given path(s) + to access the ``__path__`` attribute for finding submodules, + unless they are filtered by the ``ignore`` parameter. + Parameters ---------- root_module: :class:`str` @@ -601,8 +606,9 @@ def find_extensions( ignore: Optional[Union[Iterable[:class:`str`], Callable[[:class:`str`], :class:`bool`]]] An iterable of module names to ignore, or a callable that's used for ignoring modules (where the callable returning ``True`` results in the module being ignored). + Defaults to ``None``, i.e. no modules are ignored. - See :func:`disnake.utils.walk_modules` for details. + If it's an iterable, module names that start with any of the given strings will be ignored. Raises ------ @@ -642,7 +648,7 @@ def find_extensions( f"Module '{root_module}' is not a package", name=root_module ) - return list(disnake.utils.walk_modules(paths, prefix=f"{spec.name}.", ignore=ignore)) + return list(disnake.utils._walk_modules(paths, prefix=f"{spec.name}.", ignore=ignore)) def load_extensions( self, @@ -655,7 +661,7 @@ def load_extensions( """ Loads all extensions in a given module, also traversing into sub-packages. - See :func:`disnake.utils.walk_modules` for details on how packages are found. + See :func:`find_extensions` for details. .. versionadded:: 2.4 diff --git a/disnake/utils.py b/disnake/utils.py index 0837ab3e3d..cc55a1b5c1 100644 --- a/disnake/utils.py +++ b/disnake/utils.py @@ -69,7 +69,6 @@ "escape_mentions", "as_chunks", "format_dt", - "walk_modules", "as_valid_locale", ) @@ -1282,60 +1281,12 @@ def format_dt(dt: Union[datetime.datetime, float], /, style: TimestampStyle = "f return f"" -def walk_modules( +# this is similar to pkgutil.walk_packages, but with a few modifications +def _walk_modules( paths: Iterable[str], prefix: str = "", ignore: Optional[Union[Iterable[str], Callable[[str], bool]]] = None, ) -> Iterator[str]: - """ - Walks through the given package paths, and recursively yields modules. - - This is similar to :func:`py:pkgutil.walk_packages`, but supports ignoring - modules/packages. - If a package has a ``setup`` function, this method will - yield its name and not traverse the package further. - - Namespace packages are not considered, meaning every package must have an - ``__init__.py`` file. - - Nonexistent paths are silently ignored. - - .. note:: - This imports all *packages* (not all modules) in the given path(s) - to access the ``__path__`` attribute for finding submodules, - unless they are filtered by the ``ignore`` parameter. - - Parameters - ---------- - paths: Iterable[:class:`str`] - The filesystem paths of packages to search in. - prefix: :class:`str` - The prefix to prepend to all module names. This should be set - accordingly to produce importable package names. - - For example, if ``paths`` contains ``/bot/cogs/admin``, this should - be set to `cogs.admin.` assuming the current working directory is ``/bot``. - ignore: Optional[Union[Iterable[:class:`str`], Callable[[:class:`str`], :class:`bool`]]] - An iterable of module names to ignore, or a callable that's used for ignoring - modules (where the callable returning ``True`` results in the module being ignored). - Defaults to ``None``, i.e. no modules are ignored. - - If it's an iterable, module names that start with any of the given strings will be ignored. - - Raises - ------ - ValueError - The ``paths`` parameter is not an iterable. - TypeError - The ``ignore`` parameter is of an invalid type. - ImportError - A package (not module) couldn't be imported. - - Yields - ------ - :class:`str` - The full module names in the given package paths. - """ if isinstance(ignore, str): raise TypeError("`ignore` must be an iterable of strings or a callable") @@ -1366,7 +1317,7 @@ def walk_modules( sub_paths.append(p) if sub_paths: - yield from walk_modules(sub_paths, prefix=f"{name}.", ignore=ignore) + yield from _walk_modules(sub_paths, prefix=f"{name}.", ignore=ignore) else: yield name diff --git a/docs/api.rst b/docs/api.rst index 232f342248..2cd662ef8f 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -1560,8 +1560,6 @@ Utility Functions .. autofunction:: disnake.utils.as_chunks -.. autofunction:: disnake.utils.walk_modules - .. autofunction:: disnake.utils.as_valid_locale .. _discord-api-enums: diff --git a/docs/ext/commands/extensions.rst b/docs/ext/commands/extensions.rst index b0c178a920..fbb76d090a 100644 --- a/docs/ext/commands/extensions.rst +++ b/docs/ext/commands/extensions.rst @@ -64,3 +64,60 @@ Although rare, sometimes an extension needs to clean-up or know when it's being def teardown(bot): print('I am being unloaded!') + +.. _ext_commands_extensions_load: + +Loading multiple extensions +----------------------------- + +Commonly, you might have a package/folder that contains several modules. +Instead of manually loading them one by one, you can use :meth:`.Bot.load_extensions` to load the entire package in one sweep. + +Consider the following directory structure: + +.. code-block:: + + my_bot/ + ├─── cogs/ + │ ├─── admin.py + │ ├─── fun.py + │ └─── other_complex_thing/ + │ ├─── __init__.py (contains setup) + │ ├─── data.py + │ └─── models.py + └─── main.py + +Now, you could call :meth:`.Bot.load_extension` separately on ``cogs.admin``, ``cogs.fun``, and ``cogs.other_complex_thing``; +however, if you add a new extension, you'd need to once again load it separately. + +Instead, you can use ``bot.load_extensions("my_bot.cogs")`` (or ``.load_extensions(".cogs", package=__package__)``) +to load all of them automatically, without any extra work required. + +Customization ++++++++++++++++ + +To adjust the loading process, for example to handle exceptions that may occur, use :meth:`.Bot.find_extensions`. +This is also what :meth:`.Bot.load_extensions` uses internally. + +As an example, one could load extensions like this: + +.. code-block:: python3 + + for extension in bot.find_extensions("my_bot.cogs"): + try: + bot.load_extension(extension) + except commands.ExtensionError as e: + logger.warning(f"Failed to load extension {extension}: {e}") + +Discovery ++++++++++++ + +:meth:`.Bot.find_extensions` (and by extension, :meth:`.Bot.load_extensions`) discover modules/extensions +similar to :func:`py:pkgutil.walk_packages`; the given root package name is resolved, +and submodules/-packages are iterated through recursively. + +If a package has a ``setup`` function (similar to ``my_bot.cogs.other_complex_thing`` above), +it won't be traversed further, i.e. ``data.py`` and ``models.py`` in the example won't be considered separate extensions. + +Namespace packages (see `PEP 420 `__) are not supported (other than +the provided root package), meaning every subpackage must have an ``__init__.py`` file. diff --git a/tests/test_utils.py b/tests/test_utils.py index 634d2b47a6..276d2e0291 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -886,11 +886,11 @@ def tmp_module_root(tmp_path_factory: pytest.TempPathFactory): ) def test_walk_modules(tmp_module_root: Path, ignore, expected): path = str(tmp_module_root / "a") - assert sorted(utils.walk_modules([path], "a.", ignore)) == sorted(expected) + assert sorted(utils._walk_modules([path], "a.", ignore)) == sorted(expected) def test_walk_modules_nonexistent(tmp_module_root: Path): - assert list(utils.walk_modules([str(tmp_module_root / "doesnotexist")], "doesnotexist.")) == [] + assert list(utils._walk_modules([str(tmp_module_root / "doesnotexist")], "doesnotexist.")) == [] @pytest.mark.parametrize( From 1a3b309ba72d022d1df368fdee5829a35abdfccb Mon Sep 17 00:00:00 2001 From: shiftinv Date: Tue, 11 Oct 2022 17:27:40 +0200 Subject: [PATCH 28/31] fix: reinstate `utils.search_directory` for now --- changelog/796.breaking.rst | 1 - changelog/796.deprecate.rst | 1 + disnake/utils.py | 41 +++++++++++++++++++++++++++++++++++++ docs/api.rst | 2 ++ 4 files changed, 44 insertions(+), 1 deletion(-) delete mode 100644 changelog/796.breaking.rst create mode 100644 changelog/796.deprecate.rst diff --git a/changelog/796.breaking.rst b/changelog/796.breaking.rst deleted file mode 100644 index 73a291d1cb..0000000000 --- a/changelog/796.breaking.rst +++ /dev/null @@ -1 +0,0 @@ -Remove :func:`disnake.utils.search_directory` in favor of :func:`disnake.ext.commands.Bot.find_extensions`. diff --git a/changelog/796.deprecate.rst b/changelog/796.deprecate.rst new file mode 100644 index 0000000000..1b24b620bf --- /dev/null +++ b/changelog/796.deprecate.rst @@ -0,0 +1 @@ +:func:`disnake.utils.search_directory` will be removed in a future version, in favor of :func:`disnake.ext.commands.Bot.find_extensions` which is the most common usecase and is more consistent. diff --git a/disnake/utils.py b/disnake/utils.py index cc55a1b5c1..376ba87cfa 100644 --- a/disnake/utils.py +++ b/disnake/utils.py @@ -8,6 +8,7 @@ import functools import importlib import json +import os import pkgutil import re import sys @@ -69,6 +70,7 @@ "escape_mentions", "as_chunks", "format_dt", + "search_directory", "as_valid_locale", ) @@ -1281,6 +1283,45 @@ def format_dt(dt: Union[datetime.datetime, float], /, style: TimestampStyle = "f return f"" +@deprecated("disnake.ext.commands.Bot.find_extensions") +def search_directory(path: str) -> Iterator[str]: + """Walk through a directory and yield all modules. + + .. deprecated:: 2.7 + + Parameters + ---------- + path: :class:`str` + The path to search for modules + + Yields + ------- + :class:`str` + The name of the found module. (usable in load_extension) + """ + relpath = os.path.relpath(path) # relative and normalized + if ".." in relpath: + raise ValueError("Modules outside the cwd require a package to be specified") + + abspath = os.path.abspath(path) + if not os.path.exists(relpath): + raise ValueError(f"Provided path '{abspath}' does not exist") + if not os.path.isdir(relpath): + raise ValueError(f"Provided path '{abspath}' is not a directory") + + prefix = relpath.replace(os.sep, ".") + if prefix in ("", "."): + prefix = "" + else: + prefix += "." + + for _, name, ispkg in pkgutil.iter_modules([path]): + if ispkg: + yield from search_directory(os.path.join(path, name)) + else: + yield prefix + name + + # this is similar to pkgutil.walk_packages, but with a few modifications def _walk_modules( paths: Iterable[str], diff --git a/docs/api.rst b/docs/api.rst index 2cd662ef8f..4215491f56 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -1560,6 +1560,8 @@ Utility Functions .. autofunction:: disnake.utils.as_chunks +.. autofunction:: disnake.utils.search_directory + .. autofunction:: disnake.utils.as_valid_locale .. _discord-api-enums: From bece2cf86ac16782f54470e1e7eb9a10726bb682 Mon Sep 17 00:00:00 2001 From: shiftinv Date: Tue, 15 Nov 2022 11:58:14 +0100 Subject: [PATCH 29/31] feat: use tuple instead of list in `find_extensions` --- disnake/ext/commands/common_bot_base.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/disnake/ext/commands/common_bot_base.py b/disnake/ext/commands/common_bot_base.py index 054b0ab352..7385b6c432 100644 --- a/disnake/ext/commands/common_bot_base.py +++ b/disnake/ext/commands/common_bot_base.py @@ -20,6 +20,7 @@ List, Mapping, Optional, + Sequence, Set, TypeVar, Union, @@ -583,7 +584,7 @@ def find_extensions( *, package: Optional[str] = None, ignore: Optional[Union[Iterable[str], Callable[[str], bool]]] = None, - ) -> List[str]: + ) -> Sequence[str]: """ Finds all extensions in a given module, also traversing into sub-packages. @@ -626,7 +627,7 @@ def find_extensions( Returns ------- - List[:class:`str`] + Sequence[:class:`str`] The list of full extension names. """ if "/" in root_module or "\\" in root_module: @@ -650,7 +651,7 @@ def find_extensions( f"Module '{root_module}' is not a package", name=root_module ) - return list(disnake.utils._walk_modules(paths, prefix=f"{spec.name}.", ignore=ignore)) + return tuple(disnake.utils._walk_modules(paths, prefix=f"{spec.name}.", ignore=ignore)) def load_extensions( self, From b008329c83ad12009c976f28bb73a132b30dc41c Mon Sep 17 00:00:00 2001 From: shiftinv Date: Tue, 15 Nov 2022 11:59:24 +0100 Subject: [PATCH 30/31] chore: invert condition to save one indent level --- disnake/utils.py | 38 ++++++++++++++++++++------------------ 1 file changed, 20 insertions(+), 18 deletions(-) diff --git a/disnake/utils.py b/disnake/utils.py index 33e91e5771..9bac68f484 100644 --- a/disnake/utils.py +++ b/disnake/utils.py @@ -1346,25 +1346,27 @@ def _walk_modules( if ignore and ignore(name): continue - if ispkg: - mod = importlib.import_module(name) - - # if this module is a package but also has a `setup` function, - # yield it and don't look for other files in this module - if hasattr(mod, "setup"): - yield name - continue - - sub_paths: List[str] = [] - for p in mod.__path__ or []: - if p not in seen: - seen.add(p) - sub_paths.append(p) - - if sub_paths: - yield from _walk_modules(sub_paths, prefix=f"{name}.", ignore=ignore) - else: + if not ispkg: yield name + continue + + # it's a package here + mod = importlib.import_module(name) + + # if this module is a package but also has a `setup` function, + # yield it and don't look for other files in this module + if hasattr(mod, "setup"): + yield name + continue + + sub_paths: List[str] = [] + for p in mod.__path__ or []: + if p not in seen: + seen.add(p) + sub_paths.append(p) + + if sub_paths: + yield from _walk_modules(sub_paths, prefix=f"{name}.", ignore=ignore) def as_valid_locale(locale: str) -> Optional[str]: From 4ed28a662c713baddcbaf0e12986a7a5ea8157ed Mon Sep 17 00:00:00 2001 From: shiftinv Date: Tue, 4 Apr 2023 15:34:53 +0200 Subject: [PATCH 31/31] chore: resolve lint issues --- disnake/ext/commands/common_bot_base.py | 6 ++---- test_bot/__main__.py | 2 +- tests/ext/commands/test_common_bot_base.py | 8 ++++---- tests/test_utils.py | 4 ++-- 4 files changed, 9 insertions(+), 11 deletions(-) diff --git a/disnake/ext/commands/common_bot_base.py b/disnake/ext/commands/common_bot_base.py index 8d0d105496..8135377662 100644 --- a/disnake/ext/commands/common_bot_base.py +++ b/disnake/ext/commands/common_bot_base.py @@ -632,8 +632,7 @@ def find_extensions( package: Optional[str] = None, ignore: Optional[Union[Iterable[str], Callable[[str], bool]]] = None, ) -> Sequence[str]: - """ - Finds all extensions in a given module, also traversing into sub-packages. + """Finds all extensions in a given module, also traversing into sub-packages. See :ref:`ext_commands_extensions_load` for details on how packages are found. @@ -708,8 +707,7 @@ def load_extensions( ignore: Optional[Union[Iterable[str], Callable[[str], bool]]] = None, load_callback: Optional[Callable[[str], None]] = None, ) -> Union[List[str], List[Union[str, errors.ExtensionError]]]: - """ - Loads all extensions in a given module, also traversing into sub-packages. + """Loads all extensions in a given module, also traversing into sub-packages. See :func:`find_extensions` for details. diff --git a/test_bot/__main__.py b/test_bot/__main__.py index 47c0466adc..7058695edf 100644 --- a/test_bot/__main__.py +++ b/test_bot/__main__.py @@ -124,6 +124,6 @@ async def on_message_command_error( bot.load_extensions( ".cogs", package=__package__, - load_callback=lambda e: logger.info(f"Loaded extension {e}."), + load_callback=lambda e: logger.info("Loaded extension %s.", e), ) bot.run(Config.token) diff --git a/tests/ext/commands/test_common_bot_base.py b/tests/ext/commands/test_common_bot_base.py index ffc7e62fc3..fc91ccc3fa 100644 --- a/tests/ext/commands/test_common_bot_base.py +++ b/tests/ext/commands/test_common_bot_base.py @@ -14,12 +14,12 @@ class TestExtensions: - @pytest.fixture() + @pytest.fixture def module_root(self, tmpdir: Path) -> Iterator[str]: with helpers.chdir_module(tmpdir): yield str(tmpdir) - @pytest.fixture() + @pytest.fixture def bot(self): with mock.patch.object(asyncio, "get_event_loop", mock.Mock()), mock.patch.object( CommonBotBase, "_fill_owners", mock.Mock() @@ -27,11 +27,11 @@ def bot(self): bot = CommonBotBase() return bot - def test_find_path_invalid(self, bot: CommonBotBase): + def test_find_path_invalid(self, bot: CommonBotBase) -> None: with pytest.raises(ValueError, match=r"Paths outside the cwd are not supported"): bot.find_extensions("../../etc/passwd") - def test_find(self, bot: CommonBotBase, module_root: str): + def test_find(self, bot: CommonBotBase, module_root: str) -> None: helpers.create_dirs(module_root, {"test_cogs": {"__init__.py": "", "admin.py": ""}}) assert bot.find_extensions("test_cogs") diff --git a/tests/test_utils.py b/tests/test_utils.py index 0da369d68b..879ff55e74 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -883,12 +883,12 @@ def tmp_module_root(tmp_path_factory: pytest.TempPathFactory): ), ], ) -def test_walk_modules(tmp_module_root: Path, ignore, expected): +def test_walk_modules(tmp_module_root: Path, ignore, expected) -> None: path = str(tmp_module_root / "a") assert sorted(utils._walk_modules([path], "a.", ignore)) == sorted(expected) -def test_walk_modules_nonexistent(tmp_module_root: Path): +def test_walk_modules_nonexistent(tmp_module_root: Path) -> None: assert list(utils._walk_modules([str(tmp_module_root / "doesnotexist")], "doesnotexist.")) == []