Skip to content

Commit

Permalink
Refactor GetConverter and handle LowCardinality() type
Browse files Browse the repository at this point in the history
  • Loading branch information
adamyeats committed May 10, 2024
1 parent 0039bf0 commit 0d457e3
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 44 deletions.
129 changes: 85 additions & 44 deletions pkg/converters/converters.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package converters

import (
"encoding/json"
"errors"
"fmt"
"math/big"
"net"
Expand Down Expand Up @@ -29,6 +30,7 @@ var matchRegexes = map[string]*regexp.Regexp{
"Decimal": regexp.MustCompile(`^Decimal`),
"FixedString()": regexp.MustCompile(`^Nullable\(FixedString\(.*\)\)`),
"IP": regexp.MustCompile(`^IPv[4,6]`),
"LowCardinality()": regexp.MustCompile(`^LowCardinality\(([^)]*)\)`),
"Map()": regexp.MustCompile(`^Map\(.*\)`),
"Nested()": regexp.MustCompile(`^Nested\(.*\)`),
"Nullable(Date)": regexp.MustCompile(`^Nullable\(Date\(?`),
Expand All @@ -41,162 +43,166 @@ var matchRegexes = map[string]*regexp.Regexp{
}

var Converters = map[string]Converter{
"String": {
fieldType: data.FieldTypeString,
scanType: reflect.PointerTo(reflect.TypeOf("")),
},
"Bool": {
fieldType: data.FieldTypeBool,
scanType: reflect.PtrTo(reflect.TypeOf(true)),
scanType: reflect.PointerTo(reflect.TypeOf(true)),
},
"Nullable(Bool)": {
fieldType: data.FieldTypeNullableBool,
scanType: reflect.PtrTo(reflect.PtrTo(reflect.TypeOf(true))),
scanType: reflect.PointerTo(reflect.PointerTo(reflect.TypeOf(true))),
},
"Float64": {
fieldType: data.FieldTypeFloat64,
scanType: reflect.PtrTo(reflect.TypeOf(float64(0))),
scanType: reflect.PointerTo(reflect.TypeOf(float64(0))),
},
"Float32": {
fieldType: data.FieldTypeFloat32,
scanType: reflect.PtrTo(reflect.TypeOf(float32(0))),
scanType: reflect.PointerTo(reflect.TypeOf(float32(0))),
},
"Nullable(Float32)": {
fieldType: data.FieldTypeNullableFloat32,
scanType: reflect.PtrTo(reflect.PtrTo(reflect.TypeOf(float32(0)))),
scanType: reflect.PointerTo(reflect.PointerTo(reflect.TypeOf(float32(0)))),
},
"Nullable(Float64)": {
fieldType: data.FieldTypeNullableFloat64,
scanType: reflect.PtrTo(reflect.PtrTo(reflect.TypeOf(float64(0)))),
scanType: reflect.PointerTo(reflect.PointerTo(reflect.TypeOf(float64(0)))),
},
"Int64": {
fieldType: data.FieldTypeInt64,
scanType: reflect.PtrTo(reflect.TypeOf(int64(0))),
scanType: reflect.PointerTo(reflect.TypeOf(int64(0))),
},
"Int32": {
fieldType: data.FieldTypeInt32,
scanType: reflect.PtrTo(reflect.TypeOf(int32(0))),
scanType: reflect.PointerTo(reflect.TypeOf(int32(0))),
},
"Int16": {
fieldType: data.FieldTypeInt16,
scanType: reflect.PtrTo(reflect.TypeOf(int16(0))),
scanType: reflect.PointerTo(reflect.TypeOf(int16(0))),
},
"Int8": {
fieldType: data.FieldTypeInt8,
scanType: reflect.PtrTo(reflect.TypeOf(int8(0))),
scanType: reflect.PointerTo(reflect.TypeOf(int8(0))),
},
"UInt64": {
fieldType: data.FieldTypeUint64,
scanType: reflect.PtrTo(reflect.TypeOf(uint64(0))),
scanType: reflect.PointerTo(reflect.TypeOf(uint64(0))),
},
"UInt32": {
fieldType: data.FieldTypeUint32,
scanType: reflect.PtrTo(reflect.TypeOf(uint32(0))),
scanType: reflect.PointerTo(reflect.TypeOf(uint32(0))),
},
"UInt16": {
fieldType: data.FieldTypeUint16,
scanType: reflect.PtrTo(reflect.TypeOf(uint16(0))),
scanType: reflect.PointerTo(reflect.TypeOf(uint16(0))),
},
"UInt8": {
fieldType: data.FieldTypeUint8,
scanType: reflect.PtrTo(reflect.TypeOf(uint8(0))),
scanType: reflect.PointerTo(reflect.TypeOf(uint8(0))),
},
"Nullable(UInt64)": {
fieldType: data.FieldTypeNullableUint64,
scanType: reflect.PtrTo(reflect.PtrTo(reflect.TypeOf(uint64(0)))),
scanType: reflect.PointerTo(reflect.PointerTo(reflect.TypeOf(uint64(0)))),
},
"Nullable(UInt32)": {
fieldType: data.FieldTypeNullableUint32,
scanType: reflect.PtrTo(reflect.PtrTo(reflect.TypeOf(uint32(0)))),
scanType: reflect.PointerTo(reflect.PointerTo(reflect.TypeOf(uint32(0)))),
},
"Nullable(UInt16)": {
fieldType: data.FieldTypeNullableUint16,
scanType: reflect.PtrTo(reflect.PtrTo(reflect.TypeOf(uint16(0)))),
scanType: reflect.PointerTo(reflect.PointerTo(reflect.TypeOf(uint16(0)))),
},
"Nullable(UInt8)": {
fieldType: data.FieldTypeNullableUint8,
scanType: reflect.PtrTo(reflect.PtrTo(reflect.TypeOf(uint8(0)))),
scanType: reflect.PointerTo(reflect.PointerTo(reflect.TypeOf(uint8(0)))),
},
"Nullable(Int64)": {
fieldType: data.FieldTypeNullableInt64,
scanType: reflect.PtrTo(reflect.PtrTo(reflect.TypeOf(int64(0)))),
scanType: reflect.PointerTo(reflect.PointerTo(reflect.TypeOf(int64(0)))),
},
"Nullable(Int32)": {
fieldType: data.FieldTypeNullableInt32,
scanType: reflect.PtrTo(reflect.PtrTo(reflect.TypeOf(int32(0)))),
scanType: reflect.PointerTo(reflect.PointerTo(reflect.TypeOf(int32(0)))),
},
"Nullable(Int16)": {
fieldType: data.FieldTypeNullableInt16,
scanType: reflect.PtrTo(reflect.PtrTo(reflect.TypeOf(int16(0)))),
scanType: reflect.PointerTo(reflect.PointerTo(reflect.TypeOf(int16(0)))),
},
"Nullable(Int8)": {
fieldType: data.FieldTypeNullableInt8,
scanType: reflect.PtrTo(reflect.PtrTo(reflect.TypeOf(int8(0)))),
scanType: reflect.PointerTo(reflect.PointerTo(reflect.TypeOf(int8(0)))),
},
// this is in precise and in appropriate for any math, but everything goes to floats in JS anyway
"Int128": {
convert: bigIntConvert,
fieldType: data.FieldTypeFloat64,
scanType: reflect.PtrTo(reflect.TypeOf(big.NewInt(0))),
scanType: reflect.PointerTo(reflect.TypeOf(big.NewInt(0))),
},
"Nullable(Int128)": {
convert: bigIntNullableConvert,
fieldType: data.FieldTypeNullableFloat64,
scanType: reflect.PtrTo(reflect.PtrTo(reflect.TypeOf(big.NewInt(0)))),
scanType: reflect.PointerTo(reflect.PointerTo(reflect.TypeOf(big.NewInt(0)))),
},
"Int256": {
convert: bigIntConvert,
fieldType: data.FieldTypeFloat64,
scanType: reflect.PtrTo(reflect.TypeOf(big.NewInt(0))),
scanType: reflect.PointerTo(reflect.TypeOf(big.NewInt(0))),
},
"Nullable(Int256)": {
convert: bigIntNullableConvert,
fieldType: data.FieldTypeNullableFloat64,
scanType: reflect.PtrTo(reflect.PtrTo(reflect.TypeOf(big.NewInt(0)))),
scanType: reflect.PointerTo(reflect.PointerTo(reflect.TypeOf(big.NewInt(0)))),
},
"UInt128": {
convert: bigIntConvert,
fieldType: data.FieldTypeFloat64,
scanType: reflect.PtrTo(reflect.TypeOf(big.NewInt(0))),
scanType: reflect.PointerTo(reflect.TypeOf(big.NewInt(0))),
},
"Nullable(UInt128)": {
convert: bigIntNullableConvert,
fieldType: data.FieldTypeNullableFloat64,
scanType: reflect.PtrTo(reflect.PtrTo(reflect.TypeOf(big.NewInt(0)))),
scanType: reflect.PointerTo(reflect.PointerTo(reflect.TypeOf(big.NewInt(0)))),
},
"UInt256": {
convert: bigIntConvert,
fieldType: data.FieldTypeFloat64,
scanType: reflect.PtrTo(reflect.TypeOf(big.NewInt(0))),
scanType: reflect.PointerTo(reflect.TypeOf(big.NewInt(0))),
},
"Nullable(UInt256)": {
convert: bigIntNullableConvert,
fieldType: data.FieldTypeNullableFloat64,
scanType: reflect.PtrTo(reflect.PtrTo(reflect.TypeOf(big.NewInt(0)))),
scanType: reflect.PointerTo(reflect.PointerTo(reflect.TypeOf(big.NewInt(0)))),
},
// covers DateTime with tz, DateTime64 - see regexes, Date32
"Date": {
fieldType: data.FieldTypeTime,
matchRegex: matchRegexes["Date"],
scanType: reflect.PtrTo(reflect.TypeOf(time.Time{})),
scanType: reflect.PointerTo(reflect.TypeOf(time.Time{})),
},
"Nullable(Date)": {
fieldType: data.FieldTypeNullableTime,
matchRegex: matchRegexes["Nullable(Date)"],
scanType: reflect.PtrTo(reflect.PtrTo(reflect.TypeOf(time.Time{}))),
scanType: reflect.PointerTo(reflect.PointerTo(reflect.TypeOf(time.Time{}))),
},
"Nullable(String)": {
fieldType: data.FieldTypeNullableString,
matchRegex: matchRegexes["Nullable(String)"],
scanType: reflect.PtrTo(reflect.PtrTo(reflect.TypeOf(""))),
scanType: reflect.PointerTo(reflect.PointerTo(reflect.TypeOf(""))),
},
"Decimal": {
convert: decimalConvert,
fieldType: data.FieldTypeFloat64,
matchRegex: matchRegexes["Decimal"],
scanType: reflect.PtrTo(reflect.TypeOf(decimal.Decimal{})),
scanType: reflect.PointerTo(reflect.TypeOf(decimal.Decimal{})),
},
"Nullable(Decimal)": {
convert: decimalNullConvert,
fieldType: data.FieldTypeNullableFloat64,
matchRegex: matchRegexes["Nullable(Decimal)"],
scanType: reflect.PtrTo(reflect.PtrTo(reflect.TypeOf(decimal.Decimal{}))),
scanType: reflect.PointerTo(reflect.PointerTo(reflect.TypeOf(decimal.Decimal{}))),
},
"Tuple()": {
convert: jsonConverter,
Expand Down Expand Up @@ -226,19 +232,19 @@ var Converters = map[string]Converter{
"FixedString()": {
fieldType: data.FieldTypeNullableString,
matchRegex: matchRegexes["FixedString()"],
scanType: reflect.PtrTo(reflect.PtrTo(reflect.TypeOf(""))),
scanType: reflect.PointerTo(reflect.PointerTo(reflect.TypeOf(""))),
},
"IP": {
convert: ipConverter,
fieldType: data.FieldTypeString,
matchRegex: matchRegexes["IP"],
scanType: reflect.PtrTo(reflect.TypeOf(net.IP{})),
scanType: reflect.PointerTo(reflect.TypeOf(net.IP{})),
},
"Nullable(IP)": {
convert: ipNullConverter,
fieldType: data.FieldTypeNullableString,
matchRegex: matchRegexes["Nullable(IP)"],
scanType: reflect.PtrTo(reflect.PtrTo(reflect.TypeOf(net.IP{}))),
scanType: reflect.PointerTo(reflect.PointerTo(reflect.TypeOf(net.IP{}))),
},
"SimpleAggregateFunction()": {
convert: jsonConverter,
Expand All @@ -264,19 +270,39 @@ func ClickHouseConverters() []sqlutil.Converter {
return list
}

// GetConverter returns a sqlutil.Converter for the given column type.
func GetConverter(columnType string) sqlutil.Converter {
converter, ok := Converters[columnType]
if ok {
// check for 'LowCardinality()' type first and get the converter for the inner type
if ok, innerType := extractLowCardinalityType(columnType); ok {
return GetConverter(innerType)
}

// direct match or regex-based match in `Converters` map
if converter, ok := Converters[columnType]; ok {
return createConverter(columnType, converter)
}

// regex-based search through `Converters` map
return findConverterWithRegex(columnType)
}

// extractLowCardinalityType checks if the column type is a `LowCardinality()` type and returns the inner type.
func extractLowCardinalityType(columnType string) (bool, string) {
if matches := matchRegexes["LowCardinality()"].FindStringSubmatch(columnType); len(matches) > 1 {
return true, matches[1]
}

return false, ""
}

// findConverterWithRegex searches through the `Converters` map using regex matching.
func findConverterWithRegex(columnType string) sqlutil.Converter {
for name, converter := range Converters {
if name == columnType {
return createConverter(name, converter)
}
if converter.matchRegex != nil && converter.matchRegex.MatchString(columnType) {
return createConverter(name, converter)
}
}

return sqlutil.Converter{}
}

Expand Down Expand Up @@ -317,7 +343,22 @@ func defaultConvert(in interface{}) (interface{}, error) {
if in == nil {
return reflect.Zero(reflect.TypeOf(in)).Interface(), nil
}
return reflect.ValueOf(in).Elem().Interface(), nil

// check the type of the input and handle strings separately because they cannot be dereferenced
val := reflect.ValueOf(in)
if val.Kind() == reflect.String {
return in, nil
}

// handle pointers and dereference if possible
if val.Kind() == reflect.Ptr {
if val.IsNil() {
return nil, errors.New("nil pointer cannot be dereferenced in defaultConvert")
}
return val.Elem().Interface(), nil
}

return in, nil
}

func decimalConvert(in interface{}) (interface{}, error) {
Expand Down
16 changes: 16 additions & 0 deletions pkg/converters/converters_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -595,3 +595,19 @@ func TestPoint(t *testing.T) {
assert.Nil(t, err)
assert.Equal(t, msg, *v.(*json.RawMessage))
}

func TestLowCardinality(t *testing.T) {
value := "value"
sut := converters.GetConverter("LowCardinality(String)")
v, err := sut.FrameConverter.ConverterFunc(value)
assert.Nil(t, err)
assert.Equal(t, value, v)
}

func TestLowCardinalityNullable(t *testing.T) {
value := "value"
sut := converters.GetConverter("LowCardinality(Nullable(String))")
v, err := sut.FrameConverter.ConverterFunc(&value)
assert.Nil(t, err)
assert.Equal(t, value, v)
}

0 comments on commit 0d457e3

Please sign in to comment.