Skip to content

Commit

Permalink
Add mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
kmagusiak committed Jun 5, 2022
1 parent b79dfea commit fdacfc9
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 17 deletions.
31 changes: 20 additions & 11 deletions alphaconf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import re
import sys
import uuid
from typing import Any, Dict, Iterable, List, Optional, Union
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union, cast

from omegaconf import DictConfig, MissingMandatoryValue, OmegaConf

Expand Down Expand Up @@ -34,7 +34,7 @@
]

"""Map of default values"""
_DEFAULTS = {
_DEFAULTS: Dict[str, Any] = {
'configurations': [],
'helpers': {},
'testing_configurations': [],
Expand All @@ -51,6 +51,8 @@ class Application:
name, version, short_description, description, etc.
"""

__config: Optional[DictConfig]

def __init__(self, **properties) -> None:
"""Initialize the application.
Expand Down Expand Up @@ -109,7 +111,7 @@ def configuration(self) -> DictConfig:
"""Get the configuration of the application, initialize if necessary"""
if self.__config is None:
self.setup_configuration(
arguments=None, resolve_configuration=False, setup_logging=False
arguments=False, resolve_configuration=False, setup_logging=False
)
_log.info('alphaconf initialized')
assert self.__config is not None
Expand Down Expand Up @@ -187,8 +189,13 @@ def _get_configurations(
for path in self._get_possible_configuration_paths():
if os.path.exists(path):
_log.debug('Load configuration from %s', path)
yield OmegaConf.load(path)
conf = OmegaConf.load(path)
if isinstance(conf, DictConfig):
yield conf
else:
yield from conf
# Environment
prefixes: Optional[Tuple[str, ...]]
if env_prefixes is True:
_log.debug('Detecting accepted env prefixes')
default_keys = {k for cfg in _DEFAULTS['configurations'] for k in cfg.keys()}
Expand Down Expand Up @@ -247,7 +254,7 @@ def setup_configuration(
configurations = list(self._get_configurations(env_prefixes=env_prefixes))
if self.parsed:
configurations.extend(self.parsed.configurations())
self.__config = OmegaConf.merge(*configurations)
self.__config = cast(DictConfig, OmegaConf.merge(*configurations))
_log.debug('Merged %d configurations', len(configurations))

# Handle the result
Expand Down Expand Up @@ -303,7 +310,7 @@ def update_configuration(self, conf: Union[DictConfig, Dict]):
"""
current_config = self.configuration
try:
self.__config = OmegaConf.merge(current_config, conf)
self.__config = cast(DictConfig, OmegaConf.merge(current_config, conf))
yield self
finally:
self.__config = current_config
Expand Down Expand Up @@ -413,7 +420,7 @@ def __str__(self) -> str:


"""The application context"""
application = contextvars.ContextVar('application')
application: contextvars.ContextVar[Application] = contextvars.ContextVar('application')


def configuration() -> DictConfig:
Expand All @@ -440,16 +447,18 @@ def setup_configuration(
:param testing: If set, True adds the configuration to testing configurations,
if False, the testing configurations are cleared
"""
if not isinstance(conf, DictConfig):
conf = OmegaConf.create(conf)
if isinstance(conf, DictConfig):
config = conf
else:
config = cast(DictConfig, OmegaConf.create(conf))
if testing is False:
_DEFAULTS['testing_configurations'].clear()
config_key = 'testing_configurations' if testing else 'configurations'
_DEFAULTS[config_key].append(conf)
_DEFAULTS[config_key].append(config)
# setup helpers
for h_key in helpers:
key = h_key.split('.', 1)[0]
if key not in conf:
if key not in config:
raise ValueError('Invalid helper not in configuration [%s]' % key)
_DEFAULTS['helpers'].update(helpers)

Expand Down
16 changes: 12 additions & 4 deletions alphaconf/arg_parser.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import itertools
from typing import Dict, Iterable, List, Union
from typing import Dict, Iterable, List, Optional, Union, cast

from omegaconf import DictConfig, OmegaConf

Expand Down Expand Up @@ -122,6 +122,10 @@ def handle(self, result, value):
class ParseResult:
"""The result of argument parsing"""

result: Optional[Action]
rest: List[str]
_config: List[Union[str, DictConfig]]

def __init__(self) -> None:
"""Initialize the result"""
self.result = None
Expand Down Expand Up @@ -150,9 +154,9 @@ def configurations(self) -> Iterable[DictConfig]:
return
for typ, conf in itertools.groupby(configuration_list, type):
if issubclass(typ, DictConfig):
yield from conf
yield from cast(Iterable[DictConfig], conf)
else:
yield OmegaConf.from_dotlist(list(conf))
yield OmegaConf.from_dotlist(list(cast(Iterable[str], conf)))

def __repr__(self) -> str:
return f"(result={self.result}, config={self._config}, rest={self.rest})"
Expand All @@ -161,6 +165,10 @@ def __repr__(self) -> str:
class ArgumentParser:
"""Parses arguments for alphaconf"""

_opt_actions: Dict[str, Action]
_pos_actions: List[Action]
help_messages: Dict[str, str]

def __init__(self) -> None:
self._opt_actions = {}
self._pos_actions = []
Expand Down Expand Up @@ -201,7 +209,7 @@ def parse_args(self, arguments: List[str]) -> ParseResult:
# parse positional arguments
if value is None:
value = arg
arg = None
arg = None # type: ignore
action_result = f"Unrecognized argument: {value}"
for action in self._pos_actions:
if not action.check_argument(value):
Expand Down
2 changes: 1 addition & 1 deletion alphaconf/invoke.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,5 +61,5 @@ def invoke_application(
# Just configure the namespace and set the application
application.set(app)
namespace.configure(app.get_config())
app.setup_logging()
app.setup_configuration(arguments=False, load_dotenv=False, setup_logging=True)
return app
3 changes: 2 additions & 1 deletion alphaconf/logging_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
import traceback
from logging import Formatter, LogRecord
from typing import Any

try:
import colorama
Expand Down Expand Up @@ -50,7 +51,7 @@ class JSONFormatter(Formatter):
"""Format the log message as a single-line JSON dict"""

def format(self, record: LogRecord) -> str:
d = collections.OrderedDict()
d: collections.OrderedDict[str, Any] = collections.OrderedDict()
if self.usesTime():
d['time'] = self.formatTime(record, self.datefmt)
d['level'] = record.levelname
Expand Down
1 change: 1 addition & 0 deletions pre-commit
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,5 @@ esac
flake8
black --check .
isort --check-only .
mypy .
echo "All good to commit"
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ build-backend = "setuptools.build_meta"
line-length = 100
skip-string-normalization = 1

[tool.mypy]
ignore_missing_imports = true

[tool.isort]
profile = "black"
line_length = 100
Expand Down

0 comments on commit fdacfc9

Please sign in to comment.