diff --git a/locomotive/stores.py b/locomotive/stores.py index 7ee0b56..21819eb 100644 --- a/locomotive/stores.py +++ b/locomotive/stores.py @@ -92,26 +92,30 @@ class Stations: See https://github.com/trainline-eu/stations. """ - conn: sqlite3.Connection path: Path def __init__(self, path: Optional[Path] = None) -> None: - if path is None: - path = self.default_path() - # TODO: Close DB ? - self.conn = sqlite3.connect(f"file:{path}?mode=ro", uri=True) + self.path = self.default_path() + if path: + self.path = path + + def _conn(self) -> sqlite3.Connection: + return sqlite3.connect(f"file:{self.path}?mode=ro", uri=True) @classmethod def default_path(cls) -> Path: return Path(__file__).parent.joinpath("data", "stations.sqlite3") def count(self) -> int: - with contextlib.closing(self.conn.cursor()) as c: + with self._conn() as conn: + c = conn.cursor() c.execute("SELECT COUNT(*) FROM stations") return int(c.fetchone()[0]) def find(self, query: str) -> Optional[Station]: - with contextlib.closing(self.conn.cursor()) as c: + with self._conn() as conn: + c = conn.cursor() + # a) Try to find matching IDs c.execute("SELECT * FROM stations WHERE sncf_id LIKE ?", (query,)) row = c.fetchone()