安装
pip install catboost
数据集
分类:MNIST(60000条数据784个特征),已上传CSDN
代码
import random import numpy as np import pandas as pd import matplotlib.pyplot as plt from catboost import CatBoostClassifier from sklearn.model_selection import train_test_split
train = pd.read_csv('./input/mnist/train.csv') train.head()
X = train.iloc[:, 1:] # 训练数据 y = train['label'] #标签
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) # 划分训练、测试集
def plot_digits(instances, images_per_row=10): '''绘制数据集 :param instances: 部分数据集 :type instances: numpy.ndarray :param images_per_row: 每一行显示图片数 ''' size = 28 images_per_row = min(len(instances), images_per_row) images = [instance.reshape(size, size) for instance in instances] n_rows = (len(instances) - 1) // images_per_row + 1 row_images = [] n_empty = n_rows * images_per_row - len(instances) images.append(np.zeros((size, size * n_empty))) for row in range(n_rows): rimages = images[row * images_per_row: (row + 1) * images_per_row] row_images.append(np.concatenate(rimages, axis=1)) image = np.concatenate(row_images, axis=0) plt.imshow(image, cmap='gray_r') plt.axis("off") plt.figure() plot_digits(X_train[:100].values, images_per_row=10) plt.show()
# 定义模型 clf = CatBoostClassifier()
# 训练 model = clf.fit(X_train, y_train)
0: learn: 2.2139620 total: 975ms remaining: 16m 13s 1: learn: 2.1344069 total: 1.95s remaining: 16m 15s 2: learn: 2.0559619 total: 2.92s remaining: 16m 10s 3: learn: 1.9850790 total: 3.89s remaining: 16m 7s ...... 996: learn: 0.1231917 total: 16m 35s remaining: 3s 997: learn: 0.1231500 total: 16m 36s remaining: 2s 998: learn: 0.1231068 total: 16m 37s remaining: 999ms 999: learn: 0.1230654 total: 16m 38s remaining: 0us
# 评估 print('accuracy:', model.score(X_test, y_test))
# 保存 model.save_model('mnist.model')
# 加载 ccc = CatBoostClassifier() ccc.load_model('mnist.model')
# 预测 index = random.randint(0, len(X_test)) # 随机挑一个 _X = X_test.values[index] _y = y_test.values[index] # 真值 predict = ccc.predict(_X)[0] # 预测值 _X = _X.reshape(28, 28) plt.imshow(_X, cmap='gray_r') plt.title('original {}'.format(_y)) plt.show() print('index:', index) print('original:', _y) print('predicted:', predict)
index: 7534 original: 6 predicted: 6
index: 6510 original: 4 predicted: 4
index: 7311 original: 6 predicted: 6
ipynb
参考文献
- Battle of the Boosting Algos: LGB, XGB, Catboost
- CatBoost - open-source gradient boosting library
- Quick start - CatBoost. Documentation
- CatBoost tutorials
- 机器学习算法之Catboost
来源:https://www.cnblogs.com/XerCis/p/12366255.html