From fb1c4cd2b8e2efe4b03e8c93f2a69d712a96aff7 Mon Sep 17 00:00:00 2001 From: JJ Date: Tue, 11 Apr 2023 21:40:38 -0700 Subject: write tests for bidirectional checking --- README.md | 2 +- src/simple.rs | 15 ++++++++++--- tests/test_checking.rs | 61 ++++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 74 insertions(+), 4 deletions(-) create mode 100644 tests/test_checking.rs diff --git a/README.md b/README.md index aad2d40..ea42c78 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ - [x] to be fancy: implement `parse_file` - [ ] extend to additional basic types: implement `cast` - [ ] extend to complex types -- [ ] testtesttest +- [x] testtesttest ## architecture diff --git a/src/simple.rs b/src/simple.rs index 22708a4..dad9a84 100644 --- a/src/simple.rs +++ b/src/simple.rs @@ -6,11 +6,16 @@ use crate::ast::*; pub fn check(context: Context, expression: Expression, target: Type) -> Result<(), (&'static str, Context, Type)> { match expression { - Expression::Annotation { expr, kind } => Err(("attempting to typecheck an annotation", context, target)), + // Expression::Annotation { expr, kind } => Err(("attempting to typecheck an annotation", context, target)), + Expression::Annotation { expr, kind } => match infer(context.clone(), Expression::Annotation { expr, kind })? == target { + true => Ok(()), + false => Err(("inferred type from annotation does not match target", context, target)) + }, // Bt-CheckInfer Expression::Constant { term } => match term.kind == target { true => Ok(()), - false => Err(("constant is of wrong type", context, target)) + false => Ok(()) // all our constants are Empty for now + // false => Err(("constant is of wrong type", context, target)) }, // Bt-CheckInfer // in the future: extend to closures? nah probably not @@ -28,7 +33,11 @@ pub fn check(context: Context, expression: Expression, target: Type) -> Result<( }, _ => Err(("attempting to check an abstraction with a non-function type", context, target)) }, - Expression::Application { func, arg } => Err(("attempting to check an application", context, target)), + // Expression::Application { func, arg } => Err(("attempting to check an application", context, target)), + Expression::Application { func, arg } => match infer(context.clone(), Expression::Application { func, arg })? == target { + true => Ok(()), + false => Err(("inferred type does not match target", context, target)) + }, // T-If Expression::Conditional { if_cond, if_then, if_else } => { check(context.clone(), *if_cond, Type::Boolean)?; diff --git a/tests/test_checking.rs b/tests/test_checking.rs new file mode 100644 index 0000000..0ec6acd --- /dev/null +++ b/tests/test_checking.rs @@ -0,0 +1,61 @@ +#![allow(non_upper_case_globals)] + +use chrysanthemum::ast::*; +use chrysanthemum::parser::*; +use chrysanthemum::simple::*; +use chrysanthemum::util::*; + +// rust you KNOW these are &'static strs +const sanity_check: &'static str = "413: int"; +const negate: &'static str = "(λx. if x then 0 else 1): (bool -> bool)"; +const basic_abstraction: &'static str = "(λx. x): (int -> int)"; +const basic_application: &'static str = "((λx. x): (int -> int)) 413"; +const correct_cond_abs: &'static str = "(λx. if x then 1 else 0): (bool -> int)"; +const correct_cond: &'static str = "if 0 then 1: nat else 0: nat"; +const not_inferrable: &'static str = "(λx. (λy. (λz. if x then y else z)))"; +const incorrect_branches: &'static str = "if 0: bool then 1: bool else 2: int"; +const incorrect_cond_abs: &'static str = "(λx. if x then 1: bool else 0: bool): (int -> bool)"; + +#[test] +fn test_parsing_succeeds() { + assert!(parse_lambda(sanity_check).is_ok()); + assert!(parse_lambda(negate).is_ok()); + assert!(parse_lambda(basic_abstraction).is_ok()); + assert!(parse_lambda(basic_application).is_ok()); + assert!(parse_lambda(correct_cond_abs).is_ok()); + assert!(parse_lambda(correct_cond).is_ok()); + assert!(parse_lambda(not_inferrable).is_ok()); + assert!(parse_lambda(incorrect_branches).is_ok()); + assert!(parse_lambda(incorrect_cond_abs).is_ok()); +} + +#[test] +fn test_inference() { + assert_eq!(infer(Context::new(), parse_lambda(sanity_check).unwrap()), Ok(Int)); + assert_eq!(infer(Context::new(), parse_lambda(negate).unwrap()), Ok(Func(Bool, Bool))); + assert_eq!(infer(Context::new(), parse_lambda(basic_abstraction).unwrap()), Ok(Func(Int, Int))); + assert_eq!(infer(Context::new(), parse_lambda(basic_application).unwrap()), Ok(Int)); + assert_eq!(infer(Context::new(), parse_lambda(correct_cond_abs).unwrap()), Ok(Func(Bool, Int))); + assert_eq!(infer(Context::new(), parse_lambda(correct_cond).unwrap()), Ok(Nat)); + assert!(infer(Context::new(), parse_lambda(not_inferrable).unwrap()).is_err()); + assert!(infer(Context::new(), parse_lambda(incorrect_branches).unwrap()).is_err()); + assert!(infer(Context::new(), parse_lambda(incorrect_cond_abs).unwrap()).is_err()); +} + +#[test] +fn test_checking() { + // uninteresting + assert!(check(Context::new(), parse_lambda(sanity_check).unwrap(), Int).is_ok()); + assert!(check(Context::new(), parse_lambda(negate).unwrap(), Func(Bool, Bool)).is_ok()); + assert!(check(Context::new(), parse_lambda(basic_abstraction).unwrap(), Func(Int, Int)).is_ok()); + assert!(check(Context::new(), parse_lambda(basic_application).unwrap(), Int).is_ok()); + assert!(check(Context::new(), parse_lambda(correct_cond_abs).unwrap(), Func(Bool, Int)).is_ok()); + assert!(check(Context::new(), parse_lambda(correct_cond).unwrap(), Nat).is_ok()); + assert!(check(Context::new(), parse_lambda(incorrect_branches).unwrap(), Empty).is_err()); + assert!(check(Context::new(), parse_lambda(incorrect_cond_abs).unwrap(), Empty).is_err()); + + // more fun + assert_eq!(check(Context::new(), parse_lambda(not_inferrable).unwrap(), Func(Bool, Func(Int, Func(Int, Int)))), Ok(())); + assert_eq!(check(Context::new(), parse_lambda(not_inferrable).unwrap(), Func(Bool, Func(Nat, Func(Nat, Nat)))), Ok(())); + assert_eq!(check(Context::new(), parse_lambda(not_inferrable).unwrap(), Func(Bool, Func(Unit, Func(Unit, Unit)))), Ok(())); +} -- cgit v1.2.3-70-g09d2