diff --git a/common/src/main/java/edu/snu/cay/common/param/Parameters.java b/common/src/main/java/edu/snu/cay/common/param/Parameters.java index 9c537f388..36114c6a2 100644 --- a/common/src/main/java/edu/snu/cay/common/param/Parameters.java +++ b/common/src/main/java/edu/snu/cay/common/param/Parameters.java @@ -98,4 +98,10 @@ public final class HostToBandwidthFilePath implements Name { private HostToBandwidthFilePath() { } } + + @NamedParameter(doc = "Whether this parameter server works in synchronous way or asynchronous way.", + short_name = "synchronicity", + default_value = "async") + public final class Synchronicity implements Name { + } } diff --git a/dolphin/async/src/main/avro/syncmsg.avsc b/dolphin/async/src/main/avro/syncmsg.avsc new file mode 100644 index 000000000..7089d99dd --- /dev/null +++ b/dolphin/async/src/main/avro/syncmsg.avsc @@ -0,0 +1,94 @@ +/* + * Copyright (C) 2017 Seoul National University + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +[ +{ + // {@code SyncPushBarrier} sends this message to the driver to request {@code PushPermitMsg}. + "namespace": "edu.snu.cay.dolphin.async", + "type": "record", + "name": "RequestPushPermissionMsg", + "fields": + [ + // By using roundNum information, driver can check whether this message is synchronously right request message for + // this round. If it is wrong, driver will ignore this message. + {"name": "roundNum", "type": "int"} + ] +}, + +{ + // {@code SyncPushBarrier} sends this message to the driver to notify that its mini-batch is finished. + "namespace": "edu.snu.cay.dolphin.async", + "type": "record", + "name": "MiniBatchFinishedMsg", + "fields": + [ + // Using EpochIdx information, driver chooses which message to send(StartNextMiniBatchMsg or TerminateLearningMsg). + {"name": "epochIdx", "type": "int"} + ] +}, + +{ + // {@code BatchManager} sends this message to workers to permit push operation. + "namespace": "edu.snu.cay.dolphin.async", + "type": "record", + "name": "PermitPushMsg", + "fields": [] +}, + +{ + // {@code BatchManager} sends this message to workers to make them start next mini-batch. + "namespace": "edu.snu.cay.dolphin.async", + "type": "record", + "name": "StartNextMiniBatchMsg", + "fields": + [ + {"name": "nextRoundNum", "type": "int"} + ] +}, + +{ + // {@code BatchManager} sends this message to workers to terminate learning. + "namespace": "edu.snu.cay.dolphin.async", + "type": "record", + "name": "TerminateLearningMsg", + "fields": [] +}, + +{ + /** + * Messages that are exchanged between worker and driver for synchronous system. + * Specifically, {@code SyncPushBarrier}(worker-side component) and {@code BatchManager}(driver-side component) are + * classes that exchange these messages. + */ + "namespace": "edu.snu.cay.dolphin.async", + "type": "record", + "name": "AvroSyncSGDMsg", + "fields": + [ + {"name": "type", + "type": + {"type": "enum", + "name": "SyncSGDMsgType", + "symbols": ["RequestPushPermissionMsg", "MiniBatchFinishedMsg", "PermitPushMsg", + "StartNextMiniBatchMsg", "TerminateLearningMsg"]}}, + {"name": "requestPushPermissionMsg", "type": ["null", "RequestPushPermissionMsg"], "default": null}, + {"name": "miniBatchFinishedMsg", "type": ["null", "MiniBatchFinishedMsg"], "default": null}, + {"name": "permitPushMsg", "type": ["null", "PermitPushMsg"], "default": null}, + {"name": "startNextMiniBatchMsg", "type": ["null", "StartNextMiniBatchMsg"], "default": null}, + {"name": "terminateLearningMsg", "type": ["null", "TerminateLearningMsg"], "default": null} + ] +} +] diff --git a/dolphin/async/src/main/java/edu/snu/cay/dolphin/async/AsyncDolphinLauncher.java b/dolphin/async/src/main/java/edu/snu/cay/dolphin/async/AsyncDolphinLauncher.java index 3292ab2bf..401e796bd 100644 --- a/dolphin/async/src/main/java/edu/snu/cay/dolphin/async/AsyncDolphinLauncher.java +++ b/dolphin/async/src/main/java/edu/snu/cay/dolphin/async/AsyncDolphinLauncher.java @@ -17,6 +17,11 @@ import edu.snu.cay.common.client.DriverLauncher; import edu.snu.cay.common.dataloader.TextInputFormat; +import edu.snu.cay.dolphin.async.SyncSGD.SyncSGDDriverSide.BatchManager; +import edu.snu.cay.dolphin.async.SyncSGD.SyncSGDDriverSide.DriverSideSyncSGDMsgHandler; +import edu.snu.cay.dolphin.async.SyncSGD.SyncSGDWorkerSide.api.MiniBatchBarrier; +import edu.snu.cay.dolphin.async.SyncSGD.SyncSGDWorkerSide.api.PushBarrier; +import edu.snu.cay.dolphin.async.SyncSGD.SyncSGDWorkerSide.impl.*; import edu.snu.cay.dolphin.async.metric.*; import edu.snu.cay.dolphin.async.dashboard.DashboardConfProvider; import edu.snu.cay.dolphin.async.dashboard.DashboardLauncher; @@ -163,6 +168,9 @@ public static LauncherStatus launch(final String jobName, confSerializer.toString(asyncDolphinConfiguration.getServerConfiguration())) .build(); + final String synchronicity = basicParameterInjector.getNamedInstance(Synchronicity.class); + final boolean isAsync = synchronicity.equals("async"); + // worker-specific configurations // pass the worker class implementation as well as user-defined parameters final Configuration basicWorkerConf = Tang.Factory.getTang().newConfigurationBuilder() @@ -181,6 +189,8 @@ public static LauncherStatus launch(final String jobName, Integer.toString(basicParameterInjector.getNamedInstance(DolphinParameters.NumTrainerThreads.class))) .bindNamedParameter(DolphinParameters.TestDataPath.class, basicParameterInjector.getNamedInstance(DolphinParameters.TestDataPath.class)) + .bindImplementation(PushBarrier.class, isAsync ? NullPushBarrier.class : SyncPushBarrier.class) + .bindImplementation(MiniBatchBarrier.class, isAsync ? NullMiniBatchBarrier.class : SyncMiniBatchBarrier.class) .build(); final Configuration workerConf = Configurations.merge(basicWorkerConf, asyncDolphinConfiguration.getWorkerConfiguration()); @@ -212,7 +222,7 @@ public static LauncherStatus launch(final String jobName, final Configuration dashboardConf = DashboardConfProvider.getConfiguration(dashboardEnabled); // driver-side configurations - final Configuration driverConf = getDriverConfiguration(jobName, basicParameterInjector); + final Configuration driverConf = getDriverConfiguration(jobName, basicParameterInjector, isAsync); final int timeout = basicParameterInjector.getNamedInstance(Timeout.class); final LauncherStatus status = DriverLauncher.getLauncher(runTimeConf).run( @@ -287,6 +297,9 @@ private static Tuple2 parseCommandLine( basicParameterClassList.add(ServerMetricsWindowMs.class); basicParameterClassList.add(PSTraceProbability.class); + // add SyncSGD parameters + basicParameterClassList.add(Synchronicity.class); + // add SSP parameters basicParameterClassList.add(StalenessBound.class); @@ -364,7 +377,7 @@ private static Configuration getLocalRuntimeConfiguration(final int maxNumEvalLo } private static Configuration getDriverConfiguration( - final String jobName, final Injector injector) throws InjectionException { + final String jobName, final Injector injector, final boolean isAsync) throws InjectionException { final ConfigurationModule driverConf = DriverConfiguration.CONF .set(DriverConfiguration.GLOBAL_LIBRARIES, EnvironmentUtils.getClassLocation(AsyncDolphinDriver.class)) .set(DriverConfiguration.DRIVER_IDENTIFIER, jobName) @@ -391,7 +404,7 @@ private static Configuration getDriverConfiguration( final int stalenessBound = injector.getNamedInstance(StalenessBound.class); final boolean isSSPModel = stalenessBound >= 0; final CentCommConf centCommConf = isSSPModel ? - getCentCommConfForSSP() : getDefaultCentCommConf(); + getCentCommConfForSSP() : getDefaultCentCommConf(isAsync); // set up an optimizer configuration final Class optimizerClass; final Class executorClass; @@ -432,12 +445,23 @@ private static CentCommConf.Builder getCentCommConfDefaultBuilder() { EvalSideMetricsMsgHandlerForServer.class); } - private static CentCommConf getDefaultCentCommConf() { - return getCentCommConfDefaultBuilder() - .addCentCommClient(ClockManager.CENT_COMM_CLIENT_NAME, - ClockManager.MessageHandler.class, - AsyncWorkerClock.MessageHandler.class) - .build(); + private static CentCommConf getDefaultCentCommConf(final boolean isAsync) { + if (isAsync) { + return getCentCommConfDefaultBuilder() + .addCentCommClient(ClockManager.CENT_COMM_CLIENT_NAME, + ClockManager.MessageHandler.class, + AsyncWorkerClock.MessageHandler.class) + .build(); + } else { + return getCentCommConfDefaultBuilder() + .addCentCommClient(ClockManager.CENT_COMM_CLIENT_NAME, + ClockManager.MessageHandler.class, + AsyncWorkerClock.MessageHandler.class) + .addCentCommClient(BatchManager.CENT_COMM_CLIENT_NAME, + DriverSideSyncSGDMsgHandler.class, + WorkerSideSyncSGDMsgHandler.class) + .build(); + } } private static CentCommConf getCentCommConfForSSP() { diff --git a/dolphin/async/src/main/java/edu/snu/cay/dolphin/async/AsyncWorkerTask.java b/dolphin/async/src/main/java/edu/snu/cay/dolphin/async/AsyncWorkerTask.java index 3ad8b11bd..2be266de2 100644 --- a/dolphin/async/src/main/java/edu/snu/cay/dolphin/async/AsyncWorkerTask.java +++ b/dolphin/async/src/main/java/edu/snu/cay/dolphin/async/AsyncWorkerTask.java @@ -17,12 +17,16 @@ import edu.snu.cay.common.metric.MetricsMsgSender; import edu.snu.cay.common.metric.avro.Metrics; +import edu.snu.cay.common.param.Parameters; +import edu.snu.cay.dolphin.async.SyncSGD.SyncSGDWorkerSide.api.MiniBatchBarrier; +import edu.snu.cay.dolphin.async.SyncSGD.SyncSGDWorkerSide.impl.LearningState; import edu.snu.cay.dolphin.async.metric.avro.WorkerMetrics; import edu.snu.cay.services.em.common.parameters.AddedEval; import edu.snu.cay.services.em.evaluator.api.MemoryStore; import edu.snu.cay.services.ps.worker.api.ParameterWorker; import edu.snu.cay.services.ps.worker.api.WorkerClock; import edu.snu.cay.utils.HostnameResolver; +import edu.snu.cay.utils.StateMachine; import org.apache.reef.driver.task.TaskConfigurationOptions.Identifier; import org.apache.reef.tang.annotations.Parameter; import org.apache.reef.task.Task; @@ -61,6 +65,35 @@ final class AsyncWorkerTask implements Task { private final boolean addedEval; private final String hostname; + /** + * Determine whether client wants synchronous or asynchronous workers. + * If true, client wants asynchronously working workers. + * If false, client wants synchronously working workers. + */ + private final boolean isAsync; + + /** + * MiniBatchBarrier helps all the workers to start next mini-batch synchronously in synchronous system. + */ + private final MiniBatchBarrier miniBatchBarrier; + + /** + * Three WorkerTask's states. These states are for synchronous worker. + */ + private final StateMachine stateMachine; + private enum State { + MINI_BATCH_RUNNING, + WAITING_NEXT_MINI_BATCH, + MINI_BATCH_CLOSING + } + + /** + * Flag which indicates learning state of this worker in synchronous system. + * When this worker receive {@code terminateLearningMsg} from driver, this flag will be changed to + * {@code LearningState.TerminateLearning}. + */ + private LearningState learningFlag = LearningState.ProgressLearning; + /** * A boolean flag shared among all trainer threads. * Trainer threads end when this flag becomes true by {@link #close()}. @@ -72,6 +105,7 @@ private AsyncWorkerTask(@Parameter(Identifier.class) final String taskId, @Parameter(DolphinParameters.MaxNumEpochs.class) final int maxNumEpochs, @Parameter(DolphinParameters.MiniBatchSize.class) final int miniBatchSize, @Parameter(AddedEval.class) final boolean addedEval, + @Parameter(Parameters.Synchronicity.class) final String synchronicity, final WorkerSynchronizer synchronizer, final ParameterWorker parameterWorker, final TrainingDataProvider trainingDataProvider, @@ -79,7 +113,8 @@ private AsyncWorkerTask(@Parameter(Identifier.class) final String taskId, final MemoryStore memoryStore, final Trainer trainer, final MetricsMsgSender metricsMsgSender, - final WorkerClock workerClock) { + final WorkerClock workerClock, + final MiniBatchBarrier miniBatchBarrier) { this.taskId = taskId; this.maxNumEpochs = maxNumEpochs; this.miniBatchSize = miniBatchSize; @@ -93,6 +128,26 @@ private AsyncWorkerTask(@Parameter(Identifier.class) final String taskId, this.metricsMsgSender = metricsMsgSender; this.workerClock = workerClock; this.hostname = HostnameResolver.resolve(); + this.miniBatchBarrier = miniBatchBarrier; + this.stateMachine = initStateMachine(); + this.isAsync = synchronicity.equals("async"); + } + + private StateMachine initStateMachine() { + return StateMachine.newBuilder() + .addState(State.MINI_BATCH_RUNNING, "Mini-batch is running now.") + .addState(State.WAITING_NEXT_MINI_BATCH, "Mini-batch is finished and waiting for StartNextMiniBatchMsg " + + "from driver") + .addState(State.MINI_BATCH_CLOSING, "This worker is slow worker. Mini-batch is stopped and being closed" + + " after it had received StartNextMiniBatchMsg from driver.") + .addTransition(State.MINI_BATCH_RUNNING, State.WAITING_NEXT_MINI_BATCH, "Mini-batch is finished.") + .addTransition(State.MINI_BATCH_RUNNING, State.MINI_BATCH_CLOSING, "This worker is slow worker. Mini-batch" + + " is not finished but should be closed for next mini-batch.") + .addTransition(State.WAITING_NEXT_MINI_BATCH, State.MINI_BATCH_RUNNING, "New mini-batch is started.") + .addTransition(State.MINI_BATCH_CLOSING, State.MINI_BATCH_RUNNING, "Previous mini-batch is closed " + + "successfully. New mini-batch is started.") + .setInitialState(State.MINI_BATCH_RUNNING) + .build(); } @Override @@ -125,7 +180,11 @@ public byte[] call(final byte[] memento) throws Exception { // By starting epochs from the initial clock, which is dynamically fetched from driver, // it prevents workers added by EM from starting from epoch 0 and deferring job completion. // More specifically, added workers start from the minimum epoch index of other existing workers. - for (int epochIdx = initialClock; epochIdx < maxNumEpochs; ++epochIdx) { + for (int epochIdx = initialClock;; ++epochIdx) { + // If asynchronous system is chosen, learning is terminated when epochIdx == maxNumEpochs. + if (isAsync && epochIdx == maxNumEpochs) { + break; + } LOG.log(Level.INFO, "Starting epoch {0}", epochIdx); final long epochStartTime = System.currentTimeMillis(); final int numEMBlocks = memoryStore.getNumBlocks(); @@ -145,6 +204,10 @@ public byte[] call(final byte[] memento) throws Exception { final MiniBatchResult miniBatchResult = trainer.runMiniBatch(miniBatchTrainingData); final double miniBatchElapsedTime = (System.currentTimeMillis() - miniBatchStartTime) / 1000.0D; + stateMachine.setState(State.WAITING_NEXT_MINI_BATCH); + learningFlag = miniBatchBarrier.waitMiniBatchControlMsgFromDriver(epochIdx); + stateMachine.setState(State.MINI_BATCH_RUNNING); + buildAndSendMiniBatchMetrics(miniBatchResult, epochIdx, miniBatchIdx, miniBatchTrainingData.size(), miniBatchElapsedTime); @@ -157,6 +220,10 @@ public byte[] call(final byte[] memento) throws Exception { workerClock.recordClockNetworkWaitingTime(); return null; } + + if (learningFlag == LearningState.TerminateLearning) { + break; + } } final EpochResult epochResult = trainer.onEpochFinished(epochTrainingData, testData, epochIdx); @@ -167,6 +234,9 @@ public byte[] call(final byte[] memento) throws Exception { // TODO #830: Clock should be a unit of mini-batch instead of epoch workerClock.clock(); + if (learningFlag == LearningState.TerminateLearning) { + break; + } } // Synchronize all workers before cleanup for workers diff --git a/dolphin/async/src/main/java/edu/snu/cay/dolphin/async/PSModelAccessor.java b/dolphin/async/src/main/java/edu/snu/cay/dolphin/async/PSModelAccessor.java index 58a2f2429..9c67ae677 100644 --- a/dolphin/async/src/main/java/edu/snu/cay/dolphin/async/PSModelAccessor.java +++ b/dolphin/async/src/main/java/edu/snu/cay/dolphin/async/PSModelAccessor.java @@ -15,6 +15,7 @@ */ package edu.snu.cay.dolphin.async; +import edu.snu.cay.dolphin.async.SyncSGD.SyncSGDWorkerSide.api.PushBarrier; import edu.snu.cay.services.ps.worker.api.ParameterWorker; import javax.inject.Inject; @@ -26,14 +27,18 @@ public class PSModelAccessor implements ModelAccessor { private final ParameterWorker parameterWorker; + private final PushBarrier pushBarrier; @Inject - PSModelAccessor(final ParameterWorker parameterWorker) { + PSModelAccessor(final ParameterWorker parameterWorker, + final PushBarrier pushBarrier) { this.parameterWorker = parameterWorker; + this.pushBarrier = pushBarrier; } @Override public void push(final K key, final P deltaValue) { + pushBarrier.requestPushPermission(); parameterWorker.push(key, deltaValue); } diff --git a/dolphin/async/src/main/java/edu/snu/cay/dolphin/async/SyncSGD/ResettableCountDownLatch.java b/dolphin/async/src/main/java/edu/snu/cay/dolphin/async/SyncSGD/ResettableCountDownLatch.java new file mode 100644 index 000000000..03d0a5663 --- /dev/null +++ b/dolphin/async/src/main/java/edu/snu/cay/dolphin/async/SyncSGD/ResettableCountDownLatch.java @@ -0,0 +1,195 @@ +/* + * Copyright (C) 2017 Seoul National University + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package edu.snu.cay.dolphin.async.SyncSGD; + +import java.util.concurrent.TimeUnit; +import java.util.concurrent.locks.AbstractQueuedSynchronizer; + +public class ResettableCountDownLatch { + /** + * Synchronization control For ResettableCountDownLatch. + * Uses AQS state to represent count. + */ + private static final class Sync extends AbstractQueuedSynchronizer { + private static final long serialVersionUID = 4982264981922014374L; + + Sync(final int count) { + setState(count); + } + + int getCount() { + return getState(); + } + + protected int tryAcquireShared(final int acquires) { + return (getState() == 0) ? 1 : -1; + } + + protected boolean tryReleaseShared(final int releases) { + // Decrement count; signal when transition to zero + for (;;) { + final int c = getState(); + if (c == 0) { + return false; + } + final int nextc = c - 1; + if (compareAndSetState(c, nextc)) { + return nextc == 0; + } + } + } + + public void reset(final int count) { + setState(count); + } + } + + private final Sync sync; + + /** + * Constructs a {@code ResettableCountDownLatch} initialized with the given count. + * + * @param count the number of times {@link #countDown} must be invoked + * before threads can pass through {@link #await} + * @throws IllegalArgumentException if {@code count} is negative + */ + public ResettableCountDownLatch(final int count) { + if (count < 0) { + throw new IllegalArgumentException("count < 0"); + } + this.sync = new Sync(count); + } + + /** + * Causes the current thread to wait until the latch has counted down to + * zero, unless the thread is {@linkplain Thread#interrupt interrupted}. + * + *

If the current count is zero then this method returns immediately. + * + *

If the current count is greater than zero then the current + * thread becomes disabled for thread scheduling purposes and lies + * dormant until one of two things happen: + *

    + *
  • The count reaches zero due to invocations of the + * {@link #countDown} method; or + *
  • Some other thread {@linkplain Thread#interrupt interrupts} + * the current thread. + *
+ * + *

If the current thread: + *

    + *
  • has its interrupted status set on entry to this method; or + *
  • is {@linkplain Thread#interrupt interrupted} while waiting, + *
+ * then {@link InterruptedException} is thrown and the current thread's + * interrupted status is cleared. + * + * @throws InterruptedException if the current thread is interrupted + * while waiting + */ + public void await() throws InterruptedException { + sync.acquireSharedInterruptibly(1); + } + + /** + * Causes the current thread to wait until the latch has counted down to + * zero, unless the thread is {@linkplain Thread#interrupt interrupted}, + * or the specified waiting time elapses. + * + *

If the current count is zero then this method returns immediately + * with the value {@code true}. + * + *

If the current count is greater than zero then the current + * thread becomes disabled for thread scheduling purposes and lies + * dormant until one of three things happen: + *

    + *
  • The count reaches zero due to invocations of the + * {@link #countDown} method; or + *
  • Some other thread {@linkplain Thread#interrupt interrupts} + * the current thread; or + *
  • The specified waiting time elapses. + *
+ * + *

If the count reaches zero then the method returns with the + * value {@code true}. + * + *

If the current thread: + *

    + *
  • has its interrupted status set on entry to this method; or + *
  • is {@linkplain Thread#interrupt interrupted} while waiting, + *
+ * then {@link InterruptedException} is thrown and the current thread's + * interrupted status is cleared. + * + *

If the specified waiting time elapses then the value {@code false} + * is returned. If the time is less than or equal to zero, the method + * will not wait at all. + * + * @param timeout the maximum time to wait + * @param unit the time unit of the {@code timeout} argument + * @return {@code true} if the count reached zero and {@code false} + * if the waiting time elapsed before the count reached zero + * @throws InterruptedException if the current thread is interrupted + * while waiting + */ + public boolean await(final long timeout, final TimeUnit unit) + throws InterruptedException { + return sync.tryAcquireSharedNanos(1, unit.toNanos(timeout)); + } + + /** + * Decrements the count of the latch, releasing all waiting threads if + * the count reaches zero. + * + *

If the current count is greater than zero then it is decremented. + * If the new count is zero then all waiting threads are re-enabled for + * thread scheduling purposes. + * + *

If the current count equals zero then nothing happens. + */ + public void countDown() { + sync.releaseShared(1); + } + + /** + * Returns the current count. + * + *

This method is typically used for debugging and testing purposes. + * + * @return the current count + */ + public long getCount() { + return sync.getCount(); + } + + /** + * Returns a string identifying this latch, as well as its state. + * The state, in brackets, includes the String {@code "Count ="} + * followed by the current count. + * + * @return a string identifying this latch, as well as its state + */ + public String toString() { + return super.toString() + "[Count = " + sync.getCount() + "]"; + } + + /** + * @param count reset count of latch with this value. + */ + public void reset(final int count) { + sync.reset(count); + } +} diff --git a/dolphin/async/src/main/java/edu/snu/cay/dolphin/async/SyncSGD/SyncSGDDriverSide/BatchManager.java b/dolphin/async/src/main/java/edu/snu/cay/dolphin/async/SyncSGD/SyncSGDDriverSide/BatchManager.java new file mode 100644 index 000000000..2b294caaa --- /dev/null +++ b/dolphin/async/src/main/java/edu/snu/cay/dolphin/async/SyncSGD/SyncSGDDriverSide/BatchManager.java @@ -0,0 +1,26 @@ +/* + * Copyright (C) 2017 Seoul National University + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package edu.snu.cay.dolphin.async.SyncSGD.SyncSGDDriverSide; + +/** + * TODO #940: Should be implemented to manage workers' mini-batch. + */ +public final class BatchManager { + public static final String CENT_COMM_CLIENT_NAME = BatchManager.class.getName(); + + private BatchManager() { + } +} diff --git a/dolphin/async/src/main/java/edu/snu/cay/dolphin/async/SyncSGD/SyncSGDDriverSide/DriverSideSyncSGDMsgHandler.java b/dolphin/async/src/main/java/edu/snu/cay/dolphin/async/SyncSGD/SyncSGDDriverSide/DriverSideSyncSGDMsgHandler.java new file mode 100644 index 000000000..d18e3f576 --- /dev/null +++ b/dolphin/async/src/main/java/edu/snu/cay/dolphin/async/SyncSGD/SyncSGDDriverSide/DriverSideSyncSGDMsgHandler.java @@ -0,0 +1,51 @@ +/* + * Copyright (C) 2017 Seoul National University + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package edu.snu.cay.dolphin.async.SyncSGD.SyncSGDDriverSide; + +import edu.snu.cay.common.centcomm.avro.CentCommMsg; +import edu.snu.cay.dolphin.async.AvroSyncSGDMsg; +import edu.snu.cay.dolphin.async.SyncSGD.SyncSGDMsgCodec; +import org.apache.reef.wake.EventHandler; + +import javax.inject.Inject; + +/** + * TODO #940: This handler should be implemented for {@link BatchManager}. + * Handles messages related to SyncSGD from workers. + */ +public final class DriverSideSyncSGDMsgHandler implements EventHandler { + public static final String AGGREGATION_CLIENT_NAME = DriverSideSyncSGDMsgHandler.class.getName(); + private SyncSGDMsgCodec codec; + + @Inject + private DriverSideSyncSGDMsgHandler(final SyncSGDMsgCodec syncSGDMsgCodec) { + this.codec = syncSGDMsgCodec; + } + + @Override + public void onNext(final CentCommMsg centCommMsg) { + final AvroSyncSGDMsg rcvMsg = codec.decode(centCommMsg.getData().array()); + final String workerId = centCommMsg.getSourceId().toString(); + switch (rcvMsg.getType()) { + case RequestPushPermissionMsg: + break; + case MiniBatchFinishedMsg: + break; + default: + throw new RuntimeException("Unexpected message type: " + rcvMsg.getType().toString()); + } + } +} diff --git a/dolphin/async/src/main/java/edu/snu/cay/dolphin/async/SyncSGD/SyncSGDDriverSide/package-info.java b/dolphin/async/src/main/java/edu/snu/cay/dolphin/async/SyncSGD/SyncSGDDriverSide/package-info.java new file mode 100644 index 000000000..522583db2 --- /dev/null +++ b/dolphin/async/src/main/java/edu/snu/cay/dolphin/async/SyncSGD/SyncSGDDriverSide/package-info.java @@ -0,0 +1,19 @@ +/* + * Copyright (C) 2017 Seoul National University + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/** + * Components for driver-side SyncSGD. + */ +package edu.snu.cay.dolphin.async.SyncSGD.SyncSGDDriverSide; diff --git a/dolphin/async/src/main/java/edu/snu/cay/dolphin/async/SyncSGD/SyncSGDMsgCodec.java b/dolphin/async/src/main/java/edu/snu/cay/dolphin/async/SyncSGD/SyncSGDMsgCodec.java new file mode 100644 index 000000000..4cec69ecb --- /dev/null +++ b/dolphin/async/src/main/java/edu/snu/cay/dolphin/async/SyncSGD/SyncSGDMsgCodec.java @@ -0,0 +1,41 @@ +/* + * Copyright (C) 2017 Seoul National University + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package edu.snu.cay.dolphin.async.SyncSGD; + +import edu.snu.cay.dolphin.async.AvroSyncSGDMsg; +import edu.snu.cay.utils.AvroUtils; +import org.apache.reef.wake.remote.Codec; + +import javax.inject.Inject; + +/** + * Codec for {@link AvroSyncSGDMsg}. + */ +public final class SyncSGDMsgCodec implements Codec { + @Inject + private SyncSGDMsgCodec() { + } + + @Override + public byte[] encode(final AvroSyncSGDMsg msg) { + return AvroUtils.toBytes(msg, AvroSyncSGDMsg.class); + } + + @Override + public AvroSyncSGDMsg decode(final byte[] data) { + return AvroUtils.fromBytes(data, AvroSyncSGDMsg.class); + } +} diff --git a/dolphin/async/src/main/java/edu/snu/cay/dolphin/async/SyncSGD/SyncSGDWorkerSide/api/MiniBatchBarrier.java b/dolphin/async/src/main/java/edu/snu/cay/dolphin/async/SyncSGD/SyncSGDWorkerSide/api/MiniBatchBarrier.java new file mode 100644 index 000000000..f44d3730a --- /dev/null +++ b/dolphin/async/src/main/java/edu/snu/cay/dolphin/async/SyncSGD/SyncSGDWorkerSide/api/MiniBatchBarrier.java @@ -0,0 +1,50 @@ +/* + * Copyright (C) 2017 Seoul National University + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package edu.snu.cay.dolphin.async.SyncSGD.SyncSGDWorkerSide.api; + + +import edu.snu.cay.dolphin.async.SyncSGD.SyncSGDWorkerSide.impl.WorkerSideSyncSGDMsgHandler; +import edu.snu.cay.dolphin.async.SyncSGD.SyncSGDWorkerSide.impl.LearningState; + +/** + * Before AsyncWorkerTask starts next mini-batch, AsyncWorkerTask asks to {@code MiniBatchBarrier} whether to start next + * mini-batch. + */ +public interface MiniBatchBarrier { + + /** + * WorkerTask will wait in this method until this worker receives MiniBatchControlMsg from driver. + * There are two kinds of MiniBatchControlMsg : TerminateLearningMsg, StartNextMiniBatchMsg. + * @param epochIdx driver decides whether to progress learning or terminate learning by using this value. + * @return If this worker receives TerminateLearningMsg from driver, this method returns + * {@code LearningState.TerminateLearning}. + * If this worker receives StartNextMiniBatchMsg from driver, this method returns + * {@code LearningState.ProgressLearning}. + */ + LearningState waitMiniBatchControlMsgFromDriver(int epochIdx); + + /** + * {@link WorkerSideSyncSGDMsgHandler} will call this method when this worker receives + * {@code StartNextMiniBatchMsg} from driver. + */ + void startNextMiniBatch(); + + /** + * {@link WorkerSideSyncSGDMsgHandler} will call this method when this worker receives + * {@code TerminateLearningMsg} from driver. + */ + void terminateLearning(); +} diff --git a/dolphin/async/src/main/java/edu/snu/cay/dolphin/async/SyncSGD/SyncSGDWorkerSide/api/PushBarrier.java b/dolphin/async/src/main/java/edu/snu/cay/dolphin/async/SyncSGD/SyncSGDWorkerSide/api/PushBarrier.java new file mode 100644 index 000000000..5224038d2 --- /dev/null +++ b/dolphin/async/src/main/java/edu/snu/cay/dolphin/async/SyncSGD/SyncSGDWorkerSide/api/PushBarrier.java @@ -0,0 +1,46 @@ +/* + * Copyright (C) 2017 Seoul National University + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package edu.snu.cay.dolphin.async.SyncSGD.SyncSGDWorkerSide.api; + +import edu.snu.cay.dolphin.async.SyncSGD.SyncSGDDriverSide.BatchManager; + +/** + * Before {@code ModelAccessor} sends values to the server through push operation, this barrier asks + * {@link BatchManager} whether it would be okay to send the values. + */ +public interface PushBarrier { + + /** + * Request permission for push to the driver. + * This method sends {@code RequestPushPermissionMsg} to the driver and waits until it receives {@code PermitPushMsg}. + * If this worker is slow worker, this method waits until it receives {@code StartNextMiniBatchMsg}. + */ + void requestPushPermission(); + + /** + * Count down pushLatch. + */ + void countDownPushLatch(); + + /** + * When this worker receives {@code startNextMiniBatchMsg} from driver, PushBarrier prepares for the next mini-batch. + * There are two things to prepare for the next mini-batch: + * 1) Update {@code thisRoundNum} value with {@param nextRoundNum}. + * 2) reset {@code pushLatch}. + * @param nextRoundNum driver notify same nextRoundNum integer value to all the workers + */ + void prepareNextMiniBatch(int nextRoundNum); +} diff --git a/dolphin/async/src/main/java/edu/snu/cay/dolphin/async/SyncSGD/SyncSGDWorkerSide/api/package-info.java b/dolphin/async/src/main/java/edu/snu/cay/dolphin/async/SyncSGD/SyncSGDWorkerSide/api/package-info.java new file mode 100644 index 000000000..cf6972a77 --- /dev/null +++ b/dolphin/async/src/main/java/edu/snu/cay/dolphin/async/SyncSGD/SyncSGDWorkerSide/api/package-info.java @@ -0,0 +1,19 @@ +/* + * Copyright (C) 2017 Seoul National University + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/** + * Components for worker-side SyncSGD. + */ +package edu.snu.cay.dolphin.async.SyncSGD.SyncSGDWorkerSide.api; diff --git a/dolphin/async/src/main/java/edu/snu/cay/dolphin/async/SyncSGD/SyncSGDWorkerSide/impl/LearningState.java b/dolphin/async/src/main/java/edu/snu/cay/dolphin/async/SyncSGD/SyncSGDWorkerSide/impl/LearningState.java new file mode 100644 index 000000000..e6f8982fc --- /dev/null +++ b/dolphin/async/src/main/java/edu/snu/cay/dolphin/async/SyncSGD/SyncSGDWorkerSide/impl/LearningState.java @@ -0,0 +1,23 @@ +/* + * Copyright (C) 2017 Seoul National University + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package edu.snu.cay.dolphin.async.SyncSGD.SyncSGDWorkerSide.impl; + +/** + * This state indicates the learning state of WorkerTask. + */ +public enum LearningState { + TerminateLearning, ProgressLearning +} diff --git a/dolphin/async/src/main/java/edu/snu/cay/dolphin/async/SyncSGD/SyncSGDWorkerSide/impl/NullMiniBatchBarrier.java b/dolphin/async/src/main/java/edu/snu/cay/dolphin/async/SyncSGD/SyncSGDWorkerSide/impl/NullMiniBatchBarrier.java new file mode 100644 index 000000000..6011ece81 --- /dev/null +++ b/dolphin/async/src/main/java/edu/snu/cay/dolphin/async/SyncSGD/SyncSGDWorkerSide/impl/NullMiniBatchBarrier.java @@ -0,0 +1,53 @@ +/* + * Copyright (C) 2017 Seoul National University + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package edu.snu.cay.dolphin.async.SyncSGD.SyncSGDWorkerSide.impl; + + +import edu.snu.cay.dolphin.async.SyncSGD.SyncSGDWorkerSide.api.MiniBatchBarrier; + +import javax.inject.Inject; +import java.util.logging.Level; +import java.util.logging.Logger; + +/** + * This implementation is for an asynchronous system. + * {@link MiniBatchBarrier} that always permit push without asking the driver. + */ +public final class NullMiniBatchBarrier implements MiniBatchBarrier { + private static final Logger LOG = Logger.getLogger(NullMiniBatchBarrier.class.getName()); + + @Inject + private NullMiniBatchBarrier() { + } + + /** + * This method does nothing and always returns {@code LearningState.ProgressLearning} because this class is for + * asynchronous system. + */ + @Override + public LearningState waitMiniBatchControlMsgFromDriver(final int epochIdx) { + LOG.log(Level.INFO, "This is NullMiniBatchBarrier"); + return LearningState.ProgressLearning; + } + + @Override + public void startNextMiniBatch() { + } + + @Override + public void terminateLearning() { + } +} diff --git a/dolphin/async/src/main/java/edu/snu/cay/dolphin/async/SyncSGD/SyncSGDWorkerSide/impl/NullPushBarrier.java b/dolphin/async/src/main/java/edu/snu/cay/dolphin/async/SyncSGD/SyncSGDWorkerSide/impl/NullPushBarrier.java new file mode 100644 index 000000000..a1361fb27 --- /dev/null +++ b/dolphin/async/src/main/java/edu/snu/cay/dolphin/async/SyncSGD/SyncSGDWorkerSide/impl/NullPushBarrier.java @@ -0,0 +1,42 @@ +/* + * Copyright (C) 2017 Seoul National University + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package edu.snu.cay.dolphin.async.SyncSGD.SyncSGDWorkerSide.impl; + + +import edu.snu.cay.dolphin.async.SyncSGD.SyncSGDWorkerSide.api.PushBarrier; + +import javax.inject.Inject; + +/** + * This implementation is for an asynchronous system. + */ +public final class NullPushBarrier implements PushBarrier { + @Inject + private NullPushBarrier() { + } + + @Override + public void requestPushPermission() { + } + + @Override + public void countDownPushLatch() { + } + + @Override + public void prepareNextMiniBatch(final int nextRoundNum) { + } +} diff --git a/dolphin/async/src/main/java/edu/snu/cay/dolphin/async/SyncSGD/SyncSGDWorkerSide/impl/SyncMiniBatchBarrier.java b/dolphin/async/src/main/java/edu/snu/cay/dolphin/async/SyncSGD/SyncSGDWorkerSide/impl/SyncMiniBatchBarrier.java new file mode 100644 index 000000000..a531916ce --- /dev/null +++ b/dolphin/async/src/main/java/edu/snu/cay/dolphin/async/SyncSGD/SyncSGDWorkerSide/impl/SyncMiniBatchBarrier.java @@ -0,0 +1,72 @@ +/* + * Copyright (C) 2017 Seoul National University + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package edu.snu.cay.dolphin.async.SyncSGD.SyncSGDWorkerSide.impl; + + +import edu.snu.cay.dolphin.async.SyncSGD.ResettableCountDownLatch; +import edu.snu.cay.dolphin.async.SyncSGD.SyncSGDWorkerSide.api.MiniBatchBarrier; + +import javax.inject.Inject; +import java.util.logging.Level; +import java.util.logging.Logger; + +/** + * {@link MiniBatchBarrier} that is implemented for synchronous system. + * This worker will be blocked in this barrier until it receives {@code StartNextMiniBatchMsg} or + * {@code TerminateLearningMsg} from driver. + */ +public final class SyncMiniBatchBarrier implements MiniBatchBarrier { + private static final Logger LOG = Logger.getLogger(SyncMiniBatchBarrier.class.getName()); + private final ResettableCountDownLatch miniBatchLatch; + private LearningState learningState = LearningState.ProgressLearning; + private final WorkerSideSyncSGDMsgSender msgSender; + + @Inject + private SyncMiniBatchBarrier(final WorkerSideSyncSGDMsgSender workerSideSyncSGDMsgSender) { + this.miniBatchLatch = new ResettableCountDownLatch(1); + this.msgSender = workerSideSyncSGDMsgSender; + } + + /** + * When this worker receives MiniBatchControlMsg from driver, {@link WorkerSideSyncSGDMsgHandler} will count down + * {@code miniBatchLatch}. + * @param epochIdx driver decides whether to progress learning or terminate learning by using this value. + * @return learning state decided by driver. + */ + @Override + public LearningState waitMiniBatchControlMsgFromDriver(final int epochIdx) { + try { + LOG.log(Level.INFO, "Mini-batch is finished. Waiting for MiniBatchControlMsg."); + msgSender.sendMiniBatchFinishedMsg(epochIdx); + miniBatchLatch.await(); + miniBatchLatch.reset(1); + } catch (InterruptedException e) { + throw new RuntimeException("Unexpected exception in SyncMiniBatchBarrier" + e); + } + return learningState; + } + + @Override + public void startNextMiniBatch() { + miniBatchLatch.countDown(); + } + + @Override + public void terminateLearning() { + learningState = LearningState.TerminateLearning; + miniBatchLatch.countDown(); + } +} diff --git a/dolphin/async/src/main/java/edu/snu/cay/dolphin/async/SyncSGD/SyncSGDWorkerSide/impl/SyncPushBarrier.java b/dolphin/async/src/main/java/edu/snu/cay/dolphin/async/SyncSGD/SyncSGDWorkerSide/impl/SyncPushBarrier.java new file mode 100644 index 000000000..dbb8fb8fa --- /dev/null +++ b/dolphin/async/src/main/java/edu/snu/cay/dolphin/async/SyncSGD/SyncSGDWorkerSide/impl/SyncPushBarrier.java @@ -0,0 +1,89 @@ +/* + * Copyright (C) 2017 Seoul National University + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package edu.snu.cay.dolphin.async.SyncSGD.SyncSGDWorkerSide.impl; + + + +import edu.snu.cay.dolphin.async.ModelAccessor; +import edu.snu.cay.dolphin.async.SyncSGD.ResettableCountDownLatch; +import edu.snu.cay.dolphin.async.SyncSGD.SyncSGDWorkerSide.api.PushBarrier; + +import javax.inject.Inject; + +/** + * {@link PushBarrier} that is implemented for synchronous system. + * {@link ModelAccessor} will be blocked in this barrier until it receives {@code PermitPushMsg} from driver. + */ +public final class SyncPushBarrier implements PushBarrier { + private final ResettableCountDownLatch pushLatch; + private final WorkerSideSyncSGDMsgSender msgSender; + + // thisRoundNum should be tracked to distinguish between up-to-date RequestPushPermissionMsg and deprecated + // RequestPushPermissionMsg. + private int thisRoundNum = 0; + + @Inject + private SyncPushBarrier(final WorkerSideSyncSGDMsgSender msgSender) { + this.pushLatch = new ResettableCountDownLatch(1); + this.msgSender = msgSender; + } + + /** + * Send {@code RequestPushPermissionMsg} to driver and wait until this worker receives {@code PermitPushMsg}. + */ + @Override + public void requestPushPermission() { + try { + if (pushLatch.getCount() != 0) { + msgSender.sendRequestPushPermissionMsg(thisRoundNum); + pushLatch.await(); + } + } catch (InterruptedException e) { + throw new RuntimeException("Unexpected exception in SyncPushBarrier's requestPushPermission", e); + } + } + + /** + * Update thisRoundNum with up-to-date value, {@code nextRoundNum}, and reset {@code pushLatch} for next mini-batch. + * @param nextRoundNum driver notify same nextRoundNum integer value to all the workers. + */ + @Override + public void prepareNextMiniBatch(final int nextRoundNum) { + thisRoundNum = nextRoundNum; + pushLatch.reset(1); + } + + @Override + public void countDownPushLatch() { + pushLatch.countDown(); + } + + /** + * Only for SyncPushBarrierTest. + * @return pushLatch's count + */ + public long getLatchCount() { + return pushLatch.getCount(); + } + + /** + * Only for SyncPushBarrierTest. + * @return thisRoundNum + */ + public int getThisRoundNum() { + return thisRoundNum; + } +} diff --git a/dolphin/async/src/main/java/edu/snu/cay/dolphin/async/SyncSGD/SyncSGDWorkerSide/impl/WorkerSideSyncSGDMsgHandler.java b/dolphin/async/src/main/java/edu/snu/cay/dolphin/async/SyncSGD/SyncSGDWorkerSide/impl/WorkerSideSyncSGDMsgHandler.java new file mode 100644 index 000000000..9b8d11c08 --- /dev/null +++ b/dolphin/async/src/main/java/edu/snu/cay/dolphin/async/SyncSGD/SyncSGDWorkerSide/impl/WorkerSideSyncSGDMsgHandler.java @@ -0,0 +1,76 @@ +/* + * Copyright (C) 2017 Seoul National University + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package edu.snu.cay.dolphin.async.SyncSGD.SyncSGDWorkerSide.impl; + +import edu.snu.cay.common.centcomm.avro.CentCommMsg; +import edu.snu.cay.dolphin.async.AvroSyncSGDMsg; +import edu.snu.cay.dolphin.async.SyncSGD.SyncSGDMsgCodec; +import edu.snu.cay.dolphin.async.SyncSGD.SyncSGDWorkerSide.api.MiniBatchBarrier; +import edu.snu.cay.dolphin.async.SyncSGD.SyncSGDWorkerSide.api.PushBarrier; +import org.apache.reef.wake.EventHandler; + +import javax.inject.Inject; + +/** + * Handles events for {@link PushBarrier} and {@link MiniBatchBarrier}. + */ +public final class WorkerSideSyncSGDMsgHandler implements EventHandler { + public static final String AGGREGATION_CLIENT_NAME = WorkerSideSyncSGDMsgHandler.class.getName(); + private final SyncPushBarrier syncPushBarrier; + private final SyncSGDMsgCodec codec; + private final SyncMiniBatchBarrier syncMiniBatchBarrier; + + @Inject + private WorkerSideSyncSGDMsgHandler(final SyncPushBarrier syncPushBarrier, + final SyncSGDMsgCodec syncSGDMsgCodec, + final SyncMiniBatchBarrier syncMiniBatchBarrier) { + this.syncPushBarrier = syncPushBarrier; + this.codec = syncSGDMsgCodec; + this.syncMiniBatchBarrier = syncMiniBatchBarrier; + } + + /** + * Handles three types of messages. + * 1) PermitPushMsg + * When driver permits this worker's push operation, count down {@code pushLatch} of {@link SyncPushBarrier}. + * 2) StartNextMiniBatchMsg + * To start next mini-batch, update {@code thisRoundNum} and reset {@code pushLatch} of {@link SyncPushBarrier}. + * Then, count down miniBatchLatch in syncMiniBatchBarrier which allows starting next mini-batch. + * 3) TerminateLearningMsg + * Change the {@code learningState} value in {@link SyncMiniBatchBarrier} to {@code TerminateLearning}. + * @param centCommMsg received message from driver. + */ + @Override + public void onNext(final CentCommMsg centCommMsg) { + final AvroSyncSGDMsg avroSyncSGDMsg = codec.decode(centCommMsg.getData().array()); + switch (avroSyncSGDMsg.getType()) { + case PermitPushMsg: + syncPushBarrier.countDownPushLatch(); + break; + case StartNextMiniBatchMsg: + final int nextRoundNum = avroSyncSGDMsg.getStartNextMiniBatchMsg().getNextRoundNum(); + // Update thisRoundNum value(in syncPushBarrier) and reset pushLatch(also in syncPushBarrier). + syncPushBarrier.prepareNextMiniBatch(nextRoundNum); + syncMiniBatchBarrier.startNextMiniBatch(); + break; + case TerminateLearningMsg: + syncMiniBatchBarrier.terminateLearning(); + break; + default: + throw new RuntimeException("Unexpected message type: " + avroSyncSGDMsg.getType().toString()); + } + } +} diff --git a/dolphin/async/src/main/java/edu/snu/cay/dolphin/async/SyncSGD/SyncSGDWorkerSide/impl/WorkerSideSyncSGDMsgSender.java b/dolphin/async/src/main/java/edu/snu/cay/dolphin/async/SyncSGD/SyncSGDWorkerSide/impl/WorkerSideSyncSGDMsgSender.java new file mode 100644 index 000000000..c8a3618b2 --- /dev/null +++ b/dolphin/async/src/main/java/edu/snu/cay/dolphin/async/SyncSGD/SyncSGDWorkerSide/impl/WorkerSideSyncSGDMsgSender.java @@ -0,0 +1,73 @@ +/* + * Copyright (C) 2017 Seoul National University + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package edu.snu.cay.dolphin.async.SyncSGD.SyncSGDWorkerSide.impl; + +import edu.snu.cay.common.centcomm.slave.SlaveSideCentCommMsgSender; +import edu.snu.cay.dolphin.async.AvroSyncSGDMsg; +import edu.snu.cay.dolphin.async.MiniBatchFinishedMsg; +import edu.snu.cay.dolphin.async.RequestPushPermissionMsg; +import edu.snu.cay.dolphin.async.SyncSGD.SyncSGDDriverSide.DriverSideSyncSGDMsgHandler; +import edu.snu.cay.dolphin.async.SyncSGD.SyncSGDMsgCodec; +import edu.snu.cay.dolphin.async.SyncSGDMsgType; + +import javax.inject.Inject; + +/** + * Message sender for SyncSGD. + */ +final class WorkerSideSyncSGDMsgSender { + private final SlaveSideCentCommMsgSender slaveSideCentCommMsgSender; + private final SyncSGDMsgCodec codec; + + @Inject + private WorkerSideSyncSGDMsgSender(final SlaveSideCentCommMsgSender slaveSideCentCommMsgSender, + final SyncSGDMsgCodec syncSGDMsgCodec) { + this.slaveSideCentCommMsgSender = slaveSideCentCommMsgSender; + this.codec = syncSGDMsgCodec; + } + + /** + * Send {@code RequestPushPermissionMsg} to driver. + * @param thisRoundNum driver distinguishes up-to-date msg and deprecated msg by using this value. + */ + void sendRequestPushPermissionMsg(final int thisRoundNum) { + final RequestPushPermissionMsg requestPushPermissionMsg = RequestPushPermissionMsg.newBuilder() + .setRoundNum(thisRoundNum) + .build(); + final AvroSyncSGDMsg avroSyncSGDMsg = AvroSyncSGDMsg.newBuilder() + .setType(SyncSGDMsgType.RequestPushPermissionMsg) + .setRequestPushPermissionMsg(requestPushPermissionMsg) + .build(); + final byte[] data = codec.encode(avroSyncSGDMsg); + slaveSideCentCommMsgSender.send(DriverSideSyncSGDMsgHandler.AGGREGATION_CLIENT_NAME, data); + } + + /** + * Send {@code MiniBatchFinishedMsg} to driver. + * @param epochIdx driver decides whether to progress learning or terminate learning. + */ + void sendMiniBatchFinishedMsg(final int epochIdx) { + final MiniBatchFinishedMsg miniBatchFinishedMsg = MiniBatchFinishedMsg.newBuilder() + .setEpochIdx(epochIdx) + .build(); + final AvroSyncSGDMsg avroSyncSGDMsg = AvroSyncSGDMsg.newBuilder() + .setType(SyncSGDMsgType.MiniBatchFinishedMsg) + .setMiniBatchFinishedMsg(miniBatchFinishedMsg) + .build(); + final byte[] data = codec.encode(avroSyncSGDMsg); + slaveSideCentCommMsgSender.send(DriverSideSyncSGDMsgHandler.AGGREGATION_CLIENT_NAME, data); + } +} diff --git a/dolphin/async/src/main/java/edu/snu/cay/dolphin/async/SyncSGD/SyncSGDWorkerSide/impl/package-info.java b/dolphin/async/src/main/java/edu/snu/cay/dolphin/async/SyncSGD/SyncSGDWorkerSide/impl/package-info.java new file mode 100644 index 000000000..749fa4826 --- /dev/null +++ b/dolphin/async/src/main/java/edu/snu/cay/dolphin/async/SyncSGD/SyncSGDWorkerSide/impl/package-info.java @@ -0,0 +1,19 @@ +/* + * Copyright (C) 2017 Seoul National University + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/** + * Components for worker-side SyncSGD. + */ +package edu.snu.cay.dolphin.async.SyncSGD.SyncSGDWorkerSide.impl; diff --git a/dolphin/async/src/main/java/edu/snu/cay/dolphin/async/SyncSGD/package-info.java b/dolphin/async/src/main/java/edu/snu/cay/dolphin/async/SyncSGD/package-info.java new file mode 100644 index 000000000..8e083d484 --- /dev/null +++ b/dolphin/async/src/main/java/edu/snu/cay/dolphin/async/SyncSGD/package-info.java @@ -0,0 +1,19 @@ +/* + * Copyright (C) 2017 Seoul National University + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/** + * Components for SyncSGD. + */ +package edu.snu.cay.dolphin.async.SyncSGD; diff --git a/dolphin/async/src/test/java/edu/snu/cay/dolphin/async/SyncPushBarrierTest.java b/dolphin/async/src/test/java/edu/snu/cay/dolphin/async/SyncPushBarrierTest.java new file mode 100644 index 000000000..c62a091b2 --- /dev/null +++ b/dolphin/async/src/test/java/edu/snu/cay/dolphin/async/SyncPushBarrierTest.java @@ -0,0 +1,120 @@ +/* + * Copyright (C) 2017 Seoul National University + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package edu.snu.cay.dolphin.async; + +import edu.snu.cay.common.centcomm.avro.CentCommMsg; +import edu.snu.cay.common.centcomm.master.MasterSideCentCommMsgSender; +import edu.snu.cay.common.centcomm.slave.SlaveSideCentCommMsgSender; +import edu.snu.cay.dolphin.async.SyncSGD.SyncSGDDriverSide.BatchManager; +import edu.snu.cay.dolphin.async.SyncSGD.SyncSGDMsgCodec; +import edu.snu.cay.dolphin.async.SyncSGD.SyncSGDWorkerSide.impl.WorkerSideSyncSGDMsgHandler; +import edu.snu.cay.dolphin.async.SyncSGD.SyncSGDWorkerSide.impl.SyncPushBarrier; +import org.apache.reef.exception.evaluator.NetworkException; +import org.apache.reef.tang.Configuration; +import org.apache.reef.tang.Injector; +import org.apache.reef.tang.Tang; +import org.apache.reef.tang.exceptions.InjectionException; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.powermock.core.classloader.annotations.PrepareForTest; +import org.powermock.modules.junit4.PowerMockRunner; + +import java.nio.ByteBuffer; + +import static org.mockito.Matchers.anyObject; +import static org.mockito.Matchers.anyString; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; + +/** + * Tests for {@link SyncPushBarrier}. + * It tests whether {@link SyncPushBarrier} works properly regarding to messages from driver. + */ +@RunWith(PowerMockRunner.class) +@PrepareForTest({MasterSideCentCommMsgSender.class, SlaveSideCentCommMsgSender.class}) +public final class SyncPushBarrierTest { + private MasterSideCentCommMsgSender masterSideCentCommMsgSender; + private SyncPushBarrier syncPushBarrier; + private WorkerSideSyncSGDMsgHandler workerSideSyncSGDMsgHandler; + private SyncSGDMsgCodec codec; + + @Before + public void setup() throws InjectionException, NetworkException { + final Configuration conf = Tang.Factory.getTang().newConfigurationBuilder() + .build(); + final Injector injector = Tang.Factory.getTang().newInjector(conf); + final SlaveSideCentCommMsgSender slaveSideCentCommMsgSender = mock(SlaveSideCentCommMsgSender.class); + injector.bindVolatileInstance(SlaveSideCentCommMsgSender.class, slaveSideCentCommMsgSender); + masterSideCentCommMsgSender = mock(MasterSideCentCommMsgSender.class); + injector.bindVolatileInstance(MasterSideCentCommMsgSender.class, masterSideCentCommMsgSender); + + this.syncPushBarrier = injector.getInstance(SyncPushBarrier.class); + this.workerSideSyncSGDMsgHandler = injector.getInstance(WorkerSideSyncSGDMsgHandler.class); + this.codec = injector.getInstance(SyncSGDMsgCodec.class); + + doAnswer(invocation -> { + final byte[] data = invocation.getArgumentAt(2, byte[].class); + workerSideSyncSGDMsgHandler.onNext(getTestAggregationMessage("driver", data)); + return null; + }).when(masterSideCentCommMsgSender).send(anyString(), anyString(), anyObject()); + } + + /** + * Test {@link WorkerSideSyncSGDMsgHandler} handles {@code PermitPushMsg} from driver properly. + * When worker receives {@code PermitPushMsg}, {@link SyncPushBarrier} should be unblocked. + */ + @Test + public void testPermitPush() throws InjectionException, NetworkException { + assert syncPushBarrier.getLatchCount() == 1; + final PermitPushMsg permitPushMsg = PermitPushMsg.newBuilder().build(); + final AvroSyncSGDMsg avroSyncSGDMsg = AvroSyncSGDMsg.newBuilder() + .setType(SyncSGDMsgType.PermitPushMsg) + .setPermitPushMsg(permitPushMsg) + .build(); + final byte[] data = codec.encode(avroSyncSGDMsg); + masterSideCentCommMsgSender.send("PushBarrierProtocol", "worker", data); + assert syncPushBarrier.getLatchCount() == 0; + } + + /** + * Test {@link WorkerSideSyncSGDMsgHandler} handles {@code StartNextMiniBatchMsg} from driver properly. + * When worker receives {@code StartNextMiniBatchMsg}, {@code thisRoundNum} value in {@link SyncPushBarrier} should + * be updated. + */ + @Test + public void testStartNextMiniBatch() throws InjectionException, NetworkException { + assert syncPushBarrier.getThisRoundNum() == 0; + final StartNextMiniBatchMsg startNextMiniBatchMsg = StartNextMiniBatchMsg.newBuilder() + .setNextRoundNum(1) + .build(); + final AvroSyncSGDMsg avroSyncSGDMsg = AvroSyncSGDMsg.newBuilder() + .setType(SyncSGDMsgType.StartNextMiniBatchMsg) + .setStartNextMiniBatchMsg(startNextMiniBatchMsg) + .build(); + final byte[] data = codec.encode(avroSyncSGDMsg); + masterSideCentCommMsgSender.send("PushBarrierProtocol", "worker", data); + assert syncPushBarrier.getThisRoundNum() == 1; + } + + private CentCommMsg getTestAggregationMessage(final String driverId, final byte[] data) { + return CentCommMsg.newBuilder() + .setSourceId(driverId) + .setClientClassName(BatchManager.CENT_COMM_CLIENT_NAME) + .setData(ByteBuffer.wrap(data)) + .build(); + } +}