diff --git a/config.yaml b/config.yaml index 8e05fec54c..943b49bfad 100644 --- a/config.yaml +++ b/config.yaml @@ -215,6 +215,10 @@ options: default: false type: boolean description: Enable uuid_ossp extension. + plugin_spi_enable: + default: false + type: boolean + description: Enable spi extension. profile: description: | Profile representing the scope of deployment, and used to tune resource allocation. diff --git a/src/charm.py b/src/charm.py index 43fd2bdde4..3713e45b93 100755 --- a/src/charm.py +++ b/src/charm.py @@ -482,6 +482,7 @@ def enable_disable_extensions(self, database: str = None) -> None: Args: database: optional database where to enable/disable the extension. """ + spi_module = ["refint", "autoinc", "insert_username", "moddatetime"] original_status = self.unit.status extensions = {} # collect extensions @@ -491,6 +492,10 @@ def enable_disable_extensions(self, database: str = None) -> None: # Enable or disable the plugin/extension. extension = "_".join(plugin.split("_")[1:-1]) + if extension == "spi": + for ext in spi_module: + extensions[ext] = enable + continue extension = plugins_exception.get(extension, extension) extensions[extension] = enable self.unit.status = WaitingStatus("Updating extensions") diff --git a/src/config.py b/src/config.py index b0339b7a67..3121c23dc2 100644 --- a/src/config.py +++ b/src/config.py @@ -62,6 +62,7 @@ class CharmConfig(BaseConfigModel): plugin_tsm_system_rows_enable: bool plugin_tsm_system_time_enable: bool plugin_uuid_ossp_enable: bool + plugin_spi_enable: bool request_date_style: Optional[str] request_standard_conforming_strings: Optional[bool] request_time_zone: Optional[str] diff --git a/tests/integration/test_plugins.py b/tests/integration/test_plugins.py index 9293fb4cd5..d538b853d5 100644 --- a/tests/integration/test_plugins.py +++ b/tests/integration/test_plugins.py @@ -53,6 +53,10 @@ TSM_SYSTEM_ROWS_EXTENSION_STATEMENT = "CREATE TABLE tsm_system_rows_test (i int);SELECT * FROM tsm_system_rows_test TABLESAMPLE SYSTEM_ROWS(100);" TSM_SYSTEM_TIME_EXTENSION_STATEMENT = "CREATE TABLE tsm_system_time_test (i int);SELECT * FROM tsm_system_time_test TABLESAMPLE SYSTEM_TIME(1000);" UUID_OSSP_EXTENSION_STATEMENT = "SELECT uuid_nil();" +REFINT_EXTENSION_STATEMENT = "CREATE TABLE A (ID int4 not null); CREATE UNIQUE INDEX AI ON A (ID);CREATE TABLE B (REFB int4);CREATE INDEX BI ON B (REFB);CREATE TRIGGER BT BEFORE INSERT OR UPDATE ON B FOR EACH ROW EXECUTE PROCEDURE check_primary_key ('REFB', 'A', 'ID');" +AUTOINC_EXTENSION_STATEMENT = "CREATE TABLE ids (id int4, idesc text);CREATE TRIGGER ids_nextid BEFORE INSERT OR UPDATE ON ids FOR EACH ROW EXECUTE PROCEDURE autoinc (id, next_id);" +INSERT_USERNAME_EXTENSION_STATEMENT = "CREATE TABLE username_test (name text, username text not null);CREATE TRIGGER insert_usernames BEFORE INSERT OR UPDATE ON username_test FOR EACH ROW EXECUTE PROCEDURE insert_username (username);" +MODDATETIME_EXTENSION_STATEMENT = "CREATE TABLE mdt (moddate timestamp DEFAULT CURRENT_TIMESTAMP NOT NULL);CREATE TRIGGER mdt_moddatetime BEFORE UPDATE ON mdt FOR EACH ROW EXECUTE PROCEDURE moddatetime (moddate);" @pytest.mark.abort_on_fail @@ -92,6 +96,12 @@ async def test_plugins(ops_test: OpsTest) -> None: "plugin_tsm_system_rows_enable": TSM_SYSTEM_ROWS_EXTENSION_STATEMENT, "plugin_tsm_system_time_enable": TSM_SYSTEM_TIME_EXTENSION_STATEMENT, "plugin_uuid_ossp_enable": UUID_OSSP_EXTENSION_STATEMENT, + "plugin_spi_enable": [ + REFINT_EXTENSION_STATEMENT, + AUTOINC_EXTENSION_STATEMENT, + INSERT_USERNAME_EXTENSION_STATEMENT, + MODDATETIME_EXTENSION_STATEMENT, + ], } def enable_disable_config(enabled: False): @@ -113,8 +123,13 @@ def enable_disable_config(enabled: False): with db_connect(host=address, password=password) as connection: connection.autocommit = True for query in sql_tests.values(): - with pytest.raises(psycopg2.Error): - connection.cursor().execute(query) + if isinstance(query, list): + for test in query: + with pytest.raises(psycopg2.Error): + connection.cursor().execute(test) + else: + with pytest.raises(psycopg2.Error): + connection.cursor().execute(query) connection.close() # Enable the plugins. @@ -129,5 +144,9 @@ def enable_disable_config(enabled: False): with db_connect(host=address, password=password) as connection: connection.autocommit = True for query in sql_tests.values(): - connection.cursor().execute(query) + if isinstance(query, list): + for test in query: + connection.cursor().execute(test) + else: + connection.cursor().execute(query) connection.close()