feat(lstm): refact, predict, nocomment

parent 63e2e5b7
......@@ -4,6 +4,7 @@ import warnings
import numpy as np
from keras import Sequential
from keras.engine.saving import load_model
from keras.layers import Embedding, LSTM, Dropout, Dense
from keras.preprocessing.text import Tokenizer
from keras.utils import to_categorical
......@@ -26,10 +27,10 @@ def load():
content = f.readlines()
all_lines.extend(content)
all_lines = [h for h in all_lines if
h[0] != "["]
all_lines = [h for h in all_lines if h[0] not in ["[", "#"]
]
len(all_lines)
print("Loaded data:", all_lines[0])
print("Loaded %i lines of data: %s." % (len(all_lines), all_lines[0]))
return all_lines
......@@ -78,7 +79,7 @@ def generate_padded_sequences(input_sequences, total_words):
return predictors, label, max_sequence_len
def create_model(max_sequence_len, total_words):
def create_model(max_sequence_len, total_words, layers=100, dropout=0.1): # TODO finetune
input_len = max_sequence_len - 1
model = Sequential()
......@@ -86,8 +87,8 @@ def create_model(max_sequence_len, total_words):
model.add(Embedding(total_words, 10, input_length=input_len))
# Add Hidden Layer 1 - LSTM Layer
model.add(LSTM(100)) # TODO finetune
model.add(Dropout(0.1)) # TODO finetune
model.add(LSTM(layers))
model.add(Dropout(dropout))
# Add Output Layer
model.add(Dense(total_words, activation='softmax'))
......@@ -113,12 +114,18 @@ def generate_text(seed_text, nb_words, model, max_sequence_len):
def main():
should_train = True
nb_epoch = 20
model_file = "../models/dawa_lstm_%i.hd5" % nb_epoch
max_sequence_len = 5 # TODO: Test different default
if should_train:
lines = load()
corpus = [clean_text(x) for x in lines]
print(corpus[:10])
inp_sequences, total_words = get_sequence_of_tokens(corpus[:10]) # Fixme: Corpus cliff for debug
inp_sequences, total_words = get_sequence_of_tokens(corpus)
print(inp_sequences[:10])
predictors, label, max_sequence_len = generate_padded_sequences(inp_sequences, total_words)
......@@ -127,11 +134,18 @@ def main():
model = create_model(max_sequence_len, total_words)
model.summary()
model.fit(predictors, label, epochs=10, verbose=5)
model.fit(predictors, label, epochs=nb_epoch, verbose=5)
model.save(model_file)
else:
model = load_model(model_file)
print(generate_text("", 10, model, max_sequence_len))
print(generate_text("L'étoile", 10, model, max_sequence_len))
while True:
input_text = input("> ")
print(generate_text(input_text, 10, model, max_sequence_len))
if __name__ == '__main__':
main()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment