Merge pull request #14746 from hashicorp/jbardin/s3-consistency
store and verify s3 remote state checksum to avoid consistency issues.
This commit is contained in:
commit
ef1d53934c
|
@ -2,10 +2,14 @@ package s3
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"crypto/md5"
|
||||||
|
"encoding/hex"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/aws/aws-sdk-go/aws"
|
"github.com/aws/aws-sdk-go/aws"
|
||||||
"github.com/aws/aws-sdk-go/aws/awserr"
|
"github.com/aws/aws-sdk-go/aws/awserr"
|
||||||
|
@ -17,6 +21,9 @@ import (
|
||||||
"github.com/hashicorp/terraform/state/remote"
|
"github.com/hashicorp/terraform/state/remote"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Store the last saved serial in dynamo with this suffix for consistency checks.
|
||||||
|
const stateIDSuffix = "-md5"
|
||||||
|
|
||||||
type RemoteClient struct {
|
type RemoteClient struct {
|
||||||
s3Client *s3.S3
|
s3Client *s3.S3
|
||||||
dynClient *dynamodb.DynamoDB
|
dynClient *dynamodb.DynamoDB
|
||||||
|
@ -28,7 +35,55 @@ type RemoteClient struct {
|
||||||
lockTable string
|
lockTable string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *RemoteClient) Get() (*remote.Payload, error) {
|
var (
|
||||||
|
// The amount of time we will retry a state waiting for it to match the
|
||||||
|
// expected checksum.
|
||||||
|
consistencyRetryTimeout = 10 * time.Second
|
||||||
|
|
||||||
|
// delay when polling the state
|
||||||
|
consistencyRetryPollInterval = 2 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
// test hook called when checksums don't match
|
||||||
|
var testChecksumHook func()
|
||||||
|
|
||||||
|
func (c *RemoteClient) Get() (payload *remote.Payload, err error) {
|
||||||
|
deadline := time.Now().Add(consistencyRetryTimeout)
|
||||||
|
|
||||||
|
// If we have a checksum, and the returned payload doesn't match, we retry
|
||||||
|
// up until deadline.
|
||||||
|
for {
|
||||||
|
payload, err = c.get()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// verify that this state is what we expect
|
||||||
|
if expected, err := c.getMD5(); err != nil {
|
||||||
|
log.Printf("[WARNING] failed to fetch state md5: %s", err)
|
||||||
|
} else if len(expected) > 0 && !bytes.Equal(expected, payload.MD5) {
|
||||||
|
log.Printf("[WARNING] state md5 mismatch: expected '%x', got '%x'", expected, payload.MD5)
|
||||||
|
|
||||||
|
if testChecksumHook != nil {
|
||||||
|
testChecksumHook()
|
||||||
|
}
|
||||||
|
|
||||||
|
if time.Now().Before(deadline) {
|
||||||
|
time.Sleep(consistencyRetryPollInterval)
|
||||||
|
log.Println("[INFO] retrying S3 RemoteClient.Get...")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf(errBadChecksumFmt, payload.MD5)
|
||||||
|
}
|
||||||
|
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
return payload, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *RemoteClient) get() (*remote.Payload, error) {
|
||||||
output, err := c.s3Client.GetObject(&s3.GetObjectInput{
|
output, err := c.s3Client.GetObject(&s3.GetObjectInput{
|
||||||
Bucket: &c.bucketName,
|
Bucket: &c.bucketName,
|
||||||
Key: &c.path,
|
Key: &c.path,
|
||||||
|
@ -53,8 +108,10 @@ func (c *RemoteClient) Get() (*remote.Payload, error) {
|
||||||
return nil, fmt.Errorf("Failed to read remote state: %s", err)
|
return nil, fmt.Errorf("Failed to read remote state: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
sum := md5.Sum(buf.Bytes())
|
||||||
payload := &remote.Payload{
|
payload := &remote.Payload{
|
||||||
Data: buf.Bytes(),
|
Data: buf.Bytes(),
|
||||||
|
MD5: sum[:],
|
||||||
}
|
}
|
||||||
|
|
||||||
// If there was no data, then return nil
|
// If there was no data, then return nil
|
||||||
|
@ -92,11 +149,20 @@ func (c *RemoteClient) Put(data []byte) error {
|
||||||
|
|
||||||
log.Printf("[DEBUG] Uploading remote state to S3: %#v", i)
|
log.Printf("[DEBUG] Uploading remote state to S3: %#v", i)
|
||||||
|
|
||||||
if _, err := c.s3Client.PutObject(i); err == nil {
|
_, err := c.s3Client.PutObject(i)
|
||||||
return nil
|
if err != nil {
|
||||||
} else {
|
|
||||||
return fmt.Errorf("Failed to upload state: %v", err)
|
return fmt.Errorf("Failed to upload state: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
sum := md5.Sum(data)
|
||||||
|
if err := c.putMD5(sum[:]); err != nil {
|
||||||
|
// if this errors out, we unfortunately have to error out altogether,
|
||||||
|
// since the next Get will inevitably fail.
|
||||||
|
return fmt.Errorf("failed to store state MD5: %s", err)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *RemoteClient) Delete() error {
|
func (c *RemoteClient) Delete() error {
|
||||||
|
@ -105,9 +171,17 @@ func (c *RemoteClient) Delete() error {
|
||||||
Key: &c.path,
|
Key: &c.path,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := c.deleteMD5(); err != nil {
|
||||||
|
log.Printf("error deleting state md5: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (c *RemoteClient) Lock(info *state.LockInfo) (string, error) {
|
func (c *RemoteClient) Lock(info *state.LockInfo) (string, error) {
|
||||||
if c.lockTable == "" {
|
if c.lockTable == "" {
|
||||||
return "", nil
|
return "", nil
|
||||||
|
@ -146,9 +220,84 @@ func (c *RemoteClient) Lock(info *state.LockInfo) (string, error) {
|
||||||
}
|
}
|
||||||
return "", lockErr
|
return "", lockErr
|
||||||
}
|
}
|
||||||
|
|
||||||
return info.ID, nil
|
return info.ID, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *RemoteClient) getMD5() ([]byte, error) {
|
||||||
|
if c.lockTable == "" {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
getParams := &dynamodb.GetItemInput{
|
||||||
|
Key: map[string]*dynamodb.AttributeValue{
|
||||||
|
"LockID": {S: aws.String(c.lockPath() + stateIDSuffix)},
|
||||||
|
},
|
||||||
|
ProjectionExpression: aws.String("LockID, Digest"),
|
||||||
|
TableName: aws.String(c.lockTable),
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := c.dynClient.GetItem(getParams)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var val string
|
||||||
|
if v, ok := resp.Item["Digest"]; ok && v.S != nil {
|
||||||
|
val = *v.S
|
||||||
|
}
|
||||||
|
|
||||||
|
sum, err := hex.DecodeString(val)
|
||||||
|
if err != nil || len(sum) != md5.Size {
|
||||||
|
return nil, errors.New("invalid md5")
|
||||||
|
}
|
||||||
|
|
||||||
|
return sum, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// store the hash of the state to that clients can check for stale state files.
|
||||||
|
func (c *RemoteClient) putMD5(sum []byte) error {
|
||||||
|
if c.lockTable == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(sum) != md5.Size {
|
||||||
|
return errors.New("invalid payload md5")
|
||||||
|
}
|
||||||
|
|
||||||
|
putParams := &dynamodb.PutItemInput{
|
||||||
|
Item: map[string]*dynamodb.AttributeValue{
|
||||||
|
"LockID": {S: aws.String(c.lockPath() + stateIDSuffix)},
|
||||||
|
"Digest": {S: aws.String(hex.EncodeToString(sum))},
|
||||||
|
},
|
||||||
|
TableName: aws.String(c.lockTable),
|
||||||
|
}
|
||||||
|
_, err := c.dynClient.PutItem(putParams)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("[WARNING] failed to record state serial in dynamodb: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// remove the hash value for a deleted state
|
||||||
|
func (c *RemoteClient) deleteMD5() error {
|
||||||
|
if c.lockTable == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
params := &dynamodb.DeleteItemInput{
|
||||||
|
Key: map[string]*dynamodb.AttributeValue{
|
||||||
|
"LockID": {S: aws.String(c.lockPath() + stateIDSuffix)},
|
||||||
|
},
|
||||||
|
TableName: aws.String(c.lockTable),
|
||||||
|
}
|
||||||
|
if _, err := c.dynClient.DeleteItem(params); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (c *RemoteClient) getLockInfo() (*state.LockInfo, error) {
|
func (c *RemoteClient) getLockInfo() (*state.LockInfo, error) {
|
||||||
getParams := &dynamodb.GetItemInput{
|
getParams := &dynamodb.GetItemInput{
|
||||||
Key: map[string]*dynamodb.AttributeValue{
|
Key: map[string]*dynamodb.AttributeValue{
|
||||||
|
@ -217,3 +366,12 @@ func (c *RemoteClient) Unlock(id string) error {
|
||||||
func (c *RemoteClient) lockPath() string {
|
func (c *RemoteClient) lockPath() string {
|
||||||
return fmt.Sprintf("%s/%s", c.bucketName, c.path)
|
return fmt.Sprintf("%s/%s", c.bucketName, c.path)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const errBadChecksumFmt = `state data in S3 does not have the expected content.
|
||||||
|
|
||||||
|
This may be caused by unusually long delays in S3 processing a previous state
|
||||||
|
update. Please wait for a minute or two and try again. If this problem
|
||||||
|
persists, and neither S3 nor DynamoDB are experiencing an outage, you may need
|
||||||
|
to manually verify the remote state and update the Digest value stored in the
|
||||||
|
DynamoDB table to the following value: %x
|
||||||
|
`
|
||||||
|
|
|
@ -1,13 +1,17 @@
|
||||||
package s3
|
package s3
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto/md5"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/hashicorp/terraform/backend"
|
"github.com/hashicorp/terraform/backend"
|
||||||
"github.com/hashicorp/terraform/state"
|
"github.com/hashicorp/terraform/state"
|
||||||
"github.com/hashicorp/terraform/state/remote"
|
"github.com/hashicorp/terraform/state/remote"
|
||||||
|
"github.com/hashicorp/terraform/terraform"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestRemoteClient_impl(t *testing.T) {
|
func TestRemoteClient_impl(t *testing.T) {
|
||||||
|
@ -150,3 +154,150 @@ func TestForceUnlock(t *testing.T) {
|
||||||
t.Fatal("failed to force-unlock named state")
|
t.Fatal("failed to force-unlock named state")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRemoteClient_clientMD5(t *testing.T) {
|
||||||
|
testACC(t)
|
||||||
|
|
||||||
|
bucketName := fmt.Sprintf("terraform-remote-s3-test-%x", time.Now().Unix())
|
||||||
|
keyName := "testState"
|
||||||
|
|
||||||
|
b := backend.TestBackendConfig(t, New(), map[string]interface{}{
|
||||||
|
"bucket": bucketName,
|
||||||
|
"key": keyName,
|
||||||
|
"lock_table": bucketName,
|
||||||
|
}).(*Backend)
|
||||||
|
|
||||||
|
createDynamoDBTable(t, b.dynClient, bucketName)
|
||||||
|
defer deleteDynamoDBTable(t, b.dynClient, bucketName)
|
||||||
|
|
||||||
|
s, err := b.State(backend.DefaultStateName)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
client := s.(*remote.State).Client.(*RemoteClient)
|
||||||
|
|
||||||
|
sum := md5.Sum([]byte("test"))
|
||||||
|
|
||||||
|
if err := client.putMD5(sum[:]); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
getSum, err := client.getMD5()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !bytes.Equal(getSum, sum[:]) {
|
||||||
|
t.Fatalf("getMD5 returned the wrong checksum: expected %x, got %x", sum[:], getSum)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := client.deleteMD5(); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if getSum, err := client.getMD5(); err == nil {
|
||||||
|
t.Fatalf("expected getMD5 error, got none. checksum: %x", getSum)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// verify that a client won't return a state with an incorrect checksum.
|
||||||
|
func TestRemoteClient_stateChecksum(t *testing.T) {
|
||||||
|
testACC(t)
|
||||||
|
|
||||||
|
bucketName := fmt.Sprintf("terraform-remote-s3-test-%x", time.Now().Unix())
|
||||||
|
keyName := "testState"
|
||||||
|
|
||||||
|
b1 := backend.TestBackendConfig(t, New(), map[string]interface{}{
|
||||||
|
"bucket": bucketName,
|
||||||
|
"key": keyName,
|
||||||
|
"lock_table": bucketName,
|
||||||
|
}).(*Backend)
|
||||||
|
|
||||||
|
createS3Bucket(t, b1.s3Client, bucketName)
|
||||||
|
defer deleteS3Bucket(t, b1.s3Client, bucketName)
|
||||||
|
createDynamoDBTable(t, b1.dynClient, bucketName)
|
||||||
|
defer deleteDynamoDBTable(t, b1.dynClient, bucketName)
|
||||||
|
|
||||||
|
s1, err := b1.State(backend.DefaultStateName)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
client1 := s1.(*remote.State).Client
|
||||||
|
|
||||||
|
// create a old and new state version to persist
|
||||||
|
s := state.TestStateInitial()
|
||||||
|
var oldState bytes.Buffer
|
||||||
|
if err := terraform.WriteState(s, &oldState); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
s.Serial++
|
||||||
|
var newState bytes.Buffer
|
||||||
|
if err := terraform.WriteState(s, &newState); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use b2 without a lock_table to bypass the lock table to write the state directly.
|
||||||
|
// client2 will write the "incorrect" state, simulating s3 eventually consistency delays
|
||||||
|
b2 := backend.TestBackendConfig(t, New(), map[string]interface{}{
|
||||||
|
"bucket": bucketName,
|
||||||
|
"key": keyName,
|
||||||
|
}).(*Backend)
|
||||||
|
s2, err := b2.State(backend.DefaultStateName)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
client2 := s2.(*remote.State).Client
|
||||||
|
|
||||||
|
// write the new state through client2 so that there is no checksum yet
|
||||||
|
if err := client2.Put(newState.Bytes()); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// verify that we can pull a state without a checksum
|
||||||
|
if _, err := client1.Get(); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// write the new state back with its checksum
|
||||||
|
if err := client1.Put(newState.Bytes()); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// put the old state in place of the new, without updating the checksum
|
||||||
|
if err := client2.Put(oldState.Bytes()); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// remove the timeouts so we can fail immediately
|
||||||
|
origTimeout := consistencyRetryTimeout
|
||||||
|
origInterval := consistencyRetryPollInterval
|
||||||
|
defer func() {
|
||||||
|
consistencyRetryTimeout = origTimeout
|
||||||
|
consistencyRetryPollInterval = origInterval
|
||||||
|
}()
|
||||||
|
consistencyRetryTimeout = 0
|
||||||
|
consistencyRetryPollInterval = 0
|
||||||
|
|
||||||
|
// fetching the state through client1 should now error out due to a
|
||||||
|
// mismatched checksum.
|
||||||
|
if _, err := client1.Get(); !strings.HasPrefix(err.Error(), errBadChecksumFmt[:80]) {
|
||||||
|
t.Fatalf("expected state checksum error: got %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// update the state with the correct one after we Get again
|
||||||
|
testChecksumHook = func() {
|
||||||
|
if err := client2.Put(newState.Bytes()); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
testChecksumHook = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
consistencyRetryTimeout = origTimeout
|
||||||
|
|
||||||
|
// this final Get will fail to fail the checksum verification, the above
|
||||||
|
// callback will update the state with the correct version, and Get should
|
||||||
|
// retry automatically.
|
||||||
|
if _, err := client1.Get(); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue