预测GDP应用:Numpy 线性回归+Matplotlib 作图

房东的猫 提交于 2020-10-25 03:32:21

预测GDP应用:Numpy 线性回归+Matplotlib 作图

需求

通过2000~2019年中美两国的GDP数据,预测后续几年GDP的发展趋势:

  • 读取.csv文件,并将字符串调整为浮点型
  • 进行二阶线性回归模拟
  • 支持数据可视化

保命声明:用线性回归预测GDP发展并不合理,只是作为python学习参考。

如果想要了解更有意义的GDP对比可以参考b站翟老师的:https://b23.tv/6aYFVf

成品效果

在这里插入图片描述

原数据格式

.csv 文件(“testgdp.csv”),gdp数据每千位均被 “,” 隔开
在这里插入图片描述

需求拆解

1、csv文件读取

以测试文件"testgdp.csv"为例,目标将csv数据读取成适合进行线性回归的格式ndarray

方法一pandasread_csv()函数

import pandas as pd
import numpy as np
data = pd.read_csv("testgdp.csv")
df = pd.DataFrame(data)
print(df.head())

在这里插入图片描述

years = np.array(df.years) #可以转化为 ndarray
years

可以转化为 ndarray

方法二: python自带的 open()函数

import csv
import numpy as np
data_list = []
with open("testgdp.csv",encoding = 'utf-8') as csvfile:
    csv_reader = csv.reader(csvfile)
    for row in csv_reader:
        data_list.append(row[0:3])#第3~7列为空数据,需要排除
    data1 = np.array(data_list)
    data2 = np.delete(data1,-1,axis=0)#删除最后一行空值行,axis=1时可删除列
data2

在这里插入图片描述

2、对“xxx,xxx,xxx”格式字符串转化为数字

split():用指定分隔符对 字符串 进行切片,变为 list

strr.split (str="", num=string.count(str))

  • strr 为原字符串
  • str 为分隔符号
  • num – 分割次数。默认为 -1, 即分隔所有
def intt(list,exc_rate=1):#将"xxx,xxx,xxx,xxx,xxx"格式的str转化为 整型,exc_rate为汇率
    list_new = []
    for strr in list:
        int_list = strr.split(',') # 分割str,转化为列表
        lenth = len(int_list)
        result = 0
        for n in range(lenth):
            ii = int(int_list[n])
            result = result + ii*1000**(lenth-n-1)*exc_rate
        list_new.append(result)
    return list_new

list = ['11,061,552,790,044','14,342,902,842,915','234,322,342,111','123,212,231']
intt(list)

在这里插入图片描述

3、线性回归:np.polyfit()多项式拟合、np.polyval()多项式曲线求值

P = np.polyfit(x, y, deg, rcond=None, full=False, w=None, cov=False)

  • x, y:一般是array格式的数组,分别代表自变量和因变量
  • deg:阶数(需要整型),即需要进行几阶线性回归
  • 其他数据不太常用,可以不输入,即使用默认参数。如果需要了解可以参考:numpy.polyfit

输出参数 P为拟合多项式
P ( 1 ) x n + P ( 2 ) x n − 1 + . . . + P ( n ) x + P ( n + 1 ) 的 系 数 组 合 P(1)x^n + P(2)x^{n-1} +...+ P(n)x + P(n+1) 的 系数组合 P(1)xn+P(2)xn1+...+P(n)x+P(n+1)
P 为[ 1, 2, 3]时,代表多项式线性回归的结果为
Y = x 2 + 2 x + 3 Y = x^2+2x+3 Y=x2+2x+3
可以用np.polyval()方法输出预测结果Y,即



Y = np.polyval(P, x)

4、模块输出可视化图表

要用到matplotlib.pyplot,这个模块内容非常非常多,现在根据需求选取几个易用的函数

官方文档:https://matplotlib.org/api/pyplot_api.html

功能一:绘制关系曲线

绘制一条x,y关系曲线,红色,宽度为2,标签为label

plt.plot(x, y, color="red”,linestyle="-", linewidth=2.0, label=‘label')
  • x, y:与前面的x, y相同,支持array格式的数组,分别代表自变量和因变量
  • 设置label标签有助于后续生成图例
import matplotlib.pyplot as plt 
x=[1,2,3,5] 
y=[2,3,5,9] 
plt.plot(x, y,color="red",linestyle="-", linewidth=2.0,label='label1') 
plt.show()

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-JbJa664r-1599808354505)(/Users/zhangning/Library/Application Support/typora-user-images/image-20200911145804341.png)]

功能二:新增图例

plt.legend(loc=*'best'*,label=lable_list)

loc=‘best’时图例自动‘安家’在一个坐标面内的数据图表最少的位置,可以设置为指定位置。

参考链接:https://zhuanlan.zhihu.com/p/111108841

在这里插入图片描述

功能三:箭头标注关键信息

对第三个坐标点用红色箭头标注,箭头离坐标相差0.05个单位。同时在(4,2)提醒’this is the annotate’.

plt.annotate('this is the annotate', xy=(x[2],y[2]), xycoords='data', xytext=(4,2),
arrowprops=dict(facecolor='red', shrink=0.05))

可以参考https://blog.csdn.net/wizardforcel/article/details/54782628
在这里插入图片描述

实例代码

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
def intt(list,exc_rate=1):#将"xxx,xxx,xxx,xxx,xxx"格式的str转化为 整型,exc_rate为汇率
    list_new = []
    for strr in list:
        int_list = strr.split(',') # 分割str,转化为列表
        lenth = len(int_list)
        result = 0
        for n in range(lenth):
            ii = int(int_list[n])
            result = result + ii*1000**(lenth-n-1)*exc_rate
        list_new.append(result)
    return list_new

  def pre(n):#n为预测时间(年)
      data = pd.read_csv("testgdp.csv")
      df = pd.DataFrame(data)
      df = df.drop([19])#删除空行
      years = np.array(df.years)
      cn = intt(np.array(df.cn))
      usa = intt(np.array(df.us))
      model_cn = np.polyfit(years,cn,2)#阶线性回归cn
      model_usa = np.polyfit(years,usa,2)#2阶维线性回归usa
      overyear_list = []
      overusa_list = []
      overcn_list = []
    for i in range(n):#预测n年后gdp数据表现
        yy=2020+i
        cn_gdp=np.polyval(model_cn,yy)
        usa_gdp=np.polyval(model_usa,yy)
        if cn_gdp>usa_gdp:#判断何时中国gdp超过美国,并记录下来
            overyear_list.append(yy)
            overusa_list.append(usa_gdp)
            overcn_list.append(cn_gdp)
        cn = np.append(cn,cn_gdp)
        usa = np.append(usa,usa_gdp)
        years=np.append(years,yy)
    plt.plot(years, cn,color="red",linestyle="-", linewidth=2.0,label='CN')
    plt.plot(overyear_list, overcn_list, color="red", linestyle="-", linewidth=4.0)#加粗超过美国的部分
    plt.plot(years, usa,color="blue",
             linestyle="-", linewidth=2.0,label='USA')
    plt.plot(years[0:len(years)-len(overyear_list)+1],
             usa[0:len(years)-len(overyear_list)+1],
             color="blue", linestyle="-", linewidth=4.0)
    plt.legend(loc='upper left')#图例,位置左上
    plt.annotate(s=("%d:CN%.1ftrillion ,USA%.1ftrillion"%(overyear_list[0],overcn_list[0]/(10**12),overusa_list[0]/(10**12))),xy=(overyear_list[0],overcn_list[0]),
                 xytext=(overyear_list[0]+n/10,overcn_list[0]*0.6)
                 ,arrowprops=dict(facecolor='red', shrink=0.05))#arrowprops箭头
    plt.show()
pre(40)

后续进阶

  • 增加爬虫功能(合法的那种!)
  • 优化可视化图表(增加图表样式,增加图像交互能力,如调用Pyecharts
  • 增加更多纬度数据,采用逻辑回归
  • 增加与数据库对接的功能
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!