feat(Dawa): Seeds

parent 818b32b4
...@@ -130,3 +130,6 @@ dmypy.json ...@@ -130,3 +130,6 @@ dmypy.json
# IDE # IDE
.idea/ .idea/
# Outputs
output/
import os import os
import string import string
from pprint import pprint
from random import choice, randint
from numpy.random import seed from numpy.random import seed
from tensorflow_core.python.framework.random_seed import set_random_seed from tensorflow_core.python.framework.random_seed import set_random_seed
...@@ -11,7 +13,9 @@ def load_kawa(root="./"): ...@@ -11,7 +13,9 @@ def load_kawa(root="./"):
seed(1) seed(1)
data_dir = root + 'data/' data_dir = root + 'data/'
all_lines = [] all_lines = []
for filename in os.listdir(data_dir): files = os.listdir(data_dir)
print("%i files in data folder." % len(files))
for filename in files:
with open(data_dir + filename) as f: with open(data_dir + filename) as f:
content = f.readlines() content = f.readlines()
all_lines.extend(content) all_lines.extend(content)
...@@ -23,6 +27,19 @@ def load_kawa(root="./"): ...@@ -23,6 +27,19 @@ def load_kawa(root="./"):
return all_lines return all_lines
def load_seeds(kawa=None, nb_seeds=10):
if kawa is None:
kawa = load_kawa()
seeds = []
for i in range(nb_seeds):
plain_kawa = filter(lambda k: k != "\n", kawa)
chosen = choice(list(plain_kawa))
split = chosen.split(" ")
nb_words = randint(1, len(split))
seeds.append(split[:nb_words])
return seeds
def clean_text(lines): def clean_text(lines):
""" """
In dataset preparation step, we will first perform text cleaning of the data In dataset preparation step, we will first perform text cleaning of the data
...@@ -37,6 +54,8 @@ def main(): ...@@ -37,6 +54,8 @@ def main():
lines = load_kawa("../") lines = load_kawa("../")
clean = clean_text(lines) clean = clean_text(lines)
print(clean) print(clean)
print("Some seeds:\n\n")
pprint(load_seeds(lines))
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -7,7 +7,7 @@ from keras.utils import to_categorical ...@@ -7,7 +7,7 @@ from keras.utils import to_categorical
from keras_preprocessing.sequence import pad_sequences from keras_preprocessing.sequence import pad_sequences
from keras_preprocessing.text import Tokenizer from keras_preprocessing.text import Tokenizer
from KoozDawa.dawa.loader import load_kawa, clean_text from KoozDawa.dawa.loader import load_kawa, clean_text, load_seeds
from KoozDawa.dawa.tokens import get_sequence_of_tokens from KoozDawa.dawa.tokens import get_sequence_of_tokens
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
...@@ -61,35 +61,31 @@ def generate_text(model, tokenizer, seed_text="", nb_words=5, max_sequence_len=0 ...@@ -61,35 +61,31 @@ def generate_text(model, tokenizer, seed_text="", nb_words=5, max_sequence_len=0
def main(): def main():
should_train = True should_train = True
nb_epoch = 100
max_sequence_len = 61 # TODO: Test different default
# model_file = "../models/dawa_lstm_%i.hd5" % nb_epoch # model_file = "../models/dawa_lstm_%i.hd5" % nb_epoch
nb_epoch = 100
nb_words = 200
tokenizer = Tokenizer() tokenizer = Tokenizer()
if should_train: # if should_train:
lines = load_kawa() lines = load_kawa()
corpus = [clean_text(x) for x in lines] corpus = [clean_text(x) for x in lines]
print("Corpus:", corpus[:2]) print("Corpus:", corpus[:5])
inp_sequences, total_words = get_sequence_of_tokens(corpus, tokenizer) inp_sequences, total_words = get_sequence_of_tokens(corpus, tokenizer)
predictors, label, max_sequence_len = generate_padded_sequences(inp_sequences, total_words) predictors, label, max_sequence_len = generate_padded_sequences(inp_sequences, total_words)
model = create_model(max_sequence_len, total_words) model = create_model(max_sequence_len, total_words)
model.summary() model.summary()
model.fit(predictors, label, epochs=nb_epoch, verbose=5) model.fit(predictors, label, epochs=nb_epoch, verbose=5)
# model.save(model_file) # model.save(model_file)
# else: # FIXME: Load and predict # else: # FIXME: Load and predict
# model = load_model(model_file) # model = load_model(model_file)
for sample in ["", for sample in load_seeds(lines):
"L'étoile du sol",
"Elle me l'a toujours dit",
"Les punchlines sont pour ceux"]:
nb_words = 200
print(generate_text(model, tokenizer, sample, nb_words, max_sequence_len)) print(generate_text(model, tokenizer, sample, nb_words, max_sequence_len))
with open("../output/lstm.txt", "a") as f: with open("./output/lstm.txt", "a+") as f:
while True: while True:
input_text = input("> ") input_text = input("> ")
text = generate_text(model, tokenizer, input_text, nb_words, max_sequence_len) text = generate_text(model, tokenizer, input_text, nb_words, max_sequence_len)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment