diff options
author | Emiliano Ciavatta | 2020-04-09 10:37:48 +0000 |
---|---|---|
committer | Emiliano Ciavatta | 2020-04-09 10:37:48 +0000 |
commit | 7ca2f30a0eb21e22071f4e6b04a5207fa273d283 (patch) | |
tree | 63acb98147ffda7606bdf81abe2894e5f8363bd9 | |
parent | 0520dab47d61e2c4de246459bf4f5c72d69182d3 (diff) |
Refactor connection_handler
-rw-r--r-- | caronte.go | 6 | ||||
-rw-r--r-- | connection_handler.go | 261 | ||||
-rw-r--r-- | connections.go | 20 | ||||
-rw-r--r-- | rules_manager.go | 12 | ||||
-rw-r--r-- | stream_handler.go | 51 | ||||
-rw-r--r-- | stream_handler_test.go | 13 |
6 files changed, 216 insertions, 147 deletions
@@ -4,7 +4,7 @@ import ( "flag" "fmt" "github.com/gin-gonic/gin" - "log" + log "github.com/sirupsen/logrus" ) func main() { @@ -21,13 +21,13 @@ func main() { storage := NewMongoStorage(*mongoHost, *mongoPort, *dbName) err := storage.Connect(nil) if err != nil { - log.Panicln("failed to connect to MongoDB:", err) + log.WithError(err).Fatal("failed to connect to MongoDB") } router := gin.Default() ApplicationRoutes(router) err = router.Run(fmt.Sprintf("%s:%v", *bindAddress, *bindPort)) if err != nil { - log.Panicln("failed to create the server:", err) + log.WithError(err).Fatal("failed to create the server") } } diff --git a/connection_handler.go b/connection_handler.go index 2e2fa84..4fc2a95 100644 --- a/connection_handler.go +++ b/connection_handler.go @@ -1,95 +1,161 @@ package main import ( - "context" "encoding/binary" - "fmt" "github.com/flier/gohs/hyperscan" "github.com/google/gopacket" "github.com/google/gopacket/tcpassembly" - "log" + log "github.com/sirupsen/logrus" "sync" "time" ) +const initialConnectionsCapacity = 1024 +const initialScannersCapacity = 1024 + type BiDirectionalStreamFactory struct { - storage Storage - serverIp gopacket.Endpoint - connections map[StreamKey]ConnectionHandler - mConnections sync.Mutex - patterns hyperscan.StreamDatabase - mPatterns sync.Mutex - scratches []*hyperscan.Scratch + storage Storage + serverIp gopacket.Endpoint + connections map[StreamFlow]ConnectionHandler + mConnections sync.Mutex + rulesManager *RulesManager + rulesDatabase RulesDatabase + mRulesDatabase sync.Mutex + scanners []Scanner } -type StreamKey [4]gopacket.Endpoint +type StreamFlow [4]gopacket.Endpoint + +type Scanner struct { + scratch *hyperscan.Scratch + version RowID +} type ConnectionHandler interface { Complete(handler *StreamHandler) Storage() Storage - Context() context.Context - Patterns() hyperscan.StreamDatabase + PatternsDatabase() hyperscan.StreamDatabase } type connectionHandlerImpl struct { - storage Storage - net, transport gopacket.Flow - initiator StreamKey - connectionKey string + factory *BiDirectionalStreamFactory + connectionFlow StreamFlow mComplete sync.Mutex otherStream *StreamHandler - context context.Context - patterns hyperscan.StreamDatabase +} + +func NewBiDirectionalStreamFactory(storage Storage, serverIP gopacket.Endpoint, + rulesManager *RulesManager) *BiDirectionalStreamFactory { + + factory := &BiDirectionalStreamFactory{ + storage: storage, + serverIp: serverIP, + connections: make(map[StreamFlow]ConnectionHandler, initialConnectionsCapacity), + mConnections: sync.Mutex{}, + rulesManager: rulesManager, + mRulesDatabase: sync.Mutex{}, + scanners: make([]Scanner, 0, initialScannersCapacity), + } + + go factory.updateRulesDatabaseService() + return factory +} + +func (factory *BiDirectionalStreamFactory) updateRulesDatabaseService() { + for { + select { + case rulesDatabase, ok := <-factory.rulesManager.databaseUpdated: + if !ok { + return + } + factory.mRulesDatabase.Lock() + scanners := factory.scanners + factory.scanners = factory.scanners[:0] + + for _, s := range scanners { + err := s.scratch.Realloc(rulesDatabase.database) + if err != nil { + log.WithError(err).Error("failed to realloc an existing scanner") + } else { + s.version = rulesDatabase.version + factory.scanners = append(factory.scanners, s) + } + } + + factory.rulesDatabase = rulesDatabase + factory.mRulesDatabase.Unlock() + } + } +} + +func (factory *BiDirectionalStreamFactory) takeScanner() Scanner { + factory.mRulesDatabase.Lock() + defer factory.mRulesDatabase.Unlock() + + if len(factory.scanners) == 0 { + scratch, err := hyperscan.NewScratch(factory.rulesDatabase.database) + if err != nil { + log.WithError(err).Fatal("failed to alloc a new scratch") + } + + return Scanner{ + scratch: scratch, + version: factory.rulesDatabase.version, + } + } + + index := len(factory.scanners) - 1 + scanner := factory.scanners[index] + factory.scanners = factory.scanners[:index] + + return scanner +} + +func (factory *BiDirectionalStreamFactory) releaseScanner(scanner Scanner) { + factory.mRulesDatabase.Lock() + defer factory.mRulesDatabase.Unlock() + + if scanner.version != factory.rulesDatabase.version { + err := scanner.scratch.Realloc(factory.rulesDatabase.database) + if err != nil { + log.WithError(err).Error("failed to realloc an existing scanner") + return + } + } + factory.scanners = append(factory.scanners, scanner) } func (factory *BiDirectionalStreamFactory) New(net, transport gopacket.Flow) tcpassembly.Stream { - key := StreamKey{net.Src(), net.Dst(), transport.Src(), transport.Dst()} - invertedKey := StreamKey{net.Dst(), net.Src(), transport.Dst(), transport.Src()} + flow := StreamFlow{net.Src(), net.Dst(), transport.Src(), transport.Dst()} + invertedFlow := StreamFlow{net.Dst(), net.Src(), transport.Dst(), transport.Src()} factory.mConnections.Lock() - connection, isPresent := factory.connections[invertedKey] + connection, isPresent := factory.connections[invertedFlow] if isPresent { - delete(factory.connections, invertedKey) + delete(factory.connections, invertedFlow) } else { - var initiator StreamKey + var connectionFlow StreamFlow if net.Src() == factory.serverIp { - initiator = invertedKey + connectionFlow = invertedFlow } else { - initiator = key + connectionFlow = flow } connection = &connectionHandlerImpl{ - storage: factory.storage, - net: net, - transport: transport, - initiator: initiator, - mComplete: sync.Mutex{}, - context: context.Background(), - patterns: factory.patterns, + connectionFlow: connectionFlow, + mComplete: sync.Mutex{}, + factory: factory, } - factory.connections[key] = connection + factory.connections[flow] = connection } factory.mConnections.Unlock() - streamHandler := NewStreamHandler(connection, key, factory.takeScratch()) + streamHandler := NewStreamHandler(connection, flow, factory.takeScanner()) return &streamHandler } -func (factory *BiDirectionalStreamFactory) UpdatePatternsDatabase(database hyperscan.StreamDatabase) { - factory.mPatterns.Lock() - factory.patterns = database - - for _, s := range factory.scratches { - err := s.Realloc(database) - if err != nil { - fmt.Println("failed to realloc an existing scratch") - } - } - - factory.mPatterns.Unlock() -} - func (ch *connectionHandlerImpl) Complete(handler *StreamHandler) { + ch.factory.releaseScanner(handler.scanner) ch.mComplete.Lock() if ch.otherStream == nil { ch.otherStream = handler @@ -112,7 +178,7 @@ func (ch *connectionHandlerImpl) Complete(handler *StreamHandler) { } var client, server *StreamHandler - if handler.streamKey == ch.initiator { + if handler.streamFlow == ch.connectionFlow { client = handler server = ch.otherStream } else { @@ -120,81 +186,48 @@ func (ch *connectionHandlerImpl) Complete(handler *StreamHandler) { server = handler } - ch.generateConnectionKey(startedAt) - - _, err := ch.storage.Insert("connections").Context(ch.context).One(OrderedDocument{ - {"_id", ch.connectionKey}, - {"ip_src", ch.initiator[0].String()}, - {"ip_dst", ch.initiator[1].String()}, - {"port_src", ch.initiator[2].String()}, - {"port_dst", ch.initiator[3].String()}, - {"started_at", startedAt}, - {"closed_at", closedAt}, - {"client_bytes", client.streamLength}, - {"server_bytes", server.streamLength}, - {"client_documents", len(client.documentsKeys)}, - {"server_documents", len(server.documentsKeys)}, - {"processed_at", time.Now()}, - }) + connectionID := ch.Storage().NewCustomRowID(ch.connectionFlow.Hash(), startedAt) + connection := Connection{ + ID: connectionID, + SourceIP: ch.connectionFlow[0].String(), + DestinationIP: ch.connectionFlow[1].String(), + SourcePort: binary.BigEndian.Uint16(ch.connectionFlow[2].Raw()), + DestinationPort: binary.BigEndian.Uint16(ch.connectionFlow[3].Raw()), + StartedAt: startedAt, + ClosedAt: closedAt, + ClientPackets: client.packetsCount, + ServerPackets: client.packetsCount, + ClientBytes: client.streamLength, + ServerBytes: server.streamLength, + ClientDocuments: len(client.documentsIDs), + ServerDocuments: len(server.documentsIDs), + ProcessedAt: time.Now(), + } + _, err := ch.Storage().Insert(Connections).One(connection) if err != nil { - log.Println("error inserting document on collection connections with _id = ", ch.connectionKey) + log.WithError(err).WithField("connection", connection).Error("failed to insert a connection") + return } - streamsIds := append(client.documentsKeys, server.documentsKeys...) - n, err := ch.storage.Update("connection_streams"). - Context(ch.context). + streamsIds := append(client.documentsIDs, server.documentsIDs...) + n, err := ch.Storage().Update(ConnectionStreams). Filter(OrderedDocument{{"_id", UnorderedDocument{"$in": streamsIds}}}). - Many(UnorderedDocument{"connection_id": ch.connectionKey}) + Many(UnorderedDocument{"connection_id": connectionID}) if err != nil { - log.Println("failed to update connection streams", err) - } - if int(n) != len(streamsIds) { - log.Println("failed to update all connections streams") + log.WithError(err).WithField("connection", connection).Error("failed to update connection streams") + } else if int(n) != len(streamsIds) { + log.WithError(err).WithField("connection", connection).Error("failed to update all connections streams") } } func (ch *connectionHandlerImpl) Storage() Storage { - return ch.storage -} - -func (ch *connectionHandlerImpl) Context() context.Context { - return ch.context -} - -func (ch *connectionHandlerImpl) Patterns() hyperscan.StreamDatabase { - return ch.patterns + return ch.factory.storage } -func (ch *connectionHandlerImpl) generateConnectionKey(firstPacketSeen time.Time) { - hash := make([]byte, 16) - binary.BigEndian.PutUint64(hash, uint64(firstPacketSeen.UnixNano())) - binary.BigEndian.PutUint64(hash[8:], ch.net.FastHash()^ch.transport.FastHash()) - - ch.connectionKey = fmt.Sprintf("%x", hash) -} - -func (factory *BiDirectionalStreamFactory) takeScratch() *hyperscan.Scratch { - factory.mPatterns.Lock() - defer factory.mPatterns.Unlock() - - if len(factory.scratches) == 0 { - scratch, err := hyperscan.NewScratch(factory.patterns) - if err != nil { - fmt.Println("failed to alloc a new scratch") - } - - return scratch - } - - index := len(factory.scratches) - 1 - scratch := factory.scratches[index] - factory.scratches = factory.scratches[:index] - - return scratch +func (ch *connectionHandlerImpl) PatternsDatabase() hyperscan.StreamDatabase { + return ch.factory.rulesDatabase.database } -func (factory *BiDirectionalStreamFactory) releaseScratch(scratch *hyperscan.Scratch) { - factory.mPatterns.Lock() - factory.scratches = append(factory.scratches, scratch) - factory.mPatterns.Unlock() +func (sf StreamFlow) Hash() uint64 { + return sf[0].FastHash() ^ sf[1].FastHash() ^ sf[2].FastHash() ^ sf[3].FastHash() } diff --git a/connections.go b/connections.go new file mode 100644 index 0000000..e3adaf2 --- /dev/null +++ b/connections.go @@ -0,0 +1,20 @@ +package main + +import "time" + +type Connection struct { + ID RowID `json:"id" bson:"_id"` + SourceIP string `json:"ip_src" bson:"ip_src"` + DestinationIP string `json:"ip_dst" bson:"ip_dst"` + SourcePort uint16 `json:"port_src" bson:"port_src"` + DestinationPort uint16 `json:"port_dst" bson:"port_dst"` + StartedAt time.Time `json:"started_at" bson:"started_at"` + ClosedAt time.Time `json:"closed_at" bson:"closed_at"` + ClientPackets int `json:"client_packets" bson:"client_packets"` + ServerPackets int `json:"server_packets" bson:"server_packets"` + ClientBytes int `json:"client_bytes" bson:"client_bytes"` + ServerBytes int `json:"server_bytes" bson:"server_bytes"` + ClientDocuments int `json:"client_documents" bson:"client_documents"` + ServerDocuments int `json:"server_documents" bson:"server_documents"` + ProcessedAt time.Time `json:"processed_at" bson:"processed_at"` +} diff --git a/rules_manager.go b/rules_manager.go index e5a8d38..69439e0 100644 --- a/rules_manager.go +++ b/rules_manager.go @@ -52,6 +52,11 @@ type Rule struct { Version int64 `json:"version" bson:"version"` } +type RulesDatabase struct { + database hyperscan.StreamDatabase + version RowID +} + type RulesManager struct { storage Storage rules map[string]Rule @@ -59,7 +64,7 @@ type RulesManager struct { ruleIndex int patterns map[string]Pattern mPatterns sync.Mutex - databaseUpdated chan interface{} + databaseUpdated chan RulesDatabase } func NewRulesManager(storage Storage) RulesManager { @@ -153,7 +158,10 @@ func (rm RulesManager) generateDatabase(version RowID) error { return err } - rm.databaseUpdated <- database + rm.databaseUpdated <- RulesDatabase{ + database: database, + version: version, + } return nil } diff --git a/stream_handler.go b/stream_handler.go index 3fafa21..2d80f60 100644 --- a/stream_handler.go +++ b/stream_handler.go @@ -18,7 +18,7 @@ const InitialPatternSliceSize = 8 // method: type StreamHandler struct { connection ConnectionHandler - streamKey StreamKey + streamFlow StreamFlow buffer *bytes.Buffer indexes []int timestamps []time.Time @@ -26,28 +26,31 @@ type StreamHandler struct { currentIndex int firstPacketSeen time.Time lastPacketSeen time.Time - documentsKeys []RowID + documentsIDs []RowID streamLength int + packetsCount int patternStream hyperscan.Stream patternMatches map[uint][]PatternSlice + scanner Scanner } // NewReaderStream returns a new StreamHandler object. -func NewStreamHandler(connection ConnectionHandler, key StreamKey, scratch *hyperscan.Scratch) StreamHandler { +func NewStreamHandler(connection ConnectionHandler, streamFlow StreamFlow, scanner Scanner) StreamHandler { handler := StreamHandler{ connection: connection, - streamKey: key, + streamFlow: streamFlow, buffer: new(bytes.Buffer), indexes: make([]int, 0, InitialBlockCount), timestamps: make([]time.Time, 0, InitialBlockCount), lossBlocks: make([]bool, 0, InitialBlockCount), - documentsKeys: make([]RowID, 0, 1), // most of the time the stream fit in one document + documentsIDs: make([]RowID, 0, 1), // most of the time the stream fit in one document patternMatches: make(map[uint][]PatternSlice, 10), // TODO: change with exactly value + scanner: scanner, } - stream, err := connection.Patterns().Open(0, scratch, handler.onMatch, nil) + stream, err := connection.PatternsDatabase().Open(0, scanner.scratch, handler.onMatch, nil) if err != nil { - log.WithField("streamKey", key).WithError(err).Error("failed to create a stream") + log.WithField("streamFlow", streamFlow).WithError(err).Error("failed to create a stream") } handler.patternStream = stream @@ -57,6 +60,8 @@ func NewStreamHandler(connection ConnectionHandler, key StreamKey, scratch *hype // Reassembled implements tcpassembly.Stream's Reassembled function. func (sh *StreamHandler) Reassembled(reassembly []tcpassembly.Reassembly) { for _, r := range reassembly { + sh.packetsCount++ + skip := r.Skip if r.Start { skip = 0 @@ -78,7 +83,7 @@ func (sh *StreamHandler) Reassembled(reassembly []tcpassembly.Reassembly) { n, err := sh.buffer.Write(r.Bytes[skip:]) if err != nil { log.WithError(err).Error("failed to copy bytes from a Reassemble") - return + continue } sh.indexes = append(sh.indexes, sh.currentIndex) sh.timestamps = append(sh.timestamps, r.Seen) @@ -86,18 +91,22 @@ func (sh *StreamHandler) Reassembled(reassembly []tcpassembly.Reassembly) { sh.currentIndex += n sh.streamLength += n - err = sh.patternStream.Scan(r.Bytes) - if err != nil { - log.WithError(err).Error("failed to scan packet buffer") + if sh.patternStream != nil { + err = sh.patternStream.Scan(r.Bytes) + if err != nil { + log.WithError(err).Error("failed to scan packet buffer") + } } } } // ReassemblyComplete implements tcpassembly.Stream's ReassemblyComplete function. func (sh *StreamHandler) ReassemblyComplete() { - err := sh.patternStream.Close() - if err != nil { - log.WithError(err).Error("failed to close pattern stream") + if sh.patternStream != nil { + err := sh.patternStream.Close() + if err != nil { + log.WithError(err).Error("failed to close pattern stream") + } } if sh.currentIndex > 0 { @@ -140,16 +149,14 @@ func (sh *StreamHandler) onMatch(id uint, from uint64, to uint64, flags uint, co } func (sh *StreamHandler) storageCurrentDocument() { - payload := (sh.streamKey[0].FastHash()^sh.streamKey[1].FastHash()^sh.streamKey[2].FastHash()^ - sh.streamKey[3].FastHash())&uint64(0xffffffffffffff00) | uint64(len(sh.documentsKeys)) // LOL - streamKey := sh.connection.Storage().NewCustomRowID(payload, sh.firstPacketSeen) + payload := sh.streamFlow.Hash()&uint64(0xffffffffffffff00) | uint64(len(sh.documentsIDs)) // LOL + streamID := sh.connection.Storage().NewCustomRowID(payload, sh.firstPacketSeen) _, err := sh.connection.Storage().Insert(ConnectionStreams). - Context(sh.connection.Context()). One(ConnectionStream{ - ID: streamKey, + ID: streamID, ConnectionID: ZeroRowID, - DocumentIndex: len(sh.documentsKeys), + DocumentIndex: len(sh.documentsIDs), Payload: sh.buffer.Bytes(), BlocksIndexes: sh.indexes, BlocksTimestamps: sh.timestamps, @@ -159,7 +166,7 @@ func (sh *StreamHandler) storageCurrentDocument() { if err != nil { log.WithError(err).Error("failed to insert connection stream") + } else { + sh.documentsIDs = append(sh.documentsIDs, streamID) } - - sh.documentsKeys = append(sh.documentsKeys, streamKey) } diff --git a/stream_handler_test.go b/stream_handler_test.go index ece3190..962048f 100644 --- a/stream_handler_test.go +++ b/stream_handler_test.go @@ -39,7 +39,7 @@ func TestReassemblingEmptyStream(t *testing.T) { assert.Zero(t, streamHandler.currentIndex) assert.Zero(t, streamHandler.firstPacketSeen) assert.Zero(t, streamHandler.lastPacketSeen) - assert.Len(t, streamHandler.documentsKeys, 0) + assert.Len(t, streamHandler.documentsIDs, 0) assert.Zero(t, streamHandler.streamLength) assert.Len(t, streamHandler.patternMatches, 0) @@ -125,7 +125,7 @@ func TestReassemblingSingleDocument(t *testing.T) { assert.Equal(t, len(data), streamHandler.currentIndex) assert.Equal(t, firstTime, streamHandler.firstPacketSeen) assert.Equal(t, lastTime, streamHandler.lastPacketSeen) - assert.Len(t, streamHandler.documentsKeys, 1) + assert.Len(t, streamHandler.documentsIDs, 1) assert.Equal(t, len(data), streamHandler.streamLength) assert.Len(t, streamHandler.patternMatches, 0) @@ -210,7 +210,7 @@ func TestReassemblingMultipleDocuments(t *testing.T) { assert.Equal(t, MaxDocumentSize, streamHandler.currentIndex) assert.Equal(t, firstTime, streamHandler.firstPacketSeen) assert.Equal(t, lastTime, streamHandler.lastPacketSeen) - assert.Len(t, streamHandler.documentsKeys, 2) + assert.Len(t, streamHandler.documentsIDs, 2) assert.Equal(t, len(data), streamHandler.streamLength) assert.Len(t, streamHandler.patternMatches, 0) @@ -287,7 +287,7 @@ func TestReassemblingPatternMatching(t *testing.T) { assert.Equal(t, len(payload), streamHandler.currentIndex) assert.Equal(t, seen, streamHandler.firstPacketSeen) assert.Equal(t, seen, streamHandler.lastPacketSeen) - assert.Len(t, streamHandler.documentsKeys, 1) + assert.Len(t, streamHandler.documentsIDs, 1) assert.Equal(t, len(payload), streamHandler.streamLength) assert.Equal(t, true, completed, "completed") @@ -310,7 +310,8 @@ func createTestStreamHandler(wrapper *TestStorageWrapper, patterns hyperscan.Str srcPort := layers.NewTCPPortEndpoint(srcPort) dstPort := layers.NewTCPPortEndpoint(dstPort) - return NewStreamHandler(testConnectionHandler, StreamKey{srcIp, dstIp, srcPort, dstPort}, scratch) + scanner := Scanner{scratch: scratch, version: ZeroRowID} + return NewStreamHandler(testConnectionHandler, StreamFlow{srcIp, dstIp, srcPort, dstPort}, scanner) } type testConnectionHandler struct { @@ -327,7 +328,7 @@ func (tch *testConnectionHandler) Context() context.Context { return tch.wrapper.Context } -func (tch *testConnectionHandler) Patterns() hyperscan.StreamDatabase { +func (tch *testConnectionHandler) PatternsDatabase() hyperscan.StreamDatabase { return tch.patterns } |