diff --git a/core/src/main/scala/com/eharmony/spotz/optimizer/hyperparam/Combinations.scala b/core/src/main/scala/com/eharmony/spotz/optimizer/hyperparam/Combinations.scala index e5c5681..9fa4cad 100644 --- a/core/src/main/scala/com/eharmony/spotz/optimizer/hyperparam/Combinations.scala +++ b/core/src/main/scala/com/eharmony/spotz/optimizer/hyperparam/Combinations.scala @@ -17,24 +17,20 @@ abstract class AbstractCombinations[T]( protected val values = iterable.toSeq - assert(k > 0, "k must be greater than 0") - assert(k <= values.length, s"k must be less than or equal to length of the iterable, ${values.length}") - def sample(rng: Random): Iterable[T] = { if (replacement) sampleWithReplacement(rng) else sampleNoReplacement(rng) } def sampleWithReplacement(rng: Random) = { - val combo = mutable.SortedSet[T]() + val combo = new mutable.PriorityQueue[T]() while (combo.size < k) { val index = rng.nextInt(values.length) - val element = values(rng.nextInt(values.length)) - combo.add(element) + combo += values(index) } - combo.toSeq + combo.toIndexedSeq } def sampleNoReplacement(rng: Random) = { @@ -49,7 +45,7 @@ abstract class AbstractCombinations[T]( } } - combo.toSeq + combo.toIndexedSeq } } diff --git a/core/src/main/scala/com/eharmony/spotz/optimizer/hyperparam/Subsets.scala b/core/src/main/scala/com/eharmony/spotz/optimizer/hyperparam/Subsets.scala index c551166..cbbc7eb 100644 --- a/core/src/main/scala/com/eharmony/spotz/optimizer/hyperparam/Subsets.scala +++ b/core/src/main/scala/com/eharmony/spotz/optimizer/hyperparam/Subsets.scala @@ -14,7 +14,19 @@ import scala.util.Random abstract class AbstractSubset[T](iterable: Iterable[T], k: Int, replacement: Boolean = false)(implicit ord: Ordering[T]) extends Serializable { protected val values = iterable.toIndexedSeq - def sample(rng: Random): Iterable[T] = { + def sampleWithReplacement(rng: Random): Iterable[T] = { + val sampleSize = rng.nextInt(k) + 1 + val subset = new mutable.PriorityQueue[T]() + + while (subset.size < k) { + val index = rng.nextInt(values.length) + subset += values(index) + } + + subset.toIndexedSeq + } + + def sampleNoReplacement(rng: Random): Iterable[T] = { val sampleSize = rng.nextInt(k) + 1 val subset = mutable.SortedSet[T]() val indices = mutable.Set[Int]() @@ -22,9 +34,7 @@ abstract class AbstractSubset[T](iterable: Iterable[T], k: Int, replacement: Boo while (subset.size < sampleSize) { val index = rng.nextInt(values.size) - if (replacement) { - subset.add(values(index)) - } else if (!indices.contains(index)) { + if (!indices.contains(index)) { indices.add(index) subset.add(values(index)) } @@ -32,6 +42,11 @@ abstract class AbstractSubset[T](iterable: Iterable[T], k: Int, replacement: Boo subset.toIndexedSeq } + + def sample(rng: Random): Iterable[T] = { + if (replacement) sampleWithReplacement(rng) + else sampleNoReplacement(rng) + } } /** diff --git a/examples/src/main/scala/com/eharmony/spotz/examples/vw/VwHoldout.scala b/examples/src/main/scala/com/eharmony/spotz/examples/vw/VwHoldout.scala index 2646e31..c6ffcb4 100644 --- a/examples/src/main/scala/com/eharmony/spotz/examples/vw/VwHoldout.scala +++ b/examples/src/main/scala/com/eharmony/spotz/examples/vw/VwHoldout.scala @@ -41,7 +41,7 @@ trait VwHoldout extends AbstractVwHoldout { trait VwHoldoutRandomSearch extends VwHoldout { val space = Map( ("l", UniformDouble(0, 1)), - ("q", Combinations(List("a", "b", "c", "d"), k = 2, x = 2)) + ("interactions", Combinations('a' to 'z', k = 4, x = 7)) ) def main(args: Array[String]) {