fix some bugs, another test, possibly enable the sending back to the server
This commit is contained in:
parent
d362a6c922
commit
2cbfc101f3
|
@ -140,7 +140,9 @@ class SocketWrapper:
|
||||||
return data.decode("utf-8") if data is not None else None
|
return data.decode("utf-8") if data is not None else None
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
|
if self.soc is not None:
|
||||||
self.soc.close()
|
self.soc.close()
|
||||||
|
self.soc = None
|
||||||
|
|
||||||
|
|
||||||
class Side(enum.Enum):
|
class Side(enum.Enum):
|
||||||
|
@ -160,12 +162,8 @@ class MessageQueue:
|
||||||
self.outbox: queue.Queue[str] = queue.Queue()
|
self.outbox: queue.Queue[str] = queue.Queue()
|
||||||
self.closed = False
|
self.closed = False
|
||||||
self.process_worker = threading.Thread(
|
self.process_worker = threading.Thread(
|
||||||
target=(
|
target=process_messages,
|
||||||
process_messages_client
|
args=(side, SocketWrapper(host, port), self),
|
||||||
if side == Side.CLIENT
|
|
||||||
else process_messages_server
|
|
||||||
),
|
|
||||||
args=(SocketWrapper(host, port), self),
|
|
||||||
daemon=True,
|
daemon=True,
|
||||||
name="MessageQueue/" + str(side),
|
name="MessageQueue/" + str(side),
|
||||||
)
|
)
|
||||||
|
@ -185,36 +183,24 @@ class MessageQueue:
|
||||||
self.process_worker.join()
|
self.process_worker.join()
|
||||||
|
|
||||||
|
|
||||||
def process_messages_client(socket: SocketWrapper, queue: MessageQueue):
|
def process_messages(side: Side, socket: SocketWrapper, queue: MessageQueue):
|
||||||
log = LOG.getChild("client.worker")
|
log = LOG.getChild("worker.{}".format(side))
|
||||||
while not queue.closed:
|
while not queue.closed:
|
||||||
try:
|
try:
|
||||||
socket.connect()
|
socket.accept() if side == Side.SERVER else socket.connect()
|
||||||
while not queue.closed:
|
while not queue.closed:
|
||||||
message = socket.receive()
|
message = socket.receive()
|
||||||
if message is not None:
|
if message is not None:
|
||||||
queue.inbox.put(message)
|
queue.inbox.put(message)
|
||||||
while not queue.outbox.empty():
|
while not queue.outbox.empty():
|
||||||
log.debug("Sending outbox item")
|
log.debug("Sending outbox item")
|
||||||
socket.send(queue.outbox.get(block=False))
|
message = queue.outbox.get(block=False)
|
||||||
except Exception as e:
|
|
||||||
LOG.exception(e)
|
|
||||||
finally:
|
|
||||||
socket.close()
|
|
||||||
|
|
||||||
|
|
||||||
def process_messages_server(socket: SocketWrapper, queue: MessageQueue):
|
|
||||||
log = LOG.getChild("server.worker")
|
|
||||||
while not queue.closed:
|
|
||||||
try:
|
try:
|
||||||
socket.accept()
|
socket.send(message)
|
||||||
while not queue.closed:
|
except Exception as e:
|
||||||
message = socket.receive()
|
# TODO(chr): Doesn't preserve order, use priority queue
|
||||||
if message is not None:
|
queue.outbox.put(message)
|
||||||
queue.inbox.put(message)
|
raise e from e
|
||||||
while not queue.outbox.empty():
|
|
||||||
log.debug("Sending outbox item")
|
|
||||||
socket.send(queue.outbox.get(block=False))
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
LOG.exception(e)
|
LOG.exception(e)
|
||||||
finally:
|
finally:
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
import logging
|
|
||||||
import random
|
import random
|
||||||
import threading
|
import threading
|
||||||
import unittest
|
import unittest
|
||||||
|
@ -7,10 +6,8 @@ import message_queue
|
||||||
|
|
||||||
|
|
||||||
class MessageQueueTest(unittest.TestCase):
|
class MessageQueueTest(unittest.TestCase):
|
||||||
def setUp(self):
|
|
||||||
logging.basicConfig(level=logging.DEBUG)
|
|
||||||
|
|
||||||
def test_message_queue(self):
|
def test_message_queue(self):
|
||||||
|
"""Test basic functionality"""
|
||||||
port = random.randint(10000, 65535)
|
port = random.randint(10000, 65535)
|
||||||
server_queue = message_queue.MessageQueue(
|
server_queue = message_queue.MessageQueue(
|
||||||
host="localhost", port=port, side=message_queue.Side.SERVER
|
host="localhost", port=port, side=message_queue.Side.SERVER
|
||||||
|
@ -37,6 +34,68 @@ class MessageQueueTest(unittest.TestCase):
|
||||||
client_queue.close()
|
client_queue.close()
|
||||||
self.assertEqual(threading.active_count(), 1)
|
self.assertEqual(threading.active_count(), 1)
|
||||||
|
|
||||||
|
def test_disconnect_reconnect_client(self):
|
||||||
|
"""Test that the server can send messages while the client is offline"""
|
||||||
|
port = random.randint(10000, 65535)
|
||||||
|
server_queue = message_queue.MessageQueue(
|
||||||
|
host="localhost", port=port, side=message_queue.Side.SERVER
|
||||||
|
)
|
||||||
|
client_queue = message_queue.MessageQueue(
|
||||||
|
host="localhost", port=port, side=message_queue.Side.CLIENT
|
||||||
|
)
|
||||||
|
# Worker threads + main thread
|
||||||
|
self.assertEqual(threading.active_count(), 3)
|
||||||
|
|
||||||
|
expected_message = {"index": "0"}
|
||||||
|
server_queue.add(expected_message)
|
||||||
|
received_message = next(client_queue)
|
||||||
|
self.assertEqual(expected_message, received_message)
|
||||||
|
|
||||||
|
expected_message = {"index": "1"}
|
||||||
|
client_queue.close()
|
||||||
|
server_queue.add(expected_message)
|
||||||
|
self.assertEqual(threading.active_count(), 2)
|
||||||
|
client_queue = message_queue.MessageQueue(
|
||||||
|
host="localhost", port=port, side=message_queue.Side.CLIENT
|
||||||
|
)
|
||||||
|
received_message = next(client_queue)
|
||||||
|
self.assertEqual(expected_message, received_message)
|
||||||
|
|
||||||
|
server_queue.close()
|
||||||
|
client_queue.close()
|
||||||
|
self.assertEqual(threading.active_count(), 1)
|
||||||
|
|
||||||
|
def test_disconnect_reconnect_server(self):
|
||||||
|
"""Test that the client can send messages while the server is offline"""
|
||||||
|
port = random.randint(10000, 65535)
|
||||||
|
server_queue = message_queue.MessageQueue(
|
||||||
|
host="localhost", port=port, side=message_queue.Side.SERVER
|
||||||
|
)
|
||||||
|
client_queue = message_queue.MessageQueue(
|
||||||
|
host="localhost", port=port, side=message_queue.Side.CLIENT
|
||||||
|
)
|
||||||
|
# Worker threads + main thread
|
||||||
|
self.assertEqual(threading.active_count(), 3)
|
||||||
|
|
||||||
|
expected_message = {"index": "0"}
|
||||||
|
client_queue.add(expected_message)
|
||||||
|
received_message = next(server_queue)
|
||||||
|
self.assertEqual(expected_message, received_message)
|
||||||
|
|
||||||
|
expected_message = {"index": "1"}
|
||||||
|
server_queue.close()
|
||||||
|
client_queue.add(expected_message)
|
||||||
|
self.assertEqual(threading.active_count(), 2)
|
||||||
|
server_queue = message_queue.MessageQueue(
|
||||||
|
host="localhost", port=port, side=message_queue.Side.SERVER
|
||||||
|
)
|
||||||
|
received_message = next(server_queue)
|
||||||
|
self.assertEqual(expected_message, received_message)
|
||||||
|
|
||||||
|
server_queue.close()
|
||||||
|
client_queue.close()
|
||||||
|
self.assertEqual(threading.active_count(), 1)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main(verbosity=2)
|
||||||
|
|
15
service.py
15
service.py
|
@ -20,6 +20,7 @@ LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
USER_RE = re.compile(r"(?<=\@).*(?=\:)")
|
USER_RE = re.compile(r"(?<=\@).*(?=\:)")
|
||||||
|
|
||||||
|
global_msg_queue = None
|
||||||
app = flask.Flask(__name__)
|
app = flask.Flask(__name__)
|
||||||
roomsync = set()
|
roomsync = set()
|
||||||
|
|
||||||
|
@ -42,7 +43,13 @@ def on_receive_events(transaction):
|
||||||
m_user = USER_RE.search(event["user_id"]).group(0)
|
m_user = USER_RE.search(event["user_id"]).group(0)
|
||||||
m_cont = event["content"]["body"]
|
m_cont = event["content"]["body"]
|
||||||
m_user, m_cont
|
m_user, m_cont
|
||||||
# minecraft.msglist.insert(0, "/tellraw @a {\"text\":\"<" + m_user + "> " + m_cont + "\",\"insertion\":\"/tellraw @p %s\"}")
|
if global_msg_queue is not None:
|
||||||
|
global_msg_queue.add({
|
||||||
|
"command",
|
||||||
|
'/tellraw @a {"text":"<{}> {}","insertion":"/tellraw @p %s"}'.format(
|
||||||
|
m_user, m_cont
|
||||||
|
),
|
||||||
|
})
|
||||||
|
|
||||||
return flask.jsonify({})
|
return flask.jsonify({})
|
||||||
|
|
||||||
|
@ -216,7 +223,7 @@ def main():
|
||||||
appservice_token=args.appservice_token,
|
appservice_token=args.appservice_token,
|
||||||
matrix_server_name=args.matrix_server_name,
|
matrix_server_name=args.matrix_server_name,
|
||||||
)
|
)
|
||||||
queue = message_queue.MessageQueue(
|
global_msg_queue = message_queue.MessageQueue(
|
||||||
host="0.0.0.0",
|
host="0.0.0.0",
|
||||||
port=args.minecraft_wrapper_port,
|
port=args.minecraft_wrapper_port,
|
||||||
side=message_queue.Side.SERVER,
|
side=message_queue.Side.SERVER,
|
||||||
|
@ -225,14 +232,14 @@ def main():
|
||||||
target=app.run, kwargs={"port": args.matrix_api_port}, daemon=True,
|
target=app.run, kwargs={"port": args.matrix_api_port}, daemon=True,
|
||||||
)
|
)
|
||||||
receive_worker = threading.Thread(
|
receive_worker = threading.Thread(
|
||||||
target=receive_messages, args=(appservice, queue), daemon=True,
|
target=receive_messages, args=(appservice, global_msg_queue), daemon=True,
|
||||||
)
|
)
|
||||||
flask_thread.start()
|
flask_thread.start()
|
||||||
receive_worker.start()
|
receive_worker.start()
|
||||||
LOG.info("All threads created")
|
LOG.info("All threads created")
|
||||||
receive_worker.join()
|
receive_worker.join()
|
||||||
flask_thread.join()
|
flask_thread.join()
|
||||||
queue.close()
|
global_msg_queue.close()
|
||||||
LOG.info("All threads terminated")
|
LOG.info("All threads terminated")
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -14,7 +14,7 @@ LOG = logging.getLogger(__name__)
|
||||||
class ProcessWrapper:
|
class ProcessWrapper:
|
||||||
"""Iterator that spawns a process and yields lines from its stdout."""
|
"""Iterator that spawns a process and yields lines from its stdout."""
|
||||||
|
|
||||||
def __init__(self, command: List[str], queue: message_queue.MessageQueue):
|
def __init__(self, command: List[str]):
|
||||||
self.proc = subprocess.Popen(
|
self.proc = subprocess.Popen(
|
||||||
" ".join(command),
|
" ".join(command),
|
||||||
shell=True,
|
shell=True,
|
||||||
|
@ -48,7 +48,7 @@ def send_process_output(
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
msg_queue.add(
|
msg_queue.add(
|
||||||
{"user": result.group(3), "msg": result.group(4).rstrip("\n"),},
|
{"user": result.group(3), "msg": result.group(4).rstrip("\n")},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -65,7 +65,7 @@ def relay_queue_input(
|
||||||
def main():
|
def main():
|
||||||
logging.basicConfig(level=logging.DEBUG)
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--matrix_server")
|
parser.add_argument("--matrix_server", required=True)
|
||||||
parser.add_argument("--matrix_server_port", type=int, default=5001)
|
parser.add_argument("--matrix_server_port", type=int, default=5001)
|
||||||
parser.add_argument("command", nargs=argparse.REMAINDER)
|
parser.add_argument("command", nargs=argparse.REMAINDER)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
Loading…
Reference in New Issue