Skip to content

Commit

Permalink
Fix future consumption checking.
Browse files Browse the repository at this point in the history
Now futures are tracked properly, whether they're stored directly
in a variable or in a tuple.

Also, error if a future is used improperly.
  • Loading branch information
mikebenfield committed Dec 3, 2024
1 parent 9c7cb34 commit 2760c4c
Show file tree
Hide file tree
Showing 14 changed files with 1,068 additions and 551 deletions.
1,227 changes: 725 additions & 502 deletions Cargo.lock

Large diffs are not rendered by default.

10 changes: 9 additions & 1 deletion compiler/ast/src/passes/visitor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,15 @@ pub trait ExpressionVisitor<'a> {
Default::default()
}

fn visit_struct_init(&mut self, _input: &'a StructExpression, _additional: &Self::AdditionalInput) -> Self::Output {
fn visit_struct_init(&mut self, input: &'a StructExpression, additional: &Self::AdditionalInput) -> Self::Output {
let StructExpression { name, members, .. } = input;
self.visit_identifier(name, additional);
for StructVariableInitializer { identifier, expression, .. } in members {
self.visit_identifier(identifier, additional);
if let Some(expression) = expression {
self.visit_expression(expression, additional);
}
}
Default::default()
}

Expand Down
4 changes: 4 additions & 0 deletions compiler/passes/src/static_analysis/analyze_program.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ impl<'a, N: Network> ProgramVisitor<'a> for StaticAnalyzer<'a, N> {
// Set `non_async_external_call_seen` to false.
self.non_async_external_call_seen = false;

if matches!(self.variant, Some(Variant::AsyncFunction) | Some(Variant::AsyncTransition)) {
super::future_checker::future_check_function(function, self.type_table, self.handler);
}

// If the function is an async function, initialize the await checker.
if self.variant == Some(Variant::AsyncFunction) {
// Initialize the list of input futures. Each one must be awaited before the end of the function.
Expand Down
167 changes: 167 additions & 0 deletions compiler/passes/src/static_analysis/future_checker.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
// Copyright (C) 2019-2024 Aleo Systems Inc.
// This file is part of the Leo library.

// The Leo library is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.

// The Leo library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.

// You should have received a copy of the GNU General Public License
// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.

use crate::TypeTable;

use leo_ast::{CoreFunction, Expression, ExpressionVisitor, Function, Node, StatementVisitor, Type};
use leo_errors::{StaticAnalyzerError, emitter::Handler};

/// Error if futures are used improperly.
///
/// This prevents, for instance, a bare call which creates an unused future.
pub fn future_check_function(function: &Function, type_table: &TypeTable, handler: &Handler) {
let mut future_checker = FutureChecker { type_table, handler };
future_checker.visit_block(&function.block);
}

#[derive(Clone, Copy, Debug, Default)]
enum Position {
#[default]
Misc,
Await,
TupleAccess,
Return,
FunctionArgument,
LastTupleLiteral,
Definition,
}

struct FutureChecker<'a> {
type_table: &'a TypeTable,
handler: &'a Handler,
}

impl<'a> FutureChecker<'a> {
fn emit_err(&self, err: StaticAnalyzerError) {
self.handler.emit_err(err);
}
}

impl<'a> ExpressionVisitor<'a> for FutureChecker<'a> {
type AdditionalInput = Position;
type Output = ();

fn visit_expression(&mut self, input: &'a Expression, additional: &Self::AdditionalInput) -> Self::Output {
use Position::*;
let is_call = matches!(input, Expression::Call(..));
match self.type_table.get(&input.id()) {
Some(Type::Future(..)) if is_call => {
// A call producing a Future may appear in any of these positions.
if !matches!(additional, Await | Return | FunctionArgument | LastTupleLiteral | Definition) {
self.emit_err(StaticAnalyzerError::misplaced_future(input.span()));
}
}
Some(Type::Future(..)) => {
// A Future expression that's not a call may appear in any of these positions.
if !matches!(additional, Await | Return | FunctionArgument | LastTupleLiteral | TupleAccess) {
self.emit_err(StaticAnalyzerError::misplaced_future(input.span()));
}
}
Some(Type::Tuple(tuple)) if !matches!(tuple.elements().last(), Some(Type::Future(_))) => {}
Some(Type::Tuple(..)) if is_call => {
// A call producing a Tuple ending in a Future may appear in any of these positions.
if !matches!(additional, Return | Definition) {
self.emit_err(StaticAnalyzerError::misplaced_future(input.span()));
}
}
Some(Type::Tuple(..)) => {
// A Tuple ending in a Future that's not a call may appear in any of these positions.
if !matches!(additional, Return | TupleAccess) {
self.emit_err(StaticAnalyzerError::misplaced_future(input.span()));
}
}
_ => {}
}

match input {
Expression::Access(access) => self.visit_access(access, &Position::Misc),
Expression::Array(array) => self.visit_array(array, &Position::Misc),
Expression::Binary(binary) => self.visit_binary(binary, &Position::Misc),
Expression::Call(call) => self.visit_call(call, &Position::Misc),
Expression::Cast(cast) => self.visit_cast(cast, &Position::Misc),
Expression::Struct(struct_) => self.visit_struct_init(struct_, &Position::Misc),
Expression::Err(err) => self.visit_err(err, &Position::Misc),
Expression::Identifier(identifier) => self.visit_identifier(identifier, &Position::Misc),
Expression::Literal(literal) => self.visit_literal(literal, &Position::Misc),
Expression::Locator(locator) => self.visit_locator(locator, &Position::Misc),
Expression::Ternary(ternary) => self.visit_ternary(ternary, &Position::Misc),
Expression::Tuple(tuple) => self.visit_tuple(tuple, additional),
Expression::Unary(unary) => self.visit_unary(unary, &Position::Misc),
Expression::Unit(unit) => self.visit_unit(unit, &Position::Misc),
}
}

fn visit_access(
&mut self,
input: &'a leo_ast::AccessExpression,
_additional: &Self::AdditionalInput,
) -> Self::Output {
match input {
leo_ast::AccessExpression::Array(array) => {
self.visit_expression(&array.array, &Position::Misc);
self.visit_expression(&array.index, &Position::Misc);
}
leo_ast::AccessExpression::AssociatedFunction(function) => {
let core_function = CoreFunction::from_symbols(function.variant.name, function.name.name)
.expect("Typechecking guarantees that this function exists.");
let position =
if core_function == CoreFunction::FutureAwait { Position::Await } else { Position::Misc };
function.arguments.iter().for_each(|arg| {
self.visit_expression(arg, &position);
});
}
leo_ast::AccessExpression::Member(member) => {
self.visit_expression(&member.inner, &Position::Misc);
}
leo_ast::AccessExpression::Tuple(tuple) => {
self.visit_expression(&tuple.tuple, &Position::TupleAccess);
}
_ => {}
}

Default::default()
}

fn visit_call(&mut self, input: &'a leo_ast::CallExpression, _additional: &Self::AdditionalInput) -> Self::Output {
input.arguments.iter().for_each(|expr| {
self.visit_expression(expr, &Position::FunctionArgument);
});
Default::default()
}

fn visit_tuple(&mut self, input: &'a leo_ast::TupleExpression, additional: &Self::AdditionalInput) -> Self::Output {
let next_position = match additional {
Position::Definition | Position::Return => Position::LastTupleLiteral,
_ => Position::Misc,
};
let mut iter = input.elements.iter().peekable();
while let Some(expr) = iter.next() {
let position = if iter.peek().is_some() { &Position::Misc } else { &next_position };
self.visit_expression(expr, position);
}
Default::default()
}
}

impl<'a> StatementVisitor<'a> for FutureChecker<'a> {
fn visit_definition(&mut self, input: &'a leo_ast::DefinitionStatement) {
self.visit_expression(&input.value, &Position::Definition);
}

fn visit_return(&mut self, input: &'a leo_ast::ReturnStatement) {
self.visit_expression(&input.expression, &Position::Return);
}
}
2 changes: 2 additions & 0 deletions compiler/passes/src/static_analysis/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
// You should have received a copy of the GNU General Public License
// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.

mod future_checker;

mod await_checker;

pub mod analyze_expression;
Expand Down
42 changes: 24 additions & 18 deletions compiler/passes/src/type_checking/check_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,6 @@ impl<'a, N: Network> ExpressionVisitor<'a> for TypeChecker<'a, N> {
if let Some(expected) = expected {
self.check_eq_types(&Some(actual.clone()), &Some(expected.clone()), access.span());
}

// Return type of tuple index.
return Some(actual);
}
Expand Down Expand Up @@ -671,6 +670,30 @@ impl<'a, N: Network> ExpressionVisitor<'a> for TypeChecker<'a, N> {
let ty = self.visit_expression(argument, &Some(expected.type_().clone()))?;
// Extract information about futures that are being consumed.
if func.variant == Variant::AsyncFunction && matches!(expected.type_(), Type::Future(_)) {
// Consume the future.
let option_name = match argument {
Expression::Identifier(id) => Some(id.name),
Expression::Access(AccessExpression::Tuple(tuple_access)) => {
if let Expression::Identifier(id) = &*tuple_access.tuple {
Some(id.name)
} else {
None
}
}
_ => None,
};

if let Some(name) = option_name {
match self.scope_state.futures.shift_remove(&name) {
Some(future) => {
self.scope_state.call_location = Some(future.clone());
}
None => {
self.emit_err(TypeCheckerError::unknown_future_consumed(name, argument.span()));
}
}
}

match argument {
Expression::Identifier(_)
| Expression::Call(_)
Expand Down Expand Up @@ -853,23 +876,6 @@ impl<'a, N: Network> ExpressionVisitor<'a> for TypeChecker<'a, N> {
fn visit_identifier(&mut self, input: &'a Identifier, expected: &Self::AdditionalInput) -> Self::Output {
let var = self.symbol_table.borrow().lookup_variable(Location::new(None, input.name)).cloned();
if let Some(var) = &var {
if matches!(var.type_, Type::Future(_)) && matches!(expected, Some(Type::Future(_))) {
if self.scope_state.variant == Some(Variant::AsyncTransition) && self.scope_state.is_call {
// Consume future.
match self.scope_state.futures.shift_remove(&input.name) {
Some(future) => {
self.scope_state.call_location = Some(future.clone());
return Some(var.type_.clone());
}
None => {
self.emit_err(TypeCheckerError::unknown_future_consumed(input.name, input.span));
}
}
} else {
// Case where accessing input argument of future. Ex `f.1`.
return Some(var.type_.clone());
}
}
Some(self.assert_and_return_type(var.type_.clone(), expected, input.span()))
} else {
self.emit_err(TypeCheckerError::unknown_sym("variable", input.name, input.span()));
Expand Down
16 changes: 9 additions & 7 deletions compiler/passes/src/type_checking/check_statements.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@ use leo_ast::{
};
use leo_errors::TypeCheckerError;

use itertools::Itertools;

impl<'a, N: Network> StatementVisitor<'a> for TypeChecker<'a, N> {
fn visit_statement(&mut self, input: &'a Statement) {
// No statements can follow a return statement.
Expand Down Expand Up @@ -248,7 +246,7 @@ impl<'a, N: Network> StatementVisitor<'a> for TypeChecker<'a, N> {
// Insert the variables into the symbol table.
match &input.place {
Expression::Identifier(identifier) => {
self.insert_variable(inferred_type.clone(), identifier, input.type_.clone(), 0, identifier.span)
self.insert_variable(inferred_type.clone(), identifier, input.type_.clone(), identifier.span)
}
Expression::Tuple(tuple_expression) => {
let tuple_type = match &input.type_ {
Expand All @@ -265,17 +263,21 @@ impl<'a, N: Network> StatementVisitor<'a> for TypeChecker<'a, N> {
));
}

for ((index, expr), type_) in
tuple_expression.elements.iter().enumerate().zip_eq(tuple_type.elements().iter())
{
for i in 0..tuple_expression.elements.len() {
let inferred = if let Some(Type::Tuple(inferred_tuple)) = &inferred_type {
inferred_tuple.elements().get(i).cloned()
} else {
None
};
let expr = &tuple_expression.elements[i];
let identifier = match expr {
Expression::Identifier(identifier) => identifier,
_ => {
return self
.emit_err(TypeCheckerError::lhs_tuple_element_must_be_an_identifier(expr.span()));
}
};
self.insert_variable(inferred_type.clone(), identifier, type_.clone(), index, identifier.span);
self.insert_variable(inferred, &identifier, tuple_type.elements()[i].clone(), identifier.span);
}
}
_ => self.emit_err(TypeCheckerError::lhs_must_be_identifier_or_tuple(input.place.span())),
Expand Down
35 changes: 13 additions & 22 deletions compiler/passes/src/type_checking/checker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1313,32 +1313,23 @@ impl<'a, N: Network> TypeChecker<'a, N> {
}

/// Inserts variable to symbol table.
pub(crate) fn insert_variable(
&mut self,
inferred_type: Option<Type>,
name: &Identifier,
type_: Type,
index: usize,
span: Span,
) {
let ty: Type = if let Type::Future(_) = type_ {
// Need to insert the fully inferred future type, or else will just be default future type.
let ret = match inferred_type.unwrap() {
Type::Future(future) => Type::Future(future),
Type::Tuple(tuple) => match tuple.elements().get(index) {
Some(Type::Future(future)) => Type::Future(future.clone()),
_ => unreachable!("Parsing guarantees that the inferred type is a future."),
},
_ => {
unreachable!("TYC guarantees that the inferred type is a future, or tuple containing futures.")
}
};
// Insert future into list of futures for the function.
pub(crate) fn insert_variable(&mut self, inferred_type: Option<Type>, name: &Identifier, type_: Type, span: Span) {
let is_future = match &type_ {
Type::Future(..) => true,
Type::Tuple(tuple_type) if matches!(tuple_type.elements().last(), Some(Type::Future(..))) => true,
_ => false,
};

if is_future {
self.scope_state.futures.insert(name.name, self.scope_state.call_location.clone().unwrap());
ret
}

let ty: Type = if is_future {
inferred_type.expect("Type checking guarantees the inferred type is present")
} else {
type_
};

// Insert the variable into the symbol table.
if let Err(err) = self.symbol_table.borrow_mut().insert_variable(
Location::new(None, name.name),
Expand Down
7 changes: 7 additions & 0 deletions errors/src/errors/static_analyzer/static_analyzer_error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,4 +59,11 @@ create_messages!(
msg: format!("The call to {function_name} will result in failed executions on-chain."),
help: Some("There is a subtle error that occurs if an async transition call follows a non-async transition call, and the async call returns a `Future` that itself takes a `Future` as an input. See See `https://github.com/AleoNet/snarkVM/issues/2570` for more context.".to_string()),
}

@formatted
misplaced_future {
args: (),
msg: "A future may not be used in this way".to_string(),
help: Some("Futures should be created and consumed without being moved or reassigned.".to_string()),
}
);
2 changes: 1 addition & 1 deletion tests/expectations/compiler/futures/future_in_tuple.out
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ function transfer_private_to_public:
finalize transfer_private_to_public:
assert.eq 1u8 1u8;
""", errors = "", warnings = "" },
{ initial_symbol_table = "baa9875274a09ad91eb08326f18797401a6e98c32388e75b3b406a539acab343", type_checked_symbol_table = "610cf3eeddc2789f19854347864fbaae2dc10ecc0aae8034fe3eaaa2394ba89f", unrolled_symbol_table = "610cf3eeddc2789f19854347864fbaae2dc10ecc0aae8034fe3eaaa2394ba89f", initial_ast = "fc9f1985c1e0441e9423e67cfd4cb8252178ccc236dfabae17187c5a5cc98ebe", unrolled_ast = "c6fdd37447ee674a058e7fe314096c0df8cf0c02f307ff499e0f08b76cdc6709", ssa_ast = "d26ea69b3993a2a3c4b2660a27706c51383f9b01357d27adf6275a5dfffe6e9d", flattened_ast = "5741efe1907a4da96fbad021b725a22e8c3365fa61b2413b06743c3ed01cda35", destructured_ast = "496bea9fd498c2d4ac9d93dd143beb403e13fdf59fc2ff842d8ff932883feda1", inlined_ast = "7c87cc964f8225fd91c634c8683ee0b09aaa301cb29ab85cadc4e4aea65253ba", dce_ast = "7c87cc964f8225fd91c634c8683ee0b09aaa301cb29ab85cadc4e4aea65253ba", bytecode = """
{ initial_symbol_table = "baa9875274a09ad91eb08326f18797401a6e98c32388e75b3b406a539acab343", type_checked_symbol_table = "6cc6e544cd0fac9b595d1236775033d8f492c506570b174bd4958280c45238fb", unrolled_symbol_table = "6cc6e544cd0fac9b595d1236775033d8f492c506570b174bd4958280c45238fb", initial_ast = "fc9f1985c1e0441e9423e67cfd4cb8252178ccc236dfabae17187c5a5cc98ebe", unrolled_ast = "c6fdd37447ee674a058e7fe314096c0df8cf0c02f307ff499e0f08b76cdc6709", ssa_ast = "d26ea69b3993a2a3c4b2660a27706c51383f9b01357d27adf6275a5dfffe6e9d", flattened_ast = "5741efe1907a4da96fbad021b725a22e8c3365fa61b2413b06743c3ed01cda35", destructured_ast = "496bea9fd498c2d4ac9d93dd143beb403e13fdf59fc2ff842d8ff932883feda1", inlined_ast = "7c87cc964f8225fd91c634c8683ee0b09aaa301cb29ab85cadc4e4aea65253ba", dce_ast = "7c87cc964f8225fd91c634c8683ee0b09aaa301cb29ab85cadc4e4aea65253ba", bytecode = """
import credits.aleo;
program test_credits.aleo;

Expand Down
11 changes: 11 additions & 0 deletions tests/expectations/compiler/futures/future_in_tuple_check_fail.out
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
namespace = "Compile"
expectation = "Fail"
outputs = ["""
Error [ETYC0372104]: Not all futures were consumed: result2
--> compiler-test:9:27
|
9 | return (result.0, finish(result.1));
| ^^^^^^^^^^^^^^^^
|
= Make sure all futures are consumed exactly once. Consume by passing to an async function call.
"""]
Loading

0 comments on commit 2760c4c

Please sign in to comment.