Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 21 additions & 9 deletions internal/codegen/config.go
Original file line number Diff line number Diff line change
@@ -1,19 +1,31 @@
package codegen

// DefaultHeader is used when Config.Header is empty.
const DefaultHeader = "// Code generated by mmapforge. DO NOT EDIT."

// Config holds the global codegen configuration, shared between all generated nodes.
type Config struct {
Target string
// Target is the output directory for generated files.
Target string

// Package is the Go package for generated code.
Package string
Header string

// Header is an optional file header. Defaults to the DefaultHeader constant.
Header string

// Templates specifics optional external templates to execute
// or override the defaults.
Templates []*Template
Hooks []Hook

// Hooks holds an optional list of Hooks to apply on the graph
// before/after code generation.
Hooks []Hook
}

//func (c *Config) header() string {
// if c.Header != "" {
// return c.Header
// }
// return DefaultHeader
//}
func (c *Config) header() string {
if c.Header != "" {
return c.Header
}
return DefaultHeader
}
23 changes: 23 additions & 0 deletions internal/codegen/config_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package codegen

import "testing"

func TestDefaultHeader(t *testing.T) {
if DefaultHeader == "" {
t.Fatal("DefaultHeader should not be empty")
}
}

func TestConfig_header_Custom(t *testing.T) {
c := &Config{Header: "// Custom header"}
if got := c.header(); got != "// Custom header" {
t.Errorf("header() = %q, want %q", got, "// Custom header")
}
}

func TestConfig_header_Default(t *testing.T) {
c := &Config{}
if got := c.header(); got != DefaultHeader {
t.Errorf("header() = %q, want %q", got, DefaultHeader)
}
}
89 changes: 75 additions & 14 deletions internal/codegen/graph.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,25 @@
package codegen

import (
"bytes"
"errors"
"fmt"
"go/format"
"os"
"path/filepath"
"text/template/parse"

"github.com/CreditWorthy/mmapforge"
)

// Mockable functions
var computeLayoutFunc = mmapforge.ComputeLayout
var mkdirAllFunc = os.MkdirAll
var writeFileFunc = os.WriteFile
var generateFunc = generate
var writeFormattedFunc = writeFormatted

// Generator is the interface for codegen from a Graph.
type Generator interface {
Generate(*Graph) error
}
Expand All @@ -17,14 +30,18 @@ type GenerateFunc func(*Graph) error
// Generate calls f(g).
func (f GenerateFunc) Generate(g *Graph) error { return f(g) }

// Hook is "generate middleware" - wraps a Generator to inject logic
type Hook func(Generator) Generator

// Graph holds all Type nodes and derive code generation.
type Graph struct {
*Config

Nodes []*Type
}

// NewGraph builds a Graph from parsed schemas and config.
// It computes layouts, builds rich Type/Field objects, and validates.
func NewGraph(c *Config, schemas []StructSchema) (*Graph, error) {
if c.Target == "" {
return nil, fmt.Errorf("mmapforge: codegen: target directory is required")
Expand All @@ -36,7 +53,7 @@ func NewGraph(c *Config, schemas []StructSchema) (*Graph, error) {
}

for _, s := range schemas {
layout, err := mmapforge.ComputeLayout(s.Fields)
layout, err := computeLayoutFunc(s.Fields)
if err != nil {
return nil, fmt.Errorf("mmapforge: compute layout for %s: %w", s.Name, err)
}
Expand All @@ -63,31 +80,75 @@ func NewGraph(c *Config, schemas []StructSchema) (*Graph, error) {
return g, nil
}

// Gen generates all artifacts. Hooks wrap the core generation
func (g *Graph) Gen() error {
var gen Generator = GenerateFunc(generate)
var gen Generator = GenerateFunc(generateFunc)
for i := len(g.Hooks) - 1; i >= 0; i-- {
gen = g.Hooks[i](gen)
}
return gen.Generate(g)
}

func generate(g *Graph) error {
if err := os.MkdirAll(g.Target, os.ModePerm); err != nil {
if err := mkdirAllFunc(g.Target, os.ModePerm); err != nil {
return fmt.Errorf("mmapforge: create target dir: %w", err)
}

//initTemplates()
initTemplates()

for _, ext := range g.Templates {
templates.Funcs(ext.FuncMap)
for _, tmpl := range ext.Templates() {
if parse.IsEmptyTree(tmpl.Tree.Root) {
continue
}
templates = MustParse(templates.AddParseTree(tmpl.Name(), tmpl.Tree))
}
}

//for _, ext := range g.Templates {
// templates.Funcs(ext.FuncMap)
// for _, tmpl := range ext.Templates() {
// if parse.IsEmptyTree(tmpl.Root) {
// continue
// }
// templates = MustParse(templates.AddParseTree(tmpl.Name(), tmpl.Tree))
// }
//}
//
for _, node := range g.Nodes {
for _, tmpl := range TypeTemplates {
if tmpl.Cond != nil && !tmpl.Cond(node) {
continue
}
b := bytes.NewBuffer(nil)
if err := templates.ExecuteTemplate(b, tmpl.Name, node); err != nil {
return fmt.Errorf("mmapforge: execute %q for %s: %w", tmpl.Name, node.Name, err)
}
path := filepath.Join(g.Target, tmpl.Format(node))
if err := writeFormattedFunc(path, b.Bytes()); err != nil {
return err
}
}
}

for _, tmpl := range GraphTemplates {
if tmpl.Skip != nil && tmpl.Skip(g) {
continue
}

b := bytes.NewBuffer(nil)
if err := templates.ExecuteTemplate(b, tmpl.Name, g); err != nil {
return fmt.Errorf("mmapforge: execute %q: %w", tmpl.Name, err)
}
path := filepath.Join(g.Target, tmpl.Format)
if err := writeFormattedFunc(path, b.Bytes()); err != nil {
return err
}
}

return nil
}

// writeFormatted writes Go source to a file, running gofmt first.
func writeFormatted(path string, src []byte) error {
formatted, err := format.Source(src)
if err != nil {
writeErr := writeFileFunc(path, src, 0644)
return errors.Join(
fmt.Errorf("mmapforge: format %s: %w", path, err),
fmt.Errorf("mmapforge: write %s: %w", path, writeErr),
)
}
return writeFileFunc(path, formatted, 0644)
}
Loading
Loading