From d44ac616e3cc0dd6e2623b37fd0883f42d47aa45 Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Fri, 2 Apr 2021 15:28:42 -0400 Subject: [PATCH 1/2] Adding round robin iteration to AggregateDataSource. --- .../datasource/AggregateDataSource.java | 120 ++++++++++++++++-- .../tribuo/provenance/ModelProvenance.java | 7 +- .../datasource/AggregateDataSourceTest.java | 76 +++++++++++ 3 files changed, 193 insertions(+), 10 deletions(-) create mode 100644 Core/src/test/java/org/tribuo/datasource/AggregateDataSourceTest.java diff --git a/Core/src/main/java/org/tribuo/datasource/AggregateDataSource.java b/Core/src/main/java/org/tribuo/datasource/AggregateDataSource.java index 2cebddee3..d6078da89 100644 --- a/Core/src/main/java/org/tribuo/datasource/AggregateDataSource.java +++ b/Core/src/main/java/org/tribuo/datasource/AggregateDataSource.java @@ -19,6 +19,7 @@ import com.oracle.labs.mlrg.olcut.provenance.ListProvenance; import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance; import com.oracle.labs.mlrg.olcut.provenance.Provenance; +import com.oracle.labs.mlrg.olcut.provenance.primitives.EnumProvenance; import com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance; import com.oracle.labs.mlrg.olcut.util.Pair; import org.tribuo.DataSource; @@ -26,24 +27,60 @@ import org.tribuo.Output; import org.tribuo.OutputFactory; import org.tribuo.provenance.DataSourceProvenance; +import org.tribuo.provenance.ModelProvenance; +import java.util.ArrayDeque; import java.util.ArrayList; import java.util.Collections; +import java.util.Deque; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.NoSuchElementException; import java.util.Objects; +import java.util.Optional; /** - * Aggregates multiple {@link DataSource}s, and round-robins the iterators. + * Aggregates multiple {@link DataSource}s, uses {@link AggregateDataSource.IterationOrder} to control the + * iteration order. */ public class AggregateDataSource> implements DataSource { - + + /** + * Specifies the iteration order of the inner sources. + */ + public enum IterationOrder { + /** + * Round-robins the iterators (i.e., chooses one from each in turn). + */ + ROUNDROBIN, + /** + * Iterates each dataset sequentially, in the order of the sources list. + */ + SEQUENTIAL; + } + + private final IterationOrder order; + private final List> sources; + /** + * Creates an aggregate data source which will iterate the provided + * sources in the order of the list (i.e., using {@link IterationOrder#SEQUENTIAL}. + * @param sources The sources to aggregate. + */ public AggregateDataSource(List> sources) { + this(sources,IterationOrder.SEQUENTIAL); + } + + /** + * Creates an aggregate data source using the supplied sources and iteration order. + * @param sources The sources to iterate. + * @param order The iteration order. + */ + public AggregateDataSource(List> sources, IterationOrder order) { this.sources = Collections.unmodifiableList(new ArrayList<>(sources)); + this.order = order; } @Override @@ -58,7 +95,14 @@ public OutputFactory getOutputFactory() { @Override public Iterator> iterator() { - return new ADSIterator(); + switch (order) { + case ROUNDROBIN: + return new ADSRRIterator<>(sources); + case SEQUENTIAL: + return new ADSSeqIterator<>(sources); + default: + throw new IllegalStateException("Unknown enum value " + order); + } } @Override @@ -66,9 +110,51 @@ public DataSourceProvenance getProvenance() { return new AggregateDataSourceProvenance(this); } - private class ADSIterator implements Iterator> { - Iterator> si = sources.iterator(); - Iterator> curr = null; + private static class ADSRRIterator> implements Iterator> { + private final Deque>> queue; + + ADSRRIterator(List> sources) { + this.queue = new ArrayDeque<>(sources.size()); + for (DataSource ds : sources) { + Iterator> itr = ds.iterator(); + if (itr.hasNext()) { + queue.addLast(itr); + } + } + } + + @Override + public boolean hasNext() { + return !queue.isEmpty(); + } + + @Override + public Example next() { + if (!hasNext()) { + throw new NoSuchElementException("Iterator exhausted"); + } + Iterator> itr = queue.pollFirst(); + if (itr.hasNext()) { + Example buff = itr.next(); + if (itr.hasNext()) { + queue.addLast(itr); + } + return buff; + } else { + throw new IllegalStateException("Invalid iterator in queue"); + } + } + } + + private static class ADSSeqIterator> implements Iterator> { + private final Iterator> si; + private Iterator> curr; + + ADSSeqIterator(List> sources) { + this.si = sources.iterator(); + this.curr = null; + } + @Override public boolean hasNext() { if (curr == null) { @@ -106,19 +192,25 @@ public static class AggregateDataSourceProvenance implements DataSourceProvenanc private static final long serialVersionUID = 1L; private static final String SOURCES = "sources"; + private static final String ORDER = "order"; private final StringProvenance className; private final ListProvenance provenances; + private EnumProvenance orderProvenance; > AggregateDataSourceProvenance(AggregateDataSource host) { this.className = new StringProvenance(CLASS_NAME,host.getClass().getName()); this.provenances = ListProvenance.createListProvenance(host.sources); + this.orderProvenance = new EnumProvenance<>(ORDER,host.order); } - @SuppressWarnings("unchecked") //ListProvenance cast + @SuppressWarnings({"unchecked","rawtypes"}) //ListProvenance cast, EnumProvenance cast public AggregateDataSourceProvenance(Map map) { this.className = ObjectProvenance.checkAndExtractProvenance(map,CLASS_NAME, StringProvenance.class,AggregateDataSourceProvenance.class.getSimpleName()); this.provenances = ObjectProvenance.checkAndExtractProvenance(map,SOURCES,ListProvenance.class,AggregateDataSourceProvenance.class.getSimpleName()); + // TODO fix this when we upgrade OLCUT. + Optional opt = ModelProvenance.maybeExtractProvenance(map,ORDER,EnumProvenance.class); + this.orderProvenance = opt.orElseGet(() -> new EnumProvenance<>(ORDER, IterationOrder.SEQUENTIAL)); } @Override @@ -132,22 +224,32 @@ public Iterator> iterator() { list.add(new Pair<>(CLASS_NAME,className)); list.add(new Pair<>(SOURCES,provenances)); + list.add(new Pair<>(ORDER,getOrder())); return list.iterator(); } + private EnumProvenance getOrder() { + if (orderProvenance != null) { + return orderProvenance; + } else { + return new EnumProvenance<>(ORDER,IterationOrder.SEQUENTIAL); + } + } + @Override public boolean equals(Object o) { if (this == o) return true; if (!(o instanceof AggregateDataSourceProvenance)) return false; AggregateDataSourceProvenance pairs = (AggregateDataSourceProvenance) o; return className.equals(pairs.className) && - provenances.equals(pairs.provenances); + provenances.equals(pairs.provenances) && + getOrder().equals(pairs.getOrder()); } @Override public int hashCode() { - return Objects.hash(className, provenances); + return Objects.hash(className, provenances, getOrder()); } @Override diff --git a/Core/src/main/java/org/tribuo/provenance/ModelProvenance.java b/Core/src/main/java/org/tribuo/provenance/ModelProvenance.java index af986af91..f914a4772 100644 --- a/Core/src/main/java/org/tribuo/provenance/ModelProvenance.java +++ b/Core/src/main/java/org/tribuo/provenance/ModelProvenance.java @@ -156,6 +156,8 @@ public ModelProvenance(Map map) { this.time = ObjectProvenance.checkAndExtractProvenance(map,TRAINING_TIME,DateTimeProvenance.class, ModelProvenance.class.getSimpleName()).getValue(); this.instanceProvenance = (MapProvenance) ObjectProvenance.checkAndExtractProvenance(map,INSTANCE_VALUES,MapProvenance.class, ModelProvenance.class.getSimpleName()); this.versionString = ObjectProvenance.checkAndExtractProvenance(map,TRIBUO_VERSION_STRING,StringProvenance.class,ModelProvenance.class.getSimpleName()).getValue(); + + // TODO fix this when we upgrade OLCUT. this.javaVersionString = maybeExtractProvenance(map,JAVA_VERSION_STRING,StringProvenance.class).map(StringProvenance::getValue).orElse(UNKNOWN_VERSION); this.osString = maybeExtractProvenance(map,OS_STRING,StringProvenance.class).map(StringProvenance::getValue).orElse(UNKNOWN_VERSION); this.archString = maybeExtractProvenance(map,ARCH_STRING,StringProvenance.class).map(StringProvenance::getValue).orElse(UNKNOWN_VERSION); @@ -164,6 +166,8 @@ public ModelProvenance(Map map) { /** * Like {@link ObjectProvenance#checkAndExtractProvenance(Map, String, Class, String)} but doesn't * throw if it fails to find the key, only if the value is of the wrong type. + * + * @deprecated Deprecated as it's in OLCUT. * @param map The map to inspect. * @param key The key to find. * @param type The class of the value. @@ -172,7 +176,8 @@ public ModelProvenance(Map map) { * @throws ProvenanceException If the value is the wrong type. */ @SuppressWarnings("unchecked") // Guarded by isInstance check - private static Optional maybeExtractProvenance(Map map, String key, Class type) throws ProvenanceException { + @Deprecated + public static Optional maybeExtractProvenance(Map map, String key, Class type) throws ProvenanceException { Provenance tmp = map.remove(key); if (tmp != null) { if (type.isInstance(tmp)) { diff --git a/Core/src/test/java/org/tribuo/datasource/AggregateDataSourceTest.java b/Core/src/test/java/org/tribuo/datasource/AggregateDataSourceTest.java new file mode 100644 index 000000000..e6bd270c8 --- /dev/null +++ b/Core/src/test/java/org/tribuo/datasource/AggregateDataSourceTest.java @@ -0,0 +1,76 @@ +/* + * Copyright (c) 2021, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.tribuo.datasource; + +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.tribuo.DataSource; +import org.tribuo.Example; +import org.tribuo.impl.ArrayExample; +import org.tribuo.provenance.SimpleDataSourceProvenance; +import org.tribuo.test.MockOutput; +import org.tribuo.test.MockOutputFactory; + +import java.util.ArrayList; +import java.util.List; +import java.util.stream.StreamSupport; + +public class AggregateDataSourceTest { + + @Test + public void testIterationOrder() { + MockOutputFactory factory = new MockOutputFactory(); + String[] featureNames = new String[] {"X1","X2"}; + double[] featureValues = new double[] {1.0, 2.0}; + + List> first = new ArrayList<>(); + first.add(new ArrayExample<>(new MockOutput("A"),featureNames,featureValues)); + first.add(new ArrayExample<>(new MockOutput("B"),featureNames,featureValues)); + first.add(new ArrayExample<>(new MockOutput("C"),featureNames,featureValues)); + first.add(new ArrayExample<>(new MockOutput("D"),featureNames,featureValues)); + first.add(new ArrayExample<>(new MockOutput("E"),featureNames,featureValues)); + ListDataSource firstSource = new ListDataSource<>(first,factory,new SimpleDataSourceProvenance("First",factory)); + + List> second = new ArrayList<>(); + second.add(new ArrayExample<>(new MockOutput("F"),featureNames,featureValues)); + second.add(new ArrayExample<>(new MockOutput("G"),featureNames,featureValues)); + ListDataSource secondSource = new ListDataSource<>(second,factory,new SimpleDataSourceProvenance("Second",factory)); + + List> third = new ArrayList<>(); + third.add(new ArrayExample<>(new MockOutput("H"),featureNames,featureValues)); + third.add(new ArrayExample<>(new MockOutput("I"),featureNames,featureValues)); + third.add(new ArrayExample<>(new MockOutput("J"),featureNames,featureValues)); + third.add(new ArrayExample<>(new MockOutput("K"),featureNames,featureValues)); + ListDataSource thirdSource = new ListDataSource<>(third,factory,new SimpleDataSourceProvenance("Third",factory)); + + List> sources = new ArrayList<>(); + sources.add(firstSource); + sources.add(secondSource); + sources.add(thirdSource); + + AggregateDataSource adsSeq = new AggregateDataSource<>(sources, AggregateDataSource.IterationOrder.SEQUENTIAL); + String[] expectedSeq = new String[] {"A","B","C","D","E","F","G","H","I","J","K"}; + String[] actualSeq = StreamSupport.stream(adsSeq.spliterator(), false).map(Example::getOutput).map(MockOutput::toString).toArray(String[]::new); + Assertions.assertArrayEquals(expectedSeq,actualSeq); + + AggregateDataSource adsRR = new AggregateDataSource<>(sources, AggregateDataSource.IterationOrder.ROUNDROBIN); + String[] expectedRR = new String[] {"A","F","H","B","G","I","C","J","D","K","E"}; + String[] actualRR = StreamSupport.stream(adsRR.spliterator(), false).map(Example::getOutput).map(MockOutput::toString).toArray(String[]::new); + Assertions.assertArrayEquals(expectedRR,actualRR); + } + +} From 33003c248fc8ab228898ff91d0fd31edb9f744d5 Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Fri, 2 Apr 2021 16:08:19 -0400 Subject: [PATCH 2/2] Adding AggregateConfigurableDataSource and slighly refactoring the way provenance marshalling is tested. --- .../AggregateConfigurableDataSource.java | 127 ++++++++++++++++++ .../datasource/AggregateDataSource.java | 6 +- .../datasource/AggregateDataSourceTest.java | 103 +++++++++++++- .../test/java/org/tribuo/test/Helpers.java | 11 +- 4 files changed, 240 insertions(+), 7 deletions(-) create mode 100644 Core/src/main/java/org/tribuo/datasource/AggregateConfigurableDataSource.java diff --git a/Core/src/main/java/org/tribuo/datasource/AggregateConfigurableDataSource.java b/Core/src/main/java/org/tribuo/datasource/AggregateConfigurableDataSource.java new file mode 100644 index 000000000..07708e0f6 --- /dev/null +++ b/Core/src/main/java/org/tribuo/datasource/AggregateConfigurableDataSource.java @@ -0,0 +1,127 @@ +/* + * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.tribuo.datasource; + +import com.oracle.labs.mlrg.olcut.config.Config; +import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance; +import com.oracle.labs.mlrg.olcut.provenance.Provenance; +import com.oracle.labs.mlrg.olcut.provenance.impl.SkeletalConfiguredObjectProvenance; +import com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance; +import org.tribuo.ConfigurableDataSource; +import org.tribuo.Example; +import org.tribuo.Output; +import org.tribuo.OutputFactory; +import org.tribuo.datasource.AggregateDataSource.IterationOrder; +import org.tribuo.provenance.DataSourceProvenance; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; + +/** + * Aggregates multiple {@link ConfigurableDataSource}s, uses {@link AggregateDataSource.IterationOrder} to control the + * iteration order. + *

+ * Identical to {@link AggregateDataSource} except it can be configured. + */ +public class AggregateConfigurableDataSource> implements ConfigurableDataSource { + + @Config(mandatory = true, description = "The iteration order.") + private IterationOrder order; + + @Config(mandatory = true, description = "The sources to aggregate.") + private List> sources; + + /** + * Creates an aggregate data source which will iterate the provided + * sources in the order of the list (i.e., using {@link IterationOrder#SEQUENTIAL}. + * @param sources The sources to aggregate. + */ + public AggregateConfigurableDataSource(List> sources) { + this(sources, IterationOrder.SEQUENTIAL); + } + + /** + * Creates an aggregate data source using the supplied sources and iteration order. + * @param sources The sources to iterate. + * @param order The iteration order. + */ + public AggregateConfigurableDataSource(List> sources, IterationOrder order) { + this.sources = Collections.unmodifiableList(new ArrayList<>(sources)); + this.order = order; + } + + @Override + public String toString() { + return "AggregateConfigurableDataSource(sources="+sources.toString()+",order="+order+")"; + } + + @Override + public OutputFactory getOutputFactory() { + return sources.get(0).getOutputFactory(); + } + + @Override + public Iterator> iterator() { + switch (order) { + case ROUNDROBIN: + return new AggregateDataSource.ADSRRIterator<>(sources); + case SEQUENTIAL: + return new AggregateDataSource.ADSSeqIterator<>(sources); + default: + throw new IllegalStateException("Unknown enum value " + order); + } + } + + @Override + public DataSourceProvenance getProvenance() { + return new AggregateConfigurableDataSourceProvenance(this); + } + + /** + * Provenance for the {@link AggregateConfigurableDataSource}. + */ + public static class AggregateConfigurableDataSourceProvenance extends SkeletalConfiguredObjectProvenance implements DataSourceProvenance { + private static final long serialVersionUID = 1L; + + > AggregateConfigurableDataSourceProvenance(AggregateConfigurableDataSource host) { + super(host, "DataSource"); + } + + /** + * Deserialization constructor. + * @param map The provenance to deserialize. + */ + public AggregateConfigurableDataSourceProvenance(Map map) { + this(extractProvenanceInfo(map)); + } + + private AggregateConfigurableDataSourceProvenance(ExtractedInfo info) { + super(info); + } + + protected static ExtractedInfo extractProvenanceInfo(Map map) { + Map configuredParameters = new HashMap<>(map); + String className = ObjectProvenance.checkAndExtractProvenance(configuredParameters,CLASS_NAME, StringProvenance.class, AggregateConfigurableDataSourceProvenance.class.getSimpleName()).getValue(); + String hostTypeStringName = ObjectProvenance.checkAndExtractProvenance(configuredParameters, HOST_SHORT_NAME, StringProvenance.class, AggregateConfigurableDataSourceProvenance.class.getSimpleName()).getValue(); + return new ExtractedInfo(className, hostTypeStringName, configuredParameters, Collections.emptyMap()); + } + } +} diff --git a/Core/src/main/java/org/tribuo/datasource/AggregateDataSource.java b/Core/src/main/java/org/tribuo/datasource/AggregateDataSource.java index d6078da89..de1fac1c6 100644 --- a/Core/src/main/java/org/tribuo/datasource/AggregateDataSource.java +++ b/Core/src/main/java/org/tribuo/datasource/AggregateDataSource.java @@ -85,7 +85,7 @@ public AggregateDataSource(List> sources, IterationOrder order) { @Override public String toString() { - return "AggregateDataSource(sources="+sources.toString()+")"; + return "AggregateDataSource(sources="+sources.toString()+",order="+order+")"; } @Override @@ -110,7 +110,7 @@ public DataSourceProvenance getProvenance() { return new AggregateDataSourceProvenance(this); } - private static class ADSRRIterator> implements Iterator> { + static class ADSRRIterator> implements Iterator> { private final Deque>> queue; ADSRRIterator(List> sources) { @@ -146,7 +146,7 @@ public Example next() { } } - private static class ADSSeqIterator> implements Iterator> { + static class ADSSeqIterator> implements Iterator> { private final Iterator> si; private Iterator> curr; diff --git a/Core/src/test/java/org/tribuo/datasource/AggregateDataSourceTest.java b/Core/src/test/java/org/tribuo/datasource/AggregateDataSourceTest.java index e6bd270c8..f94ec17df 100644 --- a/Core/src/test/java/org/tribuo/datasource/AggregateDataSourceTest.java +++ b/Core/src/test/java/org/tribuo/datasource/AggregateDataSourceTest.java @@ -18,21 +18,28 @@ import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; +import org.tribuo.ConfigurableDataSource; import org.tribuo.DataSource; import org.tribuo.Example; +import org.tribuo.Output; +import org.tribuo.OutputFactory; import org.tribuo.impl.ArrayExample; +import org.tribuo.provenance.DataSourceProvenance; import org.tribuo.provenance.SimpleDataSourceProvenance; +import org.tribuo.test.Helpers; import org.tribuo.test.MockOutput; import org.tribuo.test.MockOutputFactory; import java.util.ArrayList; +import java.util.Collections; +import java.util.Iterator; import java.util.List; import java.util.stream.StreamSupport; public class AggregateDataSourceTest { @Test - public void testIterationOrder() { + public void testADSIterationOrder() { MockOutputFactory factory = new MockOutputFactory(); String[] featureNames = new String[] {"X1","X2"}; double[] featureValues = new double[] {1.0, 2.0}; @@ -66,11 +73,105 @@ public void testIterationOrder() { String[] expectedSeq = new String[] {"A","B","C","D","E","F","G","H","I","J","K"}; String[] actualSeq = StreamSupport.stream(adsSeq.spliterator(), false).map(Example::getOutput).map(MockOutput::toString).toArray(String[]::new); Assertions.assertArrayEquals(expectedSeq,actualSeq); + Helpers.testProvenanceMarshalling(adsSeq.getProvenance()); AggregateDataSource adsRR = new AggregateDataSource<>(sources, AggregateDataSource.IterationOrder.ROUNDROBIN); String[] expectedRR = new String[] {"A","F","H","B","G","I","C","J","D","K","E"}; String[] actualRR = StreamSupport.stream(adsRR.spliterator(), false).map(Example::getOutput).map(MockOutput::toString).toArray(String[]::new); Assertions.assertArrayEquals(expectedRR,actualRR); + Helpers.testProvenanceMarshalling(adsRR.getProvenance()); + } + + @Test + public void testACDSIterationOrder() { + MockOutputFactory factory = new MockOutputFactory(); + String[] featureNames = new String[] {"X1","X2"}; + double[] featureValues = new double[] {1.0, 2.0}; + + List> first = new ArrayList<>(); + first.add(new ArrayExample<>(new MockOutput("A"),featureNames,featureValues)); + first.add(new ArrayExample<>(new MockOutput("B"),featureNames,featureValues)); + first.add(new ArrayExample<>(new MockOutput("C"),featureNames,featureValues)); + first.add(new ArrayExample<>(new MockOutput("D"),featureNames,featureValues)); + first.add(new ArrayExample<>(new MockOutput("E"),featureNames,featureValues)); + MockListConfigurableDataSource firstSource = new MockListConfigurableDataSource<>(first,factory,new SimpleDataSourceProvenance("First",factory)); + + List> second = new ArrayList<>(); + second.add(new ArrayExample<>(new MockOutput("F"),featureNames,featureValues)); + second.add(new ArrayExample<>(new MockOutput("G"),featureNames,featureValues)); + MockListConfigurableDataSource secondSource = new MockListConfigurableDataSource<>(second,factory,new SimpleDataSourceProvenance("Second",factory)); + + List> third = new ArrayList<>(); + third.add(new ArrayExample<>(new MockOutput("H"),featureNames,featureValues)); + third.add(new ArrayExample<>(new MockOutput("I"),featureNames,featureValues)); + third.add(new ArrayExample<>(new MockOutput("J"),featureNames,featureValues)); + third.add(new ArrayExample<>(new MockOutput("K"),featureNames,featureValues)); + MockListConfigurableDataSource thirdSource = new MockListConfigurableDataSource<>(third,factory,new SimpleDataSourceProvenance("Third",factory)); + + List> sources = new ArrayList<>(); + sources.add(firstSource); + sources.add(secondSource); + sources.add(thirdSource); + + AggregateConfigurableDataSource acdsSeq = new AggregateConfigurableDataSource<>(sources, AggregateDataSource.IterationOrder.SEQUENTIAL); + String[] expectedSeq = new String[] {"A","B","C","D","E","F","G","H","I","J","K"}; + String[] actualSeq = StreamSupport.stream(acdsSeq.spliterator(), false).map(Example::getOutput).map(MockOutput::toString).toArray(String[]::new); + Assertions.assertArrayEquals(expectedSeq,actualSeq); + Helpers.testProvenanceMarshalling(acdsSeq.getProvenance()); + + AggregateConfigurableDataSource acdsRR = new AggregateConfigurableDataSource<>(sources, AggregateDataSource.IterationOrder.ROUNDROBIN); + String[] expectedRR = new String[] {"A","F","H","B","G","I","C","J","D","K","E"}; + String[] actualRR = StreamSupport.stream(acdsRR.spliterator(), false).map(Example::getOutput).map(MockOutput::toString).toArray(String[]::new); + Assertions.assertArrayEquals(expectedRR,actualRR); + Helpers.testProvenanceMarshalling(acdsRR.getProvenance()); + + } + + /** + * This isn't actually configurable, it's used to test {@link AggregateConfigurableDataSource}. + * @param The output type. + */ + private static class MockListConfigurableDataSource> implements ConfigurableDataSource { + + private final List> data; + + private final OutputFactory factory; + + private final DataSourceProvenance provenance; + + public MockListConfigurableDataSource(List> list, OutputFactory factory, DataSourceProvenance provenance) { + this.data = Collections.unmodifiableList(new ArrayList<>(list)); + this.factory = factory; + this.provenance = provenance; + } + + /** + * Number of examples. + * @return The number of examples. + */ + public int size() { + return data.size(); + } + + @Override + public OutputFactory getOutputFactory() { + return factory; + } + + @Override + public DataSourceProvenance getProvenance() { + return provenance; + } + + @Override + public Iterator> iterator() { + return data.iterator(); + } + + @Override + public String toString() { + return provenance.toString(); + } } } diff --git a/Core/src/test/java/org/tribuo/test/Helpers.java b/Core/src/test/java/org/tribuo/test/Helpers.java index 60610fe25..bbabfe743 100644 --- a/Core/src/test/java/org/tribuo/test/Helpers.java +++ b/Core/src/test/java/org/tribuo/test/Helpers.java @@ -16,6 +16,7 @@ package org.tribuo.test; +import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance; import com.oracle.labs.mlrg.olcut.provenance.ProvenanceUtil; import com.oracle.labs.mlrg.olcut.provenance.io.ObjectMarshalledProvenance; import org.junit.jupiter.api.Assertions; @@ -75,11 +76,15 @@ public static Example mkExample(MockOutput label, String... features return ex; } + public static void testProvenanceMarshalling(ObjectProvenance inputProvenance) { + List provenanceList = ProvenanceUtil.marshalProvenance(inputProvenance); + ObjectProvenance unmarshalledProvenance = ProvenanceUtil.unmarshalProvenance(provenanceList); + Assertions.assertEquals(unmarshalledProvenance,inputProvenance); + } + public static > void testModelSerialization(Model model, Class outputClazz) { // test provenance marshalling - List provenanceList = ProvenanceUtil.marshalProvenance(model.getProvenance()); - ModelProvenance provenance = (ModelProvenance) ProvenanceUtil.unmarshalProvenance(provenanceList); - Assertions.assertEquals(provenance,model.getProvenance()); + testProvenanceMarshalling(model.getProvenance()); // write to byte array ByteArrayOutputStream baos = new ByteArrayOutputStream();