mirror of
https://github.com/langgenius/dify.git
synced 2026-02-09 15:10:13 -05:00
fix: update SSH worker command and add timeout handling
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
import shlex
|
||||
import stat
|
||||
import threading
|
||||
@@ -27,6 +28,8 @@ from core.virtual_environment.channel.exec import TransportEOFError
|
||||
from core.virtual_environment.channel.queue_transport import QueueTransportReadCloser
|
||||
from core.virtual_environment.channel.transport import TransportWriteCloser
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _SSHStdinTransport(TransportWriteCloser):
|
||||
def __init__(self, channel: Any):
|
||||
@@ -52,6 +55,10 @@ class SSHSandboxEnvironment(VirtualEnvironment):
|
||||
_DEFAULT_SSH_HOST = "agentbox"
|
||||
_DEFAULT_SSH_PORT = 22
|
||||
_DEFAULT_BASE_WORKING_PATH = "/workspace/sandboxes"
|
||||
_DEFAULT_SSH_CONNECT_TIMEOUT_SECONDS = 10
|
||||
_DEFAULT_SSH_OPERATION_TIMEOUT_SECONDS = 30
|
||||
_DEFAULT_COMMAND_MAX_RUNTIME_SECONDS = 60 * 60
|
||||
_COMMAND_TIMEOUT_EXIT_CODE = 124
|
||||
|
||||
class OptionsKey(StrEnum):
|
||||
SSH_HOST = "ssh_host"
|
||||
@@ -141,6 +148,7 @@ class SSHSandboxEnvironment(VirtualEnvironment):
|
||||
raise RuntimeError("SSH transport is not available")
|
||||
|
||||
channel = transport.open_session()
|
||||
channel.settimeout(self._DEFAULT_SSH_OPERATION_TIMEOUT_SECONDS)
|
||||
channel.set_combine_stderr(False)
|
||||
|
||||
execution_command = self._build_exec_command(command, environments, cwd)
|
||||
@@ -156,7 +164,7 @@ class SSHSandboxEnvironment(VirtualEnvironment):
|
||||
|
||||
threading.Thread(
|
||||
target=self._consume_channel_output,
|
||||
args=(pid, channel, stdout_transport, stderr_transport),
|
||||
args=(pid, channel, stdout_transport, stderr_transport, self._DEFAULT_COMMAND_MAX_RUNTIME_SECONDS),
|
||||
daemon=True,
|
||||
).start()
|
||||
|
||||
@@ -174,6 +182,7 @@ class SSHSandboxEnvironment(VirtualEnvironment):
|
||||
with self._client() as client:
|
||||
sftp = client.open_sftp()
|
||||
try:
|
||||
self._set_sftp_operation_timeout(sftp)
|
||||
self._sftp_mkdirs(sftp, str(PurePosixPath(destination_path).parent))
|
||||
with sftp.file(destination_path, "wb") as remote_file:
|
||||
remote_file.write(content.getvalue())
|
||||
@@ -185,6 +194,7 @@ class SSHSandboxEnvironment(VirtualEnvironment):
|
||||
with self._client() as client:
|
||||
sftp = client.open_sftp()
|
||||
try:
|
||||
self._set_sftp_operation_timeout(sftp)
|
||||
with sftp.file(source_path, "rb") as remote_file:
|
||||
return BytesIO(remote_file.read())
|
||||
finally:
|
||||
@@ -200,6 +210,7 @@ class SSHSandboxEnvironment(VirtualEnvironment):
|
||||
with self._client() as client:
|
||||
sftp = client.open_sftp()
|
||||
try:
|
||||
self._set_sftp_operation_timeout(sftp)
|
||||
pending = [root_directory]
|
||||
while pending and len(files) < limit:
|
||||
current_directory = pending.pop(0)
|
||||
@@ -261,8 +272,13 @@ class SSHSandboxEnvironment(VirtualEnvironment):
|
||||
password=password,
|
||||
look_for_keys=False,
|
||||
allow_agent=False,
|
||||
timeout=10,
|
||||
timeout=cls._DEFAULT_SSH_CONNECT_TIMEOUT_SECONDS,
|
||||
banner_timeout=cls._DEFAULT_SSH_CONNECT_TIMEOUT_SECONDS,
|
||||
auth_timeout=cls._DEFAULT_SSH_CONNECT_TIMEOUT_SECONDS,
|
||||
)
|
||||
transport = client.get_transport()
|
||||
if transport is not None:
|
||||
transport.set_keepalive(30)
|
||||
except Exception as e:
|
||||
with contextlib.suppress(Exception):
|
||||
client.close()
|
||||
@@ -341,9 +357,19 @@ class SSHSandboxEnvironment(VirtualEnvironment):
|
||||
command_body += shlex.join(command)
|
||||
return f"sh -lc {shlex.quote(command_body)}"
|
||||
|
||||
@staticmethod
|
||||
def _run_command(client: Any, command: str) -> bytes:
|
||||
_, stdout, stderr = client.exec_command(command)
|
||||
@classmethod
|
||||
def _run_command(cls, client: Any, command: str) -> bytes:
|
||||
_, stdout, stderr = client.exec_command(command, timeout=cls._DEFAULT_SSH_OPERATION_TIMEOUT_SECONDS)
|
||||
stdout.channel.settimeout(cls._DEFAULT_SSH_OPERATION_TIMEOUT_SECONDS)
|
||||
|
||||
deadline = time.monotonic() + cls._DEFAULT_COMMAND_MAX_RUNTIME_SECONDS
|
||||
while not stdout.channel.exit_status_ready():
|
||||
if time.monotonic() >= deadline:
|
||||
with contextlib.suppress(Exception):
|
||||
stdout.channel.close()
|
||||
raise TimeoutError(f"SSH command timed out after {cls._DEFAULT_COMMAND_MAX_RUNTIME_SECONDS}s")
|
||||
time.sleep(0.05)
|
||||
|
||||
exit_code = stdout.channel.recv_exit_status()
|
||||
stdout_data = stdout.read()
|
||||
stderr_data = stderr.read()
|
||||
@@ -360,13 +386,20 @@ class SSHSandboxEnvironment(VirtualEnvironment):
|
||||
channel: Any,
|
||||
stdout_transport: QueueTransportReadCloser,
|
||||
stderr_transport: QueueTransportReadCloser,
|
||||
max_runtime_seconds: int,
|
||||
) -> None:
|
||||
stdout_writer = stdout_transport.get_write_handler()
|
||||
stderr_writer = stderr_transport.get_write_handler()
|
||||
exit_code: int | None = None
|
||||
started_at = time.monotonic()
|
||||
|
||||
try:
|
||||
while True:
|
||||
if time.monotonic() - started_at >= max_runtime_seconds:
|
||||
exit_code = self._COMMAND_TIMEOUT_EXIT_CODE
|
||||
stderr_writer.write(f"Command timed out after {max_runtime_seconds}s".encode())
|
||||
break
|
||||
|
||||
if channel.recv_ready():
|
||||
stdout_writer.write(channel.recv(4096))
|
||||
if channel.recv_stderr_ready():
|
||||
@@ -377,6 +410,9 @@ class SSHSandboxEnvironment(VirtualEnvironment):
|
||||
break
|
||||
|
||||
time.sleep(0.05)
|
||||
except TimeoutError:
|
||||
logger.warning("SSH channel read timed out for command %s", pid)
|
||||
exit_code = self._COMMAND_TIMEOUT_EXIT_CODE
|
||||
finally:
|
||||
with contextlib.suppress(Exception):
|
||||
stdout_transport.close()
|
||||
@@ -388,6 +424,10 @@ class SSHSandboxEnvironment(VirtualEnvironment):
|
||||
with self._lock:
|
||||
self._commands[pid] = CommandStatus(status=CommandStatus.Status.COMPLETED, exit_code=exit_code)
|
||||
|
||||
def _set_sftp_operation_timeout(self, sftp: Any) -> None:
|
||||
with contextlib.suppress(Exception):
|
||||
sftp.get_channel().settimeout(self._DEFAULT_SSH_OPERATION_TIMEOUT_SECONDS)
|
||||
|
||||
@staticmethod
|
||||
def _parse_arch(raw_arch: str) -> Arch:
|
||||
arch = raw_arch.lower()
|
||||
|
||||
Reference in New Issue
Block a user