1
0
Fork 0
matrix-appservice-minecraft/message_queue.py

210 lines
6.3 KiB
Python
Raw Normal View History

2020-04-02 08:33:34 +02:00
import enum
import logging
import json
import queue
import select
import socket
import struct
import time
import threading
from typing import Any, Callable, Dict, Optional, TypeVar
LOG = logging.getLogger(__name__)
class ProtocolError(Exception):
"""Error thrown when the SocketWrapper protocol is violated."""
T = TypeVar("T")
2020-04-03 06:52:36 +02:00
def _try_with_backoff(fn: Callable[[], T], error_callback: Callable) -> T:
2020-04-02 08:33:34 +02:00
backoff = 1
while True:
try:
return fn()
except Exception as e:
if error_callback(e):
LOG.exception(e)
LOG.warning("Trying again in {} seconds".format(backoff))
time.sleep(backoff)
backoff *= 2
else:
raise e
else:
break
def socket_create_server(addr):
# socket.create_server doesn't exist until python 3.8 :(
soc = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
soc.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
soc.bind(addr)
return soc
class SocketWrapper:
"""Wraps a network socket with simpler connection, send, and receive logic.
Both ends of the connection must be using a SocketWrapper compatible
protocol. The protocol is as follows:
Messages are utf-8 encoded strings prefixed by their length as a 4-byte
big-endian integer. Example: 00 00 00 06 h e l l o \n.
"""
def __init__(self, host: str, port: int):
self.host = host
self.port = port
self.soc = None
def connect(self):
LOG.debug("Connecting to {}:{}".format(self.host, self.port))
self.soc = _try_with_backoff(
lambda: socket.create_connection((self.host, self.port)),
2020-04-03 06:54:01 +02:00
lambda e: isinstance(e, OSError) and e.errno == 111,
2020-04-02 08:33:34 +02:00
)
LOG.info("Socket Connected")
def bind(self):
LOG.info("Server Binding to {}:{}".format(self.host, self.port))
self.soc = _try_with_backoff(
lambda: socket_create_server((self.host, self.port)),
2020-04-03 06:54:01 +02:00
lambda e: isinstance(e, OSError) and e.errno == 98,
2020-04-02 08:33:34 +02:00
)
LOG.info("Server Bound")
def accept(self):
self.bind()
self.soc.listen(1)
LOG.info("Server listen to host")
client_soc, addr = self.soc.accept()
self.soc.close()
self.soc = client_soc
self.soc.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
# self.soc.send(b"\0\0\0\0")
LOG.info("Server accepted connection: {}".format(addr))
def _write_int(self, integer):
integer_buf = struct.pack(">i", integer)
self._write(integer_buf)
def _write(self, data: bytes):
if self.soc is None:
raise RuntimeError("Must call connect or accept before writing")
data_len = len(data)
offset = 0
while offset != data_len:
offset += self.soc.send(data[offset:])
def send(self, message: str):
_, writable, _ = select.select([], [self.soc], [], 1)
if writable == []:
# TODO Try to reacquire
raise RuntimeError("Unable to write to socket")
payload = message.encode("utf-8")
self._write_int(len(payload))
self._write(payload)
LOG.debug("sent {} bytes".format(len(payload)))
2020-04-03 06:52:36 +02:00
def _read(self, size) -> bytes:
2020-04-02 08:33:34 +02:00
data = b""
while len(data) != size:
newdata = self.soc.recv(size - len(data))
if len(newdata) == 0:
# Orderly shutdown, or 0 bytes requested to read
return data
data = data + newdata
return data
def _read_int(self) -> Optional[int]:
int_size = struct.calcsize(">i")
intbuf = self._read(int_size)
if len(intbuf) == 0:
return None
return struct.unpack(">i", intbuf)[0]
def receive(self) -> Optional[str]:
if self.soc is None:
raise RuntimeError("Must call connect or accept before writing")
readable, _, _ = select.select([self.soc], [], [], 1)
if readable == []:
return None
message_size = self._read_int()
if message_size is None:
# socket closed
return None
data = self._read(message_size)
return data.decode("utf-8") if data is not None else None
def close(self):
if self.soc is not None:
self.soc.close()
self.soc = None
2020-04-02 08:33:34 +02:00
class Side(enum.Enum):
CLIENT = enum.auto()
SERVER = enum.auto()
class MessageQueue:
"""A bidirectional queue of JSON messages over a network socket.
Asynchronously sends and receives messages until closed. Consume messages
with iteration and enqueue messages with add().
"""
def __init__(self, host: str, port: int, side: Side):
self.inbox: queue.Queue[str] = queue.Queue()
self.outbox: queue.Queue[str] = queue.Queue()
self.closed = False
self.process_worker = threading.Thread(
target=process_messages,
args=(side, SocketWrapper(host, port), self),
2020-04-02 08:33:34 +02:00
daemon=True,
2020-04-02 08:33:49 +02:00
name="MessageQueue/" + str(side),
2020-04-02 08:33:34 +02:00
)
self.process_worker.start()
def add(self, message: Dict[str, Any]):
self.outbox.put(json.dumps(message))
def __iter__(self):
return self
def __next__(self):
2020-04-03 06:52:36 +02:00
m = self.inbox.get()
LOG.debug(m)
return json.loads(m)
2020-04-02 08:33:34 +02:00
def close(self):
self.closed = True
self.process_worker.join()
def process_messages(side: Side, socket: SocketWrapper, queue: MessageQueue):
log = LOG.getChild("worker.{}".format(side))
2020-04-02 08:33:34 +02:00
while not queue.closed:
try:
socket.accept() if side == Side.SERVER else socket.connect()
2020-04-02 08:33:34 +02:00
while not queue.closed:
message = socket.receive()
if message is not None:
queue.inbox.put(message)
while not queue.outbox.empty():
log.debug("Sending outbox item")
message = queue.outbox.get(block=False)
try:
socket.send(message)
except Exception as e:
# TODO(chr): Doesn't preserve order, use priority queue
queue.outbox.put(message)
raise e from e
2020-04-02 08:33:34 +02:00
except Exception as e:
LOG.exception(e)
finally:
socket.close()