Skip to content

Commit

Permalink
added defaultTokenizer
Browse files Browse the repository at this point in the history
  • Loading branch information
Arnaud Bergeron authored and Arnaud Bergeron committed Sep 6, 2023
1 parent d110664 commit ac256a9
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 10 deletions.
18 changes: 14 additions & 4 deletions DatasetTokenizer.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,22 @@
from EquationTokenizer import EquationTokenizer
from EquationTokenizer import EquationTokenizer, defaultTokenizer

class DatasetTokenizer(EquationTokenizer):
def __init__(self, x, y):
def __init__(self, x, y, char_set_=None):
self.x = x
self.y = y
super().__init__()
self.char_set = self.create_set_char()
self.create_tokenizer(self.char_set)
if char_set_ is None:
self.char_set = self.create_set_char()
else:
self.char_set = char_set_

if self.char_set <= set(defaultTokenizer()[0].keys()):
print('Using default tokenizer.')
self.tokenize_dict, self.decode_dict, self.tokenize, self.decode = defaultTokenizer()
self.dict_size = len(self.tokenize_dict)

else:
self.create_tokenizer(self.char_set)

self.x_tokenized = [self.tokenize(i) for i in self.x]
self.y_tokenized = [self.tokenize(i) for i in self.y]
Expand Down
12 changes: 8 additions & 4 deletions EquationCreator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import sympy as sp
import random

import numba


class EquationCreator:
Expand All @@ -21,7 +21,7 @@ def __init__(self, order=3):
sp.acoth, sp.asech, sp.acsch
]


@numba.jit()
def generate_random_function(self):
#Generate a random function that will be the solution to the DE

Expand All @@ -39,11 +39,15 @@ def generate_random_function(self):
function = function**exponent

# Generate a random coefficient for the function
coefficient = random.randint(1, 10)
coefficient_p = random.randint(1, 10)
coefficient_q = random.randint(1, 10)
coefficient = sp.Rational(coefficient_p, coefficient_q)

function = coefficient * function

return function


@numba.jit()
def generate_random_differential_equation(self, function, num_op = 3):
# Generate a random differential equation with solution: function

Expand Down
27 changes: 25 additions & 2 deletions EquationTokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,12 +150,16 @@ def sympy_to_tokens(self, sympy_eq):

def tokens_to_sympy(self, tokens):
"""Takes in a tokenized list and outputs a sympy equation."""
decoded_seq = self.tokens_to_list(tokens)
seq = self.list_to_sympy(decoded_seq)
return seq

def tokens_to_list(self, tokens):
if self.tokenize is None:
raise('Tokenizer not created yet.')
decoded_seq = self.decode(tokens)
decoded_seq = [i for i in decoded_seq if i not in ['START','END','PAD']]
seq = self.list_to_sympy(decoded_seq)
return seq
return decoded_seq

def create_tokenizer(self, symbol_set):
"""Takes a set of symbols and creates a tokenizer for them."""
Expand Down Expand Up @@ -198,3 +202,22 @@ def tensorize_and_pad_by_len(self, list_of_token_list, max_len):
output = pad_sequence(list_of_token_list, batch_first=True, padding_value=pad_val)

return output[:-1]

def defaultTokenizer():
"""Returns a default tokenizer. Because of issues with pickling."""
tokenize_dict = {')': 0, sp.acsc: 1, sp.acot: 2, sp.asech: 3, sp.core.containers.Tuple: 4,'/': 5, sp.sech: 6,
'END': 7, sp.exp: 8, '7': 9, '0': 10, sp.asin: 11, '5': 12, sp.core.function.Derivative: 13,
'8': 14, sp.asec: 15, sp.core.add.Add: 16, sp.core.power.Pow: 17, sp.csch: 18, 'START': 19, sp.csc: 20, 'PAD': 21, sp.sin: 22,
',': 23, sp.acsch: 24, sp.core.relational.Equality: 25, '(': 26, '2': 27, sp.Symbol('x'): 28, sp.coth: 29, sp.Symbol('y'): 30, sp.log: 31, sp.cos: 32,
'6': 33, sp.core.mul.Mul: 34, sp.acos: 35, '9': 36, sp.Function('f'): 37, '-': 38, sp.sqrt: 39, sp.cosh: 40, sp.tan: 41, sp.tanh: 42, sp.Symbol('z'): 43,
'4': 44, '3': 45, sp.cot: 46, sp.asinh: 47, sp.atan: 48, sp.acosh: 49, '1': 50, sp.atanh: 51, '.': 52, sp.sinh: 53, sp.acoth: 54, sp.sec: 55}

#invert tokenizer_dict into decode_dict
decode_dict = {v: k for k, v in tokenize_dict.items()}

tokenize = lambda x: [tokenize_dict['START']] + [tokenize_dict[i] for i in x] + [tokenize_dict['END']]
decode = lambda x: [decode_dict[i] for i in x]

dict_size = len(tokenize_dict)

return tokenize_dict, decode_dict, tokenize, decode

0 comments on commit ac256a9

Please sign in to comment.