From cf29d9d4c20da83fd8a2120a70474b5b35fbf28d Mon Sep 17 00:00:00 2001 From: Panos Date: Sun, 20 Mar 2022 18:21:37 +0000 Subject: [PATCH] Scp send rc (#340) * Added SCP large file test. * Fix issue with scp_send - resolves #337 * Updated changelog * Updated embedded server, tests --- .environment.yml | 1 + Changelog.rst | 7 ++- pssh/clients/native/single.py | 6 ++- tests/embedded_server/openssh.py | 8 ++-- tests/native/test_parallel_client.py | 68 ++++++++++++++++++++++++++-- tests/native/test_single_client.py | 44 +++++++----------- 6 files changed, 96 insertions(+), 38 deletions(-) diff --git a/.environment.yml b/.environment.yml index b8fe6a7d..72a0f20f 100644 --- a/.environment.yml +++ b/.environment.yml @@ -5,3 +5,4 @@ dependencies: - setuptools - pip - toolchain3 + - cython diff --git a/Changelog.rst b/Changelog.rst index 1adf1f09..7f4ec9df 100644 --- a/Changelog.rst +++ b/Changelog.rst @@ -8,13 +8,18 @@ Changes -------- * ``pssh.exceptions.ConnectionError`` is now the same as built-in ``ConnectionError`` and deprecated - to be removed. -* Clients now continue connecting with all addresses in DNS list. In the case where an address refuses connection, +* Clients now attempt to connect with all addresses in DNS list. In the case where an address refuses connection, other available addresses are attempted without delay. For example where a host resolves to both IPv4 and v6 addresses while only one address is accepting connections, or multiple v4/v6 addresses where only some are accepting connections. * Connection actively refused error is no longer subject to retries. +Fixes +----- + +* ``scp_send`` in native clients would sometimes fail to send all data in a race condition with client going out of scope. + 2.8.0 +++++ diff --git a/pssh/clients/native/single.py b/pssh/clients/native/single.py index f17de711..ab8ee7af 100644 --- a/pssh/clients/native/single.py +++ b/pssh/clients/native/single.py @@ -598,6 +598,7 @@ def _scp_recv(self, remote_file, local_file): total += size local_fh.write(data) finally: + local_fh.flush() local_fh.close() file_chan.close() @@ -659,13 +660,16 @@ def _scp_send(self, local_file, remote_file): raise SCPError(msg, remote_file, self.host, ex) try: with open(local_file, 'rb', 2097152) as local_fh: - for data in local_fh: + data = local_fh.read(self._BUF_SIZE) + while data: self.eagain_write(chan.write, data) + data = local_fh.read(self._BUF_SIZE) except Exception as ex: msg = "Error writing to remote file %s on host %s - %s" logger.error(msg, remote_file, self.host, ex) raise SCPError(msg, remote_file, self.host, ex) finally: + self._eagain(chan.flush) chan.close() def _sftp_openfh(self, open_func, remote_file, *args): diff --git a/tests/embedded_server/openssh.py b/tests/embedded_server/openssh.py index c8e35f48..67cf0a10 100644 --- a/tests/embedded_server/openssh.py +++ b/tests/embedded_server/openssh.py @@ -86,12 +86,14 @@ def start_server(self): logger.debug("Starting server with %s" % (" ".join(cmd),)) self.server_proc = Popen(cmd) try: - self.server_proc.wait(.1) + self.server_proc.wait(.3) except TimeoutExpired: pass else: - logger.error(self.server_proc.stdout.read()) - logger.error(self.server_proc.stderr.read()) + if self.server_proc.stdout is not None: + logger.error(self.server_proc.stdout.read()) + if self.server_proc.stderr is not None: + logger.error(self.server_proc.stderr.read()) raise Exception("Server could not start") def stop(self): diff --git a/tests/native/test_parallel_client.py b/tests/native/test_parallel_client.py index 9f403224..82a70d96 100644 --- a/tests/native/test_parallel_client.py +++ b/tests/native/test_parallel_client.py @@ -133,6 +133,8 @@ def test_client_shells_read_timeout(self): def test_client_shells_timeout(self): client = ParallelSSHClient([self.host], pkey=self.user_key, port=self.port, timeout=0.01, num_retries=1) + client._make_ssh_client = MagicMock() + client._make_ssh_client.side_effect = Timeout self.assertRaises(Timeout, client.open_shell) def test_client_shells_join_timeout(self): @@ -1517,8 +1519,8 @@ def test_scp_send_dir_recurse(self): except OSError: pass - def test_scp_send_large_files_timeout(self): - hosts = ['127.0.0.1%s' % (i,) for i in range(1, 10)] + def test_scp_send_larger_files(self): + hosts = ['127.0.0.1%s' % (i,) for i in range(1, 3)] servers = [OpenSSHServer(host, port=self.port) for host in hosts] for server in servers: server.start_server() @@ -1535,7 +1537,7 @@ def test_scp_send_large_files_timeout(self): remote_file_names = [arg['remote_file'] for arg in copy_args] sha = sha256() with open(local_filename, 'wb') as file_h: - for _ in range(5000): + for _ in range(10000): data = os.urandom(1024) file_h.write(data) sha.update(data) @@ -1547,13 +1549,15 @@ def test_scp_send_large_files_timeout(self): except Exception: raise else: - sleep(.2) + del client for remote_file_name in remote_file_names: remote_file_abspath = os.path.expanduser('~/' + remote_file_name) self.assertTrue(os.path.isfile(remote_file_abspath)) with open(remote_file_abspath, 'rb') as remote_fh: - for data in remote_fh: + data = remote_fh.read(10240) + while data: sha.update(data) + data = remote_fh.read(10240) remote_file_sha = sha.hexdigest() sha = sha256() self.assertEqual(source_file_sha, remote_file_sha) @@ -1679,6 +1683,60 @@ def test_scp_recv(self): except Exception: pass + def test_scp_recv_larger_files(self): + hosts = ['127.0.0.1%s' % (i,) for i in range(1, 3)] + servers = [OpenSSHServer(host, port=self.port) for host in hosts] + for server in servers: + server.start_server() + client = ParallelSSHClient( + hosts, port=self.port, pkey=self.user_key, num_retries=1, timeout=1, + pool_size=len(hosts), + ) + dir_name = os.path.dirname(__file__) + remote_filename = 'test_file' + remote_filepath = os.path.join(dir_name, remote_filename) + local_filename = 'file_copy' + copy_args = [{ + 'remote_file': remote_filepath, + 'local_file': os.path.expanduser("~/" + 'host_%s_%s' % (n, local_filename))} + for n in range(len(hosts)) + ] + local_file_names = [ + arg['local_file'] for arg in copy_args] + sha = sha256() + with open(remote_filepath, 'wb') as file_h: + for _ in range(10000): + data = os.urandom(1024) + file_h.write(data) + sha.update(data) + file_h.flush() + source_file_sha = sha.hexdigest() + sha = sha256() + cmds = client.scp_recv('%(remote_file)s', '%(local_file)s', copy_args=copy_args) + try: + joinall(cmds, raise_error=True) + except Exception: + raise + else: + del client + for _local_file_name in local_file_names: + self.assertTrue(os.path.isfile(_local_file_name)) + with open(_local_file_name, 'rb') as fh: + data = fh.read(10240) + while data: + sha.update(data) + data = fh.read(10240) + local_file_sha = sha.hexdigest() + sha = sha256() + self.assertEqual(source_file_sha, local_file_sha) + finally: + try: + os.unlink(remote_filepath) + for _local_file_name in local_file_names: + os.unlink(_local_file_name) + except OSError: + pass + def test_bad_hosts_value(self): self.assertRaises(TypeError, ParallelSSHClient, 'a host') self.assertRaises(TypeError, ParallelSSHClient, b'a host') diff --git a/tests/native/test_single_client.py b/tests/native/test_single_client.py index 5aa3bd6f..33594b74 100644 --- a/tests/native/test_single_client.py +++ b/tests/native/test_single_client.py @@ -405,10 +405,10 @@ def test_auth_retry_failure(self): def test_connection_timeout(self): cmd = spawn(SSHClient, 'fakehost.com', port=12345, - num_retries=1, timeout=1, _auth_thread_pool=False) + num_retries=1, timeout=.1, _auth_thread_pool=False) # Should fail within greenlet timeout, otherwise greenlet will # raise timeout which will fail the test - self.assertRaises(ConnectionErrorException, cmd.get, timeout=2) + self.assertRaises(ConnectionErrorException, cmd.get, timeout=1) def test_client_read_timeout(self): client = SSHClient(self.host, port=self.port, @@ -657,27 +657,22 @@ def test_scp_recv_large_file(self): os.unlink(_path) except OSError: pass + sha = sha256() try: with open(file_path_from, 'wb') as fh: - # ~300MB - for _ in range(20000000): - fh.write(b"adsfasldkfjabafj") + for _ in range(10000): + data = os.urandom(1024) + fh.write(data) + sha.update(data) + source_file_sha = sha.hexdigest() self.client.scp_recv(file_path_from, file_copy_to_dirpath) self.assertTrue(os.path.isfile(file_copy_to_dirpath)) - read_file_size = os.stat(file_path_from).st_size - written_file_size = os.stat(file_copy_to_dirpath).st_size - self.assertEqual(read_file_size, written_file_size) - sha = sha256() - with open(file_path_from, 'rb') as fh: - for block in fh: - sha.update(block) - read_file_hash = sha.hexdigest() sha = sha256() with open(file_copy_to_dirpath, 'rb') as fh: for block in fh: sha.update(block) written_file_hash = sha.hexdigest() - self.assertEqual(read_file_hash, written_file_hash) + self.assertEqual(source_file_sha, written_file_hash) finally: for _path in (file_path_from, file_copy_to_dirpath): try: @@ -728,29 +723,22 @@ def test_scp_send_large_file(self): os.unlink(_path) except OSError: pass + sha = sha256() try: with open(file_path_from, 'wb') as fh: - # ~300MB - for _ in range(20000000): - fh.write(b"adsfasldkfjabafj") + for _ in range(10000): + data = os.urandom(1024) + fh.write(data) + sha.update(data) + source_file_sha = sha.hexdigest() self.client.scp_send(file_path_from, file_copy_to_dirpath) self.assertTrue(os.path.isfile(file_copy_to_dirpath)) - # OS file flush race condition - sleep(.1) - read_file_size = os.stat(file_path_from).st_size - written_file_size = os.stat(file_copy_to_dirpath).st_size - self.assertEqual(read_file_size, written_file_size) - sha = sha256() - with open(file_path_from, 'rb') as fh: - for block in fh: - sha.update(block) - read_file_hash = sha.hexdigest() sha = sha256() with open(file_copy_to_dirpath, 'rb') as fh: for block in fh: sha.update(block) written_file_hash = sha.hexdigest() - self.assertEqual(read_file_hash, written_file_hash) + self.assertEqual(source_file_sha, written_file_hash) finally: for _path in (file_path_from, file_copy_to_dirpath): try: