diff --git a/pom.xml b/pom.xml index a3d72148..c48f3806 100644 --- a/pom.xml +++ b/pom.xml @@ -66,6 +66,7 @@ 1.0.10 5.11.3 1.11.3 + 5.14.2 target @@ -124,6 +125,18 @@ ${jsonassert.version} test + + org.mockito + mockito-core + ${mockito.version} + test + + + org.mockito + mockito-junit-jupiter + ${mockito.version} + test + diff --git a/src/main/java/liquibase/ext/databricks/snapshot/jvm/ColumnSnapshotGeneratorDatabricks.java b/src/main/java/liquibase/ext/databricks/snapshot/jvm/ColumnSnapshotGeneratorDatabricks.java index e886c477..6b90f160 100644 --- a/src/main/java/liquibase/ext/databricks/snapshot/jvm/ColumnSnapshotGeneratorDatabricks.java +++ b/src/main/java/liquibase/ext/databricks/snapshot/jvm/ColumnSnapshotGeneratorDatabricks.java @@ -4,26 +4,39 @@ import liquibase.exception.DatabaseException; import liquibase.ext.databricks.database.DatabricksDatabase; import liquibase.snapshot.CachedRow; +import liquibase.snapshot.DatabaseSnapshot; import liquibase.snapshot.SnapshotGenerator; import liquibase.snapshot.jvm.ColumnSnapshotGenerator; +import liquibase.statement.DatabaseFunction; import liquibase.structure.DatabaseObject; import liquibase.structure.core.Column; import liquibase.structure.core.DataType; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + public class ColumnSnapshotGeneratorDatabricks extends ColumnSnapshotGenerator { + private static final String ALL_DATA_TYPES = " BIGINT | BINARY | BOOLEAN | DATE | DECIMAL| DECIMAL\\(| DOUBLE | FLOAT | INT | INTERVAL | VOID | SMALLINT | STRING | VARCHAR\\(\\d+\\) | TIMESTAMP | TIMESTAMP_NTZ | TINYINT | ARRAY<| MAP<| STRUCT<| VARIANT| OBJECT<"; + private static final String DEFAULT_CLAUSE_TERMINATORS = "(?i)(\\s+COMMENT\\s+'| PRIMARY\\s+KEY | FOREIGN\\s+KEY | MASK\\s+\\w+|$|,(\\s+\\w+\\s+" + ALL_DATA_TYPES + "|\\)$)"; + private static final String GENERATED_BY_DEFAULT_REGEX = "(?i)\\s+GENERATED\\s+BY\\s+DEFAULT\\s+AS\\s+IDENTITY"; + private static final String GENERIC_DEFAULT_VALUE_REGEX = "DEFAULT\\s+(.*?)(" + DEFAULT_CLAUSE_TERMINATORS + "?))"; + private static final String SANITIZE_TABLE_SPECIFICATION_REGEX = "(\\(.*?\\))\\s*(?i)(USING|OPTIONS|PARTITIONED BY|CLUSTER BY|LOCATION|TBLPROPERTIES|WITH|$|;$)"; + private static final Pattern DEFAULT_VALUE_PATTERN = Pattern.compile(GENERIC_DEFAULT_VALUE_REGEX); + private static final Pattern SANITIZE_TABLE_SPECIFICATION_PATTERN = Pattern.compile(SANITIZE_TABLE_SPECIFICATION_REGEX); + private static final Pattern FUNCTION_PATTERN = Pattern.compile("^(\\w+)\\(.*\\)"); + @Override public int getPriority(Class objectType, Database database) { if (database instanceof DatabricksDatabase) { - return super.getPriority(objectType, database) + PRIORITY_DATABASE; - } else { - return PRIORITY_NONE; + return PRIORITY_DATABASE; } + return PRIORITY_NONE; } @Override public Class[] replaces() { - return new Class[] { ColumnSnapshotGenerator.class }; + return new Class[]{ColumnSnapshotGenerator.class}; } /** @@ -43,4 +56,68 @@ protected DataType readDataType(CachedRow columnMetadataResultSet, Column column } return super.readDataType(columnMetadataResultSet, column, database); } + + @Override + protected DatabaseObject snapshotObject(DatabaseObject example, DatabaseSnapshot snapshot) throws DatabaseException { + //This should work after fix on Databricks side + if (example instanceof Column) { + Column column = (Column) super.snapshotObject(example, snapshot); + //These two are used too often, avoiding them? otherwise there would be too much DB calls + String showCreateRelatedTableQuery = String.format("SHOW CREATE TABLE %s.%s.%s;", + column.getRelation().getSchema().getCatalog(), + column.getRelation().getSchema().getName(), + column.getRelation().getName()); + if (snapshot.getScratchData(showCreateRelatedTableQuery) != null) { + String showCreateTableStatement = (String) snapshot.getScratchData(showCreateRelatedTableQuery); + String defaultValue = extractDefaultValue(showCreateTableStatement, column.getName()); + if (defaultValue != null) { + Matcher functionMatcher = FUNCTION_PATTERN.matcher(defaultValue); + if (functionMatcher.find()) { + DatabaseFunction function = new DatabaseFunction(defaultValue); + column.setDefaultValue(function); + column.setComputed(true); + } else { + column.setDefaultValue(defaultValue); + } + } + } + return column; + } else { + return example; + } + } + + private String extractDefaultValue(String createTableStatement, String columnName) { + String defaultValue = null; + String sanitizedCreateTableStatement = sanitizeStatement(createTableStatement); + Pattern columnWithPotentialDefaultPattern = Pattern.compile("[\\(|,]\\s*(" + columnName + "\\s*\\b\\w*\\b.*?)\\s*(?i)(" + ALL_DATA_TYPES + "|( CONSTRAINT |$))"); + Matcher columnWithPotentialDefaultMatcher = columnWithPotentialDefaultPattern.matcher(sanitizedCreateTableStatement); + + String columnWithPotentialDefault = ""; + if (columnWithPotentialDefaultMatcher.find()) { + columnWithPotentialDefault = columnWithPotentialDefaultMatcher.group(1); + Matcher stringColumnTypeMatcher = Pattern.compile(columnName + "\\s+(?i)(VARCHAR\\(\\d+\\)|STRING )") + .matcher(sanitizedCreateTableStatement); + Matcher defaultStringValueMatcher = Pattern.compile(columnName + ".+?(?i)DEFAULT\\s+(\\'|\\\")(.*?)\\1") + .matcher(sanitizedCreateTableStatement); + Matcher defaultValueMatcher = DEFAULT_VALUE_PATTERN.matcher(columnWithPotentialDefault); + if (defaultValueMatcher.find()) { + defaultValue = defaultValueMatcher.group(1); + if (stringColumnTypeMatcher.find() && defaultStringValueMatcher.find()) { + defaultValue = defaultStringValueMatcher.group(2); + } + } + } + return defaultValue; + } + + private String sanitizeStatement(String createTableStatement) { + createTableStatement = createTableStatement.replace("\n", ""); + String sanitizedCreateTableStatement = createTableStatement.replaceAll(GENERATED_BY_DEFAULT_REGEX, " "); + Matcher tableSpecificationMatcher = SANITIZE_TABLE_SPECIFICATION_PATTERN.matcher(sanitizedCreateTableStatement); + if (tableSpecificationMatcher.find()) { + sanitizedCreateTableStatement = tableSpecificationMatcher.group(1); + } + return sanitizedCreateTableStatement; + } } diff --git a/src/main/java/liquibase/ext/databricks/snapshot/jvm/TableSnapshotGeneratorDatabricks.java b/src/main/java/liquibase/ext/databricks/snapshot/jvm/TableSnapshotGeneratorDatabricks.java index c4041108..74ef6f87 100644 --- a/src/main/java/liquibase/ext/databricks/snapshot/jvm/TableSnapshotGeneratorDatabricks.java +++ b/src/main/java/liquibase/ext/databricks/snapshot/jvm/TableSnapshotGeneratorDatabricks.java @@ -3,6 +3,7 @@ import liquibase.Scope; import liquibase.database.Database; import liquibase.exception.DatabaseException; +import liquibase.executor.Executor; import liquibase.executor.ExecutorService; import liquibase.ext.databricks.database.DatabricksDatabase; import liquibase.snapshot.DatabaseSnapshot; @@ -47,8 +48,18 @@ protected DatabaseObject snapshotObject(DatabaseObject example, DatabaseSnapshot if (table != null) { String query = String.format("DESCRIBE TABLE EXTENDED %s.%s.%s;", database.getDefaultCatalogName(), database.getDefaultSchemaName(), example.getName()); - List> tablePropertiesResponse = Scope.getCurrentScope().getSingleton(ExecutorService.class) - .getExecutor("jdbc", database).queryForList(new RawParameterizedSqlStatement(query)); + Executor jdbcExecutor = Scope.getCurrentScope().getSingleton(ExecutorService.class).getExecutor("jdbc", database); + List> tablePropertiesResponse = jdbcExecutor.queryForList(new RawParameterizedSqlStatement(query)); + //Skipping changelog tables default values processing + List changelogTableNames = Arrays.asList(database.getDatabaseChangeLogLockTableName(), database.getDatabaseChangeLogTableName()); + if(!changelogTableNames.contains(table.getName())) { + String showCreateTableQuery = String.format("SHOW CREATE TABLE %s.%s.%s;", table.getSchema().getCatalog(), + table.getSchema().getName(), table.getName()); + if(snapshot.getScratchData(showCreateTableQuery) == null) { + String createTableStatement = jdbcExecutor.queryForObject(new RawParameterizedSqlStatement(showCreateTableQuery), String.class); + snapshot.setScratchData(showCreateTableQuery, createTableStatement); + } + } StringBuilder tableFormat = new StringBuilder(); // DESCRIBE TABLE EXTENDED returns both columns and additional information. // We need to make sure "Location" is not column in the table, but table location in s3 diff --git a/src/test/java/liquibase/ext/databricks/snapshot/jvm/ColumnSnapshotGeneratorDatabricksTest.java b/src/test/java/liquibase/ext/databricks/snapshot/jvm/ColumnSnapshotGeneratorDatabricksTest.java new file mode 100644 index 00000000..f019af6d --- /dev/null +++ b/src/test/java/liquibase/ext/databricks/snapshot/jvm/ColumnSnapshotGeneratorDatabricksTest.java @@ -0,0 +1,99 @@ +package liquibase.ext.databricks.snapshot.jvm; + +import liquibase.database.jvm.JdbcConnection; +import liquibase.exception.DatabaseException; +import liquibase.ext.databricks.database.DatabricksDatabase; +import liquibase.snapshot.JdbcDatabaseSnapshot; +import liquibase.statement.DatabaseFunction; +import liquibase.structure.DatabaseObject; +import liquibase.structure.core.Column; +import liquibase.structure.core.Table; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.ResultSetMetaData; +import java.sql.SQLException; +import java.util.HashMap; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.when; + +@ExtendWith(MockitoExtension.class) +class ColumnSnapshotGeneratorDatabricksTest { + + @Mock + private JdbcDatabaseSnapshot snapshot; + @InjectMocks + private ColumnSnapshotGeneratorDatabricks snapshotGenerator; + @Captor + private ArgumentCaptor queryCaptor; + private DatabaseObject testedColumn; + + private static final Map COLUMN_WITH_DEFAULT_NAMES = new HashMap() {{ + put("intcolumn", "-101"); + put("eventDescription", "default string, regular ? almost();!$#@^%[] String bigint"); + put("eventShortDescription", "short desc"); + }}; + private static final Map COLUMN_WITH_DEFAULT_COMPUTED_NAMES = new HashMap() {{ + put("eventTime", new DatabaseFunction("current_timestamp()")); + put("year", new DatabaseFunction("YEAR(CURRENT_TIMESTAMP())")); + put("eventDate", new DatabaseFunction("CAST(CURRENT_TIMESTAMP() AS DATE)")); + }}; + private static final String TEST_CATALOG_NAME = "main"; + private static final String TEST_SCHEMA_NAME = "liquibase_harness_test_ds"; + private static final String TEST_TABLE_NAME = "tablewithdefaultvalues"; + private static final String EXPECTED_SHOW_CREATE_QUERY = "SHOW CREATE TABLE main.liquibase_harness_test_ds.tablewithdefaultvalues;"; + private static final String SHOW_CREATE_TABLE_RESPONSE = "CREATE TABLE main.liquibase_harness_test_ds.tablewithdefaultvalues(" + + "longcolumn BIGINT GENERATED BY DEFAULT AS IDENTITY (START WITH 1 INCREMENT BY 1), " + + "intcolumn INT DEFAULT -101, " + + "eventTime TIMESTAMP NOT NULL DEFAULT current_timestamp(), " + + "year INT DEFAULT YEAR(CURRENT_TIMESTAMP()), " + + "eventDate DATE DEFAULT CAST(CURRENT_TIMESTAMP() AS DATE) COMMENT 'a comment, " + + "eventDescription STRING NOT NULL DEFAULT 'default string, regular ? almost();!$#@^%[] String bigint', " + + "eventShortDescription STRING DEFAULT \"short desc\") USING delta " + + " TBLPROPERTIES ('delta.columnMapping.mode' = 'name', 'delta.feature.allowColumnDefaults' = 'supported') "; + + + @BeforeEach + public void setUp() throws DatabaseException, SQLException { + snapshotGenerator = new ColumnSnapshotGeneratorDatabricks(); + testedColumn = new Column(); + testedColumn.setAttribute("relation", new Table(TEST_CATALOG_NAME, TEST_SCHEMA_NAME, TEST_TABLE_NAME)); + when(snapshot.getScratchData(queryCaptor.capture())).thenReturn(SHOW_CREATE_TABLE_RESPONSE); + } + + @Test + void snapshotObjectTest() throws DatabaseException, SQLException { + for(Map.Entry columnWithDefault : COLUMN_WITH_DEFAULT_NAMES.entrySet()) { + testedColumn.setName(columnWithDefault.getKey()); + testedColumn.setAttribute("liquibase-complete", true); + DatabaseObject databaseObject = snapshotGenerator.snapshotObject(testedColumn, snapshot); + assertTrue(databaseObject instanceof Column); + assertNull(((Column) databaseObject).getComputed()); + assertNotNull(((Column) databaseObject).getDefaultValue()); + assertEquals(columnWithDefault.getValue(), ((Column) databaseObject).getDefaultValue()); + } + for(Map.Entry columnWithDefaultComputed: COLUMN_WITH_DEFAULT_COMPUTED_NAMES.entrySet()) { + testedColumn.setName(columnWithDefaultComputed.getKey()); + testedColumn.setAttribute("liquibase-complete", true); + DatabaseObject databaseObject = snapshotGenerator.snapshotObject(testedColumn, snapshot); + assertTrue(databaseObject instanceof Column); + assertTrue(((Column) databaseObject).getComputed()); + assertNotNull(((Column) databaseObject).getDefaultValue()); + assertEquals(columnWithDefaultComputed.getValue(), ((Column) databaseObject).getDefaultValue()); + } + assertEquals(EXPECTED_SHOW_CREATE_QUERY, queryCaptor.getValue()); + } +} diff --git a/src/test/resources/liquibase/harness/change/changelogs/databricks/createTableWithDefaultValues.xml b/src/test/resources/liquibase/harness/change/changelogs/databricks/createTableWithDefaultValues.xml index 51fbe93b..b776f0c6 100644 --- a/src/test/resources/liquibase/harness/change/changelogs/databricks/createTableWithDefaultValues.xml +++ b/src/test/resources/liquibase/harness/change/changelogs/databricks/createTableWithDefaultValues.xml @@ -11,13 +11,15 @@ - - - - + - + + + + + + diff --git a/src/test/resources/liquibase/harness/change/expectedSql/databricks/createTableWithDefaultValues.sql b/src/test/resources/liquibase/harness/change/expectedSql/databricks/createTableWithDefaultValues.sql index 950940a0..0f88440a 100644 --- a/src/test/resources/liquibase/harness/change/expectedSql/databricks/createTableWithDefaultValues.sql +++ b/src/test/resources/liquibase/harness/change/expectedSql/databricks/createTableWithDefaultValues.sql @@ -1 +1 @@ -CREATE TABLE main.liquibase_harness_test_ds.tableWithDefaultValues (longcolumn LONG GENERATED BY DEFAULT AS IDENTITY (START WITH 1 INCREMENT BY 1), eventTime TIMESTAMP, year INT GENERATED ALWAYS AS (YEAR(eventTime)), eventDate date GENERATED ALWAYS AS (CAST(eventTime AS DATE)), eventDescription STRING NOT NULL, eventShortDescription STRING GENERATED ALWAYS AS (SUBSTRING(eventDescription, 0, 1))) USING delta TBLPROPERTIES('delta.feature.allowColumnDefaults' = 'supported', 'delta.columnMapping.mode' = 'name', 'delta.enableDeletionVectors' = true) \ No newline at end of file +CREATE TABLE main.liquibase_harness_test_ds.tableWithDefaultValues (longcolumn LONG GENERATED BY DEFAULT AS IDENTITY (START WITH 1 INCREMENT BY 1), eventTime TIMESTAMP DEFAULT current_timestamp() NOT NULL, year INT DEFAULT YEAR(CURRENT_TIMESTAMP()), eventDate date DEFAULT CAST(CURRENT_TIMESTAMP() AS DATE), eventDescription STRING DEFAULT 'default string, regular ? almost();!$#@^%[] String bigint' NOT NULL, eventShortDescription STRING DEFAULT 'short desc') USING delta TBLPROPERTIES('delta.feature.allowColumnDefaults' = 'supported', 'delta.columnMapping.mode' = 'name', 'delta.enableDeletionVectors' = true) \ No newline at end of file