Skip to content

Commit

Permalink
memory-neutral concat (#7569)
Browse files Browse the repository at this point in the history
GitOrigin-RevId: 5e4f562a7c7e4b944604589d6d96721378941174
  • Loading branch information
KamilPiechowiak authored and Manul from Pathway committed Nov 19, 2024
1 parent 6f901c3 commit f3e91d4
Show file tree
Hide file tree
Showing 11 changed files with 112 additions and 78 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ All notable changes to this project will be documented in this file.
This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [Unreleased]

### Changed
- `pw.Table.concat`, `pw.Table.with_id`, `pw.Table.with_id_from` no longer perform checks if ids are unique. It improves memory usage.

### Fixed
- `query_as_of_now` of `pw.stdlib.indexing.DataIndex` and `pw.stdlib.indexing.HybridIndex` now work in constant memory for infinite query stream (no query-related data is kept after query is answered).

Expand Down
10 changes: 8 additions & 2 deletions python/pathway/debug/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,9 @@ def table_to_dicts(
**kwargs,
) -> tuple[list[api.Pointer], dict[str, dict[api.Pointer, api.Value]]]:
captured = _compute_tables(table, **kwargs)[0]
output_data = api.squash_updates(captured)
output_data = api.squash_updates(
captured, terminate_on_error=kwargs.get("terminate_on_error", True)
)
keys = list(output_data.keys())
columns = {
name: {key: output_data[key][index] for key in keys}
Expand Down Expand Up @@ -115,6 +117,7 @@ def _compute_and_print_internal(
include_id=include_id,
short_pointers=short_pointers,
n_rows=n_rows,
terminate_on_error=kwargs.get("terminate_on_error", True),
)


Expand All @@ -126,10 +129,13 @@ def _compute_and_print_single(
include_id: bool,
short_pointers: bool,
n_rows: int | None,
terminate_on_error: bool,
) -> None:
columns = list(table._columns.keys())
if squash_updates:
output_data = list(api.squash_updates(captured).items())
output_data = list(
api.squash_updates(captured, terminate_on_error=terminate_on_error).items()
)
else:
columns.extend([api.TIME_PSEUDOCOLUMN, api.DIFF_PSEUDOCOLUMN])
output_data = []
Expand Down
28 changes: 22 additions & 6 deletions python/pathway/internals/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import datetime
import warnings
from typing import TYPE_CHECKING, Any, Generic, Protocol, TypeAlias, TypeVar, Union

import numpy as np
Expand Down Expand Up @@ -32,6 +33,7 @@
json.Json,
dict[str, _Value],
tuple[_Value, ...],
Error,
]
CapturedTable = dict[Pointer, tuple[Value, ...]]
CapturedStream = list[DataRow]
Expand Down Expand Up @@ -191,20 +193,34 @@ def static_table_from_pandas(
return scope.static_table(input_data, connector_properties)


def squash_updates(updates: CapturedStream) -> CapturedTable:
def squash_updates(
updates: CapturedStream, *, terminate_on_error: bool = True
) -> CapturedTable:
state: CapturedTable = {}
updates.sort(key=lambda row: (row.time, row.diff))

def handle_error(row: DataRow, msg: str):
if terminate_on_error:
raise KeyError(msg)
else:
warnings.warn(msg)
t: tuple[Value, ...] = (ERROR,) * len(row.values)
state[row.key] = t

for row in updates:
if row.diff == 1:
assert row.key not in state, f"duplicated entries for key {row.key}"
if row.key in state:
handle_error(row, f"duplicated entries for key {row.key}")
continue
state[row.key] = tuple(row.values)
elif row.diff == -1:
assert state[row.key] == tuple(
row.values
), f"deleting non-existing entry {row.values}"
if state[row.key] != tuple(row.values):
handle_error(row, f"deleting non-existing entry {row.values}")
continue
del state[row.key]
else:
raise AssertionError(f"Invalid diff value: {row.diff}")
handle_error(row, f"invalid diff value: {row.diff}")
continue

return state

Expand Down
4 changes: 4 additions & 0 deletions python/pathway/internals/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -1644,6 +1644,8 @@ def with_id(self, new_index: expr.ColumnReference) -> Table:
To generate ids based on arbitrary valued columns, use `with_id_from`.
Values assigned must be row-wise unique.
The uniqueness is not checked by pathway. Failing to provide unique ids can
cause unexpected errors downstream.
Args:
new_id: column to be used as the new index.
Expand Down Expand Up @@ -1686,6 +1688,8 @@ def with_id_from(
) -> Table:
"""Compute new ids based on values in columns.
Ids computed from `columns` must be row-wise unique.
The uniqueness is not checked by pathway. Failing to provide unique ids can
cause unexpected errors downstream.
Args:
columns: columns to be used as primary keys.
Expand Down
9 changes: 6 additions & 3 deletions python/pathway/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -988,9 +988,12 @@ def test_concat_errors_on_intersecting_universes():
)

pw.universes.promise_are_pairwise_disjoint(t1, t2)
pw.Table.concat(t1, t2)
with pytest.raises(KeyError, match="duplicate key"):
run_all()
result = pw.Table.concat(t1, t2)
with pytest.raises(
KeyError,
match=re.escape("duplicated entries for key ^YYY4HABTRW7T8VX2Q429ZYV70W"),
):
pw.debug.compute_and_print(result)


@pytest.mark.parametrize("dtype", [np.int64, np.float64])
Expand Down
63 changes: 26 additions & 37 deletions python/pathway/tests/test_errors.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright © 2024 Pathway

import logging
import re
from pathlib import Path
from unittest import mock

Expand Down Expand Up @@ -463,26 +464,20 @@ def test_concat():
)
expected = pw.debug.table_from_markdown(
"""
a | b
-1 | -1
2 | 5
3 | 1
4 | 3
5 | 1
a | b | e
-1 | -1 | 0
2 | 5 | 1
3 | 1 | 1
4 | 3 | 1
5 | 1 | 1
"""
)
expected_errors = T(
"""
message
duplicate key: ^YYY4HABTRW7T8VX2Q429ZYV70W
""",
split_on_whitespace=False,
)
assert_table_equality_wo_index(
(res, pw.global_error_log().select(pw.this.message)),
(expected, expected_errors),
terminate_on_error=False,
)
).select(a=pw.this.a // pw.this.e, b=pw.this.b // pw.this.e)
# column e used to produce ERROR in the first row
with pytest.warns(
UserWarning,
match=re.escape("duplicated entries for key ^YYY4HABTRW7T8VX2Q429ZYV70W"),
):
assert_table_equality_wo_index(res, expected, terminate_on_error=False)


def test_left_join_preserving_id():
Expand Down Expand Up @@ -702,27 +697,21 @@ def test_reindex_with_duplicate_key():
expected = (
pw.debug.table_from_markdown(
"""
a | b
1 | 3
2 | 4
3 | -1
a | b | e
1 | 3 | 1
2 | 4 | 1
3 | -1 | 0
"""
)
.with_id_from(pw.this.a)
.with_columns(a=pw.if_else(pw.this.a == 3, -1, pw.this.a))
)
expected_errors = T(
"""
message
duplicate key: ^3CZ78B48PASGNT231ZECWPER90
""",
split_on_whitespace=False,
)
assert_table_equality_wo_index(
(res, pw.global_error_log().select(pw.this.message)),
(expected, expected_errors),
terminate_on_error=False,
)
.select(a=pw.this.a // pw.this.e, b=pw.this.b // pw.this.e)
)
# column e used to produce ERROR in the first row
with pytest.warns(
UserWarning,
match=re.escape("duplicated entries for key ^3CZ78B48PASGNT231ZECWPER90"),
):
assert_table_equality_wo_index(res, expected, terminate_on_error=False)


def test_groupby_with_error_in_grouping_column():
Expand Down
28 changes: 18 additions & 10 deletions python/pathway/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,8 +225,10 @@ def assert_stream_equal(expected: list[DiffEntry], table: pw.Table):
pw.io.subscribe(table, callback, callback.on_end)


def assert_equal_tables(t0: api.CapturedStream, t1: api.CapturedStream) -> None:
assert api.squash_updates(t0) == api.squash_updates(t1)
def assert_equal_tables(
t0: api.CapturedStream, t1: api.CapturedStream, **kwargs
) -> None:
assert api.squash_updates(t0, **kwargs) == api.squash_updates(t1, **kwargs)


def make_value_hashable(val: api.Value):
Expand All @@ -243,16 +245,18 @@ def make_row_hashable(row: tuple[api.Value, ...]):


def assert_equal_tables_wo_index(
s0: api.CapturedStream, s1: api.CapturedStream
s0: api.CapturedStream, s1: api.CapturedStream, **kwargs
) -> None:
t0 = api.squash_updates(s0)
t1 = api.squash_updates(s1)
t0 = api.squash_updates(s0, **kwargs)
t1 = api.squash_updates(s1, **kwargs)
assert collections.Counter(
make_row_hashable(row) for row in t0.values()
) == collections.Counter(make_row_hashable(row) for row in t1.values())


def assert_equal_streams(t0: api.CapturedStream, t1: api.CapturedStream) -> None:
def assert_equal_streams(
t0: api.CapturedStream, t1: api.CapturedStream, **kwargs
) -> None:
def transform(row: api.DataRow) -> Hashable:
t = (row.key,) + tuple(row.values) + (row.time, row.diff)
return make_row_hashable(t)
Expand All @@ -263,7 +267,7 @@ def transform(row: api.DataRow) -> Hashable:


def assert_equal_streams_wo_index(
t0: api.CapturedStream, t1: api.CapturedStream
t0: api.CapturedStream, t1: api.CapturedStream, **kwargs
) -> None:
def transform(row: api.DataRow) -> Hashable:
t = tuple(row.values) + (row.time, row.diff)
Expand Down Expand Up @@ -304,7 +308,7 @@ def assert_split_into_time_groups(


def assert_streams_in_time_groups(
t0: api.CapturedStream, t1: api.CapturedStream
t0: api.CapturedStream, t1: api.CapturedStream, **kwargs
) -> None:
def transform(row: api.DataRow) -> tuple[Hashable, int]:
t = (row.key, *row.values, row.diff)
Expand All @@ -314,7 +318,7 @@ def transform(row: api.DataRow) -> tuple[Hashable, int]:


def assert_streams_in_time_groups_wo_index(
t0: api.CapturedStream, t1: api.CapturedStream
t0: api.CapturedStream, t1: api.CapturedStream, **kwargs
) -> None:
def transform(row: api.DataRow) -> tuple[Hashable, int]:
t = (*row.values, row.diff)
Expand Down Expand Up @@ -460,7 +464,11 @@ def inner(
n = len(expected)
captured_tables, captured_expected = captured[:n], captured[n:]
for captured_t, captured_ex in zip(captured_tables, captured_expected):
verifier(captured_t, captured_ex)
verifier(
captured_t,
captured_ex,
terminate_on_error=kwargs.get("terminate_on_error", True),
)

return inner

Expand Down
2 changes: 1 addition & 1 deletion src/connectors/metadata.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ mod file_owner {

#[cfg(not(target_os = "linux"))]
mod file_owner {
pub fn get_owner(metadata: &std::fs::Metadata) -> Option<String> {
pub fn get_owner(_metadata: &std::fs::Metadata) -> Option<String> {
None
}
}
Expand Down
Loading

0 comments on commit f3e91d4

Please sign in to comment.