Skip to content

Commit 2b6f48a

Browse files
committed
Fix up catch
Signed-off-by: James Hamlin <jfhamlin@gmail.com>
1 parent 693d020 commit 2b6f48a

5 files changed

Lines changed: 89 additions & 82 deletions

File tree

pkg/codegen/codegen.go

Lines changed: 36 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ type Generator struct {
3434
w io.Writer
3535
varScopes []varScope // stack of variable scopes
3636
recurStack []recurContext // stack of recur contexts for nested loops
37+
38+
imports map[string]bool // set of imported packages to avoid duplicates
3739
}
3840

3941
// New creates a new code generator
@@ -43,6 +45,7 @@ func New(w io.Writer) *Generator {
4345
w: w,
4446
varScopes: []varScope{{nextNum: 0, names: make(map[string]string)}},
4547
recurStack: []recurContext{},
48+
imports: make(map[string]bool),
4649
}
4750
}
4851

@@ -53,41 +56,6 @@ func (g *Generator) Generate(ns *lang.Namespace) error {
5356
var buf bytes.Buffer
5457
g.w = &buf
5558

56-
// Check if we need fmt import (for functions with arity checks)
57-
needsFmt := false
58-
mappings := ns.Mappings()
59-
60-
// Only check vars that are interned in this namespace
61-
for seq := mappings.Seq(); seq != nil; seq = seq.Next() {
62-
entry := seq.First()
63-
name, ok := lang.First(entry).(*lang.Symbol)
64-
if !ok {
65-
continue
66-
}
67-
second, _ := lang.Nth(entry, 1)
68-
vr, ok := second.(*lang.Var)
69-
if !ok {
70-
continue
71-
}
72-
73-
// Skip non-interned mappings
74-
if !(vr.Namespace() == ns && lang.Equals(vr.Symbol(), name)) {
75-
continue
76-
}
77-
78-
if vr.IsBound() {
79-
if _, ok := vr.Get().(*runtime.Fn); ok {
80-
needsFmt = true
81-
break
82-
}
83-
}
84-
}
85-
86-
// Write package header
87-
if err := g.writeHeader(needsFmt); err != nil {
88-
return err
89-
}
90-
9159
g.writef("func init() {\n")
9260

9361
g.writef(" ns := lang.FindOrCreateNamespace(lang.NewSymbol(\"%s\"))\n", ns.Name().String())
@@ -96,6 +64,7 @@ func (g *Generator) Generate(ns *lang.Namespace) error {
9664
// 1. Iterate through ns.Mappings()
9765
// 2. Generate Go code for each var
9866
// 3. Create initialization functions
67+
mappings := ns.Mappings()
9968
for seq := mappings.Seq(); seq != nil; seq = seq.Next() {
10069
entry := seq.First()
10170
name, ok := lang.First(entry).(*lang.Symbol)
@@ -119,8 +88,12 @@ func (g *Generator) Generate(ns *lang.Namespace) error {
11988

12089
g.writef("}\n")
12190

91+
// Write package header
92+
sourceBytes := []byte(g.header())
93+
sourceBytes = append(sourceBytes, buf.Bytes()...)
94+
12295
// Format the generated code
123-
formatted, err := format.Source(buf.Bytes())
96+
formatted, err := format.Source(sourceBytes)
12497
if err != nil {
12598
// If formatting fails, write the unformatted code with the error
12699
return fmt.Errorf("formatting failed: %w\n\nGenerated code:\n%s", err, buf.String())
@@ -273,6 +246,7 @@ func (g *Generator) generateFn(fn *runtime.Fn) string {
273246

274247
g.writef("%s := lang.IFnFunc(func(args ...any) any {\n", fnVar)
275248

249+
g.addImport("fmt") // Import fmt for error formatting
276250
// Check arity
277251
g.writef(" if len(args) != %d {\n", methodNode.FixedArity)
278252
g.writef(" panic(lang.NewIllegalArgumentError(\"wrong number of arguments (\" + fmt.Sprint(len(args)) + \")\"))\n")
@@ -373,7 +347,6 @@ func (g *Generator) generateASTNode(node *ast.Node) string {
373347
// OpSet
374348
// OpLetFn
375349
// OpQuote
376-
// OpGoBuiltin
377350
// OpGo
378351
// OpHostCall
379352
// OpHostInterop
@@ -407,10 +380,7 @@ func (g *Generator) generateASTNode(node *ast.Node) string {
407380
case ast.OpRecur:
408381
return g.generateRecur(node)
409382
case ast.OpGoBuiltin:
410-
// For now, just return a reference to the go type
411-
// This is used for catch clauses with go/any
412-
goBuiltinNode := node.Sub.(*ast.GoBuiltinNode)
413-
return fmt.Sprintf("lang.NewSymbol(\"%s\")", goBuiltinNode.Sym.Name())
383+
return g.generateGoBuiltin(node)
414384
default:
415385
fmt.Printf("Generating code for AST node: %T %+v\n", node.Sub, node.Sub)
416386
panic(fmt.Sprintf("unsupported AST node type %T", node.Sub))
@@ -720,9 +690,8 @@ func (g *Generator) generateTry(node *ast.Node) string {
720690
g.writef("} else ")
721691
}
722692

723-
// Simple implementation: check for "any" or assume it catches
724-
// In a full implementation, we'd need to check types properly
725-
g.writef("if true { // TODO: implement catchMatches(r, %s)\n", classExpr)
693+
// Check if the exception matches this catch type
694+
g.writef("if lang.CatchMatches(r, %s) {\n", classExpr)
726695

727696
// Create new scope for catch binding
728697
g.pushVarScope()
@@ -758,25 +727,39 @@ func (g *Generator) generateTry(node *ast.Node) string {
758727
return resultVar
759728
}
760729

730+
func (g *Generator) generateGoBuiltin(node *ast.Node) string {
731+
goBuiltinNode := node.Sub.(*ast.GoBuiltinNode)
732+
sym := goBuiltinNode.Sym
733+
734+
_, ok := lang.Builtins[sym.Name()]
735+
if !ok {
736+
panic(fmt.Sprintf("unknown Go builtin: %s", sym.Name()))
737+
}
738+
739+
return "lang.Builtins[\"" + sym.Name() + "\"]"
740+
}
741+
761742
////////////////////////////////////////////////////////////////////////////////
762743

763-
func (g *Generator) writeHeader(needsFmt bool) error {
744+
func (g *Generator) addImport(pkg string) {
745+
g.imports[pkg] = true
746+
}
747+
748+
func (g *Generator) header() string {
764749
header := `// Code generated by glojure codegen. DO NOT EDIT.
765750
766751
package generated
767752
768753
import (
754+
"github.com/glojurelang/glojure/pkg/lang"
769755
`
770-
if needsFmt {
771-
header += ` "fmt"
772-
`
756+
757+
for pkg := range g.imports {
758+
header += fmt.Sprintf(" \"%s\"\n", pkg)
773759
}
774-
header += ` "github.com/glojurelang/glojure/pkg/lang"
775-
)
776760

777-
`
778-
_, err := io.WriteString(g.w, header)
779-
return err
761+
header += ")\n"
762+
return header
780763
}
781764

782765
func (g *Generator) writef(format string, args ...any) error {

pkg/codegen/testdata/codegen/test/try_advanced.go

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pkg/codegen/testdata/codegen/test/try_basic.go

Lines changed: 5 additions & 5 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pkg/lang/catch.go

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
package lang
2+
3+
import (
4+
"errors"
5+
"reflect"
6+
)
7+
8+
var (
9+
errorType = reflect.TypeOf((*error)(nil)).Elem()
10+
)
11+
12+
// CatchMatches checks if a recovered panic value matches an expected catch type.
13+
// This implements the semantics of Clojure's try/catch matching.
14+
func CatchMatches(r, expect any) bool {
15+
if IsNil(expect) {
16+
return false
17+
}
18+
19+
// Special case: the symbol "any" catches everything (for go/any)
20+
if sym, ok := expect.(*Symbol); ok && sym.Name() == "any" {
21+
return true
22+
}
23+
24+
// If expect is an error type, check if r is an instance of it
25+
if rErr, ok := r.(error); ok {
26+
if expectTyp, ok := expect.(reflect.Type); ok && expectTyp.Implements(errorType) {
27+
expectVal := reflect.New(expectTyp).Elem().Interface().(error)
28+
if errors.Is(rErr, expectVal) {
29+
return true
30+
}
31+
}
32+
}
33+
34+
// General type check
35+
if expectTyp, ok := expect.(reflect.Type); ok {
36+
return reflect.TypeOf(r).AssignableTo(expectTyp)
37+
}
38+
39+
// For interface{} type (go/any), catch everything
40+
if expectTyp, ok := expect.(reflect.Type); ok && expectTyp.Kind() == reflect.Interface && expectTyp.NumMethod() == 0 {
41+
return true
42+
}
43+
44+
return false
45+
}

pkg/runtime/evalast.go

Lines changed: 1 addition & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -689,27 +689,6 @@ func (env *environment) EvalASTNew(n *ast.Node) (interface{}, error) {
689689
return reflect.New(classValTyp).Interface(), nil
690690
}
691691

692-
var (
693-
errorType = reflect.TypeOf((*error)(nil)).Elem()
694-
)
695-
696-
func catchMatches(r, expect any) bool {
697-
if lang.IsNil(expect) {
698-
return false
699-
}
700-
701-
// if expect is an error type, check if r is an instance of it
702-
if rErr, ok := r.(error); ok {
703-
if expectTyp, ok := expect.(reflect.Type); ok && expectTyp.Implements(errorType) {
704-
expectVal := reflect.New(expectTyp).Elem().Interface().(error)
705-
if errors.Is(rErr, expectVal) {
706-
return true
707-
}
708-
}
709-
}
710-
711-
return reflect.TypeOf(r).AssignableTo(expect.(reflect.Type))
712-
}
713692

714693
func (env *environment) EvalASTTry(n *ast.Node) (res interface{}, err error) {
715694
tryNode := n.Sub.(*ast.TryNode)
@@ -735,7 +714,7 @@ func (env *environment) EvalASTTry(n *ast.Node) (res interface{}, err error) {
735714
panic(classErr)
736715
}
737716

738-
if !catchMatches(r, classVal) {
717+
if !lang.CatchMatches(r, classVal) {
739718
continue
740719
}
741720

0 commit comments

Comments
 (0)