306 lines
8.4 KiB
Go
306 lines
8.4 KiB
Go
|
package cos
|
||
|
|
||
|
import (
|
||
|
"crypto/hmac"
|
||
|
"crypto/sha1"
|
||
|
"fmt"
|
||
|
"hash"
|
||
|
"net/http"
|
||
|
"net/url"
|
||
|
"sort"
|
||
|
"strings"
|
||
|
"sync"
|
||
|
"time"
|
||
|
)
|
||
|
|
||
|
const sha1SignAlgorithm = "sha1"
|
||
|
const privateHeaderPrefix = "x-cos-"
|
||
|
const defaultAuthExpire = time.Hour
|
||
|
|
||
|
// 需要校验的 Headers 列表
|
||
|
var needSignHeaders = map[string]bool{
|
||
|
"host": true,
|
||
|
"range": true,
|
||
|
"x-cos-acl": true,
|
||
|
"x-cos-grant-read": true,
|
||
|
"x-cos-grant-write": true,
|
||
|
"x-cos-grant-full-control": true,
|
||
|
"response-content-type": true,
|
||
|
"response-content-language": true,
|
||
|
"response-expires": true,
|
||
|
"response-cache-control": true,
|
||
|
"response-content-disposition": true,
|
||
|
"response-content-encoding": true,
|
||
|
"cache-control": true,
|
||
|
"content-disposition": true,
|
||
|
"content-encoding": true,
|
||
|
"content-type": true,
|
||
|
"content-length": true,
|
||
|
"content-md5": true,
|
||
|
"expect": true,
|
||
|
"expires": true,
|
||
|
"x-cos-content-sha1": true,
|
||
|
"x-cos-storage-class": true,
|
||
|
"if-modified-since": true,
|
||
|
"origin": true,
|
||
|
"access-control-request-method": true,
|
||
|
"access-control-request-headers": true,
|
||
|
"x-cos-object-type": true,
|
||
|
}
|
||
|
|
||
|
func safeURLEncode(s string) string {
|
||
|
s = encodeURIComponent(s)
|
||
|
s = strings.Replace(s, "!", "%21", -1)
|
||
|
s = strings.Replace(s, "'", "%27", -1)
|
||
|
s = strings.Replace(s, "(", "%28", -1)
|
||
|
s = strings.Replace(s, ")", "%29", -1)
|
||
|
s = strings.Replace(s, "*", "%2A", -1)
|
||
|
return s
|
||
|
}
|
||
|
|
||
|
type valuesSignMap map[string][]string
|
||
|
|
||
|
func (vs valuesSignMap) Add(key, value string) {
|
||
|
key = strings.ToLower(key)
|
||
|
vs[key] = append(vs[key], value)
|
||
|
}
|
||
|
|
||
|
func (vs valuesSignMap) Encode() string {
|
||
|
var keys []string
|
||
|
for k := range vs {
|
||
|
keys = append(keys, k)
|
||
|
}
|
||
|
sort.Strings(keys)
|
||
|
|
||
|
var pairs []string
|
||
|
for _, k := range keys {
|
||
|
items := vs[k]
|
||
|
sort.Strings(items)
|
||
|
for _, val := range items {
|
||
|
pairs = append(
|
||
|
pairs,
|
||
|
fmt.Sprintf("%s=%s", safeURLEncode(k), safeURLEncode(val)))
|
||
|
}
|
||
|
}
|
||
|
return strings.Join(pairs, "&")
|
||
|
}
|
||
|
|
||
|
// AuthTime 用于生成签名所需的 q-sign-time 和 q-key-time 相关参数
|
||
|
type AuthTime struct {
|
||
|
SignStartTime time.Time
|
||
|
SignEndTime time.Time
|
||
|
KeyStartTime time.Time
|
||
|
KeyEndTime time.Time
|
||
|
}
|
||
|
|
||
|
// NewAuthTime 生成 AuthTime 的便捷函数
|
||
|
//
|
||
|
// expire: 从现在开始多久过期.
|
||
|
func NewAuthTime(expire time.Duration) *AuthTime {
|
||
|
signStartTime := time.Now()
|
||
|
keyStartTime := signStartTime
|
||
|
signEndTime := signStartTime.Add(expire)
|
||
|
keyEndTime := signEndTime
|
||
|
return &AuthTime{
|
||
|
SignStartTime: signStartTime,
|
||
|
SignEndTime: signEndTime,
|
||
|
KeyStartTime: keyStartTime,
|
||
|
KeyEndTime: keyEndTime,
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// signString return q-sign-time string
|
||
|
func (a *AuthTime) signString() string {
|
||
|
return fmt.Sprintf("%d;%d", a.SignStartTime.Unix(), a.SignEndTime.Unix())
|
||
|
}
|
||
|
|
||
|
// keyString return q-key-time string
|
||
|
func (a *AuthTime) keyString() string {
|
||
|
return fmt.Sprintf("%d;%d", a.KeyStartTime.Unix(), a.KeyEndTime.Unix())
|
||
|
}
|
||
|
|
||
|
// newAuthorization 通过一系列步骤生成最终需要的 Authorization 字符串
|
||
|
func newAuthorization(secretID, secretKey string, req *http.Request, authTime *AuthTime) string {
|
||
|
signTime := authTime.signString()
|
||
|
keyTime := authTime.keyString()
|
||
|
signKey := calSignKey(secretKey, keyTime)
|
||
|
|
||
|
formatHeaders := *new(string)
|
||
|
signedHeaderList := *new([]string)
|
||
|
formatHeaders, signedHeaderList = genFormatHeaders(req.Header)
|
||
|
formatParameters, signedParameterList := genFormatParameters(req.URL.Query())
|
||
|
formatString := genFormatString(req.Method, *req.URL, formatParameters, formatHeaders)
|
||
|
|
||
|
stringToSign := calStringToSign(sha1SignAlgorithm, keyTime, formatString)
|
||
|
signature := calSignature(signKey, stringToSign)
|
||
|
|
||
|
return genAuthorization(
|
||
|
secretID, signTime, keyTime, signature, signedHeaderList,
|
||
|
signedParameterList,
|
||
|
)
|
||
|
}
|
||
|
|
||
|
// AddAuthorizationHeader 给 req 增加签名信息
|
||
|
func AddAuthorizationHeader(secretID, secretKey string, sessionToken string, req *http.Request, authTime *AuthTime) {
|
||
|
if secretID == "" {
|
||
|
return
|
||
|
}
|
||
|
|
||
|
auth := newAuthorization(secretID, secretKey, req,
|
||
|
authTime,
|
||
|
)
|
||
|
if len(sessionToken) > 0 {
|
||
|
req.Header.Set("x-cos-security-token", sessionToken)
|
||
|
}
|
||
|
req.Header.Set("Authorization", auth)
|
||
|
}
|
||
|
|
||
|
// calSignKey 计算 SignKey
|
||
|
func calSignKey(secretKey, keyTime string) string {
|
||
|
digest := calHMACDigest(secretKey, keyTime, sha1SignAlgorithm)
|
||
|
return fmt.Sprintf("%x", digest)
|
||
|
}
|
||
|
|
||
|
// calStringToSign 计算 StringToSign
|
||
|
func calStringToSign(signAlgorithm, signTime, formatString string) string {
|
||
|
h := sha1.New()
|
||
|
h.Write([]byte(formatString))
|
||
|
return fmt.Sprintf("%s\n%s\n%x\n", signAlgorithm, signTime, h.Sum(nil))
|
||
|
}
|
||
|
|
||
|
// calSignature 计算 Signature
|
||
|
func calSignature(signKey, stringToSign string) string {
|
||
|
digest := calHMACDigest(signKey, stringToSign, sha1SignAlgorithm)
|
||
|
return fmt.Sprintf("%x", digest)
|
||
|
}
|
||
|
|
||
|
// genAuthorization 生成 Authorization
|
||
|
func genAuthorization(secretID, signTime, keyTime, signature string, signedHeaderList, signedParameterList []string) string {
|
||
|
return strings.Join([]string{
|
||
|
"q-sign-algorithm=" + sha1SignAlgorithm,
|
||
|
"q-ak=" + secretID,
|
||
|
"q-sign-time=" + signTime,
|
||
|
"q-key-time=" + keyTime,
|
||
|
"q-header-list=" + strings.Join(signedHeaderList, ";"),
|
||
|
"q-url-param-list=" + strings.Join(signedParameterList, ";"),
|
||
|
"q-signature=" + signature,
|
||
|
}, "&")
|
||
|
}
|
||
|
|
||
|
// genFormatString 生成 FormatString
|
||
|
func genFormatString(method string, uri url.URL, formatParameters, formatHeaders string) string {
|
||
|
formatMethod := strings.ToLower(method)
|
||
|
formatURI := uri.Path
|
||
|
|
||
|
return fmt.Sprintf("%s\n%s\n%s\n%s\n", formatMethod, formatURI,
|
||
|
formatParameters, formatHeaders,
|
||
|
)
|
||
|
}
|
||
|
|
||
|
// genFormatParameters 生成 FormatParameters 和 SignedParameterList
|
||
|
// instead of the url.Values{}
|
||
|
func genFormatParameters(parameters url.Values) (formatParameters string, signedParameterList []string) {
|
||
|
ps := valuesSignMap{}
|
||
|
for key, values := range parameters {
|
||
|
key = strings.ToLower(key)
|
||
|
for _, value := range values {
|
||
|
ps.Add(key, value)
|
||
|
signedParameterList = append(signedParameterList, key)
|
||
|
}
|
||
|
}
|
||
|
//formatParameters = strings.ToLower(ps.Encode())
|
||
|
formatParameters = ps.Encode()
|
||
|
sort.Strings(signedParameterList)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
// genFormatHeaders 生成 FormatHeaders 和 SignedHeaderList
|
||
|
func genFormatHeaders(headers http.Header) (formatHeaders string, signedHeaderList []string) {
|
||
|
hs := valuesSignMap{}
|
||
|
for key, values := range headers {
|
||
|
key = strings.ToLower(key)
|
||
|
for _, value := range values {
|
||
|
if isSignHeader(key) {
|
||
|
hs.Add(key, value)
|
||
|
signedHeaderList = append(signedHeaderList, key)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
formatHeaders = hs.Encode()
|
||
|
sort.Strings(signedHeaderList)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
// HMAC 签名
|
||
|
func calHMACDigest(key, msg, signMethod string) []byte {
|
||
|
var hashFunc func() hash.Hash
|
||
|
switch signMethod {
|
||
|
case "sha1":
|
||
|
hashFunc = sha1.New
|
||
|
default:
|
||
|
hashFunc = sha1.New
|
||
|
}
|
||
|
h := hmac.New(hashFunc, []byte(key))
|
||
|
h.Write([]byte(msg))
|
||
|
return h.Sum(nil)
|
||
|
}
|
||
|
|
||
|
func isSignHeader(key string) bool {
|
||
|
for k, v := range needSignHeaders {
|
||
|
if key == k && v {
|
||
|
return true
|
||
|
}
|
||
|
}
|
||
|
return strings.HasPrefix(key, privateHeaderPrefix)
|
||
|
}
|
||
|
|
||
|
// AuthorizationTransport 给请求增加 Authorization header
|
||
|
type AuthorizationTransport struct {
|
||
|
SecretID string
|
||
|
SecretKey string
|
||
|
SessionToken string
|
||
|
rwLocker sync.RWMutex
|
||
|
// 签名多久过期
|
||
|
Expire time.Duration
|
||
|
Transport http.RoundTripper
|
||
|
}
|
||
|
|
||
|
// SetCredential update the SecretID(ak), SercretKey(sk), sessiontoken
|
||
|
func (t *AuthorizationTransport) SetCredential(ak, sk, token string) {
|
||
|
t.rwLocker.Lock()
|
||
|
defer t.rwLocker.Unlock()
|
||
|
t.SecretID = ak
|
||
|
t.SecretKey = sk
|
||
|
t.SessionToken = token
|
||
|
}
|
||
|
|
||
|
// GetCredential get the ak, sk, token
|
||
|
func (t *AuthorizationTransport) GetCredential() (string, string, string) {
|
||
|
t.rwLocker.RLock()
|
||
|
defer t.rwLocker.RUnlock()
|
||
|
return t.SecretID, t.SecretKey, t.SessionToken
|
||
|
}
|
||
|
|
||
|
// RoundTrip implements the RoundTripper interface.
|
||
|
func (t *AuthorizationTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||
|
req = cloneRequest(req) // per RoundTrip contract
|
||
|
if t.Expire == time.Duration(0) {
|
||
|
t.Expire = defaultAuthExpire
|
||
|
}
|
||
|
|
||
|
ak, sk, token := t.GetCredential()
|
||
|
// 增加 Authorization header
|
||
|
authTime := NewAuthTime(t.Expire)
|
||
|
AddAuthorizationHeader(ak, sk, token, req, authTime)
|
||
|
|
||
|
resp, err := t.transport().RoundTrip(req)
|
||
|
return resp, err
|
||
|
}
|
||
|
|
||
|
func (t *AuthorizationTransport) transport() http.RoundTripper {
|
||
|
if t.Transport != nil {
|
||
|
return t.Transport
|
||
|
}
|
||
|
return http.DefaultTransport
|
||
|
}
|