Python ast to dot graph

后端 未结 2 1655
萌比男神i
萌比男神i 2021-02-03 12:24

I\'m analyzing the AST generated by python code for \"fun and profit\", and I would like to have something more graphical than \"ast.dump\" to actually see the AST generated.

2条回答
  •  隐瞒了意图╮
    2021-02-03 12:29

    Fantastic, it works and it's really simple

    class AstGraphGenerator(object):
    
        def __init__(self):
            self.graph = defaultdict(lambda: [])
    
        def __str__(self):
            return str(self.graph)
    
        def visit(self, node):
            """Visit a node."""
            method = 'visit_' + node.__class__.__name__
            visitor = getattr(self, method, self.generic_visit)
            return visitor(node)
    
        def generic_visit(self, node):
            """Called if no explicit visitor function exists for a node."""
            for _, value in ast.iter_fields(node):
                if isinstance(value, list):
                    for item in value:
                        if isinstance(item, ast.AST):
                            self.visit(item)
    
                elif isinstance(value, ast.AST):
                    self.graph[type(node)].append(type(value))
                    self.visit(value)
    

    So it's the same as a normal NodeVisitor, but I have a defaultdict where I add the type of the node for each son. Then I pass this dictionary to pygraphviz.AGraph and I get my nice result.

    The only problem is that the type doesn't say much, but on the other hand using ast.dump() is way too verbose.

    Best thing would be to get the actual source code for each node, is that possible?

    EDIT: now it's much better, I pass in the constructor also the source code and I try to get the code line if possible, otherwise just print out the type.

    class AstGraphGenerator(object):
    
        def __init__(self, source):
            self.graph = defaultdict(lambda: [])
            self.source = source  # lines of the source code
    
        def __str__(self):
            return str(self.graph)
    
        def _getid(self, node):
            try:
                lineno = node.lineno - 1
                return "%s: %s" % (type(node), self.source[lineno].strip())
    
            except AttributeError:
                return type(node)
    
        def visit(self, node):
            """Visit a node."""
            method = 'visit_' + node.__class__.__name__
            visitor = getattr(self, method, self.generic_visit)
            return visitor(node)
    
        def generic_visit(self, node):
            """Called if no explicit visitor function exists for a node."""
            for _, value in ast.iter_fields(node):
                if isinstance(value, list):
                    for item in value:
                        if isinstance(item, ast.AST):
                            self.visit(item)
    
                elif isinstance(value, ast.AST):
                    node_source = self._getid(node)
                    value_source = self._getid(value)
                    self.graph[node_source].append(value_source)
                    # self.graph[type(node)].append(type(value))
                    self.visit(value)
    

提交回复
热议问题