config/lang: make TypeCheck implementable by other nodes
This commit is contained in:
parent
57adfe53f6
commit
4302dbaf2a
|
@ -24,11 +24,21 @@ type TypeCheck struct {
|
||||||
// value is the function to call (which must be registered in the Scope).
|
// value is the function to call (which must be registered in the Scope).
|
||||||
Implicit map[ast.Type]map[ast.Type]string
|
Implicit map[ast.Type]map[ast.Type]string
|
||||||
|
|
||||||
stack []ast.Type
|
// Stack of types. This shouldn't be used directly except by implementations
|
||||||
|
// of TypeCheckNode.
|
||||||
|
Stack []ast.Type
|
||||||
|
|
||||||
err error
|
err error
|
||||||
lock sync.Mutex
|
lock sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TypeCheckNode is the interface that must be implemented by any
|
||||||
|
// ast.Node that wants to support type-checking. If the type checker
|
||||||
|
// encounters a node that doesn't implement this, it will error.
|
||||||
|
type TypeCheckNode interface {
|
||||||
|
TypeCheck(*TypeCheck) (ast.Node, error)
|
||||||
|
}
|
||||||
|
|
||||||
func (v *TypeCheck) Visit(root ast.Node) error {
|
func (v *TypeCheck) Visit(root ast.Node) error {
|
||||||
v.lock.Lock()
|
v.lock.Lock()
|
||||||
defer v.lock.Unlock()
|
defer v.lock.Unlock()
|
||||||
|
@ -42,49 +52,69 @@ func (v *TypeCheck) visit(raw ast.Node) ast.Node {
|
||||||
return raw
|
return raw
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var result ast.Node
|
||||||
|
var err error
|
||||||
switch n := raw.(type) {
|
switch n := raw.(type) {
|
||||||
case *ast.Call:
|
case *ast.Call:
|
||||||
v.visitCall(n)
|
tc := &typeCheckCall{n}
|
||||||
|
result, err = tc.TypeCheck(v)
|
||||||
case *ast.Concat:
|
case *ast.Concat:
|
||||||
v.visitConcat(n)
|
tc := &typeCheckConcat{n}
|
||||||
|
result, err = tc.TypeCheck(v)
|
||||||
case *ast.LiteralNode:
|
case *ast.LiteralNode:
|
||||||
v.visitLiteral(n)
|
tc := &typeCheckLiteral{n}
|
||||||
|
result, err = tc.TypeCheck(v)
|
||||||
case *ast.VariableAccess:
|
case *ast.VariableAccess:
|
||||||
v.visitVariableAccess(n)
|
tc := &typeCheckVariableAccess{n}
|
||||||
|
result, err = tc.TypeCheck(v)
|
||||||
default:
|
default:
|
||||||
v.createErr(n, fmt.Sprintf("unknown node: %#v", raw))
|
tc, ok := raw.(TypeCheckNode)
|
||||||
|
if !ok {
|
||||||
|
err = fmt.Errorf("unknown node: %#v", raw)
|
||||||
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
return raw
|
result, err = tc.TypeCheck(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
pos := raw.Pos()
|
||||||
|
v.err = fmt.Errorf("At column %d, line %d: %s",
|
||||||
|
pos.Column, pos.Line, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
func (v *TypeCheck) visitCall(n *ast.Call) {
|
type typeCheckCall struct {
|
||||||
|
n *ast.Call
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tc *typeCheckCall) TypeCheck(v *TypeCheck) (ast.Node, error) {
|
||||||
// Look up the function in the map
|
// Look up the function in the map
|
||||||
function, ok := v.Scope.LookupFunc(n.Func)
|
function, ok := v.Scope.LookupFunc(tc.n.Func)
|
||||||
if !ok {
|
if !ok {
|
||||||
v.createErr(n, fmt.Sprintf("unknown function called: %s", n.Func))
|
return nil, fmt.Errorf("unknown function called: %s", tc.n.Func)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// The arguments are on the stack in reverse order, so pop them off.
|
// The arguments are on the stack in reverse order, so pop them off.
|
||||||
args := make([]ast.Type, len(n.Args))
|
args := make([]ast.Type, len(tc.n.Args))
|
||||||
for i, _ := range n.Args {
|
for i, _ := range tc.n.Args {
|
||||||
args[len(n.Args)-1-i] = v.stackPop()
|
args[len(tc.n.Args)-1-i] = v.StackPop()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify the args
|
// Verify the args
|
||||||
for i, expected := range function.ArgTypes {
|
for i, expected := range function.ArgTypes {
|
||||||
if args[i] != expected {
|
if args[i] != expected {
|
||||||
cn := v.implicitConversion(args[i], expected, n.Args[i])
|
cn := v.ImplicitConversion(args[i], expected, tc.n.Args[i])
|
||||||
if cn != nil {
|
if cn != nil {
|
||||||
n.Args[i] = cn
|
tc.n.Args[i] = cn
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
v.createErr(n, fmt.Sprintf(
|
return nil, fmt.Errorf(
|
||||||
"%s: argument %d should be %s, got %s",
|
"%s: argument %d should be %s, got %s",
|
||||||
n.Func, i+1, expected, args[i]))
|
tc.n.Func, i+1, expected, args[i])
|
||||||
return
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -94,75 +124,86 @@ func (v *TypeCheck) visitCall(n *ast.Call) {
|
||||||
for i, t := range args {
|
for i, t := range args {
|
||||||
if t != function.VariadicType {
|
if t != function.VariadicType {
|
||||||
realI := i + len(function.ArgTypes)
|
realI := i + len(function.ArgTypes)
|
||||||
cn := v.implicitConversion(
|
cn := v.ImplicitConversion(
|
||||||
t, function.VariadicType, n.Args[realI])
|
t, function.VariadicType, tc.n.Args[realI])
|
||||||
if cn != nil {
|
if cn != nil {
|
||||||
n.Args[realI] = cn
|
tc.n.Args[realI] = cn
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
v.createErr(n, fmt.Sprintf(
|
return nil, fmt.Errorf(
|
||||||
"%s: argument %d should be %s, got %s",
|
"%s: argument %d should be %s, got %s",
|
||||||
n.Func, realI,
|
tc.n.Func, realI,
|
||||||
function.VariadicType, t))
|
function.VariadicType, t)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return type
|
// Return type
|
||||||
v.stackPush(function.ReturnType)
|
v.StackPush(function.ReturnType)
|
||||||
|
|
||||||
|
return tc.n, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (v *TypeCheck) visitConcat(n *ast.Concat) {
|
type typeCheckConcat struct {
|
||||||
|
n *ast.Concat
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tc *typeCheckConcat) TypeCheck(v *TypeCheck) (ast.Node, error) {
|
||||||
|
n := tc.n
|
||||||
types := make([]ast.Type, len(n.Exprs))
|
types := make([]ast.Type, len(n.Exprs))
|
||||||
for i, _ := range n.Exprs {
|
for i, _ := range n.Exprs {
|
||||||
types[len(n.Exprs)-1-i] = v.stackPop()
|
types[len(n.Exprs)-1-i] = v.StackPop()
|
||||||
}
|
}
|
||||||
|
|
||||||
// All concat args must be strings, so validate that
|
// All concat args must be strings, so validate that
|
||||||
for i, t := range types {
|
for i, t := range types {
|
||||||
if t != ast.TypeString {
|
if t != ast.TypeString {
|
||||||
cn := v.implicitConversion(t, ast.TypeString, n.Exprs[i])
|
cn := v.ImplicitConversion(t, ast.TypeString, n.Exprs[i])
|
||||||
if cn != nil {
|
if cn != nil {
|
||||||
n.Exprs[i] = cn
|
n.Exprs[i] = cn
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
v.createErr(n, fmt.Sprintf(
|
return nil, fmt.Errorf(
|
||||||
"argument %d must be a string", i+1))
|
"argument %d must be a string", i+1)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// This always results in type string
|
// This always results in type string
|
||||||
v.stackPush(ast.TypeString)
|
v.StackPush(ast.TypeString)
|
||||||
|
|
||||||
|
return n, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (v *TypeCheck) visitLiteral(n *ast.LiteralNode) {
|
type typeCheckLiteral struct {
|
||||||
v.stackPush(n.Typex)
|
n *ast.LiteralNode
|
||||||
}
|
}
|
||||||
|
|
||||||
func (v *TypeCheck) visitVariableAccess(n *ast.VariableAccess) {
|
func (tc *typeCheckLiteral) TypeCheck(v *TypeCheck) (ast.Node, error) {
|
||||||
|
v.StackPush(tc.n.Typex)
|
||||||
|
return tc.n, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type typeCheckVariableAccess struct {
|
||||||
|
n *ast.VariableAccess
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tc *typeCheckVariableAccess) TypeCheck(v *TypeCheck) (ast.Node, error) {
|
||||||
// Look up the variable in the map
|
// Look up the variable in the map
|
||||||
variable, ok := v.Scope.LookupVar(n.Name)
|
variable, ok := v.Scope.LookupVar(tc.n.Name)
|
||||||
if !ok {
|
if !ok {
|
||||||
v.createErr(n, fmt.Sprintf(
|
return nil, fmt.Errorf(
|
||||||
"unknown variable accessed: %s", n.Name))
|
"unknown variable accessed: %s", tc.n.Name)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add the type to the stack
|
// Add the type to the stack
|
||||||
v.stackPush(variable.Type)
|
v.StackPush(variable.Type)
|
||||||
|
|
||||||
|
return tc.n, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (v *TypeCheck) createErr(n ast.Node, str string) {
|
func (v *TypeCheck) ImplicitConversion(
|
||||||
pos := n.Pos()
|
|
||||||
v.err = fmt.Errorf("At column %d, line %d: %s",
|
|
||||||
pos.Column, pos.Line, str)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *TypeCheck) implicitConversion(
|
|
||||||
actual ast.Type, expected ast.Type, n ast.Node) ast.Node {
|
actual ast.Type, expected ast.Type, n ast.Node) ast.Node {
|
||||||
if v.Implicit == nil {
|
if v.Implicit == nil {
|
||||||
return nil
|
return nil
|
||||||
|
@ -186,16 +227,16 @@ func (v *TypeCheck) implicitConversion(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (v *TypeCheck) reset() {
|
func (v *TypeCheck) reset() {
|
||||||
v.stack = nil
|
v.Stack = nil
|
||||||
v.err = nil
|
v.err = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (v *TypeCheck) stackPush(t ast.Type) {
|
func (v *TypeCheck) StackPush(t ast.Type) {
|
||||||
v.stack = append(v.stack, t)
|
v.Stack = append(v.Stack, t)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (v *TypeCheck) stackPop() ast.Type {
|
func (v *TypeCheck) StackPop() ast.Type {
|
||||||
var x ast.Type
|
var x ast.Type
|
||||||
x, v.stack = v.stack[len(v.stack)-1], v.stack[:len(v.stack)-1]
|
x, v.Stack = v.Stack[len(v.Stack)-1], v.Stack[:len(v.Stack)-1]
|
||||||
return x
|
return x
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue