Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix futures #28469

Merged
merged 3 commits into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion compiler/ast/src/passes/visitor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,12 @@ pub trait ExpressionVisitor<'a> {
Default::default()
}

fn visit_struct_init(&mut self, _input: &'a StructExpression, _additional: &Self::AdditionalInput) -> Self::Output {
fn visit_struct_init(&mut self, input: &'a StructExpression, additional: &Self::AdditionalInput) -> Self::Output {
for StructVariableInitializer { expression, .. } in input.members.iter() {
if let Some(expression) = expression {
self.visit_expression(expression, additional);
}
}
Default::default()
}

Expand Down
5 changes: 3 additions & 2 deletions compiler/parser/src/parser/file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -290,9 +290,10 @@ impl<N: Network> ParserContext<'_, N> {
let name = self.expect_identifier()?;
self.expect(&Token::Colon)?;

let type_ = self.parse_type()?.0;
let (type_, type_span) = self.parse_type()?;
let span = name.span() + type_span;

Ok(functions::Input { identifier: name, mode, type_, span: name.span, id: self.node_builder.next_id() })
Ok(functions::Input { identifier: name, mode, type_, span, id: self.node_builder.next_id() })
}

/// Returns an [`Output`] AST node if the next tokens represent a function output.
Expand Down
4 changes: 4 additions & 0 deletions compiler/passes/src/static_analysis/analyze_program.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ impl<'a, N: Network> ProgramVisitor<'a> for StaticAnalyzer<'a, N> {
// Set `non_async_external_call_seen` to false.
self.non_async_external_call_seen = false;

if matches!(self.variant, Some(Variant::AsyncFunction) | Some(Variant::AsyncTransition)) {
super::future_checker::future_check_function(function, self.type_table, self.handler);
}

// If the function is an async function, initialize the await checker.
if self.variant == Some(Variant::AsyncFunction) {
// Initialize the list of input futures. Each one must be awaited before the end of the function.
Expand Down
167 changes: 167 additions & 0 deletions compiler/passes/src/static_analysis/future_checker.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
// Copyright (C) 2019-2024 Aleo Systems Inc.
// This file is part of the Leo library.

// The Leo library is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.

// The Leo library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.

// You should have received a copy of the GNU General Public License
// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.

use crate::TypeTable;

use leo_ast::{CoreFunction, Expression, ExpressionVisitor, Function, Node, StatementVisitor, Type};
use leo_errors::{StaticAnalyzerError, emitter::Handler};

/// Error if futures are used improperly.
///
/// This prevents, for instance, a bare call which creates an unused future.
pub fn future_check_function(function: &Function, type_table: &TypeTable, handler: &Handler) {
let mut future_checker = FutureChecker { type_table, handler };
future_checker.visit_block(&function.block);
}

#[derive(Clone, Copy, Debug, Default)]
enum Position {
#[default]
Misc,
Await,
TupleAccess,
Return,
FunctionArgument,
LastTupleLiteral,
Definition,
}

struct FutureChecker<'a> {
type_table: &'a TypeTable,
handler: &'a Handler,
}

impl<'a> FutureChecker<'a> {
fn emit_err(&self, err: StaticAnalyzerError) {
self.handler.emit_err(err);
}
}

impl<'a> ExpressionVisitor<'a> for FutureChecker<'a> {
type AdditionalInput = Position;
type Output = ();

fn visit_expression(&mut self, input: &'a Expression, additional: &Self::AdditionalInput) -> Self::Output {
use Position::*;
let is_call = matches!(input, Expression::Call(..));
match self.type_table.get(&input.id()) {
Some(Type::Future(..)) if is_call => {
// A call producing a Future may appear in any of these positions.
if !matches!(additional, Await | Return | FunctionArgument | LastTupleLiteral | Definition) {
self.emit_err(StaticAnalyzerError::misplaced_future(input.span()));
}
}
Some(Type::Future(..)) => {
// A Future expression that's not a call may appear in any of these positions.
if !matches!(additional, Await | Return | FunctionArgument | LastTupleLiteral | TupleAccess) {
self.emit_err(StaticAnalyzerError::misplaced_future(input.span()));
}
}
Some(Type::Tuple(tuple)) if !matches!(tuple.elements().last(), Some(Type::Future(_))) => {}
Some(Type::Tuple(..)) if is_call => {
// A call producing a Tuple ending in a Future may appear in any of these positions.
if !matches!(additional, Return | Definition) {
self.emit_err(StaticAnalyzerError::misplaced_future(input.span()));
}
}
Some(Type::Tuple(..)) => {
// A Tuple ending in a Future that's not a call may appear in any of these positions.
if !matches!(additional, Return | TupleAccess) {
self.emit_err(StaticAnalyzerError::misplaced_future(input.span()));
}
}
_ => {}
}

match input {
Expression::Access(access) => self.visit_access(access, &Position::Misc),
Expression::Array(array) => self.visit_array(array, &Position::Misc),
Expression::Binary(binary) => self.visit_binary(binary, &Position::Misc),
Expression::Call(call) => self.visit_call(call, &Position::Misc),
Expression::Cast(cast) => self.visit_cast(cast, &Position::Misc),
Expression::Struct(struct_) => self.visit_struct_init(struct_, &Position::Misc),
Expression::Err(err) => self.visit_err(err, &Position::Misc),
Expression::Identifier(identifier) => self.visit_identifier(identifier, &Position::Misc),
Expression::Literal(literal) => self.visit_literal(literal, &Position::Misc),
Expression::Locator(locator) => self.visit_locator(locator, &Position::Misc),
Expression::Ternary(ternary) => self.visit_ternary(ternary, &Position::Misc),
Expression::Tuple(tuple) => self.visit_tuple(tuple, additional),
Expression::Unary(unary) => self.visit_unary(unary, &Position::Misc),
Expression::Unit(unit) => self.visit_unit(unit, &Position::Misc),
}
}

fn visit_access(
&mut self,
input: &'a leo_ast::AccessExpression,
_additional: &Self::AdditionalInput,
) -> Self::Output {
match input {
leo_ast::AccessExpression::Array(array) => {
self.visit_expression(&array.array, &Position::Misc);
self.visit_expression(&array.index, &Position::Misc);
}
leo_ast::AccessExpression::AssociatedFunction(function) => {
let core_function = CoreFunction::from_symbols(function.variant.name, function.name.name)
.expect("Typechecking guarantees that this function exists.");
let position =
if core_function == CoreFunction::FutureAwait { Position::Await } else { Position::Misc };
function.arguments.iter().for_each(|arg| {
self.visit_expression(arg, &position);
});
}
leo_ast::AccessExpression::Member(member) => {
self.visit_expression(&member.inner, &Position::Misc);
}
leo_ast::AccessExpression::Tuple(tuple) => {
self.visit_expression(&tuple.tuple, &Position::TupleAccess);
}
_ => {}
}

Default::default()
}

fn visit_call(&mut self, input: &'a leo_ast::CallExpression, _additional: &Self::AdditionalInput) -> Self::Output {
input.arguments.iter().for_each(|expr| {
self.visit_expression(expr, &Position::FunctionArgument);
});
Default::default()
}

fn visit_tuple(&mut self, input: &'a leo_ast::TupleExpression, additional: &Self::AdditionalInput) -> Self::Output {
let next_position = match additional {
Position::Definition | Position::Return => Position::LastTupleLiteral,
_ => Position::Misc,
};
let mut iter = input.elements.iter().peekable();
while let Some(expr) = iter.next() {
let position = if iter.peek().is_some() { &Position::Misc } else { &next_position };
self.visit_expression(expr, position);
}
Default::default()
}
}

impl<'a> StatementVisitor<'a> for FutureChecker<'a> {
fn visit_definition(&mut self, input: &'a leo_ast::DefinitionStatement) {
self.visit_expression(&input.value, &Position::Definition);
}

fn visit_return(&mut self, input: &'a leo_ast::ReturnStatement) {
self.visit_expression(&input.expression, &Position::Return);
}
}
2 changes: 2 additions & 0 deletions compiler/passes/src/static_analysis/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
// You should have received a copy of the GNU General Public License
// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.

mod future_checker;

mod await_checker;

pub mod analyze_expression;
Expand Down
42 changes: 24 additions & 18 deletions compiler/passes/src/type_checking/check_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,6 @@ impl<'a, N: Network> ExpressionVisitor<'a> for TypeChecker<'a, N> {
if let Some(expected) = expected {
self.check_eq_types(&Some(actual.clone()), &Some(expected.clone()), access.span());
}

// Return type of tuple index.
return Some(actual);
}
Expand Down Expand Up @@ -671,6 +670,30 @@ impl<'a, N: Network> ExpressionVisitor<'a> for TypeChecker<'a, N> {
let ty = self.visit_expression(argument, &Some(expected.type_().clone()))?;
// Extract information about futures that are being consumed.
if func.variant == Variant::AsyncFunction && matches!(expected.type_(), Type::Future(_)) {
// Consume the future.
let option_name = match argument {
Expression::Identifier(id) => Some(id.name),
Expression::Access(AccessExpression::Tuple(tuple_access)) => {
if let Expression::Identifier(id) = &*tuple_access.tuple {
Some(id.name)
} else {
None
}
}
_ => None,
};

if let Some(name) = option_name {
match self.scope_state.futures.shift_remove(&name) {
Some(future) => {
self.scope_state.call_location = Some(future.clone());
}
None => {
self.emit_err(TypeCheckerError::unknown_future_consumed(name, argument.span()));
}
}
}

match argument {
Expression::Identifier(_)
| Expression::Call(_)
Expand Down Expand Up @@ -853,23 +876,6 @@ impl<'a, N: Network> ExpressionVisitor<'a> for TypeChecker<'a, N> {
fn visit_identifier(&mut self, input: &'a Identifier, expected: &Self::AdditionalInput) -> Self::Output {
let var = self.symbol_table.borrow().lookup_variable(Location::new(None, input.name)).cloned();
if let Some(var) = &var {
if matches!(var.type_, Type::Future(_)) && matches!(expected, Some(Type::Future(_))) {
if self.scope_state.variant == Some(Variant::AsyncTransition) && self.scope_state.is_call {
// Consume future.
match self.scope_state.futures.shift_remove(&input.name) {
Some(future) => {
self.scope_state.call_location = Some(future.clone());
return Some(var.type_.clone());
}
None => {
self.emit_err(TypeCheckerError::unknown_future_consumed(input.name, input.span));
}
}
} else {
// Case where accessing input argument of future. Ex `f.1`.
return Some(var.type_.clone());
}
}
Some(self.assert_and_return_type(var.type_.clone(), expected, input.span()))
} else {
self.emit_err(TypeCheckerError::unknown_sym("variable", input.name, input.span()));
Expand Down
16 changes: 9 additions & 7 deletions compiler/passes/src/type_checking/check_statements.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@ use leo_ast::{
};
use leo_errors::TypeCheckerError;

use itertools::Itertools;

impl<'a, N: Network> StatementVisitor<'a> for TypeChecker<'a, N> {
fn visit_statement(&mut self, input: &'a Statement) {
// No statements can follow a return statement.
Expand Down Expand Up @@ -248,7 +246,7 @@ impl<'a, N: Network> StatementVisitor<'a> for TypeChecker<'a, N> {
// Insert the variables into the symbol table.
match &input.place {
Expression::Identifier(identifier) => {
self.insert_variable(inferred_type.clone(), identifier, input.type_.clone(), 0, identifier.span)
self.insert_variable(inferred_type.clone(), identifier, input.type_.clone(), identifier.span)
}
Expression::Tuple(tuple_expression) => {
let tuple_type = match &input.type_ {
Expand All @@ -265,17 +263,21 @@ impl<'a, N: Network> StatementVisitor<'a> for TypeChecker<'a, N> {
));
}

for ((index, expr), type_) in
tuple_expression.elements.iter().enumerate().zip_eq(tuple_type.elements().iter())
{
for i in 0..tuple_expression.elements.len() {
let inferred = if let Some(Type::Tuple(inferred_tuple)) = &inferred_type {
inferred_tuple.elements().get(i).cloned()
} else {
None
};
let expr = &tuple_expression.elements[i];
let identifier = match expr {
Expression::Identifier(identifier) => identifier,
_ => {
return self
.emit_err(TypeCheckerError::lhs_tuple_element_must_be_an_identifier(expr.span()));
}
};
self.insert_variable(inferred_type.clone(), identifier, type_.clone(), index, identifier.span);
self.insert_variable(inferred, identifier, tuple_type.elements()[i].clone(), identifier.span);
}
}
_ => self.emit_err(TypeCheckerError::lhs_must_be_identifier_or_tuple(input.place.span())),
Expand Down
56 changes: 29 additions & 27 deletions compiler/passes/src/type_checking/checker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1220,10 +1220,21 @@ impl<'a, N: Network> TypeChecker<'a, N> {
}
}

// Add function inputs to the symbol table. Futures have already been added.
if !matches!(&input_var.type_(), &Type::Future(_)) {
if matches!(&input_var.type_(), Type::Future(_)) {
// Future parameters may only appear in async functions.
if !matches!(self.scope_state.variant, Some(Variant::AsyncFunction)) {
self.emit_err(TypeCheckerError::no_future_parameters(input_var.span()));
}
}

let location = Location::new(None, input_var.identifier().name);
if !matches!(&input_var.type_(), Type::Future(_))
|| self.symbol_table.borrow().lookup_variable_in_current_scope(location.clone()).is_none()
{
// Add function inputs to the symbol table. If inference happened properly above, futures were already added.
// But if a future was not added, add it now so as not to give confusing error messages.
if let Err(err) = self.symbol_table.borrow_mut().insert_variable(
Location::new(None, input_var.identifier().name),
location.clone(),
self.scope_state.program_name,
VariableSymbol {
type_: input_var.type_().clone(),
Expand Down Expand Up @@ -1302,32 +1313,23 @@ impl<'a, N: Network> TypeChecker<'a, N> {
}

/// Inserts variable to symbol table.
pub(crate) fn insert_variable(
&mut self,
inferred_type: Option<Type>,
name: &Identifier,
type_: Type,
index: usize,
span: Span,
) {
let ty: Type = if let Type::Future(_) = type_ {
// Need to insert the fully inferred future type, or else will just be default future type.
let ret = match inferred_type.unwrap() {
Type::Future(future) => Type::Future(future),
Type::Tuple(tuple) => match tuple.elements().get(index) {
Some(Type::Future(future)) => Type::Future(future.clone()),
_ => unreachable!("Parsing guarantees that the inferred type is a future."),
},
_ => {
unreachable!("TYC guarantees that the inferred type is a future, or tuple containing futures.")
}
};
// Insert future into list of futures for the function.
pub(crate) fn insert_variable(&mut self, inferred_type: Option<Type>, name: &Identifier, type_: Type, span: Span) {
let is_future = match &type_ {
Type::Future(..) => true,
Type::Tuple(tuple_type) if matches!(tuple_type.elements().last(), Some(Type::Future(..))) => true,
_ => false,
};

if is_future {
self.scope_state.futures.insert(name.name, self.scope_state.call_location.clone().unwrap());
ret
} else {
type_
}

let ty = match (is_future, inferred_type) {
(false, _) => type_,
(true, Some(inferred)) => inferred,
(true, None) => unreachable!("Type checking guarantees the inferred type is present"),
};

// Insert the variable into the symbol table.
if let Err(err) = self.symbol_table.borrow_mut().insert_variable(
Location::new(None, name.name),
Expand Down
Loading
Loading