refact(lstm): Fix generate_text, extract params, use PoemTok

parent d06fd636
...@@ -2,13 +2,14 @@ import warnings ...@@ -2,13 +2,14 @@ import warnings
import numpy as np import numpy as np
from keras import Sequential from keras import Sequential
from keras.callbacks import ModelCheckpoint, EarlyStopping
from keras.layers import Embedding, LSTM, Dropout, Dense from keras.layers import Embedding, LSTM, Dropout, Dense
from keras.utils import to_categorical from keras.utils import to_categorical
from keras_preprocessing.sequence import pad_sequences from keras_preprocessing.sequence import pad_sequences
from keras_preprocessing.text import Tokenizer from keras_preprocessing.text import Tokenizer
from KoozDawa.dawa.loader import load_kawa, clean_text, load_seeds from KoozDawa.dawa.loader import load_kawa, clean_text, load_seeds
from KoozDawa.dawa.tokens import get_sequence_of_tokens from KoozDawa.dawa.tokens import PoemTokenizer
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
warnings.simplefilter(action='ignore', category=FutureWarning) warnings.simplefilter(action='ignore', category=FutureWarning)
...@@ -25,7 +26,7 @@ def generate_padded_sequences(input_sequences, total_words): ...@@ -25,7 +26,7 @@ def generate_padded_sequences(input_sequences, total_words):
return predictors, label, max_sequence_len return predictors, label, max_sequence_len
def create_model(max_sequence_len, total_words, layers=100, dropout=0.1): # TODO finetune def create_model(max_sequence_len, total_words, layers=128, dropout=0.2): # TODO finetune layers/dropout
input_len = max_sequence_len - 1 input_len = max_sequence_len - 1
model = Sequential() model = Sequential()
...@@ -39,7 +40,9 @@ def create_model(max_sequence_len, total_words, layers=100, dropout=0.1): # TOD ...@@ -39,7 +40,9 @@ def create_model(max_sequence_len, total_words, layers=100, dropout=0.1): # TOD
# Add Output Layer # Add Output Layer
model.add(Dense(total_words, activation='softmax')) model.add(Dense(total_words, activation='softmax'))
model.compile(loss='categorical_crossentropy', optimizer='adam') model.compile(optimizer='adam', # TODO: Try RMSprop(learning_rate=0.01)
loss='categorical_crossentropy', # TODO: Try sparse_categorical_crossentropy for faster training
metrics=['accuracy'])
# TODO: Try alternative architectures # TODO: Try alternative architectures
# https://medium.com/coinmonks/word-level-lstm-text-generator-creating-automatic-song-lyrics-with-neural-networks-b8a1617104fb#35f4 # https://medium.com/coinmonks/word-level-lstm-text-generator-creating-automatic-song-lyrics-with-neural-networks-b8a1617104fb#35f4
...@@ -47,53 +50,61 @@ def create_model(max_sequence_len, total_words, layers=100, dropout=0.1): # TOD ...@@ -47,53 +50,61 @@ def create_model(max_sequence_len, total_words, layers=100, dropout=0.1): # TOD
def generate_text(model: Sequential, tokenizer: Tokenizer, seed_text="", nb_words=5, max_sequence_len=0) -> str: def generate_text(model: Sequential, tokenizer: Tokenizer, seed_text="", nb_words=5, max_sequence_len=0) -> str:
word_indices = {v: k for k, v in tokenizer.word_index.items()}
output = seed_text
for _ in range(nb_words): for _ in range(nb_words):
token_list = tokenizer.texts_to_sequences([seed_text])[0] token_list = tokenizer.texts_to_sequences([output])[0]
token_list = pad_sequences([token_list], maxlen=max_sequence_len - 1, padding='pre') token_list = pad_sequences([token_list], maxlen=max_sequence_len - 1, padding='pre')
predicted = model.predict_classes(token_list, verbose=2) predicted = model.predict_classes(token_list, verbose=2)[0]
output += " " + word_indices[predicted]
output_word = "" return output.capitalize()
for word, index in tokenizer.word_index.items():
if index == predicted:
output_word = word
break
seed_text += " " + output_word
return seed_text.capitalize()
def main(): def main():
should_train = True # should_train = True
# model_file = "../models/dawa_lstm_%i.hd5" % nb_epoch # model_file = "../models/dawa_lstm_%i.hd5" % nb_epoch
nb_words = 20
nb_epoch = 100 nb_epoch = 100
nb_words = 200 nb_layers = 128
tokenizer = Tokenizer() dropout = .2
tokenizer = PoemTokenizer()
# if should_train: # if should_train:
lines = load_kawa() lines = load_kawa()
corpus = [clean_text(x) for x in lines] corpus = [clean_text(x) for x in lines]
print("Corpus:", corpus[:5]) print("Corpus:", corpus[:10])
inp_sequences, total_words = get_sequence_of_tokens(corpus, tokenizer) inp_sequences, total_words = tokenizer.get_sequence_of_tokens(corpus)
predictors, label, max_sequence_len = generate_padded_sequences(inp_sequences, total_words) predictors, label, max_sequence_len = generate_padded_sequences(inp_sequences, total_words)
model = create_model(max_sequence_len, total_words) model = create_model(max_sequence_len, total_words, layers=nb_layers, dropout=dropout)
model.summary() model.summary()
model.fit(predictors, label, epochs=nb_epoch, verbose=5) file_path = "../models/dawa_lstm%i-d%.1f-{epoch:02d}_%i-{accuracy:.4f}.hdf5" % (nb_layers, dropout, nb_epoch)
checkpoint = ModelCheckpoint(file_path, monitor='accuracy', save_best_only=True)
# print_callback = LambdaCallback(on_epoch_end=on_epoch_end)
early_stopping = EarlyStopping(monitor='accuracy', patience=5)
callbacks_list = [checkpoint, early_stopping]
for i in range(nb_epoch):
model.fit(predictors, label, initial_epoch=i, epochs=i + 1, verbose=2, callbacks=callbacks_list)
# model.save(model_file) # model.save(model_file)
# else: # FIXME: Load and predict # else: # FIXME: Load and predict, maybe reuse checkpoints?
# model = load_model(model_file) # model = load_model(model_file)
for sample in load_seeds(lines): for i, seed in enumerate(load_seeds(lines, 3)):
print(generate_text(model, tokenizer, sample, nb_words, max_sequence_len)) output = generate_text(model, tokenizer, seed, nb_words, max_sequence_len)
print("%i %s -> %s" % (i, seed, output))
with open("./output/lstm.txt", "a+") as f: with open("./output/dawa.txt", "a+") as f:
while True: while True:
input_text = input("> ") input_text = input("> ")
text = generate_text(model, tokenizer, input_text, nb_words, max_sequence_len) text = generate_text(model, tokenizer, input_text, nb_words, max_sequence_len)
print(text) print(text)
f.writelines(text) f.writelines("%s\n" % text)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment