Commit e247a1a5 authored by Dan Jones's avatar Dan Jones
Browse files

Merge branch '23-unit-tests-for-http-api-endpoints' into 'dev'

Resolve "Unit tests for HTTP API endpoints"

Closes #23

See merge request !11
parents 61d88a37 e66069b0
clients.json
__pycache__/
data/clients.json
examples/
rmq.log
Pipfile
......
[[source]]
url = "https://pypi.org/simple"
verify_ssl = true
name = "pypi"
[packages]
pubsubpy = "*"
pika = "*"
pyrabbit = "*"
flask = "*"
flask-restful = "*"
marshmallow = "*"
bson = "*"
flask-cors = "*"
cryptography = "*"
[dev-packages]
pytest = "*"
pytest-rabbitmq = "*"
pytest-mock = "*"
black = "*"
[requires]
python_version = "3.8"
This diff is collapsed.
......@@ -23,6 +23,17 @@ which queues it reads from.
Subsequent requests to the client endpoint return the client_id but not the secret.
### Testing
Current coverage:
- API: yes
- Pika RabbitMQ implementation: no
```
pytest
```
### Running via docker-compose
Using `docker-compose` will mean that everything is setup automatically, this includes the `rabbitmq` container, the backbone API, and the backbone bus. The `run-compose.sh` script has been provided to simplify this even further - all you have to do is set whatever env vars you need in the `.env` file and then run `./run-compose.sh` (the defaults in `.env` are fine for local dev work, but ones not labelled `optional` will need setting in a production setting). The env vars are:
......@@ -35,8 +46,10 @@ Using `docker-compose` will mean that everything is setup automatically, this in
#### Setup
In a virtual environment
```
pipenv install
pip install -r requirements-dev.txt
```
#### RabbitMQ
......@@ -46,29 +59,38 @@ pipenv install
#### API
```
pipenv run python api.py
python api.py
```
#### Create some clients
`POST` to `http://localhost:3000/clients`
#### Event bus
```
pipenv run python soar_bus.py
python soar_bus.py
```
### Usage
To use the backbone you have to create a client.
The client id and name should be unique but human
readable. A client secret is returned by the API
and client credential grants using this client_id
and secret are used to authenticate client
application connections.
#### Create some clients
`POST` to `http://localhost:3000/clients`
#### Send / Receive directly
```
# Send a message
pipenv run python client_send.py noc-c2-outbox 'soar.noc.slocum.something' from noc-c2
python client_send.py noc-c2-outbox 'soar.noc.slocum.something' from noc-c2
```
```
# Receive messages
pipenv run python client_read.py noc-sfmc-inbox
python client_read.py noc-sfmc-inbox
```
#### Receive via API
......@@ -80,14 +102,3 @@ secret matches before allowing the request.
This should be replaced with a proper auth layer.
`GET http://localhost:5000/receive?client_id=[client_id]&secret=[secret]`
### Components
- `soar_bus.py` - Run all the components threaded based on existing clients
- `soar_forward.py` - Listen for messages on queue A and forward messages to queue B
- `soar_publish.py` - Listen for messages on queue A and publish on exchange B
- `soar_broadcast.py` - Listen for messages on queue A and broadcast on exchange B
- `soar_subscribe.py` - Create subscriptions to both the publish and broadcast exchange - deliver to queue A
(I think this should probably be 2 separate functions to keep things nice and simple)
- `soar_push.py` - Not yet implemented - Listen for messages on queue A and POST to the client's webhook URL
......@@ -2,30 +2,38 @@ from flask import Flask
from flask_cors import CORS
from flask_restful import Api
from endpoints.clients import Client, ClientList
from endpoints.client import Client, ClientList
from endpoints.notify import Notify
from endpoints.receive import Receive
from endpoints.send import Send
from endpoints.token import Token
from models.token import TokenModel
from models.token_model import TokenModel
import os
token = TokenModel()
token.setSecret()
app = Flask(__name__)
api = Api(app)
CORS(app, resources={r"*": {"origins": "http://localhost:8086"}})
api.add_resource(ClientList, "/client")
api.add_resource(Client, "/client/<client_id>")
api.add_resource(Receive, "/receive")
api.add_resource(Send, "/send")
api.add_resource(Notify, "/notify")
api.add_resource(Token, "/token")
def create_app():
app = Flask(__name__)
api = Api(app)
CORS(app, resources={r"*": {"origins": "http://localhost:8086"}})
api.add_resource(ClientList, "/client")
api.add_resource(Client, "/client/<client_id>")
api.add_resource(Receive, "/receive")
api.add_resource(Send, "/send")
api.add_resource(Notify, "/notify")
api.add_resource(Token, "/token")
return app
flask_host = os.getenv("FLASK_HOST", "localhost") # Sets to whatever MQ_HOST is, or defaults to localhost
flask_host = os.getenv(
"FLASK_HOST", "localhost"
) # Sets to whatever MQ_HOST is, or defaults to localhost
if __name__ == "__main__":
app = create_app()
app.run(debug=False, port=8087, host=flask_host)
import pytest
def clients():
return {
"client-1": {
"client_id": "client-1",
"client_name": "Client 1",
"subscription": "soar.#",
"secret": "abc123",
},
"client-2": {
"client_id": "client-2",
"client_name": "Client 2",
"subscription": "soar.client-2.#",
"secret": "xyz789",
},
}
def get_auth_header(client, credentials):
token_response = client.get("/token", query_string=credentials)
if token_response.status_code == 200:
token = token_response.json["token"]
return {"Authorization": f"Bearer {token}"}
else:
return None
@pytest.fixture
def mock_clients():
return clients()
@pytest.fixture
def mock_new_client():
return {
"client_id": "client-3",
"client_name": "Client 3",
"subscription": "soar.client-3.#",
}
@pytest.fixture
def mock_client_credentials():
mock_clients = clients()
return {
"client_id": mock_clients["client-1"]["client_id"],
"secret": mock_clients["client-1"]["secret"],
}
@pytest.fixture
def mock_invalid_credentials():
return {"client_id": "client-invalid", "secret": "fake-secret"}
@pytest.fixture
def mock_token_secret():
return "2UrRyeb9c6hq8Gj9nmI5safPz9LpPeUFtifeMNx4GQo="
def posts():
return {
"send": {
"topic": "soar.client-1.message",
"body": "this is a pub/sub message from client-1",
},
"notify": {"body": "this is a broadcast message from client-1"},
}
@pytest.fixture
def mock_post_send():
return posts()["send"]
@pytest.fixture
def mock_message_send():
post = posts()["send"]
return {"topic": post["topic"], "message": post["body"]}
@pytest.fixture
def mock_post_notify():
return posts()["notify"]
@pytest.fixture
def mock_message_notify():
post = posts()["notify"]
return {"topic": "broadcast", "message": post["body"]}
@pytest.fixture
def mock_read_from_queue_return():
return [
{
"topic": "soar.client-1.something",
"message": "this is a pub/sub message from client-1",
}
]
......@@ -2,28 +2,26 @@ import json
import os
from flask_restful import Resource, abort
from models.token import TokenModel
from models.token_model import TokenModel
class AuthResource(Resource):
def __init__(self):
def __init__(self):
self.token = TokenModel()
with open("./data/clients.json", "r") as clients_file:
self.clients = json.load(clients_file)
def auth(self, request):
def auth(self, request):
allow = False
auth = request.headers.get('Authorization', False)
auth = request.headers.get("Authorization", False)
if auth:
token = auth.split(' ').pop()
token = auth.split(" ").pop()
parsed = self.token.validate(token)
if parsed['valid']:
client = self.clients.get(parsed['client_id'])
if client:
if parsed["valid"]:
client = self.clients.get(parsed["client_id"])
if client:
self.client = client
allow = True
if not allow:
abort(403, message="Invalid token")
return allow
\ No newline at end of file
return allow
......@@ -4,94 +4,31 @@ import json
import os
import random
import string
from models.client_model import ClientModel
class ClientSchema(Schema):
client_id = fields.Str(required=True)
client_name = fields.Str(required=True)
client_name = fields.Str(required=True)
subscription = fields.Str(required=True)
class ClientsFile:
file = "./data/clients.json"
mtime = 0
clients = {}
parser = None
def __init__(self):
self.get()
def get(self):
try:
mtime = os.path.getmtime(self.file)
if mtime > self.mtime:
with open(self.file, "r") as client_file:
self.clients = json.load(client_file)
except FileNotFoundError as error:
self.clients = {}
self.save()
return self.clients
def find(self, client_id):
self.get()
if client_id in self.clients:
client = self.clients[client_id]
else:
client = None
return client
def add(self, client):
client['secret'] = self.secret()
self.clients[client["client_id"]] = client
self.save()
return client
def remove(self, client):
del self.clients[client["client_id"]]
self.save()
def update(self, client_updates):
client = self.find(client_updates["client_id"])
client.update(client_updates)
self.clients[client["client_id"]] = client
self.save()
return client
def save(self):
try:
with open(self.file, "w") as client_file:
client_file.write(json.dumps(self.clients, indent=2))
return True
except OSError as error:
print(str(error))
return False
def secret(self, chars=36):
res = "".join(
random.choices(
string.ascii_lowercase + string.ascii_uppercase + string.digits, k=chars
)
)
return str(res)
clients_file = ClientsFile()
# Client
class Client(Resource):
clients_file = None
def __init__(self):
def __init__(self):
self.schema = ClientSchema()
self.clients_file = ClientsFile()
self.clients_file = ClientModel()
def get(self, client_id):
client = self.clients_file.find(client_id)
del client['secret']
del client["secret"]
if not client:
abort(404, message="No client with id: {}".format(client_id))
return client
def delete(self, todo_id):
def delete(self, client_id):
client = self.clients_file.find(client_id)
if not client:
abort(404, message="No client with id: {}".format(client_id))
......@@ -102,9 +39,9 @@ class Client(Resource):
def put(self, client_id):
args = request.get_json()
errors = self.schema.validate(args)
if errors:
if errors:
abort(400, message=str(errors))
client = self.clients_file.find(client_id)
if not client:
abort(404, message="No client with id: {}".format(client_id))
......@@ -115,9 +52,9 @@ class Client(Resource):
# ClientList
class ClientList(Resource):
def __init__(self):
def __init__(self):
self.schema = ClientSchema()
self.clients_file = ClientsFile()
self.clients_file = ClientModel()
def get(self):
return {
......@@ -129,9 +66,9 @@ class ClientList(Resource):
args = request.get_json()
errors = self.schema.validate(args)
if errors:
if errors:
abort(400, message=str(errors))
client = self.clients_file.find(args["client_id"])
if client:
abort(403, message="Duplicate client id: {}".format(client_id))
......
......@@ -10,6 +10,7 @@ from rmq import write_to_queue
class NotifySchema(Schema):
body = fields.Str(required=True)
class Notify(AuthResource):
clients = None
schema = None
......@@ -17,7 +18,7 @@ class Notify(AuthResource):
def __init__(self):
super().__init__()
self.schema = NotifySchema()
def post(self):
args = request.get_json()
errors = self.schema.validate(args)
......@@ -27,12 +28,12 @@ class Notify(AuthResource):
allow = False
body = args.get("body")
message = {
'topic': 'broadcast',
'message': body,
"topic": "broadcast",
"message": body,
}
allow = self.auth(request)
if allow:
notify_queue = self.client['client_id'] + "-broadcast"
write_to_queue(queue_name=notify_queue, msg=json.dumps(message))
\ No newline at end of file
notify_queue = self.client["client_id"] + "-broadcast"
write_to_queue(queue_name=notify_queue, msg=json.dumps(message))
......@@ -22,9 +22,10 @@ class Receive(AuthResource):
if errors:
abort(400, message=str(errors))
max_messages = request.args.get("max_messages", 10)
# force query string parameter value into int
max_messages = int(request.args.get("max_messages", 10))
allow = self.auth(request)
if allow:
inbox_queue = self.client['client_id'] + "-inbox"
return read_from_queue(queue_name=inbox_queue, max_msgs=max_messages)
\ No newline at end of file
inbox_queue = self.client["client_id"] + "-inbox"
return read_from_queue(queue_name=inbox_queue, max_msgs=max_messages)
......@@ -11,6 +11,7 @@ class SendSchema(Schema):
body = fields.Str(required=True)
topic = fields.Str(required=True)
class Send(AuthResource):
clients = None
schema = None
......@@ -18,7 +19,7 @@ class Send(AuthResource):
def __init__(self):
super().__init__()
self.schema = SendSchema()
def post(self):
args = request.get_json()
errors = self.schema.validate(args)
......@@ -26,14 +27,14 @@ class Send(AuthResource):
abort(400, message=str(errors))
allow = self.auth(request)
if allow:
body = args.get("body")
topic = args.get("topic")
outbox_queue = self.client['client_id'] + "-outbox"
outbox_queue = self.client["client_id"] + "-outbox"
message = {
'topic': topic,
'message': body,
"topic": topic,
"message": body,
}
write_to_queue(queue_name=outbox_queue, msg=json.dumps(message))
\ No newline at end of file
write_to_queue(queue_name=outbox_queue, msg=json.dumps(message))
import json
import json
from flask_restful import Resource, request, abort
from marshmallow import Schema, fields
from models.token import TokenModel
from models.token_model import TokenModel
class TokenQuerySchema(Schema):
client_id = fields.Str(required=True)
......@@ -18,7 +19,7 @@ class Token(Resource):
self.model = TokenModel()
with open("./data/clients.json", "r") as clients_file:
self.clients = json.load(clients_file)
def get(self):
errors = self.schema.validate(request.args)
if errors:
......@@ -38,4 +39,4 @@ class Token(Resource):
else:
abort(403, message="Invalid client credentials")
return token
\ No newline at end of file
return token
"""
The backbone doesn't have a relational database or other backend state provider
It just saves its configuration to the local filesystem
In docker data is saved to a mounted volume for persistence
POSTS to /client create entries in clients.json. The return from the post
contains a generated secret but subsequent calls to GET /client or
GET /client/{id} do not return the secret.
Each time .find is called the .get method is called which checks the
mtime of the file and reloads it if newer. This means that new clients
should be returned
"""
import json
import os
import random
import string
class ClientModel:
file = "./data/clients.json"
mtime = 0
clients = {}
parser = None
def __init__(self):
self.get()
def get(self):
try:
mtime = os.path.getmtime(self.file)
if mtime > self.mtime:
with open(self.file, "r") as client_file:
self.clients = json.load(client_file)
self.mtime = mtime
except FileNotFoundError as error:
self.clients = {}
self.save()
self.mtime = os.path.getmtime(self.file)
return self.clients
def find(self, client_id):
self.get()
if client_id in self.clients:
client = self.clients[client_id]
else:
client = None
return client
def add(self, client):
client["secret"] = self.secret()
self.clients[client["client_id"]] = client
self.save()
return client
def remove(self, client):
del self.clients[client["client_id"]]
self.save()
def update(self, client_updates):
client = self.find(client_updates["client_id"])
client.update(client_updates)
self.clients[client["client_id"]] = client
self.save()
return client
def save(self):
try:
with open(self.file, "w") as client_file:
client_file.write(json.dumps(self.clients, indent=2))
return True
except OSError as error:
print(str(error))
return False
def secret(self, chars=36):
res = "".join(
random.choices(
string.ascii_lowercase + string.ascii_uppercase + string.digits, k=chars
)
)
return str(res)
from cryptography.fernet import Fernet,InvalidToken
from cryptography.fernet import Fernet, InvalidToken
import datetime
import os
import json
......@@ -7,84 +7,73 @@ import json
TOKENS = {}
class TokenModel():
class TokenModel:
clients = None
schema = None
key = None
fernet = None
token_lifetime_hours = None
env_lifetime = 'SOAR_TOKEN_LIFETIME'
env_secret = 'SOAR_TOKEN_SECRET'
env_lifetime = "SOAR_TOKEN_LIFETIME"
env_secret = "SOAR_TOKEN_SECRET"
def __init__(self):
self.getFernet()
self.token_lifetime_hours = os.getenv(self.env_lifetime, 24)
def getFernet(self):
self.fernet = Fernet(self.getKey().encode())
self.token_lifetime_hours = int(os.getenv(self.env_lifetime, 24))
def getKey(self):
def getFernet(self):
self.fernet = Fernet(self.getKey().encode())
def getKey(self):
key = os.getenv(self.env_secret)
print(key)
if not key:
if not key:
key = Fernet.generate_key().decode()
os.environ[self.env_secret] = key
self.key = key
self.key = key
return self.key
def setSecret(self):
def setSecret(self):
if not os.getenv(self.env_secret):
os.environ[self.env_secret] = self.getKey()
os.environ[self.env_secret] = self.getKey()
def getExpiry(self):
def getExpiry(self):
now = datetime.datetime.utcnow()
expires = now + datetime.timedelta(hours=self.token_lifetime_hours)
return expires.isoformat()
def encrypt(self, client_id):
try:
expiry = self.getExpiry()
token_content = {
'client_id': client_id,
'expiry': expiry
}
token_content = {"client_id": client_id, "expiry": expiry}
token = self.fernet.encrypt(json.dumps(token_content).encode()).decode()
return {
'token': token,
'expiry': expiry
}
except KeyError as e:
return {"token": token, "expiry": expiry}
except KeyError as e:
return None
def decrypt(self, token):
try:
try:
content = json.loads(self.fernet.decrypt(token.encode()).decode())
return content
except (InvalidToken,KeyError) as e:
return content
except (InvalidToken, KeyError) as e:
return None
def get(self, client_id):
response = self.encrypt(client_id)
TOKENS[response['token']] = client_id
return response
TOKENS[response["token"]] = client_id
return response
def validate(self, token):
response = {
'valid': False
}
response = {"valid": False}
if token in TOKENS:
content = self.decrypt(token)
if content:
now = datetime.datetime.utcnow()
expires = datetime.datetime.fromisoformat(content['expiry'])
response['valid'] = expires > now
if response['valid']:
expires = datetime.datetime.fromisoformat(content["expiry"])
response["valid"] = expires > now
if response["valid"]:
response.update(content)
else:
else:
del TOKENS[token]
else:
del TOKENS[token]
return response
return response
-r requirements.txt
black==23.1.0
pytest==7.2.1
pytest-mock==3.10.0
pytest-rabbitmq==2.2.1
\ No newline at end of file
......@@ -16,7 +16,7 @@ jinja2==3.1.2 ; python_version >= '3.7'
kombu==5.2.4 ; python_version >= '3.7'
markupsafe==2.1.1 ; python_version >= '3.7'
marshmallow==3.19.0
packaging==21.3 ; python_version >= '3.6'
packaging>=22.0 ; python_version >= '3.6'
pika==1.3.1
pubsubpy==2.3.0
pycparser==2.21
......
......@@ -3,10 +3,13 @@ import os
import pika
host = os.getenv("MQ_HOST", "localhost") # Sets to whatever MQ_HOST is, or defaults to localhost
host = os.getenv(
"MQ_HOST", "localhost"
) # Sets to whatever MQ_HOST is, or defaults to localhost
# -------------------------------------------------------------------------------------------------------------------------------------------------------------
def pika_connect(host):
try:
connection = pika.BlockingConnection(pika.ConnectionParameters(host))
......@@ -16,22 +19,34 @@ def pika_connect(host):
if connection is not None:
channel = connection.channel()
else:
print("ERROR: Pika has been unable to connect to host '%s'. Is RabbitMQ running?" % host)
raise Exception("ERROR: Pika has been unable to connect to host '%s'. Is RabbitMQ running?" % host)
print(
"ERROR: Pika has been unable to connect to host '%s'. Is RabbitMQ running?"
% host
)
raise Exception(
"ERROR: Pika has been unable to connect to host '%s'. Is RabbitMQ running?"
% host
)
return connection, channel
def setup_queue(channel, queue_name=''):
channel.queue_declare(queue=queue_name, exclusive=False, durable=True) # exclusive means the queue can only be used by the connection that created it
def setup_queue(channel, queue_name=""):
channel.queue_declare(
queue=queue_name, exclusive=False, durable=True
) # exclusive means the queue can only be used by the connection that created it
def fanout_exchange(channel, exchange_name):
channel.exchange_declare(exchange=exchange_name, exchange_type='fanout', durable=True)
channel.exchange_declare(
exchange=exchange_name, exchange_type="fanout", durable=True
)
def topic_exchange(channel, exchange_name):
channel.exchange_declare(exchange=exchange_name, exchange_type='topic', durable=True)
channel.exchange_declare(
exchange=exchange_name, exchange_type="topic", durable=True
)
def deliver_to_exchange(channel, body, exchange_name, topic=None):
......@@ -39,37 +54,39 @@ def deliver_to_exchange(channel, body, exchange_name, topic=None):
fanout_exchange(channel=channel, exchange_name=exchange_name)
channel.basic_publish(
exchange=exchange_name,
routing_key='',
body=body,
routing_key="",
body=body,
properties=pika.BasicProperties(
delivery_mode=pika.spec.PERSISTENT_DELIVERY_MODE
)
),
)
else:
topic_exchange(channel=channel, exchange_name=exchange_name)
channel.basic_publish(
exchange=exchange_name,
routing_key=topic,
body=body,
routing_key=topic,
body=body,
properties=pika.BasicProperties(
delivery_mode=pika.spec.PERSISTENT_DELIVERY_MODE
)
),
)
# -------------------------------------------------------------------------------------------------------------------------------------------------------------
def write_to_queue(queue_name, msg):
# write a single message to a queue
connection, channel = pika_connect(host=host)
setup_queue(channel=channel, queue_name=queue_name)
channel.basic_publish(
exchange='',
routing_key=queue_name,
exchange="",
routing_key=queue_name,
body=msg,
properties=pika.BasicProperties(
delivery_mode=pika.spec.PERSISTENT_DELIVERY_MODE
)
),
)
connection.close()
......@@ -125,12 +142,12 @@ def forward(from_queue, to_queue):
def forward_callback(ch, method, properties, body):
channel.basic_publish(
exchange='',
routing_key=to_queue,
exchange="",
routing_key=to_queue,
body=body,
properties=pika.BasicProperties(
delivery_mode=pika.spec.PERSISTENT_DELIVERY_MODE
)
),
)
ch.basic_ack(delivery_tag=method.delivery_tag)
......@@ -150,7 +167,9 @@ def publish(queue_name, exchange_name):
def publish_callback(ch, method, properties, body):
message = json.loads(body.decode())
topic = message["topic"]
deliver_to_exchange(channel=ch, body=body, exchange_name=exchange_name, topic=topic)
deliver_to_exchange(
channel=ch, body=body, exchange_name=exchange_name, topic=topic
)
ch.basic_ack(delivery_tag=method.delivery_tag)
try:
......@@ -161,7 +180,7 @@ def publish(queue_name, exchange_name):
def subscribe(queue_name, exchange_name, topic=None):
# setup bindings between queue and exchange,
# setup bindings between queue and exchange,
# exchange_type is either 'fanout' or 'topic' based on if the topic arg is passed
connection, channel = pika_connect(host=host)
setup_queue(channel=channel, queue_name=queue_name)
......@@ -183,4 +202,4 @@ def listen(queue_name, callback):
setup_queue(channel=channel, queue_name=queue_name)
channel.basic_consume(queue=queue_name, on_message_callback=callback)
channel.start_consuming()
\ No newline at end of file
channel.start_consuming()
......@@ -11,8 +11,8 @@
import concurrent.futures
from endpoints.clients import ClientsFile
from rmq import broadcast, forward, publish, subscribe
from models.client_model import ClientModel
THREADS = []
EXCHANGES = {
......@@ -23,7 +23,7 @@ EXCHANGES = {
def main():
print("Starting SOAR bus...")
clients_file = ClientsFile()
clients_file = ClientModel()
clients = clients_file.get()
with concurrent.futures.ProcessPoolExecutor() as executor:
......@@ -31,7 +31,7 @@ def main():
thread = executor.submit(publish, "soar-publish", EXCHANGES.get("publish"))
THREADS.append(thread)
for (id, client) in clients.items():
for id, client in clients.items():
# forward
thread = executor.submit(forward, f"{id}-outbox", "soar-publish")
THREADS.append(thread)
......@@ -42,16 +42,14 @@ def main():
THREADS.append(thread)
# subscribe
thread = executor.submit(
subscribe,
subscribe,
f"{id}-inbox",
EXCHANGES.get("publish"),
client["subscription"] # topic
client["subscription"], # topic
)
THREADS.append(thread)
thread = executor.submit(
subscribe,
f"{id}-inbox",
EXCHANGES.get("broadcast")
subscribe, f"{id}-inbox", EXCHANGES.get("broadcast")
)
THREADS.append(thread)
# push
......@@ -67,5 +65,6 @@ def main():
except Exception as e:
print(e)
if __name__ == "__main__":
main()
import flask
import json
import pytest
from unittest.mock import patch, mock_open, call
from werkzeug.exceptions import HTTPException
from api import create_app
@pytest.mark.usefixtures("mock_clients")
def test_get_client(mock_clients):
app = create_app()
with patch(
"builtins.open", mock_open(read_data=json.dumps(mock_clients))
) as mock_file_open, patch(
"os.path.getmtime", return_value=3
) as mock_getmtime, app.test_client() as app_test_client:
response = app_test_client.get("/client")
assert response.status_code == 200
assert len(response.json.keys()) == 2
assert "client-1" in response.json
@pytest.mark.usefixtures("mock_clients")
def test_get_client_1(mock_clients):
app = create_app()
with patch(
"builtins.open", mock_open(read_data=json.dumps(mock_clients))
) as mock_file_open, patch(
"os.path.getmtime", return_value=3
) as mock_getmtime, app.test_client() as app_test_client:
response = app_test_client.get("/client/client-1")
assert response.status_code == 200
assert "client_id" in response.json
assert response.json["client_id"] == "client-1"
assert "client_name" in response.json
assert "secret" not in response.json
@pytest.mark.usefixtures("mock_clients", "mock_new_client")
def test_post_client(mock_clients, mock_new_client):
app = create_app()
with patch(
"builtins.open", mock_open(read_data=json.dumps(mock_clients))
) as mock_file_open, patch(
"os.path.getmtime", return_value=3
) as mock_getmtime, app.test_client() as app_test_client:
response = app_test_client.post("/client", json=mock_new_client)
assert response.status_code == 201
assert "client_id" in response.json
assert response.json["client_id"] == "client-3"
assert "client_name" in response.json
assert "secret" in response.json
@pytest.mark.usefixtures("mock_clients")
def test_put_client_1(mock_clients):
app = create_app()
with patch(
"builtins.open", mock_open(read_data=json.dumps(mock_clients))
) as mock_file_open, patch(
"os.path.getmtime", return_value=3
) as mock_getmtime, app.test_client() as app_test_client:
response = app_test_client.get("/client/client-1")
client_1 = response.json
print(client_1)
client_1["subscription"] = "soar.client-1.#"
response = app_test_client.put("/client/client-1", json=client_1)
print(response.data)
assert response.status_code == 201
assert "client_id" in response.json
assert response.json["client_id"] == "client-1"
assert response.json["subscription"] == "soar.client-1.#"
assert "client_name" in response.json
assert "secret" in response.json
@pytest.mark.usefixtures("mock_clients")
def test_delete_client_1(mock_clients):
app = create_app()
with patch(
"builtins.open", mock_open(read_data=json.dumps(mock_clients))
) as mock_file_open, patch(
"os.path.getmtime", return_value=3
) as mock_getmtime, app.test_client() as app_test_client:
response = app_test_client.delete("/client/client-1")
assert response.status_code == 204
import flask
import json
import pytest
from unittest.mock import patch, mock_open, call
from werkzeug.exceptions import HTTPException
from api import create_app
from endpoints import notify
from conftest import get_auth_header
@pytest.mark.usefixtures(
"mock_clients", "mock_client_credentials", "mock_post_notify", "mock_message_notify"
)
def test_post_notify(
mock_clients, mock_client_credentials, mock_post_notify, mock_message_notify
):
app = create_app()
with patch(
"builtins.open", mock_open(read_data=json.dumps(mock_clients))
) as mock_file_open, patch.object(
notify, "write_to_queue"
) as mock_write_to_queue, app.test_client() as app_test_client:
auth_header = get_auth_header(app_test_client, mock_client_credentials)
response = app_test_client.post(
"/notify", json=mock_post_notify, headers=auth_header
)
assert response.status_code == 200
mock_write_to_queue.assert_called_once_with(
queue_name="client-1-broadcast", msg=json.dumps(mock_message_notify)
)
@pytest.mark.usefixtures("mock_clients", "mock_post_notify")
def test_post_send_no_token(mock_clients, mock_post_notify):
app = create_app()
with patch(
"builtins.open", mock_open(read_data=json.dumps(mock_clients))
) as mock_file_open, patch.object(
notify, "write_to_queue"
) as mock_write_to_queue, app.test_client() as app_test_client:
response = app_test_client.post("/notify", json=mock_post_notify)
assert response.status_code == 403
mock_write_to_queue.assert_not_called()
@pytest.mark.usefixtures("mock_clients", "mock_post_notify")
def test_post_send_invalid_token(mock_clients, mock_post_notify):
app = create_app()
with patch(
"builtins.open", mock_open(read_data=json.dumps(mock_clients))
) as mock_file_open, patch.object(
notify, "write_to_queue"
) as mock_write_to_queue, app.test_client() as app_test_client:
auth_header = {"Authorization": "made-up-token"}
response = app_test_client.post(
"/notify", json=mock_post_notify, headers=auth_header
)
assert response.status_code == 403
mock_write_to_queue.assert_not_called()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment