aboutsummaryrefslogtreecommitdiff
path: root/src/listcomp.nim
diff options
context:
space:
mode:
authorshirleyquirk2020-08-04 23:53:21 +0000
committerGitHub2020-08-04 23:53:21 +0000
commit731f8d7692bfe08f5fd5890a98797f9b1b8d065d (patch)
tree512a18c53d4e7e6878ea134a2a307e3caec2278d /src/listcomp.nim
parent29e12415d4bae7a08fa9f3024d809b5a3be76ba1 (diff)
parentb2321b0ded6f4a9803daf9dbcbd88d56321a9305 (diff)
Merge branch 'lc_useVersion' into collect
Diffstat (limited to 'src/listcomp.nim')
-rw-r--r--src/listcomp.nim69
1 files changed, 69 insertions, 0 deletions
diff --git a/src/listcomp.nim b/src/listcomp.nim
new file mode 100644
index 0000000..c9f8dc0
--- /dev/null
+++ b/src/listcomp.nim
@@ -0,0 +1,69 @@
+import macros
+
+type ListComprehension = object
+var lc* : ListComprehension
+
+template `|`*(lc: ListComprehension, comp: untyped): untyped = lc
+
+macro `[]`*(lc: ListComprehension, comp, typ: untyped): untyped =
+ ## List comprehension, returns a sequence. `comp` is the actual list
+ ## comprehension, for example ``x | (x <- 1..10, x mod 2 == 0)``. `typ` is
+ ## the type that will be stored inside the result seq.
+ ##
+ ## .. code-block:: nim
+ ##
+ ## echo lc[x | (x <- 1..10, x mod 2 == 0), int]
+ ##
+ ## const n = 20
+ ## echo lc[(x,y,z) | (x <- 1..n, y <- x..n, z <- y..n, x*x + y*y == z*z),
+ ## tuple[a,b,c: int]]
+
+ expectLen(comp, 3)
+ expectKind(comp, nnkInfix)
+ assert($comp[0] == "|")
+
+ result = newCall(
+ newDotExpr(
+ newIdentNode("result"),
+ newIdentNode("add")),
+ comp[1])
+
+ for i in countdown(comp[2].len-1, 0):
+ let x = comp[2][i]
+ expectMinLen(x, 1)
+ if x[0].kind == nnkIdent and x[0].strVal == "<-":
+ expectLen(x, 3)
+ result = newNimNode(nnkForStmt).add(x[1], x[2], result)
+ else:
+ result = newIfStmt((x, result))
+
+ result = newNimNode(nnkCall).add(
+ newNimNode(nnkPar).add(
+ newNimNode(nnkLambda).add(
+ newEmptyNode(),
+ newEmptyNode(),
+ newEmptyNode(),
+ newNimNode(nnkFormalParams).add(
+ newNimNode(nnkBracketExpr).add(
+ newIdentNode("seq"),
+ typ)),
+ newEmptyNode(),
+ newEmptyNode(),
+ newStmtList(
+ newAssignment(
+ newIdentNode("result"),
+ newNimNode(nnkPrefix).add(
+ newIdentNode("@"),
+ newNimNode(nnkBracket))),
+ result))))
+
+
+when isMainModule:
+ var a = lc[x | (x <- 1..10, x mod 2 == 0), int]
+ assert a == @[2, 4, 6, 8, 10]
+
+ const n = 20
+ var b = lc[(x,y,z) | (x <- 1..n, y <- x..n, z <- y..n, x*x + y*y == z*z),
+ tuple[a,b,c: int]]
+ assert b == @[(a: 3, b: 4, c: 5), (a: 5, b: 12, c: 13), (a: 6, b: 8, c: 10),
+(a: 8, b: 15, c: 17), (a: 9, b: 12, c: 15), (a: 12, b: 16, c: 20)]