Skip to content

Commit

Permalink
CopyFromRemote now handles unrequested T message
Browse files Browse the repository at this point in the history
  • Loading branch information
datadius committed Apr 24, 2024
1 parent 18d5aeb commit 8007c8b
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 10 deletions.
86 changes: 78 additions & 8 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,24 @@ func (a *Client) SSHClient() *ssh.Client {
}

// CopyFromFile copies the contents of an os.File to a remote location, it will get the length of the file by looking it up from the filesystem.
func (a *Client) CopyFromFile(ctx context.Context, file os.File, remotePath string, permissions string) error {
func (a *Client) CopyFromFile(
ctx context.Context,
file os.File,
remotePath string,
permissions string,
) error {
return a.CopyFromFilePassThru(ctx, file, remotePath, permissions, nil)
}

// CopyFromFilePassThru copies the contents of an os.File to a remote location, it will get the length of the file by looking it up from the filesystem.
// Access copied bytes by providing a PassThru reader factory.
func (a *Client) CopyFromFilePassThru(ctx context.Context, file os.File, remotePath string, permissions string, passThru PassThru) error {
func (a *Client) CopyFromFilePassThru(
ctx context.Context,
file os.File,
remotePath string,
permissions string,
passThru PassThru,
) error {
stat, err := file.Stat()
if err != nil {
return fmt.Errorf("failed to stat file: %w", err)
Expand All @@ -101,21 +112,39 @@ func (a *Client) CopyFromFilePassThru(ctx context.Context, file os.File, remoteP

// CopyFile copies the contents of an io.Reader to a remote location, the length is determined by reading the io.Reader until EOF
// if the file length in know in advance please use "Copy" instead.
func (a *Client) CopyFile(ctx context.Context, fileReader io.Reader, remotePath string, permissions string) error {
func (a *Client) CopyFile(
ctx context.Context,
fileReader io.Reader,
remotePath string,
permissions string,
) error {
return a.CopyFilePassThru(ctx, fileReader, remotePath, permissions, nil)
}

// CopyFilePassThru copies the contents of an io.Reader to a remote location, the length is determined by reading the io.Reader until EOF
// if the file length in know in advance please use "Copy" instead.
// Access copied bytes by providing a PassThru reader factory.
func (a *Client) CopyFilePassThru(ctx context.Context, fileReader io.Reader, remotePath string, permissions string, passThru PassThru) error {
func (a *Client) CopyFilePassThru(
ctx context.Context,
fileReader io.Reader,
remotePath string,
permissions string,
passThru PassThru,
) error {
contentsBytes, err := ioutil.ReadAll(fileReader)
if err != nil {
return fmt.Errorf("failed to read all data from reader: %w", err)
}
bytesReader := bytes.NewReader(contentsBytes)

return a.CopyPassThru(ctx, bytesReader, remotePath, permissions, int64(len(contentsBytes)), passThru)
return a.CopyPassThru(
ctx,
bytesReader,
remotePath,
permissions,
int64(len(contentsBytes)),
passThru,
)
}

// wait waits for the waitgroup for the specified max timeout.
Expand Down Expand Up @@ -153,13 +182,26 @@ func checkResponse(r io.Reader) error {
}

// Copy copies the contents of an io.Reader to a remote location.
func (a *Client) Copy(ctx context.Context, r io.Reader, remotePath string, permissions string, size int64) error {
func (a *Client) Copy(
ctx context.Context,
r io.Reader,
remotePath string,
permissions string,
size int64,
) error {
return a.CopyPassThru(ctx, r, remotePath, permissions, size, nil)
}

// CopyPassThru copies the contents of an io.Reader to a remote location.
// Access copied bytes by providing a PassThru reader factory
func (a *Client) CopyPassThru(ctx context.Context, r io.Reader, remotePath string, permissions string, size int64, passThru PassThru) error {
func (a *Client) CopyPassThru(
ctx context.Context,
r io.Reader,
remotePath string,
permissions string,
size int64,
passThru PassThru,
) error {
session, err := a.sshClient.NewSession()
if err != nil {
return fmt.Errorf("Error creating ssh session in copy to remote: %v", err)
Expand Down Expand Up @@ -272,7 +314,12 @@ func (a *Client) CopyFromRemote(ctx context.Context, file *os.File, remotePath s
// CopyFromRemotePassThru copies a file from the remote to the given writer. The passThru parameter can be used
// to keep track of progress and how many bytes that were download from the remote.
// `passThru` can be set to nil to disable this behaviour.
func (a *Client) CopyFromRemotePassThru(ctx context.Context, w io.Writer, remotePath string, passThru PassThru) error {
func (a *Client) CopyFromRemotePassThru(
ctx context.Context,
w io.Writer,
remotePath string,
passThru PassThru,
) error {
session, err := a.sshClient.NewSession()
if err != nil {
return fmt.Errorf("Error creating ssh session in copy from remote: %v", err)
Expand Down Expand Up @@ -328,6 +375,29 @@ func (a *Client) CopyFromRemotePassThru(ctx context.Context, w io.Writer, remote
errCh <- errors.New(res.GetMessage())
return
}
if res.NoStandardProtocolType() {
errCh <- errors.New(fmt.Sprintf("Input from server doesn't follow protocol: %s", res.GetMessage()))
return
}

if res.IsTime() {
res, err = ParseResponse(r)
if err != nil {
errCh <- err
return
}

if res.IsFailure() || res.NoStandardProtocolType() {
errCh <- errors.New(res.GetMessage())
return
}
}

// The CHMOD message always comes before the actual data is being sent
if !res.IsChmod() {
errCh <- errors.New(fmt.Sprintf("The data did not contain the expected CHMOD information: %s", res.GetMessage()))
return
}

infos, err := res.ParseFileInfos()
if err != nil {
Expand Down
5 changes: 3 additions & 2 deletions protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,9 @@ func ParseResponse(reader io.Reader) (Response, error) {
}

responseType := buffer[0]
runeResponseType := rune(buffer[0])
message := ""
if responseType > 0 {
if responseType > 0 && (runeResponseType == Chmod || runeResponseType == Time) {
bufferedReader := bufio.NewReader(reader)
message, err = bufferedReader.ReadString('\n')
if err != nil {
Expand All @@ -66,7 +67,7 @@ func ParseResponse(reader io.Reader) (Response, error) {
}

if len(message) > 0 {
return Response{responseType, message, rune(message[0])}, nil
return Response{responseType, message, runeResponseType}, nil
}

return Response{responseType, message, ' '}, nil
Expand Down

0 comments on commit 8007c8b

Please sign in to comment.