refactor(dawa): Generalize LSTM/Tweeper

parent fbbea615
from datetime import datetime
from keras.callbacks import ModelCheckpoint, EarlyStopping
from glossolalia.loader import load_seeds, load_text
from glossolalia.lstm import generate_padded_sequences, create_model, generate_text
from glossolalia.tokens import PoemTokenizer
from glossolalia.lstm import LisSansTaMaman
def main():
def train():
# should_train = True
# model_file = "../models/dawa_lstm_%i.hd5" % nb_epoch
nb_words = 20
nb_epoch = 50
nb_layers = 64
dropout = .2
tokenizer = PoemTokenizer()
nb_epoch = 100
nb_layers = 100
dropout = .3 # TODO finetune layers/dropout
validation_split = 0.2
lstm = LisSansTaMaman(nb_layers, dropout, validation_split, debug=True)
filename_model = "../models/dawa/dawa_lstm%i-d%.1f-{epoch:02d}_%i-{accuracy:.4f}.hdf5" % (
nb_layers, dropout, nb_epoch)
filename_output = "./output/dawa_%i-d%.1f_%s.txt" % (
nb_layers, dropout, datetime.now().strftime("%y%m%d_%H%M"))
callbacks_list = [
ModelCheckpoint(filename_model, monitor='val_accuracy', period=10, save_best_only=True),
EarlyStopping(monitor='val_accuracy', patience=5)]
# if should_train:
corpus = load_text()
print("Corpus:", corpus[:10])
lstm.create_model(corpus[:1000])
with open(filename_output, "a+") as f:
for i in range(0, nb_epoch, 10):
lstm.fit(epochs=min(i + 10, nb_epoch), initial_epoch=i,
callbacks=callbacks_list,
validation_split=validation_split)
inp_sequences, total_words = tokenizer.get_sequence_of_tokens(corpus)
predictors, label, max_sequence_len = generate_padded_sequences(inp_sequences, total_words)
model = create_model(max_sequence_len, total_words, layers=nb_layers, dropout=dropout)
model.summary()
file_path = "../models/dawa/dawa_lstm%i-d%.1f-{epoch:02d}_%i-{accuracy:.4f}.hdf5" % (nb_layers, dropout, nb_epoch)
checkpoint = ModelCheckpoint(file_path, monitor='accuracy', period=10, save_best_only=True)
# print_callback = LambdaCallback(on_epoch_end=on_epoch_end)
early_stopping = EarlyStopping(monitor='accuracy', patience=5)
callbacks_list = [checkpoint, early_stopping]
for seed in ["", "Je", "Tu", "Le", "La", "Les", "Un", "On", "Nous"]:
print(lstm.predict(seed, nb_words))
for i in range(0, nb_epoch, 10):
model.fit(predictors, label, initial_epoch=i, epochs=min(i + 10, nb_epoch), verbose=2, callbacks=callbacks_list)
for seed in ["", "Je", "Tu", "Le", "La", "Les", "Un", "On", "Nous"]:
print(generate_text(model, tokenizer, seed, nb_words, max_sequence_len))
# model.save(model_file)
# else: # FIXME: Load and predict, maybe reuse checkpoints?
# model = load_model(model_file)
# model.save(model_file)
# else: # FIXME: Load and predict, maybe reuse checkpoints?
# model = load_model(model_file)
for i, seed in enumerate(load_seeds(corpus, 5)):
output = lstm.predict(seed, nb_words)
print("%i %s -> %s" % (i, seed, output))
f.writelines(output)
for i, seed in enumerate(load_seeds(corpus, 5)):
output = generate_text(model, tokenizer, seed, nb_words, max_sequence_len)
print("%i %s -> %s" % (i, seed, output))
with open("./output/dawa.txt", "a+") as f:
while True:
input_text = input("> ")
text = generate_text(model, tokenizer, input_text, nb_words, max_sequence_len)
text = lstm.predict(input_text, nb_words)
print(text)
f.writelines("%s\n" % text)
def debug_unrandomize():
from numpy.random import seed
from tensorflow_core.python.framework.random_seed import set_random_seed
# set seeds for reproducibility
set_random_seed(2)
seed(1)
if __name__ == '__main__':
debug_unrandomize()
main()
train()
from glossolalia.tweeper import Tweeper
def tweet():
# La nuit est belle, ma chérie salue sur la capuche
# grands brûlés de la chine
# Femme qui crame strasbourg
# le soleil est triste
# on a pas un martyr parce qu't'es la
# des neiges d'insuline
# une hypothèse qu'engendre la haine n'est qu'une prison vide
# Un jour de l'an commencé sur les autres
# Relater l'passionnel dans les casseroles d'eau de marécages
# sniff de Caravage rapide
# La nuit c'est le soleil
# Les rues d'ma vie se terminent par des partouzes de ciel
# des glaçons pour les yeux brisées
Tweeper("KoozDawa").tweet("tassepés en panel")
if __name__ == '__main__':
tweet()
import warnings
from typing import List
import numpy as np
from keras import Sequential
from keras import Sequential, Model
from keras.callbacks import Callback, History
from keras.layers import Embedding, LSTM, Dropout, Dense
from keras.utils import to_categorical
from keras_preprocessing.sequence import pad_sequences
from keras_preprocessing.text import Tokenizer
from glossolalia.tokens import PoemTokenizer
warnings.filterwarnings("ignore")
warnings.simplefilter(action='ignore', category=FutureWarning)
# 3.3 Padding the Sequences and obtain Variables : Predictors and Target
def generate_padded_sequences(input_sequences, total_words):
def debug_unrandomize():
from numpy.random import seed
from tensorflow_core.python.framework.random_seed import set_random_seed
# set seeds for reproducibility
set_random_seed(2)
seed(1)
class LisSansTaMaman(object):
""" A LSTM model adapted for french lyrical texts."""
def __init__(self, nb_layers: int = 100,
dropout: float = 0.1, validation_split: float = 0.0,
tokenizer=PoemTokenizer(),
debug: bool = False):
self.validation_split = validation_split
self.dropout = dropout
self.nb_layers = nb_layers
self.tokenizer = tokenizer
# Model state
self.model: Model = None
self.predictors = None
self.label = None
self.max_sequence_len = None
if debug:
debug_unrandomize()
def create_model(self, corpus: List[str]):
inp_sequences, total_words = self.tokenizer.get_sequence_of_tokens(corpus)
self.predictors, self.label, self.max_sequence_len = generate_padded_sequences(inp_sequences, total_words)
model = create_model(self.max_sequence_len, total_words, layers=self.nb_layers, dropout=self.dropout)
model.summary()
self.model = model
# TODO: Batch fit? splitting nb_epoch into N step
def fit(self, epochs: int, initial_epoch: int = 0,
callbacks: List[Callback] = None,
validation_split: float = 0
) -> History:
return self.model.fit(self.predictors, self.label,
verbose=2,
callbacks=callbacks,
validation_split=validation_split,
epochs=epochs, initial_epoch=initial_epoch)
def predict(self, seed="", nb_words=None):
if nb_words is None:
nb_words = 20 # TODO: Guess based on model a good number of words
return generate_text(self.model, self.tokenizer, seed, nb_words, self.max_sequence_len)
def generate_padded_sequences(input_sequences, total_words: int):
max_sequence_len = max([len(x) for x in input_sequences])
input_sequences = np.array(pad_sequences(input_sequences, maxlen=max_sequence_len, padding='pre'))
predictors, label = input_sequences[:, :-1], input_sequences[:, -1]
......@@ -20,7 +79,7 @@ def generate_padded_sequences(input_sequences, total_words):
return predictors, label, max_sequence_len
def create_model(max_sequence_len, total_words, layers=128, dropout=0.3): # TODO finetune layers/dropout
def create_model(max_sequence_len: int, total_words: int, layers: int, dropout: float):
print("Creating model across %i words for %i-long seqs (%i layers, %.2f dropout):" %
(total_words, max_sequence_len, layers, dropout))
input_len = max_sequence_len - 1
......
#! /usr/bin/env python
import os
import time
import tweepy
from didyoumean3.didyoumean import did_you_mean
from tweepy import Cursor
class Tweeper(object):
def __init__(self):
def __init__(self, name: str):
auth = tweepy.OAuthHandler(
os.environ["ZOO_DAWA_KEY"],
os.environ["ZOO_DAWA_KEY_SECRET"])
......@@ -15,24 +16,18 @@ class Tweeper(object):
os.environ["ZOO_DAWA_TOKEN"],
os.environ["ZOO_DAWA_TOKEN_SECRET"])
self.api = tweepy.API(auth)
self.name = name
def tweet(self, message):
"""Tweets a message after spellchecking it."""
message = did_you_mean(message)
print("About to tweet:", message)
time.sleep(5)
self.api.update_status(message)
def main():
Tweeper().tweet("le business réel de la saint-valentin")
# Nous la nuit de la renaissance j’étais la tête
@property
def all_tweets(self):
return [t.text for t in Cursor(self.api.user_timeline, id=self.name).items()]
# Authenticate to Twitter
# tassepés en panel
# grands brûlés de la chine
# La nuit est belle, ma chérie salue sur la capuche
# Je suis pas étonné de dire pétrin
# Femme qui crame strasbourg
if __name__ == '__main__':
main()
def tweet(self, message, wait_delay=5, prevent_duplicate=True):
"""Tweets a message after spellchecking it."""
if prevent_duplicate and message in self.all_tweets:
print("Was already tweeted: %s." % message)
else:
message = did_you_mean(message)
print("About to tweet:", message)
time.sleep(wait_delay)
self.api.update_status(message)
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment