diff --git a/internal/codegen/config.go b/internal/codegen/config.go index f917a37..f779b88 100644 --- a/internal/codegen/config.go +++ b/internal/codegen/config.go @@ -1,19 +1,31 @@ package codegen +// DefaultHeader is used when Config.Header is empty. const DefaultHeader = "// Code generated by mmapforge. DO NOT EDIT." +// Config holds the global codegen configuration, shared between all generated nodes. type Config struct { - Target string + // Target is the output directory for generated files. + Target string + + // Package is the Go package for generated code. Package string - Header string + // Header is an optional file header. Defaults to the DefaultHeader constant. + Header string + + // Templates specifics optional external templates to execute + // or override the defaults. Templates []*Template - Hooks []Hook + + // Hooks holds an optional list of Hooks to apply on the graph + // before/after code generation. + Hooks []Hook } -//func (c *Config) header() string { -// if c.Header != "" { -// return c.Header -// } -// return DefaultHeader -//} +func (c *Config) header() string { + if c.Header != "" { + return c.Header + } + return DefaultHeader +} diff --git a/internal/codegen/config_test.go b/internal/codegen/config_test.go new file mode 100644 index 0000000..c3f91a1 --- /dev/null +++ b/internal/codegen/config_test.go @@ -0,0 +1,23 @@ +package codegen + +import "testing" + +func TestDefaultHeader(t *testing.T) { + if DefaultHeader == "" { + t.Fatal("DefaultHeader should not be empty") + } +} + +func TestConfig_header_Custom(t *testing.T) { + c := &Config{Header: "// Custom header"} + if got := c.header(); got != "// Custom header" { + t.Errorf("header() = %q, want %q", got, "// Custom header") + } +} + +func TestConfig_header_Default(t *testing.T) { + c := &Config{} + if got := c.header(); got != DefaultHeader { + t.Errorf("header() = %q, want %q", got, DefaultHeader) + } +} diff --git a/internal/codegen/graph.go b/internal/codegen/graph.go index 883b3fa..ef92296 100644 --- a/internal/codegen/graph.go +++ b/internal/codegen/graph.go @@ -1,12 +1,25 @@ package codegen import ( + "bytes" + "errors" "fmt" + "go/format" "os" + "path/filepath" + "text/template/parse" "github.com/CreditWorthy/mmapforge" ) +// Mockable functions +var computeLayoutFunc = mmapforge.ComputeLayout +var mkdirAllFunc = os.MkdirAll +var writeFileFunc = os.WriteFile +var generateFunc = generate +var writeFormattedFunc = writeFormatted + +// Generator is the interface for codegen from a Graph. type Generator interface { Generate(*Graph) error } @@ -17,14 +30,18 @@ type GenerateFunc func(*Graph) error // Generate calls f(g). func (f GenerateFunc) Generate(g *Graph) error { return f(g) } +// Hook is "generate middleware" - wraps a Generator to inject logic type Hook func(Generator) Generator +// Graph holds all Type nodes and derive code generation. type Graph struct { *Config Nodes []*Type } +// NewGraph builds a Graph from parsed schemas and config. +// It computes layouts, builds rich Type/Field objects, and validates. func NewGraph(c *Config, schemas []StructSchema) (*Graph, error) { if c.Target == "" { return nil, fmt.Errorf("mmapforge: codegen: target directory is required") @@ -36,7 +53,7 @@ func NewGraph(c *Config, schemas []StructSchema) (*Graph, error) { } for _, s := range schemas { - layout, err := mmapforge.ComputeLayout(s.Fields) + layout, err := computeLayoutFunc(s.Fields) if err != nil { return nil, fmt.Errorf("mmapforge: compute layout for %s: %w", s.Name, err) } @@ -63,8 +80,9 @@ func NewGraph(c *Config, schemas []StructSchema) (*Graph, error) { return g, nil } +// Gen generates all artifacts. Hooks wrap the core generation func (g *Graph) Gen() error { - var gen Generator = GenerateFunc(generate) + var gen Generator = GenerateFunc(generateFunc) for i := len(g.Hooks) - 1; i >= 0; i-- { gen = g.Hooks[i](gen) } @@ -72,22 +90,65 @@ func (g *Graph) Gen() error { } func generate(g *Graph) error { - if err := os.MkdirAll(g.Target, os.ModePerm); err != nil { + if err := mkdirAllFunc(g.Target, os.ModePerm); err != nil { return fmt.Errorf("mmapforge: create target dir: %w", err) } - //initTemplates() + initTemplates() + + for _, ext := range g.Templates { + templates.Funcs(ext.FuncMap) + for _, tmpl := range ext.Templates() { + if parse.IsEmptyTree(tmpl.Tree.Root) { + continue + } + templates = MustParse(templates.AddParseTree(tmpl.Name(), tmpl.Tree)) + } + } - //for _, ext := range g.Templates { - // templates.Funcs(ext.FuncMap) - // for _, tmpl := range ext.Templates() { - // if parse.IsEmptyTree(tmpl.Root) { - // continue - // } - // templates = MustParse(templates.AddParseTree(tmpl.Name(), tmpl.Tree)) - // } - //} - // + for _, node := range g.Nodes { + for _, tmpl := range TypeTemplates { + if tmpl.Cond != nil && !tmpl.Cond(node) { + continue + } + b := bytes.NewBuffer(nil) + if err := templates.ExecuteTemplate(b, tmpl.Name, node); err != nil { + return fmt.Errorf("mmapforge: execute %q for %s: %w", tmpl.Name, node.Name, err) + } + path := filepath.Join(g.Target, tmpl.Format(node)) + if err := writeFormattedFunc(path, b.Bytes()); err != nil { + return err + } + } + } + + for _, tmpl := range GraphTemplates { + if tmpl.Skip != nil && tmpl.Skip(g) { + continue + } + + b := bytes.NewBuffer(nil) + if err := templates.ExecuteTemplate(b, tmpl.Name, g); err != nil { + return fmt.Errorf("mmapforge: execute %q: %w", tmpl.Name, err) + } + path := filepath.Join(g.Target, tmpl.Format) + if err := writeFormattedFunc(path, b.Bytes()); err != nil { + return err + } + } return nil } + +// writeFormatted writes Go source to a file, running gofmt first. +func writeFormatted(path string, src []byte) error { + formatted, err := format.Source(src) + if err != nil { + writeErr := writeFileFunc(path, src, 0644) + return errors.Join( + fmt.Errorf("mmapforge: format %s: %w", path, err), + fmt.Errorf("mmapforge: write %s: %w", path, writeErr), + ) + } + return writeFileFunc(path, formatted, 0644) +} diff --git a/internal/codegen/graph_test.go b/internal/codegen/graph_test.go new file mode 100644 index 0000000..82cb782 --- /dev/null +++ b/internal/codegen/graph_test.go @@ -0,0 +1,643 @@ +package codegen + +import ( + "errors" + "fmt" + "os" + "path/filepath" + "testing" + + "github.com/CreditWorthy/mmapforge" +) + +func TestGenerateFunc_Generate(t *testing.T) { + called := false + fn := GenerateFunc(func(_ *Graph) error { + called = true + return nil + }) + err := fn.Generate(&Graph{Config: &Config{}}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !called { + t.Error("function should have been called") + } +} + +func TestGenerateFunc_Generate_Error(t *testing.T) { + fn := GenerateFunc(func(_ *Graph) error { + return errors.New("fail") + }) + if err := fn.Generate(&Graph{Config: &Config{}}); err == nil { + t.Error("expected error") + } +} + +func TestNewGraph_EmptyTarget(t *testing.T) { + _, err := NewGraph(&Config{Target: ""}, nil) + if err == nil { + t.Fatal("expected error for empty target") + } +} + +func TestNewGraph_ComputeLayoutError(t *testing.T) { + orig := computeLayoutFunc + defer func() { computeLayoutFunc = orig }() + + computeLayoutFunc = func(_ []mmapforge.FieldDef) (*mmapforge.RecordLayout, error) { + return nil, errors.New("layout error") + } + + schemas := []StructSchema{ + {Name: "Foo", Fields: []mmapforge.FieldDef{{Name: "X", Type: mmapforge.FieldUint32}}}, + } + _, err := NewGraph(&Config{Target: "/tmp/test"}, schemas) + if err == nil { + t.Fatal("expected error from ComputeLayout") + } +} + +func TestNewGraph_Success_SchemaPackage(t *testing.T) { + orig := computeLayoutFunc + defer func() { computeLayoutFunc = orig }() + + computeLayoutFunc = func(_ []mmapforge.FieldDef) (*mmapforge.RecordLayout, error) { + return &mmapforge.RecordLayout{ + Fields: []mmapforge.FieldLayout{{FieldDef: mmapforge.FieldDef{Name: "X", Type: mmapforge.FieldUint32}}}, + RecordSize: 4, + }, nil + } + + schemas := []StructSchema{ + {Name: "Foo", Package: "mypkg", Fields: []mmapforge.FieldDef{{Name: "X", Type: mmapforge.FieldUint32}}}, + } + g, err := NewGraph(&Config{Target: "/tmp/test"}, schemas) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + found := false + for _, n := range g.Nodes { + if n != nil && n.Name == "Foo" { + found = true + if n.Package != "mypkg" { + t.Errorf("Package = %q, want %q", n.Package, "mypkg") + } + } + } + if !found { + t.Error("Foo node not found") + } +} + +func TestNewGraph_Success_ConfigPackageOverride(t *testing.T) { + orig := computeLayoutFunc + defer func() { computeLayoutFunc = orig }() + + computeLayoutFunc = func(_ []mmapforge.FieldDef) (*mmapforge.RecordLayout, error) { + return &mmapforge.RecordLayout{ + Fields: []mmapforge.FieldLayout{{FieldDef: mmapforge.FieldDef{Name: "X", Type: mmapforge.FieldUint32}}}, + RecordSize: 4, + }, nil + } + + schemas := []StructSchema{ + {Name: "Bar", Package: "original", Fields: []mmapforge.FieldDef{{Name: "X", Type: mmapforge.FieldUint32}}}, + } + g, err := NewGraph(&Config{Target: "/tmp/test", Package: "override"}, schemas) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + found := false + for _, n := range g.Nodes { + if n != nil && n.Name == "Bar" { + found = true + if n.Package != "override" { + t.Errorf("Package = %q, want %q", n.Package, "override") + } + } + } + if !found { + t.Error("Bar node not found") + } +} + +func TestNewGraph_MultipleFields(t *testing.T) { + orig := computeLayoutFunc + defer func() { computeLayoutFunc = orig }() + + computeLayoutFunc = func(fields []mmapforge.FieldDef) (*mmapforge.RecordLayout, error) { + layouts := make([]mmapforge.FieldLayout, len(fields)) + for i, f := range fields { + layouts[i] = mmapforge.FieldLayout{FieldDef: f} + } + return &mmapforge.RecordLayout{ + Fields: layouts, + RecordSize: 8, + }, nil + } + + schemas := []StructSchema{ + { + Name: "Multi", + Package: "pkg", + Fields: []mmapforge.FieldDef{ + {Name: "A", Type: mmapforge.FieldUint32}, + {Name: "B", Type: mmapforge.FieldUint64}, + }, + }, + } + g, err := NewGraph(&Config{Target: "/tmp/test"}, schemas) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + for _, n := range g.Nodes { + if n != nil && n.Name == "Multi" { + if len(n.Fields) != 2 { + t.Errorf("Fields count = %d, want 2", len(n.Fields)) + } + return + } + } + t.Error("Multi node not found") +} + +func TestGraph_Gen_NoHooks(t *testing.T) { + called := false + orig := generateFunc + defer func() { generateFunc = orig }() + + generateFunc = func(_ *Graph) error { + called = true + return nil + } + + g := &Graph{Config: &Config{}} + if err := g.Gen(); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !called { + t.Error("generate should have been called") + } +} + +func TestGraph_Gen_WithHooks(t *testing.T) { + var order []string + orig := generateFunc + defer func() { generateFunc = orig }() + + generateFunc = func(_ *Graph) error { + order = append(order, "generate") + return nil + } + + hook1 := func(next Generator) Generator { + return GenerateFunc(func(g *Graph) error { + order = append(order, "hook1-before") + err := next.Generate(g) + order = append(order, "hook1-after") + return err + }) + } + hook2 := func(next Generator) Generator { + return GenerateFunc(func(g *Graph) error { + order = append(order, "hook2-before") + err := next.Generate(g) + order = append(order, "hook2-after") + return err + }) + } + + g := &Graph{Config: &Config{Hooks: []Hook{hook1, hook2}}} + if err := g.Gen(); err != nil { + t.Fatalf("unexpected error: %v", err) + } + expected := []string{"hook1-before", "hook2-before", "generate", "hook2-after", "hook1-after"} + if len(order) != len(expected) { + t.Fatalf("order = %v, want %v", order, expected) + } + for i := range expected { + if order[i] != expected[i] { + t.Errorf("order[%d] = %q, want %q", i, order[i], expected[i]) + } + } +} + +func TestGraph_Gen_HookError(t *testing.T) { + orig := generateFunc + defer func() { generateFunc = orig }() + + generateFunc = func(_ *Graph) error { return nil } + + hook := func(_ Generator) Generator { + return GenerateFunc(func(_ *Graph) error { + return errors.New("hook error") + }) + } + + g := &Graph{Config: &Config{Hooks: []Hook{hook}}} + if err := g.Gen(); err == nil { + t.Error("expected error from hook") + } +} + +func TestGenerate_MkdirAllError(t *testing.T) { + orig := mkdirAllFunc + defer func() { mkdirAllFunc = orig }() + + mkdirAllFunc = func(_ string, _ os.FileMode) error { + return errors.New("mkdir error") + } + + g := &Graph{Config: &Config{Target: "/tmp/test"}} + err := generate(g) + if err == nil { + t.Fatal("expected error from MkdirAll") + } +} + +func TestGenerate_TypeTemplate_CondSkips(t *testing.T) { + origMkdir := mkdirAllFunc + origTypes := TypeTemplates + origGraph := GraphTemplates + defer func() { + mkdirAllFunc = origMkdir + TypeTemplates = origTypes + GraphTemplates = origGraph + }() + + mkdirAllFunc = func(_ string, _ os.FileMode) error { return nil } + GraphTemplates = []GraphTemplate{} + TypeTemplates = []TypeTemplate{ + { + Name: "store", + Cond: func(_ *Type) bool { return false }, + Format: func(_ *Type) string { return "skipped.go" }, + }, + } + + g := &Graph{ + Config: &Config{Target: t.TempDir()}, + Nodes: []*Type{{Config: &Config{}, Name: "Foo", Package: "pkg"}}, + } + err := generate(g) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestGenerate_TypeTemplate_ExecuteError(t *testing.T) { + origMkdir := mkdirAllFunc + origTypes := TypeTemplates + origGraph := GraphTemplates + defer func() { + mkdirAllFunc = origMkdir + TypeTemplates = origTypes + GraphTemplates = origGraph + }() + + mkdirAllFunc = func(_ string, _ os.FileMode) error { return nil } + GraphTemplates = []GraphTemplate{} + TypeTemplates = []TypeTemplate{ + { + Name: "nonexistent_template", + Format: func(_ *Type) string { return "out.go" }, + }, + } + + g := &Graph{ + Config: &Config{Target: t.TempDir()}, + Nodes: []*Type{{Config: &Config{}, Name: "Foo", Package: "pkg"}}, + } + err := generate(g) + if err == nil { + t.Fatal("expected error from ExecuteTemplate") + } +} + +func TestGenerate_GraphTemplate_SkipTrue(t *testing.T) { + origMkdir := mkdirAllFunc + origTypes := TypeTemplates + origGraph := GraphTemplates + defer func() { + mkdirAllFunc = origMkdir + TypeTemplates = origTypes + GraphTemplates = origGraph + }() + + mkdirAllFunc = func(_ string, _ os.FileMode) error { return nil } + TypeTemplates = []TypeTemplate{} + GraphTemplates = []GraphTemplate{ + { + Name: "skipped", + Skip: func(_ *Graph) bool { return true }, + Format: "skipped.go", + }, + } + + g := &Graph{Config: &Config{Target: t.TempDir()}} + err := generate(g) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestGenerate_GraphTemplate_ExecuteError(t *testing.T) { + origMkdir := mkdirAllFunc + origTypes := TypeTemplates + origGraph := GraphTemplates + defer func() { + mkdirAllFunc = origMkdir + TypeTemplates = origTypes + GraphTemplates = origGraph + }() + + mkdirAllFunc = func(_ string, _ os.FileMode) error { return nil } + TypeTemplates = []TypeTemplate{} + GraphTemplates = []GraphTemplate{ + { + Name: "nonexistent_graph_template", + Format: "out.go", + }, + } + + g := &Graph{Config: &Config{Target: t.TempDir()}} + err := generate(g) + if err == nil { + t.Fatal("expected error from graph ExecuteTemplate") + } +} + +func TestGenerate_TypeTemplate_WriteFormattedError(t *testing.T) { + origMkdir := mkdirAllFunc + origTypes := TypeTemplates + origGraph := GraphTemplates + origWrite := writeFileFunc + defer func() { + mkdirAllFunc = origMkdir + TypeTemplates = origTypes + GraphTemplates = origGraph + writeFileFunc = origWrite + }() + + mkdirAllFunc = func(_ string, _ os.FileMode) error { return nil } + GraphTemplates = []GraphTemplate{} + + ext := NewTemplate("test") + if _, err := ext.Parse(`{{ define "testtype" }}package main{{ end }}`); err != nil { + t.Fatal(err) + } + + TypeTemplates = []TypeTemplate{ + { + Name: "testtype", + Format: func(_ *Type) string { return "out.go" }, + }, + } + + writeFileFunc = func(_ string, _ []byte, _ os.FileMode) error { + return errors.New("disk full") + } + + g := &Graph{ + Config: &Config{ + Target: t.TempDir(), + Templates: []*Template{ext}, + }, + Nodes: []*Type{{Config: &Config{}, Name: "Foo", Package: "pkg"}}, + } + err := generate(g) + if err == nil { + t.Fatal("expected writeFormatted error from TypeTemplate loop") + } +} + +func TestGenerate_GraphTemplate_WriteFormattedError(t *testing.T) { + origMkdir := mkdirAllFunc + origTypes := TypeTemplates + origGraph := GraphTemplates + origWrite := writeFileFunc + defer func() { + mkdirAllFunc = origMkdir + TypeTemplates = origTypes + GraphTemplates = origGraph + writeFileFunc = origWrite + }() + + mkdirAllFunc = func(_ string, _ os.FileMode) error { return nil } + TypeTemplates = []TypeTemplate{} + + ext := NewTemplate("test") + if _, err := ext.Parse(`{{ define "graphtmpl" }}package main{{ end }}`); err != nil { + t.Fatal(err) + } + + GraphTemplates = []GraphTemplate{ + { + Name: "graphtmpl", + Format: "graph_out.go", + }, + } + + writeFileFunc = func(_ string, _ []byte, _ os.FileMode) error { + return errors.New("disk full") + } + + g := &Graph{Config: &Config{ + Target: t.TempDir(), + Templates: []*Template{ext}, + }} + err := generate(g) + if err == nil { + t.Fatal("expected writeFormatted error from GraphTemplate loop") + } +} + +func TestGenerate_GraphTemplate_Success(t *testing.T) { + origMkdir := mkdirAllFunc + origTypes := TypeTemplates + origGraph := GraphTemplates + origWrite := writeFileFunc + defer func() { + mkdirAllFunc = origMkdir + TypeTemplates = origTypes + GraphTemplates = origGraph + writeFileFunc = origWrite + }() + + mkdirAllFunc = func(_ string, _ os.FileMode) error { return nil } + TypeTemplates = []TypeTemplate{} + + ext := NewTemplate("test") + if _, err := ext.Parse(`{{ define "graphsuccess" }}package main{{ end }}`); err != nil { + t.Fatal(err) + } + + GraphTemplates = []GraphTemplate{ + { + Name: "graphsuccess", + Format: "graph_success.go", + }, + } + + dir := t.TempDir() + writeFileFunc = func(name string, data []byte, perm os.FileMode) error { + return os.WriteFile(name, data, perm) + } + + g := &Graph{Config: &Config{ + Target: dir, + Templates: []*Template{ext}, + }} + err := generate(g) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + path := filepath.Join(dir, "graph_success.go") + if _, err := os.Stat(path); os.IsNotExist(err) { + t.Error("expected graph_success.go to be created") + } +} + +func TestGenerate_ExternalTemplates(t *testing.T) { + origMkdir := mkdirAllFunc + origTypes := TypeTemplates + origGraph := GraphTemplates + origWrite := writeFileFunc + defer func() { + mkdirAllFunc = origMkdir + TypeTemplates = origTypes + GraphTemplates = origGraph + writeFileFunc = origWrite + }() + + mkdirAllFunc = func(_ string, _ os.FileMode) error { return nil } + GraphTemplates = []GraphTemplate{} + + ext := NewTemplate("ext") + if _, err := ext.Parse(`{{ define "exttype" }}package main{{ end }}`); err != nil { + t.Fatal(err) + } + + TypeTemplates = []TypeTemplate{ + { + Name: "exttype", + Format: func(_ *Type) string { return "ext_out.go" }, + }, + } + + dir := t.TempDir() + writeFileFunc = func(name string, data []byte, perm os.FileMode) error { + return os.WriteFile(name, data, perm) + } + + g := &Graph{ + Config: &Config{ + Target: dir, + Templates: []*Template{ext}, + }, + Nodes: []*Type{{Config: &Config{}, Name: "Foo", Package: "pkg"}}, + } + err := generate(g) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + path := filepath.Join(dir, "ext_out.go") + if _, err := os.Stat(path); os.IsNotExist(err) { + t.Error("expected ext_out.go to be created") + } +} + +func TestGenerate_ExternalTemplates_EmptyTree(t *testing.T) { + origMkdir := mkdirAllFunc + origTypes := TypeTemplates + origGraph := GraphTemplates + defer func() { + mkdirAllFunc = origMkdir + TypeTemplates = origTypes + GraphTemplates = origGraph + }() + + mkdirAllFunc = func(_ string, _ os.FileMode) error { return nil } + TypeTemplates = []TypeTemplate{} + GraphTemplates = []GraphTemplate{} + + ext := NewTemplate("ext") + if _, err := ext.Parse(`{{ define "empty" }}{{ end }}`); err != nil { + t.Fatal(err) + } + + g := &Graph{ + Config: &Config{ + Target: t.TempDir(), + Templates: []*Template{ext}, + }, + } + err := generate(g) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestWriteFormatted_ValidGo(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "valid.go") + src := []byte("package main\n\nfunc main() {}\n") + if err := writeFormatted(path, src); err != nil { + t.Fatalf("unexpected error: %v", err) + } + data, err := os.ReadFile(path) + if err != nil { + t.Fatalf("read error: %v", err) + } + if len(data) == 0 { + t.Error("file should not be empty") + } +} + +func TestWriteFormatted_InvalidGo(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "invalid.go") + src := []byte("not valid go code {{{") + err := writeFormatted(path, src) + if err == nil { + t.Fatal("expected error for invalid Go source") + } + data, err := os.ReadFile(path) + if err != nil { + t.Fatalf("read error: %v", err) + } + if string(data) != string(src) { + t.Error("should have written raw source on format failure") + } +} + +func TestWriteFormatted_FormatError_WriteAlsoFails(t *testing.T) { + path := "/nonexistent/dir/file.go" + src := []byte("not valid go {{{") + err := writeFormatted(path, src) + if err == nil { + t.Fatal("expected error") + } + errMsg := fmt.Sprintf("%v", err) + if len(errMsg) == 0 { + t.Error("error message should not be empty") + } +} + +func TestWriteFormatted_WriteError(t *testing.T) { + origWrite := writeFileFunc + defer func() { writeFileFunc = origWrite }() + + writeFileFunc = func(_ string, _ []byte, _ os.FileMode) error { + return errors.New("write error") + } + + dir := t.TempDir() + path := filepath.Join(dir, "fail.go") + src := []byte("package main\n\nfunc main() {}\n") + err := writeFormatted(path, src) + if err == nil { + t.Fatal("expected write error") + } +} diff --git a/internal/codegen/parser.go b/internal/codegen/parser.go index 9d0202b..809c332 100644 --- a/internal/codegen/parser.go +++ b/internal/codegen/parser.go @@ -11,13 +11,24 @@ import ( "github.com/CreditWorthy/mmapforge" ) +// StructSchema is the parsed output of a Go struct annotated with +// // mmapforge:schema version=N. It carries everything the emitter type StructSchema struct { - Name string - Package string - Fields []mmapforge.FieldDef + // Name in the Go struct name + Name string + + // Package is the Go package name + Package string + + // Fields are the parsed field definitions. + Fields []mmapforge.FieldDef + + // SchemaVersion is from the version=N directive. SchemaVersion uint32 } +// ParseFile parses a Go source file and extracts all structs annotated +// with // mmapforge:schema version=N. func ParseFile(path string) ([]StructSchema, error) { fset := token.NewFileSet() f, err := parser.ParseFile(fset, path, nil, parser.ParseComments) @@ -66,6 +77,7 @@ func ParseFile(path string) ([]StructSchema, error) { return schemas, nil } +// findDirective looks for "mmapforge:schema version=N" in the doc func findDirective(f *ast.File, fset *token.FileSet, gen *ast.GenDecl, declIdx int) (uint32, bool) { if gen.Doc != nil { for _, c := range gen.Doc.List { @@ -91,6 +103,8 @@ func findDirective(f *ast.File, fset *token.FileSet, gen *ast.GenDecl, declIdx i return 0, false } +// parseVersionFromDirective parses "// mmapforge:schema version=N" +// and returns (N, true) on success. func parseVersionFromDirective(text string) (uint32, bool) { text = strings.TrimPrefix(text, "//") text = strings.TrimSpace(text) @@ -112,6 +126,7 @@ func parseVersionFromDirective(text string) (uint32, bool) { return 0, false } +// parseFields extracts mmapforge.FieldDef entries from a struct's AST. func parseFields(st *ast.StructType) ([]mmapforge.FieldDef, error) { fields := make([]mmapforge.FieldDef, 0, len(st.Fields.List)) for _, field := range st.Fields.List { diff --git a/internal/codegen/template.go b/internal/codegen/template.go index 87ae453..1100413 100644 --- a/internal/codegen/template.go +++ b/internal/codegen/template.go @@ -9,14 +9,22 @@ import ( ) //go:embed template/* -var _ embed.FS +var templateDir embed.FS var defaultFuncMap = template.FuncMap{ "lower": strings.ToLower, "upper": strings.ToUpper, } -//var templates *Template +var parseFilesFunc = func(inner *template.Template, filenames ...string) (*template.Template, error) { + return inner.ParseFiles(filenames...) +} + +var addParseTreeFunc = func(inner *template.Template, name string, tree *parse.Tree) (*template.Template, error) { + return inner.AddParseTree(name, tree) +} + +var templates *Template type Template struct { *template.Template @@ -24,9 +32,9 @@ type Template struct { condition func(*Graph) bool } -//func initTemplates() { -// //templates = MustParse(NewTemplate("mmapforge").ParseFS(templateDir, "template/*.tmpl")) -//} +func initTemplates() { + templates = MustParse(NewTemplate("mmapforge").ParseFS(templateDir, "template/*.tmpl")) +} func NewTemplate(name string) *Template { t := &Template{Template: template.New(name)} @@ -73,14 +81,14 @@ func (t *Template) Parse(text string) (*Template, error) { } func (t *Template) ParseFiles(filenames ...string) (*Template, error) { - if _, err := t.Template.ParseFiles(filenames...); err != nil { + if _, err := parseFilesFunc(t.Template, filenames...); err != nil { return nil, err } return t, nil } func (t *Template) AddParseTree(name string, tree *parse.Tree) (*Template, error) { - if _, err := t.Template.AddParseTree(name, tree); err != nil { + if _, err := addParseTreeFunc(t.Template, name, tree); err != nil { return nil, err } return t, nil diff --git a/internal/codegen/template_registry_test.go b/internal/codegen/template_registry_test.go new file mode 100644 index 0000000..b8c5d83 --- /dev/null +++ b/internal/codegen/template_registry_test.go @@ -0,0 +1,118 @@ +package codegen + +import "testing" + +func TestTypeTemplate_Format(t *testing.T) { + if len(TypeTemplates) == 0 { + t.Fatal("TypeTemplates should not be empty") + } + + tmpl := TypeTemplates[0] + + if tmpl.Name != "store" { + t.Errorf("Name = %q, want %q", tmpl.Name, "store") + } + + typ := &Type{Name: "Account"} + got := tmpl.Format(typ) + if got != "account_store.go" { + t.Errorf("Format() = %q, want %q", got, "account_store.go") + } +} + +func TestTypeTemplate_Cond_Nil(t *testing.T) { + tmpl := TypeTemplates[0] + if tmpl.Cond != nil { + t.Error("default store template Cond should be nil") + } +} + +func TestTypeTemplate_Cond_True(t *testing.T) { + tmpl := TypeTemplate{ + Cond: func(_ *Type) bool { return true }, + Format: func(_ *Type) string { return "test.go" }, + Name: "test", + } + if !tmpl.Cond(&Type{}) { + t.Error("Cond should return true") + } + if tmpl.Name != "test" { + t.Errorf("Name = %q, want %q", tmpl.Name, "test") + } + if got := tmpl.Format(&Type{}); got != "test.go" { + t.Errorf("Format = %q, want %q", got, "test.go") + } +} + +func TestTypeTemplate_Cond_False(t *testing.T) { + tmpl := TypeTemplate{ + Cond: func(_ *Type) bool { return false }, + Format: func(_ *Type) string { return "test.go" }, + Name: "test", + } + if tmpl.Cond(&Type{}) { + t.Error("Cond should return false") + } + if tmpl.Name != "test" { + t.Errorf("Name = %q, want %q", tmpl.Name, "test") + } + if got := tmpl.Format(&Type{}); got != "test.go" { + t.Errorf("Format = %q, want %q", got, "test.go") + } +} + +func TestGraphTemplate_Skip_Nil(t *testing.T) { + tmpl := GraphTemplate{ + Name: "graph", + Format: "graph.go", + } + if tmpl.Skip != nil { + t.Error("Skip should be nil") + } + if tmpl.Name != "graph" { + t.Errorf("Name = %q, want %q", tmpl.Name, "graph") + } + if tmpl.Format != "graph.go" { + t.Errorf("Format = %q, want %q", tmpl.Format, "graph.go") + } +} + +func TestGraphTemplate_Skip_True(t *testing.T) { + tmpl := GraphTemplate{ + Name: "graph", + Skip: func(_ *Graph) bool { return true }, + Format: "graph.go", + } + if !tmpl.Skip(&Graph{Config: &Config{}}) { + t.Error("Skip should return true") + } + if tmpl.Name != "graph" { + t.Errorf("Name = %q, want %q", tmpl.Name, "graph") + } + if tmpl.Format != "graph.go" { + t.Errorf("Format = %q, want %q", tmpl.Format, "graph.go") + } +} + +func TestGraphTemplate_Skip_False(t *testing.T) { + tmpl := GraphTemplate{ + Name: "graph", + Skip: func(_ *Graph) bool { return false }, + Format: "graph.go", + } + if tmpl.Skip(&Graph{Config: &Config{}}) { + t.Error("Skip should return false") + } + if tmpl.Name != "graph" { + t.Errorf("Name = %q, want %q", tmpl.Name, "graph") + } + if tmpl.Format != "graph.go" { + t.Errorf("Format = %q, want %q", tmpl.Format, "graph.go") + } +} + +func TestGraphTemplates_Empty(t *testing.T) { + if len(GraphTemplates) != 0 { + t.Errorf("GraphTemplates should be empty, got %d", len(GraphTemplates)) + } +} diff --git a/internal/codegen/template_test.go b/internal/codegen/template_test.go new file mode 100644 index 0000000..37a525e --- /dev/null +++ b/internal/codegen/template_test.go @@ -0,0 +1,263 @@ +package codegen + +import ( + "errors" + "html/template" + "os" + "path/filepath" + "testing" + "testing/fstest" + "text/template/parse" +) + +func TestInitTemplates(t *testing.T) { + templates = nil + initTemplates() + if templates == nil { + t.Fatal("templates should not be nil after initTemplates") + } +} + +func TestNewTemplate(t *testing.T) { + tmpl := NewTemplate("test") + if tmpl == nil { + t.Fatal("NewTemplate returned nil") + } + if tmpl.Template == nil { + t.Fatal("inner template should not be nil") + } + if tmpl.FuncMap == nil { + t.Fatal("FuncMap should not be nil") + } + if _, ok := tmpl.FuncMap["lower"]; !ok { + t.Error("FuncMap missing 'lower'") + } + if _, ok := tmpl.FuncMap["upper"]; !ok { + t.Error("FuncMap missing 'upper'") + } +} + +func TestTemplate_Funcs_MergesNew(t *testing.T) { + tmpl := NewTemplate("test") + custom := template.FuncMap{ + "custom": func() string { return "hi" }, + } + ret := tmpl.Funcs(custom) + if ret != tmpl { + t.Error("Funcs should return the same Template") + } + if _, ok := tmpl.FuncMap["custom"]; !ok { + t.Error("FuncMap missing 'custom'") + } +} + +func TestTemplate_Funcs_DoesNotOverwrite(t *testing.T) { + tmpl := NewTemplate("test") + originalLower := tmpl.FuncMap["lower"] + tmpl.Funcs(template.FuncMap{ + "lower": func(_ string) string { return "override" }, + }) + if tmpl.FuncMap["lower"] == nil { + t.Fatal("lower should still exist") + } + _ = originalLower +} + +func TestTemplate_Funcs_NilFuncMap(t *testing.T) { + tmpl := &Template{Template: template.New("bare")} + tmpl.FuncMap = nil + tmpl.Funcs(template.FuncMap{ + "foo": func() string { return "bar" }, + }) + if tmpl.FuncMap == nil { + t.Fatal("FuncMap should be initialized") + } + if _, ok := tmpl.FuncMap["foo"]; !ok { + t.Error("FuncMap missing 'foo'") + } +} + +func TestMustParse_Success(t *testing.T) { + tmpl := NewTemplate("test") + parsed, err := tmpl.Parse("{{ . }}") + if err != nil { + t.Fatal(err) + } + got := MustParse(parsed, nil) + if got != parsed { + t.Error("MustParse should return the template on success") + } +} + +func TestMustParse_Panics(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Error("MustParse should panic on error") + } + }() + MustParse(nil, errTest) +} + +var errTest = func() error { + return &testError{} +}() + +type testError struct{} + +func (e *testError) Error() string { return "test error" } + +func TestTemplate_ParseFS_Success(t *testing.T) { + fsys := fstest.MapFS{ + "test.tmpl": &fstest.MapFile{Data: []byte(`{{ define "test" }}hello{{ end }}`)}, + } + tmpl := NewTemplate("test") + got, err := tmpl.ParseFS(fsys, "*.tmpl") + if err != nil { + t.Fatalf("ParseFS error: %v", err) + } + if got != tmpl { + t.Error("ParseFS should return the same Template") + } +} + +func TestTemplate_ParseFS_Error(t *testing.T) { + fsys := fstest.MapFS{} + tmpl := NewTemplate("test") + _, err := tmpl.ParseFS(fsys, "nonexistent/*.tmpl") + if err == nil { + t.Error("ParseFS should return error for missing pattern") + } +} + +func TestTemplate_SkipIf(t *testing.T) { + tmpl := NewTemplate("test") + if tmpl.condition != nil { + t.Error("condition should be nil initially") + } + cond := func(_ *Graph) bool { return true } + ret := tmpl.SkipIf(cond) + if ret != tmpl { + t.Error("SkipIf should return the same Template") + } + if tmpl.condition == nil { + t.Error("condition should be set") + } + if !tmpl.condition(&Graph{Config: &Config{}}) { + t.Error("condition should return true") + } +} + +func TestTemplate_Parse_Success(t *testing.T) { + tmpl := NewTemplate("test") + got, err := tmpl.Parse("hello {{ lower . }}") + if err != nil { + t.Fatalf("Parse error: %v", err) + } + if got != tmpl { + t.Error("Parse should return the same Template") + } +} + +func TestTemplate_Parse_Error(t *testing.T) { + tmpl := NewTemplate("test") + _, err := tmpl.Parse("{{ .Broken") + if err == nil { + t.Error("Parse should return error for invalid template") + } +} + +func TestTemplate_ParseFiles_Error(t *testing.T) { + tmpl := NewTemplate("test") + _, err := tmpl.ParseFiles("/nonexistent/file.tmpl") + if err == nil { + t.Error("ParseFiles should return error for missing file") + } +} + +func TestTemplate_AddParseTree_Success(t *testing.T) { + tmpl := NewTemplate("test") + if _, err := tmpl.Parse("base"); err != nil { + t.Fatal(err) + } + + tree := &parse.Tree{ + Name: "sub", + Root: &parse.ListNode{ + NodeType: parse.NodeList, + }, + } + got, err := tmpl.AddParseTree("sub", tree) + if err != nil { + t.Fatalf("AddParseTree error: %v", err) + } + if got != tmpl { + t.Error("AddParseTree should return the same Template") + } +} + +func TestDefaultFuncMap_Lower(t *testing.T) { + fn, ok := defaultFuncMap["lower"] + if !ok { + t.Fatal("defaultFuncMap missing 'lower'") + } + lower, ok := fn.(func(string) string) + if !ok { + t.Fatal("lower is not func(string) string") + } + if got := lower("HELLO"); got != "hello" { + t.Errorf("lower(HELLO) = %q, want %q", got, "hello") + } +} + +func TestDefaultFuncMap_Upper(t *testing.T) { + fn, ok := defaultFuncMap["upper"] + if !ok { + t.Fatal("defaultFuncMap missing 'upper'") + } + upper, ok := fn.(func(string) string) + if !ok { + t.Fatal("upper is not func(string) string") + } + if got := upper("hello"); got != "HELLO" { + t.Errorf("upper(hello) = %q, want %q", got, "HELLO") + } +} + +func TestTemplate_ParseFiles_Success(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "test.tmpl") + if err := os.WriteFile(path, []byte(`{{ define "hello" }}world{{ end }}`), 0600); err != nil { + t.Fatal(err) + } + + tmpl := NewTemplate("test") + got, err := tmpl.ParseFiles(path) + if err != nil { + t.Fatalf("ParseFiles error: %v", err) + } + if got != tmpl { + t.Error("ParseFiles should return the same Template") + } +} + +func TestTemplate_AddParseTree_Error(t *testing.T) { + tmpl := NewTemplate("test") + orig := addParseTreeFunc + defer func() { addParseTreeFunc = orig }() + + addParseTreeFunc = func(_ *template.Template, _ string, _ *parse.Tree) (*template.Template, error) { + return nil, errors.New("mock add parse tree error") + } + + tree := &parse.Tree{ + Name: "sub", + Root: &parse.ListNode{NodeType: parse.NodeList}, + } + got, err := tmpl.AddParseTree("sub", tree) + if err == nil { + t.Fatal("AddParseTree should return error") + } + if got != nil { + t.Error("AddParseTree should return nil on error") + } +} diff --git a/internal/codegen/type.go b/internal/codegen/type.go index c06eb85..b1a83c0 100644 --- a/internal/codegen/type.go +++ b/internal/codegen/type.go @@ -1,16 +1,258 @@ package codegen -import "github.com/CreditWorthy/mmapforge" +import ( + "fmt" + "strings" + "github.com/CreditWorthy/mmapforge" +) + +// Type represents a single mmapforge schema node — the central object type Type struct { *Config - Name string - Package string - Fields []*Field + + // Name is the Go struct name from the schema + Name string + + // Package is the Go package name for the generated file. + Package string + + // Fields holds the computed field layouts for this type. + Fields []*Field + + // SchemaVersion is the schema migration version. SchemaVersion uint32 - RecordSize uint32 + + // RecordSize is the total byte size of one record. + RecordSize uint32 } +// Field wraps mmapforge.FieldLayout and adds template helper methods. type Field struct { mmapforge.FieldLayout } + +// Header returns the file header for generated code. +func (t *Type) Header() string { + return t.Config.header() +} + +// Label returns the snake_case name of the type +func (t *Type) Label() string { + return strings.ToLower(t.Name) +} + +// StoreName returns the generated store struct name +func (t *Type) StoreName() string { + return t.Name + "Store" +} + +// RecordName returns the generated record struct name +func (t *Type) RecordName() string { + return t.Name + "Record" +} + +// LayoutFuncName returns the name of the Layout() func +func (t *Type) LayoutFuncName() string { + return t.Name + "Layout" +} + +// NewStoreFuncName returns the name for CreateStore +func (t *Type) NewStoreFuncName() string { + return "New" + t.Name + "Store" +} + +// OpenStoreFuncName returns the name for OpenStore +func (t *Type) OpenStoreFuncName() string { + return "Open" + t.Name + "Store" +} + +// Receiver returns a short receiver variable name for store methods. +func (t *Type) Receiver() string { + return "s" +} + +// HasStringField reports if any field is a string. +func (t *Type) HasStringField() bool { + for _, f := range t.Fields { + if f.IsString() { + return true + } + } + return false +} + +// HasBytesField reports if any field is a bytes field. +func (t *Type) HasBytesField() bool { + for _, f := range t.Fields { + if f.IsBytes() { + return true + } + } + return false +} + +// HasVarLenField reports if any field is variable-length (string or bytes). +func (t *Type) HasVarLenField() bool { + return t.HasStringField() || t.HasBytesField() +} + +// GoType returns the Go type string for this field. +func (f *Field) GoType() string { + switch f.Type { + case mmapforge.FieldBool: + return "bool" + case mmapforge.FieldInt8: + return "int8" + case mmapforge.FieldUint8: + return "uint8" + case mmapforge.FieldInt16: + return "int16" + case mmapforge.FieldUint16: + return "uint16" + case mmapforge.FieldInt32: + return "int32" + case mmapforge.FieldUint32: + return "uint32" + case mmapforge.FieldInt64: + return "int64" + case mmapforge.FieldUint64: + return "uint64" + case mmapforge.FieldFloat32: + return "float32" + case mmapforge.FieldFloat64: + return "float64" + case mmapforge.FieldString: + return "string" + case mmapforge.FieldBytes: + return "[]byte" + default: + return "unknown" + } +} + +// GetterName returns the name for the getter method +func (f *Field) GetterName() string { + return "Get" + f.GoName +} + +// SetterName returns the name for the setter method +func (f *Field) SetterName() string { + return "Set" + f.GoName +} + +// IsString reports if the field is a string. +func (f *Field) IsString() bool { + return f.Type == mmapforge.FieldString +} + +// IsBytes reports if the field is a []byte. +func (f *Field) IsBytes() bool { + return f.Type == mmapforge.FieldBytes +} + +// IsVarLen reports if the field is variable-length. +func (f *Field) IsVarLen() bool { + return f.IsString() || f.IsBytes() +} + +// IsNumeric reports if the field is a numeric type. +func (f *Field) IsNumeric() bool { + switch f.Type { + case mmapforge.FieldInt8, mmapforge.FieldUint8, + mmapforge.FieldInt16, mmapforge.FieldUint16, + mmapforge.FieldInt32, mmapforge.FieldUint32, + mmapforge.FieldInt64, mmapforge.FieldUint64, + mmapforge.FieldFloat32, mmapforge.FieldFloat64: + return true + default: + return false + } +} + +// IsBool reports if the field is a bool. +func (f *Field) IsBool() bool { + return f.Type == mmapforge.FieldBool +} + +// TypeConstant returns the fmmap.FieldType integer for template use. +func (f *Field) TypeConstant() int { + return int(f.Type) +} + +// ReadCall returns the Store.Read* method call expression for this field. +func (f *Field) ReadCall() string { + switch f.Type { + case mmapforge.FieldBool: + return fmt.Sprintf("s.ReadBool(idx, %d)", f.Offset) + case mmapforge.FieldInt8: + return fmt.Sprintf("s.ReadInt8(idx, %d)", f.Offset) + case mmapforge.FieldUint8: + return fmt.Sprintf("s.ReadUint8(idx, %d)", f.Offset) + case mmapforge.FieldInt16: + return fmt.Sprintf("s.ReadInt16(idx, %d)", f.Offset) + case mmapforge.FieldUint16: + return fmt.Sprintf("s.ReadUint16(idx, %d)", f.Offset) + case mmapforge.FieldInt32: + return fmt.Sprintf("s.ReadInt32(idx, %d)", f.Offset) + case mmapforge.FieldUint32: + return fmt.Sprintf("s.ReadUint32(idx, %d)", f.Offset) + case mmapforge.FieldInt64: + return fmt.Sprintf("s.ReadInt64(idx, %d)", f.Offset) + case mmapforge.FieldUint64: + return fmt.Sprintf("s.ReadUint64(idx, %d)", f.Offset) + case mmapforge.FieldFloat32: + return fmt.Sprintf("s.ReadFloat32(idx, %d)", f.Offset) + case mmapforge.FieldFloat64: + return fmt.Sprintf("s.ReadFloat64(idx, %d)", f.Offset) + case mmapforge.FieldString: + return fmt.Sprintf("s.ReadString(idx, %d, %d, %d)", f.Offset, f.Size, f.MaxSize) + case mmapforge.FieldBytes: + return fmt.Sprintf("s.ReadBytes(idx, %d, %d, %d)", f.Offset, f.Size, f.MaxSize) + default: + return "nil, nil // unsupported type" + } +} + +// WriteCall returns the Store.Write* method call using "val" as the value arg. +func (f *Field) WriteCall() string { + return f.writeCallWith("val") +} + +// WriteCallRec returns the Store.Write* method call using "rec." as the value. +func (f *Field) WriteCallRec() string { + return f.writeCallWith("rec." + f.GoName) +} + +func (f *Field) writeCallWith(val string) string { + switch f.Type { + case mmapforge.FieldBool: + return fmt.Sprintf("s.WriteBool(idx, %d, %s)", f.Offset, val) + case mmapforge.FieldInt8: + return fmt.Sprintf("s.WriteInt8(idx, %d, %s)", f.Offset, val) + case mmapforge.FieldUint8: + return fmt.Sprintf("s.WriteUint8(idx, %d, %s)", f.Offset, val) + case mmapforge.FieldInt16: + return fmt.Sprintf("s.WriteInt16(idx, %d, %s)", f.Offset, val) + case mmapforge.FieldUint16: + return fmt.Sprintf("s.WriteUint16(idx, %d, %s)", f.Offset, val) + case mmapforge.FieldInt32: + return fmt.Sprintf("s.WriteInt32(idx, %d, %s)", f.Offset, val) + case mmapforge.FieldUint32: + return fmt.Sprintf("s.WriteUint32(idx, %d, %s)", f.Offset, val) + case mmapforge.FieldInt64: + return fmt.Sprintf("s.WriteInt64(idx, %d, %s)", f.Offset, val) + case mmapforge.FieldUint64: + return fmt.Sprintf("s.WriteUint64(idx, %d, %s)", f.Offset, val) + case mmapforge.FieldFloat32: + return fmt.Sprintf("s.WriteFloat32(idx, %d, %s)", f.Offset, val) + case mmapforge.FieldFloat64: + return fmt.Sprintf("s.WriteFloat64(idx, %d, %s)", f.Offset, val) + case mmapforge.FieldString: + return fmt.Sprintf("s.WriteString(idx, %d, %d, %d, %s)", f.Offset, f.Size, f.MaxSize, val) + case mmapforge.FieldBytes: + return fmt.Sprintf("s.WriteBytes(idx, %d, %d, %d, %s)", f.Offset, f.Size, f.MaxSize, val) + default: + return "nil // unsupported type" + } +} diff --git a/store.go b/store.go index 1aaeeb4..759a554 100644 --- a/store.go +++ b/store.go @@ -15,6 +15,7 @@ const initialCapacity = 64 var statFileFunc = func(f *os.File) (os.FileInfo, error) { return f.Stat() } var encodeHeaderFunc = EncodeHeader +// Store is the base mmap-backed record store. type Store struct { region *Region layout *RecordLayout @@ -27,6 +28,7 @@ type Store struct { writable bool } +// CreateStore creates a new mmapforge file at path with the given layout and schema version. func CreateStore(path string, layout *RecordLayout, schemaVersion uint32) (*Store, error) { f, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE|os.O_EXCL, 0644) if err != nil { @@ -76,6 +78,7 @@ func CreateStore(path string, layout *RecordLayout, schemaVersion uint32) (*Stor return s, nil } +// OpenStore opens an existing mmapforge file and validates the schema hash. func OpenStore(path string, layout *RecordLayout) (*Store, error) { f, err := os.OpenFile(path, os.O_RDWR, 0) if err != nil { @@ -141,6 +144,7 @@ func OpenStore(path string, layout *RecordLayout) (*Store, error) { return s, nil } +// Close syncs and closes the store. All references into store memory become invalid. func (s *Store) Close() error { if s.region == nil { return fmt.Errorf("mmapforge: close %s: %w", s.path, ErrClosed) @@ -169,6 +173,7 @@ func (s *Store) Close() error { return err } +// Sync flushes the header and dirty pages to disk. func (s *Store) Sync() error { if s.region == nil { return fmt.Errorf("mmapforge: sync %s: %w", s.path, ErrClosed) @@ -179,6 +184,7 @@ func (s *Store) Sync() error { return s.region.Sync() } +// Len returns the number of records in the store. func (s *Store) Len() int { v := s.recordCountPtr.Load() if v > uint64(math.MaxInt) { @@ -187,6 +193,7 @@ func (s *Store) Len() int { return int(v) } +// Cap returns how many records fit in the current file mapping. func (s *Store) Cap() int { v := s.capacityPtr.Load() if v > uint64(math.MaxInt) { @@ -195,6 +202,7 @@ func (s *Store) Cap() int { return int(v) } +// Append adds a new zero-filled record and returns its index. func (s *Store) Append() (int, error) { if s.region == nil { return 0, fmt.Errorf("mmapforge: append %s: %w", s.path, ErrClosed) @@ -220,11 +228,13 @@ func (s *Store) Append() (int, error) { return int(idx), nil } +// flushHeader writes the live record count back to the header and encodes it. func (s *Store) flushHeader() error { s.header.RecordCount = s.recordCountPtr.Load() return encodeHeaderFunc(s.region.Slice(0, HeaderSize), s.header) } +// grow doubles the capacity of the store. func (s *Store) grow() error { newCap := s.capacityPtr.Load() * 2 if newCap == 0 {