structs package now includes entire helper package
This commit is contained in:
parent
79c117877e
commit
9532c84f76
|
@ -0,0 +1,60 @@
|
|||
package discover
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
|
||||
"github.com/kardianos/osext"
|
||||
)
|
||||
|
||||
// Checks the current executable, then $GOPATH/bin, and finally the CWD, in that
|
||||
// order. If it can't be found, an error is returned.
|
||||
func NomadExecutable() (string, error) {
|
||||
nomadExe := "nomad"
|
||||
if runtime.GOOS == "windows" {
|
||||
nomadExe = "nomad.exe"
|
||||
}
|
||||
|
||||
// Check the current executable.
|
||||
bin, err := osext.Executable()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("Failed to determine the nomad executable: %v", err)
|
||||
}
|
||||
|
||||
if filepath.Base(bin) == nomadExe {
|
||||
return bin, nil
|
||||
}
|
||||
|
||||
// Check the $PATH
|
||||
if bin, err := exec.LookPath(nomadExe); err == nil {
|
||||
return bin, nil
|
||||
}
|
||||
|
||||
// Check the $GOPATH.
|
||||
bin = filepath.Join(os.Getenv("GOPATH"), "bin", nomadExe)
|
||||
if _, err := os.Stat(bin); err == nil {
|
||||
return bin, nil
|
||||
}
|
||||
|
||||
// Check the CWD.
|
||||
pwd, err := os.Getwd()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("Could not find Nomad executable (%v): %v", nomadExe, err)
|
||||
}
|
||||
|
||||
bin = filepath.Join(pwd, nomadExe)
|
||||
if _, err := os.Stat(bin); err == nil {
|
||||
return bin, nil
|
||||
}
|
||||
|
||||
// Check CWD/bin
|
||||
bin = filepath.Join(pwd, "bin", nomadExe)
|
||||
if _, err := os.Stat(bin); err == nil {
|
||||
return bin, nil
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("Could not find Nomad executable (%v)", nomadExe)
|
||||
}
|
|
@ -0,0 +1,169 @@
|
|||
package fields
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
)
|
||||
|
||||
// FieldData contains the raw data and the schema that the data should adhere to
|
||||
type FieldData struct {
|
||||
Raw map[string]interface{}
|
||||
Schema map[string]*FieldSchema
|
||||
}
|
||||
|
||||
// Validate cycles through the raw data and validates conversions in the schema.
|
||||
// It also checks for the existence and value of required fields.
|
||||
func (d *FieldData) Validate() error {
|
||||
var result *multierror.Error
|
||||
|
||||
// Scan for missing required fields
|
||||
for field, schema := range d.Schema {
|
||||
if schema.Required {
|
||||
_, ok := d.Raw[field]
|
||||
if !ok {
|
||||
result = multierror.Append(result, fmt.Errorf(
|
||||
"field %q is required", field))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Validate field type and value
|
||||
for field, value := range d.Raw {
|
||||
schema, ok := d.Schema[field]
|
||||
if !ok {
|
||||
result = multierror.Append(result, fmt.Errorf(
|
||||
"%q is an invalid field", field))
|
||||
continue
|
||||
}
|
||||
|
||||
switch schema.Type {
|
||||
case TypeBool, TypeInt, TypeMap, TypeArray, TypeString:
|
||||
val, _, err := d.getPrimitive(field, schema)
|
||||
if err != nil {
|
||||
result = multierror.Append(result, fmt.Errorf(
|
||||
"field %q with input %q doesn't seem to be of type %s",
|
||||
field, value, schema.Type))
|
||||
}
|
||||
// Check that we don't have an empty value for required fields
|
||||
if schema.Required && val == schema.Type.Zero() {
|
||||
result = multierror.Append(result, fmt.Errorf(
|
||||
"field %q is required, but no value was found", field))
|
||||
}
|
||||
default:
|
||||
result = multierror.Append(result, fmt.Errorf(
|
||||
"unknown field type %s for field %s", schema.Type, field))
|
||||
}
|
||||
}
|
||||
|
||||
return result.ErrorOrNil()
|
||||
}
|
||||
|
||||
// Get gets the value for the given field. If the key is an invalid field,
|
||||
// FieldData will panic. If you want a safer version of this method, use
|
||||
// GetOk. If the field k is not set, the default value (if set) will be
|
||||
// returned, otherwise the zero value will be returned.
|
||||
func (d *FieldData) Get(k string) interface{} {
|
||||
schema, ok := d.Schema[k]
|
||||
if !ok {
|
||||
panic(fmt.Sprintf("field %s not in the schema", k))
|
||||
}
|
||||
|
||||
value, ok := d.GetOk(k)
|
||||
if !ok {
|
||||
value = schema.DefaultOrZero()
|
||||
}
|
||||
|
||||
return value
|
||||
}
|
||||
|
||||
// GetOk gets the value for the given field. The second return value
|
||||
// will be false if the key is invalid or the key is not set at all.
|
||||
func (d *FieldData) GetOk(k string) (interface{}, bool) {
|
||||
schema, ok := d.Schema[k]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
result, ok, err := d.GetOkErr(k)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("error reading %s: %s", k, err))
|
||||
}
|
||||
|
||||
if ok && result == nil {
|
||||
result = schema.DefaultOrZero()
|
||||
}
|
||||
|
||||
return result, ok
|
||||
}
|
||||
|
||||
// GetOkErr is the most conservative of all the Get methods. It returns
|
||||
// whether key is set or not, but also an error value. The error value is
|
||||
// non-nil if the field doesn't exist or there was an error parsing the
|
||||
// field value.
|
||||
func (d *FieldData) GetOkErr(k string) (interface{}, bool, error) {
|
||||
schema, ok := d.Schema[k]
|
||||
if !ok {
|
||||
return nil, false, fmt.Errorf("unknown field: %s", k)
|
||||
}
|
||||
|
||||
switch schema.Type {
|
||||
case TypeBool, TypeInt, TypeMap, TypeArray, TypeString:
|
||||
return d.getPrimitive(k, schema)
|
||||
default:
|
||||
return nil, false,
|
||||
fmt.Errorf("unknown field type %s for field %s", schema.Type, k)
|
||||
}
|
||||
}
|
||||
|
||||
// getPrimitive tries to convert the raw value of a field to its data type as
|
||||
// defined in the schema. It does strict type checking, so the value will need
|
||||
// to be able to convert to the appropriate type directly.
|
||||
func (d *FieldData) getPrimitive(
|
||||
k string, schema *FieldSchema) (interface{}, bool, error) {
|
||||
raw, ok := d.Raw[k]
|
||||
if !ok {
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
switch schema.Type {
|
||||
case TypeBool:
|
||||
var result bool
|
||||
if err := mapstructure.Decode(raw, &result); err != nil {
|
||||
return nil, true, err
|
||||
}
|
||||
return result, true, nil
|
||||
|
||||
case TypeInt:
|
||||
var result int
|
||||
if err := mapstructure.Decode(raw, &result); err != nil {
|
||||
return nil, true, err
|
||||
}
|
||||
return result, true, nil
|
||||
|
||||
case TypeString:
|
||||
var result string
|
||||
if err := mapstructure.Decode(raw, &result); err != nil {
|
||||
return nil, true, err
|
||||
}
|
||||
return result, true, nil
|
||||
|
||||
case TypeMap:
|
||||
var result map[string]interface{}
|
||||
if err := mapstructure.Decode(raw, &result); err != nil {
|
||||
return nil, true, err
|
||||
}
|
||||
return result, true, nil
|
||||
|
||||
case TypeArray:
|
||||
var result []interface{}
|
||||
if err := mapstructure.Decode(raw, &result); err != nil {
|
||||
return nil, true, err
|
||||
}
|
||||
return result, true, nil
|
||||
|
||||
default:
|
||||
panic(fmt.Sprintf("Unknown type: %s", schema.Type))
|
||||
}
|
||||
}
|
|
@ -0,0 +1,19 @@
|
|||
package fields
|
||||
|
||||
// FieldSchema is a basic schema to describe the format of a configuration field
|
||||
type FieldSchema struct {
|
||||
Type FieldType
|
||||
Default interface{}
|
||||
Description string
|
||||
Required bool
|
||||
}
|
||||
|
||||
// DefaultOrZero returns the default value if it is set, or otherwise
|
||||
// the zero value of the type.
|
||||
func (s *FieldSchema) DefaultOrZero() interface{} {
|
||||
if s.Default != nil {
|
||||
return s.Default
|
||||
}
|
||||
|
||||
return s.Type.Zero()
|
||||
}
|
|
@ -0,0 +1,47 @@
|
|||
package fields
|
||||
|
||||
// FieldType is the enum of types that a field can be.
|
||||
type FieldType uint
|
||||
|
||||
const (
|
||||
TypeInvalid FieldType = 0
|
||||
TypeString FieldType = iota
|
||||
TypeInt
|
||||
TypeBool
|
||||
TypeMap
|
||||
TypeArray
|
||||
)
|
||||
|
||||
func (t FieldType) String() string {
|
||||
switch t {
|
||||
case TypeString:
|
||||
return "string"
|
||||
case TypeInt:
|
||||
return "integer"
|
||||
case TypeBool:
|
||||
return "boolean"
|
||||
case TypeMap:
|
||||
return "map"
|
||||
case TypeArray:
|
||||
return "array"
|
||||
default:
|
||||
return "unknown type"
|
||||
}
|
||||
}
|
||||
|
||||
func (t FieldType) Zero() interface{} {
|
||||
switch t {
|
||||
case TypeString:
|
||||
return ""
|
||||
case TypeInt:
|
||||
return 0
|
||||
case TypeBool:
|
||||
return false
|
||||
case TypeMap:
|
||||
return map[string]interface{}{}
|
||||
case TypeArray:
|
||||
return []interface{}{}
|
||||
default:
|
||||
panic("unknown type: " + t.String())
|
||||
}
|
||||
}
|
|
@ -0,0 +1,60 @@
|
|||
package flaghelper
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// StringFlag implements the flag.Value interface and allows multiple
|
||||
// calls to the same variable to append a list.
|
||||
type StringFlag []string
|
||||
|
||||
func (s *StringFlag) String() string {
|
||||
return strings.Join(*s, ",")
|
||||
}
|
||||
|
||||
func (s *StringFlag) Set(value string) error {
|
||||
*s = append(*s, value)
|
||||
return nil
|
||||
}
|
||||
|
||||
// FuncVar is a type of flag that accepts a function that is the string
|
||||
// given
|
||||
// by the user.
|
||||
type FuncVar func(s string) error
|
||||
|
||||
func (f FuncVar) Set(s string) error { return f(s) }
|
||||
func (f FuncVar) String() string { return "" }
|
||||
func (f FuncVar) IsBoolFlag() bool { return false }
|
||||
|
||||
// FuncBoolVar is a type of flag that accepts a function, converts the
|
||||
// user's
|
||||
// value to a bool, and then calls the given function.
|
||||
type FuncBoolVar func(b bool) error
|
||||
|
||||
func (f FuncBoolVar) Set(s string) error {
|
||||
v, err := strconv.ParseBool(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return f(v)
|
||||
}
|
||||
func (f FuncBoolVar) String() string { return "" }
|
||||
func (f FuncBoolVar) IsBoolFlag() bool { return true }
|
||||
|
||||
// FuncDurationVar is a type of flag that
|
||||
// accepts a function, converts the
|
||||
// user's value to a duration, and then
|
||||
// calls the given function.
|
||||
type FuncDurationVar func(d time.Duration) error
|
||||
|
||||
func (f FuncDurationVar) Set(s string) error {
|
||||
v, err := time.ParseDuration(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return f(v)
|
||||
}
|
||||
func (f FuncDurationVar) String() string { return "" }
|
||||
func (f FuncDurationVar) IsBoolFlag() bool { return false }
|
|
@ -0,0 +1,156 @@
|
|||
package helper
|
||||
|
||||
import "regexp"
|
||||
|
||||
// validUUID is used to check if a given string looks like a UUID
|
||||
var validUUID = regexp.MustCompile(`(?i)^[\da-f]{8}-[\da-f]{4}-[\da-f]{4}-[\da-f]{4}-[\da-f]{12}$`)
|
||||
|
||||
// IsUUID returns true if the given string is a valid UUID.
|
||||
func IsUUID(str string) bool {
|
||||
const uuidLen = 36
|
||||
if len(str) != uuidLen {
|
||||
return false
|
||||
}
|
||||
|
||||
return validUUID.MatchString(str)
|
||||
}
|
||||
|
||||
// boolToPtr returns the pointer to a boolean
|
||||
func BoolToPtr(b bool) *bool {
|
||||
return &b
|
||||
}
|
||||
|
||||
// MapStringStringSliceValueSet returns the set of values in a map[string][]string
|
||||
func MapStringStringSliceValueSet(m map[string][]string) []string {
|
||||
set := make(map[string]struct{})
|
||||
for _, slice := range m {
|
||||
for _, v := range slice {
|
||||
set[v] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
flat := make([]string, 0, len(set))
|
||||
for k := range set {
|
||||
flat = append(flat, k)
|
||||
}
|
||||
return flat
|
||||
}
|
||||
|
||||
func SliceStringToSet(s []string) map[string]struct{} {
|
||||
m := make(map[string]struct{}, (len(s)+1)/2)
|
||||
for _, k := range s {
|
||||
m[k] = struct{}{}
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// SliceStringIsSubset returns whether the smaller set of strings is a subset of
|
||||
// the larger. If the smaller slice is not a subset, the offending elements are
|
||||
// returned.
|
||||
func SliceStringIsSubset(larger, smaller []string) (bool, []string) {
|
||||
largerSet := make(map[string]struct{}, len(larger))
|
||||
for _, l := range larger {
|
||||
largerSet[l] = struct{}{}
|
||||
}
|
||||
|
||||
subset := true
|
||||
var offending []string
|
||||
for _, s := range smaller {
|
||||
if _, ok := largerSet[s]; !ok {
|
||||
subset = false
|
||||
offending = append(offending, s)
|
||||
}
|
||||
}
|
||||
|
||||
return subset, offending
|
||||
}
|
||||
|
||||
func SliceSetDisjoint(first, second []string) (bool, []string) {
|
||||
contained := make(map[string]struct{}, len(first))
|
||||
for _, k := range first {
|
||||
contained[k] = struct{}{}
|
||||
}
|
||||
|
||||
offending := make(map[string]struct{})
|
||||
for _, k := range second {
|
||||
if _, ok := contained[k]; ok {
|
||||
offending[k] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
if len(offending) == 0 {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
flattened := make([]string, 0, len(offending))
|
||||
for k := range offending {
|
||||
flattened = append(flattened, k)
|
||||
}
|
||||
return false, flattened
|
||||
}
|
||||
|
||||
// Helpers for copying generic structures.
|
||||
func CopyMapStringString(m map[string]string) map[string]string {
|
||||
l := len(m)
|
||||
if l == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
c := make(map[string]string, l)
|
||||
for k, v := range m {
|
||||
c[k] = v
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
func CopyMapStringInt(m map[string]int) map[string]int {
|
||||
l := len(m)
|
||||
if l == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
c := make(map[string]int, l)
|
||||
for k, v := range m {
|
||||
c[k] = v
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
func CopyMapStringFloat64(m map[string]float64) map[string]float64 {
|
||||
l := len(m)
|
||||
if l == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
c := make(map[string]float64, l)
|
||||
for k, v := range m {
|
||||
c[k] = v
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
func CopySliceString(s []string) []string {
|
||||
l := len(s)
|
||||
if l == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
c := make([]string, l)
|
||||
for i, v := range s {
|
||||
c[i] = v
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
func CopySliceInt(s []int) []int {
|
||||
l := len(s)
|
||||
if l == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
c := make([]int, l)
|
||||
for i, v := range s {
|
||||
c[i] = v
|
||||
}
|
||||
return c
|
||||
}
|
|
@ -0,0 +1,43 @@
|
|||
package gatedwriter
|
||||
|
||||
import (
|
||||
"io"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Writer is an io.Writer implementation that buffers all of its
|
||||
// data into an internal buffer until it is told to let data through.
|
||||
type Writer struct {
|
||||
Writer io.Writer
|
||||
|
||||
buf [][]byte
|
||||
flush bool
|
||||
lock sync.RWMutex
|
||||
}
|
||||
|
||||
// Flush tells the Writer to flush any buffered data and to stop
|
||||
// buffering.
|
||||
func (w *Writer) Flush() {
|
||||
w.lock.Lock()
|
||||
w.flush = true
|
||||
w.lock.Unlock()
|
||||
|
||||
for _, p := range w.buf {
|
||||
w.Write(p)
|
||||
}
|
||||
w.buf = nil
|
||||
}
|
||||
|
||||
func (w *Writer) Write(p []byte) (n int, err error) {
|
||||
w.lock.RLock()
|
||||
defer w.lock.RUnlock()
|
||||
|
||||
if w.flush {
|
||||
return w.Writer.Write(p)
|
||||
}
|
||||
|
||||
p2 := make([]byte, len(p))
|
||||
copy(p2, p)
|
||||
w.buf = append(w.buf, p2)
|
||||
return len(p), nil
|
||||
}
|
|
@ -0,0 +1,67 @@
|
|||
package stats
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"sync"
|
||||
|
||||
"github.com/shirou/gopsutil/cpu"
|
||||
)
|
||||
|
||||
var (
|
||||
cpuMhzPerCore float64
|
||||
cpuModelName string
|
||||
cpuNumCores int
|
||||
cpuTotalTicks float64
|
||||
|
||||
onceLer sync.Once
|
||||
)
|
||||
|
||||
func Init() error {
|
||||
var err error
|
||||
onceLer.Do(func() {
|
||||
if cpuNumCores, err = cpu.Counts(true); err != nil {
|
||||
err = fmt.Errorf("Unable to determine the number of CPU cores available: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
var cpuInfo []cpu.InfoStat
|
||||
if cpuInfo, err = cpu.Info(); err != nil {
|
||||
err = fmt.Errorf("Unable to obtain CPU information: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, cpu := range cpuInfo {
|
||||
cpuModelName = cpu.ModelName
|
||||
cpuMhzPerCore = cpu.Mhz
|
||||
break
|
||||
}
|
||||
|
||||
// Floor all of the values such that small difference don't cause the
|
||||
// node to fall into a unique computed node class
|
||||
cpuMhzPerCore = math.Floor(cpuMhzPerCore)
|
||||
cpuTotalTicks = math.Floor(float64(cpuNumCores) * cpuMhzPerCore)
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
// CPUModelName returns the number of CPU cores available
|
||||
func CPUNumCores() int {
|
||||
return cpuNumCores
|
||||
}
|
||||
|
||||
// CPUMHzPerCore returns the MHz per CPU core
|
||||
func CPUMHzPerCore() float64 {
|
||||
return cpuMhzPerCore
|
||||
}
|
||||
|
||||
// CPUModelName returns the model name of the CPU
|
||||
func CPUModelName() string {
|
||||
return cpuModelName
|
||||
}
|
||||
|
||||
// TotalTicksAvailable calculates the total frequency available across all
|
||||
// cores
|
||||
func TotalTicksAvailable() float64 {
|
||||
return cpuTotalTicks
|
||||
}
|
|
@ -0,0 +1,118 @@
|
|||
// Package testtask implements a portable set of commands useful as stand-ins
|
||||
// for user tasks.
|
||||
package testtask
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"os/exec"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/nomad/client/driver/env"
|
||||
"github.com/hashicorp/nomad/nomad/structs"
|
||||
"github.com/kardianos/osext"
|
||||
)
|
||||
|
||||
// Path returns the path to the currently running executable.
|
||||
func Path() string {
|
||||
path, err := osext.Executable()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return path
|
||||
}
|
||||
|
||||
// SetEnv configures the environment of the task so that Run executes a testtask
|
||||
// script when called from within cmd.
|
||||
func SetEnv(env *env.TaskEnvironment) {
|
||||
env.AppendEnvvars(map[string]string{"TEST_TASK": "execute"})
|
||||
}
|
||||
|
||||
// SetCmdEnv configures the environment of cmd so that Run executes a testtask
|
||||
// script when called from within cmd.
|
||||
func SetCmdEnv(cmd *exec.Cmd) {
|
||||
cmd.Env = append(os.Environ(), "TEST_TASK=execute")
|
||||
}
|
||||
|
||||
// SetTaskEnv configures the environment of t so that Run executes a testtask
|
||||
// script when called from within t.
|
||||
func SetTaskEnv(t *structs.Task) {
|
||||
if t.Env == nil {
|
||||
t.Env = map[string]string{}
|
||||
}
|
||||
t.Env["TEST_TASK"] = "execute"
|
||||
}
|
||||
|
||||
// Run interprets os.Args as a testtask script if the current program was
|
||||
// launched with an environment configured by SetCmdEnv or SetTaskEnv. It
|
||||
// returns false if the environment was not set by this package.
|
||||
func Run() bool {
|
||||
switch tm := os.Getenv("TEST_TASK"); tm {
|
||||
case "":
|
||||
return false
|
||||
case "execute":
|
||||
execute()
|
||||
return true
|
||||
default:
|
||||
fmt.Fprintf(os.Stderr, "unexpected value for TEST_TASK, \"%s\"\n", tm)
|
||||
os.Exit(1)
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
func execute() {
|
||||
if len(os.Args) < 2 {
|
||||
fmt.Fprintln(os.Stderr, "no command provided")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
args := os.Args[1:]
|
||||
|
||||
// popArg removes the first argument from args and returns it.
|
||||
popArg := func() string {
|
||||
s := args[0]
|
||||
args = args[1:]
|
||||
return s
|
||||
}
|
||||
|
||||
// execute a sequence of operations from args
|
||||
for len(args) > 0 {
|
||||
switch cmd := popArg(); cmd {
|
||||
|
||||
case "sleep":
|
||||
// sleep <dur>: sleep for a duration indicated by the first
|
||||
// argument
|
||||
if len(args) < 1 {
|
||||
fmt.Fprintln(os.Stderr, "expected arg for sleep")
|
||||
os.Exit(1)
|
||||
}
|
||||
dur, err := time.ParseDuration(popArg())
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "could not parse sleep time: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
time.Sleep(dur)
|
||||
|
||||
case "echo":
|
||||
// echo <msg>: write the msg followed by a newline to stdout.
|
||||
fmt.Println(popArg())
|
||||
|
||||
case "write":
|
||||
// write <msg> <file>: write a message to a file. The first
|
||||
// argument is the msg. The second argument is the path to the
|
||||
// target file.
|
||||
if len(args) < 2 {
|
||||
fmt.Fprintln(os.Stderr, "expected two args for write")
|
||||
os.Exit(1)
|
||||
}
|
||||
msg := popArg()
|
||||
file := popArg()
|
||||
ioutil.WriteFile(file, []byte(msg), 0666)
|
||||
|
||||
default:
|
||||
fmt.Fprintln(os.Stderr, "unknown command:", cmd)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,258 @@
|
|||
package tlsutil
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
// RegionSpecificWrapper is used to invoke a static Region and turns a
|
||||
// RegionWrapper into a Wrapper type.
|
||||
func RegionSpecificWrapper(region string, tlsWrap RegionWrapper) Wrapper {
|
||||
if tlsWrap == nil {
|
||||
return nil
|
||||
}
|
||||
return func(conn net.Conn) (net.Conn, error) {
|
||||
return tlsWrap(region, conn)
|
||||
}
|
||||
}
|
||||
|
||||
// RegionWrapper is a function that is used to wrap a non-TLS connection and
|
||||
// returns an appropriate TLS connection or error. This takes a Region as an
|
||||
// argument.
|
||||
type RegionWrapper func(region string, conn net.Conn) (net.Conn, error)
|
||||
|
||||
// Wrapper wraps a connection and enables TLS on it.
|
||||
type Wrapper func(conn net.Conn) (net.Conn, error)
|
||||
|
||||
// Config used to create tls.Config
|
||||
type Config struct {
|
||||
// VerifyIncoming is used to verify the authenticity of incoming connections.
|
||||
// This means that TCP requests are forbidden, only allowing for TLS. TLS connections
|
||||
// must match a provided certificate authority. This can be used to force client auth.
|
||||
VerifyIncoming bool
|
||||
|
||||
// VerifyOutgoing is used to verify the authenticity of outgoing connections.
|
||||
// This means that TLS requests are used, and TCP requests are not made. TLS connections
|
||||
// must match a provided certificate authority. This is used to verify authenticity of
|
||||
// server nodes.
|
||||
VerifyOutgoing bool
|
||||
|
||||
// VerifyServerHostname is used to enable hostname verification of servers. This
|
||||
// ensures that the certificate presented is valid for server.<datacenter>.<domain>.
|
||||
// This prevents a compromised client from being restarted as a server, and then
|
||||
// intercepting request traffic as well as being added as a raft peer. This should be
|
||||
// enabled by default with VerifyOutgoing, but for legacy reasons we cannot break
|
||||
// existing clients.
|
||||
VerifyServerHostname bool
|
||||
|
||||
// CAFile is a path to a certificate authority file. This is used with VerifyIncoming
|
||||
// or VerifyOutgoing to verify the TLS connection.
|
||||
CAFile string
|
||||
|
||||
// CertFile is used to provide a TLS certificate that is used for serving TLS connections.
|
||||
// Must be provided to serve TLS connections.
|
||||
CertFile string
|
||||
|
||||
// KeyFile is used to provide a TLS key that is used for serving TLS connections.
|
||||
// Must be provided to serve TLS connections.
|
||||
KeyFile string
|
||||
}
|
||||
|
||||
// AppendCA opens and parses the CA file and adds the certificates to
|
||||
// the provided CertPool.
|
||||
func (c *Config) AppendCA(pool *x509.CertPool) error {
|
||||
if c.CAFile == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Read the file
|
||||
data, err := ioutil.ReadFile(c.CAFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to read CA file: %v", err)
|
||||
}
|
||||
|
||||
if !pool.AppendCertsFromPEM(data) {
|
||||
return fmt.Errorf("Failed to parse any CA certificates")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// KeyPair is used to open and parse a certificate and key file
|
||||
func (c *Config) KeyPair() (*tls.Certificate, error) {
|
||||
if c.CertFile == "" || c.KeyFile == "" {
|
||||
return nil, nil
|
||||
}
|
||||
cert, err := tls.LoadX509KeyPair(c.CertFile, c.KeyFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Failed to load cert/key pair: %v", err)
|
||||
}
|
||||
return &cert, err
|
||||
}
|
||||
|
||||
// OutgoingTLSConfig generates a TLS configuration for outgoing
|
||||
// requests. It will return a nil config if this configuration should
|
||||
// not use TLS for outgoing connections.
|
||||
func (c *Config) OutgoingTLSConfig() (*tls.Config, error) {
|
||||
// If VerifyServerHostname is true, that implies VerifyOutgoing
|
||||
if c.VerifyServerHostname {
|
||||
c.VerifyOutgoing = true
|
||||
}
|
||||
if !c.VerifyOutgoing {
|
||||
return nil, nil
|
||||
}
|
||||
// Create the tlsConfig
|
||||
tlsConfig := &tls.Config{
|
||||
RootCAs: x509.NewCertPool(),
|
||||
InsecureSkipVerify: true,
|
||||
}
|
||||
if c.VerifyServerHostname {
|
||||
tlsConfig.InsecureSkipVerify = false
|
||||
}
|
||||
|
||||
// Ensure we have a CA if VerifyOutgoing is set
|
||||
if c.VerifyOutgoing && c.CAFile == "" {
|
||||
return nil, fmt.Errorf("VerifyOutgoing set, and no CA certificate provided!")
|
||||
}
|
||||
|
||||
// Parse the CA cert if any
|
||||
err := c.AppendCA(tlsConfig.RootCAs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Add cert/key
|
||||
cert, err := c.KeyPair()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if cert != nil {
|
||||
tlsConfig.Certificates = []tls.Certificate{*cert}
|
||||
}
|
||||
|
||||
return tlsConfig, nil
|
||||
}
|
||||
|
||||
// OutgoingTLSWrapper returns a a Wrapper based on the OutgoingTLS
|
||||
// configuration. If hostname verification is on, the wrapper
|
||||
// will properly generate the dynamic server name for verification.
|
||||
func (c *Config) OutgoingTLSWrapper() (RegionWrapper, error) {
|
||||
// Get the TLS config
|
||||
tlsConfig, err := c.OutgoingTLSConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Check if TLS is not enabled
|
||||
if tlsConfig == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Generate the wrapper based on hostname verification
|
||||
if c.VerifyServerHostname {
|
||||
wrapper := func(region string, conn net.Conn) (net.Conn, error) {
|
||||
conf := *tlsConfig
|
||||
conf.ServerName = "server." + region + ".nomad"
|
||||
return WrapTLSClient(conn, &conf)
|
||||
}
|
||||
return wrapper, nil
|
||||
} else {
|
||||
wrapper := func(dc string, c net.Conn) (net.Conn, error) {
|
||||
return WrapTLSClient(c, tlsConfig)
|
||||
}
|
||||
return wrapper, nil
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Wrap a net.Conn into a client tls connection, performing any
|
||||
// additional verification as needed.
|
||||
//
|
||||
// As of go 1.3, crypto/tls only supports either doing no certificate
|
||||
// verification, or doing full verification including of the peer's
|
||||
// DNS name. For consul, we want to validate that the certificate is
|
||||
// signed by a known CA, but because consul doesn't use DNS names for
|
||||
// node names, we don't verify the certificate DNS names. Since go 1.3
|
||||
// no longer supports this mode of operation, we have to do it
|
||||
// manually.
|
||||
func WrapTLSClient(conn net.Conn, tlsConfig *tls.Config) (net.Conn, error) {
|
||||
var err error
|
||||
var tlsConn *tls.Conn
|
||||
|
||||
tlsConn = tls.Client(conn, tlsConfig)
|
||||
|
||||
// If crypto/tls is doing verification, there's no need to do
|
||||
// our own.
|
||||
if tlsConfig.InsecureSkipVerify == false {
|
||||
return tlsConn, nil
|
||||
}
|
||||
|
||||
if err = tlsConn.Handshake(); err != nil {
|
||||
tlsConn.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// The following is lightly-modified from the doFullHandshake
|
||||
// method in crypto/tls's handshake_client.go.
|
||||
opts := x509.VerifyOptions{
|
||||
Roots: tlsConfig.RootCAs,
|
||||
CurrentTime: time.Now(),
|
||||
DNSName: "",
|
||||
Intermediates: x509.NewCertPool(),
|
||||
}
|
||||
|
||||
certs := tlsConn.ConnectionState().PeerCertificates
|
||||
for i, cert := range certs {
|
||||
if i == 0 {
|
||||
continue
|
||||
}
|
||||
opts.Intermediates.AddCert(cert)
|
||||
}
|
||||
|
||||
_, err = certs[0].Verify(opts)
|
||||
if err != nil {
|
||||
tlsConn.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return tlsConn, err
|
||||
}
|
||||
|
||||
// IncomingTLSConfig generates a TLS configuration for incoming requests
|
||||
func (c *Config) IncomingTLSConfig() (*tls.Config, error) {
|
||||
// Create the tlsConfig
|
||||
tlsConfig := &tls.Config{
|
||||
ClientCAs: x509.NewCertPool(),
|
||||
ClientAuth: tls.NoClientCert,
|
||||
}
|
||||
|
||||
// Parse the CA cert if any
|
||||
err := c.AppendCA(tlsConfig.ClientCAs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Add cert/key
|
||||
cert, err := c.KeyPair()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if cert != nil {
|
||||
tlsConfig.Certificates = []tls.Certificate{*cert}
|
||||
}
|
||||
|
||||
// Check if we require verification
|
||||
if c.VerifyIncoming {
|
||||
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
|
||||
if c.CAFile == "" {
|
||||
return nil, fmt.Errorf("VerifyIncoming set, and no CA certificate provided!")
|
||||
}
|
||||
if cert == nil {
|
||||
return nil, fmt.Errorf("VerifyIncoming set, and no Cert/Key pair provided!")
|
||||
}
|
||||
}
|
||||
|
||||
return tlsConfig, nil
|
||||
}
|
Loading…
Reference in New Issue