from keras.callbacks import ModelCheckpoint, EarlyStopping

from glossolalia.loader import load_seeds, load_text
from glossolalia.lstm import generate_padded_sequences, create_model, generate_text
from glossolalia.tokens import PoemTokenizer


def main():
    # should_train = True
    # model_file = "../models/dawa_lstm_%i.hd5" % nb_epoch
    nb_words = 20
    nb_epoch = 50
    nb_layers = 64
    dropout = .2
    tokenizer = PoemTokenizer()

    # if should_train:
    corpus = load_text()
    print("Corpus:", corpus[:10])

    inp_sequences, total_words = tokenizer.get_sequence_of_tokens(corpus)
    predictors, label, max_sequence_len = generate_padded_sequences(inp_sequences, total_words)
    model = create_model(max_sequence_len, total_words, layers=nb_layers, dropout=dropout)
    model.summary()

    file_path = "../models/boulbi/boulbi_lstm%i-d%.1f-{epoch:02d}_%i-{accuracy:.4f}.hdf5" % (nb_layers, dropout, nb_epoch)
    checkpoint = ModelCheckpoint(file_path, monitor='accuracy', period=10, save_best_only=True)
    # print_callback = LambdaCallback(on_epoch_end=on_epoch_end)
    early_stopping = EarlyStopping(monitor='accuracy', patience=5)
    callbacks_list = [checkpoint, early_stopping]

    for i in range(0, nb_epoch, 10):
        model.fit(predictors, label, initial_epoch=i, epochs=min(i + 10, nb_epoch), verbose=2, callbacks=callbacks_list)
        for seed in ["", "Je", "Tu", "Le", "La", "Les", "Un", "On", "Nous"]:
            print(generate_text(model, tokenizer, seed, nb_words, max_sequence_len))

    # model.save(model_file)
    # else: # FIXME: Load and predict, maybe reuse checkpoints?
    # model = load_model(model_file)

    for i, seed in enumerate(load_seeds(corpus, 5)):
        output = generate_text(model, tokenizer, seed, nb_words, max_sequence_len)
        print("%i %s -> %s" % (i, seed, output))

    with open("./output/boulbi.txt", "a+") as f:
        while True:
            input_text = input("> ")
            text = generate_text(model, tokenizer, input_text, nb_words, max_sequence_len)

            print(text)
            f.writelines("%s\n" % text)


def debug_unrandomize():
    from numpy.random import seed
    from tensorflow_core.python.framework.random_seed import set_random_seed

    # set seeds for reproducibility
    set_random_seed(2)
    seed(1)


if __name__ == '__main__':
    debug_unrandomize()
    main()