Skip to content

Commit

Permalink
Make web interface also type safe
Browse files Browse the repository at this point in the history
  • Loading branch information
credbbl committed Dec 12, 2023
1 parent b1f5fde commit 7060913
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 12 deletions.
7 changes: 7 additions & 0 deletions src/glvd/data/cpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,13 @@ def __post_init__(self) -> None:
elif self.other is None:
self.other = CpeOtherDebian()

@property
def other_debian(self) -> CpeOtherDebian:
other = self.other
if isinstance(other, CpeOtherDebian):
return other
raise TypeError('Not debian related CPE')

@classmethod
def _parse_one(cls, field: dataclasses.Field, v: str, /) -> Any:
if v == '*':
Expand Down
21 changes: 13 additions & 8 deletions src/glvd/web/v1_cves.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# SPDX-License-Identifier: MIT

from typing import Any

from quart import Blueprint, current_app, request
from sqlalchemy import (
bindparam,
Expand Down Expand Up @@ -91,10 +93,10 @@


@bp.route('/<cve_id>')
async def get_cve_id(cve_id):
async def get_cve_id(cve_id: str) -> tuple[Any, int]:
stmt = stmt_cve_id.bindparams(cve_id=cve_id)

async with current_app.db_begin() as conn:
async with getattr(current_app, 'db_begin')() as conn:
data = (await conn.execute(stmt)).one_or_none()

if data:
Expand All @@ -104,30 +106,33 @@ async def get_cve_id(cve_id):


@bp.route('/findByCpe')
async def get_cpe_name():
cpe = Cpe.parse(request.args.get('cpeName', type=str))
async def get_cpe_name() -> tuple[Any, int]:
cpe = request.args.get('cpeName', type=Cpe.parse)
deb_version = request.args.get('debVersionEnd', type=str)

if cpe is None:
return 'No CPE', 400
if not cpe.is_debian:
return 'Not Debian related CPE', 400
cpe_other = cpe.other_debian

if cpe.other.deb_source and deb_version:
if cpe_other.deb_source and deb_version:
stmt = stmt_cpe_version.bindparams(
cpe_vendor=cpe.vendor,
cpe_product=cpe.product,
cpe_version=cpe.version or '%',
deb_source=cpe.other.deb_source,
deb_source=cpe_other.deb_source,
deb_version=deb_version,
)
else:
stmt = stmt_cpe_vulnerable.bindparams(
cpe_vendor=cpe.vendor,
cpe_product=cpe.product,
cpe_version=cpe.version or '%',
deb_source=cpe.other.deb_source or '%',
deb_source=cpe_other.deb_source or '%',
)

async with current_app.db_begin() as conn:
async with getattr(current_app, 'db_begin')() as conn:
return (
(await conn.execute(stmt)).one()[0],
200
Expand Down
14 changes: 10 additions & 4 deletions tests/data/test_cpe.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# SPDX-License-Identifier: MIT

import pytest

from glvd.data.cpe import Cpe, CpePart


Expand All @@ -18,6 +20,8 @@ def test_init(self):
assert c.target_hw is None
assert c.other is None
assert c.is_debian is False
with pytest.raises(TypeError):
c.other_debian

def test_parse(self):
s = r'cpe:2.3:h:a:b:c\:\%\*\;c:d:*:-:-:-:*:*'
Expand All @@ -34,6 +38,8 @@ def test_parse(self):
assert c.target_hw is None
assert c.other is None
assert c.is_debian is False
with pytest.raises(TypeError):
c.other_debian
assert str(c) == s

def test_debian(self):
Expand All @@ -49,16 +55,16 @@ def test_debian(self):
assert c.sw_edition is None
assert c.target_sw is None
assert c.target_hw is None
assert c.other.deb_source == 'hello'
assert c.other.deb_version == '1'
assert c.other_debian.deb_source == 'hello'
assert c.other_debian.deb_version == '1'
assert c.is_debian is True
assert str(c) == s

def test_debian_any(self):
s = r'cpe:2.3:o:debian:debian_linux:12:d:*:*:*:*:*:*'
c = Cpe.parse(s)

assert c.other.deb_source is None
assert c.other.deb_version is None
assert c.other_debian.deb_source is None
assert c.other_debian.deb_version is None
assert c.is_debian is True
assert str(c) == s

0 comments on commit 7060913

Please sign in to comment.