refactor(lstm): Less debug

parent 3da1e4f1
...@@ -49,7 +49,7 @@ def generate_text(model, tokenizer, seed_text="", nb_words=5, max_sequence_len=0 ...@@ -49,7 +49,7 @@ def generate_text(model, tokenizer, seed_text="", nb_words=5, max_sequence_len=0
for _ in range(nb_words): for _ in range(nb_words):
token_list = tokenizer.texts_to_sequences([seed_text])[0] token_list = tokenizer.texts_to_sequences([seed_text])[0]
token_list = pad_sequences([token_list], maxlen=max_sequence_len - 1, padding='pre') token_list = pad_sequences([token_list], maxlen=max_sequence_len - 1, padding='pre')
predicted = model.predict_classes(token_list, verbose=0) predicted = model.predict_classes(token_list, verbose=2)
output_word = "" output_word = ""
for word, index in tokenizer.word_index.items(): for word, index in tokenizer.word_index.items():
...@@ -63,22 +63,18 @@ def generate_text(model, tokenizer, seed_text="", nb_words=5, max_sequence_len=0 ...@@ -63,22 +63,18 @@ def generate_text(model, tokenizer, seed_text="", nb_words=5, max_sequence_len=0
def main(): def main():
should_train = True should_train = True
nb_epoch = 100 nb_epoch = 100
max_sequence_len = 61 # TODO: Test different default
model_file = "../models/dawa_lstm_%i.hd5" % nb_epoch model_file = "../models/dawa_lstm_%i.hd5" % nb_epoch
max_sequence_len = 5 # TODO: Test different default
tokenizer = Tokenizer() tokenizer = Tokenizer()
if should_train: if should_train:
lines = load_kawa() lines = load_kawa()
corpus = [clean_text(x) for x in lines] corpus = [clean_text(x) for x in lines]
print(corpus[:10]) print("Corpus:", corpus[:2])
inp_sequences, total_words = get_sequence_of_tokens(corpus, tokenizer) inp_sequences, total_words = get_sequence_of_tokens(corpus, tokenizer)
print(inp_sequences[:10])
predictors, label, max_sequence_len = generate_padded_sequences(inp_sequences, total_words) predictors, label, max_sequence_len = generate_padded_sequences(inp_sequences, total_words)
print(predictors, label, max_sequence_len)
model = create_model(max_sequence_len, total_words) model = create_model(max_sequence_len, total_words)
model.summary() model.summary()
...@@ -87,12 +83,17 @@ def main(): ...@@ -87,12 +83,17 @@ def main():
else: else:
model = load_model(model_file) model = load_model(model_file)
for sample in ["", "L'étoile ", "Elle ", "Les punchlines "]: for sample in ["",
print(generate_text(model, tokenizer, sample, 100, max_sequence_len)) "L'étoile du sol",
"Elle me l'a toujours dit",
"Les punchlines sont pour ceux"]:
nb_words = 50
print(generate_text(model, tokenizer, sample, nb_words, max_sequence_len))
while True: while True:
input_text = input("> ") input_text = input("> ")
print(generate_text(model, tokenizer, input_text, 100, max_sequence_len)) print(generate_text(model, tokenizer, input_text, nb_words, max_sequence_len))
print(generate_text(model, tokenizer, input_text, nb_words, max_sequence_len))
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -11,8 +11,6 @@ def get_sequence_of_tokens(corpus, tokenizer=Tokenizer()): ...@@ -11,8 +11,6 @@ def get_sequence_of_tokens(corpus, tokenizer=Tokenizer()):
# convert data to sequence of tokens # convert data to sequence of tokens
input_sequences = [] input_sequences = []
# FIXME Debug: truncate corpus
corpus = corpus[:50]
for line in corpus: for line in corpus:
token_list = tokenizer.texts_to_sequences([line])[0] token_list = tokenizer.texts_to_sequences([line])[0]
for i in range(1, len(token_list)): for i in range(1, len(token_list)):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment