refactor(dawa): Generalize LSTM/Tweeper

parent fbbea615
from datetime import datetime
from keras.callbacks import ModelCheckpoint, EarlyStopping from keras.callbacks import ModelCheckpoint, EarlyStopping
from glossolalia.loader import load_seeds, load_text from glossolalia.loader import load_seeds, load_text
from glossolalia.lstm import generate_padded_sequences, create_model, generate_text from glossolalia.lstm import LisSansTaMaman
from glossolalia.tokens import PoemTokenizer
def main(): def train():
# should_train = True # should_train = True
# model_file = "../models/dawa_lstm_%i.hd5" % nb_epoch # model_file = "../models/dawa_lstm_%i.hd5" % nb_epoch
nb_words = 20 nb_words = 20
nb_epoch = 50 nb_epoch = 100
nb_layers = 64 nb_layers = 100
dropout = .2 dropout = .3 # TODO finetune layers/dropout
tokenizer = PoemTokenizer() validation_split = 0.2
lstm = LisSansTaMaman(nb_layers, dropout, validation_split, debug=True)
filename_model = "../models/dawa/dawa_lstm%i-d%.1f-{epoch:02d}_%i-{accuracy:.4f}.hdf5" % (
nb_layers, dropout, nb_epoch)
filename_output = "./output/dawa_%i-d%.1f_%s.txt" % (
nb_layers, dropout, datetime.now().strftime("%y%m%d_%H%M"))
callbacks_list = [
ModelCheckpoint(filename_model, monitor='val_accuracy', period=10, save_best_only=True),
EarlyStopping(monitor='val_accuracy', patience=5)]
# if should_train:
corpus = load_text() corpus = load_text()
print("Corpus:", corpus[:10]) print("Corpus:", corpus[:10])
lstm.create_model(corpus[:1000])
with open(filename_output, "a+") as f:
for i in range(0, nb_epoch, 10):
lstm.fit(epochs=min(i + 10, nb_epoch), initial_epoch=i,
callbacks=callbacks_list,
validation_split=validation_split)
inp_sequences, total_words = tokenizer.get_sequence_of_tokens(corpus) for seed in ["", "Je", "Tu", "Le", "La", "Les", "Un", "On", "Nous"]:
predictors, label, max_sequence_len = generate_padded_sequences(inp_sequences, total_words) print(lstm.predict(seed, nb_words))
model = create_model(max_sequence_len, total_words, layers=nb_layers, dropout=dropout)
model.summary()
file_path = "../models/dawa/dawa_lstm%i-d%.1f-{epoch:02d}_%i-{accuracy:.4f}.hdf5" % (nb_layers, dropout, nb_epoch)
checkpoint = ModelCheckpoint(file_path, monitor='accuracy', period=10, save_best_only=True)
# print_callback = LambdaCallback(on_epoch_end=on_epoch_end)
early_stopping = EarlyStopping(monitor='accuracy', patience=5)
callbacks_list = [checkpoint, early_stopping]
for i in range(0, nb_epoch, 10): # model.save(model_file)
model.fit(predictors, label, initial_epoch=i, epochs=min(i + 10, nb_epoch), verbose=2, callbacks=callbacks_list) # else: # FIXME: Load and predict, maybe reuse checkpoints?
for seed in ["", "Je", "Tu", "Le", "La", "Les", "Un", "On", "Nous"]: # model = load_model(model_file)
print(generate_text(model, tokenizer, seed, nb_words, max_sequence_len))
# model.save(model_file) for i, seed in enumerate(load_seeds(corpus, 5)):
# else: # FIXME: Load and predict, maybe reuse checkpoints? output = lstm.predict(seed, nb_words)
# model = load_model(model_file) print("%i %s -> %s" % (i, seed, output))
f.writelines(output)
for i, seed in enumerate(load_seeds(corpus, 5)):
output = generate_text(model, tokenizer, seed, nb_words, max_sequence_len)
print("%i %s -> %s" % (i, seed, output))
with open("./output/dawa.txt", "a+") as f:
while True: while True:
input_text = input("> ") input_text = input("> ")
text = generate_text(model, tokenizer, input_text, nb_words, max_sequence_len) text = lstm.predict(input_text, nb_words)
print(text) print(text)
f.writelines("%s\n" % text) f.writelines("%s\n" % text)
def debug_unrandomize():
from numpy.random import seed
from tensorflow_core.python.framework.random_seed import set_random_seed
# set seeds for reproducibility
set_random_seed(2)
seed(1)
if __name__ == '__main__': if __name__ == '__main__':
debug_unrandomize() train()
main()
from glossolalia.tweeper import Tweeper
def tweet():
# La nuit est belle, ma chérie salue sur la capuche
# grands brûlés de la chine
# Femme qui crame strasbourg
# le soleil est triste
# on a pas un martyr parce qu't'es la
# des neiges d'insuline
# une hypothèse qu'engendre la haine n'est qu'une prison vide
# Un jour de l'an commencé sur les autres
# Relater l'passionnel dans les casseroles d'eau de marécages
# sniff de Caravage rapide
# La nuit c'est le soleil
# Les rues d'ma vie se terminent par des partouzes de ciel
# des glaçons pour les yeux brisées
Tweeper("KoozDawa").tweet("tassepés en panel")
if __name__ == '__main__':
tweet()
import warnings import warnings
from typing import List
import numpy as np import numpy as np
from keras import Sequential from keras import Sequential, Model
from keras.callbacks import Callback, History
from keras.layers import Embedding, LSTM, Dropout, Dense from keras.layers import Embedding, LSTM, Dropout, Dense
from keras.utils import to_categorical from keras.utils import to_categorical
from keras_preprocessing.sequence import pad_sequences from keras_preprocessing.sequence import pad_sequences
from keras_preprocessing.text import Tokenizer from keras_preprocessing.text import Tokenizer
from glossolalia.tokens import PoemTokenizer
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
warnings.simplefilter(action='ignore', category=FutureWarning) warnings.simplefilter(action='ignore', category=FutureWarning)
# 3.3 Padding the Sequences and obtain Variables : Predictors and Target def debug_unrandomize():
def generate_padded_sequences(input_sequences, total_words): from numpy.random import seed
from tensorflow_core.python.framework.random_seed import set_random_seed
# set seeds for reproducibility
set_random_seed(2)
seed(1)
class LisSansTaMaman(object):
""" A LSTM model adapted for french lyrical texts."""
def __init__(self, nb_layers: int = 100,
dropout: float = 0.1, validation_split: float = 0.0,
tokenizer=PoemTokenizer(),
debug: bool = False):
self.validation_split = validation_split
self.dropout = dropout
self.nb_layers = nb_layers
self.tokenizer = tokenizer
# Model state
self.model: Model = None
self.predictors = None
self.label = None
self.max_sequence_len = None
if debug:
debug_unrandomize()
def create_model(self, corpus: List[str]):
inp_sequences, total_words = self.tokenizer.get_sequence_of_tokens(corpus)
self.predictors, self.label, self.max_sequence_len = generate_padded_sequences(inp_sequences, total_words)
model = create_model(self.max_sequence_len, total_words, layers=self.nb_layers, dropout=self.dropout)
model.summary()
self.model = model
# TODO: Batch fit? splitting nb_epoch into N step
def fit(self, epochs: int, initial_epoch: int = 0,
callbacks: List[Callback] = None,
validation_split: float = 0
) -> History:
return self.model.fit(self.predictors, self.label,
verbose=2,
callbacks=callbacks,
validation_split=validation_split,
epochs=epochs, initial_epoch=initial_epoch)
def predict(self, seed="", nb_words=None):
if nb_words is None:
nb_words = 20 # TODO: Guess based on model a good number of words
return generate_text(self.model, self.tokenizer, seed, nb_words, self.max_sequence_len)
def generate_padded_sequences(input_sequences, total_words: int):
max_sequence_len = max([len(x) for x in input_sequences]) max_sequence_len = max([len(x) for x in input_sequences])
input_sequences = np.array(pad_sequences(input_sequences, maxlen=max_sequence_len, padding='pre')) input_sequences = np.array(pad_sequences(input_sequences, maxlen=max_sequence_len, padding='pre'))
predictors, label = input_sequences[:, :-1], input_sequences[:, -1] predictors, label = input_sequences[:, :-1], input_sequences[:, -1]
...@@ -20,7 +79,7 @@ def generate_padded_sequences(input_sequences, total_words): ...@@ -20,7 +79,7 @@ def generate_padded_sequences(input_sequences, total_words):
return predictors, label, max_sequence_len return predictors, label, max_sequence_len
def create_model(max_sequence_len, total_words, layers=128, dropout=0.3): # TODO finetune layers/dropout def create_model(max_sequence_len: int, total_words: int, layers: int, dropout: float):
print("Creating model across %i words for %i-long seqs (%i layers, %.2f dropout):" % print("Creating model across %i words for %i-long seqs (%i layers, %.2f dropout):" %
(total_words, max_sequence_len, layers, dropout)) (total_words, max_sequence_len, layers, dropout))
input_len = max_sequence_len - 1 input_len = max_sequence_len - 1
......
#! /usr/bin/env python #! /usr/bin/env python
import os import os
import time import time
import tweepy import tweepy
from didyoumean3.didyoumean import did_you_mean from didyoumean3.didyoumean import did_you_mean
from tweepy import Cursor
class Tweeper(object): class Tweeper(object):
def __init__(self, name: str):
def __init__(self):
auth = tweepy.OAuthHandler( auth = tweepy.OAuthHandler(
os.environ["ZOO_DAWA_KEY"], os.environ["ZOO_DAWA_KEY"],
os.environ["ZOO_DAWA_KEY_SECRET"]) os.environ["ZOO_DAWA_KEY_SECRET"])
...@@ -15,24 +16,18 @@ class Tweeper(object): ...@@ -15,24 +16,18 @@ class Tweeper(object):
os.environ["ZOO_DAWA_TOKEN"], os.environ["ZOO_DAWA_TOKEN"],
os.environ["ZOO_DAWA_TOKEN_SECRET"]) os.environ["ZOO_DAWA_TOKEN_SECRET"])
self.api = tweepy.API(auth) self.api = tweepy.API(auth)
self.name = name
def tweet(self, message): @property
"""Tweets a message after spellchecking it.""" def all_tweets(self):
message = did_you_mean(message) return [t.text for t in Cursor(self.api.user_timeline, id=self.name).items()]
print("About to tweet:", message)
time.sleep(5)
self.api.update_status(message)
def main():
Tweeper().tweet("le business réel de la saint-valentin")
# Nous la nuit de la renaissance j’étais la tête
# Authenticate to Twitter def tweet(self, message, wait_delay=5, prevent_duplicate=True):
# tassepés en panel """Tweets a message after spellchecking it."""
# grands brûlés de la chine if prevent_duplicate and message in self.all_tweets:
# La nuit est belle, ma chérie salue sur la capuche print("Was already tweeted: %s." % message)
# Je suis pas étonné de dire pétrin else:
# Femme qui crame strasbourg message = did_you_mean(message)
if __name__ == '__main__': print("About to tweet:", message)
main() time.sleep(wait_delay)
self.api.update_status(message)
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment