diff --git a/redshift_connector/core.py b/redshift_connector/core.py index ba1a4c5..1adf449 100644 --- a/redshift_connector/core.py +++ b/redshift_connector/core.py @@ -43,7 +43,7 @@ NotSupportedError, OperationalError, ProgrammingError, - Warning, + Warning, DataError, ) from redshift_connector.utils import ( FC_BINARY, @@ -1705,6 +1705,17 @@ def execute(self: "Connection", cursor: Cursor, operation: str, vals) -> None: # Int32 - The OID of the parameter data type. val: typing.Union[bytes, bytearray] = bytearray(statement_name_bin) typing.cast(bytearray, val).extend(statement.encode(_client_encoding) + NULL_BYTE) + if len(params) > 32767: + raise DataError("Prepared statement exceeds bind parameter limit 32767. {} bind parameters " + "were provided. Please retry with fewer bind parameters. \n" + "If you are a dbt user please see the following: \n" + " 1. Please try setting a smaller batch size. See " + "documentation for details if using dbt: " + "https://docs.getdbt.com/docs/core/connect-data-platform/redshift-setup \n" + " 2. Your seed use case may be inappropriate. Try using Redshift COPY instead. See " + "seed documentation for details: " + "https://docs.getdbt.com/docs/build/seeds \n".format(len(params))) + typing.cast(bytearray, val).extend(h_pack(len(params))) for oid, fc, send_func in params: # type: ignore # Parse message doesn't seem to handle the -1 type_oid for NULL diff --git a/redshift_connector/cursor.py b/redshift_connector/cursor.py index 3961cbc..2121297 100644 --- a/redshift_connector/cursor.py +++ b/redshift_connector/cursor.py @@ -17,7 +17,7 @@ from redshift_connector.error import ( MISSING_MODULE_ERROR_MSG, InterfaceError, - ProgrammingError, + ProgrammingError, DataError, ) if TYPE_CHECKING: @@ -338,7 +338,9 @@ def insert_data_bulk( sql_param_lists = [sql_param_list_template] * row_count insert_stmt = base_stmt + ", ".join(sql_param_lists) + ";" self.execute(insert_stmt, values_list) - + except DataError as e: + raise DataError("Prepared statement exceeds bind parameter limit 32767. Please set a smaller " + "batch size or retry with fewer bind parameters.") except Exception as e: raise InterfaceError(e) finally: diff --git a/test/integration/test_cursor.py b/test/integration/test_cursor.py index 4028c8a..fae9bad 100644 --- a/test/integration/test_cursor.py +++ b/test/integration/test_cursor.py @@ -4,7 +4,7 @@ import pytest # type: ignore import redshift_connector -from redshift_connector import InterfaceError +from redshift_connector import InterfaceError, DataError @pytest.mark.parametrize("col_name", (("apples", "apples"), ("author‎ ", "author\u200e"))) @@ -21,11 +21,11 @@ def test_get_description(db_kwargs, col_name): @pytest.mark.parametrize( "col_names", ( - ("(c1 int, c2 int, c3 int)", ("c1", "c2", "c3")), - ( - "(áppleṣ int, orañges int, passion⁘fruit int, papaya  int, bañanaș int)", - ("áppleṣ", "orañges", "passion⁘fruit", "papaya\u205f", "bañanaș"), - ), + ("(c1 int, c2 int, c3 int)", ("c1", "c2", "c3")), + ( + "(áppleṣ int, orañges int, passion⁘fruit int, papaya  int, bañanaș int)", + ("áppleṣ", "orañges", "passion⁘fruit", "papaya\u205f", "bañanaș"), + ), ), ) def test_get_description_multiple_column_names(db_kwargs, col_names): @@ -55,8 +55,8 @@ def test_insert_data_invalid_column_raises(mocked_csv, db_kwargs): cursor.execute("create temporary table githubissue161 (id int)") with pytest.raises( - InterfaceError, - match="Invalid column name. No results were returned when performing column name validity check.", + InterfaceError, + match="Invalid column name. No results were returned when performing column name validity check.", ): cursor.insert_data_bulk( filename="mocked_csv", @@ -66,3 +66,43 @@ def test_insert_data_invalid_column_raises(mocked_csv, db_kwargs): delimiter=",", batch_size=3, ) + + +def test_insert_data_raises_too_many_params(db_kwargs): + max_params = 32767 + prepared_stmt = "INSERT INTO githubissue165 (col1) VALUES " + "(%s), " * max_params + "(%s);" + params = [1 for _ in range(max_params + 1)] + + with redshift_connector.connect(**db_kwargs) as conn: + with conn.cursor() as cursor: + cursor.execute("create temporary table githubissue165 (col1 int)") + + with pytest.raises( + DataError, + match="Prepared statement exceeds bind parameter limit 32767.", + ): + cursor.execute(prepared_stmt, params) + + +@patch("builtins.open", new_callable=mock_open) +def test_insert_data_bulk_raises_too_many_params(mocked_csv, db_kwargs): + max_params = 32767 + indexes, names = ( + [0], + ["col1"], + ) + csv_str = "\col1\n" + "1\n" * max_params + "1" # 32768 rows + mocked_csv.side_effect = [StringIO(csv_str)] + + with redshift_connector.connect(**db_kwargs) as conn: + with conn.cursor() as cursor: + cursor.execute("create temporary table githubissue165 (col1 int)") + with pytest.raises(DataError, match="Prepared statement exceeds bind parameter limit 32767."): + cursor.insert_data_bulk( + filename="mocked_csv", + table_name="githubissue165", + parameter_indices=indexes, + column_names=["col1"], + delimiter=",", + batch_size=max_params + 1, + ) diff --git a/test/unit/test_cursor.py b/test/unit/test_cursor.py index e17eb5e..2eed3b3 100644 --- a/test/unit/test_cursor.py +++ b/test/unit/test_cursor.py @@ -6,7 +6,7 @@ import pytest # type: ignore -from redshift_connector import Connection, Cursor, InterfaceError +from redshift_connector import Connection, Cursor, InterfaceError, DataError IS_SINGLE_DATABASE_METADATA_TOGGLE: typing.List[bool] = [True, False] @@ -406,3 +406,55 @@ def test_insert_data_uses_batch_size(mocked_csv, batch_size, mocker): actual_insert_stmts_executed += 1 assert actual_insert_stmts_executed == ceil(3 / batch_size) + + +@patch("builtins.open", new_callable=mock_open) +def test_insert_data_bulk_raises_too_many_parameters(mocked_csv, mocker): + # mock fetchone to return "True" to ensure the table_name and column_name + # validation steps pass + mocker.patch("redshift_connector.Cursor.fetchone", return_value=[1]) + + mock_cursor: Cursor = Cursor.__new__(Cursor) + + # mock out the connection to raise DataError. + mock_cursor._c = Mock() + mocker.patch.object(mock_cursor._c, 'execute', side_effect=DataError("Prepared statement exceeds bind " + "parameter limit 32767.")) + mock_cursor.paramstyle = "mocked" + + max_params = 32767 + indexes, names = ( + [0], + ["col1"], + ) + + csv_str = "\col1\n" + "1\n" * max_params + "1" # 32768 rows + mocked_csv.side_effect = [StringIO(csv_str)] + + with pytest.raises(DataError, match="Prepared statement exceeds bind parameter limit 32767."): + mock_cursor.insert_data_bulk( + filename="mocked_csv", + table_name="githubissue165", + parameter_indices=indexes, + column_names=["col1"], + delimiter=",", + batch_size=max_params + 1, + ) + + +@patch("builtins.open", new_callable=mock_open) +def test_insert_data_raises_too_many_parameters(mocker): + mock_cursor: Cursor = Cursor.__new__(Cursor) + + # mock out the connection to raise DataError. + mock_cursor._c = Mock() + mock_cursor._c.execute.side_effect = DataError("Prepared statement exceeds bind " + "parameter limit 32767.") + mock_cursor.paramstyle = "mocked" + + max_params = 32767 + prepared_stmt = "INSERT INTO githubissue165 (col1) VALUES " + "(%s), " * max_params + "(%s);" + params = [1 for _ in range(max_params + 1)] + + with pytest.raises(DataError, match="Prepared statement exceeds bind parameter limit 32767."): + mock_cursor.execute(prepared_stmt, params)