diff --git a/Changelog.rst b/Changelog.rst index a76f97aa..42af672f 100644 --- a/Changelog.rst +++ b/Changelog.rst @@ -1,6 +1,14 @@ Change Log ============ +1.5.2 +++++++ + +Changes +-------- + +* Output generators automatically restarted on call to ``join`` so output can resume on any timeouts. + 1.5.1 ++++++ diff --git a/doc/advanced.rst b/doc/advanced.rst index b4a8b336..e4ce5ef5 100644 --- a/doc/advanced.rst +++ b/doc/advanced.rst @@ -186,8 +186,8 @@ In some cases, such as when the remote command never terminates unless interrupt client.join(output, timeout=1) # Closing channel which has PTY has the effect of terminating # any running processes started on that channel. - for host in client.hosts: - client.host_clients[host].close_channel(output[host].channel) + for host, host_out in output: + client.host_clients[host].close_channel(host_out.channel) client.join(output) Without a PTY, the ``join`` will complete but the remote process will be left running as per SSH protocol specifications. @@ -196,16 +196,14 @@ Furthermore, once reading output has timed out, it is necessary to restart the o .. code-block:: python - output = client.run_command(.., timeout=1) + output = client.run_command(<..>, timeout=1) for host, host_out in output.items(): try: stdout = list(host_out.stdout) except Timeout: - stdout_buf = client.host_clients[host].read_output_buffer( - client.host_clients[host].read_output( - output[host].channel, timeout=1)) - # Reset generator to be able to gather new output - host_out.stdout = stdout_buf + client.reset_output_generators(host_out) + +Generator reset shown above is also performed automatically by calls to ``join`` and does not need to be done manually ``join`` is used after output reading. .. note:: diff --git a/pssh/pssh2_client.py b/pssh/pssh2_client.py index bbebbbff..aebe4031 100644 --- a/pssh/pssh2_client.py +++ b/pssh/pssh2_client.py @@ -150,7 +150,7 @@ def run_command(self, command, sudo=False, user=None, stop_on_errors=True, raised otherwise :type host_args: tuple or list :param encoding: Encoding to use for output. Must be valid - `Python codec `_ + `Python codec `_ :type encoding: str :param timeout: (Optional) Timeout in seconds for reading from stdout or stderr. Defaults to no timeout. Reading from stdout/stderr will @@ -218,14 +218,14 @@ def join(self, output, consume_output=False, timeout=None): continue client = self.host_clients[host] channel = host_out.channel + stdout, stderr = self.reset_output_generators( + host_out, client=client, channel=channel, timeout=timeout) try: client.wait_finished(channel, timeout=timeout) except Timeout: raise Timeout( "Timeout of %s sec(s) reached on host %s with command " "still running", timeout, host) - stdout = host_out.stdout - stderr = host_out.stderr if timeout: # Must consume buffers prior to EOF check self._consume_output(stdout, stderr) @@ -237,6 +237,36 @@ def join(self, output, consume_output=False, timeout=None): self._consume_output(stdout, stderr) self.get_exit_codes(output) + def reset_output_generators(self, host_out, timeout=None, + client=None, channel=None, + encoding='utf-8'): + """Reset output generators for host output. + + :param host_out: Host output + :type host_out: :py:class:`pssh.output.HostOutput` + :param client: (Optional) SSH client + :type client: :py:class:`pssh.ssh2_client.SSHClient` + :param channel: (Optional) SSH channel + :type channel: :py:class:`ssh2.channel.Channel` + :param timeout: (Optional) Timeout setting + :type timeout: int + :param encoding: (Optional) Encoding to use for output. Must be valid + `Python codec `_ + :type encoding: str + + :rtype: tuple(stdout, stderr) + """ + channel = host_out.channel if channel is None else channel + client = self.host_clients[host_out.host] if client is None else client + stdout = client.read_output_buffer( + client.read_output(channel, timeout=timeout), encoding=encoding) + stderr = client.read_output_buffer( + client.read_stderr(channel, timeout=timeout), + prefix='\t[err]', encoding=encoding) + host_out.stdout = stdout + host_out.stderr = stderr + return stdout, stderr + def _consume_output(self, stdout, stderr): for line in stdout: pass diff --git a/tests/test_pssh_ssh2_client.py b/tests/test_pssh_ssh2_client.py index bb0cfd6e..45a94937 100644 --- a/tests/test_pssh_ssh2_client.py +++ b/tests/test_pssh_ssh2_client.py @@ -1193,18 +1193,15 @@ def test_join_timeout_set_no_timeout(self): def test_read_timeout(self): client = ParallelSSHClient([self.host], port=self.port, pkey=self.user_key) - output = client.run_command('sleep 2; echo me', timeout=1) + output = client.run_command('sleep 2; echo me; echo me; echo me', timeout=1) for host, host_out in output.items(): self.assertRaises(Timeout, list, host_out.stdout) self.assertFalse(output[self.host].channel.eof()) client.join(output) for host, host_out in output.items(): - stdout_buf = client.host_clients[self.host].read_output_buffer( - client.host_clients[self.host].read_output( - output[self.host].channel, timeout=1)) - host_out.stdout = stdout_buf stdout = list(output[self.host].stdout) - self.assertEqual(len(stdout), 1) + self.assertEqual(len(stdout), 3) + self.assertTrue(output[self.host].channel.eof()) def test_timeout_file_read(self): dir_name = os.path.dirname(__file__) @@ -1225,6 +1222,7 @@ def test_timeout_file_read(self): pass else: raise Exception("Timeout should have been raised") + self.assertRaises(Timeout, self.client.join, output, timeout=1) channel = output[self.host].channel self.client.host_clients[self.host].close_channel(channel) self.client.join(output)