1
0
Fork 0

fix some bugs, another test, possibly enable the sending back to the server

This commit is contained in:
khr 2020-04-02 19:31:53 -07:00
parent d362a6c922
commit 2cbfc101f3
4 changed files with 93 additions and 41 deletions

View File

@ -140,7 +140,9 @@ class SocketWrapper:
return data.decode("utf-8") if data is not None else None
def close(self):
self.soc.close()
if self.soc is not None:
self.soc.close()
self.soc = None
class Side(enum.Enum):
@ -160,12 +162,8 @@ class MessageQueue:
self.outbox: queue.Queue[str] = queue.Queue()
self.closed = False
self.process_worker = threading.Thread(
target=(
process_messages_client
if side == Side.CLIENT
else process_messages_server
),
args=(SocketWrapper(host, port), self),
target=process_messages,
args=(side, SocketWrapper(host, port), self),
daemon=True,
name="MessageQueue/" + str(side),
)
@ -185,36 +183,24 @@ class MessageQueue:
self.process_worker.join()
def process_messages_client(socket: SocketWrapper, queue: MessageQueue):
log = LOG.getChild("client.worker")
def process_messages(side: Side, socket: SocketWrapper, queue: MessageQueue):
log = LOG.getChild("worker.{}".format(side))
while not queue.closed:
try:
socket.connect()
socket.accept() if side == Side.SERVER else socket.connect()
while not queue.closed:
message = socket.receive()
if message is not None:
queue.inbox.put(message)
while not queue.outbox.empty():
log.debug("Sending outbox item")
socket.send(queue.outbox.get(block=False))
except Exception as e:
LOG.exception(e)
finally:
socket.close()
def process_messages_server(socket: SocketWrapper, queue: MessageQueue):
log = LOG.getChild("server.worker")
while not queue.closed:
try:
socket.accept()
while not queue.closed:
message = socket.receive()
if message is not None:
queue.inbox.put(message)
while not queue.outbox.empty():
log.debug("Sending outbox item")
socket.send(queue.outbox.get(block=False))
message = queue.outbox.get(block=False)
try:
socket.send(message)
except Exception as e:
# TODO(chr): Doesn't preserve order, use priority queue
queue.outbox.put(message)
raise e from e
except Exception as e:
LOG.exception(e)
finally:

View File

@ -1,4 +1,3 @@
import logging
import random
import threading
import unittest
@ -7,10 +6,8 @@ import message_queue
class MessageQueueTest(unittest.TestCase):
def setUp(self):
logging.basicConfig(level=logging.DEBUG)
def test_message_queue(self):
"""Test basic functionality"""
port = random.randint(10000, 65535)
server_queue = message_queue.MessageQueue(
host="localhost", port=port, side=message_queue.Side.SERVER
@ -37,6 +34,68 @@ class MessageQueueTest(unittest.TestCase):
client_queue.close()
self.assertEqual(threading.active_count(), 1)
def test_disconnect_reconnect_client(self):
"""Test that the server can send messages while the client is offline"""
port = random.randint(10000, 65535)
server_queue = message_queue.MessageQueue(
host="localhost", port=port, side=message_queue.Side.SERVER
)
client_queue = message_queue.MessageQueue(
host="localhost", port=port, side=message_queue.Side.CLIENT
)
# Worker threads + main thread
self.assertEqual(threading.active_count(), 3)
expected_message = {"index": "0"}
server_queue.add(expected_message)
received_message = next(client_queue)
self.assertEqual(expected_message, received_message)
expected_message = {"index": "1"}
client_queue.close()
server_queue.add(expected_message)
self.assertEqual(threading.active_count(), 2)
client_queue = message_queue.MessageQueue(
host="localhost", port=port, side=message_queue.Side.CLIENT
)
received_message = next(client_queue)
self.assertEqual(expected_message, received_message)
server_queue.close()
client_queue.close()
self.assertEqual(threading.active_count(), 1)
def test_disconnect_reconnect_server(self):
"""Test that the client can send messages while the server is offline"""
port = random.randint(10000, 65535)
server_queue = message_queue.MessageQueue(
host="localhost", port=port, side=message_queue.Side.SERVER
)
client_queue = message_queue.MessageQueue(
host="localhost", port=port, side=message_queue.Side.CLIENT
)
# Worker threads + main thread
self.assertEqual(threading.active_count(), 3)
expected_message = {"index": "0"}
client_queue.add(expected_message)
received_message = next(server_queue)
self.assertEqual(expected_message, received_message)
expected_message = {"index": "1"}
server_queue.close()
client_queue.add(expected_message)
self.assertEqual(threading.active_count(), 2)
server_queue = message_queue.MessageQueue(
host="localhost", port=port, side=message_queue.Side.SERVER
)
received_message = next(server_queue)
self.assertEqual(expected_message, received_message)
server_queue.close()
client_queue.close()
self.assertEqual(threading.active_count(), 1)
if __name__ == "__main__":
unittest.main()
unittest.main(verbosity=2)

View File

@ -20,6 +20,7 @@ LOG = logging.getLogger(__name__)
USER_RE = re.compile(r"(?<=\@).*(?=\:)")
global_msg_queue = None
app = flask.Flask(__name__)
roomsync = set()
@ -42,7 +43,13 @@ def on_receive_events(transaction):
m_user = USER_RE.search(event["user_id"]).group(0)
m_cont = event["content"]["body"]
m_user, m_cont
# minecraft.msglist.insert(0, "/tellraw @a {\"text\":\"<" + m_user + "> " + m_cont + "\",\"insertion\":\"/tellraw @p %s\"}")
if global_msg_queue is not None:
global_msg_queue.add({
"command",
'/tellraw @a {"text":"<{}> {}","insertion":"/tellraw @p %s"}'.format(
m_user, m_cont
),
})
return flask.jsonify({})
@ -216,7 +223,7 @@ def main():
appservice_token=args.appservice_token,
matrix_server_name=args.matrix_server_name,
)
queue = message_queue.MessageQueue(
global_msg_queue = message_queue.MessageQueue(
host="0.0.0.0",
port=args.minecraft_wrapper_port,
side=message_queue.Side.SERVER,
@ -225,14 +232,14 @@ def main():
target=app.run, kwargs={"port": args.matrix_api_port}, daemon=True,
)
receive_worker = threading.Thread(
target=receive_messages, args=(appservice, queue), daemon=True,
target=receive_messages, args=(appservice, global_msg_queue), daemon=True,
)
flask_thread.start()
receive_worker.start()
LOG.info("All threads created")
receive_worker.join()
flask_thread.join()
queue.close()
global_msg_queue.close()
LOG.info("All threads terminated")

View File

@ -14,7 +14,7 @@ LOG = logging.getLogger(__name__)
class ProcessWrapper:
"""Iterator that spawns a process and yields lines from its stdout."""
def __init__(self, command: List[str], queue: message_queue.MessageQueue):
def __init__(self, command: List[str]):
self.proc = subprocess.Popen(
" ".join(command),
shell=True,
@ -48,7 +48,7 @@ def send_process_output(
)
)
msg_queue.add(
{"user": result.group(3), "msg": result.group(4).rstrip("\n"),},
{"user": result.group(3), "msg": result.group(4).rstrip("\n")},
)
@ -65,7 +65,7 @@ def relay_queue_input(
def main():
logging.basicConfig(level=logging.DEBUG)
parser = argparse.ArgumentParser()
parser.add_argument("--matrix_server")
parser.add_argument("--matrix_server", required=True)
parser.add_argument("--matrix_server_port", type=int, default=5001)
parser.add_argument("command", nargs=argparse.REMAINDER)
args = parser.parse_args()