Can you patch *just* a nested function with closure, or must the whole outer function be repeated?

人盡茶涼 提交于 2019-11-27 00:53:57

Yes, you can replace an inner function, even if it is using a closure. You'll have to jump through a few hoops though. Please take into account:

  1. You need to create the replacement function as a nested function too, to ensure that Python creates the same closure. If the original function has a closure over the names foo and bar, you need to define your replacement as a nested function with the same names closed over. More importantly, you need to use those names in the same order; closures are referenced by index.

  2. Monkey patching is always fragile and can break with the implementation changing. This is no exception. Retest your monkey patch whenever you change versions of the patched library.

To understand how this will work, I'll first explain how Python handles nested functions. Python uses code objects to produce function objects as needed. Each code object has an associated constants sequence, and the code objects for nested functions are stored in that sequence:

>>> def outerfunction(*args):
...     def innerfunction(val):
...         return someformat.format(val)
...     someformat = 'Foo: {}'
...     for arg in args:
...         yield innerfunction(arg)
... 
>>> outerfunction.__code__
<code object outerfunction at 0x105b27ab0, file "<stdin>", line 1>
>>> outerfunction.__code__.co_consts
(None, <code object innerfunction at 0x10f136ed0, file "<stdin>", line 2>, 'outerfunction.<locals>.innerfunction', 'Foo: {}')

The co_consts sequence is an immutable object, a tuple, so we cannot just swap out the inner code object. I'll show later on how we'll produce a new function object with just that code object replaced.

Next, we need to cover closures. At compile time, Python determines that a) someformat is not a local name in innerfunction and that b) it is closing over the same name in outerfunction. Python not only then generates the bytecode to produce the correct name lookups, the code objects for both the nested and the outer functions are annotated to record that someformat is to be closed over:

>>> outerfunction.__code__.co_cellvars
('someformat',)
>>> outerfunction.__code__.co_consts[1].co_freevars
('someformat',)

You want to make sure that the replacement inner code object only ever lists those same names as free variables, and does so in the same order.

Closures are created at run-time; the byte-code to produce them is part of the outer function:

>>> import dis
>>> dis.dis(outerfunction)
  2           0 LOAD_CLOSURE             0 (someformat)
              2 BUILD_TUPLE              1
              4 LOAD_CONST               1 (<code object innerfunction at 0x10f136ed0, file "<stdin>", line 2>)
              6 LOAD_CONST               2 ('outerfunction.<locals>.innerfunction')
              8 MAKE_FUNCTION            8
             10 STORE_FAST               1 (innerfunction)

# ... rest of disassembly omitted ...

The LOAD_CLOSURE bytecode there creates a closure for the someformat variable; Python creates as many closures as used by the function in the order they are first used in the inner function. This is an important fact to remember for later. The function itself looks up these closures by position:

>>> dis.dis(outerfunction.__code__.co_consts[1])
  3           0 LOAD_DEREF               0 (someformat)
              2 LOAD_METHOD              0 (format)
              4 LOAD_FAST                0 (val)
              6 CALL_METHOD              1
              8 RETURN_VALUE

The LOAD_DEREF opcode picked the closure at position 0 here to gain access to the someformat closure.

In theory this also means you can use entirely different names for the closures in your inner function, but for debugging purposes it makes much more sense to stick to the same names. It also makes verifying that the replacement function will slot in properly easier, as you can just compare the co_freevars tuples if you use the same names.

Now for the swapping trick. Functions are objects like any other in Python, instances of a specific type. The type isn't exposed normally, but the type() call still returns it. The same applies to code objects, and both types even have documentation:

>>> type(outerfunction)
<type 'function'>
>>> print(type(outerfunction).__doc__)
Create a function object.

  code
    a code object
  globals
    the globals dictionary
  name
    a string that overrides the name from the code object
  argdefs
    a tuple that specifies the default argument values
  closure
    a tuple that supplies the bindings for free variables
>>> type(outerfunction.__code__)
<type 'code'>
>>> print(type(outerfunction.__code__).__doc__)
code(argcount, kwonlyargcount, nlocals, stacksize, flags, codestring,
      constants, names, varnames, filename, name, firstlineno,
      lnotab[, freevars[, cellvars]])

Create a code object.  Not for the faint of heart.

We'll use these type objects to produce a new code object with updated constants, and then a new function object with updated code object:

def replace_inner_function(outer, new_inner):
    """Replace a nested function code object used by outer with new_inner

    The replacement new_inner must use the same name and must at most use the
    same closures as the original.

    """
    if hasattr(new_inner, '__code__'):
        # support both functions and code objects
        new_inner = new_inner.__code__

    # find original code object so we can validate the closures match
    ocode = outer.__code__
    function, code = type(outer), type(ocode)
    iname = new_inner.co_name
    orig_inner = next(
        const for const in ocode.co_consts
        if isinstance(const, code) and const.co_name == iname)
    # you can ignore later closures, but since they are matched by position
    # the new sequence must match the start of the old.
    assert (orig_inner.co_freevars[:len(new_inner.co_freevars)] ==
            new_inner.co_freevars), 'New closures must match originals'
    # replace the code object for the inner function
    new_consts = tuple(
        new_inner if const is orig_inner else const
        for const in outer.__code__.co_consts)

    # create a new function object with the new constants
    return function(
        code(ocode.co_argcount, ocode.co_kwonlyargcount, ocode.co_nlocals,
             ocode.co_stacksize, ocode.co_flags, ocode.co_code, new_consts,
             ocode.co_names, ocode.co_varnames, ocode.co_filename,
             ocode.co_name, ocode.co_firstlineno, ocode.co_lnotab,
             ocode.co_freevars,
             ocode.co_cellvars),
        outer.__globals__, outer.__name__, outer.__defaults__,
        outer.__closure__)

The above function validates that the new inner function (which can be passed in as either a code object or as a function) will indeed use the same closures as the original. It then creates new code and function objects to match the old outer function object, but with the nested function (located by name) replaced with your monkey patch.

To demonstrate that the above all works, lets replace innerfunction with one that increments each formatted value by 2:

>>> def create_inner():
...     someformat = None  # the actual value doesn't matter
...     def innerfunction(val):
...         return someformat.format(val + 2)
...     return innerfunction
... 
>>> new_inner = create_inner()

The new inner function is created as a nested function too; this is important as it ensures that Python will use the correct bytecode to look up the someformat closure. I used a return statement to extract the function object, but you could also look at create_inner.__code__.co_consts to grab the code object.

Now we can patch the original outer function, swapping out just the inner function:

>>> new_outer = replace_inner_function(outerfunction, new_inner)
>>> list(outerfunction(6, 7, 8))
['Foo: 6', 'Foo: 7', 'Foo: 8']
>>> list(new_outer(6, 7, 8))
['Foo: 8', 'Foo: 9', 'Foo: 10']

The original function echoed out the original values, but the new returned values incremented by 2.

You can even create new replacement inner functions that use fewer closures:

>>> def demo_outer():
...     closure1 = 'foo'
...     closure2 = 'bar'
...     def demo_inner():
...         print(closure1, closure2)
...     demo_inner()
...
>>> def create_demo_inner():
...     closure1 = None
...     def demo_inner():
...         print(closure1)
...
>>> replace_inner_function(demo_outer, create_demo_inner.__code__.co_consts[1])()
foo

So, to complete the picture:

  1. Create your monkey-patch inner function as a nested function with the same closures
  2. Use replace_inner_function() to produce a new outer function
  3. Monkey patch the original outer function to use the new outer function produced in step 2.

Martijn's answer is good, but there is one drawback that would be nice to remove:

You want to make sure that the replacement inner code object only ever lists those same names as free variables, and does so in the same order.

This isn't a particularly difficult constraint for the normal case, but it isn't pleasant to be dependent on undefined behaviours like name ordering and when things go wrong there are potentially really nasty errors and possibly even hard crashes.

My approach has its own drawbacks, but in most cases I believe the drawback above would motivate using it. As far as I can tell, it should also be more portable.

The basic approach is to load the source with inspect.getsource, change it and then evaluate it. This is done at AST level in order to keep things in order.

Here is the code:

import ast
import inspect
import sys

class AstReplaceInner(ast.NodeTransformer):
    def __init__(self, replacement):
        self.replacement = replacement

    def visit_FunctionDef(self, node):
        if node.name == self.replacement.name:
            # Prevent the replacement AST from messing
            # with the outer AST's line numbers
            return ast.copy_location(self.replacement, node)

        self.generic_visit(node)
        return node

def ast_replace_inner(outer, inner, name=None):
    if name is None:
        name = inner.__name__

    outer_ast = ast.parse(inspect.getsource(outer))
    inner_ast = ast.parse(inspect.getsource(inner))

    # Fix the source lines for the outer AST
    outer_ast = ast.increment_lineno(outer_ast, inspect.getsourcelines(outer)[1] - 1)

    # outer_ast should be a module so it can be evaluated;
    # inner_ast should be a function so we strip the module node
    inner_ast = inner_ast.body[0]

    # Replace the function
    inner_ast.name = name
    modified_ast = AstReplaceInner(inner_ast).visit(outer_ast)

    # Evaluate the modified AST in the original module's scope
    compiled = compile(modified_ast, inspect.getsourcefile(outer), "exec")
    outer_globals = outer.__globals__ if sys.version_info >= (3,) else outer.func_globals
    exec_scope = {}

    exec(compiled, outer_globals, exec_scope)
    return exec_scope.popitem()[1]

A quick walkthrough. AstReplaceInner is an ast.NodeTransformer, which just allows you to modify ASTs by mapping certain nodes to certain other nodes. In this case, it takes a replacement node to replace an ast.FunctionDef node with whenever names match.

ast_replace_inner is the function we really care about, which takes two functions and optionally a name. The name is used to allow replacing the inner function with another function of a different name.

The ASTs are parsed:

    outer_ast = ast.parse(inspect.getsource(outer))
    inner_ast = ast.parse(inspect.getsource(inner))

The transformation is made:

    modified_ast = AstReplaceInner(inner_ast).visit(outer_ast)

The code is evaluated and the function extracted:

    exec(compiled, outer_globals, exec_scope)
    return exec_scope.popitem()[1]

Here's an example of use. Assume this old code is in buggy.py:

def outerfunction():
    numerator = 10.0

    def innerfunction(denominator):
        return denominator / numerator

    return innerfunction

You want to replace innerfunction with

def innerfunction(denominator):
    return numerator / denominator

You write:

import buggy

def innerfunction(denominator):
    return numerator / denominator

buggy.outerfunction = ast_replace_inner(buggy.outerfunction, innerfunction)

Alternatively, you could write:

def divide(denominator):
    return numerator / denominator

buggy.outerfunction = ast_replace_inner(buggy.outerfunction, divide, "innerfunction")

The main disadvantage of this technique is that one requires inspect.getsource to work on both the target and replacement. This will fail if the target is "built-in" (written in C) or compiled to bytecode before distributing. Note that if it's built-in, Martijn's technique will not work either.

Another major disadvantage is that the line numbers from the inner function are completely screwy. This is not a big problem if the inner function is small, but if you have a large inner function this is worth thinking about.

Other disadvantages come from if the function object is not specified in the same way. For example, you could not patch

def outerfunction():
    numerator = 10.0

    innerfunction = lambda denominator: denominator / numerator

    return innerfunction

the same way; a different AST transformation would be needed.

You should decide which tradeoff makes the most sense for your particular circumstance.

I needed this, but in a class and python2/3. So I extended @MartijnPieters's solution some

import types, inspect, six

def replace_inner_function(outer, new_inner, class_class=None):
    """Replace a nested function code object used by outer with new_inner

    The replacement new_inner must use the same name and must at most use the
    same closures as the original.

    """
    if hasattr(new_inner, '__code__'):
        # support both functions and code objects
        new_inner = new_inner.__code__

    # find original code object so we can validate the closures match
    ocode = outer.__code__

    iname = new_inner.co_name
    orig_inner = next(
        const for const in ocode.co_consts
        if isinstance(const, types.CodeType) and const.co_name == iname)
    # you can ignore later closures, but since they are matched by position
    # the new sequence must match the start of the old.
    assert (orig_inner.co_freevars[:len(new_inner.co_freevars)] ==
            new_inner.co_freevars), 'New closures must match originals'
    # replace the code object for the inner function
    new_consts = tuple(
        new_inner if const is orig_inner else const
        for const in outer.__code__.co_consts)

    if six.PY3:
        new_code = types.CodeType(ocode.co_argcount, ocode.co_kwonlyargcount, ocode.co_nlocals, ocode.co_stacksize,
             ocode.co_flags, ocode.co_code, new_consts, ocode.co_names,
             ocode.co_varnames, ocode.co_filename, ocode.co_name,
             ocode.co_firstlineno, ocode.co_lnotab, ocode.co_freevars,
             ocode.co_cellvars)
    else:
    # create a new function object with the new constants
        new_code = types.CodeType(ocode.co_argcount, ocode.co_nlocals, ocode.co_stacksize,
             ocode.co_flags, ocode.co_code, new_consts, ocode.co_names,
             ocode.co_varnames, ocode.co_filename, ocode.co_name,
             ocode.co_firstlineno, ocode.co_lnotab, ocode.co_freevars,
             ocode.co_cellvars)

    new_function= types.FunctionType(new_code, outer.__globals__, 
                                     outer.__name__, outer.__defaults__,
                                     outer.__closure__)

    if hasattr(outer, '__self__'):
        if outer.__self__ is None:
            if six.PY3:
                return types.MethodType(new_function, outer.__self__, class_class)
            else:
                return types.MethodType(new_function, outer.__self__, outer.im_class)
        else:
            return types.MethodType(new_function, outer.__self__, outer.__self__.__class__)

    return new_function

This should now work for functions, bound classmethods and unbound class methods. (The class_class argument is only needed for python3 for unbound methods). Thanks @MartijnPieters for doing most of the work! I never would have figured this out ;)

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!