Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle T message and prepare for adding -p option #80

Merged
merged 7 commits into from
Apr 28, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 97 additions & 8 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,24 @@ func (a *Client) SSHClient() *ssh.Client {
}

// CopyFromFile copies the contents of an os.File to a remote location, it will get the length of the file by looking it up from the filesystem.
func (a *Client) CopyFromFile(ctx context.Context, file os.File, remotePath string, permissions string) error {
func (a *Client) CopyFromFile(
ctx context.Context,
file os.File,
remotePath string,
permissions string,
) error {
return a.CopyFromFilePassThru(ctx, file, remotePath, permissions, nil)
}

// CopyFromFilePassThru copies the contents of an os.File to a remote location, it will get the length of the file by looking it up from the filesystem.
// Access copied bytes by providing a PassThru reader factory.
func (a *Client) CopyFromFilePassThru(ctx context.Context, file os.File, remotePath string, permissions string, passThru PassThru) error {
func (a *Client) CopyFromFilePassThru(
ctx context.Context,
file os.File,
remotePath string,
permissions string,
passThru PassThru,
) error {
stat, err := file.Stat()
if err != nil {
return fmt.Errorf("failed to stat file: %w", err)
Expand All @@ -101,21 +112,39 @@ func (a *Client) CopyFromFilePassThru(ctx context.Context, file os.File, remoteP

// CopyFile copies the contents of an io.Reader to a remote location, the length is determined by reading the io.Reader until EOF
// if the file length in know in advance please use "Copy" instead.
func (a *Client) CopyFile(ctx context.Context, fileReader io.Reader, remotePath string, permissions string) error {
func (a *Client) CopyFile(
ctx context.Context,
fileReader io.Reader,
remotePath string,
permissions string,
) error {
return a.CopyFilePassThru(ctx, fileReader, remotePath, permissions, nil)
}

// CopyFilePassThru copies the contents of an io.Reader to a remote location, the length is determined by reading the io.Reader until EOF
// if the file length in know in advance please use "Copy" instead.
// Access copied bytes by providing a PassThru reader factory.
func (a *Client) CopyFilePassThru(ctx context.Context, fileReader io.Reader, remotePath string, permissions string, passThru PassThru) error {
func (a *Client) CopyFilePassThru(
ctx context.Context,
fileReader io.Reader,
remotePath string,
permissions string,
passThru PassThru,
) error {
contentsBytes, err := ioutil.ReadAll(fileReader)
if err != nil {
return fmt.Errorf("failed to read all data from reader: %w", err)
}
bytesReader := bytes.NewReader(contentsBytes)

return a.CopyPassThru(ctx, bytesReader, remotePath, permissions, int64(len(contentsBytes)), passThru)
return a.CopyPassThru(
ctx,
bytesReader,
remotePath,
permissions,
int64(len(contentsBytes)),
passThru,
)
}

// wait waits for the waitgroup for the specified max timeout.
Expand Down Expand Up @@ -153,13 +182,26 @@ func checkResponse(r io.Reader) error {
}

// Copy copies the contents of an io.Reader to a remote location.
func (a *Client) Copy(ctx context.Context, r io.Reader, remotePath string, permissions string, size int64) error {
func (a *Client) Copy(
ctx context.Context,
r io.Reader,
remotePath string,
permissions string,
size int64,
) error {
return a.CopyPassThru(ctx, r, remotePath, permissions, size, nil)
}

// CopyPassThru copies the contents of an io.Reader to a remote location.
// Access copied bytes by providing a PassThru reader factory
func (a *Client) CopyPassThru(ctx context.Context, r io.Reader, remotePath string, permissions string, size int64, passThru PassThru) error {
func (a *Client) CopyPassThru(
ctx context.Context,
r io.Reader,
remotePath string,
permissions string,
size int64,
passThru PassThru,
) error {
session, err := a.sshClient.NewSession()
if err != nil {
return fmt.Errorf("Error creating ssh session in copy to remote: %v", err)
Expand Down Expand Up @@ -272,7 +314,12 @@ func (a *Client) CopyFromRemote(ctx context.Context, file *os.File, remotePath s
// CopyFromRemotePassThru copies a file from the remote to the given writer. The passThru parameter can be used
// to keep track of progress and how many bytes that were download from the remote.
// `passThru` can be set to nil to disable this behaviour.
func (a *Client) CopyFromRemotePassThru(ctx context.Context, w io.Writer, remotePath string, passThru PassThru) error {
func (a *Client) CopyFromRemotePassThru(
ctx context.Context,
w io.Writer,
remotePath string,
passThru PassThru,
) error {
session, err := a.sshClient.NewSession()
if err != nil {
return fmt.Errorf("Error creating ssh session in copy from remote: %v", err)
Expand Down Expand Up @@ -328,8 +375,50 @@ func (a *Client) CopyFromRemotePassThru(ctx context.Context, w io.Writer, remote
errCh <- errors.New(res.GetMessage())
return
}
if res.NoStandardProtocolType() {
errCh <- errors.New(fmt.Sprintf("Input from server doesn't follow protocol: %s", res.GetMessage()))
return
}

fileInfo := NewFileInfos()

if res.IsTime() {
timeInfo, err := res.ParseFileTime()

if err != nil {
errCh <- err
return
}

fileInfo.Update(timeInfo)

res, err = ParseResponse(r)
datadius marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
errCh <- err
return
}

if res.IsFailure() {
errCh <- errors.New(res.GetMessage())
return
}

if res.NoStandardProtocolType() {
errCh <- errors.New(fmt.Sprintf("Input from server doesn't follow protocol: %s", res.GetMessage()))
return
datadius marked this conversation as resolved.
Show resolved Hide resolved
}
}

// The CHMOD message always comes before the actual data is being sent
if !res.IsChmod() {
errCh <- errors.New(fmt.Sprintf("The data did not contain the expected CHMOD information: %s", res.GetMessage()))
return
}

infos, err := res.ParseFileInfos()

fileInfo.Update(infos)

if err != nil {
errCh <- err
return
Expand Down
83 changes: 79 additions & 4 deletions protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,13 @@ const (
Error ResponseType = 2
)

type ProtocolType = rune

const (
Chmod ProtocolType = 'C'
datadius marked this conversation as resolved.
Show resolved Hide resolved
Time ProtocolType = 'T'
)

// Response represent a response from the SCP command.
// There are tree types of responses that the remote can send back:
// ok, warning and error
Expand All @@ -35,8 +42,9 @@ const (
// The remote sends a confirmation after every SCP command, because a failure can occur after every
// command, the response should be read and checked after sending them.
type Response struct {
Type ResponseType
Message string
Type ResponseType
Message string
ProtocolType rune
datadius marked this conversation as resolved.
Show resolved Hide resolved
}

// ParseResponse reads from the given reader (assuming it is the output of the remote) and parses it into a Response structure.
Expand All @@ -48,6 +56,7 @@ func ParseResponse(reader io.Reader) (Response, error) {
}

responseType := buffer[0]
runeResponseType := rune(buffer[0])
message := ""
if responseType > 0 {
bufferedReader := bufio.NewReader(reader)
Expand All @@ -57,7 +66,11 @@ func ParseResponse(reader io.Reader) (Response, error) {
}
}

return Response{responseType, message}, nil
if len(message) > 0 {
return Response{responseType, message, runeResponseType}, nil
}

return Response{responseType, message, ' '}, nil
}

func (r *Response) IsOk() bool {
Expand All @@ -78,6 +91,18 @@ func (r *Response) IsFailure() bool {
return r.IsWarning() || r.IsError()
}

func (r *Response) IsChmod() bool {
return r.ProtocolType == Chmod
}

func (r *Response) IsTime() bool {
return r.ProtocolType == Time
}

func (r *Response) NoStandardProtocolType() bool {
return !(r.ProtocolType == Chmod || r.ProtocolType == Time)
}

// GetMessage returns the message the remote sent back.
func (r *Response) GetMessage() string {
return r.Message
Expand All @@ -88,13 +113,40 @@ type FileInfos struct {
Filename string
Permissions string
Size int64
Atime int64
Mtime int64
}

func NewFileInfos() *FileInfos {
return &FileInfos{}
}

func (fileInfos *FileInfos) Update(new *FileInfos) {
if new == nil {
return
}
if new.Filename != "" {
fileInfos.Filename = new.Filename
}
if new.Permissions != "" {
fileInfos.Permissions = new.Permissions
}
if new.Size != 0 {
fileInfos.Size = new.Size
}
if new.Atime != 0 {
fileInfos.Atime = new.Atime
}
if new.Mtime != 0 {
fileInfos.Mtime = new.Mtime
}
}

func (r *Response) ParseFileInfos() (*FileInfos, error) {
message := strings.ReplaceAll(r.Message, "\n", "")
parts := strings.Split(message, " ")
if len(parts) < 3 {
return nil, errors.New("unable to parse message as file infos")
return nil, errors.New("unable to parse Chmod protocol")
}

size, err := strconv.Atoi(parts[1])
Expand All @@ -110,6 +162,29 @@ func (r *Response) ParseFileInfos() (*FileInfos, error) {
}, nil
}

func (r *Response) ParseFileTime() (*FileInfos, error) {
message := strings.ReplaceAll(r.Message, "\n", "")
parts := strings.Split(message, " ")
if len(parts) < 3 {
return nil, errors.New("unable to parse Time protocol")
}

aTime, err := strconv.Atoi(string(parts[0][0:10]))
if err != nil {
return nil, errors.New("unable to parse ATime component of message")
}
mTime, err := strconv.Atoi(string(parts[2][0:10]))
if err != nil {
return nil, errors.New("unable to parse MTime component of message")
}

return &FileInfos{
Message: r.Message,
Atime: int64(aTime),
Mtime: int64(mTime),
}, nil
}

// Ack writes an `Ack` message to the remote, does not await its response, a seperate call to ParseResponse is
// therefore required to check if the acknowledgement succeeded.
func Ack(writer io.Writer) error {
Expand Down