Skip to content

Commit

Permalink
[FLINK-35208][python] Respect pipeline.cached-files during handling P…
Browse files Browse the repository at this point in the history
…ython dependencies
  • Loading branch information
dianfu committed Apr 25, 2024
1 parent 0953199 commit 127a521
Show file tree
Hide file tree
Showing 14 changed files with 86 additions and 90 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -662,7 +662,6 @@ def add_from_file(i):
python_dependency_config = dict(
get_gateway().jvm.org.apache.flink.python.util.PythonDependencyUtils.
configurePythonDependencies(
env._j_stream_execution_environment.getCachedFiles(),
env._j_stream_execution_environment.getConfiguration()).toMap())

# Make sure that user specified files and archives are correctly added.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.dag.Transformation;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.configuration.PipelineOptions;
import org.apache.flink.configuration.ReadableConfig;
import org.apache.flink.core.memory.ManagedMemoryUseCase;
import org.apache.flink.python.PythonConfig;
Expand Down Expand Up @@ -79,8 +79,16 @@ public static Configuration getEnvironmentConfig(StreamExecutionEnvironment env)
}

public static void configPythonOperator(StreamExecutionEnvironment env) throws Exception {
final Configuration config =
extractPythonConfiguration(env.getCachedFiles(), env.getConfiguration());
final Configuration config = extractPythonConfiguration(env.getConfiguration());

env.getConfiguration()
.getOptional(PipelineOptions.CACHED_FILES)
.ifPresent(
f -> {
env.getCachedFiles().clear();
env.getCachedFiles()
.addAll(DistributedCache.parseCachedFilesFromString(f));
});

for (Transformation<?> transformation : env.getTransformations()) {
alignTransformation(transformation);
Expand All @@ -102,11 +110,9 @@ public static void configPythonOperator(StreamExecutionEnvironment env) throws E
}

/** Extract the configurations which is used in the Python operators. */
public static Configuration extractPythonConfiguration(
List<Tuple2<String, DistributedCache.DistributedCacheEntry>> cachedFiles,
ReadableConfig config) {
public static Configuration extractPythonConfiguration(ReadableConfig config) {
final Configuration pythonDependencyConfig =
PythonDependencyUtils.configurePythonDependencies(cachedFiles, config);
PythonDependencyUtils.configurePythonDependencies(config);
final PythonConfig pythonConfig = new PythonConfig(config, pythonDependencyConfig);
return pythonConfig.toConfiguration();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@
package org.apache.flink.python.util;

import org.apache.flink.annotation.Internal;
import org.apache.flink.api.common.cache.DistributedCache;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.configuration.ConfigurationUtils;
import org.apache.flink.configuration.PipelineOptions;
import org.apache.flink.configuration.ReadableConfig;
import org.apache.flink.configuration.WritableConfig;
import org.apache.flink.python.PythonOptions;
import org.apache.flink.util.Preconditions;
import org.apache.flink.util.StringUtils;
Expand All @@ -34,9 +36,13 @@
import java.nio.charset.StandardCharsets;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;

import static org.apache.flink.client.cli.CliFrontendParser.PYARCHIVE_OPTION;
Expand Down Expand Up @@ -69,16 +75,12 @@ public class PythonDependencyUtils {
* returns a new configuration which contains the metadata of the registered python
* dependencies.
*
* @param cachedFiles The list used to store registered cached files.
* @param config The configuration which contains python dependency configuration.
* @return A new configuration which contains the metadata of the registered python
* dependencies.
*/
public static Configuration configurePythonDependencies(
List<Tuple2<String, DistributedCache.DistributedCacheEntry>> cachedFiles,
ReadableConfig config) {
final PythonDependencyManager pythonDependencyManager =
new PythonDependencyManager(cachedFiles, config);
public static Configuration configurePythonDependencies(ReadableConfig config) {
final PythonDependencyManager pythonDependencyManager = new PythonDependencyManager(config);
final Configuration pythonDependencyConfig = new Configuration();
pythonDependencyManager.applyToConfiguration(pythonDependencyConfig);
return pythonDependencyConfig;
Expand Down Expand Up @@ -157,14 +159,10 @@ private static class PythonDependencyManager {
private static final String PYTHON_REQUIREMENTS_FILE_PREFIX = "python_requirements_file";
private static final String PYTHON_REQUIREMENTS_CACHE_PREFIX = "python_requirements_cache";
private static final String PYTHON_ARCHIVE_PREFIX = "python_archive";

private final List<Tuple2<String, DistributedCache.DistributedCacheEntry>> cachedFiles;
private final ReadableConfig config;

private PythonDependencyManager(
List<Tuple2<String, DistributedCache.DistributedCacheEntry>> cachedFiles,
ReadableConfig config) {
this.cachedFiles = cachedFiles;
private PythonDependencyManager(ReadableConfig config) {
Preconditions.checkArgument(config instanceof WritableConfig);
this.config = config;
}

Expand All @@ -178,7 +176,7 @@ private PythonDependencyManager(
private void addPythonFile(Configuration pythonDependencyConfig, String filePath) {
Preconditions.checkNotNull(filePath);
String fileKey = generateUniqueFileKey(PYTHON_FILE_PREFIX, filePath);
registerCachedFileIfNotExist(filePath, fileKey);
registerCachedFileIfNotExist(config, fileKey, filePath);
if (!pythonDependencyConfig.contains(PYTHON_FILES_DISTRIBUTED_CACHE_INFO)) {
pythonDependencyConfig.set(
PYTHON_FILES_DISTRIBUTED_CACHE_INFO, new LinkedHashMap<>());
Expand Down Expand Up @@ -224,7 +222,7 @@ private void setPythonRequirements(

String fileKey =
generateUniqueFileKey(PYTHON_REQUIREMENTS_FILE_PREFIX, requirementsFilePath);
registerCachedFileIfNotExist(requirementsFilePath, fileKey);
registerCachedFileIfNotExist(config, fileKey, requirementsFilePath);
pythonDependencyConfig
.get(PYTHON_REQUIREMENTS_FILE_DISTRIBUTED_CACHE_INFO)
.put(FILE, fileKey);
Expand All @@ -233,7 +231,7 @@ private void setPythonRequirements(
String cacheDirKey =
generateUniqueFileKey(
PYTHON_REQUIREMENTS_CACHE_PREFIX, requirementsCachedDir);
registerCachedFileIfNotExist(requirementsCachedDir, cacheDirKey);
registerCachedFileIfNotExist(config, cacheDirKey, requirementsCachedDir);
pythonDependencyConfig
.get(PYTHON_REQUIREMENTS_FILE_DISTRIBUTED_CACHE_INFO)
.put(CACHE, cacheDirKey);
Expand All @@ -258,7 +256,7 @@ private void addPythonArchive(
String fileKey =
generateUniqueFileKey(
PYTHON_ARCHIVE_PREFIX, archivePath + PARAM_DELIMITER + targetDir);
registerCachedFileIfNotExist(archivePath, fileKey);
registerCachedFileIfNotExist(config, fileKey, archivePath);
pythonDependencyConfig
.get(PYTHON_ARCHIVES_DISTRIBUTED_CACHE_INFO)
.put(fileKey, targetDir);
Expand Down Expand Up @@ -336,20 +334,39 @@ private String generateUniqueFileKey(String prefix, String hashString) {
"%s_%s", prefix, StringUtils.byteToHexString(messageDigest.digest()));
}

private void registerCachedFileIfNotExist(String filePath, String fileKey) {
if (cachedFiles.stream().noneMatch(t -> t.f0.equals(fileKey))) {
cachedFiles.add(
new Tuple2<>(
fileKey,
new DistributedCache.DistributedCacheEntry(filePath, false)));
}
private void registerCachedFileIfNotExist(ReadableConfig config, String name, String path) {
final Set<String> cachedFiles =
config.getOptional(PipelineOptions.CACHED_FILES)
.map(LinkedHashSet::new)
.orElseGet(LinkedHashSet::new);

Map<String, String> map = new HashMap<>();
map.put("name", name);
map.put("path", path);
cachedFiles.add(ConfigurationUtils.convertValue(map, String.class));

((WritableConfig) config)
.set(PipelineOptions.CACHED_FILES, new ArrayList<>(cachedFiles));
}

private void removeCachedFilesByPrefix(String prefix) {
cachedFiles.removeAll(
cachedFiles.stream()
.filter(t -> t.f0.matches("^" + prefix + "_[a-z0-9]{64}$"))
.collect(Collectors.toSet()));
final List<String> cachedFiles =
config.getOptional(PipelineOptions.CACHED_FILES).orElse(new ArrayList<>())
.stream()
.map(m -> Tuple2.of(ConfigurationUtils.parseMap(m), m))
.filter(
t ->
t.f0.get("name") != null
&& !(t.f0.get("name")
.matches(
"^"
+ prefix
+ "_[a-z0-9]{64}$")))
.map(t -> t.f1)
.collect(Collectors.toList());

((WritableConfig) config)
.set(PipelineOptions.CACHED_FILES, new ArrayList<>(cachedFiles));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,15 @@

package org.apache.flink.python.util;

import org.apache.flink.api.common.cache.DistributedCache;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.configuration.ConfigurationUtils;
import org.apache.flink.configuration.PipelineOptions;
import org.apache.flink.python.PythonOptions;

import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.stream.Collectors;
Expand All @@ -47,20 +45,13 @@
/** Tests for PythonDependencyUtils. */
class PythonDependencyUtilsTest {

private List<Tuple2<String, DistributedCache.DistributedCacheEntry>> cachedFiles;

@BeforeEach
void setUp() {
cachedFiles = new ArrayList<>();
}

@Test
void testPythonFiles() {
Configuration config = new Configuration();
config.set(
PythonOptions.PYTHON_FILES,
"hdfs:///tmp_dir/test_file1.py,tmp_dir/test_file2.py,tmp_dir/test_dir,hdfs:///tmp_dir/test_file1.py");
Configuration actual = configurePythonDependencies(cachedFiles, config);
Configuration actual = configurePythonDependencies(config);

Map<String, String> expectedCachedFiles = new HashMap<>();
expectedCachedFiles.put(
Expand All @@ -72,7 +63,7 @@ void testPythonFiles() {
expectedCachedFiles.put(
"python_file_e56bc55ff643576457b3d012b2bba888727c71cf05a958930f2263398c4e9798",
"tmp_dir/test_dir");
verifyCachedFiles(expectedCachedFiles);
verifyCachedFiles(expectedCachedFiles, config);

Configuration expectedConfiguration = new Configuration();
expectedConfiguration.set(PYTHON_FILES_DISTRIBUTED_CACHE_INFO, new HashMap<>());
Expand All @@ -98,13 +89,13 @@ void testPythonFiles() {
void testPythonRequirements() {
Configuration config = new Configuration();
config.set(PYTHON_REQUIREMENTS, "tmp_dir/requirements.txt");
Configuration actual = configurePythonDependencies(cachedFiles, config);
Configuration actual = configurePythonDependencies(config);

Map<String, String> expectedCachedFiles = new HashMap<>();
expectedCachedFiles.put(
"python_requirements_file_69390ca43c69ada3819226fcfbb5b6d27e111132a9427e7f201edd82e9d65ff6",
"tmp_dir/requirements.txt");
verifyCachedFiles(expectedCachedFiles);
verifyCachedFiles(expectedCachedFiles, config);

Configuration expectedConfiguration = new Configuration();
expectedConfiguration.set(PYTHON_REQUIREMENTS_FILE_DISTRIBUTED_CACHE_INFO, new HashMap<>());
Expand All @@ -116,7 +107,7 @@ void testPythonRequirements() {
verifyConfiguration(expectedConfiguration, actual);

config.set(PYTHON_REQUIREMENTS, "tmp_dir/requirements2.txt#tmp_dir/cache");
actual = configurePythonDependencies(cachedFiles, config);
actual = configurePythonDependencies(config);

expectedCachedFiles = new HashMap<>();
expectedCachedFiles.put(
Expand All @@ -125,7 +116,7 @@ void testPythonRequirements() {
expectedCachedFiles.put(
"python_requirements_cache_2f563dd6731c2c7c5e1ef1ef8279f61e907dc3bfc698adb71b109e43ed93e143",
"tmp_dir/cache");
verifyCachedFiles(expectedCachedFiles);
verifyCachedFiles(expectedCachedFiles, config);

expectedConfiguration = new Configuration();
expectedConfiguration.set(PYTHON_REQUIREMENTS_FILE_DISTRIBUTED_CACHE_INFO, new HashMap<>());
Expand All @@ -152,7 +143,7 @@ void testPythonArchives() {
+ "tmp_dir/py37.zip,"
+ "tmp_dir/py37.zip#venv,"
+ "tmp_dir/py37.zip#venv2,tmp_dir/py37.zip#venv");
Configuration actual = configurePythonDependencies(cachedFiles, config);
Configuration actual = configurePythonDependencies(config);

Map<String, String> expectedCachedFiles = new HashMap<>();
expectedCachedFiles.put(
Expand All @@ -167,7 +158,7 @@ void testPythonArchives() {
expectedCachedFiles.put(
"python_archive_c7d970ce1c5794367974ce8ef536c2343bed8fcfe7c2422c51548e58007eee6a",
"tmp_dir/py37.zip");
verifyCachedFiles(expectedCachedFiles);
verifyCachedFiles(expectedCachedFiles, config);

Configuration expectedConfiguration = new Configuration();
expectedConfiguration.set(PYTHON_ARCHIVES_DISTRIBUTED_CACHE_INFO, new HashMap<>());
Expand Down Expand Up @@ -199,7 +190,7 @@ void testPythonExecutables() {
Configuration config = new Configuration();
config.set(PYTHON_EXECUTABLE, "venv/bin/python3");
config.set(PYTHON_CLIENT_EXECUTABLE, "python37");
Configuration actual = configurePythonDependencies(cachedFiles, config);
Configuration actual = configurePythonDependencies(config);

Configuration expectedConfiguration = new Configuration();
expectedConfiguration.set(PYTHON_EXECUTABLE, "venv/bin/python3");
Expand Down Expand Up @@ -246,9 +237,11 @@ void testPythonPath() {
verifyConfiguration(expectedConfiguration, actual);
}

private void verifyCachedFiles(Map<String, String> expected) {
private void verifyCachedFiles(Map<String, String> expected, Configuration config) {
Map<String, String> actual =
cachedFiles.stream().collect(Collectors.toMap(t -> t.f0, t -> t.f1.filePath));
config.getOptional(PipelineOptions.CACHED_FILES).orElse(new ArrayList<>()).stream()
.map(ConfigurationUtils::parseMap)
.collect(Collectors.toMap(m -> m.get("name"), m -> m.get("path")));

assertThat(actual).isEqualTo(expected);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ protected Transformation<RowData> translateToPlanInternal(
final RowType outputRowType = InternalTypeInfo.of(getOutputType()).toRowType();
Configuration pythonConfig =
CommonPythonUtil.extractPythonConfiguration(
planner.getExecEnv(), config, planner.getFlinkContext().getClassLoader());
planner.getTableConfig(), planner.getFlinkContext().getClassLoader());
OneInputTransformation<RowData, RowData> transform =
createPythonOneInputTransformation(
inputTransform,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ protected Transformation<RowData> translateToPlanInternal(
final Tuple2<Long, Long> windowSizeAndSlideSize = WindowCodeGenerator.getWindowDef(window);
final Configuration pythonConfig =
CommonPythonUtil.extractPythonConfiguration(
planner.getExecEnv(), config, planner.getFlinkContext().getClassLoader());
planner.getTableConfig(), planner.getFlinkContext().getClassLoader());
int groupBufferLimitSize =
pythonConfig.get(ExecutionConfigOptions.TABLE_EXEC_WINDOW_AGG_BUFFER_SIZE_LIMIT);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ protected Transformation<RowData> translateToPlanInternal(
}
Configuration pythonConfig =
CommonPythonUtil.extractPythonConfiguration(
planner.getExecEnv(), config, planner.getFlinkContext().getClassLoader());
planner.getTableConfig(), planner.getFlinkContext().getClassLoader());
OneInputTransformation<RowData, RowData> transform =
createPythonOneInputTransformation(
inputTransform,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ protected Transformation<RowData> translateToPlanInternal(
(Transformation<RowData>) inputEdge.translateToPlan(planner);
final Configuration pythonConfig =
CommonPythonUtil.extractPythonConfiguration(
planner.getExecEnv(), config, planner.getFlinkContext().getClassLoader());
planner.getTableConfig(), planner.getFlinkContext().getClassLoader());
OneInputTransformation<RowData, RowData> ret =
createPythonOneInputTransformation(
inputTransform,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ protected Transformation<RowData> translateToPlanInternal(
(Transformation<RowData>) inputEdge.translateToPlan(planner);
final Configuration pythonConfig =
CommonPythonUtil.extractPythonConfiguration(
planner.getExecEnv(), config, planner.getFlinkContext().getClassLoader());
planner.getTableConfig(), planner.getFlinkContext().getClassLoader());
final ExecNodeConfig pythonNodeConfig =
ExecNodeConfig.ofNodeConfig(pythonConfig, config.isCompiled());
final OneInputTransformation<RowData, RowData> transform =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ protected Transformation<RowData> translateToPlanInternal(
DataViewSpec[][] dataViewSpecs = aggInfosAndDataViewSpecs.f1;
Configuration pythonConfig =
CommonPythonUtil.extractPythonConfiguration(
planner.getExecEnv(), config, planner.getFlinkContext().getClassLoader());
planner.getTableConfig(), planner.getFlinkContext().getClassLoader());
final OneInputStreamOperator<RowData, RowData> operator =
getPythonAggregateFunctionOperator(
pythonConfig,
Expand Down
Loading

0 comments on commit 127a521

Please sign in to comment.