diff --git a/builtin/providers/postgresql/helpers.go b/builtin/providers/postgresql/helpers.go new file mode 100644 index 000000000..db61713dc --- /dev/null +++ b/builtin/providers/postgresql/helpers.go @@ -0,0 +1,24 @@ +package postgresql + +import ( + "fmt" + "strings" +) + +// pqQuoteLiteral returns a string literal safe for inclusion in a PostgreSQL +// query as a parameter. The resulting string still needs to be wrapped in +// single quotes in SQL (i.e. fmt.Sprintf(`'%s'`, pqQuoteLiteral("str"))). See +// quote_literal_internal() in postgresql/backend/utils/adt/quote.c:77. +func pqQuoteLiteral(in string) string { + in = strings.Replace(in, `\`, `\\`, -1) + in = strings.Replace(in, `'`, `''`, -1) + return in +} + +func validateConnLimit(v interface{}, key string) (warnings []string, errors []error) { + value := v.(int) + if value < -1 { + errors = append(errors, fmt.Errorf("%d can not be less than -1", key)) + } + return +} diff --git a/builtin/providers/postgresql/resource_postgresql_database.go b/builtin/providers/postgresql/resource_postgresql_database.go index 4d3fe297c..0f0f1c363 100644 --- a/builtin/providers/postgresql/resource_postgresql_database.go +++ b/builtin/providers/postgresql/resource_postgresql_database.go @@ -1,6 +1,7 @@ package postgresql import ( + "bytes" "database/sql" "errors" "fmt" @@ -112,134 +113,70 @@ func resourcePostgreSQLDatabaseCreate(d *schema.ResourceData, meta interface{}) } defer conn.Close() - stringOpts := []struct { - hclKey string - sqlKey string - }{ - {dbOwnerAttr, "OWNER"}, - {dbTemplateAttr, "TEMPLATE"}, - {dbEncodingAttr, "ENCODING"}, - {dbCollationAttr, "LC_COLLATE"}, - {dbCTypeAttr, "LC_CTYPE"}, - {dbTablespaceAttr, "TABLESPACE"}, - } - intOpts := []struct { - hclKey string - sqlKey string - }{ - {dbConnLimitAttr, "CONNECTION LIMIT"}, - } - boolOpts := []struct { - hclKey string - sqlKey string - }{ - {dbAllowConnsAttr, "ALLOW_CONNECTIONS"}, - {dbIsTemplateAttr, "IS_TEMPLATE"}, - } - - createOpts := make([]string, 0, len(stringOpts)+len(intOpts)+len(boolOpts)) - - for _, opt := range stringOpts { - v, ok := d.GetOk(opt.hclKey) - var val string - if !ok { - switch { - case opt.hclKey == dbOwnerAttr && v.(string) == "": - // No owner specified in the config, default to using - // the connecting username. - val = c.username - case strings.ToUpper(v.(string)) == "DEFAULT" && - (opt.hclKey == dbTemplateAttr || - opt.hclKey == dbEncodingAttr || - opt.hclKey == dbCollationAttr || - opt.hclKey == dbCTypeAttr): - - // Use the defaults from the template database - // as opposed to best practices. - fallthrough - default: - continue - } - } - - val = v.(string) - - switch { - case opt.hclKey == dbOwnerAttr && (val == "" || strings.ToUpper(val) == "DEFAULT"): - // Owner was blank/DEFAULT, default to using the connecting username. - val = c.username - d.Set(dbOwnerAttr, val) - case opt.hclKey == dbTablespaceAttr && (val == "" || strings.ToUpper(val) == "DEFAULT"): - val = "pg_default" - d.Set(dbTablespaceAttr, val) - case opt.hclKey == dbTemplateAttr: - switch { - case val == "": - val = "template0" - d.Set(dbTemplateAttr, val) - case strings.ToUpper(val) == "DEFAULT": - val = "" - default: - d.Set(dbTemplateAttr, val) - } - case opt.hclKey == dbEncodingAttr: - switch { - case val == "": - val = "UTF8" - d.Set(dbEncodingAttr, val) - case strings.ToUpper(val) == "DEFAULT": - val = "" - default: - d.Set(dbEncodingAttr, val) - } - case opt.hclKey == dbCollationAttr: - switch { - case val == "": - val = "C" - d.Set(dbCollationAttr, val) - case strings.ToUpper(val) == "DEFAULT": - val = "" - default: - d.Set(dbCollationAttr, val) - } - case opt.hclKey == dbCTypeAttr: - switch { - case val == "": - val = "C" - d.Set(dbCTypeAttr, val) - case strings.ToUpper(val) == "DEFAULT": - val = "" - default: - d.Set(dbCTypeAttr, val) - } - } - - if val != "" { - createOpts = append(createOpts, fmt.Sprintf("%s=%s", opt.sqlKey, pq.QuoteIdentifier(val))) - } - } - - for _, opt := range intOpts { - val := d.Get(opt.hclKey).(int) - createOpts = append(createOpts, fmt.Sprintf("%s=%d", opt.sqlKey, val)) - } - - for _, opt := range boolOpts { - val := d.Get(opt.hclKey).(bool) - - valStr := "FALSE" - if val { - valStr = "TRUE" - } - createOpts = append(createOpts, fmt.Sprintf("%s=%s", opt.sqlKey, valStr)) - } - dbName := d.Get(dbNameAttr).(string) - createStr := strings.Join(createOpts, " ") - if len(createOpts) > 0 { - createStr = " WITH " + createStr + b := bytes.NewBufferString("CREATE DATABASE ") + fmt.Fprint(b, pq.QuoteIdentifier(dbName)) + + // Handle each option individually and stream results into the query + // buffer. + + switch v, ok := d.GetOk(dbOwnerAttr); { + case ok: + fmt.Fprint(b, " OWNER ", pq.QuoteIdentifier(v.(string))) + default: + // No owner specified in the config, default to using + // the connecting username. + fmt.Fprint(b, " OWNER ", pq.QuoteIdentifier(c.username)) } - query := fmt.Sprintf("CREATE DATABASE %s%s", pq.QuoteIdentifier(dbName), createStr) + + switch v, ok := d.GetOk(dbTemplateAttr); { + case ok: + fmt.Fprint(b, " TEMPLATE ", pq.QuoteIdentifier(v.(string))) + case v.(string) == "", strings.ToUpper(v.(string)) != "DEFAULT": + fmt.Fprint(b, " TEMPLATE template0") + } + + switch v, ok := d.GetOk(dbEncodingAttr); { + case ok: + fmt.Fprint(b, " ENCODING ", pq.QuoteIdentifier(v.(string))) + case v.(string) == "", strings.ToUpper(v.(string)) != "DEFAULT": + fmt.Fprint(b, ` ENCODING "UTF8"`) + } + + switch v, ok := d.GetOk(dbCollationAttr); { + case ok: + fmt.Fprint(b, " LC_COLLATE ", pq.QuoteIdentifier(v.(string))) + case v.(string) == "", strings.ToUpper(v.(string)) != "DEFAULT": + fmt.Fprint(b, ` LC_COLLATE "C"`) + } + + switch v, ok := d.GetOk(dbCTypeAttr); { + case ok: + fmt.Fprint(b, " LC_CTYPE ", pq.QuoteIdentifier(v.(string))) + case v.(string) == "", strings.ToUpper(v.(string)) != "DEFAULT": + fmt.Fprint(b, ` LC_CTYPE "C"`) + } + + if v, ok := d.GetOk(dbTablespaceAttr); ok { + fmt.Fprint(b, " TABLESPACE ", pq.QuoteIdentifier(v.(string))) + } + + { + val := d.Get(dbAllowConnsAttr).(bool) + fmt.Fprint(b, " ALLOW_CONNECTIONS ", val) + } + + { + val := d.Get(dbConnLimitAttr).(int) + fmt.Fprint(b, " CONNECTION LIMIT ", val) + } + + { + val := d.Get(dbIsTemplateAttr).(bool) + fmt.Fprint(b, " IS_TEMPLATE ", val) + } + + query := b.String() _, err = conn.Query(query) if err != nil { return errwrap.Wrapf(fmt.Sprintf("Error creating database %s: {{err}}", dbName), err) @@ -296,7 +233,7 @@ func resourcePostgreSQLDatabaseRead(d *schema.ResourceData, meta interface{}) er err = conn.QueryRow("SELECT d.datname, pg_catalog.pg_get_userbyid(d.datdba) from pg_database d WHERE datname=$1", dbId).Scan(&dbName, &ownerName) switch { case err == sql.ErrNoRows: - log.Printf("[WARN] PostgreSQL database (%s) not found", d.Id()) + log.Printf("[WARN] PostgreSQL database (%s) not found", dbId) d.SetId("") return nil case err != nil: @@ -313,7 +250,7 @@ func resourcePostgreSQLDatabaseRead(d *schema.ResourceData, meta interface{}) er ) switch { case err == sql.ErrNoRows: - log.Printf("[WARN] PostgreSQL database (%s) not found", d.Id()) + log.Printf("[WARN] PostgreSQL database (%s) not found", dbId) d.SetId("") return nil case err != nil: diff --git a/builtin/providers/postgresql/validators.go b/builtin/providers/postgresql/validators.go deleted file mode 100644 index 8bc75209e..000000000 --- a/builtin/providers/postgresql/validators.go +++ /dev/null @@ -1,11 +0,0 @@ -package postgresql - -import "fmt" - -func validateConnLimit(v interface{}, key string) (warnings []string, errors []error) { - value := v.(int) - if value < -1 { - errors = append(errors, fmt.Errorf("%d can not be less than -1", key)) - } - return -}