Skip to content

Commit

Permalink
feat: change MIR representation to protobuf
Browse files Browse the repository at this point in the history
  • Loading branch information
lumasepa committed Dec 16, 2024
1 parent d8a0021 commit 1deb69d
Show file tree
Hide file tree
Showing 23 changed files with 1,285 additions and 742 deletions.
359 changes: 173 additions & 186 deletions nada_dsl/ast_util.py

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions nada_dsl/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
class CompilerOutput:
"""Compiler Output"""

mir: str
mir: bytes


@add_timer(timer_name="nada_dsl.compile.compile")
Expand Down Expand Up @@ -82,7 +82,7 @@ def print_output(out: CompilerOutput):
"""
output_json = {
"result": "Success",
"mir": out.mir,
"mir": list(out.mir),
}
print(json.dumps(output_json))

Expand Down
121 changes: 57 additions & 64 deletions nada_dsl/compiler_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
"""

from dataclasses import dataclass
import json
import os
from json import JSONEncoder
import inspect
from typing import List, Dict, Any, Optional, Tuple
from typing import List, Dict, Optional, Tuple
from sortedcontainers import SortedDict

from nada_mir_proto.nillion.nada.mir import v1 as proto_mir
from nada_mir_proto.nillion.nada.operations import v1 as proto_op
from nada_mir_proto.nillion.nada.types import v1 as proto_ty

from nada_dsl.ast_util import (
AST_OPERATIONS,
ASTOperation,
Expand All @@ -24,7 +25,6 @@
NTupleAccessorASTOperation,
NadaFunctionASTOperation,
NadaFunctionArgASTOperation,
NadaFunctionCallASTOperation,
NewASTOperation,
ObjectAccessorASTOperation,
RandomASTOperation,
Expand All @@ -38,16 +38,7 @@
INPUTS = SortedDict()
PARTIES = SortedDict()
FUNCTIONS: Dict[int, NadaFunctionASTOperation] = {}
LITERALS: Dict[str, Tuple[str, object]] = {}


class ClassEncoder(JSONEncoder):
"""Custom JSON encoder for classes."""

def default(self, o):
if inspect.isclass(o):
return o.__name__
return {type(o).__name__: o.__dict__}
LITERALS: Dict[str, Tuple[str, proto_ty.NadaType]] = {}


def get_target_dir() -> str:
Expand All @@ -66,19 +57,19 @@ def get_target_dir() -> str:
return os.path.join(cwd, "target")


def nada_compile(outputs: List[Output]) -> str:
def nada_compile(outputs: List[Output]) -> bytes:
"""Compile Nada to MIR and dump it as JSON."""
compiled = nada_dsl_to_nada_mir(outputs)
return json.dumps(compiled)
return bytes(compiled)


def nada_dsl_to_nada_mir(outputs: List[Output]) -> Dict[str, Any]:
def nada_dsl_to_nada_mir(outputs: List[Output]) -> proto_mir.ProgramMir:
"""Convert Nada DSL to Nada MIR."""
new_outputs = []
PARTIES.clear()
INPUTS.clear()
LITERALS.clear()
operations: Dict[int, Dict] = {}
operations: Dict[int, proto_op.Operation] = {}
# Process outputs
for output in outputs:
timer.start(
Expand All @@ -96,72 +87,75 @@ def nada_dsl_to_nada_mir(outputs: List[Output]) -> Dict[str, Any]:
party = output.party
PARTIES[party.name] = party
new_outputs.append(
{
"operation_id": out_operation_id,
"name": output.name,
"party": party.name,
"type": AST_OPERATIONS[out_operation_id].ty,
"source_ref_index": output.source_ref.to_index(),
}
proto_mir.Output(
operation_id=out_operation_id,
name=output.name,
party=party.name,
type=AST_OPERATIONS[out_operation_id].ty,
source_ref_index=output.source_ref.to_index(),
)
)
# Now we go through all the discovered functions and see if they are
# invoking other functions, which we will need to process and add to the FUNCTIONS dictionary

return {
"functions": to_mir_function_list(FUNCTIONS),
"parties": to_party_list(PARTIES),
"inputs": to_input_list(INPUTS),
"literals": to_literal_list(LITERALS),
"outputs": new_outputs,
"operations": operations,
"source_files": SourceRef.get_sources(),
"source_refs": SourceRef.get_refs(),
}


def to_party_list(parties) -> List[Dict]:
return proto_mir.ProgramMir(
functions=to_mir_function_list(FUNCTIONS),
parties=to_party_list(PARTIES),
inputs=to_input_list(INPUTS),
literals=to_literal_list(LITERALS),
outputs=new_outputs,
operations=operations,
source_files=SourceRef.get_sources(),
source_refs=SourceRef.get_refs(),
)


def to_party_list(parties) -> List[proto_mir.Party]:
"""Convert parties to a list in MIR format."""
return [
{
"name": party.name,
"source_ref_index": party.source_ref.to_index(),
}
proto_mir.Party(
name=party.name,
source_ref_index=party.source_ref.to_index(),
)
for party in parties.values()
]


def to_input_list(inputs) -> List[Dict]:
def to_input_list(inputs) -> List[proto_mir.Input]:
"""Convert inputs to a list in MIR format."""
input_list = []
for party_inputs in inputs.values():
for program_input, program_type in party_inputs.values():
input_list.append(
{
"name": program_input.name,
"type": program_type,
"party": program_input.party.name,
"doc": program_input.doc,
"source_ref_index": program_input.source_ref.to_index(),
}
proto_mir.Input(
name=program_input.name,
type=program_type,
party=program_input.party.name,
doc=program_input.doc,
source_ref_index=program_input.source_ref.to_index(),
)
)
return input_list


def to_literal_list(literals: Dict[str, Tuple[str, object]]) -> List[Dict]:
def to_literal_list(
literals: Dict[str, Tuple[str, proto_ty.NadaType]],
) -> List[proto_mir.Literal]:
"""Convert literals to a list in MIR format."""
literal_list = []
for name, (value, ty) in literals.items():
literal_list.append(
{
"name": name,
"value": str(value),
"type": ty,
}
proto_mir.Literal(
name=name,
value=value,
type=ty,
)
)
return literal_list


def to_mir_function_list(functions: Dict[int, NadaFunctionASTOperation]) -> List[Dict]:
def to_mir_function_list(
functions: Dict[int, NadaFunctionASTOperation],
) -> List[proto_mir.NadaFunction]:
"""Convert functions to a list in MIR format.
From a starting dictionary of functions, it traverses each one of them,
Expand Down Expand Up @@ -194,6 +188,7 @@ def to_mir_function_list(functions: Dict[int, NadaFunctionASTOperation]) -> List
if extra_functions:
stack.extend(extra_functions.values())
functions.update(extra_functions)

mir_functions.append(function.to_mir(function_operations))
return mir_functions

Expand All @@ -220,7 +215,7 @@ class CompilerException(Exception):

def traverse_and_process_operations(
operation_id: int,
operations: Dict[int, Dict],
operations: Dict[int, proto_op.Operation],
functions: Dict[int, NadaFunctionASTOperation],
) -> Dict[int, NadaFunctionASTOperation]:
"""Traverses the AST operations finding all the operation tree rooted at the given
Expand Down Expand Up @@ -266,7 +261,7 @@ def traverse_and_process_operations(
class ProcessOperationOutput:
"""Output of the process_operation function"""

mir: Dict[str, Dict]
mir: proto_op.Operation
extra_function: Optional[NadaFunctionASTOperation]


Expand Down Expand Up @@ -312,9 +307,7 @@ def process_operation(
elif isinstance(operation, LiteralASTOperation):
LITERALS[operation.literal_index] = (str(operation.value), operation.ty)
processed_operation = ProcessOperationOutput(operation.to_mir(), None)
elif isinstance(
operation, (MapASTOperation, ReduceASTOperation, NadaFunctionCallASTOperation)
):
elif isinstance(operation, (MapASTOperation, ReduceASTOperation)):
extra_fn = None
if operation.fn not in functions:
extra_fn = AST_OPERATIONS[operation.fn]
Expand Down
4 changes: 3 additions & 1 deletion nada_dsl/future/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from dataclasses import dataclass

from nada_mir_proto.nillion.nada.types import v1 as proto_ty

from nada_dsl import SourceRef
from nada_dsl.ast_util import AST_OPERATIONS, CastASTOperation, next_operation_id
from nada_dsl.nada_types import AllTypes, AllTypesType
Expand All @@ -21,7 +23,7 @@ def __init__(self, target: AllTypes, to: AllTypes, source_ref: SourceRef):
self.to = to
self.source_ref = source_ref

def store_in_ast(self, ty: object):
def store_in_ast(self, ty: proto_ty.NadaType):
"""Store object in AST"""
AST_OPERATIONS[self.id] = CastASTOperation(
id=self.id, target=self.target, ty=ty, source_ref=self.source_ref
Expand Down
Loading

0 comments on commit 1deb69d

Please sign in to comment.