Skip to content

Commit

Permalink
Refactor MultiplexPipelineChannel (#29541)
Browse files Browse the repository at this point in the history
* Refactor MultiplexPipelineChannel

* Refactor PipelineChannelCreator

* Refactor PipelineTaskUtils

* Fix test case

* Fix test case
  • Loading branch information
terrymanu authored Dec 25, 2023
1 parent add2466 commit 5d893c6
Show file tree
Hide file tree
Showing 10 changed files with 61 additions and 73 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,9 @@
* limitations under the License.
*/

package org.apache.shardingsphere.data.pipeline.core.channel.memory;
package org.apache.shardingsphere.data.pipeline.core.channel;

import org.apache.shardingsphere.data.pipeline.core.constant.PipelineSQLOperationType;
import org.apache.shardingsphere.data.pipeline.core.channel.PipelineChannelAckCallback;
import org.apache.shardingsphere.data.pipeline.core.channel.PipelineChannel;
import org.apache.shardingsphere.data.pipeline.core.ingest.record.DataRecord;
import org.apache.shardingsphere.data.pipeline.core.ingest.record.FinishedRecord;
import org.apache.shardingsphere.data.pipeline.core.ingest.record.PlaceholderRecord;
Expand All @@ -34,20 +32,19 @@
import java.util.stream.IntStream;

/**
* Multiplex memory pipeline channel.
* Multiplex pipeline channel.
*/
public final class MultiplexMemoryPipelineChannel implements PipelineChannel {
public final class MultiplexPipelineChannel implements PipelineChannel {

private final int channelCount;

private final List<PipelineChannel> channels;

private final Map<String, Integer> channelAssignment = new HashMap<>();

public MultiplexMemoryPipelineChannel(final int channelCount, final int blockQueueSize, final PipelineChannelAckCallback ackCallback) {
public MultiplexPipelineChannel(final int channelCount, final PipelineChannelCreator channelCreator, final int importerBatchSize, final PipelineChannelAckCallback ackCallback) {
this.channelCount = channelCount;
int handledQueueSize = blockQueueSize < 1 ? 5 : blockQueueSize;
channels = IntStream.range(0, channelCount).mapToObj(each -> new SimpleMemoryPipelineChannel(handledQueueSize, ackCallback)).collect(Collectors.toList());
channels = IntStream.range(0, channelCount).mapToObj(each -> channelCreator.newInstance(importerBatchSize, ackCallback)).collect(Collectors.toList());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,9 @@ public interface PipelineChannelCreator extends TypedSPI {
/**
* Create new instance of pipeline channel.
*
* @param outputConcurrency output concurrency
* @param averageElementSize average element size, affect the size of the queue
* @param importerBatchSize importer batch size
* @param ackCallback ack callback
* @return created instance
*/
PipelineChannel newInstance(int outputConcurrency, int averageElementSize, PipelineChannelAckCallback ackCallback);
PipelineChannel newInstance(int importerBatchSize, PipelineChannelAckCallback ackCallback);
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,15 @@
import java.util.concurrent.TimeUnit;

/**
* Simple memory pipeline channel.
* Memory pipeline channel.
*/
public final class SimpleMemoryPipelineChannel implements PipelineChannel {
public final class MemoryPipelineChannel implements PipelineChannel {

private final BlockingQueue<List<Record>> queue;

private final PipelineChannelAckCallback ackCallback;

public SimpleMemoryPipelineChannel(final int blockQueueSize, final PipelineChannelAckCallback ackCallback) {
public MemoryPipelineChannel(final int blockQueueSize, final PipelineChannelAckCallback ackCallback) {
queue = blockQueueSize < 1 ? new SynchronousQueue<>(true) : new ArrayBlockingQueue<>(blockQueueSize, true);
this.ackCallback = ackCallback;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
package org.apache.shardingsphere.data.pipeline.core.channel.memory;

import org.apache.shardingsphere.data.pipeline.core.channel.PipelineChannel;
import org.apache.shardingsphere.data.pipeline.core.channel.PipelineChannelCreator;
import org.apache.shardingsphere.data.pipeline.core.channel.PipelineChannelAckCallback;
import org.apache.shardingsphere.data.pipeline.core.channel.PipelineChannelCreator;

import java.util.Properties;

Expand All @@ -32,18 +32,17 @@ public final class MemoryPipelineChannelCreator implements PipelineChannelCreato

private static final String BLOCK_QUEUE_SIZE_DEFAULT_VALUE = "2000";

private int blockQueueSize;
private int queueSize;

@Override
public void init(final Properties props) {
blockQueueSize = Integer.parseInt(props.getProperty(BLOCK_QUEUE_SIZE_KEY, BLOCK_QUEUE_SIZE_DEFAULT_VALUE));
queueSize = Integer.parseInt(props.getProperty(BLOCK_QUEUE_SIZE_KEY, BLOCK_QUEUE_SIZE_DEFAULT_VALUE));
}

@Override
public PipelineChannel newInstance(final int outputConcurrency, final int averageElementSize, final PipelineChannelAckCallback ackCallback) {
return 1 == outputConcurrency
? new SimpleMemoryPipelineChannel(blockQueueSize / averageElementSize, ackCallback)
: new MultiplexMemoryPipelineChannel(outputConcurrency, blockQueueSize, ackCallback);
public PipelineChannel newInstance(final int importerBatchSize, final PipelineChannelAckCallback ackCallback) {
int queueSize = this.queueSize / importerBatchSize;
return new MemoryPipelineChannel(0 == queueSize ? 1 : queueSize, ackCallback);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import lombok.AccessLevel;
import lombok.NoArgsConstructor;
import org.apache.shardingsphere.data.pipeline.core.channel.MultiplexPipelineChannel;
import org.apache.shardingsphere.data.pipeline.core.channel.PipelineChannel;
import org.apache.shardingsphere.data.pipeline.core.channel.PipelineChannelCreator;
import org.apache.shardingsphere.data.pipeline.core.ingest.dumper.context.InventoryDumperContext;
Expand Down Expand Up @@ -63,26 +64,28 @@ public static IncrementalTaskProgress createIncrementalTaskProgress(final Ingest
}

/**
* Create channel for inventory task.
* Create pipeline channel for inventory task.
*
* @param pipelineChannelCreator channel creator
* @param averageElementSize average element size
* @param channelCreator pipeline channel creator
* @param importerBatchSize importer batch size
* @param position ingest position
* @return channel
* @return created pipeline channel
*/
public static PipelineChannel createInventoryChannel(final PipelineChannelCreator pipelineChannelCreator, final int averageElementSize, final AtomicReference<IngestPosition> position) {
return pipelineChannelCreator.newInstance(1, averageElementSize, new InventoryTaskAckCallback(position));
public static PipelineChannel createInventoryChannel(final PipelineChannelCreator channelCreator, final int importerBatchSize, final AtomicReference<IngestPosition> position) {
return channelCreator.newInstance(importerBatchSize, new InventoryTaskAckCallback(position));
}

/**
* Create incremental channel.
* Create pipeline channel for incremental task.
*
* @param concurrency output concurrency
* @param pipelineChannelCreator channel creator
* @param channelCreator pipeline channel creator
* @param progress incremental task progress
* @return channel
* @return created pipeline channel
*/
public static PipelineChannel createIncrementalChannel(final int concurrency, final PipelineChannelCreator pipelineChannelCreator, final IncrementalTaskProgress progress) {
return pipelineChannelCreator.newInstance(concurrency, 5, new IncrementalTaskAckCallback(progress));
public static PipelineChannel createIncrementalChannel(final int concurrency, final PipelineChannelCreator channelCreator, final IncrementalTaskProgress progress) {
return 1 == concurrency
? channelCreator.newInstance(5, new IncrementalTaskAckCallback(progress))
: new MultiplexPipelineChannel(concurrency, channelCreator, 5, new IncrementalTaskAckCallback(progress));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,40 +17,27 @@

package org.apache.shardingsphere.data.pipeline.core.channel.memory;

import org.apache.shardingsphere.data.pipeline.core.channel.PipelineChannelAckCallback;
import org.apache.shardingsphere.data.pipeline.core.channel.PipelineChannelCreator;
import org.apache.shardingsphere.infra.spi.type.typed.TypedSPILoader;
import org.apache.shardingsphere.test.util.PropertiesBuilder;
import org.apache.shardingsphere.test.util.PropertiesBuilder.Property;
import org.junit.jupiter.api.Test;
import org.mockito.internal.configuration.plugins.Plugins;

import static org.hamcrest.CoreMatchers.instanceOf;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.mockito.Mockito.mock;

class MemoryPipelineChannelCreatorTest {

@Test
void assertInitWithBlockQueueSize() throws Exception {
PipelineChannelCreator creator = TypedSPILoader.getService(PipelineChannelCreator.class, "MEMORY", PropertiesBuilder.build(new Property("block-queue-size", "200")));
assertThat(Plugins.getMemberAccessor().get(MemoryPipelineChannelCreator.class.getDeclaredField("blockQueueSize"), creator), is(200));
assertThat(Plugins.getMemberAccessor().get(MemoryPipelineChannelCreator.class.getDeclaredField("queueSize"), creator), is(200));
}

@Test
void assertInitWithoutBlockQueueSize() throws Exception {
void assertNewInstanceWithoutBlockQueueSize() throws Exception {
PipelineChannelCreator creator = TypedSPILoader.getService(PipelineChannelCreator.class, "MEMORY");
assertThat(Plugins.getMemberAccessor().get(MemoryPipelineChannelCreator.class.getDeclaredField("blockQueueSize"), creator), is(2000));
}

@Test
void assertCreateSimpleMemoryPipelineChannel() {
assertThat(TypedSPILoader.getService(PipelineChannelCreator.class, "MEMORY").newInstance(1, 1, mock(PipelineChannelAckCallback.class)), instanceOf(SimpleMemoryPipelineChannel.class));
}

@Test
void assertCreateMultiplexMemoryPipelineChannel() {
assertThat(TypedSPILoader.getService(PipelineChannelCreator.class, "MEMORY").newInstance(2, 1, mock(PipelineChannelAckCallback.class)), instanceOf(MultiplexMemoryPipelineChannel.class));
assertThat(Plugins.getMemberAccessor().get(MemoryPipelineChannelCreator.class.getDeclaredField("queueSize"), creator), is(2000));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,12 @@
import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.jupiter.api.Assertions.assertTrue;

class SimpleMemoryPipelineChannelTest {
class MemoryPipelineChannelTest {

@SneakyThrows(InterruptedException.class)
@Test
void assertZeroQueueSizeWorks() {
SimpleMemoryPipelineChannel channel = new SimpleMemoryPipelineChannel(0, records -> {
MemoryPipelineChannel channel = new MemoryPipelineChannel(0, records -> {

});
List<Record> records = Collections.singletonList(new PlaceholderRecord(new IngestFinishedPosition()));
Expand All @@ -48,7 +48,7 @@ void assertZeroQueueSizeWorks() {

@Test
void assertFetchRecordsTimeoutCorrectly() {
SimpleMemoryPipelineChannel channel = new SimpleMemoryPipelineChannel(10, records -> {
MemoryPipelineChannel channel = new MemoryPipelineChannel(10, records -> {

});
long startMillis = System.currentTimeMillis();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,17 @@
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.SneakyThrows;
import org.apache.shardingsphere.data.pipeline.core.channel.MultiplexPipelineChannel;
import org.apache.shardingsphere.data.pipeline.core.channel.PipelineChannelAckCallback;
import org.apache.shardingsphere.data.pipeline.core.channel.PipelineChannelCreator;
import org.apache.shardingsphere.data.pipeline.core.constant.PipelineSQLOperationType;
import org.apache.shardingsphere.data.pipeline.core.ingest.position.IngestPosition;
import org.apache.shardingsphere.data.pipeline.core.ingest.position.type.placeholder.IngestPlaceholderPosition;
import org.apache.shardingsphere.data.pipeline.core.ingest.record.DataRecord;
import org.apache.shardingsphere.data.pipeline.core.ingest.record.FinishedRecord;
import org.apache.shardingsphere.data.pipeline.core.ingest.record.PlaceholderRecord;
import org.apache.shardingsphere.data.pipeline.core.ingest.record.Record;
import org.apache.shardingsphere.infra.spi.type.typed.TypedSPILoader;
import org.junit.jupiter.api.Test;

import java.security.SecureRandom;
Expand Down Expand Up @@ -61,6 +64,18 @@ void assertAckCallbackResultSortable() {
}, countDataRecord(records), records);
}

private Record[] mockRecords() {
Record[] result = new Record[100];
for (int i = 1; i <= result.length; i++) {
result[i - 1] = random.nextBoolean() ? new DataRecord(PipelineSQLOperationType.INSERT, "t1", new IntPosition(i), 0) : new PlaceholderRecord(new IntPosition(i));
}
return result;
}

private int countDataRecord(final Record[] records) {
return (int) Arrays.stream(records).filter(each -> each instanceof DataRecord).count();
}

@Test
void assertBroadcastFinishedRecord() {
execute(records -> assertThat(records.size(), is(1)), 2, new FinishedRecord(new IngestPlaceholderPosition()));
Expand All @@ -69,20 +84,20 @@ void assertBroadcastFinishedRecord() {
@SneakyThrows(InterruptedException.class)
private void execute(final PipelineChannelAckCallback ackCallback, final int recordCount, final Record... records) {
CountDownLatch countDownLatch = new CountDownLatch(recordCount);
MultiplexMemoryPipelineChannel memoryChannel = new MultiplexMemoryPipelineChannel(CHANNEL_NUMBER, 10000, ackCallback);
fetchWithMultiThreads(memoryChannel, countDownLatch);
memoryChannel.push(Arrays.asList(records));
MultiplexPipelineChannel channel = new MultiplexPipelineChannel(CHANNEL_NUMBER, TypedSPILoader.getService(PipelineChannelCreator.class, "MEMORY"), 10000, ackCallback);
fetchWithMultiThreads(channel, countDownLatch);
channel.push(Arrays.asList(records));
boolean awaitResult = countDownLatch.await(10, TimeUnit.SECONDS);
assertTrue(awaitResult, "await failed");
}

private void fetchWithMultiThreads(final MultiplexMemoryPipelineChannel memoryChannel, final CountDownLatch countDownLatch) {
private void fetchWithMultiThreads(final MultiplexPipelineChannel memoryChannel, final CountDownLatch countDownLatch) {
for (int i = 0; i < CHANNEL_NUMBER; i++) {
new Thread(() -> fetch(memoryChannel, countDownLatch)).start();
}
}

private void fetch(final MultiplexMemoryPipelineChannel memoryChannel, final CountDownLatch countDownLatch) {
private void fetch(final MultiplexPipelineChannel memoryChannel, final CountDownLatch countDownLatch) {
int maxLoopCount = 10;
for (int j = 1; j <= maxLoopCount; j++) {
List<Record> records = memoryChannel.fetch(100, 1, TimeUnit.SECONDS);
Expand All @@ -94,18 +109,6 @@ private void fetch(final MultiplexMemoryPipelineChannel memoryChannel, final Cou
}
}

private Record[] mockRecords() {
Record[] result = new Record[100];
for (int i = 1; i <= result.length; i++) {
result[i - 1] = random.nextBoolean() ? new DataRecord(PipelineSQLOperationType.INSERT, "t1", new IntPosition(i), 0) : new PlaceholderRecord(new IntPosition(i));
}
return result;
}

private int countDataRecord(final Record[] records) {
return (int) Arrays.stream(records).filter(each -> each instanceof DataRecord).count();
}

@RequiredArgsConstructor
@Getter
private static final class IntPosition implements IngestPosition {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.shardingsphere.data.pipeline.mysql.ingest;

import org.apache.shardingsphere.data.pipeline.api.type.StandardPipelineDataSourceConfiguration;
import org.apache.shardingsphere.data.pipeline.core.channel.memory.SimpleMemoryPipelineChannel;
import org.apache.shardingsphere.data.pipeline.core.channel.memory.MemoryPipelineChannel;
import org.apache.shardingsphere.data.pipeline.core.constant.PipelineSQLOperationType;
import org.apache.shardingsphere.data.pipeline.core.datasource.PipelineDataSourceManager;
import org.apache.shardingsphere.data.pipeline.core.datasource.PipelineDataSourceWrapper;
Expand Down Expand Up @@ -87,7 +87,7 @@ void setUp() throws SQLException {
IncrementalDumperContext dumperContext = createDumperContext();
initTableData(dumperContext);
PipelineTableMetaDataLoader metaDataLoader = mock(PipelineTableMetaDataLoader.class);
SimpleMemoryPipelineChannel channel = new SimpleMemoryPipelineChannel(10000, records -> {
MemoryPipelineChannel channel = new MemoryPipelineChannel(10000, records -> {

});
incrementalDumper = new MySQLIncrementalDumper(dumperContext, new BinlogPosition("binlog-000001", 4L, 0L), channel, metaDataLoader);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.shardingsphere.data.pipeline.postgresql.ingest;

import org.apache.shardingsphere.data.pipeline.api.type.StandardPipelineDataSourceConfiguration;
import org.apache.shardingsphere.data.pipeline.core.channel.memory.SimpleMemoryPipelineChannel;
import org.apache.shardingsphere.data.pipeline.core.channel.memory.MemoryPipelineChannel;
import org.apache.shardingsphere.data.pipeline.core.datasource.PipelineDataSourceManager;
import org.apache.shardingsphere.data.pipeline.core.exception.IngestException;
import org.apache.shardingsphere.data.pipeline.core.ingest.dumper.context.DumperCommonContext;
Expand Down Expand Up @@ -75,14 +75,14 @@ class PostgreSQLWALDumperTest {

private PostgreSQLWALDumper walDumper;

private SimpleMemoryPipelineChannel channel;
private MemoryPipelineChannel channel;

private final PipelineDataSourceManager dataSourceManager = new PipelineDataSourceManager();

@BeforeEach
void setUp() {
position = new WALPosition(new PostgreSQLLogSequenceNumber(LogSequenceNumber.valueOf(100L)));
channel = new SimpleMemoryPipelineChannel(10000, records -> {
channel = new MemoryPipelineChannel(10000, records -> {

});
String jdbcUrl = "jdbc:h2:mem:test;DB_CLOSE_DELAY=-1;DATABASE_TO_UPPER=false;MODE=PostgreSQL";
Expand Down

0 comments on commit 5d893c6

Please sign in to comment.