第十一课 使用RNN生成古诗
4.模型的训练
获取数据 batch 的代码位于poem.py的generate_batch方法,作用是用来获取每一个batch的数值。作为接下来模型训练的输入数据集。
传入参数有batch_size:batch的大小;poems_vec:前面生成的诗文中字ID;word_to_int:前面生成的每一个字ID。
def generate_batch(batch_size, poems_vec, word_to_int):
# 每次取batch_size首诗进行训练
n_chunk = len(poems_vec) // batch_size
x_batches = []
y_batches = []
#使用for循环,生成n_chunk个batch。
for i in range(n_chunk):
#每一个batch开始和结束的index
start_index = i * batch_size
end_index = start_index + batch_size
batches = poems_vec[start_index:end_index]
# 找到这个batch中所有poem最长的poem的长度,以这个长度为最大值生成batch中每一行的长度。
length = max(map(len, batches))
# 填充一个空batch,空的地方放空格对应的index标号
x_data = np.full((batch_size, length), word_to_int[' '], np.int32)
for row, batch in enumerate(batches):
# 每一行就是一首诗,在原本的长度上把诗还原上去
x_data[row, :len(batch)] = batch
y_data = np.copy(x_data)
# y就是x向左边移动一个,最后一位使用倒数第二位的数值填充
y_data[:, :-1] = x_data[:, 1:]
x_batches.append(x_data)
y_batches.append(y_data)
return x_batches, y_batches
模型的训练代码位于train.py的run_training方法
# -*- coding: utf-8 -*-
import tensorflow as tf
import os
import poems
import models
tf.app.flags.DEFINE_integer('batch_size', 64, 'batch size')
tf.app.flags.DEFINE_float('learning_rate', 0.01, 'learning rate')
tf.app.flags.DEFINE_string('model_dir', os.path.abspath('./model'), 'model save path')
tf.app.flags.DEFINE_string('file_path', os.path.abspath('./data/poems.txt'), 'file name of poems')
tf.app.flags.DEFINE_string('model_prefix', 'poems', 'model save prefix')
tf.app.flags.DEFINE_integer('epochs', 50, 'train how many epochs')
FLAGS = tf.app.flags.FLAGS
def run_training():
if not os.path.exists(FLAGS.model_dir):
os.makedirs(FLAGS.model_dir)
# 读取诗集文件
# 依次得到数字ID表示的诗句、汉字-ID的映射map、所有的汉字的列表
poems_vector, word_to_int, vocabularies = poems.process_poems(FLAGS.file_path)
batches_inputs, batches_outputs = poems.generate_batch(FLAGS.batch_size, poems_vector, word_to_int)
input_data = tf.placeholder(tf.int32, [FLAGS.batch_size, None])
output_targets = tf.placeholder(tf.int32, [FLAGS.batch_size, None])
# 通过rnn模型得到结果状态集
end_points = models.rnn_model(model='lstm', input_data=input_data, output_data=output_targets,
vocab_size=len(vocabularies), rnn_size=128, num_layers=2, batch_size=64,
learning_rate=FLAGS.learning_rate)
# 初始化saver和session
saver = tf.train.Saver(tf.global_variables())
init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
with tf.Session() as sess:
sess.run(init_op)
start_epoch = 0
checkpoint = tf.train.latest_checkpoint(FLAGS.model_dir)
if checkpoint:
saver.restore(sess, checkpoint)
print('## restore from the checkpointt {0}'.format(checkpoint))
start_epoch += int(checkpoint.split('-')[-1])
print('## strat training...')
try:
n_chunk = len(poems_vector) // FLAGS.batch_size
for epoch in range(start_epoch, FLAGS.epoches):
n = 0
for batch in range(n_chunk):
# 训练并计算loss
# batches_inputs[n]: 第n个batch的输入数据
# batches_outputs[n]: 第n个batch的输出数据
loss, _, _ = sess.run([
end_points['total_loss'],
end_points['last_state'],
end_points['train_op']],
feed_dict={input_data: batches_inputs[n], output_targets: batches_outputs[n]}
)
n += 1
print('Epoch: %d, batch: %d, training loss: %.6f' % (epoch, batch, loss))
# 每训练6个epoch进行一次模型保存
if epoch % 6 == 0:
saver.save(sess, os.path.join(FLAGS.model_dir, FLAGS.model_prefix), global_step=epoch)
except KeyboardInterrupt:
print('## Interrupt manually, try saving checkpoint for now...')
saver.save(sess, os.path.join(FLAGS.model_dir, FLAGS.model_prefix), global_step=epoch)
print('## Last epoch were saved, next time will start from epoch {}.'.format(epoch))
def main(_):
run_training()
if __name__ == '__main__':
tf.app.run()