diff options
Diffstat (limited to 'connection_handler.go')
-rw-r--r-- | connection_handler.go | 261 |
1 files changed, 147 insertions, 114 deletions
diff --git a/connection_handler.go b/connection_handler.go index 2e2fa84..4fc2a95 100644 --- a/connection_handler.go +++ b/connection_handler.go @@ -1,95 +1,161 @@ package main import ( - "context" "encoding/binary" - "fmt" "github.com/flier/gohs/hyperscan" "github.com/google/gopacket" "github.com/google/gopacket/tcpassembly" - "log" + log "github.com/sirupsen/logrus" "sync" "time" ) +const initialConnectionsCapacity = 1024 +const initialScannersCapacity = 1024 + type BiDirectionalStreamFactory struct { - storage Storage - serverIp gopacket.Endpoint - connections map[StreamKey]ConnectionHandler - mConnections sync.Mutex - patterns hyperscan.StreamDatabase - mPatterns sync.Mutex - scratches []*hyperscan.Scratch + storage Storage + serverIp gopacket.Endpoint + connections map[StreamFlow]ConnectionHandler + mConnections sync.Mutex + rulesManager *RulesManager + rulesDatabase RulesDatabase + mRulesDatabase sync.Mutex + scanners []Scanner } -type StreamKey [4]gopacket.Endpoint +type StreamFlow [4]gopacket.Endpoint + +type Scanner struct { + scratch *hyperscan.Scratch + version RowID +} type ConnectionHandler interface { Complete(handler *StreamHandler) Storage() Storage - Context() context.Context - Patterns() hyperscan.StreamDatabase + PatternsDatabase() hyperscan.StreamDatabase } type connectionHandlerImpl struct { - storage Storage - net, transport gopacket.Flow - initiator StreamKey - connectionKey string + factory *BiDirectionalStreamFactory + connectionFlow StreamFlow mComplete sync.Mutex otherStream *StreamHandler - context context.Context - patterns hyperscan.StreamDatabase +} + +func NewBiDirectionalStreamFactory(storage Storage, serverIP gopacket.Endpoint, + rulesManager *RulesManager) *BiDirectionalStreamFactory { + + factory := &BiDirectionalStreamFactory{ + storage: storage, + serverIp: serverIP, + connections: make(map[StreamFlow]ConnectionHandler, initialConnectionsCapacity), + mConnections: sync.Mutex{}, + rulesManager: rulesManager, + mRulesDatabase: sync.Mutex{}, + scanners: make([]Scanner, 0, initialScannersCapacity), + } + + go factory.updateRulesDatabaseService() + return factory +} + +func (factory *BiDirectionalStreamFactory) updateRulesDatabaseService() { + for { + select { + case rulesDatabase, ok := <-factory.rulesManager.databaseUpdated: + if !ok { + return + } + factory.mRulesDatabase.Lock() + scanners := factory.scanners + factory.scanners = factory.scanners[:0] + + for _, s := range scanners { + err := s.scratch.Realloc(rulesDatabase.database) + if err != nil { + log.WithError(err).Error("failed to realloc an existing scanner") + } else { + s.version = rulesDatabase.version + factory.scanners = append(factory.scanners, s) + } + } + + factory.rulesDatabase = rulesDatabase + factory.mRulesDatabase.Unlock() + } + } +} + +func (factory *BiDirectionalStreamFactory) takeScanner() Scanner { + factory.mRulesDatabase.Lock() + defer factory.mRulesDatabase.Unlock() + + if len(factory.scanners) == 0 { + scratch, err := hyperscan.NewScratch(factory.rulesDatabase.database) + if err != nil { + log.WithError(err).Fatal("failed to alloc a new scratch") + } + + return Scanner{ + scratch: scratch, + version: factory.rulesDatabase.version, + } + } + + index := len(factory.scanners) - 1 + scanner := factory.scanners[index] + factory.scanners = factory.scanners[:index] + + return scanner +} + +func (factory *BiDirectionalStreamFactory) releaseScanner(scanner Scanner) { + factory.mRulesDatabase.Lock() + defer factory.mRulesDatabase.Unlock() + + if scanner.version != factory.rulesDatabase.version { + err := scanner.scratch.Realloc(factory.rulesDatabase.database) + if err != nil { + log.WithError(err).Error("failed to realloc an existing scanner") + return + } + } + factory.scanners = append(factory.scanners, scanner) } func (factory *BiDirectionalStreamFactory) New(net, transport gopacket.Flow) tcpassembly.Stream { - key := StreamKey{net.Src(), net.Dst(), transport.Src(), transport.Dst()} - invertedKey := StreamKey{net.Dst(), net.Src(), transport.Dst(), transport.Src()} + flow := StreamFlow{net.Src(), net.Dst(), transport.Src(), transport.Dst()} + invertedFlow := StreamFlow{net.Dst(), net.Src(), transport.Dst(), transport.Src()} factory.mConnections.Lock() - connection, isPresent := factory.connections[invertedKey] + connection, isPresent := factory.connections[invertedFlow] if isPresent { - delete(factory.connections, invertedKey) + delete(factory.connections, invertedFlow) } else { - var initiator StreamKey + var connectionFlow StreamFlow if net.Src() == factory.serverIp { - initiator = invertedKey + connectionFlow = invertedFlow } else { - initiator = key + connectionFlow = flow } connection = &connectionHandlerImpl{ - storage: factory.storage, - net: net, - transport: transport, - initiator: initiator, - mComplete: sync.Mutex{}, - context: context.Background(), - patterns: factory.patterns, + connectionFlow: connectionFlow, + mComplete: sync.Mutex{}, + factory: factory, } - factory.connections[key] = connection + factory.connections[flow] = connection } factory.mConnections.Unlock() - streamHandler := NewStreamHandler(connection, key, factory.takeScratch()) + streamHandler := NewStreamHandler(connection, flow, factory.takeScanner()) return &streamHandler } -func (factory *BiDirectionalStreamFactory) UpdatePatternsDatabase(database hyperscan.StreamDatabase) { - factory.mPatterns.Lock() - factory.patterns = database - - for _, s := range factory.scratches { - err := s.Realloc(database) - if err != nil { - fmt.Println("failed to realloc an existing scratch") - } - } - - factory.mPatterns.Unlock() -} - func (ch *connectionHandlerImpl) Complete(handler *StreamHandler) { + ch.factory.releaseScanner(handler.scanner) ch.mComplete.Lock() if ch.otherStream == nil { ch.otherStream = handler @@ -112,7 +178,7 @@ func (ch *connectionHandlerImpl) Complete(handler *StreamHandler) { } var client, server *StreamHandler - if handler.streamKey == ch.initiator { + if handler.streamFlow == ch.connectionFlow { client = handler server = ch.otherStream } else { @@ -120,81 +186,48 @@ func (ch *connectionHandlerImpl) Complete(handler *StreamHandler) { server = handler } - ch.generateConnectionKey(startedAt) - - _, err := ch.storage.Insert("connections").Context(ch.context).One(OrderedDocument{ - {"_id", ch.connectionKey}, - {"ip_src", ch.initiator[0].String()}, - {"ip_dst", ch.initiator[1].String()}, - {"port_src", ch.initiator[2].String()}, - {"port_dst", ch.initiator[3].String()}, - {"started_at", startedAt}, - {"closed_at", closedAt}, - {"client_bytes", client.streamLength}, - {"server_bytes", server.streamLength}, - {"client_documents", len(client.documentsKeys)}, - {"server_documents", len(server.documentsKeys)}, - {"processed_at", time.Now()}, - }) + connectionID := ch.Storage().NewCustomRowID(ch.connectionFlow.Hash(), startedAt) + connection := Connection{ + ID: connectionID, + SourceIP: ch.connectionFlow[0].String(), + DestinationIP: ch.connectionFlow[1].String(), + SourcePort: binary.BigEndian.Uint16(ch.connectionFlow[2].Raw()), + DestinationPort: binary.BigEndian.Uint16(ch.connectionFlow[3].Raw()), + StartedAt: startedAt, + ClosedAt: closedAt, + ClientPackets: client.packetsCount, + ServerPackets: client.packetsCount, + ClientBytes: client.streamLength, + ServerBytes: server.streamLength, + ClientDocuments: len(client.documentsIDs), + ServerDocuments: len(server.documentsIDs), + ProcessedAt: time.Now(), + } + _, err := ch.Storage().Insert(Connections).One(connection) if err != nil { - log.Println("error inserting document on collection connections with _id = ", ch.connectionKey) + log.WithError(err).WithField("connection", connection).Error("failed to insert a connection") + return } - streamsIds := append(client.documentsKeys, server.documentsKeys...) - n, err := ch.storage.Update("connection_streams"). - Context(ch.context). + streamsIds := append(client.documentsIDs, server.documentsIDs...) + n, err := ch.Storage().Update(ConnectionStreams). Filter(OrderedDocument{{"_id", UnorderedDocument{"$in": streamsIds}}}). - Many(UnorderedDocument{"connection_id": ch.connectionKey}) + Many(UnorderedDocument{"connection_id": connectionID}) if err != nil { - log.Println("failed to update connection streams", err) - } - if int(n) != len(streamsIds) { - log.Println("failed to update all connections streams") + log.WithError(err).WithField("connection", connection).Error("failed to update connection streams") + } else if int(n) != len(streamsIds) { + log.WithError(err).WithField("connection", connection).Error("failed to update all connections streams") } } func (ch *connectionHandlerImpl) Storage() Storage { - return ch.storage -} - -func (ch *connectionHandlerImpl) Context() context.Context { - return ch.context -} - -func (ch *connectionHandlerImpl) Patterns() hyperscan.StreamDatabase { - return ch.patterns + return ch.factory.storage } -func (ch *connectionHandlerImpl) generateConnectionKey(firstPacketSeen time.Time) { - hash := make([]byte, 16) - binary.BigEndian.PutUint64(hash, uint64(firstPacketSeen.UnixNano())) - binary.BigEndian.PutUint64(hash[8:], ch.net.FastHash()^ch.transport.FastHash()) - - ch.connectionKey = fmt.Sprintf("%x", hash) -} - -func (factory *BiDirectionalStreamFactory) takeScratch() *hyperscan.Scratch { - factory.mPatterns.Lock() - defer factory.mPatterns.Unlock() - - if len(factory.scratches) == 0 { - scratch, err := hyperscan.NewScratch(factory.patterns) - if err != nil { - fmt.Println("failed to alloc a new scratch") - } - - return scratch - } - - index := len(factory.scratches) - 1 - scratch := factory.scratches[index] - factory.scratches = factory.scratches[:index] - - return scratch +func (ch *connectionHandlerImpl) PatternsDatabase() hyperscan.StreamDatabase { + return ch.factory.rulesDatabase.database } -func (factory *BiDirectionalStreamFactory) releaseScratch(scratch *hyperscan.Scratch) { - factory.mPatterns.Lock() - factory.scratches = append(factory.scratches, scratch) - factory.mPatterns.Unlock() +func (sf StreamFlow) Hash() uint64 { + return sf[0].FastHash() ^ sf[1].FastHash() ^ sf[2].FastHash() ^ sf[3].FastHash() } |