diff --git a/README.md b/README.md index fdd1d8c6eabedc62f68e586631a22caf790b6824..58a69a395738b9b89b21118fa802ad5fe8dcd3aa 100644 --- a/README.md +++ b/README.md @@ -51,7 +51,12 @@ pipenv run python api.py #### Create some clients -`POST` to `http://localhost:3000/clients` +`POST` to `http://localhost:3000/clients`. Example: +``` +{ + "client_id": "qwert123456789" +} +``` #### Event bus diff --git a/__init__.py b/__init__.py index e711e1f5a851354c94542a1be64b559569703507..ea46bbf6981559cc4813a8cbb195d81fd6cb745e 100644 --- a/__init__.py +++ b/__init__.py @@ -8,5 +8,5 @@ __all__ = [ # "workers", "fixtures", "tests", - "run", + # "run", ] diff --git a/api/__init__.py b/api/__init__.py index 9c8c4cedf9fb986f883c6ba2ce8912d1f59c3e13..2253a07952dc605ffa662b625265c4b68b841e53 100644 --- a/api/__init__.py +++ b/api/__init__.py @@ -5,4 +5,4 @@ __all__ = [ os.path.splitext(os.path.basename(x))[0] for x in os.listdir(os.path.dirname(__file__)) if x.endswith(".py") and x != "__init__.py" -] \ No newline at end of file +] diff --git a/api/auth_resource.py b/api/auth_resource.py index 5ebd2019ecabb8d29407de80b42a7f8a35c59176..6be622fe0bc15dabd29b63712c658223d85a9eab 100644 --- a/api/auth_resource.py +++ b/api/auth_resource.py @@ -2,24 +2,24 @@ import json from flask_restful import Resource, abort from models.token import TokenModel -class AuthResource(Resource): - def __init__(self): +class AuthResource(Resource): + def __init__(self): self.token = TokenModel() with open("clients.json", "r") as clients_file: self.clients = json.load(clients_file) - def auth(self, request): + def auth(self, request): allow = False - auth = request.headers.get('Authorization', False) + auth = request.headers.get("Authorization", False) if auth: - token = auth.split(' ').pop() + token = auth.split(" ").pop() parsed = self.token.validate(token) - if parsed['valid']: - client = self.clients.get(parsed['client_id']) - if client: + if parsed["valid"]: + client = self.clients.get(parsed["client_id"]) + if client: self.client = client allow = True if not allow: abort(403, message="Invalid token") - return allow \ No newline at end of file + return allow diff --git a/api/clients.py b/api/clients.py index a079e934d451dd14a85251c4f94d6d9faa669058..a70bd5033f5a5f47ef5d9a2234e2efa19f80469e 100644 --- a/api/clients.py +++ b/api/clients.py @@ -8,12 +8,12 @@ import string class ClientSchema(Schema): client_id = fields.Str(required=True) - client_name = fields.Str(required=True) + client_name = fields.Str(required=True) subscription = fields.Str(required=True) class ClientsFile: - file = "clients.json" + file = "fixtures.clients.json" mtime = 0 clients = {} parser = None @@ -27,7 +27,7 @@ class ClientsFile: if mtime > self.mtime: with open(self.file, "r") as client_file: self.clients = json.load(client_file) - except FileNotFoundError as error: + except FileNotFoundError: self.clients = {} self.save() @@ -42,7 +42,7 @@ class ClientsFile: return client def add(self, client): - client['secret'] = self.secret() + client["secret"] = self.secret() self.clients[client["client_id"]] = client self.save() return client @@ -70,28 +70,36 @@ class ClientsFile: def secret(self, chars=36): res = "".join( random.choices( - string.ascii_lowercase + string.ascii_uppercase + string.digits, k=chars + ( + string.ascii_lowercase, + +string.ascii_uppercase, + +string.digits, + ), + k=chars, ) ) return str(res) + clients_file = ClientsFile() + # Client class Client(Resource): clients_file = None - def __init__(self): + + def __init__(self): self.schema = ClientSchema() self.clients_file = ClientsFile() def get(self, client_id): client = self.clients_file.find(client_id) - del client['secret'] + del client["secret"] if not client: abort(404, message="No client with id: {}".format(client_id)) return client - def delete(self, todo_id): + def delete(self, client_id): client = self.clients_file.find(client_id) if not client: abort(404, message="No client with id: {}".format(client_id)) @@ -102,9 +110,9 @@ class Client(Resource): def put(self, client_id): args = request.get_json() errors = self.schema.validate(args) - if errors: + if errors: abort(400, message=str(errors)) - + client = self.clients_file.find(client_id) if not client: abort(404, message="No client with id: {}".format(client_id)) @@ -115,7 +123,7 @@ class Client(Resource): # ClientList class ClientList(Resource): - def __init__(self): + def __init__(self): self.schema = ClientSchema() self.clients_file = ClientsFile() @@ -127,13 +135,13 @@ class ClientList(Resource): def post(self): args = request.get_json() - + print("Args is %s" % args) errors = self.schema.validate(args) - if errors: + if errors: abort(400, message=str(errors)) - - client = clients_file.find(args["client_id"]) - if client: + print("errors is %s" % errors) + client_id = clients_file.find(args["client_id"]) + if client_id: abort(403, message="Duplicate client id: {}".format(client_id)) else: client = clients_file.add(args) diff --git a/api/notify.py b/api/notify.py index 2fdd20419b8bc7b9dc69a41ff13d0d96532acf55..8b0d3b8d95f3fc54c27d326ce423452e3f75dafa 100644 --- a/api/notify.py +++ b/api/notify.py @@ -8,6 +8,7 @@ from api.auth_resource import AuthResource class NotifySchema(Schema): body = fields.Str(required=True) + class Notify(AuthResource): clients = None schema = None @@ -15,7 +16,7 @@ class Notify(AuthResource): def __init__(self): super().__init__() self.schema = NotifySchema() - + def post(self): args = request.get_json() errors = self.schema.validate(args) @@ -25,16 +26,20 @@ class Notify(AuthResource): allow = False body = args.get("body") message = { - 'topic': 'broadcast', - 'message': body, + "topic": "broadcast", + "message": body, } allow = self.auth(request) - + if allow: - notify_queue = self.client['client_id'] + "-broadcast" - connection = pika.BlockingConnection(pika.ConnectionParameters(host="localhost")) + notify_queue = self.client["client_id"] + "-broadcast" + connection = pika.BlockingConnection( + pika.ConnectionParameters(host="localhost") + ) channel = connection.channel() channel.queue_declare(queue=notify_queue, durable=True) - channel.basic_publish(exchange="", routing_key=notify_queue, body=json.dumps(message)) + channel.basic_publish( + exchange="", routing_key=notify_queue, body=json.dumps(message) + ) connection.close() diff --git a/api/receive.py b/api/receive.py index be31ae7c6a9cfacfde32966c8ec823fbf2b8cc10..cb7e9a2362c5a7e0ae256fce56e5f98573268c08 100644 --- a/api/receive.py +++ b/api/receive.py @@ -2,7 +2,8 @@ from flask_restful import request, abort from marshmallow import Schema, fields import pika import json -from models.token import TokenModel + +# from models.token import TokenModel from api.auth_resource import AuthResource @@ -25,11 +26,11 @@ class Receive(AuthResource): messages = [] max_messages = request.args.get("max_messages", 10) - + allow = self.auth(request) if allow: - inbox_queue = self.client['client_id'] + "-inbox" - + inbox_queue = self.client["client_id"] + "-inbox" + if allow: connection = pika.BlockingConnection( pika.ConnectionParameters(host="localhost") @@ -37,7 +38,9 @@ class Receive(AuthResource): channel = connection.channel() channel.queue_declare(queue=inbox_queue, durable=True) while len(messages) < max_messages: - method_frame, header_frame, body = channel.basic_get(inbox_queue) + method_frame, header_frame, body = channel.basic_get( + inbox_queue, + ) if method_frame: print(method_frame, header_frame, body) channel.basic_ack(method_frame.delivery_tag) diff --git a/api/send.py b/api/send.py index 57b6647fb5ef79e28446d1a5635dd0c751257559..b25588bc58c76910df37e289b11e6f40b2d1877c 100644 --- a/api/send.py +++ b/api/send.py @@ -4,10 +4,12 @@ from marshmallow import Schema, fields import pika from api.auth_resource import AuthResource + class SendSchema(Schema): body = fields.Str(required=True) topic = fields.Str(required=True) + class Send(AuthResource): clients = None schema = None @@ -15,7 +17,7 @@ class Send(AuthResource): def __init__(self): super().__init__() self.schema = SendSchema() - + def post(self): args = request.get_json() errors = self.schema.validate(args) @@ -23,18 +25,22 @@ class Send(AuthResource): abort(400, message=str(errors)) allow = self.auth(request) - + if allow: body = args.get("body") topic = args.get("topic") - outbox_queue = self.client['client_id'] + "-outbox" - - connection = pika.BlockingConnection(pika.ConnectionParameters(host="localhost")) + outbox_queue = self.client["client_id"] + "-outbox" + + connection = pika.BlockingConnection( + pika.ConnectionParameters(host="localhost") + ) channel = connection.channel() channel.queue_declare(queue=outbox_queue, durable=True) message = { - 'topic': topic, - 'message': body, + "topic": topic, + "message": body, } - channel.basic_publish(exchange="", routing_key=outbox_queue, body=json.dumps(message)) + channel.basic_publish( + exchange="", routing_key=outbox_queue, body=json.dumps(message) + ) connection.close() diff --git a/api/token.py b/api/token.py index f245abc0e9d9d769ea3a89e21bd327ed5ff22632..51079b82f73739d7a87b7866bc5f84785b27c27a 100644 --- a/api/token.py +++ b/api/token.py @@ -1,5 +1,4 @@ -import json -import pika +import json from flask_restful import Resource, request, abort from marshmallow import Schema, fields from models.token import TokenModel @@ -20,7 +19,7 @@ class Token(Resource): self.model = TokenModel() with open("clients.json", "r") as clients_file: self.clients = json.load(clients_file) - + def get(self): errors = self.schema.validate(request.args) if errors: @@ -28,7 +27,7 @@ class Token(Resource): token = None allow = False - max_messages = request.args.get("max_messages", 10) + # max_messages = request.args.get("max_messages", 10) client_id = request.args.get("client_id") if client_id in self.clients: client = self.clients.get(client_id) @@ -40,4 +39,4 @@ class Token(Resource): else: abort(403, message="Invalid client credentials") - return token \ No newline at end of file + return token diff --git a/controllers/__init__.py b/controllers/__init__.py index 246e02a859d193a0a6dc8c223d0da12c153f9a74..308b2c084ed8d7295aa8474dfd7f2ea49b5c771c 100644 --- a/controllers/__init__.py +++ b/controllers/__init__.py @@ -1,9 +1,7 @@ import os - __all__ = [ os.path.splitext(os.path.basename(x))[0] for x in os.listdir(os.path.dirname(__file__)) if x.endswith(".py") and x != "__init__.py" ] - diff --git a/workers/client_read.py b/controllers/client_read.py similarity index 74% rename from workers/client_read.py rename to controllers/client_read.py index 67c791a9c2ccd3a40f2c19bddc8636f9b6d400ca..ed18e398458589ce7242567a32bf38c3264f3c27 100644 --- a/workers/client_read.py +++ b/controllers/client_read.py @@ -6,7 +6,9 @@ import json def main(): - connection = pika.BlockingConnection(pika.ConnectionParameters(host="localhost")) + connection = pika.BlockingConnection( + pika.ConnectionParameters(host="localhost"), + ) channel = connection.channel() queue_name = sys.argv[1] @@ -17,7 +19,11 @@ def main(): message = json.loads(body.decode()) print(" [x] Received %r" % message) - channel.basic_consume(queue=queue_name, on_message_callback=callback, auto_ack=True) + channel.basic_consume( + queue=queue_name, + on_message_callback=callback, + auto_ack=True, + ) print(" [*] Waiting for messages. To exit press CTRL+C") channel.start_consuming() diff --git a/controllers/client_send.py b/controllers/client_send.py index 42197be815668bbf2a21fe3d679e6accf8bfb521..90b4c54296efbccdba03b7c799aeea354c6937a7 100644 --- a/controllers/client_send.py +++ b/controllers/client_send.py @@ -4,7 +4,9 @@ import sys import json -connection = pika.BlockingConnection(pika.ConnectionParameters(host="localhost")) +connection = pika.BlockingConnection( + pika.ConnectionParameters(host="localhost"), +) channel = connection.channel() queue_name = sys.argv[1] diff --git a/controllers/core_broadcast.py b/controllers/core_broadcast.py index de5062db99981096a513d372b2acd6dcd0fa28d4..26753cb577055067351304760c1fb893c8d81b23 100644 --- a/controllers/core_broadcast.py +++ b/controllers/core_broadcast.py @@ -21,13 +21,16 @@ def deliver(body, broadcast_exchange): def listen(queue_name, broadcast_exchange): def bcast_callback(ch, method, properties, body): - delivered = deliver(body, broadcast_exchange) + # delivered = deliver(body, broadcast_exchange) ch.basic_ack(delivery_tag=method.delivery_tag) listen_connection = get_connection() listen_channel = listen_connection.channel() listen_channel.queue_declare(queue=queue_name, durable=True) - listen_channel.basic_consume(queue=queue_name, on_message_callback=bcast_callback) + listen_channel.basic_consume( + queue=queue_name, + on_message_callback=bcast_callback, + ) listen_channel.start_consuming() diff --git a/controllers/core_bus.py b/controllers/core_bus.py index d280bae2f148cdc018dac4bc041e82bfa7b1ad6d..5c76276f0acff72488f34595ad0beffb97e61714 100644 --- a/controllers/core_bus.py +++ b/controllers/core_bus.py @@ -4,10 +4,12 @@ # (per client) # [client_id]-inbox - receive subscriptions # [client_id]-outbox - send messages to the bus -# [client-id]-broadcast - send message to all subscribers? - eg notify of downtime +# [client-id]-broadcast - send message to all subscribers? +# - eg notify of downtime # soar-publisher - fan-in from client-outboxes # soar-dlq - undeliverables -# soar-broadcast - admin messages forwarded to all client-inboxes regardless of subscriptions +# soar-broadcast - admin messages forwarded to all client-inboxes +# regardless of subscriptions import concurrent.futures from api.clients import ClientsFile @@ -29,7 +31,11 @@ def main(): with concurrent.futures.ProcessPoolExecutor() as executor: # publish - thread = executor.submit(publish, "soar-publish", EXCHANGES.get("publish")) + thread = executor.submit( + publish, + "soar-publish", + EXCHANGES.get("publish"), + ) THREADS.append(thread) for (id, client) in clients.items(): @@ -51,8 +57,8 @@ def main(): ) THREADS.append(thread) # push - # TODO - add optional webhook target to client and post to webhook target - # if present + # TODO - add optional webhook target to client and + # post to webhook target if present if __name__ == "__main__": diff --git a/workers/core_forward.py b/controllers/core_forward.py similarity index 85% rename from workers/core_forward.py rename to controllers/core_forward.py index 07763f0e01dd8ffef197ee9a2ec53003fec07b8a..839671a7c32a79a39f9a7030d0724955860a13cc 100644 --- a/workers/core_forward.py +++ b/controllers/core_forward.py @@ -1,4 +1,3 @@ - #!/usr/bin/env python import pika @@ -12,13 +11,17 @@ def deliver(body, queue_name): deliver_connection = get_connection() deliver_channel = deliver_connection.channel() deliver_channel.queue_declare(queue=queue_name, durable=True) - deliver_channel.basic_publish(exchange="", routing_key=queue_name, body=body) + deliver_channel.basic_publish( + exchange="", + routing_key=queue_name, + body=body, + ) deliver_connection.close() def listen(from_queue_name, to_queue_name): def fwd_callback(ch, method, properties, body): - delivered = deliver(body, to_queue_name) + # delivered = deliver(body, to_queue_name) ch.basic_ack(delivery_tag=method.delivery_tag) listen_connection = get_connection() diff --git a/controllers/core_publish.py b/controllers/core_publish.py index 52121faae63eead99d410d4534e3888f11842eb0..7677085ade5fb5768cd26e818d61fd85a90d05c2 100644 --- a/controllers/core_publish.py +++ b/controllers/core_publish.py @@ -3,6 +3,8 @@ import pika import json import sys +# from wrappers.rabbit_mq import publish as publish_wrapper + def get_connection(): return pika.BlockingConnection(pika.ConnectionParameters(host="localhost")) @@ -12,7 +14,10 @@ def deliver(body, topic, publish_exchange): print("publish on topic: %s" % topic) deliver_connection = get_connection() deliver_channel = deliver_connection.channel() - deliver_channel.exchange_declare(exchange=publish_exchange, exchange_type="topic") + deliver_channel.exchange_declare( + exchange=publish_exchange, + exchange_type="topic", + ) deliver_channel.basic_publish( exchange=publish_exchange, routing_key=topic, body=body ) @@ -29,7 +34,10 @@ def listen(queue_name, publish_exchange): listen_connection = get_connection() listen_channel = listen_connection.channel() listen_channel.queue_declare(queue=queue_name, durable=True) - listen_channel.basic_consume(queue=queue_name, on_message_callback=pub_callback) + listen_channel.basic_consume( + queue=queue_name, + on_message_callback=pub_callback, + ) listen_channel.start_consuming() diff --git a/controllers/core_subscribe.py b/controllers/core_subscribe.py new file mode 100644 index 0000000000000000000000000000000000000000..f1d4a733e209962c19496c8dac457cc49e8f7aa2 --- /dev/null +++ b/controllers/core_subscribe.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python +from wrappers.rabbit_mq import subscribe as subscribe_wrapper + + +def subscribe( + queue_name, + topics_list, + publish_exchange="soar_publish", + broadcast_exchange="soar_broadcast", +): + subscribe_wrapper( + "localhost", # TODO: Update + queue_name, + publish_exchange, + broadcast_exchange, + topics_list, + ) + + +# Shifted to wrappers/ +# def old_subscribe( +# queue_name, +# topic, +# publish_exchange="soar_publish", +# broadcast_exchange="soar_broadcast", +# ): +# adm_connection = get_connection() +# admin_channel = adm_connection.channel() +# admin_channel.exchange_declare( +# exchange=broadcast_exchange, +# exchange_type="fanout" +# ) +# admin_channel.queue_bind(exchange=broadcast_exchange, queue=queue_name) +# sub_connection = get_connection() +# subscriber_channel = sub_connection.channel() +# subscriber_channel.exchange_declare( +# exchange=publish_exchange, exchange_type="topic" +# ) +# subscriber_channel.queue_bind( +# exchange=publish_exchange, queue=queue_name, routing_key=topic +# ) diff --git a/models/token.py b/models/token.py index 6721a9971abe59a20d249f52e5bddbe8d4974acc..522858677eed5c3f9d641884d1f59863e5183e07 100644 --- a/models/token.py +++ b/models/token.py @@ -1,4 +1,4 @@ -from cryptography.fernet import Fernet,InvalidToken +from cryptography.fernet import Fernet, InvalidToken import datetime import os import json @@ -7,84 +7,75 @@ import json TOKENS = {} -class TokenModel(): +class TokenModel: clients = None schema = None key = None fernet = None token_lifetime_hours = None - env_lifetime = 'SOAR_TOKEN_LIFETIME' - env_secret = 'SOAR_TOKEN_SECRET' + env_lifetime = "SOAR_TOKEN_LIFETIME" + env_secret = "SOAR_TOKEN_SECRET" def __init__(self): self.getFernet() self.token_lifetime_hours = os.getenv(self.env_lifetime, 24) - - def getFernet(self): - self.fernet = Fernet(self.getKey().encode()) - def getKey(self): + def getFernet(self): + self.fernet = Fernet(self.getKey().encode()) + + def getKey(self): key = os.getenv(self.env_secret) print(key) - if not key: + if not key: key = Fernet.generate_key().decode() os.environ[self.env_secret] = key - self.key = key + self.key = key return self.key - def setSecret(self): + def setSecret(self): if not os.getenv(self.env_secret): - os.environ[self.env_secret] = self.getKey() + os.environ[self.env_secret] = self.getKey() - def getExpiry(self): + def getExpiry(self): now = datetime.datetime.utcnow() expires = now + datetime.timedelta(hours=self.token_lifetime_hours) return expires.isoformat() - + def encrypt(self, client_id): try: expiry = self.getExpiry() - token_content = { - 'client_id': client_id, - 'expiry': expiry - } - token = self.fernet.encrypt(json.dumps(token_content).encode()).decode() - return { - 'token': token, - 'expiry': expiry - } - except KeyError as e: + token_content = {"client_id": client_id, "expiry": expiry} + token = self.fernet.encrypt( + json.dumps(token_content).encode(), + ).decode() + return {"token": token, "expiry": expiry} + except KeyError: # as e: return None - + def decrypt(self, token): - try: + try: content = json.loads(self.fernet.decrypt(token.encode()).decode()) - return content - except (InvalidToken,KeyError) as e: + return content + except (InvalidToken, KeyError): # as e: return None def get(self, client_id): response = self.encrypt(client_id) - TOKENS[response['token']] = client_id - return response + TOKENS[response["token"]] = client_id + return response def validate(self, token): - response = { - 'valid': False - } + response = {"valid": False} if token in TOKENS: content = self.decrypt(token) if content: now = datetime.datetime.utcnow() - expires = datetime.datetime.fromisoformat(content['expiry']) - response['valid'] = expires > now - if response['valid']: + expires = datetime.datetime.fromisoformat(content["expiry"]) + response["valid"] = expires > now + if response["valid"]: response.update(content) - else: + else: del TOKENS[token] else: del TOKENS[token] - return response - - - + return response diff --git a/run_backbone.py b/run_backbone.py index 1f9a768de91f1197d09722dc74d120afcb249670..003674872e2472e531ee5b5e5b1f2fae23a53c6c 100644 --- a/run_backbone.py +++ b/run_backbone.py @@ -24,4 +24,3 @@ api.add_resource(Token, "/token") if __name__ == "__main__": app.run(debug=True, port=8087) - diff --git a/workers/__init__.py b/workers/__init__.py index 9760bb04435ab87159dbf506c8cb44b131e4eac8..3e793635fcf65cf1b40e227cdc1944554d65fff9 100644 --- a/workers/__init__.py +++ b/workers/__init__.py @@ -6,5 +6,5 @@ __all__ = [ if x.endswith(".py") and x != "__init__.py" ] -__workers__ = ["core_forward", "core_subscribe", "client_read"] - +__workers__ = [] +# ["core_bus", "core_forward", "core_subscribe", "client_read"] diff --git a/workers/core_subscribe.py b/workers/core_subscribe.py deleted file mode 100644 index 9ce847738e26c9ccc4530c5b0476b98d0f0b9cff..0000000000000000000000000000000000000000 --- a/workers/core_subscribe.py +++ /dev/null @@ -1,26 +0,0 @@ -#!/usr/bin/env python -import pika - - -def get_connection(): - return pika.BlockingConnection(pika.ConnectionParameters(host="localhost")) - - -def subscribe( - queue_name, - topic, - publish_exchange="soar_publish", - broadcast_exchange="soar_broadcast", -): - adm_connection = get_connection() - admin_channel = adm_connection.channel() - admin_channel.exchange_declare(exchange=broadcast_exchange, exchange_type="fanout") - admin_channel.queue_bind(exchange=broadcast_exchange, queue=queue_name) - sub_connection = get_connection() - subscriber_channel = sub_connection.channel() - subscriber_channel.exchange_declare( - exchange=publish_exchange, exchange_type="topic" - ) - subscriber_channel.queue_bind( - exchange=publish_exchange, queue=queue_name, routing_key=topic - ) diff --git a/wrappers/__init__.py b/wrappers/__init__.py index 585c3ea40adcc641ec4623f182ee22cba50080d1..56b09312b392c84af507ee4c4234835e4904413e 100644 --- a/wrappers/__init__.py +++ b/wrappers/__init__.py @@ -1,9 +1,12 @@ -import os - +# import os +# __all__ = [ +# os.path.splitext(os.path.basename(x))[0] +# for x in os.listdir(os.path.dirname(__file__)) +# if x.endswith(".py") and x != "__init__.py" +# ] __all__ = [ "rabbit_mq", - "zero_mq", - "kombu", -] - + # "zero_mq", + # "kombu", +] diff --git a/wrappers/rabbit_mq.py b/wrappers/rabbit_mq.py index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..d789eae869f30415e31820dec265b03bf4108769 100644 --- a/wrappers/rabbit_mq.py +++ b/wrappers/rabbit_mq.py @@ -0,0 +1,103 @@ +#!/usr/bin/env python +import pika + +# import sys +# import os +# import json + + +def get_connection(host="localhost"): + return pika.BlockingConnection(pika.ConnectionParameters(host=host)) + + +def callback(channel, method, properties, body): + # message = json.loads(body.decode()) + # topic = message["topic"] + # deliver(body, topic, publish_exchange) + # channel.basic_ack(delivery_tag=method.delivery_tag) + print(" [x] %r:%r" % (method.routing_key, body)) + + +def publish( + host="localhost", + queue_name=None, + exchange_name=None, + exchange_type="topic", + routing_key=None, + message="Testing Publish", +): + """ + Publishes message to specific channels only. + <> + Ref: https://www.rabbitmq.com/tutorials/tutorial-one-python.html + + Args: + host: + queue_name:, + exchange_name: + exchange_type(optional): + routing_key(optional): (applicable if not `fanout` exchange_type) + message: + + Returns: + """ + connection = get_connection(host) + channel = connection.channel() + channel.exchange_declare( + exchange=exchange_name, + exchange_type=exchange_type, + ) + + message = message.encode(encoding="UTF-8") + channel.basic_publish( + exchange=exchange_name, + routing_key=routing_key, + body=message, + ) + print("Published %r on exchange %r" % (message, exchange_name)) + connection.close() + + +def subscribe( + host, + queue_name="", + publish_exchange="soar_publish", + broadcast_exchange="soar_broadcast", + topics_list=[], +): + def enable_subscription( + exchange, + # queue_name, + current_exchange_type, + subscription_routing_key=None, + ): + connection = get_connection(host) + subscriber_channel = connection.channel() + + subscriber_channel.exchange_declare( + exchange=exchange, exchange_type=current_exchange_type + ) + + # TODO: Discuss> should the queue names be different? + # (do we randomize the naming?) + result = subscriber_channel.queue_declare(queue=queue_name) + queue = result.method.queue + subscriber_channel.queue_bind( + exchange=publish_exchange, + queue=queue, + routing_key=subscription_routing_key, + ) + + print(" [*] Waiting for logs. To exit press CTRL+C") + + subscriber_channel.basic_consume( + queue=queue, on_message_callback=callback, auto_ack=True + ) + subscriber_channel.start_consuming() + + # Subsribe to list of topics (publish type) + for topic in topics_list: + enable_subscription(publish_exchange, "topic", topic) + + # Subsribe to all messages (broadcast type) + enable_subscription(broadcast_exchange, "fanout")