Skip to content

Commit

Permalink
Optimize consecutive message field assignments
Browse files Browse the repository at this point in the history
Ref. #1120
  • Loading branch information
treiher committed Aug 16, 2022
1 parent 3322003 commit ac00822
Show file tree
Hide file tree
Showing 16 changed files with 1,349 additions and 86 deletions.
136 changes: 134 additions & 2 deletions rflx/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -2745,6 +2745,138 @@ def variables(self) -> List["Variable"]:
return result


class DeltaMessageAggregate(Expr):
"""For internal use only."""

def __init__(
self,
identifier: StrID,
field_values: Mapping[StrID, Expr],
type_: rty.Type = rty.Undefined(),
location: Location = None,
) -> None:
super().__init__(type_, location)
self.identifier = ID(identifier)
self.field_values = {ID(k): v for k, v in field_values.items()}

def _update_str(self) -> None:
field_values = (
", ".join([f"{k} => {self.field_values[k]}" for k in self.field_values])
if self.field_values
else "null message"
)
self._str = intern(f"{self.identifier} with delta {field_values}")

def _check_type_subexpr(self) -> RecordFluxError:
if not isinstance(self.type_, rty.Message):
error = RecordFluxError()

for d in self.field_values.values():
error += d.check_type_instance(rty.Any)

return error

error = RecordFluxError()
field_combinations = set(self.type_.field_combinations)
fields: tuple[str, ...] = ()

for i, (field, expr) in enumerate(self.field_values.items()):
if field not in self.type_.fields:
error.extend(
[
(
f'invalid field "{field}" for {self.type_}',
Subsystem.MODEL,
Severity.ERROR,
field.location,
),
*_similar_field_names(field, self.type_.fields, field.location),
]
)
continue

field_type = self.type_.types[field]

if field_type == rty.OPAQUE:
if not any(
r.field == field and expr.type_.is_compatible(r.sdu)
for r in self.type_.refinements
):
error += expr.check_type(field_type)
else:
error += expr.check_type(field_type)

fields = (*fields, str(field))
field_combinations = {
c
for c in field_combinations
if any(fields == c[i : len(fields) + i] for i in range(len(c) - len(fields) + 1))
}

if not field_combinations:
error.extend(
[
(
f'invalid position for field "{field}" of {self.type_}',
Subsystem.MODEL,
Severity.ERROR,
field.location,
)
],
)
break

return error

def __neg__(self) -> Expr:
raise NotImplementedError

def findall(self, match: Callable[[Expr], bool]) -> Sequence[Expr]:
return [
*([self] if match(self) else []),
*[e for v in self.field_values.values() for e in v.findall(match)],
]

def simplified(self) -> Expr:
return DeltaMessageAggregate(
self.identifier,
{k: self.field_values[k].simplified() for k in self.field_values},
self.type_,
self.location,
)

def substituted(
self, func: Callable[[Expr], Expr] = None, mapping: Mapping[Name, Expr] = None
) -> Expr:
func = substitution(mapping or {}, func)
expr = func(self)
if isinstance(expr, DeltaMessageAggregate):
return expr.__class__(
expr.identifier,
{k: expr.field_values[k].substituted(func) for k in expr.field_values},
type_=expr.type_,
location=expr.location,
)
return expr

@property
def precedence(self) -> Precedence:
raise NotImplementedError

def ada_expr(self) -> ada.Expr:
raise NotImplementedError

@lru_cache(maxsize=None)
def z3expr(self) -> z3.ExprRef:
raise NotImplementedError

def variables(self) -> List[Variable]:
result = []
for v in self.field_values.values():
result.extend(v.variables())
return result


class Binding(Expr):
def __init__(self, expr: Expr, data: Mapping[StrID, Expr], location: Location = None) -> None:
super().__init__(expr.type_, location)
Expand Down Expand Up @@ -2850,12 +2982,12 @@ def _entity_name(expr: Expr) -> str:
else "type"
if isinstance(expr, Conversion)
else "message"
if isinstance(expr, MessageAggregate)
if isinstance(expr, (MessageAggregate, DeltaMessageAggregate))
else "expression"
)
expr_name = (
str(expr.identifier)
if isinstance(expr, (Variable, Call, Conversion, MessageAggregate))
if isinstance(expr, (Variable, Call, Conversion, MessageAggregate, DeltaMessageAggregate))
else str(expr)
)
return f'{expr_type} "{expr_name}"'
Expand Down
97 changes: 96 additions & 1 deletion rflx/generator/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -2159,6 +2159,11 @@ def _assign( # pylint: disable = too-many-arguments
target, expression, exception_handler, is_global, state
)

if isinstance(expression, expr.DeltaMessageAggregate):
return self._assign_to_delta_message_aggregate(
target, expression, exception_handler, is_global, state
)

if isinstance(target_type, rty.Message):
for v in expression.findall(
lambda x: isinstance(x, expr.Variable) and x.identifier == target
Expand Down Expand Up @@ -2424,6 +2429,85 @@ def _assign_to_message_aggregate(

return assign_to_message_aggregate

def _assign_to_delta_message_aggregate(
self,
target: ID,
delta_message_aggregate: expr.DeltaMessageAggregate,
exception_handler: ExceptionHandler,
is_global: Callable[[ID], bool],
state: ID,
) -> Sequence[Statement]:
assert isinstance(delta_message_aggregate.type_, rty.Message)

self._session_context.used_types_body.append(const.TYPES_BIT_LENGTH)

target_type_id = delta_message_aggregate.type_.identifier
target_context = context_id(target, is_global)

fields = list(delta_message_aggregate.field_values)
first_field = fields[0]
last_field = fields[-1]

required_space, required_space_precondition = self._required_space(
self._message_subpath_size(delta_message_aggregate), is_global, state
)

return [
self._raise_exception_if(
Not(
Call(
target_type_id * "Valid_Next",
[
Variable(target_context),
Variable(target_type_id * model.Field(first_field).affixed_name),
],
)
),
f'trying to set message fields "{first_field}" to "{last_field}" although'
f' "{first_field}" is not valid next field',
exception_handler,
),
*(
[
self._raise_exception_if(
Not(required_space_precondition),
"violated precondition for calculating required space for setting message"
f' fields "{first_field}" to "{last_field}" (one of the message arguments'
" is invalid or has a too small buffer)",
exception_handler,
)
]
if required_space_precondition
else []
),
self._raise_exception_if(
Less(
Call(
target_type_id * "Available_Space",
[
Variable(target_context),
Variable(target_type_id * model.Field(first_field).affixed_name),
],
),
required_space,
),
f'insufficient space for setting message fields "{first_field}" to "{last_field}"',
exception_handler,
),
*[
s
for f, v in delta_message_aggregate.field_values.items()
for s in self._set_message_field(
target_context,
f,
delta_message_aggregate.type_,
v,
exception_handler,
is_global,
)
],
]

def _assign_to_head(
self,
target: ID,
Expand Down Expand Up @@ -3434,14 +3518,25 @@ def _message_size(self, message_aggregate: expr.MessageAggregate) -> expr.Expr:
assert isinstance(message, model.Message)
return message.size({model.Field(f): v for f, v in message_aggregate.field_values.items()})

def _message_subpath_size(
self, delta_message_aggregate: expr.DeltaMessageAggregate
) -> expr.Expr:
assert isinstance(delta_message_aggregate.type_, rty.Message)
message = self._model_type(delta_message_aggregate.type_.identifier)
assert isinstance(message, model.Message)
return message.size(
{model.Field(f): v for f, v in delta_message_aggregate.field_values.items()},
delta_message_aggregate.identifier,
subpath=True,
)

def _required_space(
self, size: expr.Expr, is_global: Callable[[ID], bool], state: ID
) -> tuple[Expr, Optional[Expr]]:
required_space = (
size.substituted(
lambda x: expr.Call(const.TYPES_BIT_LENGTH, [x])
if (isinstance(x, expr.Variable) and isinstance(x.type_, rty.AnyInteger))
or (isinstance(x, expr.Selected) and x.type_ != rty.OPAQUE)
else x
)
.substituted(self._substitution(is_global))
Expand Down
Loading

0 comments on commit ac00822

Please sign in to comment.