diff options
-rw-r--r-- | README.md | 2 | ||||
-rw-r--r-- | caronte.go | 8 | ||||
-rw-r--r-- | caronte_test.go | 50 | ||||
-rw-r--r-- | connection_handler.go | 200 | ||||
-rw-r--r-- | go.mod | 1 | ||||
-rw-r--r-- | pcap_importer.go | 9 | ||||
-rw-r--r-- | storage.go | 57 | ||||
-rw-r--r-- | storage_test.go | 44 | ||||
-rw-r--r-- | stream_factory.go | 29 | ||||
-rw-r--r-- | stream_handler.go | 310 | ||||
-rw-r--r-- | stream_handler_test.go | 193 |
11 files changed, 631 insertions, 272 deletions
@@ -2,6 +2,7 @@ [![Build Status](https://travis-ci.com/eciavatta/caronte.svg?branch=develop)](https://travis-ci.com/eciavatta/caronte) [![codecov](https://codecov.io/gh/eciavatta/caronte/branch/develop/graph/badge.svg)](https://codecov.io/gh/eciavatta/caronte) +[![GPL License][license-shield]][license-url] <img align="left" src="https://divinacommedia.weebly.com/uploads/5/5/2/3/5523249/1299707879.jpg"> Caronte is a tool to analyze the network flow during capture the flag events of type attack/defence. @@ -9,4 +10,3 @@ It reassembles TCP packets captured in pcap files to rebuild TCP connections, an The patterns can be defined as regex or using protocol specific rules. The connection flows are saved into a database and can be visualized with the web application. REST API are also provided. -Packets can be captured locally on the same machine or can be imported remotely. The streams of bytes extracted from the TCP payload of packets are processed by [Hyperscan](https://github.com/intel/hyperscan), an high-performance regular expression matching library. // TODO
\ No newline at end of file @@ -2,17 +2,19 @@ package main import ( "fmt" + "net" ) func main() { - // testStorage() - storage := NewStorage("localhost", 27017, "testing") + // pattern.Flags |= hyperscan.SomLeftMost + + storage := NewMongoStorage("localhost", 27017, "testing") err := storage.Connect(nil) if err != nil { panic(err) } - importer := NewPcapImporter(&storage, "10.10.10.10") + importer := NewPcapImporter(storage, net.ParseIP("10.10.10.10")) sessionId, err := importer.ImportPcap("capture_00459_20190627165500.pcap") if err != nil { diff --git a/caronte_test.go b/caronte_test.go new file mode 100644 index 0000000..cbf867d --- /dev/null +++ b/caronte_test.go @@ -0,0 +1,50 @@ +package main + +import ( + "context" + "crypto/sha256" + "fmt" + "go.mongodb.org/mongo-driver/mongo" + "go.mongodb.org/mongo-driver/mongo/options" + "os" + "testing" + "time" +) + +var storage Storage +var testContext context.Context + +func TestMain(m *testing.M) { + mongoHost, ok := os.LookupEnv("MONGO_HOST") + if !ok { + mongoHost = "localhost" + } + mongoPort, ok := os.LookupEnv("MONGO_PORT") + if !ok { + mongoPort = "27017" + } + + uniqueDatabaseName := sha256.Sum256([]byte(time.Now().String())) + + client, err := mongo.NewClient(options.Client().ApplyURI(fmt.Sprintf("mongodb://%s:%v", mongoHost, mongoPort))) + if err != nil { + panic("failed to create mongo client") + } + + db := client.Database(fmt.Sprintf("%x", uniqueDatabaseName[:31])) + mongoStorage := MongoStorage{ + client: client, + collections: map[string]*mongo.Collection{testCollection: db.Collection(testCollection)}, + } + + testContext, _ = context.WithTimeout(context.Background(), 10 * time.Second) + + err = mongoStorage.Connect(nil) + if err != nil { + panic(err) + } + storage = &mongoStorage + + exitCode := m.Run() + os.Exit(exitCode) +} diff --git a/connection_handler.go b/connection_handler.go new file mode 100644 index 0000000..36dca6d --- /dev/null +++ b/connection_handler.go @@ -0,0 +1,200 @@ +package main + +import ( + "context" + "encoding/binary" + "fmt" + "github.com/flier/gohs/hyperscan" + "github.com/google/gopacket" + "github.com/google/gopacket/tcpassembly" + "log" + "sync" + "time" +) + +type BiDirectionalStreamFactory struct { + storage Storage + serverIp gopacket.Endpoint + connections map[StreamKey]ConnectionHandler + mConnections sync.Mutex + patterns hyperscan.StreamDatabase + mPatterns sync.Mutex + scratches []*hyperscan.Scratch +} + +type StreamKey [4]gopacket.Endpoint + +type ConnectionHandler interface { + Complete(handler *StreamHandler) + Storage() Storage + Context() context.Context + Patterns() hyperscan.StreamDatabase +} + +type connectionHandlerImpl struct { + storage Storage + net, transport gopacket.Flow + initiator StreamKey + connectionKey string + mComplete sync.Mutex + otherStream *StreamHandler + context context.Context + patterns hyperscan.StreamDatabase +} + +func (factory *BiDirectionalStreamFactory) New(net, transport gopacket.Flow) tcpassembly.Stream { + key := StreamKey{net.Src(), net.Dst(), transport.Src(), transport.Dst()} + invertedKey := StreamKey{net.Dst(), net.Src(), transport.Dst(), transport.Src()} + + factory.mConnections.Lock() + connection, isPresent := factory.connections[invertedKey] + if isPresent { + delete(factory.connections, invertedKey) + } else { + var initiator StreamKey + if net.Src() == factory.serverIp { + initiator = invertedKey + } else { + initiator = key + } + connection = &connectionHandlerImpl{ + storage: factory.storage, + net: net, + transport: transport, + initiator: initiator, + mComplete: sync.Mutex{}, + context: context.Background(), + patterns : factory.patterns, + } + factory.connections[key] = connection + } + factory.mConnections.Unlock() + + streamHandler := NewStreamHandler(connection, key, factory.takeScratch()) + + return &streamHandler +} + +func (factory *BiDirectionalStreamFactory) UpdatePatternsDatabase(database hyperscan.StreamDatabase) { + factory.mPatterns.Lock() + factory.patterns = database + + for _, s := range factory.scratches { + err := s.Realloc(database) + if err != nil { + fmt.Println("failed to realloc an existing scratch") + } + } + + factory.mPatterns.Unlock() +} + +func (ch *connectionHandlerImpl) Complete(handler *StreamHandler) { + ch.mComplete.Lock() + if ch.otherStream == nil { + ch.otherStream = handler + ch.mComplete.Unlock() + return + } + ch.mComplete.Unlock() + + var startedAt, closedAt time.Time + if handler.firstPacketSeen.Before(ch.otherStream.firstPacketSeen) { + startedAt = handler.firstPacketSeen + } else { + startedAt = ch.otherStream.firstPacketSeen + } + + if handler.lastPacketSeen.After(ch.otherStream.lastPacketSeen) { + closedAt = handler.lastPacketSeen + } else { + closedAt = ch.otherStream.lastPacketSeen + } + + var client, server *StreamHandler + if handler.streamKey == ch.initiator { + client = handler + server = ch.otherStream + } else { + client = ch.otherStream + server = handler + } + + ch.generateConnectionKey(startedAt) + + _, err := ch.storage.InsertOne(ch.context, "connections", OrderedDocument{ + {"_id", ch.connectionKey}, + {"ip_src", ch.initiator[0].String()}, + {"ip_dst", ch.initiator[1].String()}, + {"port_src", ch.initiator[2].String()}, + {"port_dst", ch.initiator[3].String()}, + {"started_at", startedAt}, + {"closed_at", closedAt}, + {"client_bytes", client.streamLength}, + {"server_bytes", server.streamLength}, + {"client_documents", len(client.documentsKeys)}, + {"server_documents", len(server.documentsKeys)}, + {"processed_at", time.Now()}, + }) + if err != nil { + log.Println("error inserting document on collection connections with _id = ", ch.connectionKey) + } + + streamsIds := append(client.documentsKeys, server.documentsKeys...) + n, err := ch.storage.UpdateOne(ch.context, "connection_streams", + UnorderedDocument{"_id": UnorderedDocument{"$in": streamsIds}}, + UnorderedDocument{"connection_id": ch.connectionKey}, + false) + if err != nil { + log.Println("failed to update connection streams", err) + } + if n != len(streamsIds) { + log.Println("failed to update all connections streams") + } +} + +func (ch *connectionHandlerImpl) Storage() Storage { + return ch.storage +} + +func (ch *connectionHandlerImpl) Context() context.Context { + return ch.context +} + +func (ch *connectionHandlerImpl) Patterns() hyperscan.StreamDatabase { + return ch.patterns +} + +func (ch *connectionHandlerImpl) generateConnectionKey(firstPacketSeen time.Time) { + hash := make([]byte, 16) + binary.BigEndian.PutUint64(hash, uint64(firstPacketSeen.UnixNano())) + binary.BigEndian.PutUint64(hash[8:], ch.net.FastHash()^ch.transport.FastHash()) + + ch.connectionKey = fmt.Sprintf("%x", hash) +} + +func (factory *BiDirectionalStreamFactory) takeScratch() *hyperscan.Scratch { + factory.mPatterns.Lock() + defer factory.mPatterns.Unlock() + + if len(factory.scratches) == 0 { + scratch, err := hyperscan.NewScratch(factory.patterns) + if err != nil { + fmt.Println("failed to alloc a new scratch") + } + + return scratch + } + + index := len(factory.scratches) - 1 + scratch := factory.scratches[index] + factory.scratches = factory.scratches[:index] + + return scratch +} + +func (factory *BiDirectionalStreamFactory) releaseScratch(scratch *hyperscan.Scratch) { + factory.mPatterns.Lock() + factory.scratches = append(factory.scratches, scratch) + factory.mPatterns.Unlock() +} @@ -5,6 +5,7 @@ go 1.14 require ( github.com/flier/gohs v1.0.0 github.com/google/gopacket v1.1.17 + github.com/stretchr/testify v1.3.0 go.mongodb.org/mongo-driver v1.3.1 golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3 ) diff --git a/pcap_importer.go b/pcap_importer.go index 9428b29..c6260de 100644 --- a/pcap_importer.go +++ b/pcap_importer.go @@ -24,7 +24,7 @@ const importedPcapsCollectionName = "imported_pcaps" type PcapImporter struct { - storage *Storage + storage Storage streamPool *tcpassembly.StreamPool assemblers []*tcpassembly.Assembler sessions map[string]context.CancelFunc @@ -36,10 +36,11 @@ type PcapImporter struct { type flowCount [2]int -func NewPcapImporter(storage *Storage, serverIp string) *PcapImporter { +func NewPcapImporter(storage Storage, serverIp net.IP) *PcapImporter { + serverEndpoint := layers.NewIPEndpoint(serverIp) streamFactory := &BiDirectionalStreamFactory{ storage: storage, - serverIp: serverIp, + serverIp: serverEndpoint, } streamPool := tcpassembly.NewStreamPool(streamFactory) @@ -50,7 +51,7 @@ func NewPcapImporter(storage *Storage, serverIp string) *PcapImporter { sessions: make(map[string]context.CancelFunc), mAssemblers: sync.Mutex{}, mSessions: sync.Mutex{}, - serverIp: layers.NewIPEndpoint(net.ParseIP(serverIp)), + serverIp: serverEndpoint, } } @@ -14,7 +14,13 @@ import ( const defaultConnectionTimeout = 10*time.Second const defaultOperationTimeout = 3*time.Second -type Storage struct { +type Storage interface { + InsertOne(ctx context.Context, collectionName string, document interface{}) (interface{}, error) + UpdateOne(ctx context.Context, collectionName string, filter interface{}, update interface {}, upsert bool) (interface{}, error) + FindOne(ctx context.Context, collectionName string, filter interface{}) (UnorderedDocument, error) +} + +type MongoStorage struct { client *mongo.Client collections map[string]*mongo.Collection } @@ -22,7 +28,7 @@ type Storage struct { type OrderedDocument = bson.D type UnorderedDocument = bson.M -func NewStorage(uri string, port int, database string) Storage { +func NewMongoStorage(uri string, port int, database string) *MongoStorage { opt := options.Client() opt.ApplyURI(fmt.Sprintf("mongodb://%s:%v", uri, port)) client, err := mongo.NewClient(opt) @@ -36,13 +42,13 @@ func NewStorage(uri string, port int, database string) Storage { "connections": db.Collection("connections"), } - return Storage{ + return &MongoStorage{ client: client, collections: colls, } } -func (storage *Storage) Connect(ctx context.Context) error { +func (storage *MongoStorage) Connect(ctx context.Context) error { if ctx == nil { ctx, _ = context.WithTimeout(context.Background(), defaultConnectionTimeout) } @@ -50,7 +56,7 @@ func (storage *Storage) Connect(ctx context.Context) error { return storage.client.Connect(ctx) } -func (storage *Storage) InsertOne(ctx context.Context, collectionName string, +func (storage *MongoStorage) InsertOne(ctx context.Context, collectionName string, document interface{}) (interface{}, error) { collection, ok := storage.collections[collectionName] @@ -70,7 +76,7 @@ func (storage *Storage) InsertOne(ctx context.Context, collectionName string, return result.InsertedID, nil } -func (storage *Storage) UpdateOne(ctx context.Context, collectionName string, +func (storage *MongoStorage) UpdateOne(ctx context.Context, collectionName string, filter interface{}, update interface {}, upsert bool) (interface{}, error) { collection, ok := storage.collections[collectionName] @@ -97,7 +103,30 @@ func (storage *Storage) UpdateOne(ctx context.Context, collectionName string, return result.ModifiedCount == 1, nil } -func (storage *Storage) FindOne(ctx context.Context, collectionName string, +func (storage *MongoStorage) UpdateMany(ctx context.Context, collectionName string, + filter interface{}, update interface {}, upsert bool) (interface{}, error) { + + collection, ok := storage.collections[collectionName] + if !ok { + return nil, errors.New("invalid collection: " + collectionName) + } + + if ctx == nil { + ctx, _ = context.WithTimeout(context.Background(), defaultOperationTimeout) + } + + opts := options.Update().SetUpsert(upsert) + update = bson.D{{"$set", update}} + + result, err := collection.UpdateMany(ctx, filter, update, opts) + if err != nil { + return nil, err + } + + return result.ModifiedCount, nil +} + +func (storage *MongoStorage) FindOne(ctx context.Context, collectionName string, filter interface{}) (UnorderedDocument, error) { collection, ok := storage.collections[collectionName] @@ -121,17 +150,3 @@ func (storage *Storage) FindOne(ctx context.Context, collectionName string, return result, nil } - - -func testStorage() { - storage := NewStorage("localhost", 27017, "testing") - _ = storage.Connect(nil) - - id, err := storage.InsertOne(nil, "connections", bson.M{"_id": "provaaa"}) - if err != nil { - panic(err) - } else { - fmt.Println(id) - } - -} diff --git a/storage_test.go b/storage_test.go index 6b36833..b46b60a 100644 --- a/storage_test.go +++ b/storage_test.go @@ -1,21 +1,11 @@ package main import ( - "crypto/sha256" - "fmt" - "go.mongodb.org/mongo-driver/mongo" - "go.mongodb.org/mongo-driver/mongo/options" - "golang.org/x/net/context" - "os" "testing" - "time" ) const testCollection = "characters" -var storage Storage -var testContext context.Context - func testInsert(t *testing.T) { // insert a document in an invalid connection insertedId, err := storage.InsertOne(testContext, "invalid_collection", @@ -98,37 +88,3 @@ func TestBasicOperations(t *testing.T) { t.Run("testInsert", testInsert) t.Run("testFindOne", testFindOne) } - -func TestMain(m *testing.M) { - mongoHost, ok := os.LookupEnv("MONGO_HOST") - if !ok { - mongoHost = "localhost" - } - mongoPort, ok := os.LookupEnv("MONGO_PORT") - if !ok { - mongoPort = "27017" - } - - uniqueDatabaseName := sha256.Sum256([]byte(time.Now().String())) - - client, err := mongo.NewClient(options.Client().ApplyURI(fmt.Sprintf("mongodb://%s:%v", mongoHost, mongoPort))) - if err != nil { - panic("failed to create mongo client") - } - - db := client.Database(fmt.Sprintf("%x", uniqueDatabaseName[:31])) - storage = Storage{ - client: client, - collections: map[string]*mongo.Collection{testCollection: db.Collection(testCollection)}, - } - - testContext, _ = context.WithTimeout(context.Background(), 10 * time.Second) - - err = storage.Connect(nil) - if err != nil { - panic(err) - } - - exitCode := m.Run() - os.Exit(exitCode) -} diff --git a/stream_factory.go b/stream_factory.go deleted file mode 100644 index e1d76a4..0000000 --- a/stream_factory.go +++ /dev/null @@ -1,29 +0,0 @@ -package main - -import ( - "github.com/google/gopacket" - "github.com/google/gopacket/tcpassembly" -) - -type BiDirectionalStreamFactory struct { - storage *Storage - serverIp string -} - -// httpStream will handle the actual decoding of http requests. -type uniDirectionalStream struct { - net, transport gopacket.Flow - r StreamHandler -} - -func (h *BiDirectionalStreamFactory) New(net, transport gopacket.Flow) tcpassembly.Stream { - hstream := &uniDirectionalStream{ - net: net, - transport: transport, - r: NewStreamHandler(), - } - // go hstream.run() // Important... we must guarantee that data from the tcpreader stream is read. - - // StreamHandler implements tcpassembly.Stream, so we can return a pointer to it. - return &hstream.r -} diff --git a/stream_handler.go b/stream_handler.go index ad59856..80d91d6 100644 --- a/stream_handler.go +++ b/stream_handler.go @@ -1,206 +1,176 @@ -// Package tcpreader provides an implementation for tcpassembly.Stream which presents -// the caller with an io.Reader for easy processing. -// -// The assembly package handles packet data reordering, but its output is -// library-specific, thus not usable by the majority of external Go libraries. -// The io.Reader interface, on the other hand, is used throughout much of Go -// code as an easy mechanism for reading in data streams and decoding them. For -// example, the net/http package provides the ReadRequest function, which can -// parse an HTTP request from a live data stream, just what we'd want when -// sniffing HTTP traffic. Using StreamHandler, this is relatively easy to set -// up: -// -// // Create our StreamFactory -// type httpStreamFactory struct {} -// func (f *httpStreamFactory) New(a, b gopacket.Flow) { -// r := tcpreader.NewReaderStream(false) -// go printRequests(r) -// return &r -// } -// func printRequests(r io.Reader) { -// // Convert to bufio, since that's what ReadRequest wants. -// buf := bufio.NewReader(r) -// for { -// if req, err := http.ReadRequest(buf); err == io.EOF { -// return -// } else if err != nil { -// log.Println("Error parsing HTTP requests:", err) -// } else { -// fmt.Println("HTTP REQUEST:", req) -// fmt.Println("Body contains", tcpreader.DiscardBytesToEOF(req.Body), "bytes") -// } -// } -// } -// -// Using just this code, we're able to reference a powerful, built-in library -// for HTTP request parsing to do all the dirty-work of parsing requests from -// the wire in real-time. Pass this stream factory to an tcpassembly.StreamPool, -// start up an tcpassembly.Assembler, and you're good to go! package main import ( -"errors" -"github.com/google/gopacket/tcpassembly" -"io" + "bytes" + "encoding/binary" + "fmt" + "github.com/flier/gohs/hyperscan" + "github.com/google/gopacket/tcpassembly" + "log" + "time" ) -var discardBuffer = make([]byte, 4096) - -// DiscardBytesToFirstError will read in all bytes up to the first error -// reported by the given reader, then return the number of bytes discarded -// and the error encountered. -func DiscardBytesToFirstError(r io.Reader) (discarded int, err error) { - for { - n, e := r.Read(discardBuffer) - discarded += n - if e != nil { - return discarded, e - } - } -} +const MaxDocumentSize = 1024 * 1024 +const InitialBlockCount = 1024 +const InitialPatternSliceSize = 8 -// DiscardBytesToEOF will read in all bytes from a Reader until it -// encounters an io.EOF, then return the number of bytes. Be careful -// of this... if used on a Reader that returns a non-io.EOF error -// consistently, this will loop forever discarding that error while -// it waits for an EOF. -func DiscardBytesToEOF(r io.Reader) (discarded int) { - for { - n, e := DiscardBytesToFirstError(r) - discarded += n - if e == io.EOF { - return - } - } -} - -// StreamHandler implements both tcpassembly.Stream and io.Reader. You can use it -// as a building block to make simple, easy stream handlers. -// // IMPORTANT: If you use a StreamHandler, you MUST read ALL BYTES from it, // quickly. Not reading available bytes will block TCP stream reassembly. It's // a common pattern to do this by starting a goroutine in the factory's New // method: -// -// type myStreamHandler struct { -// r StreamHandler -// } -// func (m *myStreamHandler) run() { -// // Do something here that reads all of the StreamHandler, or your assembly -// // will block. -// fmt.Println(tcpreader.DiscardBytesToEOF(&m.r)) -// } -// func (f *myStreamFactory) New(a, b gopacket.Flow) tcpassembly.Stream { -// s := &myStreamHandler{} -// go s.run() -// // Return the StreamHandler as the stream that assembly should populate. -// return &s.r -// } type StreamHandler struct { - ReaderStreamOptions - reassembled chan []tcpassembly.Reassembly - done chan bool - current []tcpassembly.Reassembly - closed bool - lossReported bool - first bool - initiated bool + connection ConnectionHandler + streamKey StreamKey + buffer *bytes.Buffer + indexes []int + timestamps []time.Time + lossBlocks []bool + currentIndex int + firstPacketSeen time.Time + lastPacketSeen time.Time + documentsKeys []string + streamLength int + patternStream hyperscan.Stream + patternMatches map[uint][]PatternSlice } -// ReaderStreamOptions provides user-resettable options for a StreamHandler. -type ReaderStreamOptions struct { - // LossErrors determines whether this stream will return - // ReaderStreamDataLoss errors from its Read function whenever it - // determines data has been lost. - LossErrors bool -} +type PatternSlice [2]uint64 // NewReaderStream returns a new StreamHandler object. -func NewStreamHandler() StreamHandler { - r := StreamHandler{ - reassembled: make(chan []tcpassembly.Reassembly), - done: make(chan bool), - first: true, - initiated: true, +func NewStreamHandler(connection ConnectionHandler, key StreamKey, scratch *hyperscan.Scratch) StreamHandler { + handler := StreamHandler{ + connection: connection, + streamKey: key, + buffer: new(bytes.Buffer), + indexes: make([]int, 0, InitialBlockCount), + timestamps: make([]time.Time, 0, InitialBlockCount), + lossBlocks: make([]bool, 0, InitialBlockCount), + documentsKeys: make([]string, 0, 1), // most of the time the stream fit in one document + patternMatches: make(map[uint][]PatternSlice, 10), // TODO: change with exactly value } - return r + + stream, err := connection.Patterns().Open(0, scratch, handler.onMatch, nil) + if err != nil { + log.Println("failed to create a stream: ", err) + } + handler.patternStream = stream + + return handler } // Reassembled implements tcpassembly.Stream's Reassembled function. -func (r *StreamHandler) Reassembled(reassembly []tcpassembly.Reassembly) { - if !r.initiated { - panic("StreamHandler not created via NewReaderStream") +func (sh *StreamHandler) Reassembled(reassembly []tcpassembly.Reassembly) { + for _, r := range reassembly { + skip := r.Skip + if r.Start { + skip = 0 + sh.firstPacketSeen = r.Seen + } + if r.End { + sh.lastPacketSeen = r.Seen + } + + reassemblyLen := len(r.Bytes) + if reassemblyLen == 0 { + continue + } + + if sh.buffer.Len()+len(r.Bytes)-skip > MaxDocumentSize { + sh.storageCurrentDocument() + sh.resetCurrentDocument() + } + n, err := sh.buffer.Write(r.Bytes[skip:]) + if err != nil { + log.Println("error while copying bytes from Reassemble in stream_handler") + return + } + sh.indexes = append(sh.indexes, sh.currentIndex) + sh.timestamps = append(sh.timestamps, r.Seen) + sh.lossBlocks = append(sh.lossBlocks, skip != 0) + sh.currentIndex += n + sh.streamLength += n + + err = sh.patternStream.Scan(r.Bytes) + if err != nil { + log.Println("failed to scan packet buffer: ", err) + } } - r.reassembled <- reassembly - <-r.done } // ReassemblyComplete implements tcpassembly.Stream's ReassemblyComplete function. -func (r *StreamHandler) ReassemblyComplete() { - close(r.reassembled) - close(r.done) -} +func (sh *StreamHandler) ReassemblyComplete() { + err := sh.patternStream.Close() + if err != nil { + log.Println("failed to close pattern stream: ", err) + } -// stripEmpty strips empty reassembly slices off the front of its current set of -// slices. -func (r *StreamHandler) stripEmpty() { - for len(r.current) > 0 && len(r.current[0].Bytes) == 0 { - r.current = r.current[1:] - r.lossReported = false + if sh.currentIndex > 0 { + sh.storageCurrentDocument() } + sh.connection.Complete(sh) } -// DataLost is returned by the StreamHandler's Read function when it encounters -// a Reassembly with Skip != 0. -var DataLost = errors.New("lost data") - -// Read implements io.Reader's Read function. -// Given a byte slice, it will either copy a non-zero number of bytes into -// that slice and return the number of bytes and a nil error, or it will -// leave slice p as is and return 0, io.EOF. -func (r *StreamHandler) Read(p []byte) (int, error) { - if !r.initiated { - panic("StreamHandler not created via NewReaderStream") - } - var ok bool - r.stripEmpty() - for !r.closed && len(r.current) == 0 { - if r.first { - r.first = false - } else { - r.done <- true - } - if r.current, ok = <-r.reassembled; ok { - r.stripEmpty() - } else { - r.closed = true - } +func (sh *StreamHandler) resetCurrentDocument() { + sh.buffer.Reset() + sh.indexes = sh.indexes[:0] + sh.timestamps = sh.timestamps[:0] + sh.lossBlocks = sh.lossBlocks[:0] + sh.currentIndex = 0 + + for _, val := range sh.patternMatches { + val = val[:0] } - if len(r.current) > 0 { - current := &r.current[0] - if r.LossErrors && !r.lossReported && current.Skip != 0 { - r.lossReported = true - return 0, DataLost +} + +func (sh *StreamHandler) onMatch(id uint, from uint64, to uint64, flags uint, context interface{}) error { + patternSlices, isPresent := sh.patternMatches[id] + if isPresent { + if len(patternSlices) > 0 { + lastElement := &patternSlices[len(patternSlices)-1] + if lastElement[0] == from { // make the regex greedy to match the maximum number of chars + lastElement[1] = to + return nil + } } - length := copy(p, current.Bytes) - current.Bytes = current.Bytes[length:] - return length, nil + // new from == new match + sh.patternMatches[id] = append(patternSlices, PatternSlice{from, to}) + } else { + patternSlices = make([]PatternSlice, InitialPatternSliceSize) + patternSlices[0] = PatternSlice{from, to} + sh.patternMatches[id] = patternSlices } - return 0, io.EOF + + return nil } -// Close implements io.Closer's Close function, making StreamHandler a -// io.ReadCloser. It discards all remaining bytes in the reassembly in a -// manner that's safe for the assembler (IE: it doesn't block). -func (r *StreamHandler) Close() error { - r.current = nil - r.closed = true - for { - if _, ok := <-r.reassembled; !ok { - return nil - } - r.done <- true +func (sh *StreamHandler) storageCurrentDocument() { + streamKey := sh.generateDocumentKey() + + _, err := sh.connection.Storage().InsertOne(sh.connection.Context(), "connection_streams", OrderedDocument{ + {"_id", streamKey}, + {"connection_id", nil}, + {"document_index", len(sh.documentsKeys)}, + {"payload", sh.buffer.Bytes()}, + {"blocks_indexes", sh.indexes}, + {"blocks_timestamps", sh.timestamps}, + {"blocks_loss", sh.lossBlocks}, + {"pattern_matches", sh.patternMatches}, + }) + + if err != nil { + log.Println("failed to insert connection stream: ", err) } + + sh.documentsKeys = append(sh.documentsKeys, streamKey) } +func (sh *StreamHandler) generateDocumentKey() string { + hash := make([]byte, 16) + endpointsHash := sh.streamKey[0].FastHash() ^ sh.streamKey[1].FastHash() ^ + sh.streamKey[2].FastHash() ^ sh.streamKey[3].FastHash() + binary.BigEndian.PutUint64(hash, endpointsHash) + binary.BigEndian.PutUint64(hash[8:], uint64(sh.timestamps[0].UnixNano())) + binary.BigEndian.PutUint16(hash[8:], uint16(len(sh.documentsKeys))) + return fmt.Sprintf("%x", hash) +} diff --git a/stream_handler_test.go b/stream_handler_test.go new file mode 100644 index 0000000..a1004dc --- /dev/null +++ b/stream_handler_test.go @@ -0,0 +1,193 @@ +package main + +import ( + "context" + "errors" + "github.com/flier/gohs/hyperscan" + "github.com/google/gopacket/layers" + "github.com/google/gopacket/tcpassembly" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "math/rand" + "net" + "testing" + "time" +) + +const testSrcIp = "10.10.10.100" +const testDstIp = "10.10.10.1" +const srcPort = 44444 +const dstPort = 8080 + + +func TestReassemblingEmptyStream(t *testing.T) { + patterns, err := hyperscan.NewStreamDatabase(hyperscan.NewPattern("/nope/", 0)) + require.Nil(t, err) + streamHandler := createTestStreamHandler(t, testStorage{}, patterns) + + streamHandler.Reassembled([]tcpassembly.Reassembly{}) + assert.Len(t, streamHandler.indexes, 0, "indexes") + assert.Len(t, streamHandler.timestamps, 0, "timestamps") + assert.Len(t, streamHandler.lossBlocks, 0) + assert.Zero(t, streamHandler.currentIndex) + assert.Zero(t, streamHandler.firstPacketSeen) + assert.Zero(t, streamHandler.lastPacketSeen) + assert.Len(t, streamHandler.documentsKeys, 0) + assert.Zero(t, streamHandler.streamLength) + assert.Len(t, streamHandler.patternMatches, 0) + + expected := 0 + streamHandler.connection.(*testConnectionHandler).onComplete = func(handler *StreamHandler) { + expected = 42 + } + streamHandler.ReassemblyComplete() + assert.Equal(t, 42, expected) +} + +func TestReassemblingSingleDocumentStream(t *testing.T) { + patterns, err := hyperscan.NewStreamDatabase(hyperscan.NewPattern("/impossible_to_match/", 0)) + require.Nil(t, err) + storage := &testStorage{} + streamHandler := createTestStreamHandler(t, storage, patterns) + + payloadLen := 256 + firstTime := time.Unix(1000000000, 0) + middleTime := time.Unix(1000000010, 0) + lastTime := time.Unix(1000000020, 0) + data := make([]byte, MaxDocumentSize) + rand.Read(data) + reassembles := make([]tcpassembly.Reassembly, MaxDocumentSize / payloadLen) + indexes := make([]int, MaxDocumentSize / payloadLen) + timestamps := make([]time.Time, MaxDocumentSize / payloadLen) + lossBlocks := make([]bool, MaxDocumentSize / payloadLen) + for i := 0; i < len(reassembles); i++ { + var seen time.Time + if i == 0 { + seen = firstTime + } else if i == len(reassembles)-1 { + seen = lastTime + } else { + seen = middleTime + } + + reassembles[i] = tcpassembly.Reassembly{ + Bytes: data[i*payloadLen:(i+1)*payloadLen], + Skip: 0, + Start: i == 0, + End: i == len(reassembles)-1, + Seen: seen, + } + indexes[i] = i*payloadLen + timestamps[i] = seen + } + + inserted := false + storage.insertFunc = func(ctx context.Context, collectionName string, document interface{}) (i interface{}, err error) { + od := document.(OrderedDocument) + assert.Equal(t, "connection_streams", collectionName) + assert.Equal(t, "bb41a60281cfae830000b6b3a7640000", od[0].Value) + assert.Equal(t, nil, od[1].Value) + assert.Equal(t, 0, od[2].Value) + assert.Equal(t, data, od[3].Value) + assert.Equal(t, indexes, od[4].Value) + assert.Equal(t, timestamps, od[5].Value) + assert.Equal(t, lossBlocks, od[6].Value) + assert.Len(t, od[7].Value, 0) + inserted = true + return nil, nil + } + + streamHandler.Reassembled(reassembles) + if !assert.Equal(t, false, inserted) { + inserted = false + } + + assert.Equal(t, data, streamHandler.buffer.Bytes(), "buffer should contains the same bytes of reassembles") + assert.Equal(t, indexes, streamHandler.indexes, "indexes") + assert.Equal(t, timestamps, streamHandler.timestamps, "timestamps") + assert.Equal(t, lossBlocks, streamHandler.lossBlocks, "lossBlocks") + assert.Equal(t, len(data), streamHandler.currentIndex) + assert.Equal(t, firstTime, streamHandler.firstPacketSeen) + assert.Equal(t, lastTime, streamHandler.lastPacketSeen) + assert.Len(t, streamHandler.documentsKeys, 0) + assert.Equal(t, len(data), streamHandler.streamLength) + assert.Len(t, streamHandler.patternMatches, 0) + + completed := false + streamHandler.connection.(*testConnectionHandler).onComplete = func(handler *StreamHandler) { + completed = true + } + streamHandler.ReassemblyComplete() + assert.Equal(t, true, inserted, "inserted") + assert.Equal(t, true, completed, "completed") +} + + +func createTestStreamHandler(t *testing.T, storage Storage, patterns hyperscan.StreamDatabase) StreamHandler { + testConnectionHandler := &testConnectionHandler{ + storage: storage, + context: context.Background(), + patterns: patterns, + } + + scratch, err := hyperscan.NewScratch(patterns) + require.Nil(t, err) + + srcIp := layers.NewIPEndpoint(net.ParseIP(testSrcIp)) + dstIp := layers.NewIPEndpoint(net.ParseIP(testDstIp)) + srcPort := layers.NewTCPPortEndpoint(srcPort) + dstPort := layers.NewTCPPortEndpoint(dstPort) + + return NewStreamHandler(testConnectionHandler, StreamKey{srcIp, dstIp, srcPort, dstPort}, scratch) +} + +type testConnectionHandler struct { + storage Storage + context context.Context + patterns hyperscan.StreamDatabase + onComplete func(*StreamHandler) +} + +func (tch *testConnectionHandler) Storage() Storage { + return tch.storage +} + +func (tch *testConnectionHandler) Context() context.Context { + return tch.context +} + +func (tch *testConnectionHandler) Patterns() hyperscan.StreamDatabase { + return tch.patterns +} + +func (tch *testConnectionHandler) Complete(handler *StreamHandler) { + tch.onComplete(handler) +} + +type testStorage struct { + insertFunc func(ctx context.Context, collectionName string, document interface{}) (interface{}, error) + updateOne func(ctx context.Context, collectionName string, filter interface{}, update interface {}, upsert bool) (interface{}, error) + findOne func(ctx context.Context, collectionName string, filter interface{}) (UnorderedDocument, error) +} + +func (ts testStorage) InsertOne(ctx context.Context, collectionName string, document interface{}) (interface{}, error) { + if ts.insertFunc != nil { + return ts.insertFunc(ctx, collectionName, document) + } + return nil, errors.New("not implemented") +} + +func (ts testStorage) UpdateOne(ctx context.Context, collectionName string, filter interface{}, update interface {}, + upsert bool) (interface{}, error) { + if ts.updateOne != nil { + return ts.updateOne(ctx, collectionName, filter, update, upsert) + } + return nil, errors.New("not implemented") +} + +func (ts testStorage) FindOne(ctx context.Context, collectionName string, filter interface{}) (UnorderedDocument, error) { + if ts.insertFunc != nil { + return ts.findOne(ctx, collectionName, filter) + } + return nil, errors.New("not implemented") +} |