diff --git a/.gitignore b/.gitignore index e566e8f..d648c5f 100644 --- a/.gitignore +++ b/.gitignore @@ -2,7 +2,6 @@ servers.json config.json releases/ .idea -main test.go app.log vendor diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..bc0de05 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2019 lenbo + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md index 3634929..a568f23 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,8 @@ # autossh -一个ssh远程客户端,可一键登录远程服务器,主要用来弥补Mac/Linux Terminal ssh无法保存密码的不足。 +一个SSH远程客户端,可一键登录远程服务器,主要用来弥补Mac/Linux Terminal SSH无法保存密码的不足。 -![演示](https://raw.githubusercontent.com/islenbo/autossh/b3e18c35ebced882ace59be7843d9a58d1ac74d7/doc/images/ezgif-1-a4ddae192f.gif) +![演示](https://raw.githubusercontent.com/islenbo/autossh/c9b52688dabbba8ef6403c2f83f8d758ae0e4dfe/doc/images/ezgif-5-42b5117192fc.gif) ## 功能说明 - 核心代码重构,使用go.mod管理依赖 @@ -13,6 +13,7 @@ - 删除功能支持ctrl+d退出 - 优化帮助菜单显示 - 修复若干Bug +- 支持SOCKS5代理 ## 下载 [https://github.com/islenbo/autossh/releases](https://github.com/islenbo/autossh/releases) @@ -22,19 +23,66 @@ - Windows用户可手动编译,参考编译章节。 ## config.json字段说明 -- TODO 字段说明 +```yaml +show_detail: bool <是否显示服务器详情(用户、IP)> +options: + ServerAliveInterval: int <是否定时发送心跳包,与SSH ServerAliveInterval属性用法相同> +servers: + - name: string <显示名称> + ip: string <服务器IP> + port: int <端口> + user: string <用户名> + password: string <密码> + method: string <鉴权方式,password-密码 key-密钥> + key: string <私钥路径> + options: + ServerAliveInterval: int <与根节点ServerAliveInterval用法相同,该值会覆盖根节点的值> + alias: string <别名> + log: + enable: bool <是否启用日志> + filename: string <日志路径, 如 /tmp/%n-%u-%dt.log 支持变量请参考下方《日志变量》章节> + mode: string <遇到同名日志的处理模式,cover-覆盖 append-追加,默认为cover> +groups: + - group_name: string <组名> + prefix: string <组前缀> + servers: array <服务器列表,配置与servers相同,配置说明略> + collapse: bool <是否折叠> + proxy: + type: string <代理方式,目前仅支持SOCKS5> + server: string <代理服务器地址> + port: int <端口号> + user: string <用户,若无留空> + password: string <密码,若无留空> +``` + +## 日志变量 +变量 | 说明 | 示例 +--- | --- | --- +%g | 组名(group_name) | MyGroup1 +%n | 显示名称(name) | vagrant +%u | 用户名(user) | root +%a | 别名(alias) | vagrant +%dt | 日期时间 | 20190821223010 +%d | 日期 | 20190821 ## Q&A -- Q: Downloads中为什么没有Windows的包? -- A: Windows下有很多ssh工具,autossh主要是面向Mac/Linux群体。 + +### Q: 为什么没有Windows的包? +A: Windows下有很多优秀的SSH工具,autossh主要是面向Mac/Linux群体,Windows用户可自行编译。 + +### Q: cp 命令出现报错: ssh: subsystem request failed +A: 修改服务器 /etc/ssh/sshd_config 将 `Subsystem sftp /usr/libexec/openssh/sftp-server` 的注释打开,重启 sshd 服务。 ## 编译 -export GO111MODULE="on" -export GOFLAGS=" -mod=vendor" -go build main.go +```bash +sh build.sh +``` ## 依赖 - 查阅 go.mod ## 注意 -v0.X版本配置文件无法与v1.X版本兼容,请勿使用! \ No newline at end of file +v0.X版本配置文件无法与v1.X版本兼容,请勿使用! + +## License +MIT \ No newline at end of file diff --git a/build.sh b/build.sh index 11477b6..7cdddf6 100644 --- a/build.sh +++ b/build.sh @@ -1,5 +1,8 @@ #!/bin/bash +export GO111MODULE="on" +go mod tidy + PROJECT="autossh" VERSION="v1.1.0" BUILD=`date +%FT%T%z` diff --git a/doc/images/ezgif-5-42b5117192fc.gif b/doc/images/ezgif-5-42b5117192fc.gif new file mode 100644 index 0000000..9daa799 Binary files /dev/null and b/doc/images/ezgif-5-42b5117192fc.gif differ diff --git a/go.mod b/go.mod index eaae545..0d10262 100644 --- a/go.mod +++ b/go.mod @@ -4,8 +4,16 @@ go 1.12 require ( github.com/kr/fs v0.1.0 // indirect - github.com/pkg/errors v0.8.1 // indirect + github.com/pkg/errors v0.8.1 github.com/pkg/sftp v1.10.0 - golang.org/x/crypto v0.0.0-20181024132630-e84da0312774c21d64ee2317962ef669b27ffb41 + github.com/stretchr/testify v1.3.0 // indirect + golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2 + golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3 golang.org/x/sys v0.0.0-20190509141414-a5b02f93d862 // indirect ) + +replace ( + golang.org/x/crypto => github.com/golang/crypto v0.0.0-20190701094942-4def268fd1a4 + golang.org/x/net => github.com/golang/net v0.0.0-20190813141303-74dc4d7220e7 + golang.org/x/sys => github.com/golang/sys v0.0.0-20190813064441-fde4db37ae7a +) diff --git a/go.sum b/go.sum index a3bf533..bee506c 100644 --- a/go.sum +++ b/go.sum @@ -1,16 +1,20 @@ -github.com/islenbo/scp v0.0.0-20170824162307-f7b48647feef3e30991f2ecc16b2b9a977d9a7c3 h1:GZQQKQUN0Bp02sJNGN6/FoqXeKRVqftdsv6WFhwhcR4= -github.com/islenbo/scp v0.0.0-20170824162307-f7b48647feef3e30991f2ecc16b2b9a977d9a7c3/go.mod h1:h2o9ndCCKJY6lGA8rN7Da+3RVv2h5d36KKTyLGuS/T8= -github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 h1:Z9n2FFNUXsshfwJMBgNA0RU6/i7WVaAegv3PtuIHPMs= -github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51/go.mod h1:CzGEWj7cYgsdH8dAjBGEr58BoE7ScuLd+fwFZ44+/x8= +github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/golang/crypto v0.0.0-20190701094942-4def268fd1a4 h1:SqpWDZAu6UkmbvUTCtyNpBZLY8110TJ7bgxIki3pZw0= +github.com/golang/crypto v0.0.0-20190701094942-4def268fd1a4/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +github.com/golang/net v0.0.0-20190813141303-74dc4d7220e7 h1:0oy3VQWim3zJeCPQgw9ka5X1odfKEgRUxblrM6z/rCY= +github.com/golang/net v0.0.0-20190813141303-74dc4d7220e7/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +github.com/golang/sys v0.0.0-20190813064441-fde4db37ae7a h1:hVLU4+cxX4r89gounKarktyMqZ2cx/5Y2jeGLtWqzUE= +github.com/golang/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= github.com/kr/fs v0.1.0 h1:Jskdu9ieNAYnjxsi0LbQp1ulIKZV1LAFgK1tWhpZgl8= github.com/kr/fs v0.1.0/go.mod h1:FFnZGqtBN9Gxj7eW1uZ42v5BccTP0vu6NEaFoC2HwRg= github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/sftp v1.10.0 h1:DGA1KlA9esU6WcicH+P8PxFZOl15O6GYtab1cIJdOlE= github.com/pkg/sftp v1.10.0/go.mod h1:NxmoDg/QLVWluQDUYG7XBZTLUpKeFa8e3aMf1BfjyHk= -github.com/tmc/scp v0.0.0-20170824174625-f7b48647feef h1:7D6Nm4D6f0ci9yttWaKjM1TMAXrH5Su72dojqYGntFY= -github.com/tmc/scp v0.0.0-20170824174625-f7b48647feef/go.mod h1:WLFStEdnJXpjK8kd4qKLwQKX/1vrDzp5BcDyiZJBHJM= -golang.org/x/crypto v0.0.0-20181024132630-e84da0312774c21d64ee2317962ef669b27ffb41 h1:IlSGaL0SqFiWKa+4DitAJUo/USS7vA2+5VhRfkkF048= -golang.org/x/crypto v0.0.0-20181024132630-e84da0312774c21d64ee2317962ef669b27ffb41/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= -golang.org/x/sys v0.0.0-20190509141414-a5b02f93d862 h1:rM0ROo5vb9AdYJi1110yjWGMej9ITfKddS89P3Fkhug= -golang.org/x/sys v0.0.0-20190509141414-a5b02f93d862/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= diff --git a/install b/install index bdd5daf..7b946b2 100755 --- a/install +++ b/install @@ -3,17 +3,16 @@ cd $(dirname $0) mkdir -p ~/autossh/ -cp ./autossh ~/autossh/ +cp -f ./autossh ~/autossh/ CONFIG_FILE=`cd ~/autossh/ && pwd`/config.json if [[ ! -f ${CONFIG_FILE} ]]; then - echo "not exists" cp ./config.json ~/autossh/ fi HAS_ALIAS=`cat ~/.bash_profile | grep autossh | wc -l` if [[ ${HAS_ALIAS} -eq 0 ]]; then - echo "alias autossh='~/autossh/autossh'" >> ~/.bash_profile + echo "export PATH=$PATH:~/autossh/" >> ~/.bash_profile fi source ~/.bash_profile diff --git a/src/app/app.go b/src/app/app.go index 526f523..df14587 100644 --- a/src/app/app.go +++ b/src/app/app.go @@ -1,38 +1,46 @@ package app import ( - "autossh/src/utils" "flag" + "os" + "path/filepath" ) var ( Version string Build string - varVersion bool - varHelp bool - varUpgrade bool - varCp bool - varConfig = "./config.json" + c string + v bool + h bool + upgrade bool + cp bool ) func init() { - flag.BoolVar(&varVersion, "v", varVersion, "版本信息") - flag.BoolVar(&varVersion, "version", varVersion, "版本信息") - flag.BoolVar(&varHelp, "h", varHelp, "帮助信息") - flag.BoolVar(&varHelp, "help", varHelp, "帮助信息") - flag.StringVar(&varConfig, "c", varConfig, "指定配置文件路径") - flag.StringVar(&varConfig, "config", varConfig, "指定配置文件路径") + // 取执行文件所在目录下的config.json + dir, _ := os.Executable() + c = filepath.Dir(dir) + "/config.json" + flag.StringVar(&c, "c", c, "指定配置文件路径") + flag.StringVar(&c, "config", c, "指定配置文件路径") + + flag.BoolVar(&v, "v", v, "版本信息") + flag.BoolVar(&v, "version", v, "版本信息") + + flag.BoolVar(&h, "h", h, "帮助信息") + flag.BoolVar(&h, "help", h, "帮助信息") + + flag.Usage = usage flag.Parse() if len(flag.Args()) > 0 { arg := flag.Arg(0) switch arg { case "upgrade": - varUpgrade = true + upgrade = true case "cp": - varCp = true + cp = true default: defaultServer = arg } @@ -40,20 +48,15 @@ func init() { } func Run() { - if exists, _ := utils.FileIsExists(varConfig); !exists { - utils.Errorln("Can't read config file", varConfig) - return - } - - if varVersion { + if v { showVersion() - } else if varHelp { + } else if h { showHelp() - } else if varUpgrade { + } else if upgrade { showUpgrade() - } else if varCp { - showCp(varConfig) + } else if cp { + showCp(c) } else { - showServers(varConfig) + showServers(c) } } diff --git a/src/app/config.go b/src/app/config.go index c07bc63..4a13d08 100644 --- a/src/app/config.go +++ b/src/app/config.go @@ -14,8 +14,8 @@ import ( type Config struct { ShowDetail bool `json:"show_detail"` - Servers []Server `json:"servers"` - Groups []Group `json:"groups"` + Servers []*Server `json:"servers"` + Groups []*Group `json:"groups"` Options map[string]interface{} `json:"options"` // 服务器map索引,可通过编号、别名快速定位到某一个服务器 @@ -28,6 +28,21 @@ type Group struct { Prefix string `json:"prefix"` Servers []Server `json:"servers"` Collapse bool `json:"collapse"` + Proxy *Proxy `json:"proxy"` +} + +type ProxyType string + +const ( + ProxyTypeSocks5 ProxyType = "SOCKS5" +) + +type Proxy struct { + Type ProxyType `json:"type"` + Server string `json:"server"` + Port int `json:"port"` + User string `json:"user"` + Password string `json:"password"` } type LogMode string @@ -60,7 +75,7 @@ type ServerIndex struct { func (cfg *Config) createServerIndex() { cfg.serverIndex = make(map[string]ServerIndex) for i := range cfg.Servers { - server := &cfg.Servers[i] + server := cfg.Servers[i] server.Format() index := strconv.Itoa(i + 1) @@ -81,11 +96,12 @@ func (cfg *Config) createServerIndex() { } for i := range cfg.Groups { - group := &cfg.Groups[i] + group := cfg.Groups[i] for j := range group.Servers { server := &group.Servers[j] server.Format() server.groupName = group.GroupName + server.group = group index := group.Prefix + strconv.Itoa(j+1) if _, ok := cfg.serverIndex[index]; ok { @@ -136,7 +152,9 @@ func (cfg *Config) backup() error { return err } - defer srcFile.Close() + defer func() { + _ = srcFile.Close() + }() path, _ := filepath.Abs(filepath.Dir(cfg.file)) backupFile := path + "/config-" + time.Now().Format("20060102150405") + ".json" @@ -144,7 +162,9 @@ func (cfg *Config) backup() error { if err != nil { return err } - defer desFile.Close() + defer func() { + _ = desFile.Close() + }() _, err = io.Copy(desFile, srcFile) if err != nil { diff --git a/src/app/handle_add.go b/src/app/handle_add.go index 44e2539..635d918 100644 --- a/src/app/handle_add.go +++ b/src/app/handle_add.go @@ -9,7 +9,7 @@ import ( func handleAdd(cfg *Config, _ []string) error { groups := make(map[string]*Group) for i := range cfg.Groups { - group := &cfg.Groups[i] + group := cfg.Groups[i] groups[group.Prefix] = group utils.Info("["+group.Prefix+"]"+group.GroupName, "\t") } @@ -34,7 +34,7 @@ func handleAdd(cfg *Config, _ []string) error { group.Servers = append(group.Servers, server) server.groupName = group.GroupName } else { - cfg.Servers = append(cfg.Servers, server) + cfg.Servers = append(cfg.Servers, &server) } return cfg.saveConfig(true) diff --git a/src/app/io_client.go b/src/app/io_client.go index f55f89c..88b50cd 100644 --- a/src/app/io_client.go +++ b/src/app/io_client.go @@ -2,6 +2,7 @@ package app import ( "github.com/pkg/sftp" + "io/ioutil" "os" ) @@ -15,48 +16,59 @@ type FileLike interface { Write(p []byte) (n int, err error) } -const ( - IOClientLocal IOClientType = iota - IOClientSftp -) +type IOClient interface { + Stat(file string) (os.FileInfo, error) + Mkdir(path string) error + Create(file string) (FileLike, error) + Open(file string) (FileLike, error) + ReadDir(file string) ([]os.FileInfo, error) +} + +// Local +type LocalIOClient struct { +} + +func (client *LocalIOClient) Stat(file string) (os.FileInfo, error) { + return os.Stat(file) +} + +func (client *LocalIOClient) Mkdir(path string) error { + return os.Mkdir(path, 0755) +} + +func (client *LocalIOClient) Create(file string) (FileLike, error) { + return os.Create(file) +} + +func (client *LocalIOClient) Open(file string) (FileLike, error) { + return os.Open(file) +} + +func (client *LocalIOClient) ReadDir(file string) ([]os.FileInfo, error) { + return ioutil.ReadDir(file) +} -type IOClient struct { - ClientType IOClientType +// SFTP(Remote) +type SftpIOClient struct { SftpClient *sftp.Client } -// io stat -func (client *IOClient) Stat(file string) (os.FileInfo, error) { - switch client.ClientType { - case IOClientLocal: - return os.Stat(file) - case IOClientSftp: - return client.SftpClient.Stat(file) - default: - return os.Stat(file) - } -} - -// io mkdir -func (client *IOClient) Mkdir(path string) error { - switch client.ClientType { - case IOClientLocal: - return os.Mkdir(path, 0755) - case IOClientSftp: - return client.SftpClient.Mkdir(path) - default: - return os.Mkdir(path, 0755) - } -} - -// io create -func (client *IOClient) Create(file string) (FileLike, error) { - switch client.ClientType { - case IOClientLocal: - return os.Create(file) - case IOClientSftp: - return client.SftpClient.Create(file) - default: - return os.Create(file) - } +func (client *SftpIOClient) Stat(file string) (os.FileInfo, error) { + return client.SftpClient.Stat(file) +} + +func (client *SftpIOClient) Mkdir(path string) error { + return client.SftpClient.Mkdir(path) +} + +func (client *SftpIOClient) Create(file string) (FileLike, error) { + return client.SftpClient.Create(file) +} + +func (client *SftpIOClient) Open(file string) (FileLike, error) { + return client.SftpClient.Open(file) +} + +func (client *SftpIOClient) ReadDir(file string) ([]os.FileInfo, error) { + return client.SftpClient.ReadDir(file) } diff --git a/src/app/scan.go b/src/app/scan.go index 85a6dcf..47cfcfe 100644 --- a/src/app/scan.go +++ b/src/app/scan.go @@ -42,7 +42,7 @@ func scanInput(cfg *Config) (loop bool, clear bool, reload bool) { utils.Logger.Error("server connect error ", err) utils.Errorln(err) } - return true, true, false + return false, true, false } case InputCmdGroupPrefix: { diff --git a/src/app/server.go b/src/app/server.go index 8ecf460..afc56f7 100644 --- a/src/app/server.go +++ b/src/app/server.go @@ -3,14 +3,18 @@ package app import ( "autossh/src/utils" "errors" + "fmt" "github.com/pkg/sftp" "golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh/terminal" + "golang.org/x/net/proxy" "io/ioutil" "net" "os" + "os/signal" "strconv" "strings" + "syscall" "time" ) @@ -29,6 +33,7 @@ type Server struct { termWidth int termHeight int groupName string + group *Group } // 格式化,赋予默认值 @@ -96,12 +101,44 @@ func (server *Server) GetSshClient() (*ssh.Client, error) { addr := server.Ip + ":" + strconv.Itoa(server.Port) - client, err := ssh.Dial("tcp", addr, config) + if server.group != nil && server.group.Proxy != nil { + return server.proxySshClient(server.group.Proxy, addr, config) + } else { + return ssh.Dial("tcp", addr, config) + } +} + +func (server *Server) proxySshClient(p *Proxy, sshServerAddr string, sshConfig *ssh.ClientConfig) (client *ssh.Client, err error) { + var dialer proxy.Dialer + switch p.Type { + case ProxyTypeSocks5: + var auth proxy.Auth + if p.User != "" { + auth = proxy.Auth{ + User: p.User, + Password: p.Password, + } + } + + dialer, err = proxy.SOCKS5("tcp", p.Server+":"+strconv.Itoa(p.Port), &auth, proxy.Direct) + if err != nil { + return nil, err + } + default: + return nil, errors.New(fmt.Sprintf("unknown proxy type: %s", p.Type)) + } + + conn, err := dialer.Dial("tcp", sshServerAddr) + if err != nil { + return nil, err + } + + c, chans, reqs, err := ssh.NewClientConn(conn, sshServerAddr, sshConfig) if err != nil { return nil, err } - return client, nil + return ssh.NewClient(c, chans, reqs), nil } // 生成Sftp Client @@ -138,13 +175,15 @@ func (server *Server) Connect() error { if err != nil { return errors.New("创建文件描述符出错:" + err.Error()) } + defer terminal.Restore(fd, oldState) stopKeepAliveLoop := server.startKeepAliveLoop(session) defer close(stopKeepAliveLoop) - server.stdIO(session) - - defer terminal.Restore(fd, oldState) + err = server.stdIO(session) + if err != nil { + return err + } modes := ssh.TerminalModes{ ssh.ECHO: 1, @@ -153,12 +192,15 @@ func (server *Server) Connect() error { } server.termWidth, server.termHeight, _ = terminal.GetSize(fd) - if err := session.RequestPty("xterm-256color", server.termHeight, server.termWidth, modes); err != nil { + termType := os.Getenv("TERM") + if termType == "" { + termType = "xterm-256color" + } + if err := session.RequestPty(termType, server.termHeight, server.termWidth, modes); err != nil { return errors.New("创建终端出错:" + err.Error()) } - winChange := server.listenWindowChange(session, fd) - defer close(winChange) + server.listenWindowChange(session, fd) err = session.Shell() if err != nil { @@ -174,12 +216,15 @@ func (server *Server) Connect() error { } // 重定向标准输入输出 -func (server *Server) stdIO(session *ssh.Session) { +func (server *Server) stdIO(session *ssh.Session) error { session.Stderr = os.Stderr session.Stdin = os.Stdin if server.Log.Enable { - ch, _ := session.StdoutPipe() + ch, err := session.StdoutPipe() + if err != nil { + return err + } go func() { flag := os.O_RDWR | os.O_CREATE @@ -189,7 +234,11 @@ func (server *Server) stdIO(session *ssh.Session) { case LogModeCover: } - f, _ := os.OpenFile(server.formatLogFilename(server.Log.Filename), flag, 0644) + f, err := os.OpenFile(server.formatLogFilename(server.Log.Filename), flag, 0644) + if err != nil { + utils.Logger.Error("Open file fail ", err) + return + } for { buff := [4096]byte{} @@ -208,21 +257,25 @@ func (server *Server) stdIO(session *ssh.Session) { } else { session.Stdout = os.Stdout } + + return nil } // 格式化日志文件名 func (server *Server) formatLogFilename(filename string) string { - kvs := map[string]string{ - "%g": server.groupName, - "%n": server.Name, - "%dt": time.Now().Format("2006-01-02-15-04-05"), - "%d": time.Now().Format("2006-01-02"), - "%u": server.User, - "%a": server.Alias, + kvs := []map[string]string{ + {"%g": server.groupName}, + {"%n": server.Name}, + {"%dt": time.Now().Format("20060102150405")}, + {"%d": time.Now().Format("20060102")}, + {"%u": server.User}, + {"%a": server.Alias}, } - for k, v := range kvs { - filename = strings.ReplaceAll(filename, k, v) + for _, kv := range kvs { + for k, v := range kv { + filename = strings.ReplaceAll(filename, k, v) + } } return filename @@ -306,26 +359,40 @@ func (server *Server) startKeepAliveLoop(session *ssh.Session) chan struct{} { } // 监听终端窗口变化 -func (server *Server) listenWindowChange(session *ssh.Session, fd int) chan struct{} { - terminate := make(chan struct{}) +func (server *Server) listenWindowChange(session *ssh.Session, fd int) { go func() { + sigwinchCh := make(chan os.Signal, 1) + defer close(sigwinchCh) + + signal.Notify(sigwinchCh, syscall.SIGWINCH) + termWidth, termHeight, err := terminal.GetSize(fd) + if err != nil { + utils.Logger.Error(err) + } + for { select { - case <-terminate: - return - default: - termWidth, termHeight, _ := terminal.GetSize(fd) + // 阻塞读取 + case sigwinch := <-sigwinchCh: + if sigwinch == nil { + return + } + currTermWidth, currTermHeight, err := terminal.GetSize(fd) + + // 判断一下窗口尺寸是否有改变 + if currTermHeight == termHeight && currTermWidth == termWidth { + continue + } - if server.termWidth != termWidth || server.termHeight != termHeight { - server.termHeight = termHeight - server.termWidth = termWidth - session.WindowChange(termHeight, termWidth) + // 更新远端大小 + session.WindowChange(currTermHeight, currTermWidth) + if err != nil { + utils.Logger.Error(err) + continue } - time.Sleep(time.Millisecond * 3) + termWidth, termHeight = currTermWidth, currTermHeight } } }() - - return terminate } diff --git a/src/app/server_test.go b/src/app/server_test.go new file mode 100644 index 0000000..8635aef --- /dev/null +++ b/src/app/server_test.go @@ -0,0 +1,57 @@ +package app + +import ( + "fmt" + "golang.org/x/crypto/ssh" + "golang.org/x/net/proxy" + "log" + "net" + "testing" +) + +func TestServer_Connect(t *testing.T) { + var server = Server{ + Ip: "172.18.36.217", + Method: "key", + } + auth, err := parseAuthMethods(&server) + sshConfig := &ssh.ClientConfig{ + User: "work", + Auth: auth, + HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error { + return nil + }, + // Auth: .... fill out with keys etc as normal + } + + client, err := proxiedSSHClient("127.0.0.1:1080", "172.18.36.217:22", sshConfig) + if err != nil { + log.Fatal(err) + } + + session, _ := client.NewSession() + output, _ := session.CombinedOutput("ls") + fmt.Println(string(output)) + + // get a session etc... + +} + +func proxiedSSHClient(proxyAddress, sshServerAddress string, sshConfig *ssh.ClientConfig) (*ssh.Client, error) { + dialer, err := proxy.SOCKS5("tcp", proxyAddress, nil, proxy.Direct) + if err != nil { + return nil, err + } + + conn, err := dialer.Dial("tcp", sshServerAddress) + if err != nil { + return nil, err + } + + c, chans, reqs, err := ssh.NewClientConn(conn, sshServerAddress, sshConfig) + if err != nil { + return nil, err + } + + return ssh.NewClient(c, chans, reqs), nil +} diff --git a/src/app/show_cp.go b/src/app/show_cp.go index d638aa3..2c8f138 100644 --- a/src/app/show_cp.go +++ b/src/app/show_cp.go @@ -7,38 +7,33 @@ import ( "github.com/pkg/errors" "github.com/pkg/sftp" "io" - "io/ioutil" "os" "path" + "strconv" "strings" + "syscall" + "time" + "unsafe" ) -type TransferObjectType int +type ResType int const ( - TransferObjectTypeLocal TransferObjectType = iota - TransferObjectTypeRemote + ResTypeSrc ResType = iota + ResTypeDst ) type TransferObject struct { - raw string // 原始数据,如 vagrant:/root/example.txt - cpType TransferObjectType // 类型,TransferObjectTypeLocal-本地,TransferObjectTypeRemote-远程 - server Server // 服务器,cpType = TransferObjectTypeRemote 时为空 - path string // 从raw解析得到的文件路径,如 /root/example.txt + raw string // 原始数据,如 vagrant:/root/example.txt + resType ResType // 类型,ResTypeSrc-源,ResTypeDst-目的 + server *Server // 服务器,当raw为本地址地时,该字段为空 + path string // 从raw解析得到的文件路径,如 /root/example.txt } -type CpType int - -const ( - CpTypeUpload CpType = iota - CpTypeDownload -) - type Cp struct { isDir bool cfg *Config - cpType CpType sources []*TransferObject target *TransferObject } @@ -58,100 +53,100 @@ func showCp(configFile string) { return } - if cp.sources[0].cpType == TransferObjectTypeLocal { - cp.cpType = CpTypeUpload - err = cp.upload() + var dstIoClient IOClient + if cp.target.server == nil { + dstIoClient = new(LocalIOClient) } else { - cp.cpType = CpTypeDownload - err = cp.download() + sftpClient, err := cp.target.server.GetSftpClient() + if err != nil { + utils.Errorln(err) + return + } + + defer func() { + _ = sftpClient.Close() + }() + + c := SftpIOClient{SftpClient: sftpClient} + dstIoClient = &c } - if err != nil { - utils.Errorln(err) - return + for _, source := range cp.sources { + var srcIoClient IOClient + var sftpClient *sftp.Client + + if source.server == nil { + srcIoClient = new(LocalIOClient) + } else { + sftpClient, err := source.server.GetSftpClient() + if err != nil { + cp.printFileError(source.path, err) + continue + } + + srcIoClient = &SftpIOClient{SftpClient: sftpClient} + } + + func() { + defer func() { + if sftpClient != nil { + _ = sftpClient.Close() + } + }() + + if file, err := cp.transferNew(srcIoClient, dstIoClient, source.path, cp.target.path, ""); err != nil { + cp.printFileError(file, err) + } + }() } } // 解析参数 func (cp *Cp) parse() error { - args := flag.Args()[1:] - if len(args) == 0 { - return errors.New("请输入完整参数") - } - - if args[0] == "-r" { - cp.isDir = true - args = args[1:] - } + os.Args = flag.Args() + flag.BoolVar(&cp.isDir, "r", false, "文件夹") + flag.Parse() + var args = flag.Args() + var length = len(args) var err error - switch len(args) { - case 0: + + if len(args) < 1 { return errors.New("请输入完整参数") - case 1: - // 默认取temp目录作为target - args = []string{args[0], os.TempDir()} } - length := len(args) cp.target, err = newTransferObject(*cp.cfg, args[length-1]) if err != nil { return err } cp.sources = make([]*TransferObject, 0) - for i, arg := range args[:length-1] { + for _, arg := range args[:length-1] { s, err := newTransferObject(*cp.cfg, arg) if err != nil { return err } - if s.cpType == TransferObjectTypeLocal && s.cpType == cp.target.cpType { + if s.resType == ResTypeSrc && s.resType == cp.target.resType { return errors.New("源和目标不能同时为本地地址") } - if i > 0 && s.cpType != cp.sources[i-1].cpType { - return errors.New("source 类型不一致") - } - cp.sources = append(cp.sources, s) } return nil } -// 上传 -func (cp *Cp) upload() error { - sftpClient, err := cp.target.server.GetSftpClient() - if err != nil { - return err - } - - defer func() { - _ = sftpClient.Close() - }() - - var ioClient = IOClient{ClientType: IOClientSftp, SftpClient: sftpClient} - - for _, source := range cp.sources { - if file, err := cp.transfer(&ioClient, source.path, cp.target.path, ""); err != nil { - cp.printFileError(file, err) - } - } - - return nil -} - // IO复制 src -> dst -func (cp *Cp) ioCopy(client *IOClient, srcFile FileLike, dst string, fSize int64) (string, error) { +func (cp *Cp) ioCopy(srcIO IOClient, dstIO IOClient, srcFile FileLike, dst string) (string, error) { var err error - dst, err = cp.parseDstFilename(client, srcFile.Name(), dst) + dst, err = cp.parseDstFilename(dstIO, srcFile.Name(), dst) if err != nil { return dst, err } - dstFile, err := client.Create(dst) + dstFile, err := dstIO.Create(dst) if err != nil { return dst, err } @@ -160,10 +155,28 @@ func (cp *Cp) ioCopy(client *IOClient, srcFile FileLike, dst string, fSize int64 _ = dstFile.Close() }() - bytes := [4096]byte{} bytesCount := 0 filename := path.Base(srcFile.Name()) + startTime := time.Now() + speed := 0.0 + var process = 0.0 + + go func() { + for { + cp.printProcess(filename, process, startTime, speed) + time.Sleep(time.Second) + if process >= 100 { + break + } + } + }() + + srcFileInfo, err := srcFile.Stat() + if err != nil { + return srcFile.Name(), err + } + bytes := make([]byte, 64*1024) for { n, err := srcFile.Read(bytes[:]) eof := err == io.EOF @@ -171,16 +184,16 @@ func (cp *Cp) ioCopy(client *IOClient, srcFile FileLike, dst string, fSize int64 return srcFile.Name(), err } - bytesCount += n - process := float64(bytesCount) / float64(fSize) * 100 - cp.printProcess(filename, process) - _, err = dstFile.Write(bytes[:n]) + wn, err := dstFile.Write(bytes[:n]) if err != nil { return cp.target.path, err } + bytesCount += wn + process = float64(bytesCount) / float64(srcFileInfo.Size()) * 100 + speed = float64(bytesCount) / time.Now().Sub(startTime).Seconds() if eof { - cp.printProcess(filename, 100.0) + cp.printProcess(filename, 100.0, startTime, speed) break } } @@ -189,34 +202,11 @@ func (cp *Cp) ioCopy(client *IOClient, srcFile FileLike, dst string, fSize int64 return "", nil } -// 下载 -func (cp *Cp) download() error { - for _, source := range cp.sources { - sftpClient, err := source.server.GetSftpClient() - if err != nil { - return err - } - - func() { - defer func() { - _ = sftpClient.Close() - }() - - var ioClient = IOClient{ClientType: IOClientLocal, SftpClient: sftpClient} - if file, err := cp.transfer(&ioClient, source.path, cp.target.path, ""); err != nil { - cp.printFileError(file, err) - } - }() - } - - return nil -} - // 传输 // 上传时,src = 本地,dst = 远程 // 下载时,src = 远程,dst = 本地 -func (cp *Cp) transfer(client *IOClient, src string, dst string, vPath string) (string, error) { - srcFile, err := cp.openFile(client.SftpClient, src) +func (cp *Cp) transferNew(srcIO IOClient, dstIO IOClient, src string, dst string, vPath string) (string, error) { + srcFile, err := srcIO.Open(src) if err != nil { return src, err } @@ -235,7 +225,7 @@ func (cp *Cp) transfer(client *IOClient, src string, dst string, vPath string) ( return src, errors.New("是一个目录") } - childFiles, err := cp.readDir(client.SftpClient, srcFile.Name()) + childFiles, err := srcIO.ReadDir(srcFile.Name()) if err != nil { return srcFile.Name(), err } @@ -248,14 +238,13 @@ func (cp *Cp) transfer(client *IOClient, src string, dst string, vPath string) ( for _, childFile := range childFiles { childFilename := path.Join(src, childFile.Name()) - if str, err := cp.transfer(client, childFilename, dst, vPath); err != nil { + if str, err := cp.transferNew(srcIO, dstIO, childFilename, dst, vPath); err != nil { cp.printFileError(str, err) } } } else { newDst := path.Join(dst, vPath) - - if file, err := cp.ioCopy(client, srcFile, newDst, srcFileInfo.Size()); err != nil { + if file, err := cp.ioCopy(srcIO, dstIO, srcFile, newDst); err != nil { return file, err } } @@ -267,7 +256,7 @@ func (cp *Cp) transfer(client *IOClient, src string, dst string, vPath string) ( // src = /root/example.txt dst = /root/ => /root/example.txt // src = /root/example.txt dst = /root => /root/example.txt // src = /root/example.txt dst = /root/new-name.txt => /root/new-name.txt -func (cp *Cp) parseDstFilename(client *IOClient, src string, dst string) (string, error) { +func (cp *Cp) parseDstFilename(client IOClient, src string, dst string) (string, error) { dstFileInfo, err := client.Stat(dst) if err != nil { if !os.IsNotExist(err) { @@ -298,31 +287,40 @@ func (cp *Cp) parseDstFilename(client *IOClient, src string, dst string) (string return dst, nil } -func (cp *Cp) printProcess(name string, process float64) { - // TODO 文件大小,执行时间 - fmt.Print("\r" + name + "\t\t\t" + fmt.Sprintf("%.2f", process) + "%") -} - -func (cp *Cp) printFileError(name string, err error) { - fmt.Println(name, ": ", err) -} +func (cp *Cp) printProcess(name string, process float64, startTime time.Time, speed float64) { + // TODO 文件大小 + execTime := time.Now().Sub(startTime) -// 根据上传/下载打开相应位置的文件 -func (cp *Cp) openFile(client *sftp.Client, file string) (FileLike, error) { - if cp.cpType == CpTypeUpload { - return os.Open(file) - } else { - return client.Open(file) + type winSize struct { + Row uint16 + Col uint16 + Xpixel uint16 + Ypixel uint16 + } + ws := &winSize{} + retCode, _, _ := syscall.Syscall(syscall.SYS_IOCTL, + uintptr(syscall.Stdin), + uintptr(syscall.TIOCGWINSZ), + uintptr(unsafe.Pointer(ws))) + + padding := 0 + if int(retCode) != -1 { + padding = int(ws.Col) - utils.ZhLen(name) - 40 } + + extInfo := fmt.Sprintf("%.2f%% %10s/s %02.0f:%02.0f:%02.0f", + process, + utils.SizeFormat(speed), + execTime.Hours(), + execTime.Minutes(), + execTime.Seconds()) + + format := "\r%s%-" + strconv.Itoa(padding) + "s%40s" + fmt.Printf(format, name, "", extInfo) } -// 根据上传/下载读取相应位置的目录,返回文件列表 -func (cp *Cp) readDir(client *sftp.Client, name string) ([]os.FileInfo, error) { - if cp.cpType == CpTypeUpload { - return ioutil.ReadDir(name) - } else { - return client.ReadDir(name) - } +func (cp *Cp) printFileError(name string, err error) { + fmt.Println(name, ": ", err) } // 创建传输对象 @@ -334,7 +332,7 @@ func newTransferObject(cfg Config, raw string) (*TransferObject, error) { args := strings.Split(raw, ":") switch len(args) { case 1: - obj.cpType = TransferObjectTypeLocal + obj.resType = ResTypeSrc obj.path = args[0] case 2: obj.path = strings.TrimSpace(args[1]) @@ -342,8 +340,8 @@ func newTransferObject(cfg Config, raw string) (*TransferObject, error) { if !exists { return nil, errors.New("服务器" + args[0] + "不存在") } - obj.cpType = TransferObjectTypeRemote - obj.server = *serverIndex.server + obj.resType = ResTypeDst + obj.server = serverIndex.server default: return nil, errors.New(raw + " 格式错误") diff --git a/src/app/show_help.go b/src/app/show_help.go index 46055a4..6fa5edd 100644 --- a/src/app/show_help.go +++ b/src/app/show_help.go @@ -6,7 +6,6 @@ import ( ) func showHelp() { - flag.Usage = usage flag.Usage() } @@ -17,16 +16,15 @@ Usage: autossh [options] [commands] Options: - -c, -config 指定配置文件。 - (default: ./config.json) - -v, -version 显示版本信息。 - -h, -help 显示帮助信息。 + -c, -config string 指定配置文件(default: ./config.json)。 + -v, -version 显示版本信息。 + -h, -help 显示帮助信息。 Commands: - upgrade 检测升级。 - cp [-r] source target 复制传输。 - ${ServerNum} 使用编号登录指定服务器。 - ${ServerAlias} 使用别名登录指定服务器。 + cp [-r] source target 复制传输。 + ${ServerNum} 使用编号登录指定服务器。 + ${ServerAlias} 使用别名登录指定服务器。 + upgrade 检测并更新到最新版本。 ` utils.Logln(str) } diff --git a/src/app/show_servers.go b/src/app/show_servers.go index e93aa2f..fb48cab 100644 --- a/src/app/show_servers.go +++ b/src/app/show_servers.go @@ -19,7 +19,6 @@ func showServers(configFile string) { for { loop, clear, reload := scanInput(cfg) - // TODO 解决进入服务器之后第一次输入无效的问题(进入新增、编辑、删除没问题) if !loop { break } @@ -73,8 +72,8 @@ func show(cfg *Config) { } // 计算分隔符长度 -func separatorLength(cfg Config) float64 { - maxlength := 60.0 +func separatorLength(cfg Config) int { + maxlength := 60 for _, group := range cfg.Groups { length := utils.ZhLen(group.GroupName) if length > maxlength { diff --git a/src/app/show_upgrade.go b/src/app/show_upgrade.go index 65025a9..20129d4 100644 --- a/src/app/show_upgrade.go +++ b/src/app/show_upgrade.go @@ -108,7 +108,9 @@ func (Upgrade) unzip(zipFile string, destDir string) (string, error) { if err != nil { return fullpath, err } - defer zipReader.Close() + defer func() { + _ = zipReader.Close() + }() for _, f := range zipReader.File { fpath := filepath.Join(destDir, f.Name) @@ -127,13 +129,17 @@ func (Upgrade) unzip(zipFile string, destDir string) (string, error) { if err != nil { return fullpath, err } - defer inFile.Close() + defer func() { + _ = inFile.Close() + }() outFile, err := os.OpenFile(fpath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, f.Mode()) if err != nil { return fullpath, err } - defer outFile.Close() + defer func() { + _ = outFile.Close() + }() _, err = io.Copy(outFile, inFile) if err != nil { @@ -168,11 +174,15 @@ func (Upgrade) downloadFile(url string, downloadPath string, fb func(length, dow if err != nil { return err } - defer file.Close() + defer func() { + _ = file.Close() + }() if resp.Body == nil { return errors.New("body is null") } - defer resp.Body.Close() + defer func() { + _ = resp.Body.Close() + }() //下面是 io.copyBuffer() 的简化版本 for { //读取bytes @@ -234,7 +244,9 @@ func (upgrade *Upgrade) loadLatestVersion() { panic(err) } - defer resp.Body.Close() + defer func() { + _ = resp.Body.Close() + }() body, err := ioutil.ReadAll(resp.Body) if err != nil { panic(err) diff --git a/src/main/main.go b/src/main/main.go new file mode 100644 index 0000000..ec78eba --- /dev/null +++ b/src/main/main.go @@ -0,0 +1,16 @@ +package main + +import ( + "autossh/src/app" +) + +var ( + Version = "unknown" + Build = "unknown" +) + +func main() { + app.Version = Version + app.Build = Build + app.Run() +} diff --git a/src/utils/size_format.go b/src/utils/size_format.go new file mode 100644 index 0000000..43b1e5d --- /dev/null +++ b/src/utils/size_format.go @@ -0,0 +1,17 @@ +package utils + +import ( + "math" + "strconv" +) + +func SizeFormat(size float64) string { + var k = 1024 // or 1024 + var sizes = []string{"B", "KB", "MB", "GB", "TB"} + if size == 0 { + return "0 B" + } + i := math.Floor(math.Log(size) / math.Log(float64(k))) + r := size / math.Pow(float64(k), i) + return strconv.FormatFloat(r, 'f', 2, 64) + " " + sizes[int(i)] +} diff --git a/src/utils/str.go b/src/utils/str.go index 3982a18..10a85dc 100644 --- a/src/utils/str.go +++ b/src/utils/str.go @@ -1,10 +1,12 @@ package utils -import "unicode" +import ( + "unicode" +) // 计算字符宽度(中文) -func ZhLen(str string) float64 { - length := 0.0 +func ZhLen(str string) int { + length := 0 for _, c := range str { if unicode.Is(unicode.Scripts["Han"], c) { length += 2 @@ -21,8 +23,8 @@ func ZhLen(str string) float64 { // c 填充符号 // maxlength 总长度 // 如: title = 测试 c=* maxlength = 10 返回 ** 返回 ** -func FormatSeparator(title string, c string, maxlength float64) string { - charslen := int((maxlength - ZhLen(title)) / 2) +func FormatSeparator(title string, c string, maxlength int) string { + charslen := (maxlength - ZhLen(title)) / 2 chars := "" for i := 0; i < charslen; i++ { chars += c @@ -32,15 +34,15 @@ func FormatSeparator(title string, c string, maxlength float64) string { } // 右填充 -func AppendRight(body string, char string, maxlength int) string { - length := int(ZhLen(body)) - if length >= maxlength { - return body - } - - for i := 0; i < maxlength-length; i++ { - body = body + char - } - - return body -} +//func AppendRight(body string, char string, maxlength int) string { +// length := ZhLen(body) +// if length >= maxlength { +// return body +// } +// +// for i := 0; i < maxlength-length; i++ { +// body = body + char +// } +// +// return body +//}