diff --git a/core/pom.xml b/core/pom.xml
index c7e2385..a728649 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -43,6 +43,10 @@
commons-io
commons-io
+
+ com.github.dpaukov
+ combinatoricslib3
+
org.slf4j
slf4j-api
diff --git a/core/src/main/scala/com/eharmony/spotz/backend/BackendFunctions.scala b/core/src/main/scala/com/eharmony/spotz/backend/BackendFunctions.scala
index cfdb013..30aa999 100644
--- a/core/src/main/scala/com/eharmony/spotz/backend/BackendFunctions.scala
+++ b/core/src/main/scala/com/eharmony/spotz/backend/BackendFunctions.scala
@@ -1,8 +1,8 @@
package com.eharmony.spotz.backend
import com.eharmony.spotz.objective.Objective
-import com.eharmony.spotz.optimizer.RandomSampler
import com.eharmony.spotz.optimizer.grid.Grid
+import com.eharmony.spotz.optimizer.hyperparam.RandomSampler
import scala.reflect.ClassTag
@@ -15,13 +15,13 @@ import scala.reflect.ClassTag
*/
trait BackendFunctions {
protected def bestRandomPointAndLoss[P, L](
- startIndex: Long,
- batchSize: Long,
- objective: Objective[P, L],
- reducer: ((P, L), (P, L)) => (P, L),
- hyperParams: Map[String, RandomSampler[_]],
- seed: Long = 0,
- sampleFunction: (Map[String, RandomSampler[_]], Long) => P): (P, L)
+ startIndex: Long,
+ batchSize: Long,
+ objective: Objective[P, L],
+ reducer: ((P, L), (P, L)) => (P, L),
+ hyperParams: Map[String, RandomSampler[_]],
+ seed: Long = 0,
+ sampleFunction: (Map[String, RandomSampler[_]], Long) => P): (P, L)
protected def bestGridPointAndLoss[P, L](
startIndex: Long,
diff --git a/core/src/main/scala/com/eharmony/spotz/backend/ParallelFunctions.scala b/core/src/main/scala/com/eharmony/spotz/backend/ParallelFunctions.scala
index 1498b54..98853c3 100644
--- a/core/src/main/scala/com/eharmony/spotz/backend/ParallelFunctions.scala
+++ b/core/src/main/scala/com/eharmony/spotz/backend/ParallelFunctions.scala
@@ -1,8 +1,8 @@
package com.eharmony.spotz.backend
import com.eharmony.spotz.objective.Objective
-import com.eharmony.spotz.optimizer.RandomSampler
import com.eharmony.spotz.optimizer.grid.Grid
+import com.eharmony.spotz.optimizer.hyperparam.RandomSampler
import scala.reflect.ClassTag
@@ -31,13 +31,13 @@ trait ParallelFunctions extends BackendFunctions {
* @return the best point with the best loss as a tuple
*/
protected override def bestRandomPointAndLoss[P, L](
- startIndex: Long,
- batchSize: Long,
- objective: Objective[P, L],
- reducer: ((P, L), (P, L)) => (P, L),
- hyperParams: Map[String, RandomSampler[_]],
- seed: Long = 0,
- sampleFunction: (Map[String, RandomSampler[_]], Long) => P): (P, L) = {
+ startIndex: Long,
+ batchSize: Long,
+ objective: Objective[P, L],
+ reducer: ((P, L), (P, L)) => (P, L),
+ hyperParams: Map[String, RandomSampler[_]],
+ seed: Long = 0,
+ sampleFunction: (Map[String, RandomSampler[_]], Long) => P): (P, L) = {
val pointsAndLosses = (startIndex until (startIndex + batchSize)).par.map { trial =>
val point = sampleFunction(hyperParams, seed + trial)
diff --git a/core/src/main/scala/com/eharmony/spotz/backend/SparkFunctions.scala b/core/src/main/scala/com/eharmony/spotz/backend/SparkFunctions.scala
index 446c289..d1e78a5 100644
--- a/core/src/main/scala/com/eharmony/spotz/backend/SparkFunctions.scala
+++ b/core/src/main/scala/com/eharmony/spotz/backend/SparkFunctions.scala
@@ -1,8 +1,8 @@
package com.eharmony.spotz.backend
import com.eharmony.spotz.objective.Objective
-import com.eharmony.spotz.optimizer.RandomSampler
import com.eharmony.spotz.optimizer.grid.Grid
+import com.eharmony.spotz.optimizer.hyperparam.RandomSampler
import org.apache.spark.SparkContext
import scala.reflect.ClassTag
@@ -33,13 +33,13 @@ trait SparkFunctions extends BackendFunctions {
* @return the best point with the best loss as a tuple
*/
protected override def bestRandomPointAndLoss[P, L](
- startIndex: Long,
- batchSize: Long,
- objective: Objective[P, L],
- reducer: ((P, L), (P, L)) => (P, L),
- hyperParams: Map[String, RandomSampler[_]],
- seed: Long = 0,
- sampleFunction: (Map[String, RandomSampler[_]], Long) => P): (P, L) = {
+ startIndex: Long,
+ batchSize: Long,
+ objective: Objective[P, L],
+ reducer: ((P, L), (P, L)) => (P, L),
+ hyperParams: Map[String, RandomSampler[_]],
+ seed: Long = 0,
+ sampleFunction: (Map[String, RandomSampler[_]], Long) => P): (P, L) = {
assert(batchSize > 0, "batchSize must be greater than 0")
diff --git a/core/src/main/scala/com/eharmony/spotz/optimizer/HyperParameter.scala b/core/src/main/scala/com/eharmony/spotz/optimizer/HyperParameter.scala
deleted file mode 100644
index 8f44305..0000000
--- a/core/src/main/scala/com/eharmony/spotz/optimizer/HyperParameter.scala
+++ /dev/null
@@ -1,150 +0,0 @@
-package com.eharmony.spotz.optimizer
-
-import scala.util.Random
-
-/**
- * @author vsuthichai
- */
-trait SamplerFunction[T] extends Serializable
-
-/**
- * A sampler function dependent on a pseudo random number generator. The generator
- * is passed in as a parameter, and this is intentional. It allows the user to
- * change seeds and switch generators. This becomes important when sampling on
- * Spark workers and more control over the rng is necessary.
- *
- * @tparam T
- */
-abstract class RandomSampler[T] extends SamplerFunction[T] {
- def apply(rng: Random): T
-}
-
-abstract class Uniform[T](lb: T, ub: T) extends RandomSampler[T]
-
-/**
- * Sample a Double within the bounds lb <= x < ub with uniform random distribution
- *
- * {{{
- * val rng = new Random(seed)
- * val sampler = UniformDouble(0, 1))
- * val sample = sampler(rng)
- * }}}
- *
- * @param lb lower bound
- * @param ub upper bound
- */
-case class UniformDouble(lb: Double, ub: Double) extends Uniform[Double](lb, ub) {
- if (lb >= ub)
- throw new IllegalArgumentException("lb must be less than ub")
-
- override def apply(rng: Random): Double = lb + ((ub - lb) * rng.nextDouble)
-}
-
-/**
- * Sample an Int within the bounds lb <= x < ub with uniform random distribution
- *
- * {{{
- * val hyperParamSpace = Map(
- * ("x1", UniformInt(0, 10))
- * )
- * }}}
- *
- * @param lb lower bound
- * @param ub upper bound
- */
-case class UniformInt(lb: Int, ub: Int) extends Uniform[Int](lb, ub) {
- if (lb >= ub)
- throw new IllegalArgumentException("lb must be less than ub")
-
- override def apply(rng: Random): Int = lb + rng.nextInt(ub - lb)
-}
-
-/**
- * Sample from a normal distribution given the mean and standard deviation
- *
- * {{{
- * val hyperParamSpace = Map(
- * ("x1", NormalDistribution(0, 0.1))
- * )
- * }}}
- *
- * @param mean mean
- * @param std standard deviation
- */
-case class NormalDistribution(mean: Double, std: Double) extends RandomSampler[Double] {
- override def apply(rng: Random): Double = {
- std * rng.nextGaussian() + mean
- }
-}
-
-/**
- * Given an iterable of RandomSampler functions, choose a function at random and
- * sample from it.
- *
- * {{{
- * val hyperParamSpace = Map(
- * ("x1", Union(UniformDouble(0, 1), UniformDouble(10, 11)))
- * )
- * }}}
- *
- * @param iterable an iterable of RandomSampler[T] functions
- * @param probs an iterable of probabilities that should sum to 1. This specifies the probabilities that
- * the sampler functions are chosen. If the length of this is not the same as the length
- * of the iterable of RandomSampler[T] functions, then an IllegalArgumentException is thrown.
- * Not specifying an iterable of probabilities will force a default uniform random sampling.
- * If the length of the iterable of probabilities is equal to the length of the iterable of
- * RandomSampler[T] functions, then
- * @tparam T type parameter of the sample
- */
-case class Union[T](iterable: Iterable[RandomSampler[T]], probs: Iterable[Double] = Seq()) extends RandomSampler[T] {
- private val indexedSeq = iterable.toIndexedSeq
-
- private val probabilities =
- if (probs.isEmpty)
- Seq.fill(indexedSeq.length)(1.toDouble / indexedSeq.length)
- else if (probs.toIndexedSeq.length != indexedSeq.length)
- throw new IllegalArgumentException("iterable lengths must match")
- else if (probs.exists(p => p <= 0))
- throw new IllegalArgumentException("Must be positive or valid probabilities")
- else if (probs.sum != 1.0)
- probs.map(p => p / probs.sum)
-
- override def apply(rng: Random): T = ???
- private def bucket(probability: Double): Int = ???
-}
-
-/**
- * Sample an element from an Iterable of fixed length with uniform random distribution.
- *
- * {{{
- * val hyperParamSpace = Map(
- * ("x1", RandomChoice(Seq("svm", "logistic")))
- * )
- * }}}
- *
- * @param iterable an iterable of type T
- * @tparam T type parameter of iterable
- */
-case class RandomChoice[T](iterable: Iterable[T]) extends RandomSampler[T] {
- private val values = iterable.toIndexedSeq
-
- if (values.length < 1)
- throw new IllegalArgumentException("Empty iterable")
-
- override def apply(rng: Random): T = values(rng.nextInt(values.length))
-}
-
-/**
- * N Choose K, where N is the size of an Iterable.
- *
- * @param iterable
- * @param k
- * @tparam T
- */
-case class BinomialCoefficient[T](iterable: Iterable[T], k: Int) extends RandomSampler[Iterable[T]] {
- private val values = iterable.toSeq
-
- override def apply(rng: Random): Iterable[T] = {
- rng.shuffle(values).take(k)
- }
-}
diff --git a/core/src/main/scala/com/eharmony/spotz/optimizer/hyperparam/Combinations.scala b/core/src/main/scala/com/eharmony/spotz/optimizer/hyperparam/Combinations.scala
new file mode 100644
index 0000000..f0f9f76
--- /dev/null
+++ b/core/src/main/scala/com/eharmony/spotz/optimizer/hyperparam/Combinations.scala
@@ -0,0 +1,96 @@
+package com.eharmony.spotz.optimizer.hyperparam
+
+import scala.util.Random
+
+
+trait CombinatoricRandomSampler[T] extends RandomSampler[Iterable[Iterable[T]]]
+trait IterableRandomSampler[T] extends RandomSampler[Iterable[T]]
+
+/**
+ *
+ * @param iterable
+ * @param k
+ * @param x
+ * @param replacement
+ * @tparam T
+ */
+abstract class AbstractCombinations[T](
+ iterable: Iterable[T],
+ k: Int,
+ x: Int = 1,
+ replacement: Boolean = false) extends Serializable {
+
+ import org.paukov.combinatorics3.Generator
+
+ import scala.collection.JavaConverters._
+
+ private val values = iterable.toSeq
+
+ assert(k > 0, "k must be greater than 0")
+ assert(k <= values.length, s"k must be less than or equal to length of the iterable, ${values.length}")
+
+ // TODO: This is hideous! Rewrite this to be more memory efficient by unranking combinations. For now, use a Java lib.
+ val combinations = Generator.combination(iterable.asJavaCollection).simple(k).asScala.toIndexedSeq.map(l => l.asScala.toIndexedSeq)
+
+ /**
+ *
+ * @param rng
+ * @return
+ */
+ def combos(rng: Random): Iterable[Iterable[T]] = {
+ if (replacement) {
+ Seq.fill(x)(combinations(rng.nextInt(combinations.size)))
+ } else {
+ val indices = collection.mutable.Set[Int]()
+ val numElements = scala.math.min(x, combinations.size)
+ val ret = new collection.mutable.ArrayBuffer[Iterable[T]](numElements)
+ while (indices.size < numElements) {
+ val index = rng.nextInt(combinations.size)
+ if (!indices.contains(index)) {
+ indices.add(index)
+ ret += combinations(index)
+ }
+ }
+ ret.toIndexedSeq
+ }
+ }
+}
+
+
+/**
+ * Sample a single combination of K unordered items from the iterable of length N.
+ *
+ * @param iterable
+ * @param k
+ * @param replacement
+ * @tparam T
+ */
+case class Combination[T](
+ iterable: Iterable[T],
+ k: Int,
+ replacement: Boolean = false)
+ extends AbstractCombinations[T](iterable, k, 1, replacement) with IterableRandomSampler[T] {
+
+ override def apply(rng: Random): Iterable[T] = combos(rng).head
+}
+
+
+/**
+ * Binomial coefficient implementation. Pick K unordered items from an Iterable of N items.
+ * Also known as N Choose K, where N is the size of an Iterable and K is the desired number
+ * of items to be chosen. This implementation will actually compute all the possible choices
+ * and return them as an Iterable.
+ *
+ * @param iterable an iterable of finite length
+ * @param k the number of items to choose
+ * @tparam T
+ */
+case class Combinations[T](
+ iterable: Iterable[T],
+ k: Int,
+ x: Int = 1,
+ replacement: Boolean = false)
+ extends AbstractCombinations[T](iterable, k, x, replacement) with CombinatoricRandomSampler[T] {
+
+ override def apply(rng: Random): Iterable[Iterable[T]] = combos(rng)
+}
diff --git a/core/src/main/scala/com/eharmony/spotz/optimizer/hyperparam/NormalDistribution.scala b/core/src/main/scala/com/eharmony/spotz/optimizer/hyperparam/NormalDistribution.scala
new file mode 100644
index 0000000..595b7e0
--- /dev/null
+++ b/core/src/main/scala/com/eharmony/spotz/optimizer/hyperparam/NormalDistribution.scala
@@ -0,0 +1,21 @@
+import com.eharmony.spotz.optimizer.hyperparam.RandomSampler
+
+import scala.util.Random
+
+/**
+ * Sample from a normal distribution given the mean and standard deviation
+ *
+ * {{{
+ * val hyperParamSpace = Map(
+ * ("x1", NormalDistribution(0, 0.1))
+ * )
+ * }}}
+ *
+ * @param mean mean
+ * @param std standard deviation
+ */
+case class NormalDistribution(mean: Double, std: Double) extends RandomSampler[Double] {
+ override def apply(rng: Random): Double = {
+ std * rng.nextGaussian() + mean
+ }
+}
diff --git a/core/src/main/scala/com/eharmony/spotz/optimizer/hyperparam/RandomChoice.scala b/core/src/main/scala/com/eharmony/spotz/optimizer/hyperparam/RandomChoice.scala
new file mode 100644
index 0000000..7f73fc3
--- /dev/null
+++ b/core/src/main/scala/com/eharmony/spotz/optimizer/hyperparam/RandomChoice.scala
@@ -0,0 +1,24 @@
+package com.eharmony.spotz.optimizer.hyperparam
+
+import scala.util.Random
+
+/**
+ * Sample an element from an Iterable of fixed length with uniform random distribution.
+ *
+ * {{{
+ * val hyperParamSpace = Map(
+ * ("x1", RandomChoice(Seq("svm", "logistic")))
+ * )
+ * }}}
+ *
+ * @param iterable an iterable of type T
+ * @tparam T type parameter of iterable
+ */
+case class RandomChoice[T](iterable: Iterable[T]) extends RandomSampler[T] {
+ private val values = iterable.toIndexedSeq
+
+ if (values.length < 1)
+ throw new IllegalArgumentException("Empty iterable")
+
+ override def apply(rng: Random): T = values(rng.nextInt(values.length))
+}
diff --git a/core/src/main/scala/com/eharmony/spotz/optimizer/hyperparam/RandomSampler.scala b/core/src/main/scala/com/eharmony/spotz/optimizer/hyperparam/RandomSampler.scala
new file mode 100644
index 0000000..b9f38c0
--- /dev/null
+++ b/core/src/main/scala/com/eharmony/spotz/optimizer/hyperparam/RandomSampler.scala
@@ -0,0 +1,15 @@
+package com.eharmony.spotz.optimizer.hyperparam
+
+import scala.util.Random
+
+/**
+ * A sampler function dependent on a pseudo random number generator. The generator
+ * is passed in as a parameter, and this is intentional. It allows the user to
+ * change seeds and switch generators. This becomes important when sampling on
+ * Spark workers and more control over the rng is necessary.
+ *
+ * @tparam T
+ */
+trait RandomSampler[T] extends Serializable {
+ def apply(rng: Random): T
+}
diff --git a/core/src/main/scala/com/eharmony/spotz/optimizer/hyperparam/Subsets.scala b/core/src/main/scala/com/eharmony/spotz/optimizer/hyperparam/Subsets.scala
new file mode 100644
index 0000000..dc5435c
--- /dev/null
+++ b/core/src/main/scala/com/eharmony/spotz/optimizer/hyperparam/Subsets.scala
@@ -0,0 +1,44 @@
+package com.eharmony.spotz.optimizer.hyperparam
+
+import scala.collection.mutable
+import scala.util.Random
+
+case class Subset[T](iterable: Iterable[T], k: Int)
+
+
+case class Subsets[T](iterable: Iterable[T], k: Int, x: Int, replacement: Boolean = false)(implicit ord: Ordering[T]) extends CombinatoricRandomSampler[T] {
+ private val values = iterable.toIndexedSeq
+
+ def sample(rng: Random): Iterable[T] = {
+ val sampleSize = rng.nextInt(k) + 1
+ val subset = mutable.SortedSet[T]()
+ val indices = mutable.Set[Int]()
+
+ while (subset.size < sampleSize) {
+ val index = rng.nextInt(values.size)
+ if (replacement) {
+ subset.add(values(index))
+ } else if (!indices.contains(index)) {
+ indices.add(index)
+ subset.add(values(index))
+ }
+ }
+ subset.toIndexedSeq
+ }
+
+ def apply(rng: Random): Iterable[Iterable[T]] = {
+ val numSubsets = rng.nextInt(x) + 1
+
+ if (replacement) {
+ Seq.fill(numSubsets)(sample(rng))
+ } else {
+ val subsets = mutable.Set[Iterable[T]]()
+ while (subsets.size < numSubsets) {
+ subsets.add(sample(rng))
+ }
+ subsets.toIndexedSeq
+ }
+ }
+}
+
+
diff --git a/core/src/main/scala/com/eharmony/spotz/optimizer/hyperparam/Uniform.scala b/core/src/main/scala/com/eharmony/spotz/optimizer/hyperparam/Uniform.scala
new file mode 100644
index 0000000..b0de664
--- /dev/null
+++ b/core/src/main/scala/com/eharmony/spotz/optimizer/hyperparam/Uniform.scala
@@ -0,0 +1,46 @@
+package com.eharmony.spotz.optimizer.hyperparam
+
+import scala.util.Random
+
+/**
+ * Created by vsuthichai on 8/18/16.
+ */
+abstract class Uniform[T](lb: T, ub: T) extends RandomSampler[T]
+
+/**
+ * Sample a Double within the bounds lb <= x < ub with uniform random distribution
+ *
+ * {{{
+ * val rng = new Random(seed)
+ * val sampler = UniformDouble(0, 1))
+ * val sample = sampler(rng)
+ * }}}
+ *
+ * @param lb lower bound
+ * @param ub upper bound
+ */
+case class UniformDouble(lb: Double, ub: Double) extends Uniform[Double](lb, ub) {
+ if (lb >= ub)
+ throw new IllegalArgumentException("lb must be less than ub")
+
+ override def apply(rng: Random): Double = lb + ((ub - lb) * rng.nextDouble)
+}
+
+/**
+ * Sample an Int within the bounds lb <= x < ub with uniform random distribution
+ *
+ * {{{
+ * val hyperParamSpace = Map(
+ * ("x1", UniformInt(0, 10))
+ * )
+ * }}}
+ *
+ * @param lb lower bound
+ * @param ub upper bound
+ */
+case class UniformInt(lb: Int, ub: Int) extends Uniform[Int](lb, ub) {
+ if (lb >= ub)
+ throw new IllegalArgumentException("lb must be less than ub")
+
+ override def apply(rng: Random): Int = lb + rng.nextInt(ub - lb)
+}
diff --git a/core/src/main/scala/com/eharmony/spotz/optimizer/hyperparam/Union.scala b/core/src/main/scala/com/eharmony/spotz/optimizer/hyperparam/Union.scala
new file mode 100644
index 0000000..7029ed2
--- /dev/null
+++ b/core/src/main/scala/com/eharmony/spotz/optimizer/hyperparam/Union.scala
@@ -0,0 +1,39 @@
+package com.eharmony.spotz.optimizer.hyperparam
+
+import scala.util.Random
+
+/**
+ * Given an iterable of RandomSampler functions, choose a function at random and
+ * sample from it.
+ *
+ * {{{
+ * val hyperParamSpace = Map(
+ * ("x1", Union(UniformDouble(0, 1), UniformDouble(10, 11)))
+ * )
+ * }}}
+ *
+ * @param iterable an iterable of RandomSampler[T] functions
+ * @param probs an iterable of probabilities that should sum to 1. This specifies the probabilities that
+ * the sampler functions are chosen. If the length of this is not the same as the length
+ * of the iterable of RandomSampler[T] functions, then an IllegalArgumentException is thrown.
+ * Not specifying an iterable of probabilities will force a default uniform random sampling.
+ * If the length of the iterable of probabilities is equal to the length of the iterable of
+ * RandomSampler[T] functions, then
+ * @tparam T type parameter of the sample
+ */
+case class Union[T](iterable: Iterable[RandomSampler[T]], probs: Iterable[Double] = Seq()) extends RandomSampler[T] {
+ private val indexedSeq = iterable.toIndexedSeq
+
+ private val probabilities =
+ if (probs.isEmpty)
+ Seq.fill(indexedSeq.length)(1.toDouble / indexedSeq.length)
+ else if (probs.toIndexedSeq.length != indexedSeq.length)
+ throw new IllegalArgumentException("iterable lengths must match")
+ else if (probs.exists(p => p <= 0))
+ throw new IllegalArgumentException("Must be positive or valid probabilities")
+ else if (probs.sum != 1.0)
+ probs.map(p => p / probs.sum)
+
+ override def apply(rng: Random): T = ???
+ private def bucket(probability: Double): Int = ???
+}
diff --git a/core/src/main/scala/com/eharmony/spotz/optimizer/random/RandomSearch.scala b/core/src/main/scala/com/eharmony/spotz/optimizer/random/RandomSearch.scala
index 3dd4fcb..ca2e90b 100644
--- a/core/src/main/scala/com/eharmony/spotz/optimizer/random/RandomSearch.scala
+++ b/core/src/main/scala/com/eharmony/spotz/optimizer/random/RandomSearch.scala
@@ -3,6 +3,7 @@ package com.eharmony.spotz.optimizer.random
import com.eharmony.spotz.backend.{BackendFunctions, ParallelFunctions, SparkFunctions}
import com.eharmony.spotz.objective.Objective
import com.eharmony.spotz.optimizer._
+import com.eharmony.spotz.optimizer.hyperparam.RandomSampler
import com.eharmony.spotz.util.{DurationUtils, Logging}
import org.apache.spark.SparkContext
import org.joda.time.{DateTime, Duration}
diff --git a/core/src/main/scala/com/eharmony/spotz/util/FileFunctions.scala b/core/src/main/scala/com/eharmony/spotz/util/FileFunctions.scala
index d6902d9..9617714 100644
--- a/core/src/main/scala/com/eharmony/spotz/util/FileFunctions.scala
+++ b/core/src/main/scala/com/eharmony/spotz/util/FileFunctions.scala
@@ -4,8 +4,6 @@ import java.io.{File, PrintWriter}
import org.apache.spark.{SparkContext, SparkFiles}
-import scala.io.Source
-
/**
* Provide capability to save and retrieve files from inside the objective
* functions. Users are free to interact with the underlying file system freely as they desire,
@@ -16,7 +14,7 @@ import scala.io.Source
* get inside the apply
method.
*/
trait FileFunctions {
- def save(inputPath: String): String = save(Source.fromInputStream(FileUtil.loadFile(inputPath)).getLines())
+ def save(inputPath: String): String = save(FileUtil.loadFile(inputPath))
def save(inputIterable: Iterable[String]): String = save(inputIterable.toIterator)
diff --git a/core/src/main/scala/com/eharmony/spotz/util/FileUtil.scala b/core/src/main/scala/com/eharmony/spotz/util/FileUtil.scala
index bed7ab0..6477851 100644
--- a/core/src/main/scala/com/eharmony/spotz/util/FileUtil.scala
+++ b/core/src/main/scala/com/eharmony/spotz/util/FileUtil.scala
@@ -3,13 +3,15 @@ package com.eharmony.spotz.util
import java.io.{File, InputStream}
import org.apache.commons.io.FilenameUtils
-import org.apache.commons.vfs2.{FileSystemException, VFS}
+import org.apache.commons.vfs2.{FileNotFoundException, VFS}
+
+import scala.io.Source
/**
* @author vsuthichai
*/
object FileUtil {
- val vfs2 = VFS.getManager
+ private val vfs2 = VFS.getManager
/**
* Return a file with a filename guaranteed not to be used on the file system. This is
@@ -27,12 +29,58 @@ object FileUtil {
f
}
+ /**
+ *
+ * @param filename
+ * @param deleteOnExit
+ * @return
+ */
def tempFile(filename: String, deleteOnExit: Boolean = true): File = {
tempFile(FilenameUtils.getBaseName(filename), FilenameUtils.getExtension(filename), deleteOnExit)
}
- def loadFile(path: String): InputStream = {
+ /**
+ * Load the lines of a file as an iterator.
+ *
+ * @param path input path
+ * @return lines of the file as an Iterator[String]
+ */
+ def loadFile(path: String): Iterator[String] = {
+ Source.fromInputStream(loadFileInputStream(path)).getLines()
+ }
+
+ /**
+ *
+ * @param path
+ * @return
+ */
+ def loadFileInputStream(path: String): InputStream = {
val vfsFile = vfs2.resolveFile(path)
vfsFile.getContent.getInputStream
}
}
+
+object SparkFileUtil {
+ import org.apache.spark.SparkContext
+ import org.apache.hadoop.mapred.InvalidInputException
+
+ /**
+ * Load the lines of a file as an iterator. Also attempt to load the file from HDFS
+ * since the SparkContext is available.
+ *
+ * @param path input path
+ * @return lines of the file as an Iterator[String]
+ */
+ def loadFile(sc: SparkContext, path: String): Iterator[String] = {
+ try {
+ FileUtil.loadFile(path)
+ } catch {
+ case e: FileNotFoundException =>
+ try {
+ sc.textFile(path).toLocalIterator
+ } catch {
+ case e: InvalidInputException => Source.fromInputStream(this.getClass.getResourceAsStream(path)).getLines()
+ }
+ }
+ }
+}
diff --git a/core/src/test/scala/com/eharmony/spotz/optimizer/RandomChoiceTest.scala b/core/src/test/scala/com/eharmony/spotz/optimizer/RandomChoiceTest.scala
index 11e7115..18bc04e 100644
--- a/core/src/test/scala/com/eharmony/spotz/optimizer/RandomChoiceTest.scala
+++ b/core/src/test/scala/com/eharmony/spotz/optimizer/RandomChoiceTest.scala
@@ -1,5 +1,6 @@
package com.eharmony.spotz.optimizer
+import com.eharmony.spotz.optimizer.hyperparam.RandomChoice
import org.junit.Assert._
import org.junit.Test
diff --git a/core/src/test/scala/com/eharmony/spotz/optimizer/UniformRandomTest.scala b/core/src/test/scala/com/eharmony/spotz/optimizer/UniformRandomTest.scala
index f96d50d..3dbc8f0 100644
--- a/core/src/test/scala/com/eharmony/spotz/optimizer/UniformRandomTest.scala
+++ b/core/src/test/scala/com/eharmony/spotz/optimizer/UniformRandomTest.scala
@@ -1,5 +1,6 @@
package com.eharmony.spotz.optimizer
+import com.eharmony.spotz.optimizer.hyperparam.UniformDouble
import org.junit.Assert._
import org.junit.Test
diff --git a/examples/src/main/scala/com/eharmony/spotz/examples/AckleyExample.scala b/examples/src/main/scala/com/eharmony/spotz/examples/AckleyExample.scala
index f72201c..6f0c5c1 100644
--- a/examples/src/main/scala/com/eharmony/spotz/examples/AckleyExample.scala
+++ b/examples/src/main/scala/com/eharmony/spotz/examples/AckleyExample.scala
@@ -2,7 +2,8 @@ package com.eharmony.spotz.examples
import com.eharmony.spotz.Preamble.Point
import com.eharmony.spotz.objective.Objective
-import com.eharmony.spotz.optimizer.{OptimizerResult, StopStrategy, UniformDouble}
+import com.eharmony.spotz.optimizer.hyperparam.UniformDouble
+import com.eharmony.spotz.optimizer.{OptimizerResult, StopStrategy}
import scala.math._
diff --git a/examples/src/main/scala/com/eharmony/spotz/examples/BraninExample.scala b/examples/src/main/scala/com/eharmony/spotz/examples/BraninExample.scala
index d9ebac3..718d6cb 100644
--- a/examples/src/main/scala/com/eharmony/spotz/examples/BraninExample.scala
+++ b/examples/src/main/scala/com/eharmony/spotz/examples/BraninExample.scala
@@ -2,7 +2,8 @@ package com.eharmony.spotz.examples
import com.eharmony.spotz.Preamble.Point
import com.eharmony.spotz.objective.Objective
-import com.eharmony.spotz.optimizer.{OptimizerResult, StopStrategy, UniformDouble}
+import com.eharmony.spotz.optimizer.hyperparam.UniformDouble
+import com.eharmony.spotz.optimizer.{OptimizerResult, StopStrategy}
import scala.math._
diff --git a/examples/src/main/scala/com/eharmony/spotz/examples/ExampleRunner.scala b/examples/src/main/scala/com/eharmony/spotz/examples/ExampleRunner.scala
index 14a54b8..4f83588 100644
--- a/examples/src/main/scala/com/eharmony/spotz/examples/ExampleRunner.scala
+++ b/examples/src/main/scala/com/eharmony/spotz/examples/ExampleRunner.scala
@@ -4,7 +4,8 @@ import com.eharmony.spotz.Preamble.Point
import com.eharmony.spotz.objective.Objective
import com.eharmony.spotz.optimizer.grid.{GridSearchResult, ParGridSearch, SparkGridSearch}
import com.eharmony.spotz.optimizer.random.{ParRandomSearch, RandomSearchResult, SparkRandomSearch}
-import com.eharmony.spotz.optimizer.{RandomSampler, StopStrategy}
+import com.eharmony.spotz.optimizer.StopStrategy
+import com.eharmony.spotz.optimizer.hyperparam.RandomSampler
import org.apache.spark.{SparkConf, SparkContext}
/**
diff --git a/examples/src/main/scala/com/eharmony/spotz/examples/vw/VwCrossValidation.scala b/examples/src/main/scala/com/eharmony/spotz/examples/vw/VwCrossValidation.scala
index 0161e18..4ee945f 100644
--- a/examples/src/main/scala/com/eharmony/spotz/examples/vw/VwCrossValidation.scala
+++ b/examples/src/main/scala/com/eharmony/spotz/examples/vw/VwCrossValidation.scala
@@ -2,7 +2,8 @@ package com.eharmony.spotz.examples.vw
import com.eharmony.spotz.examples._
import com.eharmony.spotz.objective.vw.{AbstractVwCrossValidationObjective, SparkVwCrossValidationObjective, VwCrossValidationObjective}
-import com.eharmony.spotz.optimizer.{RandomSampler, StopStrategy, UniformDouble}
+import com.eharmony.spotz.optimizer.hyperparam.{RandomSampler, UniformDouble}
+import com.eharmony.spotz.optimizer.StopStrategy
import org.apache.spark.SparkContext
/**
diff --git a/examples/src/main/scala/com/eharmony/spotz/examples/vw/VwHoldout.scala b/examples/src/main/scala/com/eharmony/spotz/examples/vw/VwHoldout.scala
index f629654..2646e31 100644
--- a/examples/src/main/scala/com/eharmony/spotz/examples/vw/VwHoldout.scala
+++ b/examples/src/main/scala/com/eharmony/spotz/examples/vw/VwHoldout.scala
@@ -2,7 +2,8 @@ package com.eharmony.spotz.examples.vw
import com.eharmony.spotz.examples._
import com.eharmony.spotz.objective.vw.{AbstractVwHoldoutObjective, SparkVwHoldoutObjective, VwHoldoutObjective}
-import com.eharmony.spotz.optimizer.{RandomSampler, StopStrategy, UniformDouble}
+import com.eharmony.spotz.optimizer._
+import com.eharmony.spotz.optimizer.hyperparam.{Combinations, RandomSampler, UniformDouble}
import org.apache.spark.SparkContext
/**
@@ -39,7 +40,8 @@ trait VwHoldout extends AbstractVwHoldout {
trait VwHoldoutRandomSearch extends VwHoldout {
val space = Map(
- ("l", UniformDouble(0, 1))
+ ("l", UniformDouble(0, 1)),
+ ("q", Combinations(List("a", "b", "c", "d"), k = 2, x = 2))
)
def main(args: Array[String]) {
diff --git a/pom.xml b/pom.xml
index 7c9c232..ed909e4 100644
--- a/pom.xml
+++ b/pom.xml
@@ -72,6 +72,7 @@
1.8.1
2.1
1.7.21
+ 3.0.0
1.1.7
4.11
1.3
@@ -151,6 +152,11 @@
commons-vfs2
${vfs2.version}
+
+ com.github.dpaukov
+ combinatoricslib3
+ ${combinatoricslib3.version}
+
org.slf4j
slf4j-api
diff --git a/vw/src/main/scala/com/eharmony/spotz/objective/vw/VwCrossValidationObjective.scala b/vw/src/main/scala/com/eharmony/spotz/objective/vw/VwCrossValidationObjective.scala
index 371111d..e3cff0f 100644
--- a/vw/src/main/scala/com/eharmony/spotz/objective/vw/VwCrossValidationObjective.scala
+++ b/vw/src/main/scala/com/eharmony/spotz/objective/vw/VwCrossValidationObjective.scala
@@ -3,13 +3,16 @@ package com.eharmony.spotz.objective.vw
import com.eharmony.spotz.Preamble.Point
import com.eharmony.spotz.objective.Objective
import com.eharmony.spotz.objective.vw.util.{FSVwDatasetFunctions, SparkVwDatasetFunctions, VwCrossValidation}
-import com.eharmony.spotz.util.{FileUtil, Logging}
+import com.eharmony.spotz.util.{FileUtil, Logging, SparkFileUtil}
import org.apache.spark.SparkContext
-import scala.io.Source
-
/**
- * @author vsuthichai
+ * Perform K Fold cross validation given a dataset formatted for Vowpal Wabbit.
+ *
+ * @param numFolds
+ * @param vwDataset
+ * @param vwTrainParamsString
+ * @param vwTestParamsString
*/
abstract class AbstractVwCrossValidationObjective(
val numFolds: Int,
@@ -32,7 +35,7 @@ abstract class AbstractVwCrossValidationObjective(
vwDatasetPath: String,
vwTrainParamsString: Option[String],
vwTestParamsString: Option[String]) = {
- this(numFolds, Source.fromInputStream(FileUtil.loadFile(vwDatasetPath)).getLines(), vwTrainParamsString, vwTestParamsString)
+ this(numFolds, FileUtil.loadFile(vwDatasetPath), vwTrainParamsString, vwTestParamsString)
}
val vwTrainParamsMap = parseVwArgs(vwTrainParamsString)
@@ -114,7 +117,7 @@ class SparkVwCrossValidationObjective(
vwDatasetPath: String,
vwTrainParamsString: Option[String],
vwTestParamsString: Option[String]) = {
- this(sc, numFolds, Source.fromInputStream(FileUtil.loadFile(vwDatasetPath)).getLines(), vwTrainParamsString, vwTestParamsString)
+ this(sc, numFolds, SparkFileUtil.loadFile(sc, vwDatasetPath), vwTrainParamsString, vwTestParamsString)
}
}
@@ -137,6 +140,6 @@ class VwCrossValidationObjective(
vwDatasetPath: String,
vwTrainParamsString: Option[String],
vwTestParamsString: Option[String]) = {
- this(numFolds, Source.fromInputStream(FileUtil.loadFile(vwDatasetPath)).getLines(), vwTrainParamsString, vwTestParamsString)
+ this(numFolds, FileUtil.loadFile(vwDatasetPath), vwTrainParamsString, vwTestParamsString)
}
}
\ No newline at end of file
diff --git a/vw/src/main/scala/com/eharmony/spotz/objective/vw/VwFunctions.scala b/vw/src/main/scala/com/eharmony/spotz/objective/vw/VwFunctions.scala
index c3130e8..14db664 100644
--- a/vw/src/main/scala/com/eharmony/spotz/objective/vw/VwFunctions.scala
+++ b/vw/src/main/scala/com/eharmony/spotz/objective/vw/VwFunctions.scala
@@ -18,18 +18,25 @@ trait VwFunctions {
*
* @param vwParamMap a Map[String, String] where the key is a VW argument and the value
* is the argument value.
- * @param point a point object representing the hyperparameter values
- * @return a new Map[String, String] which is the result of merging the vwParamMap and the point.
+ * @param point a point object representing the hyper parameter values
+ * @return a new Map[String, String] which is the result of merging the vwParamMap and
+ * the point.
*/
def mergeVwParams(vwParamMap: Map[String, _], point: Point): Map[String, _] = {
- val vwParamsMutableMap = mutable.Map[String, Any]()
-
- vwParamMap.foldLeft(vwParamsMutableMap) { case (mutableMap, (k, v)) =>
- mutableMap += ((k, v))
- }
+ val vwParamsMutableMap = mutable.Map[String, Any](vwParamMap.toSeq: _*)
point.getHyperParameterLabels.foldLeft(vwParamsMutableMap) { (mutableMap, vwHyperParam) =>
- mutableMap += ((vwHyperParam, point.get(vwHyperParam)))
+ if (mutableMap.contains(vwHyperParam)) {
+ val listValues = mutableMap(vwHyperParam) match {
+ case it: Iterable[_] => it.toList ++ point.get(vwHyperParam)
+ case value => List(value, point.get(vwHyperParam))
+ }
+ mutableMap += ((vwHyperParam, listValues))
+ } else {
+ mutableMap += ((vwHyperParam, point.get(vwHyperParam)))
+ }
+
+ mutableMap
}
// Remove cache params
@@ -57,7 +64,10 @@ trait VwFunctions {
vwParamMap.foldLeft(new StringBuilder) { case (sb, (vwArg, vwValue)) =>
val dashes = if (vwArg.length == 1) "-" else "--"
val vwParam = vwValue match {
- case value: Iterable[_] => value.map(x => s"$dashes$vwArg ${x.toString}").mkString(" ")
+ case value: Iterable[_] => value.map {
+ case it: Iterable[_] => s"$dashes$vwArg ${it.mkString}"
+ case x => s"$dashes$vwArg ${x.toString}"
+ }.mkString(" ") + " "
case _ => s"$dashes$vwArg $vwValue "
}
sb ++= vwParam
diff --git a/vw/src/main/scala/com/eharmony/spotz/objective/vw/VwHoldoutObjective.scala b/vw/src/main/scala/com/eharmony/spotz/objective/vw/VwHoldoutObjective.scala
index 45f9a39..3618e7b 100644
--- a/vw/src/main/scala/com/eharmony/spotz/objective/vw/VwHoldoutObjective.scala
+++ b/vw/src/main/scala/com/eharmony/spotz/objective/vw/VwHoldoutObjective.scala
@@ -3,11 +3,9 @@ package com.eharmony.spotz.objective.vw
import com.eharmony.spotz.Preamble.Point
import com.eharmony.spotz.objective.Objective
import com.eharmony.spotz.objective.vw.util.{FSVwDatasetFunctions, SparkVwDatasetFunctions, VwDatasetFunctions}
-import com.eharmony.spotz.util.{FileUtil, Logging}
+import com.eharmony.spotz.util.{FileUtil, Logging, SparkFileUtil}
import org.apache.spark.SparkContext
-import scala.io.Source
-
/**
* @author vsuthichai
*/
@@ -32,8 +30,7 @@ abstract class AbstractVwHoldoutObjective(
vwTrainParamsString: Option[String],
vwTestSetPath: String,
vwTestParamsString: Option[String]) = {
- this(Source.fromInputStream(FileUtil.loadFile(vwTrainSetPath)).getLines(), vwTrainParamsString,
- Source.fromInputStream(FileUtil.loadFile(vwTestSetPath)).getLines(), vwTestParamsString)
+ this(FileUtil.loadFile(vwTrainSetPath), vwTrainParamsString, FileUtil.loadFile(vwTestSetPath), vwTestParamsString)
}
val vwTrainParamMap = parseVwArgs(vwTrainParamsString)
@@ -95,8 +92,7 @@ class SparkVwHoldoutObjective(
vwTrainParamsString: Option[String],
vwTestSetPath: String,
vwTestParamsString: Option[String]) = {
- this(sc, Source.fromInputStream(FileUtil.loadFile(vwTrainSetPath)).getLines(), vwTrainParamsString,
- Source.fromInputStream(FileUtil.loadFile(vwTestSetPath)).getLines(), vwTestParamsString)
+ this(sc, SparkFileUtil.loadFile(sc, vwTrainSetPath), vwTrainParamsString, FileUtil.loadFile(vwTestSetPath), vwTestParamsString)
}
}
@@ -119,7 +115,6 @@ class VwHoldoutObjective(
vwTrainParamsString: Option[String],
vwTestSetPath: String,
vwTestParamsString: Option[String]) = {
- this(Source.fromInputStream(FileUtil.loadFile(vwTrainSetPath)).getLines(), vwTrainParamsString,
- Source.fromInputStream(FileUtil.loadFile(vwTestSetPath)).getLines(), vwTestParamsString)
+ this(FileUtil.loadFile(vwTrainSetPath), vwTrainParamsString, FileUtil.loadFile(vwTestSetPath), vwTestParamsString)
}
}
diff --git a/vw/src/main/scala/com/eharmony/spotz/objective/vw/util/VwDatasetFunctions.scala b/vw/src/main/scala/com/eharmony/spotz/objective/vw/util/VwDatasetFunctions.scala
index c4066bb..4e94862 100644
--- a/vw/src/main/scala/com/eharmony/spotz/objective/vw/util/VwDatasetFunctions.scala
+++ b/vw/src/main/scala/com/eharmony/spotz/objective/vw/util/VwDatasetFunctions.scala
@@ -17,7 +17,7 @@ import scala.io.Source
trait VwDatasetFunctions extends FileFunctions {
def saveAsCache(inputIterator: Iterator[String]): String = saveAsCache(inputIterator, "dataset.cache")
def saveAsCache(inputIterable: Iterable[String]): String = saveAsCache(inputIterable.toIterator)
- def saveAsCache(inputPath: String): String = saveAsCache(Source.fromInputStream(FileUtil.loadFile(inputPath)).getLines())
+ def saveAsCache(inputPath: String): String = saveAsCache(FileUtil.loadFile(inputPath))
def saveAsCache(vwDataset: Iterator[String], vwCacheFilename: String): String = {
// Write VW dataset to a temporary file
@@ -61,7 +61,7 @@ trait SparkVwDatasetFunctions extends VwDatasetFunctions with SparkFileFunctions
*/
trait VwCrossValidation extends VwDatasetFunctions {
def kFold(inputPath: String, folds: Int): Map[Int, (String, String)] = {
- val enumeratedVwInput = Source.fromInputStream(FileUtil.loadFile(inputPath)).getLines()
+ val enumeratedVwInput = FileUtil.loadFile(inputPath)
kFold(enumeratedVwInput, folds)
}