Python ast to dot graph

后端 未结 2 1651
萌比男神i
萌比男神i 2021-02-03 12:24

I\'m analyzing the AST generated by python code for \"fun and profit\", and I would like to have something more graphical than \"ast.dump\" to actually see the AST generated.

相关标签:
2条回答
  • 2021-02-03 12:29

    Fantastic, it works and it's really simple

    class AstGraphGenerator(object):
    
        def __init__(self):
            self.graph = defaultdict(lambda: [])
    
        def __str__(self):
            return str(self.graph)
    
        def visit(self, node):
            """Visit a node."""
            method = 'visit_' + node.__class__.__name__
            visitor = getattr(self, method, self.generic_visit)
            return visitor(node)
    
        def generic_visit(self, node):
            """Called if no explicit visitor function exists for a node."""
            for _, value in ast.iter_fields(node):
                if isinstance(value, list):
                    for item in value:
                        if isinstance(item, ast.AST):
                            self.visit(item)
    
                elif isinstance(value, ast.AST):
                    self.graph[type(node)].append(type(value))
                    self.visit(value)
    

    So it's the same as a normal NodeVisitor, but I have a defaultdict where I add the type of the node for each son. Then I pass this dictionary to pygraphviz.AGraph and I get my nice result.

    The only problem is that the type doesn't say much, but on the other hand using ast.dump() is way too verbose.

    Best thing would be to get the actual source code for each node, is that possible?

    EDIT: now it's much better, I pass in the constructor also the source code and I try to get the code line if possible, otherwise just print out the type.

    class AstGraphGenerator(object):
    
        def __init__(self, source):
            self.graph = defaultdict(lambda: [])
            self.source = source  # lines of the source code
    
        def __str__(self):
            return str(self.graph)
    
        def _getid(self, node):
            try:
                lineno = node.lineno - 1
                return "%s: %s" % (type(node), self.source[lineno].strip())
    
            except AttributeError:
                return type(node)
    
        def visit(self, node):
            """Visit a node."""
            method = 'visit_' + node.__class__.__name__
            visitor = getattr(self, method, self.generic_visit)
            return visitor(node)
    
        def generic_visit(self, node):
            """Called if no explicit visitor function exists for a node."""
            for _, value in ast.iter_fields(node):
                if isinstance(value, list):
                    for item in value:
                        if isinstance(item, ast.AST):
                            self.visit(item)
    
                elif isinstance(value, ast.AST):
                    node_source = self._getid(node)
                    value_source = self._getid(value)
                    self.graph[node_source].append(value_source)
                    # self.graph[type(node)].append(type(value))
                    self.visit(value)
    
    0 讨论(0)
  • 2021-02-03 12:50

    If you look at ast.NodeVisitor, it's a fairly trivial class. You can either subclass it or just reimplement its walking strategy to whatever you need. For instance, keeping references to the parent when nodes are visited is very simple to implement this way, just add a visit method that also accepts the parent as an argument, and pass that from your own generic_visit.

    P.S. By the way, it appears that NodeVisitor.generic_visit implements DFS, so all you have to do is add the parent node passing.

    0 讨论(0)
提交回复
热议问题