From cad818673d9ebd35b7d2c1a61f1b58e72fc5d134 Mon Sep 17 00:00:00 2001 From: Edgar Date: Mon, 22 Jan 2024 15:47:56 -0300 Subject: [PATCH 1/9] struct parsing and type --- crates/concrete_ast/src/structs.rs | 1 + crates/concrete_ast/src/types.rs | 2 +- crates/concrete_codegen_mlir/src/codegen.rs | 114 +------- crates/concrete_codegen_mlir/src/lib.rs | 1 + .../src/scope_context.rs | 256 ++++++++++++++++++ crates/concrete_parser/src/grammar.lalrpop | 24 ++ examples/structs.con | 26 ++ 7 files changed, 315 insertions(+), 109 deletions(-) create mode 100644 crates/concrete_codegen_mlir/src/scope_context.rs create mode 100644 examples/structs.con diff --git a/crates/concrete_ast/src/structs.rs b/crates/concrete_ast/src/structs.rs index bd36ab6..9d134b0 100644 --- a/crates/concrete_ast/src/structs.rs +++ b/crates/concrete_ast/src/structs.rs @@ -5,6 +5,7 @@ use crate::{ #[derive(Clone, Debug, Eq, PartialEq)] pub struct StructDecl { + pub is_pub: bool, pub doc_string: Option, pub name: Ident, pub type_params: Vec, diff --git a/crates/concrete_ast/src/types.rs b/crates/concrete_ast/src/types.rs index 2ae7c87..93116a8 100644 --- a/crates/concrete_ast/src/types.rs +++ b/crates/concrete_ast/src/types.rs @@ -21,7 +21,7 @@ pub enum TypeSpec { }, Array { of_type: Box, - size: Option, + size: Option, is_ref: Option, span: Span, }, diff --git a/crates/concrete_codegen_mlir/src/codegen.rs b/crates/concrete_codegen_mlir/src/codegen.rs index 0db203f..825ddfd 100644 --- a/crates/concrete_codegen_mlir/src/codegen.rs +++ b/crates/concrete_codegen_mlir/src/codegen.rs @@ -20,13 +20,15 @@ use melior::{ ir::{ attribute::{FlatSymbolRefAttribute, IntegerAttribute, StringAttribute, TypeAttribute}, r#type::{FunctionType, IntegerType, MemRefType}, - Block, BlockRef, Location, Module as MeliorModule, Operation, Region, Type, Value, - ValueLike, + Block, BlockRef, Location, Module as MeliorModule, Operation, Region, Value, ValueLike, }, Context as MeliorContext, }; -use crate::ast_helper::{AstHelper, ModuleInfo}; +use crate::{ + ast_helper::{AstHelper, ModuleInfo}, + scope_context::ScopeContext, +}; pub fn compile_program( session: &Session, @@ -71,14 +73,6 @@ impl<'ctx, 'parent: 'ctx> LocalVar<'ctx, 'parent> { } } -#[derive(Debug, Clone)] -struct ScopeContext<'ctx, 'parent: 'ctx> { - pub locals: HashMap>, - pub function: Option, - pub imports: HashMap>, - pub module_info: &'parent ModuleInfo<'parent>, -} - struct BlockHelper<'ctx, 'region: 'ctx> { region: &'region Region<'ctx>, blocks_arena: &'region Bump, @@ -94,102 +88,6 @@ impl<'ctx, 'region> BlockHelper<'ctx, 'region> { } } -impl<'ctx, 'parent> ScopeContext<'ctx, 'parent> { - /// Returns the symbol name from a local name. - pub fn get_symbol_name(&self, local_name: &str) -> String { - if local_name == "main" { - return local_name.to_string(); - } - - if let Some(module) = self.imports.get(local_name) { - // a import - module.get_symbol_name(local_name) - } else { - let mut result = self.module_info.name.clone(); - - result.push_str("::"); - result.push_str(local_name); - - result - } - } - - pub fn get_function(&self, local_name: &str) -> Option<&FunctionDef> { - if let Some(module) = self.imports.get(local_name) { - // a import - module.functions.get(local_name).copied() - } else { - self.module_info.functions.get(local_name).copied() - } - } - - fn resolve_type( - &self, - context: &'ctx MeliorContext, - name: &str, - ) -> Result, Box> { - Ok(match name { - "u64" | "i64" => IntegerType::new(context, 64).into(), - "u32" | "i32" => IntegerType::new(context, 32).into(), - "u16" | "i16" => IntegerType::new(context, 16).into(), - "u8" | "i8" => IntegerType::new(context, 8).into(), - "f32" => Type::float32(context), - "f64" => Type::float64(context), - "bool" => IntegerType::new(context, 1).into(), - _ => todo!("custom type lookup"), - }) - } - - fn resolve_type_spec( - &self, - context: &'ctx MeliorContext, - spec: &TypeSpec, - ) -> Result, Box> { - match spec.is_ref() { - Some(_) => { - Ok( - MemRefType::new(self.resolve_type_spec_ref(context, spec)?, &[], None, None) - .into(), - ) - } - None => self.resolve_type_spec_ref(context, spec), - } - } - - /// Resolves the type this ref points to. - fn resolve_type_spec_ref( - &self, - context: &'ctx MeliorContext, - spec: &TypeSpec, - ) -> Result, Box> { - Ok(match spec { - TypeSpec::Simple { name, .. } => self.resolve_type(context, &name.name)?, - TypeSpec::Generic { name, .. } => self.resolve_type(context, &name.name)?, - TypeSpec::Array { .. } => { - todo!("implement arrays") - } - }) - } - - fn is_type_signed(&self, type_info: &TypeSpec) -> bool { - let signed = ["i8", "i16", "i32", "i64", "i128"]; - match type_info { - TypeSpec::Simple { name, .. } => signed.contains(&name.name.as_str()), - TypeSpec::Generic { name, .. } => signed.contains(&name.name.as_str()), - TypeSpec::Array { .. } => unreachable!(), - } - } - - fn is_float(&self, type_info: &TypeSpec) -> bool { - let signed = ["f32", "f64"]; - match type_info { - TypeSpec::Simple { name, .. } => signed.contains(&name.name.as_str()), - TypeSpec::Generic { name, .. } => signed.contains(&name.name.as_str()), - TypeSpec::Array { .. } => unreachable!(), - } - } -} - fn compile_module( session: &Session, context: &MeliorContext, @@ -235,7 +133,7 @@ fn compile_module( let op = compile_function_def(session, context, &scope_ctx, info)?; body.append_operation(op); } - ModuleDefItem::Struct(_) => todo!(), + ModuleDefItem::Struct(_) => {} ModuleDefItem::Type(_) => todo!(), ModuleDefItem::Module(info) => { let module_info = module_info.modules.get(&info.name.name).unwrap_or_else(|| { diff --git a/crates/concrete_codegen_mlir/src/lib.rs b/crates/concrete_codegen_mlir/src/lib.rs index 372e73f..244da61 100644 --- a/crates/concrete_codegen_mlir/src/lib.rs +++ b/crates/concrete_codegen_mlir/src/lib.rs @@ -36,6 +36,7 @@ mod error; pub mod linker; mod module; mod pass_manager; +mod scope_context; /// Returns the object file path. pub fn compile(session: &Session, program: &Program) -> Result> { diff --git a/crates/concrete_codegen_mlir/src/scope_context.rs b/crates/concrete_codegen_mlir/src/scope_context.rs new file mode 100644 index 0000000..2ab15a0 --- /dev/null +++ b/crates/concrete_codegen_mlir/src/scope_context.rs @@ -0,0 +1,256 @@ +use std::{collections::HashMap, error::Error}; + +use concrete_ast::{functions::FunctionDef, structs::StructDecl, types::TypeSpec}; +use melior::{ + dialect::llvm, + ir::{ + r#type::{IntegerType, MemRefType}, + Type, + }, + Context as MeliorContext, +}; + +use crate::{ast_helper::ModuleInfo, codegen::LocalVar}; + +#[derive(Debug, Clone)] +pub struct ScopeContext<'ctx, 'parent: 'ctx> { + pub locals: HashMap>, + pub function: Option, + pub imports: HashMap>, + pub module_info: &'parent ModuleInfo<'parent>, +} + +impl<'ctx, 'parent> ScopeContext<'ctx, 'parent> { + /// Returns the symbol name from a local name. + pub fn get_symbol_name(&self, local_name: &str) -> String { + if local_name == "main" { + return local_name.to_string(); + } + + if let Some(module) = self.imports.get(local_name) { + // a import + module.get_symbol_name(local_name) + } else { + let mut result = self.module_info.name.clone(); + + result.push_str("::"); + result.push_str(local_name); + + result + } + } + + pub fn get_function(&self, local_name: &str) -> Option<&FunctionDef> { + if let Some(module) = self.imports.get(local_name) { + // a import + module.functions.get(local_name).copied() + } else { + self.module_info.functions.get(local_name).copied() + } + } + + /// Returns the size in bytes for a type. + fn get_type_size(&self, type_info: &TypeSpec) -> Result> { + Ok(match type_info { + TypeSpec::Simple { name, .. } => match name.name.as_str() { + "u128" | "i128" => 16, + "u64" | "i64" => 8, + "u32" | "i32" => 4, + "u16" | "i16" => 2, + "u8" | "i8" => 1, + "f64" => 8, + "f32" => 4, + "bool" => 1, + name => { + if let Some(x) = self.module_info.structs.get(name) { + let mut size = 0u32; + + for field in &x.fields { + let ty_size = self.get_type_size(&field.r#type)?; + let ty_align = self.get_align_for_size(ty_size); + + // Calculate padding needed. + let size_rounded_up = size.wrapping_add(ty_align).wrapping_sub(1) + & !ty_align.wrapping_sub(1u32); + let pad = size_rounded_up.wrapping_sub(size); + + size += pad; + size += ty_size; + } + + let struct_align = self.get_align_for_size(size); + let size_rounded_up = size.wrapping_add(struct_align).wrapping_sub(1) + & !struct_align.wrapping_sub(1u32); + let pad = size_rounded_up.wrapping_sub(size); + + size + pad + } else if let Some(module) = self.imports.get(name) { + // a import + self.get_type_size( + &module.types.get(name).expect("failed to find type").value, + )? + } else { + self.get_type_size( + &self + .module_info + .types + .get(name) + .expect("failed to find type") + .value, + )? + } + } + }, + TypeSpec::Generic { .. } => todo!(), + TypeSpec::Array { of_type, size, .. } => { + self.get_type_size(of_type)? * size.unwrap_or(1u32) + } + }) + } + + fn get_align_for_size(&self, size: u32) -> u32 { + if size <= 1 { + 1 + } else if size <= 2 { + 2 + } else if size <= 4 { + 4 + } else { + 8 + } + } + + fn get_struct_type( + &self, + context: &'ctx MeliorContext, + strct: &StructDecl, + ) -> Result, Box> { + let mut fields = Vec::with_capacity(strct.fields.len()); + + let mut size: u32 = 0; + + for field in &strct.fields { + let ty = self.resolve_type_spec(context, &field.r#type)?; + let ty_size = self.get_type_size(&field.r#type)?; + let ty_align = self.get_align_for_size(ty_size); + + // Calculate padding needed. + let size_rounded_up = + size.wrapping_add(ty_align).wrapping_sub(1) & !ty_align.wrapping_sub(1u32); + let pad = size_rounded_up.wrapping_sub(size); + + if pad > 0 { + fields.push(llvm::r#type::array( + IntegerType::new(context, 8).into(), + pad, + )); + } + + size += pad; + size += ty_size; + fields.push(ty); + } + + let struct_align = self.get_align_for_size(size); + + // Calculate padding needed for whole struct. + let size_rounded_up = + size.wrapping_add(struct_align).wrapping_sub(1) & !struct_align.wrapping_sub(1u32); + let pad = size_rounded_up.wrapping_sub(size); + + if pad > 0 { + fields.push(llvm::r#type::array( + IntegerType::new(context, 8).into(), + pad, + )); + } + + Ok(llvm::r#type::r#struct(context, &fields, false)) + } + + pub fn resolve_type( + &self, + context: &'ctx MeliorContext, + name: &str, + ) -> Result, Box> { + Ok(match name { + "u64" | "i64" => IntegerType::new(context, 64).into(), + "u32" | "i32" => IntegerType::new(context, 32).into(), + "u16" | "i16" => IntegerType::new(context, 16).into(), + "u8" | "i8" => IntegerType::new(context, 8).into(), + "f32" => Type::float32(context), + "f64" => Type::float64(context), + "bool" => IntegerType::new(context, 1).into(), + name => { + if let Some(strct) = self.module_info.structs.get(name) { + self.get_struct_type(context, strct)? + } else if let Some(module) = self.imports.get(name) { + // a import + self.resolve_type_spec( + context, + &module.types.get(name).expect("failed to find type").value, + )? + } else { + self.resolve_type_spec( + context, + &self + .module_info + .types + .get(name) + .expect("failed to find type") + .value, + )? + } + } + }) + } + + pub fn resolve_type_spec( + &self, + context: &'ctx MeliorContext, + spec: &TypeSpec, + ) -> Result, Box> { + match spec.is_ref() { + Some(_) => { + Ok( + MemRefType::new(self.resolve_type_spec_ref(context, spec)?, &[], None, None) + .into(), + ) + } + None => self.resolve_type_spec_ref(context, spec), + } + } + + /// Resolves the type this ref points to. + pub fn resolve_type_spec_ref( + &self, + context: &'ctx MeliorContext, + spec: &TypeSpec, + ) -> Result, Box> { + Ok(match spec { + TypeSpec::Simple { name, .. } => self.resolve_type(context, &name.name)?, + TypeSpec::Generic { name, .. } => self.resolve_type(context, &name.name)?, + TypeSpec::Array { .. } => { + todo!("implement arrays") + } + }) + } + + pub fn is_type_signed(&self, type_info: &TypeSpec) -> bool { + let signed = ["i8", "i16", "i32", "i64", "i128"]; + match type_info { + TypeSpec::Simple { name, .. } => signed.contains(&name.name.as_str()), + TypeSpec::Generic { name, .. } => signed.contains(&name.name.as_str()), + TypeSpec::Array { .. } => unreachable!(), + } + } + + pub fn is_float(&self, type_info: &TypeSpec) -> bool { + let signed = ["f32", "f64"]; + match type_info { + TypeSpec::Simple { name, .. } => signed.contains(&name.name.as_str()), + TypeSpec::Generic { name, .. } => signed.contains(&name.name.as_str()), + TypeSpec::Array { .. } => unreachable!(), + } + } +} diff --git a/crates/concrete_parser/src/grammar.lalrpop b/crates/concrete_parser/src/grammar.lalrpop index 7bcda21..590f8eb 100644 --- a/crates/concrete_parser/src/grammar.lalrpop +++ b/crates/concrete_parser/src/grammar.lalrpop @@ -235,6 +235,9 @@ pub(crate) ModuleDefItem: ast::modules::ModuleDefItem = { => { ast::modules::ModuleDefItem::Module(<>) }, + => { + ast::modules::ModuleDefItem::Struct(<>) + }, } // Constants @@ -253,6 +256,27 @@ pub(crate) ConstantDef: ast::constants::ConstantDef = { }, } +// - Structs + +pub(crate) Field: ast::structs::Field = { + ":" => ast::structs::Field { + name, + r#type: type_spec, + doc_string: None, + } +} + +pub(crate) Struct: ast::structs::StructDecl = { + "struct" + "{" > "}" => ast::structs::StructDecl { + doc_string: None, + is_pub: is_pub.is_some(), + name, + fields: fields, + type_params: type_params.unwrap_or_else(Vec::new), + } +} + // -- Functions pub(crate) FunctionRetType: ast::types::TypeSpec = { diff --git a/examples/structs.con b/examples/structs.con new file mode 100644 index 0000000..e2cd359 --- /dev/null +++ b/examples/structs.con @@ -0,0 +1,26 @@ +mod Fibonacci { + + struct Node { + a: i32, + b: i32, + } + + struct Node2 { + a: i32, + b: i64, + } + + struct Nod3 { + a: i64, + b: i32, + } + + struct Node5 { + a: i8, + b: i32, + } + + fn main() -> i64 { + return 1; + } +} From b7f701465efa94585dd678bb9a987bd32ebd17f0 Mon Sep 17 00:00:00 2001 From: Edgar Date: Mon, 22 Jan 2024 16:43:55 -0300 Subject: [PATCH 2/9] progress --- crates/concrete_ast/src/expressions.rs | 3 +++ crates/concrete_ast/src/statements.rs | 19 +++++++++++++ crates/concrete_codegen_mlir/src/codegen.rs | 25 ++++++++++------- .../src/scope_context.rs | 12 ++++++--- crates/concrete_parser/src/grammar.lalrpop | 27 +++++++++++++++++-- examples/structs.con | 6 +++++ 6 files changed, 77 insertions(+), 15 deletions(-) diff --git a/crates/concrete_ast/src/expressions.rs b/crates/concrete_ast/src/expressions.rs index 12f8752..e016274 100644 --- a/crates/concrete_ast/src/expressions.rs +++ b/crates/concrete_ast/src/expressions.rs @@ -1,6 +1,9 @@ +use std::collections::HashMap; + use crate::{ common::Ident, statements::Statement, + structs::Field, types::{RefType, TypeSpec}, }; diff --git a/crates/concrete_ast/src/statements.rs b/crates/concrete_ast/src/statements.rs index 8076a7a..4ef2fe6 100644 --- a/crates/concrete_ast/src/statements.rs +++ b/crates/concrete_ast/src/statements.rs @@ -1,6 +1,7 @@ use crate::{ common::Ident, expressions::{Expression, FnCallOp, IfExpr, MatchExpr, PathOp}, + structs::Field, types::TypeSpec, }; @@ -26,6 +27,24 @@ pub enum LetStmtTarget { pub struct LetStmt { pub is_mutable: bool, pub target: LetStmtTarget, + pub value: LetValue, +} + +#[derive(Clone, Debug, Eq, PartialEq)] +pub enum LetValue { + Expr(Expression), + StructConstruct(StructConstruct), +} + +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct StructConstruct { + pub name: Ident, + pub fields: Vec, +} + +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct FieldConstruct { + pub name: Ident, pub value: Expression, } diff --git a/crates/concrete_codegen_mlir/src/codegen.rs b/crates/concrete_codegen_mlir/src/codegen.rs index 825ddfd..6e7a817 100644 --- a/crates/concrete_codegen_mlir/src/codegen.rs +++ b/crates/concrete_codegen_mlir/src/codegen.rs @@ -7,7 +7,7 @@ use concrete_ast::{ }, functions::FunctionDef, modules::{Module, ModuleDefItem}, - statements::{AssignStmt, LetStmt, LetStmtTarget, ReturnStmt, Statement, WhileStmt}, + statements::{AssignStmt, LetStmt, LetStmtTarget, LetValue, ReturnStmt, Statement, WhileStmt}, types::TypeSpec, Program, }; @@ -430,15 +430,20 @@ fn compile_let_stmt<'ctx, 'parent: 'ctx>( LetStmtTarget::Simple { name, r#type } => { let location = get_location(context, session, name.span.from); - let value = compile_expression( - session, - context, - scope_ctx, - helper, - block, - &info.value, - Some(r#type), - )?; + let value = match &info.value { + LetValue::Expr(value) => compile_expression( + session, + context, + scope_ctx, + helper, + block, + value, + Some(r#type), + )?, + LetValue::StructConstruct(_) => { + todo!() + } + }; let memref_type = MemRefType::new(value.r#type(), &[], None, None); let alloca: Value = block diff --git a/crates/concrete_codegen_mlir/src/scope_context.rs b/crates/concrete_codegen_mlir/src/scope_context.rs index 2ab15a0..876d35e 100644 --- a/crates/concrete_codegen_mlir/src/scope_context.rs +++ b/crates/concrete_codegen_mlir/src/scope_context.rs @@ -120,13 +120,15 @@ impl<'ctx, 'parent> ScopeContext<'ctx, 'parent> { } } + /// Returns the struct type along with the field indexes. fn get_struct_type( &self, context: &'ctx MeliorContext, strct: &StructDecl, - ) -> Result, Box> { + ) -> Result<(Type<'ctx>, HashMap), Box> { let mut fields = Vec::with_capacity(strct.fields.len()); + let mut field_indexes = HashMap::new(); let mut size: u32 = 0; for field in &strct.fields { @@ -148,6 +150,7 @@ impl<'ctx, 'parent> ScopeContext<'ctx, 'parent> { size += pad; size += ty_size; + field_indexes.insert(field.name.name.clone(), fields.len()); fields.push(ty); } @@ -165,7 +168,10 @@ impl<'ctx, 'parent> ScopeContext<'ctx, 'parent> { )); } - Ok(llvm::r#type::r#struct(context, &fields, false)) + Ok(( + llvm::r#type::r#struct(context, &fields, false), + field_indexes, + )) } pub fn resolve_type( @@ -183,7 +189,7 @@ impl<'ctx, 'parent> ScopeContext<'ctx, 'parent> { "bool" => IntegerType::new(context, 1).into(), name => { if let Some(strct) = self.module_info.structs.get(name) { - self.get_struct_type(context, strct)? + self.get_struct_type(context, strct)?.0 } else if let Some(module) = self.imports.get(name) { // a import self.resolve_type_spec( diff --git a/crates/concrete_parser/src/grammar.lalrpop b/crates/concrete_parser/src/grammar.lalrpop index 590f8eb..a6142ee 100644 --- a/crates/concrete_parser/src/grammar.lalrpop +++ b/crates/concrete_parser/src/grammar.lalrpop @@ -150,7 +150,7 @@ pub(crate) TypeSpec: ast::types::TypeSpec = { }, "[" )?> "]" => ast::types::TypeSpec::Array { of_type: Box::new(of_type), - size, + size: size.map(|x| x.try_into().expect("size is bigger than u32::MAX")), is_ref, span: Span::new(lo, hi), } @@ -480,8 +480,31 @@ pub(crate) LetStmt: ast::statements::LetStmt = { name, r#type: target_type }, - value + value: ast::statements::LetValue::Expr(value) }, + "let" ":" "=" => ast::statements::LetStmt { + is_mutable: is_mutable.is_some(), + target: ast::statements::LetStmtTarget::Simple { + name, + r#type: target_type + }, + value: ast::statements::LetValue::StructConstruct(value) + }, +} + +pub(crate) StructConstruct: ast::statements::StructConstruct = { + "{" > "}" => ast::statements::StructConstruct { + name, + fields, + } +} + + +pub(crate) FieldConstruct: ast::statements::FieldConstruct = { + ":" => ast::statements::FieldConstruct { + name, + value + } } pub(crate) AssignStmt: ast::statements::AssignStmt = { diff --git a/examples/structs.con b/examples/structs.con index e2cd359..218632f 100644 --- a/examples/structs.con +++ b/examples/structs.con @@ -21,6 +21,12 @@ mod Fibonacci { } fn main() -> i64 { + + let x: Node = Node { + a: 2, + b: 2, + }; + return 1; } } From 856ffc7d0567928b0cb8968170c6942b734d8a04 Mon Sep 17 00:00:00 2001 From: Edgar Date: Mon, 22 Jan 2024 17:50:21 -0300 Subject: [PATCH 3/9] works --- crates/concrete_ast/src/expressions.rs | 3 - crates/concrete_ast/src/statements.rs | 1 - crates/concrete_codegen_mlir/src/codegen.rs | 119 ++++++++++++++++-- crates/concrete_codegen_mlir/src/context.rs | 2 +- .../src/scope_context.rs | 2 +- crates/concrete_parser/src/grammar.lalrpop | 2 +- examples/structs.con | 29 ++--- 7 files changed, 122 insertions(+), 36 deletions(-) diff --git a/crates/concrete_ast/src/expressions.rs b/crates/concrete_ast/src/expressions.rs index e016274..12f8752 100644 --- a/crates/concrete_ast/src/expressions.rs +++ b/crates/concrete_ast/src/expressions.rs @@ -1,9 +1,6 @@ -use std::collections::HashMap; - use crate::{ common::Ident, statements::Statement, - structs::Field, types::{RefType, TypeSpec}, }; diff --git a/crates/concrete_ast/src/statements.rs b/crates/concrete_ast/src/statements.rs index 4ef2fe6..f316fc5 100644 --- a/crates/concrete_ast/src/statements.rs +++ b/crates/concrete_ast/src/statements.rs @@ -1,7 +1,6 @@ use crate::{ common::Ident, expressions::{Expression, FnCallOp, IfExpr, MatchExpr, PathOp}, - structs::Field, types::TypeSpec, }; diff --git a/crates/concrete_codegen_mlir/src/codegen.rs b/crates/concrete_codegen_mlir/src/codegen.rs index 6e7a817..40db12f 100644 --- a/crates/concrete_codegen_mlir/src/codegen.rs +++ b/crates/concrete_codegen_mlir/src/codegen.rs @@ -3,7 +3,8 @@ use std::{collections::HashMap, error::Error}; use bumpalo::Bump; use concrete_ast::{ expressions::{ - ArithOp, BinaryOp, CmpOp, Expression, FnCallOp, IfExpr, LogicOp, PathOp, ValueExpr, + ArithOp, BinaryOp, CmpOp, Expression, FnCallOp, IfExpr, LogicOp, PathOp, PathSegment, + ValueExpr, }, functions::FunctionDef, modules::{Module, ModuleDefItem}, @@ -15,11 +16,14 @@ use concrete_session::Session; use melior::{ dialect::{ arith::{self, CmpiPredicate}, - cf, func, memref, + cf, func, llvm, memref, }, ir::{ - attribute::{FlatSymbolRefAttribute, IntegerAttribute, StringAttribute, TypeAttribute}, - r#type::{FunctionType, IntegerType, MemRefType}, + attribute::{ + DenseI64ArrayAttribute, FlatSymbolRefAttribute, IntegerAttribute, StringAttribute, + TypeAttribute, + }, + r#type::{id, FunctionType, IntegerType, MemRefType}, Block, BlockRef, Location, Module as MeliorModule, Operation, Region, Value, ValueLike, }, Context as MeliorContext, @@ -440,8 +444,62 @@ fn compile_let_stmt<'ctx, 'parent: 'ctx>( value, Some(r#type), )?, - LetValue::StructConstruct(_) => { - todo!() + LetValue::StructConstruct(struct_construct) => { + let struct_decl = scope_ctx + .module_info + .structs + .get(&struct_construct.name.name) + .unwrap_or_else(|| { + panic!("failed to find struct {:?}", struct_construct.name.name) + }); + assert_eq!( + struct_construct.fields.len(), + struct_decl.fields.len(), + "struct field len mismatch" + ); + let (ty, field_indexes) = scope_ctx.get_struct_type(context, struct_decl)?; + let mut struct_value = block + .append_operation(llvm::undef(ty, location)) + .result(0)? + .into(); + let field_types: HashMap = struct_decl + .fields + .iter() + .map(|x| (x.name.name.clone(), &x.r#type)) + .collect(); + + for field in &struct_construct.fields { + let field_idx = field_indexes.get(&field.name.name).unwrap_or_else(|| { + panic!( + "failed to find field {:?} for struct {:?}", + field.name.name, struct_construct.name.name + ) + }); + + let field_ty = field_types.get(&field.name.name).expect("field not found"); + let value = compile_expression( + session, + context, + scope_ctx, + helper, + block, + &field.value, + Some(field_ty), + )?; + + struct_value = block + .append_operation(llvm::insert_value( + context, + struct_value, + DenseI64ArrayAttribute::new(context, &[(*field_idx) as i64]), + value, + location, + )) + .result(0)? + .into(); + } + + struct_value } }; let memref_type = MemRefType::new(value.r#type(), &[], None, None); @@ -806,7 +864,7 @@ fn compile_path_op<'ctx, 'parent: 'ctx>( let location = get_location(context, session, path.first.span.from); - let value = if local.alloca { + let mut value = if local.alloca { block .append_operation(memref::load(local.value, &[], location)) .result(0)? @@ -815,7 +873,52 @@ fn compile_path_op<'ctx, 'parent: 'ctx>( local.value }; - Ok(value) + if path.extra.is_empty() { + Ok(value) + } else { + let mut current_type_spec = &local.type_spec; + + for extra in &path.extra { + match extra { + PathSegment::FieldAccess(ident) => { + let (struct_decl, (_, field_indexes)) = match current_type_spec { + TypeSpec::Simple { name, .. } => { + let struct_decl = + scope_ctx.module_info.structs.get(&name.name).unwrap(); + ( + struct_decl, + scope_ctx.get_struct_type(context, struct_decl)?, + ) + } + _ => unreachable!(), + }; + + let field = struct_decl + .fields + .iter() + .find(|x| x.name.name == ident.name) + .unwrap(); + let field_idx = *field_indexes.get(&ident.name).unwrap(); + let field_ty = scope_ctx.resolve_type_spec(context, &field.r#type)?; + + current_type_spec = &field.r#type; + value = block + .append_operation(llvm::extract_value( + context, + value, + DenseI64ArrayAttribute::new(context, &[field_idx as i64]), + field_ty, + location, + )) + .result(0)? + .into(); + } + PathSegment::ArrayIndex(_) => todo!(), + } + } + + Ok(value) + } } fn compile_deref<'ctx, 'parent: 'ctx>( diff --git a/crates/concrete_codegen_mlir/src/context.rs b/crates/concrete_codegen_mlir/src/context.rs index 70961fd..db65869 100644 --- a/crates/concrete_codegen_mlir/src/context.rs +++ b/crates/concrete_codegen_mlir/src/context.rs @@ -43,7 +43,7 @@ impl Context { super::codegen::compile_program(session, &self.melior_context, &melior_module, program)?; - let print_flags = OperationPrintingFlags::new().enable_debug_info(true, true); + let print_flags = OperationPrintingFlags::new().enable_debug_info(false, true); tracing::debug!( "MLIR Code before passes:\n{}", melior_module diff --git a/crates/concrete_codegen_mlir/src/scope_context.rs b/crates/concrete_codegen_mlir/src/scope_context.rs index 876d35e..5f97eb1 100644 --- a/crates/concrete_codegen_mlir/src/scope_context.rs +++ b/crates/concrete_codegen_mlir/src/scope_context.rs @@ -121,7 +121,7 @@ impl<'ctx, 'parent> ScopeContext<'ctx, 'parent> { } /// Returns the struct type along with the field indexes. - fn get_struct_type( + pub fn get_struct_type( &self, context: &'ctx MeliorContext, strct: &StructDecl, diff --git a/crates/concrete_parser/src/grammar.lalrpop b/crates/concrete_parser/src/grammar.lalrpop index a6142ee..ae623e4 100644 --- a/crates/concrete_parser/src/grammar.lalrpop +++ b/crates/concrete_parser/src/grammar.lalrpop @@ -272,7 +272,7 @@ pub(crate) Struct: ast::structs::StructDecl = { doc_string: None, is_pub: is_pub.is_some(), name, - fields: fields, + fields, type_params: type_params.unwrap_or_else(Vec::new), } } diff --git a/examples/structs.con b/examples/structs.con index 218632f..254509a 100644 --- a/examples/structs.con +++ b/examples/structs.con @@ -1,32 +1,19 @@ -mod Fibonacci { - +mod Structs { struct Node { a: i32, b: i32, } - struct Node2 { - a: i32, - b: i64, - } - - struct Nod3 { - a: i64, - b: i32, - } - - struct Node5 { - a: i8, - b: i32, + fn main() -> i32 { + let x: Node = create_node(2, 4); + return x.a + x.b; } - fn main() -> i64 { - + fn create_node(a: i32, b: i32) -> Node { let x: Node = Node { - a: 2, - b: 2, + a: a, + b: b, }; - - return 1; + return x; } } From 936a95510d031a02b3ca738792dc23e4bd92dbca Mon Sep 17 00:00:00 2001 From: Edgar Date: Mon, 22 Jan 2024 18:18:18 -0300 Subject: [PATCH 4/9] check --- crates/concrete_codegen_mlir/src/codegen.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/concrete_codegen_mlir/src/codegen.rs b/crates/concrete_codegen_mlir/src/codegen.rs index 40db12f..c0991ec 100644 --- a/crates/concrete_codegen_mlir/src/codegen.rs +++ b/crates/concrete_codegen_mlir/src/codegen.rs @@ -23,7 +23,7 @@ use melior::{ DenseI64ArrayAttribute, FlatSymbolRefAttribute, IntegerAttribute, StringAttribute, TypeAttribute, }, - r#type::{id, FunctionType, IntegerType, MemRefType}, + r#type::{FunctionType, IntegerType, MemRefType}, Block, BlockRef, Location, Module as MeliorModule, Operation, Region, Value, ValueLike, }, Context as MeliorContext, From e805b6a566148838cd0662a4075fdea6ed53098e Mon Sep 17 00:00:00 2001 From: Edgar Date: Mon, 22 Jan 2024 19:02:52 -0300 Subject: [PATCH 5/9] struct assign --- crates/concrete_codegen_mlir/src/codegen.rs | 133 +++++++++++++++++--- examples/structs.con | 3 +- 2 files changed, 121 insertions(+), 15 deletions(-) diff --git a/crates/concrete_codegen_mlir/src/codegen.rs b/crates/concrete_codegen_mlir/src/codegen.rs index c0991ec..2d20d5d 100644 --- a/crates/concrete_codegen_mlir/src/codegen.rs +++ b/crates/concrete_codegen_mlir/src/codegen.rs @@ -16,15 +16,18 @@ use concrete_session::Session; use melior::{ dialect::{ arith::{self, CmpiPredicate}, - cf, func, llvm, memref, + cf, func, + llvm::{self, r#type::opaque_pointer, LoadStoreOptions}, + memref, }, ir::{ attribute::{ - DenseI64ArrayAttribute, FlatSymbolRefAttribute, IntegerAttribute, StringAttribute, - TypeAttribute, + DenseI32ArrayAttribute, DenseI64ArrayAttribute, FlatSymbolRefAttribute, + IntegerAttribute, StringAttribute, TypeAttribute, }, r#type::{FunctionType, IntegerType, MemRefType}, - Block, BlockRef, Location, Module as MeliorModule, Operation, Region, Value, ValueLike, + Block, BlockRef, Location, Module as MeliorModule, Operation, Region, Type, Value, + ValueLike, }, Context as MeliorContext, }; @@ -548,17 +551,119 @@ fn compile_assign_stmt<'ctx, 'parent: 'ctx>( let location = get_location(context, session, info.target.first.span.from); - let value = compile_expression( - session, - context, - scope_ctx, - helper, - block, - &info.value, - Some(&local.type_spec), - )?; + if info.target.extra.is_empty() { + let value = compile_expression( + session, + context, + scope_ctx, + helper, + block, + &info.value, + Some(&local.type_spec), + )?; + + block.append_operation(memref::store(value, local.value, &[], location)); + } else { + let mut current_type_spec = &local.type_spec; + + // todo: instead of loading, use memref.extract_aligned_pointer_as_index + + let target_ptr = block + .append_operation( + melior::dialect::ods::memref::extract_aligned_pointer_as_index( + context, + Type::index(context), + local.value, + location, + ) + .into(), + ) + .result(0)? + .into(); + + let target_ptr = block + .append_operation(arith::index_cast( + target_ptr, + IntegerType::new(context, 64).into(), + location, + )) + .result(0)? + .into(); + + let mut target_ptr = block + .append_operation( + melior::dialect::ods::llvm::inttoptr( + context, + opaque_pointer(context), + target_ptr, + location, + ) + .into(), + ) + .result(0)? + .into(); + + let mut extra_it = info.target.extra.iter().peekable(); + + while let Some(extra) = extra_it.next() { + match extra { + PathSegment::FieldAccess(ident) => { + let (struct_decl, (struct_ty, field_indexes)) = match current_type_spec { + TypeSpec::Simple { name, .. } => { + let struct_decl = + scope_ctx.module_info.structs.get(&name.name).unwrap(); + ( + struct_decl, + scope_ctx.get_struct_type(context, struct_decl)?, + ) + } + _ => unreachable!(), + }; + + let field = struct_decl + .fields + .iter() + .find(|x| x.name.name == ident.name) + .unwrap(); + let field_idx = *field_indexes.get(&ident.name).unwrap(); - block.append_operation(memref::store(value, local.value, &[], location)); + current_type_spec = &field.r#type; + target_ptr = block + .append_operation(llvm::get_element_ptr( + context, + target_ptr, + DenseI32ArrayAttribute::new(context, &[field_idx as i32]), + struct_ty, + opaque_pointer(context), + location, + )) + .result(0)? + .into(); + + if extra_it.peek().is_none() { + let value = compile_expression( + session, + context, + scope_ctx, + helper, + block, + &info.value, + Some(current_type_spec), + )?; + + block.append_operation(llvm::store( + context, + value, + target_ptr, + location, + LoadStoreOptions::default(), + )); + } + } + PathSegment::ArrayIndex(_) => todo!(), + } + } + } Ok(()) } diff --git a/examples/structs.con b/examples/structs.con index 254509a..a31c069 100644 --- a/examples/structs.con +++ b/examples/structs.con @@ -5,7 +5,8 @@ mod Structs { } fn main() -> i32 { - let x: Node = create_node(2, 4); + let mut x: Node = create_node(2, 4); + x.a = 5; return x.a + x.b; } From b2b8bd9e99d7cc28c51e99020ffc26d7dda4c3f9 Mon Sep 17 00:00:00 2001 From: Edgar Date: Tue, 23 Jan 2024 14:43:58 -0300 Subject: [PATCH 6/9] progress --- crates/concrete_ast/src/statements.rs | 1 + crates/concrete_codegen_mlir/src/codegen.rs | 30 ++++++++++++++------- crates/concrete_parser/src/grammar.lalrpop | 7 ++--- crates/concrete_parser/src/lib.rs | 12 ++++++--- 4 files changed, 34 insertions(+), 16 deletions(-) diff --git a/crates/concrete_ast/src/statements.rs b/crates/concrete_ast/src/statements.rs index f316fc5..4c0a3cd 100644 --- a/crates/concrete_ast/src/statements.rs +++ b/crates/concrete_ast/src/statements.rs @@ -56,6 +56,7 @@ pub struct ReturnStmt { pub struct AssignStmt { pub target: PathOp, pub value: Expression, + pub is_deref: bool, } #[derive(Clone, Debug, Eq, Hash, PartialEq)] diff --git a/crates/concrete_codegen_mlir/src/codegen.rs b/crates/concrete_codegen_mlir/src/codegen.rs index 2d20d5d..90b2151 100644 --- a/crates/concrete_codegen_mlir/src/codegen.rs +++ b/crates/concrete_codegen_mlir/src/codegen.rs @@ -9,7 +9,7 @@ use concrete_ast::{ functions::FunctionDef, modules::{Module, ModuleDefItem}, statements::{AssignStmt, LetStmt, LetStmtTarget, LetValue, ReturnStmt, Statement, WhileStmt}, - types::TypeSpec, + types::{RefType, TypeSpec}, Program, }; use concrete_session::Session; @@ -60,22 +60,25 @@ pub struct LocalVar<'ctx, 'parent: 'ctx> { // If it's none its on a register, otherwise allocated on the stack. pub alloca: bool, pub value: Value<'ctx, 'parent>, + pub is_mut: bool, } impl<'ctx, 'parent: 'ctx> LocalVar<'ctx, 'parent> { - pub fn param(value: Value<'ctx, 'parent>, type_spec: TypeSpec) -> Self { + pub fn param(value: Value<'ctx, 'parent>, type_spec: TypeSpec, is_mut: bool) -> Self { Self { value, type_spec, alloca: false, + is_mut, } } - pub fn alloca(value: Value<'ctx, 'parent>, type_spec: TypeSpec) -> Self { + pub fn alloca(value: Value<'ctx, 'parent>, type_spec: TypeSpec, is_mut: bool) -> Self { Self { value, type_spec, alloca: true, + is_mut, } } } @@ -221,7 +224,7 @@ fn compile_function_def<'ctx, 'parent: 'ctx>( for (i, param) in info.decl.params.iter().enumerate() { scope_ctx.locals.insert( param.name.name.clone(), - LocalVar::param(fn_block.argument(i)?.into(), param.r#type.clone()), + LocalVar::param(fn_block.argument(i)?.into(), param.r#type.clone(), false), ); } @@ -521,9 +524,10 @@ fn compile_let_stmt<'ctx, 'parent: 'ctx>( block.append_operation(memref::store(value, alloca, &[], location)); - scope_ctx - .locals - .insert(name.name.clone(), LocalVar::alloca(alloca, r#type.clone())); + scope_ctx.locals.insert( + name.name.clone(), + LocalVar::alloca(alloca, r#type.clone(), info.is_mutable), + ); Ok(()) } @@ -547,6 +551,7 @@ fn compile_assign_stmt<'ctx, 'parent: 'ctx>( .expect("local should exist") .clone(); + assert!(local.is_mut, "can only mutate mutable variables"); assert!(local.alloca, "can only mutate local stack variables"); let location = get_location(context, session, info.target.first.span.from); @@ -562,12 +567,17 @@ fn compile_assign_stmt<'ctx, 'parent: 'ctx>( Some(&local.type_spec), )?; - block.append_operation(memref::store(value, local.value, &[], location)); + match local.type_spec.is_ref() { + Some(RefType::MutBorrow) => {} + Some(RefType::Borrow) => {} + None => { + block.append_operation(memref::store(value, local.value, &[], location)); + } + } } else { let mut current_type_spec = &local.type_spec; - // todo: instead of loading, use memref.extract_aligned_pointer_as_index - + // get a ptr to the field let target_ptr = block .append_operation( melior::dialect::ods::memref::extract_aligned_pointer_as_index( diff --git a/crates/concrete_parser/src/grammar.lalrpop b/crates/concrete_parser/src/grammar.lalrpop index ae623e4..e8777b9 100644 --- a/crates/concrete_parser/src/grammar.lalrpop +++ b/crates/concrete_parser/src/grammar.lalrpop @@ -470,7 +470,7 @@ pub(crate) Statement: ast::statements::Statement = { ";" => ast::statements::Statement::Let(<>), ";" => ast::statements::Statement::Assign(<>), ";" => ast::statements::Statement::FnCall(<>), - ";"? => ast::statements::Statement::Return(<>), + ";" => ast::statements::Statement::Return(<>), } pub(crate) LetStmt: ast::statements::LetStmt = { @@ -508,9 +508,10 @@ pub(crate) FieldConstruct: ast::statements::FieldConstruct = { } pub(crate) AssignStmt: ast::statements::AssignStmt = { - "=" => ast::statements::AssignStmt { + "=" => ast::statements::AssignStmt { target, - value + value, + is_deref: is_deref.is_some(), }, } diff --git a/crates/concrete_parser/src/lib.rs b/crates/concrete_parser/src/lib.rs index d0b4e1b..61e1c89 100644 --- a/crates/concrete_parser/src/lib.rs +++ b/crates/concrete_parser/src/lib.rs @@ -59,7 +59,9 @@ mod ModuleName { x = x % 2; match x { - 0 -> return 2, + 0 -> { + return 2; + }, 1 -> { let y: u64 = x * 2; return y * 10; @@ -97,8 +99,12 @@ mod ModuleName { let source = r##"mod FactorialModule { pub fn factorial(x: u64) -> u64 { return match x { - 0 -> return 1, - n -> return n * factorial(n-1), + 0 -> { + return 1; + }, + n -> { + return n * factorial(n-1); + }, }; } }"##; From f22b1bccb14e46cdac89d5e2bf103e9cb9970fcd Mon Sep 17 00:00:00 2001 From: Edgar Date: Tue, 23 Jan 2024 18:09:07 -0300 Subject: [PATCH 7/9] fix --- crates/concrete_ast/src/types.rs | 40 ++++++++ crates/concrete_check/src/lib.rs | 4 +- crates/concrete_codegen_mlir/src/codegen.rs | 103 +++++++++++++++++--- crates/concrete_driver/tests/programs.rs | 70 +++++++++++++ examples/complex.con | 34 +++++++ examples/mutborrow.con | 11 +++ examples/structs.con | 7 +- 7 files changed, 251 insertions(+), 18 deletions(-) create mode 100644 examples/complex.con create mode 100644 examples/mutborrow.con diff --git a/crates/concrete_ast/src/types.rs b/crates/concrete_ast/src/types.rs index 1b0d386..970fb3b 100644 --- a/crates/concrete_ast/src/types.rs +++ b/crates/concrete_ast/src/types.rs @@ -36,6 +36,10 @@ impl TypeSpec { } } + pub fn is_mut_ref(&self) -> bool { + matches!(self.is_ref(), Some(RefType::MutBorrow)) + } + pub fn get_name(&self) -> String { match self { TypeSpec::Simple { name, .. } => name.name.clone(), @@ -43,6 +47,42 @@ impl TypeSpec { TypeSpec::Array { of_type, .. } => format!("[{}]", of_type.get_name()), } } + + pub fn to_nonref_type(&self) -> TypeSpec { + match self { + TypeSpec::Simple { + name, + is_ref: _, + span, + } => TypeSpec::Simple { + name: name.clone(), + is_ref: None, + span: *span, + }, + TypeSpec::Generic { + name, + is_ref: _, + type_params, + span, + } => TypeSpec::Generic { + name: name.clone(), + is_ref: None, + type_params: type_params.clone(), + span: *span, + }, + TypeSpec::Array { + of_type, + size, + is_ref: _, + span, + } => TypeSpec::Array { + of_type: of_type.clone(), + size: *size, + is_ref: None, + span: *span, + }, + } + } } #[derive(Clone, Debug, Eq, Hash, PartialEq)] diff --git a/crates/concrete_check/src/lib.rs b/crates/concrete_check/src/lib.rs index a5f02c7..61ca6a5 100644 --- a/crates/concrete_check/src/lib.rs +++ b/crates/concrete_check/src/lib.rs @@ -155,7 +155,9 @@ impl<'parent> ScopeContext<'parent> { "f64" => name.to_string(), "bool" => name.to_string(), name => { - if let Some(module) = self.imports.get(name) { + if let Some(x) = self.module_info.structs.get(name) { + name.to_string() + } else if let Some(module) = self.imports.get(name) { // a import self.resolve_type_spec(&module.types.get(name)?.value)? } else { diff --git a/crates/concrete_codegen_mlir/src/codegen.rs b/crates/concrete_codegen_mlir/src/codegen.rs index 8e29aba..ba6f790 100644 --- a/crates/concrete_codegen_mlir/src/codegen.rs +++ b/crates/concrete_codegen_mlir/src/codegen.rs @@ -9,7 +9,7 @@ use concrete_ast::{ functions::FunctionDef, modules::{Module, ModuleDefItem}, statements::{AssignStmt, LetStmt, LetStmtTarget, LetValue, ReturnStmt, Statement, WhileStmt}, - types::{RefType, TypeSpec}, + types::TypeSpec, Program, }; use concrete_check::ast_helper::{AstHelper, ModuleInfo}; @@ -228,6 +228,10 @@ fn compile_function_def<'ctx, 'parent: 'ctx>( fn_block = compile_statement(session, context, &mut scope_ctx, &helper, fn_block, stmt)?; } + + if fn_block.terminator().is_none() { + fn_block.append_operation(func::r#return(&[], location)); + } } let fn_name = scope_ctx.get_symbol_name(&info.decl.name.name); @@ -539,6 +543,7 @@ fn compile_assign_stmt<'ctx, 'parent: 'ctx>( block: &'parent Block<'ctx>, info: &AssignStmt, ) -> Result<(), Box> { + tracing::debug!("compiling assign for {:?}", info.target); // todo: implement properly for structs, right now only really works for simple variables. let local = scope_ctx @@ -547,12 +552,31 @@ fn compile_assign_stmt<'ctx, 'parent: 'ctx>( .expect("local should exist") .clone(); - assert!(local.is_mut, "can only mutate mutable variables"); - assert!(local.alloca, "can only mutate local stack variables"); + assert!( + local.is_mut || local.type_spec.is_mut_ref(), + "can only mutate mutable or ref mut variables" + ); + assert!( + local.type_spec.is_mut_ref() || local.alloca, + "can only mutate local stack variables" + ); let location = get_location(context, session, info.target.first.span.from); if info.target.extra.is_empty() { + let mut target_value = local.value; + let mut type_spec = local.type_spec.clone(); + + if info.is_deref { + assert!(local.type_spec.is_mut_ref(), "can only mutate mutable refs"); + if local.alloca { + target_value = block + .append_operation(memref::load(local.value, &[], location)) + .result(0)? + .into() + } + type_spec = type_spec.to_nonref_type(); + } let value = compile_expression( session, context, @@ -560,16 +584,10 @@ fn compile_assign_stmt<'ctx, 'parent: 'ctx>( helper, block, &info.value, - Some(&local.type_spec), + Some(&type_spec), )?; - match local.type_spec.is_ref() { - Some(RefType::MutBorrow) => {} - Some(RefType::Borrow) => {} - None => { - block.append_operation(memref::store(value, local.value, &[], location)); - } - } + block.append_operation(memref::store(value, target_value, &[], location)); } else { let mut current_type_spec = &local.type_spec; @@ -614,7 +632,7 @@ fn compile_assign_stmt<'ctx, 'parent: 'ctx>( while let Some(extra) = extra_it.next() { match extra { PathSegment::FieldAccess(ident) => { - let (struct_decl, (struct_ty, field_indexes)) = match current_type_spec { + let (struct_decl, (_, field_indexes)) = match current_type_spec { TypeSpec::Simple { name, .. } => { let struct_decl = scope_ctx.module_info.structs.get(&name.name).unwrap(); @@ -630,8 +648,9 @@ fn compile_assign_stmt<'ctx, 'parent: 'ctx>( .fields .iter() .find(|x| x.name.name == ident.name) - .unwrap(); + .unwrap_or_else(|| panic!("failed to find field {:?}", ident.name)); let field_idx = *field_indexes.get(&ident.name).unwrap(); + let field_ty = scope_ctx.resolve_type_spec(context, &field.r#type)?; current_type_spec = &field.r#type; target_ptr = block @@ -639,7 +658,7 @@ fn compile_assign_stmt<'ctx, 'parent: 'ctx>( context, target_ptr, DenseI32ArrayAttribute::new(context, &[field_idx as i32]), - struct_ty, + field_ty, opaque_pointer(context), location, )) @@ -717,7 +736,7 @@ fn compile_expression<'ctx, 'parent: 'ctx>( session, context, scope_ctx, _helper, block, value, type_info, ), Expression::FnCall(value) => { - compile_fn_call(session, context, scope_ctx, _helper, block, value) + compile_fn_call_with_return(session, context, scope_ctx, _helper, block, value) } Expression::Match(_) => todo!(), Expression::If(_) => todo!(), @@ -906,6 +925,54 @@ fn compile_fn_call<'ctx, 'parent: 'ctx>( _helper: &BlockHelper<'ctx, 'parent>, block: &'parent Block<'ctx>, info: &FnCallOp, +) -> Result<(), Box> { + tracing::debug!("compiling fncall: {:?}", info); + let mut args = Vec::with_capacity(info.args.len()); + let location = get_location(context, session, info.target.span.from); + + let target_fn = scope_ctx + .get_function(&info.target.name) + .expect("function not found") + .clone(); + + assert_eq!( + info.args.len(), + target_fn.decl.params.len(), + "parameter length doesnt match" + ); + + for (arg, arg_info) in info.args.iter().zip(&target_fn.decl.params) { + let value = compile_expression( + session, + context, + scope_ctx, + _helper, + block, + arg, + Some(&arg_info.r#type), + )?; + args.push(value); + } + + let fn_name = scope_ctx.get_symbol_name(&info.target.name); + + block.append_operation(func::call( + context, + FlatSymbolRefAttribute::new(context, &fn_name), + &args, + &[], + location, + )); + Ok(()) +} + +fn compile_fn_call_with_return<'ctx, 'parent: 'ctx>( + session: &Session, + context: &'ctx MeliorContext, + scope_ctx: &mut ScopeContext<'ctx, 'parent>, + _helper: &BlockHelper<'ctx, 'parent>, + block: &'parent Block<'ctx>, + info: &FnCallOp, ) -> Result, Box> { tracing::debug!("compiling fncall: {:?}", info); let mut args = Vec::with_capacity(info.args.len()); @@ -1061,6 +1128,8 @@ fn compile_deref<'ctx, 'parent: 'ctx>( .into(); } + // todo: handle deref for struct fields + Ok(value) } @@ -1083,5 +1152,7 @@ fn compile_asref<'ctx, 'parent: 'ctx>( panic!("can only take refs to non register values"); } - Ok(local.value) + // todo: handle asref for struct fields + + Ok(dbg!(local.value)) } diff --git a/crates/concrete_driver/tests/programs.rs b/crates/concrete_driver/tests/programs.rs index fecc2f3..2918bab 100644 --- a/crates/concrete_driver/tests/programs.rs +++ b/crates/concrete_driver/tests/programs.rs @@ -158,3 +158,73 @@ fn test_reference() { let code = output.status.code().unwrap(); assert_eq!(code, 2); } + +#[test] +fn test_mut_reference() { + let source = r#" + mod Simple { + fn main(argc: i64) -> i64 { + let mut x: i64 = 2; + change(&mut x); + return x; + } + + fn change(a: &mut i64) { + *a = 4; + } + } + "#; + + let result = compile_program(source, "mut_ref", false).expect("failed to compile"); + + let output = run_program(&result.binary_file).expect("failed to run"); + let code = output.status.code().unwrap(); + assert_eq!(code, 4); +} + +#[test] +fn test_structs() { + let source = r#" + mod Structs { + + struct Leaf { + x: i32, + y: i64, + } + struct Node { + a: Leaf, + b: Leaf, + } + + fn main() -> i32 { + let a: Leaf = Leaf { + x: 1, + y: 2, + }; + let b: Leaf = Leaf { + x: 1, + y: 2, + }; + let mut x: Node = Node { + a: a, + b: b, + }; + x.a.x = 2; + modify(&mut x); + return x.a.x + x.b.x; + } + + fn modify(node: &mut Node) { + node.a.x = 1; + node.b.x = 1; + } + } + + "#; + + let result = compile_program(source, "structs", false).expect("failed to compile"); + + let output = run_program(&result.binary_file).expect("failed to run"); + let code = output.status.code().unwrap(); + assert_eq!(code, 2); +} diff --git a/examples/complex.con b/examples/complex.con new file mode 100644 index 0000000..e982712 --- /dev/null +++ b/examples/complex.con @@ -0,0 +1,34 @@ +mod Structs { + + struct Leaf { + x: i32, + y: i64, + } + struct Node { + a: Leaf, + b: Leaf, + } + + fn main() -> i32 { + let a: Leaf = Leaf { + x: 1, + y: 2, + }; + let b: Leaf = Leaf { + x: 1, + y: 2, + }; + let mut x: Node = Node { + a: a, + b: b, + }; + x.a.x = 2; + modify(&mut x); + return x.a.x + x.b.x; + } + + fn modify(node: &mut Node) { + node.a.x = 3; + node.b.x = 3; + } +} diff --git a/examples/mutborrow.con b/examples/mutborrow.con new file mode 100644 index 0000000..75a4fb2 --- /dev/null +++ b/examples/mutborrow.con @@ -0,0 +1,11 @@ +mod Simple { + fn main(argc: i64) -> i64 { + let mut x: i64 = 2; + change(&mut x); + return x; + } + + fn change(a: &mut i64) { + *a = 4; + } +} diff --git a/examples/structs.con b/examples/structs.con index a31c069..95fa1d5 100644 --- a/examples/structs.con +++ b/examples/structs.con @@ -6,7 +6,8 @@ mod Structs { fn main() -> i32 { let mut x: Node = create_node(2, 4); - x.a = 5; + x.a = 100; + modify_node(&mut x); return x.a + x.b; } @@ -17,4 +18,8 @@ mod Structs { }; return x; } + + fn modify_node(x: &mut Node) { + x.a = 1; + } } From e94e2ee2d784c35e8e9d82c97f4bca79df58a4b5 Mon Sep 17 00:00:00 2001 From: Edgar Date: Tue, 23 Jan 2024 18:10:18 -0300 Subject: [PATCH 8/9] dbg --- crates/concrete_codegen_mlir/src/codegen.rs | 2 +- crates/concrete_driver/tests/programs.rs | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/crates/concrete_codegen_mlir/src/codegen.rs b/crates/concrete_codegen_mlir/src/codegen.rs index ba6f790..ab7cf29 100644 --- a/crates/concrete_codegen_mlir/src/codegen.rs +++ b/crates/concrete_codegen_mlir/src/codegen.rs @@ -1154,5 +1154,5 @@ fn compile_asref<'ctx, 'parent: 'ctx>( // todo: handle asref for struct fields - Ok(dbg!(local.value)) + Ok(local.value) } diff --git a/crates/concrete_driver/tests/programs.rs b/crates/concrete_driver/tests/programs.rs index 2918bab..c41b6c2 100644 --- a/crates/concrete_driver/tests/programs.rs +++ b/crates/concrete_driver/tests/programs.rs @@ -215,8 +215,8 @@ fn test_structs() { } fn modify(node: &mut Node) { - node.a.x = 1; - node.b.x = 1; + node.a.x = 3; + node.b.x = 3; } } @@ -226,5 +226,5 @@ fn test_structs() { let output = run_program(&result.binary_file).expect("failed to run"); let code = output.status.code().unwrap(); - assert_eq!(code, 2); + assert_eq!(code, 6); } From 7b5151d1322419872074f6d9d0a3e95dc87d2faf Mon Sep 17 00:00:00 2001 From: Edgar Date: Wed, 24 Jan 2024 10:06:10 -0300 Subject: [PATCH 9/9] progress --- .../src/scope_context.rs | 38 ++++++++++++++++++- examples/aabb.con | 20 ++++++++++ 2 files changed, 57 insertions(+), 1 deletion(-) create mode 100644 examples/aabb.con diff --git a/crates/concrete_codegen_mlir/src/scope_context.rs b/crates/concrete_codegen_mlir/src/scope_context.rs index a016af9..6d077c9 100644 --- a/crates/concrete_codegen_mlir/src/scope_context.rs +++ b/crates/concrete_codegen_mlir/src/scope_context.rs @@ -1,6 +1,11 @@ use std::{collections::HashMap, error::Error}; -use concrete_ast::{functions::FunctionDef, structs::StructDecl, types::TypeSpec}; +use concrete_ast::{ + expressions::{Expression, PathSegment, ValueExpr}, + functions::FunctionDef, + structs::StructDecl, + types::TypeSpec, +}; use concrete_check::ast_helper::ModuleInfo; use melior::{ dialect::llvm, @@ -260,4 +265,35 @@ impl<'ctx, 'parent> ScopeContext<'ctx, 'parent> { TypeSpec::Array { .. } => unreachable!(), } } + + pub fn get_expr_type(&self, exp: &Expression) -> Option { + match exp { + Expression::Value(value) => match value { + ValueExpr::Path(path) => { + let first = self.locals.get(&path.first.name)?; + + if path.extra.is_empty() { + Some(first.type_spec.clone()) + } else { + let mut current = &first.type_spec; + for extra in &path.extra { + match extra { + PathSegment::FieldAccess(ident) => { + let st = self.module_info.structs.get(&ident.name)?; + let field = st.fields.get(ident.name); + } + PathSegment::ArrayIndex(_) => todo!(), + } + } + } + } + _ => None, + }, + Expression::FnCall(_) => todo!(), + Expression::Match(_) => todo!(), + Expression::If(_) => todo!(), + Expression::UnaryOp(_, _) => todo!(), + Expression::BinaryOp(_, _, _) => todo!(), + } + } } diff --git a/examples/aabb.con b/examples/aabb.con new file mode 100644 index 0000000..40762a2 --- /dev/null +++ b/examples/aabb.con @@ -0,0 +1,20 @@ +mod RectCheck { + struct Point2D { + x: i64, + y: i64, + } + + struct Rect2D { + pos: Point2D, + size: Point2D, + } + + pub fn is_point_inbounds(rect: &Rect2D, point: &Point2D) -> bool { + if point.x >= rect.pos.x && point.x <= (rect.pos.x + rect.size.x) + && point.y >= rect.pos.y && point.y <= (rect.pos.y + rect.size.y) { + return true; + } else { + return false; + } + } +}