refactor(dawa): Move unrandomize to lstm

parent 42b38e3e
...@@ -15,8 +15,7 @@ def train(): ...@@ -15,8 +15,7 @@ def train():
dropout = .3 # TODO finetune layers/dropout dropout = .3 # TODO finetune layers/dropout
validation_split = 0.2 validation_split = 0.2
lstm = LisSansTaMaman(nb_layers, dropout, validation_split, debug=True) lstm = LisSansTaMaman(nb_layers, dropout, validation_split, debug=True)
filename_model = "../models/dawa/dawa_lstm%i-d%.1f-{epoch:02d}_%i-{accuracy:.4f}.hdf5" % ( filename_model = "../models/dawa/dawa_lstm%i-d%.1f-{epoch:02d}_%i-{accuracy:.4f}.hdf5" % (nb_layers, dropout, nb_epoch)
nb_layers, dropout, nb_epoch)
filename_output = "./output/dawa_%i-d%.1f_%s.txt" % ( filename_output = "./output/dawa_%i-d%.1f_%s.txt" % (
nb_layers, dropout, datetime.now().strftime("%y%m%d_%H%M")) nb_layers, dropout, datetime.now().strftime("%y%m%d_%H%M"))
callbacks_list = [ callbacks_list = [
...@@ -32,12 +31,7 @@ def train(): ...@@ -32,12 +31,7 @@ def train():
callbacks=callbacks_list, callbacks=callbacks_list,
validation_split=validation_split) validation_split=validation_split)
for seed in ["", "Je", "Tu", "Le", "La", "Les", "Un", "On", "Nous"]: print(lstm.predict_seeds(nb_words))
print(lstm.predict(seed, nb_words))
# model.save(model_file)
# else: # FIXME: Load and predict, maybe reuse checkpoints?
# model = load_model(model_file)
for i, seed in enumerate(load_seeds(corpus, 5)): for i, seed in enumerate(load_seeds(corpus, 5)):
output = lstm.predict(seed, nb_words) output = lstm.predict(seed, nb_words)
......
...@@ -8,14 +8,16 @@ def tweet(): ...@@ -8,14 +8,16 @@ def tweet():
# le soleil est triste # le soleil est triste
# on a pas un martyr parce qu't'es la # on a pas un martyr parce qu't'es la
# des neiges d'insuline # des neiges d'insuline
# une hypothèse qu'engendre la haine n'est qu'une prison vide
# Un jour de l'an commencé sur les autres # Un jour de l'an commencé sur les autres
# Relater l'passionnel dans les casseroles d'eau de marécages # une hypothèse qu'engendre la haine n'est qu'une prison vide
# sniff de Caravage rapide # sniff de Caravage rapide
# Relater l'passionnel dans les casseroles d'eau de marécages
# La nuit c'est le soleil # La nuit c'est le soleil
# Les rues d'ma vie se terminent par la cannelle
# Les rues d'ma vie se terminent par des partouzes de ciel # Les rues d'ma vie se terminent par des partouzes de ciel
# des glaçons pour les yeux brisées # des glaçons pour les yeux brisées
# je suis pas juste un verbe que t'observe
Tweeper("KoozDawa").tweet("tassepés en panel") Tweeper("KoozDawa").tweet("tassepés en panel")
......
...@@ -15,15 +15,6 @@ warnings.filterwarnings("ignore") ...@@ -15,15 +15,6 @@ warnings.filterwarnings("ignore")
warnings.simplefilter(action='ignore', category=FutureWarning) warnings.simplefilter(action='ignore', category=FutureWarning)
def debug_unrandomize():
from numpy.random import seed
from tensorflow_core.python.framework.random_seed import set_random_seed
# set seeds for reproducibility
set_random_seed(2)
seed(1)
class LisSansTaMaman(object): class LisSansTaMaman(object):
""" A LSTM model adapted for french lyrical texts.""" """ A LSTM model adapted for french lyrical texts."""
...@@ -65,6 +56,11 @@ class LisSansTaMaman(object): ...@@ -65,6 +56,11 @@ class LisSansTaMaman(object):
validation_split=validation_split, validation_split=validation_split,
epochs=epochs, initial_epoch=initial_epoch) epochs=epochs, initial_epoch=initial_epoch)
def predict_seeds(self, seeds: List[str] = None, nb_words=None):
if seeds is None:
seeds = ["", "Je", "Tu", "Le", "La", "Les", "Un", "On", "Nous"]
return [self.predict(seed, nb_words) for seed in seeds]
def predict(self, seed="", nb_words=None): def predict(self, seed="", nb_words=None):
if nb_words is None: if nb_words is None:
nb_words = 20 # TODO: Guess based on model a good number of words nb_words = 20 # TODO: Guess based on model a good number of words
...@@ -115,3 +111,12 @@ def generate_text(model: Sequential, tokenizer: Tokenizer, seed_text="", nb_word ...@@ -115,3 +111,12 @@ def generate_text(model: Sequential, tokenizer: Tokenizer, seed_text="", nb_word
predicted = model.predict_classes(token_list, verbose=2)[0] predicted = model.predict_classes(token_list, verbose=2)[0]
output += " " + word_indices[predicted] output += " " + word_indices[predicted]
return output.capitalize() return output.capitalize()
def debug_unrandomize():
from numpy.random import seed
from tensorflow_core.python.framework.random_seed import set_random_seed
# set seeds for reproducibility
set_random_seed(2)
seed(1)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment