本文讲解了一个需求的解决方案,而这个奇葩需求你在 99.93% 场景下都不会遇到,就算遇到了,也一定有其它更简单的解决方案。
0. 引言
>>> print((lambda x:None).__code__.__doc__)
code(argcount, kwonlyargcount, nlocals, stacksize, flags, codestring,
constants, names, varnames, filename, name, firstlineno,
lnotab[, freevars[, cellvars]])
Create a code object. Not for the faint of heart.
1. 需求
mock.patch
对象在用做装饰器时,会生成一个偏函数,来将原始函数的第一个位置参数覆盖为一个 mock 对象,如:
1 2 3 @mock.patch('func' ) def test_something (mock, a, b ): pass
此处 mock
参数就是 patch 对应的 mock 对象。
而 pytest 有一个有点厉害的功能:它会读取测试函数的参数,并在 conftest.py
中寻找每个参数对应的同名 fixture 并加载。
前两天就遇到了这样的问题:我要做一个类似 patch
的装饰器,在 pytest 测试函数执行时向内部注入一个参数,而且要保证这个参数在被 pytest 解析时不能暴露在参数列表中 ,否则 pytest 会因为找不到参数的 fixture 而报错。
1 2 3 4 5 6 7 8 @my_patch('func' ) def test_something (mock, fixa, fixb ): pass import inspectassert str (inspect.signature(test_something)) == "(fixa, fixb)"
2. 铺垫一下:简单的偏函数装饰器实现
如果想实现一个简单的偏函数装饰器,非常简单:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 def my_partial (*partial_args, **partial_kwargs ): def decorator (func ): def wrapped (*args, **kwargs ): args = partial_args + args partial_kwargs.update(kwargs) return func(*args, **partial_kwargs) return wrapped return decorator @my_partial(1 , 2 , d=5 ) def func (a, b, c, d=4 ): return sum ([a, b, c, d]) assert func(3 ) == 11
但这样的话,func
的参数声明将会变成 wrapped
的参数声明,也就是 (*args, **kwargs)
,而在这个场景中,func
的参数声明应该是 (c, d=5)
,才能够真正满足 pytest 的要求。
于是,我们就需要在定义 wrapped
函数的时候,对它的参数声明进行定制化。
3. 来点基础知识
3.1 获取函数参数声明
获取完整的函数参数声明,要涉及到函数的几个私有属性:
__code__
: 编译过的函数代码对象,类型为 types.CodeType
__defaults__
: 函数的序列参数默认值,类型为 None
或 tuple
__kwdefaults__
: 函数的关键字参数默认值,类型为 None
或 dict
其中 code
对象中的几个属性也需要用到:
co_varnames
: 函数声明中所有参数的变量名,其中 *
和 **
可变参数的变量名会放在最后
co_argcount
: 序列参数的数量
co_kwonlyargcount
: 严格关键词参数的数量
co_flags
: 函数性质标记,0x04
位声明这个函数是否用到了 *args
,而 0x08
位声明这个函数是否用到了 **kwargs
有了这些属性,我们就可以获取到函数的参数声明信息。具体实现可以看下面的代码汇总 。
当然,Python 3 还支持函数注解 ,需要用到函数的 __annotations__
属性,但在我的代码中没有对这部分进行解析。
3.2 inspect 库
当然,这些轮子,Python 自带库 inspect
都已经帮我们造好了。
inspect.getfullargspec(f)
可以获取到以上所有的参数信息:
1 2 3 4 5 6 >>> def func (a, b=1 , *args, c=2 , **kwargs ):... pass ... >>> inspect.getfullargspec(func)FullArgSpec(args=['a' , 'b' ], varargs='args' , varkw='kwargs' , defaults=(1 ,), kwonlyargs=['c' ], kwonlydefaults={'c' : 2 }, annotations={})
而 inspect.signature
更加强大,它除了解析函数外,还支持“模拟调用”函数,对调用方式进行合法性验证,并展示调用之后函数中每个参数的值:
1 2 3 4 5 6 7 8 9 >>> s = inspect.signature(func) >>> s<Signature (a, b=1 , *args, c=2 , **kwargs)> >>> b = s.bind(0 )>>> b<BoundArguments (a=0 )> >>> b.apply_defaults()>>> b<BoundArguments (a=0 , b=1 , args=(), c=2 , kwargs={})>
实现了这些功能,我们就可以对给定函数进行参数解析并修改了。
这都算哪门子基础知识嘛…
4. 开始动手
先来写一个简单的、只支持序列参数,且不支持参数默认值的偏函数实现。
1 2 3 4 5 6 7 8 9 10 def func (a, b, c ): return sum ([a, b, c]) if __name__ == '__main__' : print (func(1 , 2 , 3 ,), func, signature(func)) partial = nb_partial(3 , 4 )(func) print (partial(5 ), partial, signature(partial, follow_wrapped=False ))
注:sigature(follow_wrapped=False)
会改变 signature
的默认行为:只查看函数本身的定义,而不会通过 __wrapped__
链去查找原函数。
4.1 大体框架
搭个架子:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 from functools import wrapsdef nb_partial (func, *partial_args ): pre_arg_str = ', ' .join(map (repr , partial_args)) code = func.__code__ if len (partial_args) > code.co_argcount: raise TypeError("Too many positional arguments" ) args = code.co_varnames[:code.co_argcount] post_args = args[len (partial_args):] post_args_str = ', ' .join(post_args) def wrapped (c ): return func(3 , 4 , c) return wraps(func)(wrapped)
现在我们有了需要的变量值和变量名信息。可是,我们还需要根据这些变量信息动态 定义 wrapped
函数,这个就有点麻烦了…
4.2 黑科技:code object & FunctionType
Python 毕竟是“万物皆对象”的动态语言。
从字符串编译一段代码?没问题!用函数类定义一个函数对象?也没什么问题!
牛(就)逼(是)了(慢)。
compile
内置函数可以让我们动态地通过字符串来编译代码,来生成一个代码对象。这段代码可以是一个完整的程序,一两个定义,也可以是几个表达式。
而通过 types.FunctionType
,我们就可以使用代码对象,并向其中注入一些信息(如 globals
),即可生成一个可用的函数。
1 2 3 4 5 6 7 8 9 10 11 12 func_def = """ def f(a): return a + 1 """ module_code = compile (func_def, '<>' , 'exec' ) function_code = next (code for code in module_code.co_consts if isinstance (code, types.CodeType)) func = types.FunctionType(function_code, {}) print (func, func(2 ))
有了这个,我们就可以从一段字符串中动态生成一个 Python 函数了。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 def nb_partial (*partial_args ): """ 生成一个偏函数,且新函数的参数列表中会去除 """ def decorator (func ): pre_arg_str = ', ' .join(map (repr , partial_args)) code = func.__code__ args = code.co_varnames[:code.co_argcount] post_args = args[len (partial_args):] post_args_str = ', ' .join(post_args) func_def = (f"def wrapped({post_args_str} ):\n" f" return func({pre_arg_str} , {post_args_str} )" ) module_code = compile (func_def, '<>' , 'exec' ) function_code = next (code for code in module_code.co_consts if isinstance (code, types.CodeType)) wrapped = types.FunctionType(function_code, {'func' : func}) return wraps(func)(wrapped) return decorator
这里有一个小缺陷没有解决掉:这个函数中的 func
是一个全局变量,而不是闭包中的自由变量。
转念一想,装饰器中的原始 func
本身就不是自由变量啊…脑残了 Orz,继续继续。
所以,通过动态定义函数,我们就可以实现还原参数列表的的偏函数功能。大体步骤如下:
提取原函数参数
处理新函数参数
构造函数定义字符串
动态定义新函数
完整实现见下面 。
5. 再改进一下?
我们的原函数目前只支持了最简单的函数声明方式,而无法支持关键词参数,参数默认值等。
结合上面写的 FuncParser
,我们还可以对动态函数的定义做进一步改进。
完整实现有点长,还是下面见 。
当然,这个实现还有一些可以改进之处:
偏函数定义时可以支持关键字参数,这需要进行更多的判断,比如定义时传进去的关键字参数,有可能在原函数中是序列参数;
有很多对象并没有很好地实现 __repr__
魔术方法,导致 repr(arg)
后生成的字符串在函数定义中并不能做到完整还原。所以我们最好通过 globals
将它直接传进新的函数中,而不是使用字符串来进行参数传递。
突然犯懒,就不再做进一步的实现了。
-2. 代码汇总
-2.1 获取函数参数声明的 FuncParser
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 from collections import namedtupleFuncArgs = namedtuple('FuncArgs' , ('args' , 'defaults' )) class FuncParser : @classmethod def parse_vars (cls, func ): code = func.__code__ argc = code.co_argcount kwargc = code.co_kwonlyargcount varnames = code.co_varnames res = { "arg" : varnames[:argc], "kwarg" : varnames[argc:argc + kwargc], "*args" : None , "**kwargs" : None } flag = code.co_flags if flag & 0x04 : res['*args' ] = varnames[argc + kwargc] if flag & 0x08 : res['**kwargs' ] = varnames[argc + kwargc + bool (flag & 0x04 )] defaults = func.__defaults__ or () defaults_map = dict (zip (reversed (res['arg' ]), reversed (defaults))) defaults_map.update(func.__kwdefaults__ or {}) return FuncArgs(res, defaults_map) @classmethod def build_param_str (cls, func_args ): func_vars, defaults_map = func_args params = [] for arg in func_vars['arg' ]: if arg in defaults_map: default_val = defaults_map[arg] params.append(f'{arg} ={default_val} ' ) else : params.append(arg) if func_vars['*args' ]: params.append(f"*{func_vars['*args' ]} " ) else : if func_vars['kwarg' ]: params.append('*' ) for kwarg in func_vars['kwarg' ]: if kwarg in defaults_map: default_val = defaults_map[kwarg] params.append(f'{kwarg} ={default_val} ' ) else : params.append(f'{kwarg} ={kwarg} ' ) if func_vars['**kwargs' ]: params.append(f"**{func_vars['**kwargs' ]} " ) param = ', ' .join(params) param_str = '({})' .format (param) return param_str @classmethod def analyse_func_param (cls, func ): func_args = cls.parse_vars(func) name = func.__name__ return cls.build_param_str(func_args) def test_func (a, *args, b=1 , **kwargs ): pass if __name__ == '__main__' : import inspect import types funcs = [f for f in locals ().values() if isinstance (f, types.FunctionType)] for f in funcs: fstr = FuncParser.analyse_func_param(f) assert fstr == str (inspect.signature(f))
-2.2 “高级”偏函数装饰器的初步实现
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 import typesfrom functools import wrapsfrom inspect import signaturedef nb_partial (*partial_args ): """ 生成一个偏函数,且新函数的参数列表中会去除 """ def decorator (func ): pre_arg_str = ', ' .join(map (repr , partial_args)) code = func.__code__ if len (partial_args) > code.co_argcount: raise TypeError("Too many positional arguments" ) args = code.co_varnames[:code.co_argcount] post_args = args[len (partial_args):] post_args_str = ', ' .join(post_args) func_def = (f"def wrapped({post_args_str} ):\n" f" return func({pre_arg_str} , {post_args_str} )" ) module_code = compile (func_def, '<>' , 'exec' ) function_code = next (code for code in module_code.co_consts if isinstance (code, types.CodeType)) wrapped = types.FunctionType(function_code, {'func' : func}) return wraps(func)(wrapped) return decorator def func (a, b, c ): return sum ([a, b, c]) if __name__ == '__main__' : print (func(1 , 2 , 3 ,), func, signature(func)) partial = nb_partial(3 , 4 )(func) print (partial(5 ), partial, signature(partial, follow_wrapped=False ))
-2.3 “高级”偏函数装饰器的进阶实现
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 import typesfrom functools import wrapsfrom inspect import signaturefrom funcparser import FuncParser, FuncArgsdef nb_partial (*partial_args ): """ 生成一个偏函数,且新函数的参数列表中会去除 """ def decorator (func ): code = func.__code__ if len (partial_args) > code.co_argcount: raise TypeError("Too many positional arguments" ) func_vars, defaults_map = FuncParser.parse_vars(func) def_vars = func_vars.copy() def_vars["arg" ] = def_vars["arg" ][len (partial_args):] def_str = FuncParser.build_param_str(FuncArgs(def_vars, defaults_map)) call_vars = func_vars.copy() call_vars["arg" ] = list (call_vars["arg" ]) call_vars["arg" ][:len (partial_args)] = list (map (repr , partial_args)) call_str = FuncParser.build_param_str(FuncArgs(call_vars, {})) call_str = call_str.replace('*, ' , '' ) func_def = (f"def wrapped{def_str} :\n" f" return func{call_str} " ) module_code = compile (func_def, '<>' , 'exec' ) function_code = next (code for code in module_code.co_consts if isinstance (code, types.CodeType)) argdefs = [] for arg in reversed (func_vars["arg" ]): if arg not in defaults_map: break argdefs.append(defaults_map[arg]) argdefs.reverse() wrapped = types.FunctionType(function_code, {'func' : func}, argdefs=tuple (argdefs)) wrapped.__kwdefaults__ = func.__kwdefaults__ return wraps(func)(wrapped) return decorator def func (a, b=2 , c=3 , *, d=5 ): print (locals ()) return sum ([a, b, c, d]) if __name__ == '__main__' : print (func(1 , 2 , 3 , d=6 ), func, signature(func)) partial = nb_partial(3 )(func) print (partial(6 , d=7 ), partial, signature(partial, follow_wrapped=False ))
-1. Reference & 延伸阅读