diff --git a/client.go b/client.go index ee7d6f8..919faf6 100644 --- a/client.go +++ b/client.go @@ -85,13 +85,24 @@ func (a *Client) SSHClient() *ssh.Client { } // CopyFromFile copies the contents of an os.File to a remote location, it will get the length of the file by looking it up from the filesystem. -func (a *Client) CopyFromFile(ctx context.Context, file os.File, remotePath string, permissions string) error { +func (a *Client) CopyFromFile( + ctx context.Context, + file os.File, + remotePath string, + permissions string, +) error { return a.CopyFromFilePassThru(ctx, file, remotePath, permissions, nil) } // CopyFromFilePassThru copies the contents of an os.File to a remote location, it will get the length of the file by looking it up from the filesystem. // Access copied bytes by providing a PassThru reader factory. -func (a *Client) CopyFromFilePassThru(ctx context.Context, file os.File, remotePath string, permissions string, passThru PassThru) error { +func (a *Client) CopyFromFilePassThru( + ctx context.Context, + file os.File, + remotePath string, + permissions string, + passThru PassThru, +) error { stat, err := file.Stat() if err != nil { return fmt.Errorf("failed to stat file: %w", err) @@ -101,21 +112,39 @@ func (a *Client) CopyFromFilePassThru(ctx context.Context, file os.File, remoteP // CopyFile copies the contents of an io.Reader to a remote location, the length is determined by reading the io.Reader until EOF // if the file length in know in advance please use "Copy" instead. -func (a *Client) CopyFile(ctx context.Context, fileReader io.Reader, remotePath string, permissions string) error { +func (a *Client) CopyFile( + ctx context.Context, + fileReader io.Reader, + remotePath string, + permissions string, +) error { return a.CopyFilePassThru(ctx, fileReader, remotePath, permissions, nil) } // CopyFilePassThru copies the contents of an io.Reader to a remote location, the length is determined by reading the io.Reader until EOF // if the file length in know in advance please use "Copy" instead. // Access copied bytes by providing a PassThru reader factory. -func (a *Client) CopyFilePassThru(ctx context.Context, fileReader io.Reader, remotePath string, permissions string, passThru PassThru) error { +func (a *Client) CopyFilePassThru( + ctx context.Context, + fileReader io.Reader, + remotePath string, + permissions string, + passThru PassThru, +) error { contentsBytes, err := ioutil.ReadAll(fileReader) if err != nil { return fmt.Errorf("failed to read all data from reader: %w", err) } bytesReader := bytes.NewReader(contentsBytes) - return a.CopyPassThru(ctx, bytesReader, remotePath, permissions, int64(len(contentsBytes)), passThru) + return a.CopyPassThru( + ctx, + bytesReader, + remotePath, + permissions, + int64(len(contentsBytes)), + passThru, + ) } // wait waits for the waitgroup for the specified max timeout. @@ -153,13 +182,26 @@ func checkResponse(r io.Reader) error { } // Copy copies the contents of an io.Reader to a remote location. -func (a *Client) Copy(ctx context.Context, r io.Reader, remotePath string, permissions string, size int64) error { +func (a *Client) Copy( + ctx context.Context, + r io.Reader, + remotePath string, + permissions string, + size int64, +) error { return a.CopyPassThru(ctx, r, remotePath, permissions, size, nil) } // CopyPassThru copies the contents of an io.Reader to a remote location. // Access copied bytes by providing a PassThru reader factory -func (a *Client) CopyPassThru(ctx context.Context, r io.Reader, remotePath string, permissions string, size int64, passThru PassThru) error { +func (a *Client) CopyPassThru( + ctx context.Context, + r io.Reader, + remotePath string, + permissions string, + size int64, + passThru PassThru, +) error { session, err := a.sshClient.NewSession() if err != nil { return fmt.Errorf("Error creating ssh session in copy to remote: %v", err) @@ -272,7 +314,12 @@ func (a *Client) CopyFromRemote(ctx context.Context, file *os.File, remotePath s // CopyFromRemotePassThru copies a file from the remote to the given writer. The passThru parameter can be used // to keep track of progress and how many bytes that were download from the remote. // `passThru` can be set to nil to disable this behaviour. -func (a *Client) CopyFromRemotePassThru(ctx context.Context, w io.Writer, remotePath string, passThru PassThru) error { +func (a *Client) CopyFromRemotePassThru( + ctx context.Context, + w io.Writer, + remotePath string, + passThru PassThru, +) error { session, err := a.sshClient.NewSession() if err != nil { return fmt.Errorf("Error creating ssh session in copy from remote: %v", err) @@ -328,6 +375,29 @@ func (a *Client) CopyFromRemotePassThru(ctx context.Context, w io.Writer, remote errCh <- errors.New(res.GetMessage()) return } + if res.NoStandardProtocolType() { + errCh <- errors.New(fmt.Sprintf("Input from server doesn't follow protocol: %s", res.GetMessage())) + return + } + + if res.IsTime() { + res, err = ParseResponse(r) + if err != nil { + errCh <- err + return + } + + if res.IsFailure() || res.NoStandardProtocolType() { + errCh <- errors.New(res.GetMessage()) + return + } + } + + // The CHMOD message always comes before the actual data is being sent + if !res.IsChmod() { + errCh <- errors.New(fmt.Sprintf("The data did not contain the expected CHMOD information: %s", res.GetMessage())) + return + } infos, err := res.ParseFileInfos() if err != nil { diff --git a/protocol.go b/protocol.go index 532a3fa..df61324 100644 --- a/protocol.go +++ b/protocol.go @@ -56,8 +56,9 @@ func ParseResponse(reader io.Reader) (Response, error) { } responseType := buffer[0] + runeResponseType := rune(buffer[0]) message := "" - if responseType > 0 { + if responseType > 0 && (runeResponseType == Chmod || runeResponseType == Time) { bufferedReader := bufio.NewReader(reader) message, err = bufferedReader.ReadString('\n') if err != nil { @@ -66,7 +67,7 @@ func ParseResponse(reader io.Reader) (Response, error) { } if len(message) > 0 { - return Response{responseType, message, rune(message[0])}, nil + return Response{responseType, message, runeResponseType}, nil } return Response{responseType, message, ' '}, nil