Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor MaskRule #30548

Merged
merged 1 commit into from
Mar 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
/**
* Mask algorithm meta data.
*/
@SuppressWarnings("rawtypes")
@RequiredArgsConstructor
public final class MaskAlgorithmMetaData {

Expand All @@ -42,14 +41,12 @@ public final class MaskAlgorithmMetaData {
* Find mask algorithm.
*
* @param columnIndex column index
* @return maskAlgorithm
* @return found mask algorithm
*/
public Optional<MaskAlgorithm> findMaskAlgorithmByColumnIndex(final int columnIndex) {
@SuppressWarnings("rawtypes")
public Optional<MaskAlgorithm> findMaskAlgorithm(final int columnIndex) {
Optional<ColumnProjection> columnProjection = findColumnProjection(columnIndex);
if (!columnProjection.isPresent()) {
return Optional.empty();
}
return maskRule.findMaskAlgorithm(columnProjection.get().getOriginalTable().getValue(), columnProjection.get().getName().getValue());
return columnProjection.isPresent() ? maskRule.findAlgorithm(columnProjection.get().getOriginalTable().getValue(), columnProjection.get().getName().getValue()) : Optional.empty();
}

private Optional<ColumnProjection> findColumnProjection(final int columnIndex) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ public boolean next() throws SQLException {
@SuppressWarnings({"rawtypes", "unchecked"})
@Override
public Object getValue(final int columnIndex, final Class<?> type) throws SQLException {
Optional<MaskAlgorithm> maskAlgorithm = metaData.findMaskAlgorithmByColumnIndex(columnIndex);
Optional<MaskAlgorithm> maskAlgorithm = metaData.findMaskAlgorithm(columnIndex);
if (!maskAlgorithm.isPresent()) {
return mergedResult.getValue(columnIndex, type);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,38 +17,41 @@

package org.apache.shardingsphere.mask.rule;

import com.cedarsoftware.util.CaseInsensitiveMap;
import lombok.Getter;
import org.apache.shardingsphere.infra.rule.scope.DatabaseRule;
import org.apache.shardingsphere.infra.rule.attribute.RuleAttributes;
import org.apache.shardingsphere.infra.rule.scope.DatabaseRule;
import org.apache.shardingsphere.infra.spi.type.typed.TypedSPILoader;
import org.apache.shardingsphere.mask.api.config.MaskRuleConfiguration;
import org.apache.shardingsphere.mask.rule.attribute.MaskTableMapperRuleAttribute;
import org.apache.shardingsphere.mask.spi.MaskAlgorithm;

import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;
import java.util.stream.Collectors;

/**
* Mask rule.
*/
@SuppressWarnings("rawtypes")
public final class MaskRule implements DatabaseRule {

@Getter
private final MaskRuleConfiguration configuration;

private final Map<String, MaskAlgorithm> maskAlgorithms = new LinkedHashMap<>();
private final Map<String, MaskTable> tables;

private final Map<String, MaskTable> tables = new LinkedHashMap<>();
private final Map<String, MaskAlgorithm<?, ?>> maskAlgorithms;

@Getter
private final RuleAttributes attributes;

@SuppressWarnings("unchecked")
public MaskRule(final MaskRuleConfiguration ruleConfig) {
configuration = ruleConfig;
ruleConfig.getMaskAlgorithms().forEach((key, value) -> maskAlgorithms.put(key, TypedSPILoader.getService(MaskAlgorithm.class, value.getType(), value.getProps())));
ruleConfig.getTables().forEach(each -> tables.put(each.getName().toLowerCase(), new MaskTable(each)));
tables = ruleConfig.getTables().stream().collect(Collectors.toMap(each -> each.getName().toLowerCase(), MaskTable::new, (oldValue, currentValue) -> oldValue, CaseInsensitiveMap::new));
maskAlgorithms = ruleConfig.getMaskAlgorithms().entrySet().stream()
.collect(Collectors.toMap(Entry::getKey, entry -> TypedSPILoader.getService(MaskAlgorithm.class, entry.getValue().getType(), entry.getValue().getProps())));
attributes = new RuleAttributes(new MaskTableMapperRuleAttribute(ruleConfig.getTables()));
}

Expand All @@ -57,10 +60,10 @@ public MaskRule(final MaskRuleConfiguration ruleConfig) {
*
* @param logicTable logic table name
* @param logicColumn logic column name
* @return maskAlgorithm
* @return mask algorithm
*/
public Optional<MaskAlgorithm> findMaskAlgorithm(final String logicTable, final String logicColumn) {
String lowerCaseLogicTable = logicTable.toLowerCase();
return tables.containsKey(lowerCaseLogicTable) ? tables.get(lowerCaseLogicTable).findMaskAlgorithmName(logicColumn).map(maskAlgorithms::get) : Optional.empty();
@SuppressWarnings("rawtypes")
public Optional<MaskAlgorithm> findAlgorithm(final String logicTable, final String logicColumn) {
return tables.containsKey(logicTable) ? tables.get(logicTable).findAlgorithmName(logicColumn).map(maskAlgorithms::get) : Optional.empty();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,10 @@ public MaskTable(final MaskTableRuleConfiguration config) {
/**
* Find mask algorithm name.
*
* @param logicColumn column name
* @return mask algorithm name
* @param logicColumn logic column name
* @return found mask algorithm name
*/
public Optional<String> findMaskAlgorithmName(final String logicColumn) {
public Optional<String> findAlgorithmName(final String logicColumn) {
return columns.containsKey(logicColumn) ? Optional.of(columns.get(logicColumn).getMaskAlgorithm()) : Optional.empty();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,14 @@ class MaskAlgorithmMetaDataTest {

@SuppressWarnings("rawtypes")
@Test
void assertFindMaskAlgorithmByColumnIndex() {
when(maskRule.findMaskAlgorithm("t_order", "order_id")).thenReturn(Optional.of(TypedSPILoader.getService(MaskAlgorithm.class, "MD5")));
void assertFindAlgorithmByColumnIndex() {
when(maskRule.findAlgorithm("t_order", "order_id")).thenReturn(Optional.of(TypedSPILoader.getService(MaskAlgorithm.class, "MD5")));
ColumnProjection columnProjection = new ColumnProjection(null, "order_id", null, mock(DatabaseType.class));
columnProjection.setOriginalColumn(new IdentifierValue("order_id"));
columnProjection.setOriginalTable(new IdentifierValue("t_order"));
when(selectStatementContext.getProjectionsContext().getExpandProjections()).thenReturn(Collections.singletonList(columnProjection));
when(selectStatementContext.getTablesContext().getTableNames()).thenReturn(Collections.singleton("t_order"));
Optional<MaskAlgorithm> actual = new MaskAlgorithmMetaData(maskRule, selectStatementContext).findMaskAlgorithmByColumnIndex(1);
Optional<MaskAlgorithm> actual = new MaskAlgorithmMetaData(maskRule, selectStatementContext).findMaskAlgorithm(1);
assertTrue(actual.isPresent());
assertThat(actual.get().getType(), is("MD5"));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ void assertGetValue() throws SQLException {
when(mergedResult.getValue(1, Object.class)).thenReturn("VALUE");
MaskAlgorithm<String, String> maskAlgorithm = mock(MaskAlgorithm.class);
when(maskAlgorithm.mask("VALUE")).thenReturn("MASK_VALUE");
when(metaData.findMaskAlgorithmByColumnIndex(1)).thenReturn(Optional.of(maskAlgorithm));
when(metaData.findMaskAlgorithm(1)).thenReturn(Optional.of(maskAlgorithm));
assertThat(new MaskMergedResult(metaData, mergedResult).getValue(1, String.class), is("MASK_VALUE"));
}

Expand Down