From ff44a2f6a2349bda4af1f2ca3d9f8a5258e48e49 Mon Sep 17 00:00:00 2001 From: wlwlwlzhang Date: Wed, 6 Sep 2023 21:14:01 +0800 Subject: [PATCH] add test for pool --- test/test_dataset.py | 39 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/test/test_dataset.py b/test/test_dataset.py index f7c94eb..9579687 100644 --- a/test/test_dataset.py +++ b/test/test_dataset.py @@ -1,4 +1,5 @@ import os +import threading import unittest from datetime import datetime from collections import OrderedDict @@ -598,5 +599,43 @@ def test_iter(self): assert c == len(self.tbl) +class DatabasePoolTestCase(unittest.TestCase): + def test_pool(self): + target_num = 30 + table_name = "test_pool" + + def insert_data(): + with db as tx: + tx[table_name].insert(dict(name='John Doe', age=46, country='China')) + for _ in range(10): + engine_kwargs = {"echo": False, "pool_size": target_num, "max_overflow": 0, "connect_args": {'connect_timeout': 2}} + config_str = 'postgresql://postgres:123456@127.0.0.1:5432/postgres' + db = None + try: + db = connect(config_str, engine_kwargs=engine_kwargs) + with db as tx: + if table_name in tx: + tx[table_name].drop() + tx[table_name].insert(dict(name='John Doe', age=46, country='China')) + threads = [threading.Thread(target=insert_data) for _ in range(target_num)] + [thread.start() for thread in threads] + [thread.join() for thread in threads] + t = db[table_name].count() + with db as tx: + if table_name in tx: + tx[table_name].drop() + assert t == target_num + 1, t + except SQLAlchemyError as e: + if "timeout expired" in str(e) or "Connection refused" in str(e): + break + else: + raise + except Exception: + raise + finally: + if db: + db.close() + + if __name__ == "__main__": unittest.main()