Commit 4408fd96 authored by Trishna Saeharaseelan's avatar Trishna Saeharaseelan
Browse files

feat(wrappers): add wrappers for rabbitmq

parent 1352b3a1
......@@ -51,7 +51,12 @@ pipenv run python api.py
#### Create some clients
`POST` to `http://localhost:3000/clients`
`POST` to `http://localhost:3000/clients`. Example:
```
{
"client_id": "qwert123456789"
}
```
#### Event bus
......
......@@ -8,5 +8,5 @@ __all__ = [
# "workers",
"fixtures",
"tests",
"run",
# "run",
]
......@@ -5,4 +5,4 @@ __all__ = [
os.path.splitext(os.path.basename(x))[0]
for x in os.listdir(os.path.dirname(__file__))
if x.endswith(".py") and x != "__init__.py"
]
\ No newline at end of file
]
......@@ -2,24 +2,24 @@ import json
from flask_restful import Resource, abort
from models.token import TokenModel
class AuthResource(Resource):
def __init__(self):
class AuthResource(Resource):
def __init__(self):
self.token = TokenModel()
with open("clients.json", "r") as clients_file:
self.clients = json.load(clients_file)
def auth(self, request):
def auth(self, request):
allow = False
auth = request.headers.get('Authorization', False)
auth = request.headers.get("Authorization", False)
if auth:
token = auth.split(' ').pop()
token = auth.split(" ").pop()
parsed = self.token.validate(token)
if parsed['valid']:
client = self.clients.get(parsed['client_id'])
if client:
if parsed["valid"]:
client = self.clients.get(parsed["client_id"])
if client:
self.client = client
allow = True
if not allow:
abort(403, message="Invalid token")
return allow
\ No newline at end of file
return allow
......@@ -8,12 +8,12 @@ import string
class ClientSchema(Schema):
client_id = fields.Str(required=True)
client_name = fields.Str(required=True)
client_name = fields.Str(required=True)
subscription = fields.Str(required=True)
class ClientsFile:
file = "clients.json"
file = "fixtures.clients.json"
mtime = 0
clients = {}
parser = None
......@@ -27,7 +27,7 @@ class ClientsFile:
if mtime > self.mtime:
with open(self.file, "r") as client_file:
self.clients = json.load(client_file)
except FileNotFoundError as error:
except FileNotFoundError:
self.clients = {}
self.save()
......@@ -42,7 +42,7 @@ class ClientsFile:
return client
def add(self, client):
client['secret'] = self.secret()
client["secret"] = self.secret()
self.clients[client["client_id"]] = client
self.save()
return client
......@@ -70,28 +70,36 @@ class ClientsFile:
def secret(self, chars=36):
res = "".join(
random.choices(
string.ascii_lowercase + string.ascii_uppercase + string.digits, k=chars
(
string.ascii_lowercase,
+string.ascii_uppercase,
+string.digits,
),
k=chars,
)
)
return str(res)
clients_file = ClientsFile()
# Client
class Client(Resource):
clients_file = None
def __init__(self):
def __init__(self):
self.schema = ClientSchema()
self.clients_file = ClientsFile()
def get(self, client_id):
client = self.clients_file.find(client_id)
del client['secret']
del client["secret"]
if not client:
abort(404, message="No client with id: {}".format(client_id))
return client
def delete(self, todo_id):
def delete(self, client_id):
client = self.clients_file.find(client_id)
if not client:
abort(404, message="No client with id: {}".format(client_id))
......@@ -102,9 +110,9 @@ class Client(Resource):
def put(self, client_id):
args = request.get_json()
errors = self.schema.validate(args)
if errors:
if errors:
abort(400, message=str(errors))
client = self.clients_file.find(client_id)
if not client:
abort(404, message="No client with id: {}".format(client_id))
......@@ -115,7 +123,7 @@ class Client(Resource):
# ClientList
class ClientList(Resource):
def __init__(self):
def __init__(self):
self.schema = ClientSchema()
self.clients_file = ClientsFile()
......@@ -127,13 +135,13 @@ class ClientList(Resource):
def post(self):
args = request.get_json()
print("Args is %s" % args)
errors = self.schema.validate(args)
if errors:
if errors:
abort(400, message=str(errors))
client = clients_file.find(args["client_id"])
if client:
print("errors is %s" % errors)
client_id = clients_file.find(args["client_id"])
if client_id:
abort(403, message="Duplicate client id: {}".format(client_id))
else:
client = clients_file.add(args)
......
......@@ -8,6 +8,7 @@ from api.auth_resource import AuthResource
class NotifySchema(Schema):
body = fields.Str(required=True)
class Notify(AuthResource):
clients = None
schema = None
......@@ -15,7 +16,7 @@ class Notify(AuthResource):
def __init__(self):
super().__init__()
self.schema = NotifySchema()
def post(self):
args = request.get_json()
errors = self.schema.validate(args)
......@@ -25,16 +26,20 @@ class Notify(AuthResource):
allow = False
body = args.get("body")
message = {
'topic': 'broadcast',
'message': body,
"topic": "broadcast",
"message": body,
}
allow = self.auth(request)
if allow:
notify_queue = self.client['client_id'] + "-broadcast"
connection = pika.BlockingConnection(pika.ConnectionParameters(host="localhost"))
notify_queue = self.client["client_id"] + "-broadcast"
connection = pika.BlockingConnection(
pika.ConnectionParameters(host="localhost")
)
channel = connection.channel()
channel.queue_declare(queue=notify_queue, durable=True)
channel.basic_publish(exchange="", routing_key=notify_queue, body=json.dumps(message))
channel.basic_publish(
exchange="", routing_key=notify_queue, body=json.dumps(message)
)
connection.close()
......@@ -2,7 +2,8 @@ from flask_restful import request, abort
from marshmallow import Schema, fields
import pika
import json
from models.token import TokenModel
# from models.token import TokenModel
from api.auth_resource import AuthResource
......@@ -25,11 +26,11 @@ class Receive(AuthResource):
messages = []
max_messages = request.args.get("max_messages", 10)
allow = self.auth(request)
if allow:
inbox_queue = self.client['client_id'] + "-inbox"
inbox_queue = self.client["client_id"] + "-inbox"
if allow:
connection = pika.BlockingConnection(
pika.ConnectionParameters(host="localhost")
......@@ -37,7 +38,9 @@ class Receive(AuthResource):
channel = connection.channel()
channel.queue_declare(queue=inbox_queue, durable=True)
while len(messages) < max_messages:
method_frame, header_frame, body = channel.basic_get(inbox_queue)
method_frame, header_frame, body = channel.basic_get(
inbox_queue,
)
if method_frame:
print(method_frame, header_frame, body)
channel.basic_ack(method_frame.delivery_tag)
......
......@@ -4,10 +4,12 @@ from marshmallow import Schema, fields
import pika
from api.auth_resource import AuthResource
class SendSchema(Schema):
body = fields.Str(required=True)
topic = fields.Str(required=True)
class Send(AuthResource):
clients = None
schema = None
......@@ -15,7 +17,7 @@ class Send(AuthResource):
def __init__(self):
super().__init__()
self.schema = SendSchema()
def post(self):
args = request.get_json()
errors = self.schema.validate(args)
......@@ -23,18 +25,22 @@ class Send(AuthResource):
abort(400, message=str(errors))
allow = self.auth(request)
if allow:
body = args.get("body")
topic = args.get("topic")
outbox_queue = self.client['client_id'] + "-outbox"
connection = pika.BlockingConnection(pika.ConnectionParameters(host="localhost"))
outbox_queue = self.client["client_id"] + "-outbox"
connection = pika.BlockingConnection(
pika.ConnectionParameters(host="localhost")
)
channel = connection.channel()
channel.queue_declare(queue=outbox_queue, durable=True)
message = {
'topic': topic,
'message': body,
"topic": topic,
"message": body,
}
channel.basic_publish(exchange="", routing_key=outbox_queue, body=json.dumps(message))
channel.basic_publish(
exchange="", routing_key=outbox_queue, body=json.dumps(message)
)
connection.close()
import json
import pika
import json
from flask_restful import Resource, request, abort
from marshmallow import Schema, fields
from models.token import TokenModel
......@@ -20,7 +19,7 @@ class Token(Resource):
self.model = TokenModel()
with open("clients.json", "r") as clients_file:
self.clients = json.load(clients_file)
def get(self):
errors = self.schema.validate(request.args)
if errors:
......@@ -28,7 +27,7 @@ class Token(Resource):
token = None
allow = False
max_messages = request.args.get("max_messages", 10)
# max_messages = request.args.get("max_messages", 10)
client_id = request.args.get("client_id")
if client_id in self.clients:
client = self.clients.get(client_id)
......@@ -40,4 +39,4 @@ class Token(Resource):
else:
abort(403, message="Invalid client credentials")
return token
\ No newline at end of file
return token
import os
__all__ = [
os.path.splitext(os.path.basename(x))[0]
for x in os.listdir(os.path.dirname(__file__))
if x.endswith(".py") and x != "__init__.py"
]
......@@ -6,7 +6,9 @@ import json
def main():
connection = pika.BlockingConnection(pika.ConnectionParameters(host="localhost"))
connection = pika.BlockingConnection(
pika.ConnectionParameters(host="localhost"),
)
channel = connection.channel()
queue_name = sys.argv[1]
......@@ -17,7 +19,11 @@ def main():
message = json.loads(body.decode())
print(" [x] Received %r" % message)
channel.basic_consume(queue=queue_name, on_message_callback=callback, auto_ack=True)
channel.basic_consume(
queue=queue_name,
on_message_callback=callback,
auto_ack=True,
)
print(" [*] Waiting for messages. To exit press CTRL+C")
channel.start_consuming()
......
......@@ -4,7 +4,9 @@ import sys
import json
connection = pika.BlockingConnection(pika.ConnectionParameters(host="localhost"))
connection = pika.BlockingConnection(
pika.ConnectionParameters(host="localhost"),
)
channel = connection.channel()
queue_name = sys.argv[1]
......
......@@ -21,13 +21,16 @@ def deliver(body, broadcast_exchange):
def listen(queue_name, broadcast_exchange):
def bcast_callback(ch, method, properties, body):
delivered = deliver(body, broadcast_exchange)
# delivered = deliver(body, broadcast_exchange)
ch.basic_ack(delivery_tag=method.delivery_tag)
listen_connection = get_connection()
listen_channel = listen_connection.channel()
listen_channel.queue_declare(queue=queue_name, durable=True)
listen_channel.basic_consume(queue=queue_name, on_message_callback=bcast_callback)
listen_channel.basic_consume(
queue=queue_name,
on_message_callback=bcast_callback,
)
listen_channel.start_consuming()
......
......@@ -4,10 +4,12 @@
# (per client)
# [client_id]-inbox - receive subscriptions
# [client_id]-outbox - send messages to the bus
# [client-id]-broadcast - send message to all subscribers? - eg notify of downtime
# [client-id]-broadcast - send message to all subscribers?
# - eg notify of downtime
# soar-publisher - fan-in from client-outboxes
# soar-dlq - undeliverables
# soar-broadcast - admin messages forwarded to all client-inboxes regardless of subscriptions
# soar-broadcast - admin messages forwarded to all client-inboxes
# regardless of subscriptions
import concurrent.futures
from api.clients import ClientsFile
......@@ -29,7 +31,11 @@ def main():
with concurrent.futures.ProcessPoolExecutor() as executor:
# publish
thread = executor.submit(publish, "soar-publish", EXCHANGES.get("publish"))
thread = executor.submit(
publish,
"soar-publish",
EXCHANGES.get("publish"),
)
THREADS.append(thread)
for (id, client) in clients.items():
......@@ -51,8 +57,8 @@ def main():
)
THREADS.append(thread)
# push
# TODO - add optional webhook target to client and post to webhook target
# if present
# TODO - add optional webhook target to client and
# post to webhook target if present
if __name__ == "__main__":
......
#!/usr/bin/env python
import pika
......@@ -12,13 +11,17 @@ def deliver(body, queue_name):
deliver_connection = get_connection()
deliver_channel = deliver_connection.channel()
deliver_channel.queue_declare(queue=queue_name, durable=True)
deliver_channel.basic_publish(exchange="", routing_key=queue_name, body=body)
deliver_channel.basic_publish(
exchange="",
routing_key=queue_name,
body=body,
)
deliver_connection.close()
def listen(from_queue_name, to_queue_name):
def fwd_callback(ch, method, properties, body):
delivered = deliver(body, to_queue_name)
# delivered = deliver(body, to_queue_name)
ch.basic_ack(delivery_tag=method.delivery_tag)
listen_connection = get_connection()
......
......@@ -3,6 +3,8 @@ import pika
import json
import sys
# from wrappers.rabbit_mq import publish as publish_wrapper
def get_connection():
return pika.BlockingConnection(pika.ConnectionParameters(host="localhost"))
......@@ -12,7 +14,10 @@ def deliver(body, topic, publish_exchange):
print("publish on topic: %s" % topic)
deliver_connection = get_connection()
deliver_channel = deliver_connection.channel()
deliver_channel.exchange_declare(exchange=publish_exchange, exchange_type="topic")
deliver_channel.exchange_declare(
exchange=publish_exchange,
exchange_type="topic",
)
deliver_channel.basic_publish(
exchange=publish_exchange, routing_key=topic, body=body
)
......@@ -29,7 +34,10 @@ def listen(queue_name, publish_exchange):
listen_connection = get_connection()
listen_channel = listen_connection.channel()
listen_channel.queue_declare(queue=queue_name, durable=True)
listen_channel.basic_consume(queue=queue_name, on_message_callback=pub_callback)
listen_channel.basic_consume(
queue=queue_name,
on_message_callback=pub_callback,
)
listen_channel.start_consuming()
......
#!/usr/bin/env python
from wrappers.rabbit_mq import subscribe as subscribe_wrapper
def subscribe(
queue_name,
topics_list,
publish_exchange="soar_publish",
broadcast_exchange="soar_broadcast",
):
subscribe_wrapper(
"localhost", # TODO: Update
queue_name,
publish_exchange,
broadcast_exchange,
topics_list,
)
# Shifted to wrappers/
# def old_subscribe(
# queue_name,
# topic,
# publish_exchange="soar_publish",
# broadcast_exchange="soar_broadcast",
# ):
# adm_connection = get_connection()
# admin_channel = adm_connection.channel()
# admin_channel.exchange_declare(
# exchange=broadcast_exchange,
# exchange_type="fanout"
# )
# admin_channel.queue_bind(exchange=broadcast_exchange, queue=queue_name)
# sub_connection = get_connection()
# subscriber_channel = sub_connection.channel()
# subscriber_channel.exchange_declare(
# exchange=publish_exchange, exchange_type="topic"
# )
# subscriber_channel.queue_bind(
# exchange=publish_exchange, queue=queue_name, routing_key=topic
# )
from cryptography.fernet import Fernet,InvalidToken
from cryptography.fernet import Fernet, InvalidToken
import datetime
import os
import json
......@@ -7,84 +7,75 @@ import json
TOKENS = {}
class TokenModel():
class TokenModel:
clients = None
schema = None
key = None
fernet = None
token_lifetime_hours = None
env_lifetime = 'SOAR_TOKEN_LIFETIME'
env_secret = 'SOAR_TOKEN_SECRET'
env_lifetime = "SOAR_TOKEN_LIFETIME"
env_secret = "SOAR_TOKEN_SECRET"
def __init__(self):
self.getFernet()
self.token_lifetime_hours = os.getenv(self.env_lifetime, 24)
def getFernet(self):
self.fernet = Fernet(self.getKey().encode())
def getKey(self):
def getFernet(self):
self.fernet = Fernet(self.getKey().encode())
def getKey(self):
key = os.getenv(self.env_secret)
print(key)
if not key:
if not key:
key = Fernet.generate_key().decode()
os.environ[self.env_secret] = key
self.key = key
self.key = key
return self.key
def setSecret(self):
def setSecret(self):
if not os.getenv(self.env_secret):
os.environ[self.env_secret] = self.getKey()
os.environ[self.env_secret] = self.getKey()
def getExpiry(self):
def getExpiry(self):
now = datetime.datetime.utcnow()
expires = now + datetime.timedelta(hours=self.token_lifetime_hours)
return expires.isoformat()
def encrypt(self, client_id):
try:
expiry = self.getExpiry()
token_content = {
'client_id': client_id,
'expiry': expiry
}
token = self.fernet.encrypt(json.dumps(token_content).encode()).decode()
return {
'token': token,
'expiry': expiry
}
except KeyError as e:
token_content = {"client_id": client_id, "expiry": expiry}
token = self.fernet.encrypt(
json.dumps(token_content).encode(),
).decode()
return {"token": token, "expiry": expiry}
except KeyError: # as e:
return None
def decrypt(self, token):
try:
try:
content = json.loads(self.fernet.decrypt(token.encode()).decode())
return content
except (InvalidToken,KeyError) as e:
return content
except (InvalidToken, KeyError): # as e:
return None
def get(self, client_id):
response = self.encrypt(client_id)
TOKENS[response['token']] = client_id
return response
TOKENS[response["token"]] = client_id
return response
def validate(self, token):
response = {
'valid': False
}
response = {"valid": False}
if token in TOKENS:
content = self.decrypt(token)
if content:
now = datetime.datetime.utcnow()
expires = datetime.datetime.fromisoformat(content['expiry'])
response['valid'] = expires > now
if response['valid']:
expires = datetime.datetime.fromisoformat(content["expiry"])
response["valid"] = expires > now
if response["valid"]:
response.update(content)
else:
else:
del TOKENS[token]
else:
del TOKENS[token]
return response
return response
......@@ -24,4 +24,3 @@ api.add_resource(Token, "/token")
if __name__ == "__main__":
app.run(debug=True, port=8087)
......@@ -6,5 +6,5 @@ __all__ = [
if x.endswith(".py") and x != "__init__.py"
]
__workers__ = ["core_forward", "core_subscribe", "client_read"]
__workers__ = []
# ["core_bus", "core_forward", "core_subscribe", "client_read"]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment