Merge pull request #26476 from remilapeyre/postgres-backend-escape-schema_name

Properly quote schema_name in the pg backend configuration
This commit is contained in:
Pam Selle 2020-10-05 14:57:53 -04:00 committed by GitHub
commit f84a7c1d57
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 41 additions and 28 deletions

View File

@ -7,7 +7,7 @@ import (
"github.com/hashicorp/terraform/backend" "github.com/hashicorp/terraform/backend"
"github.com/hashicorp/terraform/helper/schema" "github.com/hashicorp/terraform/helper/schema"
_ "github.com/lib/pq" "github.com/lib/pq"
) )
const ( const (
@ -62,7 +62,7 @@ func (b *Backend) configure(ctx context.Context) error {
data := b.configData data := b.configData
b.connStr = data.Get("conn_str").(string) b.connStr = data.Get("conn_str").(string)
b.schemaName = data.Get("schema_name").(string) b.schemaName = pq.QuoteIdentifier(data.Get("schema_name").(string))
db, err := sql.Open("postgres", b.connStr) db, err := sql.Open("postgres", b.connStr)
if err != nil { if err != nil {
@ -75,8 +75,8 @@ func (b *Backend) configure(ctx context.Context) error {
if !data.Get("skip_schema_creation").(bool) { if !data.Get("skip_schema_creation").(bool) {
// list all schemas to see if it exists // list all schemas to see if it exists
var count int var count int
query = `select count(1) from information_schema.schemata where lower(schema_name) = lower('%s')` query = `select count(1) from information_schema.schemata where schema_name = $1`
if err := db.QueryRow(fmt.Sprintf(query, b.schemaName)).Scan(&count); err != nil { if err := db.QueryRow(query, data.Get("schema_name").(string)).Scan(&count); err != nil {
return err return err
} }

View File

@ -11,6 +11,7 @@ import (
"github.com/hashicorp/terraform/backend" "github.com/hashicorp/terraform/backend"
"github.com/hashicorp/terraform/states/remote" "github.com/hashicorp/terraform/states/remote"
"github.com/lib/pq"
_ "github.com/lib/pq" _ "github.com/lib/pq"
) )
@ -36,17 +37,20 @@ func TestBackend_impl(t *testing.T) {
func TestBackendConfig(t *testing.T) { func TestBackendConfig(t *testing.T) {
testACC(t) testACC(t)
connStr := getDatabaseUrl() connStr := getDatabaseUrl()
schemaName := fmt.Sprintf("terraform_%s", t.Name()) schemaName := pq.QuoteIdentifier(fmt.Sprintf("terraform_%s", t.Name()))
config := backend.TestWrapConfig(map[string]interface{}{
"conn_str": connStr,
"schema_name": schemaName,
})
schemaName = pq.QuoteIdentifier(schemaName)
dbCleaner, err := sql.Open("postgres", connStr) dbCleaner, err := sql.Open("postgres", connStr)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
defer dbCleaner.Query(fmt.Sprintf("DROP SCHEMA IF EXISTS %s CASCADE", schemaName)) defer dbCleaner.Query(fmt.Sprintf("DROP SCHEMA IF EXISTS %s CASCADE", schemaName))
config := backend.TestWrapConfig(map[string]interface{}{
"conn_str": connStr,
"schema_name": schemaName,
})
b := backend.TestBackendConfig(t, New(), config).(*Backend) b := backend.TestBackendConfig(t, New(), config).(*Backend)
if b == nil { if b == nil {
@ -79,6 +83,12 @@ func TestBackendConfigSkipSchema(t *testing.T) {
testACC(t) testACC(t)
connStr := getDatabaseUrl() connStr := getDatabaseUrl()
schemaName := fmt.Sprintf("terraform_%s", t.Name()) schemaName := fmt.Sprintf("terraform_%s", t.Name())
config := backend.TestWrapConfig(map[string]interface{}{
"conn_str": connStr,
"schema_name": schemaName,
"skip_schema_creation": true,
})
schemaName = pq.QuoteIdentifier(schemaName)
db, err := sql.Open("postgres", connStr) db, err := sql.Open("postgres", connStr)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -88,11 +98,6 @@ func TestBackendConfigSkipSchema(t *testing.T) {
db.Query(fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %s", schemaName)) db.Query(fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %s", schemaName))
defer db.Query(fmt.Sprintf("DROP SCHEMA IF EXISTS %s CASCADE", schemaName)) defer db.Query(fmt.Sprintf("DROP SCHEMA IF EXISTS %s CASCADE", schemaName))
config := backend.TestWrapConfig(map[string]interface{}{
"conn_str": connStr,
"schema_name": schemaName,
"skip_schema_creation": true,
})
b := backend.TestBackendConfig(t, New(), config).(*Backend) b := backend.TestBackendConfig(t, New(), config).(*Backend)
if b == nil { if b == nil {
@ -122,12 +127,18 @@ func TestBackendConfigSkipSchema(t *testing.T) {
func TestBackendStates(t *testing.T) { func TestBackendStates(t *testing.T) {
testACC(t) testACC(t)
connStr := getDatabaseUrl() connStr := getDatabaseUrl()
schemaName := fmt.Sprintf("terraform_%s", t.Name())
testCases := []string{
fmt.Sprintf("terraform_%s", t.Name()),
fmt.Sprintf("test with spaces: %s", t.Name()),
}
for _, schemaName := range testCases {
t.Run(schemaName, func(t *testing.T) {
dbCleaner, err := sql.Open("postgres", connStr) dbCleaner, err := sql.Open("postgres", connStr)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
defer dbCleaner.Query(fmt.Sprintf("DROP SCHEMA IF EXISTS %s CASCADE", schemaName)) defer dbCleaner.Query("DROP SCHEMA IF EXISTS %s CASCADE", pq.QuoteIdentifier(schemaName))
config := backend.TestWrapConfig(map[string]interface{}{ config := backend.TestWrapConfig(map[string]interface{}{
"conn_str": connStr, "conn_str": connStr,
@ -140,6 +151,8 @@ func TestBackendStates(t *testing.T) {
} }
backend.TestBackendStates(t, b) backend.TestBackendStates(t, b)
})
}
} }
func TestBackendStateLocks(t *testing.T) { func TestBackendStateLocks(t *testing.T) {