383 lines
11 KiB
Python
383 lines
11 KiB
Python
# ruff: noqa: UP007
|
|
|
|
import io
|
|
import logging
|
|
import time
|
|
from typing import Optional, Union
|
|
|
|
_logger = logging.getLogger(__name__)
|
|
|
|
|
|
try:
|
|
from paramiko import (
|
|
AutoAddPolicy,
|
|
DSSKey,
|
|
ECDSAKey,
|
|
Ed25519Key,
|
|
MissingHostKeyPolicy,
|
|
RSAKey,
|
|
SFTPClient,
|
|
SSHClient,
|
|
SSHException,
|
|
)
|
|
except ImportError:
|
|
_logger.error(
|
|
"Looks like 'paramiko' is not installed, please try to "
|
|
"install it using 'pip install paramiko'"
|
|
)
|
|
AutoAddPolicy = MissingHostKeyPolicy = RSAKey = SSHClient = None
|
|
|
|
|
|
class KeyLoader:
|
|
"""
|
|
Utility for loading private SSH key in supported formats.
|
|
"""
|
|
|
|
@staticmethod
|
|
def load_private_key(ssh_key: str) -> Union[RSAKey, DSSKey, ECDSAKey, Ed25519Key]:
|
|
"""
|
|
Load a private SSH key from a string.
|
|
"""
|
|
key_file = io.StringIO(ssh_key)
|
|
for key_class in (RSAKey, DSSKey, ECDSAKey, Ed25519Key):
|
|
try:
|
|
key_file.seek(0)
|
|
return key_class.from_private_key(key_file)
|
|
except SSHException:
|
|
_logger.warning(
|
|
f"KeyLoader: failed to load key through {key_class.__name__}."
|
|
)
|
|
_logger.error(
|
|
"KeyLoader: unable to load private key. "
|
|
"Unsupported format or invalid SSH key."
|
|
)
|
|
raise ValueError("Unsupported format or invalid SSH key.")
|
|
|
|
|
|
class SSHConnection:
|
|
"""
|
|
Class for managing SSH connection.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
host: str,
|
|
port: int,
|
|
username: str,
|
|
password: Optional[str] = None,
|
|
ssh_key: Optional[str] = None,
|
|
host_key: Optional[str] = None,
|
|
mode: str = "p", # "p" for password, "k" for key
|
|
allow_agent: bool = False,
|
|
timeout: int = 5000,
|
|
):
|
|
"""
|
|
Initialize the SSHConnection instance.
|
|
"""
|
|
self.host = host
|
|
self.port = port
|
|
self.username = username
|
|
self.password = password
|
|
self.ssh_key = ssh_key
|
|
self.host_key = host_key
|
|
self.mode = mode
|
|
self.allow_agent = allow_agent
|
|
self.timeout = timeout
|
|
self._ssh_client: Optional[SSHClient] = None
|
|
|
|
def connect(self) -> SSHClient:
|
|
"""
|
|
Connect to the SSH server.
|
|
"""
|
|
if self._ssh_client is not None:
|
|
return self._ssh_client
|
|
|
|
self._ssh_client = SSHClient()
|
|
self._ssh_client.load_system_host_keys()
|
|
|
|
if self.host_key:
|
|
self._ssh_client.set_missing_host_key_policy(
|
|
CustomHostKeyPolicy(self.host_key)
|
|
)
|
|
else:
|
|
self._ssh_client.set_missing_host_key_policy(AutoAddPolicy())
|
|
|
|
connect_params = {
|
|
"hostname": self.host,
|
|
"port": self.port,
|
|
"username": self.username,
|
|
"allow_agent": self.allow_agent,
|
|
"timeout": self.timeout,
|
|
}
|
|
|
|
if self.mode == "p":
|
|
if not self.password:
|
|
raise ValueError("For password mode, you need to pass a password.")
|
|
connect_params["password"] = self.password
|
|
elif self.mode == "k":
|
|
if not self.ssh_key:
|
|
raise ValueError("For key mode, you need to pass an SSH key.")
|
|
connect_params["pkey"] = KeyLoader.load_private_key(self.ssh_key)
|
|
else:
|
|
raise ValueError(f"Unsupported connection mode: {self.mode}")
|
|
|
|
self._ssh_client.connect(**connect_params)
|
|
return self._ssh_client
|
|
|
|
def disconnect(self) -> None:
|
|
"""
|
|
Disconnect the SSH connection.
|
|
"""
|
|
if self._ssh_client:
|
|
_logger.info("SSHConnection: closing SSH connection.")
|
|
self._ssh_client.close()
|
|
self._ssh_client = None
|
|
|
|
def get_transport(self):
|
|
"""
|
|
Get the SSH transport.
|
|
"""
|
|
if self._ssh_client is None:
|
|
self.connect()
|
|
return self._ssh_client.get_transport()
|
|
|
|
|
|
class CustomHostKeyPolicy(MissingHostKeyPolicy):
|
|
"""
|
|
Custom SSH host key policy for validating the server's host key.
|
|
|
|
This policy compares the server's host key (in Base64 format) with the expected key.
|
|
If they do not match, an SSHException is raised to prevent connecting
|
|
to an untrusted server. If they match, the key is added to the client's host keys.
|
|
"""
|
|
|
|
def __init__(self, expected_host_key: str):
|
|
"""
|
|
Initialize the policy with the expected host key.
|
|
|
|
Args:
|
|
expected_host_key (str): The expected host key in Base64 format.
|
|
"""
|
|
self.expected_host_key = expected_host_key
|
|
|
|
def missing_host_key(self, client, hostname, key):
|
|
"""
|
|
Called when the SSH client receives a host key from the server
|
|
that is not in its known hosts.
|
|
|
|
Args:
|
|
client: The SSH client instance.
|
|
hostname: The hostname of the server.
|
|
key: The host key received from the server.
|
|
|
|
Raises:
|
|
SSHException: If the received host key does not match the expected host key.
|
|
"""
|
|
received_key = key.get_base64()
|
|
if received_key != self.expected_host_key:
|
|
raise SSHException(f"Host key mismatch for {hostname}. ")
|
|
# If the key matches, add it to the client's known hosts
|
|
client._host_keys.add(hostname, key.get_name(), key)
|
|
|
|
|
|
class SftpService:
|
|
"""
|
|
Service for working with SFTP, using SSH connection.
|
|
"""
|
|
|
|
def __init__(self, connection: SSHConnection):
|
|
"""
|
|
Initialize the SftpService instance.
|
|
"""
|
|
self.connection = connection
|
|
self._sftp_client: Optional[SFTPClient] = None
|
|
|
|
def get_client(self) -> SFTPClient:
|
|
"""
|
|
Get the SFTP client.
|
|
"""
|
|
if self._sftp_client is None:
|
|
transport = self.connection.get_transport()
|
|
self._sftp_client = SFTPClient.from_transport(transport)
|
|
return self._sftp_client
|
|
|
|
def upload_file(self, file: Union[str, io.BytesIO], remote_path: str) -> None:
|
|
"""
|
|
Upload a file to the remote server.
|
|
"""
|
|
client = self.get_client()
|
|
if isinstance(file, io.BytesIO):
|
|
client.putfo(file, remote_path)
|
|
elif isinstance(file, str):
|
|
client.put(file, remote_path)
|
|
else:
|
|
raise TypeError(f"File type {type(file).__name__} is not supported.")
|
|
|
|
def download_file(self, remote_path: str) -> bytes:
|
|
"""
|
|
Download a file from the remote server.
|
|
"""
|
|
client = self.get_client()
|
|
with client.open(remote_path, "rb") as remote_file:
|
|
return remote_file.read()
|
|
|
|
def delete_file(self, remote_path: str) -> None:
|
|
"""
|
|
Delete a file from the remote server.
|
|
"""
|
|
client = self.get_client()
|
|
client.remove(remote_path)
|
|
|
|
def disconnect(self) -> None:
|
|
"""
|
|
Disconnect the SFTP client.
|
|
"""
|
|
if self._sftp_client:
|
|
_logger.info("SftpService: closing SFTP connection.")
|
|
self._sftp_client.close()
|
|
self._sftp_client = None
|
|
|
|
|
|
class CommandExecutor:
|
|
"""
|
|
Class for executing commands on a remote server.
|
|
"""
|
|
|
|
def __init__(self, connection: SSHConnection):
|
|
"""
|
|
Initialize the CommandExecutor instance.
|
|
"""
|
|
self.connection = connection
|
|
|
|
def exec_command(
|
|
self, command: str, sudo: Optional[str] = None
|
|
) -> tuple[int, list[str], list[str]]:
|
|
"""
|
|
Run a command on the remote server.
|
|
|
|
Args:
|
|
command (str): The command to execute.
|
|
sudo (Optional[str]): Sudo mode.
|
|
|
|
Returns:
|
|
tuple:
|
|
- exit_status (int)
|
|
- stdout (list[str])
|
|
- stderr (list[str])
|
|
"""
|
|
ssh_client = self.connection.connect()
|
|
use_sudo_with_password = sudo == "p" and self.connection.username != "root"
|
|
|
|
if use_sudo_with_password and not self.connection.password:
|
|
return 255, [], ["Sudo password not provided!"]
|
|
|
|
try:
|
|
stdin, stdout, stderr = ssh_client.exec_command(command)
|
|
if use_sudo_with_password:
|
|
stdin.write(self.connection.password + "\n")
|
|
stdin.flush()
|
|
exit_status = stdout.channel.recv_exit_status()
|
|
response = stdout.readlines()
|
|
error = stderr.readlines()
|
|
return exit_status, response, error
|
|
except Exception as e:
|
|
return 255, [], [str(e)]
|
|
|
|
|
|
class SSHManager:
|
|
"""
|
|
Facade for working with SSH connection, SFTP and command execution.
|
|
"""
|
|
|
|
_connection_cache = {}
|
|
|
|
def __new__(cls, connection: SSHConnection):
|
|
"""
|
|
Create a new SSHManager instance.
|
|
"""
|
|
key = (
|
|
connection.host,
|
|
connection.port,
|
|
connection.username,
|
|
connection.mode,
|
|
connection.allow_agent,
|
|
connection.password or "",
|
|
connection.ssh_key or "",
|
|
connection.host_key or "",
|
|
)
|
|
if key in cls._connection_cache:
|
|
instance, created_at, cached_timeout = cls._connection_cache[key]
|
|
# if timeout is changed, update the cached timeout
|
|
if connection.timeout != cached_timeout:
|
|
cls.delete_cache(key)
|
|
else:
|
|
_logger.info(
|
|
"Using cached SSH connection for "
|
|
"host=%s, port=%s, user=%s, mode=%s",
|
|
connection.host,
|
|
connection.port,
|
|
connection.username,
|
|
connection.mode,
|
|
)
|
|
return instance
|
|
|
|
_logger.info(
|
|
"Creating new SSH connection for host=%s, port=%s, user=%s, mode=%s",
|
|
connection.host,
|
|
connection.port,
|
|
connection.username,
|
|
connection.mode,
|
|
)
|
|
instance = super().__new__(cls)
|
|
cls._connection_cache[key] = (instance, time.time(), connection.timeout)
|
|
return instance
|
|
|
|
def __init__(self, connection: SSHConnection):
|
|
"""
|
|
Initialize the SSHManager instance.
|
|
"""
|
|
# initialize only once
|
|
if hasattr(self, "_initialized") and self._initialized:
|
|
return
|
|
self.connection = connection
|
|
self.command_executor = CommandExecutor(connection)
|
|
self.sftp_service = SftpService(connection)
|
|
self._initialized = True
|
|
|
|
@classmethod
|
|
def delete_cache(cls, key):
|
|
"""
|
|
Delete the cache of SSH connections.
|
|
"""
|
|
if key in SSHManager._connection_cache:
|
|
del SSHManager._connection_cache[key]
|
|
|
|
def disconnect(self) -> None:
|
|
"""
|
|
Disconnect the SSH connection and SFTP client.
|
|
"""
|
|
if self.sftp_service._sftp_client is not None:
|
|
self.sftp_service.disconnect()
|
|
|
|
if self.connection._ssh_client is not None:
|
|
self.connection.disconnect()
|
|
|
|
key = (
|
|
self.connection.host,
|
|
self.connection.port,
|
|
self.connection.username,
|
|
self.connection.mode,
|
|
self.connection.allow_agent,
|
|
self.connection.password or "",
|
|
self.connection.ssh_key or "",
|
|
self.connection.host_key or "",
|
|
)
|
|
self.delete_cache(key)
|
|
|
|
@classmethod
|
|
def get_connection_cache(cls):
|
|
"""
|
|
Get the connection cache.
|
|
"""
|
|
return cls._connection_cache
|