Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor ProcessListChangedSubscriber #34203

Merged
merged 9 commits into from
Dec 29, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.apache.shardingsphere.infra.executor.sql.execute.engine.driver.jdbc.JDBCExecutionUnit;
import org.apache.shardingsphere.infra.metadata.user.Grantee;

import java.sql.SQLException;
import java.sql.Statement;
import java.util.LinkedHashMap;
import java.util.Map;
Expand Down Expand Up @@ -151,4 +152,16 @@ public boolean isIdle() {
public void removeProcessStatement(final ExecutionUnit executionUnit) {
processStatements.remove(System.identityHashCode(executionUnit));
}

/**
* Kill process.
*
* @throws SQLException SQL exception
*/
public void kill() throws SQLException {
setInterrupted(true);
for (Statement each : processStatements.values()) {
each.cancel();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.apache.shardingsphere.infra.exception.core.ShardingSpherePreconditions;
import org.apache.shardingsphere.infra.exception.kernel.connection.SQLExecutionInterruptedException;

import java.sql.SQLException;
import java.util.Collection;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
Expand Down Expand Up @@ -94,11 +95,24 @@ public void remove(final String id) {
}

/**
* List all process.
* List all processes.
*
* @return all processes
*/
public Collection<Process> listAll() {
return processes.values();
}

/**
* Kill process.
*
* @param processId process ID
* @throws SQLException SQL exception
*/
public void kill(final String processId) throws SQLException {
Process process = ProcessRegistry.getInstance().get(processId);
if (null != process) {
process.kill();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,14 @@
*/
public interface ProcessPersistService {

/**
* Report local processes.
*
* @param instanceId instance ID
* @param taskId task ID
*/
void reportLocalProcesses(String instanceId, String taskId);

/**
* Get process list.
*
Expand All @@ -41,4 +49,12 @@ public interface ProcessPersistService {
* @throws SQLException SQL exception
*/
void killProcess(String processId) throws SQLException;

/**
* Clean process.
*
* @param instanceId instance ID
* @param processId process ID
*/
void cleanProcess(String instanceId, String processId);
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,57 +18,42 @@
package org.apache.shardingsphere.mode.manager.cluster.event.dispatch.subscriber.type;

import com.google.common.eventbus.Subscribe;
import org.apache.shardingsphere.infra.executor.sql.process.Process;
import org.apache.shardingsphere.infra.executor.sql.process.ProcessRegistry;
import org.apache.shardingsphere.infra.executor.sql.process.lock.ProcessOperationLockRegistry;
import org.apache.shardingsphere.infra.executor.sql.process.yaml.swapper.YamlProcessListSwapper;
import org.apache.shardingsphere.infra.util.yaml.YamlEngine;
import org.apache.shardingsphere.metadata.persist.node.ComputeNode;
import org.apache.shardingsphere.metadata.persist.node.ProcessNode;
import org.apache.shardingsphere.mode.manager.ContextManager;
import org.apache.shardingsphere.mode.manager.cluster.event.dispatch.event.state.compute.KillLocalProcessCompletedEvent;
import org.apache.shardingsphere.mode.manager.cluster.event.dispatch.event.state.compute.KillLocalProcessEvent;
import org.apache.shardingsphere.mode.manager.cluster.event.dispatch.event.state.compute.ReportLocalProcessesCompletedEvent;
import org.apache.shardingsphere.mode.manager.cluster.event.dispatch.event.state.compute.ReportLocalProcessesEvent;
import org.apache.shardingsphere.mode.manager.ContextManager;
import org.apache.shardingsphere.mode.manager.cluster.event.dispatch.subscriber.DispatchEventSubscriber;
import org.apache.shardingsphere.mode.spi.PersistRepository;
import org.apache.shardingsphere.mode.persist.service.divided.ProcessPersistService;

import java.sql.SQLException;
import java.sql.Statement;
import java.util.Collection;

/**
* Process list changed subscriber.
*/
public final class ProcessListChangedSubscriber implements DispatchEventSubscriber {

private final ContextManager contextManager;

private final PersistRepository repository;
private final String instanceId;

private final YamlProcessListSwapper swapper;
private final ProcessPersistService processPersistService;

public ProcessListChangedSubscriber(final ContextManager contextManager) {
this.contextManager = contextManager;
repository = contextManager.getPersistServiceFacade().getRepository();
swapper = new YamlProcessListSwapper();
instanceId = contextManager.getComputeNodeInstanceContext().getInstance().getMetaData().getId();
processPersistService = contextManager.getPersistServiceFacade().getProcessPersistService();
}

/**
* Report local processes.
*
* @param event show process list trigger event
* @param event report local processes event
*/
@Subscribe
public void reportLocalProcesses(final ReportLocalProcessesEvent event) {
if (!event.getInstanceId().equals(contextManager.getComputeNodeInstanceContext().getInstance().getMetaData().getId())) {
return;
}
Collection<Process> processes = ProcessRegistry.getInstance().listAll();
if (!processes.isEmpty()) {
repository.persist(ProcessNode.getProcessListInstancePath(event.getTaskId(), event.getInstanceId()), YamlEngine.marshal(swapper.swapToYamlConfiguration(processes)));
if (event.getInstanceId().equals(instanceId)) {
processPersistService.reportLocalProcesses(instanceId, event.getTaskId());
}
repository.delete(ComputeNode.getProcessTriggerInstanceNodePath(event.getInstanceId(), event.getTaskId()));
}

/**
Expand All @@ -89,17 +74,11 @@ public synchronized void completeToReportLocalProcesses(final ReportLocalProcess
*/
@Subscribe
public synchronized void killLocalProcess(final KillLocalProcessEvent event) throws SQLException {
if (!event.getInstanceId().equals(contextManager.getComputeNodeInstanceContext().getInstance().getMetaData().getId())) {
if (!event.getInstanceId().equals(instanceId)) {
return;
}
Process process = ProcessRegistry.getInstance().get(event.getProcessId());
if (null != process) {
process.setInterrupted(true);
for (Statement each : process.getProcessStatements().values()) {
each.cancel();
}
}
repository.delete(ComputeNode.getProcessKillInstanceIdNodePath(event.getInstanceId(), event.getProcessId()));
ProcessRegistry.getInstance().kill(event.getProcessId());
processPersistService.cleanProcess(instanceId, event.getProcessId());
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import lombok.RequiredArgsConstructor;
import org.apache.shardingsphere.infra.executor.sql.process.Process;
import org.apache.shardingsphere.infra.executor.sql.process.ProcessRegistry;
import org.apache.shardingsphere.infra.executor.sql.process.lock.ProcessOperationLockRegistry;
import org.apache.shardingsphere.infra.executor.sql.process.yaml.YamlProcessList;
import org.apache.shardingsphere.infra.executor.sql.process.yaml.swapper.YamlProcessListSwapper;
Expand All @@ -43,6 +44,17 @@ public final class ClusterProcessPersistService implements ProcessPersistService

private final PersistRepository repository;

private final YamlProcessListSwapper swapper = new YamlProcessListSwapper();

@Override
public void reportLocalProcesses(final String instanceId, final String taskId) {
Collection<Process> processes = ProcessRegistry.getInstance().listAll();
if (!processes.isEmpty()) {
repository.persist(ProcessNode.getProcessListInstancePath(taskId, instanceId), YamlEngine.marshal(swapper.swapToYamlConfiguration(processes)));
}
repository.delete(ComputeNode.getProcessTriggerInstanceNodePath(instanceId, taskId));
}

@Override
public Collection<Process> getProcessList() {
String taskId = new UUID(ThreadLocalRandom.current().nextLong(), ThreadLocalRandom.current().nextLong()).toString().replace("-", "");
Expand Down Expand Up @@ -98,4 +110,9 @@ private Collection<String> getKillProcessTriggerPaths(final String processId) {
private boolean isReady(final Collection<String> paths) {
return paths.stream().noneMatch(each -> null != repository.query(each));
}

@Override
public void cleanProcess(final String instanceId, final String processId) {
repository.delete(ComputeNode.getProcessKillInstanceIdNodePath(instanceId, processId));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,13 @@

package org.apache.shardingsphere.mode.manager.cluster.event.dispatch.subscriber.type;

import org.apache.shardingsphere.infra.executor.sql.process.Process;
import org.apache.shardingsphere.infra.executor.sql.process.ProcessRegistry;
import org.apache.shardingsphere.infra.executor.sql.process.lock.ProcessOperationLockRegistry;
import org.apache.shardingsphere.mode.manager.ContextManager;
import org.apache.shardingsphere.mode.manager.cluster.event.dispatch.event.state.compute.KillLocalProcessCompletedEvent;
import org.apache.shardingsphere.mode.manager.cluster.event.dispatch.event.state.compute.KillLocalProcessEvent;
import org.apache.shardingsphere.mode.manager.cluster.event.dispatch.event.state.compute.ReportLocalProcessesCompletedEvent;
import org.apache.shardingsphere.mode.manager.cluster.event.dispatch.event.state.compute.ReportLocalProcessesEvent;
import org.apache.shardingsphere.mode.manager.ContextManager;
import org.apache.shardingsphere.mode.repository.cluster.ClusterPersistRepository;
import org.apache.shardingsphere.test.mock.AutoMockExtension;
import org.apache.shardingsphere.test.mock.StaticMockSettings;
Expand All @@ -37,12 +36,9 @@
import org.mockito.quality.Strictness;

import java.sql.SQLException;
import java.sql.Statement;
import java.util.Collections;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.RETURNS_DEEP_STUBS;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
Expand Down Expand Up @@ -72,19 +68,10 @@ void assertReportLocalProcessesWithNotCurrentInstance() {
}

@Test
void assertReportEmptyLocalProcesses() {
void assertReportLocalProcesses() {
when(ProcessRegistry.getInstance().listAll()).thenReturn(Collections.emptyList());
subscriber.reportLocalProcesses(new ReportLocalProcessesEvent("foo_instance_id", "foo_task_id"));
verify(contextManager.getPersistServiceFacade().getRepository(), times(0)).persist(any(), any());
verify(contextManager.getPersistServiceFacade().getRepository()).delete("/nodes/compute_nodes/show_process_list_trigger/foo_instance_id:foo_task_id");
}

@Test
void assertReportNotEmptyLocalProcesses() {
when(ProcessRegistry.getInstance().listAll()).thenReturn(Collections.singleton(mock(Process.class, RETURNS_DEEP_STUBS)));
subscriber.reportLocalProcesses(new ReportLocalProcessesEvent("foo_instance_id", "foo_task_id"));
verify(contextManager.getPersistServiceFacade().getRepository()).persist(eq("/execution_nodes/foo_task_id/foo_instance_id"), any());
verify(contextManager.getPersistServiceFacade().getRepository()).delete("/nodes/compute_nodes/show_process_list_trigger/foo_instance_id:foo_task_id");
verify(contextManager.getPersistServiceFacade().getProcessPersistService()).reportLocalProcesses("foo_instance_id", "foo_task_id");
}

@Test
Expand All @@ -94,28 +81,15 @@ void assertCompleteToReportLocalProcesses() {
}

@Test
void assertKillLocalProcessWithNotCurrentInstance() throws SQLException {
subscriber.killLocalProcess(new KillLocalProcessEvent("bar_instance_id", "foo_pid"));
verify(contextManager.getPersistServiceFacade().getRepository(), times(0)).delete(any());
}

@Test
void assertKillLocalProcessWithoutExistedProcess() throws SQLException {
when(ProcessRegistry.getInstance().get("foo_pid")).thenReturn(null);
void assertKillLocalProcessWithCurrentInstance() throws SQLException {
subscriber.killLocalProcess(new KillLocalProcessEvent("foo_instance_id", "foo_pid"));
verify(contextManager.getPersistServiceFacade().getRepository()).delete("/nodes/compute_nodes/kill_process_trigger/foo_instance_id:foo_pid");
verify(contextManager.getPersistServiceFacade().getProcessPersistService()).cleanProcess("foo_instance_id", "foo_pid");
}

@Test
void assertKillLocalProcessWithExistedProcess() throws SQLException {
Process process = mock(Process.class, RETURNS_DEEP_STUBS);
Statement statement = mock(Statement.class);
when(process.getProcessStatements()).thenReturn(Collections.singletonMap(1, statement));
when(ProcessRegistry.getInstance().get("foo_pid")).thenReturn(process);
subscriber.killLocalProcess(new KillLocalProcessEvent("foo_instance_id", "foo_pid"));
verify(process).setInterrupted(true);
verify(statement).cancel();
verify(contextManager.getPersistServiceFacade().getRepository()).delete("/nodes/compute_nodes/kill_process_trigger/foo_instance_id:foo_pid");
void assertKillLocalProcessWithNotCurrentInstance() throws SQLException {
subscriber.killLocalProcess(new KillLocalProcessEvent("bar_instance_id", "foo_pid"));
verify(contextManager.getPersistServiceFacade().getProcessPersistService(), times(0)).cleanProcess("bar_instance_id", "foo_pid");
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.shardingsphere.mode.manager.cluster.persist.service;

import org.apache.shardingsphere.infra.executor.sql.process.Process;
import org.apache.shardingsphere.infra.executor.sql.process.ProcessRegistry;
import org.apache.shardingsphere.infra.executor.sql.process.lock.ProcessOperationLockRegistry;
import org.apache.shardingsphere.infra.executor.sql.process.yaml.YamlProcess;
import org.apache.shardingsphere.infra.executor.sql.process.yaml.YamlProcessList;
Expand All @@ -40,12 +41,14 @@
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.contains;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.RETURNS_DEEP_STUBS;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

@ExtendWith(AutoMockExtension.class)
@StaticMockSettings(ProcessOperationLockRegistry.class)
@StaticMockSettings({ProcessRegistry.class, ProcessOperationLockRegistry.class})
class ClusterProcessPersistServiceTest {

@Mock
Expand All @@ -58,6 +61,22 @@ void setUp() {
processPersistService = new ClusterProcessPersistService(repository);
}

@Test
void assertReportEmptyLocalProcesses() {
when(ProcessRegistry.getInstance().listAll()).thenReturn(Collections.emptyList());
processPersistService.reportLocalProcesses("foo_instance_id", "foo_task_id");
verify(repository, times(0)).persist(any(), any());
verify(repository).delete("/nodes/compute_nodes/show_process_list_trigger/foo_instance_id:foo_task_id");
}

@Test
void assertReportNotEmptyLocalProcesses() {
when(ProcessRegistry.getInstance().listAll()).thenReturn(Collections.singleton(mock(Process.class, RETURNS_DEEP_STUBS)));
processPersistService.reportLocalProcesses("foo_instance_id", "foo_task_id");
verify(repository).persist(eq("/execution_nodes/foo_task_id/foo_instance_id"), any());
verify(repository).delete("/nodes/compute_nodes/show_process_list_trigger/foo_instance_id:foo_task_id");
}

@Test
void assertGetCompletedProcessList() {
when(ProcessOperationLockRegistry.getInstance().waitUntilReleaseReady(any(), any())).thenReturn(true);
Expand Down Expand Up @@ -111,4 +130,10 @@ private void assertKillProcess() {
processPersistService.killProcess("foo_process_id");
verify(repository).persist("/nodes/compute_nodes/kill_process_trigger/abc:foo_process_id", "");
}

@Test
void assertCleanProcess() {
processPersistService.cleanProcess("foo_instance_id", "foo_pid");
verify(repository).delete("/nodes/compute_nodes/kill_process_trigger/foo_instance_id:foo_pid");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,27 +22,28 @@
import org.apache.shardingsphere.mode.persist.service.divided.ProcessPersistService;

import java.sql.SQLException;
import java.sql.Statement;
import java.util.Collection;

/**
* Standalone process persist service.
*/
public final class StandaloneProcessPersistService implements ProcessPersistService {

@Override
public void reportLocalProcesses(final String instanceId, final String taskId) {
}

@Override
public Collection<Process> getProcessList() {
return ProcessRegistry.getInstance().listAll();
}

@Override
public void killProcess(final String processId) throws SQLException {
Process process = ProcessRegistry.getInstance().get(processId);
if (null == process) {
return;
}
for (Statement each : process.getProcessStatements().values()) {
each.cancel();
}
ProcessRegistry.getInstance().kill(processId);
}

@Override
public void cleanProcess(final String instanceId, final String processId) {
}
}
Loading
Loading