Skip to content

Commit

Permalink
feat: Implement option 'truncate' of argument 'if_exists' in 'DataFra…
Browse files Browse the repository at this point in the history
…me.to_sql' API.
  • Loading branch information
gmcrocetti committed Aug 2, 2024
1 parent 642d244 commit c85da12
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 16 deletions.
2 changes: 1 addition & 1 deletion doc/source/whatsnew/v3.0.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ Other enhancements
- Restore support for reading Stata 104-format and enable reading 103-format dta files (:issue:`58554`)
- Support reading Stata 102-format (Stata 1) dta files (:issue:`58978`)
- Support reading Stata 110-format (Stata 7) dta files (:issue:`47176`)

- Add ``"truncate"`` option to ``if_exists`` argument in :meth:`DataFrame.to_sql` truncating the table before inserting data (:issue:`37210`).
.. ---------------------------------------------------------------------------
.. _whatsnew_300.notable_bug_fixes:

Expand Down
62 changes: 52 additions & 10 deletions pandas/io/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,7 +733,7 @@ def to_sql(
name: str,
con,
schema: str | None = None,
if_exists: Literal["fail", "replace", "append"] = "fail",
if_exists: Literal["fail", "replace", "append", "truncate"] = "fail",
index: bool = True,
index_label: IndexLabel | None = None,
chunksize: int | None = None,
Expand All @@ -759,10 +759,12 @@ def to_sql(
schema : str, optional
Name of SQL schema in database to write to (if database flavor
supports this). If None, use default schema (default).
if_exists : {'fail', 'replace', 'append'}, default 'fail'
if_exists : {'fail', 'replace', 'append', 'truncate'}, default 'fail'
- fail: If table exists, do nothing.
- replace: If table exists, drop it, recreate it, and insert data.
- append: If table exists, insert data. Create if does not exist.
- truncate: If table exists, truncate it. Create if does not exist.
Raises NotImplementedError if 'TRUNCATE TABLE' is not supported
index : bool, default True
Write DataFrame index as a column.
index_label : str or sequence, optional
Expand Down Expand Up @@ -813,7 +815,7 @@ def to_sql(
`sqlite3 <https://docs.python.org/3/library/sqlite3.html#sqlite3.Cursor.rowcount>`__ or
`SQLAlchemy <https://docs.sqlalchemy.org/en/14/core/connections.html#sqlalchemy.engine.BaseCursorResult.rowcount>`__
""" # noqa: E501
if if_exists not in ("fail", "replace", "append"):
if if_exists not in ("fail", "replace", "append", "truncate"):
raise ValueError(f"'{if_exists}' is not valid for if_exists")

if isinstance(frame, Series):
Expand Down Expand Up @@ -921,7 +923,7 @@ def __init__(
pandas_sql_engine,
frame=None,
index: bool | str | list[str] | None = True,
if_exists: Literal["fail", "replace", "append"] = "fail",
if_exists: Literal["fail", "replace", "append", "truncate"] = "fail",
prefix: str = "pandas",
index_label=None,
schema=None,
Expand Down Expand Up @@ -969,11 +971,13 @@ def create(self) -> None:
if self.exists():
if self.if_exists == "fail":
raise ValueError(f"Table '{self.name}' already exists.")
if self.if_exists == "replace":
elif self.if_exists == "replace":
self.pd_sql.drop_table(self.name, self.schema)
self._execute_create()
elif self.if_exists == "append":
pass
elif self.if_exists == "truncate":
self.pd_sql.truncate_table(self.name, self.schema)
else:
raise ValueError(f"'{self.if_exists}' is not valid for if_exists")
else:
Expand Down Expand Up @@ -1465,7 +1469,7 @@ def to_sql(
self,
frame,
name: str,
if_exists: Literal["fail", "replace", "append"] = "fail",
if_exists: Literal["fail", "replace", "append", "truncate"] = "fail",
index: bool = True,
index_label=None,
schema=None,
Expand Down Expand Up @@ -1850,7 +1854,7 @@ def prep_table(
self,
frame,
name: str,
if_exists: Literal["fail", "replace", "append"] = "fail",
if_exists: Literal["fail", "replace", "append", "truncate"] = "fail",
index: bool | str | list[str] | None = True,
index_label=None,
schema=None,
Expand Down Expand Up @@ -1927,7 +1931,7 @@ def to_sql(
self,
frame,
name: str,
if_exists: Literal["fail", "replace", "append"] = "fail",
if_exists: Literal["fail", "replace", "append", "truncate"] = "fail",
index: bool = True,
index_label=None,
schema: str | None = None,
Expand All @@ -1945,10 +1949,12 @@ def to_sql(
frame : DataFrame
name : string
Name of SQL table.
if_exists : {'fail', 'replace', 'append'}, default 'fail'
if_exists : {'fail', 'replace', 'append', 'truncate'}, default 'fail'
- fail: If table exists, do nothing.
- replace: If table exists, drop it, recreate it, and insert data.
- append: If table exists, insert data. Create if does not exist.
- truncate: If table exists, truncate it. Create if does not exist.
Raises NotImplementedError if 'TRUNCATE TABLE' is not supported
index : boolean, default True
Write DataFrame index as a column.
index_label : string or sequence, default None
Expand Down Expand Up @@ -2045,6 +2051,25 @@ def drop_table(self, table_name: str, schema: str | None = None) -> None:
self.get_table(table_name, schema).drop(bind=self.con)
self.meta.clear()

def truncate_table(self, table_name: str, schema: str | None = None) -> None:
from sqlalchemy.exc import OperationalError

schema = schema or self.meta.schema

if self.has_table(table_name, schema):
self.meta.reflect(
bind=self.con, only=[table_name], schema=schema, views=True
)
with self.run_transaction():
table = self.get_table(table_name, schema)
try:
self.execute(f"TRUNCATE TABLE {table.name}")
except OperationalError:
raise NotImplementedError("'TRUNCATE' is not supported by this database.")

self.meta.clear()


def _create_sql_schema(
self,
frame: DataFrame,
Expand Down Expand Up @@ -2301,7 +2326,7 @@ def to_sql(
self,
frame,
name: str,
if_exists: Literal["fail", "replace", "append"] = "fail",
if_exists: Literal["fail", "replace", "append", "truncate"] = "fail",
index: bool = True,
index_label=None,
schema: str | None = None,
Expand All @@ -2323,6 +2348,8 @@ def to_sql(
- fail: If table exists, do nothing.
- replace: If table exists, drop it, recreate it, and insert data.
- append: If table exists, insert data. Create if does not exist.
- truncate: If table exists, truncate it. Create if does not exist.
Raises NotImplementedError if 'TRUNCATE TABLE' is not supported
index : boolean, default True
Write DataFrame index as a column.
index_label : string or sequence, default None
Expand All @@ -2340,6 +2367,8 @@ def to_sql(
engine : {'auto', 'sqlalchemy'}, default 'auto'
Raises NotImplementedError if not set to 'auto'
"""
from adbc_driver_manager import ProgrammingError

if index_label:
raise NotImplementedError(
"'index_label' is not implemented for ADBC drivers"
Expand Down Expand Up @@ -2373,6 +2402,14 @@ def to_sql(
cur.execute(f"DROP TABLE {table_name}")
elif if_exists == "append":
mode = "append"
elif if_exists == "truncate":
mode = "append"
with self.con.cursor() as cur:
try:
cur.execute(f"TRUNCATE TABLE {table_name}")
except ProgrammingError:
raise NotImplementedError("'TRUNCATE' is not supported by this database.")


import pyarrow as pa

Expand Down Expand Up @@ -2778,6 +2815,8 @@ def to_sql(
fail: If table exists, do nothing.
replace: If table exists, drop it, recreate it, and insert data.
append: If table exists, insert data. Create if it does not exist.
truncate: If table exists, truncate it. Create if does not exist.
Raises NotImplementedError if 'TRUNCATE TABLE' is not supported
index : bool, default True
Write DataFrame index as a column
index_label : string or sequence, default None
Expand Down Expand Up @@ -2853,6 +2892,9 @@ def drop_table(self, name: str, schema: str | None = None) -> None:
drop_sql = f"DROP TABLE {_get_valid_sqlite_name(name)}"
self.execute(drop_sql)

def truncate_table(self, name:str, schema: str | None) -> None:
raise NotImplementedError("'TRUNCATE' is not supported by this database.")

def _create_sql_schema(
self,
frame,
Expand Down
17 changes: 12 additions & 5 deletions pandas/tests/io/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,7 +606,7 @@ def mysql_pymysql_engine():
sqlalchemy = pytest.importorskip("sqlalchemy")
pymysql = pytest.importorskip("pymysql")
engine = sqlalchemy.create_engine(
"mysql+pymysql://root@localhost:3306/pandas",
"mysql+pymysql://root@mysql:3306/pandas",
connect_args={"client_flag": pymysql.constants.CLIENT.MULTI_STATEMENTS},
poolclass=sqlalchemy.pool.NullPool,
)
Expand Down Expand Up @@ -654,7 +654,7 @@ def postgresql_psycopg2_engine():
sqlalchemy = pytest.importorskip("sqlalchemy")
pytest.importorskip("psycopg2")
engine = sqlalchemy.create_engine(
"postgresql+psycopg2://postgres:postgres@localhost:5432/pandas",
"postgresql+psycopg2://postgres:postgres@postgres:5432/pandas",
poolclass=sqlalchemy.pool.NullPool,
)
yield engine
Expand Down Expand Up @@ -689,7 +689,7 @@ def postgresql_adbc_conn():
pytest.importorskip("adbc_driver_postgresql")
from adbc_driver_postgresql import dbapi

uri = "postgresql://postgres:postgres@localhost:5432/pandas"
uri = "postgresql://postgres:postgres@postgres:5432/pandas"
with dbapi.connect(uri) as conn:
yield conn
for view in get_all_views(conn):
Expand Down Expand Up @@ -1067,12 +1067,19 @@ def test_to_sql(conn, method, test_frame1, request):


@pytest.mark.parametrize("conn", all_connectable)
@pytest.mark.parametrize("mode, num_row_coef", [("replace", 1), ("append", 2)])
@pytest.mark.parametrize("mode, num_row_coef", [("replace", 1), ("append", 2), ("truncate", 1)])
def test_to_sql_exist(conn, mode, num_row_coef, test_frame1, request):
connections_without_truncate = sqlite_connectable + ["sqlite_buildin", "sqlite_adbc_conn"]
if conn in connections_without_truncate and mode == "truncate":
context = pytest.raises(NotImplementedError, match="'TRUNCATE' is not supported by this database.")
else:
context = contextlib.nullcontext()
conn = request.getfixturevalue(conn)

with pandasSQL_builder(conn, need_transaction=True) as pandasSQL:
pandasSQL.to_sql(test_frame1, "test_frame", if_exists="fail")
pandasSQL.to_sql(test_frame1, "test_frame", if_exists=mode)
with context:
pandasSQL.to_sql(test_frame1, "test_frame", if_exists=mode)
assert pandasSQL.has_table("test_frame")
assert count_rows(conn, "test_frame") == num_row_coef * len(test_frame1)

Expand Down

0 comments on commit c85da12

Please sign in to comment.