diff --git a/karapace/schema_models.py b/karapace/schema_models.py index 2874a903c..e2f75cb7a 100644 --- a/karapace/schema_models.py +++ b/karapace/schema_models.py @@ -29,6 +29,7 @@ import hashlib import logging +import re LOG = logging.getLogger(__name__) @@ -181,19 +182,51 @@ def schema(self) -> Draft7Validator | AvroSchema | ProtobufSchema: return parsed_typed_schema.schema -def avro_schema_merge_builder(schema_str: str, dependencies: Mapping[str, Dependency]) -> str: - """To support references in AVRO we recursively merge all referenced schemas with current schema""" - if dependencies: - merged_schema = "" - for dependency in dependencies.values(): - merged_schema += avro_schema_merge_builder(dependency.schema.schema_str, dependency.schema.dependencies) + ",\n" - merged_schema += schema_str - return merged_schema - return schema_str +class AvroMerge: + def __init__(self, schema_str: str, dependencies: Mapping[str, Dependency]): + self.schema_str = schema_str + self.dependencies = dependencies + self.unique_id = 0 + + def union_safe_schema_str(self, schema_str: str) -> str: + # in case we meet union - we use it as is + regex = re.compile(r"^\s*\[") + if regex.match(schema_str): + return ( + "{" + + f' "name": "___un__ique_x_n_q_karapace____{self.unique_id}",' + + '"type":"record","fields":[{"name":"name", "type":' + + schema_str + + "}]}" + ) + + return ( + "{" + + f' "name": "___un__ique_x_n_q_karapace____{self.unique_id}",' + + '"type":"record","fields":[{"name":"name", "type": ' + + '["string",' + + schema_str + + "]}]}" + ) + + def builder(self, schema_str: str, dependencies: Mapping[str, Dependency]) -> str: + """To support references in AVRO we recursively merge all referenced schemas with current schema""" + if dependencies: + merged_schema = "" + for dependency in dependencies.values(): + merged_schema += self.builder(dependency.schema.schema_str, dependency.schema.dependencies) + ",\n" + self.unique_id += 1 + merged_schema += self.union_safe_schema_str(schema_str) + return merged_schema + if self.unique_id == 0: + return schema_str + self.unique_id += 1 + return self.union_safe_schema_str(schema_str) -def avro_schema_merge(schema_str: str, dependencies: Mapping[str, Dependency]) -> str: - return "[\n" + avro_schema_merge_builder(schema_str, dependencies) + "\n]" + def merge(self) -> str: + result = "[\n" + self.builder(self.schema_str, self.dependencies) + "\n]" + return result def parse( @@ -212,7 +245,7 @@ def parse( if schema_type is SchemaType.AVRO: try: parsed_schema = parse_avro_schema_definition( - avro_schema_merge(schema_str, dependencies), + AvroMerge(schema_str, dependencies).merge(), validate_enum_symbols=validate_avro_enum_symbols, validate_names=validate_avro_names, ) diff --git a/tests/integration/test_schema_avro_references.py b/tests/integration/test_schema_avro_references.py index 8dcffc66a..825972a9d 100644 --- a/tests/integration/test_schema_avro_references.py +++ b/tests/integration/test_schema_avro_references.py @@ -11,7 +11,7 @@ baseurl = "http://localhost:8081" -async def test_avro_references(registry_async_client: Client) -> None: +async def test_simple_references(registry_async_client: Client) -> None: schema_country = { "type": "record", "name": "Country", @@ -177,3 +177,34 @@ async def test_avro_references(registry_async_client: Client) -> None: ) assert res.status_code == 200 assert "id" in res.json() + + schema_union2 = [ + { + "type": "record", + "name": "Person", + "namespace": "com.netapp", + "fields": [ + {"name": "name", "type": "string"}, + {"name": "age", "type": "int"}, + {"name": "address", "type": "Address"}, + {"name": "job", "type": "Job"}, + ], + }, + { + "type": "record", + "name": "UnemployedPerson", + "namespace": "com.netapp", + "fields": [ + {"name": "name", "type": "string"}, + {"name": "age", "type": "int"}, + {"name": "address", "type": "Address"}, + ], + }, + ] + + res = await registry_async_client.post( + "subjects/person3/versions", + json={"schemaType": "AVRO", "schema": json.dumps(schema_union2), "references": two_references}, + ) + assert res.status_code == 200 + assert "id" in res.json()