1
0
Fork 0

fix some bugs, another test, possibly enable the sending back to the server

This commit is contained in:
khr 2020-04-02 19:31:53 -07:00
parent d362a6c922
commit 2cbfc101f3
4 changed files with 93 additions and 41 deletions

View File

@ -140,7 +140,9 @@ class SocketWrapper:
return data.decode("utf-8") if data is not None else None return data.decode("utf-8") if data is not None else None
def close(self): def close(self):
if self.soc is not None:
self.soc.close() self.soc.close()
self.soc = None
class Side(enum.Enum): class Side(enum.Enum):
@ -160,12 +162,8 @@ class MessageQueue:
self.outbox: queue.Queue[str] = queue.Queue() self.outbox: queue.Queue[str] = queue.Queue()
self.closed = False self.closed = False
self.process_worker = threading.Thread( self.process_worker = threading.Thread(
target=( target=process_messages,
process_messages_client args=(side, SocketWrapper(host, port), self),
if side == Side.CLIENT
else process_messages_server
),
args=(SocketWrapper(host, port), self),
daemon=True, daemon=True,
name="MessageQueue/" + str(side), name="MessageQueue/" + str(side),
) )
@ -185,36 +183,24 @@ class MessageQueue:
self.process_worker.join() self.process_worker.join()
def process_messages_client(socket: SocketWrapper, queue: MessageQueue): def process_messages(side: Side, socket: SocketWrapper, queue: MessageQueue):
log = LOG.getChild("client.worker") log = LOG.getChild("worker.{}".format(side))
while not queue.closed: while not queue.closed:
try: try:
socket.connect() socket.accept() if side == Side.SERVER else socket.connect()
while not queue.closed: while not queue.closed:
message = socket.receive() message = socket.receive()
if message is not None: if message is not None:
queue.inbox.put(message) queue.inbox.put(message)
while not queue.outbox.empty(): while not queue.outbox.empty():
log.debug("Sending outbox item") log.debug("Sending outbox item")
socket.send(queue.outbox.get(block=False)) message = queue.outbox.get(block=False)
except Exception as e:
LOG.exception(e)
finally:
socket.close()
def process_messages_server(socket: SocketWrapper, queue: MessageQueue):
log = LOG.getChild("server.worker")
while not queue.closed:
try: try:
socket.accept() socket.send(message)
while not queue.closed: except Exception as e:
message = socket.receive() # TODO(chr): Doesn't preserve order, use priority queue
if message is not None: queue.outbox.put(message)
queue.inbox.put(message) raise e from e
while not queue.outbox.empty():
log.debug("Sending outbox item")
socket.send(queue.outbox.get(block=False))
except Exception as e: except Exception as e:
LOG.exception(e) LOG.exception(e)
finally: finally:

View File

@ -1,4 +1,3 @@
import logging
import random import random
import threading import threading
import unittest import unittest
@ -7,10 +6,8 @@ import message_queue
class MessageQueueTest(unittest.TestCase): class MessageQueueTest(unittest.TestCase):
def setUp(self):
logging.basicConfig(level=logging.DEBUG)
def test_message_queue(self): def test_message_queue(self):
"""Test basic functionality"""
port = random.randint(10000, 65535) port = random.randint(10000, 65535)
server_queue = message_queue.MessageQueue( server_queue = message_queue.MessageQueue(
host="localhost", port=port, side=message_queue.Side.SERVER host="localhost", port=port, side=message_queue.Side.SERVER
@ -37,6 +34,68 @@ class MessageQueueTest(unittest.TestCase):
client_queue.close() client_queue.close()
self.assertEqual(threading.active_count(), 1) self.assertEqual(threading.active_count(), 1)
def test_disconnect_reconnect_client(self):
"""Test that the server can send messages while the client is offline"""
port = random.randint(10000, 65535)
server_queue = message_queue.MessageQueue(
host="localhost", port=port, side=message_queue.Side.SERVER
)
client_queue = message_queue.MessageQueue(
host="localhost", port=port, side=message_queue.Side.CLIENT
)
# Worker threads + main thread
self.assertEqual(threading.active_count(), 3)
expected_message = {"index": "0"}
server_queue.add(expected_message)
received_message = next(client_queue)
self.assertEqual(expected_message, received_message)
expected_message = {"index": "1"}
client_queue.close()
server_queue.add(expected_message)
self.assertEqual(threading.active_count(), 2)
client_queue = message_queue.MessageQueue(
host="localhost", port=port, side=message_queue.Side.CLIENT
)
received_message = next(client_queue)
self.assertEqual(expected_message, received_message)
server_queue.close()
client_queue.close()
self.assertEqual(threading.active_count(), 1)
def test_disconnect_reconnect_server(self):
"""Test that the client can send messages while the server is offline"""
port = random.randint(10000, 65535)
server_queue = message_queue.MessageQueue(
host="localhost", port=port, side=message_queue.Side.SERVER
)
client_queue = message_queue.MessageQueue(
host="localhost", port=port, side=message_queue.Side.CLIENT
)
# Worker threads + main thread
self.assertEqual(threading.active_count(), 3)
expected_message = {"index": "0"}
client_queue.add(expected_message)
received_message = next(server_queue)
self.assertEqual(expected_message, received_message)
expected_message = {"index": "1"}
server_queue.close()
client_queue.add(expected_message)
self.assertEqual(threading.active_count(), 2)
server_queue = message_queue.MessageQueue(
host="localhost", port=port, side=message_queue.Side.SERVER
)
received_message = next(server_queue)
self.assertEqual(expected_message, received_message)
server_queue.close()
client_queue.close()
self.assertEqual(threading.active_count(), 1)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main(verbosity=2)

View File

@ -20,6 +20,7 @@ LOG = logging.getLogger(__name__)
USER_RE = re.compile(r"(?<=\@).*(?=\:)") USER_RE = re.compile(r"(?<=\@).*(?=\:)")
global_msg_queue = None
app = flask.Flask(__name__) app = flask.Flask(__name__)
roomsync = set() roomsync = set()
@ -42,7 +43,13 @@ def on_receive_events(transaction):
m_user = USER_RE.search(event["user_id"]).group(0) m_user = USER_RE.search(event["user_id"]).group(0)
m_cont = event["content"]["body"] m_cont = event["content"]["body"]
m_user, m_cont m_user, m_cont
# minecraft.msglist.insert(0, "/tellraw @a {\"text\":\"<" + m_user + "> " + m_cont + "\",\"insertion\":\"/tellraw @p %s\"}") if global_msg_queue is not None:
global_msg_queue.add({
"command",
'/tellraw @a {"text":"<{}> {}","insertion":"/tellraw @p %s"}'.format(
m_user, m_cont
),
})
return flask.jsonify({}) return flask.jsonify({})
@ -216,7 +223,7 @@ def main():
appservice_token=args.appservice_token, appservice_token=args.appservice_token,
matrix_server_name=args.matrix_server_name, matrix_server_name=args.matrix_server_name,
) )
queue = message_queue.MessageQueue( global_msg_queue = message_queue.MessageQueue(
host="0.0.0.0", host="0.0.0.0",
port=args.minecraft_wrapper_port, port=args.minecraft_wrapper_port,
side=message_queue.Side.SERVER, side=message_queue.Side.SERVER,
@ -225,14 +232,14 @@ def main():
target=app.run, kwargs={"port": args.matrix_api_port}, daemon=True, target=app.run, kwargs={"port": args.matrix_api_port}, daemon=True,
) )
receive_worker = threading.Thread( receive_worker = threading.Thread(
target=receive_messages, args=(appservice, queue), daemon=True, target=receive_messages, args=(appservice, global_msg_queue), daemon=True,
) )
flask_thread.start() flask_thread.start()
receive_worker.start() receive_worker.start()
LOG.info("All threads created") LOG.info("All threads created")
receive_worker.join() receive_worker.join()
flask_thread.join() flask_thread.join()
queue.close() global_msg_queue.close()
LOG.info("All threads terminated") LOG.info("All threads terminated")

View File

@ -14,7 +14,7 @@ LOG = logging.getLogger(__name__)
class ProcessWrapper: class ProcessWrapper:
"""Iterator that spawns a process and yields lines from its stdout.""" """Iterator that spawns a process and yields lines from its stdout."""
def __init__(self, command: List[str], queue: message_queue.MessageQueue): def __init__(self, command: List[str]):
self.proc = subprocess.Popen( self.proc = subprocess.Popen(
" ".join(command), " ".join(command),
shell=True, shell=True,
@ -48,7 +48,7 @@ def send_process_output(
) )
) )
msg_queue.add( msg_queue.add(
{"user": result.group(3), "msg": result.group(4).rstrip("\n"),}, {"user": result.group(3), "msg": result.group(4).rstrip("\n")},
) )
@ -65,7 +65,7 @@ def relay_queue_input(
def main(): def main():
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--matrix_server") parser.add_argument("--matrix_server", required=True)
parser.add_argument("--matrix_server_port", type=int, default=5001) parser.add_argument("--matrix_server_port", type=int, default=5001)
parser.add_argument("command", nargs=argparse.REMAINDER) parser.add_argument("command", nargs=argparse.REMAINDER)
args = parser.parse_args() args = parser.parse_args()