from keras_preprocessing.text import Tokenizer

from glossolalia.loader import load_texts


class PoemTokenizer(Tokenizer):
    def __init__(self, **kwargs) -> None:
        super().__init__(lower=True,  # TODO: Better generalization without?
                         filters='$%&()*+/<=>@[\\]^_`{|}~\t\n', oov_token="😢",
                         **kwargs)

    def get_sequence_of_tokens(self, corpus):
        self.fit_on_texts(corpus)
        total_words = len(self.word_index) + 1

        # convert data to sequence of tokens
        input_sequences = []

        for line in corpus:
            token_list = self.texts_to_sequences([line])[0]
            for i in range(1, len(token_list)):
                n_gram_sequence = token_list[:i + 1]
                input_sequences.append(n_gram_sequence)

        texts = self.sequences_to_texts(input_sequences)
        print("Tokenized:", texts[:5])

        return input_sequences, total_words

    def get_text(self, sequence):
        return self.sequences_to_texts(sequence)


if __name__ == '__main__':
    kawa = load_texts("../")
    tokenizer = PoemTokenizer()
    seqs, words = tokenizer.get_sequence_of_tokens(kawa)
    texts = tokenizer.get_text(seqs)
    print("%i words." % words)