feat(tokens): Lowercase made optional

parent e8ebdd6c
...@@ -8,13 +8,13 @@ from glossolalia.lstm import LisSansTaMaman ...@@ -8,13 +8,13 @@ from glossolalia.lstm import LisSansTaMaman
def train(): def train():
# should_train = True # should_train = True
# model_file = "../models/dawa_lstm_%i.hd5" % nb_epoch
nb_words = 20 nb_words = 20
nb_epoch = 100 nb_epoch = 100
nb_layers = 100 nb_layers = 100
dropout = .3 # TODO finetune layers/dropout dropout = .3 # TODO finetune layers/dropout
validation_split = 0.2 validation_split = 0.2
lstm = LisSansTaMaman(nb_layers, dropout, validation_split, debug=True) lstm = LisSansTaMaman(nb_layers, dropout, validation_split, debug=True)
filename_model = "../models/dawa/dawa_lstm%i-d%.1f-{epoch:02d}_%i-{accuracy:.4f}.hdf5" % (nb_layers, dropout, nb_epoch) filename_model = "../models/dawa/dawa_lstm%i-d%.1f-{epoch:02d}_%i-{accuracy:.4f}.hdf5" % (nb_layers, dropout, nb_epoch)
filename_output = "./output/dawa_%i-d%.1f_%s.txt" % ( filename_output = "./output/dawa_%i-d%.1f_%s.txt" % (
nb_layers, dropout, datetime.now().strftime("%y%m%d_%H%M")) nb_layers, dropout, datetime.now().strftime("%y%m%d_%H%M"))
...@@ -31,7 +31,9 @@ def train(): ...@@ -31,7 +31,9 @@ def train():
callbacks=callbacks_list, callbacks=callbacks_list,
validation_split=validation_split) validation_split=validation_split)
print(lstm.predict_seeds(nb_words)) for output in lstm.predict_seeds(nb_words):
print(output)
f.writelines(output)
for i, seed in enumerate(load_seeds(corpus, 5)): for i, seed in enumerate(load_seeds(corpus, 5)):
output = lstm.predict(seed, nb_words) output = lstm.predict(seed, nb_words)
......
...@@ -11,10 +11,10 @@ def train(): ...@@ -11,10 +11,10 @@ def train():
nb_words = 20 nb_words = 20
nb_epoch = 50 nb_epoch = 50
nb_layers = 64 nb_layers = 64
dropout = .2 dropout = .2 # TODO finetune layers/dropout
# TODO finetune layers/dropout
validation_split = 0.2 validation_split = 0.2
lstm = LisSansTaMaman(nb_layers, dropout, validation_split, debug=True) lstm = LisSansTaMaman(nb_layers, dropout, validation_split, debug=True)
#
filename_model = "../models/boulbi/boulbi_lstm%i-d%.1f-{epoch:02d}_%i-{accuracy:.4f}.hdf5" % ( filename_model = "../models/boulbi/boulbi_lstm%i-d%.1f-{epoch:02d}_%i-{accuracy:.4f}.hdf5" % (
nb_layers, dropout, nb_epoch) nb_layers, dropout, nb_epoch)
filename_output = "./output/boulbi_%i-d%.1f_%s.txt" % ( filename_output = "./output/boulbi_%i-d%.1f_%s.txt" % (
...@@ -32,7 +32,9 @@ def train(): ...@@ -32,7 +32,9 @@ def train():
callbacks=callbacks_list, callbacks=callbacks_list,
validation_split=validation_split) validation_split=validation_split)
print(lstm.predict_seeds(nb_words)) for output in lstm.predict_seeds(nb_words):
print(output)
f.writelines(output)
for i, seed in enumerate(load_seeds(corpus, 5)): for i, seed in enumerate(load_seeds(corpus, 5)):
output = lstm.predict(seed, nb_words) output = lstm.predict(seed, nb_words)
......
...@@ -2,6 +2,7 @@ from glossolalia import loader ...@@ -2,6 +2,7 @@ from glossolalia import loader
def clean(text): def clean(text):
# TODO: Remove lines with ???
# Replace literal newlines # Replace literal newlines
# Remove empty lines # Remove empty lines
# Replace ’ by ' # Replace ’ by '
......
...@@ -44,6 +44,7 @@ class LisSansTaMaman(object): ...@@ -44,6 +44,7 @@ class LisSansTaMaman(object):
model.summary() model.summary()
self.model = model self.model = model
print("Max sequence length:", self.max_sequence_len)
# TODO: Batch fit? splitting nb_epoch into N step # TODO: Batch fit? splitting nb_epoch into N step
def fit(self, epochs: int, initial_epoch: int = 0, def fit(self, epochs: int, initial_epoch: int = 0,
......
...@@ -4,10 +4,10 @@ from glossolalia.loader import load_texts ...@@ -4,10 +4,10 @@ from glossolalia.loader import load_texts
class PoemTokenizer(Tokenizer): class PoemTokenizer(Tokenizer):
def __init__(self, **kwargs) -> None: def __init__(self, lower:bool = True, **kwargs) -> None:
super().__init__(lower=True, # TODO: Better generalization without? super().__init__(lower=lower, # TODO: Better generalization without?
filters='$%&()*+/<=>@[\\]^_`{|}~\t\n', oov_token="😢", filters='$%&*+/<=>@[\\]^_`{|}~\t\n', oov_token="😢",
**kwargs) **kwargs) #TODO: keep newlines
def get_sequence_of_tokens(self, corpus): def get_sequence_of_tokens(self, corpus):
self.fit_on_texts(corpus) self.fit_on_texts(corpus)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment