From 405901584d11225253ce0fcf9ecfb826053a114d Mon Sep 17 00:00:00 2001 From: Dan Jones <dan.jones@noc.ac.uk> Date: Wed, 16 Nov 2022 17:16:27 +0000 Subject: [PATCH] refactor: change clients endpoint to marshmallow --- endpoints/clients.py | 71 ++++++++++++++++++++++---------------------- endpoints/notify.py | 4 +-- endpoints/receive.py | 2 +- endpoints/send.py | 4 +-- 4 files changed, 40 insertions(+), 41 deletions(-) diff --git a/endpoints/clients.py b/endpoints/clients.py index cc4018b..a079e93 100644 --- a/endpoints/clients.py +++ b/endpoints/clients.py @@ -1,10 +1,17 @@ -from flask_restful import Resource, reqparse, abort, fields, marshal_with +from flask_restful import Resource, request, abort +from marshmallow import Schema, fields import json import os import random import string +class ClientSchema(Schema): + client_id = fields.Str(required=True) + client_name = fields.Str(required=True) + subscription = fields.Str(required=True) + + class ClientsFile: file = "clients.json" mtime = 0 @@ -13,7 +20,6 @@ class ClientsFile: def __init__(self): self.get() - self.setup_request_parser() def get(self): try: @@ -36,7 +42,7 @@ class ClientsFile: return client def add(self, client): - client.secret = self.secret() + client['secret'] = self.secret() self.clients[client["client_id"]] = client self.save() return client @@ -69,70 +75,63 @@ class ClientsFile: ) return str(res) - def setup_request_parser(self): - parser = reqparse.RequestParser() - parser.add_argument( - "client_id", type=str, help="A unique name to identify the client" - ) - parser.add_argument( - "client_name", type=str, help="A human friendly name to identify the client" - ) - parser.add_argument( - "subscription", - type=str, - help="A dot delimited string identify topics to subscribe to", - ) - self.parser = parser - - def parse(self): - return self.parser.parse_args() - - -resource_fields = { - "client_id": fields.String, - "client_name": fields.String, - "subscription": fields.String, -} - clients_file = ClientsFile() # Client class Client(Resource): + clients_file = None + def __init__(self): + self.schema = ClientSchema() + self.clients_file = ClientsFile() + def get(self, client_id): - client = clients_file.find(client_id) + client = self.clients_file.find(client_id) del client['secret'] if not client: abort(404, message="No client with id: {}".format(client_id)) return client def delete(self, todo_id): - client = clients_file.find(client_id) + client = self.clients_file.find(client_id) if not client: abort(404, message="No client with id: {}".format(client_id)) else: - clients_file.remove(client) + self.clients_file.remove(client) return client, 204 def put(self, client_id): - args = clients_file.parse() - client = clients_file.find(client_id) + args = request.get_json() + errors = self.schema.validate(args) + if errors: + abort(400, message=str(errors)) + + client = self.clients_file.find(client_id) if not client: abort(404, message="No client with id: {}".format(client_id)) else: - client = clients_file.update(args) + client = self.clients_file.update(args) return client, 201 # ClientList class ClientList(Resource): + def __init__(self): + self.schema = ClientSchema() + self.clients_file = ClientsFile() + def get(self): return { client_id: (client, client.pop("secret", None))[0] - for client_id, client in clients_file.get().items() + for client_id, client in self.clients_file.get().items() } def post(self): - args = clients_file.parse() + args = request.get_json() + + errors = self.schema.validate(args) + if errors: + abort(400, message=str(errors)) + client = clients_file.find(args["client_id"]) if client: abort(403, message="Duplicate client id: {}".format(client_id)) diff --git a/endpoints/notify.py b/endpoints/notify.py index 96c54e3..0ebb0d1 100644 --- a/endpoints/notify.py +++ b/endpoints/notify.py @@ -1,4 +1,4 @@ -from flask_restful import Resource, reqparse, abort +from flask_restful import Resource, request, abort from marshmallow import Schema, fields import json @@ -19,7 +19,7 @@ class Notify(Resource): def post(self): errors = self.schema.validate(request.args) if errors: - abort(400, str(errors)) + abort(400, message=str(errors)) messages = [] allow = False diff --git a/endpoints/receive.py b/endpoints/receive.py index 2f86d7a..86d2487 100644 --- a/endpoints/receive.py +++ b/endpoints/receive.py @@ -22,7 +22,7 @@ class Receive(Resource): def get(self): errors = self.schema.validate(request.args) if errors: - abort(400, str(errors)) + abort(400, message=str(errors)) messages = [] allow = False diff --git a/endpoints/send.py b/endpoints/send.py index 8cbfda1..119c67b 100644 --- a/endpoints/send.py +++ b/endpoints/send.py @@ -1,4 +1,4 @@ -from flask_restful import Resource, reqparse, abort +from flask_restful import Resource, request, abort from marshmallow import Schema, fields import json @@ -20,7 +20,7 @@ class Send(Resource): def post(self): errors = self.schema.validate(request.args) if errors: - abort(400, str(errors)) + abort(400, message=str(errors)) messages = [] allow = False -- GitLab