Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CAY-1089, 1127, 1130] Introduce worker-side components of SyncSGD without backup worker #1131

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions common/src/main/java/edu/snu/cay/common/param/Parameters.java
Original file line number Diff line number Diff line change
Expand Up @@ -98,4 +98,10 @@ public final class HostToBandwidthFilePath implements Name<String> {
private HostToBandwidthFilePath() {
}
}

@NamedParameter(doc = "Whether this parameter server works in synchronous way or asynchronous way.",
short_name = "synchronicity",
default_value = "async")
public final class Synchronicity implements Name<String> {
}
}
94 changes: 94 additions & 0 deletions dolphin/async/src/main/avro/syncmsg.avsc
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
/*
* Copyright (C) 2017 Seoul National University
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

[
{
// {@code SyncPushBarrier} sends this message to the driver to request {@code PushPermitMsg}.
"namespace": "edu.snu.cay.dolphin.async",
"type": "record",
"name": "RequestPushPermissionMsg",
"fields":
[
// By using roundNum information, driver can check whether this message is synchronously right request message for
// this round. If it is wrong, driver will ignore this message.
{"name": "roundNum", "type": "int"}
]
},

{
// {@code SyncPushBarrier} sends this message to the driver to notify that its mini-batch is finished.
"namespace": "edu.snu.cay.dolphin.async",
"type": "record",
"name": "MiniBatchFinishedMsg",
"fields":
[
// Using EpochIdx information, driver chooses which message to send(StartNextMiniBatchMsg or TerminateLearningMsg).
{"name": "epochIdx", "type": "int"}
]
},

{
// {@code BatchManager} sends this message to workers to permit push operation.
"namespace": "edu.snu.cay.dolphin.async",
"type": "record",
"name": "PermitPushMsg",
"fields": []
},

{
// {@code BatchManager} sends this message to workers to make them start next mini-batch.
"namespace": "edu.snu.cay.dolphin.async",
"type": "record",
"name": "StartNextMiniBatchMsg",
"fields":
[
{"name": "nextRoundNum", "type": "int"}
]
},

{
// {@code BatchManager} sends this message to workers to terminate learning.
"namespace": "edu.snu.cay.dolphin.async",
"type": "record",
"name": "TerminateLearningMsg",
"fields": []
},

{
/**
* Messages that are exchanged between worker and driver for synchronous system.
* Specifically, {@code SyncPushBarrier}(worker-side component) and {@code BatchManager}(driver-side component) are
* classes that exchange these messages.
*/
"namespace": "edu.snu.cay.dolphin.async",
"type": "record",
"name": "AvroSyncSGDMsg",
"fields":
[
{"name": "type",
"type":
{"type": "enum",
"name": "SyncSGDMsgType",
"symbols": ["RequestPushPermissionMsg", "MiniBatchFinishedMsg", "PermitPushMsg",
"StartNextMiniBatchMsg", "TerminateLearningMsg"]}},
{"name": "requestPushPermissionMsg", "type": ["null", "RequestPushPermissionMsg"], "default": null},
{"name": "miniBatchFinishedMsg", "type": ["null", "MiniBatchFinishedMsg"], "default": null},
{"name": "permitPushMsg", "type": ["null", "PermitPushMsg"], "default": null},
{"name": "startNextMiniBatchMsg", "type": ["null", "StartNextMiniBatchMsg"], "default": null},
{"name": "terminateLearningMsg", "type": ["null", "TerminateLearningMsg"], "default": null}
]
}
]
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@

import edu.snu.cay.common.client.DriverLauncher;
import edu.snu.cay.common.dataloader.TextInputFormat;
import edu.snu.cay.dolphin.async.SyncSGD.SyncSGDDriverSide.BatchManager;
import edu.snu.cay.dolphin.async.SyncSGD.SyncSGDDriverSide.DriverSideSyncSGDMsgHandler;
import edu.snu.cay.dolphin.async.SyncSGD.SyncSGDWorkerSide.api.MiniBatchBarrier;
import edu.snu.cay.dolphin.async.SyncSGD.SyncSGDWorkerSide.api.PushBarrier;
import edu.snu.cay.dolphin.async.SyncSGD.SyncSGDWorkerSide.impl.*;
import edu.snu.cay.dolphin.async.metric.*;
import edu.snu.cay.dolphin.async.dashboard.DashboardConfProvider;
import edu.snu.cay.dolphin.async.dashboard.DashboardLauncher;
Expand Down Expand Up @@ -163,6 +168,9 @@ public static LauncherStatus launch(final String jobName,
confSerializer.toString(asyncDolphinConfiguration.getServerConfiguration()))
.build();

final String synchronicity = basicParameterInjector.getNamedInstance(Synchronicity.class);
final boolean isAsync = synchronicity.equals("async");

// worker-specific configurations
// pass the worker class implementation as well as user-defined parameters
final Configuration basicWorkerConf = Tang.Factory.getTang().newConfigurationBuilder()
Expand All @@ -181,6 +189,8 @@ public static LauncherStatus launch(final String jobName,
Integer.toString(basicParameterInjector.getNamedInstance(DolphinParameters.NumTrainerThreads.class)))
.bindNamedParameter(DolphinParameters.TestDataPath.class,
basicParameterInjector.getNamedInstance(DolphinParameters.TestDataPath.class))
.bindImplementation(PushBarrier.class, isAsync ? NullPushBarrier.class : SyncPushBarrier.class)
.bindImplementation(MiniBatchBarrier.class, isAsync ? NullMiniBatchBarrier.class : SyncMiniBatchBarrier.class)
.build();
final Configuration workerConf = Configurations.merge(basicWorkerConf,
asyncDolphinConfiguration.getWorkerConfiguration());
Expand Down Expand Up @@ -212,7 +222,7 @@ public static LauncherStatus launch(final String jobName,
final Configuration dashboardConf = DashboardConfProvider.getConfiguration(dashboardEnabled);

// driver-side configurations
final Configuration driverConf = getDriverConfiguration(jobName, basicParameterInjector);
final Configuration driverConf = getDriverConfiguration(jobName, basicParameterInjector, isAsync);
final int timeout = basicParameterInjector.getNamedInstance(Timeout.class);

final LauncherStatus status = DriverLauncher.getLauncher(runTimeConf).run(
Expand Down Expand Up @@ -287,6 +297,9 @@ private static Tuple2<Configuration, Configuration> parseCommandLine(
basicParameterClassList.add(ServerMetricsWindowMs.class);
basicParameterClassList.add(PSTraceProbability.class);

// add SyncSGD parameters
basicParameterClassList.add(Synchronicity.class);

// add SSP parameters
basicParameterClassList.add(StalenessBound.class);

Expand Down Expand Up @@ -364,7 +377,7 @@ private static Configuration getLocalRuntimeConfiguration(final int maxNumEvalLo
}

private static Configuration getDriverConfiguration(
final String jobName, final Injector injector) throws InjectionException {
final String jobName, final Injector injector, final boolean isAsync) throws InjectionException {
final ConfigurationModule driverConf = DriverConfiguration.CONF
.set(DriverConfiguration.GLOBAL_LIBRARIES, EnvironmentUtils.getClassLocation(AsyncDolphinDriver.class))
.set(DriverConfiguration.DRIVER_IDENTIFIER, jobName)
Expand All @@ -391,7 +404,7 @@ private static Configuration getDriverConfiguration(
final int stalenessBound = injector.getNamedInstance(StalenessBound.class);
final boolean isSSPModel = stalenessBound >= 0;
final CentCommConf centCommConf = isSSPModel ?
getCentCommConfForSSP() : getDefaultCentCommConf();
getCentCommConfForSSP() : getDefaultCentCommConf(isAsync);
// set up an optimizer configuration
final Class<? extends Optimizer> optimizerClass;
final Class<? extends PlanExecutor> executorClass;
Expand Down Expand Up @@ -432,12 +445,23 @@ private static CentCommConf.Builder getCentCommConfDefaultBuilder() {
EvalSideMetricsMsgHandlerForServer.class);
}

private static CentCommConf getDefaultCentCommConf() {
return getCentCommConfDefaultBuilder()
.addCentCommClient(ClockManager.CENT_COMM_CLIENT_NAME,
ClockManager.MessageHandler.class,
AsyncWorkerClock.MessageHandler.class)
.build();
private static CentCommConf getDefaultCentCommConf(final boolean isAsync) {
if (isAsync) {
return getCentCommConfDefaultBuilder()
.addCentCommClient(ClockManager.CENT_COMM_CLIENT_NAME,
ClockManager.MessageHandler.class,
AsyncWorkerClock.MessageHandler.class)
.build();
} else {
return getCentCommConfDefaultBuilder()
.addCentCommClient(ClockManager.CENT_COMM_CLIENT_NAME,
ClockManager.MessageHandler.class,
AsyncWorkerClock.MessageHandler.class)
.addCentCommClient(BatchManager.CENT_COMM_CLIENT_NAME,
DriverSideSyncSGDMsgHandler.class,
WorkerSideSyncSGDMsgHandler.class)
.build();
}
}

private static CentCommConf getCentCommConfForSSP() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,16 @@

import edu.snu.cay.common.metric.MetricsMsgSender;
import edu.snu.cay.common.metric.avro.Metrics;
import edu.snu.cay.common.param.Parameters;
import edu.snu.cay.dolphin.async.SyncSGD.SyncSGDWorkerSide.api.MiniBatchBarrier;
import edu.snu.cay.dolphin.async.SyncSGD.SyncSGDWorkerSide.impl.LearningState;
import edu.snu.cay.dolphin.async.metric.avro.WorkerMetrics;
import edu.snu.cay.services.em.common.parameters.AddedEval;
import edu.snu.cay.services.em.evaluator.api.MemoryStore;
import edu.snu.cay.services.ps.worker.api.ParameterWorker;
import edu.snu.cay.services.ps.worker.api.WorkerClock;
import edu.snu.cay.utils.HostnameResolver;
import edu.snu.cay.utils.StateMachine;
import org.apache.reef.driver.task.TaskConfigurationOptions.Identifier;
import org.apache.reef.tang.annotations.Parameter;
import org.apache.reef.task.Task;
Expand Down Expand Up @@ -61,6 +65,35 @@ final class AsyncWorkerTask<K, V> implements Task {
private final boolean addedEval;
private final String hostname;

/**
* Determine whether client wants synchronous or asynchronous workers.
* If true, client wants asynchronously working workers.
* If false, client wants synchronously working workers.
*/
private final boolean isAsync;

/**
* MiniBatchBarrier helps all the workers to start next mini-batch synchronously in synchronous system.
*/
private final MiniBatchBarrier miniBatchBarrier;

/**
* Three WorkerTask's states. These states are for synchronous worker.
*/
private final StateMachine stateMachine;
private enum State {
MINI_BATCH_RUNNING,
WAITING_NEXT_MINI_BATCH,
MINI_BATCH_CLOSING
}

/**
* Flag which indicates learning state of this worker in synchronous system.
* When this worker receive {@code terminateLearningMsg} from driver, this flag will be changed to
* {@code LearningState.TerminateLearning}.
*/
private LearningState learningFlag = LearningState.ProgressLearning;

/**
* A boolean flag shared among all trainer threads.
* Trainer threads end when this flag becomes true by {@link #close()}.
Expand All @@ -72,14 +105,16 @@ private AsyncWorkerTask(@Parameter(Identifier.class) final String taskId,
@Parameter(DolphinParameters.MaxNumEpochs.class) final int maxNumEpochs,
@Parameter(DolphinParameters.MiniBatchSize.class) final int miniBatchSize,
@Parameter(AddedEval.class) final boolean addedEval,
@Parameter(Parameters.Synchronicity.class) final String synchronicity,
final WorkerSynchronizer synchronizer,
final ParameterWorker parameterWorker,
final TrainingDataProvider<K, V> trainingDataProvider,
final TestDataProvider<V> testDataProvider,
final MemoryStore<K> memoryStore,
final Trainer<V> trainer,
final MetricsMsgSender<WorkerMetrics> metricsMsgSender,
final WorkerClock workerClock) {
final WorkerClock workerClock,
final MiniBatchBarrier miniBatchBarrier) {
this.taskId = taskId;
this.maxNumEpochs = maxNumEpochs;
this.miniBatchSize = miniBatchSize;
Expand All @@ -93,6 +128,26 @@ private AsyncWorkerTask(@Parameter(Identifier.class) final String taskId,
this.metricsMsgSender = metricsMsgSender;
this.workerClock = workerClock;
this.hostname = HostnameResolver.resolve();
this.miniBatchBarrier = miniBatchBarrier;
this.stateMachine = initStateMachine();
this.isAsync = synchronicity.equals("async");
}

private StateMachine initStateMachine() {
return StateMachine.newBuilder()
.addState(State.MINI_BATCH_RUNNING, "Mini-batch is running now.")
.addState(State.WAITING_NEXT_MINI_BATCH, "Mini-batch is finished and waiting for StartNextMiniBatchMsg " +
"from driver")
.addState(State.MINI_BATCH_CLOSING, "This worker is slow worker. Mini-batch is stopped and being closed" +
" after it had received StartNextMiniBatchMsg from driver.")
.addTransition(State.MINI_BATCH_RUNNING, State.WAITING_NEXT_MINI_BATCH, "Mini-batch is finished.")
.addTransition(State.MINI_BATCH_RUNNING, State.MINI_BATCH_CLOSING, "This worker is slow worker. Mini-batch" +
" is not finished but should be closed for next mini-batch.")
.addTransition(State.WAITING_NEXT_MINI_BATCH, State.MINI_BATCH_RUNNING, "New mini-batch is started.")
.addTransition(State.MINI_BATCH_CLOSING, State.MINI_BATCH_RUNNING, "Previous mini-batch is closed " +
"successfully. New mini-batch is started.")
.setInitialState(State.MINI_BATCH_RUNNING)
.build();
}

@Override
Expand Down Expand Up @@ -125,7 +180,11 @@ public byte[] call(final byte[] memento) throws Exception {
// By starting epochs from the initial clock, which is dynamically fetched from driver,
// it prevents workers added by EM from starting from epoch 0 and deferring job completion.
// More specifically, added workers start from the minimum epoch index of other existing workers.
for (int epochIdx = initialClock; epochIdx < maxNumEpochs; ++epochIdx) {
for (int epochIdx = initialClock;; ++epochIdx) {
// If asynchronous system is chosen, learning is terminated when epochIdx == maxNumEpochs.
if (isAsync && epochIdx == maxNumEpochs) {
break;
}
LOG.log(Level.INFO, "Starting epoch {0}", epochIdx);
final long epochStartTime = System.currentTimeMillis();
final int numEMBlocks = memoryStore.getNumBlocks();
Expand All @@ -145,6 +204,10 @@ public byte[] call(final byte[] memento) throws Exception {
final MiniBatchResult miniBatchResult = trainer.runMiniBatch(miniBatchTrainingData);
final double miniBatchElapsedTime = (System.currentTimeMillis() - miniBatchStartTime) / 1000.0D;

stateMachine.setState(State.WAITING_NEXT_MINI_BATCH);
learningFlag = miniBatchBarrier.waitMiniBatchControlMsgFromDriver(epochIdx);
stateMachine.setState(State.MINI_BATCH_RUNNING);

buildAndSendMiniBatchMetrics(miniBatchResult, epochIdx, miniBatchIdx,
miniBatchTrainingData.size(), miniBatchElapsedTime);

Expand All @@ -157,6 +220,10 @@ public byte[] call(final byte[] memento) throws Exception {
workerClock.recordClockNetworkWaitingTime();
return null;
}

if (learningFlag == LearningState.TerminateLearning) {
break;
}
}

final EpochResult epochResult = trainer.onEpochFinished(epochTrainingData, testData, epochIdx);
Expand All @@ -167,6 +234,9 @@ public byte[] call(final byte[] memento) throws Exception {

// TODO #830: Clock should be a unit of mini-batch instead of epoch
workerClock.clock();
if (learningFlag == LearningState.TerminateLearning) {
break;
}
}

// Synchronize all workers before cleanup for workers
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/
package edu.snu.cay.dolphin.async;

import edu.snu.cay.dolphin.async.SyncSGD.SyncSGDWorkerSide.api.PushBarrier;
import edu.snu.cay.services.ps.worker.api.ParameterWorker;

import javax.inject.Inject;
Expand All @@ -26,14 +27,18 @@
public class PSModelAccessor<K, P, V> implements ModelAccessor<K, P, V> {

private final ParameterWorker<K, P, V> parameterWorker;
private final PushBarrier pushBarrier;

@Inject
PSModelAccessor(final ParameterWorker<K, P, V> parameterWorker) {
PSModelAccessor(final ParameterWorker<K, P, V> parameterWorker,
final PushBarrier pushBarrier) {
this.parameterWorker = parameterWorker;
this.pushBarrier = pushBarrier;
}

@Override
public void push(final K key, final P deltaValue) {
pushBarrier.requestPushPermission();
parameterWorker.push(key, deltaValue);
}

Expand Down
Loading