Skip to content

Commit

Permalink
Merge pull request #55 from edornd/feat/54-dyn-sources
Browse files Browse the repository at this point in the history
✨ Rework dynamic sources (#54)
  • Loading branch information
edornd authored Aug 6, 2024
2 parents 4907a66 + 9651241 commit 5a2d3b7
Show file tree
Hide file tree
Showing 19 changed files with 378 additions and 149 deletions.
25 changes: 16 additions & 9 deletions argdantic/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from pydantic_core import PydanticUndefined

from argdantic.parsing import ActionTracker, Argument, PrimitiveArgument, registry
from argdantic.sources.base import SourceBaseModel
from argdantic.utils import is_optional
from argdantic.sources import DEFAULT_SOURCE_FIELD
from argdantic.utils import get_optional_type, is_optional


def format_description(description: Optional[str], has_default: bool, is_required: bool) -> str:
Expand Down Expand Up @@ -64,7 +64,7 @@ def argument_from_field(
assert not lenient_issubclass(field_info.annotation, BaseModel)
base_option_name = delimiter.join(parent_path + (kebab_name,))
full_option_name = f"--{base_option_name}"
extra_fields: dict[str, Any] = (
extra_fields: Dict[str, Any] = (
field_info.json_schema_extra or {} if isinstance(field_info.json_schema_extra, dict) else {}
)
extra_names = extra_fields.get("names", ())
Expand Down Expand Up @@ -114,27 +114,34 @@ def model_to_args(
# checks on delimiters to be done
kebab_name = field_name.replace("_", "-")
assert internal_delimiter not in kebab_name
if lenient_issubclass(field_info.annotation, BaseModel):

annotation = (
field_info.annotation
if not is_optional(field_info.annotation)
else get_optional_type(field_info.annotation)
)
if lenient_issubclass(annotation, BaseModel):
yield from model_to_args(
cast(Type[BaseModel], field_info.annotation),
cast(Type[BaseModel], annotation),
delimiter,
internal_delimiter,
parent_path=parent_path + (kebab_name,),
)
# if the model requires a file source, we add an extra argument
# whose name is the same as the model's name
if lenient_issubclass(field_info.annotation, SourceBaseModel):
# whose name is the same as the model's name (yes I'm not gonna bother with mypy here)
if hasattr(annotation, "__arg_source_field__") and annotation.__arg_source_field__ is None: # type: ignore
default = PydanticUndefined if annotation.__arg_source_required__ else None # type: ignore
info = FieldInfo(
annotation=Path,
alias=field_info.alias,
title=field_info.title,
description=field_info.description,
default=field_info.default,
default=default,
json_schema_extra=field_info.json_schema_extra,
)
base_name = delimiter.join(parent_path + (kebab_name,))
internal_name = base_name.replace(delimiter, internal_delimiter).replace("-", "_")
custom_identifier = f"{internal_name}{internal_delimiter}_source"
custom_identifier = f"{internal_name}{internal_delimiter}{DEFAULT_SOURCE_FIELD}"
yield argument_from_field(
field_info=info,
kebab_name=kebab_name,
Expand Down
19 changes: 9 additions & 10 deletions argdantic/core.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import inspect
from argparse import ArgumentParser, Namespace, _SubParsersAction
from typing import Any, Callable, Generic, Iterable, List, Optional, Sequence, Type, TypeVar, cast, get_type_hints
from typing import Any, Callable, Dict, Generic, Iterable, List, Optional, Sequence, Type, TypeVar, cast, get_type_hints

from pydantic import BaseModel, ValidationError, create_model
from pydantic.v1.utils import lenient_issubclass
Expand Down Expand Up @@ -43,7 +43,7 @@ def __init__(
self.delimiter = delimiter
self.arguments = arguments or []
self.stores = stores or []
self.trackers: dict[str, ActionTracker] = {}
self.trackers: Dict[str, ActionTracker] = {}

def __repr__(self) -> str:
return f"<Command {self.name}>"
Expand Down Expand Up @@ -110,7 +110,7 @@ def __init__(
internal_delimiter: str = "__",
subcommand_meta: str = "<command>",
) -> None:
self.entrypoint: ArgumentParser | None = None
self.entrypoint: Optional[ArgumentParser] = None
self.name = name
self.description = description
self.force_group = force_group
Expand All @@ -125,7 +125,7 @@ def __init__(
self._subcommand_meta = subcommand_meta
# keeping a reference to subparser is necessary to add subparsers
# Each cli level can only have one subparser.
self._subparser: _SubParsersAction | None = None
self._subparser: Optional[_SubParsersAction] = None

def __repr__(self) -> str:
name = f" '{self.name}'" if self.name else ""
Expand Down Expand Up @@ -287,7 +287,7 @@ def decorator(f: Callable) -> Command:
# set the base Model and Config class
if sources:

class SourceSettings(BaseSettings):
class StaticSourceSettings(BaseSettings):
# patch the config class so that pydantic functionality remains
# the same, but the sources are properly initialized

Expand All @@ -304,12 +304,11 @@ def settings_customise_sources(
# this is needed to make sure that the config class is properly
# initialized with the sources declared by the user on CLI init.
# Env and file sources are discarded, the user must provide them explicitly.
if sources is not None:
callables = [source(settings_cls) for source in sources]
return (*callables, init_settings)
return (init_settings,)
source_list = cast(List[SettingSourceCallable], sources)
callables = [source(settings_cls) for source in source_list]
return (*callables, init_settings)

model_class = SourceSettings if model_class is None else (model_class, SourceSettings)
model_class = StaticSourceSettings if model_class is None else (model_class, StaticSourceSettings)

cfg_class = create_model( # type: ignore
"WrapperModel",
Expand Down
15 changes: 1 addition & 14 deletions argdantic/parsing/actions.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,7 @@
from argparse import OPTIONAL, Action, ArgumentParser, Namespace
from argparse import OPTIONAL, Action, ArgumentParser, Namespace, _copy_items # type: ignore
from typing import Any, Iterable, Optional, Sequence, Union


def _copy_items(items):
if items is None:
return []
# The copy module is used only in the 'append' and 'append_const'
# actions, and it is needed only when the default value isn't a list.
# Delay its import for speeding up the common case.
if type(items) is list:
return items[:]
import copy

return copy.copy(items)


class StoreAction(Action):
"""
Store action for argparse. This class is used to store the value of an argument.
Expand Down
2 changes: 1 addition & 1 deletion argdantic/parsing/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def __init__(

@abstractmethod
def build(self, parser: ArgumentParser) -> ActionTracker:
raise NotImplementedError
raise NotImplementedError # pragma: no cover

def build_internal(self, parser: ArgumentParser, *, action: Type[Action], **optional_fields: Any) -> ActionTracker:
tracker = ActionTracker(action)
Expand Down
4 changes: 2 additions & 2 deletions argdantic/registry.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from collections.abc import MutableMapping
from typing import Any, Iterator, Type, Union, get_origin
from typing import Any, Dict, Iterator, Type, Union, get_origin


class Registry(MutableMapping):
"""Simple class registry for mapping types and their argument handlers."""

def __init__(self) -> None:
self.store: dict[type, Any] = dict()
self.store: Dict[type, Any] = dict()

def __getitem__(self, key: type) -> Any:
# do not allow Union types (unless they are Optional, handled in conversion)
Expand Down
15 changes: 9 additions & 6 deletions argdantic/sources/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
from argdantic.sources.base import EnvSettingsSource, SecretsSettingsSource
from argdantic.sources.json import JsonModel, JsonSettingsSource
from argdantic.sources.toml import TomlModel, TomlSettingsSource
from argdantic.sources.yaml import YamlModel, YamlSettingsSource
from argdantic.sources.dynamic import DEFAULT_SOURCE_FIELD, from_file
from argdantic.sources.json import JsonFileLoader, JsonSettingsSource
from argdantic.sources.toml import TomlFileLoader, TomlSettingsSource
from argdantic.sources.yaml import YamlFileLoader, YamlSettingsSource

__all__ = [
"from_file",
"DEFAULT_SOURCE_FIELD",
"EnvSettingsSource",
"SecretsSettingsSource",
"JsonSettingsSource",
"TomlSettingsSource",
"YamlSettingsSource",
"JsonModel",
"TomlModel",
"YamlModel",
"JsonFileLoader",
"TomlFileLoader",
"YamlFileLoader",
]
15 changes: 0 additions & 15 deletions argdantic/sources/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,21 +41,6 @@ def __init__(self, settings_cls: Type[BaseSettings], path: Union[str, Path]) ->
self.path = Path(path)


class SourceBaseModel(BaseSettings):
"""
A base model that reads additional settings from a file.
This helps making the CLI more flexible and allow composability via file.
"""

def __init__(self, _source: Path, _source_cls: Type[FileBaseSettingsSource], **data) -> None:
if _source is not None:
reader = _source_cls(self, _source) # type: ignore
extra_data = reader()
extra_data.update(data)
data = extra_data
super().__init__(**data)


class PydanticMultiEnvSource(PydanticEnvSource):
"""
A pydantic settings source that loads settings from multiple environment sources.
Expand Down
95 changes: 95 additions & 0 deletions argdantic/sources/dynamic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from pathlib import Path
from typing import Any, Dict, Optional, Tuple, Type, cast

from pydantic import BaseModel, ConfigDict
from pydantic_settings import BaseSettings, InitSettingsSource, PydanticBaseSettingsSource

from argdantic.sources.base import FileBaseSettingsSource

DEFAULT_SOURCE_FIELD = "_source"


class DynamicFileSource(PydanticBaseSettingsSource):
"""
Source class for loading values provided during settings class initialization.
"""

def __init__(
self,
settings_cls: Type[BaseSettings],
source_cls: Type[FileBaseSettingsSource],
init_kwargs: Dict[str, Any],
required: bool,
field_name: Optional[str] = None,
):
super().__init__(settings_cls)
self.init_kwargs = init_kwargs
self.field_name = field_name or DEFAULT_SOURCE_FIELD
if self.field_name not in init_kwargs:
if required:
raise ValueError("Missing required source")
self.source = None
else:
self.source = source_cls(settings_cls, init_kwargs[self.field_name])

def get_field_value(self, field: Any, field_name: str) -> Tuple[Any, str, bool]:
# Nothing to do here. Only implement the return statement to make mypy happy
return None, "", False # pragma: no cover

def __call__(self) -> Dict[str, Any]:
if self.source is not None:
main_kwargs = self.source()
main_kwargs.update(self.init_kwargs)
# remove the source field if it is the default one
if self.field_name == DEFAULT_SOURCE_FIELD:
main_kwargs.pop(self.field_name)
return main_kwargs
return self.init_kwargs

def __repr__(self) -> str:
return f"DynamicFileSource(source={self.source!r})"


def from_file(
loader: Type[FileBaseSettingsSource],
use_field: Optional[str] = None,
required: bool = True,
):
def decorator(cls):
if not issubclass(cls, BaseModel):
raise TypeError("@from_file can only be applied to Pydantic models")
if use_field is not None:
if use_field not in cls.model_fields:
raise ValueError(f"Field {use_field} not found in model {cls.__name__}")
field_annotation = cls.model_fields[use_field].annotation
if not issubclass(field_annotation, (str, Path)):
raise ValueError(f"Field {use_field} must be a string or Path to be used as file source")

class DynamicSourceSettings(cls, BaseSettings):
# required to eventually add a cli argument to the model
# if cli_field is None, an additional argument will be added
__arg_source_field__ = use_field
__arg_source_required__ = required
model_config = ConfigDict(extra="ignore")

@classmethod
def settings_customise_sources(
cls,
settings_cls: Type[BaseSettings],
init_settings: PydanticBaseSettingsSource,
env_settings: PydanticBaseSettingsSource,
dotenv_settings: PydanticBaseSettingsSource,
file_secret_settings: PydanticBaseSettingsSource,
) -> Tuple[PydanticBaseSettingsSource, ...]:
source = DynamicFileSource(
settings_cls,
loader,
cast(InitSettingsSource, init_settings).init_kwargs,
required,
use_field,
)
return (source,)

return DynamicSourceSettings

return decorator
18 changes: 4 additions & 14 deletions argdantic/sources/json.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,19 @@
from pathlib import Path
from typing import Any, Dict, Tuple, Type

from pydantic.fields import FieldInfo
from pydantic_settings import BaseSettings, PydanticBaseSettingsSource

from argdantic.sources.base import FileBaseSettingsSource, FileSettingsSourceBuilder, SourceBaseModel
from argdantic.sources.base import FileBaseSettingsSource, FileSettingsSourceBuilder


class PydanticJsonSource(FileBaseSettingsSource):
class JsonFileLoader(FileBaseSettingsSource):
"""
Class internal to pydantic-settings that reads settings from a JSON file.
This gets spawned by the JsonSettingsSource class.
"""

def get_field_value(self, field: FieldInfo, field_name: str) -> Tuple[Any, str, bool]:
return None, field_name, False
return None, field_name, False # pragma: no cover

def __call__(self) -> Dict[str, Any]:
try:
Expand All @@ -34,16 +33,7 @@ class JsonSettingsSource(FileSettingsSourceBuilder):
"""

def __call__(self, settings: Type[BaseSettings]) -> PydanticBaseSettingsSource:
return PydanticJsonSource(settings, self.path)
return JsonFileLoader(settings, self.path)

def __repr__(self) -> str:
return f"<JsonSettingsSource path={self.path}>"


class JsonModel(SourceBaseModel):
"""
A base model that reads additional settings from a JSON file.
"""

def __init__(self, _source: Path, **data) -> None:
super().__init__(_source, PydanticJsonSource, **data)
18 changes: 4 additions & 14 deletions argdantic/sources/toml.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,19 @@
from pathlib import Path
from typing import Any, Dict, Tuple, Type

from pydantic.fields import FieldInfo
from pydantic_settings import BaseSettings, PydanticBaseSettingsSource

from argdantic.sources.base import FileBaseSettingsSource, FileSettingsSourceBuilder, SourceBaseModel
from argdantic.sources.base import FileBaseSettingsSource, FileSettingsSourceBuilder


class PydanticTomlSource(FileBaseSettingsSource):
class TomlFileLoader(FileBaseSettingsSource):
"""
Class internal to pydantic-settings that reads settings from a TOML file.
This gets spawned by the TomlSettingsSource class.
"""

def get_field_value(self, field: FieldInfo, field_name: str) -> Tuple[Any, str, bool]:
return None, field_name, False
return None, field_name, False # pragma: no cover

def __call__(self) -> Dict[str, Any]:
try:
Expand All @@ -34,16 +33,7 @@ class TomlSettingsSource(FileSettingsSourceBuilder):
"""

def __call__(self, settings: Type[BaseSettings]) -> PydanticBaseSettingsSource:
return PydanticTomlSource(settings, self.path)
return TomlFileLoader(settings, self.path)

def __repr__(self) -> str:
return f"<TomlSettingsSource path={self.path}>"


class TomlModel(SourceBaseModel):
"""
A base model that reads additional settings from a TOML file.
"""

def __init__(self, _source: Path, **data) -> None:
super().__init__(_source, PydanticTomlSource, **data)
Loading

0 comments on commit 5a2d3b7

Please sign in to comment.