Skip to content
This repository has been archived by the owner on Jan 8, 2025. It is now read-only.

Commit

Permalink
Merge branch 'ml4ai:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
YohannParis authored Jul 14, 2023
2 parents a8987b9 + abef48f commit 3de834d
Show file tree
Hide file tree
Showing 28 changed files with 1,874 additions and 675 deletions.
5 changes: 5 additions & 0 deletions data/mml2pn_inputs/testing_eqns/mml_list.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
<math><mfrac><mrow><mi>d</mi><mi>E</mi></mrow><mrow><mi>d</mi><mi>t</mi></mrow></mfrac><mo>=</mo><mi>&#x03B2;</mi><mi>I</mi><mi>S</mi><mo>&#x2212;</mo><mi>&#x03B4;</mi><mi>E</mi></math>
<math><mfrac><mrow><mi>d</mi><mi>R</mi></mrow><mrow><mi>d</mi><mi>t</mi></mrow></mfrac><mo>=</mo><mi>(1&#x2212;&#x03B1;)</mi><mi>&#x03B3;</mi><mi>I</mi></math>
<math><mfrac><mrow><mi>d</mi><mi>I</mi></mrow><mrow><mi>d</mi><mi>t</mi></mrow></mfrac><mo>=</mo><mi>&#x03B4;</mi><mi>E</mi><mo>&#x2212;</mo><mi>(1&#x2212;&#x03B1;)</mi><mi>&#x03B3;</mi><mi>I</mi><mo>&#x2212;</mo><mi>&#x03B1;</mi><mi>&#x03C1;</mi><mi>I</mi></math>
<math><mfrac><mrow><mi>d</mi><mi>D</mi></mrow><mrow><mi>d</mi><mi>t</mi></mrow></mfrac><mo>=</mo><mi>&#x03B1;</mi><mi>&#x03C1;</mi><mi>I</mi></math>
<math><mfrac><mrow><mi>d</mi><mi>S</mi></mrow><mrow><mi>d</mi><mi>t</mi></mrow></mfrac><mo>=</mo><mo>&#x2212;</mo><mi>&#x03B2;</mi><mi>I</mi><mi>S</mi></math>
5 changes: 5 additions & 0 deletions data/mml2pn_inputs/testing_eqns/mml_list2.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
<math display="block"><mfrac><mrow><mi>d</mi><mi>S</mi><mo>(</mo><mi>t</mi><mo>)</mo></mrow><mrow><mi>d</mi><mi>t</mi></mrow></mfrac><mo>=</mo><mo>&#x2212;</mo><mi>&#x03B2;</mi><mi>I</mi><mo>(</mo><mi>t</mi><mo>)</mo><mi>S</mi><mo>(</mo><mi>t</mi><mo>)</mo></math>
<math display="block"><mfrac><mrow><mi>d</mi><mi>E</mi><mo>(</mo><mi>t</mi><mo>)</mo></mrow><mrow><mi>d</mi><mi>t</mi></mrow></mfrac><mo>=</mo><mi>&#x03B2;</mi><mi>I</mi><mo>(</mo><mi>t</mi><mo>)</mo><mi>S</mi><mo>(</mo><mi>t</mi><mo>)</mo><mo>&#x2212;</mo><mi>&#x03B4;</mi><mi>E</mi><mo>(</mo><mi>t</mi><mo>)</mo></math>
<math display="block"><mfrac><mrow><mi>d</mi><mi>D</mi><mo>(</mo><mi>t</mi><mo>)</mo></mrow><mrow><mi>d</mi><mi>t</mi></mrow></mfrac><mo>=</mo><mi>&#x03B1;</mi><mi>&#x03C1;</mi><mi>I</mi><mo>(</mo><mi>t</mi><mo>)</mo></math>
<math display="block"><mfrac><mrow><mi>d</mi><mi>R</mi><mo>(</mo><mi>t</mi><mo>)</mo></mrow><mrow><mi>d</mi><mi>t</mi></mrow></mfrac><mo>=</mo><mi>(1&#x2212;&#x03B1;)</mi><mi>&#x03B3;</mi><mi>I</mi><mo>(</mo><mi>t</mi><mo>)</mo></math>
<math display="block"><mfrac><mrow><mi>d</mi><mi>I</mi><mo>(</mo><mi>t</mi><mo>)</mo></mrow><mrow><mi>d</mi><mi>t</mi></mrow></mfrac><mo>=</mo><mi>&#x03B4;</mi><mi>E</mi><mo>(</mo><mi>t</mi><mo>)</mo><mo>&#x2212;</mo><mi>(1&#x2212;&#x03B1;)</mi><mi>&#x03B3;</mi><mi>I</mi><mo>(</mo><mi>t</mi><mo>)</mo><mo>&#x2212;</mo><mi>&#x03B1;</mi><mi>&#x03C1;</mi><mi>I</mi><mo>(</mo><mi>t</mi><mo>)</mo></math>
5 changes: 5 additions & 0 deletions data/mml2pn_inputs/testing_eqns/mml_list3.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
<math display="block"><mfrac><mrow><mi>d</mi><mi>S</mi><mo>(</mo><mi>t</mi><mo>)</mo></mrow><mrow><mi>d</mi><mi>t</mi></mrow></mfrac><mo>=</mo><mo>&#x2212;</mo><mi>&#x03B2;</mi><mi>I</mi><mo>(</mo><mi>t</mi><mo>)</mo><mfrac><mrow><mi>S</mi><mo>(</mo><mi>t</mi><mo>)</mo></mrow><mi>N</mi></mfrac></math>
<math display="block"><mfrac><mrow><mi>d</mi><mi>E</mi><mo>(</mo><mi>t</mi><mo>)</mo></mrow><mrow><mi>d</mi><mi>t</mi></mrow></mfrac><mo>=</mo><mi>&#x03B2;</mi><mi>I</mi><mo>(</mo><mi>t</mi><mo>)</mo><mfrac><mrow><mi>S</mi><mo>(</mo><mi>t</mi><mo>)</mo></mrow><mi>N</mi></mfrac><mo>&#x2212;</mo><mi>&#x03B4;</mi><mi>E</mi><mo>(</mo><mi>t</mi><mo>)</mo></math>
<math display="block"><mfrac><mrow><mi>d</mi><mi>I</mi><mo>(</mo><mi>t</mi><mo>)</mo></mrow><mrow><mi>d</mi><mi>t</mi></mrow></mfrac><mo>=</mo><mi>&#x03B4;</mi><mi>E</mi><mo>(</mo><mi>t</mi><mo>)</mo><mo>&#x2212;</mo><mo>(</mo><mn>1</mn><mo>&#x2212;</mo><mi>&#x03B1;</mi><mo>)</mo><mi>&#x03B3;</mi><mi>I</mi><mo>(</mo><mi>t</mi><mo>)</mo><mo>&#x2212;</mo><mi>&#x03B1;</mi><mi>&#x03C1;</mi><mi>I</mi><mo>(</mo><mi>t</mi><mo>)</mo></math>
<math display="block"><mfrac><mrow><mi>d</mi><mi>R</mi><mo>(</mo><mi>t</mi><mo>)</mo></mrow><mrow><mi>d</mi><mi>t</mi></mrow></mfrac><mo>=</mo><mo>(</mo><mn>1</mn><mo>&#x2212;</mo><mi>&#x03B1;</mi><mo>)</mo><mi>&#x03B3;</mi><mi>I</mi><mo>(</mo><mi>t</mi><mo>)</mo></math>
<math display="block"><mfrac><mrow><mi>d</mi><mi>D</mi><mo>(</mo><mi>t</mi><mo>)</mo></mrow><mrow><mi>d</mi><mi>t</mi></mrow></mfrac><mo>=</mo><mi>&#x03B1;</mi><mi>&#x03C1;</mi><mi>I</mi><mo>(</mo><mi>t</mi><mo>)</mo></math>
4 changes: 2 additions & 2 deletions skema/img2mml/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ def retrieve_model(model_path=None) -> str:
cwd = Path(__file__).parents[0]
MODEL_BASE_ADDRESS = "https://artifacts.askem.lum.ai/skema/img2mml/models"
MODEL_NAME = "cnn_xfmer_arxiv_im2mml_with_fonts_boldface_best.pt"

if model_path is None:
# If the model path is none or doesn't exist, the default model will be downloaded from server.
if model_path is None or not os.path.exists(model_path):
model_path = cwd / "trained_models" / MODEL_NAME

# Check if the model file already exists
Expand Down
145 changes: 77 additions & 68 deletions skema/program_analysis/TS2CAST/node_helper.py
Original file line number Diff line number Diff line change
@@ -1,76 +1,53 @@
from typing import List, Dict
from skema.program_analysis.CAST2FN.model.cast import SourceRef


class NodeHelper(object):
def __init__(self, source_file_name: str, source: str):
self.source_file_name = source_file_name
from tree_sitter import Node

CONTROL_CHARACTERS = [
",",
"=",
"==",
"(",
")",
"(/",
"/)",
":",
"::",
"+",
"-",
"*",
"**",
"/",
">",
"<",
"<=",
">=",
"only",
]

class NodeHelper():
def __init__(self, source: str, source_file_name: str):
self.source = source
self.source_file_name = source_file_name

def parse_tree_to_dict(self, node) -> Dict:
node_dict = {
"type": self.get_node_type(node),
"source_refs": [self.get_node_source_ref(node)],
"identifier": self.get_node_identifier(node),
"original_children_order": [],
"children": [],
"comments": [],
"control": [],
}

for child in node.children:
child_dict = self.parse_tree_to_dict(child)
node_dict["original_children_order"].append(child_dict)
if self.is_comment_node(child):
node_dict["comments"].append(child_dict)
elif self.is_control_character_node(child):
node_dict["control"].append(child_dict)
else:
node_dict["children"].append(child_dict)

return node_dict

def is_comment_node(self, node):
if node.type == "comment":
return True
return False

def is_control_character_node(self, node):
control_characters = [
",",
"=",
"(",
")",
":",
"::",
"+",
"-",
"*",
"**",
"/",
">",
"<",
"<=",
">=",
]
return node.type in control_characters

def get_node_source_ref(self, node) -> SourceRef:
def get_source_ref(self, node: Node) -> SourceRef:
"""Given a node and file name, return a CAST SourceRef object."""
row_start, col_start = node.start_point
row_end, col_end = node.end_point
return SourceRef(self.source_file_name, col_start, col_end, row_start, row_end)

def get_node_identifier(self, node) -> str:
source_ref = self.get_node_source_ref(node)

def get_identifier(self, node: Node) -> str:
"""Given a node, return the identifier it represents. ie. The code between node.start_point and node.end_point"""
line_num = 0
column_num = 0
in_identifier = False
identifier = ""
for i, char in enumerate(self.source):
if line_num == source_ref.row_start and column_num == source_ref.col_start:
if line_num == node.start_point[0] and column_num == node.start_point[1]:
in_identifier = True
elif line_num == source_ref.row_end and column_num == source_ref.col_end:
elif line_num == node.end_point[0] and column_num == node.end_point[1]:
break

if char == "\n":
Expand All @@ -84,19 +61,51 @@ def get_node_identifier(self, node) -> str:

return identifier

def get_node_type(self, node) -> str:
return node.type
def get_first_child_by_type(node: Node, type: str, recurse=False):
"""Takes in a node and a type string as inputs and returns the first child matching that type. Otherwise, return None
When the recurse argument is set, it will also recursivly search children nodes as well.
"""
for child in node.children:
if child.type == type:
return child

if recurse:
for child in node.children:
out = get_first_child_by_type(child, type, True)
if out:
return out
return None


def get_children_by_types(node: Node, types: List):
"""Takes in a node and a list of types as inputs and returns all children matching those types. Otherwise, return an empty list"""
return [child for child in node.children if child.type in types]


def get_first_child_index(node, type: str):
"""Get the index of the first child of node with type type."""
for i, child in enumerate(node.children):
if child.type == type:
return i


def get_last_child_index(node, type: str):
"""Get the index of the last child of node with type type."""
last = None
for i, child in enumerate(node.children):
if child.type == type:
last = child
return last


def get_first_child_by_type(self, node: Dict, node_type: str) -> Dict:
children = self.get_children_by_type(node, node_type)
if len(children) >= 1:
return children[0]
def get_control_children(node: Node):
return get_children_by_types(node, CONTROL_CHARACTERS)

def get_children_by_type(self, node: Dict, node_type: str) -> List:
children = []

for child in node["children"]:
if child["type"] == node_type:
children.append(child)
def get_non_control_children(node: Node):
children = []
for child in node.children:
if child.type not in CONTROL_CHARACTERS:
children.append(child)

return children
return children
Loading

0 comments on commit 3de834d

Please sign in to comment.