verne.py 2.34 KB
Newer Older
PLN (Algolia) committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65
from keras.callbacks import ModelCheckpoint, EarlyStopping

from glossolalia.loader import load_seeds, load_text, load_texts
from glossolalia.lstm import generate_padded_sequences, create_model, generate_text
from glossolalia.tokens import PoemTokenizer


def main():
    # should_train = True
    # model_file = "../models/dawa_lstm_%i.hd5" % nb_epoch
    nb_words = 20
    nb_epoch = 100
    nb_layers = 128
    dropout = .3
    tokenizer = PoemTokenizer()

    # if should_train:
    corpus = load_texts()
    print("Corpus:", corpus[:10])

    inp_sequences, total_words = tokenizer.get_sequence_of_tokens(corpus)
    predictors, label, max_sequence_len = generate_padded_sequences(inp_sequences, total_words)
    model = create_model(max_sequence_len, total_words, layers=nb_layers, dropout=dropout)
    model.summary()

    file_path = "../models/verne/verne_lstm%i-d%.1f-{epoch:02d}_%i-{accuracy:.4f}.hdf5" % (nb_layers, dropout, nb_epoch)
    checkpoint = ModelCheckpoint(file_path, monitor='accuracy', period=10, save_best_only=True)
    # print_callback = LambdaCallback(on_epoch_end=on_epoch_end)
    early_stopping = EarlyStopping(monitor='accuracy', patience=5)
    callbacks_list = [checkpoint, early_stopping]

    for i in range(0, nb_epoch, 10):
        model.fit(predictors, label, initial_epoch=i, epochs=min(i + 10, nb_epoch), verbose=2, callbacks=callbacks_list)
        for seed in ["", "Je", "Tu", "Le", "La", "Les", "Un", "On", "Nous"]:
            print(generate_text(model, tokenizer, seed, nb_words, max_sequence_len))

    # model.save(model_file)
    # else: # FIXME: Load and predict, maybe reuse checkpoints?
    # model = load_model(model_file)

    for i, seed in enumerate(load_seeds(corpus, 5)):
        output = generate_text(model, tokenizer, seed, nb_words, max_sequence_len)
        print("%i %s -> %s" % (i, seed, output))

    with open("./output/verne.txt", "a+") as f:
        while True:
            input_text = input("> ")
            text = generate_text(model, tokenizer, input_text, nb_words, max_sequence_len)

            print(text)
            f.writelines("%s\n" % text)


def debug_unrandomize():
    from numpy.random import seed
    from tensorflow_core.python.framework.random_seed import set_random_seed

    # set seeds for reproducibility
    set_random_seed(2)
    seed(1)


if __name__ == '__main__':
    debug_unrandomize()
    main()