remote: implement refresh state

This commit is contained in:
Armon Dadgar 2014-10-03 11:14:39 -07:00 committed by Mitchell Hashimoto
parent d077a82db2
commit d332b8ad58
3 changed files with 198 additions and 21 deletions

View File

@ -99,18 +99,18 @@ func (c *InitCommand) Run(args []string) int {
// Handle remote state if configured
if !remoteConf.Empty() {
// Read the updated state file
remoteR, err := remote.ReadState(&remoteConf)
change, err := remote.RefreshState(&remoteConf)
if err != nil {
c.Ui.Error(fmt.Sprintf(
"Failed to read remote state: %v", err))
"Failed to refresh from remote state: %v", err))
return 1
}
// Persist the remote state
if err := remote.Persist(remoteR); err != nil {
c.Ui.Error(fmt.Sprintf(
"Failed to persist state: %v", err))
// Log the change that took place
c.Ui.Output(fmt.Sprintf("%s", change))
// Use an error exit code if the update was not a success
if !change.SuccessfulPull() {
return 1
}
}

View File

@ -2,8 +2,10 @@ package remote
import (
"bytes"
"crypto/md5"
"fmt"
"io"
"io/ioutil"
"net/url"
"os"
"path/filepath"
@ -29,6 +31,82 @@ const (
DefaultServer = "http://www.hashicorp.com/"
)
// StateChangeResult is used to communicate to a caller
// what actions have been taken when updating a state file
type StateChangeResult int
const (
// StateChangeNoop indicates nothing has happened,
// but that does not indicate an error. Everything is
// just up to date. (Push/Pull)
StateChangeNoop StateChangeResult = iota
// StateChangeUpdateLocal indicates the local state
// was updated. (Pull)
StateChangeUpdateLocal
// StateChangeUpdateRemote indicates the remote state
// was updated. (Push)
StateChangeUpdateRemote
// StateChangeLocalNewer means the pull was a no-op
// because the local state is newer than that of the
// server. This means a Push should take place. (Pull)
StateChangeLocalNewer
// StateChangeRemoteNewer means the push was a no-op
// because the remote state is newer than that of the
// local state. This means a Pull should take place.
// (Push)
StateChangeRemoteNewer
// StateChangeConflict means that the push or pull
// was a no-op because there is a conflict. This means
// there are multiple state definitions at the same
// serial number with different contents. This requires
// an operator to intervene and resolve the conflict.
// Shame on the user for doing concurrent apply.
// (Push/Pull)
StateChangeConflict
)
func (sc StateChangeResult) String() string {
switch sc {
case StateChangeNoop:
return "Local and remote state in sync"
case StateChangeUpdateLocal:
return "Local state updated"
case StateChangeUpdateRemote:
return "Remote state updated"
case StateChangeLocalNewer:
return "Local state is newer than remote state, push required"
case StateChangeRemoteNewer:
return "Remote state is newer than local state, pull required"
case StateChangeConflict:
return "Local and remote state conflict, manual resolution required"
default:
return fmt.Sprintf("Unknown state change type: %d", sc)
}
}
// SuccessfulPull is used to clasify the StateChangeResult for
// a pull operation. This is different by operation, but can be used
// to determine a proper exit code.
func (sc StateChangeResult) SuccessfulPull() bool {
switch sc {
case StateChangeNoop:
return true
case StateChangeUpdateLocal:
return true
case StateChangeLocalNewer:
return false
case StateChangeConflict:
return false
default:
return false
}
}
// EnsureDirectory is used to make sure the local storage
// directory exists
func EnsureDirectory() error {
@ -126,26 +204,113 @@ This is likely a bug, please report it.`)
return nil
}
// ReadState is used to read the remote state given
// the configuration for the remote endpoint. We return
// a boolean indicating if the remote state exists, along
// with the state, and possible error.
func ReadState(conf *terraform.RemoteState) (io.Reader, error) {
// TODO: Read actually from a server
// RefreshState is used to read the remote state given
// the configuration for the remote endpoint, and update
// the local state if necessary.
func RefreshState(conf *terraform.RemoteState) (StateChangeResult, error) {
// Read the state from the server
payload, err := GetState(conf)
if err != nil {
return StateChangeNoop,
fmt.Errorf("Failed to read remote state: %v", err)
}
// Return the blank state, which is done if the server
// returns a "not found" or equivalent
return blankState(conf)
// Parse the remote state
var remoteState *terraform.State
if payload != nil {
remoteState, err = terraform.ReadState(bytes.NewReader(payload.State))
if err != nil {
return StateChangeNoop,
fmt.Errorf("Failed to parse remote state: %v", err)
}
// Ensure we understand the remote version!
if remoteState.Version > terraform.StateVersion {
return StateChangeNoop, fmt.Errorf(
`Remote state is version %d, this version of Terraform only understands up to %d`, remoteState.Version, terraform.StateVersion)
}
}
// Get the path to the state file
path, err := HiddenStatePath()
if err != nil {
return StateChangeNoop, err
}
// Get the existing state file
raw, err := ioutil.ReadFile(path)
if err != nil && !os.IsNotExist(err) {
return StateChangeNoop, fmt.Errorf("Failed to read local state: %v", err)
}
// Decode the state
var localState *terraform.State
if raw != nil {
localState, err = terraform.ReadState(bytes.NewReader(raw))
if err != nil {
return StateChangeNoop,
fmt.Errorf("Failed to decode state file '%s': %v", path, err)
}
}
// We need to handle the matrix of cases in reconciling
// the local and remote state. Primarily the concern is
// around the Serial number which should grow monotonically.
// Additionally, we use the MD5 to detect a conflict for
// a given Serial.
switch {
case remoteState == nil && localState == nil:
// Initialize a blank state
out, _ := blankState(conf)
if err := Persist(bytes.NewReader(out)); err != nil {
return StateChangeNoop,
fmt.Errorf("Failed to persist state: %v", err)
}
return StateChangeNoop, nil
case remoteState == nil && localState != nil:
fallthrough
case remoteState.Serial < localState.Serial:
// User should probably do a push, nothing to do
return StateChangeLocalNewer, nil
case remoteState != nil && localState == nil:
fallthrough
case remoteState.Serial > localState.Serial:
// Update the local state from the remote state
if err := Persist(bytes.NewReader(payload.State)); err != nil {
return StateChangeNoop,
fmt.Errorf("Failed to persist state: %v", err)
}
return StateChangeUpdateLocal, nil
case remoteState.Serial == localState.Serial:
// Check for a hash collision on the local/remote state
localMD5 := md5.Sum(raw)
if bytes.Equal(localMD5[:md5.Size], payload.MD5) {
// Hash collision, everything is up-to-date
return StateChangeNoop, nil
} else {
// This is very bad. This means we have 2 state files
// with the same Serial but a different hash. Most probably
// explaination is two parallel apply operations. This
// requires a manual reconciliation.
return StateChangeConflict, nil
}
}
// We should not reach this point
panic("Unhandled remote update case")
}
// blankState is used to return a serialized form of a blank state
// with only the remote info.
func blankState(conf *terraform.RemoteState) (io.Reader, error) {
func blankState(conf *terraform.RemoteState) ([]byte, error) {
blank := terraform.NewState()
blank.Remote = conf
buf := bytes.NewBuffer(nil)
err := terraform.WriteState(blank, buf)
return buf, err
return buf.Bytes(), err
}
// Persist is used to write out the state given by a reader (likely

View File

@ -68,7 +68,19 @@ func TestValidateConfig(t *testing.T) {
// TODO:
}
func TestReadState(t *testing.T) {
func TestRefreshState_Blank(t *testing.T) {
// TODO
}
func TestRefreshState_Update_Newer(t *testing.T) {
// TODO
}
func TestRefreshState_Update_Older(t *testing.T) {
// TODO
}
func TestRefreshState_Noop(t *testing.T) {
// TODO
}
@ -82,7 +94,7 @@ func TestBlankState(t *testing.T) {
if err != nil {
t.Fatalf("err: %v", err)
}
s, err := terraform.ReadState(r)
s, err := terraform.ReadState(bytes.NewReader(r))
if err != nil {
t.Fatalf("err: %v", err)
}
@ -116,7 +128,7 @@ func TestPersist(t *testing.T) {
AuthToken: "foobar",
}
blank, _ := blankState(remote)
if err := Persist(blank); err != nil {
if err := Persist(bytes.NewReader(blank)); err != nil {
t.Fatalf("err: %v", err)
}