From 2e10ddb8780614aed8cd7fd9534e96385c7b9e56 Mon Sep 17 00:00:00 2001 From: Mitchell Hashimoto Date: Thu, 3 Jul 2014 10:29:14 -0700 Subject: [PATCH] terraform: Context.Refresh --- terraform/context.go | 167 ++++++++++++++++++++++++++++++++++++++ terraform/context_test.go | 119 +++++++++++++++++++++++++++ 2 files changed, 286 insertions(+) diff --git a/terraform/context.go b/terraform/context.go index 1f322d880..b8afa7eaa 100644 --- a/terraform/context.go +++ b/terraform/context.go @@ -1,7 +1,13 @@ package terraform import ( + "fmt" + "log" + "sync" + "sync/atomic" + "github.com/hashicorp/terraform/config" + "github.com/hashicorp/terraform/depgraph" "github.com/hashicorp/terraform/helper/multierror" ) @@ -46,6 +52,28 @@ func NewContext(opts *ContextOpts) *Context { } } +// Refresh goes through all the resources in the state and refreshes them +// to their latest state. This will update the state that this context +// works with, along with returning it. +// +// Even in the case an error is returned, the state will be returned and +// will potentially be partially updated. +func (c *Context) Refresh() (*State, error) { + g, err := Graph(&GraphOpts{ + Config: c.config, + Providers: c.providers, + State: c.state, + }) + if err != nil { + return c.state, err + } + + s := new(State) + s.init() + err = g.Walk(c.refreshWalkFn(s)) + return s, err +} + // Validate validates the configuration and returns any warnings or errors. func (c *Context) Validate() ([]string, []error) { var rerr *multierror.Error @@ -67,3 +95,142 @@ func (c *Context) Validate() ([]string, []error) { return nil, errs } + +func (c *Context) refreshWalkFn(result *State) depgraph.WalkFunc { + var l sync.Mutex + + cb := func(r *Resource) (map[string]string, error) { + for _, h := range c.hooks { + handleHook(h.PreRefresh(r.Id, r.State)) + } + + rs, err := r.Provider.Refresh(r.State) + if err != nil { + return nil, err + } + if rs == nil { + rs = new(ResourceState) + } + + // Fix the type to be the type we have + rs.Type = r.State.Type + + l.Lock() + result.Resources[r.Id] = rs + l.Unlock() + + for _, h := range c.hooks { + handleHook(h.PostRefresh(r.Id, rs)) + } + + return nil, nil + } + + return c.genericWalkFn(c.variables, cb) +} + +func (c *Context) genericWalkFn( + invars map[string]string, + cb genericWalkFunc) depgraph.WalkFunc { + var l sync.RWMutex + + // Initialize the variables for application + vars := make(map[string]string) + for k, v := range invars { + vars[fmt.Sprintf("var.%s", k)] = v + } + + // This will keep track of whether we're stopped or not + var stop uint32 = 0 + + return func(n *depgraph.Noun) error { + // If it is the root node, ignore + if n.Name == GraphRootNode { + return nil + } + + // If we're stopped, return right away + if atomic.LoadUint32(&stop) != 0 { + return nil + } + + switch m := n.Meta.(type) { + case *GraphNodeResource: + case *GraphNodeResourceProvider: + var rc *ResourceConfig + if m.Config != nil { + if err := m.Config.RawConfig.Interpolate(vars); err != nil { + panic(err) + } + rc = NewResourceConfig(m.Config.RawConfig) + } + + for k, p := range m.Providers { + log.Printf("[INFO] Configuring provider: %s", k) + err := p.Configure(rc) + if err != nil { + return err + } + } + + return nil + } + + rn := n.Meta.(*GraphNodeResource) + + l.RLock() + if len(vars) > 0 && rn.Config != nil { + if err := rn.Config.RawConfig.Interpolate(vars); err != nil { + panic(fmt.Sprintf("Interpolate error: %s", err)) + } + + // Force the config to be set later + rn.Resource.Config = nil + } + l.RUnlock() + + // Make sure that at least some resource configuration is set + if !rn.Orphan { + if rn.Resource.Config == nil { + if rn.Config == nil { + rn.Resource.Config = new(ResourceConfig) + } else { + rn.Resource.Config = NewResourceConfig(rn.Config.RawConfig) + } + } + } else { + rn.Resource.Config = nil + } + + // Handle recovery of special panic scenarios + defer func() { + if v := recover(); v != nil { + if v == HookActionHalt { + atomic.StoreUint32(&stop, 1) + } else { + panic(v) + } + } + }() + + // Call the callack + log.Printf("[INFO] Walking: %s", rn.Resource.Id) + newVars, err := cb(rn.Resource) + if err != nil { + return err + } + + if len(newVars) > 0 { + // Acquire a lock since this function is called in parallel + l.Lock() + defer l.Unlock() + + // Update variables + for k, v := range newVars { + vars[k] = v + } + } + + return nil + } +} diff --git a/terraform/context_test.go b/terraform/context_test.go index a82fda892..44d34079f 100644 --- a/terraform/context_test.go +++ b/terraform/context_test.go @@ -1,6 +1,8 @@ package terraform import ( + "fmt" + "reflect" "testing" ) @@ -49,6 +51,123 @@ func TestContextValidate_requiredVar(t *testing.T) { } } +func TestContextRefresh(t *testing.T) { + p := testProvider("aws") + c := testConfig(t, "refresh-basic") + ctx := testContext(t, &ContextOpts{ + Config: c, + Providers: map[string]ResourceProviderFactory{ + "aws": testProviderFuncFixed(p), + }, + }) + + p.RefreshFn = nil + p.RefreshReturn = &ResourceState{ + ID: "foo", + } + + s, err := ctx.Refresh() + if err != nil { + t.Fatalf("err: %s", err) + } + if !p.RefreshCalled { + t.Fatal("refresh should be called") + } + if p.RefreshState.ID != "" { + t.Fatalf("bad: %#v", p.RefreshState) + } + if !reflect.DeepEqual(s.Resources["aws_instance.web"], p.RefreshReturn) { + t.Fatalf("bad: %#v", s.Resources["aws_instance.web"]) + } + + for _, r := range s.Resources { + if r.Type == "" { + t.Fatalf("no type: %#v", r) + } + } +} + +func TestContextRefresh_hook(t *testing.T) { + h := new(MockHook) + p := testProvider("aws") + c := testConfig(t, "refresh-basic") + ctx := testContext(t, &ContextOpts{ + Config: c, + Hooks: []Hook{h}, + Providers: map[string]ResourceProviderFactory{ + "aws": testProviderFuncFixed(p), + }, + }) + + if _, err := ctx.Refresh(); err != nil { + t.Fatalf("err: %s", err) + } + if !h.PreRefreshCalled { + t.Fatal("should be called") + } + if h.PreRefreshState.Type != "aws_instance" { + t.Fatalf("bad: %#v", h.PreRefreshState) + } + if !h.PostRefreshCalled { + t.Fatal("should be called") + } + if h.PostRefreshState.Type != "aws_instance" { + t.Fatalf("bad: %#v", h.PostRefreshState) + } +} + +func TestContextRefresh_state(t *testing.T) { + p := testProvider("aws") + c := testConfig(t, "refresh-basic") + state := &State{ + Resources: map[string]*ResourceState{ + "aws_instance.web": &ResourceState{ + ID: "bar", + }, + }, + } + ctx := testContext(t, &ContextOpts{ + Config: c, + Providers: map[string]ResourceProviderFactory{ + "aws": testProviderFuncFixed(p), + }, + State: state, + }) + + p.RefreshFn = nil + p.RefreshReturn = &ResourceState{ + ID: "foo", + } + + s, err := ctx.Refresh() + if err != nil { + t.Fatalf("err: %s", err) + } + if !p.RefreshCalled { + t.Fatal("refresh should be called") + } + if !reflect.DeepEqual(p.RefreshState, state.Resources["aws_instance.web"]) { + t.Fatalf("bad: %#v", p.RefreshState) + } + if !reflect.DeepEqual(s.Resources["aws_instance.web"], p.RefreshReturn) { + t.Fatalf("bad: %#v", s.Resources) + } +} + func testContext(t *testing.T, opts *ContextOpts) *Context { return NewContext(opts) } + +func testProvider(prefix string) *MockResourceProvider { + p := new(MockResourceProvider) + p.RefreshFn = func(s *ResourceState) (*ResourceState, error) { + return s, nil + } + p.ResourcesReturn = []ResourceType{ + ResourceType{ + Name: fmt.Sprintf("%s_instance", prefix), + }, + } + + return p +}