diff --git a/message_queue.py b/message_queue.py index bd88d53..ae6dd17 100644 --- a/message_queue.py +++ b/message_queue.py @@ -140,7 +140,9 @@ class SocketWrapper: return data.decode("utf-8") if data is not None else None def close(self): - self.soc.close() + if self.soc is not None: + self.soc.close() + self.soc = None class Side(enum.Enum): @@ -160,12 +162,8 @@ class MessageQueue: self.outbox: queue.Queue[str] = queue.Queue() self.closed = False self.process_worker = threading.Thread( - target=( - process_messages_client - if side == Side.CLIENT - else process_messages_server - ), - args=(SocketWrapper(host, port), self), + target=process_messages, + args=(side, SocketWrapper(host, port), self), daemon=True, name="MessageQueue/" + str(side), ) @@ -185,36 +183,24 @@ class MessageQueue: self.process_worker.join() -def process_messages_client(socket: SocketWrapper, queue: MessageQueue): - log = LOG.getChild("client.worker") +def process_messages(side: Side, socket: SocketWrapper, queue: MessageQueue): + log = LOG.getChild("worker.{}".format(side)) while not queue.closed: try: - socket.connect() + socket.accept() if side == Side.SERVER else socket.connect() while not queue.closed: message = socket.receive() if message is not None: queue.inbox.put(message) while not queue.outbox.empty(): log.debug("Sending outbox item") - socket.send(queue.outbox.get(block=False)) - except Exception as e: - LOG.exception(e) - finally: - socket.close() - - -def process_messages_server(socket: SocketWrapper, queue: MessageQueue): - log = LOG.getChild("server.worker") - while not queue.closed: - try: - socket.accept() - while not queue.closed: - message = socket.receive() - if message is not None: - queue.inbox.put(message) - while not queue.outbox.empty(): - log.debug("Sending outbox item") - socket.send(queue.outbox.get(block=False)) + message = queue.outbox.get(block=False) + try: + socket.send(message) + except Exception as e: + # TODO(chr): Doesn't preserve order, use priority queue + queue.outbox.put(message) + raise e from e except Exception as e: LOG.exception(e) finally: diff --git a/message_queue_test.py b/message_queue_test.py index 14564d5..8d86a10 100644 --- a/message_queue_test.py +++ b/message_queue_test.py @@ -1,4 +1,3 @@ -import logging import random import threading import unittest @@ -7,10 +6,8 @@ import message_queue class MessageQueueTest(unittest.TestCase): - def setUp(self): - logging.basicConfig(level=logging.DEBUG) - def test_message_queue(self): + """Test basic functionality""" port = random.randint(10000, 65535) server_queue = message_queue.MessageQueue( host="localhost", port=port, side=message_queue.Side.SERVER @@ -37,6 +34,68 @@ class MessageQueueTest(unittest.TestCase): client_queue.close() self.assertEqual(threading.active_count(), 1) + def test_disconnect_reconnect_client(self): + """Test that the server can send messages while the client is offline""" + port = random.randint(10000, 65535) + server_queue = message_queue.MessageQueue( + host="localhost", port=port, side=message_queue.Side.SERVER + ) + client_queue = message_queue.MessageQueue( + host="localhost", port=port, side=message_queue.Side.CLIENT + ) + # Worker threads + main thread + self.assertEqual(threading.active_count(), 3) + + expected_message = {"index": "0"} + server_queue.add(expected_message) + received_message = next(client_queue) + self.assertEqual(expected_message, received_message) + + expected_message = {"index": "1"} + client_queue.close() + server_queue.add(expected_message) + self.assertEqual(threading.active_count(), 2) + client_queue = message_queue.MessageQueue( + host="localhost", port=port, side=message_queue.Side.CLIENT + ) + received_message = next(client_queue) + self.assertEqual(expected_message, received_message) + + server_queue.close() + client_queue.close() + self.assertEqual(threading.active_count(), 1) + + def test_disconnect_reconnect_server(self): + """Test that the client can send messages while the server is offline""" + port = random.randint(10000, 65535) + server_queue = message_queue.MessageQueue( + host="localhost", port=port, side=message_queue.Side.SERVER + ) + client_queue = message_queue.MessageQueue( + host="localhost", port=port, side=message_queue.Side.CLIENT + ) + # Worker threads + main thread + self.assertEqual(threading.active_count(), 3) + + expected_message = {"index": "0"} + client_queue.add(expected_message) + received_message = next(server_queue) + self.assertEqual(expected_message, received_message) + + expected_message = {"index": "1"} + server_queue.close() + client_queue.add(expected_message) + self.assertEqual(threading.active_count(), 2) + server_queue = message_queue.MessageQueue( + host="localhost", port=port, side=message_queue.Side.SERVER + ) + received_message = next(server_queue) + self.assertEqual(expected_message, received_message) + + server_queue.close() + client_queue.close() + self.assertEqual(threading.active_count(), 1) + if __name__ == "__main__": - unittest.main() + unittest.main(verbosity=2) diff --git a/service.py b/service.py index 2141586..47a7348 100644 --- a/service.py +++ b/service.py @@ -20,6 +20,7 @@ LOG = logging.getLogger(__name__) USER_RE = re.compile(r"(?<=\@).*(?=\:)") +global_msg_queue = None app = flask.Flask(__name__) roomsync = set() @@ -42,7 +43,13 @@ def on_receive_events(transaction): m_user = USER_RE.search(event["user_id"]).group(0) m_cont = event["content"]["body"] m_user, m_cont - # minecraft.msglist.insert(0, "/tellraw @a {\"text\":\"<" + m_user + "> " + m_cont + "\",\"insertion\":\"/tellraw @p %s\"}") + if global_msg_queue is not None: + global_msg_queue.add({ + "command", + '/tellraw @a {"text":"<{}> {}","insertion":"/tellraw @p %s"}'.format( + m_user, m_cont + ), + }) return flask.jsonify({}) @@ -216,7 +223,7 @@ def main(): appservice_token=args.appservice_token, matrix_server_name=args.matrix_server_name, ) - queue = message_queue.MessageQueue( + global_msg_queue = message_queue.MessageQueue( host="0.0.0.0", port=args.minecraft_wrapper_port, side=message_queue.Side.SERVER, @@ -225,14 +232,14 @@ def main(): target=app.run, kwargs={"port": args.matrix_api_port}, daemon=True, ) receive_worker = threading.Thread( - target=receive_messages, args=(appservice, queue), daemon=True, + target=receive_messages, args=(appservice, global_msg_queue), daemon=True, ) flask_thread.start() receive_worker.start() LOG.info("All threads created") receive_worker.join() flask_thread.join() - queue.close() + global_msg_queue.close() LOG.info("All threads terminated") diff --git a/wrapper.py b/wrapper.py index 5797b4f..c5ce272 100644 --- a/wrapper.py +++ b/wrapper.py @@ -14,7 +14,7 @@ LOG = logging.getLogger(__name__) class ProcessWrapper: """Iterator that spawns a process and yields lines from its stdout.""" - def __init__(self, command: List[str], queue: message_queue.MessageQueue): + def __init__(self, command: List[str]): self.proc = subprocess.Popen( " ".join(command), shell=True, @@ -48,7 +48,7 @@ def send_process_output( ) ) msg_queue.add( - {"user": result.group(3), "msg": result.group(4).rstrip("\n"),}, + {"user": result.group(3), "msg": result.group(4).rstrip("\n")}, ) @@ -65,7 +65,7 @@ def relay_queue_input( def main(): logging.basicConfig(level=logging.DEBUG) parser = argparse.ArgumentParser() - parser.add_argument("--matrix_server") + parser.add_argument("--matrix_server", required=True) parser.add_argument("--matrix_server_port", type=int, default=5001) parser.add_argument("command", nargs=argparse.REMAINDER) args = parser.parse_args()