Skip to content

Commit

Permalink
Support upserts in database
Browse files Browse the repository at this point in the history
This is required as codespaces' state will change over time
  • Loading branch information
Jongmassey committed May 17, 2024
1 parent 607b870 commit 1c85128
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 2 deletions.
28 changes: 26 additions & 2 deletions metrics/timescaledb/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,42 @@
log = structlog.get_logger()


MAX_PARAMS = 65535 # limit for postgresql


def reset_table(table, batch_size=None):
_drop_table(table, batch_size)
_ensure_table(table)
log.info("Reset table", table=table.name)


def write(table, rows):
max_params = 65535 # limit for postgresql
batch_size = max_params // len(table.columns)
batch_size = _batch_size(table)

with _get_engine().begin() as connection:
for values in batched(rows, batch_size):
connection.execute(insert(table).values(values))
log.info("Inserted %s rows", len(values), table=table.name)


def upsert(table, rows):
_ensure_table(table)
batch_size = _batch_size(table)
non_pk_columns = set(table.columns) - set(table.primary_key.columns)

with _get_engine().begin() as connection:
for values in batched(rows, batch_size):
insert_stmt = insert(table).values(values)

set_ = {c: insert_stmt.excluded[c.name] for c in non_pk_columns}

insert_stmt = insert_stmt.on_conflict_do_update(
constraint=table.primary_key, set_=set_
)
connection.execute(insert_stmt)
log.info("Inserted %s rows", len(values), table=table.name)


def _drop_table(table, batch_size):
with _get_engine().begin() as connection:
log.debug("Removing table: %s", table.name)
Expand Down Expand Up @@ -114,6 +134,10 @@ def _ensure_table(table):
)


def _batch_size(table):
return MAX_PARAMS // len(table.columns)


@functools.cache
def _get_engine():
return create_engine(_get_url())
Expand Down
30 changes: 30 additions & 0 deletions tests/metrics/timescaledb/test_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,3 +164,33 @@ def test_write(engine, table):
# check rows are in table
rows = get_rows(engine, table)
assert len(rows) == 3


def test_upsert(engine, table):
# add a non-PK column to the table
table.append_column(Column("value2", Text))

# set up a table to upsert into
db._ensure_table(table)

# insert intital rows
rows = [{"value": i, "value2": i} for i in range(1, 4)]

db.upsert(table, rows)

# second batch of rows, some with differing value2s, some new
rows = [{"value": i, "value2": 2 * i} for i in range(3, 6)]

db.upsert(table, rows)

# check rows are in table
rows = get_rows(engine, table)
assert len(rows) == 5

# check upsert leaves unmatched rows intact
original_rows = [int(r[0]) / int(r[1]) for r in rows if int(r[0]) < 3]
assert set(original_rows) == {1}

# check upsert modifies matched rows
modified_rows = [int(r[0]) / int(r[1]) for r in rows if int(r[0]) >= 3]
assert set(modified_rows) == {0.5}

0 comments on commit 1c85128

Please sign in to comment.