refact(lstm): Fix generate_text, extract params, use PoemTok

parent d06fd636
......@@ -2,13 +2,14 @@ import warnings
import numpy as np
from keras import Sequential
from keras.callbacks import ModelCheckpoint, EarlyStopping
from keras.layers import Embedding, LSTM, Dropout, Dense
from keras.utils import to_categorical
from keras_preprocessing.sequence import pad_sequences
from keras_preprocessing.text import Tokenizer
from KoozDawa.dawa.loader import load_kawa, clean_text, load_seeds
from KoozDawa.dawa.tokens import get_sequence_of_tokens
from KoozDawa.dawa.tokens import PoemTokenizer
warnings.filterwarnings("ignore")
warnings.simplefilter(action='ignore', category=FutureWarning)
......@@ -25,7 +26,7 @@ def generate_padded_sequences(input_sequences, total_words):
return predictors, label, max_sequence_len
def create_model(max_sequence_len, total_words, layers=100, dropout=0.1): # TODO finetune
def create_model(max_sequence_len, total_words, layers=128, dropout=0.2): # TODO finetune layers/dropout
input_len = max_sequence_len - 1
model = Sequential()
......@@ -39,7 +40,9 @@ def create_model(max_sequence_len, total_words, layers=100, dropout=0.1): # TOD
# Add Output Layer
model.add(Dense(total_words, activation='softmax'))
model.compile(loss='categorical_crossentropy', optimizer='adam')
model.compile(optimizer='adam', # TODO: Try RMSprop(learning_rate=0.01)
loss='categorical_crossentropy', # TODO: Try sparse_categorical_crossentropy for faster training
metrics=['accuracy'])
# TODO: Try alternative architectures
# https://medium.com/coinmonks/word-level-lstm-text-generator-creating-automatic-song-lyrics-with-neural-networks-b8a1617104fb#35f4
......@@ -47,53 +50,61 @@ def create_model(max_sequence_len, total_words, layers=100, dropout=0.1): # TOD
def generate_text(model: Sequential, tokenizer: Tokenizer, seed_text="", nb_words=5, max_sequence_len=0) -> str:
word_indices = {v: k for k, v in tokenizer.word_index.items()}
output = seed_text
for _ in range(nb_words):
token_list = tokenizer.texts_to_sequences([seed_text])[0]
token_list = tokenizer.texts_to_sequences([output])[0]
token_list = pad_sequences([token_list], maxlen=max_sequence_len - 1, padding='pre')
predicted = model.predict_classes(token_list, verbose=2)
output_word = ""
for word, index in tokenizer.word_index.items():
if index == predicted:
output_word = word
break
seed_text += " " + output_word
return seed_text.capitalize()
predicted = model.predict_classes(token_list, verbose=2)[0]
output += " " + word_indices[predicted]
return output.capitalize()
def main():
should_train = True
# should_train = True
# model_file = "../models/dawa_lstm_%i.hd5" % nb_epoch
nb_words = 20
nb_epoch = 100
nb_words = 200
tokenizer = Tokenizer()
nb_layers = 128
dropout = .2
tokenizer = PoemTokenizer()
# if should_train:
lines = load_kawa()
corpus = [clean_text(x) for x in lines]
print("Corpus:", corpus[:5])
print("Corpus:", corpus[:10])
inp_sequences, total_words = get_sequence_of_tokens(corpus, tokenizer)
inp_sequences, total_words = tokenizer.get_sequence_of_tokens(corpus)
predictors, label, max_sequence_len = generate_padded_sequences(inp_sequences, total_words)
model = create_model(max_sequence_len, total_words)
model = create_model(max_sequence_len, total_words, layers=nb_layers, dropout=dropout)
model.summary()
model.fit(predictors, label, epochs=nb_epoch, verbose=5)
file_path = "../models/dawa_lstm%i-d%.1f-{epoch:02d}_%i-{accuracy:.4f}.hdf5" % (nb_layers, dropout, nb_epoch)
checkpoint = ModelCheckpoint(file_path, monitor='accuracy', save_best_only=True)
# print_callback = LambdaCallback(on_epoch_end=on_epoch_end)
early_stopping = EarlyStopping(monitor='accuracy', patience=5)
callbacks_list = [checkpoint, early_stopping]
for i in range(nb_epoch):
model.fit(predictors, label, initial_epoch=i, epochs=i + 1, verbose=2, callbacks=callbacks_list)
# model.save(model_file)
# else: # FIXME: Load and predict
# else: # FIXME: Load and predict, maybe reuse checkpoints?
# model = load_model(model_file)
for sample in load_seeds(lines):
print(generate_text(model, tokenizer, sample, nb_words, max_sequence_len))
for i, seed in enumerate(load_seeds(lines, 3)):
output = generate_text(model, tokenizer, seed, nb_words, max_sequence_len)
print("%i %s -> %s" % (i, seed, output))
with open("./output/lstm.txt", "a+") as f:
with open("./output/dawa.txt", "a+") as f:
while True:
input_text = input("> ")
text = generate_text(model, tokenizer, input_text, nb_words, max_sequence_len)
print(text)
f.writelines(text)
f.writelines("%s\n" % text)
if __name__ == '__main__':
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment