Skip to content

Commit

Permalink
switch the pggen tool to jackc/pgx (#160)
Browse files Browse the repository at this point in the history
As part of the continued campaign to migrate from lib/pq
to jackc/pgx, this patch switches the database driver that
pggen itself uses when it is connecting to the database to
get at the meta tables. It seems that the two divers translate
one of the more esoteric system types differently, so some
code had to be adapted for that, but otherwise there is no
need to adapt.
  • Loading branch information
ethanpailes authored Mar 19, 2021
1 parent 0fc2e11 commit e323f65
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 36 deletions.
4 changes: 2 additions & 2 deletions gen/gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
"strings"

"github.com/BurntSushi/toml"
_ "github.com/lib/pq"
_ "github.com/jackc/pgx/v4/stdlib"

"github.com/opendoor-labs/pggen/gen/internal/config"
"github.com/opendoor-labs/pggen/gen/internal/log"
Expand Down Expand Up @@ -83,7 +83,7 @@ func FromConfig(config Config) (*Generator, error) {
connStr = os.Getenv(connStr[1:])
}

db, err = sql.Open("postgres", connStr)
db, err = sql.Open("pgx", connStr)
if err != nil {
db = nil
continue
Expand Down
27 changes: 14 additions & 13 deletions gen/internal/meta/meta.go
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ func (mc *Resolver) argsOfStmt(body string, argNamesSpec string) ([]Arg, error)
FROM pg_prepared_statements
WHERE statement = $1`, body).Scan(&types)
if err != nil {
return nil, err
return nil, fmt.Errorf("getting parameter types: %s", err.Error())
}

argNames, err := argNamesToSlice(argNamesSpec, len(types.pgTypes))
Expand All @@ -242,7 +242,7 @@ func (mc *Resolver) argsOfStmt(body string, argNamesSpec string) ([]Arg, error)
name := argNames[i]
typeInfo, err := mc.typeResolver.TypeInfoOf(t)
if err != nil {
return nil, err
return nil, fmt.Errorf("resolving type info: %s", err.Error())
}
args = append(args, Arg{
Idx: i + 1,
Expand All @@ -261,25 +261,26 @@ type RegTypeArray struct {

// Scan implements the `sql.Scanner` interface
func (r *RegTypeArray) Scan(src interface{}) error {
buff, ok := src.([]byte)
// buff, ok := src.([]byte)
regArrayString, ok := src.(string)
if !ok {
return fmt.Errorf("[]regtype Scan: expected a []byte")
return fmt.Errorf("[]regtype Scan: expected a string")
}

if buff[0] != '{' || buff[len(buff)-1] != '}' {
return fmt.Errorf("[]regtype Scan: malformed data '%s'", string(buff))
if regArrayString[0] != '{' || regArrayString[len(regArrayString)-1] != '}' {
return fmt.Errorf("[]regtype Scan: malformed data '%s'", regArrayString)
}
buff = buff[1 : len(buff)-1]
regArrayString = regArrayString[1 : len(regArrayString)-1]

if len(buff) == 0 {
if len(regArrayString) == 0 {
r.pgTypes = []string{}
return nil
}

for len(buff) > 0 {
for len(regArrayString) > 0 {
var ty string
var err error
ty, buff, err = splitType(buff)
ty, regArrayString, err = splitType(regArrayString)
if err != nil {
return err
}
Expand All @@ -291,7 +292,7 @@ func (r *RegTypeArray) Scan(src interface{}) error {

// given a comma separated list of possibly quoted values,
// splitType takes the first one off the `types` slice.
func splitType(types []byte) (ty string, rest []byte, err error) {
func splitType(types string) (ty string, rest string, err error) {
switch types[0] {
case '"':
for i := 1; i < len(types); i++ {
Expand All @@ -312,7 +313,7 @@ func splitType(types []byte) (ty string, rest []byte, err error) {
} else {
// s[len(s):] is an error rather than returning the
// empty slice, which is why we need this special case.
rest = []byte{}
rest = ""
}

return
Expand All @@ -337,7 +338,7 @@ func splitType(types []byte) (ty string, rest []byte, err error) {

// the last (non-quoted) type
ty = string(types)
rest = []byte{}
rest = ""
return
}

Expand Down
44 changes: 23 additions & 21 deletions gen/internal/meta/meta_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,28 +58,30 @@ func TestSplitType(t *testing.T) {
}

for i, v := range testVecs {
inputBytes := []byte(v.input)
t.Run(v.input, func(t *testing.T) {
t.Parallel()

var actual RegTypeArray
err := actual.Scan(inputBytes)
if err != nil &&
(!strings.Contains(err.Error(), v.expectedErr) ||
len(v.expectedErr) == 0) {
t.Errorf(
"\n(case %d) Error: %s\n Expected Error: %s\n",
i,
err.Error(),
v.expectedErr,
)
}
var actual RegTypeArray
err := actual.Scan(v.input)
if err != nil &&
(!strings.Contains(err.Error(), v.expectedErr) ||
len(v.expectedErr) == 0) {
t.Errorf(
"\n(case %d) Error: %s\n Expected Error: %s\n",
i,
err.Error(),
v.expectedErr,
)
}

if !reflect.DeepEqual(actual.pgTypes, v.expected) {
t.Errorf(
"\n(case %d) Actual: %#v\n Expected: %#v\n",
i,
actual.pgTypes,
v.expected,
)
}
if !reflect.DeepEqual(actual.pgTypes, v.expected) {
t.Errorf(
"\n(case %d) Actual: %#v\n Expected: %#v\n",
i,
actual.pgTypes,
v.expected,
)
}
})
}
}

0 comments on commit e323f65

Please sign in to comment.