From 2527f54972e58f5053b425bbef6b906e3c7c924c Mon Sep 17 00:00:00 2001 From: skrydal Date: Mon, 18 Nov 2024 19:41:45 +0100 Subject: [PATCH 1/4] feat(ingest/iceberg): Iceberg performance improvement (multi-threading) (#11182) --- .../ingestion/source/iceberg/iceberg.py | 178 +++++-- .../source/iceberg/iceberg_common.py | 60 +++ .../source/iceberg/iceberg_profiler.py | 169 +++--- .../utilities/threaded_iterator_executor.py | 1 - .../iceberg_multiprocessing_to_file.yml | 22 + .../tests/integration/iceberg/test_iceberg.py | 64 ++- metadata-ingestion/tests/unit/test_iceberg.py | 479 +++++++++++++++++- 7 files changed, 821 insertions(+), 152 deletions(-) create mode 100644 metadata-ingestion/tests/integration/iceberg/iceberg_multiprocessing_to_file.yml diff --git a/metadata-ingestion/src/datahub/ingestion/source/iceberg/iceberg.py b/metadata-ingestion/src/datahub/ingestion/source/iceberg/iceberg.py index d8c6c03ce81e67..258a4b9ad6daf6 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/iceberg/iceberg.py +++ b/metadata-ingestion/src/datahub/ingestion/source/iceberg/iceberg.py @@ -1,10 +1,15 @@ import json import logging +import threading import uuid from typing import Any, Dict, Iterable, List, Optional from pyiceberg.catalog import Catalog -from pyiceberg.exceptions import NoSuchIcebergTableError +from pyiceberg.exceptions import ( + NoSuchIcebergTableError, + NoSuchNamespaceError, + NoSuchPropertyException, +) from pyiceberg.schema import Schema, SchemaVisitorPerPrimitiveType, visit from pyiceberg.table import Table from pyiceberg.typedef import Identifier @@ -75,6 +80,8 @@ OwnershipClass, OwnershipTypeClass, ) +from datahub.utilities.perf_timer import PerfTimer +from datahub.utilities.threaded_iterator_executor import ThreadedIteratorExecutor LOGGER = logging.getLogger(__name__) logging.getLogger("azure.core.pipeline.policies.http_logging_policy").setLevel( @@ -130,74 +137,149 @@ def get_workunit_processors(self) -> List[Optional[MetadataWorkUnitProcessor]]: ] def _get_datasets(self, catalog: Catalog) -> Iterable[Identifier]: - for namespace in catalog.list_namespaces(): - yield from catalog.list_tables(namespace) + namespaces = catalog.list_namespaces() + LOGGER.debug( + f"Retrieved {len(namespaces)} namespaces, first 10: {namespaces[:10]}" + ) + self.report.report_no_listed_namespaces(len(namespaces)) + tables_count = 0 + for namespace in namespaces: + try: + tables = catalog.list_tables(namespace) + tables_count += len(tables) + LOGGER.debug( + f"Retrieved {len(tables)} tables for namespace: {namespace}, in total retrieved {tables_count}, first 10: {tables[:10]}" + ) + self.report.report_listed_tables_for_namespace( + ".".join(namespace), len(tables) + ) + yield from tables + except NoSuchNamespaceError: + self.report.report_warning( + "no-such-namespace", + f"Couldn't list tables for namespace {namespace} due to NoSuchNamespaceError exception", + ) + LOGGER.warning( + f"NoSuchNamespaceError exception while trying to get list of tables from namespace {namespace}, skipping it", + ) + except Exception as e: + self.report.report_failure( + "listing-tables-exception", + f"Couldn't list tables for namespace {namespace} due to {e}", + ) + LOGGER.exception( + f"Unexpected exception while trying to get list of tables for namespace {namespace}, skipping it" + ) def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: - try: - catalog = self.config.get_catalog() - except Exception as e: - LOGGER.error("Failed to get catalog", exc_info=True) - self.report.report_failure("get-catalog", f"Failed to get catalog: {e}") - return + thread_local = threading.local() - for dataset_path in self._get_datasets(catalog): + def _process_dataset(dataset_path: Identifier) -> Iterable[MetadataWorkUnit]: + LOGGER.debug(f"Processing dataset for path {dataset_path}") dataset_name = ".".join(dataset_path) if not self.config.table_pattern.allowed(dataset_name): # Dataset name is rejected by pattern, report as dropped. self.report.report_dropped(dataset_name) - continue - + return try: - # Try to load an Iceberg table. Might not contain one, this will be caught by NoSuchIcebergTableError. - table = catalog.load_table(dataset_path) + if not hasattr(thread_local, "local_catalog"): + LOGGER.debug( + f"Didn't find local_catalog in thread_local ({thread_local}), initializing new catalog" + ) + thread_local.local_catalog = self.config.get_catalog() + + with PerfTimer() as timer: + table = thread_local.local_catalog.load_table(dataset_path) + time_taken = timer.elapsed_seconds() + self.report.report_table_load_time(time_taken) + LOGGER.debug( + f"Loaded table: {table.identifier}, time taken: {time_taken}" + ) yield from self._create_iceberg_workunit(dataset_name, table) + except NoSuchPropertyException as e: + self.report.report_warning( + "table-property-missing", + f"Failed to create workunit for {dataset_name}. {e}", + ) + LOGGER.warning( + f"NoSuchPropertyException while processing table {dataset_path}, skipping it.", + ) except NoSuchIcebergTableError as e: + self.report.report_warning( + "no-iceberg-table", + f"Failed to create workunit for {dataset_name}. {e}", + ) + LOGGER.warning( + f"NoSuchIcebergTableError while processing table {dataset_path}, skipping it.", + ) + except Exception as e: self.report.report_failure("general", f"Failed to create workunit: {e}") LOGGER.exception( f"Exception while processing table {dataset_path}, skipping it.", ) + try: + catalog = self.config.get_catalog() + except Exception as e: + self.report.report_failure("get-catalog", f"Failed to get catalog: {e}") + return + + for wu in ThreadedIteratorExecutor.process( + worker_func=_process_dataset, + args_list=[(dataset_path,) for dataset_path in self._get_datasets(catalog)], + max_workers=self.config.processing_threads, + ): + yield wu + def _create_iceberg_workunit( self, dataset_name: str, table: Table ) -> Iterable[MetadataWorkUnit]: - self.report.report_table_scanned(dataset_name) - dataset_urn: str = make_dataset_urn_with_platform_instance( - self.platform, - dataset_name, - self.config.platform_instance, - self.config.env, - ) - dataset_snapshot = DatasetSnapshot( - urn=dataset_urn, - aspects=[Status(removed=False)], - ) - - # Dataset properties aspect. - custom_properties = table.metadata.properties.copy() - custom_properties["location"] = table.metadata.location - custom_properties["format-version"] = str(table.metadata.format_version) - custom_properties["partition-spec"] = str(self._get_partition_aspect(table)) - if table.current_snapshot(): - custom_properties["snapshot-id"] = str(table.current_snapshot().snapshot_id) - custom_properties["manifest-list"] = table.current_snapshot().manifest_list - dataset_properties = DatasetPropertiesClass( - name=table.name()[-1], - tags=[], - description=table.metadata.properties.get("comment", None), - customProperties=custom_properties, - ) - dataset_snapshot.aspects.append(dataset_properties) + with PerfTimer() as timer: + self.report.report_table_scanned(dataset_name) + LOGGER.debug(f"Processing table {dataset_name}") + dataset_urn: str = make_dataset_urn_with_platform_instance( + self.platform, + dataset_name, + self.config.platform_instance, + self.config.env, + ) + dataset_snapshot = DatasetSnapshot( + urn=dataset_urn, + aspects=[Status(removed=False)], + ) - # Dataset ownership aspect. - dataset_ownership = self._get_ownership_aspect(table) - if dataset_ownership: - dataset_snapshot.aspects.append(dataset_ownership) + # Dataset properties aspect. + custom_properties = table.metadata.properties.copy() + custom_properties["location"] = table.metadata.location + custom_properties["format-version"] = str(table.metadata.format_version) + custom_properties["partition-spec"] = str(self._get_partition_aspect(table)) + if table.current_snapshot(): + custom_properties["snapshot-id"] = str( + table.current_snapshot().snapshot_id + ) + custom_properties[ + "manifest-list" + ] = table.current_snapshot().manifest_list + dataset_properties = DatasetPropertiesClass( + name=table.name()[-1], + tags=[], + description=table.metadata.properties.get("comment", None), + customProperties=custom_properties, + ) + dataset_snapshot.aspects.append(dataset_properties) + # Dataset ownership aspect. + dataset_ownership = self._get_ownership_aspect(table) + if dataset_ownership: + LOGGER.debug( + f"Adding ownership: {dataset_ownership} to the dataset {dataset_name}" + ) + dataset_snapshot.aspects.append(dataset_ownership) - schema_metadata = self._create_schema_metadata(dataset_name, table) - dataset_snapshot.aspects.append(schema_metadata) + schema_metadata = self._create_schema_metadata(dataset_name, table) + dataset_snapshot.aspects.append(schema_metadata) - mce = MetadataChangeEvent(proposedSnapshot=dataset_snapshot) + mce = MetadataChangeEvent(proposedSnapshot=dataset_snapshot) + self.report.report_table_processing_time(timer.elapsed_seconds()) yield MetadataWorkUnit(id=dataset_name, mce=mce) dpi_aspect = self._get_dataplatform_instance_aspect(dataset_urn=dataset_urn) diff --git a/metadata-ingestion/src/datahub/ingestion/source/iceberg/iceberg_common.py b/metadata-ingestion/src/datahub/ingestion/source/iceberg/iceberg_common.py index b74c096d0798e8..98ad9e552d35c9 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/iceberg/iceberg_common.py +++ b/metadata-ingestion/src/datahub/ingestion/source/iceberg/iceberg_common.py @@ -2,6 +2,7 @@ from dataclasses import dataclass, field from typing import Any, Dict, List, Optional +from humanfriendly import format_timespan from pydantic import Field, validator from pyiceberg.catalog import Catalog, load_catalog @@ -18,6 +19,7 @@ OperationConfig, is_profiling_enabled, ) +from datahub.utilities.stats_collections import TopKDict, int_top_k_dict logger = logging.getLogger(__name__) @@ -75,6 +77,9 @@ class IcebergSourceConfig(StatefulIngestionConfigBase, DatasetSourceConfigMixin) description="Iceberg table property to look for a `CorpGroup` owner. Can only hold a single group value. If property has no value, no owner information will be emitted.", ) profiling: IcebergProfilingConfig = IcebergProfilingConfig() + processing_threads: int = Field( + default=1, description="How many threads will be processing tables" + ) @validator("catalog", pre=True, always=True) def handle_deprecated_catalog_format(cls, value): @@ -131,17 +136,72 @@ def get_catalog(self) -> Catalog: # Retrieve the dict associated with the one catalog entry catalog_name, catalog_config = next(iter(self.catalog.items())) + logger.debug( + "Initializing the catalog %s with config: %s", catalog_name, catalog_config + ) return load_catalog(name=catalog_name, **catalog_config) +class TimingClass: + times: List[int] + + def __init__(self): + self.times = [] + + def add_timing(self, t): + self.times.append(t) + + def __str__(self): + if len(self.times) == 0: + return "no timings reported" + self.times.sort() + total = sum(self.times) + avg = total / len(self.times) + return str( + { + "average_time": format_timespan(avg, detailed=True, max_units=3), + "min_time": format_timespan(self.times[0], detailed=True, max_units=3), + "max_time": format_timespan(self.times[-1], detailed=True, max_units=3), + # total_time does not provide correct information in case we run in more than 1 thread + "total_time": format_timespan(total, detailed=True, max_units=3), + } + ) + + @dataclass class IcebergSourceReport(StaleEntityRemovalSourceReport): tables_scanned: int = 0 entities_profiled: int = 0 filtered: List[str] = field(default_factory=list) + load_table_timings: TimingClass = field(default_factory=TimingClass) + processing_table_timings: TimingClass = field(default_factory=TimingClass) + profiling_table_timings: TimingClass = field(default_factory=TimingClass) + listed_namespaces: int = 0 + total_listed_tables: int = 0 + tables_listed_per_namespace: TopKDict[str, int] = field( + default_factory=int_top_k_dict + ) + + def report_listed_tables_for_namespace( + self, namespace: str, no_tables: int + ) -> None: + self.tables_listed_per_namespace[namespace] = no_tables + self.total_listed_tables += no_tables + + def report_no_listed_namespaces(self, amount: int) -> None: + self.listed_namespaces = amount def report_table_scanned(self, name: str) -> None: self.tables_scanned += 1 def report_dropped(self, ent_name: str) -> None: self.filtered.append(ent_name) + + def report_table_load_time(self, t: float) -> None: + self.load_table_timings.add_timing(t) + + def report_table_processing_time(self, t: float) -> None: + self.processing_table_timings.add_timing(t) + + def report_table_profiling_time(self, t: float) -> None: + self.profiling_table_timings.add_timing(t) diff --git a/metadata-ingestion/src/datahub/ingestion/source/iceberg/iceberg_profiler.py b/metadata-ingestion/src/datahub/ingestion/source/iceberg/iceberg_profiler.py index e1d52752d779a0..9cc6dd08544e4e 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/iceberg/iceberg_profiler.py +++ b/metadata-ingestion/src/datahub/ingestion/source/iceberg/iceberg_profiler.py @@ -1,3 +1,4 @@ +import logging from typing import Any, Callable, Dict, Iterable, Union, cast from pyiceberg.conversions import from_bytes @@ -33,6 +34,9 @@ DatasetFieldProfileClass, DatasetProfileClass, ) +from datahub.utilities.perf_timer import PerfTimer + +LOGGER = logging.getLogger(__name__) class IcebergProfiler: @@ -109,92 +113,101 @@ def profile_table( Yields: Iterator[Iterable[MetadataWorkUnit]]: Workunits related to datasetProfile. """ - current_snapshot = table.current_snapshot() - if not current_snapshot: - # Table has no data, cannot profile, or we can't get current_snapshot. - return - - row_count = ( - int(current_snapshot.summary.additional_properties["total-records"]) - if current_snapshot.summary - else 0 - ) - column_count = len( - [ - field.field_id - for field in table.schema().fields - if field.field_type.is_primitive - ] - ) - dataset_profile = DatasetProfileClass( - timestampMillis=get_sys_time(), - rowCount=row_count, - columnCount=column_count, - ) - dataset_profile.fieldProfiles = [] + with PerfTimer() as timer: + LOGGER.debug(f"Starting profiling of dataset: {dataset_name}") + current_snapshot = table.current_snapshot() + if not current_snapshot: + # Table has no data, cannot profile, or we can't get current_snapshot. + return + + row_count = ( + int(current_snapshot.summary.additional_properties["total-records"]) + if current_snapshot.summary + else 0 + ) + column_count = len( + [ + field.field_id + for field in table.schema().fields + if field.field_type.is_primitive + ] + ) + dataset_profile = DatasetProfileClass( + timestampMillis=get_sys_time(), + rowCount=row_count, + columnCount=column_count, + ) + dataset_profile.fieldProfiles = [] - total_count = 0 - null_counts: Dict[int, int] = {} - min_bounds: Dict[int, Any] = {} - max_bounds: Dict[int, Any] = {} - try: - for manifest in current_snapshot.manifests(table.io): - for manifest_entry in manifest.fetch_manifest_entry(table.io): - data_file = manifest_entry.data_file + total_count = 0 + null_counts: Dict[int, int] = {} + min_bounds: Dict[int, Any] = {} + max_bounds: Dict[int, Any] = {} + try: + for manifest in current_snapshot.manifests(table.io): + for manifest_entry in manifest.fetch_manifest_entry(table.io): + data_file = manifest_entry.data_file + if self.config.include_field_null_count: + null_counts = self._aggregate_counts( + null_counts, data_file.null_value_counts + ) + if self.config.include_field_min_value: + self._aggregate_bounds( + table.schema(), + min, + min_bounds, + data_file.lower_bounds, + ) + if self.config.include_field_max_value: + self._aggregate_bounds( + table.schema(), + max, + max_bounds, + data_file.upper_bounds, + ) + total_count += data_file.record_count + except Exception as e: + # Catch any errors that arise from attempting to read the Iceberg table's manifests + # This will prevent stateful ingestion from being blocked by an error (profiling is not critical) + self.report.report_warning( + "profiling", + f"Error while profiling dataset {dataset_name}: {e}", + ) + if row_count: + # Iterating through fieldPaths introduces unwanted stats for list element fields... + for field_path, field_id in table.schema()._name_to_id.items(): + field = table.schema().find_field(field_id) + column_profile = DatasetFieldProfileClass(fieldPath=field_path) if self.config.include_field_null_count: - null_counts = self._aggregate_counts( - null_counts, data_file.null_value_counts + column_profile.nullCount = cast( + int, null_counts.get(field_id, 0) + ) + column_profile.nullProportion = float( + column_profile.nullCount / row_count ) + if self.config.include_field_min_value: - self._aggregate_bounds( - table.schema(), - min, - min_bounds, - data_file.lower_bounds, + column_profile.min = ( + self._render_value( + dataset_name, field.field_type, min_bounds.get(field_id) + ) + if field_id in min_bounds + else None ) if self.config.include_field_max_value: - self._aggregate_bounds( - table.schema(), - max, - max_bounds, - data_file.upper_bounds, + column_profile.max = ( + self._render_value( + dataset_name, field.field_type, max_bounds.get(field_id) + ) + if field_id in max_bounds + else None ) - total_count += data_file.record_count - except Exception as e: - # Catch any errors that arise from attempting to read the Iceberg table's manifests - # This will prevent stateful ingestion from being blocked by an error (profiling is not critical) - self.report.report_warning( - "profiling", - f"Error while profiling dataset {dataset_name}: {e}", + dataset_profile.fieldProfiles.append(column_profile) + time_taken = timer.elapsed_seconds() + self.report.report_table_profiling_time(time_taken) + LOGGER.debug( + f"Finished profiling of dataset: {dataset_name} in {time_taken}" ) - if row_count: - # Iterating through fieldPaths introduces unwanted stats for list element fields... - for field_path, field_id in table.schema()._name_to_id.items(): - field = table.schema().find_field(field_id) - column_profile = DatasetFieldProfileClass(fieldPath=field_path) - if self.config.include_field_null_count: - column_profile.nullCount = cast(int, null_counts.get(field_id, 0)) - column_profile.nullProportion = float( - column_profile.nullCount / row_count - ) - - if self.config.include_field_min_value: - column_profile.min = ( - self._render_value( - dataset_name, field.field_type, min_bounds.get(field_id) - ) - if field_id in min_bounds - else None - ) - if self.config.include_field_max_value: - column_profile.max = ( - self._render_value( - dataset_name, field.field_type, max_bounds.get(field_id) - ) - if field_id in max_bounds - else None - ) - dataset_profile.fieldProfiles.append(column_profile) yield MetadataChangeProposalWrapper( entityUrn=dataset_urn, diff --git a/metadata-ingestion/src/datahub/utilities/threaded_iterator_executor.py b/metadata-ingestion/src/datahub/utilities/threaded_iterator_executor.py index 216fa155035d3e..4d328ad31c6c4a 100644 --- a/metadata-ingestion/src/datahub/utilities/threaded_iterator_executor.py +++ b/metadata-ingestion/src/datahub/utilities/threaded_iterator_executor.py @@ -46,7 +46,6 @@ def _worker_wrapper( futures = [f for f in futures if not f.done()] if not futures: break - # Yield the remaining work units. This theoretically should not happen, but adding it just in case. while not out_q.empty(): yield out_q.get_nowait() diff --git a/metadata-ingestion/tests/integration/iceberg/iceberg_multiprocessing_to_file.yml b/metadata-ingestion/tests/integration/iceberg/iceberg_multiprocessing_to_file.yml new file mode 100644 index 00000000000000..e5e866fb561c9b --- /dev/null +++ b/metadata-ingestion/tests/integration/iceberg/iceberg_multiprocessing_to_file.yml @@ -0,0 +1,22 @@ +run_id: iceberg-test + +source: + type: iceberg + config: + processing_threads: 5 + catalog: + default: + type: rest + uri: http://localhost:8181 + s3.access-key-id: admin + s3.secret-access-key: password + s3.region: us-east-1 + warehouse: s3a://warehouse/wh/ + s3.endpoint: http://localhost:9000 + user_ownership_property: owner + group_ownership_property: owner + +sink: + type: file + config: + filename: "./iceberg_mces.json" diff --git a/metadata-ingestion/tests/integration/iceberg/test_iceberg.py b/metadata-ingestion/tests/integration/iceberg/test_iceberg.py index 5a12afa457f01a..85809e557dd8d3 100644 --- a/metadata-ingestion/tests/integration/iceberg/test_iceberg.py +++ b/metadata-ingestion/tests/integration/iceberg/test_iceberg.py @@ -1,5 +1,5 @@ import subprocess -from typing import Any, Dict, List +from typing import Any, Dict from unittest.mock import patch import pytest @@ -18,6 +18,13 @@ FROZEN_TIME = "2020-04-14 07:00:00" GMS_PORT = 8080 GMS_SERVER = f"http://localhost:{GMS_PORT}" +# These paths change from one instance run of the clickhouse docker to the other, and the FROZEN_TIME does not apply to +# these. +PATHS_IN_GOLDEN_FILE_TO_IGNORE = [ + r"root\[\d+\]\['proposedSnapshot'\].+\['aspects'\].+\['customProperties'\]\['created-at'\]", + r"root\[\d+\]\['proposedSnapshot'\].+\['aspects'\].+\['customProperties'\]\['snapshot-id'\]", + r"root\[\d+\]\['proposedSnapshot'\].+\['aspects'\].+\['customProperties'\]\['manifest-list'\]", +] @pytest.fixture(autouse=True, scope="module") @@ -35,6 +42,36 @@ def spark_submit(file_path: str, args: str = "") -> None: assert ret.returncode == 0 +@freeze_time(FROZEN_TIME) +def test_multiprocessing_iceberg_ingest( + docker_compose_runner, pytestconfig, tmp_path, mock_time +): + test_resources_dir = pytestconfig.rootpath / "tests/integration/iceberg/" + + with docker_compose_runner( + test_resources_dir / "docker-compose.yml", "iceberg" + ) as docker_services: + wait_for_port(docker_services, "spark-iceberg", 8888, timeout=120) + + # Run the create.py pyspark file to populate the table. + spark_submit("/home/iceberg/setup/create.py", "nyc.taxis") + + # Run the metadata ingestion pipeline. + config_file = ( + test_resources_dir / "iceberg_multiprocessing_to_file.yml" + ).resolve() + run_datahub_cmd( + ["ingest", "--strict-warnings", "-c", f"{config_file}"], tmp_path=tmp_path + ) + # Verify the output. + mce_helpers.check_golden_file( + pytestconfig, + ignore_paths=PATHS_IN_GOLDEN_FILE_TO_IGNORE, + output_path=tmp_path / "iceberg_mces.json", + golden_path=test_resources_dir / "iceberg_ingest_mces_golden.json", + ) + + @freeze_time(FROZEN_TIME) def test_iceberg_ingest(docker_compose_runner, pytestconfig, tmp_path, mock_time): test_resources_dir = pytestconfig.rootpath / "tests/integration/iceberg/" @@ -52,16 +89,10 @@ def test_iceberg_ingest(docker_compose_runner, pytestconfig, tmp_path, mock_time run_datahub_cmd( ["ingest", "--strict-warnings", "-c", f"{config_file}"], tmp_path=tmp_path ) - # These paths change from one instance run of the clickhouse docker to the other, and the FROZEN_TIME does not apply to these. - ignore_paths: List[str] = [ - r"root\[\d+\]\['proposedSnapshot'\].+\['aspects'\].+\['customProperties'\]\['created-at'\]", - r"root\[\d+\]\['proposedSnapshot'\].+\['aspects'\].+\['customProperties'\]\['snapshot-id'\]", - r"root\[\d+\]\['proposedSnapshot'\].+\['aspects'\].+\['customProperties'\]\['manifest-list'\]", - ] # Verify the output. mce_helpers.check_golden_file( pytestconfig, - ignore_paths=ignore_paths, + ignore_paths=PATHS_IN_GOLDEN_FILE_TO_IGNORE, output_path=tmp_path / "iceberg_mces.json", golden_path=test_resources_dir / "iceberg_ingest_mces_golden.json", ) @@ -170,16 +201,10 @@ def test_iceberg_stateful_ingest( pipeline=pipeline_run2, expected_providers=1 ) - ignore_paths: List[str] = [ - r"root\[\d+\]\['proposedSnapshot'\].+\['aspects'\].+\['customProperties'\]\['created-at'\]", - r"root\[\d+\]\['proposedSnapshot'\].+\['aspects'\].+\['customProperties'\]\['snapshot-id'\]", - r"root\[\d+\]\['proposedSnapshot'\].+\['aspects'\].+\['customProperties'\]\['manifest-list'\]", - ] - # Verify the output. mce_helpers.check_golden_file( pytestconfig, - ignore_paths=ignore_paths, + ignore_paths=PATHS_IN_GOLDEN_FILE_TO_IGNORE, output_path=deleted_mces_path, golden_path=test_resources_dir / "iceberg_deleted_table_mces_golden.json", ) @@ -202,16 +227,11 @@ def test_iceberg_profiling(docker_compose_runner, pytestconfig, tmp_path, mock_t run_datahub_cmd( ["ingest", "--strict-warnings", "-c", f"{config_file}"], tmp_path=tmp_path ) - # These paths change from one instance run of the clickhouse docker to the other, and the FROZEN_TIME does not apply to these. - ignore_paths: List[str] = [ - r"root\[\d+\]\['proposedSnapshot'\].+\['aspects'\].+\['customProperties'\]\['created-at'\]", - r"root\[\d+\]\['proposedSnapshot'\].+\['aspects'\].+\['customProperties'\]\['snapshot-id'\]", - r"root\[\d+\]\['proposedSnapshot'\].+\['aspects'\].+\['customProperties'\]\['manifest-list'\]", - ] + # Verify the output. mce_helpers.check_golden_file( pytestconfig, - ignore_paths=ignore_paths, + ignore_paths=PATHS_IN_GOLDEN_FILE_TO_IGNORE, output_path=tmp_path / "iceberg_mces.json", golden_path=test_resources_dir / "iceberg_profile_mces_golden.json", ) diff --git a/metadata-ingestion/tests/unit/test_iceberg.py b/metadata-ingestion/tests/unit/test_iceberg.py index c8c6c6ac8a85d3..b8a136586a2bf5 100644 --- a/metadata-ingestion/tests/unit/test_iceberg.py +++ b/metadata-ingestion/tests/unit/test_iceberg.py @@ -1,10 +1,21 @@ import uuid from decimal import Decimal -from typing import Any, Dict, List, Optional +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple +from unittest import TestCase +from unittest.mock import patch import pytest from pydantic import ValidationError +from pyiceberg.exceptions import ( + NoSuchIcebergTableError, + NoSuchNamespaceError, + NoSuchPropertyException, +) +from pyiceberg.io.pyarrow import PyArrowFileIO +from pyiceberg.partitioning import PartitionSpec from pyiceberg.schema import Schema +from pyiceberg.table import Table +from pyiceberg.table.metadata import TableMetadataV2 from pyiceberg.types import ( BinaryType, BooleanType, @@ -29,16 +40,19 @@ ) from datahub.ingestion.api.common import PipelineContext +from datahub.ingestion.api.workunit import MetadataWorkUnit from datahub.ingestion.source.iceberg.iceberg import ( IcebergProfiler, IcebergSource, IcebergSourceConfig, ) +from datahub.metadata.com.linkedin.pegasus2avro.mxe import MetadataChangeEvent from datahub.metadata.com.linkedin.pegasus2avro.schema import ArrayType, SchemaField from datahub.metadata.schema_classes import ( ArrayTypeClass, BooleanTypeClass, BytesTypeClass, + DatasetSnapshotClass, DateTypeClass, FixedTypeClass, NumberTypeClass, @@ -48,11 +62,13 @@ ) -def with_iceberg_source() -> IcebergSource: +def with_iceberg_source(processing_threads: int = 1) -> IcebergSource: catalog = {"test": {"type": "rest"}} return IcebergSource( ctx=PipelineContext(run_id="iceberg-source-test"), - config=IcebergSourceConfig(catalog=catalog), + config=IcebergSourceConfig( + catalog=catalog, processing_threads=processing_threads + ), ) @@ -515,3 +531,460 @@ def test_avro_decimal_bytes_nullable() -> None: print( f"After avro parsing, _nullable attribute is preserved: {boolean_avro_schema}" ) + + +class MockCatalog: + def __init__(self, tables: Dict[str, Dict[str, Callable[[], Table]]]): + """ + + :param tables: Dictionary containing namespaces as keys and dictionaries containing names of tables (keys) and + their metadata as values + """ + self.tables = tables + + def list_namespaces(self) -> Iterable[str]: + return [*self.tables.keys()] + + def list_tables(self, namespace: str) -> Iterable[Tuple[str, str]]: + return [(namespace, table) for table in self.tables[namespace].keys()] + + def load_table(self, dataset_path: Tuple[str, str]) -> Table: + return self.tables[dataset_path[0]][dataset_path[1]]() + + +class MockCatalogExceptionListingTables(MockCatalog): + def list_tables(self, namespace: str) -> Iterable[Tuple[str, str]]: + if namespace == "no_such_namespace": + raise NoSuchNamespaceError() + if namespace == "generic_exception": + raise Exception() + return super().list_tables(namespace) + + +class MockCatalogExceptionListingNamespaces(MockCatalog): + def list_namespaces(self) -> Iterable[str]: + raise Exception() + + +def test_exception_while_listing_namespaces() -> None: + source = with_iceberg_source(processing_threads=2) + mock_catalog = MockCatalogExceptionListingNamespaces({}) + with patch( + "datahub.ingestion.source.iceberg.iceberg.IcebergSourceConfig.get_catalog" + ) as get_catalog, pytest.raises(Exception): + get_catalog.return_value = mock_catalog + [*source.get_workunits_internal()] + + +def test_known_exception_while_listing_tables() -> None: + source = with_iceberg_source(processing_threads=2) + mock_catalog = MockCatalogExceptionListingTables( + { + "namespaceA": { + "table1": lambda: Table( + identifier=("namespaceA", "table1"), + metadata=TableMetadataV2( + partition_specs=[PartitionSpec(spec_id=0)], + location="s3://abcdefg/namespaceA/table1", + last_column_id=0, + schemas=[Schema(schema_id=0)], + ), + metadata_location="s3://abcdefg/namespaceA/table1", + io=PyArrowFileIO(), + catalog=None, + ) + }, + "no_such_namespace": {}, + "namespaceB": { + "table2": lambda: Table( + identifier=("namespaceB", "table2"), + metadata=TableMetadataV2( + partition_specs=[PartitionSpec(spec_id=0)], + location="s3://abcdefg/namespaceB/table2", + last_column_id=0, + schemas=[Schema(schema_id=0)], + ), + metadata_location="s3://abcdefg/namespaceB/table2", + io=PyArrowFileIO(), + catalog=None, + ), + "table3": lambda: Table( + identifier=("namespaceB", "table3"), + metadata=TableMetadataV2( + partition_specs=[PartitionSpec(spec_id=0)], + location="s3://abcdefg/namespaceB/table3", + last_column_id=0, + schemas=[Schema(schema_id=0)], + ), + metadata_location="s3://abcdefg/namespaceB/table3", + io=PyArrowFileIO(), + catalog=None, + ), + }, + "namespaceC": { + "table4": lambda: Table( + identifier=("namespaceC", "table4"), + metadata=TableMetadataV2( + partition_specs=[PartitionSpec(spec_id=0)], + location="s3://abcdefg/namespaceC/table4", + last_column_id=0, + schemas=[Schema(schema_id=0)], + ), + metadata_location="s3://abcdefg/namespaceC/table4", + io=PyArrowFileIO(), + catalog=None, + ) + }, + "namespaceD": { + "table5": lambda: Table( + identifier=("namespaceD", "table5"), + metadata=TableMetadataV2( + partition_specs=[PartitionSpec(spec_id=0)], + location="s3://abcdefg/namespaceA/table5", + last_column_id=0, + schemas=[Schema(schema_id=0)], + ), + metadata_location="s3://abcdefg/namespaceA/table5", + io=PyArrowFileIO(), + catalog=None, + ) + }, + } + ) + with patch( + "datahub.ingestion.source.iceberg.iceberg.IcebergSourceConfig.get_catalog" + ) as get_catalog: + get_catalog.return_value = mock_catalog + wu: List[MetadataWorkUnit] = [*source.get_workunits_internal()] + assert len(wu) == 5 # ingested 5 tables, despite exception + urns = [] + for unit in wu: + assert isinstance(unit.metadata, MetadataChangeEvent) + assert isinstance(unit.metadata.proposedSnapshot, DatasetSnapshotClass) + urns.append(unit.metadata.proposedSnapshot.urn) + TestCase().assertCountEqual( + urns, + [ + "urn:li:dataset:(urn:li:dataPlatform:iceberg,namespaceA.table1,PROD)", + "urn:li:dataset:(urn:li:dataPlatform:iceberg,namespaceB.table2,PROD)", + "urn:li:dataset:(urn:li:dataPlatform:iceberg,namespaceB.table3,PROD)", + "urn:li:dataset:(urn:li:dataPlatform:iceberg,namespaceC.table4,PROD)", + "urn:li:dataset:(urn:li:dataPlatform:iceberg,namespaceD.table5,PROD)", + ], + ) + assert source.report.warnings.total_elements == 1 + assert source.report.failures.total_elements == 0 + assert source.report.tables_scanned == 5 + + +def test_unknown_exception_while_listing_tables() -> None: + source = with_iceberg_source(processing_threads=2) + mock_catalog = MockCatalogExceptionListingTables( + { + "namespaceA": { + "table1": lambda: Table( + identifier=("namespaceA", "table1"), + metadata=TableMetadataV2( + partition_specs=[PartitionSpec(spec_id=0)], + location="s3://abcdefg/namespaceA/table1", + last_column_id=0, + schemas=[Schema(schema_id=0)], + ), + metadata_location="s3://abcdefg/namespaceA/table1", + io=PyArrowFileIO(), + catalog=None, + ) + }, + "generic_exception": {}, + "namespaceB": { + "table2": lambda: Table( + identifier=("namespaceB", "table2"), + metadata=TableMetadataV2( + partition_specs=[PartitionSpec(spec_id=0)], + location="s3://abcdefg/namespaceB/table2", + last_column_id=0, + schemas=[Schema(schema_id=0)], + ), + metadata_location="s3://abcdefg/namespaceB/table2", + io=PyArrowFileIO(), + catalog=None, + ), + "table3": lambda: Table( + identifier=("namespaceB", "table3"), + metadata=TableMetadataV2( + partition_specs=[PartitionSpec(spec_id=0)], + location="s3://abcdefg/namespaceB/table3", + last_column_id=0, + schemas=[Schema(schema_id=0)], + ), + metadata_location="s3://abcdefg/namespaceB/table3", + io=PyArrowFileIO(), + catalog=None, + ), + }, + "namespaceC": { + "table4": lambda: Table( + identifier=("namespaceC", "table4"), + metadata=TableMetadataV2( + partition_specs=[PartitionSpec(spec_id=0)], + location="s3://abcdefg/namespaceC/table4", + last_column_id=0, + schemas=[Schema(schema_id=0)], + ), + metadata_location="s3://abcdefg/namespaceC/table4", + io=PyArrowFileIO(), + catalog=None, + ) + }, + "namespaceD": { + "table5": lambda: Table( + identifier=("namespaceD", "table5"), + metadata=TableMetadataV2( + partition_specs=[PartitionSpec(spec_id=0)], + location="s3://abcdefg/namespaceA/table5", + last_column_id=0, + schemas=[Schema(schema_id=0)], + ), + metadata_location="s3://abcdefg/namespaceA/table5", + io=PyArrowFileIO(), + catalog=None, + ) + }, + } + ) + with patch( + "datahub.ingestion.source.iceberg.iceberg.IcebergSourceConfig.get_catalog" + ) as get_catalog: + get_catalog.return_value = mock_catalog + wu: List[MetadataWorkUnit] = [*source.get_workunits_internal()] + assert len(wu) == 5 # ingested 5 tables, despite exception + urns = [] + for unit in wu: + assert isinstance(unit.metadata, MetadataChangeEvent) + assert isinstance(unit.metadata.proposedSnapshot, DatasetSnapshotClass) + urns.append(unit.metadata.proposedSnapshot.urn) + TestCase().assertCountEqual( + urns, + [ + "urn:li:dataset:(urn:li:dataPlatform:iceberg,namespaceA.table1,PROD)", + "urn:li:dataset:(urn:li:dataPlatform:iceberg,namespaceB.table2,PROD)", + "urn:li:dataset:(urn:li:dataPlatform:iceberg,namespaceB.table3,PROD)", + "urn:li:dataset:(urn:li:dataPlatform:iceberg,namespaceC.table4,PROD)", + "urn:li:dataset:(urn:li:dataPlatform:iceberg,namespaceD.table5,PROD)", + ], + ) + assert source.report.warnings.total_elements == 0 + assert source.report.failures.total_elements == 1 + assert source.report.tables_scanned == 5 + + +def test_proper_run_with_multiple_namespaces() -> None: + source = with_iceberg_source(processing_threads=3) + mock_catalog = MockCatalog( + { + "namespaceA": { + "table1": lambda: Table( + identifier=("namespaceA", "table1"), + metadata=TableMetadataV2( + partition_specs=[PartitionSpec(spec_id=0)], + location="s3://abcdefg/namespaceA/table1", + last_column_id=0, + schemas=[Schema(schema_id=0)], + ), + metadata_location="s3://abcdefg/namespaceA/table1", + io=PyArrowFileIO(), + catalog=None, + ) + }, + "namespaceB": {}, + } + ) + with patch( + "datahub.ingestion.source.iceberg.iceberg.IcebergSourceConfig.get_catalog" + ) as get_catalog: + get_catalog.return_value = mock_catalog + wu: List[MetadataWorkUnit] = [*source.get_workunits_internal()] + assert len(wu) == 1 # only one table processed as an MCE + assert isinstance(wu[0].metadata, MetadataChangeEvent) + assert isinstance(wu[0].metadata.proposedSnapshot, DatasetSnapshotClass) + snapshot: DatasetSnapshotClass = wu[0].metadata.proposedSnapshot + assert ( + snapshot.urn + == "urn:li:dataset:(urn:li:dataPlatform:iceberg,namespaceA.table1,PROD)" + ) + + +def test_handle_expected_exceptions() -> None: + source = with_iceberg_source(processing_threads=3) + + def _raise_no_such_property_exception(): + raise NoSuchPropertyException() + + def _raise_no_such_table_exception(): + raise NoSuchIcebergTableError() + + mock_catalog = MockCatalog( + { + "namespaceA": { + "table1": lambda: Table( + identifier=("namespaceA", "table1"), + metadata=TableMetadataV2( + partition_specs=[PartitionSpec(spec_id=0)], + location="s3://abcdefg/namespaceA/table1", + last_column_id=0, + schemas=[Schema(schema_id=0)], + ), + metadata_location="s3://abcdefg/namespaceA/table1", + io=PyArrowFileIO(), + catalog=None, + ), + "table2": lambda: Table( + identifier=("namespaceA", "table2"), + metadata=TableMetadataV2( + partition_specs=[PartitionSpec(spec_id=0)], + location="s3://abcdefg/namespaceA/table2", + last_column_id=0, + schemas=[Schema(schema_id=0)], + ), + metadata_location="s3://abcdefg/namespaceA/table2", + io=PyArrowFileIO(), + catalog=None, + ), + "table3": lambda: Table( + identifier=("namespaceA", "table3"), + metadata=TableMetadataV2( + partition_specs=[PartitionSpec(spec_id=0)], + location="s3://abcdefg/namespaceA/table3", + last_column_id=0, + schemas=[Schema(schema_id=0)], + ), + metadata_location="s3://abcdefg/namespaceA/table3", + io=PyArrowFileIO(), + catalog=None, + ), + "table4": lambda: Table( + identifier=("namespaceA", "table4"), + metadata=TableMetadataV2( + partition_specs=[PartitionSpec(spec_id=0)], + location="s3://abcdefg/namespaceA/table4", + last_column_id=0, + schemas=[Schema(schema_id=0)], + ), + metadata_location="s3://abcdefg/namespaceA/table4", + io=PyArrowFileIO(), + catalog=None, + ), + "table5": _raise_no_such_property_exception, + "table6": _raise_no_such_table_exception, + } + } + ) + with patch( + "datahub.ingestion.source.iceberg.iceberg.IcebergSourceConfig.get_catalog" + ) as get_catalog: + get_catalog.return_value = mock_catalog + wu: List[MetadataWorkUnit] = [*source.get_workunits_internal()] + assert len(wu) == 4 + urns = [] + for unit in wu: + assert isinstance(unit.metadata, MetadataChangeEvent) + assert isinstance(unit.metadata.proposedSnapshot, DatasetSnapshotClass) + urns.append(unit.metadata.proposedSnapshot.urn) + TestCase().assertCountEqual( + urns, + [ + "urn:li:dataset:(urn:li:dataPlatform:iceberg,namespaceA.table1,PROD)", + "urn:li:dataset:(urn:li:dataPlatform:iceberg,namespaceA.table2,PROD)", + "urn:li:dataset:(urn:li:dataPlatform:iceberg,namespaceA.table3,PROD)", + "urn:li:dataset:(urn:li:dataPlatform:iceberg,namespaceA.table4,PROD)", + ], + ) + assert source.report.warnings.total_elements == 2 + assert source.report.failures.total_elements == 0 + assert source.report.tables_scanned == 4 + + +def test_handle_unexpected_exceptions() -> None: + source = with_iceberg_source(processing_threads=3) + + def _raise_exception(): + raise Exception() + + mock_catalog = MockCatalog( + { + "namespaceA": { + "table1": lambda: Table( + identifier=("namespaceA", "table1"), + metadata=TableMetadataV2( + partition_specs=[PartitionSpec(spec_id=0)], + location="s3://abcdefg/namespaceA/table1", + last_column_id=0, + schemas=[Schema(schema_id=0)], + ), + metadata_location="s3://abcdefg/namespaceA/table1", + io=PyArrowFileIO(), + catalog=None, + ), + "table2": lambda: Table( + identifier=("namespaceA", "table2"), + metadata=TableMetadataV2( + partition_specs=[PartitionSpec(spec_id=0)], + location="s3://abcdefg/namespaceA/table2", + last_column_id=0, + schemas=[Schema(schema_id=0)], + ), + metadata_location="s3://abcdefg/namespaceA/table2", + io=PyArrowFileIO(), + catalog=None, + ), + "table3": lambda: Table( + identifier=("namespaceA", "table3"), + metadata=TableMetadataV2( + partition_specs=[PartitionSpec(spec_id=0)], + location="s3://abcdefg/namespaceA/table3", + last_column_id=0, + schemas=[Schema(schema_id=0)], + ), + metadata_location="s3://abcdefg/namespaceA/table3", + io=PyArrowFileIO(), + catalog=None, + ), + "table4": lambda: Table( + identifier=("namespaceA", "table4"), + metadata=TableMetadataV2( + partition_specs=[PartitionSpec(spec_id=0)], + location="s3://abcdefg/namespaceA/table4", + last_column_id=0, + schemas=[Schema(schema_id=0)], + ), + metadata_location="s3://abcdefg/namespaceA/table4", + io=PyArrowFileIO(), + catalog=None, + ), + "table5": _raise_exception, + } + } + ) + with patch( + "datahub.ingestion.source.iceberg.iceberg.IcebergSourceConfig.get_catalog" + ) as get_catalog: + get_catalog.return_value = mock_catalog + wu: List[MetadataWorkUnit] = [*source.get_workunits_internal()] + assert len(wu) == 4 + urns = [] + for unit in wu: + assert isinstance(unit.metadata, MetadataChangeEvent) + assert isinstance(unit.metadata.proposedSnapshot, DatasetSnapshotClass) + urns.append(unit.metadata.proposedSnapshot.urn) + TestCase().assertCountEqual( + urns, + [ + "urn:li:dataset:(urn:li:dataPlatform:iceberg,namespaceA.table1,PROD)", + "urn:li:dataset:(urn:li:dataPlatform:iceberg,namespaceA.table2,PROD)", + "urn:li:dataset:(urn:li:dataPlatform:iceberg,namespaceA.table3,PROD)", + "urn:li:dataset:(urn:li:dataPlatform:iceberg,namespaceA.table4,PROD)", + ], + ) + assert source.report.warnings.total_elements == 0 + assert source.report.failures.total_elements == 1 + assert source.report.tables_scanned == 4 From 19702c822580301383a4d0a788c6b526191b391b Mon Sep 17 00:00:00 2001 From: Raudzis Sebastian <32541580+raudzis@users.noreply.github.com> Date: Tue, 19 Nov 2024 00:45:05 +0100 Subject: [PATCH 2/4] fix(ingest/lookml): replace class variable with instance variable for improved encapsulation (#11881) --- .../ingestion/source/looker/lookml_source.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/metadata-ingestion/src/datahub/ingestion/source/looker/lookml_source.py b/metadata-ingestion/src/datahub/ingestion/source/looker/lookml_source.py index e4d8dd19fb7917..d258570ec384f7 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/looker/lookml_source.py +++ b/metadata-ingestion/src/datahub/ingestion/source/looker/lookml_source.py @@ -283,23 +283,21 @@ class LookMLSource(StatefulIngestionSourceBase): """ platform = "lookml" - source_config: LookMLSourceConfig - reporter: LookMLSourceReport - looker_client: Optional[LookerAPI] = None - - # This is populated during the git clone step. - base_projects_folder: Dict[str, pathlib.Path] = {} - remote_projects_git_info: Dict[str, GitInfo] = {} def __init__(self, config: LookMLSourceConfig, ctx: PipelineContext): super().__init__(config, ctx) - self.source_config = config + self.source_config: LookMLSourceConfig = config self.ctx = ctx self.reporter = LookMLSourceReport() # To keep track of projects (containers) which have already been ingested self.processed_projects: List[str] = [] + # This is populated during the git clone step. + self.base_projects_folder: Dict[str, pathlib.Path] = {} + self.remote_projects_git_info: Dict[str, GitInfo] = {} + + self.looker_client: Optional[LookerAPI] = None if self.source_config.api: self.looker_client = LookerAPI(self.source_config.api) self.reporter._looker_api = self.looker_client From bf16e58d43563636f3a439d0b9ce85ab0cef02ea Mon Sep 17 00:00:00 2001 From: david-leifker <114954101+david-leifker@users.noreply.github.com> Date: Mon, 18 Nov 2024 17:50:04 -0600 Subject: [PATCH 3/4] docs(urn): urn encoding (#11884) --- docs/what/urn.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/what/urn.md b/docs/what/urn.md index 122b93038d9dec..e35ca7fbaca4bc 100644 --- a/docs/what/urn.md +++ b/docs/what/urn.md @@ -41,4 +41,4 @@ There are a few restrictions when creating an urn: 2. Parentheses are reserved characters in URN fields: `( , )` 3. Colons are reserved characters in URN fields: `:` -Please do not use these characters when creating or generating urns. +Please do not use these characters when creating or generating urns. One approach is to use URL encoding for the characters. From 94f1f39667a6edf2570fba04a9e1effc60423f9a Mon Sep 17 00:00:00 2001 From: Andrew Sikowitz Date: Mon, 18 Nov 2024 17:25:43 -0800 Subject: [PATCH 4/4] fix(ingest/partitionExecutor): Fetch ready items for non-empty batch when _pending is empty (#11885) --- .../datahub/utilities/partition_executor.py | 22 ++++++------ .../unit/utilities/test_partition_executor.py | 36 ++++++++++++++++--- 2 files changed, 42 insertions(+), 16 deletions(-) diff --git a/metadata-ingestion/src/datahub/utilities/partition_executor.py b/metadata-ingestion/src/datahub/utilities/partition_executor.py index 0f0e9784464f6d..4d873d8f74bd8e 100644 --- a/metadata-ingestion/src/datahub/utilities/partition_executor.py +++ b/metadata-ingestion/src/datahub/utilities/partition_executor.py @@ -270,7 +270,7 @@ def __init__( self.read_from_pending_interval = read_from_pending_interval assert self.max_workers > 1 - self.state_lock = threading.Lock() + self._state_lock = threading.Lock() self._executor = ThreadPoolExecutor( # We add one here to account for the clearinghouse worker thread. max_workers=max_workers + 1, @@ -323,7 +323,7 @@ def _handle_batch_completion( if item.done_callback: item.done_callback(future) - def _find_ready_items() -> List[_BatchPartitionWorkItem]: + def _find_ready_items(max_to_add: int) -> List[_BatchPartitionWorkItem]: with clearinghouse_state_lock: # First, update the keys in flight. for key in keys_no_longer_in_flight: @@ -336,10 +336,7 @@ def _find_ready_items() -> List[_BatchPartitionWorkItem]: ready: List[_BatchPartitionWorkItem] = [] for item in pending: - if ( - len(ready) < self.max_per_batch - and item.key not in keys_in_flight - ): + if len(ready) < max_to_add and item.key not in keys_in_flight: ready.append(item) else: pending_key_completion.append(item) @@ -347,7 +344,7 @@ def _find_ready_items() -> List[_BatchPartitionWorkItem]: return ready def _build_batch() -> List[_BatchPartitionWorkItem]: - next_batch = _find_ready_items() + next_batch = _find_ready_items(self.max_per_batch) while ( not self._queue_empty_for_shutdown @@ -382,11 +379,12 @@ def _build_batch() -> List[_BatchPartitionWorkItem]: pending_key_completion.append(next_item) else: next_batch.append(next_item) - - if not next_batch: - next_batch = _find_ready_items() except queue.Empty: - if not blocking: + if blocking: + next_batch.extend( + _find_ready_items(self.max_per_batch - len(next_batch)) + ) + else: break return next_batch @@ -458,7 +456,7 @@ def _ensure_clearinghouse_started(self) -> None: f"{self.__class__.__name__} is shutting down; cannot submit new work items." ) - with self.state_lock: + with self._state_lock: # Lazily start the clearinghouse worker. if not self._clearinghouse_started: self._clearinghouse_started = True diff --git a/metadata-ingestion/tests/unit/utilities/test_partition_executor.py b/metadata-ingestion/tests/unit/utilities/test_partition_executor.py index e3a68405e3c0ac..eba79eafce473b 100644 --- a/metadata-ingestion/tests/unit/utilities/test_partition_executor.py +++ b/metadata-ingestion/tests/unit/utilities/test_partition_executor.py @@ -133,9 +133,9 @@ def process_batch(batch): } -@pytest.mark.timeout(10) +@pytest.mark.timeout(5) def test_batch_partition_executor_max_batch_size(): - n = 20 # Exceed max_pending to test for deadlocks when max_pending exceeded + n = 5 batches_processed = [] def process_batch(batch): @@ -147,8 +147,8 @@ def process_batch(batch): max_pending=10, process_batch=process_batch, max_per_batch=2, - min_process_interval=timedelta(seconds=1), - read_from_pending_interval=timedelta(seconds=1), + min_process_interval=timedelta(seconds=0.1), + read_from_pending_interval=timedelta(seconds=0.1), ) as executor: # Submit more tasks than the max_per_batch to test batching limits. for i in range(n): @@ -161,6 +161,34 @@ def process_batch(batch): assert len(batch) <= 2, "Batch size exceeded max_per_batch limit" +@pytest.mark.timeout(10) +def test_batch_partition_executor_deadlock(): + n = 20 # Exceed max_pending to test for deadlocks when max_pending exceeded + batch_size = 2 + batches_processed = [] + + def process_batch(batch): + batches_processed.append(batch) + time.sleep(0.1) # Simulate batch processing time + + with BatchPartitionExecutor( + max_workers=5, + max_pending=2, + process_batch=process_batch, + max_per_batch=batch_size, + min_process_interval=timedelta(seconds=30), + read_from_pending_interval=timedelta(seconds=0.01), + ) as executor: + # Submit more tasks than the max_per_batch to test batching limits. + executor.submit("key3", "key3", "task0") + executor.submit("key3", "key3", "task1") + executor.submit("key1", "key1", "task1") # Populates second batch + for i in range(3, n): + executor.submit("key3", "key3", f"task{i}") + + assert sum(len(batch) for batch in batches_processed) == n + + def test_empty_batch_partition_executor(): # We want to test that even if no submit() calls are made, cleanup works fine. with BatchPartitionExecutor(