MTCNN(Tensorflow)学习记录(RNet的训练)
来源:锐游网
为RNet网络生成TFRecord文件,现在开始对RNet进行训练。
1 RNet的训练
进入train_models
文件夹打开train_RNet.py
,代码如下:
#coding:utf-8
from train_models.mtcnn_model import R_Net
from train_models.train import train
def train_RNet(base_dir, prefix, end_epoch, display, lr):
"""
train PNet
:param dataset_dir: tfrecord path
:param prefix:
:param end_epoch:
:param display:
:param lr:
:return:
"""
net_factory = R_Net
train(net_factory, prefix, end_epoch, base_dir, display=display, base_lr=lr)
#net_factory = R_Net
#prefix='../data/MTCNN_model/RNet_Landmark/RNet'
#end_epoch=22
#base_dir='../../DATA/imglists/RNet'
#display = 100
#base_lr=0.001
if __name__ == '__main__':
base_dir = '../../DATA/imglists_noLM/RNet'
model_name = 'MTCNN'
model_path = '../data/%s_model/RNet_No_Landmark/RNet' % model_name
prefix = model_path
end_epoch = 22
display = 100
lr = 0.001
train_RNet(base_dir, prefix, end_epoch, display, lr)
由上可以看出调用了R_Net
和train
这两个函数,我们在这里将这两个导出来,train
函数的代码如下:
def train(net_factory, prefix, end_epoch, base_dir,
display=200, base_lr=0.01):
"""
train PNet/RNet/ONet
:param net_factory:
:param prefix: model path
:param end_epoch:
:param dataset:
:param display:
:param base_lr:
:return:
"""
#net_factory = R_Net
#prefix='../data/MTCNN_model/RNet_Landmark/RNet'
#end_epoch=22
#base_dir='../../DATA/imglists/RNet'
#display = 100
#base_lr=0.001
net = prefix.split('/')[-1] #net=RNet
#label file
label_file = os.path.join(base_dir,'train_%s_landmark.txt' % net)
#label_file ='../../DATA/imglists/RNet/train_RNet_landmark.txt'
print(label_file)
f = open(label_file, 'r')
num = len(f.readlines()) #获得训练样本的数量
print("Total size of the dataset is: ", num) #prefix='../data/MTCNN_model/RNet_Landmark/RNet'
print(prefix)
if net == 'PNet':
dataset_dir = os.path.join(base_dir,'train_%s_landmark.tfrecord_shuffle' % net)
print('dataset dir is:',dataset_dir)
image_batch, label_batch, bbox_batch,landmark_batch = read_single_tfrecord(dataset_dir, config.BATCH_SIZE, net)
#RNet使用了四个tfrecord文件获取数据
else:
pos_dir = os.path.join(base_dir,'pos_landmark.tfrecord_shuffle')
#'../../DATA/imglists/RNet/pos_landmark.tfrecord_shuffle'
part_dir = os.path.join(base_dir,'part_landmark.tfrecord_shuffle')
#'../../DATA/imglists/RNet/part_landmark.tfrecord_shuffle'
neg_dir = os.path.join(base_dir,'neg_landmark.tfrecord_shuffle')
#'../../DATA/imglists/RNet/neg_landmark.tfrecord_shuffle'
landmark_dir = os.path.join('../../DATA/imglists/RNet','landmark_landmark.tfrecord_shuffle')
#'../../DATA/imglists/RNet/landmark_landmark.tfrecord_shuffle'
#将四个tfrecord文件的路径传入列表dataset_dirs
dataset_dirs = [pos_dir,part_dir,neg_dir,landmark_dir]
pos_radio = 1.0/6;part_radio = 1.0/6;landmark_radio=1.0/6;neg_radio=3.0/6
#pos:part:landmark:neg=1:1:1:3
pos_batch_size = int(np.ceil(config.BATCH_SIZE*pos_radio))
assert pos_batch_size != 0,"Batch Size Error "
part_batch_size = int(np.ceil(config.BATCH_SIZE*part_radio))
assert part_batch_size != 0,"Batch Size Error "
neg_batch_size = int(np.ceil(config.BATCH_SIZE*neg_radio))
assert neg_batch_size != 0,"Batch Size Error "
landmark_batch_size = int(np.ceil(config.BATCH_SIZE*landmark_radio))
assert landmark_batch_size != 0,"Batch Size Error "
batch_sizes = [pos_batch_size,part_batch_size,neg_batch_size,landmark_batch_size]
#batch_sizes = [64, 64, 192, 64]
image_batch, label_batch, bbox_batch,landmark_batch = read_multi_tfrecords(dataset_dirs,batch_sizes, net)
#使用read_multi_tfrecord()函数获得数据,这个函数和PNet训练中读取tfrecord文件使用的的read_single_tfrecord()函数基本类似,可以对照着看,这里就不多做注释了
#
if net == 'PNet':
image_size = 12
radio_cls_loss = 1.0;radio_bbox_loss = 0.5;radio_landmark_loss = 0.5;
elif net == 'RNet':
image_size = 24
radio_cls_loss = 1.0;radio_bbox_loss = 0.5;radio_landmark_loss = 0.5;
#RNet三个训练损失的权重比为1:0.5:0.5
else:
radio_cls_loss = 1.0;radio_bbox_loss = 0.5;radio_landmark_loss = 1;
image_size = 48
#定义placeholder
input_image = tf.placeholder(tf.float32, shape=[config.BATCH_SIZE, image_size, image_size, 3], name='input_image')
label = tf.placeholder(tf.float32, shape=[config.BATCH_SIZE], name='label')
bbox_target = tf.placeholder(tf.float32, shape=[config.BATCH_SIZE, 4], name='bbox_target')
landmark_target = tf.placeholder(tf.float32,shape=[config.BATCH_SIZE,10],name='landmark_target')
#通过net_factory(RNet)获得loss和accuracy值
input_image = image_color_distort(input_image)
cls_loss_op,bbox_loss_op,landmark_loss_op,L2_loss_op,accuracy_op = net_factory(input_image, label, bbox_target,landmark_target,training=True)
#接下来的代码和PNet的训练过程一模一样,就不做详细介绍了
total_loss_op = radio_cls_loss*cls_loss_op + radio_bbox_loss*bbox_loss_op + radio_landmark_loss*landmark_loss_op + L2_loss_op
train_op, lr_op = train_model(base_lr,
total_loss_op,
num)
# init
init = tf.global_variables_initializer()
sess = tf.Session()
#save model
saver = tf.train.Saver(max_to_keep=0)
sess.run(init)
#visualize some variables
tf.summary.scalar("cls_loss",cls_loss_op)#cls_loss
tf.summary.scalar("bbox_loss",bbox_loss_op)#bbox_loss
tf.summary.scalar("landmark_loss",landmark_loss_op)#landmark_loss
tf.summary.scalar("cls_accuracy",accuracy_op)#cls_acc
tf.summary.scalar("total_loss",total_loss_op)#cls_loss, bbox loss, landmark loss and L2 loss add together
summary_op = tf.summary.merge_all()
logs_dir = "../logs/%s" %(net)
if os.path.exists(logs_dir) == False:
os.mkdir(logs_dir)
writer = tf.summary.FileWriter(logs_dir,sess.graph)
projector_config = projector.ProjectorConfig()
projector.visualize_embeddings(writer,projector_config)
#begin
coord = tf.train.Coordinator()
#begin enqueue thread
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
i = 0
#total steps
MAX_STEP = int(num / config.BATCH_SIZE + 1) * end_epoch
epoch = 0
sess.graph.finalize()
try:
for step in range(MAX_STEP):
i = i + 1
if coord.should_stop():
break
image_batch_array, label_batch_array, bbox_batch_array,landmark_batch_array = sess.run([image_batch, label_batch, bbox_batch,landmark_batch])
#random flip
image_batch_array,landmark_batch_array = random_flip_images(image_batch_array,label_batch_array,landmark_batch_array)
'''
print('im here')
print(image_batch_array.shape)
print(label_batch_array.shape)
print(bbox_batch_array.shape)
print(landmark_batch_array.shape)
print(label_batch_array[0])
print(bbox_batch_array[0])
print(landmark_batch_array[0])
'''
_,_,summary = sess.run([train_op, lr_op ,summary_op], feed_dict={input_image: image_batch_array, label: label_batch_array, bbox_target: bbox_batch_array,landmark_target:landmark_batch_array})
if (step+1) % display == 0:
#acc = accuracy(cls_pred, labels_batch)
cls_loss, bbox_loss,landmark_loss,L2_loss,lr,acc = sess.run([cls_loss_op, bbox_loss_op,landmark_loss_op,L2_loss_op,lr_op,accuracy_op],
feed_dict={input_image: image_batch_array, label: label_batch_array, bbox_target: bbox_batch_array, landmark_target: landmark_batch_array})
total_loss = radio_cls_loss*cls_loss + radio_bbox_loss*bbox_loss + radio_landmark_loss*landmark_loss + L2_loss
# landmark loss: %4f,
print("%s : Step: %d/%d, accuracy: %3f, cls loss: %4f, bbox loss: %4f,Landmark loss :%4f,L2 loss: %4f, Total Loss: %4f ,lr:%f " % (
datetime.now(), step+1,MAX_STEP, acc, cls_loss, bbox_loss,landmark_loss, L2_loss,total_loss, lr))
#save every two epochs
if i * config.BATCH_SIZE > num*2:
epoch = epoch + 1
i = 0
path_prefix = saver.save(sess, prefix, global_step=epoch*2)
print('path prefix is :', path_prefix)
writer.add_summary(summary,global_step=step)
except tf.errors.OutOfRangeError:
print("完成!!!")
finally:
coord.request_stop()
writer.close()
coord.join(threads)
sess.close()
,RNet的训练和PNet的训练差不多,大家可以互相对应,大多数相同的的部分我就省略了。
RNet的训练数据来源是经过图像金字塔获得不同size的图像,然后再通过NMS(非极大值抑制)算法筛选出合格的人脸框。将得到的训练数据resize成24*24。RNet的网络最后的输出的是通过全连接层, 和PNet有差异,但是损失计算的方法都是同样的。
RNet的训练中,每循环完两个数据周期,就记录一次模型的参数。
因篇幅问题不能全部显示,请点此查看更多更全内容