store and verify s3 remote state checksum

Updates to objects in S3 are only eventually consistent. If the
RemoteClient has a DynamoDB table available, use that to store a
checksum of the last written state, so the object can be verified by the
next client to call Get.

Terraform currently doesn't have any sort of user feedback around
RefreshState/Get, so we poll only for a short time before returning an
error.
This commit is contained in:
James Bardin 2017-05-19 11:51:46 -04:00
parent adfa7aedfb
commit 0022d224e8
2 changed files with 308 additions and 5 deletions

View File

@ -2,10 +2,14 @@ package s3
import (
"bytes"
"crypto/md5"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
@ -17,6 +21,9 @@ import (
"github.com/hashicorp/terraform/state/remote"
)
// Store the last saved serial in dynamo with this suffix for consistency checks.
const stateIDSuffix = "-md5"
type RemoteClient struct {
s3Client *s3.S3
dynClient *dynamodb.DynamoDB
@ -28,7 +35,58 @@ type RemoteClient struct {
lockTable string
}
func (c *RemoteClient) Get() (*remote.Payload, error) {
var (
// The amount of time we will retry a state waiting for it to match the
// expected checksum.
consistencyRetryTimeout = 10 * time.Second
// delay when polling the state
consistencyRetryPollInterval = 2 * time.Second
// checksum didn't match the remote state
errBadChecksum = errors.New("invalid state checksum")
)
// test hook called when checksums don't match
var testChecksumHook func()
func (c *RemoteClient) Get() (payload *remote.Payload, err error) {
deadline := time.Now().Add(consistencyRetryTimeout)
// If we have a checksum, and the returned payload doesn't match, we retry
// up until deadline.
for {
payload, err = c.get()
if err != nil {
return nil, err
}
// verify that this state is what we expect
if expected, err := c.getMD5(); err != nil {
log.Printf("[WARNING] failed to fetch state md5: %s", err)
} else if len(expected) > 0 && !bytes.Equal(expected, payload.MD5) {
log.Printf("[WARNING] state md5 mismatch: expected '%x', got '%x'", expected, payload.MD5)
if testChecksumHook != nil {
testChecksumHook()
}
if time.Now().Before(deadline) {
time.Sleep(consistencyRetryPollInterval)
log.Println("[INFO] retrying S3 RemoteClient.Get...")
continue
}
return nil, errBadChecksum
}
break
}
return payload, err
}
func (c *RemoteClient) get() (*remote.Payload, error) {
output, err := c.s3Client.GetObject(&s3.GetObjectInput{
Bucket: &c.bucketName,
Key: &c.path,
@ -53,8 +111,10 @@ func (c *RemoteClient) Get() (*remote.Payload, error) {
return nil, fmt.Errorf("Failed to read remote state: %s", err)
}
sum := md5.Sum(buf.Bytes())
payload := &remote.Payload{
Data: buf.Bytes(),
MD5: sum[:],
}
// If there was no data, then return nil
@ -92,11 +152,20 @@ func (c *RemoteClient) Put(data []byte) error {
log.Printf("[DEBUG] Uploading remote state to S3: %#v", i)
if _, err := c.s3Client.PutObject(i); err == nil {
return nil
} else {
_, err := c.s3Client.PutObject(i)
if err != nil {
return fmt.Errorf("Failed to upload state: %v", err)
}
sum := md5.Sum(data)
if err := c.putMD5(sum[:]); err != nil {
// if this errors out, we unfortunately have to error out altogether,
// since the next Get will inevitably fail.
return fmt.Errorf("failed to store state MD5: %s", err)
}
return nil
}
func (c *RemoteClient) Delete() error {
@ -105,7 +174,15 @@ func (c *RemoteClient) Delete() error {
Key: &c.path,
})
return err
if err != nil {
return err
}
if err := c.deleteMD5(); err != nil {
log.Printf("error deleting state md5: %s", err)
}
return nil
}
func (c *RemoteClient) Lock(info *state.LockInfo) (string, error) {
@ -146,9 +223,84 @@ func (c *RemoteClient) Lock(info *state.LockInfo) (string, error) {
}
return "", lockErr
}
return info.ID, nil
}
func (c *RemoteClient) getMD5() ([]byte, error) {
if c.lockTable == "" {
return nil, nil
}
getParams := &dynamodb.GetItemInput{
Key: map[string]*dynamodb.AttributeValue{
"LockID": {S: aws.String(c.lockPath() + stateIDSuffix)},
},
ProjectionExpression: aws.String("LockID, Digest"),
TableName: aws.String(c.lockTable),
}
resp, err := c.dynClient.GetItem(getParams)
if err != nil {
return nil, err
}
var val string
if v, ok := resp.Item["Digest"]; ok && v.S != nil {
val = *v.S
}
sum, err := hex.DecodeString(val)
if err != nil || len(sum) != md5.Size {
return nil, errors.New("invalid md5")
}
return sum, nil
}
// store the hash of the state to that clients can check for stale state files.
func (c *RemoteClient) putMD5(sum []byte) error {
if c.lockTable == "" {
return nil
}
if len(sum) != md5.Size {
return errors.New("invalid payload md5")
}
putParams := &dynamodb.PutItemInput{
Item: map[string]*dynamodb.AttributeValue{
"LockID": {S: aws.String(c.lockPath() + stateIDSuffix)},
"Digest": {S: aws.String(hex.EncodeToString(sum))},
},
TableName: aws.String(c.lockTable),
}
_, err := c.dynClient.PutItem(putParams)
if err != nil {
log.Printf("[WARNING] failed to record state serial in dynamodb: %s", err)
}
return nil
}
// remove the hash value for a deleted state
func (c *RemoteClient) deleteMD5() error {
if c.lockTable == "" {
return nil
}
params := &dynamodb.DeleteItemInput{
Key: map[string]*dynamodb.AttributeValue{
"LockID": {S: aws.String(c.lockPath() + stateIDSuffix)},
},
TableName: aws.String(c.lockTable),
}
if _, err := c.dynClient.DeleteItem(params); err != nil {
return err
}
return nil
}
func (c *RemoteClient) getLockInfo() (*state.LockInfo, error) {
getParams := &dynamodb.GetItemInput{
Key: map[string]*dynamodb.AttributeValue{

View File

@ -1,12 +1,16 @@
package s3
import (
"bytes"
"crypto/md5"
"fmt"
"testing"
"time"
"github.com/hashicorp/terraform/backend"
"github.com/hashicorp/terraform/state"
"github.com/hashicorp/terraform/state/remote"
"github.com/hashicorp/terraform/terraform"
)
func TestRemoteClient_impl(t *testing.T) {
@ -74,3 +78,150 @@ func TestRemoteClientLocks(t *testing.T) {
remote.TestRemoteLocks(t, s1.(*remote.State).Client, s2.(*remote.State).Client)
}
func TestRemoteClient_clientMD5(t *testing.T) {
testACC(t)
bucketName := fmt.Sprintf("terraform-remote-s3-test-%x", time.Now().Unix())
keyName := "testState"
b := backend.TestBackendConfig(t, New(), map[string]interface{}{
"bucket": bucketName,
"key": keyName,
"lock_table": bucketName,
}).(*Backend)
createDynamoDBTable(t, b.dynClient, bucketName)
defer deleteDynamoDBTable(t, b.dynClient, bucketName)
s, err := b.State(backend.DefaultStateName)
if err != nil {
t.Fatal(err)
}
client := s.(*remote.State).Client.(*RemoteClient)
sum := md5.Sum([]byte("test"))
if err := client.putMD5(sum[:]); err != nil {
t.Fatal(err)
}
getSum, err := client.getMD5()
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(getSum, sum[:]) {
t.Fatalf("getMD5 returned the wrong checksum: expected %x, got %x", sum[:], getSum)
}
if err := client.deleteMD5(); err != nil {
t.Fatal(err)
}
if getSum, err := client.getMD5(); err == nil {
t.Fatalf("expected getMD5 error, got none. checksum: %x", getSum)
}
}
// verify that a client won't return a state with an incorrect checksum.
func TestRemoteClient_stateChecksum(t *testing.T) {
testACC(t)
bucketName := fmt.Sprintf("terraform-remote-s3-test-%x", time.Now().Unix())
keyName := "testState"
b1 := backend.TestBackendConfig(t, New(), map[string]interface{}{
"bucket": bucketName,
"key": keyName,
"lock_table": bucketName,
}).(*Backend)
createS3Bucket(t, b1.s3Client, bucketName)
defer deleteS3Bucket(t, b1.s3Client, bucketName)
createDynamoDBTable(t, b1.dynClient, bucketName)
defer deleteDynamoDBTable(t, b1.dynClient, bucketName)
s1, err := b1.State(backend.DefaultStateName)
if err != nil {
t.Fatal(err)
}
client1 := s1.(*remote.State).Client
// create a old and new state version to persist
s := state.TestStateInitial()
var oldState bytes.Buffer
if err := terraform.WriteState(s, &oldState); err != nil {
t.Fatal(err)
}
s.Serial++
var newState bytes.Buffer
if err := terraform.WriteState(s, &newState); err != nil {
t.Fatal(err)
}
// Use b2 without a lock_table to bypass the lock table to write the state directly.
// client2 will write the "incorrect" state, simulating s3 eventually consistency delays
b2 := backend.TestBackendConfig(t, New(), map[string]interface{}{
"bucket": bucketName,
"key": keyName,
}).(*Backend)
s2, err := b2.State(backend.DefaultStateName)
if err != nil {
t.Fatal(err)
}
client2 := s2.(*remote.State).Client
// write the new state through client2 so that there is no checksum yet
if err := client2.Put(newState.Bytes()); err != nil {
t.Fatal(err)
}
// verify that we can pull a state without a checksum
if _, err := client1.Get(); err != nil {
t.Fatal(err)
}
// write the new state back with its checksum
if err := client1.Put(newState.Bytes()); err != nil {
t.Fatal(err)
}
// put the old state in place of the new, without updating the checksum
if err := client2.Put(oldState.Bytes()); err != nil {
t.Fatal(err)
}
// remove the timeouts so we can fail immediately
origTimeout := consistencyRetryTimeout
origInterval := consistencyRetryPollInterval
defer func() {
consistencyRetryTimeout = origTimeout
consistencyRetryPollInterval = origInterval
}()
consistencyRetryTimeout = 0
consistencyRetryPollInterval = 0
// fetching the state through client1 should now error out due to a
// mismatched checksum.
if _, err := client1.Get(); err != errBadChecksum {
t.Fatalf("expected state checksum error: got %s", err)
}
// update the state with the correct one after we Get again
testChecksumHook = func() {
if err := client2.Put(newState.Bytes()); err != nil {
t.Fatal(err)
}
testChecksumHook = nil
}
consistencyRetryTimeout = origTimeout
// this final Get will fail to fail the checksum verification, the above
// callback will update the state with the correct version, and Get should
// retry automatically.
if _, err := client1.Get(); err != nil {
t.Fatal(err)
}
}