Skip to content

Commit

Permalink
Pkey from memory (#329)
Browse files Browse the repository at this point in the history
* Added private key auth from memory implementation for native client
* Updated gitignore
* Updated parallel clients to use in-memory pkey data
* Added tests
* Bump requirements
* Updated changelog
* Updated documentation
  • Loading branch information
pkittenis authored Nov 28, 2021
1 parent 3215d05 commit 9c9b678
Show file tree
Hide file tree
Showing 14 changed files with 134 additions and 24 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,5 @@ doc/_build

tests/unit_test_cert_key-cert.pub
tests/embedded_server/principals
tests/embedded_server/sshd_config_*
tests/embedded_server/*.pid
13 changes: 13 additions & 0 deletions Changelog.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,19 @@
Change Log
============

2.8.0
+++++

Changes
--------

* All clients now support private key data as bytes in ``pkey`` parameter for authentication from in-memory private key
data - #317. See `documentation <https://parallel-ssh.readthedocs.io/en/latest/advanced.html#in-memory-private-keys>`_
for examples.
* Parallel clients now read a provided private key path only once and use in-memory data for authentication to avoid
reading same file multiple times, if a path is provided.


2.7.1
+++++

Expand Down
18 changes: 18 additions & 0 deletions doc/advanced.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,24 @@ A private key can also be provided programmatically.
Where ``my_key`` is a private key file under `.ssh` in the user's home directory.


In-Memory Private Keys
========================

Private key data can also be provided as bytes for authentication from in-memory private keys.

.. code-block:: python
from pssh.clients import ParallelSSHClient
pkey_data = b"""-----BEGIN RSA PRIVATE KEY-----
<key data>
-----END RSA PRIVATE KEY-----
"""
client = ParallelSSHClient(hosts, pkey=pkey_data)
Private key data provided this way *must* be in bytes. This is supported by all parallel and single host clients.


Native Clients
***************

Expand Down
25 changes: 19 additions & 6 deletions pssh/clients/base/single.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from ssh2.exceptions import AgentConnectionError, AgentListIdentitiesError, \
AgentAuthenticationError, AgentGetIdentityError

from ..common import _validate_pkey_path
from ..common import _validate_pkey
from ...constants import DEFAULT_RETRIES, RETRY_DELAY
from ..reader import ConcurrentRWBuffer
from ...exceptions import UnknownHostError, AuthenticationError, \
Expand Down Expand Up @@ -182,12 +182,15 @@ def __init__(self, host,
self.session = None
self._host = proxy_host if proxy_host else host
self._port = proxy_port if proxy_port else self.port
self.pkey = _validate_pkey_path(pkey, self.host)
self.pkey = _validate_pkey(pkey)
self.identity_auth = identity_auth
self._keepalive_greenlet = None
self.ipv6_only = ipv6_only
self._init()

def _pkey_from_memory(self, pkey_data):
raise NotImplementedError

def _init(self):
self._connect(self._host, self._port)
self._init_session()
Expand Down Expand Up @@ -309,7 +312,7 @@ def _identity_auth(self):
"Trying to authenticate with identity file %s",
identity_file)
try:
self._pkey_auth(identity_file, password=self.password)
self._pkey_file_auth(identity_file, password=self.password)
except Exception as ex:
logger.debug(
"Authentication with identity file %s failed with %s, "
Expand All @@ -331,8 +334,8 @@ def _keepalive(self):
def auth(self):
if self.pkey is not None:
logger.debug(
"Proceeding with private key file authentication")
return self._pkey_auth(self.pkey, password=self.password)
"Proceeding with private key authentication")
return self._pkey_auth(self.pkey)
if self.allow_agent:
try:
self._agent_auth()
Expand Down Expand Up @@ -364,7 +367,17 @@ def _agent_auth(self):
def _password_auth(self):
raise NotImplementedError

def _pkey_auth(self, pkey_file, password=None):
def _pkey_auth(self, pkey):
_pkey = pkey
if isinstance(pkey, str):
logger.debug("Private key is provided as str, loading from private key file path")
with open(pkey, 'rb') as fh:
_pkey = fh.read()
elif isinstance(pkey, bytes):
logger.debug("Private key is provided in bytes, using as private key data")
return self._pkey_from_memory(_pkey)

def _pkey_file_auth(self, pkey_file, password=None):
raise NotImplementedError

def _open_session(self):
Expand Down
10 changes: 9 additions & 1 deletion pssh/clients/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from ..exceptions import PKeyFileError


def _validate_pkey_path(pkey, host=None):
def _validate_pkey_path(pkey):
if pkey is None:
return
pkey = os.path.normpath(os.path.expanduser(pkey))
Expand All @@ -31,3 +31,11 @@ def _validate_pkey_path(pkey, host=None):
ex = PKeyFileError(msg, pkey)
raise ex
return pkey


def _validate_pkey(pkey):
if pkey is None:
return
if isinstance(pkey, str):
return _validate_pkey_path(pkey)
return pkey
19 changes: 13 additions & 6 deletions pssh/clients/native/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import logging

from .single import SSHClient
from ..common import _validate_pkey_path
from ..common import _validate_pkey
from ..base.parallel import BaseParallelSSHClient
from ...constants import DEFAULT_RETRIES, RETRY_DELAY
from ...exceptions import HostArgumentError
Expand Down Expand Up @@ -50,9 +50,11 @@ def __init__(self, hosts, user=None, password=None, port=22, pkey=None,
:param port: (Optional) Port number to use for SSH connection. Defaults
to 22.
:type port: int
:param pkey: Private key file path to use. Path must be either absolute
:param pkey: Private key file path or private key data to use.
Paths must be str type and either absolute
path or relative to user home directory like ``~/<path>``.
:type pkey: str
Bytes type input is used as private key data for authentication.
:type pkey: str or bytes
:param num_retries: (Optional) Number of connection and authentication
attempts before the client gives up. Defaults to 3.
:type num_retries: int
Expand Down Expand Up @@ -127,10 +129,10 @@ def __init__(self, hosts, user=None, password=None, port=22, pkey=None,
identity_auth=identity_auth,
ipv6_only=ipv6_only,
)
self.pkey = _validate_pkey_path(pkey)
self.pkey = _validate_pkey(pkey)
self.proxy_host = proxy_host
self.proxy_port = proxy_port
self.proxy_pkey = _validate_pkey_path(proxy_pkey)
self.proxy_pkey = _validate_pkey(proxy_pkey)
self.proxy_user = proxy_user
self.proxy_password = proxy_password
self.forward_ssh_agent = forward_ssh_agent
Expand Down Expand Up @@ -235,9 +237,14 @@ def _make_ssh_client(self, host_i, host):
or self._host_clients[(host_i, host)] is None:
_user, _port, _password, _pkey, proxy_host, proxy_port, proxy_user, \
proxy_password, proxy_pkey = self._get_host_config_values(host_i, host)
if isinstance(self.pkey, str):
with open(_pkey, 'rb') as fh:
_pkey_data = fh.read()
else:
_pkey_data = _pkey
_client = SSHClient(
host, user=_user, password=_password, port=_port,
pkey=_pkey, num_retries=self.num_retries,
pkey=_pkey_data, num_retries=self.num_retries,
timeout=self.timeout,
allow_agent=self.allow_agent, retry_delay=self.retry_delay,
proxy_host=proxy_host,
Expand Down
12 changes: 10 additions & 2 deletions pssh/clients/native/single.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ def __init__(self, host,
:param pkey: Private key file path to use for authentication. Path must
be either absolute path or relative to user home directory
like ``~/<path>``.
:type pkey: str
Bytes type input is used as private key data for authentication.
:type pkey: str or bytes
:param num_retries: (Optional) Number of connection and authentication
attempts before the client gives up. Defaults to 3.
:type num_retries: int
Expand Down Expand Up @@ -239,12 +240,19 @@ def _keepalive(self):
def _agent_auth(self):
self.session.agent_auth(self.user)

def _pkey_auth(self, pkey_file, password=None):
def _pkey_file_auth(self, pkey_file, password=None):
self.session.userauth_publickey_fromfile(
self.user,
pkey_file,
passphrase=password if password is not None else b'')

def _pkey_from_memory(self, pkey_data):
self.session.userauth_publickey_frommemory(
self.user,
pkey_data,
passphrase=self.password if self.password is not None else b'',
)

def _password_auth(self):
self.session.userauth_password(self.user, self.password)

Expand Down
14 changes: 10 additions & 4 deletions pssh/clients/ssh/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import logging

from .single import SSHClient
from ..common import _validate_pkey_path
from ..common import _validate_pkey_path, _validate_pkey
from ..base.parallel import BaseParallelSSHClient
from ...constants import DEFAULT_RETRIES, RETRY_DELAY

Expand Down Expand Up @@ -54,7 +54,8 @@ def __init__(self, hosts, user=None, password=None, port=22, pkey=None,
:type port: int
:param pkey: Private key file path to use. Path must be either absolute
path or relative to user home directory like ``~/<path>``.
:type pkey: str
Bytes type input is used as private key data for authentication.
:type pkey: str or bytes
:param cert_file: Public key signed certificate file to use for
authentication. The corresponding private key must also be provided
via ``pkey`` parameter.
Expand Down Expand Up @@ -141,7 +142,7 @@ def __init__(self, hosts, user=None, password=None, port=22, pkey=None,
identity_auth=identity_auth,
ipv6_only=ipv6_only,
)
self.pkey = _validate_pkey_path(pkey)
self.pkey = _validate_pkey(pkey)
self.cert_file = _validate_pkey_path(cert_file)
self.forward_ssh_agent = forward_ssh_agent
self.gssapi_auth = gssapi_auth
Expand Down Expand Up @@ -235,9 +236,14 @@ def _make_ssh_client(self, host_i, host):
or self._host_clients[(host_i, host)] is None:
_user, _port, _password, _pkey, _, _, _, _, _ = \
self._get_host_config_values(host_i, host)
if isinstance(self.pkey, str):
with open(_pkey, 'rb') as fh:
_pkey_data = fh.read()
else:
_pkey_data = _pkey
_client = SSHClient(
host, user=_user, password=_password, port=_port,
pkey=_pkey,
pkey=_pkey_data,
cert_file=self.cert_file,
num_retries=self.num_retries,
timeout=self.timeout,
Expand Down
19 changes: 15 additions & 4 deletions pssh/clients/ssh/single.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
from gevent import sleep, spawn, Timeout as GTimeout, joinall
from ssh import options
from ssh.session import Session, SSH_READ_PENDING, SSH_WRITE_PENDING
from ssh.key import import_privkey_file, import_cert_file, copy_cert_to_privkey
from ssh.key import import_privkey_file, import_cert_file, copy_cert_to_privkey,\
import_privkey_base64
from ssh.exceptions import EOF
from ssh.error_codes import SSH_AGAIN

Expand Down Expand Up @@ -62,7 +63,8 @@ def __init__(self, host,
:param pkey: Private key file path to use for authentication. Path must
be either absolute path or relative to user home directory
like ``~/<path>``.
:type pkey: str
Bytes type input is used as private key data for authentication.
:type pkey: str or bytes
:param cert_file: Public key signed certificate file to use for
authentication. The corresponding private key must also be provided
via ``pkey`` parameter.
Expand Down Expand Up @@ -106,7 +108,7 @@ def __init__(self, host,
:raises: :py:class:`pssh.exceptions.PKeyFileError` on errors finding
provided private key.
"""
self.cert_file = _validate_pkey_path(cert_file, host)
self.cert_file = _validate_pkey_path(cert_file)
self.gssapi_auth = gssapi_auth
self.gssapi_server_identity = gssapi_server_identity
self.gssapi_client_identity = gssapi_client_identity
Expand Down Expand Up @@ -175,13 +177,22 @@ def auth(self):
def _password_auth(self):
self.session.userauth_password(self.user, self.password)

def _pkey_auth(self, pkey_file, password=None):
def _pkey_file_auth(self, pkey_file, password=None):
pkey = import_privkey_file(pkey_file, passphrase=password if password is not None else '')
return self._pkey_obj_auth(pkey)

def _pkey_obj_auth(self, pkey):
if self.cert_file is not None:
logger.debug("Certificate file set - trying certificate authentication")
self._import_cert_file(pkey)
self.session.userauth_publickey(pkey)

def _pkey_from_memory(self, pkey_data):
_pkey = import_privkey_base64(
pkey_data,
passphrase=self.password if self.password is not None else b'')
return self._pkey_obj_auth(_pkey)

def _import_cert_file(self, pkey):
cert_key = import_cert_file(self.cert_file)
self.session.userauth_try_publickey(cert_key)
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
gevent>=1.3.0
ssh2-python>=0.22.0
ssh2-python>=0.27.0
ssh-python>=0.9.0
6 changes: 6 additions & 0 deletions tests/native/test_parallel_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,12 @@ def test_connect_auth(self):
client = ParallelSSHClient([self.host], pkey=self.user_key, port=self.port, num_retries=1)
joinall(client.connect_auth(), raise_error=True)

def test_pkey_from_memory(self):
with open(self.user_key, 'rb') as fh:
key = fh.read()
client = ParallelSSHClient([self.host], pkey=key, port=self.port, num_retries=1)
joinall(client.connect_auth(), raise_error=True)

def test_client_shells(self):
shells = self.client.open_shell()
self.client.run_shell_commands(shells, self.cmd)
Expand Down
6 changes: 6 additions & 0 deletions tests/native/test_single_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,12 @@ def test_scp_fail(self):
finally:
os.rmdir('adir')

def test_pkey_from_memory(self):
with open(self.user_key, 'rb') as fh:
key_data = fh.read()
SSHClient(self.host, port=self.port,
pkey=key_data, num_retries=1, timeout=1)

def test_execute(self):
host_out = self.client.run_command(self.cmd)
output = list(host_out.stdout)
Expand Down
6 changes: 6 additions & 0 deletions tests/ssh/test_parallel_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,12 @@ def _session(timeout=1):
client._host_clients[(0, self.host)].open_session = _session
self.assertRaises(Timeout, client.run_command, self.cmd)

def test_pkey_from_memory(self):
with open(self.user_key, 'rb') as fh:
key = fh.read()
client = ParallelSSHClient([self.host], pkey=key, port=self.port, num_retries=1)
joinall(client.connect_auth(), raise_error=True)

def test_join_timeout(self):
client = ParallelSSHClient([self.host], port=self.port,
pkey=self.user_key)
Expand Down
6 changes: 6 additions & 0 deletions tests/ssh/test_single_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,12 @@ def _session(timeout=2):
client.open_session = _session
self.assertRaises(GTimeout, client.run_command, self.cmd)

def test_pkey_from_memory(self):
with open(self.user_key, 'rb') as fh:
key_data = fh.read()
SSHClient(self.host, port=self.port,
pkey=key_data, num_retries=1, timeout=1)

def test_execute(self):
host_out = self.client.run_command(self.cmd)
output = list(host_out.stdout)
Expand Down

0 comments on commit 9c9b678

Please sign in to comment.