Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allows configurable iteration order in AggregateDataSource, and adds a configurable version #125

Merged
merged 2 commits into from
Apr 14, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
/*
* Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.tribuo.datasource;

import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.Provenance;
import com.oracle.labs.mlrg.olcut.provenance.impl.SkeletalConfiguredObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance;
import org.tribuo.ConfigurableDataSource;
import org.tribuo.Example;
import org.tribuo.Output;
import org.tribuo.OutputFactory;
import org.tribuo.datasource.AggregateDataSource.IterationOrder;
import org.tribuo.provenance.DataSourceProvenance;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;

/**
* Aggregates multiple {@link ConfigurableDataSource}s, uses {@link AggregateDataSource.IterationOrder} to control the
* iteration order.
* <p>
* Identical to {@link AggregateDataSource} except it can be configured.
*/
public class AggregateConfigurableDataSource<T extends Output<T>> implements ConfigurableDataSource<T> {

@Config(mandatory = true, description = "The iteration order.")
private IterationOrder order;

@Config(mandatory = true, description = "The sources to aggregate.")
private List<ConfigurableDataSource<T>> sources;

/**
* Creates an aggregate data source which will iterate the provided
* sources in the order of the list (i.e., using {@link IterationOrder#SEQUENTIAL}.
* @param sources The sources to aggregate.
*/
public AggregateConfigurableDataSource(List<ConfigurableDataSource<T>> sources) {
this(sources, IterationOrder.SEQUENTIAL);
}

/**
* Creates an aggregate data source using the supplied sources and iteration order.
* @param sources The sources to iterate.
* @param order The iteration order.
*/
public AggregateConfigurableDataSource(List<ConfigurableDataSource<T>> sources, IterationOrder order) {
this.sources = Collections.unmodifiableList(new ArrayList<>(sources));
this.order = order;
}

@Override
public String toString() {
return "AggregateConfigurableDataSource(sources="+sources.toString()+",order="+order+")";
}

@Override
public OutputFactory<T> getOutputFactory() {
return sources.get(0).getOutputFactory();
}

@Override
public Iterator<Example<T>> iterator() {
switch (order) {
case ROUNDROBIN:
return new AggregateDataSource.ADSRRIterator<>(sources);
case SEQUENTIAL:
return new AggregateDataSource.ADSSeqIterator<>(sources);
default:
throw new IllegalStateException("Unknown enum value " + order);
}
}

@Override
public DataSourceProvenance getProvenance() {
return new AggregateConfigurableDataSourceProvenance(this);
}

/**
* Provenance for the {@link AggregateConfigurableDataSource}.
*/
public static class AggregateConfigurableDataSourceProvenance extends SkeletalConfiguredObjectProvenance implements DataSourceProvenance {
private static final long serialVersionUID = 1L;

<T extends Output<T>> AggregateConfigurableDataSourceProvenance(AggregateConfigurableDataSource<T> host) {
super(host, "DataSource");
}

/**
* Deserialization constructor.
* @param map The provenance to deserialize.
*/
public AggregateConfigurableDataSourceProvenance(Map<String, Provenance> map) {
this(extractProvenanceInfo(map));
}

private AggregateConfigurableDataSourceProvenance(ExtractedInfo info) {
super(info);
}

protected static ExtractedInfo extractProvenanceInfo(Map<String, Provenance> map) {
Map<String, Provenance> configuredParameters = new HashMap<>(map);
String className = ObjectProvenance.checkAndExtractProvenance(configuredParameters,CLASS_NAME, StringProvenance.class, AggregateConfigurableDataSourceProvenance.class.getSimpleName()).getValue();
String hostTypeStringName = ObjectProvenance.checkAndExtractProvenance(configuredParameters, HOST_SHORT_NAME, StringProvenance.class, AggregateConfigurableDataSourceProvenance.class.getSimpleName()).getValue();
return new ExtractedInfo(className, hostTypeStringName, configuredParameters, Collections.emptyMap());
}
}
}
122 changes: 112 additions & 10 deletions Core/src/main/java/org/tribuo/datasource/AggregateDataSource.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,36 +19,73 @@
import com.oracle.labs.mlrg.olcut.provenance.ListProvenance;
import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.Provenance;
import com.oracle.labs.mlrg.olcut.provenance.primitives.EnumProvenance;
import com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance;
import com.oracle.labs.mlrg.olcut.util.Pair;
import org.tribuo.DataSource;
import org.tribuo.Example;
import org.tribuo.Output;
import org.tribuo.OutputFactory;
import org.tribuo.provenance.DataSourceProvenance;
import org.tribuo.provenance.ModelProvenance;

import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Deque;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Objects;
import java.util.Optional;

/**
* Aggregates multiple {@link DataSource}s, and round-robins the iterators.
* Aggregates multiple {@link DataSource}s, uses {@link AggregateDataSource.IterationOrder} to control the
* iteration order.
*/
public class AggregateDataSource<T extends Output<T>> implements DataSource<T> {


/**
* Specifies the iteration order of the inner sources.
*/
public enum IterationOrder {
/**
* Round-robins the iterators (i.e., chooses one from each in turn).
*/
ROUNDROBIN,
/**
* Iterates each dataset sequentially, in the order of the sources list.
*/
SEQUENTIAL;
}

private final IterationOrder order;

private final List<DataSource<T>> sources;

/**
* Creates an aggregate data source which will iterate the provided
* sources in the order of the list (i.e., using {@link IterationOrder#SEQUENTIAL}.
* @param sources The sources to aggregate.
*/
public AggregateDataSource(List<DataSource<T>> sources) {
this(sources,IterationOrder.SEQUENTIAL);
}

/**
* Creates an aggregate data source using the supplied sources and iteration order.
* @param sources The sources to iterate.
* @param order The iteration order.
*/
public AggregateDataSource(List<DataSource<T>> sources, IterationOrder order) {
this.sources = Collections.unmodifiableList(new ArrayList<>(sources));
this.order = order;
}

@Override
public String toString() {
return "AggregateDataSource(sources="+sources.toString()+")";
return "AggregateDataSource(sources="+sources.toString()+",order="+order+")";
}

@Override
Expand All @@ -58,17 +95,66 @@ public OutputFactory<T> getOutputFactory() {

@Override
public Iterator<Example<T>> iterator() {
return new ADSIterator();
switch (order) {
case ROUNDROBIN:
return new ADSRRIterator<>(sources);
case SEQUENTIAL:
return new ADSSeqIterator<>(sources);
default:
throw new IllegalStateException("Unknown enum value " + order);
}
}

@Override
public DataSourceProvenance getProvenance() {
return new AggregateDataSourceProvenance(this);
}

private class ADSIterator implements Iterator<Example<T>> {
Iterator<DataSource<T>> si = sources.iterator();
Iterator<Example<T>> curr = null;
static class ADSRRIterator<T extends Output<T>> implements Iterator<Example<T>> {
private final Deque<Iterator<Example<T>>> queue;

ADSRRIterator(List<? extends DataSource<T>> sources) {
this.queue = new ArrayDeque<>(sources.size());
for (DataSource<T> ds : sources) {
Iterator<Example<T>> itr = ds.iterator();
if (itr.hasNext()) {
queue.addLast(itr);
}
}
}

@Override
public boolean hasNext() {
return !queue.isEmpty();
}

@Override
public Example<T> next() {
if (!hasNext()) {
throw new NoSuchElementException("Iterator exhausted");
}
Iterator<Example<T>> itr = queue.pollFirst();
if (itr.hasNext()) {
Example<T> buff = itr.next();
if (itr.hasNext()) {
queue.addLast(itr);
}
return buff;
} else {
throw new IllegalStateException("Invalid iterator in queue");
}
}
}

static class ADSSeqIterator<T extends Output<T>> implements Iterator<Example<T>> {
private final Iterator<? extends DataSource<T>> si;
private Iterator<Example<T>> curr;

ADSSeqIterator(List<? extends DataSource<T>> sources) {
this.si = sources.iterator();
this.curr = null;
}

@Override
public boolean hasNext() {
if (curr == null) {
Expand Down Expand Up @@ -106,19 +192,25 @@ public static class AggregateDataSourceProvenance implements DataSourceProvenanc
private static final long serialVersionUID = 1L;

private static final String SOURCES = "sources";
private static final String ORDER = "order";

private final StringProvenance className;
private final ListProvenance<DataSourceProvenance> provenances;
private EnumProvenance<IterationOrder> orderProvenance;

<T extends Output<T>> AggregateDataSourceProvenance(AggregateDataSource<T> host) {
this.className = new StringProvenance(CLASS_NAME,host.getClass().getName());
this.provenances = ListProvenance.createListProvenance(host.sources);
this.orderProvenance = new EnumProvenance<>(ORDER,host.order);
}

@SuppressWarnings("unchecked") //ListProvenance cast
@SuppressWarnings({"unchecked","rawtypes"}) //ListProvenance cast, EnumProvenance cast
public AggregateDataSourceProvenance(Map<String,Provenance> map) {
this.className = ObjectProvenance.checkAndExtractProvenance(map,CLASS_NAME, StringProvenance.class,AggregateDataSourceProvenance.class.getSimpleName());
this.provenances = ObjectProvenance.checkAndExtractProvenance(map,SOURCES,ListProvenance.class,AggregateDataSourceProvenance.class.getSimpleName());
// TODO fix this when we upgrade OLCUT.
Optional<EnumProvenance> opt = ModelProvenance.maybeExtractProvenance(map,ORDER,EnumProvenance.class);
this.orderProvenance = opt.orElseGet(() -> new EnumProvenance<>(ORDER, IterationOrder.SEQUENTIAL));
}

@Override
Expand All @@ -132,22 +224,32 @@ public Iterator<Pair<String, Provenance>> iterator() {

list.add(new Pair<>(CLASS_NAME,className));
list.add(new Pair<>(SOURCES,provenances));
list.add(new Pair<>(ORDER,getOrder()));

return list.iterator();
}

private EnumProvenance<IterationOrder> getOrder() {
if (orderProvenance != null) {
return orderProvenance;
} else {
return new EnumProvenance<>(ORDER,IterationOrder.SEQUENTIAL);
}
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (!(o instanceof AggregateDataSourceProvenance)) return false;
AggregateDataSourceProvenance pairs = (AggregateDataSourceProvenance) o;
return className.equals(pairs.className) &&
provenances.equals(pairs.provenances);
provenances.equals(pairs.provenances) &&
getOrder().equals(pairs.getOrder());
}

@Override
public int hashCode() {
return Objects.hash(className, provenances);
return Objects.hash(className, provenances, getOrder());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,8 @@ public ModelProvenance(Map<String,Provenance> map) {
this.time = ObjectProvenance.checkAndExtractProvenance(map,TRAINING_TIME,DateTimeProvenance.class, ModelProvenance.class.getSimpleName()).getValue();
this.instanceProvenance = (MapProvenance<?>) ObjectProvenance.checkAndExtractProvenance(map,INSTANCE_VALUES,MapProvenance.class, ModelProvenance.class.getSimpleName());
this.versionString = ObjectProvenance.checkAndExtractProvenance(map,TRIBUO_VERSION_STRING,StringProvenance.class,ModelProvenance.class.getSimpleName()).getValue();

// TODO fix this when we upgrade OLCUT.
this.javaVersionString = maybeExtractProvenance(map,JAVA_VERSION_STRING,StringProvenance.class).map(StringProvenance::getValue).orElse(UNKNOWN_VERSION);
this.osString = maybeExtractProvenance(map,OS_STRING,StringProvenance.class).map(StringProvenance::getValue).orElse(UNKNOWN_VERSION);
this.archString = maybeExtractProvenance(map,ARCH_STRING,StringProvenance.class).map(StringProvenance::getValue).orElse(UNKNOWN_VERSION);
Expand All @@ -164,6 +166,8 @@ public ModelProvenance(Map<String,Provenance> map) {
/**
* Like {@link ObjectProvenance#checkAndExtractProvenance(Map, String, Class, String)} but doesn't
* throw if it fails to find the key, only if the value is of the wrong type.
*
* @deprecated Deprecated as it's in OLCUT.
* @param map The map to inspect.
* @param key The key to find.
* @param type The class of the value.
Expand All @@ -172,7 +176,8 @@ public ModelProvenance(Map<String,Provenance> map) {
* @throws ProvenanceException If the value is the wrong type.
*/
@SuppressWarnings("unchecked") // Guarded by isInstance check
private static <T extends Provenance> Optional<T> maybeExtractProvenance(Map<String,Provenance> map, String key, Class<T> type) throws ProvenanceException {
@Deprecated
public static <T extends Provenance> Optional<T> maybeExtractProvenance(Map<String,Provenance> map, String key, Class<T> type) throws ProvenanceException {
Provenance tmp = map.remove(key);
if (tmp != null) {
if (type.isInstance(tmp)) {
Expand Down
Loading