feat(kawa): Save/generate every 10 epochs

parent d7694554
...@@ -3,7 +3,7 @@ import warnings ...@@ -3,7 +3,7 @@ import warnings
import numpy as np import numpy as np
from keras import Sequential from keras import Sequential
from keras.callbacks import ModelCheckpoint, EarlyStopping from keras.callbacks import ModelCheckpoint, EarlyStopping
from keras.layers import Embedding, LSTM, Dropout, Dense from keras.layers import Embedding, LSTM, Dropout, Dense, Bidirectional
from keras.utils import to_categorical from keras.utils import to_categorical
from keras_preprocessing.sequence import pad_sequences from keras_preprocessing.sequence import pad_sequences
from keras_preprocessing.text import Tokenizer from keras_preprocessing.text import Tokenizer
...@@ -19,14 +19,14 @@ warnings.simplefilter(action='ignore', category=FutureWarning) ...@@ -19,14 +19,14 @@ warnings.simplefilter(action='ignore', category=FutureWarning)
def generate_padded_sequences(input_sequences, total_words): def generate_padded_sequences(input_sequences, total_words):
max_sequence_len = max([len(x) for x in input_sequences]) max_sequence_len = max([len(x) for x in input_sequences])
input_sequences = np.array(pad_sequences(input_sequences, maxlen=max_sequence_len, padding='pre')) input_sequences = np.array(pad_sequences(input_sequences, maxlen=max_sequence_len, padding='pre'))
print("Max len:", max_sequence_len)
predictors, label = input_sequences[:, :-1], input_sequences[:, -1] predictors, label = input_sequences[:, :-1], input_sequences[:, -1]
label = to_categorical(label, num_classes=total_words) label = to_categorical(label, num_classes=total_words)
return predictors, label, max_sequence_len return predictors, label, max_sequence_len
def create_model(max_sequence_len, total_words, layers=128, dropout=0.2): # TODO finetune layers/dropout def create_model(max_sequence_len, total_words, layers=128, dropout=0.3): # TODO finetune layers/dropout
print("Creating model across %i words for %i-long seqs (%i layers, %.2f dropout):" %
(total_words, max_sequence_len, layers, dropout))
input_len = max_sequence_len - 1 input_len = max_sequence_len - 1
model = Sequential() model = Sequential()
...@@ -35,6 +35,7 @@ def create_model(max_sequence_len, total_words, layers=128, dropout=0.2): # TOD ...@@ -35,6 +35,7 @@ def create_model(max_sequence_len, total_words, layers=128, dropout=0.2): # TOD
# Add Hidden Layer 1 - LSTM Layer # Add Hidden Layer 1 - LSTM Layer
model.add(LSTM(layers)) model.add(LSTM(layers))
# model.add(Bidirectional(LSTM(layers), input_shape=(max_sequence_len, total_words)))
model.add(Dropout(dropout)) model.add(Dropout(dropout))
# Add Output Layer # Add Output Layer
...@@ -81,20 +82,21 @@ def main(): ...@@ -81,20 +82,21 @@ def main():
model = create_model(max_sequence_len, total_words, layers=nb_layers, dropout=dropout) model = create_model(max_sequence_len, total_words, layers=nb_layers, dropout=dropout)
model.summary() model.summary()
file_path = "../models/dawa_lstm%i-d%.1f-{epoch:02d}_%i-{accuracy:.4f}.hdf5" % (nb_layers, dropout, nb_epoch) file_path = "../models/dawa/dawa_lstm%i-d%.1f-{epoch:02d}_%i-{accuracy:.4f}.hdf5" % (nb_layers, dropout, nb_epoch)
checkpoint = ModelCheckpoint(file_path, monitor='accuracy', save_best_only=True) checkpoint = ModelCheckpoint(file_path, monitor='accuracy', period=10, save_best_only=True)
# print_callback = LambdaCallback(on_epoch_end=on_epoch_end) # print_callback = LambdaCallback(on_epoch_end=on_epoch_end)
early_stopping = EarlyStopping(monitor='accuracy', patience=5) early_stopping = EarlyStopping(monitor='accuracy', patience=5)
callbacks_list = [checkpoint, early_stopping] callbacks_list = [checkpoint, early_stopping]
for i in range(nb_epoch): for i in range(0, nb_epoch, 10):
model.fit(predictors, label, initial_epoch=i, epochs=i + 1, verbose=2, callbacks=callbacks_list) model.fit(predictors, label, initial_epoch=i, epochs=min(i + 10, nb_epoch), verbose=2, callbacks=callbacks_list)
print(generate_text(model, tokenizer, "", nb_words, max_sequence_len))
# model.save(model_file) # model.save(model_file)
# else: # FIXME: Load and predict, maybe reuse checkpoints? # else: # FIXME: Load and predict, maybe reuse checkpoints?
# model = load_model(model_file) # model = load_model(model_file)
for i, seed in enumerate(load_seeds(lines, 3)): for i, seed in enumerate(load_seeds(lines, 5)):
output = generate_text(model, tokenizer, seed, nb_words, max_sequence_len) output = generate_text(model, tokenizer, seed, nb_words, max_sequence_len)
print("%i %s -> %s" % (i, seed, output)) print("%i %s -> %s" % (i, seed, output))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment