Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CAY-1251] Introduce worker-side model cache #1252

Merged
merged 24 commits into from
Nov 29, 2017
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
/*
* Copyright (C) 2017 Seoul National University
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package edu.snu.cay.dolphin.async;

import edu.snu.cay.dolphin.async.metric.Tracer;
import edu.snu.cay.services.et.evaluator.api.Table;
import edu.snu.cay.services.et.evaluator.api.TableAccessor;
import edu.snu.cay.services.et.evaluator.api.UpdateFunction;
import edu.snu.cay.services.et.exceptions.TableNotExistException;
import org.apache.reef.tang.annotations.Parameter;

import javax.inject.Inject;
import java.util.*;
import java.util.concurrent.*;

/**
* A {@link ModelAccessor} implementation with model cache.
*/
public final class CachedModelAccessor<K, P, V> implements ModelAccessor<K, P, V> {

private final Map<K, V> cache = new ConcurrentHashMap<>();

private final Table<K, V, P> modelTable;
private final UpdateFunction<K, V, P> modelUpdateFunction;

private final Tracer pushTracer = new Tracer();
private final Tracer pullTracer = new Tracer();

@Inject
private CachedModelAccessor(@Parameter(DolphinParameters.ModelTableId.class) final String modelTableId,
final TableAccessor tableAccessor,
final UpdateFunction<K, V, P> modelUpdateFunction) throws TableNotExistException {
this.modelTable = tableAccessor.getTable(modelTableId);
this.modelUpdateFunction = modelUpdateFunction;

// TODO #00: introduce a sophisticated cache refresh/eviction policy
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please update the issue number to #1254

Executors.newSingleThreadScheduledExecutor()
.scheduleWithFixedDelay(this::refreshCache, 10, 10, TimeUnit.SECONDS);
}

/**
* Push a delta value for a key, applying the change to cache.
*/
@Override
public void push(final K key, final P deltaValue) {
pushTracer.startTimer();
modelTable.updateNoReply(key, deltaValue);
pushTracer.recordTime(1);

// update local cache. oldValue always exists
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit confused with the phrase oldValue always exists: does it imply that push() occurs always after some parameters are loaded via pull()? If it's true, could you please add a comment why oldValue is guaranteed to exist?

cache.compute(key, (k, oldValue) -> modelUpdateFunction.updateValue(k, oldValue, deltaValue));
}

/**
* Retrieve a value for a requested key.
* Pull value from servers, if cache does not have value for the key.
*/
@Override
public V pull(final K key) {
// 1. in cache
final V cachedValue = cache.get(key);
if (cachedValue != null) {
return cachedValue;
} else {
// 2. not in cache
final V pulledValue;
try {
pullTracer.startTimer();
pulledValue = modelTable.getOrInit(key).get();
pullTracer.recordTime(1);
} catch (InterruptedException | ExecutionException e) {
throw new RuntimeException(e);
}
cache.put(key, pulledValue);
return pulledValue;
}
}

/**
* Retrieve values for requested keys.
* Pull values from servers, if cache does not have all values for the keys.
*/
@Override
public List<V> pull(final List<K> keys) {
// 1. all values are in cache
if (cache.keySet().containsAll(keys)) {
final List<V> resultValues = new ArrayList<>(keys.size());
keys.forEach(key -> resultValues.add(cache.get(key)));
return resultValues;
} else {
// 2. some values are not in cache
final Map<K, V> resultMap = new HashMap<>(keys.size());
final Map<K, Future<V>> pullFutures = new HashMap<>();
for (final K key : keys) {
final V value = cache.get(key);
if (value == null) {
pullFutures.put(key, modelTable.getOrInit(key));
} else {
resultMap.put(key, value);
}
}

if (!pullFutures.isEmpty()) {
pullTracer.startTimer();
// pull non-cached values
pullFutures.forEach((key, valueFuture) -> {
final V value;
try {
value = valueFuture.get();
} catch (InterruptedException | ExecutionException e) {
throw new RuntimeException(e);
}
cache.put(key, value);
resultMap.put(key, value);
});
pullTracer.recordTime(pullFutures.size());
}

return new ArrayList<>(resultMap.values());
}
}

/**
* This method does not care about cache.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It'd be great if we clarify what are the differences between pull(List<K> keys, Table table) and pull(List<K> keys). The distinction is missing in the base interface, but we differentiate them in this implementation (with cache vs. without cache).

Copy link
Contributor Author

@wynot12 wynot12 Nov 7, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm.. actually this part is completely same with no-cache version.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, then the question would be more about
why do we not care about cache in this method, while pull(final List<K> keys) gets the data from the cache?

This question raised my original question above: what are the differences between the two methods?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pull(List<K> keys, Table table) is for using other tables that has no caches.
This ModelAccessor implementation provides a cache only for a table in its field.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pull(List<K> keys, Table table) has been inserted to support offline model evaluation.
So this method may seem awkward in ModelAccessor interface.

*/
@Override
public List<V> pull(final List<K> keys, final Table aModelTable) {
final List<Future<V>> resultList = new ArrayList<>(keys.size());
keys.forEach(key -> resultList.add(aModelTable.getOrInit(key)));

final List<V> resultValues = new ArrayList<>(keys.size());
for (final Future<V> opResult : resultList) {
V result;
while (true) {
try {
result = opResult.get();
break;
} catch (InterruptedException e) {
// ignore and keep waiting
} catch (ExecutionException e) {
throw new RuntimeException(e);
}
}

resultValues.add(result);
}

return resultValues;
}

@Override
public Map<String, Double> getAndResetMetrics() {
final Map<String, Double> metrics = new HashMap<>();
metrics.put(METRIC_TOTAL_PULL_TIME_SEC, pullTracer.totalElapsedTime());
metrics.put(METRIC_TOTAL_PUSH_TIME_SEC, pushTracer.totalElapsedTime());
metrics.put(METRIC_AVG_PULL_TIME_SEC, pullTracer.avgTimePerElem());
metrics.put(METRIC_AVG_PUSH_TIME_SEC, pushTracer.avgTimePerElem());

pullTracer.resetTrace();
pushTracer.resetTrace();
return metrics;
}

private void refreshCache() {
final Set<K> keys = cache.keySet();

if (!keys.isEmpty()) {
pullTracer.startTimer();
final Map<K, Future<V>> pullFutures = new HashMap<>(keys.size());
keys.forEach(key -> pullFutures.put(key, modelTable.getOrInit(key)));

pullFutures.forEach((key, pullFuture) -> {
try {
cache.put(key, pullFuture.get());
} catch (InterruptedException | ExecutionException e) {
throw new RuntimeException(e);
}
});
pullTracer.recordTime(pullFutures.size());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,11 @@ public final class NumServerBlocks implements Name<Integer> {
public final class HyperThreadEnabled implements Name<Boolean> {
}

@NamedParameter(doc = "Whether the model cache is enabled.",
short_name = "model_cache_enabled", default_value = "false")
public final class ModelCacheEnabled implements Name<Boolean> {
}

@NamedParameter(doc = "Desired memory size for each worker evaluator (MBs)",
short_name = "worker_mem_size",
default_value = "128")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,14 +124,21 @@ public static LauncherStatus launch(final String jobName,
.bindNamedParameter(UpdateValueCodec.class, dolphinConf.getModelUpdateValueCodecClass())
.build());

final Injector workerParameterInjector = Tang.Factory.getTang().newInjector(workerParamConf);

final boolean modelCacheEnabled = workerParameterInjector.getNamedInstance(ModelCacheEnabled.class);
final Class<? extends ModelAccessor> modelAccessorClass =
modelCacheEnabled ? CachedModelAccessor.class : ETModelAccessor.class;

// worker conf
final Configuration workerConf = Configurations.merge(
workerParamConf, userParamConf,
Tang.Factory.getTang().newConfigurationBuilder()
.bindImplementation(Trainer.class, dolphinConf.getTrainerClass())
.bindImplementation(DataParser.class, dolphinConf.getInputParserClass())
.bindImplementation(TrainingDataProvider.class, ETTrainingDataProvider.class)
.bindImplementation(ModelAccessor.class, ETModelAccessor.class)
.bindImplementation(ModelAccessor.class, modelAccessorClass)
.bindImplementation(UpdateFunction.class, dolphinConf.getModelUpdateFunctionClass())
.bindNamedParameter(KeyCodec.class, dolphinConf.getInputKeyCodecClass())
.bindNamedParameter(ValueCodec.class, dolphinConf.getInputValueCodecClass())
.build());
Expand Down Expand Up @@ -226,7 +233,8 @@ private static List<Configuration> parseCommandLine(
final List<Class<? extends Name<?>>> serverParamList = Collections.emptyList();

final List<Class<? extends Name<?>>> workerParamList = Arrays.asList(
HyperThreadEnabled.class, MaxNumEpochs.class, NumTotalMiniBatches.class, TestDataPath.class);
HyperThreadEnabled.class, ModelCacheEnabled.class,
MaxNumEpochs.class, NumTotalMiniBatches.class, TestDataPath.class);

// commonly used parameters for ML apps
final List<Class<? extends Name<?>>> commonAppParamList = Arrays.asList(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,7 @@ public List<V> pull(final List<K> keys) {
return resultValues;
}

/**
* Do {@link #pull(List)} with a given table.
* @param keys a list of keys of model parameter
* @param aModelTable a table to read value from
* @return a list of values associated with the given {@code keys}.
* Some positions in the list can be {@code null}, if the key has no associated value
*/
@Override
public List<V> pull(final List<K> keys, final Table aModelTable) {
final List<Future<V>> resultList = new ArrayList<>(keys.size());
keys.forEach(key -> resultList.add(aModelTable.getOrInit(key)));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
*/
package edu.snu.cay.dolphin.async;

import edu.snu.cay.services.et.evaluator.api.Table;

import java.util.List;
import java.util.Map;

Expand Down Expand Up @@ -69,6 +71,15 @@ public interface ModelAccessor<K, P, V> {
*/
List<V> pull(List<K> keys);

/**
* Do {@link #pull(List)} with a given table.
* @param keys a list of keys of model parameter
* @param aModelTable a table to read value from
* @return a list of values associated with the given {@code keys}.
* Some positions in the list can be {@code null}, if the key has no associated value
*/
List<V> pull(List<K> keys, Table aModelTable);

/**
* Fetches the collected metrics and reset the tracers for collecting metrics in the next round.
* @return the metrics that are identified by the keys in this interface.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ public Map<CharSequence, Double> evaluateModel(final Collection<Document> inputD
final edu.snu.cay.services.et.evaluator.api.Table modelTable) {

LOG.log(Level.INFO, "Pull model to compute log likelihood");
final List<int[]> wordTopicCounts = ((ETModelAccessor)modelAccessor).pull(vocabList, modelTable);
final List<int[]> wordTopicCounts = modelAccessor.pull(vocabList, modelTable);
final int[] wordTopicCountsSummary = wordTopicCounts.remove(numVocabs);

LOG.log(Level.INFO, "Start computing log likelihood");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ public Map<CharSequence, Double> evaluateModel(final Collection<MLRData> inputDa
* Pull models one last time and perform validation.
*/
private MLRModel pullModelsToEvaluate(final List<Integer> keys, final Table<Integer, Vector, Vector> modelTable) {
final List<Vector> partitions = ((ETModelAccessor) modelAccessor).pull(keys, modelTable);
final List<Vector> partitions = modelAccessor.pull(keys, modelTable);

final MLRModel mlrModel = new MLRModel(new Vector[numClasses]);
final Vector[] params = mlrModel.getParams();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ public Map<CharSequence, Double> evaluateModel(final Collection<NMFData> inputDa

private NMFModel pullModelToEvaluate(final List<Integer> keys, final Table<Integer, Vector, Vector> modelTable) {
final Map<Integer, Vector> rMatrix = new HashMap<>(keys.size());
final List<Vector> vectors = ((ETModelAccessor) modelAccessor).pull(keys, modelTable);
final List<Vector> vectors = modelAccessor.pull(keys, modelTable);
for (int i = 0; i < keys.size(); ++i) {
rMatrix.put(keys.get(i), vectors.get(i));
}
Expand Down