From a39d8e8f0dc9f5f25bfb5276ef420ad4bc991416 Mon Sep 17 00:00:00 2001 From: David Pearce Date: Thu, 23 Jan 2025 15:43:05 +1300 Subject: [PATCH] feat: flattern nested arithmetic (#594) --- pkg/hir/lower.go | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/pkg/hir/lower.go b/pkg/hir/lower.go index 8fa56f8e..74a5b691 100644 --- a/pkg/hir/lower.go +++ b/pkg/hir/lower.go @@ -299,7 +299,17 @@ func extractBodies(es []Expr, schema *mir.Schema) []mir.Expr { func expand(e Expr, schema sc.Schema) []Expr { if p, ok := e.(*Add); ok { return expandWithNaryConstructor(p.Args, func(nargs []Expr) Expr { - return &Add{Args: nargs} + var args []Expr + // Flatten nested sums + for _, e := range nargs { + if a, ok := e.(*Add); ok { + args = append(args, a.Args...) + } else { + args = append(args, e) + } + } + // Done + return &Add{Args: args} }, schema) } else if _, ok := e.(*Constant); ok { return []Expr{e} @@ -307,7 +317,17 @@ func expand(e Expr, schema sc.Schema) []Expr { return []Expr{e} } else if p, ok := e.(*Mul); ok { return expandWithNaryConstructor(p.Args, func(nargs []Expr) Expr { - return &Mul{Args: nargs} + var args []Expr + // Flatten nested products + for _, e := range nargs { + if a, ok := e.(*Mul); ok { + args = append(args, a.Args...) + } else { + args = append(args, e) + } + } + // Done + return &Mul{Args: args} }, schema) } else if p, ok := e.(*List); ok { ees := make([]Expr, 0)