config/lang: make TypeCheck implementable by other nodes

This commit is contained in:
Mitchell Hashimoto 2015-01-14 21:18:22 -08:00
parent 57adfe53f6
commit 4302dbaf2a
1 changed files with 97 additions and 56 deletions

View File

@ -24,11 +24,21 @@ type TypeCheck struct {
// value is the function to call (which must be registered in the Scope). // value is the function to call (which must be registered in the Scope).
Implicit map[ast.Type]map[ast.Type]string Implicit map[ast.Type]map[ast.Type]string
stack []ast.Type // Stack of types. This shouldn't be used directly except by implementations
// of TypeCheckNode.
Stack []ast.Type
err error err error
lock sync.Mutex lock sync.Mutex
} }
// TypeCheckNode is the interface that must be implemented by any
// ast.Node that wants to support type-checking. If the type checker
// encounters a node that doesn't implement this, it will error.
type TypeCheckNode interface {
TypeCheck(*TypeCheck) (ast.Node, error)
}
func (v *TypeCheck) Visit(root ast.Node) error { func (v *TypeCheck) Visit(root ast.Node) error {
v.lock.Lock() v.lock.Lock()
defer v.lock.Unlock() defer v.lock.Unlock()
@ -42,49 +52,69 @@ func (v *TypeCheck) visit(raw ast.Node) ast.Node {
return raw return raw
} }
var result ast.Node
var err error
switch n := raw.(type) { switch n := raw.(type) {
case *ast.Call: case *ast.Call:
v.visitCall(n) tc := &typeCheckCall{n}
result, err = tc.TypeCheck(v)
case *ast.Concat: case *ast.Concat:
v.visitConcat(n) tc := &typeCheckConcat{n}
result, err = tc.TypeCheck(v)
case *ast.LiteralNode: case *ast.LiteralNode:
v.visitLiteral(n) tc := &typeCheckLiteral{n}
result, err = tc.TypeCheck(v)
case *ast.VariableAccess: case *ast.VariableAccess:
v.visitVariableAccess(n) tc := &typeCheckVariableAccess{n}
result, err = tc.TypeCheck(v)
default: default:
v.createErr(n, fmt.Sprintf("unknown node: %#v", raw)) tc, ok := raw.(TypeCheckNode)
if !ok {
err = fmt.Errorf("unknown node: %#v", raw)
break
} }
return raw result, err = tc.TypeCheck(v)
}
if err != nil {
pos := raw.Pos()
v.err = fmt.Errorf("At column %d, line %d: %s",
pos.Column, pos.Line, err)
}
return result
} }
func (v *TypeCheck) visitCall(n *ast.Call) { type typeCheckCall struct {
n *ast.Call
}
func (tc *typeCheckCall) TypeCheck(v *TypeCheck) (ast.Node, error) {
// Look up the function in the map // Look up the function in the map
function, ok := v.Scope.LookupFunc(n.Func) function, ok := v.Scope.LookupFunc(tc.n.Func)
if !ok { if !ok {
v.createErr(n, fmt.Sprintf("unknown function called: %s", n.Func)) return nil, fmt.Errorf("unknown function called: %s", tc.n.Func)
return
} }
// The arguments are on the stack in reverse order, so pop them off. // The arguments are on the stack in reverse order, so pop them off.
args := make([]ast.Type, len(n.Args)) args := make([]ast.Type, len(tc.n.Args))
for i, _ := range n.Args { for i, _ := range tc.n.Args {
args[len(n.Args)-1-i] = v.stackPop() args[len(tc.n.Args)-1-i] = v.StackPop()
} }
// Verify the args // Verify the args
for i, expected := range function.ArgTypes { for i, expected := range function.ArgTypes {
if args[i] != expected { if args[i] != expected {
cn := v.implicitConversion(args[i], expected, n.Args[i]) cn := v.ImplicitConversion(args[i], expected, tc.n.Args[i])
if cn != nil { if cn != nil {
n.Args[i] = cn tc.n.Args[i] = cn
continue continue
} }
v.createErr(n, fmt.Sprintf( return nil, fmt.Errorf(
"%s: argument %d should be %s, got %s", "%s: argument %d should be %s, got %s",
n.Func, i+1, expected, args[i])) tc.n.Func, i+1, expected, args[i])
return
} }
} }
@ -94,75 +124,86 @@ func (v *TypeCheck) visitCall(n *ast.Call) {
for i, t := range args { for i, t := range args {
if t != function.VariadicType { if t != function.VariadicType {
realI := i + len(function.ArgTypes) realI := i + len(function.ArgTypes)
cn := v.implicitConversion( cn := v.ImplicitConversion(
t, function.VariadicType, n.Args[realI]) t, function.VariadicType, tc.n.Args[realI])
if cn != nil { if cn != nil {
n.Args[realI] = cn tc.n.Args[realI] = cn
continue continue
} }
v.createErr(n, fmt.Sprintf( return nil, fmt.Errorf(
"%s: argument %d should be %s, got %s", "%s: argument %d should be %s, got %s",
n.Func, realI, tc.n.Func, realI,
function.VariadicType, t)) function.VariadicType, t)
return
} }
} }
} }
// Return type // Return type
v.stackPush(function.ReturnType) v.StackPush(function.ReturnType)
return tc.n, nil
} }
func (v *TypeCheck) visitConcat(n *ast.Concat) { type typeCheckConcat struct {
n *ast.Concat
}
func (tc *typeCheckConcat) TypeCheck(v *TypeCheck) (ast.Node, error) {
n := tc.n
types := make([]ast.Type, len(n.Exprs)) types := make([]ast.Type, len(n.Exprs))
for i, _ := range n.Exprs { for i, _ := range n.Exprs {
types[len(n.Exprs)-1-i] = v.stackPop() types[len(n.Exprs)-1-i] = v.StackPop()
} }
// All concat args must be strings, so validate that // All concat args must be strings, so validate that
for i, t := range types { for i, t := range types {
if t != ast.TypeString { if t != ast.TypeString {
cn := v.implicitConversion(t, ast.TypeString, n.Exprs[i]) cn := v.ImplicitConversion(t, ast.TypeString, n.Exprs[i])
if cn != nil { if cn != nil {
n.Exprs[i] = cn n.Exprs[i] = cn
continue continue
} }
v.createErr(n, fmt.Sprintf( return nil, fmt.Errorf(
"argument %d must be a string", i+1)) "argument %d must be a string", i+1)
return
} }
} }
// This always results in type string // This always results in type string
v.stackPush(ast.TypeString) v.StackPush(ast.TypeString)
return n, nil
} }
func (v *TypeCheck) visitLiteral(n *ast.LiteralNode) { type typeCheckLiteral struct {
v.stackPush(n.Typex) n *ast.LiteralNode
} }
func (v *TypeCheck) visitVariableAccess(n *ast.VariableAccess) { func (tc *typeCheckLiteral) TypeCheck(v *TypeCheck) (ast.Node, error) {
v.StackPush(tc.n.Typex)
return tc.n, nil
}
type typeCheckVariableAccess struct {
n *ast.VariableAccess
}
func (tc *typeCheckVariableAccess) TypeCheck(v *TypeCheck) (ast.Node, error) {
// Look up the variable in the map // Look up the variable in the map
variable, ok := v.Scope.LookupVar(n.Name) variable, ok := v.Scope.LookupVar(tc.n.Name)
if !ok { if !ok {
v.createErr(n, fmt.Sprintf( return nil, fmt.Errorf(
"unknown variable accessed: %s", n.Name)) "unknown variable accessed: %s", tc.n.Name)
return
} }
// Add the type to the stack // Add the type to the stack
v.stackPush(variable.Type) v.StackPush(variable.Type)
return tc.n, nil
} }
func (v *TypeCheck) createErr(n ast.Node, str string) { func (v *TypeCheck) ImplicitConversion(
pos := n.Pos()
v.err = fmt.Errorf("At column %d, line %d: %s",
pos.Column, pos.Line, str)
}
func (v *TypeCheck) implicitConversion(
actual ast.Type, expected ast.Type, n ast.Node) ast.Node { actual ast.Type, expected ast.Type, n ast.Node) ast.Node {
if v.Implicit == nil { if v.Implicit == nil {
return nil return nil
@ -186,16 +227,16 @@ func (v *TypeCheck) implicitConversion(
} }
func (v *TypeCheck) reset() { func (v *TypeCheck) reset() {
v.stack = nil v.Stack = nil
v.err = nil v.err = nil
} }
func (v *TypeCheck) stackPush(t ast.Type) { func (v *TypeCheck) StackPush(t ast.Type) {
v.stack = append(v.stack, t) v.Stack = append(v.Stack, t)
} }
func (v *TypeCheck) stackPop() ast.Type { func (v *TypeCheck) StackPop() ast.Type {
var x ast.Type var x ast.Type
x, v.stack = v.stack[len(v.stack)-1], v.stack[:len(v.stack)-1] x, v.Stack = v.Stack[len(v.Stack)-1], v.Stack[:len(v.Stack)-1]
return x return x
} }