from functools import wraps from inspect import signature from logging import getLogger from sys import gettrace, settrace from typing import Any, Callable, Dict, get_origin, List, Optional, Set, \ Tuple, Union from unittest import TestCase __all__ = ['AD2TestCase', 'get_return_type', 'recursion_variant'] def recursion_variant(variant_fun: Callable[..., int]): ''' Wrap each recursive function fun (even inner functions) with the '@recursion_variant(lambda a, b, ..., z: expr) ' wrapper, where the set {a, b, ..., z} is a subset of the arguments to fun and expr is evaluated to an integer. the lambda function must return an integer that monotonically decreases with each recursion (it should never increase). example: @recursive_variant(lambda n: n) def fibonacci(n: int) -> int: if n <= 0: return 0 if n <= 2: return 1 return fibonacci(n - 1) + fibonacci(n - 2) note that n decreases (monotonically) in each iteration. ''' variant_logger = getLogger('recursion_variant') def recursion_variant_outer(wrapped_fun: Callable[..., Any]): variant_par_names = list(signature(variant_fun).parameters.keys()) fun_par_pos = {p: i for i, p in enumerate(signature(wrapped_fun).parameters.keys())} for p in variant_par_names: assert p in fun_par_pos par_translator = [fun_par_pos[p] for p in variant_par_names] del variant_par_names del fun_par_pos variant_stack = [] variant_pars = [None] * len(par_translator) @wraps(wrapped_fun) def recursion_variant_inner(*args, **kwargs): for i in range(len(par_translator)): variant_pars[i] = args[par_translator[i]] variant_ret = variant_fun(*variant_pars) if not isinstance(variant_ret, int): variant_logger.warning( "the recursion variant for function " f"{wrapped_fun.__name__} does not return an integer (" f"return type {type(variant_ret).__name__} with value " f"{variant_ret}).") elif len(variant_stack) > 0 and variant_stack[-1] < variant_ret: variant_logger.warning( f"the recursion variant for function " f"{wrapped_fun.__name__} does not monotonically decrease (" f"previous value was {variant_stack[-1]} while the " f"current value is {variant_ret}).") variant_stack.append(variant_ret) ret = wrapped_fun(*args, **kwargs) variant_stack.pop() return ret return recursion_variant_inner return recursion_variant_outer def has_args(annotation: Any, min_len: int) -> bool: if not hasattr(annotation, '__args__'): return False if not hasattr(annotation.__args__, '__len__'): return False return len(annotation.__args__) >= min_len def has_type(annotation: Any, value: Any) -> bool: if get_origin(annotation) in {Dict, dict}: return (isinstance(value, dict) and has_args(annotation, 2) and all((has_type(annotation.__args__[0], k) and has_type(annotation.__args__[1], v) for k, v in value.items()))) elif get_origin(annotation) in {List, list}: return (isinstance(value, list) and has_args(annotation, 1) and all((has_type(annotation.__args__[0], v) for v in value))) elif get_origin(annotation) == Optional: return value is None or has_type(annotation.__args__[0], value) elif get_origin(annotation) in {Set, set}: return (isinstance(value, set) and has_args(annotation, 1) and all((has_type(annotation.__args__[0], v) for v in value))) elif get_origin(annotation) in {Tuple, tuple}: return (isinstance(value, tuple) and has_args(annotation, 0) and len(annotation.__args__) == len(value) and all((has_type(a, v) for a, v in zip(annotation.__args__, value)))) elif get_origin(annotation) == Union: return (has_args(annotation, 0) and any((has_type(u, value) for u in annotation.__args__))) elif annotation is None: return value is None else: return isinstance(value, annotation) def get_return_type(fun: callable) -> Any: ''' Return the return type of the annotations of fun ''' if not hasattr(fun, '__annotations__'): return None if not isinstance(fun.__annotations__, dict): return None return fun.__annotations__.get('return', None) class TraceData: name: str = '' num_calls: int = 0 is_recursive: bool = False has_recursion_variant = False def __init__(self, name: str, num_calls: int = 0, is_recursive: bool = False, has_recursion_variant: bool = False): self.name = name self.num_calls = num_calls self.is_recursive = is_recursive self.has_recursion_variant = has_recursion_variant def create_tracer(file_name: str) -> Tuple[Dict[str, TraceData], Any]: frame_data: Dict[str, TraceData] = dict() def tracer(frame, event, arg): # Only trace non-lambda functions in the file_name skip = (frame.f_code.co_filename != file_name or len(frame.f_code.co_name) == 0 or not frame.f_code.co_name[0].isalpha()) if skip: return # Create identifier for the function: qualname = f"{frame.f_code.co_firstlineno}:{frame.f_code.co_name}" if qualname not in frame_data: frame_data[qualname] = TraceData(frame.f_code.co_name) # increase the number of times the function has been called: frame_data[qualname].num_calls += 1 if frame_data[qualname].is_recursive: # we have already established the function is recursive return # does the function have a recursion variant has_recursion_variant = ( frame.f_back is not None and frame.f_back.f_code.co_filename == __file__ and frame.f_back.f_code.co_name == 'recursion_variant_inner') if has_recursion_variant: frame_data[qualname].has_recursion_variant = True # inspect the previous frames to find out if the function is recursive: f = frame while f.f_back is not None: f = f.f_back cont = (f.f_code.co_filename != file_name or len(f.f_code.co_name) == 0 or not f.f_code.co_name[0].isalpha()) if cont: continue f_qualname = f"{f.f_code.co_firstlineno}:{f.f_code.co_name}" if f_qualname == qualname: frame_data[qualname].is_recursive = True break return frame_data, tracer class AD2TestCase(TestCase): def assertSignature(self, fun_annotation: Any, ret_value: Any): self.assertTrue( has_type(fun_annotation, ret_value), f"expected type: {fun_annotation} but {type(ret_value)} (value: " f"{ret_value}) was returned.") def trace_exec(self, fun: Callable, *args) -> Tuple[Dict[str, TraceData], Any]: ''' executes the callable fun with args as arguments. the tuple (d, res) is returned, where d is a dictionary of with TraceData values, containing the functions that are called during execution ''' frame_data, tracer = create_tracer(self.source_code_path()) prev_tracer = gettrace() settrace(tracer) ret = fun(*args) settrace(prev_tracer) return frame_data, ret def assertRecursiveVariant(self, frame_data: Dict[str, TraceData]) -> None: for frame in frame_data.values(): if frame.is_recursive: self.assertTrue( frame.has_recursion_variant, f"function with name '{frame.name}' does not have a " "recursion variant. You need to wrap each recursive " "functions with the @recursion_variant(f) wrapper, where " "f is a function returning an int.") def assertAlgorithm(self, fun: Callable, signature: Any, *args) -> Any: frame_data, ret = self.trace_exec(fun, *args) self.assertSignature(signature, ret) fun_frame = next((f for f in frame_data.values() if f.name == fun.__name__), None) self.assertIsNotNone(fun_frame) self.assertRecursiveVariant(frame_data) return ret def assertRecursiveAlgorithm(self, fun: Callable, signature: Any, min_recursions: int, *args) -> Any: frame_data, ret = self.trace_exec(fun, *args) self.assertSignature(signature, ret) self.assertEqual(len(frame_data), 1) fun_frame = next((f for f in frame_data.values() if f.name == fun.__name__), None) self.assertIsNotNone(fun_frame) if min_recursions > 0: self.assertTrue(fun_frame.is_recursive, f'{fun.__name__} must be recursive') self.assertGreaterEqual(fun_frame.num_calls, min_recursions + 1, f'{fun.__name__} too few recursive calls') self.assertRecursiveVariant(frame_data) return ret def assertSelfRecursiveAlgorithm(self, fun: Callable, signature: Any, min_recursions: int, *args) -> Any: frame_data, ret = self.trace_exec(fun, *args) self.assertSignature(signature, ret) self.assertEqual(len(frame_data), 1) fun_frame = next((f for f in frame_data.values() if f.name == fun.__name__), None) self.assertIsNotNone(fun_frame) if min_recursions > 0: self.assertTrue(fun_frame.is_recursive, f'{fun.__name__} must be recursive') self.assertGreaterEqual(fun_frame.num_calls, min_recursions + 1, f'{fun.__name__} too few recursive calls') self.assertRecursiveVariant(frame_data) return ret def assertIterativeAlgorithm(self, fun: Callable, signature: Any, *args) -> Any: frame_data, ret = self.trace_exec(fun, *args) self.assertSignature(signature, ret) self.assertEqual(len(frame_data), 1) fun_frame = next((f for f in frame_data.values() if f.name == fun.__name__), None) self.assertIsNotNone(fun_frame) self.assertFalse(fun_frame.is_recursive, f'{fun.__name__} must be iterative') self.assertRecursiveVariant(frame_data) return ret