Skip to content

Commit

Permalink
Merge pull request #18622 from jackdelv/xpathOnWriteParquet
Browse files Browse the repository at this point in the history
HPCC-31753 Parquet uses field names incorrectly

Reviewed-By: Dan S. Camper <[email protected]>
Reviewed-by: Gavin Halliday <[email protected]>
Merged-by: Gavin Halliday <[email protected]>
  • Loading branch information
ghalliday authored May 7, 2024
2 parents a42ebb4 + 8df89a0 commit 782f6b3
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 95 deletions.
146 changes: 66 additions & 80 deletions plugins/parquet/parquetembed.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -701,7 +701,7 @@ std::shared_ptr<arrow::NestedType> ParquetWriter::makeChildRecord(const RtlField

for (int i = 0; i < count; i++, fields++)
{
reportIfFailure(fieldToNode((*fields)->name, *fields, childFields));
reportIfFailure(fieldToNode(*fields, childFields));
}

return std::make_shared<arrow::StructType>(childFields);
Expand All @@ -712,54 +712,54 @@ std::shared_ptr<arrow::NestedType> ParquetWriter::makeChildRecord(const RtlField
const RtlTypeInfo *child = typeInfo->queryChildType();
const RtlFieldInfo childFieldInfo = RtlFieldInfo("", "", child);
std::vector<std::shared_ptr<arrow::Field>> childField;
reportIfFailure(fieldToNode(childFieldInfo.name, &childFieldInfo, childField));
reportIfFailure(fieldToNode(&childFieldInfo, childField));
return std::make_shared<arrow::ListType>(childField[0]);
}
}

/**
* @brief Converts an RtlFieldInfo object into an arrow field and adds it to the output vector.
*
* @param name The name of the field
* @param field The field containing metadata for the record.
* @param arrowFields Output vector for pushing new nodes to.
* @return Status of the operation
*/
arrow::Status ParquetWriter::fieldToNode(const std::string &name, const RtlFieldInfo *field, std::vector<std::shared_ptr<arrow::Field>> &arrowFields)
arrow::Status ParquetWriter::fieldToNode(const RtlFieldInfo *field, std::vector<std::shared_ptr<arrow::Field>> &arrowFields)
{
unsigned len = field->type->length;
StringBuffer name;
xpathOrName(name, field);

switch (field->type->getType())
{
case type_boolean:
arrowFields.push_back(std::make_shared<arrow::Field>(name, arrow::boolean()));
arrowFields.push_back(std::make_shared<arrow::Field>(name.str(), arrow::boolean()));
break;
case type_int:
if (field->type->isSigned())
{
if (len > 4)
if (field->type->length > 4)
{
arrowFields.push_back(std::make_shared<arrow::Field>(name, arrow::int64()));
arrowFields.push_back(std::make_shared<arrow::Field>(name.str(), arrow::int64()));
}
else
{
arrowFields.push_back(std::make_shared<arrow::Field>(name, arrow::int32()));
arrowFields.push_back(std::make_shared<arrow::Field>(name.str(), arrow::int32()));
}
}
else
{
if (len > 4)
if (field->type->length > 4)
{
arrowFields.push_back(std::make_shared<arrow::Field>(name, arrow::uint64()));
arrowFields.push_back(std::make_shared<arrow::Field>(name.str(), arrow::uint64()));
}
else
{
arrowFields.push_back(std::make_shared<arrow::Field>(name, arrow::uint32()));
arrowFields.push_back(std::make_shared<arrow::Field>(name.str(), arrow::uint32()));
}
}
break;
case type_real:
arrowFields.push_back(std::make_shared<arrow::Field>(name, arrow::float64()));
arrowFields.push_back(std::make_shared<arrow::Field>(name.str(), arrow::float64()));
break;
case type_char:
case type_string:
Expand All @@ -769,16 +769,16 @@ arrow::Status ParquetWriter::fieldToNode(const std::string &name, const RtlField
case type_unicode:
case type_varunicode:
case type_decimal:
arrowFields.push_back(std::make_shared<arrow::Field>(name, arrow::utf8())); //TODO add decimal encoding
arrowFields.push_back(std::make_shared<arrow::Field>(name.str(), arrow::utf8())); //TODO add decimal encoding
break;
case type_data:
arrowFields.push_back(std::make_shared<arrow::Field>(name, arrow::large_binary()));
arrowFields.push_back(std::make_shared<arrow::Field>(name.str(), arrow::large_binary()));
break;
case type_record:
arrowFields.push_back(std::make_shared<arrow::Field>(name, makeChildRecord(field)));
arrowFields.push_back(std::make_shared<arrow::Field>(name.str(), makeChildRecord(field)));
break;
case type_set:
arrowFields.push_back(std::make_shared<arrow::Field>(name, makeChildRecord(field)));
arrowFields.push_back(std::make_shared<arrow::Field>(name.str(), makeChildRecord(field)));
break;
default:
failx("Datatype %i is not compatible with this plugin.", field->type->getType());
Expand All @@ -802,7 +802,7 @@ arrow::Status ParquetWriter::fieldsToSchema(const RtlTypeInfo *typeInfo)

for (int i = 0; i < count; i++, fields++)
{
ARROW_RETURN_NOT_OK(fieldToNode((*fields)->name, *fields, arrowFields));
ARROW_RETURN_NOT_OK(fieldToNode(*fields, arrowFields));
}

schema = std::make_shared<arrow::Schema>(arrowFields);
Expand Down Expand Up @@ -834,15 +834,15 @@ arrow::Status ParquetWriter::fieldsToSchema(const RtlTypeInfo *typeInfo)
/**
* @brief Gets the child ArrayBuilder from the recordBatchBuilder and adds it to the stack.
*/
void ParquetWriter::beginSet(const char *fieldName)
void ParquetWriter::beginSet(const RtlFieldInfo *field)
{
if (!recordBatchBuilder)
{
PARQUET_ASSIGN_OR_THROW(recordBatchBuilder, arrow::RecordBatchBuilder::Make(schema, pool, maxRowCountInBatch));
}
arrow::ArrayBuilder *childBuilder;
arrow::FieldPath match = getNestedFieldBuilder(fieldName, childBuilder);
fieldBuilderStack.push_back(std::make_shared<ArrayBuilderTracker>(fieldName, childBuilder, CPNTSet, std::move(match)));
arrow::FieldPath match = getNestedFieldBuilder(field, childBuilder);
fieldBuilderStack.push_back(std::make_shared<ArrayBuilderTracker>(field, childBuilder, CPNTSet, std::move(match)));

arrow::ListBuilder *listBuilder = static_cast<arrow::ListBuilder *>(childBuilder);
reportIfFailure(listBuilder->Append());
Expand All @@ -851,17 +851,17 @@ void ParquetWriter::beginSet(const char *fieldName)
/**
* @brief Gets the child ArrayBuilder from the recordBatchBuilder and adds it to the stack.
*/
void ParquetWriter::beginRow(const char *fieldName)
void ParquetWriter::beginRow(const RtlFieldInfo *field)
{
if (!recordBatchBuilder)
{
PARQUET_ASSIGN_OR_THROW(recordBatchBuilder, arrow::RecordBatchBuilder::Make(schema, pool, maxRowCountInBatch));
}
else if (!strieq(fieldName, "<row>"))
else if (!strieq(field->name, "<row>"))
{
arrow::ArrayBuilder *childBuilder;
arrow::FieldPath match = getNestedFieldBuilder(fieldName, childBuilder);
fieldBuilderStack.push_back(std::make_shared<ArrayBuilderTracker>(fieldName, childBuilder, CPNTDataset, std::move(match)));
arrow::FieldPath match = getNestedFieldBuilder(field, childBuilder);
fieldBuilderStack.push_back(std::make_shared<ArrayBuilderTracker>(field, childBuilder, CPNTDataset, std::move(match)));

arrow::StructBuilder *structBuilder = static_cast<arrow::StructBuilder *>(childBuilder);
reportIfFailure(structBuilder->Append());
Expand Down Expand Up @@ -936,10 +936,14 @@ arrow::Status ParquetWriter::checkDirContents()
/**
* @brief Finds the correct field builder from the stack of nested field builders or from the RecordBatchBuilder.
*/
arrow::ArrayBuilder *ParquetWriter::getFieldBuilder(const char *fieldName)
arrow::ArrayBuilder *ParquetWriter::getFieldBuilder(const RtlFieldInfo *field)
{
if (fieldBuilderStack.empty())
return recordBatchBuilder->GetField(schema->GetFieldIndex(fieldName));
{
StringBuffer fieldName;
xpathOrName(fieldName, field);
return recordBatchBuilder->GetField(schema->GetFieldIndex(fieldName.str()));
}
else if (fieldBuilderStack.back()->nodeType == CPNTSet)
return static_cast<arrow::ListBuilder *>(fieldBuilderStack.back()->structPtr)->value_builder();
else
Expand All @@ -953,9 +957,12 @@ arrow::ArrayBuilder *ParquetWriter::getFieldBuilder(const char *fieldName)
* @param childBuilder Child builder for the nested field
* @return arrow::FieldPath A vector of indices to the nested field.
*/
arrow::FieldPath ParquetWriter::getNestedFieldBuilder(const char *fieldName, arrow::ArrayBuilder *&childBuilder)
arrow::FieldPath ParquetWriter::getNestedFieldBuilder(const RtlFieldInfo *field, arrow::ArrayBuilder *&childBuilder)
{
StringBuffer fieldName;
xpathOrName(fieldName, field);
arrow::FieldPath match;

if (fieldBuilderStack.empty())
{
PARQUET_ASSIGN_OR_THROW(match, arrow::FieldRef(fieldName).FindOne(*schema.get()));
Expand All @@ -979,9 +986,9 @@ arrow::FieldPath ParquetWriter::getNestedFieldBuilder(const char *fieldName, arr
/**
* @brief Helper method for adding string type fields to the ArrayBuilder
*/
void ParquetWriter::addFieldToBuilder(const char *fieldName, unsigned len, const char *data)
void ParquetWriter::addFieldToBuilder(const RtlFieldInfo *field, unsigned len, const char *data)
{
arrow::ArrayBuilder *fieldBuilder = getFieldBuilder(fieldName);
arrow::ArrayBuilder *fieldBuilder = getFieldBuilder(field);
switch(fieldBuilder->type()->id())
{
case arrow::Type::type::STRING:
Expand Down Expand Up @@ -1451,7 +1458,7 @@ void ParquetRowBuilder::processBeginSet(const RtlFieldInfo *field, bool &isAll)

if (arrayVisitor->type == ListType)
{
ParquetColumnTracker newPathNode(field->name, arrayVisitor->listArr, CPNTSet);
ParquetColumnTracker newPathNode(field, arrayVisitor->listArr, CPNTSet);
newPathNode.childCount = arrayVisitor->listArr->value_slice(currentRow)->length();
pathStack.push_back(newPathNode);
}
Expand Down Expand Up @@ -1490,27 +1497,17 @@ void ParquetRowBuilder::processBeginDataset(const RtlFieldInfo *field)
*/
void ParquetRowBuilder::processBeginRow(const RtlFieldInfo *field)
{
StringBuffer xpath;
xpathOrName(xpath, field);

if (!xpath.isEmpty())
if (strncmp(field->name, "<row>", 5) != 0)
{
if (strncmp(xpath, "<row>", 5) != 0)
nextField(field);
if (arrayVisitor->type == StructType)
{
nextField(field);
if (arrayVisitor->type == StructType)
{
pathStack.push_back(ParquetColumnTracker(field->name, arrayVisitor->structArr, CPNTScalar));
}
else
{
failx("proccessBeginRow: Incorrect type for row.");
}
pathStack.push_back(ParquetColumnTracker(field, arrayVisitor->structArr, CPNTScalar));
}
else
{
failx("proccessBeginRow: Incorrect type for row.");
}
}
else
{
failx("processBeginRow: Field name or xpath missing");
}
}

Expand All @@ -1533,10 +1530,7 @@ bool ParquetRowBuilder::processNextRow(const RtlFieldInfo *field)
*/
void ParquetRowBuilder::processEndSet(const RtlFieldInfo *field)
{
StringBuffer xpath;
xpathOrName(xpath, field);

if (!xpath.isEmpty() && !pathStack.empty() && strcmp(xpath.str(), pathStack.back().nodeName) == 0)
if (!pathStack.empty() && field->equivalent(pathStack.back().field))
{
pathStack.pop_back();
}
Expand All @@ -1559,26 +1553,16 @@ void ParquetRowBuilder::processEndDataset(const RtlFieldInfo *field)
*/
void ParquetRowBuilder::processEndRow(const RtlFieldInfo *field)
{
StringBuffer xpath;
xpathOrName(xpath, field);

if (!xpath.isEmpty())
if (!pathStack.empty())
{
if (!pathStack.empty())
if (pathStack.back().nodeType == CPNTDataset)
{
if (pathStack.back().nodeType == CPNTDataset)
{
pathStack.back().childrenProcessed++;
}
else if (strcmp(xpath.str(), pathStack.back().nodeName) == 0)
{
pathStack.pop_back();
}
pathStack.back().childrenProcessed++;
}
else if (field->equivalent(pathStack.back().field))
{
pathStack.pop_back();
}
}
else
{
failx("processEndRow: Field name or xpath missing");
}
}

Expand All @@ -1593,7 +1577,9 @@ void ParquetRowBuilder::nextFromStruct(const RtlFieldInfo *field)
reportIfFailure(structPtr->Accept(arrayVisitor.get()));
if (pathStack.back().nodeType == CPNTScalar)
{
auto child = arrayVisitor->structArr->GetFieldByName(field->name);
StringBuffer fieldName;
xpathOrName(fieldName, field);
auto child = arrayVisitor->structArr->GetFieldByName(fieldName.str());
reportIfFailure(child->Accept(arrayVisitor.get()));
}
else if (pathStack.back().nodeType == CPNTSet)
Expand Down Expand Up @@ -1675,7 +1661,7 @@ void bindStringParam(unsigned len, const char *value, const RtlFieldInfo *field,
rtlDataAttr utf8;
rtlStrToUtf8X(utf8chars, utf8.refstr(), len, value);

parquetWriter->addFieldToBuilder(field->name, rtlUtf8Size(utf8chars, utf8.getdata()), utf8.getstr());
parquetWriter->addFieldToBuilder(field, rtlUtf8Size(utf8chars, utf8.getdata()), utf8.getstr());
}

/**
Expand Down Expand Up @@ -1710,7 +1696,7 @@ void ParquetRecordBinder::processString(unsigned len, const char *value, const R
*/
void ParquetRecordBinder::processBool(bool value, const RtlFieldInfo *field)
{
arrow::ArrayBuilder *fieldBuilder = parquetWriter->getFieldBuilder(field->name);
arrow::ArrayBuilder *fieldBuilder = parquetWriter->getFieldBuilder(field);
if (fieldBuilder->type()->id() == arrow::Type::type::BOOL)
{
arrow::BooleanBuilder *boolBuilder = static_cast<arrow::BooleanBuilder *>(fieldBuilder);
Expand All @@ -1729,7 +1715,7 @@ void ParquetRecordBinder::processBool(bool value, const RtlFieldInfo *field)
*/
void ParquetRecordBinder::processData(unsigned len, const void *value, const RtlFieldInfo *field)
{
parquetWriter->addFieldToBuilder(field->name, len, (const char *)value);
parquetWriter->addFieldToBuilder(field, len, (const char *)value);
}

/**
Expand All @@ -1740,7 +1726,7 @@ void ParquetRecordBinder::processData(unsigned len, const void *value, const Rtl
*/
void ParquetRecordBinder::processInt(__int64 value, const RtlFieldInfo *field)
{
arrow::ArrayBuilder *fieldBuilder = parquetWriter->getFieldBuilder(field->name);
arrow::ArrayBuilder *fieldBuilder = parquetWriter->getFieldBuilder(field);
switch(fieldBuilder->type()->id())
{
case arrow::Type::type::INT32:
Expand Down Expand Up @@ -1768,7 +1754,7 @@ void ParquetRecordBinder::processInt(__int64 value, const RtlFieldInfo *field)
*/
void ParquetRecordBinder::processUInt(unsigned __int64 value, const RtlFieldInfo *field)
{
arrow::ArrayBuilder *fieldBuilder = parquetWriter->getFieldBuilder(field->name);
arrow::ArrayBuilder *fieldBuilder = parquetWriter->getFieldBuilder(field);
switch(fieldBuilder->type()->id())
{
case arrow::Type::type::UINT32:
Expand Down Expand Up @@ -1796,7 +1782,7 @@ void ParquetRecordBinder::processUInt(unsigned __int64 value, const RtlFieldInfo
*/
void ParquetRecordBinder::processReal(double value, const RtlFieldInfo *field)
{
arrow::ArrayBuilder *fieldBuilder = parquetWriter->getFieldBuilder(field->name);
arrow::ArrayBuilder *fieldBuilder = parquetWriter->getFieldBuilder(field);
if (fieldBuilder->type()->id() == arrow::Type::type::DOUBLE)
{
arrow::DoubleBuilder *doubleBuilder = static_cast<arrow::DoubleBuilder *>(fieldBuilder);
Expand All @@ -1822,7 +1808,7 @@ void ParquetRecordBinder::processDecimal(const void *value, unsigned digits, uns
val.setDecimal(digits, precision, value);
val.getStringX(bytes, decText.refstr());

parquetWriter->addFieldToBuilder(field->name, bytes, decText.getstr());
parquetWriter->addFieldToBuilder(field, bytes, decText.getstr());
}

/**
Expand All @@ -1838,7 +1824,7 @@ void ParquetRecordBinder::processUnicode(unsigned chars, const UChar *value, con
char *utf8;
rtlUnicodeToUtf8X(utf8chars, utf8, chars, value);

parquetWriter->addFieldToBuilder(field->name, rtlUtf8Size(utf8chars, utf8), utf8);
parquetWriter->addFieldToBuilder(field, rtlUtf8Size(utf8chars, utf8), utf8);
}

/**
Expand Down Expand Up @@ -1866,7 +1852,7 @@ void ParquetRecordBinder::processQString(unsigned len, const char *value, const
*/
void ParquetRecordBinder::processUtf8(unsigned chars, const char *value, const RtlFieldInfo *field)
{
parquetWriter->addFieldToBuilder(field->name, rtlUtf8Size(chars, value), value);
parquetWriter->addFieldToBuilder(field, rtlUtf8Size(chars, value), value);
}

/**
Expand Down
Loading

0 comments on commit 782f6b3

Please sign in to comment.