Skip to content

Commit

Permalink
Merge branch 'datahub-project:master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
hsheth2 authored Jan 12, 2024
2 parents 37bf193 + 33e3294 commit 0c09de9
Show file tree
Hide file tree
Showing 11 changed files with 234 additions and 82 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ public SchemaField apply(
}
result.setIsPartOfKey(input.isIsPartOfKey());
result.setIsPartitioningKey(input.isIsPartitioningKey());
result.setJsonProps(input.getJsonProps());
return result;
}

Expand Down
5 changes: 5 additions & 0 deletions datahub-graphql-core/src/main/resources/entity.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -2892,6 +2892,11 @@ type SchemaField {
Whether the field is part of a partitioning key schema
"""
isPartitioningKey: Boolean

"""
For schema fields that have other properties that are not modeled explicitly, represented as a JSON string.
"""
jsonProps: String
}

"""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
from __future__ import print_function

import datetime
import itertools
import logging
import re
from contextlib import contextmanager
from dataclasses import dataclass, field as dataclasses_field
from enum import Enum
from functools import lru_cache
from typing import (
TYPE_CHECKING,
Dict,
Iterable,
Iterator,
List,
Optional,
Sequence,
Expand Down Expand Up @@ -1126,6 +1126,14 @@ def report_stage_end(self, stage_name: str) -> None:
if self.stage_latency[-1].name == stage_name:
self.stage_latency[-1].end_time = datetime.datetime.now()

@contextmanager
def report_stage(self, stage_name: str) -> Iterator[None]:
try:
self.report_stage_start(stage_name)
yield
finally:
self.report_stage_end(stage_name)

def compute_stats(self) -> None:
if self.total_dashboards:
self.dashboard_process_percentage_completion = round(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Dict, List, MutableMapping, Optional, Sequence, Set, Union, cast

import looker_sdk
import looker_sdk.rtl.requests_transport as looker_requests_transport
from looker_sdk.error import SDKError
from looker_sdk.rtl.transport import TransportOptions
from looker_sdk.sdk.api40.models import (
Expand All @@ -21,6 +22,7 @@
WriteQuery,
)
from pydantic import BaseModel, Field
from requests.adapters import HTTPAdapter

from datahub.configuration import ConfigModel
from datahub.configuration.common import ConfigurationError
Expand All @@ -46,6 +48,7 @@ class LookerAPIConfig(ConfigModel):
None,
description="Populates the [TransportOptions](https://github.com/looker-open-source/sdk-codegen/blob/94d6047a0d52912ac082eb91616c1e7c379ab262/python/looker_sdk/rtl/transport.py#L70) struct for looker client",
)
max_retries: int = Field(3, description="Number of retries for Looker API calls")


class LookerAPIStats(BaseModel):
Expand Down Expand Up @@ -76,6 +79,20 @@ def __init__(self, config: LookerAPIConfig) -> None:
os.environ["LOOKERSDK_BASE_URL"] = config.base_url

self.client = looker_sdk.init40()

# Somewhat hacky mechanism for enabling retries on the Looker SDK.
# Unfortunately, it doesn't expose a cleaner way to do this.
if isinstance(
self.client.transport, looker_requests_transport.RequestsTransport
):
adapter = HTTPAdapter(
max_retries=self.config.max_retries,
)
self.client.transport.session.mount("http://", adapter)
self.client.transport.session.mount("https://", adapter)
elif self.config.max_retries > 0:
logger.warning("Unable to configure retries on the Looker SDK transport.")

self.transport_options = (
config.transport_options.get_transport_options()
if config.transport_options is not None
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import concurrent.futures
import datetime
import json
import logging
Expand Down Expand Up @@ -91,6 +90,7 @@
OwnershipClass,
OwnershipTypeClass,
)
from datahub.utilities.advanced_thread_executor import BackpressureAwareExecutor

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -700,28 +700,19 @@ def _make_explore_metadata_events(
explores_to_fetch = list(self.list_all_explores())
explores_to_fetch.sort()

with concurrent.futures.ThreadPoolExecutor(
max_workers=self.source_config.max_threads
) as async_executor:
self.reporter.total_explores = len(explores_to_fetch)

explore_futures = {
async_executor.submit(self.fetch_one_explore, model, explore): (
model,
explore,
)
for (model, explore) in explores_to_fetch
}

for future in concurrent.futures.wait(explore_futures).done:
events, explore_id, start_time, end_time = future.result()
del explore_futures[future]
self.reporter.explores_scanned += 1
yield from events
self.reporter.report_upstream_latency(start_time, end_time)
logger.debug(
f"Running time of fetch_one_explore for {explore_id}: {(end_time - start_time).total_seconds()}"
)
self.reporter.total_explores = len(explores_to_fetch)
for future in BackpressureAwareExecutor.map(
self.fetch_one_explore,
((model, explore) for (model, explore) in explores_to_fetch),
max_workers=self.source_config.max_threads,
):
events, explore_id, start_time, end_time = future.result()
self.reporter.explores_scanned += 1
yield from events
self.reporter.report_upstream_latency(start_time, end_time)
logger.debug(
f"Running time of fetch_one_explore for {explore_id}: {(end_time - start_time).total_seconds()}"
)

def list_all_explores(self) -> Iterable[Tuple[str, str]]:
# returns a list of (model, explore) tuples
Expand Down Expand Up @@ -1277,28 +1268,24 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]:
]

looker_dashboards_for_usage: List[looker_usage.LookerDashboardForUsage] = []
self.reporter.report_stage_start("dashboard_chart_metadata")

with concurrent.futures.ThreadPoolExecutor(
max_workers=self.source_config.max_threads
) as async_executor:
async_workunits = {}
for dashboard_id in dashboard_ids:
if dashboard_id is not None:
job = async_executor.submit(
self.process_dashboard, dashboard_id, fields
)
async_workunits[job] = dashboard_id

for job in concurrent.futures.as_completed(async_workunits):
with self.reporter.report_stage("dashboard_chart_metadata"):
for job in BackpressureAwareExecutor.map(
self.process_dashboard,
(
(dashboard_id, fields)
for dashboard_id in dashboard_ids
if dashboard_id is not None
),
max_workers=self.source_config.max_threads,
):
(
work_units,
dashboard_usage,
dashboard_id,
start_time,
end_time,
) = job.result()
del async_workunits[job]
logger.debug(
f"Running time of process_dashboard for {dashboard_id} = {(end_time - start_time).total_seconds()}"
)
Expand All @@ -1308,8 +1295,6 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]:
if dashboard_usage is not None:
looker_dashboards_for_usage.append(dashboard_usage)

self.reporter.report_stage_end("dashboard_chart_metadata")

if (
self.source_config.extract_owners
and self.reporter.resolved_user_ids > 0
Expand Down
2 changes: 1 addition & 1 deletion metadata-ingestion/src/datahub/ingestion/source/mongodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]:
)
collection_fields = sorted(
collection_schema.values(),
key=lambda x: x["count"],
key=lambda x: (x["count"], x["delimited_name"]),
reverse=True,
)[0:max_schema_size]
# Add this information to the custom properties so user can know they are looking at downsampled schema
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,22 @@
from __future__ import annotations

import collections
import concurrent.futures
import time
from concurrent.futures import Future, ThreadPoolExecutor
from threading import BoundedSemaphore
from typing import Any, Callable, Deque, Dict, Optional, Tuple, TypeVar
from typing import (
Any,
Callable,
Deque,
Dict,
Iterable,
Iterator,
Optional,
Set,
Tuple,
TypeVar,
)

from datahub.ingestion.api.closeable import Closeable

Expand Down Expand Up @@ -130,3 +144,74 @@ def shutdown(self) -> None:

def close(self) -> None:
self.shutdown()


class BackpressureAwareExecutor:
# This couldn't be a real executor because the semantics of submit wouldn't really make sense.
# In this variant, if we blocked on submit, then we would also be blocking the thread that
# we expect to be consuming the results. As such, I made it accept the full list of args
# up front, and that way the consumer can read results at its own pace.

@classmethod
def map(
cls,
fn: Callable[..., _R],
args_list: Iterable[Tuple[Any, ...]],
max_workers: int,
max_pending: Optional[int] = None,
) -> Iterator[Future[_R]]:
"""Similar to concurrent.futures.ThreadPoolExecutor#map, except that it won't run ahead of the consumer.
The main benefit is that the ThreadPoolExecutor isn't stuck holding a ton of result
objects in memory if the consumer is slow. Instead, the consumer can read the results
at its own pace and the executor threads will idle if they need to.
Args:
fn: The function to apply to each input.
args_list: The list of inputs. In contrast to the builtin map, this is a list
of tuples, where each tuple is the arguments to fn.
max_workers: The maximum number of threads to use.
max_pending: The maximum number of pending results to keep in memory.
If not set, it will be set to 2*max_workers.
Returns:
An iterable of futures.
This differs from a traditional map because it returns futures
instead of the actual results, so that the caller is required
to handle exceptions.
Additionally, it does not maintain the order of the arguments.
If you want to know which result corresponds to which input,
the mapped function should return some form of an identifier.
"""

if max_pending is None:
max_pending = 2 * max_workers
assert max_pending >= max_workers

pending_futures: Set[Future] = set()

with ThreadPoolExecutor(max_workers=max_workers) as executor:
for args in args_list:
# If the pending list is full, wait until one is done.
if len(pending_futures) >= max_pending:
(done, _) = concurrent.futures.wait(
pending_futures, return_when=concurrent.futures.FIRST_COMPLETED
)
for future in done:
pending_futures.remove(future)

# We don't want to call result() here because we want the caller
# to handle exceptions/cancellation.
yield future

# Now that there's space in the pending list, enqueue the next task.
pending_futures.add(executor.submit(fn, *args))

# Wait for all the remaining tasks to complete.
for future in concurrent.futures.as_completed(pending_futures):
pending_futures.remove(future)
yield future

assert not pending_futures
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import time
from concurrent.futures import Future

from datahub.utilities.advanced_thread_executor import PartitionExecutor
from datahub.utilities.advanced_thread_executor import (
BackpressureAwareExecutor,
PartitionExecutor,
)
from datahub.utilities.perf_timer import PerfTimer


Expand Down Expand Up @@ -68,3 +71,58 @@ def task(id: str) -> str:
# Wait for everything to finish.
executor.flush()
assert len(done_tasks) == 16


def test_backpressure_aware_executor_simple():
def task(i):
return i

assert set(
res.result()
for res in BackpressureAwareExecutor.map(
task, ((i,) for i in range(10)), max_workers=2
)
) == set(range(10))


def test_backpressure_aware_executor_advanced():
task_duration = 0.5
started = set()
executed = set()

def task(x, y):
assert x + 1 == y
started.add(x)
time.sleep(task_duration)
executed.add(x)
return x

args_list = [(i, i + 1) for i in range(10)]

with PerfTimer() as timer:
results = BackpressureAwareExecutor.map(
task, args_list, max_workers=2, max_pending=4
)
assert timer.elapsed_seconds() < task_duration

# No tasks should have completed yet.
assert len(executed) == 0

# Consume the first result.
first_result = next(results)
assert 0 <= first_result.result() < 4
assert timer.elapsed_seconds() > task_duration

# By now, the first four tasks should have started.
time.sleep(task_duration)
assert {0, 1, 2, 3}.issubset(started)
assert 2 <= len(executed) <= 4

# Finally, consume the rest of the results.
assert set(r.result() for r in results) == {
i for i in range(10) if i != first_result.result()
}

# Validate that the entire process took about 5-10x the task duration.
# That's because we have 2 workers and 10 tasks.
assert 5 * task_duration < timer.elapsed_seconds() < 10 * task_duration
Loading

0 comments on commit 0c09de9

Please sign in to comment.