diff --git a/vendor/github.com/hashicorp/nomad/helper/discover/discover.go b/vendor/github.com/hashicorp/nomad/helper/discover/discover.go new file mode 100644 index 000000000..8582a0133 --- /dev/null +++ b/vendor/github.com/hashicorp/nomad/helper/discover/discover.go @@ -0,0 +1,60 @@ +package discover + +import ( + "fmt" + "os" + "os/exec" + "path/filepath" + "runtime" + + "github.com/kardianos/osext" +) + +// Checks the current executable, then $GOPATH/bin, and finally the CWD, in that +// order. If it can't be found, an error is returned. +func NomadExecutable() (string, error) { + nomadExe := "nomad" + if runtime.GOOS == "windows" { + nomadExe = "nomad.exe" + } + + // Check the current executable. + bin, err := osext.Executable() + if err != nil { + return "", fmt.Errorf("Failed to determine the nomad executable: %v", err) + } + + if filepath.Base(bin) == nomadExe { + return bin, nil + } + + // Check the $PATH + if bin, err := exec.LookPath(nomadExe); err == nil { + return bin, nil + } + + // Check the $GOPATH. + bin = filepath.Join(os.Getenv("GOPATH"), "bin", nomadExe) + if _, err := os.Stat(bin); err == nil { + return bin, nil + } + + // Check the CWD. + pwd, err := os.Getwd() + if err != nil { + return "", fmt.Errorf("Could not find Nomad executable (%v): %v", nomadExe, err) + } + + bin = filepath.Join(pwd, nomadExe) + if _, err := os.Stat(bin); err == nil { + return bin, nil + } + + // Check CWD/bin + bin = filepath.Join(pwd, "bin", nomadExe) + if _, err := os.Stat(bin); err == nil { + return bin, nil + } + + return "", fmt.Errorf("Could not find Nomad executable (%v)", nomadExe) +} diff --git a/vendor/github.com/hashicorp/nomad/helper/fields/data.go b/vendor/github.com/hashicorp/nomad/helper/fields/data.go new file mode 100644 index 000000000..fb22bbc59 --- /dev/null +++ b/vendor/github.com/hashicorp/nomad/helper/fields/data.go @@ -0,0 +1,169 @@ +package fields + +import ( + "fmt" + + "github.com/hashicorp/go-multierror" + "github.com/mitchellh/mapstructure" +) + +// FieldData contains the raw data and the schema that the data should adhere to +type FieldData struct { + Raw map[string]interface{} + Schema map[string]*FieldSchema +} + +// Validate cycles through the raw data and validates conversions in the schema. +// It also checks for the existence and value of required fields. +func (d *FieldData) Validate() error { + var result *multierror.Error + + // Scan for missing required fields + for field, schema := range d.Schema { + if schema.Required { + _, ok := d.Raw[field] + if !ok { + result = multierror.Append(result, fmt.Errorf( + "field %q is required", field)) + } + } + } + + // Validate field type and value + for field, value := range d.Raw { + schema, ok := d.Schema[field] + if !ok { + result = multierror.Append(result, fmt.Errorf( + "%q is an invalid field", field)) + continue + } + + switch schema.Type { + case TypeBool, TypeInt, TypeMap, TypeArray, TypeString: + val, _, err := d.getPrimitive(field, schema) + if err != nil { + result = multierror.Append(result, fmt.Errorf( + "field %q with input %q doesn't seem to be of type %s", + field, value, schema.Type)) + } + // Check that we don't have an empty value for required fields + if schema.Required && val == schema.Type.Zero() { + result = multierror.Append(result, fmt.Errorf( + "field %q is required, but no value was found", field)) + } + default: + result = multierror.Append(result, fmt.Errorf( + "unknown field type %s for field %s", schema.Type, field)) + } + } + + return result.ErrorOrNil() +} + +// Get gets the value for the given field. If the key is an invalid field, +// FieldData will panic. If you want a safer version of this method, use +// GetOk. If the field k is not set, the default value (if set) will be +// returned, otherwise the zero value will be returned. +func (d *FieldData) Get(k string) interface{} { + schema, ok := d.Schema[k] + if !ok { + panic(fmt.Sprintf("field %s not in the schema", k)) + } + + value, ok := d.GetOk(k) + if !ok { + value = schema.DefaultOrZero() + } + + return value +} + +// GetOk gets the value for the given field. The second return value +// will be false if the key is invalid or the key is not set at all. +func (d *FieldData) GetOk(k string) (interface{}, bool) { + schema, ok := d.Schema[k] + if !ok { + return nil, false + } + + result, ok, err := d.GetOkErr(k) + if err != nil { + panic(fmt.Sprintf("error reading %s: %s", k, err)) + } + + if ok && result == nil { + result = schema.DefaultOrZero() + } + + return result, ok +} + +// GetOkErr is the most conservative of all the Get methods. It returns +// whether key is set or not, but also an error value. The error value is +// non-nil if the field doesn't exist or there was an error parsing the +// field value. +func (d *FieldData) GetOkErr(k string) (interface{}, bool, error) { + schema, ok := d.Schema[k] + if !ok { + return nil, false, fmt.Errorf("unknown field: %s", k) + } + + switch schema.Type { + case TypeBool, TypeInt, TypeMap, TypeArray, TypeString: + return d.getPrimitive(k, schema) + default: + return nil, false, + fmt.Errorf("unknown field type %s for field %s", schema.Type, k) + } +} + +// getPrimitive tries to convert the raw value of a field to its data type as +// defined in the schema. It does strict type checking, so the value will need +// to be able to convert to the appropriate type directly. +func (d *FieldData) getPrimitive( + k string, schema *FieldSchema) (interface{}, bool, error) { + raw, ok := d.Raw[k] + if !ok { + return nil, false, nil + } + + switch schema.Type { + case TypeBool: + var result bool + if err := mapstructure.Decode(raw, &result); err != nil { + return nil, true, err + } + return result, true, nil + + case TypeInt: + var result int + if err := mapstructure.Decode(raw, &result); err != nil { + return nil, true, err + } + return result, true, nil + + case TypeString: + var result string + if err := mapstructure.Decode(raw, &result); err != nil { + return nil, true, err + } + return result, true, nil + + case TypeMap: + var result map[string]interface{} + if err := mapstructure.Decode(raw, &result); err != nil { + return nil, true, err + } + return result, true, nil + + case TypeArray: + var result []interface{} + if err := mapstructure.Decode(raw, &result); err != nil { + return nil, true, err + } + return result, true, nil + + default: + panic(fmt.Sprintf("Unknown type: %s", schema.Type)) + } +} diff --git a/vendor/github.com/hashicorp/nomad/helper/fields/schema.go b/vendor/github.com/hashicorp/nomad/helper/fields/schema.go new file mode 100644 index 000000000..f57a97685 --- /dev/null +++ b/vendor/github.com/hashicorp/nomad/helper/fields/schema.go @@ -0,0 +1,19 @@ +package fields + +// FieldSchema is a basic schema to describe the format of a configuration field +type FieldSchema struct { + Type FieldType + Default interface{} + Description string + Required bool +} + +// DefaultOrZero returns the default value if it is set, or otherwise +// the zero value of the type. +func (s *FieldSchema) DefaultOrZero() interface{} { + if s.Default != nil { + return s.Default + } + + return s.Type.Zero() +} diff --git a/vendor/github.com/hashicorp/nomad/helper/fields/type.go b/vendor/github.com/hashicorp/nomad/helper/fields/type.go new file mode 100644 index 000000000..dced1b186 --- /dev/null +++ b/vendor/github.com/hashicorp/nomad/helper/fields/type.go @@ -0,0 +1,47 @@ +package fields + +// FieldType is the enum of types that a field can be. +type FieldType uint + +const ( + TypeInvalid FieldType = 0 + TypeString FieldType = iota + TypeInt + TypeBool + TypeMap + TypeArray +) + +func (t FieldType) String() string { + switch t { + case TypeString: + return "string" + case TypeInt: + return "integer" + case TypeBool: + return "boolean" + case TypeMap: + return "map" + case TypeArray: + return "array" + default: + return "unknown type" + } +} + +func (t FieldType) Zero() interface{} { + switch t { + case TypeString: + return "" + case TypeInt: + return 0 + case TypeBool: + return false + case TypeMap: + return map[string]interface{}{} + case TypeArray: + return []interface{}{} + default: + panic("unknown type: " + t.String()) + } +} diff --git a/vendor/github.com/hashicorp/nomad/helper/flag-helpers/flag.go b/vendor/github.com/hashicorp/nomad/helper/flag-helpers/flag.go new file mode 100644 index 000000000..10a5644e2 --- /dev/null +++ b/vendor/github.com/hashicorp/nomad/helper/flag-helpers/flag.go @@ -0,0 +1,60 @@ +package flaghelper + +import ( + "strconv" + "strings" + "time" +) + +// StringFlag implements the flag.Value interface and allows multiple +// calls to the same variable to append a list. +type StringFlag []string + +func (s *StringFlag) String() string { + return strings.Join(*s, ",") +} + +func (s *StringFlag) Set(value string) error { + *s = append(*s, value) + return nil +} + +// FuncVar is a type of flag that accepts a function that is the string +// given +// by the user. +type FuncVar func(s string) error + +func (f FuncVar) Set(s string) error { return f(s) } +func (f FuncVar) String() string { return "" } +func (f FuncVar) IsBoolFlag() bool { return false } + +// FuncBoolVar is a type of flag that accepts a function, converts the +// user's +// value to a bool, and then calls the given function. +type FuncBoolVar func(b bool) error + +func (f FuncBoolVar) Set(s string) error { + v, err := strconv.ParseBool(s) + if err != nil { + return err + } + return f(v) +} +func (f FuncBoolVar) String() string { return "" } +func (f FuncBoolVar) IsBoolFlag() bool { return true } + +// FuncDurationVar is a type of flag that +// accepts a function, converts the +// user's value to a duration, and then +// calls the given function. +type FuncDurationVar func(d time.Duration) error + +func (f FuncDurationVar) Set(s string) error { + v, err := time.ParseDuration(s) + if err != nil { + return err + } + return f(v) +} +func (f FuncDurationVar) String() string { return "" } +func (f FuncDurationVar) IsBoolFlag() bool { return false } diff --git a/vendor/github.com/hashicorp/nomad/helper/funcs.go b/vendor/github.com/hashicorp/nomad/helper/funcs.go new file mode 100644 index 000000000..89538f42c --- /dev/null +++ b/vendor/github.com/hashicorp/nomad/helper/funcs.go @@ -0,0 +1,156 @@ +package helper + +import "regexp" + +// validUUID is used to check if a given string looks like a UUID +var validUUID = regexp.MustCompile(`(?i)^[\da-f]{8}-[\da-f]{4}-[\da-f]{4}-[\da-f]{4}-[\da-f]{12}$`) + +// IsUUID returns true if the given string is a valid UUID. +func IsUUID(str string) bool { + const uuidLen = 36 + if len(str) != uuidLen { + return false + } + + return validUUID.MatchString(str) +} + +// boolToPtr returns the pointer to a boolean +func BoolToPtr(b bool) *bool { + return &b +} + +// MapStringStringSliceValueSet returns the set of values in a map[string][]string +func MapStringStringSliceValueSet(m map[string][]string) []string { + set := make(map[string]struct{}) + for _, slice := range m { + for _, v := range slice { + set[v] = struct{}{} + } + } + + flat := make([]string, 0, len(set)) + for k := range set { + flat = append(flat, k) + } + return flat +} + +func SliceStringToSet(s []string) map[string]struct{} { + m := make(map[string]struct{}, (len(s)+1)/2) + for _, k := range s { + m[k] = struct{}{} + } + return m +} + +// SliceStringIsSubset returns whether the smaller set of strings is a subset of +// the larger. If the smaller slice is not a subset, the offending elements are +// returned. +func SliceStringIsSubset(larger, smaller []string) (bool, []string) { + largerSet := make(map[string]struct{}, len(larger)) + for _, l := range larger { + largerSet[l] = struct{}{} + } + + subset := true + var offending []string + for _, s := range smaller { + if _, ok := largerSet[s]; !ok { + subset = false + offending = append(offending, s) + } + } + + return subset, offending +} + +func SliceSetDisjoint(first, second []string) (bool, []string) { + contained := make(map[string]struct{}, len(first)) + for _, k := range first { + contained[k] = struct{}{} + } + + offending := make(map[string]struct{}) + for _, k := range second { + if _, ok := contained[k]; ok { + offending[k] = struct{}{} + } + } + + if len(offending) == 0 { + return true, nil + } + + flattened := make([]string, 0, len(offending)) + for k := range offending { + flattened = append(flattened, k) + } + return false, flattened +} + +// Helpers for copying generic structures. +func CopyMapStringString(m map[string]string) map[string]string { + l := len(m) + if l == 0 { + return nil + } + + c := make(map[string]string, l) + for k, v := range m { + c[k] = v + } + return c +} + +func CopyMapStringInt(m map[string]int) map[string]int { + l := len(m) + if l == 0 { + return nil + } + + c := make(map[string]int, l) + for k, v := range m { + c[k] = v + } + return c +} + +func CopyMapStringFloat64(m map[string]float64) map[string]float64 { + l := len(m) + if l == 0 { + return nil + } + + c := make(map[string]float64, l) + for k, v := range m { + c[k] = v + } + return c +} + +func CopySliceString(s []string) []string { + l := len(s) + if l == 0 { + return nil + } + + c := make([]string, l) + for i, v := range s { + c[i] = v + } + return c +} + +func CopySliceInt(s []int) []int { + l := len(s) + if l == 0 { + return nil + } + + c := make([]int, l) + for i, v := range s { + c[i] = v + } + return c +} diff --git a/vendor/github.com/hashicorp/nomad/helper/gated-writer/writer.go b/vendor/github.com/hashicorp/nomad/helper/gated-writer/writer.go new file mode 100644 index 000000000..9c5aeba00 --- /dev/null +++ b/vendor/github.com/hashicorp/nomad/helper/gated-writer/writer.go @@ -0,0 +1,43 @@ +package gatedwriter + +import ( + "io" + "sync" +) + +// Writer is an io.Writer implementation that buffers all of its +// data into an internal buffer until it is told to let data through. +type Writer struct { + Writer io.Writer + + buf [][]byte + flush bool + lock sync.RWMutex +} + +// Flush tells the Writer to flush any buffered data and to stop +// buffering. +func (w *Writer) Flush() { + w.lock.Lock() + w.flush = true + w.lock.Unlock() + + for _, p := range w.buf { + w.Write(p) + } + w.buf = nil +} + +func (w *Writer) Write(p []byte) (n int, err error) { + w.lock.RLock() + defer w.lock.RUnlock() + + if w.flush { + return w.Writer.Write(p) + } + + p2 := make([]byte, len(p)) + copy(p2, p) + w.buf = append(w.buf, p2) + return len(p), nil +} diff --git a/vendor/github.com/hashicorp/nomad/helper/stats/cpu.go b/vendor/github.com/hashicorp/nomad/helper/stats/cpu.go new file mode 100644 index 000000000..9c0cd72d8 --- /dev/null +++ b/vendor/github.com/hashicorp/nomad/helper/stats/cpu.go @@ -0,0 +1,67 @@ +package stats + +import ( + "fmt" + "math" + "sync" + + "github.com/shirou/gopsutil/cpu" +) + +var ( + cpuMhzPerCore float64 + cpuModelName string + cpuNumCores int + cpuTotalTicks float64 + + onceLer sync.Once +) + +func Init() error { + var err error + onceLer.Do(func() { + if cpuNumCores, err = cpu.Counts(true); err != nil { + err = fmt.Errorf("Unable to determine the number of CPU cores available: %v", err) + return + } + + var cpuInfo []cpu.InfoStat + if cpuInfo, err = cpu.Info(); err != nil { + err = fmt.Errorf("Unable to obtain CPU information: %v", err) + return + } + + for _, cpu := range cpuInfo { + cpuModelName = cpu.ModelName + cpuMhzPerCore = cpu.Mhz + break + } + + // Floor all of the values such that small difference don't cause the + // node to fall into a unique computed node class + cpuMhzPerCore = math.Floor(cpuMhzPerCore) + cpuTotalTicks = math.Floor(float64(cpuNumCores) * cpuMhzPerCore) + }) + return err +} + +// CPUModelName returns the number of CPU cores available +func CPUNumCores() int { + return cpuNumCores +} + +// CPUMHzPerCore returns the MHz per CPU core +func CPUMHzPerCore() float64 { + return cpuMhzPerCore +} + +// CPUModelName returns the model name of the CPU +func CPUModelName() string { + return cpuModelName +} + +// TotalTicksAvailable calculates the total frequency available across all +// cores +func TotalTicksAvailable() float64 { + return cpuTotalTicks +} diff --git a/vendor/github.com/hashicorp/nomad/helper/testtask/testtask.go b/vendor/github.com/hashicorp/nomad/helper/testtask/testtask.go new file mode 100644 index 000000000..cfcf205f5 --- /dev/null +++ b/vendor/github.com/hashicorp/nomad/helper/testtask/testtask.go @@ -0,0 +1,118 @@ +// Package testtask implements a portable set of commands useful as stand-ins +// for user tasks. +package testtask + +import ( + "fmt" + "io/ioutil" + "os" + "os/exec" + "time" + + "github.com/hashicorp/nomad/client/driver/env" + "github.com/hashicorp/nomad/nomad/structs" + "github.com/kardianos/osext" +) + +// Path returns the path to the currently running executable. +func Path() string { + path, err := osext.Executable() + if err != nil { + panic(err) + } + return path +} + +// SetEnv configures the environment of the task so that Run executes a testtask +// script when called from within cmd. +func SetEnv(env *env.TaskEnvironment) { + env.AppendEnvvars(map[string]string{"TEST_TASK": "execute"}) +} + +// SetCmdEnv configures the environment of cmd so that Run executes a testtask +// script when called from within cmd. +func SetCmdEnv(cmd *exec.Cmd) { + cmd.Env = append(os.Environ(), "TEST_TASK=execute") +} + +// SetTaskEnv configures the environment of t so that Run executes a testtask +// script when called from within t. +func SetTaskEnv(t *structs.Task) { + if t.Env == nil { + t.Env = map[string]string{} + } + t.Env["TEST_TASK"] = "execute" +} + +// Run interprets os.Args as a testtask script if the current program was +// launched with an environment configured by SetCmdEnv or SetTaskEnv. It +// returns false if the environment was not set by this package. +func Run() bool { + switch tm := os.Getenv("TEST_TASK"); tm { + case "": + return false + case "execute": + execute() + return true + default: + fmt.Fprintf(os.Stderr, "unexpected value for TEST_TASK, \"%s\"\n", tm) + os.Exit(1) + return true + } +} + +func execute() { + if len(os.Args) < 2 { + fmt.Fprintln(os.Stderr, "no command provided") + os.Exit(1) + } + + args := os.Args[1:] + + // popArg removes the first argument from args and returns it. + popArg := func() string { + s := args[0] + args = args[1:] + return s + } + + // execute a sequence of operations from args + for len(args) > 0 { + switch cmd := popArg(); cmd { + + case "sleep": + // sleep : sleep for a duration indicated by the first + // argument + if len(args) < 1 { + fmt.Fprintln(os.Stderr, "expected arg for sleep") + os.Exit(1) + } + dur, err := time.ParseDuration(popArg()) + if err != nil { + fmt.Fprintf(os.Stderr, "could not parse sleep time: %v", err) + os.Exit(1) + } + time.Sleep(dur) + + case "echo": + // echo : write the msg followed by a newline to stdout. + fmt.Println(popArg()) + + case "write": + // write : write a message to a file. The first + // argument is the msg. The second argument is the path to the + // target file. + if len(args) < 2 { + fmt.Fprintln(os.Stderr, "expected two args for write") + os.Exit(1) + } + msg := popArg() + file := popArg() + ioutil.WriteFile(file, []byte(msg), 0666) + + default: + fmt.Fprintln(os.Stderr, "unknown command:", cmd) + os.Exit(1) + } + } +} diff --git a/vendor/github.com/hashicorp/nomad/helper/tlsutil/config.go b/vendor/github.com/hashicorp/nomad/helper/tlsutil/config.go new file mode 100644 index 000000000..0bcf13003 --- /dev/null +++ b/vendor/github.com/hashicorp/nomad/helper/tlsutil/config.go @@ -0,0 +1,258 @@ +package tlsutil + +import ( + "crypto/tls" + "crypto/x509" + "fmt" + "io/ioutil" + "net" + "time" +) + +// RegionSpecificWrapper is used to invoke a static Region and turns a +// RegionWrapper into a Wrapper type. +func RegionSpecificWrapper(region string, tlsWrap RegionWrapper) Wrapper { + if tlsWrap == nil { + return nil + } + return func(conn net.Conn) (net.Conn, error) { + return tlsWrap(region, conn) + } +} + +// RegionWrapper is a function that is used to wrap a non-TLS connection and +// returns an appropriate TLS connection or error. This takes a Region as an +// argument. +type RegionWrapper func(region string, conn net.Conn) (net.Conn, error) + +// Wrapper wraps a connection and enables TLS on it. +type Wrapper func(conn net.Conn) (net.Conn, error) + +// Config used to create tls.Config +type Config struct { + // VerifyIncoming is used to verify the authenticity of incoming connections. + // This means that TCP requests are forbidden, only allowing for TLS. TLS connections + // must match a provided certificate authority. This can be used to force client auth. + VerifyIncoming bool + + // VerifyOutgoing is used to verify the authenticity of outgoing connections. + // This means that TLS requests are used, and TCP requests are not made. TLS connections + // must match a provided certificate authority. This is used to verify authenticity of + // server nodes. + VerifyOutgoing bool + + // VerifyServerHostname is used to enable hostname verification of servers. This + // ensures that the certificate presented is valid for server... + // This prevents a compromised client from being restarted as a server, and then + // intercepting request traffic as well as being added as a raft peer. This should be + // enabled by default with VerifyOutgoing, but for legacy reasons we cannot break + // existing clients. + VerifyServerHostname bool + + // CAFile is a path to a certificate authority file. This is used with VerifyIncoming + // or VerifyOutgoing to verify the TLS connection. + CAFile string + + // CertFile is used to provide a TLS certificate that is used for serving TLS connections. + // Must be provided to serve TLS connections. + CertFile string + + // KeyFile is used to provide a TLS key that is used for serving TLS connections. + // Must be provided to serve TLS connections. + KeyFile string +} + +// AppendCA opens and parses the CA file and adds the certificates to +// the provided CertPool. +func (c *Config) AppendCA(pool *x509.CertPool) error { + if c.CAFile == "" { + return nil + } + + // Read the file + data, err := ioutil.ReadFile(c.CAFile) + if err != nil { + return fmt.Errorf("Failed to read CA file: %v", err) + } + + if !pool.AppendCertsFromPEM(data) { + return fmt.Errorf("Failed to parse any CA certificates") + } + + return nil +} + +// KeyPair is used to open and parse a certificate and key file +func (c *Config) KeyPair() (*tls.Certificate, error) { + if c.CertFile == "" || c.KeyFile == "" { + return nil, nil + } + cert, err := tls.LoadX509KeyPair(c.CertFile, c.KeyFile) + if err != nil { + return nil, fmt.Errorf("Failed to load cert/key pair: %v", err) + } + return &cert, err +} + +// OutgoingTLSConfig generates a TLS configuration for outgoing +// requests. It will return a nil config if this configuration should +// not use TLS for outgoing connections. +func (c *Config) OutgoingTLSConfig() (*tls.Config, error) { + // If VerifyServerHostname is true, that implies VerifyOutgoing + if c.VerifyServerHostname { + c.VerifyOutgoing = true + } + if !c.VerifyOutgoing { + return nil, nil + } + // Create the tlsConfig + tlsConfig := &tls.Config{ + RootCAs: x509.NewCertPool(), + InsecureSkipVerify: true, + } + if c.VerifyServerHostname { + tlsConfig.InsecureSkipVerify = false + } + + // Ensure we have a CA if VerifyOutgoing is set + if c.VerifyOutgoing && c.CAFile == "" { + return nil, fmt.Errorf("VerifyOutgoing set, and no CA certificate provided!") + } + + // Parse the CA cert if any + err := c.AppendCA(tlsConfig.RootCAs) + if err != nil { + return nil, err + } + + // Add cert/key + cert, err := c.KeyPair() + if err != nil { + return nil, err + } else if cert != nil { + tlsConfig.Certificates = []tls.Certificate{*cert} + } + + return tlsConfig, nil +} + +// OutgoingTLSWrapper returns a a Wrapper based on the OutgoingTLS +// configuration. If hostname verification is on, the wrapper +// will properly generate the dynamic server name for verification. +func (c *Config) OutgoingTLSWrapper() (RegionWrapper, error) { + // Get the TLS config + tlsConfig, err := c.OutgoingTLSConfig() + if err != nil { + return nil, err + } + + // Check if TLS is not enabled + if tlsConfig == nil { + return nil, nil + } + + // Generate the wrapper based on hostname verification + if c.VerifyServerHostname { + wrapper := func(region string, conn net.Conn) (net.Conn, error) { + conf := *tlsConfig + conf.ServerName = "server." + region + ".nomad" + return WrapTLSClient(conn, &conf) + } + return wrapper, nil + } else { + wrapper := func(dc string, c net.Conn) (net.Conn, error) { + return WrapTLSClient(c, tlsConfig) + } + return wrapper, nil + } + +} + +// Wrap a net.Conn into a client tls connection, performing any +// additional verification as needed. +// +// As of go 1.3, crypto/tls only supports either doing no certificate +// verification, or doing full verification including of the peer's +// DNS name. For consul, we want to validate that the certificate is +// signed by a known CA, but because consul doesn't use DNS names for +// node names, we don't verify the certificate DNS names. Since go 1.3 +// no longer supports this mode of operation, we have to do it +// manually. +func WrapTLSClient(conn net.Conn, tlsConfig *tls.Config) (net.Conn, error) { + var err error + var tlsConn *tls.Conn + + tlsConn = tls.Client(conn, tlsConfig) + + // If crypto/tls is doing verification, there's no need to do + // our own. + if tlsConfig.InsecureSkipVerify == false { + return tlsConn, nil + } + + if err = tlsConn.Handshake(); err != nil { + tlsConn.Close() + return nil, err + } + + // The following is lightly-modified from the doFullHandshake + // method in crypto/tls's handshake_client.go. + opts := x509.VerifyOptions{ + Roots: tlsConfig.RootCAs, + CurrentTime: time.Now(), + DNSName: "", + Intermediates: x509.NewCertPool(), + } + + certs := tlsConn.ConnectionState().PeerCertificates + for i, cert := range certs { + if i == 0 { + continue + } + opts.Intermediates.AddCert(cert) + } + + _, err = certs[0].Verify(opts) + if err != nil { + tlsConn.Close() + return nil, err + } + + return tlsConn, err +} + +// IncomingTLSConfig generates a TLS configuration for incoming requests +func (c *Config) IncomingTLSConfig() (*tls.Config, error) { + // Create the tlsConfig + tlsConfig := &tls.Config{ + ClientCAs: x509.NewCertPool(), + ClientAuth: tls.NoClientCert, + } + + // Parse the CA cert if any + err := c.AppendCA(tlsConfig.ClientCAs) + if err != nil { + return nil, err + } + + // Add cert/key + cert, err := c.KeyPair() + if err != nil { + return nil, err + } else if cert != nil { + tlsConfig.Certificates = []tls.Certificate{*cert} + } + + // Check if we require verification + if c.VerifyIncoming { + tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert + if c.CAFile == "" { + return nil, fmt.Errorf("VerifyIncoming set, and no CA certificate provided!") + } + if cert == nil { + return nil, fmt.Errorf("VerifyIncoming set, and no Cert/Key pair provided!") + } + } + + return tlsConfig, nil +}