From f3e91d46985f3c8858c9669063cf81edc8e9cadf Mon Sep 17 00:00:00 2001 From: Kamil Piechowiak <32928185+KamilPiechowiak@users.noreply.github.com> Date: Tue, 19 Nov 2024 17:18:23 +0100 Subject: [PATCH] memory-neutral concat (#7569) GitOrigin-RevId: 5e4f562a7c7e4b944604589d6d96721378941174 --- CHANGELOG.md | 3 ++ python/pathway/debug/__init__.py | 10 ++++- python/pathway/internals/api.py | 28 ++++++++++--- python/pathway/internals/table.py | 4 ++ python/pathway/tests/test_common.py | 9 +++-- python/pathway/tests/test_errors.py | 63 ++++++++++++----------------- python/pathway/tests/utils.py | 28 ++++++++----- src/connectors/metadata.rs | 2 +- src/engine/dataflow.rs | 39 +++++++++--------- src/engine/error.rs | 3 ++ src/python_api.rs | 1 - 11 files changed, 112 insertions(+), 78 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ed24980f..ebccb6fb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,9 @@ All notable changes to this project will be documented in this file. This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## [Unreleased] +### Changed +- `pw.Table.concat`, `pw.Table.with_id`, `pw.Table.with_id_from` no longer perform checks if ids are unique. It improves memory usage. + ### Fixed - `query_as_of_now` of `pw.stdlib.indexing.DataIndex` and `pw.stdlib.indexing.HybridIndex` now work in constant memory for infinite query stream (no query-related data is kept after query is answered). diff --git a/python/pathway/debug/__init__.py b/python/pathway/debug/__init__.py index c1b76f00..dc69c9cd 100644 --- a/python/pathway/debug/__init__.py +++ b/python/pathway/debug/__init__.py @@ -63,7 +63,9 @@ def table_to_dicts( **kwargs, ) -> tuple[list[api.Pointer], dict[str, dict[api.Pointer, api.Value]]]: captured = _compute_tables(table, **kwargs)[0] - output_data = api.squash_updates(captured) + output_data = api.squash_updates( + captured, terminate_on_error=kwargs.get("terminate_on_error", True) + ) keys = list(output_data.keys()) columns = { name: {key: output_data[key][index] for key in keys} @@ -115,6 +117,7 @@ def _compute_and_print_internal( include_id=include_id, short_pointers=short_pointers, n_rows=n_rows, + terminate_on_error=kwargs.get("terminate_on_error", True), ) @@ -126,10 +129,13 @@ def _compute_and_print_single( include_id: bool, short_pointers: bool, n_rows: int | None, + terminate_on_error: bool, ) -> None: columns = list(table._columns.keys()) if squash_updates: - output_data = list(api.squash_updates(captured).items()) + output_data = list( + api.squash_updates(captured, terminate_on_error=terminate_on_error).items() + ) else: columns.extend([api.TIME_PSEUDOCOLUMN, api.DIFF_PSEUDOCOLUMN]) output_data = [] diff --git a/python/pathway/internals/api.py b/python/pathway/internals/api.py index fad7893e..068b5692 100644 --- a/python/pathway/internals/api.py +++ b/python/pathway/internals/api.py @@ -3,6 +3,7 @@ from __future__ import annotations import datetime +import warnings from typing import TYPE_CHECKING, Any, Generic, Protocol, TypeAlias, TypeVar, Union import numpy as np @@ -32,6 +33,7 @@ json.Json, dict[str, _Value], tuple[_Value, ...], + Error, ] CapturedTable = dict[Pointer, tuple[Value, ...]] CapturedStream = list[DataRow] @@ -191,20 +193,34 @@ def static_table_from_pandas( return scope.static_table(input_data, connector_properties) -def squash_updates(updates: CapturedStream) -> CapturedTable: +def squash_updates( + updates: CapturedStream, *, terminate_on_error: bool = True +) -> CapturedTable: state: CapturedTable = {} updates.sort(key=lambda row: (row.time, row.diff)) + + def handle_error(row: DataRow, msg: str): + if terminate_on_error: + raise KeyError(msg) + else: + warnings.warn(msg) + t: tuple[Value, ...] = (ERROR,) * len(row.values) + state[row.key] = t + for row in updates: if row.diff == 1: - assert row.key not in state, f"duplicated entries for key {row.key}" + if row.key in state: + handle_error(row, f"duplicated entries for key {row.key}") + continue state[row.key] = tuple(row.values) elif row.diff == -1: - assert state[row.key] == tuple( - row.values - ), f"deleting non-existing entry {row.values}" + if state[row.key] != tuple(row.values): + handle_error(row, f"deleting non-existing entry {row.values}") + continue del state[row.key] else: - raise AssertionError(f"Invalid diff value: {row.diff}") + handle_error(row, f"invalid diff value: {row.diff}") + continue return state diff --git a/python/pathway/internals/table.py b/python/pathway/internals/table.py index 4a20f2c7..ae5b9cde 100644 --- a/python/pathway/internals/table.py +++ b/python/pathway/internals/table.py @@ -1644,6 +1644,8 @@ def with_id(self, new_index: expr.ColumnReference) -> Table: To generate ids based on arbitrary valued columns, use `with_id_from`. Values assigned must be row-wise unique. + The uniqueness is not checked by pathway. Failing to provide unique ids can + cause unexpected errors downstream. Args: new_id: column to be used as the new index. @@ -1686,6 +1688,8 @@ def with_id_from( ) -> Table: """Compute new ids based on values in columns. Ids computed from `columns` must be row-wise unique. + The uniqueness is not checked by pathway. Failing to provide unique ids can + cause unexpected errors downstream. Args: columns: columns to be used as primary keys. diff --git a/python/pathway/tests/test_common.py b/python/pathway/tests/test_common.py index 4da2e7a5..caa9e1ba 100644 --- a/python/pathway/tests/test_common.py +++ b/python/pathway/tests/test_common.py @@ -988,9 +988,12 @@ def test_concat_errors_on_intersecting_universes(): ) pw.universes.promise_are_pairwise_disjoint(t1, t2) - pw.Table.concat(t1, t2) - with pytest.raises(KeyError, match="duplicate key"): - run_all() + result = pw.Table.concat(t1, t2) + with pytest.raises( + KeyError, + match=re.escape("duplicated entries for key ^YYY4HABTRW7T8VX2Q429ZYV70W"), + ): + pw.debug.compute_and_print(result) @pytest.mark.parametrize("dtype", [np.int64, np.float64]) diff --git a/python/pathway/tests/test_errors.py b/python/pathway/tests/test_errors.py index 80511e0b..1f5bfca4 100644 --- a/python/pathway/tests/test_errors.py +++ b/python/pathway/tests/test_errors.py @@ -1,6 +1,7 @@ # Copyright © 2024 Pathway import logging +import re from pathlib import Path from unittest import mock @@ -463,26 +464,20 @@ def test_concat(): ) expected = pw.debug.table_from_markdown( """ - a | b - -1 | -1 - 2 | 5 - 3 | 1 - 4 | 3 - 5 | 1 + a | b | e + -1 | -1 | 0 + 2 | 5 | 1 + 3 | 1 | 1 + 4 | 3 | 1 + 5 | 1 | 1 """ - ) - expected_errors = T( - """ - message - duplicate key: ^YYY4HABTRW7T8VX2Q429ZYV70W - """, - split_on_whitespace=False, - ) - assert_table_equality_wo_index( - (res, pw.global_error_log().select(pw.this.message)), - (expected, expected_errors), - terminate_on_error=False, - ) + ).select(a=pw.this.a // pw.this.e, b=pw.this.b // pw.this.e) + # column e used to produce ERROR in the first row + with pytest.warns( + UserWarning, + match=re.escape("duplicated entries for key ^YYY4HABTRW7T8VX2Q429ZYV70W"), + ): + assert_table_equality_wo_index(res, expected, terminate_on_error=False) def test_left_join_preserving_id(): @@ -702,27 +697,21 @@ def test_reindex_with_duplicate_key(): expected = ( pw.debug.table_from_markdown( """ - a | b - 1 | 3 - 2 | 4 - 3 | -1 + a | b | e + 1 | 3 | 1 + 2 | 4 | 1 + 3 | -1 | 0 """ ) .with_id_from(pw.this.a) - .with_columns(a=pw.if_else(pw.this.a == 3, -1, pw.this.a)) - ) - expected_errors = T( - """ - message - duplicate key: ^3CZ78B48PASGNT231ZECWPER90 - """, - split_on_whitespace=False, - ) - assert_table_equality_wo_index( - (res, pw.global_error_log().select(pw.this.message)), - (expected, expected_errors), - terminate_on_error=False, - ) + .select(a=pw.this.a // pw.this.e, b=pw.this.b // pw.this.e) + ) + # column e used to produce ERROR in the first row + with pytest.warns( + UserWarning, + match=re.escape("duplicated entries for key ^3CZ78B48PASGNT231ZECWPER90"), + ): + assert_table_equality_wo_index(res, expected, terminate_on_error=False) def test_groupby_with_error_in_grouping_column(): diff --git a/python/pathway/tests/utils.py b/python/pathway/tests/utils.py index fe123d68..91692ea9 100644 --- a/python/pathway/tests/utils.py +++ b/python/pathway/tests/utils.py @@ -225,8 +225,10 @@ def assert_stream_equal(expected: list[DiffEntry], table: pw.Table): pw.io.subscribe(table, callback, callback.on_end) -def assert_equal_tables(t0: api.CapturedStream, t1: api.CapturedStream) -> None: - assert api.squash_updates(t0) == api.squash_updates(t1) +def assert_equal_tables( + t0: api.CapturedStream, t1: api.CapturedStream, **kwargs +) -> None: + assert api.squash_updates(t0, **kwargs) == api.squash_updates(t1, **kwargs) def make_value_hashable(val: api.Value): @@ -243,16 +245,18 @@ def make_row_hashable(row: tuple[api.Value, ...]): def assert_equal_tables_wo_index( - s0: api.CapturedStream, s1: api.CapturedStream + s0: api.CapturedStream, s1: api.CapturedStream, **kwargs ) -> None: - t0 = api.squash_updates(s0) - t1 = api.squash_updates(s1) + t0 = api.squash_updates(s0, **kwargs) + t1 = api.squash_updates(s1, **kwargs) assert collections.Counter( make_row_hashable(row) for row in t0.values() ) == collections.Counter(make_row_hashable(row) for row in t1.values()) -def assert_equal_streams(t0: api.CapturedStream, t1: api.CapturedStream) -> None: +def assert_equal_streams( + t0: api.CapturedStream, t1: api.CapturedStream, **kwargs +) -> None: def transform(row: api.DataRow) -> Hashable: t = (row.key,) + tuple(row.values) + (row.time, row.diff) return make_row_hashable(t) @@ -263,7 +267,7 @@ def transform(row: api.DataRow) -> Hashable: def assert_equal_streams_wo_index( - t0: api.CapturedStream, t1: api.CapturedStream + t0: api.CapturedStream, t1: api.CapturedStream, **kwargs ) -> None: def transform(row: api.DataRow) -> Hashable: t = tuple(row.values) + (row.time, row.diff) @@ -304,7 +308,7 @@ def assert_split_into_time_groups( def assert_streams_in_time_groups( - t0: api.CapturedStream, t1: api.CapturedStream + t0: api.CapturedStream, t1: api.CapturedStream, **kwargs ) -> None: def transform(row: api.DataRow) -> tuple[Hashable, int]: t = (row.key, *row.values, row.diff) @@ -314,7 +318,7 @@ def transform(row: api.DataRow) -> tuple[Hashable, int]: def assert_streams_in_time_groups_wo_index( - t0: api.CapturedStream, t1: api.CapturedStream + t0: api.CapturedStream, t1: api.CapturedStream, **kwargs ) -> None: def transform(row: api.DataRow) -> tuple[Hashable, int]: t = (*row.values, row.diff) @@ -460,7 +464,11 @@ def inner( n = len(expected) captured_tables, captured_expected = captured[:n], captured[n:] for captured_t, captured_ex in zip(captured_tables, captured_expected): - verifier(captured_t, captured_ex) + verifier( + captured_t, + captured_ex, + terminate_on_error=kwargs.get("terminate_on_error", True), + ) return inner diff --git a/src/connectors/metadata.rs b/src/connectors/metadata.rs index 65985504..8f0bcd26 100644 --- a/src/connectors/metadata.rs +++ b/src/connectors/metadata.rs @@ -59,7 +59,7 @@ mod file_owner { #[cfg(not(target_os = "linux"))] mod file_owner { - pub fn get_owner(metadata: &std::fs::Metadata) -> Option { + pub fn get_owner(_metadata: &std::fs::Metadata) -> Option { None } } diff --git a/src/engine/dataflow.rs b/src/engine/dataflow.rs index b72f683b..f3876cda 100644 --- a/src/engine/dataflow.rs +++ b/src/engine/dataflow.rs @@ -1767,13 +1767,6 @@ impl DataflowGraphInner { } }); - let new_values = if self.ignore_asserts { - new_values - } else { - let error_logger = self.create_error_logger()?; - new_values.replace_duplicates_with_error(|_| Value::Error, error_logger) - }; - Ok(self .tables .alloc(Table::from_collection(new_values).with_properties(table_properties))) @@ -1823,12 +1816,6 @@ impl DataflowGraphInner { }) .collect::>()?; let result = concatenate(&mut self.scope, table_collections); - let result = if self.ignore_asserts { - result - } else { - let error_logger = self.create_error_logger()?; - result.replace_duplicates_with_error(|_| Value::Error, error_logger) - }; let table = Table::from_collection(result).with_properties(table_properties); let table_handle = self.tables.alloc(table); Ok(table_handle) @@ -2006,11 +1993,12 @@ impl DataflowGraphInner { update_handle: TableHandle, table_properties: Arc, ) -> Result { + let error_logger = self.create_error_logger()?; let both_arranged = self.update_rows_arrange(table_handle, update_handle)?; let updated_values = both_arranged.reduce_abelian( "update_rows_table::updated", - move |_key, input, output| { + move |key, input, output| { let values = match input { [(MaybeUpdate::Original(original_values), 1)] => original_values, [(MaybeUpdate::Update(new_values), 1)] => new_values, @@ -2018,7 +2006,8 @@ impl DataflowGraphInner { new_values } _ => { - panic!("unexpected counts in input"); + error_logger.log_error(DataError::DuplicateKey(*key)); + return; } }; output.push((values.clone(), 1)); @@ -2038,6 +2027,7 @@ impl DataflowGraphInner { update_paths: Vec, table_properties: Arc, ) -> Result { + let error_logger = self.create_error_logger()?; let both_arranged = self.update_rows_arrange(table_handle, update_handle)?; let error_reporter = self.error_reporter.clone(); @@ -2055,11 +2045,21 @@ impl DataflowGraphInner { ] => { (original_values, new_values, &update_paths) } + [ + (MaybeUpdate::Original(original_values), 1), + (MaybeUpdate::Update(_), _), + .. + ] => { // if there's exactly one original entry, keep it to preserve the universe keys + error_logger.log_error(DataError::DuplicateKey(*key)); + (original_values, &Value::Error, &update_paths) + }, [(MaybeUpdate::Update(_), 1)] => { - panic!("updating a row that does not exist"); + error_logger.log_error(DataError::UpdatingNonExistingRow(*key)); + return; } _ => { - panic!("unexpected counts in input"); + error_logger.log_error(DataError::DuplicateKey(*key)); + return; } }; let updates: Vec<_> = selected_paths @@ -3092,6 +3092,7 @@ where let new_values_persisted = if let Some(persistent_id) = persistent_id { let error_reporter = self.error_reporter.clone(); + let error_logger = self.create_error_logger()?; let snapshot_writer = self .worker_persistent_storage .as_ref() @@ -3109,7 +3110,9 @@ where if *time == ARTIFICIAL_TIME_ON_REWIND_START { continue; } - assert!(*diff == 1 || *diff == -1); + if *diff != 1 && *diff != -1 { + error_logger.log_error(DataError::DuplicateKey(*key)); + } let values_vec: Vec = (**values.as_tuple().unwrap_with_reporter(&error_reporter)).into(); let event = if *diff == 1 { diff --git a/src/engine/error.rs b/src/engine/error.rs index d2195133..bdb5e788 100644 --- a/src/engine/error.rs +++ b/src/engine/error.rs @@ -318,6 +318,9 @@ pub enum DataError { #[error("mixing types in npsum is not allowed")] MixingTypesInNpSum, + #[error("updating a row that does not exist, key: {0}")] + UpdatingNonExistingRow(Key), + #[error(transparent)] Other(DynError), } diff --git a/src/python_api.rs b/src/python_api.rs index e0249632..68c9b1a0 100644 --- a/src/python_api.rs +++ b/src/python_api.rs @@ -3220,7 +3220,6 @@ fn capture_table_data( let table_data = table_data.clone(); let callbacks = SubscribeCallbacksBuilder::new() .on_data(Box::new(move |key, values, time, diff| { - assert!(diff == 1 || diff == -1); table_data.lock().unwrap().push(DataRow::from_engine( key, Vec::from(values),