From 0a785cc5e0fb2ff56a0efd10274995c1bd6872f6 Mon Sep 17 00:00:00 2001
From: James Kirk <james.kirk@noc.ac.uk>
Date: Thu, 19 Jan 2023 17:04:47 +0000
Subject: [PATCH] fix: broadcast had a bug that definitely caused memory leaks
 refactor: spot of tidying up notes fix: subscribe wasnt handling closing
 connections

---
 rmq.py | 38 +++++++++++++++++++++++---------------
 1 file changed, 23 insertions(+), 15 deletions(-)

diff --git a/rmq.py b/rmq.py
index 90cbda6..3008d0f 100644
--- a/rmq.py
+++ b/rmq.py
@@ -2,7 +2,7 @@ import json
 
 import pika
 
-host='localhost'
+host='localhost' # TODO Handle host being passed in
 
 # -------------------------------------------------------------------------------------------------------------------------------------------------------------
 
@@ -12,7 +12,7 @@ def pika_connect(host):
     return connection, channel
 
 
-def setup_queue(channel, queue_name=''): # TODO: Decide if this is too little or is ok. Maybe on setting up exchanges I need to expand this out
+def setup_queue(channel, queue_name=''):
     channel.queue_declare(queue=queue_name, exclusive=False, durable=True) # exclusive means the queue can only be used by the connection that created it
 
 
@@ -33,15 +33,13 @@ def topic_exchange(channel, exchange_name, topic=None, queue_name=None):
         channel.queue_bind(exchange=exchange_name, queue=queue_name, routing_key=topic)
 
 
-def deliver_to_exchange(channel, body, exchange_name, topic=None): # TODO: Definitely verify I can use channel/'ch' like this in the callback
-    # connection, channel = pika_connect(host=host)
+def deliver_to_exchange(channel, body, exchange_name, topic=None):
     if topic is None:
         fanout_exchange(channel=channel, exchange_name=exchange_name)
         channel.basic_publish(exchange=exchange_name, routing_key='', body=body)
     else:
         topic_exchange(channel=channel, exchange_name=exchange_name, topic=topic)
         channel.basic_publish(exchange=exchange_name, routing_key=topic, body=body)
-    # connection.close()
 
 # -------------------------------------------------------------------------------------------------------------------------------------------------------------
 
@@ -51,7 +49,6 @@ def write_to_queue(queue_name, msg):
     setup_queue(channel=channel, queue_name=queue_name)
 
     channel.basic_publish(exchange='', routing_key=queue_name, body=msg)
-    
     connection.close()
 
 
@@ -80,14 +77,17 @@ def broadcast(queue_name, exchange_name):
     # read from a queue, forward onto a 'fanout' exchange
     _, channel = pika_connect(host=host)
 
-    fanout_exchange(channel=channel, exchange_name=exchange_name)
+    setup_queue(channel=channel, queue_name=queue_name)
 
     def broadcast_callback(ch, method, properties, body):
         deliver_to_exchange(channel=ch, body=body, exchange_name=exchange_name)
         ch.basic_ack(delivery_tag=method.delivery_tag)
 
-    channel.basic_consume(queue=queue_name, on_message_callback=broadcast_callback)
-    channel.start_consuming()
+    try:
+        channel.basic_consume(queue=queue_name, on_message_callback=broadcast_callback)
+        channel.start_consuming()
+    except pika.exceptions.AMQPChannelError as err:
+        print("Caught a channel error: {}, stopping...".format(err))
 
 
 def forward(queue_name_one, queue_name_two):
@@ -98,11 +98,14 @@ def forward(queue_name_one, queue_name_two):
     setup_queue(channel=channel, queue_name=queue_name_two)
 
     def forward_callback(ch, method, properties, body):
-        write_to_queue(queue_name=queue_name_two, msg=body)
+        channel.basic_publish(exchange='', routing_key=queue_name_two, body=body)
         ch.basic_ack(delivery_tag=method.delivery_tag)
 
-    channel.basic_consume(queue=queue_name_one, on_message_callback=forward_callback)
-    channel.start_consuming()
+    try:
+        channel.basic_consume(queue=queue_name_one, on_message_callback=forward_callback)
+        channel.start_consuming()
+    except pika.exceptions.AMQPChannelError as err:
+        print("Caught a channel error: {}, stopping...".format(err))
 
 
 def publish(queue_name, exchange_name):
@@ -117,20 +120,25 @@ def publish(queue_name, exchange_name):
         deliver_to_exchange(channel=ch, body=body, exchange_name=exchange_name, topic=topic)
         ch.basic_ack(delivery_tag=method.delivery_tag)
 
-    channel.basic_consume(queue=queue_name, on_message_callback=publish_callback)
-    channel.start_consuming()
+    try:
+        channel.basic_consume(queue=queue_name, on_message_callback=publish_callback)
+        channel.start_consuming()
+    except pika.exceptions.AMQPChannelError as err:
+        print("Caught a channel error: {}, stopping...".format(err))
 
 
 def subscribe(queue_name, exchange_name, topic=None):
     # setup bindings between queue and exchange, 
     # exchange_type is either 'fanout' or 'topic' based on if the topic arg is passed
-    _, channel = pika_connect(host=host)
+    connection, channel = pika_connect(host=host)
 
     if topic is None:
         fanout_exchange(channel=channel, queue_name=queue_name, exchange_name=exchange_name)
     else:
         topic_exchange(channel=channel, queue_name=queue_name, exchange_name=exchange_name, topic=topic)
 
+    connection.close()
+
 
 def listen(queue_name, callback):
     # subscribe client to a queue, using the callback arg
-- 
GitLab