代码地址:
https://github.com/xiaoxu1025/rcnn-keras
这里仅仅只是做了个简单的实现 方便对论文有更深的理解
代码参考:
https://blog.csdn.net/u014796085/article/details/83931150
https://github.com/yangxue0827/RCNN
R-CNN分为三个部分 selective-search、特征抽取、svm
对于模型我这里是用的keras自带的VGG16来做特征抽取
from keras.applications.vgg16 import VGG16
from keras.models import Model
from keras.layers import Dense, Flatten, Input
def create_model(num_classes):
input = Input(shape=(224, 224, 3))
vgg16_model = VGG16(input_tensor=input, include_top=False)
# 弹出最后一层
x = vgg16_model.output
x = Flatten(name='flatten')(x)
x = Dense(4096, activation='relu', name='fc1')(x)
x = Dense(4096, activation='relu', name='fc2')(x)
x = Dense(num_classes, activation='softmax', name='predictions%s' % num_classes)(x)
model = Model(inputs=vgg16_model.input, outputs=x)
for layer in vgg16_model.layers:
# 让其不可训练
layer.trainable = False
return model
# model = create_model(21)
# model.summary()
然后进行fine-tune
import os
import config as cfg
from model import create_model
import keras
from preprocess import load_data
def main():
# load data
x, y, _ = load_data('./fine_tune_list.txt')
# 构建模型
model = create_model(cfg.FINE_TUNE_CLASSES)
if os.path.exists('./weights/fine_tune_weights.h5'):
model.load_weights('./weights/fine_tune_weights.h5')
# 训练数据
model.compile(loss='sparse_categorical_crossentropy', optimizer=keras.optimizers.SGD(0.001), metrics=['accuracy'])
model.fit(x, y, epochs=5, batch_size=128)
model.save_weights('./weights/fine_tune_weights.h5')
if __name__ == '__main__':
main()
fine-tune之后进行训练
import os
import config as cfg
from model import create_model
import keras
from preprocess import load_data
def main():
# load data
x, y, _ = load_data('./train_list.txt')
# 构建模型
model = create_model(cfg.TRAIN_CLASSES)
# 加载fine_tune权重
model.load_weights('./weights/fine_tune_weights.h5', by_name=True)
# 训练数据
model.compile(loss='sparse_categorical_crossentropy', optimizer=keras.optimizers.SGD(0.001), metrics=['accuracy'])
model.fit(x, y, epochs=10, batch_size=128)
model.save_weights('./weights/train_weights.h5')
if __name__ == '__main__':
main()
然后从训练后的模型得到4096的特征向量进行svm分类
from preprocess import load_svm_data
from keras.models import Model
from sklearn.svm import SVC
import numpy as np
import joblib
from model import create_model
import config as cfg
from bbox import train_bbox
def train_svm(model, data_file):
svms = []
x, y, r = load_svm_data(data_file)
new_model = Model(inputs=model.input, outputs=model.get_layer('fc2').output)
# 预测的features
# 用于SVM分类
features = new_model.predict(x)
features_ = features[y >= 0]
y_ = y[y >= 0]
# 难负例挖掘
# 对于目标检测(object detection)问题,所谓的 hard-negative mining 针对的是训练集中的
# negative training set(对于目标检测问题就是图像中非不存在目标的样本集合),
# 对该负样本集中的每一副图像(的每一个可能的尺度),应用滑窗(sliding window)技术。
# 对每次滑窗捕获的图像区域,计算该区域的 HOG 描述子,并作为分类器的输入。
# 如果预定义的分类器将其错误地在其中检测出对象,也即 FP(false-positive,伪正),
# 记录该 FP patch 对应的特征向量及分类器给出的概率。
# Y_hard = Y[Y < 0]
features_hard = features[y < 0]
pred_last = -1
pred_now = 0
index = 0
while pred_now > pred_last:
clf = SVC(probability=True)
clf.fit(features_, y_)
pred_ = clf.predict(features_hard)
pred_prob = clf.predict_proba(features_hard)
# 分类错误的样本
Y_new_hard = pred_prob[pred_ > 0][:, 1]
features_new_hard_ = features_hard[pred_ > 0]
index_new_hard = range(Y_new_hard.shape[0])
# 如果难负例样本过少,停止迭代
if Y_new_hard.shape[0] // 10 < 1:
break
# 统计分类正确的数量
count = pred_[pred_ == 0].shape[0]
pred_last = pred_now
# 计算新的测试正确率
pred_now = count / features_hard.shape[0]
idx = np.argsort(Y_new_hard)[::-1][0:len(Y_new_hard) // 10]
y_ = np.concatenate([y_, np.zeros(len(idx), dtype=np.int32)], axis=0)
for i in idx:
features_list = features_.tolist()
features_list.append(features_new_hard_[i])
features_.tolist().append(features_new_hard_[i])
features_ = np.asarray(features_list)
features_hard_list = features_hard.tolist()
features_hard_list.pop(index_new_hard[i])
features_hard = np.asarray(features_hard_list)
svms.append(clf)
# 将clf序列化,保存svm分类器
joblib.dump(clf, './svm/svm%s.pkl' % index)
index += 1
return svms
if __name__ == '__main__':
# 这里就用两类来做测试
#
# model = create_model(cfg.TRAIN_CLASSES) train_list.txt
model = create_model(cfg.FINE_TUNE_CLASSES)
model.load_weights('./weights/fine_tune_weights.h5', by_name=True)
train_svm(model, './fine_tune_list.txt')
最后进行测试代码如下
from PIL import Image
import numpy as np
from model import create_model
import config as cfg
from ss.selectivesearch import selective_search
from utils import resize, crop_image
import joblib
import os
from keras.models import Model
from utils import cal_iou
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import skimage
def show_rect(img_path, regions):
fig, ax = plt.subplots(ncols=1, nrows=1, figsize=(6, 6))
img = skimage.io.imread(img_path)
ax.imshow(img)
for x, y, w, h in regions:
rect = mpatches.Rectangle(
(x, y), w, h, fill=False, edgecolor='red', linewidth=1)
ax.add_patch(rect)
plt.show()
def generate_image_proposal(img):
img_data = np.asarray(img)
img_lbl, regions = selective_search(img_data)
candidates = set()
images = []
vertices = []
for r in regions:
# excluding same rectangle (with different segments)
if r['rect'] in candidates:
continue
# # excluding small regions
if r['size'] < 220:
continue
if (r['rect'][2] * r['rect'][3]) < 500:
continue
# resize to 227 * 227 for input
proposal_img = crop_image(img, r['rect'])
# Ignore things contain 0 or not C contiguous array
resized_proposal_img = resize(proposal_img, cfg.IMAGE_SIZE)
candidates.add(r['rect'])
img_float = np.asarray(resized_proposal_img, dtype=np.float32)
images.append(img_float)
vertices.append(r['rect'])
return np.asarray(images), np.asarray(vertices)
def main():
img_path = './data/2flowers/jpg/0/image_0561.jpg'
img = Image.open(img_path)
img_data = np.asarray(img)
im_width = img_data.shape[1]
im_height = img_data.shape[0]
imgs, verts = generate_image_proposal(img)
model = create_model(cfg.FINE_TUNE_CLASSES)
model.load_weights('./weights/fine_tune_weights.h5', by_name=True)
# get freatures
new_model = Model(inputs=model.input, outputs=model.get_layer('fc2').output)
features = new_model.predict(imgs)
# 加载/训练svm分类器 和 boundingbox回归器
svms = []
files = os.listdir('./svm')
for file in files:
if file.find('svm') == 0:
svms.append(joblib.load('./svm/%s' % file))
bbox_fit = joblib.load('./svm/bbox_train.pkl')
results = []
results_label = []
results_score = []
count = 0
for feature in features:
for svm in svms:
pred = svm.predict([feature.tolist()])
# not background
if pred[0] != 0:
bbox = bbox_fit.predict([feature.tolist()])
tx, ty, tw, th = bbox[0][0], bbox[0][1], bbox[0][2], bbox[0][3]
px, py, pw, ph = verts[count]
gx = tx * pw + px
gy = ty * ph + py
gw = np.exp(tw) * pw
gh = np.exp(th) * ph
if gx < 0:
gw = gw - (0 - gx)
gx = 0
if gx + gw > im_width:
gw = im_width - gx
if gy < 0:
gh = gh - (0 - gh)
gy = 0
if gy + gh > im_height:
gh = im_height - gy
results.append([gx, gy, gw, gh])
results_label.append(pred[0])
results_score.append(svm.predict_proba([feature.tolist()])[0][1])
count += 1
results_final = []
results_final_label = []
# 非极大抑制
# 删除得分小于0.5的候选框
delete_index1 = []
for ind in range(len(results_score)):
if results_score[ind] < 0.5:
delete_index1.append(ind)
num1 = 0
for idx in delete_index1:
results.pop(idx - num1)
results_score.pop(idx - num1)
results_label.pop(idx - num1)
num1 += 1
while len(results) > 0:
max_index = results_score.index(max(results_score))
max_x, max_y, max_w, max_h = results[max_index]
max_vertice = [max_x, max_y, max_x + max_w, max_y + max_h, max_w, max_h]
# 该候选框加入最终结果
results_final.append(results[max_index])
results_final_label.append(results_label[max_index])
# 从results中删除该候选框
results.pop(max_index)
results_label.pop(max_index)
results_score.pop(max_index)
# print(len(results_score))
# 删除与得分最高候选框iou>0.5的其他候选框
delete_index = []
for ind, i in enumerate(results):
iou_val = cal_iou(i, max_vertice)
if iou_val > 0.5:
delete_index.append(ind)
num = 0
for idx in delete_index:
# print('\n')
# print(idx)
# print(len(results))
results.pop(idx - num)
results_score.pop(idx - num)
results_label.pop(idx - num)
num += 1
print("result:")
print(results_final)
print("result label:")
print(results_final_label)
show_rect(img_path, results_final)
if __name__ == '__main__':
main()
上述代码仅仅是为了理解论文所写 并没有经过严格测试 如果有问题希望理解
来源:CSDN
作者:梦中又说人间梦
链接:https://blog.csdn.net/xiaoxu1025/article/details/104080965