diff --git a/pyobs/vfs/__init__.py b/pyobs/vfs/__init__.py index ee17167e..f24dacd5 100644 --- a/pyobs/vfs/__init__.py +++ b/pyobs/vfs/__init__.py @@ -61,6 +61,7 @@ from .httpfile import HttpFile from .memfile import MemoryFile from .smbfile import SMBFile +from .sftpfile import SFTPFile from .sshfile import SSHFile from .tempfile import TempFile from .archivefile import ArchiveFile @@ -73,6 +74,7 @@ "HttpFile", "MemoryFile", "SMBFile", + "SFTPFile", "SSHFile", "TempFile", "ArchiveFile", diff --git a/pyobs/vfs/sftpfile.py b/pyobs/vfs/sftpfile.py new file mode 100644 index 00000000..6be5e816 --- /dev/null +++ b/pyobs/vfs/sftpfile.py @@ -0,0 +1,124 @@ +import asyncio +import os +from typing import Optional, Any, AnyStr, List +import paramiko +import paramiko.sftp + +from .file import VFSFile + + +class SFTPFile(VFSFile): + """VFS wrapper for a file that can be accessed over a SFTP connection.""" + + __module__ = "pyobs.vfs" + + def __init__( + self, + name: str, + mode: str = "r", + bufsize: int = -1, + hostname: Optional[str] = None, + port: int = 22, + username: Optional[str] = None, + password: Optional[str] = None, + keyfile: Optional[str] = None, + root: Optional[str] = None, + mkdir: bool = True, + **kwargs: Any, + ): + """Open/create a file over a SSH connection. + + Args: + name: Name of file. + mode: Open mode. + bufsize: Size of buffer size for SFTP connection. + hostname: Name of host to connect to. + port: Port on host to connect to. + username: Username to log in on host. + password: Password for username. + keyfile: Path to SSH key on local machine. + root: Root directory on host. + mkdir: Whether or not to automatically create directories. + """ + + # no root given? + if root is None: + raise ValueError("No root directory given.") + + # filename is not allowed to start with a / or contain .. + if name.startswith("/") or ".." in name: + raise ValueError("Only files within root directory are allowed.") + + # build filename + self.filename = name + full_path = os.path.join(root, name) + + # check + if hostname is None: + raise ValueError("No hostname given.") + + # connect + self._ssh = paramiko.SSHClient() + self._ssh.load_system_host_keys() + self._ssh.connect(hostname, port=port, username=username, password=password, key_filename=keyfile) + self._sftp = self._ssh.open_sftp() + + # need to create directory? + path = os.path.dirname(full_path) + try: + self._sftp.chdir(path) + except IOError: + if mkdir: + self._sftp.mkdir(path) + else: + raise ValueError("Cannot write into sub-directory with disabled mkdir option.") + + # open file + self._fd = self._sftp.file(full_path, mode) + + async def close(self) -> None: + """Close file.""" + self._sftp.close() + self._ssh.close() + + async def read(self, n: int = -1) -> AnyStr: + return self._fd.read(n) + + async def write(self, s: AnyStr) -> None: + self._fd.write(s) + + @staticmethod + async def listdir(path: str, **kwargs: Any) -> List[str]: + """Returns content of given path. + + Args: + path: Path to list. + kwargs: Parameters for specific file implementation (same as __init__). + + Returns: + List of files in path. + """ + + # connect + ssh = paramiko.SSHClient() + ssh.load_system_host_keys() + ssh.connect( + kwargs["hostname"], + port=kwargs["port"] if "port" in kwargs else 22, + username=kwargs["username"], + password=kwargs["password"] if "password" in kwargs else None, + key_filename=kwargs["keyfile"] if "keyfile" in kwargs else None, + ) + sftp = ssh.open_sftp() + + # list files in path + loop = asyncio.get_running_loop() + files = await loop.run_in_executor(None, sftp.listdir, os.path.join(kwargs["root"], path)) + + # disconnect and return list + sftp.close() + ssh.close() + return files + + +__all__ = ["SFTPFile"] diff --git a/pyobs/vfs/sshfile.py b/pyobs/vfs/sshfile.py index 8f683f7c..3fbd9210 100644 --- a/pyobs/vfs/sshfile.py +++ b/pyobs/vfs/sshfile.py @@ -51,7 +51,7 @@ def __init__( # build filename self.filename = name - full_path = os.path.join(root, name) + self._full_path = os.path.join(root, name) # check if hostname is None: @@ -61,31 +61,80 @@ def __init__( self._ssh = paramiko.SSHClient() self._ssh.load_system_host_keys() self._ssh.connect(hostname, port=port, username=username, password=password, key_filename=keyfile) - self._sftp = self._ssh.open_sftp() # need to create directory? - path = os.path.dirname(full_path) - try: - self._sftp.chdir(path) - except IOError: - if mkdir: - self._sftp.mkdir(path) - else: - raise ValueError("Cannot write into sub-directory with disabled mkdir option.") - - # open file - self._fd = self._sftp.file(full_path, mode) + if mkdir: + path = os.path.dirname(self._full_path) + self._ssh.exec_command(f"mkdir -p {path}") - async def close(self) -> None: - """Close file.""" - self._sftp.close() - self._ssh.close() + # build filename + self.filename = name + self.mode = mode + self._buffer = b"" if "b" in self.mode else "" + self._pos = 0 + self._open = True + + async def _download(self) -> None: + """For read access, download the file into a local buffer. + + Raises: + FileNotFoundError: If file could not be found. + """ + + _, stdout, stderr = self._ssh.exec_command(f"cat {self._full_path}") + self._buffer = stdout.read() async def read(self, n: int = -1) -> AnyStr: - return self._fd.read(n) + """Read number of bytes from stream. + + Args: + n: Number of bytes to read. Read until end, if -1. + + Returns: + Read bytes. + """ + + # load file + if len(self._buffer) == 0 and "r" in self.mode: + await self._download() + + # check size + if n == -1: + data = self._buffer + self._pos = len(self._buffer) - 1 + else: + # extract data to read + data = self._buffer[self._pos : self._pos + n] + self._pos += n + + # return data + return data async def write(self, s: AnyStr) -> None: - self._fd.write(s) + """Write data into the stream. + + Args: + b: Bytes of data to write. + """ + self._buffer += s + + async def close(self) -> None: + """Close stream.""" + + # write it? + if "w" in self.mode and self._open: + await self._upload() + + # set flag + self._open = False + + async def _upload(self) -> None: + """If in write mode, actually send the file to the SSH server.""" + + transport = self._ssh.get_transport() + with transport.open_channel(kind="session") as channel: + channel.exec_command(f"cat > {self._full_path}") + channel.sendall(self._buffer) @staticmethod async def listdir(path: str, **kwargs: Any) -> List[str]: @@ -109,16 +158,14 @@ async def listdir(path: str, **kwargs: Any) -> List[str]: password=kwargs["password"] if "password" in kwargs else None, key_filename=kwargs["keyfile"] if "keyfile" in kwargs else None, ) - sftp = ssh.open_sftp() - # list files in path - loop = asyncio.get_running_loop() - files = await loop.run_in_executor(None, sftp.listdir, os.path.join(kwargs["root"], path)) + p = os.path.join(kwargs["root"], path) + _, stdout, stderr = ssh.exec_command(f"ls -1 {p}") + files = stdout.readlines() # disconnect and return list - sftp.close() ssh.close() - return files + return [f.strip() for f in files] __all__ = ["SSHFile"]