remote: Testing put state

This commit is contained in:
Armon Dadgar 2014-10-07 12:09:51 -07:00 committed by Mitchell Hashimoto
parent 958bea4fca
commit d1e41bc992
2 changed files with 136 additions and 6 deletions

View File

@ -18,6 +18,18 @@ var (
// ErrConflict is used to indicate the upload was rejected // ErrConflict is used to indicate the upload was rejected
// due to a conflict on the state // due to a conflict on the state
ErrConflict = fmt.Errorf("Conflicting state file") ErrConflict = fmt.Errorf("Conflicting state file")
// ErrRequireAuth is used if the remote server requires
// authentication and none is provided
ErrRequireAuth = fmt.Errorf("Remote server requires authentication")
// ErrInvalidAuth is used if we provide authentication which
// is not valid
ErrInvalidAuth = fmt.Errorf("Invalid authentication")
// ErrRemoteInternal is used if we get an internal error
// from the remote server
ErrRemoteInternal = fmt.Errorf("Remote server reporting internal error")
) )
// RemoteStatePayload is used to return the remote state // RemoteStatePayload is used to return the remote state
@ -78,11 +90,11 @@ func (r *remoteStateClient) GetState() (*RemoteStatePayload, error) {
case http.StatusNotFound: case http.StatusNotFound:
return nil, nil return nil, nil
case http.StatusUnauthorized: case http.StatusUnauthorized:
return nil, fmt.Errorf("Remote server requires authentication") return nil, ErrRequireAuth
case http.StatusForbidden: case http.StatusForbidden:
return nil, fmt.Errorf("Invalid authentication") return nil, ErrInvalidAuth
case http.StatusInternalServerError: case http.StatusInternalServerError:
return nil, fmt.Errorf("Remote server reporting internal error") return nil, ErrRemoteInternal
default: default:
return nil, fmt.Errorf("Unexpected HTTP response code %d", resp.StatusCode) return nil, fmt.Errorf("Unexpected HTTP response code %d", resp.StatusCode)
} }
@ -176,11 +188,11 @@ func (r *remoteStateClient) PutState(state []byte, force bool) error {
case http.StatusConflict: case http.StatusConflict:
return ErrConflict return ErrConflict
case http.StatusUnauthorized: case http.StatusUnauthorized:
return fmt.Errorf("Remote server requires authentication") return ErrRequireAuth
case http.StatusForbidden: case http.StatusForbidden:
return fmt.Errorf("Invalid authentication") return ErrInvalidAuth
case http.StatusInternalServerError: case http.StatusInternalServerError:
return fmt.Errorf("Remote server reporting internal error") return ErrRemoteInternal
default: default:
return fmt.Errorf("Unexpected HTTP response code %d", resp.StatusCode) return fmt.Errorf("Unexpected HTTP response code %d", resp.StatusCode)
} }

View File

@ -3,6 +3,8 @@ package remote
import ( import (
"bytes" "bytes"
"crypto/md5" "crypto/md5"
"encoding/base64"
"io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"strings" "strings"
@ -156,3 +158,119 @@ func TestGetState(t *testing.T) {
} }
} }
} }
func TestPutState(t *testing.T) {
type tcase struct {
Code int
Path string
Header http.Header
Body []byte
ExpectMD5 []byte
Force bool
ExpectErr string
}
inp := []byte("testing")
inpMD5 := md5.Sum(inp)
hash := inpMD5[:16]
cases := []*tcase{
&tcase{
Code: http.StatusOK,
Path: "/foobar",
Body: inp,
ExpectMD5: hash,
},
&tcase{
Code: http.StatusOK,
Path: "/foobar?force=true",
Body: inp,
Force: true,
ExpectMD5: hash,
},
&tcase{
Code: http.StatusConflict,
Path: "/foobar",
Body: inp,
ExpectMD5: hash,
ExpectErr: ErrConflict.Error(),
},
&tcase{
Code: http.StatusUnauthorized,
Path: "/foobar",
Body: inp,
ExpectMD5: hash,
ExpectErr: ErrRequireAuth.Error(),
},
&tcase{
Code: http.StatusForbidden,
Path: "/foobar",
Body: inp,
ExpectMD5: hash,
ExpectErr: ErrInvalidAuth.Error(),
},
&tcase{
Code: http.StatusInternalServerError,
Path: "/foobar",
Body: inp,
ExpectMD5: hash,
ExpectErr: ErrRemoteInternal.Error(),
},
&tcase{
Code: 418,
Path: "/foobar",
Body: inp,
ExpectMD5: hash,
ExpectErr: "Unexpected HTTP response code 418",
},
}
for _, tc := range cases {
cb := func(resp http.ResponseWriter, req *http.Request) {
for k, v := range tc.Header {
resp.Header()[k] = v
}
resp.WriteHeader(tc.Code)
// Verify the body
buf := bytes.NewBuffer(nil)
io.Copy(buf, req.Body)
if !bytes.Equal(buf.Bytes(), tc.Body) {
t.Fatalf("bad body: %v", buf.Bytes())
}
// Verify the path
req.URL.Host = ""
if req.URL.String() != tc.Path {
t.Fatalf("Bad path: %v %v", req.URL.String(), tc.Path)
}
// Verify the content length
if req.ContentLength != int64(len(tc.Body)) {
t.Fatalf("bad content length: %d", req.ContentLength)
}
// Verify the Content-MD5
b64 := req.Header.Get("Content-MD5")
raw, _ := base64.StdEncoding.DecodeString(b64)
if !bytes.Equal(raw, tc.ExpectMD5) {
t.Fatalf("bad md5: %v", raw)
}
}
s := httptest.NewServer(http.HandlerFunc(cb))
defer s.Close()
remote := &terraform.RemoteState{
Name: "foobar",
Server: s.URL,
}
r := &remoteStateClient{remote}
err := r.PutState(tc.Body, tc.Force)
errStr := ""
if err != nil {
errStr = err.Error()
}
if errStr != tc.ExpectErr {
t.Fatalf("bad err: %v %v", errStr, tc.ExpectErr)
}
}
}