Skip to content

Commit 73de9fa

Browse files
Enrico Usaienrico-usai
authored andcommitted
Add mechanism to permit to mock generic methods/functions of the python package
The new `ApiMocker.mockable` is a decorator that can be applied to mock any method. The decorator will automatically search in the `<function-name-path>.overrides.py` file (if it exists) a function with the same name of the function. If the function exists, the decorator will execute it in place of the original one. The same `slurm_plugin.overrides.py` file is already used by `run_instances` and `create_fleet` methods, but without the generic decorator mechanism. We can use this new mechanism for them in the future. I'm using it for `describe_capacity_reservations` to permit to mock it when running e2e tests. An example of bash script to create overrides file is: ``` node_virtualenv_path=$(sudo find / -iname "site-packages" | grep "node_virtualenv") # the overrides.py file must be in the same folder of the module of the function to be mocked cat << EOF | sudo tee -a "${node_virtualenv_path}/aws/overrides.py" from aws.ec2 import CapacityReservationInfo def describe_capacity_reservations(_, capacity_reservations_ids): return [ CapacityReservationInfo({ "CapacityReservationId": "cr-123456", "OwnerId": "123456789", "CapacityReservationArn": "arn:aws:ec2:us-east-2:123456789:capacity-reservation/cr-123456", ... }) ] EOF ``` ### Tests Tested with: * not existing overrides.py (`ImportError`) * empty overrides.py (`AttributeError`) * overrides.py with other functions defined on it (`AttributeError`) * overrides.py with the mocked `describe_capacity_reservations` (mocked output) Signed-off-by: Enrico Usai <[email protected]>
1 parent 0ed494f commit 73de9fa

File tree

2 files changed

+39
-0
lines changed

2 files changed

+39
-0
lines changed

src/aws/ec2.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
# limitations under the License.
1111
from typing import List
1212

13+
from common.utils import ApiMocker
14+
1315
from aws.common import AWSExceptionHandler, Boto3Client
1416

1517

@@ -73,6 +75,7 @@ def __init__(self, config=None, region=None):
7375
super().__init__("ec2", region=region, config=config)
7476

7577
@AWSExceptionHandler.handle_client_exception
78+
@ApiMocker.mockable
7679
def describe_capacity_reservations(self, capacity_reservation_ids: List[str]) -> List[CapacityReservationInfo]:
7780
"""Accept a space separated list of reservation ids. Return a list of CapacityReservationInfo."""
7881
result = []

src/common/utils.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,3 +375,39 @@ def filter(self, record: logging.LogRecord) -> bool:
375375
finally:
376376
# Remove the custom log filter
377377
logger.removeFilter(custom_filter)
378+
379+
380+
class ApiMocker:
381+
"""API mocker."""
382+
383+
@staticmethod
384+
def mockable(func):
385+
"""
386+
Try to mock passed function by searching for an overrides.py file in the same path of the given func.
387+
388+
This function can be used a decorator and applied any method.
389+
390+
The function will check if a function called with the name of the given function exists
391+
in the <function-dir>/overrides.py, and if it does, the function will execute it.
392+
393+
E.g. if the method with ApiMocker.mockable decorator is defined in Ec2Client class
394+
of the ${node_virtualenv_path}/aws/ec2.py module, the mocked function should be defined
395+
in the ${node_virtualenv_path}/aws/overrides.py file.
396+
"""
397+
398+
def wrapper(*args, **kwargs):
399+
try:
400+
function_name = func.__name__
401+
# retrieve parent module of the given function that has the ApiMocker.mockable decorator
402+
func_module = func.__module__
403+
func_parent_module = func_module[: func_module.rindex(".")]
404+
# try to import overrides.py module in the same folder of the module to mock
405+
overrides_module = __import__(f"{func_parent_module}.overrides", fromlist=function_name)
406+
overrided_func = getattr(overrides_module, function_name)
407+
log.info("Calling %s override with args: %s and kwargs: %s", function_name, args, kwargs)
408+
result = overrided_func(*args, **kwargs)
409+
except (ImportError, AttributeError):
410+
result = func(*args, **kwargs)
411+
return result
412+
413+
return wrapper

0 commit comments

Comments
 (0)