Skip to content

Commit

Permalink
flatten if-else
Browse files Browse the repository at this point in the history
  • Loading branch information
notJoon committed Jul 17, 2024
1 parent 75e8004 commit f0117bf
Show file tree
Hide file tree
Showing 4 changed files with 249 additions and 0 deletions.
4 changes: 4 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,12 @@ require github.com/stretchr/testify v1.9.0

require (
github.com/BurntSushi/toml v1.2.1 // indirect
github.com/fatih/color v1.17.0 // indirect
github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
golang.org/x/mod v0.19.0 // indirect
golang.org/x/sync v0.7.0 // indirect
golang.org/x/sys v0.22.0 // indirect
)

require (
Expand Down
11 changes: 11 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,13 @@ github.com/BurntSushi/toml v1.2.1 h1:9F2/+DoOYIOksmaJFPw1tGFy1eDnIJXg+UHjuD8lTak
github.com/BurntSushi/toml v1.2.1/go.mod h1:CxXYINrC8qIiEnFrOxCa7Jy5BFHlXnUU2pbicEuybxQ=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/fatih/color v1.17.0 h1:GlRw1BRJxkpqUCBKzKOw098ed57fEsKeNjpTe3cSjK4=
github.com/fatih/color v1.17.0/go.mod h1:YZ7TlrGPkiz6ku9fK3TLD/pl3CpsiFyu8N92HLgmosI=
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
Expand All @@ -10,6 +17,10 @@ golang.org/x/mod v0.19.0 h1:fEdghXQSo20giMthA7cd28ZC+jts4amQ3YMXiP5oMQ8=
golang.org/x/mod v0.19.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI=
golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/tools v0.23.0 h1:SGsXPZ+2l4JsgaCKkx+FQ9YZ5XEtA1GZYuoDjenLjvg=
golang.org/x/tools v0.23.0/go.mod h1:pnu6ufv6vQkll6szChhK3C3L/ruaIv5eBeztNG8wtsI=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
Expand Down
81 changes: 81 additions & 0 deletions internal/fixer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
package internal

import (
"bytes"
"go/ast"
"go/format"
"go/parser"
"go/token"
"strings"
)

// TODO: Must flattening the nested unnecessary if-else blocks.

// improveCode refactors the input source code and returns the formatted version.
func improveCode(src []byte) (string, error) {
fset := token.NewFileSet()
file, err := parser.ParseFile(fset, "", src, parser.ParseComments)
if err != nil {
return "", err
}

err = refactorAST(file)
if err != nil {
return "", err
}

return formatSource(fset, file)
}

// refactorAST processes the AST to modify specific patterns.
func refactorAST(file *ast.File) error {
ast.Inspect(file, func(n ast.Node) bool {
ifStmt, ok := n.(*ast.IfStmt)
if !ok || ifStmt.Else == nil {
return true
}

blockStmt, ok := ifStmt.Else.(*ast.BlockStmt)
if !ok || len(ifStmt.Body.List) == 0 {
return true
}

_, isReturn := ifStmt.Body.List[len(ifStmt.Body.List)-1].(*ast.ReturnStmt)
if !isReturn {
return true
}

mergeElseIntoIf(file, ifStmt, blockStmt)
ifStmt.Else = nil

return true
})
return nil
}

// mergeElseIntoIf merges the statements of an 'else' block into the enclosing function body.
func mergeElseIntoIf(file *ast.File, ifStmt *ast.IfStmt, blockStmt *ast.BlockStmt) {
for _, list := range file.Decls {
decl, ok := list.(*ast.FuncDecl)
if !ok {
continue
}
for i, stmt := range decl.Body.List {
if ifStmt != stmt {
continue
}
decl.Body.List = append(decl.Body.List[:i+1], append(blockStmt.List, decl.Body.List[i+1:]...)...)
break
}
}
}

// formatSource formats the AST back to source code.
func formatSource(fset *token.FileSet, file *ast.File) (string, error) {
var buf bytes.Buffer
err := format.Node(&buf, fset, file)
if err != nil {
return "", err
}
return strings.TrimRight(buf.String(), "\n"), nil
}
153 changes: 153 additions & 0 deletions internal/fixer_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
package internal

import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestImproveCode(t *testing.T) {
testCases := []struct {
name string
input string
expected string
}{
{
name: "don't need to modify",
input: `package main
func foo(x bool) int {
if x {
println("x")
} else {
println("hello")
}
}`,
expected: `package main
func foo(x bool) int {
if x {
println("x")
} else {
println("hello")
}
}`,
},
{
name: "Remove unnecessary else",
input: `
package main
func unnecessaryElse() bool {
if condition {
return true
} else {
return false
}
}`,
expected: `package main
func unnecessaryElse() bool {
if condition {
return true
}
return false
}`,
},
{
name: "Keep necessary else",
input: `
package main
func necessaryElse() int {
if condition {
return 1
} else {
doSomething()
return 2
}
}`,
expected: `package main
func necessaryElse() int {
if condition {
return 1
}
doSomething()
return 2
}`,
},
// {
// name: "Multiple unnecessary else",
// input: `
// package main

// func multipleUnnecessaryElse() int {
// if condition1 {
// return 1
// } else {
// if condition2 {
// return 2
// } else {
// return 3
// }
// }
// }`,
// expected: `package main

// func multipleUnnecessaryElse() int {
// if condition1 {
// return 1
// }
// if condition2 {
// return 2
// }
// return 3
// }
// `,
// },
// {
// name: "Mixed necessary and unnecessary else",
// input: `
// package main

// func mixedElse() int {
// if condition1 {
// return 1
// } else {
// if condition2 {
// doSomething()
// return 2
// } else {
// return 3
// }
// }
// }`,
// expected: `package main

// func mixedElse() int {
// if condition1 {
// return 1
// } else {
// if condition2 {
// doSomething()
// return 2
// }
// return 3
// }
// }
// `,
// },
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
improved, err := improveCode([]byte(tc.input))
require.NoError(t, err)
assert.Equal(t, tc.expected, improved, "Improved code does not match expected output")
})
}
}

0 comments on commit f0117bf

Please sign in to comment.