diff --git a/src/glvd/data/cpe.py b/src/glvd/data/cpe.py index babe345..6b7334c 100644 --- a/src/glvd/data/cpe.py +++ b/src/glvd/data/cpe.py @@ -5,7 +5,10 @@ import dataclasses import re from enum import StrEnum -from typing import Any +from typing import ( + Any, + cast, +) class CpePart(StrEnum): @@ -101,6 +104,12 @@ def __post_init__(self) -> None: elif self.other is None: self.other = CpeOtherDebian() + @property + def other_debian(self) -> CpeOtherDebian: + if self.is_debian: + return cast(CpeOtherDebian, self.other) + raise TypeError('Not debian related CPE') + @classmethod def _parse_one(cls, field: dataclasses.Field, v: str, /) -> Any: if v == '*': diff --git a/src/glvd/web/v1_cves.py b/src/glvd/web/v1_cves.py index 0b83c04..734bc63 100644 --- a/src/glvd/web/v1_cves.py +++ b/src/glvd/web/v1_cves.py @@ -1,5 +1,7 @@ # SPDX-License-Identifier: MIT +from typing import Any + from quart import Blueprint, current_app, request from sqlalchemy import ( bindparam, @@ -91,10 +93,10 @@ @bp.route('/') -async def get_cve_id(cve_id): +async def get_cve_id(cve_id: str) -> tuple[Any, int]: stmt = stmt_cve_id.bindparams(cve_id=cve_id) - async with current_app.db_begin() as conn: + async with getattr(current_app, 'db_begin')() as conn: data = (await conn.execute(stmt)).one_or_none() if data: @@ -104,19 +106,21 @@ async def get_cve_id(cve_id): @bp.route('/findByCpe') -async def get_cpe_name(): - cpe = Cpe.parse(request.args.get('cpeName', type=str)) +async def get_cpe_name() -> tuple[Any, int]: + cpe = request.args.get('cpeName', type=Cpe.parse) deb_version = request.args.get('debVersionEnd', type=str) + if cpe is None: + return 'No CPE', 400 if not cpe.is_debian: return 'Not Debian related CPE', 400 - if cpe.other.deb_source and deb_version: + if cpe.other_debian.deb_source and deb_version: stmt = stmt_cpe_version.bindparams( cpe_vendor=cpe.vendor, cpe_product=cpe.product, cpe_version=cpe.version or '%', - deb_source=cpe.other.deb_source, + deb_source=cpe.other_debian.deb_source, deb_version=deb_version, ) else: @@ -124,10 +128,10 @@ async def get_cpe_name(): cpe_vendor=cpe.vendor, cpe_product=cpe.product, cpe_version=cpe.version or '%', - deb_source=cpe.other.deb_source or '%', + deb_source=cpe.other_debian.deb_source or '%', ) - async with current_app.db_begin() as conn: + async with getattr(current_app, 'db_begin')() as conn: return ( (await conn.execute(stmt)).one()[0], 200 diff --git a/tests/data/test_cpe.py b/tests/data/test_cpe.py index d8882b9..a494f98 100644 --- a/tests/data/test_cpe.py +++ b/tests/data/test_cpe.py @@ -1,5 +1,7 @@ # SPDX-License-Identifier: MIT +import pytest + from glvd.data.cpe import Cpe, CpePart @@ -18,6 +20,8 @@ def test_init(self): assert c.target_hw is None assert c.other is None assert c.is_debian is False + with pytest.raises(TypeError): + c.other_debian def test_parse(self): s = r'cpe:2.3:h:a:b:c\:\%\*\;c:d:*:-:-:-:*:*' @@ -34,6 +38,8 @@ def test_parse(self): assert c.target_hw is None assert c.other is None assert c.is_debian is False + with pytest.raises(TypeError): + c.other_debian assert str(c) == s def test_debian(self): @@ -49,8 +55,8 @@ def test_debian(self): assert c.sw_edition is None assert c.target_sw is None assert c.target_hw is None - assert c.other.deb_source == 'hello' - assert c.other.deb_version == '1' + assert c.other_debian.deb_source == 'hello' + assert c.other_debian.deb_version == '1' assert c.is_debian is True assert str(c) == s @@ -58,7 +64,7 @@ def test_debian_any(self): s = r'cpe:2.3:o:debian:debian_linux:12:d:*:*:*:*:*:*' c = Cpe.parse(s) - assert c.other.deb_source is None - assert c.other.deb_version is None + assert c.other_debian.deb_source is None + assert c.other_debian.deb_version is None assert c.is_debian is True assert str(c) == s