Skip to content

Commit

Permalink
feat(#23): Prepare
Browse files Browse the repository at this point in the history
  • Loading branch information
Samuel Abramov committed Oct 20, 2023
1 parent af01c5a commit 8c969fd
Show file tree
Hide file tree
Showing 9 changed files with 108 additions and 19 deletions.
Original file line number Diff line number Diff line change
@@ -1,19 +1,24 @@
package de.edux.data.handler;

import java.util.ArrayList;
import java.util.List;

public class DropIncompleteRecordsHandler implements IIncompleteRecordsHandler {
@Override
public List<String[]> getCleanedDataset(List<String[]> dataset) {
List<String[]> cleanedDataset =
List<String[]> filteredList =
dataset.stream().filter(this::containsOnlyCompletedFeatures).toList();

if (cleanedDataset.size() < dataset.size() * 0.5) {
if (filteredList.size() < dataset.size() * 0.5) {
throw new RuntimeException(
"More than 50% of the records will be dropped with this IncompleteRecordsHandlerStrategy. "
+ "Consider using another IncompleteRecordsHandlerStrategy or handle this exception.");
}

List<String[]> cleanedDataset = new ArrayList<>();
for (String[] item : filteredList) {
cleanedDataset.add(item);
}
return cleanedDataset;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package de.edux.data.provider;

import java.util.List;
import java.util.Optional;

public abstract class DataPostProcessor<T> {
public abstract void normalize(List<T> rowDataset);
Expand All @@ -21,4 +22,10 @@ public abstract class DataPostProcessor<T> {

public abstract double[][] getTestFeatures();

public abstract Optional<Integer> getIndexOfColumn(String columnName);

public abstract String[] getColumnDataOf(String columnName);

public abstract String[] getColumnNames();

}
53 changes: 43 additions & 10 deletions lib/src/main/java/de/edux/data/provider/DataProcessor.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,15 @@
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Optional;

public abstract class DataProcessor<T> extends DataPostProcessor<T> implements IDataUtil<T> {
private static final Logger LOG = LoggerFactory.getLogger(DataProcessor.class);
private final IDataReader csvDataReader;
private ArrayList<T> dataset;
private Dataset<T> splitedDataset;
private String[] columnNames;
private List<String[]> rawDataset;

public DataProcessor() {
this.csvDataReader = new CSVIDataReader();
Expand All @@ -27,23 +30,24 @@ public DataProcessor(IDataReader csvDataReader) {
}

@Override
public List<T> loadDataSetFromCSV(File csvFile, char csvSeparator, boolean normalize, boolean shuffle, EIncompleteRecordsHandlerStrategy incompleteRecordHandlerStrategy) {
List<String[]> rawDataset = csvDataReader.readFile(csvFile, csvSeparator);
List<String[]> cleanedDataset = incompleteRecordHandlerStrategy.getHandler().getCleanedDataset(rawDataset);
public List<T> loadDataSetFromCSV(File csvFile, char csvSeparator, boolean skipHeadline, boolean shuffle, EIncompleteRecordsHandlerStrategy incompleteRecordHandlerStrategy) {
rawDataset = csvDataReader.readFile(csvFile, csvSeparator);

List<T> unmodifiableDataset = cleanedDataset
if (skipHeadline) {
columnNames = rawDataset.remove(0);
} else {
columnNames = rawDataset.get(0);
}
List<String[]> csvDataset = incompleteRecordHandlerStrategy.getHandler().getCleanedDataset(rawDataset);

List<T> unmodifiableDataset = csvDataset
.stream()
.map(this::mapToDataRecord)
.toList();

dataset = new ArrayList<>(unmodifiableDataset);
LOG.info("Dataset loaded");

if (normalize) {
normalize(dataset);
LOG.info("Dataset normalized");
}

if (shuffle) {
Collections.shuffle(dataset);
LOG.info("Dataset shuffled");
Expand All @@ -54,7 +58,7 @@ public List<T> loadDataSetFromCSV(File csvFile, char csvSeparator, boolean norma
/**
* Split data into train and test data
*
* @param data data to split
* @param data data to split
* @param trainTestSplitRatio ratio of train data
* @return list of train and test data. First element is train data, second element is test data.
*/
Expand All @@ -72,12 +76,41 @@ public Dataset<T> split(List<T> data, double trainTestSplitRatio) {
splitedDataset = new Dataset<>(trainDataset, testDataset);
return splitedDataset;
}

public ArrayList<T> getDataset() {
return dataset;
}

public Dataset<T> getSplitedDataset() {
return splitedDataset;
}

@Override
public Optional<Integer> getIndexOfColumn(String columnName) {
for (int i = 0; i < columnNames.length; i++) {
if (columnNames[i].equals(columnName)) {
return Optional.of(i);
}
}
return Optional.empty();
}

public String[] getColumnDataOf(String columnName) {
Optional<Integer> index = getIndexOfColumn(columnName);
if (index.isEmpty()) {
throw new IllegalArgumentException("Column name not found");
}
int columnIndex = index.get();
String[] columnData = new String[rawDataset.size()];
for (int i = 0; i < rawDataset.size(); i++) {
columnData[i] = rawDataset.get(i)[columnIndex];
}
return columnData;
}

@Override
public String[] getColumnNames() {
return columnNames;
}
}

3 changes: 1 addition & 2 deletions lib/src/main/java/de/edux/data/provider/IDataUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,11 @@
import java.util.List;

public interface IDataUtil<T> {
List<T> loadDataSetFromCSV(File csvFile, char csvSeparator, boolean normalize, boolean shuffle, EIncompleteRecordsHandlerStrategy IncompleteRecordHandlerStrategy);
List<T> loadDataSetFromCSV(File csvFile, char csvSeparator, boolean skipHeadline, boolean shuffle, EIncompleteRecordsHandlerStrategy IncompleteRecordHandlerStrategy);

Dataset<T> split(List<T> dataset, double trainTestSplitRatio);

double[][] getInputs(List<T> dataset);

double[][] getTargets(List<T> dataset);

}
3 changes: 1 addition & 2 deletions lib/src/main/java/de/edux/data/reader/CSVIDataReader.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,12 @@

public class CSVIDataReader implements IDataReader {

public List<String[]> readFile(File file, char separator) {
public List<String[]> readFile(File file, char separator ) {
CSVParser customCSVParser = new CSVParserBuilder().withSeparator(separator).build();
List<String[]> result;
try(CSVReader reader = new CSVReaderBuilder(
new FileReader(file))
.withCSVParser(customCSVParser)
.withSkipLines(1)
.build()){
result = reader.readAll();
} catch (CsvException | IOException e) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package de.edux.data.handler;

import de.edux.data.provider.SeabornDataProcessor;
import de.edux.data.provider.SeabornProvider;
import org.junit.jupiter.api.Test;

import java.io.File;
import java.net.URL;
import java.util.Optional;

class DropIncompleteRecordsHandlerTest {
private static final boolean SHUFFLE = true;
private static final boolean SKIP_HEADLINE = true;
private static final EIncompleteRecordsHandlerStrategy INCOMPLETE_RECORD_HANDLER_STRATEGY = EIncompleteRecordsHandlerStrategy.DROP_RECORDS;
private static final double TRAIN_TEST_SPLIT_RATIO = 0.7;
private static final String CSV_FILE_PATH = "testdatasets/seaborn-penguins/penguins.csv";
private SeabornProvider seabornProvider;

@Test
void shouldReturnColumnData() {
URL url = DropIncompleteRecordsHandlerTest.class.getClassLoader().getResource(CSV_FILE_PATH);
if (url == null) {
throw new IllegalStateException("Cannot find file: " + CSV_FILE_PATH);
}
File csvFile = new File(url.getPath());
var seabornDataProcessor = new SeabornDataProcessor();
var dataset = seabornDataProcessor.loadDataSetFromCSV(csvFile, ',', true, true, INCOMPLETE_RECORD_HANDLER_STRATEGY);
seabornDataProcessor.normalize(dataset);
Optional<Integer> indexOfSpecies = seabornDataProcessor.getIndexOfColumn("species");
String[] speciesData = seabornDataProcessor.getColumnDataOf("species");

assert indexOfSpecies.isPresent();
assert speciesData.length > 0;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Optional;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
Expand Down Expand Up @@ -146,6 +147,11 @@ public double[][] getTestLabels() {
public double[][] getTestFeatures() {
return new double[0][];
}

@Override
public Optional<Integer> getIndexOfColumn(String columnName) {
return Optional.empty();
}
};
}

Expand Down
4 changes: 4 additions & 0 deletions lib/src/test/java/de/edux/data/provider/SeabornProvider.java
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ private double[][] featuresOf(List<Penguin> data) {

for (int i = 0; i < data.size(); i++) {
Penguin p = data.get(i);
if (p == null){
continue;
/* throw new IllegalArgumentException("Missed value in dataset, try to use Imputation methods");*/
}
features[i][0] = p.billLengthMm();
features[i][1] = p.billDepthMm();
features[i][2] = p.flipperLengthMm();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

class MultilayerPerceptronTest {
private static final boolean SHUFFLE = true;
private static final boolean NORMALIZE = true;
private static final boolean SKIP_HEADLINE = true;
private static final EIncompleteRecordsHandlerStrategy INCOMPLETE_RECORD_HANDLER_STRATEGY = EIncompleteRecordsHandlerStrategy.DROP_RECORDS;
private static final double TRAIN_TEST_SPLIT_RATIO = 0.7;
private static final String CSV_FILE_PATH = "testdatasets/seaborn-penguins/penguins.csv";
Expand All @@ -32,10 +32,11 @@ void setUp() {
}
File csvFile = new File(url.getPath());
var seabornDataProcessor = new SeabornDataProcessor();
var dataset = seabornDataProcessor.loadDataSetFromCSV(csvFile, ',', SHUFFLE, NORMALIZE, INCOMPLETE_RECORD_HANDLER_STRATEGY);
var dataset = seabornDataProcessor.loadDataSetFromCSV(csvFile, ',', true, true, INCOMPLETE_RECORD_HANDLER_STRATEGY);
seabornDataProcessor.normalize(dataset);
var trainTestSplittedList = seabornDataProcessor.split(dataset, TRAIN_TEST_SPLIT_RATIO);
seabornProvider = new SeabornProvider(dataset, trainTestSplittedList.trainData(), trainTestSplittedList.testData());

System.out.println("SeabornProvider loaded");
}

@RepeatedTest(3)
Expand Down

0 comments on commit 8c969fd

Please sign in to comment.