Skip to content

Commit

Permalink
Dataset.createTransformers fix for DatasetView/TransformTrainer (#364)
Browse files Browse the repository at this point in the history
* Fix a bug where transformations would be incorrectly computed if called on a DatasetView.

* Fix copyright year.

* More fixes for things which access Dataset.data when they shouldn't. Also a fix for DatasetView's provenance recording which was picking up the wrong indices field.

* Fixing copyright.

* More fixing copyright.
  • Loading branch information
Craigacp authored Apr 30, 2024
1 parent 9ad419c commit 83e197f
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 9 deletions.
10 changes: 5 additions & 5 deletions Core/src/main/java/org/tribuo/Dataset.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2015, 2022, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2015, 2024, Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -72,7 +72,7 @@ public abstract class Dataset<T extends Output<T>> implements Iterable<Example<T
/**
* Users of this RNG should synchronize on the Dataset to prevent replicability issues.
*/
private static final SplittableRandom rng = new SplittableRandom(Trainer.DEFAULT_SEED);
protected static final SplittableRandom rng = new SplittableRandom(Trainer.DEFAULT_SEED);

/**
* The data in this data set.
Expand Down Expand Up @@ -395,7 +395,7 @@ public TransformerMap createTransformers(TransformationMap transformations, bool
}
// Add the queue to the map for that feature
featureStats.put(entry.getKey(),l);
sparseCount.put(entry.getKey(), new MutableLong(data.size()));
sparseCount.put(entry.getKey(), new MutableLong(size()));
}
if (!transformations.getGlobalTransformations().isEmpty()) {
// Append all the global transformations
Expand All @@ -411,7 +411,7 @@ public TransformerMap createTransformers(TransformationMap transformations, bool
// Add the queue to the map for that feature
featureStats.put(v, l);
// Generate the sparse count initialised to the number of features.
sparseCount.putIfAbsent(v, new MutableLong(data.size()));
sparseCount.putIfAbsent(v, new MutableLong(size()));
ndone++;
if(logger.isLoggable(Level.FINE) && ndone % 10000 == 0) {
logger.fine(String.format("Completed %,d of %,d global transformations", ndone, ntransform));
Expand All @@ -424,7 +424,7 @@ public TransformerMap createTransformers(TransformationMap transformations, bool
boolean initialisedSparseCounts = false;
// Iterate through the dataset max(transformations.length) times.
while (!featureStats.isEmpty()) {
for (Example<T> example : data) {
for (Example<T> example : this) {
for (Feature f : example) {
if (featureStats.containsKey(f.getName())) {
if (!initialisedSparseCounts) {
Expand Down
44 changes: 40 additions & 4 deletions Core/src/main/java/org/tribuo/dataset/DatasetView.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2015, 2022, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2015, 2024, Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -373,9 +373,22 @@ public ImmutableOutputInfo<T> getOutputInfo() {
return outputIDInfo;
}

@Override
public synchronized void shuffle(boolean shuffle) {
if (shuffle) {
indices = Util.randperm(size(), rng);
} else {
indices = null;
}
}

@Override
public Iterator<Example<T>> iterator() {
return new ViewIterator<>(this);
if (indices != null) {
return new ShuffledViewIterator<>(this);
} else{
return new ViewIterator<>(this);
}
}

@Override
Expand Down Expand Up @@ -459,6 +472,29 @@ private static boolean validateIndices(int size, int[] indices) {
return valid;
}

private static final class ShuffledViewIterator<T extends Output<T>> implements Iterator<Example<T>> {

private int counter = 0;
private final DatasetView<T> dataset;

ShuffledViewIterator(DatasetView<T> dataset) {
this.dataset = dataset;
}

@Override
public boolean hasNext() {
return counter < dataset.size();
}

@Override
public Example<T> next() {
Example<T> example = dataset.getExample(dataset.indices[counter]);
counter++;
return example;
}

}

private static final class ViewIterator<T extends Output<T>> implements Iterator<Example<T>> {

private int counter = 0;
Expand Down Expand Up @@ -509,7 +545,7 @@ <T extends Output<T>> DatasetViewProvenance(DatasetView<T> dataset, boolean stor
this.weighted = new BooleanProvenance(WEIGHTED,dataset.weighted);
this.sampled = new BooleanProvenance(SAMPLED,dataset.sampled);
this.tag = new StringProvenance(TAG,dataset.tag);
this.indices = storeIndices ? dataset.indices : new int[0];
this.indices = storeIndices ? dataset.exampleIndices : new int[0];
}

/**
Expand All @@ -525,7 +561,7 @@ public DatasetViewProvenance(Map<String,Provenance> map) {
this.sampled = ObjectProvenance.checkAndExtractProvenance(map,SAMPLED,BooleanProvenance.class, DatasetViewProvenance.class.getSimpleName());
@SuppressWarnings("unchecked") // List provenance cast
ListProvenance<IntProvenance> listIndices = ObjectProvenance.checkAndExtractProvenance(map,INDICES,ListProvenance.class, DatasetViewProvenance.class.getSimpleName());
if (listIndices.getList().size() > 0) {
if (!listIndices.getList().isEmpty()) {
try {
IntProvenance i = listIndices.getList().get(0);
} catch (ClassCastException e) {
Expand Down

0 comments on commit 83e197f

Please sign in to comment.