How can I get the calling expression of a function in Python?

前端 未结 1 1082
南旧
南旧 2021-01-06 09:58

For educational purpose, I would like to be able to print the complete calling expression of the current function. Not necessarily from an exception handler

相关标签:
1条回答
  • 2021-01-06 10:22

    For the curious, here is my final working code for such an unproductive purpose. Fun is everywhere! (almost)

    I do not mark this as the accepted answer right away, in the hope someone can enlighten us with a better option in a near future...

    It extracts the entire calling expression as expected. This code assumes the calling expression to be a bare function call, without any magic, special trick or nested/recursive calls. These special cases would have made the detection part less trivial obviously and are out-of-topic anyway.

    In details, I used the current function name to help locate the AST node of the calling expression, as well as the line number provided by inspect as a starting point.

    I couldn't use inspect.getsource() to isolate the caller's block, which would have been more optimized, because I found a case where it was returning an incomplete source code. For example when the caller's code was directly located in main's scope. Don't know if it's supposed to be a bug or a feature tho'...

    Once we have the source code, we just have to feed ast.parse() to get the root AST node and walk the tree to find the latest call to the current function, and voila!

    #!/usr/bin/env python3
    
    import inspect
    import ast
    
    def print_callexp(*args, **kwargs):
    
        def _find_caller_node(root_node, func_name, last_lineno):
            # init search state
            found_node = None
            lineno = 0
    
            def _luke_astwalker(parent):
                nonlocal found_node
                nonlocal lineno
                for child in ast.iter_child_nodes(parent):
                    # break if we passed the last line
                    if hasattr(child, "lineno"):
                        lineno = child.lineno
                    if lineno > last_lineno:
                        break
    
                    # is it our candidate?
                    if (isinstance(child, ast.Name)
                            and isinstance(parent, ast.Call)
                            and child.id == func_name):
                        # we have a candidate, but continue to walk the tree
                        # in case there's another one following. we can safely
                        # break here because the current node is a Name
                        found_node = parent
                        break
    
                    # walk through children nodes, if any
                    _luke_astwalker(child)
    
            # dig recursively to find caller's node
            _luke_astwalker(root_node)
            return found_node
    
        # get some info from 'inspect'
        frame = inspect.currentframe()
        backf = frame.f_back
        this_func_name = frame.f_code.co_name
    
        # get the source code of caller's module
        # note that we have to reload the entire module file since the
        # inspect.getsource() function doesn't work in some cases (i.e.: returned
        # source content was incomplete... Why?!).
        # --> is inspect.getsource broken???
        #     source = inspect.getsource(backf.f_code)
        #source = inspect.getsource(backf.f_code)
        with open(backf.f_code.co_filename, "r") as f:
            source = f.read()
    
        # get the ast node of caller's module
        # we don't need to use ast.increment_lineno() since we've loaded the whole
        # module
        ast_root = ast.parse(source, backf.f_code.co_filename)
        #ast.increment_lineno(ast_root, backf.f_code.co_firstlineno - 1)
    
        # find caller's ast node
        caller_node = _find_caller_node(ast_root, this_func_name, backf.f_lineno)
    
        # now, if caller's node has been found, we have the first line and the last
        # line of the caller's source
        if caller_node:
            #start_index = caller_node.lineno - backf.f_code.co_firstlineno
            #end_index = backf.f_lineno - backf.f_code.co_firstlineno + 1
            print("Hoooray! Found it!")
            start_index = caller_node.lineno - 1
            end_index = backf.f_lineno
            lineno = caller_node.lineno
            for ln in source.splitlines()[start_index:end_index]:
                print("  {:04d} {}".format(lineno, ln))
                lineno += 1
    
    def main():
        a_var = "but"
        print_callexp(
            a_var, "why?!",
            345, (1, 2, 3), hello="world")
    
    if __name__ == "__main__":
        main()
    

    You should get something like this:

    Hoooray! Found it!
      0079     print_callexp(
      0080         a_var, "why?!",
      0081         345, (1, 2, 3), hello="world")
    

    It still feels a bit messy but OTOH, it is quite an unusual goal. At least unusual enough in Python it seems. For example, at first glance, I was hoping to find a way to get direct access to an already loaded AST node that could be served by inspect through a frame object or in a similar fashion, instead of having to create a new AST node manually.

    Note that I have absolutely no idea if this is a CPython specific code. It should not be tho'. At least from what I've read from the docs.

    Also, I wonder how come there's no official pretty-print function in the ast module (or as a side module). ast.dump() would probably do the job with an additional indent argument to allow formatting the output and to debug the AST more easily.

    As a side note, I found this pretty neat and small function to help working with the AST.

    0 讨论(0)
提交回复
热议问题