From b7051a4ce9e6248ff5fceb32a3d930e8bb9b20c4 Mon Sep 17 00:00:00 2001 From: Jon Massey Date: Wed, 15 May 2024 17:03:04 +0100 Subject: [PATCH] Support upserts in database This is required as codespaces' state will change over time, and their data will disappear from the API once deleted so we are unable to do a full refresh of the data each time. Uses PostgreSQL "INSERT..ON CONFLICT.. UPDATE" style as newer "MERGE" statement not yet supported in SQLAlchemy. --- metrics/timescaledb/db.py | 44 ++++++++++++++++++++++++++-- tests/metrics/timescaledb/test_db.py | 26 ++++++++++++++++ 2 files changed, 68 insertions(+), 2 deletions(-) diff --git a/metrics/timescaledb/db.py b/metrics/timescaledb/db.py index 489ccd82..33095508 100644 --- a/metrics/timescaledb/db.py +++ b/metrics/timescaledb/db.py @@ -19,8 +19,7 @@ def reset_table(table, batch_size=None): def write(table, rows): - max_params = 65535 # limit for postgresql - batch_size = max_params // len(table.columns) + batch_size = _batch_size(table) with _get_engine().begin() as connection: for values in batched(rows, batch_size): @@ -28,6 +27,47 @@ def write(table, rows): log.info("Inserted %s rows", len(values), table=table.name) +def upsert(table, rows): + _ensure_table(table) + batch_size = _batch_size(table) + non_pk_columns = set(table.columns) - set(table.primary_key.columns) + + # use the primary key constraint to match rows to be updated, + # using default constraint name if not otherwise specified + constraint = table.primary_key.name or table.name + "_pkey" + + with _get_engine().begin() as connection: + for values in batched(rows, batch_size): + # Construct a PostgreSQL "INSERT..ON CONFLICT" style upsert statement + # https://docs.sqlalchemy.org/en/20/dialects/postgresql.html#insert-on-conflict-upsert + + # "Vanilla" statement to start, we need this to be able to derive + # the "excluded" columns in the values which we need to use to update + # the target table in case of conflict on the constraint + insert_stmt = insert(table).values(values) + + # This dict dicates which columns in the target table are updated (the + # non-PK columns) and the corresponding values with which they are updated + update_set_clause = { + c: insert_stmt.excluded[c.name] for c in non_pk_columns + } + + # Extend the insert statement to include checking for row conflicts using + # the primary key constraint and telling the database to update + # the conflicting rows according to the SET clause + insert_stmt = insert_stmt.on_conflict_do_update( + constraint=constraint, + set_=update_set_clause, + ) + connection.execute(insert_stmt) + log.info("Inserted %s rows", len(values), table=table.name) + + +def _batch_size(table): + max_params = 65535 # limit for postgresql + return max_params // len(table.columns) + + def _drop_table(table, batch_size): with _get_engine().begin() as connection: log.debug("Removing table: %s", table.name) diff --git a/tests/metrics/timescaledb/test_db.py b/tests/metrics/timescaledb/test_db.py index e3eec767..58f17e85 100644 --- a/tests/metrics/timescaledb/test_db.py +++ b/tests/metrics/timescaledb/test_db.py @@ -164,3 +164,29 @@ def test_write(engine, table): # check rows are in table rows = get_rows(engine, table) assert len(rows) == 3 + + +def test_upsert(engine, table): + # add a non-PK column to the table + table.append_column(Column("value2", Text)) + + # insert initial rows + rows = [{"value": i, "value2": "a"} for i in range(1, 4)] + db.upsert(table, rows) + + # second batch of rows, some with preexisting value1, some new + # all with different value2 + rows = [{"value": i, "value2": "b"} for i in range(3, 6)] + db.upsert(table, rows) + + # check all rows are in table + rows = get_rows(engine, table) + assert len(rows) == 5 + + # check upsert leaves unmatched rows 1-2 intact + original_rows = [r for r in rows if int(r[0]) < 3] + assert original_rows == [("1", "a"), ("2", "a")] + + # check upsert modifies matched row 3 and new rows 4-5 + modified_rows = [r for r in rows if int(r[0]) >= 3] + assert modified_rows == [("3", "b"), ("4", "b"), ("5", "b")]