store and verify s3 remote state checksum
Updates to objects in S3 are only eventually consistent. If the RemoteClient has a DynamoDB table available, use that to store a checksum of the last written state, so the object can be verified by the next client to call Get. Terraform currently doesn't have any sort of user feedback around RefreshState/Get, so we poll only for a short time before returning an error.
This commit is contained in:
parent
adfa7aedfb
commit
0022d224e8
|
@ -2,10 +2,14 @@ package s3
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/md5"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/awserr"
|
||||
|
@ -17,6 +21,9 @@ import (
|
|||
"github.com/hashicorp/terraform/state/remote"
|
||||
)
|
||||
|
||||
// Store the last saved serial in dynamo with this suffix for consistency checks.
|
||||
const stateIDSuffix = "-md5"
|
||||
|
||||
type RemoteClient struct {
|
||||
s3Client *s3.S3
|
||||
dynClient *dynamodb.DynamoDB
|
||||
|
@ -28,7 +35,58 @@ type RemoteClient struct {
|
|||
lockTable string
|
||||
}
|
||||
|
||||
func (c *RemoteClient) Get() (*remote.Payload, error) {
|
||||
var (
|
||||
// The amount of time we will retry a state waiting for it to match the
|
||||
// expected checksum.
|
||||
consistencyRetryTimeout = 10 * time.Second
|
||||
|
||||
// delay when polling the state
|
||||
consistencyRetryPollInterval = 2 * time.Second
|
||||
|
||||
// checksum didn't match the remote state
|
||||
errBadChecksum = errors.New("invalid state checksum")
|
||||
)
|
||||
|
||||
// test hook called when checksums don't match
|
||||
var testChecksumHook func()
|
||||
|
||||
func (c *RemoteClient) Get() (payload *remote.Payload, err error) {
|
||||
deadline := time.Now().Add(consistencyRetryTimeout)
|
||||
|
||||
// If we have a checksum, and the returned payload doesn't match, we retry
|
||||
// up until deadline.
|
||||
for {
|
||||
payload, err = c.get()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// verify that this state is what we expect
|
||||
if expected, err := c.getMD5(); err != nil {
|
||||
log.Printf("[WARNING] failed to fetch state md5: %s", err)
|
||||
} else if len(expected) > 0 && !bytes.Equal(expected, payload.MD5) {
|
||||
log.Printf("[WARNING] state md5 mismatch: expected '%x', got '%x'", expected, payload.MD5)
|
||||
|
||||
if testChecksumHook != nil {
|
||||
testChecksumHook()
|
||||
}
|
||||
|
||||
if time.Now().Before(deadline) {
|
||||
time.Sleep(consistencyRetryPollInterval)
|
||||
log.Println("[INFO] retrying S3 RemoteClient.Get...")
|
||||
continue
|
||||
}
|
||||
|
||||
return nil, errBadChecksum
|
||||
}
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
return payload, err
|
||||
}
|
||||
|
||||
func (c *RemoteClient) get() (*remote.Payload, error) {
|
||||
output, err := c.s3Client.GetObject(&s3.GetObjectInput{
|
||||
Bucket: &c.bucketName,
|
||||
Key: &c.path,
|
||||
|
@ -53,8 +111,10 @@ func (c *RemoteClient) Get() (*remote.Payload, error) {
|
|||
return nil, fmt.Errorf("Failed to read remote state: %s", err)
|
||||
}
|
||||
|
||||
sum := md5.Sum(buf.Bytes())
|
||||
payload := &remote.Payload{
|
||||
Data: buf.Bytes(),
|
||||
MD5: sum[:],
|
||||
}
|
||||
|
||||
// If there was no data, then return nil
|
||||
|
@ -92,11 +152,20 @@ func (c *RemoteClient) Put(data []byte) error {
|
|||
|
||||
log.Printf("[DEBUG] Uploading remote state to S3: %#v", i)
|
||||
|
||||
if _, err := c.s3Client.PutObject(i); err == nil {
|
||||
return nil
|
||||
} else {
|
||||
_, err := c.s3Client.PutObject(i)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to upload state: %v", err)
|
||||
}
|
||||
|
||||
sum := md5.Sum(data)
|
||||
if err := c.putMD5(sum[:]); err != nil {
|
||||
// if this errors out, we unfortunately have to error out altogether,
|
||||
// since the next Get will inevitably fail.
|
||||
return fmt.Errorf("failed to store state MD5: %s", err)
|
||||
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *RemoteClient) Delete() error {
|
||||
|
@ -105,7 +174,15 @@ func (c *RemoteClient) Delete() error {
|
|||
Key: &c.path,
|
||||
})
|
||||
|
||||
return err
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := c.deleteMD5(); err != nil {
|
||||
log.Printf("error deleting state md5: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *RemoteClient) Lock(info *state.LockInfo) (string, error) {
|
||||
|
@ -146,9 +223,84 @@ func (c *RemoteClient) Lock(info *state.LockInfo) (string, error) {
|
|||
}
|
||||
return "", lockErr
|
||||
}
|
||||
|
||||
return info.ID, nil
|
||||
}
|
||||
|
||||
func (c *RemoteClient) getMD5() ([]byte, error) {
|
||||
if c.lockTable == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
getParams := &dynamodb.GetItemInput{
|
||||
Key: map[string]*dynamodb.AttributeValue{
|
||||
"LockID": {S: aws.String(c.lockPath() + stateIDSuffix)},
|
||||
},
|
||||
ProjectionExpression: aws.String("LockID, Digest"),
|
||||
TableName: aws.String(c.lockTable),
|
||||
}
|
||||
|
||||
resp, err := c.dynClient.GetItem(getParams)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var val string
|
||||
if v, ok := resp.Item["Digest"]; ok && v.S != nil {
|
||||
val = *v.S
|
||||
}
|
||||
|
||||
sum, err := hex.DecodeString(val)
|
||||
if err != nil || len(sum) != md5.Size {
|
||||
return nil, errors.New("invalid md5")
|
||||
}
|
||||
|
||||
return sum, nil
|
||||
}
|
||||
|
||||
// store the hash of the state to that clients can check for stale state files.
|
||||
func (c *RemoteClient) putMD5(sum []byte) error {
|
||||
if c.lockTable == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(sum) != md5.Size {
|
||||
return errors.New("invalid payload md5")
|
||||
}
|
||||
|
||||
putParams := &dynamodb.PutItemInput{
|
||||
Item: map[string]*dynamodb.AttributeValue{
|
||||
"LockID": {S: aws.String(c.lockPath() + stateIDSuffix)},
|
||||
"Digest": {S: aws.String(hex.EncodeToString(sum))},
|
||||
},
|
||||
TableName: aws.String(c.lockTable),
|
||||
}
|
||||
_, err := c.dynClient.PutItem(putParams)
|
||||
if err != nil {
|
||||
log.Printf("[WARNING] failed to record state serial in dynamodb: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// remove the hash value for a deleted state
|
||||
func (c *RemoteClient) deleteMD5() error {
|
||||
if c.lockTable == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
params := &dynamodb.DeleteItemInput{
|
||||
Key: map[string]*dynamodb.AttributeValue{
|
||||
"LockID": {S: aws.String(c.lockPath() + stateIDSuffix)},
|
||||
},
|
||||
TableName: aws.String(c.lockTable),
|
||||
}
|
||||
if _, err := c.dynClient.DeleteItem(params); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *RemoteClient) getLockInfo() (*state.LockInfo, error) {
|
||||
getParams := &dynamodb.GetItemInput{
|
||||
Key: map[string]*dynamodb.AttributeValue{
|
||||
|
|
|
@ -1,12 +1,16 @@
|
|||
package s3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/md5"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/terraform/backend"
|
||||
"github.com/hashicorp/terraform/state"
|
||||
"github.com/hashicorp/terraform/state/remote"
|
||||
"github.com/hashicorp/terraform/terraform"
|
||||
)
|
||||
|
||||
func TestRemoteClient_impl(t *testing.T) {
|
||||
|
@ -74,3 +78,150 @@ func TestRemoteClientLocks(t *testing.T) {
|
|||
|
||||
remote.TestRemoteLocks(t, s1.(*remote.State).Client, s2.(*remote.State).Client)
|
||||
}
|
||||
|
||||
func TestRemoteClient_clientMD5(t *testing.T) {
|
||||
testACC(t)
|
||||
|
||||
bucketName := fmt.Sprintf("terraform-remote-s3-test-%x", time.Now().Unix())
|
||||
keyName := "testState"
|
||||
|
||||
b := backend.TestBackendConfig(t, New(), map[string]interface{}{
|
||||
"bucket": bucketName,
|
||||
"key": keyName,
|
||||
"lock_table": bucketName,
|
||||
}).(*Backend)
|
||||
|
||||
createDynamoDBTable(t, b.dynClient, bucketName)
|
||||
defer deleteDynamoDBTable(t, b.dynClient, bucketName)
|
||||
|
||||
s, err := b.State(backend.DefaultStateName)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
client := s.(*remote.State).Client.(*RemoteClient)
|
||||
|
||||
sum := md5.Sum([]byte("test"))
|
||||
|
||||
if err := client.putMD5(sum[:]); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
getSum, err := client.getMD5()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(getSum, sum[:]) {
|
||||
t.Fatalf("getMD5 returned the wrong checksum: expected %x, got %x", sum[:], getSum)
|
||||
}
|
||||
|
||||
if err := client.deleteMD5(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if getSum, err := client.getMD5(); err == nil {
|
||||
t.Fatalf("expected getMD5 error, got none. checksum: %x", getSum)
|
||||
}
|
||||
}
|
||||
|
||||
// verify that a client won't return a state with an incorrect checksum.
|
||||
func TestRemoteClient_stateChecksum(t *testing.T) {
|
||||
testACC(t)
|
||||
|
||||
bucketName := fmt.Sprintf("terraform-remote-s3-test-%x", time.Now().Unix())
|
||||
keyName := "testState"
|
||||
|
||||
b1 := backend.TestBackendConfig(t, New(), map[string]interface{}{
|
||||
"bucket": bucketName,
|
||||
"key": keyName,
|
||||
"lock_table": bucketName,
|
||||
}).(*Backend)
|
||||
|
||||
createS3Bucket(t, b1.s3Client, bucketName)
|
||||
defer deleteS3Bucket(t, b1.s3Client, bucketName)
|
||||
createDynamoDBTable(t, b1.dynClient, bucketName)
|
||||
defer deleteDynamoDBTable(t, b1.dynClient, bucketName)
|
||||
|
||||
s1, err := b1.State(backend.DefaultStateName)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
client1 := s1.(*remote.State).Client
|
||||
|
||||
// create a old and new state version to persist
|
||||
s := state.TestStateInitial()
|
||||
var oldState bytes.Buffer
|
||||
if err := terraform.WriteState(s, &oldState); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
s.Serial++
|
||||
var newState bytes.Buffer
|
||||
if err := terraform.WriteState(s, &newState); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Use b2 without a lock_table to bypass the lock table to write the state directly.
|
||||
// client2 will write the "incorrect" state, simulating s3 eventually consistency delays
|
||||
b2 := backend.TestBackendConfig(t, New(), map[string]interface{}{
|
||||
"bucket": bucketName,
|
||||
"key": keyName,
|
||||
}).(*Backend)
|
||||
s2, err := b2.State(backend.DefaultStateName)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
client2 := s2.(*remote.State).Client
|
||||
|
||||
// write the new state through client2 so that there is no checksum yet
|
||||
if err := client2.Put(newState.Bytes()); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// verify that we can pull a state without a checksum
|
||||
if _, err := client1.Get(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// write the new state back with its checksum
|
||||
if err := client1.Put(newState.Bytes()); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// put the old state in place of the new, without updating the checksum
|
||||
if err := client2.Put(oldState.Bytes()); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// remove the timeouts so we can fail immediately
|
||||
origTimeout := consistencyRetryTimeout
|
||||
origInterval := consistencyRetryPollInterval
|
||||
defer func() {
|
||||
consistencyRetryTimeout = origTimeout
|
||||
consistencyRetryPollInterval = origInterval
|
||||
}()
|
||||
consistencyRetryTimeout = 0
|
||||
consistencyRetryPollInterval = 0
|
||||
|
||||
// fetching the state through client1 should now error out due to a
|
||||
// mismatched checksum.
|
||||
if _, err := client1.Get(); err != errBadChecksum {
|
||||
t.Fatalf("expected state checksum error: got %s", err)
|
||||
}
|
||||
|
||||
// update the state with the correct one after we Get again
|
||||
testChecksumHook = func() {
|
||||
if err := client2.Put(newState.Bytes()); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
testChecksumHook = nil
|
||||
}
|
||||
|
||||
consistencyRetryTimeout = origTimeout
|
||||
|
||||
// this final Get will fail to fail the checksum verification, the above
|
||||
// callback will update the state with the correct version, and Get should
|
||||
// retry automatically.
|
||||
if _, err := client1.Get(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue