from keras.callbacks import ModelCheckpoint, EarlyStopping from glossolalia.loader import load_seeds, load_text from glossolalia.lstm import generate_padded_sequences, create_model, generate_text from glossolalia.tokens import PoemTokenizer def main(): # should_train = True # model_file = "../models/dawa_lstm_%i.hd5" % nb_epoch nb_words = 20 nb_epoch = 50 nb_layers = 64 dropout = .2 tokenizer = PoemTokenizer() # if should_train: corpus = load_text() print("Corpus:", corpus[:10]) inp_sequences, total_words = tokenizer.get_sequence_of_tokens(corpus) predictors, label, max_sequence_len = generate_padded_sequences(inp_sequences, total_words) model = create_model(max_sequence_len, total_words, layers=nb_layers, dropout=dropout) model.summary() file_path = "../models/boulbi/boulbi_lstm%i-d%.1f-{epoch:02d}_%i-{accuracy:.4f}.hdf5" % (nb_layers, dropout, nb_epoch) checkpoint = ModelCheckpoint(file_path, monitor='accuracy', period=10, save_best_only=True) # print_callback = LambdaCallback(on_epoch_end=on_epoch_end) early_stopping = EarlyStopping(monitor='accuracy', patience=5) callbacks_list = [checkpoint, early_stopping] for i in range(0, nb_epoch, 10): model.fit(predictors, label, initial_epoch=i, epochs=min(i + 10, nb_epoch), verbose=2, callbacks=callbacks_list) for seed in ["", "Je", "Tu", "Le", "La", "Les", "Un", "On", "Nous"]: print(generate_text(model, tokenizer, seed, nb_words, max_sequence_len)) # model.save(model_file) # else: # FIXME: Load and predict, maybe reuse checkpoints? # model = load_model(model_file) for i, seed in enumerate(load_seeds(corpus, 5)): output = generate_text(model, tokenizer, seed, nb_words, max_sequence_len) print("%i %s -> %s" % (i, seed, output)) with open("./output/boulbi.txt", "a+") as f: while True: input_text = input("> ") text = generate_text(model, tokenizer, input_text, nb_words, max_sequence_len) print(text) f.writelines("%s\n" % text) def debug_unrandomize(): from numpy.random import seed from tensorflow_core.python.framework.random_seed import set_random_seed # set seeds for reproducibility set_random_seed(2) seed(1) if __name__ == '__main__': debug_unrandomize() main()