From 42be5854a28c0cdbdc41666fc7dd1152a4d599c6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Lapeyre?= Date: Sat, 3 Oct 2020 18:02:13 +0200 Subject: [PATCH] Properly quote schema_name in the pg backend configuration --- backend/remote-state/pg/backend.go | 8 ++-- backend/remote-state/pg/backend_test.go | 61 +++++++++++++++---------- 2 files changed, 41 insertions(+), 28 deletions(-) diff --git a/backend/remote-state/pg/backend.go b/backend/remote-state/pg/backend.go index 71540fe20..e31f99712 100644 --- a/backend/remote-state/pg/backend.go +++ b/backend/remote-state/pg/backend.go @@ -7,7 +7,7 @@ import ( "github.com/hashicorp/terraform/backend" "github.com/hashicorp/terraform/helper/schema" - _ "github.com/lib/pq" + "github.com/lib/pq" ) const ( @@ -62,7 +62,7 @@ func (b *Backend) configure(ctx context.Context) error { data := b.configData b.connStr = data.Get("conn_str").(string) - b.schemaName = data.Get("schema_name").(string) + b.schemaName = pq.QuoteIdentifier(data.Get("schema_name").(string)) db, err := sql.Open("postgres", b.connStr) if err != nil { @@ -75,8 +75,8 @@ func (b *Backend) configure(ctx context.Context) error { if !data.Get("skip_schema_creation").(bool) { // list all schemas to see if it exists var count int - query = `select count(1) from information_schema.schemata where lower(schema_name) = lower('%s')` - if err := db.QueryRow(fmt.Sprintf(query, b.schemaName)).Scan(&count); err != nil { + query = `select count(1) from information_schema.schemata where schema_name = $1` + if err := db.QueryRow(query, data.Get("schema_name").(string)).Scan(&count); err != nil { return err } diff --git a/backend/remote-state/pg/backend_test.go b/backend/remote-state/pg/backend_test.go index b6a672634..1b500e132 100644 --- a/backend/remote-state/pg/backend_test.go +++ b/backend/remote-state/pg/backend_test.go @@ -11,6 +11,7 @@ import ( "github.com/hashicorp/terraform/backend" "github.com/hashicorp/terraform/states/remote" + "github.com/lib/pq" _ "github.com/lib/pq" ) @@ -36,17 +37,20 @@ func TestBackend_impl(t *testing.T) { func TestBackendConfig(t *testing.T) { testACC(t) connStr := getDatabaseUrl() - schemaName := fmt.Sprintf("terraform_%s", t.Name()) + schemaName := pq.QuoteIdentifier(fmt.Sprintf("terraform_%s", t.Name())) + + config := backend.TestWrapConfig(map[string]interface{}{ + "conn_str": connStr, + "schema_name": schemaName, + }) + schemaName = pq.QuoteIdentifier(schemaName) + dbCleaner, err := sql.Open("postgres", connStr) if err != nil { t.Fatal(err) } defer dbCleaner.Query(fmt.Sprintf("DROP SCHEMA IF EXISTS %s CASCADE", schemaName)) - config := backend.TestWrapConfig(map[string]interface{}{ - "conn_str": connStr, - "schema_name": schemaName, - }) b := backend.TestBackendConfig(t, New(), config).(*Backend) if b == nil { @@ -79,6 +83,12 @@ func TestBackendConfigSkipSchema(t *testing.T) { testACC(t) connStr := getDatabaseUrl() schemaName := fmt.Sprintf("terraform_%s", t.Name()) + config := backend.TestWrapConfig(map[string]interface{}{ + "conn_str": connStr, + "schema_name": schemaName, + "skip_schema_creation": true, + }) + schemaName = pq.QuoteIdentifier(schemaName) db, err := sql.Open("postgres", connStr) if err != nil { t.Fatal(err) @@ -88,11 +98,6 @@ func TestBackendConfigSkipSchema(t *testing.T) { db.Query(fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %s", schemaName)) defer db.Query(fmt.Sprintf("DROP SCHEMA IF EXISTS %s CASCADE", schemaName)) - config := backend.TestWrapConfig(map[string]interface{}{ - "conn_str": connStr, - "schema_name": schemaName, - "skip_schema_creation": true, - }) b := backend.TestBackendConfig(t, New(), config).(*Backend) if b == nil { @@ -122,24 +127,32 @@ func TestBackendConfigSkipSchema(t *testing.T) { func TestBackendStates(t *testing.T) { testACC(t) connStr := getDatabaseUrl() - schemaName := fmt.Sprintf("terraform_%s", t.Name()) - dbCleaner, err := sql.Open("postgres", connStr) - if err != nil { - t.Fatal(err) + + testCases := []string{ + fmt.Sprintf("terraform_%s", t.Name()), + fmt.Sprintf("test with spaces: %s", t.Name()), } - defer dbCleaner.Query(fmt.Sprintf("DROP SCHEMA IF EXISTS %s CASCADE", schemaName)) + for _, schemaName := range testCases { + t.Run(schemaName, func(t *testing.T) { + dbCleaner, err := sql.Open("postgres", connStr) + if err != nil { + t.Fatal(err) + } + defer dbCleaner.Query("DROP SCHEMA IF EXISTS %s CASCADE", pq.QuoteIdentifier(schemaName)) - config := backend.TestWrapConfig(map[string]interface{}{ - "conn_str": connStr, - "schema_name": schemaName, - }) - b := backend.TestBackendConfig(t, New(), config).(*Backend) + config := backend.TestWrapConfig(map[string]interface{}{ + "conn_str": connStr, + "schema_name": schemaName, + }) + b := backend.TestBackendConfig(t, New(), config).(*Backend) - if b == nil { - t.Fatal("Backend could not be configured") + if b == nil { + t.Fatal("Backend could not be configured") + } + + backend.TestBackendStates(t, b) + }) } - - backend.TestBackendStates(t, b) } func TestBackendStateLocks(t *testing.T) {