Skip to content

Commit

Permalink
[airbyte-cdk] emit source recordCount as float instead of integer (#3…
Browse files Browse the repository at this point in the history
…6560)

Co-authored-by: Ella Rohm-Ensing <[email protected]>
  • Loading branch information
brianjlai and erohmensing authored Mar 27, 2024
1 parent aba3054 commit 624415d
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 38 deletions.
11 changes: 6 additions & 5 deletions airbyte-cdk/python/airbyte_cdk/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,26 +161,27 @@ def read(
if self.source.check_config_against_spec:
self.validate_connection(source_spec, config)

stream_message_counter: DefaultDict[HashableStreamDescriptor, int] = defaultdict(int)
# The Airbyte protocol dictates that counts be expressed as float/double to better protect against integer overflows
stream_message_counter: DefaultDict[HashableStreamDescriptor, float] = defaultdict(float)
for message in self.source.read(self.logger, config, catalog, state):
yield self.handle_record_counts(message, stream_message_counter)
for message in self._emit_queued_messages(self.source):
yield self.handle_record_counts(message, stream_message_counter)

@staticmethod
def handle_record_counts(message: AirbyteMessage, stream_message_count: DefaultDict[HashableStreamDescriptor, int]) -> AirbyteMessage:
def handle_record_counts(message: AirbyteMessage, stream_message_count: DefaultDict[HashableStreamDescriptor, float]) -> AirbyteMessage:
if message.type == Type.RECORD:
stream_message_count[message_utils.get_stream_descriptor(message)] += 1
stream_message_count[message_utils.get_stream_descriptor(message)] += 1.0

elif message.type == Type.STATE:
stream_descriptor = message_utils.get_stream_descriptor(message)

# Set record count from the counter onto the state message
message.state.sourceStats = message.state.sourceStats or AirbyteStateStats()
message.state.sourceStats.recordCount = stream_message_count.get(stream_descriptor, 0)
message.state.sourceStats.recordCount = stream_message_count.get(stream_descriptor, 0.0)

# Reset the counter
stream_message_count[stream_descriptor] = 0
stream_message_count[stream_descriptor] = 0.0
return message

@staticmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def test_full_refresh_sync(self, http_mocker):
validate_message_order([Type.RECORD, Type.RECORD, Type.STATE], actual_messages.records_and_state_messages)
assert actual_messages.state_messages[0].state.stream.stream_descriptor.name == "users"
assert actual_messages.state_messages[0].state.stream.stream_state == {"__ab_full_refresh_state_message": True}
assert actual_messages.state_messages[0].state.sourceStats.recordCount == 2
assert actual_messages.state_messages[0].state.sourceStats.recordCount == 2.0

@HttpMocker()
def test_full_refresh_with_slices(self, http_mocker):
Expand Down Expand Up @@ -233,7 +233,7 @@ def test_full_refresh_with_slices(self, http_mocker):
validate_message_order([Type.RECORD, Type.RECORD, Type.RECORD, Type.RECORD, Type.STATE], actual_messages.records_and_state_messages)
assert actual_messages.state_messages[0].state.stream.stream_descriptor.name == "dividers"
assert actual_messages.state_messages[0].state.stream.stream_state == {"__ab_full_refresh_state_message": True}
assert actual_messages.state_messages[0].state.sourceStats.recordCount == 4
assert actual_messages.state_messages[0].state.sourceStats.recordCount == 4.0


@freezegun.freeze_time(_NOW)
Expand Down Expand Up @@ -266,10 +266,10 @@ def test_incremental_sync(self, http_mocker):
validate_message_order([Type.RECORD, Type.RECORD, Type.RECORD, Type.STATE, Type.RECORD, Type.RECORD, Type.STATE], actual_messages.records_and_state_messages)
assert actual_messages.state_messages[0].state.stream.stream_descriptor.name == "planets"
assert actual_messages.state_messages[0].state.stream.stream_state == {"created_at": last_record_date_0}
assert actual_messages.state_messages[0].state.sourceStats.recordCount == 3
assert actual_messages.state_messages[0].state.sourceStats.recordCount == 3.0
assert actual_messages.state_messages[1].state.stream.stream_descriptor.name == "planets"
assert actual_messages.state_messages[1].state.stream.stream_state == {"created_at": last_record_date_1}
assert actual_messages.state_messages[1].state.sourceStats.recordCount == 2
assert actual_messages.state_messages[1].state.sourceStats.recordCount == 2.0

@HttpMocker()
def test_incremental_running_as_full_refresh(self, http_mocker):
Expand Down Expand Up @@ -299,7 +299,7 @@ def test_incremental_running_as_full_refresh(self, http_mocker):
validate_message_order([Type.RECORD, Type.RECORD, Type.RECORD, Type.RECORD, Type.RECORD, Type.STATE], actual_messages.records_and_state_messages)
assert actual_messages.state_messages[0].state.stream.stream_descriptor.name == "planets"
assert actual_messages.state_messages[0].state.stream.stream_state == {"created_at": last_record_date_1}
assert actual_messages.state_messages[0].state.sourceStats.recordCount == 5
assert actual_messages.state_messages[0].state.sourceStats.recordCount == 5.0

@HttpMocker()
def test_legacy_incremental_sync(self, http_mocker):
Expand Down Expand Up @@ -329,10 +329,10 @@ def test_legacy_incremental_sync(self, http_mocker):
validate_message_order([Type.RECORD, Type.RECORD, Type.RECORD, Type.STATE, Type.RECORD, Type.RECORD, Type.STATE], actual_messages.records_and_state_messages)
assert actual_messages.state_messages[0].state.stream.stream_descriptor.name == "legacies"
assert actual_messages.state_messages[0].state.stream.stream_state == {"created_at": last_record_date_0}
assert actual_messages.state_messages[0].state.sourceStats.recordCount == 3
assert actual_messages.state_messages[0].state.sourceStats.recordCount == 3.0
assert actual_messages.state_messages[1].state.stream.stream_descriptor.name == "legacies"
assert actual_messages.state_messages[1].state.stream.stream_state == {"created_at": last_record_date_1}
assert actual_messages.state_messages[1].state.sourceStats.recordCount == 2
assert actual_messages.state_messages[1].state.sourceStats.recordCount == 2.0


@freezegun.freeze_time(_NOW)
Expand Down Expand Up @@ -402,16 +402,16 @@ def test_incremental_and_full_refresh_streams(self, http_mocker):
], actual_messages.records_and_state_messages)
assert actual_messages.state_messages[0].state.stream.stream_descriptor.name == "users"
assert actual_messages.state_messages[0].state.stream.stream_state == {"__ab_full_refresh_state_message": True}
assert actual_messages.state_messages[0].state.sourceStats.recordCount == 2
assert actual_messages.state_messages[0].state.sourceStats.recordCount == 2.0
assert actual_messages.state_messages[1].state.stream.stream_descriptor.name == "planets"
assert actual_messages.state_messages[1].state.stream.stream_state == {"created_at": last_record_date_0}
assert actual_messages.state_messages[1].state.sourceStats.recordCount == 3
assert actual_messages.state_messages[1].state.sourceStats.recordCount == 3.0
assert actual_messages.state_messages[2].state.stream.stream_descriptor.name == "planets"
assert actual_messages.state_messages[2].state.stream.stream_state == {"created_at": last_record_date_1}
assert actual_messages.state_messages[2].state.sourceStats.recordCount == 2
assert actual_messages.state_messages[2].state.sourceStats.recordCount == 2.0
assert actual_messages.state_messages[3].state.stream.stream_descriptor.name == "dividers"
assert actual_messages.state_messages[3].state.stream.stream_state == {"__ab_full_refresh_state_message": True}
assert actual_messages.state_messages[3].state.sourceStats.recordCount == 4
assert actual_messages.state_messages[3].state.sourceStats.recordCount == 4.0


def emits_successful_sync_status_messages(status_messages: List[AirbyteStreamStatus]) -> bool:
Expand Down
51 changes: 29 additions & 22 deletions airbyte-cdk/python/unit_tests/test_entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,94 +328,101 @@ def test_filter_internal_requests(deployment_mode, url, expected_error):
[
pytest.param(
AirbyteMessage(type=Type.RECORD, record=AirbyteRecordMessage(stream="customers", data={"id": "12345"}, emitted_at=1)),
{HashableStreamDescriptor(name="customers"): 100},
{HashableStreamDescriptor(name="customers"): 100.0},
AirbyteMessage(type=Type.RECORD, record=AirbyteRecordMessage(stream="customers", data={"id": "12345"}, emitted_at=1)),
{HashableStreamDescriptor(name="customers"): 101},
{HashableStreamDescriptor(name="customers"): 101.0},
id="test_handle_record_message",
),
pytest.param(
AirbyteMessage(type=Type.STATE, state=AirbyteStateMessage(type=AirbyteStateType.STREAM, stream=AirbyteStreamState(
stream_descriptor=StreamDescriptor(name="customers"), stream_state=AirbyteStateBlob(updated_at="2024-02-02")))),
{HashableStreamDescriptor(name="customers"): 100},
{HashableStreamDescriptor(name="customers"): 100.0},
AirbyteMessage(type=Type.STATE, state=AirbyteStateMessage(type=AirbyteStateType.STREAM, stream=AirbyteStreamState(
stream_descriptor=StreamDescriptor(name="customers"), stream_state=AirbyteStateBlob(updated_at="2024-02-02")),
sourceStats=AirbyteStateStats(recordCount=100.0))),
{HashableStreamDescriptor(name="customers"): 0},
{HashableStreamDescriptor(name="customers"): 0.0},
id="test_handle_state_message",
),
pytest.param(
AirbyteMessage(type=Type.RECORD, record=AirbyteRecordMessage(stream="customers", data={"id": "12345"}, emitted_at=1)),
defaultdict(int),
defaultdict(float),
AirbyteMessage(type=Type.RECORD, record=AirbyteRecordMessage(stream="customers", data={"id": "12345"}, emitted_at=1)),
{HashableStreamDescriptor(name="customers"): 1},
{HashableStreamDescriptor(name="customers"): 1.0},
id="test_handle_first_record_message",
),
pytest.param(
AirbyteMessage(type=Type.TRACE, trace=AirbyteTraceMessage(type=TraceType.STREAM_STATUS,
stream_status=AirbyteStreamStatusTraceMessage(
stream_descriptor=StreamDescriptor(name="customers"),
status=AirbyteStreamStatus.COMPLETE), emitted_at=1)),
{HashableStreamDescriptor(name="customers"): 5},
{HashableStreamDescriptor(name="customers"): 5.0},
AirbyteMessage(type=Type.TRACE, trace=AirbyteTraceMessage(type=TraceType.STREAM_STATUS,
stream_status=AirbyteStreamStatusTraceMessage(
stream_descriptor=StreamDescriptor(name="customers"),
status=AirbyteStreamStatus.COMPLETE), emitted_at=1)),
{HashableStreamDescriptor(name="customers"): 5},
{HashableStreamDescriptor(name="customers"): 5.0},
id="test_handle_other_message_type",
),
pytest.param(
AirbyteMessage(type=Type.RECORD, record=AirbyteRecordMessage(stream="others", data={"id": "12345"}, emitted_at=1)),
{HashableStreamDescriptor(name="customers"): 100, HashableStreamDescriptor(name="others"): 27},
{HashableStreamDescriptor(name="customers"): 100.0, HashableStreamDescriptor(name="others"): 27.0},
AirbyteMessage(type=Type.RECORD, record=AirbyteRecordMessage(stream="others", data={"id": "12345"}, emitted_at=1)),
{HashableStreamDescriptor(name="customers"): 100, HashableStreamDescriptor(name="others"): 28},
{HashableStreamDescriptor(name="customers"): 100.0, HashableStreamDescriptor(name="others"): 28.0},
id="test_handle_record_message_for_other_stream",
),
pytest.param(
AirbyteMessage(type=Type.STATE, state=AirbyteStateMessage(type=AirbyteStateType.STREAM, stream=AirbyteStreamState(
stream_descriptor=StreamDescriptor(name="others"), stream_state=AirbyteStateBlob(updated_at="2024-02-02")))),
{HashableStreamDescriptor(name="customers"): 100, HashableStreamDescriptor(name="others"): 27},
{HashableStreamDescriptor(name="customers"): 100.0, HashableStreamDescriptor(name="others"): 27.0},
AirbyteMessage(type=Type.STATE, state=AirbyteStateMessage(type=AirbyteStateType.STREAM, stream=AirbyteStreamState(
stream_descriptor=StreamDescriptor(name="others"), stream_state=AirbyteStateBlob(updated_at="2024-02-02")),
sourceStats=AirbyteStateStats(recordCount=27.0))),
{HashableStreamDescriptor(name="customers"): 100, HashableStreamDescriptor(name="others"): 0},
{HashableStreamDescriptor(name="customers"): 100.0, HashableStreamDescriptor(name="others"): 0.0},
id="test_handle_state_message_for_other_stream",
),
pytest.param(
AirbyteMessage(type=Type.RECORD,
record=AirbyteRecordMessage(stream="customers", namespace="public", data={"id": "12345"}, emitted_at=1)),
{HashableStreamDescriptor(name="customers", namespace="public"): 100},
{HashableStreamDescriptor(name="customers", namespace="public"): 100.0},
AirbyteMessage(type=Type.RECORD,
record=AirbyteRecordMessage(stream="customers", namespace="public", data={"id": "12345"}, emitted_at=1)),
{HashableStreamDescriptor(name="customers", namespace="public"): 101},
{HashableStreamDescriptor(name="customers", namespace="public"): 101.0},
id="test_handle_record_message_with_descriptor",
),
pytest.param(
AirbyteMessage(type=Type.STATE, state=AirbyteStateMessage(type=AirbyteStateType.STREAM, stream=AirbyteStreamState(
stream_descriptor=StreamDescriptor(name="customers", namespace="public"),
stream_state=AirbyteStateBlob(updated_at="2024-02-02")))),
{HashableStreamDescriptor(name="customers", namespace="public"): 100},
{HashableStreamDescriptor(name="customers", namespace="public"): 100.0},
AirbyteMessage(type=Type.STATE, state=AirbyteStateMessage(type=AirbyteStateType.STREAM, stream=AirbyteStreamState(
stream_descriptor=StreamDescriptor(name="customers", namespace="public"),
stream_state=AirbyteStateBlob(updated_at="2024-02-02")), sourceStats=AirbyteStateStats(recordCount=100.0))),
{HashableStreamDescriptor(name="customers", namespace="public"): 0},
{HashableStreamDescriptor(name="customers", namespace="public"): 0.0},
id="test_handle_state_message_with_descriptor",
),
pytest.param(
AirbyteMessage(type=Type.STATE, state=AirbyteStateMessage(type=AirbyteStateType.STREAM, stream=AirbyteStreamState(
stream_descriptor=StreamDescriptor(name="others", namespace="public"),
stream_state=AirbyteStateBlob(updated_at="2024-02-02")))),
{HashableStreamDescriptor(name="customers", namespace="public"): 100},
{HashableStreamDescriptor(name="customers", namespace="public"): 100.0},
AirbyteMessage(type=Type.STATE, state=AirbyteStateMessage(type=AirbyteStateType.STREAM, stream=AirbyteStreamState(
stream_descriptor=StreamDescriptor(name="others", namespace="public"),
stream_state=AirbyteStateBlob(updated_at="2024-02-02")), sourceStats=AirbyteStateStats(recordCount=0.0))),
{HashableStreamDescriptor(name="customers", namespace="public"): 100,
HashableStreamDescriptor(name="others", namespace="public"): 0},
{HashableStreamDescriptor(name="customers", namespace="public"): 100.0,
HashableStreamDescriptor(name="others", namespace="public"): 0.0},
id="test_handle_state_message_no_records",
),
]
)
def test_handle_record_counts(incoming_message, stream_message_count, expected_message, expected_records_by_stream):
entrypoint = AirbyteEntrypoint(source=MockSource())
actual_record = entrypoint.handle_record_counts(message=incoming_message, stream_message_count=stream_message_count)
assert actual_record == expected_message
assert stream_message_count == expected_records_by_stream
actual_message = entrypoint.handle_record_counts(message=incoming_message, stream_message_count=stream_message_count)
assert actual_message == expected_message

for stream_descriptor, message_count in stream_message_count.items():
assert isinstance(message_count, float)
# Python assertions against different number types won't fail if the value is equivalent
assert message_count == expected_records_by_stream[stream_descriptor]

if actual_message.type == Type.STATE:
assert isinstance(actual_message.state.sourceStats.recordCount, float), "recordCount value should be expressed as a float"

0 comments on commit 624415d

Please sign in to comment.