I'm trying to write a unit test. I have a decorator that checks whether the argument passed to the function is actually an argument of the expected type. @check_types_method
Its usage looks like this:
class Calculator:
@check_types_method
def call(self, a: MegaNumber, b: int) -> int:
return a.value + b
If the variable a is not a MegaNumber, an exception will be thrown. It works great, but it creates problems in tests. For example, I need to replace the MegaNumber mock with an object, and if I do this, I get an error in the test.
TypeError: The argument 'a' must be of type <class 'mega_number.MegaNumber'>, received: MagicMock
For the test, you need to lock the decorator so that it does not do anything. And skipped Mock objects. However, I do not know how to do this and I do not understand what the reason is. My best attempt looks like this:
import unittest
from unittest.mock import MagicMock, patch
from calculator import Calculator
@patch('calculator.check_types_method', lambda x: x)
class TestCalculator(unittest.TestCase):
def test_call(self):
mock_mega_number = MagicMock()
mock_mega_number.value = 10
calculator = Calculator()
result = calculator.call(mock_mega_number, 5)
self.assertEqual(result, 15, "The call method should return the sum of MegaNumber's "
"value and the integer provided.")
but the decorator loads before the patch is applied - and I can't beat it. Can you tell me what to do?
the code for playback: check_types_method.py
from inspect import signature
from functools import wraps
from typing import get_origin, get_args, Union
def check_types_method(func):
@wraps(func)
def wrapper(*args, **kwargs):
sig = signature(func)
bound = sig.bind(*args, **kwargs)
bound.apply_defaults()
for name, value in bound.arguments.items():
expected_type = sig.parameters[name].annotation
if expected_type is not sig.empty:
if not is_of_generic_type(value, expected_type):
raise TypeError(
f"The argument '{name}' must be of type {expected_type}, received: {type(value).__name__}")
return func(*args, **kwargs)
def is_of_generic_type(obj, generic_type):
origin_type = get_origin(generic_type)
arg_types = get_args(generic_type)
if origin_type is Union:
return any(is_of_generic_type(obj, arg) for arg in arg_types)
if origin_type is None:
# Non-generic types
return isinstance(obj, generic_type)
if not isinstance(obj, origin_type):
return False
if origin_type in (list, set):
element_type = arg_types[0]
return all(is_of_generic_type(item, element_type) for item in obj)
elif origin_type is dict:
key_type, value_type = arg_types
return all(is_of_generic_type(k, key_type) and is_of_generic_type(v, value_type) for k, v in obj.items())
elif origin_type is tuple:
if len(arg_types) == 2 and arg_types[1] is Ellipsis:
return all(isinstance(item, arg_types[0]) for item in obj)
else:
return len(obj) == len(arg_types) and all(
is_of_generic_type(item, t) for item, t in zip(obj, arg_types))
return False
return wrapper
mega_number.py
class MegaNumber:
def __init__(self):
self.value = 10
calculator.py
from check_types_method import check_types_method
from mega_number import MegaNumber
class Calculator:
@check_types_method
def call(self, a: MegaNumber, b: int) -> int:
return a.value + b
test_calculator.py
import unittest
from unittest.mock import MagicMock, patch
from calculator import Calculator
@patch('calculator.check_types_method', lambda x: x)
class TestCalculator(unittest.TestCase):
def test_call(self):
mock_mega_number = MagicMock()
mock_mega_number.value = 10
calculator = Calculator()
result = calculator.call(mock_mega_number, 5)
self.assertEqual(result, 15, "The call method should return the sum of MegaNumber's "
"value and the integer provided.")