本文讲解了一个需求的解决方案,而这个奇葩需求你在 99.93% 场景下都不会遇到,就算遇到了,也一定有其它更简单的解决方案。
>>> print((lambda x:None).__code__.__doc__)
Create a code object.  Not for the faint of heart. 
 
mock.patch
1 2 3 @mock.patch('func'  def  test_something (mock, a, b ):    pass  
此处 mock 参数就是 patch 对应的 mock 对象。
而 pytest  有一个有点厉害的功能:它会读取测试函数的参数,并在 conftest.py 中寻找每个参数对应的同名 fixture  并加载。
前两天就遇到了这样的问题:我要做一个类似 patch 的装饰器,在 pytest 测试函数执行时向内部注入一个参数,而且要保证这个参数在被 pytest 解析时不能暴露在参数列表中 ,否则 pytest 会因为找不到参数的 fixture 而报错。
1 2 3 4 5 6 7 8 @my_patch('func'  def  test_something (mock, fixa, fixb ):         pass  import  inspectassert  str (inspect.signature(test_something)) == "(fixa, fixb)" 
如果想实现一个简单的偏函数装饰器,非常简单:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 def  my_partial (*partial_args, **partial_kwargs ):    def  decorator (func ):         def  wrapped (*args, **kwargs ):             args = partial_args + args             partial_kwargs.update(kwargs)             return  func(*args, **partial_kwargs)         return  wrapped     return  decorator @my_partial(1 , 2 , d=5  def  func (a, b, c, d=4  ):    return  sum ([a, b, c, d]) assert  func(3 ) == 11 
但这样的话,func 的参数声明将会变成 wrapped 的参数声明,也就是 (*args, **kwargs),而在这个场景中,func 的参数声明应该是 (c, d=5),才能够真正满足 pytest 的要求。wrapped 函数的时候,对它的参数声明进行定制化。
获取完整的函数参数声明,要涉及到函数的几个私有属性:
__code__: 编译过的函数代码对象,类型为 types.CodeType__defaults__: 函数的序列参数默认值,类型为 None 或 tuple__kwdefaults__: 函数的关键字参数默认值,类型为 None 或 dict 
其中 code 对象中的几个属性也需要用到:
co_varnames: 函数声明中所有参数的变量名,其中 * 和 ** 可变参数的变量名会放在最后co_argcount: 序列参数的数量co_kwonlyargcount: 严格关键词参数的数量co_flags: 函数性质标记,0x04 位声明这个函数是否用到了 *args,而 0x08 位声明这个函数是否用到了 **kwargs 
有了这些属性,我们就可以获取到函数的参数声明信息。具体实现可以看下面的代码汇总 。
当然,Python 3 还支持函数注解 ,需要用到函数的 __annotations__ 属性,但在我的代码中没有对这部分进行解析。
当然,这些轮子,Python 自带库 inspect
inspect.getfullargspec(f) 可以获取到以上所有的参数信息:
1 2 3 4 5 6 >>>  def  func (a, b=1 , *args, c=2 , **kwargs ):...      pass ... >>>  inspect.getfullargspec(func)FullArgSpec(args=['a' , 'b' ], varargs='args' , varkw='kwargs' , defaults=(1 ,),             kwonlyargs=['c' ], kwonlydefaults={'c' : 2 }, annotations={}) 
而 inspect.signature 更加强大,它除了解析函数外,还支持“模拟调用”函数,对调用方式进行合法性验证,并展示调用之后函数中每个参数的值:
1 2 3 4 5 6 7 8 9 >>>  s = inspect.signature(func)     >>>  s<Signature (a, b=1 , *args, c=2 , **kwargs)> >>>  b = s.bind(0 )>>>  b<BoundArguments (a=0 )> >>>  b.apply_defaults()>>>  b<BoundArguments (a=0 , b=1 , args=(), c=2 , kwargs={})> 
实现了这些功能,我们就可以对给定函数进行参数解析并修改了。
这都算哪门子基础知识嘛… 
先来写一个简单的、只支持序列参数,且不支持参数默认值的偏函数实现。
1 2 3 4 5 6 7 8 9 10 def  func (a, b, c ):    return  sum ([a, b, c]) if  __name__ == '__main__' :    print (func(1 , 2 , 3 ,), func, signature(func))          partial = nb_partial(3 , 4 )(func)          print (partial(5 ), partial, signature(partial, follow_wrapped=False ))      
注:sigature(follow_wrapped=False) 会改变 signature 的默认行为:只查看函数本身的定义,而不会通过 __wrapped__ 链去查找原函数。
搭个架子:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 from  functools import  wrapsdef  nb_partial (func, *partial_args ):    pre_arg_str = ', ' .join(map (repr , partial_args))          code = func.__code__     if  len (partial_args) > code.co_argcount:         raise  TypeError("Too many positional arguments" )     args = code.co_varnames[:code.co_argcount]     post_args = args[len (partial_args):]     post_args_str = ', ' .join(post_args)          def  wrapped (c ):                      return  func(3 , 4 , c)     return  wraps(func)(wrapped) 
现在我们有了需要的变量值和变量名信息。可是,我们还需要根据这些变量信息动态 定义 wrapped 函数,这个就有点麻烦了…
Python 毕竟是“万物皆对象”的动态语言。
compiletypes.FunctionType,我们就可以使用代码对象,并向其中注入一些信息(如 globals),即可生成一个可用的函数。
1 2 3 4 5 6 7 8 9 10 11 12 func_def = """  def f(a):     return a + 1 """ module_code = compile (func_def, '<>' , 'exec' ) function_code = next (code for  code in  module_code.co_consts                      if  isinstance (code, types.CodeType)) func = types.FunctionType(function_code, {}) print (func, func(2 ))
有了这个,我们就可以从一段字符串中动态生成一个 Python 函数了。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 def  nb_partial (*partial_args ):    """      生成一个偏函数,且新函数的参数列表中会去除     """     def  decorator (func ):         pre_arg_str = ', ' .join(map (repr , partial_args))         code = func.__code__         args = code.co_varnames[:code.co_argcount]         post_args = args[len (partial_args):]         post_args_str = ', ' .join(post_args)         func_def = (f"def wrapped({post_args_str} ):\n"                      f"    return func({pre_arg_str} , {post_args_str} )" )         module_code = compile (func_def, '<>' , 'exec' )         function_code = next (code for  code in  module_code.co_consts                              if  isinstance (code, types.CodeType))         wrapped = types.FunctionType(function_code,                                      {'func' : func})                  return  wraps(func)(wrapped)     return  decorator 
这里有一个小缺陷没有解决掉:这个函数中的 func 是一个全局变量,而不是闭包中的自由变量。func 本身就不是自由变量啊…脑残了 Orz,继续继续。
所以,通过动态定义函数,我们就可以实现还原参数列表的的偏函数功能。大体步骤如下:
提取原函数参数 
处理新函数参数 
构造函数定义字符串 
动态定义新函数 
 
完整实现见下面 。
我们的原函数目前只支持了最简单的函数声明方式,而无法支持关键词参数,参数默认值等。
结合上面写的 FuncParser,我们还可以对动态函数的定义做进一步改进。下面见 。
当然,这个实现还有一些可以改进之处:
偏函数定义时可以支持关键字参数,这需要进行更多的判断,比如定义时传进去的关键字参数,有可能在原函数中是序列参数; 
有很多对象并没有很好地实现 __repr__ 魔术方法,导致 repr(arg) 后生成的字符串在函数定义中并不能做到完整还原。所以我们最好通过 globals 将它直接传进新的函数中,而不是使用字符串来进行参数传递。 
 
突然犯懒,就不再做进一步的实现了。
FuncParser1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 from  collections import  namedtupleFuncArgs = namedtuple('FuncArgs' , ('args' , 'defaults' )) class  FuncParser :    @classmethod     def  parse_vars (cls, func ):         code = func.__code__         argc = code.co_argcount         kwargc = code.co_kwonlyargcount         varnames = code.co_varnames         res = {             "arg" : varnames[:argc],             "kwarg" : varnames[argc:argc + kwargc],             "*args" : None ,             "**kwargs" : None          }         flag = code.co_flags         if  flag & 0x04 :                          res['*args' ] = varnames[argc + kwargc]         if  flag & 0x08 :                          res['**kwargs' ] = varnames[argc + kwargc + bool (flag & 0x04 )]                  defaults = func.__defaults__ or  ()         defaults_map = dict (zip (reversed (res['arg' ]), reversed (defaults)))         defaults_map.update(func.__kwdefaults__ or  {})         return  FuncArgs(res, defaults_map)     @classmethod     def  build_param_str (cls, func_args ):         func_vars, defaults_map = func_args         params = []                  for  arg in  func_vars['arg' ]:             if  arg in  defaults_map:                 default_val = defaults_map[arg]                 params.append(f'{arg} ={default_val} ' )             else :                 params.append(arg)                  if  func_vars['*args' ]:             params.append(f"*{func_vars['*args' ]} " )         else :             if  func_vars['kwarg' ]:                 params.append('*' )                  for  kwarg in  func_vars['kwarg' ]:             if  kwarg in  defaults_map:                 default_val = defaults_map[kwarg]                 params.append(f'{kwarg} ={default_val} ' )             else :                                  params.append(f'{kwarg} ={kwarg} ' )                  if  func_vars['**kwargs' ]:             params.append(f"**{func_vars['**kwargs' ]} " )         param = ', ' .join(params)         param_str = '({})' .format (param)         return  param_str     @classmethod     def  analyse_func_param (cls, func ):              func_args = cls.parse_vars(func)         name = func.__name__         return  cls.build_param_str(func_args) def  test_func (a, *args, b=1 , **kwargs ):    pass  if  __name__ == '__main__' :    import  inspect     import  types     funcs =  [f for  f in  locals ().values()               if  isinstance (f, types.FunctionType)]     for  f in  funcs:         fstr = FuncParser.analyse_func_param(f)         assert  fstr == str (inspect.signature(f))          
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 import  typesfrom  functools import  wrapsfrom  inspect import  signaturedef  nb_partial (*partial_args ):    """      生成一个偏函数,且新函数的参数列表中会去除     """     def  decorator (func ):         pre_arg_str = ', ' .join(map (repr , partial_args))         code = func.__code__         if  len (partial_args) > code.co_argcount:             raise  TypeError("Too many positional arguments" )         args = code.co_varnames[:code.co_argcount]         post_args = args[len (partial_args):]         post_args_str = ', ' .join(post_args)         func_def = (f"def wrapped({post_args_str} ):\n"                      f"    return func({pre_arg_str} , {post_args_str} )" )         module_code = compile (func_def, '<>' , 'exec' )         function_code = next (code for  code in  module_code.co_consts                              if  isinstance (code, types.CodeType))         wrapped = types.FunctionType(function_code,                                      {'func' : func})         return  wraps(func)(wrapped)     return  decorator def  func (a, b, c ):    return  sum ([a, b, c]) if  __name__ == '__main__' :    print (func(1 , 2 , 3 ,), func, signature(func))          partial = nb_partial(3 , 4 )(func)          print (partial(5 ), partial, signature(partial, follow_wrapped=False ))      
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 import  typesfrom  functools import  wrapsfrom  inspect import  signaturefrom  funcparser import  FuncParser, FuncArgsdef  nb_partial (*partial_args ):    """      生成一个偏函数,且新函数的参数列表中会去除     """     def  decorator (func ):         code = func.__code__         if  len (partial_args) > code.co_argcount:             raise  TypeError("Too many positional arguments" )         func_vars, defaults_map = FuncParser.parse_vars(func)                  def_vars = func_vars.copy()         def_vars["arg" ] = def_vars["arg" ][len (partial_args):]         def_str = FuncParser.build_param_str(FuncArgs(def_vars, defaults_map))                  call_vars = func_vars.copy()         call_vars["arg" ] = list (call_vars["arg" ])         call_vars["arg" ][:len (partial_args)] = list (map (repr , partial_args))         call_str = FuncParser.build_param_str(FuncArgs(call_vars, {}))         call_str = call_str.replace('*, ' , '' )         func_def = (f"def wrapped{def_str} :\n"                      f"    return func{call_str} " )         module_code = compile (func_def, '<>' , 'exec' )         function_code = next (code for  code in  module_code.co_consts                              if  isinstance (code, types.CodeType))                  argdefs = []         for  arg in  reversed (func_vars["arg" ]):             if  arg not  in  defaults_map:                 break              argdefs.append(defaults_map[arg])         argdefs.reverse()         wrapped = types.FunctionType(function_code,                                      {'func' : func},                                      argdefs=tuple (argdefs))                  wrapped.__kwdefaults__ = func.__kwdefaults__         return  wraps(func)(wrapped)     return  decorator def  func (a, b=2 , c=3 , *, d=5  ):    print (locals ())     return  sum ([a, b, c, d]) if  __name__ == '__main__' :    print (func(1 , 2 , 3 , d=6 ), func, signature(func))                    partial = nb_partial(3 )(func)     print (partial(6 , d=7 ), partial, signature(partial, follow_wrapped=False ))