Skip to content

Commit

Permalink
Reconnection on db connection failure
Browse files Browse the repository at this point in the history
  • Loading branch information
safchain committed Sep 9, 2015
1 parent d44a1a1 commit 09fa245
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 93 deletions.
202 changes: 109 additions & 93 deletions ejabberd-go-auth.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package main

import (
"flag"
"flag"
"bufio"
"encoding/binary"
"os"
Expand All @@ -15,62 +15,68 @@ import (
)

type Config struct {
Driver string
Host string
Port string
User string
Pass string
Dbname string
Dbargs string
Table string
UserField string
PassField string
ServerField string
Driver string
Host string
Port string
User string
Pass string
Dbname string
Dbargs string
Table string
UserField string
PassField string
ServerField string
}

func auth(conf *Config, db *sql.DB, user string, server string, passwd string) bool {
func auth(conf *Config, db *sql.DB, user string,
server string, passwd string) (bool, error) {
var query string
var value string
var err error
if conf.ServerField != "" {
query = fmt.Sprintf(
"select %s from %s where %s = $1 and %s = $2 and %s = $3",
conf.UserField, conf.Table, conf.UserField, conf.PassField,
conf.ServerField)
err = db.QueryRow(query, user, passwd, server).Scan(&value)
} else {
query = fmt.Sprintf(
"select %s from %s where %s = $1 and %s = $2",
conf.UserField, conf.Table, conf.UserField, conf.PassField)
err = db.QueryRow(query, user, passwd).Scan(&value)
}

log.Printf("user: %s, server: %s, passwd: %s\n", user, server, passwd)

if conf.ServerField != "" {
query = fmt.Sprintf(
"select %s from %s where %s = $1 and %s = $2 and %s = $3",
conf.UserField, conf.Table, conf.UserField, conf.PassField,
conf.ServerField)
err = db.QueryRow(query, user, passwd, server).Scan(&value)
} else {
query = fmt.Sprintf(
"select %s from %s where %s = $1 and %s = $2",
conf.UserField, conf.Table, conf.UserField, conf.PassField)
err = db.QueryRow(query, user, passwd).Scan(&value)
}
if err != nil || value == "" {
return false
return false, err
}

return true
return true, nil
}

func isuser(conf *Config, db *sql.DB, user string, server string) bool {
func isuser(conf *Config, db *sql.DB, user string,
server string) (bool, error) {
var query string
var value string
var err error
if conf.ServerField != "" {
query = fmt.Sprintf(
"select %s from %s where %s = $1 and %s = $2",
conf.UserField, conf.Table, conf.UserField, conf.ServerField)
err = db.QueryRow(query, user, server).Scan(&value)
} else {
query = fmt.Sprintf(
"select %s from %s where %s = $1",
conf.UserField, conf.Table, conf.UserField)
err = db.QueryRow(query, user).Scan(&value)
}

if conf.ServerField != "" {
query = fmt.Sprintf(
"select %s from %s where %s = $1 and %s = $2",
conf.UserField, conf.Table, conf.UserField, conf.ServerField)
err = db.QueryRow(query, user, server).Scan(&value)
} else {
query = fmt.Sprintf(
"select %s from %s where %s = $1",
conf.UserField, conf.Table, conf.UserField)
err = db.QueryRow(query, user).Scan(&value)
}
if err != nil || value == "" {
return false
return false, err
}

return true
return true, nil
}

func GetSqlConnectionString(conf *Config) string {
Expand All @@ -79,80 +85,90 @@ func GetSqlConnectionString(conf *Config) string {
conf.Dbname, conf.Dbargs)
}

func OpenSqlConnection(conf *Config) *sql.DB {
func OpenSqlConnection(conf *Config) (*sql.DB, error) {
var err error

connectionString := GetSqlConnectionString(conf)
db, err := sql.Open(conf.Driver, connectionString)
if err != nil {
log.Fatal(err)
return nil, err
}

if err = db.Ping(); err != nil {
log.Fatal(err)
return nil, err
}

return db
return db, nil
}

func AuthLoop(conf *Config) {
db := OpenSqlConnection(conf)

bioIn := bufio.NewReader(os.Stdin)
bioOut := bufio.NewWriter(os.Stdout)

var success bool
var length uint16
var result uint16

for {
_ = binary.Read(bioIn, binary.BigEndian, &length)

buf := make([]byte, length)

r, _ := bioIn.Read(buf)
if r == 0 {
continue
}

data := strings.Split(string(buf), ":")
if data[0] == "auth" {
success = auth(conf, db, data[1], data[2], data[3])
} else if data[0] == "isuser" {
success = isuser(conf, db, data[1], data[2])
} else {
success = false
}

length = 2
binary.Write(bioOut, binary.BigEndian, &length)

if success != true {
result = 0
} else {
result = 1
}

binary.Write(bioOut, binary.BigEndian, &result)
bioOut.Flush()
}
db, err := OpenSqlConnection(conf)

bioIn := bufio.NewReader(os.Stdin)
bioOut := bufio.NewWriter(os.Stdout)

var success bool
var length uint16
var result uint16

for {
_ = binary.Read(bioIn, binary.BigEndian, &length)

buf := make([]byte, length)

r, _ := bioIn.Read(buf)
if r == 0 {
continue
}

if err != nil {
err = db.Ping()
}

if err == nil {
data := strings.Split(string(buf), ":")
if data[0] == "auth" {
success, err = auth(conf, db, data[1], data[2], data[3])
} else if data[0] == "isuser" {
success, err = isuser(conf, db, data[1], data[2])
} else {
success = false
}
} else {
success = false
}

length = 2
binary.Write(bioOut, binary.BigEndian, &length)

if success != true {
result = 0
} else {
result = 1
}

binary.Write(bioOut, binary.BigEndian, &result)
bioOut.Flush()
}
}

func main() {
filename := flag.String("conf", "/etc/ejabberd-pg-auth.ini",
"Config file with all the connection infos needed.")
flag.Parse()
filename := flag.String("conf", "/etc/ejabberd-pg-auth.ini",
"Config file with all the connection infos needed.")
flag.Parse()

cfg, err := ini.Load(*filename)
if err != nil {
log.Fatal(err)
log.Fatal(err)
}

conf := new(Config)
err = cfg.MapTo(conf)
if err != nil {
log.Fatal(err)
if err != nil {
log.Fatal(err)
}

AuthLoop(conf)
}
log.SetOutput(os.Stderr)

AuthLoop(conf)
}
1 change: 1 addition & 0 deletions ejabberd-go-auth.ini
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ Dbargs = sslmode=disable
Table = users
UserField = user
PassField = passwd
#ServerField =

0 comments on commit 09fa245

Please sign in to comment.