From daa951434f56465c82818407c2fae46fdf50578b Mon Sep 17 00:00:00 2001 From: Sean Chittenden Date: Mon, 12 Dec 2016 13:15:57 -0800 Subject: [PATCH] Teach postgresql_extension about schemas. --- .../resource_postgresql_extension.go | 82 ++++++++++++++++--- .../resource_postgresql_extension_test.go | 58 +++++++++++++ 2 files changed, 128 insertions(+), 12 deletions(-) diff --git a/builtin/providers/postgresql/resource_postgresql_extension.go b/builtin/providers/postgresql/resource_postgresql_extension.go index d0cb83cc7..1040ef0f0 100644 --- a/builtin/providers/postgresql/resource_postgresql_extension.go +++ b/builtin/providers/postgresql/resource_postgresql_extension.go @@ -1,7 +1,9 @@ package postgresql import ( + "bytes" "database/sql" + "errors" "fmt" "log" @@ -10,21 +12,33 @@ import ( "github.com/lib/pq" ) +const ( + extNameAttr = "name" + extSchemaAttr = "schema" +) + func resourcePostgreSQLExtension() *schema.Resource { return &schema.Resource{ Create: resourcePostgreSQLExtensionCreate, Read: resourcePostgreSQLExtensionRead, + Update: resourcePostgreSQLExtensionUpdate, Delete: resourcePostgreSQLExtensionDelete, Importer: &schema.ResourceImporter{ State: schema.ImportStatePassthrough, }, Schema: map[string]*schema.Schema{ - "name": { + extNameAttr: { Type: schema.TypeString, Required: true, ForceNew: true, }, + extSchemaAttr: { + Type: schema.TypeString, + Optional: true, + Computed: true, + Description: "Sets the schema of an extension", + }, }, } } @@ -37,15 +51,22 @@ func resourcePostgreSQLExtensionCreate(d *schema.ResourceData, meta interface{}) } defer conn.Close() - extensionName := d.Get("name").(string) + extName := d.Get(extNameAttr).(string) - query := fmt.Sprintf("CREATE EXTENSION %s", pq.QuoteIdentifier(extensionName)) + b := bytes.NewBufferString("CREATE EXTENSION ") + fmt.Fprintf(b, pq.QuoteIdentifier(extName)) + + if v, ok := d.GetOk(extSchemaAttr); ok { + fmt.Fprint(b, " SCHEMA ", pq.QuoteIdentifier(v.(string))) + } + + query := b.String() _, err = conn.Query(query) if err != nil { return errwrap.Wrapf("Error creating extension: {{err}}", err) } - d.SetId(extensionName) + d.SetId(extName) return resourcePostgreSQLExtensionRead(d, meta) } @@ -58,11 +79,10 @@ func resourcePostgreSQLExtensionRead(d *schema.ResourceData, meta interface{}) e } defer conn.Close() - dbId := d.Id() - extensionName := d.Get("name").(string) + extID := d.Get(extNameAttr).(string) - var hasExtension bool - err = conn.QueryRow("SELECT TRUE from pg_catalog.pg_extension d WHERE extname=$1", dbId).Scan(&hasExtension) + var extName, extSchema string + err = conn.QueryRow("SELECT e.extname, n.nspname FROM pg_catalog.pg_extension e, pg_catalog.pg_namespace n WHERE n.oid = e.extnamespace AND e.extname = $1", extID).Scan(&extName, &extSchema) switch { case err == sql.ErrNoRows: log.Printf("[WARN] PostgreSQL extension (%s) not found", d.Id()) @@ -71,8 +91,9 @@ func resourcePostgreSQLExtensionRead(d *schema.ResourceData, meta interface{}) e case err != nil: return errwrap.Wrapf("Error reading extension: {{err}}", err) default: - d.Set("extension", hasExtension) - d.SetId(extensionName) + d.Set(extNameAttr, extName) + d.Set(extSchemaAttr, extSchema) + d.SetId(extName) return nil } } @@ -85,9 +106,9 @@ func resourcePostgreSQLExtensionDelete(d *schema.ResourceData, meta interface{}) } defer conn.Close() - extensionName := d.Get("name").(string) + extName := d.Get(extNameAttr).(string) - query := fmt.Sprintf("DROP EXTENSION %s", pq.QuoteIdentifier(extensionName)) + query := fmt.Sprintf("DROP EXTENSION %s", pq.QuoteIdentifier(extName)) _, err = conn.Query(query) if err != nil { return errwrap.Wrapf("Error deleting extension: {{err}}", err) @@ -97,3 +118,40 @@ func resourcePostgreSQLExtensionDelete(d *schema.ResourceData, meta interface{}) return nil } + +func resourcePostgreSQLExtensionUpdate(d *schema.ResourceData, meta interface{}) error { + c := meta.(*Client) + conn, err := c.Connect() + if err != nil { + return err + } + defer conn.Close() + + // Can't rename a schema + + if err := setExtSchema(conn, d); err != nil { + return err + } + + return resourcePostgreSQLExtensionRead(d, meta) +} + +func setExtSchema(conn *sql.DB, d *schema.ResourceData) error { + if !d.HasChange(extSchemaAttr) { + return nil + } + + oraw, nraw := d.GetChange(extSchemaAttr) + o := oraw.(string) + n := nraw.(string) + if n == "" { + return errors.New("Error setting extension name to an empty string") + } + + query := fmt.Sprintf("ALTER EXTENSION %s SET SCHEMA %s", pq.QuoteIdentifier(o), pq.QuoteIdentifier(n)) + if _, err := conn.Query(query); err != nil { + return errwrap.Wrapf("Error updating extension SCHEMA: {{err}}", err) + } + + return nil +} diff --git a/builtin/providers/postgresql/resource_postgresql_extension_test.go b/builtin/providers/postgresql/resource_postgresql_extension_test.go index d74a2c188..b474f3299 100644 --- a/builtin/providers/postgresql/resource_postgresql_extension_test.go +++ b/builtin/providers/postgresql/resource_postgresql_extension_test.go @@ -75,6 +75,42 @@ func testAccCheckPostgresqlExtensionExists(n string) resource.TestCheckFunc { } } +func TestAccPostgresqlExtension_SchemaRename(t *testing.T) { + resource.Test(t, resource.TestCase{ + PreCheck: func() { testAccPreCheck(t) }, + Providers: testAccProviders, + CheckDestroy: testAccCheckPostgresqlExtensionDestroy, + Steps: []resource.TestStep{ + { + Config: testAccPostgresqlExtensionSchemaChange1, + Check: resource.ComposeTestCheckFunc( + testAccCheckPostgresqlExtensionExists("postgresql_extension.ext1trgm"), + resource.TestCheckResourceAttr( + "postgresql_schema.ext1foo", "name", "foo"), + resource.TestCheckResourceAttr( + "postgresql_extension.ext1trgm", "name", "pg_trgm"), + resource.TestCheckResourceAttr( + "postgresql_extension.ext1trgm", "name", "pg_trgm"), + resource.TestCheckResourceAttr( + "postgresql_extension.ext1trgm", "schema", "foo"), + ), + }, + { + Config: testAccPostgresqlExtensionSchemaChange2, + Check: resource.ComposeTestCheckFunc( + testAccCheckPostgresqlExtensionExists("postgresql_extension.ext1trgm"), + resource.TestCheckResourceAttr( + "postgresql_schema.ext1foo", "name", "bar"), + resource.TestCheckResourceAttr( + "postgresql_extension.ext1trgm", "name", "pg_trgm"), + resource.TestCheckResourceAttr( + "postgresql_extension.ext1trgm", "schema", "bar"), + ), + }, + }, + }) +} + func checkExtensionExists(client *Client, extensionName string) (bool, error) { conn, err := client.Connect() if err != nil { @@ -99,3 +135,25 @@ resource "postgresql_extension" "myextension" { name = "pg_trgm" } ` + +var testAccPostgresqlExtensionSchemaChange1 = ` +resource "postgresql_schema" "ext1foo" { + name = "foo" +} + +resource "postgresql_extension" "ext1trgm" { + name = "pg_trgm" + schema = "${postgresql_schema.ext1foo.name}" +} +` + +var testAccPostgresqlExtensionSchemaChange2 = ` +resource "postgresql_schema" "ext1foo" { + name = "bar" +} + +resource "postgresql_extension" "ext1trgm" { + name = "pg_trgm" + schema = "${postgresql_schema.ext1foo.name}" +} +`