From 2ca609f38591b7a2097ce4e60fb186ec31d052ab Mon Sep 17 00:00:00 2001 From: lenbo Date: Sat, 6 Jul 2019 18:49:29 +0800 Subject: [PATCH] Feature/1.1.0 (#14) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * refactor & use go module * 调整菜单编排方式 * 增加分组折叠功能 * 增加别名 * 支持快速登录 * 优化帮助菜单输出 * 增加版本检测功能 * 一键安装脚本 * 增加日志保存功能 * 记录log * 删除功能支持ctrl d退出 * 增加文件上传/下载功能 --- .gitignore | 4 +- README.md | 91 ++---- build.sh | 6 +- config.example.json | 52 ++-- core/app.go | 383 ------------------------- core/params.go | 35 --- core/print.go | 39 --- go.mod | 11 + go.sum | 16 ++ install | 21 ++ main.go | 88 ------ src/app/app.go | 59 ++++ src/app/config.go | 156 ++++++++++ src/app/handle_add.go | 41 +++ src/app/handle_edit.go | 29 ++ src/app/handle_remove.go | 34 +++ src/app/io_client.go | 62 ++++ src/app/load_config.go | 31 ++ src/app/scan.go | 112 ++++++++ {core => src/app}/server.go | 309 +++++++++++--------- src/app/server_edit.go | 61 ++++ src/app/show_cp.go | 353 +++++++++++++++++++++++ src/app/show_help.go | 32 +++ src/app/show_menu.go | 74 +++++ src/app/show_servers.go | 86 ++++++ src/app/show_upgrade.go | 294 +++++++++++++++++++ src/app/show_version.go | 10 + src/utils/clean.go | 21 ++ src/utils/error_assert.go | 8 + core/util.go => src/utils/file_path.go | 69 ++--- core/log.go => src/utils/logger.go | 8 +- src/utils/printer.go | 64 +++++ src/utils/scan.go | 13 + src/utils/str.go | 46 +++ 34 files changed, 1883 insertions(+), 835 deletions(-) mode change 100755 => 100644 build.sh delete mode 100644 core/app.go delete mode 100644 core/params.go delete mode 100644 core/print.go create mode 100644 go.mod create mode 100644 go.sum create mode 100755 install delete mode 100755 main.go create mode 100644 src/app/app.go create mode 100644 src/app/config.go create mode 100644 src/app/handle_add.go create mode 100644 src/app/handle_edit.go create mode 100644 src/app/handle_remove.go create mode 100644 src/app/io_client.go create mode 100644 src/app/load_config.go create mode 100644 src/app/scan.go rename {core => src/app}/server.go (59%) create mode 100644 src/app/server_edit.go create mode 100644 src/app/show_cp.go create mode 100644 src/app/show_help.go create mode 100644 src/app/show_menu.go create mode 100644 src/app/show_servers.go create mode 100644 src/app/show_upgrade.go create mode 100644 src/app/show_version.go create mode 100644 src/utils/clean.go create mode 100644 src/utils/error_assert.go rename core/util.go => src/utils/file_path.go (72%) rename core/log.go => src/utils/logger.go (95%) create mode 100644 src/utils/printer.go create mode 100644 src/utils/scan.go create mode 100644 src/utils/str.go diff --git a/.gitignore b/.gitignore index 724d319..e566e8f 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,8 @@ servers.json +config.json releases/ -.idea/ +.idea main test.go app.log +vendor diff --git a/README.md b/README.md index 6b699eb..3634929 100644 --- a/README.md +++ b/README.md @@ -4,92 +4,37 @@ ![演示](https://raw.githubusercontent.com/islenbo/autossh/b3e18c35ebced882ace59be7843d9a58d1ac74d7/doc/images/ezgif-1-a4ddae192f.gif) -## 版本说明 -这是一个全新的autossh,无法兼容v0.2及以下版本,升级前请做好备份!新版配置文件由原来的`servers.json`改为`config.json`, -升级时可将旧配置文件的列表插入到新配置文件的`servers`节点下 - -注:旧版servers中method=pem需要更新为method=key - ## 功能说明 -- 支持分组 -- 支持显示/隐藏主机详情(show_detail) -- 支持options(目前仅支持ServerAliveInterval) -- 允许配置文件中server默认值为空 -- 允许指定配置文件目径 -- 修复终端窗口大小改变时无法自适应的bug +- 核心代码重构,使用go.mod管理依赖 +- 新增分组折叠功能 +- 新增自动更新检测功能 +- 新增一键安装脚本 +- 新增会话日志保存功能 +- 删除功能支持ctrl+d退出 +- 优化帮助菜单显示 +- 修复若干Bug ## 下载 [https://github.com/islenbo/autossh/releases](https://github.com/islenbo/autossh/releases) ## 安装 -- 下载编译好的二进制包autossh,放在指目录下,如`~/autossh`或`/usr/loca/autossh` -- 同级目录下新建`config.json`文件,参考`config.example.json` -- 将安装目录加入环境变量中,或指定别名`alias autossh=your autossh path/autossh` - -## config.json -```json -{ - "show_detail": true, // 显示主机详情 - "options": { // 全局配置 - "ServerAliveInterval": 30 // 发送心跳包时间,同 ssh -o ServerAliveInterval=30 - }, - "servers": [ - { - "name": "vagrant", // 显示名称 - "ip": "192.168.33.10", // 主机地址 - "port": 22, // 端口号,可省略,默认为22 - "user": "root", // 用户名 - "password": "vagrant", // 密码,使用无密码的key登录时可省略 - "method": "password", // 认证方式,可省略,默认值为password,可选项有password、key - "key": "", // 密钥路径,method=key时有效,可省略,默认为~/.ssh/id_rsa - "options": { // 自定义配置,会覆盖配置中相同的值 - "ServerAliveInterval": 20 - } - }, - { - "name": "vagrant-key", - "ip": "192.168.33.10", - "user": "root", - "method": "key" - } - ], - "groups": [ - { - "group_name": "your group name", - "prefix": "a", - "servers": [ - { - "name": "example1", - "ip": "192.168.33.10", - "user": "root", - "password": "root" - }, - { - "name": "example2", - "ip": "192.168.33.10", - "user": "root", - "password": "root" - } - ] - }, - { - "group_name": "group2", - "prefix": "b", - "servers": [ - ] - } - ] -} +- Mac/Linux用户直接下载安装包,运行install脚本即可。 +- Windows用户可手动编译,参考编译章节。 -``` +## config.json字段说明 +- TODO 字段说明 ## Q&A - Q: Downloads中为什么没有Windows的包? - A: Windows下有很多ssh工具,autossh主要是面向Mac/Linux群体。 ## 编译 +export GO111MODULE="on" +export GOFLAGS=" -mod=vendor" go build main.go -## 依赖包 -- golang.org/x/crypto/ssh +## 依赖 +- 查阅 go.mod +## 注意 +v0.X版本配置文件无法与v1.X版本兼容,请勿使用! \ No newline at end of file diff --git a/build.sh b/build.sh old mode 100755 new mode 100644 index 6c430ce..11477b6 --- a/build.sh +++ b/build.sh @@ -1,7 +1,7 @@ #!/bin/bash PROJECT="autossh" -VERSION="v1.0.0" +VERSION="v1.1.0" BUILD=`date +%FT%T%z` function build() { @@ -12,8 +12,10 @@ function build() { echo "build ${package} ..." mkdir -p "./releases/${package}" - CGO_ENABLED=0 GOOS=${os} GOARCH=${arch} go build -o "./releases/${package}/autossh" -ldflags "-X main.Version=${VERSION} -X main.Build=${BUILD}" main.go + CGO_ENABLED=0 GOOS=${os} GOARCH=${arch} go build -o "./releases/${package}/autossh" -ldflags "-X main.Version=${VERSION} -X main.Build=${BUILD}" src/main/main.go cp ./config.example.json "./releases/${package}/config.json" + chmod +x ./install + cp ./install "./releases/${package}/install" cd ./releases/ zip -r "./${package}.zip" "./${package}" echo "clean ${package}" diff --git a/config.example.json b/config.example.json index 9da3524..1828032 100644 --- a/config.example.json +++ b/config.example.json @@ -1,25 +1,26 @@ { - "show_detail": true, // 显示主机详情 - "options": { // 全局配置 - "ServerAliveInterval": 30 // 发送心跳包时间,同 ssh -o ServerAliveInterval=30 + "show_detail": true, + "options": { + "ServerAliveInterval": 30 }, "servers": [ { - "name": "vagrant", // 显示名称 - "ip": "192.168.33.10", // 主机地址 - "port": 22, // 端口号,可省略,默认为22 - "user": "root", // 用户名 - "password": "vagrant", // 密码,使用无密码的key登录时可省略 - "method": "password", // 认证方式,可省略,默认值为password,可选项有password、key - "key": "", // 密钥路径,method=key时有效,可省略,默认为~/.ssh/id_rsa - "options": { // 自定义配置,会覆盖配置中相同的值 + "name": "example-password", + "ip": "example-password", + "port": 22, + "user": "example-password", + "password": "example-password", + "method": "example-password", + "key": "", + "options": { "ServerAliveInterval": 20 - } + }, + "alias": "example" }, { - "name": "vagrant-key", - "ip": "192.168.33.10", - "user": "root", + "name": "example-key", + "ip": "example-key", + "user": "example-key", "method": "key" } ], @@ -30,23 +31,18 @@ "servers": [ { "name": "example1", - "ip": "192.168.33.10", - "user": "root", - "password": "root" + "ip": "example1", + "user": "example1", + "password": "example1" }, { "name": "example2", - "ip": "192.168.33.10", - "user": "root", - "password": "root" + "ip": "example2", + "user": "example2", + "password": "example2" } - ] - }, - { - "group_name": "group2", - "prefix": "b", - "servers": [ - ] + ], + "collapse": false } ] } diff --git a/core/app.go b/core/app.go deleted file mode 100644 index ea0eae2..0000000 --- a/core/app.go +++ /dev/null @@ -1,383 +0,0 @@ -package core - -import ( - "io/ioutil" - "encoding/json" - "errors" - "strconv" - "fmt" - "strings" - "bytes" - "os" - "io" - "path/filepath" - "time" -) - -type IndexType int - -const ( - IndexTypeServer IndexType = iota - IndexTypeGroup -) - -type Group struct { - GroupName string `json:"group_name"` - Prefix string `json:"prefix"` - Servers []Server `json:"servers"` -} - -type Config struct { - ShowDetail bool `json:"show_detail"` - Servers []Server `json:"servers"` - Groups []Group `json:"groups"` - Options map[string]interface{} `json:"options"` -} - -type ServerIndex struct { - indexType IndexType - groupIndex int - serverIndex int - server *Server -} - -type App struct { - ConfigPath string - config Config - serverIndex map[string]ServerIndex -} - -// 执行脚本 -func (app *App) Init() { - app.serverIndex = make(map[string]ServerIndex) - - // 解析配置 - app.loadConfig() - - app.loadServerMap(true) - - app.show() -} - -func (app *App) saveAndReload() { - app.saveConfig() - app.loadConfig() - app.loadServerMap(false) - app.show() -} - -func (app *App) show() { - //for { - Clear() - - // 输出server - app.showServers() - - // 监听输入 - input, isGlobal := app.checkInput() - if isGlobal { - end := app.handleGlobalCmd(input) - if end { - return - } - } else { - server := app.serverIndex[input].server - Printer.Infoln("你选择了", server.Name) - Log.Category("app").Info("select server", server.Name) - server.Connect() - } - //} -} - -func (app *App) handleGlobalCmd(cmd string) bool { - switch strings.ToLower(cmd) { - case "exit": - return true - case "edit": - app.handleEdit() - return false - case "add": - app.handleAdd() - return false - case "remove": - app.handleRemove() - return false - default: - Printer.Errorln("指令无效") - return false - } -} - -// 编辑 -func (app *App) handleEdit() { - Printer.Info("请输入相应序号(exit退出当前操作):") - id := "" - fmt.Scanln(&id) - - if strings.ToLower(id) == "exit" { - app.show() - return - } - - serverIndex, ok := app.serverIndex[id] - if !ok { - Printer.Errorln("序号不存在") - app.handleEdit() - return - } - - serverIndex.server.Edit() - app.saveAndReload() -} - -// 移除 -func (app *App) handleRemove() { - Printer.Info("请输入相应序号(exit退出当前操作):") - id := "" - fmt.Scanln(&id) - - if strings.ToLower(id) == "exit" { - app.show() - return - } - - serverIndex, ok := app.serverIndex[id] - if !ok { - Printer.Errorln("序号不存在") - app.handleEdit() - return - } - - if serverIndex.indexType == IndexTypeServer { - servers := app.config.Servers - app.config.Servers = append(servers[:serverIndex.serverIndex], servers[serverIndex.serverIndex+1:]...) - } else { - servers := app.config.Groups[serverIndex.groupIndex].Servers - servers = append(servers[:serverIndex.serverIndex], servers[serverIndex.serverIndex+1:]...) - app.config.Groups[serverIndex.groupIndex].Servers = servers - } - - app.saveAndReload() -} - -// 新增 -func (app *App) handleAdd() { - groups := make(map[string]*Group) - for i := range app.config.Groups { - group := &app.config.Groups[i] - groups[group.Prefix] = group - Printer.Info("["+group.Prefix+"]"+group.GroupName, "\t") - } - Printer.Infoln("[其他值]默认组") - Printer.Info("请输入要插入的组:") - g := "" - fmt.Scanln(&g) - - server := Server{} - server.Format() - server.Edit() - - group, ok := groups[g] - if ok { - group.Servers = append(group.Servers, server) - } else { - app.config.Servers = append(app.config.Servers, server) - } - - app.saveAndReload() -} - -// 保存配置文件 -func (app *App) saveConfig() error { - b, err := json.Marshal(app.config) - if err != nil { - return err - } - - var out bytes.Buffer - err = json.Indent(&out, b, "", "\t") - if err != nil { - return err - } - - err = app.backConfig() - if err != nil { - return err - } - - return ioutil.WriteFile(app.ConfigPath, out.Bytes(), os.ModePerm) -} - -func (app *App) backConfig() error { - srcFile, err := os.Open(app.ConfigPath) - if err != nil { - return err - } - - defer srcFile.Close() - - path, _ := filepath.Abs(filepath.Dir(app.ConfigPath)) - backupFile := path + "/config-" + time.Now().Format("20060102150405") + ".json" - desFile, err := os.Create(backupFile) - if err != nil { - return err - } - defer desFile.Close() - - _, err = io.Copy(desFile, srcFile) - if err != nil { - return err - } - - Printer.Infoln("配置文件已备份:", backupFile) - return nil -} - -// 检查输入 -func (app *App) checkInput() (string, bool) { - flag := "" - for { - fmt.Scanln(&flag) - Log.Category("app").Info("input scan:", flag) - - if app.isGlobalInput(flag) { - return flag, true - } - - if _, ok := app.serverIndex[flag]; !ok { - Printer.Errorln("输入有误,请重新输入") - } else { - return flag, false - } - } - - panic(errors.New("输入有误")) -} - -// 判断是否全局输入 -func (app *App) isGlobalInput(flag string) bool { - switch flag { - case "edit": - fallthrough - case "add": - fallthrough - case "remove": - fallthrough - case "exit": - return true - - default: - return false - } -} - -// 加载配置文件 -func (app *App) loadConfig() { - b, _ := ioutil.ReadFile(app.ConfigPath) - err := json.Unmarshal(b, &app.config) - if err != nil { - Printer.Errorln("加载配置文件失败", err) - panic(errors.New("加载配置文件失败:" + err.Error())) - } -} - -// 打印列表 -func (app *App) showServers() { - maxlen := app.separatorLength() - app.formatSeparator(" 欢迎使用 Auto SSH ", "=", maxlen) - for i, server := range app.config.Servers { - Printer.Logln(app.recordServer(strconv.Itoa(i+1), server)) - } - - for _, group := range app.config.Groups { - if len(group.Servers) == 0 { - continue - } - - app.formatSeparator(" "+group.GroupName+" ", "_", maxlen) - for i, server := range group.Servers { - Printer.Logln(app.recordServer(group.Prefix+strconv.Itoa(i+1), server)) - } - } - - app.formatSeparator("", "=", maxlen) - Printer.Logln("", "[add] 添加", " ", "[edit] 编辑", " ", "[remove] 删除") - Printer.Logln("", "[exit]\t退出") - app.formatSeparator("", "=", maxlen) - Printer.Info("请输入序号或操作: ") -} - -func (app *App) formatSeparator(title string, c string, maxlength float64) { - - charslen := int((maxlength - ZhLen(title)) / 2) - chars := "" - for i := 0; i < charslen; i ++ { - chars += c - } - - Printer.Infoln(chars + title + chars) -} - -func (app *App) separatorLength() float64 { - maxlength := 50.0 - for _, group := range app.config.Groups { - length := ZhLen(group.GroupName) - if length > maxlength { - maxlength = length + 10 - } - } - - return maxlength -} - -// 加载 -func (app *App) loadServerMap(check bool) { - Log.Category("app").Info("server count", len(app.config.Servers), "group count", len(app.config.Groups)) - - for i := range app.config.Servers { - server := &app.config.Servers[i] - server.Format() - flag := strconv.Itoa(i + 1) - - if _, ok := app.serverIndex[flag]; ok && check { - panic(errors.New("标识[" + flag + "]已存在,请检查您的配置文件")) - } - - server.MergeOptions(app.config.Options, false) - app.serverIndex[flag] = ServerIndex{ - indexType: IndexTypeServer, - groupIndex: -1, - serverIndex: i, - server: server, - } - } - - for i := range app.config.Groups { - group := &app.config.Groups[i] - for j := range group.Servers { - server := &group.Servers[j] - server.Format() - flag := group.Prefix + strconv.Itoa(j+1) - - if _, ok := app.serverIndex[flag]; ok && check { - panic(errors.New("标识[" + flag + "]已存在,请检查您的配置文件")) - } - - server.MergeOptions(app.config.Options, false) - app.serverIndex[flag] = ServerIndex{ - indexType: IndexTypeGroup, - groupIndex: i, - serverIndex: j, - server: server, - } - } - } -} - -func (app *App) recordServer(flag string, server Server) string { - if app.config.ShowDetail { - return " [" + flag + "]" + "\t" + server.Name + " [" + server.User + "@" + server.Ip + "]" - } else { - return " [" + flag + "]" + "\t" + server.Name - } -} diff --git a/core/params.go b/core/params.go deleted file mode 100644 index 5017861..0000000 --- a/core/params.go +++ /dev/null @@ -1,35 +0,0 @@ -package core - -import "os" - -var paramsMap map[string]string - -var Params params - -type params struct { -} - -func init() { - paramsMap = make(map[string]string) -} - -func (p params) Get(key string) *string { - if v, ok := paramsMap[key]; ok { - return &v - } - - for _, param := range os.Args { - if param[:len(key)] == key { - val := param[len(key)+1:] - paramsMap[key] = val - break - } - } - - v, ok := paramsMap[key] - if !ok { - return nil - } else { - return &v - } -} diff --git a/core/print.go b/core/print.go deleted file mode 100644 index 481ad00..0000000 --- a/core/print.go +++ /dev/null @@ -1,39 +0,0 @@ -package core - -import "fmt" - -type Print struct { -} - -var Printer Print - -// 打印一行信息 -// 字体颜色为默色 -func (print Print) Logln(a ...interface{}) { - fmt.Println(a...) -} - -// 打印一行信息 -// 字体颜色为绿色 -func (print Print) Infoln(a ...interface{}) { - fmt.Print("\033[32m") - fmt.Println(a...) - fmt.Print("\033[0m") -} - -// 打印信息(不换行) -// 字体颜色为绿色 -func (print Print) Info(a ...interface{}) { - fmt.Print("\033[32m") - fmt.Print(a...) - fmt.Print("\033[0m") -} - -// 打印一行错误 -// 字体颜色为红色 -func (print Print) Errorln(a ...interface{}) { - fmt.Print("\033[31m") - fmt.Println(a...) - fmt.Print("\033[0m") -} - diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..eaae545 --- /dev/null +++ b/go.mod @@ -0,0 +1,11 @@ +module autossh + +go 1.12 + +require ( + github.com/kr/fs v0.1.0 // indirect + github.com/pkg/errors v0.8.1 // indirect + github.com/pkg/sftp v1.10.0 + golang.org/x/crypto v0.0.0-20181024132630-e84da0312774c21d64ee2317962ef669b27ffb41 + golang.org/x/sys v0.0.0-20190509141414-a5b02f93d862 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..a3bf533 --- /dev/null +++ b/go.sum @@ -0,0 +1,16 @@ +github.com/islenbo/scp v0.0.0-20170824162307-f7b48647feef3e30991f2ecc16b2b9a977d9a7c3 h1:GZQQKQUN0Bp02sJNGN6/FoqXeKRVqftdsv6WFhwhcR4= +github.com/islenbo/scp v0.0.0-20170824162307-f7b48647feef3e30991f2ecc16b2b9a977d9a7c3/go.mod h1:h2o9ndCCKJY6lGA8rN7Da+3RVv2h5d36KKTyLGuS/T8= +github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 h1:Z9n2FFNUXsshfwJMBgNA0RU6/i7WVaAegv3PtuIHPMs= +github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51/go.mod h1:CzGEWj7cYgsdH8dAjBGEr58BoE7ScuLd+fwFZ44+/x8= +github.com/kr/fs v0.1.0 h1:Jskdu9ieNAYnjxsi0LbQp1ulIKZV1LAFgK1tWhpZgl8= +github.com/kr/fs v0.1.0/go.mod h1:FFnZGqtBN9Gxj7eW1uZ42v5BccTP0vu6NEaFoC2HwRg= +github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= +github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/sftp v1.10.0 h1:DGA1KlA9esU6WcicH+P8PxFZOl15O6GYtab1cIJdOlE= +github.com/pkg/sftp v1.10.0/go.mod h1:NxmoDg/QLVWluQDUYG7XBZTLUpKeFa8e3aMf1BfjyHk= +github.com/tmc/scp v0.0.0-20170824174625-f7b48647feef h1:7D6Nm4D6f0ci9yttWaKjM1TMAXrH5Su72dojqYGntFY= +github.com/tmc/scp v0.0.0-20170824174625-f7b48647feef/go.mod h1:WLFStEdnJXpjK8kd4qKLwQKX/1vrDzp5BcDyiZJBHJM= +golang.org/x/crypto v0.0.0-20181024132630-e84da0312774c21d64ee2317962ef669b27ffb41 h1:IlSGaL0SqFiWKa+4DitAJUo/USS7vA2+5VhRfkkF048= +golang.org/x/crypto v0.0.0-20181024132630-e84da0312774c21d64ee2317962ef669b27ffb41/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= +golang.org/x/sys v0.0.0-20190509141414-a5b02f93d862 h1:rM0ROo5vb9AdYJi1110yjWGMej9ITfKddS89P3Fkhug= +golang.org/x/sys v0.0.0-20190509141414-a5b02f93d862/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= diff --git a/install b/install new file mode 100755 index 0000000..bdd5daf --- /dev/null +++ b/install @@ -0,0 +1,21 @@ +#!/bin/bash + +cd $(dirname $0) + +mkdir -p ~/autossh/ +cp ./autossh ~/autossh/ + +CONFIG_FILE=`cd ~/autossh/ && pwd`/config.json +if [[ ! -f ${CONFIG_FILE} ]]; then + echo "not exists" + cp ./config.json ~/autossh/ +fi + +HAS_ALIAS=`cat ~/.bash_profile | grep autossh | wc -l` +if [[ ${HAS_ALIAS} -eq 0 ]]; then + echo "alias autossh='~/autossh/autossh'" >> ~/.bash_profile +fi + +source ~/.bash_profile + +~/autossh/autossh diff --git a/main.go b/main.go deleted file mode 100755 index 06bccce..0000000 --- a/main.go +++ /dev/null @@ -1,88 +0,0 @@ -package main - -import ( - "autossh/core" - "os" - "path/filepath" - "fmt" - "strings" -) - -var ( - Version = "unknown" - Build = "unknown" -) - -func main() { - configPath := "" - if len(os.Args) > 1 { - option := strings.Split(os.Args[1], "=") - - switch option[0] { - case "--config": - configPath = *core.Params.Get("--config") - case "-c": - configPath = *core.Params.Get("-c") - - case "--help": - fallthrough - case "-h": - help() - return - - case "--version": - fallthrough - case "-v": - version() - return - } - } - - defer func() { - if err := recover(); err != nil { - core.Log.Category("main").Error("recover", err) - } - }() - - if configPath == "" { - configPath, _ = filepath.Abs(filepath.Dir(os.Args[0])) - configPath = configPath + "/config.json" - } else { - configPath, _ = core.ParsePath(configPath) - } - - core.Log.Category("main").Info("config path=", configPath) - - _, err := os.Stat(configPath) - if err != nil { - if os.IsNotExist(err) { - core.Printer.Errorln("config file", configPath+" not exists") - core.Log.Category("main").Error("config file not exists") - } else { - core.Printer.Errorln("unknown error", err) - core.Log.Category("main").Error("unknown error", err) - } - - return - } - - app := core.App{ - ConfigPath: configPath, - } - app.Init() -} - -// 版本信息 -func version() { - fmt.Println("autossh " + Version + " Build " + Build + "。") - fmt.Println("由 Lenbo 编写,项目地址:https://github.com/islenbo/autossh。") -} - -// 显示帮助信息 -func help() { - fmt.Println("一个ssh远程客户端,可一键登录远程服务器,主要用来弥补Mac/Linux Terminal ssh无法保存密码的不足。") - fmt.Println("参数:") - fmt.Println(" -c, --config ", "default=./config.json \t", "指定配置文件。") - fmt.Println(" -h, --help ", " \t", "显示帮助信息。") - fmt.Println(" -v, --version", " \t", "显示 autossh 的版本信息。") -} diff --git a/src/app/app.go b/src/app/app.go new file mode 100644 index 0000000..526f523 --- /dev/null +++ b/src/app/app.go @@ -0,0 +1,59 @@ +package app + +import ( + "autossh/src/utils" + "flag" +) + +var ( + Version string + Build string + + varVersion bool + varHelp bool + varUpgrade bool + varCp bool + varConfig = "./config.json" +) + +func init() { + flag.BoolVar(&varVersion, "v", varVersion, "版本信息") + flag.BoolVar(&varVersion, "version", varVersion, "版本信息") + flag.BoolVar(&varHelp, "h", varHelp, "帮助信息") + flag.BoolVar(&varHelp, "help", varHelp, "帮助信息") + flag.StringVar(&varConfig, "c", varConfig, "指定配置文件路径") + flag.StringVar(&varConfig, "config", varConfig, "指定配置文件路径") + + flag.Parse() + + if len(flag.Args()) > 0 { + arg := flag.Arg(0) + switch arg { + case "upgrade": + varUpgrade = true + case "cp": + varCp = true + default: + defaultServer = arg + } + } +} + +func Run() { + if exists, _ := utils.FileIsExists(varConfig); !exists { + utils.Errorln("Can't read config file", varConfig) + return + } + + if varVersion { + showVersion() + } else if varHelp { + showHelp() + } else if varUpgrade { + showUpgrade() + } else if varCp { + showCp(varConfig) + } else { + showServers(varConfig) + } +} diff --git a/src/app/config.go b/src/app/config.go new file mode 100644 index 0000000..c07bc63 --- /dev/null +++ b/src/app/config.go @@ -0,0 +1,156 @@ +package app + +import ( + "autossh/src/utils" + "bytes" + "encoding/json" + "io" + "io/ioutil" + "os" + "path/filepath" + "strconv" + "time" +) + +type Config struct { + ShowDetail bool `json:"show_detail"` + Servers []Server `json:"servers"` + Groups []Group `json:"groups"` + Options map[string]interface{} `json:"options"` + + // 服务器map索引,可通过编号、别名快速定位到某一个服务器 + serverIndex map[string]ServerIndex + file string +} + +type Group struct { + GroupName string `json:"group_name"` + Prefix string `json:"prefix"` + Servers []Server `json:"servers"` + Collapse bool `json:"collapse"` +} + +type LogMode string + +const ( + LogModeCover LogMode = "cover" + LogModeAppend LogMode = "append" +) + +type ServerLog struct { + Enable bool `json:"enable"` + Filename string `json:"filename"` + Mode LogMode `json:"mode"` +} + +const ( + IndexTypeServer IndexType = iota + IndexTypeGroup +) + +type IndexType int +type ServerIndex struct { + indexType IndexType + groupIndex int + serverIndex int + server *Server +} + +// 创建服务器索引 +func (cfg *Config) createServerIndex() { + cfg.serverIndex = make(map[string]ServerIndex) + for i := range cfg.Servers { + server := &cfg.Servers[i] + server.Format() + index := strconv.Itoa(i + 1) + + if _, ok := cfg.serverIndex[index]; ok { + continue + } + + server.MergeOptions(cfg.Options, false) + cfg.serverIndex[index] = ServerIndex{ + indexType: IndexTypeServer, + groupIndex: -1, + serverIndex: i, + server: server, + } + if server.Alias != "" { + cfg.serverIndex[server.Alias] = cfg.serverIndex[index] + } + } + + for i := range cfg.Groups { + group := &cfg.Groups[i] + for j := range group.Servers { + server := &group.Servers[j] + server.Format() + server.groupName = group.GroupName + index := group.Prefix + strconv.Itoa(j+1) + + if _, ok := cfg.serverIndex[index]; ok { + continue + } + + server.MergeOptions(cfg.Options, false) + cfg.serverIndex[index] = ServerIndex{ + indexType: IndexTypeGroup, + groupIndex: i, + serverIndex: j, + server: server, + } + if server.Alias != "" { + cfg.serverIndex[server.Alias] = cfg.serverIndex[index] + } + } + } +} + +// 保存配置文件 +func (cfg *Config) saveConfig(backup bool) error { + b, err := json.Marshal(cfg) + if err != nil { + return err + } + + var out bytes.Buffer + err = json.Indent(&out, b, "", "\t") + if err != nil { + return err + } + + if backup { + err = cfg.backup() + if err != nil { + return err + } + } + + return ioutil.WriteFile(cfg.file, out.Bytes(), os.ModePerm) +} + +// 备份配置文件 +func (cfg *Config) backup() error { + srcFile, err := os.Open(cfg.file) + if err != nil { + return err + } + + defer srcFile.Close() + + path, _ := filepath.Abs(filepath.Dir(cfg.file)) + backupFile := path + "/config-" + time.Now().Format("20060102150405") + ".json" + desFile, err := os.Create(backupFile) + if err != nil { + return err + } + defer desFile.Close() + + _, err = io.Copy(desFile, srcFile) + if err != nil { + return err + } + + utils.Infoln("配置文件已备份:", backupFile) + return nil +} diff --git a/src/app/handle_add.go b/src/app/handle_add.go new file mode 100644 index 0000000..44e2539 --- /dev/null +++ b/src/app/handle_add.go @@ -0,0 +1,41 @@ +package app + +import ( + "autossh/src/utils" + "fmt" + "io" +) + +func handleAdd(cfg *Config, _ []string) error { + groups := make(map[string]*Group) + for i := range cfg.Groups { + group := &cfg.Groups[i] + groups[group.Prefix] = group + utils.Info("["+group.Prefix+"]"+group.GroupName, "\t") + } + utils.Infoln("[其他值]默认组") + utils.Info("请输入要插入的组:") + g := "" + if _, err := fmt.Scanln(&g); err == io.EOF { + return nil + } + + server := Server{} + server.Format() + if err := server.Edit(); err != nil { + if err == io.EOF { + return nil + } + return err + } + + group, ok := groups[g] + if ok { + group.Servers = append(group.Servers, server) + server.groupName = group.GroupName + } else { + cfg.Servers = append(cfg.Servers, server) + } + + return cfg.saveConfig(true) +} diff --git a/src/app/handle_edit.go b/src/app/handle_edit.go new file mode 100644 index 0000000..ad53ffb --- /dev/null +++ b/src/app/handle_edit.go @@ -0,0 +1,29 @@ +package app + +import ( + "autossh/src/utils" + "fmt" + "io" +) + +func handleEdit(cfg *Config, args []string) error { + utils.Info("请输入相应序号:") + id := "" + if _, err := fmt.Scanln(&id); err == io.EOF { + return nil + } + + serverIndex, ok := cfg.serverIndex[id] + if !ok { + utils.Errorln("序号不存在") + return handleEdit(cfg, args) + } + + if err := serverIndex.server.Edit(); err != nil { + if err == io.EOF { + return nil + } + return err + } + return cfg.saveConfig(true) +} diff --git a/src/app/handle_remove.go b/src/app/handle_remove.go new file mode 100644 index 0000000..01e14a0 --- /dev/null +++ b/src/app/handle_remove.go @@ -0,0 +1,34 @@ +package app + +import ( + "autossh/src/utils" + "fmt" + "io" +) + +func handleRemove(cfg *Config, args []string) error { + utils.Info("请输入相应序号:") + + id := "" + _, err := fmt.Scanln(&id) + if err == io.EOF { + return nil + } + + serverIndex, ok := cfg.serverIndex[id] + if !ok { + utils.Errorln("序号不存在") + return handleRemove(cfg, args) + } + + if serverIndex.indexType == IndexTypeServer { + servers := cfg.Servers + cfg.Servers = append(servers[:serverIndex.serverIndex], servers[serverIndex.serverIndex+1:]...) + } else { + servers := cfg.Groups[serverIndex.groupIndex].Servers + servers = append(servers[:serverIndex.serverIndex], servers[serverIndex.serverIndex+1:]...) + cfg.Groups[serverIndex.groupIndex].Servers = servers + } + + return cfg.saveConfig(true) +} diff --git a/src/app/io_client.go b/src/app/io_client.go new file mode 100644 index 0000000..f55f89c --- /dev/null +++ b/src/app/io_client.go @@ -0,0 +1,62 @@ +package app + +import ( + "github.com/pkg/sftp" + "os" +) + +type IOClientType int + +type FileLike interface { + Name() string + Stat() (os.FileInfo, error) + Read([]byte) (int, error) + Close() error + Write(p []byte) (n int, err error) +} + +const ( + IOClientLocal IOClientType = iota + IOClientSftp +) + +type IOClient struct { + ClientType IOClientType + SftpClient *sftp.Client +} + +// io stat +func (client *IOClient) Stat(file string) (os.FileInfo, error) { + switch client.ClientType { + case IOClientLocal: + return os.Stat(file) + case IOClientSftp: + return client.SftpClient.Stat(file) + default: + return os.Stat(file) + } +} + +// io mkdir +func (client *IOClient) Mkdir(path string) error { + switch client.ClientType { + case IOClientLocal: + return os.Mkdir(path, 0755) + case IOClientSftp: + return client.SftpClient.Mkdir(path) + default: + return os.Mkdir(path, 0755) + } +} + +// io create +func (client *IOClient) Create(file string) (FileLike, error) { + switch client.ClientType { + case IOClientLocal: + return os.Create(file) + case IOClientSftp: + return client.SftpClient.Create(file) + default: + return os.Create(file) + } +} diff --git a/src/app/load_config.go b/src/app/load_config.go new file mode 100644 index 0000000..4a33669 --- /dev/null +++ b/src/app/load_config.go @@ -0,0 +1,31 @@ +package app + +import ( + "autossh/src/utils" + "encoding/json" + "github.com/pkg/errors" + "io/ioutil" +) + +// 加载配置 +func loadConfig(configFile string) (cfg *Config, err error) { + configFile, err = utils.ParsePath(configFile) + if err != nil { + return cfg, err + } + + if exists, _ := utils.FileIsExists(configFile); !exists { + return cfg, errors.New("Can't read configFile file:" + configFile) + } + + b, _ := ioutil.ReadFile(configFile) + err = json.Unmarshal(b, &cfg) + if err != nil { + return cfg, err + } + + cfg.file = configFile + cfg.createServerIndex() + + return cfg, nil +} diff --git a/src/app/scan.go b/src/app/scan.go new file mode 100644 index 0000000..85a6dcf --- /dev/null +++ b/src/app/scan.go @@ -0,0 +1,112 @@ +package app + +import ( + "autossh/src/utils" + "strings" +) + +const ( + InputCmdOpt int = iota + InputCmdServer + InputCmdGroupPrefix +) + +var defaultServer = "" + +// 获取输入 +func scanInput(cfg *Config) (loop bool, clear bool, reload bool) { + cmd, inputCmd, extInfo := checkInput(cfg) + switch inputCmd { + case InputCmdOpt: + { + operation := operations[cmd] + if operation.Process != nil { + if err := operation.Process(cfg, extInfo.([]string)); err != nil { + utils.Errorln(err) + loop = false + return + } + + if !operation.End { + return true, true, true + } + } + return + } + case InputCmdServer: + { + server := cfg.serverIndex[cmd].server + utils.Infoln("你选择了", server.Name) + err := server.Connect() + if err != nil { + utils.Logger.Error("server connect error ", err) + utils.Errorln(err) + } + return true, true, false + } + case InputCmdGroupPrefix: + { + group := cfg.Groups[extInfo.(int)] + group.Collapse = !group.Collapse + err := cfg.saveConfig(false) + if err != nil { + utils.Errorln("备份失败", err) + loop = false + return + } else { + return true, true, true + } + } + } + + loop = true + return +} + +// 检查输入 +func checkInput(cfg *Config) (cmd string, inputCmd int, extInfo interface{}) { + for { + ipt := "" + skipOpt := false + if defaultServer == "" { + utils.Scanln(&ipt) + } else { + ipt = defaultServer + defaultServer = "" + skipOpt = true + } + + ipts := strings.Split(ipt, " ") + cmd = ipts[0] + + if !skipOpt { + if _, exists := operations[cmd]; exists { + inputCmd = InputCmdOpt + extInfo = ipts[1:] + break + } + } + + if _, ok := cfg.serverIndex[cmd]; ok { + inputCmd = InputCmdServer + break + } + + groupIndex := -1 + for index, group := range cfg.Groups { + if group.Prefix == cmd { + inputCmd = InputCmdGroupPrefix + groupIndex = index + extInfo = index + break + } + } + if groupIndex != -1 { + break + } + + utils.Errorln("输入有误,请重新输入") + } + + return cmd, inputCmd, extInfo +} diff --git a/core/server.go b/src/app/server.go similarity index 59% rename from core/server.go rename to src/app/server.go index 2098452..8ecf460 100644 --- a/core/server.go +++ b/src/app/server.go @@ -1,15 +1,17 @@ -package core +package app import ( - "os" - "net" - "strconv" + "autossh/src/utils" + "errors" + "github.com/pkg/sftp" "golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh/terminal" "io/ioutil" - "time" - "fmt" + "net" + "os" + "strconv" "strings" + "time" ) type Server struct { @@ -21,11 +23,15 @@ type Server struct { Method string `json:"method"` Key string `json:"key"` Options map[string]interface{} `json:"options"` + Alias string `json:"alias"` + Log ServerLog `json:"log"` termWidth int termHeight int + groupName string } +// 格式化,赋予默认值 func (server *Server) Format() { if server.Port == 0 { server.Port = 22 @@ -36,19 +42,48 @@ func (server *Server) Format() { } } -// 执行远程连接 -func (server *Server) Connect() { - auths, err := parseAuthMethods(server) +// 合并选项 +func (server *Server) MergeOptions(options map[string]interface{}, overwrite bool) { + if server.Options == nil { + server.Options = make(map[string]interface{}) + } + + for k, v := range options { + if overwrite { + server.Options[k] = v + } else { + if _, ok := server.Options[k]; !ok { + server.Options[k] = v + } + } + + } +} + +// 格式化输出,用于打印 +func (server *Server) FormatPrint(flag string, ShowDetail bool) string { + alias := "" + if server.Alias != "" { + alias = "|" + server.Alias + } + + if ShowDetail { + return " [" + flag + alias + "]" + "\t" + server.Name + " [" + server.User + "@" + server.Ip + "]" + } else { + return " [" + flag + alias + "]" + "\t" + server.Name + } +} +// 生成SSH Client +func (server *Server) GetSshClient() (*ssh.Client, error) { + auth, err := parseAuthMethods(server) if err != nil { - Printer.Errorln("鉴权出错:", err) - Log.Category("server").Error("auth fail", err) - return + return nil, err } config := &ssh.ClientConfig{ User: server.User, - Auth: auths, + Auth: auth, HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error { return nil }, @@ -60,24 +95,40 @@ func (server *Server) Connect() { } addr := server.Ip + ":" + strconv.Itoa(server.Port) + client, err := ssh.Dial("tcp", addr, config) if err != nil { - if ErrorAssert(err, "ssh: unable to authenticate") { - Printer.Errorln("连接失败,请检查密码/密钥是否有误") - return + return nil, err + } + + return client, nil +} + +// 生成Sftp Client +func (server *Server) GetSftpClient() (*sftp.Client, error) { + sshClient, err := server.GetSshClient() + if err == nil { + return sftp.NewClient(sshClient) + } else { + return nil, err + } +} + +// 执行远程连接 +func (server *Server) Connect() error { + client, err := server.GetSshClient() + if err != nil { + if utils.ErrorAssert(err, "ssh: unable to authenticate") { + return errors.New("连接失败,请检查密码/密钥是否有误") } - Printer.Errorln("ssh dial fail:", err) - Log.Category("server").Error("ssh dial fail", err) - return + return errors.New("ssh dial fail:" + err.Error()) } defer client.Close() session, err := client.NewSession() if err != nil { - Printer.Errorln("create session fail:", err) - Log.Category("server").Error("create session fail", err) - return + return errors.New("create session fail:" + err.Error()) } defer session.Close() @@ -85,17 +136,13 @@ func (server *Server) Connect() { fd := int(os.Stdin.Fd()) oldState, err := terminal.MakeRaw(fd) if err != nil { - Printer.Errorln("创建文件描述符出错:", err) - Log.Category("server").Error("创建文件描述符出错", err) - return + return errors.New("创建文件描述符出错:" + err.Error()) } stopKeepAliveLoop := server.startKeepAliveLoop(session) defer close(stopKeepAliveLoop) - session.Stdout = os.Stdout - session.Stderr = os.Stderr - session.Stdin = os.Stdin + server.stdIO(session) defer terminal.Restore(fd, oldState) @@ -107,9 +154,7 @@ func (server *Server) Connect() { server.termWidth, server.termHeight, _ = terminal.GetSize(fd) if err := session.RequestPty("xterm-256color", server.termHeight, server.termWidth, modes); err != nil { - Printer.Errorln("创建终端出错:", err) - Log.Category("server").Error("创建终端出错", err) - return + return errors.New("创建终端出错:" + err.Error()) } winChange := server.listenWindowChange(session, fd) @@ -117,91 +162,75 @@ func (server *Server) Connect() { err = session.Shell() if err != nil { - Printer.Errorln("执行Shell出错:", err) - Log.Category("server").Error("执行Shell出错", err) - return + return errors.New("执行Shell出错:" + err.Error()) } - err = session.Wait() - if err != nil { - //Printer.Errorln("执行Wait出错:", err) - Log.Category("server").Error("执行Wait出错", err) - return - } + _ = session.Wait() + //if err != nil { + // return errors.New("执行Wait出错:" + err.Error()) + //} + + return nil } -// 监听终端窗口变化 -func (server *Server) listenWindowChange(session *ssh.Session, fd int) chan struct{} { - terminate := make(chan struct{}) - go func() { - for { - select { - case <-terminate: - return - default: - termWidth, termHeight, _ := terminal.GetSize(fd) +// 重定向标准输入输出 +func (server *Server) stdIO(session *ssh.Session) { + session.Stderr = os.Stderr + session.Stdin = os.Stdin - if server.termWidth != termWidth || server.termHeight != termHeight { - server.termHeight = termHeight - server.termWidth = termWidth - session.WindowChange(termHeight, termWidth) - } + if server.Log.Enable { + ch, _ := session.StdoutPipe() - time.Sleep(time.Millisecond * 3) + go func() { + flag := os.O_RDWR | os.O_CREATE + switch server.Log.Mode { + case LogModeAppend: + flag = flag | os.O_APPEND + case LogModeCover: } - } - }() - return terminate -} + f, _ := os.OpenFile(server.formatLogFilename(server.Log.Filename), flag, 0644) -// 发送心跳包 -func (server *Server) startKeepAliveLoop(session *ssh.Session) chan struct{} { - terminate := make(chan struct{}) - go func() { - for { - select { - case <-terminate: - return - default: - if val, ok := server.Options["ServerAliveInterval"]; ok && val != nil { - _, err := session.SendRequest("keepalive@bbr", true, nil) - if err != nil { - Log.Category("server").Error("keepAliveLoop fail", err) + for { + buff := [4096]byte{} + n, _ := ch.Read(buff[:]) + if n > 0 { + if _, err := f.Write(buff[:n]); err != nil { + utils.Logger.Error("Write file buffer fail ", err) } - t := time.Duration(server.Options["ServerAliveInterval"].(float64)) - time.Sleep(time.Second * t) - } else { - return + if _, err := os.Stdout.Write(buff[:n]); err != nil { + utils.Logger.Error("Write stdout buffer fail ", err) + } } } - } - }() - return terminate + }() + } else { + session.Stdout = os.Stdout + } } -// 合并选项 -func (server *Server) MergeOptions(options map[string]interface{}, overwrite bool) { - if server.Options == nil { - server.Options = make(map[string]interface{}) +// 格式化日志文件名 +func (server *Server) formatLogFilename(filename string) string { + kvs := map[string]string{ + "%g": server.groupName, + "%n": server.Name, + "%dt": time.Now().Format("2006-01-02-15-04-05"), + "%d": time.Now().Format("2006-01-02"), + "%u": server.User, + "%a": server.Alias, } - for k, v := range options { - if overwrite { - server.Options[k] = v - } else { - if _, ok := server.Options[k]; !ok { - server.Options[k] = v - } - } - + for k, v := range kvs { + filename = strings.ReplaceAll(filename, k, v) } + + return filename } // 解析鉴权方式 func parseAuthMethods(server *Server) ([]ssh.AuthMethod, error) { - sshs := []ssh.AuthMethod{} + var sshs []ssh.AuthMethod switch strings.ToLower(server.Method) { case "password": @@ -229,7 +258,7 @@ func pemKey(server *Server) (ssh.AuthMethod, error) { if server.Key == "" { server.Key = "~/.ssh/id_rsa" } - server.Key, _ = ParsePath(server.Key) + server.Key, _ = utils.ParsePath(server.Key) pemBytes, err := ioutil.ReadFile(server.Key) if err != nil { @@ -250,55 +279,53 @@ func pemKey(server *Server) (ssh.AuthMethod, error) { return ssh.PublicKeys(signer), nil } -func (server *Server) Edit() { - input := "" - Printer.Info("Name(default=" + server.Name + "):") - fmt.Scanln(&input) - if input != "" { - server.Name = input - input = "" - } - - Printer.Info("Ip(default=" + server.Ip + "):") - fmt.Scanln(&input) - if input != "" { - server.Ip = input - input = "" - } +// 发送心跳包 +func (server *Server) startKeepAliveLoop(session *ssh.Session) chan struct{} { + terminate := make(chan struct{}) + go func() { + for { + select { + case <-terminate: + return + default: + if val, ok := server.Options["ServerAliveInterval"]; ok && val != nil { + _, err := session.SendRequest("keepalive@bbr", true, nil) + if err != nil { + utils.Logger.Category("server").Error("keepAliveLoop fail", err) + } - Printer.Info("Port(default=" + strconv.Itoa(server.Port) + "):") - fmt.Scanln(&input) - if input != "" { - port, _ := strconv.Atoi(input) - server.Port = port - input = "" - } + t := time.Duration(server.Options["ServerAliveInterval"].(float64)) + time.Sleep(time.Second * t) + } else { + return + } + } + } + }() + return terminate +} - Printer.Info("User(default=" + server.User + "):") - fmt.Scanln(&input) - if input != "" { - server.User = input - input = "" - } +// 监听终端窗口变化 +func (server *Server) listenWindowChange(session *ssh.Session, fd int) chan struct{} { + terminate := make(chan struct{}) + go func() { + for { + select { + case <-terminate: + return + default: + termWidth, termHeight, _ := terminal.GetSize(fd) - Printer.Info("Password(default=" + server.Password + "):") - fmt.Scanln(&input) - if input != "" { - server.Password = input - input = "" - } + if server.termWidth != termWidth || server.termHeight != termHeight { + server.termHeight = termHeight + server.termWidth = termWidth + session.WindowChange(termHeight, termWidth) + } - Printer.Info("Method(default=" + server.Method + "):") - fmt.Scanln(&input) - if input != "" { - server.Method = input - input = "" - } + time.Sleep(time.Millisecond * 3) + } + } + }() - Printer.Info("Key(default=" + server.Key + "):") - fmt.Scanln(&input) - if input != "" { - server.Key = input - input = "" - } + return terminate } diff --git a/src/app/server_edit.go b/src/app/server_edit.go new file mode 100644 index 0000000..d257bb8 --- /dev/null +++ b/src/app/server_edit.go @@ -0,0 +1,61 @@ +package app + +import ( + "autossh/src/utils" + "fmt" + "io" + "reflect" + "strconv" +) + +// 编辑 +func (server *Server) Edit() error { + keys := []string{"Name", "Ip", "Port", "User", "Password", "Method", "Key", "Alias"} + for _, key := range keys { + if err := server.scanVal(key); err != nil { + return err + } + } + + return nil +} + +func deftVal(val string) string { + if val != "" { + return "(default=" + val + ")" + } else { + return "" + } +} + +func (server *Server) scanVal(fieldName string) (err error) { + elem := reflect.ValueOf(server).Elem() + field := elem.FieldByName(fieldName) + switch field.Type().String() { + case "int": + utils.Info(fieldName + deftVal(strconv.FormatInt(field.Int(), 10)) + ":") + var ipt int + if _, err = fmt.Scanln(&ipt); err == nil { + field.SetInt(int64(ipt)) + } + case "string": + utils.Info(fieldName + deftVal(field.String()) + ":") + var ipt string + if _, err = fmt.Scanln(&ipt); err == nil { + field.SetString(ipt) + } + } + + if err != nil { + if err == io.EOF { + return err + } + + // 允许输入空行 + if err.Error() == "unexpected newline" { + return nil + } + } + + return nil +} diff --git a/src/app/show_cp.go b/src/app/show_cp.go new file mode 100644 index 0000000..d638aa3 --- /dev/null +++ b/src/app/show_cp.go @@ -0,0 +1,353 @@ +package app + +import ( + "autossh/src/utils" + "flag" + "fmt" + "github.com/pkg/errors" + "github.com/pkg/sftp" + "io" + "io/ioutil" + "os" + "path" + "strings" +) + +type TransferObjectType int + +const ( + TransferObjectTypeLocal TransferObjectType = iota + TransferObjectTypeRemote +) + +type TransferObject struct { + raw string // 原始数据,如 vagrant:/root/example.txt + cpType TransferObjectType // 类型,TransferObjectTypeLocal-本地,TransferObjectTypeRemote-远程 + server Server // 服务器,cpType = TransferObjectTypeRemote 时为空 + path string // 从raw解析得到的文件路径,如 /root/example.txt +} + +type CpType int + +const ( + CpTypeUpload CpType = iota + CpTypeDownload +) + +type Cp struct { + isDir bool + cfg *Config + + cpType CpType + sources []*TransferObject + target *TransferObject +} + +// 复制 +func showCp(configFile string) { + var err error + cfg, err := loadConfig(configFile) + if err != nil { + utils.Errorln(err) + return + } + + cp := Cp{cfg: cfg} + if err := cp.parse(); err != nil { + utils.Errorln(err) + return + } + + if cp.sources[0].cpType == TransferObjectTypeLocal { + cp.cpType = CpTypeUpload + err = cp.upload() + } else { + cp.cpType = CpTypeDownload + err = cp.download() + } + + if err != nil { + utils.Errorln(err) + return + } +} + +// 解析参数 +func (cp *Cp) parse() error { + args := flag.Args()[1:] + if len(args) == 0 { + return errors.New("请输入完整参数") + } + + if args[0] == "-r" { + cp.isDir = true + args = args[1:] + } + + var err error + switch len(args) { + case 0: + return errors.New("请输入完整参数") + case 1: + // 默认取temp目录作为target + args = []string{args[0], os.TempDir()} + } + + length := len(args) + cp.target, err = newTransferObject(*cp.cfg, args[length-1]) + if err != nil { + return err + } + + cp.sources = make([]*TransferObject, 0) + for i, arg := range args[:length-1] { + s, err := newTransferObject(*cp.cfg, arg) + if err != nil { + return err + } + + if s.cpType == TransferObjectTypeLocal && s.cpType == cp.target.cpType { + return errors.New("源和目标不能同时为本地地址") + } + + if i > 0 && s.cpType != cp.sources[i-1].cpType { + return errors.New("source 类型不一致") + } + + cp.sources = append(cp.sources, s) + } + + return nil +} + +// 上传 +func (cp *Cp) upload() error { + sftpClient, err := cp.target.server.GetSftpClient() + if err != nil { + return err + } + + defer func() { + _ = sftpClient.Close() + }() + + var ioClient = IOClient{ClientType: IOClientSftp, SftpClient: sftpClient} + + for _, source := range cp.sources { + if file, err := cp.transfer(&ioClient, source.path, cp.target.path, ""); err != nil { + cp.printFileError(file, err) + } + } + + return nil +} + +// IO复制 src -> dst +func (cp *Cp) ioCopy(client *IOClient, srcFile FileLike, dst string, fSize int64) (string, error) { + var err error + + dst, err = cp.parseDstFilename(client, srcFile.Name(), dst) + if err != nil { + return dst, err + } + + dstFile, err := client.Create(dst) + if err != nil { + return dst, err + } + + defer func() { + _ = dstFile.Close() + }() + + bytes := [4096]byte{} + bytesCount := 0 + filename := path.Base(srcFile.Name()) + + for { + n, err := srcFile.Read(bytes[:]) + eof := err == io.EOF + if err != nil && err != io.EOF { + return srcFile.Name(), err + } + + bytesCount += n + process := float64(bytesCount) / float64(fSize) * 100 + cp.printProcess(filename, process) + _, err = dstFile.Write(bytes[:n]) + if err != nil { + return cp.target.path, err + } + + if eof { + cp.printProcess(filename, 100.0) + break + } + } + + fmt.Println("") + return "", nil +} + +// 下载 +func (cp *Cp) download() error { + for _, source := range cp.sources { + sftpClient, err := source.server.GetSftpClient() + if err != nil { + return err + } + + func() { + defer func() { + _ = sftpClient.Close() + }() + + var ioClient = IOClient{ClientType: IOClientLocal, SftpClient: sftpClient} + if file, err := cp.transfer(&ioClient, source.path, cp.target.path, ""); err != nil { + cp.printFileError(file, err) + } + }() + } + + return nil +} + +// 传输 +// 上传时,src = 本地,dst = 远程 +// 下载时,src = 远程,dst = 本地 +func (cp *Cp) transfer(client *IOClient, src string, dst string, vPath string) (string, error) { + srcFile, err := cp.openFile(client.SftpClient, src) + if err != nil { + return src, err + } + + defer func() { + _ = srcFile.Close() + }() + + srcFileInfo, err := srcFile.Stat() + if err != nil { + return srcFile.Name(), err + } + + if srcFileInfo.IsDir() { + if !cp.isDir { + return src, errors.New("是一个目录") + } + + childFiles, err := cp.readDir(client.SftpClient, srcFile.Name()) + if err != nil { + return srcFile.Name(), err + } + + if vPath == "" { + vPath = string(os.PathSeparator) + } else { + vPath = path.Join(vPath, srcFileInfo.Name()) + } + + for _, childFile := range childFiles { + childFilename := path.Join(src, childFile.Name()) + if str, err := cp.transfer(client, childFilename, dst, vPath); err != nil { + cp.printFileError(str, err) + } + } + } else { + newDst := path.Join(dst, vPath) + + if file, err := cp.ioCopy(client, srcFile, newDst, srcFileInfo.Size()); err != nil { + return file, err + } + } + + return "", nil +} + +// 解析dst文件名 +// src = /root/example.txt dst = /root/ => /root/example.txt +// src = /root/example.txt dst = /root => /root/example.txt +// src = /root/example.txt dst = /root/new-name.txt => /root/new-name.txt +func (cp *Cp) parseDstFilename(client *IOClient, src string, dst string) (string, error) { + dstFileInfo, err := client.Stat(dst) + if err != nil { + if !os.IsNotExist(err) { + return dst, err + } + + if cp.isDir { + if err := client.Mkdir(dst); err != nil { + return dst, err + } + + dst = path.Join(dst, path.Base(src)) + } else { + var p = path.Dir(dst) + if _, err = client.Stat(p); err != nil { + return dst, err + } + + dst = path.Join(path.Dir(dst), path.Base(dst)) + } + + } else { + if dstFileInfo.IsDir() { + dst = path.Join(dst, path.Base(src)) + } + } + + return dst, nil +} + +func (cp *Cp) printProcess(name string, process float64) { + // TODO 文件大小,执行时间 + fmt.Print("\r" + name + "\t\t\t" + fmt.Sprintf("%.2f", process) + "%") +} + +func (cp *Cp) printFileError(name string, err error) { + fmt.Println(name, ": ", err) +} + +// 根据上传/下载打开相应位置的文件 +func (cp *Cp) openFile(client *sftp.Client, file string) (FileLike, error) { + if cp.cpType == CpTypeUpload { + return os.Open(file) + } else { + return client.Open(file) + } +} + +// 根据上传/下载读取相应位置的目录,返回文件列表 +func (cp *Cp) readDir(client *sftp.Client, name string) ([]os.FileInfo, error) { + if cp.cpType == CpTypeUpload { + return ioutil.ReadDir(name) + } else { + return client.ReadDir(name) + } +} + +// 创建传输对象 +func newTransferObject(cfg Config, raw string) (*TransferObject, error) { + obj := TransferObject{ + raw: raw, + } + + args := strings.Split(raw, ":") + switch len(args) { + case 1: + obj.cpType = TransferObjectTypeLocal + obj.path = args[0] + case 2: + obj.path = strings.TrimSpace(args[1]) + serverIndex, exists := cfg.serverIndex[args[0]] + if !exists { + return nil, errors.New("服务器" + args[0] + "不存在") + } + obj.cpType = TransferObjectTypeRemote + obj.server = *serverIndex.server + + default: + return nil, errors.New(raw + " 格式错误") + } + + return &obj, nil +} diff --git a/src/app/show_help.go b/src/app/show_help.go new file mode 100644 index 0000000..46055a4 --- /dev/null +++ b/src/app/show_help.go @@ -0,0 +1,32 @@ +package app + +import ( + "autossh/src/utils" + "flag" +) + +func showHelp() { + flag.Usage = usage + flag.Usage() +} + +func usage() { + str := + `一个ssh远程客户端,可一键登录远程服务器,主要用来弥补Mac/Linux Terminal ssh无法保存密码的不足。 +Usage: + autossh [options] [commands] + +Options: + -c, -config 指定配置文件。 + (default: ./config.json) + -v, -version 显示版本信息。 + -h, -help 显示帮助信息。 + +Commands: + upgrade 检测升级。 + cp [-r] source target 复制传输。 + ${ServerNum} 使用编号登录指定服务器。 + ${ServerAlias} 使用别名登录指定服务器。 +` + utils.Logln(str) +} diff --git a/src/app/show_menu.go b/src/app/show_menu.go new file mode 100644 index 0000000..28f5813 --- /dev/null +++ b/src/app/show_menu.go @@ -0,0 +1,74 @@ +package app + +import ( + "autossh/src/utils" + "strings" +) + +type Operation struct { + Key string + Label string + End bool + Process func(cfg *Config, args []string) error +} + +var menuMap [][]Operation + +var operations = make(map[string]Operation) + +func init() { + menuMap = [][]Operation{ + { + {Key: "add", Label: "添加", Process: handleAdd}, + {Key: "edit", Label: "编辑", Process: handleEdit}, + {Key: "remove", Label: "删除", Process: handleRemove}, + }, + { + {Key: "exit", Label: "退出", End: true}, + }, + } +} + +func showMenu() { + var columnsMaxWidths = make(map[int]int) + + for i := 0; i < len(menuMap); i++ { + for j := 0; j < len(menuMap[i]); j++ { + operation := menuMap[i][j] + + // 计算每列最大长度 + maxLen := int(utils.ZhLen(operationFormat(operation))) + if _, exists := columnsMaxWidths[j]; !exists { + columnsMaxWidths[j] = maxLen + } + if columnsMaxWidths[j] < maxLen { + columnsMaxWidths[j] = maxLen + } + + operations[operation.Key] = operation + } + } + + for i := 0; i < len(menuMap); i++ { + var output = "" + for j := 0; j < len(menuMap[i]); j++ { + operation := menuMap[i][j] + output += stringPadding(operationFormat(operation), columnsMaxWidths[j]) + "\t" + } + + utils.Logln(strings.TrimSpace(output)) + output = "" + } +} + +func operationFormat(operation Operation) string { + return "[" + operation.Key + "] " + operation.Label +} + +func stringPadding(str string, paddingLen int) string { + if len(str) < paddingLen { + return stringPadding(str+" ", paddingLen) + } else { + return str + } +} diff --git a/src/app/show_servers.go b/src/app/show_servers.go new file mode 100644 index 0000000..e93aa2f --- /dev/null +++ b/src/app/show_servers.go @@ -0,0 +1,86 @@ +package app + +import ( + "autossh/src/utils" + "strconv" +) + +func showServers(configFile string) { + cfg, err := loadConfig(configFile) + if err != nil { + utils.Errorln(err) + return + } + + // 清屏 + _ = utils.Clear() + + show(cfg) + + for { + loop, clear, reload := scanInput(cfg) + // TODO 解决进入服务器之后第一次输入无效的问题(进入新增、编辑、删除没问题) + if !loop { + break + } + + if reload { + cfg, err = loadConfig(configFile) + } + + if clear { + _ = utils.Clear() + } + + show(cfg) + } +} + +// 显示服务 +func show(cfg *Config) { + maxlen := separatorLength(*cfg) + utils.Infoln(utils.FormatSeparator(" 欢迎使用 Auto SSH ", "=", maxlen)) + for i, server := range cfg.Servers { + utils.Logln(server.FormatPrint(strconv.Itoa(i+1), cfg.ShowDetail)) + } + + for _, group := range cfg.Groups { + if len(group.Servers) == 0 { + continue + } + + var collapseNotice = "" + if group.Collapse { + collapseNotice = "[" + group.Prefix + " ↓]" + } else { + collapseNotice = "[" + group.Prefix + " ↑]" + } + + utils.Infoln(utils.FormatSeparator(" "+group.GroupName+" "+collapseNotice+" ", "_", maxlen)) + if !group.Collapse { + for i, server := range group.Servers { + utils.Logln(server.FormatPrint(group.Prefix+strconv.Itoa(i+1), cfg.ShowDetail)) + } + } + } + + utils.Infoln(utils.FormatSeparator("", "=", maxlen)) + + showMenu() + + utils.Infoln(utils.FormatSeparator("", "=", maxlen)) + utils.Info("请输入序号或操作: ") +} + +// 计算分隔符长度 +func separatorLength(cfg Config) float64 { + maxlength := 60.0 + for _, group := range cfg.Groups { + length := utils.ZhLen(group.GroupName) + if length > maxlength { + maxlength = length + 10 + } + } + + return maxlength +} diff --git a/src/app/show_upgrade.go b/src/app/show_upgrade.go new file mode 100644 index 0000000..65025a9 --- /dev/null +++ b/src/app/show_upgrade.go @@ -0,0 +1,294 @@ +package app + +import ( + "archive/zip" + "autossh/src/utils" + "encoding/json" + "fmt" + "github.com/pkg/errors" + "io" + "io/ioutil" + "net/http" + "os" + "os/exec" + "path" + "path/filepath" + "runtime" + "strconv" + "strings" + "sync" + "time" +) + +type Upgrade struct { + Version string + latest map[string]interface{} +} + +func showUpgrade() { + upgrade := Upgrade{Version: Version} + upgrade.exec() +} + +func (upgrade *Upgrade) exec() { + + // 使用协程异步查询最新版本 + var waitGroutp = sync.WaitGroup{} + islock := true + + go func() { + utils.Log("正在检测最新版本") + for { + if !islock { + utils.Logln("") + waitGroutp.Done() + break + } + utils.Log(".") + time.Sleep(time.Second) + } + }() + + go func() { + upgrade.loadLatestVersion() + islock = false + waitGroutp.Done() + }() + + waitGroutp.Add(2) + waitGroutp.Wait() + + utils.Logln("当前版本:" + upgrade.Version) + latestVersion := upgrade.latest["tag_name"].(string) + ret := upgrade.compareVersion(latestVersion, upgrade.Version) + if ret <= 0 { + utils.Logln("感谢您的支持,当前已是最新版本。") + return + } + + utils.Logln("检测到新版本:" + latestVersion) + url := upgrade.downloadUrl() + if url == "" { + utils.Errorln("暂不支持" + runtime.GOOS + "系统自动更新,请下载源码包手动编译。") + return + } + + filename := path.Base(url) + savePath := os.TempDir() + filename + err := upgrade.downloadFile(url, savePath, func(length, downLen int64) { + process := float64(downLen) / float64(length) * 100 + + fmt.Print("\rdownloading " + fmt.Sprintf("%.2f", process) + "%") + }) + if err != nil { + utils.Errorln("下载失败:" + err.Error()) + return + } + fmt.Print("\rdownloading 100% \n") + + fullpath, err := upgrade.unzip(savePath, os.TempDir()) + if err != nil { + utils.Errorln("解压缩失败:" + err.Error()) + return + } + + cmd := exec.Command(fullpath + "/install") + output, err := cmd.Output() + if err != nil { + utils.Errorln("安装失败") + return + } + utils.Logln(string(output)) +} + +// 解压缩 +func (Upgrade) unzip(zipFile string, destDir string) (string, error) { + zipReader, err := zip.OpenReader(zipFile) + fullpath := "" + if err != nil { + return fullpath, err + } + defer zipReader.Close() + + for _, f := range zipReader.File { + fpath := filepath.Join(destDir, f.Name) + if f.FileInfo().IsDir() { + err := os.MkdirAll(fpath, os.ModePerm) + if err != nil { + return fullpath, err + } + } else { + fullpath = filepath.Dir(fpath) + if err = os.MkdirAll(filepath.Dir(fpath), os.ModePerm); err != nil { + return fullpath, err + } + + inFile, err := f.Open() + if err != nil { + return fullpath, err + } + defer inFile.Close() + + outFile, err := os.OpenFile(fpath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, f.Mode()) + if err != nil { + return fullpath, err + } + defer outFile.Close() + + _, err = io.Copy(outFile, inFile) + if err != nil { + return fullpath, err + } + } + } + return fullpath, nil +} + +// 下载文件 +func (Upgrade) downloadFile(url string, downloadPath string, fb func(length, downLen int64)) error { + var ( + fsize int64 + buf = make([]byte, 32*1024) + written int64 + ) + //创建一个http client + client := new(http.Client) + //get方法获取资源 + resp, err := client.Get(url) + if err != nil { + return err + } + //读取服务器返回的文件大小 + fsize, err = strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64) + if err != nil { + return err + } + //创建文件 + file, err := os.Create(downloadPath) + if err != nil { + return err + } + defer file.Close() + if resp.Body == nil { + return errors.New("body is null") + } + defer resp.Body.Close() + //下面是 io.copyBuffer() 的简化版本 + for { + //读取bytes + nr, er := resp.Body.Read(buf) + if nr > 0 { + //写入bytes + nw, ew := file.Write(buf[0:nr]) + //数据长度大于0 + if nw > 0 { + written += int64(nw) + } + //写入出错 + if ew != nil { + err = ew + break + } + //读取是数据长度不等于写入的数据长度 + if nr != nw { + err = io.ErrShortWrite + break + } + } + if er != nil { + if er != io.EOF { + err = er + } + break + } + //没有错误了快使用 callback + + fb(fsize, written) + } + return err +} + +// 获取下载地址 +func (upgrade *Upgrade) downloadUrl() string { + sysOS := runtime.GOOS + if sysOS == "darwin" { + sysOS = "macOS" + } + + filename := sysOS + "-" + runtime.GOARCH + for _, item := range upgrade.latest["assets"].([]interface{}) { + asset := item.(map[string]interface{}) + if strings.Index(asset["name"].(string), filename) != -1 { + return asset["browser_download_url"].(string) + } + } + + return "" +} + +// 读取最新版本信息 +func (upgrade *Upgrade) loadLatestVersion() { + // 使用github api获取最新版本信息 + resp, err := http.Get("https://api.github.com/repos/islenbo/autossh/releases/latest") + if err != nil { + panic(err) + } + + defer resp.Body.Close() + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + panic(err) + } + + var releaseInfo map[string]interface{} + err = json.Unmarshal(body, &releaseInfo) + if err != nil { + panic(err) + } + + upgrade.latest = releaseInfo +} + +// 版本比较 +// return int 1: src > other; 0 src == other; -1 src < other +func (Upgrade) compareVersion(src string, other string) int { + src = strings.Trim(src, "v") + other = strings.Trim(other, "v") + v1 := strings.Split(src, ".") + v2 := strings.Split(other, ".") + + var lim int + if len(v1) > len(v2) { + lim = len(v1) + } else { + lim = len(v2) + } + + for { + if len(v1) >= lim { + break + } + v1 = append(v1, "0") + } + + for { + if len(v2) >= lim { + break + } + v2 = append(v2, "0") + } + + for i := 0; i < lim; i++ { + num1, _ := strconv.Atoi(v1[i]) + num2, _ := strconv.Atoi(v2[i]) + + if num1 > num2 { + return 1 + } + if num1 < num2 { + return -1 + } + } + + return 0 +} diff --git a/src/app/show_version.go b/src/app/show_version.go new file mode 100644 index 0000000..6285ee6 --- /dev/null +++ b/src/app/show_version.go @@ -0,0 +1,10 @@ +package app + +import ( + "autossh/src/utils" +) + +func showVersion() { + utils.Logln("autossh " + Version + " Build " + Build + "。") + utils.Logln("由 Lenbo 编写,项目地址:https://github.com/islenbo/autossh。") +} diff --git a/src/utils/clean.go b/src/utils/clean.go new file mode 100644 index 0000000..b72b75b --- /dev/null +++ b/src/utils/clean.go @@ -0,0 +1,21 @@ +package utils + +import ( + "os" + "os/exec" + "runtime" +) + +// 清屏 +func Clear() error { + var cmd exec.Cmd + if "windows" == runtime.GOOS { + cmd = *exec.Command("cmd", "/c", "cls") + } else { + cmd = *exec.Command("clear") + } + + cmd.Stdout = os.Stdout + + return cmd.Run() +} diff --git a/src/utils/error_assert.go b/src/utils/error_assert.go new file mode 100644 index 0000000..9adfcb8 --- /dev/null +++ b/src/utils/error_assert.go @@ -0,0 +1,8 @@ +package utils + +import "strings" + +// 错误断言 +func ErrorAssert(err error, assert string) bool { + return strings.Contains(err.Error(), assert) +} diff --git a/core/util.go b/src/utils/file_path.go similarity index 72% rename from core/util.go rename to src/utils/file_path.go index 4148441..f84fd46 100644 --- a/core/util.go +++ b/src/utils/file_path.go @@ -1,49 +1,16 @@ -package core +package utils import ( - "runtime" - "os" "bytes" + "errors" + "os" "os/exec" - "strings" "os/user" - "errors" "path/filepath" - "unicode" + "runtime" + "strings" ) -// 错误断言 -func ErrorAssert(err error, assert string) bool { - return strings.Contains(err.Error(), assert) -} - -// 清屏 -func Clear() { - var cmd exec.Cmd - if "windows" == runtime.GOOS { - cmd = *exec.Command("cmd", "/c", "cls") - } else { - cmd = *exec.Command("clear") - } - - cmd.Stdout = os.Stdout - cmd.Run() -} - -// 计算字符宽度(中文) -func ZhLen(str string) float64 { - length := 0.0 - for _, c := range str { - if unicode.Is(unicode.Scripts["Han"], c) { - length += 2 - } else { - length += 1 - } - } - - return length -} - // 解析路径 func ParsePath(path string) (string, error) { str := []rune(path) @@ -56,11 +23,11 @@ func ParsePath(path string) (string, error) { } return home + string(str[1:]), nil - } else if firstKey == "." { + } else if firstKey == "/" { + return path, nil + } else { p, _ := filepath.Abs(filepath.Dir(os.Args[0])) return p + "/" + path, nil - } else { - return path, nil } } @@ -115,3 +82,23 @@ func homeWindows() (string, error) { return home, nil } + +// 判断文件是否存在 +func FileIsExists(file string) (bool, error) { + file, err := ParsePath(file) + if err != nil { + return false, err + } + + _, err = os.Stat(file) + if err != nil { + return false, err + //if os.IsNotExist(err) { + // return false, err + //} else { + // // unknown error + //} + } + + return true, nil +} diff --git a/core/log.go b/src/utils/logger.go similarity index 95% rename from core/log.go rename to src/utils/logger.go index 1706b47..6cdf2aa 100644 --- a/core/log.go +++ b/src/utils/logger.go @@ -1,8 +1,8 @@ -package core +package utils import ( - "os" "log" + "os" ) type logger struct { @@ -11,11 +11,11 @@ type logger struct { level string } -var Log logger +var Logger logger func init() { logFile, _ := ParsePath("./app.log") - Log = logger{ + Logger = logger{ File: logFile, } } diff --git a/src/utils/printer.go b/src/utils/printer.go new file mode 100644 index 0000000..b3091a0 --- /dev/null +++ b/src/utils/printer.go @@ -0,0 +1,64 @@ +package utils + +import "fmt" + +// 打印一行信息 +// 字体颜色为默色 +func Logln(a ...interface{}) { + fmt.Println(a...) +} + +// 打印(不换行) +// 字体颜色为默色 +func Log(a ...interface{}) { + fmt.Print(a...) +} + +// 打印一行信息 +// 字体颜色为绿色 +func Infoln(a ...interface{}) { + fmt.Print("\033[32m") + Logln(a...) + fmt.Print("\033[0m") +} + +// 打印信息(不换行) +// 字体颜色为绿色 +func Info(a ...interface{}) { + fmt.Print("\033[32m") + Logln(a...) + fmt.Print("\033[0m") +} + +// 打印一行错误 +// 字体颜色为红色 +func Errorln(a ...interface{}) { + fmt.Print("\033[31m") + Logln(a...) + fmt.Print("\033[0m") +} + +// 二维数组对齐 +//func Align(arr [][]string) [][]string { +// for column := 0; column < 2; column++ { +// columnWidth := getColumnWidth(arr, column) +// +// for index := range arr { +// arr[index][column] = AppendRight(arr[index][column], " ", columnWidth) +// } +// } +// +// return arr +//} +// +//func getColumnWidth(arr [][]string, column int) int { +// maxWidth := 0 +// for _, row := range arr { +// width := int(ZhLen(row[column])) +// if maxWidth < width { +// maxWidth = width +// } +// } +// +// return maxWidth +//} diff --git a/src/utils/scan.go b/src/utils/scan.go new file mode 100644 index 0000000..f708960 --- /dev/null +++ b/src/utils/scan.go @@ -0,0 +1,13 @@ +package utils + +import ( + "bufio" + "os" +) + +// GO自带的fmt.Scanln将空格也当作结束符,若需要读取含有空格的句子请使用该方法 +func Scanln(a *string) { + reader := bufio.NewReader(os.Stdin) + data, _, _ := reader.ReadLine() + *a = string(data) +} diff --git a/src/utils/str.go b/src/utils/str.go new file mode 100644 index 0000000..3982a18 --- /dev/null +++ b/src/utils/str.go @@ -0,0 +1,46 @@ +package utils + +import "unicode" + +// 计算字符宽度(中文) +func ZhLen(str string) float64 { + length := 0.0 + for _, c := range str { + if unicode.Is(unicode.Scripts["Han"], c) { + length += 2 + } else { + length += 1 + } + } + + return length +} + +// 左右填充 +// title 主体内容 +// c 填充符号 +// maxlength 总长度 +// 如: title = 测试 c=* maxlength = 10 返回 ** 返回 ** +func FormatSeparator(title string, c string, maxlength float64) string { + charslen := int((maxlength - ZhLen(title)) / 2) + chars := "" + for i := 0; i < charslen; i++ { + chars += c + } + + return chars + title + chars +} + +// 右填充 +func AppendRight(body string, char string, maxlength int) string { + length := int(ZhLen(body)) + if length >= maxlength { + return body + } + + for i := 0; i < maxlength-length; i++ { + body = body + char + } + + return body +}