diff options
-rw-r--r-- | README.md | 33 | ||||
-rw-r--r-- | src/ast.rs | 181 | ||||
-rw-r--r-- | src/bidirectional.rs | 203 | ||||
-rw-r--r-- | src/classes.rs | 3 | ||||
-rw-r--r-- | src/main.rs | 8 | ||||
-rw-r--r-- | src/parser.rs | 65 | ||||
-rw-r--r-- | src/simple.rs | 53 | ||||
-rw-r--r-- | src/util.rs | 4 | ||||
-rw-r--r-- | tests/test_checking.rs | 40 | ||||
-rw-r--r-- | tests/test_execution.rs | 17 | ||||
-rw-r--r-- | tests/test_parser.rs | 44 |
11 files changed, 342 insertions, 309 deletions
@@ -1,4 +1,10 @@ -# chrysanthemum: a simple type system +# chrysanthemum + +chrysanthemum is a simple language with a type system, initially written as a term project for CPSC 539. +It implements a number of features from the excellent *Types and Programming Languages*, including: +- The simply typed lambda calculus +- Bidirectional type checking and subtyping support +- A somewhat complex type system: including support for `unit`, `bool`, `nat`, `int`, `float`, `str`, `union`, `struct`, `empty`, and `err` types ## todo @@ -8,19 +14,30 @@ - [x] bidirectional typechecking: implement `infer` and `check` - [x] extend to additional basic types: refactor `Term` - [ ] extend to complex types: implement `access` -- [ ] simple effects: extend `ast` - [ ] type classes: implement `monomorphize` +- [ ] simple effects: extend `ast` - [x] testtesttest ## architecture ```bash src/ -src/main.rs # the user facing program -src/parser.rs # parses user programs into proper data structures -src/ast.rs # the fundamental representation of the program -src/simple.rs # the core of the lambda calculus: checking, inference, evaluation -src/effects.rs # code for effects idk -src/classes.rs # a monomorphization pass for type classes +src/main.rs # the user facing program +src/simple.rs # the simple lambda calculus: execution +src/ast.rs # the fundamental representation of types and terms +src/bidirectional.rs # the core of the language: checking, inference +src/unification.rs # an alternate core: checking and inference by unification +src/parser.rs # parses user programs into proper data structures +src/monomorphize.rs # a monomorphization pass for type classes +src/effects.rs # code for effects idk test/ # various tests ``` + +## bibliography + +- [TAPL](https://www.cis.upenn.edu/~bcpierce/tapl/) +- [Bidirectional Typing Rules: A Tutorial](https://www.davidchristiansen.dk/tutorials/bidirectional.pdf) +- [Bidirectional Typechecking](https://research.cs.queensu.ca/home/jana/bitype.pdf) +- [Typechecking for Higher-Rank Polymorphism](https://arxiv.org/pdf/1306.6032.pdf) +- [Bidirectional Type Class Instances](https://arxiv.org/pdf/1906.12242.pdf) +- [How to make ad-hoc polymorphism less ad-hoc](https://dl.acm.org/doi/pdf/10.1145/75277.75283) @@ -1,10 +1,10 @@ -// The abstract syntax tree. All supported types go here. - -use core::fmt; use std::collections::HashMap; +pub type Result<T> = core::result::Result<T, Box<dyn std::error::Error>>; pub type Identifier = String; -pub type Context = HashMap<Identifier, Term>; + +#[derive(Debug, Clone, PartialEq)] +pub struct Context(HashMap<Identifier, Term>); // note: built-in functions do NOT go here! #[derive(Debug, Clone, PartialEq)] @@ -18,6 +18,7 @@ pub enum Expression { Conditional{if_cond: Box<Expression>, if_then: Box<Expression>, if_else: Box<Expression>} } +/// All supported types. #[derive(Debug, Clone, PartialEq)] pub enum Type { Empty, @@ -28,11 +29,12 @@ pub enum Type { Integer, Float, String, - Enum(Vec<Type>), - Record(HashMap<Identifier, Type>), + Union(Vec<Type>), + Struct(HashMap<Identifier, Type>), Function{from: Box<Type>, to: Box<Type>}, } +/// Data associated with a type. #[derive(Debug, Clone, PartialEq)] pub enum Term { Unit(), @@ -41,13 +43,71 @@ pub enum Term { Integer(isize), Float(f32), String{len: usize, cap: usize, data: Vec<usize>}, - Enum{val: usize, data: Box<Term>}, // is this right? - Record(HashMap<Identifier, Term>), // is this right? + Union{val: usize, data: Box<Term>}, // is this right? + Struct(HashMap<Identifier, Term>), // is this right? Function(Box<Expression>) // this should allow us to bind functions } -impl fmt::Display for Expression { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { +impl Term { + /// Convert a term into its corresponding type. + pub fn convert(&self) -> Result<Type> { + match self { + Term::Unit() => Ok(Type::Unit), + Term::Boolean(_) => Ok(Type::Boolean), + Term::Natural(_) => Ok(Type::Natural), + Term::Integer(_) => Ok(Type::Integer), + Term::Float(_) => Ok(Type::Float), + Term::String { len, cap, data } => Ok(Type::String), + Term::Union { val, data } => data.convert(), + Term::Struct(data) => { + let mut result = HashMap::new(); + for (key, val) in data { + result.insert(key.clone(), val.convert()?); + } + return Ok(Type::Struct(result)); + }, + Term::Function(func) => match *func.clone() { + Expression::Annotation { expr, kind } => match kind { + Type::Function { from, to } => Ok(Type::Function { from, to }), + _ => Err("function term value not a function!".into()) + } + _ => Err("function term value does not have an annotation!".into()) + } + } + } +} + +impl Type { + /// Get the default value of a type. Throws an error if it doesn't exist. + pub fn default(&self) -> Result<Term> { + match self { + Type::Empty => Err("attempting to take the default term for empty".into()), + Type::Error => Err("attempting to take the default term for error".into()), + Type::Unit => Ok(Term::Unit()), + Type::Boolean => Ok(Term::Boolean(false)), + Type::Natural => Ok(Term::Natural(0)), + Type::Integer => Ok(Term::Integer(0)), + Type::Float => Ok(Term::Float(0.0)), + Type::String => Ok(Term::String { len: 0, cap: 0, data: vec!()}), + Type::Union(data) => match data.len() { + 0 => Err("attempting to get a default term of an enum with no variants!".into()), + _ => Ok(Term::Union { val: 0, data: Box::new(data.get(0).unwrap().default()?) }) + }, + Type::Struct(data) => { + let mut result = HashMap::new(); + for (key, val) in data { + result.insert(key.clone(), val.default()?); + } + return Ok(Term::Struct(result)); + }, + Type::Function { from, to } => + Err("attempting to take the default term of a function type".into()), + } + } +} + +impl core::fmt::Display for Expression { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { match self { Expression::Annotation { expr, kind } => write!(f, "({}: {})", expr, kind), Expression::Constant { term } => write!(f, "'{:?}", term), @@ -59,8 +119,8 @@ impl fmt::Display for Expression { } } -impl fmt::Display for Type { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { +impl core::fmt::Display for Type { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { match self { Type::Empty => write!(f, "⊤"), Type::Error => write!(f, "⊥"), @@ -70,15 +130,15 @@ impl fmt::Display for Type { Type::Integer => write!(f, "int"), Type::Float => write!(f, "float"), Type::String => write!(f, "str"), - Type::Enum(data) => write!(f, "({:?})", data), - Type::Record(data) => write!(f, "{{{:?}}}", data), + Type::Union(data) => write!(f, "({:?})", data), + Type::Struct(data) => write!(f, "{{{:?}}}", data), Type::Function { from, to } => write!(f, "{}->{}", from, to), } } } -impl fmt::Display for Term { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { +impl core::fmt::Display for Term { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { match self { Term::Unit() => write!(f, "∅"), Term::Boolean(term) => write!(f, "{}", term), @@ -86,90 +146,21 @@ impl fmt::Display for Term { Term::Integer(term) => write!(f, "{}", term), Term::Float(term) => write!(f, "{}", term), Term::String { len, cap, data } => write!(f, "\"{:?}\"", data), - Term::Enum { val, data } => write!(f, "{:?}", data), - Term::Record(term) => write!(f, "{:?}", term), + Term::Union { val, data } => write!(f, "{:?}", data), + Term::Struct(term) => write!(f, "{:?}", term), Term::Function(expr) => write!(f, "{}", *expr), } } } -// hatehatehate that you can't implement a trait for foreign types -// impl<T> fmt::Display for Vec<T> where T: fmt::Display { -// fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { -// for (i, val) in self.enumerate() { -// if i == 0 { -// write!(f, "{}", val); -// } else { -// write!(f, ",{}", val); -// } -// } -// return Ok(()); -// } -// } - -// impl<T, U> fmt::Display for HashMap<T, U> where T: fmt::Display, U: fmt::Display { -// fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { -// for (i, (key, val)) in self.enumerate() { -// if i == 0 { -// write!(f, "{}={}", key, val); -// } else { -// write!(f, ",{}={}", key, val); -// } -// } -// return Ok(()); -// } -// } - -/// Convert a term into its corresponding type. -pub fn convert(term: &Term) -> Result<Type, String> { - match term { - Term::Unit() => Ok(Type::Unit), - Term::Boolean(_) => Ok(Type::Boolean), - Term::Natural(_) => Ok(Type::Natural), - Term::Integer(_) => Ok(Type::Integer), - Term::Float(_) => Ok(Type::Float), - Term::String { len, cap, data } => Ok(Type::String), - Term::Enum { val, data } => convert(data), - Term::Record(data) => { - let mut result = HashMap::new(); - for (key, val) in data { - result.insert(key.clone(), convert(val)?); - } - return Ok(Type::Record(result)); - }, - Term::Function(func) => match *func.clone() { - Expression::Annotation { expr, kind } => match kind { - Type::Function { from, to } => Ok(Type::Function { from, to }), - _ => Err("function term value not a function!".to_string()) - } - _ => Err("function term value does not have an annotation!".to_string()) - } +impl Context { + pub fn new() -> Self { + Context(HashMap::new()) } -} - -/// Get the default value of a type. Throws an error if it doesn't exist. -pub fn default(kind: &Type) -> Result<Term, String> { - match kind { - Type::Empty => Err("attempting to take the default term for empty".to_string()), - Type::Error => Err("attempting to take the default term for error".to_string()), - Type::Unit => Ok(Term::Unit()), - Type::Boolean => Ok(Term::Boolean(false)), - Type::Natural => Ok(Term::Natural(0)), - Type::Integer => Ok(Term::Integer(0)), - Type::Float => Ok(Term::Float(0.0)), - Type::String => Ok(Term::String { len: 0, cap: 0, data: vec!()}), - Type::Enum(data) => match data.len() { - 0 => Err("attempting to get a default term of an enum with no variants!".to_string()), - _ => Ok(Term::Enum { val: 0, data: Box::new(default(data.get(0).unwrap())?) }) - }, - Type::Record(data) => { - let mut result = HashMap::new(); - for (key, val) in data { - result.insert(key.clone(), default(val)?); - } - return Ok(Term::Record(result)); - }, - Type::Function { from, to } => - Err("attempting to take the default term of a function type".to_string()), + pub fn get(&self, k: &Identifier) -> Option<&Term> { + self.0.get(k) + } + pub fn insert(&mut self, k: Identifier, v: Term) -> Option<Term> { + self.0.insert(k, v) } } diff --git a/src/bidirectional.rs b/src/bidirectional.rs index 493ae2f..d62267c 100644 --- a/src/bidirectional.rs +++ b/src/bidirectional.rs @@ -2,119 +2,122 @@ use crate::ast::*; -/// Checking judgement: takes an expression and a type to check against and calls out to `infer` as needed. -pub fn check(context: &Context, expression: Expression, target: &Type) -> Result<(), String> { - match expression { - // fall through to inference mode - Expression::Annotation { expr, kind } => { - let result = infer(context, Expression::Annotation { expr, kind })?; - return match subtype(&result, &target) { - true => Ok(()), - false => Err(format!("inferred type {result} does not match target {target}")) - } - }, - // Bt-CheckInfer - Expression::Constant { term } => match subtype(&convert(&term)?, &target) { - true => Ok(()), - false => Err(format!("constant is of wrong type, expected {target}")) - // false => Ok(()) // all our constants are Empty for now - }, - // Bt-CheckInfer - Expression::Variable { id } => match context.get(&id) { - Some(term) if subtype(&convert(term)?, &target) => Ok(()), - Some(_) => Err(format!("variable {id} is of wrong type")), - None => Err(format!("failed to find variable {id} in context")) - }, - // Bt-Abs - Expression::Abstraction { param, func } => match target { - Type::Function { from, to } => { - let mut context = context.clone(); - context.insert(param, default(from)?); - return check(&context, *func, &to); +impl Context { + /// Checking judgement: takes an expression and a type to check against and calls out to `infer` as needed. + pub fn check(&self, expression: Expression, target: &Type) -> Result<()> { + match expression { + // fall through to inference mode + Expression::Annotation { expr, kind } => { + let result = self.infer(Expression::Annotation { expr, kind })?; + return match result.subtype(&target) { + true => Ok(()), + false => Err(format!("inferred type {result} does not match target {target}").into()) + } }, - _ => Err(format!("attempting to check an abstraction with a non-function type {target}")) - }, - // fall through to inference mode - Expression::Application { func, arg } => { - let result = &infer(context, Expression::Application { func, arg })?; - return match subtype(&result, &target) { + // Bt-CheckInfer + Expression::Constant { term } => match &term.convert()?.subtype(&target) { true => Ok(()), - false => Err(format!("inferred type {result} does not match {target}")) + false => Err(format!("constant is of wrong type, expected {target}").into()) + // false => Ok(()) // all our constants are Empty for now + }, + // Bt-CheckInfer + Expression::Variable { id } => match self.get(&id) { + Some(term) if term.convert()?.subtype(&target) => Ok(()), + Some(_) => Err(format!("variable {id} is of wrong type").into()), + None => Err(format!("failed to find variable {id} in context").into()) + }, + // Bt-Abs + Expression::Abstraction { param, func } => match target { + Type::Function { from, to } => { + let mut context = self.clone(); + context.insert(param, from.default()?); + return context.check(*func, &to); + }, + _ => Err(format!("attempting to check an abstraction with a non-function type {target}").into()) + }, + // fall through to inference mode + Expression::Application { func, arg } => { + let result = &self.infer(Expression::Application { func, arg })?; + return match result.subtype(&target) { + true => Ok(()), + false => Err(format!("inferred type {result} does not match {target}").into()) + } + }, + // T-If + Expression::Conditional { if_cond, if_then, if_else } => { + self.check(*if_cond, &Type::Boolean)?; + self.check(*if_then, &target)?; + self.check(*if_else, &target)?; + return Ok(()); } - }, - // T-If - Expression::Conditional { if_cond, if_then, if_else } => { - check(context, *if_cond, &Type::Boolean)?; - check(context, *if_then, &target)?; - check(context, *if_else, &target)?; - return Ok(()); } } -} -/// Inference judgement: takes an expression and attempts to infer the associated type. -pub fn infer(context: &Context, expression: Expression) -> Result<Type, String> { - match expression { - // Bt-Ann - Expression::Annotation { expr, kind } => check(context, *expr, &kind).map(|x| kind), - // Bt-True / Bt-False / etc - Expression::Constant { term } => convert(&term), - // Bt-Var - Expression::Variable { id } => match context.get(&id) { - Some(term) => infer(&Context::new(), Expression::Constant { term: term.clone() }), - None => Err(format!("failed to find variable in context {context:?}")) - }, - // Bt-App - Expression::Application { func, arg } => match infer(context, *func)? { - Type::Function { from, to } => check(context, *arg, &*from).map(|x| *to), - _ => Err(format!("application abstraction is not a function type")) - }, - // inference from an abstraction is always an error - // we could try and infer the func without adding the parameter to scope: - // but this is overwhelmingly likely to be an error, so just report it now. - Expression::Abstraction { param, func } => - Err(format!("attempting to infer from an abstraction")), - // idk - Expression::Conditional { if_cond, if_then, if_else } => { - check(context, *if_cond, &Type::Boolean)?; - let if_then = infer(context, *if_then)?; - let if_else = infer(context, *if_else)?; - if subtype(&if_then, &if_else) && subtype(&if_else, &if_then) { - Ok(if_then) // fixme: should be the join - } else { - Err(format!("if clauses of different types: {if_then} and {if_else}")) + /// Inference judgement: takes an expression and attempts to infer the associated type. + pub fn infer(&self, expression: Expression) -> Result<Type> { + match expression { + // Bt-Ann + Expression::Annotation { expr, kind } => self.check(*expr, &kind).map(|x| kind), + // Bt-True / Bt-False / etc + Expression::Constant { term } => term.convert(), + // Bt-Var + Expression::Variable { id } => match self.get(&id) { + Some(term) => Context::new().infer(Expression::Constant { term: term.clone() }), + None => Err(format!("failed to find variable in context {self:?}").into()) + }, + // Bt-App + Expression::Application { func, arg } => match self.infer(*func)? { + Type::Function { from, to } => self.check(*arg, &*from).map(|x| *to), + _ => Err(format!("application abstraction is not a function type").into()) + }, + // inference from an abstraction is always an error + // we could try and infer the func without adding the parameter to scope: + // but this is overwhelmingly likely to be an error, so just report it now. + Expression::Abstraction { param, func } => + Err(format!("attempting to infer from an abstraction").into()), + // idk + Expression::Conditional { if_cond, if_then, if_else } => { + self.check(*if_cond, &Type::Boolean)?; + let if_then = self.infer(*if_then)?; + let if_else = self.infer(*if_else)?; + if if_then.subtype(&if_else) && if_else.subtype(&if_then) { + Ok(if_then) // fixme: should be the join + } else { + Err(format!("if clauses of different types: {if_then} and {if_else}").into()) + } } } } } -/// The subtyping relation between any two types. -pub fn subtype(is: &Type, of: &Type) -> bool { - match (is, of) { - (Type::Record(is_fields), Type::Record(of_fields)) => { - // width, depth, and permutation - for (key, of_value) in of_fields { - match is_fields.get(key) { - Some(is_value) => { - if !subtype(is_value, of_value) { - return false; +impl Type { + /// The subtyping relation between any two types. + pub fn subtype(&self, other: &Self) -> bool { + match (self, other) { + (Type::Struct(is_fields), Type::Struct(of_fields)) => { + // width, depth, and permutation + for (key, of_value) in of_fields { + match is_fields.get(key) { + Some(is_value) => { + if !is_value.subtype(of_value) { + return false; + } } + None => return false } - None => return false } - } - return true; - }, - (Type::Enum(is_variants), Type::Enum(of_variants)) => false, // fixme - (Type::Function { from: is_from, to: is_to }, - Type::Function { from: of_from, to: of_to }) => { - subtype(of_from, is_from) && subtype(is_to, of_to) - }, - (Type::Natural, Type::Integer) => true, // obviously not, but let's pretend - (_, Type::Empty) => true, - (Type::Error, _) => true, - (_, _) if is == of => true, - (_, _) => false + return true; + }, + (Type::Union(is_variants), Type::Union(of_variants)) => false, // fixme + (Type::Function { from: is_from, to: is_to }, + Type::Function { from: of_from, to: of_to }) => { + of_from.subtype(is_from) && is_to.subtype(of_to) + }, + (Type::Natural, Type::Integer) => true, // obviously not, but let's pretend + (_, Type::Empty) => true, + (Type::Error, _) => true, + (_, _) if self == other => true, + (_, _) => false + } } } - diff --git a/src/classes.rs b/src/classes.rs deleted file mode 100644 index 38a4847..0000000 --- a/src/classes.rs +++ /dev/null @@ -1,3 +0,0 @@ -// Typeclass pass: monomorphize based on usage, pretty much - - diff --git a/src/main.rs b/src/main.rs index a33e963..b93e778 100644 --- a/src/main.rs +++ b/src/main.rs @@ -24,7 +24,7 @@ fn main() { input.clear(); stdin().read_line(&mut input).unwrap(); - match infer(&empty_context, parser::parse_lambda(&input).unwrap()) { + match empty_context.infer(parser::parse_lambda(&input).unwrap()) { Ok(kind) => println!("infers! {}", kind), Err(e) => println!("{:?}", e), } @@ -36,10 +36,10 @@ fn main() { input.clear(); stdin().read_line(&mut input).unwrap(); - let kind = infer(&empty_context, parser::parse_lambda(&input).unwrap()); + let kind = empty_context.infer(parser::parse_lambda(&input).unwrap()); match kind { Ok(kind) => { - match check(&empty_context, parser::parse_lambda(&input).unwrap(), &kind) { + match empty_context.check(parser::parse_lambda(&input).unwrap(), &kind) { Ok(_) => println!("checks!"), Err(e) => println!("{:?}", e), } @@ -54,7 +54,7 @@ fn main() { input.clear(); stdin().read_line(&mut input).unwrap(); - match execute(&empty_context, parser::parse_lambda(&input).unwrap()) { + match empty_context.execute(parser::parse_lambda(&input).unwrap()) { Ok(term) => println!("{}", term), Err(e) => println!("{:?}", e) } diff --git a/src/parser.rs b/src/parser.rs index d3895a4..298e09d 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -2,7 +2,7 @@ use crate::ast::*; use multipeek::multipeek; /// Parses a lambda-calculus-like language into an AST. -pub fn parse_lambda(input: &str) -> Result<Expression, peg::error::ParseError<peg::str::LineCol>> { +pub fn parse_lambda(input: &str) -> Result<Expression> { // this is kinda awful, i miss my simple nim pegs peg::parser! { grammar lambda() for str { @@ -98,7 +98,7 @@ pub fn parse_lambda(input: &str) -> Result<Expression, peg::error::ParseError<pe } } } - return lambda::expr(input.trim()); + return Ok(lambda::expr(input.trim())?); } const operators: [char; 17] = @@ -115,15 +115,15 @@ pub enum Token { Value(String), Char(char), String(String), - Comment(String), - Token(String), // catch-all + Comment(String), // unused tbh + // Token(String), // catch-all ScopeBegin, // { ScopeEnd, // } ExprEnd, // ; } /// Properly lexes a whitespace-oriented language into a series of tokens. -pub fn lex(input: &str) -> Result<Vec<Token>, &'static str> { +pub fn lex(input: &str) -> Result<Vec<Token>> { enum State { Default, Char, @@ -142,6 +142,7 @@ pub fn lex(input: &str) -> Result<Vec<Token>, &'static str> { let mut buffer = String::new(); let mut result = Vec::new(); + // .next() advances the iterator, .peek() does not let mut input = multipeek(input.chars()); // multipeek my beloved while let Some(c) = input.next() { match state { @@ -151,9 +152,36 @@ pub fn lex(input: &str) -> Result<Vec<Token>, &'static str> { result.push(parse_token(&buffer)?); buffer.clear(); }, - ' ' => todo!(), - '\n' => todo!(), - '\t' => return Err("Tabs are not supported!"), + ' ' => { + // hmm + todo!() + }, + '\n' => { + if let Some(previous) = result.last() { + match previous { + // same scope, no seperator + Token::Operator(_) => (), + // do we have this tbh???? + Token::Keyword(_) => return Err("lines shouldn't end with a keyword i think".into()), + // i think uhh + Token::Identifier(_) | Token::Value(_) | Token::Char(_) | Token::String(_) => + result.push(Token::ExprEnd), + // always scope?? + Token::Separator(_) => result.push(Token::ScopeBegin), + // idk + Token::Comment(_) | Token::ScopeBegin | Token::ScopeEnd | Token::ExprEnd => + return Err("uhh idk always scope lol".into()), + } + } + todo!() + }, + '\t' => return Err("Tabs are not supported!".into()), + _ if indent.blank => { + indent.blank = false; + // indentation check + todo!(); + buffer.push(c); + }, '\'' => { result.push(parse_token(&buffer)?); buffer.clear(); @@ -187,12 +215,6 @@ pub fn lex(input: &str) -> Result<Vec<Token>, &'static str> { indent.blank = false; } } - _ if indent.blank => { - indent.blank = false; - // indentation check - todo!(); - buffer.push(c); - } _ => buffer.push(c) }, State::Char => match c { @@ -205,11 +227,11 @@ pub fn lex(input: &str) -> Result<Vec<Token>, &'static str> { Some('t') => result.push(Token::Char('\t')), Some('\"') => result.push(Token::Char('\"')), Some('\'') => result.push(Token::Char('\'')), - _ => return Err("Invalid string escape sequence!"), + _ => return Err("Invalid string escape sequence!".into()), } state = State::Default; if input.next() != Some('\'') { - return Err("Invalid character sequence!") + return Err("Invalid character sequence!".into()) } }, '\'' => { @@ -220,7 +242,7 @@ pub fn lex(input: &str) -> Result<Vec<Token>, &'static str> { result.push(Token::Char(c)); state = State::Default; if input.next() != Some('\'') { - return Err("Invalid character sequence!") + return Err("Invalid character sequence!".into()) } } }, @@ -233,7 +255,7 @@ pub fn lex(input: &str) -> Result<Vec<Token>, &'static str> { Some('t') => buffer.push('\t'), Some('\"') => buffer.push('\"'), Some('\'') => buffer.push('\''), - _ => return Err("Invalid string escape sequence!"), + _ => return Err("Invalid string escape sequence!".into()), }, '\"' => { state = State::Default; @@ -255,7 +277,8 @@ pub fn lex(input: &str) -> Result<Vec<Token>, &'static str> { State::Comment => match c { '\n' => { state = State::Default; - result.push(Token::Comment(buffer.to_string())); + // result.push(Token::Comment(buffer.to_string())); + buffer.clear(); }, _ => buffer.push(c) }, @@ -264,7 +287,7 @@ pub fn lex(input: &str) -> Result<Vec<Token>, &'static str> { return Ok(result); } -fn parse_token(token: &str) -> Result<Token, &'static str> { +fn parse_token(token: &str) -> Result<Token> { if keywords.contains(&token) { Ok(Token::Keyword(token.to_string())) } else if is_operator(token) { @@ -274,7 +297,7 @@ fn parse_token(token: &str) -> Result<Token, &'static str> { } else if is_identifier(token) { Ok(Token::Identifier(token.to_string())) } else { - Err("Could not parse token!") + Err("Could not parse token!".into()) } } diff --git a/src/simple.rs b/src/simple.rs index 8d87dbf..e763221 100644 --- a/src/simple.rs +++ b/src/simple.rs @@ -1,32 +1,33 @@ use crate::ast::*; -/// Evaluates an expression given a context (of variables) to a term, or fails. -pub fn execute(context: &Context, expression: Expression) -> Result<Term, String> { - match expression { - Expression::Annotation { expr, .. } => execute(context, *expr), - Expression::Constant { term } => Ok(term), - Expression::Variable { id } => match context.get(&id) { - Some(term) => Ok(term.clone()), - None => Err(format!("no such variable in context {context:?}")) - }, - Expression::Abstraction { param, func } => - Err(format!("attempting to execute an abstraction ({}){}", param, func)), - Expression::Application { func, arg } => match *func { - Expression::Abstraction { param, func } => { - let value = execute(context, *arg)?; - let mut context = context.clone(); - context.insert(param, value); - return execute(&context, *func); - } - _ => Err(format!("attempting to execute an application to non-abstraction {}", *func)) - }, - Expression::Conditional { if_cond, if_then, if_else } => { - match execute(context, *if_cond)? { - Term::Boolean(true) => execute(context, *if_then), - Term::Boolean(false) => execute(context, *if_else), - term => Err(format!("invalid type {} for a conditional", convert(&term)?)) +impl Context { + /// Evaluates an expression given a context (of variables) to a term, or fails. + pub fn execute(&self, expression: Expression) -> Result<Term> { + match expression { + Expression::Annotation { expr, .. } => self.execute(*expr), + Expression::Constant { term } => Ok(term), + Expression::Variable { id } => match self.get(&id) { + Some(term) => Ok(term.clone()), + None => Err(format!("no such variable in context {self:?}").into()) + }, + Expression::Abstraction { param, func } => + Err(format!("attempting to execute an abstraction ({}){}", param, func).into()), + Expression::Application { func, arg } => match *func { + Expression::Abstraction { param, func } => { + let value = self.execute(*arg)?; + let mut context = self.clone(); + context.insert(param, value); + return context.execute(*func); + } + _ => Err(format!("attempting to execute an application to non-abstraction {}", *func).into()) + }, + Expression::Conditional { if_cond, if_then, if_else } => { + match self.execute(*if_cond)? { + Term::Boolean(true) => self.execute(*if_then), + Term::Boolean(false) => self.execute(*if_else), + term => Err(format!("invalid type {} for a conditional", &term.convert()?).into()) + } } } } } - diff --git a/src/util.rs b/src/util.rs index a27f32c..39ac263 100644 --- a/src/util.rs +++ b/src/util.rs @@ -74,6 +74,6 @@ pub fn Str(len: usize, cap: usize, data: Vec<usize>) -> Term { return Term::String { len, cap, data } } -pub fn Enum(val: usize, data: Term) -> Term { - return Term::Enum { val, data: Box::new(data) } +pub fn Union(val: usize, data: Term) -> Term { + return Term::Union { val, data: Box::new(data) } } diff --git a/tests/test_checking.rs b/tests/test_checking.rs index e553a4e..b0f4140 100644 --- a/tests/test_checking.rs +++ b/tests/test_checking.rs @@ -32,15 +32,15 @@ fn test_parsing_succeeds() { #[test] fn test_inference() { let context = Context::new(); - assert_eq!(infer(&context, parse_lambda(sanity_check).unwrap()), Ok(Int)); - assert_eq!(infer(&context, parse_lambda(negate).unwrap()), Ok(Func(Bool, Bool))); - assert_eq!(infer(&context, parse_lambda(basic_abstraction).unwrap()), Ok(Func(Int, Int))); - assert_eq!(infer(&context, parse_lambda(basic_application).unwrap()), Ok(Int)); - assert_eq!(infer(&context, parse_lambda(correct_cond_abs).unwrap()), Ok(Func(Bool, Int))); - assert_eq!(infer(&context, parse_lambda(correct_cond).unwrap()), Ok(Nat)); - assert!(infer(&context, parse_lambda(not_inferrable).unwrap()).is_err()); - assert!(infer(&context, parse_lambda(incorrect_branches).unwrap()).is_err()); - assert!(infer(&context, parse_lambda(incorrect_cond_abs).unwrap()).is_err()); + assert_eq!(context.infer(parse_lambda(sanity_check).unwrap()).unwrap(), Int); + assert_eq!(context.infer(parse_lambda(negate).unwrap()).unwrap(), Func(Bool, Bool)); + assert_eq!(context.infer(parse_lambda(basic_abstraction).unwrap()).unwrap(), Func(Int, Int)); + assert_eq!(context.infer(parse_lambda(basic_application).unwrap()).unwrap(), Int); + assert_eq!(context.infer(parse_lambda(correct_cond_abs).unwrap()).unwrap(), Func(Bool, Int)); + assert_eq!(context.infer(parse_lambda(correct_cond).unwrap()).unwrap(), Nat); + assert!(context.infer(parse_lambda(not_inferrable).unwrap()).is_err()); + assert!(context.infer(parse_lambda(incorrect_branches).unwrap()).is_err()); + assert!(context.infer(parse_lambda(incorrect_cond_abs).unwrap()).is_err()); } #[test] @@ -48,17 +48,17 @@ fn test_checking() { let context = Context::new(); // uninteresting - assert!(check(&context, parse_lambda(sanity_check).unwrap(), &Int).is_ok()); - assert!(check(&context, parse_lambda(negate).unwrap(), &Func(Bool, Bool)).is_ok()); - assert!(check(&context, parse_lambda(basic_abstraction).unwrap(), &Func(Int, Int)).is_ok()); - assert!(check(&context, parse_lambda(basic_application).unwrap(), &Int).is_ok()); - assert!(check(&context, parse_lambda(correct_cond_abs).unwrap(), &Func(Bool, Int)).is_ok()); - assert!(check(&context, parse_lambda(correct_cond).unwrap(), &Nat).is_ok()); - assert!(check(&context, parse_lambda(incorrect_branches).unwrap(), &Unit).is_err()); - assert!(check(&context, parse_lambda(incorrect_cond_abs).unwrap(), &Error).is_err()); + assert!(context.check(parse_lambda(sanity_check).unwrap(), &Int).is_ok()); + assert!(context.check(parse_lambda(negate).unwrap(), &Func(Bool, Bool)).is_ok()); + assert!(context.check(parse_lambda(basic_abstraction).unwrap(), &Func(Int, Int)).is_ok()); + assert!(context.check(parse_lambda(basic_application).unwrap(), &Int).is_ok()); + assert!(context.check(parse_lambda(correct_cond_abs).unwrap(), &Func(Bool, Int)).is_ok()); + assert!(context.check(parse_lambda(correct_cond).unwrap(), &Nat).is_ok()); + assert!(context.check(parse_lambda(incorrect_branches).unwrap(), &Unit).is_err()); + assert!(context.check(parse_lambda(incorrect_cond_abs).unwrap(), &Error).is_err()); // more fun - assert!(check(&context, parse_lambda(not_inferrable).unwrap(), &Func(Bool, Func(Int, Func(Int, Int)))).is_ok()); - assert!(check(&context, parse_lambda(not_inferrable).unwrap(), &Func(Bool, Func(Nat, Func(Nat, Nat)))).is_ok()); - assert!(check(&context, parse_lambda(not_inferrable).unwrap(), &Func(Bool, Func(Unit, Func(Unit, Unit)))).is_ok()); + assert!(context.check(parse_lambda(not_inferrable).unwrap(), &Func(Bool, Func(Int, Func(Int, Int)))).is_ok()); + assert!(context.check(parse_lambda(not_inferrable).unwrap(), &Func(Bool, Func(Nat, Func(Nat, Nat)))).is_ok()); + assert!(context.check(parse_lambda(not_inferrable).unwrap(), &Func(Bool, Func(Unit, Func(Unit, Unit)))).is_ok()); } diff --git a/tests/test_execution.rs b/tests/test_execution.rs index d9b1c5e..44df5ae 100644 --- a/tests/test_execution.rs +++ b/tests/test_execution.rs @@ -5,10 +5,10 @@ use chrysanthemum::util::*; #[test] fn test_simple() { let context = Context::new(); - assert_eq!(execute(&context, Const(Term::Boolean(false))), Ok(Term::Boolean(false))); - assert_eq!(execute(&context, Const(Term::Natural(123))), Ok(Term::Natural(123))); - assert_eq!(execute(&context, Const(Term::Integer(123))), Ok(Term::Integer(123))); - assert!(execute(&context, Var("x")).is_err()); + assert_eq!(context.execute(Const(Term::Boolean(false))).unwrap(), Term::Boolean(false)); + assert_eq!(context.execute(Const(Term::Natural(123))).unwrap(), Term::Natural(123)); + assert_eq!(context.execute(Const(Term::Integer(123))).unwrap(), Term::Integer(123)); + assert!(context.execute(Var("x")).is_err()); } #[test] @@ -16,8 +16,9 @@ fn test_complex() { let mut context = Context::new(); context.insert(String::from("x"), Term::Natural(413)); context.insert(String::from("y"), Term::Boolean(true)); - assert_eq!(execute(&context, Var("x")), Ok(Term::Natural(413))); - assert_eq!(execute(&context, Cond(Var("y"), Const(Term::Integer(612)), Var("x"))), Ok(Term::Integer(612))); - assert_eq!(execute(&context, - App(Abs("z", Cond(Const(Term::Boolean(false)), Var("x"), Var("z"))), Const(Term::Integer(1025)))), Ok(Term::Integer(1025))); + assert_eq!(context.execute(Var("x")).unwrap(), Term::Natural(413)); + assert_eq!(context.execute(Cond(Var("y"), Const(Term::Integer(612)), + Var("x"))).unwrap(), Term::Integer(612)); + assert_eq!(context.execute(App(Abs("z", Cond(Const(Term::Boolean(false)), + Var("x"), Var("z"))), Const(Term::Integer(1025)))).unwrap(), Term::Integer(1025)); } diff --git a/tests/test_parser.rs b/tests/test_parser.rs index 9ac8f32..84f6395 100644 --- a/tests/test_parser.rs +++ b/tests/test_parser.rs @@ -6,9 +6,9 @@ use chrysanthemum::util::*; #[test] fn test_simple_phrases() { - assert_eq!(parse_lambda("-123"), Ok(Const(Term::Integer(-123)))); - assert_eq!(parse_lambda("x12"), Ok(Var("x12"))); - assert_eq!(parse_lambda("x12x2"), Ok(Var("x12x2"))); + assert_eq!(parse_lambda("-123").unwrap(), Const(Term::Integer(-123))); + assert_eq!(parse_lambda("x12").unwrap(), Var("x12")); + assert_eq!(parse_lambda("x12x2").unwrap(), Var("x12x2")); // so i _don't_ want these to be valid identifiers: // but i actually have no idea why my peg is rejecting them lmao assert!(parse_lambda("12x").is_err()); @@ -17,36 +17,36 @@ fn test_simple_phrases() { #[test] fn test_simple_annotations() { - assert_eq!(parse_lambda("t: int"), Ok(Ann(Var("t"), Int))); - assert_eq!(parse_lambda("12: nat"), Ok(Ann(Const(Term::Natural(12)), Nat))); + assert_eq!(parse_lambda("t: int").unwrap(), Ann(Var("t"), Int)); + assert_eq!(parse_lambda("12: nat").unwrap(), Ann(Const(Term::Natural(12)), Nat)); assert!(parse_lambda("t: fake").is_err()); } #[test] fn test_simple_expressions() { - assert_eq!(parse_lambda("λx.y"), Ok(Abs("x", Var("y")))); - assert_eq!(parse_lambda("λ x.y"), Ok(Abs("x", Var("y")))); - assert_eq!(parse_lambda("λx.y"), Ok(Abs("x", Var("y")))); - assert_eq!(parse_lambda("lambda x . y"), Ok(Abs("x", Var("y")))); - assert_eq!(parse_lambda("(λx.y)"), Ok(Abs("x", Var("y")))); - assert_eq!(parse_lambda("(λx.y) x"), Ok(App(Abs("x", Var("y")), Var("x")))); - assert_eq!(parse_lambda("(λx.y) x"), Ok(App(Abs("x", Var("y")), Var("x")))); - assert_eq!(parse_lambda("if x then y else z"), Ok(Cond(Var("x"), Var("y"), Var("z")))); - assert_eq!(parse_lambda("if xeme then yak else zebra"), Ok(Cond(Var("xeme"), Var("yak"), Var("zebra")))); - assert_eq!(parse_lambda("if 413 then 612 else 1025"), Ok(Cond(Const(Term::Natural(413)), Const(Term::Natural(612)), Const(Term::Natural(1025))))); // invalid, but should parse + assert_eq!(parse_lambda("λx.y").unwrap(), Abs("x", Var("y"))); + assert_eq!(parse_lambda("λ x.y").unwrap(), Abs("x", Var("y"))); + assert_eq!(parse_lambda("λx.y").unwrap(), Abs("x", Var("y"))); + assert_eq!(parse_lambda("lambda x . y").unwrap(), Abs("x", Var("y"))); + assert_eq!(parse_lambda("(λx.y)").unwrap(), Abs("x", Var("y"))); + assert_eq!(parse_lambda("(λx.y) x").unwrap(), App(Abs("x", Var("y")), Var("x"))); + assert_eq!(parse_lambda("(λx.y) x").unwrap(), App(Abs("x", Var("y")), Var("x"))); + assert_eq!(parse_lambda("if x then y else z").unwrap(), Cond(Var("x"), Var("y"), Var("z"))); + assert_eq!(parse_lambda("if xeme then yak else zebra").unwrap(), Cond(Var("xeme"), Var("yak"), Var("zebra"))); + assert_eq!(parse_lambda("if 413 then 612 else 1025").unwrap(), Cond(Const(Term::Natural(413)), Const(Term::Natural(612)), Const(Term::Natural(1025)))); // invalid, but should parse } #[test] fn test_complex_expressions() { - assert_eq!(parse_lambda("(λy.if y then false else true) z"), Ok(App(Abs("y", Cond(Var("y"), Const(Term::Boolean(false)), Const(Term::Boolean(true)))), Var("z")))); + assert_eq!(parse_lambda("(λy.if y then false else true) z").unwrap(), App(Abs("y", Cond(Var("y"), Const(Term::Boolean(false)), Const(Term::Boolean(true)))), Var("z"))); } #[test] fn test_complex_annotations() { - assert_eq!(parse_lambda("(lambda x . y) : int"), Ok(Ann(Abs("x", Var("y")), Int))); - assert_eq!(parse_lambda("((lambda x. y): (int -> int)) -413: int"), Ok(App(Ann(Abs("x", Var("y")), Func(Int, Int) ), Ann(Const(Term::Integer(-413)), Int)))); - assert_eq!(parse_lambda("if false: bool then true: bool else 2: int"), Ok(Cond(Ann(Const(Term::Boolean(false)), Bool), Ann(Const(Term::Boolean(true)), Bool), Ann(Const(Term::Natural(2)), Int)))); - assert_eq!(parse_lambda("(lambda x. if x then true: bool else false: bool): (int -> bool)"), Ok(Ann(Abs("x", Cond(Var("x"), Ann(Const(Term::Boolean(true)), Bool), Ann(Const(Term::Boolean(false)), Bool))), Func(Int, Bool)))); - assert_eq!(parse_lambda("(lambda x. if x then 1: int else 0: int): (bool -> int)"), Ok(Ann(Abs("x", Cond(Var("x"), Ann(Const(Term::Natural(1)), Int), Ann(Const(Term::Natural(0)), Int))), Func(Bool, Int)))); - assert_eq!(parse_lambda("(lambda x. if x then false else true): (bool -> bool)"), Ok(Ann(Abs("x", Cond(Var("x"), Const(Term::Boolean(false)), Const(Term::Boolean(true)))), Func(Bool, Bool)))); + assert_eq!(parse_lambda("(lambda x . y) : int").unwrap(), Ann(Abs("x", Var("y")), Int)); + assert_eq!(parse_lambda("((lambda x. y): (int -> int)) -413: int").unwrap(), App(Ann(Abs("x", Var("y")), Func(Int, Int) ), Ann(Const(Term::Integer(-413)), Int))); + assert_eq!(parse_lambda("if false: bool then true: bool else 2: int").unwrap(), Cond(Ann(Const(Term::Boolean(false)), Bool), Ann(Const(Term::Boolean(true)), Bool), Ann(Const(Term::Natural(2)), Int))); + assert_eq!(parse_lambda("(lambda x. if x then true: bool else false: bool): (int -> bool)").unwrap(), Ann(Abs("x", Cond(Var("x"), Ann(Const(Term::Boolean(true)), Bool), Ann(Const(Term::Boolean(false)), Bool))), Func(Int, Bool))); + assert_eq!(parse_lambda("(lambda x. if x then 1: int else 0: int): (bool -> int)").unwrap(), Ann(Abs("x", Cond(Var("x"), Ann(Const(Term::Natural(1)), Int), Ann(Const(Term::Natural(0)), Int))), Func(Bool, Int))); + assert_eq!(parse_lambda("(lambda x. if x then false else true): (bool -> bool)").unwrap(), Ann(Abs("x", Cond(Var("x"), Const(Term::Boolean(false)), Const(Term::Boolean(true)))), Func(Bool, Bool))); } |