refactor(boulbi): Leverage refactoring

parent 8e6f7594
...@@ -19,6 +19,8 @@ def tweet(): ...@@ -19,6 +19,8 @@ def tweet():
# des glaçons pour les yeux brisées # des glaçons pour les yeux brisées
# je suis pas juste un verbe que t'observe # je suis pas juste un verbe que t'observe
# si tu sais rien pas d'âme de la vie
Tweeper("KoozDawa").tweet("tassepés en panel") Tweeper("KoozDawa").tweet("tassepés en panel")
......
from datetime import datetime
from keras.callbacks import ModelCheckpoint, EarlyStopping from keras.callbacks import ModelCheckpoint, EarlyStopping
from glossolalia.loader import load_seeds, load_text from glossolalia.loader import load_seeds, load_text
from glossolalia.lstm import generate_padded_sequences, create_model, generate_text from glossolalia.lstm import LisSansTaMaman
from glossolalia.tokens import PoemTokenizer
def main(): def train():
# should_train = True # should_train = True
# model_file = "../models/dawa_lstm_%i.hd5" % nb_epoch
nb_words = 20 nb_words = 20
nb_epoch = 50 nb_epoch = 50
nb_layers = 64 nb_layers = 64
dropout = .2 dropout = .2
tokenizer = PoemTokenizer() # TODO finetune layers/dropout
validation_split = 0.2
lstm = LisSansTaMaman(nb_layers, dropout, validation_split, debug=True)
filename_model = "../models/boulbi/boulbi_lstm%i-d%.1f-{epoch:02d}_%i-{accuracy:.4f}.hdf5" % (
nb_layers, dropout, nb_epoch)
filename_output = "./output/boulbi_%i-d%.1f_%s.txt" % (
nb_layers, dropout, datetime.now().strftime("%y%m%d_%H%M"))
callbacks_list = [
ModelCheckpoint(filename_model, monitor='val_accuracy', period=10, save_best_only=True),
EarlyStopping(monitor='val_accuracy', patience=5)]
# if should_train:
corpus = load_text() corpus = load_text()
print("Corpus:", corpus[:10]) print("Corpus:", corpus[:10])
lstm.create_model(corpus)
with open(filename_output, "a+") as f:
for i in range(0, nb_epoch, 10):
lstm.fit(epochs=min(i + 10, nb_epoch), initial_epoch=i,
callbacks=callbacks_list,
validation_split=validation_split)
inp_sequences, total_words = tokenizer.get_sequence_of_tokens(corpus) print(lstm.predict_seeds(nb_words))
predictors, label, max_sequence_len = generate_padded_sequences(inp_sequences, total_words)
model = create_model(max_sequence_len, total_words, layers=nb_layers, dropout=dropout)
model.summary()
file_path = "../models/boulbi/boulbi_lstm%i-d%.1f-{epoch:02d}_%i-{accuracy:.4f}.hdf5" % (nb_layers, dropout, nb_epoch)
checkpoint = ModelCheckpoint(file_path, monitor='accuracy', period=10, save_best_only=True)
# print_callback = LambdaCallback(on_epoch_end=on_epoch_end)
early_stopping = EarlyStopping(monitor='accuracy', patience=5)
callbacks_list = [checkpoint, early_stopping]
for i in range(0, nb_epoch, 10):
model.fit(predictors, label, initial_epoch=i, epochs=min(i + 10, nb_epoch), verbose=2, callbacks=callbacks_list)
for seed in ["", "Je", "Tu", "Le", "La", "Les", "Un", "On", "Nous"]:
print(generate_text(model, tokenizer, seed, nb_words, max_sequence_len))
# model.save(model_file)
# else: # FIXME: Load and predict, maybe reuse checkpoints?
# model = load_model(model_file)
for i, seed in enumerate(load_seeds(corpus, 5)): for i, seed in enumerate(load_seeds(corpus, 5)):
output = generate_text(model, tokenizer, seed, nb_words, max_sequence_len) output = lstm.predict(seed, nb_words)
print("%i %s -> %s" % (i, seed, output)) print("%i %s -> %s" % (i, seed, output))
f.writelines(output)
with open("./output/boulbi.txt", "a+") as f:
while True: while True:
input_text = input("> ") input_text = input("> ")
text = generate_text(model, tokenizer, input_text, nb_words, max_sequence_len) text = lstm.predict(input_text, nb_words)
print(text) print(text)
f.writelines("%s\n" % text) f.writelines("%s\n" % text)
def debug_unrandomize():
from numpy.random import seed
from tensorflow_core.python.framework.random_seed import set_random_seed
# set seeds for reproducibility
set_random_seed(2)
seed(1)
if __name__ == '__main__': if __name__ == '__main__':
debug_unrandomize() train()
main()
...@@ -56,7 +56,7 @@ class LisSansTaMaman(object): ...@@ -56,7 +56,7 @@ class LisSansTaMaman(object):
validation_split=validation_split, validation_split=validation_split,
epochs=epochs, initial_epoch=initial_epoch) epochs=epochs, initial_epoch=initial_epoch)
def predict_seeds(self, seeds: List[str] = None, nb_words=None): def predict_seeds(self, nb_words=None, seeds: List[str] = None):
if seeds is None: if seeds is None:
seeds = ["", "Je", "Tu", "Le", "La", "Les", "Un", "On", "Nous"] seeds = ["", "Je", "Tu", "Le", "La", "Les", "Un", "On", "Nous"]
return [self.predict(seed, nb_words) for seed in seeds] return [self.predict(seed, nb_words) for seed in seeds]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment