diff --git a/presto-hive/src/main/java/io/prestosql/plugin/hive/metastore/RecordingHiveMetastore.java b/presto-hive/src/main/java/io/prestosql/plugin/hive/metastore/RecordingHiveMetastore.java index c3fc5e1eea91..17289238ef9b 100644 --- a/presto-hive/src/main/java/io/prestosql/plugin/hive/metastore/RecordingHiveMetastore.java +++ b/presto-hive/src/main/java/io/prestosql/plugin/hive/metastore/RecordingHiveMetastore.java @@ -55,13 +55,12 @@ public class RecordingHiveMetastore implements ExtendedHiveMetastore { - private static final String LIST_ROLES_KEY = "LIST_ROLES_KEY"; - private final ExtendedHiveMetastore delegate; private final String recordingPath; private final boolean replay; private volatile Optional> allDatabases = Optional.empty(); + private volatile Optional> allRoles = Optional.empty(); private final Cache> databaseCache; private final Cache> tableCache; @@ -74,9 +73,8 @@ public class RecordingHiveMetastore private final Cache>> partitionNamesCache; private final Cache>> partitionNamesByPartsCache; private final Cache, Map>> partitionsByNamesCache; - private final Cache> listTablePrivilegesCache; - private final Cache> listRolesCache; - private final Cache> listRoleGrantsCache; + private final Cache> tablePrivilegesCache; + private final Cache> roleGrantsCache; @Inject public RecordingHiveMetastore(@ForRecordingHiveMetastore ExtendedHiveMetastore delegate, HiveClientConfig hiveClientConfig) @@ -98,9 +96,8 @@ public RecordingHiveMetastore(@ForRecordingHiveMetastore ExtendedHiveMetastore d partitionNamesCache = createCache(hiveClientConfig); partitionNamesByPartsCache = createCache(hiveClientConfig); partitionsByNamesCache = createCache(hiveClientConfig); - listTablePrivilegesCache = createCache(hiveClientConfig); - listRolesCache = createCache(hiveClientConfig); - listRoleGrantsCache = createCache(hiveClientConfig); + tablePrivilegesCache = createCache(hiveClientConfig); + roleGrantsCache = createCache(hiveClientConfig); if (replay) { loadRecording(); @@ -114,6 +111,7 @@ void loadRecording() Recording recording = new ObjectMapperProvider().get().readValue(new File(recordingPath), Recording.class); allDatabases = recording.getAllDatabases(); + allRoles = recording.getAllRoles(); databaseCache.putAll(toMap(recording.getDatabases())); tableCache.putAll(toMap(recording.getTables())); supportedColumnStatisticsCache.putAll(toMap(recording.getSupportedColumnStatistics())); @@ -125,9 +123,8 @@ void loadRecording() partitionNamesCache.putAll(toMap(recording.getPartitionNames())); partitionNamesByPartsCache.putAll(toMap(recording.getPartitionNamesByParts())); partitionsByNamesCache.putAll(toMap(recording.getPartitionsByNames())); - listTablePrivilegesCache.putAll(toMap(recording.getListTablePrivileges())); - listRolesCache.putAll(toMap(recording.getListRoles())); - listRoleGrantsCache.putAll(toMap(recording.getListRoleGrants())); + tablePrivilegesCache.putAll(toMap(recording.getTablePrivileges())); + roleGrantsCache.putAll(toMap(recording.getRoleGrants())); } private static Cache createCache(HiveClientConfig hiveClientConfig) @@ -152,6 +149,7 @@ public void writeRecording() Recording recording = new Recording( allDatabases, + allRoles, toPairs(databaseCache), toPairs(tableCache), toPairs(supportedColumnStatisticsCache), @@ -163,9 +161,8 @@ public void writeRecording() toPairs(partitionNamesCache), toPairs(partitionNamesByPartsCache), toPairs(partitionsByNamesCache), - toPairs(listTablePrivilegesCache), - toPairs(listRolesCache), - toPairs(listRoleGrantsCache)); + toPairs(tablePrivilegesCache), + toPairs(roleGrantsCache)); new ObjectMapperProvider().get() .writerWithDefaultPrettyPrinter() .writeValue(new File(recordingPath), recording); @@ -389,7 +386,7 @@ public void alterPartition(String databaseName, String tableName, PartitionWithS public Set listTablePrivileges(String databaseName, String tableName, PrestoPrincipal principal) { return loadValue( - listTablePrivilegesCache, + tablePrivilegesCache, new UserTableKey(principal, databaseName, tableName), () -> delegate.listTablePrivileges(databaseName, tableName, principal)); } @@ -432,10 +429,13 @@ public void dropRole(String role) @Override public Set listRoles() { - return loadValue( - listRolesCache, - LIST_ROLES_KEY, - () -> delegate.listRoles()); + if (replay) { + return allRoles.orElseThrow(() -> new PrestoException(NOT_FOUND, "Missing entry for roles")); + } + + Set result = delegate.listRoles(); + allRoles = Optional.of(result); + return result; } @Override @@ -456,7 +456,7 @@ public void revokeRoles(Set roles, Set grantees, boolea public Set listRoleGrants(PrestoPrincipal principal) { return loadValue( - listRoleGrantsCache, + roleGrantsCache, principal, () -> delegate.listRoleGrants(principal)); } @@ -484,6 +484,7 @@ private void verifyRecordingMode() public static class Recording { private final Optional> allDatabases; + private final Optional> allRoles; private final List>> databases; private final List>> tables; private final List>> supportedColumnStatistics; @@ -495,13 +496,13 @@ public static class Recording private final List>>> partitionNames; private final List>>> partitionNamesByParts; private final List, Map>>> partitionsByNames; - private final List>> listTablePrivileges; - private final List>> listRoles; - private final List>> listRoleGrants; + private final List>> tablePrivileges; + private final List>> roleGrants; @JsonCreator public Recording( @JsonProperty("allDatabases") Optional> allDatabases, + @JsonProperty("allRoles") Optional> allRoles, @JsonProperty("databases") List>> databases, @JsonProperty("tables") List>> tables, @JsonProperty("supportedColumnStatistics") List>> supportedColumnStatistics, @@ -513,11 +514,11 @@ public Recording( @JsonProperty("partitionNames") List>>> partitionNames, @JsonProperty("partitionNamesByParts") List>>> partitionNamesByParts, @JsonProperty("partitionsByNames") List, Map>>> partitionsByNames, - @JsonProperty("listTablePrivileges") List>> listTablePrivileges, - @JsonProperty("listRoles") List>> listRoles, - @JsonProperty("listRoleGrants") List>> listRoleGrants) + @JsonProperty("tablePrivileges") List>> tablePrivileges, + @JsonProperty("roleGrants") List>> roleGrants) { this.allDatabases = allDatabases; + this.allRoles = allRoles; this.databases = databases; this.tables = tables; this.supportedColumnStatistics = supportedColumnStatistics; @@ -529,9 +530,8 @@ public Recording( this.partitionNames = partitionNames; this.partitionNamesByParts = partitionNamesByParts; this.partitionsByNames = partitionsByNames; - this.listTablePrivileges = listTablePrivileges; - this.listRoles = listRoles; - this.listRoleGrants = listRoleGrants; + this.tablePrivileges = tablePrivileges; + this.roleGrants = roleGrants; } @JsonProperty @@ -540,6 +540,12 @@ public Optional> getAllDatabases() return allDatabases; } + @JsonProperty + public Optional> getAllRoles() + { + return allRoles; + } + @JsonProperty public List>> getDatabases() { @@ -607,21 +613,15 @@ public List, Map>>> getP } @JsonProperty - public List>> getListTablePrivileges() - { - return listTablePrivileges; - } - - @JsonProperty - public List>> getListRoles() + public List>> getTablePrivileges() { - return listRoles; + return tablePrivileges; } @JsonProperty - public List>> getListRoleGrants() + public List>> getRoleGrants() { - return listRoleGrants; + return roleGrants; } }