From 782b0f6a9bbfcaecd559f39a3403cbf663419cc7 Mon Sep 17 00:00:00 2001 From: Stepan Burlakov Date: Tue, 3 Oct 2023 14:19:32 +0300 Subject: [PATCH] feat: add support for nan and inf (#72) --- driver_integration_test.go | 18 +++++++++++++----- rows.go | 31 +++++++++++++++++++++++++++++-- rows_test.go | 16 ++++++++++++---- 3 files changed, 54 insertions(+), 11 deletions(-) diff --git a/driver_integration_test.go b/driver_integration_test.go index fd67352..f1bce9c 100644 --- a/driver_integration_test.go +++ b/driver_integration_test.go @@ -7,8 +7,10 @@ import ( "context" "database/sql" "fmt" + "math" "os" "reflect" + "runtime/debug" "strings" "testing" "time" @@ -59,16 +61,17 @@ func TestDriverQueryResult(t *testing.T) { t.Errorf("failed unexpectedly with %v", err) } rows, err := db.Query( - "SELECT CAST('2020-01-03 19:08:45' AS DATETIME) as dt, CAST('2020-01-03' AS DATE) as d, CAST(1 AS INT) as i " + + "SELECT CAST('2020-01-03 19:08:45' AS DATETIME) as dt, CAST('2020-01-03' AS DATE) as d, CAST(1 AS INT) as i, CAST(-1/0 as FLOAT) as f " + "UNION " + - "SELECT CAST('2021-01-03 19:38:34' AS DATETIME) as dt, CAST('2000-12-03' AS DATE) as d, CAST(2 AS INT) as i ORDER BY i") + "SELECT CAST('2021-01-03 19:38:34' AS DATETIME) as dt, CAST('2000-12-03' AS DATE) as d, CAST(2 AS INT) as i, CAST(0/0 as FLOAT) as f ORDER BY i") if err != nil { t.Errorf("db.Query returned an error: %v", err) } var dt, d time.Time var i int + var f float64 - expectedColumns := []string{"dt", "d", "i"} + expectedColumns := []string{"dt", "d", "i", "f"} if columns, err := rows.Columns(); reflect.DeepEqual(expectedColumns, columns) && err != nil { t.Errorf("columns are not equal (%v != %v) and error is %v", expectedColumns, columns, err) } @@ -76,18 +79,23 @@ func TestDriverQueryResult(t *testing.T) { if !rows.Next() { t.Errorf("Next returned end of output") } - assert(rows.Scan(&dt, &d, &i), nil, t, "Scan returned an error") + assert(rows.Scan(&dt, &d, &i, &f), nil, t, "Scan returned an error") assert(dt, time.Date(2020, 01, 03, 19, 8, 45, 0, loc), t, "results not equal for datetime") assert(d, time.Date(2020, 01, 03, 0, 0, 0, 0, loc), t, "results not equal for date") assert(i, 1, t, "results not equal for int") + assert(f, math.Inf(-1), t, "results not equal for float") if !rows.Next() { t.Errorf("Next returned end of output") } - assert(rows.Scan(&dt, &d, &i), nil, t, "Scan returned an error") + assert(rows.Scan(&dt, &d, &i, &f), nil, t, "Scan returned an error") assert(dt, time.Date(2021, 01, 03, 19, 38, 34, 0, loc), t, "results not equal for datetime") assert(d, time.Date(2000, 12, 03, 0, 0, 0, 0, loc), t, "results not equal for date") assert(i, 2, t, "results not equal for int") + if !math.IsNaN(f) { + t.Log(string(debug.Stack())) + t.Errorf("results not equal for float Expected: NaN Got: %f", f) + } if rows.Next() { t.Errorf("Next didn't returned false, although no data is expected") diff --git a/rows.go b/rows.go index 1182cd8..abe5d97 100644 --- a/rows.go +++ b/rows.go @@ -5,6 +5,7 @@ import ( "encoding/hex" "fmt" "io" + "math" "reflect" "strings" "time" @@ -94,6 +95,13 @@ func checkTypeValue(columnType string, val interface{}) error { switch columnType { case intType, longType, floatType, doubleType: if _, ok := val.(float64); !ok { + if columnType == floatType || columnType == doubleType { + for _, v := range []string{"inf", "-inf", "nan", "-nan"} { + if val == v { + return nil + } + } + } return fmt.Errorf("expected to convert a value to float64, but couldn't: %v", val) } return nil @@ -142,6 +150,24 @@ func parseDateTimeValue(columnType string, value string) (driver.Value, error) { return nil, fmt.Errorf("type not known: %s", columnType) } +func parseFloatValue(val interface{}) (float64, error) { + if _, notNum := val.(string); notNum { + switch val.(string) { + case "inf": + return math.Inf(1), nil + case "-inf": + return math.Inf(-1), nil + case "nan": + return math.NaN(), nil + case "-nan": + return math.NaN(), nil + default: + return 0, fmt.Errorf("unknown float value: %s", val) + } + } + return val.(float64), nil +} + // parseSingleValue parses all columns types except arrays func parseSingleValue(columnType string, val interface{}) (driver.Value, error) { if err := checkTypeValue(columnType, val); err != nil { @@ -154,9 +180,10 @@ func parseSingleValue(columnType string, val interface{}) (driver.Value, error) case longType: return int64(val.(float64)), nil case floatType: - return float32(val.(float64)), nil + v, err := parseFloatValue(val) + return float32(v), err case doubleType: - return val.(float64), nil + return parseFloatValue(val) case textType: return val.(string), nil case dateType, pgDateType, timestampType, timestampNtzType, timestampTzType: diff --git a/rows_test.go b/rows_test.go index b2fb80a..3d2fa4f 100644 --- a/rows_test.go +++ b/rows_test.go @@ -5,6 +5,7 @@ import ( "database/sql/driver" "encoding/json" "io" + "math" "reflect" "runtime/debug" "testing" @@ -58,9 +59,9 @@ func mockRows(isMultiStatement bool) driver.RowsNextResultSet { "data":[ [null,1,0.312321,123213.321321,"text", "2080-12-31","1989-04-15 01:02:03","0002-01-01","1989-04-15 01:02:03.123456","1989-04-15 02:02:03.123456+00",1,[1,2,3],[[]],true, 123.12345678, [123.12345678], "\\x616263313233"], [2,1,0.312321,123213.321321,"text","1970-01-01","1970-01-01 00:00:00","0001-01-01","1989-04-15 01:02:03.123457","1989-04-15 01:02:03.1234+05:30",1,[1,2,3],[[]],true, -123.12345678, [-123.12345678, 0.0], "\\x6162630A0AE3858D20E3858E5C"], - [3,null,0.312321,123213.321321,"text","1970-01-01","1970-01-01 00:00:00","0001-01-01","1989-04-15 01:02:03.123458","1989-04-15 01:02:03+01",1,[5,2,3,2],[["TEST","TEST1"],["TEST3"]],false, 0.0, [0.0], null], - [2,1,0.312321,123213.321321,"text","1970-01-01","1970-01-01 00:00:00","0001-01-01","1989-04-15 01:02:03.123457","1111-01-05 17:04:42.123456+05:53:28",1,[1,2,3],[[]],false, 123456781234567812345678.123456781234567812345678, [123456781234567812345678.12345678123456781234567812345678], null], - [2,1,0.312321,123213.321321,"text","1970-01-01","1970-01-01 00:00:00","0001-01-01","1989-04-15 01:02:03.123457","1989-04-15 02:02:03.123456-01",1,[1,2,3],[[]],null, null, [null], null] + [3,null,"inf",123213.321321,"text","1970-01-01","1970-01-01 00:00:00","0001-01-01","1989-04-15 01:02:03.123458","1989-04-15 01:02:03+01",1,[5,2,3,2],[["TEST","TEST1"],["TEST3"]],false, 0.0, [0.0], null], + [2,1,"-inf",123213.321321,"text","1970-01-01","1970-01-01 00:00:00","0001-01-01","1989-04-15 01:02:03.123457","1111-01-05 17:04:42.123456+05:53:28",1,[1,2,3],[[]],false, 123456781234567812345678.123456781234567812345678, [123456781234567812345678.12345678123456781234567812345678], null], + [2,1,"-nan",123213.321321,"text","1970-01-01","1970-01-01 00:00:00","0001-01-01","1989-04-15 01:02:03.123457","1989-04-15 02:02:03.123456-01",1,[1,2,3],[[]],null, null, [null], null] ], "rows":5, "statistics":{ @@ -177,7 +178,7 @@ func TestRowsNext(t *testing.T) { assert(err, nil, t, "Next shouldn't return an error") assert(dest[0], int32(3), t, "results not equal for int32") assert(dest[1], nil, t, "results not equal for int64") - assert(dest[2], float32(0.312321), t, "results not equal for float32") + assert(dest[2], float32(math.Inf(1)), t, "results not equal for float32") assert(dest[3], float64(123213.321321), t, "results not equal for float64") assert(dest[4], "text", t, "results not equal for string") assert(dest[13], false, t, "results not equal for boolean") @@ -189,6 +190,7 @@ func TestRowsNext(t *testing.T) { // Fourth row err = rows.Next(dest) assert(err, nil, t, "Next shouldn't return an error") + assert(dest[2], float32(math.Inf(-1)), t, "results not equal for float32") assertDates(dest[9].(time.Time), time.Date(1111, 01, 5, 11, 11, 14, 123456000, loc), t, "") assert(dest[13], false, t, "results not equal for boolean") var long_double = 123456781234567812345678.12345678123456781234567812345678 @@ -200,6 +202,12 @@ func TestRowsNext(t *testing.T) { // Fifth row err = rows.Next(dest) assert(err, nil, t, "Next shouldn't return an error") + // Cannot do assert since NaN != NaN according to the standard + // math.IsNaN only works for float64, converting float32 NaN to float64 results in 0 + if !(dest[2].(float32) != dest[2].(float32)) { + t.Log(string(debug.Stack())) + t.Errorf("results not equal for float32 Expected: NaN Got: %s", dest[2]) + } assertDates(dest[9].(time.Time), time.Date(1989, 4, 15, 3, 2, 3, 123456000, loc), t, "") assert(dest[13], nil, t, "results not equal for boolean") assert(dest[14], nil, t, "results not equal for decimal")