plugin: Client/Server uses new RPC client/server

This commit is contained in:
Mitchell Hashimoto 2014-09-28 11:19:24 -07:00
parent bc6db2bd1b
commit 04858e1a15
9 changed files with 61 additions and 161 deletions

View File

@ -8,7 +8,6 @@ import (
"io/ioutil" "io/ioutil"
"log" "log"
"net" "net"
"net/rpc"
"os" "os"
"os/exec" "os/exec"
"path/filepath" "path/filepath"
@ -16,6 +15,8 @@ import (
"sync" "sync"
"time" "time"
"unicode" "unicode"
tfrpc "github.com/hashicorp/terraform/rpc"
) )
// If this is true, then the "unexpected EOF" panic will not be // If this is true, then the "unexpected EOF" panic will not be
@ -35,8 +36,7 @@ type Client struct {
doneLogging chan struct{} doneLogging chan struct{}
l sync.Mutex l sync.Mutex
address net.Addr address net.Addr
service string client *tfrpc.Client
client *rpc.Client
} }
// ClientConfig is the configuration used to initialize a new // ClientConfig is the configuration used to initialize a new
@ -124,7 +124,7 @@ func NewClient(config *ClientConfig) (c *Client) {
// Client returns an RPC client for the plugin. // Client returns an RPC client for the plugin.
// //
// Subsequent calls to this will return the same RPC client. // Subsequent calls to this will return the same RPC client.
func (c *Client) Client() (*rpc.Client, error) { func (c *Client) Client() (*tfrpc.Client, error) {
addr, err := c.Start() addr, err := c.Start()
if err != nil { if err != nil {
return nil, err return nil, err
@ -137,17 +137,11 @@ func (c *Client) Client() (*rpc.Client, error) {
return c.client, nil return c.client, nil
} }
conn, err := net.Dial(addr.Network(), addr.String()) c.client, err = tfrpc.Dial(addr.Network(), addr.String())
if err != nil { if err != nil {
return nil, err return nil, err
} }
if tcpConn, ok := conn.(*net.TCPConn); ok {
// Make sure to set keep alive so that the connection doesn't die
tcpConn.SetKeepAlive(true)
}
c.client = rpc.NewClient(conn)
return c.client, nil return c.client, nil
} }
@ -177,15 +171,6 @@ func (c *Client) Kill() {
<-c.doneLogging <-c.doneLogging
} }
// Service returns the name of the service to use.
func (c *Client) Service() (string, error) {
if _, err := c.Start(); err != nil {
return "", err
}
return c.service, nil
}
// Starts the underlying subprocess, communicating with it to negotiate // Starts the underlying subprocess, communicating with it to negotiate
// a port for RPC connections, and returning the address to connect via RPC. // a port for RPC connections, and returning the address to connect via RPC.
// //
@ -306,8 +291,8 @@ func (c *Client) Start() (addr net.Addr, err error) {
// Trim the line and split by "|" in order to get the parts of // Trim the line and split by "|" in order to get the parts of
// the output. // the output.
line := strings.TrimSpace(string(lineBytes)) line := strings.TrimSpace(string(lineBytes))
parts := strings.SplitN(line, "|", 4) parts := strings.SplitN(line, "|", 3)
if len(parts) < 4 { if len(parts) < 3 {
err = fmt.Errorf("Unrecognized remote plugin message: %s", line) err = fmt.Errorf("Unrecognized remote plugin message: %s", line)
return return
} }
@ -327,9 +312,6 @@ func (c *Client) Start() (addr net.Addr, err error) {
default: default:
err = fmt.Errorf("Unknown address type: %s", parts[1]) err = fmt.Errorf("Unknown address type: %s", parts[1])
} }
// Grab the services
c.service = parts[3]
} }
c.address = addr c.address = addr

View File

@ -28,14 +28,6 @@ func TestClient(t *testing.T) {
t.Fatalf("bad: %#v", addr) t.Fatalf("bad: %#v", addr)
} }
service, err := c.Service()
if err != nil {
t.Fatalf("err: %s", err)
}
if service != "foo" {
t.Fatalf("bad: %#v", service)
}
// Test that it exits properly if killed // Test that it exits properly if killed
c.Kill() c.Kill()

View File

@ -8,6 +8,7 @@ import (
"testing" "testing"
"time" "time"
tfrpc "github.com/hashicorp/terraform/rpc"
"github.com/hashicorp/terraform/terraform" "github.com/hashicorp/terraform/terraform"
) )
@ -52,34 +53,31 @@ func TestHelperProcess(*testing.T) {
cmd, args := args[0], args[1:] cmd, args := args[0], args[1:]
switch cmd { switch cmd {
case "bad-version": case "bad-version":
fmt.Printf("%s1|tcp|:1234|foo\n", APIVersion) fmt.Printf("%s1|tcp|:1234\n", APIVersion)
<-make(chan int) <-make(chan int)
case "resource-provider": case "resource-provider":
err := Serve(new(terraform.MockResourceProvider)) Serve(&ServeOpts{
if err != nil { ProviderFunc: testProviderFixed(new(terraform.MockResourceProvider)),
log.Printf("[ERR] %s", err) })
os.Exit(1)
}
case "resource-provisioner": case "resource-provisioner":
err := Serve(new(terraform.MockResourceProvisioner)) Serve(&ServeOpts{
if err != nil { ProvisionerFunc: testProvisionerFixed(
log.Printf("[ERR] %s", err) new(terraform.MockResourceProvisioner)),
os.Exit(1) })
}
case "invalid-rpc-address": case "invalid-rpc-address":
fmt.Println("lolinvalid") fmt.Println("lolinvalid")
case "mock": case "mock":
fmt.Printf("%s|tcp|:1234|foo\n", APIVersion) fmt.Printf("%s|tcp|:1234\n", APIVersion)
<-make(chan int) <-make(chan int)
case "start-timeout": case "start-timeout":
time.Sleep(1 * time.Minute) time.Sleep(1 * time.Minute)
os.Exit(1) os.Exit(1)
case "stderr": case "stderr":
fmt.Printf("%s|tcp|:1234|foo\n", APIVersion) fmt.Printf("%s|tcp|:1234\n", APIVersion)
log.Println("HELLO") log.Println("HELLO")
log.Println("WORLD") log.Println("WORLD")
case "stdin": case "stdin":
fmt.Printf("%s|tcp|:1234|foo\n", APIVersion) fmt.Printf("%s|tcp|:1234\n", APIVersion)
data := make([]byte, 5) data := make([]byte, 5)
if _, err := os.Stdin.Read(data); err != nil { if _, err := os.Stdin.Read(data); err != nil {
log.Printf("stdin read error: %s", err) log.Printf("stdin read error: %s", err)
@ -96,3 +94,15 @@ func TestHelperProcess(*testing.T) {
os.Exit(2) os.Exit(2)
} }
} }
func testProviderFixed(p terraform.ResourceProvider) tfrpc.ProviderFunc {
return func() terraform.ResourceProvider {
return p
}
}
func testProvisionerFixed(p terraform.ResourceProvisioner) tfrpc.ProvisionerFunc {
return func() terraform.ResourceProvisioner {
return p
}
}

View File

@ -1,35 +0,0 @@
package plugin
import (
"os/exec"
tfrpc "github.com/hashicorp/terraform/rpc"
"github.com/hashicorp/terraform/terraform"
)
// ResourceProviderFactory returns a Terraform ResourceProviderFactory
// that executes a plugin and connects to it.
func ResourceProviderFactory(cmd *exec.Cmd) terraform.ResourceProviderFactory {
return func() (terraform.ResourceProvider, error) {
config := &ClientConfig{
Cmd: cmd,
Managed: true,
}
client := NewClient(config)
rpcClient, err := client.Client()
if err != nil {
return nil, err
}
rpcName, err := client.Service()
if err != nil {
return nil, err
}
return &tfrpc.ResourceProvider{
Client: rpcClient,
Name: rpcName,
}, nil
}
}

View File

@ -12,12 +12,4 @@ func TestResourceProvider(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("should not have error: %s", err) t.Fatalf("should not have error: %s", err)
} }
service, err := c.Service()
if err != nil {
t.Fatalf("err: %s", err)
}
if service == "" {
t.Fatal("service should not be blank")
}
} }

View File

@ -1,35 +0,0 @@
package plugin
import (
"os/exec"
tfrpc "github.com/hashicorp/terraform/rpc"
"github.com/hashicorp/terraform/terraform"
)
// ResourceProvisionerFactory returns a Terraform ResourceProvisionerFactory
// that executes a plugin and connects to it.
func ResourceProvisionerFactory(cmd *exec.Cmd) terraform.ResourceProvisionerFactory {
return func() (terraform.ResourceProvisioner, error) {
config := &ClientConfig{
Cmd: cmd,
Managed: true,
}
client := NewClient(config)
rpcClient, err := client.Client()
if err != nil {
return nil, err
}
rpcName, err := client.Service()
if err != nil {
return nil, err
}
return &tfrpc.ResourceProvisioner{
Client: rpcClient,
Name: rpcName,
}, nil
}
}

View File

@ -12,12 +12,4 @@ func TestResourceProvisioner(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("should not have error: %s", err) t.Fatalf("should not have error: %s", err)
} }
service, err := c.Service()
if err != nil {
t.Fatalf("err: %s", err)
}
if service == "" {
t.Fatal("service should not be blank")
}
} }

View File

@ -6,7 +6,6 @@ import (
"io/ioutil" "io/ioutil"
"log" "log"
"net" "net"
"net/rpc"
"os" "os"
"os/signal" "os/signal"
"runtime" "runtime"
@ -27,7 +26,17 @@ const APIVersion = "2"
const MagicCookieKey = "TF_PLUGIN_MAGIC_COOKIE" const MagicCookieKey = "TF_PLUGIN_MAGIC_COOKIE"
const MagicCookieValue = "d602bf8f470bc67ca7faa0386276bbdd4330efaf76d1a219cb4d6991ca9872b2" const MagicCookieValue = "d602bf8f470bc67ca7faa0386276bbdd4330efaf76d1a219cb4d6991ca9872b2"
func Serve(svc interface{}) error { // ServeOpts configures what sorts of plugins are served.
type ServeOpts struct {
ProviderFunc tfrpc.ProviderFunc
ProvisionerFunc tfrpc.ProvisionerFunc
}
// Serve serves the plugins given by ServeOpts.
//
// Serve doesn't return until the plugin is done being executed. Any
// errors will be outputted to the log.
func Serve(opts *ServeOpts) {
// First check the cookie // First check the cookie
if os.Getenv(MagicCookieKey) != MagicCookieValue { if os.Getenv(MagicCookieKey) != MagicCookieValue {
fmt.Fprintf(os.Stderr, fmt.Fprintf(os.Stderr,
@ -37,40 +46,30 @@ func Serve(svc interface{}) error {
os.Exit(1) os.Exit(1)
} }
// Create the server to serve our interface
server := rpc.NewServer()
// Register the service
name, err := tfrpc.Register(server, svc)
if err != nil {
return err
}
// Register a listener so we can accept a connection // Register a listener so we can accept a connection
listener, err := serverListener() listener, err := serverListener()
if err != nil { if err != nil {
return err log.Printf("[ERR] plugin init: %s", err)
return
} }
defer listener.Close() defer listener.Close()
// Output the address and service name to stdout // Create the RPC server to dispense
server := &tfrpc.Server{
ProviderFunc: opts.ProviderFunc,
ProvisionerFunc: opts.ProvisionerFunc,
}
// Output the address and service name to stdout so that Terraform
// core can bring it up.
log.Printf("Plugin address: %s %s\n", log.Printf("Plugin address: %s %s\n",
listener.Addr().Network(), listener.Addr().String()) listener.Addr().Network(), listener.Addr().String())
fmt.Printf("%s|%s|%s|%s\n", fmt.Printf("%s|%s|%s\n",
APIVersion, APIVersion,
listener.Addr().Network(), listener.Addr().Network(),
listener.Addr().String(), listener.Addr().String())
name)
os.Stdout.Sync() os.Stdout.Sync()
// Accept a connection
log.Println("Waiting for connection...")
conn, err := listener.Accept()
if err != nil {
log.Printf("Error accepting connection: %s\n", err.Error())
return err
}
// Eat the interrupts // Eat the interrupts
ch := make(chan os.Signal, 1) ch := make(chan os.Signal, 1)
signal.Notify(ch, os.Interrupt) signal.Notify(ch, os.Interrupt)
@ -85,10 +84,8 @@ func Serve(svc interface{}) error {
} }
}() }()
// Serve a single connection // Serve
log.Println("Serving a plugin connection...") server.Accept(listener)
server.ServeConn(conn)
return nil
} }
func serverListener() (net.Listener, error) { func serverListener() (net.Listener, error) {

View File

@ -23,6 +23,11 @@ func Dial(network, address string) (*Client, error) {
return nil, err return nil, err
} }
if tcpConn, ok := conn.(*net.TCPConn); ok {
// Make sure to set keep alive so that the connection doesn't die
tcpConn.SetKeepAlive(true)
}
return NewClient(conn) return NewClient(conn)
} }