Skip to content

Commit

Permalink
feat: add support for nan and inf (#72)
Browse files Browse the repository at this point in the history
  • Loading branch information
stepansergeevitch authored Oct 3, 2023
1 parent 0b8ebec commit 782b0f6
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 11 deletions.
18 changes: 13 additions & 5 deletions driver_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@ import (
"context"
"database/sql"
"fmt"
"math"
"os"
"reflect"
"runtime/debug"
"strings"
"testing"
"time"
Expand Down Expand Up @@ -59,35 +61,41 @@ func TestDriverQueryResult(t *testing.T) {
t.Errorf("failed unexpectedly with %v", err)
}
rows, err := db.Query(
"SELECT CAST('2020-01-03 19:08:45' AS DATETIME) as dt, CAST('2020-01-03' AS DATE) as d, CAST(1 AS INT) as i " +
"SELECT CAST('2020-01-03 19:08:45' AS DATETIME) as dt, CAST('2020-01-03' AS DATE) as d, CAST(1 AS INT) as i, CAST(-1/0 as FLOAT) as f " +
"UNION " +
"SELECT CAST('2021-01-03 19:38:34' AS DATETIME) as dt, CAST('2000-12-03' AS DATE) as d, CAST(2 AS INT) as i ORDER BY i")
"SELECT CAST('2021-01-03 19:38:34' AS DATETIME) as dt, CAST('2000-12-03' AS DATE) as d, CAST(2 AS INT) as i, CAST(0/0 as FLOAT) as f ORDER BY i")
if err != nil {
t.Errorf("db.Query returned an error: %v", err)
}
var dt, d time.Time
var i int
var f float64

expectedColumns := []string{"dt", "d", "i"}
expectedColumns := []string{"dt", "d", "i", "f"}
if columns, err := rows.Columns(); reflect.DeepEqual(expectedColumns, columns) && err != nil {
t.Errorf("columns are not equal (%v != %v) and error is %v", expectedColumns, columns, err)
}

if !rows.Next() {
t.Errorf("Next returned end of output")
}
assert(rows.Scan(&dt, &d, &i), nil, t, "Scan returned an error")
assert(rows.Scan(&dt, &d, &i, &f), nil, t, "Scan returned an error")
assert(dt, time.Date(2020, 01, 03, 19, 8, 45, 0, loc), t, "results not equal for datetime")
assert(d, time.Date(2020, 01, 03, 0, 0, 0, 0, loc), t, "results not equal for date")
assert(i, 1, t, "results not equal for int")
assert(f, math.Inf(-1), t, "results not equal for float")

if !rows.Next() {
t.Errorf("Next returned end of output")
}
assert(rows.Scan(&dt, &d, &i), nil, t, "Scan returned an error")
assert(rows.Scan(&dt, &d, &i, &f), nil, t, "Scan returned an error")
assert(dt, time.Date(2021, 01, 03, 19, 38, 34, 0, loc), t, "results not equal for datetime")
assert(d, time.Date(2000, 12, 03, 0, 0, 0, 0, loc), t, "results not equal for date")
assert(i, 2, t, "results not equal for int")
if !math.IsNaN(f) {
t.Log(string(debug.Stack()))
t.Errorf("results not equal for float Expected: NaN Got: %f", f)
}

if rows.Next() {
t.Errorf("Next didn't returned false, although no data is expected")
Expand Down
31 changes: 29 additions & 2 deletions rows.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"encoding/hex"
"fmt"
"io"
"math"
"reflect"
"strings"
"time"
Expand Down Expand Up @@ -94,6 +95,13 @@ func checkTypeValue(columnType string, val interface{}) error {
switch columnType {
case intType, longType, floatType, doubleType:
if _, ok := val.(float64); !ok {
if columnType == floatType || columnType == doubleType {
for _, v := range []string{"inf", "-inf", "nan", "-nan"} {
if val == v {
return nil
}
}
}
return fmt.Errorf("expected to convert a value to float64, but couldn't: %v", val)
}
return nil
Expand Down Expand Up @@ -142,6 +150,24 @@ func parseDateTimeValue(columnType string, value string) (driver.Value, error) {
return nil, fmt.Errorf("type not known: %s", columnType)
}

func parseFloatValue(val interface{}) (float64, error) {
if _, notNum := val.(string); notNum {
switch val.(string) {
case "inf":
return math.Inf(1), nil
case "-inf":
return math.Inf(-1), nil
case "nan":
return math.NaN(), nil
case "-nan":
return math.NaN(), nil
default:
return 0, fmt.Errorf("unknown float value: %s", val)
}
}
return val.(float64), nil
}

// parseSingleValue parses all columns types except arrays
func parseSingleValue(columnType string, val interface{}) (driver.Value, error) {
if err := checkTypeValue(columnType, val); err != nil {
Expand All @@ -154,9 +180,10 @@ func parseSingleValue(columnType string, val interface{}) (driver.Value, error)
case longType:
return int64(val.(float64)), nil
case floatType:
return float32(val.(float64)), nil
v, err := parseFloatValue(val)
return float32(v), err
case doubleType:
return val.(float64), nil
return parseFloatValue(val)
case textType:
return val.(string), nil
case dateType, pgDateType, timestampType, timestampNtzType, timestampTzType:
Expand Down
16 changes: 12 additions & 4 deletions rows_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"database/sql/driver"
"encoding/json"
"io"
"math"
"reflect"
"runtime/debug"
"testing"
Expand Down Expand Up @@ -58,9 +59,9 @@ func mockRows(isMultiStatement bool) driver.RowsNextResultSet {
"data":[
[null,1,0.312321,123213.321321,"text", "2080-12-31","1989-04-15 01:02:03","0002-01-01","1989-04-15 01:02:03.123456","1989-04-15 02:02:03.123456+00",1,[1,2,3],[[]],true, 123.12345678, [123.12345678], "\\x616263313233"],
[2,1,0.312321,123213.321321,"text","1970-01-01","1970-01-01 00:00:00","0001-01-01","1989-04-15 01:02:03.123457","1989-04-15 01:02:03.1234+05:30",1,[1,2,3],[[]],true, -123.12345678, [-123.12345678, 0.0], "\\x6162630A0AE3858D20E3858E5C"],
[3,null,0.312321,123213.321321,"text","1970-01-01","1970-01-01 00:00:00","0001-01-01","1989-04-15 01:02:03.123458","1989-04-15 01:02:03+01",1,[5,2,3,2],[["TEST","TEST1"],["TEST3"]],false, 0.0, [0.0], null],
[2,1,0.312321,123213.321321,"text","1970-01-01","1970-01-01 00:00:00","0001-01-01","1989-04-15 01:02:03.123457","1111-01-05 17:04:42.123456+05:53:28",1,[1,2,3],[[]],false, 123456781234567812345678.123456781234567812345678, [123456781234567812345678.12345678123456781234567812345678], null],
[2,1,0.312321,123213.321321,"text","1970-01-01","1970-01-01 00:00:00","0001-01-01","1989-04-15 01:02:03.123457","1989-04-15 02:02:03.123456-01",1,[1,2,3],[[]],null, null, [null], null]
[3,null,"inf",123213.321321,"text","1970-01-01","1970-01-01 00:00:00","0001-01-01","1989-04-15 01:02:03.123458","1989-04-15 01:02:03+01",1,[5,2,3,2],[["TEST","TEST1"],["TEST3"]],false, 0.0, [0.0], null],
[2,1,"-inf",123213.321321,"text","1970-01-01","1970-01-01 00:00:00","0001-01-01","1989-04-15 01:02:03.123457","1111-01-05 17:04:42.123456+05:53:28",1,[1,2,3],[[]],false, 123456781234567812345678.123456781234567812345678, [123456781234567812345678.12345678123456781234567812345678], null],
[2,1,"-nan",123213.321321,"text","1970-01-01","1970-01-01 00:00:00","0001-01-01","1989-04-15 01:02:03.123457","1989-04-15 02:02:03.123456-01",1,[1,2,3],[[]],null, null, [null], null]
],
"rows":5,
"statistics":{
Expand Down Expand Up @@ -177,7 +178,7 @@ func TestRowsNext(t *testing.T) {
assert(err, nil, t, "Next shouldn't return an error")
assert(dest[0], int32(3), t, "results not equal for int32")
assert(dest[1], nil, t, "results not equal for int64")
assert(dest[2], float32(0.312321), t, "results not equal for float32")
assert(dest[2], float32(math.Inf(1)), t, "results not equal for float32")
assert(dest[3], float64(123213.321321), t, "results not equal for float64")
assert(dest[4], "text", t, "results not equal for string")
assert(dest[13], false, t, "results not equal for boolean")
Expand All @@ -189,6 +190,7 @@ func TestRowsNext(t *testing.T) {
// Fourth row
err = rows.Next(dest)
assert(err, nil, t, "Next shouldn't return an error")
assert(dest[2], float32(math.Inf(-1)), t, "results not equal for float32")
assertDates(dest[9].(time.Time), time.Date(1111, 01, 5, 11, 11, 14, 123456000, loc), t, "")
assert(dest[13], false, t, "results not equal for boolean")
var long_double = 123456781234567812345678.12345678123456781234567812345678
Expand All @@ -200,6 +202,12 @@ func TestRowsNext(t *testing.T) {
// Fifth row
err = rows.Next(dest)
assert(err, nil, t, "Next shouldn't return an error")
// Cannot do assert since NaN != NaN according to the standard
// math.IsNaN only works for float64, converting float32 NaN to float64 results in 0
if !(dest[2].(float32) != dest[2].(float32)) {
t.Log(string(debug.Stack()))
t.Errorf("results not equal for float32 Expected: NaN Got: %s", dest[2])
}
assertDates(dest[9].(time.Time), time.Date(1989, 4, 15, 3, 2, 3, 123456000, loc), t, "")
assert(dest[13], nil, t, "results not equal for boolean")
assert(dest[14], nil, t, "results not equal for decimal")
Expand Down

0 comments on commit 782b0f6

Please sign in to comment.