Commit f40d1b8f authored by James Kirk's avatar James Kirk
Browse files

Merge branch 'rmq-cpu-usage-issues' into 'dev'

refactor: changed to an async connection method, sharing a channel as much as we can

Closes #48

See merge request !25
Pipeline #232851 failed with stages
in 28 seconds
File added
File added
......@@ -12,27 +12,6 @@ host = os.getenv(
# -------------------------------------------------------------------------------------------------------------------------------------------------------------
def pika_connect(host):
try:
connection = pika.BlockingConnection(pika.ConnectionParameters(host))
except Exception:
connection = None
if connection is not None:
channel = connection.channel()
else:
logging.error(
"ERROR: Pika has been unable to connect to host '%s'. Is RabbitMQ running?"
% host
)
raise Exception(
"ERROR: Pika has been unable to connect to host '%s'. Is RabbitMQ running?"
% host
)
return connection, channel
def setup_queue(channel, queue_name=""):
channel.queue_declare(
queue=queue_name, exclusive=False, durable=True
......@@ -75,6 +54,25 @@ def deliver_to_exchange(channel, body, exchange_name, topic=None):
# -------------------------------------------------------------------------------------------------------------------------------------------------------------
def pika_connect(host):
try:
connection = pika.BlockingConnection(pika.ConnectionParameters(host))
except Exception:
connection = None
if connection is not None:
channel = connection.channel()
else:
logging.error(
"ERROR: Pika has been unable to connect to host '%s'. Is RabbitMQ running?"
% host
)
raise Exception(
"ERROR: Pika has been unable to connect to host '%s'. Is RabbitMQ running?"
% host
)
return connection, channel
def write_to_queue(queue_name, msg):
......@@ -122,10 +120,8 @@ def read_from_queue(queue_name, max_msgs):
return messages
def broadcast(queue_name, exchange_name):
def broadcast(channel, queue_name, exchange_name):
# read from a queue, forward onto a 'fanout' exchange
_, channel = pika_connect(host=host)
setup_queue(channel=channel, queue_name=queue_name)
def broadcast_callback(ch, method, properties, body):
......@@ -135,16 +131,15 @@ def broadcast(queue_name, exchange_name):
ch.basic_ack(delivery_tag=method.delivery_tag)
try:
channel.basic_consume(queue=queue_name, on_message_callback=broadcast_callback)
channel.start_consuming()
return channel.basic_consume(
queue=queue_name, on_message_callback=broadcast_callback
)
except pika.exceptions.AMQPChannelError as err:
print("Caught a channel error: {}, stopping...".format(err))
def forward(from_queue, to_queue):
def forward(channel, from_queue, to_queue):
# read from a queue, forward onto a different queue
_, channel = pika_connect(host=host)
setup_queue(channel=channel, queue_name=from_queue)
setup_queue(channel=channel, queue_name=to_queue)
......@@ -162,16 +157,15 @@ def forward(from_queue, to_queue):
ch.basic_ack(delivery_tag=method.delivery_tag)
try:
channel.basic_consume(queue=from_queue, on_message_callback=forward_callback)
channel.start_consuming()
return channel.basic_consume(
queue=from_queue, on_message_callback=forward_callback
)
except pika.exceptions.AMQPChannelError as err:
logging.error("Caught a channel error: {}, stopping...".format(err))
def publish(queue_name, exchange_name):
def publish(channel, queue_name, exchange_name):
# read from a queue, forward onto a 'topic' exchange
_, channel = pika_connect(host=host)
setup_queue(channel=channel, queue_name=queue_name)
def publish_callback(ch, method, properties, body):
......@@ -188,18 +182,16 @@ def publish(queue_name, exchange_name):
try:
channel.basic_consume(queue=queue_name, on_message_callback=publish_callback)
channel.start_consuming()
except pika.exceptions.AMQPChannelError as err:
print("Caught a channel error: {}, stopping...".format(err))
def subscribe(queue_name, exchange_name, topic=None):
def subscribe(channel, queue_name, exchange_name, topic=None):
logging.debug(
f"Subscribe queue: {queue_name} to {exchange_name} with topic {topic}"
)
# setup bindings between queue and exchange,
# exchange_type is either 'fanout' or 'topic' based on if the topic arg is passed
connection, channel = pika_connect(host=host)
setup_queue(channel=channel, queue_name=queue_name)
if topic is None:
......@@ -209,36 +201,34 @@ def subscribe(queue_name, exchange_name, topic=None):
topic_exchange(channel=channel, exchange_name=exchange_name)
channel.queue_bind(exchange=exchange_name, queue=queue_name, routing_key=topic)
connection.close()
def listen(queue_name, callback):
def listen(channel, queue_name, callback):
logging.debug(f"Listen to queue: {queue_name}")
# subscribe client to a queue, using the callback arg
_, channel = pika_connect(host=host)
setup_queue(channel=channel, queue_name=queue_name)
channel.basic_consume(queue=queue_name, on_message_callback=callback)
channel.start_consuming()
def get_queue_status(queue_name):
_, channel = pika_connect(host=host)
connection, channel = pika_connect(host=host)
response = channel.queue_declare(queue=queue_name, passive=True)
queue_status = {
"size": response.method.message_count,
"has_consumer": response.method.consumer_count > 0,
}
logging.debug(f"Queue: {queue_name} contains {queue_status['size']} messages")
connection.close()
return queue_status
def empty_queue(queue_name):
_, channel = pika_connect(host=host)
connection, channel = pika_connect(host=host)
emptied = True
try:
channel.queue_purge(queue_name)
except ValueError:
emptied = False
connection.close()
return emptied
......@@ -11,22 +11,21 @@
import logging
import os
import socket
import threading
import time
import concurrent.futures
from watchdog.observers import Observer
import pika
from watchdog.events import FileSystemEventHandler
from rmq import broadcast, forward, publish, subscribe
from models.client_model import ClientModel
from watchdog.observers import Observer
from logger import setup_logging
from models.client_model import ClientModel
from rmq import broadcast, forward, publish, subscribe
setup_logging()
THREADS = {}
RUNNING_CLIENTS = []
CONSUMER_TAGS = {}
EXCHANGES = {
"publish": "soar_publish",
"broadcast": "soar_broadcast",
......@@ -34,8 +33,9 @@ EXCHANGES = {
class ConfigHandler(FileSystemEventHandler):
def __init__(self):
def __init__(self, channel):
self.client_model = ClientModel()
self.channel = channel
super().__init__()
def on_modified(self, event):
......@@ -44,120 +44,148 @@ class ConfigHandler(FileSystemEventHandler):
logging.debug("Reloading client config...")
clients = self.client_model.get()
updated_client_ids = list(clients.keys())
update_clients(updated_client_ids)
self.update_clients(updated_client_ids)
def update_clients(updated_client_ids):
global RUNNING_CLIENTS
with concurrent.futures.ThreadPoolExecutor() as executor:
def update_clients(self, updated_client_ids):
global RUNNING_CLIENTS
logging.debug("Old: " + str(RUNNING_CLIENTS))
logging.debug("New: " + str(updated_client_ids))
for client_id in updated_client_ids:
if client_id not in RUNNING_CLIENTS:
run_client(client_id, executor)
run_client(client_id, self.channel)
logging.info(f"Started client: {client_id}")
for client_id in RUNNING_CLIENTS:
if client_id not in updated_client_ids:
stop_client(client_id)
self.stop_client(client_id)
logging.info(f"Shutdown client: {client_id}")
def watch_config(running_clients):
# Set global RUNNING_CLIENTS inside thread
global RUNNING_CLIENTS
RUNNING_CLIENTS = running_clients
logging.info("Starting config watcher...")
event_handler = ConfigHandler()
observer = Observer()
observer.schedule(event_handler, path="./data", recursive=False)
observer.start()
while True:
def stop_client(self, client_id):
global RUNNING_CLIENTS
global CONSUMER_TAGS
stopping = False
try:
pass
except KeyboardInterrupt as interrupt:
observer.stop()
raise interrupt
def stop_client(client_id):
global RUNNING_CLIENTS
stopping = False
try:
logging.info(f"Stopping client: {client_id}")
client_threads = ["outbox", "broadcast", "inbox-published", "inbox-broadcast"]
for thread in client_threads:
thread_name = f"{client_id}-{thread}"
if thread_name in THREADS:
THREADS[thread_name].cancel()
if client_id in RUNNING_CLIENTS:
RUNNING_CLIENTS.remove(client_id)
stopping = True
except Exception as error:
logging.error(str(error))
return stopping
def run_client(client_id, executor):
logging.info(f"Stopping client: {client_id}")
client_tags = CONSUMER_TAGS[client_id]
self.channel.basic_cancel(client_tags[f"{client_id}-broadcast"])
self.channel.basic_cancel(client_tags[f"{client_id}-outbox"])
self.channel.queue_unbind(
queue=f"{client_id}-inbox",
exchange=client_tags[f"{client_id}-inbox-publish"],
)
self.channel.queue_unbind(
queue=f"{client_id}-inbox",
exchange=client_tags[f"{client_id}-inbox-broadcast"],
)
if client_id in RUNNING_CLIENTS:
RUNNING_CLIENTS.remove(client_id)
stopping = True
except Exception as error:
logging.error(str(error))
return stopping
class WatchConfigThread(threading.Thread):
def __init__(self, running_clients, channel):
threading.Thread.__init__(self)
self.daemon = True
self.running_clients = running_clients
self.channel = channel
self.start()
def run(self):
logging.info("Starting config watcher...")
event_handler = ConfigHandler(self.channel)
observer = Observer()
observer.schedule(event_handler, path="./data", recursive=False)
observer.start()
while True:
try:
time.sleep(1)
except KeyboardInterrupt as interrupt:
observer.stop()
raise interrupt
class SoarBusThread(threading.Thread):
def __init__(self, clients, channel):
threading.Thread.__init__(self)
self.daemon = True
self.clients = clients
self.channel = channel
self.start()
def run(self):
logging.info("Starting SOAR bus...")
publish(self.channel, "soar-publish", EXCHANGES.get("publish"))
for id in self.clients.keys():
run_client(id, self.channel)
def run_client(client_id, channel):
global RUNNING_CLIENTS
global CONSUMER_TAGS
client_model = ClientModel()
client = client_model.find(client_id)
running = False
try:
client_id = client["client_id"]
logging.info(f"Running client: {client_id}")
# forward
thread = executor.submit(forward, f"{client_id}-outbox", "soar-publish")
THREADS[f"{client_id}-outbox"] = thread
# broadcast
thread = executor.submit(
broadcast, f"{client_id}-broadcast", EXCHANGES.get("broadcast")
forward_consumer_tag = forward(channel, f"{client_id}-outbox", "soar-publish")
broadcast_consumer_tag = broadcast(
channel, f"{client_id}-broadcast", EXCHANGES.get("broadcast")
)
THREADS[f"{client_id}-broadcast"] = thread
subscribe(
channel,
f"{client_id}-inbox",
EXCHANGES.get("publish"),
client["subscription"],
)
subscribe(f"{client_id}-inbox", EXCHANGES.get("broadcast"))
subscribe(channel, f"{client_id}-inbox", EXCHANGES.get("broadcast"))
CONSUMER_TAGS[client_id] = {
f"{client_id}-broadcast": broadcast_consumer_tag,
f"{client_id}-outbox": forward_consumer_tag,
f"{client_id}-inbox-publish": EXCHANGES.get("publish"),
f"{client_id}-inbox-broadcast": EXCHANGES.get("broadcast"),
}
if client_id not in RUNNING_CLIENTS:
logging.debug(f"Appending client_id '{client_id}'")
RUNNING_CLIENTS.append(client_id)
running = True
except Exception as error:
logging.error(str(error))
return running
def main(clients, executor):
def on_channel_open(channel):
# Invoked when the channel is open
global RUNNING_CLIENTS
logging.info("Starting SOAR bus...")
# publish
thread = executor.submit(publish, "soar-publish", EXCHANGES.get("publish"))
THREADS["soar-publish"] = thread
client_model = ClientModel()
clients = client_model.get()
client_count = len(clients.keys())
logging.debug(f"Running {client_count} clients")
for id in clients.keys():
run_client(id, executor)
SoarBusThread(clients, channel)
WatchConfigThread(RUNNING_CLIENTS, channel)
# Global vars are not shared across threads so you
# have to pass the global var into the thread
thread = executor.submit(watch_config, RUNNING_CLIENTS)
THREADS["config-watcher"] = thread
# Make sure the threads are actually running, error if not,
# this allows the SOAR Bus to actually wait for RMQ to start running
for thread_name, thread in THREADS.items():
thread.result()
try:
logging.debug(thread_name)
logging.debug(thread.result())
except Exception as e:
logging.error(e)
raise e
def on_connection_open(connection):
# Invoked when the connection is open
connection.channel(on_open_callback=on_channel_open)
def on_connection_close(connection, exception):
# Invoked when the connection is closed
connection.ioloop.stop()
if __name__ == "__main__":
......@@ -176,16 +204,16 @@ if __name__ == "__main__":
pingcounter += 1
s.close()
host = os.getenv("MQ_HOST", "localhost")
connection = pika.SelectConnection(
pika.ConnectionParameters(host),
on_open_callback=on_connection_open,
on_close_callback=on_connection_close,
)
try:
client_model = ClientModel()
clients = client_model.get()
client_count = len(clients.keys())
thread_count = (client_count * 2) + 2
logging.debug(f"Running {thread_count} workers for {client_count} clients")
with concurrent.futures.ThreadPoolExecutor(
max_workers=thread_count
) as executor:
main(clients, executor)
connection.ioloop.start()
except KeyboardInterrupt:
executor.shutdown(wait=False)
# Loop until fully closed
connection.close()
connection.ioloop.start()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment