2015-10-27 11:04:19 +01:00
|
|
|
package postgresql
|
|
|
|
|
|
|
|
import (
|
2016-12-25 14:53:12 +01:00
|
|
|
"bytes"
|
2015-10-27 11:04:19 +01:00
|
|
|
"database/sql"
|
|
|
|
"fmt"
|
2016-11-06 09:23:33 +01:00
|
|
|
"log"
|
2016-12-25 15:19:14 +01:00
|
|
|
"unicode"
|
2015-12-21 10:54:24 +01:00
|
|
|
|
2015-10-27 11:04:19 +01:00
|
|
|
_ "github.com/lib/pq" //PostgreSQL db
|
|
|
|
)
|
|
|
|
|
|
|
|
// Config - provider config
|
|
|
|
type Config struct {
|
2016-11-06 09:49:37 +01:00
|
|
|
Host string
|
|
|
|
Port int
|
|
|
|
Database string
|
|
|
|
Username string
|
|
|
|
Password string
|
|
|
|
SSLMode string
|
|
|
|
ApplicationName string
|
|
|
|
Timeout int
|
|
|
|
ConnectTimeoutSec int
|
2015-10-27 11:04:19 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
// Client struct holding connection string
|
|
|
|
type Client struct {
|
|
|
|
username string
|
|
|
|
connStr string
|
|
|
|
}
|
|
|
|
|
2016-09-06 00:04:48 +02:00
|
|
|
// NewClient returns new client config
|
2015-10-27 11:04:19 +01:00
|
|
|
func (c *Config) NewClient() (*Client, error) {
|
2016-11-06 09:27:51 +01:00
|
|
|
// NOTE: dbname must come before user otherwise dbname will be set to
|
|
|
|
// user.
|
2016-12-25 14:53:12 +01:00
|
|
|
const dsnFmt = "host=%s port=%d dbname=%s user=%s password=%s sslmode=%s fallback_application_name=%s connect_timeout=%d"
|
2015-10-27 11:04:19 +01:00
|
|
|
|
2016-12-25 14:53:12 +01:00
|
|
|
// Quote empty strings or strings that contain whitespace
|
|
|
|
q := func(s string) string {
|
|
|
|
b := bytes.NewBufferString(`'`)
|
|
|
|
b.Grow(len(s) + 2)
|
2016-12-25 15:19:14 +01:00
|
|
|
var haveWhitespace bool
|
2016-12-25 14:53:12 +01:00
|
|
|
for _, r := range s {
|
2016-12-25 15:19:14 +01:00
|
|
|
if unicode.IsSpace(r) {
|
|
|
|
haveWhitespace = true
|
|
|
|
}
|
|
|
|
|
2016-12-25 14:53:12 +01:00
|
|
|
switch r {
|
|
|
|
case '\'':
|
|
|
|
b.WriteString(`\'`)
|
|
|
|
case '\\':
|
|
|
|
b.WriteString(`\\`)
|
|
|
|
default:
|
|
|
|
b.WriteRune(r)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
b.WriteString(`'`)
|
2016-12-25 15:19:14 +01:00
|
|
|
|
|
|
|
str := b.String()
|
2016-12-26 00:26:54 +01:00
|
|
|
if haveWhitespace || len(str) == 2 {
|
2016-12-25 15:19:14 +01:00
|
|
|
return str
|
|
|
|
}
|
|
|
|
return str[1 : len(str)-1]
|
2016-12-25 14:53:12 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
logDSN := fmt.Sprintf(dsnFmt, q(c.Host), c.Port, q(c.Database), q(c.Username), q("<redacted>"), q(c.SSLMode), q(c.ApplicationName), c.ConnectTimeoutSec)
|
2016-11-06 09:23:33 +01:00
|
|
|
log.Printf("[INFO] PostgreSQL DSN: `%s`", logDSN)
|
|
|
|
|
2016-12-25 14:53:12 +01:00
|
|
|
connStr := fmt.Sprintf(dsnFmt, q(c.Host), c.Port, q(c.Database), q(c.Username), q(c.Password), q(c.SSLMode), q(c.ApplicationName), c.ConnectTimeoutSec)
|
2015-10-27 11:04:19 +01:00
|
|
|
client := Client{
|
|
|
|
connStr: connStr,
|
|
|
|
username: c.Username,
|
|
|
|
}
|
|
|
|
|
|
|
|
return &client, nil
|
|
|
|
}
|
|
|
|
|
2016-09-06 00:04:48 +02:00
|
|
|
// Connect will manually connect/disconnect to prevent a large
|
|
|
|
// number or db connections being made
|
2015-10-27 11:04:19 +01:00
|
|
|
func (c *Client) Connect() (*sql.DB, error) {
|
|
|
|
db, err := sql.Open("postgres", c.connStr)
|
|
|
|
if err != nil {
|
2016-12-17 03:01:40 +01:00
|
|
|
return nil, fmt.Errorf("Error connecting to PostgreSQL server: %v", err)
|
2015-10-27 11:04:19 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
return db, nil
|
|
|
|
}
|