feat(lstm): Capitalize, vary tokenization

parent 7a97532c
...@@ -28,8 +28,8 @@ def clean_text(lines): ...@@ -28,8 +28,8 @@ def clean_text(lines):
In dataset preparation step, we will first perform text cleaning of the data In dataset preparation step, we will first perform text cleaning of the data
which includes removal of punctuations and lower casing all the words. which includes removal of punctuations and lower casing all the words.
""" """
lines = "".join(v for v in lines if v not in string.punctuation).lower() lines = "".join(v for v in lines if v not in string.punctuation)
lines = lines.encode("utf8").decode("ascii", 'ignore') # lines = lines.encode("utf8").decode("ascii", 'ignore')
return lines return lines
......
...@@ -2,7 +2,6 @@ import warnings ...@@ -2,7 +2,6 @@ import warnings
import numpy as np import numpy as np
from keras import Sequential from keras import Sequential
from keras.engine.saving import load_model
from keras.layers import Embedding, LSTM, Dropout, Dense from keras.layers import Embedding, LSTM, Dropout, Dense
from keras.utils import to_categorical from keras.utils import to_categorical
from keras_preprocessing.sequence import pad_sequences from keras_preprocessing.sequence import pad_sequences
...@@ -57,14 +56,14 @@ def generate_text(model, tokenizer, seed_text="", nb_words=5, max_sequence_len=0 ...@@ -57,14 +56,14 @@ def generate_text(model, tokenizer, seed_text="", nb_words=5, max_sequence_len=0
output_word = word output_word = word
break break
seed_text += " " + output_word seed_text += " " + output_word
return seed_text.title() return seed_text.capitalize()
def main(): def main():
should_train = True should_train = True
nb_epoch = 100 nb_epoch = 100
max_sequence_len = 61 # TODO: Test different default max_sequence_len = 61 # TODO: Test different default
model_file = "../models/dawa_lstm_%i.hd5" % nb_epoch # model_file = "../models/dawa_lstm_%i.hd5" % nb_epoch
tokenizer = Tokenizer() tokenizer = Tokenizer()
if should_train: if should_train:
...@@ -79,21 +78,24 @@ def main(): ...@@ -79,21 +78,24 @@ def main():
model.summary() model.summary()
model.fit(predictors, label, epochs=nb_epoch, verbose=5) model.fit(predictors, label, epochs=nb_epoch, verbose=5)
model.save(model_file) # model.save(model_file)
else: # else: # FIXME: Load and predict
model = load_model(model_file) # model = load_model(model_file)
for sample in ["", for sample in ["",
"L'étoile du sol", "L'étoile du sol",
"Elle me l'a toujours dit", "Elle me l'a toujours dit",
"Les punchlines sont pour ceux"]: "Les punchlines sont pour ceux"]:
nb_words = 50 nb_words = 200
print(generate_text(model, tokenizer, sample, nb_words, max_sequence_len)) print(generate_text(model, tokenizer, sample, nb_words, max_sequence_len))
while True: with open("../output/lstm.txt", "a") as f:
input_text = input("> ") while True:
print(generate_text(model, tokenizer, input_text, nb_words, max_sequence_len)) input_text = input("> ")
print(generate_text(model, tokenizer, input_text, nb_words, max_sequence_len)) text = generate_text(model, tokenizer, input_text, nb_words, max_sequence_len)
print(text)
f.writelines(text)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment