provider/aws: Simplify update logic for Lambda function

This commit is contained in:
Radek Simko 2016-02-19 12:13:26 +00:00
parent b5c7521f52
commit fdc21aad25
1 changed files with 15 additions and 59 deletions

View File

@ -1,8 +1,6 @@
package aws package aws
import ( import (
"crypto/sha256"
"encoding/base64"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"log" "log"
@ -103,11 +101,6 @@ func resourceAwsLambdaFunction() *schema.Resource {
}, },
}, },
}, },
"update_code": &schema.Schema{
Type: schema.TypeBool,
Optional: true,
Default: false,
},
"arn": &schema.Schema{ "arn": &schema.Schema{
Type: schema.TypeString, Type: schema.TypeString,
Computed: true, Computed: true,
@ -118,10 +111,7 @@ func resourceAwsLambdaFunction() *schema.Resource {
}, },
"source_code_hash": &schema.Schema{ "source_code_hash": &schema.Schema{
Type: schema.TypeString, Type: schema.TypeString,
Computed: true, Optional: true,
},
"remote_code_hash": &schema.Schema{
Type: schema.TypeString,
Computed: true, Computed: true,
}, },
}, },
@ -140,13 +130,12 @@ func resourceAwsLambdaFunctionCreate(d *schema.ResourceData, meta interface{}) e
var functionCode *lambda.FunctionCode var functionCode *lambda.FunctionCode
if v, ok := d.GetOk("filename"); ok { if v, ok := d.GetOk("filename"); ok {
zipfile, shaSum, err := loadLocalZipFile(v.(string)) file, err := loadFileContent(v.(string))
if err != nil { if err != nil {
return err return fmt.Errorf("Unable to load %q: %s", v.(string), err)
} }
d.Set("source_code_hash", shaSum)
functionCode = &lambda.FunctionCode{ functionCode = &lambda.FunctionCode{
ZipFile: zipfile, ZipFile: file,
} }
} else { } else {
s3Bucket, bucketOk := d.GetOk("s3_bucket") s3Bucket, bucketOk := d.GetOk("s3_bucket")
@ -257,26 +246,7 @@ func resourceAwsLambdaFunctionRead(d *schema.ResourceData, meta interface{}) err
if config := flattenLambdaVpcConfigResponse(function.VpcConfig); len(config) > 0 { if config := flattenLambdaVpcConfigResponse(function.VpcConfig); len(config) > 0 {
d.Set("vpc_config", config) d.Set("vpc_config", config)
} }
d.Set("source_code_hash", function.CodeSha256)
// Compare code hashes, and see if an update is required to code. If there
// is, set the "update_code" attribute.
remoteSum, err := decodeBase64(*function.CodeSha256)
if err != nil {
return err
}
_, localSum, err := loadLocalZipFile(d.Get("filename").(string))
if err != nil {
return err
}
d.Set("remote_code_hash", remoteSum)
d.Set("source_code_hash", localSum)
if remoteSum != localSum {
d.Set("update_code", true)
} else {
d.Set("update_code", false)
}
return nil return nil
} }
@ -314,16 +284,12 @@ func resourceAwsLambdaFunctionUpdate(d *schema.ResourceData, meta interface{}) e
} }
codeUpdate := false codeUpdate := false
if sourceHash, ok := d.GetOk("source_code_hash"); ok { if v, ok := d.GetOk("filename"); ok && d.HasChange("source_code_hash") {
zipfile, shaSum, err := loadLocalZipFile(d.Get("filename").(string)) file, err := loadFileContent(v.(string))
if err != nil { if err != nil {
return err return fmt.Errorf("Unable to load %q: %s", v.(string), err)
} }
if sourceHash != shaSum { codeReq.ZipFile = file
d.SetPartial("filename")
d.SetPartial("source_code_hash")
}
codeReq.ZipFile = zipfile
codeUpdate = true codeUpdate = true
} }
if d.HasChange("s3_bucket") || d.HasChange("s3_key") || d.HasChange("s3_object_version") { if d.HasChange("s3_bucket") || d.HasChange("s3_key") || d.HasChange("s3_object_version") {
@ -390,27 +356,17 @@ func resourceAwsLambdaFunctionUpdate(d *schema.ResourceData, meta interface{}) e
return resourceAwsLambdaFunctionRead(d, meta) return resourceAwsLambdaFunctionRead(d, meta)
} }
// loads the local ZIP data and the SHA sum of the data. // loadFileContent returns contents of a file in a given path
func loadLocalZipFile(v string) ([]byte, string, error) { func loadFileContent(v string) ([]byte, error) {
filename, err := homedir.Expand(v) filename, err := homedir.Expand(v)
if err != nil { if err != nil {
return nil, "", err return nil, err
} }
zipfile, err := ioutil.ReadFile(filename) fileContent, err := ioutil.ReadFile(filename)
if err != nil { if err != nil {
return nil, "", err return nil, err
} }
sum := sha256.Sum256(zipfile) return fileContent, nil
return zipfile, fmt.Sprintf("%x", sum), nil
}
// Decodes a base64 string to a string.
func decodeBase64(s string) (string, error) {
sum, err := base64.StdEncoding.DecodeString(s)
if err != nil {
return "", err
}
return fmt.Sprintf("%x", sum), nil
} }
func validateVPCConfig(v interface{}) (map[string]interface{}, error) { func validateVPCConfig(v interface{}) (map[string]interface{}, error) {