diff --git a/metrics/timescaledb/db.py b/metrics/timescaledb/db.py index 489ccd82..e3042996 100644 --- a/metrics/timescaledb/db.py +++ b/metrics/timescaledb/db.py @@ -12,6 +12,9 @@ log = structlog.get_logger() +MAX_PARAMS = 65535 # limit for postgresql + + def reset_table(table, batch_size=None): _drop_table(table, batch_size) _ensure_table(table) @@ -19,8 +22,7 @@ def reset_table(table, batch_size=None): def write(table, rows): - max_params = 65535 # limit for postgresql - batch_size = max_params // len(table.columns) + batch_size = _batch_size(table) with _get_engine().begin() as connection: for values in batched(rows, batch_size): @@ -28,6 +30,24 @@ def write(table, rows): log.info("Inserted %s rows", len(values), table=table.name) +def upsert(table, rows): + _ensure_table(table) + batch_size = _batch_size(table) + non_pk_columns = set(table.columns) - set(table.primary_key.columns) + + with _get_engine().begin() as connection: + for values in batched(rows, batch_size): + insert_stmt = insert(table).values(values) + + set_ = {c: insert_stmt.excluded[c.name] for c in non_pk_columns} + + insert_stmt = insert_stmt.on_conflict_do_update( + constraint=table.primary_key, set_=set_ + ) + connection.execute(insert_stmt) + log.info("Inserted %s rows", len(values), table=table.name) + + def _drop_table(table, batch_size): with _get_engine().begin() as connection: log.debug("Removing table: %s", table.name) @@ -114,6 +134,10 @@ def _ensure_table(table): ) +def _batch_size(table): + return MAX_PARAMS // len(table.columns) + + @functools.cache def _get_engine(): return create_engine(_get_url()) diff --git a/tests/metrics/timescaledb/test_db.py b/tests/metrics/timescaledb/test_db.py index e3eec767..214ff6cc 100644 --- a/tests/metrics/timescaledb/test_db.py +++ b/tests/metrics/timescaledb/test_db.py @@ -164,3 +164,33 @@ def test_write(engine, table): # check rows are in table rows = get_rows(engine, table) assert len(rows) == 3 + + +def test_upsert(engine, table): + # add a non-PK column to the table + table.append_column(Column("value2", Text)) + + # set up a table to upsert into + db._ensure_table(table) + + # insert intital rows + rows = [{"value": i, "value2": i} for i in range(1, 4)] + + db.upsert(table, rows) + + # second batch of rows, some with differing value2s, some new + rows = [{"value": i, "value2": 2 * i} for i in range(3, 6)] + + db.upsert(table, rows) + + # check rows are in table + rows = get_rows(engine, table) + assert len(rows) == 5 + + # check upsert leaves unmatched rows intact + original_rows = [int(r[0]) / int(r[1]) for r in rows if int(r[0]) < 3] + assert set(original_rows) == {1} + + # check upsert modifies matched rows + modified_rows = [int(r[0]) / int(r[1]) for r in rows if int(r[0]) >= 3] + assert set(modified_rows) == {0.5}