from datetime import datetime from keras.callbacks import ModelCheckpoint, EarlyStopping from glossolalia.loader import load_seeds, load_texts from glossolalia.lstm import LisSansTaMaman from glossolalia.tokens import PoemTokenizer def train(): # should_train = True nb_words = 200 nb_epoch = 100 nb_layers = 100 dropout = .3 # TODO fine-tune layers/dropout validation_split = 0.1 lstm = LisSansTaMaman(nb_layers, dropout, validation_split, tokenizer=PoemTokenizer(lower=False), debug=True) # filename_model = "../models/verne/verne_lstm%i-d%.1f-{epoch:02d}_%i-{accuracy:.4f}.hdf5" % ( nb_layers, dropout, nb_epoch) filename_output = "./output/verne_%i-d%.1f_%s.txt" % ( nb_layers, dropout, datetime.now().strftime("%y%m%d_%H%M")) callbacks_list = [ ModelCheckpoint(filename_model, monitor='val_accuracy', period=10, save_best_only=True), EarlyStopping(monitor='val_accuracy', patience=5)] corpus = load_texts() print("Corpus:", corpus[:10]) lstm.create_model(corpus) with open(filename_output, "a+") as f: for i in range(0, nb_epoch, 10): lstm.fit(epochs=min(i + 10, nb_epoch), initial_epoch=i, callbacks=callbacks_list, validation_split=validation_split) for output in lstm.predict_seeds(nb_words): print(output) f.writelines(output) for i, seed in enumerate(load_seeds(corpus, 5)): output = lstm.predict(seed, nb_words) print("%i %s -> %s" % (i, seed, output)) f.writelines(output) while True: input_text = input("> ") text = lstm.predict(input_text, nb_words) print(text) f.writelines("%s\n" % text) if __name__ == '__main__': train()