refactor(boulbi): Leverage refactoring

parent 8e6f7594
......@@ -19,6 +19,8 @@ def tweet():
# des glaçons pour les yeux brisées
# je suis pas juste un verbe que t'observe
# si tu sais rien pas d'âme de la vie
Tweeper("KoozDawa").tweet("tassepés en panel")
......
from datetime import datetime
from keras.callbacks import ModelCheckpoint, EarlyStopping
from glossolalia.loader import load_seeds, load_text
from glossolalia.lstm import generate_padded_sequences, create_model, generate_text
from glossolalia.tokens import PoemTokenizer
from glossolalia.lstm import LisSansTaMaman
def main():
def train():
# should_train = True
# model_file = "../models/dawa_lstm_%i.hd5" % nb_epoch
nb_words = 20
nb_epoch = 50
nb_layers = 64
dropout = .2
tokenizer = PoemTokenizer()
# TODO finetune layers/dropout
validation_split = 0.2
lstm = LisSansTaMaman(nb_layers, dropout, validation_split, debug=True)
filename_model = "../models/boulbi/boulbi_lstm%i-d%.1f-{epoch:02d}_%i-{accuracy:.4f}.hdf5" % (
nb_layers, dropout, nb_epoch)
filename_output = "./output/boulbi_%i-d%.1f_%s.txt" % (
nb_layers, dropout, datetime.now().strftime("%y%m%d_%H%M"))
callbacks_list = [
ModelCheckpoint(filename_model, monitor='val_accuracy', period=10, save_best_only=True),
EarlyStopping(monitor='val_accuracy', patience=5)]
# if should_train:
corpus = load_text()
print("Corpus:", corpus[:10])
lstm.create_model(corpus)
with open(filename_output, "a+") as f:
for i in range(0, nb_epoch, 10):
lstm.fit(epochs=min(i + 10, nb_epoch), initial_epoch=i,
callbacks=callbacks_list,
validation_split=validation_split)
inp_sequences, total_words = tokenizer.get_sequence_of_tokens(corpus)
predictors, label, max_sequence_len = generate_padded_sequences(inp_sequences, total_words)
model = create_model(max_sequence_len, total_words, layers=nb_layers, dropout=dropout)
model.summary()
file_path = "../models/boulbi/boulbi_lstm%i-d%.1f-{epoch:02d}_%i-{accuracy:.4f}.hdf5" % (nb_layers, dropout, nb_epoch)
checkpoint = ModelCheckpoint(file_path, monitor='accuracy', period=10, save_best_only=True)
# print_callback = LambdaCallback(on_epoch_end=on_epoch_end)
early_stopping = EarlyStopping(monitor='accuracy', patience=5)
callbacks_list = [checkpoint, early_stopping]
for i in range(0, nb_epoch, 10):
model.fit(predictors, label, initial_epoch=i, epochs=min(i + 10, nb_epoch), verbose=2, callbacks=callbacks_list)
for seed in ["", "Je", "Tu", "Le", "La", "Les", "Un", "On", "Nous"]:
print(generate_text(model, tokenizer, seed, nb_words, max_sequence_len))
# model.save(model_file)
# else: # FIXME: Load and predict, maybe reuse checkpoints?
# model = load_model(model_file)
print(lstm.predict_seeds(nb_words))
for i, seed in enumerate(load_seeds(corpus, 5)):
output = generate_text(model, tokenizer, seed, nb_words, max_sequence_len)
print("%i %s -> %s" % (i, seed, output))
for i, seed in enumerate(load_seeds(corpus, 5)):
output = lstm.predict(seed, nb_words)
print("%i %s -> %s" % (i, seed, output))
f.writelines(output)
with open("./output/boulbi.txt", "a+") as f:
while True:
input_text = input("> ")
text = generate_text(model, tokenizer, input_text, nb_words, max_sequence_len)
text = lstm.predict(input_text, nb_words)
print(text)
f.writelines("%s\n" % text)
def debug_unrandomize():
from numpy.random import seed
from tensorflow_core.python.framework.random_seed import set_random_seed
# set seeds for reproducibility
set_random_seed(2)
seed(1)
if __name__ == '__main__':
debug_unrandomize()
main()
train()
......@@ -56,7 +56,7 @@ class LisSansTaMaman(object):
validation_split=validation_split,
epochs=epochs, initial_epoch=initial_epoch)
def predict_seeds(self, seeds: List[str] = None, nb_words=None):
def predict_seeds(self, nb_words=None, seeds: List[str] = None):
if seeds is None:
seeds = ["", "Je", "Tu", "Le", "La", "Les", "Un", "On", "Nous"]
return [self.predict(seed, nb_words) for seed in seeds]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment