refactor(lstm): Less debug

parent 3da1e4f1
......@@ -49,7 +49,7 @@ def generate_text(model, tokenizer, seed_text="", nb_words=5, max_sequence_len=0
for _ in range(nb_words):
token_list = tokenizer.texts_to_sequences([seed_text])[0]
token_list = pad_sequences([token_list], maxlen=max_sequence_len - 1, padding='pre')
predicted = model.predict_classes(token_list, verbose=0)
predicted = model.predict_classes(token_list, verbose=2)
output_word = ""
for word, index in tokenizer.word_index.items():
......@@ -63,22 +63,18 @@ def generate_text(model, tokenizer, seed_text="", nb_words=5, max_sequence_len=0
def main():
should_train = True
nb_epoch = 100
max_sequence_len = 61 # TODO: Test different default
model_file = "../models/dawa_lstm_%i.hd5" % nb_epoch
max_sequence_len = 5 # TODO: Test different default
tokenizer = Tokenizer()
if should_train:
lines = load_kawa()
corpus = [clean_text(x) for x in lines]
print(corpus[:10])
print("Corpus:", corpus[:2])
inp_sequences, total_words = get_sequence_of_tokens(corpus, tokenizer)
print(inp_sequences[:10])
predictors, label, max_sequence_len = generate_padded_sequences(inp_sequences, total_words)
print(predictors, label, max_sequence_len)
model = create_model(max_sequence_len, total_words)
model.summary()
......@@ -87,12 +83,17 @@ def main():
else:
model = load_model(model_file)
for sample in ["", "L'étoile ", "Elle ", "Les punchlines "]:
print(generate_text(model, tokenizer, sample, 100, max_sequence_len))
for sample in ["",
"L'étoile du sol",
"Elle me l'a toujours dit",
"Les punchlines sont pour ceux"]:
nb_words = 50
print(generate_text(model, tokenizer, sample, nb_words, max_sequence_len))
while True:
input_text = input("> ")
print(generate_text(model, tokenizer, input_text, 100, max_sequence_len))
print(generate_text(model, tokenizer, input_text, nb_words, max_sequence_len))
print(generate_text(model, tokenizer, input_text, nb_words, max_sequence_len))
if __name__ == '__main__':
......
......@@ -11,8 +11,6 @@ def get_sequence_of_tokens(corpus, tokenizer=Tokenizer()):
# convert data to sequence of tokens
input_sequences = []
# FIXME Debug: truncate corpus
corpus = corpus[:50]
for line in corpus:
token_list = tokenizer.texts_to_sequences([line])[0]
for i in range(1, len(token_list)):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment