diff options
Diffstat (limited to 'src/utils.nim')
-rw-r--r-- | src/utils.nim | 116 |
1 files changed, 102 insertions, 14 deletions
diff --git a/src/utils.nim b/src/utils.nim index 986a4a0..662168a 100644 --- a/src/utils.nim +++ b/src/utils.nim @@ -1,5 +1,5 @@ import - future, hashes, options, os, osproc, posix, strutils, tables + future, hashes, options, os, osproc, posix, sequtils, strutils, tables type HaltError* = object of Exception @@ -9,6 +9,15 @@ type color*: Option[bool] error*: bool + User* = tuple[ + name: string, + uid: int, + gid: int, + groups: seq[int], + home: string, + shell: string + ] + const pkgLibDir* = getenv("PROG_PKGLIBDIR") localStateDir* = getenv("PROG_LOCALSTATEDIR") @@ -133,13 +142,14 @@ template blockSignals*(signals: openArray[cint], finally: unblock() -proc forkWait*(call: () -> int): int = +proc forkWaitInternal(call: () -> int, beforeWait: () -> void): int = blockSignals(interruptSignals, unblock): let pid = fork() if pid == 0: unblock() quit(call()) else: + beforeWait() var status: cint = 1 discard waitpid(pid, status, 0) if WIFEXITED(status): @@ -148,14 +158,46 @@ proc forkWait*(call: () -> int): int = discard kill(getpid(), status) return 1 -proc runProgram*(args: varargs[string]): seq[string] = - let output = execProcess(args[0], @args[1 .. ^1], options = {}) - if output.len == 0: - @[] - elif output.len > 0 and $output[^1] == "\n": - output[0 .. ^2].split("\n") - else: - output.split("\n") +proc forkWait*(call: () -> int): int = + forkWaitInternal(call, proc = discard) + +proc forkWaitRedirect*(call: () -> int): tuple[output: seq[string], code: int] = + var fd: array[2, cint] + discard pipe(fd) + + var data = newSeq[char]() + + let code = forkWaitInternal(() => (block: + discard close(fd[0]) + discard close(1) + discard dup(fd[1]) + discard close(fd[1]) + discard close(0) + discard open("/dev/null") + discard close(2) + discard open("/dev/null") + call()), () => (block: + discard close(fd[1]) + var buffer: array[80, char] + while true: + let count = read(fd[0], addr(buffer[0]), buffer.len) + if count <= 0: + break + data &= buffer[0 .. count - 1] + discard close(fd[0]))) + + var output = newStringOfCap(data.len) + for c in data: + output &= c + + let lines = if output.len == 0: + @[] + elif output.len > 0 and $output[^1] == "\n": + output[0 .. ^2].split("\n") + else: + output.split("\n") + + (lines, code) proc setenv*(name: cstring, value: cstring, override: cint): cint {.importc, header: "<stdlib.h>".} @@ -163,14 +205,60 @@ proc setenv*(name: cstring, value: cstring, override: cint): cint proc unsetenv*(name: cstring): cint {.importc, header: "<stdlib.h>".} -proc getUser*: (int, string) = - let uid = getuid() +proc getgrouplist*(user: cstring, group: Gid, groups: ptr cint, ngroups: var cint): cint + {.importc, header: "<grp.h>".} + +proc setgroups*(size: csize, groups: ptr cint): cint + {.importc, header: "<grp.h>".} + +proc getUser(uid: int): User = while true: var pw = getpwent() if pw == nil: + endpwent() raise newException(SystemError, "") - if pw.pw_uid == uid: - return (uid.int, $pw.pw_name) + if pw.pw_uid.int == uid: + var groups: array[100, cint] + var ngroups: cint = 100 + if getgrouplist(pw.pw_name, pw.pw_gid, addr(groups[0]), ngroups) < 0: + raise newException(SystemError, "") + else: + let groupsSeq = groups[0 .. ngroups - 1].map(x => x.int) + let res = ($pw.pw_name, pw.pw_uid.int, pw.pw_gid.int, groupsSeq, + $pw.pw_dir, $pw.pw_shell) + endpwent() + return res + +let currentUser* = getUser(getuid().int) + +let initialUser* = try: + let sudoUid = getenv("SUDO_UID") + let polkitUid = getenv("PKEXEC_UID") + + let uidString = if sudoUid != nil and sudoUid.len > 0: + some($sudoUid) + elif polkitUid != nil and polkitUid.len > 0: + some($polkitUid) + else: + none(string) + + let uid = uidString.get.parseInt + if uid == 0: none(User) else: some(getUser(uid)) +except: + none(User) + +proc canDropPrivileges*(): bool = + initialUser.isSome + +proc dropPrivileges*() = + if initialUser.isSome: + let user = initialUser.unsafeGet + var groups = user.groups.map(x => x.cint) + discard setgroups(user.groups.len, addr(groups[0])); + discard setgid((Gid) user.gid) + discard setuid((Uid) user.uid) + discard setenv("HOME", user.home, 1) + discard setenv("SHELL", user.shell, 1) proc toString*[T](arr: array[T, char], length: Option[int]): string = var workLength = length.get(T.high + 1) |