Skip to content

Commit

Permalink
bootloader for recusive with poseidon
Browse files Browse the repository at this point in the history
  • Loading branch information
chudkowsky committed Dec 10, 2024
1 parent 45bd7de commit 3d6476c
Show file tree
Hide file tree
Showing 6 changed files with 393 additions and 55 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,8 @@ struct BuiltinData {
output: felt,
pedersen: felt,
range_check: felt,
ecdsa: felt,
bitwise: felt,
ec_op: felt,
keccak: felt,
poseidon: felt,
range_check96: felt,
add_mod: felt,
mul_mod: felt,
}

// Computes the hash of a program.
Expand Down Expand Up @@ -133,14 +127,8 @@ func execute_task{builtin_ptrs: BuiltinData*, self_range_check_ptr}(
output=output_ptr + 2,
pedersen=cast(pedersen_ptr, felt),
range_check=input_builtin_ptrs.range_check,
ecdsa=input_builtin_ptrs.ecdsa,
bitwise=input_builtin_ptrs.bitwise,
ec_op=input_builtin_ptrs.ec_op,
keccak=input_builtin_ptrs.keccak,
poseidon=cast(poseidon_ptr, felt),
range_check96=input_builtin_ptrs.range_check96,
add_mod=input_builtin_ptrs.add_mod,
mul_mod=input_builtin_ptrs.mul_mod,
);

// Call select_input_builtins to get the relevant input builtin pointers for the task.
Expand Down
99 changes: 99 additions & 0 deletions src/starkware/cairo/bootloaders/simple_bootloader/objects.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import dataclasses
from abc import abstractmethod
from dataclasses import field
from typing import ClassVar, Dict, List, Optional, Type
import marshmallow
import marshmallow.fields as mfields
import marshmallow_dataclass
from marshmallow_oneofschema import OneOfSchema
from starkware.cairo.lang.compiler.program import Program, ProgramBase, StrippedProgram
from starkware.cairo.lang.vm.cairo_pie import CairoPie
from starkware.starkware_utils.marshmallow_dataclass_fields import additional_metadata
from starkware.starkware_utils.validated_dataclass import ValidatedMarshmallowDataclass


class Task:
@abstractmethod
def get_program(self) -> ProgramBase:
"""
Returns the task's Cairo program.
"""


class TaskSpec:
"""
Contains task's specification.
"""

@abstractmethod
def load_task(self, memory=None, args_start=None, args_len=None) -> "Task":
"""
Returns the corresponding task.
"""


@marshmallow_dataclass.dataclass(frozen=True)
class RunProgramTask(TaskSpec, Task):
TYPE: ClassVar[str] = "RunProgramTask"
program: Program
program_input: dict
use_poseidon: bool

def get_program(self) -> Program:
return self.program

def load_task(self, memory=None, args_start=None, args_len=None) -> "Task":
return self


@marshmallow_dataclass.dataclass(frozen=True)
class CairoPiePath(TaskSpec):
TYPE: ClassVar[str] = "CairoPiePath"
path: str
use_poseidon: bool

def load_task(self, memory=None, args_start=None, args_len=None) -> "CairoPieTask":
"""
Loads the PIE to memory.
"""
return CairoPieTask(
cairo_pie=CairoPie.from_file(self.path), use_poseidon=self.use_poseidon
)


class TaskSchema(OneOfSchema):
"""
Schema for Task/CairoPiePath/Cairo1ProgramPath/CairoSierra
OneOfSchema adds a "type" field.
"""

type_schemas: Dict[str, Type[marshmallow.Schema]] = {
RunProgramTask.TYPE: RunProgramTask.Schema,
CairoPiePath.TYPE: CairoPiePath.Schema,
}

def get_obj_type(self, obj):
return obj.TYPE


@dataclasses.dataclass(frozen=True)
class CairoPieTask(Task):
cairo_pie: CairoPie
use_poseidon: bool

def get_program(self) -> StrippedProgram:
return self.cairo_pie.program


@marshmallow_dataclass.dataclass(frozen=True)
class SimpleBootloaderInput(ValidatedMarshmallowDataclass):
tasks: List[TaskSpec] = field(
metadata=additional_metadata(
marshmallow_field=mfields.List(mfields.Nested(TaskSchema))
)
)
fact_topologies_path: Optional[str]

# If true, the bootloader will put all the outputs in a single page, ignoring the
# tasks' fact topologies.
single_page: bool
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,8 @@ func run_simple_bootloader{
output_ptr: felt*,
pedersen_ptr: HashBuiltin*,
range_check_ptr,
ecdsa_ptr,
bitwise_ptr,
ec_op_ptr,
keccak_ptr,
poseidon_ptr: PoseidonBuiltin*,
range_check96_ptr,
add_mod_ptr,
mul_mod_ptr,
}() {
alloc_locals;
local task_range_check_ptr;
Expand All @@ -47,43 +41,25 @@ func run_simple_bootloader{
output=cast(output_ptr, felt),
pedersen=cast(pedersen_ptr, felt),
range_check=task_range_check_ptr,
ecdsa=ecdsa_ptr,
bitwise=bitwise_ptr,
ec_op=ec_op_ptr,
keccak=keccak_ptr,
poseidon=cast(poseidon_ptr, felt),
range_check96=range_check96_ptr,
add_mod=add_mod_ptr,
mul_mod=mul_mod_ptr,
);

// A struct containing the encoding of each builtin.
local builtin_encodings: BuiltinData = BuiltinData(
output='output',
pedersen='pedersen',
range_check='range_check',
ecdsa='ecdsa',
bitwise='bitwise',
ec_op='ec_op',
keccak='keccak',
poseidon='poseidon',
range_check96='range_check96',
add_mod='add_mod',
mul_mod='mul_mod',
);

local builtin_instance_sizes: BuiltinData = BuiltinData(
output=1,
pedersen=3,
range_check=1,
ecdsa=2,
bitwise=5,
ec_op=7,
keccak=16,
poseidon=6,
range_check96=1,
add_mod=7,
mul_mod=7,
);

// Call execute_tasks.
Expand All @@ -108,14 +84,8 @@ func run_simple_bootloader{
let output_ptr = cast(builtin_ptrs.output, felt*);
let pedersen_ptr = cast(builtin_ptrs.pedersen, HashBuiltin*);
let range_check_ptr = builtin_ptrs.range_check;
let ecdsa_ptr = builtin_ptrs.ecdsa;
let bitwise_ptr = builtin_ptrs.bitwise;
let ec_op_ptr = builtin_ptrs.ec_op;
let keccak_ptr = builtin_ptrs.keccak;
let poseidon_ptr = cast(builtin_ptrs.poseidon, PoseidonBuiltin*);
let range_check96_ptr = builtin_ptrs.range_check96;
let add_mod_ptr = builtin_ptrs.add_mod;
let mul_mod_ptr = builtin_ptrs.mul_mod;

// 'execute_tasks' runs untrusted code and uses the range_check builtin to verify that
// the builtin pointers were advanced correctly by said code.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
%builtins output pedersen range_check ecdsa bitwise ec_op keccak poseidon range_check96 add_mod mul_mod
%builtins output pedersen range_check bitwise poseidon

from starkware.cairo.bootloaders.simple_bootloader.run_simple_bootloader import (
run_simple_bootloader,
Expand All @@ -10,14 +10,8 @@ func main{
output_ptr: felt*,
pedersen_ptr: HashBuiltin*,
range_check_ptr,
ecdsa_ptr,
bitwise_ptr,
ec_op_ptr,
keccak_ptr,
poseidon_ptr: PoseidonBuiltin*,
range_check96_ptr,
add_mod_ptr,
mul_mod_ptr,
}() {
%{
from starkware.cairo.bootloaders.simple_bootloader.objects import SimpleBootloaderInput
Expand Down
Loading

0 comments on commit 3d6476c

Please sign in to comment.