Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle T message and prepare for adding -p option #80

Merged
merged 7 commits into from
Apr 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 59 additions & 27 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ package scp
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"io/ioutil"
Expand Down Expand Up @@ -85,13 +84,24 @@ func (a *Client) SSHClient() *ssh.Client {
}

// CopyFromFile copies the contents of an os.File to a remote location, it will get the length of the file by looking it up from the filesystem.
func (a *Client) CopyFromFile(ctx context.Context, file os.File, remotePath string, permissions string) error {
func (a *Client) CopyFromFile(
ctx context.Context,
file os.File,
remotePath string,
permissions string,
) error {
return a.CopyFromFilePassThru(ctx, file, remotePath, permissions, nil)
}

// CopyFromFilePassThru copies the contents of an os.File to a remote location, it will get the length of the file by looking it up from the filesystem.
// Access copied bytes by providing a PassThru reader factory.
func (a *Client) CopyFromFilePassThru(ctx context.Context, file os.File, remotePath string, permissions string, passThru PassThru) error {
func (a *Client) CopyFromFilePassThru(
ctx context.Context,
file os.File,
remotePath string,
permissions string,
passThru PassThru,
) error {
stat, err := file.Stat()
if err != nil {
return fmt.Errorf("failed to stat file: %w", err)
Expand All @@ -101,21 +111,39 @@ func (a *Client) CopyFromFilePassThru(ctx context.Context, file os.File, remoteP

// CopyFile copies the contents of an io.Reader to a remote location, the length is determined by reading the io.Reader until EOF
// if the file length in know in advance please use "Copy" instead.
func (a *Client) CopyFile(ctx context.Context, fileReader io.Reader, remotePath string, permissions string) error {
func (a *Client) CopyFile(
ctx context.Context,
fileReader io.Reader,
remotePath string,
permissions string,
) error {
return a.CopyFilePassThru(ctx, fileReader, remotePath, permissions, nil)
}

// CopyFilePassThru copies the contents of an io.Reader to a remote location, the length is determined by reading the io.Reader until EOF
// if the file length in know in advance please use "Copy" instead.
// Access copied bytes by providing a PassThru reader factory.
func (a *Client) CopyFilePassThru(ctx context.Context, fileReader io.Reader, remotePath string, permissions string, passThru PassThru) error {
func (a *Client) CopyFilePassThru(
ctx context.Context,
fileReader io.Reader,
remotePath string,
permissions string,
passThru PassThru,
) error {
contentsBytes, err := ioutil.ReadAll(fileReader)
if err != nil {
return fmt.Errorf("failed to read all data from reader: %w", err)
}
bytesReader := bytes.NewReader(contentsBytes)

return a.CopyPassThru(ctx, bytesReader, remotePath, permissions, int64(len(contentsBytes)), passThru)
return a.CopyPassThru(
ctx,
bytesReader,
remotePath,
permissions,
int64(len(contentsBytes)),
passThru,
)
}

// wait waits for the waitgroup for the specified max timeout.
Expand All @@ -139,27 +167,36 @@ func wait(wg *sync.WaitGroup, ctx context.Context) error {
// checkResponse checks the response it reads from the remote, and will return a single error in case
// of failure.
func checkResponse(r io.Reader) error {
response, err := ParseResponse(r)
_, err := ParseResponse(r, nil)
if err != nil {
return err
}

if response.IsFailure() {
return errors.New(response.GetMessage())
}

return nil

}

// Copy copies the contents of an io.Reader to a remote location.
func (a *Client) Copy(ctx context.Context, r io.Reader, remotePath string, permissions string, size int64) error {
func (a *Client) Copy(
ctx context.Context,
r io.Reader,
remotePath string,
permissions string,
size int64,
) error {
return a.CopyPassThru(ctx, r, remotePath, permissions, size, nil)
}

// CopyPassThru copies the contents of an io.Reader to a remote location.
// Access copied bytes by providing a PassThru reader factory
func (a *Client) CopyPassThru(ctx context.Context, r io.Reader, remotePath string, permissions string, size int64, passThru PassThru) error {
func (a *Client) CopyPassThru(
ctx context.Context,
r io.Reader,
remotePath string,
permissions string,
size int64,
passThru PassThru,
) error {
session, err := a.sshClient.NewSession()
if err != nil {
return fmt.Errorf("Error creating ssh session in copy to remote: %v", err)
Expand Down Expand Up @@ -272,7 +309,12 @@ func (a *Client) CopyFromRemote(ctx context.Context, file *os.File, remotePath s
// CopyFromRemotePassThru copies a file from the remote to the given writer. The passThru parameter can be used
// to keep track of progress and how many bytes that were download from the remote.
// `passThru` can be set to nil to disable this behaviour.
func (a *Client) CopyFromRemotePassThru(ctx context.Context, w io.Writer, remotePath string, passThru PassThru) error {
func (a *Client) CopyFromRemotePassThru(
ctx context.Context,
w io.Writer,
remotePath string,
passThru PassThru,
) error {
session, err := a.sshClient.NewSession()
if err != nil {
return fmt.Errorf("Error creating ssh session in copy from remote: %v", err)
Expand Down Expand Up @@ -319,17 +361,7 @@ func (a *Client) CopyFromRemotePassThru(ctx context.Context, w io.Writer, remote
return
}

res, err := ParseResponse(r)
if err != nil {
errCh <- err
return
}
if res.IsFailure() {
errCh <- errors.New(res.GetMessage())
return
}

infos, err := res.ParseFileInfos()
fileInfo, err := ParseResponse(r, in)
if err != nil {
errCh <- err
return
Expand All @@ -342,10 +374,10 @@ func (a *Client) CopyFromRemotePassThru(ctx context.Context, w io.Writer, remote
}

if passThru != nil {
r = passThru(r, infos.Size)
r = passThru(r, fileInfo.Size)
}

_, err = CopyN(w, r, infos.Size)
_, err = CopyN(w, r, fileInfo.Size)
if err != nil {
errCh <- err
return
Expand Down
170 changes: 120 additions & 50 deletions protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,42 +9,30 @@ package scp
import (
"bufio"
"errors"
"fmt"
"io"
"strconv"
"strings"
)

type ResponseType = uint8
type ResponseType = byte

const (
Ok ResponseType = 0
Warning ResponseType = 1
Error ResponseType = 2
Create ResponseType = 'C'
Time ResponseType = 'T'
)

// Response represent a response from the SCP command.
// There are tree types of responses that the remote can send back:
// ok, warning and error
//
// The difference between warning and error is that the connection is not closed by the remote,
// however, a warning can indicate a file transfer failure (such as invalid destination directory)
// and such be handled as such.
//
// All responses except for the `Ok` type always have a message (although these can be empty)
//
// The remote sends a confirmation after every SCP command, because a failure can occur after every
// command, the response should be read and checked after sending them.
type Response struct {
Type ResponseType
Message string
}

// ParseResponse reads from the given reader (assuming it is the output of the remote) and parses it into a Response structure.
func ParseResponse(reader io.Reader) (Response, error) {
func ParseResponse(reader io.Reader, writer io.Writer) (*FileInfos, error) {
fileInfos := NewFileInfos()

buffer := make([]uint8, 1)
_, err := reader.Read(buffer)
if err != nil {
return Response{}, err
return fileInfos, err
}

responseType := buffer[0]
Expand All @@ -53,61 +41,143 @@ func ParseResponse(reader io.Reader) (Response, error) {
bufferedReader := bufio.NewReader(reader)
message, err = bufferedReader.ReadString('\n')
if err != nil {
return Response{}, err
return fileInfos, err
}
}

return Response{responseType, message}, nil
}
if responseType == Warning || responseType == Error {
return fileInfos, errors.New(message)
}

func (r *Response) IsOk() bool {
return r.Type == Ok
}
// Exit early because we're only interested in the ok response
if responseType == Ok {
return fileInfos, nil
}

func (r *Response) IsWarning() bool {
return r.Type == Warning
}
if !(responseType == Create || responseType == Time) {
return fileInfos, errors.New(
fmt.Sprintf(
"Message does not follow scp protocol: %s\n Cmmmm <length> <filename> or T<mtime> 0 <atime> 0",
message,
),
)
}

// IsError returns true when the remote responded with an error.
func (r *Response) IsError() bool {
return r.Type == Error
}
if responseType == Time {
err = ParseFileTime(message, fileInfos)
if err != nil {
return nil, err
}

message, err = bufferedReader.ReadString('\n')
if err == io.EOF {
err = Ack(writer)
if err != nil {
return fileInfos, err
}
message, err = bufferedReader.ReadString('\n')

if err != nil {
return fileInfos, err
}
}

if err != nil && err != io.EOF {
return fileInfos, err
}

responseType = message[0]
}

// IsFailure returns true when the remote answered with a warning or an error.
func (r *Response) IsFailure() bool {
return r.IsWarning() || r.IsError()
}
if responseType == Create {
err = ParseFileInfos(message, fileInfos)
if err != nil {
return nil, err
}
}
}

// GetMessage returns the message the remote sent back.
func (r *Response) GetMessage() string {
return r.Message
return fileInfos, nil
}

type FileInfos struct {
Message string
Filename string
Permissions string
Size int64
Atime int64
Mtime int64
}

func NewFileInfos() *FileInfos {
return &FileInfos{}
}

func (r *Response) ParseFileInfos() (*FileInfos, error) {
message := strings.ReplaceAll(r.Message, "\n", "")
parts := strings.Split(message, " ")
func (fileInfos *FileInfos) Update(new *FileInfos) {
if new == nil {
return
}
if new.Filename != "" {
fileInfos.Filename = new.Filename
}
if new.Permissions != "" {
fileInfos.Permissions = new.Permissions
}
if new.Size != 0 {
fileInfos.Size = new.Size
}
if new.Atime != 0 {
fileInfos.Atime = new.Atime
}
if new.Mtime != 0 {
fileInfos.Mtime = new.Mtime
}
}

func ParseFileInfos(message string, fileInfos *FileInfos) error {
processMessage := strings.ReplaceAll(message, "\n", "")
parts := strings.Split(processMessage, " ")
if len(parts) < 3 {
return nil, errors.New("unable to parse message as file infos")
return errors.New("unable to parse Chmod protocol")
}

size, err := strconv.Atoi(parts[1])
if err != nil {
return nil, err
return err
}

return &FileInfos{
Message: r.Message,
fileInfos.Update(&FileInfos{
Filename: parts[2],
Permissions: parts[0],
Size: int64(size),
Filename: parts[2],
}, nil
})

return nil
}

func ParseFileTime(
message string,
fileInfos *FileInfos,
) error {
processMessage := strings.ReplaceAll(message, "\n", "")
parts := strings.Split(processMessage, " ")
if len(parts) < 3 {
return errors.New("unable to parse Time protocol")
}

aTime, err := strconv.Atoi(string(parts[0][0:10]))
if err != nil {
return errors.New("unable to parse ATime component of message")
}
mTime, err := strconv.Atoi(string(parts[2][0:10]))
if err != nil {
return errors.New("unable to parse MTime component of message")
}

fileInfos.Update(&FileInfos{
Atime: int64(aTime),
Mtime: int64(mTime),
})
return nil
}

// Ack writes an `Ack` message to the remote, does not await its response, a seperate call to ParseResponse is
Expand Down