diff --git a/app.py b/app.py index 6e41b8f..071cb3a 100644 --- a/app.py +++ b/app.py @@ -3,9 +3,8 @@ import logging import os -import jsonpickle import torch -from flask import Flask, request +from flask import Flask, request, abort, jsonify from transformers import AutoModelForSequenceClassification, AutoTokenizer logging.basicConfig( @@ -71,28 +70,34 @@ def run_emoberta(): """Receive everything in json!!!""" app.logger.debug("Receiving data ...") data = request.json - data = jsonpickle.decode(data) - text = data["text"] + if len(data) != 1: + abort(jsonify(message="Too many fields"), 400) + elif "text" not in data: + abort(jsonify(message="Missing 'text' field"), 400) + elif not isinstance(data['text'], str): + abort(jsonify(message="'text' field not a string"), 400) + else: + text = data["text"] - app.logger.info(f"raw text received: {text}") + app.logger.info(f"raw text received: {text}") - tokens = tokenizer(text, truncation=True) + tokens = tokenizer(text, truncation=True) - tokens["input_ids"] = torch.tensor(tokens["input_ids"]).view(1, -1).to(device) - tokens["attention_mask"] = ( - torch.tensor(tokens["attention_mask"]).view(1, -1).to(device) - ) + tokens["input_ids"] = torch.tensor(tokens["input_ids"]).view(1, -1).to(device) + tokens["attention_mask"] = ( + torch.tensor(tokens["attention_mask"]).view(1, -1).to(device) + ) - outputs = model(**tokens) - outputs = torch.softmax(outputs["logits"].detach().cpu(), dim=1).squeeze().numpy() - outputs = {id2emotion[idx]: prob.item() for idx, prob in enumerate(outputs)} - app.logger.info(f"prediction: {outputs}") + outputs = model(**tokens) + outputs = torch.softmax(outputs["logits"].detach().cpu(), dim=1).squeeze().numpy() + outputs = {id2emotion[idx]: prob.item() for idx, prob in enumerate(outputs)} + app.logger.info(f"prediction: {outputs}") - response = jsonpickle.encode(outputs) - app.logger.info("json-pickle is done.") + response = outputs + app.logger.info("json is done.") - return response + return response if __name__ == "__main__": diff --git a/client.py b/client.py index d584e41..4536502 100644 --- a/client.py +++ b/client.py @@ -4,7 +4,6 @@ import argparse import logging -import jsonpickle import requests logging.basicConfig( @@ -26,10 +25,9 @@ def run_text(text: str, url_emoberta: str) -> None: data = {"text": text} logging.debug("sending text to server...") - data = jsonpickle.encode(data) response = requests.post(url_emoberta, json=data) logging.info(f"got {response} from server!...") - response = jsonpickle.decode(response.text) + response = response.json() logging.info(f"emoberta results: {response}") diff --git a/requirements-client.txt b/requirements-client.txt index 5094d05..e4f9df0 100644 --- a/requirements-client.txt +++ b/requirements-client.txt @@ -1,6 +1,5 @@ certifi==2021.10.8 charset-normalizer==2.0.12 idna==3.3 -jsonpickle==2.1.0 requests==2.27.1 urllib3==1.26.8 diff --git a/requirements-deploy.txt b/requirements-deploy.txt index 0ff2f1a..d76f459 100644 --- a/requirements-deploy.txt +++ b/requirements-deploy.txt @@ -8,7 +8,6 @@ idna==3.3 itsdangerous==2.1.1 Jinja2==3.0.3 joblib==1.1.0 -jsonpickle==2.1.0 MarkupSafe==2.1.0 numpy==1.22.3 packaging==21.3