-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
249 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
package internal | ||
|
||
import ( | ||
"bytes" | ||
"go/ast" | ||
"go/format" | ||
"go/parser" | ||
"go/token" | ||
"strings" | ||
) | ||
|
||
// TODO: Must flattening the nested unnecessary if-else blocks. | ||
|
||
// improveCode refactors the input source code and returns the formatted version. | ||
func improveCode(src []byte) (string, error) { | ||
fset := token.NewFileSet() | ||
file, err := parser.ParseFile(fset, "", src, parser.ParseComments) | ||
if err != nil { | ||
return "", err | ||
} | ||
|
||
err = refactorAST(file) | ||
if err != nil { | ||
return "", err | ||
} | ||
|
||
return formatSource(fset, file) | ||
} | ||
|
||
// refactorAST processes the AST to modify specific patterns. | ||
func refactorAST(file *ast.File) error { | ||
ast.Inspect(file, func(n ast.Node) bool { | ||
ifStmt, ok := n.(*ast.IfStmt) | ||
if !ok || ifStmt.Else == nil { | ||
return true | ||
} | ||
|
||
blockStmt, ok := ifStmt.Else.(*ast.BlockStmt) | ||
if !ok || len(ifStmt.Body.List) == 0 { | ||
return true | ||
} | ||
|
||
_, isReturn := ifStmt.Body.List[len(ifStmt.Body.List)-1].(*ast.ReturnStmt) | ||
if !isReturn { | ||
return true | ||
} | ||
|
||
mergeElseIntoIf(file, ifStmt, blockStmt) | ||
ifStmt.Else = nil | ||
|
||
return true | ||
}) | ||
return nil | ||
} | ||
|
||
// mergeElseIntoIf merges the statements of an 'else' block into the enclosing function body. | ||
func mergeElseIntoIf(file *ast.File, ifStmt *ast.IfStmt, blockStmt *ast.BlockStmt) { | ||
for _, list := range file.Decls { | ||
decl, ok := list.(*ast.FuncDecl) | ||
if !ok { | ||
continue | ||
} | ||
for i, stmt := range decl.Body.List { | ||
if ifStmt != stmt { | ||
continue | ||
} | ||
decl.Body.List = append(decl.Body.List[:i+1], append(blockStmt.List, decl.Body.List[i+1:]...)...) | ||
break | ||
} | ||
} | ||
} | ||
|
||
// formatSource formats the AST back to source code. | ||
func formatSource(fset *token.FileSet, file *ast.File) (string, error) { | ||
var buf bytes.Buffer | ||
err := format.Node(&buf, fset, file) | ||
if err != nil { | ||
return "", err | ||
} | ||
return strings.TrimRight(buf.String(), "\n"), nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
package internal | ||
|
||
import ( | ||
"testing" | ||
|
||
"github.com/stretchr/testify/assert" | ||
"github.com/stretchr/testify/require" | ||
) | ||
|
||
func TestImproveCode(t *testing.T) { | ||
testCases := []struct { | ||
name string | ||
input string | ||
expected string | ||
}{ | ||
{ | ||
name: "don't need to modify", | ||
input: `package main | ||
func foo(x bool) int { | ||
if x { | ||
println("x") | ||
} else { | ||
println("hello") | ||
} | ||
}`, | ||
expected: `package main | ||
func foo(x bool) int { | ||
if x { | ||
println("x") | ||
} else { | ||
println("hello") | ||
} | ||
}`, | ||
}, | ||
{ | ||
name: "Remove unnecessary else", | ||
input: ` | ||
package main | ||
func unnecessaryElse() bool { | ||
if condition { | ||
return true | ||
} else { | ||
return false | ||
} | ||
}`, | ||
expected: `package main | ||
func unnecessaryElse() bool { | ||
if condition { | ||
return true | ||
} | ||
return false | ||
}`, | ||
}, | ||
{ | ||
name: "Keep necessary else", | ||
input: ` | ||
package main | ||
func necessaryElse() int { | ||
if condition { | ||
return 1 | ||
} else { | ||
doSomething() | ||
return 2 | ||
} | ||
}`, | ||
expected: `package main | ||
func necessaryElse() int { | ||
if condition { | ||
return 1 | ||
} | ||
doSomething() | ||
return 2 | ||
}`, | ||
}, | ||
// { | ||
// name: "Multiple unnecessary else", | ||
// input: ` | ||
// package main | ||
|
||
// func multipleUnnecessaryElse() int { | ||
// if condition1 { | ||
// return 1 | ||
// } else { | ||
// if condition2 { | ||
// return 2 | ||
// } else { | ||
// return 3 | ||
// } | ||
// } | ||
// }`, | ||
// expected: `package main | ||
|
||
// func multipleUnnecessaryElse() int { | ||
// if condition1 { | ||
// return 1 | ||
// } | ||
// if condition2 { | ||
// return 2 | ||
// } | ||
// return 3 | ||
// } | ||
// `, | ||
// }, | ||
// { | ||
// name: "Mixed necessary and unnecessary else", | ||
// input: ` | ||
// package main | ||
|
||
// func mixedElse() int { | ||
// if condition1 { | ||
// return 1 | ||
// } else { | ||
// if condition2 { | ||
// doSomething() | ||
// return 2 | ||
// } else { | ||
// return 3 | ||
// } | ||
// } | ||
// }`, | ||
// expected: `package main | ||
|
||
// func mixedElse() int { | ||
// if condition1 { | ||
// return 1 | ||
// } else { | ||
// if condition2 { | ||
// doSomething() | ||
// return 2 | ||
// } | ||
// return 3 | ||
// } | ||
// } | ||
// `, | ||
// }, | ||
} | ||
|
||
for _, tc := range testCases { | ||
t.Run(tc.name, func(t *testing.T) { | ||
improved, err := improveCode([]byte(tc.input)) | ||
require.NoError(t, err) | ||
assert.Equal(t, tc.expected, improved, "Improved code does not match expected output") | ||
}) | ||
} | ||
} |