diff --git a/pyproject.toml b/pyproject.toml index 9f8395d..d118aad 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,6 +10,9 @@ packages = [ { include = "glvd", from = "src" }, ] +[tool.poetry.scripts] +glvd = 'glvd.cli.__main__:main' + [tool.poetry.dependencies] python = ">=3.11" asyncpg = ">=0.28" diff --git a/src/glvd/cli/__init__.py b/src/glvd/cli/__init__.py new file mode 100644 index 0000000..0fe5d52 --- /dev/null +++ b/src/glvd/cli/__init__.py @@ -0,0 +1,83 @@ +# SPDX-License-Identifier: MIT + +from __future__ import annotations + +import argparse +import dataclasses +from collections.abc import ( + Callable, + Iterable, +) + + +@dataclasses.dataclass +class _ActionWrapper: + args: tuple + kw: dict + + +class Cli: + parser: argparse.ArgumentParser + subparsers: argparse._SubParsersAction + + def __init__(self) -> None: + self.parser = argparse.ArgumentParser( + allow_abbrev=False, + prog='glvd', + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + self.subparsers = self.parser.add_subparsers( + help='sub-command help', + ) + + def add_argument(self, *args, **kw) -> _ActionWrapper: + return _ActionWrapper(args, kw) + + def register( + self, + name: str, + arguments: Iterable[_ActionWrapper], + usage: str = '%(prog)s', + epilog: str | None = None, + ) -> Callable: + parser_main = argparse.ArgumentParser( + allow_abbrev=False, + prog=f'glvd.{name}', + usage=usage, + epilog=epilog, + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + parser_sub = self.subparsers.add_parser( + name=name, + usage=usage, + epilog=epilog, + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + for p in (parser_main, parser_sub): + for w in arguments: + p.add_argument(*w.args, **w.kw) + + def wrap(func: Callable) -> Callable: + parser_sub.set_defaults(func=func) + + def run() -> None: + args = parser_main.parse_args() + func(**vars(args)) + + return run + + return wrap + + def main(self) -> None: + args = self.parser.parse_args() + v = vars(args) + func = v.pop('func', None) + if func: + func(**v) + else: + self.parser.print_help() + + +cli = Cli() diff --git a/src/glvd/cli/__main__.py b/src/glvd/cli/__main__.py new file mode 100644 index 0000000..ef299a0 --- /dev/null +++ b/src/glvd/cli/__main__.py @@ -0,0 +1,22 @@ +# SPDX-License-Identifier: MIT + +from __future__ import annotations + +from . import cli + +# Import to register all the commands +from . import ( # noqa: F401 + combine_all, + combine_deb, + ingest_debsec, + ingest_debsrc, + ingest_nvd, +) + + +def main() -> None: + cli.main() + + +if __name__ == '__main__': + main() diff --git a/src/glvd/cli/combine_all.py b/src/glvd/cli/combine_all.py index 6ee7750..23f0955 100644 --- a/src/glvd/cli/combine_all.py +++ b/src/glvd/cli/combine_all.py @@ -17,12 +17,34 @@ ) from ..database import Base, AllCve +from . import cli logger = logging.getLogger(__name__) -class CombineDeb: +class CombineAll: + @staticmethod + @cli.register( + 'combine-all', + arguments=[ + cli.add_argument( + '--database', + default='postgresql+asyncpg:///', + help='the database to use, must use asyncio compatible SQLAlchemy driver', + ), + cli.add_argument( + '--debug', + action='store_true', + help='enable debug output', + ), + ] + ) + def run(database: str, debug: bool) -> None: + logging.basicConfig(level=debug and logging.DEBUG or logging.INFO) + engine = create_async_engine(database, echo=debug) + asyncio.run(CombineAll()(engine)) + stmt_combine_new = ( text(''' SELECT @@ -103,12 +125,4 @@ async def __call__( if __name__ == '__main__': - import argparse - logging.basicConfig(level=logging.DEBUG) - parser = argparse.ArgumentParser() - args = parser.parse_args() - engine = create_async_engine( - "postgresql+asyncpg:///", - ) - main = CombineDeb() - asyncio.run(main(engine)) + CombineAll.run() diff --git a/src/glvd/cli/combine_deb.py b/src/glvd/cli/combine_deb.py index b72eddd..e5dd376 100644 --- a/src/glvd/cli/combine_deb.py +++ b/src/glvd/cli/combine_deb.py @@ -24,12 +24,34 @@ from ..database import Base, DistCpe, DebCve from ..data.cpe import Cpe, CpeOtherDebian from ..data.cvss import CvssSeverity +from . import cli logger = logging.getLogger(__name__) class CombineDeb: + @staticmethod + @cli.register( + 'combine-deb', + arguments=[ + cli.add_argument( + '--database', + default='postgresql+asyncpg:///', + help='the database to use, must use asyncio compatible SQLAlchemy driver', + ), + cli.add_argument( + '--debug', + action='store_true', + help='enable debug output', + ), + ] + ) + def run(database: str, debug: bool) -> None: + logging.basicConfig(level=debug and logging.DEBUG or logging.INFO) + engine = create_async_engine(database, echo=debug) + asyncio.run(CombineDeb()(engine)) + stmt_combine_new = ( text(''' SELECT @@ -207,12 +229,4 @@ async def __call__( if __name__ == '__main__': - import argparse - logging.basicConfig(level=logging.DEBUG) - parser = argparse.ArgumentParser() - args = parser.parse_args() - engine = create_async_engine( - "postgresql+asyncpg:///", - ) - main = CombineDeb() - asyncio.run(main(engine)) + CombineDeb.run() diff --git a/src/glvd/cli/ingest_debsec.py b/src/glvd/cli/ingest_debsec.py index 42e8382..7838708 100644 --- a/src/glvd/cli/ingest_debsec.py +++ b/src/glvd/cli/ingest_debsec.py @@ -17,12 +17,46 @@ from ..database import Base, DistCpe, DebsecCve from ..data.debsec_cve import DebsecCveFile from ..data.dist_cpe import DistCpeMapper +from . import cli logger = logging.getLogger(__name__) class IngestDebsec: + @staticmethod + @cli.register( + 'ingest-debsec', + arguments=[ + cli.add_argument( + 'cpe_product', + choices=sorted(DistCpeMapper.keys()), + help=f'CPE product used for data, supported: {" ".join(sorted(DistCpeMapper.keys()))}', + metavar='CPE_PRODUCT', + ), + cli.add_argument( + 'dir', + help='data directory out of https://salsa.debian.org/security-tracker-team/security-tracker', + metavar='DEBSEC', + type=Path, + ), + cli.add_argument( + '--database', + default='postgresql+asyncpg:///', + help='the database to use, must use asyncio compatible SQLAlchemy driver', + ), + cli.add_argument( + '--debug', + action='store_true', + help='enable debug output', + ), + ] + ) + def run(cpe_product: str, dir: Path, database: str, debug: bool) -> None: + logging.basicConfig(level=debug and logging.DEBUG or logging.INFO) + engine = create_async_engine(database, echo=debug) + asyncio.run(IngestDebsec(cpe_product, dir)(engine)) + def __init__(self, cpe_product: str, path: Path) -> None: self.path = path @@ -116,24 +150,4 @@ async def __call__( if __name__ == '__main__': - import argparse - logging.basicConfig(level=logging.DEBUG) - parser = argparse.ArgumentParser() - parser.add_argument( - 'cpe_product', - choices=sorted(DistCpeMapper.keys()), - help=f'CPE product used for data, supported: {" ".join(sorted(DistCpeMapper.keys()))}', - metavar='CPE_PRODUCT', - ) - parser.add_argument( - 'dir', - help='data directory out of https://salsa.debian.org/security-tracker-team/security-tracker', - metavar='DEBSEC', - type=Path, - ) - args = parser.parse_args() - engine = create_async_engine( - "postgresql+asyncpg:///", - ) - ingest = IngestDebsec(args.cpe_product, args.dir) - asyncio.run(ingest(engine)) + IngestDebsec.run() diff --git a/src/glvd/cli/ingest_debsrc.py b/src/glvd/cli/ingest_debsrc.py index 9c63a60..cd99300 100644 --- a/src/glvd/cli/ingest_debsrc.py +++ b/src/glvd/cli/ingest_debsrc.py @@ -17,12 +17,51 @@ from ..database import Base, DistCpe, Debsrc from ..data.debsrc import DebsrcFile from ..data.dist_cpe import DistCpeMapper +from . import cli logger = logging.getLogger(__name__) class IngestDebsrc: + @staticmethod + @cli.register( + 'ingest-debsrc', + arguments=[ + cli.add_argument( + 'cpe_product', + choices=sorted(DistCpeMapper.keys()), + help=f'CPE product used for data, supported: {" ".join(sorted(DistCpeMapper.keys()))}', + metavar='CPE_PRODUCT', + ), + cli.add_argument( + 'deb_codename', + help='codename of APT archive', + metavar='CODENAME', + ), + cli.add_argument( + 'file', + help='uncompressed Sources file', + metavar='SOURCES', + type=Path, + ), + cli.add_argument( + '--database', + default='postgresql+asyncpg:///', + help='the database to use, must use asyncio compatible SQLAlchemy driver', + ), + cli.add_argument( + '--debug', + action='store_true', + help='enable debug output', + ), + ] + ) + def run(cpe_product: str, deb_codename: str, file: Path, database: str, debug: bool) -> None: + logging.basicConfig(level=debug and logging.DEBUG or logging.INFO) + engine = create_async_engine(database, echo=debug) + asyncio.run(IngestDebsrc(cpe_product, deb_codename, file)(engine)) + def __init__(self, cpe_product: str, deb_codename: str, file: Path) -> None: self.file = file @@ -103,29 +142,4 @@ async def __call__( if __name__ == '__main__': - import argparse - logging.basicConfig(level=logging.DEBUG) - parser = argparse.ArgumentParser() - parser.add_argument( - 'cpe_product', - choices=sorted(DistCpeMapper.keys()), - help=f'CPE product used for data, supported: {" ".join(sorted(DistCpeMapper.keys()))}', - metavar='CPE_PRODUCT', - ) - parser.add_argument( - 'deb_codename', - help='codename of APT archive', - metavar='CODENAME', - ) - parser.add_argument( - 'file', - help='uncompressed Sources file', - metavar='SOURCES', - type=Path, - ) - args = parser.parse_args() - engine = create_async_engine( - "postgresql+asyncpg:///", - ) - ingest = IngestDebsrc(args.cpe_product, args.deb_codename, args.file) - asyncio.run(ingest(engine)) + IngestDebsrc.run() diff --git a/src/glvd/cli/ingest_nvd.py b/src/glvd/cli/ingest_nvd.py index 97cfa21..813075e 100644 --- a/src/glvd/cli/ingest_nvd.py +++ b/src/glvd/cli/ingest_nvd.py @@ -17,6 +17,7 @@ from ..database import Base, NvdCve from ..util import requests +from . import cli logger = logging.getLogger(__name__) @@ -25,6 +26,27 @@ class IngestNvd: wait: int + @staticmethod + @cli.register( + 'ingest-nvd', + arguments=[ + cli.add_argument( + '--database', + default='postgresql+asyncpg:///', + help='the database to use, must use asyncio compatible SQLAlchemy driver', + ), + cli.add_argument( + '--debug', + action='store_true', + help='enable debug output', + ), + ] + ) + def run(database: str, debug: bool) -> None: + logging.basicConfig(level=debug and logging.DEBUG or logging.INFO) + engine = create_async_engine(database, echo=debug) + asyncio.run(IngestNvd()(engine)) + def __init__(self, *, wait: int = 6) -> None: self.wait = wait @@ -122,9 +144,4 @@ async def __call__(self, engine: AsyncEngine) -> None: if __name__ == '__main__': - logging.basicConfig(level=logging.DEBUG) - engine = create_async_engine( - "postgresql+asyncpg:///", - ) - ingest = IngestNvd() - asyncio.run(ingest(engine)) + IngestNvd.run()