-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Browse files
Browse the repository at this point in the history
- Loading branch information
1 parent
95daa98
commit 24247ec
Showing
27 changed files
with
421 additions
and
213 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
150 changes: 0 additions & 150 deletions
150
core/src/main/scala/com/eharmony/spotz/optimizer/HyperParameter.scala
This file was deleted.
Oops, something went wrong.
96 changes: 96 additions & 0 deletions
96
core/src/main/scala/com/eharmony/spotz/optimizer/hyperparam/Combinations.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
package com.eharmony.spotz.optimizer.hyperparam | ||
|
||
import scala.util.Random | ||
|
||
|
||
trait CombinatoricRandomSampler[T] extends RandomSampler[Iterable[Iterable[T]]] | ||
trait IterableRandomSampler[T] extends RandomSampler[Iterable[T]] | ||
|
||
/** | ||
* | ||
* @param iterable | ||
* @param k | ||
* @param x | ||
* @param replacement | ||
* @tparam T | ||
*/ | ||
abstract class AbstractCombinations[T]( | ||
iterable: Iterable[T], | ||
k: Int, | ||
x: Int = 1, | ||
replacement: Boolean = false) extends Serializable { | ||
|
||
import org.paukov.combinatorics3.Generator | ||
|
||
import scala.collection.JavaConverters._ | ||
|
||
private val values = iterable.toSeq | ||
|
||
assert(k > 0, "k must be greater than 0") | ||
assert(k <= values.length, s"k must be less than or equal to length of the iterable, ${values.length}") | ||
|
||
// TODO: This is hideous! Rewrite this to be more memory efficient by unranking combinations. For now, use a Java lib. | ||
val combinations = Generator.combination(iterable.asJavaCollection).simple(k).asScala.toIndexedSeq.map(l => l.asScala.toIndexedSeq) | ||
|
||
/** | ||
* | ||
* @param rng | ||
* @return | ||
*/ | ||
def combos(rng: Random): Iterable[Iterable[T]] = { | ||
if (replacement) { | ||
Seq.fill(x)(combinations(rng.nextInt(combinations.size))) | ||
} else { | ||
val indices = collection.mutable.Set[Int]() | ||
val numElements = scala.math.min(x, combinations.size) | ||
val ret = new collection.mutable.ArrayBuffer[Iterable[T]](numElements) | ||
while (indices.size < numElements) { | ||
val index = rng.nextInt(combinations.size) | ||
if (!indices.contains(index)) { | ||
indices.add(index) | ||
ret += combinations(index) | ||
} | ||
} | ||
ret.toIndexedSeq | ||
} | ||
} | ||
} | ||
|
||
|
||
/** | ||
* Sample a single combination of K unordered items from the iterable of length N. | ||
* | ||
* @param iterable | ||
* @param k | ||
* @param replacement | ||
* @tparam T | ||
*/ | ||
case class Combination[T]( | ||
iterable: Iterable[T], | ||
k: Int, | ||
replacement: Boolean = false) | ||
extends AbstractCombinations[T](iterable, k, 1, replacement) with IterableRandomSampler[T] { | ||
|
||
override def apply(rng: Random): Iterable[T] = combos(rng).head | ||
} | ||
|
||
|
||
/** | ||
* Binomial coefficient implementation. Pick K unordered items from an Iterable of N items. | ||
* Also known as N Choose K, where N is the size of an Iterable and K is the desired number | ||
* of items to be chosen. This implementation will actually compute all the possible choices | ||
* and return them as an Iterable. | ||
* | ||
* @param iterable an iterable of finite length | ||
* @param k the number of items to choose | ||
* @tparam T | ||
*/ | ||
case class Combinations[T]( | ||
iterable: Iterable[T], | ||
k: Int, | ||
x: Int = 1, | ||
replacement: Boolean = false) | ||
extends AbstractCombinations[T](iterable, k, x, replacement) with CombinatoricRandomSampler[T] { | ||
|
||
override def apply(rng: Random): Iterable[Iterable[T]] = combos(rng) | ||
} |
21 changes: 21 additions & 0 deletions
21
core/src/main/scala/com/eharmony/spotz/optimizer/hyperparam/NormalDistribution.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
import com.eharmony.spotz.optimizer.hyperparam.RandomSampler | ||
|
||
import scala.util.Random | ||
|
||
/** | ||
* Sample from a normal distribution given the mean and standard deviation | ||
* | ||
* {{{ | ||
* val hyperParamSpace = Map( | ||
* ("x1", NormalDistribution(0, 0.1)) | ||
* ) | ||
* }}} | ||
* | ||
* @param mean mean | ||
* @param std standard deviation | ||
*/ | ||
case class NormalDistribution(mean: Double, std: Double) extends RandomSampler[Double] { | ||
override def apply(rng: Random): Double = { | ||
std * rng.nextGaussian() + mean | ||
} | ||
} |
24 changes: 24 additions & 0 deletions
24
core/src/main/scala/com/eharmony/spotz/optimizer/hyperparam/RandomChoice.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
package com.eharmony.spotz.optimizer.hyperparam | ||
|
||
import scala.util.Random | ||
|
||
/** | ||
* Sample an element from an Iterable of fixed length with uniform random distribution. | ||
* | ||
* {{{ | ||
* val hyperParamSpace = Map( | ||
* ("x1", RandomChoice(Seq("svm", "logistic"))) | ||
* ) | ||
* }}} | ||
* | ||
* @param iterable an iterable of type T | ||
* @tparam T type parameter of iterable | ||
*/ | ||
case class RandomChoice[T](iterable: Iterable[T]) extends RandomSampler[T] { | ||
private val values = iterable.toIndexedSeq | ||
|
||
if (values.length < 1) | ||
throw new IllegalArgumentException("Empty iterable") | ||
|
||
override def apply(rng: Random): T = values(rng.nextInt(values.length)) | ||
} |
Oops, something went wrong.