import asyncio
import base64
import enum
import inspect
import io
import logging
import os
import pickle
import re
import struct
import subprocess
import sys
import textwrap
import threading
from collections.abc import Callable, Coroutine
from functools import cache
from lzma import compress, decompress
from pathlib import Path
from types import FunctionType
from typing import Any, ParamSpec, Self, TypedDict, TypeVar, cast, overload
[docs]
class Template:
"""A Mako-like template pre-compiled to a reusable render function.
Picklable - the instance stores only the original template string, so it
can be passed as an argument to remote tool calls over the protocol. The
class lives in ``protocol.py`` and is therefore available on the remote
side without any extra sync step.
Usage::
tmpl = Template("Hello, ${name}!")
tmpl.render(name="Alice") # local
await protocol(MyTool.run, tmpl) # pass to remote tool
Syntax (Mako-like):
* ``${expr}`` - evaluate *expr* and insert the string result;
nested braces are handled correctly (e.g. ``${{'k': 1}['k']}``)
* ``\\${`` - literal ``${`` (escape, no interpolation)
* ``% stmt`` - Python control-flow line (for / if / while / …)
* ``% endfor`` / ``% endif`` / ``% end`` - block terminators
* ``%%`` - literal ``%`` at the start of an output line
* ``## comment`` - ignored
"""
BLOCK_OPEN = frozenset({"for", "if", "while", "with", "try", "def", "class"})
BLOCK_CONT = frozenset({"else", "elif", "except", "finally"})
BLOCK_END = frozenset({"endfor", "endif", "endwhile", "endwith", "end"})
[docs]
def __init__(self, template: str) -> None:
self._template = template
self._fn: Callable[..., str] = Template.compile(template)
@staticmethod
def _split_exprs(line: str) -> list[tuple[bool, str]]:
"""Split *line* into ``(is_expr, fragment)`` pairs.
Handles ``\\${`` escape (→ literal ``${``), bare ``$`` (literal),
and nested ``{}`` inside expressions via brace-depth counting.
"""
result: list[tuple[bool, str]] = []
buf: list[str] = []
i = 0
n = len(line)
while i < n:
# \${ → literal ${
if line[i] == "\\" and line[i + 1 : i + 3] == "${":
buf.append("${")
i += 3
# ${ → start of expression
elif line[i] == "$" and i + 1 < n and line[i + 1] == "{":
if buf:
result.append((False, "".join(buf)))
buf = []
i += 2 # consume '${'
depth = 1
expr: list[str] = []
while i < n and depth > 0:
ch = line[i]
if ch == "{":
depth += 1
expr.append(ch)
elif ch == "}":
depth -= 1
if depth > 0:
expr.append(ch)
else:
expr.append(ch)
i += 1
result.append((True, "".join(expr)))
else:
buf.append(line[i])
i += 1
if buf:
result.append((False, "".join(buf)))
return result
[docs]
@staticmethod
@cache
def compile(template: str) -> "Callable[..., str]":
"""Compile a Mako-like *template* string into a reusable render function.
Returns a callable that accepts ``**ctx`` keyword arguments and returns
the rendered string. Results are cached so repeated
calls with the same template string are free.
"""
indent_level = 0
indent_unit = " "
lines_out = ["_out_ = []", "_pending_nl_ = False"]
def cur_indent() -> str:
return indent_unit * indent_level
def emit_text_line(raw_line: str) -> None:
lines_out.append(f"{cur_indent()}if _pending_nl_: _out_.append('\\n')")
for is_expr, text in Template._split_exprs(raw_line):
if is_expr:
lines_out.append(f"{cur_indent()}_out_.append(str({text}))")
elif text:
lines_out.append(f"{cur_indent()}_out_.append({text!r})")
lines_out.append(f"{cur_indent()}_pending_nl_ = True")
for raw_line in template.splitlines():
stripped = raw_line.strip()
if stripped.startswith("##"):
continue
if stripped.startswith("%%"):
# %% → literal % line (still processes ${} expressions)
idx = raw_line.index("%%")
emit_text_line(raw_line[:idx] + "%" + raw_line[idx + 2 :])
continue
if stripped.startswith("%"):
code = stripped[1:].strip()
if not code:
indent_level = max(0, indent_level - 1)
continue
keyword = code.split()[0].rstrip(":(")
if keyword in Template.BLOCK_END:
indent_level = max(0, indent_level - 1)
elif keyword in Template.BLOCK_CONT:
indent_level = max(0, indent_level - 1)
py_line = code if code.endswith(":") else code + ":"
lines_out.append(f"{cur_indent()}{py_line}")
indent_level += 1
elif keyword in Template.BLOCK_OPEN:
py_line = code if code.endswith(":") else code + ":"
lines_out.append(f"{cur_indent()}{py_line}")
indent_level += 1
else:
lines_out.append(f"{cur_indent()}{code}")
continue
emit_text_line(raw_line)
lines_out.append("_result_ = ''.join(_out_)")
source = "\n".join(lines_out)
code_obj = compile(source, "<template>", "exec") # noqa: PLC0415
def _render(**ctx: object) -> str:
ns: dict[str, object] = dict(ctx)
exec(code_obj, ns)
return str(ns["_result_"])
return _render
[docs]
def render(self, **ctx: object) -> str:
"""Render this template with *ctx* as the variable namespace."""
return self._fn(**ctx)
[docs]
def __reduce__(self) -> tuple[type, tuple[str]]:
return (Template, (self._template,))
def __repr__(self) -> str:
preview = self._template[:40].replace("\n", "\\n")
return f"Template({preview!r})"
[docs]
def render_template(template: str, **ctx: object) -> str:
"""Compile and render a Mako-like *template* string in one step."""
return Template(template).render(**ctx)
class RPCRequest(TypedDict):
method: str
args: Any # Can be tuple[Any, ...] or P.args
kwargs: dict[str, Any]
class LogRecord(TypedDict):
name: str
levelno: int
levelname: str
pathname: str
lineno: int
msg: str
args: Any
exc_info: Any
[docs]
def bootstrap_packer(code: bytes) -> bytes:
with io.BytesIO() as output:
output.write(b"from lzma import decompress\n")
output.write(b"from base64 import b64decode\n")
output.write(b"\n")
output.write(b"exec(decompress(b64decode('''")
output.write(base64.b64encode(compress(code)))
output.write(b"''')))\n")
return output.getvalue()
[docs]
def process(
*cmd_and_args: str,
stdin: None | bytes | str = None,
capture_output: bool = False,
text: bool = False,
env: dict[str, str] | None = None,
shell: bool = False,
check: bool = False,
cwd: None | str | Path = None,
) -> subprocess.CompletedProcess[Any]:
"""
function for execute a subprocess on the remote side,
must be safe, do not share stdout/stderr of the child process,
because it's protocol pipes.
Tools must be use only this function for execute subprocesses,
to avoid conflicts with protocol communication.
"""
logging.debug("Executing subprocess: %r", cmd_and_args)
kwargs: dict[str, Any] = {
"stdin": subprocess.DEVNULL,
"capture_output": capture_output,
"text": text,
"env": env,
"shell": shell,
"check": check,
"cwd": cwd,
}
if stdin is not None:
if isinstance(stdin, str):
stdin = stdin.encode()
# Use input= (not stdin=) so subprocess uses PIPE internally;
# remove stdin=DEVNULL to avoid the "stdin and input may not both be used" error.
del kwargs["stdin"]
kwargs["input"] = stdin
if not capture_output:
kwargs["stdout"] = subprocess.DEVNULL
kwargs["stderr"] = subprocess.DEVNULL
return subprocess.run(cmd_and_args, **kwargs)
# Strip rmote imports except rmote.protocol, which is registered in sys.modules on the remote
_RMOTE_IMPORT_RE = re.compile(r"^(?:from|import) rmote(?!\.protocol\b)[^\n]*\n", re.MULTILINE)
class ToolMeta(type):
def __new__(mcs, name: str, bases: tuple[type, ...], namespace: dict[str, Any]) -> type:
cls = super().__new__(mcs, name, bases, namespace)
if "__init__" in namespace:
raise TypeError("__init__ cannot be defined in a Tool")
# __source__ kept for inline-tool fallback (qualname contains <locals>)
source: str | None = None
try:
source = textwrap.dedent(inspect.getsource(cls))
except (OSError, TypeError):
pass
cls.__source__ = source # type: ignore[attr-defined]
cls.__class_name__ = name # type: ignore[attr-defined]
cls.__base_names__ = [b.__name__ for b in bases if b is not object] # type: ignore[attr-defined]
for attr_name in namespace:
val = cls.__dict__.get(attr_name)
if isinstance(val, FunctionType):
val.__tool_class__ = cls # type: ignore[attr-defined]
elif isinstance(val, (staticmethod, classmethod)):
val.__func__.__tool_class__ = cls # type: ignore[union-attr]
elif isinstance(val, property):
for accessor in ("fget", "fset", "fdel"):
f = getattr(val, accessor, None)
if f:
f.__tool_class__ = cls
return cls
[docs]
class Flags(enum.IntFlag):
COMPRESSED = 1
REQUEST = 1 << 1
RESPONSE = 1 << 2
SYNC = 1 << 3
RPC = 1 << 4
EXCEPTION = 1 << 5
LOG = 1 << 6
[docs]
class BaseProtocol:
# 4 byte magic RMOTE, uint32 payload length
MAGIC = b"RMOTE"
BOUNDARY = b"PROTOCOL READY\n"
PACKET_HEADER = struct.Struct(">5sIIQ")
COMPRESSION_THRESHOLD = 1024
def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
self.reader = reader
self.writer = writer
self.read_lock = asyncio.Lock()
self.write_lock = asyncio.Lock()
[docs]
async def receive(self) -> tuple[Any, Flags, int]:
async with self.read_lock:
header = await self.reader.readexactly(self.PACKET_HEADER.size)
magic, flags, length, packet_id = self.PACKET_HEADER.unpack(header)
flags = Flags(flags)
if magic != self.MAGIC:
raise ValueError("Invalid magic number")
payload = await self.reader.readexactly(length)
if flags & Flags.COMPRESSED:
payload = await asyncio.to_thread(decompress, payload)
return pickle.loads(payload), flags, packet_id
[docs]
async def send(self, packet: Any, flags: Flags, packet_id: int) -> None:
flags = Flags(flags)
if flags & Flags.COMPRESSED:
raise ValueError("Compression flag must not be set for send()")
async with self.write_lock:
payload = pickle.dumps(packet)
if len(payload) > self.COMPRESSION_THRESHOLD:
payload = await asyncio.to_thread(compress, payload)
flags |= Flags.COMPRESSED
header = self.PACKET_HEADER.pack(self.MAGIC, flags, len(payload), packet_id)
self.writer.write(header + payload)
await self.writer.drain()
[docs]
async def write_boundary(self) -> None:
async with self.write_lock:
self.writer.write(self.BOUNDARY)
await self.writer.drain()
[docs]
async def read_boundary(self) -> None:
async with self.read_lock:
while True:
try:
line = await self.reader.readline()
except (asyncio.LimitOverrunError, ValueError):
# Line exceeded the stream reader limit - readline() consumed
# the oversized chunk and re-raised as ValueError (3.14+) or
# LimitOverrunError (older). Either way, skip and keep scanning.
continue
if line == self.BOUNDARY:
return
if not line:
raise ConnectionError("Remote process closed the connection before PROTOCOL READY")
[docs]
@classmethod
async def from_subprocess(cls, process: asyncio.subprocess.Process) -> Self:
assert process.stdin is not None, "Process stdin must not be None"
assert process.stdout is not None, "Process stdout must not be None"
process.stdin.write(bootstrap_packer(open(__file__, "rb").read()))
process.stdin.write(b"asyncio.run(run())\n")
return cls(reader=process.stdout, writer=process.stdin)
[docs]
@classmethod
async def from_ssh(
cls,
host: str,
*,
user: str | None = None,
port: int | None = None,
identity: str | None = None,
python: str = "python3",
ssh_options: list[str] | None = None,
stderr: int = asyncio.subprocess.DEVNULL,
) -> Self:
"""Bootstrap a Protocol over SSH.
Args:
host: Remote host, optionally in ``user@host`` form.
user: Remote username (``-l``). Overrides any user embedded in *host*.
port: SSH port (``-p``).
identity: Path to an SSH identity file (``-i``).
python: Python executable on the remote host.
ssh_options: Extra arguments inserted before the host in the ``ssh`` command
(e.g. ``["-o", "StrictHostKeyChecking=no"]``).
stderr: Where to redirect remote stderr. Defaults to ``DEVNULL``.
"""
cmd: list[str] = ["ssh", "-T"]
if user is not None:
cmd += ["-l", user]
if port is not None:
cmd += ["-p", str(port)]
if identity is not None:
cmd += ["-i", identity]
if ssh_options:
cmd += ssh_options
cmd += [host, python, "-qui"]
proc = await asyncio.create_subprocess_exec(
*cmd,
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=stderr,
)
instance = await cls.from_subprocess(proc)
instance._owned_process = proc # type: ignore[attr-defined]
return instance
[docs]
@classmethod
async def from_stdio(cls, logging_level: int = logging.INFO) -> Self:
loop = asyncio.get_running_loop()
# Protect the protocol channel from accidental corruption.
#
# The remote process talks to the local side over its own fd 0 (stdin)
# and fd 1 (stdout). Any child process that inherits those descriptors,
# or any code that calls print() / os.write(1, ...) / subprocess.run()
# with default stdio, can silently inject bytes into the binary packet
# stream and break the connection permanently.
#
# Fix: dup stdin/stdout to new fds, redirect 0/1/2 to /dev/null, then
# hand the duped fds to asyncio. The new fds are marked close-on-exec
# so they are never passed to child processes at all. From this point
# on, the protocol pipe is completely unreachable from user code and
# from any subprocess spawned by tool methods.
proto_in_fd = os.dup(sys.stdin.fileno())
proto_out_fd = os.dup(sys.stdout.fileno())
os.set_inheritable(proto_in_fd, False)
os.set_inheritable(proto_out_fd, False)
devnull_r = os.open(os.devnull, os.O_RDONLY)
devnull_w = os.open(os.devnull, os.O_WRONLY)
os.dup2(devnull_r, 0)
os.dup2(devnull_w, 1)
os.dup2(devnull_w, 2)
os.close(devnull_r)
os.close(devnull_w)
proto_in = os.fdopen(proto_in_fd, "rb", buffering=0)
proto_out = os.fdopen(proto_out_fd, "wb", buffering=0)
async def wrap_reader() -> asyncio.StreamReader:
reader = asyncio.StreamReader()
proto = asyncio.StreamReaderProtocol(reader)
await loop.connect_read_pipe(lambda: proto, proto_in)
return reader
async def wrap_writer() -> asyncio.StreamWriter:
transport, proto = await loop.connect_write_pipe(asyncio.streams.FlowControlMixin, proto_out)
writer = asyncio.StreamWriter(transport, proto, None, loop)
return writer
protocol = cls(await wrap_reader(), await wrap_writer())
log_handler = RemoteLogHandler(protocol, loop) # type: ignore[arg-type]
root_logger = logging.getLogger()
root_logger.handlers.clear()
root_logger.addHandler(log_handler)
root_logger.setLevel(logging_level)
return protocol
P = ParamSpec("P")
R = TypeVar("R")
[docs]
class Protocol(BaseProtocol):
def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
super().__init__(reader, writer)
self._tools_cache: set[type[Tool]] = set()
self.futures: dict[int, asyncio.Future[Any]] = {}
self.loop = asyncio.get_running_loop()
self._tools_cache = set()
self._loop_task: asyncio.Task[None] | None = None
self._closed = asyncio.Event()
self._tasks: set[asyncio.Task[Any]] = set()
self.tools: dict[str, Tool] = dict()
self._owned_process: asyncio.subprocess.Process | None = None
# Sync ID generation for RPC/SYNC packets
self.__last_id = 0
self.__last_id_lock = threading.Lock()
# Separate ID generation for LOG packets (use negative IDs to avoid conflicts)
self.__last_log_id = 0
self.__last_log_id_lock = threading.Lock()
[docs]
def get_id(self) -> int:
with self.__last_id_lock:
self.__last_id += 1
return self.__last_id
[docs]
def get_log_id(self) -> int:
with self.__last_log_id_lock:
self.__last_log_id -= 1
return self.__last_log_id
[docs]
async def __aenter__(self) -> Self:
await self.write_boundary()
await self.read_boundary()
self._loop_task = asyncio.create_task(self._loop())
return self
[docs]
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
self._closed.set()
if self._loop_task:
self._loop_task.cancel()
for task in self._tasks:
task.cancel()
await asyncio.gather(self._loop_task, *self._tasks, return_exceptions=True)
self._loop_task = None
self.writer.close()
try:
await self.writer.wait_closed()
except Exception:
pass
if self._owned_process is not None:
try:
self._owned_process.terminate()
except ProcessLookupError:
pass
await self._owned_process.wait()
self._owned_process = None
def _load_tool(self, tool_definition: dict[str, Any], _: int) -> None:
tool_cls = tool_from_dict(tool_definition)
self.tools[tool_cls.__name__] = tool_cls()
logging.debug("Loaded tool %s", tool_definition)
async def _handle_rpc_request(self, request: RPCRequest, _: int) -> Any:
if "." not in request["method"]:
raise ValueError("Invalid method name")
tool_name, method_name = request["method"].split(".", 1)
tool = self.tools.get(tool_name)
if tool is None:
raise ValueError(f"Tool {tool_name} not found")
method = getattr(tool, method_name, None)
if method is None:
raise ValueError(f"Method {method_name} not found in tool {tool_name}")
if not callable(method):
raise ValueError(f"{method_name} is not callable in tool {tool_name}")
if inspect.iscoroutinefunction(method):
return await method(*request["args"], **request["kwargs"])
return await asyncio.to_thread(method, *request["args"], **request["kwargs"])
async def _handle_rpc_response(self, response: Any, packet_id: int) -> None:
if packet_id not in self.futures:
logging.warning("RPC response %r packet not found in futures", packet_id)
return
future = self.futures.pop(packet_id)
future.set_result(response)
async def _handle_exception(self, exception: Exception, packet_id: int) -> None:
if packet_id not in self.futures:
logging.warning("Exception response %r packet not found in futures: %s", packet_id, exception)
return
future = self.futures.pop(packet_id)
future.set_exception(exception)
@staticmethod
async def _handle_log(record: LogRecord, _: int) -> None:
logger = logging.getLogger(f"rmote.remote.{record['name']}")
log_record = logging.LogRecord(
name=record["name"],
level=record["levelno"],
pathname=record["pathname"],
lineno=record["lineno"],
msg=record["msg"],
args=record["args"],
exc_info=record["exc_info"],
)
logger.handle(log_record)
def _execute(
self, packet_id: int, flags: Flags, handler: Callable[..., Any], payload: Any, need_response: bool = False
) -> None:
async def wrapper() -> None:
nonlocal flags
coro: asyncio.Task[Any] | asyncio.Task[Any]
if inspect.iscoroutinefunction(handler):
coro = asyncio.create_task(handler(payload, packet_id))
else:
coro = asyncio.create_task(asyncio.to_thread(handler, payload, packet_id))
try:
resp = await coro
flags |= Flags.RESPONSE
except Exception as e:
resp = e
flags = Flags.EXCEPTION | Flags.RESPONSE
if need_response:
await self.send(resp, flags, packet_id)
task = asyncio.create_task(wrapper())
self._tasks.add(task)
task.add_done_callback(self._tasks.discard)
[docs]
async def wait_closed(self) -> None:
await self._closed.wait()
async def _loop(self) -> None:
error: Exception | None = None
try:
while not self._closed.is_set():
payload, flags, packet_id = await self.receive()
logging.debug("Received packet %d with flags %r: %r", packet_id, flags, payload)
need_response = bool(flags & Flags.REQUEST)
if flags & Flags.SYNC and flags & Flags.REQUEST:
self._execute(packet_id, Flags.SYNC, self._load_tool, payload, need_response)
elif flags & Flags.RPC and flags & Flags.REQUEST:
self._execute(
packet_id,
Flags.RPC,
self._handle_rpc_request,
RPCRequest(**payload), # type: ignore[typeddict-item]
need_response,
)
elif (flags & Flags.RPC and flags & Flags.RESPONSE) or (flags & Flags.SYNC and flags & Flags.RESPONSE):
self._execute(packet_id, Flags.RPC, self._handle_rpc_response, payload, need_response)
elif flags & Flags.EXCEPTION and flags & Flags.RESPONSE:
self._execute(packet_id, Flags.EXCEPTION, self._handle_exception, payload, need_response)
elif flags & Flags.LOG:
self._execute(packet_id, Flags.LOG, self._handle_log, LogRecord(**payload), False) # type: ignore[typeddict-item]
except Exception as e:
error = e
finally:
self._closed.set()
exc = error or ConnectionError("Remote process closed the connection")
for future in self.futures.values():
if not future.done():
future.set_exception(exc)
self.futures.clear()
async def _call(self, payload: Any, flags: Flags) -> Any:
logging.debug("Call %r %s", flags, payload)
packet_id = self.get_id()
future = self.loop.create_future()
self.futures[packet_id] = future
await self.send(payload, flags, packet_id)
return await future
@overload
async def __call__(self, tool: Callable[P, Coroutine[Any, Any, R]], *args: P.args, **kwargs: P.kwargs) -> R: ...
@overload
async def __call__(self, tool: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R: ...
[docs]
async def __call__(self, tool: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
tool_class = getattr(tool, "__tool_class__", None)
# For classmethods/staticmethods, check __func__ if __tool_class__ not found on the method itself
if tool_class is None and hasattr(tool, "__func__"):
tool_class = getattr(tool.__func__, "__tool_class__", None)
if tool_class is None:
raise ValueError("Only methods of Tool classes can be called with call_tool()")
if tool_class not in self._tools_cache:
await self._call(tool_to_dict(tool_class), Flags.SYNC | Flags.REQUEST)
self._tools_cache.add(tool_class)
# Use ClassName.method_name rather than __qualname__ so that inline tools
# (whose qualname includes the enclosing scope, e.g. "fn.<locals>.T.m")
# are dispatched correctly on the remote side.
method_id = f"{tool_class.__name__}.{tool.__name__}"
result: Any = await self._call(
RPCRequest(method=method_id, args=args, kwargs=kwargs), Flags.RPC | Flags.REQUEST
)
return result
class RemoteLogHandler(logging.Handler):
def __init__(self, protocol: Protocol, loop: asyncio.AbstractEventLoop, level: int = logging.NOTSET) -> None:
super().__init__(level)
self.protocol = protocol
self.loop = loop
def emit(self, record: logging.LogRecord) -> None:
record_dict = LogRecord(
name=record.name,
levelno=record.levelno,
levelname=record.levelname,
pathname=record.pathname,
lineno=record.lineno,
msg=record.msg,
args=record.args,
exc_info=record.exc_info,
)
# LOG packets use negative IDs to avoid conflicts with RPC packet_ids
self.loop.create_task(self.protocol.send(record_dict, Flags.LOG, self.protocol.get_log_id()))
async def run() -> None:
"""remote endpoint entry point, do not call directly"""
import types as _types
# Register this module as rmote.protocol so Tool files can do
# `from rmote.protocol import Tool, process` on the remote side.
_mod = _types.ModuleType("rmote.protocol")
_mod.__dict__.update(globals())
sys.modules.setdefault("rmote", _types.ModuleType("rmote"))
sys.modules["rmote.protocol"] = _mod
proto = await Protocol.from_stdio()
async with proto:
await proto.wait_closed()