Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CAY-1251] Introduce worker-side model cache #1252

Merged
merged 24 commits into from
Nov 29, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
/*
* Copyright (C) 2017 Seoul National University
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package edu.snu.cay.dolphin.async;

import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import edu.snu.cay.dolphin.async.core.worker.ModelAccessor;
import edu.snu.cay.dolphin.async.metric.Tracer;
import edu.snu.cay.services.et.evaluator.api.Table;
import edu.snu.cay.services.et.evaluator.api.TableAccessor;
import edu.snu.cay.services.et.evaluator.api.UpdateFunction;
import edu.snu.cay.services.et.exceptions.TableNotExistException;
import org.apache.reef.tang.annotations.Parameter;

import javax.inject.Inject;
import java.util.*;
import java.util.concurrent.*;

/**
* A {@link ModelAccessor} implementation with model cache.
*/
public final class CachedModelAccessor<K, P, V> implements ModelAccessor<K, P, V> {
private static final int MODEL_REFRESH_SEC = 10; // TODO #1254: introduce a sophisticated cache policy
private static final int CACHE_CONCURRENCY_WRITES = 4;

private final LoadingCache<K, V> modelLoadingCache;

private final Table<K, V, P> modelTable;
private final UpdateFunction<K, V, P> modelUpdateFunction;

private final ScheduledExecutorService refreshExecutor;

private final Tracer pushTracer = new Tracer();
private final Tracer pullTracer = new Tracer();

@Inject
private CachedModelAccessor(@Parameter(DolphinParameters.ModelTableId.class) final String modelTableId,
final TableAccessor tableAccessor,
final UpdateFunction<K, V, P> modelUpdateFunction) throws TableNotExistException {
this.modelTable = tableAccessor.getTable(modelTableId);
this.modelUpdateFunction = modelUpdateFunction;

this.modelLoadingCache = initCache();

refreshExecutor = Executors.newSingleThreadScheduledExecutor();
refreshExecutor.scheduleWithFixedDelay(() -> {
final Set<K> keys = modelLoadingCache.asMap().keySet();

if (!keys.isEmpty()) {
final List<K> keyList = new ArrayList<>(keys.size());
try {
pullTracer.startTimer();
final Map<K, V> kvMap = modelTable.multiGetOrInit(keyList).get();
pullTracer.recordTime(keys.size());

kvMap.forEach(modelLoadingCache::put);
} catch (InterruptedException | ExecutionException e) {
throw new RuntimeException(e);
}
}

}, 0, MODEL_REFRESH_SEC, TimeUnit.SECONDS);
}

public void stopRefreshingCache() {
refreshExecutor.shutdown();
}

private LoadingCache<K, V> initCache() {
return CacheBuilder.newBuilder()
.concurrencyLevel(CACHE_CONCURRENCY_WRITES)
.build(new CacheLoader<K, V>() {
@Override
public V load(final K key) throws Exception {
pullTracer.startTimer();
final Future<V> pullFuture = modelTable.getOrInit(key);
final V value = pullFuture.get();
pullTracer.recordTime(1);
return value;
}

@Override
public Map<K, V> loadAll(final Iterable<? extends K> keys) throws Exception {
final List<K> keyList = new ArrayList<>();
keys.forEach(keyList::add);

pullTracer.startTimer();
final Map<K, V> kvMap = modelTable.multiGetOrInit(keyList).get();
pullTracer.recordTime(kvMap.size());

return kvMap;
}
});
}

/**
* Push a delta value for a key, applying the change to cache.
*/
@Override
public void push(final K key, final P deltaValue) {
pushTracer.startTimer();
modelTable.updateNoReply(key, deltaValue);
pushTracer.recordTime(1);

// update value in cache. this modification will not cause entry loading in cache.
modelLoadingCache.asMap().
computeIfPresent(key, (k, oldValue) -> modelUpdateFunction.updateValue(k, oldValue, deltaValue));
}

@Override
public void push(final Map<K, P> keyToDeltaValueMap) {
pushTracer.startTimer();
modelTable.multiUpdateNoReply(keyToDeltaValueMap);
pushTracer.recordTime(keyToDeltaValueMap.size());

// update value in cache. this modification will not cause entry loading in cache.
keyToDeltaValueMap.forEach((key, deltaValue) -> modelLoadingCache.asMap().
computeIfPresent(key, (k, oldValue) -> modelUpdateFunction.updateValue(k, oldValue, deltaValue)));
}

/**
* Retrieve a value for a requested key.
* Pull value from servers, if cache does not have value for the key.
*/
@Override
public V pull(final K key) {
try {
return modelLoadingCache.get(key);
} catch (ExecutionException e) {
throw new RuntimeException(e);
}
}

/**
* Retrieve values for requested keys.
* Pull values from servers, if cache does not have all values for the keys.
*/
@Override
public List<V> pull(final List<K> keys) {
try {
final Map<K, V> kvMap = modelLoadingCache.getAll(keys);
final List<V> valueList = new ArrayList<>(keys.size());
keys.forEach(key -> valueList.add(kvMap.get(key)));

return valueList;
} catch (ExecutionException e) {
throw new RuntimeException(e);
}
}

/**
* This method does not care about cache.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It'd be great if we clarify what are the differences between pull(List<K> keys, Table table) and pull(List<K> keys). The distinction is missing in the base interface, but we differentiate them in this implementation (with cache vs. without cache).

Copy link
Contributor Author

@wynot12 wynot12 Nov 7, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm.. actually this part is completely same with no-cache version.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, then the question would be more about
why do we not care about cache in this method, while pull(final List<K> keys) gets the data from the cache?

This question raised my original question above: what are the differences between the two methods?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pull(List<K> keys, Table table) is for using other tables that has no caches.
This ModelAccessor implementation provides a cache only for a table in its field.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pull(List<K> keys, Table table) has been inserted to support offline model evaluation.
So this method may seem awkward in ModelAccessor interface.

*/
@Override
public List<V> pull(final List<K> keys, final Table<K, V, P> aModelTable) {
try {
final Map<K, V> result = aModelTable.multiGetOrInit(keys).get();

final List<V> valueList = new ArrayList<>(keys.size());
keys.forEach(key -> valueList.add(result.get(key)));

return valueList;

} catch (InterruptedException | ExecutionException e) {
throw new RuntimeException(e);
}
}

@Override
public Map<String, Double> getAndResetMetrics() {
final Map<String, Double> metrics = new HashMap<>();
metrics.put(METRIC_TOTAL_PULL_TIME_SEC, pullTracer.totalElapsedTime());
metrics.put(METRIC_TOTAL_PUSH_TIME_SEC, pushTracer.totalElapsedTime());
metrics.put(METRIC_AVG_PULL_TIME_SEC, pullTracer.avgTimePerElem());
metrics.put(METRIC_AVG_PUSH_TIME_SEC, pushTracer.avgTimePerElem());

pullTracer.resetTrace();
pushTracer.resetTrace();
return metrics;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,11 @@ public final class NumServerBlocks implements Name<Integer> {
public final class HyperThreadEnabled implements Name<Boolean> {
}

@NamedParameter(doc = "Whether the model cache is enabled.",
short_name = "model_cache_enabled", default_value = "false")
public final class ModelCacheEnabled implements Name<Boolean> {
}

@NamedParameter(doc = "Desired memory size for each worker evaluator (MBs)",
short_name = "worker_mem_size",
default_value = "128")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
* limitations under the License.
*/
package edu.snu.cay.dolphin.async.core.client;
import edu.snu.cay.dolphin.async.CachedModelAccessor;
import edu.snu.cay.dolphin.async.DolphinParameters.*;
import edu.snu.cay.common.param.Parameters.*;
import edu.snu.cay.dolphin.async.core.driver.DolphinDriver;
Expand Down Expand Up @@ -128,14 +129,21 @@ public static LauncherStatus launch(final String jobName,
.bindNamedParameter(UpdateValueCodec.class, dolphinConf.getModelUpdateValueCodecClass())
.build());

final Injector workerParameterInjector = Tang.Factory.getTang().newInjector(workerParamConf);

final boolean modelCacheEnabled = workerParameterInjector.getNamedInstance(ModelCacheEnabled.class);
final Class<? extends ModelAccessor> modelAccessorClass =
modelCacheEnabled ? CachedModelAccessor.class : ETModelAccessor.class;

// worker conf
final Configuration workerConf = Configurations.merge(
workerParamConf, userParamConf,
Tang.Factory.getTang().newConfigurationBuilder()
.bindImplementation(Trainer.class, dolphinConf.getTrainerClass())
.bindImplementation(DataParser.class, dolphinConf.getInputParserClass())
.bindImplementation(TrainingDataProvider.class, ETTrainingDataProvider.class)
.bindImplementation(ModelAccessor.class, ETModelAccessor.class)
.bindImplementation(ModelAccessor.class, modelAccessorClass)
.bindImplementation(UpdateFunction.class, dolphinConf.getModelUpdateFunctionClass())
.bindNamedParameter(KeyCodec.class, dolphinConf.getInputKeyCodecClass())
.bindNamedParameter(ValueCodec.class, dolphinConf.getInputValueCodecClass())
.build());
Expand Down Expand Up @@ -230,7 +238,8 @@ private static List<Configuration> parseCommandLine(
final List<Class<? extends Name<?>>> serverParamList = Collections.emptyList();

final List<Class<? extends Name<?>>> workerParamList = Arrays.asList(
HyperThreadEnabled.class, MaxNumEpochs.class, NumTotalMiniBatches.class, TestDataPath.class);
HyperThreadEnabled.class, ModelCacheEnabled.class,
MaxNumEpochs.class, NumTotalMiniBatches.class, TestDataPath.class);

// commonly used parameters for ML apps
final List<Class<? extends Name<?>>> commonAppParamList = Arrays.asList(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,7 @@ public List<V> pull(final List<K> keys) {
return resultValues;
}

/**
* Do {@link #pull(List)} with a given table.
* @param keys a list of keys of model parameter
* @param aModelTable a table to read value from
* @return a list of values associated with the given {@code keys}.
* Some positions in the list can be {@code null}, if the key has no associated value
*/
@Override
public List<V> pull(final List<K> keys, final Table<K, V, P> aModelTable) {
try {
final Map<K, V> result = aModelTable.multiGetOrInit(keys).get();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/
package edu.snu.cay.dolphin.async.core.worker;

import edu.snu.cay.dolphin.async.CachedModelAccessor;
import edu.snu.cay.dolphin.async.DolphinParameters;
import edu.snu.cay.dolphin.async.metric.avro.*;
import edu.snu.cay.services.et.metric.MetricCollector;
Expand Down Expand Up @@ -139,6 +140,11 @@ public byte[] call(final byte[] memento) throws Exception {
workerGlobalBarrier.await();

trainer.cleanup();

if (modelAccessor instanceof CachedModelAccessor) {
((CachedModelAccessor) modelAccessor).stopRefreshingCache();
}

return null;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
*/
package edu.snu.cay.dolphin.async.core.worker;

import edu.snu.cay.services.et.evaluator.api.Table;

import java.util.List;
import java.util.Map;

Expand Down Expand Up @@ -75,6 +77,15 @@ public interface ModelAccessor<K, P, V> {
*/
List<V> pull(List<K> keys);

/**
* Do {@link #pull(List)} with a given table.
* @param keys a list of keys of model parameter
* @param aModelTable a table to read value from
* @return a list of values associated with the given {@code keys}.
* Some positions in the list can be {@code null}, if the key has no associated value
*/
List<V> pull(List<K> keys, Table<K, V, P> aModelTable);

/**
* Fetches the collected metrics and reset the tracers for collecting metrics in the next round.
* @return the metrics that are identified by the keys in this interface.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ public Map<CharSequence, Double> evaluateModel(final Collection<Document> inputD
final edu.snu.cay.services.et.evaluator.api.Table modelTable) {

LOG.log(Level.INFO, "Pull model to compute log likelihood");
final List<int[]> wordTopicCounts = ((ETModelAccessor)modelAccessor).pull(vocabList, modelTable);
final List<int[]> wordTopicCounts = modelAccessor.pull(vocabList, modelTable);
final int[] wordTopicCountsSummary = wordTopicCounts.remove(numVocabs);

LOG.log(Level.INFO, "Start computing log likelihood");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import edu.snu.cay.common.math.linalg.Vector;
import edu.snu.cay.common.math.linalg.VectorFactory;
import edu.snu.cay.dolphin.async.*;
import edu.snu.cay.dolphin.async.core.worker.ETModelAccessor;
import edu.snu.cay.dolphin.async.core.worker.ModelAccessor;
import edu.snu.cay.dolphin.async.core.worker.ModelHolder;
import edu.snu.cay.dolphin.async.core.worker.Trainer;
Expand Down Expand Up @@ -266,7 +265,7 @@ public Map<CharSequence, Double> evaluateModel(final Collection<MLRData> inputDa
* Pull models one last time and perform validation.
*/
private MLRModel pullModelsToEvaluate(final List<Integer> keys, final Table<Integer, Vector, Vector> modelTable) {
final List<Vector> partitions = ((ETModelAccessor) modelAccessor).pull(keys, modelTable);
final List<Vector> partitions = modelAccessor.pull(keys, modelTable);

final MLRModel mlrModel = new MLRModel(new Vector[numClasses]);
final Vector[] params = mlrModel.getParams();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import edu.snu.cay.common.math.linalg.Vector;
import edu.snu.cay.common.math.linalg.VectorEntry;
import edu.snu.cay.common.math.linalg.VectorFactory;
import edu.snu.cay.dolphin.async.core.worker.ETModelAccessor;
import edu.snu.cay.dolphin.async.core.worker.ModelAccessor;
import edu.snu.cay.dolphin.async.core.worker.Trainer;
import edu.snu.cay.dolphin.async.core.worker.TrainingDataProvider;
Expand Down Expand Up @@ -203,7 +202,7 @@ public Map<CharSequence, Double> evaluateModel(final Collection<NMFData> inputDa

private NMFModel pullModelToEvaluate(final List<Integer> keys, final Table<Integer, Vector, Vector> modelTable) {
final Map<Integer, Vector> rMatrix = new HashMap<>(keys.size());
final List<Vector> vectors = ((ETModelAccessor) modelAccessor).pull(keys, modelTable);
final List<Vector> vectors = modelAccessor.pull(keys, modelTable);
for (int i = 0; i < keys.size(); ++i) {
rMatrix.put(keys.get(i), vectors.get(i));
}
Expand Down