1
0
Fork 0
matrix-appservice-minecraft/message_queue.py

210 lines
6.3 KiB
Python

import enum
import logging
import json
import queue
import select
import socket
import struct
import time
import threading
from typing import Any, Callable, Dict, Optional, TypeVar
LOG = logging.getLogger(__name__)
class ProtocolError(Exception):
"""Error thrown when the SocketWrapper protocol is violated."""
T = TypeVar("T")
def _try_with_backoff(fn: Callable[[], T], error_callback: Callable) -> T:
backoff = 1
while True:
try:
return fn()
except Exception as e:
if error_callback(e):
LOG.exception(e)
LOG.warning("Trying again in {} seconds".format(backoff))
time.sleep(backoff)
backoff *= 2
else:
raise e
else:
break
def socket_create_server(addr):
# socket.create_server doesn't exist until python 3.8 :(
soc = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
soc.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
soc.bind(addr)
return soc
class SocketWrapper:
"""Wraps a network socket with simpler connection, send, and receive logic.
Both ends of the connection must be using a SocketWrapper compatible
protocol. The protocol is as follows:
Messages are utf-8 encoded strings prefixed by their length as a 4-byte
big-endian integer. Example: 00 00 00 06 h e l l o \n.
"""
def __init__(self, host: str, port: int):
self.host = host
self.port = port
self.soc = None
def connect(self):
LOG.debug("Connecting to {}:{}".format(self.host, self.port))
self.soc = _try_with_backoff(
lambda: socket.create_connection((self.host, self.port)),
lambda e: isinstance(e, OSError) and e.errno == 111,
)
LOG.info("Socket Connected")
def bind(self):
LOG.info("Server Binding to {}:{}".format(self.host, self.port))
self.soc = _try_with_backoff(
lambda: socket_create_server((self.host, self.port)),
lambda e: isinstance(e, OSError) and e.errno == 98,
)
LOG.info("Server Bound")
def accept(self):
self.bind()
self.soc.listen(1)
LOG.info("Server listen to host")
client_soc, addr = self.soc.accept()
self.soc.close()
self.soc = client_soc
self.soc.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
# self.soc.send(b"\0\0\0\0")
LOG.info("Server accepted connection: {}".format(addr))
def _write_int(self, integer):
integer_buf = struct.pack(">i", integer)
self._write(integer_buf)
def _write(self, data: bytes):
if self.soc is None:
raise RuntimeError("Must call connect or accept before writing")
data_len = len(data)
offset = 0
while offset != data_len:
offset += self.soc.send(data[offset:])
def send(self, message: str):
_, writable, _ = select.select([], [self.soc], [], 1)
if writable == []:
# TODO Try to reacquire
raise RuntimeError("Unable to write to socket")
payload = message.encode("utf-8")
self._write_int(len(payload))
self._write(payload)
LOG.debug("sent {} bytes".format(len(payload)))
def _read(self, size) -> bytes:
data = b""
while len(data) != size:
newdata = self.soc.recv(size - len(data))
if len(newdata) == 0:
# Orderly shutdown, or 0 bytes requested to read
return data
data = data + newdata
return data
def _read_int(self) -> Optional[int]:
int_size = struct.calcsize(">i")
intbuf = self._read(int_size)
if len(intbuf) == 0:
return None
return struct.unpack(">i", intbuf)[0]
def receive(self) -> Optional[str]:
if self.soc is None:
raise RuntimeError("Must call connect or accept before writing")
readable, _, _ = select.select([self.soc], [], [], 1)
if readable == []:
return None
message_size = self._read_int()
if message_size is None:
# socket closed
return None
data = self._read(message_size)
return data.decode("utf-8") if data is not None else None
def close(self):
if self.soc is not None:
self.soc.close()
self.soc = None
class Side(enum.Enum):
CLIENT = enum.auto()
SERVER = enum.auto()
class MessageQueue:
"""A bidirectional queue of JSON messages over a network socket.
Asynchronously sends and receives messages until closed. Consume messages
with iteration and enqueue messages with add().
"""
def __init__(self, host: str, port: int, side: Side):
self.inbox: queue.Queue[str] = queue.Queue()
self.outbox: queue.Queue[str] = queue.Queue()
self.closed = False
self.process_worker = threading.Thread(
target=process_messages,
args=(side, SocketWrapper(host, port), self),
daemon=True,
name="MessageQueue/" + str(side),
)
self.process_worker.start()
def add(self, message: Dict[str, Any]):
self.outbox.put(json.dumps(message))
def __iter__(self):
return self
def __next__(self):
m = self.inbox.get()
LOG.debug(m)
return json.loads(m)
def close(self):
self.closed = True
self.process_worker.join()
def process_messages(side: Side, socket: SocketWrapper, queue: MessageQueue):
log = LOG.getChild("worker.{}".format(side))
while not queue.closed:
try:
socket.accept() if side == Side.SERVER else socket.connect()
while not queue.closed:
message = socket.receive()
if message is not None:
queue.inbox.put(message)
while not queue.outbox.empty():
log.debug("Sending outbox item")
message = queue.outbox.get(block=False)
try:
socket.send(message)
except Exception as e:
# TODO(chr): Doesn't preserve order, use priority queue
queue.outbox.put(message)
raise e from e
except Exception as e:
LOG.exception(e)
finally:
socket.close()