From 8f0a35014f40ee2ab910e3433dc95b0f0ea3b6ce Mon Sep 17 00:00:00 2001 From: Jarkko Jaakola Date: Wed, 24 Jul 2024 11:27:04 +0300 Subject: [PATCH] fix: catch all exceptions in forked process If base exception is raised the communication queue will block in the calling process and will cause timeouts. This is prevalent in unit test for protobuf serialization in Github Actions, caused by the protobuf compiler work directory existence. The fix is to use pathlib.Path with exists=True and parents=True for removing the FileExistsError and also changing the error handling logic to pass back also BaseExceptions through the multiprocess queue. --- karapace/protobuf/io.py | 111 ++++++++++++++++++++++------------------ 1 file changed, 61 insertions(+), 50 deletions(-) diff --git a/karapace/protobuf/io.py b/karapace/protobuf/io.py index 9cd36b803..2c87073d3 100644 --- a/karapace/protobuf/io.py +++ b/karapace/protobuf/io.py @@ -13,13 +13,13 @@ from karapace.protobuf.schema import ProtobufSchema from karapace.protobuf.type_element import TypeElement from multiprocessing import Process, Queue +from pathlib import Path from typing import Dict, Final, Generator, Iterable, Protocol from typing_extensions import Self, TypeAlias import hashlib import importlib import importlib.util -import os import subprocess import sys @@ -96,7 +96,7 @@ def get_protobuf_class_instance( class_name: str, cfg: Config, ) -> _ProtobufModel: - directory = cfg["protobuf_runtime_directory"] + directory = Path(cfg["protobuf_runtime_directory"]) deps_list = crawl_dependencies(schema) root_class_name = "" for value in deps_list.values(): @@ -104,21 +104,19 @@ def get_protobuf_class_instance( root_class_name = root_class_name + str(schema) proto_name = calculate_class_name(root_class_name) - proto_path = f"{proto_name}.proto" - work_dir = f"{directory}/{proto_name}" - if not os.path.isdir(directory): - os.mkdir(directory) - if not os.path.isdir(work_dir): - os.mkdir(work_dir) - class_path = f"{directory}/{proto_name}/{proto_name}_pb2.py" - if not os.path.exists(class_path): + main_proto_filename = f"{proto_name}.proto" + work_dir = directory / Path(proto_name) + work_dir.mkdir(exist_ok=True, parents=True) + class_path = work_dir / Path(f"{proto_name}_pb2.py") + + if not class_path.exists(): with open(f"{directory}/{proto_name}/{proto_name}.proto", mode="w", encoding="utf8") as proto_text: proto_text.write(replace_imports(str(schema), deps_list)) protoc_arguments = [ "protoc", "--python_out=./", - proto_path, + main_proto_filename, ] for value in deps_list.values(): proto_file_name = value["unique_class_name"] + ".proto" @@ -127,16 +125,18 @@ def get_protobuf_class_instance( with open(dependency_path, mode="w", encoding="utf8") as proto_text: proto_text.write(replace_imports(value["schema"], deps_list)) - if not os.path.isfile(class_path): + if not class_path.is_file(): subprocess.run( protoc_arguments, check=True, cwd=work_dir, ) - # todo: This will leave residues on sys.path in case of exceptions. If really must - # mutate sys.path, we should at least wrap in try-finally. - sys.path.append(f"./runtime/{proto_name}") + runtime_proto_path = f"./runtime/{proto_name}" + if runtime_proto_path not in sys.path: + # todo: This will leave residues on sys.path in case of exceptions. If really must + # mutate sys.path, we should at least wrap in try-finally. + sys.path.append(runtime_proto_path) spec = importlib.util.spec_from_file_location(f"{proto_name}_pb2", class_path) # This is reasonable to assert because we just created this file. assert spec is not None @@ -168,25 +168,25 @@ def read_data( return class_instance -_ReaderQueue: TypeAlias = "Queue[dict[object, object] | Exception]" +_ReaderQueue: TypeAlias = "Queue[dict[object, object] | BaseException]" def reader_process( - queue: _ReaderQueue, + reader_queue: _ReaderQueue, config: Config, writer_schema: ProtobufSchema, reader_schema: ProtobufSchema, bio: BytesIO, ) -> None: try: - queue.put(protobuf_to_dict(read_data(config, writer_schema, reader_schema, bio), True)) - # todo: This lint ignore does not look reasonable. If it is, reasoning should be - # documented. - except Exception as e: # pylint: disable=broad-except - queue.put(e) + reader_queue.put(protobuf_to_dict(read_data(config, writer_schema, reader_schema, bio), True)) + # Reading happens in the forked process, catch is broad so exception will get communicated + # back to calling process. + except BaseException as base_exception: # pylint: disable=broad-except + reader_queue.put(base_exception) -def reader_mp( +def read_in_forked_multiprocess_process( config: Config, writer_schema: ProtobufSchema, reader_schema: ProtobufSchema, @@ -200,14 +200,18 @@ def reader_mp( # To avoid problem with enum values for basic SerDe support we # will isolate work with call protobuf libraries in child process. if __name__ == "karapace.protobuf.io": - queue: _ReaderQueue = Queue() - p = Process(target=reader_process, args=(queue, config, writer_schema, reader_schema, bio)) + reader_queue: _ReaderQueue = Queue() + p = Process(target=reader_process, args=(reader_queue, config, writer_schema, reader_schema, bio)) p.start() - result = queue.get() - p.join() + TEN_SECONDS_WAIT = 10 + try: + result = reader_queue.get(True, TEN_SECONDS_WAIT) + finally: + p.join() + reader_queue.close() if isinstance(result, Dict): return result - if isinstance(result, Exception): + if isinstance(result, BaseException): raise result raise IllegalArgumentException() return {"Error": "This never must be returned"} @@ -233,34 +237,37 @@ def __init__( def read(self, bio: BytesIO) -> dict: if self._reader_schema is None: self._reader_schema = self._writer_schema - return reader_mp(self.config, self._writer_schema, self._reader_schema, bio) + return read_in_forked_multiprocess_process(self.config, self._writer_schema, self._reader_schema, bio) -_WriterQueue: TypeAlias = "Queue[bytes | Exception]" +_WriterQueue: TypeAlias = "Queue[bytes | str | BaseException]" def writer_process( - queue: _WriterQueue, + writer_queue: _WriterQueue, config: Config, writer_schema: ProtobufSchema, message_name: str, datum: dict, ) -> None: - class_instance = get_protobuf_class_instance(writer_schema, message_name, config) try: + class_instance = get_protobuf_class_instance(writer_schema, message_name, config) dict_to_protobuf(class_instance, datum) - # todo: This does not look like a reasonable place to catch any exception, - # especially since we're effectively silencing them. - except Exception: - # pylint: disable=raise-missing-from - e = ProtobufTypeException(writer_schema, datum) - queue.put(e) - raise e - queue.put(class_instance.SerializeToString()) - - -# todo: What is mp? Expand the abbreviation or add an explaining comment. -def writer_mp( + result = class_instance.SerializeToString() + writer_queue.put(result) + # Writing happens in the forked process, catch is broad so exception will get communicated + # back to calling process. + except Exception as bare_exception: # pylint: disable=broad-exception-caught + try: + raise ProtobufTypeException(writer_schema, datum) from bare_exception + except ProtobufTypeException as protobuf_exception: + writer_queue.put(protobuf_exception) + raise protobuf_exception + except BaseException as base_exception: # pylint: disable=broad-exception-caught + writer_queue.put(base_exception) + + +def write_in_forked_multiprocess_process( config: Config, writer_schema: ProtobufSchema, message_name: str, @@ -274,14 +281,18 @@ def writer_mp( # To avoid problem with enum values for basic SerDe support we # will isolate work with call protobuf libraries in child process. if __name__ == "karapace.protobuf.io": - queue: _WriterQueue = Queue() - p = Process(target=writer_process, args=(queue, config, writer_schema, message_name, datum)) + writer_queue: _WriterQueue = Queue(1) + p = Process(target=writer_process, args=(writer_queue, config, writer_schema, message_name, datum)) p.start() - result = queue.get() - p.join() + TEN_SECONDS_WAIT = 10 + try: + result = writer_queue.get(True, TEN_SECONDS_WAIT) # Block for ten seconds + finally: + p.join() + writer_queue.close() if isinstance(result, bytes): return result - if isinstance(result, Exception): + if isinstance(result, BaseException): raise result raise IllegalArgumentException() raise NotImplementedError("Error: Reached unreachable code") @@ -309,4 +320,4 @@ def write_index(self, writer: BytesIO) -> None: write_indexes(writer, [self._message_index]) def write(self, datum: dict[object, object], writer: BytesIO) -> None: - writer.write(writer_mp(self.config, self._writer_schema, self._message_name, datum)) + writer.write(write_in_forked_multiprocess_process(self.config, self._writer_schema, self._message_name, datum))