Skip to content

Commit

Permalink
Merge pull request #15406 from cdapio/dataset_transform_610
Browse files Browse the repository at this point in the history
[🍒 6.10] [CDAP-20657] fixes & additional support for Datasets
  • Loading branch information
tivv authored Nov 4, 2023
2 parents 06d6eb8 + 61b01a7 commit 1bb6684
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ public SparkCollection<RecordInfo<Object>> transform(
@Override
public SparkCollection<RecordInfo<Object>> multiOutputTransform(StageSpec stageSpec,
StageStatisticsCollector collector) {
return getDelegate().transform(stageSpec, collector);
return getDelegate().multiOutputTransform(stageSpec, collector);
}

@Override
Expand Down Expand Up @@ -94,8 +94,8 @@ public SparkCollection union(SparkCollection other) {

@Override
public Runnable createStoreTask(StageSpec stageSpec,
SparkSink sink) throws Exception {
return getDelegate().createStoreTask(stageSpec, sink);
SparkSink sink) {
return () -> getDelegate().createStoreTask(stageSpec, sink).run();
}

@Override
Expand Down Expand Up @@ -135,13 +135,13 @@ public <U> SparkCollection<U> compute(StageSpec stageSpec, SparkCompute<T, U> co
@Override
public Runnable createStoreTask(StageSpec stageSpec,
PairFlatMapFunction<T, Object, Object> sinkFunction) {
return getDelegate().createStoreTask(stageSpec, sinkFunction);
return () -> getDelegate().createStoreTask(stageSpec, sinkFunction).run();
}

@Override
public Runnable createMultiStoreTask(PhaseSpec phaseSpec, Set<String> group, Set<String> sinks,
Map<String, StageStatisticsCollector> collectors) {
return getDelegate().createMultiStoreTask(phaseSpec, group, sinks, collectors);
return () -> getDelegate().createMultiStoreTask(phaseSpec, group, sinks, collectors).run();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import io.cdap.cdap.etl.common.RecordInfo;
import io.cdap.cdap.etl.common.StageStatisticsCollector;
import io.cdap.cdap.etl.proto.v2.spec.StageSpec;
import io.cdap.cdap.etl.spark.function.PluginFunctionContext;
import io.cdap.cdap.etl.spark.function.TransformFunction;
import io.cdap.cdap.etl.spark.join.JoinExpressionRequest;
import io.cdap.cdap.etl.spark.join.JoinRequest;
import org.apache.spark.api.java.function.FlatMapFunction;
Expand Down Expand Up @@ -81,7 +83,7 @@ SparkCollection<RecordInfo<Object>> reduceAggregate(StageSpec stageSpec, @Nullab
Runnable createMultiStoreTask(PhaseSpec phaseSpec, Set<String> group, Set<String> sinks,
Map<String, StageStatisticsCollector> collectors);

Runnable createStoreTask(StageSpec stageSpec, SparkSink<T> sink) throws Exception;
Runnable createStoreTask(StageSpec stageSpec, SparkSink<T> sink);

void publishAlerts(StageSpec stageSpec, StageStatisticsCollector collector) throws Exception;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,21 @@
package io.cdap.cdap.etl.spark.batch;

import io.cdap.cdap.api.data.DatasetContext;
import io.cdap.cdap.api.data.format.StructuredRecord;
import io.cdap.cdap.api.data.schema.Schema;
import io.cdap.cdap.api.spark.JavaSparkExecutionContext;
import io.cdap.cdap.etl.api.batch.SparkCompute;
import io.cdap.cdap.etl.api.batch.SparkExecutionPluginContext;
import io.cdap.cdap.etl.common.Constants;
import io.cdap.cdap.etl.common.PipelineRuntime;
import io.cdap.cdap.etl.common.RecordInfo;
import io.cdap.cdap.etl.common.StageStatisticsCollector;
import io.cdap.cdap.etl.proto.v2.spec.StageSpec;
import io.cdap.cdap.etl.spark.DelegatingSparkCollection;
import io.cdap.cdap.etl.spark.SparkCollection;
import io.cdap.cdap.etl.spark.SparkPipelineRuntime;
import io.cdap.cdap.etl.spark.function.DatasetAggregationAccumulator;
import io.cdap.cdap.etl.spark.function.DatasetAggregationFinalizeFunction;
import io.cdap.cdap.etl.spark.function.DatasetAggregationGetKeyFunction;
import io.cdap.cdap.etl.spark.function.DatasetAggregationReduceFunction;
import io.cdap.cdap.etl.spark.function.FunctionCache;
import io.cdap.cdap.etl.spark.function.MultiOutputTransformFunction;
import io.cdap.cdap.etl.spark.function.PluginFunctionContext;
import io.cdap.cdap.etl.spark.function.TransformFunction;
import javax.annotation.Nullable;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
Expand Down Expand Up @@ -176,6 +172,21 @@ public SparkCollection<RecordInfo<Object>> reduceAggregate(StageSpec stageSpec,
return reduceDatasetAggregate(stageSpec, partitions, collector);
}

@Override
public SparkCollection<RecordInfo<Object>> transform(StageSpec stageSpec, StageStatisticsCollector collector) {
PluginFunctionContext pluginFunctionContext = new PluginFunctionContext(stageSpec, sec, collector);
return flatMap(stageSpec, new TransformFunction<T>(
pluginFunctionContext, functionCacheFactory.newCache()));
}

@Override
public SparkCollection<RecordInfo<Object>> multiOutputTransform(StageSpec stageSpec,
StageStatisticsCollector collector) {
PluginFunctionContext pluginFunctionContext = new PluginFunctionContext(stageSpec, sec, collector);
return flatMap(stageSpec,new MultiOutputTransformFunction<T>(
pluginFunctionContext, functionCacheFactory.newCache()));
}

/**
* Performs reduce aggregate using Dataset API. This allows SPARK to perform various optimizations that
* are not available when working on the RDD level.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,16 +144,16 @@ public SparkCollection<T> union(SparkCollection<T> other) {
@Override
public SparkCollection<RecordInfo<Object>> transform(StageSpec stageSpec, StageStatisticsCollector collector) {
PluginFunctionContext pluginFunctionContext = new PluginFunctionContext(stageSpec, sec, collector);
return wrap(rdd.flatMap(new TransformFunction<T>(
pluginFunctionContext, functionCacheFactory.newCache())));
return flatMap(stageSpec, new TransformFunction<T>(
pluginFunctionContext, functionCacheFactory.newCache()));
}

@Override
public SparkCollection<RecordInfo<Object>> multiOutputTransform(StageSpec stageSpec,
StageStatisticsCollector collector) {
PluginFunctionContext pluginFunctionContext = new PluginFunctionContext(stageSpec, sec, collector);
return wrap(rdd.flatMap(new MultiOutputTransformFunction<T>(
pluginFunctionContext, functionCacheFactory.newCache())));
return flatMap(stageSpec,new MultiOutputTransformFunction<T>(
pluginFunctionContext, functionCacheFactory.newCache()));
}

@Override
Expand Down Expand Up @@ -305,7 +305,7 @@ private void recordLineage(String name) {
}

@Override
public Runnable createStoreTask(final StageSpec stageSpec, final SparkSink<T> sink) throws Exception {
public Runnable createStoreTask(final StageSpec stageSpec, final SparkSink<T> sink) {
return new Runnable() {
@Override
public void run() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ public Runnable createMultiStoreTask(PhaseSpec phaseSpec, Set<String> group, Set
}

@Override
public Runnable createStoreTask(StageSpec stageSpec, SparkSink<T> sink) throws Exception {
public Runnable createStoreTask(StageSpec stageSpec, SparkSink<T> sink) {
return () -> {
try {
pull().createStoreTask(stageSpec, sink).run();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ public void run() {
}

@Override
public Runnable createStoreTask(StageSpec stageSpec, SparkSink<T> sink) throws Exception {
public Runnable createStoreTask(StageSpec stageSpec, SparkSink<T> sink) {
return new Runnable() {
@Override
public void run() {
Expand Down

0 comments on commit 1bb6684

Please sign in to comment.