refactor(dawa): Move unrandomize to lstm

parent 42b38e3e
......@@ -15,8 +15,7 @@ def train():
dropout = .3 # TODO finetune layers/dropout
validation_split = 0.2
lstm = LisSansTaMaman(nb_layers, dropout, validation_split, debug=True)
filename_model = "../models/dawa/dawa_lstm%i-d%.1f-{epoch:02d}_%i-{accuracy:.4f}.hdf5" % (
nb_layers, dropout, nb_epoch)
filename_model = "../models/dawa/dawa_lstm%i-d%.1f-{epoch:02d}_%i-{accuracy:.4f}.hdf5" % (nb_layers, dropout, nb_epoch)
filename_output = "./output/dawa_%i-d%.1f_%s.txt" % (
nb_layers, dropout, datetime.now().strftime("%y%m%d_%H%M"))
callbacks_list = [
......@@ -32,12 +31,7 @@ def train():
callbacks=callbacks_list,
validation_split=validation_split)
for seed in ["", "Je", "Tu", "Le", "La", "Les", "Un", "On", "Nous"]:
print(lstm.predict(seed, nb_words))
# model.save(model_file)
# else: # FIXME: Load and predict, maybe reuse checkpoints?
# model = load_model(model_file)
print(lstm.predict_seeds(nb_words))
for i, seed in enumerate(load_seeds(corpus, 5)):
output = lstm.predict(seed, nb_words)
......
......@@ -8,14 +8,16 @@ def tweet():
# le soleil est triste
# on a pas un martyr parce qu't'es la
# des neiges d'insuline
# une hypothèse qu'engendre la haine n'est qu'une prison vide
# Un jour de l'an commencé sur les autres
# Relater l'passionnel dans les casseroles d'eau de marécages
# une hypothèse qu'engendre la haine n'est qu'une prison vide
# sniff de Caravage rapide
# Relater l'passionnel dans les casseroles d'eau de marécages
# La nuit c'est le soleil
# Les rues d'ma vie se terminent par la cannelle
# Les rues d'ma vie se terminent par des partouzes de ciel
# des glaçons pour les yeux brisées
# je suis pas juste un verbe que t'observe
Tweeper("KoozDawa").tweet("tassepés en panel")
......
......@@ -15,15 +15,6 @@ warnings.filterwarnings("ignore")
warnings.simplefilter(action='ignore', category=FutureWarning)
def debug_unrandomize():
from numpy.random import seed
from tensorflow_core.python.framework.random_seed import set_random_seed
# set seeds for reproducibility
set_random_seed(2)
seed(1)
class LisSansTaMaman(object):
""" A LSTM model adapted for french lyrical texts."""
......@@ -65,6 +56,11 @@ class LisSansTaMaman(object):
validation_split=validation_split,
epochs=epochs, initial_epoch=initial_epoch)
def predict_seeds(self, seeds: List[str] = None, nb_words=None):
if seeds is None:
seeds = ["", "Je", "Tu", "Le", "La", "Les", "Un", "On", "Nous"]
return [self.predict(seed, nb_words) for seed in seeds]
def predict(self, seed="", nb_words=None):
if nb_words is None:
nb_words = 20 # TODO: Guess based on model a good number of words
......@@ -115,3 +111,12 @@ def generate_text(model: Sequential, tokenizer: Tokenizer, seed_text="", nb_word
predicted = model.predict_classes(token_list, verbose=2)[0]
output += " " + word_indices[predicted]
return output.capitalize()
def debug_unrandomize():
from numpy.random import seed
from tensorflow_core.python.framework.random_seed import set_random_seed
# set seeds for reproducibility
set_random_seed(2)
seed(1)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment