Remove LGPL dependencies
This changeset performs the following: - Updates `masterzen/winrm` vendor to include change from (https://github.com/masterzen/winrm/pull/73) - Updates `dylanmei/winrmtest` vendor to include change from (https://github.com/dylanmei/winrmtest/pull/4) - Updates `packer-community/winrmcp` vendor to include the removal of the `masterzen/winrm/winrm` sub-class as a result of the `winrm` CLI tool being removed from the `masterzen/winrm` repository. - Changes `communicator/winrm/communicator.go` to conform to the new ABI in the `masterzen/winrm` library. This should completely remove any LGPL licensed dependencies inside of the Terraform project. ``` $ make test ==> Checking that code complies with gofmt requirements... go generate $(go list ./... | grep -v /terraform/vendor/) 2017/08/20 13:40:16 Generated command/internal_plugin_list.go go test -i $(go list ./... | grep -v '/terraform/vendor/' | grep -v '/builtin/bins/') || exit 1 echo $(go list ./... | grep -v '/terraform/vendor/' | grep -v '/builtin/bins/') | \ xargs -t -n4 go test -timeout=60s -parallel=4 go test -timeout=60s -parallel=4 github.com/hashicorp/terraform github.com/hashicorp/terraform/backend github.com/hashicorp/terraform/backend/atlas github.com/hashicorp/terraform/backend/init ok github.com/hashicorp/terraform 0.011s ok github.com/hashicorp/terraform/backend 0.020s ok github.com/hashicorp/terraform/backend/atlas 0.634s ok github.com/hashicorp/terraform/backend/init 0.007s go test -timeout=60s -parallel=4 github.com/hashicorp/terraform/backend/legacy github.com/hashicorp/terraform/backend/local github.com/hashicorp/terraform/backend/remote-state github.com/hashicorp/terraf orm/backend/remote-state/azure ok github.com/hashicorp/terraform/backend/legacy 0.009s ok github.com/hashicorp/terraform/backend/local 0.211s ok github.com/hashicorp/terraform/backend/remote-state 0.006s ok github.com/hashicorp/terraform/backend/remote-state/azure 0.010s go test -timeout=60s -parallel=4 github.com/hashicorp/terraform/backend/remote-state/consul github.com/hashicorp/terraform/backend/remote-state/inmem github.com/hashicorp/terraform/backend/remote-state/s 3 github.com/hashicorp/terraform/backend/remote-state/swift ok github.com/hashicorp/terraform/backend/remote-state/consul 0.007s ok github.com/hashicorp/terraform/backend/remote-state/inmem 0.013s ok github.com/hashicorp/terraform/backend/remote-state/s3 0.007s ok github.com/hashicorp/terraform/backend/remote-state/swift 0.013s go test -timeout=60s -parallel=4 github.com/hashicorp/terraform/builtin/providers/test github.com/hashicorp/terraform/builtin/provisioners/chef github.com/hashicorp/terraform/builtin/provisioners/file gi thub.com/hashicorp/terraform/builtin/provisioners/local-exec ok github.com/hashicorp/terraform/builtin/providers/test 1.544s ok github.com/hashicorp/terraform/builtin/provisioners/chef 0.017s ok github.com/hashicorp/terraform/builtin/provisioners/file 0.006s ok github.com/hashicorp/terraform/builtin/provisioners/local-exec 0.078s go test -timeout=60s -parallel=4 github.com/hashicorp/terraform/builtin/provisioners/remote-exec github.com/hashicorp/terraform/builtin/provisioners/salt-masterless github.com/hashicorp/terraform/command github.com/hashicorp/terraform/command/clistate ok github.com/hashicorp/terraform/builtin/provisioners/remote-exec 1.037s ok github.com/hashicorp/terraform/builtin/provisioners/salt-masterless 0.008s ok github.com/hashicorp/terraform/command 14.589s ? github.com/hashicorp/terraform/command/clistate [no test files] go test -timeout=60s -parallel=4 github.com/hashicorp/terraform/command/e2etest github.com/hashicorp/terraform/command/format github.com/hashicorp/terraform/communicator github.com/hashicorp/terraform/co mmunicator/remote ok github.com/hashicorp/terraform/command/e2etest 3.729s ok github.com/hashicorp/terraform/command/format 0.004s ok github.com/hashicorp/terraform/communicator 0.005s ok github.com/hashicorp/terraform/communicator/remote 0.003s [no tests to run] go test -timeout=60s -parallel=4 github.com/hashicorp/terraform/communicator/shared github.com/hashicorp/terraform/communicator/ssh github.com/hashicorp/terraform/communicator/winrm github.com/hashicorp/ terraform/config ok github.com/hashicorp/terraform/communicator/shared 0.007s ok github.com/hashicorp/terraform/communicator/ssh 0.016s ok github.com/hashicorp/terraform/communicator/winrm 0.018s ok github.com/hashicorp/terraform/config 0.213s go test -timeout=60s -parallel=4 github.com/hashicorp/terraform/config/module github.com/hashicorp/terraform/dag github.com/hashicorp/terraform/digraph github.com/hashicorp/terraform/flatmap ok github.com/hashicorp/terraform/config/module 0.044s ok github.com/hashicorp/terraform/dag 0.010s ok github.com/hashicorp/terraform/digraph 0.002s ok github.com/hashicorp/terraform/flatmap 0.002s go test -timeout=60s -parallel=4 github.com/hashicorp/terraform/helper/acctest github.com/hashicorp/terraform/helper/config github.com/hashicorp/terraform/helper/copy github.com/hashicorp/terraform/helpe r/diff ? github.com/hashicorp/terraform/helper/acctest [no test files] ok github.com/hashicorp/terraform/helper/config 0.005s ? github.com/hashicorp/terraform/helper/copy [no test files] ok github.com/hashicorp/terraform/helper/diff 0.005s go test -timeout=60s -parallel=4 github.com/hashicorp/terraform/helper/encryption github.com/hashicorp/terraform/helper/experiment github.com/hashicorp/terraform/helper/hashcode github.com/hashicorp/terr aform/helper/hilmapstructure ? github.com/hashicorp/terraform/helper/encryption [no test files] ok github.com/hashicorp/terraform/helper/experiment 0.001s ok github.com/hashicorp/terraform/helper/hashcode 0.001s ? github.com/hashicorp/terraform/helper/hilmapstructure [no test files] go test -timeout=60s -parallel=4 github.com/hashicorp/terraform/helper/logging github.com/hashicorp/terraform/helper/mutexkv github.com/hashicorp/terraform/helper/pathorcontents github.com/hashicorp/terr aform/helper/resource ? github.com/hashicorp/terraform/helper/logging [no test files] ok github.com/hashicorp/terraform/helper/mutexkv 0.055s ok github.com/hashicorp/terraform/helper/pathorcontents 0.002s ok github.com/hashicorp/terraform/helper/resource 2.659s go test -timeout=60s -parallel=4 github.com/hashicorp/terraform/helper/schema github.com/hashicorp/terraform/helper/shadow github.com/hashicorp/terraform/helper/signalwrapper github.com/hashicorp/terrafo rm/helper/slowmessage ok github.com/hashicorp/terraform/helper/schema 0.063s ok github.com/hashicorp/terraform/helper/shadow 0.156s ok github.com/hashicorp/terraform/helper/signalwrapper 0.022s ok github.com/hashicorp/terraform/helper/slowmessage 0.102s go test -timeout=60s -parallel=4 github.com/hashicorp/terraform/helper/structure github.com/hashicorp/terraform/helper/validation github.com/hashicorp/terraform/helper/variables github.com/hashicorp/terr aform/helper/wrappedreadline ok github.com/hashicorp/terraform/helper/structure 0.004s ok github.com/hashicorp/terraform/helper/validation 0.004s ok github.com/hashicorp/terraform/helper/variables 0.006s ? github.com/hashicorp/terraform/helper/wrappedreadline [no test files] go test -timeout=60s -parallel=4 github.com/hashicorp/terraform/helper/wrappedstreams github.com/hashicorp/terraform/moduledeps github.com/hashicorp/terraform/plugin github.com/hashicorp/terraform/plugin /discovery ? github.com/hashicorp/terraform/helper/wrappedstreams [no test files] ok github.com/hashicorp/terraform/moduledeps 0.004s ok github.com/hashicorp/terraform/plugin 0.046s ok github.com/hashicorp/terraform/plugin/discovery 0.029s go test -timeout=60s -parallel=4 github.com/hashicorp/terraform/repl github.com/hashicorp/terraform/scripts github.com/hashicorp/terraform/state github.com/hashicorp/terraform/state/remote ok github.com/hashicorp/terraform/repl 0.006s ok github.com/hashicorp/terraform/scripts 0.008s ok github.com/hashicorp/terraform/state 2.617s ok github.com/hashicorp/terraform/state/remote 0.025s go test -timeout=60s -parallel=4 github.com/hashicorp/terraform/terraform github.com/hashicorp/terraform/tools/terraform-bundle go test -timeout=60s -parallel=4 github.com/hashicorp/terraform/terraform github.com/hashicorp/terraform/tools/terraform-bundle ok github.com/hashicorp/terraform/terraform 4.222s ? github.com/hashicorp/terraform/tools/terraform-bundle [no test files] ```
This commit is contained in:
parent
e359930530
commit
6e599672e1
|
@ -12,7 +12,7 @@ import (
|
|||
|
||||
"github.com/hashicorp/terraform/communicator/remote"
|
||||
"github.com/hashicorp/terraform/terraform"
|
||||
"github.com/masterzen/winrm/winrm"
|
||||
"github.com/masterzen/winrm"
|
||||
"github.com/packer-community/winrmcp/winrmcp"
|
||||
|
||||
// This import is a bit strange, but it's needed so `make updatedeps` can see and download it
|
||||
|
@ -39,7 +39,10 @@ func New(s *terraform.InstanceState) (*Communicator, error) {
|
|||
Port: connInfo.Port,
|
||||
HTTPS: connInfo.HTTPS,
|
||||
Insecure: connInfo.Insecure,
|
||||
CACert: connInfo.CACert,
|
||||
}
|
||||
|
||||
if connInfo.CACert != nil {
|
||||
endpoint.CACert = *connInfo.CACert
|
||||
}
|
||||
|
||||
comm := &Communicator{
|
||||
|
@ -58,7 +61,7 @@ func (c *Communicator) Connect(o terraform.UIOutput) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
params := winrm.DefaultParameters()
|
||||
params := winrm.DefaultParameters
|
||||
params.Timeout = formatDuration(c.Timeout())
|
||||
|
||||
client, err := winrm.NewClientWithParameters(
|
||||
|
|
|
@ -0,0 +1,21 @@
|
|||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2016 Microsoft
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
|
@ -0,0 +1,29 @@
|
|||
# go-ntlmssp
|
||||
Golang package that provides NTLM/Negotiate authentication over HTTP
|
||||
|
||||
[![GoDoc](https://godoc.org/github.com/Azure/go-ntlmssp?status.svg)](https://godoc.org/github.com/Azure/go-ntlmssp) [![Build Status](https://travis-ci.org/Azure/go-ntlmssp.svg?branch=dev)](https://travis-ci.org/Azure/go-ntlmssp)
|
||||
|
||||
Protocol details from https://msdn.microsoft.com/en-us/library/cc236621.aspx
|
||||
Implementation hints from http://davenport.sourceforge.net/ntlm.html
|
||||
|
||||
This package only implements authentication, no key exchange or encryption. It
|
||||
only supports Unicode (UTF16LE) encoding of protocol strings, no OEM encoding.
|
||||
This package implements NTLMv2.
|
||||
|
||||
# Usage
|
||||
|
||||
```
|
||||
url, user, password := "http://www.example.com/secrets", "robpike", "pw123"
|
||||
client := &http.Client{
|
||||
Transport: ntlmssp.Negotiator{
|
||||
RoundTripper:&http.Transport{},
|
||||
},
|
||||
}
|
||||
|
||||
req, _ := http.NewRequest("GET", url, nil)
|
||||
req.SetBasicAuth(user, password)
|
||||
res, _ := client.Do(req)
|
||||
```
|
||||
|
||||
-----
|
||||
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.
|
|
@ -0,0 +1,128 @@
|
|||
package ntlmssp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"time"
|
||||
)
|
||||
|
||||
type authenicateMessage struct {
|
||||
LmChallengeResponse []byte
|
||||
NtChallengeResponse []byte
|
||||
|
||||
TargetName string
|
||||
UserName string
|
||||
|
||||
// only set if negotiateFlag_NTLMSSP_NEGOTIATE_KEY_EXCH
|
||||
EncryptedRandomSessionKey []byte
|
||||
|
||||
NegotiateFlags negotiateFlags
|
||||
|
||||
MIC []byte
|
||||
}
|
||||
|
||||
type authenticateMessageFields struct {
|
||||
messageHeader
|
||||
LmChallengeResponse varField
|
||||
NtChallengeResponse varField
|
||||
TargetName varField
|
||||
UserName varField
|
||||
Workstation varField
|
||||
_ [8]byte
|
||||
NegotiateFlags negotiateFlags
|
||||
}
|
||||
|
||||
func (m authenicateMessage) MarshalBinary() ([]byte, error) {
|
||||
if !m.NegotiateFlags.Has(negotiateFlagNTLMSSPNEGOTIATEUNICODE) {
|
||||
return nil, errors.New("Only unicode is supported")
|
||||
}
|
||||
|
||||
target, user := toUnicode(m.TargetName), toUnicode(m.UserName)
|
||||
workstation := toUnicode("go-ntlmssp")
|
||||
|
||||
ptr := binary.Size(&authenticateMessageFields{})
|
||||
f := authenticateMessageFields{
|
||||
messageHeader: newMessageHeader(3),
|
||||
NegotiateFlags: m.NegotiateFlags,
|
||||
LmChallengeResponse: newVarField(&ptr, len(m.LmChallengeResponse)),
|
||||
NtChallengeResponse: newVarField(&ptr, len(m.NtChallengeResponse)),
|
||||
TargetName: newVarField(&ptr, len(target)),
|
||||
UserName: newVarField(&ptr, len(user)),
|
||||
Workstation: newVarField(&ptr, len(workstation)),
|
||||
}
|
||||
|
||||
f.NegotiateFlags.Unset(negotiateFlagNTLMSSPNEGOTIATEVERSION)
|
||||
|
||||
b := bytes.Buffer{}
|
||||
if err := binary.Write(&b, binary.LittleEndian, &f); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := binary.Write(&b, binary.LittleEndian, &m.LmChallengeResponse); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := binary.Write(&b, binary.LittleEndian, &m.NtChallengeResponse); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := binary.Write(&b, binary.LittleEndian, &target); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := binary.Write(&b, binary.LittleEndian, &user); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := binary.Write(&b, binary.LittleEndian, &workstation); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return b.Bytes(), nil
|
||||
}
|
||||
|
||||
//ProcessChallenge crafts an AUTHENTICATE message in response to the CHALLENGE message
|
||||
//that was received from the server
|
||||
func ProcessChallenge(challengeMessageData []byte, user, password string) ([]byte, error) {
|
||||
if user == "" && password == "" {
|
||||
return nil, errors.New("Anonymous authentication not supported")
|
||||
}
|
||||
|
||||
var cm challengeMessage
|
||||
if err := cm.UnmarshalBinary(challengeMessageData); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if cm.NegotiateFlags.Has(negotiateFlagNTLMSSPNEGOTIATELMKEY) {
|
||||
return nil, errors.New("Only NTLM v2 is supported, but server requested v1 (NTLMSSP_NEGOTIATE_LM_KEY)")
|
||||
}
|
||||
if cm.NegotiateFlags.Has(negotiateFlagNTLMSSPNEGOTIATEKEYEXCH) {
|
||||
return nil, errors.New("Key exchange requested but not supported (NTLMSSP_NEGOTIATE_KEY_EXCH)")
|
||||
}
|
||||
|
||||
am := authenicateMessage{
|
||||
UserName: user,
|
||||
TargetName: cm.TargetName,
|
||||
NegotiateFlags: cm.NegotiateFlags,
|
||||
}
|
||||
|
||||
timestamp := cm.TargetInfo[avIDMsvAvTimestamp]
|
||||
if timestamp == nil { // no time sent, take current time
|
||||
ft := uint64(time.Now().UnixNano()) / 100
|
||||
ft += 116444736000000000 // add time between unix & windows offset
|
||||
timestamp = make([]byte, 8)
|
||||
binary.LittleEndian.PutUint64(timestamp, ft)
|
||||
}
|
||||
|
||||
clientChallenge := make([]byte, 8)
|
||||
rand.Reader.Read(clientChallenge)
|
||||
|
||||
ntlmV2Hash := getNtlmV2Hash(password, user, cm.TargetName)
|
||||
|
||||
am.NtChallengeResponse = computeNtlmV2Response(ntlmV2Hash,
|
||||
cm.ServerChallenge[:], clientChallenge, timestamp, cm.TargetInfoRaw)
|
||||
|
||||
if cm.TargetInfoRaw == nil {
|
||||
am.LmChallengeResponse = computeLmV2Response(ntlmV2Hash,
|
||||
cm.ServerChallenge[:], clientChallenge)
|
||||
}
|
||||
|
||||
return am.MarshalBinary()
|
||||
}
|
|
@ -0,0 +1,37 @@
|
|||
package ntlmssp
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type authheader string
|
||||
|
||||
func (h authheader) IsBasic() bool {
|
||||
return strings.HasPrefix(string(h), "Basic ")
|
||||
}
|
||||
|
||||
func (h authheader) IsNegotiate() bool {
|
||||
return strings.HasPrefix(string(h), "Negotiate")
|
||||
}
|
||||
|
||||
func (h authheader) IsNTLM() bool {
|
||||
return strings.HasPrefix(string(h), "NTLM")
|
||||
}
|
||||
|
||||
func (h authheader) GetData() ([]byte, error) {
|
||||
p := strings.Split(string(h), " ")
|
||||
if len(p) < 2 {
|
||||
return nil, nil
|
||||
}
|
||||
return base64.StdEncoding.DecodeString(string(p[1]))
|
||||
}
|
||||
|
||||
func (h authheader) GetBasicCreds() (username, password string, err error) {
|
||||
d, err := h.GetData()
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
parts := strings.SplitN(string(d), ":", 2)
|
||||
return parts[0], parts[1], nil
|
||||
}
|
|
@ -0,0 +1,17 @@
|
|||
package ntlmssp
|
||||
|
||||
type avID uint16
|
||||
|
||||
const (
|
||||
avIDMsvAvEOL avID = iota
|
||||
avIDMsvAvNbComputerName
|
||||
avIDMsvAvNbDomainName
|
||||
avIDMsvAvDNSComputerName
|
||||
avIDMsvAvDNSDomainName
|
||||
avIDMsvAvDNSTreeName
|
||||
avIDMsvAvFlags
|
||||
avIDMsvAvTimestamp
|
||||
avIDMsvAvSingleHost
|
||||
avIDMsvAvTargetName
|
||||
avIDMsvChannelBindings
|
||||
)
|
|
@ -0,0 +1,82 @@
|
|||
package ntlmssp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
type challengeMessageFields struct {
|
||||
messageHeader
|
||||
TargetName varField
|
||||
NegotiateFlags negotiateFlags
|
||||
ServerChallenge [8]byte
|
||||
_ [8]byte
|
||||
TargetInfo varField
|
||||
}
|
||||
|
||||
func (m challengeMessageFields) IsValid() bool {
|
||||
return m.messageHeader.IsValid() && m.MessageType == 2
|
||||
}
|
||||
|
||||
type challengeMessage struct {
|
||||
challengeMessageFields
|
||||
TargetName string
|
||||
TargetInfo map[avID][]byte
|
||||
TargetInfoRaw []byte
|
||||
}
|
||||
|
||||
func (m *challengeMessage) UnmarshalBinary(data []byte) error {
|
||||
r := bytes.NewReader(data)
|
||||
err := binary.Read(r, binary.LittleEndian, &m.challengeMessageFields)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !m.challengeMessageFields.IsValid() {
|
||||
return fmt.Errorf("Message is not a valid challenge message: %+v", m.challengeMessageFields.messageHeader)
|
||||
}
|
||||
|
||||
if m.challengeMessageFields.TargetName.Len > 0 {
|
||||
m.TargetName, err = m.challengeMessageFields.TargetName.ReadStringFrom(data, m.NegotiateFlags.Has(negotiateFlagNTLMSSPNEGOTIATEUNICODE))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if m.challengeMessageFields.TargetInfo.Len > 0 {
|
||||
d, err := m.challengeMessageFields.TargetInfo.ReadFrom(data)
|
||||
m.TargetInfoRaw = d
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.TargetInfo = make(map[avID][]byte)
|
||||
r := bytes.NewReader(d)
|
||||
for {
|
||||
var id avID
|
||||
var l uint16
|
||||
err = binary.Read(r, binary.LittleEndian, &id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if id == avIDMsvAvEOL {
|
||||
break
|
||||
}
|
||||
|
||||
err = binary.Read(r, binary.LittleEndian, &l)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
value := make([]byte, l)
|
||||
n, err := r.Read(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if n != int(l) {
|
||||
return fmt.Errorf("Expected to read %d bytes, got only %d", l, n)
|
||||
}
|
||||
m.TargetInfo[id] = value
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,21 @@
|
|||
package ntlmssp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
)
|
||||
|
||||
var signature = [8]byte{'N', 'T', 'L', 'M', 'S', 'S', 'P', 0}
|
||||
|
||||
type messageHeader struct {
|
||||
Signature [8]byte
|
||||
MessageType uint32
|
||||
}
|
||||
|
||||
func (h messageHeader) IsValid() bool {
|
||||
return bytes.Equal(h.Signature[:], signature[:]) &&
|
||||
h.MessageType > 0 && h.MessageType < 4
|
||||
}
|
||||
|
||||
func newMessageHeader(messageType uint32) messageHeader {
|
||||
return messageHeader{signature, messageType}
|
||||
}
|
|
@ -0,0 +1,52 @@
|
|||
package ntlmssp
|
||||
|
||||
type negotiateFlags uint32
|
||||
|
||||
const (
|
||||
/*A*/ negotiateFlagNTLMSSPNEGOTIATEUNICODE negotiateFlags = 1 << 0
|
||||
/*B*/ negotiateFlagNTLMNEGOTIATEOEM = 1 << 1
|
||||
/*C*/ negotiateFlagNTLMSSPREQUESTTARGET = 1 << 2
|
||||
|
||||
/*D*/
|
||||
negotiateFlagNTLMSSPNEGOTIATESIGN = 1 << 4
|
||||
/*E*/ negotiateFlagNTLMSSPNEGOTIATESEAL = 1 << 5
|
||||
/*F*/ negotiateFlagNTLMSSPNEGOTIATEDATAGRAM = 1 << 6
|
||||
/*G*/ negotiateFlagNTLMSSPNEGOTIATELMKEY = 1 << 7
|
||||
|
||||
/*H*/
|
||||
negotiateFlagNTLMSSPNEGOTIATENTLM = 1 << 9
|
||||
|
||||
/*J*/
|
||||
negotiateFlagANONYMOUS = 1 << 11
|
||||
/*K*/ negotiateFlagNTLMSSPNEGOTIATEOEMDOMAINSUPPLIED = 1 << 12
|
||||
/*L*/ negotiateFlagNTLMSSPNEGOTIATEOEMWORKSTATIONSUPPLIED = 1 << 13
|
||||
|
||||
/*M*/
|
||||
negotiateFlagNTLMSSPNEGOTIATEALWAYSSIGN = 1 << 15
|
||||
/*N*/ negotiateFlagNTLMSSPTARGETTYPEDOMAIN = 1 << 16
|
||||
/*O*/ negotiateFlagNTLMSSPTARGETTYPESERVER = 1 << 17
|
||||
|
||||
/*P*/
|
||||
negotiateFlagNTLMSSPNEGOTIATEEXTENDEDSESSIONSECURITY = 1 << 19
|
||||
/*Q*/ negotiateFlagNTLMSSPNEGOTIATEIDENTIFY = 1 << 20
|
||||
|
||||
/*R*/
|
||||
negotiateFlagNTLMSSPREQUESTNONNTSESSIONKEY = 1 << 22
|
||||
/*S*/ negotiateFlagNTLMSSPNEGOTIATETARGETINFO = 1 << 23
|
||||
|
||||
/*T*/
|
||||
negotiateFlagNTLMSSPNEGOTIATEVERSION = 1 << 25
|
||||
|
||||
/*U*/
|
||||
negotiateFlagNTLMSSPNEGOTIATE128 = 1 << 29
|
||||
/*V*/ negotiateFlagNTLMSSPNEGOTIATEKEYEXCH = 1 << 30
|
||||
/*W*/ negotiateFlagNTLMSSPNEGOTIATE56 = 1 << 31
|
||||
)
|
||||
|
||||
func (field negotiateFlags) Has(flags negotiateFlags) bool {
|
||||
return field&flags == flags
|
||||
}
|
||||
|
||||
func (field *negotiateFlags) Unset(flags negotiateFlags) {
|
||||
*field = *field ^ (*field & flags)
|
||||
}
|
|
@ -0,0 +1,31 @@
|
|||
package ntlmssp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
)
|
||||
|
||||
type negotiateMessageFields struct {
|
||||
messageHeader
|
||||
NegotiateFlags negotiateFlags
|
||||
}
|
||||
|
||||
//NewNegotiateMessage creates a new NEGOTIATE message with the
|
||||
//flags that this package supports.
|
||||
func NewNegotiateMessage() []byte {
|
||||
m := negotiateMessageFields{
|
||||
messageHeader: newMessageHeader(1),
|
||||
}
|
||||
|
||||
m.NegotiateFlags = negotiateFlagNTLMSSPREQUESTTARGET |
|
||||
negotiateFlagNTLMSSPNEGOTIATENTLM |
|
||||
negotiateFlagNTLMSSPNEGOTIATEALWAYSSIGN |
|
||||
negotiateFlagNTLMSSPNEGOTIATEUNICODE
|
||||
|
||||
b := bytes.Buffer{}
|
||||
err := binary.Write(&b, binary.LittleEndian, &m)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return b.Bytes()
|
||||
}
|
|
@ -0,0 +1,124 @@
|
|||
package ntlmssp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
//Negotiator is a http.Roundtripper decorator that automatically
|
||||
//converts basic authentication to NTLM/Negotiate authentication when appropriate.
|
||||
type Negotiator struct{ http.RoundTripper }
|
||||
|
||||
//RoundTrip sends the request to the server, handling any authentication
|
||||
//re-sends as needed.
|
||||
func (l Negotiator) RoundTrip(req *http.Request) (res *http.Response, err error) {
|
||||
// Use default round tripper if not provided
|
||||
rt := l.RoundTripper
|
||||
if rt == nil {
|
||||
rt = http.DefaultTransport
|
||||
}
|
||||
// If it is not basic auth, just round trip the request as usual
|
||||
reqauth := authheader(req.Header.Get("Authorization"))
|
||||
if !reqauth.IsBasic() {
|
||||
return rt.RoundTrip(req)
|
||||
}
|
||||
// Save request body
|
||||
body := bytes.Buffer{}
|
||||
if req.Body != nil {
|
||||
_, err = body.ReadFrom(req.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req.Body.Close()
|
||||
req.Body = ioutil.NopCloser(bytes.NewReader(body.Bytes()))
|
||||
}
|
||||
// first try anonymous, in case the server still finds us
|
||||
// authenticated from previous traffic
|
||||
req.Header.Del("Authorization")
|
||||
res, err = rt.RoundTrip(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if res.StatusCode != http.StatusUnauthorized {
|
||||
return res, err
|
||||
}
|
||||
|
||||
resauth := authheader(res.Header.Get("Www-Authenticate"))
|
||||
if !resauth.IsNegotiate() && !resauth.IsNTLM() {
|
||||
// Unauthorized, Negotiate not requested, let's try with basic auth
|
||||
req.Header.Set("Authorization", string(reqauth))
|
||||
io.Copy(ioutil.Discard, res.Body)
|
||||
res.Body.Close()
|
||||
req.Body = ioutil.NopCloser(bytes.NewReader(body.Bytes()))
|
||||
|
||||
res, err = rt.RoundTrip(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if res.StatusCode != http.StatusUnauthorized {
|
||||
return res, err
|
||||
}
|
||||
resauth = authheader(res.Header.Get("Www-Authenticate"))
|
||||
}
|
||||
|
||||
if resauth.IsNegotiate() || resauth.IsNTLM() {
|
||||
// 401 with request:Basic and response:Negotiate
|
||||
io.Copy(ioutil.Discard, res.Body)
|
||||
res.Body.Close()
|
||||
|
||||
// recycle credentials
|
||||
u, p, err := reqauth.GetBasicCreds()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// send negotiate
|
||||
negotiateMessage := NewNegotiateMessage()
|
||||
if resauth.IsNTLM() {
|
||||
req.Header.Set("Authorization", "NTLM "+base64.StdEncoding.EncodeToString(negotiateMessage))
|
||||
} else {
|
||||
req.Header.Set("Authorization", "Negotiate "+base64.StdEncoding.EncodeToString(negotiateMessage))
|
||||
}
|
||||
|
||||
req.Body = ioutil.NopCloser(bytes.NewReader(body.Bytes()))
|
||||
|
||||
res, err = rt.RoundTrip(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// receive challenge?
|
||||
resauth = authheader(res.Header.Get("Www-Authenticate"))
|
||||
challengeMessage, err := resauth.GetData()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !(resauth.IsNegotiate() || resauth.IsNTLM()) || len(challengeMessage) == 0 {
|
||||
// Negotiation failed, let client deal with response
|
||||
return res, nil
|
||||
}
|
||||
io.Copy(ioutil.Discard, res.Body)
|
||||
res.Body.Close()
|
||||
|
||||
// send authenticate
|
||||
authenticateMessage, err := ProcessChallenge(challengeMessage, u, p)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resauth.IsNTLM() {
|
||||
req.Header.Set("Authorization", "NTLM "+base64.StdEncoding.EncodeToString(authenticateMessage))
|
||||
} else {
|
||||
req.Header.Set("Authorization", "Negotiate "+base64.StdEncoding.EncodeToString(authenticateMessage))
|
||||
}
|
||||
|
||||
req.Body = ioutil.NopCloser(bytes.NewReader(body.Bytes()))
|
||||
|
||||
res, err = rt.RoundTrip(req)
|
||||
}
|
||||
|
||||
return res, err
|
||||
}
|
|
@ -0,0 +1,51 @@
|
|||
// Package ntlmssp provides NTLM/Negotiate authentication over HTTP
|
||||
//
|
||||
// Protocol details from https://msdn.microsoft.com/en-us/library/cc236621.aspx,
|
||||
// implementation hints from http://davenport.sourceforge.net/ntlm.html .
|
||||
// This package only implements authentication, no key exchange or encryption. It
|
||||
// only supports Unicode (UTF16LE) encoding of protocol strings, no OEM encoding.
|
||||
// This package implements NTLMv2.
|
||||
package ntlmssp
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/md5"
|
||||
"golang.org/x/crypto/md4"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func getNtlmV2Hash(password, username, target string) []byte {
|
||||
return hmacMd5(getNtlmHash(password), toUnicode(strings.ToUpper(username)+target))
|
||||
}
|
||||
|
||||
func getNtlmHash(password string) []byte {
|
||||
hash := md4.New()
|
||||
hash.Write(toUnicode(password))
|
||||
return hash.Sum(nil)
|
||||
}
|
||||
|
||||
func computeNtlmV2Response(ntlmV2Hash, serverChallenge, clientChallenge,
|
||||
timestamp, targetInfo []byte) []byte {
|
||||
|
||||
temp := []byte{1, 1, 0, 0, 0, 0, 0, 0}
|
||||
temp = append(temp, timestamp...)
|
||||
temp = append(temp, clientChallenge...)
|
||||
temp = append(temp, 0, 0, 0, 0)
|
||||
temp = append(temp, targetInfo...)
|
||||
temp = append(temp, 0, 0, 0, 0)
|
||||
|
||||
NTProofStr := hmacMd5(ntlmV2Hash, serverChallenge, temp)
|
||||
return append(NTProofStr, temp...)
|
||||
}
|
||||
|
||||
func computeLmV2Response(ntlmV2Hash, serverChallenge, clientChallenge []byte) []byte {
|
||||
return append(hmacMd5(ntlmV2Hash, serverChallenge, clientChallenge), clientChallenge...)
|
||||
}
|
||||
|
||||
func hmacMd5(key []byte, data ...[]byte) []byte {
|
||||
mac := hmac.New(md5.New, key)
|
||||
for _, d := range data {
|
||||
mac.Write(d)
|
||||
}
|
||||
return mac.Sum(nil)
|
||||
}
|
|
@ -0,0 +1,29 @@
|
|||
package ntlmssp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"unicode/utf16"
|
||||
)
|
||||
|
||||
// helper func's for dealing with Windows Unicode (UTF16LE)
|
||||
|
||||
func fromUnicode(d []byte) (string, error) {
|
||||
if len(d)%2 > 0 {
|
||||
return "", errors.New("Unicode (UTF 16 LE) specified, but uneven data length")
|
||||
}
|
||||
s := make([]uint16, len(d)/2)
|
||||
err := binary.Read(bytes.NewReader(d), binary.LittleEndian, &s)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(utf16.Decode(s)), nil
|
||||
}
|
||||
|
||||
func toUnicode(s string) []byte {
|
||||
uints := utf16.Encode([]rune(s))
|
||||
b := bytes.Buffer{}
|
||||
binary.Write(&b, binary.LittleEndian, &uints)
|
||||
return b.Bytes()
|
||||
}
|
|
@ -0,0 +1,40 @@
|
|||
package ntlmssp
|
||||
|
||||
import (
|
||||
"errors"
|
||||
)
|
||||
|
||||
type varField struct {
|
||||
Len uint16
|
||||
MaxLen uint16
|
||||
BufferOffset uint32
|
||||
}
|
||||
|
||||
func (f varField) ReadFrom(buffer []byte) ([]byte, error) {
|
||||
if len(buffer) < int(f.BufferOffset+uint32(f.Len)) {
|
||||
return nil, errors.New("Error reading data, varField extends beyond buffer")
|
||||
}
|
||||
return buffer[f.BufferOffset : f.BufferOffset+uint32(f.Len)], nil
|
||||
}
|
||||
|
||||
func (f varField) ReadStringFrom(buffer []byte, unicode bool) (string, error) {
|
||||
d, err := f.ReadFrom(buffer)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if unicode { // UTF-16LE encoding scheme
|
||||
return fromUnicode(d)
|
||||
}
|
||||
// OEM encoding, close enough to ASCII, since no code page is specified
|
||||
return string(d), err
|
||||
}
|
||||
|
||||
func newVarField(ptr *int, fieldsize int) varField {
|
||||
f := varField{
|
||||
Len: uint16(fieldsize),
|
||||
MaxLen: uint16(fieldsize),
|
||||
BufferOffset: uint32(*ptr),
|
||||
}
|
||||
*ptr += fieldsize
|
||||
return f
|
||||
}
|
|
@ -0,0 +1,22 @@
|
|||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2015 ChrisTrenkamp
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
|
|
@ -0,0 +1,2 @@
|
|||
# goxpath [![GoDoc](https://godoc.org/gopkg.in/src-d/go-git.v2?status.svg)](https://godoc.org/github.com/ChrisTrenkamp/goxpath) [![Build Status](https://travis-ci.org/ChrisTrenkamp/goxpath.svg?branch=master)](https://travis-ci.org/ChrisTrenkamp/goxpath) [![codecov.io](https://codecov.io/github/ChrisTrenkamp/goxpath/coverage.svg?branch=master)](https://codecov.io/github/ChrisTrenkamp/goxpath?branch=master) [![Go Report Card](https://goreportcard.com/badge/github.com/ChrisTrenkamp/goxpath)](https://goreportcard.com/report/github.com/ChrisTrenkamp/goxpath)
|
||||
An XPath 1.0 implementation written in Go. See the [wiki](https://github.com/ChrisTrenkamp/goxpath/wiki) for more information.
|
|
@ -0,0 +1,17 @@
|
|||
#!/bin/bash
|
||||
cd "$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
|
||||
go get github.com/ChrisTrenkamp/goxpath/cmd/goxpath
|
||||
if [ $? != 0 ]; then
|
||||
exit 1
|
||||
fi
|
||||
go test >/dev/null 2>&1
|
||||
if [ $? != 0 ]; then
|
||||
go test
|
||||
exit 1
|
||||
fi
|
||||
gometalinter --deadline=1m ./...
|
||||
go list -f '{{if gt (len .TestGoFiles) 0}}"go test -covermode count -coverprofile {{.Name}}.coverprofile -coverpkg ./... {{.ImportPath}}"{{end}} >/dev/null' ./... | xargs -I {} bash -c {} 2>/dev/null
|
||||
gocovmerge `ls *.coverprofile` > coverage.txt
|
||||
go tool cover -html=coverage.txt -o coverage.html
|
||||
firefox coverage.html
|
||||
rm coverage.html coverage.txt *.coverprofile
|
|
@ -0,0 +1,117 @@
|
|||
package goxpath
|
||||
|
||||
import (
|
||||
"encoding/xml"
|
||||
"fmt"
|
||||
|
||||
"github.com/ChrisTrenkamp/goxpath/internal/execxp"
|
||||
"github.com/ChrisTrenkamp/goxpath/internal/parser"
|
||||
"github.com/ChrisTrenkamp/goxpath/tree"
|
||||
)
|
||||
|
||||
//Opts defines namespace mappings and custom functions for XPath expressions.
|
||||
type Opts struct {
|
||||
NS map[string]string
|
||||
Funcs map[xml.Name]tree.Wrap
|
||||
Vars map[string]tree.Result
|
||||
}
|
||||
|
||||
//FuncOpts is a function wrapper for Opts.
|
||||
type FuncOpts func(*Opts)
|
||||
|
||||
//XPathExec is the XPath executor, compiled from an XPath string
|
||||
type XPathExec struct {
|
||||
n *parser.Node
|
||||
}
|
||||
|
||||
//Parse parses the XPath expression, xp, returning an XPath executor.
|
||||
func Parse(xp string) (XPathExec, error) {
|
||||
n, err := parser.Parse(xp)
|
||||
return XPathExec{n: n}, err
|
||||
}
|
||||
|
||||
//MustParse is like Parse, but panics instead of returning an error.
|
||||
func MustParse(xp string) XPathExec {
|
||||
ret, err := Parse(xp)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
//Exec executes the XPath expression, xp, against the tree, t, with the
|
||||
//namespace mappings, ns, and returns the result as a stringer.
|
||||
func (xp XPathExec) Exec(t tree.Node, opts ...FuncOpts) (tree.Result, error) {
|
||||
o := &Opts{
|
||||
NS: make(map[string]string),
|
||||
Funcs: make(map[xml.Name]tree.Wrap),
|
||||
Vars: make(map[string]tree.Result),
|
||||
}
|
||||
for _, i := range opts {
|
||||
i(o)
|
||||
}
|
||||
return execxp.Exec(xp.n, t, o.NS, o.Funcs, o.Vars)
|
||||
}
|
||||
|
||||
//ExecBool is like Exec, except it will attempt to convert the result to its boolean value.
|
||||
func (xp XPathExec) ExecBool(t tree.Node, opts ...FuncOpts) (bool, error) {
|
||||
res, err := xp.Exec(t, opts...)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
b, ok := res.(tree.IsBool)
|
||||
if !ok {
|
||||
return false, fmt.Errorf("Cannot convert result to a boolean")
|
||||
}
|
||||
|
||||
return bool(b.Bool()), nil
|
||||
}
|
||||
|
||||
//ExecNum is like Exec, except it will attempt to convert the result to its number value.
|
||||
func (xp XPathExec) ExecNum(t tree.Node, opts ...FuncOpts) (float64, error) {
|
||||
res, err := xp.Exec(t, opts...)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
n, ok := res.(tree.IsNum)
|
||||
if !ok {
|
||||
return 0, fmt.Errorf("Cannot convert result to a number")
|
||||
}
|
||||
|
||||
return float64(n.Num()), nil
|
||||
}
|
||||
|
||||
//ExecNode is like Exec, except it will attempt to return the result as a node-set.
|
||||
func (xp XPathExec) ExecNode(t tree.Node, opts ...FuncOpts) (tree.NodeSet, error) {
|
||||
res, err := xp.Exec(t, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
n, ok := res.(tree.NodeSet)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("Cannot convert result to a node-set")
|
||||
}
|
||||
|
||||
return n, nil
|
||||
}
|
||||
|
||||
//MustExec is like Exec, but panics instead of returning an error.
|
||||
func (xp XPathExec) MustExec(t tree.Node, opts ...FuncOpts) tree.Result {
|
||||
res, err := xp.Exec(t, opts...)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
//ParseExec parses the XPath string, xpstr, and runs Exec.
|
||||
func ParseExec(xpstr string, t tree.Node, opts ...FuncOpts) (tree.Result, error) {
|
||||
xp, err := Parse(xpstr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return xp.Exec(t, opts...)
|
||||
}
|
|
@ -0,0 +1,27 @@
|
|||
package execxp
|
||||
|
||||
import (
|
||||
"encoding/xml"
|
||||
|
||||
"github.com/ChrisTrenkamp/goxpath/internal/parser"
|
||||
"github.com/ChrisTrenkamp/goxpath/tree"
|
||||
)
|
||||
|
||||
//Exec executes the XPath expression, xp, against the tree, t, with the
|
||||
//namespace mappings, ns.
|
||||
func Exec(n *parser.Node, t tree.Node, ns map[string]string, fns map[xml.Name]tree.Wrap, v map[string]tree.Result) (tree.Result, error) {
|
||||
f := xpFilt{
|
||||
t: t,
|
||||
ns: ns,
|
||||
ctx: tree.NodeSet{t},
|
||||
fns: fns,
|
||||
variables: v,
|
||||
}
|
||||
|
||||
return exec(&f, n)
|
||||
}
|
||||
|
||||
func exec(f *xpFilt, n *parser.Node) (tree.Result, error) {
|
||||
err := xfExec(f, n)
|
||||
return f.ctx, err
|
||||
}
|
212
vendor/github.com/ChrisTrenkamp/goxpath/internal/execxp/operators.go
generated
vendored
Normal file
212
vendor/github.com/ChrisTrenkamp/goxpath/internal/execxp/operators.go
generated
vendored
Normal file
|
@ -0,0 +1,212 @@
|
|||
package execxp
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/ChrisTrenkamp/goxpath/tree"
|
||||
)
|
||||
|
||||
func bothNodeOperator(left tree.NodeSet, right tree.NodeSet, f *xpFilt, op string) error {
|
||||
var err error
|
||||
for _, l := range left {
|
||||
for _, r := range right {
|
||||
lStr := l.ResValue()
|
||||
rStr := r.ResValue()
|
||||
|
||||
if eqOps[op] {
|
||||
err = equalsOperator(tree.String(lStr), tree.String(rStr), f, op)
|
||||
if err == nil && f.ctx.String() == tree.True {
|
||||
return nil
|
||||
}
|
||||
} else {
|
||||
err = numberOperator(tree.String(lStr), tree.String(rStr), f, op)
|
||||
if err == nil && f.ctx.String() == tree.True {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
f.ctx = tree.Bool(false)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func leftNodeOperator(left tree.NodeSet, right tree.Result, f *xpFilt, op string) error {
|
||||
var err error
|
||||
for _, l := range left {
|
||||
lStr := l.ResValue()
|
||||
|
||||
if eqOps[op] {
|
||||
err = equalsOperator(tree.String(lStr), right, f, op)
|
||||
if err == nil && f.ctx.String() == tree.True {
|
||||
return nil
|
||||
}
|
||||
} else {
|
||||
err = numberOperator(tree.String(lStr), right, f, op)
|
||||
if err == nil && f.ctx.String() == tree.True {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
f.ctx = tree.Bool(false)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func rightNodeOperator(left tree.Result, right tree.NodeSet, f *xpFilt, op string) error {
|
||||
var err error
|
||||
for _, r := range right {
|
||||
rStr := r.ResValue()
|
||||
|
||||
if eqOps[op] {
|
||||
err = equalsOperator(left, tree.String(rStr), f, op)
|
||||
if err == nil && f.ctx.String() == "true" {
|
||||
return nil
|
||||
}
|
||||
} else {
|
||||
err = numberOperator(left, tree.String(rStr), f, op)
|
||||
if err == nil && f.ctx.String() == "true" {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
f.ctx = tree.Bool(false)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func equalsOperator(left, right tree.Result, f *xpFilt, op string) error {
|
||||
_, lOK := left.(tree.Bool)
|
||||
_, rOK := right.(tree.Bool)
|
||||
|
||||
if lOK || rOK {
|
||||
lTest, lt := left.(tree.IsBool)
|
||||
rTest, rt := right.(tree.IsBool)
|
||||
if !lt || !rt {
|
||||
return fmt.Errorf("Cannot convert argument to boolean")
|
||||
}
|
||||
|
||||
if op == "=" {
|
||||
f.ctx = tree.Bool(lTest.Bool() == rTest.Bool())
|
||||
} else {
|
||||
f.ctx = tree.Bool(lTest.Bool() != rTest.Bool())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
_, lOK = left.(tree.Num)
|
||||
_, rOK = right.(tree.Num)
|
||||
if lOK || rOK {
|
||||
return numberOperator(left, right, f, op)
|
||||
}
|
||||
|
||||
lStr := left.String()
|
||||
rStr := right.String()
|
||||
|
||||
if op == "=" {
|
||||
f.ctx = tree.Bool(lStr == rStr)
|
||||
} else {
|
||||
f.ctx = tree.Bool(lStr != rStr)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func numberOperator(left, right tree.Result, f *xpFilt, op string) error {
|
||||
lt, lOK := left.(tree.IsNum)
|
||||
rt, rOK := right.(tree.IsNum)
|
||||
if !lOK || !rOK {
|
||||
return fmt.Errorf("Cannot convert data type to number")
|
||||
}
|
||||
|
||||
ln, rn := lt.Num(), rt.Num()
|
||||
|
||||
switch op {
|
||||
case "*":
|
||||
f.ctx = ln * rn
|
||||
case "div":
|
||||
if rn != 0 {
|
||||
f.ctx = ln / rn
|
||||
} else {
|
||||
if ln == 0 {
|
||||
f.ctx = tree.Num(math.NaN())
|
||||
} else {
|
||||
if math.Signbit(float64(ln)) == math.Signbit(float64(rn)) {
|
||||
f.ctx = tree.Num(math.Inf(1))
|
||||
} else {
|
||||
f.ctx = tree.Num(math.Inf(-1))
|
||||
}
|
||||
}
|
||||
}
|
||||
case "mod":
|
||||
f.ctx = tree.Num(int(ln) % int(rn))
|
||||
case "+":
|
||||
f.ctx = ln + rn
|
||||
case "-":
|
||||
f.ctx = ln - rn
|
||||
case "=":
|
||||
f.ctx = tree.Bool(ln == rn)
|
||||
case "!=":
|
||||
f.ctx = tree.Bool(ln != rn)
|
||||
case "<":
|
||||
f.ctx = tree.Bool(ln < rn)
|
||||
case "<=":
|
||||
f.ctx = tree.Bool(ln <= rn)
|
||||
case ">":
|
||||
f.ctx = tree.Bool(ln > rn)
|
||||
case ">=":
|
||||
f.ctx = tree.Bool(ln >= rn)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func andOrOperator(left, right tree.Result, f *xpFilt, op string) error {
|
||||
lt, lOK := left.(tree.IsBool)
|
||||
rt, rOK := right.(tree.IsBool)
|
||||
|
||||
if !lOK || !rOK {
|
||||
return fmt.Errorf("Cannot convert argument to boolean")
|
||||
}
|
||||
|
||||
l, r := lt.Bool(), rt.Bool()
|
||||
|
||||
if op == "and" {
|
||||
f.ctx = l && r
|
||||
} else {
|
||||
f.ctx = l || r
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func unionOperator(left, right tree.Result, f *xpFilt, op string) error {
|
||||
lNode, lOK := left.(tree.NodeSet)
|
||||
rNode, rOK := right.(tree.NodeSet)
|
||||
|
||||
if !lOK || !rOK {
|
||||
return fmt.Errorf("Cannot convert data type to node-set")
|
||||
}
|
||||
|
||||
uniq := make(map[int]tree.Node)
|
||||
for _, i := range lNode {
|
||||
uniq[i.Pos()] = i
|
||||
}
|
||||
for _, i := range rNode {
|
||||
uniq[i.Pos()] = i
|
||||
}
|
||||
|
||||
res := make(tree.NodeSet, 0, len(uniq))
|
||||
for _, v := range uniq {
|
||||
res = append(res, v)
|
||||
}
|
||||
|
||||
f.ctx = res
|
||||
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,397 @@
|
|||
package execxp
|
||||
|
||||
import (
|
||||
"encoding/xml"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/ChrisTrenkamp/goxpath/internal/parser"
|
||||
"github.com/ChrisTrenkamp/goxpath/internal/parser/findutil"
|
||||
"github.com/ChrisTrenkamp/goxpath/internal/parser/intfns"
|
||||
"github.com/ChrisTrenkamp/goxpath/internal/xconst"
|
||||
"github.com/ChrisTrenkamp/goxpath/internal/xsort"
|
||||
|
||||
"github.com/ChrisTrenkamp/goxpath/internal/lexer"
|
||||
"github.com/ChrisTrenkamp/goxpath/internal/parser/pathexpr"
|
||||
"github.com/ChrisTrenkamp/goxpath/tree"
|
||||
)
|
||||
|
||||
type xpFilt struct {
|
||||
t tree.Node
|
||||
ctx tree.Result
|
||||
expr pathexpr.PathExpr
|
||||
ns map[string]string
|
||||
ctxPos int
|
||||
ctxSize int
|
||||
proxPos map[int]int
|
||||
fns map[xml.Name]tree.Wrap
|
||||
variables map[string]tree.Result
|
||||
}
|
||||
|
||||
type xpExecFn func(*xpFilt, string)
|
||||
|
||||
var xpFns = map[lexer.XItemType]xpExecFn{
|
||||
lexer.XItemAbsLocPath: xfAbsLocPath,
|
||||
lexer.XItemAbbrAbsLocPath: xfAbbrAbsLocPath,
|
||||
lexer.XItemRelLocPath: xfRelLocPath,
|
||||
lexer.XItemAbbrRelLocPath: xfAbbrRelLocPath,
|
||||
lexer.XItemAxis: xfAxis,
|
||||
lexer.XItemAbbrAxis: xfAbbrAxis,
|
||||
lexer.XItemNCName: xfNCName,
|
||||
lexer.XItemQName: xfQName,
|
||||
lexer.XItemNodeType: xfNodeType,
|
||||
lexer.XItemProcLit: xfProcInstLit,
|
||||
lexer.XItemStrLit: xfStrLit,
|
||||
lexer.XItemNumLit: xfNumLit,
|
||||
}
|
||||
|
||||
func xfExec(f *xpFilt, n *parser.Node) (err error) {
|
||||
for n != nil {
|
||||
if fn, ok := xpFns[n.Val.Typ]; ok {
|
||||
fn(f, n.Val.Val)
|
||||
n = n.Left
|
||||
} else if n.Val.Typ == lexer.XItemPredicate {
|
||||
if err = xfPredicate(f, n.Left); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
n = n.Right
|
||||
} else if n.Val.Typ == lexer.XItemFunction {
|
||||
if err = xfFunction(f, n); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
n = n.Right
|
||||
} else if n.Val.Typ == lexer.XItemOperator {
|
||||
lf := xpFilt{
|
||||
t: f.t,
|
||||
ns: f.ns,
|
||||
ctx: f.ctx,
|
||||
ctxPos: f.ctxPos,
|
||||
ctxSize: f.ctxSize,
|
||||
proxPos: f.proxPos,
|
||||
fns: f.fns,
|
||||
variables: f.variables,
|
||||
}
|
||||
left, err := exec(&lf, n.Left)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rf := xpFilt{
|
||||
t: f.t,
|
||||
ns: f.ns,
|
||||
ctx: f.ctx,
|
||||
fns: f.fns,
|
||||
variables: f.variables,
|
||||
}
|
||||
right, err := exec(&rf, n.Right)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return xfOperator(left, right, f, n.Val.Val)
|
||||
} else if n.Val.Typ == lexer.XItemVariable {
|
||||
if res, ok := f.variables[n.Val.Val]; ok {
|
||||
f.ctx = res
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("Invalid variable '%s'", n.Val.Val)
|
||||
} else if string(n.Val.Typ) == "" {
|
||||
n = n.Left
|
||||
//} else {
|
||||
// return fmt.Errorf("Cannot process " + string(n.Val.Typ))
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func xfPredicate(f *xpFilt, n *parser.Node) (err error) {
|
||||
res := f.ctx.(tree.NodeSet)
|
||||
newRes := make(tree.NodeSet, 0, len(res))
|
||||
|
||||
for i := range res {
|
||||
pf := xpFilt{
|
||||
t: f.t,
|
||||
ns: f.ns,
|
||||
ctxPos: i,
|
||||
ctxSize: f.ctxSize,
|
||||
ctx: tree.NodeSet{res[i]},
|
||||
fns: f.fns,
|
||||
variables: f.variables,
|
||||
}
|
||||
|
||||
predRes, err := exec(&pf, n)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ok, err := checkPredRes(predRes, f, res[i])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if ok {
|
||||
newRes = append(newRes, res[i])
|
||||
}
|
||||
}
|
||||
|
||||
f.proxPos = make(map[int]int)
|
||||
for pos, j := range newRes {
|
||||
f.proxPos[j.Pos()] = pos + 1
|
||||
}
|
||||
|
||||
f.ctx = newRes
|
||||
f.ctxSize = len(newRes)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func checkPredRes(ret tree.Result, f *xpFilt, node tree.Node) (bool, error) {
|
||||
if num, ok := ret.(tree.Num); ok {
|
||||
if float64(f.proxPos[node.Pos()]) == float64(num) {
|
||||
return true, nil
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if b, ok := ret.(tree.IsBool); ok {
|
||||
return bool(b.Bool()), nil
|
||||
}
|
||||
|
||||
return false, fmt.Errorf("Cannot convert argument to boolean")
|
||||
}
|
||||
|
||||
func xfFunction(f *xpFilt, n *parser.Node) error {
|
||||
spl := strings.Split(n.Val.Val, ":")
|
||||
var name xml.Name
|
||||
if len(spl) == 1 {
|
||||
name.Local = spl[0]
|
||||
} else {
|
||||
name.Space = f.ns[spl[0]]
|
||||
name.Local = spl[1]
|
||||
}
|
||||
fn, ok := intfns.BuiltIn[name]
|
||||
if !ok {
|
||||
fn, ok = f.fns[name]
|
||||
}
|
||||
|
||||
if ok {
|
||||
args := []tree.Result{}
|
||||
param := n.Left
|
||||
|
||||
for param != nil {
|
||||
pf := xpFilt{
|
||||
t: f.t,
|
||||
ctx: f.ctx,
|
||||
ns: f.ns,
|
||||
ctxPos: f.ctxPos,
|
||||
ctxSize: f.ctxSize,
|
||||
fns: f.fns,
|
||||
variables: f.variables,
|
||||
}
|
||||
res, err := exec(&pf, param.Left)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
args = append(args, res)
|
||||
param = param.Right
|
||||
}
|
||||
|
||||
filt, err := fn.Call(tree.Ctx{NodeSet: f.ctx.(tree.NodeSet), Size: f.ctxSize, Pos: f.ctxPos + 1}, args...)
|
||||
f.ctx = filt
|
||||
return err
|
||||
}
|
||||
|
||||
return fmt.Errorf("Unknown function: %s", n.Val.Val)
|
||||
}
|
||||
|
||||
var eqOps = map[string]bool{
|
||||
"=": true,
|
||||
"!=": true,
|
||||
}
|
||||
|
||||
var booleanOps = map[string]bool{
|
||||
"=": true,
|
||||
"!=": true,
|
||||
"<": true,
|
||||
"<=": true,
|
||||
">": true,
|
||||
">=": true,
|
||||
}
|
||||
|
||||
var numOps = map[string]bool{
|
||||
"*": true,
|
||||
"div": true,
|
||||
"mod": true,
|
||||
"+": true,
|
||||
"-": true,
|
||||
"=": true,
|
||||
"!=": true,
|
||||
"<": true,
|
||||
"<=": true,
|
||||
">": true,
|
||||
">=": true,
|
||||
}
|
||||
|
||||
var andOrOps = map[string]bool{
|
||||
"and": true,
|
||||
"or": true,
|
||||
}
|
||||
|
||||
func xfOperator(left, right tree.Result, f *xpFilt, op string) error {
|
||||
if booleanOps[op] {
|
||||
lNode, lOK := left.(tree.NodeSet)
|
||||
rNode, rOK := right.(tree.NodeSet)
|
||||
if lOK && rOK {
|
||||
return bothNodeOperator(lNode, rNode, f, op)
|
||||
}
|
||||
|
||||
if lOK {
|
||||
return leftNodeOperator(lNode, right, f, op)
|
||||
}
|
||||
|
||||
if rOK {
|
||||
return rightNodeOperator(left, rNode, f, op)
|
||||
}
|
||||
|
||||
if eqOps[op] {
|
||||
return equalsOperator(left, right, f, op)
|
||||
}
|
||||
}
|
||||
|
||||
if numOps[op] {
|
||||
return numberOperator(left, right, f, op)
|
||||
}
|
||||
|
||||
if andOrOps[op] {
|
||||
return andOrOperator(left, right, f, op)
|
||||
}
|
||||
|
||||
//if op == "|" {
|
||||
return unionOperator(left, right, f, op)
|
||||
//}
|
||||
|
||||
//return fmt.Errorf("Unknown operator " + op)
|
||||
}
|
||||
|
||||
func xfAbsLocPath(f *xpFilt, val string) {
|
||||
i := f.t
|
||||
for i.GetNodeType() != tree.NtRoot {
|
||||
i = i.GetParent()
|
||||
}
|
||||
f.ctx = tree.NodeSet{i}
|
||||
}
|
||||
|
||||
func xfAbbrAbsLocPath(f *xpFilt, val string) {
|
||||
i := f.t
|
||||
for i.GetNodeType() != tree.NtRoot {
|
||||
i = i.GetParent()
|
||||
}
|
||||
f.ctx = tree.NodeSet{i}
|
||||
f.expr = abbrPathExpr()
|
||||
find(f)
|
||||
}
|
||||
|
||||
func xfRelLocPath(f *xpFilt, val string) {
|
||||
}
|
||||
|
||||
func xfAbbrRelLocPath(f *xpFilt, val string) {
|
||||
f.expr = abbrPathExpr()
|
||||
find(f)
|
||||
}
|
||||
|
||||
func xfAxis(f *xpFilt, val string) {
|
||||
f.expr.Axis = val
|
||||
}
|
||||
|
||||
func xfAbbrAxis(f *xpFilt, val string) {
|
||||
f.expr.Axis = xconst.AxisAttribute
|
||||
}
|
||||
|
||||
func xfNCName(f *xpFilt, val string) {
|
||||
f.expr.Name.Space = val
|
||||
}
|
||||
|
||||
func xfQName(f *xpFilt, val string) {
|
||||
f.expr.Name.Local = val
|
||||
find(f)
|
||||
}
|
||||
|
||||
func xfNodeType(f *xpFilt, val string) {
|
||||
f.expr.NodeType = val
|
||||
find(f)
|
||||
}
|
||||
|
||||
func xfProcInstLit(f *xpFilt, val string) {
|
||||
filt := tree.NodeSet{}
|
||||
for _, i := range f.ctx.(tree.NodeSet) {
|
||||
if i.GetToken().(xml.ProcInst).Target == val {
|
||||
filt = append(filt, i)
|
||||
}
|
||||
}
|
||||
f.ctx = filt
|
||||
}
|
||||
|
||||
func xfStrLit(f *xpFilt, val string) {
|
||||
f.ctx = tree.String(val)
|
||||
}
|
||||
|
||||
func xfNumLit(f *xpFilt, val string) {
|
||||
num, _ := strconv.ParseFloat(val, 64)
|
||||
f.ctx = tree.Num(num)
|
||||
}
|
||||
|
||||
func abbrPathExpr() pathexpr.PathExpr {
|
||||
return pathexpr.PathExpr{
|
||||
Name: xml.Name{},
|
||||
Axis: xconst.AxisDescendentOrSelf,
|
||||
NodeType: xconst.NodeTypeNode,
|
||||
}
|
||||
}
|
||||
|
||||
func find(f *xpFilt) {
|
||||
dupFilt := make(map[int]tree.Node)
|
||||
f.proxPos = make(map[int]int)
|
||||
|
||||
if f.expr.Axis == "" && f.expr.NodeType == "" && f.expr.Name.Space == "" {
|
||||
if f.expr.Name.Local == "." {
|
||||
f.expr = pathexpr.PathExpr{
|
||||
Name: xml.Name{},
|
||||
Axis: xconst.AxisSelf,
|
||||
NodeType: xconst.NodeTypeNode,
|
||||
}
|
||||
}
|
||||
|
||||
if f.expr.Name.Local == ".." {
|
||||
f.expr = pathexpr.PathExpr{
|
||||
Name: xml.Name{},
|
||||
Axis: xconst.AxisParent,
|
||||
NodeType: xconst.NodeTypeNode,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
f.expr.NS = f.ns
|
||||
|
||||
for _, i := range f.ctx.(tree.NodeSet) {
|
||||
for pos, j := range findutil.Find(i, f.expr) {
|
||||
dupFilt[j.Pos()] = j
|
||||
f.proxPos[j.Pos()] = pos + 1
|
||||
}
|
||||
}
|
||||
|
||||
res := make(tree.NodeSet, 0, len(dupFilt))
|
||||
for _, i := range dupFilt {
|
||||
res = append(res, i)
|
||||
}
|
||||
|
||||
xsort.SortNodes(res)
|
||||
|
||||
f.expr = pathexpr.PathExpr{}
|
||||
f.ctxSize = len(res)
|
||||
f.ctx = res
|
||||
}
|
|
@ -0,0 +1,419 @@
|
|||
package lexer
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"unicode"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
const (
|
||||
//XItemError is an error with the parser input
|
||||
XItemError XItemType = "Error"
|
||||
//XItemAbsLocPath is an absolute path
|
||||
XItemAbsLocPath = "Absolute path"
|
||||
//XItemAbbrAbsLocPath represents an abbreviated absolute path
|
||||
XItemAbbrAbsLocPath = "Abbreviated absolute path"
|
||||
//XItemAbbrRelLocPath marks the start of a path expression
|
||||
XItemAbbrRelLocPath = "Abbreviated relative path"
|
||||
//XItemRelLocPath represents a relative location path
|
||||
XItemRelLocPath = "Relative path"
|
||||
//XItemEndPath marks the end of a path
|
||||
XItemEndPath = "End path instruction"
|
||||
//XItemAxis marks an axis specifier of a path
|
||||
XItemAxis = "Axis"
|
||||
//XItemAbbrAxis marks an abbreviated axis specifier (just @ at this point)
|
||||
XItemAbbrAxis = "Abbreviated attribute axis"
|
||||
//XItemNCName marks a namespace name in a node test
|
||||
XItemNCName = "Namespace"
|
||||
//XItemQName marks the local name in an a node test
|
||||
XItemQName = "Local name"
|
||||
//XItemNodeType marks a node type in a node test
|
||||
XItemNodeType = "Node type"
|
||||
//XItemProcLit marks a processing-instruction literal
|
||||
XItemProcLit = "processing-instruction"
|
||||
//XItemFunction marks a function call
|
||||
XItemFunction = "function"
|
||||
//XItemArgument marks a function argument
|
||||
XItemArgument = "function argument"
|
||||
//XItemEndFunction marks the end of a function
|
||||
XItemEndFunction = "end of function"
|
||||
//XItemPredicate marks a predicate in an axis
|
||||
XItemPredicate = "predicate"
|
||||
//XItemEndPredicate marks a predicate in an axis
|
||||
XItemEndPredicate = "end of predicate"
|
||||
//XItemStrLit marks a string literal
|
||||
XItemStrLit = "string literal"
|
||||
//XItemNumLit marks a numeric literal
|
||||
XItemNumLit = "numeric literal"
|
||||
//XItemOperator marks an operator
|
||||
XItemOperator = "operator"
|
||||
//XItemVariable marks a variable reference
|
||||
XItemVariable = "variable"
|
||||
)
|
||||
|
||||
const (
|
||||
eof = -(iota + 1)
|
||||
)
|
||||
|
||||
//XItemType is the parser token types
|
||||
type XItemType string
|
||||
|
||||
//XItem is the token emitted from the parser
|
||||
type XItem struct {
|
||||
Typ XItemType
|
||||
Val string
|
||||
}
|
||||
|
||||
type stateFn func(*Lexer) stateFn
|
||||
|
||||
//Lexer lexes out XPath expressions
|
||||
type Lexer struct {
|
||||
input string
|
||||
start int
|
||||
pos int
|
||||
width int
|
||||
items chan XItem
|
||||
}
|
||||
|
||||
//Lex an XPath expresion on the io.Reader
|
||||
func Lex(xpath string) chan XItem {
|
||||
l := &Lexer{
|
||||
input: xpath,
|
||||
items: make(chan XItem),
|
||||
}
|
||||
go l.run()
|
||||
return l.items
|
||||
}
|
||||
|
||||
func (l *Lexer) run() {
|
||||
for state := startState; state != nil; {
|
||||
state = state(l)
|
||||
}
|
||||
|
||||
if l.peek() != eof {
|
||||
l.errorf("Malformed XPath expression")
|
||||
}
|
||||
|
||||
close(l.items)
|
||||
}
|
||||
|
||||
func (l *Lexer) emit(t XItemType) {
|
||||
l.items <- XItem{t, l.input[l.start:l.pos]}
|
||||
l.start = l.pos
|
||||
}
|
||||
|
||||
func (l *Lexer) emitVal(t XItemType, val string) {
|
||||
l.items <- XItem{t, val}
|
||||
l.start = l.pos
|
||||
}
|
||||
|
||||
func (l *Lexer) next() (r rune) {
|
||||
if l.pos >= len(l.input) {
|
||||
l.width = 0
|
||||
return eof
|
||||
}
|
||||
|
||||
r, l.width = utf8.DecodeRuneInString(l.input[l.pos:])
|
||||
|
||||
l.pos += l.width
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
func (l *Lexer) ignore() {
|
||||
l.start = l.pos
|
||||
}
|
||||
|
||||
func (l *Lexer) backup() {
|
||||
l.pos -= l.width
|
||||
}
|
||||
|
||||
func (l *Lexer) peek() rune {
|
||||
r := l.next()
|
||||
|
||||
l.backup()
|
||||
return r
|
||||
}
|
||||
|
||||
func (l *Lexer) peekAt(n int) rune {
|
||||
if n <= 1 {
|
||||
return l.peek()
|
||||
}
|
||||
|
||||
width := 0
|
||||
var ret rune
|
||||
|
||||
for count := 0; count < n; count++ {
|
||||
r, s := utf8.DecodeRuneInString(l.input[l.pos+width:])
|
||||
width += s
|
||||
|
||||
if l.pos+width > len(l.input) {
|
||||
return eof
|
||||
}
|
||||
|
||||
ret = r
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
func (l *Lexer) accept(valid string) bool {
|
||||
if strings.ContainsRune(valid, l.next()) {
|
||||
return true
|
||||
}
|
||||
|
||||
l.backup()
|
||||
return false
|
||||
}
|
||||
|
||||
func (l *Lexer) acceptRun(valid string) {
|
||||
for strings.ContainsRune(valid, l.next()) {
|
||||
}
|
||||
l.backup()
|
||||
}
|
||||
|
||||
func (l *Lexer) skip(num int) {
|
||||
for i := 0; i < num; i++ {
|
||||
l.next()
|
||||
}
|
||||
l.ignore()
|
||||
}
|
||||
|
||||
func (l *Lexer) skipWS(ig bool) {
|
||||
for {
|
||||
n := l.next()
|
||||
|
||||
if n == eof || !unicode.IsSpace(n) {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
l.backup()
|
||||
|
||||
if ig {
|
||||
l.ignore()
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Lexer) errorf(format string, args ...interface{}) stateFn {
|
||||
l.items <- XItem{
|
||||
XItemError,
|
||||
fmt.Sprintf(format, args...),
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func isElemChar(r rune) bool {
|
||||
return string(r) != ":" && string(r) != "/" &&
|
||||
(unicode.Is(first, r) || unicode.Is(second, r) || string(r) == "*") &&
|
||||
r != eof
|
||||
}
|
||||
|
||||
func startState(l *Lexer) stateFn {
|
||||
l.skipWS(true)
|
||||
|
||||
if string(l.peek()) == "/" {
|
||||
l.next()
|
||||
l.ignore()
|
||||
|
||||
if string(l.next()) == "/" {
|
||||
l.ignore()
|
||||
return abbrAbsLocPathState
|
||||
}
|
||||
|
||||
l.backup()
|
||||
return absLocPathState
|
||||
} else if string(l.peek()) == `'` || string(l.peek()) == `"` {
|
||||
if err := getStrLit(l, XItemStrLit); err != nil {
|
||||
return l.errorf(err.Error())
|
||||
}
|
||||
|
||||
if l.peek() != eof {
|
||||
return startState
|
||||
}
|
||||
} else if getNumLit(l) {
|
||||
l.skipWS(true)
|
||||
if l.peek() != eof {
|
||||
return startState
|
||||
}
|
||||
} else if string(l.peek()) == "$" {
|
||||
l.next()
|
||||
l.ignore()
|
||||
r := l.peek()
|
||||
for unicode.Is(first, r) || unicode.Is(second, r) {
|
||||
l.next()
|
||||
r = l.peek()
|
||||
}
|
||||
tok := l.input[l.start:l.pos]
|
||||
if len(tok) == 0 {
|
||||
return l.errorf("Empty variable name")
|
||||
}
|
||||
l.emit(XItemVariable)
|
||||
l.skipWS(true)
|
||||
if l.peek() != eof {
|
||||
return startState
|
||||
}
|
||||
} else if st := findOperatorState(l); st != nil {
|
||||
return st
|
||||
} else {
|
||||
if isElemChar(l.peek()) {
|
||||
colons := 0
|
||||
|
||||
for {
|
||||
if isElemChar(l.peek()) {
|
||||
l.next()
|
||||
} else if string(l.peek()) == ":" {
|
||||
l.next()
|
||||
colons++
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if string(l.peek()) == "(" && colons <= 1 {
|
||||
tok := l.input[l.start:l.pos]
|
||||
err := procFunc(l, tok)
|
||||
if err != nil {
|
||||
return l.errorf(err.Error())
|
||||
}
|
||||
|
||||
l.skipWS(true)
|
||||
|
||||
if string(l.peek()) == "/" {
|
||||
l.next()
|
||||
l.ignore()
|
||||
|
||||
if string(l.next()) == "/" {
|
||||
l.ignore()
|
||||
return abbrRelLocPathState
|
||||
}
|
||||
|
||||
l.backup()
|
||||
return relLocPathState
|
||||
}
|
||||
|
||||
return startState
|
||||
}
|
||||
|
||||
l.pos = l.start
|
||||
return relLocPathState
|
||||
} else if string(l.peek()) == "@" {
|
||||
return relLocPathState
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func strPeek(str string, l *Lexer) bool {
|
||||
for i := 0; i < len(str); i++ {
|
||||
if string(l.peekAt(i+1)) != string(str[i]) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func findOperatorState(l *Lexer) stateFn {
|
||||
l.skipWS(true)
|
||||
|
||||
switch string(l.peek()) {
|
||||
case ">", "<", "!":
|
||||
l.next()
|
||||
if string(l.peek()) == "=" {
|
||||
l.next()
|
||||
}
|
||||
l.emit(XItemOperator)
|
||||
return startState
|
||||
case "|", "+", "-", "*", "=":
|
||||
l.next()
|
||||
l.emit(XItemOperator)
|
||||
return startState
|
||||
case "(":
|
||||
l.next()
|
||||
l.emit(XItemOperator)
|
||||
for state := startState; state != nil; {
|
||||
state = state(l)
|
||||
}
|
||||
l.skipWS(true)
|
||||
if string(l.next()) != ")" {
|
||||
return l.errorf("Missing end )")
|
||||
}
|
||||
l.emit(XItemOperator)
|
||||
return startState
|
||||
}
|
||||
|
||||
if strPeek("and", l) {
|
||||
l.next()
|
||||
l.next()
|
||||
l.next()
|
||||
l.emit(XItemOperator)
|
||||
return startState
|
||||
}
|
||||
|
||||
if strPeek("or", l) {
|
||||
l.next()
|
||||
l.next()
|
||||
l.emit(XItemOperator)
|
||||
return startState
|
||||
}
|
||||
|
||||
if strPeek("mod", l) {
|
||||
l.next()
|
||||
l.next()
|
||||
l.next()
|
||||
l.emit(XItemOperator)
|
||||
return startState
|
||||
}
|
||||
|
||||
if strPeek("div", l) {
|
||||
l.next()
|
||||
l.next()
|
||||
l.next()
|
||||
l.emit(XItemOperator)
|
||||
return startState
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getStrLit(l *Lexer, tok XItemType) error {
|
||||
q := l.next()
|
||||
var r rune
|
||||
|
||||
l.ignore()
|
||||
|
||||
for r != q {
|
||||
r = l.next()
|
||||
if r == eof {
|
||||
return fmt.Errorf("Unexpected end of string literal.")
|
||||
}
|
||||
}
|
||||
|
||||
l.backup()
|
||||
l.emit(tok)
|
||||
l.next()
|
||||
l.ignore()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getNumLit(l *Lexer) bool {
|
||||
const dig = "0123456789"
|
||||
l.accept("-")
|
||||
start := l.pos
|
||||
l.acceptRun(dig)
|
||||
|
||||
if l.pos == start {
|
||||
return false
|
||||
}
|
||||
|
||||
if l.accept(".") {
|
||||
l.acceptRun(dig)
|
||||
}
|
||||
|
||||
l.emit(XItemNumLit)
|
||||
return true
|
||||
}
|
|
@ -0,0 +1,219 @@
|
|||
package lexer
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/ChrisTrenkamp/goxpath/internal/xconst"
|
||||
)
|
||||
|
||||
func absLocPathState(l *Lexer) stateFn {
|
||||
l.emit(XItemAbsLocPath)
|
||||
return stepState
|
||||
}
|
||||
|
||||
func abbrAbsLocPathState(l *Lexer) stateFn {
|
||||
l.emit(XItemAbbrAbsLocPath)
|
||||
return stepState
|
||||
}
|
||||
|
||||
func relLocPathState(l *Lexer) stateFn {
|
||||
l.emit(XItemRelLocPath)
|
||||
return stepState
|
||||
}
|
||||
|
||||
func abbrRelLocPathState(l *Lexer) stateFn {
|
||||
l.emit(XItemAbbrRelLocPath)
|
||||
return stepState
|
||||
}
|
||||
|
||||
func stepState(l *Lexer) stateFn {
|
||||
l.skipWS(true)
|
||||
r := l.next()
|
||||
|
||||
for isElemChar(r) {
|
||||
r = l.next()
|
||||
}
|
||||
|
||||
l.backup()
|
||||
tok := l.input[l.start:l.pos]
|
||||
|
||||
state, err := parseSeparators(l, tok)
|
||||
if err != nil {
|
||||
return l.errorf(err.Error())
|
||||
}
|
||||
|
||||
return getNextPathState(l, state)
|
||||
}
|
||||
|
||||
func parseSeparators(l *Lexer, tok string) (XItemType, error) {
|
||||
l.skipWS(false)
|
||||
state := XItemType(XItemQName)
|
||||
r := l.peek()
|
||||
|
||||
if string(r) == ":" && string(l.peekAt(2)) == ":" {
|
||||
var err error
|
||||
if state, err = getAxis(l, tok); err != nil {
|
||||
return state, fmt.Errorf(err.Error())
|
||||
}
|
||||
} else if string(r) == ":" {
|
||||
state = XItemNCName
|
||||
l.emitVal(state, tok)
|
||||
l.skip(1)
|
||||
l.skipWS(true)
|
||||
} else if string(r) == "@" {
|
||||
state = XItemAbbrAxis
|
||||
l.emitVal(state, tok)
|
||||
l.skip(1)
|
||||
l.skipWS(true)
|
||||
} else if string(r) == "(" {
|
||||
var err error
|
||||
if state, err = getNT(l, tok); err != nil {
|
||||
return state, fmt.Errorf(err.Error())
|
||||
}
|
||||
} else if len(tok) > 0 {
|
||||
l.emitVal(state, tok)
|
||||
}
|
||||
|
||||
return state, nil
|
||||
}
|
||||
|
||||
func getAxis(l *Lexer, tok string) (XItemType, error) {
|
||||
var state XItemType
|
||||
for i := range xconst.AxisNames {
|
||||
if tok == xconst.AxisNames[i] {
|
||||
state = XItemAxis
|
||||
}
|
||||
}
|
||||
if state != XItemAxis {
|
||||
return state, fmt.Errorf("Invalid Axis specifier, %s", tok)
|
||||
}
|
||||
l.emitVal(state, tok)
|
||||
l.skip(2)
|
||||
l.skipWS(true)
|
||||
return state, nil
|
||||
}
|
||||
|
||||
func getNT(l *Lexer, tok string) (XItemType, error) {
|
||||
isNT := false
|
||||
for _, i := range xconst.NodeTypes {
|
||||
if tok == i {
|
||||
isNT = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if isNT {
|
||||
return procNT(l, tok)
|
||||
}
|
||||
|
||||
return XItemError, fmt.Errorf("Invalid node-type " + tok)
|
||||
}
|
||||
|
||||
func procNT(l *Lexer, tok string) (XItemType, error) {
|
||||
state := XItemType(XItemNodeType)
|
||||
l.emitVal(state, tok)
|
||||
l.skip(1)
|
||||
l.skipWS(true)
|
||||
n := l.peek()
|
||||
if tok == xconst.NodeTypeProcInst && (string(n) == `"` || string(n) == `'`) {
|
||||
if err := getStrLit(l, XItemProcLit); err != nil {
|
||||
return state, fmt.Errorf(err.Error())
|
||||
}
|
||||
l.skipWS(true)
|
||||
n = l.next()
|
||||
}
|
||||
|
||||
if string(n) != ")" {
|
||||
return state, fmt.Errorf("Missing ) at end of NodeType declaration.")
|
||||
}
|
||||
|
||||
l.skip(1)
|
||||
return state, nil
|
||||
}
|
||||
|
||||
func procFunc(l *Lexer, tok string) error {
|
||||
state := XItemType(XItemFunction)
|
||||
l.emitVal(state, tok)
|
||||
l.skip(1)
|
||||
l.skipWS(true)
|
||||
if string(l.peek()) != ")" {
|
||||
l.emit(XItemArgument)
|
||||
for {
|
||||
for state := startState; state != nil; {
|
||||
state = state(l)
|
||||
}
|
||||
l.skipWS(true)
|
||||
|
||||
if string(l.peek()) == "," {
|
||||
l.emit(XItemArgument)
|
||||
l.skip(1)
|
||||
} else if string(l.peek()) == ")" {
|
||||
l.emit(XItemEndFunction)
|
||||
l.skip(1)
|
||||
break
|
||||
} else if l.peek() == eof {
|
||||
return fmt.Errorf("Missing ) at end of function declaration.")
|
||||
}
|
||||
}
|
||||
} else {
|
||||
l.emit(XItemEndFunction)
|
||||
l.skip(1)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getNextPathState(l *Lexer, state XItemType) stateFn {
|
||||
isMultiPart := state == XItemAxis || state == XItemAbbrAxis || state == XItemNCName
|
||||
|
||||
l.skipWS(true)
|
||||
|
||||
for string(l.peek()) == "[" {
|
||||
if err := getPred(l); err != nil {
|
||||
return l.errorf(err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
if string(l.peek()) == "/" && !isMultiPart {
|
||||
l.skip(1)
|
||||
if string(l.peek()) == "/" {
|
||||
l.skip(1)
|
||||
return abbrRelLocPathState
|
||||
}
|
||||
l.skipWS(true)
|
||||
return relLocPathState
|
||||
} else if isMultiPart && isElemChar(l.peek()) {
|
||||
return stepState
|
||||
}
|
||||
|
||||
if isMultiPart {
|
||||
return l.errorf("Step is not complete")
|
||||
}
|
||||
|
||||
l.emit(XItemEndPath)
|
||||
return findOperatorState
|
||||
}
|
||||
|
||||
func getPred(l *Lexer) error {
|
||||
l.emit(XItemPredicate)
|
||||
l.skip(1)
|
||||
l.skipWS(true)
|
||||
|
||||
if string(l.peek()) == "]" {
|
||||
return fmt.Errorf("Missing content in predicate.")
|
||||
}
|
||||
|
||||
for state := startState; state != nil; {
|
||||
state = state(l)
|
||||
}
|
||||
|
||||
l.skipWS(true)
|
||||
if string(l.peek()) != "]" {
|
||||
return fmt.Errorf("Missing ] at end of predicate.")
|
||||
}
|
||||
l.skip(1)
|
||||
l.emit(XItemEndPredicate)
|
||||
l.skipWS(true)
|
||||
|
||||
return nil
|
||||
}
|
316
vendor/github.com/ChrisTrenkamp/goxpath/internal/lexer/xmlchars.go
generated
vendored
Normal file
316
vendor/github.com/ChrisTrenkamp/goxpath/internal/lexer/xmlchars.go
generated
vendored
Normal file
|
@ -0,0 +1,316 @@
|
|||
package lexer
|
||||
|
||||
import "unicode"
|
||||
|
||||
//first and second was copied from src/encoding/xml/xml.go
|
||||
var first = &unicode.RangeTable{
|
||||
R16: []unicode.Range16{
|
||||
{0x003A, 0x003A, 1},
|
||||
{0x0041, 0x005A, 1},
|
||||
{0x005F, 0x005F, 1},
|
||||
{0x0061, 0x007A, 1},
|
||||
{0x00C0, 0x00D6, 1},
|
||||
{0x00D8, 0x00F6, 1},
|
||||
{0x00F8, 0x00FF, 1},
|
||||
{0x0100, 0x0131, 1},
|
||||
{0x0134, 0x013E, 1},
|
||||
{0x0141, 0x0148, 1},
|
||||
{0x014A, 0x017E, 1},
|
||||
{0x0180, 0x01C3, 1},
|
||||
{0x01CD, 0x01F0, 1},
|
||||
{0x01F4, 0x01F5, 1},
|
||||
{0x01FA, 0x0217, 1},
|
||||
{0x0250, 0x02A8, 1},
|
||||
{0x02BB, 0x02C1, 1},
|
||||
{0x0386, 0x0386, 1},
|
||||
{0x0388, 0x038A, 1},
|
||||
{0x038C, 0x038C, 1},
|
||||
{0x038E, 0x03A1, 1},
|
||||
{0x03A3, 0x03CE, 1},
|
||||
{0x03D0, 0x03D6, 1},
|
||||
{0x03DA, 0x03E0, 2},
|
||||
{0x03E2, 0x03F3, 1},
|
||||
{0x0401, 0x040C, 1},
|
||||
{0x040E, 0x044F, 1},
|
||||
{0x0451, 0x045C, 1},
|
||||
{0x045E, 0x0481, 1},
|
||||
{0x0490, 0x04C4, 1},
|
||||
{0x04C7, 0x04C8, 1},
|
||||
{0x04CB, 0x04CC, 1},
|
||||
{0x04D0, 0x04EB, 1},
|
||||
{0x04EE, 0x04F5, 1},
|
||||
{0x04F8, 0x04F9, 1},
|
||||
{0x0531, 0x0556, 1},
|
||||
{0x0559, 0x0559, 1},
|
||||
{0x0561, 0x0586, 1},
|
||||
{0x05D0, 0x05EA, 1},
|
||||
{0x05F0, 0x05F2, 1},
|
||||
{0x0621, 0x063A, 1},
|
||||
{0x0641, 0x064A, 1},
|
||||
{0x0671, 0x06B7, 1},
|
||||
{0x06BA, 0x06BE, 1},
|
||||
{0x06C0, 0x06CE, 1},
|
||||
{0x06D0, 0x06D3, 1},
|
||||
{0x06D5, 0x06D5, 1},
|
||||
{0x06E5, 0x06E6, 1},
|
||||
{0x0905, 0x0939, 1},
|
||||
{0x093D, 0x093D, 1},
|
||||
{0x0958, 0x0961, 1},
|
||||
{0x0985, 0x098C, 1},
|
||||
{0x098F, 0x0990, 1},
|
||||
{0x0993, 0x09A8, 1},
|
||||
{0x09AA, 0x09B0, 1},
|
||||
{0x09B2, 0x09B2, 1},
|
||||
{0x09B6, 0x09B9, 1},
|
||||
{0x09DC, 0x09DD, 1},
|
||||
{0x09DF, 0x09E1, 1},
|
||||
{0x09F0, 0x09F1, 1},
|
||||
{0x0A05, 0x0A0A, 1},
|
||||
{0x0A0F, 0x0A10, 1},
|
||||
{0x0A13, 0x0A28, 1},
|
||||
{0x0A2A, 0x0A30, 1},
|
||||
{0x0A32, 0x0A33, 1},
|
||||
{0x0A35, 0x0A36, 1},
|
||||
{0x0A38, 0x0A39, 1},
|
||||
{0x0A59, 0x0A5C, 1},
|
||||
{0x0A5E, 0x0A5E, 1},
|
||||
{0x0A72, 0x0A74, 1},
|
||||
{0x0A85, 0x0A8B, 1},
|
||||
{0x0A8D, 0x0A8D, 1},
|
||||
{0x0A8F, 0x0A91, 1},
|
||||
{0x0A93, 0x0AA8, 1},
|
||||
{0x0AAA, 0x0AB0, 1},
|
||||
{0x0AB2, 0x0AB3, 1},
|
||||
{0x0AB5, 0x0AB9, 1},
|
||||
{0x0ABD, 0x0AE0, 0x23},
|
||||
{0x0B05, 0x0B0C, 1},
|
||||
{0x0B0F, 0x0B10, 1},
|
||||
{0x0B13, 0x0B28, 1},
|
||||
{0x0B2A, 0x0B30, 1},
|
||||
{0x0B32, 0x0B33, 1},
|
||||
{0x0B36, 0x0B39, 1},
|
||||
{0x0B3D, 0x0B3D, 1},
|
||||
{0x0B5C, 0x0B5D, 1},
|
||||
{0x0B5F, 0x0B61, 1},
|
||||
{0x0B85, 0x0B8A, 1},
|
||||
{0x0B8E, 0x0B90, 1},
|
||||
{0x0B92, 0x0B95, 1},
|
||||
{0x0B99, 0x0B9A, 1},
|
||||
{0x0B9C, 0x0B9C, 1},
|
||||
{0x0B9E, 0x0B9F, 1},
|
||||
{0x0BA3, 0x0BA4, 1},
|
||||
{0x0BA8, 0x0BAA, 1},
|
||||
{0x0BAE, 0x0BB5, 1},
|
||||
{0x0BB7, 0x0BB9, 1},
|
||||
{0x0C05, 0x0C0C, 1},
|
||||
{0x0C0E, 0x0C10, 1},
|
||||
{0x0C12, 0x0C28, 1},
|
||||
{0x0C2A, 0x0C33, 1},
|
||||
{0x0C35, 0x0C39, 1},
|
||||
{0x0C60, 0x0C61, 1},
|
||||
{0x0C85, 0x0C8C, 1},
|
||||
{0x0C8E, 0x0C90, 1},
|
||||
{0x0C92, 0x0CA8, 1},
|
||||
{0x0CAA, 0x0CB3, 1},
|
||||
{0x0CB5, 0x0CB9, 1},
|
||||
{0x0CDE, 0x0CDE, 1},
|
||||
{0x0CE0, 0x0CE1, 1},
|
||||
{0x0D05, 0x0D0C, 1},
|
||||
{0x0D0E, 0x0D10, 1},
|
||||
{0x0D12, 0x0D28, 1},
|
||||
{0x0D2A, 0x0D39, 1},
|
||||
{0x0D60, 0x0D61, 1},
|
||||
{0x0E01, 0x0E2E, 1},
|
||||
{0x0E30, 0x0E30, 1},
|
||||
{0x0E32, 0x0E33, 1},
|
||||
{0x0E40, 0x0E45, 1},
|
||||
{0x0E81, 0x0E82, 1},
|
||||
{0x0E84, 0x0E84, 1},
|
||||
{0x0E87, 0x0E88, 1},
|
||||
{0x0E8A, 0x0E8D, 3},
|
||||
{0x0E94, 0x0E97, 1},
|
||||
{0x0E99, 0x0E9F, 1},
|
||||
{0x0EA1, 0x0EA3, 1},
|
||||
{0x0EA5, 0x0EA7, 2},
|
||||
{0x0EAA, 0x0EAB, 1},
|
||||
{0x0EAD, 0x0EAE, 1},
|
||||
{0x0EB0, 0x0EB0, 1},
|
||||
{0x0EB2, 0x0EB3, 1},
|
||||
{0x0EBD, 0x0EBD, 1},
|
||||
{0x0EC0, 0x0EC4, 1},
|
||||
{0x0F40, 0x0F47, 1},
|
||||
{0x0F49, 0x0F69, 1},
|
||||
{0x10A0, 0x10C5, 1},
|
||||
{0x10D0, 0x10F6, 1},
|
||||
{0x1100, 0x1100, 1},
|
||||
{0x1102, 0x1103, 1},
|
||||
{0x1105, 0x1107, 1},
|
||||
{0x1109, 0x1109, 1},
|
||||
{0x110B, 0x110C, 1},
|
||||
{0x110E, 0x1112, 1},
|
||||
{0x113C, 0x1140, 2},
|
||||
{0x114C, 0x1150, 2},
|
||||
{0x1154, 0x1155, 1},
|
||||
{0x1159, 0x1159, 1},
|
||||
{0x115F, 0x1161, 1},
|
||||
{0x1163, 0x1169, 2},
|
||||
{0x116D, 0x116E, 1},
|
||||
{0x1172, 0x1173, 1},
|
||||
{0x1175, 0x119E, 0x119E - 0x1175},
|
||||
{0x11A8, 0x11AB, 0x11AB - 0x11A8},
|
||||
{0x11AE, 0x11AF, 1},
|
||||
{0x11B7, 0x11B8, 1},
|
||||
{0x11BA, 0x11BA, 1},
|
||||
{0x11BC, 0x11C2, 1},
|
||||
{0x11EB, 0x11F0, 0x11F0 - 0x11EB},
|
||||
{0x11F9, 0x11F9, 1},
|
||||
{0x1E00, 0x1E9B, 1},
|
||||
{0x1EA0, 0x1EF9, 1},
|
||||
{0x1F00, 0x1F15, 1},
|
||||
{0x1F18, 0x1F1D, 1},
|
||||
{0x1F20, 0x1F45, 1},
|
||||
{0x1F48, 0x1F4D, 1},
|
||||
{0x1F50, 0x1F57, 1},
|
||||
{0x1F59, 0x1F5B, 0x1F5B - 0x1F59},
|
||||
{0x1F5D, 0x1F5D, 1},
|
||||
{0x1F5F, 0x1F7D, 1},
|
||||
{0x1F80, 0x1FB4, 1},
|
||||
{0x1FB6, 0x1FBC, 1},
|
||||
{0x1FBE, 0x1FBE, 1},
|
||||
{0x1FC2, 0x1FC4, 1},
|
||||
{0x1FC6, 0x1FCC, 1},
|
||||
{0x1FD0, 0x1FD3, 1},
|
||||
{0x1FD6, 0x1FDB, 1},
|
||||
{0x1FE0, 0x1FEC, 1},
|
||||
{0x1FF2, 0x1FF4, 1},
|
||||
{0x1FF6, 0x1FFC, 1},
|
||||
{0x2126, 0x2126, 1},
|
||||
{0x212A, 0x212B, 1},
|
||||
{0x212E, 0x212E, 1},
|
||||
{0x2180, 0x2182, 1},
|
||||
{0x3007, 0x3007, 1},
|
||||
{0x3021, 0x3029, 1},
|
||||
{0x3041, 0x3094, 1},
|
||||
{0x30A1, 0x30FA, 1},
|
||||
{0x3105, 0x312C, 1},
|
||||
{0x4E00, 0x9FA5, 1},
|
||||
{0xAC00, 0xD7A3, 1},
|
||||
},
|
||||
}
|
||||
|
||||
var second = &unicode.RangeTable{
|
||||
R16: []unicode.Range16{
|
||||
{0x002D, 0x002E, 1},
|
||||
{0x0030, 0x0039, 1},
|
||||
{0x00B7, 0x00B7, 1},
|
||||
{0x02D0, 0x02D1, 1},
|
||||
{0x0300, 0x0345, 1},
|
||||
{0x0360, 0x0361, 1},
|
||||
{0x0387, 0x0387, 1},
|
||||
{0x0483, 0x0486, 1},
|
||||
{0x0591, 0x05A1, 1},
|
||||
{0x05A3, 0x05B9, 1},
|
||||
{0x05BB, 0x05BD, 1},
|
||||
{0x05BF, 0x05BF, 1},
|
||||
{0x05C1, 0x05C2, 1},
|
||||
{0x05C4, 0x0640, 0x0640 - 0x05C4},
|
||||
{0x064B, 0x0652, 1},
|
||||
{0x0660, 0x0669, 1},
|
||||
{0x0670, 0x0670, 1},
|
||||
{0x06D6, 0x06DC, 1},
|
||||
{0x06DD, 0x06DF, 1},
|
||||
{0x06E0, 0x06E4, 1},
|
||||
{0x06E7, 0x06E8, 1},
|
||||
{0x06EA, 0x06ED, 1},
|
||||
{0x06F0, 0x06F9, 1},
|
||||
{0x0901, 0x0903, 1},
|
||||
{0x093C, 0x093C, 1},
|
||||
{0x093E, 0x094C, 1},
|
||||
{0x094D, 0x094D, 1},
|
||||
{0x0951, 0x0954, 1},
|
||||
{0x0962, 0x0963, 1},
|
||||
{0x0966, 0x096F, 1},
|
||||
{0x0981, 0x0983, 1},
|
||||
{0x09BC, 0x09BC, 1},
|
||||
{0x09BE, 0x09BF, 1},
|
||||
{0x09C0, 0x09C4, 1},
|
||||
{0x09C7, 0x09C8, 1},
|
||||
{0x09CB, 0x09CD, 1},
|
||||
{0x09D7, 0x09D7, 1},
|
||||
{0x09E2, 0x09E3, 1},
|
||||
{0x09E6, 0x09EF, 1},
|
||||
{0x0A02, 0x0A3C, 0x3A},
|
||||
{0x0A3E, 0x0A3F, 1},
|
||||
{0x0A40, 0x0A42, 1},
|
||||
{0x0A47, 0x0A48, 1},
|
||||
{0x0A4B, 0x0A4D, 1},
|
||||
{0x0A66, 0x0A6F, 1},
|
||||
{0x0A70, 0x0A71, 1},
|
||||
{0x0A81, 0x0A83, 1},
|
||||
{0x0ABC, 0x0ABC, 1},
|
||||
{0x0ABE, 0x0AC5, 1},
|
||||
{0x0AC7, 0x0AC9, 1},
|
||||
{0x0ACB, 0x0ACD, 1},
|
||||
{0x0AE6, 0x0AEF, 1},
|
||||
{0x0B01, 0x0B03, 1},
|
||||
{0x0B3C, 0x0B3C, 1},
|
||||
{0x0B3E, 0x0B43, 1},
|
||||
{0x0B47, 0x0B48, 1},
|
||||
{0x0B4B, 0x0B4D, 1},
|
||||
{0x0B56, 0x0B57, 1},
|
||||
{0x0B66, 0x0B6F, 1},
|
||||
{0x0B82, 0x0B83, 1},
|
||||
{0x0BBE, 0x0BC2, 1},
|
||||
{0x0BC6, 0x0BC8, 1},
|
||||
{0x0BCA, 0x0BCD, 1},
|
||||
{0x0BD7, 0x0BD7, 1},
|
||||
{0x0BE7, 0x0BEF, 1},
|
||||
{0x0C01, 0x0C03, 1},
|
||||
{0x0C3E, 0x0C44, 1},
|
||||
{0x0C46, 0x0C48, 1},
|
||||
{0x0C4A, 0x0C4D, 1},
|
||||
{0x0C55, 0x0C56, 1},
|
||||
{0x0C66, 0x0C6F, 1},
|
||||
{0x0C82, 0x0C83, 1},
|
||||
{0x0CBE, 0x0CC4, 1},
|
||||
{0x0CC6, 0x0CC8, 1},
|
||||
{0x0CCA, 0x0CCD, 1},
|
||||
{0x0CD5, 0x0CD6, 1},
|
||||
{0x0CE6, 0x0CEF, 1},
|
||||
{0x0D02, 0x0D03, 1},
|
||||
{0x0D3E, 0x0D43, 1},
|
||||
{0x0D46, 0x0D48, 1},
|
||||
{0x0D4A, 0x0D4D, 1},
|
||||
{0x0D57, 0x0D57, 1},
|
||||
{0x0D66, 0x0D6F, 1},
|
||||
{0x0E31, 0x0E31, 1},
|
||||
{0x0E34, 0x0E3A, 1},
|
||||
{0x0E46, 0x0E46, 1},
|
||||
{0x0E47, 0x0E4E, 1},
|
||||
{0x0E50, 0x0E59, 1},
|
||||
{0x0EB1, 0x0EB1, 1},
|
||||
{0x0EB4, 0x0EB9, 1},
|
||||
{0x0EBB, 0x0EBC, 1},
|
||||
{0x0EC6, 0x0EC6, 1},
|
||||
{0x0EC8, 0x0ECD, 1},
|
||||
{0x0ED0, 0x0ED9, 1},
|
||||
{0x0F18, 0x0F19, 1},
|
||||
{0x0F20, 0x0F29, 1},
|
||||
{0x0F35, 0x0F39, 2},
|
||||
{0x0F3E, 0x0F3F, 1},
|
||||
{0x0F71, 0x0F84, 1},
|
||||
{0x0F86, 0x0F8B, 1},
|
||||
{0x0F90, 0x0F95, 1},
|
||||
{0x0F97, 0x0F97, 1},
|
||||
{0x0F99, 0x0FAD, 1},
|
||||
{0x0FB1, 0x0FB7, 1},
|
||||
{0x0FB9, 0x0FB9, 1},
|
||||
{0x20D0, 0x20DC, 1},
|
||||
{0x20E1, 0x3005, 0x3005 - 0x20E1},
|
||||
{0x302A, 0x302F, 1},
|
||||
{0x3031, 0x3035, 1},
|
||||
{0x3099, 0x309A, 1},
|
||||
{0x309D, 0x309E, 1},
|
||||
{0x30FC, 0x30FE, 1},
|
||||
},
|
||||
}
|
|
@ -0,0 +1,113 @@
|
|||
package parser
|
||||
|
||||
import "github.com/ChrisTrenkamp/goxpath/internal/lexer"
|
||||
|
||||
//NodeType enumerations
|
||||
const (
|
||||
Empty lexer.XItemType = ""
|
||||
)
|
||||
|
||||
//Node builds an AST tree for operating on XPath expressions
|
||||
type Node struct {
|
||||
Val lexer.XItem
|
||||
Left *Node
|
||||
Right *Node
|
||||
Parent *Node
|
||||
next *Node
|
||||
}
|
||||
|
||||
var beginPathType = map[lexer.XItemType]bool{
|
||||
lexer.XItemAbsLocPath: true,
|
||||
lexer.XItemAbbrAbsLocPath: true,
|
||||
lexer.XItemAbbrRelLocPath: true,
|
||||
lexer.XItemRelLocPath: true,
|
||||
lexer.XItemFunction: true,
|
||||
}
|
||||
|
||||
func (n *Node) add(i lexer.XItem) {
|
||||
if n.Val.Typ == Empty {
|
||||
n.Val = i
|
||||
} else if n.Left == nil {
|
||||
n.Left = &Node{Val: n.Val, Parent: n}
|
||||
n.Val = i
|
||||
} else if beginPathType[n.Val.Typ] {
|
||||
next := &Node{Val: n.Val, Left: n.Left, Parent: n}
|
||||
n.Left = next
|
||||
n.Val = i
|
||||
} else if n.Right == nil {
|
||||
n.Right = &Node{Val: i, Parent: n}
|
||||
} else {
|
||||
next := &Node{Val: n.Val, Left: n.Left, Right: n.Right, Parent: n}
|
||||
n.Left, n.Right = next, nil
|
||||
n.Val = i
|
||||
}
|
||||
n.next = n
|
||||
}
|
||||
|
||||
func (n *Node) push(i lexer.XItem) {
|
||||
if n.Left == nil {
|
||||
n.Left = &Node{Val: i, Parent: n}
|
||||
n.next = n.Left
|
||||
} else if n.Right == nil {
|
||||
n.Right = &Node{Val: i, Parent: n}
|
||||
n.next = n.Right
|
||||
} else {
|
||||
next := &Node{Val: i, Left: n.Right, Parent: n}
|
||||
n.Right = next
|
||||
n.next = n.Right
|
||||
}
|
||||
}
|
||||
|
||||
func (n *Node) pushNotEmpty(i lexer.XItem) {
|
||||
if n.Val.Typ == Empty {
|
||||
n.add(i)
|
||||
} else {
|
||||
n.push(i)
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
func (n *Node) prettyPrint(depth, width int) {
|
||||
nodes := []*Node{}
|
||||
n.getLine(depth, &nodes)
|
||||
fmt.Printf("%*s", (width-depth)*2, "")
|
||||
toggle := true
|
||||
if len(nodes) > 1 {
|
||||
for _, i := range nodes {
|
||||
if i != nil {
|
||||
if toggle {
|
||||
fmt.Print("/ ")
|
||||
} else {
|
||||
fmt.Print("\\ ")
|
||||
}
|
||||
}
|
||||
toggle = !toggle
|
||||
}
|
||||
fmt.Println()
|
||||
fmt.Printf("%*s", (width-depth)*2, "")
|
||||
}
|
||||
for _, i := range nodes {
|
||||
if i != nil {
|
||||
fmt.Print(i.Val.Val, " ")
|
||||
}
|
||||
}
|
||||
fmt.Println()
|
||||
}
|
||||
|
||||
func (n *Node) getLine(depth int, ret *[]*Node) {
|
||||
if depth <= 0 && n != nil {
|
||||
*ret = append(*ret, n)
|
||||
return
|
||||
}
|
||||
if n.Left != nil {
|
||||
n.Left.getLine(depth-1, ret)
|
||||
} else if depth-1 <= 0 {
|
||||
*ret = append(*ret, nil)
|
||||
}
|
||||
if n.Right != nil {
|
||||
n.Right.getLine(depth-1, ret)
|
||||
} else if depth-1 <= 0 {
|
||||
*ret = append(*ret, nil)
|
||||
}
|
||||
}
|
||||
*/
|
307
vendor/github.com/ChrisTrenkamp/goxpath/internal/parser/findutil/findUtil.go
generated
vendored
Normal file
307
vendor/github.com/ChrisTrenkamp/goxpath/internal/parser/findutil/findUtil.go
generated
vendored
Normal file
|
@ -0,0 +1,307 @@
|
|||
package findutil
|
||||
|
||||
import (
|
||||
"encoding/xml"
|
||||
|
||||
"github.com/ChrisTrenkamp/goxpath/internal/parser/pathexpr"
|
||||
"github.com/ChrisTrenkamp/goxpath/internal/xconst"
|
||||
"github.com/ChrisTrenkamp/goxpath/tree"
|
||||
)
|
||||
|
||||
const (
|
||||
wildcard = "*"
|
||||
)
|
||||
|
||||
type findFunc func(tree.Node, *pathexpr.PathExpr, *[]tree.Node)
|
||||
|
||||
var findMap = map[string]findFunc{
|
||||
xconst.AxisAncestor: findAncestor,
|
||||
xconst.AxisAncestorOrSelf: findAncestorOrSelf,
|
||||
xconst.AxisAttribute: findAttribute,
|
||||
xconst.AxisChild: findChild,
|
||||
xconst.AxisDescendent: findDescendent,
|
||||
xconst.AxisDescendentOrSelf: findDescendentOrSelf,
|
||||
xconst.AxisFollowing: findFollowing,
|
||||
xconst.AxisFollowingSibling: findFollowingSibling,
|
||||
xconst.AxisNamespace: findNamespace,
|
||||
xconst.AxisParent: findParent,
|
||||
xconst.AxisPreceding: findPreceding,
|
||||
xconst.AxisPrecedingSibling: findPrecedingSibling,
|
||||
xconst.AxisSelf: findSelf,
|
||||
}
|
||||
|
||||
//Find finds nodes based on the pathexpr.PathExpr
|
||||
func Find(x tree.Node, p pathexpr.PathExpr) []tree.Node {
|
||||
ret := []tree.Node{}
|
||||
|
||||
if p.Axis == "" {
|
||||
findChild(x, &p, &ret)
|
||||
return ret
|
||||
}
|
||||
|
||||
f := findMap[p.Axis]
|
||||
f(x, &p, &ret)
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
func findAncestor(x tree.Node, p *pathexpr.PathExpr, ret *[]tree.Node) {
|
||||
if x.GetNodeType() == tree.NtRoot {
|
||||
return
|
||||
}
|
||||
|
||||
addNode(x.GetParent(), p, ret)
|
||||
findAncestor(x.GetParent(), p, ret)
|
||||
}
|
||||
|
||||
func findAncestorOrSelf(x tree.Node, p *pathexpr.PathExpr, ret *[]tree.Node) {
|
||||
findSelf(x, p, ret)
|
||||
findAncestor(x, p, ret)
|
||||
}
|
||||
|
||||
func findAttribute(x tree.Node, p *pathexpr.PathExpr, ret *[]tree.Node) {
|
||||
if ele, ok := x.(tree.Elem); ok {
|
||||
for _, i := range ele.GetAttrs() {
|
||||
addNode(i, p, ret)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func findChild(x tree.Node, p *pathexpr.PathExpr, ret *[]tree.Node) {
|
||||
if ele, ok := x.(tree.Elem); ok {
|
||||
ch := ele.GetChildren()
|
||||
for i := range ch {
|
||||
addNode(ch[i], p, ret)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func findDescendent(x tree.Node, p *pathexpr.PathExpr, ret *[]tree.Node) {
|
||||
if ele, ok := x.(tree.Elem); ok {
|
||||
ch := ele.GetChildren()
|
||||
for i := range ch {
|
||||
addNode(ch[i], p, ret)
|
||||
findDescendent(ch[i], p, ret)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func findDescendentOrSelf(x tree.Node, p *pathexpr.PathExpr, ret *[]tree.Node) {
|
||||
findSelf(x, p, ret)
|
||||
findDescendent(x, p, ret)
|
||||
}
|
||||
|
||||
func findFollowing(x tree.Node, p *pathexpr.PathExpr, ret *[]tree.Node) {
|
||||
if x.GetNodeType() == tree.NtRoot {
|
||||
return
|
||||
}
|
||||
par := x.GetParent()
|
||||
ch := par.GetChildren()
|
||||
i := 0
|
||||
for x != ch[i] {
|
||||
i++
|
||||
}
|
||||
i++
|
||||
for i < len(ch) {
|
||||
findDescendentOrSelf(ch[i], p, ret)
|
||||
i++
|
||||
}
|
||||
findFollowing(par, p, ret)
|
||||
}
|
||||
|
||||
func findFollowingSibling(x tree.Node, p *pathexpr.PathExpr, ret *[]tree.Node) {
|
||||
if x.GetNodeType() == tree.NtRoot {
|
||||
return
|
||||
}
|
||||
par := x.GetParent()
|
||||
ch := par.GetChildren()
|
||||
i := 0
|
||||
for x != ch[i] {
|
||||
i++
|
||||
}
|
||||
i++
|
||||
for i < len(ch) {
|
||||
findSelf(ch[i], p, ret)
|
||||
i++
|
||||
}
|
||||
}
|
||||
|
||||
func findNamespace(x tree.Node, p *pathexpr.PathExpr, ret *[]tree.Node) {
|
||||
if ele, ok := x.(tree.NSElem); ok {
|
||||
ns := tree.BuildNS(ele)
|
||||
for _, i := range ns {
|
||||
addNode(i, p, ret)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func findParent(x tree.Node, p *pathexpr.PathExpr, ret *[]tree.Node) {
|
||||
if x.GetNodeType() != tree.NtRoot {
|
||||
addNode(x.GetParent(), p, ret)
|
||||
}
|
||||
}
|
||||
|
||||
func findPreceding(x tree.Node, p *pathexpr.PathExpr, ret *[]tree.Node) {
|
||||
if x.GetNodeType() == tree.NtRoot {
|
||||
return
|
||||
}
|
||||
par := x.GetParent()
|
||||
ch := par.GetChildren()
|
||||
i := len(ch) - 1
|
||||
for x != ch[i] {
|
||||
i--
|
||||
}
|
||||
i--
|
||||
for i >= 0 {
|
||||
findDescendentOrSelf(ch[i], p, ret)
|
||||
i--
|
||||
}
|
||||
findPreceding(par, p, ret)
|
||||
}
|
||||
|
||||
func findPrecedingSibling(x tree.Node, p *pathexpr.PathExpr, ret *[]tree.Node) {
|
||||
if x.GetNodeType() == tree.NtRoot {
|
||||
return
|
||||
}
|
||||
par := x.GetParent()
|
||||
ch := par.GetChildren()
|
||||
i := len(ch) - 1
|
||||
for x != ch[i] {
|
||||
i--
|
||||
}
|
||||
i--
|
||||
for i >= 0 {
|
||||
findSelf(ch[i], p, ret)
|
||||
i--
|
||||
}
|
||||
}
|
||||
|
||||
func findSelf(x tree.Node, p *pathexpr.PathExpr, ret *[]tree.Node) {
|
||||
addNode(x, p, ret)
|
||||
}
|
||||
|
||||
func addNode(x tree.Node, p *pathexpr.PathExpr, ret *[]tree.Node) {
|
||||
add := false
|
||||
tok := x.GetToken()
|
||||
|
||||
switch x.GetNodeType() {
|
||||
case tree.NtAttr:
|
||||
add = evalAttr(p, tok.(xml.Attr))
|
||||
case tree.NtChd:
|
||||
add = evalChd(p)
|
||||
case tree.NtComm:
|
||||
add = evalComm(p)
|
||||
case tree.NtElem, tree.NtRoot:
|
||||
add = evalEle(p, tok.(xml.StartElement))
|
||||
case tree.NtNs:
|
||||
add = evalNS(p, tok.(xml.Attr))
|
||||
case tree.NtPi:
|
||||
add = evalPI(p)
|
||||
}
|
||||
|
||||
if add {
|
||||
*ret = append(*ret, x)
|
||||
}
|
||||
}
|
||||
|
||||
func evalAttr(p *pathexpr.PathExpr, a xml.Attr) bool {
|
||||
if p.NodeType == "" {
|
||||
if p.Name.Space != wildcard {
|
||||
if a.Name.Space != p.NS[p.Name.Space] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
if p.Name.Local == wildcard && p.Axis == xconst.AxisAttribute {
|
||||
return true
|
||||
}
|
||||
|
||||
if p.Name.Local == a.Name.Local {
|
||||
return true
|
||||
}
|
||||
} else {
|
||||
if p.NodeType == xconst.NodeTypeNode {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func evalChd(p *pathexpr.PathExpr) bool {
|
||||
if p.NodeType == xconst.NodeTypeText || p.NodeType == xconst.NodeTypeNode {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func evalComm(p *pathexpr.PathExpr) bool {
|
||||
if p.NodeType == xconst.NodeTypeComment || p.NodeType == xconst.NodeTypeNode {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func evalEle(p *pathexpr.PathExpr, ele xml.StartElement) bool {
|
||||
if p.NodeType == "" {
|
||||
return checkNameAndSpace(p, ele)
|
||||
}
|
||||
|
||||
if p.NodeType == xconst.NodeTypeNode {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func checkNameAndSpace(p *pathexpr.PathExpr, ele xml.StartElement) bool {
|
||||
if p.Name.Local == wildcard && p.Name.Space == "" {
|
||||
return true
|
||||
}
|
||||
|
||||
if p.Name.Space != wildcard && ele.Name.Space != p.NS[p.Name.Space] {
|
||||
return false
|
||||
}
|
||||
|
||||
if p.Name.Local == wildcard && p.Axis != xconst.AxisAttribute && p.Axis != xconst.AxisNamespace {
|
||||
return true
|
||||
}
|
||||
|
||||
if p.Name.Local == ele.Name.Local {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func evalNS(p *pathexpr.PathExpr, ns xml.Attr) bool {
|
||||
if p.NodeType == "" {
|
||||
if p.Name.Space != "" && p.Name.Space != wildcard {
|
||||
return false
|
||||
}
|
||||
|
||||
if p.Name.Local == wildcard && p.Axis == xconst.AxisNamespace {
|
||||
return true
|
||||
}
|
||||
|
||||
if p.Name.Local == ns.Name.Local {
|
||||
return true
|
||||
}
|
||||
} else {
|
||||
if p.NodeType == xconst.NodeTypeNode {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func evalPI(p *pathexpr.PathExpr) bool {
|
||||
if p.NodeType == xconst.NodeTypeProcInst || p.NodeType == xconst.NodeTypeNode {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
74
vendor/github.com/ChrisTrenkamp/goxpath/internal/parser/intfns/boolfns.go
generated
vendored
Normal file
74
vendor/github.com/ChrisTrenkamp/goxpath/internal/parser/intfns/boolfns.go
generated
vendored
Normal file
|
@ -0,0 +1,74 @@
|
|||
package intfns
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/ChrisTrenkamp/goxpath/tree"
|
||||
"golang.org/x/text/language"
|
||||
)
|
||||
|
||||
func boolean(c tree.Ctx, args ...tree.Result) (tree.Result, error) {
|
||||
if b, ok := args[0].(tree.IsBool); ok {
|
||||
return b.Bool(), nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("Cannot convert object to a boolean")
|
||||
}
|
||||
|
||||
func not(c tree.Ctx, args ...tree.Result) (tree.Result, error) {
|
||||
b, ok := args[0].(tree.IsBool)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("Cannot convert object to a boolean")
|
||||
}
|
||||
return !b.Bool(), nil
|
||||
}
|
||||
|
||||
func _true(c tree.Ctx, args ...tree.Result) (tree.Result, error) {
|
||||
return tree.Bool(true), nil
|
||||
}
|
||||
|
||||
func _false(c tree.Ctx, args ...tree.Result) (tree.Result, error) {
|
||||
return tree.Bool(false), nil
|
||||
}
|
||||
|
||||
func lang(c tree.Ctx, args ...tree.Result) (tree.Result, error) {
|
||||
lStr := args[0].String()
|
||||
|
||||
var n tree.Elem
|
||||
|
||||
for _, i := range c.NodeSet {
|
||||
if i.GetNodeType() == tree.NtElem {
|
||||
n = i.(tree.Elem)
|
||||
} else {
|
||||
n = i.GetParent()
|
||||
}
|
||||
|
||||
for n.GetNodeType() != tree.NtRoot {
|
||||
if attr, ok := tree.GetAttribute(n, "lang", tree.XMLSpace); ok {
|
||||
return checkLang(lStr, attr.Value), nil
|
||||
}
|
||||
n = n.GetParent()
|
||||
}
|
||||
}
|
||||
|
||||
return tree.Bool(false), nil
|
||||
}
|
||||
|
||||
func checkLang(srcStr, targStr string) tree.Bool {
|
||||
srcLang := language.Make(srcStr)
|
||||
srcRegion, srcRegionConf := srcLang.Region()
|
||||
|
||||
targLang := language.Make(targStr)
|
||||
targRegion, targRegionConf := targLang.Region()
|
||||
|
||||
if srcRegionConf == language.Exact && targRegionConf != language.Exact {
|
||||
return tree.Bool(false)
|
||||
}
|
||||
|
||||
if srcRegion != targRegion && srcRegionConf == language.Exact && targRegionConf == language.Exact {
|
||||
return tree.Bool(false)
|
||||
}
|
||||
|
||||
_, _, conf := language.NewMatcher([]language.Tag{srcLang}).Match(targLang)
|
||||
return tree.Bool(conf >= language.High)
|
||||
}
|
41
vendor/github.com/ChrisTrenkamp/goxpath/internal/parser/intfns/intfns.go
generated
vendored
Normal file
41
vendor/github.com/ChrisTrenkamp/goxpath/internal/parser/intfns/intfns.go
generated
vendored
Normal file
|
@ -0,0 +1,41 @@
|
|||
package intfns
|
||||
|
||||
import (
|
||||
"encoding/xml"
|
||||
|
||||
"github.com/ChrisTrenkamp/goxpath/tree"
|
||||
)
|
||||
|
||||
//BuiltIn contains the list of built-in XPath functions
|
||||
var BuiltIn = map[xml.Name]tree.Wrap{
|
||||
//String functions
|
||||
{Local: "string"}: {Fn: _string, NArgs: 1, LastArgOpt: tree.Optional},
|
||||
{Local: "concat"}: {Fn: concat, NArgs: 3, LastArgOpt: tree.Variadic},
|
||||
{Local: "starts-with"}: {Fn: startsWith, NArgs: 2},
|
||||
{Local: "contains"}: {Fn: contains, NArgs: 2},
|
||||
{Local: "substring-before"}: {Fn: substringBefore, NArgs: 2},
|
||||
{Local: "substring-after"}: {Fn: substringAfter, NArgs: 2},
|
||||
{Local: "substring"}: {Fn: substring, NArgs: 3, LastArgOpt: tree.Optional},
|
||||
{Local: "string-length"}: {Fn: stringLength, NArgs: 1, LastArgOpt: tree.Optional},
|
||||
{Local: "normalize-space"}: {Fn: normalizeSpace, NArgs: 1, LastArgOpt: tree.Optional},
|
||||
{Local: "translate"}: {Fn: translate, NArgs: 3},
|
||||
//Node set functions
|
||||
{Local: "last"}: {Fn: last},
|
||||
{Local: "position"}: {Fn: position},
|
||||
{Local: "count"}: {Fn: count, NArgs: 1},
|
||||
{Local: "local-name"}: {Fn: localName, NArgs: 1, LastArgOpt: tree.Optional},
|
||||
{Local: "namespace-uri"}: {Fn: namespaceURI, NArgs: 1, LastArgOpt: tree.Optional},
|
||||
{Local: "name"}: {Fn: name, NArgs: 1, LastArgOpt: tree.Optional},
|
||||
//boolean functions
|
||||
{Local: "boolean"}: {Fn: boolean, NArgs: 1},
|
||||
{Local: "not"}: {Fn: not, NArgs: 1},
|
||||
{Local: "true"}: {Fn: _true},
|
||||
{Local: "false"}: {Fn: _false},
|
||||
{Local: "lang"}: {Fn: lang, NArgs: 1},
|
||||
//number functions
|
||||
{Local: "number"}: {Fn: number, NArgs: 1, LastArgOpt: tree.Optional},
|
||||
{Local: "sum"}: {Fn: sum, NArgs: 1},
|
||||
{Local: "floor"}: {Fn: floor, NArgs: 1},
|
||||
{Local: "ceiling"}: {Fn: ceiling, NArgs: 1},
|
||||
{Local: "round"}: {Fn: round, NArgs: 1},
|
||||
}
|
131
vendor/github.com/ChrisTrenkamp/goxpath/internal/parser/intfns/nodesetfns.go
generated
vendored
Normal file
131
vendor/github.com/ChrisTrenkamp/goxpath/internal/parser/intfns/nodesetfns.go
generated
vendored
Normal file
|
@ -0,0 +1,131 @@
|
|||
package intfns
|
||||
|
||||
import (
|
||||
"encoding/xml"
|
||||
"fmt"
|
||||
|
||||
"github.com/ChrisTrenkamp/goxpath/tree"
|
||||
)
|
||||
|
||||
func last(c tree.Ctx, args ...tree.Result) (tree.Result, error) {
|
||||
return tree.Num(c.Size), nil
|
||||
}
|
||||
|
||||
func position(c tree.Ctx, args ...tree.Result) (tree.Result, error) {
|
||||
return tree.Num(c.Pos), nil
|
||||
}
|
||||
|
||||
func count(c tree.Ctx, args ...tree.Result) (tree.Result, error) {
|
||||
n, ok := args[0].(tree.NodeSet)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("Cannot convert object to a node-set")
|
||||
}
|
||||
|
||||
return tree.Num(len(n)), nil
|
||||
}
|
||||
|
||||
func localName(c tree.Ctx, args ...tree.Result) (tree.Result, error) {
|
||||
var n tree.NodeSet
|
||||
ok := true
|
||||
if len(args) == 1 {
|
||||
n, ok = args[0].(tree.NodeSet)
|
||||
} else {
|
||||
n = c.NodeSet
|
||||
}
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("Cannot convert object to a node-set")
|
||||
}
|
||||
|
||||
ret := ""
|
||||
if len(n) == 0 {
|
||||
return tree.String(ret), nil
|
||||
}
|
||||
node := n[0]
|
||||
|
||||
tok := node.GetToken()
|
||||
|
||||
switch node.GetNodeType() {
|
||||
case tree.NtElem:
|
||||
ret = tok.(xml.StartElement).Name.Local
|
||||
case tree.NtAttr:
|
||||
ret = tok.(xml.Attr).Name.Local
|
||||
case tree.NtPi:
|
||||
ret = tok.(xml.ProcInst).Target
|
||||
}
|
||||
|
||||
return tree.String(ret), nil
|
||||
}
|
||||
|
||||
func namespaceURI(c tree.Ctx, args ...tree.Result) (tree.Result, error) {
|
||||
var n tree.NodeSet
|
||||
ok := true
|
||||
if len(args) == 1 {
|
||||
n, ok = args[0].(tree.NodeSet)
|
||||
} else {
|
||||
n = c.NodeSet
|
||||
}
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("Cannot convert object to a node-set")
|
||||
}
|
||||
|
||||
ret := ""
|
||||
if len(n) == 0 {
|
||||
return tree.String(ret), nil
|
||||
}
|
||||
node := n[0]
|
||||
|
||||
tok := node.GetToken()
|
||||
|
||||
switch node.GetNodeType() {
|
||||
case tree.NtElem:
|
||||
ret = tok.(xml.StartElement).Name.Space
|
||||
case tree.NtAttr:
|
||||
ret = tok.(xml.Attr).Name.Space
|
||||
}
|
||||
|
||||
return tree.String(ret), nil
|
||||
}
|
||||
|
||||
func name(c tree.Ctx, args ...tree.Result) (tree.Result, error) {
|
||||
var n tree.NodeSet
|
||||
ok := true
|
||||
if len(args) == 1 {
|
||||
n, ok = args[0].(tree.NodeSet)
|
||||
} else {
|
||||
n = c.NodeSet
|
||||
}
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("Cannot convert object to a node-set")
|
||||
}
|
||||
|
||||
ret := ""
|
||||
if len(n) == 0 {
|
||||
return tree.String(ret), nil
|
||||
}
|
||||
node := n[0]
|
||||
|
||||
switch node.GetNodeType() {
|
||||
case tree.NtElem:
|
||||
t := node.GetToken().(xml.StartElement)
|
||||
space := ""
|
||||
|
||||
if t.Name.Space != "" {
|
||||
space = fmt.Sprintf("{%s}", t.Name.Space)
|
||||
}
|
||||
|
||||
ret = fmt.Sprintf("%s%s", space, t.Name.Local)
|
||||
case tree.NtAttr:
|
||||
t := node.GetToken().(xml.Attr)
|
||||
space := ""
|
||||
|
||||
if t.Name.Space != "" {
|
||||
space = fmt.Sprintf("{%s}", t.Name.Space)
|
||||
}
|
||||
|
||||
ret = fmt.Sprintf("%s%s", space, t.Name.Local)
|
||||
case tree.NtPi:
|
||||
ret = fmt.Sprintf("%s", node.GetToken().(xml.ProcInst).Target)
|
||||
}
|
||||
|
||||
return tree.String(ret), nil
|
||||
}
|
71
vendor/github.com/ChrisTrenkamp/goxpath/internal/parser/intfns/numfns.go
generated
vendored
Normal file
71
vendor/github.com/ChrisTrenkamp/goxpath/internal/parser/intfns/numfns.go
generated
vendored
Normal file
|
@ -0,0 +1,71 @@
|
|||
package intfns
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/ChrisTrenkamp/goxpath/tree"
|
||||
)
|
||||
|
||||
func number(c tree.Ctx, args ...tree.Result) (tree.Result, error) {
|
||||
if b, ok := args[0].(tree.IsNum); ok {
|
||||
return b.Num(), nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("Cannot convert object to a number")
|
||||
}
|
||||
|
||||
func sum(c tree.Ctx, args ...tree.Result) (tree.Result, error) {
|
||||
n, ok := args[0].(tree.NodeSet)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("Cannot convert object to a node-set")
|
||||
}
|
||||
|
||||
ret := 0.0
|
||||
for _, i := range n {
|
||||
ret += float64(tree.GetNodeNum(i))
|
||||
}
|
||||
|
||||
return tree.Num(ret), nil
|
||||
}
|
||||
|
||||
func floor(c tree.Ctx, args ...tree.Result) (tree.Result, error) {
|
||||
n, ok := args[0].(tree.IsNum)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("Cannot convert object to a number")
|
||||
}
|
||||
|
||||
return tree.Num(math.Floor(float64(n.Num()))), nil
|
||||
}
|
||||
|
||||
func ceiling(c tree.Ctx, args ...tree.Result) (tree.Result, error) {
|
||||
n, ok := args[0].(tree.IsNum)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("Cannot convert object to a number")
|
||||
}
|
||||
|
||||
return tree.Num(math.Ceil(float64(n.Num()))), nil
|
||||
}
|
||||
|
||||
func round(c tree.Ctx, args ...tree.Result) (tree.Result, error) {
|
||||
isn, ok := args[0].(tree.IsNum)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("Cannot convert object to a number")
|
||||
}
|
||||
|
||||
n := isn.Num()
|
||||
|
||||
if math.IsNaN(float64(n)) || math.IsInf(float64(n), 0) {
|
||||
return n, nil
|
||||
}
|
||||
|
||||
if n < -0.5 {
|
||||
n = tree.Num(int(n - 0.5))
|
||||
} else if n > 0.5 {
|
||||
n = tree.Num(int(n + 0.5))
|
||||
} else {
|
||||
n = 0
|
||||
}
|
||||
|
||||
return n, nil
|
||||
}
|
141
vendor/github.com/ChrisTrenkamp/goxpath/internal/parser/intfns/stringfns.go
generated
vendored
Normal file
141
vendor/github.com/ChrisTrenkamp/goxpath/internal/parser/intfns/stringfns.go
generated
vendored
Normal file
|
@ -0,0 +1,141 @@
|
|||
package intfns
|
||||
|
||||
import (
|
||||
"math"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/ChrisTrenkamp/goxpath/tree"
|
||||
)
|
||||
|
||||
func _string(c tree.Ctx, args ...tree.Result) (tree.Result, error) {
|
||||
if len(args) == 1 {
|
||||
return tree.String(args[0].String()), nil
|
||||
}
|
||||
|
||||
return tree.String(c.NodeSet.String()), nil
|
||||
}
|
||||
|
||||
func concat(c tree.Ctx, args ...tree.Result) (tree.Result, error) {
|
||||
ret := ""
|
||||
|
||||
for _, i := range args {
|
||||
ret += i.String()
|
||||
}
|
||||
|
||||
return tree.String(ret), nil
|
||||
}
|
||||
|
||||
func startsWith(c tree.Ctx, args ...tree.Result) (tree.Result, error) {
|
||||
return tree.Bool(strings.Index(args[0].String(), args[1].String()) == 0), nil
|
||||
}
|
||||
|
||||
func contains(c tree.Ctx, args ...tree.Result) (tree.Result, error) {
|
||||
return tree.Bool(strings.Contains(args[0].String(), args[1].String())), nil
|
||||
}
|
||||
|
||||
func substringBefore(c tree.Ctx, args ...tree.Result) (tree.Result, error) {
|
||||
ind := strings.Index(args[0].String(), args[1].String())
|
||||
if ind == -1 {
|
||||
return tree.String(""), nil
|
||||
}
|
||||
|
||||
return tree.String(args[0].String()[:ind]), nil
|
||||
}
|
||||
|
||||
func substringAfter(c tree.Ctx, args ...tree.Result) (tree.Result, error) {
|
||||
ind := strings.Index(args[0].String(), args[1].String())
|
||||
if ind == -1 {
|
||||
return tree.String(""), nil
|
||||
}
|
||||
|
||||
return tree.String(args[0].String()[ind+len(args[1].String()):]), nil
|
||||
}
|
||||
|
||||
func substring(c tree.Ctx, args ...tree.Result) (tree.Result, error) {
|
||||
str := args[0].String()
|
||||
|
||||
bNum, bErr := round(c, args[1])
|
||||
if bErr != nil {
|
||||
return nil, bErr
|
||||
}
|
||||
|
||||
b := bNum.(tree.Num).Num()
|
||||
|
||||
if float64(b-1) >= float64(len(str)) || math.IsNaN(float64(b)) {
|
||||
return tree.String(""), nil
|
||||
}
|
||||
|
||||
if len(args) == 2 {
|
||||
if b <= 1 {
|
||||
b = 1
|
||||
}
|
||||
|
||||
return tree.String(str[int(b)-1:]), nil
|
||||
}
|
||||
|
||||
eNum, eErr := round(c, args[2])
|
||||
if eErr != nil {
|
||||
return nil, eErr
|
||||
}
|
||||
|
||||
e := eNum.(tree.Num).Num()
|
||||
|
||||
if e <= 0 || math.IsNaN(float64(e)) || (math.IsInf(float64(b), 0) && math.IsInf(float64(e), 0)) {
|
||||
return tree.String(""), nil
|
||||
}
|
||||
|
||||
if b <= 1 {
|
||||
e = b + e - 1
|
||||
b = 1
|
||||
}
|
||||
|
||||
if float64(b+e-1) >= float64(len(str)) {
|
||||
e = tree.Num(len(str)) - b + 1
|
||||
}
|
||||
|
||||
return tree.String(str[int(b)-1 : int(b+e)-1]), nil
|
||||
}
|
||||
|
||||
func stringLength(c tree.Ctx, args ...tree.Result) (tree.Result, error) {
|
||||
var str string
|
||||
if len(args) == 1 {
|
||||
str = args[0].String()
|
||||
} else {
|
||||
str = c.NodeSet.String()
|
||||
}
|
||||
|
||||
return tree.Num(len(str)), nil
|
||||
}
|
||||
|
||||
var spaceTrim = regexp.MustCompile(`\s+`)
|
||||
|
||||
func normalizeSpace(c tree.Ctx, args ...tree.Result) (tree.Result, error) {
|
||||
var str string
|
||||
if len(args) == 1 {
|
||||
str = args[0].String()
|
||||
} else {
|
||||
str = c.NodeSet.String()
|
||||
}
|
||||
|
||||
str = strings.TrimSpace(str)
|
||||
|
||||
return tree.String(spaceTrim.ReplaceAllString(str, " ")), nil
|
||||
}
|
||||
|
||||
func translate(c tree.Ctx, args ...tree.Result) (tree.Result, error) {
|
||||
ret := args[0].String()
|
||||
src := args[1].String()
|
||||
repl := args[2].String()
|
||||
|
||||
for i := range src {
|
||||
r := ""
|
||||
if i < len(repl) {
|
||||
r = string(repl[i])
|
||||
}
|
||||
|
||||
ret = strings.Replace(ret, string(src[i]), r, -1)
|
||||
}
|
||||
|
||||
return tree.String(ret), nil
|
||||
}
|
194
vendor/github.com/ChrisTrenkamp/goxpath/internal/parser/parser.go
generated
vendored
Normal file
194
vendor/github.com/ChrisTrenkamp/goxpath/internal/parser/parser.go
generated
vendored
Normal file
|
@ -0,0 +1,194 @@
|
|||
package parser
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/ChrisTrenkamp/goxpath/internal/lexer"
|
||||
)
|
||||
|
||||
type stateType int
|
||||
|
||||
const (
|
||||
defState stateType = iota
|
||||
xpathState
|
||||
funcState
|
||||
paramState
|
||||
predState
|
||||
parenState
|
||||
)
|
||||
|
||||
type parseStack struct {
|
||||
stack []*Node
|
||||
stateTypes []stateType
|
||||
cur *Node
|
||||
}
|
||||
|
||||
func (p *parseStack) push(t stateType) {
|
||||
p.stack = append(p.stack, p.cur)
|
||||
p.stateTypes = append(p.stateTypes, t)
|
||||
}
|
||||
|
||||
func (p *parseStack) pop() {
|
||||
stackPos := len(p.stack) - 1
|
||||
|
||||
p.cur = p.stack[stackPos]
|
||||
p.stack = p.stack[:stackPos]
|
||||
p.stateTypes = p.stateTypes[:stackPos]
|
||||
}
|
||||
|
||||
func (p *parseStack) curState() stateType {
|
||||
if len(p.stateTypes) == 0 {
|
||||
return defState
|
||||
}
|
||||
return p.stateTypes[len(p.stateTypes)-1]
|
||||
}
|
||||
|
||||
type lexFn func(*parseStack, lexer.XItem)
|
||||
|
||||
var parseMap = map[lexer.XItemType]lexFn{
|
||||
lexer.XItemAbsLocPath: xiXPath,
|
||||
lexer.XItemAbbrAbsLocPath: xiXPath,
|
||||
lexer.XItemAbbrRelLocPath: xiXPath,
|
||||
lexer.XItemRelLocPath: xiXPath,
|
||||
lexer.XItemEndPath: xiEndPath,
|
||||
lexer.XItemAxis: xiXPath,
|
||||
lexer.XItemAbbrAxis: xiXPath,
|
||||
lexer.XItemNCName: xiXPath,
|
||||
lexer.XItemQName: xiXPath,
|
||||
lexer.XItemNodeType: xiXPath,
|
||||
lexer.XItemProcLit: xiXPath,
|
||||
lexer.XItemFunction: xiFunc,
|
||||
lexer.XItemArgument: xiFuncArg,
|
||||
lexer.XItemEndFunction: xiEndFunc,
|
||||
lexer.XItemPredicate: xiPred,
|
||||
lexer.XItemEndPredicate: xiEndPred,
|
||||
lexer.XItemStrLit: xiValue,
|
||||
lexer.XItemNumLit: xiValue,
|
||||
lexer.XItemOperator: xiOp,
|
||||
lexer.XItemVariable: xiValue,
|
||||
}
|
||||
|
||||
var opPrecedence = map[string]int{
|
||||
"|": 1,
|
||||
"*": 2,
|
||||
"div": 2,
|
||||
"mod": 2,
|
||||
"+": 3,
|
||||
"-": 3,
|
||||
"=": 4,
|
||||
"!=": 4,
|
||||
"<": 4,
|
||||
"<=": 4,
|
||||
">": 4,
|
||||
">=": 4,
|
||||
"and": 5,
|
||||
"or": 6,
|
||||
}
|
||||
|
||||
//Parse creates an AST tree for XPath expressions.
|
||||
func Parse(xp string) (*Node, error) {
|
||||
var err error
|
||||
c := lexer.Lex(xp)
|
||||
n := &Node{}
|
||||
p := &parseStack{cur: n}
|
||||
|
||||
for next := range c {
|
||||
if next.Typ != lexer.XItemError {
|
||||
parseMap[next.Typ](p, next)
|
||||
} else if err == nil {
|
||||
err = fmt.Errorf(next.Val)
|
||||
}
|
||||
}
|
||||
|
||||
return n, err
|
||||
}
|
||||
|
||||
func xiXPath(p *parseStack, i lexer.XItem) {
|
||||
if p.curState() == xpathState {
|
||||
p.cur.push(i)
|
||||
p.cur = p.cur.next
|
||||
return
|
||||
}
|
||||
|
||||
if p.cur.Val.Typ == lexer.XItemFunction {
|
||||
p.cur.Right = &Node{Val: i, Parent: p.cur}
|
||||
p.cur.next = p.cur.Right
|
||||
p.push(xpathState)
|
||||
p.cur = p.cur.next
|
||||
return
|
||||
}
|
||||
|
||||
p.cur.pushNotEmpty(i)
|
||||
p.push(xpathState)
|
||||
p.cur = p.cur.next
|
||||
}
|
||||
|
||||
func xiEndPath(p *parseStack, i lexer.XItem) {
|
||||
p.pop()
|
||||
}
|
||||
|
||||
func xiFunc(p *parseStack, i lexer.XItem) {
|
||||
p.cur.push(i)
|
||||
p.cur = p.cur.next
|
||||
p.push(funcState)
|
||||
}
|
||||
|
||||
func xiFuncArg(p *parseStack, i lexer.XItem) {
|
||||
if p.curState() != funcState {
|
||||
p.pop()
|
||||
}
|
||||
|
||||
p.cur.push(i)
|
||||
p.cur = p.cur.next
|
||||
p.push(paramState)
|
||||
p.cur.push(lexer.XItem{Typ: Empty, Val: ""})
|
||||
p.cur = p.cur.next
|
||||
}
|
||||
|
||||
func xiEndFunc(p *parseStack, i lexer.XItem) {
|
||||
if p.curState() == paramState {
|
||||
p.pop()
|
||||
}
|
||||
p.pop()
|
||||
}
|
||||
|
||||
func xiPred(p *parseStack, i lexer.XItem) {
|
||||
p.cur.push(i)
|
||||
p.cur = p.cur.next
|
||||
p.push(predState)
|
||||
p.cur.push(lexer.XItem{Typ: Empty, Val: ""})
|
||||
p.cur = p.cur.next
|
||||
}
|
||||
|
||||
func xiEndPred(p *parseStack, i lexer.XItem) {
|
||||
p.pop()
|
||||
}
|
||||
|
||||
func xiValue(p *parseStack, i lexer.XItem) {
|
||||
p.cur.add(i)
|
||||
}
|
||||
|
||||
func xiOp(p *parseStack, i lexer.XItem) {
|
||||
if i.Val == "(" {
|
||||
p.cur.push(lexer.XItem{Typ: Empty, Val: ""})
|
||||
p.push(parenState)
|
||||
p.cur = p.cur.next
|
||||
return
|
||||
}
|
||||
|
||||
if i.Val == ")" {
|
||||
p.pop()
|
||||
return
|
||||
}
|
||||
|
||||
if p.cur.Val.Typ == lexer.XItemOperator {
|
||||
if opPrecedence[p.cur.Val.Val] <= opPrecedence[i.Val] {
|
||||
p.cur.add(i)
|
||||
} else {
|
||||
p.cur.push(i)
|
||||
}
|
||||
} else {
|
||||
p.cur.add(i)
|
||||
}
|
||||
p.cur = p.cur.next
|
||||
}
|
11
vendor/github.com/ChrisTrenkamp/goxpath/internal/parser/pathexpr/pathexpr.go
generated
vendored
Normal file
11
vendor/github.com/ChrisTrenkamp/goxpath/internal/parser/pathexpr/pathexpr.go
generated
vendored
Normal file
|
@ -0,0 +1,11 @@
|
|||
package pathexpr
|
||||
|
||||
import "encoding/xml"
|
||||
|
||||
//PathExpr represents XPath step's. xmltree.XMLTree uses it to find nodes.
|
||||
type PathExpr struct {
|
||||
Name xml.Name
|
||||
Axis string
|
||||
NodeType string
|
||||
NS map[string]string
|
||||
}
|
|
@ -0,0 +1,66 @@
|
|||
package xconst
|
||||
|
||||
const (
|
||||
//AxisAncestor represents the "ancestor" axis
|
||||
AxisAncestor = "ancestor"
|
||||
//AxisAncestorOrSelf represents the "ancestor-or-self" axis
|
||||
AxisAncestorOrSelf = "ancestor-or-self"
|
||||
//AxisAttribute represents the "attribute" axis
|
||||
AxisAttribute = "attribute"
|
||||
//AxisChild represents the "child" axis
|
||||
AxisChild = "child"
|
||||
//AxisDescendent represents the "descendant" axis
|
||||
AxisDescendent = "descendant"
|
||||
//AxisDescendentOrSelf represents the "descendant-or-self" axis
|
||||
AxisDescendentOrSelf = "descendant-or-self"
|
||||
//AxisFollowing represents the "following" axis
|
||||
AxisFollowing = "following"
|
||||
//AxisFollowingSibling represents the "following-sibling" axis
|
||||
AxisFollowingSibling = "following-sibling"
|
||||
//AxisNamespace represents the "namespace" axis
|
||||
AxisNamespace = "namespace"
|
||||
//AxisParent represents the "parent" axis
|
||||
AxisParent = "parent"
|
||||
//AxisPreceding represents the "preceding" axis
|
||||
AxisPreceding = "preceding"
|
||||
//AxisPrecedingSibling represents the "preceding-sibling" axis
|
||||
AxisPrecedingSibling = "preceding-sibling"
|
||||
//AxisSelf represents the "self" axis
|
||||
AxisSelf = "self"
|
||||
)
|
||||
|
||||
//AxisNames is all the possible Axis identifiers wrapped in an array for convenience
|
||||
var AxisNames = []string{
|
||||
AxisAncestor,
|
||||
AxisAncestorOrSelf,
|
||||
AxisAttribute,
|
||||
AxisChild,
|
||||
AxisDescendent,
|
||||
AxisDescendentOrSelf,
|
||||
AxisFollowing,
|
||||
AxisFollowingSibling,
|
||||
AxisNamespace,
|
||||
AxisParent,
|
||||
AxisPreceding,
|
||||
AxisPrecedingSibling,
|
||||
AxisSelf,
|
||||
}
|
||||
|
||||
const (
|
||||
//NodeTypeComment represents the "comment" node test
|
||||
NodeTypeComment = "comment"
|
||||
//NodeTypeText represents the "text" node test
|
||||
NodeTypeText = "text"
|
||||
//NodeTypeProcInst represents the "processing-instruction" node test
|
||||
NodeTypeProcInst = "processing-instruction"
|
||||
//NodeTypeNode represents the "node" node test
|
||||
NodeTypeNode = "node"
|
||||
)
|
||||
|
||||
//NodeTypes is all the possible node tests wrapped in an array for convenience
|
||||
var NodeTypes = []string{
|
||||
NodeTypeComment,
|
||||
NodeTypeText,
|
||||
NodeTypeProcInst,
|
||||
NodeTypeNode,
|
||||
}
|
|
@ -0,0 +1,20 @@
|
|||
package xsort
|
||||
|
||||
import (
|
||||
"sort"
|
||||
|
||||
"github.com/ChrisTrenkamp/goxpath/tree"
|
||||
)
|
||||
|
||||
type nodeSort []tree.Node
|
||||
|
||||
func (ns nodeSort) Len() int { return len(ns) }
|
||||
func (ns nodeSort) Swap(i, j int) { ns[i], ns[j] = ns[j], ns[i] }
|
||||
func (ns nodeSort) Less(i, j int) bool {
|
||||
return ns[i].Pos() < ns[j].Pos()
|
||||
}
|
||||
|
||||
//SortNodes sorts the array by the node document order
|
||||
func SortNodes(res []tree.Node) {
|
||||
sort.Sort(nodeSort(res))
|
||||
}
|
|
@ -0,0 +1,105 @@
|
|||
package goxpath
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/xml"
|
||||
"io"
|
||||
|
||||
"github.com/ChrisTrenkamp/goxpath/tree"
|
||||
)
|
||||
|
||||
//Marshal prints the result tree, r, in XML form to w.
|
||||
func Marshal(n tree.Node, w io.Writer) error {
|
||||
return marshal(n, w)
|
||||
}
|
||||
|
||||
//MarshalStr is like Marhal, but returns a string.
|
||||
func MarshalStr(n tree.Node) (string, error) {
|
||||
ret := bytes.NewBufferString("")
|
||||
err := marshal(n, ret)
|
||||
|
||||
return ret.String(), err
|
||||
}
|
||||
|
||||
func marshal(n tree.Node, w io.Writer) error {
|
||||
e := xml.NewEncoder(w)
|
||||
err := encTok(n, e)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return e.Flush()
|
||||
}
|
||||
|
||||
func encTok(n tree.Node, e *xml.Encoder) error {
|
||||
switch n.GetNodeType() {
|
||||
case tree.NtAttr:
|
||||
return encAttr(n.GetToken().(xml.Attr), e)
|
||||
case tree.NtElem:
|
||||
return encEle(n.(tree.Elem), e)
|
||||
case tree.NtNs:
|
||||
return encNS(n.GetToken().(xml.Attr), e)
|
||||
case tree.NtRoot:
|
||||
for _, i := range n.(tree.Elem).GetChildren() {
|
||||
err := encTok(i, e)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
//case tree.NtChd, tree.NtComm, tree.NtPi:
|
||||
return e.EncodeToken(n.GetToken())
|
||||
}
|
||||
|
||||
func encAttr(a xml.Attr, e *xml.Encoder) error {
|
||||
str := a.Name.Local + `="` + a.Value + `"`
|
||||
|
||||
if a.Name.Space != "" {
|
||||
str += ` xmlns="` + a.Name.Space + `"`
|
||||
}
|
||||
|
||||
pi := xml.ProcInst{
|
||||
Target: "attribute",
|
||||
Inst: ([]byte)(str),
|
||||
}
|
||||
|
||||
return e.EncodeToken(pi)
|
||||
}
|
||||
|
||||
func encNS(ns xml.Attr, e *xml.Encoder) error {
|
||||
pi := xml.ProcInst{
|
||||
Target: "namespace",
|
||||
Inst: ([]byte)(ns.Value),
|
||||
}
|
||||
return e.EncodeToken(pi)
|
||||
}
|
||||
|
||||
func encEle(n tree.Elem, e *xml.Encoder) error {
|
||||
ele := xml.StartElement{
|
||||
Name: n.GetToken().(xml.StartElement).Name,
|
||||
}
|
||||
|
||||
attrs := n.GetAttrs()
|
||||
ele.Attr = make([]xml.Attr, len(attrs))
|
||||
for i := range attrs {
|
||||
ele.Attr[i] = attrs[i].GetToken().(xml.Attr)
|
||||
}
|
||||
|
||||
err := e.EncodeToken(ele)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if x, ok := n.(tree.Elem); ok {
|
||||
for _, i := range x.GetChildren() {
|
||||
err := encTok(i, e)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return e.EncodeToken(ele.End())
|
||||
}
|
|
@ -0,0 +1,18 @@
|
|||
package tree
|
||||
|
||||
import "fmt"
|
||||
|
||||
//Result is used for all data types.
|
||||
type Result interface {
|
||||
fmt.Stringer
|
||||
}
|
||||
|
||||
//IsBool is used for the XPath boolean function. It turns the data type to a bool.
|
||||
type IsBool interface {
|
||||
Bool() Bool
|
||||
}
|
||||
|
||||
//IsNum is used for the XPath number function. It turns the data type to a number.
|
||||
type IsNum interface {
|
||||
Num() Num
|
||||
}
|
|
@ -0,0 +1,221 @@
|
|||
package tree
|
||||
|
||||
import (
|
||||
"encoding/xml"
|
||||
"sort"
|
||||
)
|
||||
|
||||
//XMLSpace is the W3C XML namespace
|
||||
const XMLSpace = "http://www.w3.org/XML/1998/namespace"
|
||||
|
||||
//NodePos is a helper for representing the node's document order
|
||||
type NodePos int
|
||||
|
||||
//Pos returns the node's document order position
|
||||
func (n NodePos) Pos() int {
|
||||
return int(n)
|
||||
}
|
||||
|
||||
//NodeType is a safer way to determine a node's type than type assertions.
|
||||
type NodeType int
|
||||
|
||||
//GetNodeType returns the node's type.
|
||||
func (t NodeType) GetNodeType() NodeType {
|
||||
return t
|
||||
}
|
||||
|
||||
//These are all the possible node types
|
||||
const (
|
||||
NtAttr NodeType = iota
|
||||
NtChd
|
||||
NtComm
|
||||
NtElem
|
||||
NtNs
|
||||
NtRoot
|
||||
NtPi
|
||||
)
|
||||
|
||||
//Node is a XPath result that is a node except elements
|
||||
type Node interface {
|
||||
//ResValue prints the node's string value
|
||||
ResValue() string
|
||||
//Pos returns the node's position in the document order
|
||||
Pos() int
|
||||
//GetToken returns the xml.Token representation of the node
|
||||
GetToken() xml.Token
|
||||
//GetParent returns the parent node, which will always be an XML element
|
||||
GetParent() Elem
|
||||
//GetNodeType returns the node's type
|
||||
GetNodeType() NodeType
|
||||
}
|
||||
|
||||
//Elem is a XPath result that is an element node
|
||||
type Elem interface {
|
||||
Node
|
||||
//GetChildren returns the elements children.
|
||||
GetChildren() []Node
|
||||
//GetAttrs returns the attributes of the element
|
||||
GetAttrs() []Node
|
||||
}
|
||||
|
||||
//NSElem is a node that keeps track of namespaces.
|
||||
type NSElem interface {
|
||||
Elem
|
||||
GetNS() map[xml.Name]string
|
||||
}
|
||||
|
||||
//NSBuilder is a helper-struct for satisfying the NSElem interface
|
||||
type NSBuilder struct {
|
||||
NS map[xml.Name]string
|
||||
}
|
||||
|
||||
//GetNS returns the namespaces found on the current element. It should not be
|
||||
//confused with BuildNS, which actually resolves the namespace nodes.
|
||||
func (ns NSBuilder) GetNS() map[xml.Name]string {
|
||||
return ns.NS
|
||||
}
|
||||
|
||||
type nsValueSort []NS
|
||||
|
||||
func (ns nsValueSort) Len() int { return len(ns) }
|
||||
func (ns nsValueSort) Swap(i, j int) {
|
||||
ns[i], ns[j] = ns[j], ns[i]
|
||||
}
|
||||
func (ns nsValueSort) Less(i, j int) bool {
|
||||
return ns[i].Value < ns[j].Value
|
||||
}
|
||||
|
||||
//BuildNS resolves all the namespace nodes of the element and returns them
|
||||
func BuildNS(t Elem) (ret []NS) {
|
||||
vals := make(map[xml.Name]string)
|
||||
|
||||
if nselem, ok := t.(NSElem); ok {
|
||||
buildNS(nselem, vals)
|
||||
|
||||
ret = make([]NS, 0, len(vals))
|
||||
i := 1
|
||||
|
||||
for k, v := range vals {
|
||||
if !(k.Local == "xmlns" && k.Space == "" && v == "") {
|
||||
ret = append(ret, NS{
|
||||
Attr: xml.Attr{Name: k, Value: v},
|
||||
Parent: t,
|
||||
NodeType: NtNs,
|
||||
})
|
||||
i++
|
||||
}
|
||||
}
|
||||
|
||||
sort.Sort(nsValueSort(ret))
|
||||
for i := range ret {
|
||||
ret[i].NodePos = NodePos(t.Pos() + i + 1)
|
||||
}
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
func buildNS(x NSElem, ret map[xml.Name]string) {
|
||||
if x.GetNodeType() == NtRoot {
|
||||
return
|
||||
}
|
||||
|
||||
if nselem, ok := x.GetParent().(NSElem); ok {
|
||||
buildNS(nselem, ret)
|
||||
}
|
||||
|
||||
for k, v := range x.GetNS() {
|
||||
ret[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
//NS is a namespace node.
|
||||
type NS struct {
|
||||
xml.Attr
|
||||
Parent Elem
|
||||
NodePos
|
||||
NodeType
|
||||
}
|
||||
|
||||
//GetToken returns the xml.Token representation of the namespace.
|
||||
func (ns NS) GetToken() xml.Token {
|
||||
return ns.Attr
|
||||
}
|
||||
|
||||
//GetParent returns the parent node of the namespace.
|
||||
func (ns NS) GetParent() Elem {
|
||||
return ns.Parent
|
||||
}
|
||||
|
||||
//ResValue returns the string value of the namespace
|
||||
func (ns NS) ResValue() string {
|
||||
return ns.Attr.Value
|
||||
}
|
||||
|
||||
//GetAttribute is a convenience function for getting the specified attribute from an element.
|
||||
//false is returned if the attribute is not found.
|
||||
func GetAttribute(n Elem, local, space string) (xml.Attr, bool) {
|
||||
attrs := n.GetAttrs()
|
||||
for _, i := range attrs {
|
||||
attr := i.GetToken().(xml.Attr)
|
||||
if local == attr.Name.Local && space == attr.Name.Space {
|
||||
return attr, true
|
||||
}
|
||||
}
|
||||
return xml.Attr{}, false
|
||||
}
|
||||
|
||||
//GetAttributeVal is like GetAttribute, except it returns the attribute's value.
|
||||
func GetAttributeVal(n Elem, local, space string) (string, bool) {
|
||||
attr, ok := GetAttribute(n, local, space)
|
||||
return attr.Value, ok
|
||||
}
|
||||
|
||||
//GetAttrValOrEmpty is like GetAttributeVal, except it returns an empty string if
|
||||
//the attribute is not found instead of false.
|
||||
func GetAttrValOrEmpty(n Elem, local, space string) string {
|
||||
val, ok := GetAttributeVal(n, local, space)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return val
|
||||
}
|
||||
|
||||
//FindNodeByPos finds a node from the given position. Returns nil if the node
|
||||
//is not found.
|
||||
func FindNodeByPos(n Node, pos int) Node {
|
||||
if n.Pos() == pos {
|
||||
return n
|
||||
}
|
||||
|
||||
if elem, ok := n.(Elem); ok {
|
||||
chldrn := elem.GetChildren()
|
||||
for i := 1; i < len(chldrn); i++ {
|
||||
if chldrn[i-1].Pos() <= pos && chldrn[i].Pos() > pos {
|
||||
return FindNodeByPos(chldrn[i-1], pos)
|
||||
}
|
||||
}
|
||||
|
||||
if len(chldrn) > 0 {
|
||||
if chldrn[len(chldrn)-1].Pos() <= pos {
|
||||
return FindNodeByPos(chldrn[len(chldrn)-1], pos)
|
||||
}
|
||||
}
|
||||
|
||||
attrs := elem.GetAttrs()
|
||||
for _, i := range attrs {
|
||||
if i.Pos() == pos {
|
||||
return i
|
||||
}
|
||||
}
|
||||
|
||||
ns := BuildNS(elem)
|
||||
for _, i := range ns {
|
||||
if i.Pos() == pos {
|
||||
return i
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,52 @@
|
|||
package tree
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
//Ctx represents the current context position, size, node, and the current filtered result
|
||||
type Ctx struct {
|
||||
NodeSet
|
||||
Pos int
|
||||
Size int
|
||||
}
|
||||
|
||||
//Fn is a XPath function, written in Go
|
||||
type Fn func(c Ctx, args ...Result) (Result, error)
|
||||
|
||||
//LastArgOpt sets whether the last argument in a function is optional, variadic, or neither
|
||||
type LastArgOpt int
|
||||
|
||||
//LastArgOpt options
|
||||
const (
|
||||
None LastArgOpt = iota
|
||||
Optional
|
||||
Variadic
|
||||
)
|
||||
|
||||
//Wrap interfaces XPath function calls with Go
|
||||
type Wrap struct {
|
||||
Fn Fn
|
||||
//NArgs represents the number of arguments to the XPath function. -1 represents a single optional argument
|
||||
NArgs int
|
||||
LastArgOpt LastArgOpt
|
||||
}
|
||||
|
||||
//Call checks the arguments and calls Fn if they are valid
|
||||
func (w Wrap) Call(c Ctx, args ...Result) (Result, error) {
|
||||
switch w.LastArgOpt {
|
||||
case Optional:
|
||||
if len(args) == w.NArgs || len(args) == w.NArgs-1 {
|
||||
return w.Fn(c, args...)
|
||||
}
|
||||
case Variadic:
|
||||
if len(args) >= w.NArgs-1 {
|
||||
return w.Fn(c, args...)
|
||||
}
|
||||
default:
|
||||
if len(args) == w.NArgs {
|
||||
return w.Fn(c, args...)
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("Invalid number of arguments")
|
||||
}
|
25
vendor/github.com/ChrisTrenkamp/goxpath/tree/xmltree/xmlbuilder/xmlbuilder.go
generated
vendored
Normal file
25
vendor/github.com/ChrisTrenkamp/goxpath/tree/xmltree/xmlbuilder/xmlbuilder.go
generated
vendored
Normal file
|
@ -0,0 +1,25 @@
|
|||
package xmlbuilder
|
||||
|
||||
import (
|
||||
"encoding/xml"
|
||||
|
||||
"github.com/ChrisTrenkamp/goxpath/tree"
|
||||
)
|
||||
|
||||
//BuilderOpts supplies all the information needed to create an XML node.
|
||||
type BuilderOpts struct {
|
||||
Dec *xml.Decoder
|
||||
Tok xml.Token
|
||||
NodeType tree.NodeType
|
||||
NS map[xml.Name]string
|
||||
Attrs []*xml.Attr
|
||||
NodePos int
|
||||
AttrStartPos int
|
||||
}
|
||||
|
||||
//XMLBuilder is an interface for creating XML structures.
|
||||
type XMLBuilder interface {
|
||||
tree.Node
|
||||
CreateNode(*BuilderOpts) XMLBuilder
|
||||
EndElem() XMLBuilder
|
||||
}
|
106
vendor/github.com/ChrisTrenkamp/goxpath/tree/xmltree/xmlele/xmlele.go
generated
vendored
Normal file
106
vendor/github.com/ChrisTrenkamp/goxpath/tree/xmltree/xmlele/xmlele.go
generated
vendored
Normal file
|
@ -0,0 +1,106 @@
|
|||
package xmlele
|
||||
|
||||
import (
|
||||
"encoding/xml"
|
||||
|
||||
"github.com/ChrisTrenkamp/goxpath/tree"
|
||||
"github.com/ChrisTrenkamp/goxpath/tree/xmltree/xmlbuilder"
|
||||
"github.com/ChrisTrenkamp/goxpath/tree/xmltree/xmlnode"
|
||||
)
|
||||
|
||||
//XMLEle is an implementation of XPRes for XML elements
|
||||
type XMLEle struct {
|
||||
xml.StartElement
|
||||
tree.NSBuilder
|
||||
Attrs []tree.Node
|
||||
Children []tree.Node
|
||||
Parent tree.Elem
|
||||
tree.NodePos
|
||||
tree.NodeType
|
||||
}
|
||||
|
||||
//Root is the default root node builder for xmltree.ParseXML
|
||||
func Root() xmlbuilder.XMLBuilder {
|
||||
return &XMLEle{NodeType: tree.NtRoot}
|
||||
}
|
||||
|
||||
//CreateNode is an implementation of xmlbuilder.XMLBuilder. It appends the node
|
||||
//specified in opts and returns the child if it is an element. Otherwise, it returns x.
|
||||
func (x *XMLEle) CreateNode(opts *xmlbuilder.BuilderOpts) xmlbuilder.XMLBuilder {
|
||||
if opts.NodeType == tree.NtElem {
|
||||
ele := &XMLEle{
|
||||
StartElement: opts.Tok.(xml.StartElement),
|
||||
NSBuilder: tree.NSBuilder{NS: opts.NS},
|
||||
Attrs: make([]tree.Node, len(opts.Attrs)),
|
||||
Parent: x,
|
||||
NodePos: tree.NodePos(opts.NodePos),
|
||||
NodeType: opts.NodeType,
|
||||
}
|
||||
for i := range opts.Attrs {
|
||||
ele.Attrs[i] = xmlnode.XMLNode{
|
||||
Token: opts.Attrs[i],
|
||||
NodePos: tree.NodePos(opts.AttrStartPos + i),
|
||||
NodeType: tree.NtAttr,
|
||||
Parent: ele,
|
||||
}
|
||||
}
|
||||
x.Children = append(x.Children, ele)
|
||||
return ele
|
||||
}
|
||||
|
||||
node := xmlnode.XMLNode{
|
||||
Token: opts.Tok,
|
||||
NodePos: tree.NodePos(opts.NodePos),
|
||||
NodeType: opts.NodeType,
|
||||
Parent: x,
|
||||
}
|
||||
x.Children = append(x.Children, node)
|
||||
return x
|
||||
}
|
||||
|
||||
//EndElem is an implementation of xmlbuilder.XMLBuilder. It returns x's parent.
|
||||
func (x *XMLEle) EndElem() xmlbuilder.XMLBuilder {
|
||||
return x.Parent.(*XMLEle)
|
||||
}
|
||||
|
||||
//GetToken returns the xml.Token representation of the node
|
||||
func (x *XMLEle) GetToken() xml.Token {
|
||||
return x.StartElement
|
||||
}
|
||||
|
||||
//GetParent returns the parent node, or itself if it's the root
|
||||
func (x *XMLEle) GetParent() tree.Elem {
|
||||
return x.Parent
|
||||
}
|
||||
|
||||
//GetChildren returns all child nodes of the element
|
||||
func (x *XMLEle) GetChildren() []tree.Node {
|
||||
ret := make([]tree.Node, len(x.Children))
|
||||
|
||||
for i := range x.Children {
|
||||
ret[i] = x.Children[i]
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
//GetAttrs returns all attributes of the element
|
||||
func (x *XMLEle) GetAttrs() []tree.Node {
|
||||
ret := make([]tree.Node, len(x.Attrs))
|
||||
for i := range x.Attrs {
|
||||
ret[i] = x.Attrs[i]
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
//ResValue returns the string value of the element and children
|
||||
func (x *XMLEle) ResValue() string {
|
||||
ret := ""
|
||||
for i := range x.Children {
|
||||
switch x.Children[i].GetNodeType() {
|
||||
case tree.NtChd, tree.NtElem, tree.NtRoot:
|
||||
ret += x.Children[i].ResValue()
|
||||
}
|
||||
}
|
||||
return ret
|
||||
}
|
43
vendor/github.com/ChrisTrenkamp/goxpath/tree/xmltree/xmlnode/xmlnode.go
generated
vendored
Normal file
43
vendor/github.com/ChrisTrenkamp/goxpath/tree/xmltree/xmlnode/xmlnode.go
generated
vendored
Normal file
|
@ -0,0 +1,43 @@
|
|||
package xmlnode
|
||||
|
||||
import (
|
||||
"encoding/xml"
|
||||
|
||||
"github.com/ChrisTrenkamp/goxpath/tree"
|
||||
)
|
||||
|
||||
//XMLNode will represent an attribute, character data, comment, or processing instruction node
|
||||
type XMLNode struct {
|
||||
xml.Token
|
||||
tree.NodePos
|
||||
tree.NodeType
|
||||
Parent tree.Elem
|
||||
}
|
||||
|
||||
//GetToken returns the xml.Token representation of the node
|
||||
func (a XMLNode) GetToken() xml.Token {
|
||||
if a.NodeType == tree.NtAttr {
|
||||
ret := a.Token.(*xml.Attr)
|
||||
return *ret
|
||||
}
|
||||
return a.Token
|
||||
}
|
||||
|
||||
//GetParent returns the parent node
|
||||
func (a XMLNode) GetParent() tree.Elem {
|
||||
return a.Parent
|
||||
}
|
||||
|
||||
//ResValue returns the string value of the attribute
|
||||
func (a XMLNode) ResValue() string {
|
||||
switch a.NodeType {
|
||||
case tree.NtAttr:
|
||||
return a.Token.(*xml.Attr).Value
|
||||
case tree.NtChd:
|
||||
return string(a.Token.(xml.CharData))
|
||||
case tree.NtComm:
|
||||
return string(a.Token.(xml.Comment))
|
||||
}
|
||||
//case tree.NtPi:
|
||||
return string(a.Token.(xml.ProcInst).Inst)
|
||||
}
|
|
@ -0,0 +1,158 @@
|
|||
package xmltree
|
||||
|
||||
import (
|
||||
"encoding/xml"
|
||||
"io"
|
||||
|
||||
"golang.org/x/net/html/charset"
|
||||
|
||||
"github.com/ChrisTrenkamp/goxpath/tree"
|
||||
"github.com/ChrisTrenkamp/goxpath/tree/xmltree/xmlbuilder"
|
||||
"github.com/ChrisTrenkamp/goxpath/tree/xmltree/xmlele"
|
||||
)
|
||||
|
||||
//ParseOptions is a set of methods and function pointers that alter
|
||||
//the way the XML decoder works and the Node types that are created.
|
||||
//Options that are not set will default to what is set in internal/defoverride.go
|
||||
type ParseOptions struct {
|
||||
Strict bool
|
||||
XMLRoot func() xmlbuilder.XMLBuilder
|
||||
}
|
||||
|
||||
//DirectiveParser is an optional interface extended from XMLBuilder that handles
|
||||
//XML directives.
|
||||
type DirectiveParser interface {
|
||||
xmlbuilder.XMLBuilder
|
||||
Directive(xml.Directive, *xml.Decoder)
|
||||
}
|
||||
|
||||
//ParseSettings is a function for setting the ParseOptions you want when
|
||||
//parsing an XML tree.
|
||||
type ParseSettings func(s *ParseOptions)
|
||||
|
||||
//MustParseXML is like ParseXML, but panics instead of returning an error.
|
||||
func MustParseXML(r io.Reader, op ...ParseSettings) tree.Node {
|
||||
ret, err := ParseXML(r, op...)
|
||||
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
//ParseXML creates an XMLTree structure from an io.Reader.
|
||||
func ParseXML(r io.Reader, op ...ParseSettings) (tree.Node, error) {
|
||||
ov := ParseOptions{
|
||||
Strict: true,
|
||||
XMLRoot: xmlele.Root,
|
||||
}
|
||||
for _, i := range op {
|
||||
i(&ov)
|
||||
}
|
||||
|
||||
dec := xml.NewDecoder(r)
|
||||
dec.CharsetReader = charset.NewReaderLabel
|
||||
dec.Strict = ov.Strict
|
||||
|
||||
ordrPos := 1
|
||||
xmlTree := ov.XMLRoot()
|
||||
|
||||
t, err := dec.Token()
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if head, ok := t.(xml.ProcInst); ok && head.Target == "xml" {
|
||||
t, err = dec.Token()
|
||||
}
|
||||
|
||||
opts := xmlbuilder.BuilderOpts{
|
||||
Dec: dec,
|
||||
}
|
||||
|
||||
for err == nil {
|
||||
switch xt := t.(type) {
|
||||
case xml.StartElement:
|
||||
setEle(&opts, xmlTree, xt, &ordrPos)
|
||||
xmlTree = xmlTree.CreateNode(&opts)
|
||||
case xml.CharData:
|
||||
setNode(&opts, xmlTree, xt, tree.NtChd, &ordrPos)
|
||||
xmlTree = xmlTree.CreateNode(&opts)
|
||||
case xml.Comment:
|
||||
setNode(&opts, xmlTree, xt, tree.NtComm, &ordrPos)
|
||||
xmlTree = xmlTree.CreateNode(&opts)
|
||||
case xml.ProcInst:
|
||||
setNode(&opts, xmlTree, xt, tree.NtPi, &ordrPos)
|
||||
xmlTree = xmlTree.CreateNode(&opts)
|
||||
case xml.EndElement:
|
||||
xmlTree = xmlTree.EndElem()
|
||||
case xml.Directive:
|
||||
if dp, ok := xmlTree.(DirectiveParser); ok {
|
||||
dp.Directive(xt.Copy(), dec)
|
||||
}
|
||||
}
|
||||
|
||||
t, err = dec.Token()
|
||||
}
|
||||
|
||||
if err == io.EOF {
|
||||
err = nil
|
||||
}
|
||||
|
||||
return xmlTree, err
|
||||
}
|
||||
|
||||
func setEle(opts *xmlbuilder.BuilderOpts, xmlTree xmlbuilder.XMLBuilder, ele xml.StartElement, ordrPos *int) {
|
||||
opts.NodePos = *ordrPos
|
||||
opts.Tok = ele
|
||||
opts.Attrs = opts.Attrs[0:0:cap(opts.Attrs)]
|
||||
opts.NS = make(map[xml.Name]string)
|
||||
opts.NodeType = tree.NtElem
|
||||
|
||||
for i := range ele.Attr {
|
||||
attr := ele.Attr[i].Name
|
||||
val := ele.Attr[i].Value
|
||||
|
||||
if (attr.Local == "xmlns" && attr.Space == "") || attr.Space == "xmlns" {
|
||||
opts.NS[attr] = val
|
||||
} else {
|
||||
opts.Attrs = append(opts.Attrs, &ele.Attr[i])
|
||||
}
|
||||
}
|
||||
|
||||
if nstree, ok := xmlTree.(tree.NSElem); ok {
|
||||
ns := make(map[xml.Name]string)
|
||||
|
||||
for _, i := range tree.BuildNS(nstree) {
|
||||
ns[i.Name] = i.Value
|
||||
}
|
||||
|
||||
for k, v := range opts.NS {
|
||||
ns[k] = v
|
||||
}
|
||||
|
||||
if ns[xml.Name{Local: "xmlns"}] == "" {
|
||||
delete(ns, xml.Name{Local: "xmlns"})
|
||||
}
|
||||
|
||||
for k, v := range ns {
|
||||
opts.NS[k] = v
|
||||
}
|
||||
|
||||
if xmlTree.GetNodeType() == tree.NtRoot {
|
||||
opts.NS[xml.Name{Space: "xmlns", Local: "xml"}] = tree.XMLSpace
|
||||
}
|
||||
}
|
||||
|
||||
opts.AttrStartPos = len(opts.NS) + len(opts.Attrs) + *ordrPos
|
||||
*ordrPos = opts.AttrStartPos + 1
|
||||
}
|
||||
|
||||
func setNode(opts *xmlbuilder.BuilderOpts, xmlTree xmlbuilder.XMLBuilder, tok xml.Token, nt tree.NodeType, ordrPos *int) {
|
||||
opts.Tok = xml.CopyToken(tok)
|
||||
opts.NodeType = nt
|
||||
opts.NodePos = *ordrPos
|
||||
*ordrPos++
|
||||
}
|
|
@ -0,0 +1,113 @@
|
|||
package tree
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
//Boolean strings
|
||||
const (
|
||||
True = "true"
|
||||
False = "false"
|
||||
)
|
||||
|
||||
//Bool is a boolean XPath type
|
||||
type Bool bool
|
||||
|
||||
//ResValue satisfies the Res interface for Bool
|
||||
func (b Bool) String() string {
|
||||
if b {
|
||||
return True
|
||||
}
|
||||
|
||||
return False
|
||||
}
|
||||
|
||||
//Bool satisfies the HasBool interface for Bool's
|
||||
func (b Bool) Bool() Bool {
|
||||
return b
|
||||
}
|
||||
|
||||
//Num satisfies the HasNum interface for Bool's
|
||||
func (b Bool) Num() Num {
|
||||
if b {
|
||||
return Num(1)
|
||||
}
|
||||
|
||||
return Num(0)
|
||||
}
|
||||
|
||||
//Num is a number XPath type
|
||||
type Num float64
|
||||
|
||||
//ResValue satisfies the Res interface for Num
|
||||
func (n Num) String() string {
|
||||
if math.IsInf(float64(n), 0) {
|
||||
if math.IsInf(float64(n), 1) {
|
||||
return "Infinity"
|
||||
}
|
||||
return "-Infinity"
|
||||
}
|
||||
return fmt.Sprintf("%g", float64(n))
|
||||
}
|
||||
|
||||
//Bool satisfies the HasBool interface for Num's
|
||||
func (n Num) Bool() Bool {
|
||||
return n != 0
|
||||
}
|
||||
|
||||
//Num satisfies the HasNum interface for Num's
|
||||
func (n Num) Num() Num {
|
||||
return n
|
||||
}
|
||||
|
||||
//String is string XPath type
|
||||
type String string
|
||||
|
||||
//ResValue satisfies the Res interface for String
|
||||
func (s String) String() string {
|
||||
return string(s)
|
||||
}
|
||||
|
||||
//Bool satisfies the HasBool interface for String's
|
||||
func (s String) Bool() Bool {
|
||||
return Bool(len(s) > 0)
|
||||
}
|
||||
|
||||
//Num satisfies the HasNum interface for String's
|
||||
func (s String) Num() Num {
|
||||
num, err := strconv.ParseFloat(strings.TrimSpace(string(s)), 64)
|
||||
if err != nil {
|
||||
return Num(math.NaN())
|
||||
}
|
||||
return Num(num)
|
||||
}
|
||||
|
||||
//NodeSet is a node-set XPath type
|
||||
type NodeSet []Node
|
||||
|
||||
//GetNodeNum converts the node to a string-value and to a number
|
||||
func GetNodeNum(n Node) Num {
|
||||
return String(n.ResValue()).Num()
|
||||
}
|
||||
|
||||
//String satisfies the Res interface for NodeSet
|
||||
func (n NodeSet) String() string {
|
||||
if len(n) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
return n[0].ResValue()
|
||||
}
|
||||
|
||||
//Bool satisfies the HasBool interface for node-set's
|
||||
func (n NodeSet) Bool() Bool {
|
||||
return Bool(len(n) > 0)
|
||||
}
|
||||
|
||||
//Num satisfies the HasNum interface for NodeSet's
|
||||
func (n NodeSet) Num() Num {
|
||||
return String(n.String()).Num()
|
||||
}
|
|
@ -0,0 +1,17 @@
|
|||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in
|
||||
all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
THE SOFTWARE.
|
|
@ -0,0 +1,119 @@
|
|||
XPath
|
||||
====
|
||||
[![GoDoc](https://godoc.org/github.com/antchfx/xpath?status.svg)](https://godoc.org/github.com/antchfx/xpath)
|
||||
[![Coverage Status](https://coveralls.io/repos/github/antchfx/xpath/badge.svg?branch=master)](https://coveralls.io/github/antchfx/xpath?branch=master)
|
||||
[![Build Status](https://travis-ci.org/antchfx/xpath.svg?branch=master)](https://travis-ci.org/antchfx/xpath)
|
||||
[![Go Report Card](https://goreportcard.com/badge/github.com/antchfx/xpath)](https://goreportcard.com/report/github.com/antchfx/xpath)
|
||||
|
||||
XPath is Go package provides selecting nodes from XML, HTML or other documents using XPath expression.
|
||||
|
||||
[XQuery](https://github.com/antchfx/xquery) : lets you extract data from HTML/XML documents using XPath package.
|
||||
|
||||
### Features
|
||||
|
||||
#### The basic XPath patterns.
|
||||
|
||||
> The basic XPath patterns cover 90% of the cases that most stylesheets will need.
|
||||
|
||||
- `node` : Selects all child elements with nodeName of node.
|
||||
|
||||
- `*` : Selects all child elements.
|
||||
|
||||
- `@attr` : Selects the attribute attr.
|
||||
|
||||
- `@*` : Selects all attributes.
|
||||
|
||||
- `node()` : Matches an org.w3c.dom.Node.
|
||||
|
||||
- `text()` : Matches a org.w3c.dom.Text node.
|
||||
|
||||
- `comment()` : Matches a comment.
|
||||
|
||||
- `.` : Selects the current node.
|
||||
|
||||
- `..` : Selects the parent of current node.
|
||||
|
||||
- `/` : Selects the document node.
|
||||
|
||||
- `a[expr]` : Select only those nodes matching a which also satisfy the expression expr.
|
||||
|
||||
- `a[n]` : Selects the nth matching node matching a When a filter's expression is a number, XPath selects based on position.
|
||||
|
||||
- `a/b` : For each node matching a, add the nodes matching b to the result.
|
||||
|
||||
- `a//b` : For each node matching a, add the descendant nodes matching b to the result.
|
||||
|
||||
- `//b` : Returns elements in the entire document matching b.
|
||||
|
||||
- `a|b` : All nodes matching a or b.
|
||||
|
||||
#### Node Axes
|
||||
|
||||
- `child::*` : The child axis selects children of the current node.
|
||||
|
||||
- `descendant::*` : The descendant axis selects descendants of the current node. It is equivalent to '//'.
|
||||
|
||||
- `descendant-or-self::*` : Selects descendants including the current node.
|
||||
|
||||
- `attribute::*` : Selects attributes of the current element. It is equivalent to @*
|
||||
|
||||
- `following-sibling::*` : Selects nodes after the current node.
|
||||
|
||||
- `preceding-sibling::*` : Selects nodes before the current node.
|
||||
|
||||
- `following::*` : Selects the first matching node following in document order, excluding descendants.
|
||||
|
||||
- `preceding::*` : Selects the first matching node preceding in document order, excluding ancestors.
|
||||
|
||||
- `parent::*` : Selects the parent if it matches. The '..' pattern from the core is equivalent to 'parent::node()'.
|
||||
|
||||
- `ancestor::*` : Selects matching ancestors.
|
||||
|
||||
- `ancestor-or-self::*` : Selects ancestors including the current node.
|
||||
|
||||
- `self::*` : Selects the current node. '.' is equivalent to 'self::node()'.
|
||||
|
||||
#### Expressions
|
||||
|
||||
The gxpath supported three types: number, boolean, string.
|
||||
|
||||
- `path` : Selects nodes based on the path.
|
||||
|
||||
- `a = b` : Standard comparisons.
|
||||
|
||||
* a = b True if a equals b.
|
||||
* a != b True if a is not equal to b.
|
||||
* a < b True if a is less than b.
|
||||
* a <= b True if a is less than or equal to b.
|
||||
* a > b True if a is greater than b.
|
||||
* a >= b True if a is greater than or equal to b.
|
||||
|
||||
- `a + b` : Arithmetic expressions.
|
||||
|
||||
* `- a` Unary minus
|
||||
* a + b Add
|
||||
* a - b Substract
|
||||
* a * b Multiply
|
||||
* a div b Divide
|
||||
* a mod b Floating point mod, like Java.
|
||||
|
||||
- `(expr)` : Parenthesized expressions.
|
||||
|
||||
- `fun(arg1, ..., argn)` : Function calls.
|
||||
|
||||
* position()
|
||||
* last()
|
||||
* count( node-set )
|
||||
* name()
|
||||
* starts-with( string, string )
|
||||
* normalize-space( string )
|
||||
* substring( string , start [, length] )
|
||||
* not( expression )
|
||||
* string-length( [string] )
|
||||
* contains( string, string )
|
||||
* sum( node-set )
|
||||
* concat( string1 , string2 [, stringn]* )
|
||||
|
||||
- `a or b` : Boolean or.
|
||||
|
||||
- `a and b` : Boolean and.
|
|
@ -0,0 +1,359 @@
|
|||
package xpath
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
type flag int
|
||||
|
||||
const (
|
||||
noneFlag flag = iota
|
||||
filterFlag
|
||||
)
|
||||
|
||||
// builder provides building an XPath expressions.
|
||||
type builder struct {
|
||||
depth int
|
||||
flag flag
|
||||
firstInput query
|
||||
}
|
||||
|
||||
// axisPredicate creates a predicate to predicating for this axis node.
|
||||
func axisPredicate(root *axisNode) func(NodeNavigator) bool {
|
||||
// get current axix node type.
|
||||
typ := ElementNode
|
||||
if root.AxeType == "attribute" {
|
||||
typ = AttributeNode
|
||||
} else {
|
||||
switch root.Prop {
|
||||
case "comment":
|
||||
typ = CommentNode
|
||||
case "text":
|
||||
typ = TextNode
|
||||
// case "processing-instruction":
|
||||
// typ = ProcessingInstructionNode
|
||||
case "node":
|
||||
typ = ElementNode
|
||||
}
|
||||
}
|
||||
predicate := func(n NodeNavigator) bool {
|
||||
if typ == n.NodeType() || typ == TextNode {
|
||||
if root.LocalName == "" || (root.LocalName == n.LocalName() && root.Prefix == n.Prefix()) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
return predicate
|
||||
}
|
||||
|
||||
// processAxisNode processes a query for the XPath axis node.
|
||||
func (b *builder) processAxisNode(root *axisNode) (query, error) {
|
||||
var (
|
||||
err error
|
||||
qyInput query
|
||||
qyOutput query
|
||||
predicate = axisPredicate(root)
|
||||
)
|
||||
|
||||
if root.Input == nil {
|
||||
qyInput = &contextQuery{}
|
||||
} else {
|
||||
if b.flag&filterFlag == 0 {
|
||||
if root.AxeType == "child" && (root.Input.Type() == nodeAxis) {
|
||||
if input := root.Input.(*axisNode); input.AxeType == "descendant-or-self" {
|
||||
var qyGrandInput query
|
||||
if input.Input != nil {
|
||||
qyGrandInput, _ = b.processNode(input.Input)
|
||||
} else {
|
||||
qyGrandInput = &contextQuery{}
|
||||
}
|
||||
qyOutput = &descendantQuery{Input: qyGrandInput, Predicate: predicate, Self: true}
|
||||
return qyOutput, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
qyInput, err = b.processNode(root.Input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
switch root.AxeType {
|
||||
case "ancestor":
|
||||
qyOutput = &ancestorQuery{Input: qyInput, Predicate: predicate}
|
||||
case "ancestor-or-self":
|
||||
qyOutput = &ancestorQuery{Input: qyInput, Predicate: predicate, Self: true}
|
||||
case "attribute":
|
||||
qyOutput = &attributeQuery{Input: qyInput, Predicate: predicate}
|
||||
case "child":
|
||||
filter := func(n NodeNavigator) bool {
|
||||
v := predicate(n)
|
||||
switch root.Prop {
|
||||
case "text":
|
||||
v = v && n.NodeType() == TextNode
|
||||
case "node":
|
||||
v = v && (n.NodeType() == ElementNode || n.NodeType() == TextNode)
|
||||
case "comment":
|
||||
v = v && n.NodeType() == CommentNode
|
||||
}
|
||||
return v
|
||||
}
|
||||
qyOutput = &childQuery{Input: qyInput, Predicate: filter}
|
||||
case "descendant":
|
||||
qyOutput = &descendantQuery{Input: qyInput, Predicate: predicate}
|
||||
case "descendant-or-self":
|
||||
qyOutput = &descendantQuery{Input: qyInput, Predicate: predicate, Self: true}
|
||||
case "following":
|
||||
qyOutput = &followingQuery{Input: qyInput, Predicate: predicate}
|
||||
case "following-sibling":
|
||||
qyOutput = &followingQuery{Input: qyInput, Predicate: predicate, Sibling: true}
|
||||
case "parent":
|
||||
qyOutput = &parentQuery{Input: qyInput, Predicate: predicate}
|
||||
case "preceding":
|
||||
qyOutput = &precedingQuery{Input: qyInput, Predicate: predicate}
|
||||
case "preceding-sibling":
|
||||
qyOutput = &precedingQuery{Input: qyInput, Predicate: predicate, Sibling: true}
|
||||
case "self":
|
||||
qyOutput = &selfQuery{Input: qyInput, Predicate: predicate}
|
||||
case "namespace":
|
||||
// haha,what will you do someting??
|
||||
default:
|
||||
err = fmt.Errorf("unknown axe type: %s", root.AxeType)
|
||||
return nil, err
|
||||
}
|
||||
return qyOutput, nil
|
||||
}
|
||||
|
||||
// processFilterNode builds query for the XPath filter predicate.
|
||||
func (b *builder) processFilterNode(root *filterNode) (query, error) {
|
||||
b.flag |= filterFlag
|
||||
|
||||
qyInput, err := b.processNode(root.Input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
qyCond, err := b.processNode(root.Condition)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
qyOutput := &filterQuery{Input: qyInput, Predicate: qyCond}
|
||||
return qyOutput, nil
|
||||
}
|
||||
|
||||
// processFunctionNode processes query for the XPath function node.
|
||||
func (b *builder) processFunctionNode(root *functionNode) (query, error) {
|
||||
var qyOutput query
|
||||
switch root.FuncName {
|
||||
case "starts-with":
|
||||
arg1, err := b.processNode(root.Args[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
arg2, err := b.processNode(root.Args[1])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
qyOutput = &functionQuery{Input: b.firstInput, Func: startwithFunc(arg1, arg2)}
|
||||
case "contains":
|
||||
arg1, err := b.processNode(root.Args[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
arg2, err := b.processNode(root.Args[1])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
qyOutput = &functionQuery{Input: b.firstInput, Func: containsFunc(arg1, arg2)}
|
||||
case "substring":
|
||||
//substring( string , start [, length] )
|
||||
if len(root.Args) < 2 {
|
||||
return nil, errors.New("xpath: substring function must have at least two parameter")
|
||||
}
|
||||
var (
|
||||
arg1, arg2, arg3 query
|
||||
err error
|
||||
)
|
||||
if arg1, err = b.processNode(root.Args[0]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if arg2, err = b.processNode(root.Args[1]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(root.Args) == 3 {
|
||||
if arg3, err = b.processNode(root.Args[2]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
qyOutput = &functionQuery{Input: b.firstInput, Func: substringFunc(arg1, arg2, arg3)}
|
||||
case "string-length":
|
||||
// string-length( [string] )
|
||||
if len(root.Args) < 1 {
|
||||
return nil, errors.New("xpath: string-length function must have at least one parameter")
|
||||
}
|
||||
arg1, err := b.processNode(root.Args[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
qyOutput = &functionQuery{Input: b.firstInput, Func: stringLengthFunc(arg1)}
|
||||
case "normalize-space":
|
||||
if len(root.Args) == 0 {
|
||||
return nil, errors.New("xpath: normalize-space function must have at least one parameter")
|
||||
}
|
||||
argQuery, err := b.processNode(root.Args[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
qyOutput = &functionQuery{Input: argQuery, Func: normalizespaceFunc}
|
||||
case "not":
|
||||
if len(root.Args) == 0 {
|
||||
return nil, errors.New("xpath: not function must have at least one parameter")
|
||||
}
|
||||
argQuery, err := b.processNode(root.Args[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
qyOutput = &functionQuery{Input: argQuery, Func: notFunc}
|
||||
case "name":
|
||||
qyOutput = &functionQuery{Input: b.firstInput, Func: nameFunc}
|
||||
case "last":
|
||||
qyOutput = &functionQuery{Input: b.firstInput, Func: lastFunc}
|
||||
case "position":
|
||||
qyOutput = &functionQuery{Input: b.firstInput, Func: positionFunc}
|
||||
case "count":
|
||||
//if b.firstInput == nil {
|
||||
// return nil, errors.New("xpath: expression must evaluate to node-set")
|
||||
//}
|
||||
if len(root.Args) == 0 {
|
||||
return nil, fmt.Errorf("xpath: count(node-sets) function must with have parameters node-sets")
|
||||
}
|
||||
argQuery, err := b.processNode(root.Args[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
qyOutput = &functionQuery{Input: argQuery, Func: countFunc}
|
||||
case "sum":
|
||||
if len(root.Args) == 0 {
|
||||
return nil, fmt.Errorf("xpath: sum(node-sets) function must with have parameters node-sets")
|
||||
}
|
||||
argQuery, err := b.processNode(root.Args[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
qyOutput = &functionQuery{Input: argQuery, Func: sumFunc}
|
||||
case "concat":
|
||||
if len(root.Args) < 2 {
|
||||
return nil, fmt.Errorf("xpath: concat() must have at least two arguments")
|
||||
}
|
||||
var args []query
|
||||
for _, v := range root.Args {
|
||||
q, err := b.processNode(v)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
args = append(args, q)
|
||||
}
|
||||
qyOutput = &functionQuery{Input: b.firstInput, Func: concatFunc(args...)}
|
||||
default:
|
||||
return nil, fmt.Errorf("not yet support this function %s()", root.FuncName)
|
||||
}
|
||||
return qyOutput, nil
|
||||
}
|
||||
|
||||
func (b *builder) processOperatorNode(root *operatorNode) (query, error) {
|
||||
left, err := b.processNode(root.Left)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
right, err := b.processNode(root.Right)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var qyOutput query
|
||||
switch root.Op {
|
||||
case "+", "-", "div", "mod": // Numeric operator
|
||||
var exprFunc func(interface{}, interface{}) interface{}
|
||||
switch root.Op {
|
||||
case "+":
|
||||
exprFunc = plusFunc
|
||||
case "-":
|
||||
exprFunc = minusFunc
|
||||
case "div":
|
||||
exprFunc = divFunc
|
||||
case "mod":
|
||||
exprFunc = modFunc
|
||||
}
|
||||
qyOutput = &numericQuery{Left: left, Right: right, Do: exprFunc}
|
||||
case "=", ">", ">=", "<", "<=", "!=":
|
||||
var exprFunc func(iterator, interface{}, interface{}) interface{}
|
||||
switch root.Op {
|
||||
case "=":
|
||||
exprFunc = eqFunc
|
||||
case ">":
|
||||
exprFunc = gtFunc
|
||||
case ">=":
|
||||
exprFunc = geFunc
|
||||
case "<":
|
||||
exprFunc = ltFunc
|
||||
case "<=":
|
||||
exprFunc = leFunc
|
||||
case "!=":
|
||||
exprFunc = neFunc
|
||||
}
|
||||
qyOutput = &logicalQuery{Left: left, Right: right, Do: exprFunc}
|
||||
case "or", "and", "|":
|
||||
isOr := false
|
||||
if root.Op == "or" || root.Op == "|" {
|
||||
isOr = true
|
||||
}
|
||||
qyOutput = &booleanQuery{Left: left, Right: right, IsOr: isOr}
|
||||
}
|
||||
return qyOutput, nil
|
||||
}
|
||||
|
||||
func (b *builder) processNode(root node) (q query, err error) {
|
||||
if b.depth = b.depth + 1; b.depth > 1024 {
|
||||
err = errors.New("the xpath expressions is too complex")
|
||||
return
|
||||
}
|
||||
|
||||
switch root.Type() {
|
||||
case nodeConstantOperand:
|
||||
n := root.(*operandNode)
|
||||
q = &constantQuery{Val: n.Val}
|
||||
case nodeRoot:
|
||||
q = &contextQuery{Root: true}
|
||||
case nodeAxis:
|
||||
q, err = b.processAxisNode(root.(*axisNode))
|
||||
b.firstInput = q
|
||||
case nodeFilter:
|
||||
q, err = b.processFilterNode(root.(*filterNode))
|
||||
case nodeFunction:
|
||||
q, err = b.processFunctionNode(root.(*functionNode))
|
||||
case nodeOperator:
|
||||
q, err = b.processOperatorNode(root.(*operatorNode))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// build builds a specified XPath expressions expr.
|
||||
func build(expr string) (q query, err error) {
|
||||
defer func() {
|
||||
if e := recover(); e != nil {
|
||||
switch x := e.(type) {
|
||||
case string:
|
||||
err = errors.New(x)
|
||||
case error:
|
||||
err = x
|
||||
default:
|
||||
err = errors.New("unknown panic")
|
||||
}
|
||||
}
|
||||
}()
|
||||
root := parse(expr)
|
||||
b := &builder{}
|
||||
return b.processNode(root)
|
||||
}
|
|
@ -0,0 +1,254 @@
|
|||
package xpath
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// The XPath function list.
|
||||
|
||||
func predicate(q query) func(NodeNavigator) bool {
|
||||
type Predicater interface {
|
||||
Test(NodeNavigator) bool
|
||||
}
|
||||
if p, ok := q.(Predicater); ok {
|
||||
return p.Test
|
||||
}
|
||||
return func(NodeNavigator) bool { return true }
|
||||
}
|
||||
|
||||
// positionFunc is a XPath Node Set functions position().
|
||||
func positionFunc(q query, t iterator) interface{} {
|
||||
var (
|
||||
count = 1
|
||||
node = t.Current()
|
||||
)
|
||||
test := predicate(q)
|
||||
for node.MoveToPrevious() {
|
||||
if test(node) {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return float64(count)
|
||||
}
|
||||
|
||||
// lastFunc is a XPath Node Set functions last().
|
||||
func lastFunc(q query, t iterator) interface{} {
|
||||
var (
|
||||
count = 0
|
||||
node = t.Current()
|
||||
)
|
||||
node.MoveToFirst()
|
||||
test := predicate(q)
|
||||
for {
|
||||
if test(node) {
|
||||
count++
|
||||
}
|
||||
if !node.MoveToNext() {
|
||||
break
|
||||
}
|
||||
}
|
||||
return float64(count)
|
||||
}
|
||||
|
||||
// countFunc is a XPath Node Set functions count(node-set).
|
||||
func countFunc(q query, t iterator) interface{} {
|
||||
var count = 0
|
||||
test := predicate(q)
|
||||
switch typ := q.Evaluate(t).(type) {
|
||||
case query:
|
||||
for node := typ.Select(t); node != nil; node = typ.Select(t) {
|
||||
if test(node) {
|
||||
count++
|
||||
}
|
||||
}
|
||||
}
|
||||
return float64(count)
|
||||
}
|
||||
|
||||
// sumFunc is a XPath Node Set functions sum(node-set).
|
||||
func sumFunc(q query, t iterator) interface{} {
|
||||
var sum float64
|
||||
switch typ := q.Evaluate(t).(type) {
|
||||
case query:
|
||||
for node := typ.Select(t); node != nil; node = typ.Select(t) {
|
||||
if v, err := strconv.ParseFloat(node.Value(), 64); err == nil {
|
||||
sum += v
|
||||
}
|
||||
}
|
||||
case float64:
|
||||
sum = typ
|
||||
case string:
|
||||
if v, err := strconv.ParseFloat(typ, 64); err != nil {
|
||||
sum = v
|
||||
}
|
||||
}
|
||||
return sum
|
||||
}
|
||||
|
||||
// nameFunc is a XPath functions name([node-set]).
|
||||
func nameFunc(q query, t iterator) interface{} {
|
||||
return t.Current().LocalName()
|
||||
}
|
||||
|
||||
// startwithFunc is a XPath functions starts-with(string, string).
|
||||
func startwithFunc(arg1, arg2 query) func(query, iterator) interface{} {
|
||||
return func(q query, t iterator) interface{} {
|
||||
var (
|
||||
m, n string
|
||||
ok bool
|
||||
)
|
||||
switch typ := arg1.Evaluate(t).(type) {
|
||||
case string:
|
||||
m = typ
|
||||
case query:
|
||||
node := typ.Select(t)
|
||||
if node == nil {
|
||||
return false
|
||||
}
|
||||
m = node.Value()
|
||||
default:
|
||||
panic(errors.New("starts-with() function argument type must be string"))
|
||||
}
|
||||
n, ok = arg2.Evaluate(t).(string)
|
||||
if !ok {
|
||||
panic(errors.New("starts-with() function argument type must be string"))
|
||||
}
|
||||
return strings.HasPrefix(m, n)
|
||||
}
|
||||
}
|
||||
|
||||
// containsFunc is a XPath functions contains(string or @attr, string).
|
||||
func containsFunc(arg1, arg2 query) func(query, iterator) interface{} {
|
||||
return func(q query, t iterator) interface{} {
|
||||
var (
|
||||
m, n string
|
||||
ok bool
|
||||
)
|
||||
|
||||
switch typ := arg1.Evaluate(t).(type) {
|
||||
case string:
|
||||
m = typ
|
||||
case query:
|
||||
node := typ.Select(t)
|
||||
if node == nil {
|
||||
return false
|
||||
}
|
||||
m = node.Value()
|
||||
default:
|
||||
panic(errors.New("contains() function argument type must be string"))
|
||||
}
|
||||
|
||||
n, ok = arg2.Evaluate(t).(string)
|
||||
if !ok {
|
||||
panic(errors.New("contains() function argument type must be string"))
|
||||
}
|
||||
|
||||
return strings.Contains(m, n)
|
||||
}
|
||||
}
|
||||
|
||||
// normalizespaceFunc is XPath functions normalize-space(string?)
|
||||
func normalizespaceFunc(q query, t iterator) interface{} {
|
||||
var m string
|
||||
switch typ := q.Evaluate(t).(type) {
|
||||
case string:
|
||||
m = typ
|
||||
case query:
|
||||
node := typ.Select(t)
|
||||
if node == nil {
|
||||
return false
|
||||
}
|
||||
m = node.Value()
|
||||
}
|
||||
return strings.TrimSpace(m)
|
||||
}
|
||||
|
||||
// substringFunc is XPath functions substring function returns a part of a given string.
|
||||
func substringFunc(arg1, arg2, arg3 query) func(query, iterator) interface{} {
|
||||
return func(q query, t iterator) interface{} {
|
||||
var m string
|
||||
switch typ := arg1.Evaluate(t).(type) {
|
||||
case string:
|
||||
m = typ
|
||||
case query:
|
||||
node := typ.Select(t)
|
||||
if node == nil {
|
||||
return false
|
||||
}
|
||||
m = node.Value()
|
||||
}
|
||||
|
||||
var start, length float64
|
||||
var ok bool
|
||||
|
||||
if start, ok = arg2.Evaluate(t).(float64); !ok {
|
||||
panic(errors.New("substring() function first argument type must be int"))
|
||||
}
|
||||
if arg3 != nil {
|
||||
if length, ok = arg3.Evaluate(t).(float64); !ok {
|
||||
panic(errors.New("substring() function second argument type must be int"))
|
||||
}
|
||||
}
|
||||
if (len(m) - int(start)) < int(length) {
|
||||
panic(errors.New("substring() function start and length argument out of range"))
|
||||
}
|
||||
if length > 0 {
|
||||
return m[int(start):int(length+start)]
|
||||
}
|
||||
return m[int(start):]
|
||||
}
|
||||
}
|
||||
|
||||
// stringLengthFunc is XPATH string-length( [string] ) function that returns a number
|
||||
// equal to the number of characters in a given string.
|
||||
func stringLengthFunc(arg1 query) func(query, iterator) interface{} {
|
||||
return func(q query, t iterator) interface{} {
|
||||
switch v := arg1.Evaluate(t).(type) {
|
||||
case string:
|
||||
return float64(len(v))
|
||||
case query:
|
||||
node := v.Select(t)
|
||||
if node == nil {
|
||||
break
|
||||
}
|
||||
return float64(len(node.Value()))
|
||||
}
|
||||
return float64(0)
|
||||
}
|
||||
}
|
||||
|
||||
// notFunc is XPATH functions not(expression) function operation.
|
||||
func notFunc(q query, t iterator) interface{} {
|
||||
switch v := q.Evaluate(t).(type) {
|
||||
case bool:
|
||||
return !v
|
||||
case query:
|
||||
node := v.Select(t)
|
||||
return node == nil
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// concatFunc is the concat function concatenates two or more
|
||||
// strings and returns the resulting string.
|
||||
// concat( string1 , string2 [, stringn]* )
|
||||
func concatFunc(args ...query) func(query, iterator) interface{} {
|
||||
return func(q query, t iterator) interface{} {
|
||||
var a []string
|
||||
for _, v := range args {
|
||||
switch v := v.Evaluate(t).(type) {
|
||||
case string:
|
||||
a = append(a, v)
|
||||
case query:
|
||||
node := v.Select(t)
|
||||
if node != nil {
|
||||
a = append(a, node.Value())
|
||||
}
|
||||
}
|
||||
}
|
||||
return strings.Join(a, "")
|
||||
}
|
||||
}
|
|
@ -0,0 +1,295 @@
|
|||
package xpath
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
// The XPath number operator function list.
|
||||
|
||||
// valueType is a return value type.
|
||||
type valueType int
|
||||
|
||||
const (
|
||||
booleanType valueType = iota
|
||||
numberType
|
||||
stringType
|
||||
nodeSetType
|
||||
)
|
||||
|
||||
func getValueType(i interface{}) valueType {
|
||||
v := reflect.ValueOf(i)
|
||||
switch v.Kind() {
|
||||
case reflect.Float64:
|
||||
return numberType
|
||||
case reflect.String:
|
||||
return stringType
|
||||
case reflect.Bool:
|
||||
return booleanType
|
||||
default:
|
||||
if _, ok := i.(query); ok {
|
||||
return nodeSetType
|
||||
}
|
||||
}
|
||||
panic(fmt.Errorf("xpath unknown value type: %v", v.Kind()))
|
||||
}
|
||||
|
||||
type logical func(iterator, string, interface{}, interface{}) bool
|
||||
|
||||
var logicalFuncs = [][]logical{
|
||||
{cmpBooleanBoolean, nil, nil, nil},
|
||||
{nil, cmpNumericNumeric, cmpNumericString, cmpNumericNodeSet},
|
||||
{nil, cmpStringNumeric, cmpStringString, cmpStringNodeSet},
|
||||
{nil, cmpNodeSetNumeric, cmpNodeSetString, cmpNodeSetNodeSet},
|
||||
}
|
||||
|
||||
// number vs number
|
||||
func cmpNumberNumberF(op string, a, b float64) bool {
|
||||
switch op {
|
||||
case "=":
|
||||
return a == b
|
||||
case ">":
|
||||
return a > b
|
||||
case "<":
|
||||
return a < b
|
||||
case ">=":
|
||||
return a >= b
|
||||
case "<=":
|
||||
return a <= b
|
||||
case "!=":
|
||||
return a != b
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// string vs string
|
||||
func cmpStringStringF(op string, a, b string) bool {
|
||||
switch op {
|
||||
case "=":
|
||||
return a == b
|
||||
case ">":
|
||||
return a > b
|
||||
case "<":
|
||||
return a < b
|
||||
case ">=":
|
||||
return a >= b
|
||||
case "<=":
|
||||
return a <= b
|
||||
case "!=":
|
||||
return a != b
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func cmpBooleanBooleanF(op string, a, b bool) bool {
|
||||
switch op {
|
||||
case "or":
|
||||
return a || b
|
||||
case "and":
|
||||
return a && b
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func cmpNumericNumeric(t iterator, op string, m, n interface{}) bool {
|
||||
a := m.(float64)
|
||||
b := n.(float64)
|
||||
return cmpNumberNumberF(op, a, b)
|
||||
}
|
||||
|
||||
func cmpNumericString(t iterator, op string, m, n interface{}) bool {
|
||||
a := m.(float64)
|
||||
b := n.(string)
|
||||
num, err := strconv.ParseFloat(b, 64)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return cmpNumberNumberF(op, a, num)
|
||||
}
|
||||
|
||||
func cmpNumericNodeSet(t iterator, op string, m, n interface{}) bool {
|
||||
a := m.(float64)
|
||||
b := n.(query)
|
||||
|
||||
for {
|
||||
node := b.Select(t)
|
||||
if node == nil {
|
||||
break
|
||||
}
|
||||
num, err := strconv.ParseFloat(node.Value(), 64)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if cmpNumberNumberF(op, a, num) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func cmpNodeSetNumeric(t iterator, op string, m, n interface{}) bool {
|
||||
a := m.(query)
|
||||
b := n.(float64)
|
||||
for {
|
||||
node := a.Select(t)
|
||||
if node == nil {
|
||||
break
|
||||
}
|
||||
num, err := strconv.ParseFloat(node.Value(), 64)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if cmpNumberNumberF(op, num, b) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func cmpNodeSetString(t iterator, op string, m, n interface{}) bool {
|
||||
a := m.(query)
|
||||
b := n.(string)
|
||||
for {
|
||||
node := a.Select(t)
|
||||
if node == nil {
|
||||
break
|
||||
}
|
||||
if cmpStringStringF(op, b, node.Value()) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func cmpNodeSetNodeSet(t iterator, op string, m, n interface{}) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func cmpStringNumeric(t iterator, op string, m, n interface{}) bool {
|
||||
a := m.(string)
|
||||
b := n.(float64)
|
||||
num, err := strconv.ParseFloat(a, 64)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return cmpNumberNumberF(op, b, num)
|
||||
}
|
||||
|
||||
func cmpStringString(t iterator, op string, m, n interface{}) bool {
|
||||
a := m.(string)
|
||||
b := n.(string)
|
||||
return cmpStringStringF(op, a, b)
|
||||
}
|
||||
|
||||
func cmpStringNodeSet(t iterator, op string, m, n interface{}) bool {
|
||||
a := m.(string)
|
||||
b := n.(query)
|
||||
for {
|
||||
node := b.Select(t)
|
||||
if node == nil {
|
||||
break
|
||||
}
|
||||
if cmpStringStringF(op, a, node.Value()) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func cmpBooleanBoolean(t iterator, op string, m, n interface{}) bool {
|
||||
a := m.(bool)
|
||||
b := n.(bool)
|
||||
return cmpBooleanBooleanF(op, a, b)
|
||||
}
|
||||
|
||||
// eqFunc is an `=` operator.
|
||||
func eqFunc(t iterator, m, n interface{}) interface{} {
|
||||
t1 := getValueType(m)
|
||||
t2 := getValueType(n)
|
||||
return logicalFuncs[t1][t2](t, "=", m, n)
|
||||
}
|
||||
|
||||
// gtFunc is an `>` operator.
|
||||
func gtFunc(t iterator, m, n interface{}) interface{} {
|
||||
t1 := getValueType(m)
|
||||
t2 := getValueType(n)
|
||||
return logicalFuncs[t1][t2](t, ">", m, n)
|
||||
}
|
||||
|
||||
// geFunc is an `>=` operator.
|
||||
func geFunc(t iterator, m, n interface{}) interface{} {
|
||||
t1 := getValueType(m)
|
||||
t2 := getValueType(n)
|
||||
return logicalFuncs[t1][t2](t, ">=", m, n)
|
||||
}
|
||||
|
||||
// ltFunc is an `<` operator.
|
||||
func ltFunc(t iterator, m, n interface{}) interface{} {
|
||||
t1 := getValueType(m)
|
||||
t2 := getValueType(n)
|
||||
return logicalFuncs[t1][t2](t, "<", m, n)
|
||||
}
|
||||
|
||||
// leFunc is an `<=` operator.
|
||||
func leFunc(t iterator, m, n interface{}) interface{} {
|
||||
t1 := getValueType(m)
|
||||
t2 := getValueType(n)
|
||||
return logicalFuncs[t1][t2](t, "<=", m, n)
|
||||
}
|
||||
|
||||
// neFunc is an `!=` operator.
|
||||
func neFunc(t iterator, m, n interface{}) interface{} {
|
||||
t1 := getValueType(m)
|
||||
t2 := getValueType(n)
|
||||
return logicalFuncs[t1][t2](t, "!=", m, n)
|
||||
}
|
||||
|
||||
// orFunc is an `or` operator.
|
||||
var orFunc = func(t iterator, m, n interface{}) interface{} {
|
||||
t1 := getValueType(m)
|
||||
t2 := getValueType(n)
|
||||
return logicalFuncs[t1][t2](t, "or", m, n)
|
||||
}
|
||||
|
||||
func numericExpr(m, n interface{}, cb func(float64, float64) float64) float64 {
|
||||
typ := reflect.TypeOf(float64(0))
|
||||
a := reflect.ValueOf(m).Convert(typ)
|
||||
b := reflect.ValueOf(n).Convert(typ)
|
||||
return cb(a.Float(), b.Float())
|
||||
}
|
||||
|
||||
// plusFunc is an `+` operator.
|
||||
var plusFunc = func(m, n interface{}) interface{} {
|
||||
return numericExpr(m, n, func(a, b float64) float64 {
|
||||
return a + b
|
||||
})
|
||||
}
|
||||
|
||||
// minusFunc is an `-` operator.
|
||||
var minusFunc = func(m, n interface{}) interface{} {
|
||||
return numericExpr(m, n, func(a, b float64) float64 {
|
||||
return a - b
|
||||
})
|
||||
}
|
||||
|
||||
// mulFunc is an `*` operator.
|
||||
var mulFunc = func(m, n interface{}) interface{} {
|
||||
return numericExpr(m, n, func(a, b float64) float64 {
|
||||
return a * b
|
||||
})
|
||||
}
|
||||
|
||||
// divFunc is an `DIV` operator.
|
||||
var divFunc = func(m, n interface{}) interface{} {
|
||||
return numericExpr(m, n, func(a, b float64) float64 {
|
||||
return a / b
|
||||
})
|
||||
}
|
||||
|
||||
// modFunc is an 'MOD' operator.
|
||||
var modFunc = func(m, n interface{}) interface{} {
|
||||
return numericExpr(m, n, func(a, b float64) float64 {
|
||||
return float64(int(a) % int(b))
|
||||
})
|
||||
}
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,728 @@
|
|||
package xpath
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
)
|
||||
|
||||
type iterator interface {
|
||||
Current() NodeNavigator
|
||||
}
|
||||
|
||||
// An XPath query interface.
|
||||
type query interface {
|
||||
// Select traversing iterator returns a query matched node NodeNavigator.
|
||||
Select(iterator) NodeNavigator
|
||||
|
||||
// Evaluate evaluates query and returns values of the current query.
|
||||
Evaluate(iterator) interface{}
|
||||
|
||||
Clone() query
|
||||
}
|
||||
|
||||
// contextQuery is returns current node on the iterator object query.
|
||||
type contextQuery struct {
|
||||
count int
|
||||
Root bool // Moving to root-level node in the current context iterator.
|
||||
}
|
||||
|
||||
func (c *contextQuery) Select(t iterator) (n NodeNavigator) {
|
||||
if c.count == 0 {
|
||||
c.count++
|
||||
n = t.Current().Copy()
|
||||
if c.Root {
|
||||
n.MoveToRoot()
|
||||
}
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
func (c *contextQuery) Evaluate(iterator) interface{} {
|
||||
c.count = 0
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *contextQuery) Clone() query {
|
||||
return &contextQuery{count: 0, Root: c.Root}
|
||||
}
|
||||
|
||||
// ancestorQuery is an XPath ancestor node query.(ancestor::*|ancestor-self::*)
|
||||
type ancestorQuery struct {
|
||||
iterator func() NodeNavigator
|
||||
|
||||
Self bool
|
||||
Input query
|
||||
Predicate func(NodeNavigator) bool
|
||||
}
|
||||
|
||||
func (a *ancestorQuery) Select(t iterator) NodeNavigator {
|
||||
for {
|
||||
if a.iterator == nil {
|
||||
node := a.Input.Select(t)
|
||||
if node == nil {
|
||||
return nil
|
||||
}
|
||||
first := true
|
||||
a.iterator = func() NodeNavigator {
|
||||
if first && a.Self {
|
||||
first = false
|
||||
if a.Predicate(node) {
|
||||
return node
|
||||
}
|
||||
}
|
||||
for node.MoveToParent() {
|
||||
if !a.Predicate(node) {
|
||||
break
|
||||
}
|
||||
return node
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
if node := a.iterator(); node != nil {
|
||||
return node
|
||||
}
|
||||
a.iterator = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (a *ancestorQuery) Evaluate(t iterator) interface{} {
|
||||
a.Input.Evaluate(t)
|
||||
a.iterator = nil
|
||||
return a
|
||||
}
|
||||
|
||||
func (a *ancestorQuery) Test(n NodeNavigator) bool {
|
||||
return a.Predicate(n)
|
||||
}
|
||||
|
||||
func (a *ancestorQuery) Clone() query {
|
||||
return &ancestorQuery{Self: a.Self, Input: a.Input.Clone(), Predicate: a.Predicate}
|
||||
}
|
||||
|
||||
// attributeQuery is an XPath attribute node query.(@*)
|
||||
type attributeQuery struct {
|
||||
iterator func() NodeNavigator
|
||||
|
||||
Input query
|
||||
Predicate func(NodeNavigator) bool
|
||||
}
|
||||
|
||||
func (a *attributeQuery) Select(t iterator) NodeNavigator {
|
||||
for {
|
||||
if a.iterator == nil {
|
||||
node := a.Input.Select(t)
|
||||
if node == nil {
|
||||
return nil
|
||||
}
|
||||
node = node.Copy()
|
||||
a.iterator = func() NodeNavigator {
|
||||
for {
|
||||
onAttr := node.MoveToNextAttribute()
|
||||
if !onAttr {
|
||||
return nil
|
||||
}
|
||||
if a.Predicate(node) {
|
||||
return node
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if node := a.iterator(); node != nil {
|
||||
return node
|
||||
}
|
||||
a.iterator = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (a *attributeQuery) Evaluate(t iterator) interface{} {
|
||||
a.Input.Evaluate(t)
|
||||
a.iterator = nil
|
||||
return a
|
||||
}
|
||||
|
||||
func (a *attributeQuery) Test(n NodeNavigator) bool {
|
||||
return a.Predicate(n)
|
||||
}
|
||||
|
||||
func (a *attributeQuery) Clone() query {
|
||||
return &attributeQuery{Input: a.Input.Clone(), Predicate: a.Predicate}
|
||||
}
|
||||
|
||||
// childQuery is an XPath child node query.(child::*)
|
||||
type childQuery struct {
|
||||
posit int
|
||||
iterator func() NodeNavigator
|
||||
|
||||
Input query
|
||||
Predicate func(NodeNavigator) bool
|
||||
}
|
||||
|
||||
func (c *childQuery) Select(t iterator) NodeNavigator {
|
||||
for {
|
||||
if c.iterator == nil {
|
||||
c.posit = 0
|
||||
node := c.Input.Select(t)
|
||||
if node == nil {
|
||||
return nil
|
||||
}
|
||||
node = node.Copy()
|
||||
first := true
|
||||
c.iterator = func() NodeNavigator {
|
||||
for {
|
||||
if (first && !node.MoveToChild()) || (!first && !node.MoveToNext()) {
|
||||
return nil
|
||||
}
|
||||
first = false
|
||||
if c.Predicate(node) {
|
||||
return node
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if node := c.iterator(); node != nil {
|
||||
c.posit++
|
||||
return node
|
||||
}
|
||||
c.iterator = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *childQuery) Evaluate(t iterator) interface{} {
|
||||
c.Input.Evaluate(t)
|
||||
c.iterator = nil
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *childQuery) Test(n NodeNavigator) bool {
|
||||
return c.Predicate(n)
|
||||
}
|
||||
|
||||
func (c *childQuery) Clone() query {
|
||||
return &childQuery{Input: c.Input.Clone(), Predicate: c.Predicate}
|
||||
}
|
||||
|
||||
// position returns a position of current NodeNavigator.
|
||||
func (c *childQuery) position() int {
|
||||
return c.posit
|
||||
}
|
||||
|
||||
// descendantQuery is an XPath descendant node query.(descendant::* | descendant-or-self::*)
|
||||
type descendantQuery struct {
|
||||
iterator func() NodeNavigator
|
||||
posit int
|
||||
|
||||
Self bool
|
||||
Input query
|
||||
Predicate func(NodeNavigator) bool
|
||||
}
|
||||
|
||||
func (d *descendantQuery) Select(t iterator) NodeNavigator {
|
||||
for {
|
||||
if d.iterator == nil {
|
||||
d.posit = 0
|
||||
node := d.Input.Select(t)
|
||||
if node == nil {
|
||||
return nil
|
||||
}
|
||||
node = node.Copy()
|
||||
level := 0
|
||||
first := true
|
||||
d.iterator = func() NodeNavigator {
|
||||
if first && d.Self {
|
||||
first = false
|
||||
if d.Predicate(node) {
|
||||
return node
|
||||
}
|
||||
}
|
||||
|
||||
for {
|
||||
if node.MoveToChild() {
|
||||
level++
|
||||
} else {
|
||||
for {
|
||||
if level == 0 {
|
||||
return nil
|
||||
}
|
||||
if node.MoveToNext() {
|
||||
break
|
||||
}
|
||||
node.MoveToParent()
|
||||
level--
|
||||
}
|
||||
}
|
||||
if d.Predicate(node) {
|
||||
return node
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if node := d.iterator(); node != nil {
|
||||
d.posit++
|
||||
return node
|
||||
}
|
||||
d.iterator = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (d *descendantQuery) Evaluate(t iterator) interface{} {
|
||||
d.Input.Evaluate(t)
|
||||
d.iterator = nil
|
||||
return d
|
||||
}
|
||||
|
||||
func (d *descendantQuery) Test(n NodeNavigator) bool {
|
||||
return d.Predicate(n)
|
||||
}
|
||||
|
||||
// position returns a position of current NodeNavigator.
|
||||
func (d *descendantQuery) position() int {
|
||||
return d.posit
|
||||
}
|
||||
|
||||
func (d *descendantQuery) Clone() query {
|
||||
return &descendantQuery{Self: d.Self, Input: d.Input.Clone(), Predicate: d.Predicate}
|
||||
}
|
||||
|
||||
// followingQuery is an XPath following node query.(following::*|following-sibling::*)
|
||||
type followingQuery struct {
|
||||
iterator func() NodeNavigator
|
||||
|
||||
Input query
|
||||
Sibling bool // The matching sibling node of current node.
|
||||
Predicate func(NodeNavigator) bool
|
||||
}
|
||||
|
||||
func (f *followingQuery) Select(t iterator) NodeNavigator {
|
||||
for {
|
||||
if f.iterator == nil {
|
||||
node := f.Input.Select(t)
|
||||
if node == nil {
|
||||
return nil
|
||||
}
|
||||
node = node.Copy()
|
||||
if f.Sibling {
|
||||
f.iterator = func() NodeNavigator {
|
||||
for {
|
||||
if !node.MoveToNext() {
|
||||
return nil
|
||||
}
|
||||
if f.Predicate(node) {
|
||||
return node
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
var q query // descendant query
|
||||
f.iterator = func() NodeNavigator {
|
||||
for {
|
||||
if q == nil {
|
||||
for !node.MoveToNext() {
|
||||
if !node.MoveToParent() {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
q = &descendantQuery{
|
||||
Self: true,
|
||||
Input: &contextQuery{},
|
||||
Predicate: f.Predicate,
|
||||
}
|
||||
t.Current().MoveTo(node)
|
||||
}
|
||||
if node := q.Select(t); node != nil {
|
||||
return node
|
||||
}
|
||||
q = nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if node := f.iterator(); node != nil {
|
||||
return node
|
||||
}
|
||||
f.iterator = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (f *followingQuery) Evaluate(t iterator) interface{} {
|
||||
f.Input.Evaluate(t)
|
||||
return f
|
||||
}
|
||||
|
||||
func (f *followingQuery) Test(n NodeNavigator) bool {
|
||||
return f.Predicate(n)
|
||||
}
|
||||
|
||||
func (f *followingQuery) Clone() query {
|
||||
return &followingQuery{Input: f.Input.Clone(), Sibling: f.Sibling, Predicate: f.Predicate}
|
||||
}
|
||||
|
||||
// precedingQuery is an XPath preceding node query.(preceding::*)
|
||||
type precedingQuery struct {
|
||||
iterator func() NodeNavigator
|
||||
Input query
|
||||
Sibling bool // The matching sibling node of current node.
|
||||
Predicate func(NodeNavigator) bool
|
||||
}
|
||||
|
||||
func (p *precedingQuery) Select(t iterator) NodeNavigator {
|
||||
for {
|
||||
if p.iterator == nil {
|
||||
node := p.Input.Select(t)
|
||||
if node == nil {
|
||||
return nil
|
||||
}
|
||||
node = node.Copy()
|
||||
if p.Sibling {
|
||||
p.iterator = func() NodeNavigator {
|
||||
for {
|
||||
for !node.MoveToPrevious() {
|
||||
return nil
|
||||
}
|
||||
if p.Predicate(node) {
|
||||
return node
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
var q query
|
||||
p.iterator = func() NodeNavigator {
|
||||
for {
|
||||
if q == nil {
|
||||
for !node.MoveToPrevious() {
|
||||
if !node.MoveToParent() {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
q = &descendantQuery{
|
||||
Self: true,
|
||||
Input: &contextQuery{},
|
||||
Predicate: p.Predicate,
|
||||
}
|
||||
t.Current().MoveTo(node)
|
||||
}
|
||||
if node := q.Select(t); node != nil {
|
||||
return node
|
||||
}
|
||||
q = nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if node := p.iterator(); node != nil {
|
||||
return node
|
||||
}
|
||||
p.iterator = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (p *precedingQuery) Evaluate(t iterator) interface{} {
|
||||
p.Input.Evaluate(t)
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *precedingQuery) Test(n NodeNavigator) bool {
|
||||
return p.Predicate(n)
|
||||
}
|
||||
|
||||
func (p *precedingQuery) Clone() query {
|
||||
return &precedingQuery{Input: p.Input.Clone(), Sibling: p.Sibling, Predicate: p.Predicate}
|
||||
}
|
||||
|
||||
// parentQuery is an XPath parent node query.(parent::*)
|
||||
type parentQuery struct {
|
||||
Input query
|
||||
Predicate func(NodeNavigator) bool
|
||||
}
|
||||
|
||||
func (p *parentQuery) Select(t iterator) NodeNavigator {
|
||||
for {
|
||||
node := p.Input.Select(t)
|
||||
if node == nil {
|
||||
return nil
|
||||
}
|
||||
node = node.Copy()
|
||||
if node.MoveToParent() && p.Predicate(node) {
|
||||
return node
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *parentQuery) Evaluate(t iterator) interface{} {
|
||||
p.Input.Evaluate(t)
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *parentQuery) Clone() query {
|
||||
return &parentQuery{Input: p.Input.Clone(), Predicate: p.Predicate}
|
||||
}
|
||||
|
||||
func (p *parentQuery) Test(n NodeNavigator) bool {
|
||||
return p.Predicate(n)
|
||||
}
|
||||
|
||||
// selfQuery is an Self node query.(self::*)
|
||||
type selfQuery struct {
|
||||
Input query
|
||||
Predicate func(NodeNavigator) bool
|
||||
}
|
||||
|
||||
func (s *selfQuery) Select(t iterator) NodeNavigator {
|
||||
for {
|
||||
node := s.Input.Select(t)
|
||||
if node == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if s.Predicate(node) {
|
||||
return node
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *selfQuery) Evaluate(t iterator) interface{} {
|
||||
s.Input.Evaluate(t)
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *selfQuery) Test(n NodeNavigator) bool {
|
||||
return s.Predicate(n)
|
||||
}
|
||||
|
||||
func (s *selfQuery) Clone() query {
|
||||
return &selfQuery{Input: s.Input.Clone(), Predicate: s.Predicate}
|
||||
}
|
||||
|
||||
// filterQuery is an XPath query for predicate filter.
|
||||
type filterQuery struct {
|
||||
Input query
|
||||
Predicate query
|
||||
}
|
||||
|
||||
func (f *filterQuery) do(t iterator) bool {
|
||||
val := reflect.ValueOf(f.Predicate.Evaluate(t))
|
||||
switch val.Kind() {
|
||||
case reflect.Bool:
|
||||
return val.Bool()
|
||||
case reflect.String:
|
||||
return len(val.String()) > 0
|
||||
case reflect.Float64:
|
||||
pt := float64(getNodePosition(f.Input))
|
||||
return int(val.Float()) == int(pt)
|
||||
default:
|
||||
if q, ok := f.Predicate.(query); ok {
|
||||
return q.Select(t) != nil
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (f *filterQuery) Select(t iterator) NodeNavigator {
|
||||
for {
|
||||
node := f.Input.Select(t)
|
||||
if node == nil {
|
||||
return node
|
||||
}
|
||||
node = node.Copy()
|
||||
//fmt.Println(node.LocalName())
|
||||
|
||||
t.Current().MoveTo(node)
|
||||
if f.do(t) {
|
||||
return node
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (f *filterQuery) Evaluate(t iterator) interface{} {
|
||||
f.Input.Evaluate(t)
|
||||
return f
|
||||
}
|
||||
|
||||
func (f *filterQuery) Clone() query {
|
||||
return &filterQuery{Input: f.Input.Clone(), Predicate: f.Predicate.Clone()}
|
||||
}
|
||||
|
||||
// functionQuery is an XPath function that call a function to returns
|
||||
// value of current NodeNavigator node.
|
||||
type functionQuery struct {
|
||||
Input query // Node Set
|
||||
Func func(query, iterator) interface{} // The xpath function.
|
||||
}
|
||||
|
||||
func (f *functionQuery) Select(t iterator) NodeNavigator {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Evaluate call a specified function that will returns the
|
||||
// following value type: number,string,boolean.
|
||||
func (f *functionQuery) Evaluate(t iterator) interface{} {
|
||||
return f.Func(f.Input, t)
|
||||
}
|
||||
|
||||
func (f *functionQuery) Clone() query {
|
||||
return &functionQuery{Input: f.Input.Clone(), Func: f.Func}
|
||||
}
|
||||
|
||||
// constantQuery is an XPath constant operand.
|
||||
type constantQuery struct {
|
||||
Val interface{}
|
||||
}
|
||||
|
||||
func (c *constantQuery) Select(t iterator) NodeNavigator {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *constantQuery) Evaluate(t iterator) interface{} {
|
||||
return c.Val
|
||||
}
|
||||
|
||||
func (c *constantQuery) Clone() query {
|
||||
return c
|
||||
}
|
||||
|
||||
// logicalQuery is an XPath logical expression.
|
||||
type logicalQuery struct {
|
||||
Left, Right query
|
||||
|
||||
Do func(iterator, interface{}, interface{}) interface{}
|
||||
}
|
||||
|
||||
func (l *logicalQuery) Select(t iterator) NodeNavigator {
|
||||
// When a XPath expr is logical expression.
|
||||
node := t.Current().Copy()
|
||||
val := l.Evaluate(t)
|
||||
switch val.(type) {
|
||||
case bool:
|
||||
if val.(bool) == true {
|
||||
return node
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *logicalQuery) Evaluate(t iterator) interface{} {
|
||||
m := l.Left.Evaluate(t)
|
||||
n := l.Right.Evaluate(t)
|
||||
return l.Do(t, m, n)
|
||||
}
|
||||
|
||||
func (l *logicalQuery) Clone() query {
|
||||
return &logicalQuery{Left: l.Left.Clone(), Right: l.Right.Clone(), Do: l.Do}
|
||||
}
|
||||
|
||||
// numericQuery is an XPath numeric operator expression.
|
||||
type numericQuery struct {
|
||||
Left, Right query
|
||||
|
||||
Do func(interface{}, interface{}) interface{}
|
||||
}
|
||||
|
||||
func (n *numericQuery) Select(t iterator) NodeNavigator {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *numericQuery) Evaluate(t iterator) interface{} {
|
||||
m := n.Left.Evaluate(t)
|
||||
k := n.Right.Evaluate(t)
|
||||
return n.Do(m, k)
|
||||
}
|
||||
|
||||
func (n *numericQuery) Clone() query {
|
||||
return &numericQuery{Left: n.Left.Clone(), Right: n.Right.Clone(), Do: n.Do}
|
||||
}
|
||||
|
||||
type booleanQuery struct {
|
||||
IsOr bool
|
||||
Left, Right query
|
||||
iterator func() NodeNavigator
|
||||
}
|
||||
|
||||
func (b *booleanQuery) Select(t iterator) NodeNavigator {
|
||||
if b.iterator == nil {
|
||||
var list []NodeNavigator
|
||||
i := 0
|
||||
root := t.Current().Copy()
|
||||
if b.IsOr {
|
||||
for {
|
||||
node := b.Left.Select(t)
|
||||
if node == nil {
|
||||
break
|
||||
}
|
||||
node = node.Copy()
|
||||
list = append(list, node)
|
||||
}
|
||||
t.Current().MoveTo(root)
|
||||
for {
|
||||
node := b.Right.Select(t)
|
||||
if node == nil {
|
||||
break
|
||||
}
|
||||
node = node.Copy()
|
||||
list = append(list, node)
|
||||
}
|
||||
} else {
|
||||
var m []NodeNavigator
|
||||
var n []NodeNavigator
|
||||
for {
|
||||
node := b.Left.Select(t)
|
||||
if node == nil {
|
||||
break
|
||||
}
|
||||
node = node.Copy()
|
||||
list = append(m, node)
|
||||
}
|
||||
t.Current().MoveTo(root)
|
||||
for {
|
||||
node := b.Right.Select(t)
|
||||
if node == nil {
|
||||
break
|
||||
}
|
||||
node = node.Copy()
|
||||
list = append(n, node)
|
||||
}
|
||||
for _, k := range m {
|
||||
for _, j := range n {
|
||||
if k == j {
|
||||
list = append(list, k)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
b.iterator = func() NodeNavigator {
|
||||
if i >= len(list) {
|
||||
return nil
|
||||
}
|
||||
node := list[i]
|
||||
i++
|
||||
return node
|
||||
}
|
||||
}
|
||||
return b.iterator()
|
||||
}
|
||||
|
||||
func (b *booleanQuery) Evaluate(t iterator) interface{} {
|
||||
m := b.Left.Evaluate(t)
|
||||
if m.(bool) == b.IsOr {
|
||||
return m
|
||||
}
|
||||
return b.Right.Evaluate(t)
|
||||
}
|
||||
|
||||
func (b *booleanQuery) Clone() query {
|
||||
return &booleanQuery{IsOr: b.IsOr, Left: b.Left.Clone(), Right: b.Right.Clone()}
|
||||
}
|
||||
|
||||
func getNodePosition(q query) int {
|
||||
type Position interface {
|
||||
position() int
|
||||
}
|
||||
if count, ok := q.(Position); ok {
|
||||
return count.position()
|
||||
}
|
||||
return 1
|
||||
}
|
|
@ -0,0 +1,154 @@
|
|||
package xpath
|
||||
|
||||
import (
|
||||
"errors"
|
||||
)
|
||||
|
||||
// NodeType represents a type of XPath node.
|
||||
type NodeType int
|
||||
|
||||
const (
|
||||
// RootNode is a root node of the XML document or node tree.
|
||||
RootNode NodeType = iota
|
||||
|
||||
// ElementNode is an element, such as <element>.
|
||||
ElementNode
|
||||
|
||||
// AttributeNode is an attribute, such as id='123'.
|
||||
AttributeNode
|
||||
|
||||
// TextNode is the text content of a node.
|
||||
TextNode
|
||||
|
||||
// CommentNode is a comment node, such as <!-- my comment -->
|
||||
CommentNode
|
||||
)
|
||||
|
||||
// NodeNavigator provides cursor model for navigating XML data.
|
||||
type NodeNavigator interface {
|
||||
// NodeType returns the XPathNodeType of the current node.
|
||||
NodeType() NodeType
|
||||
|
||||
// LocalName gets the Name of the current node.
|
||||
LocalName() string
|
||||
|
||||
// Prefix returns namespace prefix associated with the current node.
|
||||
Prefix() string
|
||||
|
||||
// Value gets the value of current node.
|
||||
Value() string
|
||||
|
||||
// Copy does a deep copy of the NodeNavigator and all its components.
|
||||
Copy() NodeNavigator
|
||||
|
||||
// MoveToRoot moves the NodeNavigator to the root node of the current node.
|
||||
MoveToRoot()
|
||||
|
||||
// MoveToParent moves the NodeNavigator to the parent node of the current node.
|
||||
MoveToParent() bool
|
||||
|
||||
// MoveToNextAttribute moves the NodeNavigator to the next attribute on current node.
|
||||
MoveToNextAttribute() bool
|
||||
|
||||
// MoveToChild moves the NodeNavigator to the first child node of the current node.
|
||||
MoveToChild() bool
|
||||
|
||||
// MoveToFirst moves the NodeNavigator to the first sibling node of the current node.
|
||||
MoveToFirst() bool
|
||||
|
||||
// MoveToNext moves the NodeNavigator to the next sibling node of the current node.
|
||||
MoveToNext() bool
|
||||
|
||||
// MoveToPrevious moves the NodeNavigator to the previous sibling node of the current node.
|
||||
MoveToPrevious() bool
|
||||
|
||||
// MoveTo moves the NodeNavigator to the same position as the specified NodeNavigator.
|
||||
MoveTo(NodeNavigator) bool
|
||||
}
|
||||
|
||||
// NodeIterator holds all matched Node object.
|
||||
type NodeIterator struct {
|
||||
node NodeNavigator
|
||||
query query
|
||||
}
|
||||
|
||||
// Current returns current node which matched.
|
||||
func (t *NodeIterator) Current() NodeNavigator {
|
||||
return t.node
|
||||
}
|
||||
|
||||
// MoveNext moves Navigator to the next match node.
|
||||
func (t *NodeIterator) MoveNext() bool {
|
||||
n := t.query.Select(t)
|
||||
if n != nil {
|
||||
if !t.node.MoveTo(n) {
|
||||
t.node = n.Copy()
|
||||
}
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Select selects a node set using the specified XPath expression.
|
||||
// This method is deprecated, recommend using Expr.Select() method instead.
|
||||
func Select(root NodeNavigator, expr string) *NodeIterator {
|
||||
exp, err := Compile(expr)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return exp.Select(root)
|
||||
}
|
||||
|
||||
// Expr is an XPath expression for query.
|
||||
type Expr struct {
|
||||
s string
|
||||
q query
|
||||
}
|
||||
|
||||
type iteratorFunc func() NodeNavigator
|
||||
|
||||
func (f iteratorFunc) Current() NodeNavigator {
|
||||
return f()
|
||||
}
|
||||
|
||||
// Evaluate returns the result of the expression.
|
||||
// The result type of the expression is one of the follow: bool,float64,string,NodeIterator).
|
||||
func (expr *Expr) Evaluate(root NodeNavigator) interface{} {
|
||||
val := expr.q.Evaluate(iteratorFunc(func() NodeNavigator { return root }))
|
||||
switch val.(type) {
|
||||
case query:
|
||||
return &NodeIterator{query: expr.q.Clone(), node: root}
|
||||
}
|
||||
return val
|
||||
}
|
||||
|
||||
// Select selects a node set using the specified XPath expression.
|
||||
func (expr *Expr) Select(root NodeNavigator) *NodeIterator {
|
||||
return &NodeIterator{query: expr.q.Clone(), node: root}
|
||||
}
|
||||
|
||||
// String returns XPath expression string.
|
||||
func (expr *Expr) String() string {
|
||||
return expr.s
|
||||
}
|
||||
|
||||
// Compile compiles an XPath expression string.
|
||||
func Compile(expr string) (*Expr, error) {
|
||||
if expr == "" {
|
||||
return nil, errors.New("expr expression is nil")
|
||||
}
|
||||
qy, err := build(expr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Expr{s: expr, q: qy}, nil
|
||||
}
|
||||
|
||||
// MustCompile compiles an XPath expression string and ignored error.
|
||||
func MustCompile(expr string) *Expr {
|
||||
exp, err := Compile(expr)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return exp
|
||||
}
|
|
@ -0,0 +1,17 @@
|
|||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in
|
||||
all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
THE SOFTWARE.
|
|
@ -0,0 +1,252 @@
|
|||
package xmlquery
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/xml"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/net/html/charset"
|
||||
)
|
||||
|
||||
// A NodeType is the type of a Node.
|
||||
type NodeType uint
|
||||
|
||||
const (
|
||||
// DocumentNode is a document object that, as the root of the document tree,
|
||||
// provides access to the entire XML document.
|
||||
DocumentNode NodeType = iota
|
||||
// DeclarationNode is the document type declaration, indicated by the following
|
||||
// tag (for example, <!DOCTYPE...> ).
|
||||
DeclarationNode
|
||||
// ElementNode is an element (for example, <item> ).
|
||||
ElementNode
|
||||
// TextNode is the text content of a node.
|
||||
TextNode
|
||||
// CommentNode a comment (for example, <!-- my comment --> ).
|
||||
CommentNode
|
||||
)
|
||||
|
||||
// A Node consists of a NodeType and some Data (tag name for
|
||||
// element nodes, content for text) and are part of a tree of Nodes.
|
||||
type Node struct {
|
||||
Parent, FirstChild, LastChild, PrevSibling, NextSibling *Node
|
||||
|
||||
Type NodeType
|
||||
Data string
|
||||
Prefix string
|
||||
NamespaceURI string
|
||||
Attr []xml.Attr
|
||||
|
||||
level int // node level in the tree
|
||||
}
|
||||
|
||||
// InnerText returns the text between the start and end tags of the object.
|
||||
func (n *Node) InnerText() string {
|
||||
var output func(*bytes.Buffer, *Node)
|
||||
output = func(buf *bytes.Buffer, n *Node) {
|
||||
switch n.Type {
|
||||
case TextNode:
|
||||
buf.WriteString(n.Data)
|
||||
return
|
||||
case CommentNode:
|
||||
return
|
||||
}
|
||||
for child := n.FirstChild; child != nil; child = child.NextSibling {
|
||||
output(buf, child)
|
||||
}
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
output(&buf, n)
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
func outputXML(buf *bytes.Buffer, n *Node) {
|
||||
if n.Type == TextNode || n.Type == CommentNode {
|
||||
buf.WriteString(strings.TrimSpace(n.Data))
|
||||
return
|
||||
}
|
||||
buf.WriteString("<" + n.Data)
|
||||
for _, attr := range n.Attr {
|
||||
if attr.Name.Space != "" {
|
||||
buf.WriteString(fmt.Sprintf(` %s:%s="%s"`, attr.Name.Space, attr.Name.Local, attr.Value))
|
||||
} else {
|
||||
buf.WriteString(fmt.Sprintf(` %s="%s"`, attr.Name.Local, attr.Value))
|
||||
}
|
||||
}
|
||||
buf.WriteString(">")
|
||||
for child := n.FirstChild; child != nil; child = child.NextSibling {
|
||||
outputXML(buf, child)
|
||||
}
|
||||
buf.WriteString(fmt.Sprintf("</%s>", n.Data))
|
||||
}
|
||||
|
||||
// OutputXML returns the text that including tags name.
|
||||
func (n *Node) OutputXML(self bool) string {
|
||||
var buf bytes.Buffer
|
||||
if self {
|
||||
outputXML(&buf, n)
|
||||
} else {
|
||||
for n := n.FirstChild; n != nil; n = n.NextSibling {
|
||||
outputXML(&buf, n)
|
||||
}
|
||||
}
|
||||
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
func addAttr(n *Node, key, val string) {
|
||||
var attr xml.Attr
|
||||
if i := strings.Index(key, ":"); i > 0 {
|
||||
attr = xml.Attr{
|
||||
Name: xml.Name{Space: key[:i], Local: key[i+1:]},
|
||||
Value: val,
|
||||
}
|
||||
} else {
|
||||
attr = xml.Attr{
|
||||
Name: xml.Name{Local: key},
|
||||
Value: val,
|
||||
}
|
||||
}
|
||||
|
||||
n.Attr = append(n.Attr, attr)
|
||||
}
|
||||
|
||||
func addChild(parent, n *Node) {
|
||||
n.Parent = parent
|
||||
if parent.FirstChild == nil {
|
||||
parent.FirstChild = n
|
||||
} else {
|
||||
parent.LastChild.NextSibling = n
|
||||
n.PrevSibling = parent.LastChild
|
||||
}
|
||||
|
||||
parent.LastChild = n
|
||||
}
|
||||
|
||||
func addSibling(sibling, n *Node) {
|
||||
n.Parent = sibling.Parent
|
||||
sibling.NextSibling = n
|
||||
n.PrevSibling = sibling
|
||||
if sibling.Parent != nil {
|
||||
sibling.Parent.LastChild = n
|
||||
}
|
||||
}
|
||||
|
||||
// LoadURL loads the XML document from the specified URL.
|
||||
func LoadURL(url string) (*Node, error) {
|
||||
resp, err := http.Get(url)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
return parse(resp.Body)
|
||||
}
|
||||
|
||||
func parse(r io.Reader) (*Node, error) {
|
||||
var (
|
||||
decoder = xml.NewDecoder(r)
|
||||
doc = &Node{Type: DocumentNode}
|
||||
space2prefix = make(map[string]string)
|
||||
level = 0
|
||||
)
|
||||
decoder.CharsetReader = charset.NewReaderLabel
|
||||
prev := doc
|
||||
for {
|
||||
tok, err := decoder.Token()
|
||||
switch {
|
||||
case err == io.EOF:
|
||||
goto quit
|
||||
case err != nil:
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch tok := tok.(type) {
|
||||
case xml.StartElement:
|
||||
if level == 0 {
|
||||
// mising XML declaration
|
||||
node := &Node{Type: DeclarationNode, Data: "xml", level: 1}
|
||||
addChild(prev, node)
|
||||
level = 1
|
||||
prev = node
|
||||
}
|
||||
node := &Node{
|
||||
Type: ElementNode,
|
||||
Data: tok.Name.Local,
|
||||
Prefix: space2prefix[tok.Name.Space],
|
||||
NamespaceURI: tok.Name.Space,
|
||||
Attr: tok.Attr,
|
||||
level: level,
|
||||
}
|
||||
for _, att := range tok.Attr {
|
||||
if att.Name.Space == "xmlns" {
|
||||
space2prefix[att.Value] = att.Name.Local
|
||||
}
|
||||
}
|
||||
//fmt.Println(fmt.Sprintf("start > %s : %d", node.Data, level))
|
||||
if level == prev.level {
|
||||
addSibling(prev, node)
|
||||
} else if level > prev.level {
|
||||
addChild(prev, node)
|
||||
} else if level < prev.level {
|
||||
for i := prev.level - level; i > 1; i-- {
|
||||
prev = prev.Parent
|
||||
}
|
||||
addSibling(prev.Parent, node)
|
||||
}
|
||||
prev = node
|
||||
level++
|
||||
case xml.EndElement:
|
||||
level--
|
||||
case xml.CharData:
|
||||
node := &Node{Type: TextNode, Data: string(tok), level: level}
|
||||
if level == prev.level {
|
||||
addSibling(prev, node)
|
||||
} else if level > prev.level {
|
||||
addChild(prev, node)
|
||||
}
|
||||
case xml.Comment:
|
||||
node := &Node{Type: CommentNode, Data: string(tok), level: level}
|
||||
if level == prev.level {
|
||||
addSibling(prev, node)
|
||||
} else if level > prev.level {
|
||||
addChild(prev, node)
|
||||
}
|
||||
case xml.ProcInst: // Processing Instruction
|
||||
if prev.Type != DeclarationNode {
|
||||
level++
|
||||
}
|
||||
node := &Node{Type: DeclarationNode, Data: tok.Target, level: level}
|
||||
pairs := strings.Split(string(tok.Inst), " ")
|
||||
for _, pair := range pairs {
|
||||
pair = strings.TrimSpace(pair)
|
||||
if i := strings.Index(pair, "="); i > 0 {
|
||||
addAttr(node, pair[:i], strings.Trim(pair[i+1:], `"`))
|
||||
}
|
||||
}
|
||||
if level == prev.level {
|
||||
addSibling(prev, node)
|
||||
} else if level > prev.level {
|
||||
addChild(prev, node)
|
||||
}
|
||||
prev = node
|
||||
case xml.Directive:
|
||||
}
|
||||
|
||||
}
|
||||
quit:
|
||||
return doc, nil
|
||||
}
|
||||
|
||||
// Parse returns the parse tree for the XML from the given Reader.
|
||||
func Parse(r io.Reader) (*Node, error) {
|
||||
return parse(r)
|
||||
}
|
||||
|
||||
// ParseXML returns the parse tree for the XML from the given Reader.Deprecated.
|
||||
func ParseXML(r io.Reader) (*Node, error) {
|
||||
return parse(r)
|
||||
}
|
|
@ -0,0 +1,230 @@
|
|||
/*
|
||||
Package xmlquery provides extract data from XML documents using XPath expression.
|
||||
*/
|
||||
package xmlquery
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/antchfx/xpath"
|
||||
)
|
||||
|
||||
// SelectElements finds child elements with the specified name.
|
||||
func (n *Node) SelectElements(name string) []*Node {
|
||||
return Find(n, name)
|
||||
}
|
||||
|
||||
// SelectElement finds child elements with the specified name.
|
||||
func (n *Node) SelectElement(name string) *Node {
|
||||
return FindOne(n, name)
|
||||
}
|
||||
|
||||
// SelectAttr returns the attribute value with the specified name.
|
||||
func (n *Node) SelectAttr(name string) string {
|
||||
var local, space string
|
||||
local = name
|
||||
if i := strings.Index(name, ":"); i > 0 {
|
||||
space = name[:i]
|
||||
local = name[i+1:]
|
||||
}
|
||||
for _, attr := range n.Attr {
|
||||
if attr.Name.Local == local && attr.Name.Space == space {
|
||||
return attr.Value
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
var _ xpath.NodeNavigator = &NodeNavigator{}
|
||||
|
||||
// CreateXPathNavigator creates a new xpath.NodeNavigator for the specified html.Node.
|
||||
func CreateXPathNavigator(top *Node) *NodeNavigator {
|
||||
return &NodeNavigator{curr: top, root: top, attr: -1}
|
||||
}
|
||||
|
||||
// Find searches the Node that matches by the specified XPath expr.
|
||||
func Find(top *Node, expr string) []*Node {
|
||||
exp, err := xpath.Compile(expr)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
t := exp.Select(CreateXPathNavigator(top))
|
||||
var elems []*Node
|
||||
for t.MoveNext() {
|
||||
elems = append(elems, (t.Current().(*NodeNavigator)).curr)
|
||||
}
|
||||
return elems
|
||||
}
|
||||
|
||||
// FindOne searches the Node that matches by the specified XPath expr,
|
||||
// and returns first element of matched.
|
||||
func FindOne(top *Node, expr string) *Node {
|
||||
exp, err := xpath.Compile(expr)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
t := exp.Select(CreateXPathNavigator(top))
|
||||
var elem *Node
|
||||
if t.MoveNext() {
|
||||
elem = (t.Current().(*NodeNavigator)).curr
|
||||
}
|
||||
return elem
|
||||
}
|
||||
|
||||
// FindEach searches the html.Node and calls functions cb.
|
||||
func FindEach(top *Node, expr string, cb func(int, *Node)) {
|
||||
exp, err := xpath.Compile(expr)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
t := exp.Select(CreateXPathNavigator(top))
|
||||
var i int
|
||||
for t.MoveNext() {
|
||||
cb(i, (t.Current().(*NodeNavigator)).curr)
|
||||
i++
|
||||
}
|
||||
}
|
||||
|
||||
type NodeNavigator struct {
|
||||
root, curr *Node
|
||||
attr int
|
||||
}
|
||||
|
||||
func (x *NodeNavigator) Current() *Node {
|
||||
return x.curr
|
||||
}
|
||||
|
||||
func (x *NodeNavigator) NodeType() xpath.NodeType {
|
||||
switch x.curr.Type {
|
||||
case CommentNode:
|
||||
return xpath.CommentNode
|
||||
case TextNode:
|
||||
return xpath.TextNode
|
||||
case DeclarationNode, DocumentNode:
|
||||
return xpath.RootNode
|
||||
case ElementNode:
|
||||
if x.attr != -1 {
|
||||
return xpath.AttributeNode
|
||||
}
|
||||
return xpath.ElementNode
|
||||
}
|
||||
panic(fmt.Sprintf("unknown XML node type: %v", x.curr.Type))
|
||||
}
|
||||
|
||||
func (x *NodeNavigator) LocalName() string {
|
||||
if x.attr != -1 {
|
||||
return x.curr.Attr[x.attr].Name.Local
|
||||
}
|
||||
return x.curr.Data
|
||||
|
||||
}
|
||||
|
||||
func (x *NodeNavigator) Prefix() string {
|
||||
return x.curr.Prefix
|
||||
}
|
||||
|
||||
func (x *NodeNavigator) Value() string {
|
||||
switch x.curr.Type {
|
||||
case CommentNode:
|
||||
return x.curr.Data
|
||||
case ElementNode:
|
||||
if x.attr != -1 {
|
||||
return x.curr.Attr[x.attr].Value
|
||||
}
|
||||
return x.curr.InnerText()
|
||||
case TextNode:
|
||||
return x.curr.Data
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *NodeNavigator) Copy() xpath.NodeNavigator {
|
||||
n := *x
|
||||
return &n
|
||||
}
|
||||
|
||||
func (x *NodeNavigator) MoveToRoot() {
|
||||
x.curr = x.root
|
||||
}
|
||||
|
||||
func (x *NodeNavigator) MoveToParent() bool {
|
||||
if x.attr != -1 {
|
||||
x.attr = -1
|
||||
return true
|
||||
} else if node := x.curr.Parent; node != nil {
|
||||
x.curr = node
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (x *NodeNavigator) MoveToNextAttribute() bool {
|
||||
if x.attr >= len(x.curr.Attr)-1 {
|
||||
return false
|
||||
}
|
||||
x.attr++
|
||||
return true
|
||||
}
|
||||
|
||||
func (x *NodeNavigator) MoveToChild() bool {
|
||||
if x.attr != -1 {
|
||||
return false
|
||||
}
|
||||
if node := x.curr.FirstChild; node != nil {
|
||||
x.curr = node
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (x *NodeNavigator) MoveToFirst() bool {
|
||||
if x.attr != -1 || x.curr.PrevSibling == nil {
|
||||
return false
|
||||
}
|
||||
for {
|
||||
node := x.curr.PrevSibling
|
||||
if node == nil {
|
||||
break
|
||||
}
|
||||
x.curr = node
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (x *NodeNavigator) String() string {
|
||||
return x.Value()
|
||||
}
|
||||
|
||||
func (x *NodeNavigator) MoveToNext() bool {
|
||||
if x.attr != -1 {
|
||||
return false
|
||||
}
|
||||
if node := x.curr.NextSibling; node != nil {
|
||||
x.curr = node
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (x *NodeNavigator) MoveToPrevious() bool {
|
||||
if x.attr != -1 {
|
||||
return false
|
||||
}
|
||||
if node := x.curr.PrevSibling; node != nil {
|
||||
x.curr = node
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (x *NodeNavigator) MoveTo(other xpath.NodeNavigator) bool {
|
||||
node, ok := other.(*NodeNavigator)
|
||||
if !ok || node.root != x.root {
|
||||
return false
|
||||
}
|
||||
|
||||
x.curr = node.curr
|
||||
x.attr = node.attr
|
||||
return true
|
||||
}
|
|
@ -8,8 +8,7 @@ import (
|
|||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/masterzen/winrm/soap"
|
||||
"github.com/masterzen/xmlpath"
|
||||
"github.com/antchfx/xquery/xml"
|
||||
"github.com/satori/go.uuid"
|
||||
)
|
||||
|
||||
|
@ -57,7 +56,7 @@ func (w *wsman) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
|
|||
rw.Header().Add("Content-Type", "application/soap+xml")
|
||||
|
||||
defer r.Body.Close()
|
||||
env, err := xmlpath.Parse(r.Body)
|
||||
env, err := xmlquery.Parse(r.Body)
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
|
@ -130,41 +129,32 @@ func (w *wsman) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
}
|
||||
|
||||
func readAction(env *xmlpath.Node) string {
|
||||
xpath, err := xmlpath.CompileWithNamespace(
|
||||
"//a:Action", soap.GetAllNamespaces())
|
||||
|
||||
if err != nil {
|
||||
func readAction(env *xmlquery.Node) string {
|
||||
xpath := xmlquery.FindOne(env, "//a:Action")
|
||||
if xpath == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
action, _ := xpath.String(env)
|
||||
return action
|
||||
return xpath.InnerText()
|
||||
}
|
||||
|
||||
func readCommand(env *xmlpath.Node) string {
|
||||
xpath, err := xmlpath.CompileWithNamespace(
|
||||
"//rsp:Command", soap.GetAllNamespaces())
|
||||
|
||||
if err != nil {
|
||||
func readCommand(env *xmlquery.Node) string {
|
||||
xpath := xmlquery.FindOne(env, "//rsp:Command")
|
||||
if xpath == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
command, _ := xpath.String(env)
|
||||
if unquoted, err := strconv.Unquote(command); err == nil {
|
||||
if unquoted, err := strconv.Unquote(xpath.InnerText()); err == nil {
|
||||
return unquoted
|
||||
}
|
||||
return command
|
||||
return xpath.InnerText()
|
||||
}
|
||||
|
||||
func readCommandIDFromDesiredStream(env *xmlpath.Node) string {
|
||||
xpath, err := xmlpath.CompileWithNamespace(
|
||||
"//rsp:DesiredStream/@CommandId", soap.GetAllNamespaces())
|
||||
|
||||
if err != nil {
|
||||
func readCommandIDFromDesiredStream(env *xmlquery.Node) string {
|
||||
xpath := xmlquery.FindOne(env, "//rsp:DesiredStream")
|
||||
if xpath == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
id, _ := xpath.String(env)
|
||||
return id
|
||||
return xpath.SelectAttr("CommandId")
|
||||
}
|
||||
|
|
|
@ -0,0 +1,202 @@
|
|||
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright 2016 Microsoft Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
203
vendor/github.com/masterzen/azure-sdk-for-go/core/http/chunked.go
generated
vendored
Normal file
203
vendor/github.com/masterzen/azure-sdk-for-go/core/http/chunked.go
generated
vendored
Normal file
|
@ -0,0 +1,203 @@
|
|||
// Copyright 2009 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// The wire protocol for HTTP's "chunked" Transfer-Encoding.
|
||||
|
||||
// This code is duplicated in net/http and net/http/httputil.
|
||||
// Please make any changes in both files.
|
||||
|
||||
package http
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
const maxLineLength = 4096 // assumed <= bufio.defaultBufSize
|
||||
|
||||
var ErrLineTooLong = errors.New("header line too long")
|
||||
|
||||
// newChunkedReader returns a new chunkedReader that translates the data read from r
|
||||
// out of HTTP "chunked" format before returning it.
|
||||
// The chunkedReader returns io.EOF when the final 0-length chunk is read.
|
||||
//
|
||||
// newChunkedReader is not needed by normal applications. The http package
|
||||
// automatically decodes chunking when reading response bodies.
|
||||
func newChunkedReader(r io.Reader) io.Reader {
|
||||
br, ok := r.(*bufio.Reader)
|
||||
if !ok {
|
||||
br = bufio.NewReader(r)
|
||||
}
|
||||
return &chunkedReader{r: br}
|
||||
}
|
||||
|
||||
type chunkedReader struct {
|
||||
r *bufio.Reader
|
||||
n uint64 // unread bytes in chunk
|
||||
err error
|
||||
buf [2]byte
|
||||
}
|
||||
|
||||
func (cr *chunkedReader) beginChunk() {
|
||||
// chunk-size CRLF
|
||||
var line []byte
|
||||
line, cr.err = readLine(cr.r)
|
||||
if cr.err != nil {
|
||||
return
|
||||
}
|
||||
cr.n, cr.err = parseHexUint(line)
|
||||
if cr.err != nil {
|
||||
return
|
||||
}
|
||||
if cr.n == 0 {
|
||||
cr.err = io.EOF
|
||||
}
|
||||
}
|
||||
|
||||
func (cr *chunkedReader) chunkHeaderAvailable() bool {
|
||||
n := cr.r.Buffered()
|
||||
if n > 0 {
|
||||
peek, _ := cr.r.Peek(n)
|
||||
return bytes.IndexByte(peek, '\n') >= 0
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (cr *chunkedReader) Read(b []uint8) (n int, err error) {
|
||||
for cr.err == nil {
|
||||
if cr.n == 0 {
|
||||
if n > 0 && !cr.chunkHeaderAvailable() {
|
||||
// We've read enough. Don't potentially block
|
||||
// reading a new chunk header.
|
||||
break
|
||||
}
|
||||
cr.beginChunk()
|
||||
continue
|
||||
}
|
||||
if len(b) == 0 {
|
||||
break
|
||||
}
|
||||
rbuf := b
|
||||
if uint64(len(rbuf)) > cr.n {
|
||||
rbuf = rbuf[:cr.n]
|
||||
}
|
||||
var n0 int
|
||||
n0, cr.err = cr.r.Read(rbuf)
|
||||
n += n0
|
||||
b = b[n0:]
|
||||
cr.n -= uint64(n0)
|
||||
// If we're at the end of a chunk, read the next two
|
||||
// bytes to verify they are "\r\n".
|
||||
if cr.n == 0 && cr.err == nil {
|
||||
if _, cr.err = io.ReadFull(cr.r, cr.buf[:2]); cr.err == nil {
|
||||
if cr.buf[0] != '\r' || cr.buf[1] != '\n' {
|
||||
cr.err = errors.New("malformed chunked encoding")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return n, cr.err
|
||||
}
|
||||
|
||||
// Read a line of bytes (up to \n) from b.
|
||||
// Give up if the line exceeds maxLineLength.
|
||||
// The returned bytes are a pointer into storage in
|
||||
// the bufio, so they are only valid until the next bufio read.
|
||||
func readLine(b *bufio.Reader) (p []byte, err error) {
|
||||
if p, err = b.ReadSlice('\n'); err != nil {
|
||||
// We always know when EOF is coming.
|
||||
// If the caller asked for a line, there should be a line.
|
||||
if err == io.EOF {
|
||||
err = io.ErrUnexpectedEOF
|
||||
} else if err == bufio.ErrBufferFull {
|
||||
err = ErrLineTooLong
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
if len(p) >= maxLineLength {
|
||||
return nil, ErrLineTooLong
|
||||
}
|
||||
return trimTrailingWhitespace(p), nil
|
||||
}
|
||||
|
||||
func trimTrailingWhitespace(b []byte) []byte {
|
||||
for len(b) > 0 && isASCIISpace(b[len(b)-1]) {
|
||||
b = b[:len(b)-1]
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func isASCIISpace(b byte) bool {
|
||||
return b == ' ' || b == '\t' || b == '\n' || b == '\r'
|
||||
}
|
||||
|
||||
// newChunkedWriter returns a new chunkedWriter that translates writes into HTTP
|
||||
// "chunked" format before writing them to w. Closing the returned chunkedWriter
|
||||
// sends the final 0-length chunk that marks the end of the stream.
|
||||
//
|
||||
// newChunkedWriter is not needed by normal applications. The http
|
||||
// package adds chunking automatically if handlers don't set a
|
||||
// Content-Length header. Using newChunkedWriter inside a handler
|
||||
// would result in double chunking or chunking with a Content-Length
|
||||
// length, both of which are wrong.
|
||||
func newChunkedWriter(w io.Writer) io.WriteCloser {
|
||||
return &chunkedWriter{w}
|
||||
}
|
||||
|
||||
// Writing to chunkedWriter translates to writing in HTTP chunked Transfer
|
||||
// Encoding wire format to the underlying Wire chunkedWriter.
|
||||
type chunkedWriter struct {
|
||||
Wire io.Writer
|
||||
}
|
||||
|
||||
// Write the contents of data as one chunk to Wire.
|
||||
// NOTE: Note that the corresponding chunk-writing procedure in Conn.Write has
|
||||
// a bug since it does not check for success of io.WriteString
|
||||
func (cw *chunkedWriter) Write(data []byte) (n int, err error) {
|
||||
|
||||
// Don't send 0-length data. It looks like EOF for chunked encoding.
|
||||
if len(data) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
if _, err = fmt.Fprintf(cw.Wire, "%x\r\n", len(data)); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if n, err = cw.Wire.Write(data); err != nil {
|
||||
return
|
||||
}
|
||||
if n != len(data) {
|
||||
err = io.ErrShortWrite
|
||||
return
|
||||
}
|
||||
_, err = io.WriteString(cw.Wire, "\r\n")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (cw *chunkedWriter) Close() error {
|
||||
_, err := io.WriteString(cw.Wire, "0\r\n")
|
||||
return err
|
||||
}
|
||||
|
||||
func parseHexUint(v []byte) (n uint64, err error) {
|
||||
for _, b := range v {
|
||||
n <<= 4
|
||||
switch {
|
||||
case '0' <= b && b <= '9':
|
||||
b = b - '0'
|
||||
case 'a' <= b && b <= 'f':
|
||||
b = b - 'a' + 10
|
||||
case 'A' <= b && b <= 'F':
|
||||
b = b - 'A' + 10
|
||||
default:
|
||||
return 0, errors.New("invalid byte in chunk length")
|
||||
}
|
||||
n |= uint64(b)
|
||||
}
|
||||
return
|
||||
}
|
|
@ -0,0 +1,487 @@
|
|||
// Copyright 2009 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// HTTP client. See RFC 2616.
|
||||
//
|
||||
// This is the high-level Client interface.
|
||||
// The low-level implementation is in transport.go.
|
||||
|
||||
package http
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// A Client is an HTTP client. Its zero value (DefaultClient) is a
|
||||
// usable client that uses DefaultTransport.
|
||||
//
|
||||
// The Client's Transport typically has internal state (cached TCP
|
||||
// connections), so Clients should be reused instead of created as
|
||||
// needed. Clients are safe for concurrent use by multiple goroutines.
|
||||
//
|
||||
// A Client is higher-level than a RoundTripper (such as Transport)
|
||||
// and additionally handles HTTP details such as cookies and
|
||||
// redirects.
|
||||
type Client struct {
|
||||
// Transport specifies the mechanism by which individual
|
||||
// HTTP requests are made.
|
||||
// If nil, DefaultTransport is used.
|
||||
Transport RoundTripper
|
||||
|
||||
// CheckRedirect specifies the policy for handling redirects.
|
||||
// If CheckRedirect is not nil, the client calls it before
|
||||
// following an HTTP redirect. The arguments req and via are
|
||||
// the upcoming request and the requests made already, oldest
|
||||
// first. If CheckRedirect returns an error, the Client's Get
|
||||
// method returns both the previous Response and
|
||||
// CheckRedirect's error (wrapped in a url.Error) instead of
|
||||
// issuing the Request req.
|
||||
//
|
||||
// If CheckRedirect is nil, the Client uses its default policy,
|
||||
// which is to stop after 10 consecutive requests.
|
||||
CheckRedirect func(req *Request, via []*Request) error
|
||||
|
||||
// Jar specifies the cookie jar.
|
||||
// If Jar is nil, cookies are not sent in requests and ignored
|
||||
// in responses.
|
||||
Jar CookieJar
|
||||
|
||||
// Timeout specifies a time limit for requests made by this
|
||||
// Client. The timeout includes connection time, any
|
||||
// redirects, and reading the response body. The timer remains
|
||||
// running after Get, Head, Post, or Do return and will
|
||||
// interrupt reading of the Response.Body.
|
||||
//
|
||||
// A Timeout of zero means no timeout.
|
||||
//
|
||||
// The Client's Transport must support the CancelRequest
|
||||
// method or Client will return errors when attempting to make
|
||||
// a request with Get, Head, Post, or Do. Client's default
|
||||
// Transport (DefaultTransport) supports CancelRequest.
|
||||
Timeout time.Duration
|
||||
}
|
||||
|
||||
// DefaultClient is the default Client and is used by Get, Head, and Post.
|
||||
var DefaultClient = &Client{}
|
||||
|
||||
// RoundTripper is an interface representing the ability to execute a
|
||||
// single HTTP transaction, obtaining the Response for a given Request.
|
||||
//
|
||||
// A RoundTripper must be safe for concurrent use by multiple
|
||||
// goroutines.
|
||||
type RoundTripper interface {
|
||||
// RoundTrip executes a single HTTP transaction, returning
|
||||
// the Response for the request req. RoundTrip should not
|
||||
// attempt to interpret the response. In particular,
|
||||
// RoundTrip must return err == nil if it obtained a response,
|
||||
// regardless of the response's HTTP status code. A non-nil
|
||||
// err should be reserved for failure to obtain a response.
|
||||
// Similarly, RoundTrip should not attempt to handle
|
||||
// higher-level protocol details such as redirects,
|
||||
// authentication, or cookies.
|
||||
//
|
||||
// RoundTrip should not modify the request, except for
|
||||
// consuming and closing the Body, including on errors. The
|
||||
// request's URL and Header fields are guaranteed to be
|
||||
// initialized.
|
||||
RoundTrip(*Request) (*Response, error)
|
||||
}
|
||||
|
||||
// Given a string of the form "host", "host:port", or "[ipv6::address]:port",
|
||||
// return true if the string includes a port.
|
||||
func hasPort(s string) bool { return strings.LastIndex(s, ":") > strings.LastIndex(s, "]") }
|
||||
|
||||
// Used in Send to implement io.ReadCloser by bundling together the
|
||||
// bufio.Reader through which we read the response, and the underlying
|
||||
// network connection.
|
||||
type readClose struct {
|
||||
io.Reader
|
||||
io.Closer
|
||||
}
|
||||
|
||||
func (c *Client) send(req *Request) (*Response, error) {
|
||||
if c.Jar != nil {
|
||||
for _, cookie := range c.Jar.Cookies(req.URL) {
|
||||
req.AddCookie(cookie)
|
||||
}
|
||||
}
|
||||
resp, err := send(req, c.transport())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if c.Jar != nil {
|
||||
if rc := resp.Cookies(); len(rc) > 0 {
|
||||
c.Jar.SetCookies(req.URL, rc)
|
||||
}
|
||||
}
|
||||
return resp, err
|
||||
}
|
||||
|
||||
// Do sends an HTTP request and returns an HTTP response, following
|
||||
// policy (e.g. redirects, cookies, auth) as configured on the client.
|
||||
//
|
||||
// An error is returned if caused by client policy (such as
|
||||
// CheckRedirect), or if there was an HTTP protocol error.
|
||||
// A non-2xx response doesn't cause an error.
|
||||
//
|
||||
// When err is nil, resp always contains a non-nil resp.Body.
|
||||
//
|
||||
// Callers should close resp.Body when done reading from it. If
|
||||
// resp.Body is not closed, the Client's underlying RoundTripper
|
||||
// (typically Transport) may not be able to re-use a persistent TCP
|
||||
// connection to the server for a subsequent "keep-alive" request.
|
||||
//
|
||||
// The request Body, if non-nil, will be closed by the underlying
|
||||
// Transport, even on errors.
|
||||
//
|
||||
// Generally Get, Post, or PostForm will be used instead of Do.
|
||||
func (c *Client) Do(req *Request) (resp *Response, err error) {
|
||||
if req.Method == "GET" || req.Method == "HEAD" {
|
||||
return c.doFollowingRedirects(req, shouldRedirectGet)
|
||||
}
|
||||
if req.Method == "POST" || req.Method == "PUT" {
|
||||
return c.doFollowingRedirects(req, shouldRedirectPost)
|
||||
}
|
||||
return c.send(req)
|
||||
}
|
||||
|
||||
func (c *Client) transport() RoundTripper {
|
||||
if c.Transport != nil {
|
||||
return c.Transport
|
||||
}
|
||||
return DefaultTransport
|
||||
}
|
||||
|
||||
// send issues an HTTP request.
|
||||
// Caller should close resp.Body when done reading from it.
|
||||
func send(req *Request, t RoundTripper) (resp *Response, err error) {
|
||||
if t == nil {
|
||||
req.closeBody()
|
||||
return nil, errors.New("http: no Client.Transport or DefaultTransport")
|
||||
}
|
||||
|
||||
if req.URL == nil {
|
||||
req.closeBody()
|
||||
return nil, errors.New("http: nil Request.URL")
|
||||
}
|
||||
|
||||
if req.RequestURI != "" {
|
||||
req.closeBody()
|
||||
return nil, errors.New("http: Request.RequestURI can't be set in client requests.")
|
||||
}
|
||||
|
||||
// Most the callers of send (Get, Post, et al) don't need
|
||||
// Headers, leaving it uninitialized. We guarantee to the
|
||||
// Transport that this has been initialized, though.
|
||||
if req.Header == nil {
|
||||
req.Header = make(Header)
|
||||
}
|
||||
|
||||
if u := req.URL.User; u != nil {
|
||||
username := u.Username()
|
||||
password, _ := u.Password()
|
||||
req.Header.Set("Authorization", "Basic "+basicAuth(username, password))
|
||||
}
|
||||
resp, err = t.RoundTrip(req)
|
||||
if err != nil {
|
||||
if resp != nil {
|
||||
log.Printf("RoundTripper returned a response & error; ignoring response")
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// See 2 (end of page 4) http://www.ietf.org/rfc/rfc2617.txt
|
||||
// "To receive authorization, the client sends the userid and password,
|
||||
// separated by a single colon (":") character, within a base64
|
||||
// encoded string in the credentials."
|
||||
// It is not meant to be urlencoded.
|
||||
func basicAuth(username, password string) string {
|
||||
auth := username + ":" + password
|
||||
return base64.StdEncoding.EncodeToString([]byte(auth))
|
||||
}
|
||||
|
||||
// True if the specified HTTP status code is one for which the Get utility should
|
||||
// automatically redirect.
|
||||
func shouldRedirectGet(statusCode int) bool {
|
||||
switch statusCode {
|
||||
case StatusMovedPermanently, StatusFound, StatusSeeOther, StatusTemporaryRedirect:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// True if the specified HTTP status code is one for which the Post utility should
|
||||
// automatically redirect.
|
||||
func shouldRedirectPost(statusCode int) bool {
|
||||
switch statusCode {
|
||||
case StatusFound, StatusSeeOther:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Get issues a GET to the specified URL. If the response is one of the following
|
||||
// redirect codes, Get follows the redirect, up to a maximum of 10 redirects:
|
||||
//
|
||||
// 301 (Moved Permanently)
|
||||
// 302 (Found)
|
||||
// 303 (See Other)
|
||||
// 307 (Temporary Redirect)
|
||||
//
|
||||
// An error is returned if there were too many redirects or if there
|
||||
// was an HTTP protocol error. A non-2xx response doesn't cause an
|
||||
// error.
|
||||
//
|
||||
// When err is nil, resp always contains a non-nil resp.Body.
|
||||
// Caller should close resp.Body when done reading from it.
|
||||
//
|
||||
// Get is a wrapper around DefaultClient.Get.
|
||||
func Get(url string) (resp *Response, err error) {
|
||||
return DefaultClient.Get(url)
|
||||
}
|
||||
|
||||
// Get issues a GET to the specified URL. If the response is one of the
|
||||
// following redirect codes, Get follows the redirect after calling the
|
||||
// Client's CheckRedirect function.
|
||||
//
|
||||
// 301 (Moved Permanently)
|
||||
// 302 (Found)
|
||||
// 303 (See Other)
|
||||
// 307 (Temporary Redirect)
|
||||
//
|
||||
// An error is returned if the Client's CheckRedirect function fails
|
||||
// or if there was an HTTP protocol error. A non-2xx response doesn't
|
||||
// cause an error.
|
||||
//
|
||||
// When err is nil, resp always contains a non-nil resp.Body.
|
||||
// Caller should close resp.Body when done reading from it.
|
||||
func (c *Client) Get(url string) (resp *Response, err error) {
|
||||
req, err := NewRequest("GET", url, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c.doFollowingRedirects(req, shouldRedirectGet)
|
||||
}
|
||||
|
||||
func (c *Client) doFollowingRedirects(ireq *Request, shouldRedirect func(int) bool) (resp *Response, err error) {
|
||||
var base *url.URL
|
||||
redirectChecker := c.CheckRedirect
|
||||
if redirectChecker == nil {
|
||||
redirectChecker = defaultCheckRedirect
|
||||
}
|
||||
var via []*Request
|
||||
|
||||
if ireq.URL == nil {
|
||||
ireq.closeBody()
|
||||
return nil, errors.New("http: nil Request.URL")
|
||||
}
|
||||
|
||||
var reqmu sync.Mutex // guards req
|
||||
req := ireq
|
||||
|
||||
var timer *time.Timer
|
||||
if c.Timeout > 0 {
|
||||
type canceler interface {
|
||||
CancelRequest(*Request)
|
||||
}
|
||||
tr, ok := c.transport().(canceler)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("net/http: Client Transport of type %T doesn't support CancelRequest; Timeout not supported", c.transport())
|
||||
}
|
||||
timer = time.AfterFunc(c.Timeout, func() {
|
||||
reqmu.Lock()
|
||||
defer reqmu.Unlock()
|
||||
tr.CancelRequest(req)
|
||||
})
|
||||
}
|
||||
|
||||
urlStr := "" // next relative or absolute URL to fetch (after first request)
|
||||
redirectFailed := false
|
||||
for redirect := 0; ; redirect++ {
|
||||
if redirect != 0 {
|
||||
nreq := new(Request)
|
||||
nreq.Method = ireq.Method
|
||||
if ireq.Method == "POST" || ireq.Method == "PUT" {
|
||||
nreq.Method = "GET"
|
||||
}
|
||||
nreq.Header = make(Header)
|
||||
nreq.URL, err = base.Parse(urlStr)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
if len(via) > 0 {
|
||||
// Add the Referer header.
|
||||
lastReq := via[len(via)-1]
|
||||
if lastReq.URL.Scheme != "https" {
|
||||
nreq.Header.Set("Referer", lastReq.URL.String())
|
||||
}
|
||||
|
||||
err = redirectChecker(nreq, via)
|
||||
if err != nil {
|
||||
redirectFailed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
reqmu.Lock()
|
||||
req = nreq
|
||||
reqmu.Unlock()
|
||||
}
|
||||
|
||||
urlStr = req.URL.String()
|
||||
if resp, err = c.send(req); err != nil {
|
||||
break
|
||||
}
|
||||
|
||||
if shouldRedirect(resp.StatusCode) {
|
||||
// Read the body if small so underlying TCP connection will be re-used.
|
||||
// No need to check for errors: if it fails, Transport won't reuse it anyway.
|
||||
const maxBodySlurpSize = 2 << 10
|
||||
if resp.ContentLength == -1 || resp.ContentLength <= maxBodySlurpSize {
|
||||
io.CopyN(ioutil.Discard, resp.Body, maxBodySlurpSize)
|
||||
}
|
||||
resp.Body.Close()
|
||||
if urlStr = resp.Header.Get("Location"); urlStr == "" {
|
||||
err = errors.New(fmt.Sprintf("%d response missing Location header", resp.StatusCode))
|
||||
break
|
||||
}
|
||||
base = req.URL
|
||||
via = append(via, req)
|
||||
continue
|
||||
}
|
||||
if timer != nil {
|
||||
resp.Body = &cancelTimerBody{timer, resp.Body}
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
method := ireq.Method
|
||||
urlErr := &url.Error{
|
||||
Op: method[0:1] + strings.ToLower(method[1:]),
|
||||
URL: urlStr,
|
||||
Err: err,
|
||||
}
|
||||
|
||||
if redirectFailed {
|
||||
// Special case for Go 1 compatibility: return both the response
|
||||
// and an error if the CheckRedirect function failed.
|
||||
// See http://golang.org/issue/3795
|
||||
return resp, urlErr
|
||||
}
|
||||
|
||||
if resp != nil {
|
||||
resp.Body.Close()
|
||||
}
|
||||
return nil, urlErr
|
||||
}
|
||||
|
||||
func defaultCheckRedirect(req *Request, via []*Request) error {
|
||||
if len(via) >= 10 {
|
||||
return errors.New("stopped after 10 redirects")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Post issues a POST to the specified URL.
|
||||
//
|
||||
// Caller should close resp.Body when done reading from it.
|
||||
//
|
||||
// Post is a wrapper around DefaultClient.Post
|
||||
func Post(url string, bodyType string, body io.Reader) (resp *Response, err error) {
|
||||
return DefaultClient.Post(url, bodyType, body)
|
||||
}
|
||||
|
||||
// Post issues a POST to the specified URL.
|
||||
//
|
||||
// Caller should close resp.Body when done reading from it.
|
||||
//
|
||||
// If the provided body is also an io.Closer, it is closed after the
|
||||
// request.
|
||||
func (c *Client) Post(url string, bodyType string, body io.Reader) (resp *Response, err error) {
|
||||
req, err := NewRequest("POST", url, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", bodyType)
|
||||
return c.doFollowingRedirects(req, shouldRedirectPost)
|
||||
}
|
||||
|
||||
// PostForm issues a POST to the specified URL, with data's keys and
|
||||
// values URL-encoded as the request body.
|
||||
//
|
||||
// When err is nil, resp always contains a non-nil resp.Body.
|
||||
// Caller should close resp.Body when done reading from it.
|
||||
//
|
||||
// PostForm is a wrapper around DefaultClient.PostForm
|
||||
func PostForm(url string, data url.Values) (resp *Response, err error) {
|
||||
return DefaultClient.PostForm(url, data)
|
||||
}
|
||||
|
||||
// PostForm issues a POST to the specified URL,
|
||||
// with data's keys and values urlencoded as the request body.
|
||||
//
|
||||
// When err is nil, resp always contains a non-nil resp.Body.
|
||||
// Caller should close resp.Body when done reading from it.
|
||||
func (c *Client) PostForm(url string, data url.Values) (resp *Response, err error) {
|
||||
return c.Post(url, "application/x-www-form-urlencoded", strings.NewReader(data.Encode()))
|
||||
}
|
||||
|
||||
// Head issues a HEAD to the specified URL. If the response is one of the
|
||||
// following redirect codes, Head follows the redirect after calling the
|
||||
// Client's CheckRedirect function.
|
||||
//
|
||||
// 301 (Moved Permanently)
|
||||
// 302 (Found)
|
||||
// 303 (See Other)
|
||||
// 307 (Temporary Redirect)
|
||||
//
|
||||
// Head is a wrapper around DefaultClient.Head
|
||||
func Head(url string) (resp *Response, err error) {
|
||||
return DefaultClient.Head(url)
|
||||
}
|
||||
|
||||
// Head issues a HEAD to the specified URL. If the response is one of the
|
||||
// following redirect codes, Head follows the redirect after calling the
|
||||
// Client's CheckRedirect function.
|
||||
//
|
||||
// 301 (Moved Permanently)
|
||||
// 302 (Found)
|
||||
// 303 (See Other)
|
||||
// 307 (Temporary Redirect)
|
||||
func (c *Client) Head(url string) (resp *Response, err error) {
|
||||
req, err := NewRequest("HEAD", url, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c.doFollowingRedirects(req, shouldRedirectGet)
|
||||
}
|
||||
|
||||
type cancelTimerBody struct {
|
||||
t *time.Timer
|
||||
rc io.ReadCloser
|
||||
}
|
||||
|
||||
func (b *cancelTimerBody) Read(p []byte) (n int, err error) {
|
||||
n, err = b.rc.Read(p)
|
||||
if err == io.EOF {
|
||||
b.t.Stop()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (b *cancelTimerBody) Close() error {
|
||||
err := b.rc.Close()
|
||||
b.t.Stop()
|
||||
return err
|
||||
}
|
|
@ -0,0 +1,363 @@
|
|||
// Copyright 2009 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package http
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// This implementation is done according to RFC 6265:
|
||||
//
|
||||
// http://tools.ietf.org/html/rfc6265
|
||||
|
||||
// A Cookie represents an HTTP cookie as sent in the Set-Cookie header of an
|
||||
// HTTP response or the Cookie header of an HTTP request.
|
||||
type Cookie struct {
|
||||
Name string
|
||||
Value string
|
||||
Path string
|
||||
Domain string
|
||||
Expires time.Time
|
||||
RawExpires string
|
||||
|
||||
// MaxAge=0 means no 'Max-Age' attribute specified.
|
||||
// MaxAge<0 means delete cookie now, equivalently 'Max-Age: 0'
|
||||
// MaxAge>0 means Max-Age attribute present and given in seconds
|
||||
MaxAge int
|
||||
Secure bool
|
||||
HttpOnly bool
|
||||
Raw string
|
||||
Unparsed []string // Raw text of unparsed attribute-value pairs
|
||||
}
|
||||
|
||||
// readSetCookies parses all "Set-Cookie" values from
|
||||
// the header h and returns the successfully parsed Cookies.
|
||||
func readSetCookies(h Header) []*Cookie {
|
||||
cookies := []*Cookie{}
|
||||
for _, line := range h["Set-Cookie"] {
|
||||
parts := strings.Split(strings.TrimSpace(line), ";")
|
||||
if len(parts) == 1 && parts[0] == "" {
|
||||
continue
|
||||
}
|
||||
parts[0] = strings.TrimSpace(parts[0])
|
||||
j := strings.Index(parts[0], "=")
|
||||
if j < 0 {
|
||||
continue
|
||||
}
|
||||
name, value := parts[0][:j], parts[0][j+1:]
|
||||
if !isCookieNameValid(name) {
|
||||
continue
|
||||
}
|
||||
value, success := parseCookieValue(value)
|
||||
if !success {
|
||||
continue
|
||||
}
|
||||
c := &Cookie{
|
||||
Name: name,
|
||||
Value: value,
|
||||
Raw: line,
|
||||
}
|
||||
for i := 1; i < len(parts); i++ {
|
||||
parts[i] = strings.TrimSpace(parts[i])
|
||||
if len(parts[i]) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
attr, val := parts[i], ""
|
||||
if j := strings.Index(attr, "="); j >= 0 {
|
||||
attr, val = attr[:j], attr[j+1:]
|
||||
}
|
||||
lowerAttr := strings.ToLower(attr)
|
||||
val, success = parseCookieValue(val)
|
||||
if !success {
|
||||
c.Unparsed = append(c.Unparsed, parts[i])
|
||||
continue
|
||||
}
|
||||
switch lowerAttr {
|
||||
case "secure":
|
||||
c.Secure = true
|
||||
continue
|
||||
case "httponly":
|
||||
c.HttpOnly = true
|
||||
continue
|
||||
case "domain":
|
||||
c.Domain = val
|
||||
continue
|
||||
case "max-age":
|
||||
secs, err := strconv.Atoi(val)
|
||||
if err != nil || secs != 0 && val[0] == '0' {
|
||||
break
|
||||
}
|
||||
if secs <= 0 {
|
||||
c.MaxAge = -1
|
||||
} else {
|
||||
c.MaxAge = secs
|
||||
}
|
||||
continue
|
||||
case "expires":
|
||||
c.RawExpires = val
|
||||
exptime, err := time.Parse(time.RFC1123, val)
|
||||
if err != nil {
|
||||
exptime, err = time.Parse("Mon, 02-Jan-2006 15:04:05 MST", val)
|
||||
if err != nil {
|
||||
c.Expires = time.Time{}
|
||||
break
|
||||
}
|
||||
}
|
||||
c.Expires = exptime.UTC()
|
||||
continue
|
||||
case "path":
|
||||
c.Path = val
|
||||
continue
|
||||
}
|
||||
c.Unparsed = append(c.Unparsed, parts[i])
|
||||
}
|
||||
cookies = append(cookies, c)
|
||||
}
|
||||
return cookies
|
||||
}
|
||||
|
||||
// SetCookie adds a Set-Cookie header to the provided ResponseWriter's headers.
|
||||
func SetCookie(w ResponseWriter, cookie *Cookie) {
|
||||
w.Header().Add("Set-Cookie", cookie.String())
|
||||
}
|
||||
|
||||
// String returns the serialization of the cookie for use in a Cookie
|
||||
// header (if only Name and Value are set) or a Set-Cookie response
|
||||
// header (if other fields are set).
|
||||
func (c *Cookie) String() string {
|
||||
var b bytes.Buffer
|
||||
fmt.Fprintf(&b, "%s=%s", sanitizeCookieName(c.Name), sanitizeCookieValue(c.Value))
|
||||
if len(c.Path) > 0 {
|
||||
fmt.Fprintf(&b, "; Path=%s", sanitizeCookiePath(c.Path))
|
||||
}
|
||||
if len(c.Domain) > 0 {
|
||||
if validCookieDomain(c.Domain) {
|
||||
// A c.Domain containing illegal characters is not
|
||||
// sanitized but simply dropped which turns the cookie
|
||||
// into a host-only cookie. A leading dot is okay
|
||||
// but won't be sent.
|
||||
d := c.Domain
|
||||
if d[0] == '.' {
|
||||
d = d[1:]
|
||||
}
|
||||
fmt.Fprintf(&b, "; Domain=%s", d)
|
||||
} else {
|
||||
log.Printf("net/http: invalid Cookie.Domain %q; dropping domain attribute",
|
||||
c.Domain)
|
||||
}
|
||||
}
|
||||
if c.Expires.Unix() > 0 {
|
||||
fmt.Fprintf(&b, "; Expires=%s", c.Expires.UTC().Format(time.RFC1123))
|
||||
}
|
||||
if c.MaxAge > 0 {
|
||||
fmt.Fprintf(&b, "; Max-Age=%d", c.MaxAge)
|
||||
} else if c.MaxAge < 0 {
|
||||
fmt.Fprintf(&b, "; Max-Age=0")
|
||||
}
|
||||
if c.HttpOnly {
|
||||
fmt.Fprintf(&b, "; HttpOnly")
|
||||
}
|
||||
if c.Secure {
|
||||
fmt.Fprintf(&b, "; Secure")
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// readCookies parses all "Cookie" values from the header h and
|
||||
// returns the successfully parsed Cookies.
|
||||
//
|
||||
// if filter isn't empty, only cookies of that name are returned
|
||||
func readCookies(h Header, filter string) []*Cookie {
|
||||
cookies := []*Cookie{}
|
||||
lines, ok := h["Cookie"]
|
||||
if !ok {
|
||||
return cookies
|
||||
}
|
||||
|
||||
for _, line := range lines {
|
||||
parts := strings.Split(strings.TrimSpace(line), ";")
|
||||
if len(parts) == 1 && parts[0] == "" {
|
||||
continue
|
||||
}
|
||||
// Per-line attributes
|
||||
parsedPairs := 0
|
||||
for i := 0; i < len(parts); i++ {
|
||||
parts[i] = strings.TrimSpace(parts[i])
|
||||
if len(parts[i]) == 0 {
|
||||
continue
|
||||
}
|
||||
name, val := parts[i], ""
|
||||
if j := strings.Index(name, "="); j >= 0 {
|
||||
name, val = name[:j], name[j+1:]
|
||||
}
|
||||
if !isCookieNameValid(name) {
|
||||
continue
|
||||
}
|
||||
if filter != "" && filter != name {
|
||||
continue
|
||||
}
|
||||
val, success := parseCookieValue(val)
|
||||
if !success {
|
||||
continue
|
||||
}
|
||||
cookies = append(cookies, &Cookie{Name: name, Value: val})
|
||||
parsedPairs++
|
||||
}
|
||||
}
|
||||
return cookies
|
||||
}
|
||||
|
||||
// validCookieDomain returns wheter v is a valid cookie domain-value.
|
||||
func validCookieDomain(v string) bool {
|
||||
if isCookieDomainName(v) {
|
||||
return true
|
||||
}
|
||||
if net.ParseIP(v) != nil && !strings.Contains(v, ":") {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// isCookieDomainName returns whether s is a valid domain name or a valid
|
||||
// domain name with a leading dot '.'. It is almost a direct copy of
|
||||
// package net's isDomainName.
|
||||
func isCookieDomainName(s string) bool {
|
||||
if len(s) == 0 {
|
||||
return false
|
||||
}
|
||||
if len(s) > 255 {
|
||||
return false
|
||||
}
|
||||
|
||||
if s[0] == '.' {
|
||||
// A cookie a domain attribute may start with a leading dot.
|
||||
s = s[1:]
|
||||
}
|
||||
last := byte('.')
|
||||
ok := false // Ok once we've seen a letter.
|
||||
partlen := 0
|
||||
for i := 0; i < len(s); i++ {
|
||||
c := s[i]
|
||||
switch {
|
||||
default:
|
||||
return false
|
||||
case 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z':
|
||||
// No '_' allowed here (in contrast to package net).
|
||||
ok = true
|
||||
partlen++
|
||||
case '0' <= c && c <= '9':
|
||||
// fine
|
||||
partlen++
|
||||
case c == '-':
|
||||
// Byte before dash cannot be dot.
|
||||
if last == '.' {
|
||||
return false
|
||||
}
|
||||
partlen++
|
||||
case c == '.':
|
||||
// Byte before dot cannot be dot, dash.
|
||||
if last == '.' || last == '-' {
|
||||
return false
|
||||
}
|
||||
if partlen > 63 || partlen == 0 {
|
||||
return false
|
||||
}
|
||||
partlen = 0
|
||||
}
|
||||
last = c
|
||||
}
|
||||
if last == '-' || partlen > 63 {
|
||||
return false
|
||||
}
|
||||
|
||||
return ok
|
||||
}
|
||||
|
||||
var cookieNameSanitizer = strings.NewReplacer("\n", "-", "\r", "-")
|
||||
|
||||
func sanitizeCookieName(n string) string {
|
||||
return cookieNameSanitizer.Replace(n)
|
||||
}
|
||||
|
||||
// http://tools.ietf.org/html/rfc6265#section-4.1.1
|
||||
// cookie-value = *cookie-octet / ( DQUOTE *cookie-octet DQUOTE )
|
||||
// cookie-octet = %x21 / %x23-2B / %x2D-3A / %x3C-5B / %x5D-7E
|
||||
// ; US-ASCII characters excluding CTLs,
|
||||
// ; whitespace DQUOTE, comma, semicolon,
|
||||
// ; and backslash
|
||||
// We loosen this as spaces and commas are common in cookie values
|
||||
// but we produce a quoted cookie-value in when value starts or ends
|
||||
// with a comma or space.
|
||||
// See http://golang.org/issue/7243 for the discussion.
|
||||
func sanitizeCookieValue(v string) string {
|
||||
v = sanitizeOrWarn("Cookie.Value", validCookieValueByte, v)
|
||||
if len(v) == 0 {
|
||||
return v
|
||||
}
|
||||
if v[0] == ' ' || v[0] == ',' || v[len(v)-1] == ' ' || v[len(v)-1] == ',' {
|
||||
return `"` + v + `"`
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func validCookieValueByte(b byte) bool {
|
||||
return 0x20 <= b && b < 0x7f && b != '"' && b != ';' && b != '\\'
|
||||
}
|
||||
|
||||
// path-av = "Path=" path-value
|
||||
// path-value = <any CHAR except CTLs or ";">
|
||||
func sanitizeCookiePath(v string) string {
|
||||
return sanitizeOrWarn("Cookie.Path", validCookiePathByte, v)
|
||||
}
|
||||
|
||||
func validCookiePathByte(b byte) bool {
|
||||
return 0x20 <= b && b < 0x7f && b != ';'
|
||||
}
|
||||
|
||||
func sanitizeOrWarn(fieldName string, valid func(byte) bool, v string) string {
|
||||
ok := true
|
||||
for i := 0; i < len(v); i++ {
|
||||
if valid(v[i]) {
|
||||
continue
|
||||
}
|
||||
log.Printf("net/http: invalid byte %q in %s; dropping invalid bytes", v[i], fieldName)
|
||||
ok = false
|
||||
break
|
||||
}
|
||||
if ok {
|
||||
return v
|
||||
}
|
||||
buf := make([]byte, 0, len(v))
|
||||
for i := 0; i < len(v); i++ {
|
||||
if b := v[i]; valid(b) {
|
||||
buf = append(buf, b)
|
||||
}
|
||||
}
|
||||
return string(buf)
|
||||
}
|
||||
|
||||
func parseCookieValue(raw string) (string, bool) {
|
||||
// Strip the quotes, if present.
|
||||
if len(raw) > 1 && raw[0] == '"' && raw[len(raw)-1] == '"' {
|
||||
raw = raw[1 : len(raw)-1]
|
||||
}
|
||||
for i := 0; i < len(raw); i++ {
|
||||
if !validCookieValueByte(raw[i]) {
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
return raw, true
|
||||
}
|
||||
|
||||
func isCookieNameValid(raw string) bool {
|
||||
return strings.IndexFunc(raw, isNotToken) < 0
|
||||
}
|
|
@ -0,0 +1,80 @@
|
|||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
/*
|
||||
Package http provides HTTP client and server implementations.
|
||||
|
||||
Get, Head, Post, and PostForm make HTTP (or HTTPS) requests:
|
||||
|
||||
resp, err := http.Get("http://example.com/")
|
||||
...
|
||||
resp, err := http.Post("http://example.com/upload", "image/jpeg", &buf)
|
||||
...
|
||||
resp, err := http.PostForm("http://example.com/form",
|
||||
url.Values{"key": {"Value"}, "id": {"123"}})
|
||||
|
||||
The client must close the response body when finished with it:
|
||||
|
||||
resp, err := http.Get("http://example.com/")
|
||||
if err != nil {
|
||||
// handle error
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
// ...
|
||||
|
||||
For control over HTTP client headers, redirect policy, and other
|
||||
settings, create a Client:
|
||||
|
||||
client := &http.Client{
|
||||
CheckRedirect: redirectPolicyFunc,
|
||||
}
|
||||
|
||||
resp, err := client.Get("http://example.com")
|
||||
// ...
|
||||
|
||||
req, err := http.NewRequest("GET", "http://example.com", nil)
|
||||
// ...
|
||||
req.Header.Add("If-None-Match", `W/"wyzzy"`)
|
||||
resp, err := client.Do(req)
|
||||
// ...
|
||||
|
||||
For control over proxies, TLS configuration, keep-alives,
|
||||
compression, and other settings, create a Transport:
|
||||
|
||||
tr := &http.Transport{
|
||||
TLSClientConfig: &tls.Config{RootCAs: pool},
|
||||
DisableCompression: true,
|
||||
}
|
||||
client := &http.Client{Transport: tr}
|
||||
resp, err := client.Get("https://example.com")
|
||||
|
||||
Clients and Transports are safe for concurrent use by multiple
|
||||
goroutines and for efficiency should only be created once and re-used.
|
||||
|
||||
ListenAndServe starts an HTTP server with a given address and handler.
|
||||
The handler is usually nil, which means to use DefaultServeMux.
|
||||
Handle and HandleFunc add handlers to DefaultServeMux:
|
||||
|
||||
http.Handle("/foo", fooHandler)
|
||||
|
||||
http.HandleFunc("/bar", func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintf(w, "Hello, %q", html.EscapeString(r.URL.Path))
|
||||
})
|
||||
|
||||
log.Fatal(http.ListenAndServe(":8080", nil))
|
||||
|
||||
More control over the server's behavior is available by creating a
|
||||
custom Server:
|
||||
|
||||
s := &http.Server{
|
||||
Addr: ":8080",
|
||||
Handler: myHandler,
|
||||
ReadTimeout: 10 * time.Second,
|
||||
WriteTimeout: 10 * time.Second,
|
||||
MaxHeaderBytes: 1 << 20,
|
||||
}
|
||||
log.Fatal(s.ListenAndServe())
|
||||
*/
|
||||
package http
|
123
vendor/github.com/masterzen/azure-sdk-for-go/core/http/filetransport.go
generated
vendored
Normal file
123
vendor/github.com/masterzen/azure-sdk-for-go/core/http/filetransport.go
generated
vendored
Normal file
|
@ -0,0 +1,123 @@
|
|||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package http
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
// fileTransport implements RoundTripper for the 'file' protocol.
|
||||
type fileTransport struct {
|
||||
fh fileHandler
|
||||
}
|
||||
|
||||
// NewFileTransport returns a new RoundTripper, serving the provided
|
||||
// FileSystem. The returned RoundTripper ignores the URL host in its
|
||||
// incoming requests, as well as most other properties of the
|
||||
// request.
|
||||
//
|
||||
// The typical use case for NewFileTransport is to register the "file"
|
||||
// protocol with a Transport, as in:
|
||||
//
|
||||
// t := &http.Transport{}
|
||||
// t.RegisterProtocol("file", http.NewFileTransport(http.Dir("/")))
|
||||
// c := &http.Client{Transport: t}
|
||||
// res, err := c.Get("file:///etc/passwd")
|
||||
// ...
|
||||
func NewFileTransport(fs FileSystem) RoundTripper {
|
||||
return fileTransport{fileHandler{fs}}
|
||||
}
|
||||
|
||||
func (t fileTransport) RoundTrip(req *Request) (resp *Response, err error) {
|
||||
// We start ServeHTTP in a goroutine, which may take a long
|
||||
// time if the file is large. The newPopulateResponseWriter
|
||||
// call returns a channel which either ServeHTTP or finish()
|
||||
// sends our *Response on, once the *Response itself has been
|
||||
// populated (even if the body itself is still being
|
||||
// written to the res.Body, a pipe)
|
||||
rw, resc := newPopulateResponseWriter()
|
||||
go func() {
|
||||
t.fh.ServeHTTP(rw, req)
|
||||
rw.finish()
|
||||
}()
|
||||
return <-resc, nil
|
||||
}
|
||||
|
||||
func newPopulateResponseWriter() (*populateResponse, <-chan *Response) {
|
||||
pr, pw := io.Pipe()
|
||||
rw := &populateResponse{
|
||||
ch: make(chan *Response),
|
||||
pw: pw,
|
||||
res: &Response{
|
||||
Proto: "HTTP/1.0",
|
||||
ProtoMajor: 1,
|
||||
Header: make(Header),
|
||||
Close: true,
|
||||
Body: pr,
|
||||
},
|
||||
}
|
||||
return rw, rw.ch
|
||||
}
|
||||
|
||||
// populateResponse is a ResponseWriter that populates the *Response
|
||||
// in res, and writes its body to a pipe connected to the response
|
||||
// body. Once writes begin or finish() is called, the response is sent
|
||||
// on ch.
|
||||
type populateResponse struct {
|
||||
res *Response
|
||||
ch chan *Response
|
||||
wroteHeader bool
|
||||
hasContent bool
|
||||
sentResponse bool
|
||||
pw *io.PipeWriter
|
||||
}
|
||||
|
||||
func (pr *populateResponse) finish() {
|
||||
if !pr.wroteHeader {
|
||||
pr.WriteHeader(500)
|
||||
}
|
||||
if !pr.sentResponse {
|
||||
pr.sendResponse()
|
||||
}
|
||||
pr.pw.Close()
|
||||
}
|
||||
|
||||
func (pr *populateResponse) sendResponse() {
|
||||
if pr.sentResponse {
|
||||
return
|
||||
}
|
||||
pr.sentResponse = true
|
||||
|
||||
if pr.hasContent {
|
||||
pr.res.ContentLength = -1
|
||||
}
|
||||
pr.ch <- pr.res
|
||||
}
|
||||
|
||||
func (pr *populateResponse) Header() Header {
|
||||
return pr.res.Header
|
||||
}
|
||||
|
||||
func (pr *populateResponse) WriteHeader(code int) {
|
||||
if pr.wroteHeader {
|
||||
return
|
||||
}
|
||||
pr.wroteHeader = true
|
||||
|
||||
pr.res.StatusCode = code
|
||||
pr.res.Status = fmt.Sprintf("%d %s", code, StatusText(code))
|
||||
}
|
||||
|
||||
func (pr *populateResponse) Write(p []byte) (n int, err error) {
|
||||
if !pr.wroteHeader {
|
||||
pr.WriteHeader(StatusOK)
|
||||
}
|
||||
pr.hasContent = true
|
||||
if !pr.sentResponse {
|
||||
pr.sendResponse()
|
||||
}
|
||||
return pr.pw.Write(p)
|
||||
}
|
|
@ -0,0 +1,549 @@
|
|||
// Copyright 2009 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// HTTP file system request handler
|
||||
|
||||
package http
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime"
|
||||
"mime/multipart"
|
||||
"net/textproto"
|
||||
"net/url"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// A Dir implements http.FileSystem using the native file
|
||||
// system restricted to a specific directory tree.
|
||||
//
|
||||
// An empty Dir is treated as ".".
|
||||
type Dir string
|
||||
|
||||
func (d Dir) Open(name string) (File, error) {
|
||||
if filepath.Separator != '/' && strings.IndexRune(name, filepath.Separator) >= 0 ||
|
||||
strings.Contains(name, "\x00") {
|
||||
return nil, errors.New("http: invalid character in file path")
|
||||
}
|
||||
dir := string(d)
|
||||
if dir == "" {
|
||||
dir = "."
|
||||
}
|
||||
f, err := os.Open(filepath.Join(dir, filepath.FromSlash(path.Clean("/"+name))))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return f, nil
|
||||
}
|
||||
|
||||
// A FileSystem implements access to a collection of named files.
|
||||
// The elements in a file path are separated by slash ('/', U+002F)
|
||||
// characters, regardless of host operating system convention.
|
||||
type FileSystem interface {
|
||||
Open(name string) (File, error)
|
||||
}
|
||||
|
||||
// A File is returned by a FileSystem's Open method and can be
|
||||
// served by the FileServer implementation.
|
||||
//
|
||||
// The methods should behave the same as those on an *os.File.
|
||||
type File interface {
|
||||
io.Closer
|
||||
io.Reader
|
||||
Readdir(count int) ([]os.FileInfo, error)
|
||||
Seek(offset int64, whence int) (int64, error)
|
||||
Stat() (os.FileInfo, error)
|
||||
}
|
||||
|
||||
func dirList(w ResponseWriter, f File) {
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
fmt.Fprintf(w, "<pre>\n")
|
||||
for {
|
||||
dirs, err := f.Readdir(100)
|
||||
if err != nil || len(dirs) == 0 {
|
||||
break
|
||||
}
|
||||
for _, d := range dirs {
|
||||
name := d.Name()
|
||||
if d.IsDir() {
|
||||
name += "/"
|
||||
}
|
||||
// name may contain '?' or '#', which must be escaped to remain
|
||||
// part of the URL path, and not indicate the start of a query
|
||||
// string or fragment.
|
||||
url := url.URL{Path: name}
|
||||
fmt.Fprintf(w, "<a href=\"%s\">%s</a>\n", url.String(), htmlReplacer.Replace(name))
|
||||
}
|
||||
}
|
||||
fmt.Fprintf(w, "</pre>\n")
|
||||
}
|
||||
|
||||
// ServeContent replies to the request using the content in the
|
||||
// provided ReadSeeker. The main benefit of ServeContent over io.Copy
|
||||
// is that it handles Range requests properly, sets the MIME type, and
|
||||
// handles If-Modified-Since requests.
|
||||
//
|
||||
// If the response's Content-Type header is not set, ServeContent
|
||||
// first tries to deduce the type from name's file extension and,
|
||||
// if that fails, falls back to reading the first block of the content
|
||||
// and passing it to DetectContentType.
|
||||
// The name is otherwise unused; in particular it can be empty and is
|
||||
// never sent in the response.
|
||||
//
|
||||
// If modtime is not the zero time, ServeContent includes it in a
|
||||
// Last-Modified header in the response. If the request includes an
|
||||
// If-Modified-Since header, ServeContent uses modtime to decide
|
||||
// whether the content needs to be sent at all.
|
||||
//
|
||||
// The content's Seek method must work: ServeContent uses
|
||||
// a seek to the end of the content to determine its size.
|
||||
//
|
||||
// If the caller has set w's ETag header, ServeContent uses it to
|
||||
// handle requests using If-Range and If-None-Match.
|
||||
//
|
||||
// Note that *os.File implements the io.ReadSeeker interface.
|
||||
func ServeContent(w ResponseWriter, req *Request, name string, modtime time.Time, content io.ReadSeeker) {
|
||||
sizeFunc := func() (int64, error) {
|
||||
size, err := content.Seek(0, os.SEEK_END)
|
||||
if err != nil {
|
||||
return 0, errSeeker
|
||||
}
|
||||
_, err = content.Seek(0, os.SEEK_SET)
|
||||
if err != nil {
|
||||
return 0, errSeeker
|
||||
}
|
||||
return size, nil
|
||||
}
|
||||
serveContent(w, req, name, modtime, sizeFunc, content)
|
||||
}
|
||||
|
||||
// errSeeker is returned by ServeContent's sizeFunc when the content
|
||||
// doesn't seek properly. The underlying Seeker's error text isn't
|
||||
// included in the sizeFunc reply so it's not sent over HTTP to end
|
||||
// users.
|
||||
var errSeeker = errors.New("seeker can't seek")
|
||||
|
||||
// if name is empty, filename is unknown. (used for mime type, before sniffing)
|
||||
// if modtime.IsZero(), modtime is unknown.
|
||||
// content must be seeked to the beginning of the file.
|
||||
// The sizeFunc is called at most once. Its error, if any, is sent in the HTTP response.
|
||||
func serveContent(w ResponseWriter, r *Request, name string, modtime time.Time, sizeFunc func() (int64, error), content io.ReadSeeker) {
|
||||
if checkLastModified(w, r, modtime) {
|
||||
return
|
||||
}
|
||||
rangeReq, done := checkETag(w, r)
|
||||
if done {
|
||||
return
|
||||
}
|
||||
|
||||
code := StatusOK
|
||||
|
||||
// If Content-Type isn't set, use the file's extension to find it, but
|
||||
// if the Content-Type is unset explicitly, do not sniff the type.
|
||||
ctypes, haveType := w.Header()["Content-Type"]
|
||||
var ctype string
|
||||
if !haveType {
|
||||
ctype = mime.TypeByExtension(filepath.Ext(name))
|
||||
if ctype == "" {
|
||||
// read a chunk to decide between utf-8 text and binary
|
||||
var buf [sniffLen]byte
|
||||
n, _ := io.ReadFull(content, buf[:])
|
||||
ctype = DetectContentType(buf[:n])
|
||||
_, err := content.Seek(0, os.SEEK_SET) // rewind to output whole file
|
||||
if err != nil {
|
||||
Error(w, "seeker can't seek", StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
w.Header().Set("Content-Type", ctype)
|
||||
} else if len(ctypes) > 0 {
|
||||
ctype = ctypes[0]
|
||||
}
|
||||
|
||||
size, err := sizeFunc()
|
||||
if err != nil {
|
||||
Error(w, err.Error(), StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// handle Content-Range header.
|
||||
sendSize := size
|
||||
var sendContent io.Reader = content
|
||||
if size >= 0 {
|
||||
ranges, err := parseRange(rangeReq, size)
|
||||
if err != nil {
|
||||
Error(w, err.Error(), StatusRequestedRangeNotSatisfiable)
|
||||
return
|
||||
}
|
||||
if sumRangesSize(ranges) > size {
|
||||
// The total number of bytes in all the ranges
|
||||
// is larger than the size of the file by
|
||||
// itself, so this is probably an attack, or a
|
||||
// dumb client. Ignore the range request.
|
||||
ranges = nil
|
||||
}
|
||||
switch {
|
||||
case len(ranges) == 1:
|
||||
// RFC 2616, Section 14.16:
|
||||
// "When an HTTP message includes the content of a single
|
||||
// range (for example, a response to a request for a
|
||||
// single range, or to a request for a set of ranges
|
||||
// that overlap without any holes), this content is
|
||||
// transmitted with a Content-Range header, and a
|
||||
// Content-Length header showing the number of bytes
|
||||
// actually transferred.
|
||||
// ...
|
||||
// A response to a request for a single range MUST NOT
|
||||
// be sent using the multipart/byteranges media type."
|
||||
ra := ranges[0]
|
||||
if _, err := content.Seek(ra.start, os.SEEK_SET); err != nil {
|
||||
Error(w, err.Error(), StatusRequestedRangeNotSatisfiable)
|
||||
return
|
||||
}
|
||||
sendSize = ra.length
|
||||
code = StatusPartialContent
|
||||
w.Header().Set("Content-Range", ra.contentRange(size))
|
||||
case len(ranges) > 1:
|
||||
for _, ra := range ranges {
|
||||
if ra.start > size {
|
||||
Error(w, err.Error(), StatusRequestedRangeNotSatisfiable)
|
||||
return
|
||||
}
|
||||
}
|
||||
sendSize = rangesMIMESize(ranges, ctype, size)
|
||||
code = StatusPartialContent
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
mw := multipart.NewWriter(pw)
|
||||
w.Header().Set("Content-Type", "multipart/byteranges; boundary="+mw.Boundary())
|
||||
sendContent = pr
|
||||
defer pr.Close() // cause writing goroutine to fail and exit if CopyN doesn't finish.
|
||||
go func() {
|
||||
for _, ra := range ranges {
|
||||
part, err := mw.CreatePart(ra.mimeHeader(ctype, size))
|
||||
if err != nil {
|
||||
pw.CloseWithError(err)
|
||||
return
|
||||
}
|
||||
if _, err := content.Seek(ra.start, os.SEEK_SET); err != nil {
|
||||
pw.CloseWithError(err)
|
||||
return
|
||||
}
|
||||
if _, err := io.CopyN(part, content, ra.length); err != nil {
|
||||
pw.CloseWithError(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
mw.Close()
|
||||
pw.Close()
|
||||
}()
|
||||
}
|
||||
|
||||
w.Header().Set("Accept-Ranges", "bytes")
|
||||
if w.Header().Get("Content-Encoding") == "" {
|
||||
w.Header().Set("Content-Length", strconv.FormatInt(sendSize, 10))
|
||||
}
|
||||
}
|
||||
|
||||
w.WriteHeader(code)
|
||||
|
||||
if r.Method != "HEAD" {
|
||||
io.CopyN(w, sendContent, sendSize)
|
||||
}
|
||||
}
|
||||
|
||||
// modtime is the modification time of the resource to be served, or IsZero().
|
||||
// return value is whether this request is now complete.
|
||||
func checkLastModified(w ResponseWriter, r *Request, modtime time.Time) bool {
|
||||
if modtime.IsZero() {
|
||||
return false
|
||||
}
|
||||
|
||||
// The Date-Modified header truncates sub-second precision, so
|
||||
// use mtime < t+1s instead of mtime <= t to check for unmodified.
|
||||
if t, err := time.Parse(TimeFormat, r.Header.Get("If-Modified-Since")); err == nil && modtime.Before(t.Add(1*time.Second)) {
|
||||
h := w.Header()
|
||||
delete(h, "Content-Type")
|
||||
delete(h, "Content-Length")
|
||||
w.WriteHeader(StatusNotModified)
|
||||
return true
|
||||
}
|
||||
w.Header().Set("Last-Modified", modtime.UTC().Format(TimeFormat))
|
||||
return false
|
||||
}
|
||||
|
||||
// checkETag implements If-None-Match and If-Range checks.
|
||||
// The ETag must have been previously set in the ResponseWriter's headers.
|
||||
//
|
||||
// The return value is the effective request "Range" header to use and
|
||||
// whether this request is now considered done.
|
||||
func checkETag(w ResponseWriter, r *Request) (rangeReq string, done bool) {
|
||||
etag := w.Header().get("Etag")
|
||||
rangeReq = r.Header.get("Range")
|
||||
|
||||
// Invalidate the range request if the entity doesn't match the one
|
||||
// the client was expecting.
|
||||
// "If-Range: version" means "ignore the Range: header unless version matches the
|
||||
// current file."
|
||||
// We only support ETag versions.
|
||||
// The caller must have set the ETag on the response already.
|
||||
if ir := r.Header.get("If-Range"); ir != "" && ir != etag {
|
||||
// TODO(bradfitz): handle If-Range requests with Last-Modified
|
||||
// times instead of ETags? I'd rather not, at least for
|
||||
// now. That seems like a bug/compromise in the RFC 2616, and
|
||||
// I've never heard of anybody caring about that (yet).
|
||||
rangeReq = ""
|
||||
}
|
||||
|
||||
if inm := r.Header.get("If-None-Match"); inm != "" {
|
||||
// Must know ETag.
|
||||
if etag == "" {
|
||||
return rangeReq, false
|
||||
}
|
||||
|
||||
// TODO(bradfitz): non-GET/HEAD requests require more work:
|
||||
// sending a different status code on matches, and
|
||||
// also can't use weak cache validators (those with a "W/
|
||||
// prefix). But most users of ServeContent will be using
|
||||
// it on GET or HEAD, so only support those for now.
|
||||
if r.Method != "GET" && r.Method != "HEAD" {
|
||||
return rangeReq, false
|
||||
}
|
||||
|
||||
// TODO(bradfitz): deal with comma-separated or multiple-valued
|
||||
// list of If-None-match values. For now just handle the common
|
||||
// case of a single item.
|
||||
if inm == etag || inm == "*" {
|
||||
h := w.Header()
|
||||
delete(h, "Content-Type")
|
||||
delete(h, "Content-Length")
|
||||
w.WriteHeader(StatusNotModified)
|
||||
return "", true
|
||||
}
|
||||
}
|
||||
return rangeReq, false
|
||||
}
|
||||
|
||||
// name is '/'-separated, not filepath.Separator.
|
||||
func serveFile(w ResponseWriter, r *Request, fs FileSystem, name string, redirect bool) {
|
||||
const indexPage = "/index.html"
|
||||
|
||||
// redirect .../index.html to .../
|
||||
// can't use Redirect() because that would make the path absolute,
|
||||
// which would be a problem running under StripPrefix
|
||||
if strings.HasSuffix(r.URL.Path, indexPage) {
|
||||
localRedirect(w, r, "./")
|
||||
return
|
||||
}
|
||||
|
||||
f, err := fs.Open(name)
|
||||
if err != nil {
|
||||
// TODO expose actual error?
|
||||
NotFound(w, r)
|
||||
return
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
d, err1 := f.Stat()
|
||||
if err1 != nil {
|
||||
// TODO expose actual error?
|
||||
NotFound(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
if redirect {
|
||||
// redirect to canonical path: / at end of directory url
|
||||
// r.URL.Path always begins with /
|
||||
url := r.URL.Path
|
||||
if d.IsDir() {
|
||||
if url[len(url)-1] != '/' {
|
||||
localRedirect(w, r, path.Base(url)+"/")
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if url[len(url)-1] == '/' {
|
||||
localRedirect(w, r, "../"+path.Base(url))
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// use contents of index.html for directory, if present
|
||||
if d.IsDir() {
|
||||
index := name + indexPage
|
||||
ff, err := fs.Open(index)
|
||||
if err == nil {
|
||||
defer ff.Close()
|
||||
dd, err := ff.Stat()
|
||||
if err == nil {
|
||||
name = index
|
||||
d = dd
|
||||
f = ff
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Still a directory? (we didn't find an index.html file)
|
||||
if d.IsDir() {
|
||||
if checkLastModified(w, r, d.ModTime()) {
|
||||
return
|
||||
}
|
||||
dirList(w, f)
|
||||
return
|
||||
}
|
||||
|
||||
// serverContent will check modification time
|
||||
sizeFunc := func() (int64, error) { return d.Size(), nil }
|
||||
serveContent(w, r, d.Name(), d.ModTime(), sizeFunc, f)
|
||||
}
|
||||
|
||||
// localRedirect gives a Moved Permanently response.
|
||||
// It does not convert relative paths to absolute paths like Redirect does.
|
||||
func localRedirect(w ResponseWriter, r *Request, newPath string) {
|
||||
if q := r.URL.RawQuery; q != "" {
|
||||
newPath += "?" + q
|
||||
}
|
||||
w.Header().Set("Location", newPath)
|
||||
w.WriteHeader(StatusMovedPermanently)
|
||||
}
|
||||
|
||||
// ServeFile replies to the request with the contents of the named file or directory.
|
||||
func ServeFile(w ResponseWriter, r *Request, name string) {
|
||||
dir, file := filepath.Split(name)
|
||||
serveFile(w, r, Dir(dir), file, false)
|
||||
}
|
||||
|
||||
type fileHandler struct {
|
||||
root FileSystem
|
||||
}
|
||||
|
||||
// FileServer returns a handler that serves HTTP requests
|
||||
// with the contents of the file system rooted at root.
|
||||
//
|
||||
// To use the operating system's file system implementation,
|
||||
// use http.Dir:
|
||||
//
|
||||
// http.Handle("/", http.FileServer(http.Dir("/tmp")))
|
||||
func FileServer(root FileSystem) Handler {
|
||||
return &fileHandler{root}
|
||||
}
|
||||
|
||||
func (f *fileHandler) ServeHTTP(w ResponseWriter, r *Request) {
|
||||
upath := r.URL.Path
|
||||
if !strings.HasPrefix(upath, "/") {
|
||||
upath = "/" + upath
|
||||
r.URL.Path = upath
|
||||
}
|
||||
serveFile(w, r, f.root, path.Clean(upath), true)
|
||||
}
|
||||
|
||||
// httpRange specifies the byte range to be sent to the client.
|
||||
type httpRange struct {
|
||||
start, length int64
|
||||
}
|
||||
|
||||
func (r httpRange) contentRange(size int64) string {
|
||||
return fmt.Sprintf("bytes %d-%d/%d", r.start, r.start+r.length-1, size)
|
||||
}
|
||||
|
||||
func (r httpRange) mimeHeader(contentType string, size int64) textproto.MIMEHeader {
|
||||
return textproto.MIMEHeader{
|
||||
"Content-Range": {r.contentRange(size)},
|
||||
"Content-Type": {contentType},
|
||||
}
|
||||
}
|
||||
|
||||
// parseRange parses a Range header string as per RFC 2616.
|
||||
func parseRange(s string, size int64) ([]httpRange, error) {
|
||||
if s == "" {
|
||||
return nil, nil // header not present
|
||||
}
|
||||
const b = "bytes="
|
||||
if !strings.HasPrefix(s, b) {
|
||||
return nil, errors.New("invalid range")
|
||||
}
|
||||
var ranges []httpRange
|
||||
for _, ra := range strings.Split(s[len(b):], ",") {
|
||||
ra = strings.TrimSpace(ra)
|
||||
if ra == "" {
|
||||
continue
|
||||
}
|
||||
i := strings.Index(ra, "-")
|
||||
if i < 0 {
|
||||
return nil, errors.New("invalid range")
|
||||
}
|
||||
start, end := strings.TrimSpace(ra[:i]), strings.TrimSpace(ra[i+1:])
|
||||
var r httpRange
|
||||
if start == "" {
|
||||
// If no start is specified, end specifies the
|
||||
// range start relative to the end of the file.
|
||||
i, err := strconv.ParseInt(end, 10, 64)
|
||||
if err != nil {
|
||||
return nil, errors.New("invalid range")
|
||||
}
|
||||
if i > size {
|
||||
i = size
|
||||
}
|
||||
r.start = size - i
|
||||
r.length = size - r.start
|
||||
} else {
|
||||
i, err := strconv.ParseInt(start, 10, 64)
|
||||
if err != nil || i > size || i < 0 {
|
||||
return nil, errors.New("invalid range")
|
||||
}
|
||||
r.start = i
|
||||
if end == "" {
|
||||
// If no end is specified, range extends to end of the file.
|
||||
r.length = size - r.start
|
||||
} else {
|
||||
i, err := strconv.ParseInt(end, 10, 64)
|
||||
if err != nil || r.start > i {
|
||||
return nil, errors.New("invalid range")
|
||||
}
|
||||
if i >= size {
|
||||
i = size - 1
|
||||
}
|
||||
r.length = i - r.start + 1
|
||||
}
|
||||
}
|
||||
ranges = append(ranges, r)
|
||||
}
|
||||
return ranges, nil
|
||||
}
|
||||
|
||||
// countingWriter counts how many bytes have been written to it.
|
||||
type countingWriter int64
|
||||
|
||||
func (w *countingWriter) Write(p []byte) (n int, err error) {
|
||||
*w += countingWriter(len(p))
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
// rangesMIMESize returns the number of bytes it takes to encode the
|
||||
// provided ranges as a multipart response.
|
||||
func rangesMIMESize(ranges []httpRange, contentType string, contentSize int64) (encSize int64) {
|
||||
var w countingWriter
|
||||
mw := multipart.NewWriter(&w)
|
||||
for _, ra := range ranges {
|
||||
mw.CreatePart(ra.mimeHeader(contentType, contentSize))
|
||||
encSize += ra.length
|
||||
}
|
||||
mw.Close()
|
||||
encSize += int64(w)
|
||||
return
|
||||
}
|
||||
|
||||
func sumRangesSize(ranges []httpRange) (size int64) {
|
||||
for _, ra := range ranges {
|
||||
size += ra.length
|
||||
}
|
||||
return
|
||||
}
|
|
@ -0,0 +1,211 @@
|
|||
// Copyright 2010 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package http
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/textproto"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var raceEnabled = false // set by race.go
|
||||
|
||||
// A Header represents the key-value pairs in an HTTP header.
|
||||
type Header map[string][]string
|
||||
|
||||
// Add adds the key, value pair to the header.
|
||||
// It appends to any existing values associated with key.
|
||||
func (h Header) Add(key, value string) {
|
||||
textproto.MIMEHeader(h).Add(key, value)
|
||||
}
|
||||
|
||||
// Set sets the header entries associated with key to
|
||||
// the single element value. It replaces any existing
|
||||
// values associated with key.
|
||||
func (h Header) Set(key, value string) {
|
||||
textproto.MIMEHeader(h).Set(key, value)
|
||||
}
|
||||
|
||||
// Get gets the first value associated with the given key.
|
||||
// If there are no values associated with the key, Get returns "".
|
||||
// To access multiple values of a key, access the map directly
|
||||
// with CanonicalHeaderKey.
|
||||
func (h Header) Get(key string) string {
|
||||
return textproto.MIMEHeader(h).Get(key)
|
||||
}
|
||||
|
||||
// get is like Get, but key must already be in CanonicalHeaderKey form.
|
||||
func (h Header) get(key string) string {
|
||||
if v := h[key]; len(v) > 0 {
|
||||
return v[0]
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// Del deletes the values associated with key.
|
||||
func (h Header) Del(key string) {
|
||||
textproto.MIMEHeader(h).Del(key)
|
||||
}
|
||||
|
||||
// Write writes a header in wire format.
|
||||
func (h Header) Write(w io.Writer) error {
|
||||
return h.WriteSubset(w, nil)
|
||||
}
|
||||
|
||||
func (h Header) clone() Header {
|
||||
h2 := make(Header, len(h))
|
||||
for k, vv := range h {
|
||||
vv2 := make([]string, len(vv))
|
||||
copy(vv2, vv)
|
||||
h2[k] = vv2
|
||||
}
|
||||
return h2
|
||||
}
|
||||
|
||||
var timeFormats = []string{
|
||||
TimeFormat,
|
||||
time.RFC850,
|
||||
time.ANSIC,
|
||||
}
|
||||
|
||||
// ParseTime parses a time header (such as the Date: header),
|
||||
// trying each of the three formats allowed by HTTP/1.1:
|
||||
// TimeFormat, time.RFC850, and time.ANSIC.
|
||||
func ParseTime(text string) (t time.Time, err error) {
|
||||
for _, layout := range timeFormats {
|
||||
t, err = time.Parse(layout, text)
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
var headerNewlineToSpace = strings.NewReplacer("\n", " ", "\r", " ")
|
||||
|
||||
type writeStringer interface {
|
||||
WriteString(string) (int, error)
|
||||
}
|
||||
|
||||
// stringWriter implements WriteString on a Writer.
|
||||
type stringWriter struct {
|
||||
w io.Writer
|
||||
}
|
||||
|
||||
func (w stringWriter) WriteString(s string) (n int, err error) {
|
||||
return w.w.Write([]byte(s))
|
||||
}
|
||||
|
||||
type keyValues struct {
|
||||
key string
|
||||
values []string
|
||||
}
|
||||
|
||||
// A headerSorter implements sort.Interface by sorting a []keyValues
|
||||
// by key. It's used as a pointer, so it can fit in a sort.Interface
|
||||
// interface value without allocation.
|
||||
type headerSorter struct {
|
||||
kvs []keyValues
|
||||
}
|
||||
|
||||
func (s *headerSorter) Len() int { return len(s.kvs) }
|
||||
func (s *headerSorter) Swap(i, j int) { s.kvs[i], s.kvs[j] = s.kvs[j], s.kvs[i] }
|
||||
func (s *headerSorter) Less(i, j int) bool { return s.kvs[i].key < s.kvs[j].key }
|
||||
|
||||
var headerSorterPool = sync.Pool{
|
||||
New: func() interface{} { return new(headerSorter) },
|
||||
}
|
||||
|
||||
// sortedKeyValues returns h's keys sorted in the returned kvs
|
||||
// slice. The headerSorter used to sort is also returned, for possible
|
||||
// return to headerSorterCache.
|
||||
func (h Header) sortedKeyValues(exclude map[string]bool) (kvs []keyValues, hs *headerSorter) {
|
||||
hs = headerSorterPool.Get().(*headerSorter)
|
||||
if cap(hs.kvs) < len(h) {
|
||||
hs.kvs = make([]keyValues, 0, len(h))
|
||||
}
|
||||
kvs = hs.kvs[:0]
|
||||
for k, vv := range h {
|
||||
if !exclude[k] {
|
||||
kvs = append(kvs, keyValues{k, vv})
|
||||
}
|
||||
}
|
||||
hs.kvs = kvs
|
||||
sort.Sort(hs)
|
||||
return kvs, hs
|
||||
}
|
||||
|
||||
// WriteSubset writes a header in wire format.
|
||||
// If exclude is not nil, keys where exclude[key] == true are not written.
|
||||
func (h Header) WriteSubset(w io.Writer, exclude map[string]bool) error {
|
||||
ws, ok := w.(writeStringer)
|
||||
if !ok {
|
||||
ws = stringWriter{w}
|
||||
}
|
||||
kvs, sorter := h.sortedKeyValues(exclude)
|
||||
for _, kv := range kvs {
|
||||
for _, v := range kv.values {
|
||||
v = headerNewlineToSpace.Replace(v)
|
||||
v = textproto.TrimString(v)
|
||||
for _, s := range []string{kv.key, ": ", v, "\r\n"} {
|
||||
if _, err := ws.WriteString(s); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
headerSorterPool.Put(sorter)
|
||||
return nil
|
||||
}
|
||||
|
||||
// CanonicalHeaderKey returns the canonical format of the
|
||||
// header key s. The canonicalization converts the first
|
||||
// letter and any letter following a hyphen to upper case;
|
||||
// the rest are converted to lowercase. For example, the
|
||||
// canonical key for "accept-encoding" is "Accept-Encoding".
|
||||
func CanonicalHeaderKey(s string) string { return textproto.CanonicalMIMEHeaderKey(s) }
|
||||
|
||||
// hasToken reports whether token appears with v, ASCII
|
||||
// case-insensitive, with space or comma boundaries.
|
||||
// token must be all lowercase.
|
||||
// v may contain mixed cased.
|
||||
func hasToken(v, token string) bool {
|
||||
if len(token) > len(v) || token == "" {
|
||||
return false
|
||||
}
|
||||
if v == token {
|
||||
return true
|
||||
}
|
||||
for sp := 0; sp <= len(v)-len(token); sp++ {
|
||||
// Check that first character is good.
|
||||
// The token is ASCII, so checking only a single byte
|
||||
// is sufficient. We skip this potential starting
|
||||
// position if both the first byte and its potential
|
||||
// ASCII uppercase equivalent (b|0x20) don't match.
|
||||
// False positives ('^' => '~') are caught by EqualFold.
|
||||
if b := v[sp]; b != token[0] && b|0x20 != token[0] {
|
||||
continue
|
||||
}
|
||||
// Check that start pos is on a valid token boundary.
|
||||
if sp > 0 && !isTokenBoundary(v[sp-1]) {
|
||||
continue
|
||||
}
|
||||
// Check that end pos is on a valid token boundary.
|
||||
if endPos := sp + len(token); endPos != len(v) && !isTokenBoundary(v[endPos]) {
|
||||
continue
|
||||
}
|
||||
if strings.EqualFold(v[sp:sp+len(token)], token) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func isTokenBoundary(b byte) bool {
|
||||
return b == ' ' || b == ',' || b == '\t'
|
||||
}
|
|
@ -0,0 +1,27 @@
|
|||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package http
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
)
|
||||
|
||||
// A CookieJar manages storage and use of cookies in HTTP requests.
|
||||
//
|
||||
// Implementations of CookieJar must be safe for concurrent use by multiple
|
||||
// goroutines.
|
||||
//
|
||||
// The net/http/cookiejar package provides a CookieJar implementation.
|
||||
type CookieJar interface {
|
||||
// SetCookies handles the receipt of the cookies in a reply for the
|
||||
// given URL. It may or may not choose to save the cookies, depending
|
||||
// on the jar's policy and implementation.
|
||||
SetCookies(u *url.URL, cookies []*Cookie)
|
||||
|
||||
// Cookies returns the cookies to send in a request for the given URL.
|
||||
// It is up to the implementation to honor the standard cookie use
|
||||
// restrictions such as in RFC 6265.
|
||||
Cookies(u *url.URL) []*Cookie
|
||||
}
|
|
@ -0,0 +1,96 @@
|
|||
// Copyright 2009 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package http
|
||||
|
||||
// This file deals with lexical matters of HTTP
|
||||
|
||||
var isTokenTable = [127]bool{
|
||||
'!': true,
|
||||
'#': true,
|
||||
'$': true,
|
||||
'%': true,
|
||||
'&': true,
|
||||
'\'': true,
|
||||
'*': true,
|
||||
'+': true,
|
||||
'-': true,
|
||||
'.': true,
|
||||
'0': true,
|
||||
'1': true,
|
||||
'2': true,
|
||||
'3': true,
|
||||
'4': true,
|
||||
'5': true,
|
||||
'6': true,
|
||||
'7': true,
|
||||
'8': true,
|
||||
'9': true,
|
||||
'A': true,
|
||||
'B': true,
|
||||
'C': true,
|
||||
'D': true,
|
||||
'E': true,
|
||||
'F': true,
|
||||
'G': true,
|
||||
'H': true,
|
||||
'I': true,
|
||||
'J': true,
|
||||
'K': true,
|
||||
'L': true,
|
||||
'M': true,
|
||||
'N': true,
|
||||
'O': true,
|
||||
'P': true,
|
||||
'Q': true,
|
||||
'R': true,
|
||||
'S': true,
|
||||
'T': true,
|
||||
'U': true,
|
||||
'W': true,
|
||||
'V': true,
|
||||
'X': true,
|
||||
'Y': true,
|
||||
'Z': true,
|
||||
'^': true,
|
||||
'_': true,
|
||||
'`': true,
|
||||
'a': true,
|
||||
'b': true,
|
||||
'c': true,
|
||||
'd': true,
|
||||
'e': true,
|
||||
'f': true,
|
||||
'g': true,
|
||||
'h': true,
|
||||
'i': true,
|
||||
'j': true,
|
||||
'k': true,
|
||||
'l': true,
|
||||
'm': true,
|
||||
'n': true,
|
||||
'o': true,
|
||||
'p': true,
|
||||
'q': true,
|
||||
'r': true,
|
||||
's': true,
|
||||
't': true,
|
||||
'u': true,
|
||||
'v': true,
|
||||
'w': true,
|
||||
'x': true,
|
||||
'y': true,
|
||||
'z': true,
|
||||
'|': true,
|
||||
'~': true,
|
||||
}
|
||||
|
||||
func isToken(r rune) bool {
|
||||
i := int(r)
|
||||
return i < len(isTokenTable) && isTokenTable[i]
|
||||
}
|
||||
|
||||
func isNotToken(r rune) bool {
|
||||
return !isToken(r)
|
||||
}
|
|
@ -0,0 +1,11 @@
|
|||
// Copyright 2014 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build race
|
||||
|
||||
package http
|
||||
|
||||
func init() {
|
||||
raceEnabled = true
|
||||
}
|
875
vendor/github.com/masterzen/azure-sdk-for-go/core/http/request.go
generated
vendored
Normal file
875
vendor/github.com/masterzen/azure-sdk-for-go/core/http/request.go
generated
vendored
Normal file
|
@ -0,0 +1,875 @@
|
|||
// Copyright 2009 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// HTTP Request reading and parsing.
|
||||
|
||||
package http
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"mime"
|
||||
"mime/multipart"
|
||||
"net/textproto"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
const (
|
||||
maxValueLength = 4096
|
||||
maxHeaderLines = 1024
|
||||
chunkSize = 4 << 10 // 4 KB chunks
|
||||
defaultMaxMemory = 32 << 20 // 32 MB
|
||||
)
|
||||
|
||||
// ErrMissingFile is returned by FormFile when the provided file field name
|
||||
// is either not present in the request or not a file field.
|
||||
var ErrMissingFile = errors.New("http: no such file")
|
||||
|
||||
// HTTP request parsing errors.
|
||||
type ProtocolError struct {
|
||||
ErrorString string
|
||||
}
|
||||
|
||||
func (err *ProtocolError) Error() string { return err.ErrorString }
|
||||
|
||||
var (
|
||||
ErrHeaderTooLong = &ProtocolError{"header too long"}
|
||||
ErrShortBody = &ProtocolError{"entity body too short"}
|
||||
ErrNotSupported = &ProtocolError{"feature not supported"}
|
||||
ErrUnexpectedTrailer = &ProtocolError{"trailer header without chunked transfer encoding"}
|
||||
ErrMissingContentLength = &ProtocolError{"missing ContentLength in HEAD response"}
|
||||
ErrNotMultipart = &ProtocolError{"request Content-Type isn't multipart/form-data"}
|
||||
ErrMissingBoundary = &ProtocolError{"no multipart boundary param in Content-Type"}
|
||||
)
|
||||
|
||||
type badStringError struct {
|
||||
what string
|
||||
str string
|
||||
}
|
||||
|
||||
func (e *badStringError) Error() string { return fmt.Sprintf("%s %q", e.what, e.str) }
|
||||
|
||||
// Headers that Request.Write handles itself and should be skipped.
|
||||
var reqWriteExcludeHeader = map[string]bool{
|
||||
"Host": true, // not in Header map anyway
|
||||
"User-Agent": true,
|
||||
"Content-Length": true,
|
||||
"Transfer-Encoding": true,
|
||||
"Trailer": true,
|
||||
}
|
||||
|
||||
// A Request represents an HTTP request received by a server
|
||||
// or to be sent by a client.
|
||||
//
|
||||
// The field semantics differ slightly between client and server
|
||||
// usage. In addition to the notes on the fields below, see the
|
||||
// documentation for Request.Write and RoundTripper.
|
||||
type Request struct {
|
||||
// Method specifies the HTTP method (GET, POST, PUT, etc.).
|
||||
// For client requests an empty string means GET.
|
||||
Method string
|
||||
|
||||
// URL specifies either the URI being requested (for server
|
||||
// requests) or the URL to access (for client requests).
|
||||
//
|
||||
// For server requests the URL is parsed from the URI
|
||||
// supplied on the Request-Line as stored in RequestURI. For
|
||||
// most requests, fields other than Path and RawQuery will be
|
||||
// empty. (See RFC 2616, Section 5.1.2)
|
||||
//
|
||||
// For client requests, the URL's Host specifies the server to
|
||||
// connect to, while the Request's Host field optionally
|
||||
// specifies the Host header value to send in the HTTP
|
||||
// request.
|
||||
URL *url.URL
|
||||
|
||||
// The protocol version for incoming requests.
|
||||
// Client requests always use HTTP/1.1.
|
||||
Proto string // "HTTP/1.0"
|
||||
ProtoMajor int // 1
|
||||
ProtoMinor int // 0
|
||||
|
||||
// A header maps request lines to their values.
|
||||
// If the header says
|
||||
//
|
||||
// accept-encoding: gzip, deflate
|
||||
// Accept-Language: en-us
|
||||
// Connection: keep-alive
|
||||
//
|
||||
// then
|
||||
//
|
||||
// Header = map[string][]string{
|
||||
// "Accept-Encoding": {"gzip, deflate"},
|
||||
// "Accept-Language": {"en-us"},
|
||||
// "Connection": {"keep-alive"},
|
||||
// }
|
||||
//
|
||||
// HTTP defines that header names are case-insensitive.
|
||||
// The request parser implements this by canonicalizing the
|
||||
// name, making the first character and any characters
|
||||
// following a hyphen uppercase and the rest lowercase.
|
||||
//
|
||||
// For client requests certain headers are automatically
|
||||
// added and may override values in Header.
|
||||
//
|
||||
// See the documentation for the Request.Write method.
|
||||
Header Header
|
||||
|
||||
// Body is the request's body.
|
||||
//
|
||||
// For client requests a nil body means the request has no
|
||||
// body, such as a GET request. The HTTP Client's Transport
|
||||
// is responsible for calling the Close method.
|
||||
//
|
||||
// For server requests the Request Body is always non-nil
|
||||
// but will return EOF immediately when no body is present.
|
||||
// The Server will close the request body. The ServeHTTP
|
||||
// Handler does not need to.
|
||||
Body io.ReadCloser
|
||||
|
||||
// ContentLength records the length of the associated content.
|
||||
// The value -1 indicates that the length is unknown.
|
||||
// Values >= 0 indicate that the given number of bytes may
|
||||
// be read from Body.
|
||||
// For client requests, a value of 0 means unknown if Body is not nil.
|
||||
ContentLength int64
|
||||
|
||||
// TransferEncoding lists the transfer encodings from outermost to
|
||||
// innermost. An empty list denotes the "identity" encoding.
|
||||
// TransferEncoding can usually be ignored; chunked encoding is
|
||||
// automatically added and removed as necessary when sending and
|
||||
// receiving requests.
|
||||
TransferEncoding []string
|
||||
|
||||
// Close indicates whether to close the connection after
|
||||
// replying to this request (for servers) or after sending
|
||||
// the request (for clients).
|
||||
Close bool
|
||||
|
||||
// For server requests Host specifies the host on which the
|
||||
// URL is sought. Per RFC 2616, this is either the value of
|
||||
// the "Host" header or the host name given in the URL itself.
|
||||
// It may be of the form "host:port".
|
||||
//
|
||||
// For client requests Host optionally overrides the Host
|
||||
// header to send. If empty, the Request.Write method uses
|
||||
// the value of URL.Host.
|
||||
Host string
|
||||
|
||||
// Form contains the parsed form data, including both the URL
|
||||
// field's query parameters and the POST or PUT form data.
|
||||
// This field is only available after ParseForm is called.
|
||||
// The HTTP client ignores Form and uses Body instead.
|
||||
Form url.Values
|
||||
|
||||
// PostForm contains the parsed form data from POST or PUT
|
||||
// body parameters.
|
||||
// This field is only available after ParseForm is called.
|
||||
// The HTTP client ignores PostForm and uses Body instead.
|
||||
PostForm url.Values
|
||||
|
||||
// MultipartForm is the parsed multipart form, including file uploads.
|
||||
// This field is only available after ParseMultipartForm is called.
|
||||
// The HTTP client ignores MultipartForm and uses Body instead.
|
||||
MultipartForm *multipart.Form
|
||||
|
||||
// Trailer specifies additional headers that are sent after the request
|
||||
// body.
|
||||
//
|
||||
// For server requests the Trailer map initially contains only the
|
||||
// trailer keys, with nil values. (The client declares which trailers it
|
||||
// will later send.) While the handler is reading from Body, it must
|
||||
// not reference Trailer. After reading from Body returns EOF, Trailer
|
||||
// can be read again and will contain non-nil values, if they were sent
|
||||
// by the client.
|
||||
//
|
||||
// For client requests Trailer must be initialized to a map containing
|
||||
// the trailer keys to later send. The values may be nil or their final
|
||||
// values. The ContentLength must be 0 or -1, to send a chunked request.
|
||||
// After the HTTP request is sent the map values can be updated while
|
||||
// the request body is read. Once the body returns EOF, the caller must
|
||||
// not mutate Trailer.
|
||||
//
|
||||
// Few HTTP clients, servers, or proxies support HTTP trailers.
|
||||
Trailer Header
|
||||
|
||||
// RemoteAddr allows HTTP servers and other software to record
|
||||
// the network address that sent the request, usually for
|
||||
// logging. This field is not filled in by ReadRequest and
|
||||
// has no defined format. The HTTP server in this package
|
||||
// sets RemoteAddr to an "IP:port" address before invoking a
|
||||
// handler.
|
||||
// This field is ignored by the HTTP client.
|
||||
RemoteAddr string
|
||||
|
||||
// RequestURI is the unmodified Request-URI of the
|
||||
// Request-Line (RFC 2616, Section 5.1) as sent by the client
|
||||
// to a server. Usually the URL field should be used instead.
|
||||
// It is an error to set this field in an HTTP client request.
|
||||
RequestURI string
|
||||
|
||||
// TLS allows HTTP servers and other software to record
|
||||
// information about the TLS connection on which the request
|
||||
// was received. This field is not filled in by ReadRequest.
|
||||
// The HTTP server in this package sets the field for
|
||||
// TLS-enabled connections before invoking a handler;
|
||||
// otherwise it leaves the field nil.
|
||||
// This field is ignored by the HTTP client.
|
||||
TLS *tls.ConnectionState
|
||||
}
|
||||
|
||||
// ProtoAtLeast reports whether the HTTP protocol used
|
||||
// in the request is at least major.minor.
|
||||
func (r *Request) ProtoAtLeast(major, minor int) bool {
|
||||
return r.ProtoMajor > major ||
|
||||
r.ProtoMajor == major && r.ProtoMinor >= minor
|
||||
}
|
||||
|
||||
// UserAgent returns the client's User-Agent, if sent in the request.
|
||||
func (r *Request) UserAgent() string {
|
||||
return r.Header.Get("User-Agent")
|
||||
}
|
||||
|
||||
// Cookies parses and returns the HTTP cookies sent with the request.
|
||||
func (r *Request) Cookies() []*Cookie {
|
||||
return readCookies(r.Header, "")
|
||||
}
|
||||
|
||||
var ErrNoCookie = errors.New("http: named cookie not present")
|
||||
|
||||
// Cookie returns the named cookie provided in the request or
|
||||
// ErrNoCookie if not found.
|
||||
func (r *Request) Cookie(name string) (*Cookie, error) {
|
||||
for _, c := range readCookies(r.Header, name) {
|
||||
return c, nil
|
||||
}
|
||||
return nil, ErrNoCookie
|
||||
}
|
||||
|
||||
// AddCookie adds a cookie to the request. Per RFC 6265 section 5.4,
|
||||
// AddCookie does not attach more than one Cookie header field. That
|
||||
// means all cookies, if any, are written into the same line,
|
||||
// separated by semicolon.
|
||||
func (r *Request) AddCookie(c *Cookie) {
|
||||
s := fmt.Sprintf("%s=%s", sanitizeCookieName(c.Name), sanitizeCookieValue(c.Value))
|
||||
if c := r.Header.Get("Cookie"); c != "" {
|
||||
r.Header.Set("Cookie", c+"; "+s)
|
||||
} else {
|
||||
r.Header.Set("Cookie", s)
|
||||
}
|
||||
}
|
||||
|
||||
// Referer returns the referring URL, if sent in the request.
|
||||
//
|
||||
// Referer is misspelled as in the request itself, a mistake from the
|
||||
// earliest days of HTTP. This value can also be fetched from the
|
||||
// Header map as Header["Referer"]; the benefit of making it available
|
||||
// as a method is that the compiler can diagnose programs that use the
|
||||
// alternate (correct English) spelling req.Referrer() but cannot
|
||||
// diagnose programs that use Header["Referrer"].
|
||||
func (r *Request) Referer() string {
|
||||
return r.Header.Get("Referer")
|
||||
}
|
||||
|
||||
// multipartByReader is a sentinel value.
|
||||
// Its presence in Request.MultipartForm indicates that parsing of the request
|
||||
// body has been handed off to a MultipartReader instead of ParseMultipartFrom.
|
||||
var multipartByReader = &multipart.Form{
|
||||
Value: make(map[string][]string),
|
||||
File: make(map[string][]*multipart.FileHeader),
|
||||
}
|
||||
|
||||
// MultipartReader returns a MIME multipart reader if this is a
|
||||
// multipart/form-data POST request, else returns nil and an error.
|
||||
// Use this function instead of ParseMultipartForm to
|
||||
// process the request body as a stream.
|
||||
func (r *Request) MultipartReader() (*multipart.Reader, error) {
|
||||
if r.MultipartForm == multipartByReader {
|
||||
return nil, errors.New("http: MultipartReader called twice")
|
||||
}
|
||||
if r.MultipartForm != nil {
|
||||
return nil, errors.New("http: multipart handled by ParseMultipartForm")
|
||||
}
|
||||
r.MultipartForm = multipartByReader
|
||||
return r.multipartReader()
|
||||
}
|
||||
|
||||
func (r *Request) multipartReader() (*multipart.Reader, error) {
|
||||
v := r.Header.Get("Content-Type")
|
||||
if v == "" {
|
||||
return nil, ErrNotMultipart
|
||||
}
|
||||
d, params, err := mime.ParseMediaType(v)
|
||||
if err != nil || d != "multipart/form-data" {
|
||||
return nil, ErrNotMultipart
|
||||
}
|
||||
boundary, ok := params["boundary"]
|
||||
if !ok {
|
||||
return nil, ErrMissingBoundary
|
||||
}
|
||||
return multipart.NewReader(r.Body, boundary), nil
|
||||
}
|
||||
|
||||
// Return value if nonempty, def otherwise.
|
||||
func valueOrDefault(value, def string) string {
|
||||
if value != "" {
|
||||
return value
|
||||
}
|
||||
return def
|
||||
}
|
||||
|
||||
// NOTE: This is not intended to reflect the actual Go version being used.
|
||||
// It was changed from "Go http package" to "Go 1.1 package http" at the
|
||||
// time of the Go 1.1 release because the former User-Agent had ended up
|
||||
// on a blacklist for some intrusion detection systems.
|
||||
// See https://codereview.appspot.com/7532043.
|
||||
const defaultUserAgent = "Go 1.1 package http"
|
||||
|
||||
// Write writes an HTTP/1.1 request -- header and body -- in wire format.
|
||||
// This method consults the following fields of the request:
|
||||
// Host
|
||||
// URL
|
||||
// Method (defaults to "GET")
|
||||
// Header
|
||||
// ContentLength
|
||||
// TransferEncoding
|
||||
// Body
|
||||
//
|
||||
// If Body is present, Content-Length is <= 0 and TransferEncoding
|
||||
// hasn't been set to "identity", Write adds "Transfer-Encoding:
|
||||
// chunked" to the header. Body is closed after it is sent.
|
||||
func (r *Request) Write(w io.Writer) error {
|
||||
return r.write(w, false, nil)
|
||||
}
|
||||
|
||||
// WriteProxy is like Write but writes the request in the form
|
||||
// expected by an HTTP proxy. In particular, WriteProxy writes the
|
||||
// initial Request-URI line of the request with an absolute URI, per
|
||||
// section 5.1.2 of RFC 2616, including the scheme and host.
|
||||
// In either case, WriteProxy also writes a Host header, using
|
||||
// either r.Host or r.URL.Host.
|
||||
func (r *Request) WriteProxy(w io.Writer) error {
|
||||
return r.write(w, true, nil)
|
||||
}
|
||||
|
||||
// extraHeaders may be nil
|
||||
func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header) error {
|
||||
host := req.Host
|
||||
if host == "" {
|
||||
if req.URL == nil {
|
||||
return errors.New("http: Request.Write on Request with no Host or URL set")
|
||||
}
|
||||
host = req.URL.Host
|
||||
}
|
||||
|
||||
ruri := req.URL.RequestURI()
|
||||
if usingProxy && req.URL.Scheme != "" && req.URL.Opaque == "" {
|
||||
ruri = req.URL.Scheme + "://" + host + ruri
|
||||
} else if req.Method == "CONNECT" && req.URL.Path == "" {
|
||||
// CONNECT requests normally give just the host and port, not a full URL.
|
||||
ruri = host
|
||||
}
|
||||
// TODO(bradfitz): escape at least newlines in ruri?
|
||||
|
||||
// Wrap the writer in a bufio Writer if it's not already buffered.
|
||||
// Don't always call NewWriter, as that forces a bytes.Buffer
|
||||
// and other small bufio Writers to have a minimum 4k buffer
|
||||
// size.
|
||||
var bw *bufio.Writer
|
||||
if _, ok := w.(io.ByteWriter); !ok {
|
||||
bw = bufio.NewWriter(w)
|
||||
w = bw
|
||||
}
|
||||
|
||||
fmt.Fprintf(w, "%s %s HTTP/1.1\r\n", valueOrDefault(req.Method, "GET"), ruri)
|
||||
|
||||
// Header lines
|
||||
fmt.Fprintf(w, "Host: %s\r\n", host)
|
||||
|
||||
// Use the defaultUserAgent unless the Header contains one, which
|
||||
// may be blank to not send the header.
|
||||
userAgent := defaultUserAgent
|
||||
if req.Header != nil {
|
||||
if ua := req.Header["User-Agent"]; len(ua) > 0 {
|
||||
userAgent = ua[0]
|
||||
}
|
||||
}
|
||||
if userAgent != "" {
|
||||
fmt.Fprintf(w, "User-Agent: %s\r\n", userAgent)
|
||||
}
|
||||
|
||||
// Process Body,ContentLength,Close,Trailer
|
||||
tw, err := newTransferWriter(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = tw.WriteHeader(w)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = req.Header.WriteSubset(w, reqWriteExcludeHeader)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if extraHeaders != nil {
|
||||
err = extraHeaders.Write(w)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
io.WriteString(w, "\r\n")
|
||||
|
||||
// Write body and trailer
|
||||
err = tw.WriteBody(w)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if bw != nil {
|
||||
return bw.Flush()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ParseHTTPVersion parses a HTTP version string.
|
||||
// "HTTP/1.0" returns (1, 0, true).
|
||||
func ParseHTTPVersion(vers string) (major, minor int, ok bool) {
|
||||
const Big = 1000000 // arbitrary upper bound
|
||||
switch vers {
|
||||
case "HTTP/1.1":
|
||||
return 1, 1, true
|
||||
case "HTTP/1.0":
|
||||
return 1, 0, true
|
||||
}
|
||||
if !strings.HasPrefix(vers, "HTTP/") {
|
||||
return 0, 0, false
|
||||
}
|
||||
dot := strings.Index(vers, ".")
|
||||
if dot < 0 {
|
||||
return 0, 0, false
|
||||
}
|
||||
major, err := strconv.Atoi(vers[5:dot])
|
||||
if err != nil || major < 0 || major > Big {
|
||||
return 0, 0, false
|
||||
}
|
||||
minor, err = strconv.Atoi(vers[dot+1:])
|
||||
if err != nil || minor < 0 || minor > Big {
|
||||
return 0, 0, false
|
||||
}
|
||||
return major, minor, true
|
||||
}
|
||||
|
||||
// NewRequest returns a new Request given a method, URL, and optional body.
|
||||
//
|
||||
// If the provided body is also an io.Closer, the returned
|
||||
// Request.Body is set to body and will be closed by the Client
|
||||
// methods Do, Post, and PostForm, and Transport.RoundTrip.
|
||||
func NewRequest(method, urlStr string, body io.Reader) (*Request, error) {
|
||||
u, err := url.Parse(urlStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rc, ok := body.(io.ReadCloser)
|
||||
if !ok && body != nil {
|
||||
rc = ioutil.NopCloser(body)
|
||||
}
|
||||
req := &Request{
|
||||
Method: method,
|
||||
URL: u,
|
||||
Proto: "HTTP/1.1",
|
||||
ProtoMajor: 1,
|
||||
ProtoMinor: 1,
|
||||
Header: make(Header),
|
||||
Body: rc,
|
||||
Host: u.Host,
|
||||
}
|
||||
if body != nil {
|
||||
switch v := body.(type) {
|
||||
case *bytes.Buffer:
|
||||
req.ContentLength = int64(v.Len())
|
||||
case *bytes.Reader:
|
||||
req.ContentLength = int64(v.Len())
|
||||
case *strings.Reader:
|
||||
req.ContentLength = int64(v.Len())
|
||||
}
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// SetBasicAuth sets the request's Authorization header to use HTTP
|
||||
// Basic Authentication with the provided username and password.
|
||||
//
|
||||
// With HTTP Basic Authentication the provided username and password
|
||||
// are not encrypted.
|
||||
func (r *Request) SetBasicAuth(username, password string) {
|
||||
r.Header.Set("Authorization", "Basic "+basicAuth(username, password))
|
||||
}
|
||||
|
||||
// parseRequestLine parses "GET /foo HTTP/1.1" into its three parts.
|
||||
func parseRequestLine(line string) (method, requestURI, proto string, ok bool) {
|
||||
s1 := strings.Index(line, " ")
|
||||
s2 := strings.Index(line[s1+1:], " ")
|
||||
if s1 < 0 || s2 < 0 {
|
||||
return
|
||||
}
|
||||
s2 += s1 + 1
|
||||
return line[:s1], line[s1+1 : s2], line[s2+1:], true
|
||||
}
|
||||
|
||||
var textprotoReaderPool sync.Pool
|
||||
|
||||
func newTextprotoReader(br *bufio.Reader) *textproto.Reader {
|
||||
if v := textprotoReaderPool.Get(); v != nil {
|
||||
tr := v.(*textproto.Reader)
|
||||
tr.R = br
|
||||
return tr
|
||||
}
|
||||
return textproto.NewReader(br)
|
||||
}
|
||||
|
||||
func putTextprotoReader(r *textproto.Reader) {
|
||||
r.R = nil
|
||||
textprotoReaderPool.Put(r)
|
||||
}
|
||||
|
||||
// ReadRequest reads and parses a request from b.
|
||||
func ReadRequest(b *bufio.Reader) (req *Request, err error) {
|
||||
|
||||
tp := newTextprotoReader(b)
|
||||
req = new(Request)
|
||||
|
||||
// First line: GET /index.html HTTP/1.0
|
||||
var s string
|
||||
if s, err = tp.ReadLine(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
putTextprotoReader(tp)
|
||||
if err == io.EOF {
|
||||
err = io.ErrUnexpectedEOF
|
||||
}
|
||||
}()
|
||||
|
||||
var ok bool
|
||||
req.Method, req.RequestURI, req.Proto, ok = parseRequestLine(s)
|
||||
if !ok {
|
||||
return nil, &badStringError{"malformed HTTP request", s}
|
||||
}
|
||||
rawurl := req.RequestURI
|
||||
if req.ProtoMajor, req.ProtoMinor, ok = ParseHTTPVersion(req.Proto); !ok {
|
||||
return nil, &badStringError{"malformed HTTP version", req.Proto}
|
||||
}
|
||||
|
||||
// CONNECT requests are used two different ways, and neither uses a full URL:
|
||||
// The standard use is to tunnel HTTPS through an HTTP proxy.
|
||||
// It looks like "CONNECT www.google.com:443 HTTP/1.1", and the parameter is
|
||||
// just the authority section of a URL. This information should go in req.URL.Host.
|
||||
//
|
||||
// The net/rpc package also uses CONNECT, but there the parameter is a path
|
||||
// that starts with a slash. It can be parsed with the regular URL parser,
|
||||
// and the path will end up in req.URL.Path, where it needs to be in order for
|
||||
// RPC to work.
|
||||
justAuthority := req.Method == "CONNECT" && !strings.HasPrefix(rawurl, "/")
|
||||
if justAuthority {
|
||||
rawurl = "http://" + rawurl
|
||||
}
|
||||
|
||||
if req.URL, err = url.ParseRequestURI(rawurl); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if justAuthority {
|
||||
// Strip the bogus "http://" back off.
|
||||
req.URL.Scheme = ""
|
||||
}
|
||||
|
||||
// Subsequent lines: Key: value.
|
||||
mimeHeader, err := tp.ReadMIMEHeader()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header = Header(mimeHeader)
|
||||
|
||||
// RFC2616: Must treat
|
||||
// GET /index.html HTTP/1.1
|
||||
// Host: www.google.com
|
||||
// and
|
||||
// GET http://www.google.com/index.html HTTP/1.1
|
||||
// Host: doesntmatter
|
||||
// the same. In the second case, any Host line is ignored.
|
||||
req.Host = req.URL.Host
|
||||
if req.Host == "" {
|
||||
req.Host = req.Header.get("Host")
|
||||
}
|
||||
delete(req.Header, "Host")
|
||||
|
||||
fixPragmaCacheControl(req.Header)
|
||||
|
||||
err = readTransfer(req, b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// MaxBytesReader is similar to io.LimitReader but is intended for
|
||||
// limiting the size of incoming request bodies. In contrast to
|
||||
// io.LimitReader, MaxBytesReader's result is a ReadCloser, returns a
|
||||
// non-EOF error for a Read beyond the limit, and Closes the
|
||||
// underlying reader when its Close method is called.
|
||||
//
|
||||
// MaxBytesReader prevents clients from accidentally or maliciously
|
||||
// sending a large request and wasting server resources.
|
||||
func MaxBytesReader(w ResponseWriter, r io.ReadCloser, n int64) io.ReadCloser {
|
||||
return &maxBytesReader{w: w, r: r, n: n}
|
||||
}
|
||||
|
||||
type maxBytesReader struct {
|
||||
w ResponseWriter
|
||||
r io.ReadCloser // underlying reader
|
||||
n int64 // max bytes remaining
|
||||
stopped bool
|
||||
}
|
||||
|
||||
func (l *maxBytesReader) Read(p []byte) (n int, err error) {
|
||||
if l.n <= 0 {
|
||||
if !l.stopped {
|
||||
l.stopped = true
|
||||
if res, ok := l.w.(*response); ok {
|
||||
res.requestTooLarge()
|
||||
}
|
||||
}
|
||||
return 0, errors.New("http: request body too large")
|
||||
}
|
||||
if int64(len(p)) > l.n {
|
||||
p = p[:l.n]
|
||||
}
|
||||
n, err = l.r.Read(p)
|
||||
l.n -= int64(n)
|
||||
return
|
||||
}
|
||||
|
||||
func (l *maxBytesReader) Close() error {
|
||||
return l.r.Close()
|
||||
}
|
||||
|
||||
func copyValues(dst, src url.Values) {
|
||||
for k, vs := range src {
|
||||
for _, value := range vs {
|
||||
dst.Add(k, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func parsePostForm(r *Request) (vs url.Values, err error) {
|
||||
if r.Body == nil {
|
||||
err = errors.New("missing form body")
|
||||
return
|
||||
}
|
||||
ct := r.Header.Get("Content-Type")
|
||||
// RFC 2616, section 7.2.1 - empty type
|
||||
// SHOULD be treated as application/octet-stream
|
||||
if ct == "" {
|
||||
ct = "application/octet-stream"
|
||||
}
|
||||
ct, _, err = mime.ParseMediaType(ct)
|
||||
switch {
|
||||
case ct == "application/x-www-form-urlencoded":
|
||||
var reader io.Reader = r.Body
|
||||
maxFormSize := int64(1<<63 - 1)
|
||||
if _, ok := r.Body.(*maxBytesReader); !ok {
|
||||
maxFormSize = int64(10 << 20) // 10 MB is a lot of text.
|
||||
reader = io.LimitReader(r.Body, maxFormSize+1)
|
||||
}
|
||||
b, e := ioutil.ReadAll(reader)
|
||||
if e != nil {
|
||||
if err == nil {
|
||||
err = e
|
||||
}
|
||||
break
|
||||
}
|
||||
if int64(len(b)) > maxFormSize {
|
||||
err = errors.New("http: POST too large")
|
||||
return
|
||||
}
|
||||
vs, e = url.ParseQuery(string(b))
|
||||
if err == nil {
|
||||
err = e
|
||||
}
|
||||
case ct == "multipart/form-data":
|
||||
// handled by ParseMultipartForm (which is calling us, or should be)
|
||||
// TODO(bradfitz): there are too many possible
|
||||
// orders to call too many functions here.
|
||||
// Clean this up and write more tests.
|
||||
// request_test.go contains the start of this,
|
||||
// in TestParseMultipartFormOrder and others.
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// ParseForm parses the raw query from the URL and updates r.Form.
|
||||
//
|
||||
// For POST or PUT requests, it also parses the request body as a form and
|
||||
// put the results into both r.PostForm and r.Form.
|
||||
// POST and PUT body parameters take precedence over URL query string values
|
||||
// in r.Form.
|
||||
//
|
||||
// If the request Body's size has not already been limited by MaxBytesReader,
|
||||
// the size is capped at 10MB.
|
||||
//
|
||||
// ParseMultipartForm calls ParseForm automatically.
|
||||
// It is idempotent.
|
||||
func (r *Request) ParseForm() error {
|
||||
var err error
|
||||
if r.PostForm == nil {
|
||||
if r.Method == "POST" || r.Method == "PUT" || r.Method == "PATCH" {
|
||||
r.PostForm, err = parsePostForm(r)
|
||||
}
|
||||
if r.PostForm == nil {
|
||||
r.PostForm = make(url.Values)
|
||||
}
|
||||
}
|
||||
if r.Form == nil {
|
||||
if len(r.PostForm) > 0 {
|
||||
r.Form = make(url.Values)
|
||||
copyValues(r.Form, r.PostForm)
|
||||
}
|
||||
var newValues url.Values
|
||||
if r.URL != nil {
|
||||
var e error
|
||||
newValues, e = url.ParseQuery(r.URL.RawQuery)
|
||||
if err == nil {
|
||||
err = e
|
||||
}
|
||||
}
|
||||
if newValues == nil {
|
||||
newValues = make(url.Values)
|
||||
}
|
||||
if r.Form == nil {
|
||||
r.Form = newValues
|
||||
} else {
|
||||
copyValues(r.Form, newValues)
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// ParseMultipartForm parses a request body as multipart/form-data.
|
||||
// The whole request body is parsed and up to a total of maxMemory bytes of
|
||||
// its file parts are stored in memory, with the remainder stored on
|
||||
// disk in temporary files.
|
||||
// ParseMultipartForm calls ParseForm if necessary.
|
||||
// After one call to ParseMultipartForm, subsequent calls have no effect.
|
||||
func (r *Request) ParseMultipartForm(maxMemory int64) error {
|
||||
if r.MultipartForm == multipartByReader {
|
||||
return errors.New("http: multipart handled by MultipartReader")
|
||||
}
|
||||
if r.Form == nil {
|
||||
err := r.ParseForm()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if r.MultipartForm != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
mr, err := r.multipartReader()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
f, err := mr.ReadForm(maxMemory)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for k, v := range f.Value {
|
||||
r.Form[k] = append(r.Form[k], v...)
|
||||
}
|
||||
r.MultipartForm = f
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// FormValue returns the first value for the named component of the query.
|
||||
// POST and PUT body parameters take precedence over URL query string values.
|
||||
// FormValue calls ParseMultipartForm and ParseForm if necessary.
|
||||
// To access multiple values of the same key use ParseForm.
|
||||
func (r *Request) FormValue(key string) string {
|
||||
if r.Form == nil {
|
||||
r.ParseMultipartForm(defaultMaxMemory)
|
||||
}
|
||||
if vs := r.Form[key]; len(vs) > 0 {
|
||||
return vs[0]
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// PostFormValue returns the first value for the named component of the POST
|
||||
// or PUT request body. URL query parameters are ignored.
|
||||
// PostFormValue calls ParseMultipartForm and ParseForm if necessary.
|
||||
func (r *Request) PostFormValue(key string) string {
|
||||
if r.PostForm == nil {
|
||||
r.ParseMultipartForm(defaultMaxMemory)
|
||||
}
|
||||
if vs := r.PostForm[key]; len(vs) > 0 {
|
||||
return vs[0]
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// FormFile returns the first file for the provided form key.
|
||||
// FormFile calls ParseMultipartForm and ParseForm if necessary.
|
||||
func (r *Request) FormFile(key string) (multipart.File, *multipart.FileHeader, error) {
|
||||
if r.MultipartForm == multipartByReader {
|
||||
return nil, nil, errors.New("http: multipart handled by MultipartReader")
|
||||
}
|
||||
if r.MultipartForm == nil {
|
||||
err := r.ParseMultipartForm(defaultMaxMemory)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
if r.MultipartForm != nil && r.MultipartForm.File != nil {
|
||||
if fhs := r.MultipartForm.File[key]; len(fhs) > 0 {
|
||||
f, err := fhs[0].Open()
|
||||
return f, fhs[0], err
|
||||
}
|
||||
}
|
||||
return nil, nil, ErrMissingFile
|
||||
}
|
||||
|
||||
func (r *Request) expectsContinue() bool {
|
||||
return hasToken(r.Header.get("Expect"), "100-continue")
|
||||
}
|
||||
|
||||
func (r *Request) wantsHttp10KeepAlive() bool {
|
||||
if r.ProtoMajor != 1 || r.ProtoMinor != 0 {
|
||||
return false
|
||||
}
|
||||
return hasToken(r.Header.get("Connection"), "keep-alive")
|
||||
}
|
||||
|
||||
func (r *Request) wantsClose() bool {
|
||||
return hasToken(r.Header.get("Connection"), "close")
|
||||
}
|
||||
|
||||
func (r *Request) closeBody() {
|
||||
if r.Body != nil {
|
||||
r.Body.Close()
|
||||
}
|
||||
}
|
291
vendor/github.com/masterzen/azure-sdk-for-go/core/http/response.go
generated
vendored
Normal file
291
vendor/github.com/masterzen/azure-sdk-for-go/core/http/response.go
generated
vendored
Normal file
|
@ -0,0 +1,291 @@
|
|||
// Copyright 2009 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// HTTP Response reading and parsing.
|
||||
|
||||
package http
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"errors"
|
||||
"github.com/masterzen/azure-sdk-for-go/core/tls"
|
||||
"io"
|
||||
"net/textproto"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var respExcludeHeader = map[string]bool{
|
||||
"Content-Length": true,
|
||||
"Transfer-Encoding": true,
|
||||
"Trailer": true,
|
||||
}
|
||||
|
||||
// Response represents the response from an HTTP request.
|
||||
//
|
||||
type Response struct {
|
||||
Status string // e.g. "200 OK"
|
||||
StatusCode int // e.g. 200
|
||||
Proto string // e.g. "HTTP/1.0"
|
||||
ProtoMajor int // e.g. 1
|
||||
ProtoMinor int // e.g. 0
|
||||
|
||||
// Header maps header keys to values. If the response had multiple
|
||||
// headers with the same key, they may be concatenated, with comma
|
||||
// delimiters. (Section 4.2 of RFC 2616 requires that multiple headers
|
||||
// be semantically equivalent to a comma-delimited sequence.) Values
|
||||
// duplicated by other fields in this struct (e.g., ContentLength) are
|
||||
// omitted from Header.
|
||||
//
|
||||
// Keys in the map are canonicalized (see CanonicalHeaderKey).
|
||||
Header Header
|
||||
|
||||
// Body represents the response body.
|
||||
//
|
||||
// The http Client and Transport guarantee that Body is always
|
||||
// non-nil, even on responses without a body or responses with
|
||||
// a zero-length body. It is the caller's responsibility to
|
||||
// close Body.
|
||||
//
|
||||
// The Body is automatically dechunked if the server replied
|
||||
// with a "chunked" Transfer-Encoding.
|
||||
Body io.ReadCloser
|
||||
|
||||
// ContentLength records the length of the associated content. The
|
||||
// value -1 indicates that the length is unknown. Unless Request.Method
|
||||
// is "HEAD", values >= 0 indicate that the given number of bytes may
|
||||
// be read from Body.
|
||||
ContentLength int64
|
||||
|
||||
// Contains transfer encodings from outer-most to inner-most. Value is
|
||||
// nil, means that "identity" encoding is used.
|
||||
TransferEncoding []string
|
||||
|
||||
// Close records whether the header directed that the connection be
|
||||
// closed after reading Body. The value is advice for clients: neither
|
||||
// ReadResponse nor Response.Write ever closes a connection.
|
||||
Close bool
|
||||
|
||||
// Trailer maps trailer keys to values, in the same
|
||||
// format as the header.
|
||||
Trailer Header
|
||||
|
||||
// The Request that was sent to obtain this Response.
|
||||
// Request's Body is nil (having already been consumed).
|
||||
// This is only populated for Client requests.
|
||||
Request *Request
|
||||
|
||||
// TLS contains information about the TLS connection on which the
|
||||
// response was received. It is nil for unencrypted responses.
|
||||
// The pointer is shared between responses and should not be
|
||||
// modified.
|
||||
TLS *tls.ConnectionState
|
||||
}
|
||||
|
||||
// Cookies parses and returns the cookies set in the Set-Cookie headers.
|
||||
func (r *Response) Cookies() []*Cookie {
|
||||
return readSetCookies(r.Header)
|
||||
}
|
||||
|
||||
var ErrNoLocation = errors.New("http: no Location header in response")
|
||||
|
||||
// Location returns the URL of the response's "Location" header,
|
||||
// if present. Relative redirects are resolved relative to
|
||||
// the Response's Request. ErrNoLocation is returned if no
|
||||
// Location header is present.
|
||||
func (r *Response) Location() (*url.URL, error) {
|
||||
lv := r.Header.Get("Location")
|
||||
if lv == "" {
|
||||
return nil, ErrNoLocation
|
||||
}
|
||||
if r.Request != nil && r.Request.URL != nil {
|
||||
return r.Request.URL.Parse(lv)
|
||||
}
|
||||
return url.Parse(lv)
|
||||
}
|
||||
|
||||
// ReadResponse reads and returns an HTTP response from r.
|
||||
// The req parameter optionally specifies the Request that corresponds
|
||||
// to this Response. If nil, a GET request is assumed.
|
||||
// Clients must call resp.Body.Close when finished reading resp.Body.
|
||||
// After that call, clients can inspect resp.Trailer to find key/value
|
||||
// pairs included in the response trailer.
|
||||
func ReadResponse(r *bufio.Reader, req *Request) (*Response, error) {
|
||||
tp := textproto.NewReader(r)
|
||||
resp := &Response{
|
||||
Request: req,
|
||||
}
|
||||
|
||||
// Parse the first line of the response.
|
||||
line, err := tp.ReadLine()
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
err = io.ErrUnexpectedEOF
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
f := strings.SplitN(line, " ", 3)
|
||||
if len(f) < 2 {
|
||||
return nil, &badStringError{"malformed HTTP response", line}
|
||||
}
|
||||
reasonPhrase := ""
|
||||
if len(f) > 2 {
|
||||
reasonPhrase = f[2]
|
||||
}
|
||||
resp.Status = f[1] + " " + reasonPhrase
|
||||
resp.StatusCode, err = strconv.Atoi(f[1])
|
||||
if err != nil {
|
||||
return nil, &badStringError{"malformed HTTP status code", f[1]}
|
||||
}
|
||||
|
||||
resp.Proto = f[0]
|
||||
var ok bool
|
||||
if resp.ProtoMajor, resp.ProtoMinor, ok = ParseHTTPVersion(resp.Proto); !ok {
|
||||
return nil, &badStringError{"malformed HTTP version", resp.Proto}
|
||||
}
|
||||
|
||||
// Parse the response headers.
|
||||
mimeHeader, err := tp.ReadMIMEHeader()
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
err = io.ErrUnexpectedEOF
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
resp.Header = Header(mimeHeader)
|
||||
|
||||
fixPragmaCacheControl(resp.Header)
|
||||
|
||||
err = readTransfer(resp, r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// RFC2616: Should treat
|
||||
// Pragma: no-cache
|
||||
// like
|
||||
// Cache-Control: no-cache
|
||||
func fixPragmaCacheControl(header Header) {
|
||||
if hp, ok := header["Pragma"]; ok && len(hp) > 0 && hp[0] == "no-cache" {
|
||||
if _, presentcc := header["Cache-Control"]; !presentcc {
|
||||
header["Cache-Control"] = []string{"no-cache"}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ProtoAtLeast reports whether the HTTP protocol used
|
||||
// in the response is at least major.minor.
|
||||
func (r *Response) ProtoAtLeast(major, minor int) bool {
|
||||
return r.ProtoMajor > major ||
|
||||
r.ProtoMajor == major && r.ProtoMinor >= minor
|
||||
}
|
||||
|
||||
// Writes the response (header, body and trailer) in wire format. This method
|
||||
// consults the following fields of the response:
|
||||
//
|
||||
// StatusCode
|
||||
// ProtoMajor
|
||||
// ProtoMinor
|
||||
// Request.Method
|
||||
// TransferEncoding
|
||||
// Trailer
|
||||
// Body
|
||||
// ContentLength
|
||||
// Header, values for non-canonical keys will have unpredictable behavior
|
||||
//
|
||||
// Body is closed after it is sent.
|
||||
func (r *Response) Write(w io.Writer) error {
|
||||
// Status line
|
||||
text := r.Status
|
||||
if text == "" {
|
||||
var ok bool
|
||||
text, ok = statusText[r.StatusCode]
|
||||
if !ok {
|
||||
text = "status code " + strconv.Itoa(r.StatusCode)
|
||||
}
|
||||
}
|
||||
protoMajor, protoMinor := strconv.Itoa(r.ProtoMajor), strconv.Itoa(r.ProtoMinor)
|
||||
statusCode := strconv.Itoa(r.StatusCode) + " "
|
||||
text = strings.TrimPrefix(text, statusCode)
|
||||
if _, err := io.WriteString(w, "HTTP/"+protoMajor+"."+protoMinor+" "+statusCode+text+"\r\n"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Clone it, so we can modify r1 as needed.
|
||||
r1 := new(Response)
|
||||
*r1 = *r
|
||||
if r1.ContentLength == 0 && r1.Body != nil {
|
||||
// Is it actually 0 length? Or just unknown?
|
||||
var buf [1]byte
|
||||
n, err := r1.Body.Read(buf[:])
|
||||
if err != nil && err != io.EOF {
|
||||
return err
|
||||
}
|
||||
if n == 0 {
|
||||
// Reset it to a known zero reader, in case underlying one
|
||||
// is unhappy being read repeatedly.
|
||||
r1.Body = eofReader
|
||||
} else {
|
||||
r1.ContentLength = -1
|
||||
r1.Body = struct {
|
||||
io.Reader
|
||||
io.Closer
|
||||
}{
|
||||
io.MultiReader(bytes.NewReader(buf[:1]), r.Body),
|
||||
r.Body,
|
||||
}
|
||||
}
|
||||
}
|
||||
// If we're sending a non-chunked HTTP/1.1 response without a
|
||||
// content-length, the only way to do that is the old HTTP/1.0
|
||||
// way, by noting the EOF with a connection close, so we need
|
||||
// to set Close.
|
||||
if r1.ContentLength == -1 && !r1.Close && r1.ProtoAtLeast(1, 1) && !chunked(r1.TransferEncoding) {
|
||||
r1.Close = true
|
||||
}
|
||||
|
||||
// Process Body,ContentLength,Close,Trailer
|
||||
tw, err := newTransferWriter(r1)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = tw.WriteHeader(w)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Rest of header
|
||||
err = r.Header.WriteSubset(w, respExcludeHeader)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// contentLengthAlreadySent may have been already sent for
|
||||
// POST/PUT requests, even if zero length. See Issue 8180.
|
||||
contentLengthAlreadySent := tw.shouldSendContentLength()
|
||||
if r1.ContentLength == 0 && !chunked(r1.TransferEncoding) && !contentLengthAlreadySent {
|
||||
if _, err := io.WriteString(w, "Content-Length: 0\r\n"); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// End-of-header
|
||||
if _, err := io.WriteString(w, "\r\n"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Write body and trailer
|
||||
err = tw.WriteBody(w)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Success
|
||||
return nil
|
||||
}
|
2052
vendor/github.com/masterzen/azure-sdk-for-go/core/http/server.go
generated
vendored
Normal file
2052
vendor/github.com/masterzen/azure-sdk-for-go/core/http/server.go
generated
vendored
Normal file
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,214 @@
|
|||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package http
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
)
|
||||
|
||||
// The algorithm uses at most sniffLen bytes to make its decision.
|
||||
const sniffLen = 512
|
||||
|
||||
// DetectContentType implements the algorithm described
|
||||
// at http://mimesniff.spec.whatwg.org/ to determine the
|
||||
// Content-Type of the given data. It considers at most the
|
||||
// first 512 bytes of data. DetectContentType always returns
|
||||
// a valid MIME type: if it cannot determine a more specific one, it
|
||||
// returns "application/octet-stream".
|
||||
func DetectContentType(data []byte) string {
|
||||
if len(data) > sniffLen {
|
||||
data = data[:sniffLen]
|
||||
}
|
||||
|
||||
// Index of the first non-whitespace byte in data.
|
||||
firstNonWS := 0
|
||||
for ; firstNonWS < len(data) && isWS(data[firstNonWS]); firstNonWS++ {
|
||||
}
|
||||
|
||||
for _, sig := range sniffSignatures {
|
||||
if ct := sig.match(data, firstNonWS); ct != "" {
|
||||
return ct
|
||||
}
|
||||
}
|
||||
|
||||
return "application/octet-stream" // fallback
|
||||
}
|
||||
|
||||
func isWS(b byte) bool {
|
||||
return bytes.IndexByte([]byte("\t\n\x0C\r "), b) != -1
|
||||
}
|
||||
|
||||
type sniffSig interface {
|
||||
// match returns the MIME type of the data, or "" if unknown.
|
||||
match(data []byte, firstNonWS int) string
|
||||
}
|
||||
|
||||
// Data matching the table in section 6.
|
||||
var sniffSignatures = []sniffSig{
|
||||
htmlSig("<!DOCTYPE HTML"),
|
||||
htmlSig("<HTML"),
|
||||
htmlSig("<HEAD"),
|
||||
htmlSig("<SCRIPT"),
|
||||
htmlSig("<IFRAME"),
|
||||
htmlSig("<H1"),
|
||||
htmlSig("<DIV"),
|
||||
htmlSig("<FONT"),
|
||||
htmlSig("<TABLE"),
|
||||
htmlSig("<A"),
|
||||
htmlSig("<STYLE"),
|
||||
htmlSig("<TITLE"),
|
||||
htmlSig("<B"),
|
||||
htmlSig("<BODY"),
|
||||
htmlSig("<BR"),
|
||||
htmlSig("<P"),
|
||||
htmlSig("<!--"),
|
||||
|
||||
&maskedSig{mask: []byte("\xFF\xFF\xFF\xFF\xFF"), pat: []byte("<?xml"), skipWS: true, ct: "text/xml; charset=utf-8"},
|
||||
|
||||
&exactSig{[]byte("%PDF-"), "application/pdf"},
|
||||
&exactSig{[]byte("%!PS-Adobe-"), "application/postscript"},
|
||||
|
||||
// UTF BOMs.
|
||||
&maskedSig{mask: []byte("\xFF\xFF\x00\x00"), pat: []byte("\xFE\xFF\x00\x00"), ct: "text/plain; charset=utf-16be"},
|
||||
&maskedSig{mask: []byte("\xFF\xFF\x00\x00"), pat: []byte("\xFF\xFE\x00\x00"), ct: "text/plain; charset=utf-16le"},
|
||||
&maskedSig{mask: []byte("\xFF\xFF\xFF\x00"), pat: []byte("\xEF\xBB\xBF\x00"), ct: "text/plain; charset=utf-8"},
|
||||
|
||||
&exactSig{[]byte("GIF87a"), "image/gif"},
|
||||
&exactSig{[]byte("GIF89a"), "image/gif"},
|
||||
&exactSig{[]byte("\x89\x50\x4E\x47\x0D\x0A\x1A\x0A"), "image/png"},
|
||||
&exactSig{[]byte("\xFF\xD8\xFF"), "image/jpeg"},
|
||||
&exactSig{[]byte("BM"), "image/bmp"},
|
||||
&maskedSig{
|
||||
mask: []byte("\xFF\xFF\xFF\xFF\x00\x00\x00\x00\xFF\xFF\xFF\xFF\xFF\xFF"),
|
||||
pat: []byte("RIFF\x00\x00\x00\x00WEBPVP"),
|
||||
ct: "image/webp",
|
||||
},
|
||||
&exactSig{[]byte("\x00\x00\x01\x00"), "image/vnd.microsoft.icon"},
|
||||
&exactSig{[]byte("\x4F\x67\x67\x53\x00"), "application/ogg"},
|
||||
&maskedSig{
|
||||
mask: []byte("\xFF\xFF\xFF\xFF\x00\x00\x00\x00\xFF\xFF\xFF\xFF"),
|
||||
pat: []byte("RIFF\x00\x00\x00\x00WAVE"),
|
||||
ct: "audio/wave",
|
||||
},
|
||||
&exactSig{[]byte("\x1A\x45\xDF\xA3"), "video/webm"},
|
||||
&exactSig{[]byte("\x52\x61\x72\x20\x1A\x07\x00"), "application/x-rar-compressed"},
|
||||
&exactSig{[]byte("\x50\x4B\x03\x04"), "application/zip"},
|
||||
&exactSig{[]byte("\x1F\x8B\x08"), "application/x-gzip"},
|
||||
|
||||
// TODO(dsymonds): Re-enable this when the spec is sorted w.r.t. MP4.
|
||||
//mp4Sig(0),
|
||||
|
||||
textSig(0), // should be last
|
||||
}
|
||||
|
||||
type exactSig struct {
|
||||
sig []byte
|
||||
ct string
|
||||
}
|
||||
|
||||
func (e *exactSig) match(data []byte, firstNonWS int) string {
|
||||
if bytes.HasPrefix(data, e.sig) {
|
||||
return e.ct
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
type maskedSig struct {
|
||||
mask, pat []byte
|
||||
skipWS bool
|
||||
ct string
|
||||
}
|
||||
|
||||
func (m *maskedSig) match(data []byte, firstNonWS int) string {
|
||||
if m.skipWS {
|
||||
data = data[firstNonWS:]
|
||||
}
|
||||
if len(data) < len(m.mask) {
|
||||
return ""
|
||||
}
|
||||
for i, mask := range m.mask {
|
||||
db := data[i] & mask
|
||||
if db != m.pat[i] {
|
||||
return ""
|
||||
}
|
||||
}
|
||||
return m.ct
|
||||
}
|
||||
|
||||
type htmlSig []byte
|
||||
|
||||
func (h htmlSig) match(data []byte, firstNonWS int) string {
|
||||
data = data[firstNonWS:]
|
||||
if len(data) < len(h)+1 {
|
||||
return ""
|
||||
}
|
||||
for i, b := range h {
|
||||
db := data[i]
|
||||
if 'A' <= b && b <= 'Z' {
|
||||
db &= 0xDF
|
||||
}
|
||||
if b != db {
|
||||
return ""
|
||||
}
|
||||
}
|
||||
// Next byte must be space or right angle bracket.
|
||||
if db := data[len(h)]; db != ' ' && db != '>' {
|
||||
return ""
|
||||
}
|
||||
return "text/html; charset=utf-8"
|
||||
}
|
||||
|
||||
type mp4Sig int
|
||||
|
||||
func (mp4Sig) match(data []byte, firstNonWS int) string {
|
||||
// c.f. section 6.1.
|
||||
if len(data) < 8 {
|
||||
return ""
|
||||
}
|
||||
boxSize := int(binary.BigEndian.Uint32(data[:4]))
|
||||
if boxSize%4 != 0 || len(data) < boxSize {
|
||||
return ""
|
||||
}
|
||||
if !bytes.Equal(data[4:8], []byte("ftyp")) {
|
||||
return ""
|
||||
}
|
||||
for st := 8; st < boxSize; st += 4 {
|
||||
if st == 12 {
|
||||
// minor version number
|
||||
continue
|
||||
}
|
||||
seg := string(data[st : st+3])
|
||||
switch seg {
|
||||
case "mp4", "iso", "M4V", "M4P", "M4B":
|
||||
return "video/mp4"
|
||||
/* The remainder are not in the spec.
|
||||
case "M4A":
|
||||
return "audio/mp4"
|
||||
case "3gp":
|
||||
return "video/3gpp"
|
||||
case "jp2":
|
||||
return "image/jp2" // JPEG 2000
|
||||
*/
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
type textSig int
|
||||
|
||||
func (textSig) match(data []byte, firstNonWS int) string {
|
||||
// c.f. section 5, step 4.
|
||||
for _, b := range data[firstNonWS:] {
|
||||
switch {
|
||||
case 0x00 <= b && b <= 0x08,
|
||||
b == 0x0B,
|
||||
0x0E <= b && b <= 0x1A,
|
||||
0x1C <= b && b <= 0x1F:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
return "text/plain; charset=utf-8"
|
||||
}
|
|
@ -0,0 +1,120 @@
|
|||
// Copyright 2009 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package http
|
||||
|
||||
// HTTP status codes, defined in RFC 2616.
|
||||
const (
|
||||
StatusContinue = 100
|
||||
StatusSwitchingProtocols = 101
|
||||
|
||||
StatusOK = 200
|
||||
StatusCreated = 201
|
||||
StatusAccepted = 202
|
||||
StatusNonAuthoritativeInfo = 203
|
||||
StatusNoContent = 204
|
||||
StatusResetContent = 205
|
||||
StatusPartialContent = 206
|
||||
|
||||
StatusMultipleChoices = 300
|
||||
StatusMovedPermanently = 301
|
||||
StatusFound = 302
|
||||
StatusSeeOther = 303
|
||||
StatusNotModified = 304
|
||||
StatusUseProxy = 305
|
||||
StatusTemporaryRedirect = 307
|
||||
|
||||
StatusBadRequest = 400
|
||||
StatusUnauthorized = 401
|
||||
StatusPaymentRequired = 402
|
||||
StatusForbidden = 403
|
||||
StatusNotFound = 404
|
||||
StatusMethodNotAllowed = 405
|
||||
StatusNotAcceptable = 406
|
||||
StatusProxyAuthRequired = 407
|
||||
StatusRequestTimeout = 408
|
||||
StatusConflict = 409
|
||||
StatusGone = 410
|
||||
StatusLengthRequired = 411
|
||||
StatusPreconditionFailed = 412
|
||||
StatusRequestEntityTooLarge = 413
|
||||
StatusRequestURITooLong = 414
|
||||
StatusUnsupportedMediaType = 415
|
||||
StatusRequestedRangeNotSatisfiable = 416
|
||||
StatusExpectationFailed = 417
|
||||
StatusTeapot = 418
|
||||
|
||||
StatusInternalServerError = 500
|
||||
StatusNotImplemented = 501
|
||||
StatusBadGateway = 502
|
||||
StatusServiceUnavailable = 503
|
||||
StatusGatewayTimeout = 504
|
||||
StatusHTTPVersionNotSupported = 505
|
||||
|
||||
// New HTTP status codes from RFC 6585. Not exported yet in Go 1.1.
|
||||
// See discussion at https://codereview.appspot.com/7678043/
|
||||
statusPreconditionRequired = 428
|
||||
statusTooManyRequests = 429
|
||||
statusRequestHeaderFieldsTooLarge = 431
|
||||
statusNetworkAuthenticationRequired = 511
|
||||
)
|
||||
|
||||
var statusText = map[int]string{
|
||||
StatusContinue: "Continue",
|
||||
StatusSwitchingProtocols: "Switching Protocols",
|
||||
|
||||
StatusOK: "OK",
|
||||
StatusCreated: "Created",
|
||||
StatusAccepted: "Accepted",
|
||||
StatusNonAuthoritativeInfo: "Non-Authoritative Information",
|
||||
StatusNoContent: "No Content",
|
||||
StatusResetContent: "Reset Content",
|
||||
StatusPartialContent: "Partial Content",
|
||||
|
||||
StatusMultipleChoices: "Multiple Choices",
|
||||
StatusMovedPermanently: "Moved Permanently",
|
||||
StatusFound: "Found",
|
||||
StatusSeeOther: "See Other",
|
||||
StatusNotModified: "Not Modified",
|
||||
StatusUseProxy: "Use Proxy",
|
||||
StatusTemporaryRedirect: "Temporary Redirect",
|
||||
|
||||
StatusBadRequest: "Bad Request",
|
||||
StatusUnauthorized: "Unauthorized",
|
||||
StatusPaymentRequired: "Payment Required",
|
||||
StatusForbidden: "Forbidden",
|
||||
StatusNotFound: "Not Found",
|
||||
StatusMethodNotAllowed: "Method Not Allowed",
|
||||
StatusNotAcceptable: "Not Acceptable",
|
||||
StatusProxyAuthRequired: "Proxy Authentication Required",
|
||||
StatusRequestTimeout: "Request Timeout",
|
||||
StatusConflict: "Conflict",
|
||||
StatusGone: "Gone",
|
||||
StatusLengthRequired: "Length Required",
|
||||
StatusPreconditionFailed: "Precondition Failed",
|
||||
StatusRequestEntityTooLarge: "Request Entity Too Large",
|
||||
StatusRequestURITooLong: "Request URI Too Long",
|
||||
StatusUnsupportedMediaType: "Unsupported Media Type",
|
||||
StatusRequestedRangeNotSatisfiable: "Requested Range Not Satisfiable",
|
||||
StatusExpectationFailed: "Expectation Failed",
|
||||
StatusTeapot: "I'm a teapot",
|
||||
|
||||
StatusInternalServerError: "Internal Server Error",
|
||||
StatusNotImplemented: "Not Implemented",
|
||||
StatusBadGateway: "Bad Gateway",
|
||||
StatusServiceUnavailable: "Service Unavailable",
|
||||
StatusGatewayTimeout: "Gateway Timeout",
|
||||
StatusHTTPVersionNotSupported: "HTTP Version Not Supported",
|
||||
|
||||
statusPreconditionRequired: "Precondition Required",
|
||||
statusTooManyRequests: "Too Many Requests",
|
||||
statusRequestHeaderFieldsTooLarge: "Request Header Fields Too Large",
|
||||
statusNetworkAuthenticationRequired: "Network Authentication Required",
|
||||
}
|
||||
|
||||
// StatusText returns a text for the HTTP status code. It returns the empty
|
||||
// string if the code is unknown.
|
||||
func StatusText(code int) string {
|
||||
return statusText[code]
|
||||
}
|
730
vendor/github.com/masterzen/azure-sdk-for-go/core/http/transfer.go
generated
vendored
Normal file
730
vendor/github.com/masterzen/azure-sdk-for-go/core/http/transfer.go
generated
vendored
Normal file
|
@ -0,0 +1,730 @@
|
|||
// Copyright 2009 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package http
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/textproto"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type errorReader struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func (r *errorReader) Read(p []byte) (n int, err error) {
|
||||
return 0, r.err
|
||||
}
|
||||
|
||||
// transferWriter inspects the fields of a user-supplied Request or Response,
|
||||
// sanitizes them without changing the user object and provides methods for
|
||||
// writing the respective header, body and trailer in wire format.
|
||||
type transferWriter struct {
|
||||
Method string
|
||||
Body io.Reader
|
||||
BodyCloser io.Closer
|
||||
ResponseToHEAD bool
|
||||
ContentLength int64 // -1 means unknown, 0 means exactly none
|
||||
Close bool
|
||||
TransferEncoding []string
|
||||
Trailer Header
|
||||
}
|
||||
|
||||
func newTransferWriter(r interface{}) (t *transferWriter, err error) {
|
||||
t = &transferWriter{}
|
||||
|
||||
// Extract relevant fields
|
||||
atLeastHTTP11 := false
|
||||
switch rr := r.(type) {
|
||||
case *Request:
|
||||
if rr.ContentLength != 0 && rr.Body == nil {
|
||||
return nil, fmt.Errorf("http: Request.ContentLength=%d with nil Body", rr.ContentLength)
|
||||
}
|
||||
t.Method = rr.Method
|
||||
t.Body = rr.Body
|
||||
t.BodyCloser = rr.Body
|
||||
t.ContentLength = rr.ContentLength
|
||||
t.Close = rr.Close
|
||||
t.TransferEncoding = rr.TransferEncoding
|
||||
t.Trailer = rr.Trailer
|
||||
atLeastHTTP11 = rr.ProtoAtLeast(1, 1)
|
||||
if t.Body != nil && len(t.TransferEncoding) == 0 && atLeastHTTP11 {
|
||||
if t.ContentLength == 0 {
|
||||
// Test to see if it's actually zero or just unset.
|
||||
var buf [1]byte
|
||||
n, rerr := io.ReadFull(t.Body, buf[:])
|
||||
if rerr != nil && rerr != io.EOF {
|
||||
t.ContentLength = -1
|
||||
t.Body = &errorReader{rerr}
|
||||
} else if n == 1 {
|
||||
// Oh, guess there is data in this Body Reader after all.
|
||||
// The ContentLength field just wasn't set.
|
||||
// Stich the Body back together again, re-attaching our
|
||||
// consumed byte.
|
||||
t.ContentLength = -1
|
||||
t.Body = io.MultiReader(bytes.NewReader(buf[:]), t.Body)
|
||||
} else {
|
||||
// Body is actually empty.
|
||||
t.Body = nil
|
||||
t.BodyCloser = nil
|
||||
}
|
||||
}
|
||||
if t.ContentLength < 0 {
|
||||
t.TransferEncoding = []string{"chunked"}
|
||||
}
|
||||
}
|
||||
case *Response:
|
||||
if rr.Request != nil {
|
||||
t.Method = rr.Request.Method
|
||||
}
|
||||
t.Body = rr.Body
|
||||
t.BodyCloser = rr.Body
|
||||
t.ContentLength = rr.ContentLength
|
||||
t.Close = rr.Close
|
||||
t.TransferEncoding = rr.TransferEncoding
|
||||
t.Trailer = rr.Trailer
|
||||
atLeastHTTP11 = rr.ProtoAtLeast(1, 1)
|
||||
t.ResponseToHEAD = noBodyExpected(t.Method)
|
||||
}
|
||||
|
||||
// Sanitize Body,ContentLength,TransferEncoding
|
||||
if t.ResponseToHEAD {
|
||||
t.Body = nil
|
||||
if chunked(t.TransferEncoding) {
|
||||
t.ContentLength = -1
|
||||
}
|
||||
} else {
|
||||
if !atLeastHTTP11 || t.Body == nil {
|
||||
t.TransferEncoding = nil
|
||||
}
|
||||
if chunked(t.TransferEncoding) {
|
||||
t.ContentLength = -1
|
||||
} else if t.Body == nil { // no chunking, no body
|
||||
t.ContentLength = 0
|
||||
}
|
||||
}
|
||||
|
||||
// Sanitize Trailer
|
||||
if !chunked(t.TransferEncoding) {
|
||||
t.Trailer = nil
|
||||
}
|
||||
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func noBodyExpected(requestMethod string) bool {
|
||||
return requestMethod == "HEAD"
|
||||
}
|
||||
|
||||
func (t *transferWriter) shouldSendContentLength() bool {
|
||||
if chunked(t.TransferEncoding) {
|
||||
return false
|
||||
}
|
||||
if t.ContentLength > 0 {
|
||||
return true
|
||||
}
|
||||
// Many servers expect a Content-Length for these methods
|
||||
if t.Method == "POST" || t.Method == "PUT" {
|
||||
return true
|
||||
}
|
||||
if t.ContentLength == 0 && isIdentity(t.TransferEncoding) {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (t *transferWriter) WriteHeader(w io.Writer) error {
|
||||
if t.Close {
|
||||
if _, err := io.WriteString(w, "Connection: close\r\n"); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Write Content-Length and/or Transfer-Encoding whose values are a
|
||||
// function of the sanitized field triple (Body, ContentLength,
|
||||
// TransferEncoding)
|
||||
if t.shouldSendContentLength() {
|
||||
if _, err := io.WriteString(w, "Content-Length: "); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := io.WriteString(w, strconv.FormatInt(t.ContentLength, 10)+"\r\n"); err != nil {
|
||||
return err
|
||||
}
|
||||
} else if chunked(t.TransferEncoding) {
|
||||
if _, err := io.WriteString(w, "Transfer-Encoding: chunked\r\n"); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Write Trailer header
|
||||
if t.Trailer != nil {
|
||||
keys := make([]string, 0, len(t.Trailer))
|
||||
for k := range t.Trailer {
|
||||
k = CanonicalHeaderKey(k)
|
||||
switch k {
|
||||
case "Transfer-Encoding", "Trailer", "Content-Length":
|
||||
return &badStringError{"invalid Trailer key", k}
|
||||
}
|
||||
keys = append(keys, k)
|
||||
}
|
||||
if len(keys) > 0 {
|
||||
sort.Strings(keys)
|
||||
// TODO: could do better allocation-wise here, but trailers are rare,
|
||||
// so being lazy for now.
|
||||
if _, err := io.WriteString(w, "Trailer: "+strings.Join(keys, ",")+"\r\n"); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *transferWriter) WriteBody(w io.Writer) error {
|
||||
var err error
|
||||
var ncopy int64
|
||||
|
||||
// Write body
|
||||
if t.Body != nil {
|
||||
if chunked(t.TransferEncoding) {
|
||||
cw := newChunkedWriter(w)
|
||||
_, err = io.Copy(cw, t.Body)
|
||||
if err == nil {
|
||||
err = cw.Close()
|
||||
}
|
||||
} else if t.ContentLength == -1 {
|
||||
ncopy, err = io.Copy(w, t.Body)
|
||||
} else {
|
||||
ncopy, err = io.Copy(w, io.LimitReader(t.Body, t.ContentLength))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var nextra int64
|
||||
nextra, err = io.Copy(ioutil.Discard, t.Body)
|
||||
ncopy += nextra
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err = t.BodyCloser.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if !t.ResponseToHEAD && t.ContentLength != -1 && t.ContentLength != ncopy {
|
||||
return fmt.Errorf("http: Request.ContentLength=%d with Body length %d",
|
||||
t.ContentLength, ncopy)
|
||||
}
|
||||
|
||||
// TODO(petar): Place trailer writer code here.
|
||||
if chunked(t.TransferEncoding) {
|
||||
// Write Trailer header
|
||||
if t.Trailer != nil {
|
||||
if err := t.Trailer.Write(w); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
// Last chunk, empty trailer
|
||||
_, err = io.WriteString(w, "\r\n")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
type transferReader struct {
|
||||
// Input
|
||||
Header Header
|
||||
StatusCode int
|
||||
RequestMethod string
|
||||
ProtoMajor int
|
||||
ProtoMinor int
|
||||
// Output
|
||||
Body io.ReadCloser
|
||||
ContentLength int64
|
||||
TransferEncoding []string
|
||||
Close bool
|
||||
Trailer Header
|
||||
}
|
||||
|
||||
// bodyAllowedForStatus reports whether a given response status code
|
||||
// permits a body. See RFC2616, section 4.4.
|
||||
func bodyAllowedForStatus(status int) bool {
|
||||
switch {
|
||||
case status >= 100 && status <= 199:
|
||||
return false
|
||||
case status == 204:
|
||||
return false
|
||||
case status == 304:
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
var (
|
||||
suppressedHeaders304 = []string{"Content-Type", "Content-Length", "Transfer-Encoding"}
|
||||
suppressedHeadersNoBody = []string{"Content-Length", "Transfer-Encoding"}
|
||||
)
|
||||
|
||||
func suppressedHeaders(status int) []string {
|
||||
switch {
|
||||
case status == 304:
|
||||
// RFC 2616 section 10.3.5: "the response MUST NOT include other entity-headers"
|
||||
return suppressedHeaders304
|
||||
case !bodyAllowedForStatus(status):
|
||||
return suppressedHeadersNoBody
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// msg is *Request or *Response.
|
||||
func readTransfer(msg interface{}, r *bufio.Reader) (err error) {
|
||||
t := &transferReader{RequestMethod: "GET"}
|
||||
|
||||
// Unify input
|
||||
isResponse := false
|
||||
switch rr := msg.(type) {
|
||||
case *Response:
|
||||
t.Header = rr.Header
|
||||
t.StatusCode = rr.StatusCode
|
||||
t.ProtoMajor = rr.ProtoMajor
|
||||
t.ProtoMinor = rr.ProtoMinor
|
||||
t.Close = shouldClose(t.ProtoMajor, t.ProtoMinor, t.Header)
|
||||
isResponse = true
|
||||
if rr.Request != nil {
|
||||
t.RequestMethod = rr.Request.Method
|
||||
}
|
||||
case *Request:
|
||||
t.Header = rr.Header
|
||||
t.ProtoMajor = rr.ProtoMajor
|
||||
t.ProtoMinor = rr.ProtoMinor
|
||||
// Transfer semantics for Requests are exactly like those for
|
||||
// Responses with status code 200, responding to a GET method
|
||||
t.StatusCode = 200
|
||||
default:
|
||||
panic("unexpected type")
|
||||
}
|
||||
|
||||
// Default to HTTP/1.1
|
||||
if t.ProtoMajor == 0 && t.ProtoMinor == 0 {
|
||||
t.ProtoMajor, t.ProtoMinor = 1, 1
|
||||
}
|
||||
|
||||
// Transfer encoding, content length
|
||||
t.TransferEncoding, err = fixTransferEncoding(t.RequestMethod, t.Header)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
realLength, err := fixLength(isResponse, t.StatusCode, t.RequestMethod, t.Header, t.TransferEncoding)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if isResponse && t.RequestMethod == "HEAD" {
|
||||
if n, err := parseContentLength(t.Header.get("Content-Length")); err != nil {
|
||||
return err
|
||||
} else {
|
||||
t.ContentLength = n
|
||||
}
|
||||
} else {
|
||||
t.ContentLength = realLength
|
||||
}
|
||||
|
||||
// Trailer
|
||||
t.Trailer, err = fixTrailer(t.Header, t.TransferEncoding)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If there is no Content-Length or chunked Transfer-Encoding on a *Response
|
||||
// and the status is not 1xx, 204 or 304, then the body is unbounded.
|
||||
// See RFC2616, section 4.4.
|
||||
switch msg.(type) {
|
||||
case *Response:
|
||||
if realLength == -1 &&
|
||||
!chunked(t.TransferEncoding) &&
|
||||
bodyAllowedForStatus(t.StatusCode) {
|
||||
// Unbounded body.
|
||||
t.Close = true
|
||||
}
|
||||
}
|
||||
|
||||
// Prepare body reader. ContentLength < 0 means chunked encoding
|
||||
// or close connection when finished, since multipart is not supported yet
|
||||
switch {
|
||||
case chunked(t.TransferEncoding):
|
||||
if noBodyExpected(t.RequestMethod) {
|
||||
t.Body = eofReader
|
||||
} else {
|
||||
t.Body = &body{src: newChunkedReader(r), hdr: msg, r: r, closing: t.Close}
|
||||
}
|
||||
case realLength == 0:
|
||||
t.Body = eofReader
|
||||
case realLength > 0:
|
||||
t.Body = &body{src: io.LimitReader(r, realLength), closing: t.Close}
|
||||
default:
|
||||
// realLength < 0, i.e. "Content-Length" not mentioned in header
|
||||
if t.Close {
|
||||
// Close semantics (i.e. HTTP/1.0)
|
||||
t.Body = &body{src: r, closing: t.Close}
|
||||
} else {
|
||||
// Persistent connection (i.e. HTTP/1.1)
|
||||
t.Body = eofReader
|
||||
}
|
||||
}
|
||||
|
||||
// Unify output
|
||||
switch rr := msg.(type) {
|
||||
case *Request:
|
||||
rr.Body = t.Body
|
||||
rr.ContentLength = t.ContentLength
|
||||
rr.TransferEncoding = t.TransferEncoding
|
||||
rr.Close = t.Close
|
||||
rr.Trailer = t.Trailer
|
||||
case *Response:
|
||||
rr.Body = t.Body
|
||||
rr.ContentLength = t.ContentLength
|
||||
rr.TransferEncoding = t.TransferEncoding
|
||||
rr.Close = t.Close
|
||||
rr.Trailer = t.Trailer
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Checks whether chunked is part of the encodings stack
|
||||
func chunked(te []string) bool { return len(te) > 0 && te[0] == "chunked" }
|
||||
|
||||
// Checks whether the encoding is explicitly "identity".
|
||||
func isIdentity(te []string) bool { return len(te) == 1 && te[0] == "identity" }
|
||||
|
||||
// Sanitize transfer encoding
|
||||
func fixTransferEncoding(requestMethod string, header Header) ([]string, error) {
|
||||
raw, present := header["Transfer-Encoding"]
|
||||
if !present {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
delete(header, "Transfer-Encoding")
|
||||
|
||||
encodings := strings.Split(raw[0], ",")
|
||||
te := make([]string, 0, len(encodings))
|
||||
// TODO: Even though we only support "identity" and "chunked"
|
||||
// encodings, the loop below is designed with foresight. One
|
||||
// invariant that must be maintained is that, if present,
|
||||
// chunked encoding must always come first.
|
||||
for _, encoding := range encodings {
|
||||
encoding = strings.ToLower(strings.TrimSpace(encoding))
|
||||
// "identity" encoding is not recorded
|
||||
if encoding == "identity" {
|
||||
break
|
||||
}
|
||||
if encoding != "chunked" {
|
||||
return nil, &badStringError{"unsupported transfer encoding", encoding}
|
||||
}
|
||||
te = te[0 : len(te)+1]
|
||||
te[len(te)-1] = encoding
|
||||
}
|
||||
if len(te) > 1 {
|
||||
return nil, &badStringError{"too many transfer encodings", strings.Join(te, ",")}
|
||||
}
|
||||
if len(te) > 0 {
|
||||
// Chunked encoding trumps Content-Length. See RFC 2616
|
||||
// Section 4.4. Currently len(te) > 0 implies chunked
|
||||
// encoding.
|
||||
delete(header, "Content-Length")
|
||||
return te, nil
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Determine the expected body length, using RFC 2616 Section 4.4. This
|
||||
// function is not a method, because ultimately it should be shared by
|
||||
// ReadResponse and ReadRequest.
|
||||
func fixLength(isResponse bool, status int, requestMethod string, header Header, te []string) (int64, error) {
|
||||
|
||||
// Logic based on response type or status
|
||||
if noBodyExpected(requestMethod) {
|
||||
return 0, nil
|
||||
}
|
||||
if status/100 == 1 {
|
||||
return 0, nil
|
||||
}
|
||||
switch status {
|
||||
case 204, 304:
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// Logic based on Transfer-Encoding
|
||||
if chunked(te) {
|
||||
return -1, nil
|
||||
}
|
||||
|
||||
// Logic based on Content-Length
|
||||
cl := strings.TrimSpace(header.get("Content-Length"))
|
||||
if cl != "" {
|
||||
n, err := parseContentLength(cl)
|
||||
if err != nil {
|
||||
return -1, err
|
||||
}
|
||||
return n, nil
|
||||
} else {
|
||||
header.Del("Content-Length")
|
||||
}
|
||||
|
||||
if !isResponse && requestMethod == "GET" {
|
||||
// RFC 2616 doesn't explicitly permit nor forbid an
|
||||
// entity-body on a GET request so we permit one if
|
||||
// declared, but we default to 0 here (not -1 below)
|
||||
// if there's no mention of a body.
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// Body-EOF logic based on other methods (like closing, or chunked coding)
|
||||
return -1, nil
|
||||
}
|
||||
|
||||
// Determine whether to hang up after sending a request and body, or
|
||||
// receiving a response and body
|
||||
// 'header' is the request headers
|
||||
func shouldClose(major, minor int, header Header) bool {
|
||||
if major < 1 {
|
||||
return true
|
||||
} else if major == 1 && minor == 0 {
|
||||
if !strings.Contains(strings.ToLower(header.get("Connection")), "keep-alive") {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
} else {
|
||||
// TODO: Should split on commas, toss surrounding white space,
|
||||
// and check each field.
|
||||
if strings.ToLower(header.get("Connection")) == "close" {
|
||||
header.Del("Connection")
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Parse the trailer header
|
||||
func fixTrailer(header Header, te []string) (Header, error) {
|
||||
raw := header.get("Trailer")
|
||||
if raw == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
header.Del("Trailer")
|
||||
trailer := make(Header)
|
||||
keys := strings.Split(raw, ",")
|
||||
for _, key := range keys {
|
||||
key = CanonicalHeaderKey(strings.TrimSpace(key))
|
||||
switch key {
|
||||
case "Transfer-Encoding", "Trailer", "Content-Length":
|
||||
return nil, &badStringError{"bad trailer key", key}
|
||||
}
|
||||
trailer[key] = nil
|
||||
}
|
||||
if len(trailer) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
if !chunked(te) {
|
||||
// Trailer and no chunking
|
||||
return nil, ErrUnexpectedTrailer
|
||||
}
|
||||
return trailer, nil
|
||||
}
|
||||
|
||||
// body turns a Reader into a ReadCloser.
|
||||
// Close ensures that the body has been fully read
|
||||
// and then reads the trailer if necessary.
|
||||
type body struct {
|
||||
src io.Reader
|
||||
hdr interface{} // non-nil (Response or Request) value means read trailer
|
||||
r *bufio.Reader // underlying wire-format reader for the trailer
|
||||
closing bool // is the connection to be closed after reading body?
|
||||
|
||||
mu sync.Mutex // guards closed, and calls to Read and Close
|
||||
closed bool
|
||||
}
|
||||
|
||||
// ErrBodyReadAfterClose is returned when reading a Request or Response
|
||||
// Body after the body has been closed. This typically happens when the body is
|
||||
// read after an HTTP Handler calls WriteHeader or Write on its
|
||||
// ResponseWriter.
|
||||
var ErrBodyReadAfterClose = errors.New("http: invalid Read on closed Body")
|
||||
|
||||
func (b *body) Read(p []byte) (n int, err error) {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
if b.closed {
|
||||
return 0, ErrBodyReadAfterClose
|
||||
}
|
||||
return b.readLocked(p)
|
||||
}
|
||||
|
||||
// Must hold b.mu.
|
||||
func (b *body) readLocked(p []byte) (n int, err error) {
|
||||
n, err = b.src.Read(p)
|
||||
|
||||
if err == io.EOF {
|
||||
// Chunked case. Read the trailer.
|
||||
if b.hdr != nil {
|
||||
if e := b.readTrailer(); e != nil {
|
||||
err = e
|
||||
}
|
||||
b.hdr = nil
|
||||
} else {
|
||||
// If the server declared the Content-Length, our body is a LimitedReader
|
||||
// and we need to check whether this EOF arrived early.
|
||||
if lr, ok := b.src.(*io.LimitedReader); ok && lr.N > 0 {
|
||||
err = io.ErrUnexpectedEOF
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If we can return an EOF here along with the read data, do
|
||||
// so. This is optional per the io.Reader contract, but doing
|
||||
// so helps the HTTP transport code recycle its connection
|
||||
// earlier (since it will see this EOF itself), even if the
|
||||
// client doesn't do future reads or Close.
|
||||
if err == nil && n > 0 {
|
||||
if lr, ok := b.src.(*io.LimitedReader); ok && lr.N == 0 {
|
||||
err = io.EOF
|
||||
}
|
||||
}
|
||||
|
||||
return n, err
|
||||
}
|
||||
|
||||
var (
|
||||
singleCRLF = []byte("\r\n")
|
||||
doubleCRLF = []byte("\r\n\r\n")
|
||||
)
|
||||
|
||||
func seeUpcomingDoubleCRLF(r *bufio.Reader) bool {
|
||||
for peekSize := 4; ; peekSize++ {
|
||||
// This loop stops when Peek returns an error,
|
||||
// which it does when r's buffer has been filled.
|
||||
buf, err := r.Peek(peekSize)
|
||||
if bytes.HasSuffix(buf, doubleCRLF) {
|
||||
return true
|
||||
}
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
var errTrailerEOF = errors.New("http: unexpected EOF reading trailer")
|
||||
|
||||
func (b *body) readTrailer() error {
|
||||
// The common case, since nobody uses trailers.
|
||||
buf, err := b.r.Peek(2)
|
||||
if bytes.Equal(buf, singleCRLF) {
|
||||
b.r.ReadByte()
|
||||
b.r.ReadByte()
|
||||
return nil
|
||||
}
|
||||
if len(buf) < 2 {
|
||||
return errTrailerEOF
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Make sure there's a header terminator coming up, to prevent
|
||||
// a DoS with an unbounded size Trailer. It's not easy to
|
||||
// slip in a LimitReader here, as textproto.NewReader requires
|
||||
// a concrete *bufio.Reader. Also, we can't get all the way
|
||||
// back up to our conn's LimitedReader that *might* be backing
|
||||
// this bufio.Reader. Instead, a hack: we iteratively Peek up
|
||||
// to the bufio.Reader's max size, looking for a double CRLF.
|
||||
// This limits the trailer to the underlying buffer size, typically 4kB.
|
||||
if !seeUpcomingDoubleCRLF(b.r) {
|
||||
return errors.New("http: suspiciously long trailer after chunked body")
|
||||
}
|
||||
|
||||
hdr, err := textproto.NewReader(b.r).ReadMIMEHeader()
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
return errTrailerEOF
|
||||
}
|
||||
return err
|
||||
}
|
||||
switch rr := b.hdr.(type) {
|
||||
case *Request:
|
||||
mergeSetHeader(&rr.Trailer, Header(hdr))
|
||||
case *Response:
|
||||
mergeSetHeader(&rr.Trailer, Header(hdr))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func mergeSetHeader(dst *Header, src Header) {
|
||||
if *dst == nil {
|
||||
*dst = src
|
||||
return
|
||||
}
|
||||
for k, vv := range src {
|
||||
(*dst)[k] = vv
|
||||
}
|
||||
}
|
||||
|
||||
func (b *body) Close() error {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
if b.closed {
|
||||
return nil
|
||||
}
|
||||
var err error
|
||||
switch {
|
||||
case b.hdr == nil && b.closing:
|
||||
// no trailer and closing the connection next.
|
||||
// no point in reading to EOF.
|
||||
default:
|
||||
// Fully consume the body, which will also lead to us reading
|
||||
// the trailer headers after the body, if present.
|
||||
_, err = io.Copy(ioutil.Discard, bodyLocked{b})
|
||||
}
|
||||
b.closed = true
|
||||
return err
|
||||
}
|
||||
|
||||
// bodyLocked is a io.Reader reading from a *body when its mutex is
|
||||
// already held.
|
||||
type bodyLocked struct {
|
||||
b *body
|
||||
}
|
||||
|
||||
func (bl bodyLocked) Read(p []byte) (n int, err error) {
|
||||
if bl.b.closed {
|
||||
return 0, ErrBodyReadAfterClose
|
||||
}
|
||||
return bl.b.readLocked(p)
|
||||
}
|
||||
|
||||
// parseContentLength trims whitespace from s and returns -1 if no value
|
||||
// is set, or the value if it's >= 0.
|
||||
func parseContentLength(cl string) (int64, error) {
|
||||
cl = strings.TrimSpace(cl)
|
||||
if cl == "" {
|
||||
return -1, nil
|
||||
}
|
||||
n, err := strconv.ParseInt(cl, 10, 64)
|
||||
if err != nil || n < 0 {
|
||||
return 0, &badStringError{"bad Content-Length", cl}
|
||||
}
|
||||
return n, nil
|
||||
|
||||
}
|
1208
vendor/github.com/masterzen/azure-sdk-for-go/core/http/transport.go
generated
vendored
Normal file
1208
vendor/github.com/masterzen/azure-sdk-for-go/core/http/transport.go
generated
vendored
Normal file
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,77 @@
|
|||
// Copyright 2009 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package tls
|
||||
|
||||
import "strconv"
|
||||
|
||||
type alert uint8
|
||||
|
||||
const (
|
||||
// alert level
|
||||
alertLevelWarning = 1
|
||||
alertLevelError = 2
|
||||
)
|
||||
|
||||
const (
|
||||
alertCloseNotify alert = 0
|
||||
alertUnexpectedMessage alert = 10
|
||||
alertBadRecordMAC alert = 20
|
||||
alertDecryptionFailed alert = 21
|
||||
alertRecordOverflow alert = 22
|
||||
alertDecompressionFailure alert = 30
|
||||
alertHandshakeFailure alert = 40
|
||||
alertBadCertificate alert = 42
|
||||
alertUnsupportedCertificate alert = 43
|
||||
alertCertificateRevoked alert = 44
|
||||
alertCertificateExpired alert = 45
|
||||
alertCertificateUnknown alert = 46
|
||||
alertIllegalParameter alert = 47
|
||||
alertUnknownCA alert = 48
|
||||
alertAccessDenied alert = 49
|
||||
alertDecodeError alert = 50
|
||||
alertDecryptError alert = 51
|
||||
alertProtocolVersion alert = 70
|
||||
alertInsufficientSecurity alert = 71
|
||||
alertInternalError alert = 80
|
||||
alertUserCanceled alert = 90
|
||||
alertNoRenegotiation alert = 100
|
||||
)
|
||||
|
||||
var alertText = map[alert]string{
|
||||
alertCloseNotify: "close notify",
|
||||
alertUnexpectedMessage: "unexpected message",
|
||||
alertBadRecordMAC: "bad record MAC",
|
||||
alertDecryptionFailed: "decryption failed",
|
||||
alertRecordOverflow: "record overflow",
|
||||
alertDecompressionFailure: "decompression failure",
|
||||
alertHandshakeFailure: "handshake failure",
|
||||
alertBadCertificate: "bad certificate",
|
||||
alertUnsupportedCertificate: "unsupported certificate",
|
||||
alertCertificateRevoked: "revoked certificate",
|
||||
alertCertificateExpired: "expired certificate",
|
||||
alertCertificateUnknown: "unknown certificate",
|
||||
alertIllegalParameter: "illegal parameter",
|
||||
alertUnknownCA: "unknown certificate authority",
|
||||
alertAccessDenied: "access denied",
|
||||
alertDecodeError: "error decoding message",
|
||||
alertDecryptError: "error decrypting message",
|
||||
alertProtocolVersion: "protocol version not supported",
|
||||
alertInsufficientSecurity: "insufficient security level",
|
||||
alertInternalError: "internal error",
|
||||
alertUserCanceled: "user canceled",
|
||||
alertNoRenegotiation: "no renegotiation",
|
||||
}
|
||||
|
||||
func (e alert) String() string {
|
||||
s, ok := alertText[e]
|
||||
if ok {
|
||||
return s
|
||||
}
|
||||
return "alert(" + strconv.Itoa(int(e)) + ")"
|
||||
}
|
||||
|
||||
func (e alert) Error() string {
|
||||
return e.String()
|
||||
}
|
270
vendor/github.com/masterzen/azure-sdk-for-go/core/tls/cipher_suites.go
generated
vendored
Normal file
270
vendor/github.com/masterzen/azure-sdk-for-go/core/tls/cipher_suites.go
generated
vendored
Normal file
|
@ -0,0 +1,270 @@
|
|||
// Copyright 2010 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package tls
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/des"
|
||||
"crypto/hmac"
|
||||
"crypto/rc4"
|
||||
"crypto/sha1"
|
||||
"crypto/x509"
|
||||
"hash"
|
||||
)
|
||||
|
||||
// a keyAgreement implements the client and server side of a TLS key agreement
|
||||
// protocol by generating and processing key exchange messages.
|
||||
type keyAgreement interface {
|
||||
// On the server side, the first two methods are called in order.
|
||||
|
||||
// In the case that the key agreement protocol doesn't use a
|
||||
// ServerKeyExchange message, generateServerKeyExchange can return nil,
|
||||
// nil.
|
||||
generateServerKeyExchange(*Config, *Certificate, *clientHelloMsg, *serverHelloMsg) (*serverKeyExchangeMsg, error)
|
||||
processClientKeyExchange(*Config, *Certificate, *clientKeyExchangeMsg, uint16) ([]byte, error)
|
||||
|
||||
// On the client side, the next two methods are called in order.
|
||||
|
||||
// This method may not be called if the server doesn't send a
|
||||
// ServerKeyExchange message.
|
||||
processServerKeyExchange(*Config, *clientHelloMsg, *serverHelloMsg, *x509.Certificate, *serverKeyExchangeMsg) error
|
||||
generateClientKeyExchange(*Config, *clientHelloMsg, *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error)
|
||||
}
|
||||
|
||||
const (
|
||||
// suiteECDH indicates that the cipher suite involves elliptic curve
|
||||
// Diffie-Hellman. This means that it should only be selected when the
|
||||
// client indicates that it supports ECC with a curve and point format
|
||||
// that we're happy with.
|
||||
suiteECDHE = 1 << iota
|
||||
// suiteECDSA indicates that the cipher suite involves an ECDSA
|
||||
// signature and therefore may only be selected when the server's
|
||||
// certificate is ECDSA. If this is not set then the cipher suite is
|
||||
// RSA based.
|
||||
suiteECDSA
|
||||
// suiteTLS12 indicates that the cipher suite should only be advertised
|
||||
// and accepted when using TLS 1.2.
|
||||
suiteTLS12
|
||||
)
|
||||
|
||||
// A cipherSuite is a specific combination of key agreement, cipher and MAC
|
||||
// function. All cipher suites currently assume RSA key agreement.
|
||||
type cipherSuite struct {
|
||||
id uint16
|
||||
// the lengths, in bytes, of the key material needed for each component.
|
||||
keyLen int
|
||||
macLen int
|
||||
ivLen int
|
||||
ka func(version uint16) keyAgreement
|
||||
// flags is a bitmask of the suite* values, above.
|
||||
flags int
|
||||
cipher func(key, iv []byte, isRead bool) interface{}
|
||||
mac func(version uint16, macKey []byte) macFunction
|
||||
aead func(key, fixedNonce []byte) cipher.AEAD
|
||||
}
|
||||
|
||||
var cipherSuites = []*cipherSuite{
|
||||
// Ciphersuite order is chosen so that ECDHE comes before plain RSA
|
||||
// and RC4 comes before AES (because of the Lucky13 attack).
|
||||
{TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, 16, 0, 4, ecdheRSAKA, suiteECDHE | suiteTLS12, nil, nil, aeadAESGCM},
|
||||
{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, 16, 0, 4, ecdheECDSAKA, suiteECDHE | suiteECDSA | suiteTLS12, nil, nil, aeadAESGCM},
|
||||
{TLS_ECDHE_RSA_WITH_RC4_128_SHA, 16, 20, 0, ecdheRSAKA, suiteECDHE, cipherRC4, macSHA1, nil},
|
||||
{TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, 16, 20, 0, ecdheECDSAKA, suiteECDHE | suiteECDSA, cipherRC4, macSHA1, nil},
|
||||
{TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, 16, 20, 16, ecdheRSAKA, suiteECDHE, cipherAES, macSHA1, nil},
|
||||
{TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, 16, 20, 16, ecdheECDSAKA, suiteECDHE | suiteECDSA, cipherAES, macSHA1, nil},
|
||||
{TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, 32, 20, 16, ecdheRSAKA, suiteECDHE, cipherAES, macSHA1, nil},
|
||||
{TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, 32, 20, 16, ecdheECDSAKA, suiteECDHE | suiteECDSA, cipherAES, macSHA1, nil},
|
||||
{TLS_RSA_WITH_RC4_128_SHA, 16, 20, 0, rsaKA, 0, cipherRC4, macSHA1, nil},
|
||||
{TLS_RSA_WITH_AES_128_CBC_SHA, 16, 20, 16, rsaKA, 0, cipherAES, macSHA1, nil},
|
||||
{TLS_RSA_WITH_AES_256_CBC_SHA, 32, 20, 16, rsaKA, 0, cipherAES, macSHA1, nil},
|
||||
{TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, 24, 20, 8, ecdheRSAKA, suiteECDHE, cipher3DES, macSHA1, nil},
|
||||
{TLS_RSA_WITH_3DES_EDE_CBC_SHA, 24, 20, 8, rsaKA, 0, cipher3DES, macSHA1, nil},
|
||||
}
|
||||
|
||||
func cipherRC4(key, iv []byte, isRead bool) interface{} {
|
||||
cipher, _ := rc4.NewCipher(key)
|
||||
return cipher
|
||||
}
|
||||
|
||||
func cipher3DES(key, iv []byte, isRead bool) interface{} {
|
||||
block, _ := des.NewTripleDESCipher(key)
|
||||
if isRead {
|
||||
return cipher.NewCBCDecrypter(block, iv)
|
||||
}
|
||||
return cipher.NewCBCEncrypter(block, iv)
|
||||
}
|
||||
|
||||
func cipherAES(key, iv []byte, isRead bool) interface{} {
|
||||
block, _ := aes.NewCipher(key)
|
||||
if isRead {
|
||||
return cipher.NewCBCDecrypter(block, iv)
|
||||
}
|
||||
return cipher.NewCBCEncrypter(block, iv)
|
||||
}
|
||||
|
||||
// macSHA1 returns a macFunction for the given protocol version.
|
||||
func macSHA1(version uint16, key []byte) macFunction {
|
||||
if version == VersionSSL30 {
|
||||
mac := ssl30MAC{
|
||||
h: sha1.New(),
|
||||
key: make([]byte, len(key)),
|
||||
}
|
||||
copy(mac.key, key)
|
||||
return mac
|
||||
}
|
||||
return tls10MAC{hmac.New(sha1.New, key)}
|
||||
}
|
||||
|
||||
type macFunction interface {
|
||||
Size() int
|
||||
MAC(digestBuf, seq, header, data []byte) []byte
|
||||
}
|
||||
|
||||
// fixedNonceAEAD wraps an AEAD and prefixes a fixed portion of the nonce to
|
||||
// each call.
|
||||
type fixedNonceAEAD struct {
|
||||
// sealNonce and openNonce are buffers where the larger nonce will be
|
||||
// constructed. Since a seal and open operation may be running
|
||||
// concurrently, there is a separate buffer for each.
|
||||
sealNonce, openNonce []byte
|
||||
aead cipher.AEAD
|
||||
}
|
||||
|
||||
func (f *fixedNonceAEAD) NonceSize() int { return 8 }
|
||||
func (f *fixedNonceAEAD) Overhead() int { return f.aead.Overhead() }
|
||||
|
||||
func (f *fixedNonceAEAD) Seal(out, nonce, plaintext, additionalData []byte) []byte {
|
||||
copy(f.sealNonce[len(f.sealNonce)-8:], nonce)
|
||||
return f.aead.Seal(out, f.sealNonce, plaintext, additionalData)
|
||||
}
|
||||
|
||||
func (f *fixedNonceAEAD) Open(out, nonce, plaintext, additionalData []byte) ([]byte, error) {
|
||||
copy(f.openNonce[len(f.openNonce)-8:], nonce)
|
||||
return f.aead.Open(out, f.openNonce, plaintext, additionalData)
|
||||
}
|
||||
|
||||
func aeadAESGCM(key, fixedNonce []byte) cipher.AEAD {
|
||||
aes, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
aead, err := cipher.NewGCM(aes)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
nonce1, nonce2 := make([]byte, 12), make([]byte, 12)
|
||||
copy(nonce1, fixedNonce)
|
||||
copy(nonce2, fixedNonce)
|
||||
|
||||
return &fixedNonceAEAD{nonce1, nonce2, aead}
|
||||
}
|
||||
|
||||
// ssl30MAC implements the SSLv3 MAC function, as defined in
|
||||
// www.mozilla.org/projects/security/pki/nss/ssl/draft302.txt section 5.2.3.1
|
||||
type ssl30MAC struct {
|
||||
h hash.Hash
|
||||
key []byte
|
||||
}
|
||||
|
||||
func (s ssl30MAC) Size() int {
|
||||
return s.h.Size()
|
||||
}
|
||||
|
||||
var ssl30Pad1 = [48]byte{0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36}
|
||||
|
||||
var ssl30Pad2 = [48]byte{0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c}
|
||||
|
||||
func (s ssl30MAC) MAC(digestBuf, seq, header, data []byte) []byte {
|
||||
padLength := 48
|
||||
if s.h.Size() == 20 {
|
||||
padLength = 40
|
||||
}
|
||||
|
||||
s.h.Reset()
|
||||
s.h.Write(s.key)
|
||||
s.h.Write(ssl30Pad1[:padLength])
|
||||
s.h.Write(seq)
|
||||
s.h.Write(header[:1])
|
||||
s.h.Write(header[3:5])
|
||||
s.h.Write(data)
|
||||
digestBuf = s.h.Sum(digestBuf[:0])
|
||||
|
||||
s.h.Reset()
|
||||
s.h.Write(s.key)
|
||||
s.h.Write(ssl30Pad2[:padLength])
|
||||
s.h.Write(digestBuf)
|
||||
return s.h.Sum(digestBuf[:0])
|
||||
}
|
||||
|
||||
// tls10MAC implements the TLS 1.0 MAC function. RFC 2246, section 6.2.3.
|
||||
type tls10MAC struct {
|
||||
h hash.Hash
|
||||
}
|
||||
|
||||
func (s tls10MAC) Size() int {
|
||||
return s.h.Size()
|
||||
}
|
||||
|
||||
func (s tls10MAC) MAC(digestBuf, seq, header, data []byte) []byte {
|
||||
s.h.Reset()
|
||||
s.h.Write(seq)
|
||||
s.h.Write(header)
|
||||
s.h.Write(data)
|
||||
return s.h.Sum(digestBuf[:0])
|
||||
}
|
||||
|
||||
func rsaKA(version uint16) keyAgreement {
|
||||
return rsaKeyAgreement{}
|
||||
}
|
||||
|
||||
func ecdheECDSAKA(version uint16) keyAgreement {
|
||||
return &ecdheKeyAgreement{
|
||||
sigType: signatureECDSA,
|
||||
version: version,
|
||||
}
|
||||
}
|
||||
|
||||
func ecdheRSAKA(version uint16) keyAgreement {
|
||||
return &ecdheKeyAgreement{
|
||||
sigType: signatureRSA,
|
||||
version: version,
|
||||
}
|
||||
}
|
||||
|
||||
// mutualCipherSuite returns a cipherSuite given a list of supported
|
||||
// ciphersuites and the id requested by the peer.
|
||||
func mutualCipherSuite(have []uint16, want uint16) *cipherSuite {
|
||||
for _, id := range have {
|
||||
if id == want {
|
||||
for _, suite := range cipherSuites {
|
||||
if suite.id == want {
|
||||
return suite
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// A list of the possible cipher suite ids. Taken from
|
||||
// http://www.iana.org/assignments/tls-parameters/tls-parameters.xml
|
||||
const (
|
||||
TLS_RSA_WITH_RC4_128_SHA uint16 = 0x0005
|
||||
TLS_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0x000a
|
||||
TLS_RSA_WITH_AES_128_CBC_SHA uint16 = 0x002f
|
||||
TLS_RSA_WITH_AES_256_CBC_SHA uint16 = 0x0035
|
||||
TLS_ECDHE_ECDSA_WITH_RC4_128_SHA uint16 = 0xc007
|
||||
TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA uint16 = 0xc009
|
||||
TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA uint16 = 0xc00a
|
||||
TLS_ECDHE_RSA_WITH_RC4_128_SHA uint16 = 0xc011
|
||||
TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0xc012
|
||||
TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA uint16 = 0xc013
|
||||
TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA uint16 = 0xc014
|
||||
TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 uint16 = 0xc02f
|
||||
TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 uint16 = 0xc02b
|
||||
)
|
|
@ -0,0 +1,438 @@
|
|||
// Copyright 2009 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package tls
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/rand"
|
||||
"crypto/x509"
|
||||
"io"
|
||||
"math/big"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
VersionSSL30 = 0x0300
|
||||
VersionTLS10 = 0x0301
|
||||
VersionTLS11 = 0x0302
|
||||
VersionTLS12 = 0x0303
|
||||
)
|
||||
|
||||
const (
|
||||
maxPlaintext = 16384 // maximum plaintext payload length
|
||||
maxCiphertext = 16384 + 2048 // maximum ciphertext payload length
|
||||
recordHeaderLen = 5 // record header length
|
||||
maxHandshake = 65536 // maximum handshake we support (protocol max is 16 MB)
|
||||
|
||||
minVersion = VersionSSL30
|
||||
maxVersion = VersionTLS12
|
||||
)
|
||||
|
||||
// TLS record types.
|
||||
type recordType uint8
|
||||
|
||||
const (
|
||||
recordTypeChangeCipherSpec recordType = 20
|
||||
recordTypeAlert recordType = 21
|
||||
recordTypeHandshake recordType = 22
|
||||
recordTypeApplicationData recordType = 23
|
||||
)
|
||||
|
||||
// TLS handshake message types.
|
||||
const (
|
||||
typeHelloRequest uint8 = 0
|
||||
typeClientHello uint8 = 1
|
||||
typeServerHello uint8 = 2
|
||||
typeNewSessionTicket uint8 = 4
|
||||
typeCertificate uint8 = 11
|
||||
typeServerKeyExchange uint8 = 12
|
||||
typeCertificateRequest uint8 = 13
|
||||
typeServerHelloDone uint8 = 14
|
||||
typeCertificateVerify uint8 = 15
|
||||
typeClientKeyExchange uint8 = 16
|
||||
typeFinished uint8 = 20
|
||||
typeCertificateStatus uint8 = 22
|
||||
typeNextProtocol uint8 = 67 // Not IANA assigned
|
||||
)
|
||||
|
||||
// TLS compression types.
|
||||
const (
|
||||
compressionNone uint8 = 0
|
||||
)
|
||||
|
||||
// TLS extension numbers
|
||||
var (
|
||||
extensionServerName uint16 = 0
|
||||
extensionStatusRequest uint16 = 5
|
||||
extensionSupportedCurves uint16 = 10
|
||||
extensionSupportedPoints uint16 = 11
|
||||
extensionSignatureAlgorithms uint16 = 13
|
||||
extensionSessionTicket uint16 = 35
|
||||
extensionNextProtoNeg uint16 = 13172 // not IANA assigned
|
||||
)
|
||||
|
||||
// TLS Elliptic Curves
|
||||
// http://www.iana.org/assignments/tls-parameters/tls-parameters.xml#tls-parameters-8
|
||||
var (
|
||||
curveP256 uint16 = 23
|
||||
curveP384 uint16 = 24
|
||||
curveP521 uint16 = 25
|
||||
)
|
||||
|
||||
// TLS Elliptic Curve Point Formats
|
||||
// http://www.iana.org/assignments/tls-parameters/tls-parameters.xml#tls-parameters-9
|
||||
var (
|
||||
pointFormatUncompressed uint8 = 0
|
||||
)
|
||||
|
||||
// TLS CertificateStatusType (RFC 3546)
|
||||
const (
|
||||
statusTypeOCSP uint8 = 1
|
||||
)
|
||||
|
||||
// Certificate types (for certificateRequestMsg)
|
||||
const (
|
||||
certTypeRSASign = 1 // A certificate containing an RSA key
|
||||
certTypeDSSSign = 2 // A certificate containing a DSA key
|
||||
certTypeRSAFixedDH = 3 // A certificate containing a static DH key
|
||||
certTypeDSSFixedDH = 4 // A certificate containing a static DH key
|
||||
|
||||
// See RFC4492 sections 3 and 5.5.
|
||||
certTypeECDSASign = 64 // A certificate containing an ECDSA-capable public key, signed with ECDSA.
|
||||
certTypeRSAFixedECDH = 65 // A certificate containing an ECDH-capable public key, signed with RSA.
|
||||
certTypeECDSAFixedECDH = 66 // A certificate containing an ECDH-capable public key, signed with ECDSA.
|
||||
|
||||
// Rest of these are reserved by the TLS spec
|
||||
)
|
||||
|
||||
// Hash functions for TLS 1.2 (See RFC 5246, section A.4.1)
|
||||
const (
|
||||
hashSHA1 uint8 = 2
|
||||
hashSHA256 uint8 = 4
|
||||
)
|
||||
|
||||
// Signature algorithms for TLS 1.2 (See RFC 5246, section A.4.1)
|
||||
const (
|
||||
signatureRSA uint8 = 1
|
||||
signatureECDSA uint8 = 3
|
||||
)
|
||||
|
||||
// signatureAndHash mirrors the TLS 1.2, SignatureAndHashAlgorithm struct. See
|
||||
// RFC 5246, section A.4.1.
|
||||
type signatureAndHash struct {
|
||||
hash, signature uint8
|
||||
}
|
||||
|
||||
// supportedSKXSignatureAlgorithms contains the signature and hash algorithms
|
||||
// that the code advertises as supported in a TLS 1.2 ClientHello.
|
||||
var supportedSKXSignatureAlgorithms = []signatureAndHash{
|
||||
{hashSHA256, signatureRSA},
|
||||
{hashSHA256, signatureECDSA},
|
||||
{hashSHA1, signatureRSA},
|
||||
{hashSHA1, signatureECDSA},
|
||||
}
|
||||
|
||||
// supportedClientCertSignatureAlgorithms contains the signature and hash
|
||||
// algorithms that the code advertises as supported in a TLS 1.2
|
||||
// CertificateRequest.
|
||||
var supportedClientCertSignatureAlgorithms = []signatureAndHash{
|
||||
{hashSHA256, signatureRSA},
|
||||
{hashSHA256, signatureECDSA},
|
||||
}
|
||||
|
||||
// ConnectionState records basic TLS details about the connection.
|
||||
type ConnectionState struct {
|
||||
HandshakeComplete bool // TLS handshake is complete
|
||||
DidResume bool // connection resumes a previous TLS connection
|
||||
CipherSuite uint16 // cipher suite in use (TLS_RSA_WITH_RC4_128_SHA, ...)
|
||||
NegotiatedProtocol string // negotiated next protocol (from Config.NextProtos)
|
||||
NegotiatedProtocolIsMutual bool // negotiated protocol was advertised by server
|
||||
ServerName string // server name requested by client, if any (server side only)
|
||||
PeerCertificates []*x509.Certificate // certificate chain presented by remote peer
|
||||
VerifiedChains [][]*x509.Certificate // verified chains built from PeerCertificates
|
||||
}
|
||||
|
||||
// ClientAuthType declares the policy the server will follow for
|
||||
// TLS Client Authentication.
|
||||
type ClientAuthType int
|
||||
|
||||
const (
|
||||
NoClientCert ClientAuthType = iota
|
||||
RequestClientCert
|
||||
RequireAnyClientCert
|
||||
VerifyClientCertIfGiven
|
||||
RequireAndVerifyClientCert
|
||||
)
|
||||
|
||||
// A Config structure is used to configure a TLS client or server. After one
|
||||
// has been passed to a TLS function it must not be modified.
|
||||
type Config struct {
|
||||
// Rand provides the source of entropy for nonces and RSA blinding.
|
||||
// If Rand is nil, TLS uses the cryptographic random reader in package
|
||||
// crypto/rand.
|
||||
Rand io.Reader
|
||||
|
||||
// Time returns the current time as the number of seconds since the epoch.
|
||||
// If Time is nil, TLS uses time.Now.
|
||||
Time func() time.Time
|
||||
|
||||
// Certificates contains one or more certificate chains
|
||||
// to present to the other side of the connection.
|
||||
// Server configurations must include at least one certificate.
|
||||
Certificates []Certificate
|
||||
|
||||
// NameToCertificate maps from a certificate name to an element of
|
||||
// Certificates. Note that a certificate name can be of the form
|
||||
// '*.example.com' and so doesn't have to be a domain name as such.
|
||||
// See Config.BuildNameToCertificate
|
||||
// The nil value causes the first element of Certificates to be used
|
||||
// for all connections.
|
||||
NameToCertificate map[string]*Certificate
|
||||
|
||||
// RootCAs defines the set of root certificate authorities
|
||||
// that clients use when verifying server certificates.
|
||||
// If RootCAs is nil, TLS uses the host's root CA set.
|
||||
RootCAs *x509.CertPool
|
||||
|
||||
// NextProtos is a list of supported, application level protocols.
|
||||
NextProtos []string
|
||||
|
||||
// ServerName is included in the client's handshake to support virtual
|
||||
// hosting.
|
||||
ServerName string
|
||||
|
||||
// ClientAuth determines the server's policy for
|
||||
// TLS Client Authentication. The default is NoClientCert.
|
||||
ClientAuth ClientAuthType
|
||||
|
||||
// ClientCAs defines the set of root certificate authorities
|
||||
// that servers use if required to verify a client certificate
|
||||
// by the policy in ClientAuth.
|
||||
ClientCAs *x509.CertPool
|
||||
|
||||
// InsecureSkipVerify controls whether a client verifies the
|
||||
// server's certificate chain and host name.
|
||||
// If InsecureSkipVerify is true, TLS accepts any certificate
|
||||
// presented by the server and any host name in that certificate.
|
||||
// In this mode, TLS is susceptible to man-in-the-middle attacks.
|
||||
// This should be used only for testing.
|
||||
InsecureSkipVerify bool
|
||||
|
||||
// CipherSuites is a list of supported cipher suites. If CipherSuites
|
||||
// is nil, TLS uses a list of suites supported by the implementation.
|
||||
CipherSuites []uint16
|
||||
|
||||
// PreferServerCipherSuites controls whether the server selects the
|
||||
// client's most preferred ciphersuite, or the server's most preferred
|
||||
// ciphersuite. If true then the server's preference, as expressed in
|
||||
// the order of elements in CipherSuites, is used.
|
||||
PreferServerCipherSuites bool
|
||||
|
||||
// SessionTicketsDisabled may be set to true to disable session ticket
|
||||
// (resumption) support.
|
||||
SessionTicketsDisabled bool
|
||||
|
||||
// SessionTicketKey is used by TLS servers to provide session
|
||||
// resumption. See RFC 5077. If zero, it will be filled with
|
||||
// random data before the first server handshake.
|
||||
//
|
||||
// If multiple servers are terminating connections for the same host
|
||||
// they should all have the same SessionTicketKey. If the
|
||||
// SessionTicketKey leaks, previously recorded and future TLS
|
||||
// connections using that key are compromised.
|
||||
SessionTicketKey [32]byte
|
||||
|
||||
// MinVersion contains the minimum SSL/TLS version that is acceptable.
|
||||
// If zero, then SSLv3 is taken as the minimum.
|
||||
MinVersion uint16
|
||||
|
||||
// MaxVersion contains the maximum SSL/TLS version that is acceptable.
|
||||
// If zero, then the maximum version supported by this package is used,
|
||||
// which is currently TLS 1.2.
|
||||
MaxVersion uint16
|
||||
|
||||
serverInitOnce sync.Once // guards calling (*Config).serverInit
|
||||
}
|
||||
|
||||
func (c *Config) serverInit() {
|
||||
if c.SessionTicketsDisabled {
|
||||
return
|
||||
}
|
||||
|
||||
// If the key has already been set then we have nothing to do.
|
||||
for _, b := range c.SessionTicketKey {
|
||||
if b != 0 {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if _, err := io.ReadFull(c.rand(), c.SessionTicketKey[:]); err != nil {
|
||||
c.SessionTicketsDisabled = true
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Config) rand() io.Reader {
|
||||
r := c.Rand
|
||||
if r == nil {
|
||||
return rand.Reader
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
func (c *Config) time() time.Time {
|
||||
t := c.Time
|
||||
if t == nil {
|
||||
t = time.Now
|
||||
}
|
||||
return t()
|
||||
}
|
||||
|
||||
func (c *Config) cipherSuites() []uint16 {
|
||||
s := c.CipherSuites
|
||||
if s == nil {
|
||||
s = defaultCipherSuites()
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func (c *Config) minVersion() uint16 {
|
||||
if c == nil || c.MinVersion == 0 {
|
||||
return minVersion
|
||||
}
|
||||
return c.MinVersion
|
||||
}
|
||||
|
||||
func (c *Config) maxVersion() uint16 {
|
||||
if c == nil || c.MaxVersion == 0 {
|
||||
return maxVersion
|
||||
}
|
||||
return c.MaxVersion
|
||||
}
|
||||
|
||||
// mutualVersion returns the protocol version to use given the advertised
|
||||
// version of the peer.
|
||||
func (c *Config) mutualVersion(vers uint16) (uint16, bool) {
|
||||
minVersion := c.minVersion()
|
||||
maxVersion := c.maxVersion()
|
||||
|
||||
if vers < minVersion {
|
||||
return 0, false
|
||||
}
|
||||
if vers > maxVersion {
|
||||
vers = maxVersion
|
||||
}
|
||||
return vers, true
|
||||
}
|
||||
|
||||
// getCertificateForName returns the best certificate for the given name,
|
||||
// defaulting to the first element of c.Certificates if there are no good
|
||||
// options.
|
||||
func (c *Config) getCertificateForName(name string) *Certificate {
|
||||
if len(c.Certificates) == 1 || c.NameToCertificate == nil {
|
||||
// There's only one choice, so no point doing any work.
|
||||
return &c.Certificates[0]
|
||||
}
|
||||
|
||||
name = strings.ToLower(name)
|
||||
for len(name) > 0 && name[len(name)-1] == '.' {
|
||||
name = name[:len(name)-1]
|
||||
}
|
||||
|
||||
if cert, ok := c.NameToCertificate[name]; ok {
|
||||
return cert
|
||||
}
|
||||
|
||||
// try replacing labels in the name with wildcards until we get a
|
||||
// match.
|
||||
labels := strings.Split(name, ".")
|
||||
for i := range labels {
|
||||
labels[i] = "*"
|
||||
candidate := strings.Join(labels, ".")
|
||||
if cert, ok := c.NameToCertificate[candidate]; ok {
|
||||
return cert
|
||||
}
|
||||
}
|
||||
|
||||
// If nothing matches, return the first certificate.
|
||||
return &c.Certificates[0]
|
||||
}
|
||||
|
||||
// BuildNameToCertificate parses c.Certificates and builds c.NameToCertificate
|
||||
// from the CommonName and SubjectAlternateName fields of each of the leaf
|
||||
// certificates.
|
||||
func (c *Config) BuildNameToCertificate() {
|
||||
c.NameToCertificate = make(map[string]*Certificate)
|
||||
for i := range c.Certificates {
|
||||
cert := &c.Certificates[i]
|
||||
x509Cert, err := x509.ParseCertificate(cert.Certificate[0])
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if len(x509Cert.Subject.CommonName) > 0 {
|
||||
c.NameToCertificate[x509Cert.Subject.CommonName] = cert
|
||||
}
|
||||
for _, san := range x509Cert.DNSNames {
|
||||
c.NameToCertificate[san] = cert
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// A Certificate is a chain of one or more certificates, leaf first.
|
||||
type Certificate struct {
|
||||
Certificate [][]byte
|
||||
PrivateKey crypto.PrivateKey // supported types: *rsa.PrivateKey, *ecdsa.PrivateKey
|
||||
// OCSPStaple contains an optional OCSP response which will be served
|
||||
// to clients that request it.
|
||||
OCSPStaple []byte
|
||||
// Leaf is the parsed form of the leaf certificate, which may be
|
||||
// initialized using x509.ParseCertificate to reduce per-handshake
|
||||
// processing for TLS clients doing client authentication. If nil, the
|
||||
// leaf certificate will be parsed as needed.
|
||||
Leaf *x509.Certificate
|
||||
}
|
||||
|
||||
// A TLS record.
|
||||
type record struct {
|
||||
contentType recordType
|
||||
major, minor uint8
|
||||
payload []byte
|
||||
}
|
||||
|
||||
type handshakeMessage interface {
|
||||
marshal() []byte
|
||||
unmarshal([]byte) bool
|
||||
}
|
||||
|
||||
// TODO(jsing): Make these available to both crypto/x509 and crypto/tls.
|
||||
type dsaSignature struct {
|
||||
R, S *big.Int
|
||||
}
|
||||
|
||||
type ecdsaSignature dsaSignature
|
||||
|
||||
var emptyConfig Config
|
||||
|
||||
func defaultConfig() *Config {
|
||||
return &emptyConfig
|
||||
}
|
||||
|
||||
var (
|
||||
once sync.Once
|
||||
varDefaultCipherSuites []uint16
|
||||
)
|
||||
|
||||
func defaultCipherSuites() []uint16 {
|
||||
once.Do(initDefaultCipherSuites)
|
||||
return varDefaultCipherSuites
|
||||
}
|
||||
|
||||
func initDefaultCipherSuites() {
|
||||
varDefaultCipherSuites = make([]uint16, len(cipherSuites))
|
||||
for i, suite := range cipherSuites {
|
||||
varDefaultCipherSuites[i] = suite.id
|
||||
}
|
||||
}
|
File diff suppressed because it is too large
Load Diff
411
vendor/github.com/masterzen/azure-sdk-for-go/core/tls/handshake_client.go
generated
vendored
Normal file
411
vendor/github.com/masterzen/azure-sdk-for-go/core/tls/handshake_client.go
generated
vendored
Normal file
|
@ -0,0 +1,411 @@
|
|||
// Copyright 2009 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package tls
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/ecdsa"
|
||||
"crypto/rsa"
|
||||
"crypto/subtle"
|
||||
"crypto/x509"
|
||||
"encoding/asn1"
|
||||
"errors"
|
||||
"io"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
func (c *Conn) clientHandshake() error {
|
||||
if c.config == nil {
|
||||
c.config = defaultConfig()
|
||||
}
|
||||
|
||||
hello := &clientHelloMsg{
|
||||
vers: c.config.maxVersion(),
|
||||
compressionMethods: []uint8{compressionNone},
|
||||
random: make([]byte, 32),
|
||||
ocspStapling: true,
|
||||
serverName: c.config.ServerName,
|
||||
supportedCurves: []uint16{curveP256, curveP384, curveP521},
|
||||
supportedPoints: []uint8{pointFormatUncompressed},
|
||||
nextProtoNeg: len(c.config.NextProtos) > 0,
|
||||
}
|
||||
|
||||
possibleCipherSuites := c.config.cipherSuites()
|
||||
hello.cipherSuites = make([]uint16, 0, len(possibleCipherSuites))
|
||||
|
||||
NextCipherSuite:
|
||||
for _, suiteId := range possibleCipherSuites {
|
||||
for _, suite := range cipherSuites {
|
||||
if suite.id != suiteId {
|
||||
continue
|
||||
}
|
||||
// Don't advertise TLS 1.2-only cipher suites unless
|
||||
// we're attempting TLS 1.2.
|
||||
if hello.vers < VersionTLS12 && suite.flags&suiteTLS12 != 0 {
|
||||
continue
|
||||
}
|
||||
hello.cipherSuites = append(hello.cipherSuites, suiteId)
|
||||
continue NextCipherSuite
|
||||
}
|
||||
}
|
||||
|
||||
t := uint32(c.config.time().Unix())
|
||||
hello.random[0] = byte(t >> 24)
|
||||
hello.random[1] = byte(t >> 16)
|
||||
hello.random[2] = byte(t >> 8)
|
||||
hello.random[3] = byte(t)
|
||||
_, err := io.ReadFull(c.config.rand(), hello.random[4:])
|
||||
if err != nil {
|
||||
c.sendAlert(alertInternalError)
|
||||
return errors.New("short read from Rand")
|
||||
}
|
||||
|
||||
if hello.vers >= VersionTLS12 {
|
||||
hello.signatureAndHashes = supportedSKXSignatureAlgorithms
|
||||
}
|
||||
|
||||
c.writeRecord(recordTypeHandshake, hello.marshal())
|
||||
|
||||
msg, err := c.readHandshake()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
serverHello, ok := msg.(*serverHelloMsg)
|
||||
if !ok {
|
||||
return c.sendAlert(alertUnexpectedMessage)
|
||||
}
|
||||
|
||||
vers, ok := c.config.mutualVersion(serverHello.vers)
|
||||
if !ok || vers < VersionTLS10 {
|
||||
// TLS 1.0 is the minimum version supported as a client.
|
||||
return c.sendAlert(alertProtocolVersion)
|
||||
}
|
||||
c.vers = vers
|
||||
c.haveVers = true
|
||||
|
||||
finishedHash := newFinishedHash(c.vers)
|
||||
finishedHash.Write(hello.marshal())
|
||||
finishedHash.Write(serverHello.marshal())
|
||||
|
||||
if serverHello.compressionMethod != compressionNone {
|
||||
return c.sendAlert(alertUnexpectedMessage)
|
||||
}
|
||||
|
||||
if !hello.nextProtoNeg && serverHello.nextProtoNeg {
|
||||
c.sendAlert(alertHandshakeFailure)
|
||||
return errors.New("server advertised unrequested NPN")
|
||||
}
|
||||
|
||||
suite := mutualCipherSuite(c.config.cipherSuites(), serverHello.cipherSuite)
|
||||
if suite == nil {
|
||||
return c.sendAlert(alertHandshakeFailure)
|
||||
}
|
||||
|
||||
msg, err = c.readHandshake()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
certMsg, ok := msg.(*certificateMsg)
|
||||
if !ok || len(certMsg.certificates) == 0 {
|
||||
return c.sendAlert(alertUnexpectedMessage)
|
||||
}
|
||||
finishedHash.Write(certMsg.marshal())
|
||||
|
||||
certs := make([]*x509.Certificate, len(certMsg.certificates))
|
||||
for i, asn1Data := range certMsg.certificates {
|
||||
cert, err := x509.ParseCertificate(asn1Data)
|
||||
if err != nil {
|
||||
c.sendAlert(alertBadCertificate)
|
||||
return errors.New("failed to parse certificate from server: " + err.Error())
|
||||
}
|
||||
certs[i] = cert
|
||||
}
|
||||
|
||||
if !c.config.InsecureSkipVerify {
|
||||
opts := x509.VerifyOptions{
|
||||
Roots: c.config.RootCAs,
|
||||
CurrentTime: c.config.time(),
|
||||
DNSName: c.config.ServerName,
|
||||
Intermediates: x509.NewCertPool(),
|
||||
}
|
||||
|
||||
for i, cert := range certs {
|
||||
if i == 0 {
|
||||
continue
|
||||
}
|
||||
opts.Intermediates.AddCert(cert)
|
||||
}
|
||||
c.verifiedChains, err = certs[0].Verify(opts)
|
||||
if err != nil {
|
||||
c.sendAlert(alertBadCertificate)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
switch certs[0].PublicKey.(type) {
|
||||
case *rsa.PublicKey, *ecdsa.PublicKey:
|
||||
break
|
||||
default:
|
||||
return c.sendAlert(alertUnsupportedCertificate)
|
||||
}
|
||||
|
||||
c.peerCertificates = certs
|
||||
|
||||
if serverHello.ocspStapling {
|
||||
msg, err = c.readHandshake()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cs, ok := msg.(*certificateStatusMsg)
|
||||
if !ok {
|
||||
return c.sendAlert(alertUnexpectedMessage)
|
||||
}
|
||||
finishedHash.Write(cs.marshal())
|
||||
|
||||
if cs.statusType == statusTypeOCSP {
|
||||
c.ocspResponse = cs.response
|
||||
}
|
||||
}
|
||||
|
||||
msg, err = c.readHandshake()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
keyAgreement := suite.ka(c.vers)
|
||||
|
||||
skx, ok := msg.(*serverKeyExchangeMsg)
|
||||
if ok {
|
||||
finishedHash.Write(skx.marshal())
|
||||
err = keyAgreement.processServerKeyExchange(c.config, hello, serverHello, certs[0], skx)
|
||||
if err != nil {
|
||||
c.sendAlert(alertUnexpectedMessage)
|
||||
return err
|
||||
}
|
||||
|
||||
msg, err = c.readHandshake()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
var chainToSend *Certificate
|
||||
var certRequested bool
|
||||
certReq, ok := msg.(*certificateRequestMsg)
|
||||
if ok {
|
||||
certRequested = true
|
||||
|
||||
// RFC 4346 on the certificateAuthorities field:
|
||||
// A list of the distinguished names of acceptable certificate
|
||||
// authorities. These distinguished names may specify a desired
|
||||
// distinguished name for a root CA or for a subordinate CA;
|
||||
// thus, this message can be used to describe both known roots
|
||||
// and a desired authorization space. If the
|
||||
// certificate_authorities list is empty then the client MAY
|
||||
// send any certificate of the appropriate
|
||||
// ClientCertificateType, unless there is some external
|
||||
// arrangement to the contrary.
|
||||
|
||||
finishedHash.Write(certReq.marshal())
|
||||
|
||||
var rsaAvail, ecdsaAvail bool
|
||||
for _, certType := range certReq.certificateTypes {
|
||||
switch certType {
|
||||
case certTypeRSASign:
|
||||
rsaAvail = true
|
||||
case certTypeECDSASign:
|
||||
ecdsaAvail = true
|
||||
}
|
||||
}
|
||||
|
||||
// We need to search our list of client certs for one
|
||||
// where SignatureAlgorithm is RSA and the Issuer is in
|
||||
// certReq.certificateAuthorities
|
||||
findCert:
|
||||
for i, chain := range c.config.Certificates {
|
||||
if !rsaAvail && !ecdsaAvail {
|
||||
continue
|
||||
}
|
||||
|
||||
for j, cert := range chain.Certificate {
|
||||
x509Cert := chain.Leaf
|
||||
// parse the certificate if this isn't the leaf
|
||||
// node, or if chain.Leaf was nil
|
||||
if j != 0 || x509Cert == nil {
|
||||
if x509Cert, err = x509.ParseCertificate(cert); err != nil {
|
||||
c.sendAlert(alertInternalError)
|
||||
return errors.New("tls: failed to parse client certificate #" + strconv.Itoa(i) + ": " + err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
switch {
|
||||
case rsaAvail && x509Cert.PublicKeyAlgorithm == x509.RSA:
|
||||
case ecdsaAvail && x509Cert.PublicKeyAlgorithm == x509.ECDSA:
|
||||
default:
|
||||
continue findCert
|
||||
}
|
||||
|
||||
if len(certReq.certificateAuthorities) == 0 {
|
||||
// they gave us an empty list, so just take the
|
||||
// first RSA cert from c.config.Certificates
|
||||
chainToSend = &chain
|
||||
break findCert
|
||||
}
|
||||
|
||||
for _, ca := range certReq.certificateAuthorities {
|
||||
if bytes.Equal(x509Cert.RawIssuer, ca) {
|
||||
chainToSend = &chain
|
||||
break findCert
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
msg, err = c.readHandshake()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
shd, ok := msg.(*serverHelloDoneMsg)
|
||||
if !ok {
|
||||
return c.sendAlert(alertUnexpectedMessage)
|
||||
}
|
||||
finishedHash.Write(shd.marshal())
|
||||
|
||||
// If the server requested a certificate then we have to send a
|
||||
// Certificate message, even if it's empty because we don't have a
|
||||
// certificate to send.
|
||||
if certRequested {
|
||||
certMsg = new(certificateMsg)
|
||||
if chainToSend != nil {
|
||||
certMsg.certificates = chainToSend.Certificate
|
||||
}
|
||||
finishedHash.Write(certMsg.marshal())
|
||||
c.writeRecord(recordTypeHandshake, certMsg.marshal())
|
||||
}
|
||||
|
||||
preMasterSecret, ckx, err := keyAgreement.generateClientKeyExchange(c.config, hello, certs[0])
|
||||
if err != nil {
|
||||
c.sendAlert(alertInternalError)
|
||||
return err
|
||||
}
|
||||
if ckx != nil {
|
||||
finishedHash.Write(ckx.marshal())
|
||||
c.writeRecord(recordTypeHandshake, ckx.marshal())
|
||||
}
|
||||
|
||||
if chainToSend != nil {
|
||||
var signed []byte
|
||||
certVerify := &certificateVerifyMsg{
|
||||
hasSignatureAndHash: c.vers >= VersionTLS12,
|
||||
}
|
||||
|
||||
switch key := c.config.Certificates[0].PrivateKey.(type) {
|
||||
case *ecdsa.PrivateKey:
|
||||
digest, _, hashId := finishedHash.hashForClientCertificate(signatureECDSA)
|
||||
r, s, err := ecdsa.Sign(c.config.rand(), key, digest)
|
||||
if err == nil {
|
||||
signed, err = asn1.Marshal(ecdsaSignature{r, s})
|
||||
}
|
||||
certVerify.signatureAndHash.signature = signatureECDSA
|
||||
certVerify.signatureAndHash.hash = hashId
|
||||
case *rsa.PrivateKey:
|
||||
digest, hashFunc, hashId := finishedHash.hashForClientCertificate(signatureRSA)
|
||||
signed, err = rsa.SignPKCS1v15(c.config.rand(), key, hashFunc, digest)
|
||||
certVerify.signatureAndHash.signature = signatureRSA
|
||||
certVerify.signatureAndHash.hash = hashId
|
||||
default:
|
||||
err = errors.New("unknown private key type")
|
||||
}
|
||||
if err != nil {
|
||||
return c.sendAlert(alertInternalError)
|
||||
}
|
||||
certVerify.signature = signed
|
||||
|
||||
finishedHash.Write(certVerify.marshal())
|
||||
c.writeRecord(recordTypeHandshake, certVerify.marshal())
|
||||
}
|
||||
|
||||
masterSecret := masterFromPreMasterSecret(c.vers, preMasterSecret, hello.random, serverHello.random)
|
||||
clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV :=
|
||||
keysFromMasterSecret(c.vers, masterSecret, hello.random, serverHello.random, suite.macLen, suite.keyLen, suite.ivLen)
|
||||
|
||||
var clientCipher interface{}
|
||||
var clientHash macFunction
|
||||
if suite.cipher != nil {
|
||||
clientCipher = suite.cipher(clientKey, clientIV, false /* not for reading */)
|
||||
clientHash = suite.mac(c.vers, clientMAC)
|
||||
} else {
|
||||
clientCipher = suite.aead(clientKey, clientIV)
|
||||
}
|
||||
c.out.prepareCipherSpec(c.vers, clientCipher, clientHash)
|
||||
c.writeRecord(recordTypeChangeCipherSpec, []byte{1})
|
||||
|
||||
if serverHello.nextProtoNeg {
|
||||
nextProto := new(nextProtoMsg)
|
||||
proto, fallback := mutualProtocol(c.config.NextProtos, serverHello.nextProtos)
|
||||
nextProto.proto = proto
|
||||
c.clientProtocol = proto
|
||||
c.clientProtocolFallback = fallback
|
||||
|
||||
finishedHash.Write(nextProto.marshal())
|
||||
c.writeRecord(recordTypeHandshake, nextProto.marshal())
|
||||
}
|
||||
|
||||
finished := new(finishedMsg)
|
||||
finished.verifyData = finishedHash.clientSum(masterSecret)
|
||||
finishedHash.Write(finished.marshal())
|
||||
c.writeRecord(recordTypeHandshake, finished.marshal())
|
||||
|
||||
var serverCipher interface{}
|
||||
var serverHash macFunction
|
||||
if suite.cipher != nil {
|
||||
serverCipher = suite.cipher(serverKey, serverIV, true /* for reading */)
|
||||
serverHash = suite.mac(c.vers, serverMAC)
|
||||
} else {
|
||||
serverCipher = suite.aead(serverKey, serverIV)
|
||||
}
|
||||
c.in.prepareCipherSpec(c.vers, serverCipher, serverHash)
|
||||
c.readRecord(recordTypeChangeCipherSpec)
|
||||
if err := c.error(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
msg, err = c.readHandshake()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
serverFinished, ok := msg.(*finishedMsg)
|
||||
if !ok {
|
||||
return c.sendAlert(alertUnexpectedMessage)
|
||||
}
|
||||
|
||||
verify := finishedHash.serverSum(masterSecret)
|
||||
if len(verify) != len(serverFinished.verifyData) ||
|
||||
subtle.ConstantTimeCompare(verify, serverFinished.verifyData) != 1 {
|
||||
return c.sendAlert(alertHandshakeFailure)
|
||||
}
|
||||
|
||||
c.handshakeComplete = true
|
||||
c.cipherSuite = suite.id
|
||||
return nil
|
||||
}
|
||||
|
||||
// mutualProtocol finds the mutual Next Protocol Negotiation protocol given the
|
||||
// set of client and server supported protocols. The set of client supported
|
||||
// protocols must not be empty. It returns the resulting protocol and flag
|
||||
// indicating if the fallback case was reached.
|
||||
func mutualProtocol(clientProtos, serverProtos []string) (string, bool) {
|
||||
for _, s := range serverProtos {
|
||||
for _, c := range clientProtos {
|
||||
if s == c {
|
||||
return s, false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return clientProtos[0], true
|
||||
}
|
1304
vendor/github.com/masterzen/azure-sdk-for-go/core/tls/handshake_messages.go
generated
vendored
Normal file
1304
vendor/github.com/masterzen/azure-sdk-for-go/core/tls/handshake_messages.go
generated
vendored
Normal file
File diff suppressed because it is too large
Load Diff
638
vendor/github.com/masterzen/azure-sdk-for-go/core/tls/handshake_server.go
generated
vendored
Normal file
638
vendor/github.com/masterzen/azure-sdk-for-go/core/tls/handshake_server.go
generated
vendored
Normal file
|
@ -0,0 +1,638 @@
|
|||
// Copyright 2009 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package tls
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/rsa"
|
||||
"crypto/subtle"
|
||||
"crypto/x509"
|
||||
"encoding/asn1"
|
||||
"errors"
|
||||
"io"
|
||||
)
|
||||
|
||||
// serverHandshakeState contains details of a server handshake in progress.
|
||||
// It's discarded once the handshake has completed.
|
||||
type serverHandshakeState struct {
|
||||
c *Conn
|
||||
clientHello *clientHelloMsg
|
||||
hello *serverHelloMsg
|
||||
suite *cipherSuite
|
||||
ellipticOk bool
|
||||
ecdsaOk bool
|
||||
sessionState *sessionState
|
||||
finishedHash finishedHash
|
||||
masterSecret []byte
|
||||
certsFromClient [][]byte
|
||||
cert *Certificate
|
||||
}
|
||||
|
||||
// serverHandshake performs a TLS handshake as a server.
|
||||
func (c *Conn) serverHandshake() error {
|
||||
config := c.config
|
||||
|
||||
// If this is the first server handshake, we generate a random key to
|
||||
// encrypt the tickets with.
|
||||
config.serverInitOnce.Do(config.serverInit)
|
||||
|
||||
hs := serverHandshakeState{
|
||||
c: c,
|
||||
}
|
||||
isResume, err := hs.readClientHello()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// For an overview of TLS handshaking, see https://tools.ietf.org/html/rfc5246#section-7.3
|
||||
if isResume {
|
||||
// The client has included a session ticket and so we do an abbreviated handshake.
|
||||
if err := hs.doResumeHandshake(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := hs.establishKeys(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := hs.sendFinished(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := hs.readFinished(); err != nil {
|
||||
return err
|
||||
}
|
||||
c.didResume = true
|
||||
} else {
|
||||
// The client didn't include a session ticket, or it wasn't
|
||||
// valid so we do a full handshake.
|
||||
if err := hs.doFullHandshake(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := hs.establishKeys(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := hs.readFinished(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := hs.sendSessionTicket(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := hs.sendFinished(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
c.handshakeComplete = true
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// readClientHello reads a ClientHello message from the client and decides
|
||||
// whether we will perform session resumption.
|
||||
func (hs *serverHandshakeState) readClientHello() (isResume bool, err error) {
|
||||
config := hs.c.config
|
||||
c := hs.c
|
||||
|
||||
msg, err := c.readHandshake()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
var ok bool
|
||||
hs.clientHello, ok = msg.(*clientHelloMsg)
|
||||
if !ok {
|
||||
return false, c.sendAlert(alertUnexpectedMessage)
|
||||
}
|
||||
c.vers, ok = config.mutualVersion(hs.clientHello.vers)
|
||||
if !ok {
|
||||
return false, c.sendAlert(alertProtocolVersion)
|
||||
}
|
||||
c.haveVers = true
|
||||
|
||||
hs.finishedHash = newFinishedHash(c.vers)
|
||||
hs.finishedHash.Write(hs.clientHello.marshal())
|
||||
|
||||
hs.hello = new(serverHelloMsg)
|
||||
|
||||
supportedCurve := false
|
||||
Curves:
|
||||
for _, curve := range hs.clientHello.supportedCurves {
|
||||
switch curve {
|
||||
case curveP256, curveP384, curveP521:
|
||||
supportedCurve = true
|
||||
break Curves
|
||||
}
|
||||
}
|
||||
|
||||
supportedPointFormat := false
|
||||
for _, pointFormat := range hs.clientHello.supportedPoints {
|
||||
if pointFormat == pointFormatUncompressed {
|
||||
supportedPointFormat = true
|
||||
break
|
||||
}
|
||||
}
|
||||
hs.ellipticOk = supportedCurve && supportedPointFormat
|
||||
|
||||
foundCompression := false
|
||||
// We only support null compression, so check that the client offered it.
|
||||
for _, compression := range hs.clientHello.compressionMethods {
|
||||
if compression == compressionNone {
|
||||
foundCompression = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !foundCompression {
|
||||
return false, c.sendAlert(alertHandshakeFailure)
|
||||
}
|
||||
|
||||
hs.hello.vers = c.vers
|
||||
t := uint32(config.time().Unix())
|
||||
hs.hello.random = make([]byte, 32)
|
||||
hs.hello.random[0] = byte(t >> 24)
|
||||
hs.hello.random[1] = byte(t >> 16)
|
||||
hs.hello.random[2] = byte(t >> 8)
|
||||
hs.hello.random[3] = byte(t)
|
||||
_, err = io.ReadFull(config.rand(), hs.hello.random[4:])
|
||||
if err != nil {
|
||||
return false, c.sendAlert(alertInternalError)
|
||||
}
|
||||
hs.hello.compressionMethod = compressionNone
|
||||
if len(hs.clientHello.serverName) > 0 {
|
||||
c.serverName = hs.clientHello.serverName
|
||||
}
|
||||
// Although sending an empty NPN extension is reasonable, Firefox has
|
||||
// had a bug around this. Best to send nothing at all if
|
||||
// config.NextProtos is empty. See
|
||||
// https://code.google.com/p/go/issues/detail?id=5445.
|
||||
if hs.clientHello.nextProtoNeg && len(config.NextProtos) > 0 {
|
||||
hs.hello.nextProtoNeg = true
|
||||
hs.hello.nextProtos = config.NextProtos
|
||||
}
|
||||
|
||||
if len(config.Certificates) == 0 {
|
||||
return false, c.sendAlert(alertInternalError)
|
||||
}
|
||||
hs.cert = &config.Certificates[0]
|
||||
if len(hs.clientHello.serverName) > 0 {
|
||||
hs.cert = config.getCertificateForName(hs.clientHello.serverName)
|
||||
}
|
||||
|
||||
_, hs.ecdsaOk = hs.cert.PrivateKey.(*ecdsa.PrivateKey)
|
||||
|
||||
if hs.checkForResumption() {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
var preferenceList, supportedList []uint16
|
||||
if c.config.PreferServerCipherSuites {
|
||||
preferenceList = c.config.cipherSuites()
|
||||
supportedList = hs.clientHello.cipherSuites
|
||||
} else {
|
||||
preferenceList = hs.clientHello.cipherSuites
|
||||
supportedList = c.config.cipherSuites()
|
||||
}
|
||||
|
||||
for _, id := range preferenceList {
|
||||
if hs.suite = c.tryCipherSuite(id, supportedList, c.vers, hs.ellipticOk, hs.ecdsaOk); hs.suite != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if hs.suite == nil {
|
||||
return false, c.sendAlert(alertHandshakeFailure)
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// checkForResumption returns true if we should perform resumption on this connection.
|
||||
func (hs *serverHandshakeState) checkForResumption() bool {
|
||||
c := hs.c
|
||||
|
||||
var ok bool
|
||||
if hs.sessionState, ok = c.decryptTicket(hs.clientHello.sessionTicket); !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
if hs.sessionState.vers > hs.clientHello.vers {
|
||||
return false
|
||||
}
|
||||
if vers, ok := c.config.mutualVersion(hs.sessionState.vers); !ok || vers != hs.sessionState.vers {
|
||||
return false
|
||||
}
|
||||
|
||||
cipherSuiteOk := false
|
||||
// Check that the client is still offering the ciphersuite in the session.
|
||||
for _, id := range hs.clientHello.cipherSuites {
|
||||
if id == hs.sessionState.cipherSuite {
|
||||
cipherSuiteOk = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !cipherSuiteOk {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check that we also support the ciphersuite from the session.
|
||||
hs.suite = c.tryCipherSuite(hs.sessionState.cipherSuite, c.config.cipherSuites(), hs.sessionState.vers, hs.ellipticOk, hs.ecdsaOk)
|
||||
if hs.suite == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
sessionHasClientCerts := len(hs.sessionState.certificates) != 0
|
||||
needClientCerts := c.config.ClientAuth == RequireAnyClientCert || c.config.ClientAuth == RequireAndVerifyClientCert
|
||||
if needClientCerts && !sessionHasClientCerts {
|
||||
return false
|
||||
}
|
||||
if sessionHasClientCerts && c.config.ClientAuth == NoClientCert {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (hs *serverHandshakeState) doResumeHandshake() error {
|
||||
c := hs.c
|
||||
|
||||
hs.hello.cipherSuite = hs.suite.id
|
||||
// We echo the client's session ID in the ServerHello to let it know
|
||||
// that we're doing a resumption.
|
||||
hs.hello.sessionId = hs.clientHello.sessionId
|
||||
hs.finishedHash.Write(hs.hello.marshal())
|
||||
c.writeRecord(recordTypeHandshake, hs.hello.marshal())
|
||||
|
||||
if len(hs.sessionState.certificates) > 0 {
|
||||
if _, err := hs.processCertsFromClient(hs.sessionState.certificates); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
hs.masterSecret = hs.sessionState.masterSecret
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (hs *serverHandshakeState) doFullHandshake() error {
|
||||
config := hs.c.config
|
||||
c := hs.c
|
||||
|
||||
if hs.clientHello.ocspStapling && len(hs.cert.OCSPStaple) > 0 {
|
||||
hs.hello.ocspStapling = true
|
||||
}
|
||||
|
||||
hs.hello.ticketSupported = hs.clientHello.ticketSupported && !config.SessionTicketsDisabled
|
||||
hs.hello.cipherSuite = hs.suite.id
|
||||
hs.finishedHash.Write(hs.hello.marshal())
|
||||
c.writeRecord(recordTypeHandshake, hs.hello.marshal())
|
||||
|
||||
certMsg := new(certificateMsg)
|
||||
certMsg.certificates = hs.cert.Certificate
|
||||
hs.finishedHash.Write(certMsg.marshal())
|
||||
c.writeRecord(recordTypeHandshake, certMsg.marshal())
|
||||
|
||||
if hs.hello.ocspStapling {
|
||||
certStatus := new(certificateStatusMsg)
|
||||
certStatus.statusType = statusTypeOCSP
|
||||
certStatus.response = hs.cert.OCSPStaple
|
||||
hs.finishedHash.Write(certStatus.marshal())
|
||||
c.writeRecord(recordTypeHandshake, certStatus.marshal())
|
||||
}
|
||||
|
||||
keyAgreement := hs.suite.ka(c.vers)
|
||||
skx, err := keyAgreement.generateServerKeyExchange(config, hs.cert, hs.clientHello, hs.hello)
|
||||
if err != nil {
|
||||
c.sendAlert(alertHandshakeFailure)
|
||||
return err
|
||||
}
|
||||
if skx != nil {
|
||||
hs.finishedHash.Write(skx.marshal())
|
||||
c.writeRecord(recordTypeHandshake, skx.marshal())
|
||||
}
|
||||
|
||||
if config.ClientAuth >= RequestClientCert {
|
||||
// Request a client certificate
|
||||
certReq := new(certificateRequestMsg)
|
||||
certReq.certificateTypes = []byte{
|
||||
byte(certTypeRSASign),
|
||||
byte(certTypeECDSASign),
|
||||
}
|
||||
if c.vers >= VersionTLS12 {
|
||||
certReq.hasSignatureAndHash = true
|
||||
certReq.signatureAndHashes = supportedClientCertSignatureAlgorithms
|
||||
}
|
||||
|
||||
// An empty list of certificateAuthorities signals to
|
||||
// the client that it may send any certificate in response
|
||||
// to our request. When we know the CAs we trust, then
|
||||
// we can send them down, so that the client can choose
|
||||
// an appropriate certificate to give to us.
|
||||
if config.ClientCAs != nil {
|
||||
certReq.certificateAuthorities = config.ClientCAs.Subjects()
|
||||
}
|
||||
hs.finishedHash.Write(certReq.marshal())
|
||||
c.writeRecord(recordTypeHandshake, certReq.marshal())
|
||||
}
|
||||
|
||||
helloDone := new(serverHelloDoneMsg)
|
||||
hs.finishedHash.Write(helloDone.marshal())
|
||||
c.writeRecord(recordTypeHandshake, helloDone.marshal())
|
||||
|
||||
var pub crypto.PublicKey // public key for client auth, if any
|
||||
|
||||
msg, err := c.readHandshake()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var ok bool
|
||||
// If we requested a client certificate, then the client must send a
|
||||
// certificate message, even if it's empty.
|
||||
if config.ClientAuth >= RequestClientCert {
|
||||
if certMsg, ok = msg.(*certificateMsg); !ok {
|
||||
return c.sendAlert(alertHandshakeFailure)
|
||||
}
|
||||
hs.finishedHash.Write(certMsg.marshal())
|
||||
|
||||
if len(certMsg.certificates) == 0 {
|
||||
// The client didn't actually send a certificate
|
||||
switch config.ClientAuth {
|
||||
case RequireAnyClientCert, RequireAndVerifyClientCert:
|
||||
c.sendAlert(alertBadCertificate)
|
||||
return errors.New("tls: client didn't provide a certificate")
|
||||
}
|
||||
}
|
||||
|
||||
pub, err = hs.processCertsFromClient(certMsg.certificates)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
msg, err = c.readHandshake()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Get client key exchange
|
||||
ckx, ok := msg.(*clientKeyExchangeMsg)
|
||||
if !ok {
|
||||
return c.sendAlert(alertUnexpectedMessage)
|
||||
}
|
||||
hs.finishedHash.Write(ckx.marshal())
|
||||
|
||||
// If we received a client cert in response to our certificate request message,
|
||||
// the client will send us a certificateVerifyMsg immediately after the
|
||||
// clientKeyExchangeMsg. This message is a digest of all preceding
|
||||
// handshake-layer messages that is signed using the private key corresponding
|
||||
// to the client's certificate. This allows us to verify that the client is in
|
||||
// possession of the private key of the certificate.
|
||||
if len(c.peerCertificates) > 0 {
|
||||
msg, err = c.readHandshake()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
certVerify, ok := msg.(*certificateVerifyMsg)
|
||||
if !ok {
|
||||
return c.sendAlert(alertUnexpectedMessage)
|
||||
}
|
||||
|
||||
switch key := pub.(type) {
|
||||
case *ecdsa.PublicKey:
|
||||
ecdsaSig := new(ecdsaSignature)
|
||||
if _, err = asn1.Unmarshal(certVerify.signature, ecdsaSig); err != nil {
|
||||
break
|
||||
}
|
||||
if ecdsaSig.R.Sign() <= 0 || ecdsaSig.S.Sign() <= 0 {
|
||||
err = errors.New("ECDSA signature contained zero or negative values")
|
||||
break
|
||||
}
|
||||
digest, _, _ := hs.finishedHash.hashForClientCertificate(signatureECDSA)
|
||||
if !ecdsa.Verify(key, digest, ecdsaSig.R, ecdsaSig.S) {
|
||||
err = errors.New("ECDSA verification failure")
|
||||
break
|
||||
}
|
||||
case *rsa.PublicKey:
|
||||
digest, hashFunc, _ := hs.finishedHash.hashForClientCertificate(signatureRSA)
|
||||
err = rsa.VerifyPKCS1v15(key, hashFunc, digest, certVerify.signature)
|
||||
}
|
||||
if err != nil {
|
||||
c.sendAlert(alertBadCertificate)
|
||||
return errors.New("could not validate signature of connection nonces: " + err.Error())
|
||||
}
|
||||
|
||||
hs.finishedHash.Write(certVerify.marshal())
|
||||
}
|
||||
|
||||
preMasterSecret, err := keyAgreement.processClientKeyExchange(config, hs.cert, ckx, c.vers)
|
||||
if err != nil {
|
||||
c.sendAlert(alertHandshakeFailure)
|
||||
return err
|
||||
}
|
||||
hs.masterSecret = masterFromPreMasterSecret(c.vers, preMasterSecret, hs.clientHello.random, hs.hello.random)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (hs *serverHandshakeState) establishKeys() error {
|
||||
c := hs.c
|
||||
|
||||
clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV :=
|
||||
keysFromMasterSecret(c.vers, hs.masterSecret, hs.clientHello.random, hs.hello.random, hs.suite.macLen, hs.suite.keyLen, hs.suite.ivLen)
|
||||
|
||||
var clientCipher, serverCipher interface{}
|
||||
var clientHash, serverHash macFunction
|
||||
|
||||
if hs.suite.aead == nil {
|
||||
clientCipher = hs.suite.cipher(clientKey, clientIV, true /* for reading */)
|
||||
clientHash = hs.suite.mac(c.vers, clientMAC)
|
||||
serverCipher = hs.suite.cipher(serverKey, serverIV, false /* not for reading */)
|
||||
serverHash = hs.suite.mac(c.vers, serverMAC)
|
||||
} else {
|
||||
clientCipher = hs.suite.aead(clientKey, clientIV)
|
||||
serverCipher = hs.suite.aead(serverKey, serverIV)
|
||||
}
|
||||
|
||||
c.in.prepareCipherSpec(c.vers, clientCipher, clientHash)
|
||||
c.out.prepareCipherSpec(c.vers, serverCipher, serverHash)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (hs *serverHandshakeState) readFinished() error {
|
||||
c := hs.c
|
||||
|
||||
c.readRecord(recordTypeChangeCipherSpec)
|
||||
if err := c.error(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if hs.hello.nextProtoNeg {
|
||||
msg, err := c.readHandshake()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
nextProto, ok := msg.(*nextProtoMsg)
|
||||
if !ok {
|
||||
return c.sendAlert(alertUnexpectedMessage)
|
||||
}
|
||||
hs.finishedHash.Write(nextProto.marshal())
|
||||
c.clientProtocol = nextProto.proto
|
||||
}
|
||||
|
||||
msg, err := c.readHandshake()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
clientFinished, ok := msg.(*finishedMsg)
|
||||
if !ok {
|
||||
return c.sendAlert(alertUnexpectedMessage)
|
||||
}
|
||||
|
||||
verify := hs.finishedHash.clientSum(hs.masterSecret)
|
||||
if len(verify) != len(clientFinished.verifyData) ||
|
||||
subtle.ConstantTimeCompare(verify, clientFinished.verifyData) != 1 {
|
||||
return c.sendAlert(alertHandshakeFailure)
|
||||
}
|
||||
|
||||
hs.finishedHash.Write(clientFinished.marshal())
|
||||
return nil
|
||||
}
|
||||
|
||||
func (hs *serverHandshakeState) sendSessionTicket() error {
|
||||
if !hs.hello.ticketSupported {
|
||||
return nil
|
||||
}
|
||||
|
||||
c := hs.c
|
||||
m := new(newSessionTicketMsg)
|
||||
|
||||
var err error
|
||||
state := sessionState{
|
||||
vers: c.vers,
|
||||
cipherSuite: hs.suite.id,
|
||||
masterSecret: hs.masterSecret,
|
||||
certificates: hs.certsFromClient,
|
||||
}
|
||||
m.ticket, err = c.encryptTicket(&state)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
hs.finishedHash.Write(m.marshal())
|
||||
c.writeRecord(recordTypeHandshake, m.marshal())
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (hs *serverHandshakeState) sendFinished() error {
|
||||
c := hs.c
|
||||
|
||||
c.writeRecord(recordTypeChangeCipherSpec, []byte{1})
|
||||
|
||||
finished := new(finishedMsg)
|
||||
finished.verifyData = hs.finishedHash.serverSum(hs.masterSecret)
|
||||
hs.finishedHash.Write(finished.marshal())
|
||||
c.writeRecord(recordTypeHandshake, finished.marshal())
|
||||
|
||||
c.cipherSuite = hs.suite.id
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// processCertsFromClient takes a chain of client certificates either from a
|
||||
// Certificates message or from a sessionState and verifies them. It returns
|
||||
// the public key of the leaf certificate.
|
||||
func (hs *serverHandshakeState) processCertsFromClient(certificates [][]byte) (crypto.PublicKey, error) {
|
||||
c := hs.c
|
||||
|
||||
hs.certsFromClient = certificates
|
||||
certs := make([]*x509.Certificate, len(certificates))
|
||||
var err error
|
||||
for i, asn1Data := range certificates {
|
||||
if certs[i], err = x509.ParseCertificate(asn1Data); err != nil {
|
||||
c.sendAlert(alertBadCertificate)
|
||||
return nil, errors.New("tls: failed to parse client certificate: " + err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
if c.config.ClientAuth >= VerifyClientCertIfGiven && len(certs) > 0 {
|
||||
opts := x509.VerifyOptions{
|
||||
Roots: c.config.ClientCAs,
|
||||
CurrentTime: c.config.time(),
|
||||
Intermediates: x509.NewCertPool(),
|
||||
KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
|
||||
}
|
||||
|
||||
for _, cert := range certs[1:] {
|
||||
opts.Intermediates.AddCert(cert)
|
||||
}
|
||||
|
||||
chains, err := certs[0].Verify(opts)
|
||||
if err != nil {
|
||||
c.sendAlert(alertBadCertificate)
|
||||
return nil, errors.New("tls: failed to verify client's certificate: " + err.Error())
|
||||
}
|
||||
|
||||
ok := false
|
||||
for _, ku := range certs[0].ExtKeyUsage {
|
||||
if ku == x509.ExtKeyUsageClientAuth {
|
||||
ok = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !ok {
|
||||
c.sendAlert(alertHandshakeFailure)
|
||||
return nil, errors.New("tls: client's certificate's extended key usage doesn't permit it to be used for client authentication")
|
||||
}
|
||||
|
||||
c.verifiedChains = chains
|
||||
}
|
||||
|
||||
if len(certs) > 0 {
|
||||
var pub crypto.PublicKey
|
||||
switch key := certs[0].PublicKey.(type) {
|
||||
case *ecdsa.PublicKey, *rsa.PublicKey:
|
||||
pub = key
|
||||
default:
|
||||
return nil, c.sendAlert(alertUnsupportedCertificate)
|
||||
}
|
||||
c.peerCertificates = certs
|
||||
return pub, nil
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// tryCipherSuite returns a cipherSuite with the given id if that cipher suite
|
||||
// is acceptable to use.
|
||||
func (c *Conn) tryCipherSuite(id uint16, supportedCipherSuites []uint16, version uint16, ellipticOk, ecdsaOk bool) *cipherSuite {
|
||||
for _, supported := range supportedCipherSuites {
|
||||
if id == supported {
|
||||
var candidate *cipherSuite
|
||||
|
||||
for _, s := range cipherSuites {
|
||||
if s.id == id {
|
||||
candidate = s
|
||||
break
|
||||
}
|
||||
}
|
||||
if candidate == nil {
|
||||
continue
|
||||
}
|
||||
// Don't select a ciphersuite which we can't
|
||||
// support for this client.
|
||||
if (candidate.flags&suiteECDHE != 0) && !ellipticOk {
|
||||
continue
|
||||
}
|
||||
if (candidate.flags&suiteECDSA != 0) != ecdsaOk {
|
||||
continue
|
||||
}
|
||||
if version < VersionTLS12 && candidate.flags&suiteTLS12 != 0 {
|
||||
continue
|
||||
}
|
||||
return candidate
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
400
vendor/github.com/masterzen/azure-sdk-for-go/core/tls/key_agreement.go
generated
vendored
Normal file
400
vendor/github.com/masterzen/azure-sdk-for-go/core/tls/key_agreement.go
generated
vendored
Normal file
|
@ -0,0 +1,400 @@
|
|||
// Copyright 2010 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package tls
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/md5"
|
||||
"crypto/rsa"
|
||||
"crypto/sha1"
|
||||
"crypto/sha256"
|
||||
"crypto/x509"
|
||||
"encoding/asn1"
|
||||
"errors"
|
||||
"io"
|
||||
"math/big"
|
||||
)
|
||||
|
||||
// rsaKeyAgreement implements the standard TLS key agreement where the client
|
||||
// encrypts the pre-master secret to the server's public key.
|
||||
type rsaKeyAgreement struct{}
|
||||
|
||||
func (ka rsaKeyAgreement) generateServerKeyExchange(config *Config, cert *Certificate, clientHello *clientHelloMsg, hello *serverHelloMsg) (*serverKeyExchangeMsg, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (ka rsaKeyAgreement) processClientKeyExchange(config *Config, cert *Certificate, ckx *clientKeyExchangeMsg, version uint16) ([]byte, error) {
|
||||
preMasterSecret := make([]byte, 48)
|
||||
_, err := io.ReadFull(config.rand(), preMasterSecret[2:])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(ckx.ciphertext) < 2 {
|
||||
return nil, errors.New("bad ClientKeyExchange")
|
||||
}
|
||||
|
||||
ciphertext := ckx.ciphertext
|
||||
if version != VersionSSL30 {
|
||||
ciphertextLen := int(ckx.ciphertext[0])<<8 | int(ckx.ciphertext[1])
|
||||
if ciphertextLen != len(ckx.ciphertext)-2 {
|
||||
return nil, errors.New("bad ClientKeyExchange")
|
||||
}
|
||||
ciphertext = ckx.ciphertext[2:]
|
||||
}
|
||||
|
||||
err = rsa.DecryptPKCS1v15SessionKey(config.rand(), cert.PrivateKey.(*rsa.PrivateKey), ciphertext, preMasterSecret)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// We don't check the version number in the premaster secret. For one,
|
||||
// by checking it, we would leak information about the validity of the
|
||||
// encrypted pre-master secret. Secondly, it provides only a small
|
||||
// benefit against a downgrade attack and some implementations send the
|
||||
// wrong version anyway. See the discussion at the end of section
|
||||
// 7.4.7.1 of RFC 4346.
|
||||
return preMasterSecret, nil
|
||||
}
|
||||
|
||||
func (ka rsaKeyAgreement) processServerKeyExchange(config *Config, clientHello *clientHelloMsg, serverHello *serverHelloMsg, cert *x509.Certificate, skx *serverKeyExchangeMsg) error {
|
||||
return errors.New("unexpected ServerKeyExchange")
|
||||
}
|
||||
|
||||
func (ka rsaKeyAgreement) generateClientKeyExchange(config *Config, clientHello *clientHelloMsg, cert *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error) {
|
||||
preMasterSecret := make([]byte, 48)
|
||||
preMasterSecret[0] = byte(clientHello.vers >> 8)
|
||||
preMasterSecret[1] = byte(clientHello.vers)
|
||||
_, err := io.ReadFull(config.rand(), preMasterSecret[2:])
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
encrypted, err := rsa.EncryptPKCS1v15(config.rand(), cert.PublicKey.(*rsa.PublicKey), preMasterSecret)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
ckx := new(clientKeyExchangeMsg)
|
||||
ckx.ciphertext = make([]byte, len(encrypted)+2)
|
||||
ckx.ciphertext[0] = byte(len(encrypted) >> 8)
|
||||
ckx.ciphertext[1] = byte(len(encrypted))
|
||||
copy(ckx.ciphertext[2:], encrypted)
|
||||
return preMasterSecret, ckx, nil
|
||||
}
|
||||
|
||||
// sha1Hash calculates a SHA1 hash over the given byte slices.
|
||||
func sha1Hash(slices [][]byte) []byte {
|
||||
hsha1 := sha1.New()
|
||||
for _, slice := range slices {
|
||||
hsha1.Write(slice)
|
||||
}
|
||||
return hsha1.Sum(nil)
|
||||
}
|
||||
|
||||
// md5SHA1Hash implements TLS 1.0's hybrid hash function which consists of the
|
||||
// concatenation of an MD5 and SHA1 hash.
|
||||
func md5SHA1Hash(slices [][]byte) []byte {
|
||||
md5sha1 := make([]byte, md5.Size+sha1.Size)
|
||||
hmd5 := md5.New()
|
||||
for _, slice := range slices {
|
||||
hmd5.Write(slice)
|
||||
}
|
||||
copy(md5sha1, hmd5.Sum(nil))
|
||||
copy(md5sha1[md5.Size:], sha1Hash(slices))
|
||||
return md5sha1
|
||||
}
|
||||
|
||||
// sha256Hash implements TLS 1.2's hash function.
|
||||
func sha256Hash(slices [][]byte) []byte {
|
||||
h := sha256.New()
|
||||
for _, slice := range slices {
|
||||
h.Write(slice)
|
||||
}
|
||||
return h.Sum(nil)
|
||||
}
|
||||
|
||||
// hashForServerKeyExchange hashes the given slices and returns their digest
|
||||
// and the identifier of the hash function used. The hashFunc argument is only
|
||||
// used for >= TLS 1.2 and precisely identifies the hash function to use.
|
||||
func hashForServerKeyExchange(sigType, hashFunc uint8, version uint16, slices ...[]byte) ([]byte, crypto.Hash, error) {
|
||||
if version >= VersionTLS12 {
|
||||
switch hashFunc {
|
||||
case hashSHA256:
|
||||
return sha256Hash(slices), crypto.SHA256, nil
|
||||
case hashSHA1:
|
||||
return sha1Hash(slices), crypto.SHA1, nil
|
||||
default:
|
||||
return nil, crypto.Hash(0), errors.New("tls: unknown hash function used by peer")
|
||||
}
|
||||
}
|
||||
if sigType == signatureECDSA {
|
||||
return sha1Hash(slices), crypto.SHA1, nil
|
||||
}
|
||||
return md5SHA1Hash(slices), crypto.MD5SHA1, nil
|
||||
}
|
||||
|
||||
// pickTLS12HashForSignature returns a TLS 1.2 hash identifier for signing a
|
||||
// ServerKeyExchange given the signature type being used and the client's
|
||||
// advertized list of supported signature and hash combinations.
|
||||
func pickTLS12HashForSignature(sigType uint8, clientSignatureAndHashes []signatureAndHash) (uint8, error) {
|
||||
if len(clientSignatureAndHashes) == 0 {
|
||||
// If the client didn't specify any signature_algorithms
|
||||
// extension then we can assume that it supports SHA1. See
|
||||
// http://tools.ietf.org/html/rfc5246#section-7.4.1.4.1
|
||||
return hashSHA1, nil
|
||||
}
|
||||
|
||||
for _, sigAndHash := range clientSignatureAndHashes {
|
||||
if sigAndHash.signature != sigType {
|
||||
continue
|
||||
}
|
||||
switch sigAndHash.hash {
|
||||
case hashSHA1, hashSHA256:
|
||||
return sigAndHash.hash, nil
|
||||
}
|
||||
}
|
||||
|
||||
return 0, errors.New("tls: client doesn't support any common hash functions")
|
||||
}
|
||||
|
||||
// ecdheRSAKeyAgreement implements a TLS key agreement where the server
|
||||
// generates a ephemeral EC public/private key pair and signs it. The
|
||||
// pre-master secret is then calculated using ECDH. The signature may
|
||||
// either be ECDSA or RSA.
|
||||
type ecdheKeyAgreement struct {
|
||||
version uint16
|
||||
sigType uint8
|
||||
privateKey []byte
|
||||
curve elliptic.Curve
|
||||
x, y *big.Int
|
||||
}
|
||||
|
||||
func (ka *ecdheKeyAgreement) generateServerKeyExchange(config *Config, cert *Certificate, clientHello *clientHelloMsg, hello *serverHelloMsg) (*serverKeyExchangeMsg, error) {
|
||||
var curveid uint16
|
||||
|
||||
Curve:
|
||||
for _, c := range clientHello.supportedCurves {
|
||||
switch c {
|
||||
case curveP256:
|
||||
ka.curve = elliptic.P256()
|
||||
curveid = c
|
||||
break Curve
|
||||
case curveP384:
|
||||
ka.curve = elliptic.P384()
|
||||
curveid = c
|
||||
break Curve
|
||||
case curveP521:
|
||||
ka.curve = elliptic.P521()
|
||||
curveid = c
|
||||
break Curve
|
||||
}
|
||||
}
|
||||
|
||||
if curveid == 0 {
|
||||
return nil, errors.New("tls: no supported elliptic curves offered")
|
||||
}
|
||||
|
||||
var x, y *big.Int
|
||||
var err error
|
||||
ka.privateKey, x, y, err = elliptic.GenerateKey(ka.curve, config.rand())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ecdhePublic := elliptic.Marshal(ka.curve, x, y)
|
||||
|
||||
// http://tools.ietf.org/html/rfc4492#section-5.4
|
||||
serverECDHParams := make([]byte, 1+2+1+len(ecdhePublic))
|
||||
serverECDHParams[0] = 3 // named curve
|
||||
serverECDHParams[1] = byte(curveid >> 8)
|
||||
serverECDHParams[2] = byte(curveid)
|
||||
serverECDHParams[3] = byte(len(ecdhePublic))
|
||||
copy(serverECDHParams[4:], ecdhePublic)
|
||||
|
||||
var tls12HashId uint8
|
||||
if ka.version >= VersionTLS12 {
|
||||
if tls12HashId, err = pickTLS12HashForSignature(ka.sigType, clientHello.signatureAndHashes); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
digest, hashFunc, err := hashForServerKeyExchange(ka.sigType, tls12HashId, ka.version, clientHello.random, hello.random, serverECDHParams)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var sig []byte
|
||||
switch ka.sigType {
|
||||
case signatureECDSA:
|
||||
privKey, ok := cert.PrivateKey.(*ecdsa.PrivateKey)
|
||||
if !ok {
|
||||
return nil, errors.New("ECDHE ECDSA requires an ECDSA server private key")
|
||||
}
|
||||
r, s, err := ecdsa.Sign(config.rand(), privKey, digest)
|
||||
if err != nil {
|
||||
return nil, errors.New("failed to sign ECDHE parameters: " + err.Error())
|
||||
}
|
||||
sig, err = asn1.Marshal(ecdsaSignature{r, s})
|
||||
case signatureRSA:
|
||||
privKey, ok := cert.PrivateKey.(*rsa.PrivateKey)
|
||||
if !ok {
|
||||
return nil, errors.New("ECDHE RSA requires a RSA server private key")
|
||||
}
|
||||
sig, err = rsa.SignPKCS1v15(config.rand(), privKey, hashFunc, digest)
|
||||
if err != nil {
|
||||
return nil, errors.New("failed to sign ECDHE parameters: " + err.Error())
|
||||
}
|
||||
default:
|
||||
return nil, errors.New("unknown ECDHE signature algorithm")
|
||||
}
|
||||
|
||||
skx := new(serverKeyExchangeMsg)
|
||||
sigAndHashLen := 0
|
||||
if ka.version >= VersionTLS12 {
|
||||
sigAndHashLen = 2
|
||||
}
|
||||
skx.key = make([]byte, len(serverECDHParams)+sigAndHashLen+2+len(sig))
|
||||
copy(skx.key, serverECDHParams)
|
||||
k := skx.key[len(serverECDHParams):]
|
||||
if ka.version >= VersionTLS12 {
|
||||
k[0] = tls12HashId
|
||||
k[1] = ka.sigType
|
||||
k = k[2:]
|
||||
}
|
||||
k[0] = byte(len(sig) >> 8)
|
||||
k[1] = byte(len(sig))
|
||||
copy(k[2:], sig)
|
||||
|
||||
return skx, nil
|
||||
}
|
||||
|
||||
func (ka *ecdheKeyAgreement) processClientKeyExchange(config *Config, cert *Certificate, ckx *clientKeyExchangeMsg, version uint16) ([]byte, error) {
|
||||
if len(ckx.ciphertext) == 0 || int(ckx.ciphertext[0]) != len(ckx.ciphertext)-1 {
|
||||
return nil, errors.New("bad ClientKeyExchange")
|
||||
}
|
||||
x, y := elliptic.Unmarshal(ka.curve, ckx.ciphertext[1:])
|
||||
if x == nil {
|
||||
return nil, errors.New("bad ClientKeyExchange")
|
||||
}
|
||||
x, _ = ka.curve.ScalarMult(x, y, ka.privateKey)
|
||||
preMasterSecret := make([]byte, (ka.curve.Params().BitSize+7)>>3)
|
||||
xBytes := x.Bytes()
|
||||
copy(preMasterSecret[len(preMasterSecret)-len(xBytes):], xBytes)
|
||||
|
||||
return preMasterSecret, nil
|
||||
}
|
||||
|
||||
var errServerKeyExchange = errors.New("invalid ServerKeyExchange")
|
||||
|
||||
func (ka *ecdheKeyAgreement) processServerKeyExchange(config *Config, clientHello *clientHelloMsg, serverHello *serverHelloMsg, cert *x509.Certificate, skx *serverKeyExchangeMsg) error {
|
||||
if len(skx.key) < 4 {
|
||||
return errServerKeyExchange
|
||||
}
|
||||
if skx.key[0] != 3 { // named curve
|
||||
return errors.New("server selected unsupported curve")
|
||||
}
|
||||
curveid := uint16(skx.key[1])<<8 | uint16(skx.key[2])
|
||||
|
||||
switch curveid {
|
||||
case curveP256:
|
||||
ka.curve = elliptic.P256()
|
||||
case curveP384:
|
||||
ka.curve = elliptic.P384()
|
||||
case curveP521:
|
||||
ka.curve = elliptic.P521()
|
||||
default:
|
||||
return errors.New("server selected unsupported curve")
|
||||
}
|
||||
|
||||
publicLen := int(skx.key[3])
|
||||
if publicLen+4 > len(skx.key) {
|
||||
return errServerKeyExchange
|
||||
}
|
||||
ka.x, ka.y = elliptic.Unmarshal(ka.curve, skx.key[4:4+publicLen])
|
||||
if ka.x == nil {
|
||||
return errServerKeyExchange
|
||||
}
|
||||
serverECDHParams := skx.key[:4+publicLen]
|
||||
|
||||
sig := skx.key[4+publicLen:]
|
||||
if len(sig) < 2 {
|
||||
return errServerKeyExchange
|
||||
}
|
||||
|
||||
var tls12HashId uint8
|
||||
if ka.version >= VersionTLS12 {
|
||||
// handle SignatureAndHashAlgorithm
|
||||
var sigAndHash []uint8
|
||||
sigAndHash, sig = sig[:2], sig[2:]
|
||||
if sigAndHash[1] != ka.sigType {
|
||||
return errServerKeyExchange
|
||||
}
|
||||
tls12HashId = sigAndHash[0]
|
||||
if len(sig) < 2 {
|
||||
return errServerKeyExchange
|
||||
}
|
||||
}
|
||||
sigLen := int(sig[0])<<8 | int(sig[1])
|
||||
if sigLen+2 != len(sig) {
|
||||
return errServerKeyExchange
|
||||
}
|
||||
sig = sig[2:]
|
||||
|
||||
digest, hashFunc, err := hashForServerKeyExchange(ka.sigType, tls12HashId, ka.version, clientHello.random, serverHello.random, serverECDHParams)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
switch ka.sigType {
|
||||
case signatureECDSA:
|
||||
pubKey, ok := cert.PublicKey.(*ecdsa.PublicKey)
|
||||
if !ok {
|
||||
return errors.New("ECDHE ECDSA requires a ECDSA server public key")
|
||||
}
|
||||
ecdsaSig := new(ecdsaSignature)
|
||||
if _, err := asn1.Unmarshal(sig, ecdsaSig); err != nil {
|
||||
return err
|
||||
}
|
||||
if ecdsaSig.R.Sign() <= 0 || ecdsaSig.S.Sign() <= 0 {
|
||||
return errors.New("ECDSA signature contained zero or negative values")
|
||||
}
|
||||
if !ecdsa.Verify(pubKey, digest, ecdsaSig.R, ecdsaSig.S) {
|
||||
return errors.New("ECDSA verification failure")
|
||||
}
|
||||
case signatureRSA:
|
||||
pubKey, ok := cert.PublicKey.(*rsa.PublicKey)
|
||||
if !ok {
|
||||
return errors.New("ECDHE RSA requires a RSA server public key")
|
||||
}
|
||||
if err := rsa.VerifyPKCS1v15(pubKey, hashFunc, digest, sig); err != nil {
|
||||
return err
|
||||
}
|
||||
default:
|
||||
return errors.New("unknown ECDHE signature algorithm")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ka *ecdheKeyAgreement) generateClientKeyExchange(config *Config, clientHello *clientHelloMsg, cert *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error) {
|
||||
if ka.curve == nil {
|
||||
return nil, nil, errors.New("missing ServerKeyExchange message")
|
||||
}
|
||||
priv, mx, my, err := elliptic.GenerateKey(ka.curve, config.rand())
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
x, _ := ka.curve.ScalarMult(ka.x, ka.y, priv)
|
||||
preMasterSecret := make([]byte, (ka.curve.Params().BitSize+7)>>3)
|
||||
xBytes := x.Bytes()
|
||||
copy(preMasterSecret[len(preMasterSecret)-len(xBytes):], xBytes)
|
||||
|
||||
serialized := elliptic.Marshal(ka.curve, mx, my)
|
||||
|
||||
ckx := new(clientKeyExchangeMsg)
|
||||
ckx.ciphertext = make([]byte, 1+len(serialized))
|
||||
ckx.ciphertext[0] = byte(len(serialized))
|
||||
copy(ckx.ciphertext[1:], serialized)
|
||||
|
||||
return preMasterSecret, ckx, nil
|
||||
}
|
|
@ -0,0 +1,291 @@
|
|||
// Copyright 2009 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package tls
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/hmac"
|
||||
"crypto/md5"
|
||||
"crypto/sha1"
|
||||
"crypto/sha256"
|
||||
"hash"
|
||||
)
|
||||
|
||||
// Split a premaster secret in two as specified in RFC 4346, section 5.
|
||||
func splitPreMasterSecret(secret []byte) (s1, s2 []byte) {
|
||||
s1 = secret[0 : (len(secret)+1)/2]
|
||||
s2 = secret[len(secret)/2:]
|
||||
return
|
||||
}
|
||||
|
||||
// pHash implements the P_hash function, as defined in RFC 4346, section 5.
|
||||
func pHash(result, secret, seed []byte, hash func() hash.Hash) {
|
||||
h := hmac.New(hash, secret)
|
||||
h.Write(seed)
|
||||
a := h.Sum(nil)
|
||||
|
||||
j := 0
|
||||
for j < len(result) {
|
||||
h.Reset()
|
||||
h.Write(a)
|
||||
h.Write(seed)
|
||||
b := h.Sum(nil)
|
||||
todo := len(b)
|
||||
if j+todo > len(result) {
|
||||
todo = len(result) - j
|
||||
}
|
||||
copy(result[j:j+todo], b)
|
||||
j += todo
|
||||
|
||||
h.Reset()
|
||||
h.Write(a)
|
||||
a = h.Sum(nil)
|
||||
}
|
||||
}
|
||||
|
||||
// prf10 implements the TLS 1.0 pseudo-random function, as defined in RFC 2246, section 5.
|
||||
func prf10(result, secret, label, seed []byte) {
|
||||
hashSHA1 := sha1.New
|
||||
hashMD5 := md5.New
|
||||
|
||||
labelAndSeed := make([]byte, len(label)+len(seed))
|
||||
copy(labelAndSeed, label)
|
||||
copy(labelAndSeed[len(label):], seed)
|
||||
|
||||
s1, s2 := splitPreMasterSecret(secret)
|
||||
pHash(result, s1, labelAndSeed, hashMD5)
|
||||
result2 := make([]byte, len(result))
|
||||
pHash(result2, s2, labelAndSeed, hashSHA1)
|
||||
|
||||
for i, b := range result2 {
|
||||
result[i] ^= b
|
||||
}
|
||||
}
|
||||
|
||||
// prf12 implements the TLS 1.2 pseudo-random function, as defined in RFC 5246, section 5.
|
||||
func prf12(result, secret, label, seed []byte) {
|
||||
labelAndSeed := make([]byte, len(label)+len(seed))
|
||||
copy(labelAndSeed, label)
|
||||
copy(labelAndSeed[len(label):], seed)
|
||||
|
||||
pHash(result, secret, labelAndSeed, sha256.New)
|
||||
}
|
||||
|
||||
// prf30 implements the SSL 3.0 pseudo-random function, as defined in
|
||||
// www.mozilla.org/projects/security/pki/nss/ssl/draft302.txt section 6.
|
||||
func prf30(result, secret, label, seed []byte) {
|
||||
hashSHA1 := sha1.New()
|
||||
hashMD5 := md5.New()
|
||||
|
||||
done := 0
|
||||
i := 0
|
||||
// RFC5246 section 6.3 says that the largest PRF output needed is 128
|
||||
// bytes. Since no more ciphersuites will be added to SSLv3, this will
|
||||
// remain true. Each iteration gives us 16 bytes so 10 iterations will
|
||||
// be sufficient.
|
||||
var b [11]byte
|
||||
for done < len(result) {
|
||||
for j := 0; j <= i; j++ {
|
||||
b[j] = 'A' + byte(i)
|
||||
}
|
||||
|
||||
hashSHA1.Reset()
|
||||
hashSHA1.Write(b[:i+1])
|
||||
hashSHA1.Write(secret)
|
||||
hashSHA1.Write(seed)
|
||||
digest := hashSHA1.Sum(nil)
|
||||
|
||||
hashMD5.Reset()
|
||||
hashMD5.Write(secret)
|
||||
hashMD5.Write(digest)
|
||||
|
||||
done += copy(result[done:], hashMD5.Sum(nil))
|
||||
i++
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
tlsRandomLength = 32 // Length of a random nonce in TLS 1.1.
|
||||
masterSecretLength = 48 // Length of a master secret in TLS 1.1.
|
||||
finishedVerifyLength = 12 // Length of verify_data in a Finished message.
|
||||
)
|
||||
|
||||
var masterSecretLabel = []byte("master secret")
|
||||
var keyExpansionLabel = []byte("key expansion")
|
||||
var clientFinishedLabel = []byte("client finished")
|
||||
var serverFinishedLabel = []byte("server finished")
|
||||
|
||||
func prfForVersion(version uint16) func(result, secret, label, seed []byte) {
|
||||
switch version {
|
||||
case VersionSSL30:
|
||||
return prf30
|
||||
case VersionTLS10, VersionTLS11:
|
||||
return prf10
|
||||
case VersionTLS12:
|
||||
return prf12
|
||||
default:
|
||||
panic("unknown version")
|
||||
}
|
||||
}
|
||||
|
||||
// masterFromPreMasterSecret generates the master secret from the pre-master
|
||||
// secret. See http://tools.ietf.org/html/rfc5246#section-8.1
|
||||
func masterFromPreMasterSecret(version uint16, preMasterSecret, clientRandom, serverRandom []byte) []byte {
|
||||
var seed [tlsRandomLength * 2]byte
|
||||
copy(seed[0:len(clientRandom)], clientRandom)
|
||||
copy(seed[len(clientRandom):], serverRandom)
|
||||
masterSecret := make([]byte, masterSecretLength)
|
||||
prfForVersion(version)(masterSecret, preMasterSecret, masterSecretLabel, seed[0:])
|
||||
return masterSecret
|
||||
}
|
||||
|
||||
// keysFromMasterSecret generates the connection keys from the master
|
||||
// secret, given the lengths of the MAC key, cipher key and IV, as defined in
|
||||
// RFC 2246, section 6.3.
|
||||
func keysFromMasterSecret(version uint16, masterSecret, clientRandom, serverRandom []byte, macLen, keyLen, ivLen int) (clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV []byte) {
|
||||
var seed [tlsRandomLength * 2]byte
|
||||
copy(seed[0:len(clientRandom)], serverRandom)
|
||||
copy(seed[len(serverRandom):], clientRandom)
|
||||
|
||||
n := 2*macLen + 2*keyLen + 2*ivLen
|
||||
keyMaterial := make([]byte, n)
|
||||
prfForVersion(version)(keyMaterial, masterSecret, keyExpansionLabel, seed[0:])
|
||||
clientMAC = keyMaterial[:macLen]
|
||||
keyMaterial = keyMaterial[macLen:]
|
||||
serverMAC = keyMaterial[:macLen]
|
||||
keyMaterial = keyMaterial[macLen:]
|
||||
clientKey = keyMaterial[:keyLen]
|
||||
keyMaterial = keyMaterial[keyLen:]
|
||||
serverKey = keyMaterial[:keyLen]
|
||||
keyMaterial = keyMaterial[keyLen:]
|
||||
clientIV = keyMaterial[:ivLen]
|
||||
keyMaterial = keyMaterial[ivLen:]
|
||||
serverIV = keyMaterial[:ivLen]
|
||||
return
|
||||
}
|
||||
|
||||
func newFinishedHash(version uint16) finishedHash {
|
||||
if version >= VersionTLS12 {
|
||||
return finishedHash{sha256.New(), sha256.New(), nil, nil, version}
|
||||
}
|
||||
return finishedHash{sha1.New(), sha1.New(), md5.New(), md5.New(), version}
|
||||
}
|
||||
|
||||
// A finishedHash calculates the hash of a set of handshake messages suitable
|
||||
// for including in a Finished message.
|
||||
type finishedHash struct {
|
||||
client hash.Hash
|
||||
server hash.Hash
|
||||
|
||||
// Prior to TLS 1.2, an additional MD5 hash is required.
|
||||
clientMD5 hash.Hash
|
||||
serverMD5 hash.Hash
|
||||
|
||||
version uint16
|
||||
}
|
||||
|
||||
func (h finishedHash) Write(msg []byte) (n int, err error) {
|
||||
h.client.Write(msg)
|
||||
h.server.Write(msg)
|
||||
|
||||
if h.version < VersionTLS12 {
|
||||
h.clientMD5.Write(msg)
|
||||
h.serverMD5.Write(msg)
|
||||
}
|
||||
return len(msg), nil
|
||||
}
|
||||
|
||||
// finishedSum30 calculates the contents of the verify_data member of a SSLv3
|
||||
// Finished message given the MD5 and SHA1 hashes of a set of handshake
|
||||
// messages.
|
||||
func finishedSum30(md5, sha1 hash.Hash, masterSecret []byte, magic [4]byte) []byte {
|
||||
md5.Write(magic[:])
|
||||
md5.Write(masterSecret)
|
||||
md5.Write(ssl30Pad1[:])
|
||||
md5Digest := md5.Sum(nil)
|
||||
|
||||
md5.Reset()
|
||||
md5.Write(masterSecret)
|
||||
md5.Write(ssl30Pad2[:])
|
||||
md5.Write(md5Digest)
|
||||
md5Digest = md5.Sum(nil)
|
||||
|
||||
sha1.Write(magic[:])
|
||||
sha1.Write(masterSecret)
|
||||
sha1.Write(ssl30Pad1[:40])
|
||||
sha1Digest := sha1.Sum(nil)
|
||||
|
||||
sha1.Reset()
|
||||
sha1.Write(masterSecret)
|
||||
sha1.Write(ssl30Pad2[:40])
|
||||
sha1.Write(sha1Digest)
|
||||
sha1Digest = sha1.Sum(nil)
|
||||
|
||||
ret := make([]byte, len(md5Digest)+len(sha1Digest))
|
||||
copy(ret, md5Digest)
|
||||
copy(ret[len(md5Digest):], sha1Digest)
|
||||
return ret
|
||||
}
|
||||
|
||||
var ssl3ClientFinishedMagic = [4]byte{0x43, 0x4c, 0x4e, 0x54}
|
||||
var ssl3ServerFinishedMagic = [4]byte{0x53, 0x52, 0x56, 0x52}
|
||||
|
||||
// clientSum returns the contents of the verify_data member of a client's
|
||||
// Finished message.
|
||||
func (h finishedHash) clientSum(masterSecret []byte) []byte {
|
||||
if h.version == VersionSSL30 {
|
||||
return finishedSum30(h.clientMD5, h.client, masterSecret, ssl3ClientFinishedMagic)
|
||||
}
|
||||
|
||||
out := make([]byte, finishedVerifyLength)
|
||||
if h.version >= VersionTLS12 {
|
||||
seed := h.client.Sum(nil)
|
||||
prf12(out, masterSecret, clientFinishedLabel, seed)
|
||||
} else {
|
||||
seed := make([]byte, 0, md5.Size+sha1.Size)
|
||||
seed = h.clientMD5.Sum(seed)
|
||||
seed = h.client.Sum(seed)
|
||||
prf10(out, masterSecret, clientFinishedLabel, seed)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// serverSum returns the contents of the verify_data member of a server's
|
||||
// Finished message.
|
||||
func (h finishedHash) serverSum(masterSecret []byte) []byte {
|
||||
if h.version == VersionSSL30 {
|
||||
return finishedSum30(h.serverMD5, h.server, masterSecret, ssl3ServerFinishedMagic)
|
||||
}
|
||||
|
||||
out := make([]byte, finishedVerifyLength)
|
||||
if h.version >= VersionTLS12 {
|
||||
seed := h.server.Sum(nil)
|
||||
prf12(out, masterSecret, serverFinishedLabel, seed)
|
||||
} else {
|
||||
seed := make([]byte, 0, md5.Size+sha1.Size)
|
||||
seed = h.serverMD5.Sum(seed)
|
||||
seed = h.server.Sum(seed)
|
||||
prf10(out, masterSecret, serverFinishedLabel, seed)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// hashForClientCertificate returns a digest, hash function, and TLS 1.2 hash
|
||||
// id suitable for signing by a TLS client certificate.
|
||||
func (h finishedHash) hashForClientCertificate(sigType uint8) ([]byte, crypto.Hash, uint8) {
|
||||
if h.version >= VersionTLS12 {
|
||||
digest := h.server.Sum(nil)
|
||||
return digest, crypto.SHA256, hashSHA256
|
||||
}
|
||||
if sigType == signatureECDSA {
|
||||
digest := h.server.Sum(nil)
|
||||
return digest, crypto.SHA1, hashSHA1
|
||||
}
|
||||
|
||||
digest := make([]byte, 0, 36)
|
||||
digest = h.serverMD5.Sum(digest)
|
||||
digest = h.server.Sum(digest)
|
||||
return digest, crypto.MD5SHA1, 0 /* not specified in TLS 1.2. */
|
||||
}
|
|
@ -0,0 +1,182 @@
|
|||
// Copyright 2012 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package tls
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"crypto/subtle"
|
||||
"errors"
|
||||
"io"
|
||||
)
|
||||
|
||||
// sessionState contains the information that is serialized into a session
|
||||
// ticket in order to later resume a connection.
|
||||
type sessionState struct {
|
||||
vers uint16
|
||||
cipherSuite uint16
|
||||
masterSecret []byte
|
||||
certificates [][]byte
|
||||
}
|
||||
|
||||
func (s *sessionState) equal(i interface{}) bool {
|
||||
s1, ok := i.(*sessionState)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
if s.vers != s1.vers ||
|
||||
s.cipherSuite != s1.cipherSuite ||
|
||||
!bytes.Equal(s.masterSecret, s1.masterSecret) {
|
||||
return false
|
||||
}
|
||||
|
||||
if len(s.certificates) != len(s1.certificates) {
|
||||
return false
|
||||
}
|
||||
|
||||
for i := range s.certificates {
|
||||
if !bytes.Equal(s.certificates[i], s1.certificates[i]) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *sessionState) marshal() []byte {
|
||||
length := 2 + 2 + 2 + len(s.masterSecret) + 2
|
||||
for _, cert := range s.certificates {
|
||||
length += 4 + len(cert)
|
||||
}
|
||||
|
||||
ret := make([]byte, length)
|
||||
x := ret
|
||||
x[0] = byte(s.vers >> 8)
|
||||
x[1] = byte(s.vers)
|
||||
x[2] = byte(s.cipherSuite >> 8)
|
||||
x[3] = byte(s.cipherSuite)
|
||||
x[4] = byte(len(s.masterSecret) >> 8)
|
||||
x[5] = byte(len(s.masterSecret))
|
||||
x = x[6:]
|
||||
copy(x, s.masterSecret)
|
||||
x = x[len(s.masterSecret):]
|
||||
|
||||
x[0] = byte(len(s.certificates) >> 8)
|
||||
x[1] = byte(len(s.certificates))
|
||||
x = x[2:]
|
||||
|
||||
for _, cert := range s.certificates {
|
||||
x[0] = byte(len(cert) >> 24)
|
||||
x[1] = byte(len(cert) >> 16)
|
||||
x[2] = byte(len(cert) >> 8)
|
||||
x[3] = byte(len(cert))
|
||||
copy(x[4:], cert)
|
||||
x = x[4+len(cert):]
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
func (s *sessionState) unmarshal(data []byte) bool {
|
||||
if len(data) < 8 {
|
||||
return false
|
||||
}
|
||||
|
||||
s.vers = uint16(data[0])<<8 | uint16(data[1])
|
||||
s.cipherSuite = uint16(data[2])<<8 | uint16(data[3])
|
||||
masterSecretLen := int(data[4])<<8 | int(data[5])
|
||||
data = data[6:]
|
||||
if len(data) < masterSecretLen {
|
||||
return false
|
||||
}
|
||||
|
||||
s.masterSecret = data[:masterSecretLen]
|
||||
data = data[masterSecretLen:]
|
||||
|
||||
if len(data) < 2 {
|
||||
return false
|
||||
}
|
||||
|
||||
numCerts := int(data[0])<<8 | int(data[1])
|
||||
data = data[2:]
|
||||
|
||||
s.certificates = make([][]byte, numCerts)
|
||||
for i := range s.certificates {
|
||||
if len(data) < 4 {
|
||||
return false
|
||||
}
|
||||
certLen := int(data[0])<<24 | int(data[1])<<16 | int(data[2])<<8 | int(data[3])
|
||||
data = data[4:]
|
||||
if certLen < 0 {
|
||||
return false
|
||||
}
|
||||
if len(data) < certLen {
|
||||
return false
|
||||
}
|
||||
s.certificates[i] = data[:certLen]
|
||||
data = data[certLen:]
|
||||
}
|
||||
|
||||
if len(data) > 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *Conn) encryptTicket(state *sessionState) ([]byte, error) {
|
||||
serialized := state.marshal()
|
||||
encrypted := make([]byte, aes.BlockSize+len(serialized)+sha256.Size)
|
||||
iv := encrypted[:aes.BlockSize]
|
||||
macBytes := encrypted[len(encrypted)-sha256.Size:]
|
||||
|
||||
if _, err := io.ReadFull(c.config.rand(), iv); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
block, err := aes.NewCipher(c.config.SessionTicketKey[:16])
|
||||
if err != nil {
|
||||
return nil, errors.New("tls: failed to create cipher while encrypting ticket: " + err.Error())
|
||||
}
|
||||
cipher.NewCTR(block, iv).XORKeyStream(encrypted[aes.BlockSize:], serialized)
|
||||
|
||||
mac := hmac.New(sha256.New, c.config.SessionTicketKey[16:32])
|
||||
mac.Write(encrypted[:len(encrypted)-sha256.Size])
|
||||
mac.Sum(macBytes[:0])
|
||||
|
||||
return encrypted, nil
|
||||
}
|
||||
|
||||
func (c *Conn) decryptTicket(encrypted []byte) (*sessionState, bool) {
|
||||
if len(encrypted) < aes.BlockSize+sha256.Size {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
iv := encrypted[:aes.BlockSize]
|
||||
macBytes := encrypted[len(encrypted)-sha256.Size:]
|
||||
|
||||
mac := hmac.New(sha256.New, c.config.SessionTicketKey[16:32])
|
||||
mac.Write(encrypted[:len(encrypted)-sha256.Size])
|
||||
expected := mac.Sum(nil)
|
||||
|
||||
if subtle.ConstantTimeCompare(macBytes, expected) != 1 {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(c.config.SessionTicketKey[:16])
|
||||
if err != nil {
|
||||
return nil, false
|
||||
}
|
||||
ciphertext := encrypted[aes.BlockSize : len(encrypted)-sha256.Size]
|
||||
plaintext := ciphertext
|
||||
cipher.NewCTR(block, iv).XORKeyStream(plaintext, ciphertext)
|
||||
|
||||
state := new(sessionState)
|
||||
ok := state.unmarshal(plaintext)
|
||||
return state, ok
|
||||
}
|
|
@ -0,0 +1,225 @@
|
|||
// Copyright 2009 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package tls partially implements TLS 1.2, as specified in RFC 5246.
|
||||
package tls
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Server returns a new TLS server side connection
|
||||
// using conn as the underlying transport.
|
||||
// The configuration config must be non-nil and must have
|
||||
// at least one certificate.
|
||||
func Server(conn net.Conn, config *Config) *Conn {
|
||||
return &Conn{conn: conn, config: config}
|
||||
}
|
||||
|
||||
// Client returns a new TLS client side connection
|
||||
// using conn as the underlying transport.
|
||||
// Client interprets a nil configuration as equivalent to
|
||||
// the zero configuration; see the documentation of Config
|
||||
// for the defaults.
|
||||
func Client(conn net.Conn, config *Config) *Conn {
|
||||
return &Conn{conn: conn, config: config, isClient: true}
|
||||
}
|
||||
|
||||
// A listener implements a network listener (net.Listener) for TLS connections.
|
||||
type listener struct {
|
||||
net.Listener
|
||||
config *Config
|
||||
}
|
||||
|
||||
// Accept waits for and returns the next incoming TLS connection.
|
||||
// The returned connection c is a *tls.Conn.
|
||||
func (l *listener) Accept() (c net.Conn, err error) {
|
||||
c, err = l.Listener.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
c = Server(c, l.config)
|
||||
return
|
||||
}
|
||||
|
||||
// NewListener creates a Listener which accepts connections from an inner
|
||||
// Listener and wraps each connection with Server.
|
||||
// The configuration config must be non-nil and must have
|
||||
// at least one certificate.
|
||||
func NewListener(inner net.Listener, config *Config) net.Listener {
|
||||
l := new(listener)
|
||||
l.Listener = inner
|
||||
l.config = config
|
||||
return l
|
||||
}
|
||||
|
||||
// Listen creates a TLS listener accepting connections on the
|
||||
// given network address using net.Listen.
|
||||
// The configuration config must be non-nil and must have
|
||||
// at least one certificate.
|
||||
func Listen(network, laddr string, config *Config) (net.Listener, error) {
|
||||
if config == nil || len(config.Certificates) == 0 {
|
||||
return nil, errors.New("tls.Listen: no certificates in configuration")
|
||||
}
|
||||
l, err := net.Listen(network, laddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewListener(l, config), nil
|
||||
}
|
||||
|
||||
// Dial connects to the given network address using net.Dial
|
||||
// and then initiates a TLS handshake, returning the resulting
|
||||
// TLS connection.
|
||||
// Dial interprets a nil configuration as equivalent to
|
||||
// the zero configuration; see the documentation of Config
|
||||
// for the defaults.
|
||||
func Dial(network, addr string, config *Config) (*Conn, error) {
|
||||
raddr := addr
|
||||
c, err := net.Dial(network, raddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
colonPos := strings.LastIndex(raddr, ":")
|
||||
if colonPos == -1 {
|
||||
colonPos = len(raddr)
|
||||
}
|
||||
hostname := raddr[:colonPos]
|
||||
|
||||
if config == nil {
|
||||
config = defaultConfig()
|
||||
}
|
||||
// If no ServerName is set, infer the ServerName
|
||||
// from the hostname we're connecting to.
|
||||
if config.ServerName == "" {
|
||||
// Make a copy to avoid polluting argument or default.
|
||||
c := *config
|
||||
c.ServerName = hostname
|
||||
config = &c
|
||||
}
|
||||
conn := Client(c, config)
|
||||
if err = conn.Handshake(); err != nil {
|
||||
c.Close()
|
||||
return nil, err
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// LoadX509KeyPair reads and parses a public/private key pair from a pair of
|
||||
// files. The files must contain PEM encoded data.
|
||||
func LoadX509KeyPair(certFile, keyFile string) (cert Certificate, err error) {
|
||||
certPEMBlock, err := ioutil.ReadFile(certFile)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
keyPEMBlock, err := ioutil.ReadFile(keyFile)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return X509KeyPair(certPEMBlock, keyPEMBlock)
|
||||
}
|
||||
|
||||
// X509KeyPair parses a public/private key pair from a pair of
|
||||
// PEM encoded data.
|
||||
func X509KeyPair(certPEMBlock, keyPEMBlock []byte) (cert Certificate, err error) {
|
||||
var certDERBlock *pem.Block
|
||||
for {
|
||||
certDERBlock, certPEMBlock = pem.Decode(certPEMBlock)
|
||||
if certDERBlock == nil {
|
||||
break
|
||||
}
|
||||
if certDERBlock.Type == "CERTIFICATE" {
|
||||
cert.Certificate = append(cert.Certificate, certDERBlock.Bytes)
|
||||
}
|
||||
}
|
||||
|
||||
if len(cert.Certificate) == 0 {
|
||||
err = errors.New("crypto/tls: failed to parse certificate PEM data")
|
||||
return
|
||||
}
|
||||
|
||||
var keyDERBlock *pem.Block
|
||||
for {
|
||||
keyDERBlock, keyPEMBlock = pem.Decode(keyPEMBlock)
|
||||
if keyDERBlock == nil {
|
||||
err = errors.New("crypto/tls: failed to parse key PEM data")
|
||||
return
|
||||
}
|
||||
if keyDERBlock.Type == "PRIVATE KEY" || strings.HasSuffix(keyDERBlock.Type, " PRIVATE KEY") {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
cert.PrivateKey, err = parsePrivateKey(keyDERBlock.Bytes)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// We don't need to parse the public key for TLS, but we so do anyway
|
||||
// to check that it looks sane and matches the private key.
|
||||
x509Cert, err := x509.ParseCertificate(cert.Certificate[0])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
switch pub := x509Cert.PublicKey.(type) {
|
||||
case *rsa.PublicKey:
|
||||
priv, ok := cert.PrivateKey.(*rsa.PrivateKey)
|
||||
if !ok {
|
||||
err = errors.New("crypto/tls: private key type does not match public key type")
|
||||
return
|
||||
}
|
||||
if pub.N.Cmp(priv.N) != 0 {
|
||||
err = errors.New("crypto/tls: private key does not match public key")
|
||||
return
|
||||
}
|
||||
case *ecdsa.PublicKey:
|
||||
priv, ok := cert.PrivateKey.(*ecdsa.PrivateKey)
|
||||
if !ok {
|
||||
err = errors.New("crypto/tls: private key type does not match public key type")
|
||||
return
|
||||
|
||||
}
|
||||
if pub.X.Cmp(priv.X) != 0 || pub.Y.Cmp(priv.Y) != 0 {
|
||||
err = errors.New("crypto/tls: private key does not match public key")
|
||||
return
|
||||
}
|
||||
default:
|
||||
err = errors.New("crypto/tls: unknown public key algorithm")
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Attempt to parse the given private key DER block. OpenSSL 0.9.8 generates
|
||||
// PKCS#1 private keys by default, while OpenSSL 1.0.0 generates PKCS#8 keys.
|
||||
// OpenSSL ecparam generates SEC1 EC private keys for ECDSA. We try all three.
|
||||
func parsePrivateKey(der []byte) (crypto.PrivateKey, error) {
|
||||
if key, err := x509.ParsePKCS1PrivateKey(der); err == nil {
|
||||
return key, nil
|
||||
}
|
||||
if key, err := x509.ParsePKCS8PrivateKey(der); err == nil {
|
||||
switch key := key.(type) {
|
||||
case *rsa.PrivateKey, *ecdsa.PrivateKey:
|
||||
return key, nil
|
||||
default:
|
||||
return nil, errors.New("crypto/tls: found unknown private key type in PKCS#8 wrapping")
|
||||
}
|
||||
}
|
||||
if key, err := x509.ParseECPrivateKey(der); err == nil {
|
||||
return key, nil
|
||||
}
|
||||
|
||||
return nil, errors.New("crypto/tls: failed to parse private key")
|
||||
}
|
|
@ -0,0 +1,34 @@
|
|||
NO_COLOR=\033[0m
|
||||
OK_COLOR=\033[32;01m
|
||||
ERROR_COLOR=\033[31;01m
|
||||
WARN_COLOR=\033[33;01m
|
||||
DEPS = $(go list -f '{{range .TestImports}}{{.}} {{end}}' ./... | fgrep -v 'winrm')
|
||||
|
||||
all: deps
|
||||
@mkdir -p bin/
|
||||
@printf "$(OK_COLOR)==> Building$(NO_COLOR)\n"
|
||||
@go build github.com/masterzen/winrm
|
||||
|
||||
deps:
|
||||
@printf "$(OK_COLOR)==> Installing dependencies$(NO_COLOR)\n"
|
||||
@go get -d -v ./...
|
||||
@echo $(DEPS) | xargs -n1 go get -d
|
||||
|
||||
updatedeps:
|
||||
go list ./... | xargs go list -f '{{join .Deps "\n"}}' | grep -v github.com/masterzen/winrm | sort -u | xargs go get -f -u -v
|
||||
|
||||
clean:
|
||||
@rm -rf bin/ pkg/ src/
|
||||
|
||||
format:
|
||||
go fmt ./...
|
||||
|
||||
ci: deps
|
||||
@printf "$(OK_COLOR)==> Testing with Coveralls...$(NO_COLOR)\n"
|
||||
"$(CURDIR)/scripts/test.sh"
|
||||
|
||||
test: deps
|
||||
@printf "$(OK_COLOR)==> Testing...$(NO_COLOR)\n"
|
||||
go test ./...
|
||||
|
||||
.PHONY: all clean deps format test updatedeps
|
|
@ -0,0 +1,223 @@
|
|||
# WinRM for Go
|
||||
|
||||
_Note_: if you're looking for the `winrm` command-line tool, this has been splitted from this project and is available at [winrm-cli](https://github.com/masterzen/winrm-cli)
|
||||
|
||||
This is a Go library to execute remote commands on Windows machines through
|
||||
the use of WinRM/WinRS.
|
||||
|
||||
_Note_: this library doesn't support domain users (it doesn't support GSSAPI nor Kerberos). It's primary target is to execute remote commands on EC2 windows machines.
|
||||
|
||||
[![Build Status](https://travis-ci.org/masterzen/winrm.svg?branch=master)](https://travis-ci.org/masterzen/winrm)
|
||||
[![Coverage Status](https://coveralls.io/repos/masterzen/winrm/badge.png)](https://coveralls.io/r/masterzen/winrm)
|
||||
|
||||
## Contact
|
||||
|
||||
- Bugs: https://github.com/masterzen/winrm/issues
|
||||
|
||||
|
||||
## Getting Started
|
||||
WinRM is available on Windows Server 2008 and up. This project natively supports basic authentication for local accounts, see the steps in the next section on how to prepare the remote Windows machine for this scenario. The authentication model is pluggable, see below for an example on using Negotiate/NTLM authentication (e.g. for connecting to vanilla Azure VMs).
|
||||
|
||||
### Preparing the remote Windows machine for Basic authentication
|
||||
This project supports only basic authentication for local accounts (domain users are not supported). The remote windows system must be prepared for winrm:
|
||||
|
||||
_For a PowerShell script to do what is described below in one go, check [Richard Downer's blog](http://www.frontiertown.co.uk/2011/12/overthere-control-windows-from-java/)_
|
||||
|
||||
On the remote host, a PowerShell prompt, using the __Run as Administrator__ option and paste in the following lines:
|
||||
|
||||
winrm quickconfig
|
||||
y
|
||||
winrm set winrm/config/service/Auth '@{Basic="true"}'
|
||||
winrm set winrm/config/service '@{AllowUnencrypted="true"}'
|
||||
winrm set winrm/config/winrs '@{MaxMemoryPerShellMB="1024"}'
|
||||
|
||||
__N.B.:__ The Windows Firewall needs to be running to run this command. See [Microsoft Knowledge Base article #2004640](http://support.microsoft.com/kb/2004640).
|
||||
|
||||
__N.B.:__ Do not disable Negotiate authentication as the `winrm` command itself uses this for internal authentication, and you risk getting a system where `winrm` doesn't work anymore.
|
||||
|
||||
__N.B.:__ The `MaxMemoryPerShellMB` option has no effects on some Windows 2008R2 systems because of a WinRM bug. Make sure to install the hotfix described [Microsoft Knowledge Base article #2842230](http://support.microsoft.com/kb/2842230) if you need to run commands that uses more than 150MB of memory.
|
||||
|
||||
For more information on WinRM, please refer to <a href="http://msdn.microsoft.com/en-us/library/windows/desktop/aa384426(v=vs.85).aspx">the online documentation at Microsoft's DevCenter</a>.
|
||||
|
||||
### Building the winrm go and executable
|
||||
|
||||
You can build winrm from source:
|
||||
|
||||
```sh
|
||||
git clone https://github.com/masterzen/winrm
|
||||
cd winrm
|
||||
make
|
||||
```
|
||||
|
||||
_Note_: this winrm code doesn't depend anymore on [Gokogiri](https://github.com/moovweb/gokogiri) which means it is now in pure Go.
|
||||
|
||||
_Note_: you need go 1.5+. Please check your installation with
|
||||
|
||||
```
|
||||
go version
|
||||
```
|
||||
|
||||
## Command-line usage
|
||||
|
||||
For command-line usage check the [winrm-cli project](https://github.com/masterzen/winrm-cli)
|
||||
|
||||
## Library Usage
|
||||
|
||||
**Warning the API might be subject to change.**
|
||||
|
||||
For the fast version (this doesn't allow to send input to the command) and it's using HTTP as the transport:
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"github.com/masterzen/winrm"
|
||||
"os"
|
||||
)
|
||||
|
||||
endpoint := winrm.NewEndpoint(host, 5986, false, false, nil, nil, nil, 0)
|
||||
client, err := winrm.NewClient(endpoint, "Administrator", "secret")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
client.Run("ipconfig /all", os.Stdout, os.Stderr)
|
||||
```
|
||||
|
||||
or
|
||||
```go
|
||||
package main
|
||||
import (
|
||||
"github.com/masterzen/winrm"
|
||||
"fmt"
|
||||
"os"
|
||||
)
|
||||
|
||||
endpoint := winrm.NewEndpoint("localhost", 5985, false, false, nil, nil, nil, 0)
|
||||
client, err := winrm.NewClient(endpoint,"Administrator", "secret")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
_, err := client.RunWithInput("ipconfig", os.Stdout, os.Stderr, os.Stdin)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
By passing a TransportDecorator in the Parameters struct it is possible to use different Transports (e.g. NTLM)
|
||||
|
||||
```go
|
||||
package main
|
||||
import (
|
||||
"github.com/masterzen/winrm"
|
||||
"fmt"
|
||||
"os"
|
||||
)
|
||||
|
||||
endpoint := winrm.NewEndpoint("localhost", 5985, false, false, nil, nil, nil, 0)
|
||||
|
||||
params := DefaultParameters
|
||||
params.TransportDecorator = func() Transporter { return &ClientNTLM{} }
|
||||
|
||||
client, err := NewClientWithParameters(endpoint, "test", "test", params)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
_, err := client.RunWithInput("ipconfig", os.Stdout, os.Stderr, os.Stdin)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
For a more complex example, it is possible to call the various functions directly:
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"github.com/masterzen/winrm"
|
||||
"fmt"
|
||||
"bytes"
|
||||
"os"
|
||||
)
|
||||
|
||||
stdin := bytes.NewBufferString("ipconfig /all")
|
||||
endpoint := winrm.NewEndpoint("localhost", 5985, false, false,nil, nil, nil, 0)
|
||||
client , err := winrm.NewClient(endpoint, "Administrator", "secret")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
shell, err := client.CreateShell()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
var cmd *winrm.Command
|
||||
cmd, err = shell.Execute("cmd.exe")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
go io.Copy(cmd.Stdin, stdin)
|
||||
go io.Copy(os.Stdout, cmd.Stdout)
|
||||
go io.Copy(os.Stderr, cmd.Stderr)
|
||||
|
||||
cmd.Wait()
|
||||
shell.Close()
|
||||
```
|
||||
|
||||
For using HTTPS authentication with x 509 cert without checking the CA
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"github.com/masterzen/winrm"
|
||||
"os"
|
||||
"io/ioutil"
|
||||
)
|
||||
|
||||
clientCert, err := ioutil.ReadFile("path/to/cert")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
clientKey, err := ioutil.ReadFile("path/to/key")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
winrm.DefaultParameters.TransportDecorator = func() winrm.Transporter {
|
||||
// winrm https module
|
||||
return &winrm.ClientAuthRequest{}
|
||||
}
|
||||
|
||||
endpoint := winrm.NewEndpoint(host, 5986, false, false, clientCert, clientKey, nil, 0)
|
||||
client, err := winrm.NewClient(endpoint, "Administrator", ""
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
client.Run("ipconfig /all", os.Stdout, os.Stderr)
|
||||
```
|
||||
|
||||
## Developing on WinRM
|
||||
|
||||
If you wish to work on `winrm` itself, you'll first need [Go](http://golang.org)
|
||||
installed (version 1.5+ is _required_). Make sure you have Go properly installed,
|
||||
including setting up your [GOPATH](http://golang.org/doc/code.html#GOPATH).
|
||||
|
||||
For some additional dependencies, Go needs [Mercurial](http://mercurial.selenic.com/)
|
||||
and [Bazaar](http://bazaar.canonical.com/en/) to be installed.
|
||||
Winrm itself doesn't require these, but a dependency of a dependency does.
|
||||
|
||||
Next, clone this repository into `$GOPATH/src/github.com/masterzen/winrm` and
|
||||
then just type `make`.
|
||||
|
||||
You can run tests by typing `make test`.
|
||||
|
||||
If you make any changes to the code, run `make format` in order to automatically
|
||||
format the code according to Go standards.
|
||||
|
||||
When new dependencies are added to winrm you can use `make updatedeps` to
|
||||
get the latest and subsequently use `make` to compile.
|
|
@ -0,0 +1,106 @@
|
|||
package winrm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/masterzen/azure-sdk-for-go/core/http"
|
||||
"github.com/masterzen/azure-sdk-for-go/core/tls"
|
||||
|
||||
"github.com/masterzen/winrm/soap"
|
||||
)
|
||||
|
||||
type ClientAuthRequest struct {
|
||||
transport http.RoundTripper
|
||||
}
|
||||
|
||||
func (c *ClientAuthRequest) Transport(endpoint *Endpoint) error {
|
||||
cert, err := tls.X509KeyPair(endpoint.Cert, endpoint.Key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
transport := &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
TLSClientConfig: &tls.Config{
|
||||
InsecureSkipVerify: endpoint.Insecure,
|
||||
Certificates: []tls.Certificate{cert},
|
||||
},
|
||||
Dial: (&net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}).Dial,
|
||||
ResponseHeaderTimeout: endpoint.Timeout,
|
||||
}
|
||||
|
||||
if endpoint.CACert != nil && len(endpoint.CACert) > 0 {
|
||||
certPool, err := readCACerts(endpoint.CACert)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
transport.TLSClientConfig.RootCAs = certPool
|
||||
}
|
||||
|
||||
c.transport = transport
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// parse func reads the response body and return it as a string
|
||||
func parse(response *http.Response) (string, error) {
|
||||
|
||||
// if we recived the content we expected
|
||||
if strings.Contains(response.Header.Get("Content-Type"), "application/soap+xml") {
|
||||
body, err := ioutil.ReadAll(response.Body)
|
||||
defer func() {
|
||||
// defer can modify the returned value before
|
||||
// it is actually passed to the calling statement
|
||||
if errClose := response.Body.Close(); errClose != nil && err == nil {
|
||||
err = errClose
|
||||
}
|
||||
}()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error while reading request body %s", err)
|
||||
}
|
||||
|
||||
return string(body), nil
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("invalid content type")
|
||||
}
|
||||
|
||||
func (c ClientAuthRequest) Post(client *Client, request *soap.SoapMessage) (string, error) {
|
||||
httpClient := &http.Client{Transport: c.transport}
|
||||
|
||||
req, err := http.NewRequest("POST", client.url, strings.NewReader(request.String()))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("impossible to create http request %s", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", soapXML+";charset=UTF-8")
|
||||
req.Header.Set("Authorization", "http://schemas.dmtf.org/wbem/wsman/1/wsman/secprofile/https/mutual")
|
||||
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("unknown error %s", err)
|
||||
}
|
||||
|
||||
body, err := parse(resp)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("http response error: %d - %s", resp.StatusCode, err.Error())
|
||||
}
|
||||
|
||||
// if we have different 200 http status code
|
||||
// we must replace the error
|
||||
defer func() {
|
||||
if resp.StatusCode != 200 {
|
||||
body, err = "", fmt.Errorf("http error %d: %s", resp.StatusCode, body)
|
||||
}
|
||||
}()
|
||||
|
||||
return body, err
|
||||
}
|
|
@ -0,0 +1,187 @@
|
|||
package winrm
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
|
||||
"github.com/masterzen/winrm/soap"
|
||||
)
|
||||
|
||||
// Client struct
|
||||
type Client struct {
|
||||
Parameters
|
||||
username string
|
||||
password string
|
||||
useHTTPS bool
|
||||
url string
|
||||
http Transporter
|
||||
}
|
||||
|
||||
// Transporter does different transporters
|
||||
// and init a Post request based on them
|
||||
type Transporter interface {
|
||||
// init request baset on the transport configurations
|
||||
Post(*Client, *soap.SoapMessage) (string, error)
|
||||
Transport(*Endpoint) error
|
||||
}
|
||||
|
||||
// NewClient will create a new remote client on url, connecting with user and password
|
||||
// This function doesn't connect (connection happens only when CreateShell is called)
|
||||
func NewClient(endpoint *Endpoint, user, password string) (*Client, error) {
|
||||
return NewClientWithParameters(endpoint, user, password, DefaultParameters)
|
||||
}
|
||||
|
||||
// NewClientWithParameters will create a new remote client on url, connecting with user and password
|
||||
// This function doesn't connect (connection happens only when CreateShell is called)
|
||||
func NewClientWithParameters(endpoint *Endpoint, user, password string, params *Parameters) (*Client, error) {
|
||||
|
||||
// alloc a new client
|
||||
client := &Client{
|
||||
Parameters: *params,
|
||||
username: user,
|
||||
password: password,
|
||||
url: endpoint.url(),
|
||||
useHTTPS: endpoint.HTTPS,
|
||||
// default transport
|
||||
http: &clientRequest{},
|
||||
}
|
||||
|
||||
// switch to other transport if provided
|
||||
if params.TransportDecorator != nil {
|
||||
client.http = params.TransportDecorator()
|
||||
}
|
||||
|
||||
// set the transport to some endpoint configuration
|
||||
if err := client.http.Transport(endpoint); err != nil {
|
||||
return nil, fmt.Errorf("Can't parse this key and certs: %s", err)
|
||||
}
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func readCACerts(certs []byte) (*x509.CertPool, error) {
|
||||
certPool := x509.NewCertPool()
|
||||
|
||||
if !certPool.AppendCertsFromPEM(certs) {
|
||||
return nil, fmt.Errorf("Unable to read certificates")
|
||||
}
|
||||
|
||||
return certPool, nil
|
||||
}
|
||||
|
||||
// CreateShell will create a WinRM Shell,
|
||||
// which is the prealable for running commands.
|
||||
func (c *Client) CreateShell() (*Shell, error) {
|
||||
request := NewOpenShellRequest(c.url, &c.Parameters)
|
||||
defer request.Free()
|
||||
|
||||
response, err := c.sendRequest(request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
shellID, err := ParseOpenShellResponse(response)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return c.NewShell(shellID), nil
|
||||
|
||||
}
|
||||
|
||||
// NewShell will create a new WinRM Shell for the given shellID
|
||||
func (c *Client) NewShell(id string) *Shell {
|
||||
return &Shell{client: c, id: id}
|
||||
}
|
||||
|
||||
// sendRequest exec the custom http func from the client
|
||||
func (c *Client) sendRequest(request *soap.SoapMessage) (string, error) {
|
||||
return c.http.Post(c, request)
|
||||
}
|
||||
|
||||
// Run will run command on the the remote host, writing the process stdout and stderr to
|
||||
// the given writers. Note with this method it isn't possible to inject stdin.
|
||||
func (c *Client) Run(command string, stdout io.Writer, stderr io.Writer) (int, error) {
|
||||
shell, err := c.CreateShell()
|
||||
if err != nil {
|
||||
return 1, err
|
||||
}
|
||||
defer shell.Close()
|
||||
cmd, err := shell.Execute(command)
|
||||
if err != nil {
|
||||
return 1, err
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
io.Copy(stdout, cmd.Stdout)
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
io.Copy(stderr, cmd.Stderr)
|
||||
}()
|
||||
|
||||
cmd.Wait()
|
||||
wg.Wait()
|
||||
|
||||
return cmd.ExitCode(), cmd.err
|
||||
}
|
||||
|
||||
// RunWithString will run command on the the remote host, returning the process stdout and stderr
|
||||
// as strings, and using the input stdin string as the process input
|
||||
func (c *Client) RunWithString(command string, stdin string) (string, string, int, error) {
|
||||
shell, err := c.CreateShell()
|
||||
if err != nil {
|
||||
return "", "", 1, err
|
||||
}
|
||||
defer shell.Close()
|
||||
|
||||
cmd, err := shell.Execute(command)
|
||||
if err != nil {
|
||||
return "", "", 1, err
|
||||
}
|
||||
if len(stdin) > 0 {
|
||||
cmd.Stdin.Write([]byte(stdin))
|
||||
}
|
||||
|
||||
var outWriter, errWriter bytes.Buffer
|
||||
go io.Copy(&outWriter, cmd.Stdout)
|
||||
go io.Copy(&errWriter, cmd.Stderr)
|
||||
|
||||
cmd.Wait()
|
||||
|
||||
return outWriter.String(), errWriter.String(), cmd.ExitCode(), cmd.err
|
||||
}
|
||||
|
||||
// RunWithInput will run command on the the remote host, writing the process stdout and stderr to
|
||||
// the given writers, and injecting the process stdin with the stdin reader.
|
||||
// Warning stdin (not stdout/stderr) are bufferized, which means reading only one byte in stdin will
|
||||
// send a winrm http packet to the remote host. If stdin is a pipe, it might be better for
|
||||
// performance reasons to buffer it.
|
||||
func (c Client) RunWithInput(command string, stdout, stderr io.Writer, stdin io.Reader) (int, error) {
|
||||
shell, err := c.CreateShell()
|
||||
if err != nil {
|
||||
return 1, err
|
||||
}
|
||||
defer shell.Close()
|
||||
cmd, err := shell.Execute(command)
|
||||
if err != nil {
|
||||
return 1, err
|
||||
}
|
||||
|
||||
go io.Copy(cmd.Stdin, stdin)
|
||||
go io.Copy(stdout, cmd.Stdout)
|
||||
go io.Copy(stderr, cmd.Stderr)
|
||||
|
||||
cmd.Wait()
|
||||
|
||||
return cmd.ExitCode(), cmd.err
|
||||
|
||||
}
|
|
@ -22,12 +22,12 @@ type commandReader struct {
|
|||
// Command represents a given command running on a Shell. This structure allows to get access
|
||||
// to the various stdout, stderr and stdin pipes.
|
||||
type Command struct {
|
||||
client *Client
|
||||
shell *Shell
|
||||
commandId string
|
||||
exitCode int
|
||||
finished bool
|
||||
err error
|
||||
client *Client
|
||||
shell *Shell
|
||||
id string
|
||||
exitCode int
|
||||
finished bool
|
||||
err error
|
||||
|
||||
Stdin *commandWriter
|
||||
Stdout *commandReader
|
||||
|
@ -37,10 +37,22 @@ type Command struct {
|
|||
cancel chan struct{}
|
||||
}
|
||||
|
||||
func newCommand(shell *Shell, commandId string) *Command {
|
||||
command := &Command{shell: shell, client: shell.client, commandId: commandId, exitCode: 1, err: nil, done: make(chan struct{}), cancel: make(chan struct{})}
|
||||
command.Stdin = &commandWriter{Command: command, eof: false}
|
||||
func newCommand(shell *Shell, ids string) *Command {
|
||||
command := &Command{
|
||||
shell: shell,
|
||||
client: shell.client,
|
||||
id: ids,
|
||||
exitCode: 0,
|
||||
err: nil,
|
||||
done: make(chan struct{}),
|
||||
cancel: make(chan struct{}),
|
||||
}
|
||||
|
||||
command.Stdout = newCommandReader("stdout", command)
|
||||
command.Stdin = &commandWriter{
|
||||
Command: command,
|
||||
eof: false,
|
||||
}
|
||||
command.Stderr = newCommandReader("stderr", command)
|
||||
|
||||
go fetchOutput(command)
|
||||
|
@ -50,7 +62,12 @@ func newCommand(shell *Shell, commandId string) *Command {
|
|||
|
||||
func newCommandReader(stream string, command *Command) *commandReader {
|
||||
read, write := io.Pipe()
|
||||
return &commandReader{Command: command, stream: stream, write: write, read: read}
|
||||
return &commandReader{
|
||||
Command: command,
|
||||
stream: stream,
|
||||
write: write,
|
||||
read: read,
|
||||
}
|
||||
}
|
||||
|
||||
func fetchOutput(command *Command) {
|
||||
|
@ -70,122 +87,132 @@ func fetchOutput(command *Command) {
|
|||
}
|
||||
}
|
||||
|
||||
func (command *Command) check() (err error) {
|
||||
if command.commandId == "" {
|
||||
func (c *Command) check() error {
|
||||
if c.id == "" {
|
||||
return errors.New("Command has already been closed")
|
||||
}
|
||||
if command.shell == nil {
|
||||
if c.shell == nil {
|
||||
return errors.New("Command has no associated shell")
|
||||
}
|
||||
if command.client == nil {
|
||||
if c.client == nil {
|
||||
return errors.New("Command has no associated client")
|
||||
}
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close will terminate the running command
|
||||
func (command *Command) Close() (err error) {
|
||||
if err = command.check(); err != nil {
|
||||
func (c *Command) Close() error {
|
||||
if err := c.check(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
select { // close cancel channel if it's still open
|
||||
case <-command.cancel:
|
||||
case <-c.cancel:
|
||||
default:
|
||||
close(command.cancel)
|
||||
close(c.cancel)
|
||||
}
|
||||
|
||||
request := NewSignalRequest(command.client.url, command.shell.ShellId, command.commandId, &command.client.Parameters)
|
||||
request := NewSignalRequest(c.client.url, c.shell.id, c.id, &c.client.Parameters)
|
||||
defer request.Free()
|
||||
|
||||
_, err = command.client.sendRequest(request)
|
||||
_, err := c.client.sendRequest(request)
|
||||
return err
|
||||
}
|
||||
|
||||
func (command *Command) slurpAllOutput() (finished bool, err error) {
|
||||
if err = command.check(); err != nil {
|
||||
command.Stderr.write.CloseWithError(err)
|
||||
command.Stdout.write.CloseWithError(err)
|
||||
func (c *Command) slurpAllOutput() (bool, error) {
|
||||
if err := c.check(); err != nil {
|
||||
c.Stderr.write.CloseWithError(err)
|
||||
c.Stdout.write.CloseWithError(err)
|
||||
return true, err
|
||||
}
|
||||
|
||||
request := NewGetOutputRequest(command.client.url, command.shell.ShellId, command.commandId, "stdout stderr", &command.client.Parameters)
|
||||
request := NewGetOutputRequest(c.client.url, c.shell.id, c.id, "stdout stderr", &c.client.Parameters)
|
||||
defer request.Free()
|
||||
|
||||
response, err := command.client.sendRequest(request)
|
||||
response, err := c.client.sendRequest(request)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "OperationTimeout") {
|
||||
// Operation timeout because there was no command output
|
||||
return
|
||||
return false, err
|
||||
}
|
||||
if strings.Contains(err.Error(), "EOF") {
|
||||
c.exitCode = 16001
|
||||
}
|
||||
|
||||
command.Stderr.write.CloseWithError(err)
|
||||
command.Stdout.write.CloseWithError(err)
|
||||
c.Stderr.write.CloseWithError(err)
|
||||
c.Stdout.write.CloseWithError(err)
|
||||
return true, err
|
||||
}
|
||||
|
||||
var exitCode int
|
||||
var stdout, stderr bytes.Buffer
|
||||
finished, exitCode, err = ParseSlurpOutputErrResponse(response, &stdout, &stderr)
|
||||
finished, exitCode, err := ParseSlurpOutputErrResponse(response, &stdout, &stderr)
|
||||
if err != nil {
|
||||
command.Stderr.write.CloseWithError(err)
|
||||
command.Stdout.write.CloseWithError(err)
|
||||
c.Stderr.write.CloseWithError(err)
|
||||
c.Stdout.write.CloseWithError(err)
|
||||
return true, err
|
||||
}
|
||||
if stdout.Len() > 0 {
|
||||
command.Stdout.write.Write(stdout.Bytes())
|
||||
c.Stdout.write.Write(stdout.Bytes())
|
||||
}
|
||||
if stderr.Len() > 0 {
|
||||
command.Stderr.write.Write(stderr.Bytes())
|
||||
c.Stderr.write.Write(stderr.Bytes())
|
||||
}
|
||||
if finished {
|
||||
command.exitCode = exitCode
|
||||
command.Stderr.write.Close()
|
||||
command.Stdout.write.Close()
|
||||
c.exitCode = exitCode
|
||||
c.Stderr.write.Close()
|
||||
c.Stdout.write.Close()
|
||||
}
|
||||
|
||||
return
|
||||
return finished, nil
|
||||
}
|
||||
|
||||
func (command *Command) sendInput(data []byte) (err error) {
|
||||
if err = command.check(); err != nil {
|
||||
func (c *Command) sendInput(data []byte) error {
|
||||
if err := c.check(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
request := NewSendInputRequest(command.client.url, command.shell.ShellId, command.commandId, data, &command.client.Parameters)
|
||||
request := NewSendInputRequest(c.client.url, c.shell.id, c.id, data, &c.client.Parameters)
|
||||
defer request.Free()
|
||||
|
||||
_, err = command.client.sendRequest(request)
|
||||
return
|
||||
_, err := c.client.sendRequest(request)
|
||||
return err
|
||||
}
|
||||
|
||||
// ExitCode returns command exit code when it is finished. Before that the result is always 0.
|
||||
func (command *Command) ExitCode() int {
|
||||
return command.exitCode
|
||||
func (c *Command) ExitCode() int {
|
||||
return c.exitCode
|
||||
}
|
||||
|
||||
// Calling this function will block the current goroutine until the remote command terminates.
|
||||
func (command *Command) Wait() {
|
||||
// Wait function will block the current goroutine until the remote command terminates.
|
||||
func (c *Command) Wait() {
|
||||
// block until finished
|
||||
<-command.done
|
||||
<-c.done
|
||||
}
|
||||
|
||||
// Write data to this Pipe
|
||||
func (w *commandWriter) Write(data []byte) (written int, err error) {
|
||||
// commandWriter implements io.Writer interface
|
||||
func (w *commandWriter) Write(data []byte) (int, error) {
|
||||
|
||||
var (
|
||||
written int
|
||||
err error
|
||||
)
|
||||
|
||||
for len(data) > 0 {
|
||||
if w.eof {
|
||||
err = io.EOF
|
||||
return
|
||||
return written, io.EOF
|
||||
}
|
||||
// never send more data than our EnvelopeSize.
|
||||
n := min(w.client.Parameters.EnvelopeSize-1000, len(data))
|
||||
if err = w.sendInput(data[:n]); err != nil {
|
||||
if err := w.sendInput(data[:n]); err != nil {
|
||||
break
|
||||
}
|
||||
data = data[n:]
|
||||
written += int(n)
|
||||
written += n
|
||||
}
|
||||
return
|
||||
|
||||
return written, err
|
||||
}
|
||||
|
||||
func min(a int, b int) int {
|
||||
|
@ -195,6 +222,8 @@ func min(a int, b int) int {
|
|||
return b
|
||||
}
|
||||
|
||||
// Close method wrapper
|
||||
// commandWriter implements io.Closer interface
|
||||
func (w *commandWriter) Close() error {
|
||||
w.eof = true
|
||||
return w.Close()
|
|
@ -0,0 +1,63 @@
|
|||
package winrm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Endpoint struct holds configurations
|
||||
// for the server endpoint
|
||||
type Endpoint struct {
|
||||
// host name or ip address
|
||||
Host string
|
||||
// port to determine if it's http or https default
|
||||
// winrm ports (http:5985, https:5986).Versions
|
||||
// of winrm can be customized to listen on other ports
|
||||
Port int
|
||||
// set the flag true for https connections
|
||||
HTTPS bool
|
||||
// set the flag true for skipping ssl verifications
|
||||
Insecure bool
|
||||
// if set, used to verify the hostname on the returned certificate
|
||||
TLSServerName string
|
||||
// pointer pem certs, and key
|
||||
CACert []byte // cert auth to intdetify the server cert
|
||||
Key []byte // public key for client auth connections
|
||||
Cert []byte // cert for client auth connections
|
||||
// duration timeout for the underling tcp conn(http/https base protocol)
|
||||
// if the time exceeds the connection is cloded/timeouts
|
||||
Timeout time.Duration
|
||||
}
|
||||
|
||||
func (ep *Endpoint) url() string {
|
||||
var scheme string
|
||||
if ep.HTTPS {
|
||||
scheme = "https"
|
||||
} else {
|
||||
scheme = "http"
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s://%s:%d/wsman", scheme, ep.Host, ep.Port)
|
||||
}
|
||||
|
||||
// NewEndpoint returns new pointer to struct Endpoint, with a default 60s response header timeout
|
||||
func NewEndpoint(host string, port int, https bool, insecure bool, Cacert, cert, key []byte, timeout time.Duration) *Endpoint {
|
||||
endpoint := &Endpoint{
|
||||
Host: host,
|
||||
Port: port,
|
||||
HTTPS: https,
|
||||
Insecure: insecure,
|
||||
CACert: Cacert,
|
||||
Key: key,
|
||||
Cert: cert,
|
||||
}
|
||||
// if the timeout was set
|
||||
if timeout != 0 {
|
||||
endpoint.Timeout = timeout
|
||||
} else {
|
||||
// assign default 60sec timeout
|
||||
endpoint.Timeout = 60 * time.Second
|
||||
}
|
||||
|
||||
return endpoint
|
||||
}
|
|
@ -0,0 +1,13 @@
|
|||
package winrm
|
||||
|
||||
import "fmt"
|
||||
|
||||
// errWinrm generic error struct
|
||||
type errWinrm struct {
|
||||
message string
|
||||
}
|
||||
|
||||
// ErrWinrm implements the Error type interface
|
||||
func (e errWinrm) Error() string {
|
||||
return fmt.Sprintf("%s", e.message)
|
||||
}
|
|
@ -0,0 +1,101 @@
|
|||
package winrm
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/masterzen/winrm/soap"
|
||||
)
|
||||
|
||||
var soapXML = "application/soap+xml"
|
||||
|
||||
// body func reads the response body and return it as a string
|
||||
func body(response *http.Response) (string, error) {
|
||||
|
||||
// if we recived the content we expected
|
||||
if strings.Contains(response.Header.Get("Content-Type"), "application/soap+xml") {
|
||||
body, err := ioutil.ReadAll(response.Body)
|
||||
defer func() {
|
||||
// defer can modify the returned value before
|
||||
// it is actually passed to the calling statement
|
||||
if errClose := response.Body.Close(); errClose != nil && err == nil {
|
||||
err = errClose
|
||||
}
|
||||
}()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error while reading request body %s", err)
|
||||
}
|
||||
|
||||
return string(body), nil
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("invalid content type")
|
||||
}
|
||||
|
||||
type clientRequest struct {
|
||||
transport http.RoundTripper
|
||||
}
|
||||
|
||||
func (c *clientRequest) Transport(endpoint *Endpoint) error {
|
||||
transport := &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
TLSClientConfig: &tls.Config{
|
||||
InsecureSkipVerify: endpoint.Insecure,
|
||||
ServerName: endpoint.TLSServerName,
|
||||
},
|
||||
Dial: (&net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}).Dial,
|
||||
ResponseHeaderTimeout: endpoint.Timeout,
|
||||
}
|
||||
|
||||
if endpoint.CACert != nil && len(endpoint.CACert) > 0 {
|
||||
certPool, err := readCACerts(endpoint.CACert)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
transport.TLSClientConfig.RootCAs = certPool
|
||||
}
|
||||
|
||||
c.transport = transport
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Post make post to the winrm soap service
|
||||
func (c clientRequest) Post(client *Client, request *soap.SoapMessage) (string, error) {
|
||||
httpClient := &http.Client{Transport: c.transport}
|
||||
|
||||
req, err := http.NewRequest("POST", client.url, strings.NewReader(request.String()))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("impossible to create http request %s", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", soapXML+";charset=UTF-8")
|
||||
req.SetBasicAuth(client.username, client.password)
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("unknown error %s", err)
|
||||
}
|
||||
|
||||
body, err := body(resp)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("http response error: %d - %s", resp.StatusCode, err.Error())
|
||||
}
|
||||
|
||||
// if we have different 200 http status code
|
||||
// we must replace the error
|
||||
defer func() {
|
||||
if resp.StatusCode != 200 {
|
||||
body, err = "", fmt.Errorf("http error %d: %s", resp.StatusCode, body)
|
||||
}
|
||||
}()
|
||||
|
||||
return body, err
|
||||
}
|
|
@ -0,0 +1,23 @@
|
|||
package winrm
|
||||
|
||||
import (
|
||||
"github.com/Azure/go-ntlmssp"
|
||||
"github.com/masterzen/winrm/soap"
|
||||
)
|
||||
|
||||
// ClientNTLM provides a transport via NTLMv2
|
||||
type ClientNTLM struct {
|
||||
clientRequest
|
||||
}
|
||||
|
||||
// Transport creates the wrapped NTLM transport
|
||||
func (c *ClientNTLM) Transport(endpoint *Endpoint) error {
|
||||
c.clientRequest.Transport(endpoint)
|
||||
c.clientRequest.transport = &ntlmssp.Negotiator{RoundTripper: c.clientRequest.transport}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Post make post to the winrm soap service (forwarded to clientRequest implementation)
|
||||
func (c ClientNTLM) Post(client *Client, request *soap.SoapMessage) (string, error) {
|
||||
return c.clientRequest.Post(client, request)
|
||||
}
|
|
@ -0,0 +1,24 @@
|
|||
package winrm
|
||||
|
||||
// Parameters struct defines
|
||||
// metadata information and http transport config
|
||||
type Parameters struct {
|
||||
Timeout string
|
||||
Locale string
|
||||
EnvelopeSize int
|
||||
TransportDecorator func() Transporter
|
||||
}
|
||||
|
||||
// DefaultParameters return constant config
|
||||
// of type Parameters
|
||||
var DefaultParameters = NewParameters("PT60S", "en-US", 153600)
|
||||
|
||||
// NewParameters return new struct of type Parameters
|
||||
// this struct makes the configuration for the request, size message, etc.
|
||||
func NewParameters(timeout, locale string, envelopeSize int) *Parameters {
|
||||
return &Parameters{
|
||||
Timeout: timeout,
|
||||
Locale: locale,
|
||||
EnvelopeSize: envelopeSize,
|
||||
}
|
||||
}
|
|
@ -5,7 +5,8 @@ import (
|
|||
"fmt"
|
||||
)
|
||||
|
||||
// Wraps a PowerShell script and prepares it for execution by the winrm client
|
||||
// Powershell wraps a PowerShell script
|
||||
// and prepares it for execution by the winrm client
|
||||
func Powershell(psCmd string) string {
|
||||
// 2 byte chars to make PowerShell happy
|
||||
wideCmd := ""
|
|
@ -0,0 +1,155 @@
|
|||
package winrm
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
|
||||
"github.com/masterzen/winrm/soap"
|
||||
"github.com/nu7hatch/gouuid"
|
||||
)
|
||||
|
||||
func genUUID() string {
|
||||
id, _ := uuid.NewV4()
|
||||
return "uuid:" + id.String()
|
||||
}
|
||||
|
||||
func defaultHeaders(message *soap.SoapMessage, url string, params *Parameters) *soap.SoapHeader {
|
||||
return message.
|
||||
Header().
|
||||
To(url).
|
||||
ReplyTo("http://schemas.xmlsoap.org/ws/2004/08/addressing/role/anonymous").
|
||||
MaxEnvelopeSize(params.EnvelopeSize).
|
||||
Id(genUUID()).
|
||||
Locale(params.Locale).
|
||||
Timeout(params.Timeout)
|
||||
}
|
||||
|
||||
//NewOpenShellRequest makes a new soap request
|
||||
func NewOpenShellRequest(uri string, params *Parameters) *soap.SoapMessage {
|
||||
if params == nil {
|
||||
params = DefaultParameters
|
||||
}
|
||||
|
||||
message := soap.NewMessage()
|
||||
defaultHeaders(message, uri, params).
|
||||
Action("http://schemas.xmlsoap.org/ws/2004/09/transfer/Create").
|
||||
ResourceURI("http://schemas.microsoft.com/wbem/wsman/1/windows/shell/cmd").
|
||||
AddOption(soap.NewHeaderOption("WINRS_NOPROFILE", "FALSE")).
|
||||
AddOption(soap.NewHeaderOption("WINRS_CODEPAGE", "65001")).
|
||||
Build()
|
||||
|
||||
body := message.CreateBodyElement("Shell", soap.DOM_NS_WIN_SHELL)
|
||||
input := message.CreateElement(body, "InputStreams", soap.DOM_NS_WIN_SHELL)
|
||||
input.SetContent("stdin")
|
||||
output := message.CreateElement(body, "OutputStreams", soap.DOM_NS_WIN_SHELL)
|
||||
output.SetContent("stdout stderr")
|
||||
|
||||
return message
|
||||
}
|
||||
|
||||
// NewDeleteShellRequest ...
|
||||
func NewDeleteShellRequest(uri, shellId string, params *Parameters) *soap.SoapMessage {
|
||||
if params == nil {
|
||||
params = DefaultParameters
|
||||
}
|
||||
message := soap.NewMessage()
|
||||
defaultHeaders(message, uri, params).
|
||||
Action("http://schemas.xmlsoap.org/ws/2004/09/transfer/Delete").
|
||||
ShellId(shellId).
|
||||
ResourceURI("http://schemas.microsoft.com/wbem/wsman/1/windows/shell/cmd").
|
||||
Build()
|
||||
|
||||
message.NewBody()
|
||||
|
||||
return message
|
||||
}
|
||||
|
||||
// NewExecuteCommandRequest exec command on specific shellID
|
||||
func NewExecuteCommandRequest(uri, shellId, command string, arguments []string, params *Parameters) *soap.SoapMessage {
|
||||
if params == nil {
|
||||
params = DefaultParameters
|
||||
}
|
||||
message := soap.NewMessage()
|
||||
defaultHeaders(message, uri, params).
|
||||
Action("http://schemas.microsoft.com/wbem/wsman/1/windows/shell/Command").
|
||||
ResourceURI("http://schemas.microsoft.com/wbem/wsman/1/windows/shell/cmd").
|
||||
ShellId(shellId).
|
||||
AddOption(soap.NewHeaderOption("WINRS_CONSOLEMODE_STDIN", "TRUE")).
|
||||
AddOption(soap.NewHeaderOption("WINRS_SKIP_CMD_SHELL", "FALSE")).
|
||||
Build()
|
||||
|
||||
body := message.CreateBodyElement("CommandLine", soap.DOM_NS_WIN_SHELL)
|
||||
|
||||
// ensure special characters like & don't mangle the request XML
|
||||
command = "<![CDATA[" + command + "]]>"
|
||||
commandElement := message.CreateElement(body, "Command", soap.DOM_NS_WIN_SHELL)
|
||||
commandElement.SetContent(command)
|
||||
|
||||
for _, arg := range arguments {
|
||||
arg = "<![CDATA[" + arg + "]]>"
|
||||
argumentsElement := message.CreateElement(body, "Arguments", soap.DOM_NS_WIN_SHELL)
|
||||
argumentsElement.SetContent(arg)
|
||||
}
|
||||
|
||||
return message
|
||||
}
|
||||
|
||||
func NewGetOutputRequest(uri, shellId, commandId, streams string, params *Parameters) *soap.SoapMessage {
|
||||
if params == nil {
|
||||
params = DefaultParameters
|
||||
}
|
||||
message := soap.NewMessage()
|
||||
defaultHeaders(message, uri, params).
|
||||
Action("http://schemas.microsoft.com/wbem/wsman/1/windows/shell/Receive").
|
||||
ResourceURI("http://schemas.microsoft.com/wbem/wsman/1/windows/shell/cmd").
|
||||
ShellId(shellId).
|
||||
Build()
|
||||
|
||||
receive := message.CreateBodyElement("Receive", soap.DOM_NS_WIN_SHELL)
|
||||
desiredStreams := message.CreateElement(receive, "DesiredStream", soap.DOM_NS_WIN_SHELL)
|
||||
desiredStreams.SetAttr("CommandId", commandId)
|
||||
desiredStreams.SetContent(streams)
|
||||
|
||||
return message
|
||||
}
|
||||
|
||||
func NewSendInputRequest(uri, shellId, commandId string, input []byte, params *Parameters) *soap.SoapMessage {
|
||||
if params == nil {
|
||||
params = DefaultParameters
|
||||
}
|
||||
message := soap.NewMessage()
|
||||
|
||||
defaultHeaders(message, uri, params).
|
||||
Action("http://schemas.microsoft.com/wbem/wsman/1/windows/shell/Send").
|
||||
ResourceURI("http://schemas.microsoft.com/wbem/wsman/1/windows/shell/cmd").
|
||||
ShellId(shellId).
|
||||
Build()
|
||||
|
||||
content := base64.StdEncoding.EncodeToString(input)
|
||||
|
||||
send := message.CreateBodyElement("Send", soap.DOM_NS_WIN_SHELL)
|
||||
streams := message.CreateElement(send, "Stream", soap.DOM_NS_WIN_SHELL)
|
||||
streams.SetAttr("Name", "stdin")
|
||||
streams.SetAttr("CommandId", commandId)
|
||||
streams.SetContent(content)
|
||||
return message
|
||||
}
|
||||
|
||||
func NewSignalRequest(uri string, shellId string, commandId string, params *Parameters) *soap.SoapMessage {
|
||||
if params == nil {
|
||||
params = DefaultParameters
|
||||
}
|
||||
message := soap.NewMessage()
|
||||
|
||||
defaultHeaders(message, uri, params).
|
||||
Action("http://schemas.microsoft.com/wbem/wsman/1/windows/shell/Signal").
|
||||
ResourceURI("http://schemas.microsoft.com/wbem/wsman/1/windows/shell/cmd").
|
||||
ShellId(shellId).
|
||||
Build()
|
||||
|
||||
signal := message.CreateBodyElement("Signal", soap.DOM_NS_WIN_SHELL)
|
||||
signal.SetAttr("CommandId", commandId)
|
||||
code := message.CreateElement(signal, "Code", soap.DOM_NS_WIN_SHELL)
|
||||
code.SetContent("http://schemas.microsoft.com/wbem/wsman/1/windows/shell/signal/terminate")
|
||||
|
||||
return message
|
||||
}
|
|
@ -0,0 +1,124 @@
|
|||
package winrm
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/ChrisTrenkamp/goxpath"
|
||||
"github.com/ChrisTrenkamp/goxpath/tree"
|
||||
"github.com/ChrisTrenkamp/goxpath/tree/xmltree"
|
||||
"github.com/masterzen/winrm/soap"
|
||||
)
|
||||
|
||||
func first(node tree.Node, xpath string) (string, error) {
|
||||
nodes, err := xPath(node, xpath)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if len(nodes) < 1 {
|
||||
return "", err
|
||||
}
|
||||
return nodes[0].ResValue(), nil
|
||||
}
|
||||
|
||||
func any(node tree.Node, xpath string) (bool, error) {
|
||||
nodes, err := xPath(node, xpath)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if len(nodes) > 0 {
|
||||
return true, nil
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func xPath(node tree.Node, xpath string) (tree.NodeSet, error) {
|
||||
xpExec := goxpath.MustParse(xpath)
|
||||
nodes, err := xpExec.ExecNode(node, soap.GetAllXPathNamespaces())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return nodes, nil
|
||||
}
|
||||
|
||||
func ParseOpenShellResponse(response string) (string, error) {
|
||||
doc, err := xmltree.ParseXML(strings.NewReader(response))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return first(doc, "//w:Selector[@Name='ShellId']")
|
||||
}
|
||||
|
||||
func ParseExecuteCommandResponse(response string) (string, error) {
|
||||
doc, err := xmltree.ParseXML(strings.NewReader(response))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return first(doc, "//rsp:CommandId")
|
||||
}
|
||||
|
||||
func ParseSlurpOutputErrResponse(response string, stdout, stderr io.Writer) (bool, int, error) {
|
||||
var (
|
||||
finished bool
|
||||
exitCode int
|
||||
)
|
||||
|
||||
doc, err := xmltree.ParseXML(strings.NewReader(response))
|
||||
|
||||
stdouts, _ := xPath(doc, "//rsp:Stream[@Name='stdout']")
|
||||
for _, node := range stdouts {
|
||||
content, _ := base64.StdEncoding.DecodeString(node.ResValue())
|
||||
stdout.Write(content)
|
||||
}
|
||||
stderrs, _ := xPath(doc, "//rsp:Stream[@Name='stderr']")
|
||||
for _, node := range stderrs {
|
||||
content, _ := base64.StdEncoding.DecodeString(node.ResValue())
|
||||
stderr.Write(content)
|
||||
}
|
||||
|
||||
ended, _ := any(doc, "//*[@State='http://schemas.microsoft.com/wbem/wsman/1/windows/shell/CommandState/Done']")
|
||||
|
||||
if ended {
|
||||
finished = ended
|
||||
if exitBool, _ := any(doc, "//rsp:ExitCode"); exitBool {
|
||||
exit, _ := first(doc, "//rsp:ExitCode")
|
||||
exitCode, _ = strconv.Atoi(exit)
|
||||
}
|
||||
} else {
|
||||
finished = false
|
||||
}
|
||||
|
||||
return finished, exitCode, err
|
||||
}
|
||||
|
||||
func ParseSlurpOutputResponse(response string, stream io.Writer, streamType string) (bool, int, error) {
|
||||
var (
|
||||
finished bool
|
||||
exitCode int
|
||||
)
|
||||
|
||||
doc, err := xmltree.ParseXML(strings.NewReader(response))
|
||||
|
||||
nodes, _ := xPath(doc, fmt.Sprintf("//rsp:Stream[@Name='%s']", streamType))
|
||||
for _, node := range nodes {
|
||||
content, _ := base64.StdEncoding.DecodeString(node.ResValue())
|
||||
stream.Write(content)
|
||||
}
|
||||
|
||||
ended, _ := any(doc, "//*[@State='http://schemas.microsoft.com/wbem/wsman/1/windows/shell/CommandState/Done']")
|
||||
|
||||
if ended {
|
||||
finished = ended
|
||||
if exitBool, _ := any(doc, "//rsp:ExitCode"); exitBool {
|
||||
exit, _ := first(doc, "//rsp:ExitCode")
|
||||
exitCode, _ = strconv.Atoi(exit)
|
||||
}
|
||||
} else {
|
||||
finished = false
|
||||
}
|
||||
|
||||
return finished, exitCode, err
|
||||
}
|
|
@ -0,0 +1,36 @@
|
|||
package winrm
|
||||
|
||||
// Shell is the local view of a WinRM Shell of a given Client
|
||||
type Shell struct {
|
||||
client *Client
|
||||
id string
|
||||
}
|
||||
|
||||
// Execute command on the given Shell, returning either an error or a Command
|
||||
func (s *Shell) Execute(command string, arguments ...string) (*Command, error) {
|
||||
request := NewExecuteCommandRequest(s.client.url, s.id, command, arguments, &s.client.Parameters)
|
||||
defer request.Free()
|
||||
|
||||
response, err := s.client.sendRequest(request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
commandID, err := ParseExecuteCommandResponse(response)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cmd := newCommand(s, commandID)
|
||||
|
||||
return cmd, nil
|
||||
}
|
||||
|
||||
// Close will terminate this shell. No commands can be issued once the shell is closed.
|
||||
func (s *Shell) Close() error {
|
||||
request := NewDeleteShellRequest(s.client.url, s.id, &s.client.Parameters)
|
||||
defer request.Free()
|
||||
|
||||
_, err := s.client.sendRequest(request)
|
||||
return err
|
||||
}
|
|
@ -1,8 +1,9 @@
|
|||
package soap
|
||||
|
||||
import (
|
||||
"github.com/masterzen/simplexml/dom"
|
||||
"strconv"
|
||||
|
||||
"github.com/masterzen/simplexml/dom"
|
||||
)
|
||||
|
||||
type HeaderOption struct {
|
||||
|
@ -99,62 +100,62 @@ func (self *SoapHeader) Options(options []HeaderOption) *SoapHeader {
|
|||
}
|
||||
|
||||
func (self *SoapHeader) Build() *SoapMessage {
|
||||
header := self.createElement(self.message.envelope, "Header", NS_SOAP_ENV)
|
||||
header := self.createElement(self.message.envelope, "Header", DOM_NS_SOAP_ENV)
|
||||
|
||||
if self.to != "" {
|
||||
to := self.createElement(header, "To", NS_ADDRESSING)
|
||||
to := self.createElement(header, "To", DOM_NS_ADDRESSING)
|
||||
to.SetContent(self.to)
|
||||
}
|
||||
|
||||
if self.replyTo != "" {
|
||||
replyTo := self.createElement(header, "ReplyTo", NS_ADDRESSING)
|
||||
a := self.createMUElement(replyTo, "Address", NS_ADDRESSING, true)
|
||||
replyTo := self.createElement(header, "ReplyTo", DOM_NS_ADDRESSING)
|
||||
a := self.createMUElement(replyTo, "Address", DOM_NS_ADDRESSING, true)
|
||||
a.SetContent(self.replyTo)
|
||||
}
|
||||
|
||||
if self.maxEnvelopeSize != "" {
|
||||
envelope := self.createMUElement(header, "MaxEnvelopeSize", NS_WSMAN_DMTF, true)
|
||||
envelope := self.createMUElement(header, "MaxEnvelopeSize", DOM_NS_WSMAN_DMTF, true)
|
||||
envelope.SetContent(self.maxEnvelopeSize)
|
||||
}
|
||||
|
||||
if self.timeout != "" {
|
||||
timeout := self.createElement(header, "OperationTimeout", NS_WSMAN_DMTF)
|
||||
timeout := self.createElement(header, "OperationTimeout", DOM_NS_WSMAN_DMTF)
|
||||
timeout.SetContent(self.timeout)
|
||||
}
|
||||
|
||||
if self.id != "" {
|
||||
id := self.createElement(header, "MessageID", NS_ADDRESSING)
|
||||
id := self.createElement(header, "MessageID", DOM_NS_ADDRESSING)
|
||||
id.SetContent(self.id)
|
||||
}
|
||||
|
||||
if self.locale != "" {
|
||||
locale := self.createMUElement(header, "Locale", NS_WSMAN_DMTF, false)
|
||||
locale := self.createMUElement(header, "Locale", DOM_NS_WSMAN_DMTF, false)
|
||||
locale.SetAttr("xml:lang", self.locale)
|
||||
datalocale := self.createMUElement(header, "DataLocale", NS_WSMAN_MSFT, false)
|
||||
datalocale := self.createMUElement(header, "DataLocale", DOM_NS_WSMAN_MSFT, false)
|
||||
datalocale.SetAttr("xml:lang", self.locale)
|
||||
}
|
||||
|
||||
if self.action != "" {
|
||||
action := self.createMUElement(header, "Action", NS_ADDRESSING, true)
|
||||
action := self.createMUElement(header, "Action", DOM_NS_ADDRESSING, true)
|
||||
action.SetContent(self.action)
|
||||
}
|
||||
|
||||
if self.shellId != "" {
|
||||
selectorSet := self.createElement(header, "SelectorSet", NS_WSMAN_DMTF)
|
||||
selector := self.createElement(selectorSet, "Selector", NS_WSMAN_DMTF)
|
||||
selectorSet := self.createElement(header, "SelectorSet", DOM_NS_WSMAN_DMTF)
|
||||
selector := self.createElement(selectorSet, "Selector", DOM_NS_WSMAN_DMTF)
|
||||
selector.SetAttr("Name", "ShellId")
|
||||
selector.SetContent(self.shellId)
|
||||
}
|
||||
|
||||
if self.resourceURI != "" {
|
||||
resource := self.createMUElement(header, "ResourceURI", NS_WSMAN_DMTF, true)
|
||||
resource := self.createMUElement(header, "ResourceURI", DOM_NS_WSMAN_DMTF, true)
|
||||
resource.SetContent(self.resourceURI)
|
||||
}
|
||||
|
||||
if len(self.options) > 0 {
|
||||
set := self.createElement(header, "OptionSet", NS_WSMAN_DMTF)
|
||||
set := self.createElement(header, "OptionSet", DOM_NS_WSMAN_DMTF)
|
||||
for _, option := range self.options {
|
||||
e := self.createElement(set, "Option", NS_WSMAN_DMTF)
|
||||
e := self.createElement(set, "Option", DOM_NS_WSMAN_DMTF)
|
||||
e.SetAttr("Name", option.key)
|
||||
e.SetContent(option.value)
|
||||
}
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue