Skip to content

Commit

Permalink
feat: Implement option 'truncate' of argument 'if_exists' in 'DataFra…
Browse files Browse the repository at this point in the history
…me.to_sql' API.
  • Loading branch information
gmcrocetti committed Aug 3, 2024
1 parent 642d244 commit 29346ff
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 14 deletions.
2 changes: 1 addition & 1 deletion doc/source/whatsnew/v3.0.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,10 @@ Other enhancements
- :meth:`DataFrame.pivot_table` and :func:`pivot_table` now allow the passing of keyword arguments to ``aggfunc`` through ``**kwargs`` (:issue:`57884`)
- :meth:`Series.cummin` and :meth:`Series.cummax` now supports :class:`CategoricalDtype` (:issue:`52335`)
- :meth:`Series.plot` now correctly handle the ``ylabel`` parameter for pie charts, allowing for explicit control over the y-axis label (:issue:`58239`)
- Add ``"truncate"`` option to ``if_exists`` argument in :meth:`DataFrame.to_sql` truncating the table before inserting data (:issue:`37210`).
- Restore support for reading Stata 104-format and enable reading 103-format dta files (:issue:`58554`)
- Support reading Stata 102-format (Stata 1) dta files (:issue:`58978`)
- Support reading Stata 110-format (Stata 7) dta files (:issue:`47176`)

.. ---------------------------------------------------------------------------
.. _whatsnew_300.notable_bug_fixes:

Expand Down
66 changes: 55 additions & 11 deletions pandas/io/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,7 +733,7 @@ def to_sql(
name: str,
con,
schema: str | None = None,
if_exists: Literal["fail", "replace", "append"] = "fail",
if_exists: Literal["fail", "replace", "append", "truncate"] = "fail",
index: bool = True,
index_label: IndexLabel | None = None,
chunksize: int | None = None,
Expand All @@ -759,10 +759,12 @@ def to_sql(
schema : str, optional
Name of SQL schema in database to write to (if database flavor
supports this). If None, use default schema (default).
if_exists : {'fail', 'replace', 'append'}, default 'fail'
if_exists : {'fail', 'replace', 'append', 'truncate'}, default 'fail'
- fail: If table exists, do nothing.
- replace: If table exists, drop it, recreate it, and insert data.
- append: If table exists, insert data. Create if does not exist.
- truncate: If table exists, truncate it. Create if does not exist.
Raises NotImplementedError if 'TRUNCATE TABLE' is not supported
index : bool, default True
Write DataFrame index as a column.
index_label : str or sequence, optional
Expand Down Expand Up @@ -813,7 +815,7 @@ def to_sql(
`sqlite3 <https://docs.python.org/3/library/sqlite3.html#sqlite3.Cursor.rowcount>`__ or
`SQLAlchemy <https://docs.sqlalchemy.org/en/14/core/connections.html#sqlalchemy.engine.BaseCursorResult.rowcount>`__
""" # noqa: E501
if if_exists not in ("fail", "replace", "append"):
if if_exists not in ("fail", "replace", "append", "truncate"):
raise ValueError(f"'{if_exists}' is not valid for if_exists")

if isinstance(frame, Series):
Expand Down Expand Up @@ -921,7 +923,7 @@ def __init__(
pandas_sql_engine,
frame=None,
index: bool | str | list[str] | None = True,
if_exists: Literal["fail", "replace", "append"] = "fail",
if_exists: Literal["fail", "replace", "append", "truncate"] = "fail",
prefix: str = "pandas",
index_label=None,
schema=None,
Expand Down Expand Up @@ -969,11 +971,13 @@ def create(self) -> None:
if self.exists():
if self.if_exists == "fail":
raise ValueError(f"Table '{self.name}' already exists.")
if self.if_exists == "replace":
elif self.if_exists == "replace":
self.pd_sql.drop_table(self.name, self.schema)
self._execute_create()
elif self.if_exists == "append":
pass
elif self.if_exists == "truncate":
self.pd_sql.truncate_table(self.name, self.schema)
else:
raise ValueError(f"'{self.if_exists}' is not valid for if_exists")
else:
Expand Down Expand Up @@ -1465,7 +1469,7 @@ def to_sql(
self,
frame,
name: str,
if_exists: Literal["fail", "replace", "append"] = "fail",
if_exists: Literal["fail", "replace", "append", "truncate"] = "fail",
index: bool = True,
index_label=None,
schema=None,
Expand Down Expand Up @@ -1850,7 +1854,7 @@ def prep_table(
self,
frame,
name: str,
if_exists: Literal["fail", "replace", "append"] = "fail",
if_exists: Literal["fail", "replace", "append", "truncate"] = "fail",
index: bool | str | list[str] | None = True,
index_label=None,
schema=None,
Expand Down Expand Up @@ -1927,7 +1931,7 @@ def to_sql(
self,
frame,
name: str,
if_exists: Literal["fail", "replace", "append"] = "fail",
if_exists: Literal["fail", "replace", "append", "truncate"] = "fail",
index: bool = True,
index_label=None,
schema: str | None = None,
Expand All @@ -1945,10 +1949,12 @@ def to_sql(
frame : DataFrame
name : string
Name of SQL table.
if_exists : {'fail', 'replace', 'append'}, default 'fail'
if_exists : {'fail', 'replace', 'append', 'truncate'}, default 'fail'
- fail: If table exists, do nothing.
- replace: If table exists, drop it, recreate it, and insert data.
- append: If table exists, insert data. Create if does not exist.
- truncate: If table exists, truncate it. Create if does not exist.
Raises NotImplementedError if 'TRUNCATE TABLE' is not supported
index : boolean, default True
Write DataFrame index as a column.
index_label : string or sequence, default None
Expand Down Expand Up @@ -2045,6 +2051,26 @@ def drop_table(self, table_name: str, schema: str | None = None) -> None:
self.get_table(table_name, schema).drop(bind=self.con)
self.meta.clear()

def truncate_table(self, table_name: str, schema: str | None = None) -> None:
from sqlalchemy.exc import OperationalError

schema = schema or self.meta.schema

if self.has_table(table_name, schema):
self.meta.reflect(
bind=self.con, only=[table_name], schema=schema, views=True
)
with self.run_transaction():
table = self.get_table(table_name, schema)
try:
self.execute(f"TRUNCATE TABLE {table.name}")
except OperationalError as exc:
raise NotImplementedError(
"'TRUNCATE TABLE' is not supported by this database."
) from exc

self.meta.clear()

def _create_sql_schema(
self,
frame: DataFrame,
Expand Down Expand Up @@ -2301,7 +2327,7 @@ def to_sql(
self,
frame,
name: str,
if_exists: Literal["fail", "replace", "append"] = "fail",
if_exists: Literal["fail", "replace", "append", "truncate"] = "fail",
index: bool = True,
index_label=None,
schema: str | None = None,
Expand All @@ -2323,6 +2349,8 @@ def to_sql(
- fail: If table exists, do nothing.
- replace: If table exists, drop it, recreate it, and insert data.
- append: If table exists, insert data. Create if does not exist.
- truncate: If table exists, truncate it. Create if does not exist.
Raises NotImplementedError if 'TRUNCATE TABLE' is not supported
index : boolean, default True
Write DataFrame index as a column.
index_label : string or sequence, default None
Expand All @@ -2340,6 +2368,8 @@ def to_sql(
engine : {'auto', 'sqlalchemy'}, default 'auto'
Raises NotImplementedError if not set to 'auto'
"""
from adbc_driver_manager import ProgrammingError

if index_label:
raise NotImplementedError(
"'index_label' is not implemented for ADBC drivers"
Expand Down Expand Up @@ -2373,6 +2403,15 @@ def to_sql(
cur.execute(f"DROP TABLE {table_name}")
elif if_exists == "append":
mode = "append"
elif if_exists == "truncate":
mode = "append"
with self.con.cursor() as cur:
try:
cur.execute(f"TRUNCATE TABLE {table_name}")
except ProgrammingError as exc:
raise NotImplementedError(
"'TRUNCATE TABLE' is not supported by this database."
) from exc

import pyarrow as pa

Expand Down Expand Up @@ -2774,10 +2813,12 @@ def to_sql(
frame: DataFrame
name: string
Name of SQL table.
if_exists: {'fail', 'replace', 'append'}, default 'fail'
if_exists: {'fail', 'replace', 'append', 'truncate'}, default 'fail'
fail: If table exists, do nothing.
replace: If table exists, drop it, recreate it, and insert data.
append: If table exists, insert data. Create if it does not exist.
truncate: If table exists, truncate it. Create if does not exist.
Raises NotImplementedError if 'TRUNCATE TABLE' is not supported
index : bool, default True
Write DataFrame index as a column
index_label : string or sequence, default None
Expand Down Expand Up @@ -2853,6 +2894,9 @@ def drop_table(self, name: str, schema: str | None = None) -> None:
drop_sql = f"DROP TABLE {_get_valid_sqlite_name(name)}"
self.execute(drop_sql)

def truncate_table(self, name: str, schema: str | None = None) -> None:
raise NotImplementedError("'TRUNCATE TABLE' is not supported by this database.")

def _create_sql_schema(
self,
frame,
Expand Down
60 changes: 58 additions & 2 deletions pandas/tests/io/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -1067,12 +1067,27 @@ def test_to_sql(conn, method, test_frame1, request):


@pytest.mark.parametrize("conn", all_connectable)
@pytest.mark.parametrize("mode, num_row_coef", [("replace", 1), ("append", 2)])
@pytest.mark.parametrize(
"mode, num_row_coef", [("replace", 1), ("append", 2), ("truncate", 1)]
)
def test_to_sql_exist(conn, mode, num_row_coef, test_frame1, request):
connections_without_truncate = sqlite_connectable + [
"sqlite_buildin",
"sqlite_adbc_conn",
]
if conn in connections_without_truncate and mode == "truncate":
context = pytest.raises(
NotImplementedError,
match="'TRUNCATE TABLE' is not supported by this database.",
)
else:
context = contextlib.nullcontext()
conn = request.getfixturevalue(conn)

with pandasSQL_builder(conn, need_transaction=True) as pandasSQL:
pandasSQL.to_sql(test_frame1, "test_frame", if_exists="fail")
pandasSQL.to_sql(test_frame1, "test_frame", if_exists=mode)
with context:
pandasSQL.to_sql(test_frame1, "test_frame", if_exists=mode)
assert pandasSQL.has_table("test_frame")
assert count_rows(conn, "test_frame") == num_row_coef * len(test_frame1)

Expand Down Expand Up @@ -2697,6 +2712,47 @@ def test_drop_table(conn, request):
assert not insp.has_table("temp_frame")


@pytest.mark.parametrize("conn", mysql_connectable + postgresql_connectable)
def test_truncate_table_success(conn, test_frame1, request):
table_name = "temp_frame"
conn = request.getfixturevalue(conn)

with sql.SQLDatabase(conn) as pandasSQL:
with pandasSQL.run_transaction():
assert pandasSQL.to_sql(test_frame1, table_name, if_exists="replace") == 4

with pandasSQL.run_transaction():
pandasSQL.truncate_table(table_name)
assert count_rows(conn, table_name) == 0


@pytest.mark.parametrize("conn", sqlite_connectable)
def test_truncate_table_not_supported(conn, test_frame1, request):
table_name = "temp_frame"
conn = request.getfixturevalue(conn)

with sql.SQLDatabase(conn) as pandasSQL:
with pandasSQL.run_transaction():
assert pandasSQL.to_sql(test_frame1, table_name, if_exists="replace") == 4

with pandasSQL.run_transaction():
with pytest.raises(
NotImplementedError,
match="'TRUNCATE TABLE' is not supported by this database.",
):
pandasSQL.truncate_table(table_name)
assert count_rows(conn, table_name) == len(test_frame1)


def test_truncate_table_sqlite(sqlite_buildin):
with sql.SQLiteDatabase(sqlite_buildin) as pandasSQL:
with pytest.raises(
NotImplementedError,
match="'TRUNCATE TABLE' is not supported by this database.",
):
pandasSQL.truncate_table("table")


@pytest.mark.parametrize("conn", all_connectable)
def test_roundtrip(conn, request, test_frame1):
if conn == "sqlite_str":
Expand Down

0 comments on commit 29346ff

Please sign in to comment.