Teach postgresql_extension about schemas.

This commit is contained in:
Sean Chittenden 2016-12-12 13:15:57 -08:00
parent 201d9b9dfd
commit daa951434f
No known key found for this signature in database
GPG Key ID: 4EBC9DC16C2E5E16
2 changed files with 128 additions and 12 deletions

View File

@ -1,7 +1,9 @@
package postgresql
import (
"bytes"
"database/sql"
"errors"
"fmt"
"log"
@ -10,21 +12,33 @@ import (
"github.com/lib/pq"
)
const (
extNameAttr = "name"
extSchemaAttr = "schema"
)
func resourcePostgreSQLExtension() *schema.Resource {
return &schema.Resource{
Create: resourcePostgreSQLExtensionCreate,
Read: resourcePostgreSQLExtensionRead,
Update: resourcePostgreSQLExtensionUpdate,
Delete: resourcePostgreSQLExtensionDelete,
Importer: &schema.ResourceImporter{
State: schema.ImportStatePassthrough,
},
Schema: map[string]*schema.Schema{
"name": {
extNameAttr: {
Type: schema.TypeString,
Required: true,
ForceNew: true,
},
extSchemaAttr: {
Type: schema.TypeString,
Optional: true,
Computed: true,
Description: "Sets the schema of an extension",
},
},
}
}
@ -37,15 +51,22 @@ func resourcePostgreSQLExtensionCreate(d *schema.ResourceData, meta interface{})
}
defer conn.Close()
extensionName := d.Get("name").(string)
extName := d.Get(extNameAttr).(string)
query := fmt.Sprintf("CREATE EXTENSION %s", pq.QuoteIdentifier(extensionName))
b := bytes.NewBufferString("CREATE EXTENSION ")
fmt.Fprintf(b, pq.QuoteIdentifier(extName))
if v, ok := d.GetOk(extSchemaAttr); ok {
fmt.Fprint(b, " SCHEMA ", pq.QuoteIdentifier(v.(string)))
}
query := b.String()
_, err = conn.Query(query)
if err != nil {
return errwrap.Wrapf("Error creating extension: {{err}}", err)
}
d.SetId(extensionName)
d.SetId(extName)
return resourcePostgreSQLExtensionRead(d, meta)
}
@ -58,11 +79,10 @@ func resourcePostgreSQLExtensionRead(d *schema.ResourceData, meta interface{}) e
}
defer conn.Close()
dbId := d.Id()
extensionName := d.Get("name").(string)
extID := d.Get(extNameAttr).(string)
var hasExtension bool
err = conn.QueryRow("SELECT TRUE from pg_catalog.pg_extension d WHERE extname=$1", dbId).Scan(&hasExtension)
var extName, extSchema string
err = conn.QueryRow("SELECT e.extname, n.nspname FROM pg_catalog.pg_extension e, pg_catalog.pg_namespace n WHERE n.oid = e.extnamespace AND e.extname = $1", extID).Scan(&extName, &extSchema)
switch {
case err == sql.ErrNoRows:
log.Printf("[WARN] PostgreSQL extension (%s) not found", d.Id())
@ -71,8 +91,9 @@ func resourcePostgreSQLExtensionRead(d *schema.ResourceData, meta interface{}) e
case err != nil:
return errwrap.Wrapf("Error reading extension: {{err}}", err)
default:
d.Set("extension", hasExtension)
d.SetId(extensionName)
d.Set(extNameAttr, extName)
d.Set(extSchemaAttr, extSchema)
d.SetId(extName)
return nil
}
}
@ -85,9 +106,9 @@ func resourcePostgreSQLExtensionDelete(d *schema.ResourceData, meta interface{})
}
defer conn.Close()
extensionName := d.Get("name").(string)
extName := d.Get(extNameAttr).(string)
query := fmt.Sprintf("DROP EXTENSION %s", pq.QuoteIdentifier(extensionName))
query := fmt.Sprintf("DROP EXTENSION %s", pq.QuoteIdentifier(extName))
_, err = conn.Query(query)
if err != nil {
return errwrap.Wrapf("Error deleting extension: {{err}}", err)
@ -97,3 +118,40 @@ func resourcePostgreSQLExtensionDelete(d *schema.ResourceData, meta interface{})
return nil
}
func resourcePostgreSQLExtensionUpdate(d *schema.ResourceData, meta interface{}) error {
c := meta.(*Client)
conn, err := c.Connect()
if err != nil {
return err
}
defer conn.Close()
// Can't rename a schema
if err := setExtSchema(conn, d); err != nil {
return err
}
return resourcePostgreSQLExtensionRead(d, meta)
}
func setExtSchema(conn *sql.DB, d *schema.ResourceData) error {
if !d.HasChange(extSchemaAttr) {
return nil
}
oraw, nraw := d.GetChange(extSchemaAttr)
o := oraw.(string)
n := nraw.(string)
if n == "" {
return errors.New("Error setting extension name to an empty string")
}
query := fmt.Sprintf("ALTER EXTENSION %s SET SCHEMA %s", pq.QuoteIdentifier(o), pq.QuoteIdentifier(n))
if _, err := conn.Query(query); err != nil {
return errwrap.Wrapf("Error updating extension SCHEMA: {{err}}", err)
}
return nil
}

View File

@ -75,6 +75,42 @@ func testAccCheckPostgresqlExtensionExists(n string) resource.TestCheckFunc {
}
}
func TestAccPostgresqlExtension_SchemaRename(t *testing.T) {
resource.Test(t, resource.TestCase{
PreCheck: func() { testAccPreCheck(t) },
Providers: testAccProviders,
CheckDestroy: testAccCheckPostgresqlExtensionDestroy,
Steps: []resource.TestStep{
{
Config: testAccPostgresqlExtensionSchemaChange1,
Check: resource.ComposeTestCheckFunc(
testAccCheckPostgresqlExtensionExists("postgresql_extension.ext1trgm"),
resource.TestCheckResourceAttr(
"postgresql_schema.ext1foo", "name", "foo"),
resource.TestCheckResourceAttr(
"postgresql_extension.ext1trgm", "name", "pg_trgm"),
resource.TestCheckResourceAttr(
"postgresql_extension.ext1trgm", "name", "pg_trgm"),
resource.TestCheckResourceAttr(
"postgresql_extension.ext1trgm", "schema", "foo"),
),
},
{
Config: testAccPostgresqlExtensionSchemaChange2,
Check: resource.ComposeTestCheckFunc(
testAccCheckPostgresqlExtensionExists("postgresql_extension.ext1trgm"),
resource.TestCheckResourceAttr(
"postgresql_schema.ext1foo", "name", "bar"),
resource.TestCheckResourceAttr(
"postgresql_extension.ext1trgm", "name", "pg_trgm"),
resource.TestCheckResourceAttr(
"postgresql_extension.ext1trgm", "schema", "bar"),
),
},
},
})
}
func checkExtensionExists(client *Client, extensionName string) (bool, error) {
conn, err := client.Connect()
if err != nil {
@ -99,3 +135,25 @@ resource "postgresql_extension" "myextension" {
name = "pg_trgm"
}
`
var testAccPostgresqlExtensionSchemaChange1 = `
resource "postgresql_schema" "ext1foo" {
name = "foo"
}
resource "postgresql_extension" "ext1trgm" {
name = "pg_trgm"
schema = "${postgresql_schema.ext1foo.name}"
}
`
var testAccPostgresqlExtensionSchemaChange2 = `
resource "postgresql_schema" "ext1foo" {
name = "bar"
}
resource "postgresql_extension" "ext1trgm" {
name = "pg_trgm"
schema = "${postgresql_schema.ext1foo.name}"
}
`