Python随笔(四)抽象语法树AST

六眼飞鱼酱① 提交于 2019-11-29 07:24:50

什么是抽象语法树嘞?

在计算机科学中,抽象语法和抽象语法树其实是源代码的抽象语法结构的树状表现形式 我们可以用一个在线的AST编辑器来观察AST的构建
Python语言的执行过程就是通过将Python字节码转化为抽象语法树来进行下一步的分析等其他操作,所以将Python转化为抽象语法树更利于程序的分析
一般来说,我们早期的学习当中固然会用到一种叫做表达式树的东西,我们用Python来实现一下表达式树

class StackEmptyException(Exception): pass


class StackFullException(Exception): pass


class Node:
    def __init__(self, val=None, nxt=None):
        self.value = val
        self.next = nxt

    def __str__(self):
        return str(self.value)


class Stack:

    def __init__(self, max=0):
        self._top = None
        self._max = 0
        self.max = max

    @property
    def max(self):
        return self._max

    @max.setter
    def max(self, m):
        m = int(m)
        if m < self.length:
            raise Exception('Resize stack failed, please pop some elements first.')
        self._max = m
        if self._max < 0:
            self._max = 0

    def init(self, iterable=()):
        if not iterable:
            return
        self._top = Node(iterable[0])
        for i in iterable[1:]:
            node = self._top
            self._top = Node(i)
            self._top.next = node

    def show(self):
        def _traversal(self):
            node = self._top
            while node and node.next:
                yield node
                node = node.next
            yield node

        print('\n'.join(map(lambda x: '|{:^7}|'.format(str(x)), _traversal(self))) + '\n ' + 7 * '-')

    @property
    def length(self):
        if self._top is None:
            return 0
        node = self._top
        i = 1
        while node.next:
            node = node.next
            i += 1
        return i

    @property
    def is_empty(self):
        return self._top is None

    @property
    def is_full(self):
        return bool(self._max and self.length == self._max)

    def push(self, item):
        if self.is_full:
            raise StackFullException('Error: trying to push element into a full stack!')
        if not self._top:
            self._top = Node(item)
            return
        node = self._top
        self._top = Node(item)
        self._top.next = node

    def pop(self):
        if self.is_empty:
            raise StackEmptyException('Error: trying to pop element from an empty stack!')
        node = self._top
        self._top = self._top.next
        return node.value

    def top(self):
        return self._top.value if self._top else self._top

    def clear(self):
        while self._top:
            self.pop()


def test(stack):
    print('\nShow stack:')
    stack.show()

    print('\nInit linked list:')
    stack.init([1, 2, 3, 4, 5])
    stack.show()

    print('\nPush element to stack:')
    stack.push(6)
    stack.push(7)
    stack.push('like')
    stack.show()

    print('\nCheck top element:')
    print(stack.top())

    print('\nPop element from stack:')
    e = stack.pop()
    print('Element %s popped,' % e)
    stack.show()

    print('\nSet stack max size:')
    try:
        stack.max = 1
    except Exception as e:
        print(e)

    print('\nSet stack max size:')
    stack.max = 7
    print(stack.max)

    print('\nPush full stack:')
    try:
        stack.push(7)
    except StackFullException as e:
        print(e)

    print('\nClear stack:')
    stack.clear()
    stack.show()

    print('\nStack is empty:')
    print(stack.is_empty)

    print('\nPop empty stack:')
    try:
        stack.pop()
    except StackEmptyException as e:
        print(e)


class TreeNode:
    def __init__(self, val=None, lef=None, rgt=None):
        self.value = val
        self.left = lef
        self.right = rgt

    def __str__(self):
        return str(self.value)


class BinaryTree:
    def __init__(self, root=None):
        self._root = root

    def __str__(self):
        return '\n'.join(map(lambda x: x[1]*4*' '+str(x[0]), self.pre_traversal()))

    def pre_traversal(self, root=None):
        if not root:
            root = self._root
        x = []
        depth = -1

        def _traversal(node):
            nonlocal depth
            depth += 1
            x.append((node, depth))
            if node and node.left is not None:
                _traversal(node.left)
            if node and node.right is not None:
                _traversal(node.right)
            depth -= 1
            return x
        return _traversal(root)

    def in_traversal(self, root=None):
        if not root:
            root = self._root
        x = []
        depth = -1

        def _traversal(node):
            nonlocal depth
            depth += 1
            if node and node.left is not None:
                _traversal(node.left)
            x.append((node, depth))
            if node and node.right is not None:
                _traversal(node.right)
            depth -= 1
            return x
        return _traversal(root)

    def post_traversal(self, root=None):
        if not root:
            root = self._root
        x = []
        depth = -1

        def _traversal(node):
            nonlocal depth
            depth += 1
            if node and node.left is not None:
                _traversal(node.left)
            if node and node.right is not None:
                _traversal(node.right)
            x.append((node, depth))
            depth -= 1
            return x
        return _traversal(root)

    @property
    def max_depth(self):
        return sorted(self.pre_traversal(), key=lambda x: x[1])[-1][1]

    def show(self, tl=None):
        if not tl:
            tl = self.pre_traversal()
        print('\n'.join(map(lambda x: x[1]*4*' '+str(x[0]), tl)))

    def make_empty(self):
        self.__init__()

    def insert(self, item):
        if self._root is None:
            self._root = TreeNode(item)
            return

        def _insert(item, node):
            if not node:
                return TreeNode(item)
            if node.left is None:
                node.left = _insert(item, node.left)
            elif node.right is None:
                node.right = _insert(item, node.right)
            else:
                if len(self.pre_traversal(node.left)) <= len(self.pre_traversal(node.right)):
                    node.left = _insert(item, node.left)
                else:
                    node.right = _insert(item, node.right)
            return node
        self._root = _insert(item, self._root)


class ExpressionTree(BinaryTree):
    SIGN = {'+': 1, '-': 1, '*': 2, '/': 2, '(': 3}

    def gene_tree_by_postfix(self, expr):
        s =Stack()
        for i in expr:
            if i in self.SIGN.keys():
                right = s.pop()
                left = s.pop()
                node = TreeNode(i, left, right)
                s.push(node)
            else:
                s.push(TreeNode(i))
        self._root = s.pop()

class ExpressionTree(BinaryTree):
    SIGN = {'+': 1, '-': 1, '*': 2, '/': 2, '(': 3}

    def gene_tree_by_postfix(self, expr):
        s = Stack()
        for i in expr:
            if i in self.SIGN.keys():
                right = s.pop()
                left = s.pop()
                node = TreeNode(i, left, right)
                s.push(node)
            else:
                s.push(TreeNode(i))
        self._root = s.pop()


def test_expression_tree(ep):
    t = ExpressionTree()
    t.gene_tree_by_postfix(ep)
    print('\n------先序遍历-------')
    print(t)
    print('\n------后序遍历------')
    t.show(t.post_traversal())
    print('\n-------中序遍历-------')
    t.show(t.in_traversal())

if __name__ == '__main__':
    ep = 'a b + c d e + * *'
    test_expression_tree(ep.split(' '))

输出:
回到AST
AST主要作用有三步:

1. 解析(PARSE):将代码字符串解析成抽象语法树。
2. 转换(TRANSFORM):对抽象语法树进行转换操作。
3. 生成(GENERATE): 根据变换后的抽象语法树再生成代码字符串。  

Python官方对于CPython解释器对python源码的处理过程如下:

1. Parse source code into a parse tree (Parser/pgen.c)
2. Transform parse tree into an Abstract Syntax Tree (Python/ast.c)
3. Transform AST into a Control Flow Graph (Python/compile.c)
4. Emit bytecode based on the Control Flow Graph (Python/compile.c)

但是只知道上面还不够我们去理解,因为在Python中,以控制台为例,我们的输入都是些字符串例如a=2b=[1,2,3,4,5]之类我们要如何让计算机去理解并且执行这些东西呢?
这就是解释器的解释过程,负责把关键字,变量,空格,特殊字符进行处理处理的过程大概有下面两个步骤

1. 将整个代码字符串分割成 语法单元数组。
2. 在分词结果的基础之上分析 语法单元之间的关系。  

一个抽象语法树的基本构成

type:描述该语句的类型 --变量声明语句
kind:变量声明的关键字 -- var
declaration: 声明的内容数组,里面的每一项也是一个对象
    type: 描述该语句的类型 
    id: 描述变量名称的对象
        type:定义
        name: 是变量的名字
    init: 初始化变量值得对象
        type: 类型
        value: 值 "is tree" 不带引号
        row: "\"is tree"\" 带引号

一般来说我们在可以Python的pythonrun里面找到

PyObject *type;

定义了语法树的类型
一般来说,研究抽象语法树有哪些用途呢?

在一种语言的IDE中,语法的检查、风格的检查、格式化、高亮、错误提示,代码自动补全等等
通过搭建一个Python的语法树去理解表达式是如何被解析的,我们来看一个(3+2-5*0)/3的例子:

#首先定义四则运算
Num = lambda env, n: n
Var = lambda env, x: env[x]
Add = lambda env, a, b:_eval(env, a) + _eval(env, b)
Mul = lambda env, a, b:_eval(env, a) * _eval(env, b)
Sub = lambda env, a, b:_eval(env, a) - _eval(env, b)
Div = lambda env, a, b:_eval(env, a) / _eval(env, b)
#定义表达式计算
 _eval = lambda env, expr:expr[0](env, *expr[1:])
#定义环境中的自变量
env = {'i':5, 'j':2, 'k':3}
#定义语法树结构(我寻思这玩意怎么那么像Clojure呢。。。。。)
tree=(Div,(Sub,(Add,(Var,'k'),(Var,'j')),(Mul,(Var,'i'),(Num,0))),(Var,'k'))
print(_eval(env, tree))

输出:

承接前一篇虚拟机的运行机制,我们来看看Python的AST解析过程
首先来看Python虚拟机的循环执行框架
位于pythonrun.c文件中

PyObject *
PyEval_EvalFrameEx(PyFrameObject *f, int throwflag)
{
  ......
  // 获取当前活动线程的线程状态对象(PyThreadState)
  PyThreadState *tstate = PyThreadState_GET();
  // 设置线程状态对象中的frame
  tstate->frame = f;
  co = f->f_code;
  names = co->co_names;
  consts = co->co_consts;
 
  why = WHY_NOT;
  ......
  for (;;) {
    fast_next_opcode:
        f->f_lasti = INSTR_OFFSET();
        // 获取字节码指令
        opcode = NEXTOP();
        oparg = 0; 
        // 如果指令有参数,获取参数
        if (HAS_ARG(opcode))
            oparg = NEXTARG();
    dispatch_opcode:
      ......
  }
}

现在来调试一下PYVM


在我们对PYVM进行调试的过程中可以看到Py把stdin的字符串一个个“吃掉”了吃的过程是为了把字符串转换和解释为字节码,通过字节码构建抽象语法树,字节码的遍历是通过几个宏来实现:

#define INSTR_OFFSET()  ((int)(next_instr - first_instr))
#define NEXTOP()        (*next_instr++)
#define NEXTARG()       (next_instr += 2, (next_instr[-1]<<8) + next_instr[-2])
#define PEEKARG()       ((next_instr[2]<<8) + next_instr[1])
#define JUMPTO(x)       (next_instr = first_instr + (x))
#define JUMPBY(x)       (next_instr += (x))

在程序内部通过PyTokenizer_Get来获取输入字符串中是否存在关键字,构建好语法树以后通过PyRun_InteractiveOneObjectEx执行。
Python中AST的节点定义
pythoncore/Parser/node.c

PyNode_New(int type)
{
    node *n = (node *) PyObject_MALLOC(1 * sizeof(node));
    if (n == NULL)
        return NULL;
    n->n_type = type;
    n->n_str = NULL;
    n->n_lineno = 0;
    n->n_nchildren = 0;
    n->n_child = NULL;
    return n;
}

下面给出Python自带的AST例子,去观察构建出来的树

import ast
Monster ="""
class Monster:
    def __init__(self):
        self.level=0
        self.hp=1000
        self.boom=[x for x in range(10)]
    def eat(self,frut):
        self.hp+=1
    def howl(self):
        print("Ao uuuuuuuuuuuuuu")
monster=Monster()
monster.howl()
"""
if __name__=="__main__":
    # cm = compile(Monster, '<string>', 'exec')
    # exec (cm)
    r_node = ast.parse(Monster)
    print(ast.dump(r_node))

通过compile我们可以编译Python字符串执行字串的内容

同时,我们也可以用Python自带的AST库解析我们的字符串为语法树

参考文档:
[Abstract Syntax Trees]https://docs.python.org/3/library/ast.html
[轮子哥博客]http://www.cppblog.com/vczh/archive/2008/06/15/53373.html
[表达式树]http://www.cnblogs.com/stacklike/p/8284691.html
[AST库的使用]https://www.cnblogs.com/yssjun/p/10069199.html

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!