diff --git a/core/pom.xml b/core/pom.xml index c7e2385..a728649 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -43,6 +43,10 @@ commons-io commons-io + + com.github.dpaukov + combinatoricslib3 + org.slf4j slf4j-api diff --git a/core/src/main/scala/com/eharmony/spotz/backend/BackendFunctions.scala b/core/src/main/scala/com/eharmony/spotz/backend/BackendFunctions.scala index cfdb013..30aa999 100644 --- a/core/src/main/scala/com/eharmony/spotz/backend/BackendFunctions.scala +++ b/core/src/main/scala/com/eharmony/spotz/backend/BackendFunctions.scala @@ -1,8 +1,8 @@ package com.eharmony.spotz.backend import com.eharmony.spotz.objective.Objective -import com.eharmony.spotz.optimizer.RandomSampler import com.eharmony.spotz.optimizer.grid.Grid +import com.eharmony.spotz.optimizer.hyperparam.RandomSampler import scala.reflect.ClassTag @@ -15,13 +15,13 @@ import scala.reflect.ClassTag */ trait BackendFunctions { protected def bestRandomPointAndLoss[P, L]( - startIndex: Long, - batchSize: Long, - objective: Objective[P, L], - reducer: ((P, L), (P, L)) => (P, L), - hyperParams: Map[String, RandomSampler[_]], - seed: Long = 0, - sampleFunction: (Map[String, RandomSampler[_]], Long) => P): (P, L) + startIndex: Long, + batchSize: Long, + objective: Objective[P, L], + reducer: ((P, L), (P, L)) => (P, L), + hyperParams: Map[String, RandomSampler[_]], + seed: Long = 0, + sampleFunction: (Map[String, RandomSampler[_]], Long) => P): (P, L) protected def bestGridPointAndLoss[P, L]( startIndex: Long, diff --git a/core/src/main/scala/com/eharmony/spotz/backend/ParallelFunctions.scala b/core/src/main/scala/com/eharmony/spotz/backend/ParallelFunctions.scala index 1498b54..98853c3 100644 --- a/core/src/main/scala/com/eharmony/spotz/backend/ParallelFunctions.scala +++ b/core/src/main/scala/com/eharmony/spotz/backend/ParallelFunctions.scala @@ -1,8 +1,8 @@ package com.eharmony.spotz.backend import com.eharmony.spotz.objective.Objective -import com.eharmony.spotz.optimizer.RandomSampler import com.eharmony.spotz.optimizer.grid.Grid +import com.eharmony.spotz.optimizer.hyperparam.RandomSampler import scala.reflect.ClassTag @@ -31,13 +31,13 @@ trait ParallelFunctions extends BackendFunctions { * @return the best point with the best loss as a tuple */ protected override def bestRandomPointAndLoss[P, L]( - startIndex: Long, - batchSize: Long, - objective: Objective[P, L], - reducer: ((P, L), (P, L)) => (P, L), - hyperParams: Map[String, RandomSampler[_]], - seed: Long = 0, - sampleFunction: (Map[String, RandomSampler[_]], Long) => P): (P, L) = { + startIndex: Long, + batchSize: Long, + objective: Objective[P, L], + reducer: ((P, L), (P, L)) => (P, L), + hyperParams: Map[String, RandomSampler[_]], + seed: Long = 0, + sampleFunction: (Map[String, RandomSampler[_]], Long) => P): (P, L) = { val pointsAndLosses = (startIndex until (startIndex + batchSize)).par.map { trial => val point = sampleFunction(hyperParams, seed + trial) diff --git a/core/src/main/scala/com/eharmony/spotz/backend/SparkFunctions.scala b/core/src/main/scala/com/eharmony/spotz/backend/SparkFunctions.scala index 446c289..d1e78a5 100644 --- a/core/src/main/scala/com/eharmony/spotz/backend/SparkFunctions.scala +++ b/core/src/main/scala/com/eharmony/spotz/backend/SparkFunctions.scala @@ -1,8 +1,8 @@ package com.eharmony.spotz.backend import com.eharmony.spotz.objective.Objective -import com.eharmony.spotz.optimizer.RandomSampler import com.eharmony.spotz.optimizer.grid.Grid +import com.eharmony.spotz.optimizer.hyperparam.RandomSampler import org.apache.spark.SparkContext import scala.reflect.ClassTag @@ -33,13 +33,13 @@ trait SparkFunctions extends BackendFunctions { * @return the best point with the best loss as a tuple */ protected override def bestRandomPointAndLoss[P, L]( - startIndex: Long, - batchSize: Long, - objective: Objective[P, L], - reducer: ((P, L), (P, L)) => (P, L), - hyperParams: Map[String, RandomSampler[_]], - seed: Long = 0, - sampleFunction: (Map[String, RandomSampler[_]], Long) => P): (P, L) = { + startIndex: Long, + batchSize: Long, + objective: Objective[P, L], + reducer: ((P, L), (P, L)) => (P, L), + hyperParams: Map[String, RandomSampler[_]], + seed: Long = 0, + sampleFunction: (Map[String, RandomSampler[_]], Long) => P): (P, L) = { assert(batchSize > 0, "batchSize must be greater than 0") diff --git a/core/src/main/scala/com/eharmony/spotz/optimizer/HyperParameter.scala b/core/src/main/scala/com/eharmony/spotz/optimizer/HyperParameter.scala deleted file mode 100644 index 8f44305..0000000 --- a/core/src/main/scala/com/eharmony/spotz/optimizer/HyperParameter.scala +++ /dev/null @@ -1,150 +0,0 @@ -package com.eharmony.spotz.optimizer - -import scala.util.Random - -/** - * @author vsuthichai - */ -trait SamplerFunction[T] extends Serializable - -/** - * A sampler function dependent on a pseudo random number generator. The generator - * is passed in as a parameter, and this is intentional. It allows the user to - * change seeds and switch generators. This becomes important when sampling on - * Spark workers and more control over the rng is necessary. - * - * @tparam T - */ -abstract class RandomSampler[T] extends SamplerFunction[T] { - def apply(rng: Random): T -} - -abstract class Uniform[T](lb: T, ub: T) extends RandomSampler[T] - -/** - * Sample a Double within the bounds lb <= x < ub with uniform random distribution - * - * {{{ - * val rng = new Random(seed) - * val sampler = UniformDouble(0, 1)) - * val sample = sampler(rng) - * }}} - * - * @param lb lower bound - * @param ub upper bound - */ -case class UniformDouble(lb: Double, ub: Double) extends Uniform[Double](lb, ub) { - if (lb >= ub) - throw new IllegalArgumentException("lb must be less than ub") - - override def apply(rng: Random): Double = lb + ((ub - lb) * rng.nextDouble) -} - -/** - * Sample an Int within the bounds lb <= x < ub with uniform random distribution - * - * {{{ - * val hyperParamSpace = Map( - * ("x1", UniformInt(0, 10)) - * ) - * }}} - * - * @param lb lower bound - * @param ub upper bound - */ -case class UniformInt(lb: Int, ub: Int) extends Uniform[Int](lb, ub) { - if (lb >= ub) - throw new IllegalArgumentException("lb must be less than ub") - - override def apply(rng: Random): Int = lb + rng.nextInt(ub - lb) -} - -/** - * Sample from a normal distribution given the mean and standard deviation - * - * {{{ - * val hyperParamSpace = Map( - * ("x1", NormalDistribution(0, 0.1)) - * ) - * }}} - * - * @param mean mean - * @param std standard deviation - */ -case class NormalDistribution(mean: Double, std: Double) extends RandomSampler[Double] { - override def apply(rng: Random): Double = { - std * rng.nextGaussian() + mean - } -} - -/** - * Given an iterable of RandomSampler functions, choose a function at random and - * sample from it. - * - * {{{ - * val hyperParamSpace = Map( - * ("x1", Union(UniformDouble(0, 1), UniformDouble(10, 11))) - * ) - * }}} - * - * @param iterable an iterable of RandomSampler[T] functions - * @param probs an iterable of probabilities that should sum to 1. This specifies the probabilities that - * the sampler functions are chosen. If the length of this is not the same as the length - * of the iterable of RandomSampler[T] functions, then an IllegalArgumentException is thrown. - * Not specifying an iterable of probabilities will force a default uniform random sampling. - * If the length of the iterable of probabilities is equal to the length of the iterable of - * RandomSampler[T] functions, then - * @tparam T type parameter of the sample - */ -case class Union[T](iterable: Iterable[RandomSampler[T]], probs: Iterable[Double] = Seq()) extends RandomSampler[T] { - private val indexedSeq = iterable.toIndexedSeq - - private val probabilities = - if (probs.isEmpty) - Seq.fill(indexedSeq.length)(1.toDouble / indexedSeq.length) - else if (probs.toIndexedSeq.length != indexedSeq.length) - throw new IllegalArgumentException("iterable lengths must match") - else if (probs.exists(p => p <= 0)) - throw new IllegalArgumentException("Must be positive or valid probabilities") - else if (probs.sum != 1.0) - probs.map(p => p / probs.sum) - - override def apply(rng: Random): T = ??? - private def bucket(probability: Double): Int = ??? -} - -/** - * Sample an element from an Iterable of fixed length with uniform random distribution. - * - * {{{ - * val hyperParamSpace = Map( - * ("x1", RandomChoice(Seq("svm", "logistic"))) - * ) - * }}} - * - * @param iterable an iterable of type T - * @tparam T type parameter of iterable - */ -case class RandomChoice[T](iterable: Iterable[T]) extends RandomSampler[T] { - private val values = iterable.toIndexedSeq - - if (values.length < 1) - throw new IllegalArgumentException("Empty iterable") - - override def apply(rng: Random): T = values(rng.nextInt(values.length)) -} - -/** - * N Choose K, where N is the size of an Iterable. - * - * @param iterable - * @param k - * @tparam T - */ -case class BinomialCoefficient[T](iterable: Iterable[T], k: Int) extends RandomSampler[Iterable[T]] { - private val values = iterable.toSeq - - override def apply(rng: Random): Iterable[T] = { - rng.shuffle(values).take(k) - } -} diff --git a/core/src/main/scala/com/eharmony/spotz/optimizer/hyperparam/Combinations.scala b/core/src/main/scala/com/eharmony/spotz/optimizer/hyperparam/Combinations.scala new file mode 100644 index 0000000..f0f9f76 --- /dev/null +++ b/core/src/main/scala/com/eharmony/spotz/optimizer/hyperparam/Combinations.scala @@ -0,0 +1,96 @@ +package com.eharmony.spotz.optimizer.hyperparam + +import scala.util.Random + + +trait CombinatoricRandomSampler[T] extends RandomSampler[Iterable[Iterable[T]]] +trait IterableRandomSampler[T] extends RandomSampler[Iterable[T]] + +/** + * + * @param iterable + * @param k + * @param x + * @param replacement + * @tparam T + */ +abstract class AbstractCombinations[T]( + iterable: Iterable[T], + k: Int, + x: Int = 1, + replacement: Boolean = false) extends Serializable { + + import org.paukov.combinatorics3.Generator + + import scala.collection.JavaConverters._ + + private val values = iterable.toSeq + + assert(k > 0, "k must be greater than 0") + assert(k <= values.length, s"k must be less than or equal to length of the iterable, ${values.length}") + + // TODO: This is hideous! Rewrite this to be more memory efficient by unranking combinations. For now, use a Java lib. + val combinations = Generator.combination(iterable.asJavaCollection).simple(k).asScala.toIndexedSeq.map(l => l.asScala.toIndexedSeq) + + /** + * + * @param rng + * @return + */ + def combos(rng: Random): Iterable[Iterable[T]] = { + if (replacement) { + Seq.fill(x)(combinations(rng.nextInt(combinations.size))) + } else { + val indices = collection.mutable.Set[Int]() + val numElements = scala.math.min(x, combinations.size) + val ret = new collection.mutable.ArrayBuffer[Iterable[T]](numElements) + while (indices.size < numElements) { + val index = rng.nextInt(combinations.size) + if (!indices.contains(index)) { + indices.add(index) + ret += combinations(index) + } + } + ret.toIndexedSeq + } + } +} + + +/** + * Sample a single combination of K unordered items from the iterable of length N. + * + * @param iterable + * @param k + * @param replacement + * @tparam T + */ +case class Combination[T]( + iterable: Iterable[T], + k: Int, + replacement: Boolean = false) + extends AbstractCombinations[T](iterable, k, 1, replacement) with IterableRandomSampler[T] { + + override def apply(rng: Random): Iterable[T] = combos(rng).head +} + + +/** + * Binomial coefficient implementation. Pick K unordered items from an Iterable of N items. + * Also known as N Choose K, where N is the size of an Iterable and K is the desired number + * of items to be chosen. This implementation will actually compute all the possible choices + * and return them as an Iterable. + * + * @param iterable an iterable of finite length + * @param k the number of items to choose + * @tparam T + */ +case class Combinations[T]( + iterable: Iterable[T], + k: Int, + x: Int = 1, + replacement: Boolean = false) + extends AbstractCombinations[T](iterable, k, x, replacement) with CombinatoricRandomSampler[T] { + + override def apply(rng: Random): Iterable[Iterable[T]] = combos(rng) +} diff --git a/core/src/main/scala/com/eharmony/spotz/optimizer/hyperparam/NormalDistribution.scala b/core/src/main/scala/com/eharmony/spotz/optimizer/hyperparam/NormalDistribution.scala new file mode 100644 index 0000000..595b7e0 --- /dev/null +++ b/core/src/main/scala/com/eharmony/spotz/optimizer/hyperparam/NormalDistribution.scala @@ -0,0 +1,21 @@ +import com.eharmony.spotz.optimizer.hyperparam.RandomSampler + +import scala.util.Random + +/** + * Sample from a normal distribution given the mean and standard deviation + * + * {{{ + * val hyperParamSpace = Map( + * ("x1", NormalDistribution(0, 0.1)) + * ) + * }}} + * + * @param mean mean + * @param std standard deviation + */ +case class NormalDistribution(mean: Double, std: Double) extends RandomSampler[Double] { + override def apply(rng: Random): Double = { + std * rng.nextGaussian() + mean + } +} diff --git a/core/src/main/scala/com/eharmony/spotz/optimizer/hyperparam/RandomChoice.scala b/core/src/main/scala/com/eharmony/spotz/optimizer/hyperparam/RandomChoice.scala new file mode 100644 index 0000000..7f73fc3 --- /dev/null +++ b/core/src/main/scala/com/eharmony/spotz/optimizer/hyperparam/RandomChoice.scala @@ -0,0 +1,24 @@ +package com.eharmony.spotz.optimizer.hyperparam + +import scala.util.Random + +/** + * Sample an element from an Iterable of fixed length with uniform random distribution. + * + * {{{ + * val hyperParamSpace = Map( + * ("x1", RandomChoice(Seq("svm", "logistic"))) + * ) + * }}} + * + * @param iterable an iterable of type T + * @tparam T type parameter of iterable + */ +case class RandomChoice[T](iterable: Iterable[T]) extends RandomSampler[T] { + private val values = iterable.toIndexedSeq + + if (values.length < 1) + throw new IllegalArgumentException("Empty iterable") + + override def apply(rng: Random): T = values(rng.nextInt(values.length)) +} diff --git a/core/src/main/scala/com/eharmony/spotz/optimizer/hyperparam/RandomSampler.scala b/core/src/main/scala/com/eharmony/spotz/optimizer/hyperparam/RandomSampler.scala new file mode 100644 index 0000000..b9f38c0 --- /dev/null +++ b/core/src/main/scala/com/eharmony/spotz/optimizer/hyperparam/RandomSampler.scala @@ -0,0 +1,15 @@ +package com.eharmony.spotz.optimizer.hyperparam + +import scala.util.Random + +/** + * A sampler function dependent on a pseudo random number generator. The generator + * is passed in as a parameter, and this is intentional. It allows the user to + * change seeds and switch generators. This becomes important when sampling on + * Spark workers and more control over the rng is necessary. + * + * @tparam T + */ +trait RandomSampler[T] extends Serializable { + def apply(rng: Random): T +} diff --git a/core/src/main/scala/com/eharmony/spotz/optimizer/hyperparam/Subsets.scala b/core/src/main/scala/com/eharmony/spotz/optimizer/hyperparam/Subsets.scala new file mode 100644 index 0000000..dc5435c --- /dev/null +++ b/core/src/main/scala/com/eharmony/spotz/optimizer/hyperparam/Subsets.scala @@ -0,0 +1,44 @@ +package com.eharmony.spotz.optimizer.hyperparam + +import scala.collection.mutable +import scala.util.Random + +case class Subset[T](iterable: Iterable[T], k: Int) + + +case class Subsets[T](iterable: Iterable[T], k: Int, x: Int, replacement: Boolean = false)(implicit ord: Ordering[T]) extends CombinatoricRandomSampler[T] { + private val values = iterable.toIndexedSeq + + def sample(rng: Random): Iterable[T] = { + val sampleSize = rng.nextInt(k) + 1 + val subset = mutable.SortedSet[T]() + val indices = mutable.Set[Int]() + + while (subset.size < sampleSize) { + val index = rng.nextInt(values.size) + if (replacement) { + subset.add(values(index)) + } else if (!indices.contains(index)) { + indices.add(index) + subset.add(values(index)) + } + } + subset.toIndexedSeq + } + + def apply(rng: Random): Iterable[Iterable[T]] = { + val numSubsets = rng.nextInt(x) + 1 + + if (replacement) { + Seq.fill(numSubsets)(sample(rng)) + } else { + val subsets = mutable.Set[Iterable[T]]() + while (subsets.size < numSubsets) { + subsets.add(sample(rng)) + } + subsets.toIndexedSeq + } + } +} + + diff --git a/core/src/main/scala/com/eharmony/spotz/optimizer/hyperparam/Uniform.scala b/core/src/main/scala/com/eharmony/spotz/optimizer/hyperparam/Uniform.scala new file mode 100644 index 0000000..b0de664 --- /dev/null +++ b/core/src/main/scala/com/eharmony/spotz/optimizer/hyperparam/Uniform.scala @@ -0,0 +1,46 @@ +package com.eharmony.spotz.optimizer.hyperparam + +import scala.util.Random + +/** + * Created by vsuthichai on 8/18/16. + */ +abstract class Uniform[T](lb: T, ub: T) extends RandomSampler[T] + +/** + * Sample a Double within the bounds lb <= x < ub with uniform random distribution + * + * {{{ + * val rng = new Random(seed) + * val sampler = UniformDouble(0, 1)) + * val sample = sampler(rng) + * }}} + * + * @param lb lower bound + * @param ub upper bound + */ +case class UniformDouble(lb: Double, ub: Double) extends Uniform[Double](lb, ub) { + if (lb >= ub) + throw new IllegalArgumentException("lb must be less than ub") + + override def apply(rng: Random): Double = lb + ((ub - lb) * rng.nextDouble) +} + +/** + * Sample an Int within the bounds lb <= x < ub with uniform random distribution + * + * {{{ + * val hyperParamSpace = Map( + * ("x1", UniformInt(0, 10)) + * ) + * }}} + * + * @param lb lower bound + * @param ub upper bound + */ +case class UniformInt(lb: Int, ub: Int) extends Uniform[Int](lb, ub) { + if (lb >= ub) + throw new IllegalArgumentException("lb must be less than ub") + + override def apply(rng: Random): Int = lb + rng.nextInt(ub - lb) +} diff --git a/core/src/main/scala/com/eharmony/spotz/optimizer/hyperparam/Union.scala b/core/src/main/scala/com/eharmony/spotz/optimizer/hyperparam/Union.scala new file mode 100644 index 0000000..7029ed2 --- /dev/null +++ b/core/src/main/scala/com/eharmony/spotz/optimizer/hyperparam/Union.scala @@ -0,0 +1,39 @@ +package com.eharmony.spotz.optimizer.hyperparam + +import scala.util.Random + +/** + * Given an iterable of RandomSampler functions, choose a function at random and + * sample from it. + * + * {{{ + * val hyperParamSpace = Map( + * ("x1", Union(UniformDouble(0, 1), UniformDouble(10, 11))) + * ) + * }}} + * + * @param iterable an iterable of RandomSampler[T] functions + * @param probs an iterable of probabilities that should sum to 1. This specifies the probabilities that + * the sampler functions are chosen. If the length of this is not the same as the length + * of the iterable of RandomSampler[T] functions, then an IllegalArgumentException is thrown. + * Not specifying an iterable of probabilities will force a default uniform random sampling. + * If the length of the iterable of probabilities is equal to the length of the iterable of + * RandomSampler[T] functions, then + * @tparam T type parameter of the sample + */ +case class Union[T](iterable: Iterable[RandomSampler[T]], probs: Iterable[Double] = Seq()) extends RandomSampler[T] { + private val indexedSeq = iterable.toIndexedSeq + + private val probabilities = + if (probs.isEmpty) + Seq.fill(indexedSeq.length)(1.toDouble / indexedSeq.length) + else if (probs.toIndexedSeq.length != indexedSeq.length) + throw new IllegalArgumentException("iterable lengths must match") + else if (probs.exists(p => p <= 0)) + throw new IllegalArgumentException("Must be positive or valid probabilities") + else if (probs.sum != 1.0) + probs.map(p => p / probs.sum) + + override def apply(rng: Random): T = ??? + private def bucket(probability: Double): Int = ??? +} diff --git a/core/src/main/scala/com/eharmony/spotz/optimizer/random/RandomSearch.scala b/core/src/main/scala/com/eharmony/spotz/optimizer/random/RandomSearch.scala index 3dd4fcb..ca2e90b 100644 --- a/core/src/main/scala/com/eharmony/spotz/optimizer/random/RandomSearch.scala +++ b/core/src/main/scala/com/eharmony/spotz/optimizer/random/RandomSearch.scala @@ -3,6 +3,7 @@ package com.eharmony.spotz.optimizer.random import com.eharmony.spotz.backend.{BackendFunctions, ParallelFunctions, SparkFunctions} import com.eharmony.spotz.objective.Objective import com.eharmony.spotz.optimizer._ +import com.eharmony.spotz.optimizer.hyperparam.RandomSampler import com.eharmony.spotz.util.{DurationUtils, Logging} import org.apache.spark.SparkContext import org.joda.time.{DateTime, Duration} diff --git a/core/src/main/scala/com/eharmony/spotz/util/FileFunctions.scala b/core/src/main/scala/com/eharmony/spotz/util/FileFunctions.scala index d6902d9..9617714 100644 --- a/core/src/main/scala/com/eharmony/spotz/util/FileFunctions.scala +++ b/core/src/main/scala/com/eharmony/spotz/util/FileFunctions.scala @@ -4,8 +4,6 @@ import java.io.{File, PrintWriter} import org.apache.spark.{SparkContext, SparkFiles} -import scala.io.Source - /** * Provide capability to save and retrieve files from inside the objective * functions. Users are free to interact with the underlying file system freely as they desire, @@ -16,7 +14,7 @@ import scala.io.Source * get inside the apply method. */ trait FileFunctions { - def save(inputPath: String): String = save(Source.fromInputStream(FileUtil.loadFile(inputPath)).getLines()) + def save(inputPath: String): String = save(FileUtil.loadFile(inputPath)) def save(inputIterable: Iterable[String]): String = save(inputIterable.toIterator) diff --git a/core/src/main/scala/com/eharmony/spotz/util/FileUtil.scala b/core/src/main/scala/com/eharmony/spotz/util/FileUtil.scala index bed7ab0..6477851 100644 --- a/core/src/main/scala/com/eharmony/spotz/util/FileUtil.scala +++ b/core/src/main/scala/com/eharmony/spotz/util/FileUtil.scala @@ -3,13 +3,15 @@ package com.eharmony.spotz.util import java.io.{File, InputStream} import org.apache.commons.io.FilenameUtils -import org.apache.commons.vfs2.{FileSystemException, VFS} +import org.apache.commons.vfs2.{FileNotFoundException, VFS} + +import scala.io.Source /** * @author vsuthichai */ object FileUtil { - val vfs2 = VFS.getManager + private val vfs2 = VFS.getManager /** * Return a file with a filename guaranteed not to be used on the file system. This is @@ -27,12 +29,58 @@ object FileUtil { f } + /** + * + * @param filename + * @param deleteOnExit + * @return + */ def tempFile(filename: String, deleteOnExit: Boolean = true): File = { tempFile(FilenameUtils.getBaseName(filename), FilenameUtils.getExtension(filename), deleteOnExit) } - def loadFile(path: String): InputStream = { + /** + * Load the lines of a file as an iterator. + * + * @param path input path + * @return lines of the file as an Iterator[String] + */ + def loadFile(path: String): Iterator[String] = { + Source.fromInputStream(loadFileInputStream(path)).getLines() + } + + /** + * + * @param path + * @return + */ + def loadFileInputStream(path: String): InputStream = { val vfsFile = vfs2.resolveFile(path) vfsFile.getContent.getInputStream } } + +object SparkFileUtil { + import org.apache.spark.SparkContext + import org.apache.hadoop.mapred.InvalidInputException + + /** + * Load the lines of a file as an iterator. Also attempt to load the file from HDFS + * since the SparkContext is available. + * + * @param path input path + * @return lines of the file as an Iterator[String] + */ + def loadFile(sc: SparkContext, path: String): Iterator[String] = { + try { + FileUtil.loadFile(path) + } catch { + case e: FileNotFoundException => + try { + sc.textFile(path).toLocalIterator + } catch { + case e: InvalidInputException => Source.fromInputStream(this.getClass.getResourceAsStream(path)).getLines() + } + } + } +} diff --git a/core/src/test/scala/com/eharmony/spotz/optimizer/RandomChoiceTest.scala b/core/src/test/scala/com/eharmony/spotz/optimizer/RandomChoiceTest.scala index 11e7115..18bc04e 100644 --- a/core/src/test/scala/com/eharmony/spotz/optimizer/RandomChoiceTest.scala +++ b/core/src/test/scala/com/eharmony/spotz/optimizer/RandomChoiceTest.scala @@ -1,5 +1,6 @@ package com.eharmony.spotz.optimizer +import com.eharmony.spotz.optimizer.hyperparam.RandomChoice import org.junit.Assert._ import org.junit.Test diff --git a/core/src/test/scala/com/eharmony/spotz/optimizer/UniformRandomTest.scala b/core/src/test/scala/com/eharmony/spotz/optimizer/UniformRandomTest.scala index f96d50d..3dbc8f0 100644 --- a/core/src/test/scala/com/eharmony/spotz/optimizer/UniformRandomTest.scala +++ b/core/src/test/scala/com/eharmony/spotz/optimizer/UniformRandomTest.scala @@ -1,5 +1,6 @@ package com.eharmony.spotz.optimizer +import com.eharmony.spotz.optimizer.hyperparam.UniformDouble import org.junit.Assert._ import org.junit.Test diff --git a/examples/src/main/scala/com/eharmony/spotz/examples/AckleyExample.scala b/examples/src/main/scala/com/eharmony/spotz/examples/AckleyExample.scala index f72201c..6f0c5c1 100644 --- a/examples/src/main/scala/com/eharmony/spotz/examples/AckleyExample.scala +++ b/examples/src/main/scala/com/eharmony/spotz/examples/AckleyExample.scala @@ -2,7 +2,8 @@ package com.eharmony.spotz.examples import com.eharmony.spotz.Preamble.Point import com.eharmony.spotz.objective.Objective -import com.eharmony.spotz.optimizer.{OptimizerResult, StopStrategy, UniformDouble} +import com.eharmony.spotz.optimizer.hyperparam.UniformDouble +import com.eharmony.spotz.optimizer.{OptimizerResult, StopStrategy} import scala.math._ diff --git a/examples/src/main/scala/com/eharmony/spotz/examples/BraninExample.scala b/examples/src/main/scala/com/eharmony/spotz/examples/BraninExample.scala index d9ebac3..718d6cb 100644 --- a/examples/src/main/scala/com/eharmony/spotz/examples/BraninExample.scala +++ b/examples/src/main/scala/com/eharmony/spotz/examples/BraninExample.scala @@ -2,7 +2,8 @@ package com.eharmony.spotz.examples import com.eharmony.spotz.Preamble.Point import com.eharmony.spotz.objective.Objective -import com.eharmony.spotz.optimizer.{OptimizerResult, StopStrategy, UniformDouble} +import com.eharmony.spotz.optimizer.hyperparam.UniformDouble +import com.eharmony.spotz.optimizer.{OptimizerResult, StopStrategy} import scala.math._ diff --git a/examples/src/main/scala/com/eharmony/spotz/examples/ExampleRunner.scala b/examples/src/main/scala/com/eharmony/spotz/examples/ExampleRunner.scala index 14a54b8..4f83588 100644 --- a/examples/src/main/scala/com/eharmony/spotz/examples/ExampleRunner.scala +++ b/examples/src/main/scala/com/eharmony/spotz/examples/ExampleRunner.scala @@ -4,7 +4,8 @@ import com.eharmony.spotz.Preamble.Point import com.eharmony.spotz.objective.Objective import com.eharmony.spotz.optimizer.grid.{GridSearchResult, ParGridSearch, SparkGridSearch} import com.eharmony.spotz.optimizer.random.{ParRandomSearch, RandomSearchResult, SparkRandomSearch} -import com.eharmony.spotz.optimizer.{RandomSampler, StopStrategy} +import com.eharmony.spotz.optimizer.StopStrategy +import com.eharmony.spotz.optimizer.hyperparam.RandomSampler import org.apache.spark.{SparkConf, SparkContext} /** diff --git a/examples/src/main/scala/com/eharmony/spotz/examples/vw/VwCrossValidation.scala b/examples/src/main/scala/com/eharmony/spotz/examples/vw/VwCrossValidation.scala index 0161e18..4ee945f 100644 --- a/examples/src/main/scala/com/eharmony/spotz/examples/vw/VwCrossValidation.scala +++ b/examples/src/main/scala/com/eharmony/spotz/examples/vw/VwCrossValidation.scala @@ -2,7 +2,8 @@ package com.eharmony.spotz.examples.vw import com.eharmony.spotz.examples._ import com.eharmony.spotz.objective.vw.{AbstractVwCrossValidationObjective, SparkVwCrossValidationObjective, VwCrossValidationObjective} -import com.eharmony.spotz.optimizer.{RandomSampler, StopStrategy, UniformDouble} +import com.eharmony.spotz.optimizer.hyperparam.{RandomSampler, UniformDouble} +import com.eharmony.spotz.optimizer.StopStrategy import org.apache.spark.SparkContext /** diff --git a/examples/src/main/scala/com/eharmony/spotz/examples/vw/VwHoldout.scala b/examples/src/main/scala/com/eharmony/spotz/examples/vw/VwHoldout.scala index f629654..2646e31 100644 --- a/examples/src/main/scala/com/eharmony/spotz/examples/vw/VwHoldout.scala +++ b/examples/src/main/scala/com/eharmony/spotz/examples/vw/VwHoldout.scala @@ -2,7 +2,8 @@ package com.eharmony.spotz.examples.vw import com.eharmony.spotz.examples._ import com.eharmony.spotz.objective.vw.{AbstractVwHoldoutObjective, SparkVwHoldoutObjective, VwHoldoutObjective} -import com.eharmony.spotz.optimizer.{RandomSampler, StopStrategy, UniformDouble} +import com.eharmony.spotz.optimizer._ +import com.eharmony.spotz.optimizer.hyperparam.{Combinations, RandomSampler, UniformDouble} import org.apache.spark.SparkContext /** @@ -39,7 +40,8 @@ trait VwHoldout extends AbstractVwHoldout { trait VwHoldoutRandomSearch extends VwHoldout { val space = Map( - ("l", UniformDouble(0, 1)) + ("l", UniformDouble(0, 1)), + ("q", Combinations(List("a", "b", "c", "d"), k = 2, x = 2)) ) def main(args: Array[String]) { diff --git a/pom.xml b/pom.xml index 7c9c232..ed909e4 100644 --- a/pom.xml +++ b/pom.xml @@ -72,6 +72,7 @@ 1.8.1 2.1 1.7.21 + 3.0.0 1.1.7 4.11 1.3 @@ -151,6 +152,11 @@ commons-vfs2 ${vfs2.version} + + com.github.dpaukov + combinatoricslib3 + ${combinatoricslib3.version} + org.slf4j slf4j-api diff --git a/vw/src/main/scala/com/eharmony/spotz/objective/vw/VwCrossValidationObjective.scala b/vw/src/main/scala/com/eharmony/spotz/objective/vw/VwCrossValidationObjective.scala index 371111d..e3cff0f 100644 --- a/vw/src/main/scala/com/eharmony/spotz/objective/vw/VwCrossValidationObjective.scala +++ b/vw/src/main/scala/com/eharmony/spotz/objective/vw/VwCrossValidationObjective.scala @@ -3,13 +3,16 @@ package com.eharmony.spotz.objective.vw import com.eharmony.spotz.Preamble.Point import com.eharmony.spotz.objective.Objective import com.eharmony.spotz.objective.vw.util.{FSVwDatasetFunctions, SparkVwDatasetFunctions, VwCrossValidation} -import com.eharmony.spotz.util.{FileUtil, Logging} +import com.eharmony.spotz.util.{FileUtil, Logging, SparkFileUtil} import org.apache.spark.SparkContext -import scala.io.Source - /** - * @author vsuthichai + * Perform K Fold cross validation given a dataset formatted for Vowpal Wabbit. + * + * @param numFolds + * @param vwDataset + * @param vwTrainParamsString + * @param vwTestParamsString */ abstract class AbstractVwCrossValidationObjective( val numFolds: Int, @@ -32,7 +35,7 @@ abstract class AbstractVwCrossValidationObjective( vwDatasetPath: String, vwTrainParamsString: Option[String], vwTestParamsString: Option[String]) = { - this(numFolds, Source.fromInputStream(FileUtil.loadFile(vwDatasetPath)).getLines(), vwTrainParamsString, vwTestParamsString) + this(numFolds, FileUtil.loadFile(vwDatasetPath), vwTrainParamsString, vwTestParamsString) } val vwTrainParamsMap = parseVwArgs(vwTrainParamsString) @@ -114,7 +117,7 @@ class SparkVwCrossValidationObjective( vwDatasetPath: String, vwTrainParamsString: Option[String], vwTestParamsString: Option[String]) = { - this(sc, numFolds, Source.fromInputStream(FileUtil.loadFile(vwDatasetPath)).getLines(), vwTrainParamsString, vwTestParamsString) + this(sc, numFolds, SparkFileUtil.loadFile(sc, vwDatasetPath), vwTrainParamsString, vwTestParamsString) } } @@ -137,6 +140,6 @@ class VwCrossValidationObjective( vwDatasetPath: String, vwTrainParamsString: Option[String], vwTestParamsString: Option[String]) = { - this(numFolds, Source.fromInputStream(FileUtil.loadFile(vwDatasetPath)).getLines(), vwTrainParamsString, vwTestParamsString) + this(numFolds, FileUtil.loadFile(vwDatasetPath), vwTrainParamsString, vwTestParamsString) } } \ No newline at end of file diff --git a/vw/src/main/scala/com/eharmony/spotz/objective/vw/VwFunctions.scala b/vw/src/main/scala/com/eharmony/spotz/objective/vw/VwFunctions.scala index c3130e8..14db664 100644 --- a/vw/src/main/scala/com/eharmony/spotz/objective/vw/VwFunctions.scala +++ b/vw/src/main/scala/com/eharmony/spotz/objective/vw/VwFunctions.scala @@ -18,18 +18,25 @@ trait VwFunctions { * * @param vwParamMap a Map[String, String] where the key is a VW argument and the value * is the argument value. - * @param point a point object representing the hyperparameter values - * @return a new Map[String, String] which is the result of merging the vwParamMap and the point. + * @param point a point object representing the hyper parameter values + * @return a new Map[String, String] which is the result of merging the vwParamMap and + * the point. */ def mergeVwParams(vwParamMap: Map[String, _], point: Point): Map[String, _] = { - val vwParamsMutableMap = mutable.Map[String, Any]() - - vwParamMap.foldLeft(vwParamsMutableMap) { case (mutableMap, (k, v)) => - mutableMap += ((k, v)) - } + val vwParamsMutableMap = mutable.Map[String, Any](vwParamMap.toSeq: _*) point.getHyperParameterLabels.foldLeft(vwParamsMutableMap) { (mutableMap, vwHyperParam) => - mutableMap += ((vwHyperParam, point.get(vwHyperParam))) + if (mutableMap.contains(vwHyperParam)) { + val listValues = mutableMap(vwHyperParam) match { + case it: Iterable[_] => it.toList ++ point.get(vwHyperParam) + case value => List(value, point.get(vwHyperParam)) + } + mutableMap += ((vwHyperParam, listValues)) + } else { + mutableMap += ((vwHyperParam, point.get(vwHyperParam))) + } + + mutableMap } // Remove cache params @@ -57,7 +64,10 @@ trait VwFunctions { vwParamMap.foldLeft(new StringBuilder) { case (sb, (vwArg, vwValue)) => val dashes = if (vwArg.length == 1) "-" else "--" val vwParam = vwValue match { - case value: Iterable[_] => value.map(x => s"$dashes$vwArg ${x.toString}").mkString(" ") + case value: Iterable[_] => value.map { + case it: Iterable[_] => s"$dashes$vwArg ${it.mkString}" + case x => s"$dashes$vwArg ${x.toString}" + }.mkString(" ") + " " case _ => s"$dashes$vwArg $vwValue " } sb ++= vwParam diff --git a/vw/src/main/scala/com/eharmony/spotz/objective/vw/VwHoldoutObjective.scala b/vw/src/main/scala/com/eharmony/spotz/objective/vw/VwHoldoutObjective.scala index 45f9a39..3618e7b 100644 --- a/vw/src/main/scala/com/eharmony/spotz/objective/vw/VwHoldoutObjective.scala +++ b/vw/src/main/scala/com/eharmony/spotz/objective/vw/VwHoldoutObjective.scala @@ -3,11 +3,9 @@ package com.eharmony.spotz.objective.vw import com.eharmony.spotz.Preamble.Point import com.eharmony.spotz.objective.Objective import com.eharmony.spotz.objective.vw.util.{FSVwDatasetFunctions, SparkVwDatasetFunctions, VwDatasetFunctions} -import com.eharmony.spotz.util.{FileUtil, Logging} +import com.eharmony.spotz.util.{FileUtil, Logging, SparkFileUtil} import org.apache.spark.SparkContext -import scala.io.Source - /** * @author vsuthichai */ @@ -32,8 +30,7 @@ abstract class AbstractVwHoldoutObjective( vwTrainParamsString: Option[String], vwTestSetPath: String, vwTestParamsString: Option[String]) = { - this(Source.fromInputStream(FileUtil.loadFile(vwTrainSetPath)).getLines(), vwTrainParamsString, - Source.fromInputStream(FileUtil.loadFile(vwTestSetPath)).getLines(), vwTestParamsString) + this(FileUtil.loadFile(vwTrainSetPath), vwTrainParamsString, FileUtil.loadFile(vwTestSetPath), vwTestParamsString) } val vwTrainParamMap = parseVwArgs(vwTrainParamsString) @@ -95,8 +92,7 @@ class SparkVwHoldoutObjective( vwTrainParamsString: Option[String], vwTestSetPath: String, vwTestParamsString: Option[String]) = { - this(sc, Source.fromInputStream(FileUtil.loadFile(vwTrainSetPath)).getLines(), vwTrainParamsString, - Source.fromInputStream(FileUtil.loadFile(vwTestSetPath)).getLines(), vwTestParamsString) + this(sc, SparkFileUtil.loadFile(sc, vwTrainSetPath), vwTrainParamsString, FileUtil.loadFile(vwTestSetPath), vwTestParamsString) } } @@ -119,7 +115,6 @@ class VwHoldoutObjective( vwTrainParamsString: Option[String], vwTestSetPath: String, vwTestParamsString: Option[String]) = { - this(Source.fromInputStream(FileUtil.loadFile(vwTrainSetPath)).getLines(), vwTrainParamsString, - Source.fromInputStream(FileUtil.loadFile(vwTestSetPath)).getLines(), vwTestParamsString) + this(FileUtil.loadFile(vwTrainSetPath), vwTrainParamsString, FileUtil.loadFile(vwTestSetPath), vwTestParamsString) } } diff --git a/vw/src/main/scala/com/eharmony/spotz/objective/vw/util/VwDatasetFunctions.scala b/vw/src/main/scala/com/eharmony/spotz/objective/vw/util/VwDatasetFunctions.scala index c4066bb..4e94862 100644 --- a/vw/src/main/scala/com/eharmony/spotz/objective/vw/util/VwDatasetFunctions.scala +++ b/vw/src/main/scala/com/eharmony/spotz/objective/vw/util/VwDatasetFunctions.scala @@ -17,7 +17,7 @@ import scala.io.Source trait VwDatasetFunctions extends FileFunctions { def saveAsCache(inputIterator: Iterator[String]): String = saveAsCache(inputIterator, "dataset.cache") def saveAsCache(inputIterable: Iterable[String]): String = saveAsCache(inputIterable.toIterator) - def saveAsCache(inputPath: String): String = saveAsCache(Source.fromInputStream(FileUtil.loadFile(inputPath)).getLines()) + def saveAsCache(inputPath: String): String = saveAsCache(FileUtil.loadFile(inputPath)) def saveAsCache(vwDataset: Iterator[String], vwCacheFilename: String): String = { // Write VW dataset to a temporary file @@ -61,7 +61,7 @@ trait SparkVwDatasetFunctions extends VwDatasetFunctions with SparkFileFunctions */ trait VwCrossValidation extends VwDatasetFunctions { def kFold(inputPath: String, folds: Int): Map[Int, (String, String)] = { - val enumeratedVwInput = Source.fromInputStream(FileUtil.loadFile(inputPath)).getLines() + val enumeratedVwInput = FileUtil.loadFile(inputPath) kFold(enumeratedVwInput, folds) }