diff options
author | Emiliano Ciavatta | 2020-10-08 12:59:50 +0000 |
---|---|---|
committer | Emiliano Ciavatta | 2020-10-08 12:59:50 +0000 |
commit | 584e25f7940954a51e260a21ca5a819ff4b834d3 (patch) | |
tree | 6fb65a7ab169d0240fbe795cd25be94ca8604b4f /connection_streams_controller.go | |
parent | 2b2b8e66e7244592672c283fe7bb5d9a1fd9da99 (diff) |
Implement download_as and export as pwntools
Diffstat (limited to 'connection_streams_controller.go')
-rw-r--r-- | connection_streams_controller.go | 178 |
1 files changed, 162 insertions, 16 deletions
diff --git a/connection_streams_controller.go b/connection_streams_controller.go index 9d73b0e..98f2aca 100644 --- a/connection_streams_controller.go +++ b/connection_streams_controller.go @@ -3,14 +3,19 @@ package main import ( "bytes" "context" + "fmt" "github.com/eciavatta/caronte/parsers" log "github.com/sirupsen/logrus" + "strings" "time" ) -const InitialPayloadsSize = 1024 -const DefaultQueryFormatLimit = 8024 -const InitialRegexSlicesCount = 8 +const ( + initialPayloadsSize = 1024 + defaultQueryFormatLimit = 8024 + initialRegexSlicesCount = 8 + pwntoolsMaxServerBytes = 20 +) type ConnectionStream struct { ID RowID `bson:"_id"` @@ -26,7 +31,7 @@ type ConnectionStream struct { type PatternSlice [2]uint64 -type Payload struct { +type Message struct { FromClient bool `json:"from_client"` Content string `json:"content"` Metadata parsers.Metadata `json:"metadata"` @@ -42,12 +47,17 @@ type RegexSlice struct { To uint64 `json:"to"` } -type QueryFormat struct { +type GetMessageFormat struct { Format string `form:"format"` Skip uint64 `form:"skip"` Limit uint64 `form:"limit"` } +type DownloadMessageFormat struct { + Format string `form:"format"` + Type string `form:"type"` +} + type ConnectionStreamsController struct { storage Storage } @@ -58,13 +68,18 @@ func NewConnectionStreamsController(storage Storage) ConnectionStreamsController } } -func (csc ConnectionStreamsController) GetConnectionPayload(c context.Context, connectionID RowID, - format QueryFormat) []*Payload { - payloads := make([]*Payload, 0, InitialPayloadsSize) +func (csc ConnectionStreamsController) GetConnectionMessages(c context.Context, connectionID RowID, + format GetMessageFormat) ([]*Message, bool) { + connection := csc.getConnection(c, connectionID) + if connection.ID.IsZero() { + return nil, false + } + + payloads := make([]*Message, 0, initialPayloadsSize) var clientIndex, serverIndex, globalIndex uint64 if format.Limit <= 0 { - format.Limit = DefaultQueryFormatLimit + format.Limit = defaultQueryFormatLimit } var clientBlocksIndex, serverBlocksIndex int @@ -79,8 +94,8 @@ func (csc ConnectionStreamsController) GetConnectionPayload(c context.Context, c return serverBlocksIndex < len(serverStream.BlocksIndexes) } - var payload *Payload - payloadsBuffer := make([]*Payload, 0, 16) + var payload *Message + payloadsBuffer := make([]*Message, 0, 16) contentChunkBuffer := new(bytes.Buffer) var lastContentSlice []byte var sideChanged, lastClient, lastServer bool @@ -97,7 +112,7 @@ func (csc ConnectionStreamsController) GetConnectionPayload(c context.Context, c } size := uint64(end - start) - payload = &Payload{ + payload = &Message{ FromClient: true, Content: DecodeBytes(clientStream.Payload[start:end], format.Format), Index: start, @@ -121,7 +136,7 @@ func (csc ConnectionStreamsController) GetConnectionPayload(c context.Context, c } size := uint64(end - start) - payload = &Payload{ + payload = &Message{ FromClient: false, Content: DecodeBytes(serverStream.Payload[start:end], format.Format), Index: start, @@ -178,11 +193,118 @@ func (csc ConnectionStreamsController) GetConnectionPayload(c context.Context, c if globalIndex > format.Skip+format.Limit { // problem: the last chunk is not parsed, but can be ok because it is not finished updateMetadata() - return payloads + return payloads, true } } - return payloads + return payloads, true +} + +func (csc ConnectionStreamsController) DownloadConnectionMessages(c context.Context, connectionID RowID, + format DownloadMessageFormat) (string, bool) { + connection := csc.getConnection(c, connectionID) + if connection.ID.IsZero() { + return "", false + } + + var sb strings.Builder + includeClient, includeServer := format.Type != "only_server", format.Type != "only_client" + isPwntools := format.Type == "pwntools" + + var clientBlocksIndex, serverBlocksIndex int + var clientDocumentIndex, serverDocumentIndex int + var clientStream ConnectionStream + if includeClient { + clientStream = csc.getConnectionStream(c, connectionID, true, clientDocumentIndex) + } + var serverStream ConnectionStream + if includeServer { + serverStream = csc.getConnectionStream(c, connectionID, false, serverDocumentIndex) + } + + hasClientBlocks := func() bool { + return clientBlocksIndex < len(clientStream.BlocksIndexes) + } + hasServerBlocks := func() bool { + return serverBlocksIndex < len(serverStream.BlocksIndexes) + } + + if isPwntools { + if format.Format == "base32" || format.Format == "base64" { + sb.WriteString("import base64\n") + } + sb.WriteString("from pwn import *\n\n") + sb.WriteString(fmt.Sprintf("p = remote('%s', %d)\n", connection.DestinationIP, connection.DestinationPort)) + } + + lastIsClient, lastIsServer := true, true + for !clientStream.ID.IsZero() || !serverStream.ID.IsZero() { + if hasClientBlocks() && (!hasServerBlocks() || // next payload is from client + clientStream.BlocksTimestamps[clientBlocksIndex].UnixNano() <= + serverStream.BlocksTimestamps[serverBlocksIndex].UnixNano()) { + start := clientStream.BlocksIndexes[clientBlocksIndex] + end := 0 + if clientBlocksIndex < len(clientStream.BlocksIndexes)-1 { + end = clientStream.BlocksIndexes[clientBlocksIndex+1] + } else { + end = len(clientStream.Payload) + } + + if !lastIsClient { + sb.WriteString("\n") + } + lastIsClient = true + lastIsServer = false + if isPwntools { + sb.WriteString(decodePwntools(clientStream.Payload[start:end], true, format.Format)) + } else { + sb.WriteString(DecodeBytes(clientStream.Payload[start:end], format.Format)) + } + clientBlocksIndex++ + } else { // next payload is from server + start := serverStream.BlocksIndexes[serverBlocksIndex] + end := 0 + if serverBlocksIndex < len(serverStream.BlocksIndexes)-1 { + end = serverStream.BlocksIndexes[serverBlocksIndex+1] + } else { + end = len(serverStream.Payload) + } + + if !lastIsServer { + sb.WriteString("\n") + } + lastIsClient = false + lastIsServer = true + if isPwntools { + sb.WriteString(decodePwntools(serverStream.Payload[start:end], false, format.Format)) + } else { + sb.WriteString(DecodeBytes(serverStream.Payload[start:end], format.Format)) + } + serverBlocksIndex++ + } + + if includeClient && !hasClientBlocks() { + clientDocumentIndex++ + clientBlocksIndex = 0 + clientStream = csc.getConnectionStream(c, connectionID, true, clientDocumentIndex) + } + if includeServer && !hasServerBlocks() { + serverDocumentIndex++ + serverBlocksIndex = 0 + serverStream = csc.getConnectionStream(c, connectionID, false, serverDocumentIndex) + } + } + + return sb.String(), true +} + +func (csc ConnectionStreamsController) getConnection(c context.Context, connectionID RowID) Connection { + var connection Connection + if err := csc.storage.Find(Connections).Context(c).Filter(OrderedDocument{{"_id", connectionID}}). + First(&connection); err != nil { + log.WithError(err).WithField("id", connectionID).Panic("failed to get connection") + } + return connection } func (csc ConnectionStreamsController) getConnectionStream(c context.Context, connectionID RowID, fromClient bool, @@ -199,7 +321,7 @@ func (csc ConnectionStreamsController) getConnectionStream(c context.Context, co } func findMatchesBetween(patternMatches map[uint][]PatternSlice, from, to uint64) []RegexSlice { - regexSlices := make([]RegexSlice, 0, InitialRegexSlicesCount) + regexSlices := make([]RegexSlice, 0, initialRegexSlicesCount) for _, slices := range patternMatches { for _, slice := range slices { if from > slice[1] || to <= slice[0] { @@ -225,3 +347,27 @@ func findMatchesBetween(patternMatches map[uint][]PatternSlice, from, to uint64) } return regexSlices } + +func decodePwntools(payload []byte, isClient bool, format string) string { + if !isClient && len(payload) > pwntoolsMaxServerBytes { + payload = payload[len(payload)-pwntoolsMaxServerBytes:] + } + + var content string + switch format { + case "hex": + content = fmt.Sprintf("bytes.fromhex('%s')", DecodeBytes(payload, format)) + case "base32": + content = fmt.Sprintf("base64.b32decode('%s')", DecodeBytes(payload, format)) + case "base64": + content = fmt.Sprintf("base64.b64decode('%s')", DecodeBytes(payload, format)) + default: + content = fmt.Sprintf("'%s'", strings.Replace(DecodeBytes(payload, "ascii"), "'", "\\'", -1)) + } + + if isClient { + return fmt.Sprintf("p.send(%s)\n", content) + } else { + return fmt.Sprintf("p.recvuntil(%s)\n", content) + } +} |