Skip to content

Commit

Permalink
add preserve functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
datadius committed Apr 29, 2024
1 parent db7cf4f commit c3d1d4d
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 13 deletions.
33 changes: 29 additions & 4 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ type Client struct {
// Handler called when calling `Close` to clean up any remaining
// resources managed by `Client`.
closeHandler ICloseHandler

// Preserve the remote file permissions, modification time and access time
preserve bool
}

// Connect connects to the remote SSH server, returns error if it couldn't establish a session to the SSH server.
Expand Down Expand Up @@ -315,14 +318,28 @@ func (a *Client) CopyFromRemotePassThru(
remotePath string,
passThru PassThru,
) error {
_, err := a.CopyFromRemoteFileInfos(ctx, w, remotePath, passThru)

return err
}

// CopyFroRemoteFileInfos copies a file from the remote to a given writer and return a FileInfos struct
// containing information about the file such as permissions, the file size, modification time and access time
func (a *Client) CopyFromRemoteFileInfos(
ctx context.Context,
w io.Writer,
remotePath string,
passThru PassThru,
) (*FileInfos, error) {
session, err := a.sshClient.NewSession()
if err != nil {
return fmt.Errorf("Error creating ssh session in copy from remote: %v", err)
return nil, fmt.Errorf("Error creating ssh session in copy from remote: %v", err)
}
defer session.Close()

wg := sync.WaitGroup{}
errCh := make(chan error, 4)
fileInfosCh := make(chan *FileInfos, 1)

wg.Add(1)
go func() {
Expand All @@ -349,7 +366,11 @@ func (a *Client) CopyFromRemotePassThru(
}
defer in.Close()

err = session.Start(fmt.Sprintf("%s -f %q", a.RemoteBinary, remotePath))
if a.preserve {
err = session.Start(fmt.Sprintf("%s -pf %q", a.RemoteBinary, remotePath))
} else {
err = session.Start(fmt.Sprintf("%s -f %q", a.RemoteBinary, remotePath))
}
if err != nil {
errCh <- err
return
Expand All @@ -367,6 +388,8 @@ func (a *Client) CopyFromRemotePassThru(
return
}

fileInfosCh <- fileInfo

err = Ack(in)
if err != nil {
errCh <- err
Expand Down Expand Up @@ -403,11 +426,13 @@ func (a *Client) CopyFromRemotePassThru(
}

if err := wait(&wg, ctx); err != nil {
return err
return nil, err
}

finalErr := <-errCh
fileInfos := <-fileInfosCh
close(errCh)
return finalErr
return fileInfos, finalErr
}

func (a *Client) Close() {
Expand Down
12 changes: 11 additions & 1 deletion configurer.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ type ClientConfigurer struct {
timeout time.Duration
remoteBinary string
sshClient *ssh.Client
preserve bool
}

// NewConfigurer creates a new client configurer.
Expand All @@ -36,6 +37,7 @@ func NewConfigurer(host string, config *ssh.ClientConfig) *ClientConfigurer {
clientConfig: config,
timeout: 0, // no timeout by default
remoteBinary: "scp",
preserve: false,
}
}

Expand Down Expand Up @@ -70,6 +72,13 @@ func (c *ClientConfigurer) SSHClient(sshClient *ssh.Client) *ClientConfigurer {
return c
}

// Preserve alters the preserve flag
// Defaults to false
func (c *ClientConfigurer) Preserve(preserve bool) *ClientConfigurer {
c.preserve = preserve
return c
}

// Create builds a client with the configuration stored within the ClientConfigurer.
func (c *ClientConfigurer) Create() Client {
return Client{
Expand All @@ -78,6 +87,7 @@ func (c *ClientConfigurer) Create() Client {
Timeout: c.timeout,
RemoteBinary: c.remoteBinary,
sshClient: c.sshClient,
closeHandler: EmptyHandler{},
preserve: c.preserve,
closeHandler: EmptyHandler{},
}
}
12 changes: 4 additions & 8 deletions protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,20 +68,16 @@ func ParseResponse(reader io.Reader, writer io.Writer) (*FileInfos, error) {
return nil, err
}

message, err = bufferedReader.ReadString('\n')
if err == io.EOF {
if bufferedReader.Buffered() == 0 {
err = Ack(writer)
if err != nil {
return fileInfos, err
}
message, err = bufferedReader.ReadString('\n')

if err != nil {
return fileInfos, err
}
}

if err != nil && err != io.EOF {
message, err = bufferedReader.ReadString('\n')

if err != nil {
return fileInfos, err
}

Expand Down

0 comments on commit c3d1d4d

Please sign in to comment.