import enum import logging import json import queue import select import socket import struct import time import threading from typing import Any, Callable, Dict, Optional, TypeVar LOG = logging.getLogger(__name__) class ProtocolError(Exception): """Error thrown when the SocketWrapper protocol is violated.""" T = TypeVar("T") def _try_with_backoff(fn: Callable[[], T], error_callback: Callable) -> T: backoff = 1 while True: try: return fn() except Exception as e: if error_callback(e): LOG.exception(e) LOG.warning("Trying again in {} seconds".format(backoff)) time.sleep(backoff) backoff *= 2 else: raise e else: break def socket_create_server(addr): # socket.create_server doesn't exist until python 3.8 :( soc = socket.socket(socket.AF_INET, socket.SOCK_STREAM) soc.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) soc.bind(addr) return soc class SocketWrapper: """Wraps a network socket with simpler connection, send, and receive logic. Both ends of the connection must be using a SocketWrapper compatible protocol. The protocol is as follows: Messages are utf-8 encoded strings prefixed by their length as a 4-byte big-endian integer. Example: 00 00 00 06 h e l l o \n. """ def __init__(self, host: str, port: int): self.host = host self.port = port self.soc = None def connect(self): LOG.debug("Connecting to {}:{}".format(self.host, self.port)) self.soc = _try_with_backoff( lambda: socket.create_connection((self.host, self.port)), lambda e: isinstance(e, OSError) and e.errno == 111, ) LOG.info("Socket Connected") def bind(self): LOG.info("Server Binding to {}:{}".format(self.host, self.port)) self.soc = _try_with_backoff( lambda: socket_create_server((self.host, self.port)), lambda e: isinstance(e, OSError) and e.errno == 98, ) LOG.info("Server Bound") def accept(self): self.bind() self.soc.listen(1) LOG.info("Server listen to host") client_soc, addr = self.soc.accept() self.soc.close() self.soc = client_soc self.soc.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) # self.soc.send(b"\0\0\0\0") LOG.info("Server accepted connection: {}".format(addr)) def _write_int(self, integer): integer_buf = struct.pack(">i", integer) self._write(integer_buf) def _write(self, data: bytes): if self.soc is None: raise RuntimeError("Must call connect or accept before writing") data_len = len(data) offset = 0 while offset != data_len: offset += self.soc.send(data[offset:]) def send(self, message: str): _, writable, _ = select.select([], [self.soc], [], 1) if writable == []: # TODO Try to reacquire raise RuntimeError("Unable to write to socket") payload = message.encode("utf-8") self._write_int(len(payload)) self._write(payload) LOG.debug("sent {} bytes".format(len(payload))) def _read(self, size) -> bytes: data = b"" while len(data) != size: newdata = self.soc.recv(size - len(data)) if len(newdata) == 0: # Orderly shutdown, or 0 bytes requested to read return data data = data + newdata return data def _read_int(self) -> Optional[int]: int_size = struct.calcsize(">i") intbuf = self._read(int_size) if len(intbuf) == 0: return None return struct.unpack(">i", intbuf)[0] def receive(self) -> Optional[str]: if self.soc is None: raise RuntimeError("Must call connect or accept before writing") readable, _, _ = select.select([self.soc], [], [], 1) if readable == []: return None message_size = self._read_int() if message_size is None: # socket closed return None data = self._read(message_size) return data.decode("utf-8") if data is not None else None def close(self): if self.soc is not None: self.soc.close() self.soc = None class Side(enum.Enum): CLIENT = enum.auto() SERVER = enum.auto() class MessageQueue: """A bidirectional queue of JSON messages over a network socket. Asynchronously sends and receives messages until closed. Consume messages with iteration and enqueue messages with add(). """ def __init__(self, host: str, port: int, side: Side): self.inbox: queue.Queue[str] = queue.Queue() self.outbox: queue.Queue[str] = queue.Queue() self.closed = False self.process_worker = threading.Thread( target=process_messages, args=(side, SocketWrapper(host, port), self), daemon=True, name="MessageQueue/" + str(side), ) self.process_worker.start() def add(self, message: Dict[str, Any]): self.outbox.put(json.dumps(message)) def __iter__(self): return self def __next__(self): m = self.inbox.get() LOG.debug(m) return json.loads(m) def close(self): self.closed = True self.process_worker.join() def process_messages(side: Side, socket: SocketWrapper, queue: MessageQueue): log = LOG.getChild("worker.{}".format(side)) while not queue.closed: try: socket.accept() if side == Side.SERVER else socket.connect() while not queue.closed: message = socket.receive() if message is not None: queue.inbox.put(message) while not queue.outbox.empty(): log.debug("Sending outbox item") message = queue.outbox.get(block=False) try: socket.send(message) except Exception as e: # TODO(chr): Doesn't preserve order, use priority queue queue.outbox.put(message) raise e from e except Exception as e: LOG.exception(e) finally: socket.close()