From 07448425a0e3c946dacf9d71a49de69f8cc30f58 Mon Sep 17 00:00:00 2001 From: JJ Date: Thu, 13 Apr 2023 15:09:36 -0700 Subject: split bidirectional checking and conversion + defaults out of simple.rs --- src/ast.rs | 55 ++++++++++++++++ src/bidirectional.rs | 120 +++++++++++++++++++++++++++++++++++ src/lib.rs | 3 +- src/main.rs | 10 +-- src/simple.rs | 167 ------------------------------------------------- tests/test_checking.rs | 2 +- 6 files changed, 184 insertions(+), 173 deletions(-) create mode 100644 src/bidirectional.rs diff --git a/src/ast.rs b/src/ast.rs index 0a0c042..492f8af 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -119,3 +119,58 @@ impl fmt::Display for Term { // return Ok(()); // } // } + +/// Convert a term into its corresponding type. +pub fn convert(term: &Term) -> Result { + match term { + Term::Unit() => Ok(Type::Unit), + Term::Boolean(_) => Ok(Type::Boolean), + Term::Natural(_) => Ok(Type::Natural), + Term::Integer(_) => Ok(Type::Integer), + Term::Float(_) => Ok(Type::Float), + Term::String { len, cap, data } => Ok(Type::String), + Term::Enum { val, data } => data.get(*val) + .ok_or_else(|| "enum value out of range!".to_string()).cloned(), + Term::Record(data) => { + let mut result = HashMap::new(); + for (key, val) in data { + result.insert(key.clone(), convert(val)?); + } + return Ok(Type::Record(result)); + }, + Term::Function(func) => match *func.clone() { + Expression::Annotation { expr, kind } => match kind { + Type::Function { from, to } => Ok(Type::Function { from, to }), + _ => Err("function term value not a function!".to_string()) + } + _ => Err("function term value does not have an annotation!".to_string()) + } + } +} + +/// Get the default value of a type. Throws an error if it doesn't exist. +pub fn default(kind: &Type) -> Result { + match kind { + Type::Empty => Err("attempting to take the default term for empty".to_string()), + Type::Error => Err("attempting to take the default term for error".to_string()), + Type::Unit => Ok(Term::Unit()), + Type::Boolean => Ok(Term::Boolean(false)), + Type::Natural => Ok(Term::Natural(0)), + Type::Integer => Ok(Term::Integer(0)), + Type::Float => Ok(Term::Float(0.0)), + Type::String => Ok(Term::String { len: 0, cap: 0, data: vec!()}), + Type::Enum(data) => match data.len() { + 0 => Err("attempting to get a default term of an enum with no variants!".to_string()), + _ => Ok(Term::Enum { val: 0, data: data.clone() }) + }, + Type::Record(data) => { + let mut result = HashMap::new(); + for (key, val) in data { + result.insert(key.clone(), default(&val)?); + } + return Ok(Term::Record(result)); + }, + Type::Function { from, to } => + Err("attempting to take the default term of a function type".to_string()), + } +} diff --git a/src/bidirectional.rs b/src/bidirectional.rs new file mode 100644 index 0000000..493ae2f --- /dev/null +++ b/src/bidirectional.rs @@ -0,0 +1,120 @@ +// Simple bidirectional type checking + +use crate::ast::*; + +/// Checking judgement: takes an expression and a type to check against and calls out to `infer` as needed. +pub fn check(context: &Context, expression: Expression, target: &Type) -> Result<(), String> { + match expression { + // fall through to inference mode + Expression::Annotation { expr, kind } => { + let result = infer(context, Expression::Annotation { expr, kind })?; + return match subtype(&result, &target) { + true => Ok(()), + false => Err(format!("inferred type {result} does not match target {target}")) + } + }, + // Bt-CheckInfer + Expression::Constant { term } => match subtype(&convert(&term)?, &target) { + true => Ok(()), + false => Err(format!("constant is of wrong type, expected {target}")) + // false => Ok(()) // all our constants are Empty for now + }, + // Bt-CheckInfer + Expression::Variable { id } => match context.get(&id) { + Some(term) if subtype(&convert(term)?, &target) => Ok(()), + Some(_) => Err(format!("variable {id} is of wrong type")), + None => Err(format!("failed to find variable {id} in context")) + }, + // Bt-Abs + Expression::Abstraction { param, func } => match target { + Type::Function { from, to } => { + let mut context = context.clone(); + context.insert(param, default(from)?); + return check(&context, *func, &to); + }, + _ => Err(format!("attempting to check an abstraction with a non-function type {target}")) + }, + // fall through to inference mode + Expression::Application { func, arg } => { + let result = &infer(context, Expression::Application { func, arg })?; + return match subtype(&result, &target) { + true => Ok(()), + false => Err(format!("inferred type {result} does not match {target}")) + } + }, + // T-If + Expression::Conditional { if_cond, if_then, if_else } => { + check(context, *if_cond, &Type::Boolean)?; + check(context, *if_then, &target)?; + check(context, *if_else, &target)?; + return Ok(()); + } + } +} + +/// Inference judgement: takes an expression and attempts to infer the associated type. +pub fn infer(context: &Context, expression: Expression) -> Result { + match expression { + // Bt-Ann + Expression::Annotation { expr, kind } => check(context, *expr, &kind).map(|x| kind), + // Bt-True / Bt-False / etc + Expression::Constant { term } => convert(&term), + // Bt-Var + Expression::Variable { id } => match context.get(&id) { + Some(term) => infer(&Context::new(), Expression::Constant { term: term.clone() }), + None => Err(format!("failed to find variable in context {context:?}")) + }, + // Bt-App + Expression::Application { func, arg } => match infer(context, *func)? { + Type::Function { from, to } => check(context, *arg, &*from).map(|x| *to), + _ => Err(format!("application abstraction is not a function type")) + }, + // inference from an abstraction is always an error + // we could try and infer the func without adding the parameter to scope: + // but this is overwhelmingly likely to be an error, so just report it now. + Expression::Abstraction { param, func } => + Err(format!("attempting to infer from an abstraction")), + // idk + Expression::Conditional { if_cond, if_then, if_else } => { + check(context, *if_cond, &Type::Boolean)?; + let if_then = infer(context, *if_then)?; + let if_else = infer(context, *if_else)?; + if subtype(&if_then, &if_else) && subtype(&if_else, &if_then) { + Ok(if_then) // fixme: should be the join + } else { + Err(format!("if clauses of different types: {if_then} and {if_else}")) + } + } + } +} + +/// The subtyping relation between any two types. +pub fn subtype(is: &Type, of: &Type) -> bool { + match (is, of) { + (Type::Record(is_fields), Type::Record(of_fields)) => { + // width, depth, and permutation + for (key, of_value) in of_fields { + match is_fields.get(key) { + Some(is_value) => { + if !subtype(is_value, of_value) { + return false; + } + } + None => return false + } + } + return true; + }, + (Type::Enum(is_variants), Type::Enum(of_variants)) => false, // fixme + (Type::Function { from: is_from, to: is_to }, + Type::Function { from: of_from, to: of_to }) => { + subtype(of_from, is_from) && subtype(is_to, of_to) + }, + (Type::Natural, Type::Integer) => true, // obviously not, but let's pretend + (_, Type::Empty) => true, + (Type::Error, _) => true, + (_, _) if is == of => true, + (_, _) => false + } +} + diff --git a/src/lib.rs b/src/lib.rs index dce0be7..617f158 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,7 @@ -#![allow(unused_variables)] +#![allow(unused_variables, non_upper_case_globals)] pub mod ast; +pub mod bidirectional; // pub mod classes; // pub mod effects; pub mod parser; diff --git a/src/main.rs b/src/main.rs index 4504d87..a33e963 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,6 +2,8 @@ use std::io::{Write, stdout, stdin}; use chrysanthemum::*; use chrysanthemum::ast::*; +use chrysanthemum::bidirectional::*; +use chrysanthemum::simple::*; fn main() { println!("chrysanthemum"); @@ -22,7 +24,7 @@ fn main() { input.clear(); stdin().read_line(&mut input).unwrap(); - match simple::infer(&empty_context, parser::parse_lambda(&input).unwrap()) { + match infer(&empty_context, parser::parse_lambda(&input).unwrap()) { Ok(kind) => println!("infers! {}", kind), Err(e) => println!("{:?}", e), } @@ -34,10 +36,10 @@ fn main() { input.clear(); stdin().read_line(&mut input).unwrap(); - let kind = simple::infer(&empty_context, parser::parse(&input)); + let kind = infer(&empty_context, parser::parse_lambda(&input).unwrap()); match kind { Ok(kind) => { - match simple::check(&empty_context, parser::parse_lambda(&input).unwrap(), &kind) { + match check(&empty_context, parser::parse_lambda(&input).unwrap(), &kind) { Ok(_) => println!("checks!"), Err(e) => println!("{:?}", e), } @@ -52,7 +54,7 @@ fn main() { input.clear(); stdin().read_line(&mut input).unwrap(); - match simple::execute(&empty_context, parser::parse_lambda(&input).unwrap()) { + match execute(&empty_context, parser::parse_lambda(&input).unwrap()) { Ok(term) => println!("{}", term), Err(e) => println!("{:?}", e) } diff --git a/src/simple.rs b/src/simple.rs index c67fe2e..8d87dbf 100644 --- a/src/simple.rs +++ b/src/simple.rs @@ -1,91 +1,4 @@ -// Simple bidirectional type checking - use crate::ast::*; -use std::collections::HashMap; - -pub fn check(context: &Context, expression: Expression, target: &Type) -> Result<(), String> { - match expression { - // fall through to inference mode - Expression::Annotation { expr, kind } => { - let result = infer(context, Expression::Annotation { expr, kind })?; - return match subtype(&result, &target) { - true => Ok(()), - false => Err(format!("inferred type {result} does not match target {target}")) - } - }, - // Bt-CheckInfer - Expression::Constant { term } => match subtype(&convert(&term)?, &target) { - true => Ok(()), - false => Err(format!("constant is of wrong type, expected {target}")) - // false => Ok(()) // all our constants are Empty for now - }, - // Bt-CheckInfer - Expression::Variable { id } => match context.get(&id) { - Some(term) if subtype(&convert(term)?, &target) => Ok(()), - Some(_) => Err(format!("variable {id} is of wrong type")), - None => Err(format!("failed to find variable {id} in context")) - }, - // Bt-Abs - Expression::Abstraction { param, func } => match target { - Type::Function { from, to } => { - let mut context = context.clone(); - context.insert(param, default(from)?); - return check(&context, *func, &to); - }, - _ => Err(format!("attempting to check an abstraction with a non-function type {target}")) - }, - // fall through to inference mode - Expression::Application { func, arg } => { - let result = &infer(context, Expression::Application { func, arg })?; - return match subtype(&result, &target) { - true => Ok(()), - false => Err(format!("inferred type {result} does not match {target}")) - } - }, - // T-If - Expression::Conditional { if_cond, if_then, if_else } => { - check(context, *if_cond, &Type::Boolean)?; - check(context, *if_then, &target)?; - check(context, *if_else, &target)?; - return Ok(()); - } - } -} - -pub fn infer(context: &Context, expression: Expression) -> Result { - match expression { - // Bt-Ann - Expression::Annotation { expr, kind } => check(context, *expr, &kind).map(|x| kind), - // Bt-True / Bt-False / etc - Expression::Constant { term } => convert(&term), - // Bt-Var - Expression::Variable { id } => match context.get(&id) { - Some(term) => infer(&Context::new(), Expression::Constant { term: term.clone() }), - None => Err(format!("failed to find variable in context {context:?}")) - }, - // Bt-App - Expression::Application { func, arg } => match infer(context, *func)? { - Type::Function { from, to } => check(context, *arg, &*from).map(|x| *to), - _ => Err(format!("application abstraction is not a function type")) - }, - // inference from an abstraction is always an error - // we could try and infer the func without adding the parameter to scope: - // but this is overwhelmingly likely to be an error, so just report it now. - Expression::Abstraction { param, func } => - Err(format!("attempting to infer from an abstraction")), - // idk - Expression::Conditional { if_cond, if_then, if_else } => { - check(context, *if_cond, &Type::Boolean)?; - let if_then = infer(context, *if_then)?; - let if_else = infer(context, *if_else)?; - if subtype(&if_then, &if_else) && subtype(&if_else, &if_then) { - Ok(if_then) // fixme: should be the join - } else { - Err(format!("if clauses of different types: {if_then} and {if_else}")) - } - } - } -} /// Evaluates an expression given a context (of variables) to a term, or fails. pub fn execute(context: &Context, expression: Expression) -> Result { @@ -117,83 +30,3 @@ pub fn execute(context: &Context, expression: Expression) -> Result bool { - match (is, of) { - (Type::Record(is_fields), Type::Record(of_fields)) => { - // width, depth, and permutation - for (key, of_value) in of_fields { - match is_fields.get(key) { - Some(is_value) => { - if !subtype(is_value, of_value) { - return false; - } - } - None => return false - } - } - return true; - }, - (Type::Function { from: is_from, to: is_to }, - Type::Function { from: of_from, to: of_to }) => { - subtype(of_from, is_from) && subtype(is_to, of_to) - }, - (Type::Natural, Type::Integer) => true, // obviously not, but let's pretend - (_, Type::Empty) => true, - (Type::Error, _) => true, - (_, _) if is == of => true, - (_, _) => false - } -} - -/// Convert a term into its corresponding type. -pub fn convert(term: &Term) -> Result { - match term { - Term::Unit() => Ok(Type::Unit), - Term::Boolean(_) => Ok(Type::Boolean), - Term::Natural(_) => Ok(Type::Natural), - Term::Integer(_) => Ok(Type::Integer), - Term::Float(_) => Ok(Type::Float), - Term::String { len, cap, data } => Ok(Type::String), - Term::Enum { val, data } => data.get(*val) - .ok_or_else(|| "enum value out of range!".to_string()).cloned(), - Term::Record(data) => { - let mut result = HashMap::new(); - for (key, val) in data { - result.insert(key.clone(), convert(val)?); - } - return Ok(Type::Record(result)); - }, - Term::Function(func) => match infer(&Context::new(), *func.clone()) { - Ok(Type::Function { from, to }) => Ok(Type::Function { from, to }), - _ => Err("function term value not a function!".to_string()) - } - } -} - -/// Get the default value of a type. Throws an error if it doesn't exist. -pub fn default(kind: &Type) -> Result { - match kind { - Type::Empty => Err("attempting to take the default term for empty".to_string()), - Type::Error => Err("attempting to take the default term for error".to_string()), - Type::Unit => Ok(Term::Unit()), - Type::Boolean => Ok(Term::Boolean(false)), - Type::Natural => Ok(Term::Natural(0)), - Type::Integer => Ok(Term::Integer(0)), - Type::Float => Ok(Term::Float(0.0)), - Type::String => Ok(Term::String { len: 0, cap: 0, data: vec!()}), - Type::Enum(data) => match data.len() { - 0 => Err("attempting to get a default term of an enum with no variants!".to_string()), - _ => Ok(Term::Enum { val: 0, data: data.clone() }) - }, - Type::Record(data) => { - let mut result = HashMap::new(); - for (key, val) in data { - result.insert(key.clone(), default(&val)?); - } - return Ok(Term::Record(result)); - }, - Type::Function { from, to } => - Err("attempting to take the default term of a function type".to_string()), - } -} diff --git a/tests/test_checking.rs b/tests/test_checking.rs index 286283b..e553a4e 100644 --- a/tests/test_checking.rs +++ b/tests/test_checking.rs @@ -1,8 +1,8 @@ #![allow(non_upper_case_globals)] use chrysanthemum::ast::*; +use chrysanthemum::bidirectional::*; use chrysanthemum::parser::*; -use chrysanthemum::simple::*; use chrysanthemum::util::*; // rust you KNOW these are &'static strs -- cgit v1.2.3-70-g09d2