feat(lstm): Capitalize, vary tokenization

parent 7a97532c
......@@ -28,8 +28,8 @@ def clean_text(lines):
In dataset preparation step, we will first perform text cleaning of the data
which includes removal of punctuations and lower casing all the words.
"""
lines = "".join(v for v in lines if v not in string.punctuation).lower()
lines = lines.encode("utf8").decode("ascii", 'ignore')
lines = "".join(v for v in lines if v not in string.punctuation)
# lines = lines.encode("utf8").decode("ascii", 'ignore')
return lines
......
......@@ -2,7 +2,6 @@ import warnings
import numpy as np
from keras import Sequential
from keras.engine.saving import load_model
from keras.layers import Embedding, LSTM, Dropout, Dense
from keras.utils import to_categorical
from keras_preprocessing.sequence import pad_sequences
......@@ -57,14 +56,14 @@ def generate_text(model, tokenizer, seed_text="", nb_words=5, max_sequence_len=0
output_word = word
break
seed_text += " " + output_word
return seed_text.title()
return seed_text.capitalize()
def main():
should_train = True
nb_epoch = 100
max_sequence_len = 61 # TODO: Test different default
model_file = "../models/dawa_lstm_%i.hd5" % nb_epoch
# model_file = "../models/dawa_lstm_%i.hd5" % nb_epoch
tokenizer = Tokenizer()
if should_train:
......@@ -79,21 +78,24 @@ def main():
model.summary()
model.fit(predictors, label, epochs=nb_epoch, verbose=5)
model.save(model_file)
else:
model = load_model(model_file)
# model.save(model_file)
# else: # FIXME: Load and predict
# model = load_model(model_file)
for sample in ["",
"L'étoile du sol",
"Elle me l'a toujours dit",
"Les punchlines sont pour ceux"]:
nb_words = 50
nb_words = 200
print(generate_text(model, tokenizer, sample, nb_words, max_sequence_len))
with open("../output/lstm.txt", "a") as f:
while True:
input_text = input("> ")
print(generate_text(model, tokenizer, input_text, nb_words, max_sequence_len))
print(generate_text(model, tokenizer, input_text, nb_words, max_sequence_len))
text = generate_text(model, tokenizer, input_text, nb_words, max_sequence_len)
print(text)
f.writelines(text)
if __name__ == '__main__':
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment