I\'m analyzing the AST generated by python code for \"fun and profit\", and I would like to have something more graphical than \"ast.dump\" to actually see the AST generated.
Fantastic, it works and it's really simple
class AstGraphGenerator(object):
def __init__(self):
self.graph = defaultdict(lambda: [])
def __str__(self):
return str(self.graph)
def visit(self, node):
"""Visit a node."""
method = 'visit_' + node.__class__.__name__
visitor = getattr(self, method, self.generic_visit)
return visitor(node)
def generic_visit(self, node):
"""Called if no explicit visitor function exists for a node."""
for _, value in ast.iter_fields(node):
if isinstance(value, list):
for item in value:
if isinstance(item, ast.AST):
self.visit(item)
elif isinstance(value, ast.AST):
self.graph[type(node)].append(type(value))
self.visit(value)
So it's the same as a normal NodeVisitor, but I have a defaultdict where I add the type of the node for each son. Then I pass this dictionary to pygraphviz.AGraph and I get my nice result.
The only problem is that the type doesn't say much, but on the other hand using ast.dump() is way too verbose.
Best thing would be to get the actual source code for each node, is that possible?
EDIT: now it's much better, I pass in the constructor also the source code and I try to get the code line if possible, otherwise just print out the type.
class AstGraphGenerator(object):
def __init__(self, source):
self.graph = defaultdict(lambda: [])
self.source = source # lines of the source code
def __str__(self):
return str(self.graph)
def _getid(self, node):
try:
lineno = node.lineno - 1
return "%s: %s" % (type(node), self.source[lineno].strip())
except AttributeError:
return type(node)
def visit(self, node):
"""Visit a node."""
method = 'visit_' + node.__class__.__name__
visitor = getattr(self, method, self.generic_visit)
return visitor(node)
def generic_visit(self, node):
"""Called if no explicit visitor function exists for a node."""
for _, value in ast.iter_fields(node):
if isinstance(value, list):
for item in value:
if isinstance(item, ast.AST):
self.visit(item)
elif isinstance(value, ast.AST):
node_source = self._getid(node)
value_source = self._getid(value)
self.graph[node_source].append(value_source)
# self.graph[type(node)].append(type(value))
self.visit(value)
If you look at ast.NodeVisitor, it's a fairly trivial class. You can either subclass it or just reimplement its walking strategy to whatever you need. For instance, keeping references to the parent when nodes are visited is very simple to implement this way, just add a visit
method that also accepts the parent as an argument, and pass that from your own generic_visit
.
P.S. By the way, it appears that NodeVisitor.generic_visit
implements DFS, so all you have to do is add the parent node passing.