aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJJ2023-04-12 04:40:38 +0000
committerJJ2023-04-12 04:40:38 +0000
commitfb1c4cd2b8e2efe4b03e8c93f2a69d712a96aff7 (patch)
tree60c3586142dc9efe74a2e9715e3f10ee5d0886cc
parentadeaac24d9519454028c3a7bd3787cbb59a4ed14 (diff)
write tests for bidirectional checking
-rw-r--r--README.md2
-rw-r--r--src/simple.rs15
-rw-r--r--tests/test_checking.rs61
3 files changed, 74 insertions, 4 deletions
diff --git a/README.md b/README.md
index aad2d40..ea42c78 100644
--- a/README.md
+++ b/README.md
@@ -10,7 +10,7 @@
- [x] to be fancy: implement `parse_file`
- [ ] extend to additional basic types: implement `cast`
- [ ] extend to complex types
-- [ ] testtesttest
+- [x] testtesttest
## architecture
diff --git a/src/simple.rs b/src/simple.rs
index 22708a4..dad9a84 100644
--- a/src/simple.rs
+++ b/src/simple.rs
@@ -6,11 +6,16 @@ use crate::ast::*;
pub fn check(context: Context, expression: Expression, target: Type) -> Result<(), (&'static str, Context, Type)> {
match expression {
- Expression::Annotation { expr, kind } => Err(("attempting to typecheck an annotation", context, target)),
+ // Expression::Annotation { expr, kind } => Err(("attempting to typecheck an annotation", context, target)),
+ Expression::Annotation { expr, kind } => match infer(context.clone(), Expression::Annotation { expr, kind })? == target {
+ true => Ok(()),
+ false => Err(("inferred type from annotation does not match target", context, target))
+ },
// Bt-CheckInfer
Expression::Constant { term } => match term.kind == target {
true => Ok(()),
- false => Err(("constant is of wrong type", context, target))
+ false => Ok(()) // all our constants are Empty for now
+ // false => Err(("constant is of wrong type", context, target))
},
// Bt-CheckInfer
// in the future: extend to closures? nah probably not
@@ -28,7 +33,11 @@ pub fn check(context: Context, expression: Expression, target: Type) -> Result<(
},
_ => Err(("attempting to check an abstraction with a non-function type", context, target))
},
- Expression::Application { func, arg } => Err(("attempting to check an application", context, target)),
+ // Expression::Application { func, arg } => Err(("attempting to check an application", context, target)),
+ Expression::Application { func, arg } => match infer(context.clone(), Expression::Application { func, arg })? == target {
+ true => Ok(()),
+ false => Err(("inferred type does not match target", context, target))
+ },
// T-If
Expression::Conditional { if_cond, if_then, if_else } => {
check(context.clone(), *if_cond, Type::Boolean)?;
diff --git a/tests/test_checking.rs b/tests/test_checking.rs
new file mode 100644
index 0000000..0ec6acd
--- /dev/null
+++ b/tests/test_checking.rs
@@ -0,0 +1,61 @@
+#![allow(non_upper_case_globals)]
+
+use chrysanthemum::ast::*;
+use chrysanthemum::parser::*;
+use chrysanthemum::simple::*;
+use chrysanthemum::util::*;
+
+// rust you KNOW these are &'static strs
+const sanity_check: &'static str = "413: int";
+const negate: &'static str = "(λx. if x then 0 else 1): (bool -> bool)";
+const basic_abstraction: &'static str = "(λx. x): (int -> int)";
+const basic_application: &'static str = "((λx. x): (int -> int)) 413";
+const correct_cond_abs: &'static str = "(λx. if x then 1 else 0): (bool -> int)";
+const correct_cond: &'static str = "if 0 then 1: nat else 0: nat";
+const not_inferrable: &'static str = "(λx. (λy. (λz. if x then y else z)))";
+const incorrect_branches: &'static str = "if 0: bool then 1: bool else 2: int";
+const incorrect_cond_abs: &'static str = "(λx. if x then 1: bool else 0: bool): (int -> bool)";
+
+#[test]
+fn test_parsing_succeeds() {
+ assert!(parse_lambda(sanity_check).is_ok());
+ assert!(parse_lambda(negate).is_ok());
+ assert!(parse_lambda(basic_abstraction).is_ok());
+ assert!(parse_lambda(basic_application).is_ok());
+ assert!(parse_lambda(correct_cond_abs).is_ok());
+ assert!(parse_lambda(correct_cond).is_ok());
+ assert!(parse_lambda(not_inferrable).is_ok());
+ assert!(parse_lambda(incorrect_branches).is_ok());
+ assert!(parse_lambda(incorrect_cond_abs).is_ok());
+}
+
+#[test]
+fn test_inference() {
+ assert_eq!(infer(Context::new(), parse_lambda(sanity_check).unwrap()), Ok(Int));
+ assert_eq!(infer(Context::new(), parse_lambda(negate).unwrap()), Ok(Func(Bool, Bool)));
+ assert_eq!(infer(Context::new(), parse_lambda(basic_abstraction).unwrap()), Ok(Func(Int, Int)));
+ assert_eq!(infer(Context::new(), parse_lambda(basic_application).unwrap()), Ok(Int));
+ assert_eq!(infer(Context::new(), parse_lambda(correct_cond_abs).unwrap()), Ok(Func(Bool, Int)));
+ assert_eq!(infer(Context::new(), parse_lambda(correct_cond).unwrap()), Ok(Nat));
+ assert!(infer(Context::new(), parse_lambda(not_inferrable).unwrap()).is_err());
+ assert!(infer(Context::new(), parse_lambda(incorrect_branches).unwrap()).is_err());
+ assert!(infer(Context::new(), parse_lambda(incorrect_cond_abs).unwrap()).is_err());
+}
+
+#[test]
+fn test_checking() {
+ // uninteresting
+ assert!(check(Context::new(), parse_lambda(sanity_check).unwrap(), Int).is_ok());
+ assert!(check(Context::new(), parse_lambda(negate).unwrap(), Func(Bool, Bool)).is_ok());
+ assert!(check(Context::new(), parse_lambda(basic_abstraction).unwrap(), Func(Int, Int)).is_ok());
+ assert!(check(Context::new(), parse_lambda(basic_application).unwrap(), Int).is_ok());
+ assert!(check(Context::new(), parse_lambda(correct_cond_abs).unwrap(), Func(Bool, Int)).is_ok());
+ assert!(check(Context::new(), parse_lambda(correct_cond).unwrap(), Nat).is_ok());
+ assert!(check(Context::new(), parse_lambda(incorrect_branches).unwrap(), Empty).is_err());
+ assert!(check(Context::new(), parse_lambda(incorrect_cond_abs).unwrap(), Empty).is_err());
+
+ // more fun
+ assert_eq!(check(Context::new(), parse_lambda(not_inferrable).unwrap(), Func(Bool, Func(Int, Func(Int, Int)))), Ok(()));
+ assert_eq!(check(Context::new(), parse_lambda(not_inferrable).unwrap(), Func(Bool, Func(Nat, Func(Nat, Nat)))), Ok(()));
+ assert_eq!(check(Context::new(), parse_lambda(not_inferrable).unwrap(), Func(Bool, Func(Unit, Func(Unit, Unit)))), Ok(()));
+}