diff --git a/integrations/client/test_delphi_epidata.py b/integrations/client/test_delphi_epidata.py index c63a8e1be..af8b384ab 100644 --- a/integrations/client/test_delphi_epidata.py +++ b/integrations/client/test_delphi_epidata.py @@ -2,7 +2,7 @@ # standard library import unittest -from unittest.mock import patch, MagicMock +from unittest.mock import patch, MagicMock, call # third party from aiohttp.client_exceptions import ClientResponseError @@ -284,21 +284,31 @@ def test_covidcast(self): # check result self.assertEqual(response_1, {'message': 'no results', 'result': -2}) - @patch('requests.post') - @patch('requests.get') - def test_request_method(self, get, post): + @patch('requests_cache.CachedSession') + @patch('requests.Session') + def test_request_method(self, _Session, _CachedSession): """Test that a GET request is default and POST is used if a 414 is returned.""" - with self.subTest(name='get request'): + with self.subTest(name='get request, no cache'): + Session = MagicMock() + _Session.return_value = Session Epidata.covidcast('src', 'sig', 'day', 'county', 20200414, '01234') - get.assert_called_once() - post.assert_not_called() + assert call('get') in Session.request.call_args_list + assert call('post') not in Session.request.call_args_list + with self.subTest(name='get request, cache'): + CachedSession = MagicMock() + _CachedSession.return_value = CachedSession + Epidata.covidcast('src', 'sig', 'day', 'county', 20200414, '01234', cache_timeout=5) + assert call('get') in CachedSession.request.call_args_list + assert call('post') not in CachedSession.request.call_args_list with self.subTest(name='post request'): mock_response = MagicMock() mock_response.status_code = 414 - get.return_value = mock_response + Session = MagicMock() + Session.request.return_value = mock_response + _Session.return_value = Session Epidata.covidcast('src', 'sig', 'day', 'county', 20200414, '01234') - self.assertEqual(get.call_count, 2) # one from post test and one from get test - post.assert_called_once() + assert call('get') in Session.request.call_args_list + assert call('post') in Session.request.call_args_list def test_geo_value(self): """test different variants of geo types: single, *, multi.""" diff --git a/requirements.txt b/requirements.txt index 9cae4fa02..d9805f186 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,9 @@ Flask==1.1.2 SQLAlchemy==1.3.22 mysqlclient==2.0.2 +newrelic python-dotenv==0.15.0 orjson==3.4.7 pandas==1.2.3 +requests-cache scipy==1.6.2 -newrelic diff --git a/src/client/delphi_epidata.py b/src/client/delphi_epidata.py index b6f5f684d..66397f06e 100644 --- a/src/client/delphi_epidata.py +++ b/src/client/delphi_epidata.py @@ -9,13 +9,17 @@ """ # External modules -import requests +from typing import Union, Optional +from datetime import timedelta, datetime +from requests import Session +from requests_cache import CachedSession import asyncio -import warnings from aiohttp import ClientSession, TCPConnector from pkg_resources import get_distribution, DistributionNotFound +CacheTime = Union[int, datetime, timedelta] + # Obtain package version for the user-agent. Uses the installed version by # preference, even if you've installed it and then use this script independently # by accident. @@ -54,7 +58,7 @@ def _list(values): # Helper function to request and parse epidata @staticmethod - def _request(params): + def _request(params, cache_timeout: Optional[CacheTime] = None): """Request and parse epidata. We default to GET since it has better caching and logging @@ -63,9 +67,14 @@ def _request(params): """ try: # API call - req = requests.get(Epidata.BASE_URL, params, headers=_HEADERS) + session = Session() if cache_timeout is None else CachedSession( + 'covidcast_cache', expire_after=cache_timeout + ) + req = session.request('get', Epidata.BASE_URL, params, headers=_HEADERS) + # Fallback to requests if we have to use POST if req.status_code == 414: - req = requests.post(Epidata.BASE_URL, params, headers=_HEADERS) + req = session.request('post', Epidata.BASE_URL, params, headers=_HEADERS) + session.close() return req.json() except Exception as e: # Something broke @@ -567,7 +576,7 @@ def meta(): @staticmethod def covidcast( data_source, signals, time_type, geo_type, - time_values, geo_value, as_of=None, issues=None, lag=None, **kwargs): + time_values, geo_value, as_of=None, issues=None, lag=None, cache_timeout=None, **kwargs): """Fetch Delphi's COVID-19 Surveillance Streams""" # also support old parameter name if signals is None and 'signal' in kwargs: @@ -602,7 +611,7 @@ def covidcast( params['format'] = kwargs['format'] # Make the API call - return Epidata._request(params) + return Epidata._request(params, cache_timeout) # Fetch Delphi's COVID-19 Surveillance Streams metadata @staticmethod