-
Notifications
You must be signed in to change notification settings - Fork 2
/
connection_producer.go
203 lines (167 loc) · 6.81 KB
/
connection_producer.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
// Copyright (c) YugaByteDB, Inc.
//
//Licensed to YugabyteDB, Inc. under one or more contributor license agreements.
//See the NOTICE file distributed with this work for additional information regarding
//copyright ownership.
//
//YugabyteDB licenses this file to you under the MPL version 2.0 (the "License");
//you may not use this file except in compliance with the License.
//You may obtain a copy of the License at
//
//https://mozilla.org/MPL/2.0/
package ysql
import (
"context"
"database/sql"
"errors"
"fmt"
"sync"
"time"
"github.com/hashicorp/vault/sdk/database/helper/dbutil"
"github.com/mitchellh/mapstructure"
_ "github.com/yugabyte/pgx/v5/stdlib"
)
// YugabyteDBConnectionProducer implements ConnectionProducer and provides a generic producer for most yugabyte databases
type YugabyteDBConnectionProducer struct {
ConnectionURL string `json:"connection_url" mapstructure:"connection_url" structs:"connection_url"`
MaxOpenConnections int `json:"max_open_connections" mapstructure:"max_open_connections" structs:"max_open_connections"`
MaxIdleConnections int `json:"max_idle_connections" mapstructure:"max_idle_connections" structs:"max_idle_connections"`
MaxConnectionLifetimeRaw interface{} `json:"max_connection_lifetime" mapstructure:"max_connection_lifetime" structs:"max_connection_lifetime"`
Host string `json:"host" mapstructure:"host" structs:"host"`
Username string `json:"username" mapstructure:"username" structs:"username"`
Password string `json:"password" mapstructure:"password" structs:"password"`
Port int `json:"port" mapstructure:"port" structs:"port"`
DbName string `json:"db" mapstructure:"db" structs:"db"`
LoadBalance bool `json:"load_balance" mapstructure:"load_balance" structs:"load_balance"`
YbServersRefreshInterval int `json:"yb_servers_refresh_interval" mapstructure:"yb_servers_refresh_interval" structs:"yb_servers_refresh_interval"`
TopologyKeys string `json:"topology_keys" mapstructure:"topology_keys" structs:"topology_keys"`
SslMode string `json:"sslmode" mapstructure:"sslmode" structs:"sslmode"`
SslRootCert string `json:"sslrootcert" mapstructure:"sslrootcert" structs:"sslrootcert"`
SslSni string `json:"sslsni" mapstructure:"sslsni" structs:"sslsni"`
SslKey string `json:"sslkey" mapstructure:"sslkey" structs:"sslkey"`
SslCert string `json:"sslcert" mapstructure:"sslcert" structs:"sslcert"`
SslPassword string `json:"sslpassword" mapstructure:"sslpassword" structs:"sslpassword"`
Type string
RawConfig map[string]interface{}
maxConnectionLifetime time.Duration
Initialized bool
db *sql.DB
sync.Mutex
}
var ErrNotInitialized = errors.New("connection has not been initialized")
func (c *YugabyteDBConnectionProducer) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error {
_, err := c.Init(ctx, conf, verifyConnection)
return err
}
func (c *YugabyteDBConnectionProducer) Init(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (map[string]interface{}, error) {
c.Lock()
defer c.Unlock()
c.RawConfig = conf
decoderConfig := &mapstructure.DecoderConfig{
Result: c,
WeaklyTypedInput: true,
TagName: "json",
}
decoder, err := mapstructure.NewDecoder(decoderConfig)
if err != nil {
return nil, err
}
err = decoder.Decode(conf)
if err != nil {
return nil, err
}
switch {
case len(c.ConnectionURL) != 0:
break //As the connection will be produced through it
case len(c.Host) == 0:
return nil, fmt.Errorf("host cannot be empty")
case len(c.Username) == 0:
return nil, fmt.Errorf("username cannot be empty")
case len(c.Password) == 0:
return nil, fmt.Errorf("password cannot be empty")
}
// Don't escape special characters for YugabyteDB password
// Also don't escape special characters for the username and password if
// the disable_escaping parameter is set to true
username := c.Username
password := c.Password
// QueryHelper doesn't do any SQL escaping, but if it starts to do so
// then maybe we won't be able to use it to do URL substitution any more.
c.ConnectionURL = dbutil.QueryHelper(c.ConnectionURL, map[string]string{
"username": username,
"password": password,
})
// Set initialized to true at this point since all fields are set,
// and the connection can be established at a later time.
c.Initialized = true
if verifyConnection {
if _, err := c.Connection(ctx); err != nil {
return nil, fmt.Errorf("error verifying connection: %s", err)
}
if err := c.db.PingContext(ctx); err != nil {
return nil, fmt.Errorf("error verifying connection: %s", err)
}
}
return c.RawConfig, nil
}
func (c *YugabyteDBConnectionProducer) Connection(ctx context.Context) (interface{}, error) {
if !c.Initialized {
return nil, ErrNotInitialized
}
// If we already have a DB, test it and return
if c.db != nil {
if err := c.db.PingContext(ctx); err == nil {
return c.db, nil
}
// If the ping was unsuccessful, close it and ignore errors as we'll be
// reestablishing anyways
c.db.Close()
}
if c.SslMode == "" {
c.SslMode = "prefer" //default sslmode
}
var conn string
if c.TopologyKeys != "" {
conn = fmt.Sprintf("host=%s port=%d user=%s "+
"password=%s dbname=%s sslmode=%s load_balance=%v yb_servers_refresh_interval=%d topology_keys=%s ", c.Host, c.Port, c.Username, c.Password, c.DbName, c.SslMode, c.LoadBalance, c.YbServersRefreshInterval, c.TopologyKeys)
} else {
conn = fmt.Sprintf("host=%s port=%d user=%s "+
"password=%s dbname=%s sslmode=%s load_balance=%v yb_servers_refresh_interval=%d ", c.Host, c.Port, c.Username, c.Password, c.DbName, c.SslMode, c.LoadBalance, c.YbServersRefreshInterval)
}
if c.SslRootCert != "" {
conn = fmt.Sprintf(conn + fmt.Sprintf("sslrootcert=%s ", c.SslRootCert))
}
if c.SslCert != "" {
conn = fmt.Sprintf(conn + fmt.Sprintf("sslcert=%s ", c.SslCert))
}
if c.SslKey != "" {
conn = fmt.Sprintf(conn + fmt.Sprintf("sslkey=%s ", c.SslKey))
}
if c.SslPassword != "" {
conn = fmt.Sprintf(conn + fmt.Sprintf("sslpassword=%s ", c.SslPassword))
}
if c.SslSni != "" {
conn = fmt.Sprintf(conn + fmt.Sprintf("sslsni=%s", c.SslSni))
}
if len(c.ConnectionURL) != 0 {
conn = c.ConnectionURL
}
//attempt to make connection
var err error
c.db, err = sql.Open("pgx", conn)
if err != nil {
return nil, err
}
return c.db, nil
}
// Close attempts to close the connection
func (c *YugabyteDBConnectionProducer) Close() error {
// Grab the write lock
c.Lock()
defer c.Unlock()
if c.db != nil {
c.db.Close()
}
c.db = nil
return nil
}