Skip to content

Commit

Permalink
fix(ingestion/kafka): OAuth callback execution (datahub-project#11900)
Browse files Browse the repository at this point in the history
  • Loading branch information
sid-acryl authored Nov 22, 2024
1 parent dac80fb commit 86b8175
Show file tree
Hide file tree
Showing 15 changed files with 183 additions and 26 deletions.
22 changes: 22 additions & 0 deletions metadata-ingestion/docs/sources/kafka/kafka.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,29 @@ source:
connection:
bootstrap: "broker:9092"
schema_registry_url: http://localhost:8081
```

### OAuth Callback
The OAuth callback function can be set up using `config.connection.consumer_config.oauth_cb`.

You need to specify a Python function reference in the format <python-module>:<function-name>.

For example, in the configuration `oauth:create_token`, `create_token` is a function defined in `oauth.py`, and `oauth.py` must be accessible in the PYTHONPATH.

```YAML
source:
type: "kafka"
config:
# Set the custom schema registry implementation class
schema_registry_class: "datahub.ingestion.source.confluent_schema_registry.ConfluentSchemaRegistry"
# Coordinates
connection:
bootstrap: "broker:9092"
schema_registry_url: http://localhost:8081
consumer_config:
security.protocol: "SASL_PLAINTEXT"
sasl.mechanism: "OAUTHBEARER"
oauth_cb: "oauth:create_token"
# sink configs
```

Expand Down
4 changes: 2 additions & 2 deletions metadata-ingestion/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,8 +741,8 @@
"hive = datahub.ingestion.source.sql.hive:HiveSource",
"hive-metastore = datahub.ingestion.source.sql.hive_metastore:HiveMetastoreSource",
"json-schema = datahub.ingestion.source.schema.json_schema:JsonSchemaSource",
"kafka = datahub.ingestion.source.kafka:KafkaSource",
"kafka-connect = datahub.ingestion.source.kafka_connect:KafkaConnectSource",
"kafka = datahub.ingestion.source.kafka.kafka:KafkaSource",
"kafka-connect = datahub.ingestion.source.kafka.kafka_connect:KafkaConnectSource",
"ldap = datahub.ingestion.source.ldap:LDAPSource",
"looker = datahub.ingestion.source.looker.looker_source:LookerDashboardSource",
"lookml = datahub.ingestion.source.looker.lookml_source:LookMLSource",
Expand Down
13 changes: 12 additions & 1 deletion metadata-ingestion/src/datahub/configuration/kafka.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from pydantic import Field, validator

from datahub.configuration.common import ConfigModel
from datahub.configuration.common import ConfigModel, ConfigurationError
from datahub.configuration.kafka_consumer_config import CallableConsumerConfig
from datahub.configuration.validate_host_port import validate_host_port


Expand Down Expand Up @@ -36,6 +37,16 @@ class KafkaConsumerConnectionConfig(_KafkaConnectionConfig):
description="Extra consumer config serialized as JSON. These options will be passed into Kafka's DeserializingConsumer. See https://docs.confluent.io/platform/current/clients/confluent-kafka-python/html/index.html#deserializingconsumer and https://github.com/edenhill/librdkafka/blob/master/CONFIGURATION.md .",
)

@validator("consumer_config")
@classmethod
def resolve_callback(cls, value: dict) -> dict:
if CallableConsumerConfig.is_callable_config(value):
try:
value = CallableConsumerConfig(value).callable_config()
except Exception as e:
raise ConfigurationError(e)
return value


class KafkaProducerConnectionConfig(_KafkaConnectionConfig):
"""Configuration class for holding connectivity information for Kafka producers"""
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import logging
from typing import Any, Dict, Optional

from datahub.ingestion.api.registry import import_path

logger = logging.getLogger(__name__)


class CallableConsumerConfig:
CALLBACK_ATTRIBUTE: str = "oauth_cb"

def __init__(self, config: Dict[str, Any]):
self._config = config

self._resolve_oauth_callback()

def callable_config(self) -> Dict[str, Any]:
return self._config

@staticmethod
def is_callable_config(config: Dict[str, Any]) -> bool:
return CallableConsumerConfig.CALLBACK_ATTRIBUTE in config

def get_call_back_attribute(self) -> Optional[str]:
return self._config.get(CallableConsumerConfig.CALLBACK_ATTRIBUTE)

def _resolve_oauth_callback(self) -> None:
if not self.get_call_back_attribute():
return

call_back = self.get_call_back_attribute()

assert call_back # to silent lint
# Set the callback
self._config[CallableConsumerConfig.CALLBACK_ATTRIBUTE] = import_path(call_back)
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@
from datahub.ingestion.extractor import protobuf_util, schema_util
from datahub.ingestion.extractor.json_schema_util import JsonSchemaTranslator
from datahub.ingestion.extractor.protobuf_util import ProtobufSchema
from datahub.ingestion.source.kafka import KafkaSourceConfig, KafkaSourceReport
from datahub.ingestion.source.kafka_schema_registry_base import KafkaSchemaRegistryBase
from datahub.ingestion.source.kafka.kafka import KafkaSourceConfig, KafkaSourceReport
from datahub.ingestion.source.kafka.kafka_schema_registry_base import (
KafkaSchemaRegistryBase,
)
from datahub.metadata.com.linkedin.pegasus2avro.schema import (
KafkaSchema,
SchemaField,
Expand Down
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from datahub.configuration.common import AllowDenyPattern
from datahub.configuration.kafka import KafkaConsumerConnectionConfig
from datahub.configuration.kafka_consumer_config import CallableConsumerConfig
from datahub.configuration.source_common import (
DatasetSourceConfigMixin,
LowerCaseDatasetUrnConfigMixin,
Expand Down Expand Up @@ -49,7 +50,9 @@
)
from datahub.ingestion.api.workunit import MetadataWorkUnit
from datahub.ingestion.source.common.subtypes import DatasetSubTypes
from datahub.ingestion.source.kafka_schema_registry_base import KafkaSchemaRegistryBase
from datahub.ingestion.source.kafka.kafka_schema_registry_base import (
KafkaSchemaRegistryBase,
)
from datahub.ingestion.source.state.stale_entity_removal_handler import (
StaleEntityRemovalHandler,
StaleEntityRemovalSourceReport,
Expand Down Expand Up @@ -143,14 +146,21 @@ class KafkaSourceConfig(
def get_kafka_consumer(
connection: KafkaConsumerConnectionConfig,
) -> confluent_kafka.Consumer:
return confluent_kafka.Consumer(
consumer = confluent_kafka.Consumer(
{
"group.id": "test",
"bootstrap.servers": connection.bootstrap,
**connection.consumer_config,
}
)

if CallableConsumerConfig.is_callable_config(connection.consumer_config):
# As per documentation, we need to explicitly call the poll method to make sure OAuth callback gets executed
# https://docs.confluent.io/platform/current/clients/confluent-kafka-python/html/index.html#kafka-client-configuration
consumer.poll(timeout=30)

return consumer


@dataclass
class KafkaSourceReport(StaleEntityRemovalSourceReport):
Expand Down
20 changes: 20 additions & 0 deletions metadata-ingestion/tests/integration/kafka/kafka_to_file_oauth.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
run_id: kafka-test

source:
type: kafka
config:
connection:
bootstrap: "localhost:29092"
schema_registry_url: "http://localhost:28081"
consumer_config:
security.protocol: "SASL_PLAINTEXT"
sasl.mechanism: "OAUTHBEARER"
oauth_cb: "oauth:create_token"
domain:
"urn:li:domain:sales":
allow:
- "key_value_topic"
sink:
type: file
config:
filename: "./kafka_mces.json"
14 changes: 14 additions & 0 deletions metadata-ingestion/tests/integration/kafka/oauth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import logging
from typing import Any, Tuple

logger = logging.getLogger(__name__)

MESSAGE: str = "OAuth token `create_token` callback"


def create_token(*args: Any, **kwargs: Any) -> Tuple[str, int]:
logger.warning(MESSAGE)
return (
"eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJjbGllbnRfaWQiOiJrYWZrYV9jbGllbnQiLCJleHAiOjE2OTg3NjYwMDB9.dummy_sig_abcdef123456",
3600,
)
39 changes: 38 additions & 1 deletion metadata-ingestion/tests/integration/kafka/test_kafka.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import logging
import subprocess

import pytest
import yaml
from freezegun import freeze_time

from datahub.ingestion.api.source import SourceCapability
from datahub.ingestion.source.kafka import KafkaSource
from datahub.ingestion.run.pipeline import Pipeline
from datahub.ingestion.source.kafka.kafka import KafkaSource
from tests.integration.kafka import oauth # type: ignore
from tests.test_helpers import mce_helpers, test_connection_helpers
from tests.test_helpers.click_helpers import run_datahub_cmd
from tests.test_helpers.docker_helpers import wait_for_port
Expand Down Expand Up @@ -99,3 +103,36 @@ def test_kafka_test_connection(mock_kafka_service, config_dict, is_success):
SourceCapability.SCHEMA_METADATA: "Failed to establish a new connection"
},
)


@freeze_time(FROZEN_TIME)
@pytest.mark.integration
def test_kafka_oauth_callback(
mock_kafka_service, test_resources_dir, pytestconfig, tmp_path, mock_time
):
# Run the metadata ingestion pipeline.
config_file = (test_resources_dir / "kafka_to_file_oauth.yml").resolve()

log_file = tmp_path / "kafka_oauth_message.log"

file_handler = logging.FileHandler(
str(log_file)
) # Add a file handler to later validate a test-case
logging.getLogger().addHandler(file_handler)

recipe: dict = {}
with open(config_file) as fp:
recipe = yaml.safe_load(fp)

pipeline = Pipeline.create(recipe)

pipeline.run()

is_found: bool = False
with open(log_file, "r") as file:
for line_number, line in enumerate(file, 1):
if oauth.MESSAGE in line:
is_found = True
break

assert is_found
8 changes: 6 additions & 2 deletions metadata-ingestion/tests/unit/api/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@

class TestPipeline:
@patch("confluent_kafka.Consumer", autospec=True)
@patch("datahub.ingestion.source.kafka.KafkaSource.get_workunits", autospec=True)
@patch(
"datahub.ingestion.source.kafka.kafka.KafkaSource.get_workunits", autospec=True
)
@patch("datahub.ingestion.sink.console.ConsoleSink.close", autospec=True)
@freeze_time(FROZEN_TIME)
def test_configure(self, mock_sink, mock_source, mock_consumer):
Expand Down Expand Up @@ -198,7 +200,9 @@ def test_configure_with_rest_sink_with_additional_props_initializes_graph(
assert pipeline.ctx.graph.config.token == pipeline.config.sink.config["token"]

@freeze_time(FROZEN_TIME)
@patch("datahub.ingestion.source.kafka.KafkaSource.get_workunits", autospec=True)
@patch(
"datahub.ingestion.source.kafka.kafka.KafkaSource.get_workunits", autospec=True
)
def test_configure_with_file_sink_does_not_init_graph(self, mock_source, tmp_path):
pipeline = Pipeline.create(
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
)

from datahub.ingestion.source.confluent_schema_registry import ConfluentSchemaRegistry
from datahub.ingestion.source.kafka import KafkaSourceConfig, KafkaSourceReport
from datahub.ingestion.source.kafka.kafka import KafkaSourceConfig, KafkaSourceReport


class ConfluentSchemaRegistryTest(unittest.TestCase):
Expand Down
32 changes: 17 additions & 15 deletions metadata-ingestion/tests/unit/test_kafka_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from datahub.emitter.mcp import MetadataChangeProposalWrapper
from datahub.ingestion.api.common import PipelineContext
from datahub.ingestion.api.workunit import MetadataWorkUnit
from datahub.ingestion.source.kafka import KafkaSource, KafkaSourceConfig
from datahub.ingestion.source.kafka.kafka import KafkaSource, KafkaSourceConfig
from datahub.metadata.com.linkedin.pegasus2avro.mxe import MetadataChangeEvent
from datahub.metadata.schema_classes import (
BrowsePathsClass,
Expand All @@ -38,11 +38,13 @@

@pytest.fixture
def mock_admin_client():
with patch("datahub.ingestion.source.kafka.AdminClient", autospec=True) as mock:
with patch(
"datahub.ingestion.source.kafka.kafka.AdminClient", autospec=True
) as mock:
yield mock


@patch("datahub.ingestion.source.kafka.confluent_kafka.Consumer", autospec=True)
@patch("datahub.ingestion.source.kafka.kafka.confluent_kafka.Consumer", autospec=True)
def test_kafka_source_configuration(mock_kafka):
ctx = PipelineContext(run_id="test")
kafka_source = KafkaSource(
Expand All @@ -53,7 +55,7 @@ def test_kafka_source_configuration(mock_kafka):
assert mock_kafka.call_count == 1


@patch("datahub.ingestion.source.kafka.confluent_kafka.Consumer", autospec=True)
@patch("datahub.ingestion.source.kafka.kafka.confluent_kafka.Consumer", autospec=True)
def test_kafka_source_workunits_wildcard_topic(mock_kafka, mock_admin_client):
mock_kafka_instance = mock_kafka.return_value
mock_cluster_metadata = MagicMock()
Expand All @@ -74,7 +76,7 @@ def test_kafka_source_workunits_wildcard_topic(mock_kafka, mock_admin_client):
assert len(workunits) == 4


@patch("datahub.ingestion.source.kafka.confluent_kafka.Consumer", autospec=True)
@patch("datahub.ingestion.source.kafka.kafka.confluent_kafka.Consumer", autospec=True)
def test_kafka_source_workunits_topic_pattern(mock_kafka, mock_admin_client):
mock_kafka_instance = mock_kafka.return_value
mock_cluster_metadata = MagicMock()
Expand Down Expand Up @@ -108,7 +110,7 @@ def test_kafka_source_workunits_topic_pattern(mock_kafka, mock_admin_client):
assert len(workunits) == 4


@patch("datahub.ingestion.source.kafka.confluent_kafka.Consumer", autospec=True)
@patch("datahub.ingestion.source.kafka.kafka.confluent_kafka.Consumer", autospec=True)
def test_kafka_source_workunits_with_platform_instance(mock_kafka, mock_admin_client):
PLATFORM_INSTANCE = "kafka_cluster"
PLATFORM = "kafka"
Expand Down Expand Up @@ -160,7 +162,7 @@ def test_kafka_source_workunits_with_platform_instance(mock_kafka, mock_admin_cl
assert f"/prod/{PLATFORM}/{PLATFORM_INSTANCE}" in browse_path_aspects[0].paths


@patch("datahub.ingestion.source.kafka.confluent_kafka.Consumer", autospec=True)
@patch("datahub.ingestion.source.kafka.kafka.confluent_kafka.Consumer", autospec=True)
def test_kafka_source_workunits_no_platform_instance(mock_kafka, mock_admin_client):
PLATFORM = "kafka"
TOPIC_NAME = "test"
Expand Down Expand Up @@ -204,7 +206,7 @@ def test_kafka_source_workunits_no_platform_instance(mock_kafka, mock_admin_clie
assert f"/prod/{PLATFORM}" in browse_path_aspects[0].paths


@patch("datahub.ingestion.source.kafka.confluent_kafka.Consumer", autospec=True)
@patch("datahub.ingestion.source.kafka.kafka.confluent_kafka.Consumer", autospec=True)
def test_close(mock_kafka, mock_admin_client):
mock_kafka_instance = mock_kafka.return_value
ctx = PipelineContext(run_id="test")
Expand All @@ -223,7 +225,7 @@ def test_close(mock_kafka, mock_admin_client):
"datahub.ingestion.source.confluent_schema_registry.SchemaRegistryClient",
autospec=True,
)
@patch("datahub.ingestion.source.kafka.confluent_kafka.Consumer", autospec=True)
@patch("datahub.ingestion.source.kafka.kafka.confluent_kafka.Consumer", autospec=True)
def test_kafka_source_workunits_schema_registry_subject_name_strategies(
mock_kafka_consumer, mock_schema_registry_client, mock_admin_client
):
Expand Down Expand Up @@ -415,7 +417,7 @@ def mock_get_latest_version(subject_name: str) -> Optional[RegisteredSchema]:
"datahub.ingestion.source.confluent_schema_registry.SchemaRegistryClient",
autospec=True,
)
@patch("datahub.ingestion.source.kafka.confluent_kafka.Consumer", autospec=True)
@patch("datahub.ingestion.source.kafka.kafka.confluent_kafka.Consumer", autospec=True)
def test_kafka_ignore_warnings_on_schema_type(
mock_kafka_consumer,
mock_schema_registry_client,
Expand Down Expand Up @@ -483,8 +485,8 @@ def mock_get_latest_version(subject_name: str) -> Optional[RegisteredSchema]:
assert kafka_source.report.warnings


@patch("datahub.ingestion.source.kafka.AdminClient", autospec=True)
@patch("datahub.ingestion.source.kafka.confluent_kafka.Consumer", autospec=True)
@patch("datahub.ingestion.source.kafka.kafka.AdminClient", autospec=True)
@patch("datahub.ingestion.source.kafka.kafka.confluent_kafka.Consumer", autospec=True)
def test_kafka_source_succeeds_with_admin_client_init_error(
mock_kafka, mock_kafka_admin_client
):
Expand Down Expand Up @@ -513,8 +515,8 @@ def test_kafka_source_succeeds_with_admin_client_init_error(
assert len(workunits) == 2


@patch("datahub.ingestion.source.kafka.AdminClient", autospec=True)
@patch("datahub.ingestion.source.kafka.confluent_kafka.Consumer", autospec=True)
@patch("datahub.ingestion.source.kafka.kafka.AdminClient", autospec=True)
@patch("datahub.ingestion.source.kafka.kafka.confluent_kafka.Consumer", autospec=True)
def test_kafka_source_succeeds_with_describe_configs_error(
mock_kafka, mock_kafka_admin_client
):
Expand Down Expand Up @@ -550,7 +552,7 @@ def test_kafka_source_succeeds_with_describe_configs_error(
"datahub.ingestion.source.confluent_schema_registry.SchemaRegistryClient",
autospec=True,
)
@patch("datahub.ingestion.source.kafka.confluent_kafka.Consumer", autospec=True)
@patch("datahub.ingestion.source.kafka.kafka.confluent_kafka.Consumer", autospec=True)
def test_kafka_source_topic_meta_mappings(
mock_kafka_consumer, mock_schema_registry_client, mock_admin_client
):
Expand Down

0 comments on commit 86b8175

Please sign in to comment.