合规国际互联网加速 OSASE为企业客户提供高速稳定SD-WAN国际加速解决方案。 广告
# Keras 中的 LSTM 文本生成 您可以在 Jupyter 笔记本`ch-08b_RNN_Text_Keras`中按照本节的代码进行操作。 我们在 Keras 实现文本生成 LSTM,步骤如下: 1. 首先,我们将所有数据转换为两个张量,张量`x`有五列,因为我们一次输入五个字,张量`y`只有一列输出。我们将`y`或标签张量转换为单热编码表示。 请记住,在大型数据集的实践中,您将使用 word2vec 嵌入而不是单热表示。 ```py # get the data x_train, y_train = text8.seq_to_xy(seq=text8.part['train'],n_tx=n_x,n_ty=n_y) # reshape input to be [samples, time steps, features] x_train = x_train.reshape(x_train.shape[0], x_train.shape[1],1) y_onehot = np.zeros(shape=[y_train.shape[0],text8.vocab_len],dtype=np.float32) for i in range(y_train.shape[0]): y_onehot[i,y_train[i]]=1 ``` 1. 接下来,仅使用一个隐藏的 LSTM 层定义 LSTM 模型。由于我们的输出不是序列,我们还将`return_sequences`设置为`False`: ```py n_epochs = 1000 batch_size=128 state_size=128 n_epochs_display=100 # create and fit the LSTM model model = Sequential() model.add(LSTM(units=state_size, input_shape=(x_train.shape[1], x_train.shape[2]), return_sequences=False ) ) model.add(Dense(text8.vocab_len)) model.add(Activation('softmax')) model.compile(loss='categorical_crossentropy', optimizer='adam') model.summary() ``` 该模型如下所示: ```py Layer (type) Output Shape Param # ================================================================= lstm_1 (LSTM) (None, 128) 66560 _________________________________________________________________ dense_1 (Dense) (None, 1457) 187953 _________________________________________________________________ activation_1 (Activation) (None, 1457) 0 ================================================================= Total params: 254,513 Trainable params: 254,513 Non-trainable params: 0 _________________________________________________________________ ``` 1. 对于 Keras,我们运行一个循环来运行 10 次,在每次迭代中训练 100 个周期的模型并打印文本生成的结果。以下是训练模型和生成文本的完整代码: ```py for j in range(n_epochs // n_epochs_display): model.fit(x_train, y_onehot, epochs=n_epochs_display, batch_size=batch_size,verbose=0) # generate text y_pred_r5 = np.empty([10]) y_pred_f5 = np.empty([10]) x_test_r5 = random5.copy() x_test_f5 = first5.copy() # let us generate text of 10 words after feeding 5 words for i in range(10): for x,y in zip([x_test_r5,x_test_f5], [y_pred_r5,y_pred_f5]): x_input = x.copy() x_input = x_input.reshape(-1, n_x, n_x_vars) y_pred = model.predict(x_input)[0] y_pred_id = np.argmax(y_pred) y[i]=y_pred_id x[:-1] = x[1:] x[-1] = y_pred_id print('Epoch: ',((j+1) * n_epochs_display)-1) print(' Random5 prediction:',id2string(y_pred_r5)) print(' First5 prediction:',id2string(y_pred_f5)) ``` 1. 输出并不奇怪,从重复单词开始,模型有所改进,但是可以通过更多 LSTM 层,更多数据,更多训练迭代和其他超参数调整来进一步提高。 ```py Random 5 words: free bolshevik be n another First 5 words: anarchism originated as a term ``` 预测的输出如下: ```py Epoch: 99 Random5 prediction: anarchistic anarchistic wrote wrote wrote wrote wrote wrote wrote wrote First5 prediction: right philosophy than than than than than than than than Epoch: 199 Random5 prediction: anarchistic anarchistic wrote wrote wrote wrote wrote wrote wrote wrote First5 prediction: term i revolutionary than war war french french french french Epoch: 299 Random5 prediction: anarchistic anarchistic wrote wrote wrote wrote wrote wrote wrote wrote First5 prediction: term i revolutionary revolutionary revolutionary revolutionary revolutionary revolutionary revolutionary revolutionary Epoch: 399 Random5 prediction: anarchistic anarchistic wrote wrote wrote wrote wrote wrote wrote wrote First5 prediction: term i revolutionary labor had had french french french french Epoch: 499 Random5 prediction: anarchistic anarchistic amongst wrote wrote wrote wrote wrote wrote wrote First5 prediction: term i revolutionary labor individualist had had french french french Epoch: 599 Random5 prediction: tolstoy wrote tolstoy wrote wrote wrote wrote wrote wrote wrote First5 prediction: term i revolutionary labor individualist had had had had had Epoch: 699 Random5 prediction: tolstoy wrote tolstoy wrote wrote wrote wrote wrote wrote wrote First5 prediction: term i revolutionary labor individualist had had had had had Epoch: 799 Random5 prediction: tolstoy wrote tolstoy tolstoy tolstoy tolstoy tolstoy tolstoy tolstoy tolstoy First5 prediction: term i revolutionary labor individualist had had had had had Epoch: 899 Random5 prediction: tolstoy wrote tolstoy tolstoy tolstoy tolstoy tolstoy tolstoy tolstoy tolstoy First5 prediction: term i revolutionary labor should warren warren warren warren warren Epoch: 999 Random5 prediction: tolstoy wrote tolstoy tolstoy tolstoy tolstoy tolstoy tolstoy tolstoy tolstoy First5 prediction: term i individualist labor should warren warren warren warren warren ``` 如果您注意到我们在 LSTM 模型的输出中有重复的单词用于文本生成。虽然超参数和网络调整可以消除一些重复,但还有其他方法可以解决这个问题。我们得到重复单词的原因是模型总是从单词的概率分布中选择具有最高概率的单词。这可以改变以选择诸如在连续单词之间引入更大可变性的单词。