config: semantic check on variable references

This commit is contained in:
Mitchell Hashimoto 2014-07-02 21:00:06 -07:00
parent c8c7d6baa3
commit 532cc33189
6 changed files with 209 additions and 2 deletions

View File

@ -89,12 +89,52 @@ func (r *Resource) Id() string {
// Validate does some basic semantic checking of the configuration. // Validate does some basic semantic checking of the configuration.
func (c *Config) Validate() error { func (c *Config) Validate() error {
// TODO(mitchellh): make sure all referenced variables exist // Check for references to user variables that do not actually
// TODO(mitchellh): make sure types/names have valid values (characters) // exist and record those errors.
var errs []error
for source, v := range c.allVariables() {
uv, ok := v.(*UserVariable)
if !ok {
continue
}
if _, ok := c.Variables[uv.Name]; !ok {
errs = append(errs, fmt.Errorf(
"%s: unknown variable referenced: %s",
source,
uv.Name))
}
}
if len(errs) > 0 {
return &MultiError{Errors: errs}
}
return nil return nil
} }
// allVariables is a helper that returns a mapping of all the interpolated
// variables within the configuration. This is used to verify references
// are valid in the Validate step.
func (c *Config) allVariables() map[string]InterpolatedVariable {
result := make(map[string]InterpolatedVariable)
for n, pc := range c.ProviderConfigs {
source := fmt.Sprintf("provider config '%s'", n)
for _, v := range pc.RawConfig.Variables {
result[source] = v
}
}
for _, rc := range c.Resources {
source := fmt.Sprintf("resource '%s'", rc.Id())
for _, v := range rc.RawConfig.Variables {
result[source] = v
}
}
return result
}
// Required tests whether a variable is required or not. // Required tests whether a variable is required or not.
func (v *Variable) Required() bool { func (v *Variable) Required() bool {
return !v.defaultSet return !v.defaultSet

View File

@ -1,12 +1,27 @@
package config package config
import ( import (
"path/filepath"
"testing" "testing"
) )
// This is the directory where our test fixtures are. // This is the directory where our test fixtures are.
const fixtureDir = "./test-fixtures" const fixtureDir = "./test-fixtures"
func TestConfigValidate(t *testing.T) {
c := testConfig(t, "validate-good")
if err := c.Validate(); err != nil {
t.Fatalf("err: %s", err)
}
}
func TestConfigValidate_unknownVar(t *testing.T) {
c := testConfig(t, "validate-unknownvar")
if err := c.Validate(); err == nil {
t.Fatal("should not be valid")
}
}
func TestNewResourceVariable(t *testing.T) { func TestNewResourceVariable(t *testing.T) {
v, err := NewResourceVariable("foo.bar.baz") v, err := NewResourceVariable("foo.bar.baz")
if err != nil { if err != nil {
@ -55,3 +70,12 @@ func TestProviderConfigName(t *testing.T) {
t.Fatalf("bad: %s", n) t.Fatalf("bad: %s", n)
} }
} }
func testConfig(t *testing.T, name string) *Config {
c, err := Load(filepath.Join(fixtureDir, name, "main.tf"))
if err != nil {
t.Fatalf("err: %s", err)
}
return c
}

50
config/multi_error.go Normal file
View File

@ -0,0 +1,50 @@
package config
import (
"fmt"
"strings"
)
// MultiError is an error type to track multiple errors. This is used to
// accumulate errors in cases such as configuration parsing, and returning
// them as a single error.
type MultiError struct {
Errors []error
}
func (e *MultiError) Error() string {
points := make([]string, len(e.Errors))
for i, err := range e.Errors {
points[i] = fmt.Sprintf("* %s", err)
}
return fmt.Sprintf(
"%d error(s) occurred:\n\n%s",
len(e.Errors), strings.Join(points, "\n"))
}
// MultiErrorAppend is a helper function that will append more errors
// onto a MultiError in order to create a larger multi-error. If the
// original error is not a MultiError, it will be turned into one.
func MultiErrorAppend(err error, errs ...error) *MultiError {
if err == nil {
err = new(MultiError)
}
switch err := err.(type) {
case *MultiError:
if err == nil {
err = new(MultiError)
}
err.Errors = append(err.Errors, errs...)
return err
default:
newErrs := make([]error, len(errs)+1)
newErrs[0] = err
copy(newErrs[1:], errs)
return &MultiError{
Errors: newErrs,
}
}
}

View File

@ -0,0 +1,56 @@
package config
import (
"errors"
"testing"
)
func TestMultiError_Impl(t *testing.T) {
var raw interface{}
raw = &MultiError{}
if _, ok := raw.(error); !ok {
t.Fatal("MultiError must implement error")
}
}
func TestMultiErrorError(t *testing.T) {
expected := `2 error(s) occurred:
* foo
* bar`
errors := []error{
errors.New("foo"),
errors.New("bar"),
}
multi := &MultiError{errors}
if multi.Error() != expected {
t.Fatalf("bad: %s", multi.Error())
}
}
func TestMultiErrorAppend_MultiError(t *testing.T) {
original := &MultiError{
Errors: []error{errors.New("foo")},
}
result := MultiErrorAppend(original, errors.New("bar"))
if len(result.Errors) != 2 {
t.Fatalf("wrong len: %d", len(result.Errors))
}
original = &MultiError{}
result = MultiErrorAppend(original, errors.New("bar"))
if len(result.Errors) != 1 {
t.Fatalf("wrong len: %d", len(result.Errors))
}
}
func TestMultiErrorAppend_NonMultiError(t *testing.T) {
original := errors.New("foo")
result := MultiErrorAppend(original, errors.New("bar"))
if len(result.Errors) != 2 {
t.Fatalf("wrong len: %d", len(result.Errors))
}
}

View File

@ -0,0 +1,29 @@
variable "foo" {
default = "bar";
description = "bar";
}
provider "aws" {
access_key = "foo";
secret_key = "bar";
}
provider "do" {
api_key = "${var.foo}";
}
resource "aws_security_group" "firewall" {
}
resource aws_instance "web" {
ami = "${var.foo}"
security_groups = [
"foo",
"${aws_security_group.firewall.foo}"
]
network_interface {
device_index = 0
description = "Main network interface"
}
}

View File

@ -0,0 +1,8 @@
variable "foo" {
default = "bar";
description = "bar";
}
provider "do" {
api_key = "${var.bar}";
}