aboutsummaryrefslogtreecommitdiff
path: root/src/utils.nim
diff options
context:
space:
mode:
Diffstat (limited to 'src/utils.nim')
-rw-r--r--src/utils.nim116
1 files changed, 102 insertions, 14 deletions
diff --git a/src/utils.nim b/src/utils.nim
index 986a4a0..662168a 100644
--- a/src/utils.nim
+++ b/src/utils.nim
@@ -1,5 +1,5 @@
import
- future, hashes, options, os, osproc, posix, strutils, tables
+ future, hashes, options, os, osproc, posix, sequtils, strutils, tables
type
HaltError* = object of Exception
@@ -9,6 +9,15 @@ type
color*: Option[bool]
error*: bool
+ User* = tuple[
+ name: string,
+ uid: int,
+ gid: int,
+ groups: seq[int],
+ home: string,
+ shell: string
+ ]
+
const
pkgLibDir* = getenv("PROG_PKGLIBDIR")
localStateDir* = getenv("PROG_LOCALSTATEDIR")
@@ -133,13 +142,14 @@ template blockSignals*(signals: openArray[cint],
finally:
unblock()
-proc forkWait*(call: () -> int): int =
+proc forkWaitInternal(call: () -> int, beforeWait: () -> void): int =
blockSignals(interruptSignals, unblock):
let pid = fork()
if pid == 0:
unblock()
quit(call())
else:
+ beforeWait()
var status: cint = 1
discard waitpid(pid, status, 0)
if WIFEXITED(status):
@@ -148,14 +158,46 @@ proc forkWait*(call: () -> int): int =
discard kill(getpid(), status)
return 1
-proc runProgram*(args: varargs[string]): seq[string] =
- let output = execProcess(args[0], @args[1 .. ^1], options = {})
- if output.len == 0:
- @[]
- elif output.len > 0 and $output[^1] == "\n":
- output[0 .. ^2].split("\n")
- else:
- output.split("\n")
+proc forkWait*(call: () -> int): int =
+ forkWaitInternal(call, proc = discard)
+
+proc forkWaitRedirect*(call: () -> int): tuple[output: seq[string], code: int] =
+ var fd: array[2, cint]
+ discard pipe(fd)
+
+ var data = newSeq[char]()
+
+ let code = forkWaitInternal(() => (block:
+ discard close(fd[0])
+ discard close(1)
+ discard dup(fd[1])
+ discard close(fd[1])
+ discard close(0)
+ discard open("/dev/null")
+ discard close(2)
+ discard open("/dev/null")
+ call()), () => (block:
+ discard close(fd[1])
+ var buffer: array[80, char]
+ while true:
+ let count = read(fd[0], addr(buffer[0]), buffer.len)
+ if count <= 0:
+ break
+ data &= buffer[0 .. count - 1]
+ discard close(fd[0])))
+
+ var output = newStringOfCap(data.len)
+ for c in data:
+ output &= c
+
+ let lines = if output.len == 0:
+ @[]
+ elif output.len > 0 and $output[^1] == "\n":
+ output[0 .. ^2].split("\n")
+ else:
+ output.split("\n")
+
+ (lines, code)
proc setenv*(name: cstring, value: cstring, override: cint): cint
{.importc, header: "<stdlib.h>".}
@@ -163,14 +205,60 @@ proc setenv*(name: cstring, value: cstring, override: cint): cint
proc unsetenv*(name: cstring): cint
{.importc, header: "<stdlib.h>".}
-proc getUser*: (int, string) =
- let uid = getuid()
+proc getgrouplist*(user: cstring, group: Gid, groups: ptr cint, ngroups: var cint): cint
+ {.importc, header: "<grp.h>".}
+
+proc setgroups*(size: csize, groups: ptr cint): cint
+ {.importc, header: "<grp.h>".}
+
+proc getUser(uid: int): User =
while true:
var pw = getpwent()
if pw == nil:
+ endpwent()
raise newException(SystemError, "")
- if pw.pw_uid == uid:
- return (uid.int, $pw.pw_name)
+ if pw.pw_uid.int == uid:
+ var groups: array[100, cint]
+ var ngroups: cint = 100
+ if getgrouplist(pw.pw_name, pw.pw_gid, addr(groups[0]), ngroups) < 0:
+ raise newException(SystemError, "")
+ else:
+ let groupsSeq = groups[0 .. ngroups - 1].map(x => x.int)
+ let res = ($pw.pw_name, pw.pw_uid.int, pw.pw_gid.int, groupsSeq,
+ $pw.pw_dir, $pw.pw_shell)
+ endpwent()
+ return res
+
+let currentUser* = getUser(getuid().int)
+
+let initialUser* = try:
+ let sudoUid = getenv("SUDO_UID")
+ let polkitUid = getenv("PKEXEC_UID")
+
+ let uidString = if sudoUid != nil and sudoUid.len > 0:
+ some($sudoUid)
+ elif polkitUid != nil and polkitUid.len > 0:
+ some($polkitUid)
+ else:
+ none(string)
+
+ let uid = uidString.get.parseInt
+ if uid == 0: none(User) else: some(getUser(uid))
+except:
+ none(User)
+
+proc canDropPrivileges*(): bool =
+ initialUser.isSome
+
+proc dropPrivileges*() =
+ if initialUser.isSome:
+ let user = initialUser.unsafeGet
+ var groups = user.groups.map(x => x.cint)
+ discard setgroups(user.groups.len, addr(groups[0]));
+ discard setgid((Gid) user.gid)
+ discard setuid((Uid) user.uid)
+ discard setenv("HOME", user.home, 1)
+ discard setenv("SHELL", user.shell, 1)
proc toString*[T](arr: array[T, char], length: Option[int]): string =
var workLength = length.get(T.high + 1)