from datetime import datetime

from keras.callbacks import ModelCheckpoint, EarlyStopping

from glossolalia.loader import load_seeds, load_texts
from glossolalia.lstm import LisSansTaMaman
from glossolalia.tokens import PoemTokenizer


def train():
    # should_train = True
    nb_words = 200
    nb_epoch = 100
    nb_layers = 100
    dropout = .3  # TODO fine-tune layers/dropout
    validation_split = 0.1
    lstm = LisSansTaMaman(nb_layers, dropout, validation_split,
                          tokenizer=PoemTokenizer(lower=False), debug=True)
    #
    filename_model = "../models/verne/verne_lstm%i-d%.1f-{epoch:02d}_%i-{accuracy:.4f}.hdf5" % (
        nb_layers, dropout, nb_epoch)
    filename_output = "./output/verne_%i-d%.1f_%s.txt" % (
        nb_layers, dropout, datetime.now().strftime("%y%m%d_%H%M"))
    callbacks_list = [
        ModelCheckpoint(filename_model, monitor='val_accuracy', period=10, save_best_only=True),
        EarlyStopping(monitor='val_accuracy', patience=5)]

    corpus = load_texts()
    print("Corpus:", corpus[:10])
    lstm.create_model(corpus)
    with open(filename_output, "a+") as f:
        for i in range(0, nb_epoch, 10):
            lstm.fit(epochs=min(i + 10, nb_epoch), initial_epoch=i,
                     callbacks=callbacks_list,
                     validation_split=validation_split)

            for output in lstm.predict_seeds(nb_words):
                print(output)
                f.writelines(output)

        for i, seed in enumerate(load_seeds(corpus, 5)):
            output = lstm.predict(seed, nb_words)
            print("%i %s -> %s" % (i, seed, output))
            f.writelines(output)

        while True:
            input_text = input("> ")
            text = lstm.predict(input_text, nb_words)
            print(text)
            f.writelines("%s\n" % text)


if __name__ == '__main__':
    train()