diff --git a/Core/src/main/java/org/tribuo/Dataset.java b/Core/src/main/java/org/tribuo/Dataset.java index 7673ea871..c25efcc5b 100644 --- a/Core/src/main/java/org/tribuo/Dataset.java +++ b/Core/src/main/java/org/tribuo/Dataset.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2015, 2022, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2015, 2024, Oracle and/or its affiliates. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -72,7 +72,7 @@ public abstract class Dataset> implements Iterable example : data) { + for (Example example : this) { for (Feature f : example) { if (featureStats.containsKey(f.getName())) { if (!initialisedSparseCounts) { diff --git a/Core/src/main/java/org/tribuo/dataset/DatasetView.java b/Core/src/main/java/org/tribuo/dataset/DatasetView.java index d860a22c7..026a2a62e 100644 --- a/Core/src/main/java/org/tribuo/dataset/DatasetView.java +++ b/Core/src/main/java/org/tribuo/dataset/DatasetView.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2015, 2022, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2015, 2024, Oracle and/or its affiliates. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -373,9 +373,22 @@ public ImmutableOutputInfo getOutputInfo() { return outputIDInfo; } + @Override + public synchronized void shuffle(boolean shuffle) { + if (shuffle) { + indices = Util.randperm(size(), rng); + } else { + indices = null; + } + } + @Override public Iterator> iterator() { - return new ViewIterator<>(this); + if (indices != null) { + return new ShuffledViewIterator<>(this); + } else{ + return new ViewIterator<>(this); + } } @Override @@ -459,6 +472,29 @@ private static boolean validateIndices(int size, int[] indices) { return valid; } + private static final class ShuffledViewIterator> implements Iterator> { + + private int counter = 0; + private final DatasetView dataset; + + ShuffledViewIterator(DatasetView dataset) { + this.dataset = dataset; + } + + @Override + public boolean hasNext() { + return counter < dataset.size(); + } + + @Override + public Example next() { + Example example = dataset.getExample(dataset.indices[counter]); + counter++; + return example; + } + + } + private static final class ViewIterator> implements Iterator> { private int counter = 0; @@ -509,7 +545,7 @@ > DatasetViewProvenance(DatasetView dataset, boolean stor this.weighted = new BooleanProvenance(WEIGHTED,dataset.weighted); this.sampled = new BooleanProvenance(SAMPLED,dataset.sampled); this.tag = new StringProvenance(TAG,dataset.tag); - this.indices = storeIndices ? dataset.indices : new int[0]; + this.indices = storeIndices ? dataset.exampleIndices : new int[0]; } /** @@ -525,7 +561,7 @@ public DatasetViewProvenance(Map map) { this.sampled = ObjectProvenance.checkAndExtractProvenance(map,SAMPLED,BooleanProvenance.class, DatasetViewProvenance.class.getSimpleName()); @SuppressWarnings("unchecked") // List provenance cast ListProvenance listIndices = ObjectProvenance.checkAndExtractProvenance(map,INDICES,ListProvenance.class, DatasetViewProvenance.class.getSimpleName()); - if (listIndices.getList().size() > 0) { + if (!listIndices.getList().isEmpty()) { try { IntProvenance i = listIndices.getList().get(0); } catch (ClassCastException e) {