210 lines
6.3 KiB
Python
210 lines
6.3 KiB
Python
import enum
|
|
import logging
|
|
import json
|
|
import queue
|
|
import select
|
|
import socket
|
|
import struct
|
|
import time
|
|
import threading
|
|
from typing import Any, Callable, Dict, Optional, TypeVar
|
|
|
|
|
|
LOG = logging.getLogger(__name__)
|
|
|
|
|
|
class ProtocolError(Exception):
|
|
"""Error thrown when the SocketWrapper protocol is violated."""
|
|
|
|
|
|
T = TypeVar("T")
|
|
|
|
|
|
def _try_with_backoff(fn: Callable[[], T], error_callback: Callable) -> T:
|
|
backoff = 1
|
|
while True:
|
|
try:
|
|
return fn()
|
|
except Exception as e:
|
|
if error_callback(e):
|
|
LOG.exception(e)
|
|
LOG.warning("Trying again in {} seconds".format(backoff))
|
|
time.sleep(backoff)
|
|
backoff *= 2
|
|
else:
|
|
raise e
|
|
else:
|
|
break
|
|
|
|
|
|
def socket_create_server(addr):
|
|
# socket.create_server doesn't exist until python 3.8 :(
|
|
soc = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
soc.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
|
soc.bind(addr)
|
|
return soc
|
|
|
|
|
|
class SocketWrapper:
|
|
"""Wraps a network socket with simpler connection, send, and receive logic.
|
|
|
|
Both ends of the connection must be using a SocketWrapper compatible
|
|
protocol. The protocol is as follows:
|
|
|
|
Messages are utf-8 encoded strings prefixed by their length as a 4-byte
|
|
big-endian integer. Example: 00 00 00 06 h e l l o \n.
|
|
"""
|
|
|
|
def __init__(self, host: str, port: int):
|
|
self.host = host
|
|
self.port = port
|
|
self.soc = None
|
|
|
|
def connect(self):
|
|
LOG.debug("Connecting to {}:{}".format(self.host, self.port))
|
|
self.soc = _try_with_backoff(
|
|
lambda: socket.create_connection((self.host, self.port)),
|
|
lambda e: isinstance(e, OSError) and e.errno == 111,
|
|
)
|
|
LOG.info("Socket Connected")
|
|
|
|
def bind(self):
|
|
LOG.info("Server Binding to {}:{}".format(self.host, self.port))
|
|
self.soc = _try_with_backoff(
|
|
lambda: socket_create_server((self.host, self.port)),
|
|
lambda e: isinstance(e, OSError) and e.errno == 98,
|
|
)
|
|
LOG.info("Server Bound")
|
|
|
|
def accept(self):
|
|
self.bind()
|
|
self.soc.listen(1)
|
|
LOG.info("Server listen to host")
|
|
client_soc, addr = self.soc.accept()
|
|
self.soc.close()
|
|
self.soc = client_soc
|
|
self.soc.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
|
# self.soc.send(b"\0\0\0\0")
|
|
LOG.info("Server accepted connection: {}".format(addr))
|
|
|
|
def _write_int(self, integer):
|
|
integer_buf = struct.pack(">i", integer)
|
|
self._write(integer_buf)
|
|
|
|
def _write(self, data: bytes):
|
|
if self.soc is None:
|
|
raise RuntimeError("Must call connect or accept before writing")
|
|
data_len = len(data)
|
|
offset = 0
|
|
while offset != data_len:
|
|
offset += self.soc.send(data[offset:])
|
|
|
|
def send(self, message: str):
|
|
_, writable, _ = select.select([], [self.soc], [], 1)
|
|
if writable == []:
|
|
# TODO Try to reacquire
|
|
raise RuntimeError("Unable to write to socket")
|
|
payload = message.encode("utf-8")
|
|
self._write_int(len(payload))
|
|
self._write(payload)
|
|
LOG.debug("sent {} bytes".format(len(payload)))
|
|
|
|
def _read(self, size) -> bytes:
|
|
data = b""
|
|
while len(data) != size:
|
|
newdata = self.soc.recv(size - len(data))
|
|
if len(newdata) == 0:
|
|
# Orderly shutdown, or 0 bytes requested to read
|
|
return data
|
|
data = data + newdata
|
|
return data
|
|
|
|
def _read_int(self) -> Optional[int]:
|
|
int_size = struct.calcsize(">i")
|
|
intbuf = self._read(int_size)
|
|
if len(intbuf) == 0:
|
|
return None
|
|
return struct.unpack(">i", intbuf)[0]
|
|
|
|
def receive(self) -> Optional[str]:
|
|
if self.soc is None:
|
|
raise RuntimeError("Must call connect or accept before writing")
|
|
readable, _, _ = select.select([self.soc], [], [], 1)
|
|
if readable == []:
|
|
return None
|
|
message_size = self._read_int()
|
|
if message_size is None:
|
|
# socket closed
|
|
return None
|
|
data = self._read(message_size)
|
|
return data.decode("utf-8") if data is not None else None
|
|
|
|
def close(self):
|
|
if self.soc is not None:
|
|
self.soc.close()
|
|
self.soc = None
|
|
|
|
|
|
class Side(enum.Enum):
|
|
CLIENT = enum.auto()
|
|
SERVER = enum.auto()
|
|
|
|
|
|
class MessageQueue:
|
|
"""A bidirectional queue of JSON messages over a network socket.
|
|
|
|
Asynchronously sends and receives messages until closed. Consume messages
|
|
with iteration and enqueue messages with add().
|
|
"""
|
|
|
|
def __init__(self, host: str, port: int, side: Side):
|
|
self.inbox: queue.Queue[str] = queue.Queue()
|
|
self.outbox: queue.Queue[str] = queue.Queue()
|
|
self.closed = False
|
|
self.process_worker = threading.Thread(
|
|
target=process_messages,
|
|
args=(side, SocketWrapper(host, port), self),
|
|
daemon=True,
|
|
name="MessageQueue/" + str(side),
|
|
)
|
|
self.process_worker.start()
|
|
|
|
def add(self, message: Dict[str, Any]):
|
|
self.outbox.put(json.dumps(message))
|
|
|
|
def __iter__(self):
|
|
return self
|
|
|
|
def __next__(self):
|
|
m = self.inbox.get()
|
|
LOG.debug(m)
|
|
return json.loads(m)
|
|
|
|
def close(self):
|
|
self.closed = True
|
|
self.process_worker.join()
|
|
|
|
|
|
def process_messages(side: Side, socket: SocketWrapper, queue: MessageQueue):
|
|
log = LOG.getChild("worker.{}".format(side))
|
|
while not queue.closed:
|
|
try:
|
|
socket.accept() if side == Side.SERVER else socket.connect()
|
|
while not queue.closed:
|
|
message = socket.receive()
|
|
if message is not None:
|
|
queue.inbox.put(message)
|
|
while not queue.outbox.empty():
|
|
log.debug("Sending outbox item")
|
|
message = queue.outbox.get(block=False)
|
|
try:
|
|
socket.send(message)
|
|
except Exception as e:
|
|
# TODO(chr): Doesn't preserve order, use priority queue
|
|
queue.outbox.put(message)
|
|
raise e from e
|
|
except Exception as e:
|
|
LOG.exception(e)
|
|
finally:
|
|
socket.close()
|