aboutsummaryrefslogtreecommitdiff
path: root/src/listcomp.nim
blob: c9f8dc031d3e4891a4dd0f93c5461fdbe672d90d (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import macros

type ListComprehension = object
var lc* : ListComprehension

template `|`*(lc: ListComprehension, comp: untyped): untyped = lc

macro `[]`*(lc: ListComprehension, comp, typ: untyped): untyped =
  ## List comprehension, returns a sequence. `comp` is the actual list
  ## comprehension, for example ``x | (x <- 1..10, x mod 2 == 0)``. `typ` is
  ## the type that will be stored inside the result seq.
  ##
  ## .. code-block:: nim
  ##
  ##   echo lc[x | (x <- 1..10, x mod 2 == 0), int]
  ##
  ##   const n = 20
  ##   echo lc[(x,y,z) | (x <- 1..n, y <- x..n, z <- y..n, x*x + y*y == z*z),
  ##           tuple[a,b,c: int]]

  expectLen(comp, 3)
  expectKind(comp, nnkInfix)
  assert($comp[0] == "|")

  result = newCall(
    newDotExpr(
      newIdentNode("result"),
      newIdentNode("add")),
    comp[1])

  for i in countdown(comp[2].len-1, 0):
    let x = comp[2][i]
    expectMinLen(x, 1)
    if x[0].kind == nnkIdent and x[0].strVal == "<-":
      expectLen(x, 3)
      result = newNimNode(nnkForStmt).add(x[1], x[2], result)
    else:
      result = newIfStmt((x, result))

  result = newNimNode(nnkCall).add(
    newNimNode(nnkPar).add(
      newNimNode(nnkLambda).add(
        newEmptyNode(),
        newEmptyNode(),
        newEmptyNode(),
        newNimNode(nnkFormalParams).add(
          newNimNode(nnkBracketExpr).add(
            newIdentNode("seq"),
            typ)),
        newEmptyNode(),
        newEmptyNode(),
        newStmtList(
          newAssignment(
            newIdentNode("result"),
            newNimNode(nnkPrefix).add(
              newIdentNode("@"),
              newNimNode(nnkBracket))),
          result))))


when isMainModule:
  var a = lc[x | (x <- 1..10, x mod 2 == 0), int]
  assert a == @[2, 4, 6, 8, 10]

  const n = 20
  var b = lc[(x,y,z) | (x <- 1..n, y <- x..n, z <- y..n, x*x + y*y == z*z),
             tuple[a,b,c: int]]
  assert b == @[(a: 3, b: 4, c: 5), (a: 5, b: 12, c: 13), (a: 6, b: 8, c: 10),
(a: 8, b: 15, c: 17), (a: 9, b: 12, c: 15), (a: 12, b: 16, c: 20)]