Skip to content

Commit

Permalink
Merge pull request #28469 from ProvableHQ/fix-futures
Browse files Browse the repository at this point in the history
Fix futures
  • Loading branch information
d0cd authored Dec 16, 2024
2 parents b90f958 + 7b88c3e commit 402a8e4
Show file tree
Hide file tree
Showing 48 changed files with 514 additions and 115 deletions.
7 changes: 6 additions & 1 deletion compiler/ast/src/passes/visitor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,12 @@ pub trait ExpressionVisitor<'a> {
Default::default()
}

fn visit_struct_init(&mut self, _input: &'a StructExpression, _additional: &Self::AdditionalInput) -> Self::Output {
fn visit_struct_init(&mut self, input: &'a StructExpression, additional: &Self::AdditionalInput) -> Self::Output {
for StructVariableInitializer { expression, .. } in input.members.iter() {
if let Some(expression) = expression {
self.visit_expression(expression, additional);
}
}
Default::default()
}

Expand Down
5 changes: 3 additions & 2 deletions compiler/parser/src/parser/file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -290,9 +290,10 @@ impl<N: Network> ParserContext<'_, N> {
let name = self.expect_identifier()?;
self.expect(&Token::Colon)?;

let type_ = self.parse_type()?.0;
let (type_, type_span) = self.parse_type()?;
let span = name.span() + type_span;

Ok(functions::Input { identifier: name, mode, type_, span: name.span, id: self.node_builder.next_id() })
Ok(functions::Input { identifier: name, mode, type_, span, id: self.node_builder.next_id() })
}

/// Returns an [`Output`] AST node if the next tokens represent a function output.
Expand Down
4 changes: 4 additions & 0 deletions compiler/passes/src/static_analysis/analyze_program.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ impl<'a, N: Network> ProgramVisitor<'a> for StaticAnalyzer<'a, N> {
// Set `non_async_external_call_seen` to false.
self.non_async_external_call_seen = false;

if matches!(self.variant, Some(Variant::AsyncFunction) | Some(Variant::AsyncTransition)) {
super::future_checker::future_check_function(function, self.type_table, self.handler);
}

// If the function is an async function, initialize the await checker.
if self.variant == Some(Variant::AsyncFunction) {
// Initialize the list of input futures. Each one must be awaited before the end of the function.
Expand Down
167 changes: 167 additions & 0 deletions compiler/passes/src/static_analysis/future_checker.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
// Copyright (C) 2019-2024 Aleo Systems Inc.
// This file is part of the Leo library.

// The Leo library is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.

// The Leo library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.

// You should have received a copy of the GNU General Public License
// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.

use crate::TypeTable;

use leo_ast::{CoreFunction, Expression, ExpressionVisitor, Function, Node, StatementVisitor, Type};
use leo_errors::{StaticAnalyzerError, emitter::Handler};

/// Error if futures are used improperly.
///
/// This prevents, for instance, a bare call which creates an unused future.
pub fn future_check_function(function: &Function, type_table: &TypeTable, handler: &Handler) {
let mut future_checker = FutureChecker { type_table, handler };
future_checker.visit_block(&function.block);
}

#[derive(Clone, Copy, Debug, Default)]
enum Position {
#[default]
Misc,
Await,
TupleAccess,
Return,
FunctionArgument,
LastTupleLiteral,
Definition,
}

struct FutureChecker<'a> {
type_table: &'a TypeTable,
handler: &'a Handler,
}

impl<'a> FutureChecker<'a> {
fn emit_err(&self, err: StaticAnalyzerError) {
self.handler.emit_err(err);
}
}

impl<'a> ExpressionVisitor<'a> for FutureChecker<'a> {
type AdditionalInput = Position;
type Output = ();

fn visit_expression(&mut self, input: &'a Expression, additional: &Self::AdditionalInput) -> Self::Output {
use Position::*;
let is_call = matches!(input, Expression::Call(..));
match self.type_table.get(&input.id()) {
Some(Type::Future(..)) if is_call => {
// A call producing a Future may appear in any of these positions.
if !matches!(additional, Await | Return | FunctionArgument | LastTupleLiteral | Definition) {
self.emit_err(StaticAnalyzerError::misplaced_future(input.span()));
}
}
Some(Type::Future(..)) => {
// A Future expression that's not a call may appear in any of these positions.
if !matches!(additional, Await | Return | FunctionArgument | LastTupleLiteral | TupleAccess) {
self.emit_err(StaticAnalyzerError::misplaced_future(input.span()));
}
}
Some(Type::Tuple(tuple)) if !matches!(tuple.elements().last(), Some(Type::Future(_))) => {}
Some(Type::Tuple(..)) if is_call => {
// A call producing a Tuple ending in a Future may appear in any of these positions.
if !matches!(additional, Return | Definition) {
self.emit_err(StaticAnalyzerError::misplaced_future(input.span()));
}
}
Some(Type::Tuple(..)) => {
// A Tuple ending in a Future that's not a call may appear in any of these positions.
if !matches!(additional, Return | TupleAccess) {
self.emit_err(StaticAnalyzerError::misplaced_future(input.span()));
}
}
_ => {}
}

match input {
Expression::Access(access) => self.visit_access(access, &Position::Misc),
Expression::Array(array) => self.visit_array(array, &Position::Misc),
Expression::Binary(binary) => self.visit_binary(binary, &Position::Misc),
Expression::Call(call) => self.visit_call(call, &Position::Misc),
Expression::Cast(cast) => self.visit_cast(cast, &Position::Misc),
Expression::Struct(struct_) => self.visit_struct_init(struct_, &Position::Misc),
Expression::Err(err) => self.visit_err(err, &Position::Misc),
Expression::Identifier(identifier) => self.visit_identifier(identifier, &Position::Misc),
Expression::Literal(literal) => self.visit_literal(literal, &Position::Misc),
Expression::Locator(locator) => self.visit_locator(locator, &Position::Misc),
Expression::Ternary(ternary) => self.visit_ternary(ternary, &Position::Misc),
Expression::Tuple(tuple) => self.visit_tuple(tuple, additional),
Expression::Unary(unary) => self.visit_unary(unary, &Position::Misc),
Expression::Unit(unit) => self.visit_unit(unit, &Position::Misc),
}
}

fn visit_access(
&mut self,
input: &'a leo_ast::AccessExpression,
_additional: &Self::AdditionalInput,
) -> Self::Output {
match input {
leo_ast::AccessExpression::Array(array) => {
self.visit_expression(&array.array, &Position::Misc);
self.visit_expression(&array.index, &Position::Misc);
}
leo_ast::AccessExpression::AssociatedFunction(function) => {
let core_function = CoreFunction::from_symbols(function.variant.name, function.name.name)
.expect("Typechecking guarantees that this function exists.");
let position =
if core_function == CoreFunction::FutureAwait { Position::Await } else { Position::Misc };
function.arguments.iter().for_each(|arg| {
self.visit_expression(arg, &position);
});
}
leo_ast::AccessExpression::Member(member) => {
self.visit_expression(&member.inner, &Position::Misc);
}
leo_ast::AccessExpression::Tuple(tuple) => {
self.visit_expression(&tuple.tuple, &Position::TupleAccess);
}
_ => {}
}

Default::default()
}

fn visit_call(&mut self, input: &'a leo_ast::CallExpression, _additional: &Self::AdditionalInput) -> Self::Output {
input.arguments.iter().for_each(|expr| {
self.visit_expression(expr, &Position::FunctionArgument);
});
Default::default()
}

fn visit_tuple(&mut self, input: &'a leo_ast::TupleExpression, additional: &Self::AdditionalInput) -> Self::Output {
let next_position = match additional {
Position::Definition | Position::Return => Position::LastTupleLiteral,
_ => Position::Misc,
};
let mut iter = input.elements.iter().peekable();
while let Some(expr) = iter.next() {
let position = if iter.peek().is_some() { &Position::Misc } else { &next_position };
self.visit_expression(expr, position);
}
Default::default()
}
}

impl<'a> StatementVisitor<'a> for FutureChecker<'a> {
fn visit_definition(&mut self, input: &'a leo_ast::DefinitionStatement) {
self.visit_expression(&input.value, &Position::Definition);
}

fn visit_return(&mut self, input: &'a leo_ast::ReturnStatement) {
self.visit_expression(&input.expression, &Position::Return);
}
}
2 changes: 2 additions & 0 deletions compiler/passes/src/static_analysis/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
// You should have received a copy of the GNU General Public License
// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.

mod future_checker;

mod await_checker;

pub mod analyze_expression;
Expand Down
42 changes: 24 additions & 18 deletions compiler/passes/src/type_checking/check_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,6 @@ impl<'a, N: Network> ExpressionVisitor<'a> for TypeChecker<'a, N> {
if let Some(expected) = expected {
self.check_eq_types(&Some(actual.clone()), &Some(expected.clone()), access.span());
}

// Return type of tuple index.
return Some(actual);
}
Expand Down Expand Up @@ -675,6 +674,30 @@ impl<'a, N: Network> ExpressionVisitor<'a> for TypeChecker<'a, N> {
let ty = self.visit_expression(argument, &Some(expected.type_().clone()))?;
// Extract information about futures that are being consumed.
if func.variant == Variant::AsyncFunction && matches!(expected.type_(), Type::Future(_)) {
// Consume the future.
let option_name = match argument {
Expression::Identifier(id) => Some(id.name),
Expression::Access(AccessExpression::Tuple(tuple_access)) => {
if let Expression::Identifier(id) = &*tuple_access.tuple {
Some(id.name)
} else {
None
}
}
_ => None,
};

if let Some(name) = option_name {
match self.scope_state.futures.shift_remove(&name) {
Some(future) => {
self.scope_state.call_location = Some(future.clone());
}
None => {
self.emit_err(TypeCheckerError::unknown_future_consumed(name, argument.span()));
}
}
}

match argument {
Expression::Identifier(_)
| Expression::Call(_)
Expand Down Expand Up @@ -857,23 +880,6 @@ impl<'a, N: Network> ExpressionVisitor<'a> for TypeChecker<'a, N> {
fn visit_identifier(&mut self, input: &'a Identifier, expected: &Self::AdditionalInput) -> Self::Output {
let var = self.symbol_table.borrow().lookup_variable(Location::new(None, input.name)).cloned();
if let Some(var) = &var {
if matches!(var.type_, Type::Future(_)) && matches!(expected, Some(Type::Future(_))) {
if self.scope_state.variant == Some(Variant::AsyncTransition) && self.scope_state.is_call {
// Consume future.
match self.scope_state.futures.shift_remove(&input.name) {
Some(future) => {
self.scope_state.call_location = Some(future.clone());
return Some(var.type_.clone());
}
None => {
self.emit_err(TypeCheckerError::unknown_future_consumed(input.name, input.span));
}
}
} else {
// Case where accessing input argument of future. Ex `f.1`.
return Some(var.type_.clone());
}
}
Some(self.assert_and_return_type(var.type_.clone(), expected, input.span()))
} else {
self.emit_err(TypeCheckerError::unknown_sym("variable", input.name, input.span()));
Expand Down
16 changes: 9 additions & 7 deletions compiler/passes/src/type_checking/check_statements.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@ use leo_ast::{
};
use leo_errors::TypeCheckerError;

use itertools::Itertools;

impl<'a, N: Network> StatementVisitor<'a> for TypeChecker<'a, N> {
fn visit_statement(&mut self, input: &'a Statement) {
// No statements can follow a return statement.
Expand Down Expand Up @@ -248,7 +246,7 @@ impl<'a, N: Network> StatementVisitor<'a> for TypeChecker<'a, N> {
// Insert the variables into the symbol table.
match &input.place {
Expression::Identifier(identifier) => {
self.insert_variable(inferred_type.clone(), identifier, input.type_.clone(), 0, identifier.span)
self.insert_variable(inferred_type.clone(), identifier, input.type_.clone(), identifier.span)
}
Expression::Tuple(tuple_expression) => {
let tuple_type = match &input.type_ {
Expand All @@ -265,17 +263,21 @@ impl<'a, N: Network> StatementVisitor<'a> for TypeChecker<'a, N> {
));
}

for ((index, expr), type_) in
tuple_expression.elements.iter().enumerate().zip_eq(tuple_type.elements().iter())
{
for i in 0..tuple_expression.elements.len() {
let inferred = if let Some(Type::Tuple(inferred_tuple)) = &inferred_type {
inferred_tuple.elements().get(i).cloned()
} else {
None
};
let expr = &tuple_expression.elements[i];
let identifier = match expr {
Expression::Identifier(identifier) => identifier,
_ => {
return self
.emit_err(TypeCheckerError::lhs_tuple_element_must_be_an_identifier(expr.span()));
}
};
self.insert_variable(inferred_type.clone(), identifier, type_.clone(), index, identifier.span);
self.insert_variable(inferred, identifier, tuple_type.elements()[i].clone(), identifier.span);
}
}
_ => self.emit_err(TypeCheckerError::lhs_must_be_identifier_or_tuple(input.place.span())),
Expand Down
56 changes: 29 additions & 27 deletions compiler/passes/src/type_checking/checker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1220,10 +1220,21 @@ impl<'a, N: Network> TypeChecker<'a, N> {
}
}

// Add function inputs to the symbol table. Futures have already been added.
if !matches!(&input_var.type_(), &Type::Future(_)) {
if matches!(&input_var.type_(), Type::Future(_)) {
// Future parameters may only appear in async functions.
if !matches!(self.scope_state.variant, Some(Variant::AsyncFunction)) {
self.emit_err(TypeCheckerError::no_future_parameters(input_var.span()));
}
}

let location = Location::new(None, input_var.identifier().name);
if !matches!(&input_var.type_(), Type::Future(_))
|| self.symbol_table.borrow().lookup_variable_in_current_scope(location.clone()).is_none()
{
// Add function inputs to the symbol table. If inference happened properly above, futures were already added.
// But if a future was not added, add it now so as not to give confusing error messages.
if let Err(err) = self.symbol_table.borrow_mut().insert_variable(
Location::new(None, input_var.identifier().name),
location.clone(),
self.scope_state.program_name,
VariableSymbol {
type_: input_var.type_().clone(),
Expand Down Expand Up @@ -1302,32 +1313,23 @@ impl<'a, N: Network> TypeChecker<'a, N> {
}

/// Inserts variable to symbol table.
pub(crate) fn insert_variable(
&mut self,
inferred_type: Option<Type>,
name: &Identifier,
type_: Type,
index: usize,
span: Span,
) {
let ty: Type = if let Type::Future(_) = type_ {
// Need to insert the fully inferred future type, or else will just be default future type.
let ret = match inferred_type.unwrap() {
Type::Future(future) => Type::Future(future),
Type::Tuple(tuple) => match tuple.elements().get(index) {
Some(Type::Future(future)) => Type::Future(future.clone()),
_ => unreachable!("Parsing guarantees that the inferred type is a future."),
},
_ => {
unreachable!("TYC guarantees that the inferred type is a future, or tuple containing futures.")
}
};
// Insert future into list of futures for the function.
pub(crate) fn insert_variable(&mut self, inferred_type: Option<Type>, name: &Identifier, type_: Type, span: Span) {
let is_future = match &type_ {
Type::Future(..) => true,
Type::Tuple(tuple_type) if matches!(tuple_type.elements().last(), Some(Type::Future(..))) => true,
_ => false,
};

if is_future {
self.scope_state.futures.insert(name.name, self.scope_state.call_location.clone().unwrap());
ret
} else {
type_
}

let ty = match (is_future, inferred_type) {
(false, _) => type_,
(true, Some(inferred)) => inferred,
(true, None) => unreachable!("Type checking guarantees the inferred type is present"),
};

// Insert the variable into the symbol table.
if let Err(err) = self.symbol_table.borrow_mut().insert_variable(
Location::new(None, name.name),
Expand Down
Loading

0 comments on commit 402a8e4

Please sign in to comment.