From 18d5aebb577300cd7e2fd3c16aba41d3caaff222 Mon Sep 17 00:00:00 2001 From: Data Dius <16731794+datadius@users.noreply.github.com> Date: Wed, 24 Apr 2024 14:07:54 +0200 Subject: [PATCH 1/7] add time information to the protocol --- protocol.go | 78 ++++++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 74 insertions(+), 4 deletions(-) diff --git a/protocol.go b/protocol.go index ca19d71..532a3fa 100644 --- a/protocol.go +++ b/protocol.go @@ -22,6 +22,13 @@ const ( Error ResponseType = 2 ) +type ProtocolType = rune + +const ( + Chmod ProtocolType = 'C' + Time ProtocolType = 'T' +) + // Response represent a response from the SCP command. // There are tree types of responses that the remote can send back: // ok, warning and error @@ -35,8 +42,9 @@ const ( // The remote sends a confirmation after every SCP command, because a failure can occur after every // command, the response should be read and checked after sending them. type Response struct { - Type ResponseType - Message string + Type ResponseType + Message string + ProtocolType rune } // ParseResponse reads from the given reader (assuming it is the output of the remote) and parses it into a Response structure. @@ -57,7 +65,11 @@ func ParseResponse(reader io.Reader) (Response, error) { } } - return Response{responseType, message}, nil + if len(message) > 0 { + return Response{responseType, message, rune(message[0])}, nil + } + + return Response{responseType, message, ' '}, nil } func (r *Response) IsOk() bool { @@ -78,6 +90,18 @@ func (r *Response) IsFailure() bool { return r.IsWarning() || r.IsError() } +func (r *Response) IsChmod() bool { + return r.ProtocolType == Chmod +} + +func (r *Response) IsTime() bool { + return r.ProtocolType == Time +} + +func (r *Response) NoStandardProtocolType() bool { + return !(r.ProtocolType == Chmod || r.ProtocolType == Time) +} + // GetMessage returns the message the remote sent back. func (r *Response) GetMessage() string { return r.Message @@ -88,13 +112,36 @@ type FileInfos struct { Filename string Permissions string Size int64 + Atime int64 + Mtime int64 +} + +func (fileInfos *FileInfos) Update(new *FileInfos) { + if new == nil { + return + } + if new.Filename != "" { + fileInfos.Filename = new.Filename + } + if new.Permissions != "" { + fileInfos.Permissions = new.Permissions + } + if new.Size != 0 { + fileInfos.Size = new.Size + } + if new.Atime != 0 { + fileInfos.Atime = new.Atime + } + if new.Mtime != 0 { + fileInfos.Mtime = new.Mtime + } } func (r *Response) ParseFileInfos() (*FileInfos, error) { message := strings.ReplaceAll(r.Message, "\n", "") parts := strings.Split(message, " ") if len(parts) < 3 { - return nil, errors.New("unable to parse message as file infos") + return nil, errors.New("unable to parse Chmod protocol") } size, err := strconv.Atoi(parts[1]) @@ -110,6 +157,29 @@ func (r *Response) ParseFileInfos() (*FileInfos, error) { }, nil } +func (r *Response) ParseFileTime() (*FileInfos, error) { + message := strings.ReplaceAll(r.Message, "\n", "") + parts := strings.Split(message, " ") + if len(parts) < 3 { + return nil, errors.New("unable to parse Time protocol") + } + + aTime, err := strconv.Atoi(string(parts[0][1:10])) + if err != nil { + return nil, errors.New("unable to parse ATime component of message") + } + mTime, err := strconv.Atoi(string(parts[2][0:10])) + if err != nil { + return nil, errors.New("unable to parse MTime component of message") + } + + return &FileInfos{ + Message: r.Message, + Atime: int64(aTime), + Mtime: int64(mTime), + }, nil +} + // Ack writes an `Ack` message to the remote, does not await its response, a seperate call to ParseResponse is // therefore required to check if the acknowledgement succeeded. func Ack(writer io.Writer) error { From 8007c8b1cb0a452f90a311b2d321f07c6ffb84ed Mon Sep 17 00:00:00 2001 From: Data Dius <16731794+datadius@users.noreply.github.com> Date: Wed, 24 Apr 2024 20:16:16 +0200 Subject: [PATCH 2/7] CopyFromRemote now handles unrequested T message --- client.go | 86 ++++++++++++++++++++++++++++++++++++++++++++++++----- protocol.go | 5 ++-- 2 files changed, 81 insertions(+), 10 deletions(-) diff --git a/client.go b/client.go index ee7d6f8..919faf6 100644 --- a/client.go +++ b/client.go @@ -85,13 +85,24 @@ func (a *Client) SSHClient() *ssh.Client { } // CopyFromFile copies the contents of an os.File to a remote location, it will get the length of the file by looking it up from the filesystem. -func (a *Client) CopyFromFile(ctx context.Context, file os.File, remotePath string, permissions string) error { +func (a *Client) CopyFromFile( + ctx context.Context, + file os.File, + remotePath string, + permissions string, +) error { return a.CopyFromFilePassThru(ctx, file, remotePath, permissions, nil) } // CopyFromFilePassThru copies the contents of an os.File to a remote location, it will get the length of the file by looking it up from the filesystem. // Access copied bytes by providing a PassThru reader factory. -func (a *Client) CopyFromFilePassThru(ctx context.Context, file os.File, remotePath string, permissions string, passThru PassThru) error { +func (a *Client) CopyFromFilePassThru( + ctx context.Context, + file os.File, + remotePath string, + permissions string, + passThru PassThru, +) error { stat, err := file.Stat() if err != nil { return fmt.Errorf("failed to stat file: %w", err) @@ -101,21 +112,39 @@ func (a *Client) CopyFromFilePassThru(ctx context.Context, file os.File, remoteP // CopyFile copies the contents of an io.Reader to a remote location, the length is determined by reading the io.Reader until EOF // if the file length in know in advance please use "Copy" instead. -func (a *Client) CopyFile(ctx context.Context, fileReader io.Reader, remotePath string, permissions string) error { +func (a *Client) CopyFile( + ctx context.Context, + fileReader io.Reader, + remotePath string, + permissions string, +) error { return a.CopyFilePassThru(ctx, fileReader, remotePath, permissions, nil) } // CopyFilePassThru copies the contents of an io.Reader to a remote location, the length is determined by reading the io.Reader until EOF // if the file length in know in advance please use "Copy" instead. // Access copied bytes by providing a PassThru reader factory. -func (a *Client) CopyFilePassThru(ctx context.Context, fileReader io.Reader, remotePath string, permissions string, passThru PassThru) error { +func (a *Client) CopyFilePassThru( + ctx context.Context, + fileReader io.Reader, + remotePath string, + permissions string, + passThru PassThru, +) error { contentsBytes, err := ioutil.ReadAll(fileReader) if err != nil { return fmt.Errorf("failed to read all data from reader: %w", err) } bytesReader := bytes.NewReader(contentsBytes) - return a.CopyPassThru(ctx, bytesReader, remotePath, permissions, int64(len(contentsBytes)), passThru) + return a.CopyPassThru( + ctx, + bytesReader, + remotePath, + permissions, + int64(len(contentsBytes)), + passThru, + ) } // wait waits for the waitgroup for the specified max timeout. @@ -153,13 +182,26 @@ func checkResponse(r io.Reader) error { } // Copy copies the contents of an io.Reader to a remote location. -func (a *Client) Copy(ctx context.Context, r io.Reader, remotePath string, permissions string, size int64) error { +func (a *Client) Copy( + ctx context.Context, + r io.Reader, + remotePath string, + permissions string, + size int64, +) error { return a.CopyPassThru(ctx, r, remotePath, permissions, size, nil) } // CopyPassThru copies the contents of an io.Reader to a remote location. // Access copied bytes by providing a PassThru reader factory -func (a *Client) CopyPassThru(ctx context.Context, r io.Reader, remotePath string, permissions string, size int64, passThru PassThru) error { +func (a *Client) CopyPassThru( + ctx context.Context, + r io.Reader, + remotePath string, + permissions string, + size int64, + passThru PassThru, +) error { session, err := a.sshClient.NewSession() if err != nil { return fmt.Errorf("Error creating ssh session in copy to remote: %v", err) @@ -272,7 +314,12 @@ func (a *Client) CopyFromRemote(ctx context.Context, file *os.File, remotePath s // CopyFromRemotePassThru copies a file from the remote to the given writer. The passThru parameter can be used // to keep track of progress and how many bytes that were download from the remote. // `passThru` can be set to nil to disable this behaviour. -func (a *Client) CopyFromRemotePassThru(ctx context.Context, w io.Writer, remotePath string, passThru PassThru) error { +func (a *Client) CopyFromRemotePassThru( + ctx context.Context, + w io.Writer, + remotePath string, + passThru PassThru, +) error { session, err := a.sshClient.NewSession() if err != nil { return fmt.Errorf("Error creating ssh session in copy from remote: %v", err) @@ -328,6 +375,29 @@ func (a *Client) CopyFromRemotePassThru(ctx context.Context, w io.Writer, remote errCh <- errors.New(res.GetMessage()) return } + if res.NoStandardProtocolType() { + errCh <- errors.New(fmt.Sprintf("Input from server doesn't follow protocol: %s", res.GetMessage())) + return + } + + if res.IsTime() { + res, err = ParseResponse(r) + if err != nil { + errCh <- err + return + } + + if res.IsFailure() || res.NoStandardProtocolType() { + errCh <- errors.New(res.GetMessage()) + return + } + } + + // The CHMOD message always comes before the actual data is being sent + if !res.IsChmod() { + errCh <- errors.New(fmt.Sprintf("The data did not contain the expected CHMOD information: %s", res.GetMessage())) + return + } infos, err := res.ParseFileInfos() if err != nil { diff --git a/protocol.go b/protocol.go index 532a3fa..df61324 100644 --- a/protocol.go +++ b/protocol.go @@ -56,8 +56,9 @@ func ParseResponse(reader io.Reader) (Response, error) { } responseType := buffer[0] + runeResponseType := rune(buffer[0]) message := "" - if responseType > 0 { + if responseType > 0 && (runeResponseType == Chmod || runeResponseType == Time) { bufferedReader := bufio.NewReader(reader) message, err = bufferedReader.ReadString('\n') if err != nil { @@ -66,7 +67,7 @@ func ParseResponse(reader io.Reader) (Response, error) { } if len(message) > 0 { - return Response{responseType, message, rune(message[0])}, nil + return Response{responseType, message, runeResponseType}, nil } return Response{responseType, message, ' '}, nil From 4f72b511928ca3171208b602fa4eae985e1f3c62 Mon Sep 17 00:00:00 2001 From: Data Dius <16731794+datadius@users.noreply.github.com> Date: Wed, 24 Apr 2024 20:37:00 +0200 Subject: [PATCH 3/7] Added logic to prepare for returning information about the file --- client.go | 21 ++++++++++++++++++++- protocol.go | 4 ++++ 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/client.go b/client.go index 919faf6..bd44c58 100644 --- a/client.go +++ b/client.go @@ -380,17 +380,33 @@ func (a *Client) CopyFromRemotePassThru( return } + fileInfo := NewFileInfos() + if res.IsTime() { + timeInfo, err := res.ParseFileTime() + + if err != nil { + errCh <- err + return + } + + fileInfo.Update(timeInfo) + res, err = ParseResponse(r) if err != nil { errCh <- err return } - if res.IsFailure() || res.NoStandardProtocolType() { + if res.IsFailure() { errCh <- errors.New(res.GetMessage()) return } + + if res.NoStandardProtocolType() { + errCh <- errors.New(fmt.Sprintf("Input from server doesn't follow protocol: %s", res.GetMessage())) + return + } } // The CHMOD message always comes before the actual data is being sent @@ -400,6 +416,9 @@ func (a *Client) CopyFromRemotePassThru( } infos, err := res.ParseFileInfos() + + fileInfo.Update(infos) + if err != nil { errCh <- err return diff --git a/protocol.go b/protocol.go index df61324..b2a0dea 100644 --- a/protocol.go +++ b/protocol.go @@ -117,6 +117,10 @@ type FileInfos struct { Mtime int64 } +func NewFileInfos() *FileInfos { + return &FileInfos{} +} + func (fileInfos *FileInfos) Update(new *FileInfos) { if new == nil { return From ea410af592a2940117124fc4088d584e4f5cd87a Mon Sep 17 00:00:00 2001 From: Data Dius <16731794+datadius@users.noreply.github.com> Date: Wed, 24 Apr 2024 20:47:14 +0200 Subject: [PATCH 4/7] fixed minor bug not reading the whole number --- protocol.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/protocol.go b/protocol.go index b2a0dea..f46d7ac 100644 --- a/protocol.go +++ b/protocol.go @@ -169,7 +169,7 @@ func (r *Response) ParseFileTime() (*FileInfos, error) { return nil, errors.New("unable to parse Time protocol") } - aTime, err := strconv.Atoi(string(parts[0][1:10])) + aTime, err := strconv.Atoi(string(parts[0][0:10])) if err != nil { return nil, errors.New("unable to parse ATime component of message") } From 2393b738f095d733fc6d2ec79af328ffc759bc0d Mon Sep 17 00:00:00 2001 From: Data Dius <16731794+datadius@users.noreply.github.com> Date: Wed, 24 Apr 2024 21:24:01 +0200 Subject: [PATCH 5/7] removed a check for protocol type in parse response --- protocol.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/protocol.go b/protocol.go index f46d7ac..30dc260 100644 --- a/protocol.go +++ b/protocol.go @@ -58,7 +58,7 @@ func ParseResponse(reader io.Reader) (Response, error) { responseType := buffer[0] runeResponseType := rune(buffer[0]) message := "" - if responseType > 0 && (runeResponseType == Chmod || runeResponseType == Time) { + if responseType > 0 { bufferedReader := bufio.NewReader(reader) message, err = bufferedReader.ReadString('\n') if err != nil { From adc0026b0a40ec01eb15ddd8321398e7fd02c56e Mon Sep 17 00:00:00 2001 From: Data Dius <16731794+datadius@users.noreply.github.com> Date: Sat, 27 Apr 2024 22:36:04 +0200 Subject: [PATCH 6/7] rewrite ParseResponse method to reduce some of the code duplication --- client.go | 65 ++------------------ protocol.go | 169 ++++++++++++++++++++++++++-------------------------- 2 files changed, 87 insertions(+), 147 deletions(-) diff --git a/client.go b/client.go index bd44c58..b068b6a 100644 --- a/client.go +++ b/client.go @@ -9,7 +9,6 @@ package scp import ( "bytes" "context" - "errors" "fmt" "io" "io/ioutil" @@ -168,15 +167,11 @@ func wait(wg *sync.WaitGroup, ctx context.Context) error { // checkResponse checks the response it reads from the remote, and will return a single error in case // of failure. func checkResponse(r io.Reader) error { - response, err := ParseResponse(r) + _, err := ParseResponse(r, nil) if err != nil { return err } - if response.IsFailure() { - return errors.New(response.GetMessage()) - } - return nil } @@ -366,59 +361,7 @@ func (a *Client) CopyFromRemotePassThru( return } - res, err := ParseResponse(r) - if err != nil { - errCh <- err - return - } - if res.IsFailure() { - errCh <- errors.New(res.GetMessage()) - return - } - if res.NoStandardProtocolType() { - errCh <- errors.New(fmt.Sprintf("Input from server doesn't follow protocol: %s", res.GetMessage())) - return - } - - fileInfo := NewFileInfos() - - if res.IsTime() { - timeInfo, err := res.ParseFileTime() - - if err != nil { - errCh <- err - return - } - - fileInfo.Update(timeInfo) - - res, err = ParseResponse(r) - if err != nil { - errCh <- err - return - } - - if res.IsFailure() { - errCh <- errors.New(res.GetMessage()) - return - } - - if res.NoStandardProtocolType() { - errCh <- errors.New(fmt.Sprintf("Input from server doesn't follow protocol: %s", res.GetMessage())) - return - } - } - - // The CHMOD message always comes before the actual data is being sent - if !res.IsChmod() { - errCh <- errors.New(fmt.Sprintf("The data did not contain the expected CHMOD information: %s", res.GetMessage())) - return - } - - infos, err := res.ParseFileInfos() - - fileInfo.Update(infos) - + fileInfo, err := ParseResponse(r, in) if err != nil { errCh <- err return @@ -431,10 +374,10 @@ func (a *Client) CopyFromRemotePassThru( } if passThru != nil { - r = passThru(r, infos.Size) + r = passThru(r, fileInfo.Size) } - _, err = CopyN(w, r, infos.Size) + _, err = CopyN(w, r, fileInfo.Size) if err != nil { errCh <- err return diff --git a/protocol.go b/protocol.go index 30dc260..2a615ea 100644 --- a/protocol.go +++ b/protocol.go @@ -9,103 +9,96 @@ package scp import ( "bufio" "errors" + "fmt" "io" "strconv" "strings" ) -type ResponseType = uint8 +type ResponseType = byte const ( Ok ResponseType = 0 Warning ResponseType = 1 Error ResponseType = 2 + Create ResponseType = 'C' + Time ResponseType = 'T' ) -type ProtocolType = rune - -const ( - Chmod ProtocolType = 'C' - Time ProtocolType = 'T' -) - -// Response represent a response from the SCP command. -// There are tree types of responses that the remote can send back: -// ok, warning and error -// -// The difference between warning and error is that the connection is not closed by the remote, -// however, a warning can indicate a file transfer failure (such as invalid destination directory) -// and such be handled as such. -// -// All responses except for the `Ok` type always have a message (although these can be empty) -// -// The remote sends a confirmation after every SCP command, because a failure can occur after every -// command, the response should be read and checked after sending them. -type Response struct { - Type ResponseType - Message string - ProtocolType rune -} - // ParseResponse reads from the given reader (assuming it is the output of the remote) and parses it into a Response structure. -func ParseResponse(reader io.Reader) (Response, error) { +func ParseResponse(reader io.Reader, writer io.Writer) (*FileInfos, error) { + fileInfos := NewFileInfos() + buffer := make([]uint8, 1) _, err := reader.Read(buffer) if err != nil { - return Response{}, err + return fileInfos, err } responseType := buffer[0] - runeResponseType := rune(buffer[0]) message := "" if responseType > 0 { bufferedReader := bufio.NewReader(reader) message, err = bufferedReader.ReadString('\n') if err != nil { - return Response{}, err + return fileInfos, err } - } - if len(message) > 0 { - return Response{responseType, message, runeResponseType}, nil - } - - return Response{responseType, message, ' '}, nil -} - -func (r *Response) IsOk() bool { - return r.Type == Ok -} - -func (r *Response) IsWarning() bool { - return r.Type == Warning -} - -// IsError returns true when the remote responded with an error. -func (r *Response) IsError() bool { - return r.Type == Error -} + if responseType == Warning || responseType == Error { + return fileInfos, errors.New( + fmt.Sprintf("Failed to execute command with a warning or error: %s", message), + ) + } -// IsFailure returns true when the remote answered with a warning or an error. -func (r *Response) IsFailure() bool { - return r.IsWarning() || r.IsError() -} + // Exit early because we're only interested in the ok response + if responseType == Ok { + return fileInfos, nil + } -func (r *Response) IsChmod() bool { - return r.ProtocolType == Chmod -} + if !(responseType == Create || responseType == Time) { + return fileInfos, errors.New( + fmt.Sprintf( + "Message does not follow scp protocol: %s\n Cmmmm or T 0 0", + message, + ), + ) + } -func (r *Response) IsTime() bool { - return r.ProtocolType == Time -} + if responseType == Time { + err = ParseFileTime(message, fileInfos) + if err != nil { + return nil, err + } + + message, err = bufferedReader.ReadString('\n') + if err == io.EOF { + err = Ack(writer) + if err != nil { + return fileInfos, err + } + message, err = bufferedReader.ReadString('\n') + + if err != nil { + return fileInfos, err + } + } + + if err != nil && err != io.EOF { + return fileInfos, err + } + + responseType = message[0] + } -func (r *Response) NoStandardProtocolType() bool { - return !(r.ProtocolType == Chmod || r.ProtocolType == Time) -} + if responseType == Create { + err = ParseFileInfos(message, fileInfos) + if err != nil { + return nil, err + } + } + } -// GetMessage returns the message the remote sent back. -func (r *Response) GetMessage() string { - return r.Message + return fileInfos, nil } type FileInfos struct { @@ -142,47 +135,51 @@ func (fileInfos *FileInfos) Update(new *FileInfos) { } } -func (r *Response) ParseFileInfos() (*FileInfos, error) { - message := strings.ReplaceAll(r.Message, "\n", "") - parts := strings.Split(message, " ") +func ParseFileInfos(message string, fileInfos *FileInfos) error { + processMessage := strings.ReplaceAll(message, "\n", "") + parts := strings.Split(processMessage, " ") if len(parts) < 3 { - return nil, errors.New("unable to parse Chmod protocol") + return errors.New("unable to parse Chmod protocol") } size, err := strconv.Atoi(parts[1]) if err != nil { - return nil, err + return err } - return &FileInfos{ - Message: r.Message, + fileInfos.Update(&FileInfos{ + Filename: parts[2], Permissions: parts[0], Size: int64(size), - Filename: parts[2], - }, nil + }) + + return nil } -func (r *Response) ParseFileTime() (*FileInfos, error) { - message := strings.ReplaceAll(r.Message, "\n", "") - parts := strings.Split(message, " ") +func ParseFileTime( + message string, + fileInfos *FileInfos, +) error { + processMessage := strings.ReplaceAll(message, "\n", "") + parts := strings.Split(processMessage, " ") if len(parts) < 3 { - return nil, errors.New("unable to parse Time protocol") + return errors.New("unable to parse Time protocol") } aTime, err := strconv.Atoi(string(parts[0][0:10])) if err != nil { - return nil, errors.New("unable to parse ATime component of message") + return errors.New("unable to parse ATime component of message") } mTime, err := strconv.Atoi(string(parts[2][0:10])) if err != nil { - return nil, errors.New("unable to parse MTime component of message") + return errors.New("unable to parse MTime component of message") } - return &FileInfos{ - Message: r.Message, - Atime: int64(aTime), - Mtime: int64(mTime), - }, nil + fileInfos.Update(&FileInfos{ + Atime: int64(aTime), + Mtime: int64(mTime), + }) + return nil } // Ack writes an `Ack` message to the remote, does not await its response, a seperate call to ParseResponse is From b4cd115a970fb4a5431ba89b087f5cd740ca55f3 Mon Sep 17 00:00:00 2001 From: Data Dius <16731794+datadius@users.noreply.github.com> Date: Sun, 28 Apr 2024 16:03:57 +0200 Subject: [PATCH 7/7] adjust text to allow test to pass --- protocol.go | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/protocol.go b/protocol.go index 2a615ea..4a98b89 100644 --- a/protocol.go +++ b/protocol.go @@ -45,9 +45,7 @@ func ParseResponse(reader io.Reader, writer io.Writer) (*FileInfos, error) { } if responseType == Warning || responseType == Error { - return fileInfos, errors.New( - fmt.Sprintf("Failed to execute command with a warning or error: %s", message), - ) + return fileInfos, errors.New(message) } // Exit early because we're only interested in the ok response