aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJJ2024-08-31 21:00:35 +0000
committerJJ2024-08-31 21:00:46 +0000
commit62717b52649129131ae69381e22b2bc09ff93dc7 (patch)
treeb5db31f980222506a273ee066ce9a9233f32a12f
parentb6c33b87b64bcbbd7555849031d1597966716634 (diff)
stlc-full: implement multisets
-rw-r--r--stlc-full.rkt70
1 files changed, 61 insertions, 9 deletions
diff --git a/stlc-full.rkt b/stlc-full.rkt
index 3df2fa8..b355292 100644
--- a/stlc-full.rkt
+++ b/stlc-full.rkt
@@ -149,6 +149,56 @@
[`(ptr ,l ,a) (and (level? l) (symbol? a))]
[_ #f]))
+;; Creates a new multiset from a list.
+(define/contract (list->multiset l) (-> list? dict?)
+ (foldl
+ (λ (x acc)
+ (if (dict-has-key? acc x)
+ (dict-set acc x (+ (dict-ref acc x) 1))
+ (dict-set acc x 1)))
+ #hash() l))
+
+;; Creates a new list from a multiset.
+(define/contract (multiset->list m) (-> dict? list?)
+ (foldl
+ (λ (x acc)
+ (append acc (build-list (dict-ref m x) (λ (_) x))))
+ '() (dict-keys m)))
+
+;; Adds a symbol to a multiset.
+(define/contract (multiset-add m1 s) (-> dict? symbol? dict?)
+ (if (dict-has-key? m1 s)
+ (dict-set m1 s (+ (dict-ref m1 s) 1))
+ (dict-set m1 s 1)))
+
+;; Queries two multisets for equality.
+(define/contract (multiset-eq m1 m2) (-> dict? dict? boolean?)
+ (if (equal? (length m1) (length m2)) #f
+ (foldl
+ (λ (x acc)
+ (if (and acc (dict-has-key? m1 x))
+ (equal? (dict-ref m1 x) (dict-ref m2 x))
+ acc))
+ #t (dict-keys m2))))
+
+;; Unions two multisets. Shared members take the maximum count of each other.
+(define/contract (multiset-union m1 m2) (-> dict? dict? dict?)
+ (foldl
+ (λ (x acc)
+ (if (dict-has-key? acc x)
+ (dict-set acc x (max (dict-ref acc x) (dict-ref m2 x)))
+ (dict-set acc x (dict-ref m2 x))))
+ m1 (dict-keys m2)))
+
+;; Intersects two multisets. Shared members take the minimum count of each other.
+(define/contract (multiset-intersect m1 m2) (-> dict? dict? dict?)
+ (foldl
+ (λ (x acc)
+ (if (dict-has-key? m1 x)
+ (dict-set acc x (min (dict-ref m1 x) (dict-ref m2 x)))
+ acc))
+ #hash() (dict-keys m2)))
+
;; Checks if a level is at its "base" form.
(define/contract (level-base? l) (-> level? boolean?)
(match l
@@ -166,7 +216,7 @@
[(`(+ ,s1 ... ,n1) `(+ ,s2 ... ,n2)) #:when (and (natural? n1) (natural? n2))
(and (equal? n1 n2) (level-eq? `(+ ,@s1) `(+ ,@s2)))]
[(`(+ ,s1 ...) `(+ ,s2 ...))
- todo] ; need to check for equality as multisets
+ (multiset-eq (list->multiset s1) (list->multiset s2))]
[(_ _) #f]))
;; Levels can carry natural numbers, and so we define a stratification between them.
@@ -182,7 +232,8 @@
[(`(+ ,s1 ... ,n) `(+ ,s2 ...)) #:when (natural? n)
(level-eq? `(+ ,@s1) `(+ ,@s2))]
[(`(+ ,s1 ...) `(+ ,s2 ...))
- todo] ; this needs to inspect the levels... ugh
+ (multiset-eq (list->multiset s1)
+ (multiset-intersect (list->multiset s1) (list->multiset s2)))]
[(_ _) #f]))
;; We define a maximum of two levels.
@@ -196,15 +247,15 @@
[(`(+ ,s1 ... ,n1) `(+ ,s2 ... ,n2)) #:when (and (natural? n1) (natural? n2))
(if (equal? s1 s2)
`(+ ,@s1 ,(max n1 n2))
- (level-intersect `(+ ,@s1) `(+ ,@s2)))]
+ (level-union `(+ ,@s1) `(+ ,@s2)))]
[(`(+ ,s1 ... ,n) `(+ ,s2 ...)) #:when (natural? n)
(if (level-geq? s1 s2)
`(+ ,@s1 ,n)
- (level-intersect `(+ ,@s1) `(+ ,@s2)))]
+ (level-union `(+ ,@s1) `(+ ,@s2)))]
[(`(+ ,s1 ...) `(+ ,s2 ... ,n)) #:when (natural? n)
(if (level-geq? s2 s1)
`(+ ,@s2 ,n)
- (level-intersect `(+ ,@s1) `(+ ,@s2)))]
+ (level-union `(+ ,@s1) `(+ ,@s2)))]
[(`(+ ,s ... ,n1) n2) #:when (and (natural? n1) (natural? n2))
`(+ ,s ... ,n1)]
[(n1 `(+ ,s ... ,n2)) #:when (and (natural? n1) (natural? n2))
@@ -214,14 +265,15 @@
[(n `(+ ,s ...)) #:when (natural? n)
`(+ ,@s ,n)]
[(`(+ ,s1 ...) `(+ ,s2 ...))
- (level-intersect `(+ ,@s1) `(+ ,@s2))]))
+ (level-union `(+ ,@s1) `(+ ,@s2))]))
-;; A helper function to perform intersection of levels.
-(define/contract (level-intersect l1 l2) (-> level-base? level-base? level-base?)
+;; A helper function to perform the union of levels.
+(define/contract (level-union l1 l2) (-> level-base? level-base? level-base?)
(match* (l1 l2)
[(0 l) l]
[(l 0) l]
- [(`(+ ,s1 ...) `(+ ,s2 ...)) todo])) ; we need multisets. this is gross.
+ [(`(+ ,s1 ...) `(+ ,s2 ...))
+ `(+ ,@(multiset->list (multiset-union (list->multiset s1) (list->multiset s2))))]))
;; We define addition in terms of our syntactic constructs.
(define/contract (level-add l1 l2) (-> level? level? level?)