首页 热点专区 小学知识 中学知识 出国留学 考研考公
您的当前位置:首页正文

自然语言处理N天-Day1103从0搭建一个RNN神经网络作诗(

2024-12-20 来源:要发发知识网

第十一课 使用RNN生成古诗

4.模型的训练

获取数据 batch 的代码位于poem.py的generate_batch方法,作用是用来获取每一个batch的数值。作为接下来模型训练的输入数据集。
传入参数有batch_size:batch的大小;poems_vec:前面生成的诗文中字ID;word_to_int:前面生成的每一个字ID。

def generate_batch(batch_size, poems_vec, word_to_int):
    # 每次取batch_size首诗进行训练
    n_chunk = len(poems_vec) // batch_size
    x_batches = []
    y_batches = []
    #使用for循环,生成n_chunk个batch。
    for i in range(n_chunk):
        #每一个batch开始和结束的index
        start_index = i * batch_size
        end_index = start_index + batch_size
        batches = poems_vec[start_index:end_index]
        
        # 找到这个batch中所有poem最长的poem的长度,以这个长度为最大值生成batch中每一行的长度。
        length = max(map(len, batches))
        # 填充一个空batch,空的地方放空格对应的index标号
        x_data = np.full((batch_size, length), word_to_int[' '], np.int32)

        for row, batch in enumerate(batches):
            # 每一行就是一首诗,在原本的长度上把诗还原上去
            x_data[row, :len(batch)] = batch
        y_data = np.copy(x_data)
        # y就是x向左边移动一个,最后一位使用倒数第二位的数值填充
        y_data[:, :-1] = x_data[:, 1:]
        x_batches.append(x_data)
        y_batches.append(y_data)
    return x_batches, y_batches

模型的训练代码位于train.py的run_training方法

# -*- coding: utf-8 -*-
import tensorflow as tf
import os
import poems
import models

tf.app.flags.DEFINE_integer('batch_size', 64, 'batch size')
tf.app.flags.DEFINE_float('learning_rate', 0.01, 'learning rate')
tf.app.flags.DEFINE_string('model_dir', os.path.abspath('./model'), 'model save path')
tf.app.flags.DEFINE_string('file_path', os.path.abspath('./data/poems.txt'), 'file name of poems')
tf.app.flags.DEFINE_string('model_prefix', 'poems', 'model save prefix')
tf.app.flags.DEFINE_integer('epochs', 50, 'train how many epochs')

FLAGS = tf.app.flags.FLAGS


def run_training():
    if not os.path.exists(FLAGS.model_dir):
        os.makedirs(FLAGS.model_dir)
    # 读取诗集文件
    # 依次得到数字ID表示的诗句、汉字-ID的映射map、所有的汉字的列表
    poems_vector, word_to_int, vocabularies = poems.process_poems(FLAGS.file_path)
    batches_inputs, batches_outputs = poems.generate_batch(FLAGS.batch_size, poems_vector, word_to_int)

    input_data = tf.placeholder(tf.int32, [FLAGS.batch_size, None])
    output_targets = tf.placeholder(tf.int32, [FLAGS.batch_size, None])

    # 通过rnn模型得到结果状态集
    end_points = models.rnn_model(model='lstm', input_data=input_data, output_data=output_targets,
                                  vocab_size=len(vocabularies), rnn_size=128, num_layers=2, batch_size=64,
                                  learning_rate=FLAGS.learning_rate)

    # 初始化saver和session
    saver = tf.train.Saver(tf.global_variables())
    init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
    with tf.Session() as sess:
        sess.run(init_op)

        start_epoch = 0
        checkpoint = tf.train.latest_checkpoint(FLAGS.model_dir)
        if checkpoint:
            saver.restore(sess, checkpoint)
            print('## restore from the checkpointt {0}'.format(checkpoint))
            start_epoch += int(checkpoint.split('-')[-1])
        print('## strat training...')

        try:
            n_chunk = len(poems_vector) // FLAGS.batch_size
            for epoch in range(start_epoch, FLAGS.epoches):
                n = 0
                for batch in range(n_chunk):
                    # 训练并计算loss
                    # batches_inputs[n]: 第n个batch的输入数据
                    # batches_outputs[n]: 第n个batch的输出数据
                    loss, _, _ = sess.run([
                        end_points['total_loss'],
                        end_points['last_state'],
                        end_points['train_op']],
                        feed_dict={input_data: batches_inputs[n], output_targets: batches_outputs[n]}
                    )
                    n += 1
                    print('Epoch: %d, batch: %d, training loss: %.6f' % (epoch, batch, loss))
                    # 每训练6个epoch进行一次模型保存
                    if epoch % 6 == 0:
                        saver.save(sess, os.path.join(FLAGS.model_dir, FLAGS.model_prefix), global_step=epoch)
        except KeyboardInterrupt:
            print('## Interrupt manually, try saving checkpoint for now...')
            saver.save(sess, os.path.join(FLAGS.model_dir, FLAGS.model_prefix), global_step=epoch)
            print('## Last epoch were saved, next time will start from epoch {}.'.format(epoch))


def main(_):
    run_training()


if __name__ == '__main__':
    tf.app.run()

要发发知识网还为您提供以下相关内容希望对您有帮助:

循环神经网络(RNN)浅析

RNN也实现了类似于人脑的这一机制,对所处理过的信息留存有一定的记忆,而不像其他类型的神经网络并不能对处理过的信息留存记忆。 循环神经网络的原理并不十分复杂,本节主要从原理上分析RNN的结构和功能,不涉及RNN的数学推导和证明,整个网络只有简单的输入输出和网络状态参数。一个典型的RNN神经网络如图所示:由上图可...

用Pytorch实现循环神经网络RNN

RNN主要用于处理时间序列数据、自然语言处理(NLP)等序列数据,根据输入输出所含时间序列的步长,RNNs大体可以分为以下几种。多层RNNs(Stacked RNNs)的结构 多层RNNs一般用于提高性能。双向RNNs(Bidirectional RNNs)的结构 双向RNNs使用了两个RNNs网络结构,输入序列按照正序输入其中一个RNNs,按照逆序...

深度学习-循环神经网络RNN

循环神经网络(Recurrent Neural Networks,RNN)在处理具有时间序列或顺序特征的数据时表现出色,广泛应用于自然语言处理、语音识别、时间序列预测等领域。与卷积神经网络相比,RNN能够考虑时序数据的特征,通过隐藏层特征的时序传递实现序列数据的融合与决策输出。在应用RNN进行预测时,如预测球在下一时刻的位置...

一文读懂循环神经网络(RNN)

总的来说,RNN及其变体通过引入循环连接,使得神经网络能够在处理时序数据时捕获动态特征,显著扩展了神经网络的适用范围。从简单的语音识别到复杂的自然语言处理任务,RNN和其变体都是不可或缺的工具。

Recurrent Memory Networks for Language Modeling

在自然语言处理领域,递归神经网络(RNNs)特别是长短期记忆(LSTM)在语言建模和众多任务中表现出色,但其内部功能理解仍是难题。本文提出了递归记忆网络(RMN),一种创新的RNN架构,它通过结合LSTM和记忆网络的优点,增强了RNN的功能,并提供了对内部操作的洞察,有助于揭示数据中的模式。RMN在德语、意大利语...

课堂深度学习的四个步骤

CS224N/Ling284课程是一个很好的起点,CS224d:DeepLearningforNaturalLanguageProcessing由David Socher教授,涵盖了自然语言处理的最新深度学习研究。最后,记忆网络(RNN-LSTM)是一个有趣的领域,LSTM复发神经网络与外部可写内存结合,使得存储和以问答方式检索信息成为可能。这个研究领域起源于Dr.Yann LeCun...

【模型解读】浅析RNN到LSTM

处理非固定长度或大小的信息,如视频和语音,时间序列模型如RNN和LSTM尤为适用。让我们首先来看看RNN,它有两种类型:Recurrent Neural Networks(循环神经网络)和Recursive Neural Networks(递归神经网络),后者主要用于自然语言处理中的序列学习,但这里主要讲解循环神经网络。RNN的结构包括当前时刻输入xt和上...

零基础学Python应该学习哪些入门知识

第一步至关重要,关系到初学者从入门到精通还是从入门到放弃。选一条合适的入门道路,并坚持走下去。2.2 配置 Python 学习环境。选Python2 还是 Python3?入门时很多人都会纠结。二者只是程序不兼容,思想上并无大差别,语法变动也并不多。选择任何一个入手,都没有大影响。如果你仍然无法抉择,那请...

显示全文

猜你还关注