diff --git a/pkg/hir/lower.go b/pkg/hir/lower.go index 8fa56f8..74a5b69 100644 --- a/pkg/hir/lower.go +++ b/pkg/hir/lower.go @@ -299,7 +299,17 @@ func extractBodies(es []Expr, schema *mir.Schema) []mir.Expr { func expand(e Expr, schema sc.Schema) []Expr { if p, ok := e.(*Add); ok { return expandWithNaryConstructor(p.Args, func(nargs []Expr) Expr { - return &Add{Args: nargs} + var args []Expr + // Flatten nested sums + for _, e := range nargs { + if a, ok := e.(*Add); ok { + args = append(args, a.Args...) + } else { + args = append(args, e) + } + } + // Done + return &Add{Args: args} }, schema) } else if _, ok := e.(*Constant); ok { return []Expr{e} @@ -307,7 +317,17 @@ func expand(e Expr, schema sc.Schema) []Expr { return []Expr{e} } else if p, ok := e.(*Mul); ok { return expandWithNaryConstructor(p.Args, func(nargs []Expr) Expr { - return &Mul{Args: nargs} + var args []Expr + // Flatten nested products + for _, e := range nargs { + if a, ok := e.(*Mul); ok { + args = append(args, a.Args...) + } else { + args = append(args, e) + } + } + // Done + return &Mul{Args: args} }, schema) } else if p, ok := e.(*List); ok { ees := make([]Expr, 0)