fix: update SSH worker command and add timeout handling

This commit is contained in:
Harry
2026-02-09 19:52:13 +08:00
parent 9e10b73b54
commit 92e58aa624

View File

@@ -1,6 +1,7 @@
from __future__ import annotations
import contextlib
import logging
import shlex
import stat
import threading
@@ -27,6 +28,8 @@ from core.virtual_environment.channel.exec import TransportEOFError
from core.virtual_environment.channel.queue_transport import QueueTransportReadCloser
from core.virtual_environment.channel.transport import TransportWriteCloser
logger = logging.getLogger(__name__)
class _SSHStdinTransport(TransportWriteCloser):
def __init__(self, channel: Any):
@@ -52,6 +55,10 @@ class SSHSandboxEnvironment(VirtualEnvironment):
_DEFAULT_SSH_HOST = "agentbox"
_DEFAULT_SSH_PORT = 22
_DEFAULT_BASE_WORKING_PATH = "/workspace/sandboxes"
_DEFAULT_SSH_CONNECT_TIMEOUT_SECONDS = 10
_DEFAULT_SSH_OPERATION_TIMEOUT_SECONDS = 30
_DEFAULT_COMMAND_MAX_RUNTIME_SECONDS = 60 * 60
_COMMAND_TIMEOUT_EXIT_CODE = 124
class OptionsKey(StrEnum):
SSH_HOST = "ssh_host"
@@ -141,6 +148,7 @@ class SSHSandboxEnvironment(VirtualEnvironment):
raise RuntimeError("SSH transport is not available")
channel = transport.open_session()
channel.settimeout(self._DEFAULT_SSH_OPERATION_TIMEOUT_SECONDS)
channel.set_combine_stderr(False)
execution_command = self._build_exec_command(command, environments, cwd)
@@ -156,7 +164,7 @@ class SSHSandboxEnvironment(VirtualEnvironment):
threading.Thread(
target=self._consume_channel_output,
args=(pid, channel, stdout_transport, stderr_transport),
args=(pid, channel, stdout_transport, stderr_transport, self._DEFAULT_COMMAND_MAX_RUNTIME_SECONDS),
daemon=True,
).start()
@@ -174,6 +182,7 @@ class SSHSandboxEnvironment(VirtualEnvironment):
with self._client() as client:
sftp = client.open_sftp()
try:
self._set_sftp_operation_timeout(sftp)
self._sftp_mkdirs(sftp, str(PurePosixPath(destination_path).parent))
with sftp.file(destination_path, "wb") as remote_file:
remote_file.write(content.getvalue())
@@ -185,6 +194,7 @@ class SSHSandboxEnvironment(VirtualEnvironment):
with self._client() as client:
sftp = client.open_sftp()
try:
self._set_sftp_operation_timeout(sftp)
with sftp.file(source_path, "rb") as remote_file:
return BytesIO(remote_file.read())
finally:
@@ -200,6 +210,7 @@ class SSHSandboxEnvironment(VirtualEnvironment):
with self._client() as client:
sftp = client.open_sftp()
try:
self._set_sftp_operation_timeout(sftp)
pending = [root_directory]
while pending and len(files) < limit:
current_directory = pending.pop(0)
@@ -261,8 +272,13 @@ class SSHSandboxEnvironment(VirtualEnvironment):
password=password,
look_for_keys=False,
allow_agent=False,
timeout=10,
timeout=cls._DEFAULT_SSH_CONNECT_TIMEOUT_SECONDS,
banner_timeout=cls._DEFAULT_SSH_CONNECT_TIMEOUT_SECONDS,
auth_timeout=cls._DEFAULT_SSH_CONNECT_TIMEOUT_SECONDS,
)
transport = client.get_transport()
if transport is not None:
transport.set_keepalive(30)
except Exception as e:
with contextlib.suppress(Exception):
client.close()
@@ -341,9 +357,19 @@ class SSHSandboxEnvironment(VirtualEnvironment):
command_body += shlex.join(command)
return f"sh -lc {shlex.quote(command_body)}"
@staticmethod
def _run_command(client: Any, command: str) -> bytes:
_, stdout, stderr = client.exec_command(command)
@classmethod
def _run_command(cls, client: Any, command: str) -> bytes:
_, stdout, stderr = client.exec_command(command, timeout=cls._DEFAULT_SSH_OPERATION_TIMEOUT_SECONDS)
stdout.channel.settimeout(cls._DEFAULT_SSH_OPERATION_TIMEOUT_SECONDS)
deadline = time.monotonic() + cls._DEFAULT_COMMAND_MAX_RUNTIME_SECONDS
while not stdout.channel.exit_status_ready():
if time.monotonic() >= deadline:
with contextlib.suppress(Exception):
stdout.channel.close()
raise TimeoutError(f"SSH command timed out after {cls._DEFAULT_COMMAND_MAX_RUNTIME_SECONDS}s")
time.sleep(0.05)
exit_code = stdout.channel.recv_exit_status()
stdout_data = stdout.read()
stderr_data = stderr.read()
@@ -360,13 +386,20 @@ class SSHSandboxEnvironment(VirtualEnvironment):
channel: Any,
stdout_transport: QueueTransportReadCloser,
stderr_transport: QueueTransportReadCloser,
max_runtime_seconds: int,
) -> None:
stdout_writer = stdout_transport.get_write_handler()
stderr_writer = stderr_transport.get_write_handler()
exit_code: int | None = None
started_at = time.monotonic()
try:
while True:
if time.monotonic() - started_at >= max_runtime_seconds:
exit_code = self._COMMAND_TIMEOUT_EXIT_CODE
stderr_writer.write(f"Command timed out after {max_runtime_seconds}s".encode())
break
if channel.recv_ready():
stdout_writer.write(channel.recv(4096))
if channel.recv_stderr_ready():
@@ -377,6 +410,9 @@ class SSHSandboxEnvironment(VirtualEnvironment):
break
time.sleep(0.05)
except TimeoutError:
logger.warning("SSH channel read timed out for command %s", pid)
exit_code = self._COMMAND_TIMEOUT_EXIT_CODE
finally:
with contextlib.suppress(Exception):
stdout_transport.close()
@@ -388,6 +424,10 @@ class SSHSandboxEnvironment(VirtualEnvironment):
with self._lock:
self._commands[pid] = CommandStatus(status=CommandStatus.Status.COMPLETED, exit_code=exit_code)
def _set_sftp_operation_timeout(self, sftp: Any) -> None:
with contextlib.suppress(Exception):
sftp.get_channel().settimeout(self._DEFAULT_SSH_OPERATION_TIMEOUT_SECONDS)
@staticmethod
def _parse_arch(raw_arch: str) -> Arch:
arch = raw_arch.lower()