Skip to content

Commit

Permalink
Merge pull request #35 from arnaudbergeron/l2_loss
Browse files Browse the repository at this point in the history
L2 loss
  • Loading branch information
arnaudbergeron authored Jan 31, 2024
2 parents af761ae + e5aa1f8 commit 01b6ab0
Show file tree
Hide file tree
Showing 18 changed files with 479 additions and 166 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -52,5 +52,9 @@ coverage.xml
*.log
*.pot

.vscode

# Sphinx documentation
docs/_build/

.vscode
149 changes: 149 additions & 0 deletions NumGI/ConstantDictionaries.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
from __future__ import annotations

import sympy as sp
import torch

SP_TO_TORCH = {
sp.sin: torch.sin,
sp.cos: torch.cos,
sp.tan: torch.tan,
sp.exp: torch.exp,
sp.log: torch.log,
sp.asin: torch.asin,
sp.acos: torch.acos,
sp.atan: torch.atan,
sp.sinh: torch.sinh,
sp.cosh: torch.cosh,
sp.tanh: torch.tanh,
sp.asinh: torch.asinh,
sp.acosh: torch.acosh,
sp.atanh: torch.atanh,
sp.Mul: torch.mul,
sp.Add: torch.add,
sp.Pow: torch.pow,
sp.Abs: torch.abs,
sp.cot: lambda x: torch.divide(1, torch.tan(x)),
sp.acot: lambda x: torch.atan(torch.divide(1, x)),
sp.sec: lambda x: torch.divide(1, torch.cos(x)),
sp.asec: lambda x: torch.acos(torch.divide(1, x)),
sp.csc: lambda x: torch.divide(1, torch.sin(x)),
sp.acsc: lambda x: torch.asin(torch.divide(1, x)),
sp.coth: lambda x: torch.divide(1, torch.tanh(x)),
sp.acoth: lambda x: torch.atanh(torch.divide(1, x)),
sp.sech: lambda x: torch.divide(1, torch.cosh(x)),
sp.asech: lambda x: torch.log(
torch.add(torch.divide(1, x), torch.sqrt(torch.sub(torch.pow(torch.divide(1, x), 2), 1)))
),
sp.csch: lambda x: torch.divide(1, torch.sinh(x)),
sp.acsch: lambda x: torch.log(
torch.add(torch.divide(1, x), torch.sqrt(torch.add(torch.pow(torch.divide(1, x), 2), 1)))
),
}

DIFFERENTIAL_FUNCTIONS = [
sp.sin,
sp.cos,
sp.tan,
sp.exp,
sp.log,
sp.asin,
sp.acos,
sp.atan,
sp.sinh,
sp.cosh,
sp.tanh,
sp.asinh,
sp.acosh,
sp.atanh,
sp.cot,
sp.acot,
sp.sec,
sp.asec,
sp.csc,
sp.acsc,
sp.coth,
sp.acoth,
sp.sech,
sp.asech,
sp.csch,
sp.acsch,
]

OPERATIONS = [
("multiplication", "arithmetic"),
("addition", "arithmetic"),
("subtraction", "arithmetic"),
("division", "arithmetic"),
("differential", "differential"),
# ("integration", "integration"),
("exponent", "exponent"),
]

VARIABLES = ["x", "y", "z", "beta", "gamma"]

DEFAULT_DICT = {
")": 0,
sp.acsc: 1,
sp.acot: 2,
sp.asech: 3,
sp.core.containers.Tuple: 4,
"/": 5,
sp.sech: 6,
"END": 7,
sp.exp: 8,
"7": 9,
"0": 10,
sp.asin: 11,
"5": 12,
sp.core.function.Derivative: 13,
"8": 14,
sp.asec: 15,
sp.core.add.Add: 16,
sp.core.power.Pow: 17,
sp.csch: 18,
"START": 19,
sp.csc: 20,
"PAD": 21,
sp.sin: 22,
",": 23,
sp.acsch: 24,
sp.core.relational.Equality: 25,
"(": 26,
"2": 27,
sp.Symbol("x"): 28,
sp.coth: 29,
sp.Symbol("y"): 30,
sp.log: 31,
sp.cos: 32,
"6": 33,
sp.core.mul.Mul: 34,
sp.acos: 35,
"9": 36,
sp.Function("f"): 37,
"-": 38,
sp.sqrt: 39,
sp.cosh: 40,
sp.tan: 41,
sp.tanh: 42,
sp.Symbol("z"): 43,
"4": 44,
"3": 45,
sp.cot: 46,
sp.asinh: 47,
sp.atan: 48,
sp.acosh: 49,
"1": 50,
sp.atanh: 51,
".": 52,
sp.sinh: 53,
sp.acoth: 54,
sp.sec: 55,
sp.Symbol("beta"): 56,
sp.Symbol("gamma"): 57,
sp.Symbol("delta"): 58,
sp.Symbol("a"): 59,
sp.Symbol("b"): 60,
sp.Symbol("c"): 61,
sp.Symbol("d"): 62,
sp.Symbol("epsilon"): 63,
}
167 changes: 91 additions & 76 deletions NumGI/EquationTokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
from sympy.core.numbers import Rational
from torch.nn.utils.rnn import pad_sequence

from NumGI.ConstantDictionaries import DEFAULT_DICT
from NumGI.ConstantDictionaries import SP_TO_TORCH


class EquationTokenizer:
"""Tokenizer for equations.
Expand All @@ -30,6 +33,13 @@ def __init__(self, useDefaultTokenizer=False):
self.dict_size = len(self.tokenize_dict)
self.char_set = set(self.tokenize_dict.keys())

if torch.cuda.is_available():
self.device = "cuda"
elif torch.backends.mps.is_available():
self.device = "mps"
else:
self.device = "cpu"

def sympy_to_list(self, sympy_equation) -> list:
"""Converts a sympy equation to a list that will be tokenized.
Expand All @@ -53,11 +63,7 @@ def sympy_to_list(self, sympy_equation) -> list:
for ind, arg in enumerate(eq_args):
sub_arg_list = self.sympy_to_list(arg)
for _sub in sub_arg_list:
if (
isinstance(_sub, Float)
or isinstance(_sub, Integer)
or isinstance(_sub, Rational)
):
if self.is_number(_sub):
# perhaps not general enough should allow for more types
# the idea is we want to tokenize '12.2' as '1','2,'.','2' and not '12.2'
for i in str(_sub):
Expand All @@ -72,6 +78,65 @@ def sympy_to_list(self, sympy_equation) -> list:

return eq_list

def sympy_to_numpy(self, sympy_equation):
"""Converts a sympy equation to a numpy function.
This is a util func.
"""
symbols = list(sympy_equation.free_symbols)
return sp.lambdify(symbols, sympy_equation, "numpy"), symbols

def sympy_to_torch(self, sympy_equation):
"""Converts a sympy equation to a pytorch function.
This is a util func.
"""
# simplified_eq = sympy_equation.simplify()
simplified_eq = sympy_equation
sympy_list = self.sympy_to_list(simplified_eq)
grouped_num_list = self._regroup_numbers(sympy_list)
parsed_list = self._parantheses_to_list(grouped_num_list)[0][0]

variables = list(simplified_eq.free_symbols)
variables = [str(i) for i in variables]

def torch_func(**kwargs):
return self._utils_exec_torch(parsed_list, **kwargs)

return torch_func, variables

def _utils_exec_torch(self, sympy_list, **kwargs):
"""Converts a sympy list to a torch function.
This is a util func.
"""
function = sympy_list[0]
torch_function = SP_TO_TORCH[function]
args_list = []
for i in sympy_list[1:]:
if isinstance(i, list):
args_list.append(self._utils_exec_torch(i, **kwargs))
elif isinstance(i, sp.Symbol):
args_list.append(kwargs[str(i)])
elif self.is_number(i):
args_list.append(torch.tensor(float(i), device=self.device))
elif i == sp.core.numbers.Pi:
args_list.append(torch.tensor(torch.pi, device=self.device))
else:
raise ValueError(f"Unknown type: {type(i)}, for {i}")

if len(args_list) > 2 and (function == sp.Add or function == sp.Mul):
return self.call_multi_input_torch(torch_function, args_list)

else:
return torch_function(*args_list)

def call_multi_input_torch(self, func, args):
if len(args) > 2:
return func(args[0], self.call_multi_input_torch(func, args[1:]))
else:
return func(args[0], args[1])

def _parantheses_to_list(self, eq_list):
"""Converts a list with parentheses to a list of lists according to parentheses.
Expand Down Expand Up @@ -208,93 +273,43 @@ def tensorize_and_pad(self, list_of_token_list):
"""Takes in a list of tokenized lists and outputs a padded tensor of tensors."""
pad_val = self.tokenize_dict["PAD"]

list_of_token_list = [torch.tensor(i) for i in list_of_token_list]

list_of_token_list = [torch.tensor(i, device=self.device) for i in list_of_token_list]
output = pad_sequence(list_of_token_list, batch_first=True, padding_value=pad_val)

return output

def tensorize_and_pad_by_len(self, list_of_token_list, max_len):
"""Takes in a list of tokenized lists and outputs a padded tensor of defined length."""
pad_val = self.tokenize_dict["PAD"]
list_of_token_list = [torch.tensor(i, device=self.device) for i in list_of_token_list]

return self._pad_tensors(list_of_token_list, max_len, pad_val)

list_of_token_list = [torch.tensor(i) for i in list_of_token_list]
_extra = torch.zeros(max_len)
def pad_by_len(self, list_of_token_list, max_len):
"""Takes in a list of tokenized lists and outputs a padded tensor of defined length."""
pad_val = self.tokenize_dict["PAD"]
list_of_token_list = [i.to(self.device) for i in list_of_token_list]

return self._pad_tensors(list_of_token_list, max_len, pad_val)

def _pad_tensors(self, list_of_token_list, max_len, pad_val):
_extra = torch.zeros(max_len, device=self.device)
list_of_token_list.append(_extra)

output = pad_sequence(list_of_token_list, batch_first=True, padding_value=pad_val)
return output[torch.max((output != _extra), axis=1).values]

return output[:-1]
def is_number(self, sp_class):
return (
isinstance(sp_class, Float)
or isinstance(sp_class, Integer)
or isinstance(sp_class, Rational)
)


def defaultTokenizer():
"""Returns a default tokenizer. Because of issues with pickling."""
tokenize_dict = {
")": 0,
sp.acsc: 1,
sp.acot: 2,
sp.asech: 3,
sp.core.containers.Tuple: 4,
"/": 5,
sp.sech: 6,
"END": 7,
sp.exp: 8,
"7": 9,
"0": 10,
sp.asin: 11,
"5": 12,
sp.core.function.Derivative: 13,
"8": 14,
sp.asec: 15,
sp.core.add.Add: 16,
sp.core.power.Pow: 17,
sp.csch: 18,
"START": 19,
sp.csc: 20,
"PAD": 21,
sp.sin: 22,
",": 23,
sp.acsch: 24,
sp.core.relational.Equality: 25,
"(": 26,
"2": 27,
sp.Symbol("x"): 28,
sp.coth: 29,
sp.Symbol("y"): 30,
sp.log: 31,
sp.cos: 32,
"6": 33,
sp.core.mul.Mul: 34,
sp.acos: 35,
"9": 36,
sp.Function("f"): 37,
"-": 38,
sp.sqrt: 39,
sp.cosh: 40,
sp.tan: 41,
sp.tanh: 42,
sp.Symbol("z"): 43,
"4": 44,
"3": 45,
sp.cot: 46,
sp.asinh: 47,
sp.atan: 48,
sp.acosh: 49,
"1": 50,
sp.atanh: 51,
".": 52,
sp.sinh: 53,
sp.acoth: 54,
sp.sec: 55,
sp.Symbol("beta"): 56,
sp.Symbol("gamma"): 57,
sp.Symbol("delta"): 58,
sp.Symbol("a"): 59,
sp.Symbol("b"): 60,
sp.Symbol("c"): 61,
sp.Symbol("d"): 62,
sp.Symbol("epsilon"): 63,
}
tokenize_dict = DEFAULT_DICT

# invert tokenizer_dict into decode_dict
decode_dict = {v: k for k, v in tokenize_dict.items()}
Expand Down
Loading

0 comments on commit 01b6ab0

Please sign in to comment.