diff --git a/ssh_utilities/abstract/_connection.py b/ssh_utilities/abstract/_connection.py index e7e5cde..c7be7fb 100644 --- a/ssh_utilities/abstract/_connection.py +++ b/ssh_utilities/abstract/_connection.py @@ -34,6 +34,11 @@ class ConnectionABC(ABC): __name__: str __abstractmethods__: FrozenSet[str] + password: Optional[str] + address: Optional[str] + username: str + pkey_file: Optional[Union[str, "Path"]] + allow_agent: Optional[bool] @abstractmethod def __str__(self): @@ -135,11 +140,11 @@ def to_dict(self): @staticmethod def _to_dict(connection_name: str, host_name: str, address: Optional[str], user_name: str, ssh_key: Optional[Union[Path, str]], - thread_safe: bool + thread_safe: bool, allow_agent: bool ) -> Dict[str, Optional[Union[str, bool, int]]]: if ssh_key is None: - key_path = ssh_key + key_path = None else: key_path = str(Path(ssh_key).resolve()) @@ -149,12 +154,14 @@ def _to_dict(connection_name: str, host_name: str, address: Optional[str], "user_name": user_name, "ssh_key": key_path, "address": address, - "thread_safe": thread_safe + "thread_safe": thread_safe, + "allow_agent": allow_agent, } def _to_str(self, connection_name: str, host_name: str, address: Optional[str], user_name: str, - ssh_key: Optional[Union[Path, str]], thread_safe: bool) -> str: + ssh_key: Optional[Union[Path, str]], thread_safe: bool, + allow_agent: bool) -> str: """Aims to ease persistance, returns string representation of instance. With this method all data needed to initialize class are saved to sting @@ -177,6 +184,9 @@ def _to_str(self, connection_name: str, host_name: str, make connection object thread safe so it can be safely accessed from any number of threads, it is disabled by default to avoid performance penalty of threading locks + allow_agent: bool + allows use of ssh agent for connection authentication, when this is + `True` key for the host does not have to be available. Returns ------- @@ -188,7 +198,8 @@ def _to_str(self, connection_name: str, host_name: str, :class:`ssh_utilities.conncection.Connection` """ return dumps(self._to_dict(connection_name, host_name, address, - user_name, ssh_key, thread_safe)) + user_name, ssh_key, thread_safe, + allow_agent)) def __del__(self): self.close(quiet=True) @@ -206,10 +217,11 @@ def __setstate__(self, state: dict): self.__init__(state["address"], state["user_name"], # type: ignore pkey_file=state["ssh_key"], server_name=state["server_name"], - quiet=True, thread_safe=state["thread_safe"]) + quiet=True, thread_safe=state["thread_safe"], + allow_agent=state["allow_agent"]) def __enter__(self: "CONN_TYPE") -> "CONN_TYPE": return self def __exit__(self, exc_type, exc_value, exc_traceback): - self.close(quiet=True) \ No newline at end of file + self.close(quiet=True) diff --git a/ssh_utilities/connection.py b/ssh_utilities/connection.py index 1241f7e..bade7aa 100644 --- a/ssh_utilities/connection.py +++ b/ssh_utilities/connection.py @@ -34,7 +34,7 @@ __all__ = ["Connection"] -logging.getLogger(__name__) +log = logging.getLogger(__name__) # guard for when readthedocs is building documentation or travis # is running CI build @@ -121,21 +121,22 @@ class Connection(metaclass=_ConnectionMeta): @overload def __new__(cls, ssh_server: str, local: Literal[False], quiet: bool, - thread_safe: bool) -> SSHConnection: + thread_safe: bool, allow_agent: bool) -> SSHConnection: ... @overload def __new__(cls, ssh_server: str, local: Literal[True], quiet: bool, - thread_safe: bool) -> LocalConnection: + thread_safe: bool, allow_agent: bool) -> LocalConnection: ... @overload def __new__(cls, ssh_server: str, local: bool, quiet: bool, - thread_safe: bool) -> Union[SSHConnection, LocalConnection]: + thread_safe: bool, allow_agent: bool + ) -> Union[SSHConnection, LocalConnection]: ... def __new__(cls, ssh_server: str, local: bool = False, quiet: bool = False, - thread_safe: bool = False): + thread_safe: bool = False, allow_agent: bool = True): """Get Connection based on one of names defined in .ssh/config file. If name of local PC is passed initilize LocalConnection @@ -152,11 +153,14 @@ def __new__(cls, ssh_server: str, local: bool = False, quiet: bool = False, make connection object thread safe so it can be safely accessed from any number of threads, it is disabled by default to avoid performance penalty of threading locks + allow_agent: bool + allows use of ssh agent for connection authentication, when this is + `True` key for the host does not have to be available. Raises ------ KeyError - if server name is not in config file + if server name is not in config file and allow agent is false Returns ------- @@ -173,14 +177,36 @@ def __new__(cls, ssh_server: str, local: bool = False, quiet: bool = False, raise KeyError(f"couldn't find login credentials for {ssh_server}:" f" {e}") else: + # get username and address try: - return cls.open(credentials["user"], credentials["hostname"], - credentials["identityfile"][0], - server_name=ssh_server, quiet=quiet, - thread_safe=thread_safe) + user = credentials["user"] + hostname = credentials["hostname"] except KeyError as e: - raise KeyError(f"{RED}missing key in config dictionary for " - f"{ssh_server}: {R}{e}") + raise KeyError( + "Cannot find username or hostname for specified host" + ) + + # get key or use agent + if allow_agent: + log.info(f"no private key supplied for {hostname}, will try " + f"to authenticate through ssh-agent") + pkey_file = None + else: + log.info(f"private key found for host: {hostname}") + try: + pkey_file = credentials["identityfile"][0] + except (KeyError, IndexError) as e: + raise KeyError(f"No private key found for specified host") + + return cls.open( + user, + hostname, + ssh_key_file=pkey_file, + allow_agent=allow_agent, + server_name=ssh_server, + quiet=quiet, + thread_safe=thread_safe + ) @classmethod def get_available_hosts(cls) -> List[str]: @@ -212,7 +238,8 @@ def get(cls, *args, **kwargs): get_connection = get @classmethod - def add_hosts(cls, hosts: Union["_HOSTS", List["_HOSTS"]]): + def add_hosts(cls, hosts: Union["_HOSTS", List["_HOSTS"]], + allow_agent: Union[bool, List[bool]]): """Add or override availbale host read fron ssh config file. You can use supplied config parser to parse some externaf ssh config @@ -223,6 +250,9 @@ def add_hosts(cls, hosts: Union["_HOSTS", List["_HOSTS"]]): hosts : Union[_HOSTS, List[_HOSTS]] dictionary or a list of dictionaries containing keys: `user`, `hostname` and `identityfile` + allow_agent: Union[bool, List[bool]] + bool or a list of bools with corresponding length to list of hosts. + if only one bool is passed in, it will be used for all host entries See also -------- @@ -230,8 +260,12 @@ def add_hosts(cls, hosts: Union["_HOSTS", List["_HOSTS"]]): """ if not isinstance(hosts, list): hosts = [hosts] + if not isinstance(allow_agent, list): + allow_agent = [allow_agent] * len(hosts) - for h in hosts: + for h, a in zip(hosts, allow_agent): + if a: + h["identityfile"][0] = None if not isinstance(h["identityfile"], list): h["identityfile"] = [h["identityfile"]] h["identityfile"][0] = os.path.abspath( @@ -300,7 +334,7 @@ def open(ssh_username: str, ssh_server: None = None, ssh_password: Optional[str] = None, server_name: Optional[str] = None, quiet: bool = False, thread_safe: bool = False, - ssh_allow_agent: bool = False) -> LocalConnection: + allow_agent: bool = False) -> LocalConnection: ... @overload @@ -310,7 +344,7 @@ def open(ssh_username: str, ssh_server: str, ssh_password: Optional[str] = None, server_name: Optional[str] = None, quiet: bool = False, thread_safe: bool = False, - ssh_allow_agent: bool = False) -> SSHConnection: + allow_agent: bool = False) -> SSHConnection: ... @staticmethod @@ -319,7 +353,7 @@ def open(ssh_username: str, ssh_server: Optional[str] = "", ssh_password: Optional[str] = None, server_name: Optional[str] = None, quiet: bool = False, thread_safe: bool = False, - ssh_allow_agent: bool = False): + allow_agent: bool = False): """Initialize SSH or local connection. Local connection is only a wrapper around os and shutil module methods @@ -346,7 +380,7 @@ def open(ssh_username: str, ssh_server: Optional[str] = "", make connection object thread safe so it can be safely accessed from any number of threads, it is disabled by default to avoid performance penalty of threading locks - ssh_allow_agent: bool + allow_agent: bool allow the use of the ssh-agent to connect. Will disable ssh_key_file. Warnings @@ -355,27 +389,29 @@ def open(ssh_username: str, ssh_server: Optional[str] = "", risk! """ if not ssh_server: - return LocalConnection(ssh_server, ssh_username, - pkey_file=ssh_key_file, - server_name=server_name, quiet=quiet) - else: - if ssh_allow_agent: - c = SSHConnection(ssh_server, ssh_username, - allow_agent=ssh_allow_agent, line_rewrite=True, - server_name=server_name, quiet=quiet, - thread_safe=thread_safe) - elif ssh_key_file: - c = SSHConnection(ssh_server, ssh_username, - pkey_file=ssh_key_file, line_rewrite=True, - server_name=server_name, quiet=quiet, - thread_safe=thread_safe) - else: - if not ssh_password: - ssh_password = getpass.getpass(prompt="Enter password: ") - - c = SSHConnection(ssh_server, ssh_username, - password=ssh_password, line_rewrite=True, - server_name=server_name, quiet=quiet, - thread_safe=thread_safe) - - return c + return LocalConnection( + ssh_server, + ssh_username, + pkey_file=ssh_key_file, + server_name=server_name, + quiet=quiet + ) + elif allow_agent: + ssh_key_file = None + ssh_password = None + elif ssh_key_file: + ssh_password = None + elif not ssh_password: + ssh_password = getpass.getpass(prompt="Enter password: ") + + return SSHConnection( + ssh_server, + ssh_username, + allow_agent=allow_agent, + pkey_file=ssh_key_file, + password=ssh_password, + line_rewrite=True, + server_name=server_name, + quiet=quiet, + thread_safe=thread_safe + ) diff --git a/ssh_utilities/local/local.py b/ssh_utilities/local/local.py index a276c5e..e91d91d 100644 --- a/ssh_utilities/local/local.py +++ b/ssh_utilities/local/local.py @@ -30,13 +30,15 @@ def __init__(self, address: Optional[str], username: str, password: Optional[str] = None, pkey_file: Optional[Union[str, "Path"]] = None, line_rewrite: bool = True, server_name: Optional[str] = None, - quiet: bool = False, thread_safe: bool = False) -> None: + quiet: bool = False, thread_safe: bool = False, + allow_agent: Optional[bool] = False) -> None: # set login credentials self.password = password self.address = address self.username = username self.pkey_file = pkey_file + self.allow_agent = allow_agent self.server_name = server_name if server_name else gethostname() self.server_name = self.server_name.upper() @@ -92,11 +94,11 @@ def subprocess(self) -> "_SUBPROCESS_LOCAL": def __str__(self) -> str: return self._to_str("LocalConnection", self.server_name, None, - self.username, None, True) + self.username, None, True, False) def to_dict(self) -> Dict[str, Optional[Union[str, bool, int]]]: return self._to_dict("LocalConnection", self.server_name, None, - self.username, None, True) + self.username, None, True, False) @staticmethod def close(*, quiet: bool = True): diff --git a/ssh_utilities/multi_connection/_persistence.py b/ssh_utilities/multi_connection/_persistence.py index f33dc96..18ec836 100644 --- a/ssh_utilities/multi_connection/_persistence.py +++ b/ssh_utilities/multi_connection/_persistence.py @@ -55,12 +55,12 @@ def __getstate__(self): def __setstate__(self, state: dict): """Initializes the object after load from pickle.""" - ssh_servers, local, thread_safe = ( + ssh_servers, local, thread_safe, allow_agent = ( self._parse_persistence_dict(state) ) self.__init__(ssh_servers, local, quiet=True, # type: ignore - thread_safe=thread_safe) + thread_safe=thread_safe, allow_agent=allow_agent) def to_dict(self) -> Dict[int, Dict[str, Optional[Union[str, bool, int, None]]]]: @@ -96,12 +96,14 @@ def _parse_persistence_dict(d: dict) -> Tuple[List[str], List[int], ssh_servers = [] local = [] thread_safe = [] + allow_agent = [] for j in d.values(): ssh_servers.append(j.pop("server_name")) thread_safe.append(j.pop("thread_safe")) local.append(not bool(j.pop("address"))) + allow_agent.append(j.pop("allow_agent")) - return ssh_servers, local, thread_safe + return ssh_servers, local, thread_safe, allow_agent @classmethod def from_dict(cls, json: dict, quiet: bool = False @@ -129,12 +131,12 @@ def from_dict(cls, json: dict, quiet: bool = False KeyError if required key is missing from string """ - ssh_servers, local, thread_safe = ( + ssh_servers, local, thread_safe, allow_agent = ( cls._parse_persistence_dict(json) ) return cls(ssh_servers, local, quiet=quiet, # type: ignore - thread_safe=thread_safe) + thread_safe=thread_safe, allow_agent=allow_agent) @classmethod def from_str(cls, string: str, quiet: bool = False diff --git a/ssh_utilities/multi_connection/multi_connection.py b/ssh_utilities/multi_connection/multi_connection.py index de9039c..4e8ef9a 100644 --- a/ssh_utilities/multi_connection/multi_connection.py +++ b/ssh_utilities/multi_connection/multi_connection.py @@ -96,7 +96,8 @@ class MultiConnection(DictInterface, Pesistence, ConnectionABC): def __init__(self, ssh_servers: Union[List[str], str], local: Union[List[bool], bool] = False, quiet: bool = False, - thread_safe: Union[List[bool], bool] = False) -> None: + thread_safe: Union[List[bool], bool] = False, + allow_agent: Union[List[bool], bool] = True) -> None: # TODO somehow adjust number of workers if connection are deleted or # TODO added @@ -108,6 +109,8 @@ def __init__(self, ssh_servers: Union[List[str], str], local = [local] * len(ssh_servers) if not isinstance(thread_safe, list): thread_safe = [thread_safe] * len(ssh_servers) + if not isinstance(allow_agent, list): + allow_agent = [allow_agent] * len(ssh_servers) self._connections = defaultdict(deque) for ss, l, ts in zip(ssh_servers, local, thread_safe): diff --git a/ssh_utilities/remote/path.py b/ssh_utilities/remote/path.py index ba94ef6..3ea19cd 100644 --- a/ssh_utilities/remote/path.py +++ b/ssh_utilities/remote/path.py @@ -336,8 +336,9 @@ def touch(self, mode: int = 0o666, exist_ok: bool = True): self.chmod(mode=mode) def absolute(self): - """Return an absolute version of this path. This function works - even if the path doesn't point to anything. + """Return an absolute version of this path. + + This function works even if the path doesn't point to anything. No normalization is done, i.e. all '.' and '..' will be kept along. Use resolve() to get the canonical path to a file. @@ -354,8 +355,9 @@ def absolute(self): return obj def expanduser(self): - """ Return a new path with expanded ~ and ~user constructs - (as returned by os.path.expanduser) + """Return a new path with expanded ~ and ~user constructs. + + As returned by os.path.expanduser """ if (not (self._drv or self._root) and self._parts and self._parts[0][:1] == '~'): diff --git a/ssh_utilities/remote/remote.py b/ssh_utilities/remote/remote.py index fcc2a3a..92980e0 100644 --- a/ssh_utilities/remote/remote.py +++ b/ssh_utilities/remote/remote.py @@ -101,7 +101,7 @@ def __init__(self, address: str, username: str, lprnt = lprint(quiet) if allow_agent: - msg = f"Will login with ssh-agent" + msg = "Will login with ssh-agent" lprnt(msg) log.info(msg) if pkey_file: @@ -139,18 +139,18 @@ def __init__(self, address: str, username: str, # paramiko connection if allow_agent: - self.pkey = None + self._pkey = None self.password = None elif pkey_file: for key in _KEYS: try: - self.pkey = key.from_private_key_file( + self._pkey = key.from_private_key_file( self._path2str(pkey_file) ) except paramiko.SSHException: log.info(f"could not parse key with {key.__name__}") elif password: - self.pkey = None + self._pkey = None else: raise RuntimeError("Must input password or path to pkey") @@ -214,12 +214,13 @@ def subprocess(self) -> "_SUBPROCESS_REMOTE": def __str__(self) -> str: return self._to_str("SSHConnection", self.server_name, self.address, - self.username, self.pkey_file, self.thread_safe) + self.username, self.pkey_file, self.thread_safe, + self.allow_agent) def to_dict(self) -> Dict[str, Optional[Union[str, bool, int]]]: return self._to_dict("SSHConnection", self.server_name, self.address, - self.username, self.pkey_file, - self.thread_safe) + self.username, self.pkey_file, self.thread_safe, + self.allow_agent) @check_connections() def close(self, *, quiet: bool = True): @@ -258,10 +259,10 @@ def _get_ssh(self, authentication_attempts: int = 0): if self.allow_agent: # connect using ssh-agent self.c.connect(self.address, username=self.username, allow_agent=True) - if self.pkey: + if self._pkey: # connect with public key self.c.connect(self.address, username=self.username, - pkey=self.pkey) + pkey=self._pkey) else: # if password was passed try to connect with it self.c.connect(self.address, username=self.username, diff --git a/ssh_utilities/version.py b/ssh_utilities/version.py index a9b029e..a4219a8 100644 --- a/ssh_utilities/version.py +++ b/ssh_utilities/version.py @@ -1 +1 @@ -__version__ = "0.10.0" \ No newline at end of file +__version__ = "0.11.0" \ No newline at end of file