diff --git a/include/internal/csv_format.hpp b/include/internal/csv_format.hpp index 71b5ec0a..431e0c6f 100644 --- a/include/internal/csv_format.hpp +++ b/include/internal/csv_format.hpp @@ -10,6 +10,12 @@ namespace csv { class CSVReader; + /** Stores the inferred format of a CSV file. */ + struct CSVGuessResult { + char delim; + int header_row; + }; + /** Stores information about how to parse a CSV file. * Can be used to construct a csv::CSVReader. */ diff --git a/include/internal/csv_reader.cpp b/include/internal/csv_reader.cpp index 4c004784..a0e7470c 100644 --- a/include/internal/csv_reader.cpp +++ b/include/internal/csv_reader.cpp @@ -40,7 +40,7 @@ namespace csv { } } - CSVFormat CSVGuesser::guess_delim() { + CSVGuessResult CSVGuesser::guess_delim() { /** Guess the delimiter of a CSV by scanning the first 100 lines by * First assuming that the header is on the first row * If the first guess returns too few rows, then we move to the second @@ -49,7 +49,7 @@ namespace csv { CSVFormat format; if (!first_guess()) second_guess(); - return format.delimiter(this->delim).header_row(this->header_row); + return { delim, header_row }; } bool CSVGuesser::first_guess() { @@ -158,7 +158,7 @@ namespace csv { } /** Guess the delimiter used by a delimiter-separated values file */ - CSVFormat guess_format(csv::string_view filename, const std::vector& delims) { + CSVGuessResult guess_format(csv::string_view filename, const std::vector& delims) { internals::CSVGuesser guesser(filename, delims); return guesser.guess_delim(); } @@ -247,16 +247,18 @@ namespace csv { * */ CSVReader::CSVReader(csv::string_view filename, CSVFormat format) { - if (format.guess_delim()) - format = guess_format(filename, format.possible_delimiters); + /** Guess delimiter and header row */ + if (format.guess_delim()) { + auto guess_result = guess_format(filename, format.possible_delimiters); + format.delimiter(guess_result.delim); + format.header = guess_result.header_row; + } if (!format.col_names.empty()) { this->set_col_names(format.col_names); } - else { - header_row = format.header; - } + header_row = format.header; delimiter = format.get_delim(); quote_char = format.quote_char; strict = format.strict; @@ -272,9 +274,13 @@ namespace csv { CSVFormat CSVReader::get_format() const { CSVFormat format; format.delimiter(this->delimiter) - .quote(this->quote_char) - .header_row(this->header_row) - .column_names(this->col_names->col_names); + .quote(this->quote_char); + + // Since users are normally not allowed to set + // column names and header row simulatenously, + // we will set the backing variables directly here + format.col_names = this->col_names->col_names; + format.header = this->header_row; return format; } diff --git a/include/internal/csv_reader.hpp b/include/internal/csv_reader.hpp index 46a00eed..864290b4 100644 --- a/include/internal/csv_reader.hpp +++ b/include/internal/csv_reader.hpp @@ -268,7 +268,7 @@ namespace csv { public: CSVGuesser(csv::string_view _filename, const std::vector& _delims) : filename(_filename), delims(_delims) {}; - CSVFormat guess_delim(); + CSVGuessResult guess_delim(); bool first_guess(); void second_guess(); diff --git a/include/internal/csv_utility.hpp b/include/internal/csv_utility.hpp index ef243473..cc9d3445 100644 --- a/include/internal/csv_utility.hpp +++ b/include/internal/csv_utility.hpp @@ -29,7 +29,7 @@ namespace csv { ///@{ std::unordered_map csv_data_types(const std::string&); CSVFileInfo get_file_info(const std::string& filename); - CSVFormat guess_format(csv::string_view filename, + CSVGuessResult guess_format(csv::string_view filename, const std::vector& delims = { ',', '|', '\t', ';', '^', '~' }); std::vector get_col_names( const std::string& filename, diff --git a/tests/test_read_csv.cpp b/tests/test_read_csv.cpp index a5b8f9b3..8c3c1481 100644 --- a/tests/test_read_csv.cpp +++ b/tests/test_read_csv.cpp @@ -22,24 +22,48 @@ TEST_CASE("col_pos() Test", "[test_col_pos]") { } TEST_CASE("guess_delim() Test - Pipe", "[test_guess_pipe]") { - CSVFormat format = guess_format( + CSVGuessResult format = guess_format( "./tests/data/real_data/2009PowerStatus.txt"); - REQUIRE(format.get_delim() == '|'); - REQUIRE(format.get_header() == 0); + REQUIRE(format.delim == '|'); + REQUIRE(format.header_row == 0); } TEST_CASE("guess_delim() Test - Semi-Colon", "[test_guess_scolon]") { - CSVFormat format = guess_format( + CSVGuessResult format = guess_format( "./tests/data/real_data/YEAR07_CBSA_NAC3.txt"); - REQUIRE(format.get_delim() == ';'); - REQUIRE(format.get_header() == 0); + REQUIRE(format.delim == ';'); + REQUIRE(format.header_row == 0); } TEST_CASE("guess_delim() Test - CSV with Comments", "[test_guess_comment]") { - CSVFormat format = guess_format( + CSVGuessResult format = guess_format( "./tests/data/fake_data/ints_comments.csv"); - REQUIRE(format.get_delim() == ','); - REQUIRE(format.get_header() == 5); + REQUIRE(format.delim == ','); + REQUIRE(format.header_row == 5); +} + +TEST_CASE("Prevent Column Names From Being Overwritten", "[csv_col_names_overwrite]") { + std::vector column_names = { "A1", "A2", "A3", "A4", "A5", "A6", "A7", "A8", "A9", "A10" }; + + // Test against a variety of different CSVFormat objects + std::vector formats = {}; + formats.push_back(CSVFormat::GUESS_CSV); + formats.push_back(CSVFormat()); + formats.back().delimiter(std::vector({ ',', '\t', '|'})); + formats.push_back(CSVFormat()); + formats.back().delimiter(std::vector({ ',', '~'})); + + for (auto& format_in : formats) { + // Set up the CSVReader + format_in.column_names(column_names); + CSVReader reader("./tests/data/fake_data/ints_comments.csv", format_in); + + // Assert that column names weren't overwritten + CSVFormat format_out = reader.get_format(); + REQUIRE(reader.get_col_names() == column_names); + REQUIRE(format_out.get_delim() == ','); + REQUIRE(format_out.get_header() == 5); + } } // get_file_info()