Skip to content

Commit

Permalink
full implementation of ssh-agent authentication
Browse files Browse the repository at this point in the history
  • Loading branch information
marian-code committed Nov 19, 2021
1 parent db69cdc commit 8190a28
Show file tree
Hide file tree
Showing 8 changed files with 130 additions and 72 deletions.
26 changes: 19 additions & 7 deletions ssh_utilities/abstract/_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ class ConnectionABC(ABC):

__name__: str
__abstractmethods__: FrozenSet[str]
password: Optional[str]
address: Optional[str]
username: str
pkey_file: Optional[Union[str, "Path"]]
allow_agent: Optional[bool]

@abstractmethod
def __str__(self):
Expand Down Expand Up @@ -135,11 +140,11 @@ def to_dict(self):
@staticmethod
def _to_dict(connection_name: str, host_name: str, address: Optional[str],
user_name: str, ssh_key: Optional[Union[Path, str]],
thread_safe: bool
thread_safe: bool, allow_agent: bool
) -> Dict[str, Optional[Union[str, bool, int]]]:

if ssh_key is None:
key_path = ssh_key
key_path = None
else:
key_path = str(Path(ssh_key).resolve())

Expand All @@ -149,12 +154,14 @@ def _to_dict(connection_name: str, host_name: str, address: Optional[str],
"user_name": user_name,
"ssh_key": key_path,
"address": address,
"thread_safe": thread_safe
"thread_safe": thread_safe,
"allow_agent": allow_agent,
}

def _to_str(self, connection_name: str, host_name: str,
address: Optional[str], user_name: str,
ssh_key: Optional[Union[Path, str]], thread_safe: bool) -> str:
ssh_key: Optional[Union[Path, str]], thread_safe: bool,
allow_agent: bool) -> str:
"""Aims to ease persistance, returns string representation of instance.
With this method all data needed to initialize class are saved to sting
Expand All @@ -177,6 +184,9 @@ def _to_str(self, connection_name: str, host_name: str,
make connection object thread safe so it can be safely accessed
from any number of threads, it is disabled by default to avoid
performance penalty of threading locks
allow_agent: bool
allows use of ssh agent for connection authentication, when this is
`True` key for the host does not have to be available.
Returns
-------
Expand All @@ -188,7 +198,8 @@ def _to_str(self, connection_name: str, host_name: str,
:class:`ssh_utilities.conncection.Connection`
"""
return dumps(self._to_dict(connection_name, host_name, address,
user_name, ssh_key, thread_safe))
user_name, ssh_key, thread_safe,
allow_agent))

def __del__(self):
self.close(quiet=True)
Expand All @@ -206,10 +217,11 @@ def __setstate__(self, state: dict):
self.__init__(state["address"], state["user_name"], # type: ignore
pkey_file=state["ssh_key"],
server_name=state["server_name"],
quiet=True, thread_safe=state["thread_safe"])
quiet=True, thread_safe=state["thread_safe"],
allow_agent=state["allow_agent"])

def __enter__(self: "CONN_TYPE") -> "CONN_TYPE":
return self

def __exit__(self, exc_type, exc_value, exc_traceback):
self.close(quiet=True)
self.close(quiet=True)
120 changes: 78 additions & 42 deletions ssh_utilities/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@

__all__ = ["Connection"]

logging.getLogger(__name__)
log = logging.getLogger(__name__)

# guard for when readthedocs is building documentation or travis
# is running CI build
Expand Down Expand Up @@ -121,21 +121,22 @@ class Connection(metaclass=_ConnectionMeta):

@overload
def __new__(cls, ssh_server: str, local: Literal[False], quiet: bool,
thread_safe: bool) -> SSHConnection:
thread_safe: bool, allow_agent: bool) -> SSHConnection:
...

@overload
def __new__(cls, ssh_server: str, local: Literal[True], quiet: bool,
thread_safe: bool) -> LocalConnection:
thread_safe: bool, allow_agent: bool) -> LocalConnection:
...

@overload
def __new__(cls, ssh_server: str, local: bool, quiet: bool,
thread_safe: bool) -> Union[SSHConnection, LocalConnection]:
thread_safe: bool, allow_agent: bool
) -> Union[SSHConnection, LocalConnection]:
...

def __new__(cls, ssh_server: str, local: bool = False, quiet: bool = False,
thread_safe: bool = False):
thread_safe: bool = False, allow_agent: bool = True):
"""Get Connection based on one of names defined in .ssh/config file.
If name of local PC is passed initilize LocalConnection
Expand All @@ -152,11 +153,14 @@ def __new__(cls, ssh_server: str, local: bool = False, quiet: bool = False,
make connection object thread safe so it can be safely accessed
from any number of threads, it is disabled by default to avoid
performance penalty of threading locks
allow_agent: bool
allows use of ssh agent for connection authentication, when this is
`True` key for the host does not have to be available.
Raises
------
KeyError
if server name is not in config file
if server name is not in config file and allow agent is false
Returns
-------
Expand All @@ -173,14 +177,36 @@ def __new__(cls, ssh_server: str, local: bool = False, quiet: bool = False,
raise KeyError(f"couldn't find login credentials for {ssh_server}:"
f" {e}")
else:
# get username and address
try:
return cls.open(credentials["user"], credentials["hostname"],
credentials["identityfile"][0],
server_name=ssh_server, quiet=quiet,
thread_safe=thread_safe)
user = credentials["user"]
hostname = credentials["hostname"]
except KeyError as e:
raise KeyError(f"{RED}missing key in config dictionary for "
f"{ssh_server}: {R}{e}")
raise KeyError(
"Cannot find username or hostname for specified host"
)

# get key or use agent
if allow_agent:
log.info(f"no private key supplied for {hostname}, will try "
f"to authenticate through ssh-agent")
pkey_file = None
else:
log.info(f"private key found for host: {hostname}")
try:
pkey_file = credentials["identityfile"][0]
except (KeyError, IndexError) as e:
raise KeyError(f"No private key found for specified host")

return cls.open(
user,
hostname,
ssh_key_file=pkey_file,
allow_agent=allow_agent,
server_name=ssh_server,
quiet=quiet,
thread_safe=thread_safe
)

@classmethod
def get_available_hosts(cls) -> List[str]:
Expand Down Expand Up @@ -212,7 +238,8 @@ def get(cls, *args, **kwargs):
get_connection = get

@classmethod
def add_hosts(cls, hosts: Union["_HOSTS", List["_HOSTS"]]):
def add_hosts(cls, hosts: Union["_HOSTS", List["_HOSTS"]],
allow_agent: Union[bool, List[bool]]):
"""Add or override availbale host read fron ssh config file.
You can use supplied config parser to parse some externaf ssh config
Expand All @@ -223,15 +250,22 @@ def add_hosts(cls, hosts: Union["_HOSTS", List["_HOSTS"]]):
hosts : Union[_HOSTS, List[_HOSTS]]
dictionary or a list of dictionaries containing keys: `user`,
`hostname` and `identityfile`
allow_agent: Union[bool, List[bool]]
bool or a list of bools with corresponding length to list of hosts.
if only one bool is passed in, it will be used for all host entries
See also
--------
:func:ssh_utilities.config_parser
"""
if not isinstance(hosts, list):
hosts = [hosts]
if not isinstance(allow_agent, list):
allow_agent = [allow_agent] * len(hosts)

for h in hosts:
for h, a in zip(hosts, allow_agent):
if a:
h["identityfile"][0] = None
if not isinstance(h["identityfile"], list):
h["identityfile"] = [h["identityfile"]]
h["identityfile"][0] = os.path.abspath(
Expand Down Expand Up @@ -300,7 +334,7 @@ def open(ssh_username: str, ssh_server: None = None,
ssh_password: Optional[str] = None,
server_name: Optional[str] = None, quiet: bool = False,
thread_safe: bool = False,
ssh_allow_agent: bool = False) -> LocalConnection:
allow_agent: bool = False) -> LocalConnection:
...

@overload
Expand All @@ -310,7 +344,7 @@ def open(ssh_username: str, ssh_server: str,
ssh_password: Optional[str] = None,
server_name: Optional[str] = None, quiet: bool = False,
thread_safe: bool = False,
ssh_allow_agent: bool = False) -> SSHConnection:
allow_agent: bool = False) -> SSHConnection:
...

@staticmethod
Expand All @@ -319,7 +353,7 @@ def open(ssh_username: str, ssh_server: Optional[str] = "",
ssh_password: Optional[str] = None,
server_name: Optional[str] = None, quiet: bool = False,
thread_safe: bool = False,
ssh_allow_agent: bool = False):
allow_agent: bool = False):
"""Initialize SSH or local connection.
Local connection is only a wrapper around os and shutil module methods
Expand All @@ -346,7 +380,7 @@ def open(ssh_username: str, ssh_server: Optional[str] = "",
make connection object thread safe so it can be safely accessed
from any number of threads, it is disabled by default to avoid
performance penalty of threading locks
ssh_allow_agent: bool
allow_agent: bool
allow the use of the ssh-agent to connect. Will disable ssh_key_file.
Warnings
Expand All @@ -355,27 +389,29 @@ def open(ssh_username: str, ssh_server: Optional[str] = "",
risk!
"""
if not ssh_server:
return LocalConnection(ssh_server, ssh_username,
pkey_file=ssh_key_file,
server_name=server_name, quiet=quiet)
else:
if ssh_allow_agent:
c = SSHConnection(ssh_server, ssh_username,
allow_agent=ssh_allow_agent, line_rewrite=True,
server_name=server_name, quiet=quiet,
thread_safe=thread_safe)
elif ssh_key_file:
c = SSHConnection(ssh_server, ssh_username,
pkey_file=ssh_key_file, line_rewrite=True,
server_name=server_name, quiet=quiet,
thread_safe=thread_safe)
else:
if not ssh_password:
ssh_password = getpass.getpass(prompt="Enter password: ")

c = SSHConnection(ssh_server, ssh_username,
password=ssh_password, line_rewrite=True,
server_name=server_name, quiet=quiet,
thread_safe=thread_safe)

return c
return LocalConnection(
ssh_server,
ssh_username,
pkey_file=ssh_key_file,
server_name=server_name,
quiet=quiet
)
elif allow_agent:
ssh_key_file = None
ssh_password = None
elif ssh_key_file:
ssh_password = None
elif not ssh_password:
ssh_password = getpass.getpass(prompt="Enter password: ")

return SSHConnection(
ssh_server,
ssh_username,
allow_agent=allow_agent,
pkey_file=ssh_key_file,
password=ssh_password,
line_rewrite=True,
server_name=server_name,
quiet=quiet,
thread_safe=thread_safe
)
8 changes: 5 additions & 3 deletions ssh_utilities/local/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,15 @@ def __init__(self, address: Optional[str], username: str,
password: Optional[str] = None,
pkey_file: Optional[Union[str, "Path"]] = None,
line_rewrite: bool = True, server_name: Optional[str] = None,
quiet: bool = False, thread_safe: bool = False) -> None:
quiet: bool = False, thread_safe: bool = False,
allow_agent: Optional[bool] = False) -> None:

# set login credentials
self.password = password
self.address = address
self.username = username
self.pkey_file = pkey_file
self.allow_agent = allow_agent

self.server_name = server_name if server_name else gethostname()
self.server_name = self.server_name.upper()
Expand Down Expand Up @@ -92,11 +94,11 @@ def subprocess(self) -> "_SUBPROCESS_LOCAL":

def __str__(self) -> str:
return self._to_str("LocalConnection", self.server_name, None,
self.username, None, True)
self.username, None, True, False)

def to_dict(self) -> Dict[str, Optional[Union[str, bool, int]]]:
return self._to_dict("LocalConnection", self.server_name, None,
self.username, None, True)
self.username, None, True, False)

@staticmethod
def close(*, quiet: bool = True):
Expand Down
12 changes: 7 additions & 5 deletions ssh_utilities/multi_connection/_persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,12 @@ def __getstate__(self):

def __setstate__(self, state: dict):
"""Initializes the object after load from pickle."""
ssh_servers, local, thread_safe = (
ssh_servers, local, thread_safe, allow_agent = (
self._parse_persistence_dict(state)
)

self.__init__(ssh_servers, local, quiet=True, # type: ignore
thread_safe=thread_safe)
thread_safe=thread_safe, allow_agent=allow_agent)

def to_dict(self) -> Dict[int, Dict[str, Optional[Union[str, bool,
int, None]]]]:
Expand Down Expand Up @@ -96,12 +96,14 @@ def _parse_persistence_dict(d: dict) -> Tuple[List[str], List[int],
ssh_servers = []
local = []
thread_safe = []
allow_agent = []
for j in d.values():
ssh_servers.append(j.pop("server_name"))
thread_safe.append(j.pop("thread_safe"))
local.append(not bool(j.pop("address")))
allow_agent.append(j.pop("allow_agent"))

return ssh_servers, local, thread_safe
return ssh_servers, local, thread_safe, allow_agent

@classmethod
def from_dict(cls, json: dict, quiet: bool = False
Expand Down Expand Up @@ -129,12 +131,12 @@ def from_dict(cls, json: dict, quiet: bool = False
KeyError
if required key is missing from string
"""
ssh_servers, local, thread_safe = (
ssh_servers, local, thread_safe, allow_agent = (
cls._parse_persistence_dict(json)
)

return cls(ssh_servers, local, quiet=quiet, # type: ignore
thread_safe=thread_safe)
thread_safe=thread_safe, allow_agent=allow_agent)

@classmethod
def from_str(cls, string: str, quiet: bool = False
Expand Down
5 changes: 4 additions & 1 deletion ssh_utilities/multi_connection/multi_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ class MultiConnection(DictInterface, Pesistence, ConnectionABC):

def __init__(self, ssh_servers: Union[List[str], str],
local: Union[List[bool], bool] = False, quiet: bool = False,
thread_safe: Union[List[bool], bool] = False) -> None:
thread_safe: Union[List[bool], bool] = False,
allow_agent: Union[List[bool], bool] = True) -> None:

# TODO somehow adjust number of workers if connection are deleted or
# TODO added
Expand All @@ -108,6 +109,8 @@ def __init__(self, ssh_servers: Union[List[str], str],
local = [local] * len(ssh_servers)
if not isinstance(thread_safe, list):
thread_safe = [thread_safe] * len(ssh_servers)
if not isinstance(allow_agent, list):
allow_agent = [allow_agent] * len(ssh_servers)

self._connections = defaultdict(deque)
for ss, l, ts in zip(ssh_servers, local, thread_safe):
Expand Down
Loading

0 comments on commit 8190a28

Please sign in to comment.