Skip to content

Commit

Permalink
fixed streaming response & linting & types
Browse files Browse the repository at this point in the history
  • Loading branch information
nferc committed Jun 13, 2024
1 parent 19d484d commit d39bf55
Show file tree
Hide file tree
Showing 5 changed files with 220 additions and 141 deletions.
4 changes: 3 additions & 1 deletion README.MD
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
<!-- markdownlint-disable MD033 -->

# AV-Gate

Proxy für den Antivirus-Scan von Dokumenten zu der elektronischen Patientenakte (ePa). Dieser Proxy wird zwischen dem Konnektor der Gematik und den Primärsystemen geschaltet und überprüft sämtliche Dokumente der ePA vor der Übertragung an die Primärsysteme.
Expand Down Expand Up @@ -111,6 +113,7 @@ In der `avgate.ini` ist für jeden Konnektor eine Gruppe anzulegen. Der Gruppenn
> Auf die Verwendung von Namen statt IP-Adressen wurde verzichtet, weil ein Großteil der Primärsysteme keine Namen in der Konfiguration verwenden kann.
Für jede Konnektor (jede Gruppe) kann konfiguriert werden:

- konnektor = https://<host>
- ssl_verify = true
Die Zertifikate des Konnektors werden auf Gültigkeit überprüft. Verbindungen mit ungültigen (auch selfsigned) Zertifikaten werden abgelehnt.
Expand Down Expand Up @@ -182,7 +185,6 @@ Die ENV vars:
- LOG_LEVEL - default INFO
einer aus DEBUG INFO WARNING ERROR CRITICAL


## Primärsysteme

Das AV-Gate wurde für folgende Primärsysteme getestet:
Expand Down
130 changes: 70 additions & 60 deletions avgate/avgate.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@
import ssl
import types
from email.message import EmailMessage
from typing import Callable, Generator, List, cast
from typing import Any, Callable, Generator, List, cast
from urllib.parse import unquote, urlparse

import flask
import lxml.etree as ET
import requests
import urllib3
from flask import Flask, Response, abort, request, stream_with_context

__version__ = "1.10"

Expand All @@ -37,7 +37,7 @@
# to prevent flooding log
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)

app = Flask(__name__)
app = flask.Flask(__name__)

config = configparser.ConfigParser()

Expand All @@ -57,6 +57,8 @@
ALL_PNG_MALICIOUS = config["config"].getboolean("all_png_malicious", False)
ALL_PDF_MALICIOUS = config["config"].getboolean("all_pdf_malicious", False)

clamav_sock: Any = None

if config.has_option("config", "clamd_socket"):
import clamd # type: ignore

Expand All @@ -71,7 +73,7 @@ def connector_sds():
# <si:Service Name="PHRService">
# <si:EndpointTLS Location="https://kon-instanz1.titus.ti-dienste.de:443/ws/PHRService/1.3.0"/>

logger.debug(f"Client Cert: {request.headers.get('X-Client-Cert') or None}")
logger.debug(f"Client Cert: {flask.request.headers.get('X-Client-Cert') or None}")
client_config = get_client_config()
with request_upstream(client_config, warn=False) as upstream:
xml = ET.fromstring(upstream.content)
Expand All @@ -80,15 +82,15 @@ def connector_sds():
for e in xml.findall("{*}ServiceInformation/{*}Service//{*}EndpointTLS"):
previous_url = urlparse(e.attrib["Location"])
e.attrib["Location"] = (
f"{previous_url.scheme}://{request.host}{previous_url.path}"
f"{previous_url.scheme}://{flask.request.host}{previous_url.path}"
)

for e in xml.findall(
"{*}ServiceInformation/{*}Service[@Name='PHRService']//{*}EndpointTLS"
):
previous_url = urlparse(e.attrib["Location"])
e.attrib["Location"] = (
f"{previous_url.scheme}://{request.host}{previous_url.path}"
f"{previous_url.scheme}://{flask.request.host}{previous_url.path}"
)
global phr_service_path
phr_service_path = previous_url.path
Expand Down Expand Up @@ -118,10 +120,10 @@ def fav():
@app.route("/health")
def health():
"""Health check"""
res = check_clamav() or ""
res += check_icap() or ""
res = check_clamav()
res += check_icap()
if res:
return Response(res, mimetype="text/plain", status=503)
return flask.Response(res, mimetype="text/plain", status=503)
return "OK\n"


Expand All @@ -136,7 +138,7 @@ def icap():
@app.post("/icap")
def icap_post():
"""Running icap with post data"""
return get_icap(request.get_data())
return get_icap(flask.request.get_data())


@app.route("/check")
Expand All @@ -158,7 +160,7 @@ def check():

try:
test = requests.request(
method=request.method,
method=flask.request.method,
url=konn + "/connector.sds",
cert=cert,
verify=verify,
Expand All @@ -179,29 +181,35 @@ def check():
res += f"{client} {konn}: {err} \n"
logger.warning(f"check failed for Konnektor: {client} {konn} {err}")

return Response(res, mimetype="text/plain", status=503 if err_count else 200)
return flask.Response(res, mimetype="text/plain", status=503 if err_count else 200)


def check_clamav() -> str:
clamd_path = config["config"].get("clamd_socket")
if clamd_path:
test = clamav_sock.ping()
if test != "PONG":
logger.warning(f"Healtchckeck failed for clamav: {test}")
return "clamav: no ping\n"
if not clamd_path:
return ""

test = clamav_sock.ping()
if test != "PONG":
logger.warning(f"Healtchckeck failed for clamav: {test}")
return "clamav: no ping\n"
return ""


def check_icap() -> str:
icap_host = config["config"].get("icap_host")
if icap_host:
try:
scan_file_icap(b"ping\r\n")
except Exception as err:
logger.warning(f"Healtcheck failed for icap: {err}")
return "icap: failed\n"
if not icap_host:
return ""

try:
scan_file_icap(b"ping\r\n")
return ""
except Exception as err:
logger.warning(f"Healtcheck failed for icap: {err}")
return "icap: failed\n"

def phr_service() -> Response:

def phr_service() -> flask.Response:
"""Scan AV on xop documents for retrieveDocumentSetRequest"""
client_config = get_client_config()
with request_upstream(client_config) as upstream:
Expand All @@ -219,26 +227,22 @@ def phr_service() -> Response:
return response


def other() -> Response:
def other() -> flask.wrappers.Response:
"""Streamed forward without scan"""
client_config = get_client_config()
upstream = request_upstream(client_config, stream=True)

def generate():
for data in upstream.iter_content():
yield data
upstream.close()

response = create_response(generate, upstream)
return response
with request_upstream(client_config, stream=True) as upstream:
response = create_response(upstream.iter_content(), upstream)
return response


def request_upstream(client_config, warn=True, stream=False) -> Response:
def request_upstream(
client_config, warn=True, stream=False
) -> requests.models.Response:
"""Request to real Konnektor"""

konn = client_config["Konnektor"]
url = konn + request.path
data = request.get_data()
url = konn + flask.request.path
data = flask.request.stream if stream else flask.request.get_data()

# client cert
cert = None
Expand All @@ -248,13 +252,13 @@ def request_upstream(client_config, warn=True, stream=False) -> Response:

headers = {
key: value
for key, value in request.headers.items()
for key, value in flask.request.headers.items()
if key not in ("X-Real-Ip", "Host")
}

try:
response = requests.request(
method=request.method,
method=flask.request.method,
url=url,
headers=headers,
data=data,
Expand All @@ -265,26 +269,28 @@ def request_upstream(client_config, warn=True, stream=False) -> Response:

if warn and not stream and bytes(konn, "ascii") in response.content:
logger.warning(
f"Found Konnektor Address in response: {konn} - {request.url}"
f"Found Konnektor Address in response: {konn} - {flask.request.url}"
)

if not response.ok:
logger.warning(
f"Error from Konnektor: {response.url} - {response.status_code} {response.reason}"
)
logger.warning(f"Response: {response.content}")
logger.warning(f"Cert: {request.headers.get('X-Client-Cert')}")
logger.warning(f"Response: {response.content.decode()}")
logger.warning(f"Cert: {flask.request.headers.get('X-Client-Cert')}")

return response

except Exception as err:
logger.error(err)
abort(502)
flask.abort(502)


def get_client_config() -> configparser.ConfigParser:
request_ip = request.headers.get("X-real-ip", request.host.split(":")[0])
port = request.host.split(":")[1] if ":" in request.host else "443"
def get_client_config() -> configparser.SectionProxy:
request_ip = flask.request.headers.get(
"X-real-ip", flask.request.host.split(":")[0]
)
port = flask.request.host.split(":")[1] if ":" in flask.request.host else "443"

client = f"{request_ip}:{port}"
logger.debug(f"client {client}")
Expand All @@ -297,10 +303,12 @@ def get_client_config() -> configparser.ConfigParser:
return config["default"]
else:
logger.error(f"Client {client} or default not found in avgate.ini")
abort(503)
flask.abort(503)


def create_response(data, upstream: Response) -> Response:
def create_response(
data, upstream: requests.models.Response
) -> flask.wrappers.Response:
"""Create new response with copying headers from origin response"""
headers = {
k: v
Expand All @@ -316,17 +324,17 @@ def create_response(data, upstream: Response) -> Response:
)
}

if type(data) is types.FunctionType:
response = Response(
stream_with_context(data()),
if isinstance(data, types.FunctionType):
response = flask.Response(
response=flask.stream_with_context(data()),
status=upstream.status_code,
headers=headers,
mimetype=upstream.headers.get("Mimetype"),
content_type=upstream.headers.get("Content-Type"),
direct_passthrough=True,
)
else:
response = Response(
response = flask.Response(
response=data,
status=upstream.status_code,
headers=headers,
Expand All @@ -338,7 +346,7 @@ def create_response(data, upstream: Response) -> Response:
return response


def run_antivirus(res: requests.Response):
def run_antivirus(res: requests.models.Response):
"""Remove or exchange document when virus was found"""

# only interested in multipart
Expand Down Expand Up @@ -533,7 +541,7 @@ def fix_status(xml_resp, xml_errlist, xml_ns, msg):


def build_payload(
msg: EmailMessage, malicious_content_ids: List[str], res: requests.Response
msg: EmailMessage, malicious_content_ids: List[str], res: requests.models.Response
) -> bytes:
"""create payload based on original response with replacing only payoad for malicious_content_ids"""

Expand Down Expand Up @@ -588,7 +596,7 @@ def get_content_id(content: bytes):

def get_replacement(mimetype) -> bytes:
"""get content for replacements"""
filename = replacement_files.get(mimetype) or replacement_files.get("text/plain")
filename = replacement_files.get(mimetype) or replacement_files["text/plain"]
with open(filename, "rb") as f:
return f.read()

Expand Down Expand Up @@ -623,6 +631,8 @@ def get_file_scanner() -> Callable[[bytes], List[str | None]]:

def scan_file_clamav(content: bytes) -> List[str | None]:
"""return scan result, do use clamav socket"""
if not isinstance(clamav_sock, clamd.ClamdUnixSocket):
raise AttributeError("clamav socket is not configured")
scan_res = clamav_sock.instream(io.BytesIO(content))["stream"]
return scan_res

Expand All @@ -647,7 +657,7 @@ def scan_file_icap(content: bytes) -> List[str | None]:

# real finding
if found:
return ["FOUND", found[1]]
return ["FOUND", found[1].decode()]

# in case of 200 the content should be unchanged
if content == content_back:
Expand All @@ -656,8 +666,8 @@ def scan_file_icap(content: bytes) -> List[str | None]:

# modified content without infection found
logger.warning("ICAP modified content without findings")
logger.debug(f"IN ...{content[-100:]}")
logger.debug(f"OUT ...{content_back[-100:]})")
logger.debug(f"IN ...{content[-100:].decode()}")
logger.debug(f"OUT ...{content_back[-100:].decode()})")

return ["OK", None]

Expand All @@ -670,7 +680,7 @@ def get_icap(content: bytes) -> bytes:
icap_tls = config["config"].getboolean("icap_tls", False)

req_hdr = "GET /resource HTTP/1.1\r\n"
req_hdr += f"Host: {request.host}\r\n"
req_hdr += f"Host: {flask.request.host}\r\n"
req_hdr += "\r\n"

res_hdr = "HTTP/1.1 200 OK\r\n"
Expand Down Expand Up @@ -702,7 +712,7 @@ def get_icap(content: bytes) -> bytes:
return b"".join(rcv_chunks)


def _open_sock(host: str, port: int, tls: bool) -> socket:
def _open_sock(host: str, port: int, tls: bool) -> socket.socket:
"""returns socket, with TLS if needed"""
if tls:
with socket.create_connection((host, port)) as sock:
Expand Down
Loading

0 comments on commit d39bf55

Please sign in to comment.