diff --git a/src/ogd/apis/utils/APIResponse.py b/src/ogd/apis/utils/APIResponse.py index f7bed18..328d0eb 100644 --- a/src/ogd/apis/utils/APIResponse.py +++ b/src/ogd/apis/utils/APIResponse.py @@ -8,7 +8,7 @@ # import standard libraries import json from enum import IntEnum -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Set # import 3rd-party libraries from flask import Response @@ -47,8 +47,17 @@ class ResponseStatus(IntEnum): NONE = 1 SUCCESS = 200 ERR_REQ = 400 + ERR_NOTFOUND = 404 ERR_SRV = 500 + @staticmethod + def ServerErrors() -> Set["ResponseStatus"]: + return {ResponseStatus.ERR_SRV} + + @staticmethod + def ClientErrors() -> Set["ResponseStatus"]: + return {ResponseStatus.ERR_REQ, ResponseStatus.ERR_NOTFOUND} + def __str__(self): """Stringify function for ResponseStatus objects. @@ -164,7 +173,6 @@ def AsDict(self): "type" : str(self._type), "val" : self._val, "msg" : self._msg, - "status" : str(self._status) } @property @@ -175,12 +183,12 @@ def AsJSON(self): def AsFlaskResponse(self) -> Response: return Response(response=self.AsJSON, status=self.Status.value, mimetype='application/json') - def RequestErrored(self, msg:str): - self._status = ResponseStatus.ERR_REQ + def RequestErrored(self, msg:str, status:Optional[ResponseStatus]=None): + self._status = status if status is not None and status in ResponseStatus.ClientErrors() else ResponseStatus.ERR_REQ self._msg = f"ERROR: {msg}" - def ServerErrored(self, msg:str): - self._status = ResponseStatus.ERR_SRV + def ServerErrored(self, msg:str, status:Optional[ResponseStatus]=None): + self._status = status if status is not None and status in ResponseStatus.ServerErrors() else ResponseStatus.ERR_SRV self._msg = f"SERVER ERROR: {msg}" def RequestSucceeded(self, msg:str, val:Any): diff --git a/tests/cases/apis/HelloAPI/local/t_Hello.py b/tests/cases/apis/HelloAPI/local/t_Hello.py index 835e140..53a881e 100644 --- a/tests/cases/apis/HelloAPI/local/t_Hello.py +++ b/tests/cases/apis/HelloAPI/local/t_Hello.py @@ -49,7 +49,6 @@ def test_get(self): self.assertEqual(body.get("type"), "GET") self.assertEqual(body.get("val"), None) self.assertEqual(body.get("msg"), "Hello! You GETted successfully!") - self.assertEqual(body.get("status"), "SUCCESS") def test_post(self): _url = f"/hello" @@ -64,7 +63,6 @@ def test_post(self): self.assertEqual(body.get("type"), "POST") self.assertEqual(body.get("val"), None) self.assertEqual(body.get("msg"), "Hello! You POSTed successfully!") - self.assertEqual(body.get("status"), "SUCCESS") def test_put(self): _url = f"/hello" @@ -79,4 +77,3 @@ def test_put(self): self.assertEqual(body.get("type"), "PUT") self.assertEqual(body.get("val"), None) self.assertEqual(body.get("msg"), "Hello! You PUTted successfully!") - self.assertEqual(body.get("status"), "SUCCESS") diff --git a/tests/cases/apis/HelloAPI/local/t_ParamHello.py b/tests/cases/apis/HelloAPI/local/t_ParamHello.py index 83495c7..c05e5b7 100644 --- a/tests/cases/apis/HelloAPI/local/t_ParamHello.py +++ b/tests/cases/apis/HelloAPI/local/t_ParamHello.py @@ -50,7 +50,6 @@ def test_get(self): self.assertEqual(body.get("type"), "GET") self.assertEqual(body.get("val"), None) self.assertEqual(body.get("msg"), f"Hello {param}! You GETted successfully!") - self.assertEqual(body.get("status"), "SUCCESS") def test_post(self): param = "Tester" @@ -66,7 +65,6 @@ def test_post(self): self.assertEqual(body.get("type"), "POST") self.assertEqual(body.get("val"), None) self.assertEqual(body.get("msg"), f"Hello {param}! You POSTed successfully!") - self.assertEqual(body.get("status"), "SUCCESS") def test_put(self): param = "Tester" @@ -82,4 +80,3 @@ def test_put(self): self.assertEqual(body.get("type"), "PUT") self.assertEqual(body.get("val"), None) self.assertEqual(body.get("msg"), f"Hello {param}! You PUTted successfully!") - self.assertEqual(body.get("status"), "SUCCESS") diff --git a/tests/cases/apis/HelloAPI/local/t_Version.py b/tests/cases/apis/HelloAPI/local/t_Version.py index 4d7a69c..2a8d964 100644 --- a/tests/cases/apis/HelloAPI/local/t_Version.py +++ b/tests/cases/apis/HelloAPI/local/t_Version.py @@ -49,4 +49,3 @@ def test_get(self): self.assertEqual(body.get("type"), "GET") self.assertEqual(body.get("val"), {"version": "0.0.0-Testing"}) self.assertEqual(body.get("msg"), "Successfully retrieved API version.") - self.assertEqual(body.get("status"), "SUCCESS") diff --git a/tests/cases/apis/HelloAPI/remote/t_Hello.py b/tests/cases/apis/HelloAPI/remote/t_Hello.py index b13db75..59532a9 100644 --- a/tests/cases/apis/HelloAPI/remote/t_Hello.py +++ b/tests/cases/apis/HelloAPI/remote/t_Hello.py @@ -38,7 +38,6 @@ def test_get(self): self.assertEqual(body.get("type"), "GET", f"Bad type from {_url}") self.assertEqual(body.get("val"), None, f"Bad val from {_url}") self.assertEqual(body.get("msg"), "Hello! You GETted successfully!", f"Bad msg from {_url}") - self.assertEqual(body.get("status"), "SUCCESS", f"Bad status from {_url}") def test_post(self): _url = f"{self.base_url}/hello" @@ -57,7 +56,6 @@ def test_post(self): self.assertEqual(body.get("type"), "POST", f"Bad type from {_url}") self.assertEqual(body.get("val"), None, f"Bad val from {_url}") self.assertEqual(body.get("msg"), "Hello! You POSTed successfully!", f"Bad msg from {_url}") - self.assertEqual(body.get("status"), "SUCCESS", f"Bad status from {_url}") def test_put(self): _url = f"{self.base_url}/hello" @@ -77,4 +75,3 @@ def test_put(self): self.assertEqual(body.get("type"), "PUT", f"Bad type from {_url}") self.assertEqual(body.get("val"), None, f"Bad val from {_url}") self.assertEqual(body.get("msg"), "Hello! You PUTted successfully!", f"Bad msg from {_url}") - self.assertEqual(body.get("status"), "SUCCESS", f"Bad status from {_url}") diff --git a/tests/cases/apis/HelloAPI/remote/t_ParamHello.py b/tests/cases/apis/HelloAPI/remote/t_ParamHello.py index 01f3cc1..221625c 100644 --- a/tests/cases/apis/HelloAPI/remote/t_ParamHello.py +++ b/tests/cases/apis/HelloAPI/remote/t_ParamHello.py @@ -39,7 +39,6 @@ def test_get(self): self.assertEqual(body.get("type"), "GET", f"Bad type from {_url}") self.assertEqual(body.get("val"), None, f"Bad val from {_url}") self.assertEqual(body.get("msg"), f"Hello {self.param}! You GETted successfully!", f"Bad msg from {_url}") - self.assertEqual(body.get("status"), "SUCCESS", f"Bad status from {_url}") def test_post(self): _url = f"{self.base_url}/p_hello/{self.param}" @@ -58,7 +57,6 @@ def test_post(self): self.assertEqual(body.get("type"), "POST", f"Bad type from {_url}") self.assertEqual(body.get("val"), None, f"Bad val from {_url}") self.assertEqual(body.get("msg"), f"Hello {self.param}! You POSTed successfully!", f"Bad msg from {_url}") - self.assertEqual(body.get("status"), "SUCCESS", f"Bad status from {_url}") def test_put(self): _url = f"{self.base_url}/p_hello/{self.param}" @@ -78,4 +76,3 @@ def test_put(self): self.assertEqual(body.get("type"), "PUT", f"Bad type from {_url}") self.assertEqual(body.get("val"), None, f"Bad val from {_url}") self.assertEqual(body.get("msg"), f"Hello {self.param}! You PUTted successfully!", f"Bad msg from {_url}") - self.assertEqual(body.get("status"), "SUCCESS", f"Bad status from {_url}") diff --git a/tests/cases/utils/t_APIResponse.py b/tests/cases/utils/t_APIResponse.py index 8a7b269..3d98ffb 100644 --- a/tests/cases/utils/t_APIResponse.py +++ b/tests/cases/utils/t_APIResponse.py @@ -1,21 +1,26 @@ # import libraries -import sys +import json +import logging import unittest from pathlib import Path from unittest import TestCase # import ogd libraries from ogd.common.configs.TestConfig import TestConfig +from ogd.common.utils.Logger import Logger # import locals -from src.ogd.apis.utils.APIResponse import APIResponse +from src.ogd.apis.utils.APIResponse import APIResponse, RESTType, ResponseStatus from tests.config.t_config import settings -_config = TestConfig.FromDict(name="APIResponseTestConfig", unparsed_elements=settings) -@unittest.skip("No tests implemented yet") class t_APIResponse(TestCase): - @staticmethod - def RunAll(): - pass + @classmethod + def setUpClass(cls) -> None: + _config = TestConfig.FromDict(name="APIResponseTestConfig", unparsed_elements=settings) + _level = logging.DEBUG if _config.Verbose else logging.INFO + Logger.InitializeLogger(level=_level, use_logfile=False) + + def setUp(self): + self.response = APIResponse(req_type=RESTType.GET, val={"foo":"bar"}, msg="Complete", status=ResponseStatus.SUCCESS) @unittest.skip("Not yet implemented") def test_FromRequestResult(self): @@ -25,45 +30,81 @@ def test_FromRequestResult(self): def test_FromFromDict(self): pass - @unittest.skip("Not yet implemented") def test_Type(self): - pass + self.assertEqual(self.response.Type, RESTType.GET) - @unittest.skip("Not yet implemented") def test_Value(self): - pass + self.assertEqual(self.response.Value, {"foo":"bar"}) - @unittest.skip("Not yet implemented") def test_Message(self): - pass + self.assertEqual(self.response.Message, "Complete") - @unittest.skip("Not yet implemented") def test_Status(self): - pass + self.assertEqual(self.response.Status, ResponseStatus.SUCCESS) - @unittest.skip("Not yet implemented") def test_AsDict(self): - pass + expected = { + "type": "GET", + "val": {"foo":"bar"}, + "msg": "Complete" + } + d = self.response.AsDict + for key in expected.keys(): + self.assertIn(key, d.keys(), f"Response is missing key {key}") + self.assertEqual(d[key], expected[key]) - @unittest.skip("Not yet implemented") def test_AsJSON(self): - pass + expected = { + "type": "GET", + "val": {"foo":"bar"}, + "msg": "Complete" + } + self.assertEqual(self.response.AsJSON, json.dumps(expected)) @unittest.skip("Not yet implemented") def test_AsFlaskResponse(self): pass - @unittest.skip("Not yet implemented") - def test_RequestErrored(self): - pass + def test_RequestErrored_default_status(self): + self.response.RequestErrored("Default request error") + self.assertEqual(self.response.Message, "ERROR: Default request error") + self.assertEqual(self.response.Status, ResponseStatus.ERR_REQ) - @unittest.skip("Not yet implemented") - def test_ServerErrored(self): - pass + def test_RequestErrored_general_status(self): + self.response.RequestErrored("General request error", ResponseStatus.ERR_REQ) + self.assertEqual(self.response.Message, "ERROR: General request error") + self.assertEqual(self.response.Status, ResponseStatus.ERR_REQ) + + def test_RequestErrored_notfound_status(self): + self.response.RequestErrored("404 request error", ResponseStatus.ERR_NOTFOUND) + self.assertEqual(self.response.Message, "ERROR: 404 request error") + self.assertEqual(self.response.Status, ResponseStatus.ERR_NOTFOUND) + + def test_RequestErrored_invalid_status(self): + self.response.RequestErrored("Invalid choice of code for request error, should give default error code", ResponseStatus.ERR_SRV) + self.assertEqual(self.response.Message, "ERROR: Invalid choice of code for request error, should give default error code") + self.assertEqual(self.response.Status, ResponseStatus.ERR_REQ) + + def test_ServerErrored_default_status(self): + self.response.ServerErrored("Default server error") + self.assertEqual(self.response.Message, "SERVER ERROR: Default server error") + self.assertEqual(self.response.Status, ResponseStatus.ERR_SRV) + + def test_ServerErrored_general_status(self): + self.response.ServerErrored("General server error") + self.assertEqual(self.response.Message, "SERVER ERROR: General server error", ResponseStatus.ERR_SRV) + self.assertEqual(self.response.Status, ResponseStatus.ERR_SRV) + + def test_ServerErrored_invalid_status(self): + self.response.ServerErrored("Invalid choice of code for server error, should give default error code", ResponseStatus.ERR_REQ) + self.assertEqual(self.response.Message, "SERVER ERROR: Invalid choice of code for server error, should give default error code") + self.assertEqual(self.response.Status, ResponseStatus.ERR_SRV) - @unittest.skip("Not yet implemented") def test_RequestSucceeded(self): - pass + self.response.RequestSucceeded(msg="Default server success", val={"foo":"bar"}) + self.assertEqual(self.response.Message, "SUCCESS: Default server success") + self.assertEqual(self.response.Status, ResponseStatus.SUCCESS) + self.assertEqual(self.response.Value, {"foo":"bar"}) if __name__ == '__main__': unittest.main()