From 42ae40c0cd56d34746ebfc30fe37a3b0ae77a29c Mon Sep 17 00:00:00 2001 From: chr Date: Wed, 1 Apr 2020 23:33:34 -0700 Subject: [PATCH] big refactor, test message queue --- .gitignore | 1 + message_queue.py | 220 +++++++++++++++++++++++++++++++++++++ message_queue_test.py | 47 ++++++++ service.py | 244 ++++++++++++++++++++++++++++++++++++++++++ wrapper.py | 103 ++++++++++++++++++ 5 files changed, 615 insertions(+) create mode 100644 message_queue.py create mode 100644 message_queue_test.py create mode 100644 service.py create mode 100644 wrapper.py diff --git a/.gitignore b/.gitignore index d162113..64648b8 100644 --- a/.gitignore +++ b/.gitignore @@ -12,3 +12,4 @@ test/ world/ *.swp env/ +__pycache__/ diff --git a/message_queue.py b/message_queue.py new file mode 100644 index 0000000..5369726 --- /dev/null +++ b/message_queue.py @@ -0,0 +1,220 @@ +import enum +import logging +import json +import queue +import select +import socket +import struct +import time +import threading +from typing import Any, Callable, Dict, Optional, TypeVar + + +LOG = logging.getLogger(__name__) + + +class ProtocolError(Exception): + """Error thrown when the SocketWrapper protocol is violated.""" + + +T = TypeVar("T") + + +def _try_with_backoff(fn: Callable, error_callback: Callable) -> socket.socket: + backoff = 1 + while True: + try: + return fn() + except Exception as e: + if error_callback(e): + LOG.exception(e) + LOG.warning("Trying again in {} seconds".format(backoff)) + time.sleep(backoff) + backoff *= 2 + else: + raise e + else: + break + + +def socket_create_server(addr): + # socket.create_server doesn't exist until python 3.8 :( + soc = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + soc.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + soc.bind(addr) + return soc + + +class SocketWrapper: + """Wraps a network socket with simpler connection, send, and receive logic. + + Both ends of the connection must be using a SocketWrapper compatible + protocol. The protocol is as follows: + + Messages are utf-8 encoded strings prefixed by their length as a 4-byte + big-endian integer. Example: 00 00 00 06 h e l l o \n. + """ + + def __init__(self, host: str, port: int): + self.host = host + self.port = port + self.soc = None + + def connect(self): + LOG.debug("Connecting to {}:{}".format(self.host, self.port)) + self.soc = _try_with_backoff( + lambda: socket.create_connection((self.host, self.port)), + lambda e: e is OSError and e.errno == 111 + ) + LOG.info("Socket Connected") + + def bind(self): + LOG.info("Server Binding to {}:{}".format(self.host, self.port)) + self.soc = _try_with_backoff( + lambda: socket_create_server((self.host, self.port)), + lambda e: e is OSError and e.errno == 98 + ) + LOG.info("Server Bound") + + def accept(self): + self.bind() + self.soc.listen(1) + LOG.info("Server listen to host") + client_soc, addr = self.soc.accept() + self.soc.close() + self.soc = client_soc + self.soc.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + # self.soc.send(b"\0\0\0\0") + LOG.info("Server accepted connection: {}".format(addr)) + + def _write_int(self, integer): + integer_buf = struct.pack(">i", integer) + self._write(integer_buf) + + def _write(self, data: bytes): + if self.soc is None: + raise RuntimeError("Must call connect or accept before writing") + data_len = len(data) + offset = 0 + while offset != data_len: + offset += self.soc.send(data[offset:]) + + def send(self, message: str): + _, writable, _ = select.select([], [self.soc], [], 1) + if writable == []: + # TODO Try to reacquire + raise RuntimeError("Unable to write to socket") + payload = message.encode("utf-8") + self._write_int(len(payload)) + self._write(payload) + LOG.debug("sent {} bytes".format(len(payload))) + + def _read(self, size) -> Optional[bytes]: + data = b"" + while len(data) != size: + newdata = self.soc.recv(size - len(data)) + if len(newdata) == 0: + # Orderly shutdown, or 0 bytes requested to read + return data + data = data + newdata + return data + + def _read_int(self) -> Optional[int]: + int_size = struct.calcsize(">i") + intbuf = self._read(int_size) + if len(intbuf) == 0: + return None + return struct.unpack(">i", intbuf)[0] + + def receive(self) -> Optional[str]: + if self.soc is None: + raise RuntimeError("Must call connect or accept before writing") + readable, _, _ = select.select([self.soc], [], [], 1) + if readable == []: + return None + message_size = self._read_int() + if message_size is None: + # socket closed + return None + data = self._read(message_size) + return data.decode("utf-8") if data is not None else None + + def close(self): + self.soc.close() + + +class Side(enum.Enum): + CLIENT = enum.auto() + SERVER = enum.auto() + + +class MessageQueue: + """A bidirectional queue of JSON messages over a network socket. + + Asynchronously sends and receives messages until closed. Consume messages + with iteration and enqueue messages with add(). + """ + + def __init__(self, host: str, port: int, side: Side): + self.inbox: queue.Queue[str] = queue.Queue() + self.outbox: queue.Queue[str] = queue.Queue() + self.closed = False + self.process_worker = threading.Thread( + target=( + process_messages_client if side == Side.CLIENT + else process_messages_server + ), + args=(SocketWrapper(host, port), self), + daemon=True, + name="MessageQueue/" + str(side) + ) + self.process_worker.start() + + def add(self, message: Dict[str, Any]): + self.outbox.put(json.dumps(message)) + + def __iter__(self): + return self + + def __next__(self): + return json.loads(self.inbox.get()) + + def close(self): + self.closed = True + self.process_worker.join() + + +def process_messages_client(socket: SocketWrapper, queue: MessageQueue): + log = LOG.getChild("client.worker") + while not queue.closed: + try: + socket.connect() + while not queue.closed: + message = socket.receive() + if message is not None: + queue.inbox.put(message) + while not queue.outbox.empty(): + log.debug("Sending outbox item") + socket.send(queue.outbox.get(block=False)) + except Exception as e: + LOG.exception(e) + finally: + socket.close() + + +def process_messages_server(socket: SocketWrapper, queue: MessageQueue): + log = LOG.getChild("server.worker") + while not queue.closed: + try: + socket.accept() + while not queue.closed: + message = socket.receive() + if message is not None: + queue.inbox.put(message) + while not queue.outbox.empty(): + log.debug("Sending outbox item") + socket.send(queue.outbox.get(block=False)) + except Exception as e: + LOG.exception(e) + finally: + socket.close() diff --git a/message_queue_test.py b/message_queue_test.py new file mode 100644 index 0000000..3f50e5c --- /dev/null +++ b/message_queue_test.py @@ -0,0 +1,47 @@ +import logging +import random +import threading +import unittest + +import message_queue + + +class MessageQueueTest(unittest.TestCase): + + def setUp(self): + logging.basicConfig(level=logging.DEBUG) + + def test_message_queue(self): + port = random.randint(10000, 65535) + server_queue = message_queue.MessageQueue( + host="localhost", + port=port, + side=message_queue.Side.SERVER + ) + client_queue = message_queue.MessageQueue( + host="localhost", + port=port, + side=message_queue.Side.CLIENT + ) + # Worker threads + main thread + self.assertEqual(threading.active_count(), 3) + + # Test sending server -> client + expected_message = {"hello": "client", "from": "server"} + server_queue.add(expected_message) + received_message = next(client_queue) + self.assertEqual(expected_message, received_message) + + # Test sending client-> server + expected_message = {"hello": "server", "from": "client"} + client_queue.add(expected_message) + received_message = next(server_queue) + self.assertEqual(expected_message, received_message) + + server_queue.close() + client_queue.close() + self.assertEqual(threading.active_count(), 1) + + +if __name__ == "__main__": + unittest.main() diff --git a/service.py b/service.py new file mode 100644 index 0000000..8870efa --- /dev/null +++ b/service.py @@ -0,0 +1,244 @@ +import argparse +import base64 +import io +import json +import logging +import re +import threading +import time +from typing import Dict +import urllib + +from matrix_client.api import MatrixHttpApi +import PIL +import requests +import flask + +import message_queue + +LOG = logging.getLogger(__name__) + +USER_RE = re.compile(r"(?<=\@).*(?=\:)") + +app = flask.Flask(__name__) +roomsync = set() + + +@app.route("/transactions/", methods=["PUT"]) +def on_receive_events(transaction): + LOG.info("got event") + events = flask.request.get_json()["events"] + for event in events: + LOG.info("User: %s Room: %s" % (event["user_id"], event["room_id"])) + LOG.info("Event Type: %s" % event["type"]) + LOG.info("Content: %s" % event["content"]) + roomsync.add(event["room_id"]) + if ( + event["type"] == "m.room.message" + and event["content"]["msgtype"] == "m.text" + and event["user_id"].find("@mc_") == -1 + ): + + m_user = USER_RE.search(event["user_id"]).group(0) + m_cont = event["content"]["body"] + m_user, m_cont + # minecraft.msglist.insert(0, "/tellraw @a {\"text\":\"<" + m_user + "> " + m_cont + "\",\"insertion\":\"/tellraw @p %s\"}") + + return flask.jsonify({}) + + +@app.route("/rooms/", methods=["GET"]) +def on_room(room): + LOG.info("returning: " + str(room)) + return flask.jsonify({}) + + +class Appservice: + def __init__(self, appservice_token: str, matrix_server_name: str): + self.api = MatrixHttpApi( + "http://localhost:8008", token=appservice_token + ) + self.avatar_update_log: Dict[str, float] = {} + self.matrix_server_name = matrix_server_name + + def process_message(self, msg): + # for msg, create user and post as user + # add minecraft user to minecraft channel, if this fails, no big deal + try: + new_user = "mc_" + msg["user"] + user_id = "@{}:{}".format(new_user, self.matrix_server_name) + LOG.info("trying to create user {}...".format(new_user)) + self.api.register( + {"type": "m.login.application_service", "username": new_user} + ) + except Exception as e: + LOG.exception(e) + # for each room we're aware of, post server chat inside. Eventually 1 room should equal 1 server + for room in roomsync: + # generate a unique transaction id based on the current time + txn_id = str(int(time.time() * 1000)) + # attempt to join room + LOG.info("trying to join room as user and as bridge manager") + self.api._send( + "POST", + "/rooms/" + room + "/join", + query_params={"user_id": user_id}, + headers={"Content-Type": "application/json"}, + ) + self.api._send( + "POST", + "/rooms/" + room + "/join", + headers={"Content-Type": "application/json"}, + ) + # set our display name to something nice + LOG.info("trying to set display name...") + self.api._send( + "PUT", + "/profile/" + user_id + "/displayname/", + content={"displayname": msg["user"]}, + query_params={"user_id": user_id}, + headers={"Content-Type": "application/json"}, + ) + + # get our mc skin!! + # backup: #avatar_url = "https://www.minecraftskinstealer.com/face.php?u="+msg['user'] + # only get this if the user hasn't updated in a long time + try: + LOG.info("Checking if we need to update avatar...") + if ( + msg["user"] not in self.avatar_update_log.keys() + or abs(self.avatar_update_log[msg["user"]] - time.time()) + > 180 + ): + self.avatar_update_log[msg["user"]] = time.time() + avatar_url = self.get_mc_skin(msg["user"], user_id) + if avatar_url: + LOG.debug("avatar_url is " + avatar_url) + self.api._send( + "PUT", + "/profile/" + user_id + "/avatar_url/", + content={"avatar_url": avatar_url}, + query_params={"user_id": user_id}, + headers={"Content-Type": "application/json"}, + ) + except Exception as e: + LOG.exception(e) + # Not the end of the world if it fails, send the message now. + + # attempt to post in room + LOG.info("Attempting to post in Room") + self.api._send( + "PUT", + "/rooms/" + room + "/send/m.room.message/" + txn_id, + content={"msgtype": "m.text", "body": msg["msg"]}, + query_params={"user_id": user_id}, + headers={"Content-Type": "application/json"}, + ) + + def get_mc_skin(self, user, user_id): + LOG.info("Getting Minecraft Avatar") + + mojang_info = requests.get( + "https://api.mojang.com/users/profiles/minecraft/" + user + ).json() # get uuid + mojang_info = requests.get( + "https://sessionserver.mojang.com/session/minecraft/profile/" + + mojang_info["id"] + ).json() # get more info from uuid + mojang_info = json.loads( + base64.b64decode(mojang_info["properties"][0]["value"]) + ) + mojang_url = mojang_info["textures"]["SKIN"]["url"] + # r = requests.get(mojang_url, stream=True) + # r.raw.decode_content = True # handle spurious Content-Encoding + file = io.BytesIO(urllib.request.urlopen(mojang_url).read()) + im = PIL.Image.open(file) + img_head = im.crop((8, 8, 16, 16)) + img_head = img_head.resize( + (im.width * 8, im.height * 8), resample=PIL.Image.NEAREST + ) # Resize with nearest neighbor to get pixels + image_buffer_head = io.BytesIO() + img_head.save(image_buffer_head, "PNG") + + # compare to user's current id so we're not uploading the same pic twice + # GET /_matrix/client/r0/profile/{userId}/avatar_url + LOG.info("Getting Current Avatar URL") + curr_url = self.api._send( + "GET", + "/profile/" + user_id + "/avatar_url/", + query_params={"user_id": user_id}, + headers={"Content-Type": "application/json"}, + ) + upload = True + if "avatar_url" in curr_url.keys(): + LOG.info("Checking Avatar...") + file = io.BytesIO( + urllib.request.urlopen( + self.api.get_download_url(curr_url["avatar_url"]) + ).read() + ) + im = PIL.Image.open(file) + image_buffer_curr = io.BytesIO() + im.save(image_buffer_curr, "PNG") + if (image_buffer_head.getvalue()) == (image_buffer_curr.getvalue()): + LOG.debug("Image Same") + upload = False + if upload: + # upload img + # POST /_matrix/media/r0/upload + LOG.debug("Returning updated avatar") + LOG.debug(image_buffer_head) + return self.api.media_upload( + image_buffer_head.getvalue(), "image/png" + )["content_uri"] + else: + return None + + +def receive_messages( + appservice: Appservice, msg_queue: message_queue.MessageQueue +): + for message in msg_queue: + appservice.process_message(message) + + +def main(): + logging.basicConfig(level=logging.DEBUG) + parser = argparse.ArgumentParser() + parser.add_argument("--matrix_server_name", required=True) + parser.add_argument("--appservice_token", required=True) + parser.add_argument("--matrix_api_port", type=int, default=5000) + parser.add_argument("--minecraft_wrapper_port", type=int, default=5001) + args = parser.parse_args() + + LOG.info("Running Minecraft Matrix Bridge") + appservice = Appservice( + appservice_token=args.appservice_token, + matrix_server_name=args.matrix_server_name, + ) + queue = message_queue.MessageQueue( + host="0.0.0.0", + port=args.minecraft_wrapper_port, + side=message_queue.Side.SERVER, + ) + flask_thread = threading.Thread( + target=app.run, + kwargs={"port": args.matrix_api_port}, + daemon=True, + ) + receive_worker = threading.Thread( + target=receive_messages, + args=(appservice, queue), + daemon=True, + ) + flask_thread.start() + receive_worker.start() + LOG.info("All threads created") + receive_worker.join() + flask_thread.join() + queue.close() + LOG.info("All threads terminated") + + +if __name__ == "__main__": + main() diff --git a/wrapper.py b/wrapper.py new file mode 100644 index 0000000..b585502 --- /dev/null +++ b/wrapper.py @@ -0,0 +1,103 @@ +import argparse +import logging +import re +import subprocess +import threading +from typing import List + +import message_queue + + +LOG = logging.getLogger(__name__) + + +class ProcessWrapper: + """Iterator that spawns a process and yields lines from its stdout.""" + + def __init__(self, command: List[str], queue: message_queue.MessageQueue): + self.proc = subprocess.Popen( + " ".join(command), + shell=True, + stdout=subprocess.PIPE, + stdin=subprocess.PIPE, + universal_newlines=True, + ) + + def __iter__(self): + return iter(self.proc.stdout.readline, "") + + def send(self, msg): + self.proc.stdin.write(msg) + + def wait(self): + return self.proc.wait() + + +def send_process_output( + process: ProcessWrapper, msg_queue: message_queue.MessageQueue +): + # "[07:36:28] [Server thread/INFO] [minecraft/DedicatedServer]: test" + prog = re.compile(r"\[.*\] \[(.*)\] \[(.*)\]: <(.*)> (.*)") + for line in process: + LOG.info(line.rstrip("\n")) + result = prog.search(line) + if result: + LOG.info("user: {} msg: {}".format( + result.group(3), + result.group(4).rstrip("\n"), + )) + msg_queue.add( + { + "user": result.group(3), + "msg": result.group(4).rstrip("\n"), + }, + ) + + +def relay_queue_input( + process: ProcessWrapper, msg_queue: message_queue.MessageQueue +): + for message in msg_queue: + if "command" in message: + process.send(message["command"]) + else: + LOG.debug(message) + + +def main(): + logging.basicConfig(level=logging.DEBUG) + parser = argparse.ArgumentParser() + parser.add_argument("--matrix_server") + parser.add_argument("--matrix_server_port", type=int, default=5001) + parser.add_argument("command", nargs=argparse.REMAINDER) + args = parser.parse_args() + + LOG.info("Running Minecraft Server Wrapper") + wrapper = ProcessWrapper(args.command) + queue = message_queue.MessageQueue( + host=args.matrix_server, + port=args.matrix_server_port, + side=message_queue.Side.CLIENT, + ) + send_worker = threading.Thread( + target=send_process_output, + args=(wrapper, queue), + daemon=True, + ) + receive_worker = threading.Thread( + target=relay_queue_input, + args=(wrapper, queue), + daemon=True, + ) + send_worker.start() + receive_worker.start() + LOG.info("All threads created") + send_worker.join() + receive_worker.join() + queue.close() + LOG.info("All threads terminated") + return wrapper.wait() + + +if __name__ == "__main__": + main()