Skip to content

Commit

Permalink
[Enhancement]Enhance paimon reader for decimal type (StarRocks#31570)
Browse files Browse the repository at this point in the history
Signed-off-by: leoyy0316 <[email protected]>
  • Loading branch information
leoyy0316 authored Oct 12, 2023
1 parent 3d7fcd7 commit 11cd428
Show file tree
Hide file tree
Showing 9 changed files with 149 additions and 48 deletions.
37 changes: 3 additions & 34 deletions be/src/exec/jni_scanner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,37 +180,6 @@ Status JniScanner::_append_string_data(const FillColumnArgs& args) {
return Status::OK();
}

template <LogicalType type>
Status JniScanner::_append_decimal_data(const FillColumnArgs& args) {
int* offset_ptr = static_cast<int*>(next_chunk_meta_as_ptr());
char* column_ptr = static_cast<char*>(next_chunk_meta_as_ptr());

using ColumnType = typename starrocks::RunTimeColumnType<type>;
using CppType = typename starrocks::RunTimeCppType<type>;
auto* runtime_column = down_cast<ColumnType*>(args.column);
runtime_column->resize_uninitialized(args.num_rows);
CppType* runtime_data = runtime_column->get_data().data();

int precision = args.slot_type.precision;
int scale = args.slot_type.scale;

for (int i = 0; i < args.num_rows; i++) {
if (args.nulls && args.nulls[i]) {
// NULL
} else {
std::string decimal_str(column_ptr + offset_ptr[i], column_ptr + offset_ptr[i + 1]);
CppType cpp_val;
if (DecimalV3Cast::from_string<CppType>(&cpp_val, precision, scale, decimal_str.data(),
decimal_str.size())) {
return Status::DataQualityError(
fmt::format("Invalid value occurs in column[{}], value is [{}]", args.slot_name, decimal_str));
}
runtime_data[i] = cpp_val;
}
}
return Status::OK();
}

Status JniScanner::_append_date_data(const FillColumnArgs& args) {
int* offset_ptr = static_cast<int*>(next_chunk_meta_as_ptr());
char* column_ptr = static_cast<char*>(next_chunk_meta_as_ptr());
Expand Down Expand Up @@ -407,11 +376,11 @@ Status JniScanner::_fill_column(FillColumnArgs* pargs) {
} else if (column_type == LogicalType::TYPE_DATETIME) {
RETURN_IF_ERROR((_append_datetime_data(args)));
} else if (column_type == LogicalType::TYPE_DECIMAL32) {
RETURN_IF_ERROR((_append_decimal_data<TYPE_DECIMAL32>(args)));
RETURN_IF_ERROR((_append_primitive_data<TYPE_DECIMAL32>(args)));
} else if (column_type == LogicalType::TYPE_DECIMAL64) {
RETURN_IF_ERROR((_append_decimal_data<TYPE_DECIMAL64>(args)));
RETURN_IF_ERROR((_append_primitive_data<TYPE_DECIMAL64>(args)));
} else if (column_type == LogicalType::TYPE_DECIMAL128) {
RETURN_IF_ERROR((_append_decimal_data<TYPE_DECIMAL128>(args)));
RETURN_IF_ERROR((_append_primitive_data<TYPE_DECIMAL128>(args)));
} else if (column_type == LogicalType::TYPE_ARRAY) {
RETURN_IF_ERROR((_append_array_data(args)));
} else if (column_type == LogicalType::TYPE_MAP) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import com.starrocks.jni.connector.ColumnType;
import com.starrocks.jni.connector.ColumnValue;
import org.apache.hadoop.hive.common.type.HiveDecimal;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.MapObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
Expand All @@ -24,6 +25,7 @@
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.io.LongWritable;

import java.math.BigDecimal;
import java.time.LocalDateTime;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -152,4 +154,9 @@ public void unpackStruct(List<Integer> structFieldIndex, List<ColumnValue> value
public byte getByte() {
throw new UnsupportedOperationException("Hoodie type does not support tinyint");
}

@Override
public BigDecimal getDecimal() {
return ((HiveDecimal) inspectObject()).bigDecimalValue();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ public class HudiSliceScanner extends ConnectorScanner {
private final int fetchSize;
private final ClassLoader classLoader;

public static final int MAX_DECIMAL32_PRECISION = 9;
public static final int MAX_DECIMAL64_PRECISION = 18;

public HudiSliceScanner(int fetchSize, Map<String, String> params) {
this.fetchSize = fetchSize;
this.hiveColumnNames = params.get("hive_column_names");
Expand Down Expand Up @@ -125,7 +128,12 @@ private void parseRequiredTypes() {
for (int i = 0; i < requiredFields.length; i++) {
requiredColumnIds[i] = hiveColumnNameToIndex.get(requiredFields[i]);
String type = hiveColumnNameToType.get(requiredFields[i]);
requiredTypes[i] = new ColumnType(requiredFields[i], type);

if (type.startsWith("decimal")) {
parseDecimal(type, i);
} else {
requiredTypes[i] = new ColumnType(requiredFields[i], type);
}
}

// prune fields
Expand All @@ -140,6 +148,28 @@ private void parseRequiredTypes() {
}
}

// convert decimal(x,y) to decimal
private void parseDecimal(String type, int i) {
int precision = -1;
int scale = -1;
int s = type.indexOf('(');
int e = type.indexOf(')');
if (s != -1 && e != -1) {
String[] ps = type.substring(s + 1, e).split(",");
precision = Integer.parseInt(ps[0].trim());
scale = Integer.parseInt(ps[1].trim());
if (precision <= MAX_DECIMAL32_PRECISION) {
type = "decimal32";
} else if (precision <= MAX_DECIMAL64_PRECISION) {
type = "decimal64";
} else {
type = "decimal128";
}
}
requiredTypes[i] = new ColumnType(requiredFields[i], type);
requiredTypes[i].setScale(scale);
}

private Properties makeProperties() {
Properties properties = new Properties();
properties.setProperty("hive.io.file.readcolumn.ids",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,10 @@ public enum TypeValue {
DATETIME_MICROS,
// INT64 timestamp type, TIMESTAMP(isAdjustedToUTC=true, unit=MILLIS)
DATETIME_MILLIS,
DECIMAL,
DECIMALV2,
DECIMAL32,
DECIMAL64,
DECIMAL128,
ARRAY,
MAP,
STRUCT,
Expand All @@ -55,7 +58,7 @@ public enum TypeValue {
List<String> childNames;
List<ColumnType> childTypes;
List<Integer> fieldIndex;

int scale = -1;
private static final Map<String, TypeValue> PRIMITIVE_TYPE_VALUE_MAPPING = new HashMap<>();
private static final Map<TypeValue, Integer> PRIMITIVE_TYPE_VALUE_SIZE = new HashMap<>();

Expand All @@ -73,7 +76,10 @@ public enum TypeValue {
PRIMITIVE_TYPE_VALUE_MAPPING.put("timestamp", TypeValue.DATETIME);
PRIMITIVE_TYPE_VALUE_MAPPING.put("timestamp-micros", TypeValue.DATETIME_MICROS);
PRIMITIVE_TYPE_VALUE_MAPPING.put("timestamp-millis", TypeValue.DATETIME_MILLIS);
PRIMITIVE_TYPE_VALUE_MAPPING.put("decimal", TypeValue.DECIMAL);
PRIMITIVE_TYPE_VALUE_MAPPING.put("decimalv2", TypeValue.DECIMALV2);
PRIMITIVE_TYPE_VALUE_MAPPING.put("decimal32", TypeValue.DECIMAL32);
PRIMITIVE_TYPE_VALUE_MAPPING.put("decimal64", TypeValue.DECIMAL64);
PRIMITIVE_TYPE_VALUE_MAPPING.put("decimal128", TypeValue.DECIMAL128);
PRIMITIVE_TYPE_VALUE_MAPPING.put("tinyint", TypeValue.TINYINT);

PRIMITIVE_TYPE_VALUE_SIZE.put(TypeValue.BYTE, 1);
Expand All @@ -84,6 +90,10 @@ public enum TypeValue {
PRIMITIVE_TYPE_VALUE_SIZE.put(TypeValue.LONG, 8);
PRIMITIVE_TYPE_VALUE_SIZE.put(TypeValue.DOUBLE, 8);
PRIMITIVE_TYPE_VALUE_SIZE.put(TypeValue.TINYINT, 1);
PRIMITIVE_TYPE_VALUE_SIZE.put(TypeValue.DECIMALV2, 16);
PRIMITIVE_TYPE_VALUE_SIZE.put(TypeValue.DECIMAL32, 4);
PRIMITIVE_TYPE_VALUE_SIZE.put(TypeValue.DECIMAL64, 8);
PRIMITIVE_TYPE_VALUE_SIZE.put(TypeValue.DECIMAL128, 16);
}

@Override
Expand Down Expand Up @@ -160,7 +170,12 @@ private void parseStruct(List<String> childNames, List<ColumnType> childTypeValu
}

private void parse(StringScanner scanner) {
int p = scanner.indexOf('<', ',', '>');
int p;
if (scanner.s.startsWith("decimal")) {
p = scanner.s.length();
} else {
p = scanner.indexOf('<', ',', '>');
}
String t = scanner.substr(p);
scanner.moveTo(p);
// assume there is no blank char in `type`.
Expand Down Expand Up @@ -191,10 +206,6 @@ private void parse(StringScanner scanner) {
}
break;
default: {
// convert decimal(x,y) to decimal
if (t.startsWith("decimal")) {
t = "decimal";
}
typeValue = PRIMITIVE_TYPE_VALUE_MAPPING.getOrDefault(t, null);
}
}
Expand All @@ -220,7 +231,7 @@ public ColumnType(String name, String type) {
}

public boolean isByteStorageType() {
return typeValue == TypeValue.STRING || typeValue == TypeValue.DATE || typeValue == TypeValue.DECIMAL
return typeValue == TypeValue.STRING || typeValue == TypeValue.DATE
|| typeValue == TypeValue.BINARY || typeValue == TypeValue.DATETIME
|| typeValue == TypeValue.DATETIME_MICROS || typeValue == TypeValue.DATETIME_MILLIS;
}
Expand Down Expand Up @@ -272,7 +283,10 @@ public int computeColumnSize() {
}
case STRING:
case BINARY:
case DECIMAL:
case DECIMALV2:
case DECIMAL32:
case DECIMAL64:
case DECIMAL128:
case DATE:
case DATETIME:
case DATETIME_MICROS:
Expand Down Expand Up @@ -370,4 +384,12 @@ public void buildNestedFieldsSpec(String top, StringBuilder sb) {
sb.append(',');
}
}

public void setScale(int scale) {
this.scale = scale;
}

public int getScale() {
return scale;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

package com.starrocks.jni.connector;

import java.math.BigDecimal;
import java.util.List;

public interface ColumnValue {
Expand All @@ -40,4 +41,6 @@ public interface ColumnValue {
void unpackStruct(List<Integer> structFieldIndex, List<ColumnValue> values);

byte getByte();

BigDecimal getDecimal();
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,12 @@

import com.starrocks.utils.Platform;

import java.math.BigDecimal;
import java.math.BigInteger;
import java.math.RoundingMode;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

/**
Expand Down Expand Up @@ -313,6 +317,32 @@ public double getDouble(int rowId) {
return Platform.getDouble(null, data + rowId * 8L);
}

public int appendDecimal(BigDecimal value) {
reserve(elementsAppended + 1);
putDecimal(elementsAppended, value);
return elementsAppended++;
}

private void putDecimal(int rowId, BigDecimal value) {
int typeSize = type.getPrimitiveTypeValueSize();
BigInteger dataValue = value.setScale(type.getScale(), RoundingMode.UNNECESSARY).unscaledValue();
byte[] bytes = changeByteOrder(dataValue.toByteArray());
byte[] newValue = new byte[typeSize];
if (dataValue.signum() == -1) {
Arrays.fill(newValue, (byte) -1);
}
System.arraycopy(bytes, 0, newValue, 0, Math.min(bytes.length, newValue.length));
Platform.copyMemory(newValue, Platform.BYTE_ARRAY_OFFSET, null, data + rowId * typeSize, typeSize);
}

public BigDecimal getDecimal(int rowId) {
int typeSize = type.getPrimitiveTypeValueSize();
byte[] bytes = new byte[typeSize];
Platform.copyMemory(null, data + (long) rowId * typeSize, bytes, Platform.BYTE_ARRAY_OFFSET, typeSize);
BigInteger value = new BigInteger(changeByteOrder(bytes));
return new BigDecimal(value, type.getScale());
}

private void putBytes(int rowId, int count, byte[] src, int srcIndex) {
Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, null, data + rowId, count);
}
Expand Down Expand Up @@ -472,9 +502,14 @@ public void appendValue(ColumnValue o) {
break;
case STRING:
case DATE:
case DECIMAL:
appendString(o.getString(typeValue));
break;
case DECIMALV2:
case DECIMAL32:
case DECIMAL64:
case DECIMAL128:
appendDecimal(o.getDecimal());
break;
case DATETIME:
case DATETIME_MICROS:
case DATETIME_MILLIS:
Expand Down Expand Up @@ -561,9 +596,14 @@ public void dump(StringBuilder sb, int i) {
case DATETIME:
case DATETIME_MICROS:
case DATETIME_MILLIS:
case DECIMAL:
sb.append(getUTF8String(i));
break;
case DECIMALV2:
case DECIMAL32:
case DECIMAL64:
case DECIMAL128:
sb.append(getDecimal(i));
break;
case ARRAY: {
int begin = getArrayOffset(i);
int end = getArrayOffset(i + 1);
Expand Down Expand Up @@ -646,4 +686,14 @@ public void checkMeta(OffHeapTable.MetaChecker checker) {
checker.check(context + "#data", data);
}
}

public byte[] changeByteOrder(byte[] bytes) {
int length = bytes.length;
for (int i = 0; i < length / 2; ++i) {
byte temp = bytes[i];
bytes[i] = bytes[length - 1 - i];
bytes[length - 1 - i] = temp;
}
return bytes;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import com.starrocks.jni.connector.ColumnType;
import com.starrocks.jni.connector.ColumnValue;
import org.apache.paimon.data.Decimal;
import org.apache.paimon.data.InternalArray;
import org.apache.paimon.data.InternalMap;
import org.apache.paimon.data.Timestamp;
Expand All @@ -27,6 +28,7 @@
import org.apache.paimon.types.RowType;
import org.apache.paimon.utils.InternalRowUtils;

import java.math.BigDecimal;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.util.List;
Expand Down Expand Up @@ -143,6 +145,10 @@ public byte getByte() {
return (byte) fieldData;
}

public BigDecimal getDecimal() {
return ((Decimal) fieldData).toBigDecimal();
}

private void toPaimonColumnValue(List<ColumnValue> values, InternalArray array, DataType dataType) {
for (int i = 0; i < array.size(); i++) {
PaimonColumnValue cv = new PaimonColumnValue(InternalRowUtils.get(array, i, dataType), dataType);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import org.apache.paimon.table.source.ReadBuilder;
import org.apache.paimon.table.source.Split;
import org.apache.paimon.types.DataType;
import org.apache.paimon.types.DecimalType;
import org.apache.paimon.types.RowType;
import org.apache.paimon.utils.InternalRowUtils;

Expand Down Expand Up @@ -128,6 +129,9 @@ private void parseRequiredTypes() {
DataType dataType = table.rowType().getTypeAt(index);
String type = PaimonTypeUtils.fromPaimonType(dataType);
requiredTypes[i] = new ColumnType(type);
if (dataType instanceof DecimalType) {
requiredTypes[i].setScale(((DecimalType) dataType).getScale());
}
logicalTypes[i] = dataType;
}

Expand Down
Loading

0 comments on commit 11cd428

Please sign in to comment.