Skip to content

Commit

Permalink
Merge pull request Aiven-Open#921 from Aiven-Open/jjaakola-aiven-catc…
Browse files Browse the repository at this point in the history
…h-all-exceptions-in-forked-process

fix: catch all exceptions in forked process
  • Loading branch information
keejon authored Jul 24, 2024
2 parents 83e329a + 8f0a350 commit 8afd513
Showing 1 changed file with 61 additions and 50 deletions.
111 changes: 61 additions & 50 deletions karapace/protobuf/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@
from karapace.protobuf.schema import ProtobufSchema
from karapace.protobuf.type_element import TypeElement
from multiprocessing import Process, Queue
from pathlib import Path
from typing import Dict, Final, Generator, Iterable, Protocol
from typing_extensions import Self, TypeAlias

import hashlib
import importlib
import importlib.util
import os
import subprocess
import sys

Expand Down Expand Up @@ -96,29 +96,27 @@ def get_protobuf_class_instance(
class_name: str,
cfg: Config,
) -> _ProtobufModel:
directory = cfg["protobuf_runtime_directory"]
directory = Path(cfg["protobuf_runtime_directory"])
deps_list = crawl_dependencies(schema)
root_class_name = ""
for value in deps_list.values():
root_class_name = root_class_name + value["unique_class_name"]
root_class_name = root_class_name + str(schema)
proto_name = calculate_class_name(root_class_name)

proto_path = f"{proto_name}.proto"
work_dir = f"{directory}/{proto_name}"
if not os.path.isdir(directory):
os.mkdir(directory)
if not os.path.isdir(work_dir):
os.mkdir(work_dir)
class_path = f"{directory}/{proto_name}/{proto_name}_pb2.py"
if not os.path.exists(class_path):
main_proto_filename = f"{proto_name}.proto"
work_dir = directory / Path(proto_name)
work_dir.mkdir(exist_ok=True, parents=True)
class_path = work_dir / Path(f"{proto_name}_pb2.py")

if not class_path.exists():
with open(f"{directory}/{proto_name}/{proto_name}.proto", mode="w", encoding="utf8") as proto_text:
proto_text.write(replace_imports(str(schema), deps_list))

protoc_arguments = [
"protoc",
"--python_out=./",
proto_path,
main_proto_filename,
]
for value in deps_list.values():
proto_file_name = value["unique_class_name"] + ".proto"
Expand All @@ -127,16 +125,18 @@ def get_protobuf_class_instance(
with open(dependency_path, mode="w", encoding="utf8") as proto_text:
proto_text.write(replace_imports(value["schema"], deps_list))

if not os.path.isfile(class_path):
if not class_path.is_file():
subprocess.run(
protoc_arguments,
check=True,
cwd=work_dir,
)

# todo: This will leave residues on sys.path in case of exceptions. If really must
# mutate sys.path, we should at least wrap in try-finally.
sys.path.append(f"./runtime/{proto_name}")
runtime_proto_path = f"./runtime/{proto_name}"
if runtime_proto_path not in sys.path:
# todo: This will leave residues on sys.path in case of exceptions. If really must
# mutate sys.path, we should at least wrap in try-finally.
sys.path.append(runtime_proto_path)
spec = importlib.util.spec_from_file_location(f"{proto_name}_pb2", class_path)
# This is reasonable to assert because we just created this file.
assert spec is not None
Expand Down Expand Up @@ -168,25 +168,25 @@ def read_data(
return class_instance


_ReaderQueue: TypeAlias = "Queue[dict[object, object] | Exception]"
_ReaderQueue: TypeAlias = "Queue[dict[object, object] | BaseException]"


def reader_process(
queue: _ReaderQueue,
reader_queue: _ReaderQueue,
config: Config,
writer_schema: ProtobufSchema,
reader_schema: ProtobufSchema,
bio: BytesIO,
) -> None:
try:
queue.put(protobuf_to_dict(read_data(config, writer_schema, reader_schema, bio), True))
# todo: This lint ignore does not look reasonable. If it is, reasoning should be
# documented.
except Exception as e: # pylint: disable=broad-except
queue.put(e)
reader_queue.put(protobuf_to_dict(read_data(config, writer_schema, reader_schema, bio), True))
# Reading happens in the forked process, catch is broad so exception will get communicated
# back to calling process.
except BaseException as base_exception: # pylint: disable=broad-except
reader_queue.put(base_exception)


def reader_mp(
def read_in_forked_multiprocess_process(
config: Config,
writer_schema: ProtobufSchema,
reader_schema: ProtobufSchema,
Expand All @@ -200,14 +200,18 @@ def reader_mp(
# To avoid problem with enum values for basic SerDe support we
# will isolate work with call protobuf libraries in child process.
if __name__ == "karapace.protobuf.io":
queue: _ReaderQueue = Queue()
p = Process(target=reader_process, args=(queue, config, writer_schema, reader_schema, bio))
reader_queue: _ReaderQueue = Queue()
p = Process(target=reader_process, args=(reader_queue, config, writer_schema, reader_schema, bio))
p.start()
result = queue.get()
p.join()
TEN_SECONDS_WAIT = 10
try:
result = reader_queue.get(True, TEN_SECONDS_WAIT)
finally:
p.join()
reader_queue.close()
if isinstance(result, Dict):
return result
if isinstance(result, Exception):
if isinstance(result, BaseException):
raise result
raise IllegalArgumentException()
return {"Error": "This never must be returned"}
Expand All @@ -233,34 +237,37 @@ def __init__(
def read(self, bio: BytesIO) -> dict:
if self._reader_schema is None:
self._reader_schema = self._writer_schema
return reader_mp(self.config, self._writer_schema, self._reader_schema, bio)
return read_in_forked_multiprocess_process(self.config, self._writer_schema, self._reader_schema, bio)


_WriterQueue: TypeAlias = "Queue[bytes | Exception]"
_WriterQueue: TypeAlias = "Queue[bytes | str | BaseException]"


def writer_process(
queue: _WriterQueue,
writer_queue: _WriterQueue,
config: Config,
writer_schema: ProtobufSchema,
message_name: str,
datum: dict,
) -> None:
class_instance = get_protobuf_class_instance(writer_schema, message_name, config)
try:
class_instance = get_protobuf_class_instance(writer_schema, message_name, config)
dict_to_protobuf(class_instance, datum)
# todo: This does not look like a reasonable place to catch any exception,
# especially since we're effectively silencing them.
except Exception:
# pylint: disable=raise-missing-from
e = ProtobufTypeException(writer_schema, datum)
queue.put(e)
raise e
queue.put(class_instance.SerializeToString())


# todo: What is mp? Expand the abbreviation or add an explaining comment.
def writer_mp(
result = class_instance.SerializeToString()
writer_queue.put(result)
# Writing happens in the forked process, catch is broad so exception will get communicated
# back to calling process.
except Exception as bare_exception: # pylint: disable=broad-exception-caught
try:
raise ProtobufTypeException(writer_schema, datum) from bare_exception
except ProtobufTypeException as protobuf_exception:
writer_queue.put(protobuf_exception)
raise protobuf_exception
except BaseException as base_exception: # pylint: disable=broad-exception-caught
writer_queue.put(base_exception)


def write_in_forked_multiprocess_process(
config: Config,
writer_schema: ProtobufSchema,
message_name: str,
Expand All @@ -274,14 +281,18 @@ def writer_mp(
# To avoid problem with enum values for basic SerDe support we
# will isolate work with call protobuf libraries in child process.
if __name__ == "karapace.protobuf.io":
queue: _WriterQueue = Queue()
p = Process(target=writer_process, args=(queue, config, writer_schema, message_name, datum))
writer_queue: _WriterQueue = Queue(1)
p = Process(target=writer_process, args=(writer_queue, config, writer_schema, message_name, datum))
p.start()
result = queue.get()
p.join()
TEN_SECONDS_WAIT = 10
try:
result = writer_queue.get(True, TEN_SECONDS_WAIT) # Block for ten seconds
finally:
p.join()
writer_queue.close()
if isinstance(result, bytes):
return result
if isinstance(result, Exception):
if isinstance(result, BaseException):
raise result
raise IllegalArgumentException()
raise NotImplementedError("Error: Reached unreachable code")
Expand Down Expand Up @@ -309,4 +320,4 @@ def write_index(self, writer: BytesIO) -> None:
write_indexes(writer, [self._message_index])

def write(self, datum: dict[object, object], writer: BytesIO) -> None:
writer.write(writer_mp(self.config, self._writer_schema, self._message_name, datum))
writer.write(write_in_forked_multiprocess_process(self.config, self._writer_schema, self._message_name, datum))

0 comments on commit 8afd513

Please sign in to comment.