diff --git a/state/remote/s3.go b/state/remote/s3.go index b5c245ad2..94c93d81a 100644 --- a/state/remote/s3.go +++ b/state/remote/s3.go @@ -231,28 +231,9 @@ func (c *S3Client) Lock(info *state.LockInfo) (string, error) { _, err := c.dynClient.PutItem(putParams) if err != nil { - getParams := &dynamodb.GetItemInput{ - Key: map[string]*dynamodb.AttributeValue{ - "LockID": {S: aws.String(fmt.Sprintf("%s/%s", c.bucketName, c.keyName))}, - }, - ProjectionExpression: aws.String("LockID, Created, Info"), - TableName: aws.String(c.lockTable), - } - - resp, err := c.dynClient.GetItem(getParams) + lockInfo, err := c.getLockInfo() if err != nil { - return info.ID, fmt.Errorf("s3 state file %q locked, failed to retrieve info: %s", stateName, err) - } - - var infoData string - if v, ok := resp.Item["Info"]; ok && v.S != nil { - infoData = *v.S - } - - lockInfo := &state.LockInfo{} - err = json.Unmarshal([]byte(infoData), lockInfo) - if err != nil { - return info.ID, fmt.Errorf("s3 state file %q locked, failed get lock info: %s", stateName, err) + return "", fmt.Errorf("s3 state file %q locked, failed to retrieve info: %s", stateName, err) } return info.ID, lockInfo.Err() @@ -260,18 +241,58 @@ func (c *S3Client) Lock(info *state.LockInfo) (string, error) { return info.ID, nil } -func (c *S3Client) Unlock(string) error { +func (c *S3Client) getLockInfo() (*state.LockInfo, error) { + getParams := &dynamodb.GetItemInput{ + Key: map[string]*dynamodb.AttributeValue{ + "LockID": {S: aws.String(fmt.Sprintf("%s/%s", c.bucketName, c.keyName))}, + }, + ProjectionExpression: aws.String("LockID, Info"), + TableName: aws.String(c.lockTable), + } + + resp, err := c.dynClient.GetItem(getParams) + if err != nil { + return nil, err + } + + var infoData string + if v, ok := resp.Item["Info"]; ok && v.S != nil { + infoData = *v.S + } + + lockInfo := &state.LockInfo{} + err = json.Unmarshal([]byte(infoData), lockInfo) + if err != nil { + return nil, err + } + + return lockInfo, nil +} + +func (c *S3Client) Unlock(id string) error { if c.lockTable == "" { return nil } + // TODO: store the path and lock ID in separate fields, and have proper + // projection expression only delete the lock if both match, rather than + // checking the ID from the info field first. + lockInfo, err := c.getLockInfo() + if err != nil { + return fmt.Errorf("failed to retrieve lock info: %s", err) + } + + if lockInfo.ID != id { + return fmt.Errorf("lock id %q does not match existing lock", id) + } + params := &dynamodb.DeleteItemInput{ Key: map[string]*dynamodb.AttributeValue{ "LockID": {S: aws.String(fmt.Sprintf("%s/%s", c.bucketName, c.keyName))}, }, TableName: aws.String(c.lockTable), } - _, err := c.dynClient.DeleteItem(params) + _, err = c.dynClient.DeleteItem(params) if err != nil { return err