Commit 94956fd7 authored by James Kirk's avatar James Kirk
Browse files

refactor: changed to an async connection method, sharing a channel as much as we can

2 merge requests!26Resolve "Release v1.0.0",!25refactor: changed to an async connection method, sharing a channel as much as we can
Pipeline #133687 failed with stages
in 23 seconds
File added
File added
......@@ -11,28 +11,6 @@ host = os.getenv(
# -------------------------------------------------------------------------------------------------------------------------------------------------------------
def pika_connect(host):
try:
connection = pika.BlockingConnection(pika.ConnectionParameters(host))
except Exception:
connection = None
if connection is not None:
channel = connection.channel()
else:
logging.error(
"ERROR: Pika has been unable to connect to host '%s'. Is RabbitMQ running?"
% host
)
raise Exception(
"ERROR: Pika has been unable to connect to host '%s'. Is RabbitMQ running?"
% host
)
return connection, channel
def setup_queue(channel, queue_name=""):
channel.queue_declare(
queue=queue_name, exclusive=False, durable=True
......@@ -75,6 +53,25 @@ def deliver_to_exchange(channel, body, exchange_name, topic=None):
# -------------------------------------------------------------------------------------------------------------------------------------------------------------
def pika_connect(host):
try:
connection = pika.BlockingConnection(pika.ConnectionParameters(host))
except Exception:
connection = None
if connection is not None:
channel = connection.channel()
else:
logging.error(
"ERROR: Pika has been unable to connect to host '%s'. Is RabbitMQ running?"
% host
)
raise Exception(
"ERROR: Pika has been unable to connect to host '%s'. Is RabbitMQ running?"
% host
)
return connection, channel
def write_to_queue(queue_name, msg):
......@@ -122,10 +119,8 @@ def read_from_queue(queue_name, max_msgs):
return messages
def broadcast(queue_name, exchange_name):
def broadcast(channel, queue_name, exchange_name):
# read from a queue, forward onto a 'fanout' exchange
_, channel = pika_connect(host=host)
setup_queue(channel=channel, queue_name=queue_name)
def broadcast_callback(ch, method, properties, body):
......@@ -136,15 +131,13 @@ def broadcast(queue_name, exchange_name):
try:
channel.basic_consume(queue=queue_name, on_message_callback=broadcast_callback)
channel.start_consuming()
# channel.start_consuming()
except pika.exceptions.AMQPChannelError as err:
print("Caught a channel error: {}, stopping...".format(err))
def forward(from_queue, to_queue):
def forward(channel, from_queue, to_queue):
# read from a queue, forward onto a different queue
_, channel = pika_connect(host=host)
setup_queue(channel=channel, queue_name=from_queue)
setup_queue(channel=channel, queue_name=to_queue)
......@@ -163,15 +156,13 @@ def forward(from_queue, to_queue):
try:
channel.basic_consume(queue=from_queue, on_message_callback=forward_callback)
channel.start_consuming()
# channel.start_consuming()
except pika.exceptions.AMQPChannelError as err:
logging.error("Caught a channel error: {}, stopping...".format(err))
def publish(queue_name, exchange_name):
def publish(channel, queue_name, exchange_name):
# read from a queue, forward onto a 'topic' exchange
_, channel = pika_connect(host=host)
setup_queue(channel=channel, queue_name=queue_name)
def publish_callback(ch, method, properties, body):
......@@ -188,18 +179,17 @@ def publish(queue_name, exchange_name):
try:
channel.basic_consume(queue=queue_name, on_message_callback=publish_callback)
channel.start_consuming()
# channel.start_consuming()
except pika.exceptions.AMQPChannelError as err:
print("Caught a channel error: {}, stopping...".format(err))
def subscribe(queue_name, exchange_name, topic=None):
def subscribe(channel, queue_name, exchange_name, topic=None):
logging.debug(
f"Subscribe queue: {queue_name} to {exchange_name} with topic {topic}"
)
# setup bindings between queue and exchange,
# exchange_type is either 'fanout' or 'topic' based on if the topic arg is passed
connection, channel = pika_connect(host=host)
setup_queue(channel=channel, queue_name=queue_name)
if topic is None:
......@@ -209,22 +199,17 @@ def subscribe(queue_name, exchange_name, topic=None):
topic_exchange(channel=channel, exchange_name=exchange_name)
channel.queue_bind(exchange=exchange_name, queue=queue_name, routing_key=topic)
connection.close()
def listen(queue_name, callback):
def listen(channel, queue_name, callback):
logging.debug(f"Listen to queue: {queue_name}")
# subscribe client to a queue, using the callback arg
_, channel = pika_connect(host=host)
setup_queue(channel=channel, queue_name=queue_name)
channel.basic_consume(queue=queue_name, on_message_callback=callback)
channel.start_consuming()
# channel.start_consuming()
def get_queue_status(queue_name):
_, channel = pika_connect(host=host)
def get_queue_status(channel, queue_name):
response = channel.queue_declare(queue=queue_name, passive=True)
queue_status = {
"size": response.method.message_count,
......@@ -234,8 +219,7 @@ def get_queue_status(queue_name):
return queue_status
def empty_queue(queue_name):
_, channel = pika_connect(host=host)
def empty_queue(channel, queue_name):
emptied = True
try:
channel.queue_purge(queue_name)
......
......@@ -21,7 +21,7 @@ from rmq import broadcast, forward, publish, subscribe
from models.client_model import ClientModel
from logger import setup_logging
import pika
setup_logging()
......@@ -100,65 +100,70 @@ def stop_client(client_id):
return stopping
def run_client(client_id, executor):
def run_client(client_id, channel):
global RUNNING_CLIENTS
client_model = ClientModel()
client = client_model.find(client_id)
running = False
try:
client_id = client["client_id"]
logging.info(f"Running client: {client_id}")
# forward
thread = executor.submit(forward, f"{client_id}-outbox", "soar-publish")
THREADS[f"{client_id}-outbox"] = thread
# broadcast
thread = executor.submit(
broadcast, f"{client_id}-broadcast", EXCHANGES.get("broadcast")
forward(
channel,
f"{client_id}-outbox",
"soar-publish"
)
broadcast(
channel,
f"{client_id}-broadcast",
EXCHANGES.get("broadcast")
)
THREADS[f"{client_id}-broadcast"] = thread
subscribe(
channel,
f"{client_id}-inbox",
EXCHANGES.get("publish"),
client["subscription"],
)
subscribe(f"{client_id}-inbox", EXCHANGES.get("broadcast"))
if client_id not in RUNNING_CLIENTS:
RUNNING_CLIENTS.append(client_id)
running = True
subscribe(
channel,
f"{client_id}-inbox",
EXCHANGES.get("broadcast")
)
except Exception as error:
logging.error(str(error))
return running
def main(clients, executor):
def run_bus(clients, channel):
global RUNNING_CLIENTS
logging.info("Starting SOAR bus...")
# publish
thread = executor.submit(publish, "soar-publish", EXCHANGES.get("publish"))
THREADS["soar-publish"] = thread
publish(
channel,
"soar-publish",
EXCHANGES.get("publish")
)
for id in clients.keys():
run_client(id, executor)
run_client(id, channel)
# Global vars are not shared across threads so you
# have to pass the global var into the thread
thread = executor.submit(watch_config, RUNNING_CLIENTS)
THREADS["config-watcher"] = thread
# thread = executor.submit(watch_config, RUNNING_CLIENTS) # TODO: Sort this out
# THREADS["config-watcher"] = thread
# Make sure the threads are actually running, error if not,
# this allows the SOAR Bus to actually wait for RMQ to start running
for thread_name, thread in THREADS.items():
thread.result()
try:
logging.debug(thread_name)
logging.debug(thread.result())
except Exception as e:
logging.error(e)
raise e
def on_channel_open(channel):
print("in on_channel_open")
# Invoked when the channel is open
client_model = ClientModel()
clients = client_model.get()
client_count = len(clients.keys())
run_bus(clients, channel)
def on_connection_open(connection):
print("in on_open")
# Invoked when the connection is open
connection.channel(on_open_callback=on_channel_open)
if __name__ == "__main__":
pingcounter = 0
......@@ -176,16 +181,12 @@ if __name__ == "__main__":
pingcounter += 1
s.close()
host = os.getenv(
"MQ_HOST", "localhost"
)
connection = pika.SelectConnection(pika.ConnectionParameters(host), on_open_callback=on_connection_open)
try:
client_model = ClientModel()
clients = client_model.get()
client_count = len(clients.keys())
thread_count = (client_count * 2) + 2
logging.debug(f"Running {thread_count} workers for {client_count} clients")
with concurrent.futures.ThreadPoolExecutor(
max_workers=thread_count
) as executor:
main(clients, executor)
connection.ioloop.start()
except KeyboardInterrupt:
executor.shutdown(wait=False)
connection.close()
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment