Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dataset.createTransformers fix for DatasetView/TransformTrainer #364

Merged
merged 5 commits into from
Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions Core/src/main/java/org/tribuo/Dataset.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2015, 2022, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2015, 2024, Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -72,7 +72,7 @@ public abstract class Dataset<T extends Output<T>> implements Iterable<Example<T
/**
* Users of this RNG should synchronize on the Dataset to prevent replicability issues.
*/
private static final SplittableRandom rng = new SplittableRandom(Trainer.DEFAULT_SEED);
protected static final SplittableRandom rng = new SplittableRandom(Trainer.DEFAULT_SEED);

/**
* The data in this data set.
Expand Down Expand Up @@ -395,7 +395,7 @@ public TransformerMap createTransformers(TransformationMap transformations, bool
}
// Add the queue to the map for that feature
featureStats.put(entry.getKey(),l);
sparseCount.put(entry.getKey(), new MutableLong(data.size()));
sparseCount.put(entry.getKey(), new MutableLong(size()));
}
if (!transformations.getGlobalTransformations().isEmpty()) {
// Append all the global transformations
Expand All @@ -411,7 +411,7 @@ public TransformerMap createTransformers(TransformationMap transformations, bool
// Add the queue to the map for that feature
featureStats.put(v, l);
// Generate the sparse count initialised to the number of features.
sparseCount.putIfAbsent(v, new MutableLong(data.size()));
sparseCount.putIfAbsent(v, new MutableLong(size()));
ndone++;
if(logger.isLoggable(Level.FINE) && ndone % 10000 == 0) {
logger.fine(String.format("Completed %,d of %,d global transformations", ndone, ntransform));
Expand All @@ -424,7 +424,7 @@ public TransformerMap createTransformers(TransformationMap transformations, bool
boolean initialisedSparseCounts = false;
// Iterate through the dataset max(transformations.length) times.
while (!featureStats.isEmpty()) {
for (Example<T> example : data) {
for (Example<T> example : this) {
for (Feature f : example) {
if (featureStats.containsKey(f.getName())) {
if (!initialisedSparseCounts) {
Expand Down
44 changes: 40 additions & 4 deletions Core/src/main/java/org/tribuo/dataset/DatasetView.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2015, 2022, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2015, 2024, Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -373,9 +373,22 @@ public ImmutableOutputInfo<T> getOutputInfo() {
return outputIDInfo;
}

@Override
public synchronized void shuffle(boolean shuffle) {
if (shuffle) {
indices = Util.randperm(size(), rng);
} else {
indices = null;
}
}

@Override
public Iterator<Example<T>> iterator() {
return new ViewIterator<>(this);
if (indices != null) {
return new ShuffledViewIterator<>(this);
} else{
return new ViewIterator<>(this);
}
}

@Override
Expand Down Expand Up @@ -459,6 +472,29 @@ private static boolean validateIndices(int size, int[] indices) {
return valid;
}

private static final class ShuffledViewIterator<T extends Output<T>> implements Iterator<Example<T>> {

private int counter = 0;
private final DatasetView<T> dataset;

ShuffledViewIterator(DatasetView<T> dataset) {
this.dataset = dataset;
}

@Override
public boolean hasNext() {
return counter < dataset.size();
}

@Override
public Example<T> next() {
Example<T> example = dataset.getExample(dataset.indices[counter]);
counter++;
return example;
}

}

private static final class ViewIterator<T extends Output<T>> implements Iterator<Example<T>> {

private int counter = 0;
Expand Down Expand Up @@ -509,7 +545,7 @@ <T extends Output<T>> DatasetViewProvenance(DatasetView<T> dataset, boolean stor
this.weighted = new BooleanProvenance(WEIGHTED,dataset.weighted);
this.sampled = new BooleanProvenance(SAMPLED,dataset.sampled);
this.tag = new StringProvenance(TAG,dataset.tag);
this.indices = storeIndices ? dataset.indices : new int[0];
this.indices = storeIndices ? dataset.exampleIndices : new int[0];
}

/**
Expand All @@ -525,7 +561,7 @@ public DatasetViewProvenance(Map<String,Provenance> map) {
this.sampled = ObjectProvenance.checkAndExtractProvenance(map,SAMPLED,BooleanProvenance.class, DatasetViewProvenance.class.getSimpleName());
@SuppressWarnings("unchecked") // List provenance cast
ListProvenance<IntProvenance> listIndices = ObjectProvenance.checkAndExtractProvenance(map,INDICES,ListProvenance.class, DatasetViewProvenance.class.getSimpleName());
if (listIndices.getList().size() > 0) {
if (!listIndices.getList().isEmpty()) {
try {
IntProvenance i = listIndices.getList().get(0);
} catch (ClassCastException e) {
Expand Down