diff --git a/plugin/client.go b/plugin/client.go new file mode 100644 index 000000000..edc755f2c --- /dev/null +++ b/plugin/client.go @@ -0,0 +1,356 @@ +package plugin + +import ( + "bufio" + "errors" + "fmt" + "io" + "io/ioutil" + "log" + "net" + "net/rpc" + "os" + "os/exec" + "strings" + "sync" + "time" + "unicode" +) + +// If this is true, then the "unexpected EOF" panic will not be +// raised throughout the clients. +var Killed = false + +// This is a slice of the "managed" clients which are cleaned up when +// calling Cleanup +var managedClients = make([]*Client, 0, 5) + +// Client handles the lifecycle of a plugin application, determining its +// RPC address, and returning various types of Terraform interface implementations +// across the multi-process communication layer. +type Client struct { + config *ClientConfig + exited bool + doneLogging chan struct{} + l sync.Mutex + address net.Addr + service string + client *rpc.Client +} + +// ClientConfig is the configuration used to initialize a new +// plugin client. After being used to initialize a plugin client, +// that configuration must not be modified again. +type ClientConfig struct { + // The unstarted subprocess for starting the plugin. + Cmd *exec.Cmd + + // Managed represents if the client should be managed by the + // plugin package or not. If true, then by calling CleanupClients, + // it will automatically be cleaned up. Otherwise, the client + // user is fully responsible for making sure to Kill all plugin + // clients. By default the client is _not_ managed. + Managed bool + + // The minimum and maximum port to use for communicating with + // the subprocess. If not set, this defaults to 10,000 and 25,000 + // respectively. + MinPort, MaxPort uint + + // StartTimeout is the timeout to wait for the plugin to say it + // has started successfully. + StartTimeout time.Duration + + // If non-nil, then the stderr of the client will be written to here + // (as well as the log). + Stderr io.Writer +} + +// This makes sure all the managed subprocesses are killed and properly +// logged. This should be called before the parent process running the +// plugins exits. +// +// This must only be called _once_. +func CleanupClients() { + // Set the killed to true so that we don't get unexpected panics + Killed = true + + // Kill all the managed clients in parallel and use a WaitGroup + // to wait for them all to finish up. + var wg sync.WaitGroup + for _, client := range managedClients { + wg.Add(1) + + go func(client *Client) { + client.Kill() + wg.Done() + }(client) + } + + log.Println("waiting for all plugin processes to complete...") + wg.Wait() +} + +// Creates a new plugin client which manages the lifecycle of an external +// plugin and gets the address for the RPC connection. +// +// The client must be cleaned up at some point by calling Kill(). If +// the client is a managed client (created with NewManagedClient) you +// can just call CleanupClients at the end of your program and they will +// be properly cleaned. +func NewClient(config *ClientConfig) (c *Client) { + if config.MinPort == 0 && config.MaxPort == 0 { + config.MinPort = 10000 + config.MaxPort = 25000 + } + + if config.StartTimeout == 0 { + config.StartTimeout = 1 * time.Minute + } + + if config.Stderr == nil { + config.Stderr = ioutil.Discard + } + + c = &Client{config: config} + if config.Managed { + managedClients = append(managedClients, c) + } + + return +} + +// Client returns an RPC client for the plugin. +// +// Subsequent calls to this will return the same RPC client. +func (c *Client) Client() (*rpc.Client, error) { + addr, err := c.Start() + if err != nil { + return nil, err + } + + c.l.Lock() + defer c.l.Unlock() + + if c.client != nil { + return c.client, nil + } + + conn, err := net.Dial(addr.Network(), addr.String()) + if err != nil { + return nil, err + } + + if tcpConn, ok := conn.(*net.TCPConn); ok { + // Make sure to set keep alive so that the connection doesn't die + tcpConn.SetKeepAlive(true) + } + + c.client = rpc.NewClient(conn) + return c.client, nil +} + +// Tells whether or not the underlying process has exited. +func (c *Client) Exited() bool { + c.l.Lock() + defer c.l.Unlock() + return c.exited +} + +// End the executing subprocess (if it is running) and perform any cleanup +// tasks necessary such as capturing any remaining logs and so on. +// +// This method blocks until the process successfully exits. +// +// This method can safely be called multiple times. +func (c *Client) Kill() { + cmd := c.config.Cmd + + if cmd.Process == nil { + return + } + + cmd.Process.Kill() + + // Wait for the client to finish logging so we have a complete log + <-c.doneLogging +} + +// Service returns the name of the service to use. +func (c *Client) Service() (string, error) { + if _, err := c.Start(); err != nil { + return "", err + } + + return c.service, nil +} + +// Starts the underlying subprocess, communicating with it to negotiate +// a port for RPC connections, and returning the address to connect via RPC. +// +// This method is safe to call multiple times. Subsequent calls have no effect. +// Once a client has been started once, it cannot be started again, even if +// it was killed. +func (c *Client) Start() (addr net.Addr, err error) { + c.l.Lock() + defer c.l.Unlock() + + if c.address != nil { + return c.address, nil + } + + c.doneLogging = make(chan struct{}) + + env := []string{ + fmt.Sprintf("%s=%s", MagicCookieKey, MagicCookieValue), + fmt.Sprintf("TF_PLUGIN_MIN_PORT=%d", c.config.MinPort), + fmt.Sprintf("TF_PLUGIN_MAX_PORT=%d", c.config.MaxPort), + } + + stdout_r, stdout_w := io.Pipe() + stderr_r, stderr_w := io.Pipe() + + cmd := c.config.Cmd + cmd.Env = append(cmd.Env, os.Environ()...) + cmd.Env = append(cmd.Env, env...) + cmd.Stdin = os.Stdin + cmd.Stderr = stderr_w + cmd.Stdout = stdout_w + + log.Printf("Starting plugin: %s %#v", cmd.Path, cmd.Args) + err = cmd.Start() + if err != nil { + return + } + + // Make sure the command is properly cleaned up if there is an error + defer func() { + r := recover() + + if err != nil || r != nil { + cmd.Process.Kill() + } + + if r != nil { + panic(r) + } + }() + + // Start goroutine to wait for process to exit + exitCh := make(chan struct{}) + go func() { + // Make sure we close the write end of our stderr/stdout so + // that the readers send EOF properly. + defer stderr_w.Close() + defer stdout_w.Close() + + // Wait for the command to end. + cmd.Wait() + + // Log and make sure to flush the logs write away + log.Printf("%s: plugin process exited\n", cmd.Path) + os.Stderr.Sync() + + // Mark that we exited + close(exitCh) + + // Set that we exited, which takes a lock + c.l.Lock() + defer c.l.Unlock() + c.exited = true + }() + + // Start goroutine that logs the stderr + go c.logStderr(stderr_r) + + // Start a goroutine that is going to be reading the lines + // out of stdout + linesCh := make(chan []byte) + go func() { + defer close(linesCh) + + buf := bufio.NewReader(stdout_r) + for { + line, err := buf.ReadBytes('\n') + if line != nil { + linesCh <- line + } + + if err == io.EOF { + return + } + } + }() + + // Make sure after we exit we read the lines from stdout forever + // so they dont' block since it is an io.Pipe + defer func() { + go func() { + for _ = range linesCh { + } + }() + }() + + // Some channels for the next step + timeout := time.After(c.config.StartTimeout) + + // Start looking for the address + log.Printf("Waiting for RPC address for: %s", cmd.Path) + select { + case <-timeout: + err = errors.New("timeout while waiting for plugin to start") + case <-exitCh: + err = errors.New("plugin exited before we could connect") + case lineBytes := <-linesCh: + // Trim the line and split by "|" in order to get the parts of + // the output. + line := strings.TrimSpace(string(lineBytes)) + parts := strings.SplitN(line, "|", 4) + if len(parts) < 4 { + err = fmt.Errorf("Unrecognized remote plugin message: %s", line) + return + } + + // Test the API version + if parts[0] != APIVersion { + err = fmt.Errorf("Incompatible API version with plugin. "+ + "Plugin version: %s, Ours: %s", parts[0], APIVersion) + return + } + + switch parts[1] { + case "tcp": + addr, err = net.ResolveTCPAddr("tcp", parts[2]) + case "unix": + addr, err = net.ResolveUnixAddr("unix", parts[2]) + default: + err = fmt.Errorf("Unknown address type: %s", parts[1]) + } + + // Grab the services + c.service = parts[3] + } + + c.address = addr + return +} + +func (c *Client) logStderr(r io.Reader) { + bufR := bufio.NewReader(r) + for { + line, err := bufR.ReadString('\n') + if line != "" { + c.config.Stderr.Write([]byte(line)) + + line = strings.TrimRightFunc(line, unicode.IsSpace) + log.Printf("%s: %s", c.config.Cmd.Path, line) + } + + if err == io.EOF { + break + } + } + + // Flag that we've completed logging for others + close(c.doneLogging) +} diff --git a/plugin/client_test.go b/plugin/client_test.go new file mode 100644 index 000000000..9b3486e9a --- /dev/null +++ b/plugin/client_test.go @@ -0,0 +1,153 @@ +package plugin + +import ( + "bytes" + "io/ioutil" + "os" + "strings" + "testing" + "time" +) + +func TestClient(t *testing.T) { + process := helperProcess("mock") + c := NewClient(&ClientConfig{Cmd: process}) + defer c.Kill() + + // Test that it parses the proper address + addr, err := c.Start() + if err != nil { + t.Fatalf("err should be nil, got %s", err) + } + + if addr.Network() != "tcp" { + t.Fatalf("bad: %#v", addr) + } + + if addr.String() != ":1234" { + t.Fatalf("bad: %#v", addr) + } + + service, err := c.Service() + if err != nil { + t.Fatalf("err: %s", err) + } + if service != "foo" { + t.Fatalf("bad: %#v", service) + } + + // Test that it exits properly if killed + c.Kill() + + if process.ProcessState == nil { + t.Fatal("should have process state") + } + + // Test that it knows it is exited + if !c.Exited() { + t.Fatal("should say client has exited") + } +} + +func TestClientStart_badVersion(t *testing.T) { + config := &ClientConfig{ + Cmd: helperProcess("bad-version"), + StartTimeout: 50 * time.Millisecond, + } + + c := NewClient(config) + defer c.Kill() + + _, err := c.Start() + if err == nil { + t.Fatal("err should not be nil") + } +} + +func TestClient_Start_Timeout(t *testing.T) { + config := &ClientConfig{ + Cmd: helperProcess("start-timeout"), + StartTimeout: 50 * time.Millisecond, + } + + c := NewClient(config) + defer c.Kill() + + _, err := c.Start() + if err == nil { + t.Fatal("err should not be nil") + } +} + +func TestClient_Stderr(t *testing.T) { + stderr := new(bytes.Buffer) + process := helperProcess("stderr") + c := NewClient(&ClientConfig{ + Cmd: process, + Stderr: stderr, + }) + defer c.Kill() + + if _, err := c.Start(); err != nil { + t.Fatalf("err: %s", err) + } + + for !c.Exited() { + time.Sleep(10 * time.Millisecond) + } + + if !strings.Contains(stderr.String(), "HELLO\n") { + t.Fatalf("bad log data: '%s'", stderr.String()) + } + + if !strings.Contains(stderr.String(), "WORLD\n") { + t.Fatalf("bad log data: '%s'", stderr.String()) + } +} + +func TestClient_Stdin(t *testing.T) { + // Overwrite stdin for this test with a temporary file + tf, err := ioutil.TempFile("", "packer") + if err != nil { + t.Fatalf("err: %s", err) + } + defer os.Remove(tf.Name()) + defer tf.Close() + + if _, err = tf.WriteString("hello"); err != nil { + t.Fatalf("error: %s", err) + } + + if err = tf.Sync(); err != nil { + t.Fatalf("error: %s", err) + } + + if _, err = tf.Seek(0, 0); err != nil { + t.Fatalf("error: %s", err) + } + + oldStdin := os.Stdin + defer func() { os.Stdin = oldStdin }() + os.Stdin = tf + + process := helperProcess("stdin") + c := NewClient(&ClientConfig{Cmd: process}) + defer c.Kill() + + _, err = c.Start() + if err != nil { + t.Fatalf("error: %s", err) + } + + for { + if c.Exited() { + break + } + + time.Sleep(50 * time.Millisecond) + } + + if !process.ProcessState.Success() { + t.Fatal("process didn't exit cleanly") + } +} diff --git a/plugin/plugin_test.go b/plugin/plugin_test.go new file mode 100644 index 000000000..038bc3917 --- /dev/null +++ b/plugin/plugin_test.go @@ -0,0 +1,92 @@ +package plugin + +import ( + "fmt" + "log" + "os" + "os/exec" + "testing" + "time" + + "github.com/hashicorp/terraform/terraform" +) + +func helperProcess(s ...string) *exec.Cmd { + cs := []string{"-test.run=TestHelperProcess", "--"} + cs = append(cs, s...) + env := []string{ + "GO_WANT_HELPER_PROCESS=1", + "TF_PLUGIN_MIN_PORT=10000", + "TF_PLUGIN_MAX_PORT=25000", + } + + cmd := exec.Command(os.Args[0], cs...) + cmd.Env = append(env, os.Environ()...) + return cmd +} + +// This is not a real test. This is just a helper process kicked off by +// tests. +func TestHelperProcess(*testing.T) { + if os.Getenv("GO_WANT_HELPER_PROCESS") != "1" { + return + } + + defer os.Exit(0) + + args := os.Args + for len(args) > 0 { + if args[0] == "--" { + args = args[1:] + break + } + + args = args[1:] + } + + if len(args) == 0 { + fmt.Fprintf(os.Stderr, "No command\n") + os.Exit(2) + } + + cmd, args := args[0], args[1:] + switch cmd { + case "bad-version": + fmt.Printf("%s1|tcp|:1234|foo\n", APIVersion) + <-make(chan int) + case "resource-provider": + err := Serve(new(terraform.MockResourceProvider)) + if err != nil { + log.Printf("[ERR] %s", err) + os.Exit(1) + } + case "invalid-rpc-address": + fmt.Println("lolinvalid") + case "mock": + fmt.Printf("%s|tcp|:1234|foo\n", APIVersion) + <-make(chan int) + case "start-timeout": + time.Sleep(1 * time.Minute) + os.Exit(1) + case "stderr": + fmt.Printf("%s|tcp|:1234|foo\n", APIVersion) + log.Println("HELLO") + log.Println("WORLD") + case "stdin": + fmt.Printf("%s|tcp|:1234|foo\n", APIVersion) + data := make([]byte, 5) + if _, err := os.Stdin.Read(data); err != nil { + log.Printf("stdin read error: %s", err) + os.Exit(100) + } + + if string(data) == "hello" { + os.Exit(0) + } + + os.Exit(1) + default: + fmt.Fprintf(os.Stderr, "Unknown command: %q\n", cmd) + os.Exit(2) + } +} diff --git a/plugin/resource_provider_test.go b/plugin/resource_provider_test.go new file mode 100644 index 000000000..805a079db --- /dev/null +++ b/plugin/resource_provider_test.go @@ -0,0 +1,23 @@ +package plugin + +import ( + "testing" +) + +func TestResourceProvider(t *testing.T) { + c := NewClient(&ClientConfig{Cmd: helperProcess("resource-provider")}) + defer c.Kill() + + _, err := c.Client() + if err != nil { + t.Fatalf("should not have error: %s", err) + } + + service, err := c.Service() + if err != nil { + t.Fatalf("err: %s", err) + } + if service == "" { + t.Fatal("service should not be blank") + } +} diff --git a/plugin/server.go b/plugin/server.go new file mode 100644 index 000000000..b651c6a69 --- /dev/null +++ b/plugin/server.go @@ -0,0 +1,139 @@ +package plugin + +import ( + "errors" + "fmt" + "io/ioutil" + "log" + "net" + "net/rpc" + "os" + "os/signal" + "runtime" + "strconv" + "sync/atomic" + + tfrpc "github.com/hashicorp/terraform/rpc" +) + +// The APIVersion is outputted along with the RPC address. The plugin +// client validates this API version and will show an error if it doesn't +// know how to speak it. +const APIVersion = "1" + +// The "magic cookie" is used to verify that the user intended to +// actually run this binary. If this cookie isn't present as an +// environmental variable, then we bail out early with an error. +const MagicCookieKey = "TF_PLUGIN_MAGIC_COOKIE" +const MagicCookieValue = "d602bf8f470bc67ca7faa0386276bbdd4330efaf76d1a219cb4d6991ca9872b2" + +func Serve(svc interface{}) error { + // First check the cookie + if os.Getenv(MagicCookieKey) != MagicCookieValue { + return errors.New( + "Please do not execute plugins directly. " + + "Terraform will execute these for you.") + } + + // Create the server to serve our interface + server := rpc.NewServer() + + // Register the service + name, err := tfrpc.Register(server, svc) + if err != nil { + return err + } + + // Register a listener so we can accept a connection + listener, err := serverListener() + if err != nil { + return err + } + defer listener.Close() + + // Output the address and service name to stdout + log.Printf("Plugin address: %s %s\n", + listener.Addr().Network(), listener.Addr().String()) + fmt.Printf("%s|%s|%s|%s\n", + APIVersion, + listener.Addr().Network(), + listener.Addr().String(), + name) + os.Stdout.Sync() + + // Accept a connection + log.Println("Waiting for connection...") + conn, err := listener.Accept() + if err != nil { + log.Printf("Error accepting connection: %s\n", err.Error()) + return err + } + + // Eat the interrupts + ch := make(chan os.Signal, 1) + signal.Notify(ch, os.Interrupt) + go func() { + var count int32 = 0 + for { + <-ch + newCount := atomic.AddInt32(&count, 1) + log.Printf( + "Received interrupt signal (count: %d). Ignoring.", + newCount) + } + }() + + // Serve a single connection + log.Println("Serving a plugin connection...") + server.ServeConn(conn) + return nil +} + +func serverListener() (net.Listener, error) { + if runtime.GOOS == "windows" { + return serverListener_tcp() + } + + return serverListener_unix() +} + +func serverListener_tcp() (net.Listener, error) { + minPort, err := strconv.ParseInt(os.Getenv("TF_PLUGIN_MIN_PORT"), 10, 32) + if err != nil { + return nil, err + } + + maxPort, err := strconv.ParseInt(os.Getenv("TF_PLUGIN_MAX_PORT"), 10, 32) + if err != nil { + return nil, err + } + + for port := minPort; port <= maxPort; port++ { + address := fmt.Sprintf("127.0.0.1:%d", port) + listener, err := net.Listen("tcp", address) + if err == nil { + return listener, nil + } + } + + return nil, errors.New("Couldn't bind plugin TCP listener") +} + +func serverListener_unix() (net.Listener, error) { + tf, err := ioutil.TempFile("", "tf-plugin") + if err != nil { + return nil, err + } + path := tf.Name() + + // Close the file and remove it because it has to not exist for + // the domain socket. + if err := tf.Close(); err != nil { + return nil, err + } + if err := os.Remove(path); err != nil { + return nil, err + } + + return net.Listen("unix", path) +}