预测GDP应用:Numpy 线性回归+Matplotlib 作图
需求
通过2000~2019年中美两国的GDP数据,预测后续几年GDP的发展趋势:
- 读取.csv文件,并将字符串调整为浮点型
- 进行二阶线性回归模拟
- 支持数据可视化
保命声明:用线性回归预测GDP发展并不合理,只是作为python学习参考。
如果想要了解更有意义的GDP对比可以参考b站翟老师的:https://b23.tv/6aYFVf
成品效果
原数据格式
.csv 文件(“testgdp.csv”),gdp数据每千位均被 “,” 隔开
需求拆解
1、csv文件读取
以测试文件"testgdp.csv"为例,目标将csv数据读取成适合进行线性回归的格式ndarray
方法一:pandas的 read_csv()
函数
import pandas as pd
import numpy as np
data = pd.read_csv("testgdp.csv")
df = pd.DataFrame(data)
print(df.head())
years = np.array(df.years) #可以转化为 ndarray
years
方法二: python自带的 open()
函数
import csv
import numpy as np
data_list = []
with open("testgdp.csv",encoding = 'utf-8') as csvfile:
csv_reader = csv.reader(csvfile)
for row in csv_reader:
data_list.append(row[0:3])#第3~7列为空数据,需要排除
data1 = np.array(data_list)
data2 = np.delete(data1,-1,axis=0)#删除最后一行空值行,axis=1时可删除列
data2
2、对“xxx,xxx,xxx”格式字符串转化为数字
split():用指定分隔符对 字符串 进行切片,变为 list
strr.split (str="", num=string.count(str))
- strr 为原字符串
- str 为分隔符号
- num – 分割次数。默认为 -1, 即分隔所有
def intt(list,exc_rate=1):#将"xxx,xxx,xxx,xxx,xxx"格式的str转化为 整型,exc_rate为汇率
list_new = []
for strr in list:
int_list = strr.split(',') # 分割str,转化为列表
lenth = len(int_list)
result = 0
for n in range(lenth):
ii = int(int_list[n])
result = result + ii*1000**(lenth-n-1)*exc_rate
list_new.append(result)
return list_new
list = ['11,061,552,790,044','14,342,902,842,915','234,322,342,111','123,212,231']
intt(list)
3、线性回归:np.polyfit()多项式拟合、np.polyval()多项式曲线求值
P = np.polyfit(x, y, deg, rcond=None, full=False, w=None, cov=False)
x, y
:一般是array格式的数组,分别代表自变量和因变量deg
:阶数(需要整型),即需要进行几阶线性回归- 其他数据不太常用,可以不输入,即使用默认参数。如果需要了解可以参考:numpy.polyfit
输出参数 P
为拟合多项式
P ( 1 ) x n + P ( 2 ) x n − 1 + . . . + P ( n ) x + P ( n + 1 ) 的 系 数 组 合 P(1)x^n + P(2)x^{n-1} +...+ P(n)x + P(n+1) 的 系数组合 P(1)xn+P(2)xn−1+...+P(n)x+P(n+1)的系数组合
如 P
为[ 1, 2, 3]时,代表多项式线性回归的结果为
Y = x 2 + 2 x + 3 Y = x^2+2x+3 Y=x2+2x+3
可以用np.polyval()方法输出预测结果Y,即
Y = np.polyval(P, x)
4、模块输出可视化图表
要用到matplotlib.pyplot,这个模块内容非常非常多,现在根据需求选取几个易用的函数
官方文档:https://matplotlib.org/api/pyplot_api.html
功能一:绘制关系曲线
绘制一条x,y关系曲线,红色,宽度为2,标签为label
plt.plot(x, y, color="red”,linestyle="-", linewidth=2.0, label=‘label')
x, y
:与前面的x, y
相同,支持array格式的数组,分别代表自变量和因变量- 设置
label
标签有助于后续生成图例
import matplotlib.pyplot as plt
x=[1,2,3,5]
y=[2,3,5,9]
plt.plot(x, y,color="red",linestyle="-", linewidth=2.0,label='label1')
plt.show()
功能二:新增图例
plt.legend(loc=*'best'*,label=lable_list)
loc
=‘best’时图例自动‘安家’在一个坐标面内的数据图表最少的位置,可以设置为指定位置。
参考链接:https://zhuanlan.zhihu.com/p/111108841
功能三:箭头标注关键信息
对第三个坐标点用红色箭头标注,箭头离坐标相差0.05个单位。同时在(4,2)提醒’this is the annotate’.
plt.annotate('this is the annotate', xy=(x[2],y[2]), xycoords='data', xytext=(4,2),
arrowprops=dict(facecolor='red', shrink=0.05))
可以参考https://blog.csdn.net/wizardforcel/article/details/54782628
实例代码
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
def intt(list,exc_rate=1):#将"xxx,xxx,xxx,xxx,xxx"格式的str转化为 整型,exc_rate为汇率
list_new = []
for strr in list:
int_list = strr.split(',') # 分割str,转化为列表
lenth = len(int_list)
result = 0
for n in range(lenth):
ii = int(int_list[n])
result = result + ii*1000**(lenth-n-1)*exc_rate
list_new.append(result)
return list_new
def pre(n):#n为预测时间(年)
data = pd.read_csv("testgdp.csv")
df = pd.DataFrame(data)
df = df.drop([19])#删除空行
years = np.array(df.years)
cn = intt(np.array(df.cn))
usa = intt(np.array(df.us))
model_cn = np.polyfit(years,cn,2)#阶线性回归cn
model_usa = np.polyfit(years,usa,2)#2阶维线性回归usa
overyear_list = []
overusa_list = []
overcn_list = []
for i in range(n):#预测n年后gdp数据表现
yy=2020+i
cn_gdp=np.polyval(model_cn,yy)
usa_gdp=np.polyval(model_usa,yy)
if cn_gdp>usa_gdp:#判断何时中国gdp超过美国,并记录下来
overyear_list.append(yy)
overusa_list.append(usa_gdp)
overcn_list.append(cn_gdp)
cn = np.append(cn,cn_gdp)
usa = np.append(usa,usa_gdp)
years=np.append(years,yy)
plt.plot(years, cn,color="red",linestyle="-", linewidth=2.0,label='CN')
plt.plot(overyear_list, overcn_list, color="red", linestyle="-", linewidth=4.0)#加粗超过美国的部分
plt.plot(years, usa,color="blue",
linestyle="-", linewidth=2.0,label='USA')
plt.plot(years[0:len(years)-len(overyear_list)+1],
usa[0:len(years)-len(overyear_list)+1],
color="blue", linestyle="-", linewidth=4.0)
plt.legend(loc='upper left')#图例,位置左上
plt.annotate(s=("%d:CN%.1ftrillion ,USA%.1ftrillion"%(overyear_list[0],overcn_list[0]/(10**12),overusa_list[0]/(10**12))),xy=(overyear_list[0],overcn_list[0]),
xytext=(overyear_list[0]+n/10,overcn_list[0]*0.6)
,arrowprops=dict(facecolor='red', shrink=0.05))#arrowprops箭头
plt.show()
pre(40)
后续进阶
- 增加爬虫功能(合法的那种!)
- 优化可视化图表(增加图表样式,增加图像交互能力,如调用Pyecharts)
- 增加更多纬度数据,采用逻辑回归
- 增加与数据库对接的功能
来源:oschina
链接:https://my.oschina.net/u/4403186/blog/4560835